├── .gitignore ├── README.md ├── assets ├── 1.gif ├── 10.gif ├── 11.gif ├── 12.gif ├── 2.gif ├── 3.gif ├── 4.gif ├── 5.gif ├── 6.gif ├── 7.gif ├── 8.gif └── 9.gif ├── requirements.txt ├── run.sh ├── sat ├── arguments.py ├── asset │ └── prompt.txt ├── base_model.py ├── configs │ ├── cogvideox_2b.yaml │ └── inference.yaml ├── demo.py ├── diffusion_video.py ├── dit_video_concat.py ├── run.sh ├── sample_video.py ├── sgm │ ├── __init__.py │ ├── lr_scheduler.py │ ├── models │ │ ├── __init__.py │ │ └── autoencoder.py │ ├── modules │ │ ├── __init__.py │ │ ├── attention.py │ │ ├── autoencoding │ │ │ ├── __init__.py │ │ │ ├── losses │ │ │ │ ├── __init__.py │ │ │ │ ├── discriminator_loss.py │ │ │ │ ├── lpips.py │ │ │ │ └── video_loss.py │ │ │ ├── lpips │ │ │ │ ├── __init__.py │ │ │ │ ├── loss │ │ │ │ │ ├── .gitignore │ │ │ │ │ ├── LICENSE │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── lpips.py │ │ │ │ ├── model │ │ │ │ │ ├── LICENSE │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── model.py │ │ │ │ ├── util.py │ │ │ │ └── vqperceptual.py │ │ │ ├── magvit2_pytorch.py │ │ │ ├── regularizers │ │ │ │ ├── __init__.py │ │ │ │ ├── base.py │ │ │ │ ├── finite_scalar_quantization.py │ │ │ │ ├── lookup_free_quantization.py │ │ │ │ └── quantize.py │ │ │ ├── temporal_ae.py │ │ │ └── vqvae │ │ │ │ ├── movq_dec_3d.py │ │ │ │ ├── movq_dec_3d_dev.py │ │ │ │ ├── movq_enc_3d.py │ │ │ │ ├── movq_modules.py │ │ │ │ ├── quantize.py │ │ │ │ └── vqvae_blocks.py │ │ ├── cp_enc_dec.py │ │ ├── diffusionmodules │ │ │ ├── __init__.py │ │ │ ├── denoiser.py │ │ │ ├── denoiser_scaling.py │ │ │ ├── denoiser_weighting.py │ │ │ ├── discretizer.py │ │ │ ├── guiders.py │ │ │ ├── lora.py │ │ │ ├── loss.py │ │ │ ├── model.py │ │ │ ├── openaimodel.py │ │ │ ├── sampling.py │ │ │ ├── sampling_utils.py │ │ │ ├── sigma_sampling.py │ │ │ ├── util.py │ │ │ └── wrappers.py │ │ ├── distributions │ │ │ ├── __init__.py │ │ │ └── distributions.py │ │ ├── ema.py │ │ ├── encoders │ │ │ ├── __init__.py │ │ │ └── modules.py │ │ └── video_attention.py │ ├── util.py │ └── webds.py ├── transformer_gating.py └── vae_modules │ ├── attention.py │ ├── autoencoder.py │ ├── cp_enc_dec.py │ ├── ema.py │ ├── regularizers.py │ └── utils.py └── tools ├── caption ├── README.md ├── README_ja.md ├── README_zh.md └── assests │ └── cogvlm2-video-example.png ├── convert_weight_sat2hf.py └── venhancer ├── README.md ├── README_ja.md └── README_zh.md /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .hypothesis 3 | demo_videos 4 | results 5 | 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | cover/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | db.sqlite3-journal 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | .pybuilder/ 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | # For a library or package, you might want to ignore these files since the code is 92 | # intended to run in multiple environments; otherwise, check them in: 93 | # .python-version 94 | 95 | # pipenv 96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 99 | # install all needed dependencies. 100 | #Pipfile.lock 101 | 102 | # poetry 103 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 104 | # This is especially recommended for binary packages to ensure reproducibility, and is more 105 | # commonly ignored for libraries. 106 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 107 | #poetry.lock 108 | 109 | # pdm 110 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 111 | #pdm.lock 112 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 113 | # in version control. 114 | # https://pdm.fming.dev/#use-with-ide 115 | .pdm.toml 116 | 117 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 118 | __pypackages__/ 119 | 120 | # Celery stuff 121 | celerybeat-schedule 122 | celerybeat.pid 123 | 124 | # SageMath parsed files 125 | *.sage.py 126 | 127 | # Environments 128 | .env 129 | .venv 130 | env/ 131 | venv/ 132 | ENV/ 133 | env.bak/ 134 | venv.bak/ 135 | 136 | # Spyder project settings 137 | .spyderproject 138 | .spyproject 139 | 140 | # Rope project settings 141 | .ropeproject 142 | 143 | # mkdocs documentation 144 | /site 145 | 146 | # mypy 147 | .mypy_cache/ 148 | .dmypy.json 149 | dmypy.json 150 | 151 | # Pyre type checker 152 | .pyre/ 153 | 154 | # pytype static type analyzer 155 | .pytype/ 156 | 157 | # Cython debug symbols 158 | cython_debug/ 159 | 160 | # PyCharm 161 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 162 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 163 | # and can be added to the global gitignore or merged into this file. For a more nuclear 164 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 165 | .idea/ 166 | .vscode/ 167 | 168 | # macos 169 | *.DS_Store 170 | *.json 171 | 172 | /sat/ckpt 173 | /sat/outputs 174 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RepVideo: Rethinking Cross-Layer Representation for Video Generation 2 | 3 | 6 | 7 |
8 | 9 | 10 | Chenyang Si1†, 11 | 12 | Weichen Fan1†, 13 | 14 | Zhengyao Lv2, 15 | 16 | Ziqi Huang1, 17 | 18 | Yu Qiao2, 19 | 20 | Ziwei Liu1✉ 21 | 22 |
23 |
24 | S-Lab, Nanyang Technological University1      Shanghai Artificial Intelligence Laboratory 2 25 |
Equal contribution.    Corresponding Author.
26 |
27 | 28 |

