├── .gitignore
├── LICENSE
├── README.md
├── __init__.py
├── docker
├── LICENSE
├── README.md
├── build_docker.png
├── docker_build.sh
├── docker_run.sh
├── dockerfile.base
├── packages.txt
├── ports.txt
├── postinstallscript.sh
├── requirements.txt
├── run_docker.png
└── setup_env.sh
├── docs
├── Contribution_Guidelines.md
├── Data.md
├── EVAL.md
├── Report-v1.0.0.md
└── VQVAE.md
├── examples
├── get_latents_std.py
├── prompt_list_0.txt
├── rec_image.py
├── rec_imvi_vae.py
├── rec_video.py
├── rec_video_ae.py
└── rec_video_vae.py
├── nodes.py
├── opensora
├── __init__.py
├── dataset
│ ├── __init__.py
│ ├── extract_feature_dataset.py
│ ├── feature_datasets.py
│ ├── landscope.py
│ ├── sky_datasets.py
│ ├── t2v_datasets.py
│ ├── transform.py
│ └── ucf101.py
├── eval
│ ├── cal_flolpips.py
│ ├── cal_fvd.py
│ ├── cal_lpips.py
│ ├── cal_psnr.py
│ ├── cal_ssim.py
│ ├── eval_clip_score.py
│ ├── eval_common_metric.py
│ ├── flolpips
│ │ ├── correlation
│ │ │ └── correlation.py
│ │ ├── flolpips.py
│ │ ├── pretrained_networks.py
│ │ ├── pwcnet.py
│ │ └── utils.py
│ ├── fvd
│ │ ├── styleganv
│ │ │ ├── fvd.py
│ │ │ └── i3d_torchscript.pt
│ │ └── videogpt
│ │ │ ├── fvd.py
│ │ │ ├── i3d_pretrained_400.pt
│ │ │ └── pytorch_i3d.py
│ └── script
│ │ ├── cal_clip_score.sh
│ │ ├── cal_fvd.sh
│ │ ├── cal_lpips.sh
│ │ ├── cal_psnr.sh
│ │ └── cal_ssim.sh
├── models
│ ├── __init__.py
│ ├── ae
│ │ ├── __init__.py
│ │ ├── imagebase
│ │ │ ├── __init__.py
│ │ │ ├── vae
│ │ │ │ └── vae.py
│ │ │ └── vqvae
│ │ │ │ ├── model.py
│ │ │ │ ├── quantize.py
│ │ │ │ ├── vqgan.py
│ │ │ │ └── vqvae.py
│ │ └── videobase
│ │ │ ├── __init__.py
│ │ │ ├── causal_vae
│ │ │ ├── __init__.py
│ │ │ └── modeling_causalvae.py
│ │ │ ├── causal_vqvae
│ │ │ ├── __init__.py
│ │ │ ├── configuration_causalvqvae.py
│ │ │ ├── modeling_causalvqvae.py
│ │ │ └── trainer_causalvqvae.py
│ │ │ ├── configuration_videobase.py
│ │ │ ├── dataset_videobase.py
│ │ │ ├── losses
│ │ │ ├── __init__.py
│ │ │ ├── discriminator.py
│ │ │ ├── lpips.py
│ │ │ └── perceptual_loss.py
│ │ │ ├── modeling_videobase.py
│ │ │ ├── modules
│ │ │ ├── __init__.py
│ │ │ ├── attention.py
│ │ │ ├── block.py
│ │ │ ├── conv.py
│ │ │ ├── normalize.py
│ │ │ ├── ops.py
│ │ │ ├── quant.py
│ │ │ ├── resnet_block.py
│ │ │ └── updownsample.py
│ │ │ ├── trainer_videobase.py
│ │ │ ├── utils
│ │ │ ├── distrib_utils.py
│ │ │ ├── module_utils.py
│ │ │ ├── scheduler_utils.py
│ │ │ └── video_utils.py
│ │ │ └── vqvae
│ │ │ ├── __init__.py
│ │ │ ├── configuration_vqvae.py
│ │ │ ├── modeling_vqvae.py
│ │ │ └── trainer_vqvae.py
│ ├── captioner
│ │ └── caption_refiner
│ │ │ ├── README.md
│ │ │ ├── caption_refiner.py
│ │ │ ├── dataset
│ │ │ └── test_videos
│ │ │ │ ├── captions.json
│ │ │ │ ├── video1.gif
│ │ │ │ └── video2.gif
│ │ │ ├── demo_for_refiner.py
│ │ │ └── gpt_combinator.py
│ ├── diffusion
│ │ ├── __init__.py
│ │ ├── diffusion
│ │ │ ├── __init__.py
│ │ │ ├── diffusion_utils.py
│ │ │ ├── gaussian_diffusion.py
│ │ │ ├── gaussian_diffusion_t2v.py
│ │ │ ├── respace.py
│ │ │ └── timestep_sampler.py
│ │ ├── latte
│ │ │ ├── modeling_latte.py
│ │ │ ├── modules.py
│ │ │ └── pos.py
│ │ ├── transport
│ │ │ ├── __init__.py
│ │ │ ├── integrators.py
│ │ │ ├── path.py
│ │ │ ├── transport.py
│ │ │ └── utils.py
│ │ └── utils
│ │ │ ├── curope
│ │ │ ├── __init__.py
│ │ │ ├── curope.cpp
│ │ │ ├── curope2d.py
│ │ │ ├── kernels.cu
│ │ │ └── setup.py
│ │ │ └── pos_embed.py
│ ├── frame_interpolation
│ │ ├── cfgs
│ │ │ └── AMT-G.yaml
│ │ ├── interpolation.py
│ │ ├── networks
│ │ │ ├── AMT-G.py
│ │ │ ├── __init__.py
│ │ │ └── blocks
│ │ │ │ ├── __init__.py
│ │ │ │ ├── feat_enc.py
│ │ │ │ ├── ifrnet.py
│ │ │ │ ├── multi_flow.py
│ │ │ │ └── raft.py
│ │ ├── readme.md
│ │ └── utils
│ │ │ ├── __init__.py
│ │ │ ├── build_utils.py
│ │ │ ├── dist_utils.py
│ │ │ ├── flow_utils.py
│ │ │ └── utils.py
│ ├── super_resolution
│ │ ├── README.md
│ │ ├── basicsr
│ │ │ ├── __init__.py
│ │ │ ├── archs
│ │ │ │ ├── __init__.py
│ │ │ │ ├── arch_util.py
│ │ │ │ ├── rgt_arch.py
│ │ │ │ └── vgg_arch.py
│ │ │ ├── data
│ │ │ │ ├── __init__.py
│ │ │ │ ├── data_sampler.py
│ │ │ │ ├── data_util.py
│ │ │ │ ├── paired_image_dataset.py
│ │ │ │ ├── prefetch_dataloader.py
│ │ │ │ ├── single_image_dataset.py
│ │ │ │ └── transforms.py
│ │ │ ├── losses
│ │ │ │ ├── __init__.py
│ │ │ │ ├── loss_util.py
│ │ │ │ └── losses.py
│ │ │ ├── metrics
│ │ │ │ ├── __init__.py
│ │ │ │ ├── metric_util.py
│ │ │ │ └── psnr_ssim.py
│ │ │ ├── models
│ │ │ │ ├── __init__.py
│ │ │ │ ├── base_model.py
│ │ │ │ ├── lr_scheduler.py
│ │ │ │ ├── rgt_model.py
│ │ │ │ └── sr_model.py
│ │ │ ├── test_img.py
│ │ │ └── utils
│ │ │ │ ├── __init__.py
│ │ │ │ ├── dist_util.py
│ │ │ │ ├── file_client.py
│ │ │ │ ├── img_util.py
│ │ │ │ ├── logger.py
│ │ │ │ ├── matlab_functions.py
│ │ │ │ ├── misc.py
│ │ │ │ ├── options.py
│ │ │ │ └── registry.py
│ │ ├── options
│ │ │ └── test
│ │ │ │ ├── test_RGT_x2.yml
│ │ │ │ ├── test_RGT_x4.yml
│ │ │ │ └── test_single_config.yml
│ │ └── run.py
│ └── text_encoder
│ │ ├── __init__.py
│ │ ├── clip.py
│ │ └── t5.py
├── sample
│ ├── pipeline_videogen.py
│ ├── sample.py
│ ├── sample_t2v.py
│ └── transport_sample.py
├── serve
│ ├── gradio_utils.py
│ └── gradio_web_server.py
├── train
│ ├── train.py
│ ├── train_causalvae.py
│ ├── train_t2v.py
│ ├── train_t2v_feature.py
│ ├── train_t2v_t5_feature.py
│ └── train_videogpt.py
└── utils
│ ├── dataset_utils.py
│ ├── downloader.py
│ ├── taming_download.py
│ └── utils.py
├── pyproject.toml
├── requirements.txt
├── scripts
├── accelerate_configs
│ ├── ddp_config.yaml
│ ├── deepspeed_zero2_config.yaml
│ ├── deepspeed_zero2_offload_config.yaml
│ ├── deepspeed_zero3_config.yaml
│ ├── deepspeed_zero3_offload_config.yaml
│ ├── default_config.yaml
│ ├── hostfile
│ ├── multi_node_example.yaml
│ ├── zero2.json
│ ├── zero2_offload.json
│ ├── zero3.json
│ └── zero3_offload.json
├── causalvae
│ ├── eval.sh
│ ├── gen_video.sh
│ ├── release.json
│ └── train.sh
├── class_condition
│ ├── sample.sh
│ ├── train_imgae.sh
│ └── train_vidae.sh
├── slurm
│ └── placeholder
├── text_condition
│ ├── sample_image.sh
│ ├── sample_video.sh
│ ├── train_imageae.sh
│ ├── train_videoae_17x256x256.sh
│ ├── train_videoae_65x256x256.sh
│ └── train_videoae_65x512x512.sh
├── un_condition
│ ├── sample.sh
│ ├── train_imgae.sh
│ └── train_vidae.sh
└── videogpt
│ ├── train_videogpt.sh
│ ├── train_videogpt_dsz2.sh
│ └── train_videogpt_dsz3.sh
├── wf.json
└── wf.png
/.gitignore:
--------------------------------------------------------------------------------
1 | ucf101_stride4x4x4
2 | __pycache__
3 | *.mp4
4 | .ipynb_checkpoints
5 | *.pth
6 | UCF-101/
7 | results/
8 | vae
9 | build/
10 | opensora.egg-info/
11 | wandb/
12 | .idea
13 | *.ipynb
14 | *.jpg
15 | *.mp3
16 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 PKU-YUAN's Group (袁粒课题组-北大信工)
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ComfyUI-Open-Sora-Plan
2 |
3 | ## workflow
4 |
5 | https://github.com/chaojie/ComfyUI-Open-Sora-Plan/blob/main/wf.json
6 |
7 |
8 |
9 | ### [Open-Sora-Plan](https://github.com/PKU-YuanGroup/Open-Sora-Plan)
10 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 | from .nodes import NODE_CLASS_MAPPINGS
2 |
3 | __all__ = ['NODE_CLASS_MAPPINGS']
--------------------------------------------------------------------------------
/docker/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 SimonLee
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/docker/README.md:
--------------------------------------------------------------------------------
1 | # Docker4ML
2 |
3 | Useful docker scripts for ML developement.
4 | [https://github.com/SimonLeeGit/Docker4ML](https://github.com/SimonLeeGit/Docker4ML)
5 |
6 | ## Build Docker Image
7 |
8 | ```bash
9 | bash docker_build.sh
10 | ```
11 |
12 | 
13 |
14 | ## Run Docker Container as Development Envirnoment
15 |
16 | ```bash
17 | bash docker_run.sh
18 | ```
19 |
20 | 
21 |
22 | ## Custom Docker Config
23 |
24 | ### Config [setup_env.sh](./setup_env.sh)
25 |
26 | You can modify this file to custom your settings.
27 |
28 | ```bash
29 | TAG=ml:dev
30 | BASE_TAG=nvcr.io/nvidia/pytorch:23.12-py3
31 | ```
32 |
33 | #### TAG
34 |
35 | Your built docker image tag, you can set it as what you what.
36 |
37 | #### BASE_TAG
38 |
39 | The base docker image tag for your built docker image, here we use nvidia pytorch images.
40 | You can check it from [https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch/tags](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch/tags)
41 |
42 | Also, you can use other docker image as base, such as: [ubuntu](https://hub.docker.com/_/ubuntu/tags)
43 |
44 | ### USER_NAME
45 |
46 | Your user name used in docker container.
47 |
48 | ### USER_PASSWD
49 |
50 | Your user password used in docker container.
51 |
52 | ### Config [requriements.txt](./requirements.txt)
53 |
54 | You can add your default installed python libraries here.
55 |
56 | ```txt
57 | transformers==4.27.1
58 | ```
59 |
60 | By default, it has some libs installed, you can check it from [https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-01.html](https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-01.html)
61 |
62 | ### Config [packages.txt](./packages.txt)
63 |
64 | You can add your default apt-get installed packages here.
65 |
66 | ```txt
67 | wget
68 | curl
69 | git
70 | ```
71 |
72 | ### Config [ports.txt](./ports.txt)
73 |
74 | You can add some ports enabled for docker container here.
75 |
76 | ```txt
77 | -p 6006:6006
78 | -p 8080:8080
79 | ```
80 |
81 | ### Config [postinstallscript.sh](./postinstallscript.sh)
82 |
83 | You can add your custom script to run when build docker image.
84 |
85 | ## Q&A
86 |
87 | If you have any use problems, please contact to .
88 |
--------------------------------------------------------------------------------
/docker/build_docker.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-Open-Sora-Plan/b060ff6d7a85a27eec5ff9b81b599d03c4ac1bc6/docker/build_docker.png
--------------------------------------------------------------------------------
/docker/docker_build.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | WORK_DIR=$(dirname "$(readlink -f "$0")")
4 | cd $WORK_DIR
5 |
6 | source setup_env.sh
7 |
8 | docker build -t $TAG --build-arg BASE_TAG=$BASE_TAG --build-arg USER_NAME=$USER_NAME --build-arg USER_PASSWD=$USER_PASSWD . -f dockerfile.base
9 |
--------------------------------------------------------------------------------
/docker/docker_run.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | WORK_DIR=$(dirname "$(readlink -f "$0")")
4 | source $WORK_DIR/setup_env.sh
5 |
6 | RUNNING_IDS="$(docker ps --filter ancestor=$TAG --format "{{.ID}}")"
7 |
8 | if [ -n "$RUNNING_IDS" ]; then
9 | # Initialize an array to hold the container IDs
10 | declare -a container_ids=($RUNNING_IDS)
11 |
12 | # Get the first container ID using array indexing
13 | ID=${container_ids[0]}
14 |
15 | # Print the first container ID
16 | echo ' '
17 | echo "The running container ID is: $ID, enter it!"
18 | else
19 | echo ' '
20 | echo "Not found running containers, run it!"
21 |
22 | # Run a new docker container instance
23 | ID=$(docker run \
24 | --rm \
25 | --gpus all \
26 | -itd \
27 | --ipc=host \
28 | --ulimit memlock=-1 \
29 | --ulimit stack=67108864 \
30 | -e DISPLAY=$DISPLAY \
31 | -v /tmp/.X11-unix/:/tmp/.X11-unix/ \
32 | -v $PWD:/home/$USER_NAME/workspace \
33 | -w /home/$USER_NAME/workspace \
34 | $(cat $WORK_DIR/ports.txt) \
35 | $TAG)
36 | fi
37 |
38 | docker logs $ID
39 |
40 | echo ' '
41 | echo ' '
42 | echo '========================================='
43 | echo ' '
44 |
45 | docker exec -it $ID bash
46 |
--------------------------------------------------------------------------------
/docker/dockerfile.base:
--------------------------------------------------------------------------------
1 | ARG BASE_TAG
2 | FROM ${BASE_TAG}
3 | ARG USER_NAME=myuser
4 | ARG USER_PASSWD=111111
5 | ARG DEBIAN_FRONTEND=noninteractive
6 |
7 | # Pre-install packages, pip install requirements and run post install script.
8 | COPY packages.txt .
9 | COPY requirements.txt .
10 | COPY postinstallscript.sh .
11 | RUN apt-get update && apt-get install -y sudo $(cat packages.txt)
12 | RUN pip install --no-cache-dir -r requirements.txt
13 | RUN bash postinstallscript.sh
14 |
15 | # Create a new user and group using the username argument
16 | RUN groupadd -r ${USER_NAME} && useradd -r -m -g${USER_NAME} ${USER_NAME}
17 | RUN echo "${USER_NAME}:${USER_PASSWD}" | chpasswd
18 | RUN usermod -aG sudo ${USER_NAME}
19 | USER ${USER_NAME}
20 | ENV USER=${USER_NAME}
21 | WORKDIR /home/${USER_NAME}/workspace
22 |
23 | # Set the prompt to highlight the username
24 | RUN echo "export PS1='\[\033[01;32m\]\u\[\033[00m\]@\[\033[01;34m\]\h\[\033[00m\]:\[\033[01;36m\]\w\[\033[00m\]\$'" >> /home/${USER_NAME}/.bashrc
25 |
--------------------------------------------------------------------------------
/docker/packages.txt:
--------------------------------------------------------------------------------
1 | wget
2 | curl
3 | git
--------------------------------------------------------------------------------
/docker/ports.txt:
--------------------------------------------------------------------------------
1 | -p 6006:6006
--------------------------------------------------------------------------------
/docker/postinstallscript.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # this script will run when build docker image.
3 |
4 |
--------------------------------------------------------------------------------
/docker/requirements.txt:
--------------------------------------------------------------------------------
1 | setuptools>=61.0
2 | torch==2.0.1
3 | torchvision==0.15.2
4 | transformers==4.32.0
5 | albumentations==1.4.0
6 | av==11.0.0
7 | decord==0.6.0
8 | einops==0.3.0
9 | fastapi==0.110.0
10 | accelerate==0.21.0
11 | gdown==5.1.0
12 | h5py==3.10.0
13 | idna==3.6
14 | imageio==2.34.0
15 | matplotlib==3.7.5
16 | numpy==1.24.4
17 | omegaconf==2.1.1
18 | opencv-python==4.9.0.80
19 | opencv-python-headless==4.9.0.80
20 | pandas==2.0.3
21 | pillow==10.2.0
22 | pydub==0.25.1
23 | pytorch-lightning==1.4.2
24 | pytorchvideo==0.1.5
25 | PyYAML==6.0.1
26 | regex==2023.12.25
27 | requests==2.31.0
28 | scikit-learn==1.3.2
29 | scipy==1.10.1
30 | six==1.16.0
31 | tensorboard==2.14.0
32 | test-tube==0.7.5
33 | timm==0.9.16
34 | torchdiffeq==0.2.3
35 | torchmetrics==0.5.0
36 | tqdm==4.66.2
37 | urllib3==2.2.1
38 | uvicorn==0.27.1
39 | diffusers==0.24.0
40 | scikit-video==1.1.11
41 |
--------------------------------------------------------------------------------
/docker/run_docker.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-Open-Sora-Plan/b060ff6d7a85a27eec5ff9b81b599d03c4ac1bc6/docker/run_docker.png
--------------------------------------------------------------------------------
/docker/setup_env.sh:
--------------------------------------------------------------------------------
1 | # Docker tag for new build image
2 | TAG=open_sora_plan:dev
3 |
4 | # Base docker image tag used by docker build
5 | BASE_TAG=nvcr.io/nvidia/pytorch:23.05-py3
6 |
7 | # User name used in docker container
8 | USER_NAME=developer
9 |
10 | # User password used in docker container
11 | USER_PASSWD=666666
--------------------------------------------------------------------------------
/docs/Contribution_Guidelines.md:
--------------------------------------------------------------------------------
1 | # Contributing to the Open-Sora Plan Community
2 |
3 | The Open-Sora Plan open-source community is a collaborative initiative driven by the community, emphasizing a commitment to being free and void of exploitation. Organized spontaneously by community members, we invite you to contribute to the Open-Sora Plan open-source community and help elevate it to new heights!
4 |
5 | ## Submitting a Pull Request (PR)
6 |
7 | As a contributor, before submitting your request, kindly follow these guidelines:
8 |
9 | 1. Start by checking the [Open-Sora Plan GitHub](https://github.com/PKU-YuanGroup/Open-Sora-Plan/pulls) to see if there are any open or closed pull requests related to your intended submission. Avoid duplicating existing work.
10 |
11 | 2. [Fork](https://github.com/PKU-YuanGroup/Open-Sora-Plan/fork) the [open-sora plan](https://github.com/PKU-YuanGroup/Open-Sora-Plan) repository and download your forked repository to your local machine.
12 |
13 | ```bash
14 | git clone [your-forked-repository-url]
15 | ```
16 |
17 | 3. Add the original Open-Sora Plan repository as a remote to sync with the latest updates:
18 |
19 | ```bash
20 | git remote add upstream https://github.com/PKU-YuanGroup/Open-Sora-Plan
21 | ```
22 |
23 | 4. Sync the code from the main repository to your local machine, and then push it back to your forked remote repository.
24 |
25 | ```
26 | # Pull the latest code from the upstream branch
27 | git fetch upstream
28 |
29 | # Switch to the main branch
30 | git checkout main
31 |
32 | # Merge the updates from the upstream branch into main, synchronizing the local main branch with the upstream
33 | git merge upstream/main
34 |
35 | # Additionally, sync the local main branch to the remote branch of your forked repository
36 | git push origin main
37 | ```
38 |
39 |
40 | > Note: Sync the code from the main repository before each submission.
41 |
42 | 5. Create a branch in your forked repository for your changes, ensuring the branch name is meaningful.
43 |
44 | ```bash
45 | git checkout -b my-docs-branch main
46 | ```
47 |
48 | 6. While making modifications and committing changes, adhere to our [Commit Message Format](#Commit-Message-Format).
49 |
50 | ```bash
51 | git commit -m "[docs]: xxxx"
52 | ```
53 |
54 | 7. Push your changes to your GitHub repository.
55 |
56 | ```bash
57 | git push origin my-docs-branch
58 | ```
59 |
60 | 8. Submit a pull request to `Open-Sora-Plan:main` on the GitHub repository page.
61 |
62 | ## Commit Message Format
63 |
64 | Commit messages must include both `` and `` sections.
65 |
66 | ```bash
67 | []:
68 | │ │
69 | │ └─⫸ Briefly describe your changes, without ending with a period.
70 | │
71 | └─⫸ Commit Type: |docs|feat|fix|refactor|
72 | ```
73 |
74 | ### Type
75 |
76 | * **docs**: Modify or add documents.
77 | * **feat**: Introduce a new feature.
78 | * **fix**: Fix a bug.
79 | * **refactor**: Restructure code, excluding new features or bug fixes.
80 |
81 | ### Summary
82 |
83 | Describe modifications in English, without ending with a period.
84 |
85 | > e.g., git commit -m "[docs]: add a contributing.md file"
86 |
87 | This guideline is borrowed by [minisora](https://github.com/mini-sora/minisora). We sincerely appreciate MiniSora authors for their awesome templates.
88 |
--------------------------------------------------------------------------------
/docs/Data.md:
--------------------------------------------------------------------------------
1 |
2 | **We need more dataset**, please refer to the [open-sora-Dataset](https://github.com/shaodong233/open-sora-Dataset) for details.
3 |
4 |
5 | ## Sky
6 |
7 |
8 | This is an un-condition datasets. [Link](https://drive.google.com/open?id=1xWLiU-MBGN7MrsFHQm4_yXmfHBsMbJQo)
9 |
10 | ```
11 | sky_timelapse
12 | ├── readme
13 | ├── sky_test
14 | ├── sky_train
15 | ├── test_videofolder.py
16 | └── video_folder.py
17 | ```
18 |
19 | ## UCF101
20 |
21 | We test the code with UCF-101 dataset. In order to download UCF-101 dataset, you can download the necessary files in [here](https://www.crcv.ucf.edu/data/UCF101.php). The code assumes a `ucf101` directory with the following structure
22 | ```
23 | UCF-101/
24 | ApplyEyeMakeup/
25 | v1.avi
26 | ...
27 | ...
28 | YoYo/
29 | v1.avi
30 | ...
31 | ```
32 |
33 |
34 | ## Offline feature extraction
35 | Coming soon...
36 |
--------------------------------------------------------------------------------
/docs/EVAL.md:
--------------------------------------------------------------------------------
1 | # Evaluate the generated videos quality
2 |
3 | You can easily calculate the following video quality metrics, which supports the batch-wise process.
4 | - **CLIP-SCORE**: It uses the pretrained CLIP model to measure the cosine similarity between two modalities.
5 | - **FVD**: Frechét Video Distance
6 | - **SSIM**: structural similarity index measure
7 | - **LPIPS**: learned perceptual image patch similarity
8 | - **PSNR**: peak-signal-to-noise ratio
9 |
10 | # Requirement
11 | ## Environment
12 | - install Pytorch (torch>=1.7.1)
13 | - install CLIP
14 | ```
15 | pip install git+https://github.com/openai/CLIP.git
16 | ```
17 | - install clip-cose from PyPi
18 | ```
19 | pip install clip-score
20 | ```
21 | - Other package
22 | ```
23 | pip install lpips
24 | pip install scipy (scipy==1.7.3/1.9.3, if you use 1.11.3, **you will calculate a WRONG FVD VALUE!!!**)
25 | pip install numpy
26 | pip install pillow
27 | pip install torchvision>=0.8.2
28 | pip install ftfy
29 | pip install regex
30 | pip install tqdm
31 | ```
32 | ## Pretrain model
33 | - FVD
34 | Before you cacluate FVD, you should first download the FVD pre-trained model. You can manually download any of the following and put it into FVD folder.
35 | - `i3d_torchscript.pt` from [here](https://www.dropbox.com/s/ge9e5ujwgetktms/i3d_torchscript.pt)
36 | - `i3d_pretrained_400.pt` from [here](https://onedrive.live.com/download?cid=78EEF3EB6AE7DBCB&resid=78EEF3EB6AE7DBCB%21199&authkey=AApKdFHPXzWLNyI)
37 |
38 |
39 | ## Other Notices
40 | 1. Make sure the pixel value of videos should be in [0, 1].
41 | 2. We average SSIM when images have 3 channels, ssim is the only metric extremely sensitive to gray being compared to b/w.
42 | 3. Because the i3d model downsamples in the time dimension, `frames_num` should > 10 when calculating FVD, so FVD calculation begins from 10-th frame, like upper example.
43 | 4. For grayscale videos, we multiply to 3 channels
44 | 5. data input specifications for clip_score
45 | > - Image Files:All images should be stored in a single directory. The image files can be in either .png or .jpg format.
46 | >
47 | > - Text Files: All text data should be contained in plain text files in a separate directory. These text files should have the extension .txt.
48 | >
49 | > Note: The number of files in the image directory should be exactly equal to the number of files in the text directory. Additionally, the files in the image directory and text directory should be paired by file name. For instance, if there is a cat.png in the image directory, there should be a corresponding cat.txt in the text directory.
50 | >
51 | > Directory Structure Example:
52 | > ```
53 | > ├── path/to/image
54 | > │ ├── cat.png
55 | > │ ├── dog.png
56 | > │ └── bird.jpg
57 | > └── path/to/text
58 | > ├── cat.txt
59 | > ├── dog.txt
60 | > └── bird.txt
61 | > ```
62 |
63 | 6. data input specifications for fvd, psnr, ssim, lpips
64 |
65 | > Directory Structure Example:
66 | > ```
67 | > ├── path/to/generated_image
68 | > │ ├── cat.mp4
69 | > │ ├── dog.mp4
70 | > │ └── bird.mp4
71 | > └── path/to/real_image
72 | > ├── cat.mp4
73 | > ├── dog.mp4
74 | > └── bird.mp4
75 | > ```
76 |
77 |
78 |
79 | # Usage
80 |
81 | ```
82 | # you change the file path and need to set the frame_num, resolution etc...
83 |
84 | # clip_score cross modality
85 | cd opensora/eval
86 | bash script/cal_clip_score.sh
87 |
88 |
89 |
90 | # fvd
91 | cd opensora/eval
92 | bash script/cal_fvd.sh
93 |
94 | # psnr
95 | cd opensora/eval
96 | bash eval/script/cal_psnr.sh
97 |
98 |
99 | # ssim
100 | cd opensora/eval
101 | bash eval/script/cal_ssim.sh
102 |
103 |
104 | # lpips
105 | cd opensora/eval
106 | bash eval/script/cal_lpips.sh
107 | ```
108 |
109 | # Acknowledgement
110 | The evaluation codebase refers to [clip-score](https://github.com/Taited/clip-score) and [common_metrics](https://github.com/JunyaoHu/common_metrics_on_video_quality).
--------------------------------------------------------------------------------
/docs/VQVAE.md:
--------------------------------------------------------------------------------
1 | # VQVAE Documentation
2 |
3 | # Introduction
4 |
5 | Vector Quantized Variational AutoEncoders (VQ-VAE) is a type of autoencoder that uses a discrete latent representation. It is particularly useful for tasks that require discrete latent variables, such as text-to-speech and video generation.
6 |
7 | # Usage
8 |
9 | ## Initialization
10 |
11 | To initialize a VQVAE model, you can use the `VideoGPTVQVAE` class. This class is a part of the `opensora.models.ae` module.
12 |
13 | ```python
14 | from opensora.models.ae import VideoGPTVQVAE
15 |
16 | vqvae = VideoGPTVQVAE()
17 | ```
18 |
19 | ### Training
20 |
21 | To train the VQVAE model, you can use the `train_videogpt.sh` script. This script will train the model using the parameters specified in the script.
22 |
23 | ```bash
24 | bash scripts/videogpt/train_videogpt.sh
25 | ```
26 |
27 | ### Loading Pretrained Models
28 |
29 | You can load a pretrained model using the `download_and_load_model` method. This method will download the checkpoint file and load the model.
30 |
31 | ```python
32 | vqvae = VideoGPTVQVAE.download_and_load_model("bair_stride4x2x2")
33 | ```
34 |
35 | Alternatively, you can load a model from a checkpoint using the `load_from_checkpoint` method.
36 |
37 | ```python
38 | vqvae = VQVAEModel.load_from_checkpoint("results/VQVAE/checkpoint-1000")
39 | ```
40 |
41 | ### Encoding and Decoding
42 |
43 | You can encode a video using the `encode` method. This method will return the encodings and embeddings of the video.
44 |
45 | ```python
46 | encodings, embeddings = vqvae.encode(x_vae, include_embeddings=True)
47 | ```
48 |
49 | You can reconstruct a video from its encodings using the decode method.
50 |
51 | ```python
52 | video_recon = vqvae.decode(encodings)
53 | ```
54 |
55 | ## Testing
56 |
57 | You can test the VQVAE model by reconstructing a video. The `examples/rec_video.py` script provides an example of how to do this.
--------------------------------------------------------------------------------
/examples/get_latents_std.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import DataLoader, Subset
3 | import sys
4 | sys.path.append(".")
5 | from opensora.models.ae.videobase import CausalVAEModel, CausalVAEDataset
6 |
7 | num_workers = 4
8 | batch_size = 12
9 |
10 | torch.manual_seed(0)
11 | torch.set_grad_enabled(False)
12 |
13 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
14 |
15 | pretrained_model_name_or_path = 'results/causalvae/checkpoint-26000'
16 | data_path = '/remote-home1/dataset/UCF-101'
17 | video_num_frames = 17
18 | resolution = 128
19 | sample_rate = 10
20 |
21 | vae = CausalVAEModel.load_from_checkpoint(pretrained_model_name_or_path)
22 | vae.to(device)
23 |
24 | dataset = CausalVAEDataset(data_path, sequence_length=video_num_frames, resolution=resolution, sample_rate=sample_rate)
25 | subset_indices = list(range(1000))
26 | subset_dataset = Subset(dataset, subset_indices)
27 | loader = DataLoader(subset_dataset, batch_size=8, pin_memory=True)
28 |
29 | all_latents = []
30 | for video_data in loader:
31 | video_data = video_data['video'].to(device)
32 | latents = vae.encode(video_data).sample()
33 | all_latents.append(video_data.cpu())
34 |
35 | all_latents_tensor = torch.cat(all_latents)
36 | std = all_latents_tensor.std().item()
37 | normalizer = 1 / std
38 | print(f'{normalizer = }')
--------------------------------------------------------------------------------
/examples/rec_image.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append(".")
3 | from PIL import Image
4 | import torch
5 | from torchvision.transforms import ToTensor, Compose, Resize, Normalize
6 | from torch.nn import functional as F
7 | from opensora.models.ae.videobase import CausalVAEModel
8 | import argparse
9 | import numpy as np
10 |
11 | def preprocess(video_data: torch.Tensor, short_size: int = 128) -> torch.Tensor:
12 | transform = Compose(
13 | [
14 | ToTensor(),
15 | Normalize((0.5), (0.5)),
16 | Resize(size=short_size),
17 | ]
18 | )
19 | outputs = transform(video_data)
20 | outputs = outputs.unsqueeze(0).unsqueeze(2)
21 | return outputs
22 |
23 | def main(args: argparse.Namespace):
24 | image_path = args.image_path
25 | resolution = args.resolution
26 | device = args.device
27 |
28 | vqvae = CausalVAEModel.load_from_checkpoint(args.ckpt)
29 | vqvae.eval()
30 | vqvae = vqvae.to(device)
31 |
32 | with torch.no_grad():
33 | x_vae = preprocess(Image.open(image_path), resolution)
34 | x_vae = x_vae.to(device)
35 | latents = vqvae.encode(x_vae)
36 | recon = vqvae.decode(latents.sample())
37 | x = recon[0, :, 0, :, :]
38 | x = x.squeeze()
39 | x = x.detach().cpu().numpy()
40 | x = np.clip(x, -1, 1)
41 | x = (x + 1) / 2
42 | x = (255*x).astype(np.uint8)
43 | x = x.transpose(1,2,0)
44 | image = Image.fromarray(x)
45 | image.save(args.rec_path)
46 |
47 |
48 | if __name__ == '__main__':
49 | parser = argparse.ArgumentParser()
50 | parser.add_argument('--image-path', type=str, default='')
51 | parser.add_argument('--rec-path', type=str, default='')
52 | parser.add_argument('--ckpt', type=str, default='')
53 | parser.add_argument('--resolution', type=int, default=336)
54 | parser.add_argument('--device', type=str, default='cuda')
55 |
56 | args = parser.parse_args()
57 | main(args)
58 |
--------------------------------------------------------------------------------
/opensora/__init__.py:
--------------------------------------------------------------------------------
1 | #
--------------------------------------------------------------------------------
/opensora/dataset/__init__.py:
--------------------------------------------------------------------------------
1 | from torchvision.transforms import Compose
2 | from transformers import AutoTokenizer
3 |
4 | from .feature_datasets import T2V_Feature_dataset, T2V_T5_Feature_dataset
5 | from torchvision import transforms
6 | from torchvision.transforms import Lambda
7 |
8 | from .t2v_datasets import T2V_dataset
9 | from .transform import ToTensorVideo, TemporalRandomCrop, RandomHorizontalFlipVideo, CenterCropResizeVideo, LongSideResizeVideo, SpatialStrideCropVideo
10 |
11 |
12 | ae_norm = {
13 | 'CausalVAEModel_4x8x8': Lambda(lambda x: 2. * x - 1.),
14 | 'CausalVQVAEModel_4x4x4': Lambda(lambda x: x - 0.5),
15 | 'CausalVQVAEModel_4x8x8': Lambda(lambda x: x - 0.5),
16 | 'VQVAEModel_4x4x4': Lambda(lambda x: x - 0.5),
17 | 'VQVAEModel_4x8x8': Lambda(lambda x: x - 0.5),
18 | "bair_stride4x2x2": Lambda(lambda x: x - 0.5),
19 | "ucf101_stride4x4x4": Lambda(lambda x: x - 0.5),
20 | "kinetics_stride4x4x4": Lambda(lambda x: x - 0.5),
21 | "kinetics_stride2x4x4": Lambda(lambda x: x - 0.5),
22 | 'stabilityai/sd-vae-ft-mse': transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
23 | 'stabilityai/sd-vae-ft-ema': transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
24 | 'vqgan_imagenet_f16_1024': Lambda(lambda x: 2. * x - 1.),
25 | 'vqgan_imagenet_f16_16384': Lambda(lambda x: 2. * x - 1.),
26 | 'vqgan_gumbel_f8': Lambda(lambda x: 2. * x - 1.),
27 |
28 | }
29 | ae_denorm = {
30 | 'CausalVAEModel_4x8x8': lambda x: (x + 1.) / 2.,
31 | 'CausalVQVAEModel_4x4x4': lambda x: x + 0.5,
32 | 'CausalVQVAEModel_4x8x8': lambda x: x + 0.5,
33 | 'VQVAEModel_4x4x4': lambda x: x + 0.5,
34 | 'VQVAEModel_4x8x8': lambda x: x + 0.5,
35 | "bair_stride4x2x2": lambda x: x + 0.5,
36 | "ucf101_stride4x4x4": lambda x: x + 0.5,
37 | "kinetics_stride4x4x4": lambda x: x + 0.5,
38 | "kinetics_stride2x4x4": lambda x: x + 0.5,
39 | 'stabilityai/sd-vae-ft-mse': lambda x: 0.5 * x + 0.5,
40 | 'stabilityai/sd-vae-ft-ema': lambda x: 0.5 * x + 0.5,
41 | 'vqgan_imagenet_f16_1024': lambda x: (x + 1.) / 2.,
42 | 'vqgan_imagenet_f16_16384': lambda x: (x + 1.) / 2.,
43 | 'vqgan_gumbel_f8': lambda x: (x + 1.) / 2.,
44 | }
45 |
46 | def getdataset(args):
47 | temporal_sample = TemporalRandomCrop(args.num_frames * args.sample_rate) # 16 x
48 | norm_fun = ae_norm[args.ae]
49 | if args.dataset == 't2v':
50 | if args.multi_scale:
51 | resize = [
52 | LongSideResizeVideo(args.max_image_size, skip_low_resolution=True),
53 | SpatialStrideCropVideo(args.stride)
54 | ]
55 | else:
56 | resize = [CenterCropResizeVideo(args.max_image_size), ]
57 | transform = transforms.Compose([
58 | ToTensorVideo(),
59 | *resize,
60 | # RandomHorizontalFlipVideo(p=0.5), # in case their caption have position decription
61 | norm_fun
62 | ])
63 | tokenizer = AutoTokenizer.from_pretrained(args.text_encoder_name, cache_dir=args.cache_dir)
64 | return T2V_dataset(args, transform=transform, temporal_sample=temporal_sample, tokenizer=tokenizer)
65 | raise NotImplementedError(args.dataset)
66 |
--------------------------------------------------------------------------------
/opensora/dataset/extract_feature_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | from glob import glob
3 |
4 | import numpy as np
5 | import torch
6 | import torchvision
7 | from PIL import Image
8 | from torch.utils.data import Dataset
9 |
10 | from opensora.utils.dataset_utils import DecordInit, is_image_file
11 |
12 |
13 | class ExtractVideo2Feature(Dataset):
14 | def __init__(self, args, transform):
15 | self.data_path = args.data_path
16 | self.transform = transform
17 | self.v_decoder = DecordInit()
18 | self.samples = list(glob(f'{self.data_path}'))
19 |
20 | def __len__(self):
21 | return len(self.samples)
22 |
23 | def __getitem__(self, idx):
24 | video_path = self.samples[idx]
25 | video = self.decord_read(video_path)
26 | video = self.transform(video) # T C H W -> T C H W
27 | return video, video_path
28 |
29 | def tv_read(self, path):
30 | vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit='sec', output_format='TCHW')
31 | total_frames = len(vframes)
32 | frame_indice = list(range(total_frames))
33 | video = vframes[frame_indice]
34 | return video
35 |
36 | def decord_read(self, path):
37 | decord_vr = self.v_decoder(path)
38 | total_frames = len(decord_vr)
39 | frame_indice = list(range(total_frames))
40 | video_data = decord_vr.get_batch(frame_indice).asnumpy()
41 | video_data = torch.from_numpy(video_data)
42 | video_data = video_data.permute(0, 3, 1, 2) # (T, H, W, C) -> (T C H W)
43 | return video_data
44 |
45 |
46 |
47 | class ExtractImage2Feature(Dataset):
48 | def __init__(self, args, transform):
49 | self.data_path = args.data_path
50 | self.transform = transform
51 | self.data_all = list(glob(f'{self.data_path}'))
52 |
53 | def __len__(self):
54 | return len(self.data_all)
55 |
56 | def __getitem__(self, index):
57 | path = self.data_all[index]
58 | video_frame = torch.as_tensor(np.array(Image.open(path), dtype=np.uint8, copy=True)).unsqueeze(0)
59 | video_frame = video_frame.permute(0, 3, 1, 2)
60 | video_frame = self.transform(video_frame) # T C H W
61 | # video_frame = video_frame.transpose(0, 1) # T C H W -> C T H W
62 |
63 | return video_frame, path
64 |
65 |
--------------------------------------------------------------------------------
/opensora/dataset/landscope.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 | from glob import glob
4 |
5 | import decord
6 | import numpy as np
7 | import torch
8 | import torchvision
9 | from decord import VideoReader, cpu
10 | from torch.utils.data import Dataset
11 | from torchvision.transforms import Compose, Lambda, ToTensor
12 | from torchvision.transforms._transforms_video import NormalizeVideo, RandomCropVideo, RandomHorizontalFlipVideo
13 | from pytorchvideo.transforms import ApplyTransformToKey, ShortSideScale, UniformTemporalSubsample
14 | from torch.nn import functional as F
15 | import random
16 |
17 | from opensora.utils.dataset_utils import DecordInit
18 |
19 |
20 | class Landscope(Dataset):
21 | def __init__(self, args, transform, temporal_sample):
22 | self.data_path = args.data_path
23 | self.num_frames = args.num_frames
24 | self.transform = transform
25 | self.temporal_sample = temporal_sample
26 | self.v_decoder = DecordInit()
27 |
28 | self.samples = self._make_dataset()
29 | self.use_image_num = args.use_image_num
30 | self.use_img_from_vid = args.use_img_from_vid
31 | if self.use_image_num != 0 and not self.use_img_from_vid:
32 | self.img_cap_list = self.get_img_cap_list()
33 |
34 |
35 | def _make_dataset(self):
36 | paths = list(glob(os.path.join(self.data_path, '**', '*.mp4'), recursive=True))
37 |
38 | return paths
39 |
40 | def __len__(self):
41 | return len(self.samples)
42 |
43 | def __getitem__(self, idx):
44 | video_path = self.samples[idx]
45 | try:
46 | video = self.tv_read(video_path)
47 | video = self.transform(video) # T C H W -> T C H W
48 | video = video.transpose(0, 1) # T C H W -> C T H W
49 | if self.use_image_num != 0 and self.use_img_from_vid:
50 | select_image_idx = np.linspace(0, self.num_frames - 1, self.use_image_num, dtype=int)
51 | assert self.num_frames >= self.use_image_num
52 | images = video[:, select_image_idx] # c, num_img, h, w
53 | video = torch.cat([video, images], dim=1) # c, num_frame+num_img, h, w
54 | elif self.use_image_num != 0 and not self.use_img_from_vid:
55 | images, captions = self.img_cap_list[idx]
56 | raise NotImplementedError
57 | else:
58 | pass
59 | return video, 1
60 | except Exception as e:
61 | print(f'Error with {e}, {video_path}')
62 | return self.__getitem__(random.randint(0, self.__len__()-1))
63 |
64 | def tv_read(self, path):
65 | vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit='sec', output_format='TCHW')
66 | total_frames = len(vframes)
67 |
68 | # Sampling video frames
69 | start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
70 | # assert end_frame_ind - start_frame_ind >= self.num_frames
71 | frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.num_frames, dtype=int)
72 | video = vframes[frame_indice] # (T, C, H, W)
73 |
74 | return video
75 |
76 | def decord_read(self, path):
77 | decord_vr = self.v_decoder(path)
78 | total_frames = len(decord_vr)
79 | # Sampling video frames
80 | start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
81 | # assert end_frame_ind - start_frame_ind >= self.num_frames
82 | frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.num_frames, dtype=int)
83 |
84 | video_data = decord_vr.get_batch(frame_indice).asnumpy()
85 | video_data = torch.from_numpy(video_data)
86 | video_data = video_data.permute(0, 3, 1, 2) # (T, H, W, C) -> (T C H W)
87 | return video_data
88 |
89 | def get_img_cap_list(self):
90 | raise NotImplementedError
91 |
--------------------------------------------------------------------------------
/opensora/dataset/ucf101.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 |
4 | import decord
5 | import numpy as np
6 | import torch
7 | import torchvision
8 | from decord import VideoReader, cpu
9 | from torch.utils.data import Dataset
10 | from torchvision.transforms import Compose, Lambda, ToTensor
11 | from torchvision.transforms._transforms_video import NormalizeVideo, RandomCropVideo, RandomHorizontalFlipVideo
12 | from pytorchvideo.transforms import ApplyTransformToKey, ShortSideScale, UniformTemporalSubsample
13 | from torch.nn import functional as F
14 | import random
15 |
16 | from opensora.utils.dataset_utils import DecordInit
17 |
18 |
19 | class UCF101(Dataset):
20 | def __init__(self, args, transform, temporal_sample):
21 | self.data_path = args.data_path
22 | self.num_frames = args.num_frames
23 | self.transform = transform
24 | self.temporal_sample = temporal_sample
25 | self.v_decoder = DecordInit()
26 |
27 | self.classes = sorted(os.listdir(self.data_path))
28 | self.class_to_idx = {cls_name: idx for idx, cls_name in enumerate(self.classes)}
29 | self.samples = self._make_dataset()
30 |
31 |
32 | def _make_dataset(self):
33 | dataset = []
34 | for class_name in self.classes:
35 | class_path = os.path.join(self.data_path, class_name)
36 | for fname in os.listdir(class_path):
37 | if fname.endswith('.avi'):
38 | item = (os.path.join(class_path, fname), self.class_to_idx[class_name])
39 | dataset.append(item)
40 | return dataset
41 |
42 | def __len__(self):
43 | return len(self.samples)
44 |
45 | def __getitem__(self, idx):
46 | video_path, label = self.samples[idx]
47 | try:
48 | video = self.tv_read(video_path)
49 | video = self.transform(video) # T C H W -> T C H W
50 | video = video.transpose(0, 1) # T C H W -> C T H W
51 | return video, label
52 | except Exception as e:
53 | print(f'Error with {e}, {video_path}')
54 | return self.__getitem__(random.randint(0, self.__len__()-1))
55 |
56 | def tv_read(self, path):
57 | vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit='sec', output_format='TCHW')
58 | total_frames = len(vframes)
59 |
60 | # Sampling video frames
61 | start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
62 | # assert end_frame_ind - start_frame_ind >= self.num_frames
63 | frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.num_frames, dtype=int)
64 | video = vframes[frame_indice] # (T, C, H, W)
65 |
66 | return video
67 |
68 | def decord_read(self, path):
69 | decord_vr = self.v_decoder(path)
70 | total_frames = len(decord_vr)
71 | # Sampling video frames
72 | start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
73 | # assert end_frame_ind - start_frame_ind >= self.num_frames
74 | frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.num_frames, dtype=int)
75 |
76 | video_data = decord_vr.get_batch(frame_indice).asnumpy()
77 | video_data = torch.from_numpy(video_data)
78 | video_data = video_data.permute(0, 3, 1, 2) # (T, H, W, C) -> (T C H W)
79 | return video_data
80 |
81 |
--------------------------------------------------------------------------------
/opensora/eval/cal_flolpips.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from tqdm import tqdm
4 | import math
5 | from einops import rearrange
6 | import sys
7 | sys.path.append(".")
8 | from opensora.eval.flolpips.pwcnet import Network as PWCNet
9 | from opensora.eval.flolpips.flolpips import FloLPIPS
10 |
11 | loss_fn = FloLPIPS(net='alex', version='0.1').eval().requires_grad_(False)
12 | flownet = PWCNet().eval().requires_grad_(False)
13 |
14 | def trans(x):
15 | return x
16 |
17 |
18 | def calculate_flolpips(videos1, videos2, device):
19 | global loss_fn, flownet
20 |
21 | print("calculate_flowlpips...")
22 | loss_fn = loss_fn.to(device)
23 | flownet = flownet.to(device)
24 |
25 | if videos1.shape != videos2.shape:
26 | print("Warning: the shape of videos are not equal.")
27 | min_frames = min(videos1.shape[1], videos2.shape[1])
28 | videos1 = videos1[:, :min_frames]
29 | videos2 = videos2[:, :min_frames]
30 |
31 | videos1 = trans(videos1)
32 | videos2 = trans(videos2)
33 |
34 | flolpips_results = []
35 | for video_num in tqdm(range(videos1.shape[0])):
36 | video1 = videos1[video_num].to(device)
37 | video2 = videos2[video_num].to(device)
38 | frames_rec = video1[:-1]
39 | frames_rec_next = video1[1:]
40 | frames_gt = video2[:-1]
41 | frames_gt_next = video2[1:]
42 | t, c, h, w = frames_gt.shape
43 | flow_gt = flownet(frames_gt, frames_gt_next)
44 | flow_dis = flownet(frames_rec, frames_rec_next)
45 | flow_diff = flow_gt - flow_dis
46 | flolpips = loss_fn.forward(frames_gt, frames_rec, flow_diff, normalize=True)
47 | flolpips_results.append(flolpips.cpu().numpy().tolist())
48 |
49 | flolpips_results = np.array(flolpips_results) # [batch_size, num_frames]
50 | flolpips = {}
51 | flolpips_std = {}
52 |
53 | for clip_timestamp in range(flolpips_results.shape[1]):
54 | flolpips[clip_timestamp] = np.mean(flolpips_results[:,clip_timestamp], axis=-1)
55 | flolpips_std[clip_timestamp] = np.std(flolpips_results[:,clip_timestamp], axis=-1)
56 |
57 | result = {
58 | "value": flolpips,
59 | "value_std": flolpips_std,
60 | "video_setting": video1.shape,
61 | "video_setting_name": "time, channel, heigth, width",
62 | "result": flolpips_results,
63 | "details": flolpips_results.tolist()
64 | }
65 |
66 | return result
67 |
68 | # test code / using example
69 |
70 | def main():
71 | NUMBER_OF_VIDEOS = 8
72 | VIDEO_LENGTH = 50
73 | CHANNEL = 3
74 | SIZE = 64
75 | videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
76 | videos2 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
77 |
78 | import json
79 | result = calculate_flolpips(videos1, videos2, "cuda:0")
80 | print(json.dumps(result, indent=4))
81 |
82 | if __name__ == "__main__":
83 | main()
--------------------------------------------------------------------------------
/opensora/eval/cal_fvd.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from tqdm import tqdm
4 |
5 | def trans(x):
6 | # if greyscale images add channel
7 | if x.shape[-3] == 1:
8 | x = x.repeat(1, 1, 3, 1, 1)
9 |
10 | # permute BTCHW -> BCTHW
11 | x = x.permute(0, 2, 1, 3, 4)
12 |
13 | return x
14 |
15 | def calculate_fvd(videos1, videos2, device, method='styleganv'):
16 |
17 | if method == 'styleganv':
18 | from fvd.styleganv.fvd import get_fvd_feats, frechet_distance, load_i3d_pretrained
19 | elif method == 'videogpt':
20 | from fvd.videogpt.fvd import load_i3d_pretrained
21 | from fvd.videogpt.fvd import get_fvd_logits as get_fvd_feats
22 | from fvd.videogpt.fvd import frechet_distance
23 |
24 | print("calculate_fvd...")
25 |
26 | # videos [batch_size, timestamps, channel, h, w]
27 |
28 | assert videos1.shape == videos2.shape
29 |
30 | i3d = load_i3d_pretrained(device=device)
31 | fvd_results = []
32 |
33 | # support grayscale input, if grayscale -> channel*3
34 | # BTCHW -> BCTHW
35 | # videos -> [batch_size, channel, timestamps, h, w]
36 |
37 | videos1 = trans(videos1)
38 | videos2 = trans(videos2)
39 |
40 | fvd_results = {}
41 |
42 | # for calculate FVD, each clip_timestamp must >= 10
43 | for clip_timestamp in tqdm(range(10, videos1.shape[-3]+1)):
44 |
45 | # get a video clip
46 | # videos_clip [batch_size, channel, timestamps[:clip], h, w]
47 | videos_clip1 = videos1[:, :, : clip_timestamp]
48 | videos_clip2 = videos2[:, :, : clip_timestamp]
49 |
50 | # get FVD features
51 | feats1 = get_fvd_feats(videos_clip1, i3d=i3d, device=device)
52 | feats2 = get_fvd_feats(videos_clip2, i3d=i3d, device=device)
53 |
54 | # calculate FVD when timestamps[:clip]
55 | fvd_results[clip_timestamp] = frechet_distance(feats1, feats2)
56 |
57 | result = {
58 | "value": fvd_results,
59 | "video_setting": videos1.shape,
60 | "video_setting_name": "batch_size, channel, time, heigth, width",
61 | }
62 |
63 | return result
64 |
65 | # test code / using example
66 |
67 | def main():
68 | NUMBER_OF_VIDEOS = 8
69 | VIDEO_LENGTH = 50
70 | CHANNEL = 3
71 | SIZE = 64
72 | videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
73 | videos2 = torch.ones(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
74 | device = torch.device("cuda")
75 | # device = torch.device("cpu")
76 |
77 | import json
78 | result = calculate_fvd(videos1, videos2, device, method='videogpt')
79 | print(json.dumps(result, indent=4))
80 |
81 | result = calculate_fvd(videos1, videos2, device, method='styleganv')
82 | print(json.dumps(result, indent=4))
83 |
84 | if __name__ == "__main__":
85 | main()
86 |
--------------------------------------------------------------------------------
/opensora/eval/cal_lpips.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from tqdm import tqdm
4 | import math
5 |
6 | import torch
7 | import lpips
8 |
9 | spatial = True # Return a spatial map of perceptual distance.
10 |
11 | # Linearly calibrated models (LPIPS)
12 | loss_fn = lpips.LPIPS(net='alex', spatial=spatial) # Can also set net = 'squeeze' or 'vgg'
13 | # loss_fn = lpips.LPIPS(net='alex', spatial=spatial, lpips=False) # Can also set net = 'squeeze' or 'vgg'
14 |
15 | def trans(x):
16 | # if greyscale images add channel
17 | if x.shape[-3] == 1:
18 | x = x.repeat(1, 1, 3, 1, 1)
19 |
20 | # value range [0, 1] -> [-1, 1]
21 | x = x * 2 - 1
22 |
23 | return x
24 |
25 | def calculate_lpips(videos1, videos2, device):
26 | # image should be RGB, IMPORTANT: normalized to [-1,1]
27 | print("calculate_lpips...")
28 |
29 | assert videos1.shape == videos2.shape
30 |
31 | # videos [batch_size, timestamps, channel, h, w]
32 |
33 | # support grayscale input, if grayscale -> channel*3
34 | # value range [0, 1] -> [-1, 1]
35 | videos1 = trans(videos1)
36 | videos2 = trans(videos2)
37 |
38 | lpips_results = []
39 |
40 | for video_num in tqdm(range(videos1.shape[0])):
41 | # get a video
42 | # video [timestamps, channel, h, w]
43 | video1 = videos1[video_num]
44 | video2 = videos2[video_num]
45 |
46 | lpips_results_of_a_video = []
47 | for clip_timestamp in range(len(video1)):
48 | # get a img
49 | # img [timestamps[x], channel, h, w]
50 | # img [channel, h, w] tensor
51 |
52 | img1 = video1[clip_timestamp].unsqueeze(0).to(device)
53 | img2 = video2[clip_timestamp].unsqueeze(0).to(device)
54 |
55 | loss_fn.to(device)
56 |
57 | # calculate lpips of a video
58 | lpips_results_of_a_video.append(loss_fn.forward(img1, img2).mean().detach().cpu().tolist())
59 | lpips_results.append(lpips_results_of_a_video)
60 |
61 | lpips_results = np.array(lpips_results)
62 |
63 | lpips = {}
64 | lpips_std = {}
65 |
66 | for clip_timestamp in range(len(video1)):
67 | lpips[clip_timestamp] = np.mean(lpips_results[:,clip_timestamp])
68 | lpips_std[clip_timestamp] = np.std(lpips_results[:,clip_timestamp])
69 |
70 |
71 | result = {
72 | "value": lpips,
73 | "value_std": lpips_std,
74 | "video_setting": video1.shape,
75 | "video_setting_name": "time, channel, heigth, width",
76 | }
77 |
78 | return result
79 |
80 | # test code / using example
81 |
82 | def main():
83 | NUMBER_OF_VIDEOS = 8
84 | VIDEO_LENGTH = 50
85 | CHANNEL = 3
86 | SIZE = 64
87 | videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
88 | videos2 = torch.ones(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
89 | device = torch.device("cuda")
90 | # device = torch.device("cpu")
91 |
92 | import json
93 | result = calculate_lpips(videos1, videos2, device)
94 | print(json.dumps(result, indent=4))
95 |
96 | if __name__ == "__main__":
97 | main()
--------------------------------------------------------------------------------
/opensora/eval/cal_psnr.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from tqdm import tqdm
4 | import math
5 |
6 | def img_psnr(img1, img2):
7 | # [0,1]
8 | # compute mse
9 | # mse = np.mean((img1-img2)**2)
10 | mse = np.mean((img1 / 1.0 - img2 / 1.0) ** 2)
11 | # compute psnr
12 | if mse < 1e-10:
13 | return 100
14 | psnr = 20 * math.log10(1 / math.sqrt(mse))
15 | return psnr
16 |
17 | def trans(x):
18 | return x
19 |
20 | def calculate_psnr(videos1, videos2):
21 | print("calculate_psnr...")
22 |
23 | # videos [batch_size, timestamps, channel, h, w]
24 |
25 | assert videos1.shape == videos2.shape
26 |
27 | videos1 = trans(videos1)
28 | videos2 = trans(videos2)
29 |
30 | psnr_results = []
31 |
32 | for video_num in tqdm(range(videos1.shape[0])):
33 | # get a video
34 | # video [timestamps, channel, h, w]
35 | video1 = videos1[video_num]
36 | video2 = videos2[video_num]
37 |
38 | psnr_results_of_a_video = []
39 | for clip_timestamp in range(len(video1)):
40 | # get a img
41 | # img [timestamps[x], channel, h, w]
42 | # img [channel, h, w] numpy
43 |
44 | img1 = video1[clip_timestamp].numpy()
45 | img2 = video2[clip_timestamp].numpy()
46 |
47 | # calculate psnr of a video
48 | psnr_results_of_a_video.append(img_psnr(img1, img2))
49 |
50 | psnr_results.append(psnr_results_of_a_video)
51 |
52 | psnr_results = np.array(psnr_results) # [batch_size, num_frames]
53 | psnr = {}
54 | psnr_std = {}
55 |
56 | for clip_timestamp in range(len(video1)):
57 | psnr[clip_timestamp] = np.mean(psnr_results[:,clip_timestamp])
58 | psnr_std[clip_timestamp] = np.std(psnr_results[:,clip_timestamp])
59 |
60 | result = {
61 | "value": psnr,
62 | "value_std": psnr_std,
63 | "video_setting": video1.shape,
64 | "video_setting_name": "time, channel, heigth, width",
65 | }
66 |
67 | return result
68 |
69 | # test code / using example
70 |
71 | def main():
72 | NUMBER_OF_VIDEOS = 8
73 | VIDEO_LENGTH = 50
74 | CHANNEL = 3
75 | SIZE = 64
76 | videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
77 | videos2 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
78 |
79 | import json
80 | result = calculate_psnr(videos1, videos2)
81 | print(json.dumps(result, indent=4))
82 |
83 | if __name__ == "__main__":
84 | main()
--------------------------------------------------------------------------------
/opensora/eval/cal_ssim.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from tqdm import tqdm
4 | import cv2
5 |
6 | def ssim(img1, img2):
7 | C1 = 0.01 ** 2
8 | C2 = 0.03 ** 2
9 | img1 = img1.astype(np.float64)
10 | img2 = img2.astype(np.float64)
11 | kernel = cv2.getGaussianKernel(11, 1.5)
12 | window = np.outer(kernel, kernel.transpose())
13 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
14 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
15 | mu1_sq = mu1 ** 2
16 | mu2_sq = mu2 ** 2
17 | mu1_mu2 = mu1 * mu2
18 | sigma1_sq = cv2.filter2D(img1 ** 2, -1, window)[5:-5, 5:-5] - mu1_sq
19 | sigma2_sq = cv2.filter2D(img2 ** 2, -1, window)[5:-5, 5:-5] - mu2_sq
20 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
21 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
22 | (sigma1_sq + sigma2_sq + C2))
23 | return ssim_map.mean()
24 |
25 |
26 | def calculate_ssim_function(img1, img2):
27 | # [0,1]
28 | # ssim is the only metric extremely sensitive to gray being compared to b/w
29 | if not img1.shape == img2.shape:
30 | raise ValueError('Input images must have the same dimensions.')
31 | if img1.ndim == 2:
32 | return ssim(img1, img2)
33 | elif img1.ndim == 3:
34 | if img1.shape[0] == 3:
35 | ssims = []
36 | for i in range(3):
37 | ssims.append(ssim(img1[i], img2[i]))
38 | return np.array(ssims).mean()
39 | elif img1.shape[0] == 1:
40 | return ssim(np.squeeze(img1), np.squeeze(img2))
41 | else:
42 | raise ValueError('Wrong input image dimensions.')
43 |
44 | def trans(x):
45 | return x
46 |
47 | def calculate_ssim(videos1, videos2):
48 | print("calculate_ssim...")
49 |
50 | # videos [batch_size, timestamps, channel, h, w]
51 |
52 | assert videos1.shape == videos2.shape
53 |
54 | videos1 = trans(videos1)
55 | videos2 = trans(videos2)
56 |
57 | ssim_results = []
58 |
59 | for video_num in tqdm(range(videos1.shape[0])):
60 | # get a video
61 | # video [timestamps, channel, h, w]
62 | video1 = videos1[video_num]
63 | video2 = videos2[video_num]
64 |
65 | ssim_results_of_a_video = []
66 | for clip_timestamp in range(len(video1)):
67 | # get a img
68 | # img [timestamps[x], channel, h, w]
69 | # img [channel, h, w] numpy
70 |
71 | img1 = video1[clip_timestamp].numpy()
72 | img2 = video2[clip_timestamp].numpy()
73 |
74 | # calculate ssim of a video
75 | ssim_results_of_a_video.append(calculate_ssim_function(img1, img2))
76 |
77 | ssim_results.append(ssim_results_of_a_video)
78 |
79 | ssim_results = np.array(ssim_results)
80 |
81 | ssim = {}
82 | ssim_std = {}
83 |
84 | for clip_timestamp in range(len(video1)):
85 | ssim[clip_timestamp] = np.mean(ssim_results[:,clip_timestamp])
86 | ssim_std[clip_timestamp] = np.std(ssim_results[:,clip_timestamp])
87 |
88 | result = {
89 | "value": ssim,
90 | "value_std": ssim_std,
91 | "video_setting": video1.shape,
92 | "video_setting_name": "time, channel, heigth, width",
93 | }
94 |
95 | return result
96 |
97 | # test code / using example
98 |
99 | def main():
100 | NUMBER_OF_VIDEOS = 8
101 | VIDEO_LENGTH = 50
102 | CHANNEL = 3
103 | SIZE = 64
104 | videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
105 | videos2 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
106 | device = torch.device("cuda")
107 |
108 | import json
109 | result = calculate_ssim(videos1, videos2)
110 | print(json.dumps(result, indent=4))
111 |
112 | if __name__ == "__main__":
113 | main()
--------------------------------------------------------------------------------
/opensora/eval/flolpips/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 | import torch
4 |
5 |
6 | def normalize_tensor(in_feat,eps=1e-10):
7 | norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True))
8 | return in_feat/(norm_factor+eps)
9 |
10 | def l2(p0, p1, range=255.):
11 | return .5*np.mean((p0 / range - p1 / range)**2)
12 |
13 | def dssim(p0, p1, range=255.):
14 | from skimage.measure import compare_ssim
15 | return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2.
16 |
17 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
18 | image_numpy = image_tensor[0].cpu().float().numpy()
19 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
20 | return image_numpy.astype(imtype)
21 |
22 | def tensor2np(tensor_obj):
23 | # change dimension of a tensor object into a numpy array
24 | return tensor_obj[0].cpu().float().numpy().transpose((1,2,0))
25 |
26 | def np2tensor(np_obj):
27 | # change dimenion of np array into tensor array
28 | return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
29 |
30 | def tensor2tensorlab(image_tensor,to_norm=True,mc_only=False):
31 | # image tensor to lab tensor
32 | from skimage import color
33 |
34 | img = tensor2im(image_tensor)
35 | img_lab = color.rgb2lab(img)
36 | if(mc_only):
37 | img_lab[:,:,0] = img_lab[:,:,0]-50
38 | if(to_norm and not mc_only):
39 | img_lab[:,:,0] = img_lab[:,:,0]-50
40 | img_lab = img_lab/100.
41 |
42 | return np2tensor(img_lab)
43 |
44 | def read_frame_yuv2rgb(stream, width, height, iFrame, bit_depth, pix_fmt='420'):
45 | if pix_fmt == '420':
46 | multiplier = 1
47 | uv_factor = 2
48 | elif pix_fmt == '444':
49 | multiplier = 2
50 | uv_factor = 1
51 | else:
52 | print('Pixel format {} is not supported'.format(pix_fmt))
53 | return
54 |
55 | if bit_depth == 8:
56 | datatype = np.uint8
57 | stream.seek(iFrame*1.5*width*height*multiplier)
58 | Y = np.fromfile(stream, dtype=datatype, count=width*height).reshape((height, width))
59 |
60 | # read chroma samples and upsample since original is 4:2:0 sampling
61 | U = np.fromfile(stream, dtype=datatype, count=(width//uv_factor)*(height//uv_factor)).\
62 | reshape((height//uv_factor, width//uv_factor))
63 | V = np.fromfile(stream, dtype=datatype, count=(width//uv_factor)*(height//uv_factor)).\
64 | reshape((height//uv_factor, width//uv_factor))
65 |
66 | else:
67 | datatype = np.uint16
68 | stream.seek(iFrame*3*width*height*multiplier)
69 | Y = np.fromfile(stream, dtype=datatype, count=width*height).reshape((height, width))
70 |
71 | U = np.fromfile(stream, dtype=datatype, count=(width//uv_factor)*(height//uv_factor)).\
72 | reshape((height//uv_factor, width//uv_factor))
73 | V = np.fromfile(stream, dtype=datatype, count=(width//uv_factor)*(height//uv_factor)).\
74 | reshape((height//uv_factor, width//uv_factor))
75 |
76 | if pix_fmt == '420':
77 | yuv = np.empty((height*3//2, width), dtype=datatype)
78 | yuv[0:height,:] = Y
79 |
80 | yuv[height:height+height//4,:] = U.reshape(-1, width)
81 | yuv[height+height//4:,:] = V.reshape(-1, width)
82 |
83 | if bit_depth != 8:
84 | yuv = (yuv/(2**bit_depth-1)*255).astype(np.uint8)
85 |
86 | #convert to rgb
87 | rgb = cv2.cvtColor(yuv, cv2.COLOR_YUV2RGB_I420)
88 |
89 | else:
90 | yvu = np.stack([Y,V,U],axis=2)
91 | if bit_depth != 8:
92 | yvu = (yvu/(2**bit_depth-1)*255).astype(np.uint8)
93 | rgb = cv2.cvtColor(yvu, cv2.COLOR_YCrCb2RGB)
94 |
95 | return rgb
96 |
--------------------------------------------------------------------------------
/opensora/eval/fvd/styleganv/fvd.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import math
4 | import torch.nn.functional as F
5 |
6 | # https://github.com/universome/fvd-comparison
7 |
8 |
9 | def load_i3d_pretrained(device=torch.device('cpu')):
10 | i3D_WEIGHTS_URL = "https://www.dropbox.com/s/ge9e5ujwgetktms/i3d_torchscript.pt"
11 | filepath = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'i3d_torchscript.pt')
12 | print(filepath)
13 | if not os.path.exists(filepath):
14 | print(f"preparing for download {i3D_WEIGHTS_URL}, you can download it by yourself.")
15 | os.system(f"wget {i3D_WEIGHTS_URL} -O {filepath}")
16 | i3d = torch.jit.load(filepath).eval().to(device)
17 | i3d = torch.nn.DataParallel(i3d)
18 | return i3d
19 |
20 |
21 | def get_feats(videos, detector, device, bs=10):
22 | # videos : torch.tensor BCTHW [0, 1]
23 | detector_kwargs = dict(rescale=False, resize=False, return_features=True) # Return raw features before the softmax layer.
24 | feats = np.empty((0, 400))
25 | with torch.no_grad():
26 | for i in range((len(videos)-1)//bs + 1):
27 | feats = np.vstack([feats, detector(torch.stack([preprocess_single(video) for video in videos[i*bs:(i+1)*bs]]).to(device), **detector_kwargs).detach().cpu().numpy()])
28 | return feats
29 |
30 |
31 | def get_fvd_feats(videos, i3d, device, bs=10):
32 | # videos in [0, 1] as torch tensor BCTHW
33 | # videos = [preprocess_single(video) for video in videos]
34 | embeddings = get_feats(videos, i3d, device, bs)
35 | return embeddings
36 |
37 |
38 | def preprocess_single(video, resolution=224, sequence_length=None):
39 | # video: CTHW, [0, 1]
40 | c, t, h, w = video.shape
41 |
42 | # temporal crop
43 | if sequence_length is not None:
44 | assert sequence_length <= t
45 | video = video[:, :sequence_length]
46 |
47 | # scale shorter side to resolution
48 | scale = resolution / min(h, w)
49 | if h < w:
50 | target_size = (resolution, math.ceil(w * scale))
51 | else:
52 | target_size = (math.ceil(h * scale), resolution)
53 | video = F.interpolate(video, size=target_size, mode='bilinear', align_corners=False)
54 |
55 | # center crop
56 | c, t, h, w = video.shape
57 | w_start = (w - resolution) // 2
58 | h_start = (h - resolution) // 2
59 | video = video[:, :, h_start:h_start + resolution, w_start:w_start + resolution]
60 |
61 | # [0, 1] -> [-1, 1]
62 | video = (video - 0.5) * 2
63 |
64 | return video.contiguous()
65 |
66 |
67 | """
68 | Copy-pasted from https://github.com/cvpr2022-stylegan-v/stylegan-v/blob/main/src/metrics/frechet_video_distance.py
69 | """
70 | from typing import Tuple
71 | from scipy.linalg import sqrtm
72 | import numpy as np
73 |
74 |
75 | def compute_stats(feats: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
76 | mu = feats.mean(axis=0) # [d]
77 | sigma = np.cov(feats, rowvar=False) # [d, d]
78 | return mu, sigma
79 |
80 |
81 | def frechet_distance(feats_fake: np.ndarray, feats_real: np.ndarray) -> float:
82 | mu_gen, sigma_gen = compute_stats(feats_fake)
83 | mu_real, sigma_real = compute_stats(feats_real)
84 | m = np.square(mu_gen - mu_real).sum()
85 | if feats_fake.shape[0]>1:
86 | s, _ = sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member
87 | fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2))
88 | else:
89 | fid = np.real(m)
90 | return float(fid)
--------------------------------------------------------------------------------
/opensora/eval/fvd/styleganv/i3d_torchscript.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-Open-Sora-Plan/b060ff6d7a85a27eec5ff9b81b599d03c4ac1bc6/opensora/eval/fvd/styleganv/i3d_torchscript.pt
--------------------------------------------------------------------------------
/opensora/eval/fvd/videogpt/i3d_pretrained_400.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-Open-Sora-Plan/b060ff6d7a85a27eec5ff9b81b599d03c4ac1bc6/opensora/eval/fvd/videogpt/i3d_pretrained_400.pt
--------------------------------------------------------------------------------
/opensora/eval/script/cal_clip_score.sh:
--------------------------------------------------------------------------------
1 | # clip_score cross modality
2 | python eval_clip_score.py \
3 | --real_path path/to/image \
4 | --generated_path path/to/text \
5 | --batch-size 50 \
6 | --device "cuda"
7 |
8 | # clip_score within the same modality
9 | python eval_clip_score.py \
10 | --real_path path/to/textA \
11 | --generated_path path/to/textB \
12 | --real_flag txt \
13 | --generated_flag txt \
14 | --batch-size 50 \
15 | --device "cuda"
16 |
17 | python eval_clip_score.py \
18 | --real_path path/to/imageA \
19 | --generated_path path/to/imageB \
20 | --real_flag img \
21 | --generated_flag img \
22 | --batch-size 50 \
23 | --device "cuda"
24 |
--------------------------------------------------------------------------------
/opensora/eval/script/cal_fvd.sh:
--------------------------------------------------------------------------------
1 | python eval_common_metric.py \
2 | --real_video_dir path/to/imageA\
3 | --generated_video_dir path/to/imageB \
4 | --batch_size 10 \
5 | --crop_size 64 \
6 | --num_frames 20 \
7 | --device 'cuda' \
8 | --metric 'fvd' \
9 | --fvd_method 'styleganv'
10 |
--------------------------------------------------------------------------------
/opensora/eval/script/cal_lpips.sh:
--------------------------------------------------------------------------------
1 | python eval_common_metric.py \
2 | --real_video_dir path/to/imageA\
3 | --generated_video_dir path/to/imageB \
4 | --batch_size 10 \
5 | --num_frames 20 \
6 | --crop_size 64 \
7 | --device 'cuda' \
8 | --metric 'lpips'
--------------------------------------------------------------------------------
/opensora/eval/script/cal_psnr.sh:
--------------------------------------------------------------------------------
1 |
2 | python eval_common_metric.py \
3 | --real_video_dir /data/xiaogeng_liu/data/video1 \
4 | --generated_video_dir /data/xiaogeng_liu/data/video2 \
5 | --batch_size 10 \
6 | --num_frames 20 \
7 | --crop_size 64 \
8 | --device 'cuda' \
9 | --metric 'psnr'
--------------------------------------------------------------------------------
/opensora/eval/script/cal_ssim.sh:
--------------------------------------------------------------------------------
1 | python eval_common_metric.py \
2 | --real_video_dir /data/xiaogeng_liu/data/video1 \
3 | --generated_video_dir /data/xiaogeng_liu/data/video2 \
4 | --batch_size 10 \
5 | --num_frames 20 \
6 | --crop_size 64 \
7 | --device 'cuda' \
8 | --metric 'ssim'
--------------------------------------------------------------------------------
/opensora/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-Open-Sora-Plan/b060ff6d7a85a27eec5ff9b81b599d03c4ac1bc6/opensora/models/__init__.py
--------------------------------------------------------------------------------
/opensora/models/ae/__init__.py:
--------------------------------------------------------------------------------
1 | from .imagebase import imagebase_ae, imagebase_ae_stride, imagebase_ae_channel
2 | from .videobase import videobase_ae, videobase_ae_stride, videobase_ae_channel
3 | from .videobase import (
4 | VQVAEConfiguration,
5 | VQVAEModel,
6 | VQVAETrainer,
7 | CausalVQVAEModel,
8 | CausalVQVAEConfiguration,
9 | CausalVQVAETrainer
10 | )
11 |
12 | ae_stride_config = {}
13 | ae_stride_config.update(imagebase_ae_stride)
14 | ae_stride_config.update(videobase_ae_stride)
15 |
16 | ae_channel_config = {}
17 | ae_channel_config.update(imagebase_ae_channel)
18 | ae_channel_config.update(videobase_ae_channel)
19 |
20 | def getae(args):
21 | """deprecation"""
22 | ae = imagebase_ae.get(args.ae, None) or videobase_ae.get(args.ae, None)
23 | assert ae is not None
24 | return ae(args.ae)
25 |
26 | def getae_wrapper(ae):
27 | """deprecation"""
28 | ae = imagebase_ae.get(ae, None) or videobase_ae.get(ae, None)
29 | assert ae is not None
30 | return ae
--------------------------------------------------------------------------------
/opensora/models/ae/imagebase/__init__.py:
--------------------------------------------------------------------------------
1 | from .vae.vae import HFVAEWrapper
2 | from .vae.vae import SDVAEWrapper
3 | from .vqvae.vqvae import SDVQVAEWrapper
4 |
5 | vae = ['stabilityai/sd-vae-ft-mse', 'stabilityai/sd-vae-ft-ema']
6 | vqvae = ['vqgan_imagenet_f16_1024', 'vqgan_imagenet_f16_16384', 'vqgan_gumbel_f8']
7 |
8 | imagebase_ae_stride = {
9 | 'stabilityai/sd-vae-ft-mse': [1, 8, 8],
10 | 'stabilityai/sd-vae-ft-ema': [1, 8, 8],
11 | 'vqgan_imagenet_f16_1024': [1, 16, 16],
12 | 'vqgan_imagenet_f16_16384': [1, 16, 16],
13 | 'vqgan_gumbel_f8': [1, 8, 8],
14 | }
15 |
16 | imagebase_ae_channel = {
17 | 'stabilityai/sd-vae-ft-mse': 4,
18 | 'stabilityai/sd-vae-ft-ema': 4,
19 | 'vqgan_imagenet_f16_1024': -1,
20 | 'vqgan_imagenet_f16_16384': -1,
21 | 'vqgan_gumbel_f8': -1,
22 | }
23 |
24 | imagebase_ae = {
25 | 'stabilityai/sd-vae-ft-mse': HFVAEWrapper,
26 | 'stabilityai/sd-vae-ft-ema': HFVAEWrapper,
27 | 'vqgan_imagenet_f16_1024': SDVQVAEWrapper,
28 | 'vqgan_imagenet_f16_16384': SDVQVAEWrapper,
29 | 'vqgan_gumbel_f8': SDVQVAEWrapper,
30 | }
--------------------------------------------------------------------------------
/opensora/models/ae/imagebase/vae/vae.py:
--------------------------------------------------------------------------------
1 | from einops import rearrange
2 | from torch import nn
3 | from diffusers.models import AutoencoderKL
4 |
5 |
6 | class HFVAEWrapper(nn.Module):
7 | def __init__(self, hfvae='mse'):
8 | super(HFVAEWrapper, self).__init__()
9 | self.vae = AutoencoderKL.from_pretrained(hfvae, cache_dir='cache_dir')
10 | def encode(self, x): # b c h w
11 | t = 0
12 | if x.ndim == 5:
13 | b, c, t, h, w = x.shape
14 | x = rearrange(x, 'b c t h w -> (b t) c h w').contiguous()
15 | x = self.vae.encode(x).latent_dist.sample().mul_(0.18215)
16 | if t != 0:
17 | x = rearrange(x, '(b t) c h w -> b c t h w', t=t).contiguous()
18 | return x
19 | def decode(self, x):
20 | t = 0
21 | if x.ndim == 5:
22 | b, c, t, h, w = x.shape
23 | x = rearrange(x, 'b c t h w -> (b t) c h w').contiguous()
24 | x = self.vae.decode(x / 0.18215).sample
25 | if t != 0:
26 | x = rearrange(x, '(b t) c h w -> b t c h w', t=t).contiguous()
27 | return x
28 |
29 | class SDVAEWrapper(nn.Module):
30 | def __init__(self):
31 | super(SDVAEWrapper, self).__init__()
32 | raise NotImplementedError
33 |
34 | def encode(self, x): # b c h w
35 | raise NotImplementedError
36 |
37 | def decode(self, x):
38 | raise NotImplementedError
--------------------------------------------------------------------------------
/opensora/models/ae/imagebase/vqvae/vqvae.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | import yaml
3 | import torch
4 | from omegaconf import OmegaConf
5 | from .vqgan import VQModel, GumbelVQ
6 |
7 | def load_config(config_path, display=False):
8 | config = OmegaConf.load(config_path)
9 | if display:
10 | print(yaml.dump(OmegaConf.to_container(config)))
11 | return config
12 |
13 |
14 | def load_vqgan(config, ckpt_path=None, is_gumbel=False):
15 | if is_gumbel:
16 | model = GumbelVQ(**config.model.params)
17 | else:
18 | model = VQModel(**config.model.params)
19 | if ckpt_path is not None:
20 | sd = torch.load(ckpt_path, map_location="cpu")["state_dict"]
21 | missing, unexpected = model.load_state_dict(sd, strict=False)
22 | return model.eval()
23 |
24 |
25 | class SDVQVAEWrapper(nn.Module):
26 | def __init__(self, name):
27 | super(SDVQVAEWrapper, self).__init__()
28 | raise NotImplementedError
29 |
30 | def encode(self, x): # b c h w
31 | raise NotImplementedError
32 |
33 | def decode(self, x):
34 | raise NotImplementedError
35 |
--------------------------------------------------------------------------------
/opensora/models/ae/videobase/__init__.py:
--------------------------------------------------------------------------------
1 | from .vqvae import (
2 | VQVAEConfiguration,
3 | VQVAEModel,
4 | VQVAETrainer,
5 | VQVAEModelWrapper
6 | )
7 | from .causal_vqvae import (
8 | CausalVQVAEConfiguration,
9 | CausalVQVAETrainer,
10 | CausalVQVAEModel, CausalVQVAEModelWrapper
11 | )
12 | from .causal_vae import (
13 | CausalVAEModel, CausalVAEModelWrapper
14 | )
15 |
16 |
17 | videobase_ae_stride = {
18 | 'CausalVAEModel_4x8x8': [4, 8, 8],
19 | 'CausalVQVAEModel_4x4x4': [4, 4, 4],
20 | 'CausalVQVAEModel_4x8x8': [4, 8, 8],
21 | 'VQVAEModel_4x4x4': [4, 4, 4],
22 | 'OpenVQVAEModel_4x4x4': [4, 4, 4],
23 | 'VQVAEModel_4x8x8': [4, 8, 8],
24 | 'bair_stride4x2x2': [4, 2, 2],
25 | 'ucf101_stride4x4x4': [4, 4, 4],
26 | 'kinetics_stride4x4x4': [4, 4, 4],
27 | 'kinetics_stride2x4x4': [2, 4, 4],
28 | }
29 |
30 | videobase_ae_channel = {
31 | 'CausalVAEModel_4x8x8': 4,
32 | 'CausalVQVAEModel_4x4x4': 4,
33 | 'CausalVQVAEModel_4x8x8': 4,
34 | 'VQVAEModel_4x4x4': 4,
35 | 'OpenVQVAEModel_4x4x4': 4,
36 | 'VQVAEModel_4x8x8': 4,
37 | 'bair_stride4x2x2': 256,
38 | 'ucf101_stride4x4x4': 256,
39 | 'kinetics_stride4x4x4': 256,
40 | 'kinetics_stride2x4x4': 256,
41 | }
42 |
43 | videobase_ae = {
44 | 'CausalVAEModel_4x8x8': CausalVAEModelWrapper,
45 | 'CausalVQVAEModel_4x4x4': CausalVQVAEModelWrapper,
46 | 'CausalVQVAEModel_4x8x8': CausalVQVAEModelWrapper,
47 | 'VQVAEModel_4x4x4': VQVAEModelWrapper,
48 | 'VQVAEModel_4x8x8': VQVAEModelWrapper,
49 | "bair_stride4x2x2": VQVAEModelWrapper,
50 | "ucf101_stride4x4x4": VQVAEModelWrapper,
51 | "kinetics_stride4x4x4": VQVAEModelWrapper,
52 | "kinetics_stride2x4x4": VQVAEModelWrapper,
53 | }
54 |
--------------------------------------------------------------------------------
/opensora/models/ae/videobase/causal_vae/__init__.py:
--------------------------------------------------------------------------------
1 | from .modeling_causalvae import CausalVAEModel
2 |
3 | from einops import rearrange
4 | from torch import nn
5 |
6 | class CausalVAEModelWrapper(nn.Module):
7 | def __init__(self, model_path, subfolder=None, cache_dir=None, **kwargs):
8 | super(CausalVAEModelWrapper, self).__init__()
9 | # if os.path.exists(ckpt):
10 | # self.vae = CausalVAEModel.load_from_checkpoint(ckpt)
11 | self.vae = CausalVAEModel.from_pretrained(model_path, subfolder=subfolder, cache_dir=cache_dir, **kwargs)
12 | def encode(self, x): # b c t h w
13 | # x = self.vae.encode(x).sample()
14 | x = self.vae.encode(x).sample().mul_(0.18215)
15 | return x
16 | def decode(self, x):
17 | # x = self.vae.decode(x)
18 | x = self.vae.decode(x / 0.18215)
19 | x = rearrange(x, 'b c t h w -> b t c h w').contiguous()
20 | return x
21 |
22 | def dtype(self):
23 | return self.vae.dtype
24 | #
25 | # def device(self):
26 | # return self.vae.device
--------------------------------------------------------------------------------
/opensora/models/ae/videobase/causal_vqvae/__init__.py:
--------------------------------------------------------------------------------
1 | from .configuration_causalvqvae import CausalVQVAEConfiguration
2 | from .modeling_causalvqvae import CausalVQVAEModel
3 | from .trainer_causalvqvae import CausalVQVAETrainer
4 |
5 |
6 | from einops import rearrange
7 | from torch import nn
8 |
9 | class CausalVQVAEModelWrapper(nn.Module):
10 | def __init__(self, ckpt):
11 | super(CausalVQVAEModelWrapper, self).__init__()
12 | self.vqvae = CausalVQVAEModel.load_from_checkpoint(ckpt)
13 | def encode(self, x): # b c t h w
14 | x = self.vqvae.pre_vq_conv(self.vqvae.encoder(x))
15 | return x
16 | def decode(self, x):
17 | vq_output = self.vqvae.codebook(x)
18 | x = self.vqvae.decoder(self.vqvae.post_vq_conv(vq_output['embeddings']))
19 | x = rearrange(x, 'b c t h w -> b t c h w').contiguous()
20 | return x
--------------------------------------------------------------------------------
/opensora/models/ae/videobase/causal_vqvae/configuration_causalvqvae.py:
--------------------------------------------------------------------------------
1 | from ..configuration_videobase import VideoBaseConfiguration
2 | from typing import Union, Tuple
3 |
4 | class CausalVQVAEConfiguration(VideoBaseConfiguration):
5 | def __init__(
6 | self,
7 | embedding_dim: int = 256,
8 | n_codes: int = 2048,
9 | n_hiddens: int = 240,
10 | n_res_layers: int = 4,
11 | resolution: int = 128,
12 | sequence_length: int = 16,
13 | time_downsample: int = 4,
14 | spatial_downsample: int = 8,
15 | no_pos_embd: bool = True,
16 | **kwargs,
17 | ):
18 | super().__init__(**kwargs)
19 |
20 | self.embedding_dim = embedding_dim
21 | self.n_codes = n_codes
22 | self.n_hiddens = n_hiddens
23 | self.n_res_layers = n_res_layers
24 | self.resolution = resolution
25 | self.sequence_length = sequence_length
26 | self.time_downsample = time_downsample
27 | self.spatial_downsample = spatial_downsample
28 | self.no_pos_embd = no_pos_embd
29 |
30 | self.hidden_size = n_hiddens
31 |
--------------------------------------------------------------------------------
/opensora/models/ae/videobase/causal_vqvae/trainer_causalvqvae.py:
--------------------------------------------------------------------------------
1 | from ..trainer_videobase import VideoBaseTrainer
2 | import torch.nn.functional as F
3 | from typing import Optional
4 | import os
5 | import torch
6 | from transformers.utils import WEIGHTS_NAME
7 | import json
8 |
9 | class CausalVQVAETrainer(VideoBaseTrainer):
10 |
11 | def compute_loss(self, model, inputs, return_outputs=False):
12 | model = model.module
13 | x = inputs.get("video")
14 | x = x / 2
15 | z = model.pre_vq_conv(model.encoder(x))
16 | vq_output = model.codebook(z)
17 | x_recon = model.decoder(model.post_vq_conv(vq_output["embeddings"]))
18 | recon_loss = F.mse_loss(x_recon, x) / 0.06
19 | commitment_loss = vq_output['commitment_loss']
20 | loss = recon_loss + commitment_loss
21 | return loss
22 |
--------------------------------------------------------------------------------
/opensora/models/ae/videobase/configuration_videobase.py:
--------------------------------------------------------------------------------
1 | import json
2 | import yaml
3 | from typing import TypeVar, Dict, Any
4 | from diffusers import ConfigMixin
5 |
6 | T = TypeVar('T', bound='VideoBaseConfiguration')
7 | class VideoBaseConfiguration(ConfigMixin):
8 | config_name = "VideoBaseConfiguration"
9 | _nested_config_fields: Dict[str, Any] = {}
10 |
11 | def __init__(self, **kwargs):
12 | pass
13 |
14 | def to_dict(self) -> Dict[str, Any]:
15 | d = {}
16 | for key, value in vars(self).items():
17 | if isinstance(value, VideoBaseConfiguration):
18 | d[key] = value.to_dict() # Serialize nested VideoBaseConfiguration instances
19 | elif isinstance(value, tuple):
20 | d[key] = list(value)
21 | else:
22 | d[key] = value
23 | return d
24 |
25 | def to_yaml_file(self, yaml_path: str):
26 | with open(yaml_path, 'w') as yaml_file:
27 | yaml.dump(self.to_dict(), yaml_file, default_flow_style=False)
28 |
29 | @classmethod
30 | def load_from_yaml(cls: T, yaml_path: str) -> T:
31 | with open(yaml_path, 'r') as yaml_file:
32 | config_dict = yaml.safe_load(yaml_file)
33 | for field, field_type in cls._nested_config_fields.items():
34 | if field in config_dict:
35 | config_dict[field] = field_type.load_from_dict(config_dict[field])
36 | return cls(**config_dict)
37 |
38 | @classmethod
39 | def load_from_dict(cls: T, config_dict: Dict[str, Any]) -> T:
40 | # Process nested configuration objects
41 | for field, field_type in cls._nested_config_fields.items():
42 | if field in config_dict:
43 | config_dict[field] = field_type.load_from_dict(config_dict[field])
44 | return cls(**config_dict)
--------------------------------------------------------------------------------
/opensora/models/ae/videobase/losses/__init__.py:
--------------------------------------------------------------------------------
1 | from .perceptual_loss import SimpleLPIPS, LPIPSWithDiscriminator, LPIPSWithDiscriminator3D
--------------------------------------------------------------------------------
/opensora/models/ae/videobase/modeling_videobase.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from diffusers import ModelMixin, ConfigMixin
3 | from torch import nn
4 | import os
5 | import json
6 | import pytorch_lightning as pl
7 | from diffusers.configuration_utils import ConfigMixin
8 | from diffusers.models.modeling_utils import ModelMixin
9 | from typing import Optional, Union
10 | import glob
11 |
12 | class VideoBaseAE(nn.Module):
13 | _supports_gradient_checkpointing = False
14 |
15 | def __init__(self, *args, **kwargs) -> None:
16 | super().__init__(*args, **kwargs)
17 |
18 | @classmethod
19 | def load_from_checkpoint(cls, model_path):
20 | with open(os.path.join(model_path, "config.json"), "r") as file:
21 | config = json.load(file)
22 | state_dict = torch.load(os.path.join(model_path, "pytorch_model.bin"), map_location="cpu")
23 | if 'state_dict' in state_dict:
24 | state_dict = state_dict['state_dict']
25 | model = cls(config=cls.CONFIGURATION_CLS(**config))
26 | model.load_state_dict(state_dict)
27 | return model
28 |
29 | @classmethod
30 | def download_and_load_model(cls, model_name, cache_dir=None):
31 | pass
32 |
33 | def encode(self, x: torch.Tensor, *args, **kwargs):
34 | pass
35 |
36 | def decode(self, encoding: torch.Tensor, *args, **kwargs):
37 | pass
38 |
39 | class VideoBaseAE_PL(pl.LightningModule, ModelMixin, ConfigMixin):
40 | config_name = "config.json"
41 |
42 | def __init__(self, *args, **kwargs) -> None:
43 | super().__init__(*args, **kwargs)
44 |
45 | def encode(self, x: torch.Tensor, *args, **kwargs):
46 | pass
47 |
48 | def decode(self, encoding: torch.Tensor, *args, **kwargs):
49 | pass
50 |
51 | @property
52 | def num_training_steps(self) -> int:
53 | """Total training steps inferred from datamodule and devices."""
54 | if self.trainer.max_steps:
55 | return self.trainer.max_steps
56 |
57 | limit_batches = self.trainer.limit_train_batches
58 | batches = len(self.train_dataloader())
59 | batches = min(batches, limit_batches) if isinstance(limit_batches, int) else int(limit_batches * batches)
60 |
61 | num_devices = max(1, self.trainer.num_gpus, self.trainer.num_processes)
62 | if self.trainer.tpu_cores:
63 | num_devices = max(num_devices, self.trainer.tpu_cores)
64 |
65 | effective_accum = self.trainer.accumulate_grad_batches * num_devices
66 | return (batches // effective_accum) * self.trainer.max_epochs
67 |
68 | @classmethod
69 | def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
70 | ckpt_files = glob.glob(os.path.join(pretrained_model_name_or_path, '*.ckpt'))
71 | if ckpt_files:
72 | # Adapt to PyTorch Lightning
73 | last_ckpt_file = ckpt_files[-1]
74 | config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
75 | model = cls.from_config(config_file)
76 | print("init from {}".format(last_ckpt_file))
77 | model.init_from_ckpt(last_ckpt_file)
78 | return model
79 | else:
80 | return super().from_pretrained(pretrained_model_name_or_path, **kwargs)
--------------------------------------------------------------------------------
/opensora/models/ae/videobase/modules/__init__.py:
--------------------------------------------------------------------------------
1 | from .block import Block
2 | from .attention import (
3 | AttnBlock3D,
4 | AttnBlock3DFix,
5 | AttnBlock,
6 | LinAttnBlock,
7 | LinearAttention,
8 | TemporalAttnBlock
9 | )
10 | from .conv import CausalConv3d, Conv2d
11 | from .normalize import GroupNorm, Normalize
12 | from .resnet_block import ResnetBlock2D, ResnetBlock3D
13 | from .updownsample import (
14 | SpatialDownsample2x,
15 | SpatialUpsample2x,
16 | TimeDownsample2x,
17 | TimeUpsample2x,
18 | Upsample,
19 | Downsample,
20 | TimeDownsampleRes2x,
21 | TimeUpsampleRes2x,
22 | TimeDownsampleResAdv2x,
23 | TimeUpsampleResAdv2x
24 | )
25 |
--------------------------------------------------------------------------------
/opensora/models/ae/videobase/modules/block.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | class Block(nn.Module):
4 | def __init__(self, *args, **kwargs) -> None:
5 | super().__init__(*args, **kwargs)
--------------------------------------------------------------------------------
/opensora/models/ae/videobase/modules/conv.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from typing import Union, Tuple
3 | import torch.nn.functional as F
4 | import torch
5 | from .block import Block
6 | from .ops import cast_tuple
7 | from einops import rearrange
8 | from .ops import video_to_image
9 |
10 | class Conv2d(nn.Conv2d):
11 | def __init__(
12 | self,
13 | in_channels: int,
14 | out_channels: int,
15 | kernel_size: Union[int, Tuple[int]] = 3,
16 | stride: Union[int, Tuple[int]] = 1,
17 | padding: Union[str, int, Tuple[int]] = 0,
18 | dilation: Union[int, Tuple[int]] = 1,
19 | groups: int = 1,
20 | bias: bool = True,
21 | padding_mode: str = "zeros",
22 | device=None,
23 | dtype=None,
24 | ) -> None:
25 | super().__init__(
26 | in_channels,
27 | out_channels,
28 | kernel_size,
29 | stride,
30 | padding,
31 | dilation,
32 | groups,
33 | bias,
34 | padding_mode,
35 | device,
36 | dtype,
37 | )
38 |
39 | @video_to_image
40 | def forward(self, x):
41 | return super().forward(x)
42 |
43 |
44 | class CausalConv3d(nn.Module):
45 | def __init__(
46 | self, chan_in, chan_out, kernel_size: Union[int, Tuple[int, int, int]], init_method="random", **kwargs
47 | ):
48 | super().__init__()
49 | self.kernel_size = cast_tuple(kernel_size, 3)
50 | self.time_kernel_size = self.kernel_size[0]
51 | self.chan_in = chan_in
52 | self.chan_out = chan_out
53 | stride = kwargs.pop("stride", 1)
54 | padding = kwargs.pop("padding", 0)
55 | padding = list(cast_tuple(padding, 3))
56 | padding[0] = 0
57 | stride = cast_tuple(stride, 3)
58 | self.conv = nn.Conv3d(chan_in, chan_out, self.kernel_size, stride=stride, padding=padding)
59 | self._init_weights(init_method)
60 |
61 | def _init_weights(self, init_method):
62 | ks = torch.tensor(self.kernel_size)
63 | if init_method == "avg":
64 | assert (
65 | self.kernel_size[1] == 1 and self.kernel_size[2] == 1
66 | ), "only support temporal up/down sample"
67 | assert self.chan_in == self.chan_out, "chan_in must be equal to chan_out"
68 | weight = torch.zeros((self.chan_out, self.chan_in, *self.kernel_size))
69 |
70 | eyes = torch.concat(
71 | [
72 | torch.eye(self.chan_in).unsqueeze(-1) * 1/3,
73 | torch.eye(self.chan_in).unsqueeze(-1) * 1/3,
74 | torch.eye(self.chan_in).unsqueeze(-1) * 1/3,
75 | ],
76 | dim=-1,
77 | )
78 | weight[:, :, :, 0, 0] = eyes
79 |
80 | self.conv.weight = nn.Parameter(
81 | weight,
82 | requires_grad=True,
83 | )
84 | elif init_method == "zero":
85 | self.conv.weight = nn.Parameter(
86 | torch.zeros((self.chan_out, self.chan_in, *self.kernel_size)),
87 | requires_grad=True,
88 | )
89 | if self.conv.bias is not None:
90 | nn.init.constant_(self.conv.bias, 0)
91 |
92 | def forward(self, x):
93 | # 1 + 16 16 as video, 1 as image
94 | first_frame_pad = x[:, :, :1, :, :].repeat(
95 | (1, 1, self.time_kernel_size - 1, 1, 1)
96 | ) # b c t h w
97 | x = torch.concatenate((first_frame_pad, x), dim=2) # 3 + 16
98 | return self.conv(x)
--------------------------------------------------------------------------------
/opensora/models/ae/videobase/modules/normalize.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from .block import Block
4 |
5 | class GroupNorm(Block):
6 | def __init__(self, num_channels, num_groups=32, eps=1e-6, *args, **kwargs) -> None:
7 | super().__init__(*args, **kwargs)
8 | self.norm = torch.nn.GroupNorm(
9 | num_groups=num_groups, num_channels=num_channels, eps=1e-6, affine=True
10 | )
11 | def forward(self, x):
12 | return self.norm(x)
13 |
14 | def Normalize(in_channels, num_groups=32):
15 | return torch.nn.GroupNorm(
16 | num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
17 | )
18 |
19 | class ActNorm(nn.Module):
20 | def __init__(self, num_features, logdet=False, affine=True,
21 | allow_reverse_init=False):
22 | assert affine
23 | super().__init__()
24 | self.logdet = logdet
25 | self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
26 | self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
27 | self.allow_reverse_init = allow_reverse_init
28 |
29 | self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
30 |
31 | def initialize(self, input):
32 | with torch.no_grad():
33 | flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
34 | mean = (
35 | flatten.mean(1)
36 | .unsqueeze(1)
37 | .unsqueeze(2)
38 | .unsqueeze(3)
39 | .permute(1, 0, 2, 3)
40 | )
41 | std = (
42 | flatten.std(1)
43 | .unsqueeze(1)
44 | .unsqueeze(2)
45 | .unsqueeze(3)
46 | .permute(1, 0, 2, 3)
47 | )
48 |
49 | self.loc.data.copy_(-mean)
50 | self.scale.data.copy_(1 / (std + 1e-6))
51 |
52 | def forward(self, input, reverse=False):
53 | if reverse:
54 | return self.reverse(input)
55 | if len(input.shape) == 2:
56 | input = input[:,:,None,None]
57 | squeeze = True
58 | else:
59 | squeeze = False
60 |
61 | _, _, height, width = input.shape
62 |
63 | if self.training and self.initialized.item() == 0:
64 | self.initialize(input)
65 | self.initialized.fill_(1)
66 |
67 | h = self.scale * (input + self.loc)
68 |
69 | if squeeze:
70 | h = h.squeeze(-1).squeeze(-1)
71 |
72 | if self.logdet:
73 | log_abs = torch.log(torch.abs(self.scale))
74 | logdet = height*width*torch.sum(log_abs)
75 | logdet = logdet * torch.ones(input.shape[0]).to(input)
76 | return h, logdet
77 |
78 | return h
79 |
80 | def reverse(self, output):
81 | if self.training and self.initialized.item() == 0:
82 | if not self.allow_reverse_init:
83 | raise RuntimeError(
84 | "Initializing ActNorm in reverse direction is "
85 | "disabled by default. Use allow_reverse_init=True to enable."
86 | )
87 | else:
88 | self.initialize(output)
89 | self.initialized.fill_(1)
90 |
91 | if len(output.shape) == 2:
92 | output = output[:,:,None,None]
93 | squeeze = True
94 | else:
95 | squeeze = False
96 |
97 | h = output / self.scale - self.loc
98 |
99 | if squeeze:
100 | h = h.squeeze(-1).squeeze(-1)
101 | return h
102 |
--------------------------------------------------------------------------------
/opensora/models/ae/videobase/modules/ops.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from einops import rearrange
3 |
4 | def video_to_image(func):
5 | def wrapper(self, x, *args, **kwargs):
6 | if x.dim() == 5:
7 | t = x.shape[2]
8 | x = rearrange(x, "b c t h w -> (b t) c h w")
9 | x = func(self, x, *args, **kwargs)
10 | x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
11 | return x
12 | return wrapper
13 |
14 | def nonlinearity(x):
15 | return x * torch.sigmoid(x)
16 |
17 | def cast_tuple(t, length=1):
18 | return t if isinstance(t, tuple) else ((t,) * length)
19 |
20 | def shift_dim(x, src_dim=-1, dest_dim=-1, make_contiguous=True):
21 | n_dims = len(x.shape)
22 | if src_dim < 0:
23 | src_dim = n_dims + src_dim
24 | if dest_dim < 0:
25 | dest_dim = n_dims + dest_dim
26 | assert 0 <= src_dim < n_dims and 0 <= dest_dim < n_dims
27 | dims = list(range(n_dims))
28 | del dims[src_dim]
29 | permutation = []
30 | ctr = 0
31 | for i in range(n_dims):
32 | if i == dest_dim:
33 | permutation.append(src_dim)
34 | else:
35 | permutation.append(dims[ctr])
36 | ctr += 1
37 | x = x.permute(permutation)
38 | if make_contiguous:
39 | x = x.contiguous()
40 | return x
--------------------------------------------------------------------------------
/opensora/models/ae/videobase/modules/quant.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.distributed as dist
4 | import numpy as np
5 | import torch.nn.functional as F
6 | from .ops import shift_dim
7 |
8 | class Codebook(nn.Module):
9 | def __init__(self, n_codes, embedding_dim):
10 | super().__init__()
11 | self.register_buffer("embeddings", torch.randn(n_codes, embedding_dim))
12 | self.register_buffer("N", torch.zeros(n_codes))
13 | self.register_buffer("z_avg", self.embeddings.data.clone())
14 |
15 | self.n_codes = n_codes
16 | self.embedding_dim = embedding_dim
17 | self._need_init = True
18 |
19 | def _tile(self, x):
20 | d, ew = x.shape
21 | if d < self.n_codes:
22 | n_repeats = (self.n_codes + d - 1) // d
23 | std = 0.01 / np.sqrt(ew)
24 | x = x.repeat(n_repeats, 1)
25 | x = x + torch.randn_like(x) * std
26 | return x
27 |
28 | def _init_embeddings(self, z):
29 | # z: [b, c, t, h, w]
30 | self._need_init = False
31 | flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2)
32 | y = self._tile(flat_inputs)
33 |
34 | d = y.shape[0]
35 | _k_rand = y[torch.randperm(y.shape[0])][: self.n_codes]
36 | if dist.is_initialized():
37 | dist.broadcast(_k_rand, 0)
38 | self.embeddings.data.copy_(_k_rand)
39 | self.z_avg.data.copy_(_k_rand)
40 | self.N.data.copy_(torch.ones(self.n_codes))
41 |
42 | def forward(self, z):
43 | # z: [b, c, t, h, w]
44 | if self._need_init and self.training:
45 | self._init_embeddings(z)
46 | flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2)
47 | distances = (
48 | (flat_inputs**2).sum(dim=1, keepdim=True)
49 | - 2 * flat_inputs @ self.embeddings.t()
50 | + (self.embeddings.t() ** 2).sum(dim=0, keepdim=True)
51 | )
52 |
53 | encoding_indices = torch.argmin(distances, dim=1)
54 | encode_onehot = F.one_hot(encoding_indices, self.n_codes).type_as(flat_inputs)
55 | encoding_indices = encoding_indices.view(z.shape[0], *z.shape[2:])
56 |
57 | embeddings = F.embedding(encoding_indices, self.embeddings)
58 | embeddings = shift_dim(embeddings, -1, 1)
59 |
60 | commitment_loss = 0.25 * F.mse_loss(z, embeddings.detach())
61 |
62 | # EMA codebook update
63 | if self.training:
64 | n_total = encode_onehot.sum(dim=0)
65 | encode_sum = flat_inputs.t() @ encode_onehot
66 | if dist.is_initialized():
67 | dist.all_reduce(n_total)
68 | dist.all_reduce(encode_sum)
69 |
70 | self.N.data.mul_(0.99).add_(n_total, alpha=0.01)
71 | self.z_avg.data.mul_(0.99).add_(encode_sum.t(), alpha=0.01)
72 |
73 | n = self.N.sum()
74 | weights = (self.N + 1e-7) / (n + self.n_codes * 1e-7) * n
75 | encode_normalized = self.z_avg / weights.unsqueeze(1)
76 | self.embeddings.data.copy_(encode_normalized)
77 |
78 | y = self._tile(flat_inputs)
79 | _k_rand = y[torch.randperm(y.shape[0])][: self.n_codes]
80 | if dist.is_initialized():
81 | dist.broadcast(_k_rand, 0)
82 |
83 | usage = (self.N.view(self.n_codes, 1) >= 1).float()
84 | self.embeddings.data.mul_(usage).add_(_k_rand * (1 - usage))
85 |
86 | embeddings_st = (embeddings - z).detach() + z
87 |
88 | avg_probs = torch.mean(encode_onehot, dim=0)
89 | perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
90 |
91 | return dict(
92 | embeddings=embeddings_st,
93 | encodings=encoding_indices,
94 | commitment_loss=commitment_loss,
95 | perplexity=perplexity,
96 | )
97 |
98 | def dictionary_lookup(self, encodings):
99 | embeddings = F.embedding(encodings, self.embeddings)
100 | return embeddings
--------------------------------------------------------------------------------
/opensora/models/ae/videobase/modules/resnet_block.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from einops import rearrange, pack, unpack
4 | from .normalize import Normalize
5 | from .ops import nonlinearity, video_to_image
6 | from .conv import CausalConv3d
7 | from .block import Block
8 |
9 | class ResnetBlock2D(Block):
10 | def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
11 | dropout):
12 | super().__init__()
13 | self.in_channels = in_channels
14 | self.out_channels = in_channels if out_channels is None else out_channels
15 | self.use_conv_shortcut = conv_shortcut
16 |
17 | self.norm1 = Normalize(in_channels)
18 | self.conv1 = torch.nn.Conv2d(
19 | in_channels, out_channels, kernel_size=3, stride=1, padding=1
20 | )
21 | self.norm2 = Normalize(out_channels)
22 | self.dropout = torch.nn.Dropout(dropout)
23 | self.conv2 = torch.nn.Conv2d(
24 | out_channels, out_channels, kernel_size=3, stride=1, padding=1
25 | )
26 | if self.in_channels != self.out_channels:
27 | if self.use_conv_shortcut:
28 | self.conv_shortcut = torch.nn.Conv2d(
29 | in_channels, out_channels, kernel_size=3, stride=1, padding=1
30 | )
31 | else:
32 | self.nin_shortcut = torch.nn.Conv2d(
33 | in_channels, out_channels, kernel_size=1, stride=1, padding=0
34 | )
35 |
36 | @video_to_image
37 | def forward(self, x):
38 | h = x
39 | h = self.norm1(h)
40 | h = nonlinearity(h)
41 | h = self.conv1(h)
42 | h = self.norm2(h)
43 | h = nonlinearity(h)
44 | h = self.dropout(h)
45 | h = self.conv2(h)
46 | if self.in_channels != self.out_channels:
47 | if self.use_conv_shortcut:
48 | x = self.conv_shortcut(x)
49 | else:
50 | x = self.nin_shortcut(x)
51 | x = x + h
52 | return x
53 |
54 | class ResnetBlock3D(Block):
55 | def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout):
56 | super().__init__()
57 | self.in_channels = in_channels
58 | self.out_channels = in_channels if out_channels is None else out_channels
59 | self.use_conv_shortcut = conv_shortcut
60 |
61 | self.norm1 = Normalize(in_channels)
62 | self.conv1 = CausalConv3d(in_channels, out_channels, 3, padding=1)
63 | self.norm2 = Normalize(out_channels)
64 | self.dropout = torch.nn.Dropout(dropout)
65 | self.conv2 = CausalConv3d(out_channels, out_channels, 3, padding=1)
66 | if self.in_channels != self.out_channels:
67 | if self.use_conv_shortcut:
68 | self.conv_shortcut = CausalConv3d(in_channels, out_channels, 3, padding=1)
69 | else:
70 | self.nin_shortcut = CausalConv3d(in_channels, out_channels, 1, padding=0)
71 |
72 | def forward(self, x):
73 | h = x
74 | h = self.norm1(h)
75 | h = nonlinearity(h)
76 | h = self.conv1(h)
77 | h = self.norm2(h)
78 | h = nonlinearity(h)
79 | h = self.dropout(h)
80 | h = self.conv2(h)
81 | if self.in_channels != self.out_channels:
82 | if self.use_conv_shortcut:
83 | x = self.conv_shortcut(x)
84 | else:
85 | x = self.nin_shortcut(x)
86 | return x + h
--------------------------------------------------------------------------------
/opensora/models/ae/videobase/trainer_videobase.py:
--------------------------------------------------------------------------------
1 | from transformers import Trainer
2 | import torch.nn.functional as F
3 | from typing import Optional
4 | import os
5 | import torch
6 | from transformers.utils import WEIGHTS_NAME
7 | import json
8 |
9 | class VideoBaseTrainer(Trainer):
10 |
11 | def _save(self, output_dir: Optional[str] = None, state_dict=None):
12 | output_dir = output_dir if output_dir is not None else self.args.output_dir
13 | os.makedirs(output_dir, exist_ok=True)
14 | if state_dict is None:
15 | state_dict = self.model.state_dict()
16 |
17 | # get model config
18 | model_config = self.model.config.to_dict()
19 |
20 | # add more information
21 | model_config['model'] = self.model.__class__.__name__
22 |
23 | with open(os.path.join(output_dir, "config.json"), "w") as file:
24 | json.dump(self.model.config.to_dict(), file)
25 | torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
26 | torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
27 |
--------------------------------------------------------------------------------
/opensora/models/ae/videobase/utils/distrib_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 | class DiagonalGaussianDistribution(object):
5 | def __init__(self, parameters, deterministic=False):
6 | self.parameters = parameters
7 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
8 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
9 | self.deterministic = deterministic
10 | self.std = torch.exp(0.5 * self.logvar)
11 | self.var = torch.exp(self.logvar)
12 | if self.deterministic:
13 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
14 |
15 | def sample(self):
16 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
17 | return x
18 |
19 | def kl(self, other=None):
20 | if self.deterministic:
21 | return torch.Tensor([0.])
22 | else:
23 | if other is None:
24 | return 0.5 * torch.sum(torch.pow(self.mean, 2)
25 | + self.var - 1.0 - self.logvar,
26 | dim=[1, 2, 3])
27 | else:
28 | return 0.5 * torch.sum(
29 | torch.pow(self.mean - other.mean, 2) / other.var
30 | + self.var / other.var - 1.0 - self.logvar + other.logvar,
31 | dim=[1, 2, 3])
32 |
33 | def nll(self, sample, dims=[1,2,3]):
34 | if self.deterministic:
35 | return torch.Tensor([0.])
36 | logtwopi = np.log(2.0 * np.pi)
37 | return 0.5 * torch.sum(
38 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
39 | dim=dims)
40 |
41 | def mode(self):
42 | return self.mean
43 |
--------------------------------------------------------------------------------
/opensora/models/ae/videobase/utils/module_utils.py:
--------------------------------------------------------------------------------
1 | import importlib
2 |
3 | Module = str
4 | MODULES_BASE = "opensora.models.ae.videobase.modules."
5 |
6 | def resolve_str_to_obj(str_val, append=True):
7 | if append:
8 | str_val = MODULES_BASE + str_val
9 | module_name, class_name = str_val.rsplit('.', 1)
10 | module = importlib.import_module(module_name)
11 | return getattr(module, class_name)
12 |
13 | def create_instance(module_class_str: str, **kwargs):
14 | module_name, class_name = module_class_str.rsplit('.', 1)
15 | module = importlib.import_module(module_name)
16 | class_ = getattr(module, class_name)
17 | return class_(**kwargs)
--------------------------------------------------------------------------------
/opensora/models/ae/videobase/utils/scheduler_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | def cosine_scheduler(step, max_steps, value_base=1, value_end=0):
4 | step = torch.tensor(step)
5 | cosine_value = 0.5 * (1 + torch.cos(torch.pi * step / max_steps))
6 | value = value_end + (value_base - value_end) * cosine_value
7 | return value
--------------------------------------------------------------------------------
/opensora/models/ae/videobase/utils/video_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 | def tensor_to_video(x):
5 | x = x.detach().cpu()
6 | x = torch.clamp(x, -1, 1)
7 | x = (x + 1) / 2
8 | x = x.permute(1, 0, 2, 3).float().numpy() # c t h w ->
9 | x = (255 * x).astype(np.uint8)
10 | return x
--------------------------------------------------------------------------------
/opensora/models/ae/videobase/vqvae/__init__.py:
--------------------------------------------------------------------------------
1 | from einops import rearrange
2 | from torch import nn
3 |
4 | from .configuration_vqvae import VQVAEConfiguration
5 | from .modeling_vqvae import VQVAEModel
6 | from .trainer_vqvae import VQVAETrainer
7 |
8 | videovqvae = [
9 | "bair_stride4x2x2",
10 | "ucf101_stride4x4x4",
11 | "kinetics_stride4x4x4",
12 | "kinetics_stride2x4x4",
13 | ]
14 | videovae = []
15 |
16 | class VQVAEModelWrapper(nn.Module):
17 | def __init__(self, ckpt='kinetics_stride4x4x4'):
18 | super(VQVAEModelWrapper, self).__init__()
19 | if ckpt in videovqvae:
20 | self.vqvae = VQVAEModel.download_and_load_model(ckpt)
21 | else:
22 | self.vqvae = VQVAEModel.load_from_checkpoint(ckpt)
23 | def encode(self, x): # b c t h w
24 | x = self.vqvae.pre_vq_conv(self.vqvae.encoder(x))
25 | return x
26 | def decode(self, x):
27 | vq_output = self.vqvae.codebook(x)
28 | x = self.vqvae.decoder(self.vqvae.post_vq_conv(vq_output['embeddings']))
29 | x = rearrange(x, 'b c t h w -> b t c h w').contiguous()
30 | return x
31 |
--------------------------------------------------------------------------------
/opensora/models/ae/videobase/vqvae/configuration_vqvae.py:
--------------------------------------------------------------------------------
1 | from ..configuration_videobase import VideoBaseConfiguration
2 | from typing import Union, Tuple
3 |
4 | class VQVAEConfiguration(VideoBaseConfiguration):
5 | def __init__(
6 | self,
7 | embedding_dim: int = 256,
8 | n_codes: int = 2048,
9 | n_hiddens: int = 240,
10 | n_res_layers: int = 4,
11 | resolution: int = 128,
12 | sequence_length: int = 16,
13 | downsample: Union[Tuple[int, int, int], str] = (4, 4, 4),
14 | no_pos_embd: bool = True,
15 | **kwargs,
16 | ):
17 | super().__init__(**kwargs)
18 |
19 | self.embedding_dim = embedding_dim
20 | self.n_codes = n_codes
21 | self.n_hiddens = n_hiddens
22 | self.n_res_layers = n_res_layers
23 | self.resolution = resolution
24 | self.sequence_length = sequence_length
25 |
26 | if isinstance(downsample, str):
27 | self.downsample = tuple(map(int, downsample.split(",")))
28 | else:
29 | self.downsample = downsample
30 |
31 | self.no_pos_embd = no_pos_embd
32 |
33 | self.hidden_size = n_hiddens
34 |
--------------------------------------------------------------------------------
/opensora/models/ae/videobase/vqvae/trainer_vqvae.py:
--------------------------------------------------------------------------------
1 | from ..trainer_videobase import VideoBaseTrainer
2 | import torch.nn.functional as F
3 | from typing import Optional
4 | import os
5 | import torch
6 | from transformers.utils import WEIGHTS_NAME
7 | import json
8 |
9 | class VQVAETrainer(VideoBaseTrainer):
10 |
11 | def compute_loss(self, model, inputs, return_outputs=False):
12 | model = model.module
13 | x = inputs.get("video")
14 | x = x / 2
15 | z = model.pre_vq_conv(model.encoder(x))
16 | vq_output = model.codebook(z)
17 | x_recon = model.decoder(model.post_vq_conv(vq_output["embeddings"]))
18 | recon_loss = F.mse_loss(x_recon, x) / 0.06
19 | commitment_loss = vq_output['commitment_loss']
20 | loss = recon_loss + commitment_loss
21 | return loss
22 |
23 |
--------------------------------------------------------------------------------
/opensora/models/captioner/caption_refiner/README.md:
--------------------------------------------------------------------------------
1 | # Refiner for Video Caption
2 |
3 | Transform the short caption annotations from video datasets into the long and detailed caption annotations.
4 |
5 | * Add detailed description for background scene.
6 | * Add detailed description for object attributes, including color, material, pose.
7 | * Add detailed description for object-level spatial relationship.
8 |
9 | ## 🛠️ Extra Requirements and Installation
10 |
11 | * openai == 0.28.0
12 | * jsonlines == 4.0.0
13 | * nltk == 3.8.1
14 | * Install the LLaMA-Accessory:
15 |
16 | you also need to download the weight of SPHINX to ./ckpt/ folder
17 |
18 | ## 🗝️ Refining
19 |
20 | The refining instruction is in [demo_for_refiner.py](demo_for_refiner.py).
21 |
22 | ```bash
23 | python demo_for_refiner.py --root_path $path_to_repo$ --api_key $openai_api_key$
24 | ```
25 |
26 | ### Refining Demos
27 |
28 | ```bash
29 | [original caption]: A red mustang parked in a showroom with american flags hanging from the ceiling.
30 | ```
31 |
32 | ```bash
33 | [refine caption]: This scene depicts a red Mustang parked in a showroom with American flags hanging from the ceiling. The showroom likely serves as a space for showcasing and purchasing cars, and the Mustang is displayed prominently near the flags and ceiling. The scene also features a large window and other objects. Overall, it seems to take place in a car show or dealership.
34 | ```
35 |
36 | - [ ] Add GPT-3.5-Turbo for caption summarization. ⌛ [WIP]
37 | - [ ] Add LLAVA-1.6. ⌛ [WIP]
38 | - [ ] More descriptions. ⌛ [WIP]
--------------------------------------------------------------------------------
/opensora/models/captioner/caption_refiner/dataset/test_videos/captions.json:
--------------------------------------------------------------------------------
1 | {"video1.gif": "A red mustang parked in a showroom with american flags hanging from the ceiling.", "video2.gif": "An aerial view of a city with a river running through it."}
--------------------------------------------------------------------------------
/opensora/models/captioner/caption_refiner/dataset/test_videos/video1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-Open-Sora-Plan/b060ff6d7a85a27eec5ff9b81b599d03c4ac1bc6/opensora/models/captioner/caption_refiner/dataset/test_videos/video1.gif
--------------------------------------------------------------------------------
/opensora/models/captioner/caption_refiner/dataset/test_videos/video2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-Open-Sora-Plan/b060ff6d7a85a27eec5ff9b81b599d03c4ac1bc6/opensora/models/captioner/caption_refiner/dataset/test_videos/video2.gif
--------------------------------------------------------------------------------
/opensora/models/captioner/caption_refiner/demo_for_refiner.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from caption_refiner import CaptionRefiner
3 | from gpt_combinator import caption_summary, caption_qa
4 |
5 | def parse_args():
6 | parser = argparse.ArgumentParser(description="question-answer-generation-using-gpt-3")
7 | parser.add_argument("--root_path", required=True, help="The path to repo.")
8 | parser.add_argument("--api_key", required=True, help="OpenAI API key.")
9 | args = parser.parse_args()
10 | return args
11 |
12 | if __name__ == "__main__":
13 | args = parse_args()
14 | myrefiner = CaptionRefiner(
15 | sample_num=6, add_detect=True, add_pos=True, add_attr=True,
16 | openai_api_key = args.api_key,
17 | openai_api_base = "https://one-api.bltcy.top/v1",
18 | )
19 |
20 | results = myrefiner.caption_refine(
21 | video_path="./dataset/test_videos/video1.gif",
22 | org_caption="A red mustang parked in a showroom with american flags hanging from the ceiling.",
23 | model_path = args.root_path + "/ckpts/SPHINX-Tiny",
24 | )
25 |
26 | final_caption = myrefiner.gpt_summary(results)
27 |
28 | print(final_caption)
29 |
--------------------------------------------------------------------------------
/opensora/models/captioner/caption_refiner/gpt_combinator.py:
--------------------------------------------------------------------------------
1 | import openai
2 | import ast
3 |
4 | def caption_qa(caption_list, api_key, api_base):
5 | openai.api_key = api_key
6 | openai.api_base = api_base
7 |
8 | question = "What is the color of a red apple"
9 | answer = "red"
10 | pred = "green"
11 | try:
12 | # Compute the correctness score
13 | completion = openai.ChatCompletion.create(
14 | model="gpt-3.5-turbo",
15 | # model="gpt-4",
16 | # model="gpt-4-vision-compatible",
17 | messages=[
18 | {
19 | "role": "system",
20 | "content":
21 | "You are an intelligent chatbot designed for evaluating the correctness of generative outputs for question-answer pairs. "
22 | "Your task is to compare the predicted answer with the correct answer and determine if they match meaningfully. Here's how you can accomplish the task:"
23 | "------"
24 | "##INSTRUCTIONS: "
25 | "- Focus on the meaningful match between the predicted answer and the correct answer.\n"
26 | "- Consider synonyms or paraphrases as valid matches.\n"
27 | "- Evaluate the correctness of the prediction compared to the answer."
28 | },
29 | {
30 | "role": "user",
31 | "content":
32 | "Please evaluate the following video-based question-answer pair:\n\n"
33 | f"Question: {question}\n"
34 | f"Correct Answer: {answer}\n"
35 | f"Predicted Answer: {pred}\n\n"
36 | "Provide your evaluation only as a yes/no and score where the score is an integer value between 0 and 5, with 5 indicating the highest meaningful match. "
37 | "Please generate the response in the form of a Python dictionary string with keys 'pred' and 'score', where value of 'pred' is a string of 'yes' or 'no' and value of 'score' is in INTEGER, not STRING."
38 | "DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. "
39 | "For example, your response should look like this: {'pred': 'yes', 'score': 4.8}."
40 | }
41 | ]
42 | )
43 | # Convert response to a Python dictionary.
44 | response_message = completion["choices"][0]["message"]["content"]
45 | response_dict = ast.literal_eval(response_message)
46 | print(response_dict)
47 |
48 | except Exception as e:
49 | print(f"Error processing file : {e}")
50 |
51 |
52 | def caption_summary(long_caption, api_key, api_base):
53 | """
54 | apply GPT3-Turbo as the combination for original caption and the prompted captions for a video
55 | """
56 | openai.api_key = api_key
57 | openai.api_base = api_base
58 |
59 | try:
60 | # Compute the correctness score
61 | completion = openai.ChatCompletion.create(
62 | model="gpt-3.5-turbo",
63 | messages=[
64 | {
65 | "role": "system",
66 | "content":
67 | "You are an intelligent chatbot designed for summarizing from a long sentence. "
68 | },
69 | {
70 | "role": "user",
71 | "content":
72 | "Please summarize the following sentences. Make it shorter than 70 words."
73 | f"the long sentence: {long_caption}\n"
74 | "Provide your summarization with less than 70 words. "
75 | "DO NOT PROVIDE ANY OTHER TEXT OR EXPLANATION. Only provide the summary sentence. "
76 | }
77 | ]
78 | )
79 | # "Please generate the response in the form of a Python dictionary string with keys 'pred' and 'score', where value of 'pred' is a string of 'yes' or 'no' and value of 'score' is in INTEGER, not STRING."
80 | # "DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. "
81 | # "For example, your response should look like this: {'summary': 'your summary sentence'}."
82 |
83 | # Convert response to a Python dictionary.
84 | response_message = completion["choices"][0]["message"]["content"]
85 | response_dict = ast.literal_eval(response_message)
86 |
87 | except Exception as e:
88 | print(f"Error processing file : {e}")
89 |
90 | return response_dict
91 |
92 | if __name__ == "__main__":
93 | caption_summary()
--------------------------------------------------------------------------------
/opensora/models/diffusion/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from .latte.modeling_latte import Latte_models
3 |
4 | Diffusion_models = {}
5 | Diffusion_models.update(Latte_models)
6 |
7 |
--------------------------------------------------------------------------------
/opensora/models/diffusion/diffusion/__init__.py:
--------------------------------------------------------------------------------
1 | # Modified from OpenAI's diffusion repos
2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5 |
6 | from .respace import SpacedDiffusion, space_timesteps, SpacedDiffusion_T
7 |
8 |
9 | def create_diffusion(
10 | timestep_respacing,
11 | noise_schedule="linear",
12 | use_kl=False,
13 | sigma_small=False,
14 | predict_xstart=False,
15 | learn_sigma=True,
16 | # learn_sigma=False,
17 | rescale_learned_sigmas=False,
18 | diffusion_steps=1000
19 | ):
20 | from . import gaussian_diffusion as gd
21 | betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
22 | if use_kl:
23 | loss_type = gd.LossType.RESCALED_KL
24 | elif rescale_learned_sigmas:
25 | loss_type = gd.LossType.RESCALED_MSE
26 | else:
27 | loss_type = gd.LossType.MSE
28 | if timestep_respacing is None or timestep_respacing == "":
29 | timestep_respacing = [diffusion_steps]
30 | return SpacedDiffusion(
31 | use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
32 | betas=betas,
33 | model_mean_type=(
34 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
35 | ),
36 | model_var_type=(
37 | (
38 | gd.ModelVarType.FIXED_LARGE
39 | if not sigma_small
40 | else gd.ModelVarType.FIXED_SMALL
41 | )
42 | if not learn_sigma
43 | else gd.ModelVarType.LEARNED_RANGE
44 | ),
45 | loss_type=loss_type
46 | # rescale_timesteps=rescale_timesteps,
47 | )
48 |
49 | def create_diffusion_T(
50 | timestep_respacing,
51 | noise_schedule="linear",
52 | use_kl=False,
53 | sigma_small=False,
54 | predict_xstart=False,
55 | learn_sigma=True,
56 | # learn_sigma=False,
57 | rescale_learned_sigmas=False,
58 | diffusion_steps=1000
59 | ):
60 | from . import gaussian_diffusion_t2v as gd
61 | betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
62 | if use_kl:
63 | loss_type = gd.LossType.RESCALED_KL
64 | elif rescale_learned_sigmas:
65 | loss_type = gd.LossType.RESCALED_MSE
66 | else:
67 | loss_type = gd.LossType.MSE
68 | if timestep_respacing is None or timestep_respacing == "":
69 | timestep_respacing = [diffusion_steps]
70 | return SpacedDiffusion_T(
71 | use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
72 | betas=betas,
73 | model_mean_type=(
74 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
75 | ),
76 | model_var_type=(
77 | (
78 | gd.ModelVarType.FIXED_LARGE
79 | if not sigma_small
80 | else gd.ModelVarType.FIXED_SMALL
81 | )
82 | if not learn_sigma
83 | else gd.ModelVarType.LEARNED_RANGE
84 | ),
85 | loss_type=loss_type
86 | # rescale_timesteps=rescale_timesteps,
87 | )
88 |
--------------------------------------------------------------------------------
/opensora/models/diffusion/diffusion/diffusion_utils.py:
--------------------------------------------------------------------------------
1 | # Modified from OpenAI's diffusion repos
2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5 |
6 | import torch as th
7 | import numpy as np
8 |
9 |
10 | def normal_kl(mean1, logvar1, mean2, logvar2):
11 | """
12 | Compute the KL divergence between two gaussians.
13 | Shapes are automatically broadcasted, so batches can be compared to
14 | scalars, among other use cases.
15 | """
16 | tensor = None
17 | for obj in (mean1, logvar1, mean2, logvar2):
18 | if isinstance(obj, th.Tensor):
19 | tensor = obj
20 | break
21 | assert tensor is not None, "at least one argument must be a Tensor"
22 |
23 | # Force variances to be Tensors. Broadcasting helps convert scalars to
24 | # Tensors, but it does not work for th.exp().
25 | logvar1, logvar2 = [
26 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
27 | for x in (logvar1, logvar2)
28 | ]
29 |
30 | return 0.5 * (
31 | -1.0
32 | + logvar2
33 | - logvar1
34 | + th.exp(logvar1 - logvar2)
35 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
36 | )
37 |
38 |
39 | def approx_standard_normal_cdf(x):
40 | """
41 | A fast approximation of the cumulative distribution function of the
42 | standard normal.
43 | """
44 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
45 |
46 |
47 | def continuous_gaussian_log_likelihood(x, *, means, log_scales):
48 | """
49 | Compute the log-likelihood of a continuous Gaussian distribution.
50 | :param x: the targets
51 | :param means: the Gaussian mean Tensor.
52 | :param log_scales: the Gaussian log stddev Tensor.
53 | :return: a tensor like x of log probabilities (in nats).
54 | """
55 | centered_x = x - means
56 | inv_stdv = th.exp(-log_scales)
57 | normalized_x = centered_x * inv_stdv
58 | log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x)
59 | return log_probs
60 |
61 |
62 | def discretized_gaussian_log_likelihood(x, *, means, log_scales):
63 | """
64 | Compute the log-likelihood of a Gaussian distribution discretizing to a
65 | given image.
66 | :param x: the target images. It is assumed that this was uint8 values,
67 | rescaled to the range [-1, 1].
68 | :param means: the Gaussian mean Tensor.
69 | :param log_scales: the Gaussian log stddev Tensor.
70 | :return: a tensor like x of log probabilities (in nats).
71 | """
72 | assert x.shape == means.shape == log_scales.shape
73 | centered_x = x - means
74 | inv_stdv = th.exp(-log_scales)
75 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
76 | cdf_plus = approx_standard_normal_cdf(plus_in)
77 | min_in = inv_stdv * (centered_x - 1.0 / 255.0)
78 | cdf_min = approx_standard_normal_cdf(min_in)
79 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
80 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
81 | cdf_delta = cdf_plus - cdf_min
82 | log_probs = th.where(
83 | x < -0.999,
84 | log_cdf_plus,
85 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
86 | )
87 | assert log_probs.shape == x.shape
88 | return log_probs
89 |
--------------------------------------------------------------------------------
/opensora/models/diffusion/transport/__init__.py:
--------------------------------------------------------------------------------
1 | from .transport import Transport, ModelType, WeightType, PathType, Sampler
2 |
3 | def create_transport(
4 | path_type='Linear',
5 | prediction="velocity",
6 | loss_weight=None,
7 | train_eps=None,
8 | sample_eps=None,
9 | ):
10 | """function for creating Transport object
11 | **Note**: model prediction defaults to velocity
12 | Args:
13 | - path_type: type of path to use; default to linear
14 | - learn_score: set model prediction to score
15 | - learn_noise: set model prediction to noise
16 | - velocity_weighted: weight loss by velocity weight
17 | - likelihood_weighted: weight loss by likelihood weight
18 | - train_eps: small epsilon for avoiding instability during training
19 | - sample_eps: small epsilon for avoiding instability during sampling
20 | """
21 |
22 | if prediction == "noise":
23 | model_type = ModelType.NOISE
24 | elif prediction == "score":
25 | model_type = ModelType.SCORE
26 | else:
27 | model_type = ModelType.VELOCITY
28 |
29 | if loss_weight == "velocity":
30 | loss_type = WeightType.VELOCITY
31 | elif loss_weight == "likelihood":
32 | loss_type = WeightType.LIKELIHOOD
33 | else:
34 | loss_type = WeightType.NONE
35 |
36 | path_choice = {
37 | "Linear": PathType.LINEAR,
38 | "GVP": PathType.GVP,
39 | "VP": PathType.VP,
40 | }
41 |
42 | path_type = path_choice[path_type]
43 |
44 | if (path_type in [PathType.VP]):
45 | train_eps = 1e-5 if train_eps is None else train_eps
46 | sample_eps = 1e-3 if train_eps is None else sample_eps
47 | elif (path_type in [PathType.GVP, PathType.LINEAR] and model_type != ModelType.VELOCITY):
48 | train_eps = 1e-3 if train_eps is None else train_eps
49 | sample_eps = 1e-3 if train_eps is None else sample_eps
50 | else: # velocity & [GVP, LINEAR] is stable everywhere
51 | train_eps = 0
52 | sample_eps = 0
53 |
54 | # create flow state
55 | state = Transport(
56 | model_type=model_type,
57 | path_type=path_type,
58 | loss_type=loss_type,
59 | train_eps=train_eps,
60 | sample_eps=sample_eps,
61 | )
62 |
63 | return state
--------------------------------------------------------------------------------
/opensora/models/diffusion/transport/integrators.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch as th
3 | import torch.nn as nn
4 | from torchdiffeq import odeint
5 | from functools import partial
6 | from tqdm import tqdm
7 |
8 | class sde:
9 | """SDE solver class"""
10 | def __init__(
11 | self,
12 | drift,
13 | diffusion,
14 | *,
15 | t0,
16 | t1,
17 | num_steps,
18 | sampler_type,
19 | ):
20 | assert t0 < t1, "SDE sampler has to be in forward time"
21 |
22 | self.num_timesteps = num_steps
23 | self.t = th.linspace(t0, t1, num_steps)
24 | self.dt = self.t[1] - self.t[0]
25 | self.drift = drift
26 | self.diffusion = diffusion
27 | self.sampler_type = sampler_type
28 |
29 | def __Euler_Maruyama_step(self, x, mean_x, t, model, **model_kwargs):
30 | w_cur = th.randn(x.size()).to(x)
31 | t = th.ones(x.size(0)).to(x) * t
32 | dw = w_cur * th.sqrt(self.dt)
33 | drift = self.drift(x, t, model, **model_kwargs)
34 | diffusion = self.diffusion(x, t)
35 | mean_x = x + drift * self.dt
36 | x = mean_x + th.sqrt(2 * diffusion) * dw
37 | return x, mean_x
38 |
39 | def __Heun_step(self, x, _, t, model, **model_kwargs):
40 | w_cur = th.randn(x.size()).to(x)
41 | dw = w_cur * th.sqrt(self.dt)
42 | t_cur = th.ones(x.size(0)).to(x) * t
43 | diffusion = self.diffusion(x, t_cur)
44 | xhat = x + th.sqrt(2 * diffusion) * dw
45 | K1 = self.drift(xhat, t_cur, model, **model_kwargs)
46 | xp = xhat + self.dt * K1
47 | K2 = self.drift(xp, t_cur + self.dt, model, **model_kwargs)
48 | return xhat + 0.5 * self.dt * (K1 + K2), xhat # at last time point we do not perform the heun step
49 |
50 | def __forward_fn(self):
51 | """TODO: generalize here by adding all private functions ending with steps to it"""
52 | sampler_dict = {
53 | "Euler": self.__Euler_Maruyama_step,
54 | "Heun": self.__Heun_step,
55 | }
56 |
57 | try:
58 | sampler = sampler_dict[self.sampler_type]
59 | except:
60 | raise NotImplementedError("Smapler type not implemented.")
61 |
62 | return sampler
63 |
64 | def sample(self, init, model, **model_kwargs):
65 | """forward loop of sde"""
66 | x = init
67 | mean_x = init
68 | samples = []
69 | sampler = self.__forward_fn()
70 | for ti in self.t[:-1]:
71 | with th.no_grad():
72 | x, mean_x = sampler(x, mean_x, ti, model, **model_kwargs)
73 | samples.append(x)
74 |
75 | return samples
76 |
77 | class ode:
78 | """ODE solver class"""
79 | def __init__(
80 | self,
81 | drift,
82 | *,
83 | t0,
84 | t1,
85 | sampler_type,
86 | num_steps,
87 | atol,
88 | rtol,
89 | ):
90 | assert t0 < t1, "ODE sampler has to be in forward time"
91 |
92 | self.drift = drift
93 | self.t = th.linspace(t0, t1, num_steps)
94 | self.atol = atol
95 | self.rtol = rtol
96 | self.sampler_type = sampler_type
97 |
98 | def sample(self, x, model, **model_kwargs):
99 |
100 | device = x[0].device if isinstance(x, tuple) else x.device
101 | def _fn(t, x):
102 | t = th.ones(x[0].size(0)).to(device) * t if isinstance(x, tuple) else th.ones(x.size(0)).to(device) * t
103 | model_output = self.drift(x, t, model, **model_kwargs)
104 | return model_output
105 |
106 | t = self.t.to(device)
107 | atol = [self.atol] * len(x) if isinstance(x, tuple) else [self.atol]
108 | rtol = [self.rtol] * len(x) if isinstance(x, tuple) else [self.rtol]
109 | samples = odeint(
110 | _fn,
111 | x,
112 | t,
113 | method=self.sampler_type,
114 | atol=atol,
115 | rtol=rtol
116 | )
117 | return samples
--------------------------------------------------------------------------------
/opensora/models/diffusion/transport/utils.py:
--------------------------------------------------------------------------------
1 | import torch as th
2 |
3 | class EasyDict:
4 |
5 | def __init__(self, sub_dict):
6 | for k, v in sub_dict.items():
7 | setattr(self, k, v)
8 |
9 | def __getitem__(self, key):
10 | return getattr(self, key)
11 |
12 | def mean_flat(x):
13 | """
14 | Take the mean over all non-batch dimensions.
15 | """
16 | return th.mean(x, dim=list(range(1, len(x.size()))))
17 |
18 | def log_state(state):
19 | result = []
20 |
21 | sorted_state = dict(sorted(state.items()))
22 | for key, value in sorted_state.items():
23 | # Check if the value is an instance of a class
24 | if "