├── .github └── workflows │ └── docker_build.yml ├── .gitignore ├── LICENSE ├── README.md ├── docs ├── Contribution_Guidelines.md ├── I2V.md ├── Prompt_Refiner.md ├── Report-v1.0.0-cn.md ├── Report-v1.0.0.md ├── Report-v1.1.0.md ├── Report-v1.2.0.md ├── Report-v1.3.0.md ├── T2V.md └── VAE.md ├── examples ├── cond_pix_path.txt ├── cond_prompt.txt ├── rec_image.py ├── rec_video.py └── sora.txt ├── opensora ├── __init__.py ├── acceleration │ ├── __init__.py │ ├── communications.py │ └── parallel_states.py ├── adaptor │ ├── __init__.py │ ├── bf16_optimizer.py │ ├── engine.py │ ├── modules.py │ ├── stage_1_and_2.py │ ├── utils.py │ └── zp_manager.py ├── dataset │ ├── __init__.py │ ├── inpaint_dataset.py │ ├── t2v_datasets.py │ ├── transform.py │ └── virtual_disk.py ├── models │ ├── __init__.py │ ├── causalvideovae │ │ ├── __init__.py │ │ ├── dataset │ │ │ ├── __init__.py │ │ │ ├── ddp_sampler.py │ │ │ ├── transform.py │ │ │ └── video_dataset.py │ │ ├── eval │ │ │ ├── cal_fvd.py │ │ │ ├── cal_lpips.py │ │ │ ├── cal_psnr.py │ │ │ ├── cal_ssim.py │ │ │ ├── eval.py │ │ │ ├── fvd │ │ │ │ ├── styleganv │ │ │ │ │ └── fvd.py │ │ │ │ └── videogpt │ │ │ │ │ ├── fvd.py │ │ │ │ │ └── pytorch_i3d.py │ │ │ └── script │ │ │ │ ├── cal_clip_score.sh │ │ │ │ ├── cal_fvd.sh │ │ │ │ ├── cal_lpips.sh │ │ │ │ ├── cal_psnr.sh │ │ │ │ └── cal_ssim.sh │ │ ├── model │ │ │ ├── __init__.py │ │ │ ├── configuration_videobase.py │ │ │ ├── dataset_videobase.py │ │ │ ├── ema_model.py │ │ │ ├── losses │ │ │ │ ├── __init__.py │ │ │ │ ├── discriminator.py │ │ │ │ ├── lpips.py │ │ │ │ └── perceptual_loss.py │ │ │ ├── modeling_videobase.py │ │ │ ├── modules │ │ │ │ ├── __init__.py │ │ │ │ ├── attention.py │ │ │ │ ├── block.py │ │ │ │ ├── conv.py │ │ │ │ ├── normalize.py │ │ │ │ ├── ops.py │ │ │ │ ├── quant.py │ │ │ │ ├── resnet_block.py │ │ │ │ ├── updownsample.py │ │ │ │ └── wavelet.py │ │ │ ├── registry.py │ │ │ ├── trainer_videobase.py │ │ │ ├── utils │ │ │ │ ├── __init__.py │ │ │ │ ├── distrib_utils.py │ │ │ │ ├── module_utils.py │ │ │ │ ├── scheduler_utils.py │ │ │ │ ├── video_utils.py │ │ │ │ └── wavelet_utils.py │ │ │ └── vae │ │ │ │ ├── __init__.py │ │ │ │ ├── modeling_causalvae.py │ │ │ │ └── modeling_wfvae.py │ │ ├── sample │ │ │ └── rec_video_vae.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── dataset_utils.py │ │ │ ├── downloader.py │ │ │ └── video_utils.py │ ├── diffusion │ │ ├── __init__.py │ │ ├── common.py │ │ ├── curope │ │ │ ├── __init__.py │ │ │ ├── curope.cpp │ │ │ ├── curope3d.py │ │ │ ├── kernels.cu │ │ │ └── setup.py │ │ └── opensora_v1_3 │ │ │ ├── __init__.py │ │ │ ├── modeling_inpaint.py │ │ │ ├── modeling_opensora.py │ │ │ └── modules.py │ ├── frame_interpolation │ │ ├── cfgs │ │ │ └── AMT-G.yaml │ │ ├── interpolation.py │ │ ├── networks │ │ │ ├── AMT-G.py │ │ │ ├── __init__.py │ │ │ └── blocks │ │ │ │ ├── __init__.py │ │ │ │ ├── feat_enc.py │ │ │ │ ├── ifrnet.py │ │ │ │ ├── multi_flow.py │ │ │ │ └── raft.py │ │ ├── readme.md │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── build_utils.py │ │ │ ├── dist_utils.py │ │ │ ├── flow_utils.py │ │ │ └── utils.py │ ├── prompt_refiner │ │ ├── inference.py │ │ ├── merge.py │ │ └── train.py │ └── text_encoder │ │ ├── __init__.py │ │ ├── clip.py │ │ └── t5.py ├── npu_config.py ├── sample │ ├── caption_refiner.py │ ├── pipeline_inpaint.py │ ├── pipeline_opensora.py │ ├── rec_image.py │ ├── rec_video.py │ └── sample.py ├── serve │ ├── gradio_utils.py │ ├── gradio_web_server.py │ ├── gradio_web_server_i2v.py │ └── style.css ├── train │ ├── train_causalvae.py │ ├── train_inpaint.py │ └── train_t2v_diffusers.py └── utils │ ├── communications.py │ ├── dataset_utils.py │ ├── downloader.py │ ├── ema.py │ ├── ema_utils.py │ ├── freeinit_utils.py │ ├── lora_utils.py │ ├── mask_utils.py │ ├── parallel_states.py │ ├── sample_utils.py │ ├── taming_download.py │ └── utils.py ├── pyproject.toml └── scripts ├── accelerate_configs ├── ddp_config.yaml ├── deepspeed_zero2_config.yaml ├── deepspeed_zero2_offload_config.yaml ├── deepspeed_zero3_config.yaml ├── deepspeed_zero3_offload_config.yaml ├── default_config.yaml ├── hostfile ├── multi_node_example.yaml ├── multi_node_example_by_ddp.yaml ├── zero2.json ├── zero2_npu.json ├── zero2_offload.json ├── zero3.json └── zero3_offload.json ├── causalvae ├── eval.sh ├── prepare_eval.sh ├── rec_image.sh ├── rec_video.sh ├── train.sh └── wfvae_4dim.json ├── slurm └── placeholder ├── text_condition ├── gpu │ ├── sample_inpaint_v1_3.sh │ ├── sample_t2v_v1_3.sh │ ├── train_inpaint_v1_3.sh │ └── train_t2v_v1_3.sh └── npu │ ├── sample_inpaint_v1_3.sh │ ├── sample_t2v_v1_3.sh │ ├── train_inpaint_v1_3.sh │ └── train_t2v_v1_3.sh ├── train_configs └── mask_config.yaml └── train_data └── merge_data.txt /.github/workflows/docker_build.yml: -------------------------------------------------------------------------------- 1 | name: docker-build 2 | 3 | on: 4 | workflow_dispatch: 5 | push: 6 | branches: 7 | - "main" 8 | paths: 9 | - "docker/Dockerfile" 10 | 11 | jobs: 12 | build-Open-Sora: 13 | runs-on: ubuntu-latest 14 | steps: 15 | - 16 | name: Checkout 17 | uses: actions/checkout@v4 18 | - 19 | name: Set up QEMU 20 | uses: docker/setup-qemu-action@v3 21 | - 22 | name: Set up Docker Buildx 23 | uses: docker/setup-buildx-action@v3 24 | - 25 | name: Login to Docker Hub 26 | uses: docker/login-action@v3 27 | with: 28 | username: ${{ secrets.DOCKERHUB_USERNAME }} 29 | password: ${{ secrets.DOCKERHUB_TOKEN }} 30 | - 31 | name: Build and push Open-Sora image 32 | uses: docker/build-push-action@v5 33 | with: 34 | context: . 35 | file: ./docker/Dockerfile 36 | push: true 37 | platforms: linux/amd64, linux/arm64, linux/s390x, linux/ppc64le 38 | tags: ${{ secrets.DOCKERHUB_USERNAME }}/open-sora -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ucf101_stride4x4x4 2 | __pycache__ 3 | *.mp4 4 | .ipynb_checkpoints 5 | *.pth 6 | UCF-101/ 7 | results/ 8 | build/ 9 | opensora.egg-info/ 10 | wandb/ 11 | .idea 12 | *.ipynb 13 | *.jpg 14 | *.mp3 15 | *.safetensors 16 | *.mp4 17 | *.png 18 | *.gif 19 | *.pth 20 | *.pt 21 | cache_dir/ 22 | wandb/ 23 | test* 24 | sample_video*/ 25 | 512* 26 | 720* 27 | 1024* 28 | *debug* 29 | private* 30 | .deepspeed_env 31 | 256* 32 | sample_image*/ 33 | taming* 34 | *test* 35 | sft* 36 | flash* 37 | 65x256* 38 | alpha_vae 39 | *node* 40 | cache/ 41 | Open-Sora-Plan_models/ 42 | sample_image*cfg* 43 | *tmp* 44 | *pymp* 45 | check.py 46 | bucket.py 47 | whileinf.py 48 | validation_dir/ 49 | runs/ 50 | samples/ 51 | inpaint*/ 52 | bs32x8x1* 53 | *tmp* 54 | *pymp* 55 | check.py 56 | bucket.py 57 | whileinf.py 58 | bs4x8x16_* 59 | *.zip 60 | *validation/ 61 | bs1x8x32* 62 | bs16x8x1* 63 | bs8x8x2* 64 | bs8x8x1* 65 | bs8x8x8* 66 | bs1x8x16* 67 | checklora.py 68 | dim4todim8.py 69 | *vae8_any*320x320* 70 | samples/ 71 | runs/ 72 | *validation/ 73 | training_log*txt 74 | filter_motion* 75 | json2*.py 76 | motionfun* 77 | res_dist* 78 | filter_json_aes_m* 79 | stage2*.json 80 | kernel_meta 81 | ge_check_op.json 82 | WFVAE_DISTILL_FORMAL 83 | read_video* 84 | bs32x8x2* 85 | filter_json_aes_m* 86 | json2json* 87 | makenpu_json* 88 | *make_small_json* 89 | *schedule_noise* 90 | test* 91 | gpu_profiling* 92 | gyy_dense* 93 | torchelasti* 94 | *VEnhancer* 95 | *spdemo* 96 | i2v.txt 97 | *run_i2v* 98 | *curope* 99 | any* 100 | *nomotion* 101 | log* 102 | *svg 103 | *k8s* 104 | *rf* 105 | *lzj* 106 | final* 107 | opensora/train/*debug.py 108 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 PKU-YUAN-Lab (袁粒课题组-北大信工) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /docs/Contribution_Guidelines.md: -------------------------------------------------------------------------------- 1 | # Contributing to the Open-Sora Plan Community 2 | 3 | The Open-Sora Plan open-source community is a collaborative initiative driven by the community, emphasizing a commitment to being free and void of exploitation. Organized spontaneously by community members, we invite you to contribute to the Open-Sora Plan open-source community and help elevate it to new heights! 4 | 5 | ## Submitting a Pull Request (PR) 6 | 7 | As a contributor, before submitting your request, kindly follow these guidelines: 8 | 9 | 1. Start by checking the [Open-Sora Plan GitHub](https://github.com/PKU-YuanGroup/Open-Sora-Plan/pulls) to see if there are any open or closed pull requests related to your intended submission. Avoid duplicating existing work. 10 | 11 | 2. [Fork](https://github.com/PKU-YuanGroup/Open-Sora-Plan/fork) the [open-sora plan](https://github.com/PKU-YuanGroup/Open-Sora-Plan) repository and download your forked repository to your local machine. 12 | 13 | ```bash 14 | git clone [your-forked-repository-url] 15 | ``` 16 | 17 | 3. Add the original Open-Sora Plan repository as a remote to sync with the latest updates: 18 | 19 | ```bash 20 | git remote add upstream https://github.com/PKU-YuanGroup/Open-Sora-Plan 21 | ``` 22 | 23 | 4. Sync the code from the main repository to your local machine, and then push it back to your forked remote repository. 24 | 25 | ``` 26 | # Pull the latest code from the upstream branch 27 | git fetch upstream 28 | 29 | # Switch to the main branch 30 | git checkout main 31 | 32 | # Merge the updates from the upstream branch into main, synchronizing the local main branch with the upstream 33 | git merge upstream/main 34 | 35 | # Additionally, sync the local main branch to the remote branch of your forked repository 36 | git push origin main 37 | ``` 38 | 39 | 40 | > Note: Sync the code from the main repository before each submission. 41 | 42 | 5. Create a branch in your forked repository for your changes, ensuring the branch name is meaningful. 43 | 44 | ```bash 45 | git checkout -b my-docs-branch main 46 | ``` 47 | 48 | 6. While making modifications and committing changes, adhere to our [Commit Message Format](#Commit-Message-Format). 49 | 50 | ```bash 51 | git commit -m "[docs]: xxxx" 52 | ``` 53 | 54 | 7. Push your changes to your GitHub repository. 55 | 56 | ```bash 57 | git push origin my-docs-branch 58 | ``` 59 | 60 | 8. Submit a pull request to `Open-Sora-Plan:main` on the GitHub repository page. 61 | 62 | ## Commit Message Format 63 | 64 | Commit messages must include both `` and `` sections. 65 | 66 | ```bash 67 | []: 68 | │ │ 69 | │ └─⫸ Briefly describe your changes, without ending with a period. 70 | │ 71 | └─⫸ Commit Type: |docs|feat|fix|refactor| 72 | ``` 73 | 74 | ### Type 75 | 76 | * **docs**: Modify or add documents. 77 | * **feat**: Introduce a new feature. 78 | * **fix**: Fix a bug. 79 | * **refactor**: Restructure code, excluding new features or bug fixes. 80 | 81 | ### Summary 82 | 83 | Describe modifications in English, without ending with a period. 84 | 85 | > e.g., git commit -m "[docs]: add a contributing.md file" 86 | 87 | This guideline is borrowed by [minisora](https://github.com/mini-sora/minisora). We sincerely appreciate MiniSora authors for their awesome templates. 88 | -------------------------------------------------------------------------------- /docs/I2V.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ### Data prepare 4 | 5 | Data preparation aligns with T2V section. 6 | 7 | ### Training 8 | 9 | Training on GPUs: 10 | 11 | ```bash 12 | bash scripts/text_condition/gpu/train_inpaint_v1_3.sh 13 | ``` 14 | 15 | Training on NPUs: 16 | 17 | ```bash 18 | bash scripts/text_condition/npu/train_inpaint_v1_3.sh 19 | ``` 20 | 21 | There are additional parameters you need to understand beyond those introduced in the T2V section. 22 | 23 | | Argparse | Usage | 24 | | ------------------------------------- | ------------------------------------------------------------ | 25 | | `--default_text_ratio` 0.5 | During I2V training, a portion of the text is replaced with a default text to account for cases where the user provides an image without accompanying text. | 26 | | `--mask_config` | The path of the `mask_config` file. | 27 | | `--add_noise_to_condition` | Adding a small amount of noise to conditional frames during training to improve generalization. | 28 | 29 | In Open-Sora Plan V1.3, all mask ratio settings are specified in the `mask_config` file, located at `scripts/train_configs/mask_config.yaml`. The parameters include: 30 | 31 | | Argparse | Usage | 32 | | ---------------------------- | ------------------------------------------------------------ | 33 | | `min_clear_ratio` | The minimum ratio of frames retained during continuation and random masking. | 34 | | `max_clear_ratio` | The maximum ratio of frames retained during continuation and random masking. | 35 | | `mask_type_ratio_dict_video` | During training, specify the ratio for each mask task. For video data, there are six mask types: `t2iv`, `i2v`, `transition`, `continuation`, `clear`, and `random_temporal`. These inputs will be normalized to ensure their sum equals one. | 36 | | `mask_type_ratio_dict_image` | During training, specify the ratio for each mask task. For image data, there are two mask types: `t2iv` and `clear`. These inputs will be normalized to ensure their sum equals one. | 37 | 38 | ### Inference 39 | 40 | Inference on GPUs: 41 | 42 | ```bash 43 | bash scripts/text_condition/gpu/sample_inpaint_v1_3.sh 44 | ``` 45 | 46 | Inference on NPUs: 47 | 48 | ```bash 49 | bash scripts/text_condition/npu/sample_inpaint_v1_3.sh 50 | ``` 51 | 52 | In the current version, we have only open-sourced the 93x480p version of the Image-to-Video (I2V) model. We recommend configuration `--guidance_scale 7.5 --num_sampling_steps 100 --sample_method EulerAncestralDiscrete` for sampling. 53 | 54 | **Inference on 93×480p**, the speed on H100 and Ascend 910B. 55 | 56 | | Size | 1 H100 | 1 Ascend 910B | 57 | | ------- | ------------ | ------------- | 58 | | 93×480p | 150s/100step | 292s/100step | 59 | 60 | During inference, you can specify `--nproc_per_node` and set the `--sp` parameter to choose between single-gpu/npu mode, DDP (Distributed Data Parallel) mode, or SP (Sequential Parallel) mode for inference. 61 | 62 | The following are key parameters required for inference: 63 | 64 | | Argparse | Usage | 65 | | ---------------------------------------------- | ------------------------------------------------------------ | 66 | | `--height` 352 `--width` 640 `--crop_for_hw` | When `crop_for_hw` is specified, the I2V model operates in fixed-resolution mode, generating outputs at the user-specified height and width. | 67 | | `--max_hxw` 236544 | When `crop_for_hw` is not specified, the I2V model operates in arbitrary resolution mode, resizing outputs to the greatest common divisor of the resolutions in the input image list. In this case, the `--max_hxw` parameter must be provided, with a default value of 236544. | 68 | | `--text_prompt` | The path to the `prompt` file, where each line represents a prompt. Each line must correspond precisely to each line in `--conditional_pixel_values_path`. | 69 | | `--conditional_pixel_values_path` | The input path for control information can contain one or multiple images or videos, with each line controlling the generation of one video. It must correspond precisely to each prompt in `--text_prompt`. | 70 | | `--mask_type` | Specify the mask type used for the current inference; available types are listed in the `MaskType` class in `opensora/utils/mask_utils.py`, which are six mask types: `t2iv`, `i2v`, `transition`, `continuation`, `clear`, and `random_temporal`. This parameter can be omitted when performing I2V and Transition tasks. | 71 | | `--noise_strength` | The noise strength added to conditional frames, which defaults to 0 (no noise added). | 72 | 73 | Before inference, you need to create two text files: one named `prompt.txt` and another named `conditional_pixel_values_path`. Each line of text in `prompt.txt` should correspond to the paths on each line in `conditional_pixel_values_path`. 74 | 75 | For example, if the content of `prompt.txt` is: 76 | 77 | ``` 78 | this is a prompt of i2v task. 79 | this is a prompt of transition task. 80 | ``` 81 | 82 | Then the content of `conditional_pixel_values_path` should be: 83 | 84 | ``` 85 | /path/to/image_0.png 86 | /path/to/image_1_0.png,/path/to/image_1_1.png 87 | ``` 88 | 89 | This means we will execute a image-to-video task using `/path/to/image_0.png` and "this is a prompt of i2v task." For the transition task, we'll use `/path/to/image_1_0.png` and `/path/to/image_1_1.png` (note that these two paths are separated by a comma without any spaces) along with "this is a prompt of transition task." 90 | 91 | After creating the files, make sure to specify their paths in the `sample_inpaint_v1_3.sh` script. 92 | -------------------------------------------------------------------------------- /docs/Prompt_Refiner.md: -------------------------------------------------------------------------------- 1 | ## Data 2 | 3 | We have open-sourced our dataset of 32,555 pairs, which includes Chinese data. The dataset is available [here](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.3.0/tree/main/prompt_refiner). The details can be found [here](https://github.com/PKU-YuanGroup/Open-Sora-Plan/blob/main/docs/Report-v1.3.0.md#prompt-refiner). 4 | 5 | In fact, it is a JSON file with the following structure. 6 | 7 | ``` 8 | [ 9 | { 10 | "instruction": "Refine the sentence: \"A newly married couple sharing a piece of there wedding cake.\" to contain subject description, action, scene description. (Optional: camera language, light and shadow, atmosphere) and conceive some additional actions to make the sentence more dynamic. Make sure it is a fluent sentence, not nonsense.", 11 | "input": "", 12 | "output": "The newlywed couple, dressed in elegant attire..." 13 | }, 14 | ... 15 | ] 16 | ``` 17 | 18 | ## Train 19 | 20 | `--data_path` is the path to the prepared JSON file. 21 | `--model_path` is the directory containing the LLaMA 3.1 weights, including `config.json` and some weight files. 22 | `--lora_out_path` is the path where the LoRA model will be saved. 23 | 24 | ``` 25 | cd opensora/models/prompt_refiner 26 | CUDA_VISIBLE_DEVICES=0 python train.py \ 27 | --data_path path/to/data.json \ 28 | --model_path path/to/llama_model \ 29 | --lora_out_path path/to/save/lora_model 30 | ``` 31 | 32 | ## Merge 33 | 34 | `--model_path` is the directory containing the LLaMA 3.1 weights, including `config.json` and some weight files. 35 | `--lora_in_path` is the directory containing the pre-trained LoRA model. 36 | `--lora_out_path` is the path for the merged model. 37 | 38 | ``` 39 | cd opensora/models/prompt_refiner 40 | CUDA_VISIBLE_DEVICES=0 python merge.py \ 41 | --base_path path/to/llama_model \ 42 | --lora_in_path path/to/save/lora_model \ 43 | --lora_out_path path/to/save/merge_model 44 | ``` 45 | 46 | ## Inference 47 | 48 | `--model_path` is the directory containing the weights (LLaMA 3.1 or merged Lora weight), including `config.json` and some weight files. 49 | `--prompt` is the text you want to input, which will be refined. 50 | 51 | ``` 52 | cd opensora/models/prompt_refiner 53 | CUDA_VISIBLE_DEVICES=0 python merge.py \ 54 | --mode_path path/to/data.json \ 55 | --prompt path/to/save/lora_model 56 | ``` -------------------------------------------------------------------------------- /docs/VAE.md: -------------------------------------------------------------------------------- 1 | 2 | ### Data prepare 3 | The organization of the training data is easy. We only need to put all the videos recursively in a directory. This makes the training more convenient when using multiple datasets. 4 | ``` shell 5 | Training Dataset 6 | |——sub_dataset1 7 | |——sub_sub_dataset1 8 | |——video1.mp4 9 | |——video2.mp4 10 | ...... 11 | |——sub_sub_dataset2 12 | |——video3.mp4 13 | |——video4.mp4 14 | ...... 15 | |——sub_dataset2 16 | |——video5.mp4 17 | |——video6.mp4 18 | ...... 19 | |——video7.mp4 20 | |——video8.mp4 21 | ``` 22 | 23 | ### Training 24 | ``` shell 25 | bash scripts/causalvae/train.sh 26 | ``` 27 | We introduce the important args for training. 28 | 29 | | Argparse | Usage | 30 | |:---|:---| 31 | |_Training size_|| 32 | |`--num_frames`|The number of using frames for training videos| 33 | |`--resolution`|The resolution of the input to the VAE| 34 | |`--batch_size`|The local batch size in each GPU| 35 | |`--sample_rate`|The frame interval of when loading training videos| 36 | |_Data processing_|| 37 | |`--video_path`|/path/to/dataset| 38 | |_Load weights_|| 39 | |`--model_name`| `CausalVAE` or `WFVAE`| 40 | |`--model_config`|/path/to/config.json The model config of VAE. If you want to train from scratch use this parameter.| 41 | |`--pretrained_model_name_or_path`|A directory containing a model checkpoint and its config. Using this parameter will only load its weight but not load the state of the optimizer| 42 | |`--resume_from_checkpoint`|/path/to/checkpoint It will resume the training process from the checkpoint including the weight and the optimizer.| 43 | 44 | ### Inference 45 | 46 | ``` shell 47 | bash scripts/causalvae/rec_video.sh 48 | ``` 49 | We introduce the important args for inference. 50 | | Argparse | Usage | 51 | |:---|:---| 52 | |_Ouoput video size_|| 53 | |`--num_frames`|The number of frames of generated videos| 54 | |`--height`|The resolution of generated videos| 55 | |`--width`|The resolution of generated videos| 56 | |_Data processing_|| 57 | |`--video_path`|The path to the original video| 58 | |`--rec_path`|The path to the generated video| 59 | |_Load weights_|| 60 | |`--ae_path`|/path/to/model_dir. A directory containing the checkpoint of VAE is used for inference and its model config.json| 61 | |_Other_|| 62 | |`--enable_tilintg`|Use tiling to deal with videos of high resolution and long duration| 63 | |`--save_memory`|Save memory to inference but lightly influence quality| 64 | 65 | 66 | ### Evaluation 67 | 68 | The evaluation process consists of two steps: 69 | 70 | Reconstruct videos in batches: `bash scripts/causalvae/prepare_eval.sh` 71 | Evaluate video metrics: `bash scripts/causalvae/eval.sh` 72 | 73 | To simplify the evaluation, environment variables are used for control. For step 1 (`bash scripts/causalvae/prepare_eval.sh`): 74 | 75 | ```bash 76 | # Experiment name 77 | EXP_NAME=wfvae 78 | # Video parameters 79 | SAMPLE_RATE=1 80 | NUM_FRAMES=33 81 | RESOLUTION=256 82 | # Model weights 83 | CKPT=ckpt 84 | # Select subset size (0 for full set) 85 | SUBSET_SIZE=0 86 | # Dataset directory 87 | DATASET_DIR=test_video 88 | ``` 89 | 90 | For step 2 (`scripts/causalvae/eval.sh`): 91 | 92 | ```bash 93 | # Experiment name 94 | EXP_NAME=wfvae-4dim 95 | # Video parameters 96 | SAMPLE_RATE=1 97 | NUM_FRAMES=33 98 | RESOLUTION=256 99 | # Evaluation metric 100 | METRIC=lpips 101 | # Select subset size (0 for full set) 102 | SUBSET_SIZE=0 103 | # Path to the ground truth videos, which can be saved during video reconstruction by setting `--output_origin` 104 | ORIGIN_DIR=video_gen/${EXP_NAME}_sr${SAMPLE_RATE}_nf${NUM_FRAMES}_res${RESOLUTION}_subset${SUBSET_SIZE}/origin 105 | # Path to the reconstructed videos 106 | RECON_DIR=video_gen/${EXP_NAME}_sr${SAMPLE_RATE}_nf${NUM_FRAMES}_res${RESOLUTION}_subset${SUBSET_SIZE} 107 | ``` -------------------------------------------------------------------------------- /examples/cond_pix_path.txt: -------------------------------------------------------------------------------- 1 | examples/test_img1.png 2 | examples/test_img2.png 3 | examples/test_img3.png -------------------------------------------------------------------------------- /examples/cond_prompt.txt: -------------------------------------------------------------------------------- 1 | A rocket ascends slowly into the sky. 2 | Along the coast, variously sized boats float on the lake. 3 | The landscape at sunset is profound and expansive. -------------------------------------------------------------------------------- /examples/rec_image.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append(".") 3 | from PIL import Image 4 | import torch 5 | from torchvision.transforms import ToTensor, Compose, Resize, Normalize, Lambda 6 | from torch.nn import functional as F 7 | import argparse 8 | import numpy as np 9 | from opensora.models.causalvideovae import ae_wrapper 10 | 11 | def preprocess(video_data: torch.Tensor, short_size: int = 128) -> torch.Tensor: 12 | transform = Compose( 13 | [ 14 | ToTensor(), 15 | Lambda(lambda x: 2. * x - 1.), 16 | Resize(size=short_size), 17 | ] 18 | ) 19 | outputs = transform(video_data) 20 | outputs = outputs.unsqueeze(0).unsqueeze(2) 21 | return outputs 22 | 23 | def main(args: argparse.Namespace): 24 | image_path = args.image_path 25 | short_size = args.short_size 26 | device = args.device 27 | kwarg = {} 28 | 29 | # vae = getae_wrapper(args.ae)(args.model_path, subfolder="vae", cache_dir='cache_dir', **kwarg).to(device) 30 | vae = ae_wrapper[args.ae](args.ae_path, **kwarg).eval().to(device) 31 | if args.enable_tiling: 32 | vae.vae.enable_tiling() 33 | vae.vae.tile_overlap_factor = args.tile_overlap_factor 34 | vae.eval() 35 | vae = vae.to(device) 36 | vae = vae.half() 37 | 38 | with torch.no_grad(): 39 | x_vae = preprocess(Image.open(image_path), short_size) 40 | x_vae = x_vae.to(device, dtype=torch.float16) # b c t h w 41 | latents = vae.encode(x_vae) 42 | latents = latents.to(torch.float16) 43 | image_recon = vae.decode(latents) # b t c h w 44 | x = image_recon[0, 0, :, :, :] 45 | x = x.squeeze() 46 | x = x.detach().cpu().numpy() 47 | x = np.clip(x, -1, 1) 48 | x = (x + 1) / 2 49 | x = (255*x).astype(np.uint8) 50 | x = x.transpose(1,2,0) 51 | image = Image.fromarray(x) 52 | image.save(args.rec_path) 53 | 54 | 55 | if __name__ == '__main__': 56 | parser = argparse.ArgumentParser() 57 | parser.add_argument('--image_path', type=str, default='') 58 | parser.add_argument('--rec_path', type=str, default='') 59 | parser.add_argument('--ae', type=str, default='') 60 | parser.add_argument('--ae_path', type=str, default='') 61 | parser.add_argument('--model_path', type=str, default='results/pretrained') 62 | parser.add_argument('--short_size', type=int, default=336) 63 | parser.add_argument('--device', type=str, default='cuda') 64 | parser.add_argument('--tile_overlap_factor', type=float, default=0.25) 65 | parser.add_argument('--enable_tiling', action='store_true') 66 | 67 | args = parser.parse_args() 68 | main(args) -------------------------------------------------------------------------------- /opensora/__init__.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------------------------------- /opensora/acceleration/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-YuanGroup/Open-Sora-Plan/469d4e8810c326811e1be7e1c17b845503633210/opensora/acceleration/__init__.py -------------------------------------------------------------------------------- /opensora/acceleration/parallel_states.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch_npu 3 | import torch.distributed as dist 4 | import os 5 | try: 6 | from lcalib.functional import lcal_initialize 7 | enable_LCCL = True 8 | except: 9 | lcal_initialize = None 10 | enable_LCCL = False 11 | class COMM_INFO: 12 | def __init__(self): 13 | self.group = None 14 | self.world_size = 0 15 | self.rank = -1 16 | 17 | lccl_info = COMM_INFO() 18 | hccl_info = COMM_INFO() 19 | _SEQUENCE_PARALLEL_STATE = False 20 | def initialize_sequence_parallel_state(sequence_parallel_size): 21 | global _SEQUENCE_PARALLEL_STATE 22 | if sequence_parallel_size > 1: 23 | _SEQUENCE_PARALLEL_STATE = True 24 | initialize_sequence_parallel_group(sequence_parallel_size) 25 | 26 | def set_sequence_parallel_state(state): 27 | global _SEQUENCE_PARALLEL_STATE 28 | _SEQUENCE_PARALLEL_STATE = state 29 | 30 | def get_sequence_parallel_state(): 31 | return _SEQUENCE_PARALLEL_STATE 32 | 33 | def initialize_sequence_parallel_group(sequence_parallel_size): 34 | """Initialize the sequence parallel group.""" 35 | rank = int(os.getenv('RANK', '0')) 36 | world_size = int(os.getenv("WORLD_SIZE", '1')) 37 | assert world_size % sequence_parallel_size == 0, "world_size must be divisible by sequence_parallel_size" 38 | # hccl 39 | hccl_info.world_size = sequence_parallel_size 40 | hccl_info.rank = rank 41 | num_sequence_parallel_groups: int = world_size // sequence_parallel_size 42 | for i in range(num_sequence_parallel_groups): 43 | ranks = range(i * sequence_parallel_size, (i + 1) * sequence_parallel_size) 44 | group = dist.new_group(ranks) 45 | if rank in ranks: 46 | hccl_info.group = group 47 | 48 | if enable_LCCL: 49 | assert sequence_parallel_size == 8, "sequence_parallel_size should be 8 when enable_LCCL is True" 50 | rank %= sequence_parallel_size 51 | lccl_info.world_size = sequence_parallel_size 52 | lccl_info.group = lcal_initialize(rank, sequence_parallel_size) 53 | lccl_info.rank = rank 54 | 55 | def destroy_sequence_parallel_group(): 56 | """Destroy the sequence parallel group.""" 57 | dist.destroy_process_group() 58 | -------------------------------------------------------------------------------- /opensora/adaptor/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-YuanGroup/Open-Sora-Plan/469d4e8810c326811e1be7e1c17b845503633210/opensora/adaptor/__init__.py -------------------------------------------------------------------------------- /opensora/adaptor/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | 6 | def fp32_layer_norm_forward(self, inputs: torch.Tensor) -> torch.Tensor: 7 | origin_dtype = inputs.dtype 8 | return F.layer_norm(inputs.float(), self.normalized_shape, self.weight.float() if self.weight is not None else None, 9 | self.bias.float() if self.bias is not None else None, self.eps).to(origin_dtype) 10 | 11 | 12 | def fp32_silu_forward(self, inputs: torch.Tensor) -> torch.Tensor: 13 | return torch.nn.functional.silu(inputs.float(), inplace=self.inplace).to(inputs.dtype) 14 | 15 | 16 | def fp32_gelu_forward(self, inputs: torch.Tensor) -> torch.Tensor: 17 | return torch.nn.functional.gelu(inputs.float(), approximate=self.approximate).to(inputs.dtype) 18 | 19 | 20 | def replace_with_fp32_forwards(): 21 | nn.GELU.forward = fp32_gelu_forward 22 | nn.SiLU.forward = fp32_silu_forward 23 | nn.LayerNorm.forward = fp32_layer_norm_forward 24 | -------------------------------------------------------------------------------- /opensora/adaptor/zp_manager.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import torch.distributed as dist 4 | 5 | 6 | class ZPManager(object): 7 | def __init__(self, zp_size=8): 8 | self.rank = int(os.getenv('RANK', '0')) 9 | self.world_size = int(os.getenv("WORLD_SIZE", '1')) 10 | self.zp_size = zp_size 11 | self.zp_group = None 12 | self.zp_rank = None 13 | self.is_initialized = False 14 | 15 | def init_group(self): 16 | if self.is_initialized: 17 | return 18 | 19 | self.is_initialized = True 20 | 21 | """Initialize the sequence parallel group.""" 22 | num_zp_groups: int = self.world_size // self.zp_size 23 | for i in range(num_zp_groups): 24 | ranks = range(i * self.zp_size, (i + 1) * self.zp_size) 25 | group = dist.new_group(ranks) 26 | if self.rank in ranks: 27 | self.zp_group = group 28 | self.zp_rank = self.rank % self.zp_size 29 | 30 | 31 | zp_manager = ZPManager() 32 | -------------------------------------------------------------------------------- /opensora/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from torchvision.transforms import Compose 2 | from transformers import AutoTokenizer, AutoImageProcessor 3 | 4 | from torchvision import transforms 5 | from torchvision.transforms import Lambda 6 | 7 | try: 8 | import torch_npu 9 | except: 10 | torch_npu = None 11 | 12 | from opensora.dataset.t2v_datasets import T2V_dataset 13 | from opensora.dataset.inpaint_dataset import Inpaint_dataset 14 | from opensora.models.causalvideovae import ae_norm, ae_denorm 15 | from opensora.dataset.transform import ToTensorVideo, TemporalRandomCrop, MaxHWResizeVideo, CenterCropResizeVideo, LongSideResizeVideo, SpatialStrideCropVideo, NormalizeVideo, ToTensorAfterResize 16 | 17 | 18 | 19 | def getdataset(args): 20 | temporal_sample = TemporalRandomCrop(args.num_frames) # 16 x 21 | norm_fun = ae_norm[args.ae] 22 | if args.force_resolution: 23 | resize = [CenterCropResizeVideo((args.max_height, args.max_width)), ] 24 | else: 25 | resize = [ 26 | MaxHWResizeVideo(args.max_hxw), 27 | SpatialStrideCropVideo(stride=args.hw_stride), 28 | ] 29 | 30 | tokenizer_1 = AutoTokenizer.from_pretrained(args.text_encoder_name_1, cache_dir=args.cache_dir) 31 | tokenizer_2 = None 32 | if args.text_encoder_name_2 is not None: 33 | tokenizer_2 = AutoTokenizer.from_pretrained(args.text_encoder_name_2, cache_dir=args.cache_dir) 34 | if args.dataset == 't2v': 35 | transform = transforms.Compose([ 36 | ToTensorVideo(), 37 | *resize, 38 | norm_fun 39 | ]) # also work for img, because img is video when frame=1 40 | return T2V_dataset( 41 | args, transform=transform, temporal_sample=temporal_sample, 42 | tokenizer_1=tokenizer_1, tokenizer_2=tokenizer_2 43 | ) 44 | elif args.dataset == 'i2v' or args.dataset == 'inpaint': 45 | resize_transform = Compose(resize) 46 | transform = Compose([ 47 | ToTensorAfterResize(), 48 | norm_fun, 49 | ]) 50 | return Inpaint_dataset( 51 | args, resize_transform=resize_transform, transform=transform, 52 | temporal_sample=temporal_sample, tokenizer_1=tokenizer_1, tokenizer_2=tokenizer_2 53 | ) 54 | raise NotImplementedError(args.dataset) 55 | 56 | 57 | if __name__ == "__main__": 58 | ''' 59 | python opensora/dataset/__init__.py 60 | ''' 61 | from accelerate import Accelerator 62 | from opensora.dataset.t2v_datasets import dataset_prog 63 | from opensora.utils.dataset_utils import LengthGroupedSampler, Collate 64 | from torch.utils.data import DataLoader 65 | import random 66 | from torch import distributed as dist 67 | from tqdm import tqdm 68 | args = type('args', (), 69 | { 70 | 'ae': 'WFVAEModel_D32_4x8x8', 71 | 'dataset': 't2v', 72 | 'model_max_length': 512, 73 | 'max_height': 640, 74 | 'max_width': 640, 75 | 'hw_stride': 16, 76 | 'num_frames': 93, 77 | 'compress_kv_factor': 1, 78 | 'interpolation_scale_t': 1, 79 | 'interpolation_scale_h': 1, 80 | 'interpolation_scale_w': 1, 81 | 'cache_dir': '../cache_dir', 82 | 'data': '/home/image_data/gyy/mmdit/Open-Sora-Plan/scripts/train_data/current_hq_on_npu.txt', 83 | 'train_fps': 18, 84 | 'drop_short_ratio': 0.0, 85 | 'speed_factor': 1.0, 86 | 'cfg': 0.1, 87 | 'text_encoder_name_1': 'google/mt5-xxl', 88 | 'text_encoder_name_2': None, 89 | 'dataloader_num_workers': 8, 90 | 'force_resolution': False, 91 | 'use_decord': True, 92 | 'group_data': True, 93 | 'train_batch_size': 1, 94 | 'gradient_accumulation_steps': 1, 95 | 'ae_stride': 8, 96 | 'ae_stride_t': 4, 97 | 'patch_size': 2, 98 | 'patch_size_t': 1, 99 | 'total_batch_size': 256, 100 | 'sp_size': 1, 101 | 'max_hxw': 384*384, 102 | 'min_hxw': 384*288, 103 | # 'max_hxw': 236544, 104 | # 'min_hxw': 102400, 105 | } 106 | ) 107 | # accelerator = Accelerator() 108 | dataset = getdataset(args) 109 | # data = next(iter(dataset)) 110 | # import ipdb;ipdb.set_trace() 111 | # print() 112 | sampler = LengthGroupedSampler( 113 | args.train_batch_size, 114 | world_size=1, 115 | gradient_accumulation_size=args.gradient_accumulation_steps, 116 | initial_global_step=0, 117 | lengths=dataset.lengths, 118 | group_data=args.group_data, 119 | ) 120 | train_dataloader = DataLoader( 121 | dataset, 122 | shuffle=False, 123 | # pin_memory=True, 124 | collate_fn=Collate(args), 125 | batch_size=args.train_batch_size, 126 | num_workers=args.dataloader_num_workers, 127 | sampler=sampler, 128 | drop_last=False, 129 | prefetch_factor=4 130 | ) 131 | import ipdb;ipdb.set_trace() 132 | import imageio 133 | import numpy as np 134 | from einops import rearrange 135 | while True: 136 | for idx, i in enumerate(tqdm(train_dataloader)): 137 | pixel_values = i[0][0] 138 | pixel_values_ = (pixel_values+1)/2 139 | pixel_values_ = rearrange(pixel_values_, 'c t h w -> t h w c') * 255.0 140 | pixel_values_ = pixel_values_.numpy().astype(np.uint8) 141 | imageio.mimwrite(f'output{idx}.mp4', pixel_values_, fps=args.train_fps) 142 | dist.barrier() 143 | pass -------------------------------------------------------------------------------- /opensora/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .causalvideovae import CausalVAEModelWrapper, WFVAEModelWrapper -------------------------------------------------------------------------------- /opensora/models/causalvideovae/__init__.py: -------------------------------------------------------------------------------- 1 | from torchvision.transforms import Lambda 2 | from .model.vae import CausalVAEModel, WFVAEModel 3 | from einops import rearrange 4 | import torch 5 | try: 6 | import torch_npu 7 | from opensora.npu_config import npu_config 8 | except: 9 | torch_npu = None 10 | npu_config = None 11 | pass 12 | import torch.nn as nn 13 | import torch 14 | 15 | class CausalVAEModelWrapper(nn.Module): 16 | def __init__(self, model_path, subfolder=None, cache_dir=None, use_ema=False, **kwargs): 17 | super(CausalVAEModelWrapper, self).__init__() 18 | self.vae = CausalVAEModel.from_pretrained(model_path, subfolder=subfolder, cache_dir=cache_dir, **kwargs) 19 | 20 | def encode(self, x): 21 | x = self.vae.encode(x).sample().mul_(0.18215) 22 | return x 23 | def decode(self, x): 24 | x = self.vae.decode(x / 0.18215) 25 | x = rearrange(x, 'b c t h w -> b t c h w').contiguous() 26 | return x 27 | 28 | def dtype(self): 29 | return self.vae.dtype 30 | 31 | class WFVAEModelWrapper(nn.Module): 32 | def __init__(self, model_path, subfolder=None, cache_dir=None, **kwargs): 33 | super(WFVAEModelWrapper, self).__init__() 34 | self.vae = WFVAEModel.from_pretrained(model_path, subfolder=subfolder, cache_dir=cache_dir, **kwargs) 35 | self.register_buffer('shift', torch.tensor(self.vae.config.shift)[None, :, None, None, None]) 36 | self.register_buffer('scale', torch.tensor(self.vae.config.scale)[None, :, None, None, None]) 37 | 38 | def encode(self, x): 39 | x = (self.vae.encode(x).sample() - self.shift.to(x.device, dtype=x.dtype)) * self.scale.to(x.device, dtype=x.dtype) 40 | return x 41 | 42 | def decode(self, x): 43 | x = x / self.scale.to(x.device, dtype=x.dtype) + self.shift.to(x.device, dtype=x.dtype) 44 | x = self.vae.decode(x) 45 | x = rearrange(x, 'b c t h w -> b t c h w').contiguous() 46 | return x 47 | 48 | def dtype(self): 49 | return self.vae.dtype 50 | 51 | ae_wrapper = { 52 | 'CausalVAEModel_D4_2x8x8': CausalVAEModelWrapper, 53 | 'CausalVAEModel_D8_2x8x8': CausalVAEModelWrapper, 54 | 'CausalVAEModel_D4_4x8x8': CausalVAEModelWrapper, 55 | 'CausalVAEModel_D8_4x8x8': CausalVAEModelWrapper, 56 | 'WFVAEModel_D8_4x8x8': WFVAEModelWrapper, 57 | 'WFVAEModel_D16_4x8x8': WFVAEModelWrapper, 58 | 'WFVAEModel_D32_4x8x8': WFVAEModelWrapper, 59 | 'WFVAEModel_D32_8x8x8': WFVAEModelWrapper, 60 | } 61 | 62 | ae_stride_config = { 63 | 'CausalVAEModel_D4_2x8x8': [2, 8, 8], 64 | 'CausalVAEModel_D8_2x8x8': [2, 8, 8], 65 | 'CausalVAEModel_D4_4x8x8': [4, 8, 8], 66 | 'CausalVAEModel_D8_4x8x8': [4, 8, 8], 67 | 'WFVAEModel_D8_4x8x8': [4, 8, 8], 68 | 'WFVAEModel_D16_4x8x8': [4, 8, 8], 69 | 'WFVAEModel_D32_4x8x8': [4, 8, 8], 70 | 'WFVAEModel_D32_8x8x8': [8, 8, 8], 71 | } 72 | 73 | ae_channel_config = { 74 | 'CausalVAEModel_D4_2x8x8': 4, 75 | 'CausalVAEModel_D8_2x8x8': 8, 76 | 'CausalVAEModel_D4_4x8x8': 4, 77 | 'CausalVAEModel_D8_4x8x8': 8, 78 | 'WFVAEModel_D8_4x8x8': 8, 79 | 'WFVAEModel_D16_4x8x8': 16, 80 | 'WFVAEModel_D32_4x8x8': 32, 81 | 'WFVAEModel_D32_8x8x8': 32, 82 | } 83 | 84 | ae_denorm = { 85 | 'CausalVAEModel_D4_2x8x8': lambda x: (x + 1.) / 2., 86 | 'CausalVAEModel_D8_2x8x8': lambda x: (x + 1.) / 2., 87 | 'CausalVAEModel_D4_4x8x8': lambda x: (x + 1.) / 2., 88 | 'CausalVAEModel_D8_4x8x8': lambda x: (x + 1.) / 2., 89 | 'WFVAEModel_D8_4x8x8': lambda x: (x + 1.) / 2., 90 | 'WFVAEModel_D16_4x8x8': lambda x: (x + 1.) / 2., 91 | 'WFVAEModel_D32_4x8x8': lambda x: (x + 1.) / 2., 92 | 'WFVAEModel_D32_8x8x8': lambda x: (x + 1.) / 2., 93 | } 94 | 95 | ae_norm = { 96 | 'CausalVAEModel_D4_2x8x8': Lambda(lambda x: 2. * x - 1.), 97 | 'CausalVAEModel_D8_2x8x8': Lambda(lambda x: 2. * x - 1.), 98 | 'CausalVAEModel_D4_4x8x8': Lambda(lambda x: 2. * x - 1.), 99 | 'CausalVAEModel_D8_4x8x8': Lambda(lambda x: 2. * x - 1.), 100 | 'WFVAEModel_D8_4x8x8': Lambda(lambda x: 2. * x - 1.), 101 | 'WFVAEModel_D16_4x8x8': Lambda(lambda x: 2. * x - 1.), 102 | 'WFVAEModel_D32_4x8x8': Lambda(lambda x: 2. * x - 1.), 103 | 'WFVAEModel_D32_8x8x8': Lambda(lambda x: 2. * x - 1.), 104 | } -------------------------------------------------------------------------------- /opensora/models/causalvideovae/dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-YuanGroup/Open-Sora-Plan/469d4e8810c326811e1be7e1c17b845503633210/opensora/models/causalvideovae/dataset/__init__.py -------------------------------------------------------------------------------- /opensora/models/causalvideovae/eval/cal_fvd.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from tqdm import tqdm 4 | 5 | def trans(x): 6 | # if greyscale images add channel 7 | if x.shape[-3] == 1: 8 | x = x.repeat(1, 1, 3, 1, 1) 9 | 10 | # permute BTCHW -> BCTHW 11 | x = x.permute(0, 2, 1, 3, 4) 12 | 13 | return x 14 | 15 | def calculate_fvd(videos1, videos2, device, method='styleganv'): 16 | 17 | if method == 'styleganv': 18 | from fvd.styleganv.fvd import get_fvd_feats, frechet_distance, load_i3d_pretrained 19 | elif method == 'videogpt': 20 | from fvd.videogpt.fvd import load_i3d_pretrained 21 | from fvd.videogpt.fvd import get_fvd_logits as get_fvd_feats 22 | from fvd.videogpt.fvd import frechet_distance 23 | 24 | print("calculate_fvd...") 25 | 26 | # videos [batch_size, timestamps, channel, h, w] 27 | 28 | assert videos1.shape == videos2.shape 29 | 30 | i3d = load_i3d_pretrained(device=device) 31 | fvd_results = [] 32 | 33 | # support grayscale input, if grayscale -> channel*3 34 | # BTCHW -> BCTHW 35 | # videos -> [batch_size, channel, timestamps, h, w] 36 | 37 | videos1 = trans(videos1) 38 | videos2 = trans(videos2) 39 | 40 | fvd_results = {} 41 | 42 | # for calculate FVD, each clip_timestamp must >= 10 43 | for clip_timestamp in tqdm(range(10, videos1.shape[-3]+1)): 44 | 45 | # get a video clip 46 | # videos_clip [batch_size, channel, timestamps[:clip], h, w] 47 | videos_clip1 = videos1[:, :, : clip_timestamp] 48 | videos_clip2 = videos2[:, :, : clip_timestamp] 49 | 50 | # get FVD features 51 | feats1 = get_fvd_feats(videos_clip1, i3d=i3d, device=device) 52 | feats2 = get_fvd_feats(videos_clip2, i3d=i3d, device=device) 53 | 54 | # calculate FVD when timestamps[:clip] 55 | fvd_results[clip_timestamp] = frechet_distance(feats1, feats2) 56 | 57 | result = { 58 | "value": fvd_results, 59 | "video_setting": videos1.shape, 60 | "video_setting_name": "batch_size, channel, time, heigth, width", 61 | } 62 | 63 | return result 64 | 65 | # test code / using example 66 | 67 | def main(): 68 | NUMBER_OF_VIDEOS = 8 69 | VIDEO_LENGTH = 50 70 | CHANNEL = 3 71 | SIZE = 64 72 | videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False) 73 | videos2 = torch.ones(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False) 74 | device = torch.device("cuda") 75 | # device = torch.device("cpu") 76 | 77 | import json 78 | result = calculate_fvd(videos1, videos2, device, method='videogpt') 79 | print(json.dumps(result, indent=4)) 80 | 81 | result = calculate_fvd(videos1, videos2, device, method='styleganv') 82 | print(json.dumps(result, indent=4)) 83 | 84 | if __name__ == "__main__": 85 | main() 86 | -------------------------------------------------------------------------------- /opensora/models/causalvideovae/eval/cal_lpips.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from tqdm import tqdm 4 | import math 5 | 6 | import torch 7 | import lpips 8 | 9 | spatial = True # Return a spatial map of perceptual distance. 10 | 11 | # Linearly calibrated models (LPIPS) 12 | loss_fn = lpips.LPIPS(net='alex', spatial=spatial) # Can also set net = 'squeeze' or 'vgg' 13 | # loss_fn = lpips.LPIPS(net='alex', spatial=spatial, lpips=False) # Can also set net = 'squeeze' or 'vgg' 14 | 15 | def trans(x): 16 | # if greyscale images add channel 17 | if x.shape[-3] == 1: 18 | x = x.repeat(1, 1, 3, 1, 1) 19 | 20 | # value range [0, 1] -> [-1, 1] 21 | x = x * 2 - 1 22 | 23 | return x 24 | 25 | def calculate_lpips(videos1, videos2, device): 26 | # image should be RGB, IMPORTANT: normalized to [-1,1] 27 | print("calculate_lpips...") 28 | 29 | assert videos1.shape == videos2.shape 30 | 31 | # videos [batch_size, timestamps, channel, h, w] 32 | 33 | # support grayscale input, if grayscale -> channel*3 34 | # value range [0, 1] -> [-1, 1] 35 | videos1 = trans(videos1) 36 | videos2 = trans(videos2) 37 | 38 | lpips_results = [] 39 | 40 | for video_num in tqdm(range(videos1.shape[0])): 41 | # get a video 42 | # video [timestamps, channel, h, w] 43 | video1 = videos1[video_num] 44 | video2 = videos2[video_num] 45 | 46 | lpips_results_of_a_video = [] 47 | for clip_timestamp in range(len(video1)): 48 | # get a img 49 | # img [timestamps[x], channel, h, w] 50 | # img [channel, h, w] tensor 51 | 52 | img1 = video1[clip_timestamp].unsqueeze(0).to(device) 53 | img2 = video2[clip_timestamp].unsqueeze(0).to(device) 54 | 55 | loss_fn.to(device) 56 | 57 | # calculate lpips of a video 58 | lpips_results_of_a_video.append(loss_fn.forward(img1, img2).mean().detach().cpu().tolist()) 59 | lpips_results.append(lpips_results_of_a_video) 60 | 61 | lpips_results = np.array(lpips_results) 62 | 63 | lpips = {} 64 | lpips_std = {} 65 | 66 | for clip_timestamp in range(len(video1)): 67 | lpips[clip_timestamp] = np.mean(lpips_results[:,clip_timestamp]) 68 | lpips_std[clip_timestamp] = np.std(lpips_results[:,clip_timestamp]) 69 | 70 | 71 | result = { 72 | "value": lpips, 73 | "value_std": lpips_std, 74 | "video_setting": video1.shape, 75 | "video_setting_name": "time, channel, heigth, width", 76 | } 77 | 78 | return result 79 | 80 | # test code / using example 81 | 82 | def main(): 83 | NUMBER_OF_VIDEOS = 8 84 | VIDEO_LENGTH = 50 85 | CHANNEL = 3 86 | SIZE = 64 87 | videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False) 88 | videos2 = torch.ones(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False) 89 | device = torch.device("cuda") 90 | # device = torch.device("cpu") 91 | 92 | import json 93 | result = calculate_lpips(videos1, videos2, device) 94 | print(json.dumps(result, indent=4)) 95 | 96 | if __name__ == "__main__": 97 | main() -------------------------------------------------------------------------------- /opensora/models/causalvideovae/eval/cal_psnr.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from tqdm import tqdm 4 | import math 5 | 6 | def img_psnr_cuda(img1, img2): 7 | # [0,1] 8 | # compute mse 9 | # mse = np.mean((img1-img2)**2) 10 | mse = torch.mean((img1 / 1.0 - img2 / 1.0) ** 2) 11 | # compute psnr 12 | if mse < 1e-10: 13 | return 100 14 | psnr = 20 * torch.log10(1 / torch.sqrt(mse)) 15 | return psnr 16 | 17 | 18 | def img_psnr(img1, img2): 19 | # [0,1] 20 | # compute mse 21 | # mse = np.mean((img1-img2)**2) 22 | mse = np.mean((img1 / 1.0 - img2 / 1.0) ** 2) 23 | # compute psnr 24 | if mse < 1e-10: 25 | return 100 26 | psnr = 20 * math.log10(1 / math.sqrt(mse)) 27 | return psnr 28 | 29 | 30 | def trans(x): 31 | return x 32 | 33 | def calculate_psnr(videos1, videos2): 34 | print("calculate_psnr...") 35 | 36 | # videos [batch_size, timestamps, channel, h, w] 37 | 38 | assert videos1.shape == videos2.shape 39 | 40 | videos1 = trans(videos1) 41 | videos2 = trans(videos2) 42 | 43 | psnr_results = [] 44 | 45 | for video_num in tqdm(range(videos1.shape[0])): 46 | # get a video 47 | # video [timestamps, channel, h, w] 48 | video1 = videos1[video_num] 49 | video2 = videos2[video_num] 50 | 51 | psnr_results_of_a_video = [] 52 | for clip_timestamp in range(len(video1)): 53 | # get a img 54 | # img [timestamps[x], channel, h, w] 55 | # img [channel, h, w] numpy 56 | 57 | img1 = video1[clip_timestamp].numpy() 58 | img2 = video2[clip_timestamp].numpy() 59 | 60 | # calculate psnr of a video 61 | psnr_results_of_a_video.append(img_psnr(img1, img2)) 62 | 63 | psnr_results.append(psnr_results_of_a_video) 64 | 65 | psnr_results = np.array(psnr_results) # [batch_size, num_frames] 66 | psnr = {} 67 | psnr_std = {} 68 | 69 | for clip_timestamp in range(len(video1)): 70 | psnr[clip_timestamp] = np.mean(psnr_results[:,clip_timestamp]) 71 | psnr_std[clip_timestamp] = np.std(psnr_results[:,clip_timestamp]) 72 | 73 | result = { 74 | "value": psnr, 75 | "value_std": psnr_std, 76 | "video_setting": video1.shape, 77 | "video_setting_name": "time, channel, heigth, width", 78 | } 79 | 80 | return result 81 | 82 | # test code / using example 83 | 84 | def main(): 85 | NUMBER_OF_VIDEOS = 8 86 | VIDEO_LENGTH = 50 87 | CHANNEL = 3 88 | SIZE = 64 89 | videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False) 90 | videos2 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False) 91 | 92 | import json 93 | result = calculate_psnr(videos1, videos2) 94 | print(json.dumps(result, indent=4)) 95 | 96 | if __name__ == "__main__": 97 | main() -------------------------------------------------------------------------------- /opensora/models/causalvideovae/eval/cal_ssim.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from tqdm import tqdm 4 | import cv2 5 | 6 | def ssim(img1, img2): 7 | C1 = 0.01 ** 2 8 | C2 = 0.03 ** 2 9 | img1 = img1.astype(np.float64) 10 | img2 = img2.astype(np.float64) 11 | kernel = cv2.getGaussianKernel(11, 1.5) 12 | window = np.outer(kernel, kernel.transpose()) 13 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid 14 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 15 | mu1_sq = mu1 ** 2 16 | mu2_sq = mu2 ** 2 17 | mu1_mu2 = mu1 * mu2 18 | sigma1_sq = cv2.filter2D(img1 ** 2, -1, window)[5:-5, 5:-5] - mu1_sq 19 | sigma2_sq = cv2.filter2D(img2 ** 2, -1, window)[5:-5, 5:-5] - mu2_sq 20 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 21 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 22 | (sigma1_sq + sigma2_sq + C2)) 23 | return ssim_map.mean() 24 | 25 | 26 | def calculate_ssim_function(img1, img2): 27 | # [0,1] 28 | # ssim is the only metric extremely sensitive to gray being compared to b/w 29 | if not img1.shape == img2.shape: 30 | raise ValueError('Input images must have the same dimensions.') 31 | if img1.ndim == 2: 32 | return ssim(img1, img2) 33 | elif img1.ndim == 3: 34 | if img1.shape[0] == 3: 35 | ssims = [] 36 | for i in range(3): 37 | ssims.append(ssim(img1[i], img2[i])) 38 | return np.array(ssims).mean() 39 | elif img1.shape[0] == 1: 40 | return ssim(np.squeeze(img1), np.squeeze(img2)) 41 | else: 42 | raise ValueError('Wrong input image dimensions.') 43 | 44 | def trans(x): 45 | return x 46 | 47 | def calculate_ssim(videos1, videos2): 48 | print("calculate_ssim...") 49 | 50 | # videos [batch_size, timestamps, channel, h, w] 51 | 52 | assert videos1.shape == videos2.shape 53 | 54 | videos1 = trans(videos1) 55 | videos2 = trans(videos2) 56 | 57 | ssim_results = [] 58 | 59 | for video_num in tqdm(range(videos1.shape[0])): 60 | # get a video 61 | # video [timestamps, channel, h, w] 62 | video1 = videos1[video_num] 63 | video2 = videos2[video_num] 64 | 65 | ssim_results_of_a_video = [] 66 | for clip_timestamp in range(len(video1)): 67 | # get a img 68 | # img [timestamps[x], channel, h, w] 69 | # img [channel, h, w] numpy 70 | 71 | img1 = video1[clip_timestamp].numpy() 72 | img2 = video2[clip_timestamp].numpy() 73 | 74 | # calculate ssim of a video 75 | ssim_results_of_a_video.append(calculate_ssim_function(img1, img2)) 76 | 77 | ssim_results.append(ssim_results_of_a_video) 78 | 79 | ssim_results = np.array(ssim_results) 80 | 81 | ssim = {} 82 | ssim_std = {} 83 | 84 | for clip_timestamp in range(len(video1)): 85 | ssim[clip_timestamp] = np.mean(ssim_results[:,clip_timestamp]) 86 | ssim_std[clip_timestamp] = np.std(ssim_results[:,clip_timestamp]) 87 | 88 | result = { 89 | "value": ssim, 90 | "value_std": ssim_std, 91 | "video_setting": video1.shape, 92 | "video_setting_name": "time, channel, heigth, width", 93 | } 94 | 95 | return result 96 | 97 | # test code / using example 98 | 99 | def main(): 100 | NUMBER_OF_VIDEOS = 8 101 | VIDEO_LENGTH = 50 102 | CHANNEL = 3 103 | SIZE = 64 104 | videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False) 105 | videos2 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False) 106 | device = torch.device("cuda") 107 | 108 | import json 109 | result = calculate_ssim(videos1, videos2) 110 | print(json.dumps(result, indent=4)) 111 | 112 | if __name__ == "__main__": 113 | main() -------------------------------------------------------------------------------- /opensora/models/causalvideovae/eval/fvd/styleganv/fvd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import math 4 | import torch.nn.functional as F 5 | 6 | # https://github.com/universome/fvd-comparison 7 | 8 | 9 | def load_i3d_pretrained(device=torch.device('cpu')): 10 | i3D_WEIGHTS_URL = "https://www.dropbox.com/s/ge9e5ujwgetktms/i3d_torchscript.pt" 11 | filepath = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'i3d_torchscript.pt') 12 | print(filepath) 13 | if not os.path.exists(filepath): 14 | print(f"preparing for download {i3D_WEIGHTS_URL}, you can download it by yourself.") 15 | os.system(f"wget {i3D_WEIGHTS_URL} -O {filepath}") 16 | i3d = torch.jit.load(filepath).eval().to(device) 17 | i3d = torch.nn.DataParallel(i3d) 18 | return i3d 19 | 20 | 21 | def get_feats(videos, detector, device, bs=10): 22 | # videos : torch.tensor BCTHW [0, 1] 23 | detector_kwargs = dict(rescale=False, resize=False, return_features=True) # Return raw features before the softmax layer. 24 | feats = np.empty((0, 400)) 25 | with torch.no_grad(): 26 | for i in range((len(videos)-1)//bs + 1): 27 | feats = np.vstack([feats, detector(torch.stack([preprocess_single(video) for video in videos[i*bs:(i+1)*bs]]).to(device), **detector_kwargs).detach().cpu().numpy()]) 28 | return feats 29 | 30 | 31 | def get_fvd_feats(videos, i3d, device, bs=10): 32 | # videos in [0, 1] as torch tensor BCTHW 33 | # videos = [preprocess_single(video) for video in videos] 34 | embeddings = get_feats(videos, i3d, device, bs) 35 | return embeddings 36 | 37 | 38 | def preprocess_single(video, resolution=224, sequence_length=None): 39 | # video: CTHW, [0, 1] 40 | c, t, h, w = video.shape 41 | 42 | # temporal crop 43 | if sequence_length is not None: 44 | assert sequence_length <= t 45 | video = video[:, :sequence_length] 46 | 47 | # scale shorter side to resolution 48 | scale = resolution / min(h, w) 49 | if h < w: 50 | target_size = (resolution, math.ceil(w * scale)) 51 | else: 52 | target_size = (math.ceil(h * scale), resolution) 53 | video = F.interpolate(video, size=target_size, mode='bilinear', align_corners=False) 54 | 55 | # center crop 56 | c, t, h, w = video.shape 57 | w_start = (w - resolution) // 2 58 | h_start = (h - resolution) // 2 59 | video = video[:, :, h_start:h_start + resolution, w_start:w_start + resolution] 60 | 61 | # [0, 1] -> [-1, 1] 62 | video = (video - 0.5) * 2 63 | 64 | return video.contiguous() 65 | 66 | 67 | """ 68 | Copy-pasted from https://github.com/cvpr2022-stylegan-v/stylegan-v/blob/main/src/metrics/frechet_video_distance.py 69 | """ 70 | from typing import Tuple 71 | from scipy.linalg import sqrtm 72 | import numpy as np 73 | 74 | 75 | def compute_stats(feats: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: 76 | mu = feats.mean(axis=0) # [d] 77 | sigma = np.cov(feats, rowvar=False) # [d, d] 78 | return mu, sigma 79 | 80 | 81 | def frechet_distance(feats_fake: np.ndarray, feats_real: np.ndarray) -> float: 82 | mu_gen, sigma_gen = compute_stats(feats_fake) 83 | mu_real, sigma_real = compute_stats(feats_real) 84 | m = np.square(mu_gen - mu_real).sum() 85 | if feats_fake.shape[0]>1: 86 | s, _ = sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member 87 | fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2)) 88 | else: 89 | fid = np.real(m) 90 | return float(fid) -------------------------------------------------------------------------------- /opensora/models/causalvideovae/eval/fvd/videogpt/fvd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import math 4 | import torch.nn.functional as F 5 | import numpy as np 6 | import einops 7 | 8 | def load_i3d_pretrained(device=torch.device('cpu')): 9 | i3D_WEIGHTS_URL = "https://onedrive.live.com/download?cid=78EEF3EB6AE7DBCB&resid=78EEF3EB6AE7DBCB%21199&authkey=AApKdFHPXzWLNyI" 10 | filepath = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'i3d_pretrained_400.pt') 11 | print(filepath) 12 | if not os.path.exists(filepath): 13 | print(f"preparing for download {i3D_WEIGHTS_URL}, you can download it by yourself.") 14 | os.system(f"wget {i3D_WEIGHTS_URL} -O {filepath}") 15 | from .pytorch_i3d import InceptionI3d 16 | i3d = InceptionI3d(400, in_channels=3).eval().to(device) 17 | i3d.load_state_dict(torch.load(filepath, map_location=device)) 18 | i3d = torch.nn.DataParallel(i3d) 19 | return i3d 20 | 21 | def preprocess_single(video, resolution, sequence_length=None): 22 | # video: THWC, {0, ..., 255} 23 | video = video.permute(0, 3, 1, 2).float() / 255. # TCHW 24 | t, c, h, w = video.shape 25 | 26 | # temporal crop 27 | if sequence_length is not None: 28 | assert sequence_length <= t 29 | video = video[:sequence_length] 30 | 31 | # scale shorter side to resolution 32 | scale = resolution / min(h, w) 33 | if h < w: 34 | target_size = (resolution, math.ceil(w * scale)) 35 | else: 36 | target_size = (math.ceil(h * scale), resolution) 37 | video = F.interpolate(video, size=target_size, mode='bilinear', 38 | align_corners=False) 39 | 40 | # center crop 41 | t, c, h, w = video.shape 42 | w_start = (w - resolution) // 2 43 | h_start = (h - resolution) // 2 44 | video = video[:, :, h_start:h_start + resolution, w_start:w_start + resolution] 45 | video = video.permute(1, 0, 2, 3).contiguous() # CTHW 46 | 47 | video -= 0.5 48 | 49 | return video 50 | 51 | def preprocess(videos, target_resolution=224): 52 | # we should tras videos in [0-1] [b c t h w] as th.float 53 | # -> videos in {0, ..., 255} [b t h w c] as np.uint8 array 54 | videos = einops.rearrange(videos, 'b c t h w -> b t h w c') 55 | videos = (videos*255).numpy().astype(np.uint8) 56 | 57 | b, t, h, w, c = videos.shape 58 | videos = torch.from_numpy(videos) 59 | videos = torch.stack([preprocess_single(video, target_resolution) for video in videos]) 60 | return videos * 2 # [-0.5, 0.5] -> [-1, 1] 61 | 62 | def get_fvd_logits(videos, i3d, device, bs=10): 63 | videos = preprocess(videos) 64 | embeddings = get_logits(i3d, videos, device, bs=10) 65 | return embeddings 66 | 67 | # https://github.com/tensorflow/gan/blob/de4b8da3853058ea380a6152bd3bd454013bf619/tensorflow_gan/python/eval/classifier_metrics.py#L161 68 | def _symmetric_matrix_square_root(mat, eps=1e-10): 69 | u, s, v = torch.svd(mat) 70 | si = torch.where(s < eps, s, torch.sqrt(s)) 71 | return torch.matmul(torch.matmul(u, torch.diag(si)), v.t()) 72 | 73 | # https://github.com/tensorflow/gan/blob/de4b8da3853058ea380a6152bd3bd454013bf619/tensorflow_gan/python/eval/classifier_metrics.py#L400 74 | def trace_sqrt_product(sigma, sigma_v): 75 | sqrt_sigma = _symmetric_matrix_square_root(sigma) 76 | sqrt_a_sigmav_a = torch.matmul(sqrt_sigma, torch.matmul(sigma_v, sqrt_sigma)) 77 | return torch.trace(_symmetric_matrix_square_root(sqrt_a_sigmav_a)) 78 | 79 | # https://discuss.pytorch.org/t/covariance-and-gradient-support/16217/2 80 | def cov(m, rowvar=False): 81 | '''Estimate a covariance matrix given data. 82 | 83 | Covariance indicates the level to which two variables vary together. 84 | If we examine N-dimensional samples, `X = [x_1, x_2, ... x_N]^T`, 85 | then the covariance matrix element `C_{ij}` is the covariance of 86 | `x_i` and `x_j`. The element `C_{ii}` is the variance of `x_i`. 87 | 88 | Args: 89 | m: A 1-D or 2-D array containing multiple variables and observations. 90 | Each row of `m` represents a variable, and each column a single 91 | observation of all those variables. 92 | rowvar: If `rowvar` is True, then each row represents a 93 | variable, with observations in the columns. Otherwise, the 94 | relationship is transposed: each column represents a variable, 95 | while the rows contain observations. 96 | 97 | Returns: 98 | The covariance matrix of the variables. 99 | ''' 100 | if m.dim() > 2: 101 | raise ValueError('m has more than 2 dimensions') 102 | if m.dim() < 2: 103 | m = m.view(1, -1) 104 | if not rowvar and m.size(0) != 1: 105 | m = m.t() 106 | 107 | fact = 1.0 / (m.size(1) - 1) # unbiased estimate 108 | m -= torch.mean(m, dim=1, keepdim=True) 109 | mt = m.t() # if complex: mt = m.t().conj() 110 | return fact * m.matmul(mt).squeeze() 111 | 112 | 113 | def frechet_distance(x1, x2): 114 | x1 = x1.flatten(start_dim=1) 115 | x2 = x2.flatten(start_dim=1) 116 | m, m_w = x1.mean(dim=0), x2.mean(dim=0) 117 | sigma, sigma_w = cov(x1, rowvar=False), cov(x2, rowvar=False) 118 | mean = torch.sum((m - m_w) ** 2) 119 | if x1.shape[0]>1: 120 | sqrt_trace_component = trace_sqrt_product(sigma, sigma_w) 121 | trace = torch.trace(sigma + sigma_w) - 2.0 * sqrt_trace_component 122 | fd = trace + mean 123 | else: 124 | fd = np.real(mean) 125 | return float(fd) 126 | 127 | 128 | def get_logits(i3d, videos, device, bs=10): 129 | # assert videos.shape[0] % 16 == 0 130 | with torch.no_grad(): 131 | logits = [] 132 | for i in range(0, videos.shape[0], bs): 133 | batch = videos[i:i + bs].to(device) 134 | # logits.append(i3d.module.extract_features(batch)) # wrong 135 | logits.append(i3d(batch)) # right 136 | logits = torch.cat(logits, dim=0) 137 | return logits 138 | -------------------------------------------------------------------------------- /opensora/models/causalvideovae/eval/script/cal_clip_score.sh: -------------------------------------------------------------------------------- 1 | # clip_score cross modality 2 | python eval_clip_score.py \ 3 | --real_path path/to/image \ 4 | --generated_path path/to/text \ 5 | --batch-size 50 \ 6 | --device "cuda" 7 | 8 | # clip_score within the same modality 9 | python eval_clip_score.py \ 10 | --real_path path/to/textA \ 11 | --generated_path path/to/textB \ 12 | --real_flag txt \ 13 | --generated_flag txt \ 14 | --batch-size 50 \ 15 | --device "cuda" 16 | 17 | python eval_clip_score.py \ 18 | --real_path path/to/imageA \ 19 | --generated_path path/to/imageB \ 20 | --real_flag img \ 21 | --generated_flag img \ 22 | --batch-size 50 \ 23 | --device "cuda" 24 | -------------------------------------------------------------------------------- /opensora/models/causalvideovae/eval/script/cal_fvd.sh: -------------------------------------------------------------------------------- 1 | python eval_common_metric.py \ 2 | --real_video_dir path/to/imageA\ 3 | --generated_video_dir path/to/imageB \ 4 | --batch_size 10 \ 5 | --crop_size 64 \ 6 | --num_frames 20 \ 7 | --device 'cuda' \ 8 | --metric 'fvd' \ 9 | --fvd_method 'styleganv' 10 | -------------------------------------------------------------------------------- /opensora/models/causalvideovae/eval/script/cal_lpips.sh: -------------------------------------------------------------------------------- 1 | python eval_common_metric.py \ 2 | --real_video_dir path/to/imageA\ 3 | --generated_video_dir path/to/imageB \ 4 | --batch_size 10 \ 5 | --num_frames 20 \ 6 | --crop_size 64 \ 7 | --device 'cuda' \ 8 | --metric 'lpips' -------------------------------------------------------------------------------- /opensora/models/causalvideovae/eval/script/cal_psnr.sh: -------------------------------------------------------------------------------- 1 | 2 | python eval_common_metric.py \ 3 | --real_video_dir /data/xiaogeng_liu/data/video1 \ 4 | --generated_video_dir /data/xiaogeng_liu/data/video2 \ 5 | --batch_size 10 \ 6 | --num_frames 20 \ 7 | --crop_size 64 \ 8 | --device 'cuda' \ 9 | --metric 'psnr' -------------------------------------------------------------------------------- /opensora/models/causalvideovae/eval/script/cal_ssim.sh: -------------------------------------------------------------------------------- 1 | python eval_common_metric.py \ 2 | --real_video_dir /data/xiaogeng_liu/data/video1 \ 3 | --generated_video_dir /data/xiaogeng_liu/data/video2 \ 4 | --batch_size 10 \ 5 | --num_frames 20 \ 6 | --crop_size 64 \ 7 | --device 'cuda' \ 8 | --metric 'ssim' -------------------------------------------------------------------------------- /opensora/models/causalvideovae/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .registry import ModelRegistry 2 | from .vae import ( 3 | CausalVAEModel, WFVAEModel 4 | ) 5 | -------------------------------------------------------------------------------- /opensora/models/causalvideovae/model/configuration_videobase.py: -------------------------------------------------------------------------------- 1 | import json 2 | import yaml 3 | from typing import TypeVar, Dict, Any 4 | from diffusers import ConfigMixin 5 | 6 | T = TypeVar('T', bound='VideoBaseConfiguration') 7 | class VideoBaseConfiguration(ConfigMixin): 8 | config_name = "VideoBaseConfiguration" 9 | _nested_config_fields: Dict[str, Any] = {} 10 | 11 | def __init__(self, **kwargs): 12 | pass 13 | 14 | def to_dict(self) -> Dict[str, Any]: 15 | d = {} 16 | for key, value in vars(self).items(): 17 | if isinstance(value, VideoBaseConfiguration): 18 | d[key] = value.to_dict() # Serialize nested VideoBaseConfiguration instances 19 | elif isinstance(value, tuple): 20 | d[key] = list(value) 21 | else: 22 | d[key] = value 23 | return d 24 | 25 | def to_yaml_file(self, yaml_path: str): 26 | with open(yaml_path, 'w') as yaml_file: 27 | yaml.dump(self.to_dict(), yaml_file, default_flow_style=False) 28 | 29 | @classmethod 30 | def load_from_yaml(cls: T, yaml_path: str) -> T: 31 | with open(yaml_path, 'r') as yaml_file: 32 | config_dict = yaml.safe_load(yaml_file) 33 | for field, field_type in cls._nested_config_fields.items(): 34 | if field in config_dict: 35 | config_dict[field] = field_type.load_from_dict(config_dict[field]) 36 | return cls(**config_dict) 37 | 38 | @classmethod 39 | def load_from_dict(cls: T, config_dict: Dict[str, Any]) -> T: 40 | # Process nested configuration objects 41 | for field, field_type in cls._nested_config_fields.items(): 42 | if field in config_dict: 43 | config_dict[field] = field_type.load_from_dict(config_dict[field]) 44 | return cls(**config_dict) -------------------------------------------------------------------------------- /opensora/models/causalvideovae/model/dataset_videobase.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import random 3 | from glob import glob 4 | 5 | from torchvision import transforms 6 | import numpy as np 7 | import torch 8 | import torch.utils.data as data 9 | import torch.nn.functional as F 10 | from torchvision.transforms import Lambda 11 | 12 | from ..dataset.transform import ToTensorVideo, CenterCropVideo 13 | from ..utils.dataset_utils import DecordInit 14 | 15 | def TemporalRandomCrop(total_frames, size): 16 | """ 17 | Performs a random temporal crop on a video sequence. 18 | 19 | This function randomly selects a continuous frame sequence of length `size` from a video sequence. 20 | `total_frames` indicates the total number of frames in the video sequence, and `size` represents the length of the frame sequence to be cropped. 21 | 22 | Parameters: 23 | - total_frames (int): The total number of frames in the video sequence. 24 | - size (int): The length of the frame sequence to be cropped. 25 | 26 | Returns: 27 | - (int, int): A tuple containing two integers. The first integer is the starting frame index of the cropped sequence, 28 | and the second integer is the ending frame index (inclusive) of the cropped sequence. 29 | """ 30 | rand_end = max(0, total_frames - size - 1) 31 | begin_index = random.randint(0, rand_end) 32 | end_index = min(begin_index + size, total_frames) 33 | return begin_index, end_index 34 | 35 | def resize(x, resolution): 36 | height, width = x.shape[-2:] 37 | resolution = min(2 * resolution, height, width) 38 | aspect_ratio = width / height 39 | if width <= height: 40 | new_width = resolution 41 | new_height = int(resolution / aspect_ratio) 42 | else: 43 | new_height = resolution 44 | new_width = int(resolution * aspect_ratio) 45 | resized_x = F.interpolate(x, size=(new_height, new_width), mode='bilinear', align_corners=True, antialias=True) 46 | return resized_x 47 | 48 | class VideoDataset(data.Dataset): 49 | """ Generic dataset for videos files stored in folders 50 | Returns BCTHW videos in the range [-0.5, 0.5] """ 51 | video_exts = ['avi', 'mp4', 'webm'] 52 | def __init__(self, video_folder, sequence_length, image_folder=None, train=True, resolution=64, sample_rate=1, dynamic_sample=True): 53 | 54 | self.train = train 55 | self.sequence_length = sequence_length 56 | self.sample_rate = sample_rate 57 | self.resolution = resolution 58 | self.v_decoder = DecordInit() 59 | self.video_folder = video_folder 60 | self.dynamic_sample = dynamic_sample 61 | 62 | self.transform = transforms.Compose([ 63 | ToTensorVideo(), 64 | # Lambda(lambda x: resize(x, self.resolution)), 65 | CenterCropVideo(self.resolution), 66 | Lambda(lambda x: 2.0 * x - 1.0) 67 | ]) 68 | print('Building datasets...') 69 | self.samples = self._make_dataset() 70 | 71 | def _make_dataset(self): 72 | samples = [] 73 | samples += sum([glob(osp.join(self.video_folder, '**', f'*.{ext}'), recursive=True) 74 | for ext in self.video_exts], []) 75 | return samples 76 | 77 | def __len__(self): 78 | return len(self.samples) 79 | 80 | def __getitem__(self, idx): 81 | video_path = self.samples[idx] 82 | try: 83 | video = self.decord_read(video_path) 84 | video = self.transform(video) # T C H W -> T C H W 85 | video = video.transpose(0, 1) # T C H W -> C T H W 86 | return dict(video=video, label="") 87 | except Exception as e: 88 | print(f'Error with {e}, {video_path}') 89 | return self.__getitem__(random.randint(0, self.__len__()-1)) 90 | 91 | def decord_read(self, path): 92 | decord_vr = self.v_decoder(path) 93 | total_frames = len(decord_vr) 94 | # Sampling video frames 95 | if self.dynamic_sample: 96 | sample_rate = random.randint(1, self.sample_rate) 97 | else: 98 | sample_rate = self.sample_rate 99 | size = self.sequence_length * sample_rate 100 | start_frame_ind, end_frame_ind = TemporalRandomCrop(total_frames, size) 101 | # assert end_frame_ind - start_frame_ind >= self.num_frames 102 | frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.sequence_length, dtype=int) 103 | 104 | video_data = decord_vr.get_batch(frame_indice).asnumpy() 105 | video_data = torch.from_numpy(video_data) 106 | video_data = video_data.permute(0, 3, 1, 2) # (T, H, W, C) -> (T C H W) 107 | return video_data -------------------------------------------------------------------------------- /opensora/models/causalvideovae/model/ema_model.py: -------------------------------------------------------------------------------- 1 | class EMA: 2 | def __init__(self, model, decay): 3 | self.model = model 4 | self.decay = decay 5 | self.shadow = {} 6 | self.backup = {} 7 | 8 | def register(self): 9 | for name, param in self.model.named_parameters(): 10 | if param.requires_grad: 11 | self.shadow[name] = param.data.clone() 12 | 13 | def update(self): 14 | for name, param in self.model.named_parameters(): 15 | if name in self.shadow: 16 | new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name] 17 | self.shadow[name] = new_average.clone() 18 | 19 | def apply_shadow(self): 20 | for name, param in self.model.named_parameters(): 21 | if name in self.shadow: 22 | self.backup[name] = param.data 23 | param.data = self.shadow[name] 24 | 25 | def restore(self): 26 | for name, param in self.model.named_parameters(): 27 | if name in self.shadow: 28 | param.data = self.backup[name] 29 | self.backup = {} 30 | 31 | -------------------------------------------------------------------------------- /opensora/models/causalvideovae/model/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .perceptual_loss import LPIPSWithDiscriminator3D 2 | -------------------------------------------------------------------------------- /opensora/models/causalvideovae/model/losses/discriminator.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import torch.nn as nn 3 | from ..modules.conv import CausalConv3d 4 | from einops import rearrange 5 | 6 | def weights_init(m): 7 | classname = m.__class__.__name__ 8 | if classname.find('Conv') != -1: 9 | nn.init.normal_(m.weight.data, 0.0, 0.02) 10 | elif classname.find('BatchNorm') != -1: 11 | nn.init.normal_(m.weight.data, 1.0, 0.02) 12 | nn.init.constant_(m.bias.data, 0) 13 | 14 | def weights_init_conv(m): 15 | if hasattr(m, 'conv'): 16 | m = m.conv 17 | classname = m.__class__.__name__ 18 | if classname.find('Conv') != -1: 19 | nn.init.normal_(m.weight.data, 0.0, 0.02) 20 | elif classname.find('BatchNorm') != -1: 21 | nn.init.normal_(m.weight.data, 1.0, 0.02) 22 | nn.init.constant_(m.bias.data, 0) 23 | 24 | class NLayerDiscriminator3D(nn.Module): 25 | """Defines a 3D PatchGAN discriminator as in Pix2Pix but for 3D inputs.""" 26 | def __init__(self, input_nc=1, ndf=64, n_layers=3, use_actnorm=False): 27 | """ 28 | Construct a 3D PatchGAN discriminator 29 | 30 | Parameters: 31 | input_nc (int) -- the number of channels in input volumes 32 | ndf (int) -- the number of filters in the last conv layer 33 | n_layers (int) -- the number of conv layers in the discriminator 34 | use_actnorm (bool) -- flag to use actnorm instead of batchnorm 35 | """ 36 | super(NLayerDiscriminator3D, self).__init__() 37 | if not use_actnorm: 38 | norm_layer = nn.BatchNorm3d 39 | else: 40 | raise NotImplementedError("Not implemented.") 41 | if type(norm_layer) == functools.partial: 42 | use_bias = norm_layer.func != nn.BatchNorm3d 43 | else: 44 | use_bias = norm_layer != nn.BatchNorm3d 45 | 46 | kw = 3 47 | padw = 1 48 | sequence = [nn.Conv3d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] 49 | nf_mult = 1 50 | nf_mult_prev = 1 51 | for n in range(1, n_layers): # gradually increase the number of filters 52 | nf_mult_prev = nf_mult 53 | nf_mult = min(2 ** n, 8) 54 | sequence += [ 55 | nn.Conv3d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=(kw, kw, kw), stride=(2 if n==1 else 1,2,2), padding=padw, bias=use_bias), 56 | norm_layer(ndf * nf_mult), 57 | nn.LeakyReLU(0.2, True) 58 | ] 59 | 60 | nf_mult_prev = nf_mult 61 | nf_mult = min(2 ** n_layers, 8) 62 | sequence += [ 63 | nn.Conv3d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=(kw, kw, kw), stride=1, padding=padw, bias=use_bias), 64 | norm_layer(ndf * nf_mult), 65 | nn.LeakyReLU(0.2, True) 66 | ] 67 | 68 | sequence += [nn.Conv3d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map 69 | self.main = nn.Sequential(*sequence) 70 | 71 | def forward(self, input): 72 | """Standard forward.""" 73 | return self.main(input) 74 | 75 | 76 | 77 | 78 | 79 | # class NLayerDiscriminator3D(nn.Module): 80 | # """Defines a 3D PatchGAN discriminator as in Pix2Pix but for 3D inputs.""" 81 | # def __init__(self, input_nc=1, ndf=64, n_layers=3, use_actnorm=False): 82 | # """ 83 | # Construct a 3D PatchGAN discriminator 84 | 85 | # Parameters: 86 | # input_nc (int) -- the number of channels in input volumes 87 | # ndf (int) -- the number of filters in the last conv layer 88 | # n_layers (int) -- the number of conv layers in the discriminator 89 | # use_actnorm (bool) -- flag to use actnorm instead of batchnorm 90 | # """ 91 | # super(NLayerDiscriminator3D, self).__init__() 92 | # if not use_actnorm: 93 | # norm_layer = nn.BatchNorm3d 94 | # else: 95 | # raise NotImplementedError("Not implemented.") 96 | # if type(norm_layer) == functools.partial: 97 | # use_bias = norm_layer.func != nn.BatchNorm3d 98 | # else: 99 | # use_bias = norm_layer != nn.BatchNorm3d 100 | 101 | # kw = 4 102 | # padw = 1 103 | # sequence = [CausalConv3d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] 104 | # nf_mult = 1 105 | # nf_mult_prev = 1 106 | # for n in range(1, n_layers): # gradually increase the number of filters 107 | # nf_mult_prev = nf_mult 108 | # nf_mult = min(2 ** n, 8) 109 | # sequence += [ 110 | # CausalConv3d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=(kw, kw, kw), stride=(2 if n==1 else 1,2,2), padding=padw, bias=use_bias), 111 | # norm_layer(ndf * nf_mult), 112 | # nn.LeakyReLU(0.2, True) 113 | # ] 114 | 115 | # nf_mult_prev = nf_mult 116 | # nf_mult = min(2 ** n_layers, 8) 117 | # sequence += [ 118 | # CausalConv3d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=(kw, kw, kw), stride=1, padding=padw, bias=use_bias), 119 | # norm_layer(ndf * nf_mult), 120 | # nn.LeakyReLU(0.2, True) 121 | # ] 122 | 123 | # sequence += [CausalConv3d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map 124 | # self.main = nn.Sequential(*sequence) 125 | 126 | # def forward(self, input): 127 | # """Standard forward.""" 128 | # return self.main(input) -------------------------------------------------------------------------------- /opensora/models/causalvideovae/model/losses/lpips.py: -------------------------------------------------------------------------------- 1 | """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torchvision import models 6 | from collections import namedtuple 7 | from .....utils.taming_download import get_ckpt_path 8 | 9 | class LPIPS(nn.Module): 10 | # Learned perceptual metric 11 | def __init__(self, use_dropout=True): 12 | super().__init__() 13 | self.scaling_layer = ScalingLayer() 14 | self.chns = [64, 128, 256, 512, 512] # vg16 features 15 | self.net = vgg16(pretrained=True, requires_grad=False) 16 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) 17 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) 18 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) 19 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) 20 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) 21 | self.load_from_pretrained() 22 | for param in self.parameters(): 23 | param.requires_grad = False 24 | 25 | def load_from_pretrained(self, name="vgg_lpips"): 26 | ckpt = get_ckpt_path(name, ".cache/lpips") 27 | self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) 28 | print("loaded pretrained LPIPS loss from {}".format(ckpt)) 29 | 30 | @classmethod 31 | def from_pretrained(cls, name="vgg_lpips"): 32 | if name != "vgg_lpips": 33 | raise NotImplementedError 34 | model = cls() 35 | ckpt = get_ckpt_path(name) 36 | model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) 37 | return model 38 | 39 | def forward(self, input, target): 40 | in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) 41 | outs0, outs1 = self.net(in0_input), self.net(in1_input) 42 | feats0, feats1, diffs = {}, {}, {} 43 | lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] 44 | for kk in range(len(self.chns)): 45 | feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) 46 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 47 | 48 | res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))] 49 | val = res[0] 50 | for l in range(1, len(self.chns)): 51 | val += res[l] 52 | return val 53 | 54 | 55 | class ScalingLayer(nn.Module): 56 | def __init__(self): 57 | super(ScalingLayer, self).__init__() 58 | self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 59 | self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None]) 60 | 61 | def forward(self, inp): 62 | return (inp - self.shift) / self.scale 63 | 64 | 65 | class NetLinLayer(nn.Module): 66 | """ A single linear layer which does a 1x1 conv """ 67 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 68 | super(NetLinLayer, self).__init__() 69 | layers = [nn.Dropout(), ] if (use_dropout) else [] 70 | layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] 71 | self.model = nn.Sequential(*layers) 72 | 73 | 74 | class vgg16(torch.nn.Module): 75 | def __init__(self, requires_grad=False, pretrained=True): 76 | super(vgg16, self).__init__() 77 | vgg_pretrained_features = models.vgg16(pretrained=pretrained).features 78 | self.slice1 = torch.nn.Sequential() 79 | self.slice2 = torch.nn.Sequential() 80 | self.slice3 = torch.nn.Sequential() 81 | self.slice4 = torch.nn.Sequential() 82 | self.slice5 = torch.nn.Sequential() 83 | self.N_slices = 5 84 | for x in range(4): 85 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 86 | for x in range(4, 9): 87 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 88 | for x in range(9, 16): 89 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 90 | for x in range(16, 23): 91 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 92 | for x in range(23, 30): 93 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 94 | if not requires_grad: 95 | for param in self.parameters(): 96 | param.requires_grad = False 97 | 98 | def forward(self, X): 99 | h = self.slice1(X) 100 | h_relu1_2 = h 101 | h = self.slice2(h) 102 | h_relu2_2 = h 103 | h = self.slice3(h) 104 | h_relu3_3 = h 105 | h = self.slice4(h) 106 | h_relu4_3 = h 107 | h = self.slice5(h) 108 | h_relu5_3 = h 109 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) 110 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 111 | return out 112 | 113 | 114 | def normalize_tensor(x,eps=1e-10): 115 | norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True)) 116 | return x/(norm_factor+eps) 117 | 118 | 119 | def spatial_average(x, keepdim=True): 120 | return x.mean([2,3],keepdim=keepdim) 121 | -------------------------------------------------------------------------------- /opensora/models/causalvideovae/model/modeling_videobase.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from diffusers import ModelMixin, ConfigMixin 3 | from torch import nn 4 | import os 5 | import json 6 | from diffusers.configuration_utils import ConfigMixin 7 | from diffusers.models.modeling_utils import ModelMixin 8 | from typing import Optional, Union 9 | import glob 10 | 11 | 12 | class VideoBaseAE(ModelMixin, ConfigMixin): 13 | config_name = "config.json" 14 | 15 | def __init__(self, *args, **kwargs) -> None: 16 | super().__init__(*args, **kwargs) 17 | 18 | def encode(self, x: torch.Tensor, *args, **kwargs): 19 | pass 20 | 21 | def decode(self, encoding: torch.Tensor, *args, **kwargs): 22 | pass 23 | 24 | @property 25 | def num_training_steps(self) -> int: 26 | """Total training steps inferred from datamodule and devices.""" 27 | if self.trainer.max_steps: 28 | return self.trainer.max_steps 29 | 30 | limit_batches = self.trainer.limit_train_batches 31 | batches = len(self.train_dataloader()) 32 | batches = min(batches, limit_batches) if isinstance(limit_batches, int) else int(limit_batches * batches) 33 | 34 | num_devices = max(1, self.trainer.num_gpus, self.trainer.num_processes) 35 | if self.trainer.tpu_cores: 36 | num_devices = max(num_devices, self.trainer.tpu_cores) 37 | 38 | effective_accum = self.trainer.accumulate_grad_batches * num_devices 39 | return (batches // effective_accum) * self.trainer.max_epochs 40 | 41 | @classmethod 42 | def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): 43 | ckpt_files = glob.glob(os.path.join(pretrained_model_name_or_path, '*.ckpt')) 44 | if ckpt_files: 45 | # Adapt to checkpoint 46 | last_ckpt_file = ckpt_files[-1] 47 | config_file = os.path.join(pretrained_model_name_or_path, cls.config_name) 48 | model = cls.from_config(config_file) 49 | model.init_from_ckpt(last_ckpt_file) 50 | return model 51 | else: 52 | return super().from_pretrained(pretrained_model_name_or_path, **kwargs) -------------------------------------------------------------------------------- /opensora/models/causalvideovae/model/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .block import Block 2 | from .attention import * 3 | from .conv import * 4 | from .normalize import * 5 | from .resnet_block import * 6 | from .updownsample import * 7 | from .wavelet import * -------------------------------------------------------------------------------- /opensora/models/causalvideovae/model/modules/attention.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from .normalize import Normalize 4 | from .conv import CausalConv3d 5 | import torch 6 | from .block import Block 7 | 8 | try: 9 | import torch_npu 10 | from opensora.npu_config import npu_config, set_run_dtype 11 | except: 12 | torch_npu = None 13 | npu_config = None 14 | # from xformers import ops as xops 15 | 16 | class AttnBlock3D(Block): 17 | """Compatible with old versions, there are issues, use with caution.""" 18 | def __init__(self, in_channels): 19 | super().__init__() 20 | self.in_channels = in_channels 21 | 22 | self.norm = Normalize(in_channels) 23 | self.q = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) 24 | self.k = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) 25 | self.v = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) 26 | self.proj_out = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) 27 | 28 | def forward(self, x): 29 | h_ = x 30 | h_ = self.norm(h_) 31 | q = self.q(h_) 32 | k = self.k(h_) 33 | v = self.v(h_) 34 | 35 | # compute attention 36 | b, c, t, h, w = q.shape 37 | q = q.reshape(b * t, c, h * w) 38 | q = q.permute(0, 2, 1) # b,hw,c 39 | k = k.reshape(b * t, c, h * w) # b,c,hw 40 | w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] 41 | w_ = w_ * (int(c) ** (-0.5)) 42 | w_ = torch.nn.functional.softmax(w_, dim=2) 43 | 44 | # attend to values 45 | v = v.reshape(b * t, c, h * w) 46 | w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) 47 | h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] 48 | h_ = h_.reshape(b, c, t, h, w) 49 | 50 | h_ = self.proj_out(h_) 51 | 52 | return x + h_ 53 | 54 | class AttnBlock3DFix(nn.Module): 55 | """ 56 | Thanks to https://github.com/PKU-YuanGroup/Open-Sora-Plan/pull/172. 57 | """ 58 | def __init__(self, in_channels, norm_type="groupnorm"): 59 | super().__init__() 60 | self.in_channels = in_channels 61 | 62 | self.norm = Normalize(in_channels, norm_type=norm_type) 63 | self.q = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) 64 | self.k = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) 65 | self.v = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) 66 | self.proj_out = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) 67 | 68 | def forward(self, x): 69 | h_ = x 70 | h_ = self.norm(h_) 71 | q = self.q(h_) 72 | k = self.k(h_) 73 | v = self.v(h_) 74 | 75 | b, c, t, h, w = q.shape 76 | q = q.permute(0, 2, 3, 4, 1).reshape(b * t, h * w, c).contiguous() 77 | k = k.permute(0, 2, 3, 4, 1).reshape(b * t, h * w, c).contiguous() 78 | v = v.permute(0, 2, 3, 4, 1).reshape(b * t, h * w, c).contiguous() 79 | 80 | if torch_npu is None: 81 | # attn_output = xops.memory_efficient_attention( 82 | # q, k, v, 83 | # scale=c ** -0.5 84 | # ) 85 | q = q.view(b * t, -1, 1, c).transpose(1, 2) 86 | k = k.view(b * t, -1, 1, c).transpose(1, 2) 87 | v = v.view(b * t, -1, 1, c).transpose(1, 2) 88 | 89 | attn_output = F.scaled_dot_product_attention( 90 | q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False 91 | ) 92 | attn_output = attn_output.transpose(1, 2).reshape(b * t, -1, 1 * c) 93 | 94 | else: 95 | # print('npu_config.enable_FA, q.dtype == torch.float32', npu_config.enable_FA, q.dtype == torch.float32) 96 | if npu_config.enable_FA and q.dtype == torch.float32: 97 | dtype = torch.bfloat16 98 | else: 99 | dtype = None 100 | with set_run_dtype(q, dtype): 101 | query, key, value = npu_config.set_current_run_dtype([q, k, v]) 102 | hidden_states = npu_config.run_attention(query, key, value, atten_mask=None, input_layout="BSH", 103 | head_dim=c, head_num=1) 104 | 105 | attn_output = npu_config.restore_dtype(hidden_states) 106 | 107 | attn_output = attn_output.reshape(b, t, h, w, c).permute(0, 4, 1, 2, 3) 108 | h_ = self.proj_out(attn_output) 109 | 110 | return x + h_ 111 | -------------------------------------------------------------------------------- /opensora/models/causalvideovae/model/modules/block.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | class Block(nn.Module): 4 | def __init__(self, *args, **kwargs) -> None: 5 | super().__init__(*args, **kwargs) -------------------------------------------------------------------------------- /opensora/models/causalvideovae/model/modules/conv.py: -------------------------------------------------------------------------------- 1 | try: 2 | import torch_npu 3 | from opensora.npu_config import npu_config 4 | except: 5 | torch_npu = None 6 | npu_config = None 7 | 8 | import torch.nn as nn 9 | from typing import Union, Tuple 10 | import torch 11 | from .block import Block 12 | from .ops import cast_tuple 13 | from .ops import video_to_image 14 | from torch.utils.checkpoint import checkpoint 15 | import torch.nn.functional as F 16 | from collections import deque 17 | 18 | class Conv2d(nn.Conv2d): 19 | def __init__( 20 | self, 21 | in_channels: int, 22 | out_channels: int, 23 | kernel_size: Union[int, Tuple[int]] = 3, 24 | stride: Union[int, Tuple[int]] = 1, 25 | padding: Union[str, int, Tuple[int]] = 0, 26 | dilation: Union[int, Tuple[int]] = 1, 27 | groups: int = 1, 28 | bias: bool = True, 29 | padding_mode: str = "zeros", 30 | device=None, 31 | dtype=None, 32 | ) -> None: 33 | super().__init__( 34 | in_channels, 35 | out_channels, 36 | kernel_size, 37 | stride, 38 | padding, 39 | dilation, 40 | groups, 41 | bias, 42 | padding_mode, 43 | device, 44 | dtype, 45 | ) 46 | 47 | @video_to_image 48 | def forward(self, x): 49 | return super().forward(x) 50 | 51 | 52 | 53 | class CausalConv3d(Block): 54 | def __init__( 55 | self, 56 | chan_in, 57 | chan_out, 58 | kernel_size: Union[int, Tuple[int, int, int]], 59 | enable_cached=False, 60 | bias=True, 61 | **kwargs 62 | ): 63 | super().__init__() 64 | self.kernel_size = cast_tuple(kernel_size, 3) 65 | self.time_kernel_size = self.kernel_size[0] 66 | self.chan_in = chan_in 67 | self.chan_out = chan_out 68 | self.stride = kwargs.pop("stride", 1) 69 | self.padding = kwargs.pop("padding", 0) 70 | self.padding = list(cast_tuple(self.padding, 3)) 71 | self.padding[0] = 0 72 | self.stride = cast_tuple(self.stride, 3) 73 | self.conv = nn.Conv3d( 74 | chan_in, 75 | chan_out, 76 | self.kernel_size, 77 | stride=self.stride, 78 | padding=self.padding, 79 | bias=bias 80 | ) 81 | self.enable_cached = enable_cached 82 | 83 | self.is_first_chunk = True 84 | 85 | self.causal_cached = deque() 86 | self.cache_offset = 0 87 | 88 | def forward(self, x): 89 | if self.is_first_chunk: 90 | first_frame_pad = x[:, :, :1, :, :].repeat( 91 | (1, 1, self.time_kernel_size - 1, 1, 1) 92 | ) 93 | else: 94 | first_frame_pad = self.causal_cached.popleft() 95 | 96 | x = torch.concatenate((first_frame_pad, x), dim=2) 97 | 98 | if self.enable_cached and self.time_kernel_size != 1: 99 | if (self.time_kernel_size - 1) // self.stride[0] != 0: 100 | if self.cache_offset == 0: 101 | self.causal_cached.append(x[:, :, -(self.time_kernel_size - 1) // self.stride[0]:].clone()) 102 | else: 103 | self.causal_cached.append(x[:, :, :-self.cache_offset][:, :, -(self.time_kernel_size - 1) // self.stride[0]:].clone()) 104 | else: 105 | self.causal_cached.append(x[:, :, 0:0, :, :].clone()) 106 | elif self.enable_cached: 107 | self.causal_cached.append(x[:, :, 0:0, :, :].clone()) 108 | 109 | x = self.conv(x) 110 | return x 111 | 112 | 113 | class CausalConv3d_GC(CausalConv3d): 114 | def __init__( 115 | self, 116 | chan_in, 117 | chan_out, 118 | kernel_size: Union[int, Tuple[int]], 119 | init_method="random", 120 | **kwargs 121 | ): 122 | super().__init__(chan_in, chan_out, kernel_size, init_method, **kwargs) 123 | 124 | def forward(self, x): 125 | # 1 + 16 16 as video, 1 as image 126 | first_frame_pad = x[:, :, :1, :, :].repeat( 127 | (1, 1, self.time_kernel_size - 1, 1, 1) 128 | ) # b c t h w 129 | x = torch.concatenate((first_frame_pad, x), dim=2) # 3 + 16 130 | return checkpoint(self.conv, x) 131 | -------------------------------------------------------------------------------- /opensora/models/causalvideovae/model/modules/normalize.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .block import Block 4 | from einops import rearrange 5 | 6 | class GroupNorm(Block): 7 | def __init__(self, num_channels, num_groups=32, eps=1e-6, *args, **kwargs) -> None: 8 | super().__init__(*args, **kwargs) 9 | self.norm = torch.nn.GroupNorm( 10 | num_groups=num_groups, num_channels=num_channels, eps=eps, affine=True 11 | ) 12 | def forward(self, x): 13 | return self.norm(x) 14 | 15 | class LayerNorm(Block): 16 | def __init__(self, num_channels, eps=1e-6, *args, **kwargs) -> None: 17 | super().__init__(*args, **kwargs) 18 | self.norm = torch.nn.LayerNorm(num_channels, eps=eps, elementwise_affine=True) 19 | def forward(self, x): 20 | if x.dim() == 5: 21 | x = rearrange(x, "b c t h w -> b t h w c") 22 | x = self.norm(x) 23 | x = rearrange(x, "b t h w c -> b c t h w") 24 | else: 25 | x = rearrange(x, "b c h w -> b h w c") 26 | x = self.norm(x) 27 | x = rearrange(x, "b h w c -> b c h w") 28 | return x 29 | 30 | def Normalize(in_channels, num_groups=32, norm_type="groupnorm"): 31 | if norm_type == "groupnorm": 32 | return torch.nn.GroupNorm( 33 | num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True 34 | ) 35 | elif norm_type == "layernorm": 36 | return LayerNorm(num_channels=in_channels, eps=1e-6) -------------------------------------------------------------------------------- /opensora/models/causalvideovae/model/modules/ops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import rearrange 3 | 4 | def video_to_image(func): 5 | def wrapper(self, x, *args, **kwargs): 6 | if x.dim() == 5: 7 | t = x.shape[2] 8 | if True: 9 | x = rearrange(x, "b c t h w -> (b t) c h w") 10 | x = func(self, x, *args, **kwargs) 11 | x = rearrange(x, "(b t) c h w -> b c t h w", t=t) 12 | else: 13 | # Conv 2d slice infer 14 | result = [] 15 | for i in range(t): 16 | frame = x[:, :, i, :, :] 17 | frame = func(self, frame, *args, **kwargs) 18 | result.append(frame.unsqueeze(2)) 19 | x = torch.concatenate(result, dim=2) 20 | return x 21 | return wrapper 22 | 23 | def nonlinearity(x): 24 | return x * torch.sigmoid(x) 25 | 26 | def cast_tuple(t, length=1): 27 | return t if isinstance(t, tuple) or isinstance(t, list) else ((t,) * length) 28 | 29 | def shift_dim(x, src_dim=-1, dest_dim=-1, make_contiguous=True): 30 | n_dims = len(x.shape) 31 | if src_dim < 0: 32 | src_dim = n_dims + src_dim 33 | if dest_dim < 0: 34 | dest_dim = n_dims + dest_dim 35 | assert 0 <= src_dim < n_dims and 0 <= dest_dim < n_dims 36 | dims = list(range(n_dims)) 37 | del dims[src_dim] 38 | permutation = [] 39 | ctr = 0 40 | for i in range(n_dims): 41 | if i == dest_dim: 42 | permutation.append(src_dim) 43 | else: 44 | permutation.append(dims[ctr]) 45 | ctr += 1 46 | x = x.permute(permutation) 47 | if make_contiguous: 48 | x = x.contiguous() 49 | return x -------------------------------------------------------------------------------- /opensora/models/causalvideovae/model/modules/quant.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.distributed as dist 4 | import numpy as np 5 | import torch.nn.functional as F 6 | from .ops import shift_dim 7 | 8 | class Codebook(nn.Module): 9 | def __init__(self, n_codes, embedding_dim): 10 | super().__init__() 11 | self.register_buffer("embeddings", torch.randn(n_codes, embedding_dim)) 12 | self.register_buffer("N", torch.zeros(n_codes)) 13 | self.register_buffer("z_avg", self.embeddings.data.clone()) 14 | 15 | self.n_codes = n_codes 16 | self.embedding_dim = embedding_dim 17 | self._need_init = True 18 | 19 | def _tile(self, x): 20 | d, ew = x.shape 21 | if d < self.n_codes: 22 | n_repeats = (self.n_codes + d - 1) // d 23 | std = 0.01 / np.sqrt(ew) 24 | x = x.repeat(n_repeats, 1) 25 | x = x + torch.randn_like(x) * std 26 | return x 27 | 28 | def _init_embeddings(self, z): 29 | # z: [b, c, t, h, w] 30 | self._need_init = False 31 | flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2) 32 | y = self._tile(flat_inputs) 33 | 34 | d = y.shape[0] 35 | _k_rand = y[torch.randperm(y.shape[0])][: self.n_codes] 36 | if dist.is_initialized(): 37 | dist.broadcast(_k_rand, 0) 38 | self.embeddings.data.copy_(_k_rand) 39 | self.z_avg.data.copy_(_k_rand) 40 | self.N.data.copy_(torch.ones(self.n_codes)) 41 | 42 | def forward(self, z): 43 | # z: [b, c, t, h, w] 44 | if self._need_init and self.training: 45 | self._init_embeddings(z) 46 | flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2) 47 | distances = ( 48 | (flat_inputs**2).sum(dim=1, keepdim=True) 49 | - 2 * flat_inputs @ self.embeddings.t() 50 | + (self.embeddings.t() ** 2).sum(dim=0, keepdim=True) 51 | ) 52 | 53 | encoding_indices = torch.argmin(distances, dim=1) 54 | encode_onehot = F.one_hot(encoding_indices, self.n_codes).type_as(flat_inputs) 55 | encoding_indices = encoding_indices.view(z.shape[0], *z.shape[2:]) 56 | 57 | embeddings = F.embedding(encoding_indices, self.embeddings) 58 | embeddings = shift_dim(embeddings, -1, 1) 59 | 60 | commitment_loss = 0.25 * F.mse_loss(z, embeddings.detach()) 61 | 62 | # EMA codebook update 63 | if self.training: 64 | n_total = encode_onehot.sum(dim=0) 65 | encode_sum = flat_inputs.t() @ encode_onehot 66 | if dist.is_initialized(): 67 | dist.all_reduce(n_total) 68 | dist.all_reduce(encode_sum) 69 | 70 | self.N.data.mul_(0.99).add_(n_total, alpha=0.01) 71 | self.z_avg.data.mul_(0.99).add_(encode_sum.t(), alpha=0.01) 72 | 73 | n = self.N.sum() 74 | weights = (self.N + 1e-7) / (n + self.n_codes * 1e-7) * n 75 | encode_normalized = self.z_avg / weights.unsqueeze(1) 76 | self.embeddings.data.copy_(encode_normalized) 77 | 78 | y = self._tile(flat_inputs) 79 | _k_rand = y[torch.randperm(y.shape[0])][: self.n_codes] 80 | if dist.is_initialized(): 81 | dist.broadcast(_k_rand, 0) 82 | 83 | usage = (self.N.view(self.n_codes, 1) >= 1).float() 84 | self.embeddings.data.mul_(usage).add_(_k_rand * (1 - usage)) 85 | 86 | embeddings_st = (embeddings - z).detach() + z 87 | 88 | avg_probs = torch.mean(encode_onehot, dim=0) 89 | perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) 90 | 91 | return dict( 92 | embeddings=embeddings_st, 93 | encodings=encoding_indices, 94 | commitment_loss=commitment_loss, 95 | perplexity=perplexity, 96 | ) 97 | 98 | def dictionary_lookup(self, encodings): 99 | embeddings = F.embedding(encodings, self.embeddings) 100 | return embeddings -------------------------------------------------------------------------------- /opensora/models/causalvideovae/model/registry.py: -------------------------------------------------------------------------------- 1 | class ModelRegistry: 2 | _models = {} 3 | 4 | @classmethod 5 | def register(cls, model_name): 6 | def decorator(model_class): 7 | cls._models[model_name] = model_class 8 | return model_class 9 | return decorator 10 | 11 | @classmethod 12 | def get_model(cls, model_name): 13 | return cls._models.get(model_name) -------------------------------------------------------------------------------- /opensora/models/causalvideovae/model/trainer_videobase.py: -------------------------------------------------------------------------------- 1 | from transformers import Trainer 2 | import torch.nn.functional as F 3 | from typing import Optional 4 | import os 5 | import torch 6 | from transformers.utils import WEIGHTS_NAME 7 | import json 8 | 9 | class VideoBaseTrainer(Trainer): 10 | 11 | def _save(self, output_dir: Optional[str] = None, state_dict=None): 12 | output_dir = output_dir if output_dir is not None else self.args.output_dir 13 | os.makedirs(output_dir, exist_ok=True) 14 | if state_dict is None: 15 | state_dict = self.model.state_dict() 16 | 17 | # get model config 18 | model_config = self.model.config.to_dict() 19 | 20 | # add more information 21 | model_config['model'] = self.model.__class__.__name__ 22 | 23 | with open(os.path.join(output_dir, "config.json"), "w") as file: 24 | json.dump(self.model.config.to_dict(), file) 25 | torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) 26 | torch.save(self.args, os.path.join(output_dir, "training_args.bin")) 27 | -------------------------------------------------------------------------------- /opensora/models/causalvideovae/model/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-YuanGroup/Open-Sora-Plan/469d4e8810c326811e1be7e1c17b845503633210/opensora/models/causalvideovae/model/utils/__init__.py -------------------------------------------------------------------------------- /opensora/models/causalvideovae/model/utils/distrib_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | class DiagonalGaussianDistribution(object): 5 | def __init__(self, parameters, deterministic=False): 6 | self.parameters = parameters 7 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 8 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 9 | self.deterministic = deterministic 10 | self.std = torch.exp(0.5 * self.logvar) 11 | self.var = torch.exp(self.logvar) 12 | if self.deterministic: 13 | self.var = self.std = torch.zeros_like(self.mean, device=self.parameters.device, dtype=self.parameters.dtype) 14 | 15 | def sample(self): 16 | x = self.mean + self.std * torch.randn(self.mean.shape, device=self.parameters.device, dtype=self.parameters.dtype) 17 | return x 18 | 19 | def kl(self, other=None): 20 | if self.deterministic: 21 | return torch.Tensor([0.]) 22 | else: 23 | if other is None: 24 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 25 | + self.var - 1.0 - self.logvar, 26 | dim=[1, 2, 3]) 27 | else: 28 | return 0.5 * torch.sum( 29 | torch.pow(self.mean - other.mean, 2) / other.var 30 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 31 | dim=[1, 2, 3]) 32 | 33 | def nll(self, sample, dims=[1,2,3]): 34 | if self.deterministic: 35 | return torch.Tensor([0.]) 36 | logtwopi = np.log(2.0 * np.pi) 37 | return 0.5 * torch.sum( 38 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 39 | dim=dims) 40 | 41 | def mode(self): 42 | return self.mean 43 | -------------------------------------------------------------------------------- /opensora/models/causalvideovae/model/utils/module_utils.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | Module = str 4 | MODULES_BASE = "opensora.models.causalvideovae.model.modules." 5 | 6 | def resolve_str_to_obj(str_val, append=True): 7 | if append: 8 | str_val = MODULES_BASE + str_val 9 | module_name, class_name = str_val.rsplit('.', 1) 10 | module = importlib.import_module(module_name) 11 | return getattr(module, class_name) 12 | 13 | def create_instance(module_class_str: str, **kwargs): 14 | module_name, class_name = module_class_str.rsplit('.', 1) 15 | module = importlib.import_module(module_name) 16 | class_ = getattr(module, class_name) 17 | return class_(**kwargs) -------------------------------------------------------------------------------- /opensora/models/causalvideovae/model/utils/scheduler_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def cosine_scheduler(step, max_steps, value_base=1, value_end=0): 4 | step = torch.tensor(step) 5 | cosine_value = 0.5 * (1 + torch.cos(torch.pi * step / max_steps)) 6 | value = value_end + (value_base - value_end) * cosine_value 7 | return value -------------------------------------------------------------------------------- /opensora/models/causalvideovae/model/utils/video_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | def tensor_to_video(x): 5 | x = (x * 2 - 1).detach().cpu() 6 | x = torch.clamp(x, -1, 1) 7 | x = (x + 1) / 2 8 | x = x.permute(1, 0, 2, 3).float().numpy() # c t h w -> t c h w 9 | x = (255 * x).astype(np.uint8) 10 | return x -------------------------------------------------------------------------------- /opensora/models/causalvideovae/model/vae/__init__.py: -------------------------------------------------------------------------------- 1 | from .modeling_causalvae import CausalVAEModel 2 | from .modeling_wfvae import WFVAEModel 3 | from einops import rearrange 4 | from torch import nn -------------------------------------------------------------------------------- /opensora/models/causalvideovae/sample/rec_video_vae.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from tqdm import tqdm 3 | import torch 4 | import sys 5 | from torch.utils.data import DataLoader, Subset 6 | import os 7 | from accelerate import Accelerator 8 | 9 | sys.path.append(".") 10 | from opensora.models.causalvideovae.model import * 11 | from opensora.models.causalvideovae.dataset.video_dataset import ValidVideoDataset 12 | from opensora.models.causalvideovae.utils.video_utils import custom_to_video 13 | 14 | @torch.no_grad() 15 | def main(args: argparse.Namespace): 16 | accelerator = Accelerator() 17 | device = accelerator.device 18 | 19 | real_video_dir = args.real_video_dir 20 | generated_video_dir = args.generated_video_dir 21 | sample_rate = args.sample_rate 22 | resolution = args.resolution 23 | crop_size = args.crop_size 24 | num_frames = args.num_frames 25 | sample_rate = args.sample_rate 26 | device = args.device 27 | sample_fps = args.sample_fps 28 | batch_size = args.batch_size 29 | num_workers = args.num_workers 30 | subset_size = args.subset_size 31 | 32 | if not os.path.exists(args.generated_video_dir): 33 | os.makedirs(args.generated_video_dir, exist_ok=True) 34 | 35 | data_type = torch.bfloat16 36 | 37 | # ---- Load Model ---- 38 | device = args.device 39 | model_cls = ModelRegistry.get_model(args.model_name) 40 | vae = model_cls.from_pretrained(args.from_pretrained) 41 | vae = vae.to(device).to(data_type) 42 | if args.enable_tiling: 43 | vae.enable_tiling() 44 | vae.tile_overlap_factor = args.tile_overlap_factor 45 | 46 | # ---- Prepare Dataset ---- 47 | dataset = ValidVideoDataset( 48 | real_video_dir=real_video_dir, 49 | num_frames=num_frames, 50 | sample_rate=sample_rate, 51 | crop_size=crop_size, 52 | resolution=resolution, 53 | ) 54 | if subset_size: 55 | indices = range(subset_size) 56 | dataset = Subset(dataset, indices=indices) 57 | 58 | dataloader = DataLoader( 59 | dataset, batch_size=batch_size, pin_memory=False, num_workers=num_workers 60 | ) 61 | dataloader = accelerator.prepare(dataloader) 62 | 63 | # ---- Inference ---- 64 | for batch in tqdm(dataloader, disable=not accelerator.is_local_main_process): 65 | x, file_names = batch['video'], batch['file_name'] 66 | x = x.to(device=device, dtype=data_type) # b c t h w 67 | x = x * 2 - 1 68 | encode_result = vae.encode(x) 69 | if isinstance(encode_result, tuple): 70 | encode_result = encode_result[0] 71 | latents = encode_result.sample().to(data_type) 72 | video_recon = vae.decode(latents) 73 | if isinstance(video_recon, tuple): 74 | video_recon = video_recon[0] 75 | for idx, video in enumerate(video_recon): 76 | output_path = os.path.join(generated_video_dir, file_names[idx]) 77 | if args.output_origin: 78 | os.makedirs(os.path.join(generated_video_dir, "origin/"), exist_ok=True) 79 | origin_output_path = os.path.join(generated_video_dir, "origin/", file_names[idx]) 80 | custom_to_video( 81 | x[idx], fps=sample_fps / sample_rate, output_file=origin_output_path 82 | ) 83 | custom_to_video( 84 | video, fps=sample_fps / sample_rate, output_file=output_path 85 | ) 86 | 87 | if __name__ == "__main__": 88 | parser = argparse.ArgumentParser() 89 | parser.add_argument("--real_video_dir", type=str, default="") 90 | parser.add_argument("--generated_video_dir", type=str, default="") 91 | parser.add_argument("--from_pretrained", type=str, default="") 92 | parser.add_argument("--sample_fps", type=int, default=30) 93 | parser.add_argument("--resolution", type=int, default=336) 94 | parser.add_argument("--crop_size", type=int, default=None) 95 | parser.add_argument("--num_frames", type=int, default=17) 96 | parser.add_argument("--sample_rate", type=int, default=1) 97 | parser.add_argument("--batch_size", type=int, default=1) 98 | parser.add_argument("--num_workers", type=int, default=8) 99 | parser.add_argument("--subset_size", type=int, default=None) 100 | parser.add_argument("--tile_overlap_factor", type=float, default=0.25) 101 | parser.add_argument('--enable_tiling', action='store_true') 102 | parser.add_argument('--output_origin', action='store_true') 103 | parser.add_argument("--model_name", type=str, default=None, help="") 104 | parser.add_argument("--device", type=str, default="cuda") 105 | 106 | args = parser.parse_args() 107 | main(args) 108 | 109 | -------------------------------------------------------------------------------- /opensora/models/causalvideovae/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-YuanGroup/Open-Sora-Plan/469d4e8810c326811e1be7e1c17b845503633210/opensora/models/causalvideovae/utils/__init__.py -------------------------------------------------------------------------------- /opensora/models/causalvideovae/utils/dataset_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | from einops import rearrange 3 | import decord 4 | from torch.nn import functional as F 5 | import torch 6 | 7 | 8 | IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG'] 9 | 10 | def is_image_file(filename): 11 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 12 | 13 | class DecordInit(object): 14 | """Using Decord(https://github.com/dmlc/decord) to initialize the video_reader.""" 15 | 16 | def __init__(self, num_threads=1): 17 | self.num_threads = num_threads 18 | self.ctx = decord.cpu(0) 19 | 20 | def __call__(self, filename): 21 | """Perform the Decord initialization. 22 | Args: 23 | results (dict): The resulting dict to be modified and passed 24 | to the next transform in pipeline. 25 | """ 26 | reader = decord.VideoReader(filename, 27 | ctx=self.ctx, 28 | num_threads=self.num_threads) 29 | return reader 30 | 31 | def __repr__(self): 32 | repr_str = (f'{self.__class__.__name__}(' 33 | f'sr={self.sr},' 34 | f'num_threads={self.num_threads})') 35 | return repr_str 36 | 37 | def pad_to_multiple(number, ds_stride): 38 | remainder = number % ds_stride 39 | if remainder == 0: 40 | return number 41 | else: 42 | padding = ds_stride - remainder 43 | return number + padding 44 | -------------------------------------------------------------------------------- /opensora/models/causalvideovae/utils/downloader.py: -------------------------------------------------------------------------------- 1 | import gdown 2 | import os 3 | 4 | opensora_cache_home = os.path.expanduser( 5 | os.getenv("OPENSORA_HOME", os.path.join("~/.cache", "opensora")) 6 | ) 7 | 8 | 9 | def gdown_download(id, fname, cache_dir=None): 10 | cache_dir = opensora_cache_home if not cache_dir else cache_dir 11 | 12 | os.makedirs(cache_dir, exist_ok=True) 13 | destination = os.path.join(cache_dir, fname) 14 | if os.path.exists(destination): 15 | return destination 16 | 17 | gdown.download(id=id, output=destination, quiet=False) 18 | return destination 19 | -------------------------------------------------------------------------------- /opensora/models/causalvideovae/utils/video_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import numpy.typing as npt 4 | import cv2 5 | from decord import VideoReader, cpu 6 | 7 | def array_to_video( 8 | image_array: npt.NDArray, fps: float = 30.0, output_file: str = "output_video.mp4" 9 | ) -> None: 10 | """b h w c""" 11 | height, width, channels = image_array[0].shape 12 | fourcc = cv2.VideoWriter_fourcc(*"mp4v") 13 | video_writer = cv2.VideoWriter(output_file, fourcc, float(fps), (width, height)) 14 | 15 | for image in image_array: 16 | image_rgb = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) 17 | video_writer.write(image_rgb) 18 | 19 | video_writer.release() 20 | 21 | def custom_to_video( 22 | x: torch.Tensor, fps: float = 2.0, output_file: str = "output_video.mp4" 23 | ) -> None: 24 | x = x.detach().cpu() 25 | x = torch.clamp(x, -1, 1) 26 | x = (x + 1) / 2 27 | x = x.permute(1, 2, 3, 0).float().numpy() 28 | x = (255 * x).astype(np.uint8) 29 | array_to_video(x, fps=fps, output_file=output_file) 30 | return 31 | 32 | def read_video(video_path: str, num_frames: int, sample_rate: int) -> torch.Tensor: 33 | decord_vr = VideoReader(video_path, ctx=cpu(0), num_threads=8) 34 | total_frames = len(decord_vr) 35 | sample_frames_len = sample_rate * num_frames 36 | 37 | if total_frames > sample_frames_len: 38 | s = 0 39 | e = s + sample_frames_len 40 | num_frames = num_frames 41 | else: 42 | s = 0 43 | e = total_frames 44 | num_frames = int(total_frames / sample_frames_len * num_frames) 45 | print( 46 | f"sample_frames_len {sample_frames_len}, only can sample {num_frames * sample_rate}", 47 | video_path, 48 | total_frames, 49 | ) 50 | 51 | frame_id_list = np.linspace(s, e - 1, num_frames, dtype=int) 52 | video_data = decord_vr.get_batch(frame_id_list).asnumpy() 53 | video_data = torch.from_numpy(video_data) 54 | video_data = video_data.permute(3, 0, 1, 2) # (T, H, W, C) -> (C, T, H, W) 55 | return video_data 56 | 57 | def tensor_to_video(x): 58 | """[0-1] tensor to video""" 59 | x = (x * 2 - 1).detach().cpu() 60 | x = torch.clamp(x, -1, 1) 61 | x = (x + 1) / 2 62 | x = x.permute(1, 0, 2, 3).float().numpy() # c t h w -> t c h w 63 | x = (255 * x).astype(np.uint8) 64 | return x -------------------------------------------------------------------------------- /opensora/models/diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | from .opensora_v1_3.modeling_opensora import OpenSora_v1_3_models 2 | from .opensora_v1_3.modeling_inpaint import OpenSoraInpaint_v1_3_models 3 | 4 | 5 | Diffusion_models = {} 6 | Diffusion_models.update(OpenSora_v1_3_models) 7 | Diffusion_models.update(OpenSoraInpaint_v1_3_models) 8 | 9 | from .opensora_v1_3.modeling_opensora import OpenSora_v1_3_models_class 10 | from .opensora_v1_3.modeling_inpaint import OpenSoraInpaint_v1_3_models_class 11 | 12 | Diffusion_models_class = {} 13 | Diffusion_models_class.update(OpenSora_v1_3_models_class) 14 | Diffusion_models_class.update(OpenSoraInpaint_v1_3_models_class) 15 | -------------------------------------------------------------------------------- /opensora/models/diffusion/common.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import rearrange, repeat 3 | from typing import Any, Dict, Optional, Tuple 4 | import torch 5 | import torch.nn.functional as F 6 | from torch import nn 7 | from diffusers.models.attention_processor import Attention as Attention_ 8 | try: 9 | import torch_npu 10 | from opensora.npu_config import npu_config, set_run_dtype 11 | from opensora.acceleration.parallel_states import get_sequence_parallel_state, hccl_info as xccl_info 12 | from opensora.acceleration.communications import all_to_all_SBH 13 | except: 14 | torch_npu = None 15 | npu_config = None 16 | set_run_dtype = None 17 | from opensora.utils.parallel_states import get_sequence_parallel_state, nccl_info as xccl_info 18 | from opensora.utils.communications import all_to_all_SBH 19 | 20 | class PatchEmbed2D(nn.Module): 21 | """2D Image to Patch Embedding but with video""" 22 | 23 | def __init__( 24 | self, 25 | patch_size=16, 26 | in_channels=3, 27 | embed_dim=768, 28 | bias=True, 29 | ): 30 | super().__init__() 31 | self.proj = nn.Conv2d( 32 | in_channels, embed_dim, 33 | kernel_size=(patch_size, patch_size), stride=(patch_size, patch_size), bias=bias 34 | ) 35 | 36 | def forward(self, latent): 37 | b, _, _, _, _ = latent.shape 38 | latent = rearrange(latent, 'b c t h w -> (b t) c h w') 39 | latent = self.proj(latent) 40 | latent = rearrange(latent, '(b t) c h w -> b (t h w) c', b=b) 41 | return latent 42 | 43 | 44 | class PositionGetter3D(object): 45 | """ return positions of patches """ 46 | 47 | def __init__(self, ): 48 | self.cache_positions = {} 49 | 50 | def __call__(self, b, t, h, w, device): 51 | if not (b,t,h,w) in self.cache_positions: 52 | x = torch.arange(w, device=device) 53 | y = torch.arange(h, device=device) 54 | z = torch.arange(t, device=device) 55 | pos = torch.cartesian_prod(z, y, x) 56 | # print('PositionGetter3D', PositionGetter3D) 57 | pos = pos.reshape(t * h * w, 3).transpose(0, 1).reshape(3, -1, 1).contiguous().expand(3, -1, b).clone() 58 | poses = (pos[0].contiguous(), pos[1].contiguous(), pos[2].contiguous()) 59 | max_poses = (int(poses[0].max()), int(poses[1].max()), int(poses[2].max())) 60 | 61 | self.cache_positions[b, t, h, w] = (poses, max_poses) 62 | pos = self.cache_positions[b, t, h, w] 63 | 64 | return pos 65 | 66 | class RoPE3D(torch.nn.Module): 67 | 68 | def __init__(self, freq=10000.0, F0=1.0, interpolation_scale_thw=(1, 1, 1)): 69 | super().__init__() 70 | self.base = freq 71 | self.F0 = F0 72 | self.interpolation_scale_t = interpolation_scale_thw[0] 73 | self.interpolation_scale_h = interpolation_scale_thw[1] 74 | self.interpolation_scale_w = interpolation_scale_thw[2] 75 | self.cache = {} 76 | 77 | def get_cos_sin(self, D, seq_len, device, dtype, interpolation_scale=1): 78 | if (D, seq_len, device, dtype) not in self.cache: 79 | inv_freq = 1.0 / (self.base ** (torch.arange(0, D, 2).float().to(device) / D)) 80 | t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) / interpolation_scale 81 | freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype) 82 | freqs = torch.cat((freqs, freqs), dim=-1) 83 | cos = freqs.cos() # (Seq, Dim) 84 | sin = freqs.sin() 85 | self.cache[D, seq_len, device, dtype] = (cos, sin) 86 | return self.cache[D, seq_len, device, dtype] 87 | 88 | @staticmethod 89 | def rotate_half(x): 90 | x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2:] 91 | return torch.cat((-x2, x1), dim=-1) 92 | 93 | def apply_rope1d(self, tokens, pos1d, cos, sin): 94 | assert pos1d.ndim == 2 95 | # for (ntokens x batch_size x nheads x dim) 96 | cos = torch.nn.functional.embedding(pos1d, cos)[:, :, None, :] 97 | sin = torch.nn.functional.embedding(pos1d, sin)[:, :, None, :] 98 | 99 | return (tokens * cos) + (self.rotate_half(tokens) * sin) 100 | 101 | def forward(self, tokens, positions): 102 | """ 103 | input: 104 | * tokens: ntokens x batch_size x nheads x dim 105 | * positions: batch_size x ntokens x 3 (t, y and x position of each token) 106 | output: 107 | * tokens after appplying RoPE3D (ntokens x batch_size x nheads x dim) 108 | """ 109 | assert tokens.size(3) % 3 == 0, "number of dimensions should be a multiple of three" 110 | D = tokens.size(3) // 3 111 | poses, max_poses = positions 112 | assert len(poses) == 3 and poses[0].ndim == 2# Batch, Seq, 3 113 | cos_t, sin_t = self.get_cos_sin(D, max_poses[0] + 1, tokens.device, tokens.dtype, self.interpolation_scale_t) 114 | cos_y, sin_y = self.get_cos_sin(D, max_poses[1] + 1, tokens.device, tokens.dtype, self.interpolation_scale_h) 115 | cos_x, sin_x = self.get_cos_sin(D, max_poses[2] + 1, tokens.device, tokens.dtype, self.interpolation_scale_w) 116 | # split features into three along the feature dimension, and apply rope1d on each half 117 | t, y, x = tokens.chunk(3, dim=-1) 118 | t = self.apply_rope1d(t, poses[0], cos_t, sin_t) 119 | y = self.apply_rope1d(y, poses[1], cos_y, sin_y) 120 | x = self.apply_rope1d(x, poses[2], cos_x, sin_x) 121 | tokens = torch.cat((t, y, x), dim=-1) 122 | return tokens -------------------------------------------------------------------------------- /opensora/models/diffusion/curope/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2022-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | 4 | from .curope3d import cuRoPE3D -------------------------------------------------------------------------------- /opensora/models/diffusion/curope/curope.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright (C) 2022-present Naver Corporation. All rights reserved. 3 | Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 4 | */ 5 | 6 | #include 7 | 8 | // forward declaration 9 | void rope_3d_cuda( torch::Tensor tokens, const torch::Tensor pos, const float base, const float fwd ); 10 | 11 | void rope_3d_cpu( torch::Tensor tokens, const torch::Tensor positions, const float base, const float fwd ) 12 | { 13 | const int B = tokens.size(0); 14 | const int N = tokens.size(1); 15 | const int H = tokens.size(2); 16 | const int D = tokens.size(3) / 6; 17 | 18 | auto tok = tokens.accessor(); 19 | auto pos = positions.accessor(); 20 | 21 | for (int b = 0; b < B; b++) { 22 | for (int x = 0; x < 3; x++) { // t and y and then x (3d) 23 | for (int n = 0; n < N; n++) { 24 | 25 | // grab the token position 26 | const int p = pos[b][n][x]; 27 | 28 | for (int h = 0; h < H; h++) { 29 | for (int d = 0; d < D; d++) { 30 | // grab the two values 31 | float u = tok[b][n][h][d+0+x*2*D]; 32 | float v = tok[b][n][h][d+D+x*2*D]; 33 | 34 | // grab the cos,sin 35 | const float inv_freq = fwd * p / powf(base, d/float(D)); 36 | float c = cosf(inv_freq); 37 | float s = sinf(inv_freq); 38 | 39 | // write the result 40 | tok[b][n][h][d+0+x*2*D] = u*c - v*s; 41 | tok[b][n][h][d+D+x*2*D] = v*c + u*s; 42 | } 43 | } 44 | } 45 | } 46 | } 47 | } 48 | 49 | void rope_3d( torch::Tensor tokens, // B,N,H,D 50 | const torch::Tensor positions, // B,N,3 51 | const float base, 52 | const float fwd ) 53 | { 54 | TORCH_CHECK(tokens.dim() == 4, "tokens must have 4 dimensions"); 55 | TORCH_CHECK(positions.dim() == 3, "positions must have 3 dimensions"); 56 | TORCH_CHECK(tokens.size(0) == positions.size(0), "batch size differs between tokens & positions"); 57 | TORCH_CHECK(tokens.size(1) == positions.size(1), "seq_length differs between tokens & positions"); 58 | TORCH_CHECK(positions.size(2) == 3, "positions.shape[2] must be equal to 3"); 59 | TORCH_CHECK(tokens.is_cuda() == positions.is_cuda(), "tokens and positions are not on the same device" ); 60 | 61 | if (tokens.is_cuda()) 62 | rope_3d_cuda( tokens, positions, base, fwd ); 63 | else 64 | rope_3d_cpu( tokens, positions, base, fwd ); 65 | } 66 | 67 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 68 | m.def("rope_3d", &rope_3d, "RoPE 3d forward/backward"); 69 | } -------------------------------------------------------------------------------- /opensora/models/diffusion/curope/curope3d.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2022-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | 4 | import torch 5 | 6 | try: 7 | import curope as _kernels # run `python setup.py install` 8 | except ModuleNotFoundError: 9 | from . import curope as _kernels # run `python setup.py build_ext --inplace` 10 | 11 | 12 | class cuRoPE3D_func (torch.autograd.Function): 13 | 14 | @staticmethod 15 | def forward(ctx, tokens, positions, base, F0=1): 16 | ctx.save_for_backward(positions) 17 | ctx.saved_base = base 18 | ctx.saved_F0 = F0 19 | # tokens = tokens.clone() # uncomment this if inplace doesn't work 20 | _kernels.rope_3d( tokens, positions, base, F0 ) 21 | ctx.mark_dirty(tokens) 22 | return tokens 23 | 24 | @staticmethod 25 | def backward(ctx, grad_res): 26 | positions, base, F0 = ctx.saved_tensors[0], ctx.saved_base, ctx.saved_F0 27 | _kernels.rope_3d( grad_res, positions, base, -F0 ) 28 | ctx.mark_dirty(grad_res) 29 | return grad_res, None, None, None 30 | 31 | 32 | class cuRoPE3D(torch.nn.Module): 33 | def __init__(self, freq=10000.0, F0=1.0, interpolation_scale_thw=None): 34 | super().__init__() 35 | self.base = freq 36 | self.F0 = F0 37 | 38 | def forward(self, tokens, positions): 39 | # tokens.transpose(1,2): B,N,H,D 40 | # positions: B,N,3 41 | # print('tokens.transpose(1,2).shape, positions.shape, self.base, self.F0', 42 | # tokens.transpose(1,2).shape, positions.shape, self.base, self.F0) 43 | cuRoPE3D_func.apply( tokens.transpose(1,2), positions, self.base, self.F0 ) 44 | return tokens -------------------------------------------------------------------------------- /opensora/models/diffusion/curope/kernels.cu: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright (C) 2022-present Naver Corporation. All rights reserved. 3 | Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 4 | */ 5 | 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | #define CHECK_CUDA(tensor) {\ 12 | TORCH_CHECK((tensor).is_cuda(), #tensor " is not in cuda memory"); \ 13 | TORCH_CHECK((tensor).is_contiguous(), #tensor " is not contiguous"); } 14 | void CHECK_KERNEL() {auto error = cudaGetLastError(); TORCH_CHECK( error == cudaSuccess, cudaGetErrorString(error));} 15 | 16 | 17 | template < typename scalar_t > 18 | __global__ void rope_3d_cuda_kernel( 19 | //scalar_t* __restrict__ tokens, 20 | torch::PackedTensorAccessor32 tokens, 21 | const int64_t* __restrict__ pos, 22 | const float base, 23 | const float fwd ) 24 | // const int N, const int H, const int D ) 25 | { 26 | // tokens shape = (B, N, H, D) 27 | const int N = tokens.size(1); 28 | const int H = tokens.size(2); 29 | const int D = tokens.size(3); 30 | 31 | // each block update a single token, for all heads 32 | // each thread takes care of a single output 33 | extern __shared__ float shared[]; 34 | float* shared_inv_freq = shared + D; 35 | 36 | const int b = blockIdx.x / N; 37 | const int n = blockIdx.x % N; 38 | 39 | const int Q = D / 6; 40 | // one token = [0..Q : Q..2Q : 2Q..3Q : 3Q..4Q : 4Q..5Q : 6Q..D] 41 | // u_T v_T u_Y v_Y u_X v_X 42 | 43 | // shared memory: first, compute inv_freq 44 | if (threadIdx.x < Q) 45 | shared_inv_freq[threadIdx.x] = fwd / powf(base, threadIdx.x/float(Q)); 46 | __syncthreads(); 47 | 48 | // start of X or Y part 49 | const int X = threadIdx.x < D/2 ? 0 : 1; 50 | const int m = (X*D/2) + (threadIdx.x % Q); // index of u_Y or u_X 51 | 52 | // grab the cos,sin appropriate for me 53 | const float freq = pos[blockIdx.x*2+X] * shared_inv_freq[threadIdx.x % Q]; 54 | const float cos = cosf(freq); 55 | const float sin = sinf(freq); 56 | /* 57 | float* shared_cos_sin = shared + D + D/4; 58 | if ((threadIdx.x % (D/2)) < Q) 59 | shared_cos_sin[m+0] = cosf(freq); 60 | else 61 | shared_cos_sin[m+Q] = sinf(freq); 62 | __syncthreads(); 63 | const float cos = shared_cos_sin[m+0]; 64 | const float sin = shared_cos_sin[m+Q]; 65 | */ 66 | 67 | for (int h = 0; h < H; h++) 68 | { 69 | // then, load all the token for this head in shared memory 70 | shared[threadIdx.x] = tokens[b][n][h][threadIdx.x]; 71 | __syncthreads(); 72 | 73 | const float u = shared[m]; 74 | const float v = shared[m+Q]; 75 | 76 | // write output 77 | if ((threadIdx.x % (D/2)) < Q) 78 | tokens[b][n][h][threadIdx.x] = u*cos - v*sin; 79 | else 80 | tokens[b][n][h][threadIdx.x] = v*cos + u*sin; 81 | } 82 | } 83 | 84 | void rope_3d_cuda( torch::Tensor tokens, const torch::Tensor pos, const float base, const float fwd ) 85 | { 86 | const int B = tokens.size(0); // batch size 87 | const int N = tokens.size(1); // sequence length 88 | const int H = tokens.size(2); // number of heads 89 | const int D = tokens.size(3); // dimension per head 90 | 91 | TORCH_CHECK(tokens.stride(3) == 1 && tokens.stride(2) == D, "tokens are not contiguous"); 92 | TORCH_CHECK(pos.is_contiguous(), "positions are not contiguous"); 93 | TORCH_CHECK(pos.size(0) == B && pos.size(1) == N && pos.size(2) == 3, "bad pos.shape"); 94 | TORCH_CHECK(D % 6 == 0, "token dim must be multiple of 6"); 95 | 96 | // one block for each layer, one thread per local-max 97 | const int THREADS_PER_BLOCK = D; 98 | const int N_BLOCKS = B * N; // each block takes care of H*D values 99 | const int SHARED_MEM = sizeof(float) * (D + D/6); 100 | 101 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(tokens.type(), "rope_3d_cuda", ([&] { 102 | rope_3d_cuda_kernel <<>> ( 103 | //tokens.data_ptr(), 104 | tokens.packed_accessor32(), 105 | pos.data_ptr(), 106 | base, fwd); //, N, H, D ); 107 | })); 108 | } -------------------------------------------------------------------------------- /opensora/models/diffusion/curope/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2022-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | 4 | from setuptools import setup 5 | from torch import cuda 6 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 7 | 8 | # compile for all possible CUDA architectures 9 | all_cuda_archs = cuda.get_gencode_flags().replace('compute=','arch=').split() 10 | # alternatively, you can list cuda archs that you want, eg: 11 | # all_cuda_archs = [ 12 | # '-gencode', 'arch=compute_70,code=sm_70', 13 | # '-gencode', 'arch=compute_75,code=sm_75', 14 | # '-gencode', 'arch=compute_80,code=sm_80', 15 | # '-gencode', 'arch=compute_86,code=sm_86' 16 | # ] 17 | 18 | setup( 19 | name = 'curope', 20 | ext_modules = [ 21 | CUDAExtension( 22 | name='curope', 23 | sources=[ 24 | "curope.cpp", 25 | "kernels.cu", 26 | ], 27 | extra_compile_args = dict( 28 | nvcc=['-O3','--ptxas-options=-v',"--use_fast_math"]+all_cuda_archs, 29 | cxx=['-O3']) 30 | ) 31 | ], 32 | cmdclass = { 33 | 'build_ext': BuildExtension 34 | }) -------------------------------------------------------------------------------- /opensora/models/diffusion/opensora_v1_3/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-YuanGroup/Open-Sora-Plan/469d4e8810c326811e1be7e1c17b845503633210/opensora/models/diffusion/opensora_v1_3/__init__.py -------------------------------------------------------------------------------- /opensora/models/frame_interpolation/cfgs/AMT-G.yaml: -------------------------------------------------------------------------------- 1 | 2 | seed: 2023 3 | 4 | network: 5 | name: networks.AMT-G.Model 6 | params: 7 | corr_radius: 3 8 | corr_lvls: 4 9 | num_flows: 5 -------------------------------------------------------------------------------- /opensora/models/frame_interpolation/networks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-YuanGroup/Open-Sora-Plan/469d4e8810c326811e1be7e1c17b845503633210/opensora/models/frame_interpolation/networks/__init__.py -------------------------------------------------------------------------------- /opensora/models/frame_interpolation/networks/blocks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-YuanGroup/Open-Sora-Plan/469d4e8810c326811e1be7e1c17b845503633210/opensora/models/frame_interpolation/networks/blocks/__init__.py -------------------------------------------------------------------------------- /opensora/models/frame_interpolation/networks/blocks/ifrnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from utils.flow_utils import warp 5 | 6 | 7 | def resize(x, scale_factor): 8 | return F.interpolate(x, scale_factor=scale_factor, mode="bilinear", align_corners=False) 9 | 10 | def convrelu(in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True): 11 | return nn.Sequential( 12 | nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=bias), 13 | nn.PReLU(out_channels) 14 | ) 15 | 16 | class ResBlock(nn.Module): 17 | def __init__(self, in_channels, side_channels, bias=True): 18 | super(ResBlock, self).__init__() 19 | self.side_channels = side_channels 20 | self.conv1 = nn.Sequential( 21 | nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias), 22 | nn.PReLU(in_channels) 23 | ) 24 | self.conv2 = nn.Sequential( 25 | nn.Conv2d(side_channels, side_channels, kernel_size=3, stride=1, padding=1, bias=bias), 26 | nn.PReLU(side_channels) 27 | ) 28 | self.conv3 = nn.Sequential( 29 | nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias), 30 | nn.PReLU(in_channels) 31 | ) 32 | self.conv4 = nn.Sequential( 33 | nn.Conv2d(side_channels, side_channels, kernel_size=3, stride=1, padding=1, bias=bias), 34 | nn.PReLU(side_channels) 35 | ) 36 | self.conv5 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias) 37 | self.prelu = nn.PReLU(in_channels) 38 | 39 | def forward(self, x): 40 | out = self.conv1(x) 41 | 42 | res_feat = out[:, :-self.side_channels, ...] 43 | side_feat = out[:, -self.side_channels:, :, :] 44 | side_feat = self.conv2(side_feat) 45 | out = self.conv3(torch.cat([res_feat, side_feat], 1)) 46 | 47 | res_feat = out[:, :-self.side_channels, ...] 48 | side_feat = out[:, -self.side_channels:, :, :] 49 | side_feat = self.conv4(side_feat) 50 | out = self.conv5(torch.cat([res_feat, side_feat], 1)) 51 | 52 | out = self.prelu(x + out) 53 | return out 54 | 55 | class Encoder(nn.Module): 56 | def __init__(self, channels, large=False): 57 | super(Encoder, self).__init__() 58 | self.channels = channels 59 | prev_ch = 3 60 | for idx, ch in enumerate(channels, 1): 61 | k = 7 if large and idx == 1 else 3 62 | p = 3 if k ==7 else 1 63 | self.register_module(f'pyramid{idx}', 64 | nn.Sequential( 65 | convrelu(prev_ch, ch, k, 2, p), 66 | convrelu(ch, ch, 3, 1, 1) 67 | )) 68 | prev_ch = ch 69 | 70 | def forward(self, in_x): 71 | fs = [] 72 | for idx in range(len(self.channels)): 73 | out_x = getattr(self, f'pyramid{idx+1}')(in_x) 74 | fs.append(out_x) 75 | in_x = out_x 76 | return fs 77 | 78 | class InitDecoder(nn.Module): 79 | def __init__(self, in_ch, out_ch, skip_ch) -> None: 80 | super().__init__() 81 | self.convblock = nn.Sequential( 82 | convrelu(in_ch*2+1, in_ch*2), 83 | ResBlock(in_ch*2, skip_ch), 84 | nn.ConvTranspose2d(in_ch*2, out_ch+4, 4, 2, 1, bias=True) 85 | ) 86 | def forward(self, f0, f1, embt): 87 | h, w = f0.shape[2:] 88 | embt = embt.repeat(1, 1, h, w) 89 | out = self.convblock(torch.cat([f0, f1, embt], 1)) 90 | flow0, flow1 = torch.chunk(out[:, :4, ...], 2, 1) 91 | ft_ = out[:, 4:, ...] 92 | return flow0, flow1, ft_ 93 | 94 | class IntermediateDecoder(nn.Module): 95 | def __init__(self, in_ch, out_ch, skip_ch) -> None: 96 | super().__init__() 97 | self.convblock = nn.Sequential( 98 | convrelu(in_ch*3+4, in_ch*3), 99 | ResBlock(in_ch*3, skip_ch), 100 | nn.ConvTranspose2d(in_ch*3, out_ch+4, 4, 2, 1, bias=True) 101 | ) 102 | def forward(self, ft_, f0, f1, flow0_in, flow1_in): 103 | f0_warp = warp(f0, flow0_in) 104 | f1_warp = warp(f1, flow1_in) 105 | f_in = torch.cat([ft_, f0_warp, f1_warp, flow0_in, flow1_in], 1) 106 | out = self.convblock(f_in) 107 | flow0, flow1 = torch.chunk(out[:, :4, ...], 2, 1) 108 | ft_ = out[:, 4:, ...] 109 | flow0 = flow0 + 2.0 * resize(flow0_in, scale_factor=2.0) 110 | flow1 = flow1 + 2.0 * resize(flow1_in, scale_factor=2.0) 111 | return flow0, flow1, ft_ -------------------------------------------------------------------------------- /opensora/models/frame_interpolation/networks/blocks/multi_flow.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from utils.flow_utils import warp 4 | from networks.blocks.ifrnet import ( 5 | convrelu, resize, 6 | ResBlock, 7 | ) 8 | 9 | 10 | def multi_flow_combine(comb_block, img0, img1, flow0, flow1, 11 | mask=None, img_res=None, mean=None): 12 | ''' 13 | A parallel implementation of multiple flow field warping 14 | comb_block: An nn.Seqential object. 15 | img shape: [b, c, h, w] 16 | flow shape: [b, 2*num_flows, h, w] 17 | mask (opt): 18 | If 'mask' is None, the function conduct a simple average. 19 | img_res (opt): 20 | If 'img_res' is None, the function adds zero instead. 21 | mean (opt): 22 | If 'mean' is None, the function adds zero instead. 23 | ''' 24 | b, c, h, w = flow0.shape 25 | num_flows = c // 2 26 | flow0 = flow0.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w) 27 | flow1 = flow1.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w) 28 | 29 | mask = mask.reshape(b, num_flows, 1, h, w 30 | ).reshape(-1, 1, h, w) if mask is not None else None 31 | img_res = img_res.reshape(b, num_flows, 3, h, w 32 | ).reshape(-1, 3, h, w) if img_res is not None else 0 33 | img0 = torch.stack([img0] * num_flows, 1).reshape(-1, 3, h, w) 34 | img1 = torch.stack([img1] * num_flows, 1).reshape(-1, 3, h, w) 35 | mean = torch.stack([mean] * num_flows, 1).reshape(-1, 1, 1, 1 36 | ) if mean is not None else 0 37 | 38 | img0_warp = warp(img0, flow0) 39 | img1_warp = warp(img1, flow1) 40 | img_warps = mask * img0_warp + (1 - mask) * img1_warp + mean + img_res 41 | img_warps = img_warps.reshape(b, num_flows, 3, h, w) 42 | imgt_pred = img_warps.mean(1) + comb_block(img_warps.view(b, -1, h, w)) 43 | return imgt_pred 44 | 45 | 46 | class MultiFlowDecoder(nn.Module): 47 | def __init__(self, in_ch, skip_ch, num_flows=3): 48 | super(MultiFlowDecoder, self).__init__() 49 | self.num_flows = num_flows 50 | self.convblock = nn.Sequential( 51 | convrelu(in_ch*3+4, in_ch*3), 52 | ResBlock(in_ch*3, skip_ch), 53 | nn.ConvTranspose2d(in_ch*3, 8*num_flows, 4, 2, 1, bias=True) 54 | ) 55 | 56 | def forward(self, ft_, f0, f1, flow0, flow1): 57 | n = self.num_flows 58 | f0_warp = warp(f0, flow0) 59 | f1_warp = warp(f1, flow1) 60 | out = self.convblock(torch.cat([ft_, f0_warp, f1_warp, flow0, flow1], 1)) 61 | delta_flow0, delta_flow1, mask, img_res = torch.split(out, [2*n, 2*n, n, 3*n], 1) 62 | mask = torch.sigmoid(mask) 63 | 64 | flow0 = delta_flow0 + 2.0 * resize(flow0, scale_factor=2.0 65 | ).repeat(1, self.num_flows, 1, 1) 66 | flow1 = delta_flow1 + 2.0 * resize(flow1, scale_factor=2.0 67 | ).repeat(1, self.num_flows, 1, 1) 68 | 69 | return flow0, flow1, mask, img_res -------------------------------------------------------------------------------- /opensora/models/frame_interpolation/readme.md: -------------------------------------------------------------------------------- 1 | #### Frame Interpolation 2 | 3 | We use AMT as our frame interpolation model. (Thanks [AMT](https://github.com/MCG-NKU/AMT)) After sampling, you can use frame interpolation model to interpolate your video smoothly. 4 | 5 | 1. Download the pretrained weights from [AMT](https://github.com/MCG-NKU/AMT), we recommend using the largest model AMT-G to achieve the best performance. 6 | 2. Run the script of frame interpolation. 7 | ``` 8 | python opensora/models/frame_interpolation/interpolation.py --ckpt /path/to/ckpt --niters 1 --input /path/to/input/video.mp4 --output_path /path/to/output/floder --frame_rate 30 9 | ``` 10 | 3. The output video will be stored at output_path and its duration time is equal `the total number of frames after frame interpolation / the frame rate` 11 | ##### Frame Interpolation Specific Settings 12 | 13 | * `--ckpt`: Pretrained model of [AMT](https://github.com/MCG-NKU/AMT). We use AMT-G as our frame interpolation model. 14 | * `--niter`: Iterations of interpolation. With $m$ input frames, `[N_ITER]` $=n$ corresponds to $2^n\times (m-1)+1$ output frames. 15 | * `--input`: Path of the input video. 16 | * `--output_path`: Folder Path of the output video. 17 | * `--frame_rate"`: Frame rate of the output video. 18 | -------------------------------------------------------------------------------- /opensora/models/frame_interpolation/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-YuanGroup/Open-Sora-Plan/469d4e8810c326811e1be7e1c17b845503633210/opensora/models/frame_interpolation/utils/__init__.py -------------------------------------------------------------------------------- /opensora/models/frame_interpolation/utils/build_utils.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | 4 | def base_build_fn(module, cls, params): 5 | return getattr(importlib.import_module( 6 | module, package=None), cls)(**params) 7 | 8 | 9 | def build_from_cfg(config): 10 | module, cls = config['name'].rsplit(".", 1) 11 | params = config.get('params', {}) 12 | return base_build_fn(module, cls, params) 13 | -------------------------------------------------------------------------------- /opensora/models/frame_interpolation/utils/dist_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | 5 | def get_world_size(): 6 | """Find OMPI world size without calling mpi functions 7 | :rtype: int 8 | """ 9 | if os.environ.get('PMI_SIZE') is not None: 10 | return int(os.environ.get('PMI_SIZE') or 1) 11 | elif os.environ.get('OMPI_COMM_WORLD_SIZE') is not None: 12 | return int(os.environ.get('OMPI_COMM_WORLD_SIZE') or 1) 13 | else: 14 | return torch.cuda.device_count() 15 | 16 | 17 | def get_global_rank(): 18 | """Find OMPI world rank without calling mpi functions 19 | :rtype: int 20 | """ 21 | if os.environ.get('PMI_RANK') is not None: 22 | return int(os.environ.get('PMI_RANK') or 0) 23 | elif os.environ.get('OMPI_COMM_WORLD_RANK') is not None: 24 | return int(os.environ.get('OMPI_COMM_WORLD_RANK') or 0) 25 | else: 26 | return 0 27 | 28 | 29 | def get_local_rank(): 30 | """Find OMPI local rank without calling mpi functions 31 | :rtype: int 32 | """ 33 | if os.environ.get('MPI_LOCALRANKID') is not None: 34 | return int(os.environ.get('MPI_LOCALRANKID') or 0) 35 | elif os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') is not None: 36 | return int(os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') or 0) 37 | else: 38 | return 0 39 | 40 | 41 | def get_master_ip(): 42 | if os.environ.get('AZ_BATCH_MASTER_NODE') is not None: 43 | return os.environ.get('AZ_BATCH_MASTER_NODE').split(':')[0] 44 | elif os.environ.get('AZ_BATCHAI_MPI_MASTER_NODE') is not None: 45 | return os.environ.get('AZ_BATCHAI_MPI_MASTER_NODE') 46 | else: 47 | return "127.0.0.1" 48 | 49 | -------------------------------------------------------------------------------- /opensora/models/frame_interpolation/utils/flow_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from PIL import ImageFile 4 | import torch.nn.functional as F 5 | ImageFile.LOAD_TRUNCATED_IMAGES = True 6 | 7 | 8 | def warp(img, flow): 9 | B, _, H, W = flow.shape 10 | xx = torch.linspace(-1.0, 1.0, W).view(1, 1, 1, W).expand(B, -1, H, -1) 11 | yy = torch.linspace(-1.0, 1.0, H).view(1, 1, H, 1).expand(B, -1, -1, W) 12 | grid = torch.cat([xx, yy], 1).to(img) 13 | flow_ = torch.cat([flow[:, 0:1, :, :] / ((W - 1.0) / 2.0), flow[:, 1:2, :, :] / ((H - 1.0) / 2.0)], 1) 14 | grid_ = (grid + flow_).permute(0, 2, 3, 1) 15 | output = F.grid_sample(input=img, grid=grid_, mode='bilinear', padding_mode='border', align_corners=True) 16 | return output 17 | 18 | 19 | def make_colorwheel(): 20 | """ 21 | Generates a color wheel for optical flow visualization as presented in: 22 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) 23 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf 24 | Code follows the original C++ source code of Daniel Scharstein. 25 | Code follows the the Matlab source code of Deqing Sun. 26 | Returns: 27 | np.ndarray: Color wheel 28 | """ 29 | 30 | RY = 15 31 | YG = 6 32 | GC = 4 33 | CB = 11 34 | BM = 13 35 | MR = 6 36 | 37 | ncols = RY + YG + GC + CB + BM + MR 38 | colorwheel = np.zeros((ncols, 3)) 39 | col = 0 40 | 41 | # RY 42 | colorwheel[0:RY, 0] = 255 43 | colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY) 44 | col = col+RY 45 | # YG 46 | colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG) 47 | colorwheel[col:col+YG, 1] = 255 48 | col = col+YG 49 | # GC 50 | colorwheel[col:col+GC, 1] = 255 51 | colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC) 52 | col = col+GC 53 | # CB 54 | colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB) 55 | colorwheel[col:col+CB, 2] = 255 56 | col = col+CB 57 | # BM 58 | colorwheel[col:col+BM, 2] = 255 59 | colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM) 60 | col = col+BM 61 | # MR 62 | colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR) 63 | colorwheel[col:col+MR, 0] = 255 64 | return colorwheel 65 | 66 | def flow_uv_to_colors(u, v, convert_to_bgr=False): 67 | """ 68 | Applies the flow color wheel to (possibly clipped) flow components u and v. 69 | According to the C++ source code of Daniel Scharstein 70 | According to the Matlab source code of Deqing Sun 71 | Args: 72 | u (np.ndarray): Input horizontal flow of shape [H,W] 73 | v (np.ndarray): Input vertical flow of shape [H,W] 74 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 75 | Returns: 76 | np.ndarray: Flow visualization image of shape [H,W,3] 77 | """ 78 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) 79 | colorwheel = make_colorwheel() # shape [55x3] 80 | ncols = colorwheel.shape[0] 81 | rad = np.sqrt(np.square(u) + np.square(v)) 82 | a = np.arctan2(-v, -u)/np.pi 83 | fk = (a+1) / 2*(ncols-1) 84 | k0 = np.floor(fk).astype(np.int32) 85 | k1 = k0 + 1 86 | k1[k1 == ncols] = 0 87 | f = fk - k0 88 | for i in range(colorwheel.shape[1]): 89 | tmp = colorwheel[:,i] 90 | col0 = tmp[k0] / 255.0 91 | col1 = tmp[k1] / 255.0 92 | col = (1-f)*col0 + f*col1 93 | idx = (rad <= 1) 94 | col[idx] = 1 - rad[idx] * (1-col[idx]) 95 | col[~idx] = col[~idx] * 0.75 # out of range 96 | # Note the 2-i => BGR instead of RGB 97 | ch_idx = 2-i if convert_to_bgr else i 98 | flow_image[:,:,ch_idx] = np.floor(255 * col) 99 | return flow_image 100 | 101 | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False): 102 | """ 103 | Expects a two dimensional flow image of shape. 104 | Args: 105 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2] 106 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None. 107 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 108 | Returns: 109 | np.ndarray: Flow visualization image of shape [H,W,3] 110 | """ 111 | assert flow_uv.ndim == 3, 'input flow must have three dimensions' 112 | assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' 113 | if clip_flow is not None: 114 | flow_uv = np.clip(flow_uv, 0, clip_flow) 115 | u = flow_uv[:,:,0] 116 | v = flow_uv[:,:,1] 117 | rad = np.sqrt(np.square(u) + np.square(v)) 118 | rad_max = np.max(rad) 119 | epsilon = 1e-5 120 | u = u / (rad_max + epsilon) 121 | v = v / (rad_max + epsilon) 122 | return flow_uv_to_colors(u, v, convert_to_bgr) -------------------------------------------------------------------------------- /opensora/models/prompt_refiner/inference.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModelForCausalLM, AutoTokenizer 2 | import torch 3 | from tqdm import tqdm 4 | import argparse 5 | 6 | def get_output(prompt): 7 | template = "Refine the sentence: \"{}\" to contain subject description, action, scene description. " \ 8 | "(Optional: camera language, light and shadow, atmosphere) and conceive some additional actions to make the sentence more dynamic. " \ 9 | "Make sure it is a fluent sentence, not nonsense." 10 | prompt = template.format(prompt) 11 | messages = [ 12 | {"role": "system", "content": "You are a caption refiner."}, 13 | {"role": "user", "content": prompt} 14 | ] 15 | 16 | input_ids = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) 17 | model_inputs = tokenizer([input_ids], return_tensors="pt").to(device) 18 | generated_ids = model.generate(model_inputs.input_ids, max_new_tokens=512) 19 | generated_ids = [ 20 | output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) 21 | ] 22 | response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] 23 | print('\nInput\n:', prompt) 24 | print('\nOutput\n:', response) 25 | return response 26 | 27 | def parse_args(): 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument("--mode_path", type=str, default="llama3_8B_lora_merged_cn") 30 | parser.add_argument("--prompt", type=str, default='a dog is running.') 31 | args = parser.parse_args() 32 | return args 33 | 34 | if __name__ == '__main__': 35 | args = parse_args() 36 | device = torch.device('cuda') 37 | tokenizer = AutoTokenizer.from_pretrained(args.mode_path, trust_remote_code=True) 38 | model = AutoModelForCausalLM.from_pretrained(args.mode_path,torch_dtype=torch.bfloat16, trust_remote_code=True).to(device).eval() 39 | 40 | response = get_output(args.prompt) -------------------------------------------------------------------------------- /opensora/models/prompt_refiner/merge.py: -------------------------------------------------------------------------------- 1 | import os 2 | from transformers import AutoModelForCausalLM, AutoTokenizer 3 | from peft import PeftModel 4 | import torch 5 | import argparse 6 | 7 | 8 | def get_lora_model(base_model_path, lora_model_input_path, lora_model_output_path): 9 | model = AutoModelForCausalLM.from_pretrained(base_model_path, torch_dtype=torch.float16, device_map="auto",trust_remote_code=True) 10 | model = PeftModel.from_pretrained(model, lora_model_input_path) 11 | merged_model = model.merge_and_unload() 12 | merged_model.save_pretrained(lora_model_output_path, safe_serialization=True) 13 | print("Merge lora to base model") 14 | 15 | tokenizer = AutoTokenizer.from_pretrained(base_model_path, trust_remote_code=True) 16 | tokenizer.save_pretrained(lora_model_output_path) 17 | print("Save tokenizer") 18 | 19 | def get_model_result(base_model_path, fintune_model_path): 20 | tokenizer = AutoTokenizer.from_pretrained(base_model_path) 21 | device = "cuda" 22 | 23 | fintune_model = AutoModelForCausalLM.from_pretrained( 24 | fintune_model_path, 25 | device_map="auto", 26 | torch_dtype=torch.bfloat16, 27 | ).eval() 28 | 29 | base_model = AutoModelForCausalLM.from_pretrained( 30 | base_model_path, 31 | device_map="auto", 32 | torch_dtype=torch.bfloat16, 33 | ).eval() 34 | 35 | template = "Refine the sentence: \"{}\" to contain subject description, action, scene description. " \ 36 | "(Optional: camera language, light and shadow, atmosphere) and conceive some additional actions to make the sentence more dynamic. " \ 37 | "Make sure it is a fluent sentence, not nonsense." 38 | 39 | prompt = "a dog和一只猫" 40 | prompt = template.format(prompt) 41 | messages = [ 42 | {"role": "system", "content": "You are a caption refiner."}, 43 | {"role": "user", "content": prompt} 44 | ] 45 | text = tokenizer.apply_chat_template( 46 | messages, 47 | tokenize=False, 48 | add_generation_prompt=True 49 | ) 50 | 51 | model_inputs = tokenizer([text], return_tensors="pt").to(device) 52 | 53 | def get_result(model_inputs, model): 54 | generated_ids = model.generate( 55 | model_inputs.input_ids, 56 | max_new_tokens=512, 57 | eos_token_id=tokenizer.get_vocab()["<|eot_id|>"] 58 | ) 59 | generated_ids = [ 60 | output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) 61 | ] 62 | 63 | response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] 64 | return response 65 | 66 | base_model_response = get_result(model_inputs, base_model) 67 | fintune_model_response = get_result(model_inputs, fintune_model) 68 | print("\nInput\n", prompt) 69 | print("\nResult before fine-tune:\n", base_model_response) 70 | print("\nResult after fine-tune:\n", fintune_model_response) 71 | 72 | def parse_args(): 73 | parser = argparse.ArgumentParser() 74 | parser.add_argument("--base_path", type=str, default="Meta-Llama-3___1-8B-Instruct") 75 | parser.add_argument("--lora_in_path", type=str, default="llama3_1_instruct_lora/checkpoint-1008") 76 | parser.add_argument("--lora_out_path", type=str, default="llama3_1_instruct_lora/llama3_8B_lora_merged_cn") 77 | args = parser.parse_args() 78 | return args 79 | 80 | if __name__ == '__main__': 81 | args = parse_args() 82 | get_lora_model(args.base_path, args.lora_in_path, args.lora_out_path) 83 | get_model_result(args.base_path, args.lora_out_path) -------------------------------------------------------------------------------- /opensora/models/prompt_refiner/train.py: -------------------------------------------------------------------------------- 1 | from datasets import Dataset 2 | import pandas as pd 3 | from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForSeq2Seq, TrainingArguments, Trainer, GenerationConfig 4 | from peft import LoraConfig, TaskType, get_peft_model 5 | import torch 6 | import argparse 7 | 8 | ins = "Refine the sentence to contain subject description, action, scene description. " \ 9 | "(Optional: camera language, light and shadow, atmosphere) and conceive some additional actions to make the sentence more dynamic. " \ 10 | "Make sure it is a fluent sentence, not nonsense." 11 | 12 | def parse_args(): 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--data_path", type=str, default='refine_32255.json') 15 | parser.add_argument("--model_path", type=str, default='Meta-Llama-3___1-8B-Instruct') 16 | parser.add_argument("--lora_out_path", type=str, default="llama3_1_instruct_lora") 17 | args = parser.parse_args() 18 | return args 19 | 20 | args = parse_args() 21 | 22 | 23 | df = pd.read_json(args.data_path) 24 | ds = Dataset.from_pandas(df) 25 | tokenizer = AutoTokenizer.from_pretrained(args.model_path, use_fast=False, trust_remote_code=True) 26 | tokenizer.pad_token = tokenizer.eos_token 27 | 28 | def process_func(example): 29 | MAX_LENGTH = 2048 30 | input_ids, attention_mask, labels = [], [], [] 31 | instruction = tokenizer(f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a caption refiner.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{example['instruction'] + example['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", add_special_tokens=False) # add_special_tokens 不在开头加 special_tokens 32 | response = tokenizer(f"{example['output']}<|eot_id|>", add_special_tokens=False) 33 | input_ids = instruction["input_ids"] + response["input_ids"] + [tokenizer.pad_token_id] 34 | attention_mask = instruction["attention_mask"] + response["attention_mask"] + [1] 35 | labels = [-100] * len(instruction["input_ids"]) + response["input_ids"] + [tokenizer.pad_token_id] 36 | if len(input_ids) > MAX_LENGTH: 37 | input_ids = input_ids[:MAX_LENGTH] 38 | attention_mask = attention_mask[:MAX_LENGTH] 39 | labels = labels[:MAX_LENGTH] 40 | return { 41 | "input_ids": input_ids, 42 | "attention_mask": attention_mask, 43 | "labels": labels 44 | } 45 | 46 | tokenized_id = ds.map(process_func, remove_columns=ds.column_names) 47 | 48 | 49 | model = AutoModelForCausalLM.from_pretrained(args.model_path, device_map="auto",torch_dtype=torch.bfloat16) 50 | print(model) 51 | model.enable_input_require_grads() 52 | 53 | config = LoraConfig( 54 | task_type=TaskType.CAUSAL_LM, 55 | target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], 56 | inference_mode=False, 57 | r=64, 58 | lora_alpha=64, 59 | lora_dropout=0.1 60 | ) 61 | print(config) 62 | 63 | model = get_peft_model(model, config) 64 | model.print_trainable_parameters() 65 | 66 | args = TrainingArguments( 67 | output_dir=args.lora_out_path, 68 | per_device_train_batch_size=32, 69 | gradient_accumulation_steps=1, 70 | logging_steps=1, 71 | num_train_epochs=1, 72 | save_steps=20, 73 | dataloader_num_workers=4, 74 | learning_rate=1.5e-4, 75 | warmup_ratio=0.1, 76 | save_on_each_node=True, 77 | gradient_checkpointing=True, 78 | report_to='wandb', 79 | ) 80 | 81 | trainer = Trainer( 82 | model=model, 83 | args=args, 84 | train_dataset=tokenized_id, 85 | data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True), 86 | ) 87 | 88 | trainer.train() -------------------------------------------------------------------------------- /opensora/models/text_encoder/__init__.py: -------------------------------------------------------------------------------- 1 | from opensora.models.text_encoder.clip import CLIPWrapper 2 | from opensora.models.text_encoder.t5 import T5Wrapper 3 | 4 | text_encoder = { 5 | 'google/mt5-xl': T5Wrapper, 6 | 'google/mt5-xxl': T5Wrapper, 7 | 'google/umt5-xl': T5Wrapper, 8 | 'google/umt5-xxl': T5Wrapper, 9 | 'google/t5-v1_1-xl': T5Wrapper, 10 | 'DeepFloyd/t5-v1_1-xxl': T5Wrapper, 11 | 'openai/clip-vit-large-patch14': CLIPWrapper, 12 | 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k': CLIPWrapper 13 | } 14 | 15 | def get_text_warpper(text_encoder_name): 16 | """deprecation""" 17 | encoder_key = None 18 | for key in text_encoder.keys(): 19 | if key in text_encoder_name: 20 | encoder_key = key 21 | break 22 | text_enc = text_encoder.get(encoder_key, None) 23 | assert text_enc is not None 24 | return text_enc 25 | -------------------------------------------------------------------------------- /opensora/models/text_encoder/clip.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from transformers import CLIPTextModelWithProjection 4 | 5 | try: 6 | import torch_npu 7 | except: 8 | torch_npu = None 9 | 10 | class CLIPWrapper(nn.Module): 11 | def __init__(self, args, **kwargs): 12 | super(CLIPWrapper, self).__init__() 13 | self.model_name = args.text_encoder_name_2 14 | if torch_npu is not None: 15 | self.model_name = '/home/save_dir/pretrained/clip/models--laion--CLIP-ViT-bigG-14-laion2B-39B-b160k/snapshots/bc7788f151930d91b58474715fdce5524ad9a189' 16 | else: 17 | self.model_name = '/storage/cache_dir/CLIP-ViT-bigG-14-laion2B-39B-b160k' 18 | print(f'Loading CLIP model from {self.model_name}...') 19 | self.text_enc = CLIPTextModelWithProjection.from_pretrained(self.model_name, cache_dir=args.cache_dir, **kwargs).eval() 20 | 21 | def forward(self, input_ids, attention_mask): 22 | text_encoder_embs = self.text_enc(input_ids=input_ids, output_hidden_states=True)[0] 23 | return text_encoder_embs.detach() 24 | -------------------------------------------------------------------------------- /opensora/models/text_encoder/t5.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from transformers import T5EncoderModel 4 | 5 | try: 6 | import torch_npu 7 | except: 8 | torch_npu = None 9 | 10 | class T5Wrapper(nn.Module): 11 | def __init__(self, args, **kwargs): 12 | super(T5Wrapper, self).__init__() 13 | self.model_name = args.text_encoder_name_1 14 | print(f'Loading T5 model from {self.model_name}...') 15 | self.text_enc = T5EncoderModel.from_pretrained(self.model_name, cache_dir=args.cache_dir, **kwargs).eval() 16 | 17 | def forward(self, input_ids, attention_mask): 18 | text_encoder_embs = self.text_enc(input_ids=input_ids, attention_mask=attention_mask)['last_hidden_state'] 19 | return text_encoder_embs.detach() 20 | -------------------------------------------------------------------------------- /opensora/sample/caption_refiner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from transformers import AutoTokenizer, AutoModelForCausalLM 4 | 5 | 6 | 7 | TEMPLATE = """ 8 | Refine the sentence: \"{}\" to contain subject description, action, scene description. " \ 9 | "(Optional: camera language, light and shadow, atmosphere) and conceive some additional actions to make the sentence more dynamic. " \ 10 | "Make sure it is a fluent sentence, not nonsense. 11 | """ 12 | 13 | class OpenSoraCaptionRefiner(nn.Module): 14 | def __init__(self, args, dtype, device): 15 | super().__init__() 16 | self.tokenizer = AutoTokenizer.from_pretrained( 17 | args.caption_refiner, trust_remote_code=True 18 | ) 19 | self.model = AutoModelForCausalLM.from_pretrained( 20 | args.caption_refiner, torch_dtype=dtype, trust_remote_code=True 21 | ).to(device).eval() 22 | self.device = device 23 | 24 | def get_refiner_output(self, prompt): 25 | prompt = TEMPLATE.format(prompt) 26 | messages = [ 27 | {"role": "system", "content": "You are a caption refiner."}, 28 | {"role": "user", "content": prompt} 29 | ] 30 | input_ids = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) 31 | model_inputs = self.tokenizer([input_ids], return_tensors="pt").to(self.device) 32 | generated_ids = self.model.generate(model_inputs.input_ids, max_new_tokens=512) 33 | generated_ids = [ 34 | output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) 35 | ] 36 | response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] 37 | return response -------------------------------------------------------------------------------- /opensora/sample/rec_image.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append(".") 3 | from PIL import Image 4 | import torch 5 | from torchvision.transforms import ToTensor, Compose, Resize, Normalize, Lambda 6 | from torch.nn import functional as F 7 | import argparse 8 | import numpy as np 9 | from opensora.models.causalvideovae import ae_wrapper 10 | 11 | def preprocess(video_data: torch.Tensor, short_size: int = 128) -> torch.Tensor: 12 | transform = Compose( 13 | [ 14 | ToTensor(), 15 | Lambda(lambda x: 2. * x - 1.), 16 | Resize(size=short_size), 17 | ] 18 | ) 19 | outputs = transform(video_data) 20 | outputs = outputs.unsqueeze(0).unsqueeze(2) 21 | return outputs 22 | 23 | def main(args: argparse.Namespace): 24 | image_path = args.image_path 25 | short_size = args.short_size 26 | device = args.device 27 | kwarg = {} 28 | 29 | # vae = getae_wrapper(args.ae)(args.model_path, subfolder="vae", cache_dir='cache_dir', **kwarg).to(device) 30 | vae = ae_wrapper[args.ae](args.ae_path, **kwarg).eval().to(device) 31 | if args.enable_tiling: 32 | vae.vae.enable_tiling() 33 | vae.vae.tile_overlap_factor = args.tile_overlap_factor 34 | vae.eval() 35 | vae = vae.to(device) 36 | vae = vae.half() 37 | 38 | with torch.no_grad(): 39 | x_vae = preprocess(Image.open(image_path), short_size) 40 | x_vae = x_vae.to(device, dtype=torch.float16) # b c t h w 41 | latents = vae.encode(x_vae) 42 | latents = latents.to(torch.float16) 43 | image_recon = vae.decode(latents) # b t c h w 44 | x = image_recon[0, 0, :, :, :] 45 | x = x.squeeze() 46 | x = x.detach().cpu().numpy() 47 | x = np.clip(x, -1, 1) 48 | x = (x + 1) / 2 49 | x = (255*x).astype(np.uint8) 50 | x = x.transpose(1,2,0) 51 | image = Image.fromarray(x) 52 | image.save(args.rec_path) 53 | 54 | 55 | if __name__ == '__main__': 56 | parser = argparse.ArgumentParser() 57 | parser.add_argument('--image_path', type=str, default='') 58 | parser.add_argument('--rec_path', type=str, default='') 59 | parser.add_argument('--ae', type=str, default='') 60 | parser.add_argument('--ae_path', type=str, default='') 61 | parser.add_argument('--model_path', type=str, default='results/pretrained') 62 | parser.add_argument('--short_size', type=int, default=336) 63 | parser.add_argument('--device', type=str, default='cuda') 64 | parser.add_argument('--tile_overlap_factor', type=float, default=0.25) 65 | parser.add_argument('--enable_tiling', action='store_true') 66 | 67 | args = parser.parse_args() 68 | main(args) 69 | -------------------------------------------------------------------------------- /opensora/sample/sample.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | try: 4 | import torch_npu 5 | from opensora.npu_config import npu_config 6 | except: 7 | torch_npu = None 8 | npu_config = None 9 | pass 10 | from opensora.utils.sample_utils import ( 11 | init_gpu_env, init_npu_env, prepare_pipeline, get_args, 12 | run_model_and_save_samples, run_model_and_save_samples_npu 13 | ) 14 | from opensora.sample.caption_refiner import OpenSoraCaptionRefiner 15 | 16 | if __name__ == "__main__": 17 | args = get_args() 18 | dtype = torch.float16 19 | 20 | if torch_npu is not None: 21 | npu_config.print_msg(args) 22 | npu_config.conv_dtype = dtype 23 | init_npu_env(args) 24 | else: 25 | args = init_gpu_env(args) 26 | 27 | device = torch.cuda.current_device() 28 | if args.num_frames != 1 and args.enhance_video is not None: 29 | from opensora.sample.VEnhancer.enhance_a_video import VEnhancer 30 | enhance_video_model = VEnhancer(model_path=args.enhance_video, version='v2', device=device) 31 | else: 32 | enhance_video_model = None 33 | pipeline = prepare_pipeline(args, dtype, device) 34 | if args.caption_refiner is not None: 35 | caption_refiner_model = OpenSoraCaptionRefiner(args, dtype, device) 36 | else: 37 | caption_refiner_model = None 38 | 39 | if npu_config is not None and npu_config.on_npu and npu_config.profiling: 40 | run_model_and_save_samples_npu(args, pipeline, caption_refiner_model, enhance_video_model) 41 | else: 42 | run_model_and_save_samples(args, pipeline, caption_refiner_model, enhance_video_model) 43 | -------------------------------------------------------------------------------- /opensora/serve/style.css: -------------------------------------------------------------------------------- 1 | .gradio-container{width:1280px!important} -------------------------------------------------------------------------------- /opensora/utils/downloader.py: -------------------------------------------------------------------------------- 1 | import gdown 2 | import os 3 | 4 | opensora_cache_home = os.path.expanduser( 5 | os.getenv("OPENSORA_HOME", os.path.join("~/.cache", "opensora")) 6 | ) 7 | 8 | 9 | def gdown_download(id, fname, cache_dir=None): 10 | cache_dir = opensora_cache_home if not cache_dir else cache_dir 11 | 12 | os.makedirs(cache_dir, exist_ok=True) 13 | destination = os.path.join(cache_dir, fname) 14 | if os.path.exists(destination): 15 | return destination 16 | 17 | gdown.download(id=id, output=destination, quiet=False) 18 | return destination 19 | -------------------------------------------------------------------------------- /opensora/utils/ema_utils.py: -------------------------------------------------------------------------------- 1 | 2 | from peft import get_peft_model, PeftModel 3 | import os 4 | from copy import deepcopy 5 | import torch 6 | import json 7 | from diffusers.training_utils import EMAModel as diffuser_EMAModel 8 | 9 | 10 | 11 | class EMAModel(diffuser_EMAModel): 12 | def __init__(self, parameters, **kwargs): 13 | self.lora_config = kwargs.pop('lora_config', None) 14 | super().__init__(parameters, **kwargs) 15 | 16 | @classmethod 17 | def from_pretrained(cls, path, model_cls, lora_config, model_base) -> "EMAModel": 18 | # 1. load model 19 | if lora_config is not None: 20 | # 1.1 load origin model 21 | model_base = model_cls.from_pretrained(model_base) # model_base 22 | config = model_base.config 23 | # 1.2 convert to lora model automatically and load lora weight 24 | model = PeftModel.from_pretrained(model_base, path) # lora_origin_model 25 | else: 26 | model = model_cls.from_pretrained(path) 27 | config = model.config 28 | # 3. ema the whole model 29 | ema_model = cls(model.parameters(), model_cls=model_cls, model_config=config, lora_config=lora_config) 30 | # 4. load ema_config, e.g decay... 31 | with open(os.path.join(path, 'ema_config.json'), 'r') as f: 32 | state_dict = json.load(f) 33 | ema_model.load_state_dict(state_dict) 34 | return ema_model 35 | 36 | def save_pretrained(self, path): 37 | if self.model_cls is None: 38 | raise ValueError("`save_pretrained` can only be used if `model_cls` was defined at __init__.") 39 | 40 | if self.model_config is None: 41 | raise ValueError("`save_pretrained` can only be used if `model_config` was defined at __init__.") 42 | # 1. init a base model randomly 43 | model = self.model_cls.from_config(self.model_config) 44 | # 1.1 convert lora_model 45 | if self.lora_config is not None: 46 | model = get_peft_model(model, self.lora_config) 47 | # 2. ema_model copy to model 48 | self.copy_to(model.parameters()) 49 | # 3. save weight 50 | if self.lora_config is not None: 51 | model.save_pretrained(path) # only lora weight 52 | merge_model = model.merge_and_unload() 53 | merge_model.save_pretrained(path) # merge_model weight 54 | else: 55 | merge_model.save_pretrained(path) # model weight 56 | # 4. save ema_config, e.g decay... 57 | state_dict = self.state_dict() # lora_model weight 58 | state_dict.pop("shadow_params", None) 59 | with open(os.path.join(path, 'ema_config.json'), 'w') as f: 60 | json.dump(state_dict, f, indent=2) -------------------------------------------------------------------------------- /opensora/utils/freeinit_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.fft as fft 3 | import math 4 | 5 | 6 | def freq_mix_3d(x, noise, LPF): 7 | """ 8 | Noise reinitialization. 9 | 10 | Args: 11 | x: diffused latent 12 | noise: randomly sampled noise 13 | LPF: low pass filter 14 | """ 15 | # FFT 16 | x_freq = fft.fftn(x, dim=(-3, -2, -1)) 17 | x_freq = fft.fftshift(x_freq, dim=(-3, -2, -1)) 18 | noise_freq = fft.fftn(noise, dim=(-3, -2, -1)) 19 | noise_freq = fft.fftshift(noise_freq, dim=(-3, -2, -1)) 20 | 21 | # frequency mix 22 | HPF = 1 - LPF 23 | x_freq_low = x_freq * LPF 24 | noise_freq_high = noise_freq * HPF 25 | x_freq_mixed = x_freq_low + noise_freq_high # mix in freq domain 26 | 27 | # IFFT 28 | x_freq_mixed = fft.ifftshift(x_freq_mixed, dim=(-3, -2, -1)) 29 | x_mixed = fft.ifftn(x_freq_mixed, dim=(-3, -2, -1)).real 30 | 31 | return x_mixed 32 | 33 | 34 | def get_freq_filter(shape, device, filter_type, n, d_s, d_t): 35 | """ 36 | Form the frequency filter for noise reinitialization. 37 | 38 | Args: 39 | shape: shape of latent (B, C, T, H, W) 40 | filter_type: type of the freq filter 41 | n: (only for butterworth) order of the filter, larger n ~ ideal, smaller n ~ gaussian 42 | d_s: normalized stop frequency for spatial dimensions (0.0-1.0) 43 | d_t: normalized stop frequency for temporal dimension (0.0-1.0) 44 | """ 45 | if filter_type == "gaussian": 46 | return gaussian_low_pass_filter(shape=shape, d_s=d_s, d_t=d_t).to(device) 47 | elif filter_type == "ideal": 48 | return ideal_low_pass_filter(shape=shape, d_s=d_s, d_t=d_t).to(device) 49 | elif filter_type == "box": 50 | return box_low_pass_filter(shape=shape, d_s=d_s, d_t=d_t).to(device) 51 | elif filter_type == "butterworth": 52 | return butterworth_low_pass_filter(shape=shape, n=n, d_s=d_s, d_t=d_t).to(device) 53 | else: 54 | raise NotImplementedError 55 | 56 | def gaussian_low_pass_filter(shape, d_s=0.25, d_t=0.25): 57 | """ 58 | Compute the gaussian low pass filter mask. 59 | 60 | Args: 61 | shape: shape of the filter (volume) 62 | d_s: normalized stop frequency for spatial dimensions (0.0-1.0) 63 | d_t: normalized stop frequency for temporal dimension (0.0-1.0) 64 | """ 65 | T, H, W = shape[-3], shape[-2], shape[-1] 66 | mask = torch.zeros(shape) 67 | if d_s==0 or d_t==0: 68 | return mask 69 | for t in range(T): 70 | for h in range(H): 71 | for w in range(W): 72 | d_square = (((d_s/d_t)*(2*t/T-1))**2 + (2*h/H-1)**2 + (2*w/W-1)**2) 73 | mask[..., t,h,w] = math.exp(-1/(2*d_s**2) * d_square) 74 | return mask 75 | 76 | 77 | def butterworth_low_pass_filter(shape, n=4, d_s=0.25, d_t=0.25): 78 | """ 79 | Compute the butterworth low pass filter mask. 80 | 81 | Args: 82 | shape: shape of the filter (volume) 83 | n: order of the filter, larger n ~ ideal, smaller n ~ gaussian 84 | d_s: normalized stop frequency for spatial dimensions (0.0-1.0) 85 | d_t: normalized stop frequency for temporal dimension (0.0-1.0) 86 | """ 87 | T, H, W = shape[-3], shape[-2], shape[-1] 88 | mask = torch.zeros(shape) 89 | if d_s==0 or d_t==0: 90 | return mask 91 | for t in range(T): 92 | for h in range(H): 93 | for w in range(W): 94 | d_square = (((d_s/d_t)*(2*t/T-1))**2 + (2*h/H-1)**2 + (2*w/W-1)**2) 95 | mask[..., t,h,w] = 1 / (1 + (d_square / d_s**2)**n) 96 | return mask 97 | 98 | 99 | def ideal_low_pass_filter(shape, d_s=0.25, d_t=0.25): 100 | """ 101 | Compute the ideal low pass filter mask. 102 | 103 | Args: 104 | shape: shape of the filter (volume) 105 | d_s: normalized stop frequency for spatial dimensions (0.0-1.0) 106 | d_t: normalized stop frequency for temporal dimension (0.0-1.0) 107 | """ 108 | T, H, W = shape[-3], shape[-2], shape[-1] 109 | mask = torch.zeros(shape) 110 | if d_s==0 or d_t==0: 111 | return mask 112 | for t in range(T): 113 | for h in range(H): 114 | for w in range(W): 115 | d_square = (((d_s/d_t)*(2*t/T-1))**2 + (2*h/H-1)**2 + (2*w/W-1)**2) 116 | mask[..., t,h,w] = 1 if d_square <= d_s*2 else 0 117 | return mask 118 | 119 | 120 | def box_low_pass_filter(shape, d_s=0.25, d_t=0.25): 121 | """ 122 | Compute the ideal low pass filter mask (approximated version). 123 | 124 | Args: 125 | shape: shape of the filter (volume) 126 | d_s: normalized stop frequency for spatial dimensions (0.0-1.0) 127 | d_t: normalized stop frequency for temporal dimension (0.0-1.0) 128 | """ 129 | T, H, W = shape[-3], shape[-2], shape[-1] 130 | mask = torch.zeros(shape) 131 | if d_s==0 or d_t==0: 132 | return mask 133 | 134 | threshold_s = round(int(H // 2) * d_s) 135 | threshold_t = round(T // 2 * d_t) 136 | 137 | cframe, crow, ccol = T // 2, H // 2, W //2 138 | mask[..., cframe - threshold_t:cframe + threshold_t, crow - threshold_s:crow + threshold_s, ccol - threshold_s:ccol + threshold_s] = 1.0 139 | 140 | return mask -------------------------------------------------------------------------------- /opensora/utils/lora_utils.py: -------------------------------------------------------------------------------- 1 | 2 | from peft import get_peft_model, PeftModel 3 | import os 4 | from copy import deepcopy 5 | import torch 6 | import json 7 | 8 | def maybe_zero_3(param, ignore_status=False, name=None): 9 | from deepspeed import zero 10 | from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus 11 | if hasattr(param, "ds_id"): 12 | if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: 13 | if not ignore_status: 14 | logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}") 15 | with zero.GatheredParameters([param]): 16 | param = param.data.detach().cpu().clone() 17 | else: 18 | param = param.detach().cpu().clone() 19 | return param 20 | 21 | # Borrowed from peft.utils.get_peft_model_state_dict 22 | def get_peft_state_maybe_zero_3(named_params, bias): 23 | if bias == "none": 24 | to_return = {k: t for k, t in named_params if "lora_" in k} 25 | elif bias == "all": 26 | to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k} 27 | elif bias == "lora_only": 28 | to_return = {} 29 | maybe_lora_bias = {} 30 | lora_bias_names = set() 31 | for k, t in named_params: 32 | if "lora_" in k: 33 | to_return[k] = t 34 | bias_name = k.split("lora_")[0] + "bias" 35 | lora_bias_names.add(bias_name) 36 | elif "bias" in k: 37 | maybe_lora_bias[k] = t 38 | for k, t in maybe_lora_bias: 39 | if bias_name in lora_bias_names: 40 | to_return[bias_name] = t 41 | else: 42 | raise NotImplementedError 43 | to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()} 44 | return to_return 45 | -------------------------------------------------------------------------------- /opensora/utils/parallel_states.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | import os 4 | 5 | class COMM_INFO: 6 | def __init__(self): 7 | self.group = None 8 | self.world_size = 0 9 | self.rank = -1 10 | 11 | nccl_info = COMM_INFO() 12 | _SEQUENCE_PARALLEL_STATE = False 13 | def initialize_sequence_parallel_state(sequence_parallel_size): 14 | global _SEQUENCE_PARALLEL_STATE 15 | if sequence_parallel_size > 1: 16 | _SEQUENCE_PARALLEL_STATE = True 17 | initialize_sequence_parallel_group(sequence_parallel_size) 18 | 19 | def set_sequence_parallel_state(state): 20 | global _SEQUENCE_PARALLEL_STATE 21 | _SEQUENCE_PARALLEL_STATE = state 22 | 23 | def get_sequence_parallel_state(): 24 | return _SEQUENCE_PARALLEL_STATE 25 | 26 | def initialize_sequence_parallel_group(sequence_parallel_size): 27 | """Initialize the sequence parallel group.""" 28 | rank = int(os.getenv('RANK', '0')) 29 | world_size = int(os.getenv("WORLD_SIZE", '1')) 30 | assert world_size % sequence_parallel_size == 0, "world_size must be divisible by sequence_parallel_size" 31 | # hccl 32 | nccl_info.world_size = sequence_parallel_size 33 | nccl_info.rank = rank 34 | num_sequence_parallel_groups: int = world_size // sequence_parallel_size 35 | for i in range(num_sequence_parallel_groups): 36 | ranks = range(i * sequence_parallel_size, (i + 1) * sequence_parallel_size) 37 | group = dist.new_group(ranks) 38 | if rank in ranks: 39 | nccl_info.group = group 40 | 41 | 42 | def destroy_sequence_parallel_group(): 43 | """Destroy the sequence parallel group.""" 44 | dist.destroy_process_group() 45 | -------------------------------------------------------------------------------- /opensora/utils/taming_download.py: -------------------------------------------------------------------------------- 1 | """Modified from https://github.com/CompVis/taming-transformers.git""" 2 | 3 | import os, hashlib 4 | import requests 5 | from tqdm import tqdm 6 | 7 | URL_MAP = { 8 | "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1" 9 | } 10 | 11 | CKPT_MAP = { 12 | "vgg_lpips": "vgg.pth" 13 | } 14 | 15 | MD5_MAP = { 16 | "vgg_lpips": "d507d7349b931f0638a25a48a722f98a" 17 | } 18 | 19 | 20 | def download(url, local_path, chunk_size=1024): 21 | os.makedirs(os.path.split(local_path)[0], exist_ok=True) 22 | with requests.get(url, stream=True) as r: 23 | total_size = int(r.headers.get("content-length", 0)) 24 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: 25 | with open(local_path, "wb") as f: 26 | for data in r.iter_content(chunk_size=chunk_size): 27 | if data: 28 | f.write(data) 29 | pbar.update(chunk_size) 30 | 31 | 32 | def md5_hash(path): 33 | with open(path, "rb") as f: 34 | content = f.read() 35 | return hashlib.md5(content).hexdigest() 36 | 37 | 38 | def get_ckpt_path(name, root, check=False): 39 | assert name in URL_MAP 40 | path = os.path.join(root, CKPT_MAP[name]) 41 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): 42 | print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) 43 | download(URL_MAP[name], path) 44 | md5 = md5_hash(path) 45 | assert md5 == MD5_MAP[name], md5 46 | return path 47 | 48 | 49 | class KeyNotFoundError(Exception): 50 | def __init__(self, cause, keys=None, visited=None): 51 | self.cause = cause 52 | self.keys = keys 53 | self.visited = visited 54 | messages = list() 55 | if keys is not None: 56 | messages.append("Key not found: {}".format(keys)) 57 | if visited is not None: 58 | messages.append("Visited: {}".format(visited)) 59 | messages.append("Cause:\n{}".format(cause)) 60 | message = "\n".join(messages) 61 | super().__init__(message) 62 | 63 | 64 | def retrieve( 65 | list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False 66 | ): 67 | """Given a nested list or dict return the desired value at key expanding 68 | callable nodes if necessary and :attr:`expand` is ``True``. The expansion 69 | is done in-place. 70 | 71 | Parameters 72 | ---------- 73 | list_or_dict : list or dict 74 | Possibly nested list or dictionary. 75 | key : str 76 | key/to/value, path like string describing all keys necessary to 77 | consider to get to the desired value. List indices can also be 78 | passed here. 79 | splitval : str 80 | String that defines the delimiter between keys of the 81 | different depth levels in `key`. 82 | default : obj 83 | Value returned if :attr:`key` is not found. 84 | expand : bool 85 | Whether to expand callable nodes on the path or not. 86 | 87 | Returns 88 | ------- 89 | The desired value or if :attr:`default` is not ``None`` and the 90 | :attr:`key` is not found returns ``default``. 91 | 92 | Raises 93 | ------ 94 | Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is 95 | ``None``. 96 | """ 97 | 98 | keys = key.split(splitval) 99 | 100 | success = True 101 | try: 102 | visited = [] 103 | parent = None 104 | last_key = None 105 | for key in keys: 106 | if callable(list_or_dict): 107 | if not expand: 108 | raise KeyNotFoundError( 109 | ValueError( 110 | "Trying to get past callable node with expand=False." 111 | ), 112 | keys=keys, 113 | visited=visited, 114 | ) 115 | list_or_dict = list_or_dict() 116 | parent[last_key] = list_or_dict 117 | 118 | last_key = key 119 | parent = list_or_dict 120 | 121 | try: 122 | if isinstance(list_or_dict, dict): 123 | list_or_dict = list_or_dict[key] 124 | else: 125 | list_or_dict = list_or_dict[int(key)] 126 | except (KeyError, IndexError, ValueError) as e: 127 | raise KeyNotFoundError(e, keys=keys, visited=visited) 128 | 129 | visited += [key] 130 | # final expansion of retrieved value 131 | if expand and callable(list_or_dict): 132 | list_or_dict = list_or_dict() 133 | parent[last_key] = list_or_dict 134 | except KeyNotFoundError as e: 135 | if default is None: 136 | raise e 137 | else: 138 | list_or_dict = default 139 | success = False 140 | 141 | if not pass_success: 142 | return list_or_dict 143 | else: 144 | return list_or_dict, success 145 | 146 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "opensora" 7 | version = "1.3.0" 8 | description = "Reproduce OpenAI's Sora." 9 | readme = "README.md" 10 | requires-python = ">=3.8" 11 | classifiers = [ 12 | "Programming Language :: Python :: 3", 13 | "License :: OSI Approved :: Apache Software License", 14 | ] 15 | dependencies = [ 16 | "transformers==4.44.2", "tokenizers==0.19.1", 17 | "albumentations==1.4.0", "av==11.0.0", "decord==0.6.0", "einops==0.7.0", "fastapi==0.110.0", 18 | "gdown==5.1.0", "h5py==3.10.0", "idna==3.8", 'imageio==2.34.0', "matplotlib==3.7.5", "numpy==1.24.4", 19 | "omegaconf==2.1.1", "opencv-python==4.9.0.80", "opencv-python-headless==4.9.0.80", "pandas==2.0.3", "pillow==10.2.0", 20 | "pydub==0.25.1", "pytorchvideo==0.1.5", "PyYAML==6.0.2", "regex==2024.7.24", 21 | "requests==2.32.3", "scikit-learn==1.3.2", "scipy==1.10.1", "six==1.16.0", "test-tube==0.7.5", 22 | "timm==0.9.16", "torchdiffeq==0.2.3", "torchmetrics==1.3.2", "tqdm==4.66.5", "urllib3==2.2.2", "uvicorn==0.27.1", 23 | "scikit-video==1.1.11", "imageio-ffmpeg==0.4.9", "sentencepiece==0.1.99", "beautifulsoup4==4.12.3", 24 | "ftfy==6.1.3", "moviepy==1.0.3", "wandb==0.16.3", "tensorboard==2.14.0", "pydantic==2.6.4", "gradio==4.0.0", 25 | "torch==2.1.0", "torchvision==0.16.0", "xformers==0.0.22.post7", "accelerate==0.34.0", "diffusers==0.30.2", "deepspeed==0.12.6" 26 | ] 27 | 28 | [project.optional-dependencies] 29 | dev = ["mypy==1.8.0"] 30 | 31 | 32 | [project.urls] 33 | "Homepage" = "https://github.com/PKU-YuanGroup/Open-Sora-Plan" 34 | "Bug Tracker" = "https://github.com/PKU-YuanGroup/Open-Sora-Plan/issues" 35 | 36 | [tool.setuptools.packages.find] 37 | exclude = ["assets*", "docker*", "docs", "scripts*"] 38 | 39 | [tool.wheel] 40 | exclude = ["assets*", "docker*", "docs", "scripts*"] 41 | 42 | [tool.mypy] 43 | warn_return_any = true 44 | warn_unused_configs = true 45 | ignore_missing_imports = true 46 | disallow_untyped_calls = true 47 | check_untyped_defs = true 48 | no_implicit_optional = true 49 | -------------------------------------------------------------------------------- /scripts/accelerate_configs/ddp_config.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | distributed_type: MULTI_GPU 3 | fsdp_config: {} 4 | machine_rank: 0 5 | main_process_ip: null 6 | main_process_port: 29501 7 | main_training_function: main 8 | num_machines: 1 9 | num_processes: 1 10 | gpu_ids: 0, 11 | use_cpu: false -------------------------------------------------------------------------------- /scripts/accelerate_configs/deepspeed_zero2_config.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | distributed_type: DEEPSPEED 3 | deepspeed_config: 4 | deepspeed_config_file: scripts/accelerate_configs/zero2.json 5 | fsdp_config: {} 6 | machine_rank: 0 7 | main_process_ip: null 8 | main_process_port: 29513 9 | main_training_function: main 10 | num_machines: 1 11 | num_processes: 8 12 | gpu_ids: 0,1,2,3,4,5,6,7 13 | use_cpu: false 14 | -------------------------------------------------------------------------------- /scripts/accelerate_configs/deepspeed_zero2_offload_config.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | distributed_type: DEEPSPEED 3 | deepspeed_config: 4 | deepspeed_config_file: scripts/accelerate_configs/zero2_offload.json 5 | fsdp_config: {} 6 | machine_rank: 0 7 | main_process_ip: null 8 | main_process_port: 29501 9 | main_training_function: main 10 | num_machines: 1 11 | num_processes: 8 12 | gpu_ids: 0,1,2,3,4,5,6,7 13 | use_cpu: false -------------------------------------------------------------------------------- /scripts/accelerate_configs/deepspeed_zero3_config.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | distributed_type: DEEPSPEED 3 | deepspeed_config: 4 | deepspeed_config_file: scripts/accelerate_configs/zero3.json 5 | fsdp_config: {} 6 | machine_rank: 0 7 | main_process_ip: null 8 | main_process_port: 29501 9 | main_training_function: main 10 | num_machines: 1 11 | num_processes: 8 12 | gpu_ids: 0,1,2,3,4,5,6,7 13 | use_cpu: false -------------------------------------------------------------------------------- /scripts/accelerate_configs/deepspeed_zero3_offload_config.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | distributed_type: DEEPSPEED 3 | deepspeed_config: 4 | deepspeed_config_file: scripts/accelerate_configs/zero3_offload.json 5 | fsdp_config: {} 6 | machine_rank: 0 7 | main_process_ip: null 8 | main_process_port: 29501 9 | main_training_function: main 10 | num_machines: 1 11 | num_processes: 8 12 | gpu_ids: 0,1,2,3,4,5,6,7 13 | use_cpu: false -------------------------------------------------------------------------------- /scripts/accelerate_configs/default_config.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | distributed_type: MULTI_GPU 3 | fsdp_config: {} 4 | machine_rank: 0 5 | main_process_ip: null 6 | main_process_port: 29501 7 | main_training_function: main 8 | mixed_precision: bf16 9 | num_machines: 1 10 | num_processes: 8 11 | gpu_ids: 0,1,2,3,4,5,6,7 12 | use_cpu: false -------------------------------------------------------------------------------- /scripts/accelerate_configs/hostfile: -------------------------------------------------------------------------------- 1 | 100.64.24.30 slots=8 2 | 100.64.24.6 slots=8 3 | 100.64.24.7 slots=8 4 | 100.64.24.8 slots=8 5 | 100.64.24.10 slots=8 6 | 100.64.24.11 slots=8 7 | 100.64.24.13 slots=8 8 | 100.64.24.14 slots=8 9 | 100.64.24.17 slots=8 10 | 100.64.24.19 slots=8 11 | 100.64.24.26 slots=8 12 | 100.64.24.27 slots=8 13 | 100.64.24.28 slots=8 14 | 100.64.24.29 slots=8 15 | 100.64.24.31 slots=8 16 | 100.64.24.32 slots=8 -------------------------------------------------------------------------------- /scripts/accelerate_configs/multi_node_example.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | distributed_type: DEEPSPEED 3 | deepspeed_config: 4 | deepspeed_config_file: scripts/accelerate_configs/zero2.json 5 | deepspeed_hostfile: scripts/accelerate_configs/hostfile 6 | fsdp_config: {} 7 | machine_rank: 0 8 | main_process_ip: 100.64.24.30 9 | main_process_port: 29522 10 | main_training_function: main 11 | num_machines: 16 12 | num_processes: 128 13 | rdzv_backend: static 14 | same_network: true 15 | tpu_env: [] 16 | tpu_use_cluster: false 17 | tpu_use_sudo: false 18 | use_cpu: false 19 | -------------------------------------------------------------------------------- /scripts/accelerate_configs/multi_node_example_by_ddp.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | distributed_type: MULTI_GPU 3 | fsdp_config: {} 4 | main_process_port: 29501 5 | main_training_function: main 6 | num_machines: 32 7 | num_processes: 256 8 | rdzv_backend: static 9 | same_network: true 10 | tpu_env: [] 11 | tpu_use_cluster: false 12 | tpu_use_sudo: false 13 | use_cpu: false -------------------------------------------------------------------------------- /scripts/accelerate_configs/zero2.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": false, 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "communication_data_type": "fp32", 14 | "gradient_clipping": 1.0, 15 | "train_micro_batch_size_per_gpu": "auto", 16 | "train_batch_size": "auto", 17 | "gradient_accumulation_steps": "auto", 18 | "zero_optimization": { 19 | "stage": 2, 20 | "overlap_comm": true, 21 | "contiguous_gradients": true, 22 | "sub_group_size": 1e9, 23 | "reduce_bucket_size": 5e8 24 | } 25 | } -------------------------------------------------------------------------------- /scripts/accelerate_configs/zero2_npu.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": false, 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "communication_data_type": "fp32", 14 | "gradient_clipping": 1.0, 15 | "train_micro_batch_size_per_gpu": "auto", 16 | "train_batch_size": "auto", 17 | "gradient_accumulation_steps": "auto", 18 | "zero_optimization": { 19 | "stage": 2, 20 | "overlap_comm": true, 21 | "allgather_bucket_size": 536870912, 22 | "contiguous_gradients": true, 23 | "reduce_bucket_size": 536870912 24 | } 25 | } -------------------------------------------------------------------------------- /scripts/accelerate_configs/zero2_offload.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "communication_data_type": "fp32", 14 | "gradient_clipping": 1.0, 15 | "train_micro_batch_size_per_gpu": "auto", 16 | "train_batch_size": "auto", 17 | "gradient_accumulation_steps": "auto", 18 | "zero_optimization": { 19 | "stage": 2, 20 | "offload_optimizer": { 21 | "device": "cpu" 22 | }, 23 | "overlap_comm": true, 24 | "contiguous_gradients": true, 25 | "sub_group_size": 1e9, 26 | "reduce_bucket_size": 5e8, 27 | "round_robin_gradients": true 28 | } 29 | } -------------------------------------------------------------------------------- /scripts/accelerate_configs/zero3.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "communication_data_type": "fp32", 14 | "gradient_clipping": 1.0, 15 | "train_micro_batch_size_per_gpu": "auto", 16 | "train_batch_size": "auto", 17 | "gradient_accumulation_steps": "auto", 18 | "zero_optimization": { 19 | "stage": 3, 20 | "overlap_comm": true, 21 | "contiguous_gradients": true, 22 | "sub_group_size": 1e9, 23 | "reduce_bucket_size": 5e8, 24 | "stage3_prefetch_bucket_size": 5e8, 25 | "stage3_param_persistence_threshold": "auto", 26 | "stage3_max_live_parameters": 1e9, 27 | "stage3_max_reuse_distance": 1e9, 28 | "stage3_gather_16bit_weights_on_model_save": true 29 | } 30 | } -------------------------------------------------------------------------------- /scripts/accelerate_configs/zero3_offload.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "zero_optimization": { 14 | "stage": 3, 15 | "offload_optimizer": { 16 | "device": "cpu", 17 | "pin_memory": true 18 | }, 19 | "offload_param": { 20 | "device": "cpu", 21 | "pin_memory": true 22 | }, 23 | "overlap_comm": true, 24 | "contiguous_gradients": true, 25 | "sub_group_size": 1e9, 26 | "reduce_bucket_size": 5e8, 27 | "stage3_prefetch_bucket_size": "auto", 28 | "stage3_param_persistence_threshold": "auto", 29 | "stage3_max_live_parameters": 1e9, 30 | "stage3_max_reuse_distance": 1e9, 31 | "gather_16bit_weights_on_model_save": true 32 | }, 33 | "gradient_accumulation_steps": "auto", 34 | "gradient_clipping": "auto", 35 | "train_batch_size": "auto", 36 | "train_micro_batch_size_per_gpu": "auto", 37 | "steps_per_print": 1e5, 38 | "wall_clock_breakdown": false 39 | } -------------------------------------------------------------------------------- /scripts/causalvae/eval.sh: -------------------------------------------------------------------------------- 1 | EXP_NAME=wfvae-4dim 2 | SAMPLE_RATE=1 3 | NUM_FRAMES=33 4 | RESOLUTION=256 5 | METRIC=lpips 6 | SUBSET_SIZE=0 7 | ORIGIN_DIR=video_gen/${EXP_NAME}_sr${SAMPLE_RATE}_nf${NUM_FRAMES}_res${RESOLUTION}_subset${SUBSET_SIZE}/origin 8 | RECON_DIR=video_gen/${EXP_NAME}_sr${SAMPLE_RATE}_nf${NUM_FRAMES}_res${RESOLUTION}_subset${SUBSET_SIZE} 9 | 10 | python opensora/models/causalvideovae/eval/eval.py \ 11 | --batch_size 8 \ 12 | --real_video_dir ${ORIGIN_DIR} \ 13 | --generated_video_dir ${RECON_DIR} \ 14 | --device cuda:1 \ 15 | --sample_fps 1 \ 16 | --sample_rate ${SAMPLE_RATE} \ 17 | --num_frames ${NUM_FRAMES} \ 18 | --resolution ${RESOLUTION} \ 19 | --crop_size ${RESOLUTION} \ 20 | --subset_size ${SUBSET_SIZE} \ 21 | --metric ${METRIC} -------------------------------------------------------------------------------- /scripts/causalvae/prepare_eval.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 2 | DATASET_DIR=test_video 3 | EXP_NAME=wfvae 4 | SAMPLE_RATE=1 5 | NUM_FRAMES=33 6 | RESOLUTION=256 7 | CKPT=ckpt 8 | SUBSET_SIZE=0 9 | 10 | accelerate launch \ 11 | --config_file scripts/accelerate_configs/default_config.yaml \ 12 | opensora/models/causalvideovae/sample/rec_video_vae.py \ 13 | --batch_size 1 \ 14 | --real_video_dir ${DATASET_DIR} \ 15 | --generated_video_dir video_gen/${EXP_NAME}_sr${SAMPLE_RATE}_nf${NUM_FRAMES}_res${RESOLUTION}_subset${SUBSET_SIZE} \ 16 | --device cuda \ 17 | --sample_fps 24 \ 18 | --sample_rate ${SAMPLE_RATE} \ 19 | --num_frames ${NUM_FRAMES} \ 20 | --resolution ${RESOLUTION} \ 21 | --subset_size ${SUBSET_SIZE} \ 22 | --num_workers 8 \ 23 | --from_pretrained ${CKPT} \ 24 | --model_name WFVAE \ 25 | --output_origin \ 26 | --crop_size ${RESOLUTION} 27 | -------------------------------------------------------------------------------- /scripts/causalvae/rec_image.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python examples/rec_image.py \ 2 | --ae WFVAEModel_D8_4x8x8 \ 3 | --ae_path "/storage/lcm/WF-VAE/results/latent8" \ 4 | --image_path /storage/dataset/image/anytext3m/ocr_data/Art/images/gt_5544.jpg \ 5 | --rec_path rec_.jpg \ 6 | --device cuda \ 7 | --short_size 512 -------------------------------------------------------------------------------- /scripts/causalvae/rec_video.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=1 python examples/rec_video.py \ 2 | --ae WFVAEModel_D8_4x8x8 \ 3 | --ae_path "/storage/lcm/WF-VAE/results/latent8" \ 4 | --video_path /storage/lcm/WF-VAE/testvideo/gm1190263332-337350271.mp4 \ 5 | --rec_path rec_tile_.mp4 \ 6 | --device cuda \ 7 | --sample_rate 1 \ 8 | --num_frames 65 \ 9 | --height 512 \ 10 | --width 512 \ 11 | --fps 30 \ 12 | --enable_tiling -------------------------------------------------------------------------------- /scripts/causalvae/train.sh: -------------------------------------------------------------------------------- 1 | export WANDB_PROJECT=WFVAE 2 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 3 | export GLOO_SOCKET_IFNAME=bond0 4 | export NCCL_SOCKET_IFNAME=bond0 5 | export NCCL_IB_HCA=mlx5_10:1,mlx5_11:1,mlx5_12:1,mlx5_13:1 6 | export NCCL_IB_GID_INDEX=3 7 | export NCCL_IB_TC=162 8 | export NCCL_IB_TIMEOUT=22 9 | export NCCL_PXN_DISABLE=0 10 | export NCCL_IB_QPS_PER_CONNECTION=4 11 | export NCCL_ALGO=Ring 12 | export OMP_NUM_THREADS=1 13 | export MKL_NUM_THREADS=1 14 | 15 | EXP_NAME=TRAIN 16 | 17 | torchrun \ 18 | --nnodes=1 --nproc_per_node=8 \ 19 | --master_addr=localhost \ 20 | --master_port=12133 \ 21 | opensora/train/train_causalvae.py \ 22 | --exp_name ${EXP_NAME} \ 23 | --video_path /storage/dataset/vae_eval/OpenMMLab___Kinetics-400/raw/Kinetics-400/videos_train/ \ 24 | --eval_video_path /storage/dataset/vae_eval/OpenMMLab___Kinetics-400/raw/Kinetics-400/videos_val/ \ 25 | --model_name WFVAE \ 26 | --model_config scripts/causalvae/wfvae_4dim.json \ 27 | --resolution 256 \ 28 | --num_frames 25 \ 29 | --batch_size 1 \ 30 | --lr 0.00001 \ 31 | --epochs 4 \ 32 | --disc_start 0 \ 33 | --save_ckpt_step 5000 \ 34 | --eval_steps 1000 \ 35 | --eval_batch_size 1 \ 36 | --eval_num_frames 33 \ 37 | --eval_sample_rate 1 \ 38 | --eval_subset_size 500 \ 39 | --eval_lpips \ 40 | --ema \ 41 | --ema_decay 0.999 \ 42 | --perceptual_weight 1.0 \ 43 | --loss_type l1 \ 44 | --sample_rate 1 \ 45 | --disc_cls opensora.models.causalvideovae.model.losses.LPIPSWithDiscriminator3D \ 46 | --wavelet_loss \ 47 | --wavelet_weight 0.1 -------------------------------------------------------------------------------- /scripts/causalvae/wfvae_4dim.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "WFVAEModel", 3 | "_diffusers_version": "0.30.2", 4 | "base_channels": 128, 5 | "connect_res_layer_num": 1, 6 | "decoder_energy_flow_hidden_size": 128, 7 | "decoder_num_resblocks": 2, 8 | "dropout": 0.0, 9 | "encoder_energy_flow_hidden_size": 128, 10 | "encoder_num_resblocks": 2, 11 | "l1_dowmsample_block": "Downsample", 12 | "l1_downsample_wavelet": "HaarWaveletTransform2D", 13 | "l1_upsample_block": "Upsample", 14 | "l1_upsample_wavelet": "InverseHaarWaveletTransform2D", 15 | "l2_dowmsample_block": "Spatial2xTime2x3DDownsample", 16 | "l2_downsample_wavelet": "HaarWaveletTransform3D", 17 | "l2_upsample_block": "Spatial2xTime2x3DUpsample", 18 | "l2_upsample_wavelet": "InverseHaarWaveletTransform3D", 19 | "latent_dim": 4, 20 | "norm_type": "layernorm", 21 | "t_interpolation": "trilinear", 22 | "use_attention": true 23 | } -------------------------------------------------------------------------------- /scripts/slurm/placeholder: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-YuanGroup/Open-Sora-Plan/469d4e8810c326811e1be7e1c17b845503633210/scripts/slurm/placeholder -------------------------------------------------------------------------------- /scripts/text_condition/gpu/sample_inpaint_v1_3.sh: -------------------------------------------------------------------------------- 1 | 2 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nnodes=1 --nproc_per_node 8 --master_port 29513 \ 3 | -m opensora.sample.sample \ 4 | --model_type "inpaint" \ 5 | --model_path model_path \ 6 | --version v1_3 \ 7 | --num_frames 93 \ 8 | --height 352 \ 9 | --width 640 \ 10 | --max_hxw 236544 \ 11 | --crop_for_hw \ 12 | --cache_dir "../cache_dir" \ 13 | --text_encoder_name_1 "/storage/ongoing/new/Open-Sora-Plan/cache_dir/mt5-xxl" \ 14 | --text_prompt examples/cond_prompt.txt \ 15 | --conditional_pixel_values_path examples/cond_pix_path.txt \ 16 | --ae WFVAEModel_D8_4x8x8 \ 17 | --ae_path "/storage/lcm/WF-VAE/results/latent8" \ 18 | --save_img_path "./save_path" \ 19 | --fps 18 \ 20 | --guidance_scale 7.5 \ 21 | --num_sampling_steps 100 \ 22 | --max_sequence_length 512 \ 23 | --sample_method EulerAncestralDiscrete \ 24 | --seed 1234 \ 25 | --num_samples_per_prompt 1 \ 26 | --rescale_betas_zero_snr \ 27 | --prediction_type "v_prediction" \ 28 | --noise_strength 0.0 \ -------------------------------------------------------------------------------- /scripts/text_condition/gpu/sample_t2v_v1_3.sh: -------------------------------------------------------------------------------- 1 | 2 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nnodes=1 --nproc_per_node 8 --master_port 29514 \ 3 | -m opensora.sample.sample \ 4 | --model_path /storage/ongoing/9.29/mmdit/Open-Sora-Plan/final_ft_any93x352x640_v1_3_bs512_lr1e-5_snr5.0_fps16_zsnr_nofix_16node/checkpoint-5500/model_ema \ 5 | --version v1_3 \ 6 | --num_frames 93 \ 7 | --height 352 \ 8 | --width 640 \ 9 | --cache_dir "../cache_dir" \ 10 | --text_encoder_name_1 "/storage/ongoing/new/Open-Sora-Plan/cache_dir/mt5-xxl" \ 11 | --text_prompt "examples/sora.txt" \ 12 | --ae WFVAEModel_D8_4x8x8 \ 13 | --ae_path "/storage/lcm/WF-VAE/results/latent8" \ 14 | --save_img_path "./train_1_3_nomotion_fps18" \ 15 | --fps 18 \ 16 | --guidance_scale 7.5 \ 17 | --num_sampling_steps 100 \ 18 | --max_sequence_length 512 \ 19 | --sample_method EulerAncestralDiscrete \ 20 | --seed 1234 \ 21 | --num_samples_per_prompt 1 \ 22 | --rescale_betas_zero_snr \ 23 | --prediction_type "v_prediction" -------------------------------------------------------------------------------- /scripts/text_condition/gpu/train_inpaint_v1_3.sh: -------------------------------------------------------------------------------- 1 | 2 | export HF_DATASETS_OFFLINE=1 3 | export TRANSFORMERS_OFFLINE=1 4 | export PDSH_RCMD_TYPE=ssh 5 | # NCCL setting 6 | export GLOO_SOCKET_IFNAME=bond0 7 | export NCCL_SOCKET_IFNAME=bond0 8 | export NCCL_IB_HCA=mlx5_10:1,mlx5_11:1,mlx5_12:1,mlx5_13:1 9 | export NCCL_IB_GID_INDEX=3 10 | export NCCL_IB_TC=162 11 | export NCCL_IB_TIMEOUT=25 12 | export NCCL_PXN_DISABLE=0 13 | export NCCL_IB_QPS_PER_CONNECTION=4 14 | export NCCL_ALGO=Ring 15 | export OMP_NUM_THREADS=1 16 | export MKL_NUM_THREADS=1 17 | export NCCL_IB_RETRY_CNT=32 18 | # export NCCL_ALGO=Tree 19 | 20 | accelerate launch \ 21 | --config_file scripts/accelerate_configs/deepspeed_zero2_config.yaml \ 22 | opensora/train/train_inpaint.py \ 23 | --model OpenSoraInpaint_v1_3-2B/122 \ 24 | --text_encoder_name_1 google/mt5-xxl \ 25 | --cache_dir "../../cache_dir/" \ 26 | --dataset inpaint \ 27 | --data "scripts/train_data/video_data.txt" \ 28 | --ae WFVAEModel_D8_4x8x8 \ 29 | --ae_path "/storage/lcm/WF-VAE/results/latent8" \ 30 | --sample_rate 1 \ 31 | --num_frames 93 \ 32 | --max_hxw 236544 \ 33 | --min_hxw 102400 \ 34 | --interpolation_scale_t 1.0 \ 35 | --interpolation_scale_h 1.0 \ 36 | --interpolation_scale_w 1.0 \ 37 | --gradient_checkpointing \ 38 | --train_batch_size=1 \ 39 | --dataloader_num_workers 8 \ 40 | --gradient_accumulation_steps=1 \ 41 | --max_train_steps=1000000 \ 42 | --learning_rate=1e-5 \ 43 | --lr_scheduler="constant" \ 44 | --lr_warmup_steps=0 \ 45 | --mixed_precision="bf16" \ 46 | --report_to="wandb" \ 47 | --checkpointing_steps=1000 \ 48 | --allow_tf32 \ 49 | --model_max_length 512 \ 50 | --use_ema \ 51 | --ema_start_step 0 \ 52 | --cfg 0.1 \ 53 | --resume_from_checkpoint="latest" \ 54 | --speed_factor 1.0 \ 55 | --ema_decay 0.9999 \ 56 | --drop_short_ratio 0.0 \ 57 | --hw_stride 32 \ 58 | --sparse1d --sparse_n 4 \ 59 | --train_fps 18 \ 60 | --seed 1234 \ 61 | --trained_data_global_step 0 \ 62 | --group_data \ 63 | --use_decord \ 64 | --prediction_type "v_prediction" \ 65 | --output_dir="debug" \ 66 | --rescale_betas_zero_snr \ 67 | --mask_config scripts/train_configs/mask_config.yaml \ 68 | --add_noise_to_condition \ 69 | --default_text_ratio 0.5 70 | # --pretrained "" 71 | -------------------------------------------------------------------------------- /scripts/text_condition/gpu/train_t2v_v1_3.sh: -------------------------------------------------------------------------------- 1 | 2 | export HF_DATASETS_OFFLINE=1 3 | export TRANSFORMERS_OFFLINE=1 4 | export PDSH_RCMD_TYPE=ssh 5 | # NCCL setting 6 | export GLOO_SOCKET_IFNAME=bond0 7 | export NCCL_SOCKET_IFNAME=bond0 8 | export NCCL_IB_HCA=mlx5_10:1,mlx5_11:1,mlx5_12:1,mlx5_13:1 9 | export NCCL_IB_GID_INDEX=3 10 | export NCCL_IB_TC=162 11 | export NCCL_IB_TIMEOUT=25 12 | export NCCL_PXN_DISABLE=0 13 | export NCCL_IB_QPS_PER_CONNECTION=4 14 | export NCCL_ALGO=Ring 15 | export OMP_NUM_THREADS=1 16 | export MKL_NUM_THREADS=1 17 | export NCCL_IB_RETRY_CNT=32 18 | # export NCCL_ALGO=Tree 19 | 20 | accelerate launch \ 21 | --config_file scripts/accelerate_configs/deepspeed_zero2_config.yaml \ 22 | opensora/train/train_t2v_diffusers.py \ 23 | --model OpenSoraT2V_v1_3-2B/122 \ 24 | --text_encoder_name_1 google/mt5-xxl \ 25 | --cache_dir "../../cache_dir/" \ 26 | --dataset t2v \ 27 | --data "scripts/train_data/merge_data.txt" \ 28 | --ae WFVAEModel_D8_4x8x8 \ 29 | --ae_path "/storage/lcm/WF-VAE/results/latent8" \ 30 | --sample_rate 1 \ 31 | --num_frames 1 \ 32 | --max_height 352 \ 33 | --max_width 640 \ 34 | --interpolation_scale_t 1.0 \ 35 | --interpolation_scale_h 1.0 \ 36 | --interpolation_scale_w 1.0 \ 37 | --gradient_checkpointing \ 38 | --train_batch_size=4 \ 39 | --dataloader_num_workers 16 \ 40 | --gradient_accumulation_steps=1 \ 41 | --max_train_steps=1000000 \ 42 | --learning_rate=1e-5 \ 43 | --lr_scheduler="constant" \ 44 | --lr_warmup_steps=0 \ 45 | --mixed_precision="bf16" \ 46 | --report_to="wandb" \ 47 | --checkpointing_steps=500 \ 48 | --allow_tf32 \ 49 | --model_max_length 512 \ 50 | --use_ema \ 51 | --ema_start_step 0 \ 52 | --cfg 0.1 \ 53 | --resume_from_checkpoint="latest" \ 54 | --speed_factor 1.0 \ 55 | --ema_decay 0.9999 \ 56 | --drop_short_ratio 0.0 \ 57 | --pretrained "" \ 58 | --hw_stride 32 \ 59 | --sparse1d --sparse_n 4 \ 60 | --train_fps 16 \ 61 | --seed 1234 \ 62 | --trained_data_global_step 0 \ 63 | --group_data \ 64 | --use_decord \ 65 | --prediction_type "v_prediction" \ 66 | --snr_gamma 5.0 \ 67 | --force_resolution \ 68 | --rescale_betas_zero_snr \ 69 | --output_dir="debug" -------------------------------------------------------------------------------- /scripts/text_condition/npu/sample_inpaint_v1_3.sh: -------------------------------------------------------------------------------- 1 | 2 | export TASK_QUEUE_ENABLE=0 3 | torchrun --nnodes=1 --nproc_per_node 8 --master_port 29522 \ 4 | -m opensora.sample.sample \ 5 | --model_type "inpaint" \ 6 | --model_path model_path \ 7 | --version v1_3 \ 8 | --num_frames 93 \ 9 | --crop_for_hw \ 10 | --height 352 \ 11 | --width 640 \ 12 | --max_hxw 236544 \ 13 | --cache_dir "../cache_dir" \ 14 | --text_encoder_name_1 "/home/save_dir/pretrained/mt5-xxl" \ 15 | --text_prompt /home/image_data/gyy/mmdit/Open-Sora-Plan/validation_dir/prompt.txt \ 16 | --conditional_pixel_values_path /home/image_data/gyy/mmdit/Open-Sora-Plan/validation_dir/cond_imgs_path.txt \ 17 | --ae WFVAEModel_D8_4x8x8 \ 18 | --ae_path "/home/save_dir/lzj/formal_8dim/latent8" \ 19 | --save_img_path "./test" \ 20 | --fps 18 \ 21 | --guidance_scale 7.5 \ 22 | --num_sampling_steps 50 \ 23 | --max_sequence_length 512 \ 24 | --sample_method EulerAncestralDiscrete \ 25 | --seed 2514 \ 26 | --num_samples_per_prompt 1 \ 27 | --prediction_type "v_prediction" \ 28 | --rescale_betas_zero_snr \ 29 | --noise_strength 0.0 \ 30 | # --mask_type i2v \ 31 | # --enable_tiling 32 | -------------------------------------------------------------------------------- /scripts/text_condition/npu/sample_t2v_v1_3.sh: -------------------------------------------------------------------------------- 1 | 2 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nnodes=1 --nproc_per_node 8 --master_port 29513 \ 3 | -m opensora.sample.sample \ 4 | --model_path model_path \ 5 | --version v1_3 \ 6 | --num_frames 93 \ 7 | --height 352 \ 8 | --width 640 \ 9 | --cache_dir "../cache_dir" \ 10 | --text_encoder_name_1 "/home/save_dir/pretrained/mt5-xxl" \ 11 | --text_prompt examples/sora_refine.txt \ 12 | --ae WFVAEModel_D8_4x8x8 \ 13 | --ae_path "/home/save_dir/lzj/formal_8dim/latent8" \ 14 | --save_img_path "./test" \ 15 | --fps 18 \ 16 | --guidance_scale 7.5 \ 17 | --num_sampling_steps 100 \ 18 | --max_sequence_length 512 \ 19 | --sample_method EulerAncestralDiscrete \ 20 | --seed 1234 \ 21 | --num_samples_per_prompt 1 \ 22 | --rescale_betas_zero_snr \ 23 | --prediction_type "v_prediction" -------------------------------------------------------------------------------- /scripts/text_condition/npu/train_inpaint_v1_3.sh: -------------------------------------------------------------------------------- 1 | 2 | export PROJECT=$PROJECT_NAME 3 | # export PROJECT='test' 4 | export HF_DATASETS_OFFLINE=1 5 | export TRANSFORMERS_OFFLINE=1 6 | 7 | export TASK_QUEUE_ENABLE=0 8 | export HCCL_OP_BASE_FFTS_MODE_ENABLE=TRUE 9 | export MULTI_STREAM_MEMORY_REUSE=1 10 | export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True 11 | # export HCCL_ALGO="level0:NA;level1:H-D_R" 12 | # --machine_rank=${MACHINE_RANK} \ 13 | # --main_process_ip=${MAIN_PROCESS_IP_VALUE} \ 14 | # multi_node_example_by_deepspeed.yaml 15 | # deepspeed_zero2_config.yaml 16 | 17 | accelerate launch \ 18 | --config_file scripts/accelerate_configs/deepspeed_zero2_config.yaml \ 19 | opensora/train/train_inpaint.py \ 20 | --model OpenSoraInpaint_v1_3-2B/122 \ 21 | --text_encoder_name_1 google/mt5-xxl \ 22 | --cache_dir "../../cache_dir/" \ 23 | --dataset inpaint \ 24 | --data "scripts/train_data/video_data.txt" \ 25 | --ae WFVAEModel_D8_4x8x8 \ 26 | --ae_path "/home/save_dir/lzj/formal_8dim/latent8" \ 27 | --vae_fp32 \ 28 | --sample_rate 1 \ 29 | --num_frames 93 \ 30 | --max_hxw 236544 \ 31 | --min_hxw 102400 \ 32 | --snr_gamma 5.0 \ 33 | --interpolation_scale_t 1.0 \ 34 | --interpolation_scale_h 1.0 \ 35 | --interpolation_scale_w 1.0 \ 36 | --gradient_checkpointing \ 37 | --train_batch_size=1 \ 38 | --dataloader_num_workers 8 \ 39 | --gradient_accumulation_steps=1 \ 40 | --max_train_steps=1000000 \ 41 | --learning_rate=1e-5 \ 42 | --lr_scheduler="constant" \ 43 | --lr_warmup_steps=0 \ 44 | --mixed_precision="bf16" \ 45 | --report_to="wandb" \ 46 | --checkpointing_steps=500 \ 47 | --allow_tf32 \ 48 | --model_max_length 512 \ 49 | --use_ema \ 50 | --ema_start_step 0 \ 51 | --cfg 0.1 \ 52 | --speed_factor 1.0 \ 53 | --ema_decay 0.9999 \ 54 | --drop_short_ratio 0.0 \ 55 | --hw_stride 32 \ 56 | --sparse1d --sparse_n=4 \ 57 | --train_fps 16 \ 58 | --seed 1234 \ 59 | --trained_data_global_step 0 \ 60 | --group_data \ 61 | --use_decord \ 62 | --prediction_type "v_prediction" \ 63 | --output_dir="/home/save_dir/runs/$PROJECT" \ 64 | --mask_config scripts/train_configs/mask_config.yaml \ 65 | --add_noise_to_condition \ 66 | --default_text_ratio 0.5 \ 67 | --resume_from_checkpoint="latest" 68 | # --pretrained "/home/save_dir/pretrained/93x640x640_144k_ema" 69 | # --force_resolution 70 | # --force_resolution \ 71 | # --max_height 352 \ 72 | # --max_width 640 \ 73 | -------------------------------------------------------------------------------- /scripts/text_condition/npu/train_t2v_v1_3.sh: -------------------------------------------------------------------------------- 1 | 2 | export PROJECT=$PROJECT_NAME 3 | # export PROJECT='test' 4 | export HF_DATASETS_OFFLINE=1 5 | export TRANSFORMERS_OFFLINE=1 6 | 7 | export TASK_QUEUE_ENABLE=0 8 | export HCCL_OP_BASE_FFTS_MODE_ENABLE=TRUE 9 | export MULTI_STREAM_MEMORY_REUSE=1 10 | export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True 11 | # export HCCL_ALGO="level0:NA;level1:H-D_R" 12 | # --machine_rank=${MACHINE_RANK} \ 13 | # --main_process_ip=${MAIN_PROCESS_IP_VALUE} \ 14 | # multi_node_example_by_deepspeed.yaml 15 | # deepspeed_zero2_config.yaml 16 | 17 | accelerate launch \ 18 | --config_file scripts/accelerate_configs/deepspeed_zero2_config.yaml \ 19 | opensora/train/train_t2v_diffusers.py \ 20 | --model OpenSoraT2V_v1_3-2B/122 \ 21 | --text_encoder_name_1 google/mt5-xxl \ 22 | --cache_dir "../../cache_dir/" \ 23 | --dataset t2v \ 24 | --data "scripts/train_data/video_data_debug_on_npu.txt" \ 25 | --ae WFVAEModel_D8_4x8x8 \ 26 | --ae_path "/home/save_dir/lzj/formal_8dim/latent8" \ 27 | --sample_rate 1 \ 28 | --num_frames 93 \ 29 | --max_height 352 \ 30 | --max_width 640 \ 31 | --force_resolution \ 32 | --interpolation_scale_t 1.0 \ 33 | --interpolation_scale_h 1.0 \ 34 | --interpolation_scale_w 1.0 \ 35 | --gradient_checkpointing \ 36 | --train_batch_size=1 \ 37 | --dataloader_num_workers 8 \ 38 | --gradient_accumulation_steps=1 \ 39 | --max_train_steps=1000000 \ 40 | --learning_rate=1e-5 \ 41 | --lr_scheduler="constant" \ 42 | --lr_warmup_steps=0 \ 43 | --mixed_precision="bf16" \ 44 | --report_to="wandb" \ 45 | --checkpointing_steps=500 \ 46 | --allow_tf32 \ 47 | --model_max_length 512 \ 48 | --use_ema \ 49 | --ema_start_step 0 \ 50 | --cfg 0.1 \ 51 | --resume_from_checkpoint="latest" \ 52 | --speed_factor 1.0 \ 53 | --ema_decay 0.9999 \ 54 | --drop_short_ratio 0.0 \ 55 | --pretrained "/home/save_dir/pretrained/93x640x640_144k_ema" \ 56 | --hw_stride 32 \ 57 | --sparse1d --sparse_n 4 \ 58 | --train_fps 16 \ 59 | --seed 1234 \ 60 | --trained_data_global_step 0 \ 61 | --group_data \ 62 | --use_decord \ 63 | --prediction_type "v_prediction" \ 64 | --snr_gamma 5.0 \ 65 | --rescale_betas_zero_snr \ 66 | --output_dir="debug" -------------------------------------------------------------------------------- /scripts/train_configs/mask_config.yaml: -------------------------------------------------------------------------------- 1 | # mask processor args 2 | min_clear_ratio: 0.0 3 | max_clear_ratio: 1.0 4 | 5 | # mask_type_ratio_dict_video 6 | mask_type_ratio_dict_video: 7 | t2iv: 1 8 | i2v: 8 9 | transition: 8 10 | continuation: 2 11 | clear: 0 12 | random_temporal: 1 13 | 14 | mask_type_ratio_dict_image: 15 | t2iv: 0 16 | clear: 0 -------------------------------------------------------------------------------- /scripts/train_data/merge_data.txt: -------------------------------------------------------------------------------- 1 | /storage/dataset/recap_datacomp_1b_data/output,/storage/anno_pkl/img_nocn_res160_pkl/recap_64part_filter_aes_res160_pkl/part0_7036495.pkl --------------------------------------------------------------------------------