├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── docker ├── LICENSE ├── README.md ├── build_docker.png ├── docker_build.sh ├── docker_run.sh ├── dockerfile.base ├── packages.txt ├── ports.txt ├── postinstallscript.sh ├── requirements.txt ├── run_docker.png └── setup_env.sh ├── docs ├── Contribution_Guidelines.md ├── Data.md ├── EVAL.md ├── Report-v1.0.0.md └── VQVAE.md ├── examples ├── get_latents_std.py ├── prompt_list_0.txt ├── rec_image.py ├── rec_imvi_vae.py ├── rec_video.py ├── rec_video_ae.py └── rec_video_vae.py ├── nodes.py ├── opensora ├── __init__.py ├── dataset │ ├── __init__.py │ ├── extract_feature_dataset.py │ ├── feature_datasets.py │ ├── landscope.py │ ├── sky_datasets.py │ ├── t2v_datasets.py │ ├── transform.py │ └── ucf101.py ├── eval │ ├── cal_flolpips.py │ ├── cal_fvd.py │ ├── cal_lpips.py │ ├── cal_psnr.py │ ├── cal_ssim.py │ ├── eval_clip_score.py │ ├── eval_common_metric.py │ ├── flolpips │ │ ├── correlation │ │ │ └── correlation.py │ │ ├── flolpips.py │ │ ├── pretrained_networks.py │ │ ├── pwcnet.py │ │ └── utils.py │ ├── fvd │ │ ├── styleganv │ │ │ ├── fvd.py │ │ │ └── i3d_torchscript.pt │ │ └── videogpt │ │ │ ├── fvd.py │ │ │ ├── i3d_pretrained_400.pt │ │ │ └── pytorch_i3d.py │ └── script │ │ ├── cal_clip_score.sh │ │ ├── cal_fvd.sh │ │ ├── cal_lpips.sh │ │ ├── cal_psnr.sh │ │ └── cal_ssim.sh ├── models │ ├── __init__.py │ ├── ae │ │ ├── __init__.py │ │ ├── imagebase │ │ │ ├── __init__.py │ │ │ ├── vae │ │ │ │ └── vae.py │ │ │ └── vqvae │ │ │ │ ├── model.py │ │ │ │ ├── quantize.py │ │ │ │ ├── vqgan.py │ │ │ │ └── vqvae.py │ │ └── videobase │ │ │ ├── __init__.py │ │ │ ├── causal_vae │ │ │ ├── __init__.py │ │ │ └── modeling_causalvae.py │ │ │ ├── causal_vqvae │ │ │ ├── __init__.py │ │ │ ├── configuration_causalvqvae.py │ │ │ ├── modeling_causalvqvae.py │ │ │ └── trainer_causalvqvae.py │ │ │ ├── configuration_videobase.py │ │ │ ├── dataset_videobase.py │ │ │ ├── losses │ │ │ ├── __init__.py │ │ │ ├── discriminator.py │ │ │ ├── lpips.py │ │ │ └── perceptual_loss.py │ │ │ ├── modeling_videobase.py │ │ │ ├── modules │ │ │ ├── __init__.py │ │ │ ├── attention.py │ │ │ ├── block.py │ │ │ ├── conv.py │ │ │ ├── normalize.py │ │ │ ├── ops.py │ │ │ ├── quant.py │ │ │ ├── resnet_block.py │ │ │ └── updownsample.py │ │ │ ├── trainer_videobase.py │ │ │ ├── utils │ │ │ ├── distrib_utils.py │ │ │ ├── module_utils.py │ │ │ ├── scheduler_utils.py │ │ │ └── video_utils.py │ │ │ └── vqvae │ │ │ ├── __init__.py │ │ │ ├── configuration_vqvae.py │ │ │ ├── modeling_vqvae.py │ │ │ └── trainer_vqvae.py │ ├── captioner │ │ └── caption_refiner │ │ │ ├── README.md │ │ │ ├── caption_refiner.py │ │ │ ├── dataset │ │ │ └── test_videos │ │ │ │ ├── captions.json │ │ │ │ ├── video1.gif │ │ │ │ └── video2.gif │ │ │ ├── demo_for_refiner.py │ │ │ └── gpt_combinator.py │ ├── diffusion │ │ ├── __init__.py │ │ ├── diffusion │ │ │ ├── __init__.py │ │ │ ├── diffusion_utils.py │ │ │ ├── gaussian_diffusion.py │ │ │ ├── gaussian_diffusion_t2v.py │ │ │ ├── respace.py │ │ │ └── timestep_sampler.py │ │ ├── latte │ │ │ ├── modeling_latte.py │ │ │ ├── modules.py │ │ │ └── pos.py │ │ ├── transport │ │ │ ├── __init__.py │ │ │ ├── integrators.py │ │ │ ├── path.py │ │ │ ├── transport.py │ │ │ └── utils.py │ │ └── utils │ │ │ ├── curope │ │ │ ├── __init__.py │ │ │ ├── curope.cpp │ │ │ ├── curope2d.py │ │ │ ├── kernels.cu │ │ │ └── setup.py │ │ │ └── pos_embed.py │ ├── frame_interpolation │ │ ├── cfgs │ │ │ └── AMT-G.yaml │ │ ├── interpolation.py │ │ ├── networks │ │ │ ├── AMT-G.py │ │ │ ├── __init__.py │ │ │ └── blocks │ │ │ │ ├── __init__.py │ │ │ │ ├── feat_enc.py │ │ │ │ ├── ifrnet.py │ │ │ │ ├── multi_flow.py │ │ │ │ └── raft.py │ │ ├── readme.md │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── build_utils.py │ │ │ ├── dist_utils.py │ │ │ ├── flow_utils.py │ │ │ └── utils.py │ ├── super_resolution │ │ ├── README.md │ │ ├── basicsr │ │ │ ├── __init__.py │ │ │ ├── archs │ │ │ │ ├── __init__.py │ │ │ │ ├── arch_util.py │ │ │ │ ├── rgt_arch.py │ │ │ │ └── vgg_arch.py │ │ │ ├── data │ │ │ │ ├── __init__.py │ │ │ │ ├── data_sampler.py │ │ │ │ ├── data_util.py │ │ │ │ ├── paired_image_dataset.py │ │ │ │ ├── prefetch_dataloader.py │ │ │ │ ├── single_image_dataset.py │ │ │ │ └── transforms.py │ │ │ ├── losses │ │ │ │ ├── __init__.py │ │ │ │ ├── loss_util.py │ │ │ │ └── losses.py │ │ │ ├── metrics │ │ │ │ ├── __init__.py │ │ │ │ ├── metric_util.py │ │ │ │ └── psnr_ssim.py │ │ │ ├── models │ │ │ │ ├── __init__.py │ │ │ │ ├── base_model.py │ │ │ │ ├── lr_scheduler.py │ │ │ │ ├── rgt_model.py │ │ │ │ └── sr_model.py │ │ │ ├── test_img.py │ │ │ └── utils │ │ │ │ ├── __init__.py │ │ │ │ ├── dist_util.py │ │ │ │ ├── file_client.py │ │ │ │ ├── img_util.py │ │ │ │ ├── logger.py │ │ │ │ ├── matlab_functions.py │ │ │ │ ├── misc.py │ │ │ │ ├── options.py │ │ │ │ └── registry.py │ │ ├── options │ │ │ └── test │ │ │ │ ├── test_RGT_x2.yml │ │ │ │ ├── test_RGT_x4.yml │ │ │ │ └── test_single_config.yml │ │ └── run.py │ └── text_encoder │ │ ├── __init__.py │ │ ├── clip.py │ │ └── t5.py ├── sample │ ├── pipeline_videogen.py │ ├── sample.py │ ├── sample_t2v.py │ └── transport_sample.py ├── serve │ ├── gradio_utils.py │ └── gradio_web_server.py ├── train │ ├── train.py │ ├── train_causalvae.py │ ├── train_t2v.py │ ├── train_t2v_feature.py │ ├── train_t2v_t5_feature.py │ └── train_videogpt.py └── utils │ ├── dataset_utils.py │ ├── downloader.py │ ├── taming_download.py │ └── utils.py ├── pyproject.toml ├── requirements.txt ├── scripts ├── accelerate_configs │ ├── ddp_config.yaml │ ├── deepspeed_zero2_config.yaml │ ├── deepspeed_zero2_offload_config.yaml │ ├── deepspeed_zero3_config.yaml │ ├── deepspeed_zero3_offload_config.yaml │ ├── default_config.yaml │ ├── hostfile │ ├── multi_node_example.yaml │ ├── zero2.json │ ├── zero2_offload.json │ ├── zero3.json │ └── zero3_offload.json ├── causalvae │ ├── eval.sh │ ├── gen_video.sh │ ├── release.json │ └── train.sh ├── class_condition │ ├── sample.sh │ ├── train_imgae.sh │ └── train_vidae.sh ├── slurm │ └── placeholder ├── text_condition │ ├── sample_image.sh │ ├── sample_video.sh │ ├── train_imageae.sh │ ├── train_videoae_17x256x256.sh │ ├── train_videoae_65x256x256.sh │ └── train_videoae_65x512x512.sh ├── un_condition │ ├── sample.sh │ ├── train_imgae.sh │ └── train_vidae.sh └── videogpt │ ├── train_videogpt.sh │ ├── train_videogpt_dsz2.sh │ └── train_videogpt_dsz3.sh ├── wf.json └── wf.png /.gitignore: -------------------------------------------------------------------------------- 1 | ucf101_stride4x4x4 2 | __pycache__ 3 | *.mp4 4 | .ipynb_checkpoints 5 | *.pth 6 | UCF-101/ 7 | results/ 8 | vae 9 | build/ 10 | opensora.egg-info/ 11 | wandb/ 12 | .idea 13 | *.ipynb 14 | *.jpg 15 | *.mp3 16 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 PKU-YUAN's Group (袁粒课题组-北大信工) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ComfyUI-Open-Sora-Plan 2 | 3 | ## workflow 4 | 5 | https://github.com/chaojie/ComfyUI-Open-Sora-Plan/blob/main/wf.json 6 | 7 | 8 | 9 | ### [Open-Sora-Plan](https://github.com/PKU-YuanGroup/Open-Sora-Plan) 10 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from .nodes import NODE_CLASS_MAPPINGS 2 | 3 | __all__ = ['NODE_CLASS_MAPPINGS'] -------------------------------------------------------------------------------- /docker/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 SimonLee 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /docker/README.md: -------------------------------------------------------------------------------- 1 | # Docker4ML 2 | 3 | Useful docker scripts for ML developement. 4 | [https://github.com/SimonLeeGit/Docker4ML](https://github.com/SimonLeeGit/Docker4ML) 5 | 6 | ## Build Docker Image 7 | 8 | ```bash 9 | bash docker_build.sh 10 | ``` 11 | 12 | ![build_docker](build_docker.png) 13 | 14 | ## Run Docker Container as Development Envirnoment 15 | 16 | ```bash 17 | bash docker_run.sh 18 | ``` 19 | 20 | ![run_docker](run_docker.png) 21 | 22 | ## Custom Docker Config 23 | 24 | ### Config [setup_env.sh](./setup_env.sh) 25 | 26 | You can modify this file to custom your settings. 27 | 28 | ```bash 29 | TAG=ml:dev 30 | BASE_TAG=nvcr.io/nvidia/pytorch:23.12-py3 31 | ``` 32 | 33 | #### TAG 34 | 35 | Your built docker image tag, you can set it as what you what. 36 | 37 | #### BASE_TAG 38 | 39 | The base docker image tag for your built docker image, here we use nvidia pytorch images. 40 | You can check it from [https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch/tags](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch/tags) 41 | 42 | Also, you can use other docker image as base, such as: [ubuntu](https://hub.docker.com/_/ubuntu/tags) 43 | 44 | ### USER_NAME 45 | 46 | Your user name used in docker container. 47 | 48 | ### USER_PASSWD 49 | 50 | Your user password used in docker container. 51 | 52 | ### Config [requriements.txt](./requirements.txt) 53 | 54 | You can add your default installed python libraries here. 55 | 56 | ```txt 57 | transformers==4.27.1 58 | ``` 59 | 60 | By default, it has some libs installed, you can check it from [https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-01.html](https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-01.html) 61 | 62 | ### Config [packages.txt](./packages.txt) 63 | 64 | You can add your default apt-get installed packages here. 65 | 66 | ```txt 67 | wget 68 | curl 69 | git 70 | ``` 71 | 72 | ### Config [ports.txt](./ports.txt) 73 | 74 | You can add some ports enabled for docker container here. 75 | 76 | ```txt 77 | -p 6006:6006 78 | -p 8080:8080 79 | ``` 80 | 81 | ### Config [postinstallscript.sh](./postinstallscript.sh) 82 | 83 | You can add your custom script to run when build docker image. 84 | 85 | ## Q&A 86 | 87 | If you have any use problems, please contact to . 88 | -------------------------------------------------------------------------------- /docker/build_docker.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-Open-Sora-Plan/b060ff6d7a85a27eec5ff9b81b599d03c4ac1bc6/docker/build_docker.png -------------------------------------------------------------------------------- /docker/docker_build.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | WORK_DIR=$(dirname "$(readlink -f "$0")") 4 | cd $WORK_DIR 5 | 6 | source setup_env.sh 7 | 8 | docker build -t $TAG --build-arg BASE_TAG=$BASE_TAG --build-arg USER_NAME=$USER_NAME --build-arg USER_PASSWD=$USER_PASSWD . -f dockerfile.base 9 | -------------------------------------------------------------------------------- /docker/docker_run.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | WORK_DIR=$(dirname "$(readlink -f "$0")") 4 | source $WORK_DIR/setup_env.sh 5 | 6 | RUNNING_IDS="$(docker ps --filter ancestor=$TAG --format "{{.ID}}")" 7 | 8 | if [ -n "$RUNNING_IDS" ]; then 9 | # Initialize an array to hold the container IDs 10 | declare -a container_ids=($RUNNING_IDS) 11 | 12 | # Get the first container ID using array indexing 13 | ID=${container_ids[0]} 14 | 15 | # Print the first container ID 16 | echo ' ' 17 | echo "The running container ID is: $ID, enter it!" 18 | else 19 | echo ' ' 20 | echo "Not found running containers, run it!" 21 | 22 | # Run a new docker container instance 23 | ID=$(docker run \ 24 | --rm \ 25 | --gpus all \ 26 | -itd \ 27 | --ipc=host \ 28 | --ulimit memlock=-1 \ 29 | --ulimit stack=67108864 \ 30 | -e DISPLAY=$DISPLAY \ 31 | -v /tmp/.X11-unix/:/tmp/.X11-unix/ \ 32 | -v $PWD:/home/$USER_NAME/workspace \ 33 | -w /home/$USER_NAME/workspace \ 34 | $(cat $WORK_DIR/ports.txt) \ 35 | $TAG) 36 | fi 37 | 38 | docker logs $ID 39 | 40 | echo ' ' 41 | echo ' ' 42 | echo '=========================================' 43 | echo ' ' 44 | 45 | docker exec -it $ID bash 46 | -------------------------------------------------------------------------------- /docker/dockerfile.base: -------------------------------------------------------------------------------- 1 | ARG BASE_TAG 2 | FROM ${BASE_TAG} 3 | ARG USER_NAME=myuser 4 | ARG USER_PASSWD=111111 5 | ARG DEBIAN_FRONTEND=noninteractive 6 | 7 | # Pre-install packages, pip install requirements and run post install script. 8 | COPY packages.txt . 9 | COPY requirements.txt . 10 | COPY postinstallscript.sh . 11 | RUN apt-get update && apt-get install -y sudo $(cat packages.txt) 12 | RUN pip install --no-cache-dir -r requirements.txt 13 | RUN bash postinstallscript.sh 14 | 15 | # Create a new user and group using the username argument 16 | RUN groupadd -r ${USER_NAME} && useradd -r -m -g${USER_NAME} ${USER_NAME} 17 | RUN echo "${USER_NAME}:${USER_PASSWD}" | chpasswd 18 | RUN usermod -aG sudo ${USER_NAME} 19 | USER ${USER_NAME} 20 | ENV USER=${USER_NAME} 21 | WORKDIR /home/${USER_NAME}/workspace 22 | 23 | # Set the prompt to highlight the username 24 | RUN echo "export PS1='\[\033[01;32m\]\u\[\033[00m\]@\[\033[01;34m\]\h\[\033[00m\]:\[\033[01;36m\]\w\[\033[00m\]\$'" >> /home/${USER_NAME}/.bashrc 25 | -------------------------------------------------------------------------------- /docker/packages.txt: -------------------------------------------------------------------------------- 1 | wget 2 | curl 3 | git -------------------------------------------------------------------------------- /docker/ports.txt: -------------------------------------------------------------------------------- 1 | -p 6006:6006 -------------------------------------------------------------------------------- /docker/postinstallscript.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # this script will run when build docker image. 3 | 4 | -------------------------------------------------------------------------------- /docker/requirements.txt: -------------------------------------------------------------------------------- 1 | setuptools>=61.0 2 | torch==2.0.1 3 | torchvision==0.15.2 4 | transformers==4.32.0 5 | albumentations==1.4.0 6 | av==11.0.0 7 | decord==0.6.0 8 | einops==0.3.0 9 | fastapi==0.110.0 10 | accelerate==0.21.0 11 | gdown==5.1.0 12 | h5py==3.10.0 13 | idna==3.6 14 | imageio==2.34.0 15 | matplotlib==3.7.5 16 | numpy==1.24.4 17 | omegaconf==2.1.1 18 | opencv-python==4.9.0.80 19 | opencv-python-headless==4.9.0.80 20 | pandas==2.0.3 21 | pillow==10.2.0 22 | pydub==0.25.1 23 | pytorch-lightning==1.4.2 24 | pytorchvideo==0.1.5 25 | PyYAML==6.0.1 26 | regex==2023.12.25 27 | requests==2.31.0 28 | scikit-learn==1.3.2 29 | scipy==1.10.1 30 | six==1.16.0 31 | tensorboard==2.14.0 32 | test-tube==0.7.5 33 | timm==0.9.16 34 | torchdiffeq==0.2.3 35 | torchmetrics==0.5.0 36 | tqdm==4.66.2 37 | urllib3==2.2.1 38 | uvicorn==0.27.1 39 | diffusers==0.24.0 40 | scikit-video==1.1.11 41 | -------------------------------------------------------------------------------- /docker/run_docker.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-Open-Sora-Plan/b060ff6d7a85a27eec5ff9b81b599d03c4ac1bc6/docker/run_docker.png -------------------------------------------------------------------------------- /docker/setup_env.sh: -------------------------------------------------------------------------------- 1 | # Docker tag for new build image 2 | TAG=open_sora_plan:dev 3 | 4 | # Base docker image tag used by docker build 5 | BASE_TAG=nvcr.io/nvidia/pytorch:23.05-py3 6 | 7 | # User name used in docker container 8 | USER_NAME=developer 9 | 10 | # User password used in docker container 11 | USER_PASSWD=666666 -------------------------------------------------------------------------------- /docs/Contribution_Guidelines.md: -------------------------------------------------------------------------------- 1 | # Contributing to the Open-Sora Plan Community 2 | 3 | The Open-Sora Plan open-source community is a collaborative initiative driven by the community, emphasizing a commitment to being free and void of exploitation. Organized spontaneously by community members, we invite you to contribute to the Open-Sora Plan open-source community and help elevate it to new heights! 4 | 5 | ## Submitting a Pull Request (PR) 6 | 7 | As a contributor, before submitting your request, kindly follow these guidelines: 8 | 9 | 1. Start by checking the [Open-Sora Plan GitHub](https://github.com/PKU-YuanGroup/Open-Sora-Plan/pulls) to see if there are any open or closed pull requests related to your intended submission. Avoid duplicating existing work. 10 | 11 | 2. [Fork](https://github.com/PKU-YuanGroup/Open-Sora-Plan/fork) the [open-sora plan](https://github.com/PKU-YuanGroup/Open-Sora-Plan) repository and download your forked repository to your local machine. 12 | 13 | ```bash 14 | git clone [your-forked-repository-url] 15 | ``` 16 | 17 | 3. Add the original Open-Sora Plan repository as a remote to sync with the latest updates: 18 | 19 | ```bash 20 | git remote add upstream https://github.com/PKU-YuanGroup/Open-Sora-Plan 21 | ``` 22 | 23 | 4. Sync the code from the main repository to your local machine, and then push it back to your forked remote repository. 24 | 25 | ``` 26 | # Pull the latest code from the upstream branch 27 | git fetch upstream 28 | 29 | # Switch to the main branch 30 | git checkout main 31 | 32 | # Merge the updates from the upstream branch into main, synchronizing the local main branch with the upstream 33 | git merge upstream/main 34 | 35 | # Additionally, sync the local main branch to the remote branch of your forked repository 36 | git push origin main 37 | ``` 38 | 39 | 40 | > Note: Sync the code from the main repository before each submission. 41 | 42 | 5. Create a branch in your forked repository for your changes, ensuring the branch name is meaningful. 43 | 44 | ```bash 45 | git checkout -b my-docs-branch main 46 | ``` 47 | 48 | 6. While making modifications and committing changes, adhere to our [Commit Message Format](#Commit-Message-Format). 49 | 50 | ```bash 51 | git commit -m "[docs]: xxxx" 52 | ``` 53 | 54 | 7. Push your changes to your GitHub repository. 55 | 56 | ```bash 57 | git push origin my-docs-branch 58 | ``` 59 | 60 | 8. Submit a pull request to `Open-Sora-Plan:main` on the GitHub repository page. 61 | 62 | ## Commit Message Format 63 | 64 | Commit messages must include both `` and `` sections. 65 | 66 | ```bash 67 | []: 68 | │ │ 69 | │ └─⫸ Briefly describe your changes, without ending with a period. 70 | │ 71 | └─⫸ Commit Type: |docs|feat|fix|refactor| 72 | ``` 73 | 74 | ### Type 75 | 76 | * **docs**: Modify or add documents. 77 | * **feat**: Introduce a new feature. 78 | * **fix**: Fix a bug. 79 | * **refactor**: Restructure code, excluding new features or bug fixes. 80 | 81 | ### Summary 82 | 83 | Describe modifications in English, without ending with a period. 84 | 85 | > e.g., git commit -m "[docs]: add a contributing.md file" 86 | 87 | This guideline is borrowed by [minisora](https://github.com/mini-sora/minisora). We sincerely appreciate MiniSora authors for their awesome templates. 88 | -------------------------------------------------------------------------------- /docs/Data.md: -------------------------------------------------------------------------------- 1 | 2 | **We need more dataset**, please refer to the [open-sora-Dataset](https://github.com/shaodong233/open-sora-Dataset) for details. 3 | 4 | 5 | ## Sky 6 | 7 | 8 | This is an un-condition datasets. [Link](https://drive.google.com/open?id=1xWLiU-MBGN7MrsFHQm4_yXmfHBsMbJQo) 9 | 10 | ``` 11 | sky_timelapse 12 | ├── readme 13 | ├── sky_test 14 | ├── sky_train 15 | ├── test_videofolder.py 16 | └── video_folder.py 17 | ``` 18 | 19 | ## UCF101 20 | 21 | We test the code with UCF-101 dataset. In order to download UCF-101 dataset, you can download the necessary files in [here](https://www.crcv.ucf.edu/data/UCF101.php). The code assumes a `ucf101` directory with the following structure 22 | ``` 23 | UCF-101/ 24 | ApplyEyeMakeup/ 25 | v1.avi 26 | ... 27 | ... 28 | YoYo/ 29 | v1.avi 30 | ... 31 | ``` 32 | 33 | 34 | ## Offline feature extraction 35 | Coming soon... 36 | -------------------------------------------------------------------------------- /docs/EVAL.md: -------------------------------------------------------------------------------- 1 | # Evaluate the generated videos quality 2 | 3 | You can easily calculate the following video quality metrics, which supports the batch-wise process. 4 | - **CLIP-SCORE**: It uses the pretrained CLIP model to measure the cosine similarity between two modalities. 5 | - **FVD**: Frechét Video Distance 6 | - **SSIM**: structural similarity index measure 7 | - **LPIPS**: learned perceptual image patch similarity 8 | - **PSNR**: peak-signal-to-noise ratio 9 | 10 | # Requirement 11 | ## Environment 12 | - install Pytorch (torch>=1.7.1) 13 | - install CLIP 14 | ``` 15 | pip install git+https://github.com/openai/CLIP.git 16 | ``` 17 | - install clip-cose from PyPi 18 | ``` 19 | pip install clip-score 20 | ``` 21 | - Other package 22 | ``` 23 | pip install lpips 24 | pip install scipy (scipy==1.7.3/1.9.3, if you use 1.11.3, **you will calculate a WRONG FVD VALUE!!!**) 25 | pip install numpy 26 | pip install pillow 27 | pip install torchvision>=0.8.2 28 | pip install ftfy 29 | pip install regex 30 | pip install tqdm 31 | ``` 32 | ## Pretrain model 33 | - FVD 34 | Before you cacluate FVD, you should first download the FVD pre-trained model. You can manually download any of the following and put it into FVD folder. 35 | - `i3d_torchscript.pt` from [here](https://www.dropbox.com/s/ge9e5ujwgetktms/i3d_torchscript.pt) 36 | - `i3d_pretrained_400.pt` from [here](https://onedrive.live.com/download?cid=78EEF3EB6AE7DBCB&resid=78EEF3EB6AE7DBCB%21199&authkey=AApKdFHPXzWLNyI) 37 | 38 | 39 | ## Other Notices 40 | 1. Make sure the pixel value of videos should be in [0, 1]. 41 | 2. We average SSIM when images have 3 channels, ssim is the only metric extremely sensitive to gray being compared to b/w. 42 | 3. Because the i3d model downsamples in the time dimension, `frames_num` should > 10 when calculating FVD, so FVD calculation begins from 10-th frame, like upper example. 43 | 4. For grayscale videos, we multiply to 3 channels 44 | 5. data input specifications for clip_score 45 | > - Image Files:All images should be stored in a single directory. The image files can be in either .png or .jpg format. 46 | > 47 | > - Text Files: All text data should be contained in plain text files in a separate directory. These text files should have the extension .txt. 48 | > 49 | > Note: The number of files in the image directory should be exactly equal to the number of files in the text directory. Additionally, the files in the image directory and text directory should be paired by file name. For instance, if there is a cat.png in the image directory, there should be a corresponding cat.txt in the text directory. 50 | > 51 | > Directory Structure Example: 52 | > ``` 53 | > ├── path/to/image 54 | > │ ├── cat.png 55 | > │ ├── dog.png 56 | > │ └── bird.jpg 57 | > └── path/to/text 58 | > ├── cat.txt 59 | > ├── dog.txt 60 | > └── bird.txt 61 | > ``` 62 | 63 | 6. data input specifications for fvd, psnr, ssim, lpips 64 | 65 | > Directory Structure Example: 66 | > ``` 67 | > ├── path/to/generated_image 68 | > │ ├── cat.mp4 69 | > │ ├── dog.mp4 70 | > │ └── bird.mp4 71 | > └── path/to/real_image 72 | > ├── cat.mp4 73 | > ├── dog.mp4 74 | > └── bird.mp4 75 | > ``` 76 | 77 | 78 | 79 | # Usage 80 | 81 | ``` 82 | # you change the file path and need to set the frame_num, resolution etc... 83 | 84 | # clip_score cross modality 85 | cd opensora/eval 86 | bash script/cal_clip_score.sh 87 | 88 | 89 | 90 | # fvd 91 | cd opensora/eval 92 | bash script/cal_fvd.sh 93 | 94 | # psnr 95 | cd opensora/eval 96 | bash eval/script/cal_psnr.sh 97 | 98 | 99 | # ssim 100 | cd opensora/eval 101 | bash eval/script/cal_ssim.sh 102 | 103 | 104 | # lpips 105 | cd opensora/eval 106 | bash eval/script/cal_lpips.sh 107 | ``` 108 | 109 | # Acknowledgement 110 | The evaluation codebase refers to [clip-score](https://github.com/Taited/clip-score) and [common_metrics](https://github.com/JunyaoHu/common_metrics_on_video_quality). -------------------------------------------------------------------------------- /docs/VQVAE.md: -------------------------------------------------------------------------------- 1 | # VQVAE Documentation 2 | 3 | # Introduction 4 | 5 | Vector Quantized Variational AutoEncoders (VQ-VAE) is a type of autoencoder that uses a discrete latent representation. It is particularly useful for tasks that require discrete latent variables, such as text-to-speech and video generation. 6 | 7 | # Usage 8 | 9 | ## Initialization 10 | 11 | To initialize a VQVAE model, you can use the `VideoGPTVQVAE` class. This class is a part of the `opensora.models.ae` module. 12 | 13 | ```python 14 | from opensora.models.ae import VideoGPTVQVAE 15 | 16 | vqvae = VideoGPTVQVAE() 17 | ``` 18 | 19 | ### Training 20 | 21 | To train the VQVAE model, you can use the `train_videogpt.sh` script. This script will train the model using the parameters specified in the script. 22 | 23 | ```bash 24 | bash scripts/videogpt/train_videogpt.sh 25 | ``` 26 | 27 | ### Loading Pretrained Models 28 | 29 | You can load a pretrained model using the `download_and_load_model` method. This method will download the checkpoint file and load the model. 30 | 31 | ```python 32 | vqvae = VideoGPTVQVAE.download_and_load_model("bair_stride4x2x2") 33 | ``` 34 | 35 | Alternatively, you can load a model from a checkpoint using the `load_from_checkpoint` method. 36 | 37 | ```python 38 | vqvae = VQVAEModel.load_from_checkpoint("results/VQVAE/checkpoint-1000") 39 | ``` 40 | 41 | ### Encoding and Decoding 42 | 43 | You can encode a video using the `encode` method. This method will return the encodings and embeddings of the video. 44 | 45 | ```python 46 | encodings, embeddings = vqvae.encode(x_vae, include_embeddings=True) 47 | ``` 48 | 49 | You can reconstruct a video from its encodings using the decode method. 50 | 51 | ```python 52 | video_recon = vqvae.decode(encodings) 53 | ``` 54 | 55 | ## Testing 56 | 57 | You can test the VQVAE model by reconstructing a video. The `examples/rec_video.py` script provides an example of how to do this. -------------------------------------------------------------------------------- /examples/get_latents_std.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader, Subset 3 | import sys 4 | sys.path.append(".") 5 | from opensora.models.ae.videobase import CausalVAEModel, CausalVAEDataset 6 | 7 | num_workers = 4 8 | batch_size = 12 9 | 10 | torch.manual_seed(0) 11 | torch.set_grad_enabled(False) 12 | 13 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 14 | 15 | pretrained_model_name_or_path = 'results/causalvae/checkpoint-26000' 16 | data_path = '/remote-home1/dataset/UCF-101' 17 | video_num_frames = 17 18 | resolution = 128 19 | sample_rate = 10 20 | 21 | vae = CausalVAEModel.load_from_checkpoint(pretrained_model_name_or_path) 22 | vae.to(device) 23 | 24 | dataset = CausalVAEDataset(data_path, sequence_length=video_num_frames, resolution=resolution, sample_rate=sample_rate) 25 | subset_indices = list(range(1000)) 26 | subset_dataset = Subset(dataset, subset_indices) 27 | loader = DataLoader(subset_dataset, batch_size=8, pin_memory=True) 28 | 29 | all_latents = [] 30 | for video_data in loader: 31 | video_data = video_data['video'].to(device) 32 | latents = vae.encode(video_data).sample() 33 | all_latents.append(video_data.cpu()) 34 | 35 | all_latents_tensor = torch.cat(all_latents) 36 | std = all_latents_tensor.std().item() 37 | normalizer = 1 / std 38 | print(f'{normalizer = }') -------------------------------------------------------------------------------- /examples/rec_image.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append(".") 3 | from PIL import Image 4 | import torch 5 | from torchvision.transforms import ToTensor, Compose, Resize, Normalize 6 | from torch.nn import functional as F 7 | from opensora.models.ae.videobase import CausalVAEModel 8 | import argparse 9 | import numpy as np 10 | 11 | def preprocess(video_data: torch.Tensor, short_size: int = 128) -> torch.Tensor: 12 | transform = Compose( 13 | [ 14 | ToTensor(), 15 | Normalize((0.5), (0.5)), 16 | Resize(size=short_size), 17 | ] 18 | ) 19 | outputs = transform(video_data) 20 | outputs = outputs.unsqueeze(0).unsqueeze(2) 21 | return outputs 22 | 23 | def main(args: argparse.Namespace): 24 | image_path = args.image_path 25 | resolution = args.resolution 26 | device = args.device 27 | 28 | vqvae = CausalVAEModel.load_from_checkpoint(args.ckpt) 29 | vqvae.eval() 30 | vqvae = vqvae.to(device) 31 | 32 | with torch.no_grad(): 33 | x_vae = preprocess(Image.open(image_path), resolution) 34 | x_vae = x_vae.to(device) 35 | latents = vqvae.encode(x_vae) 36 | recon = vqvae.decode(latents.sample()) 37 | x = recon[0, :, 0, :, :] 38 | x = x.squeeze() 39 | x = x.detach().cpu().numpy() 40 | x = np.clip(x, -1, 1) 41 | x = (x + 1) / 2 42 | x = (255*x).astype(np.uint8) 43 | x = x.transpose(1,2,0) 44 | image = Image.fromarray(x) 45 | image.save(args.rec_path) 46 | 47 | 48 | if __name__ == '__main__': 49 | parser = argparse.ArgumentParser() 50 | parser.add_argument('--image-path', type=str, default='') 51 | parser.add_argument('--rec-path', type=str, default='') 52 | parser.add_argument('--ckpt', type=str, default='') 53 | parser.add_argument('--resolution', type=int, default=336) 54 | parser.add_argument('--device', type=str, default='cuda') 55 | 56 | args = parser.parse_args() 57 | main(args) 58 | -------------------------------------------------------------------------------- /opensora/__init__.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------------------------------- /opensora/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from torchvision.transforms import Compose 2 | from transformers import AutoTokenizer 3 | 4 | from .feature_datasets import T2V_Feature_dataset, T2V_T5_Feature_dataset 5 | from torchvision import transforms 6 | from torchvision.transforms import Lambda 7 | 8 | from .t2v_datasets import T2V_dataset 9 | from .transform import ToTensorVideo, TemporalRandomCrop, RandomHorizontalFlipVideo, CenterCropResizeVideo, LongSideResizeVideo, SpatialStrideCropVideo 10 | 11 | 12 | ae_norm = { 13 | 'CausalVAEModel_4x8x8': Lambda(lambda x: 2. * x - 1.), 14 | 'CausalVQVAEModel_4x4x4': Lambda(lambda x: x - 0.5), 15 | 'CausalVQVAEModel_4x8x8': Lambda(lambda x: x - 0.5), 16 | 'VQVAEModel_4x4x4': Lambda(lambda x: x - 0.5), 17 | 'VQVAEModel_4x8x8': Lambda(lambda x: x - 0.5), 18 | "bair_stride4x2x2": Lambda(lambda x: x - 0.5), 19 | "ucf101_stride4x4x4": Lambda(lambda x: x - 0.5), 20 | "kinetics_stride4x4x4": Lambda(lambda x: x - 0.5), 21 | "kinetics_stride2x4x4": Lambda(lambda x: x - 0.5), 22 | 'stabilityai/sd-vae-ft-mse': transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), 23 | 'stabilityai/sd-vae-ft-ema': transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), 24 | 'vqgan_imagenet_f16_1024': Lambda(lambda x: 2. * x - 1.), 25 | 'vqgan_imagenet_f16_16384': Lambda(lambda x: 2. * x - 1.), 26 | 'vqgan_gumbel_f8': Lambda(lambda x: 2. * x - 1.), 27 | 28 | } 29 | ae_denorm = { 30 | 'CausalVAEModel_4x8x8': lambda x: (x + 1.) / 2., 31 | 'CausalVQVAEModel_4x4x4': lambda x: x + 0.5, 32 | 'CausalVQVAEModel_4x8x8': lambda x: x + 0.5, 33 | 'VQVAEModel_4x4x4': lambda x: x + 0.5, 34 | 'VQVAEModel_4x8x8': lambda x: x + 0.5, 35 | "bair_stride4x2x2": lambda x: x + 0.5, 36 | "ucf101_stride4x4x4": lambda x: x + 0.5, 37 | "kinetics_stride4x4x4": lambda x: x + 0.5, 38 | "kinetics_stride2x4x4": lambda x: x + 0.5, 39 | 'stabilityai/sd-vae-ft-mse': lambda x: 0.5 * x + 0.5, 40 | 'stabilityai/sd-vae-ft-ema': lambda x: 0.5 * x + 0.5, 41 | 'vqgan_imagenet_f16_1024': lambda x: (x + 1.) / 2., 42 | 'vqgan_imagenet_f16_16384': lambda x: (x + 1.) / 2., 43 | 'vqgan_gumbel_f8': lambda x: (x + 1.) / 2., 44 | } 45 | 46 | def getdataset(args): 47 | temporal_sample = TemporalRandomCrop(args.num_frames * args.sample_rate) # 16 x 48 | norm_fun = ae_norm[args.ae] 49 | if args.dataset == 't2v': 50 | if args.multi_scale: 51 | resize = [ 52 | LongSideResizeVideo(args.max_image_size, skip_low_resolution=True), 53 | SpatialStrideCropVideo(args.stride) 54 | ] 55 | else: 56 | resize = [CenterCropResizeVideo(args.max_image_size), ] 57 | transform = transforms.Compose([ 58 | ToTensorVideo(), 59 | *resize, 60 | # RandomHorizontalFlipVideo(p=0.5), # in case their caption have position decription 61 | norm_fun 62 | ]) 63 | tokenizer = AutoTokenizer.from_pretrained(args.text_encoder_name, cache_dir=args.cache_dir) 64 | return T2V_dataset(args, transform=transform, temporal_sample=temporal_sample, tokenizer=tokenizer) 65 | raise NotImplementedError(args.dataset) 66 | -------------------------------------------------------------------------------- /opensora/dataset/extract_feature_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | 4 | import numpy as np 5 | import torch 6 | import torchvision 7 | from PIL import Image 8 | from torch.utils.data import Dataset 9 | 10 | from opensora.utils.dataset_utils import DecordInit, is_image_file 11 | 12 | 13 | class ExtractVideo2Feature(Dataset): 14 | def __init__(self, args, transform): 15 | self.data_path = args.data_path 16 | self.transform = transform 17 | self.v_decoder = DecordInit() 18 | self.samples = list(glob(f'{self.data_path}')) 19 | 20 | def __len__(self): 21 | return len(self.samples) 22 | 23 | def __getitem__(self, idx): 24 | video_path = self.samples[idx] 25 | video = self.decord_read(video_path) 26 | video = self.transform(video) # T C H W -> T C H W 27 | return video, video_path 28 | 29 | def tv_read(self, path): 30 | vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit='sec', output_format='TCHW') 31 | total_frames = len(vframes) 32 | frame_indice = list(range(total_frames)) 33 | video = vframes[frame_indice] 34 | return video 35 | 36 | def decord_read(self, path): 37 | decord_vr = self.v_decoder(path) 38 | total_frames = len(decord_vr) 39 | frame_indice = list(range(total_frames)) 40 | video_data = decord_vr.get_batch(frame_indice).asnumpy() 41 | video_data = torch.from_numpy(video_data) 42 | video_data = video_data.permute(0, 3, 1, 2) # (T, H, W, C) -> (T C H W) 43 | return video_data 44 | 45 | 46 | 47 | class ExtractImage2Feature(Dataset): 48 | def __init__(self, args, transform): 49 | self.data_path = args.data_path 50 | self.transform = transform 51 | self.data_all = list(glob(f'{self.data_path}')) 52 | 53 | def __len__(self): 54 | return len(self.data_all) 55 | 56 | def __getitem__(self, index): 57 | path = self.data_all[index] 58 | video_frame = torch.as_tensor(np.array(Image.open(path), dtype=np.uint8, copy=True)).unsqueeze(0) 59 | video_frame = video_frame.permute(0, 3, 1, 2) 60 | video_frame = self.transform(video_frame) # T C H W 61 | # video_frame = video_frame.transpose(0, 1) # T C H W -> C T H W 62 | 63 | return video_frame, path 64 | 65 | -------------------------------------------------------------------------------- /opensora/dataset/landscope.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | from glob import glob 4 | 5 | import decord 6 | import numpy as np 7 | import torch 8 | import torchvision 9 | from decord import VideoReader, cpu 10 | from torch.utils.data import Dataset 11 | from torchvision.transforms import Compose, Lambda, ToTensor 12 | from torchvision.transforms._transforms_video import NormalizeVideo, RandomCropVideo, RandomHorizontalFlipVideo 13 | from pytorchvideo.transforms import ApplyTransformToKey, ShortSideScale, UniformTemporalSubsample 14 | from torch.nn import functional as F 15 | import random 16 | 17 | from opensora.utils.dataset_utils import DecordInit 18 | 19 | 20 | class Landscope(Dataset): 21 | def __init__(self, args, transform, temporal_sample): 22 | self.data_path = args.data_path 23 | self.num_frames = args.num_frames 24 | self.transform = transform 25 | self.temporal_sample = temporal_sample 26 | self.v_decoder = DecordInit() 27 | 28 | self.samples = self._make_dataset() 29 | self.use_image_num = args.use_image_num 30 | self.use_img_from_vid = args.use_img_from_vid 31 | if self.use_image_num != 0 and not self.use_img_from_vid: 32 | self.img_cap_list = self.get_img_cap_list() 33 | 34 | 35 | def _make_dataset(self): 36 | paths = list(glob(os.path.join(self.data_path, '**', '*.mp4'), recursive=True)) 37 | 38 | return paths 39 | 40 | def __len__(self): 41 | return len(self.samples) 42 | 43 | def __getitem__(self, idx): 44 | video_path = self.samples[idx] 45 | try: 46 | video = self.tv_read(video_path) 47 | video = self.transform(video) # T C H W -> T C H W 48 | video = video.transpose(0, 1) # T C H W -> C T H W 49 | if self.use_image_num != 0 and self.use_img_from_vid: 50 | select_image_idx = np.linspace(0, self.num_frames - 1, self.use_image_num, dtype=int) 51 | assert self.num_frames >= self.use_image_num 52 | images = video[:, select_image_idx] # c, num_img, h, w 53 | video = torch.cat([video, images], dim=1) # c, num_frame+num_img, h, w 54 | elif self.use_image_num != 0 and not self.use_img_from_vid: 55 | images, captions = self.img_cap_list[idx] 56 | raise NotImplementedError 57 | else: 58 | pass 59 | return video, 1 60 | except Exception as e: 61 | print(f'Error with {e}, {video_path}') 62 | return self.__getitem__(random.randint(0, self.__len__()-1)) 63 | 64 | def tv_read(self, path): 65 | vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit='sec', output_format='TCHW') 66 | total_frames = len(vframes) 67 | 68 | # Sampling video frames 69 | start_frame_ind, end_frame_ind = self.temporal_sample(total_frames) 70 | # assert end_frame_ind - start_frame_ind >= self.num_frames 71 | frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.num_frames, dtype=int) 72 | video = vframes[frame_indice] # (T, C, H, W) 73 | 74 | return video 75 | 76 | def decord_read(self, path): 77 | decord_vr = self.v_decoder(path) 78 | total_frames = len(decord_vr) 79 | # Sampling video frames 80 | start_frame_ind, end_frame_ind = self.temporal_sample(total_frames) 81 | # assert end_frame_ind - start_frame_ind >= self.num_frames 82 | frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.num_frames, dtype=int) 83 | 84 | video_data = decord_vr.get_batch(frame_indice).asnumpy() 85 | video_data = torch.from_numpy(video_data) 86 | video_data = video_data.permute(0, 3, 1, 2) # (T, H, W, C) -> (T C H W) 87 | return video_data 88 | 89 | def get_img_cap_list(self): 90 | raise NotImplementedError 91 | -------------------------------------------------------------------------------- /opensora/dataset/ucf101.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | 4 | import decord 5 | import numpy as np 6 | import torch 7 | import torchvision 8 | from decord import VideoReader, cpu 9 | from torch.utils.data import Dataset 10 | from torchvision.transforms import Compose, Lambda, ToTensor 11 | from torchvision.transforms._transforms_video import NormalizeVideo, RandomCropVideo, RandomHorizontalFlipVideo 12 | from pytorchvideo.transforms import ApplyTransformToKey, ShortSideScale, UniformTemporalSubsample 13 | from torch.nn import functional as F 14 | import random 15 | 16 | from opensora.utils.dataset_utils import DecordInit 17 | 18 | 19 | class UCF101(Dataset): 20 | def __init__(self, args, transform, temporal_sample): 21 | self.data_path = args.data_path 22 | self.num_frames = args.num_frames 23 | self.transform = transform 24 | self.temporal_sample = temporal_sample 25 | self.v_decoder = DecordInit() 26 | 27 | self.classes = sorted(os.listdir(self.data_path)) 28 | self.class_to_idx = {cls_name: idx for idx, cls_name in enumerate(self.classes)} 29 | self.samples = self._make_dataset() 30 | 31 | 32 | def _make_dataset(self): 33 | dataset = [] 34 | for class_name in self.classes: 35 | class_path = os.path.join(self.data_path, class_name) 36 | for fname in os.listdir(class_path): 37 | if fname.endswith('.avi'): 38 | item = (os.path.join(class_path, fname), self.class_to_idx[class_name]) 39 | dataset.append(item) 40 | return dataset 41 | 42 | def __len__(self): 43 | return len(self.samples) 44 | 45 | def __getitem__(self, idx): 46 | video_path, label = self.samples[idx] 47 | try: 48 | video = self.tv_read(video_path) 49 | video = self.transform(video) # T C H W -> T C H W 50 | video = video.transpose(0, 1) # T C H W -> C T H W 51 | return video, label 52 | except Exception as e: 53 | print(f'Error with {e}, {video_path}') 54 | return self.__getitem__(random.randint(0, self.__len__()-1)) 55 | 56 | def tv_read(self, path): 57 | vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit='sec', output_format='TCHW') 58 | total_frames = len(vframes) 59 | 60 | # Sampling video frames 61 | start_frame_ind, end_frame_ind = self.temporal_sample(total_frames) 62 | # assert end_frame_ind - start_frame_ind >= self.num_frames 63 | frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.num_frames, dtype=int) 64 | video = vframes[frame_indice] # (T, C, H, W) 65 | 66 | return video 67 | 68 | def decord_read(self, path): 69 | decord_vr = self.v_decoder(path) 70 | total_frames = len(decord_vr) 71 | # Sampling video frames 72 | start_frame_ind, end_frame_ind = self.temporal_sample(total_frames) 73 | # assert end_frame_ind - start_frame_ind >= self.num_frames 74 | frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.num_frames, dtype=int) 75 | 76 | video_data = decord_vr.get_batch(frame_indice).asnumpy() 77 | video_data = torch.from_numpy(video_data) 78 | video_data = video_data.permute(0, 3, 1, 2) # (T, H, W, C) -> (T C H W) 79 | return video_data 80 | 81 | -------------------------------------------------------------------------------- /opensora/eval/cal_flolpips.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from tqdm import tqdm 4 | import math 5 | from einops import rearrange 6 | import sys 7 | sys.path.append(".") 8 | from opensora.eval.flolpips.pwcnet import Network as PWCNet 9 | from opensora.eval.flolpips.flolpips import FloLPIPS 10 | 11 | loss_fn = FloLPIPS(net='alex', version='0.1').eval().requires_grad_(False) 12 | flownet = PWCNet().eval().requires_grad_(False) 13 | 14 | def trans(x): 15 | return x 16 | 17 | 18 | def calculate_flolpips(videos1, videos2, device): 19 | global loss_fn, flownet 20 | 21 | print("calculate_flowlpips...") 22 | loss_fn = loss_fn.to(device) 23 | flownet = flownet.to(device) 24 | 25 | if videos1.shape != videos2.shape: 26 | print("Warning: the shape of videos are not equal.") 27 | min_frames = min(videos1.shape[1], videos2.shape[1]) 28 | videos1 = videos1[:, :min_frames] 29 | videos2 = videos2[:, :min_frames] 30 | 31 | videos1 = trans(videos1) 32 | videos2 = trans(videos2) 33 | 34 | flolpips_results = [] 35 | for video_num in tqdm(range(videos1.shape[0])): 36 | video1 = videos1[video_num].to(device) 37 | video2 = videos2[video_num].to(device) 38 | frames_rec = video1[:-1] 39 | frames_rec_next = video1[1:] 40 | frames_gt = video2[:-1] 41 | frames_gt_next = video2[1:] 42 | t, c, h, w = frames_gt.shape 43 | flow_gt = flownet(frames_gt, frames_gt_next) 44 | flow_dis = flownet(frames_rec, frames_rec_next) 45 | flow_diff = flow_gt - flow_dis 46 | flolpips = loss_fn.forward(frames_gt, frames_rec, flow_diff, normalize=True) 47 | flolpips_results.append(flolpips.cpu().numpy().tolist()) 48 | 49 | flolpips_results = np.array(flolpips_results) # [batch_size, num_frames] 50 | flolpips = {} 51 | flolpips_std = {} 52 | 53 | for clip_timestamp in range(flolpips_results.shape[1]): 54 | flolpips[clip_timestamp] = np.mean(flolpips_results[:,clip_timestamp], axis=-1) 55 | flolpips_std[clip_timestamp] = np.std(flolpips_results[:,clip_timestamp], axis=-1) 56 | 57 | result = { 58 | "value": flolpips, 59 | "value_std": flolpips_std, 60 | "video_setting": video1.shape, 61 | "video_setting_name": "time, channel, heigth, width", 62 | "result": flolpips_results, 63 | "details": flolpips_results.tolist() 64 | } 65 | 66 | return result 67 | 68 | # test code / using example 69 | 70 | def main(): 71 | NUMBER_OF_VIDEOS = 8 72 | VIDEO_LENGTH = 50 73 | CHANNEL = 3 74 | SIZE = 64 75 | videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False) 76 | videos2 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False) 77 | 78 | import json 79 | result = calculate_flolpips(videos1, videos2, "cuda:0") 80 | print(json.dumps(result, indent=4)) 81 | 82 | if __name__ == "__main__": 83 | main() -------------------------------------------------------------------------------- /opensora/eval/cal_fvd.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from tqdm import tqdm 4 | 5 | def trans(x): 6 | # if greyscale images add channel 7 | if x.shape[-3] == 1: 8 | x = x.repeat(1, 1, 3, 1, 1) 9 | 10 | # permute BTCHW -> BCTHW 11 | x = x.permute(0, 2, 1, 3, 4) 12 | 13 | return x 14 | 15 | def calculate_fvd(videos1, videos2, device, method='styleganv'): 16 | 17 | if method == 'styleganv': 18 | from fvd.styleganv.fvd import get_fvd_feats, frechet_distance, load_i3d_pretrained 19 | elif method == 'videogpt': 20 | from fvd.videogpt.fvd import load_i3d_pretrained 21 | from fvd.videogpt.fvd import get_fvd_logits as get_fvd_feats 22 | from fvd.videogpt.fvd import frechet_distance 23 | 24 | print("calculate_fvd...") 25 | 26 | # videos [batch_size, timestamps, channel, h, w] 27 | 28 | assert videos1.shape == videos2.shape 29 | 30 | i3d = load_i3d_pretrained(device=device) 31 | fvd_results = [] 32 | 33 | # support grayscale input, if grayscale -> channel*3 34 | # BTCHW -> BCTHW 35 | # videos -> [batch_size, channel, timestamps, h, w] 36 | 37 | videos1 = trans(videos1) 38 | videos2 = trans(videos2) 39 | 40 | fvd_results = {} 41 | 42 | # for calculate FVD, each clip_timestamp must >= 10 43 | for clip_timestamp in tqdm(range(10, videos1.shape[-3]+1)): 44 | 45 | # get a video clip 46 | # videos_clip [batch_size, channel, timestamps[:clip], h, w] 47 | videos_clip1 = videos1[:, :, : clip_timestamp] 48 | videos_clip2 = videos2[:, :, : clip_timestamp] 49 | 50 | # get FVD features 51 | feats1 = get_fvd_feats(videos_clip1, i3d=i3d, device=device) 52 | feats2 = get_fvd_feats(videos_clip2, i3d=i3d, device=device) 53 | 54 | # calculate FVD when timestamps[:clip] 55 | fvd_results[clip_timestamp] = frechet_distance(feats1, feats2) 56 | 57 | result = { 58 | "value": fvd_results, 59 | "video_setting": videos1.shape, 60 | "video_setting_name": "batch_size, channel, time, heigth, width", 61 | } 62 | 63 | return result 64 | 65 | # test code / using example 66 | 67 | def main(): 68 | NUMBER_OF_VIDEOS = 8 69 | VIDEO_LENGTH = 50 70 | CHANNEL = 3 71 | SIZE = 64 72 | videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False) 73 | videos2 = torch.ones(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False) 74 | device = torch.device("cuda") 75 | # device = torch.device("cpu") 76 | 77 | import json 78 | result = calculate_fvd(videos1, videos2, device, method='videogpt') 79 | print(json.dumps(result, indent=4)) 80 | 81 | result = calculate_fvd(videos1, videos2, device, method='styleganv') 82 | print(json.dumps(result, indent=4)) 83 | 84 | if __name__ == "__main__": 85 | main() 86 | -------------------------------------------------------------------------------- /opensora/eval/cal_lpips.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from tqdm import tqdm 4 | import math 5 | 6 | import torch 7 | import lpips 8 | 9 | spatial = True # Return a spatial map of perceptual distance. 10 | 11 | # Linearly calibrated models (LPIPS) 12 | loss_fn = lpips.LPIPS(net='alex', spatial=spatial) # Can also set net = 'squeeze' or 'vgg' 13 | # loss_fn = lpips.LPIPS(net='alex', spatial=spatial, lpips=False) # Can also set net = 'squeeze' or 'vgg' 14 | 15 | def trans(x): 16 | # if greyscale images add channel 17 | if x.shape[-3] == 1: 18 | x = x.repeat(1, 1, 3, 1, 1) 19 | 20 | # value range [0, 1] -> [-1, 1] 21 | x = x * 2 - 1 22 | 23 | return x 24 | 25 | def calculate_lpips(videos1, videos2, device): 26 | # image should be RGB, IMPORTANT: normalized to [-1,1] 27 | print("calculate_lpips...") 28 | 29 | assert videos1.shape == videos2.shape 30 | 31 | # videos [batch_size, timestamps, channel, h, w] 32 | 33 | # support grayscale input, if grayscale -> channel*3 34 | # value range [0, 1] -> [-1, 1] 35 | videos1 = trans(videos1) 36 | videos2 = trans(videos2) 37 | 38 | lpips_results = [] 39 | 40 | for video_num in tqdm(range(videos1.shape[0])): 41 | # get a video 42 | # video [timestamps, channel, h, w] 43 | video1 = videos1[video_num] 44 | video2 = videos2[video_num] 45 | 46 | lpips_results_of_a_video = [] 47 | for clip_timestamp in range(len(video1)): 48 | # get a img 49 | # img [timestamps[x], channel, h, w] 50 | # img [channel, h, w] tensor 51 | 52 | img1 = video1[clip_timestamp].unsqueeze(0).to(device) 53 | img2 = video2[clip_timestamp].unsqueeze(0).to(device) 54 | 55 | loss_fn.to(device) 56 | 57 | # calculate lpips of a video 58 | lpips_results_of_a_video.append(loss_fn.forward(img1, img2).mean().detach().cpu().tolist()) 59 | lpips_results.append(lpips_results_of_a_video) 60 | 61 | lpips_results = np.array(lpips_results) 62 | 63 | lpips = {} 64 | lpips_std = {} 65 | 66 | for clip_timestamp in range(len(video1)): 67 | lpips[clip_timestamp] = np.mean(lpips_results[:,clip_timestamp]) 68 | lpips_std[clip_timestamp] = np.std(lpips_results[:,clip_timestamp]) 69 | 70 | 71 | result = { 72 | "value": lpips, 73 | "value_std": lpips_std, 74 | "video_setting": video1.shape, 75 | "video_setting_name": "time, channel, heigth, width", 76 | } 77 | 78 | return result 79 | 80 | # test code / using example 81 | 82 | def main(): 83 | NUMBER_OF_VIDEOS = 8 84 | VIDEO_LENGTH = 50 85 | CHANNEL = 3 86 | SIZE = 64 87 | videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False) 88 | videos2 = torch.ones(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False) 89 | device = torch.device("cuda") 90 | # device = torch.device("cpu") 91 | 92 | import json 93 | result = calculate_lpips(videos1, videos2, device) 94 | print(json.dumps(result, indent=4)) 95 | 96 | if __name__ == "__main__": 97 | main() -------------------------------------------------------------------------------- /opensora/eval/cal_psnr.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from tqdm import tqdm 4 | import math 5 | 6 | def img_psnr(img1, img2): 7 | # [0,1] 8 | # compute mse 9 | # mse = np.mean((img1-img2)**2) 10 | mse = np.mean((img1 / 1.0 - img2 / 1.0) ** 2) 11 | # compute psnr 12 | if mse < 1e-10: 13 | return 100 14 | psnr = 20 * math.log10(1 / math.sqrt(mse)) 15 | return psnr 16 | 17 | def trans(x): 18 | return x 19 | 20 | def calculate_psnr(videos1, videos2): 21 | print("calculate_psnr...") 22 | 23 | # videos [batch_size, timestamps, channel, h, w] 24 | 25 | assert videos1.shape == videos2.shape 26 | 27 | videos1 = trans(videos1) 28 | videos2 = trans(videos2) 29 | 30 | psnr_results = [] 31 | 32 | for video_num in tqdm(range(videos1.shape[0])): 33 | # get a video 34 | # video [timestamps, channel, h, w] 35 | video1 = videos1[video_num] 36 | video2 = videos2[video_num] 37 | 38 | psnr_results_of_a_video = [] 39 | for clip_timestamp in range(len(video1)): 40 | # get a img 41 | # img [timestamps[x], channel, h, w] 42 | # img [channel, h, w] numpy 43 | 44 | img1 = video1[clip_timestamp].numpy() 45 | img2 = video2[clip_timestamp].numpy() 46 | 47 | # calculate psnr of a video 48 | psnr_results_of_a_video.append(img_psnr(img1, img2)) 49 | 50 | psnr_results.append(psnr_results_of_a_video) 51 | 52 | psnr_results = np.array(psnr_results) # [batch_size, num_frames] 53 | psnr = {} 54 | psnr_std = {} 55 | 56 | for clip_timestamp in range(len(video1)): 57 | psnr[clip_timestamp] = np.mean(psnr_results[:,clip_timestamp]) 58 | psnr_std[clip_timestamp] = np.std(psnr_results[:,clip_timestamp]) 59 | 60 | result = { 61 | "value": psnr, 62 | "value_std": psnr_std, 63 | "video_setting": video1.shape, 64 | "video_setting_name": "time, channel, heigth, width", 65 | } 66 | 67 | return result 68 | 69 | # test code / using example 70 | 71 | def main(): 72 | NUMBER_OF_VIDEOS = 8 73 | VIDEO_LENGTH = 50 74 | CHANNEL = 3 75 | SIZE = 64 76 | videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False) 77 | videos2 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False) 78 | 79 | import json 80 | result = calculate_psnr(videos1, videos2) 81 | print(json.dumps(result, indent=4)) 82 | 83 | if __name__ == "__main__": 84 | main() -------------------------------------------------------------------------------- /opensora/eval/cal_ssim.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from tqdm import tqdm 4 | import cv2 5 | 6 | def ssim(img1, img2): 7 | C1 = 0.01 ** 2 8 | C2 = 0.03 ** 2 9 | img1 = img1.astype(np.float64) 10 | img2 = img2.astype(np.float64) 11 | kernel = cv2.getGaussianKernel(11, 1.5) 12 | window = np.outer(kernel, kernel.transpose()) 13 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid 14 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 15 | mu1_sq = mu1 ** 2 16 | mu2_sq = mu2 ** 2 17 | mu1_mu2 = mu1 * mu2 18 | sigma1_sq = cv2.filter2D(img1 ** 2, -1, window)[5:-5, 5:-5] - mu1_sq 19 | sigma2_sq = cv2.filter2D(img2 ** 2, -1, window)[5:-5, 5:-5] - mu2_sq 20 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 21 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 22 | (sigma1_sq + sigma2_sq + C2)) 23 | return ssim_map.mean() 24 | 25 | 26 | def calculate_ssim_function(img1, img2): 27 | # [0,1] 28 | # ssim is the only metric extremely sensitive to gray being compared to b/w 29 | if not img1.shape == img2.shape: 30 | raise ValueError('Input images must have the same dimensions.') 31 | if img1.ndim == 2: 32 | return ssim(img1, img2) 33 | elif img1.ndim == 3: 34 | if img1.shape[0] == 3: 35 | ssims = [] 36 | for i in range(3): 37 | ssims.append(ssim(img1[i], img2[i])) 38 | return np.array(ssims).mean() 39 | elif img1.shape[0] == 1: 40 | return ssim(np.squeeze(img1), np.squeeze(img2)) 41 | else: 42 | raise ValueError('Wrong input image dimensions.') 43 | 44 | def trans(x): 45 | return x 46 | 47 | def calculate_ssim(videos1, videos2): 48 | print("calculate_ssim...") 49 | 50 | # videos [batch_size, timestamps, channel, h, w] 51 | 52 | assert videos1.shape == videos2.shape 53 | 54 | videos1 = trans(videos1) 55 | videos2 = trans(videos2) 56 | 57 | ssim_results = [] 58 | 59 | for video_num in tqdm(range(videos1.shape[0])): 60 | # get a video 61 | # video [timestamps, channel, h, w] 62 | video1 = videos1[video_num] 63 | video2 = videos2[video_num] 64 | 65 | ssim_results_of_a_video = [] 66 | for clip_timestamp in range(len(video1)): 67 | # get a img 68 | # img [timestamps[x], channel, h, w] 69 | # img [channel, h, w] numpy 70 | 71 | img1 = video1[clip_timestamp].numpy() 72 | img2 = video2[clip_timestamp].numpy() 73 | 74 | # calculate ssim of a video 75 | ssim_results_of_a_video.append(calculate_ssim_function(img1, img2)) 76 | 77 | ssim_results.append(ssim_results_of_a_video) 78 | 79 | ssim_results = np.array(ssim_results) 80 | 81 | ssim = {} 82 | ssim_std = {} 83 | 84 | for clip_timestamp in range(len(video1)): 85 | ssim[clip_timestamp] = np.mean(ssim_results[:,clip_timestamp]) 86 | ssim_std[clip_timestamp] = np.std(ssim_results[:,clip_timestamp]) 87 | 88 | result = { 89 | "value": ssim, 90 | "value_std": ssim_std, 91 | "video_setting": video1.shape, 92 | "video_setting_name": "time, channel, heigth, width", 93 | } 94 | 95 | return result 96 | 97 | # test code / using example 98 | 99 | def main(): 100 | NUMBER_OF_VIDEOS = 8 101 | VIDEO_LENGTH = 50 102 | CHANNEL = 3 103 | SIZE = 64 104 | videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False) 105 | videos2 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False) 106 | device = torch.device("cuda") 107 | 108 | import json 109 | result = calculate_ssim(videos1, videos2) 110 | print(json.dumps(result, indent=4)) 111 | 112 | if __name__ == "__main__": 113 | main() -------------------------------------------------------------------------------- /opensora/eval/flolpips/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import torch 4 | 5 | 6 | def normalize_tensor(in_feat,eps=1e-10): 7 | norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True)) 8 | return in_feat/(norm_factor+eps) 9 | 10 | def l2(p0, p1, range=255.): 11 | return .5*np.mean((p0 / range - p1 / range)**2) 12 | 13 | def dssim(p0, p1, range=255.): 14 | from skimage.measure import compare_ssim 15 | return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2. 16 | 17 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.): 18 | image_numpy = image_tensor[0].cpu().float().numpy() 19 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor 20 | return image_numpy.astype(imtype) 21 | 22 | def tensor2np(tensor_obj): 23 | # change dimension of a tensor object into a numpy array 24 | return tensor_obj[0].cpu().float().numpy().transpose((1,2,0)) 25 | 26 | def np2tensor(np_obj): 27 | # change dimenion of np array into tensor array 28 | return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 29 | 30 | def tensor2tensorlab(image_tensor,to_norm=True,mc_only=False): 31 | # image tensor to lab tensor 32 | from skimage import color 33 | 34 | img = tensor2im(image_tensor) 35 | img_lab = color.rgb2lab(img) 36 | if(mc_only): 37 | img_lab[:,:,0] = img_lab[:,:,0]-50 38 | if(to_norm and not mc_only): 39 | img_lab[:,:,0] = img_lab[:,:,0]-50 40 | img_lab = img_lab/100. 41 | 42 | return np2tensor(img_lab) 43 | 44 | def read_frame_yuv2rgb(stream, width, height, iFrame, bit_depth, pix_fmt='420'): 45 | if pix_fmt == '420': 46 | multiplier = 1 47 | uv_factor = 2 48 | elif pix_fmt == '444': 49 | multiplier = 2 50 | uv_factor = 1 51 | else: 52 | print('Pixel format {} is not supported'.format(pix_fmt)) 53 | return 54 | 55 | if bit_depth == 8: 56 | datatype = np.uint8 57 | stream.seek(iFrame*1.5*width*height*multiplier) 58 | Y = np.fromfile(stream, dtype=datatype, count=width*height).reshape((height, width)) 59 | 60 | # read chroma samples and upsample since original is 4:2:0 sampling 61 | U = np.fromfile(stream, dtype=datatype, count=(width//uv_factor)*(height//uv_factor)).\ 62 | reshape((height//uv_factor, width//uv_factor)) 63 | V = np.fromfile(stream, dtype=datatype, count=(width//uv_factor)*(height//uv_factor)).\ 64 | reshape((height//uv_factor, width//uv_factor)) 65 | 66 | else: 67 | datatype = np.uint16 68 | stream.seek(iFrame*3*width*height*multiplier) 69 | Y = np.fromfile(stream, dtype=datatype, count=width*height).reshape((height, width)) 70 | 71 | U = np.fromfile(stream, dtype=datatype, count=(width//uv_factor)*(height//uv_factor)).\ 72 | reshape((height//uv_factor, width//uv_factor)) 73 | V = np.fromfile(stream, dtype=datatype, count=(width//uv_factor)*(height//uv_factor)).\ 74 | reshape((height//uv_factor, width//uv_factor)) 75 | 76 | if pix_fmt == '420': 77 | yuv = np.empty((height*3//2, width), dtype=datatype) 78 | yuv[0:height,:] = Y 79 | 80 | yuv[height:height+height//4,:] = U.reshape(-1, width) 81 | yuv[height+height//4:,:] = V.reshape(-1, width) 82 | 83 | if bit_depth != 8: 84 | yuv = (yuv/(2**bit_depth-1)*255).astype(np.uint8) 85 | 86 | #convert to rgb 87 | rgb = cv2.cvtColor(yuv, cv2.COLOR_YUV2RGB_I420) 88 | 89 | else: 90 | yvu = np.stack([Y,V,U],axis=2) 91 | if bit_depth != 8: 92 | yvu = (yvu/(2**bit_depth-1)*255).astype(np.uint8) 93 | rgb = cv2.cvtColor(yvu, cv2.COLOR_YCrCb2RGB) 94 | 95 | return rgb 96 | -------------------------------------------------------------------------------- /opensora/eval/fvd/styleganv/fvd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import math 4 | import torch.nn.functional as F 5 | 6 | # https://github.com/universome/fvd-comparison 7 | 8 | 9 | def load_i3d_pretrained(device=torch.device('cpu')): 10 | i3D_WEIGHTS_URL = "https://www.dropbox.com/s/ge9e5ujwgetktms/i3d_torchscript.pt" 11 | filepath = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'i3d_torchscript.pt') 12 | print(filepath) 13 | if not os.path.exists(filepath): 14 | print(f"preparing for download {i3D_WEIGHTS_URL}, you can download it by yourself.") 15 | os.system(f"wget {i3D_WEIGHTS_URL} -O {filepath}") 16 | i3d = torch.jit.load(filepath).eval().to(device) 17 | i3d = torch.nn.DataParallel(i3d) 18 | return i3d 19 | 20 | 21 | def get_feats(videos, detector, device, bs=10): 22 | # videos : torch.tensor BCTHW [0, 1] 23 | detector_kwargs = dict(rescale=False, resize=False, return_features=True) # Return raw features before the softmax layer. 24 | feats = np.empty((0, 400)) 25 | with torch.no_grad(): 26 | for i in range((len(videos)-1)//bs + 1): 27 | feats = np.vstack([feats, detector(torch.stack([preprocess_single(video) for video in videos[i*bs:(i+1)*bs]]).to(device), **detector_kwargs).detach().cpu().numpy()]) 28 | return feats 29 | 30 | 31 | def get_fvd_feats(videos, i3d, device, bs=10): 32 | # videos in [0, 1] as torch tensor BCTHW 33 | # videos = [preprocess_single(video) for video in videos] 34 | embeddings = get_feats(videos, i3d, device, bs) 35 | return embeddings 36 | 37 | 38 | def preprocess_single(video, resolution=224, sequence_length=None): 39 | # video: CTHW, [0, 1] 40 | c, t, h, w = video.shape 41 | 42 | # temporal crop 43 | if sequence_length is not None: 44 | assert sequence_length <= t 45 | video = video[:, :sequence_length] 46 | 47 | # scale shorter side to resolution 48 | scale = resolution / min(h, w) 49 | if h < w: 50 | target_size = (resolution, math.ceil(w * scale)) 51 | else: 52 | target_size = (math.ceil(h * scale), resolution) 53 | video = F.interpolate(video, size=target_size, mode='bilinear', align_corners=False) 54 | 55 | # center crop 56 | c, t, h, w = video.shape 57 | w_start = (w - resolution) // 2 58 | h_start = (h - resolution) // 2 59 | video = video[:, :, h_start:h_start + resolution, w_start:w_start + resolution] 60 | 61 | # [0, 1] -> [-1, 1] 62 | video = (video - 0.5) * 2 63 | 64 | return video.contiguous() 65 | 66 | 67 | """ 68 | Copy-pasted from https://github.com/cvpr2022-stylegan-v/stylegan-v/blob/main/src/metrics/frechet_video_distance.py 69 | """ 70 | from typing import Tuple 71 | from scipy.linalg import sqrtm 72 | import numpy as np 73 | 74 | 75 | def compute_stats(feats: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: 76 | mu = feats.mean(axis=0) # [d] 77 | sigma = np.cov(feats, rowvar=False) # [d, d] 78 | return mu, sigma 79 | 80 | 81 | def frechet_distance(feats_fake: np.ndarray, feats_real: np.ndarray) -> float: 82 | mu_gen, sigma_gen = compute_stats(feats_fake) 83 | mu_real, sigma_real = compute_stats(feats_real) 84 | m = np.square(mu_gen - mu_real).sum() 85 | if feats_fake.shape[0]>1: 86 | s, _ = sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member 87 | fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2)) 88 | else: 89 | fid = np.real(m) 90 | return float(fid) -------------------------------------------------------------------------------- /opensora/eval/fvd/styleganv/i3d_torchscript.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-Open-Sora-Plan/b060ff6d7a85a27eec5ff9b81b599d03c4ac1bc6/opensora/eval/fvd/styleganv/i3d_torchscript.pt -------------------------------------------------------------------------------- /opensora/eval/fvd/videogpt/i3d_pretrained_400.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-Open-Sora-Plan/b060ff6d7a85a27eec5ff9b81b599d03c4ac1bc6/opensora/eval/fvd/videogpt/i3d_pretrained_400.pt -------------------------------------------------------------------------------- /opensora/eval/script/cal_clip_score.sh: -------------------------------------------------------------------------------- 1 | # clip_score cross modality 2 | python eval_clip_score.py \ 3 | --real_path path/to/image \ 4 | --generated_path path/to/text \ 5 | --batch-size 50 \ 6 | --device "cuda" 7 | 8 | # clip_score within the same modality 9 | python eval_clip_score.py \ 10 | --real_path path/to/textA \ 11 | --generated_path path/to/textB \ 12 | --real_flag txt \ 13 | --generated_flag txt \ 14 | --batch-size 50 \ 15 | --device "cuda" 16 | 17 | python eval_clip_score.py \ 18 | --real_path path/to/imageA \ 19 | --generated_path path/to/imageB \ 20 | --real_flag img \ 21 | --generated_flag img \ 22 | --batch-size 50 \ 23 | --device "cuda" 24 | -------------------------------------------------------------------------------- /opensora/eval/script/cal_fvd.sh: -------------------------------------------------------------------------------- 1 | python eval_common_metric.py \ 2 | --real_video_dir path/to/imageA\ 3 | --generated_video_dir path/to/imageB \ 4 | --batch_size 10 \ 5 | --crop_size 64 \ 6 | --num_frames 20 \ 7 | --device 'cuda' \ 8 | --metric 'fvd' \ 9 | --fvd_method 'styleganv' 10 | -------------------------------------------------------------------------------- /opensora/eval/script/cal_lpips.sh: -------------------------------------------------------------------------------- 1 | python eval_common_metric.py \ 2 | --real_video_dir path/to/imageA\ 3 | --generated_video_dir path/to/imageB \ 4 | --batch_size 10 \ 5 | --num_frames 20 \ 6 | --crop_size 64 \ 7 | --device 'cuda' \ 8 | --metric 'lpips' -------------------------------------------------------------------------------- /opensora/eval/script/cal_psnr.sh: -------------------------------------------------------------------------------- 1 | 2 | python eval_common_metric.py \ 3 | --real_video_dir /data/xiaogeng_liu/data/video1 \ 4 | --generated_video_dir /data/xiaogeng_liu/data/video2 \ 5 | --batch_size 10 \ 6 | --num_frames 20 \ 7 | --crop_size 64 \ 8 | --device 'cuda' \ 9 | --metric 'psnr' -------------------------------------------------------------------------------- /opensora/eval/script/cal_ssim.sh: -------------------------------------------------------------------------------- 1 | python eval_common_metric.py \ 2 | --real_video_dir /data/xiaogeng_liu/data/video1 \ 3 | --generated_video_dir /data/xiaogeng_liu/data/video2 \ 4 | --batch_size 10 \ 5 | --num_frames 20 \ 6 | --crop_size 64 \ 7 | --device 'cuda' \ 8 | --metric 'ssim' -------------------------------------------------------------------------------- /opensora/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-Open-Sora-Plan/b060ff6d7a85a27eec5ff9b81b599d03c4ac1bc6/opensora/models/__init__.py -------------------------------------------------------------------------------- /opensora/models/ae/__init__.py: -------------------------------------------------------------------------------- 1 | from .imagebase import imagebase_ae, imagebase_ae_stride, imagebase_ae_channel 2 | from .videobase import videobase_ae, videobase_ae_stride, videobase_ae_channel 3 | from .videobase import ( 4 | VQVAEConfiguration, 5 | VQVAEModel, 6 | VQVAETrainer, 7 | CausalVQVAEModel, 8 | CausalVQVAEConfiguration, 9 | CausalVQVAETrainer 10 | ) 11 | 12 | ae_stride_config = {} 13 | ae_stride_config.update(imagebase_ae_stride) 14 | ae_stride_config.update(videobase_ae_stride) 15 | 16 | ae_channel_config = {} 17 | ae_channel_config.update(imagebase_ae_channel) 18 | ae_channel_config.update(videobase_ae_channel) 19 | 20 | def getae(args): 21 | """deprecation""" 22 | ae = imagebase_ae.get(args.ae, None) or videobase_ae.get(args.ae, None) 23 | assert ae is not None 24 | return ae(args.ae) 25 | 26 | def getae_wrapper(ae): 27 | """deprecation""" 28 | ae = imagebase_ae.get(ae, None) or videobase_ae.get(ae, None) 29 | assert ae is not None 30 | return ae -------------------------------------------------------------------------------- /opensora/models/ae/imagebase/__init__.py: -------------------------------------------------------------------------------- 1 | from .vae.vae import HFVAEWrapper 2 | from .vae.vae import SDVAEWrapper 3 | from .vqvae.vqvae import SDVQVAEWrapper 4 | 5 | vae = ['stabilityai/sd-vae-ft-mse', 'stabilityai/sd-vae-ft-ema'] 6 | vqvae = ['vqgan_imagenet_f16_1024', 'vqgan_imagenet_f16_16384', 'vqgan_gumbel_f8'] 7 | 8 | imagebase_ae_stride = { 9 | 'stabilityai/sd-vae-ft-mse': [1, 8, 8], 10 | 'stabilityai/sd-vae-ft-ema': [1, 8, 8], 11 | 'vqgan_imagenet_f16_1024': [1, 16, 16], 12 | 'vqgan_imagenet_f16_16384': [1, 16, 16], 13 | 'vqgan_gumbel_f8': [1, 8, 8], 14 | } 15 | 16 | imagebase_ae_channel = { 17 | 'stabilityai/sd-vae-ft-mse': 4, 18 | 'stabilityai/sd-vae-ft-ema': 4, 19 | 'vqgan_imagenet_f16_1024': -1, 20 | 'vqgan_imagenet_f16_16384': -1, 21 | 'vqgan_gumbel_f8': -1, 22 | } 23 | 24 | imagebase_ae = { 25 | 'stabilityai/sd-vae-ft-mse': HFVAEWrapper, 26 | 'stabilityai/sd-vae-ft-ema': HFVAEWrapper, 27 | 'vqgan_imagenet_f16_1024': SDVQVAEWrapper, 28 | 'vqgan_imagenet_f16_16384': SDVQVAEWrapper, 29 | 'vqgan_gumbel_f8': SDVQVAEWrapper, 30 | } -------------------------------------------------------------------------------- /opensora/models/ae/imagebase/vae/vae.py: -------------------------------------------------------------------------------- 1 | from einops import rearrange 2 | from torch import nn 3 | from diffusers.models import AutoencoderKL 4 | 5 | 6 | class HFVAEWrapper(nn.Module): 7 | def __init__(self, hfvae='mse'): 8 | super(HFVAEWrapper, self).__init__() 9 | self.vae = AutoencoderKL.from_pretrained(hfvae, cache_dir='cache_dir') 10 | def encode(self, x): # b c h w 11 | t = 0 12 | if x.ndim == 5: 13 | b, c, t, h, w = x.shape 14 | x = rearrange(x, 'b c t h w -> (b t) c h w').contiguous() 15 | x = self.vae.encode(x).latent_dist.sample().mul_(0.18215) 16 | if t != 0: 17 | x = rearrange(x, '(b t) c h w -> b c t h w', t=t).contiguous() 18 | return x 19 | def decode(self, x): 20 | t = 0 21 | if x.ndim == 5: 22 | b, c, t, h, w = x.shape 23 | x = rearrange(x, 'b c t h w -> (b t) c h w').contiguous() 24 | x = self.vae.decode(x / 0.18215).sample 25 | if t != 0: 26 | x = rearrange(x, '(b t) c h w -> b t c h w', t=t).contiguous() 27 | return x 28 | 29 | class SDVAEWrapper(nn.Module): 30 | def __init__(self): 31 | super(SDVAEWrapper, self).__init__() 32 | raise NotImplementedError 33 | 34 | def encode(self, x): # b c h w 35 | raise NotImplementedError 36 | 37 | def decode(self, x): 38 | raise NotImplementedError -------------------------------------------------------------------------------- /opensora/models/ae/imagebase/vqvae/vqvae.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import yaml 3 | import torch 4 | from omegaconf import OmegaConf 5 | from .vqgan import VQModel, GumbelVQ 6 | 7 | def load_config(config_path, display=False): 8 | config = OmegaConf.load(config_path) 9 | if display: 10 | print(yaml.dump(OmegaConf.to_container(config))) 11 | return config 12 | 13 | 14 | def load_vqgan(config, ckpt_path=None, is_gumbel=False): 15 | if is_gumbel: 16 | model = GumbelVQ(**config.model.params) 17 | else: 18 | model = VQModel(**config.model.params) 19 | if ckpt_path is not None: 20 | sd = torch.load(ckpt_path, map_location="cpu")["state_dict"] 21 | missing, unexpected = model.load_state_dict(sd, strict=False) 22 | return model.eval() 23 | 24 | 25 | class SDVQVAEWrapper(nn.Module): 26 | def __init__(self, name): 27 | super(SDVQVAEWrapper, self).__init__() 28 | raise NotImplementedError 29 | 30 | def encode(self, x): # b c h w 31 | raise NotImplementedError 32 | 33 | def decode(self, x): 34 | raise NotImplementedError 35 | -------------------------------------------------------------------------------- /opensora/models/ae/videobase/__init__.py: -------------------------------------------------------------------------------- 1 | from .vqvae import ( 2 | VQVAEConfiguration, 3 | VQVAEModel, 4 | VQVAETrainer, 5 | VQVAEModelWrapper 6 | ) 7 | from .causal_vqvae import ( 8 | CausalVQVAEConfiguration, 9 | CausalVQVAETrainer, 10 | CausalVQVAEModel, CausalVQVAEModelWrapper 11 | ) 12 | from .causal_vae import ( 13 | CausalVAEModel, CausalVAEModelWrapper 14 | ) 15 | 16 | 17 | videobase_ae_stride = { 18 | 'CausalVAEModel_4x8x8': [4, 8, 8], 19 | 'CausalVQVAEModel_4x4x4': [4, 4, 4], 20 | 'CausalVQVAEModel_4x8x8': [4, 8, 8], 21 | 'VQVAEModel_4x4x4': [4, 4, 4], 22 | 'OpenVQVAEModel_4x4x4': [4, 4, 4], 23 | 'VQVAEModel_4x8x8': [4, 8, 8], 24 | 'bair_stride4x2x2': [4, 2, 2], 25 | 'ucf101_stride4x4x4': [4, 4, 4], 26 | 'kinetics_stride4x4x4': [4, 4, 4], 27 | 'kinetics_stride2x4x4': [2, 4, 4], 28 | } 29 | 30 | videobase_ae_channel = { 31 | 'CausalVAEModel_4x8x8': 4, 32 | 'CausalVQVAEModel_4x4x4': 4, 33 | 'CausalVQVAEModel_4x8x8': 4, 34 | 'VQVAEModel_4x4x4': 4, 35 | 'OpenVQVAEModel_4x4x4': 4, 36 | 'VQVAEModel_4x8x8': 4, 37 | 'bair_stride4x2x2': 256, 38 | 'ucf101_stride4x4x4': 256, 39 | 'kinetics_stride4x4x4': 256, 40 | 'kinetics_stride2x4x4': 256, 41 | } 42 | 43 | videobase_ae = { 44 | 'CausalVAEModel_4x8x8': CausalVAEModelWrapper, 45 | 'CausalVQVAEModel_4x4x4': CausalVQVAEModelWrapper, 46 | 'CausalVQVAEModel_4x8x8': CausalVQVAEModelWrapper, 47 | 'VQVAEModel_4x4x4': VQVAEModelWrapper, 48 | 'VQVAEModel_4x8x8': VQVAEModelWrapper, 49 | "bair_stride4x2x2": VQVAEModelWrapper, 50 | "ucf101_stride4x4x4": VQVAEModelWrapper, 51 | "kinetics_stride4x4x4": VQVAEModelWrapper, 52 | "kinetics_stride2x4x4": VQVAEModelWrapper, 53 | } 54 | -------------------------------------------------------------------------------- /opensora/models/ae/videobase/causal_vae/__init__.py: -------------------------------------------------------------------------------- 1 | from .modeling_causalvae import CausalVAEModel 2 | 3 | from einops import rearrange 4 | from torch import nn 5 | 6 | class CausalVAEModelWrapper(nn.Module): 7 | def __init__(self, model_path, subfolder=None, cache_dir=None, **kwargs): 8 | super(CausalVAEModelWrapper, self).__init__() 9 | # if os.path.exists(ckpt): 10 | # self.vae = CausalVAEModel.load_from_checkpoint(ckpt) 11 | self.vae = CausalVAEModel.from_pretrained(model_path, subfolder=subfolder, cache_dir=cache_dir, **kwargs) 12 | def encode(self, x): # b c t h w 13 | # x = self.vae.encode(x).sample() 14 | x = self.vae.encode(x).sample().mul_(0.18215) 15 | return x 16 | def decode(self, x): 17 | # x = self.vae.decode(x) 18 | x = self.vae.decode(x / 0.18215) 19 | x = rearrange(x, 'b c t h w -> b t c h w').contiguous() 20 | return x 21 | 22 | def dtype(self): 23 | return self.vae.dtype 24 | # 25 | # def device(self): 26 | # return self.vae.device -------------------------------------------------------------------------------- /opensora/models/ae/videobase/causal_vqvae/__init__.py: -------------------------------------------------------------------------------- 1 | from .configuration_causalvqvae import CausalVQVAEConfiguration 2 | from .modeling_causalvqvae import CausalVQVAEModel 3 | from .trainer_causalvqvae import CausalVQVAETrainer 4 | 5 | 6 | from einops import rearrange 7 | from torch import nn 8 | 9 | class CausalVQVAEModelWrapper(nn.Module): 10 | def __init__(self, ckpt): 11 | super(CausalVQVAEModelWrapper, self).__init__() 12 | self.vqvae = CausalVQVAEModel.load_from_checkpoint(ckpt) 13 | def encode(self, x): # b c t h w 14 | x = self.vqvae.pre_vq_conv(self.vqvae.encoder(x)) 15 | return x 16 | def decode(self, x): 17 | vq_output = self.vqvae.codebook(x) 18 | x = self.vqvae.decoder(self.vqvae.post_vq_conv(vq_output['embeddings'])) 19 | x = rearrange(x, 'b c t h w -> b t c h w').contiguous() 20 | return x -------------------------------------------------------------------------------- /opensora/models/ae/videobase/causal_vqvae/configuration_causalvqvae.py: -------------------------------------------------------------------------------- 1 | from ..configuration_videobase import VideoBaseConfiguration 2 | from typing import Union, Tuple 3 | 4 | class CausalVQVAEConfiguration(VideoBaseConfiguration): 5 | def __init__( 6 | self, 7 | embedding_dim: int = 256, 8 | n_codes: int = 2048, 9 | n_hiddens: int = 240, 10 | n_res_layers: int = 4, 11 | resolution: int = 128, 12 | sequence_length: int = 16, 13 | time_downsample: int = 4, 14 | spatial_downsample: int = 8, 15 | no_pos_embd: bool = True, 16 | **kwargs, 17 | ): 18 | super().__init__(**kwargs) 19 | 20 | self.embedding_dim = embedding_dim 21 | self.n_codes = n_codes 22 | self.n_hiddens = n_hiddens 23 | self.n_res_layers = n_res_layers 24 | self.resolution = resolution 25 | self.sequence_length = sequence_length 26 | self.time_downsample = time_downsample 27 | self.spatial_downsample = spatial_downsample 28 | self.no_pos_embd = no_pos_embd 29 | 30 | self.hidden_size = n_hiddens 31 | -------------------------------------------------------------------------------- /opensora/models/ae/videobase/causal_vqvae/trainer_causalvqvae.py: -------------------------------------------------------------------------------- 1 | from ..trainer_videobase import VideoBaseTrainer 2 | import torch.nn.functional as F 3 | from typing import Optional 4 | import os 5 | import torch 6 | from transformers.utils import WEIGHTS_NAME 7 | import json 8 | 9 | class CausalVQVAETrainer(VideoBaseTrainer): 10 | 11 | def compute_loss(self, model, inputs, return_outputs=False): 12 | model = model.module 13 | x = inputs.get("video") 14 | x = x / 2 15 | z = model.pre_vq_conv(model.encoder(x)) 16 | vq_output = model.codebook(z) 17 | x_recon = model.decoder(model.post_vq_conv(vq_output["embeddings"])) 18 | recon_loss = F.mse_loss(x_recon, x) / 0.06 19 | commitment_loss = vq_output['commitment_loss'] 20 | loss = recon_loss + commitment_loss 21 | return loss 22 | -------------------------------------------------------------------------------- /opensora/models/ae/videobase/configuration_videobase.py: -------------------------------------------------------------------------------- 1 | import json 2 | import yaml 3 | from typing import TypeVar, Dict, Any 4 | from diffusers import ConfigMixin 5 | 6 | T = TypeVar('T', bound='VideoBaseConfiguration') 7 | class VideoBaseConfiguration(ConfigMixin): 8 | config_name = "VideoBaseConfiguration" 9 | _nested_config_fields: Dict[str, Any] = {} 10 | 11 | def __init__(self, **kwargs): 12 | pass 13 | 14 | def to_dict(self) -> Dict[str, Any]: 15 | d = {} 16 | for key, value in vars(self).items(): 17 | if isinstance(value, VideoBaseConfiguration): 18 | d[key] = value.to_dict() # Serialize nested VideoBaseConfiguration instances 19 | elif isinstance(value, tuple): 20 | d[key] = list(value) 21 | else: 22 | d[key] = value 23 | return d 24 | 25 | def to_yaml_file(self, yaml_path: str): 26 | with open(yaml_path, 'w') as yaml_file: 27 | yaml.dump(self.to_dict(), yaml_file, default_flow_style=False) 28 | 29 | @classmethod 30 | def load_from_yaml(cls: T, yaml_path: str) -> T: 31 | with open(yaml_path, 'r') as yaml_file: 32 | config_dict = yaml.safe_load(yaml_file) 33 | for field, field_type in cls._nested_config_fields.items(): 34 | if field in config_dict: 35 | config_dict[field] = field_type.load_from_dict(config_dict[field]) 36 | return cls(**config_dict) 37 | 38 | @classmethod 39 | def load_from_dict(cls: T, config_dict: Dict[str, Any]) -> T: 40 | # Process nested configuration objects 41 | for field, field_type in cls._nested_config_fields.items(): 42 | if field in config_dict: 43 | config_dict[field] = field_type.load_from_dict(config_dict[field]) 44 | return cls(**config_dict) -------------------------------------------------------------------------------- /opensora/models/ae/videobase/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .perceptual_loss import SimpleLPIPS, LPIPSWithDiscriminator, LPIPSWithDiscriminator3D -------------------------------------------------------------------------------- /opensora/models/ae/videobase/modeling_videobase.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from diffusers import ModelMixin, ConfigMixin 3 | from torch import nn 4 | import os 5 | import json 6 | import pytorch_lightning as pl 7 | from diffusers.configuration_utils import ConfigMixin 8 | from diffusers.models.modeling_utils import ModelMixin 9 | from typing import Optional, Union 10 | import glob 11 | 12 | class VideoBaseAE(nn.Module): 13 | _supports_gradient_checkpointing = False 14 | 15 | def __init__(self, *args, **kwargs) -> None: 16 | super().__init__(*args, **kwargs) 17 | 18 | @classmethod 19 | def load_from_checkpoint(cls, model_path): 20 | with open(os.path.join(model_path, "config.json"), "r") as file: 21 | config = json.load(file) 22 | state_dict = torch.load(os.path.join(model_path, "pytorch_model.bin"), map_location="cpu") 23 | if 'state_dict' in state_dict: 24 | state_dict = state_dict['state_dict'] 25 | model = cls(config=cls.CONFIGURATION_CLS(**config)) 26 | model.load_state_dict(state_dict) 27 | return model 28 | 29 | @classmethod 30 | def download_and_load_model(cls, model_name, cache_dir=None): 31 | pass 32 | 33 | def encode(self, x: torch.Tensor, *args, **kwargs): 34 | pass 35 | 36 | def decode(self, encoding: torch.Tensor, *args, **kwargs): 37 | pass 38 | 39 | class VideoBaseAE_PL(pl.LightningModule, ModelMixin, ConfigMixin): 40 | config_name = "config.json" 41 | 42 | def __init__(self, *args, **kwargs) -> None: 43 | super().__init__(*args, **kwargs) 44 | 45 | def encode(self, x: torch.Tensor, *args, **kwargs): 46 | pass 47 | 48 | def decode(self, encoding: torch.Tensor, *args, **kwargs): 49 | pass 50 | 51 | @property 52 | def num_training_steps(self) -> int: 53 | """Total training steps inferred from datamodule and devices.""" 54 | if self.trainer.max_steps: 55 | return self.trainer.max_steps 56 | 57 | limit_batches = self.trainer.limit_train_batches 58 | batches = len(self.train_dataloader()) 59 | batches = min(batches, limit_batches) if isinstance(limit_batches, int) else int(limit_batches * batches) 60 | 61 | num_devices = max(1, self.trainer.num_gpus, self.trainer.num_processes) 62 | if self.trainer.tpu_cores: 63 | num_devices = max(num_devices, self.trainer.tpu_cores) 64 | 65 | effective_accum = self.trainer.accumulate_grad_batches * num_devices 66 | return (batches // effective_accum) * self.trainer.max_epochs 67 | 68 | @classmethod 69 | def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): 70 | ckpt_files = glob.glob(os.path.join(pretrained_model_name_or_path, '*.ckpt')) 71 | if ckpt_files: 72 | # Adapt to PyTorch Lightning 73 | last_ckpt_file = ckpt_files[-1] 74 | config_file = os.path.join(pretrained_model_name_or_path, cls.config_name) 75 | model = cls.from_config(config_file) 76 | print("init from {}".format(last_ckpt_file)) 77 | model.init_from_ckpt(last_ckpt_file) 78 | return model 79 | else: 80 | return super().from_pretrained(pretrained_model_name_or_path, **kwargs) -------------------------------------------------------------------------------- /opensora/models/ae/videobase/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .block import Block 2 | from .attention import ( 3 | AttnBlock3D, 4 | AttnBlock3DFix, 5 | AttnBlock, 6 | LinAttnBlock, 7 | LinearAttention, 8 | TemporalAttnBlock 9 | ) 10 | from .conv import CausalConv3d, Conv2d 11 | from .normalize import GroupNorm, Normalize 12 | from .resnet_block import ResnetBlock2D, ResnetBlock3D 13 | from .updownsample import ( 14 | SpatialDownsample2x, 15 | SpatialUpsample2x, 16 | TimeDownsample2x, 17 | TimeUpsample2x, 18 | Upsample, 19 | Downsample, 20 | TimeDownsampleRes2x, 21 | TimeUpsampleRes2x, 22 | TimeDownsampleResAdv2x, 23 | TimeUpsampleResAdv2x 24 | ) 25 | -------------------------------------------------------------------------------- /opensora/models/ae/videobase/modules/block.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | class Block(nn.Module): 4 | def __init__(self, *args, **kwargs) -> None: 5 | super().__init__(*args, **kwargs) -------------------------------------------------------------------------------- /opensora/models/ae/videobase/modules/conv.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from typing import Union, Tuple 3 | import torch.nn.functional as F 4 | import torch 5 | from .block import Block 6 | from .ops import cast_tuple 7 | from einops import rearrange 8 | from .ops import video_to_image 9 | 10 | class Conv2d(nn.Conv2d): 11 | def __init__( 12 | self, 13 | in_channels: int, 14 | out_channels: int, 15 | kernel_size: Union[int, Tuple[int]] = 3, 16 | stride: Union[int, Tuple[int]] = 1, 17 | padding: Union[str, int, Tuple[int]] = 0, 18 | dilation: Union[int, Tuple[int]] = 1, 19 | groups: int = 1, 20 | bias: bool = True, 21 | padding_mode: str = "zeros", 22 | device=None, 23 | dtype=None, 24 | ) -> None: 25 | super().__init__( 26 | in_channels, 27 | out_channels, 28 | kernel_size, 29 | stride, 30 | padding, 31 | dilation, 32 | groups, 33 | bias, 34 | padding_mode, 35 | device, 36 | dtype, 37 | ) 38 | 39 | @video_to_image 40 | def forward(self, x): 41 | return super().forward(x) 42 | 43 | 44 | class CausalConv3d(nn.Module): 45 | def __init__( 46 | self, chan_in, chan_out, kernel_size: Union[int, Tuple[int, int, int]], init_method="random", **kwargs 47 | ): 48 | super().__init__() 49 | self.kernel_size = cast_tuple(kernel_size, 3) 50 | self.time_kernel_size = self.kernel_size[0] 51 | self.chan_in = chan_in 52 | self.chan_out = chan_out 53 | stride = kwargs.pop("stride", 1) 54 | padding = kwargs.pop("padding", 0) 55 | padding = list(cast_tuple(padding, 3)) 56 | padding[0] = 0 57 | stride = cast_tuple(stride, 3) 58 | self.conv = nn.Conv3d(chan_in, chan_out, self.kernel_size, stride=stride, padding=padding) 59 | self._init_weights(init_method) 60 | 61 | def _init_weights(self, init_method): 62 | ks = torch.tensor(self.kernel_size) 63 | if init_method == "avg": 64 | assert ( 65 | self.kernel_size[1] == 1 and self.kernel_size[2] == 1 66 | ), "only support temporal up/down sample" 67 | assert self.chan_in == self.chan_out, "chan_in must be equal to chan_out" 68 | weight = torch.zeros((self.chan_out, self.chan_in, *self.kernel_size)) 69 | 70 | eyes = torch.concat( 71 | [ 72 | torch.eye(self.chan_in).unsqueeze(-1) * 1/3, 73 | torch.eye(self.chan_in).unsqueeze(-1) * 1/3, 74 | torch.eye(self.chan_in).unsqueeze(-1) * 1/3, 75 | ], 76 | dim=-1, 77 | ) 78 | weight[:, :, :, 0, 0] = eyes 79 | 80 | self.conv.weight = nn.Parameter( 81 | weight, 82 | requires_grad=True, 83 | ) 84 | elif init_method == "zero": 85 | self.conv.weight = nn.Parameter( 86 | torch.zeros((self.chan_out, self.chan_in, *self.kernel_size)), 87 | requires_grad=True, 88 | ) 89 | if self.conv.bias is not None: 90 | nn.init.constant_(self.conv.bias, 0) 91 | 92 | def forward(self, x): 93 | # 1 + 16 16 as video, 1 as image 94 | first_frame_pad = x[:, :, :1, :, :].repeat( 95 | (1, 1, self.time_kernel_size - 1, 1, 1) 96 | ) # b c t h w 97 | x = torch.concatenate((first_frame_pad, x), dim=2) # 3 + 16 98 | return self.conv(x) -------------------------------------------------------------------------------- /opensora/models/ae/videobase/modules/normalize.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .block import Block 4 | 5 | class GroupNorm(Block): 6 | def __init__(self, num_channels, num_groups=32, eps=1e-6, *args, **kwargs) -> None: 7 | super().__init__(*args, **kwargs) 8 | self.norm = torch.nn.GroupNorm( 9 | num_groups=num_groups, num_channels=num_channels, eps=1e-6, affine=True 10 | ) 11 | def forward(self, x): 12 | return self.norm(x) 13 | 14 | def Normalize(in_channels, num_groups=32): 15 | return torch.nn.GroupNorm( 16 | num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True 17 | ) 18 | 19 | class ActNorm(nn.Module): 20 | def __init__(self, num_features, logdet=False, affine=True, 21 | allow_reverse_init=False): 22 | assert affine 23 | super().__init__() 24 | self.logdet = logdet 25 | self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) 26 | self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) 27 | self.allow_reverse_init = allow_reverse_init 28 | 29 | self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8)) 30 | 31 | def initialize(self, input): 32 | with torch.no_grad(): 33 | flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) 34 | mean = ( 35 | flatten.mean(1) 36 | .unsqueeze(1) 37 | .unsqueeze(2) 38 | .unsqueeze(3) 39 | .permute(1, 0, 2, 3) 40 | ) 41 | std = ( 42 | flatten.std(1) 43 | .unsqueeze(1) 44 | .unsqueeze(2) 45 | .unsqueeze(3) 46 | .permute(1, 0, 2, 3) 47 | ) 48 | 49 | self.loc.data.copy_(-mean) 50 | self.scale.data.copy_(1 / (std + 1e-6)) 51 | 52 | def forward(self, input, reverse=False): 53 | if reverse: 54 | return self.reverse(input) 55 | if len(input.shape) == 2: 56 | input = input[:,:,None,None] 57 | squeeze = True 58 | else: 59 | squeeze = False 60 | 61 | _, _, height, width = input.shape 62 | 63 | if self.training and self.initialized.item() == 0: 64 | self.initialize(input) 65 | self.initialized.fill_(1) 66 | 67 | h = self.scale * (input + self.loc) 68 | 69 | if squeeze: 70 | h = h.squeeze(-1).squeeze(-1) 71 | 72 | if self.logdet: 73 | log_abs = torch.log(torch.abs(self.scale)) 74 | logdet = height*width*torch.sum(log_abs) 75 | logdet = logdet * torch.ones(input.shape[0]).to(input) 76 | return h, logdet 77 | 78 | return h 79 | 80 | def reverse(self, output): 81 | if self.training and self.initialized.item() == 0: 82 | if not self.allow_reverse_init: 83 | raise RuntimeError( 84 | "Initializing ActNorm in reverse direction is " 85 | "disabled by default. Use allow_reverse_init=True to enable." 86 | ) 87 | else: 88 | self.initialize(output) 89 | self.initialized.fill_(1) 90 | 91 | if len(output.shape) == 2: 92 | output = output[:,:,None,None] 93 | squeeze = True 94 | else: 95 | squeeze = False 96 | 97 | h = output / self.scale - self.loc 98 | 99 | if squeeze: 100 | h = h.squeeze(-1).squeeze(-1) 101 | return h 102 | -------------------------------------------------------------------------------- /opensora/models/ae/videobase/modules/ops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import rearrange 3 | 4 | def video_to_image(func): 5 | def wrapper(self, x, *args, **kwargs): 6 | if x.dim() == 5: 7 | t = x.shape[2] 8 | x = rearrange(x, "b c t h w -> (b t) c h w") 9 | x = func(self, x, *args, **kwargs) 10 | x = rearrange(x, "(b t) c h w -> b c t h w", t=t) 11 | return x 12 | return wrapper 13 | 14 | def nonlinearity(x): 15 | return x * torch.sigmoid(x) 16 | 17 | def cast_tuple(t, length=1): 18 | return t if isinstance(t, tuple) else ((t,) * length) 19 | 20 | def shift_dim(x, src_dim=-1, dest_dim=-1, make_contiguous=True): 21 | n_dims = len(x.shape) 22 | if src_dim < 0: 23 | src_dim = n_dims + src_dim 24 | if dest_dim < 0: 25 | dest_dim = n_dims + dest_dim 26 | assert 0 <= src_dim < n_dims and 0 <= dest_dim < n_dims 27 | dims = list(range(n_dims)) 28 | del dims[src_dim] 29 | permutation = [] 30 | ctr = 0 31 | for i in range(n_dims): 32 | if i == dest_dim: 33 | permutation.append(src_dim) 34 | else: 35 | permutation.append(dims[ctr]) 36 | ctr += 1 37 | x = x.permute(permutation) 38 | if make_contiguous: 39 | x = x.contiguous() 40 | return x -------------------------------------------------------------------------------- /opensora/models/ae/videobase/modules/quant.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.distributed as dist 4 | import numpy as np 5 | import torch.nn.functional as F 6 | from .ops import shift_dim 7 | 8 | class Codebook(nn.Module): 9 | def __init__(self, n_codes, embedding_dim): 10 | super().__init__() 11 | self.register_buffer("embeddings", torch.randn(n_codes, embedding_dim)) 12 | self.register_buffer("N", torch.zeros(n_codes)) 13 | self.register_buffer("z_avg", self.embeddings.data.clone()) 14 | 15 | self.n_codes = n_codes 16 | self.embedding_dim = embedding_dim 17 | self._need_init = True 18 | 19 | def _tile(self, x): 20 | d, ew = x.shape 21 | if d < self.n_codes: 22 | n_repeats = (self.n_codes + d - 1) // d 23 | std = 0.01 / np.sqrt(ew) 24 | x = x.repeat(n_repeats, 1) 25 | x = x + torch.randn_like(x) * std 26 | return x 27 | 28 | def _init_embeddings(self, z): 29 | # z: [b, c, t, h, w] 30 | self._need_init = False 31 | flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2) 32 | y = self._tile(flat_inputs) 33 | 34 | d = y.shape[0] 35 | _k_rand = y[torch.randperm(y.shape[0])][: self.n_codes] 36 | if dist.is_initialized(): 37 | dist.broadcast(_k_rand, 0) 38 | self.embeddings.data.copy_(_k_rand) 39 | self.z_avg.data.copy_(_k_rand) 40 | self.N.data.copy_(torch.ones(self.n_codes)) 41 | 42 | def forward(self, z): 43 | # z: [b, c, t, h, w] 44 | if self._need_init and self.training: 45 | self._init_embeddings(z) 46 | flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2) 47 | distances = ( 48 | (flat_inputs**2).sum(dim=1, keepdim=True) 49 | - 2 * flat_inputs @ self.embeddings.t() 50 | + (self.embeddings.t() ** 2).sum(dim=0, keepdim=True) 51 | ) 52 | 53 | encoding_indices = torch.argmin(distances, dim=1) 54 | encode_onehot = F.one_hot(encoding_indices, self.n_codes).type_as(flat_inputs) 55 | encoding_indices = encoding_indices.view(z.shape[0], *z.shape[2:]) 56 | 57 | embeddings = F.embedding(encoding_indices, self.embeddings) 58 | embeddings = shift_dim(embeddings, -1, 1) 59 | 60 | commitment_loss = 0.25 * F.mse_loss(z, embeddings.detach()) 61 | 62 | # EMA codebook update 63 | if self.training: 64 | n_total = encode_onehot.sum(dim=0) 65 | encode_sum = flat_inputs.t() @ encode_onehot 66 | if dist.is_initialized(): 67 | dist.all_reduce(n_total) 68 | dist.all_reduce(encode_sum) 69 | 70 | self.N.data.mul_(0.99).add_(n_total, alpha=0.01) 71 | self.z_avg.data.mul_(0.99).add_(encode_sum.t(), alpha=0.01) 72 | 73 | n = self.N.sum() 74 | weights = (self.N + 1e-7) / (n + self.n_codes * 1e-7) * n 75 | encode_normalized = self.z_avg / weights.unsqueeze(1) 76 | self.embeddings.data.copy_(encode_normalized) 77 | 78 | y = self._tile(flat_inputs) 79 | _k_rand = y[torch.randperm(y.shape[0])][: self.n_codes] 80 | if dist.is_initialized(): 81 | dist.broadcast(_k_rand, 0) 82 | 83 | usage = (self.N.view(self.n_codes, 1) >= 1).float() 84 | self.embeddings.data.mul_(usage).add_(_k_rand * (1 - usage)) 85 | 86 | embeddings_st = (embeddings - z).detach() + z 87 | 88 | avg_probs = torch.mean(encode_onehot, dim=0) 89 | perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) 90 | 91 | return dict( 92 | embeddings=embeddings_st, 93 | encodings=encoding_indices, 94 | commitment_loss=commitment_loss, 95 | perplexity=perplexity, 96 | ) 97 | 98 | def dictionary_lookup(self, encodings): 99 | embeddings = F.embedding(encodings, self.embeddings) 100 | return embeddings -------------------------------------------------------------------------------- /opensora/models/ae/videobase/modules/resnet_block.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from einops import rearrange, pack, unpack 4 | from .normalize import Normalize 5 | from .ops import nonlinearity, video_to_image 6 | from .conv import CausalConv3d 7 | from .block import Block 8 | 9 | class ResnetBlock2D(Block): 10 | def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, 11 | dropout): 12 | super().__init__() 13 | self.in_channels = in_channels 14 | self.out_channels = in_channels if out_channels is None else out_channels 15 | self.use_conv_shortcut = conv_shortcut 16 | 17 | self.norm1 = Normalize(in_channels) 18 | self.conv1 = torch.nn.Conv2d( 19 | in_channels, out_channels, kernel_size=3, stride=1, padding=1 20 | ) 21 | self.norm2 = Normalize(out_channels) 22 | self.dropout = torch.nn.Dropout(dropout) 23 | self.conv2 = torch.nn.Conv2d( 24 | out_channels, out_channels, kernel_size=3, stride=1, padding=1 25 | ) 26 | if self.in_channels != self.out_channels: 27 | if self.use_conv_shortcut: 28 | self.conv_shortcut = torch.nn.Conv2d( 29 | in_channels, out_channels, kernel_size=3, stride=1, padding=1 30 | ) 31 | else: 32 | self.nin_shortcut = torch.nn.Conv2d( 33 | in_channels, out_channels, kernel_size=1, stride=1, padding=0 34 | ) 35 | 36 | @video_to_image 37 | def forward(self, x): 38 | h = x 39 | h = self.norm1(h) 40 | h = nonlinearity(h) 41 | h = self.conv1(h) 42 | h = self.norm2(h) 43 | h = nonlinearity(h) 44 | h = self.dropout(h) 45 | h = self.conv2(h) 46 | if self.in_channels != self.out_channels: 47 | if self.use_conv_shortcut: 48 | x = self.conv_shortcut(x) 49 | else: 50 | x = self.nin_shortcut(x) 51 | x = x + h 52 | return x 53 | 54 | class ResnetBlock3D(Block): 55 | def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout): 56 | super().__init__() 57 | self.in_channels = in_channels 58 | self.out_channels = in_channels if out_channels is None else out_channels 59 | self.use_conv_shortcut = conv_shortcut 60 | 61 | self.norm1 = Normalize(in_channels) 62 | self.conv1 = CausalConv3d(in_channels, out_channels, 3, padding=1) 63 | self.norm2 = Normalize(out_channels) 64 | self.dropout = torch.nn.Dropout(dropout) 65 | self.conv2 = CausalConv3d(out_channels, out_channels, 3, padding=1) 66 | if self.in_channels != self.out_channels: 67 | if self.use_conv_shortcut: 68 | self.conv_shortcut = CausalConv3d(in_channels, out_channels, 3, padding=1) 69 | else: 70 | self.nin_shortcut = CausalConv3d(in_channels, out_channels, 1, padding=0) 71 | 72 | def forward(self, x): 73 | h = x 74 | h = self.norm1(h) 75 | h = nonlinearity(h) 76 | h = self.conv1(h) 77 | h = self.norm2(h) 78 | h = nonlinearity(h) 79 | h = self.dropout(h) 80 | h = self.conv2(h) 81 | if self.in_channels != self.out_channels: 82 | if self.use_conv_shortcut: 83 | x = self.conv_shortcut(x) 84 | else: 85 | x = self.nin_shortcut(x) 86 | return x + h -------------------------------------------------------------------------------- /opensora/models/ae/videobase/trainer_videobase.py: -------------------------------------------------------------------------------- 1 | from transformers import Trainer 2 | import torch.nn.functional as F 3 | from typing import Optional 4 | import os 5 | import torch 6 | from transformers.utils import WEIGHTS_NAME 7 | import json 8 | 9 | class VideoBaseTrainer(Trainer): 10 | 11 | def _save(self, output_dir: Optional[str] = None, state_dict=None): 12 | output_dir = output_dir if output_dir is not None else self.args.output_dir 13 | os.makedirs(output_dir, exist_ok=True) 14 | if state_dict is None: 15 | state_dict = self.model.state_dict() 16 | 17 | # get model config 18 | model_config = self.model.config.to_dict() 19 | 20 | # add more information 21 | model_config['model'] = self.model.__class__.__name__ 22 | 23 | with open(os.path.join(output_dir, "config.json"), "w") as file: 24 | json.dump(self.model.config.to_dict(), file) 25 | torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) 26 | torch.save(self.args, os.path.join(output_dir, "training_args.bin")) 27 | -------------------------------------------------------------------------------- /opensora/models/ae/videobase/utils/distrib_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | class DiagonalGaussianDistribution(object): 5 | def __init__(self, parameters, deterministic=False): 6 | self.parameters = parameters 7 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 8 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 9 | self.deterministic = deterministic 10 | self.std = torch.exp(0.5 * self.logvar) 11 | self.var = torch.exp(self.logvar) 12 | if self.deterministic: 13 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 14 | 15 | def sample(self): 16 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 17 | return x 18 | 19 | def kl(self, other=None): 20 | if self.deterministic: 21 | return torch.Tensor([0.]) 22 | else: 23 | if other is None: 24 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 25 | + self.var - 1.0 - self.logvar, 26 | dim=[1, 2, 3]) 27 | else: 28 | return 0.5 * torch.sum( 29 | torch.pow(self.mean - other.mean, 2) / other.var 30 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 31 | dim=[1, 2, 3]) 32 | 33 | def nll(self, sample, dims=[1,2,3]): 34 | if self.deterministic: 35 | return torch.Tensor([0.]) 36 | logtwopi = np.log(2.0 * np.pi) 37 | return 0.5 * torch.sum( 38 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 39 | dim=dims) 40 | 41 | def mode(self): 42 | return self.mean 43 | -------------------------------------------------------------------------------- /opensora/models/ae/videobase/utils/module_utils.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | Module = str 4 | MODULES_BASE = "opensora.models.ae.videobase.modules." 5 | 6 | def resolve_str_to_obj(str_val, append=True): 7 | if append: 8 | str_val = MODULES_BASE + str_val 9 | module_name, class_name = str_val.rsplit('.', 1) 10 | module = importlib.import_module(module_name) 11 | return getattr(module, class_name) 12 | 13 | def create_instance(module_class_str: str, **kwargs): 14 | module_name, class_name = module_class_str.rsplit('.', 1) 15 | module = importlib.import_module(module_name) 16 | class_ = getattr(module, class_name) 17 | return class_(**kwargs) -------------------------------------------------------------------------------- /opensora/models/ae/videobase/utils/scheduler_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def cosine_scheduler(step, max_steps, value_base=1, value_end=0): 4 | step = torch.tensor(step) 5 | cosine_value = 0.5 * (1 + torch.cos(torch.pi * step / max_steps)) 6 | value = value_end + (value_base - value_end) * cosine_value 7 | return value -------------------------------------------------------------------------------- /opensora/models/ae/videobase/utils/video_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | def tensor_to_video(x): 5 | x = x.detach().cpu() 6 | x = torch.clamp(x, -1, 1) 7 | x = (x + 1) / 2 8 | x = x.permute(1, 0, 2, 3).float().numpy() # c t h w -> 9 | x = (255 * x).astype(np.uint8) 10 | return x -------------------------------------------------------------------------------- /opensora/models/ae/videobase/vqvae/__init__.py: -------------------------------------------------------------------------------- 1 | from einops import rearrange 2 | from torch import nn 3 | 4 | from .configuration_vqvae import VQVAEConfiguration 5 | from .modeling_vqvae import VQVAEModel 6 | from .trainer_vqvae import VQVAETrainer 7 | 8 | videovqvae = [ 9 | "bair_stride4x2x2", 10 | "ucf101_stride4x4x4", 11 | "kinetics_stride4x4x4", 12 | "kinetics_stride2x4x4", 13 | ] 14 | videovae = [] 15 | 16 | class VQVAEModelWrapper(nn.Module): 17 | def __init__(self, ckpt='kinetics_stride4x4x4'): 18 | super(VQVAEModelWrapper, self).__init__() 19 | if ckpt in videovqvae: 20 | self.vqvae = VQVAEModel.download_and_load_model(ckpt) 21 | else: 22 | self.vqvae = VQVAEModel.load_from_checkpoint(ckpt) 23 | def encode(self, x): # b c t h w 24 | x = self.vqvae.pre_vq_conv(self.vqvae.encoder(x)) 25 | return x 26 | def decode(self, x): 27 | vq_output = self.vqvae.codebook(x) 28 | x = self.vqvae.decoder(self.vqvae.post_vq_conv(vq_output['embeddings'])) 29 | x = rearrange(x, 'b c t h w -> b t c h w').contiguous() 30 | return x 31 | -------------------------------------------------------------------------------- /opensora/models/ae/videobase/vqvae/configuration_vqvae.py: -------------------------------------------------------------------------------- 1 | from ..configuration_videobase import VideoBaseConfiguration 2 | from typing import Union, Tuple 3 | 4 | class VQVAEConfiguration(VideoBaseConfiguration): 5 | def __init__( 6 | self, 7 | embedding_dim: int = 256, 8 | n_codes: int = 2048, 9 | n_hiddens: int = 240, 10 | n_res_layers: int = 4, 11 | resolution: int = 128, 12 | sequence_length: int = 16, 13 | downsample: Union[Tuple[int, int, int], str] = (4, 4, 4), 14 | no_pos_embd: bool = True, 15 | **kwargs, 16 | ): 17 | super().__init__(**kwargs) 18 | 19 | self.embedding_dim = embedding_dim 20 | self.n_codes = n_codes 21 | self.n_hiddens = n_hiddens 22 | self.n_res_layers = n_res_layers 23 | self.resolution = resolution 24 | self.sequence_length = sequence_length 25 | 26 | if isinstance(downsample, str): 27 | self.downsample = tuple(map(int, downsample.split(","))) 28 | else: 29 | self.downsample = downsample 30 | 31 | self.no_pos_embd = no_pos_embd 32 | 33 | self.hidden_size = n_hiddens 34 | -------------------------------------------------------------------------------- /opensora/models/ae/videobase/vqvae/trainer_vqvae.py: -------------------------------------------------------------------------------- 1 | from ..trainer_videobase import VideoBaseTrainer 2 | import torch.nn.functional as F 3 | from typing import Optional 4 | import os 5 | import torch 6 | from transformers.utils import WEIGHTS_NAME 7 | import json 8 | 9 | class VQVAETrainer(VideoBaseTrainer): 10 | 11 | def compute_loss(self, model, inputs, return_outputs=False): 12 | model = model.module 13 | x = inputs.get("video") 14 | x = x / 2 15 | z = model.pre_vq_conv(model.encoder(x)) 16 | vq_output = model.codebook(z) 17 | x_recon = model.decoder(model.post_vq_conv(vq_output["embeddings"])) 18 | recon_loss = F.mse_loss(x_recon, x) / 0.06 19 | commitment_loss = vq_output['commitment_loss'] 20 | loss = recon_loss + commitment_loss 21 | return loss 22 | 23 | -------------------------------------------------------------------------------- /opensora/models/captioner/caption_refiner/README.md: -------------------------------------------------------------------------------- 1 | # Refiner for Video Caption 2 | 3 | Transform the short caption annotations from video datasets into the long and detailed caption annotations. 4 | 5 | * Add detailed description for background scene. 6 | * Add detailed description for object attributes, including color, material, pose. 7 | * Add detailed description for object-level spatial relationship. 8 | 9 | ## 🛠️ Extra Requirements and Installation 10 | 11 | * openai == 0.28.0 12 | * jsonlines == 4.0.0 13 | * nltk == 3.8.1 14 | * Install the LLaMA-Accessory: 15 | 16 | you also need to download the weight of SPHINX to ./ckpt/ folder 17 | 18 | ## 🗝️ Refining 19 | 20 | The refining instruction is in [demo_for_refiner.py](demo_for_refiner.py). 21 | 22 | ```bash 23 | python demo_for_refiner.py --root_path $path_to_repo$ --api_key $openai_api_key$ 24 | ``` 25 | 26 | ### Refining Demos 27 | 28 | ```bash 29 | [original caption]: A red mustang parked in a showroom with american flags hanging from the ceiling. 30 | ``` 31 | 32 | ```bash 33 | [refine caption]: This scene depicts a red Mustang parked in a showroom with American flags hanging from the ceiling. The showroom likely serves as a space for showcasing and purchasing cars, and the Mustang is displayed prominently near the flags and ceiling. The scene also features a large window and other objects. Overall, it seems to take place in a car show or dealership. 34 | ``` 35 | 36 | - [ ] Add GPT-3.5-Turbo for caption summarization. ⌛ [WIP] 37 | - [ ] Add LLAVA-1.6. ⌛ [WIP] 38 | - [ ] More descriptions. ⌛ [WIP] -------------------------------------------------------------------------------- /opensora/models/captioner/caption_refiner/dataset/test_videos/captions.json: -------------------------------------------------------------------------------- 1 | {"video1.gif": "A red mustang parked in a showroom with american flags hanging from the ceiling.", "video2.gif": "An aerial view of a city with a river running through it."} -------------------------------------------------------------------------------- /opensora/models/captioner/caption_refiner/dataset/test_videos/video1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-Open-Sora-Plan/b060ff6d7a85a27eec5ff9b81b599d03c4ac1bc6/opensora/models/captioner/caption_refiner/dataset/test_videos/video1.gif -------------------------------------------------------------------------------- /opensora/models/captioner/caption_refiner/dataset/test_videos/video2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-Open-Sora-Plan/b060ff6d7a85a27eec5ff9b81b599d03c4ac1bc6/opensora/models/captioner/caption_refiner/dataset/test_videos/video2.gif -------------------------------------------------------------------------------- /opensora/models/captioner/caption_refiner/demo_for_refiner.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from caption_refiner import CaptionRefiner 3 | from gpt_combinator import caption_summary, caption_qa 4 | 5 | def parse_args(): 6 | parser = argparse.ArgumentParser(description="question-answer-generation-using-gpt-3") 7 | parser.add_argument("--root_path", required=True, help="The path to repo.") 8 | parser.add_argument("--api_key", required=True, help="OpenAI API key.") 9 | args = parser.parse_args() 10 | return args 11 | 12 | if __name__ == "__main__": 13 | args = parse_args() 14 | myrefiner = CaptionRefiner( 15 | sample_num=6, add_detect=True, add_pos=True, add_attr=True, 16 | openai_api_key = args.api_key, 17 | openai_api_base = "https://one-api.bltcy.top/v1", 18 | ) 19 | 20 | results = myrefiner.caption_refine( 21 | video_path="./dataset/test_videos/video1.gif", 22 | org_caption="A red mustang parked in a showroom with american flags hanging from the ceiling.", 23 | model_path = args.root_path + "/ckpts/SPHINX-Tiny", 24 | ) 25 | 26 | final_caption = myrefiner.gpt_summary(results) 27 | 28 | print(final_caption) 29 | -------------------------------------------------------------------------------- /opensora/models/captioner/caption_refiner/gpt_combinator.py: -------------------------------------------------------------------------------- 1 | import openai 2 | import ast 3 | 4 | def caption_qa(caption_list, api_key, api_base): 5 | openai.api_key = api_key 6 | openai.api_base = api_base 7 | 8 | question = "What is the color of a red apple" 9 | answer = "red" 10 | pred = "green" 11 | try: 12 | # Compute the correctness score 13 | completion = openai.ChatCompletion.create( 14 | model="gpt-3.5-turbo", 15 | # model="gpt-4", 16 | # model="gpt-4-vision-compatible", 17 | messages=[ 18 | { 19 | "role": "system", 20 | "content": 21 | "You are an intelligent chatbot designed for evaluating the correctness of generative outputs for question-answer pairs. " 22 | "Your task is to compare the predicted answer with the correct answer and determine if they match meaningfully. Here's how you can accomplish the task:" 23 | "------" 24 | "##INSTRUCTIONS: " 25 | "- Focus on the meaningful match between the predicted answer and the correct answer.\n" 26 | "- Consider synonyms or paraphrases as valid matches.\n" 27 | "- Evaluate the correctness of the prediction compared to the answer." 28 | }, 29 | { 30 | "role": "user", 31 | "content": 32 | "Please evaluate the following video-based question-answer pair:\n\n" 33 | f"Question: {question}\n" 34 | f"Correct Answer: {answer}\n" 35 | f"Predicted Answer: {pred}\n\n" 36 | "Provide your evaluation only as a yes/no and score where the score is an integer value between 0 and 5, with 5 indicating the highest meaningful match. " 37 | "Please generate the response in the form of a Python dictionary string with keys 'pred' and 'score', where value of 'pred' is a string of 'yes' or 'no' and value of 'score' is in INTEGER, not STRING." 38 | "DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. " 39 | "For example, your response should look like this: {'pred': 'yes', 'score': 4.8}." 40 | } 41 | ] 42 | ) 43 | # Convert response to a Python dictionary. 44 | response_message = completion["choices"][0]["message"]["content"] 45 | response_dict = ast.literal_eval(response_message) 46 | print(response_dict) 47 | 48 | except Exception as e: 49 | print(f"Error processing file : {e}") 50 | 51 | 52 | def caption_summary(long_caption, api_key, api_base): 53 | """ 54 | apply GPT3-Turbo as the combination for original caption and the prompted captions for a video 55 | """ 56 | openai.api_key = api_key 57 | openai.api_base = api_base 58 | 59 | try: 60 | # Compute the correctness score 61 | completion = openai.ChatCompletion.create( 62 | model="gpt-3.5-turbo", 63 | messages=[ 64 | { 65 | "role": "system", 66 | "content": 67 | "You are an intelligent chatbot designed for summarizing from a long sentence. " 68 | }, 69 | { 70 | "role": "user", 71 | "content": 72 | "Please summarize the following sentences. Make it shorter than 70 words." 73 | f"the long sentence: {long_caption}\n" 74 | "Provide your summarization with less than 70 words. " 75 | "DO NOT PROVIDE ANY OTHER TEXT OR EXPLANATION. Only provide the summary sentence. " 76 | } 77 | ] 78 | ) 79 | # "Please generate the response in the form of a Python dictionary string with keys 'pred' and 'score', where value of 'pred' is a string of 'yes' or 'no' and value of 'score' is in INTEGER, not STRING." 80 | # "DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. " 81 | # "For example, your response should look like this: {'summary': 'your summary sentence'}." 82 | 83 | # Convert response to a Python dictionary. 84 | response_message = completion["choices"][0]["message"]["content"] 85 | response_dict = ast.literal_eval(response_message) 86 | 87 | except Exception as e: 88 | print(f"Error processing file : {e}") 89 | 90 | return response_dict 91 | 92 | if __name__ == "__main__": 93 | caption_summary() -------------------------------------------------------------------------------- /opensora/models/diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .latte.modeling_latte import Latte_models 3 | 4 | Diffusion_models = {} 5 | Diffusion_models.update(Latte_models) 6 | 7 | -------------------------------------------------------------------------------- /opensora/models/diffusion/diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | from .respace import SpacedDiffusion, space_timesteps, SpacedDiffusion_T 7 | 8 | 9 | def create_diffusion( 10 | timestep_respacing, 11 | noise_schedule="linear", 12 | use_kl=False, 13 | sigma_small=False, 14 | predict_xstart=False, 15 | learn_sigma=True, 16 | # learn_sigma=False, 17 | rescale_learned_sigmas=False, 18 | diffusion_steps=1000 19 | ): 20 | from . import gaussian_diffusion as gd 21 | betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps) 22 | if use_kl: 23 | loss_type = gd.LossType.RESCALED_KL 24 | elif rescale_learned_sigmas: 25 | loss_type = gd.LossType.RESCALED_MSE 26 | else: 27 | loss_type = gd.LossType.MSE 28 | if timestep_respacing is None or timestep_respacing == "": 29 | timestep_respacing = [diffusion_steps] 30 | return SpacedDiffusion( 31 | use_timesteps=space_timesteps(diffusion_steps, timestep_respacing), 32 | betas=betas, 33 | model_mean_type=( 34 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X 35 | ), 36 | model_var_type=( 37 | ( 38 | gd.ModelVarType.FIXED_LARGE 39 | if not sigma_small 40 | else gd.ModelVarType.FIXED_SMALL 41 | ) 42 | if not learn_sigma 43 | else gd.ModelVarType.LEARNED_RANGE 44 | ), 45 | loss_type=loss_type 46 | # rescale_timesteps=rescale_timesteps, 47 | ) 48 | 49 | def create_diffusion_T( 50 | timestep_respacing, 51 | noise_schedule="linear", 52 | use_kl=False, 53 | sigma_small=False, 54 | predict_xstart=False, 55 | learn_sigma=True, 56 | # learn_sigma=False, 57 | rescale_learned_sigmas=False, 58 | diffusion_steps=1000 59 | ): 60 | from . import gaussian_diffusion_t2v as gd 61 | betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps) 62 | if use_kl: 63 | loss_type = gd.LossType.RESCALED_KL 64 | elif rescale_learned_sigmas: 65 | loss_type = gd.LossType.RESCALED_MSE 66 | else: 67 | loss_type = gd.LossType.MSE 68 | if timestep_respacing is None or timestep_respacing == "": 69 | timestep_respacing = [diffusion_steps] 70 | return SpacedDiffusion_T( 71 | use_timesteps=space_timesteps(diffusion_steps, timestep_respacing), 72 | betas=betas, 73 | model_mean_type=( 74 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X 75 | ), 76 | model_var_type=( 77 | ( 78 | gd.ModelVarType.FIXED_LARGE 79 | if not sigma_small 80 | else gd.ModelVarType.FIXED_SMALL 81 | ) 82 | if not learn_sigma 83 | else gd.ModelVarType.LEARNED_RANGE 84 | ), 85 | loss_type=loss_type 86 | # rescale_timesteps=rescale_timesteps, 87 | ) 88 | -------------------------------------------------------------------------------- /opensora/models/diffusion/diffusion/diffusion_utils.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | import torch as th 7 | import numpy as np 8 | 9 | 10 | def normal_kl(mean1, logvar1, mean2, logvar2): 11 | """ 12 | Compute the KL divergence between two gaussians. 13 | Shapes are automatically broadcasted, so batches can be compared to 14 | scalars, among other use cases. 15 | """ 16 | tensor = None 17 | for obj in (mean1, logvar1, mean2, logvar2): 18 | if isinstance(obj, th.Tensor): 19 | tensor = obj 20 | break 21 | assert tensor is not None, "at least one argument must be a Tensor" 22 | 23 | # Force variances to be Tensors. Broadcasting helps convert scalars to 24 | # Tensors, but it does not work for th.exp(). 25 | logvar1, logvar2 = [ 26 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) 27 | for x in (logvar1, logvar2) 28 | ] 29 | 30 | return 0.5 * ( 31 | -1.0 32 | + logvar2 33 | - logvar1 34 | + th.exp(logvar1 - logvar2) 35 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2) 36 | ) 37 | 38 | 39 | def approx_standard_normal_cdf(x): 40 | """ 41 | A fast approximation of the cumulative distribution function of the 42 | standard normal. 43 | """ 44 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) 45 | 46 | 47 | def continuous_gaussian_log_likelihood(x, *, means, log_scales): 48 | """ 49 | Compute the log-likelihood of a continuous Gaussian distribution. 50 | :param x: the targets 51 | :param means: the Gaussian mean Tensor. 52 | :param log_scales: the Gaussian log stddev Tensor. 53 | :return: a tensor like x of log probabilities (in nats). 54 | """ 55 | centered_x = x - means 56 | inv_stdv = th.exp(-log_scales) 57 | normalized_x = centered_x * inv_stdv 58 | log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x) 59 | return log_probs 60 | 61 | 62 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 63 | """ 64 | Compute the log-likelihood of a Gaussian distribution discretizing to a 65 | given image. 66 | :param x: the target images. It is assumed that this was uint8 values, 67 | rescaled to the range [-1, 1]. 68 | :param means: the Gaussian mean Tensor. 69 | :param log_scales: the Gaussian log stddev Tensor. 70 | :return: a tensor like x of log probabilities (in nats). 71 | """ 72 | assert x.shape == means.shape == log_scales.shape 73 | centered_x = x - means 74 | inv_stdv = th.exp(-log_scales) 75 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 76 | cdf_plus = approx_standard_normal_cdf(plus_in) 77 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 78 | cdf_min = approx_standard_normal_cdf(min_in) 79 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 80 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 81 | cdf_delta = cdf_plus - cdf_min 82 | log_probs = th.where( 83 | x < -0.999, 84 | log_cdf_plus, 85 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 86 | ) 87 | assert log_probs.shape == x.shape 88 | return log_probs 89 | -------------------------------------------------------------------------------- /opensora/models/diffusion/transport/__init__.py: -------------------------------------------------------------------------------- 1 | from .transport import Transport, ModelType, WeightType, PathType, Sampler 2 | 3 | def create_transport( 4 | path_type='Linear', 5 | prediction="velocity", 6 | loss_weight=None, 7 | train_eps=None, 8 | sample_eps=None, 9 | ): 10 | """function for creating Transport object 11 | **Note**: model prediction defaults to velocity 12 | Args: 13 | - path_type: type of path to use; default to linear 14 | - learn_score: set model prediction to score 15 | - learn_noise: set model prediction to noise 16 | - velocity_weighted: weight loss by velocity weight 17 | - likelihood_weighted: weight loss by likelihood weight 18 | - train_eps: small epsilon for avoiding instability during training 19 | - sample_eps: small epsilon for avoiding instability during sampling 20 | """ 21 | 22 | if prediction == "noise": 23 | model_type = ModelType.NOISE 24 | elif prediction == "score": 25 | model_type = ModelType.SCORE 26 | else: 27 | model_type = ModelType.VELOCITY 28 | 29 | if loss_weight == "velocity": 30 | loss_type = WeightType.VELOCITY 31 | elif loss_weight == "likelihood": 32 | loss_type = WeightType.LIKELIHOOD 33 | else: 34 | loss_type = WeightType.NONE 35 | 36 | path_choice = { 37 | "Linear": PathType.LINEAR, 38 | "GVP": PathType.GVP, 39 | "VP": PathType.VP, 40 | } 41 | 42 | path_type = path_choice[path_type] 43 | 44 | if (path_type in [PathType.VP]): 45 | train_eps = 1e-5 if train_eps is None else train_eps 46 | sample_eps = 1e-3 if train_eps is None else sample_eps 47 | elif (path_type in [PathType.GVP, PathType.LINEAR] and model_type != ModelType.VELOCITY): 48 | train_eps = 1e-3 if train_eps is None else train_eps 49 | sample_eps = 1e-3 if train_eps is None else sample_eps 50 | else: # velocity & [GVP, LINEAR] is stable everywhere 51 | train_eps = 0 52 | sample_eps = 0 53 | 54 | # create flow state 55 | state = Transport( 56 | model_type=model_type, 57 | path_type=path_type, 58 | loss_type=loss_type, 59 | train_eps=train_eps, 60 | sample_eps=sample_eps, 61 | ) 62 | 63 | return state -------------------------------------------------------------------------------- /opensora/models/diffusion/transport/integrators.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch as th 3 | import torch.nn as nn 4 | from torchdiffeq import odeint 5 | from functools import partial 6 | from tqdm import tqdm 7 | 8 | class sde: 9 | """SDE solver class""" 10 | def __init__( 11 | self, 12 | drift, 13 | diffusion, 14 | *, 15 | t0, 16 | t1, 17 | num_steps, 18 | sampler_type, 19 | ): 20 | assert t0 < t1, "SDE sampler has to be in forward time" 21 | 22 | self.num_timesteps = num_steps 23 | self.t = th.linspace(t0, t1, num_steps) 24 | self.dt = self.t[1] - self.t[0] 25 | self.drift = drift 26 | self.diffusion = diffusion 27 | self.sampler_type = sampler_type 28 | 29 | def __Euler_Maruyama_step(self, x, mean_x, t, model, **model_kwargs): 30 | w_cur = th.randn(x.size()).to(x) 31 | t = th.ones(x.size(0)).to(x) * t 32 | dw = w_cur * th.sqrt(self.dt) 33 | drift = self.drift(x, t, model, **model_kwargs) 34 | diffusion = self.diffusion(x, t) 35 | mean_x = x + drift * self.dt 36 | x = mean_x + th.sqrt(2 * diffusion) * dw 37 | return x, mean_x 38 | 39 | def __Heun_step(self, x, _, t, model, **model_kwargs): 40 | w_cur = th.randn(x.size()).to(x) 41 | dw = w_cur * th.sqrt(self.dt) 42 | t_cur = th.ones(x.size(0)).to(x) * t 43 | diffusion = self.diffusion(x, t_cur) 44 | xhat = x + th.sqrt(2 * diffusion) * dw 45 | K1 = self.drift(xhat, t_cur, model, **model_kwargs) 46 | xp = xhat + self.dt * K1 47 | K2 = self.drift(xp, t_cur + self.dt, model, **model_kwargs) 48 | return xhat + 0.5 * self.dt * (K1 + K2), xhat # at last time point we do not perform the heun step 49 | 50 | def __forward_fn(self): 51 | """TODO: generalize here by adding all private functions ending with steps to it""" 52 | sampler_dict = { 53 | "Euler": self.__Euler_Maruyama_step, 54 | "Heun": self.__Heun_step, 55 | } 56 | 57 | try: 58 | sampler = sampler_dict[self.sampler_type] 59 | except: 60 | raise NotImplementedError("Smapler type not implemented.") 61 | 62 | return sampler 63 | 64 | def sample(self, init, model, **model_kwargs): 65 | """forward loop of sde""" 66 | x = init 67 | mean_x = init 68 | samples = [] 69 | sampler = self.__forward_fn() 70 | for ti in self.t[:-1]: 71 | with th.no_grad(): 72 | x, mean_x = sampler(x, mean_x, ti, model, **model_kwargs) 73 | samples.append(x) 74 | 75 | return samples 76 | 77 | class ode: 78 | """ODE solver class""" 79 | def __init__( 80 | self, 81 | drift, 82 | *, 83 | t0, 84 | t1, 85 | sampler_type, 86 | num_steps, 87 | atol, 88 | rtol, 89 | ): 90 | assert t0 < t1, "ODE sampler has to be in forward time" 91 | 92 | self.drift = drift 93 | self.t = th.linspace(t0, t1, num_steps) 94 | self.atol = atol 95 | self.rtol = rtol 96 | self.sampler_type = sampler_type 97 | 98 | def sample(self, x, model, **model_kwargs): 99 | 100 | device = x[0].device if isinstance(x, tuple) else x.device 101 | def _fn(t, x): 102 | t = th.ones(x[0].size(0)).to(device) * t if isinstance(x, tuple) else th.ones(x.size(0)).to(device) * t 103 | model_output = self.drift(x, t, model, **model_kwargs) 104 | return model_output 105 | 106 | t = self.t.to(device) 107 | atol = [self.atol] * len(x) if isinstance(x, tuple) else [self.atol] 108 | rtol = [self.rtol] * len(x) if isinstance(x, tuple) else [self.rtol] 109 | samples = odeint( 110 | _fn, 111 | x, 112 | t, 113 | method=self.sampler_type, 114 | atol=atol, 115 | rtol=rtol 116 | ) 117 | return samples -------------------------------------------------------------------------------- /opensora/models/diffusion/transport/utils.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | 3 | class EasyDict: 4 | 5 | def __init__(self, sub_dict): 6 | for k, v in sub_dict.items(): 7 | setattr(self, k, v) 8 | 9 | def __getitem__(self, key): 10 | return getattr(self, key) 11 | 12 | def mean_flat(x): 13 | """ 14 | Take the mean over all non-batch dimensions. 15 | """ 16 | return th.mean(x, dim=list(range(1, len(x.size())))) 17 | 18 | def log_state(state): 19 | result = [] 20 | 21 | sorted_state = dict(sorted(state.items())) 22 | for key, value in sorted_state.items(): 23 | # Check if the value is an instance of a class 24 | if " 7 | 8 | // forward declaration 9 | void rope_2d_cuda( torch::Tensor tokens, const torch::Tensor pos, const float base, const float fwd ); 10 | 11 | void rope_2d_cpu( torch::Tensor tokens, const torch::Tensor positions, const float base, const float fwd ) 12 | { 13 | const int B = tokens.size(0); 14 | const int N = tokens.size(1); 15 | const int H = tokens.size(2); 16 | const int D = tokens.size(3) / 4; 17 | 18 | auto tok = tokens.accessor(); 19 | auto pos = positions.accessor(); 20 | 21 | for (int b = 0; b < B; b++) { 22 | for (int x = 0; x < 2; x++) { // y and then x (2d) 23 | for (int n = 0; n < N; n++) { 24 | 25 | // grab the token position 26 | const int p = pos[b][n][x]; 27 | 28 | for (int h = 0; h < H; h++) { 29 | for (int d = 0; d < D; d++) { 30 | // grab the two values 31 | float u = tok[b][n][h][d+0+x*2*D]; 32 | float v = tok[b][n][h][d+D+x*2*D]; 33 | 34 | // grab the cos,sin 35 | const float inv_freq = fwd * p / powf(base, d/float(D)); 36 | float c = cosf(inv_freq); 37 | float s = sinf(inv_freq); 38 | 39 | // write the result 40 | tok[b][n][h][d+0+x*2*D] = u*c - v*s; 41 | tok[b][n][h][d+D+x*2*D] = v*c + u*s; 42 | } 43 | } 44 | } 45 | } 46 | } 47 | } 48 | 49 | void rope_2d( torch::Tensor tokens, // B,N,H,D 50 | const torch::Tensor positions, // B,N,2 51 | const float base, 52 | const float fwd ) 53 | { 54 | TORCH_CHECK(tokens.dim() == 4, "tokens must have 4 dimensions"); 55 | TORCH_CHECK(positions.dim() == 3, "positions must have 3 dimensions"); 56 | TORCH_CHECK(tokens.size(0) == positions.size(0), "batch size differs between tokens & positions"); 57 | TORCH_CHECK(tokens.size(1) == positions.size(1), "seq_length differs between tokens & positions"); 58 | TORCH_CHECK(positions.size(2) == 2, "positions.shape[2] must be equal to 2"); 59 | TORCH_CHECK(tokens.is_cuda() == positions.is_cuda(), "tokens and positions are not on the same device" ); 60 | 61 | if (tokens.is_cuda()) 62 | rope_2d_cuda( tokens, positions, base, fwd ); 63 | else 64 | rope_2d_cpu( tokens, positions, base, fwd ); 65 | } 66 | 67 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 68 | m.def("rope_2d", &rope_2d, "RoPE 2d forward/backward"); 69 | } 70 | -------------------------------------------------------------------------------- /opensora/models/diffusion/utils/curope/curope2d.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2022-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | 4 | import torch 5 | 6 | try: 7 | import curope as _kernels # run `python setup.py install` 8 | except ModuleNotFoundError: 9 | from . import curope as _kernels # run `python setup.py build_ext --inplace` 10 | 11 | 12 | class cuRoPE2D_func (torch.autograd.Function): 13 | 14 | @staticmethod 15 | def forward(ctx, tokens, positions, base, F0=1): 16 | ctx.save_for_backward(positions) 17 | ctx.saved_base = base 18 | ctx.saved_F0 = F0 19 | # tokens = tokens.clone() # uncomment this if inplace doesn't work 20 | _kernels.rope_2d( tokens, positions, base, F0 ) 21 | ctx.mark_dirty(tokens) 22 | return tokens 23 | 24 | @staticmethod 25 | def backward(ctx, grad_res): 26 | positions, base, F0 = ctx.saved_tensors[0], ctx.saved_base, ctx.saved_F0 27 | _kernels.rope_2d( grad_res, positions, base, -F0 ) 28 | ctx.mark_dirty(grad_res) 29 | return grad_res, None, None, None 30 | 31 | 32 | class cuRoPE2D(torch.nn.Module): 33 | def __init__(self, freq=100.0, F0=1.0): 34 | super().__init__() 35 | self.base = freq 36 | self.F0 = F0 37 | 38 | def forward(self, tokens, positions): 39 | cuRoPE2D_func.apply( tokens.transpose(1,2), positions, self.base, self.F0 ) 40 | return tokens -------------------------------------------------------------------------------- /opensora/models/diffusion/utils/curope/kernels.cu: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright (C) 2022-present Naver Corporation. All rights reserved. 3 | Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 4 | */ 5 | 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | #define CHECK_CUDA(tensor) {\ 12 | TORCH_CHECK((tensor).is_cuda(), #tensor " is not in cuda memory"); \ 13 | TORCH_CHECK((tensor).is_contiguous(), #tensor " is not contiguous"); } 14 | void CHECK_KERNEL() {auto error = cudaGetLastError(); TORCH_CHECK( error == cudaSuccess, cudaGetErrorString(error));} 15 | 16 | 17 | template < typename scalar_t > 18 | __global__ void rope_2d_cuda_kernel( 19 | //scalar_t* __restrict__ tokens, 20 | torch::PackedTensorAccessor32 tokens, 21 | const int64_t* __restrict__ pos, 22 | const float base, 23 | const float fwd ) 24 | // const int N, const int H, const int D ) 25 | { 26 | // tokens shape = (B, N, H, D) 27 | const int N = tokens.size(1); 28 | const int H = tokens.size(2); 29 | const int D = tokens.size(3); 30 | 31 | // each block update a single token, for all heads 32 | // each thread takes care of a single output 33 | extern __shared__ float shared[]; 34 | float* shared_inv_freq = shared + D; 35 | 36 | const int b = blockIdx.x / N; 37 | const int n = blockIdx.x % N; 38 | 39 | const int Q = D / 4; 40 | // one token = [0..Q : Q..2Q : 2Q..3Q : 3Q..D] 41 | // u_Y v_Y u_X v_X 42 | 43 | // shared memory: first, compute inv_freq 44 | if (threadIdx.x < Q) 45 | shared_inv_freq[threadIdx.x] = fwd / powf(base, threadIdx.x/float(Q)); 46 | __syncthreads(); 47 | 48 | // start of X or Y part 49 | const int X = threadIdx.x < D/2 ? 0 : 1; 50 | const int m = (X*D/2) + (threadIdx.x % Q); // index of u_Y or u_X 51 | 52 | // grab the cos,sin appropriate for me 53 | const float freq = pos[blockIdx.x*2+X] * shared_inv_freq[threadIdx.x % Q]; 54 | const float cos = cosf(freq); 55 | const float sin = sinf(freq); 56 | /* 57 | float* shared_cos_sin = shared + D + D/4; 58 | if ((threadIdx.x % (D/2)) < Q) 59 | shared_cos_sin[m+0] = cosf(freq); 60 | else 61 | shared_cos_sin[m+Q] = sinf(freq); 62 | __syncthreads(); 63 | const float cos = shared_cos_sin[m+0]; 64 | const float sin = shared_cos_sin[m+Q]; 65 | */ 66 | 67 | for (int h = 0; h < H; h++) 68 | { 69 | // then, load all the token for this head in shared memory 70 | shared[threadIdx.x] = tokens[b][n][h][threadIdx.x]; 71 | __syncthreads(); 72 | 73 | const float u = shared[m]; 74 | const float v = shared[m+Q]; 75 | 76 | // write output 77 | if ((threadIdx.x % (D/2)) < Q) 78 | tokens[b][n][h][threadIdx.x] = u*cos - v*sin; 79 | else 80 | tokens[b][n][h][threadIdx.x] = v*cos + u*sin; 81 | } 82 | } 83 | 84 | void rope_2d_cuda( torch::Tensor tokens, const torch::Tensor pos, const float base, const float fwd ) 85 | { 86 | const int B = tokens.size(0); // batch size 87 | const int N = tokens.size(1); // sequence length 88 | const int H = tokens.size(2); // number of heads 89 | const int D = tokens.size(3); // dimension per head 90 | 91 | TORCH_CHECK(tokens.stride(3) == 1 && tokens.stride(2) == D, "tokens are not contiguous"); 92 | TORCH_CHECK(pos.is_contiguous(), "positions are not contiguous"); 93 | TORCH_CHECK(pos.size(0) == B && pos.size(1) == N && pos.size(2) == 2, "bad pos.shape"); 94 | TORCH_CHECK(D % 4 == 0, "token dim must be multiple of 4"); 95 | 96 | // one block for each layer, one thread per local-max 97 | const int THREADS_PER_BLOCK = D; 98 | const int N_BLOCKS = B * N; // each block takes care of H*D values 99 | const int SHARED_MEM = sizeof(float) * (D + D/4); 100 | 101 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(tokens.type(), "rope_2d_cuda", ([&] { 102 | rope_2d_cuda_kernel <<>> ( 103 | //tokens.data_ptr(), 104 | tokens.packed_accessor32(), 105 | pos.data_ptr(), 106 | base, fwd); //, N, H, D ); 107 | })); 108 | } 109 | -------------------------------------------------------------------------------- /opensora/models/diffusion/utils/curope/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2022-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | 4 | from setuptools import setup 5 | from torch import cuda 6 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 7 | 8 | # compile for all possible CUDA architectures 9 | all_cuda_archs = cuda.get_gencode_flags().replace('compute=','arch=').split() 10 | # alternatively, you can list cuda archs that you want, eg: 11 | # all_cuda_archs = [ 12 | # '-gencode', 'arch=compute_70,code=sm_70', 13 | # '-gencode', 'arch=compute_75,code=sm_75', 14 | # '-gencode', 'arch=compute_80,code=sm_80', 15 | # '-gencode', 'arch=compute_86,code=sm_86' 16 | # ] 17 | 18 | setup( 19 | name = 'curope', 20 | ext_modules = [ 21 | CUDAExtension( 22 | name='curope', 23 | sources=[ 24 | "curope.cpp", 25 | "kernels.cu", 26 | ], 27 | extra_compile_args = dict( 28 | nvcc=['-O3','--ptxas-options=-v',"--use_fast_math"]+all_cuda_archs, 29 | cxx=['-O3']) 30 | ) 31 | ], 32 | cmdclass = { 33 | 'build_ext': BuildExtension 34 | }) 35 | -------------------------------------------------------------------------------- /opensora/models/frame_interpolation/cfgs/AMT-G.yaml: -------------------------------------------------------------------------------- 1 | 2 | seed: 2023 3 | 4 | network: 5 | name: networks.AMT-G.Model 6 | params: 7 | corr_radius: 3 8 | corr_lvls: 4 9 | num_flows: 5 -------------------------------------------------------------------------------- /opensora/models/frame_interpolation/networks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-Open-Sora-Plan/b060ff6d7a85a27eec5ff9b81b599d03c4ac1bc6/opensora/models/frame_interpolation/networks/__init__.py -------------------------------------------------------------------------------- /opensora/models/frame_interpolation/networks/blocks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-Open-Sora-Plan/b060ff6d7a85a27eec5ff9b81b599d03c4ac1bc6/opensora/models/frame_interpolation/networks/blocks/__init__.py -------------------------------------------------------------------------------- /opensora/models/frame_interpolation/networks/blocks/multi_flow.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from utils.flow_utils import warp 4 | from networks.blocks.ifrnet import ( 5 | convrelu, resize, 6 | ResBlock, 7 | ) 8 | 9 | 10 | def multi_flow_combine(comb_block, img0, img1, flow0, flow1, 11 | mask=None, img_res=None, mean=None): 12 | ''' 13 | A parallel implementation of multiple flow field warping 14 | comb_block: An nn.Seqential object. 15 | img shape: [b, c, h, w] 16 | flow shape: [b, 2*num_flows, h, w] 17 | mask (opt): 18 | If 'mask' is None, the function conduct a simple average. 19 | img_res (opt): 20 | If 'img_res' is None, the function adds zero instead. 21 | mean (opt): 22 | If 'mean' is None, the function adds zero instead. 23 | ''' 24 | b, c, h, w = flow0.shape 25 | num_flows = c // 2 26 | flow0 = flow0.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w) 27 | flow1 = flow1.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w) 28 | 29 | mask = mask.reshape(b, num_flows, 1, h, w 30 | ).reshape(-1, 1, h, w) if mask is not None else None 31 | img_res = img_res.reshape(b, num_flows, 3, h, w 32 | ).reshape(-1, 3, h, w) if img_res is not None else 0 33 | img0 = torch.stack([img0] * num_flows, 1).reshape(-1, 3, h, w) 34 | img1 = torch.stack([img1] * num_flows, 1).reshape(-1, 3, h, w) 35 | mean = torch.stack([mean] * num_flows, 1).reshape(-1, 1, 1, 1 36 | ) if mean is not None else 0 37 | 38 | img0_warp = warp(img0, flow0) 39 | img1_warp = warp(img1, flow1) 40 | img_warps = mask * img0_warp + (1 - mask) * img1_warp + mean + img_res 41 | img_warps = img_warps.reshape(b, num_flows, 3, h, w) 42 | imgt_pred = img_warps.mean(1) + comb_block(img_warps.view(b, -1, h, w)) 43 | return imgt_pred 44 | 45 | 46 | class MultiFlowDecoder(nn.Module): 47 | def __init__(self, in_ch, skip_ch, num_flows=3): 48 | super(MultiFlowDecoder, self).__init__() 49 | self.num_flows = num_flows 50 | self.convblock = nn.Sequential( 51 | convrelu(in_ch*3+4, in_ch*3), 52 | ResBlock(in_ch*3, skip_ch), 53 | nn.ConvTranspose2d(in_ch*3, 8*num_flows, 4, 2, 1, bias=True) 54 | ) 55 | 56 | def forward(self, ft_, f0, f1, flow0, flow1): 57 | n = self.num_flows 58 | f0_warp = warp(f0, flow0) 59 | f1_warp = warp(f1, flow1) 60 | out = self.convblock(torch.cat([ft_, f0_warp, f1_warp, flow0, flow1], 1)) 61 | delta_flow0, delta_flow1, mask, img_res = torch.split(out, [2*n, 2*n, n, 3*n], 1) 62 | mask = torch.sigmoid(mask) 63 | 64 | flow0 = delta_flow0 + 2.0 * resize(flow0, scale_factor=2.0 65 | ).repeat(1, self.num_flows, 1, 1) 66 | flow1 = delta_flow1 + 2.0 * resize(flow1, scale_factor=2.0 67 | ).repeat(1, self.num_flows, 1, 1) 68 | 69 | return flow0, flow1, mask, img_res -------------------------------------------------------------------------------- /opensora/models/frame_interpolation/readme.md: -------------------------------------------------------------------------------- 1 | #### Frame Interpolation 2 | 3 | We use AMT as our frame interpolation model. (Thanks [AMT](https://github.com/MCG-NKU/AMT)) After sampling, you can use frame interpolation model to interpolate your video smoothly. 4 | 5 | 1. Download the pretrained weights from [AMT](https://github.com/MCG-NKU/AMT), we recommend using the largest model AMT-G to achieve the best performance. 6 | 2. Run the script of frame interpolation. 7 | ``` 8 | python opensora/models/frame_interpolation/interpolation.py --ckpt /path/to/ckpt --niters 1 --input /path/to/input/video.mp4 --output_path /path/to/output/floder --frame_rate 30 9 | ``` 10 | 3. The output video will be stored at output_path and its duration time is equal `the total number of frames after frame interpolation / the frame rate` 11 | ##### Frame Interpolation Specific Settings 12 | 13 | * `--ckpt`: Pretrained model of [AMT](https://github.com/MCG-NKU/AMT). We use AMT-G as our frame interpolation model. 14 | * `--niter`: Iterations of interpolation. With $m$ input frames, `[N_ITER]` $=n$ corresponds to $2^n\times (m-1)+1$ output frames. 15 | * `--input`: Path of the input video. 16 | * `--output_path`: Folder Path of the output video. 17 | * `--frame_rate"`: Frame rate of the output video. 18 | -------------------------------------------------------------------------------- /opensora/models/frame_interpolation/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-Open-Sora-Plan/b060ff6d7a85a27eec5ff9b81b599d03c4ac1bc6/opensora/models/frame_interpolation/utils/__init__.py -------------------------------------------------------------------------------- /opensora/models/frame_interpolation/utils/build_utils.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | 4 | def base_build_fn(module, cls, params): 5 | return getattr(importlib.import_module( 6 | module, package=None), cls)(**params) 7 | 8 | 9 | def build_from_cfg(config): 10 | module, cls = config['name'].rsplit(".", 1) 11 | params = config.get('params', {}) 12 | return base_build_fn(module, cls, params) 13 | -------------------------------------------------------------------------------- /opensora/models/frame_interpolation/utils/dist_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | 5 | def get_world_size(): 6 | """Find OMPI world size without calling mpi functions 7 | :rtype: int 8 | """ 9 | if os.environ.get('PMI_SIZE') is not None: 10 | return int(os.environ.get('PMI_SIZE') or 1) 11 | elif os.environ.get('OMPI_COMM_WORLD_SIZE') is not None: 12 | return int(os.environ.get('OMPI_COMM_WORLD_SIZE') or 1) 13 | else: 14 | return torch.cuda.device_count() 15 | 16 | 17 | def get_global_rank(): 18 | """Find OMPI world rank without calling mpi functions 19 | :rtype: int 20 | """ 21 | if os.environ.get('PMI_RANK') is not None: 22 | return int(os.environ.get('PMI_RANK') or 0) 23 | elif os.environ.get('OMPI_COMM_WORLD_RANK') is not None: 24 | return int(os.environ.get('OMPI_COMM_WORLD_RANK') or 0) 25 | else: 26 | return 0 27 | 28 | 29 | def get_local_rank(): 30 | """Find OMPI local rank without calling mpi functions 31 | :rtype: int 32 | """ 33 | if os.environ.get('MPI_LOCALRANKID') is not None: 34 | return int(os.environ.get('MPI_LOCALRANKID') or 0) 35 | elif os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') is not None: 36 | return int(os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') or 0) 37 | else: 38 | return 0 39 | 40 | 41 | def get_master_ip(): 42 | if os.environ.get('AZ_BATCH_MASTER_NODE') is not None: 43 | return os.environ.get('AZ_BATCH_MASTER_NODE').split(':')[0] 44 | elif os.environ.get('AZ_BATCHAI_MPI_MASTER_NODE') is not None: 45 | return os.environ.get('AZ_BATCHAI_MPI_MASTER_NODE') 46 | else: 47 | return "127.0.0.1" 48 | 49 | -------------------------------------------------------------------------------- /opensora/models/super_resolution/README.md: -------------------------------------------------------------------------------- 1 | 2 | ## Environment Preparation 3 | 4 | For video super resolution, please prepare your own python envirment from [RGT](https://github.com/zhengchen1999/RGT) and down the [ckpt](https://drive.google.com/drive/folders/1zxrr31Kp2D_N9a-OUAPaJEn_yTaSXTfZ) into the folder like 5 | ```bash 6 | ./experiments/pretrained_models/RGT_x2.pth 7 | ``` 8 | 9 | ## Video Super Resolution 10 | The inferencing instruction is in [run.py](run.py). 11 | ```bash 12 | python run.py --SR x4 --root_path /path_to_root --input_dir /path_to_input_dir --output_dir /path_to_video_output 13 | ``` 14 | You can configure some more detailed parameters in [run.py](run.py) such as . 15 | ```bash 16 | --mul_numwork 16 --use_chop False 17 | ``` 18 | We recommend using `` --use_chop = False `` when memory allows. 19 | Note that in our tests. 20 | 21 | A single frame of 256x256 requires about 3G RAM-Usage, and a single 4090 card can process about one frame per second. 22 | 23 | A single frame of 512x512 takes about 19G RAM-Usage, and a single 4090 takes about 5 seconds to process a frame. 24 | 25 | 26 | -------------------------------------------------------------------------------- /opensora/models/super_resolution/basicsr/__init__.py: -------------------------------------------------------------------------------- 1 | from .archs import * 2 | from .data import * 3 | from .metrics import * 4 | from .models import * 5 | from .test import * 6 | from .utils import * 7 | -------------------------------------------------------------------------------- /opensora/models/super_resolution/basicsr/archs/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from copy import deepcopy 3 | from os import path as osp 4 | 5 | from basicsr.utils import get_root_logger, scandir 6 | from basicsr.utils.registry import ARCH_REGISTRY 7 | 8 | __all__ = ['build_network'] 9 | 10 | # automatically scan and import arch modules for registry 11 | # scan all the files under the 'archs' folder and collect files ending with 12 | # '_arch.py' 13 | arch_folder = osp.dirname(osp.abspath(__file__)) 14 | arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')] 15 | # import all the arch modules 16 | _arch_modules = [importlib.import_module(f'basicsr.archs.{file_name}') for file_name in arch_filenames] 17 | 18 | 19 | def build_network(opt): 20 | opt = deepcopy(opt) 21 | network_type = opt.pop('type') 22 | net = ARCH_REGISTRY.get(network_type)(**opt) 23 | logger = get_root_logger() 24 | logger.info(f'Network [{net.__class__.__name__}] is created.') 25 | return net 26 | -------------------------------------------------------------------------------- /opensora/models/super_resolution/basicsr/data/data_sampler.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.utils.data.sampler import Sampler 4 | 5 | 6 | class EnlargedSampler(Sampler): 7 | """Sampler that restricts data loading to a subset of the dataset. 8 | 9 | Modified from torch.utils.data.distributed.DistributedSampler 10 | Support enlarging the dataset for iteration-based training, for saving 11 | time when restart the dataloader after each epoch 12 | 13 | Args: 14 | dataset (torch.utils.data.Dataset): Dataset used for sampling. 15 | num_replicas (int | None): Number of processes participating in 16 | the training. It is usually the world_size. 17 | rank (int | None): Rank of the current process within num_replicas. 18 | ratio (int): Enlarging ratio. Default: 1. 19 | """ 20 | 21 | def __init__(self, dataset, num_replicas, rank, ratio=1): 22 | self.dataset = dataset 23 | self.num_replicas = num_replicas 24 | self.rank = rank 25 | self.epoch = 0 26 | self.num_samples = math.ceil(len(self.dataset) * ratio / self.num_replicas) 27 | self.total_size = self.num_samples * self.num_replicas 28 | 29 | def __iter__(self): 30 | # deterministically shuffle based on epoch 31 | g = torch.Generator() 32 | g.manual_seed(self.epoch) 33 | indices = torch.randperm(self.total_size, generator=g).tolist() 34 | 35 | dataset_size = len(self.dataset) 36 | indices = [v % dataset_size for v in indices] 37 | 38 | # subsample 39 | indices = indices[self.rank:self.total_size:self.num_replicas] 40 | assert len(indices) == self.num_samples 41 | 42 | return iter(indices) 43 | 44 | def __len__(self): 45 | return self.num_samples 46 | 47 | def set_epoch(self, epoch): 48 | self.epoch = epoch 49 | -------------------------------------------------------------------------------- /opensora/models/super_resolution/basicsr/data/prefetch_dataloader.py: -------------------------------------------------------------------------------- 1 | import queue as Queue 2 | import threading 3 | import torch 4 | from torch.utils.data import DataLoader 5 | 6 | 7 | class PrefetchGenerator(threading.Thread): 8 | """A general prefetch generator. 9 | 10 | Ref: 11 | https://stackoverflow.com/questions/7323664/python-generator-pre-fetch 12 | 13 | Args: 14 | generator: Python generator. 15 | num_prefetch_queue (int): Number of prefetch queue. 16 | """ 17 | 18 | def __init__(self, generator, num_prefetch_queue): 19 | threading.Thread.__init__(self) 20 | self.queue = Queue.Queue(num_prefetch_queue) 21 | self.generator = generator 22 | self.daemon = True 23 | self.start() 24 | 25 | def run(self): 26 | for item in self.generator: 27 | self.queue.put(item) 28 | self.queue.put(None) 29 | 30 | def __next__(self): 31 | next_item = self.queue.get() 32 | if next_item is None: 33 | raise StopIteration 34 | return next_item 35 | 36 | def __iter__(self): 37 | return self 38 | 39 | 40 | class PrefetchDataLoader(DataLoader): 41 | """Prefetch version of dataloader. 42 | 43 | Ref: 44 | https://github.com/IgorSusmelj/pytorch-styleguide/issues/5# 45 | 46 | TODO: 47 | Need to test on single gpu and ddp (multi-gpu). There is a known issue in 48 | ddp. 49 | 50 | Args: 51 | num_prefetch_queue (int): Number of prefetch queue. 52 | kwargs (dict): Other arguments for dataloader. 53 | """ 54 | 55 | def __init__(self, num_prefetch_queue, **kwargs): 56 | self.num_prefetch_queue = num_prefetch_queue 57 | super(PrefetchDataLoader, self).__init__(**kwargs) 58 | 59 | def __iter__(self): 60 | return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue) 61 | 62 | 63 | class CPUPrefetcher(): 64 | """CPU prefetcher. 65 | 66 | Args: 67 | loader: Dataloader. 68 | """ 69 | 70 | def __init__(self, loader): 71 | self.ori_loader = loader 72 | self.loader = iter(loader) 73 | 74 | def next(self): 75 | try: 76 | return next(self.loader) 77 | except StopIteration: 78 | return None 79 | 80 | def reset(self): 81 | self.loader = iter(self.ori_loader) 82 | 83 | 84 | class CUDAPrefetcher(): 85 | """CUDA prefetcher. 86 | 87 | Ref: 88 | https://github.com/NVIDIA/apex/issues/304# 89 | 90 | It may consums more GPU memory. 91 | 92 | Args: 93 | loader: Dataloader. 94 | opt (dict): Options. 95 | """ 96 | 97 | def __init__(self, loader, opt): 98 | self.ori_loader = loader 99 | self.loader = iter(loader) 100 | self.opt = opt 101 | self.stream = torch.cuda.Stream() 102 | self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu') 103 | self.preload() 104 | 105 | def preload(self): 106 | try: 107 | self.batch = next(self.loader) # self.batch is a dict 108 | except StopIteration: 109 | self.batch = None 110 | return None 111 | # put tensors to gpu 112 | with torch.cuda.stream(self.stream): 113 | for k, v in self.batch.items(): 114 | if torch.is_tensor(v): 115 | self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True) 116 | 117 | def next(self): 118 | torch.cuda.current_stream().wait_stream(self.stream) 119 | batch = self.batch 120 | self.preload() 121 | return batch 122 | 123 | def reset(self): 124 | self.loader = iter(self.ori_loader) 125 | self.preload() 126 | -------------------------------------------------------------------------------- /opensora/models/super_resolution/basicsr/data/single_image_dataset.py: -------------------------------------------------------------------------------- 1 | from os import path as osp 2 | from torch.utils import data as data 3 | from torchvision.transforms.functional import normalize 4 | 5 | from basicsr.data.data_util import paths_from_lmdb 6 | from basicsr.utils import FileClient, imfrombytes, img2tensor, scandir 7 | from basicsr.utils.matlab_functions import rgb2ycbcr 8 | from basicsr.utils.registry import DATASET_REGISTRY 9 | 10 | 11 | @DATASET_REGISTRY.register() 12 | class SingleImageDataset(data.Dataset): 13 | """Read only lq images in the test phase. 14 | 15 | Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc). 16 | 17 | There are two modes: 18 | 1. 'meta_info_file': Use meta information file to generate paths. 19 | 2. 'folder': Scan folders to generate paths. 20 | 21 | Args: 22 | opt (dict): Config for train datasets. It contains the following keys: 23 | dataroot_lq (str): Data root path for lq. 24 | meta_info_file (str): Path for meta information file. 25 | io_backend (dict): IO backend type and other kwarg. 26 | """ 27 | 28 | def __init__(self, opt): 29 | super(SingleImageDataset, self).__init__() 30 | self.opt = opt 31 | # file client (io backend) 32 | self.file_client = None 33 | self.io_backend_opt = opt['io_backend'] 34 | self.mean = opt['mean'] if 'mean' in opt else None 35 | self.std = opt['std'] if 'std' in opt else None 36 | self.lq_folder = opt['dataroot_lq'] 37 | 38 | if self.io_backend_opt['type'] == 'lmdb': 39 | self.io_backend_opt['db_paths'] = [self.lq_folder] 40 | self.io_backend_opt['client_keys'] = ['lq'] 41 | self.paths = paths_from_lmdb(self.lq_folder) 42 | elif 'meta_info_file' in self.opt: 43 | with open(self.opt['meta_info_file'], 'r') as fin: 44 | self.paths = [osp.join(self.lq_folder, line.rstrip().split(' ')[0]) for line in fin] 45 | else: 46 | self.paths = sorted(list(scandir(self.lq_folder, full_path=True))) 47 | 48 | def __getitem__(self, index): 49 | if self.file_client is None: 50 | self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) 51 | 52 | # load lq image 53 | lq_path = self.paths[index] 54 | img_bytes = self.file_client.get(lq_path, 'lq') 55 | img_lq = imfrombytes(img_bytes, float32=True) 56 | 57 | # color space transform 58 | if 'color' in self.opt and self.opt['color'] == 'y': 59 | img_lq = rgb2ycbcr(img_lq, y_only=True)[..., None] 60 | 61 | # BGR to RGB, HWC to CHW, numpy to tensor 62 | img_lq = img2tensor(img_lq, bgr2rgb=True, float32=True) 63 | # normalize 64 | if self.mean is not None or self.std is not None: 65 | normalize(img_lq, self.mean, self.std, inplace=True) 66 | return {'lq': img_lq, 'lq_path': lq_path} 67 | 68 | def __len__(self): 69 | return len(self.paths) 70 | -------------------------------------------------------------------------------- /opensora/models/super_resolution/basicsr/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | from basicsr.utils import get_root_logger 4 | from basicsr.utils.registry import LOSS_REGISTRY 5 | from .losses import (CharbonnierLoss, GANLoss, L1Loss, MSELoss, WeightedTVLoss, g_path_regularize, 6 | gradient_penalty_loss, r1_penalty) 7 | 8 | __all__ = [ 9 | 'L1Loss', 'MSELoss', 'CharbonnierLoss', 'WeightedTVLoss', 'GANLoss', 'gradient_penalty_loss', 10 | 'r1_penalty', 'g_path_regularize' 11 | ] 12 | 13 | 14 | def build_loss(opt): 15 | """Build loss from options. 16 | 17 | Args: 18 | opt (dict): Configuration. It must contain: 19 | type (str): Model type. 20 | """ 21 | opt = deepcopy(opt) 22 | loss_type = opt.pop('type') 23 | loss = LOSS_REGISTRY.get(loss_type)(**opt) 24 | logger = get_root_logger() 25 | logger.info(f'Loss [{loss.__class__.__name__}] is created.') 26 | return loss 27 | -------------------------------------------------------------------------------- /opensora/models/super_resolution/basicsr/losses/loss_util.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from torch.nn import functional as F 3 | 4 | 5 | def reduce_loss(loss, reduction): 6 | """Reduce loss as specified. 7 | 8 | Args: 9 | loss (Tensor): Elementwise loss tensor. 10 | reduction (str): Options are 'none', 'mean' and 'sum'. 11 | 12 | Returns: 13 | Tensor: Reduced loss tensor. 14 | """ 15 | reduction_enum = F._Reduction.get_enum(reduction) 16 | # none: 0, elementwise_mean:1, sum: 2 17 | if reduction_enum == 0: 18 | return loss 19 | elif reduction_enum == 1: 20 | return loss.mean() 21 | else: 22 | return loss.sum() 23 | 24 | 25 | def weight_reduce_loss(loss, weight=None, reduction='mean'): 26 | """Apply element-wise weight and reduce loss. 27 | 28 | Args: 29 | loss (Tensor): Element-wise loss. 30 | weight (Tensor): Element-wise weights. Default: None. 31 | reduction (str): Same as built-in losses of PyTorch. Options are 32 | 'none', 'mean' and 'sum'. Default: 'mean'. 33 | 34 | Returns: 35 | Tensor: Loss values. 36 | """ 37 | # if weight is specified, apply element-wise weight 38 | if weight is not None: 39 | assert weight.dim() == loss.dim() 40 | assert weight.size(1) == 1 or weight.size(1) == loss.size(1) 41 | loss = loss * weight 42 | 43 | # if weight is not specified or reduction is sum, just reduce the loss 44 | if weight is None or reduction == 'sum': 45 | loss = reduce_loss(loss, reduction) 46 | # if reduction is mean, then compute mean over weight region 47 | elif reduction == 'mean': 48 | if weight.size(1) > 1: 49 | weight = weight.sum() 50 | else: 51 | weight = weight.sum() * loss.size(1) 52 | loss = loss.sum() / weight 53 | 54 | return loss 55 | 56 | 57 | def weighted_loss(loss_func): 58 | """Create a weighted version of a given loss function. 59 | 60 | To use this decorator, the loss function must have the signature like 61 | `loss_func(pred, target, **kwargs)`. The function only needs to compute 62 | element-wise loss without any reduction. This decorator will add weight 63 | and reduction arguments to the function. The decorated function will have 64 | the signature like `loss_func(pred, target, weight=None, reduction='mean', 65 | **kwargs)`. 66 | 67 | :Example: 68 | 69 | >>> import torch 70 | >>> @weighted_loss 71 | >>> def l1_loss(pred, target): 72 | >>> return (pred - target).abs() 73 | 74 | >>> pred = torch.Tensor([0, 2, 3]) 75 | >>> target = torch.Tensor([1, 1, 1]) 76 | >>> weight = torch.Tensor([1, 0, 1]) 77 | 78 | >>> l1_loss(pred, target) 79 | tensor(1.3333) 80 | >>> l1_loss(pred, target, weight) 81 | tensor(1.5000) 82 | >>> l1_loss(pred, target, reduction='none') 83 | tensor([1., 1., 2.]) 84 | >>> l1_loss(pred, target, weight, reduction='sum') 85 | tensor(3.) 86 | """ 87 | 88 | @functools.wraps(loss_func) 89 | def wrapper(pred, target, weight=None, reduction='mean', **kwargs): 90 | # get element-wise loss 91 | loss = loss_func(pred, target, **kwargs) 92 | loss = weight_reduce_loss(loss, weight, reduction) 93 | return loss 94 | 95 | return wrapper 96 | -------------------------------------------------------------------------------- /opensora/models/super_resolution/basicsr/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | from basicsr.utils.registry import METRIC_REGISTRY 4 | from .psnr_ssim import calculate_psnr, calculate_ssim 5 | 6 | __all__ = ['calculate_psnr', 'calculate_ssim'] 7 | 8 | 9 | def calculate_metric(data, opt): 10 | """Calculate metric from data and options. 11 | 12 | Args: 13 | opt (dict): Configuration. It must contain: 14 | type (str): Model type. 15 | """ 16 | opt = deepcopy(opt) 17 | metric_type = opt.pop('type') 18 | metric = METRIC_REGISTRY.get(metric_type)(**data, **opt) 19 | return metric 20 | -------------------------------------------------------------------------------- /opensora/models/super_resolution/basicsr/metrics/metric_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from basicsr.utils.matlab_functions import bgr2ycbcr 4 | 5 | 6 | def reorder_image(img, input_order='HWC'): 7 | """Reorder images to 'HWC' order. 8 | 9 | If the input_order is (h, w), return (h, w, 1); 10 | If the input_order is (c, h, w), return (h, w, c); 11 | If the input_order is (h, w, c), return as it is. 12 | 13 | Args: 14 | img (ndarray): Input image. 15 | input_order (str): Whether the input order is 'HWC' or 'CHW'. 16 | If the input image shape is (h, w), input_order will not have 17 | effects. Default: 'HWC'. 18 | 19 | Returns: 20 | ndarray: reordered image. 21 | """ 22 | 23 | if input_order not in ['HWC', 'CHW']: 24 | raise ValueError(f"Wrong input_order {input_order}. Supported input_orders are 'HWC' and 'CHW'") 25 | if len(img.shape) == 2: 26 | img = img[..., None] 27 | if input_order == 'CHW': 28 | img = img.transpose(1, 2, 0) 29 | return img 30 | 31 | 32 | def to_y_channel(img): 33 | """Change to Y channel of YCbCr. 34 | 35 | Args: 36 | img (ndarray): Images with range [0, 255]. 37 | 38 | Returns: 39 | (ndarray): Images with range [0, 255] (float type) without round. 40 | """ 41 | img = img.astype(np.float32) / 255. 42 | if img.ndim == 3 and img.shape[2] == 3: 43 | img = bgr2ycbcr(img, y_only=True) 44 | img = img[..., None] 45 | return img * 255. 46 | -------------------------------------------------------------------------------- /opensora/models/super_resolution/basicsr/models/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from copy import deepcopy 3 | from os import path as osp 4 | 5 | from basicsr.utils import get_root_logger, scandir 6 | from basicsr.utils.registry import MODEL_REGISTRY 7 | 8 | __all__ = ['build_model'] 9 | 10 | # automatically scan and import model modules for registry 11 | # scan all the files under the 'models' folder and collect files ending with 12 | # '_model.py' 13 | model_folder = osp.dirname(osp.abspath(__file__)) 14 | model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')] 15 | # import all the model modules 16 | _model_modules = [importlib.import_module(f'basicsr.models.{file_name}') for file_name in model_filenames] 17 | 18 | 19 | def build_model(opt): 20 | """Build model from options. 21 | 22 | Args: 23 | opt (dict): Configuration. It must contain: 24 | model_type (str): Model type. 25 | """ 26 | opt = deepcopy(opt) 27 | model = MODEL_REGISTRY.get(opt['model_type'])(opt) 28 | logger = get_root_logger() 29 | logger.info(f'Model [{model.__class__.__name__}] is created.') 30 | return model 31 | -------------------------------------------------------------------------------- /opensora/models/super_resolution/basicsr/models/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | from collections import Counter 3 | from torch.optim.lr_scheduler import _LRScheduler 4 | 5 | 6 | class MultiStepRestartLR(_LRScheduler): 7 | """ MultiStep with restarts learning rate scheme. 8 | 9 | Args: 10 | optimizer (torch.nn.optimizer): Torch optimizer. 11 | milestones (list): Iterations that will decrease learning rate. 12 | gamma (float): Decrease ratio. Default: 0.1. 13 | restarts (list): Restart iterations. Default: [0]. 14 | restart_weights (list): Restart weights at each restart iteration. 15 | Default: [1]. 16 | last_epoch (int): Used in _LRScheduler. Default: -1. 17 | """ 18 | 19 | def __init__(self, optimizer, milestones, gamma=0.1, restarts=(0, ), restart_weights=(1, ), last_epoch=-1): 20 | self.milestones = Counter(milestones) 21 | self.gamma = gamma 22 | self.restarts = restarts 23 | self.restart_weights = restart_weights 24 | assert len(self.restarts) == len(self.restart_weights), 'restarts and their weights do not match.' 25 | super(MultiStepRestartLR, self).__init__(optimizer, last_epoch) 26 | 27 | def get_lr(self): 28 | if self.last_epoch in self.restarts: 29 | weight = self.restart_weights[self.restarts.index(self.last_epoch)] 30 | return [group['initial_lr'] * weight for group in self.optimizer.param_groups] 31 | if self.last_epoch not in self.milestones: 32 | return [group['lr'] for group in self.optimizer.param_groups] 33 | return [group['lr'] * self.gamma**self.milestones[self.last_epoch] for group in self.optimizer.param_groups] 34 | 35 | 36 | def get_position_from_periods(iteration, cumulative_period): 37 | """Get the position from a period list. 38 | 39 | It will return the index of the right-closest number in the period list. 40 | For example, the cumulative_period = [100, 200, 300, 400], 41 | if iteration == 50, return 0; 42 | if iteration == 210, return 2; 43 | if iteration == 300, return 2. 44 | 45 | Args: 46 | iteration (int): Current iteration. 47 | cumulative_period (list[int]): Cumulative period list. 48 | 49 | Returns: 50 | int: The position of the right-closest number in the period list. 51 | """ 52 | for i, period in enumerate(cumulative_period): 53 | if iteration <= period: 54 | return i 55 | 56 | 57 | class CosineAnnealingRestartLR(_LRScheduler): 58 | """ Cosine annealing with restarts learning rate scheme. 59 | 60 | An example of config: 61 | periods = [10, 10, 10, 10] 62 | restart_weights = [1, 0.5, 0.5, 0.5] 63 | eta_min=1e-7 64 | 65 | It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the 66 | scheduler will restart with the weights in restart_weights. 67 | 68 | Args: 69 | optimizer (torch.nn.optimizer): Torch optimizer. 70 | periods (list): Period for each cosine anneling cycle. 71 | restart_weights (list): Restart weights at each restart iteration. 72 | Default: [1]. 73 | eta_min (float): The minimum lr. Default: 0. 74 | last_epoch (int): Used in _LRScheduler. Default: -1. 75 | """ 76 | 77 | def __init__(self, optimizer, periods, restart_weights=(1, ), eta_min=0, last_epoch=-1): 78 | self.periods = periods 79 | self.restart_weights = restart_weights 80 | self.eta_min = eta_min 81 | assert (len(self.periods) == len( 82 | self.restart_weights)), 'periods and restart_weights should have the same length.' 83 | self.cumulative_period = [sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))] 84 | super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch) 85 | 86 | def get_lr(self): 87 | idx = get_position_from_periods(self.last_epoch, self.cumulative_period) 88 | current_weight = self.restart_weights[idx] 89 | nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1] 90 | current_period = self.periods[idx] 91 | 92 | return [ 93 | self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) * 94 | (1 + math.cos(math.pi * ((self.last_epoch - nearest_restart) / current_period))) 95 | for base_lr in self.base_lrs 96 | ] 97 | -------------------------------------------------------------------------------- /opensora/models/super_resolution/basicsr/test_img.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | from os import path as osp 4 | from basicsr.data import build_dataloader, build_dataset 5 | from basicsr.models import build_model 6 | from basicsr.utils import get_root_logger, get_time_str, make_exp_dirs 7 | from basicsr.utils.options import dict2str, parse_options 8 | 9 | 10 | def image_sr(args): 11 | # parse options, set distributed setting, set ramdom seed 12 | opt, _ = parse_options(args.root_path,args.SR,is_train=False) 13 | torch.backends.cudnn.benchmark = True 14 | # torch.backends.cudnn.deterministic = True 15 | 16 | # create test dataset and dataloader 17 | test_loaders = [] 18 | for _, dataset_opt in sorted(opt['datasets'].items()): 19 | dataset_opt['dataroot_lq'] = osp.join(args.output_dir, f'temp_LR') 20 | if args.SR == 'x4': 21 | opt['upscale'] = opt['network_g']['upscale'] = 4 22 | opt['val']['suffix'] = 'x4' 23 | opt['path']['pretrain_network_g'] = osp.join(args.root_path, f'experiments/pretrained_models/RGT_x4.pth') 24 | if args.SR == 'x2': 25 | opt['upscale'] = opt['network_g']['upscale'] = 2 26 | opt['val']['suffix'] = 'x2' 27 | 28 | # test_set = build_dataset(dataset_opt) 29 | # test_loader = build_dataloader( 30 | # test_set, dataset_opt, num_gpu=opt['num_gpu'], dist=opt['dist'], sampler=None, seed=opt['manual_seed']) 31 | # test_loaders.append(test_loader) 32 | 33 | opt['path']['pretrain_network_g'] = args.ckpt_path 34 | opt['val']['use_chop'] = args.use_chop 35 | opt['path']['visualization'] = osp.join(args.output_dir, f'temp_results') 36 | opt['path']['results_root'] = osp.join(args.output_dir, f'temp_results') 37 | 38 | # create model 39 | model = build_model(opt) 40 | for test_loader in test_loaders: 41 | test_set_name = test_loader.dataset.opt['name'] 42 | model.validation(test_loader, current_iter=opt['name'], tb_logger=None, save_img=opt['val']['save_img']) 43 | 44 | 45 | if __name__ == '__main__': 46 | root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir)) 47 | # print(root_path) 48 | # image_sr(root_path) -------------------------------------------------------------------------------- /opensora/models/super_resolution/basicsr/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .file_client import FileClient 2 | from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img 3 | from .logger import AvgTimer, MessageLogger, get_env_info, get_root_logger, init_tb_logger, init_wandb_logger 4 | from .misc import check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, scandir, set_random_seed, sizeof_fmt 5 | 6 | __all__ = [ 7 | # file_client.py 8 | 'FileClient', 9 | # img_util.py 10 | 'img2tensor', 11 | 'tensor2img', 12 | 'imfrombytes', 13 | 'imwrite', 14 | 'crop_border', 15 | # logger.py 16 | 'MessageLogger', 17 | 'AvgTimer', 18 | 'init_tb_logger', 19 | 'init_wandb_logger', 20 | 'get_root_logger', 21 | 'get_env_info', 22 | # misc.py 23 | 'set_random_seed', 24 | 'get_time_str', 25 | 'mkdir_and_rename', 26 | 'make_exp_dirs', 27 | 'scandir', 28 | 'check_resume', 29 | 'sizeof_fmt', 30 | ] 31 | -------------------------------------------------------------------------------- /opensora/models/super_resolution/basicsr/utils/dist_util.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501 2 | import functools 3 | import os 4 | import subprocess 5 | import torch 6 | import torch.distributed as dist 7 | import torch.multiprocessing as mp 8 | 9 | 10 | def init_dist(launcher, backend='nccl', **kwargs): 11 | if mp.get_start_method(allow_none=True) is None: 12 | mp.set_start_method('spawn') 13 | if launcher == 'pytorch': 14 | _init_dist_pytorch(backend, **kwargs) 15 | elif launcher == 'slurm': 16 | _init_dist_slurm(backend, **kwargs) 17 | else: 18 | raise ValueError(f'Invalid launcher type: {launcher}') 19 | 20 | 21 | def _init_dist_pytorch(backend, **kwargs): 22 | rank = int(os.environ['RANK']) 23 | num_gpus = torch.cuda.device_count() 24 | torch.cuda.set_device(rank % num_gpus) 25 | dist.init_process_group(backend=backend, **kwargs) 26 | 27 | 28 | def _init_dist_slurm(backend, port=None): 29 | """Initialize slurm distributed training environment. 30 | 31 | If argument ``port`` is not specified, then the master port will be system 32 | environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system 33 | environment variable, then a default port ``29500`` will be used. 34 | 35 | Args: 36 | backend (str): Backend of torch.distributed. 37 | port (int, optional): Master port. Defaults to None. 38 | """ 39 | proc_id = int(os.environ['SLURM_PROCID']) 40 | ntasks = int(os.environ['SLURM_NTASKS']) 41 | node_list = os.environ['SLURM_NODELIST'] 42 | num_gpus = torch.cuda.device_count() 43 | torch.cuda.set_device(proc_id % num_gpus) 44 | addr = subprocess.getoutput(f'scontrol show hostname {node_list} | head -n1') 45 | # specify master port 46 | if port is not None: 47 | os.environ['MASTER_PORT'] = str(port) 48 | elif 'MASTER_PORT' in os.environ: 49 | pass # use MASTER_PORT in the environment variable 50 | else: 51 | # 29500 is torch.distributed default port 52 | os.environ['MASTER_PORT'] = '29500' 53 | os.environ['MASTER_ADDR'] = addr 54 | os.environ['WORLD_SIZE'] = str(ntasks) 55 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) 56 | os.environ['RANK'] = str(proc_id) 57 | dist.init_process_group(backend=backend) 58 | 59 | 60 | def get_dist_info(): 61 | if dist.is_available(): 62 | initialized = dist.is_initialized() 63 | else: 64 | initialized = False 65 | if initialized: 66 | rank = dist.get_rank() 67 | world_size = dist.get_world_size() 68 | else: 69 | rank = 0 70 | world_size = 1 71 | return rank, world_size 72 | 73 | 74 | def master_only(func): 75 | 76 | @functools.wraps(func) 77 | def wrapper(*args, **kwargs): 78 | rank, _ = get_dist_info() 79 | if rank == 0: 80 | return func(*args, **kwargs) 81 | 82 | return wrapper 83 | -------------------------------------------------------------------------------- /opensora/models/super_resolution/basicsr/utils/registry.py: -------------------------------------------------------------------------------- 1 | # Modified from: https://github.com/facebookresearch/fvcore/blob/master/fvcore/common/registry.py # noqa: E501 2 | 3 | 4 | class Registry(): 5 | """ 6 | The registry that provides name -> object mapping, to support third-party 7 | users' custom modules. 8 | 9 | To create a registry (e.g. a backbone registry): 10 | 11 | .. code-block:: python 12 | 13 | BACKBONE_REGISTRY = Registry('BACKBONE') 14 | 15 | To register an object: 16 | 17 | .. code-block:: python 18 | 19 | @BACKBONE_REGISTRY.register() 20 | class MyBackbone(): 21 | ... 22 | 23 | Or: 24 | 25 | .. code-block:: python 26 | 27 | BACKBONE_REGISTRY.register(MyBackbone) 28 | """ 29 | 30 | def __init__(self, name): 31 | """ 32 | Args: 33 | name (str): the name of this registry 34 | """ 35 | self._name = name 36 | self._obj_map = {} 37 | 38 | def _do_register(self, name, obj): 39 | assert (name not in self._obj_map), (f"An object named '{name}' was already registered " 40 | f"in '{self._name}' registry!") 41 | self._obj_map[name] = obj 42 | 43 | def register(self, obj=None): 44 | """ 45 | Register the given object under the the name `obj.__name__`. 46 | Can be used as either a decorator or not. 47 | See docstring of this class for usage. 48 | """ 49 | if obj is None: 50 | # used as a decorator 51 | def deco(func_or_class): 52 | name = func_or_class.__name__ 53 | self._do_register(name, func_or_class) 54 | return func_or_class 55 | 56 | return deco 57 | 58 | # used as a function call 59 | name = obj.__name__ 60 | self._do_register(name, obj) 61 | 62 | def get(self, name): 63 | ret = self._obj_map.get(name) 64 | if ret is None: 65 | raise KeyError(f"No object named '{name}' found in '{self._name}' registry!") 66 | return ret 67 | 68 | def __contains__(self, name): 69 | return name in self._obj_map 70 | 71 | def __iter__(self): 72 | return iter(self._obj_map.items()) 73 | 74 | def keys(self): 75 | return self._obj_map.keys() 76 | 77 | 78 | DATASET_REGISTRY = Registry('dataset') 79 | ARCH_REGISTRY = Registry('arch') 80 | MODEL_REGISTRY = Registry('model') 81 | LOSS_REGISTRY = Registry('loss') 82 | METRIC_REGISTRY = Registry('metric') 83 | -------------------------------------------------------------------------------- /opensora/models/super_resolution/options/test/test_RGT_x2.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: test_RGT_x2 3 | model_type: RGTModel 4 | scale: 2 5 | num_gpu: 1 6 | manual_seed: 10 7 | 8 | datasets: 9 | test_1: # the 1st test dataset 10 | task: SR 11 | name: Set5 12 | type: PairedImageDataset 13 | dataroot_gt: datasets/benchmark/Set5/HR 14 | dataroot_lq: datasets/benchmark/Set5/LR_bicubic/X2 15 | filename_tmpl: '{}x2' 16 | io_backend: 17 | type: disk 18 | 19 | # test_2: # the 2st test dataset 20 | # task: SR 21 | # name: Set14 22 | # type: PairedImageDataset 23 | # dataroot_gt: datasets/benchmark/Set14/HR 24 | # dataroot_lq: datasets/benchmark/Set14/LR_bicubic/X2 25 | # filename_tmpl: '{}x2' 26 | # io_backend: 27 | # type: disk 28 | 29 | # test_3: # the 3st test dataset 30 | # task: SR 31 | # name: B100 32 | # type: PairedImageDataset 33 | # dataroot_gt: datasets/benchmark/B100/HR 34 | # dataroot_lq: datasets/benchmark/B100/LR_bicubic/X2 35 | # filename_tmpl: '{}x2' 36 | # io_backend: 37 | # type: disk 38 | 39 | # test_4: # the 4st test dataset 40 | # task: SR 41 | # name: Urban100 42 | # type: PairedImageDataset 43 | # dataroot_gt: datasets/benchmark/Urban100/HR 44 | # dataroot_lq: datasets/benchmark/Urban100/LR_bicubic/X2 45 | # filename_tmpl: '{}x2' 46 | # io_backend: 47 | # type: disk 48 | 49 | # test_5: # the 5st test dataset 50 | # task: SR 51 | # name: Manga109 52 | # type: PairedImageDataset 53 | # dataroot_gt: datasets/benchmark/Manga109/HR 54 | # dataroot_lq: datasets/benchmark/Manga109/LR_bicubic/X2 55 | # filename_tmpl: '{}_LRBI_x2' 56 | # io_backend: 57 | # type: disk 58 | 59 | 60 | # network structures 61 | network_g: 62 | type: RGT 63 | upscale: 2 64 | in_chans: 3 65 | img_size: 64 66 | img_range: 1. 67 | depth: [6,6,6,6,6,6,6,6] 68 | embed_dim: 180 69 | num_heads: [6,6,6,6,6,6,6,6] 70 | mlp_ratio: 2 71 | resi_connection: '1conv' 72 | split_size: [8,32] 73 | c_ratio: 0.5 74 | 75 | # path 76 | path: 77 | pretrain_network_g: /remote-home/lzy/RGT/experiments/pretrained_models/RGT_x2.pth 78 | strict_load_g: True 79 | 80 | # validation settings 81 | val: 82 | save_img: True 83 | suffix: ~ # add suffix to saved images, if None, use exp name 84 | use_chop: False # True to save memory, if img too large 85 | 86 | metrics: 87 | psnr: # metric name, can be arbitrary 88 | type: calculate_psnr 89 | crop_border: 2 90 | test_y_channel: True 91 | ssim: 92 | type: calculate_ssim 93 | crop_border: 2 94 | test_y_channel: True -------------------------------------------------------------------------------- /opensora/models/super_resolution/options/test/test_RGT_x4.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: test_RGT_x4 3 | model_type: RGTModel 4 | scale: 4 5 | num_gpu: 1 6 | manual_seed: 10 7 | 8 | datasets: 9 | test_1: # the 1st test dataset 10 | task: SR 11 | name: Set5 12 | type: PairedImageDataset 13 | dataroot_gt: datasets/benchmark/Set5/HR 14 | dataroot_lq: datasets/benchmark/Set5/LR_bicubic/X4 15 | filename_tmpl: '{}x4' 16 | io_backend: 17 | type: disk 18 | 19 | # test_2: # the 2st test dataset 20 | # task: SR 21 | # name: Set14 22 | # type: PairedImageDataset 23 | # dataroot_gt: datasets/benchmark/Set14/HR 24 | # dataroot_lq: datasets/benchmark/Set14/LR_bicubic/X4 25 | # filename_tmpl: '{}x4' 26 | # io_backend: 27 | # type: disk 28 | 29 | # test_3: # the 3st test dataset 30 | # task: SR 31 | # name: B100 32 | # type: PairedImageDataset 33 | # dataroot_gt: datasets/benchmark/B100/HR 34 | # dataroot_lq: datasets/benchmark/B100/LR_bicubic/X4 35 | # filename_tmpl: '{}x4' 36 | # io_backend: 37 | # type: disk 38 | 39 | # test_4: # the 4st test dataset 40 | # task: SR 41 | # name: Urban100 42 | # type: PairedImageDataset 43 | # dataroot_gt: datasets/benchmark/Urban100/HR 44 | # dataroot_lq: datasets/benchmark/Urban100/LR_bicubic/X4 45 | # filename_tmpl: '{}x4' 46 | # io_backend: 47 | # type: disk 48 | 49 | # test_5: # the 5st test dataset 50 | # task: SR 51 | # name: Manga109 52 | # type: PairedImageDataset 53 | # dataroot_gt: datasets/benchmark/Manga109/HR 54 | # dataroot_lq: datasets/benchmark/Manga109/LR_bicubic/X4 55 | # filename_tmpl: '{}_LRBI_x4' 56 | # io_backend: 57 | # type: disk 58 | 59 | 60 | # network structures 61 | network_g: 62 | type: RGT 63 | upscale: 4 64 | in_chans: 3 65 | img_size: 64 66 | img_range: 1. 67 | depth: [6,6,6,6,6,6,6,6] 68 | embed_dim: 180 69 | num_heads: [6,6,6,6,6,6,6,6] 70 | mlp_ratio: 2 71 | resi_connection: '1conv' 72 | split_size: [8,32] 73 | c_ratio: 0.5 74 | 75 | # path 76 | path: 77 | pretrain_network_g: /remote-home/lzy/RGT/experiments/pretrained_models/RGT_x4.pth 78 | strict_load_g: True 79 | 80 | # validation settings 81 | val: 82 | save_img: True 83 | suffix: ~ # add suffix to saved images, if None, use exp name 84 | use_chop: False # True to save memory, if img too large 85 | 86 | metrics: 87 | psnr: # metric name, can be arbitrary 88 | type: calculate_psnr 89 | crop_border: 4 90 | test_y_channel: True 91 | ssim: 92 | type: calculate_ssim 93 | crop_border: 4 94 | test_y_channel: True -------------------------------------------------------------------------------- /opensora/models/super_resolution/options/test/test_single_config.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: test_single 3 | model_type: RGTModel 4 | scale: 2 5 | num_gpu: 1 6 | manual_seed: 10 7 | 8 | datasets: 9 | test_1: # the 1st test dataset 10 | name: Single 11 | type: SingleImageDataset 12 | dataroot_lq: /test 13 | io_backend: 14 | type: disk 15 | 16 | 17 | # network structures 18 | network_g: 19 | type: RGT 20 | upscale: 2 21 | in_chans: 3 22 | img_size: 64 23 | img_range: 1. 24 | depth: [6,6,6,6,6,6,6,6] 25 | embed_dim: 180 26 | num_heads: [6,6,6,6,6,6,6,6] 27 | mlp_ratio: 2 28 | resi_connection: '1conv' 29 | split_size: [8,32] 30 | c_ratio: 0.5 31 | 32 | # path 33 | path: 34 | pretrain_network_g: /test 35 | strict_load_g: True 36 | 37 | # validation settings 38 | val: 39 | save_img: True 40 | suffix: ~ # add suffix to saved images, if None, use exp name 41 | use_chop: False # True to save memory, if img too large -------------------------------------------------------------------------------- /opensora/models/text_encoder/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from transformers import T5EncoderModel, CLIPModel, CLIPProcessor 4 | 5 | from opensora.utils.utils import get_precision 6 | 7 | 8 | class T5Wrapper(nn.Module): 9 | def __init__(self, args, **kwargs): 10 | super(T5Wrapper, self).__init__() 11 | self.model_name = args.text_encoder_name 12 | self.text_enc = T5EncoderModel.from_pretrained(self.model_name, cache_dir=args.cache_dir, **kwargs).eval() 13 | 14 | def forward(self, input_ids, attention_mask): 15 | text_encoder_embs = self.text_enc(input_ids=input_ids, attention_mask=attention_mask)['last_hidden_state'] 16 | return text_encoder_embs.detach() 17 | 18 | class CLIPWrapper(nn.Module): 19 | def __init__(self, args): 20 | super(CLIPWrapper, self).__init__() 21 | self.model_name = args.text_encoder_name 22 | dtype = get_precision(args) 23 | model_kwargs = {'cache_dir': args.cache_dir, 'low_cpu_mem_usage': True, 'torch_dtype': dtype} 24 | self.text_enc = CLIPModel.from_pretrained(self.model_name, **model_kwargs).eval() 25 | 26 | def forward(self, input_ids, attention_mask): 27 | text_encoder_embs = self.text_enc.get_text_features(input_ids=input_ids, attention_mask=attention_mask) 28 | return text_encoder_embs.detach() 29 | 30 | 31 | 32 | text_encoder = { 33 | 'DeepFloyd/t5-v1_1-xxl': T5Wrapper, 34 | 'openai/clip-vit-large-patch14': CLIPWrapper 35 | } 36 | 37 | 38 | def get_text_enc(args): 39 | """deprecation""" 40 | text_enc = text_encoder.get(args.text_encoder_name, None) 41 | assert text_enc is not None 42 | return text_enc(args) 43 | 44 | def get_text_warpper(text_encoder_name): 45 | """deprecation""" 46 | text_enc = text_encoder.get(text_encoder_name, None) 47 | assert text_enc is not None 48 | return text_enc 49 | -------------------------------------------------------------------------------- /opensora/train/train_videogpt.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append(".") 4 | 5 | from opensora.models.ae.videobase.dataset_videobase import VideoDataset 6 | from opensora.models.ae.videobase import ( 7 | VQVAEModel, 8 | VQVAEConfiguration, 9 | VQVAETrainer, 10 | ) 11 | import argparse 12 | from typing import Optional 13 | from accelerate.utils import set_seed 14 | from transformers import HfArgumentParser, TrainingArguments 15 | from dataclasses import dataclass, field, asdict 16 | 17 | 18 | @dataclass 19 | class VQVAEArgument: 20 | embedding_dim: int = (field(default=256),) 21 | n_codes: int = (field(default=2048),) 22 | n_hiddens: int = (field(default=240),) 23 | n_res_layers: int = (field(default=4),) 24 | resolution: int = (field(default=128),) 25 | sequence_length: int = (field(default=16),) 26 | downsample: str = (field(default="4,4,4"),) 27 | no_pos_embd: bool = (True,) 28 | data_path: str = field(default=None, metadata={"help": "data path"}) 29 | 30 | 31 | @dataclass 32 | class VQVAETrainingArgument(TrainingArguments): 33 | remove_unused_columns: Optional[bool] = field( 34 | default=False, 35 | metadata={ 36 | "help": "Remove columns not required by the model when using an nlp.Dataset." 37 | }, 38 | ) 39 | 40 | 41 | def train(args, vqvae_args: VQVAEArgument, training_args: VQVAETrainingArgument): 42 | # Load Config 43 | config = VQVAEConfiguration( 44 | embedding_dim=vqvae_args.embedding_dim, 45 | n_codes=vqvae_args.n_codes, 46 | n_hiddens=vqvae_args.n_hiddens, 47 | n_res_layers=vqvae_args.n_res_layers, 48 | resolution=vqvae_args.resolution, 49 | sequence_length=vqvae_args.sequence_length, 50 | downsample=vqvae_args.downsample, 51 | no_pos_embd=vqvae_args.no_pos_embd, 52 | ) 53 | # Load Model 54 | model = VQVAEModel(config) 55 | # Load Dataset 56 | dataset = VideoDataset( 57 | args.data_path, 58 | sequence_length=args.sequence_length, 59 | resolution=config.resolution, 60 | ) 61 | # Load Trainer 62 | trainer = VQVAETrainer(model, training_args, train_dataset=dataset) 63 | trainer.train() 64 | 65 | 66 | if __name__ == "__main__": 67 | parser = HfArgumentParser((VQVAEArgument, VQVAETrainingArgument)) 68 | vqvae_args, training_args = parser.parse_args_into_dataclasses() 69 | args = argparse.Namespace(**vars(vqvae_args), **vars(training_args)) 70 | set_seed(args.seed) 71 | 72 | train(args, vqvae_args, training_args) 73 | -------------------------------------------------------------------------------- /opensora/utils/downloader.py: -------------------------------------------------------------------------------- 1 | import gdown 2 | import os 3 | 4 | opensora_cache_home = os.path.expanduser( 5 | os.getenv("OPENSORA_HOME", os.path.join("~/.cache", "opensora")) 6 | ) 7 | 8 | 9 | def gdown_download(id, fname, cache_dir=None): 10 | cache_dir = opensora_cache_home if not cache_dir else cache_dir 11 | 12 | os.makedirs(cache_dir, exist_ok=True) 13 | destination = os.path.join(cache_dir, fname) 14 | if os.path.exists(destination): 15 | return destination 16 | 17 | gdown.download(id=id, output=destination, quiet=False) 18 | return destination 19 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "opensora" 7 | version = "1.0.0" 8 | description = "Reproduce OpenAI's Sora." 9 | readme = "README.md" 10 | requires-python = ">=3.8" 11 | classifiers = [ 12 | "Programming Language :: Python :: 3", 13 | "License :: OSI Approved :: Apache Software License", 14 | ] 15 | dependencies = [ 16 | "albumentations", "av", "decord", "einops", 17 | "gdown", "h5py", "idna", 'imageio', "matplotlib", 18 | "omegaconf", "opencv-python", "opencv-python-headless", "pandas", "pillow", 19 | "pydub", "pytorch-lightning", "pytorchvideo", "PyYAML", "regex", 20 | "requests", "scikit-learn", "scipy", "six", "test-tube", 21 | "timm", "torchdiffeq", "torchmetrics", "tqdm", "urllib3", "uvicorn", 22 | "diffusers", "scikit-video", "imageio-ffmpeg", "sentencepiece", "beautifulsoup4", 23 | "ftfy", "moviepy", "wandb", "tensorboard" 24 | ] 25 | 26 | [project.optional-dependencies] 27 | train = ["deepspeed==0.9.5", "pydantic==1.10.13"] 28 | dev = ["mypy==1.8.0"] 29 | 30 | 31 | [project.urls] 32 | "Homepage" = "https://github.com/PKU-YuanGroup/Open-Sora-Plan" 33 | "Bug Tracker" = "https://github.com/PKU-YuanGroup/Open-Sora-Plan/issues" 34 | 35 | [tool.setuptools.packages.find] 36 | exclude = ["assets*", "docker*", "docs", "scripts*"] 37 | 38 | [tool.wheel] 39 | exclude = ["assets*", "docker*", "docs", "scripts*"] 40 | 41 | [tool.mypy] 42 | warn_return_any = true 43 | warn_unused_configs = true 44 | ignore_missing_imports = true 45 | disallow_untyped_calls = true 46 | check_untyped_defs = true 47 | no_implicit_optional = true 48 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | albumentations 2 | av 3 | decord 4 | einops 5 | gdown 6 | h5py 7 | idna 8 | imageio 9 | matplotlib 10 | omegaconf 11 | opencv-python 12 | opencv-python-headless 13 | pandas 14 | pillow 15 | pydub 16 | pytorch-lightning 17 | pytorchvideo 18 | PyYAML 19 | regex 20 | requests 21 | scikit-learn 22 | scipy 23 | six 24 | test-tube 25 | timm 26 | torchdiffeq 27 | torchmetrics 28 | tqdm 29 | urllib3 30 | uvicorn 31 | diffusers 32 | scikit-video 33 | imageio-ffmpeg 34 | sentencepiece 35 | beautifulsoup4 36 | ftfy 37 | moviepy 38 | wandb 39 | tensorboard 40 | accelerate -------------------------------------------------------------------------------- /scripts/accelerate_configs/ddp_config.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | distributed_type: MULTI_GPU 3 | fsdp_config: {} 4 | machine_rank: 0 5 | main_process_ip: null 6 | main_process_port: 29501 7 | main_training_function: main 8 | num_machines: 1 9 | num_processes: 8 10 | gpu_ids: 0,1,2,3,4,5,6,7 11 | use_cpu: false -------------------------------------------------------------------------------- /scripts/accelerate_configs/deepspeed_zero2_config.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | distributed_type: DEEPSPEED 3 | deepspeed_config: 4 | deepspeed_config_file: scripts/accelerate_configs/zero2.json 5 | fsdp_config: {} 6 | machine_rank: 0 7 | main_process_ip: null 8 | main_process_port: 29501 9 | main_training_function: main 10 | num_machines: 1 11 | num_processes: 8 12 | gpu_ids: 0,1,2,3,4,5,6,7 13 | use_cpu: false -------------------------------------------------------------------------------- /scripts/accelerate_configs/deepspeed_zero2_offload_config.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | distributed_type: DEEPSPEED 3 | deepspeed_config: 4 | deepspeed_config_file: scripts/accelerate_configs/zero2_offload.json 5 | fsdp_config: {} 6 | machine_rank: 0 7 | main_process_ip: null 8 | main_process_port: 29501 9 | main_training_function: main 10 | num_machines: 1 11 | num_processes: 8 12 | gpu_ids: 0,1,2,3,4,5,6,7 13 | use_cpu: false -------------------------------------------------------------------------------- /scripts/accelerate_configs/deepspeed_zero3_config.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | distributed_type: DEEPSPEED 3 | deepspeed_config: 4 | deepspeed_config_file: scripts/accelerate_configs/zero3.json 5 | fsdp_config: {} 6 | machine_rank: 0 7 | main_process_ip: null 8 | main_process_port: 29501 9 | main_training_function: main 10 | num_machines: 1 11 | num_processes: 8 12 | gpu_ids: 0,1,2,3,4,5,6,7 13 | use_cpu: false -------------------------------------------------------------------------------- /scripts/accelerate_configs/deepspeed_zero3_offload_config.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | distributed_type: DEEPSPEED 3 | deepspeed_config: 4 | deepspeed_config_file: scripts/accelerate_configs/zero3_offload.json 5 | fsdp_config: {} 6 | machine_rank: 0 7 | main_process_ip: null 8 | main_process_port: 29501 9 | main_training_function: main 10 | num_machines: 1 11 | num_processes: 8 12 | gpu_ids: 0,1,2,3,4,5,6,7 13 | use_cpu: false -------------------------------------------------------------------------------- /scripts/accelerate_configs/default_config.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | distributed_type: MULTI_GPU 3 | fsdp_config: {} 4 | machine_rank: 0 5 | main_process_ip: null 6 | main_process_port: 29501 7 | main_training_function: main 8 | mixed_precision: bf16 9 | num_machines: 1 10 | num_processes: 8 11 | gpu_ids: 0,1,2,3,4,5,6,7 12 | use_cpu: false -------------------------------------------------------------------------------- /scripts/accelerate_configs/hostfile: -------------------------------------------------------------------------------- 1 | gpu55 slots=8 # your server name and GPU in total 2 | gpu117 slots=8 3 | -------------------------------------------------------------------------------- /scripts/accelerate_configs/multi_node_example.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | distributed_type: DEEPSPEED 3 | deepspeed_config: 4 | deepspeed_config_file: scripts/accelerate_configs/zero2.json 5 | deepspeed_hostfile: /remote-home1/yeyang/Open-Sora-Plan/scripts/accelerate_configs/hostfile 6 | fsdp_config: {} 7 | machine_rank: 0 8 | main_process_ip: 10.10.10.55 9 | main_process_port: 29501 10 | main_training_function: main 11 | num_machines: 2 12 | num_processes: 16 13 | rdzv_backend: static 14 | same_network: true 15 | tpu_env: [] 16 | tpu_use_cluster: false 17 | tpu_use_sudo: false 18 | use_cpu: false 19 | -------------------------------------------------------------------------------- /scripts/accelerate_configs/zero2.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "train_micro_batch_size_per_gpu": "auto", 14 | "train_batch_size": "auto", 15 | "gradient_accumulation_steps": "auto", 16 | "zero_optimization": { 17 | "stage": 2, 18 | "overlap_comm": true, 19 | "contiguous_gradients": true, 20 | "sub_group_size": 1e9, 21 | "reduce_bucket_size": 5e8 22 | } 23 | } -------------------------------------------------------------------------------- /scripts/accelerate_configs/zero2_offload.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "train_micro_batch_size_per_gpu": "auto", 14 | "train_batch_size": "auto", 15 | "gradient_accumulation_steps": "auto", 16 | "zero_optimization": { 17 | "stage": 2, 18 | "offload_optimizer": { 19 | "device": "cpu" 20 | }, 21 | "overlap_comm": true, 22 | "contiguous_gradients": true, 23 | "sub_group_size": 1e9, 24 | "reduce_bucket_size": 5e8 25 | } 26 | } -------------------------------------------------------------------------------- /scripts/accelerate_configs/zero3.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "train_micro_batch_size_per_gpu": "auto", 14 | "train_batch_size": "auto", 15 | "gradient_accumulation_steps": "auto", 16 | "zero_optimization": { 17 | "stage": 3, 18 | "overlap_comm": true, 19 | "contiguous_gradients": true, 20 | "sub_group_size": 1e9, 21 | "reduce_bucket_size": 5e8, 22 | "stage3_prefetch_bucket_size": "auto", 23 | "stage3_param_persistence_threshold": "auto", 24 | "stage3_max_live_parameters": 1e9, 25 | "stage3_max_reuse_distance": 1e9, 26 | "stage3_gather_16bit_weights_on_model_save": true 27 | } 28 | } -------------------------------------------------------------------------------- /scripts/accelerate_configs/zero3_offload.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "zero_optimization": { 14 | "stage": 3, 15 | "offload_optimizer": { 16 | "device": "cpu", 17 | "pin_memory": true 18 | }, 19 | "offload_param": { 20 | "device": "cpu", 21 | "pin_memory": true 22 | }, 23 | "overlap_comm": true, 24 | "contiguous_gradients": true, 25 | "sub_group_size": 1e9, 26 | "reduce_bucket_size": 5e8, 27 | "stage3_prefetch_bucket_size": "auto", 28 | "stage3_param_persistence_threshold": "auto", 29 | "stage3_max_live_parameters": 1e9, 30 | "stage3_max_reuse_distance": 1e9, 31 | "gather_16bit_weights_on_model_save": true 32 | }, 33 | "gradient_accumulation_steps": "auto", 34 | "gradient_clipping": "auto", 35 | "train_batch_size": "auto", 36 | "train_micro_batch_size_per_gpu": "auto", 37 | "steps_per_print": 1e5, 38 | "wall_clock_breakdown": false 39 | } -------------------------------------------------------------------------------- /scripts/causalvae/eval.sh: -------------------------------------------------------------------------------- 1 | python opensora/eval/eval_common_metric.py \ 2 | --batch_size 2 \ 3 | --real_video_dir ..//test_eval/release/origin \ 4 | --generated_video_dir ../test_eval/release \ 5 | --device cuda \ 6 | --sample_fps 10 \ 7 | --crop_size 256 \ 8 | --resolution 256 \ 9 | --num_frames 17 \ 10 | --sample_rate 1 \ 11 | --subset_size 100 \ 12 | --metric ssim -------------------------------------------------------------------------------- /scripts/causalvae/gen_video.sh: -------------------------------------------------------------------------------- 1 | python examples/rec_video_vae.py \ 2 | --batch_size 1 \ 3 | --real_video_dir ../test_eval/eyes_test \ 4 | --generated_video_dir ../test_eval/eyes_gen \ 5 | --device cuda \ 6 | --sample_fps 10 \ 7 | --sample_rate 1 \ 8 | --num_frames 17 \ 9 | --resolution 512 \ 10 | --crop_size 512 \ 11 | --num_workers 8 \ 12 | --ckpt results/pretrained_488 \ 13 | --enable_tiling -------------------------------------------------------------------------------- /scripts/causalvae/release.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "CausalVAEModel", 3 | "_diffusers_version": "0.27.2", 4 | "attn_resolutions": [], 5 | "decoder_attention": "AttnBlock3D", 6 | "decoder_conv_in": "CausalConv3d", 7 | "decoder_conv_out": "CausalConv3d", 8 | "decoder_mid_resnet": "ResnetBlock3D", 9 | "decoder_resnet_blocks": [ 10 | "ResnetBlock3D", 11 | "ResnetBlock3D", 12 | "ResnetBlock3D", 13 | "ResnetBlock3D" 14 | ], 15 | "decoder_spatial_upsample": [ 16 | "", 17 | "SpatialUpsample2x", 18 | "SpatialUpsample2x", 19 | "SpatialUpsample2x" 20 | ], 21 | "decoder_temporal_upsample": [ 22 | "", 23 | "", 24 | "TimeUpsample2x", 25 | "TimeUpsample2x" 26 | ], 27 | "double_z": true, 28 | "dropout": 0.0, 29 | "embed_dim": 4, 30 | "encoder_attention": "AttnBlock3D", 31 | "encoder_conv_in": "CausalConv3d", 32 | "encoder_conv_out": "CausalConv3d", 33 | "encoder_mid_resnet": "ResnetBlock3D", 34 | "encoder_resnet_blocks": [ 35 | "ResnetBlock3D", 36 | "ResnetBlock3D", 37 | "ResnetBlock3D", 38 | "ResnetBlock3D" 39 | ], 40 | "encoder_spatial_downsample": [ 41 | "SpatialDownsample2x", 42 | "SpatialDownsample2x", 43 | "SpatialDownsample2x", 44 | "" 45 | ], 46 | "encoder_temporal_downsample": [ 47 | "TimeDownsample2x", 48 | "TimeDownsample2x", 49 | "", 50 | "" 51 | ], 52 | "hidden_size": 128, 53 | "hidden_size_mult": [ 54 | 1, 55 | 2, 56 | 4, 57 | 4 58 | ], 59 | "loss_params": { 60 | "disc_start": 2001, 61 | "disc_weight": 0.5, 62 | "kl_weight": 1e-06, 63 | "logvar_init": 0.0 64 | }, 65 | "loss_type": "opensora.models.ae.videobase.losses.LPIPSWithDiscriminator", 66 | "lr": 1e-05, 67 | "num_res_blocks": 2, 68 | "q_conv": "CausalConv3d", 69 | "resolution": 256, 70 | "z_channels": 4 71 | } 72 | -------------------------------------------------------------------------------- /scripts/causalvae/train.sh: -------------------------------------------------------------------------------- 1 | python opensora/train/train_causalvae.py \ 2 | --exp_name "exp_name" \ 3 | --batch_size 1 \ 4 | --precision bf16 \ 5 | --max_steps 40000 \ 6 | --save_steps 100 \ 7 | --output_dir results/causalvae_ \ 8 | --video_path /remote-home1/dataset/data_split_tt \ 9 | --video_num_frames 1 \ 10 | --resolution 32 \ 11 | --sample_rate 1 \ 12 | --n_nodes 1 \ 13 | --devices 1 \ 14 | --num_workers 8 \ 15 | --load_from_checkpoint ./results/pretrained_488/ -------------------------------------------------------------------------------- /scripts/class_condition/sample.sh: -------------------------------------------------------------------------------- 1 | accelerate launch \ 2 | --num_processes 1 \ 3 | --main_process_port 29502 \ 4 | opensora/sample/sample.py \ 5 | --model Latte-XL/122 \ 6 | --ae stabilityai/sd-vae-ft-mse \ 7 | --ckpt ucf101-f16s3-128-imgvae188-bf16-ckpt-flash/checkpoint-98500 \ 8 | --train_classcondition \ 9 | --num_classes 101 \ 10 | --fps 10 \ 11 | --num_frames 16 \ 12 | --image_size 128 \ 13 | --num_sampling_steps 500 \ 14 | --attention_mode flash \ 15 | --mixed_precision bf16 16 | -------------------------------------------------------------------------------- /scripts/class_condition/train_imgae.sh: -------------------------------------------------------------------------------- 1 | export WANDB_KEY="" 2 | export ENTITY="" 3 | export PROJECT="ucf101-f16s3-128-imgvae188-bf16-ckpt-flash" 4 | accelerate launch \ 5 | --config_file scripts/accelerate_configs/ddp_config.yaml \ 6 | opensora/train/train.py \ 7 | --model Latte-XL/122 \ 8 | --dataset ucf101 \ 9 | --ae stabilityai/sd-vae-ft-mse \ 10 | --data_path /remote-home/yeyang/UCF-101 \ 11 | --train_classcondition \ 12 | --num_classes 101 \ 13 | --sample_rate 3 \ 14 | --num_frames 16 \ 15 | --max_image_size 128 \ 16 | --gradient_checkpointing \ 17 | --attention_mode flash \ 18 | --train_batch_size=8 --dataloader_num_workers 10 \ 19 | --gradient_accumulation_steps=1 \ 20 | --max_train_steps=1000000 \ 21 | --learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \ 22 | --mixed_precision="bf16" \ 23 | --report_to="wandb" \ 24 | --checkpointing_steps=500 \ 25 | --output_dir="ucf101-f16s3-128-imgvae188-bf16-ckpt-flash" \ 26 | --allow_tf32 27 | -------------------------------------------------------------------------------- /scripts/class_condition/train_vidae.sh: -------------------------------------------------------------------------------- 1 | export WANDB_KEY="" 2 | export ENTITY="" 3 | export PROJECT="ucf101-f16s3-128-causalvideovae444-bf16-ckpt-flash" 4 | accelerate launch \ 5 | --config_file scripts/accelerate_configs/ddp_config.yaml \ 6 | opensora/train/train.py \ 7 | --model Latte-XL/122 \ 8 | --dataset ucf101 \ 9 | --ae CausalVQVAEModel_4x4x4 \ 10 | --data_path /remote-home/yeyang/UCF-101 \ 11 | --train_classcondition \ 12 | --num_classes 101 \ 13 | --sample_rate 3 \ 14 | --num_frames 16 \ 15 | --max_image_size 128 \ 16 | --gradient_checkpointing \ 17 | --attention_mode flash \ 18 | --train_batch_size=8 --dataloader_num_workers 10 \ 19 | --gradient_accumulation_steps=1 \ 20 | --max_train_steps=1000000 \ 21 | --learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \ 22 | --mixed_precision="bf16" \ 23 | --report_to="wandb" \ 24 | --checkpointing_steps=500 \ 25 | --output_dir="ucf101-f16s3-128-causalvideovae444-bf16-ckpt-flash" \ 26 | --allow_tf32 27 | -------------------------------------------------------------------------------- /scripts/slurm/placeholder: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-Open-Sora-Plan/b060ff6d7a85a27eec5ff9b81b599d03c4ac1bc6/scripts/slurm/placeholder -------------------------------------------------------------------------------- /scripts/text_condition/sample_image.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python opensora/sample/sample_t2v.py \ 2 | --model_path LanguageBind/Open-Sora-Plan-v1.0.0 \ 3 | --text_encoder_name DeepFloyd/t5-v1_1-xxl \ 4 | --text_prompt examples/prompt_list_0.txt \ 5 | --ae CausalVAEModel_4x8x8 \ 6 | --version 65x512x512 \ 7 | --save_img_path "./sample_images/prompt_list_0" \ 8 | --fps 24 \ 9 | --guidance_scale 7.5 \ 10 | --num_sampling_steps 250 \ 11 | --enable_tiling \ 12 | --force_images 13 | -------------------------------------------------------------------------------- /scripts/text_condition/sample_video.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=7 python opensora/sample/sample_t2v.py \ 2 | --model_path LanguageBind/Open-Sora-Plan-v1.0.0 \ 3 | --text_encoder_name DeepFloyd/t5-v1_1-xxl \ 4 | --text_prompt examples/prompt_list_7.txt \ 5 | --ae CausalVAEModel_4x8x8 \ 6 | --version 65x512x512 \ 7 | --save_img_path "./sample_videos/prompt_list_7" \ 8 | --fps 24 \ 9 | --guidance_scale 7.5 \ 10 | --num_sampling_steps 250 \ 11 | --enable_tiling 12 | -------------------------------------------------------------------------------- /scripts/text_condition/train_imageae.sh: -------------------------------------------------------------------------------- 1 | export WANDB_KEY="" 2 | export ENTITY="" 3 | export PROJECT="t2v-f16s3-img4-128-imgvae188-bf16-gc-xformers" 4 | accelerate launch \ 5 | --config_file scripts/accelerate_configs/deepspeed_zero2_config.yaml \ 6 | opensora/train/train_t2v.py \ 7 | --model LatteT2V-XL/122 \ 8 | --text_encoder_name DeepFloyd/t5-v1_1-xxl \ 9 | --dataset t2v \ 10 | --ae stabilityai/sd-vae-ft-mse \ 11 | --data_path /remote-home1/dataset/sharegpt4v_path_cap_.json \ 12 | --video_folder /remote-home1/dataset/data_split \ 13 | --sample_rate 1 \ 14 | --num_frames 17 \ 15 | --max_image_size 256 \ 16 | --gradient_checkpointing \ 17 | --attention_mode xformers \ 18 | --train_batch_size=4 \ 19 | --dataloader_num_workers 10 \ 20 | --gradient_accumulation_steps=1 \ 21 | --max_train_steps=1000000 \ 22 | --learning_rate=2e-05 \ 23 | --lr_scheduler="constant" \ 24 | --lr_warmup_steps=0 \ 25 | --mixed_precision="bf16" \ 26 | --report_to="wandb" \ 27 | --checkpointing_steps=500 \ 28 | --output_dir="t2v-f17-256-img4-imagevae488-bf16-ckpt-xformers-bs4-lr2e-5-t5" \ 29 | --allow_tf32 \ 30 | --pretrained t2v.pt \ 31 | --use_deepspeed \ 32 | --model_max_length 300 \ 33 | --use_image_num 4 \ 34 | --use_img_from_vid 35 | -------------------------------------------------------------------------------- /scripts/text_condition/train_videoae_17x256x256.sh: -------------------------------------------------------------------------------- 1 | export WANDB_KEY="" 2 | export ENTITY="" 3 | export PROJECT="t2v-f17-256-img4-videovae488-bf16-ckpt-xformers-bs4-lr2e-5-t5" 4 | accelerate launch \ 5 | --config_file scripts/accelerate_configs/deepspeed_zero2_config.yaml \ 6 | opensora/train/train_t2v.py \ 7 | --model LatteT2V-XL/122 \ 8 | --text_encoder_name DeepFloyd/t5-v1_1-xxl \ 9 | --dataset t2v \ 10 | --ae CausalVAEModel_4x8x8 \ 11 | --ae_path CausalVAEModel_4x8x8 \ 12 | --data_path /remote-home1/dataset/sharegpt4v_path_cap_64x512x512.json \ 13 | --video_folder /remote-home1/dataset/data_split_tt \ 14 | --sample_rate 1 \ 15 | --num_frames 17 \ 16 | --max_image_size 256 \ 17 | --gradient_checkpointing \ 18 | --attention_mode xformers \ 19 | --train_batch_size=4 \ 20 | --dataloader_num_workers 10 \ 21 | --gradient_accumulation_steps=1 \ 22 | --max_train_steps=1000000 \ 23 | --learning_rate=2e-05 \ 24 | --lr_scheduler="constant" \ 25 | --lr_warmup_steps=0 \ 26 | --mixed_precision="bf16" \ 27 | --report_to="wandb" \ 28 | --checkpointing_steps=500 \ 29 | --output_dir="t2v-f17-256-img4-videovae488-bf16-ckpt-xformers-bs4-lr2e-5-t5" \ 30 | --allow_tf32 \ 31 | --pretrained t2v.pt \ 32 | --use_deepspeed \ 33 | --model_max_length 300 \ 34 | --use_image_num 4 \ 35 | --use_img_from_vid 36 | -------------------------------------------------------------------------------- /scripts/text_condition/train_videoae_65x256x256.sh: -------------------------------------------------------------------------------- 1 | export WANDB_KEY="" 2 | export ENTITY="" 3 | export PROJECT="t2v-f65-256-img4-videovae488-bf16-ckpt-xformers-bs4-lr2e-5-t5" 4 | accelerate launch \ 5 | --config_file scripts/accelerate_configs/deepspeed_zero2_config.yaml \ 6 | opensora/train/train_t2v.py \ 7 | --model LatteT2V-XL/122 \ 8 | --text_encoder_name DeepFloyd/t5-v1_1-xxl \ 9 | --dataset t2v \ 10 | --ae CausalVAEModel_4x8x8 \ 11 | --ae_path CausalVAEModel_4x8x8 \ 12 | --data_path /remote-home1/dataset/sharegpt4v_path_cap_.json \ 13 | --video_folder /remote-home1/dataset/data_split_tt \ 14 | --sample_rate 1 \ 15 | --num_frames 65 \ 16 | --max_image_size 256 \ 17 | --gradient_checkpointing \ 18 | --attention_mode xformers \ 19 | --train_batch_size=4 \ 20 | --dataloader_num_workers 10 \ 21 | --gradient_accumulation_steps=1 \ 22 | --max_train_steps=1000000 \ 23 | --learning_rate=2e-05 \ 24 | --lr_scheduler="constant" \ 25 | --lr_warmup_steps=0 \ 26 | --mixed_precision="bf16" \ 27 | --report_to="wandb" \ 28 | --checkpointing_steps=500 \ 29 | --output_dir="t2v-f65-256-img4-videovae488-bf16-ckpt-xformers-bs4-lr2e-5-t5" \ 30 | --allow_tf32 \ 31 | --pretrained t2v.pt \ 32 | --use_deepspeed \ 33 | --model_max_length 300 \ 34 | --use_image_num 4 \ 35 | --use_img_from_vid 36 | -------------------------------------------------------------------------------- /scripts/text_condition/train_videoae_65x512x512.sh: -------------------------------------------------------------------------------- 1 | export WANDB_KEY="" 2 | export ENTITY="" 3 | export PROJECT="t2v-f65-256-img16-videovae488-bf16-ckpt-xformers-bs4-lr2e-5-t5" 4 | accelerate launch \ 5 | --config_file scripts/accelerate_configs/deepspeed_zero2_config.yaml \ 6 | opensora/train/train_t2v.py \ 7 | --model LatteT2V-XL/122 \ 8 | --text_encoder_name DeepFloyd/t5-v1_1-xxl \ 9 | --dataset t2v \ 10 | --ae CausalVAEModel_4x8x8 \ 11 | --ae_path CausalVAEModel_4x8x8 \ 12 | --data_path /remote-home1/dataset/sharegpt4v_path_cap_.json \ 13 | --video_folder /remote-home1/dataset/data_split_tt \ 14 | --sample_rate 1 \ 15 | --num_frames 65 \ 16 | --max_image_size 512 \ 17 | --gradient_checkpointing \ 18 | --attention_mode xformers \ 19 | --train_batch_size=2 \ 20 | --dataloader_num_workers 10 \ 21 | --gradient_accumulation_steps=1 \ 22 | --max_train_steps=1000000 \ 23 | --learning_rate=2e-05 \ 24 | --lr_scheduler="constant" \ 25 | --lr_warmup_steps=0 \ 26 | --mixed_precision="bf16" \ 27 | --report_to="wandb" \ 28 | --checkpointing_steps=500 \ 29 | --output_dir="t2v-f65-512-img16-videovae488-bf16-ckpt-xformers-bs4-lr2e-5-t5" \ 30 | --allow_tf32 \ 31 | --pretrained t2v.pt \ 32 | --use_deepspeed \ 33 | --model_max_length 300 \ 34 | --use_image_num 16 \ 35 | --use_img_from_vid \ 36 | --enable_tiling 37 | -------------------------------------------------------------------------------- /scripts/un_condition/sample.sh: -------------------------------------------------------------------------------- 1 | 2 | accelerate launch \ 3 | --num_processes 1 \ 4 | --main_process_port 29502 \ 5 | opensora/sample/sample.py \ 6 | --model Latte-XL/122 \ 7 | --ae CausalVQVAEModel \ 8 | --ckpt sky-f17s3-128-causalvideovae488-bf16-ckpt-flash-log/checkpoint-45500 \ 9 | --fps 10 \ 10 | --num_frames 17 \ 11 | --image_size 128 \ 12 | --num_sampling_steps 250 \ 13 | --attention_mode flash \ 14 | --mixed_precision bf16 \ 15 | --num_sample 10 -------------------------------------------------------------------------------- /scripts/un_condition/train_imgae.sh: -------------------------------------------------------------------------------- 1 | export WANDB_KEY="" 2 | export ENTITY="" 3 | export PROJECT="sky-f16s3-128-imgae188-bf16-ckpt-flash-log" 4 | HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 accelerate launch \ 5 | --config_file scripts/accelerate_configs/ddp_config.yaml \ 6 | opensora/train/train.py \ 7 | --model Latte-XL/122 \ 8 | --dataset sky \ 9 | --ae stabilityai/sd-vae-ft-mse \ 10 | --data_path /remote-home/yeyang/sky_timelapse/sky_train/ \ 11 | --sample_rate 3 \ 12 | --num_frames 16 \ 13 | --max_image_size 128 \ 14 | --gradient_checkpointing \ 15 | --attention_mode flash \ 16 | --train_batch_size=8 --dataloader_num_workers 10 \ 17 | --gradient_accumulation_steps=1 \ 18 | --max_train_steps=1000000 \ 19 | --learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \ 20 | --mixed_precision="bf16" \ 21 | --report_to="wandb" \ 22 | --checkpointing_steps=500 \ 23 | --output_dir="sky-f16s3-128-imgae188-bf16-ckpt-flash-log" \ 24 | --allow_tf32 25 | 26 | -------------------------------------------------------------------------------- /scripts/un_condition/train_vidae.sh: -------------------------------------------------------------------------------- 1 | export WANDB_KEY="" 2 | export ENTITY="" 3 | export PROJECT="sky-f17s3-128-causalvideovae444-bf16-ckpt-flash-log" 4 | HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 accelerate launch \ 5 | --config_file scripts/accelerate_configs/ddp_config.yaml \ 6 | opensora/train/train.py \ 7 | --model Latte-XL/122 \ 8 | --dataset sky \ 9 | --ae CausalVQVAEModel_4x4x4 \ 10 | --data_path /remote-home/yeyang/sky_timelapse/sky_train/ \ 11 | --sample_rate 3 \ 12 | --num_frames 17 \ 13 | --max_image_size 128 \ 14 | --gradient_checkpointing \ 15 | --attention_mode flash \ 16 | --train_batch_size=8 --dataloader_num_workers 10 \ 17 | --gradient_accumulation_steps=1 \ 18 | --max_train_steps=1000000 \ 19 | --learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \ 20 | --mixed_precision="bf16" \ 21 | --report_to="wandb" \ 22 | --checkpointing_steps=500 \ 23 | --output_dir="sky-f17s3-128-causalvideovae444-bf16-ckpt-flash-log" \ 24 | --allow_tf32 25 | 26 | -------------------------------------------------------------------------------- /scripts/videogpt/train_videogpt.sh: -------------------------------------------------------------------------------- 1 | 2 | accelerate launch \ 3 | --config_file scripts/accelerate_configs/ddp_config.yaml \ 4 | opensora/train/train_videogpt.py \ 5 | --do_train \ 6 | --seed 1234 \ 7 | --data_path "/remote-home/yeyang/UCF-101/" \ 8 | --per_device_train_batch_size 1 \ 9 | --gradient_accumulation_steps 1 \ 10 | --learning_rate 7e-4 \ 11 | --weight_decay 0. \ 12 | --max_steps 20000 \ 13 | --lr_scheduler_type cosine \ 14 | --max_grad_norm 1.0 \ 15 | --save_strategy steps \ 16 | --save_total_limit 5 \ 17 | --logging_steps 5 \ 18 | --save_steps 1000 \ 19 | --n_codes 2048 \ 20 | --n_hiddens 240 \ 21 | --embedding_dim 4 \ 22 | --n_res_layers 4 \ 23 | --downsample "4,4,4" \ 24 | --resolution 240 \ 25 | --sequence_length 16 \ 26 | --output_dir results/videogpt_488_256_16 \ 27 | --bf16 True \ 28 | --fp16 False \ 29 | --report_to tensorboard \ 30 | --dataloader_num_workers 10 31 | -------------------------------------------------------------------------------- /scripts/videogpt/train_videogpt_dsz2.sh: -------------------------------------------------------------------------------- 1 | export ACCELERATE_GRADIENT_ACCUMULATION_STEPS=1 2 | 3 | accelerate launch \ 4 | --config_file scripts/accelerate_configs/deepspeed_zero2_config.yaml \ 5 | opensora/train/train_videogpt.py \ 6 | --do_train \ 7 | --seed 1234 \ 8 | --data_path "datasets/UCF-101/" \ 9 | --per_device_train_batch_size 32 \ 10 | --gradient_accumulation_steps $ACCELERATE_GRADIENT_ACCUMULATION_STEPS \ 11 | --learning_rate 7e-4 \ 12 | --weight_decay 0. \ 13 | --num_train_epochs 2 \ 14 | --lr_scheduler_type cosine \ 15 | --max_grad_norm 1.0 \ 16 | --save_strategy steps \ 17 | --save_total_limit 5 \ 18 | --logging_steps 5 \ 19 | --save_steps 10000 \ 20 | --n_codes 1024 \ 21 | --n_hiddens 240 \ 22 | --embedding_dim 4 \ 23 | --n_res_layers 4 \ 24 | --downsample "4,4,4" \ 25 | --resolution 128 \ 26 | --sequence_length 16 \ 27 | --output_dir results/videogpt_444_128 \ 28 | --bf16 True \ 29 | --fp16 False \ 30 | --report_to tensorboard 31 | -------------------------------------------------------------------------------- /scripts/videogpt/train_videogpt_dsz3.sh: -------------------------------------------------------------------------------- 1 | export ACCELERATE_GRADIENT_ACCUMULATION_STEPS=1 2 | 3 | accelerate launch \ 4 | --config_file scripts/accelerate_configs/deepspeed_zero3_config.yaml \ 5 | opensora/train/train_videogpt.py \ 6 | --do_train \ 7 | --seed 1234 \ 8 | --data_path "datasets/UCF-101/" \ 9 | --per_device_train_batch_size 32 \ 10 | --gradient_accumulation_steps $ACCELERATE_GRADIENT_ACCUMULATION_STEPS \ 11 | --learning_rate 7e-4 \ 12 | --weight_decay 0. \ 13 | --num_train_epochs 2 \ 14 | --lr_scheduler_type cosine \ 15 | --max_grad_norm 1.0 \ 16 | --save_strategy steps \ 17 | --save_total_limit 5 \ 18 | --logging_steps 5 \ 19 | --save_steps 10000 \ 20 | --n_codes 1024 \ 21 | --n_hiddens 240 \ 22 | --embedding_dim 4 \ 23 | --n_res_layers 4 \ 24 | --downsample "4,4,4" \ 25 | --resolution 128 \ 26 | --sequence_length 16 \ 27 | --output_dir results/videogpt_444_128 \ 28 | --bf16 True \ 29 | --fp16 False \ 30 | --report_to tensorboard 31 | -------------------------------------------------------------------------------- /wf.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-Open-Sora-Plan/b060ff6d7a85a27eec5ff9b81b599d03c4ac1bc6/wf.png --------------------------------------------------------------------------------