├── .github ├── CONTRIBUTING.md ├── CONTRIBUTING_zh-CN.md ├── ISSUE_TEMPLATE │ ├── bug_report.md │ ├── custom.md │ └── feature_request.md └── pull_request_template.md ├── .gitignore ├── LICENSE ├── README.md ├── README_zh-CN.md ├── assets ├── Latte │ ├── Latte网络结构.png │ ├── S-AdaLN.png │ ├── T2V1.png │ ├── T2V2.png │ ├── ViT.png │ ├── patch_embedding.png │ ├── result.jpg │ ├── training_FVD.png │ └── 模型配置.png ├── SD3论文领读.png ├── logo.jpg ├── qrcode.png ├── sora-reproduce.png ├── sora夜谈.png └── wechatqrcode.jpg ├── codes ├── OpenDiT │ ├── .gitignore │ ├── .isort.cfg │ ├── .pre-commit-config.yaml │ ├── CONTRIBUTING.md │ ├── LICENSE │ ├── README.md │ ├── figure │ │ ├── dit_loss.png │ │ ├── dit_results.png │ │ ├── end2end.png │ │ ├── fastseq_exp.png │ │ ├── fastseq_overview.png │ │ ├── logo.png │ │ └── wechat.jpg │ ├── opendit │ │ ├── __init__.py │ │ ├── diffusion │ │ │ ├── __init__.py │ │ │ ├── diffusion_utils.py │ │ │ ├── gaussian_diffusion.py │ │ │ ├── respace.py │ │ │ └── timestep_sampler.py │ │ ├── embed │ │ │ ├── clip_text_emb.py │ │ │ ├── label_emb.py │ │ │ ├── patch_emb.py │ │ │ ├── pos_emb.py │ │ │ └── time_emb.py │ │ ├── kernels │ │ │ ├── __init__.py │ │ │ ├── fused_modulate.py │ │ │ └── k_fused_modulate.py │ │ ├── models │ │ │ ├── __init__.py │ │ │ ├── dit.py │ │ │ └── latte.py │ │ ├── modules │ │ │ ├── __init__.py │ │ │ ├── attn.py │ │ │ └── block.py │ │ ├── utils │ │ │ ├── ckpt_utils.py │ │ │ ├── data_utils.py │ │ │ ├── debug_utils.py │ │ │ ├── download.py │ │ │ ├── operation.py │ │ │ ├── pg_utils.py │ │ │ ├── train_utils.py │ │ │ └── video_utils.py │ │ └── vae │ │ │ ├── attention.py │ │ │ ├── data.py │ │ │ ├── download.py │ │ │ ├── reconstruct.py │ │ │ ├── utils.py │ │ │ ├── vqvae.py │ │ │ └── wrapper.py │ ├── preprocess.py │ ├── requirements.txt │ ├── sample.py │ ├── sample_img.sh │ ├── sample_video.sh │ ├── setup.py │ ├── tests │ │ ├── test_checkpoint.py │ │ ├── test_clip.py │ │ ├── test_dataloader.py │ │ ├── test_ema_sharding.py │ │ ├── test_fastseq_parallel.py │ │ ├── test_flash_attention.py │ │ ├── test_fused_modulate.py │ │ ├── test_model │ │ │ ├── org_dit.py │ │ │ └── test_model.py │ │ ├── test_rearrange_model.py │ │ └── test_ulysses_parallel.py │ ├── train.py │ ├── train_img.sh │ ├── train_video.sh │ └── videos │ │ ├── art-museum.mp4 │ │ ├── demo.csv │ │ ├── lagos.mp4 │ │ ├── man-on-the-cloud.mp4 │ │ └── suv-in-the-dust.mp4 ├── README.md ├── README_zh-CN.md ├── SiT │ ├── .gitignore │ ├── LICENSE.txt │ ├── README.md │ ├── download.py │ ├── environment.yml │ ├── models.py │ ├── run_SiT.ipynb │ ├── sample.py │ ├── sample_ddp.py │ ├── train.py │ ├── train_utils.py │ ├── transport │ │ ├── __init__.py │ │ ├── integrators.py │ │ ├── path.py │ │ ├── transport.py │ │ └── utils.py │ ├── visuals │ │ ├── visual.png │ │ └── visual_2.png │ └── wandb_utils.py ├── StableCascade │ ├── .gitignore │ ├── LICENSE │ ├── README.md │ ├── WEIGHTS_LICENSE │ ├── __init__.py │ ├── configs │ │ └── training │ │ │ ├── controlnet_c_3b_canny.yaml │ │ │ ├── controlnet_c_3b_identity.yaml │ │ │ ├── controlnet_c_3b_inpainting.yaml │ │ │ ├── controlnet_c_3b_sr.yaml │ │ │ ├── finetune_b_3b.yaml │ │ │ ├── finetune_b_700m.yaml │ │ │ ├── finetune_c_1b.yaml │ │ │ ├── finetune_c_3b.yaml │ │ │ ├── finetune_c_3b_lora.yaml │ │ │ ├── finetune_c_3b_lowres.yaml │ │ │ └── finetune_c_3b_v.yaml │ ├── core │ │ ├── __init__.py │ │ ├── data │ │ │ ├── __init__.py │ │ │ └── bucketeer.py │ │ ├── scripts │ │ │ ├── __init__.py │ │ │ └── cli.py │ │ ├── templates │ │ │ ├── __init__.py │ │ │ └── diffusion.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── base_dto.py │ │ │ └── save_and_load.py │ ├── figures │ │ ├── collage_1.jpg │ │ ├── collage_2.jpg │ │ ├── collage_3.jpg │ │ ├── collage_4.jpg │ │ ├── comparison-inference-speed.jpg │ │ ├── comparison.png │ │ ├── controlnet-canny.jpg │ │ ├── controlnet-face.jpg │ │ ├── controlnet-paint.jpg │ │ ├── controlnet-sr.jpg │ │ ├── fernando.jpg │ │ ├── fernando_original.jpg │ │ ├── image-to-image-example-rodent.jpg │ │ ├── image-variations-example-headset.jpg │ │ ├── model-overview.jpg │ │ ├── original.jpg │ │ ├── reconstructed.jpg │ │ └── text-to-image-example-penguin.jpg │ ├── gdf │ │ ├── __init__.py │ │ ├── loss_weights.py │ │ ├── noise_conditions.py │ │ ├── readme.md │ │ ├── samplers.py │ │ ├── scalers.py │ │ ├── schedulers.py │ │ └── targets.py │ ├── gradio_app │ │ ├── app.py │ │ └── style.css │ ├── models │ │ ├── download_models.sh │ │ └── readme.md │ ├── modules │ │ ├── __init__.py │ │ ├── cnet_modules │ │ │ ├── face_id │ │ │ │ └── arcface.py │ │ │ ├── inpainting │ │ │ │ ├── saliency_model.pt │ │ │ │ └── saliency_model.py │ │ │ └── pidinet │ │ │ │ ├── __init__.py │ │ │ │ ├── ckpts │ │ │ │ └── table5_pidinet.pth │ │ │ │ ├── model.py │ │ │ │ └── util.py │ │ ├── common.py │ │ ├── controlnet.py │ │ ├── effnet.py │ │ ├── lora.py │ │ ├── previewer.py │ │ ├── stage_a.py │ │ ├── stage_b.py │ │ └── stage_c.py │ ├── requirements.txt │ └── train │ │ ├── __init__.py │ │ ├── base.py │ │ ├── example_train.sh │ │ ├── readme.md │ │ ├── train_b.py │ │ ├── train_c.py │ │ ├── train_c_controlnet.py │ │ └── train_c_lora.py ├── WALT │ └── README.md └── pipeline │ ├── README.md │ └── README_zh-CN.md ├── docs ├── HOT_NEWS_BASELINES_GUIDES.md ├── HOT_NEWS_BASELINES_GUIDES_zh-CN.md ├── Minisora_LPRS │ ├── 0001.jpg │ ├── 0002.jpg │ ├── 0003.jpg │ ├── 0004.jpg │ ├── 0005.jpg │ ├── 0006.jpg │ ├── 0007.jpg │ ├── 0008.jpg │ ├── 0009.jpg │ ├── 0010.jpg │ ├── 0011.jpg │ ├── 0012.jpg │ ├── 0013.jpg │ ├── 0014.jpg │ └── 0015.jpg ├── README.md ├── README_zh-CN.md ├── survey_README.md └── survey_README_zh-CN.md └── notes ├── Latte.md ├── PixArt-Σ 论文精读翻译.pdf ├── PixArt-Σ论文解析.pdf ├── README.md ├── README_zh-CN.md ├── SD3_zh-CN.md ├── dataset_note.md └── latte论文精读翻译.pdf /.github/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to contribute to the Mini Sora community 2 | 3 | English | [简体中文](./CONTRIBUTING_zh-CN.md) 4 | 5 | The Mini Sora open-source community is positioned as a community-driven initiative (**free of charge and devoid of any exploitation**) organized spontaneously by community members. 6 | 7 | We really hope that you can contribute to the Mini Sora open source community and help us make it better than it is now! If you're making your first contribution to Mini Sora, you can check out the list of issues labeled [good first PR](https://github.com/mini-sora/minisora/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+PR%22). 8 | 9 | ## Submitting a Pull Request (PR) 10 | 11 | As a contributor, before submitting your request, here are the guidelines we hope you follow: 12 | 13 | 1. Firstly, please search in the [Mini Sora GitHub](https://github.com/mini-sora/minisora/pulls) to see if there are any open or closed pull requests related to the content you intend to submit. We assume you wouldn't want to duplicate existing work. 14 | 15 | 2. Next, [fork](https://github.com/mini-sora/minisora/fork) the [minisora](https://github.com/mini-sora/minisora) repository, and download your forked repository to your local machine. 16 | 17 | ``` 18 | git clone 【your-forked-repository-url】 19 | ``` 20 | 21 | 3. To add the original Mini Sora repository as a remote and facilitate syncing with the latest updates: 22 | 23 | ``` 24 | git remote add upstream https://github.com/mini-sora/minisora 25 | ``` 26 | 27 | 4. Sync the code from the main repository to your local machine, and then sync it back to your forked remote repository. 28 | 29 | ``` 30 | # Pull the latest code from the upstream branch 31 | git fetch upstream 32 | 33 | # Switch to the main branch 34 | git checkout main 35 | 36 | # Merge the updates from the upstream branch into main, synchronizing the local main branch with the upstream 37 | git merge upstream/main 38 | 39 | # Additionally, sync the local main branch to the remote branch of your forked repository 40 | git push origin main 41 | ``` 42 | 43 | > Note: Before starting each submission, please synchronize the code from the main repository. 44 | 45 | 46 | 47 | 5. In your forked repository, please create a branch for submitting your changes. The branch name should preferably be meaningful. 48 | 49 | ``` 50 | git checkout -b my-docs-branch main 51 | ``` 52 | 53 | 6. While making modifications on your branch and committing changes, please adhere to our [Commit Message Format](#Commit-Message-Format) for composing commit descriptions. If you're adding a paper, please follow the [Paper Naming Convention](#Paper-Naming-Convention) when filling in the paper index information. 54 | 55 | ``` 56 | git commit -m "[docs]: xxxx" 57 | ``` 58 | 59 | 7. Submit your data to your GitHub repository. 60 | 61 | ``` 62 | git push origin my-docs-branch 63 | ``` 64 | 65 | 8. Go back to the GitHub repository page and submit a pull request to `minisora:main`. 66 | 67 | 9. The works added to the "最近更新" Section should be related to Video Generation and should be a Very Impressive Work. 68 | 69 | ## Commit Message Format 70 | 71 | The commit message must include both `` and `` sections. 72 | 73 | ``` 74 | []: 75 | │ │ 76 | │ └─⫸ Briefly describe your changes, without ending with a period. 77 | │ 78 | └─⫸ Commit Type: |docs|feat|fix|refactor| 79 | ``` 80 | 81 | ### Type 82 | 83 | * **docs**:When you modify a document or add a document, select `docs` 84 | 85 | The following types are reserved for future code collaboration if needed. 86 | 87 | * **feat**:Here it refers to a new feature. 88 | * **fix**:fix bug 89 | * **refactor**: Refactor code, no new features or bug fixes involved 90 | 91 | ### summary 92 | 93 | * Describe the modifications in English, without ending with a period (.) 94 | 95 | > eg: git commit -m "[docs]: add a contributing.md file" 96 | 97 | ## Paper Naming Convention 98 | 99 | 100 | > Format —— [**Journal Name**] Paper Title 101 | > 102 | > Example1——[**CVPR 24** paper] **lovieChat**: From Dense Token to Sparse Memory for Long Video Understanding 103 | > 104 | > Example2——[**Paper**] **Sora**: Creating video from text 105 | -------------------------------------------------------------------------------- /.github/CONTRIBUTING_zh-CN.md: -------------------------------------------------------------------------------- 1 | # 如何向Mini Sora 社区贡献 2 | 3 | [English](./CONTRIBUTING.md) | 简体中文 4 | 5 | Mini Sora 开源社区定位为由社区同学自发组织的开源社区(**免费不收取任何费用、不割韭菜**) 6 | 7 | 我们非常希望你们能够为 Mini Sora 开源社区做出贡献,并且帮助我们把它做得比现在更好!如果你首次为 Mini Sora 做贡献,可以查看 [good first PR](https://github.com/mini-sora/minisora/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+PR%22) 的 issue 列表。 8 | 9 | ## 提交请求(PR) 10 | 11 | 作为贡献者,在你提交你的请求之前,以下是我们希望你遵循的规范: 12 | 13 | 1. 首先,在 [Mini Sora GitHub](https://github.com/mini-sora/minisora/pulls) 中搜索与您想要提交相关的内容开放或关闭的 PR。我们想您也不希望重复现有的工作。 14 | 15 | 2. 然后 [Fork](https://github.com/mini-sora/minisora/fork) [minisora](https://github.com/mini-sora/minisora) 仓库,并下载你的仓库到本地 16 | 17 | ``` 18 | git clone 【你fork的仓库地址】 19 | ``` 20 | 21 | 3. 添加mini-sora原仓库,方便同步远程仓库最新的更新 22 | 23 | ``` 24 | git remote add upstream https://github.com/mini-sora/minisora 25 | ``` 26 | 27 | 4. 同步主仓库代码到你本地,以及同步回你fork的远程仓库 28 | 29 | ``` 30 | # 从upstream分支上,拉取最新代码 31 | git fetch upstream 32 | # 切换到main分支上 33 | git checkout main 34 | # 把upstream分支上的更新内容合并到mian上,本地的main分支就和上游同步了 35 | git merge upstream/main 36 | # 还需把本地main同步到【你fork的仓库地址】的远程分支 37 | git push origin main 38 | ``` 39 | 40 | > 注意:每次开始提交前,请先同步主仓库的代码 41 | 42 | 43 | 44 | 5. 在你自己fork的仓库,请创建一个分支用于提交你的变更内容。分支名尽可能的有一定意义。 45 | 46 | ``` 47 | git checkout -b my-docs-branch main 48 | ``` 49 | 50 | 6. 在你的分支上面进行修改,提交commit时,请按照我们的[Commit消息格式](#Commit消息格式)进行编写commit描述,当你是添加了论文,在填写论文索引时,请按照[论文命名规范](#论文命名规范)填写索引数据。 51 | 52 | ``` 53 | git commit -m "[docs]: xxxx" 54 | ``` 55 | 56 | 7. 提交数据到你的GitHub仓库 57 | 58 | ``` 59 | git push origin my-docs-branch 60 | ``` 61 | 62 | 8. 回到GitHub仓库页面,提交PR到`minisora:main` 63 | 64 | 9. 加入"最近更新"的工作栏目中添加的工作应该与视频生成有关,并且应该是非常令人印象深刻的工作。 65 | 66 | ## Commit消息格式 67 | 68 | 提交的 commit message 必须包含``和``两部分 69 | 70 | ``` 71 | []: 72 | │ │ 73 | │ └─⫸ 简短描述你的修改内容,结尾没有句号 74 | │ 75 | └─⫸ Commit Type: |docs|feat|fix|refactor| 76 | ``` 77 | 78 | ### Type 79 | 80 | * **docs**:当你修改了文档,或者添加了文档,选择`docs` 81 | 82 | 以下类型是如果后续涉及到代码协作预留 83 | 84 | * **feat**:这里是指一个新的功能 85 | * **fix**:修复bug 86 | * **refactor**: 重构代码,不涉及新功能或者bug修复 87 | 88 | ### summary 89 | 90 | * 用英文描述修改的内容,不要用句号(.)结尾 91 | 92 | > eg: git commit -m "[docs]: add a contributing.md file" 93 | 94 | ## 论文命名规范 95 | 96 | >格式—— [**期刊名**] 论文名称 97 | > 98 | >样例1——[**CVPR 24** paper] **lovieChat**: From Dense Token to Sparse Memory for Long Video Understanding 99 | > 100 | >样例2——[**Paper**] **Sora**: Creating video from text 101 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. Go to '...' 16 | 2. Click on '....' 17 | 3. Scroll down to '....' 18 | 4. See error 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Screenshots** 24 | If applicable, add screenshots to help explain your problem. 25 | 26 | **Desktop (please complete the following information):** 27 | - OS: [e.g. iOS] 28 | - Browser [e.g. chrome, safari] 29 | - Version [e.g. 22] 30 | 31 | **Smartphone (please complete the following information):** 32 | - Device: [e.g. iPhone6] 33 | - OS: [e.g. iOS8.1] 34 | - Browser [e.g. stock browser, safari] 35 | - Version [e.g. 22] 36 | 37 | **Additional context** 38 | Add any other context about the problem here. 39 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/custom.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Sora issue template 3 | about: A document issue template that might be more suitable for minisora's needs. 4 | title: '[Add/Update/Remove/Correct/Other] - Brief Description' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | ## **Update Type in Title** 10 | 11 | > Example 12 | > - Add new paper/project/resource 13 | > - Update existing paper/project/resource information 14 | > - Remove outdated paper/project/resource 15 | > - Correct document formatting/link/spelling errors 16 | > - Other document updates 17 | 18 | ## **Detailed Description** 19 | 20 | **Content Name/Link**: [Provide the name or link of the content that needs to be updated] 21 | 22 | **Current Status/Issue**: [Briefly describe the current status or problem with the content] 23 | 24 | **Update Details**: [Describe in detail the updates you wish to make, including new information, modifications, or removal of outdated content] 25 | 26 | ## **Additional Information** 27 | **Reason for Update**: [Explain why the update is necessary and how it will benefit the community] 28 | 29 | **Deadline (if any)**: [If there is a specific deadline for the update, please provide it] 30 | 31 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | ## Pull Request Description 2 | 3 | Please read [CONTRIBUTING](https://github.com/mini-sora/minisora/blob/main/.github/CONTRIBUTING.md) manual before submitting your PR. 4 | 5 | ### Title 6 | [Add/Update/Remove/Correct/Other] - Brief Description 7 | 8 | ### Summary 9 | This pull request adds information about the xx paper and xx project to the README file. XXX is xxx video diffusion model for xxx task. 10 | 11 | ### Changes 12 | - Added a section in the README with links to the XXX paper, GitHub repository, and project website. 13 | - Ensured that the links are functional and the information is up-to-date. 14 | 15 | ### Review Notes 16 | - Please review the added content for accuracy and formatting. 17 | - Check that the links provided are correct and lead to the intended resources. 18 | - **Please check if both the Chinese and English READMEs have been updated simultaneously.** 19 | 20 | ### Related Issues 21 | - Closes #issue-number (if there's an associated issue) 22 | 23 | ### Additional Comments 24 | - The addition of XXX to our README will provide valuable resources for users interested in efficient video generation. 25 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | .ipynb_checkpoints/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | data/conversations.json 10 | inference/ 11 | inference_results/ 12 | output/ 13 | train_data/ 14 | log/ 15 | *.DS_Store 16 | *.vs 17 | *.user 18 | *~ 19 | *.vscode 20 | *.idea 21 | 22 | *.log 23 | .clang-format 24 | .clang_format.hook 25 | temp/ 26 | build/ 27 | dist/ 28 | paddleocr.egg-info/ 29 | /deploy/android_demo/app/OpenCV/ 30 | /deploy/android_demo/app/PaddleLite/ 31 | /deploy/android_demo/app/.cxx/ 32 | /deploy/android_demo/app/cache/ 33 | test_tipc/web/models/ 34 | test_tipc/web/node_modules/ 35 | test_video/ -------------------------------------------------------------------------------- /assets/Latte/Latte网络结构.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mini-sora/minisora/c831ce3cbac4c9a81a315c6901145c023814897c/assets/Latte/Latte网络结构.png -------------------------------------------------------------------------------- /assets/Latte/S-AdaLN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mini-sora/minisora/c831ce3cbac4c9a81a315c6901145c023814897c/assets/Latte/S-AdaLN.png -------------------------------------------------------------------------------- /assets/Latte/T2V1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mini-sora/minisora/c831ce3cbac4c9a81a315c6901145c023814897c/assets/Latte/T2V1.png -------------------------------------------------------------------------------- /assets/Latte/T2V2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mini-sora/minisora/c831ce3cbac4c9a81a315c6901145c023814897c/assets/Latte/T2V2.png -------------------------------------------------------------------------------- /assets/Latte/ViT.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mini-sora/minisora/c831ce3cbac4c9a81a315c6901145c023814897c/assets/Latte/ViT.png -------------------------------------------------------------------------------- /assets/Latte/patch_embedding.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mini-sora/minisora/c831ce3cbac4c9a81a315c6901145c023814897c/assets/Latte/patch_embedding.png -------------------------------------------------------------------------------- /assets/Latte/result.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mini-sora/minisora/c831ce3cbac4c9a81a315c6901145c023814897c/assets/Latte/result.jpg -------------------------------------------------------------------------------- /assets/Latte/training_FVD.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mini-sora/minisora/c831ce3cbac4c9a81a315c6901145c023814897c/assets/Latte/training_FVD.png -------------------------------------------------------------------------------- /assets/Latte/模型配置.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mini-sora/minisora/c831ce3cbac4c9a81a315c6901145c023814897c/assets/Latte/模型配置.png -------------------------------------------------------------------------------- /assets/SD3论文领读.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mini-sora/minisora/c831ce3cbac4c9a81a315c6901145c023814897c/assets/SD3论文领读.png -------------------------------------------------------------------------------- /assets/logo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mini-sora/minisora/c831ce3cbac4c9a81a315c6901145c023814897c/assets/logo.jpg -------------------------------------------------------------------------------- /assets/qrcode.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mini-sora/minisora/c831ce3cbac4c9a81a315c6901145c023814897c/assets/qrcode.png -------------------------------------------------------------------------------- /assets/sora-reproduce.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mini-sora/minisora/c831ce3cbac4c9a81a315c6901145c023814897c/assets/sora-reproduce.png -------------------------------------------------------------------------------- /assets/sora夜谈.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mini-sora/minisora/c831ce3cbac4c9a81a315c6901145c023814897c/assets/sora夜谈.png -------------------------------------------------------------------------------- /assets/wechatqrcode.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mini-sora/minisora/c831ce3cbac4c9a81a315c6901145c023814897c/assets/wechatqrcode.jpg -------------------------------------------------------------------------------- /codes/OpenDiT/.gitignore: -------------------------------------------------------------------------------- 1 | outputs/ 2 | processed/ 3 | profile/ 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | pip-wheel-metadata/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | docs/.build/ 78 | 79 | # PyBuilder 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 100 | __pypackages__/ 101 | 102 | # Celery stuff 103 | celerybeat-schedule 104 | celerybeat.pid 105 | 106 | # SageMath parsed files 107 | *.sage.py 108 | 109 | # Environments 110 | .env 111 | .venv 112 | env/ 113 | venv/ 114 | ENV/ 115 | env.bak/ 116 | venv.bak/ 117 | 118 | # Spyder project settings 119 | .spyderproject 120 | .spyproject 121 | 122 | # Rope project settings 123 | .ropeproject 124 | 125 | # mkdocs documentation 126 | /site 127 | 128 | # mypy 129 | .mypy_cache/ 130 | .dmypy.json 131 | dmypy.json 132 | 133 | # Pyre type checker 134 | .pyre/ 135 | 136 | # IDE 137 | .idea/ 138 | .vscode/ 139 | 140 | # macos 141 | *.DS_Store 142 | #data/ 143 | 144 | docs/.build 145 | 146 | # pytorch checkpoint 147 | *.pt 148 | 149 | # ignore version.py generated by setup.py 150 | colossalai/version.py 151 | 152 | # ignore any kernel build files 153 | .o 154 | .so 155 | 156 | # ignore python interface defition file 157 | .pyi 158 | 159 | # ignore coverage test file 160 | coverage.lcov 161 | coverage.xml 162 | 163 | # ignore testmon and coverage files 164 | .coverage 165 | .testmondata* 166 | 167 | # ignore data files 168 | datasets 169 | -------------------------------------------------------------------------------- /codes/OpenDiT/.isort.cfg: -------------------------------------------------------------------------------- 1 | [settings] 2 | line_length = 120 3 | multi_line_output=3 4 | include_trailing_comma = true 5 | ignore_comments = true 6 | profile = black 7 | honor_noqa = true 8 | -------------------------------------------------------------------------------- /codes/OpenDiT/.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | 3 | - repo: https://github.com/PyCQA/autoflake 4 | rev: v2.2.1 5 | hooks: 6 | - id: autoflake 7 | name: autoflake (python) 8 | args: ['--in-place', '--remove-unused-variables', '--remove-all-unused-imports', '--ignore-init-module-imports'] 9 | 10 | - repo: https://github.com/pycqa/isort 11 | rev: 5.12.0 12 | hooks: 13 | - id: isort 14 | name: sort all imports (python) 15 | 16 | - repo: https://github.com/psf/black-pre-commit-mirror 17 | rev: 23.9.1 18 | hooks: 19 | - id: black 20 | name: black formatter 21 | args: ['--line-length=120', '--target-version=py37', '--target-version=py38', '--target-version=py39','--target-version=py310'] 22 | 23 | - repo: https://github.com/pre-commit/mirrors-clang-format 24 | rev: v13.0.1 25 | hooks: 26 | - id: clang-format 27 | name: clang formatter 28 | types_or: [c++, c] 29 | 30 | - repo: https://github.com/pre-commit/pre-commit-hooks 31 | rev: v4.3.0 32 | hooks: 33 | - id: check-yaml 34 | - id: check-merge-conflict 35 | - id: check-case-conflict 36 | - id: trailing-whitespace 37 | - id: end-of-file-fixer 38 | - id: mixed-line-ending 39 | args: ['--fix=lf'] 40 | -------------------------------------------------------------------------------- /codes/OpenDiT/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | ## Coding Standards 2 | 3 | ### Unit Tests 4 | We use [PyTest](https://docs.pytest.org/en/latest/) to execute tests. You can install pytest by `pip install pytest`. As some of the tests require initialization of the distributed backend, GPUs are needed to execute these tests. 5 | 6 | To set up the environment for unit testing, first change your current directory to the root directory of your local ColossalAI repository, then run 7 | ```bash 8 | pip install -r requirements/requirements-test.txt 9 | ``` 10 | If you encounter an error telling "Could not find a version that satisfies the requirement fbgemm-gpu==0.2.0", please downgrade your python version to 3.8 or 3.9 and try again. 11 | 12 | If you only want to run CPU tests, you can run 13 | 14 | ```bash 15 | pytest -m cpu tests/ 16 | ``` 17 | 18 | If you have 8 GPUs on your machine, you can run the full test 19 | 20 | ```bash 21 | pytest tests/ 22 | ``` 23 | 24 | If you do not have 8 GPUs on your machine, do not worry. Unit testing will be automatically conducted when you put up a pull request to the main branch. 25 | 26 | 27 | ### Code Style 28 | 29 | We have some static checks when you commit your code change, please make sure you can pass all the tests and make sure the coding style meets our requirements. We use pre-commit hook to make sure the code is aligned with the writing standard. To set up the code style checking, you need to follow the steps below. 30 | 31 | ```shell 32 | # these commands are executed under the Colossal-AI directory 33 | pip install pre-commit 34 | pre-commit install 35 | ``` 36 | 37 | Code format checking will be automatically executed when you commit your changes. 38 | -------------------------------------------------------------------------------- /codes/OpenDiT/figure/dit_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mini-sora/minisora/c831ce3cbac4c9a81a315c6901145c023814897c/codes/OpenDiT/figure/dit_loss.png -------------------------------------------------------------------------------- /codes/OpenDiT/figure/dit_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mini-sora/minisora/c831ce3cbac4c9a81a315c6901145c023814897c/codes/OpenDiT/figure/dit_results.png -------------------------------------------------------------------------------- /codes/OpenDiT/figure/end2end.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mini-sora/minisora/c831ce3cbac4c9a81a315c6901145c023814897c/codes/OpenDiT/figure/end2end.png -------------------------------------------------------------------------------- /codes/OpenDiT/figure/fastseq_exp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mini-sora/minisora/c831ce3cbac4c9a81a315c6901145c023814897c/codes/OpenDiT/figure/fastseq_exp.png -------------------------------------------------------------------------------- /codes/OpenDiT/figure/fastseq_overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mini-sora/minisora/c831ce3cbac4c9a81a315c6901145c023814897c/codes/OpenDiT/figure/fastseq_overview.png -------------------------------------------------------------------------------- /codes/OpenDiT/figure/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mini-sora/minisora/c831ce3cbac4c9a81a315c6901145c023814897c/codes/OpenDiT/figure/logo.png -------------------------------------------------------------------------------- /codes/OpenDiT/figure/wechat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mini-sora/minisora/c831ce3cbac4c9a81a315c6901145c023814897c/codes/OpenDiT/figure/wechat.jpg -------------------------------------------------------------------------------- /codes/OpenDiT/opendit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mini-sora/minisora/c831ce3cbac4c9a81a315c6901145c023814897c/codes/OpenDiT/opendit/__init__.py -------------------------------------------------------------------------------- /codes/OpenDiT/opendit/diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos and Meta DiT 2 | # DiT: https://github.com/facebookresearch/DiT/tree/main 3 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 4 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 5 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 6 | 7 | from . import gaussian_diffusion as gd 8 | from .respace import SpacedDiffusion, space_timesteps 9 | 10 | 11 | def create_diffusion( 12 | timestep_respacing, 13 | noise_schedule="linear", 14 | use_kl=False, 15 | sigma_small=False, 16 | predict_xstart=False, 17 | learn_sigma=True, 18 | rescale_learned_sigmas=False, 19 | diffusion_steps=1000, 20 | ): 21 | betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps) 22 | if use_kl: 23 | loss_type = gd.LossType.RESCALED_KL 24 | elif rescale_learned_sigmas: 25 | loss_type = gd.LossType.RESCALED_MSE 26 | else: 27 | loss_type = gd.LossType.MSE 28 | if timestep_respacing is None or timestep_respacing == "": 29 | timestep_respacing = [diffusion_steps] 30 | return SpacedDiffusion( 31 | use_timesteps=space_timesteps(diffusion_steps, timestep_respacing), 32 | betas=betas, 33 | model_mean_type=(gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X), 34 | model_var_type=( 35 | (gd.ModelVarType.FIXED_LARGE if not sigma_small else gd.ModelVarType.FIXED_SMALL) 36 | if not learn_sigma 37 | else gd.ModelVarType.LEARNED_RANGE 38 | ), 39 | loss_type=loss_type 40 | # rescale_timesteps=rescale_timesteps, 41 | ) 42 | -------------------------------------------------------------------------------- /codes/OpenDiT/opendit/diffusion/diffusion_utils.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos and Meta DiT 2 | # DiT: https://github.com/facebookresearch/DiT/tree/main 3 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 4 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 5 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 6 | 7 | import numpy as np 8 | import torch as th 9 | 10 | 11 | def normal_kl(mean1, logvar1, mean2, logvar2): 12 | """ 13 | Compute the KL divergence between two gaussians. 14 | Shapes are automatically broadcasted, so batches can be compared to 15 | scalars, among other use cases. 16 | """ 17 | tensor = None 18 | for obj in (mean1, logvar1, mean2, logvar2): 19 | if isinstance(obj, th.Tensor): 20 | tensor = obj 21 | break 22 | assert tensor is not None, "at least one argument must be a Tensor" 23 | 24 | # Force variances to be Tensors. Broadcasting helps convert scalars to 25 | # Tensors, but it does not work for th.exp(). 26 | logvar1, logvar2 = [x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) for x in (logvar1, logvar2)] 27 | 28 | return 0.5 * (-1.0 + logvar2 - logvar1 + th.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * th.exp(-logvar2)) 29 | 30 | 31 | def approx_standard_normal_cdf(x): 32 | """ 33 | A fast approximation of the cumulative distribution function of the 34 | standard normal. 35 | """ 36 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) 37 | 38 | 39 | def continuous_gaussian_log_likelihood(x, *, means, log_scales): 40 | """ 41 | Compute the log-likelihood of a continuous Gaussian distribution. 42 | :param x: the targets 43 | :param means: the Gaussian mean Tensor. 44 | :param log_scales: the Gaussian log stddev Tensor. 45 | :return: a tensor like x of log probabilities (in nats). 46 | """ 47 | centered_x = x - means 48 | inv_stdv = th.exp(-log_scales) 49 | normalized_x = centered_x * inv_stdv 50 | log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x) 51 | return log_probs 52 | 53 | 54 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 55 | """ 56 | Compute the log-likelihood of a Gaussian distribution discretizing to a 57 | given image. 58 | :param x: the target images. It is assumed that this was uint8 values, 59 | rescaled to the range [-1, 1]. 60 | :param means: the Gaussian mean Tensor. 61 | :param log_scales: the Gaussian log stddev Tensor. 62 | :return: a tensor like x of log probabilities (in nats). 63 | """ 64 | assert x.shape == means.shape == log_scales.shape 65 | centered_x = x - means 66 | inv_stdv = th.exp(-log_scales) 67 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 68 | cdf_plus = approx_standard_normal_cdf(plus_in) 69 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 70 | cdf_min = approx_standard_normal_cdf(min_in) 71 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 72 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 73 | cdf_delta = cdf_plus - cdf_min 74 | log_probs = th.where( 75 | x < -0.999, 76 | log_cdf_plus, 77 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 78 | ) 79 | assert log_probs.shape == x.shape 80 | return log_probs 81 | -------------------------------------------------------------------------------- /codes/OpenDiT/opendit/embed/clip_text_emb.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import torch.nn as nn 3 | import transformers 4 | from transformers import CLIPTextModel, CLIPTokenizer 5 | 6 | transformers.logging.set_verbosity_error() 7 | 8 | 9 | class AbstractEncoder(nn.Module): 10 | def __init__(self): 11 | super().__init__() 12 | 13 | def encode(self, *args, **kwargs): 14 | raise NotImplementedError 15 | 16 | 17 | class FrozenCLIPEmbedder(AbstractEncoder): 18 | """Uses the CLIP transformer encoder for text (from Hugging Face)""" 19 | 20 | def __init__(self, path="openai/clip-vit-huge-patch14", device="cuda", max_length=77): 21 | super().__init__() 22 | self.tokenizer = CLIPTokenizer.from_pretrained(path) 23 | self.transformer = CLIPTextModel.from_pretrained(path) 24 | self.device = device 25 | self.max_length = max_length 26 | self._freeze() 27 | 28 | def _freeze(self): 29 | self.transformer = self.transformer.eval() 30 | for param in self.parameters(): 31 | param.requires_grad = False 32 | 33 | def forward(self, text): 34 | batch_encoding = self.tokenizer( 35 | text, 36 | truncation=True, 37 | max_length=self.max_length, 38 | return_length=True, 39 | return_overflowing_tokens=False, 40 | padding="max_length", 41 | return_tensors="pt", 42 | ) 43 | tokens = batch_encoding["input_ids"].to(self.device) 44 | outputs = self.transformer(input_ids=tokens) 45 | 46 | z = outputs.last_hidden_state 47 | pooled_z = outputs.pooler_output 48 | return z, pooled_z 49 | 50 | def encode(self, text): 51 | return self(text) 52 | 53 | 54 | class TextEmbedder(nn.Module): 55 | """ 56 | Embeds text prompt into vector representations. Also handles text dropout for classifier-free guidance. 57 | """ 58 | 59 | def __init__(self, path, hidden_size, dropout_prob=0.1): 60 | super().__init__() 61 | self.text_encoder = FrozenCLIPEmbedder(path=path) 62 | self.dropout_prob = dropout_prob 63 | 64 | output_dim = self.text_encoder.transformer.config.hidden_size 65 | self.output_projection = nn.Linear(output_dim, hidden_size) 66 | 67 | def token_drop(self, text_prompts, force_drop_ids=None): 68 | """ 69 | Drops text to enable classifier-free guidance. 70 | """ 71 | if force_drop_ids is None: 72 | drop_ids = numpy.random.uniform(0, 1, len(text_prompts)) < self.dropout_prob 73 | else: 74 | # TODO 75 | drop_ids = force_drop_ids == 1 76 | labels = list(numpy.where(drop_ids, "", text_prompts)) 77 | # print(labels) 78 | return labels 79 | 80 | def forward(self, text_prompts, train, force_drop_ids=None): 81 | use_dropout = self.dropout_prob > 0 82 | if (train and use_dropout) or (force_drop_ids is not None): 83 | text_prompts = self.token_drop(text_prompts, force_drop_ids) 84 | embeddings, pooled_embeddings = self.text_encoder(text_prompts) 85 | # return embeddings, pooled_embeddings 86 | text_embeddings = self.output_projection(pooled_embeddings) 87 | return text_embeddings 88 | -------------------------------------------------------------------------------- /codes/OpenDiT/opendit/embed/label_emb.py: -------------------------------------------------------------------------------- 1 | # Modified from Meta DiT 2 | 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | # -------------------------------------------------------- 6 | # References: 7 | # DiT: https://github.com/facebookresearch/DiT/tree/main 8 | # GLIDE: https://github.com/openai/glide-text2im 9 | # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py 10 | # -------------------------------------------------------- 11 | 12 | 13 | import torch 14 | from torch import nn 15 | 16 | 17 | class LabelEmbedder(nn.Module): 18 | """ 19 | Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. 20 | """ 21 | 22 | def __init__(self, num_classes, hidden_size, dropout_prob): 23 | super().__init__() 24 | use_cfg_embedding = dropout_prob > 0 25 | self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size) 26 | self.num_classes = num_classes 27 | self.dropout_prob = dropout_prob 28 | 29 | def token_drop(self, labels, force_drop_ids=None): 30 | """ 31 | Drops labels to enable classifier-free guidance. 32 | """ 33 | if force_drop_ids is None: 34 | drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob 35 | else: 36 | drop_ids = force_drop_ids == 1 37 | labels = torch.where(drop_ids, self.num_classes, labels) 38 | return labels 39 | 40 | def forward(self, labels, train, force_drop_ids=None): 41 | use_dropout = self.dropout_prob > 0 42 | if (train and use_dropout) or (force_drop_ids is not None): 43 | labels = self.token_drop(labels, force_drop_ids) 44 | embeddings = self.embedding_table(labels) 45 | return embeddings 46 | -------------------------------------------------------------------------------- /codes/OpenDiT/opendit/embed/patch_emb.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch.nn.functional as F 8 | from torch import nn 9 | 10 | 11 | class PatchEmbed3D(nn.Module): 12 | """Video to Patch Embedding. 13 | 14 | Args: 15 | patch_size (int): Patch token size. Default: (2,4,4). 16 | in_chans (int): Number of input video channels. Default: 3. 17 | embed_dim (int): Number of linear projection output channels. Default: 96. 18 | norm_layer (nn.Module, optional): Normalization layer. Default: None 19 | """ 20 | 21 | def __init__( 22 | self, 23 | patch_size=(2, 4, 4), 24 | in_chans=3, 25 | embed_dim=96, 26 | norm_layer=None, 27 | flatten=True, 28 | ): 29 | super().__init__() 30 | self.patch_size = patch_size 31 | self.flatten = flatten 32 | 33 | self.in_chans = in_chans 34 | self.embed_dim = embed_dim 35 | 36 | self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 37 | if norm_layer is not None: 38 | self.norm = norm_layer(embed_dim) 39 | else: 40 | self.norm = None 41 | 42 | def forward(self, x): 43 | """Forward function.""" 44 | # padding 45 | _, _, D, H, W = x.size() 46 | if W % self.patch_size[2] != 0: 47 | x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2])) 48 | if H % self.patch_size[1] != 0: 49 | x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1])) 50 | if D % self.patch_size[0] != 0: 51 | x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0])) 52 | 53 | x = self.proj(x) # (B C T H W) 54 | if self.norm is not None: 55 | D, Wh, Ww = x.size(2), x.size(3), x.size(4) 56 | x = x.flatten(2).transpose(1, 2) 57 | x = self.norm(x) 58 | x = x.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww) 59 | if self.flatten: 60 | x = x.flatten(2).transpose(1, 2) # BCTHW -> BNC 61 | return x 62 | -------------------------------------------------------------------------------- /codes/OpenDiT/opendit/embed/pos_emb.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | 9 | 10 | def get_3d_sincos_pos_embed(embed_dim, grid_size, t_size, cls_token=False): 11 | """ 12 | grid_size: int of the grid height and width 13 | t_size: int of the temporal size 14 | return: 15 | pos_embed: [t_size*grid_size*grid_size, embed_dim] or [1+t_size*grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 16 | """ 17 | assert embed_dim % 4 == 0 18 | embed_dim_spatial = embed_dim // 4 * 3 19 | embed_dim_temporal = embed_dim // 4 20 | 21 | # spatial 22 | grid_h = np.arange(grid_size, dtype=np.float32) 23 | grid_w = np.arange(grid_size, dtype=np.float32) 24 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 25 | grid = np.stack(grid, axis=0) 26 | 27 | grid = grid.reshape([2, 1, grid_size, grid_size]) 28 | pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(embed_dim_spatial, grid) 29 | 30 | # temporal 31 | grid_t = np.arange(t_size, dtype=np.float32) 32 | pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(embed_dim_temporal, grid_t) 33 | 34 | # concate: [T, H, W] order 35 | pos_embed_temporal = pos_embed_temporal[:, np.newaxis, :] 36 | pos_embed_temporal = np.repeat(pos_embed_temporal, grid_size**2, axis=1) # [T, H*W, D // 4] 37 | pos_embed_spatial = pos_embed_spatial[np.newaxis, :, :] 38 | pos_embed_spatial = np.repeat(pos_embed_spatial, t_size, axis=0) # [T, H*W, D // 4 * 3] 39 | 40 | pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial], axis=-1) 41 | pos_embed = pos_embed.reshape([-1, embed_dim]) # [T*H*W, D] 42 | 43 | if cls_token: 44 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 45 | return pos_embed 46 | 47 | 48 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): 49 | """ 50 | grid_size: int of the grid height and width 51 | return: 52 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 53 | """ 54 | grid_h = np.arange(grid_size, dtype=np.float32) 55 | grid_w = np.arange(grid_size, dtype=np.float32) 56 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 57 | grid = np.stack(grid, axis=0) 58 | 59 | grid = grid.reshape([2, 1, grid_size, grid_size]) 60 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 61 | if cls_token and extra_tokens > 0: 62 | pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) 63 | return pos_embed 64 | 65 | 66 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 67 | assert embed_dim % 2 == 0 68 | 69 | # use half of dimensions to encode grid_h 70 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 71 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 72 | 73 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 74 | return emb 75 | 76 | 77 | def get_1d_sincos_pos_embed(embed_dim, length): 78 | pos = np.arange(0, length)[..., None] 79 | return get_1d_sincos_pos_embed_from_grid(embed_dim, pos) 80 | 81 | 82 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 83 | """ 84 | embed_dim: output dimension for each position 85 | pos: a list of positions to be encoded: size (M,) 86 | out: (M, D) 87 | """ 88 | assert embed_dim % 2 == 0 89 | omega = np.arange(embed_dim // 2, dtype=np.float64) 90 | omega /= embed_dim / 2.0 91 | omega = 1.0 / 10000**omega # (D/2,) 92 | 93 | pos = pos.reshape(-1) # (M,) 94 | out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product 95 | 96 | emb_sin = np.sin(out) # (M, D/2) 97 | emb_cos = np.cos(out) # (M, D/2) 98 | 99 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 100 | return emb 101 | -------------------------------------------------------------------------------- /codes/OpenDiT/opendit/embed/time_emb.py: -------------------------------------------------------------------------------- 1 | # Modified from Meta DiT 2 | 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | # -------------------------------------------------------- 6 | # References: 7 | # DiT: https://github.com/facebookresearch/DiT/tree/main 8 | # GLIDE: https://github.com/openai/glide-text2im 9 | # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py 10 | # -------------------------------------------------------- 11 | 12 | 13 | import math 14 | 15 | import torch 16 | from torch import nn 17 | 18 | 19 | class TimestepEmbedder(nn.Module): 20 | """ 21 | Embeds scalar timesteps into vector representations. 22 | """ 23 | 24 | def __init__(self, hidden_size, frequency_embedding_size=256): 25 | super().__init__() 26 | self.mlp = nn.Sequential( 27 | nn.Linear(frequency_embedding_size, hidden_size, bias=True), 28 | nn.SiLU(), 29 | nn.Linear(hidden_size, hidden_size, bias=True), 30 | ) 31 | self.frequency_embedding_size = frequency_embedding_size 32 | 33 | @staticmethod 34 | def timestep_embedding(t, dim, max_period=10000): 35 | """ 36 | Create sinusoidal timestep embeddings. 37 | :param t: a 1-D Tensor of N indices, one per batch element. 38 | These may be fractional. 39 | :param dim: the dimension of the output. 40 | :param max_period: controls the minimum frequency of the embeddings. 41 | :return: an (N, D) Tensor of positional embeddings. 42 | """ 43 | # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py 44 | half = dim // 2 45 | freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( 46 | device=t.device 47 | ) 48 | args = t[:, None].float() * freqs[None] 49 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 50 | if dim % 2: 51 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 52 | return embedding 53 | 54 | def forward(self, t, dtype): 55 | t_freq = self.timestep_embedding(t, self.frequency_embedding_size) 56 | if t_freq.dtype != dtype: 57 | t_freq = t_freq.to(dtype) 58 | t_emb = self.mlp(t_freq) 59 | return t_emb 60 | -------------------------------------------------------------------------------- /codes/OpenDiT/opendit/kernels/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mini-sora/minisora/c831ce3cbac4c9a81a315c6901145c023814897c/codes/OpenDiT/opendit/kernels/__init__.py -------------------------------------------------------------------------------- /codes/OpenDiT/opendit/kernels/fused_modulate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import triton 3 | 4 | from .k_fused_modulate import _modulate_bwd, _modulate_fwd 5 | 6 | 7 | class _FusedModulate(torch.autograd.Function): 8 | @staticmethod 9 | def forward(ctx, x, scale, shift): 10 | y = torch.empty_like(x) 11 | batch, seq_len, dim = x.shape 12 | M = batch * seq_len 13 | N = dim 14 | x = x.view(-1, dim).contiguous() 15 | scale = scale.view(-1, dim).contiguous() 16 | shift = shift.view(-1, dim).contiguous() 17 | 18 | def grid(meta): 19 | return ( 20 | triton.cdiv(batch * seq_len, meta["BLOCK_M"]), 21 | triton.cdiv(dim, meta["BLOCK_N"]), 22 | ) 23 | 24 | _modulate_fwd[grid](x, y, scale, shift, x.stride(0), scale.stride(0), M, N, seq_len) 25 | 26 | ctx.save_for_backward(x, scale) 27 | ctx.batch = batch 28 | ctx.seq_len = seq_len 29 | ctx.dim = dim 30 | return y 31 | 32 | @staticmethod 33 | def backward(ctx, dy): # pragma: no cover # this is covered, but called directly from C++ 34 | x, scale = ctx.saved_tensors 35 | 36 | batch, seq_len, dim = ctx.batch, ctx.seq_len, ctx.dim 37 | M = batch * seq_len 38 | N = dim 39 | 40 | # allocate output 41 | dy = dy.contiguous() 42 | dx = torch.empty_like(dy) 43 | dscale = torch.empty_like(dy) 44 | dshift = torch.sum(dy, dim=1) 45 | 46 | def grid(meta): 47 | return ( 48 | triton.cdiv(batch * seq_len, meta["BLOCK_M"]), 49 | triton.cdiv(dim, meta["BLOCK_N"]), 50 | ) 51 | 52 | _modulate_bwd[grid](dx, x, dy, scale, dscale, x.stride(0), scale.stride(0), M, N, seq_len) 53 | 54 | dscale = torch.sum(dscale, dim=1) 55 | return dx, dscale, dshift 56 | 57 | 58 | def fused_modulate( 59 | x: torch.Tensor, 60 | scale: torch.Tensor, 61 | shift: torch.Tensor, 62 | ) -> torch.Tensor: 63 | return _FusedModulate.apply(x, scale, shift) 64 | -------------------------------------------------------------------------------- /codes/OpenDiT/opendit/kernels/k_fused_modulate.py: -------------------------------------------------------------------------------- 1 | import triton 2 | import triton.language as tl 3 | 4 | CONFIG_LIST = [ 5 | triton.Config({"BLOCK_M": 256, "BLOCK_N": 32}, num_stages=2, num_warps=4), 6 | triton.Config({"BLOCK_M": 128, "BLOCK_N": 64}, num_stages=2, num_warps=4), 7 | triton.Config({"BLOCK_M": 128, "BLOCK_N": 32}, num_stages=2, num_warps=4), 8 | triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_stages=2, num_warps=4), 9 | triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_stages=2, num_warps=4), 10 | triton.Config({"BLOCK_M": 64, "BLOCK_N": 32}, num_stages=2, num_warps=4), 11 | triton.Config({"BLOCK_M": 32, "BLOCK_N": 64}, num_stages=2, num_warps=4), 12 | triton.Config({"BLOCK_M": 32, "BLOCK_N": 128}, num_stages=2, num_warps=4), 13 | triton.Config({"BLOCK_M": 32, "BLOCK_N": 256}, num_stages=2, num_warps=4), 14 | ] 15 | 16 | 17 | @triton.autotune( 18 | configs=CONFIG_LIST, 19 | key=["M", "N"], 20 | ) 21 | @triton.jit 22 | def _modulate_fwd( 23 | x_ptr, # *Pointer* to first input vector. 24 | output_ptr, # *Pointer* to output vector. 25 | scale_ptr, 26 | shift_ptr, 27 | m_stride, 28 | s_stride, 29 | M, 30 | N, 31 | seq_len, 32 | BLOCK_M: tl.constexpr, # Number of elements each program should process. 33 | BLOCK_N: tl.constexpr, 34 | # NOTE: `constexpr` so it can be used as a shape value. 35 | ): 36 | row_id = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0. 37 | rows = row_id * BLOCK_M + tl.arange(0, BLOCK_M) 38 | s_rows = (row_id // seq_len) * BLOCK_M 39 | col_id = tl.program_id(axis=1) 40 | cols = col_id * BLOCK_N + tl.arange(0, BLOCK_N) 41 | 42 | x_ptrs = x_ptr + rows[:, None] * m_stride + cols[None, :] 43 | scale_ptrs = scale_ptr + s_rows * s_stride + cols[None, :] 44 | shift_ptrs = shift_ptr + s_rows * s_stride + cols[None, :] 45 | 46 | col_mask = cols[None, :] < N 47 | block_mask = (rows[:, None] < M) & col_mask 48 | s_block_mask = col_mask 49 | x = tl.load(x_ptrs, mask=block_mask, other=0.0) 50 | scale = tl.load(scale_ptrs, mask=s_block_mask, other=0.0) 51 | shift = tl.load(shift_ptrs, mask=s_block_mask, other=0.0) 52 | 53 | output = x * (1 + scale) + shift 54 | # Write x + y back to DRAM. 55 | tl.store(output_ptr + rows[:, None] * m_stride + cols[None, :], output, mask=block_mask) 56 | 57 | 58 | @triton.autotune( 59 | configs=CONFIG_LIST, 60 | key=["M", "N"], 61 | ) 62 | @triton.jit 63 | def _modulate_bwd( 64 | dx_ptr, # *Pointer* to first input vector. 65 | x_ptr, 66 | dy_ptr, # *Pointer* to output vector. 67 | scale_ptr, 68 | dscale_ptr, 69 | m_stride, 70 | s_stride, 71 | M, 72 | N, 73 | seq_len, 74 | BLOCK_M: tl.constexpr, # Number of elements each program should process. 75 | BLOCK_N: tl.constexpr, 76 | # NOTE: `constexpr` so it can be used as a shape value. 77 | ): 78 | row_id = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0. 79 | rows = row_id * BLOCK_M + tl.arange(0, BLOCK_M) 80 | s_rows = (row_id // seq_len) * BLOCK_M 81 | col_id = tl.program_id(axis=1) 82 | cols = col_id * BLOCK_N + tl.arange(0, BLOCK_N) 83 | 84 | x_ptrs = x_ptr + rows[:, None] * m_stride + cols[None, :] 85 | dy_ptrs = dy_ptr + rows[:, None] * m_stride + cols[None, :] 86 | dx_ptrs = dx_ptr + rows[:, None] * m_stride + cols[None, :] 87 | dscale_ptrs = dscale_ptr + rows[:, None] * m_stride + cols[None, :] 88 | 89 | scale_ptrs = scale_ptr + s_rows * s_stride + cols[None, :] 90 | 91 | col_mask = cols[None, :] < N 92 | block_mask = (rows[:, None] < M) & col_mask 93 | s_block_mask = col_mask 94 | x = tl.load(x_ptrs, mask=block_mask, other=0.0) 95 | dy = tl.load(dy_ptrs, mask=block_mask, other=0.0) 96 | scale = tl.load(scale_ptrs, mask=s_block_mask, other=0.0) 97 | 98 | dx = dy * (1 + scale) 99 | dscale = dy * x 100 | # Write x + y back to DRAM. 101 | tl.store(dx_ptrs, dx, mask=block_mask) 102 | tl.store(dscale_ptrs, dscale, mask=block_mask) 103 | -------------------------------------------------------------------------------- /codes/OpenDiT/opendit/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mini-sora/minisora/c831ce3cbac4c9a81a315c6901145c023814897c/codes/OpenDiT/opendit/models/__init__.py -------------------------------------------------------------------------------- /codes/OpenDiT/opendit/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mini-sora/minisora/c831ce3cbac4c9a81a315c6901145c023814897c/codes/OpenDiT/opendit/modules/__init__.py -------------------------------------------------------------------------------- /codes/OpenDiT/opendit/utils/debug_utils.py: -------------------------------------------------------------------------------- 1 | import torch.distributed as dist 2 | 3 | 4 | # Print debug information on selected rank 5 | def print_rank(var_name, var_value, rank=0): 6 | if dist.get_rank() == rank: 7 | print(f"[Rank {rank}] {var_name}: {var_value}") 8 | -------------------------------------------------------------------------------- /codes/OpenDiT/opendit/utils/download.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """ 8 | Functions for downloading pre-trained DiT models 9 | """ 10 | import os 11 | import json 12 | import torch 13 | from torchvision.datasets.utils import download_url 14 | 15 | pretrained_models = {"DiT-XL-2-512x512.pt", "DiT-XL-2-256x256.pt"} 16 | 17 | 18 | def find_model(model_name): 19 | """ 20 | Finds a pre-trained DiT model, downloading it if necessary. Alternatively, loads a model from a local path. 21 | """ 22 | if model_name in pretrained_models: # Find/download our pre-trained DiT checkpoints 23 | return download_model(model_name) 24 | else: # Load a custom DiT checkpoint: 25 | if not os.path.isfile(model_name): 26 | # if the model_name is a directory, then we assume we should load it in the Hugging Face manner 27 | # i.e. the model weights are sharded into multiple files and there is an index.json file 28 | # walk through the files in the directory and find the index.json file 29 | index_file = [os.path.join(model_name, f) for f in os.listdir(model_name) if "index.json" in f] 30 | assert len(index_file) == 1, f"Could not find index.json in {model_name}" 31 | 32 | # process index json 33 | with open (index_file[0], "r") as f: 34 | index_data = json.load(f) 35 | 36 | bin_to_weight_mapping = dict() 37 | for k, v in index_data['weight_map'].items(): 38 | if v in bin_to_weight_mapping: 39 | bin_to_weight_mapping[v].append(k) 40 | else: 41 | bin_to_weight_mapping[v] = [k] 42 | 43 | # make state dict 44 | state_dict = dict() 45 | for bin_name, weight_list in bin_to_weight_mapping.items(): 46 | bin_path = os.path.join(model_name, bin_name) 47 | bin_state_dict = torch.load(bin_path, map_location=lambda storage, loc: storage) 48 | for weight in weight_list: 49 | state_dict[weight] = bin_state_dict[weight] 50 | return state_dict 51 | else: 52 | # if it is a file, we just load it directly in the typical PyTorch manner 53 | assert os.path.exists(model_name), f"Could not find DiT checkpoint at {model_name}" 54 | checkpoint = torch.load(model_name, map_location=lambda storage, loc: storage) 55 | if "ema" in checkpoint: # supports checkpoints from train.py 56 | checkpoint = checkpoint["ema"] 57 | return checkpoint 58 | 59 | 60 | def download_model(model_name): 61 | """ 62 | Downloads a pre-trained DiT model from the web. 63 | """ 64 | assert model_name in pretrained_models 65 | local_path = f"pretrained_models/{model_name}" 66 | if not os.path.isfile(local_path): 67 | os.makedirs("pretrained_models", exist_ok=True) 68 | web_path = f"https://dl.fbaipublicfiles.com/DiT/models/{model_name}" 69 | download_url(web_path, "pretrained_models") 70 | model = torch.load(local_path, map_location=lambda storage, loc: storage) 71 | return model 72 | 73 | 74 | if __name__ == "__main__": 75 | # Download all DiT checkpoints 76 | for model in pretrained_models: 77 | download_model(model) 78 | print("Done.") 79 | -------------------------------------------------------------------------------- /codes/OpenDiT/opendit/utils/pg_utils.py: -------------------------------------------------------------------------------- 1 | from colossalai.cluster.process_group_mesh import ProcessGroupMesh 2 | from torch.distributed import ProcessGroup 3 | 4 | 5 | class ProcessGroupManager(ProcessGroupMesh): 6 | def __init__(self, *size: int, dp_axis, sp_axis): 7 | super().__init__(*size) 8 | self.dp_axis = dp_axis 9 | self.sp_axis = sp_axis 10 | self._dp_group: ProcessGroup = self.get_group_along_axis(self.dp_axis) 11 | self._sp_group: ProcessGroup = self.get_group_along_axis(self.sp_axis) 12 | 13 | @property 14 | def dp_group(self) -> ProcessGroup: 15 | return self._dp_group 16 | 17 | @property 18 | def sp_group(self) -> ProcessGroup: 19 | return self._sp_group 20 | -------------------------------------------------------------------------------- /codes/OpenDiT/opendit/utils/train_utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | import torch.distributed as dist 5 | from colossalai.zero.low_level.low_level_optim import LowLevelZeroOptimizer 6 | 7 | from opendit.models.dit import DiT 8 | from opendit.models.latte import Latte 9 | 10 | 11 | def get_model_numel(model: torch.nn.Module) -> int: 12 | return sum(p.numel() for p in model.parameters()) 13 | 14 | 15 | def format_numel_str(numel: int) -> str: 16 | B = 1024**3 17 | M = 1024**2 18 | K = 1024 19 | if numel >= B: 20 | return f"{numel / B:.2f} B" 21 | elif numel >= M: 22 | return f"{numel / M:.2f} M" 23 | elif numel >= K: 24 | return f"{numel / K:.2f} K" 25 | else: 26 | return f"{numel}" 27 | 28 | 29 | def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor: 30 | dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM) 31 | tensor.div_(dist.get_world_size()) 32 | return tensor 33 | 34 | 35 | @torch.no_grad() 36 | def update_ema( 37 | ema_model: torch.nn.Module, model: torch.nn.Module, optimizer=None, decay: float = 0.9999, sharded: bool = True 38 | ) -> None: 39 | """ 40 | Step the EMA model towards the current model. 41 | """ 42 | if not (isinstance(model, DiT) or isinstance(model, Latte)): 43 | model = model.module 44 | ema_params = OrderedDict(ema_model.named_parameters()) 45 | model_params = OrderedDict(model.named_parameters()) 46 | 47 | for name, param in model_params.items(): 48 | if name == "pos_embed": 49 | continue 50 | if param.requires_grad == False: 51 | continue 52 | if not sharded: 53 | param_data = param.data 54 | ema_params[name].mul_(decay).add_(param_data, alpha=1 - decay) 55 | else: 56 | if param.data.dtype != torch.float32 and isinstance(optimizer, LowLevelZeroOptimizer): 57 | param_id = id(param) 58 | master_param = optimizer._param_store.working_to_master_param[param_id] 59 | param_data = master_param.data 60 | else: 61 | param_data = param.data 62 | ema_params[name].mul_(decay).add_(param_data, alpha=1 - decay) 63 | 64 | 65 | def requires_grad(model: torch.nn.Module, flag: bool = True) -> None: 66 | """ 67 | Set requires_grad flag for all parameters in a model. 68 | """ 69 | for p in model.parameters(): 70 | p.requires_grad = flag 71 | -------------------------------------------------------------------------------- /codes/OpenDiT/opendit/vae/download.py: -------------------------------------------------------------------------------- 1 | # This code is copied from https://github.com/wilson1yan/VideoGPT 2 | # Copyright (c) 2021 Wilson Yan. All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import os 9 | 10 | import gdown 11 | import torch 12 | 13 | from .vqvae import VQVAE 14 | 15 | 16 | def download(id, fname, root=os.path.expanduser("~/.cache/videogpt")): 17 | os.makedirs(root, exist_ok=True) 18 | destination = os.path.join(root, fname) 19 | 20 | if os.path.exists(destination): 21 | return destination 22 | 23 | gdown.download(id=id, output=destination, quiet=False) 24 | return destination 25 | 26 | 27 | _VQVAE = { 28 | "bair_stride4x2x2": "1iIAYJ2Qqrx5Q94s5eIXQYJgAydzvT_8L", # trained on 16 frames of 64 x 64 images 29 | "ucf101_stride4x4x4": "1uuB_8WzHP_bbBmfuaIV7PK_Itl3DyHY5", # trained on 16 frames of 128 x 128 images 30 | "kinetics_stride4x4x4": "1DOvOZnFAIQmux6hG7pN_HkyJZy3lXbCB", # trained on 16 frames of 128 x 128 images 31 | "kinetics_stride2x4x4": "1jvtjjtrtE4cy6pl7DK_zWFEPY3RZt2pB", # trained on 16 frames of 128 x 128 images 32 | } 33 | 34 | 35 | def load_vqvae(model_name, device=torch.device("cpu"), root=os.path.expanduser("~/.cache/videogpt")): 36 | assert model_name in _VQVAE, f"Invalid model_name: {model_name}" 37 | filepath = download(_VQVAE[model_name], model_name, root=root) 38 | vqvae = VQVAE.load_from_checkpoint(filepath).to(device) 39 | vqvae.eval() 40 | 41 | return vqvae 42 | -------------------------------------------------------------------------------- /codes/OpenDiT/opendit/vae/reconstruct.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from diffusers.models import AutoencoderKL 5 | from torchvision.io import write_video 6 | from torchvision.utils import save_image 7 | 8 | from opendit.vae.wrapper import AutoencoderKLWrapper 9 | 10 | 11 | def t2v(x): 12 | x = (x * 0.5 + 0.5).clamp(0, 1) 13 | x = (x * 255).to(torch.uint8) 14 | x = x.permute(1, 2, 3, 0).cpu() 15 | return x 16 | 17 | 18 | def save_sample(x, real=None): 19 | B = x.size(0) 20 | nrows = B // int(B**0.5) 21 | if x.size(2) == 1: 22 | path = "sample.png" 23 | x = x.squeeze(2) 24 | if real is not None: 25 | real = real.squeeze(2) 26 | x = torch.cat([real, x], dim=-1) 27 | save_image(x, path, nrow=nrows, normalize=True, value_range=(-1, 1)) 28 | print(f"Sampled images saved to {path}") 29 | else: 30 | path_dir = "sample_videos" 31 | os.makedirs(path_dir, exist_ok=True) 32 | for i in range(B): 33 | path = os.path.join(path_dir, f"sample_{i}.mp4") 34 | x_i = t2v(x[i]) 35 | if real is not None: 36 | real_i = t2v(real[i]) 37 | x_i = torch.cat([real_i, x_i], dim=-2) 38 | write_video(path, x_i, fps=20, video_codec="h264") 39 | print(f"Sampled video saved to {path}") 40 | 41 | 42 | @torch.no_grad() 43 | def reconstruct(args, data) -> None: 44 | device = "cuda" if torch.cuda.is_available() else "cpu" 45 | vae = AutoencoderKL.from_pretrained(args.vae) 46 | vae = AutoencoderKLWrapper(vae) 47 | vae = vae.to(device) 48 | data = data.to(device) 49 | 50 | x = vae.encode(data) 51 | x = vae.decode(x) 52 | save_sample(x, real=data) 53 | -------------------------------------------------------------------------------- /codes/OpenDiT/opendit/vae/utils.py: -------------------------------------------------------------------------------- 1 | # This code is copied from https://github.com/wilson1yan/VideoGPT 2 | # Copyright (c) 2021 Wilson Yan. All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | 9 | import numpy as np 10 | import skvideo.io 11 | 12 | 13 | # Shifts src_tf dim to dest dim 14 | # i.e. shift_dim(x, 1, -1) would be (b, c, t, h, w) -> (b, t, h, w, c) 15 | def shift_dim(x, src_dim=-1, dest_dim=-1, make_contiguous=True): 16 | n_dims = len(x.shape) 17 | if src_dim < 0: 18 | src_dim = n_dims + src_dim 19 | if dest_dim < 0: 20 | dest_dim = n_dims + dest_dim 21 | 22 | assert 0 <= src_dim < n_dims and 0 <= dest_dim < n_dims 23 | 24 | dims = list(range(n_dims)) 25 | del dims[src_dim] 26 | 27 | permutation = [] 28 | ctr = 0 29 | for i in range(n_dims): 30 | if i == dest_dim: 31 | permutation.append(src_dim) 32 | else: 33 | permutation.append(dims[ctr]) 34 | ctr += 1 35 | x = x.permute(permutation) 36 | if make_contiguous: 37 | x = x.contiguous() 38 | return x 39 | 40 | 41 | # reshapes tensor start from dim i (inclusive) 42 | # to dim j (exclusive) to the desired shape 43 | # e.g. if x.shape = (b, thw, c) then 44 | # view_range(x, 1, 2, (t, h, w)) returns 45 | # x of shape (b, t, h, w, c) 46 | def view_range(x, i, j, shape): 47 | shape = tuple(shape) 48 | 49 | n_dims = len(x.shape) 50 | if i < 0: 51 | i = n_dims + i 52 | 53 | if j is None: 54 | j = n_dims 55 | elif j < 0: 56 | j = n_dims + j 57 | 58 | assert 0 <= i < j <= n_dims 59 | 60 | x_shape = x.shape 61 | target_shape = x_shape[:i] + shape + x_shape[j:] 62 | return x.view(target_shape) 63 | 64 | 65 | def tensor_slice(x, begin, size): 66 | assert all([b >= 0 for b in begin]) 67 | size = [l - b if s == -1 else s for s, b, l in zip(size, begin, x.shape)] 68 | assert all([s >= 0 for s in size]) 69 | 70 | slices = [slice(b, b + s) for b, s in zip(begin, size)] 71 | return x[slices] 72 | 73 | 74 | def save_video_grid(video, fname, nrow=None): 75 | b, c, t, h, w = video.shape 76 | video = video.permute(0, 2, 3, 4, 1) 77 | video = (video.cpu().numpy() * 255).astype("uint8") 78 | 79 | if nrow is None: 80 | nrow = math.ceil(math.sqrt(b)) 81 | ncol = math.ceil(b / nrow) 82 | padding = 1 83 | video_grid = np.zeros((t, (padding + h) * nrow + padding, (padding + w) * ncol + padding, c), dtype="uint8") 84 | for i in range(b): 85 | r = i // ncol 86 | c = i % ncol 87 | 88 | start_r = (padding + h) * r 89 | start_c = (padding + w) * c 90 | video_grid[:, start_r : start_r + h, start_c : start_c + w] = video[i] 91 | 92 | skvideo.io.vwrite(fname, video_grid, inputdict={"-r": "5"}) 93 | print("saved videos to", fname) 94 | -------------------------------------------------------------------------------- /codes/OpenDiT/opendit/vae/wrapper.py: -------------------------------------------------------------------------------- 1 | from einops import rearrange 2 | from torch import nn 3 | 4 | 5 | class AutoencoderKLWrapper(nn.Module): 6 | def __init__(self, vae): 7 | super().__init__() 8 | self.module = vae 9 | self.out_channels = vae.config.latent_channels 10 | self.patch_size = [1, 8, 8] 11 | 12 | def encode(self, x): 13 | # x is (B, C, T, H, W) 14 | B = x.shape[0] 15 | x = rearrange(x, "b c t h w -> (b t) c h w") 16 | x = self.module.encode(x).latent_dist.sample().mul_(0.18215) 17 | x = rearrange(x, "(b t) c h w -> b c t h w", b=B) 18 | return x 19 | 20 | def decode(self, x): 21 | # x is (B, C, T, H, W) 22 | B = x.shape[0] 23 | x = rearrange(x, "b c t h w -> (b t) c h w") 24 | x = self.module.decode(x / 0.18215).sample 25 | x = rearrange(x, "(b t) c h w -> b c t h w", b=B) 26 | return x 27 | -------------------------------------------------------------------------------- /codes/OpenDiT/preprocess.py: -------------------------------------------------------------------------------- 1 | # This script is used to generate a csv file for the UCF101 dataset. 2 | # The csv file contains the path to each video and its corresponding class. 3 | # The csv file will be used to load the dataset in the training script. 4 | 5 | 6 | import csv 7 | import os 8 | 9 | 10 | def get_filelist(file_path): 11 | Filelist = [] 12 | for home, dirs, files in os.walk(file_path): 13 | for filename in files: 14 | Filelist.append(os.path.join(home, filename)) 15 | return Filelist 16 | 17 | 18 | def split_by_capital(name): 19 | # BoxingPunchingBag -> Boxing Punching Bag 20 | new_name = "" 21 | for i in range(len(name)): 22 | if name[i].isupper() and i != 0: 23 | new_name += " " 24 | new_name += name[i] 25 | return new_name 26 | 27 | 28 | root = "path/to/ucf101" 29 | split = "train" 30 | 31 | root = os.path.expanduser(root) 32 | video_lists = get_filelist(os.path.join(root, split)) 33 | classes = [x.split("/")[-2] for x in video_lists] 34 | classes = [split_by_capital(x) for x in classes] 35 | samples = list(zip(video_lists, classes)) 36 | 37 | with open(f"preprocess/ucf101_{split}.csv", "w") as f: 38 | writer = csv.writer(f) 39 | writer.writerows(samples) 40 | 41 | print(f"Saved {len(samples)} samples to preprocess/ucf101_{split}.csv.") 42 | -------------------------------------------------------------------------------- /codes/OpenDiT/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | tqdm 3 | psutil 4 | packaging 5 | pre-commit 6 | rich 7 | click 8 | fabric 9 | contexttimer 10 | ninja 11 | torch>=1.13 12 | safetensors 13 | einops 14 | pydantic 15 | ray 16 | protobuf 17 | pytorch_lightning 18 | h5py 19 | gdown 20 | scikit-video 21 | pyav 22 | tensorboard 23 | timm 24 | matplotlib 25 | accelerate 26 | diffusers 27 | transformers 28 | flash_attn 29 | colossalai 30 | -------------------------------------------------------------------------------- /codes/OpenDiT/sample_img.sh: -------------------------------------------------------------------------------- 1 | python sample.py \ 2 | --model DiT-XL/2 \ 3 | --image_size 256 \ 4 | --num_classes 10 \ 5 | --ckpt ckpt_path 6 | -------------------------------------------------------------------------------- /codes/OpenDiT/sample_video.sh: -------------------------------------------------------------------------------- 1 | python sample.py \ 2 | --model VDiT-XL/1x2x2 \ 3 | --use_video \ 4 | --ckpt ckpt_path \ 5 | --num_frames 16 \ 6 | --image_size 256 \ 7 | --frame_interval 3 8 | -------------------------------------------------------------------------------- /codes/OpenDiT/setup.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from setuptools import find_packages, setup 4 | 5 | 6 | def fetch_requirements(path) -> List[str]: 7 | """ 8 | This function reads the requirements file. 9 | 10 | Args: 11 | path (str): the path to the requirements file. 12 | 13 | Returns: 14 | The lines in the requirements file. 15 | """ 16 | with open(path, "r") as fd: 17 | return [r.strip() for r in fd.readlines()] 18 | 19 | 20 | def fetch_readme() -> str: 21 | """ 22 | This function reads the README.md file in the current directory. 23 | 24 | Returns: 25 | The lines in the README file. 26 | """ 27 | with open("README.md", encoding="utf-8") as f: 28 | return f.read() 29 | 30 | 31 | setup( 32 | name="opendit", 33 | version="0.1.0", 34 | packages=find_packages( 35 | exclude=( 36 | "videos", 37 | "tests", 38 | "figure", 39 | "*.egg-info", 40 | ) 41 | ), 42 | description="OpenDiT", 43 | long_description=fetch_readme(), 44 | long_description_content_type="text/markdown", 45 | license="Apache Software License 2.0", 46 | install_requires=fetch_requirements("requirements.txt"), 47 | python_requires=">=3.6", 48 | classifiers=[ 49 | "Programming Language :: Python :: 3", 50 | "License :: OSI Approved :: Apache Software License", 51 | "Environment :: GPU :: NVIDIA CUDA", 52 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 53 | "Topic :: System :: Distributed Computing", 54 | ], 55 | ) 56 | -------------------------------------------------------------------------------- /codes/OpenDiT/tests/test_checkpoint.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | import colossalai 5 | import pytest 6 | import torch 7 | import torch.distributed as dist 8 | from colossalai.booster import Booster 9 | from colossalai.booster.plugin import LowLevelZeroPlugin 10 | from colossalai.nn.optimizer import HybridAdam 11 | from colossalai.testing import check_state_dict_equal, rerun_if_address_is_in_use, spawn 12 | from colossalai.zero import LowLevelZeroOptimizer 13 | 14 | from opendit.models.dit import DiT 15 | 16 | 17 | def run_zero_checkpoint(stage: int, shard: bool, offload: bool): 18 | plugin = LowLevelZeroPlugin(precision="fp16", stage=stage, max_norm=1.0, initial_scale=32, cpu_offload=offload) 19 | booster = Booster(plugin=plugin) 20 | model = DiT(depth=2, hidden_size=64, patch_size=2, num_heads=4, dtype=torch.float16).half() 21 | criterion = lambda x: x.mean() 22 | optimizer = HybridAdam((model.parameters()), lr=0.001) 23 | model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) 24 | 25 | x = torch.randn(2, 4, 32, 32).cuda().requires_grad_(True) 26 | y = torch.randint(0, 10, (2,)).cuda() 27 | t = torch.randint(0, 10, (2,)).cuda() 28 | output = model(x, y, t) 29 | loss = criterion(output) 30 | booster.backward(loss, optimizer) 31 | optimizer.step() 32 | 33 | tempdir = "./tempdir" 34 | if dist.get_rank() == 0: 35 | if os.path.exists(tempdir): 36 | shutil.rmtree(tempdir) 37 | os.makedirs(tempdir) 38 | dist.barrier() 39 | 40 | model_ckpt_path = f"{tempdir}/model" 41 | optimizer_ckpt_path = f"{tempdir}/optimizer" 42 | booster.save_model(model, model_ckpt_path, shard=shard) 43 | booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard) 44 | 45 | dist.barrier() 46 | 47 | new_model = DiT(depth=2, hidden_size=64, patch_size=2, num_heads=4, dtype=torch.float16).half() 48 | new_optimizer = HybridAdam((new_model.parameters()), lr=0.001) 49 | new_model, new_optimizer, _, _, _ = booster.boost(new_model, new_optimizer) 50 | 51 | booster.load_model(new_model, model_ckpt_path) 52 | check_state_dict_equal(model.state_dict(), new_model.state_dict(), False) 53 | # check master weight 54 | assert isinstance(new_optimizer, LowLevelZeroOptimizer) 55 | working_param_id_set = set(id(p) for p in new_model.parameters()) 56 | for p_id, master_param in new_optimizer._param_store.working_to_master_param.items(): 57 | assert p_id in working_param_id_set 58 | working_param = new_optimizer._param_store.master_to_working_param[id(master_param)] 59 | padding = new_optimizer._param_store.get_param_padding_size(working_param) 60 | padded_param = torch.nn.functional.pad(working_param.data.view(-1), (0, padding)) 61 | working_shard = padded_param.chunk(dist.get_world_size())[dist.get_rank()] 62 | assert torch.equal( 63 | working_shard, master_param.data.view(-1).to(dtype=padded_param.dtype, device=padded_param.device) 64 | ) 65 | 66 | booster.load_optimizer(new_optimizer, optimizer_ckpt_path) 67 | check_state_dict_equal(optimizer.optim.state_dict(), new_optimizer.optim.state_dict(), False) 68 | dist.barrier() 69 | 70 | if dist.get_rank() == 0: 71 | shutil.rmtree(tempdir) 72 | dist.barrier() 73 | 74 | 75 | def run_dist(rank, world_size, port, stage: int, shard: bool, offload: bool): 76 | colossalai.launch(config=(dict()), rank=rank, world_size=world_size, port=port, host="localhost") 77 | run_zero_checkpoint(stage=stage, shard=shard, offload=offload) 78 | 79 | 80 | @pytest.mark.parametrize("stage", [2]) 81 | @pytest.mark.parametrize("shard", [True, False]) 82 | @pytest.mark.parametrize("offload", [False, True]) 83 | @rerun_if_address_is_in_use() 84 | def test_zero_checkpoint(stage, shard, offload): 85 | spawn(run_dist, 2, stage=stage, shard=shard, offload=offload) 86 | 87 | 88 | if __name__ == "__main__": 89 | test_zero_checkpoint(2, True, False) 90 | -------------------------------------------------------------------------------- /codes/OpenDiT/tests/test_clip.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from opendit.embed.clip_text_emb import TextEmbedder 4 | 5 | if __name__ == "__main__": 6 | r""" 7 | Returns: 8 | 9 | Examples from CLIPTextModel: 10 | 11 | ```python 12 | >>> from transformers import AutoTokenizer, CLIPTextModel 13 | 14 | >>> model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32") 15 | >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32") 16 | 17 | >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt") 18 | 19 | >>> outputs = model(**inputs) 20 | >>> last_hidden_state = outputs.last_hidden_state 21 | >>> pooled_output = outputs.pooler_output # pooled (EOS token) states 22 | ```""" 23 | 24 | device = "cuda" if torch.cuda.is_available() else "cpu" 25 | 26 | text_encoder = TextEmbedder(path="openai/clip-vit-base-patch32", dropout_prob=0.00001).to(device) 27 | 28 | text_prompt = [ 29 | ["a photo of a cat", "a photo of a cat"], 30 | ["a photo of a dog", "a photo of a cat"], 31 | ["a photo of a dog human", "a photo of a cat"], 32 | ] 33 | # text_prompt = ('None', 'None', 'None') 34 | output, pooled_output = text_encoder(text_prompts=text_prompt, train=False) 35 | # print(output) 36 | print(output.shape) 37 | print(pooled_output.shape) 38 | # print(output.shape) 39 | -------------------------------------------------------------------------------- /codes/OpenDiT/tests/test_dataloader.py: -------------------------------------------------------------------------------- 1 | import colossalai 2 | import torch 3 | import torch.distributed as dist 4 | from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn 5 | from colossalai.utils import get_current_device 6 | from torchvision import transforms 7 | from torchvision.datasets import CIFAR10 8 | 9 | from opendit.utils.data_utils import center_crop_arr, prepare_dataloader 10 | from opendit.utils.pg_utils import ProcessGroupManager 11 | 12 | WORKERS = 4 13 | 14 | 15 | @parameterize("batch_size", [2]) 16 | @parameterize("sequence_parallel_size", [2, 4]) 17 | @parameterize("image_size", [256]) 18 | def run_dataloader_test(batch_size, sequence_parallel_size, image_size, num_workers=0, data_path="../datasets"): 19 | sp_size = sequence_parallel_size 20 | dp_size = dist.get_world_size() // sp_size 21 | pg_manager = ProcessGroupManager(dp_size, sp_size, dp_axis=0, sp_axis=1) 22 | device = get_current_device() 23 | 24 | # Setup data: 25 | transform = transforms.Compose( 26 | [ 27 | transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, image_size)), 28 | transforms.RandomHorizontalFlip(), 29 | transforms.ToTensor(), 30 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), 31 | ] 32 | ) 33 | dataset = CIFAR10(data_path, transform=transform, download=True) 34 | dataloader = prepare_dataloader( 35 | dataset, 36 | batch_size=batch_size, 37 | shuffle=True, 38 | drop_last=True, 39 | pin_memory=True, 40 | num_workers=num_workers, 41 | pg_manager=pg_manager, 42 | ) 43 | dataloader_iter = iter(dataloader) 44 | x, y = next(dataloader_iter) 45 | x = x.to(device) 46 | y = y.to(device) 47 | 48 | x_list = [torch.empty_like(x) for _ in range(dist.get_world_size())] 49 | y_list = [torch.empty_like(y) for _ in range(dist.get_world_size())] 50 | 51 | dist.all_gather(x_list, x) 52 | dist.all_gather(y_list, y) 53 | 54 | sp_group_ranks = pg_manager.get_ranks_in_group(pg_manager.sp_group) 55 | dp_group_ranks = pg_manager.get_ranks_in_group(pg_manager.dp_group) 56 | 57 | for rank in sp_group_ranks: 58 | if rank != dist.get_rank(): 59 | assert torch.allclose( 60 | x_list[rank], x_list[dist.get_rank()] 61 | ), f"x in rank {rank} and {dist.get_rank()} are not equal in the same sequence parallel group." 62 | assert torch.allclose( 63 | y_list[rank], y_list[dist.get_rank()] 64 | ), f"y in rank {rank} and {dist.get_rank()} are not equal in the same sequence parallel group." 65 | 66 | for rank in dp_group_ranks: 67 | if rank != dist.get_rank(): 68 | assert not torch.allclose( 69 | x_list[rank], x_list[dist.get_rank()] 70 | ), f"x in rank {rank} and {dist.get_rank()} are equal in the same data parallel group." 71 | assert not torch.allclose( 72 | y_list[rank], y_list[dist.get_rank()] 73 | ), f"y in rank {rank} and {dist.get_rank()} are equal in the same data parallel group." 74 | 75 | 76 | def check_dataloader(rank, world_size, port): 77 | colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") 78 | run_dataloader_test() 79 | 80 | 81 | @rerun_if_address_is_in_use() 82 | def test_sequence_parallel_dataloader(): 83 | spawn(check_dataloader, WORKERS) 84 | 85 | 86 | if __name__ == "__main__": 87 | test_sequence_parallel_dataloader() 88 | -------------------------------------------------------------------------------- /codes/OpenDiT/tests/test_ema_sharding.py: -------------------------------------------------------------------------------- 1 | import os 2 | from copy import deepcopy 3 | 4 | import colossalai 5 | import torch 6 | import torch.distributed as dist 7 | from colossalai.booster import Booster 8 | from colossalai.booster.plugin import LowLevelZeroPlugin 9 | from colossalai.nn.optimizer import HybridAdam 10 | from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn 11 | 12 | from opendit.models.dit import DiT 13 | from opendit.utils.ckpt_utils import model_gathering, record_model_param_shape 14 | from opendit.utils.operation import model_sharding 15 | from opendit.utils.train_utils import update_ema 16 | 17 | 18 | def assert_params_equal(model1, model2): 19 | for (name1, param1), (name2, param2) in zip(model1.named_parameters(), model2.named_parameters()): 20 | assert name1 == name2 21 | if name1 == "pos_embed": 22 | continue 23 | assert torch.allclose(param1, param2) 24 | 25 | 26 | @clear_cache_before_run() 27 | def run_ema_sharding(): 28 | plugin = LowLevelZeroPlugin(precision="fp16", stage=2, max_norm=1.0, initial_scale=32) 29 | booster = Booster(plugin=plugin) 30 | model = DiT(depth=2, hidden_size=64, patch_size=2, num_heads=4, dtype=torch.float16).cuda().half() 31 | 32 | ema_sharding = deepcopy(model).eval() 33 | model_param_shape = record_model_param_shape(ema_sharding) 34 | model_sharding(ema_sharding) 35 | ema_no_sharding = deepcopy(model).eval() 36 | ema_to_read = deepcopy(model).eval() 37 | 38 | criterion = lambda x: x.mean() 39 | optimizer = HybridAdam((model.parameters()), lr=0.001) 40 | model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) 41 | 42 | x = torch.randn(2, 4, 32, 32).cuda().requires_grad_(True) 43 | y = torch.randint(0, 10, (2,)).cuda() 44 | t = torch.randint(0, 10, (2,)).cuda() 45 | output = model(x, y, t) 46 | loss = criterion(output) 47 | booster.backward(loss, optimizer) 48 | optimizer.step() 49 | 50 | update_ema(ema_sharding, model.module, optimizer=optimizer, sharded=True, decay=0.5) 51 | update_ema(ema_no_sharding, model.module, optimizer=optimizer, sharded=False, decay=0.5) 52 | 53 | # should be equal after update 54 | gather_ema_sharding = deepcopy(ema_sharding) 55 | model_gathering(gather_ema_sharding, model_param_shape) 56 | if dist.get_rank() == 0: 57 | assert_params_equal(gather_ema_sharding, ema_no_sharding) 58 | dist.barrier() 59 | 60 | # should be same after read again 61 | if dist.get_rank() == 0: 62 | torch.save(gather_ema_sharding.state_dict(), "tmp.pth") 63 | ema_to_read.load_state_dict(torch.load("tmp.pth")) 64 | assert_params_equal(gather_ema_sharding, ema_to_read) 65 | os.remove("tmp.pth") 66 | dist.barrier() 67 | 68 | # should be same after sharding again 69 | if dist.get_rank() == 0: 70 | model_sharding(gather_ema_sharding) 71 | assert_params_equal(gather_ema_sharding, ema_sharding) 72 | dist.barrier() 73 | 74 | 75 | def run_dist(rank, world_size, port): 76 | colossalai.launch(config=(dict()), rank=rank, world_size=world_size, port=port, host="localhost") 77 | run_ema_sharding() 78 | torch.cuda.empty_cache() 79 | 80 | 81 | @rerun_if_address_is_in_use() 82 | def test_ema_sharding(): 83 | spawn(run_dist, 2) 84 | 85 | 86 | if __name__ == "__main__": 87 | test_ema_sharding() 88 | -------------------------------------------------------------------------------- /codes/OpenDiT/tests/test_flash_attention.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import colossalai 4 | import flash_attn 5 | import pytest 6 | import torch 7 | from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn 8 | from torch.testing import assert_close 9 | 10 | from opendit.modules.attn import DistAttention 11 | 12 | torch.manual_seed(1024) 13 | 14 | WORKERS = 1 15 | DTYPE = torch.float16 16 | 17 | 18 | def _run_flash_attn(seq_len, hidden_dim, head_num, batch_size, use_flash_attn): 19 | seq_len = seq_len 20 | hidden_dim = hidden_dim 21 | head_num = head_num 22 | batch_size = batch_size 23 | 24 | # set dtype as bf16 25 | torch.set_default_dtype(DTYPE) 26 | 27 | x = torch.randn(batch_size, seq_len, hidden_dim).cuda() 28 | x_naive_attn = x.clone().requires_grad_(True) 29 | x_flash_attn = x.clone().requires_grad_(True) 30 | 31 | # DistAttention without flash attention 32 | dist_attn_without_flashattn = DistAttention( 33 | dim=hidden_dim, 34 | num_heads=head_num, 35 | enable_flashattn=use_flash_attn, 36 | sequence_parallel_size=1, 37 | sequence_parallel_group=None, 38 | ).cuda() 39 | 40 | dist_attn_with_flashattn = copy.deepcopy(dist_attn_without_flashattn) 41 | setattr(dist_attn_with_flashattn, "enable_flashattn", True) 42 | 43 | naive_attn_output = dist_attn_without_flashattn(x_naive_attn) 44 | flash_attn_output = dist_attn_with_flashattn(x_flash_attn) 45 | 46 | assert_close(naive_attn_output, flash_attn_output, atol=1e-4, rtol=1e-4) 47 | 48 | # Attention backward 49 | naive_attn_output.sum().backward() 50 | qkv_grad_naive_attn = dist_attn_without_flashattn.qkv.weight.grad 51 | o_grad_naive_attn = dist_attn_without_flashattn.proj.weight.grad 52 | x_grad_naive_attn = x_naive_attn.grad 53 | 54 | flash_attn_output.sum().backward() 55 | qkv_grad_flash_attn = dist_attn_with_flashattn.qkv.weight.grad 56 | o_grad_flash_attn = dist_attn_with_flashattn.proj.weight.grad 57 | x_grad_flash_attn = x_flash_attn.grad 58 | 59 | # backward result check 60 | assert_close(qkv_grad_naive_attn, qkv_grad_flash_attn, atol=1e-3, rtol=1e-3) 61 | assert_close(o_grad_naive_attn, o_grad_flash_attn, atol=1e-3, rtol=1e-3) 62 | assert_close(x_grad_naive_attn, x_grad_flash_attn, atol=1e-3, rtol=1e-3) 63 | 64 | 65 | @parameterize("seq_len", [256]) 66 | @parameterize("hidden_dim", [1152]) 67 | @parameterize("head_num", [16]) 68 | @parameterize("batch_size", [2]) 69 | @parameterize("use_flash_attn", [True]) 70 | def run_flash_attn(seq_len, hidden_dim, head_num, batch_size, use_flash_attn): 71 | _run_flash_attn(seq_len, hidden_dim, head_num, batch_size, use_flash_attn) 72 | 73 | 74 | def check_all2all_attn(rank, world_size, port): 75 | colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") 76 | run_flash_attn() 77 | 78 | 79 | @pytest.mark.skipif(flash_attn.__version__ < "2.4.1", reason="requires flashattn 2.4.1 or higher") 80 | @rerun_if_address_is_in_use() 81 | def test_flash_attn(): 82 | spawn(check_all2all_attn, nprocs=WORKERS) 83 | 84 | 85 | if __name__ == "__main__": 86 | test_flash_attn() 87 | -------------------------------------------------------------------------------- /codes/OpenDiT/tests/test_fused_modulate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from opendit.kernels.fused_modulate import fused_modulate 4 | 5 | 6 | def test_fused_modulate(): 7 | x1 = torch.rand((1, 20, 100), requires_grad=True).cuda() 8 | x1.retain_grad() 9 | shift1 = torch.rand((1, 100), requires_grad=True).cuda() 10 | shift1.retain_grad() 11 | scale1 = torch.rand((1, 100), requires_grad=True).cuda() 12 | scale1.retain_grad() 13 | x2 = x1.clone().detach().requires_grad_() 14 | shift2 = shift1.clone().detach().requires_grad_() 15 | scale2 = scale1.clone().detach().requires_grad_() 16 | 17 | out1 = fused_modulate(x1, scale1, shift1) 18 | out1.mean().backward() 19 | out2 = x2 * (1 + scale2.unsqueeze(1)) + shift2.unsqueeze(1) 20 | out2.mean().backward() 21 | 22 | assert torch.allclose(out1, out2, atol=1e-6), f"\nout1:\n{out1}\nout2:\n{out2}\n" 23 | assert torch.allclose(x1.grad, x2.grad, atol=1e-6), f"\nx1.grad:\n{x1.grad}\nx2.grad:\n{x2.grad}\n" 24 | assert torch.allclose( 25 | scale1.grad, scale2.grad, atol=1e-4 26 | ), f"\nscale1.grad:\n{scale1.grad}\nscale2.grad:\n{scale2.grad}\n" 27 | assert torch.allclose( 28 | shift1.grad, shift2.grad, atol=1e-4 29 | ), f"\nshift1.grad:\n{shift1.grad}\nshift2.grad:\n{shift2.grad}\n" 30 | 31 | 32 | if __name__ == "__main__": 33 | test_fused_modulate() 34 | -------------------------------------------------------------------------------- /codes/OpenDiT/tests/test_model/test_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from org_dit import DiT_S_2 as ORG_MODEL 3 | from torch import nn 4 | 5 | from opendit.models.dit import DiT_S_2 as NEW_MODEL 6 | 7 | 8 | def initialize_weights(model): 9 | # use normal distribution to initialize all weights for test 10 | 11 | def _basic_init(module): 12 | if isinstance(module, nn.Linear): 13 | nn.init.normal_(module.weight, std=0.02) 14 | if module.bias is not None: 15 | nn.init.normal_(module.bias, std=0.02) 16 | 17 | model.apply(_basic_init) 18 | 19 | # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): 20 | w = model.x_embedder.proj.weight.data 21 | nn.init.xavier_uniform_(w.view([w.shape[0], -1])) 22 | nn.init.normal_(model.x_embedder.proj.bias, std=0.02) 23 | 24 | # Initialize label embedding table: 25 | nn.init.normal_(model.y_embedder.embedding_table.weight, std=0.02) 26 | 27 | # Initialize timestep embedding MLP: 28 | nn.init.normal_(model.t_embedder.mlp[0].weight, std=0.02) 29 | nn.init.normal_(model.t_embedder.mlp[2].weight, std=0.02) 30 | 31 | # Zero-out adaLN modulation layers in DiT blocks: 32 | for block in model.blocks: 33 | nn.init.normal_(block.adaLN_modulation[-1].weight, std=0.02) 34 | nn.init.normal_(block.adaLN_modulation[-1].bias, std=0.02) 35 | 36 | # Zero-out output layers: 37 | nn.init.normal_(model.final_layer.adaLN_modulation[-1].weight, std=0.02) 38 | nn.init.normal_(model.final_layer.adaLN_modulation[-1].bias, std=0.02) 39 | nn.init.normal_(model.final_layer.linear.weight, std=0.02) 40 | nn.init.normal_(model.final_layer.linear.bias, std=0.02) 41 | 42 | 43 | def test_model(): 44 | torch.manual_seed(0) 45 | org_model = ORG_MODEL().cuda() 46 | initialize_weights(org_model) 47 | torch.manual_seed(0) 48 | new_model = NEW_MODEL().cuda() 49 | initialize_weights(new_model) 50 | 51 | # Check if the model parameters are equal 52 | for (org_name, org_param), (new_name, new_param) in zip(org_model.named_parameters(), new_model.named_parameters()): 53 | assert org_name == new_name 54 | assert torch.equal(org_param, new_param), f"Parameter {org_name} is not equal\n{org_param}\n {new_param}" 55 | 56 | x1 = torch.randn(2, 4, 32, 32).cuda().requires_grad_(True) 57 | y1 = torch.randint(0, 10, (2,)).cuda() 58 | t1 = torch.randint(0, 10, (2,)).cuda() 59 | x2 = x1.clone().detach().requires_grad_(True) 60 | y2 = y1.clone().detach() 61 | t2 = t1.clone().detach() 62 | 63 | org_output = org_model(x1, t1, y1) 64 | new_output = new_model(x2, t2, y2) 65 | assert torch.allclose( 66 | org_output, new_output, atol=1e-5 67 | ), f"Max diff: {torch.max(torch.abs(org_output - new_output))}, Mean diff: {torch.mean(torch.abs(org_output - new_output))}" 68 | 69 | org_output.mean().backward() 70 | new_output.mean().backward() 71 | assert torch.allclose( 72 | x1.grad, x2.grad, atol=1e-5 73 | ), f"Max diff: {torch.max(torch.abs(x1.grad - x2.grad))}, Mean diff: {torch.mean(torch.abs(x1.grad - x2.grad))}" 74 | 75 | 76 | if __name__ == "__main__": 77 | test_model() 78 | -------------------------------------------------------------------------------- /codes/OpenDiT/train_img.sh: -------------------------------------------------------------------------------- 1 | torchrun --standalone --nproc_per_node=2 train.py \ 2 | --model DiT-XL/2 \ 3 | --batch_size 2 \ 4 | --num_classes 10 5 | -------------------------------------------------------------------------------- /codes/OpenDiT/train_video.sh: -------------------------------------------------------------------------------- 1 | torchrun --standalone --nproc_per_node=2 train.py \ 2 | --model VDiT-XL/1x2x2 \ 3 | --use_video \ 4 | --data_path ./videos/demo.csv \ 5 | --batch_size 1 \ 6 | --num_frames 16 \ 7 | --image_size 256 \ 8 | --frame_interval 3 9 | -------------------------------------------------------------------------------- /codes/OpenDiT/videos/art-museum.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mini-sora/minisora/c831ce3cbac4c9a81a315c6901145c023814897c/codes/OpenDiT/videos/art-museum.mp4 -------------------------------------------------------------------------------- /codes/OpenDiT/videos/demo.csv: -------------------------------------------------------------------------------- 1 | ./videos/art-museum.mp4,art-museum 2 | ./videos/lagos.mp4,lagos 3 | ./videos/man-on-the-cloud.mp4,man on the cloud 4 | ./videos/suv-in-the-dust.mp4,suv in the dust 5 | -------------------------------------------------------------------------------- /codes/OpenDiT/videos/lagos.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mini-sora/minisora/c831ce3cbac4c9a81a315c6901145c023814897c/codes/OpenDiT/videos/lagos.mp4 -------------------------------------------------------------------------------- /codes/OpenDiT/videos/man-on-the-cloud.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mini-sora/minisora/c831ce3cbac4c9a81a315c6901145c023814897c/codes/OpenDiT/videos/man-on-the-cloud.mp4 -------------------------------------------------------------------------------- /codes/OpenDiT/videos/suv-in-the-dust.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mini-sora/minisora/c831ce3cbac4c9a81a315c6901145c023814897c/codes/OpenDiT/videos/suv-in-the-dust.mp4 -------------------------------------------------------------------------------- /codes/README_zh-CN.md: -------------------------------------------------------------------------------- 1 | # Mini Sora 社区Sora复现小组 2 | 3 | 4 | 5 | [![Contributors][contributors-shield]][contributors-url] 6 | [![Forks][forks-shield]][forks-url] 7 | [![Issues][issues-shield]][issues-url] 8 | [![MIT License][license-shield]][license-url] 9 | [![Stargazers][stars-shield]][stars-url] 10 |
11 | 12 | 13 |
14 | 15 | 16 |
 
17 |
18 |
19 |
20 | 21 |
22 | 23 | [English](README.md) | 简体中文 24 | 25 | 26 |
27 | 28 | ## Mini Sora 复现目标 29 | 30 | 1. **GPU-Friendly**: 最好对GPU内存大小和GPU数量要求较低,比如8卡A100,4KA6000,单卡Rtx4090之类的算力可以训练和推理 31 | 2. **Training-Efficiency** : 不需要训练太久即可有较好的效果 32 | 3. **Inference-Efficiency**: 推理生成视频时, 长度和分辨率不要求过高, 如3-10s,480p都是可接受的 33 | 34 | 候选复现论文主要有以下三篇, 来作为后续Sora复现的Baseline, 社区已经(02/29)将[OpenDiT](https://github.com/NUS-HPC-AI-Lab/OpenDiT)和[SiT](https://github.com/willisma/SiT)代码Fork到codes文件夹下, 期待贡献者提交PR, 将Baseline代码迁移到Sora复现工作上来. [**Update**] 03/02, 添加[StableCascade](https://github.com/Stability-AI/StableCascade) codes 35 | 36 | - DiT with **OpenDiT** 37 | - OpenDiT采用分布式训练,生成图片采用8卡A100训练。 38 | - OpenDiT采用的sd的vae编码,采用的是sd的预训练模型,实测出来效果会比VideoGPT的vqvae效果更好。 39 | - Sora Leader做过DALLE3,生成视频 的 解码 是用类似DALLE3的扩散方式, 所以压缩编码的时候应该是DALLE3的 反方向的方式 40 | - **SiT** 41 | - **W.A.L.T**(还未release) 42 | - **StableCascade** 43 | - ToDo: make it as a video-based model with additional temp layer in the near future 44 | 45 | ## 数据集 46 | 47 | ... 48 | 49 | ## 模型架构 50 | 51 | ... 52 | 53 | ## 算力需求 54 | 55 | ... 56 | 57 | 68 | 69 | ## Sora复现小组-MiniSora社区微信交流群 70 | 71 |
72 | 73 | 74 |
 
75 |
76 |
77 |
78 | 79 | ## Star History 80 | 81 | [![Star History Chart](https://api.star-history.com/svg?repos=mini-sora/minisora&type=Date)](https://star-history.com/#mini-sora/minisora&Date) 82 | 83 | ## 如何向Mini Sora 社区贡献 84 | 85 | 我们非常希望你们能够为 Mini Sora 开源社区做出贡献,并且帮助我们把它做得比现在更好! 86 | 87 | 具体查看[贡献指南](../.github/CONTRIBUTING_zh-CN.md) 88 | 89 | ## 社区贡献者 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | [contributors-shield]: https://img.shields.io/github/contributors/mini-sora/minisora.svg?style=flat-square 100 | [contributors-url]: https://github.com/mini-sora/minisora/graphs/contributors 101 | [forks-shield]: https://img.shields.io/github/forks/mini-sora/minisora.svg?style=flat-square 102 | [forks-url]: https://github.com/mini-sora/minisora/network/members 103 | [stars-shield]: https://img.shields.io/github/stars/mini-sora/minisora.svg?style=flat-square 104 | [stars-url]: https://github.com/mini-sora/minisora/stargazers 105 | [issues-shield]: https://img.shields.io/github/issues/mini-sora/minisora.svg?style=flat-square 106 | [issues-url]: https://img.shields.io/github/issues/mini-sora/minisora.svg 107 | [license-shield]: https://img.shields.io/github/license/mini-sora/minisora.svg?style=flat-square 108 | [license-url]: https://github.com/mini-sora/minisora/blob/main/LICENSE 109 | -------------------------------------------------------------------------------- /codes/SiT/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | wandb 3 | 4 | .DS_store 5 | samples 6 | results 7 | pretrained_models -------------------------------------------------------------------------------- /codes/SiT/LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Meta Platforms, Inc. and affiliates. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /codes/SiT/download.py: -------------------------------------------------------------------------------- 1 | # This source code is licensed under the license found in the 2 | # LICENSE file in the root directory of this source tree. 3 | 4 | """ 5 | Functions for downloading pre-trained SiT models 6 | """ 7 | from torchvision.datasets.utils import download_url 8 | import torch 9 | import os 10 | 11 | 12 | pretrained_models = {'SiT-XL-2-256x256.pt'} 13 | 14 | 15 | def find_model(model_name): 16 | """ 17 | Finds a pre-trained SiT model, downloading it if necessary. Alternatively, loads a model from a local path. 18 | """ 19 | if model_name in pretrained_models: 20 | return download_model(model_name) 21 | else: 22 | assert os.path.isfile(model_name), f'Could not find SiT checkpoint at {model_name}' 23 | checkpoint = torch.load(model_name, map_location=lambda storage, loc: storage) 24 | if "ema" in checkpoint: # supports checkpoints from train.py 25 | checkpoint = checkpoint["ema"] 26 | return checkpoint 27 | 28 | 29 | def download_model(model_name): 30 | """ 31 | Downloads a pre-trained SiT model from the web. 32 | """ 33 | assert model_name in pretrained_models 34 | local_path = f'pretrained_models/{model_name}' 35 | if not os.path.isfile(local_path): 36 | os.makedirs('pretrained_models', exist_ok=True) 37 | web_path = f'https://www.dl.dropboxusercontent.com/scl/fi/as9oeomcbub47de5g4be0/SiT-XL-2-256.pt?rlkey=uxzxmpicu46coq3msb17b9ofa&dl=0' 38 | download_url(web_path, 'pretrained_models', filename=model_name) 39 | model = torch.load(local_path, map_location=lambda storage, loc: storage) 40 | return model 41 | -------------------------------------------------------------------------------- /codes/SiT/environment.yml: -------------------------------------------------------------------------------- 1 | name: SiT 2 | channels: 3 | - pytorch 4 | - nvidia 5 | dependencies: 6 | - python >= 3.8 7 | - pytorch >= 1.13 8 | - torchvision 9 | - pytorch-cuda >=11.7 10 | - pip 11 | - pip: 12 | - timm 13 | - diffusers 14 | - accelerate 15 | - torchdiffeq 16 | - wandb 17 | -------------------------------------------------------------------------------- /codes/SiT/train_utils.py: -------------------------------------------------------------------------------- 1 | def none_or_str(value): 2 | if value == 'None': 3 | return None 4 | return value 5 | 6 | def parse_transport_args(parser): 7 | group = parser.add_argument_group("Transport arguments") 8 | group.add_argument("--path-type", type=str, default="Linear", choices=["Linear", "GVP", "VP"]) 9 | group.add_argument("--prediction", type=str, default="velocity", choices=["velocity", "score", "noise"]) 10 | group.add_argument("--loss-weight", type=none_or_str, default=None, choices=[None, "velocity", "likelihood"]) 11 | group.add_argument("--sample-eps", type=float) 12 | group.add_argument("--train-eps", type=float) 13 | 14 | def parse_ode_args(parser): 15 | group = parser.add_argument_group("ODE arguments") 16 | group.add_argument("--sampling-method", type=str, default="dopri5", help="blackbox ODE solver methods; for full list check https://github.com/rtqichen/torchdiffeq") 17 | group.add_argument("--atol", type=float, default=1e-6, help="Absolute tolerance") 18 | group.add_argument("--rtol", type=float, default=1e-3, help="Relative tolerance") 19 | group.add_argument("--reverse", action="store_true") 20 | group.add_argument("--likelihood", action="store_true") 21 | 22 | def parse_sde_args(parser): 23 | group = parser.add_argument_group("SDE arguments") 24 | group.add_argument("--sampling-method", type=str, default="Euler", choices=["Euler", "Heun"]) 25 | group.add_argument("--diffusion-form", type=str, default="sigma", \ 26 | choices=["constant", "SBDM", "sigma", "linear", "decreasing", "increasing-decreasing"],\ 27 | help="form of diffusion coefficient in the SDE") 28 | group.add_argument("--diffusion-norm", type=float, default=1.0) 29 | group.add_argument("--last-step", type=none_or_str, default="Mean", choices=[None, "Mean", "Tweedie", "Euler"],\ 30 | help="form of last step taken in the SDE") 31 | group.add_argument("--last-step-size", type=float, default=0.04, \ 32 | help="size of the last step taken") -------------------------------------------------------------------------------- /codes/SiT/transport/__init__.py: -------------------------------------------------------------------------------- 1 | from .transport import Transport, ModelType, WeightType, PathType, Sampler 2 | 3 | def create_transport( 4 | path_type='Linear', 5 | prediction="velocity", 6 | loss_weight=None, 7 | train_eps=None, 8 | sample_eps=None, 9 | ): 10 | """function for creating Transport object 11 | **Note**: model prediction defaults to velocity 12 | Args: 13 | - path_type: type of path to use; default to linear 14 | - learn_score: set model prediction to score 15 | - learn_noise: set model prediction to noise 16 | - velocity_weighted: weight loss by velocity weight 17 | - likelihood_weighted: weight loss by likelihood weight 18 | - train_eps: small epsilon for avoiding instability during training 19 | - sample_eps: small epsilon for avoiding instability during sampling 20 | """ 21 | 22 | if prediction == "noise": 23 | model_type = ModelType.NOISE 24 | elif prediction == "score": 25 | model_type = ModelType.SCORE 26 | else: 27 | model_type = ModelType.VELOCITY 28 | 29 | if loss_weight == "velocity": 30 | loss_type = WeightType.VELOCITY 31 | elif loss_weight == "likelihood": 32 | loss_type = WeightType.LIKELIHOOD 33 | else: 34 | loss_type = WeightType.NONE 35 | 36 | path_choice = { 37 | "Linear": PathType.LINEAR, 38 | "GVP": PathType.GVP, 39 | "VP": PathType.VP, 40 | } 41 | 42 | path_type = path_choice[path_type] 43 | 44 | if (path_type in [PathType.VP]): 45 | train_eps = 1e-5 if train_eps is None else train_eps 46 | sample_eps = 1e-3 if train_eps is None else sample_eps 47 | elif (path_type in [PathType.GVP, PathType.LINEAR] and model_type != ModelType.VELOCITY): 48 | train_eps = 1e-3 if train_eps is None else train_eps 49 | sample_eps = 1e-3 if train_eps is None else sample_eps 50 | else: # velocity & [GVP, LINEAR] is stable everywhere 51 | train_eps = 0 52 | sample_eps = 0 53 | 54 | # create flow state 55 | state = Transport( 56 | model_type=model_type, 57 | path_type=path_type, 58 | loss_type=loss_type, 59 | train_eps=train_eps, 60 | sample_eps=sample_eps, 61 | ) 62 | 63 | return state -------------------------------------------------------------------------------- /codes/SiT/transport/integrators.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch as th 3 | import torch.nn as nn 4 | from torchdiffeq import odeint 5 | from functools import partial 6 | from tqdm import tqdm 7 | 8 | class sde: 9 | """SDE solver class""" 10 | def __init__( 11 | self, 12 | drift, 13 | diffusion, 14 | *, 15 | t0, 16 | t1, 17 | num_steps, 18 | sampler_type, 19 | ): 20 | assert t0 < t1, "SDE sampler has to be in forward time" 21 | 22 | self.num_timesteps = num_steps 23 | self.t = th.linspace(t0, t1, num_steps) 24 | self.dt = self.t[1] - self.t[0] 25 | self.drift = drift 26 | self.diffusion = diffusion 27 | self.sampler_type = sampler_type 28 | 29 | def __Euler_Maruyama_step(self, x, mean_x, t, model, **model_kwargs): 30 | w_cur = th.randn(x.size()).to(x) 31 | t = th.ones(x.size(0)).to(x) * t 32 | dw = w_cur * th.sqrt(self.dt) 33 | drift = self.drift(x, t, model, **model_kwargs) 34 | diffusion = self.diffusion(x, t) 35 | mean_x = x + drift * self.dt 36 | x = mean_x + th.sqrt(2 * diffusion) * dw 37 | return x, mean_x 38 | 39 | def __Heun_step(self, x, _, t, model, **model_kwargs): 40 | w_cur = th.randn(x.size()).to(x) 41 | dw = w_cur * th.sqrt(self.dt) 42 | t_cur = th.ones(x.size(0)).to(x) * t 43 | diffusion = self.diffusion(x, t_cur) 44 | xhat = x + th.sqrt(2 * diffusion) * dw 45 | K1 = self.drift(xhat, t_cur, model, **model_kwargs) 46 | xp = xhat + self.dt * K1 47 | K2 = self.drift(xp, t_cur + self.dt, model, **model_kwargs) 48 | return xhat + 0.5 * self.dt * (K1 + K2), xhat # at last time point we do not perform the heun step 49 | 50 | def __forward_fn(self): 51 | """TODO: generalize here by adding all private functions ending with steps to it""" 52 | sampler_dict = { 53 | "Euler": self.__Euler_Maruyama_step, 54 | "Heun": self.__Heun_step, 55 | } 56 | 57 | try: 58 | sampler = sampler_dict[self.sampler_type] 59 | except: 60 | raise NotImplementedError("Smapler type not implemented.") 61 | 62 | return sampler 63 | 64 | def sample(self, init, model, **model_kwargs): 65 | """forward loop of sde""" 66 | x = init 67 | mean_x = init 68 | samples = [] 69 | sampler = self.__forward_fn() 70 | for ti in self.t[:-1]: 71 | with th.no_grad(): 72 | x, mean_x = sampler(x, mean_x, ti, model, **model_kwargs) 73 | samples.append(x) 74 | 75 | return samples 76 | 77 | class ode: 78 | """ODE solver class""" 79 | def __init__( 80 | self, 81 | drift, 82 | *, 83 | t0, 84 | t1, 85 | sampler_type, 86 | num_steps, 87 | atol, 88 | rtol, 89 | ): 90 | assert t0 < t1, "ODE sampler has to be in forward time" 91 | 92 | self.drift = drift 93 | self.t = th.linspace(t0, t1, num_steps) 94 | self.atol = atol 95 | self.rtol = rtol 96 | self.sampler_type = sampler_type 97 | 98 | def sample(self, x, model, **model_kwargs): 99 | 100 | device = x[0].device if isinstance(x, tuple) else x.device 101 | def _fn(t, x): 102 | t = th.ones(x[0].size(0)).to(device) * t if isinstance(x, tuple) else th.ones(x.size(0)).to(device) * t 103 | model_output = self.drift(x, t, model, **model_kwargs) 104 | return model_output 105 | 106 | t = self.t.to(device) 107 | atol = [self.atol] * len(x) if isinstance(x, tuple) else [self.atol] 108 | rtol = [self.rtol] * len(x) if isinstance(x, tuple) else [self.rtol] 109 | samples = odeint( 110 | _fn, 111 | x, 112 | t, 113 | method=self.sampler_type, 114 | atol=atol, 115 | rtol=rtol 116 | ) 117 | return samples -------------------------------------------------------------------------------- /codes/SiT/transport/utils.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | 3 | class EasyDict: 4 | 5 | def __init__(self, sub_dict): 6 | for k, v in sub_dict.items(): 7 | setattr(self, k, v) 8 | 9 | def __getitem__(self, key): 10 | return getattr(self, key) 11 | 12 | def mean_flat(x): 13 | """ 14 | Take the mean over all non-batch dimensions. 15 | """ 16 | return th.mean(x, dim=list(range(1, len(x.size())))) 17 | 18 | def log_state(state): 19 | result = [] 20 | 21 | sorted_state = dict(sorted(state.items())) 22 | for key, value in sorted_state.items(): 23 | # Check if the value is an instance of a class 24 | if "= 768'] 34 | - ['height', 'lambda h: h >= 768'] 35 | 36 | # ema_start_iters: 5000 37 | # ema_iters: 100 38 | # ema_beta: 0.9 39 | 40 | webdataset_path: 41 | - s3://path/to/your/first/dataset/on/s3 42 | - s3://path/to/your/second/dataset/on/s3 43 | effnet_checkpoint_path: models/effnet_encoder.safetensors 44 | previewer_checkpoint_path: models/previewer.safetensors 45 | generator_checkpoint_path: models/stage_c_bf16.safetensors -------------------------------------------------------------------------------- /codes/StableCascade/configs/training/controlnet_c_3b_identity.yaml: -------------------------------------------------------------------------------- 1 | # GLOBAL STUFF 2 | experiment_id: stage_c_3b_controlnet_identity 3 | checkpoint_path: /path/to/checkpoint 4 | output_path: /path/to/output 5 | model_version: 3.6B 6 | 7 | # WandB 8 | wandb_project: StableCascade 9 | wandb_entity: wandb_username 10 | 11 | # TRAINING PARAMS 12 | lr: 1.0e-4 13 | batch_size: 256 14 | image_size: 768 15 | # multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16] 16 | grad_accum_steps: 1 17 | updates: 200000 18 | backup_every: 2000 19 | save_every: 1000 20 | warmup_updates: 1 21 | use_fsdp: True 22 | 23 | # ControlNet specific 24 | controlnet_bottleneck_mode: 'simple' 25 | controlnet_blocks: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] 26 | controlnet_filter: IdentityFilter 27 | controlnet_filter_params: 28 | max_faces: 4 29 | p_drop: 0.05 30 | p_full: 0.3 31 | # offset_noise: 0.1 32 | 33 | # CUSTOM CAPTIONS GETTER & FILTERS 34 | captions_getter: ['txt', identity] 35 | dataset_filters: 36 | - ['width', 'lambda w: w >= 768'] 37 | - ['height', 'lambda h: h >= 768'] 38 | 39 | # ema_start_iters: 5000 40 | # ema_iters: 100 41 | # ema_beta: 0.9 42 | 43 | webdataset_path: 44 | - s3://path/to/your/first/dataset/on/s3 45 | - s3://path/to/your/second/dataset/on/s3 46 | effnet_checkpoint_path: models/effnet_encoder.safetensors 47 | previewer_checkpoint_path: models/previewer.safetensors 48 | generator_checkpoint_path: models/stage_c_bf16.safetensors -------------------------------------------------------------------------------- /codes/StableCascade/configs/training/controlnet_c_3b_inpainting.yaml: -------------------------------------------------------------------------------- 1 | # GLOBAL STUFF 2 | experiment_id: stage_c_3b_controlnet_inpainting 3 | checkpoint_path: /path/to/checkpoint 4 | output_path: /path/to/output 5 | model_version: 3.6B 6 | 7 | # WandB 8 | wandb_project: StableCascade 9 | wandb_entity: wandb_username 10 | 11 | # TRAINING PARAMS 12 | lr: 1.0e-4 13 | batch_size: 256 14 | image_size: 768 15 | # multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16] 16 | grad_accum_steps: 1 17 | updates: 10000 18 | backup_every: 2000 19 | save_every: 1000 20 | warmup_updates: 1 21 | use_fsdp: True 22 | 23 | # ControlNet specific 24 | controlnet_blocks: [0, 4, 8, 12, 51, 55, 59, 63] 25 | controlnet_filter: InpaintFilter 26 | controlnet_filter_params: 27 | thresold: [0.04, 0.4] 28 | p_outpaint: 0.4 29 | offset_noise: 0.1 30 | 31 | # CUSTOM CAPTIONS GETTER & FILTERS 32 | captions_getter: ['txt', identity] 33 | dataset_filters: 34 | - ['width', 'lambda w: w >= 768'] 35 | - ['height', 'lambda h: h >= 768'] 36 | 37 | # ema_start_iters: 5000 38 | # ema_iters: 100 39 | # ema_beta: 0.9 40 | 41 | webdataset_path: 42 | - s3://path/to/your/first/dataset/on/s3 43 | - s3://path/to/your/second/dataset/on/s3 44 | effnet_checkpoint_path: models/effnet_encoder.safetensors 45 | previewer_checkpoint_path: models/previewer.safetensors 46 | generator_checkpoint_path: models/stage_c_bf16.safetensors -------------------------------------------------------------------------------- /codes/StableCascade/configs/training/controlnet_c_3b_sr.yaml: -------------------------------------------------------------------------------- 1 | # GLOBAL STUFF 2 | experiment_id: stage_c_3b_controlnet_sr 3 | checkpoint_path: /path/to/checkpoint 4 | output_path: /path/to/output 5 | model_version: 3.6B 6 | 7 | # WandB 8 | wandb_project: StableCascade 9 | wandb_entity: wandb_username 10 | 11 | # TRAINING PARAMS 12 | lr: 1.0e-4 13 | batch_size: 256 14 | image_size: 768 15 | # multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16] 16 | grad_accum_steps: 1 17 | updates: 30000 18 | backup_every: 5000 19 | save_every: 1000 20 | warmup_updates: 1 21 | use_fsdp: True 22 | 23 | # ControlNet specific 24 | controlnet_bottleneck_mode: 'large' 25 | controlnet_blocks: [0, 4, 8, 12, 51, 55, 59, 63] 26 | controlnet_filter: SREffnetFilter 27 | controlnet_filter_params: 28 | scale_factor: 0.5 29 | offset_noise: 0.1 30 | 31 | # CUSTOM CAPTIONS GETTER & FILTERS 32 | captions_getter: ['txt', identity] 33 | dataset_filters: 34 | - ['width', 'lambda w: w >= 768'] 35 | - ['height', 'lambda h: h >= 768'] 36 | 37 | # ema_start_iters: 5000 38 | # ema_iters: 100 39 | # ema_beta: 0.9 40 | 41 | webdataset_path: 42 | - s3://path/to/your/first/dataset/on/s3 43 | - s3://path/to/your/second/dataset/on/s3 44 | effnet_checkpoint_path: models/effnet_encoder.safetensors 45 | previewer_checkpoint_path: models/previewer.safetensors 46 | generator_checkpoint_path: models/stage_c_bf16.safetensors -------------------------------------------------------------------------------- /codes/StableCascade/configs/training/finetune_b_3b.yaml: -------------------------------------------------------------------------------- 1 | # GLOBAL STUFF 2 | experiment_id: stage_b_3b_finetuning 3 | checkpoint_path: /path/to/checkpoint 4 | output_path: /path/to/output 5 | model_version: 3B 6 | 7 | # WandB 8 | wandb_project: StableCascade 9 | wandb_entity: wandb_username 10 | 11 | # TRAINING PARAMS 12 | lr: 1.0e-4 13 | batch_size: 256 14 | image_size: 1024 15 | # multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16] 16 | shift: 4 17 | grad_accum_steps: 1 18 | updates: 100000 19 | backup_every: 20000 20 | save_every: 1000 21 | warmup_updates: 1 22 | use_fsdp: True 23 | 24 | # GDF 25 | adaptive_loss_weight: True 26 | 27 | # ema_start_iters: 5000 28 | # ema_iters: 100 29 | # ema_beta: 0.9 30 | 31 | webdataset_path: 32 | - s3://path/to/your/first/dataset/on/s3 33 | - s3://path/to/your/second/dataset/on/s3 34 | effnet_checkpoint_path: models/effnet_encoder.safetensors 35 | stage_a_checkpoint_path: models/stage_a.safetensors 36 | generator_checkpoint_path: models/stage_b_bf16.safetensors 37 | -------------------------------------------------------------------------------- /codes/StableCascade/configs/training/finetune_b_700m.yaml: -------------------------------------------------------------------------------- 1 | # GLOBAL STUFF 2 | experiment_id: stage_b_700m_finetuning 3 | checkpoint_path: /path/to/checkpoint 4 | output_path: /path/to/output 5 | model_version: 700M 6 | 7 | # WandB 8 | wandb_project: StableCascade 9 | wandb_entity: wandb_username 10 | 11 | # TRAINING PARAMS 12 | lr: 1.0e-4 13 | batch_size: 512 14 | image_size: 1024 15 | # multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16] 16 | shift: 4 17 | grad_accum_steps: 1 18 | updates: 10000 19 | backup_every: 20000 20 | save_every: 2000 21 | warmup_updates: 1 22 | use_fsdp: True 23 | 24 | # GDF 25 | adaptive_loss_weight: True 26 | 27 | # ema_start_iters: 5000 28 | # ema_iters: 100 29 | # ema_beta: 0.9 30 | 31 | webdataset_path: 32 | - s3://path/to/your/first/dataset/on/s3 33 | - s3://path/to/your/second/dataset/on/s3 34 | effnet_checkpoint_path: models/effnet_encoder.safetensors 35 | stage_a_checkpoint_path: models/stage_a.safetensors 36 | generator_checkpoint_path: models/stage_b_lite_bf16.safetensors 37 | -------------------------------------------------------------------------------- /codes/StableCascade/configs/training/finetune_c_1b.yaml: -------------------------------------------------------------------------------- 1 | # GLOBAL STUFF 2 | experiment_id: stage_c_1b_finetuning 3 | checkpoint_path: /path/to/checkpoint 4 | output_path: /path/to/output 5 | model_version: 1B 6 | 7 | # WandB 8 | wandb_project: StableCascade 9 | wandb_entity: wandb_username 10 | 11 | # TRAINING PARAMS 12 | lr: 1.0e-4 13 | batch_size: 1024 14 | image_size: 768 15 | # multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16] 16 | grad_accum_steps: 1 17 | updates: 10000 18 | backup_every: 20000 19 | save_every: 2000 20 | warmup_updates: 1 21 | use_fsdp: True 22 | 23 | # GDF 24 | # adaptive_loss_weight: True 25 | 26 | # ema_start_iters: 5000 27 | # ema_iters: 100 28 | # ema_beta: 0.9 29 | 30 | webdataset_path: 31 | - s3://path/to/your/first/dataset/on/s3 32 | - s3://path/to/your/second/dataset/on/s3 33 | effnet_checkpoint_path: models/effnet_encoder.safetensors 34 | previewer_checkpoint_path: models/previewer.safetensors 35 | generator_checkpoint_path: models/stage_c_lite_bf16.safetensors -------------------------------------------------------------------------------- /codes/StableCascade/configs/training/finetune_c_3b.yaml: -------------------------------------------------------------------------------- 1 | # GLOBAL STUFF 2 | experiment_id: stage_c_3b_finetuning 3 | checkpoint_path: /path/to/checkpoint 4 | output_path: /path/to/output 5 | model_version: 3.6B 6 | 7 | # WandB 8 | wandb_project: StableCascade 9 | wandb_entity: wandb_username 10 | 11 | # TRAINING PARAMS 12 | lr: 1.0e-4 13 | batch_size: 512 14 | image_size: 768 15 | multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16] 16 | grad_accum_steps: 1 17 | updates: 100000 18 | backup_every: 20000 19 | save_every: 2000 20 | warmup_updates: 1 21 | use_fsdp: True 22 | 23 | # GDF 24 | adaptive_loss_weight: True 25 | 26 | # ema_start_iters: 5000 27 | # ema_iters: 100 28 | # ema_beta: 0.9 29 | 30 | webdataset_path: 31 | - s3://path/to/your/first/dataset/on/s3 32 | - s3://path/to/your/second/dataset/on/s3 33 | effnet_checkpoint_path: models/effnet_encoder.safetensors 34 | previewer_checkpoint_path: models/previewer.safetensors 35 | generator_checkpoint_path: models/stage_c_bf16.safetensors -------------------------------------------------------------------------------- /codes/StableCascade/configs/training/finetune_c_3b_lora.yaml: -------------------------------------------------------------------------------- 1 | # GLOBAL STUFF 2 | experiment_id: stage_c_3b_lora 3 | checkpoint_path: /path/to/checkpoint 4 | output_path: /path/to/output 5 | model_version: 3.6B 6 | 7 | # WandB 8 | wandb_project: StableCascade 9 | wandb_entity: wandb_username 10 | 11 | # TRAINING PARAMS 12 | lr: 1.0e-4 13 | batch_size: 32 14 | image_size: 768 15 | multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16] 16 | grad_accum_steps: 4 17 | updates: 10000 18 | backup_every: 1000 19 | save_every: 100 20 | warmup_updates: 1 21 | # use_fsdp: True -> FSDP doesn't work at the moment for LoRA 22 | use_fsdp: False 23 | 24 | # GDF 25 | # adaptive_loss_weight: True 26 | 27 | # LoRA specific 28 | module_filters: ['.attn'] 29 | rank: 4 30 | train_tokens: 31 | # - ['^snail', null] # token starts with "snail" -> "snail" & "snails", don't need to be reinitialized 32 | - ['[fernando]', '^dog'] # custom token [snail], initialize as avg of snail & snails 33 | 34 | 35 | # ema_start_iters: 5000 36 | # ema_iters: 100 37 | # ema_beta: 0.9 38 | 39 | webdataset_path: 40 | - s3://path/to/your/first/dataset/on/s3 41 | - s3://path/to/your/second/dataset/on/s3 42 | effnet_checkpoint_path: models/effnet_encoder.safetensors 43 | previewer_checkpoint_path: models/previewer.safetensors 44 | generator_checkpoint_path: models/stage_c_bf16.safetensors -------------------------------------------------------------------------------- /codes/StableCascade/configs/training/finetune_c_3b_lowres.yaml: -------------------------------------------------------------------------------- 1 | # GLOBAL STUFF 2 | experiment_id: stage_c_3b_finetuning 3 | checkpoint_path: /path/to/checkpoint 4 | output_path: /path/to/output 5 | model_version: 3.6B 6 | 7 | # WandB 8 | wandb_project: StableCascade 9 | wandb_entity: wandb_username 10 | 11 | # TRAINING PARAMS 12 | lr: 1.0e-4 13 | batch_size: 1024 14 | image_size: 384 15 | multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16] 16 | grad_accum_steps: 1 17 | updates: 100000 18 | backup_every: 20000 19 | save_every: 2000 20 | warmup_updates: 1 21 | use_fsdp: True 22 | 23 | # GDF 24 | adaptive_loss_weight: True 25 | 26 | # CUSTOM CAPTIONS GETTER & FILTERS 27 | # captions_getter: ['json', captions_getter] 28 | # dataset_filters: 29 | # - ['normalized_score', 'lambda s: s > 9.0'] 30 | # - ['pgen_normalized_score', 'lambda s: s > 3.0'] 31 | 32 | # ema_start_iters: 5000 33 | # ema_iters: 100 34 | # ema_beta: 0.9 35 | 36 | webdataset_path: 37 | - s3://path/to/your/first/dataset/on/s3 38 | - s3://path/to/your/second/dataset/on/s3 39 | effnet_checkpoint_path: models/effnet_encoder.safetensors 40 | previewer_checkpoint_path: models/previewer.safetensors 41 | generator_checkpoint_path: models/stage_c_bf16.safetensors -------------------------------------------------------------------------------- /codes/StableCascade/configs/training/finetune_c_3b_v.yaml: -------------------------------------------------------------------------------- 1 | # GLOBAL STUFF 2 | experiment_id: stage_c_3b_finetuning 3 | checkpoint_path: /path/to/checkpoint 4 | output_path: /path/to/output 5 | model_version: 3.6B 6 | 7 | # WandB 8 | wandb_project: StableCascade 9 | wandb_entity: wandb_username 10 | 11 | # TRAINING PARAMS 12 | lr: 1.0e-4 13 | batch_size: 512 14 | image_size: 768 15 | multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16] 16 | grad_accum_steps: 1 17 | updates: 100000 18 | backup_every: 20000 19 | save_every: 2000 20 | warmup_updates: 1 21 | use_fsdp: True 22 | 23 | # GDF 24 | adaptive_loss_weight: True 25 | edm_objective: True 26 | 27 | # ema_start_iters: 5000 28 | # ema_iters: 100 29 | # ema_beta: 0.9 30 | 31 | webdataset_path: 32 | - s3://path/to/your/first/dataset/on/s3 33 | - s3://path/to/your/second/dataset/on/s3 34 | effnet_checkpoint_path: models/effnet_encoder.safetensors 35 | previewer_checkpoint_path: models/previewer.safetensors 36 | generator_checkpoint_path: models/stage_c_bf16.safetensors -------------------------------------------------------------------------------- /codes/StableCascade/core/data/__init__.py: -------------------------------------------------------------------------------- 1 | import json 2 | import subprocess 3 | import yaml 4 | import os 5 | from .bucketeer import Bucketeer 6 | 7 | class MultiFilter(): 8 | def __init__(self, rules, default=False): 9 | self.rules = rules 10 | self.default = default 11 | 12 | def __call__(self, x): 13 | try: 14 | x_json = x['json'] 15 | if isinstance(x_json, bytes): 16 | x_json = json.loads(x_json) 17 | validations = [] 18 | for k, r in self.rules.items(): 19 | if isinstance(k, tuple): 20 | v = r(*[x_json[kv] for kv in k]) 21 | else: 22 | v = r(x_json[k]) 23 | validations.append(v) 24 | return all(validations) 25 | except Exception: 26 | return False 27 | 28 | class MultiGetter(): 29 | def __init__(self, rules): 30 | self.rules = rules 31 | 32 | def __call__(self, x_json): 33 | if isinstance(x_json, bytes): 34 | x_json = json.loads(x_json) 35 | outputs = [] 36 | for k, r in self.rules.items(): 37 | if isinstance(k, tuple): 38 | v = r(*[x_json[kv] for kv in k]) 39 | else: 40 | v = r(x_json[k]) 41 | outputs.append(v) 42 | if len(outputs) == 1: 43 | outputs = outputs[0] 44 | return outputs 45 | 46 | def setup_webdataset_path(paths, cache_path=None): 47 | if cache_path is None or not os.path.exists(cache_path): 48 | tar_paths = [] 49 | if isinstance(paths, str): 50 | paths = [paths] 51 | for path in paths: 52 | if path.strip().endswith(".tar"): 53 | # Avoid looking up s3 if we already have a tar file 54 | tar_paths.append(path) 55 | continue 56 | bucket = "/".join(path.split("/")[:3]) 57 | result = subprocess.run([f"aws s3 ls {path} --recursive | awk '{{print $4}}'"], stdout=subprocess.PIPE, shell=True, check=True) 58 | files = result.stdout.decode('utf-8').split() 59 | files = [f"{bucket}/{f}" for f in files if f.endswith(".tar")] 60 | tar_paths += files 61 | 62 | with open(cache_path, 'w', encoding='utf-8') as outfile: 63 | yaml.dump(tar_paths, outfile, default_flow_style=False) 64 | else: 65 | with open(cache_path, 'r', encoding='utf-8') as file: 66 | tar_paths = yaml.safe_load(file) 67 | 68 | tar_paths_str = ",".join([f"{p}" for p in tar_paths]) 69 | return f"pipe:aws s3 cp {{ {tar_paths_str} }} -" 70 | -------------------------------------------------------------------------------- /codes/StableCascade/core/data/bucketeer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import numpy as np 4 | from torchtools.transforms import SmartCrop 5 | import math 6 | 7 | class Bucketeer(): 8 | def __init__(self, dataloader, density=256*256, factor=8, ratios=[1/1, 1/2, 3/4, 3/5, 4/5, 6/9, 9/16], reverse_list=True, randomize_p=0.3, randomize_q=0.2, crop_mode='random', p_random_ratio=0.0, interpolate_nearest=False): 9 | assert crop_mode in ['center', 'random', 'smart'] 10 | self.crop_mode = crop_mode 11 | self.ratios = ratios 12 | if reverse_list: 13 | for r in list(ratios): 14 | if 1/r not in self.ratios: 15 | self.ratios.append(1/r) 16 | self.sizes = [(int(((density/r)**0.5//factor)*factor), int(((density*r)**0.5//factor)*factor)) for r in ratios] 17 | self.batch_size = dataloader.batch_size 18 | self.iterator = iter(dataloader) 19 | self.buckets = {s: [] for s in self.sizes} 20 | self.smartcrop = SmartCrop(int(density**0.5), randomize_p, randomize_q) if self.crop_mode=='smart' else None 21 | self.p_random_ratio = p_random_ratio 22 | self.interpolate_nearest = interpolate_nearest 23 | 24 | def get_available_batch(self): 25 | for b in self.buckets: 26 | if len(self.buckets[b]) >= self.batch_size: 27 | batch = self.buckets[b][:self.batch_size] 28 | self.buckets[b] = self.buckets[b][self.batch_size:] 29 | return batch 30 | return None 31 | 32 | def get_closest_size(self, x): 33 | if self.p_random_ratio > 0 and np.random.rand() < self.p_random_ratio: 34 | best_size_idx = np.random.randint(len(self.ratios)) 35 | else: 36 | w, h = x.size(-1), x.size(-2) 37 | best_size_idx = np.argmin([abs(w/h-r) for r in self.ratios]) 38 | return self.sizes[best_size_idx] 39 | 40 | def get_resize_size(self, orig_size, tgt_size): 41 | if (tgt_size[1]/tgt_size[0] - 1) * (orig_size[1]/orig_size[0] - 1) >= 0: 42 | alt_min = int(math.ceil(max(tgt_size)*min(orig_size)/max(orig_size))) 43 | resize_size = max(alt_min, min(tgt_size)) 44 | else: 45 | alt_max = int(math.ceil(min(tgt_size)*max(orig_size)/min(orig_size))) 46 | resize_size = max(alt_max, max(tgt_size)) 47 | return resize_size 48 | 49 | def __next__(self): 50 | batch = self.get_available_batch() 51 | while batch is None: 52 | elements = next(self.iterator) 53 | for dct in elements: 54 | img = dct['images'] 55 | size = self.get_closest_size(img) 56 | resize_size = self.get_resize_size(img.shape[-2:], size) 57 | if self.interpolate_nearest: 58 | img = torchvision.transforms.functional.resize(img, resize_size, interpolation=torchvision.transforms.InterpolationMode.NEAREST) 59 | else: 60 | img = torchvision.transforms.functional.resize(img, resize_size, interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True) 61 | if self.crop_mode == 'center': 62 | img = torchvision.transforms.functional.center_crop(img, size) 63 | elif self.crop_mode == 'random': 64 | img = torchvision.transforms.RandomCrop(size)(img) 65 | elif self.crop_mode == 'smart': 66 | self.smartcrop.output_size = size 67 | img = self.smartcrop(img) 68 | self.buckets[size].append({**{'images': img}, **{k:dct[k] for k in dct if k != 'images'}}) 69 | batch = self.get_available_batch() 70 | 71 | out = {k:[batch[i][k] for i in range(len(batch))] for k in batch[0]} 72 | return {k: torch.stack(o, dim=0) if isinstance(o[0], torch.Tensor) else o for k, o in out.items()} 73 | -------------------------------------------------------------------------------- /codes/StableCascade/core/scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mini-sora/minisora/c831ce3cbac4c9a81a315c6901145c023814897c/codes/StableCascade/core/scripts/__init__.py -------------------------------------------------------------------------------- /codes/StableCascade/core/scripts/cli.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | from .. import WarpCore 4 | from .. import templates 5 | 6 | 7 | def template_init(args): 8 | return '''' 9 | 10 | 11 | '''.strip() 12 | 13 | 14 | def init_template(args): 15 | parser = argparse.ArgumentParser(description='WarpCore template init tool') 16 | parser.add_argument('-t', '--template', type=str, default='WarpCore') 17 | args = parser.parse_args(args) 18 | 19 | if args.template == 'WarpCore': 20 | template_cls = WarpCore 21 | else: 22 | try: 23 | template_cls = __import__(args.template) 24 | except ModuleNotFoundError: 25 | template_cls = getattr(templates, args.template) 26 | print(template_cls) 27 | 28 | 29 | def main(): 30 | if len(sys.argv) < 2: 31 | print('Usage: core ') 32 | sys.exit(1) 33 | if sys.argv[1] == 'init': 34 | init_template(sys.argv[2:]) 35 | else: 36 | print('Unknown command') 37 | sys.exit(1) 38 | 39 | 40 | if __name__ == '__main__': 41 | main() 42 | -------------------------------------------------------------------------------- /codes/StableCascade/core/templates/__init__.py: -------------------------------------------------------------------------------- 1 | from .diffusion import DiffusionCore -------------------------------------------------------------------------------- /codes/StableCascade/core/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_dto import Base, nested_dto, EXPECTED, EXPECTED_TRAIN 2 | from .save_and_load import create_folder_if_necessary, safe_save, load_or_fail 3 | 4 | # MOVE IT SOMERWHERE ELSE 5 | def update_weights_ema(tgt_model, src_model, beta=0.999): 6 | for self_params, src_params in zip(tgt_model.parameters(), src_model.parameters()): 7 | self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1-beta) 8 | for self_buffers, src_buffers in zip(tgt_model.buffers(), src_model.buffers()): 9 | self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1-beta) -------------------------------------------------------------------------------- /codes/StableCascade/core/utils/base_dto.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from dataclasses import dataclass, _MISSING_TYPE 3 | from munch import Munch 4 | 5 | EXPECTED = "___REQUIRED___" 6 | EXPECTED_TRAIN = "___REQUIRED_TRAIN___" 7 | 8 | # pylint: disable=invalid-field-call 9 | def nested_dto(x, raw=False): 10 | return dataclasses.field(default_factory=lambda: x if raw else Munch.fromDict(x)) 11 | 12 | @dataclass(frozen=True) 13 | class Base: 14 | training: bool = None 15 | def __new__(cls, **kwargs): 16 | training = kwargs.get('training', True) 17 | setteable_fields = cls.setteable_fields(**kwargs) 18 | mandatory_fields = cls.mandatory_fields(**kwargs) 19 | invalid_kwargs = [ 20 | {k: v} for k, v in kwargs.items() if k not in setteable_fields or v == EXPECTED or (v == EXPECTED_TRAIN and training is not False) 21 | ] 22 | print(mandatory_fields) 23 | assert ( 24 | len(invalid_kwargs) == 0 25 | ), f"Invalid fields detected when initializing this DTO: {invalid_kwargs}.\nDeclare this field and set it to None or EXPECTED in order to make it setteable." 26 | missing_kwargs = [f for f in mandatory_fields if f not in kwargs] 27 | assert ( 28 | len(missing_kwargs) == 0 29 | ), f"Required fields missing initializing this DTO: {missing_kwargs}." 30 | return object.__new__(cls) 31 | 32 | 33 | @classmethod 34 | def setteable_fields(cls, **kwargs): 35 | return [f.name for f in dataclasses.fields(cls) if f.default is None or isinstance(f.default, _MISSING_TYPE) or f.default == EXPECTED or f.default == EXPECTED_TRAIN] 36 | 37 | @classmethod 38 | def mandatory_fields(cls, **kwargs): 39 | training = kwargs.get('training', True) 40 | return [f.name for f in dataclasses.fields(cls) if isinstance(f.default, _MISSING_TYPE) and isinstance(f.default_factory, _MISSING_TYPE) or f.default == EXPECTED or (f.default == EXPECTED_TRAIN and training is not False)] 41 | 42 | @classmethod 43 | def from_dict(cls, kwargs): 44 | for k in kwargs: 45 | if isinstance(kwargs[k], (dict, list, tuple)): 46 | kwargs[k] = Munch.fromDict(kwargs[k]) 47 | return cls(**kwargs) 48 | 49 | def to_dict(self): 50 | # selfdict = dataclasses.asdict(self) # needs to pickle stuff, doesn't support some more complex classes 51 | selfdict = {} 52 | for k in dataclasses.fields(self): 53 | selfdict[k.name] = getattr(self, k.name) 54 | if isinstance(selfdict[k.name], Munch): 55 | selfdict[k.name] = selfdict[k.name].toDict() 56 | return selfdict 57 | -------------------------------------------------------------------------------- /codes/StableCascade/core/utils/save_and_load.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import json 4 | from pathlib import Path 5 | import safetensors 6 | import wandb 7 | 8 | 9 | def create_folder_if_necessary(path): 10 | path = "/".join(path.split("/")[:-1]) 11 | Path(path).mkdir(parents=True, exist_ok=True) 12 | 13 | 14 | def safe_save(ckpt, path): 15 | try: 16 | os.remove(f"{path}.bak") 17 | except OSError: 18 | pass 19 | try: 20 | os.rename(path, f"{path}.bak") 21 | except OSError: 22 | pass 23 | if path.endswith(".pt") or path.endswith(".ckpt"): 24 | torch.save(ckpt, path) 25 | elif path.endswith(".json"): 26 | with open(path, "w", encoding="utf-8") as f: 27 | json.dump(ckpt, f, indent=4) 28 | elif path.endswith(".safetensors"): 29 | safetensors.torch.save_file(ckpt, path) 30 | else: 31 | raise ValueError(f"File extension not supported: {path}") 32 | 33 | 34 | def load_or_fail(path, wandb_run_id=None): 35 | accepted_extensions = [".pt", ".ckpt", ".json", ".safetensors"] 36 | try: 37 | assert any( 38 | [path.endswith(ext) for ext in accepted_extensions] 39 | ), f"Automatic loading not supported for this extension: {path}" 40 | if not os.path.exists(path): 41 | checkpoint = None 42 | elif path.endswith(".pt") or path.endswith(".ckpt"): 43 | checkpoint = torch.load(path, map_location="cpu") 44 | elif path.endswith(".json"): 45 | with open(path, "r", encoding="utf-8") as f: 46 | checkpoint = json.load(f) 47 | elif path.endswith(".safetensors"): 48 | checkpoint = {} 49 | with safetensors.safe_open(path, framework="pt", device="cpu") as f: 50 | for key in f.keys(): 51 | checkpoint[key] = f.get_tensor(key) 52 | return checkpoint 53 | except Exception as e: 54 | if wandb_run_id is not None: 55 | wandb.alert( 56 | title=f"Corrupt checkpoint for run {wandb_run_id}", 57 | text=f"Training {wandb_run_id} tried to load checkpoint {path} and failed", 58 | ) 59 | raise e 60 | -------------------------------------------------------------------------------- /codes/StableCascade/figures/collage_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mini-sora/minisora/c831ce3cbac4c9a81a315c6901145c023814897c/codes/StableCascade/figures/collage_1.jpg -------------------------------------------------------------------------------- /codes/StableCascade/figures/collage_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mini-sora/minisora/c831ce3cbac4c9a81a315c6901145c023814897c/codes/StableCascade/figures/collage_2.jpg -------------------------------------------------------------------------------- /codes/StableCascade/figures/collage_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mini-sora/minisora/c831ce3cbac4c9a81a315c6901145c023814897c/codes/StableCascade/figures/collage_3.jpg -------------------------------------------------------------------------------- /codes/StableCascade/figures/collage_4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mini-sora/minisora/c831ce3cbac4c9a81a315c6901145c023814897c/codes/StableCascade/figures/collage_4.jpg -------------------------------------------------------------------------------- /codes/StableCascade/figures/comparison-inference-speed.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mini-sora/minisora/c831ce3cbac4c9a81a315c6901145c023814897c/codes/StableCascade/figures/comparison-inference-speed.jpg -------------------------------------------------------------------------------- /codes/StableCascade/figures/comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mini-sora/minisora/c831ce3cbac4c9a81a315c6901145c023814897c/codes/StableCascade/figures/comparison.png -------------------------------------------------------------------------------- /codes/StableCascade/figures/controlnet-canny.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mini-sora/minisora/c831ce3cbac4c9a81a315c6901145c023814897c/codes/StableCascade/figures/controlnet-canny.jpg -------------------------------------------------------------------------------- /codes/StableCascade/figures/controlnet-face.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mini-sora/minisora/c831ce3cbac4c9a81a315c6901145c023814897c/codes/StableCascade/figures/controlnet-face.jpg -------------------------------------------------------------------------------- /codes/StableCascade/figures/controlnet-paint.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mini-sora/minisora/c831ce3cbac4c9a81a315c6901145c023814897c/codes/StableCascade/figures/controlnet-paint.jpg -------------------------------------------------------------------------------- /codes/StableCascade/figures/controlnet-sr.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mini-sora/minisora/c831ce3cbac4c9a81a315c6901145c023814897c/codes/StableCascade/figures/controlnet-sr.jpg -------------------------------------------------------------------------------- /codes/StableCascade/figures/fernando.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mini-sora/minisora/c831ce3cbac4c9a81a315c6901145c023814897c/codes/StableCascade/figures/fernando.jpg -------------------------------------------------------------------------------- /codes/StableCascade/figures/fernando_original.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mini-sora/minisora/c831ce3cbac4c9a81a315c6901145c023814897c/codes/StableCascade/figures/fernando_original.jpg -------------------------------------------------------------------------------- /codes/StableCascade/figures/image-to-image-example-rodent.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mini-sora/minisora/c831ce3cbac4c9a81a315c6901145c023814897c/codes/StableCascade/figures/image-to-image-example-rodent.jpg -------------------------------------------------------------------------------- /codes/StableCascade/figures/image-variations-example-headset.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mini-sora/minisora/c831ce3cbac4c9a81a315c6901145c023814897c/codes/StableCascade/figures/image-variations-example-headset.jpg -------------------------------------------------------------------------------- /codes/StableCascade/figures/model-overview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mini-sora/minisora/c831ce3cbac4c9a81a315c6901145c023814897c/codes/StableCascade/figures/model-overview.jpg -------------------------------------------------------------------------------- /codes/StableCascade/figures/original.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mini-sora/minisora/c831ce3cbac4c9a81a315c6901145c023814897c/codes/StableCascade/figures/original.jpg -------------------------------------------------------------------------------- /codes/StableCascade/figures/reconstructed.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mini-sora/minisora/c831ce3cbac4c9a81a315c6901145c023814897c/codes/StableCascade/figures/reconstructed.jpg -------------------------------------------------------------------------------- /codes/StableCascade/figures/text-to-image-example-penguin.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mini-sora/minisora/c831ce3cbac4c9a81a315c6901145c023814897c/codes/StableCascade/figures/text-to-image-example-penguin.jpg -------------------------------------------------------------------------------- /codes/StableCascade/gdf/loss_weights.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | # --- Loss Weighting 5 | class BaseLossWeight(): 6 | def weight(self, logSNR): 7 | raise NotImplementedError("this method needs to be overridden") 8 | 9 | def __call__(self, logSNR, *args, shift=1, clamp_range=None, **kwargs): 10 | clamp_range = [-1e9, 1e9] if clamp_range is None else clamp_range 11 | if shift != 1: 12 | logSNR = logSNR.clone() + 2 * np.log(shift) 13 | return self.weight(logSNR, *args, **kwargs).clamp(*clamp_range) 14 | 15 | class ComposedLossWeight(BaseLossWeight): 16 | def __init__(self, div, mul): 17 | self.mul = [mul] if isinstance(mul, BaseLossWeight) else mul 18 | self.div = [div] if isinstance(div, BaseLossWeight) else div 19 | 20 | def weight(self, logSNR): 21 | prod, div = 1, 1 22 | for m in self.mul: 23 | prod *= m.weight(logSNR) 24 | for d in self.div: 25 | div *= d.weight(logSNR) 26 | return prod/div 27 | 28 | class ConstantLossWeight(BaseLossWeight): 29 | def __init__(self, v=1): 30 | self.v = v 31 | 32 | def weight(self, logSNR): 33 | return torch.ones_like(logSNR) * self.v 34 | 35 | class SNRLossWeight(BaseLossWeight): 36 | def weight(self, logSNR): 37 | return logSNR.exp() 38 | 39 | class P2LossWeight(BaseLossWeight): 40 | def __init__(self, k=1.0, gamma=1.0, s=1.0): 41 | self.k, self.gamma, self.s = k, gamma, s 42 | 43 | def weight(self, logSNR): 44 | return (self.k + (logSNR * self.s).exp()) ** -self.gamma 45 | 46 | class SNRPlusOneLossWeight(BaseLossWeight): 47 | def weight(self, logSNR): 48 | return logSNR.exp() + 1 49 | 50 | class MinSNRLossWeight(BaseLossWeight): 51 | def __init__(self, max_snr=5): 52 | self.max_snr = max_snr 53 | 54 | def weight(self, logSNR): 55 | return logSNR.exp().clamp(max=self.max_snr) 56 | 57 | class MinSNRPlusOneLossWeight(BaseLossWeight): 58 | def __init__(self, max_snr=5): 59 | self.max_snr = max_snr 60 | 61 | def weight(self, logSNR): 62 | return (logSNR.exp() + 1).clamp(max=self.max_snr) 63 | 64 | class TruncatedSNRLossWeight(BaseLossWeight): 65 | def __init__(self, min_snr=1): 66 | self.min_snr = min_snr 67 | 68 | def weight(self, logSNR): 69 | return logSNR.exp().clamp(min=self.min_snr) 70 | 71 | class SechLossWeight(BaseLossWeight): 72 | def __init__(self, div=2): 73 | self.div = div 74 | 75 | def weight(self, logSNR): 76 | return 1/(logSNR/self.div).cosh() 77 | 78 | class DebiasedLossWeight(BaseLossWeight): 79 | def weight(self, logSNR): 80 | return 1/logSNR.exp().sqrt() 81 | 82 | class SigmoidLossWeight(BaseLossWeight): 83 | def __init__(self, s=1): 84 | self.s = s 85 | 86 | def weight(self, logSNR): 87 | return (logSNR * self.s).sigmoid() 88 | 89 | class AdaptiveLossWeight(BaseLossWeight): 90 | def __init__(self, logsnr_range=[-10, 10], buckets=300, weight_range=[1e-7, 1e7]): 91 | self.bucket_ranges = torch.linspace(logsnr_range[0], logsnr_range[1], buckets-1) 92 | self.bucket_losses = torch.ones(buckets) 93 | self.weight_range = weight_range 94 | 95 | def weight(self, logSNR): 96 | indices = torch.searchsorted(self.bucket_ranges.to(logSNR.device), logSNR) 97 | return (1/self.bucket_losses.to(logSNR.device)[indices]).clamp(*self.weight_range) 98 | 99 | def update_buckets(self, logSNR, loss, beta=0.99): 100 | indices = torch.searchsorted(self.bucket_ranges.to(logSNR.device), logSNR).cpu() 101 | self.bucket_losses[indices] = self.bucket_losses[indices]*beta + loss.detach().cpu() * (1-beta) 102 | -------------------------------------------------------------------------------- /codes/StableCascade/gdf/noise_conditions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | class BaseNoiseCond(): 5 | def __init__(self, *args, shift=1, clamp_range=None, **kwargs): 6 | clamp_range = [-1e9, 1e9] if clamp_range is None else clamp_range 7 | self.shift = shift 8 | self.clamp_range = clamp_range 9 | self.setup(*args, **kwargs) 10 | 11 | def setup(self, *args, **kwargs): 12 | pass # this method is optional, override it if required 13 | 14 | def cond(self, logSNR): 15 | raise NotImplementedError("this method needs to be overriden") 16 | 17 | def __call__(self, logSNR): 18 | if self.shift != 1: 19 | logSNR = logSNR.clone() + 2 * np.log(self.shift) 20 | return self.cond(logSNR).clamp(*self.clamp_range) 21 | 22 | class CosineTNoiseCond(BaseNoiseCond): 23 | def setup(self, s=0.008, clamp_range=[0, 1]): # [0.0001, 0.9999] 24 | self.s = torch.tensor([s]) 25 | self.clamp_range = clamp_range 26 | self.min_var = torch.cos(self.s / (1 + self.s) * torch.pi * 0.5) ** 2 27 | 28 | def cond(self, logSNR): 29 | var = logSNR.sigmoid() 30 | var = var.clamp(*self.clamp_range) 31 | s, min_var = self.s.to(var.device), self.min_var.to(var.device) 32 | t = (((var * min_var) ** 0.5).acos() / (torch.pi * 0.5)) * (1 + s) - s 33 | return t 34 | 35 | class EDMNoiseCond(BaseNoiseCond): 36 | def cond(self, logSNR): 37 | return -logSNR/8 38 | 39 | class SigmoidNoiseCond(BaseNoiseCond): 40 | def cond(self, logSNR): 41 | return (-logSNR).sigmoid() 42 | 43 | class LogSNRNoiseCond(BaseNoiseCond): 44 | def cond(self, logSNR): 45 | return logSNR 46 | 47 | class EDMSigmaNoiseCond(BaseNoiseCond): 48 | def setup(self, sigma_data=1): 49 | self.sigma_data = sigma_data 50 | 51 | def cond(self, logSNR): 52 | return torch.exp(-logSNR / 2) * self.sigma_data 53 | 54 | class RectifiedFlowsNoiseCond(BaseNoiseCond): 55 | def cond(self, logSNR): 56 | _a = logSNR.exp() - 1 57 | _a[_a == 0] = 1e-3 # Avoid division by zero 58 | a = 1 + (2-(2**2 + 4*_a)**0.5) / (2*_a) 59 | return a 60 | 61 | # Any NoiseCond that cannot be described easily as a continuous function of t 62 | # It needs to define self.x and self.y in the setup() method 63 | class PiecewiseLinearNoiseCond(BaseNoiseCond): 64 | def setup(self): 65 | self.x = None 66 | self.y = None 67 | 68 | def piecewise_linear(self, y, xs, ys): 69 | indices = (len(xs)-2) - torch.searchsorted(ys.flip(dims=(-1,))[:-2], y) 70 | x_min, x_max = xs[indices], xs[indices+1] 71 | y_min, y_max = ys[indices], ys[indices+1] 72 | x = x_min + (x_max - x_min) * (y - y_min) / (y_max - y_min) 73 | return x 74 | 75 | def cond(self, logSNR): 76 | var = logSNR.sigmoid() 77 | t = self.piecewise_linear(var, self.x.to(var.device), self.y.to(var.device)) # .mul(1000).round().clamp(min=0) 78 | return t 79 | 80 | class StableDiffusionNoiseCond(PiecewiseLinearNoiseCond): 81 | def setup(self, linear_range=[0.00085, 0.012], total_steps=1000): 82 | self.total_steps = total_steps 83 | linear_range_sqrt = [r**0.5 for r in linear_range] 84 | self.x = torch.linspace(0, 1, total_steps+1) 85 | 86 | alphas = 1-(linear_range_sqrt[0]*(1-self.x) + linear_range_sqrt[1]*self.x)**2 87 | self.y = alphas.cumprod(dim=-1) 88 | 89 | def cond(self, logSNR): 90 | return super().cond(logSNR).clamp(0, 1) 91 | 92 | class DiscreteNoiseCond(BaseNoiseCond): 93 | def setup(self, noise_cond, steps=1000, continuous_range=[0, 1]): 94 | self.noise_cond = noise_cond 95 | self.steps = steps 96 | self.continuous_range = continuous_range 97 | 98 | def cond(self, logSNR): 99 | cond = self.noise_cond(logSNR) 100 | cond = (cond-self.continuous_range[0]) / (self.continuous_range[1]-self.continuous_range[0]) 101 | return cond.mul(self.steps).long() 102 | -------------------------------------------------------------------------------- /codes/StableCascade/gdf/readme.md: -------------------------------------------------------------------------------- 1 | # Generic Diffusion Framework (GDF) 2 | 3 | # Basic usage 4 | GDF is a simple framework for working with diffusion models. It implements most common diffusion frameworks (DDPM / DDIM 5 | , EDM, Rectified Flows, etc.) and makes it very easy to switch between them or combine different parts of different 6 | frameworks 7 | 8 | Using GDF is very straighforward, first of all just define an instance of the GDF class: 9 | 10 | ```python 11 | from gdf import GDF 12 | from gdf import CosineSchedule 13 | from gdf import VPScaler, EpsilonTarget, CosineTNoiseCond, P2LossWeight 14 | 15 | gdf = GDF( 16 | schedule=CosineSchedule(clamp_range=[0.0001, 0.9999]), 17 | input_scaler=VPScaler(), target=EpsilonTarget(), 18 | noise_cond=CosineTNoiseCond(), 19 | loss_weight=P2LossWeight(), 20 | ) 21 | ``` 22 | 23 | You need to define the following components: 24 | * **Train Schedule**: This will return the logSNR schedule that will be used during training, some of the schedulers can be configured. A train schedule will then be called with a batch size and will randomly sample some values from the defined distribution. 25 | * **Sample Schedule**: This is the schedule that will be used later on when sampling. It might be different from the training schedule. 26 | * **Input Scaler**: If you want to use Variance Preserving or LERP (rectified flows) 27 | * **Target**: What the target is during training, usually: epsilon, x0 or v 28 | * **Noise Conditioning**: You could directly pass the logSNR to your model but usually a normalized value is used instead, for example the EDM framework proposes to use `-logSNR/8` 29 | * **Loss Weight**: There are many proposed loss weighting strategies, here you define which one you'll use 30 | 31 | All of those classes are actually very simple logSNR centric definitions, for example the VPScaler is defined as just: 32 | ```python 33 | class VPScaler(): 34 | def __call__(self, logSNR): 35 | a_squared = logSNR.sigmoid() 36 | a = a_squared.sqrt() 37 | b = (1-a_squared).sqrt() 38 | return a, b 39 | 40 | ``` 41 | 42 | So it's very easy to extend this framework with custom schedulers, scalers, targets, loss weights, etc... 43 | 44 | ### Training 45 | 46 | When you define your training loop you can get all you need by just doing: 47 | ```python 48 | shift, loss_shift = 1, 1 # this can be set to higher values as per what the Simple Diffusion paper sugested for high resolution 49 | for inputs, extra_conditions in dataloader_iterator: 50 | noised, noise, target, logSNR, noise_cond, loss_weight = gdf.diffuse(inputs, shift=shift, loss_shift=loss_shift) 51 | pred = diffusion_model(noised, noise_cond, extra_conditions) 52 | 53 | loss = nn.functional.mse_loss(pred, target, reduction='none').mean(dim=[1, 2, 3]) 54 | loss_adjusted = (loss * loss_weight).mean() 55 | 56 | loss_adjusted.backward() 57 | optimizer.step() 58 | optimizer.zero_grad(set_to_none=True) 59 | ``` 60 | 61 | And that's all, you have a diffusion model training, where it's very easy to customize the different elements of the 62 | training from the GDF class. 63 | 64 | ### Sampling 65 | 66 | The other important part is sampling, when you want to use this framework to sample you can just do the following: 67 | 68 | ```python 69 | from gdf import DDPMSampler 70 | 71 | shift = 1 72 | sampling_configs = { 73 | "timesteps": 30, "cfg": 7, "sampler": DDPMSampler(gdf), "shift": shift, 74 | "schedule": CosineSchedule(clamp_range=[0.0001, 0.9999]) 75 | } 76 | 77 | *_, (sampled, _, _) = gdf.sample( 78 | diffusion_model, {"cond": extra_conditions}, latents.shape, 79 | unconditional_inputs= {"cond": torch.zeros_like(extra_conditions)}, 80 | device=device, **sampling_configs 81 | ) 82 | ``` 83 | 84 | # Available modules 85 | 86 | TODO 87 | -------------------------------------------------------------------------------- /codes/StableCascade/gdf/samplers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class SimpleSampler(): 4 | def __init__(self, gdf): 5 | self.gdf = gdf 6 | self.current_step = -1 7 | 8 | def __call__(self, *args, **kwargs): 9 | self.current_step += 1 10 | return self.step(*args, **kwargs) 11 | 12 | def init_x(self, shape): 13 | return torch.randn(*shape) 14 | 15 | def step(self, x, x0, epsilon, logSNR, logSNR_prev): 16 | raise NotImplementedError("You should override the 'apply' function.") 17 | 18 | class DDIMSampler(SimpleSampler): 19 | def step(self, x, x0, epsilon, logSNR, logSNR_prev, eta=0): 20 | a, b = self.gdf.input_scaler(logSNR) 21 | if len(a.shape) == 1: 22 | a, b = a.view(-1, *[1]*(len(x0.shape)-1)), b.view(-1, *[1]*(len(x0.shape)-1)) 23 | 24 | a_prev, b_prev = self.gdf.input_scaler(logSNR_prev) 25 | if len(a_prev.shape) == 1: 26 | a_prev, b_prev = a_prev.view(-1, *[1]*(len(x0.shape)-1)), b_prev.view(-1, *[1]*(len(x0.shape)-1)) 27 | 28 | sigma_tau = eta * (b_prev**2 / b**2).sqrt() * (1 - a**2 / a_prev**2).sqrt() if eta > 0 else 0 29 | # x = a_prev * x0 + (1 - a_prev**2 - sigma_tau ** 2).sqrt() * epsilon + sigma_tau * torch.randn_like(x0) 30 | x = a_prev * x0 + (b_prev**2 - sigma_tau**2).sqrt() * epsilon + sigma_tau * torch.randn_like(x0) 31 | return x 32 | 33 | class DDPMSampler(DDIMSampler): 34 | def step(self, x, x0, epsilon, logSNR, logSNR_prev, eta=1): 35 | return super().step(x, x0, epsilon, logSNR, logSNR_prev, eta) 36 | 37 | class LCMSampler(SimpleSampler): 38 | def step(self, x, x0, epsilon, logSNR, logSNR_prev): 39 | a_prev, b_prev = self.gdf.input_scaler(logSNR_prev) 40 | if len(a_prev.shape) == 1: 41 | a_prev, b_prev = a_prev.view(-1, *[1]*(len(x0.shape)-1)), b_prev.view(-1, *[1]*(len(x0.shape)-1)) 42 | return x0 * a_prev + torch.randn_like(epsilon) * b_prev 43 | -------------------------------------------------------------------------------- /codes/StableCascade/gdf/scalers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class BaseScaler(): 4 | def __init__(self): 5 | self.stretched_limits = None 6 | 7 | def setup_limits(self, schedule, input_scaler, stretch_max=True, stretch_min=True, shift=1): 8 | min_logSNR = schedule(torch.ones(1), shift=shift) 9 | max_logSNR = schedule(torch.zeros(1), shift=shift) 10 | 11 | min_a, max_b = [v.item() for v in input_scaler(min_logSNR)] if stretch_max else [0, 1] 12 | max_a, min_b = [v.item() for v in input_scaler(max_logSNR)] if stretch_min else [1, 0] 13 | self.stretched_limits = [min_a, max_a, min_b, max_b] 14 | return self.stretched_limits 15 | 16 | def stretch_limits(self, a, b): 17 | min_a, max_a, min_b, max_b = self.stretched_limits 18 | return (a - min_a) / (max_a - min_a), (b - min_b) / (max_b - min_b) 19 | 20 | def scalers(self, logSNR): 21 | raise NotImplementedError("this method needs to be overridden") 22 | 23 | def __call__(self, logSNR): 24 | a, b = self.scalers(logSNR) 25 | if self.stretched_limits is not None: 26 | a, b = self.stretch_limits(a, b) 27 | return a, b 28 | 29 | class VPScaler(BaseScaler): 30 | def scalers(self, logSNR): 31 | a_squared = logSNR.sigmoid() 32 | a = a_squared.sqrt() 33 | b = (1-a_squared).sqrt() 34 | return a, b 35 | 36 | class LERPScaler(BaseScaler): 37 | def scalers(self, logSNR): 38 | _a = logSNR.exp() - 1 39 | _a[_a == 0] = 1e-3 # Avoid division by zero 40 | a = 1 + (2-(2**2 + 4*_a)**0.5) / (2*_a) 41 | b = 1-a 42 | return a, b 43 | -------------------------------------------------------------------------------- /codes/StableCascade/gdf/targets.py: -------------------------------------------------------------------------------- 1 | class EpsilonTarget(): 2 | def __call__(self, x0, epsilon, logSNR, a, b): 3 | return epsilon 4 | 5 | def x0(self, noised, pred, logSNR, a, b): 6 | return (noised - pred * b) / a 7 | 8 | def epsilon(self, noised, pred, logSNR, a, b): 9 | return pred 10 | 11 | class X0Target(): 12 | def __call__(self, x0, epsilon, logSNR, a, b): 13 | return x0 14 | 15 | def x0(self, noised, pred, logSNR, a, b): 16 | return pred 17 | 18 | def epsilon(self, noised, pred, logSNR, a, b): 19 | return (noised - pred * a) / b 20 | 21 | class VTarget(): 22 | def __call__(self, x0, epsilon, logSNR, a, b): 23 | return a * epsilon - b * x0 24 | 25 | def x0(self, noised, pred, logSNR, a, b): 26 | squared_sum = a**2 + b**2 27 | return a/squared_sum * noised - b/squared_sum * pred 28 | 29 | def epsilon(self, noised, pred, logSNR, a, b): 30 | squared_sum = a**2 + b**2 31 | return b/squared_sum * noised + a/squared_sum * pred 32 | 33 | class RectifiedFlowsTarget(): 34 | def __call__(self, x0, epsilon, logSNR, a, b): 35 | return epsilon - x0 36 | 37 | def x0(self, noised, pred, logSNR, a, b): 38 | return noised - pred * b 39 | 40 | def epsilon(self, noised, pred, logSNR, a, b): 41 | return noised + pred * a 42 | -------------------------------------------------------------------------------- /codes/StableCascade/gradio_app/style.css: -------------------------------------------------------------------------------- 1 | h1 { 2 | text-align: center; 3 | justify-content: center; 4 | } 5 | [role="tabpanel"]{border: 0} 6 | #duplicate-button { 7 | margin: auto; 8 | color: #fff; 9 | background: #1565c0; 10 | border-radius: 100vh; 11 | } 12 | 13 | .gradio-container { 14 | max-width: 690px! important; 15 | } 16 | 17 | #share-btn-container{padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; max-width: 13rem; margin-left: auto;margin-top: 0.35em;} 18 | div#share-btn-container > div {flex-direction: row;background: black;align-items: center} 19 | #share-btn-container:hover {background-color: #060606} 20 | #share-btn {all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.5rem !important; padding-bottom: 0.5rem !important;right:0;font-size: 15px;} 21 | #share-btn * {all: unset} 22 | #share-btn-container div:nth-child(-n+2){width: auto !important;min-height: 0px !important;} 23 | #share-btn-container .wrap {display: none !important} 24 | #share-btn-container.hidden {display: none!important} -------------------------------------------------------------------------------- /codes/StableCascade/models/download_models.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Check if at least two arguments were provided (excluding the optional first one) 4 | if [ $# -lt 2 ]; then 5 | echo "Insufficient arguments provided. At least two arguments are required." 6 | exit 1 7 | fi 8 | 9 | # Check for the optional "essential" argument and download the essential models if present 10 | if [ "$1" == "essential" ]; then 11 | echo "Downloading Essential Models (EfficientNet, Stage A, Previewer)" 12 | wget https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_a.safetensors -P . -q --show-progress 13 | wget https://huggingface.co/stabilityai/StableWurst/resolve/main/previewer.safetensors -P . -q --show-progress 14 | wget https://huggingface.co/stabilityai/StableWurst/resolve/main/effnet_encoder.safetensors -P . -q --show-progress 15 | shift # Move the arguments, $2 becomes $1, $3 becomes $2, etc. 16 | fi 17 | 18 | # Now, $1 is the second argument due to the potential shift above 19 | second_argument="$1" 20 | binary_decision="${2:-bfloat16}" # Use default or specific binary value if provided 21 | 22 | case $second_argument in 23 | big-big) 24 | if [ "$binary_decision" == "bfloat16" ]; then 25 | echo "Downloading Large Stage B & Large Stage C" 26 | wget https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_b_bf16.safetensors -P . -q --show-progress 27 | wget https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_c_bf16.safetensors -P . -q --show-progress 28 | else 29 | wget https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_b.safetensors -P . -q --show-progress 30 | wget https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_c.safetensors -P . -q --show-progress 31 | fi 32 | ;; 33 | big-small) 34 | if [ "$binary_decision" == "bfloat16" ]; then 35 | echo "Downloading Large Stage B & Small Stage C (BFloat16)" 36 | wget https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_b_bf16.safetensors -P . -q --show-progress 37 | wget https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_c_lite_bf16.safetensors -P . -q --show-progress 38 | else 39 | echo "Downloading Large Stage B & Small Stage C" 40 | wget https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_b.safetensors -P . -q --show-progress 41 | wget https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_c_lite.safetensors -P . -q --show-progress 42 | fi 43 | ;; 44 | small-big) 45 | if [ "$binary_decision" == "bfloat16" ]; then 46 | echo "Downloading Small Stage B & Large Stage C (BFloat16)" 47 | wget https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_b_lite_bf16.safetensors -P . -q --show-progress 48 | wget https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_c_bf16.safetensors -P . -q --show-progress 49 | else 50 | echo "Downloading Small Stage B & Large Stage C" 51 | wget https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_b_lite.safetensors -P . -q --show-progress 52 | wget https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_c.safetensors -P . -q --show-progress 53 | fi 54 | ;; 55 | small-small) 56 | if [ "$binary_decision" == "bfloat16" ]; then 57 | echo "Downloading Small Stage B & Small Stage C (BFloat16)" 58 | wget https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_b_lite_bf16.safetensors -P . -q --show-progress 59 | wget https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_c_lite_bf16.safetensors -P . -q --show-progress 60 | else 61 | echo "Downloading Small Stage B & Small Stage C" 62 | wget https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_b_lite.safetensors -P . -q --show-progress 63 | wget https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_c_lite.safetensors -P . -q --show-progress 64 | fi 65 | ;; 66 | *) 67 | echo "Invalid second argument. Please provide a valid argument: big-big, big-small, small-big, or small-small." 68 | exit 2 69 | ;; 70 | esac 71 | -------------------------------------------------------------------------------- /codes/StableCascade/models/readme.md: -------------------------------------------------------------------------------- 1 | # Download Models 2 | 3 | As there are many models provided, let's make sure you only download the ones you need. 4 | The ``download_models.sh`` will make that very easy. The basic usage looks like this:
5 | ```bash 6 | bash download_models.sh essential variant bfloat16 7 | ``` 8 | 9 | **essential**
10 | This is optional and determines if you want to download the EfficientNet, Stage A & Previewer. 11 | If this is the first time you run this command, you should definitely do it, because we need it. 12 | 13 | **variant**
14 | This determines which varient you want to use for **Stage B** and **Stage C**. 15 | There are four options: 16 | 17 | | | Stage C (Large) | Stage C (Lite) | 18 | |---------------------|-----------------|----------------| 19 | | **Stage B (Large)** | big-big | big-small | 20 | | **Stage B (Lite)** | small-big | small-small | 21 | 22 | 23 | So if you want to download the large Stage B & large Stage C you can execute:
24 | ```bash 25 | bash download_models.sh essential big-big bfloat16 26 | ``` 27 | 28 | **bfloat16**
29 | The last argument is optional as well, and simply determines in which precision you download Stage B & Stage C. 30 | If you want a faster download, choose _bfloat16_ (if your machine supports it), otherwise use _float32_. 31 | 32 | ### Recommendation 33 | If your GPU allows for it, you should definitely go for the **large** Stage C, which has 3.6 billion parameters. 34 | It is a lot better and was finetuned a lot more. Also, the ControlNet and Lora examples are only for the large Stage C at the moment. 35 | For Stage B the difference is not so big. The **large** Stage B is better at reconstructing small details, 36 | but if your GPU is not so powerful, just go for the smaller one. 37 | 38 | ### Remark 39 | Unfortunately, you can not run the models in float16 at the moment. Only bfloat16 or float32 work for now. However, 40 | with some investigation, it should be possible to fix the overflowing and allow for inference in float16 as well. -------------------------------------------------------------------------------- /codes/StableCascade/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .effnet import EfficientNetEncoder 2 | from .stage_c import StageC 3 | from .stage_c import ResBlock, AttnBlock, TimestepBlock, FeedForwardBlock 4 | from .previewer import Previewer 5 | from .controlnet import ControlNet, ControlNetDeliverer 6 | from . import controlnet as controlnet_filters -------------------------------------------------------------------------------- /codes/StableCascade/modules/cnet_modules/inpainting/saliency_model.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mini-sora/minisora/c831ce3cbac4c9a81a315c6901145c023814897c/codes/StableCascade/modules/cnet_modules/inpainting/saliency_model.pt -------------------------------------------------------------------------------- /codes/StableCascade/modules/cnet_modules/inpainting/saliency_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | from torch import nn 4 | from PIL import Image 5 | import numpy as np 6 | import os 7 | 8 | 9 | # MICRO RESNET 10 | class ResBlock(nn.Module): 11 | def __init__(self, channels): 12 | super(ResBlock, self).__init__() 13 | 14 | self.resblock = nn.Sequential( 15 | nn.ReflectionPad2d(1), 16 | nn.Conv2d(channels, channels, kernel_size=3), 17 | nn.InstanceNorm2d(channels, affine=True), 18 | nn.ReLU(), 19 | nn.ReflectionPad2d(1), 20 | nn.Conv2d(channels, channels, kernel_size=3), 21 | nn.InstanceNorm2d(channels, affine=True), 22 | ) 23 | 24 | def forward(self, x): 25 | out = self.resblock(x) 26 | return out + x 27 | 28 | 29 | class Upsample2d(nn.Module): 30 | def __init__(self, scale_factor): 31 | super(Upsample2d, self).__init__() 32 | 33 | self.interp = nn.functional.interpolate 34 | self.scale_factor = scale_factor 35 | 36 | def forward(self, x): 37 | x = self.interp(x, scale_factor=self.scale_factor, mode='nearest') 38 | return x 39 | 40 | 41 | class MicroResNet(nn.Module): 42 | def __init__(self): 43 | super(MicroResNet, self).__init__() 44 | 45 | self.downsampler = nn.Sequential( 46 | nn.ReflectionPad2d(4), 47 | nn.Conv2d(3, 8, kernel_size=9, stride=4), 48 | nn.InstanceNorm2d(8, affine=True), 49 | nn.ReLU(), 50 | nn.ReflectionPad2d(1), 51 | nn.Conv2d(8, 16, kernel_size=3, stride=2), 52 | nn.InstanceNorm2d(16, affine=True), 53 | nn.ReLU(), 54 | nn.ReflectionPad2d(1), 55 | nn.Conv2d(16, 32, kernel_size=3, stride=2), 56 | nn.InstanceNorm2d(32, affine=True), 57 | nn.ReLU(), 58 | ) 59 | 60 | self.residual = nn.Sequential( 61 | ResBlock(32), 62 | nn.Conv2d(32, 64, kernel_size=1, bias=False, groups=32), 63 | ResBlock(64), 64 | ) 65 | 66 | self.segmentator = nn.Sequential( 67 | nn.ReflectionPad2d(1), 68 | nn.Conv2d(64, 16, kernel_size=3), 69 | nn.InstanceNorm2d(16, affine=True), 70 | nn.ReLU(), 71 | Upsample2d(scale_factor=2), 72 | nn.ReflectionPad2d(4), 73 | nn.Conv2d(16, 1, kernel_size=9), 74 | nn.Sigmoid() 75 | ) 76 | 77 | def forward(self, x): 78 | out = self.downsampler(x) 79 | out = self.residual(out) 80 | out = self.segmentator(out) 81 | return out 82 | -------------------------------------------------------------------------------- /codes/StableCascade/modules/cnet_modules/pidinet/__init__.py: -------------------------------------------------------------------------------- 1 | # Pidinet 2 | # https://github.com/hellozhuo/pidinet 3 | 4 | import os 5 | import torch 6 | import numpy as np 7 | from einops import rearrange 8 | from .model import pidinet 9 | from .util import annotator_ckpts_path, safe_step 10 | 11 | 12 | class PidiNetDetector: 13 | def __init__(self, device): 14 | remote_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/table5_pidinet.pth" 15 | modelpath = os.path.join(annotator_ckpts_path, "table5_pidinet.pth") 16 | if not os.path.exists(modelpath): 17 | from basicsr.utils.download_util import load_file_from_url 18 | load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path) 19 | self.netNetwork = pidinet() 20 | self.netNetwork.load_state_dict( 21 | {k.replace('module.', ''): v for k, v in torch.load(modelpath)['state_dict'].items()}) 22 | self.netNetwork.to(device).eval().requires_grad_(False) 23 | 24 | def __call__(self, input_image): # , safe=False): 25 | return self.netNetwork(input_image)[-1] 26 | # assert input_image.ndim == 3 27 | # input_image = input_image[:, :, ::-1].copy() 28 | # with torch.no_grad(): 29 | # image_pidi = torch.from_numpy(input_image).float().cuda() 30 | # image_pidi = image_pidi / 255.0 31 | # image_pidi = rearrange(image_pidi, 'h w c -> 1 c h w') 32 | # edge = self.netNetwork(image_pidi)[-1] 33 | 34 | # if safe: 35 | # edge = safe_step(edge) 36 | # edge = (edge * 255.0).clip(0, 255).astype(np.uint8) 37 | # return edge[0][0] 38 | -------------------------------------------------------------------------------- /codes/StableCascade/modules/cnet_modules/pidinet/ckpts/table5_pidinet.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mini-sora/minisora/c831ce3cbac4c9a81a315c6901145c023814897c/codes/StableCascade/modules/cnet_modules/pidinet/ckpts/table5_pidinet.pth -------------------------------------------------------------------------------- /codes/StableCascade/modules/cnet_modules/pidinet/util.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import cv2 5 | import os 6 | 7 | annotator_ckpts_path = os.path.join(os.path.dirname(__file__), 'ckpts') 8 | 9 | 10 | def HWC3(x): 11 | assert x.dtype == np.uint8 12 | if x.ndim == 2: 13 | x = x[:, :, None] 14 | assert x.ndim == 3 15 | H, W, C = x.shape 16 | assert C == 1 or C == 3 or C == 4 17 | if C == 3: 18 | return x 19 | if C == 1: 20 | return np.concatenate([x, x, x], axis=2) 21 | if C == 4: 22 | color = x[:, :, 0:3].astype(np.float32) 23 | alpha = x[:, :, 3:4].astype(np.float32) / 255.0 24 | y = color * alpha + 255.0 * (1.0 - alpha) 25 | y = y.clip(0, 255).astype(np.uint8) 26 | return y 27 | 28 | 29 | def resize_image(input_image, resolution): 30 | H, W, C = input_image.shape 31 | H = float(H) 32 | W = float(W) 33 | k = float(resolution) / min(H, W) 34 | H *= k 35 | W *= k 36 | H = int(np.round(H / 64.0)) * 64 37 | W = int(np.round(W / 64.0)) * 64 38 | img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA) 39 | return img 40 | 41 | 42 | def nms(x, t, s): 43 | x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s) 44 | 45 | f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8) 46 | f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8) 47 | f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8) 48 | f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8) 49 | 50 | y = np.zeros_like(x) 51 | 52 | for f in [f1, f2, f3, f4]: 53 | np.putmask(y, cv2.dilate(x, kernel=f) == x, x) 54 | 55 | z = np.zeros_like(y, dtype=np.uint8) 56 | z[y > t] = 255 57 | return z 58 | 59 | 60 | def make_noise_disk(H, W, C, F): 61 | noise = np.random.uniform(low=0, high=1, size=((H // F) + 2, (W // F) + 2, C)) 62 | noise = cv2.resize(noise, (W + 2 * F, H + 2 * F), interpolation=cv2.INTER_CUBIC) 63 | noise = noise[F: F + H, F: F + W] 64 | noise -= np.min(noise) 65 | noise /= np.max(noise) 66 | if C == 1: 67 | noise = noise[:, :, None] 68 | return noise 69 | 70 | 71 | def min_max_norm(x): 72 | x -= np.min(x) 73 | x /= np.maximum(np.max(x), 1e-5) 74 | return x 75 | 76 | 77 | def safe_step(x, step=2): 78 | y = x.astype(np.float32) * float(step + 1) 79 | y = y.astype(np.int32).astype(np.float32) / float(step) 80 | return y 81 | 82 | 83 | def img2mask(img, H, W, low=10, high=90): 84 | assert img.ndim == 3 or img.ndim == 2 85 | assert img.dtype == np.uint8 86 | 87 | if img.ndim == 3: 88 | y = img[:, :, random.randrange(0, img.shape[2])] 89 | else: 90 | y = img 91 | 92 | y = cv2.resize(y, (W, H), interpolation=cv2.INTER_CUBIC) 93 | 94 | if random.uniform(0, 1) < 0.5: 95 | y = 255 - y 96 | 97 | return y < np.percentile(y, random.randrange(low, high)) 98 | -------------------------------------------------------------------------------- /codes/StableCascade/modules/common.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class Linear(torch.nn.Linear): 5 | def reset_parameters(self): 6 | return None 7 | 8 | class Conv2d(torch.nn.Conv2d): 9 | def reset_parameters(self): 10 | return None 11 | 12 | 13 | class Attention2D(nn.Module): 14 | def __init__(self, c, nhead, dropout=0.0): 15 | super().__init__() 16 | self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True) 17 | 18 | def forward(self, x, kv, self_attn=False): 19 | orig_shape = x.shape 20 | x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4 21 | if self_attn: 22 | kv = torch.cat([x, kv], dim=1) 23 | x = self.attn(x, kv, kv, need_weights=False)[0] 24 | x = x.permute(0, 2, 1).view(*orig_shape) 25 | return x 26 | 27 | 28 | class LayerNorm2d(nn.LayerNorm): 29 | def __init__(self, *args, **kwargs): 30 | super().__init__(*args, **kwargs) 31 | 32 | def forward(self, x): 33 | return super().forward(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) 34 | 35 | 36 | class GlobalResponseNorm(nn.Module): 37 | "from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105" 38 | def __init__(self, dim): 39 | super().__init__() 40 | self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) 41 | self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) 42 | 43 | def forward(self, x): 44 | Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True) 45 | Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) 46 | return self.gamma * (x * Nx) + self.beta + x 47 | 48 | 49 | class ResBlock(nn.Module): 50 | def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0): # , num_heads=4, expansion=2): 51 | super().__init__() 52 | self.depthwise = Conv2d(c, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c) 53 | # self.depthwise = SAMBlock(c, num_heads, expansion) 54 | self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) 55 | self.channelwise = nn.Sequential( 56 | Linear(c + c_skip, c * 4), 57 | nn.GELU(), 58 | GlobalResponseNorm(c * 4), 59 | nn.Dropout(dropout), 60 | Linear(c * 4, c) 61 | ) 62 | 63 | def forward(self, x, x_skip=None): 64 | x_res = x 65 | x = self.norm(self.depthwise(x)) 66 | if x_skip is not None: 67 | x = torch.cat([x, x_skip], dim=1) 68 | x = self.channelwise(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) 69 | return x + x_res 70 | 71 | 72 | class AttnBlock(nn.Module): 73 | def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0): 74 | super().__init__() 75 | self.self_attn = self_attn 76 | self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) 77 | self.attention = Attention2D(c, nhead, dropout) 78 | self.kv_mapper = nn.Sequential( 79 | nn.SiLU(), 80 | Linear(c_cond, c) 81 | ) 82 | 83 | def forward(self, x, kv): 84 | kv = self.kv_mapper(kv) 85 | x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn) 86 | return x 87 | 88 | 89 | class FeedForwardBlock(nn.Module): 90 | def __init__(self, c, dropout=0.0): 91 | super().__init__() 92 | self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) 93 | self.channelwise = nn.Sequential( 94 | Linear(c, c * 4), 95 | nn.GELU(), 96 | GlobalResponseNorm(c * 4), 97 | nn.Dropout(dropout), 98 | Linear(c * 4, c) 99 | ) 100 | 101 | def forward(self, x): 102 | x = x + self.channelwise(self.norm(x).permute(0, 2, 3, 1)).permute(0, 3, 1, 2) 103 | return x 104 | 105 | 106 | class TimestepBlock(nn.Module): 107 | def __init__(self, c, c_timestep, conds=['sca']): 108 | super().__init__() 109 | self.mapper = Linear(c_timestep, c * 2) 110 | self.conds = conds 111 | for cname in conds: 112 | setattr(self, f"mapper_{cname}", Linear(c_timestep, c * 2)) 113 | 114 | def forward(self, x, t): 115 | t = t.chunk(len(self.conds) + 1, dim=1) 116 | a, b = self.mapper(t[0])[:, :, None, None].chunk(2, dim=1) 117 | for i, c in enumerate(self.conds): 118 | ac, bc = getattr(self, f"mapper_{c}")(t[i + 1])[:, :, None, None].chunk(2, dim=1) 119 | a, b = a + ac, b + bc 120 | return x * (1 + a) + b 121 | -------------------------------------------------------------------------------- /codes/StableCascade/modules/effnet.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | from torch import nn 3 | 4 | 5 | # EfficientNet 6 | class EfficientNetEncoder(nn.Module): 7 | def __init__(self, c_latent=16): 8 | super().__init__() 9 | self.backbone = torchvision.models.efficientnet_v2_s(weights='DEFAULT').features.eval() 10 | self.mapper = nn.Sequential( 11 | nn.Conv2d(1280, c_latent, kernel_size=1, bias=False), 12 | nn.BatchNorm2d(c_latent, affine=False), # then normalize them to have mean 0 and std 1 13 | ) 14 | 15 | def forward(self, x): 16 | return self.mapper(self.backbone(x)) 17 | 18 | -------------------------------------------------------------------------------- /codes/StableCascade/modules/lora.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LoRA(nn.Module): 6 | def __init__(self, layer, name='weight', rank=16, alpha=1): 7 | super().__init__() 8 | weight = getattr(layer, name) 9 | self.lora_down = nn.Parameter(torch.zeros((rank, weight.size(1)))) 10 | self.lora_up = nn.Parameter(torch.zeros((weight.size(0), rank))) 11 | nn.init.normal_(self.lora_up, mean=0, std=1) 12 | 13 | self.scale = alpha / rank 14 | self.enabled = True 15 | 16 | def forward(self, original_weights): 17 | if self.enabled: 18 | lora_shape = list(original_weights.shape[:2]) + [1] * (len(original_weights.shape) - 2) 19 | lora_weights = torch.matmul(self.lora_up.clone(), self.lora_down.clone()).view(*lora_shape) * self.scale 20 | return original_weights + lora_weights 21 | else: 22 | return original_weights 23 | 24 | 25 | def apply_lora(model, filters=None, rank=16): 26 | def check_parameter(module, name): 27 | return hasattr(module, name) and not torch.nn.utils.parametrize.is_parametrized(module, name) and isinstance( 28 | getattr(module, name), nn.Parameter) 29 | 30 | for name, module in model.named_modules(): 31 | if filters is None or any([f in name for f in filters]): 32 | if check_parameter(module, "weight"): 33 | device, dtype = module.weight.device, module.weight.dtype 34 | torch.nn.utils.parametrize.register_parametrization(module, 'weight', LoRA(module, "weight", rank=rank).to(dtype).to(device)) 35 | elif check_parameter(module, "in_proj_weight"): 36 | device, dtype = module.in_proj_weight.device, module.in_proj_weight.dtype 37 | torch.nn.utils.parametrize.register_parametrization(module, 'in_proj_weight', LoRA(module, "in_proj_weight", rank=rank).to(dtype).to(device)) 38 | 39 | 40 | class ReToken(nn.Module): 41 | def __init__(self, indices=None): 42 | super().__init__() 43 | assert indices is not None 44 | self.embeddings = nn.Parameter(torch.zeros(len(indices), 1280)) 45 | self.register_buffer('indices', torch.tensor(indices)) 46 | self.enabled = True 47 | 48 | def forward(self, embeddings): 49 | if self.enabled: 50 | embeddings = embeddings.clone() 51 | for i, idx in enumerate(self.indices): 52 | embeddings[idx] += self.embeddings[i] 53 | return embeddings 54 | 55 | 56 | def apply_retoken(module, indices=None): 57 | def check_parameter(module, name): 58 | return hasattr(module, name) and not torch.nn.utils.parametrize.is_parametrized(module, name) and isinstance( 59 | getattr(module, name), nn.Parameter) 60 | 61 | if check_parameter(module, "weight"): 62 | device, dtype = module.weight.device, module.weight.dtype 63 | torch.nn.utils.parametrize.register_parametrization(module, 'weight', ReToken(indices=indices).to(dtype).to(device)) 64 | 65 | 66 | def remove_lora(model, leave_parametrized=True): 67 | for module in model.modules(): 68 | if torch.nn.utils.parametrize.is_parametrized(module, "weight"): 69 | nn.utils.parametrize.remove_parametrizations(module, "weight", leave_parametrized=leave_parametrized) 70 | elif torch.nn.utils.parametrize.is_parametrized(module, "in_proj_weight"): 71 | nn.utils.parametrize.remove_parametrizations(module, "in_proj_weight", leave_parametrized=leave_parametrized) 72 | -------------------------------------------------------------------------------- /codes/StableCascade/modules/previewer.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | # Fast Decoder for Stage C latents. E.g. 16 x 24 x 24 -> 3 x 192 x 192 5 | class Previewer(nn.Module): 6 | def __init__(self, c_in=16, c_hidden=512, c_out=3): 7 | super().__init__() 8 | self.blocks = nn.Sequential( 9 | nn.Conv2d(c_in, c_hidden, kernel_size=1), # 16 channels to 512 channels 10 | nn.GELU(), 11 | nn.BatchNorm2d(c_hidden), 12 | 13 | nn.Conv2d(c_hidden, c_hidden, kernel_size=3, padding=1), 14 | nn.GELU(), 15 | nn.BatchNorm2d(c_hidden), 16 | 17 | nn.ConvTranspose2d(c_hidden, c_hidden // 2, kernel_size=2, stride=2), # 16 -> 32 18 | nn.GELU(), 19 | nn.BatchNorm2d(c_hidden // 2), 20 | 21 | nn.Conv2d(c_hidden // 2, c_hidden // 2, kernel_size=3, padding=1), 22 | nn.GELU(), 23 | nn.BatchNorm2d(c_hidden // 2), 24 | 25 | nn.ConvTranspose2d(c_hidden // 2, c_hidden // 4, kernel_size=2, stride=2), # 32 -> 64 26 | nn.GELU(), 27 | nn.BatchNorm2d(c_hidden // 4), 28 | 29 | nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1), 30 | nn.GELU(), 31 | nn.BatchNorm2d(c_hidden // 4), 32 | 33 | nn.ConvTranspose2d(c_hidden // 4, c_hidden // 4, kernel_size=2, stride=2), # 64 -> 128 34 | nn.GELU(), 35 | nn.BatchNorm2d(c_hidden // 4), 36 | 37 | nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1), 38 | nn.GELU(), 39 | nn.BatchNorm2d(c_hidden // 4), 40 | 41 | nn.Conv2d(c_hidden // 4, c_out, kernel_size=1), 42 | ) 43 | 44 | def forward(self, x): 45 | return self.blocks(x) 46 | -------------------------------------------------------------------------------- /codes/StableCascade/requirements.txt: -------------------------------------------------------------------------------- 1 | --find-links https://download.pytorch.org/whl/torch_stable.html 2 | accelerate>=0.25.0 3 | torch==2.1.2+cu118 4 | torchvision==0.16.2+cu118 5 | transformers>=4.30.0 6 | numpy>=1.23.5 7 | kornia>=0.7.0 8 | insightface>=0.7.3 9 | opencv-python>=4.8.1.78 10 | tqdm>=4.66.1 11 | matplotlib>=3.7.4 12 | webdataset>=0.2.79 13 | wandb>=0.16.2 14 | munch>=4.0.0 15 | onnxruntime>=1.16.3 16 | einops>=0.7.0 17 | onnx2torch>=1.5.13 18 | warmup-scheduler @ git+https://github.com/ildoonet/pytorch-gradual-warmup-lr.git 19 | torchtools @ git+https://github.com/pabloppp/pytorch-tools 20 | -------------------------------------------------------------------------------- /codes/StableCascade/train/__init__.py: -------------------------------------------------------------------------------- 1 | from .train_b import WurstCore as WurstCoreB 2 | from .train_c import WurstCore as WurstCoreC 3 | from .train_c_controlnet import WurstCore as ControlNetCore 4 | from .train_c_lora import WurstCore as LoraCore -------------------------------------------------------------------------------- /codes/StableCascade/train/example_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --partition=A100 3 | #SBATCH --nodes=1 4 | #SBATCH --gpus-per-node=8 5 | #SBATCH --ntasks-per-node=8 6 | #SBATCH --exclusive 7 | #SBATCH --job-name=your_job_name 8 | #SBATCH --account your_account_name 9 | 10 | module load openmpi 11 | module load cuda/11.8 12 | export NCCL_PROTO=simple 13 | 14 | export FI_EFA_FORK_SAFE=1 15 | export FI_LOG_LEVEL=1 16 | export FI_EFA_USE_DEVICE_RDMA=1 # use for p4dn 17 | 18 | export NCCL_DEBUG=info 19 | export PYTHONFAULTHANDLER=1 20 | 21 | export CUDA_LAUNCH_BLOCKING=0 22 | export OMPI_MCA_mtl_base_verbose=1 23 | export FI_EFA_ENABLE_SHM_TRANSFER=0 24 | export FI_PROVIDER=efa 25 | export FI_EFA_TX_MIN_CREDITS=64 26 | export NCCL_TREE_THRESHOLD=0 27 | 28 | export PYTHONWARNINGS="ignore" 29 | export CXX=g++ 30 | 31 | source /path/to/your/python/environment/bin/activate 32 | 33 | master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) 34 | export MASTER_ADDR=$master_addr 35 | export MASTER_PORT=33751 36 | export PYTHONPATH=./StableWurst 37 | echo "r$SLURM_NODEID master: $MASTER_ADDR" 38 | echo "r$SLURM_NODEID Launching python script" 39 | 40 | cd /path/to/your/directory 41 | rm dist_file 42 | srun python3 train/train_c_lora.py configs/training/finetune_c_3b_lora.yaml -------------------------------------------------------------------------------- /codes/WALT/README.md: -------------------------------------------------------------------------------- 1 | Add WALT codes later after the codes of it are released 2 | -------------------------------------------------------------------------------- /codes/pipeline/README.md: -------------------------------------------------------------------------------- 1 | # Pipeline 2 | 3 | 4 | 5 | [![Contributors][contributors-shield]][contributors-url] 6 | [![Forks][forks-shield]][forks-url] 7 | [![Issues][issues-shield]][issues-url] 8 | [![MIT License][license-shield]][license-url] 9 | [![Stargazers][stars-shield]][stars-url] 10 |
11 | 12 | 13 |
14 | 15 | 16 |
 
17 |
18 |
19 |
20 | 21 |
22 | 23 | English | [简体中文](./README_zh-CN.md) 24 | 25 |
26 | 27 | ## Objectives 28 | 29 | A codebase equipped with video understanding labeling functionality and a review feature for rich annotations. 30 | 31 | ## Performance 32 | 33 | Generate data for SVD training with the expectation that, the model will be capable of producing high-quality videos ranging from one minute to several minutes. 34 | 35 | [your-project-path]: mini-sora/minisora 36 | [contributors-shield]: https://img.shields.io/github/contributors/mini-sora/minisora.svg?style=flat-square 37 | [contributors-url]: https://github.com/mini-sora/minisora/graphs/contributors 38 | [forks-shield]: https://img.shields.io/github/forks/mini-sora/minisora.svg?style=flat-square 39 | [forks-url]: https://github.com/mini-sora/minisora/network/members 40 | [stars-shield]: https://img.shields.io/github/stars/mini-sora/minisora.svg?style=flat-square 41 | [stars-url]: https://github.com/mini-sora/minisora/stargazers 42 | [issues-shield]: https://img.shields.io/github/issues/mini-sora/minisora.svg?style=flat-square 43 | [issues-url]: https://img.shields.io/github/issues/mini-sora/minisora.svg 44 | [license-shield]: https://img.shields.io/github/license/mini-sora/minisora.svg?style=flat-square 45 | [license-url]: https://github.com/mini-sora/minisora/blob/main/LICENSE 46 | -------------------------------------------------------------------------------- /codes/pipeline/README_zh-CN.md: -------------------------------------------------------------------------------- 1 | # Mini Sora复现Pipline计划 2 | 3 | 4 | 5 | [![Contributors][contributors-shield]][contributors-url] 6 | [![Forks][forks-shield]][forks-url] 7 | [![Issues][issues-shield]][issues-url] 8 | [![MIT License][license-shield]][license-url] 9 | [![Stargazers][stars-shield]][stars-url] 10 |
11 | 12 | 13 |
14 | 15 | 16 |
 
17 |
18 |
19 |
20 | 21 |
22 | 23 | [English](./README.md) | 简体中文 24 | 25 |
26 | 27 | ## 目标 28 | 29 | 形成一个带有视频理解labeling功能, 且对丰富标注具有审查功能的repo 30 | 31 | ## 效果 32 | 33 | 生成用于SVD训练的数据,以期待满足训练完成后模型可以生成长度一分钟乃至十几分钟的高质量视频 34 | 35 | [your-project-path]: mini-sora/minisora 36 | [contributors-shield]: https://img.shields.io/github/contributors/mini-sora/minisora.svg?style=flat-square 37 | [contributors-url]: https://github.com/mini-sora/minisora/graphs/contributors 38 | [forks-shield]: https://img.shields.io/github/forks/mini-sora/minisora.svg?style=flat-square 39 | [forks-url]: https://github.com/mini-sora/minisora/network/members 40 | [stars-shield]: https://img.shields.io/github/stars/mini-sora/minisora.svg?style=flat-square 41 | [stars-url]: https://github.com/mini-sora/minisora/stargazers 42 | [issues-shield]: https://img.shields.io/github/issues/mini-sora/minisora.svg?style=flat-square 43 | [issues-url]: https://img.shields.io/github/issues/mini-sora/minisora.svg 44 | [license-shield]: https://img.shields.io/github/license/mini-sora/minisora.svg?style=flat-square 45 | [license-url]: https://github.com/mini-sora/minisora/blob/main/LICENSE 46 | -------------------------------------------------------------------------------- /docs/HOT_NEWS_BASELINES_GUIDES.md: -------------------------------------------------------------------------------- 1 | # MiniSora Contribution Guides for Hot News and Baseline Sections 2 | 3 | 4 | 5 | [![Contributors][contributors-shield]][contributors-url] 6 | [![Forks][forks-shield]][forks-url] 7 | [![Issues][issues-shield]][issues-url] 8 | [![MIT License][license-shield]][license-url] 9 | [![Stargazers][stars-shield]][stars-url] 10 |
11 | 12 | 13 | 14 |
15 | 16 | 17 |
 
18 |
19 |
20 | 21 |
22 | 23 | English | [简体中文](./HOT_NEWS_BASELINES_GUIDES_zh-CN.md) 24 | 25 |
26 | 27 | ## Baseline Paper Criteria 28 | 29 | A paper can be listed as a "Baseline" if it meets one of the following criteria: 30 | 31 | 1. Nominated for best paper (including best paper) at a top conference or published in a top journal. 32 | 2. Has more than 100 citations. 33 | 3. Has more than 500 GitHub Stars. 34 | 35 | ## Hots News Criteria 36 | 37 | A paper can be listed as a "Hots News" if it meets one of the following criteria: 38 | 39 | 1. Nominated for best paper (including best paper) at a top conference or published in a top journal. 40 | 2. Has more than 500 GitHub Stars. 41 | 3. **Research works** that **meet either of the above two criteria** and include optimizations, improvements, or fixes, such as **OpenDiT** and **OpenSora**. 42 | 4. Significant follow-up work to previous achievements, such as refactoring of original code or tuning of models, like **SD3**. 43 | 44 | ## How to Add Research Work to Hots News and Baseline 45 | 46 | 1. Propose the paper title (or Blog link, GitHub link, etc.) in an Issue. 47 | 2. Wait for a response from a MiniSora member for verification. 48 | 3. Once agreed, you can add the research work to these two sections in a PR. 49 | 4. Note: PR merges for papers that are not recognized will be rejected. 50 | 51 | [your-project-path]: mini-sora/minisora 52 | [contributors-shield]: https://img.shields.io/github/contributors/mini-sora/minisora.svg?style=flat-square 53 | [contributors-url]: https://github.com/mini-sora/minisora/graphs/contributors 54 | [forks-shield]: https://img.shields.io/github/forks/mini-sora/minisora.svg?style=flat-square 55 | [forks-url]: https://github.com/mini-sora/minisora/network/members 56 | [stars-shield]: https://img.shields.io/github/stars/mini-sora/minisora.svg?style=flat-square 57 | [stars-url]: https://github.com/mini-sora/minisora/stargazers 58 | [issues-shield]: https://img.shields.io/github/issues/mini-sora/minisora.svg?style=flat-square 59 | [issues-url]: https://img.shields.io/github/issues/mini-sora/minisora.svg 60 | [license-shield]: https://img.shields.io/github/license/mini-sora/minisora.svg?style=flat-square 61 | [license-url]: https://github.com/mini-sora/minisora/blob/main/LICENSE 62 | -------------------------------------------------------------------------------- /docs/HOT_NEWS_BASELINES_GUIDES_zh-CN.md: -------------------------------------------------------------------------------- 1 | # MiniSora 热点更新和Baseline模型模块贡献指南 2 | 3 | 4 | 5 | [![Contributors][contributors-shield]][contributors-url] 6 | [![Forks][forks-shield]][forks-url] 7 | [![Issues][issues-shield]][issues-url] 8 | [![MIT License][license-shield]][license-url] 9 | [![Stargazers][stars-shield]][stars-url] 10 |
11 | 12 | 13 | 14 |
15 | 16 | 17 |
 
18 |
19 |
20 | 21 |
22 | 23 | [English](./HOT_NEWS_BASELINES_GUIDES.md) | 简体中文 24 | 25 |
26 | 27 | ## Baseline 论文标准 28 | 29 | 满足以下要求之一即可被列为“Baseline”: 30 | 31 | 1. 顶会最佳论文提名(含最佳论文)或者Top期刊论文 32 | 2. 引用数大于100 33 | 3. GitHub Star数大于500 34 | 35 | ## 热点更新标准 36 | 37 | 满足以下要求之一即可被列为“热点更新”: 38 | 39 | 1. 顶会最佳论文提名(含最佳论文)或者Top期刊论文 40 | 2. GitHub Star数大于500 41 | 3. **满足以上两项的有关工作**的优化 、改进或修复,如**OpenDiT**和**OpenSora**等 42 | 4. 之前工作成果显著的后续工作,对原始代码的重构或模型的调参等,如**SD3**等 43 | 44 | ## 如何添加研究工作到热点更新和Baseline 45 | 46 | 1. 在Issue提出有关论文的论文名(或者Blog链接,GitHub链接等) 47 | 2. 等待MiniSora member验证回复, 48 | 3. 确认同意后,方可在PR中添加研究工作到这两个部分中 49 | 4. 注意:不被认同的论文添加将会被拒绝PR merge 50 | 51 | ### [更多社区贡献说明文档](../.github/CONTRIBUTING_zh-CN.md) 52 | 53 | [your-project-path]: mini-sora/minisora 54 | [contributors-shield]: https://img.shields.io/github/contributors/mini-sora/minisora.svg?style=flat-square 55 | [contributors-url]: https://github.com/mini-sora/minisora/graphs/contributors 56 | [forks-shield]: https://img.shields.io/github/forks/mini-sora/minisora.svg?style=flat-square 57 | [forks-url]: https://github.com/mini-sora/minisora/network/members 58 | [stars-shield]: https://img.shields.io/github/stars/mini-sora/minisora.svg?style=flat-square 59 | [stars-url]: https://github.com/mini-sora/minisora/stargazers 60 | [issues-shield]: https://img.shields.io/github/issues/mini-sora/minisora.svg?style=flat-square 61 | [issues-url]: https://img.shields.io/github/issues/mini-sora/minisora.svg 62 | [license-shield]: https://img.shields.io/github/license/mini-sora/minisora.svg?style=flat-square 63 | [license-url]: https://github.com/mini-sora/minisora/blob/main/LICENSE 64 | -------------------------------------------------------------------------------- /docs/Minisora_LPRS/0001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mini-sora/minisora/c831ce3cbac4c9a81a315c6901145c023814897c/docs/Minisora_LPRS/0001.jpg -------------------------------------------------------------------------------- /docs/Minisora_LPRS/0002.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mini-sora/minisora/c831ce3cbac4c9a81a315c6901145c023814897c/docs/Minisora_LPRS/0002.jpg -------------------------------------------------------------------------------- /docs/Minisora_LPRS/0003.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mini-sora/minisora/c831ce3cbac4c9a81a315c6901145c023814897c/docs/Minisora_LPRS/0003.jpg -------------------------------------------------------------------------------- /docs/Minisora_LPRS/0004.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mini-sora/minisora/c831ce3cbac4c9a81a315c6901145c023814897c/docs/Minisora_LPRS/0004.jpg -------------------------------------------------------------------------------- /docs/Minisora_LPRS/0005.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mini-sora/minisora/c831ce3cbac4c9a81a315c6901145c023814897c/docs/Minisora_LPRS/0005.jpg -------------------------------------------------------------------------------- /docs/Minisora_LPRS/0006.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mini-sora/minisora/c831ce3cbac4c9a81a315c6901145c023814897c/docs/Minisora_LPRS/0006.jpg -------------------------------------------------------------------------------- /docs/Minisora_LPRS/0007.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mini-sora/minisora/c831ce3cbac4c9a81a315c6901145c023814897c/docs/Minisora_LPRS/0007.jpg -------------------------------------------------------------------------------- /docs/Minisora_LPRS/0008.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mini-sora/minisora/c831ce3cbac4c9a81a315c6901145c023814897c/docs/Minisora_LPRS/0008.jpg -------------------------------------------------------------------------------- /docs/Minisora_LPRS/0009.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mini-sora/minisora/c831ce3cbac4c9a81a315c6901145c023814897c/docs/Minisora_LPRS/0009.jpg -------------------------------------------------------------------------------- /docs/Minisora_LPRS/0010.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mini-sora/minisora/c831ce3cbac4c9a81a315c6901145c023814897c/docs/Minisora_LPRS/0010.jpg -------------------------------------------------------------------------------- /docs/Minisora_LPRS/0011.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mini-sora/minisora/c831ce3cbac4c9a81a315c6901145c023814897c/docs/Minisora_LPRS/0011.jpg -------------------------------------------------------------------------------- /docs/Minisora_LPRS/0012.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mini-sora/minisora/c831ce3cbac4c9a81a315c6901145c023814897c/docs/Minisora_LPRS/0012.jpg -------------------------------------------------------------------------------- /docs/Minisora_LPRS/0013.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mini-sora/minisora/c831ce3cbac4c9a81a315c6901145c023814897c/docs/Minisora_LPRS/0013.jpg -------------------------------------------------------------------------------- /docs/Minisora_LPRS/0014.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mini-sora/minisora/c831ce3cbac4c9a81a315c6901145c023814897c/docs/Minisora_LPRS/0014.jpg -------------------------------------------------------------------------------- /docs/Minisora_LPRS/0015.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mini-sora/minisora/c831ce3cbac4c9a81a315c6901145c023814897c/docs/Minisora_LPRS/0015.jpg -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # MiniSora Docs 2 | 3 | 4 | 5 | [![Contributors][contributors-shield]][contributors-url] 6 | [![Forks][forks-shield]][forks-url] 7 | [![Issues][issues-shield]][issues-url] 8 | [![MIT License][license-shield]][license-url] 9 | [![Stargazers][stars-shield]][stars-url] 10 |
11 | 12 | 13 |
14 | 15 | 16 |
 
17 |
18 |
19 | 20 |
21 | 22 | English | [简体中文](README_zh-CN.md) 23 | 24 |
25 | 26 | The "docs" directory contains various documents related to the minisora repository. 27 | 28 | * [Community Contribution Guidelines Document](../.github/CONTRIBUTING.md) 29 | 30 | [your-project-path]: mini-sora/minisora 31 | [contributors-shield]: https://img.shields.io/github/contributors/mini-sora/minisora.svg?style=flat-square 32 | [contributors-url]: https://github.com/mini-sora/minisora/graphs/contributors 33 | [forks-shield]: https://img.shields.io/github/forks/mini-sora/minisora.svg?style=flat-square 34 | [forks-url]: https://github.com/mini-sora/minisora/network/members 35 | [stars-shield]: https://img.shields.io/github/stars/mini-sora/minisora.svg?style=flat-square 36 | [stars-url]: https://github.com/mini-sora/minisora/stargazers 37 | [issues-shield]: https://img.shields.io/github/issues/mini-sora/minisora.svg?style=flat-square 38 | [issues-url]: https://img.shields.io/github/issues/mini-sora/minisora.svg 39 | [license-shield]: https://img.shields.io/github/license/mini-sora/minisora.svg?style=flat-square 40 | [license-url]: https://github.com/mini-sora/minisora/blob/main/LICENSE -------------------------------------------------------------------------------- /docs/README_zh-CN.md: -------------------------------------------------------------------------------- 1 | # MiniSora Docs 2 | 3 | 4 | 5 | [![Contributors][contributors-shield]][contributors-url] 6 | [![Forks][forks-shield]][forks-url] 7 | [![Issues][issues-shield]][issues-url] 8 | [![MIT License][license-shield]][license-url] 9 | [![Stargazers][stars-shield]][stars-url] 10 |
11 | 12 | 13 | 14 |
15 | 16 | 17 |
 
18 |
19 |
20 | 21 |
22 | 23 | [English](README.md) | 简体中文 24 | 25 |
26 | 27 | docs目录下存放的是minisora仓库相关的一些文档 28 | 29 | * [社区贡献说明文档](../.github/CONTRIBUTING_zh-CN.md) 30 | 31 | 32 | [your-project-path]: mini-sora/minisora 33 | [contributors-shield]: https://img.shields.io/github/contributors/mini-sora/minisora.svg?style=flat-square 34 | [contributors-url]: https://github.com/mini-sora/minisora/graphs/contributors 35 | [forks-shield]: https://img.shields.io/github/forks/mini-sora/minisora.svg?style=flat-square 36 | [forks-url]: https://github.com/mini-sora/minisora/network/members 37 | [stars-shield]: https://img.shields.io/github/stars/mini-sora/minisora.svg?style=flat-square 38 | [stars-url]: https://github.com/mini-sora/minisora/stargazers 39 | [issues-shield]: https://img.shields.io/github/issues/mini-sora/minisora.svg?style=flat-square 40 | [issues-url]: https://img.shields.io/github/issues/mini-sora/minisora.svg 41 | [license-shield]: https://img.shields.io/github/license/mini-sora/minisora.svg?style=flat-square 42 | [license-url]: https://github.com/mini-sora/minisora/blob/main/LICENSE 43 | -------------------------------------------------------------------------------- /docs/survey_README.md: -------------------------------------------------------------------------------- 1 | # Introduction of MiniSora and Latest Progress in Replicating Sora 2 | 3 | 4 | [![Contributors][contributors-shield]][contributors-url] 5 | [![Forks][forks-shield]][forks-url] 6 | [![Issues][issues-shield]][issues-url] 7 | [![MIT License][license-shield]][license-url] 8 | [![Stargazers][stars-shield]][stars-url] 9 |
10 | 11 | 12 |
13 | 14 | 15 |
 
16 |
17 |
18 | 19 |
20 | 21 | English | [简体中文](./survey_README_zh-CN.md) 22 | 23 |
24 | 25 | If you are interested in joining our Sora Survey Paper Team, please leave your messages in the Issues and PRs section. We will respond to your inquiries as promptly as possible. 26 | 27 | ![](./Minisora_LPRS/0001.jpg) 28 | ![](./Minisora_LPRS/0002.jpg) 29 | ![](./Minisora_LPRS/0003.jpg) 30 | ![](./Minisora_LPRS/0004.jpg) 31 | ![](./Minisora_LPRS/0005.jpg) 32 | ![](./Minisora_LPRS/0006.jpg) 33 | ![](./Minisora_LPRS/0007.jpg) 34 | ![](./Minisora_LPRS/0008.jpg) 35 | ![](./Minisora_LPRS/0009.jpg) 36 | ![](./Minisora_LPRS/0010.jpg) 37 | ![](./Minisora_LPRS/0011.jpg) 38 | ![](./Minisora_LPRS/0012.jpg) 39 | ![](./Minisora_LPRS/0013.jpg) 40 | ![](./Minisora_LPRS/0014.jpg) 41 | ![](./Minisora_LPRS/0015.jpg) 42 | 43 | [your-project-path]: mini-sora/minisora 44 | [contributors-shield]: https://img.shields.io/github/contributors/mini-sora/minisora.svg?style=flat-square 45 | [contributors-url]: https://github.com/mini-sora/minisora/graphs/contributors 46 | [forks-shield]: https://img.shields.io/github/forks/mini-sora/minisora.svg?style=flat-square 47 | [forks-url]: https://github.com/mini-sora/minisora/network/members 48 | [stars-shield]: https://img.shields.io/github/stars/mini-sora/minisora.svg?style=flat-square 49 | [stars-url]: https://github.com/mini-sora/minisora/stargazers 50 | [issues-shield]: https://img.shields.io/github/issues/mini-sora/minisora.svg?style=flat-square 51 | [issues-url]: https://img.shields.io/github/issues/mini-sora/minisora.svg 52 | [license-shield]: https://img.shields.io/github/license/mini-sora/minisora.svg?style=flat-square 53 | [license-url]: https://github.com/mini-sora/minisora/blob/main/LICENSE -------------------------------------------------------------------------------- /docs/survey_README_zh-CN.md: -------------------------------------------------------------------------------- 1 | # MiniSora简介与Sora复现最新进展 2 | 3 | 4 | 5 | [![Contributors][contributors-shield]][contributors-url] 6 | [![Forks][forks-shield]][forks-url] 7 | [![Issues][issues-shield]][issues-url] 8 | [![MIT License][license-shield]][license-url] 9 | [![Stargazers][stars-shield]][stars-url] 10 |
11 | 12 | 13 |
14 | 15 |
 
16 |
17 |
18 | 19 |
20 | 21 | [English](./survey_README.md) | 简体中文 22 | 23 |
24 | 25 | 如果您有兴趣加入我们的Sora综述论文团队,请在Issue和PR留下您的信息。我们将尽快回复您的问题。 26 | 27 | ![](./Minisora_LPRS/0001.jpg) 28 | ![](./Minisora_LPRS/0002.jpg) 29 | ![](./Minisora_LPRS/0003.jpg) 30 | ![](./Minisora_LPRS/0004.jpg) 31 | ![](./Minisora_LPRS/0005.jpg) 32 | ![](./Minisora_LPRS/0006.jpg) 33 | ![](./Minisora_LPRS/0007.jpg) 34 | ![](./Minisora_LPRS/0008.jpg) 35 | ![](./Minisora_LPRS/0009.jpg) 36 | ![](./Minisora_LPRS/0010.jpg) 37 | ![](./Minisora_LPRS/0011.jpg) 38 | ![](./Minisora_LPRS/0012.jpg) 39 | ![](./Minisora_LPRS/0013.jpg) 40 | ![](./Minisora_LPRS/0014.jpg) 41 | ![](./Minisora_LPRS/0015.jpg) 42 | 43 | 44 | [your-project-path]: mini-sora/minisora 45 | [contributors-shield]: https://img.shields.io/github/contributors/mini-sora/minisora.svg?style=flat-square 46 | [contributors-url]: https://github.com/mini-sora/minisora/graphs/contributors 47 | [forks-shield]: https://img.shields.io/github/forks/mini-sora/minisora.svg?style=flat-square 48 | [forks-url]: https://github.com/mini-sora/minisora/network/members 49 | [stars-shield]: https://img.shields.io/github/stars/mini-sora/minisora.svg?style=flat-square 50 | [stars-url]: https://github.com/mini-sora/minisora/stargazers 51 | [issues-shield]: https://img.shields.io/github/issues/mini-sora/minisora.svg?style=flat-square 52 | [issues-url]: https://img.shields.io/github/issues/mini-sora/minisora.svg 53 | [license-shield]: https://img.shields.io/github/license/mini-sora/minisora.svg?style=flat-square 54 | [license-url]: https://github.com/mini-sora/minisora/blob/main/LICENSE 55 | -------------------------------------------------------------------------------- /notes/PixArt-Σ 论文精读翻译.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mini-sora/minisora/c831ce3cbac4c9a81a315c6901145c023814897c/notes/PixArt-Σ 论文精读翻译.pdf -------------------------------------------------------------------------------- /notes/PixArt-Σ论文解析.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mini-sora/minisora/c831ce3cbac4c9a81a315c6901145c023814897c/notes/PixArt-Σ论文解析.pdf -------------------------------------------------------------------------------- /notes/README.md: -------------------------------------------------------------------------------- 1 | 2 | # Paper reading and round-table discussion of Minisora Community 3 | 4 | 5 | [![Contributors][contributors-shield]][contributors-url] 6 | [![Forks][forks-shield]][forks-url] 7 | [![Issues][issues-shield]][issues-url] 8 | [![MIT License][license-shield]][license-url] 9 | [![Stargazers][stars-shield]][stars-url] 10 |
11 | 12 | 13 |
14 | 15 | 16 |
 
17 |
18 |
19 | 20 |
21 | 22 | English | [简体中文](./README_zh-CN.md) 23 | 24 |
25 | 26 | ## Paper Reading Notes 27 | 28 | [**Data Notes(zh-CN)**](./dataset_note.md) 29 | 30 | [**Latte Paper Intensive Translation(zh-CN)**](./latte%E8%AE%BA%E6%96%87%E7%B2%BE%E8%AF%BB%E7%BF%BB%E8%AF%91.pdf) 31 | 32 | [**Latte Paper Interpretation(zh-CN)**](./Latte.md) 33 | 34 | [**Parallel Acceleration Method for Diffusion Models: DistriFusion Paper Discussion(zh-CN)**](https://mp.weixin.qq.com/s/K6juxdW5RdBpFERmVrsUkA) 35 | 36 | **SD3 Paper Interpretation([zh-CN](./SD3_zh-CN.md), [ZhiHu](https://zhuanlan.zhihu.com/p/686273242))** 37 | 38 | ## Previous round-table discussions 39 | 40 | ### Night Talk with Sora: Video Diffusion Overview 41 | 42 | **Speaker**: Xing Zhen, PhD candidate in the Vision and Learning Lab, Fudan University 43 | 44 | **Highlights**: Fundamentals of Diffusion Modeling for Image Generation / Development of Diffusion Modeling for Text to Video / An Introduction to the Technology and Reproduction Challenges Behind Sora 45 | 46 | **Time of Online Live**: 02/28 20:00-21:00 47 | 48 | **Bilibili**: [Sora 之 Video Diffusion 综述(zh-CN)](https://www.bilibili.com/video/BV1cJ4m1e7sQ) 49 | 50 | **PPT**: [FeiShu Link](https://aicarrier.feishu.cn/file/Ds0BbCAo6oTazdxxo3Zciw1Nnne) 51 | 52 | **ZhiHu Notes**: [A Survey on Generative Diffusion Model(zh-CN)](https://zhuanlan.zhihu.com/p/684795460) 53 | 54 | ------------------ 55 | 56 | ### Paper Interpretation of Stable Diffusion 3 paper: MM-DiT 57 | 58 | **Speaker**: MMagic Core Contributors 59 | 60 | **Highlights**: MMagic core contributors will lead us in interpreting the Stable Diffusion 3 paper, discussing the architecture details and design principles of Stable Diffusion 3. 61 | 62 | **Time of Online Live**: 03/12 20:00 63 | 64 | **PPT**: [FeiShu Link](https://aicarrier.feishu.cn/file/NXnTbo5eqo8xNYxeHnecjLdJnQq) 65 | 66 | [contributors-shield]: https://img.shields.io/github/contributors/mini-sora/minisora.svg?style=flat-square 67 | [contributors-url]: https://github.com/mini-sora/minisora/graphs/contributors 68 | [forks-shield]: https://img.shields.io/github/forks/mini-sora/minisora.svg?style=flat-square 69 | [forks-url]: https://github.com/mini-sora/minisora/network/members 70 | [stars-shield]: https://img.shields.io/github/stars/mini-sora/minisora.svg?style=flat-square 71 | [stars-url]: https://github.com/mini-sora/minisora/stargazers 72 | [issues-shield]: https://img.shields.io/github/issues/mini-sora/minisora.svg?style=flat-square 73 | [issues-url]: https://img.shields.io/github/issues/mini-sora/minisora.svg 74 | [license-shield]: https://img.shields.io/github/license/mini-sora/minisora.svg?style=flat-square 75 | [license-url]: https://github.com/mini-sora/minisora/blob/main/LICENSE 76 | -------------------------------------------------------------------------------- /notes/README_zh-CN.md: -------------------------------------------------------------------------------- 1 | 2 | # Minisora 社区论文阅读和圆桌讨论 3 | 4 | 5 | [![Contributors][contributors-shield]][contributors-url] 6 | [![Forks][forks-shield]][forks-url] 7 | [![Issues][issues-shield]][issues-url] 8 | [![MIT License][license-shield]][license-url] 9 | [![Stargazers][stars-shield]][stars-url] 10 |
11 | 12 | 13 |
14 | 15 | 16 |
 
17 |
18 |
19 | 20 |
21 | 22 | [English](./README.md) | 简体中文 23 | 24 |
25 | 26 | ## 论文阅读笔记 27 | 28 | [**数据Notes**](./dataset_note.md) 29 | 30 | [**Latte论文精读翻译**](./latte%E8%AE%BA%E6%96%87%E7%B2%BE%E8%AF%BB%E7%BF%BB%E8%AF%91.pdf) 31 | 32 | [**Latte论文解读**](./Latte.md) 33 | 34 | [**扩散模型并行加速方法: DistriFusion论文共读**](https://mp.weixin.qq.com/s/K6juxdW5RdBpFERmVrsUkA) 35 | 36 | **SD3论文精读([zh-CN](./SD3_zh-CN.md), [ZhiHu](https://zhuanlan.zhihu.com/p/686273242))** 37 | 38 | ## 往期圆桌讨论 39 | 40 | ### Sora 夜谈之Video Diffusion 综述 41 | 42 | **主讲**: 邢桢 复旦大学视觉与学习实验室博士生 43 | 44 | **直播看点**: 图像生成扩散模型基础/文生视频扩散模型的发展/浅谈 Sora 背后技术和复现挑战 45 | 46 | **在线直播时间**: 02/28 20:00-21:00 47 | 48 | **B站回放**: [Sora 之 Video Diffusion 综述](https://www.bilibili.com/video/BV1cJ4m1e7sQ) 49 | 50 | **PPT**: [飞书链接](https://aicarrier.feishu.cn/file/Ds0BbCAo6oTazdxxo3Zciw1Nnne) 51 | 52 | **知乎Notes**: [A Survey on Generative Diffusion Model 生成扩散模型综述](https://zhuanlan.zhihu.com/p/684795460) 53 | 54 | --------- 55 | 56 | ### Stable Diffusion 3 论文(MM-DiT)解读 57 | 58 | **主讲**:MMagic 核心贡献者 59 | 60 | **直播看点**:MMagic 核心贡献者为我们领读 Stable Diffusion 3 论文,介绍 Stable Diffusion 3 的架构细节和设计思路。 61 | 62 | **在线直播时间**:03/12 20:00 63 | 64 | **PPT**: [飞书链接](https://aicarrier.feishu.cn/file/NXnTbo5eqo8xNYxeHnecjLdJnQq) 65 | 66 | [contributors-shield]: https://img.shields.io/github/contributors/mini-sora/minisora.svg?style=flat-square 67 | [contributors-url]: https://github.com/mini-sora/minisora/graphs/contributors 68 | [forks-shield]: https://img.shields.io/github/forks/mini-sora/minisora.svg?style=flat-square 69 | [forks-url]: https://github.com/mini-sora/minisora/network/members 70 | [stars-shield]: https://img.shields.io/github/stars/mini-sora/minisora.svg?style=flat-square 71 | [stars-url]: https://github.com/mini-sora/minisora/stargazers 72 | [issues-shield]: https://img.shields.io/github/issues/mini-sora/minisora.svg?style=flat-square 73 | [issues-url]: https://img.shields.io/github/issues/mini-sora/minisora.svg 74 | [license-shield]: https://img.shields.io/github/license/mini-sora/minisora.svg?style=flat-square 75 | [license-url]: https://github.com/mini-sora/minisora/blob/main/LICENSE 76 | -------------------------------------------------------------------------------- /notes/dataset_note.md: -------------------------------------------------------------------------------- 1 | # 数据 2 | 3 | ## Sora技术报告节选 4 | > We first train a **highly descriptive captioner model** and then use it to produce text captions for all videos in our training set. 5 | > 6 | > 我们首先训练一个高度描述性的字幕模型,然后用它为我们训练集中的所有视频生成文字字幕。 7 | 8 | 用AI assistant辅助生成高质量数据集,用于进行更大规模AI的训练。 9 | 10 | 11 | > Similar to DALL·E 3, we also leverage GPT to turn short user prompts into longer detailed captions that are sent to the video model. This enables Sora to generate high quality videos that accurately follow user prompts. 12 | > 13 | > 与DALL·E 3类似,我们也利用GPT将用户的简短提示转化为发送给视频模型的更长的详细字幕。这使得Sora能够生成准确遵循用户提示的高质量视频。 14 | 15 | AI assistant - Prompt工程: 将人为输入的prompt改写成更高质量的prompt。即,重述提示词技术(RePrompt)。 16 | 17 | ## 视频-文本生成模型 18 | 19 | 用于高质量大规模数据集构建。 20 | 21 | - **BibiGPT**: [Bilibili](https://www.bilibili.com/video/BV1fX4y1Q7Ux/?vd_source=dd5a650b0ad84edd0d54bb18196ecb86) , [Github](https://github.com/JimmyLv/BibiGPT-v1) 22 | 23 | - **VideoBERT**:一个早期尝试,利用自监督学习在大规模未标注视频上训练,以理解视频内容并生成描述文本。 24 | 25 | - **S3D (Separable 3D ConvNets)**:使用3D卷积网络处理视频帧,捕捉时间和空间信息,常与语言模型结合用于视频理解和描述。 26 | 27 | - **HERO (Hierarchical Encoder Representation for Video+Language Omni-representation)**:专为视频+语言任务设计,通过分层编码器提高了视频理解和文本生成的能力。 28 | 29 | - **ViLBERT (Vision-and-Language BERT)** 和 **VideoBERT**:这两个模型结合了视觉和语言信息,通过类似BERT的结构实现跨模态理解,可以用于视频描述生成。 30 | 31 | - **OpenAI's CLIP**:虽然CLIP主要设计用于图像和文本的匹配,但其强大的跨模态理解能力也被应用于视频内容的理解和描述。 32 | 33 | - **Transformer-based models**:基于Transformer的模型,如GPT和BERT,通过预训练和微调,也可以应用于视频描述的任务。 34 | 35 | ## 重述提示词技术 36 | 37 | - **RePrompt**: Automatic Prompt Editing to Refine AI-Generative Art Towards Precise Expressions [Paper](https://arxiv.org/abs/2302.09466) 38 | - Retrieval-Enhanced Visual Prompt Learning for Few-shot Classification [Paper](https://arxiv.org/pdf/2306.02243.pdf) 39 | - [deepgram.com/ai-apps/reprompt](https://deepgram.com/ai-apps/reprompt) , [foundr.ai/product/reprompt](https://foundr.ai/product/reprompt) 40 | -------------------------------------------------------------------------------- /notes/latte论文精读翻译.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mini-sora/minisora/c831ce3cbac4c9a81a315c6901145c023814897c/notes/latte论文精读翻译.pdf --------------------------------------------------------------------------------