29 | 30 |
31 | Paper | 32 | Project Page 33 |
34 | 37 | 38 | --- 39 | 40 | ![](https://img.shields.io/badge/RepVideo-v0.1-darkcyan) 41 | [![Hits](https://hits.seeyoufarm.com/api/count/incr/badge.svg?url=https%3A%2F%2Fgithub.com%2FVchitect%2FRepVideo&count_bg=%23BDC4B7&title_bg=%2342C4A8&icon=octopusdeploy.svg&icon_color=%23E7E7E7&title=visitors&edge_flat=true)](https://hits.seeyoufarm.com) 42 | [![Hits](https://hits.seeyoufarm.com/api/count/incr/badge.svg?url=https%3A%2F%2Farxiv.org%2Fpdf%2F2501.08994&count_bg=%2379C83D&title_bg=%23555555&icon=&icon_color=%23E7E7E7&title=Paper&edge_flat=false)](https://hits.seeyoufarm.com) 43 | [![Hits](https://hits.seeyoufarm.com/api/count/incr/badge.svg?url=https%3A%2F%2Fgithub.com%2FVchitect%2FRepVid-Webpage&count_bg=%23BE4C4C&title_bg=%235E5D64&icon=&icon_color=%23E7E7E7&title=Page&edge_flat=false)](https://hits.seeyoufarm.com) 44 | 45 | ## 🔥 Update and News 46 | - [2025.01.25] 🔥 Inference code and [checkpoint](https://huggingface.co/Vchitect/RepVideo) are released. 47 | 48 | 49 | ## :astonished: Gallery 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 |
79 | 80 | 81 | ## Installation 82 | 83 | ### 1. Create a conda environment and download models 84 | 85 | 86 | ```bash 87 | conda create -n RepVid python==3.10 88 | conda activate RepVid 89 | pip install -r requirements.txt 90 | 91 | 92 | mkdir ckpt 93 | cd ckpt 94 | mkdir t5-v1_1-xxl 95 | wget https://huggingface.co/THUDM/CogVideoX-2b/resolve/main/text_encoder/config.json 96 | wget https://huggingface.co/THUDM/CogVideoX-2b/resolve/main/text_encoder/model-00001-of-00002.safetensors 97 | wget https://huggingface.co/THUDM/CogVideoX-2b/resolve/main/text_encoder/model-00002-of-00002.safetensors 98 | wget https://huggingface.co/THUDM/CogVideoX-2b/resolve/main/text_encoder/model.safetensors.index.json 99 | wget https://huggingface.co/THUDM/CogVideoX-2b/resolve/main/tokenizer/added_tokens.json 100 | wget https://huggingface.co/THUDM/CogVideoX-2b/resolve/main/tokenizer/special_tokens_map.json 101 | wget https://huggingface.co/THUDM/CogVideoX-2b/resolve/main/tokenizer/spiece.model 102 | wget https://huggingface.co/THUDM/CogVideoX-2b/resolve/main/tokenizer/tokenizer_config.json 103 | 104 | cd ../ 105 | mkdir vae 106 | wget https://cloud.tsinghua.edu.cn/f/fdba7608a49c463ba754/?dl=1 107 | mv 'index.html?dl=1' vae.zip 108 | unzip vae.zip 109 | ``` 110 | ### 2. Download Our latest Checkpoint 111 | ``` 112 | git-lfs clone https://huggingface.co/Vchitect/RepVideo/tree/main 113 | # Then modify the "load" path in "sat/configs/inference.yaml" accordingly. 114 | ``` 115 | 116 | ## Inference 117 | 118 | ~~~bash 119 | cd sat 120 | bash run.sh 121 | ~~~ 122 | 123 | ## BibTeX 124 | ``` 125 | @article{si2025RepVideo, 126 | title={RepVideo: Rethinking Cross-Layer Representation for Video Generation}, 127 | author={Si, Chenyang and Fan, Weichen and Lv, Zhengyao and Huang, Ziqi and Qiao, Yu and Liu, Ziwei}, 128 | journal={arXiv 2501.08994}, 129 | year={2025} 130 | } 131 | ``` 132 | 133 | ## 🔑 License 134 | 135 | This code is licensed under Apache-2.0. The framework is fully open for academic research and also allows free commercial usage. 136 | 137 | 138 | ## Disclaimer 139 | 140 | We disclaim responsibility for user-generated content. The model was not trained to realistically represent people or events, so using it to generate such content is beyond the model's capabilities. It is prohibited for pornographic, violent and bloody content generation, and to generate content that is demeaning or harmful to people or their environment, culture, religion, etc. Users are solely liable for their actions. The project contributors are not legally affiliated with, nor accountable for users' behaviors. Use the generative model responsibly, adhering to ethical and legal standards. 141 | 142 | -------------------------------------------------------------------------------- /assets/1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vchitect/RepVideo/7727c0c592a152b9cd09a47d9d1ba3b81792e25f/assets/1.gif -------------------------------------------------------------------------------- /assets/10.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vchitect/RepVideo/7727c0c592a152b9cd09a47d9d1ba3b81792e25f/assets/10.gif -------------------------------------------------------------------------------- /assets/11.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vchitect/RepVideo/7727c0c592a152b9cd09a47d9d1ba3b81792e25f/assets/11.gif -------------------------------------------------------------------------------- /assets/12.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vchitect/RepVideo/7727c0c592a152b9cd09a47d9d1ba3b81792e25f/assets/12.gif -------------------------------------------------------------------------------- /assets/2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vchitect/RepVideo/7727c0c592a152b9cd09a47d9d1ba3b81792e25f/assets/2.gif -------------------------------------------------------------------------------- /assets/3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vchitect/RepVideo/7727c0c592a152b9cd09a47d9d1ba3b81792e25f/assets/3.gif -------------------------------------------------------------------------------- /assets/4.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vchitect/RepVideo/7727c0c592a152b9cd09a47d9d1ba3b81792e25f/assets/4.gif -------------------------------------------------------------------------------- /assets/5.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vchitect/RepVideo/7727c0c592a152b9cd09a47d9d1ba3b81792e25f/assets/5.gif -------------------------------------------------------------------------------- /assets/6.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vchitect/RepVideo/7727c0c592a152b9cd09a47d9d1ba3b81792e25f/assets/6.gif -------------------------------------------------------------------------------- /assets/7.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vchitect/RepVideo/7727c0c592a152b9cd09a47d9d1ba3b81792e25f/assets/7.gif -------------------------------------------------------------------------------- /assets/8.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vchitect/RepVideo/7727c0c592a152b9cd09a47d9d1ba3b81792e25f/assets/8.gif -------------------------------------------------------------------------------- /assets/9.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vchitect/RepVideo/7727c0c592a152b9cd09a47d9d1ba3b81792e25f/assets/9.gif -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | diffusers>=0.31.0 2 | accelerate>=1.1.1 3 | transformers>=4.46.2 4 | numpy==2.2.1 5 | torch>=2.5.0 6 | torchvision>=0.20.0 7 | sentencepiece>=0.2.0 8 | SwissArmyTransformer>=0.4.12 9 | gradio>=5.5.0 10 | imageio>=2.35.1 11 | imageio-ffmpeg>=0.5.1 12 | openai>=1.54.0 13 | moviepy>=1.0.3 14 | scikit-video>=1.1.11 15 | SwissArmyTransformer>=0.4.12 16 | omegaconf>=2.3.0 17 | pytorch_lightning>=2.4.0 18 | kornia>=0.7.3 19 | beartype>=0.19.0 20 | fsspec>=2024.2.0 21 | safetensors>=0.4.5 22 | scipy>=1.14.1 23 | decord>=0.6.0 24 | wandb>=0.18.5 25 | deepspeed>=0.15.3 26 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | srun -p video-aigc-1 --gres=gpu:1 --quotatype=auto --ntasks-per-node=1 -n1 -N1 --cpus-per-task=12 python ckpt.py -------------------------------------------------------------------------------- /sat/asset/prompt.txt: -------------------------------------------------------------------------------- 1 | A cat holding a sign. 2 | -------------------------------------------------------------------------------- /sat/configs/cogvideox_2b.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | scale_factor: 1.15258426 3 | disable_first_stage_autocast: true 4 | log_keys: 5 | - txt 6 | 7 | denoiser_config: 8 | target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser 9 | params: 10 | num_idx: 1000 11 | quantize_c_noise: False 12 | 13 | weighting_config: 14 | target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting 15 | scaling_config: 16 | target: sgm.modules.diffusionmodules.denoiser_scaling.VideoScaling 17 | discretization_config: 18 | target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization 19 | params: 20 | shift_scale: 3.0 21 | 22 | network_config: 23 | target: dit_video_concat.DiffusionTransformer 24 | params: 25 | time_embed_dim: 512 26 | elementwise_affine: True 27 | num_frames: 49 28 | time_compressed_rate: 4 29 | latent_width: 90 30 | latent_height: 60 31 | # latent_width: 48 32 | # latent_height: 32 33 | num_layers: 30 34 | patch_size: 2 35 | in_channels: 16 36 | out_channels: 16 37 | hidden_size: 1920 38 | adm_in_channels: 256 39 | num_attention_heads: 30 40 | 41 | transformer_args: 42 | checkpoint_activations: True ## using gradient checkpointing 43 | vocab_size: 1 44 | max_sequence_length: 64 45 | layernorm_order: pre 46 | skip_init: false 47 | model_parallel_size: 1 48 | is_decoder: false 49 | 50 | modules: 51 | pos_embed_config: 52 | target: dit_video_concat.Basic3DPositionEmbeddingMixin 53 | params: 54 | text_length: 226 55 | height_interpolation: 1.875 56 | width_interpolation: 1.875 57 | 58 | patch_embed_config: 59 | target: dit_video_concat.ImagePatchEmbeddingMixin 60 | params: 61 | text_hidden_size: 4096 62 | 63 | adaln_layer_config: 64 | target: dit_video_concat.AdaLNMixin 65 | params: 66 | qk_ln: True 67 | 68 | final_layer_config: 69 | target: dit_video_concat.FinalLayerMixin 70 | 71 | conditioner_config: 72 | target: sgm.modules.GeneralConditioner 73 | params: 74 | emb_models: 75 | - is_trainable: false 76 | input_key: txt 77 | ucg_rate: 0.1 78 | target: sgm.modules.encoders.modules.FrozenT5Embedder 79 | params: 80 | model_dir: "ckpt/t5-v1_1-xxl" 81 | max_length: 226 82 | 83 | first_stage_config: 84 | target: vae_modules.autoencoder.VideoAutoencoderInferenceWrapper 85 | params: 86 | cp_size: 1 87 | ckpt_path: "ckpt/vae/vae/3d-vae.pt" 88 | ignore_keys: [ 'loss' ] 89 | 90 | loss_config: 91 | target: torch.nn.Identity 92 | 93 | regularizer_config: 94 | target: vae_modules.regularizers.DiagonalGaussianRegularizer 95 | 96 | encoder_config: 97 | target: vae_modules.cp_enc_dec.ContextParallelEncoder3D 98 | params: 99 | double_z: true 100 | z_channels: 16 101 | resolution: 256 102 | in_channels: 3 103 | out_ch: 3 104 | ch: 128 105 | ch_mult: [ 1, 2, 2, 4 ] 106 | attn_resolutions: [ ] 107 | num_res_blocks: 3 108 | dropout: 0.0 109 | gather_norm: True 110 | 111 | decoder_config: 112 | target: vae_modules.cp_enc_dec.ContextParallelDecoder3D 113 | params: 114 | double_z: True 115 | z_channels: 16 116 | resolution: 256 117 | in_channels: 3 118 | out_ch: 3 119 | ch: 128 120 | ch_mult: [ 1, 2, 2, 4 ] 121 | attn_resolutions: [ ] 122 | num_res_blocks: 3 123 | dropout: 0.0 124 | gather_norm: False 125 | 126 | loss_fn_config: 127 | target: sgm.modules.diffusionmodules.loss.VideoDiffusionLoss 128 | params: 129 | offset_noise_level: 0 130 | sigma_sampler_config: 131 | target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling 132 | params: 133 | uniform_sampling: True 134 | num_idx: 1000 135 | discretization_config: 136 | target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization 137 | params: 138 | shift_scale: 3.0 139 | 140 | sampler_config: 141 | target: sgm.modules.diffusionmodules.sampling.VPSDEDPMPP2MSampler 142 | params: 143 | num_steps: 50 144 | verbose: True 145 | 146 | discretization_config: 147 | target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization 148 | params: 149 | shift_scale: 3.0 150 | 151 | guider_config: 152 | target: sgm.modules.diffusionmodules.guiders.DynamicCFG 153 | params: 154 | scale: 6 155 | exp: 5 156 | num_steps: 50 -------------------------------------------------------------------------------- /sat/configs/inference.yaml: -------------------------------------------------------------------------------- 1 | args: 2 | latent_channels: 16 3 | mode: inference 4 | load: ckpt/RepVideo 5 | 6 | 7 | batch_size: 1 8 | input_type: txt 9 | input_file: asset/prompt.txt 10 | 11 | sampling_num_frames: 13 12 | sampling_fps: 8 13 | fp16: True 14 | 15 | output_dir: outputs/cog/train_ncfg_channel_1_step_orth 16 | force_inference: True -------------------------------------------------------------------------------- /sat/demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import argparse 4 | from typing import List, Union 5 | from tqdm import tqdm 6 | import imageio 7 | import torch 8 | import numpy as np 9 | from einops import rearrange 10 | import torchvision.transforms as TT 11 | from sat.model.base_model import get_model 12 | from sat.training.model_io import load_checkpoint 13 | from sat import mpu 14 | from diffusion_video import SATVideoDiffusionEngine 15 | from torchvision.transforms.functional import resize 16 | from torchvision.transforms import InterpolationMode 17 | import gradio as gr 18 | from arguments import get_args 19 | 20 | # Load model once at the beginning 21 | class ModelHandler: 22 | def __init__(self): 23 | self.model = None 24 | self.first_stage_model = None 25 | 26 | def load_model(self, args): 27 | if self.model is None: 28 | self.model = get_model(args, SATVideoDiffusionEngine) 29 | load_checkpoint(self.model, args) 30 | self.model.eval() 31 | self.first_stage_model = self.model.first_stage_model 32 | 33 | def get_model(self): 34 | return self.model 35 | 36 | def get_first_stage_model(self): 37 | return self.first_stage_model 38 | 39 | model_handler = ModelHandler() 40 | 41 | # Utility functions 42 | def get_unique_embedder_keys_from_conditioner(conditioner): 43 | return list(set([x.input_key for x in conditioner.embedders])) 44 | 45 | def get_batch(keys, value_dict, N: List[int], T=None, device="cuda"): 46 | batch = {} 47 | batch_uc = {} 48 | 49 | for key in keys: 50 | if key == "txt": 51 | batch["txt"] = np.repeat([value_dict["prompt"]], repeats=math.prod(N)).reshape(N).tolist() 52 | batch_uc["txt"] = np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)).reshape(N).tolist() 53 | else: 54 | batch[key] = value_dict[key] 55 | 56 | if T is not None: 57 | batch["num_video_frames"] = T 58 | 59 | for key in batch.keys(): 60 | if key not in batch_uc and isinstance(batch[key], torch.Tensor): 61 | batch_uc[key] = torch.clone(batch[key]) 62 | return batch, batch_uc 63 | 64 | def save_video_as_grid_and_mp4(video_batch: torch.Tensor, save_path: str, fps: int = 5): 65 | os.makedirs(save_path, exist_ok=True) 66 | 67 | for i, vid in enumerate(video_batch): 68 | gif_frames = [] 69 | for frame in vid: 70 | frame = rearrange(frame, "c h w -> h w c") 71 | frame = (255.0 * frame).cpu().numpy().astype(np.uint8) 72 | gif_frames.append(frame) 73 | now_save_path = os.path.join(save_path, f"{i:06d}.mp4") 74 | with imageio.get_writer(now_save_path, fps=fps) as writer: 75 | for frame in gif_frames: 76 | writer.append_data(frame) 77 | 78 | # Main inference function 79 | def infer(prompt: str, sampling_num_frames: int, batch_size: int, latent_channels: int, sampling_fps: int): 80 | args = get_args(['--base', 'configs/cogvideox_2b.yaml', 'configs/inference.yaml', '--seed', '42']) 81 | args = argparse.Namespace(**vars(args)) 82 | del args.deepspeed_config 83 | args.model_config.first_stage_config.params.cp_size = 1 84 | args.model_config.network_config.params.transformer_args.model_parallel_size = 1 85 | args.model_config.network_config.params.transformer_args.checkpoint_activations = False 86 | args.model_config.loss_fn_config.params.sigma_sampler_config.params.uniform_sampling = False 87 | 88 | model_handler.load_model(args) 89 | model = model_handler.get_model() 90 | first_stage_model = model_handler.get_first_stage_model() 91 | 92 | image_size = [480, 720] 93 | T, H, W, C, F = args.sampling_num_frames, image_size[0], image_size[1], args.latent_channels, 8 94 | num_samples = [1] 95 | 96 | value_dict = {"prompt": prompt, "negative_prompt": "", "num_frames": torch.tensor(T).unsqueeze(0)} 97 | 98 | batch, batch_uc = get_batch( 99 | get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, num_samples 100 | ) 101 | 102 | c, uc = model.conditioner.get_unconditional_conditioning( 103 | batch, 104 | batch_uc=batch_uc, 105 | force_uc_zero_embeddings=["txt"], 106 | ) 107 | for k in c: 108 | if not k == "crossattn": 109 | c[k], uc[k] = map(lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc)) 110 | 111 | samples_z = model.sample( 112 | c, 113 | uc=uc, 114 | batch_size=batch_size, 115 | shape=(T, C, H // F, W // F), 116 | ) 117 | samples_z = samples_z.permute(0, 2, 1, 3, 4).contiguous() 118 | 119 | latent = 1.0 / model.scale_factor * samples_z 120 | 121 | # recons = [] 122 | # loop_num = (T - 1) // 2 123 | # for i in range(loop_num): 124 | # start_frame, end_frame = i * 2 + 1, i * 2 + 3 if i != 0 else (0, 3) 125 | # recon = first_stage_model.decode(latent[:, :, start_frame:end_frame].contiguous()) 126 | # recons.append(recon) 127 | recons = [] 128 | loop_num = (T - 1) // 2 129 | for i in range(loop_num): 130 | if i == 0: 131 | start_frame, end_frame = 0, 3 132 | else: 133 | start_frame, end_frame = i * 2 + 1, i * 2 + 3 134 | if i == loop_num - 1: 135 | clear_fake_cp_cache = True 136 | else: 137 | clear_fake_cp_cache = False 138 | with torch.no_grad(): 139 | recon = first_stage_model.decode( 140 | latent[:, :, start_frame:end_frame].contiguous(), clear_fake_cp_cache=clear_fake_cp_cache 141 | ) 142 | 143 | recons.append(recon) 144 | 145 | recon = torch.cat(recons, dim=2).to(torch.float32) 146 | samples_x = recon.permute(0, 2, 1, 3, 4).contiguous() 147 | samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0).cpu() 148 | 149 | save_path = "outputs/demo" 150 | save_video_as_grid_and_mp4(samples, save_path, fps=sampling_fps) 151 | return os.path.join(save_path, "000000.mp4") 152 | 153 | # Gradio Interface 154 | def demo_interface(prompt): 155 | video_path = infer(prompt, 16, 1, 8, 5) 156 | return video_path 157 | 158 | with gr.Blocks() as demo: 159 | gr.Markdown("""# RepVideo Gradio Demo 160 | Generate high-quality videos based on text prompts. 161 | """) 162 | 163 | with gr.Row(): 164 | with gr.Column(): 165 | prompt = gr.Textbox(label="Prompt", placeholder="Enter your text prompt here.", lines=3) 166 | generate_button = gr.Button("Generate Video") 167 | 168 | with gr.Column(): 169 | video_output = gr.Video(label="Generated Video") 170 | 171 | generate_button.click( 172 | demo_interface, 173 | inputs=[prompt], 174 | outputs=[video_output] 175 | ) 176 | 177 | demo.launch(server_name="127.0.0.1", server_port=7860) 178 | -------------------------------------------------------------------------------- /sat/run.sh: -------------------------------------------------------------------------------- 1 | 2 | export CUDA_VISIBLE_DEVICES=0 3 | python sample_video.py --base configs/cogvideox_2b.yaml configs/inference.yaml --seed 42 4 | -------------------------------------------------------------------------------- /sat/sample_video.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import argparse 4 | from typing import List, Union 5 | from tqdm import tqdm 6 | from omegaconf import ListConfig 7 | import imageio 8 | 9 | import torch 10 | import numpy as np 11 | from einops import rearrange 12 | import torchvision.transforms as TT 13 | 14 | from sat.model.base_model import get_model 15 | from sat.training.model_io import load_checkpoint 16 | from sat import mpu 17 | 18 | from diffusion_video import SATVideoDiffusionEngine 19 | from arguments import get_args 20 | from torchvision.transforms.functional import center_crop, resize 21 | from torchvision.transforms import InterpolationMode 22 | 23 | 24 | def read_from_cli(): 25 | cnt = 0 26 | try: 27 | while True: 28 | x = input("Please input English text (Ctrl-D quit): ") 29 | yield x.strip(), cnt 30 | cnt += 1 31 | except EOFError as e: 32 | pass 33 | 34 | 35 | def read_from_file(p, rank=0, world_size=1): 36 | with open(p, "r") as fin: 37 | cnt = -1 38 | for l in fin: 39 | cnt += 1 40 | if cnt % world_size != rank: 41 | continue 42 | yield l.strip(), cnt 43 | 44 | 45 | def get_unique_embedder_keys_from_conditioner(conditioner): 46 | return list(set([x.input_key for x in conditioner.embedders])) 47 | 48 | 49 | def get_batch(keys, value_dict, N: Union[List, ListConfig], T=None, device="cuda"): 50 | batch = {} 51 | batch_uc = {} 52 | 53 | for key in keys: 54 | if key == "txt": 55 | batch["txt"] = np.repeat([value_dict["prompt"]], repeats=math.prod(N)).reshape(N).tolist() 56 | batch_uc["txt"] = np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)).reshape(N).tolist() 57 | else: 58 | batch[key] = value_dict[key] 59 | 60 | if T is not None: 61 | batch["num_video_frames"] = T 62 | 63 | for key in batch.keys(): 64 | if key not in batch_uc and isinstance(batch[key], torch.Tensor): 65 | batch_uc[key] = torch.clone(batch[key]) 66 | return batch, batch_uc 67 | 68 | 69 | def save_video_as_grid_and_mp4(video_batch: torch.Tensor, save_path: str, fps: int = 5, args=None, key=None): 70 | os.makedirs(save_path, exist_ok=True) 71 | 72 | for i, vid in enumerate(video_batch): 73 | gif_frames = [] 74 | for frame in vid: 75 | frame = rearrange(frame, "c h w -> h w c") 76 | frame = (255.0 * frame).cpu().numpy().astype(np.uint8) 77 | gif_frames.append(frame) 78 | now_save_path = os.path.join(save_path, f"{i:06d}.mp4") 79 | with imageio.get_writer(now_save_path, fps=fps) as writer: 80 | for frame in gif_frames: 81 | writer.append_data(frame) 82 | 83 | 84 | def resize_for_rectangle_crop(arr, image_size, reshape_mode="random"): 85 | if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]: 86 | arr = resize( 87 | arr, 88 | size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])], 89 | interpolation=InterpolationMode.BICUBIC, 90 | ) 91 | else: 92 | arr = resize( 93 | arr, 94 | size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]], 95 | interpolation=InterpolationMode.BICUBIC, 96 | ) 97 | 98 | h, w = arr.shape[2], arr.shape[3] 99 | arr = arr.squeeze(0) 100 | 101 | delta_h = h - image_size[0] 102 | delta_w = w - image_size[1] 103 | 104 | if reshape_mode == "random" or reshape_mode == "none": 105 | top = np.random.randint(0, delta_h + 1) 106 | left = np.random.randint(0, delta_w + 1) 107 | elif reshape_mode == "center": 108 | top, left = delta_h // 2, delta_w // 2 109 | else: 110 | raise NotImplementedError 111 | arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1]) 112 | return arr 113 | 114 | 115 | def sampling_main(args, model_cls): 116 | if isinstance(model_cls, type): 117 | model = get_model(args, model_cls) 118 | else: 119 | model = model_cls 120 | 121 | load_checkpoint(model, args) 122 | model.eval() 123 | 124 | if args.input_type == "cli": 125 | data_iter = read_from_cli() 126 | elif args.input_type == "txt": 127 | rank, world_size = mpu.get_data_parallel_rank(), mpu.get_data_parallel_world_size() 128 | print("rank and world_size", rank, world_size) 129 | data_iter = read_from_file(args.input_file, rank=rank, world_size=world_size) 130 | else: 131 | raise NotImplementedError 132 | 133 | image_size = [480, 720] 134 | 135 | sample_func = model.sample 136 | T, H, W, C, F = args.sampling_num_frames, image_size[0], image_size[1], args.latent_channels, 8 137 | num_samples = [1] 138 | force_uc_zero_embeddings = ["txt"] 139 | device = model.device 140 | with torch.no_grad(): 141 | for text, cnt in tqdm(data_iter): 142 | # reload model on GPU 143 | model.to(device) 144 | print("rank:", rank, "start to process", text, cnt) 145 | # TODO: broadcast image2video 146 | value_dict = { 147 | "prompt": text, 148 | "negative_prompt": "", 149 | "num_frames": torch.tensor(T).unsqueeze(0), 150 | } 151 | 152 | batch, batch_uc = get_batch( 153 | get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, num_samples 154 | ) 155 | for key in batch: 156 | if isinstance(batch[key], torch.Tensor): 157 | print(key, batch[key].shape) 158 | elif isinstance(batch[key], list): 159 | print(key, [len(l) for l in batch[key]]) 160 | else: 161 | print(key, batch[key]) 162 | c, uc = model.conditioner.get_unconditional_conditioning( 163 | batch, 164 | batch_uc=batch_uc, 165 | force_uc_zero_embeddings=force_uc_zero_embeddings, 166 | ) 167 | 168 | for k in c: 169 | if not k == "crossattn": 170 | c[k], uc[k] = map(lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc)) 171 | for index in range(args.batch_size): 172 | # reload model on GPU 173 | model.to(device) 174 | print('c: ',c.shape) 175 | print('uc: ',uc.shape) 176 | s 177 | samples_z = sample_func( 178 | c, 179 | uc=uc, 180 | batch_size=1, 181 | shape=(T, C, H // F, W // F), 182 | ) 183 | # print(samples_z.shape) 184 | samples_z = samples_z.permute(0, 2, 1, 3, 4).contiguous() 185 | 186 | # Unload the model from GPU to save GPU memory 187 | model.to("cpu") 188 | torch.cuda.empty_cache() 189 | first_stage_model = model.first_stage_model 190 | first_stage_model = first_stage_model.to(device) 191 | 192 | 193 | 194 | latent = 1.0 / model.scale_factor * samples_z 195 | 196 | # Decode latent serial to save GPU memory 197 | recons = [] 198 | loop_num = (T - 1) // 2 199 | for i in range(loop_num): 200 | if i == 0: 201 | start_frame, end_frame = 0, 3 202 | else: 203 | start_frame, end_frame = i * 2 + 1, i * 2 + 3 204 | if i == loop_num - 1: 205 | clear_fake_cp_cache = True 206 | else: 207 | clear_fake_cp_cache = False 208 | with torch.no_grad(): 209 | recon = first_stage_model.decode( 210 | latent[:, :, start_frame:end_frame].contiguous(), clear_fake_cp_cache=clear_fake_cp_cache 211 | ) 212 | 213 | recons.append(recon) 214 | 215 | recon = torch.cat(recons, dim=2).to(torch.float32) 216 | samples_x = recon.permute(0, 2, 1, 3, 4).contiguous() 217 | samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0).cpu() 218 | 219 | save_path = os.path.join( 220 | args.output_dir, str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:120], str(index) 221 | ) 222 | if mpu.get_model_parallel_rank() == 0: 223 | save_video_as_grid_and_mp4(samples, save_path, fps=args.sampling_fps) 224 | 225 | 226 | if __name__ == "__main__": 227 | if "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ: 228 | os.environ["LOCAL_RANK"] = os.environ["OMPI_COMM_WORLD_LOCAL_RANK"] 229 | os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"] 230 | os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"] 231 | py_parser = argparse.ArgumentParser(add_help=False) 232 | known, args_list = py_parser.parse_known_args() 233 | 234 | args = get_args(args_list) 235 | args = argparse.Namespace(**vars(args), **vars(known)) 236 | del args.deepspeed_config 237 | args.model_config.first_stage_config.params.cp_size = 1 238 | args.model_config.network_config.params.transformer_args.model_parallel_size = 1 239 | args.model_config.network_config.params.transformer_args.checkpoint_activations = False 240 | args.model_config.loss_fn_config.params.sigma_sampler_config.params.uniform_sampling = False 241 | 242 | sampling_main(args, model_cls=SATVideoDiffusionEngine) 243 | -------------------------------------------------------------------------------- /sat/sgm/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import AutoencodingEngine 2 | from .util import get_configs_path, instantiate_from_config 3 | 4 | __version__ = "0.1.0" 5 | -------------------------------------------------------------------------------- /sat/sgm/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LambdaWarmUpCosineScheduler: 5 | """ 6 | note: use with a base_lr of 1.0 7 | """ 8 | 9 | def __init__( 10 | self, 11 | warm_up_steps, 12 | lr_min, 13 | lr_max, 14 | lr_start, 15 | max_decay_steps, 16 | verbosity_interval=0, 17 | ): 18 | self.lr_warm_up_steps = warm_up_steps 19 | self.lr_start = lr_start 20 | self.lr_min = lr_min 21 | self.lr_max = lr_max 22 | self.lr_max_decay_steps = max_decay_steps 23 | self.last_lr = 0.0 24 | self.verbosity_interval = verbosity_interval 25 | 26 | def schedule(self, n, **kwargs): 27 | if self.verbosity_interval > 0: 28 | if n % self.verbosity_interval == 0: 29 | print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 30 | if n < self.lr_warm_up_steps: 31 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start 32 | self.last_lr = lr 33 | return lr 34 | else: 35 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) 36 | t = min(t, 1.0) 37 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (1 + np.cos(t * np.pi)) 38 | self.last_lr = lr 39 | return lr 40 | 41 | def __call__(self, n, **kwargs): 42 | return self.schedule(n, **kwargs) 43 | 44 | 45 | class LambdaWarmUpCosineScheduler2: 46 | """ 47 | supports repeated iterations, configurable via lists 48 | note: use with a base_lr of 1.0. 49 | """ 50 | 51 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): 52 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) 53 | self.lr_warm_up_steps = warm_up_steps 54 | self.f_start = f_start 55 | self.f_min = f_min 56 | self.f_max = f_max 57 | self.cycle_lengths = cycle_lengths 58 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) 59 | self.last_f = 0.0 60 | self.verbosity_interval = verbosity_interval 61 | 62 | def find_in_interval(self, n): 63 | interval = 0 64 | for cl in self.cum_cycles[1:]: 65 | if n <= cl: 66 | return interval 67 | interval += 1 68 | 69 | def schedule(self, n, **kwargs): 70 | cycle = self.find_in_interval(n) 71 | n = n - self.cum_cycles[cycle] 72 | if self.verbosity_interval > 0: 73 | if n % self.verbosity_interval == 0: 74 | print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " f"current cycle {cycle}") 75 | if n < self.lr_warm_up_steps[cycle]: 76 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 77 | self.last_f = f 78 | return f 79 | else: 80 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) 81 | t = min(t, 1.0) 82 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (1 + np.cos(t * np.pi)) 83 | self.last_f = f 84 | return f 85 | 86 | def __call__(self, n, **kwargs): 87 | return self.schedule(n, **kwargs) 88 | 89 | 90 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): 91 | def schedule(self, n, **kwargs): 92 | cycle = self.find_in_interval(n) 93 | n = n - self.cum_cycles[cycle] 94 | if self.verbosity_interval > 0: 95 | if n % self.verbosity_interval == 0: 96 | print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " f"current cycle {cycle}") 97 | 98 | if n < self.lr_warm_up_steps[cycle]: 99 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 100 | self.last_f = f 101 | return f 102 | else: 103 | f = ( 104 | self.f_min[cycle] 105 | + (self.f_max[cycle] - self.f_min[cycle]) 106 | * (self.cycle_lengths[cycle] - n) 107 | / (self.cycle_lengths[cycle]) 108 | ) 109 | self.last_f = f 110 | return f 111 | -------------------------------------------------------------------------------- /sat/sgm/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .autoencoder import AutoencodingEngine 2 | -------------------------------------------------------------------------------- /sat/sgm/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .encoders.modules import GeneralConditioner 2 | 3 | UNCONDITIONAL_CONFIG = { 4 | "target": "sgm.modules.GeneralConditioner", 5 | "params": {"emb_models": []}, 6 | } 7 | -------------------------------------------------------------------------------- /sat/sgm/modules/autoencoding/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vchitect/RepVideo/7727c0c592a152b9cd09a47d9d1ba3b81792e25f/sat/sgm/modules/autoencoding/__init__.py -------------------------------------------------------------------------------- /sat/sgm/modules/autoencoding/losses/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | "GeneralLPIPSWithDiscriminator", 3 | "LatentLPIPS", 4 | ] 5 | 6 | from .discriminator_loss import GeneralLPIPSWithDiscriminator 7 | from .lpips import LatentLPIPS 8 | from .video_loss import VideoAutoencoderLoss 9 | -------------------------------------------------------------------------------- /sat/sgm/modules/autoencoding/losses/lpips.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from ....util import default, instantiate_from_config 5 | from ..lpips.loss.lpips import LPIPS 6 | 7 | 8 | class LatentLPIPS(nn.Module): 9 | def __init__( 10 | self, 11 | decoder_config, 12 | perceptual_weight=1.0, 13 | latent_weight=1.0, 14 | scale_input_to_tgt_size=False, 15 | scale_tgt_to_input_size=False, 16 | perceptual_weight_on_inputs=0.0, 17 | ): 18 | super().__init__() 19 | self.scale_input_to_tgt_size = scale_input_to_tgt_size 20 | self.scale_tgt_to_input_size = scale_tgt_to_input_size 21 | self.init_decoder(decoder_config) 22 | self.perceptual_loss = LPIPS().eval() 23 | self.perceptual_weight = perceptual_weight 24 | self.latent_weight = latent_weight 25 | self.perceptual_weight_on_inputs = perceptual_weight_on_inputs 26 | 27 | def init_decoder(self, config): 28 | self.decoder = instantiate_from_config(config) 29 | if hasattr(self.decoder, "encoder"): 30 | del self.decoder.encoder 31 | 32 | def forward(self, latent_inputs, latent_predictions, image_inputs, split="train"): 33 | log = dict() 34 | loss = (latent_inputs - latent_predictions) ** 2 35 | log[f"{split}/latent_l2_loss"] = loss.mean().detach() 36 | image_reconstructions = None 37 | if self.perceptual_weight > 0.0: 38 | image_reconstructions = self.decoder.decode(latent_predictions) 39 | image_targets = self.decoder.decode(latent_inputs) 40 | perceptual_loss = self.perceptual_loss(image_targets.contiguous(), image_reconstructions.contiguous()) 41 | loss = self.latent_weight * loss.mean() + self.perceptual_weight * perceptual_loss.mean() 42 | log[f"{split}/perceptual_loss"] = perceptual_loss.mean().detach() 43 | 44 | if self.perceptual_weight_on_inputs > 0.0: 45 | image_reconstructions = default(image_reconstructions, self.decoder.decode(latent_predictions)) 46 | if self.scale_input_to_tgt_size: 47 | image_inputs = torch.nn.functional.interpolate( 48 | image_inputs, 49 | image_reconstructions.shape[2:], 50 | mode="bicubic", 51 | antialias=True, 52 | ) 53 | elif self.scale_tgt_to_input_size: 54 | image_reconstructions = torch.nn.functional.interpolate( 55 | image_reconstructions, 56 | image_inputs.shape[2:], 57 | mode="bicubic", 58 | antialias=True, 59 | ) 60 | 61 | perceptual_loss2 = self.perceptual_loss(image_inputs.contiguous(), image_reconstructions.contiguous()) 62 | loss = loss + self.perceptual_weight_on_inputs * perceptual_loss2.mean() 63 | log[f"{split}/perceptual_loss_on_inputs"] = perceptual_loss2.mean().detach() 64 | return loss, log 65 | -------------------------------------------------------------------------------- /sat/sgm/modules/autoencoding/lpips/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vchitect/RepVideo/7727c0c592a152b9cd09a47d9d1ba3b81792e25f/sat/sgm/modules/autoencoding/lpips/__init__.py -------------------------------------------------------------------------------- /sat/sgm/modules/autoencoding/lpips/loss/.gitignore: -------------------------------------------------------------------------------- 1 | vgg.pth -------------------------------------------------------------------------------- /sat/sgm/modules/autoencoding/lpips/loss/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 15 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 16 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 17 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 18 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 19 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 20 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 21 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 22 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 23 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /sat/sgm/modules/autoencoding/lpips/loss/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vchitect/RepVideo/7727c0c592a152b9cd09a47d9d1ba3b81792e25f/sat/sgm/modules/autoencoding/lpips/loss/__init__.py -------------------------------------------------------------------------------- /sat/sgm/modules/autoencoding/lpips/loss/lpips.py: -------------------------------------------------------------------------------- 1 | """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" 2 | 3 | from collections import namedtuple 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torchvision import models 8 | 9 | from ..util import get_ckpt_path 10 | 11 | 12 | class LPIPS(nn.Module): 13 | # Learned perceptual metric 14 | def __init__(self, use_dropout=True): 15 | super().__init__() 16 | self.scaling_layer = ScalingLayer() 17 | self.chns = [64, 128, 256, 512, 512] # vg16 features 18 | self.net = vgg16(pretrained=True, requires_grad=False) 19 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) 20 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) 21 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) 22 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) 23 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) 24 | self.load_from_pretrained() 25 | for param in self.parameters(): 26 | param.requires_grad = False 27 | 28 | def load_from_pretrained(self, name="vgg_lpips"): 29 | ckpt = get_ckpt_path(name, "sgm/modules/autoencoding/lpips/loss") 30 | self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) 31 | print("loaded pretrained LPIPS loss from {}".format(ckpt)) 32 | 33 | @classmethod 34 | def from_pretrained(cls, name="vgg_lpips"): 35 | if name != "vgg_lpips": 36 | raise NotImplementedError 37 | model = cls() 38 | ckpt = get_ckpt_path(name) 39 | model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) 40 | return model 41 | 42 | def forward(self, input, target): 43 | in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) 44 | outs0, outs1 = self.net(in0_input), self.net(in1_input) 45 | feats0, feats1, diffs = {}, {}, {} 46 | lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] 47 | for kk in range(len(self.chns)): 48 | feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) 49 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 50 | 51 | res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))] 52 | val = res[0] 53 | for l in range(1, len(self.chns)): 54 | val += res[l] 55 | return val 56 | 57 | 58 | class ScalingLayer(nn.Module): 59 | def __init__(self): 60 | super(ScalingLayer, self).__init__() 61 | self.register_buffer("shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None]) 62 | self.register_buffer("scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None]) 63 | 64 | def forward(self, inp): 65 | return (inp - self.shift) / self.scale 66 | 67 | 68 | class NetLinLayer(nn.Module): 69 | """A single linear layer which does a 1x1 conv""" 70 | 71 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 72 | super(NetLinLayer, self).__init__() 73 | layers = ( 74 | [ 75 | nn.Dropout(), 76 | ] 77 | if (use_dropout) 78 | else [] 79 | ) 80 | layers += [ 81 | nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), 82 | ] 83 | self.model = nn.Sequential(*layers) 84 | 85 | 86 | class vgg16(torch.nn.Module): 87 | def __init__(self, requires_grad=False, pretrained=True): 88 | super(vgg16, self).__init__() 89 | vgg_pretrained_features = models.vgg16(pretrained=pretrained).features 90 | self.slice1 = torch.nn.Sequential() 91 | self.slice2 = torch.nn.Sequential() 92 | self.slice3 = torch.nn.Sequential() 93 | self.slice4 = torch.nn.Sequential() 94 | self.slice5 = torch.nn.Sequential() 95 | self.N_slices = 5 96 | for x in range(4): 97 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 98 | for x in range(4, 9): 99 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 100 | for x in range(9, 16): 101 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 102 | for x in range(16, 23): 103 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 104 | for x in range(23, 30): 105 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 106 | if not requires_grad: 107 | for param in self.parameters(): 108 | param.requires_grad = False 109 | 110 | def forward(self, X): 111 | h = self.slice1(X) 112 | h_relu1_2 = h 113 | h = self.slice2(h) 114 | h_relu2_2 = h 115 | h = self.slice3(h) 116 | h_relu3_3 = h 117 | h = self.slice4(h) 118 | h_relu4_3 = h 119 | h = self.slice5(h) 120 | h_relu5_3 = h 121 | vgg_outputs = namedtuple("VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"]) 122 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 123 | return out 124 | 125 | 126 | def normalize_tensor(x, eps=1e-10): 127 | norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True)) 128 | return x / (norm_factor + eps) 129 | 130 | 131 | def spatial_average(x, keepdim=True): 132 | return x.mean([2, 3], keepdim=keepdim) 133 | -------------------------------------------------------------------------------- /sat/sgm/modules/autoencoding/lpips/model/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2017, Jun-Yan Zhu and Taesung Park 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 15 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 16 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 17 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 18 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 19 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 20 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 21 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 22 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 23 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 24 | 25 | 26 | --------------------------- LICENSE FOR pix2pix -------------------------------- 27 | BSD License 28 | 29 | For pix2pix software 30 | Copyright (c) 2016, Phillip Isola and Jun-Yan Zhu 31 | All rights reserved. 32 | 33 | Redistribution and use in source and binary forms, with or without 34 | modification, are permitted provided that the following conditions are met: 35 | 36 | * Redistributions of source code must retain the above copyright notice, this 37 | list of conditions and the following disclaimer. 38 | 39 | * Redistributions in binary form must reproduce the above copyright notice, 40 | this list of conditions and the following disclaimer in the documentation 41 | and/or other materials provided with the distribution. 42 | 43 | ----------------------------- LICENSE FOR DCGAN -------------------------------- 44 | BSD License 45 | 46 | For dcgan.torch software 47 | 48 | Copyright (c) 2015, Facebook, Inc. All rights reserved. 49 | 50 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 51 | 52 | Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 53 | 54 | Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 55 | 56 | Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 57 | 58 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /sat/sgm/modules/autoencoding/lpips/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vchitect/RepVideo/7727c0c592a152b9cd09a47d9d1ba3b81792e25f/sat/sgm/modules/autoencoding/lpips/model/__init__.py -------------------------------------------------------------------------------- /sat/sgm/modules/autoencoding/lpips/model/model.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import torch.nn as nn 4 | 5 | from ..util import ActNorm 6 | 7 | 8 | def weights_init(m): 9 | classname = m.__class__.__name__ 10 | if classname.find("Conv") != -1: 11 | try: 12 | nn.init.normal_(m.weight.data, 0.0, 0.02) 13 | except: 14 | nn.init.normal_(m.conv.weight.data, 0.0, 0.02) 15 | elif classname.find("BatchNorm") != -1: 16 | nn.init.normal_(m.weight.data, 1.0, 0.02) 17 | nn.init.constant_(m.bias.data, 0) 18 | 19 | 20 | class NLayerDiscriminator(nn.Module): 21 | """Defines a PatchGAN discriminator as in Pix2Pix 22 | --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py 23 | """ 24 | 25 | def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): 26 | """Construct a PatchGAN discriminator 27 | Parameters: 28 | input_nc (int) -- the number of channels in input images 29 | ndf (int) -- the number of filters in the last conv layer 30 | n_layers (int) -- the number of conv layers in the discriminator 31 | norm_layer -- normalization layer 32 | """ 33 | super(NLayerDiscriminator, self).__init__() 34 | if not use_actnorm: 35 | norm_layer = nn.BatchNorm2d 36 | else: 37 | norm_layer = ActNorm 38 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters 39 | use_bias = norm_layer.func != nn.BatchNorm2d 40 | else: 41 | use_bias = norm_layer != nn.BatchNorm2d 42 | 43 | kw = 4 44 | padw = 1 45 | sequence = [ 46 | nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), 47 | nn.LeakyReLU(0.2, True), 48 | ] 49 | nf_mult = 1 50 | nf_mult_prev = 1 51 | for n in range(1, n_layers): # gradually increase the number of filters 52 | nf_mult_prev = nf_mult 53 | nf_mult = min(2**n, 8) 54 | sequence += [ 55 | nn.Conv2d( 56 | ndf * nf_mult_prev, 57 | ndf * nf_mult, 58 | kernel_size=kw, 59 | stride=2, 60 | padding=padw, 61 | bias=use_bias, 62 | ), 63 | norm_layer(ndf * nf_mult), 64 | nn.LeakyReLU(0.2, True), 65 | ] 66 | 67 | nf_mult_prev = nf_mult 68 | nf_mult = min(2**n_layers, 8) 69 | sequence += [ 70 | nn.Conv2d( 71 | ndf * nf_mult_prev, 72 | ndf * nf_mult, 73 | kernel_size=kw, 74 | stride=1, 75 | padding=padw, 76 | bias=use_bias, 77 | ), 78 | norm_layer(ndf * nf_mult), 79 | nn.LeakyReLU(0.2, True), 80 | ] 81 | 82 | sequence += [ 83 | nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw) 84 | ] # output 1 channel prediction map 85 | self.main = nn.Sequential(*sequence) 86 | 87 | def forward(self, input): 88 | """Standard forward.""" 89 | return self.main(input) 90 | -------------------------------------------------------------------------------- /sat/sgm/modules/autoencoding/lpips/util.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | 4 | import requests 5 | import torch 6 | import torch.nn as nn 7 | from tqdm import tqdm 8 | 9 | URL_MAP = {"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"} 10 | 11 | CKPT_MAP = {"vgg_lpips": "vgg.pth"} 12 | 13 | MD5_MAP = {"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"} 14 | 15 | 16 | def download(url, local_path, chunk_size=1024): 17 | os.makedirs(os.path.split(local_path)[0], exist_ok=True) 18 | with requests.get(url, stream=True) as r: 19 | total_size = int(r.headers.get("content-length", 0)) 20 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: 21 | with open(local_path, "wb") as f: 22 | for data in r.iter_content(chunk_size=chunk_size): 23 | if data: 24 | f.write(data) 25 | pbar.update(chunk_size) 26 | 27 | 28 | def md5_hash(path): 29 | with open(path, "rb") as f: 30 | content = f.read() 31 | return hashlib.md5(content).hexdigest() 32 | 33 | 34 | def get_ckpt_path(name, root, check=False): 35 | assert name in URL_MAP 36 | path = os.path.join(root, CKPT_MAP[name]) 37 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): 38 | print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) 39 | download(URL_MAP[name], path) 40 | md5 = md5_hash(path) 41 | assert md5 == MD5_MAP[name], md5 42 | return path 43 | 44 | 45 | class ActNorm(nn.Module): 46 | def __init__(self, num_features, logdet=False, affine=True, allow_reverse_init=False): 47 | assert affine 48 | super().__init__() 49 | self.logdet = logdet 50 | self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) 51 | self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) 52 | self.allow_reverse_init = allow_reverse_init 53 | 54 | self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8)) 55 | 56 | def initialize(self, input): 57 | with torch.no_grad(): 58 | flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) 59 | mean = flatten.mean(1).unsqueeze(1).unsqueeze(2).unsqueeze(3).permute(1, 0, 2, 3) 60 | std = flatten.std(1).unsqueeze(1).unsqueeze(2).unsqueeze(3).permute(1, 0, 2, 3) 61 | 62 | self.loc.data.copy_(-mean) 63 | self.scale.data.copy_(1 / (std + 1e-6)) 64 | 65 | def forward(self, input, reverse=False): 66 | if reverse: 67 | return self.reverse(input) 68 | if len(input.shape) == 2: 69 | input = input[:, :, None, None] 70 | squeeze = True 71 | else: 72 | squeeze = False 73 | 74 | _, _, height, width = input.shape 75 | 76 | if self.training and self.initialized.item() == 0: 77 | self.initialize(input) 78 | self.initialized.fill_(1) 79 | 80 | h = self.scale * (input + self.loc) 81 | 82 | if squeeze: 83 | h = h.squeeze(-1).squeeze(-1) 84 | 85 | if self.logdet: 86 | log_abs = torch.log(torch.abs(self.scale)) 87 | logdet = height * width * torch.sum(log_abs) 88 | logdet = logdet * torch.ones(input.shape[0]).to(input) 89 | return h, logdet 90 | 91 | return h 92 | 93 | def reverse(self, output): 94 | if self.training and self.initialized.item() == 0: 95 | if not self.allow_reverse_init: 96 | raise RuntimeError( 97 | "Initializing ActNorm in reverse direction is " 98 | "disabled by default. Use allow_reverse_init=True to enable." 99 | ) 100 | else: 101 | self.initialize(output) 102 | self.initialized.fill_(1) 103 | 104 | if len(output.shape) == 2: 105 | output = output[:, :, None, None] 106 | squeeze = True 107 | else: 108 | squeeze = False 109 | 110 | h = output / self.scale - self.loc 111 | 112 | if squeeze: 113 | h = h.squeeze(-1).squeeze(-1) 114 | return h 115 | -------------------------------------------------------------------------------- /sat/sgm/modules/autoencoding/lpips/vqperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def hinge_d_loss(logits_real, logits_fake): 6 | loss_real = torch.mean(F.relu(1.0 - logits_real)) 7 | loss_fake = torch.mean(F.relu(1.0 + logits_fake)) 8 | d_loss = 0.5 * (loss_real + loss_fake) 9 | return d_loss 10 | 11 | 12 | def vanilla_d_loss(logits_real, logits_fake): 13 | d_loss = 0.5 * ( 14 | torch.mean(torch.nn.functional.softplus(-logits_real)) + torch.mean(torch.nn.functional.softplus(logits_fake)) 15 | ) 16 | return d_loss 17 | -------------------------------------------------------------------------------- /sat/sgm/modules/autoencoding/regularizers/__init__.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import Any, Tuple 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from ....modules.distributions.distributions import DiagonalGaussianDistribution 9 | from .base import AbstractRegularizer 10 | 11 | 12 | class DiagonalGaussianRegularizer(AbstractRegularizer): 13 | def __init__(self, sample: bool = True): 14 | super().__init__() 15 | self.sample = sample 16 | 17 | def get_trainable_parameters(self) -> Any: 18 | yield from () 19 | 20 | def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: 21 | log = dict() 22 | posterior = DiagonalGaussianDistribution(z) 23 | if self.sample: 24 | z = posterior.sample() 25 | else: 26 | z = posterior.mode() 27 | kl_loss = posterior.kl() 28 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] 29 | log["kl_loss"] = kl_loss 30 | return z, log 31 | -------------------------------------------------------------------------------- /sat/sgm/modules/autoencoding/regularizers/base.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import Any, Tuple 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torch import nn 7 | 8 | 9 | class AbstractRegularizer(nn.Module): 10 | def __init__(self): 11 | super().__init__() 12 | 13 | def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: 14 | raise NotImplementedError() 15 | 16 | @abstractmethod 17 | def get_trainable_parameters(self) -> Any: 18 | raise NotImplementedError() 19 | 20 | 21 | class IdentityRegularizer(AbstractRegularizer): 22 | def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: 23 | return z, dict() 24 | 25 | def get_trainable_parameters(self) -> Any: 26 | yield from () 27 | 28 | 29 | def measure_perplexity(predicted_indices: torch.Tensor, num_centroids: int) -> Tuple[torch.Tensor, torch.Tensor]: 30 | # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py 31 | # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally 32 | encodings = F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids) 33 | avg_probs = encodings.mean(0) 34 | perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() 35 | cluster_use = torch.sum(avg_probs > 0) 36 | return perplexity, cluster_use 37 | -------------------------------------------------------------------------------- /sat/sgm/modules/autoencoding/regularizers/finite_scalar_quantization.py: -------------------------------------------------------------------------------- 1 | """ 2 | Finite Scalar Quantization: VQ-VAE Made Simple - https://arxiv.org/abs/2309.15505 3 | Code adapted from Jax version in Appendix A.1 4 | """ 5 | 6 | from typing import List, Optional 7 | 8 | import torch 9 | import torch.nn as nn 10 | from torch.nn import Module 11 | from torch import Tensor, int32 12 | from torch.cuda.amp import autocast 13 | 14 | from einops import rearrange, pack, unpack 15 | 16 | # helper functions 17 | 18 | 19 | def exists(v): 20 | return v is not None 21 | 22 | 23 | def default(*args): 24 | for arg in args: 25 | if exists(arg): 26 | return arg 27 | return None 28 | 29 | 30 | def pack_one(t, pattern): 31 | return pack([t], pattern) 32 | 33 | 34 | def unpack_one(t, ps, pattern): 35 | return unpack(t, ps, pattern)[0] 36 | 37 | 38 | # tensor helpers 39 | 40 | 41 | def round_ste(z: Tensor) -> Tensor: 42 | """Round with straight through gradients.""" 43 | zhat = z.round() 44 | return z + (zhat - z).detach() 45 | 46 | 47 | # main class 48 | 49 | 50 | class FSQ(Module): 51 | def __init__( 52 | self, 53 | levels: List[int], 54 | dim: Optional[int] = None, 55 | num_codebooks=1, 56 | keep_num_codebooks_dim: Optional[bool] = None, 57 | scale: Optional[float] = None, 58 | ): 59 | super().__init__() 60 | _levels = torch.tensor(levels, dtype=int32) 61 | self.register_buffer("_levels", _levels, persistent=False) 62 | 63 | _basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=int32) 64 | self.register_buffer("_basis", _basis, persistent=False) 65 | 66 | self.scale = scale 67 | 68 | codebook_dim = len(levels) 69 | self.codebook_dim = codebook_dim 70 | 71 | effective_codebook_dim = codebook_dim * num_codebooks 72 | self.num_codebooks = num_codebooks 73 | self.effective_codebook_dim = effective_codebook_dim 74 | 75 | keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1) 76 | assert not (num_codebooks > 1 and not keep_num_codebooks_dim) 77 | self.keep_num_codebooks_dim = keep_num_codebooks_dim 78 | 79 | self.dim = default(dim, len(_levels) * num_codebooks) 80 | 81 | has_projections = self.dim != effective_codebook_dim 82 | self.project_in = nn.Linear(self.dim, effective_codebook_dim) if has_projections else nn.Identity() 83 | self.project_out = nn.Linear(effective_codebook_dim, self.dim) if has_projections else nn.Identity() 84 | self.has_projections = has_projections 85 | 86 | self.codebook_size = self._levels.prod().item() 87 | 88 | implicit_codebook = self.indices_to_codes(torch.arange(self.codebook_size), project_out=False) 89 | self.register_buffer("implicit_codebook", implicit_codebook, persistent=False) 90 | 91 | def bound(self, z: Tensor, eps: float = 1e-3) -> Tensor: 92 | """Bound `z`, an array of shape (..., d).""" 93 | half_l = (self._levels - 1) * (1 + eps) / 2 94 | offset = torch.where(self._levels % 2 == 0, 0.5, 0.0) 95 | shift = (offset / half_l).atanh() 96 | return (z + shift).tanh() * half_l - offset 97 | 98 | def quantize(self, z: Tensor) -> Tensor: 99 | """Quantizes z, returns quantized zhat, same shape as z.""" 100 | quantized = round_ste(self.bound(z)) 101 | half_width = self._levels // 2 # Renormalize to [-1, 1]. 102 | return quantized / half_width 103 | 104 | def _scale_and_shift(self, zhat_normalized: Tensor) -> Tensor: 105 | half_width = self._levels // 2 106 | return (zhat_normalized * half_width) + half_width 107 | 108 | def _scale_and_shift_inverse(self, zhat: Tensor) -> Tensor: 109 | half_width = self._levels // 2 110 | return (zhat - half_width) / half_width 111 | 112 | def codes_to_indices(self, zhat: Tensor) -> Tensor: 113 | """Converts a `code` to an index in the codebook.""" 114 | assert zhat.shape[-1] == self.codebook_dim 115 | zhat = self._scale_and_shift(zhat) 116 | return (zhat * self._basis).sum(dim=-1).to(int32) 117 | 118 | def indices_to_codes(self, indices: Tensor, project_out=True) -> Tensor: 119 | """Inverse of `codes_to_indices`.""" 120 | 121 | is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim)) 122 | 123 | indices = rearrange(indices, "... -> ... 1") 124 | codes_non_centered = (indices // self._basis) % self._levels 125 | codes = self._scale_and_shift_inverse(codes_non_centered) 126 | 127 | if self.keep_num_codebooks_dim: 128 | codes = rearrange(codes, "... c d -> ... (c d)") 129 | 130 | if project_out: 131 | codes = self.project_out(codes) 132 | 133 | if is_img_or_video: 134 | codes = rearrange(codes, "b ... d -> b d ...") 135 | 136 | return codes 137 | 138 | @autocast(enabled=False) 139 | def forward(self, z: Tensor) -> Tensor: 140 | """ 141 | einstein notation 142 | b - batch 143 | n - sequence (or flattened spatial dimensions) 144 | d - feature dimension 145 | c - number of codebook dim 146 | """ 147 | 148 | is_img_or_video = z.ndim >= 4 149 | 150 | # standardize image or video into (batch, seq, dimension) 151 | 152 | if is_img_or_video: 153 | z = rearrange(z, "b d ... -> b ... d") 154 | z, ps = pack_one(z, "b * d") 155 | 156 | assert z.shape[-1] == self.dim, f"expected dimension of {self.dim} but found dimension of {z.shape[-1]}" 157 | 158 | z = self.project_in(z) 159 | 160 | z = rearrange(z, "b n (c d) -> b n c d", c=self.num_codebooks) 161 | 162 | codes = self.quantize(z) 163 | indices = self.codes_to_indices(codes) 164 | 165 | codes = rearrange(codes, "b n c d -> b n (c d)") 166 | 167 | out = self.project_out(codes) 168 | 169 | # reconstitute image or video dimensions 170 | 171 | if is_img_or_video: 172 | out = unpack_one(out, ps, "b * d") 173 | out = rearrange(out, "b ... d -> b d ...") 174 | 175 | indices = unpack_one(indices, ps, "b * c") 176 | 177 | if not self.keep_num_codebooks_dim: 178 | indices = rearrange(indices, "... 1 -> ...") 179 | 180 | return out, indices 181 | -------------------------------------------------------------------------------- /sat/sgm/modules/autoencoding/regularizers/lookup_free_quantization.py: -------------------------------------------------------------------------------- 1 | """ 2 | Lookup Free Quantization 3 | Proposed in https://arxiv.org/abs/2310.05737 4 | 5 | In the simplest setup, each dimension is quantized into {-1, 1}. 6 | An entropy penalty is used to encourage utilization. 7 | """ 8 | 9 | from math import log2, ceil 10 | from collections import namedtuple 11 | 12 | import torch 13 | from torch import nn, einsum 14 | import torch.nn.functional as F 15 | from torch.nn import Module 16 | from torch.cuda.amp import autocast 17 | 18 | from einops import rearrange, reduce, pack, unpack 19 | 20 | # constants 21 | 22 | Return = namedtuple("Return", ["quantized", "indices", "entropy_aux_loss"]) 23 | 24 | LossBreakdown = namedtuple("LossBreakdown", ["per_sample_entropy", "batch_entropy", "commitment"]) 25 | 26 | # helper functions 27 | 28 | 29 | def exists(v): 30 | return v is not None 31 | 32 | 33 | def default(*args): 34 | for arg in args: 35 | if exists(arg): 36 | return arg() if callable(arg) else arg 37 | return None 38 | 39 | 40 | def pack_one(t, pattern): 41 | return pack([t], pattern) 42 | 43 | 44 | def unpack_one(t, ps, pattern): 45 | return unpack(t, ps, pattern)[0] 46 | 47 | 48 | # entropy 49 | 50 | 51 | def log(t, eps=1e-5): 52 | return t.clamp(min=eps).log() 53 | 54 | 55 | def entropy(prob): 56 | return (-prob * log(prob)).sum(dim=-1) 57 | 58 | 59 | # class 60 | 61 | 62 | class LFQ(Module): 63 | def __init__( 64 | self, 65 | *, 66 | dim=None, 67 | codebook_size=None, 68 | entropy_loss_weight=0.1, 69 | commitment_loss_weight=0.25, 70 | diversity_gamma=1.0, 71 | straight_through_activation=nn.Identity(), 72 | num_codebooks=1, 73 | keep_num_codebooks_dim=None, 74 | codebook_scale=1.0, # for residual LFQ, codebook scaled down by 2x at each layer 75 | frac_per_sample_entropy=1.0, # make less than 1. to only use a random fraction of the probs for per sample entropy 76 | ): 77 | super().__init__() 78 | 79 | # some assert validations 80 | 81 | assert exists(dim) or exists(codebook_size), "either dim or codebook_size must be specified for LFQ" 82 | assert ( 83 | not exists(codebook_size) or log2(codebook_size).is_integer() 84 | ), f"your codebook size must be a power of 2 for lookup free quantization (suggested {2 ** ceil(log2(codebook_size))})" 85 | 86 | codebook_size = default(codebook_size, lambda: 2**dim) 87 | codebook_dim = int(log2(codebook_size)) 88 | 89 | codebook_dims = codebook_dim * num_codebooks 90 | dim = default(dim, codebook_dims) 91 | 92 | has_projections = dim != codebook_dims 93 | self.project_in = nn.Linear(dim, codebook_dims) if has_projections else nn.Identity() 94 | self.project_out = nn.Linear(codebook_dims, dim) if has_projections else nn.Identity() 95 | self.has_projections = has_projections 96 | 97 | self.dim = dim 98 | self.codebook_dim = codebook_dim 99 | self.num_codebooks = num_codebooks 100 | 101 | keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1) 102 | assert not (num_codebooks > 1 and not keep_num_codebooks_dim) 103 | self.keep_num_codebooks_dim = keep_num_codebooks_dim 104 | 105 | # straight through activation 106 | 107 | self.activation = straight_through_activation 108 | 109 | # entropy aux loss related weights 110 | 111 | assert 0 < frac_per_sample_entropy <= 1.0 112 | self.frac_per_sample_entropy = frac_per_sample_entropy 113 | 114 | self.diversity_gamma = diversity_gamma 115 | self.entropy_loss_weight = entropy_loss_weight 116 | 117 | # codebook scale 118 | 119 | self.codebook_scale = codebook_scale 120 | 121 | # commitment loss 122 | 123 | self.commitment_loss_weight = commitment_loss_weight 124 | 125 | # for no auxiliary loss, during inference 126 | 127 | self.register_buffer("mask", 2 ** torch.arange(codebook_dim - 1, -1, -1)) 128 | self.register_buffer("zero", torch.tensor(0.0), persistent=False) 129 | 130 | # codes 131 | 132 | all_codes = torch.arange(codebook_size) 133 | bits = ((all_codes[..., None].int() & self.mask) != 0).float() 134 | codebook = self.bits_to_codes(bits) 135 | 136 | self.register_buffer("codebook", codebook, persistent=False) 137 | 138 | def bits_to_codes(self, bits): 139 | return bits * self.codebook_scale * 2 - self.codebook_scale 140 | 141 | @property 142 | def dtype(self): 143 | return self.codebook.dtype 144 | 145 | def indices_to_codes(self, indices, project_out=True): 146 | is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim)) 147 | 148 | if not self.keep_num_codebooks_dim: 149 | indices = rearrange(indices, "... -> ... 1") 150 | 151 | # indices to codes, which are bits of either -1 or 1 152 | 153 | bits = ((indices[..., None].int() & self.mask) != 0).to(self.dtype) 154 | 155 | codes = self.bits_to_codes(bits) 156 | 157 | codes = rearrange(codes, "... c d -> ... (c d)") 158 | 159 | # whether to project codes out to original dimensions 160 | # if the input feature dimensions were not log2(codebook size) 161 | 162 | if project_out: 163 | codes = self.project_out(codes) 164 | 165 | # rearrange codes back to original shape 166 | 167 | if is_img_or_video: 168 | codes = rearrange(codes, "b ... d -> b d ...") 169 | 170 | return codes 171 | 172 | @autocast(enabled=False) 173 | def forward( 174 | self, 175 | x, 176 | inv_temperature=100.0, 177 | return_loss_breakdown=False, 178 | mask=None, 179 | ): 180 | """ 181 | einstein notation 182 | b - batch 183 | n - sequence (or flattened spatial dimensions) 184 | d - feature dimension, which is also log2(codebook size) 185 | c - number of codebook dim 186 | """ 187 | 188 | x = x.float() 189 | 190 | is_img_or_video = x.ndim >= 4 191 | 192 | # standardize image or video into (batch, seq, dimension) 193 | 194 | if is_img_or_video: 195 | x = rearrange(x, "b d ... -> b ... d") 196 | x, ps = pack_one(x, "b * d") 197 | 198 | assert x.shape[-1] == self.dim, f"expected dimension of {self.dim} but received {x.shape[-1]}" 199 | 200 | x = self.project_in(x) 201 | 202 | # split out number of codebooks 203 | 204 | x = rearrange(x, "b n (c d) -> b n c d", c=self.num_codebooks) 205 | 206 | # quantize by eq 3. 207 | 208 | original_input = x 209 | 210 | codebook_value = torch.ones_like(x) * self.codebook_scale 211 | quantized = torch.where(x > 0, codebook_value, -codebook_value) 212 | 213 | # use straight-through gradients (optionally with custom activation fn) if training 214 | 215 | if self.training: 216 | x = self.activation(x) 217 | x = x + (quantized - x).detach() 218 | else: 219 | x = quantized 220 | 221 | # calculate indices 222 | 223 | indices = reduce((x > 0).int() * self.mask.int(), "b n c d -> b n c", "sum") 224 | 225 | # entropy aux loss 226 | 227 | if self.training: 228 | # the same as euclidean distance up to a constant 229 | distance = -2 * einsum("... i d, j d -> ... i j", original_input, self.codebook) 230 | 231 | prob = (-distance * inv_temperature).softmax(dim=-1) 232 | 233 | # account for mask 234 | 235 | if exists(mask): 236 | prob = prob[mask] 237 | else: 238 | prob = rearrange(prob, "b n ... -> (b n) ...") 239 | 240 | # whether to only use a fraction of probs, for reducing memory 241 | 242 | if self.frac_per_sample_entropy < 1.0: 243 | num_tokens = prob.shape[0] 244 | num_sampled_tokens = int(num_tokens * self.frac_per_sample_entropy) 245 | rand_mask = torch.randn(num_tokens).argsort(dim=-1) < num_sampled_tokens 246 | per_sample_probs = prob[rand_mask] 247 | else: 248 | per_sample_probs = prob 249 | 250 | # calculate per sample entropy 251 | 252 | per_sample_entropy = entropy(per_sample_probs).mean() 253 | 254 | # distribution over all available tokens in the batch 255 | 256 | avg_prob = reduce(per_sample_probs, "... c d -> c d", "mean") 257 | codebook_entropy = entropy(avg_prob).mean() 258 | 259 | # 1. entropy will be nudged to be low for each code, to encourage the network to output confident predictions 260 | # 2. codebook entropy will be nudged to be high, to encourage all codes to be uniformly used within the batch 261 | 262 | entropy_aux_loss = per_sample_entropy - self.diversity_gamma * codebook_entropy 263 | else: 264 | # if not training, just return dummy 0 265 | entropy_aux_loss = per_sample_entropy = codebook_entropy = self.zero 266 | 267 | # commit loss 268 | 269 | if self.training: 270 | commit_loss = F.mse_loss(original_input, quantized.detach(), reduction="none") 271 | 272 | if exists(mask): 273 | commit_loss = commit_loss[mask] 274 | 275 | commit_loss = commit_loss.mean() 276 | else: 277 | commit_loss = self.zero 278 | 279 | # merge back codebook dim 280 | 281 | x = rearrange(x, "b n c d -> b n (c d)") 282 | 283 | # project out to feature dimension if needed 284 | 285 | x = self.project_out(x) 286 | 287 | # reconstitute image or video dimensions 288 | 289 | if is_img_or_video: 290 | x = unpack_one(x, ps, "b * d") 291 | x = rearrange(x, "b ... d -> b d ...") 292 | 293 | indices = unpack_one(indices, ps, "b * c") 294 | 295 | # whether to remove single codebook dim 296 | 297 | if not self.keep_num_codebooks_dim: 298 | indices = rearrange(indices, "... 1 -> ...") 299 | 300 | # complete aux loss 301 | 302 | aux_loss = entropy_aux_loss * self.entropy_loss_weight + commit_loss * self.commitment_loss_weight 303 | 304 | ret = Return(x, indices, aux_loss) 305 | 306 | if not return_loss_breakdown: 307 | return ret 308 | 309 | return ret, LossBreakdown(per_sample_entropy, codebook_entropy, commit_loss) 310 | -------------------------------------------------------------------------------- /sat/sgm/modules/autoencoding/temporal_ae.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Iterable, Union 2 | 3 | import torch 4 | from einops import rearrange, repeat 5 | 6 | from sgm.modules.diffusionmodules.model import ( 7 | XFORMERS_IS_AVAILABLE, 8 | AttnBlock, 9 | Decoder, 10 | MemoryEfficientAttnBlock, 11 | ResnetBlock, 12 | ) 13 | from sgm.modules.diffusionmodules.openaimodel import ResBlock, timestep_embedding 14 | from sgm.modules.video_attention import VideoTransformerBlock 15 | from sgm.util import partialclass 16 | 17 | 18 | class VideoResBlock(ResnetBlock): 19 | def __init__( 20 | self, 21 | out_channels, 22 | *args, 23 | dropout=0.0, 24 | video_kernel_size=3, 25 | alpha=0.0, 26 | merge_strategy="learned", 27 | **kwargs, 28 | ): 29 | super().__init__(out_channels=out_channels, dropout=dropout, *args, **kwargs) 30 | if video_kernel_size is None: 31 | video_kernel_size = [3, 1, 1] 32 | self.time_stack = ResBlock( 33 | channels=out_channels, 34 | emb_channels=0, 35 | dropout=dropout, 36 | dims=3, 37 | use_scale_shift_norm=False, 38 | use_conv=False, 39 | up=False, 40 | down=False, 41 | kernel_size=video_kernel_size, 42 | use_checkpoint=False, 43 | skip_t_emb=True, 44 | ) 45 | 46 | self.merge_strategy = merge_strategy 47 | if self.merge_strategy == "fixed": 48 | self.register_buffer("mix_factor", torch.Tensor([alpha])) 49 | elif self.merge_strategy == "learned": 50 | self.register_parameter("mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))) 51 | else: 52 | raise ValueError(f"unknown merge strategy {self.merge_strategy}") 53 | 54 | def get_alpha(self, bs): 55 | if self.merge_strategy == "fixed": 56 | return self.mix_factor 57 | elif self.merge_strategy == "learned": 58 | return torch.sigmoid(self.mix_factor) 59 | else: 60 | raise NotImplementedError() 61 | 62 | def forward(self, x, temb, skip_video=False, timesteps=None): 63 | if timesteps is None: 64 | timesteps = self.timesteps 65 | 66 | b, c, h, w = x.shape 67 | 68 | x = super().forward(x, temb) 69 | 70 | if not skip_video: 71 | x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps) 72 | 73 | x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps) 74 | 75 | x = self.time_stack(x, temb) 76 | 77 | alpha = self.get_alpha(bs=b // timesteps) 78 | x = alpha * x + (1.0 - alpha) * x_mix 79 | 80 | x = rearrange(x, "b c t h w -> (b t) c h w") 81 | return x 82 | 83 | 84 | class AE3DConv(torch.nn.Conv2d): 85 | def __init__(self, in_channels, out_channels, video_kernel_size=3, *args, **kwargs): 86 | super().__init__(in_channels, out_channels, *args, **kwargs) 87 | if isinstance(video_kernel_size, Iterable): 88 | padding = [int(k // 2) for k in video_kernel_size] 89 | else: 90 | padding = int(video_kernel_size // 2) 91 | 92 | self.time_mix_conv = torch.nn.Conv3d( 93 | in_channels=out_channels, 94 | out_channels=out_channels, 95 | kernel_size=video_kernel_size, 96 | padding=padding, 97 | ) 98 | 99 | def forward(self, input, timesteps, skip_video=False): 100 | x = super().forward(input) 101 | if skip_video: 102 | return x 103 | x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps) 104 | x = self.time_mix_conv(x) 105 | return rearrange(x, "b c t h w -> (b t) c h w") 106 | 107 | 108 | class VideoBlock(AttnBlock): 109 | def __init__(self, in_channels: int, alpha: float = 0, merge_strategy: str = "learned"): 110 | super().__init__(in_channels) 111 | # no context, single headed, as in base class 112 | self.time_mix_block = VideoTransformerBlock( 113 | dim=in_channels, 114 | n_heads=1, 115 | d_head=in_channels, 116 | checkpoint=False, 117 | ff_in=True, 118 | attn_mode="softmax", 119 | ) 120 | 121 | time_embed_dim = self.in_channels * 4 122 | self.video_time_embed = torch.nn.Sequential( 123 | torch.nn.Linear(self.in_channels, time_embed_dim), 124 | torch.nn.SiLU(), 125 | torch.nn.Linear(time_embed_dim, self.in_channels), 126 | ) 127 | 128 | self.merge_strategy = merge_strategy 129 | if self.merge_strategy == "fixed": 130 | self.register_buffer("mix_factor", torch.Tensor([alpha])) 131 | elif self.merge_strategy == "learned": 132 | self.register_parameter("mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))) 133 | else: 134 | raise ValueError(f"unknown merge strategy {self.merge_strategy}") 135 | 136 | def forward(self, x, timesteps, skip_video=False): 137 | if skip_video: 138 | return super().forward(x) 139 | 140 | x_in = x 141 | x = self.attention(x) 142 | h, w = x.shape[2:] 143 | x = rearrange(x, "b c h w -> b (h w) c") 144 | 145 | x_mix = x 146 | num_frames = torch.arange(timesteps, device=x.device) 147 | num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps) 148 | num_frames = rearrange(num_frames, "b t -> (b t)") 149 | t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False) 150 | emb = self.video_time_embed(t_emb) # b, n_channels 151 | emb = emb[:, None, :] 152 | x_mix = x_mix + emb 153 | 154 | alpha = self.get_alpha() 155 | x_mix = self.time_mix_block(x_mix, timesteps=timesteps) 156 | x = alpha * x + (1.0 - alpha) * x_mix # alpha merge 157 | 158 | x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) 159 | x = self.proj_out(x) 160 | 161 | return x_in + x 162 | 163 | def get_alpha( 164 | self, 165 | ): 166 | if self.merge_strategy == "fixed": 167 | return self.mix_factor 168 | elif self.merge_strategy == "learned": 169 | return torch.sigmoid(self.mix_factor) 170 | else: 171 | raise NotImplementedError(f"unknown merge strategy {self.merge_strategy}") 172 | 173 | 174 | class MemoryEfficientVideoBlock(MemoryEfficientAttnBlock): 175 | def __init__(self, in_channels: int, alpha: float = 0, merge_strategy: str = "learned"): 176 | super().__init__(in_channels) 177 | # no context, single headed, as in base class 178 | self.time_mix_block = VideoTransformerBlock( 179 | dim=in_channels, 180 | n_heads=1, 181 | d_head=in_channels, 182 | checkpoint=False, 183 | ff_in=True, 184 | attn_mode="softmax-xformers", 185 | ) 186 | 187 | time_embed_dim = self.in_channels * 4 188 | self.video_time_embed = torch.nn.Sequential( 189 | torch.nn.Linear(self.in_channels, time_embed_dim), 190 | torch.nn.SiLU(), 191 | torch.nn.Linear(time_embed_dim, self.in_channels), 192 | ) 193 | 194 | self.merge_strategy = merge_strategy 195 | if self.merge_strategy == "fixed": 196 | self.register_buffer("mix_factor", torch.Tensor([alpha])) 197 | elif self.merge_strategy == "learned": 198 | self.register_parameter("mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))) 199 | else: 200 | raise ValueError(f"unknown merge strategy {self.merge_strategy}") 201 | 202 | def forward(self, x, timesteps, skip_time_block=False): 203 | if skip_time_block: 204 | return super().forward(x) 205 | 206 | x_in = x 207 | x = self.attention(x) 208 | h, w = x.shape[2:] 209 | x = rearrange(x, "b c h w -> b (h w) c") 210 | 211 | x_mix = x 212 | num_frames = torch.arange(timesteps, device=x.device) 213 | num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps) 214 | num_frames = rearrange(num_frames, "b t -> (b t)") 215 | t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False) 216 | emb = self.video_time_embed(t_emb) # b, n_channels 217 | emb = emb[:, None, :] 218 | x_mix = x_mix + emb 219 | 220 | alpha = self.get_alpha() 221 | x_mix = self.time_mix_block(x_mix, timesteps=timesteps) 222 | x = alpha * x + (1.0 - alpha) * x_mix # alpha merge 223 | 224 | x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) 225 | x = self.proj_out(x) 226 | 227 | return x_in + x 228 | 229 | def get_alpha( 230 | self, 231 | ): 232 | if self.merge_strategy == "fixed": 233 | return self.mix_factor 234 | elif self.merge_strategy == "learned": 235 | return torch.sigmoid(self.mix_factor) 236 | else: 237 | raise NotImplementedError(f"unknown merge strategy {self.merge_strategy}") 238 | 239 | 240 | def make_time_attn( 241 | in_channels, 242 | attn_type="vanilla", 243 | attn_kwargs=None, 244 | alpha: float = 0, 245 | merge_strategy: str = "learned", 246 | ): 247 | assert attn_type in [ 248 | "vanilla", 249 | "vanilla-xformers", 250 | ], f"attn_type {attn_type} not supported for spatio-temporal attention" 251 | print(f"making spatial and temporal attention of type '{attn_type}' with {in_channels} in_channels") 252 | if not XFORMERS_IS_AVAILABLE and attn_type == "vanilla-xformers": 253 | print( 254 | f"Attention mode '{attn_type}' is not available. Falling back to vanilla attention. " 255 | f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}" 256 | ) 257 | attn_type = "vanilla" 258 | 259 | if attn_type == "vanilla": 260 | assert attn_kwargs is None 261 | return partialclass(VideoBlock, in_channels, alpha=alpha, merge_strategy=merge_strategy) 262 | elif attn_type == "vanilla-xformers": 263 | print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...") 264 | return partialclass( 265 | MemoryEfficientVideoBlock, 266 | in_channels, 267 | alpha=alpha, 268 | merge_strategy=merge_strategy, 269 | ) 270 | else: 271 | return NotImplementedError() 272 | 273 | 274 | class Conv2DWrapper(torch.nn.Conv2d): 275 | def forward(self, input: torch.Tensor, **kwargs) -> torch.Tensor: 276 | return super().forward(input) 277 | 278 | 279 | class VideoDecoder(Decoder): 280 | available_time_modes = ["all", "conv-only", "attn-only"] 281 | 282 | def __init__( 283 | self, 284 | *args, 285 | video_kernel_size: Union[int, list] = 3, 286 | alpha: float = 0.0, 287 | merge_strategy: str = "learned", 288 | time_mode: str = "conv-only", 289 | **kwargs, 290 | ): 291 | self.video_kernel_size = video_kernel_size 292 | self.alpha = alpha 293 | self.merge_strategy = merge_strategy 294 | self.time_mode = time_mode 295 | assert ( 296 | self.time_mode in self.available_time_modes 297 | ), f"time_mode parameter has to be in {self.available_time_modes}" 298 | super().__init__(*args, **kwargs) 299 | 300 | def get_last_layer(self, skip_time_mix=False, **kwargs): 301 | if self.time_mode == "attn-only": 302 | raise NotImplementedError("TODO") 303 | else: 304 | return self.conv_out.time_mix_conv.weight if not skip_time_mix else self.conv_out.weight 305 | 306 | def _make_attn(self) -> Callable: 307 | if self.time_mode not in ["conv-only", "only-last-conv"]: 308 | return partialclass( 309 | make_time_attn, 310 | alpha=self.alpha, 311 | merge_strategy=self.merge_strategy, 312 | ) 313 | else: 314 | return super()._make_attn() 315 | 316 | def _make_conv(self) -> Callable: 317 | if self.time_mode != "attn-only": 318 | return partialclass(AE3DConv, video_kernel_size=self.video_kernel_size) 319 | else: 320 | return Conv2DWrapper 321 | 322 | def _make_resblock(self) -> Callable: 323 | if self.time_mode not in ["attn-only", "only-last-conv"]: 324 | return partialclass( 325 | VideoResBlock, 326 | video_kernel_size=self.video_kernel_size, 327 | alpha=self.alpha, 328 | merge_strategy=self.merge_strategy, 329 | ) 330 | else: 331 | return super()._make_resblock() 332 | -------------------------------------------------------------------------------- /sat/sgm/modules/autoencoding/vqvae/quantize.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from torch import einsum 6 | from einops import rearrange 7 | 8 | 9 | class VectorQuantizer2(nn.Module): 10 | """ 11 | Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly 12 | avoids costly matrix multiplications and allows for post-hoc remapping of indices. 13 | """ 14 | 15 | # NOTE: due to a bug the beta term was applied to the wrong term. for 16 | # backwards compatibility we use the buggy version by default, but you can 17 | # specify legacy=False to fix it. 18 | def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True): 19 | super().__init__() 20 | self.n_e = n_e 21 | self.e_dim = e_dim 22 | self.beta = beta 23 | self.legacy = legacy 24 | 25 | self.embedding = nn.Embedding(self.n_e, self.e_dim) 26 | self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) 27 | 28 | self.remap = remap 29 | if self.remap is not None: 30 | self.register_buffer("used", torch.tensor(np.load(self.remap))) 31 | self.re_embed = self.used.shape[0] 32 | self.unknown_index = unknown_index # "random" or "extra" or integer 33 | if self.unknown_index == "extra": 34 | self.unknown_index = self.re_embed 35 | self.re_embed = self.re_embed + 1 36 | print( 37 | f"Remapping {self.n_e} indices to {self.re_embed} indices. " 38 | f"Using {self.unknown_index} for unknown indices." 39 | ) 40 | else: 41 | self.re_embed = n_e 42 | 43 | self.sane_index_shape = sane_index_shape 44 | 45 | def remap_to_used(self, inds): 46 | ishape = inds.shape 47 | assert len(ishape) > 1 48 | inds = inds.reshape(ishape[0], -1) 49 | used = self.used.to(inds) 50 | match = (inds[:, :, None] == used[None, None, ...]).long() 51 | new = match.argmax(-1) 52 | unknown = match.sum(2) < 1 53 | if self.unknown_index == "random": 54 | new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device) 55 | else: 56 | new[unknown] = self.unknown_index 57 | return new.reshape(ishape) 58 | 59 | def unmap_to_all(self, inds): 60 | ishape = inds.shape 61 | assert len(ishape) > 1 62 | inds = inds.reshape(ishape[0], -1) 63 | used = self.used.to(inds) 64 | if self.re_embed > self.used.shape[0]: # extra token 65 | inds[inds >= self.used.shape[0]] = 0 # simply set to zero 66 | back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) 67 | return back.reshape(ishape) 68 | 69 | def forward(self, z, temp=None, rescale_logits=False, return_logits=False): 70 | assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel" 71 | assert rescale_logits == False, "Only for interface compatible with Gumbel" 72 | assert return_logits == False, "Only for interface compatible with Gumbel" 73 | # reshape z -> (batch, height, width, channel) and flatten 74 | z = rearrange(z, "b c h w -> b h w c").contiguous() 75 | z_flattened = z.view(-1, self.e_dim) 76 | # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z 77 | 78 | d = ( 79 | torch.sum(z_flattened**2, dim=1, keepdim=True) 80 | + torch.sum(self.embedding.weight**2, dim=1) 81 | - 2 * torch.einsum("bd,dn->bn", z_flattened, rearrange(self.embedding.weight, "n d -> d n")) 82 | ) 83 | 84 | min_encoding_indices = torch.argmin(d, dim=1) 85 | z_q = self.embedding(min_encoding_indices).view(z.shape) 86 | perplexity = None 87 | min_encodings = None 88 | 89 | # compute loss for embedding 90 | if not self.legacy: 91 | loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2) 92 | else: 93 | loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2) 94 | 95 | # preserve gradients 96 | z_q = z + (z_q - z).detach() 97 | 98 | # reshape back to match original input shape 99 | z_q = rearrange(z_q, "b h w c -> b c h w").contiguous() 100 | 101 | if self.remap is not None: 102 | min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis 103 | min_encoding_indices = self.remap_to_used(min_encoding_indices) 104 | min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten 105 | 106 | if self.sane_index_shape: 107 | min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3]) 108 | 109 | return z_q, loss, (perplexity, min_encodings, min_encoding_indices) 110 | 111 | def get_codebook_entry(self, indices, shape): 112 | # shape specifying (batch, height, width, channel) 113 | if self.remap is not None: 114 | indices = indices.reshape(shape[0], -1) # add batch axis 115 | indices = self.unmap_to_all(indices) 116 | indices = indices.reshape(-1) # flatten again 117 | 118 | # get quantized latent vectors 119 | z_q = self.embedding(indices) 120 | 121 | if shape is not None: 122 | z_q = z_q.view(shape) 123 | # reshape back to match original input shape 124 | z_q = z_q.permute(0, 3, 1, 2).contiguous() 125 | 126 | return z_q 127 | 128 | 129 | class GumbelQuantize(nn.Module): 130 | """ 131 | credit to @karpathy: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!) 132 | Gumbel Softmax trick quantizer 133 | Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016 134 | https://arxiv.org/abs/1611.01144 135 | """ 136 | 137 | def __init__( 138 | self, 139 | num_hiddens, 140 | embedding_dim, 141 | n_embed, 142 | straight_through=True, 143 | kl_weight=5e-4, 144 | temp_init=1.0, 145 | use_vqinterface=True, 146 | remap=None, 147 | unknown_index="random", 148 | ): 149 | super().__init__() 150 | 151 | self.embedding_dim = embedding_dim 152 | self.n_embed = n_embed 153 | 154 | self.straight_through = straight_through 155 | self.temperature = temp_init 156 | self.kl_weight = kl_weight 157 | 158 | self.proj = nn.Conv2d(num_hiddens, n_embed, 1) 159 | self.embed = nn.Embedding(n_embed, embedding_dim) 160 | 161 | self.use_vqinterface = use_vqinterface 162 | 163 | self.remap = remap 164 | if self.remap is not None: 165 | self.register_buffer("used", torch.tensor(np.load(self.remap))) 166 | self.re_embed = self.used.shape[0] 167 | self.unknown_index = unknown_index # "random" or "extra" or integer 168 | if self.unknown_index == "extra": 169 | self.unknown_index = self.re_embed 170 | self.re_embed = self.re_embed + 1 171 | print( 172 | f"Remapping {self.n_embed} indices to {self.re_embed} indices. " 173 | f"Using {self.unknown_index} for unknown indices." 174 | ) 175 | else: 176 | self.re_embed = n_embed 177 | 178 | def remap_to_used(self, inds): 179 | ishape = inds.shape 180 | assert len(ishape) > 1 181 | inds = inds.reshape(ishape[0], -1) 182 | used = self.used.to(inds) 183 | match = (inds[:, :, None] == used[None, None, ...]).long() 184 | new = match.argmax(-1) 185 | unknown = match.sum(2) < 1 186 | if self.unknown_index == "random": 187 | new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device) 188 | else: 189 | new[unknown] = self.unknown_index 190 | return new.reshape(ishape) 191 | 192 | def unmap_to_all(self, inds): 193 | ishape = inds.shape 194 | assert len(ishape) > 1 195 | inds = inds.reshape(ishape[0], -1) 196 | used = self.used.to(inds) 197 | if self.re_embed > self.used.shape[0]: # extra token 198 | inds[inds >= self.used.shape[0]] = 0 # simply set to zero 199 | back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) 200 | return back.reshape(ishape) 201 | 202 | def forward(self, z, temp=None, return_logits=False): 203 | # force hard = True when we are in eval mode, as we must quantize. actually, always true seems to work 204 | hard = self.straight_through if self.training else True 205 | temp = self.temperature if temp is None else temp 206 | 207 | logits = self.proj(z) 208 | if self.remap is not None: 209 | # continue only with used logits 210 | full_zeros = torch.zeros_like(logits) 211 | logits = logits[:, self.used, ...] 212 | 213 | soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard) 214 | if self.remap is not None: 215 | # go back to all entries but unused set to zero 216 | full_zeros[:, self.used, ...] = soft_one_hot 217 | soft_one_hot = full_zeros 218 | z_q = einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight) 219 | 220 | # + kl divergence to the prior loss 221 | qy = F.softmax(logits, dim=1) 222 | diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean() 223 | 224 | ind = soft_one_hot.argmax(dim=1) 225 | if self.remap is not None: 226 | ind = self.remap_to_used(ind) 227 | if self.use_vqinterface: 228 | if return_logits: 229 | return z_q, diff, (None, None, ind), logits 230 | return z_q, diff, (None, None, ind) 231 | return z_q, diff, ind 232 | 233 | def get_codebook_entry(self, indices, shape): 234 | b, h, w, c = shape 235 | assert b * h * w == indices.shape[0] 236 | indices = rearrange(indices, "(b h w) -> b h w", b=b, h=h, w=w) 237 | if self.remap is not None: 238 | indices = self.unmap_to_all(indices) 239 | one_hot = F.one_hot(indices, num_classes=self.n_embed).permute(0, 3, 1, 2).float() 240 | z_q = einsum("b n h w, n d -> b d h w", one_hot, self.embed.weight) 241 | return z_q 242 | -------------------------------------------------------------------------------- /sat/sgm/modules/cp_enc_dec.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.distributed 4 | import torch.nn as nn 5 | from ..util import ( 6 | get_context_parallel_group, 7 | get_context_parallel_rank, 8 | get_context_parallel_world_size, 9 | 10 | ) 11 | 12 | _USE_CP = True 13 | 14 | 15 | def cast_tuple(t, length=1): 16 | return t if isinstance(t, tuple) else ((t,) * length) 17 | 18 | 19 | def divisible_by(num, den): 20 | return (num % den) == 0 21 | 22 | 23 | def is_odd(n): 24 | return not divisible_by(n, 2) 25 | 26 | 27 | def exists(v): 28 | return v is not None 29 | 30 | 31 | def pair(t): 32 | return t if isinstance(t, tuple) else (t, t) 33 | 34 | 35 | def get_timestep_embedding(timesteps, embedding_dim): 36 | """ 37 | This matches the implementation in Denoising Diffusion Probabilistic Models: 38 | From Fairseq. 39 | Build sinusoidal embeddings. 40 | This matches the implementation in tensor2tensor, but differs slightly 41 | from the description in Section 3.5 of "Attention Is All You Need". 42 | """ 43 | assert len(timesteps.shape) == 1 44 | 45 | half_dim = embedding_dim // 2 46 | emb = math.log(10000) / (half_dim - 1) 47 | emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) 48 | emb = emb.to(device=timesteps.device) 49 | emb = timesteps.float()[:, None] * emb[None, :] 50 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) 51 | if embedding_dim % 2 == 1: # zero pad 52 | emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) 53 | return emb 54 | 55 | 56 | def nonlinearity(x): 57 | # swish 58 | return x * torch.sigmoid(x) 59 | 60 | 61 | def leaky_relu(p=0.1): 62 | return nn.LeakyReLU(p) 63 | 64 | 65 | def _split(input_, dim): 66 | cp_world_size = get_context_parallel_world_size() 67 | 68 | if cp_world_size == 1: 69 | return input_ 70 | 71 | cp_rank = get_context_parallel_rank() 72 | 73 | # print('in _split, cp_rank:', cp_rank, 'input_size:', input_.shape) 74 | 75 | inpu_first_frame_ = input_.transpose(0, dim)[:1].transpose(0, dim).contiguous() 76 | input_ = input_.transpose(0, dim)[1:].transpose(0, dim).contiguous() 77 | dim_size = input_.size()[dim] // cp_world_size 78 | 79 | input_list = torch.split(input_, dim_size, dim=dim) 80 | output = input_list[cp_rank] 81 | 82 | if cp_rank == 0: 83 | output = torch.cat([inpu_first_frame_, output], dim=dim) 84 | output = output.contiguous() 85 | 86 | # print('out _split, cp_rank:', cp_rank, 'output_size:', output.shape) 87 | 88 | return output 89 | 90 | 91 | def _gather(input_, dim): 92 | cp_world_size = get_context_parallel_world_size() 93 | 94 | # Bypass the function if context parallel is 1 95 | if cp_world_size == 1: 96 | return input_ 97 | 98 | group = get_context_parallel_group() 99 | cp_rank = get_context_parallel_rank() 100 | 101 | # print('in _gather, cp_rank:', cp_rank, 'input_size:', input_.shape) 102 | 103 | input_first_frame_ = input_.transpose(0, dim)[:1].transpose(0, dim).contiguous() 104 | if cp_rank == 0: 105 | input_ = input_.transpose(0, dim)[1:].transpose(0, dim).contiguous() 106 | 107 | tensor_list = [torch.empty_like(torch.cat([input_first_frame_, input_], dim=dim))] + [ 108 | torch.empty_like(input_) for _ in range(cp_world_size - 1) 109 | ] 110 | 111 | if cp_rank == 0: 112 | input_ = torch.cat([input_first_frame_, input_], dim=dim) 113 | 114 | tensor_list[cp_rank] = input_ 115 | torch.distributed.all_gather(tensor_list, input_, group=group) 116 | 117 | output = torch.cat(tensor_list, dim=dim).contiguous() 118 | 119 | # print('out _gather, cp_rank:', cp_rank, 'output_size:', output.shape) 120 | 121 | return output 122 | 123 | 124 | def _conv_split(input_, dim, kernel_size): 125 | cp_world_size = get_context_parallel_world_size() 126 | 127 | # Bypass the function if context parallel is 1 128 | if cp_world_size == 1: 129 | return input_ 130 | 131 | # print('in _conv_split, cp_rank:', cp_rank, 'input_size:', input_.shape) 132 | 133 | cp_rank = get_context_parallel_rank() 134 | 135 | dim_size = (input_.size()[dim] - kernel_size) // cp_world_size 136 | 137 | if cp_rank == 0: 138 | output = input_.transpose(dim, 0)[: dim_size + kernel_size].transpose(dim, 0) 139 | else: 140 | output = input_.transpose(dim, 0)[cp_rank * dim_size + 1 : (cp_rank + 1) * dim_size + kernel_size].transpose( 141 | dim, 0 142 | ) 143 | output = output.contiguous() 144 | 145 | # print('out _conv_split, cp_rank:', cp_rank, 'input_size:', output.shape) 146 | 147 | return output 148 | 149 | 150 | def _conv_gather(input_, dim, kernel_size): 151 | cp_world_size = get_context_parallel_world_size() 152 | 153 | # Bypass the function if context parallel is 1 154 | if cp_world_size == 1: 155 | return input_ 156 | 157 | group = get_context_parallel_group() 158 | cp_rank = get_context_parallel_rank() 159 | 160 | # print('in _conv_gather, cp_rank:', cp_rank, 'input_size:', input_.shape) 161 | 162 | input_first_kernel_ = input_.transpose(0, dim)[:kernel_size].transpose(0, dim).contiguous() 163 | if cp_rank == 0: 164 | input_ = input_.transpose(0, dim)[kernel_size:].transpose(0, dim).contiguous() 165 | else: 166 | input_ = input_.transpose(0, dim)[kernel_size - 1 :].transpose(0, dim).contiguous() 167 | 168 | tensor_list = [torch.empty_like(torch.cat([input_first_kernel_, input_], dim=dim))] + [ 169 | torch.empty_like(input_) for _ in range(cp_world_size - 1) 170 | ] 171 | if cp_rank == 0: 172 | input_ = torch.cat([input_first_kernel_, input_], dim=dim) 173 | 174 | tensor_list[cp_rank] = input_ 175 | torch.distributed.all_gather(tensor_list, input_, group=group) 176 | 177 | # Note: torch.cat already creates a contiguous tensor. 178 | output = torch.cat(tensor_list, dim=dim).contiguous() 179 | 180 | # print('out _conv_gather, cp_rank:', cp_rank, 'input_size:', output.shape) 181 | 182 | return output -------------------------------------------------------------------------------- /sat/sgm/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- 1 | from .denoiser import Denoiser 2 | from .discretizer import Discretization 3 | from .model import Decoder, Encoder, Model 4 | from .openaimodel import UNetModel 5 | from .sampling import BaseDiffusionSampler 6 | from .wrappers import OpenAIWrapper 7 | -------------------------------------------------------------------------------- /sat/sgm/modules/diffusionmodules/denoiser.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from ...util import append_dims, instantiate_from_config 7 | 8 | 9 | class Denoiser(nn.Module): 10 | def __init__(self, weighting_config, scaling_config): 11 | super().__init__() 12 | 13 | self.weighting = instantiate_from_config(weighting_config) 14 | self.scaling = instantiate_from_config(scaling_config) 15 | 16 | def possibly_quantize_sigma(self, sigma): 17 | return sigma 18 | 19 | def possibly_quantize_c_noise(self, c_noise): 20 | return c_noise 21 | 22 | def w(self, sigma): 23 | return self.weighting(sigma) 24 | 25 | def forward( 26 | self, 27 | network: nn.Module, 28 | input: torch.Tensor, 29 | sigma: torch.Tensor, 30 | cond: Dict, 31 | **additional_model_inputs, 32 | ) -> torch.Tensor: 33 | sigma = self.possibly_quantize_sigma(sigma) 34 | sigma_shape = sigma.shape 35 | sigma = append_dims(sigma, input.ndim) 36 | c_skip, c_out, c_in, c_noise = self.scaling(sigma, **additional_model_inputs) 37 | c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape)) 38 | return network(input * c_in, c_noise, cond, **additional_model_inputs) * c_out + input * c_skip 39 | 40 | 41 | class DiscreteDenoiser(Denoiser): 42 | def __init__( 43 | self, 44 | weighting_config, 45 | scaling_config, 46 | num_idx, 47 | discretization_config, 48 | do_append_zero=False, 49 | quantize_c_noise=True, 50 | flip=True, 51 | ): 52 | super().__init__(weighting_config, scaling_config) 53 | sigmas = instantiate_from_config(discretization_config)(num_idx, do_append_zero=do_append_zero, flip=flip) 54 | self.sigmas = sigmas 55 | # self.register_buffer("sigmas", sigmas) 56 | self.quantize_c_noise = quantize_c_noise 57 | 58 | def sigma_to_idx(self, sigma): 59 | dists = sigma - self.sigmas.to(sigma.device)[:, None] 60 | return dists.abs().argmin(dim=0).view(sigma.shape) 61 | 62 | def idx_to_sigma(self, idx): 63 | return self.sigmas.to(idx.device)[idx] 64 | 65 | def possibly_quantize_sigma(self, sigma): 66 | return self.idx_to_sigma(self.sigma_to_idx(sigma)) 67 | 68 | def possibly_quantize_c_noise(self, c_noise): 69 | if self.quantize_c_noise: 70 | return self.sigma_to_idx(c_noise) 71 | else: 72 | return c_noise 73 | -------------------------------------------------------------------------------- /sat/sgm/modules/diffusionmodules/denoiser_scaling.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Any, Tuple 3 | 4 | import torch 5 | 6 | 7 | class DenoiserScaling(ABC): 8 | @abstractmethod 9 | def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 10 | pass 11 | 12 | 13 | class EDMScaling: 14 | def __init__(self, sigma_data: float = 0.5): 15 | self.sigma_data = sigma_data 16 | 17 | def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 18 | c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) 19 | c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5 20 | c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5 21 | c_noise = 0.25 * sigma.log() 22 | return c_skip, c_out, c_in, c_noise 23 | 24 | 25 | class EpsScaling: 26 | def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 27 | c_skip = torch.ones_like(sigma, device=sigma.device) 28 | c_out = -sigma 29 | c_in = 1 / (sigma**2 + 1.0) ** 0.5 30 | c_noise = sigma.clone() 31 | return c_skip, c_out, c_in, c_noise 32 | 33 | 34 | class VScaling: 35 | def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 36 | c_skip = 1.0 / (sigma**2 + 1.0) 37 | c_out = -sigma / (sigma**2 + 1.0) ** 0.5 38 | c_in = 1.0 / (sigma**2 + 1.0) ** 0.5 39 | c_noise = sigma.clone() 40 | return c_skip, c_out, c_in, c_noise 41 | 42 | 43 | class VScalingWithEDMcNoise(DenoiserScaling): 44 | def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 45 | c_skip = 1.0 / (sigma**2 + 1.0) 46 | c_out = -sigma / (sigma**2 + 1.0) ** 0.5 47 | c_in = 1.0 / (sigma**2 + 1.0) ** 0.5 48 | c_noise = 0.25 * sigma.log() 49 | return c_skip, c_out, c_in, c_noise 50 | 51 | 52 | class VideoScaling: # similar to VScaling 53 | def __call__( 54 | self, alphas_cumprod_sqrt: torch.Tensor, **additional_model_inputs 55 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 56 | c_skip = alphas_cumprod_sqrt 57 | c_out = -((1 - alphas_cumprod_sqrt**2) ** 0.5) 58 | c_in = torch.ones_like(alphas_cumprod_sqrt, device=alphas_cumprod_sqrt.device) 59 | c_noise = additional_model_inputs["idx"].clone() 60 | return c_skip, c_out, c_in, c_noise 61 | -------------------------------------------------------------------------------- /sat/sgm/modules/diffusionmodules/denoiser_weighting.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class UnitWeighting: 5 | def __call__(self, sigma): 6 | return torch.ones_like(sigma, device=sigma.device) 7 | 8 | 9 | class EDMWeighting: 10 | def __init__(self, sigma_data=0.5): 11 | self.sigma_data = sigma_data 12 | 13 | def __call__(self, sigma): 14 | return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 15 | 16 | 17 | class VWeighting(EDMWeighting): 18 | def __init__(self): 19 | super().__init__(sigma_data=1.0) 20 | 21 | 22 | class EpsWeighting: 23 | def __call__(self, sigma): 24 | return sigma**-2.0 25 | -------------------------------------------------------------------------------- /sat/sgm/modules/diffusionmodules/discretizer.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from functools import partial 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from ...modules.diffusionmodules.util import make_beta_schedule 8 | from ...util import append_zero 9 | 10 | 11 | def generate_roughly_equally_spaced_steps(num_substeps: int, max_step: int) -> np.ndarray: 12 | return np.linspace(max_step - 1, 0, num_substeps, endpoint=False).astype(int)[::-1] 13 | 14 | 15 | class Discretization: 16 | def __call__(self, n, do_append_zero=True, device="cpu", flip=False, return_idx=False): 17 | if return_idx: 18 | sigmas, idx = self.get_sigmas(n, device=device, return_idx=return_idx) 19 | else: 20 | sigmas = self.get_sigmas(n, device=device, return_idx=return_idx) 21 | sigmas = append_zero(sigmas) if do_append_zero else sigmas 22 | if return_idx: 23 | return sigmas if not flip else torch.flip(sigmas, (0,)), idx 24 | else: 25 | return sigmas if not flip else torch.flip(sigmas, (0,)) 26 | 27 | @abstractmethod 28 | def get_sigmas(self, n, device): 29 | pass 30 | 31 | 32 | class EDMDiscretization(Discretization): 33 | def __init__(self, sigma_min=0.002, sigma_max=80.0, rho=7.0): 34 | self.sigma_min = sigma_min 35 | self.sigma_max = sigma_max 36 | self.rho = rho 37 | 38 | def get_sigmas(self, n, device="cpu"): 39 | ramp = torch.linspace(0, 1, n, device=device) 40 | min_inv_rho = self.sigma_min ** (1 / self.rho) 41 | max_inv_rho = self.sigma_max ** (1 / self.rho) 42 | sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** self.rho 43 | return sigmas 44 | 45 | 46 | class LegacyDDPMDiscretization(Discretization): 47 | def __init__( 48 | self, 49 | linear_start=0.00085, 50 | linear_end=0.0120, 51 | num_timesteps=1000, 52 | ): 53 | super().__init__() 54 | self.num_timesteps = num_timesteps 55 | betas = make_beta_schedule("linear", num_timesteps, linear_start=linear_start, linear_end=linear_end) 56 | alphas = 1.0 - betas 57 | self.alphas_cumprod = np.cumprod(alphas, axis=0) 58 | self.to_torch = partial(torch.tensor, dtype=torch.float32) 59 | 60 | def get_sigmas(self, n, device="cpu"): 61 | if n < self.num_timesteps: 62 | timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps) 63 | alphas_cumprod = self.alphas_cumprod[timesteps] 64 | elif n == self.num_timesteps: 65 | alphas_cumprod = self.alphas_cumprod 66 | else: 67 | raise ValueError 68 | 69 | to_torch = partial(torch.tensor, dtype=torch.float32, device=device) 70 | sigmas = to_torch((1 - alphas_cumprod) / alphas_cumprod) ** 0.5 71 | return torch.flip(sigmas, (0,)) # sigma_t: 14.4 -> 0.029 72 | 73 | 74 | class ZeroSNRDDPMDiscretization(Discretization): 75 | def __init__( 76 | self, 77 | linear_start=0.00085, 78 | linear_end=0.0120, 79 | num_timesteps=1000, 80 | shift_scale=1.0, # noise schedule t_n -> t_m: logSNR(t_m) = logSNR(t_n) - log(shift_scale) 81 | keep_start=False, 82 | post_shift=False, 83 | ): 84 | super().__init__() 85 | if keep_start and not post_shift: 86 | linear_start = linear_start / (shift_scale + (1 - shift_scale) * linear_start) 87 | self.num_timesteps = num_timesteps 88 | betas = make_beta_schedule("linear", num_timesteps, linear_start=linear_start, linear_end=linear_end) 89 | alphas = 1.0 - betas 90 | self.alphas_cumprod = np.cumprod(alphas, axis=0) 91 | self.to_torch = partial(torch.tensor, dtype=torch.float32) 92 | 93 | # SNR shift 94 | if not post_shift: 95 | self.alphas_cumprod = self.alphas_cumprod / (shift_scale + (1 - shift_scale) * self.alphas_cumprod) 96 | 97 | self.post_shift = post_shift 98 | self.shift_scale = shift_scale 99 | 100 | def get_sigmas(self, n, device="cpu", return_idx=False): 101 | if n < self.num_timesteps: 102 | timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps) 103 | alphas_cumprod = self.alphas_cumprod[timesteps] 104 | elif n == self.num_timesteps: 105 | alphas_cumprod = self.alphas_cumprod 106 | else: 107 | raise ValueError 108 | 109 | to_torch = partial(torch.tensor, dtype=torch.float32, device=device) 110 | alphas_cumprod = to_torch(alphas_cumprod) 111 | alphas_cumprod_sqrt = alphas_cumprod.sqrt() 112 | alphas_cumprod_sqrt_0 = alphas_cumprod_sqrt[0].clone() 113 | alphas_cumprod_sqrt_T = alphas_cumprod_sqrt[-1].clone() 114 | 115 | alphas_cumprod_sqrt -= alphas_cumprod_sqrt_T 116 | alphas_cumprod_sqrt *= alphas_cumprod_sqrt_0 / (alphas_cumprod_sqrt_0 - alphas_cumprod_sqrt_T) 117 | 118 | if self.post_shift: 119 | alphas_cumprod_sqrt = ( 120 | alphas_cumprod_sqrt**2 / (self.shift_scale + (1 - self.shift_scale) * alphas_cumprod_sqrt**2) 121 | ) ** 0.5 122 | 123 | if return_idx: 124 | return torch.flip(alphas_cumprod_sqrt, (0,)), timesteps 125 | else: 126 | return torch.flip(alphas_cumprod_sqrt, (0,)) # sqrt(alpha_t): 0 -> 0.99 127 | -------------------------------------------------------------------------------- /sat/sgm/modules/diffusionmodules/guiders.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from abc import ABC, abstractmethod 3 | from typing import Dict, List, Optional, Tuple, Union 4 | from functools import partial 5 | import math 6 | 7 | import torch 8 | from einops import rearrange, repeat 9 | 10 | from ...util import append_dims, default, instantiate_from_config 11 | 12 | 13 | class Guider(ABC): 14 | @abstractmethod 15 | def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor: 16 | pass 17 | 18 | def prepare_inputs(self, x: torch.Tensor, s: float, c: Dict, uc: Dict) -> Tuple[torch.Tensor, float, Dict]: 19 | pass 20 | 21 | 22 | class VanillaCFG: 23 | """ 24 | implements parallelized CFG 25 | """ 26 | 27 | def __init__(self, scale, dyn_thresh_config=None): 28 | self.scale = scale 29 | scale_schedule = lambda scale, sigma: scale # independent of step 30 | self.scale_schedule = partial(scale_schedule, scale) 31 | self.dyn_thresh = instantiate_from_config( 32 | default( 33 | dyn_thresh_config, 34 | {"target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"}, 35 | ) 36 | ) 37 | 38 | def __call__(self, x, sigma, scale=None): 39 | x_u, x_c = x.chunk(2) 40 | scale_value = default(scale, self.scale_schedule(sigma)) 41 | x_pred = self.dyn_thresh(x_u, x_c, scale_value) 42 | return x_pred 43 | 44 | def prepare_inputs(self, x, s, c, uc): 45 | c_out = dict() 46 | 47 | for k in c: 48 | if k in ["vector", "crossattn", "concat"]: 49 | c_out[k] = torch.cat((uc[k], c[k]), 0) 50 | else: 51 | assert c[k] == uc[k] 52 | c_out[k] = c[k] 53 | return torch.cat([x] * 2), torch.cat([s] * 2), c_out 54 | 55 | 56 | class DynamicCFG(VanillaCFG): 57 | def __init__(self, scale, exp, num_steps, dyn_thresh_config=None): 58 | super().__init__(scale, dyn_thresh_config) 59 | scale_schedule = ( 60 | lambda scale, sigma, step_index: 1 + scale * (1 - math.cos(math.pi * (step_index / num_steps) ** exp)) / 2 61 | ) 62 | self.scale_schedule = partial(scale_schedule, scale) 63 | self.dyn_thresh = instantiate_from_config( 64 | default( 65 | dyn_thresh_config, 66 | {"target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"}, 67 | ) 68 | ) 69 | 70 | def __call__(self, x, sigma, step_index, scale=None): 71 | x_u, x_c = x.chunk(2) 72 | scale_value = self.scale_schedule(sigma, step_index.item()) 73 | x_pred = self.dyn_thresh(x_u, x_c, scale_value) 74 | return x_pred 75 | 76 | 77 | class IdentityGuider: 78 | def __call__(self, x, sigma): 79 | return x 80 | 81 | def prepare_inputs(self, x, s, c, uc): 82 | c_out = dict() 83 | 84 | for k in c: 85 | c_out[k] = c[k] 86 | 87 | return x, s, c_out 88 | -------------------------------------------------------------------------------- /sat/sgm/modules/diffusionmodules/loss.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from omegaconf import ListConfig 7 | import math 8 | 9 | from ...modules.diffusionmodules.sampling import VideoDDIMSampler, VPSDEDPMPP2MSampler 10 | from ...util import append_dims, instantiate_from_config 11 | from ...modules.autoencoding.lpips.loss.lpips import LPIPS 12 | 13 | # import rearrange 14 | from einops import rearrange 15 | import random 16 | from sat import mpu 17 | 18 | 19 | class StandardDiffusionLoss(nn.Module): 20 | def __init__( 21 | self, 22 | sigma_sampler_config, 23 | type="l2", 24 | offset_noise_level=0.0, 25 | batch2model_keys: Optional[Union[str, List[str], ListConfig]] = None, 26 | ): 27 | super().__init__() 28 | 29 | assert type in ["l2", "l1", "lpips"] 30 | 31 | self.sigma_sampler = instantiate_from_config(sigma_sampler_config) 32 | 33 | self.type = type 34 | self.offset_noise_level = offset_noise_level 35 | 36 | if type == "lpips": 37 | self.lpips = LPIPS().eval() 38 | 39 | if not batch2model_keys: 40 | batch2model_keys = [] 41 | 42 | if isinstance(batch2model_keys, str): 43 | batch2model_keys = [batch2model_keys] 44 | 45 | self.batch2model_keys = set(batch2model_keys) 46 | 47 | def __call__(self, network, denoiser, conditioner, input, batch): 48 | cond = conditioner(batch) 49 | additional_model_inputs = {key: batch[key] for key in self.batch2model_keys.intersection(batch)} 50 | 51 | sigmas = self.sigma_sampler(input.shape[0]).to(input.device) 52 | noise = torch.randn_like(input) 53 | if self.offset_noise_level > 0.0: 54 | noise = ( 55 | noise + append_dims(torch.randn(input.shape[0]).to(input.device), input.ndim) * self.offset_noise_level 56 | ) 57 | noise = noise.to(input.dtype) 58 | noised_input = input.float() + noise * append_dims(sigmas, input.ndim) 59 | model_output = denoiser(network, noised_input, sigmas, cond, **additional_model_inputs) 60 | w = append_dims(denoiser.w(sigmas), input.ndim) 61 | return self.get_loss(model_output, input, w) 62 | 63 | def get_loss(self, model_output, target, w): 64 | if self.type == "l2": 65 | return torch.mean((w * (model_output - target) ** 2).reshape(target.shape[0], -1), 1) 66 | elif self.type == "l1": 67 | return torch.mean((w * (model_output - target).abs()).reshape(target.shape[0], -1), 1) 68 | elif self.type == "lpips": 69 | loss = self.lpips(model_output, target).reshape(-1) 70 | return loss 71 | 72 | 73 | class VideoDiffusionLoss(StandardDiffusionLoss): 74 | def __init__(self, block_scale=None, block_size=None, min_snr_value=None, fixed_frames=0, **kwargs): 75 | self.fixed_frames = fixed_frames 76 | self.block_scale = block_scale 77 | self.block_size = block_size 78 | self.min_snr_value = min_snr_value 79 | super().__init__(**kwargs) 80 | 81 | def __call__(self, network, denoiser, conditioner, input, batch): 82 | cond = conditioner(batch) 83 | additional_model_inputs = {key: batch[key] for key in self.batch2model_keys.intersection(batch)} 84 | 85 | alphas_cumprod_sqrt, idx = self.sigma_sampler(input.shape[0], return_idx=True) 86 | alphas_cumprod_sqrt = alphas_cumprod_sqrt.to(input.device) 87 | idx = idx.to(input.device) 88 | 89 | noise = torch.randn_like(input) 90 | 91 | # broadcast noise 92 | mp_size = mpu.get_model_parallel_world_size() 93 | global_rank = torch.distributed.get_rank() // mp_size 94 | src = global_rank * mp_size 95 | torch.distributed.broadcast(idx, src=src, group=mpu.get_model_parallel_group()) 96 | torch.distributed.broadcast(noise, src=src, group=mpu.get_model_parallel_group()) 97 | torch.distributed.broadcast(alphas_cumprod_sqrt, src=src, group=mpu.get_model_parallel_group()) 98 | 99 | additional_model_inputs["idx"] = idx 100 | 101 | if self.offset_noise_level > 0.0: 102 | noise = ( 103 | noise + append_dims(torch.randn(input.shape[0]).to(input.device), input.ndim) * self.offset_noise_level 104 | ) 105 | 106 | noised_input = input.float() * append_dims(alphas_cumprod_sqrt, input.ndim) + noise * append_dims( 107 | (1 - alphas_cumprod_sqrt**2) ** 0.5, input.ndim 108 | ) 109 | 110 | model_output = denoiser(network, noised_input, alphas_cumprod_sqrt, cond, **additional_model_inputs) 111 | w = append_dims(1 / (1 - alphas_cumprod_sqrt**2), input.ndim) # v-pred 112 | 113 | if self.min_snr_value is not None: 114 | w = min(w, self.min_snr_value) 115 | return self.get_loss(model_output, input, w) 116 | 117 | def get_loss(self, model_output, target, w): 118 | if self.type == "l2": 119 | return torch.mean((w * (model_output - target) ** 2).reshape(target.shape[0], -1), 1) 120 | elif self.type == "l1": 121 | return torch.mean((w * (model_output - target).abs()).reshape(target.shape[0], -1), 1) 122 | elif self.type == "lpips": 123 | loss = self.lpips(model_output, target).reshape(-1) 124 | return loss 125 | 126 | 127 | def get_3d_position_ids(frame_len, h, w): 128 | i = torch.arange(frame_len).view(frame_len, 1, 1).expand(frame_len, h, w) 129 | j = torch.arange(h).view(1, h, 1).expand(frame_len, h, w) 130 | k = torch.arange(w).view(1, 1, w).expand(frame_len, h, w) 131 | position_ids = torch.stack([i, j, k], dim=-1).reshape(-1, 3) 132 | return position_ids 133 | -------------------------------------------------------------------------------- /sat/sgm/modules/diffusionmodules/sampling_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from scipy import integrate 3 | 4 | from ...util import append_dims 5 | from einops import rearrange 6 | 7 | 8 | class NoDynamicThresholding: 9 | def __call__(self, uncond, cond, scale): 10 | scale = append_dims(scale, cond.ndim) if isinstance(scale, torch.Tensor) else scale 11 | return uncond + scale * (cond - uncond) 12 | 13 | 14 | class StaticThresholding: 15 | def __call__(self, uncond, cond, scale): 16 | result = uncond + scale * (cond - uncond) 17 | result = torch.clamp(result, min=-1.0, max=1.0) 18 | return result 19 | 20 | 21 | def dynamic_threshold(x, p=0.95): 22 | N, T, C, H, W = x.shape 23 | x = rearrange(x, "n t c h w -> n c (t h w)") 24 | l, r = x.quantile(q=torch.tensor([1 - p, p], device=x.device), dim=-1, keepdim=True) 25 | s = torch.maximum(-l, r) 26 | threshold_mask = (s > 1).expand(-1, -1, H * W * T) 27 | if threshold_mask.any(): 28 | x = torch.where(threshold_mask, x.clamp(min=-1 * s, max=s), x) 29 | x = rearrange(x, "n c (t h w) -> n t c h w", t=T, h=H, w=W) 30 | return x 31 | 32 | 33 | def dynamic_thresholding2(x0): 34 | p = 0.995 # A hyperparameter in the paper of "Imagen" [1]. 35 | origin_dtype = x0.dtype 36 | x0 = x0.to(torch.float32) 37 | s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) 38 | s = append_dims(torch.maximum(s, torch.ones_like(s).to(s.device)), x0.dim()) 39 | x0 = torch.clamp(x0, -s, s) # / s 40 | return x0.to(origin_dtype) 41 | 42 | 43 | def latent_dynamic_thresholding(x0): 44 | p = 0.9995 45 | origin_dtype = x0.dtype 46 | x0 = x0.to(torch.float32) 47 | s = torch.quantile(torch.abs(x0), p, dim=2) 48 | s = append_dims(s, x0.dim()) 49 | x0 = torch.clamp(x0, -s, s) / s 50 | return x0.to(origin_dtype) 51 | 52 | 53 | def dynamic_thresholding3(x0): 54 | p = 0.995 # A hyperparameter in the paper of "Imagen" [1]. 55 | origin_dtype = x0.dtype 56 | x0 = x0.to(torch.float32) 57 | s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) 58 | s = append_dims(torch.maximum(s, torch.ones_like(s).to(s.device)), x0.dim()) 59 | x0 = torch.clamp(x0, -s, s) # / s 60 | return x0.to(origin_dtype) 61 | 62 | 63 | class DynamicThresholding: 64 | def __call__(self, uncond, cond, scale): 65 | mean = uncond.mean() 66 | std = uncond.std() 67 | result = uncond + scale * (cond - uncond) 68 | result_mean, result_std = result.mean(), result.std() 69 | result = (result - result_mean) / result_std * std 70 | # result = dynamic_thresholding3(result) 71 | return result 72 | 73 | 74 | class DynamicThresholdingV1: 75 | def __init__(self, scale_factor): 76 | self.scale_factor = scale_factor 77 | 78 | def __call__(self, uncond, cond, scale): 79 | result = uncond + scale * (cond - uncond) 80 | unscaled_result = result / self.scale_factor 81 | B, T, C, H, W = unscaled_result.shape 82 | flattened = rearrange(unscaled_result, "b t c h w -> b c (t h w)") 83 | means = flattened.mean(dim=2).unsqueeze(2) 84 | recentered = flattened - means 85 | magnitudes = recentered.abs().max() 86 | normalized = recentered / magnitudes 87 | thresholded = latent_dynamic_thresholding(normalized) 88 | denormalized = thresholded * magnitudes 89 | uncentered = denormalized + means 90 | unflattened = rearrange(uncentered, "b c (t h w) -> b t c h w", t=T, h=H, w=W) 91 | scaled_result = unflattened * self.scale_factor 92 | return scaled_result 93 | 94 | 95 | class DynamicThresholdingV2: 96 | def __call__(self, uncond, cond, scale): 97 | B, T, C, H, W = uncond.shape 98 | diff = cond - uncond 99 | mim_target = uncond + diff * 4.0 100 | cfg_target = uncond + diff * 8.0 101 | 102 | mim_flattened = rearrange(mim_target, "b t c h w -> b c (t h w)") 103 | cfg_flattened = rearrange(cfg_target, "b t c h w -> b c (t h w)") 104 | mim_means = mim_flattened.mean(dim=2).unsqueeze(2) 105 | cfg_means = cfg_flattened.mean(dim=2).unsqueeze(2) 106 | mim_centered = mim_flattened - mim_means 107 | cfg_centered = cfg_flattened - cfg_means 108 | 109 | mim_scaleref = mim_centered.std(dim=2).unsqueeze(2) 110 | cfg_scaleref = cfg_centered.std(dim=2).unsqueeze(2) 111 | 112 | cfg_renormalized = cfg_centered / cfg_scaleref * mim_scaleref 113 | 114 | result = cfg_renormalized + cfg_means 115 | unflattened = rearrange(result, "b c (t h w) -> b t c h w", t=T, h=H, w=W) 116 | 117 | return unflattened 118 | 119 | 120 | def linear_multistep_coeff(order, t, i, j, epsrel=1e-4): 121 | if order - 1 > i: 122 | raise ValueError(f"Order {order} too high for step {i}") 123 | 124 | def fn(tau): 125 | prod = 1.0 126 | for k in range(order): 127 | if j == k: 128 | continue 129 | prod *= (tau - t[i - k]) / (t[i - j] - t[i - k]) 130 | return prod 131 | 132 | return integrate.quad(fn, t[i], t[i + 1], epsrel=epsrel)[0] 133 | 134 | 135 | def get_ancestral_step(sigma_from, sigma_to, eta=1.0): 136 | if not eta: 137 | return sigma_to, 0.0 138 | sigma_up = torch.minimum( 139 | sigma_to, 140 | eta * (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5, 141 | ) 142 | sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 143 | return sigma_down, sigma_up 144 | 145 | 146 | def to_d(x, sigma, denoised): 147 | return (x - denoised) / append_dims(sigma, x.ndim) 148 | 149 | 150 | def to_neg_log_sigma(sigma): 151 | return sigma.log().neg() 152 | 153 | 154 | def to_sigma(neg_log_sigma): 155 | return neg_log_sigma.neg().exp() 156 | -------------------------------------------------------------------------------- /sat/sgm/modules/diffusionmodules/sigma_sampling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed 3 | 4 | from sat import mpu 5 | 6 | from ...util import default, instantiate_from_config 7 | 8 | 9 | class EDMSampling: 10 | def __init__(self, p_mean=-1.2, p_std=1.2): 11 | self.p_mean = p_mean 12 | self.p_std = p_std 13 | 14 | def __call__(self, n_samples, rand=None): 15 | log_sigma = self.p_mean + self.p_std * default(rand, torch.randn((n_samples,))) 16 | return log_sigma.exp() 17 | 18 | 19 | class DiscreteSampling: 20 | def __init__(self, discretization_config, num_idx, do_append_zero=False, flip=True, uniform_sampling=False): 21 | self.num_idx = num_idx 22 | self.sigmas = instantiate_from_config(discretization_config)(num_idx, do_append_zero=do_append_zero, flip=flip) 23 | world_size = mpu.get_data_parallel_world_size() 24 | self.uniform_sampling = uniform_sampling 25 | if self.uniform_sampling: 26 | i = 1 27 | while True: 28 | if world_size % i != 0 or num_idx % (world_size // i) != 0: 29 | i += 1 30 | else: 31 | self.group_num = world_size // i 32 | break 33 | 34 | assert self.group_num > 0 35 | assert world_size % self.group_num == 0 36 | self.group_width = world_size // self.group_num # the number of rank in one group 37 | self.sigma_interval = self.num_idx // self.group_num 38 | 39 | def idx_to_sigma(self, idx): 40 | return self.sigmas[idx] 41 | 42 | def __call__(self, n_samples, rand=None, return_idx=False): 43 | if self.uniform_sampling: 44 | rank = mpu.get_data_parallel_rank() 45 | group_index = rank // self.group_width 46 | idx = default( 47 | rand, 48 | torch.randint( 49 | group_index * self.sigma_interval, (group_index + 1) * self.sigma_interval, (n_samples,) 50 | ), 51 | ) 52 | else: 53 | idx = default( 54 | rand, 55 | torch.randint(0, self.num_idx, (n_samples,)), 56 | ) 57 | if return_idx: 58 | return self.idx_to_sigma(idx), idx 59 | else: 60 | return self.idx_to_sigma(idx) 61 | 62 | 63 | class PartialDiscreteSampling: 64 | def __init__(self, discretization_config, total_num_idx, partial_num_idx, do_append_zero=False, flip=True): 65 | self.total_num_idx = total_num_idx 66 | self.partial_num_idx = partial_num_idx 67 | self.sigmas = instantiate_from_config(discretization_config)( 68 | total_num_idx, do_append_zero=do_append_zero, flip=flip 69 | ) 70 | 71 | def idx_to_sigma(self, idx): 72 | return self.sigmas[idx] 73 | 74 | def __call__(self, n_samples, rand=None): 75 | idx = default( 76 | rand, 77 | # torch.randint(self.total_num_idx-self.partial_num_idx, self.total_num_idx, (n_samples,)), 78 | torch.randint(0, self.partial_num_idx, (n_samples,)), 79 | ) 80 | return self.idx_to_sigma(idx) 81 | -------------------------------------------------------------------------------- /sat/sgm/modules/diffusionmodules/wrappers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from packaging import version 4 | 5 | OPENAIUNETWRAPPER = "sgm.modules.diffusionmodules.wrappers.OpenAIWrapper" 6 | 7 | 8 | class IdentityWrapper(nn.Module): 9 | def __init__(self, diffusion_model, compile_model: bool = False, dtype: torch.dtype = torch.float32): 10 | super().__init__() 11 | compile = ( 12 | torch.compile 13 | if (version.parse(torch.__version__) >= version.parse("2.0.0")) and compile_model 14 | else lambda x: x 15 | ) 16 | self.diffusion_model = compile(diffusion_model) 17 | self.dtype = dtype 18 | 19 | def forward(self, *args, **kwargs): 20 | return self.diffusion_model(*args, **kwargs) 21 | 22 | 23 | class OpenAIWrapper(IdentityWrapper): 24 | def forward(self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs) -> torch.Tensor: 25 | for key in c: 26 | c[key] = c[key].to(self.dtype) 27 | 28 | if x.dim() == 4: 29 | x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=1) 30 | elif x.dim() == 5: 31 | x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=2) 32 | else: 33 | raise ValueError("Input tensor must be 4D or 5D") 34 | 35 | return self.diffusion_model( 36 | x, 37 | timesteps=t, 38 | context=c.get("crossattn", None), 39 | y=c.get("vector", None), 40 | **kwargs, 41 | ) 42 | -------------------------------------------------------------------------------- /sat/sgm/modules/distributions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vchitect/RepVideo/7727c0c592a152b9cd09a47d9d1ba3b81792e25f/sat/sgm/modules/distributions/__init__.py -------------------------------------------------------------------------------- /sat/sgm/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | # x = self.mean + self.std * torch.randn(self.mean.shape).to( 37 | # device=self.parameters.device 38 | # ) 39 | x = self.mean + self.std * torch.randn_like(self.mean).to(device=self.parameters.device) 40 | return x 41 | 42 | def kl(self, other=None): 43 | if self.deterministic: 44 | return torch.Tensor([0.0]) 45 | else: 46 | if other is None: 47 | return 0.5 * torch.sum( 48 | torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, 49 | dim=[1, 2, 3], 50 | ) 51 | else: 52 | return 0.5 * torch.sum( 53 | torch.pow(self.mean - other.mean, 2) / other.var 54 | + self.var / other.var 55 | - 1.0 56 | - self.logvar 57 | + other.logvar, 58 | dim=[1, 2, 3], 59 | ) 60 | 61 | def nll(self, sample, dims=[1, 2, 3]): 62 | if self.deterministic: 63 | return torch.Tensor([0.0]) 64 | logtwopi = np.log(2.0 * np.pi) 65 | return 0.5 * torch.sum( 66 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 67 | dim=dims, 68 | ) 69 | 70 | def mode(self): 71 | return self.mean 72 | 73 | 74 | def normal_kl(mean1, logvar1, mean2, logvar2): 75 | """ 76 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 77 | Compute the KL divergence between two gaussians. 78 | Shapes are automatically broadcasted, so batches can be compared to 79 | scalars, among other use cases. 80 | """ 81 | tensor = None 82 | for obj in (mean1, logvar1, mean2, logvar2): 83 | if isinstance(obj, torch.Tensor): 84 | tensor = obj 85 | break 86 | assert tensor is not None, "at least one argument must be a Tensor" 87 | 88 | # Force variances to be Tensors. Broadcasting helps convert scalars to 89 | # Tensors, but it does not work for torch.exp(). 90 | logvar1, logvar2 = [x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) for x in (logvar1, logvar2)] 91 | 92 | return 0.5 * ( 93 | -1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 94 | ) 95 | -------------------------------------------------------------------------------- /sat/sgm/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError("Decay must be between 0 and 1") 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer( 14 | "num_updates", 15 | torch.tensor(0, dtype=torch.int) if use_num_upates else torch.tensor(-1, dtype=torch.int), 16 | ) 17 | 18 | for name, p in model.named_parameters(): 19 | if p.requires_grad: 20 | # remove as '.'-character is not allowed in buffers 21 | s_name = name.replace(".", "") 22 | self.m_name2s_name.update({name: s_name}) 23 | self.register_buffer(s_name, p.clone().detach().data) 24 | 25 | self.collected_params = [] 26 | 27 | def reset_num_updates(self): 28 | del self.num_updates 29 | self.register_buffer("num_updates", torch.tensor(0, dtype=torch.int)) 30 | 31 | def forward(self, model): 32 | decay = self.decay 33 | 34 | if self.num_updates >= 0: 35 | self.num_updates += 1 36 | decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) 37 | 38 | one_minus_decay = 1.0 - decay 39 | 40 | with torch.no_grad(): 41 | m_param = dict(model.named_parameters()) 42 | shadow_params = dict(self.named_buffers()) 43 | 44 | for key in m_param: 45 | if m_param[key].requires_grad: 46 | sname = self.m_name2s_name[key] 47 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 48 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 49 | else: 50 | assert not key in self.m_name2s_name 51 | 52 | def copy_to(self, model): 53 | m_param = dict(model.named_parameters()) 54 | shadow_params = dict(self.named_buffers()) 55 | for key in m_param: 56 | if m_param[key].requires_grad: 57 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 58 | else: 59 | assert not key in self.m_name2s_name 60 | 61 | def store(self, parameters): 62 | """ 63 | Save the current parameters for restoring later. 64 | Args: 65 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 66 | temporarily stored. 67 | """ 68 | self.collected_params = [param.clone() for param in parameters] 69 | 70 | def restore(self, parameters): 71 | """ 72 | Restore the parameters stored with the `store` method. 73 | Useful to validate the model with EMA parameters without affecting the 74 | original optimization process. Store the parameters before the 75 | `copy_to` method. After validation (or model saving), use this to 76 | restore the former parameters. 77 | Args: 78 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 79 | updated with the stored parameters. 80 | """ 81 | for c_param, param in zip(self.collected_params, parameters): 82 | param.data.copy_(c_param.data) 83 | -------------------------------------------------------------------------------- /sat/sgm/modules/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vchitect/RepVideo/7727c0c592a152b9cd09a47d9d1ba3b81792e25f/sat/sgm/modules/encoders/__init__.py -------------------------------------------------------------------------------- /sat/sgm/modules/encoders/modules.py: -------------------------------------------------------------------------------- 1 | import math 2 | from contextlib import nullcontext 3 | from functools import partial 4 | from typing import Dict, List, Optional, Tuple, Union 5 | 6 | import kornia 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | from einops import rearrange, repeat 11 | from omegaconf import ListConfig 12 | from torch.utils.checkpoint import checkpoint 13 | from transformers import ( 14 | T5EncoderModel, 15 | T5Tokenizer, 16 | ) 17 | 18 | from ...util import ( 19 | append_dims, 20 | autocast, 21 | count_params, 22 | default, 23 | disabled_train, 24 | expand_dims_like, 25 | instantiate_from_config, 26 | ) 27 | 28 | 29 | class AbstractEmbModel(nn.Module): 30 | def __init__(self): 31 | super().__init__() 32 | self._is_trainable = None 33 | self._ucg_rate = None 34 | self._input_key = None 35 | 36 | @property 37 | def is_trainable(self) -> bool: 38 | return self._is_trainable 39 | 40 | @property 41 | def ucg_rate(self) -> Union[float, torch.Tensor]: 42 | return self._ucg_rate 43 | 44 | @property 45 | def input_key(self) -> str: 46 | return self._input_key 47 | 48 | @is_trainable.setter 49 | def is_trainable(self, value: bool): 50 | self._is_trainable = value 51 | 52 | @ucg_rate.setter 53 | def ucg_rate(self, value: Union[float, torch.Tensor]): 54 | self._ucg_rate = value 55 | 56 | @input_key.setter 57 | def input_key(self, value: str): 58 | self._input_key = value 59 | 60 | @is_trainable.deleter 61 | def is_trainable(self): 62 | del self._is_trainable 63 | 64 | @ucg_rate.deleter 65 | def ucg_rate(self): 66 | del self._ucg_rate 67 | 68 | @input_key.deleter 69 | def input_key(self): 70 | del self._input_key 71 | 72 | 73 | class GeneralConditioner(nn.Module): 74 | OUTPUT_DIM2KEYS = {2: "vector", 3: "crossattn", 4: "concat", 5: "concat"} 75 | KEY2CATDIM = {"vector": 1, "crossattn": 2, "concat": 1} 76 | 77 | def __init__(self, emb_models: Union[List, ListConfig], cor_embs=[], cor_p=[]): 78 | super().__init__() 79 | embedders = [] 80 | for n, embconfig in enumerate(emb_models): 81 | embedder = instantiate_from_config(embconfig) 82 | assert isinstance( 83 | embedder, AbstractEmbModel 84 | ), f"embedder model {embedder.__class__.__name__} has to inherit from AbstractEmbModel" 85 | embedder.is_trainable = embconfig.get("is_trainable", False) 86 | embedder.ucg_rate = embconfig.get("ucg_rate", 0.0) 87 | if not embedder.is_trainable: 88 | embedder.train = disabled_train 89 | for param in embedder.parameters(): 90 | param.requires_grad = False 91 | embedder.eval() 92 | print( 93 | f"Initialized embedder #{n}: {embedder.__class__.__name__} " 94 | f"with {count_params(embedder, False)} params. Trainable: {embedder.is_trainable}" 95 | ) 96 | 97 | if "input_key" in embconfig: 98 | embedder.input_key = embconfig["input_key"] 99 | elif "input_keys" in embconfig: 100 | embedder.input_keys = embconfig["input_keys"] 101 | else: 102 | raise KeyError(f"need either 'input_key' or 'input_keys' for embedder {embedder.__class__.__name__}") 103 | 104 | embedder.legacy_ucg_val = embconfig.get("legacy_ucg_value", None) 105 | if embedder.legacy_ucg_val is not None: 106 | embedder.ucg_prng = np.random.RandomState() 107 | 108 | embedders.append(embedder) 109 | self.embedders = nn.ModuleList(embedders) 110 | 111 | if len(cor_embs) > 0: 112 | assert len(cor_p) == 2 ** len(cor_embs) 113 | self.cor_embs = cor_embs 114 | self.cor_p = cor_p 115 | 116 | def possibly_get_ucg_val(self, embedder: AbstractEmbModel, batch: Dict) -> Dict: 117 | assert embedder.legacy_ucg_val is not None 118 | p = embedder.ucg_rate 119 | val = embedder.legacy_ucg_val 120 | for i in range(len(batch[embedder.input_key])): 121 | if embedder.ucg_prng.choice(2, p=[1 - p, p]): 122 | batch[embedder.input_key][i] = val 123 | return batch 124 | 125 | def surely_get_ucg_val(self, embedder: AbstractEmbModel, batch: Dict, cond_or_not) -> Dict: 126 | assert embedder.legacy_ucg_val is not None 127 | val = embedder.legacy_ucg_val 128 | for i in range(len(batch[embedder.input_key])): 129 | if cond_or_not[i]: 130 | batch[embedder.input_key][i] = val 131 | return batch 132 | 133 | def get_single_embedding( 134 | self, 135 | embedder, 136 | batch, 137 | output, 138 | cond_or_not: Optional[np.ndarray] = None, 139 | force_zero_embeddings: Optional[List] = None, 140 | ): 141 | embedding_context = nullcontext if embedder.is_trainable else torch.no_grad 142 | with embedding_context(): 143 | if hasattr(embedder, "input_key") and (embedder.input_key is not None): 144 | if embedder.legacy_ucg_val is not None: 145 | if cond_or_not is None: 146 | batch = self.possibly_get_ucg_val(embedder, batch) 147 | else: 148 | batch = self.surely_get_ucg_val(embedder, batch, cond_or_not) 149 | emb_out = embedder(batch[embedder.input_key]) 150 | elif hasattr(embedder, "input_keys"): 151 | emb_out = embedder(*[batch[k] for k in embedder.input_keys]) 152 | assert isinstance( 153 | emb_out, (torch.Tensor, list, tuple) 154 | ), f"encoder outputs must be tensors or a sequence, but got {type(emb_out)}" 155 | if not isinstance(emb_out, (list, tuple)): 156 | emb_out = [emb_out] 157 | for emb in emb_out: 158 | out_key = self.OUTPUT_DIM2KEYS[emb.dim()] 159 | if embedder.ucg_rate > 0.0 and embedder.legacy_ucg_val is None: 160 | if cond_or_not is None: 161 | emb = ( 162 | expand_dims_like( 163 | torch.bernoulli((1.0 - embedder.ucg_rate) * torch.ones(emb.shape[0], device=emb.device)), 164 | emb, 165 | ) 166 | * emb 167 | ) 168 | else: 169 | emb = ( 170 | expand_dims_like( 171 | torch.tensor(1 - cond_or_not, dtype=emb.dtype, device=emb.device), 172 | emb, 173 | ) 174 | * emb 175 | ) 176 | if hasattr(embedder, "input_key") and embedder.input_key in force_zero_embeddings: 177 | emb = torch.zeros_like(emb) 178 | if out_key in output: 179 | output[out_key] = torch.cat((output[out_key], emb), self.KEY2CATDIM[out_key]) 180 | else: 181 | output[out_key] = emb 182 | return output 183 | 184 | def forward(self, batch: Dict, force_zero_embeddings: Optional[List] = None) -> Dict: 185 | output = dict() 186 | if force_zero_embeddings is None: 187 | force_zero_embeddings = [] 188 | 189 | if len(self.cor_embs) > 0: 190 | batch_size = len(batch[list(batch.keys())[0]]) 191 | rand_idx = np.random.choice(len(self.cor_p), size=(batch_size,), p=self.cor_p) 192 | for emb_idx in self.cor_embs: 193 | cond_or_not = rand_idx % 2 194 | rand_idx //= 2 195 | output = self.get_single_embedding( 196 | self.embedders[emb_idx], 197 | batch, 198 | output=output, 199 | cond_or_not=cond_or_not, 200 | force_zero_embeddings=force_zero_embeddings, 201 | ) 202 | 203 | for i, embedder in enumerate(self.embedders): 204 | if i in self.cor_embs: 205 | continue 206 | output = self.get_single_embedding( 207 | embedder, batch, output=output, force_zero_embeddings=force_zero_embeddings 208 | ) 209 | return output 210 | 211 | def get_unconditional_conditioning(self, batch_c, batch_uc=None, force_uc_zero_embeddings=None): 212 | if force_uc_zero_embeddings is None: 213 | force_uc_zero_embeddings = [] 214 | ucg_rates = list() 215 | for embedder in self.embedders: 216 | ucg_rates.append(embedder.ucg_rate) 217 | embedder.ucg_rate = 0.0 218 | cor_embs = self.cor_embs 219 | cor_p = self.cor_p 220 | self.cor_embs = [] 221 | self.cor_p = [] 222 | 223 | c = self(batch_c) 224 | uc = self(batch_c if batch_uc is None else batch_uc, force_uc_zero_embeddings) 225 | 226 | for embedder, rate in zip(self.embedders, ucg_rates): 227 | embedder.ucg_rate = rate 228 | self.cor_embs = cor_embs 229 | self.cor_p = cor_p 230 | 231 | return c, uc 232 | 233 | 234 | class FrozenT5Embedder(AbstractEmbModel): 235 | """Uses the T5 transformer encoder for text""" 236 | 237 | def __init__( 238 | self, 239 | model_dir="google/t5-v1_1-xxl", 240 | device="cuda", 241 | max_length=77, 242 | freeze=True, 243 | cache_dir=None, 244 | ): 245 | super().__init__() 246 | if model_dir is not "google/t5-v1_1-xxl": 247 | self.tokenizer = T5Tokenizer.from_pretrained(model_dir) 248 | self.transformer = T5EncoderModel.from_pretrained(model_dir) 249 | else: 250 | self.tokenizer = T5Tokenizer.from_pretrained(model_dir, cache_dir=cache_dir) 251 | self.transformer = T5EncoderModel.from_pretrained(model_dir, cache_dir=cache_dir) 252 | self.device = device 253 | self.max_length = max_length 254 | if freeze: 255 | self.freeze() 256 | 257 | def freeze(self): 258 | self.transformer = self.transformer.eval() 259 | 260 | for param in self.parameters(): 261 | param.requires_grad = False 262 | 263 | # @autocast 264 | def forward(self, text): 265 | batch_encoding = self.tokenizer( 266 | text, 267 | truncation=True, 268 | max_length=self.max_length, 269 | return_length=True, 270 | return_overflowing_tokens=False, 271 | padding="max_length", 272 | return_tensors="pt", 273 | ) 274 | tokens = batch_encoding["input_ids"].to(self.device) 275 | with torch.autocast("cuda", enabled=False): 276 | outputs = self.transformer(input_ids=tokens) 277 | z = outputs.last_hidden_state 278 | return z 279 | 280 | def encode(self, text): 281 | return self(text) 282 | -------------------------------------------------------------------------------- /sat/sgm/modules/video_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..modules.attention import * 4 | from ..modules.diffusionmodules.util import AlphaBlender, linear, timestep_embedding 5 | 6 | 7 | class TimeMixSequential(nn.Sequential): 8 | def forward(self, x, context=None, timesteps=None): 9 | for layer in self: 10 | x = layer(x, context, timesteps) 11 | 12 | return x 13 | 14 | 15 | class VideoTransformerBlock(nn.Module): 16 | ATTENTION_MODES = { 17 | "softmax": CrossAttention, 18 | "softmax-xformers": MemoryEfficientCrossAttention, 19 | } 20 | 21 | def __init__( 22 | self, 23 | dim, 24 | n_heads, 25 | d_head, 26 | dropout=0.0, 27 | context_dim=None, 28 | gated_ff=True, 29 | checkpoint=True, 30 | timesteps=None, 31 | ff_in=False, 32 | inner_dim=None, 33 | attn_mode="softmax", 34 | disable_self_attn=False, 35 | disable_temporal_crossattention=False, 36 | switch_temporal_ca_to_sa=False, 37 | ): 38 | super().__init__() 39 | 40 | attn_cls = self.ATTENTION_MODES[attn_mode] 41 | 42 | self.ff_in = ff_in or inner_dim is not None 43 | if inner_dim is None: 44 | inner_dim = dim 45 | 46 | assert int(n_heads * d_head) == inner_dim 47 | 48 | self.is_res = inner_dim == dim 49 | 50 | if self.ff_in: 51 | self.norm_in = nn.LayerNorm(dim) 52 | self.ff_in = FeedForward(dim, dim_out=inner_dim, dropout=dropout, glu=gated_ff) 53 | 54 | self.timesteps = timesteps 55 | self.disable_self_attn = disable_self_attn 56 | if self.disable_self_attn: 57 | self.attn1 = attn_cls( 58 | query_dim=inner_dim, 59 | heads=n_heads, 60 | dim_head=d_head, 61 | context_dim=context_dim, 62 | dropout=dropout, 63 | ) # is a cross-attention 64 | else: 65 | self.attn1 = attn_cls( 66 | query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout 67 | ) # is a self-attention 68 | 69 | self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff) 70 | 71 | if disable_temporal_crossattention: 72 | if switch_temporal_ca_to_sa: 73 | raise ValueError 74 | else: 75 | self.attn2 = None 76 | else: 77 | self.norm2 = nn.LayerNorm(inner_dim) 78 | if switch_temporal_ca_to_sa: 79 | self.attn2 = attn_cls( 80 | query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout 81 | ) # is a self-attention 82 | else: 83 | self.attn2 = attn_cls( 84 | query_dim=inner_dim, 85 | context_dim=context_dim, 86 | heads=n_heads, 87 | dim_head=d_head, 88 | dropout=dropout, 89 | ) # is self-attn if context is none 90 | 91 | self.norm1 = nn.LayerNorm(inner_dim) 92 | self.norm3 = nn.LayerNorm(inner_dim) 93 | self.switch_temporal_ca_to_sa = switch_temporal_ca_to_sa 94 | 95 | self.checkpoint = checkpoint 96 | if self.checkpoint: 97 | print(f"{self.__class__.__name__} is using checkpointing") 98 | 99 | def forward(self, x: torch.Tensor, context: torch.Tensor = None, timesteps: int = None) -> torch.Tensor: 100 | if self.checkpoint: 101 | return checkpoint(self._forward, x, context, timesteps) 102 | else: 103 | return self._forward(x, context, timesteps=timesteps) 104 | 105 | def _forward(self, x, context=None, timesteps=None): 106 | assert self.timesteps or timesteps 107 | assert not (self.timesteps and timesteps) or self.timesteps == timesteps 108 | timesteps = self.timesteps or timesteps 109 | B, S, C = x.shape 110 | x = rearrange(x, "(b t) s c -> (b s) t c", t=timesteps) 111 | 112 | if self.ff_in: 113 | x_skip = x 114 | x = self.ff_in(self.norm_in(x)) 115 | if self.is_res: 116 | x += x_skip 117 | 118 | if self.disable_self_attn: 119 | x = self.attn1(self.norm1(x), context=context) + x 120 | else: 121 | x = self.attn1(self.norm1(x)) + x 122 | 123 | if self.attn2 is not None: 124 | if self.switch_temporal_ca_to_sa: 125 | x = self.attn2(self.norm2(x)) + x 126 | else: 127 | x = self.attn2(self.norm2(x), context=context) + x 128 | x_skip = x 129 | x = self.ff(self.norm3(x)) 130 | if self.is_res: 131 | x += x_skip 132 | 133 | x = rearrange(x, "(b s) t c -> (b t) s c", s=S, b=B // timesteps, c=C, t=timesteps) 134 | return x 135 | 136 | def get_last_layer(self): 137 | return self.ff.net[-1].weight 138 | 139 | 140 | str_to_dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16} 141 | 142 | 143 | class SpatialVideoTransformer(SpatialTransformer): 144 | def __init__( 145 | self, 146 | in_channels, 147 | n_heads, 148 | d_head, 149 | depth=1, 150 | dropout=0.0, 151 | use_linear=False, 152 | context_dim=None, 153 | use_spatial_context=False, 154 | timesteps=None, 155 | merge_strategy: str = "fixed", 156 | merge_factor: float = 0.5, 157 | time_context_dim=None, 158 | ff_in=False, 159 | checkpoint=False, 160 | time_depth=1, 161 | attn_mode="softmax", 162 | disable_self_attn=False, 163 | disable_temporal_crossattention=False, 164 | max_time_embed_period: int = 10000, 165 | dtype="fp32", 166 | ): 167 | super().__init__( 168 | in_channels, 169 | n_heads, 170 | d_head, 171 | depth=depth, 172 | dropout=dropout, 173 | attn_type=attn_mode, 174 | use_checkpoint=checkpoint, 175 | context_dim=context_dim, 176 | use_linear=use_linear, 177 | disable_self_attn=disable_self_attn, 178 | ) 179 | self.time_depth = time_depth 180 | self.depth = depth 181 | self.max_time_embed_period = max_time_embed_period 182 | 183 | time_mix_d_head = d_head 184 | n_time_mix_heads = n_heads 185 | 186 | time_mix_inner_dim = int(time_mix_d_head * n_time_mix_heads) 187 | 188 | inner_dim = n_heads * d_head 189 | if use_spatial_context: 190 | time_context_dim = context_dim 191 | 192 | self.time_stack = nn.ModuleList( 193 | [ 194 | VideoTransformerBlock( 195 | inner_dim, 196 | n_time_mix_heads, 197 | time_mix_d_head, 198 | dropout=dropout, 199 | context_dim=time_context_dim, 200 | timesteps=timesteps, 201 | checkpoint=checkpoint, 202 | ff_in=ff_in, 203 | inner_dim=time_mix_inner_dim, 204 | attn_mode=attn_mode, 205 | disable_self_attn=disable_self_attn, 206 | disable_temporal_crossattention=disable_temporal_crossattention, 207 | ) 208 | for _ in range(self.depth) 209 | ] 210 | ) 211 | 212 | assert len(self.time_stack) == len(self.transformer_blocks) 213 | 214 | self.use_spatial_context = use_spatial_context 215 | self.in_channels = in_channels 216 | 217 | time_embed_dim = self.in_channels * 4 218 | self.time_pos_embed = nn.Sequential( 219 | linear(self.in_channels, time_embed_dim), 220 | nn.SiLU(), 221 | linear(time_embed_dim, self.in_channels), 222 | ) 223 | 224 | self.time_mixer = AlphaBlender(alpha=merge_factor, merge_strategy=merge_strategy) 225 | self.dtype = str_to_dtype[dtype] 226 | 227 | def forward( 228 | self, 229 | x: torch.Tensor, 230 | context: Optional[torch.Tensor] = None, 231 | time_context: Optional[torch.Tensor] = None, 232 | timesteps: Optional[int] = None, 233 | image_only_indicator: Optional[torch.Tensor] = None, 234 | ) -> torch.Tensor: 235 | _, _, h, w = x.shape 236 | x_in = x 237 | spatial_context = None 238 | if exists(context): 239 | spatial_context = context 240 | 241 | if self.use_spatial_context: 242 | assert context.ndim == 3, f"n dims of spatial context should be 3 but are {context.ndim}" 243 | 244 | time_context = context 245 | time_context_first_timestep = time_context[::timesteps] 246 | time_context = repeat(time_context_first_timestep, "b ... -> (b n) ...", n=h * w) 247 | elif time_context is not None and not self.use_spatial_context: 248 | time_context = repeat(time_context, "b ... -> (b n) ...", n=h * w) 249 | if time_context.ndim == 2: 250 | time_context = rearrange(time_context, "b c -> b 1 c") 251 | 252 | x = self.norm(x) 253 | if not self.use_linear: 254 | x = self.proj_in(x) 255 | x = rearrange(x, "b c h w -> b (h w) c") 256 | if self.use_linear: 257 | x = self.proj_in(x) 258 | 259 | num_frames = torch.arange(timesteps, device=x.device) 260 | num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps) 261 | num_frames = rearrange(num_frames, "b t -> (b t)") 262 | t_emb = timestep_embedding( 263 | num_frames, 264 | self.in_channels, 265 | repeat_only=False, 266 | max_period=self.max_time_embed_period, 267 | dtype=self.dtype, 268 | ) 269 | emb = self.time_pos_embed(t_emb) 270 | emb = emb[:, None, :] 271 | 272 | for it_, (block, mix_block) in enumerate(zip(self.transformer_blocks, self.time_stack)): 273 | x = block( 274 | x, 275 | context=spatial_context, 276 | ) 277 | 278 | x_mix = x 279 | x_mix = x_mix + emb 280 | 281 | x_mix = mix_block(x_mix, context=time_context, timesteps=timesteps) 282 | x = self.time_mixer( 283 | x_spatial=x, 284 | x_temporal=x_mix, 285 | image_only_indicator=image_only_indicator, 286 | ) 287 | if self.use_linear: 288 | x = self.proj_out(x) 289 | x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) 290 | if not self.use_linear: 291 | x = self.proj_out(x) 292 | out = x + x_in 293 | return out 294 | -------------------------------------------------------------------------------- /sat/vae_modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError("Decay must be between 0 and 1") 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer( 14 | "num_updates", 15 | torch.tensor(0, dtype=torch.int) if use_num_upates else torch.tensor(-1, dtype=torch.int), 16 | ) 17 | 18 | for name, p in model.named_parameters(): 19 | if p.requires_grad: 20 | # remove as '.'-character is not allowed in buffers 21 | s_name = name.replace(".", "") 22 | self.m_name2s_name.update({name: s_name}) 23 | self.register_buffer(s_name, p.clone().detach().data) 24 | 25 | self.collected_params = [] 26 | 27 | def reset_num_updates(self): 28 | del self.num_updates 29 | self.register_buffer("num_updates", torch.tensor(0, dtype=torch.int)) 30 | 31 | def forward(self, model): 32 | decay = self.decay 33 | 34 | if self.num_updates >= 0: 35 | self.num_updates += 1 36 | decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) 37 | 38 | one_minus_decay = 1.0 - decay 39 | 40 | with torch.no_grad(): 41 | m_param = dict(model.named_parameters()) 42 | shadow_params = dict(self.named_buffers()) 43 | 44 | for key in m_param: 45 | if m_param[key].requires_grad: 46 | sname = self.m_name2s_name[key] 47 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 48 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 49 | else: 50 | assert not key in self.m_name2s_name 51 | 52 | def copy_to(self, model): 53 | m_param = dict(model.named_parameters()) 54 | shadow_params = dict(self.named_buffers()) 55 | for key in m_param: 56 | if m_param[key].requires_grad: 57 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 58 | else: 59 | assert not key in self.m_name2s_name 60 | 61 | def store(self, parameters): 62 | """ 63 | Save the current parameters for restoring later. 64 | Args: 65 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 66 | temporarily stored. 67 | """ 68 | self.collected_params = [param.clone() for param in parameters] 69 | 70 | def restore(self, parameters): 71 | """ 72 | Restore the parameters stored with the `store` method. 73 | Useful to validate the model with EMA parameters without affecting the 74 | original optimization process. Store the parameters before the 75 | `copy_to` method. After validation (or model saving), use this to 76 | restore the former parameters. 77 | Args: 78 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 79 | updated with the stored parameters. 80 | """ 81 | for c_param, param in zip(self.collected_params, parameters): 82 | param.data.copy_(c_param.data) 83 | -------------------------------------------------------------------------------- /sat/vae_modules/regularizers.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import Any, Tuple 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn 8 | 9 | 10 | class DiagonalGaussianDistribution(object): 11 | def __init__(self, parameters, deterministic=False): 12 | self.parameters = parameters 13 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 14 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 15 | self.deterministic = deterministic 16 | self.std = torch.exp(0.5 * self.logvar) 17 | self.var = torch.exp(self.logvar) 18 | if self.deterministic: 19 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 20 | 21 | def sample(self): 22 | # x = self.mean + self.std * torch.randn(self.mean.shape).to( 23 | # device=self.parameters.device 24 | # ) 25 | x = self.mean + self.std * torch.randn_like(self.mean) 26 | return x 27 | 28 | def kl(self, other=None): 29 | if self.deterministic: 30 | return torch.Tensor([0.0]) 31 | else: 32 | if other is None: 33 | return 0.5 * torch.sum( 34 | torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, 35 | dim=[1, 2, 3], 36 | ) 37 | else: 38 | return 0.5 * torch.sum( 39 | torch.pow(self.mean - other.mean, 2) / other.var 40 | + self.var / other.var 41 | - 1.0 42 | - self.logvar 43 | + other.logvar, 44 | dim=[1, 2, 3], 45 | ) 46 | 47 | def nll(self, sample, dims=[1, 2, 3]): 48 | if self.deterministic: 49 | return torch.Tensor([0.0]) 50 | logtwopi = np.log(2.0 * np.pi) 51 | return 0.5 * torch.sum( 52 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 53 | dim=dims, 54 | ) 55 | 56 | def mode(self): 57 | return self.mean 58 | 59 | 60 | class AbstractRegularizer(nn.Module): 61 | def __init__(self): 62 | super().__init__() 63 | 64 | def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: 65 | raise NotImplementedError() 66 | 67 | @abstractmethod 68 | def get_trainable_parameters(self) -> Any: 69 | raise NotImplementedError() 70 | 71 | 72 | class IdentityRegularizer(AbstractRegularizer): 73 | def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: 74 | return z, dict() 75 | 76 | def get_trainable_parameters(self) -> Any: 77 | yield from () 78 | 79 | 80 | def measure_perplexity(predicted_indices: torch.Tensor, num_centroids: int) -> Tuple[torch.Tensor, torch.Tensor]: 81 | # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py 82 | # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally 83 | encodings = F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids) 84 | avg_probs = encodings.mean(0) 85 | perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() 86 | cluster_use = torch.sum(avg_probs > 0) 87 | return perplexity, cluster_use 88 | 89 | 90 | class DiagonalGaussianRegularizer(AbstractRegularizer): 91 | def __init__(self, sample: bool = True): 92 | super().__init__() 93 | self.sample = sample 94 | 95 | def get_trainable_parameters(self) -> Any: 96 | yield from () 97 | 98 | def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: 99 | log = dict() 100 | posterior = DiagonalGaussianDistribution(z) 101 | if self.sample: 102 | z = posterior.sample() 103 | else: 104 | z = posterior.mode() 105 | kl_loss = posterior.kl() 106 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] 107 | log["kl_loss"] = kl_loss 108 | return z, log 109 | -------------------------------------------------------------------------------- /tools/caption/README.md: -------------------------------------------------------------------------------- 1 | # Video Caption 2 | 3 | Typically, most video data does not come with corresponding descriptive text, so it is necessary to convert the video 4 | data into textual descriptions to provide the essential training data for text-to-video models. 5 | 6 | ## Video Caption via CogVLM2-Video 7 | 8 |

9 | 🤗 Hugging Face   |   🤖 ModelScope   |    📑 Blog    | 💬 Online Demo   10 |

11 | 12 | CogVLM2-Video is a versatile video understanding model equipped with timestamp-based question answering capabilities. 13 | Users can input prompts such as `Please describe this video in detail.` to the model to obtain a detailed video caption: 14 |
15 | 16 |
17 | 18 | Users can use the provided [code](https://github.com/THUDM/CogVLM2/tree/main/video_demo) to load the model or configure a RESTful API to generate video captions. -------------------------------------------------------------------------------- /tools/caption/README_ja.md: -------------------------------------------------------------------------------- 1 | # ビデオキャプション 2 | 3 | 通常、ほとんどのビデオデータには対応する説明文が付いていないため、ビデオデータをテキストの説明に変換して、テキストからビデオへのモデルに必要なトレーニングデータを提供する必要があります。 4 | 5 | ## CogVLM2-Video を使用したビデオキャプション 6 | 7 |

8 | 🤗 Hugging Face   |   🤖 ModelScope   |    📑 ブログ    | 💬 オンラインデモ   9 |

10 | 11 | CogVLM2-Video は、タイムスタンプベースの質問応答機能を備えた多機能なビデオ理解モデルです。ユーザーは `このビデオを詳細に説明してください。` などのプロンプトをモデルに入力して、詳細なビデオキャプションを取得できます: 12 |
13 | 14 |
15 | 16 | ユーザーは提供された[コード](https://github.com/THUDM/CogVLM2/tree/main/video_demo)を使用してモデルをロードするか、RESTful API を構成してビデオキャプションを生成できます。 17 | -------------------------------------------------------------------------------- /tools/caption/README_zh.md: -------------------------------------------------------------------------------- 1 | # 视频Caption 2 | 3 | 通常,大多数视频数据不带有相应的描述性文本,因此需要将视频数据转换为文本描述,以提供必要的训练数据用于文本到视频模型。 4 | 5 | ## 通过 CogVLM2-Video 模型生成视频Caption 6 | 7 | 🤗 [Hugging Face](https://huggingface.co/THUDM/cogvlm2-video-llama3-chat) | 🤖 [ModelScope](https://modelscope.cn/models/ZhipuAI/cogvlm2-video-llama3-chat) | 📑 [Blog](https://cogvlm2-video.github.io/) | [💬 Online Demo](http://cogvlm2-online.cogviewai.cn:7868/) 8 | 9 | CogVLM2-Video 是一个多功能的视频理解模型,具备基于时间戳的问题回答能力。用户可以输入诸如 `请详细描述这个视频` 的提示语给模型,以获得详细的视频Caption: 10 | 11 | 12 |
13 | 14 |
15 | 16 | 用户可以使用提供的[代码](https://github.com/THUDM/CogVLM2/tree/main/video_demo)加载模型或配置 RESTful API 来生成视频Caption。 -------------------------------------------------------------------------------- /tools/caption/assests/cogvlm2-video-example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vchitect/RepVideo/7727c0c592a152b9cd09a47d9d1ba3b81792e25f/tools/caption/assests/cogvlm2-video-example.png -------------------------------------------------------------------------------- /tools/convert_weight_sat2hf.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script demonstrates how to convert and generate video from a text prompt using CogVideoX with 🤗Huggingface Diffusers Pipeline. 3 | 4 | Note: 5 | This script requires the `diffusers>=0.30.0` library to be installed. 6 | 7 | Run the script: 8 | $ python convert_and_generate.py --transformer_ckpt_path --vae_ckpt_path --output_path --text_encoder_path 9 | 10 | Functions: 11 | - reassign_query_key_value_inplace: Reassigns the query, key, and value weights in-place. 12 | - reassign_query_key_layernorm_inplace: Reassigns layer normalization for query and key in-place. 13 | - reassign_adaln_norm_inplace: Reassigns adaptive layer normalization in-place. 14 | - remove_keys_inplace: Removes specified keys from the state_dict in-place. 15 | - replace_up_keys_inplace: Replaces keys in the "up" block in-place. 16 | - get_state_dict: Extracts the state_dict from a saved checkpoint. 17 | - update_state_dict_inplace: Updates the state_dict with new key assignments in-place. 18 | - convert_transformer: Converts a transformer checkpoint to the CogVideoX format. 19 | - convert_vae: Converts a VAE checkpoint to the CogVideoX format. 20 | - get_args: Parses command-line arguments for the script. 21 | - generate_video: Generates a video from a text prompt using the CogVideoX pipeline. 22 | """ 23 | 24 | import argparse 25 | from typing import Any, Dict 26 | 27 | import torch 28 | from diffusers import AutoencoderKLCogVideoX, CogVideoXDDIMScheduler, CogVideoXPipeline, CogVideoXTransformer3DModel 29 | from transformers import T5EncoderModel, T5Tokenizer 30 | 31 | 32 | # Function to reassign the query, key, and value weights in-place 33 | def reassign_query_key_value_inplace(key: str, state_dict: Dict[str, Any]): 34 | to_q_key = key.replace("query_key_value", "to_q") 35 | to_k_key = key.replace("query_key_value", "to_k") 36 | to_v_key = key.replace("query_key_value", "to_v") 37 | to_q, to_k, to_v = torch.chunk(state_dict[key], chunks=3, dim=0) 38 | state_dict[to_q_key] = to_q 39 | state_dict[to_k_key] = to_k 40 | state_dict[to_v_key] = to_v 41 | state_dict.pop(key) 42 | 43 | 44 | # Function to reassign layer normalization for query and key in-place 45 | def reassign_query_key_layernorm_inplace(key: str, state_dict: Dict[str, Any]): 46 | layer_id, weight_or_bias = key.split(".")[-2:] 47 | 48 | if "query" in key: 49 | new_key = f"transformer_blocks.{layer_id}.attn1.norm_q.{weight_or_bias}" 50 | elif "key" in key: 51 | new_key = f"transformer_blocks.{layer_id}.attn1.norm_k.{weight_or_bias}" 52 | 53 | state_dict[new_key] = state_dict.pop(key) 54 | 55 | 56 | # Function to reassign adaptive layer normalization in-place 57 | def reassign_adaln_norm_inplace(key: str, state_dict: Dict[str, Any]): 58 | layer_id, _, weight_or_bias = key.split(".")[-3:] 59 | 60 | weights_or_biases = state_dict[key].chunk(12, dim=0) 61 | norm1_weights_or_biases = torch.cat(weights_or_biases[0:3] + weights_or_biases[6:9]) 62 | norm2_weights_or_biases = torch.cat(weights_or_biases[3:6] + weights_or_biases[9:12]) 63 | 64 | norm1_key = f"transformer_blocks.{layer_id}.norm1.linear.{weight_or_bias}" 65 | state_dict[norm1_key] = norm1_weights_or_biases 66 | 67 | norm2_key = f"transformer_blocks.{layer_id}.norm2.linear.{weight_or_bias}" 68 | state_dict[norm2_key] = norm2_weights_or_biases 69 | 70 | state_dict.pop(key) 71 | 72 | 73 | # Function to remove keys from state_dict in-place 74 | def remove_keys_inplace(key: str, state_dict: Dict[str, Any]): 75 | state_dict.pop(key) 76 | 77 | 78 | # Function to replace keys in the "up" block in-place 79 | def replace_up_keys_inplace(key: str, state_dict: Dict[str, Any]): 80 | key_split = key.split(".") 81 | layer_index = int(key_split[2]) 82 | replace_layer_index = 4 - 1 - layer_index 83 | 84 | key_split[1] = "up_blocks" 85 | key_split[2] = str(replace_layer_index) 86 | new_key = ".".join(key_split) 87 | 88 | state_dict[new_key] = state_dict.pop(key) 89 | 90 | 91 | # Dictionary for renaming transformer keys 92 | TRANSFORMER_KEYS_RENAME_DICT = { 93 | "transformer.final_layernorm": "norm_final", 94 | "transformer": "transformer_blocks", 95 | "attention": "attn1", 96 | "mlp": "ff.net", 97 | "dense_h_to_4h": "0.proj", 98 | "dense_4h_to_h": "2", 99 | ".layers": "", 100 | "dense": "to_out.0", 101 | "input_layernorm": "norm1.norm", 102 | "post_attn1_layernorm": "norm2.norm", 103 | "time_embed.0": "time_embedding.linear_1", 104 | "time_embed.2": "time_embedding.linear_2", 105 | "mixins.patch_embed": "patch_embed", 106 | "mixins.final_layer.norm_final": "norm_out.norm", 107 | "mixins.final_layer.linear": "proj_out", 108 | "mixins.final_layer.adaLN_modulation.1": "norm_out.linear", 109 | } 110 | 111 | # Dictionary for handling special keys in transformer 112 | TRANSFORMER_SPECIAL_KEYS_REMAP = { 113 | "query_key_value": reassign_query_key_value_inplace, 114 | "query_layernorm_list": reassign_query_key_layernorm_inplace, 115 | "key_layernorm_list": reassign_query_key_layernorm_inplace, 116 | "adaln_layer.adaLN_modulations": reassign_adaln_norm_inplace, 117 | "embed_tokens": remove_keys_inplace, 118 | } 119 | 120 | # Dictionary for renaming VAE keys 121 | VAE_KEYS_RENAME_DICT = { 122 | "block.": "resnets.", 123 | "down.": "down_blocks.", 124 | "downsample": "downsamplers.0", 125 | "upsample": "upsamplers.0", 126 | "nin_shortcut": "conv_shortcut", 127 | "encoder.mid.block_1": "encoder.mid_block.resnets.0", 128 | "encoder.mid.block_2": "encoder.mid_block.resnets.1", 129 | "decoder.mid.block_1": "decoder.mid_block.resnets.0", 130 | "decoder.mid.block_2": "decoder.mid_block.resnets.1", 131 | } 132 | 133 | # Dictionary for handling special keys in VAE 134 | VAE_SPECIAL_KEYS_REMAP = { 135 | "loss": remove_keys_inplace, 136 | "up.": replace_up_keys_inplace, 137 | } 138 | 139 | # Maximum length of the tokenizer (Must be 226) 140 | TOKENIZER_MAX_LENGTH = 226 141 | 142 | 143 | # Function to extract the state_dict from a saved checkpoint 144 | def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]: 145 | state_dict = saved_dict 146 | if "model" in saved_dict.keys(): 147 | state_dict = state_dict["model"] 148 | if "module" in saved_dict.keys(): 149 | state_dict = state_dict["module"] 150 | if "state_dict" in saved_dict.keys(): 151 | state_dict = state_dict["state_dict"] 152 | return state_dict 153 | 154 | 155 | # Function to update the state_dict with new key assignments in-place 156 | def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]: 157 | state_dict[new_key] = state_dict.pop(old_key) 158 | 159 | 160 | # Function to convert a transformer checkpoint to the CogVideoX format 161 | def convert_transformer(ckpt_path: str): 162 | PREFIX_KEY = "model.diffusion_model." 163 | 164 | original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True)) 165 | transformer = CogVideoXTransformer3DModel() 166 | 167 | for key in list(original_state_dict.keys()): 168 | new_key = key[len(PREFIX_KEY) :] 169 | for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items(): 170 | new_key = new_key.replace(replace_key, rename_key) 171 | update_state_dict_inplace(original_state_dict, key, new_key) 172 | 173 | for key in list(original_state_dict.keys()): 174 | for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items(): 175 | if special_key not in key: 176 | continue 177 | handler_fn_inplace(key, original_state_dict) 178 | 179 | transformer.load_state_dict(original_state_dict, strict=True) 180 | return transformer 181 | 182 | 183 | # Function to convert a VAE checkpoint to the CogVideoX format 184 | def convert_vae(ckpt_path: str): 185 | original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True)) 186 | vae = AutoencoderKLCogVideoX() 187 | 188 | for key in list(original_state_dict.keys()): 189 | new_key = key[:] 190 | for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items(): 191 | new_key = new_key.replace(replace_key, rename_key) 192 | update_state_dict_inplace(original_state_dict, key, new_key) 193 | 194 | for key in list(original_state_dict.keys()): 195 | for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items(): 196 | if special_key not in key: 197 | continue 198 | handler_fn_inplace(key, original_state_dict) 199 | 200 | vae.load_state_dict(original_state_dict, strict=True) 201 | return vae 202 | 203 | 204 | # Function to parse command-line arguments for the script 205 | def get_args(): 206 | parser = argparse.ArgumentParser() 207 | parser.add_argument( 208 | "--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint" 209 | ) 210 | parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint") 211 | parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved") 212 | parser.add_argument( 213 | "--text_encoder_path", 214 | type=str, 215 | required=True, 216 | default="google/t5-v1_1-xxl", 217 | help="Path where converted model should be saved", 218 | ) 219 | parser.add_argument( 220 | "--text_encoder_cache_dir", 221 | type=str, 222 | default=None, 223 | help="Path to text encoder cache directory. Not needed if text_encoder_path is in your local.", 224 | ) 225 | parser.add_argument("--fp16", action="store_true", default=True, help="Whether to save the model weights in fp16") 226 | parser.add_argument( 227 | "--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving" 228 | ) 229 | return parser.parse_args() 230 | 231 | 232 | if __name__ == "__main__": 233 | args = get_args() 234 | 235 | transformer = None 236 | vae = None 237 | 238 | if args.transformer_ckpt_path is not None: 239 | transformer = convert_transformer(args.transformer_ckpt_path) 240 | if args.vae_ckpt_path is not None: 241 | vae = convert_vae(args.vae_ckpt_path) 242 | 243 | tokenizer = T5Tokenizer.from_pretrained(args.text_encoder_path, model_max_length=TOKENIZER_MAX_LENGTH) 244 | text_encoder = T5EncoderModel.from_pretrained(args.text_encoder_path, cache_dir=args.text_encoder_cache_dir) 245 | 246 | scheduler = CogVideoXDDIMScheduler.from_config( 247 | { 248 | "snr_shift_scale": 3.0, 249 | "beta_end": 0.012, 250 | "beta_schedule": "scaled_linear", 251 | "beta_start": 0.00085, 252 | "clip_sample": False, 253 | "num_train_timesteps": 1000, 254 | "prediction_type": "v_prediction", 255 | "rescale_betas_zero_snr": True, 256 | "set_alpha_to_one": True, 257 | "timestep_spacing": "linspace", 258 | } 259 | ) 260 | 261 | pipe = CogVideoXPipeline( 262 | tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler 263 | ) 264 | 265 | if args.fp16: 266 | pipe = pipe.to(dtype=torch.float16) 267 | 268 | pipe.save_pretrained(args.output_path, safe_serialization=True, push_to_hub=args.push_to_hub) 269 | -------------------------------------------------------------------------------- /tools/venhancer/README.md: -------------------------------------------------------------------------------- 1 | # Enhance CogVideoX Generated Videos with VEnhancer 2 | 3 | This tutorial will guide you through using the VEnhancer tool to enhance videos generated by CogVideoX, including 4 | achieving higher frame rates and higher resolutions. 5 | 6 | ## Model Introduction 7 | 8 | VEnhancer implements spatial super-resolution, temporal super-resolution (frame interpolation), and video refinement in 9 | a unified framework. It can flexibly adapt to different upsampling factors (e.g., 1x~8x) for spatial or temporal 10 | super-resolution. Additionally, it provides flexible control to modify the refinement strength, enabling it to handle 11 | diverse video artifacts. 12 | 13 | VEnhancer follows the design of ControlNet, copying the architecture and weights of the multi-frame encoder and middle 14 | block from a pre-trained video diffusion model to build a trainable conditional network. This video ControlNet accepts 15 | low-resolution keyframes and noisy full-frame latents as inputs. In addition to the time step t and prompt, our proposed 16 | video-aware conditioning also includes noise augmentation level σ and downscaling factor s as additional network 17 | conditioning inputs. 18 | 19 | ## Hardware Requirements 20 | 21 | + Operating System: Linux (requires xformers dependency) 22 | + Hardware: NVIDIA GPU with at least 60GB of VRAM per card. Machines such as H100, A100 are recommended. 23 | 24 | ## Quick Start 25 | 26 | 1. Clone the repository and install dependencies as per the official instructions: 27 | 28 | ```shell 29 | git clone https://github.com/Vchitect/VEnhancer.git 30 | cd VEnhancer 31 | ## Torch and other dependencies can use those from CogVideoX. If you need to create a new environment, use the following commands: 32 | pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 33 | 34 | ## Install required dependencies 35 | pip install -r requirements.txt 36 | ``` 37 | 38 | Where: 39 | 40 | - `input_path` is the path to the input video 41 | - `prompt` is the description of the video content. The prompt used by this tool should be shorter, not exceeding 77 42 | words. You may need to simplify the prompt used for generating the CogVideoX video. 43 | - `target_fps` is the target frame rate for the video. Typically, 16 fps is already smooth, with 24 fps as the default 44 | value. 45 | - `up_scale` is recommend to be set to 2,3,4. The target resolution is limited to be around 2k and below. 46 | - `noise_aug` value depends on the input video quality. Lower quality needs higher noise levels, which corresponds to 47 | stronger refinement. 250~300 is for very low-quality videos. good videos: <= 200. 48 | - `steps` if you want fewer steps, please change solver_mode to "normal" first, then decline the number of steps. " 49 | fast" solver_mode has fixed steps (15). 50 | The code will automatically download the required models from Hugging Face during execution. 51 | 52 | Typical runtime logs are as follows: 53 | 54 | ```shell 55 | /share/home/zyx/.conda/envs/cogvideox/lib/python3.10/site-packages/xformers/ops/fmha/flash.py:211: FutureWarning: `torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch. 56 | @torch.library.impl_abstract("xformers_flash::flash_fwd") 57 | /share/home/zyx/.conda/envs/cogvideox/lib/python3.10/site-packages/xformers/ops/fmha/flash.py:344: FutureWarning: `torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch. 58 | @torch.library.impl_abstract("xformers_flash::flash_bwd") 59 | 2024-08-20 13:25:17,553 - video_to_video - INFO - checkpoint_path: ./ckpts/venhancer_paper.pt 60 | /share/home/zyx/.conda/envs/cogvideox/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. 61 | checkpoint = torch.load(checkpoint_path, map_location=map_location) 62 | 2024-08-20 13:25:37,486 - video_to_video - INFO - Build encoder with FrozenOpenCLIPEmbedder 63 | /share/home/zyx/Code/VEnhancer/video_to_video/video_to_video_model.py:35: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. 64 | load_dict = torch.load(cfg.model_path, map_location='cpu') 65 | 2024-08-20 13:25:55,391 - video_to_video - INFO - Load model path ./ckpts/venhancer_paper.pt, with local status 66 | 2024-08-20 13:25:55,392 - video_to_video - INFO - Build diffusion with GaussianDiffusion 67 | 2024-08-20 13:26:16,092 - video_to_video - INFO - input video path: inputs/000000.mp4 68 | 2024-08-20 13:26:16,093 - video_to_video - INFO - text: Wide-angle aerial shot at dawn,soft morning light casting long shadows,an elderly man walking his dog through a quiet,foggy park,trees and benches in the background,peaceful and serene atmosphere 69 | 2024-08-20 13:26:16,156 - video_to_video - INFO - input frames length: 49 70 | 2024-08-20 13:26:16,156 - video_to_video - INFO - input fps: 8.0 71 | 2024-08-20 13:26:16,156 - video_to_video - INFO - target_fps: 24.0 72 | 2024-08-20 13:26:16,311 - video_to_video - INFO - input resolution: (480, 720) 73 | 2024-08-20 13:26:16,312 - video_to_video - INFO - target resolution: (1320, 1982) 74 | 2024-08-20 13:26:16,312 - video_to_video - INFO - noise augmentation: 250 75 | 2024-08-20 13:26:16,312 - video_to_video - INFO - scale s is set to: 8 76 | 2024-08-20 13:26:16,399 - video_to_video - INFO - video_data shape: torch.Size([145, 3, 1320, 1982]) 77 | /share/home/zyx/Code/VEnhancer/video_to_video/video_to_video_model.py:108: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead. 78 | with amp.autocast(enabled=True): 79 | 2024-08-20 13:27:19,605 - video_to_video - INFO - step: 0 80 | 2024-08-20 13:30:12,020 - video_to_video - INFO - step: 1 81 | 2024-08-20 13:33:04,956 - video_to_video - INFO - step: 2 82 | 2024-08-20 13:35:58,691 - video_to_video - INFO - step: 3 83 | 2024-08-20 13:38:51,254 - video_to_video - INFO - step: 4 84 | 2024-08-20 13:41:44,150 - video_to_video - INFO - step: 5 85 | 2024-08-20 13:44:37,017 - video_to_video - INFO - step: 6 86 | 2024-08-20 13:47:30,037 - video_to_video - INFO - step: 7 87 | 2024-08-20 13:50:22,838 - video_to_video - INFO - step: 8 88 | 2024-08-20 13:53:15,844 - video_to_video - INFO - step: 9 89 | 2024-08-20 13:56:08,657 - video_to_video - INFO - step: 10 90 | 2024-08-20 13:59:01,648 - video_to_video - INFO - step: 11 91 | 2024-08-20 14:01:54,541 - video_to_video - INFO - step: 12 92 | 2024-08-20 14:04:47,488 - video_to_video - INFO - step: 13 93 | 2024-08-20 14:10:13,637 - video_to_video - INFO - sampling, finished. 94 | 95 | ``` 96 | 97 | Running on a single A100 GPU, enhancing each 6-second CogVideoX generated video with default settings will consume 60GB 98 | of VRAM and take 40-50 minutes. -------------------------------------------------------------------------------- /tools/venhancer/README_ja.md: -------------------------------------------------------------------------------- 1 | 2 | # VEnhancer で CogVideoX によって生成されたビデオを強化する 3 | 4 | このチュートリアルでは、VEnhancer ツールを使用して、CogVideoX で生成されたビデオを強化し、より高いフレームレートと高い解像度を実現する方法を説明します。 5 | 6 | ## モデルの紹介 7 | 8 | VEnhancer は、空間超解像、時間超解像(フレーム補間)、およびビデオのリファインメントを統一されたフレームワークで実現します。空間または時間の超解像のために、さまざまなアップサンプリング係数(例:1x〜8x)に柔軟に対応できます。さらに、多様なビデオアーティファクトを処理するために、リファインメント強度を変更する柔軟な制御を提供します。 9 | 10 | VEnhancer は ControlNet の設計に従い、事前訓練されたビデオ拡散モデルのマルチフレームエンコーダーとミドルブロックのアーキテクチャとウェイトをコピーして、トレーニング可能な条件ネットワークを構築します。このビデオ ControlNet は、低解像度のキーフレームとノイズを含む完全なフレームを入力として受け取ります。さらに、タイムステップ t とプロンプトに加えて、提案されたビデオ対応条件により、ノイズ増幅レベル σ およびダウンスケーリングファクター s が追加のネットワーク条件として使用されます。 11 | 12 | ## ハードウェア要件 13 | 14 | + オペレーティングシステム: Linux (xformers 依存関係が必要) 15 | + ハードウェア: 単一カードあたり少なくとも 60GB の VRAM を持つ NVIDIA GPU。H100、A100 などのマシンを推奨します。 16 | 17 | ## クイックスタート 18 | 19 | 1. 公式の指示に従ってリポジトリをクローンし、依存関係をインストールします。 20 | 21 | ```shell 22 | git clone https://github.com/Vchitect/VEnhancer.git 23 | cd VEnhancer 24 | ## Torch などの依存関係は CogVideoX の依存関係を使用できます。新しい環境を作成する必要がある場合は、以下のコマンドを使用してください。 25 | pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 26 | 27 | ## 必須の依存関係をインストールします。 28 | pip install -r requirements.txt 29 | ``` 30 | 31 | 2. コードを実行します。 32 | 33 | ```shell 34 | python enhance_a_video.py --up_scale 4 --target_fps 24 --noise_aug 250 --solver_mode 'fast' --steps 15 --input_path inputs/000000.mp4 --prompt 'Wide-angle aerial shot at dawn, soft morning light casting long shadows, an elderly man walking his dog through a quiet, foggy park, trees and benches in the background, peaceful and serene atmosphere' --save_dir 'results/' 35 | ``` 36 | 37 | 次の設定を行います: 38 | 39 | - `input_path` 是输入视频的路径 40 | - `prompt` 是视频内容的描述。此工具使用的提示词应更短,不超过77个字。您可能需要简化用于生成CogVideoX视频的提示词。 41 | - `target_fps` 是视频的目标帧率。通常,16 fps已经很流畅,默认值为24 fps。 42 | - `up_scale` 推荐设置为2、3或4。目标分辨率限制在2k左右及以下。 43 | - `noise_aug` 的值取决于输入视频的质量。质量较低的视频需要更高的噪声级别,这对应于更强的优化。250~300适用于非常低质量的视频。对于高质量视频,设置为≤200。 44 | - `steps` 如果想减少步数,请先将solver_mode改为“normal”,然后减少步数。“fast”模式的步数是固定的(15步)。 45 | 代码在执行过程中会自动从Hugging Face下载所需的模型。 46 | 47 | コードの実行中に、必要なモデルは Hugging Face から自動的にダウンロードされます。 48 | 49 | ```shell 50 | /share/home/zyx/.conda/envs/cogvideox/lib/python3.10/site-packages/xformers/ops/fmha/flash.py:211: FutureWarning: `torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch. 51 | @torch.library.impl_abstract("xformers_flash::flash_fwd") 52 | /share/home/zyx/.conda/envs/cogvideox/lib/python3.10/site-packages/xformers/ops/fmha/flash.py:344: FutureWarning: `torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch. 53 | @torch.library.impl_abstract("xformers_flash::flash_bwd") 54 | 2024-08-20 13:25:17,553 - video_to_video - INFO - checkpoint_path: ./ckpts/venhancer_paper.pt 55 | /share/home/zyx/.conda/envs/cogvideox/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. 56 | checkpoint = torch.load(checkpoint_path, map_location=map_location) 57 | 2024-08-20 13:25:37,486 - video_to_video - INFO - Build encoder with FrozenOpenCLIPEmbedder 58 | /share/home/zyx/Code/VEnhancer/video_to_video/video_to_video_model.py:35: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. 59 | load_dict = torch.load(cfg.model_path, map_location='cpu') 60 | 2024-08-20 13:25:55,391 - video_to_video - INFO - Load model path ./ckpts/venhancer_paper.pt, with local status 61 | 2024-08-20 13:25:55,392 - video_to_video - INFO - Build diffusion with GaussianDiffusion 62 | 2024-08-20 13:26:16,092 - video_to_video - INFO - input video path: inputs/000000.mp4 63 | 2024-08-20 13:26:16,093 - video_to_video - INFO - text: Wide-angle aerial shot at dawn,soft morning light casting long shadows,an elderly man walking his dog through a quiet,foggy park,trees and benches in the background,peaceful and serene atmosphere 64 | 2024-08-20 13:26:16,156 - video_to_video - INFO - input frames length: 49 65 | 2024-08-20 13:26:16,156 - video_to_video - INFO - input fps: 8.0 66 | 2024-08-20 13:26:16,156 - video_to_video - INFO - target_fps: 24.0 67 | 2024-08-20 13:26:16,311 - video_to_video - INFO - input resolution: (480, 720) 68 | 2024-08-20 13:26:16,312 - video_to_video - INFO - target resolution: (1320, 1982) 69 | 2024-08-20 13:26:16,312 - video_to_video - INFO - noise augmentation: 250 70 | 2024-08-20 13:26:16,312 - video_to_video - INFO - scale s is set to: 8 71 | 2024-08-20 13:26:16,399 - video_to_video - INFO - video_data shape: torch.Size([145, 3, 1320, 1982]) 72 | /share/home/zyx/Code/VEnhancer/video_to_video/video_to_video_model.py:108: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead. 73 | with amp.autocast(enabled=True): 74 | 2024-08-20 13:27:19,605 - video_to_video - INFO - step: 0 75 | 2024-08-20 13:30:12,020 - video_to_video - INFO - step: 1 76 | 2024-08-20 13:33:04,956 - video_to_video - INFO - step: 2 77 | 2024-08-20 13:35:58,691 - video_to_video - INFO - step: 3 78 | 2024-08-20 13:38:51,254 - video_to_video - INFO - step: 4 79 | 2024-08-20 13:41:44,150 - video_to_video - INFO - step: 5 80 | 2024-08-20 13:44:37,017 - video_to_video - INFO - step: 6 81 | 2024-08-20 13:47:30,037 - video_to_video - INFO - step: 7 82 | 2024-08-20 13:50:22,838 - video_to_video - INFO - step: 8 83 | 2024-08-20 13:53:15,844 - video_to_video - INFO - step: 9 84 | 2024-08-20 13:56:08,657 - video_to_video - INFO - step: 10 85 | 2024-08-20 13:59:01,648 - video_to_video - INFO - step: 11 86 | 2024-08-20 14:01:54,541 - video_to_video - INFO - step: 12 87 | 2024-08-20 14:04:47,488 - video_to_video - INFO - step: 13 88 | 2024-08-20 14:10:13,637 - video_to_video - INFO - sampling, finished. 89 | 90 | ``` 91 | 92 | A100 GPU を単一で使用している場合、CogVideoX によって生成された 6 秒間のビデオを強化するには、デフォルト設定で 60GB の VRAM を消費し、40〜50 分かかります。 93 | -------------------------------------------------------------------------------- /tools/venhancer/README_zh.md: -------------------------------------------------------------------------------- 1 | # 使用 VEnhancer 对 CogVdieoX 生成视频进行增强 2 | 3 | 本教程将要使用 VEnhancer 工具 对 CogVdieoX 生成视频进行增强, 包括更高的帧率和更高的分辨率 4 | 5 | ## 模型介绍 6 | 7 | VEnhancer 在一个统一的框架中实现了空间超分辨率、时间超分辨率(帧插值)和视频优化。它可以灵活地适应不同的上采样因子(例如,1x~ 8 | 8x)用于空间或时间超分辨率。此外,它提供了灵活的控制,以修改优化强度,从而处理多样化的视频伪影。 9 | 10 | VEnhancer 遵循 ControlNet 的设计,复制了预训练的视频扩散模型的多帧编码器和中间块的架构和权重,构建了一个可训练的条件网络。这个视频 11 | ControlNet 接受低分辨率关键帧和包含噪声的完整帧作为输入。此外,除了时间步 t 和提示词外,我们提出的视频感知条件还将噪声增强的噪声级别 12 | σ 和降尺度因子 s 作为附加的网络条件输入。 13 | 14 | ## 硬件需求 15 | 16 | + 操作系统: Linux (需要依赖xformers) 17 | + 硬件: NVIDIA GPU 并至少保证单卡显存超过60G,推荐使用 H100,A100等机器。 18 | 19 | ## 快速上手 20 | 21 | 1. 按照官方指引克隆仓库并安装依赖 22 | 23 | ```shell 24 | git clone https://github.com/Vchitect/VEnhancer.git 25 | cd VEnhancer 26 | ## torch等依赖可以使用CogVideoX的依赖,如果你需要创建一个新的环境,可以使用以下命令 27 | pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 28 | 29 | ## 安装必须的依赖 30 | pip install -r requirements.txt 31 | ``` 32 | 33 | 2. 运行代码 34 | 35 | ```shell 36 | python enhance_a_video.py \ 37 | --up_scale 4 --target_fps 24 --noise_aug 250 \ 38 | --solver_mode 'fast' --steps 15 \ 39 | --input_path inputs/000000.mp4 \ 40 | --prompt 'Wide-angle aerial shot at dawn,soft morning light casting long shadows,an elderly man walking his dog through a quiet,foggy park,trees and benches in the background,peaceful and serene atmosphere' \ 41 | --save_dir 'results/' 42 | ``` 43 | 44 | 其中: 45 | 46 | - `input_path` 是输入视频的路径 47 | - `prompt` 是视频内容的描述。此工具使用的提示词应更短,不超过77个字。您可能需要简化用于生成CogVideoX视频的提示词。 48 | - `target_fps` 是视频的目标帧率。通常,16 fps已经很流畅,默认值为24 fps。 49 | - `up_scale` 推荐设置为2、3或4。目标分辨率限制在2k左右及以下。 50 | - `noise_aug` 的值取决于输入视频的质量。质量较低的视频需要更高的噪声级别,这对应于更强的优化。250~300适用于非常低质量的视频。对于高质量视频,设置为≤200。 51 | - `steps` 如果想减少步数,请先将solver_mode改为“normal”,然后减少步数。“fast”模式的步数是固定的(15步)。 52 | 代码在执行过程中会自动从Hugging Face下载所需的模型。 53 | 54 | 代码运行过程中,会自动从Huggingface拉取需要的模型 55 | 56 | 运行日志通常如下: 57 | 58 | ```shell 59 | /share/home/zyx/.conda/envs/cogvideox/lib/python3.10/site-packages/xformers/ops/fmha/flash.py:211: FutureWarning: `torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch. 60 | @torch.library.impl_abstract("xformers_flash::flash_fwd") 61 | /share/home/zyx/.conda/envs/cogvideox/lib/python3.10/site-packages/xformers/ops/fmha/flash.py:344: FutureWarning: `torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch. 62 | @torch.library.impl_abstract("xformers_flash::flash_bwd") 63 | 2024-08-20 13:25:17,553 - video_to_video - INFO - checkpoint_path: ./ckpts/venhancer_paper.pt 64 | /share/home/zyx/.conda/envs/cogvideox/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. 65 | checkpoint = torch.load(checkpoint_path, map_location=map_location) 66 | 2024-08-20 13:25:37,486 - video_to_video - INFO - Build encoder with FrozenOpenCLIPEmbedder 67 | /share/home/zyx/Code/VEnhancer/video_to_video/video_to_video_model.py:35: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. 68 | load_dict = torch.load(cfg.model_path, map_location='cpu') 69 | 2024-08-20 13:25:55,391 - video_to_video - INFO - Load model path ./ckpts/venhancer_paper.pt, with local status 70 | 2024-08-20 13:25:55,392 - video_to_video - INFO - Build diffusion with GaussianDiffusion 71 | 2024-08-20 13:26:16,092 - video_to_video - INFO - input video path: inputs/000000.mp4 72 | 2024-08-20 13:26:16,093 - video_to_video - INFO - text: Wide-angle aerial shot at dawn,soft morning light casting long shadows,an elderly man walking his dog through a quiet,foggy park,trees and benches in the background,peaceful and serene atmosphere 73 | 2024-08-20 13:26:16,156 - video_to_video - INFO - input frames length: 49 74 | 2024-08-20 13:26:16,156 - video_to_video - INFO - input fps: 8.0 75 | 2024-08-20 13:26:16,156 - video_to_video - INFO - target_fps: 24.0 76 | 2024-08-20 13:26:16,311 - video_to_video - INFO - input resolution: (480, 720) 77 | 2024-08-20 13:26:16,312 - video_to_video - INFO - target resolution: (1320, 1982) 78 | 2024-08-20 13:26:16,312 - video_to_video - INFO - noise augmentation: 250 79 | 2024-08-20 13:26:16,312 - video_to_video - INFO - scale s is set to: 8 80 | 2024-08-20 13:26:16,399 - video_to_video - INFO - video_data shape: torch.Size([145, 3, 1320, 1982]) 81 | /share/home/zyx/Code/VEnhancer/video_to_video/video_to_video_model.py:108: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead. 82 | with amp.autocast(enabled=True): 83 | 2024-08-20 13:27:19,605 - video_to_video - INFO - step: 0 84 | 2024-08-20 13:30:12,020 - video_to_video - INFO - step: 1 85 | 2024-08-20 13:33:04,956 - video_to_video - INFO - step: 2 86 | 2024-08-20 13:35:58,691 - video_to_video - INFO - step: 3 87 | 2024-08-20 13:38:51,254 - video_to_video - INFO - step: 4 88 | 2024-08-20 13:41:44,150 - video_to_video - INFO - step: 5 89 | 2024-08-20 13:44:37,017 - video_to_video - INFO - step: 6 90 | 2024-08-20 13:47:30,037 - video_to_video - INFO - step: 7 91 | 2024-08-20 13:50:22,838 - video_to_video - INFO - step: 8 92 | 2024-08-20 13:53:15,844 - video_to_video - INFO - step: 9 93 | 2024-08-20 13:56:08,657 - video_to_video - INFO - step: 10 94 | 2024-08-20 13:59:01,648 - video_to_video - INFO - step: 11 95 | 2024-08-20 14:01:54,541 - video_to_video - INFO - step: 12 96 | 2024-08-20 14:04:47,488 - video_to_video - INFO - step: 13 97 | 2024-08-20 14:10:13,637 - video_to_video - INFO - sampling, finished. 98 | 99 | ``` 100 | 101 | 使用A100单卡运行,对于每个CogVideoX生产的6秒视频,按照默认配置,会消耗60G显存,并用时40-50分钟。 --------------------------------------------------------------------------------