├── .gitignore
├── README.md
├── assets
├── 1.gif
├── 10.gif
├── 11.gif
├── 12.gif
├── 2.gif
├── 3.gif
├── 4.gif
├── 5.gif
├── 6.gif
├── 7.gif
├── 8.gif
└── 9.gif
├── requirements.txt
├── run.sh
├── sat
├── arguments.py
├── asset
│ └── prompt.txt
├── base_model.py
├── configs
│ ├── cogvideox_2b.yaml
│ └── inference.yaml
├── demo.py
├── diffusion_video.py
├── dit_video_concat.py
├── run.sh
├── sample_video.py
├── sgm
│ ├── __init__.py
│ ├── lr_scheduler.py
│ ├── models
│ │ ├── __init__.py
│ │ └── autoencoder.py
│ ├── modules
│ │ ├── __init__.py
│ │ ├── attention.py
│ │ ├── autoencoding
│ │ │ ├── __init__.py
│ │ │ ├── losses
│ │ │ │ ├── __init__.py
│ │ │ │ ├── discriminator_loss.py
│ │ │ │ ├── lpips.py
│ │ │ │ └── video_loss.py
│ │ │ ├── lpips
│ │ │ │ ├── __init__.py
│ │ │ │ ├── loss
│ │ │ │ │ ├── .gitignore
│ │ │ │ │ ├── LICENSE
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ └── lpips.py
│ │ │ │ ├── model
│ │ │ │ │ ├── LICENSE
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ └── model.py
│ │ │ │ ├── util.py
│ │ │ │ └── vqperceptual.py
│ │ │ ├── magvit2_pytorch.py
│ │ │ ├── regularizers
│ │ │ │ ├── __init__.py
│ │ │ │ ├── base.py
│ │ │ │ ├── finite_scalar_quantization.py
│ │ │ │ ├── lookup_free_quantization.py
│ │ │ │ └── quantize.py
│ │ │ ├── temporal_ae.py
│ │ │ └── vqvae
│ │ │ │ ├── movq_dec_3d.py
│ │ │ │ ├── movq_dec_3d_dev.py
│ │ │ │ ├── movq_enc_3d.py
│ │ │ │ ├── movq_modules.py
│ │ │ │ ├── quantize.py
│ │ │ │ └── vqvae_blocks.py
│ │ ├── cp_enc_dec.py
│ │ ├── diffusionmodules
│ │ │ ├── __init__.py
│ │ │ ├── denoiser.py
│ │ │ ├── denoiser_scaling.py
│ │ │ ├── denoiser_weighting.py
│ │ │ ├── discretizer.py
│ │ │ ├── guiders.py
│ │ │ ├── lora.py
│ │ │ ├── loss.py
│ │ │ ├── model.py
│ │ │ ├── openaimodel.py
│ │ │ ├── sampling.py
│ │ │ ├── sampling_utils.py
│ │ │ ├── sigma_sampling.py
│ │ │ ├── util.py
│ │ │ └── wrappers.py
│ │ ├── distributions
│ │ │ ├── __init__.py
│ │ │ └── distributions.py
│ │ ├── ema.py
│ │ ├── encoders
│ │ │ ├── __init__.py
│ │ │ └── modules.py
│ │ └── video_attention.py
│ ├── util.py
│ └── webds.py
├── transformer_gating.py
└── vae_modules
│ ├── attention.py
│ ├── autoencoder.py
│ ├── cp_enc_dec.py
│ ├── ema.py
│ ├── regularizers.py
│ └── utils.py
└── tools
├── caption
├── README.md
├── README_ja.md
├── README_zh.md
└── assests
│ └── cogvlm2-video-example.png
├── convert_weight_sat2hf.py
└── venhancer
├── README.md
├── README_ja.md
└── README_zh.md
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | .hypothesis
3 | demo_videos
4 | results
5 |
6 | # Byte-compiled / optimized / DLL files
7 | __pycache__/
8 | *.py[cod]
9 | *$py.class
10 |
11 | # C extensions
12 | *.so
13 |
14 | # Distribution / packaging
15 | .Python
16 | build/
17 | develop-eggs/
18 | dist/
19 | downloads/
20 | eggs/
21 | .eggs/
22 | lib/
23 | lib64/
24 | parts/
25 | sdist/
26 | var/
27 | wheels/
28 | share/python-wheels/
29 | *.egg-info/
30 | .installed.cfg
31 | *.egg
32 | MANIFEST
33 |
34 | # PyInstaller
35 | # Usually these files are written by a python script from a template
36 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
37 | *.manifest
38 | *.spec
39 |
40 | # Installer logs
41 | pip-log.txt
42 | pip-delete-this-directory.txt
43 |
44 | # Unit test / coverage reports
45 | htmlcov/
46 | .tox/
47 | .nox/
48 | .coverage
49 | .coverage.*
50 | .cache
51 | nosetests.xml
52 | coverage.xml
53 | *.cover
54 | *.py,cover
55 | .hypothesis/
56 | .pytest_cache/
57 | cover/
58 |
59 | # Translations
60 | *.mo
61 | *.pot
62 |
63 | # Django stuff:
64 | *.log
65 | local_settings.py
66 | db.sqlite3
67 | db.sqlite3-journal
68 |
69 | # Flask stuff:
70 | instance/
71 | .webassets-cache
72 |
73 | # Scrapy stuff:
74 | .scrapy
75 |
76 | # Sphinx documentation
77 | docs/_build/
78 |
79 | # PyBuilder
80 | .pybuilder/
81 | target/
82 |
83 | # Jupyter Notebook
84 | .ipynb_checkpoints
85 |
86 | # IPython
87 | profile_default/
88 | ipython_config.py
89 |
90 | # pyenv
91 | # For a library or package, you might want to ignore these files since the code is
92 | # intended to run in multiple environments; otherwise, check them in:
93 | # .python-version
94 |
95 | # pipenv
96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
99 | # install all needed dependencies.
100 | #Pipfile.lock
101 |
102 | # poetry
103 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
104 | # This is especially recommended for binary packages to ensure reproducibility, and is more
105 | # commonly ignored for libraries.
106 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
107 | #poetry.lock
108 |
109 | # pdm
110 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
111 | #pdm.lock
112 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
113 | # in version control.
114 | # https://pdm.fming.dev/#use-with-ide
115 | .pdm.toml
116 |
117 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
118 | __pypackages__/
119 |
120 | # Celery stuff
121 | celerybeat-schedule
122 | celerybeat.pid
123 |
124 | # SageMath parsed files
125 | *.sage.py
126 |
127 | # Environments
128 | .env
129 | .venv
130 | env/
131 | venv/
132 | ENV/
133 | env.bak/
134 | venv.bak/
135 |
136 | # Spyder project settings
137 | .spyderproject
138 | .spyproject
139 |
140 | # Rope project settings
141 | .ropeproject
142 |
143 | # mkdocs documentation
144 | /site
145 |
146 | # mypy
147 | .mypy_cache/
148 | .dmypy.json
149 | dmypy.json
150 |
151 | # Pyre type checker
152 | .pyre/
153 |
154 | # pytype static type analyzer
155 | .pytype/
156 |
157 | # Cython debug symbols
158 | cython_debug/
159 |
160 | # PyCharm
161 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
162 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
163 | # and can be added to the global gitignore or merged into this file. For a more nuclear
164 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
165 | .idea/
166 | .vscode/
167 |
168 | # macos
169 | *.DS_Store
170 | *.json
171 |
172 | /sat/ckpt
173 | /sat/outputs
174 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # RepVideo: Rethinking Cross-Layer Representation for Video Generation
2 |
3 |
6 |
7 |
23 |
24 | S-Lab, Nanyang Technological University1 Shanghai Artificial Intelligence Laboratory 2
25 |
†Equal contribution. ✉Corresponding Author.
26 |
27 |
28 |
29 |
30 |
34 |
37 |
38 | ---
39 |
40 | 
41 | [](https://hits.seeyoufarm.com)
42 | [](https://hits.seeyoufarm.com)
43 | [](https://hits.seeyoufarm.com)
44 |
45 | ## 🔥 Update and News
46 | - [2025.01.25] 🔥 Inference code and [checkpoint](https://huggingface.co/Vchitect/RepVideo) are released.
47 |
48 |
49 | ## :astonished: Gallery
50 |
51 |
52 |
53 |
54 | |
55 | |
56 | |
57 |
58 |
59 |
60 |
61 | |
62 | |
63 | |
64 |
65 |
66 |
67 | |
68 | |
69 | |
70 |
71 |
72 |
73 | |
74 | |
75 | |
76 |
77 |
78 |
79 |
80 |
81 | ## Installation
82 |
83 | ### 1. Create a conda environment and download models
84 |
85 |
86 | ```bash
87 | conda create -n RepVid python==3.10
88 | conda activate RepVid
89 | pip install -r requirements.txt
90 |
91 |
92 | mkdir ckpt
93 | cd ckpt
94 | mkdir t5-v1_1-xxl
95 | wget https://huggingface.co/THUDM/CogVideoX-2b/resolve/main/text_encoder/config.json
96 | wget https://huggingface.co/THUDM/CogVideoX-2b/resolve/main/text_encoder/model-00001-of-00002.safetensors
97 | wget https://huggingface.co/THUDM/CogVideoX-2b/resolve/main/text_encoder/model-00002-of-00002.safetensors
98 | wget https://huggingface.co/THUDM/CogVideoX-2b/resolve/main/text_encoder/model.safetensors.index.json
99 | wget https://huggingface.co/THUDM/CogVideoX-2b/resolve/main/tokenizer/added_tokens.json
100 | wget https://huggingface.co/THUDM/CogVideoX-2b/resolve/main/tokenizer/special_tokens_map.json
101 | wget https://huggingface.co/THUDM/CogVideoX-2b/resolve/main/tokenizer/spiece.model
102 | wget https://huggingface.co/THUDM/CogVideoX-2b/resolve/main/tokenizer/tokenizer_config.json
103 |
104 | cd ../
105 | mkdir vae
106 | wget https://cloud.tsinghua.edu.cn/f/fdba7608a49c463ba754/?dl=1
107 | mv 'index.html?dl=1' vae.zip
108 | unzip vae.zip
109 | ```
110 | ### 2. Download Our latest Checkpoint
111 | ```
112 | git-lfs clone https://huggingface.co/Vchitect/RepVideo/tree/main
113 | # Then modify the "load" path in "sat/configs/inference.yaml" accordingly.
114 | ```
115 |
116 | ## Inference
117 |
118 | ~~~bash
119 | cd sat
120 | bash run.sh
121 | ~~~
122 |
123 | ## BibTeX
124 | ```
125 | @article{si2025RepVideo,
126 | title={RepVideo: Rethinking Cross-Layer Representation for Video Generation},
127 | author={Si, Chenyang and Fan, Weichen and Lv, Zhengyao and Huang, Ziqi and Qiao, Yu and Liu, Ziwei},
128 | journal={arXiv 2501.08994},
129 | year={2025}
130 | }
131 | ```
132 |
133 | ## 🔑 License
134 |
135 | This code is licensed under Apache-2.0. The framework is fully open for academic research and also allows free commercial usage.
136 |
137 |
138 | ## Disclaimer
139 |
140 | We disclaim responsibility for user-generated content. The model was not trained to realistically represent people or events, so using it to generate such content is beyond the model's capabilities. It is prohibited for pornographic, violent and bloody content generation, and to generate content that is demeaning or harmful to people or their environment, culture, religion, etc. Users are solely liable for their actions. The project contributors are not legally affiliated with, nor accountable for users' behaviors. Use the generative model responsibly, adhering to ethical and legal standards.
141 |
142 |
--------------------------------------------------------------------------------
/assets/1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Vchitect/RepVideo/7727c0c592a152b9cd09a47d9d1ba3b81792e25f/assets/1.gif
--------------------------------------------------------------------------------
/assets/10.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Vchitect/RepVideo/7727c0c592a152b9cd09a47d9d1ba3b81792e25f/assets/10.gif
--------------------------------------------------------------------------------
/assets/11.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Vchitect/RepVideo/7727c0c592a152b9cd09a47d9d1ba3b81792e25f/assets/11.gif
--------------------------------------------------------------------------------
/assets/12.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Vchitect/RepVideo/7727c0c592a152b9cd09a47d9d1ba3b81792e25f/assets/12.gif
--------------------------------------------------------------------------------
/assets/2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Vchitect/RepVideo/7727c0c592a152b9cd09a47d9d1ba3b81792e25f/assets/2.gif
--------------------------------------------------------------------------------
/assets/3.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Vchitect/RepVideo/7727c0c592a152b9cd09a47d9d1ba3b81792e25f/assets/3.gif
--------------------------------------------------------------------------------
/assets/4.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Vchitect/RepVideo/7727c0c592a152b9cd09a47d9d1ba3b81792e25f/assets/4.gif
--------------------------------------------------------------------------------
/assets/5.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Vchitect/RepVideo/7727c0c592a152b9cd09a47d9d1ba3b81792e25f/assets/5.gif
--------------------------------------------------------------------------------
/assets/6.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Vchitect/RepVideo/7727c0c592a152b9cd09a47d9d1ba3b81792e25f/assets/6.gif
--------------------------------------------------------------------------------
/assets/7.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Vchitect/RepVideo/7727c0c592a152b9cd09a47d9d1ba3b81792e25f/assets/7.gif
--------------------------------------------------------------------------------
/assets/8.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Vchitect/RepVideo/7727c0c592a152b9cd09a47d9d1ba3b81792e25f/assets/8.gif
--------------------------------------------------------------------------------
/assets/9.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Vchitect/RepVideo/7727c0c592a152b9cd09a47d9d1ba3b81792e25f/assets/9.gif
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | diffusers>=0.31.0
2 | accelerate>=1.1.1
3 | transformers>=4.46.2
4 | numpy==2.2.1
5 | torch>=2.5.0
6 | torchvision>=0.20.0
7 | sentencepiece>=0.2.0
8 | SwissArmyTransformer>=0.4.12
9 | gradio>=5.5.0
10 | imageio>=2.35.1
11 | imageio-ffmpeg>=0.5.1
12 | openai>=1.54.0
13 | moviepy>=1.0.3
14 | scikit-video>=1.1.11
15 | SwissArmyTransformer>=0.4.12
16 | omegaconf>=2.3.0
17 | pytorch_lightning>=2.4.0
18 | kornia>=0.7.3
19 | beartype>=0.19.0
20 | fsspec>=2024.2.0
21 | safetensors>=0.4.5
22 | scipy>=1.14.1
23 | decord>=0.6.0
24 | wandb>=0.18.5
25 | deepspeed>=0.15.3
26 |
--------------------------------------------------------------------------------
/run.sh:
--------------------------------------------------------------------------------
1 | srun -p video-aigc-1 --gres=gpu:1 --quotatype=auto --ntasks-per-node=1 -n1 -N1 --cpus-per-task=12 python ckpt.py
--------------------------------------------------------------------------------
/sat/asset/prompt.txt:
--------------------------------------------------------------------------------
1 | A cat holding a sign.
2 |
--------------------------------------------------------------------------------
/sat/configs/cogvideox_2b.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | scale_factor: 1.15258426
3 | disable_first_stage_autocast: true
4 | log_keys:
5 | - txt
6 |
7 | denoiser_config:
8 | target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
9 | params:
10 | num_idx: 1000
11 | quantize_c_noise: False
12 |
13 | weighting_config:
14 | target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
15 | scaling_config:
16 | target: sgm.modules.diffusionmodules.denoiser_scaling.VideoScaling
17 | discretization_config:
18 | target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
19 | params:
20 | shift_scale: 3.0
21 |
22 | network_config:
23 | target: dit_video_concat.DiffusionTransformer
24 | params:
25 | time_embed_dim: 512
26 | elementwise_affine: True
27 | num_frames: 49
28 | time_compressed_rate: 4
29 | latent_width: 90
30 | latent_height: 60
31 | # latent_width: 48
32 | # latent_height: 32
33 | num_layers: 30
34 | patch_size: 2
35 | in_channels: 16
36 | out_channels: 16
37 | hidden_size: 1920
38 | adm_in_channels: 256
39 | num_attention_heads: 30
40 |
41 | transformer_args:
42 | checkpoint_activations: True ## using gradient checkpointing
43 | vocab_size: 1
44 | max_sequence_length: 64
45 | layernorm_order: pre
46 | skip_init: false
47 | model_parallel_size: 1
48 | is_decoder: false
49 |
50 | modules:
51 | pos_embed_config:
52 | target: dit_video_concat.Basic3DPositionEmbeddingMixin
53 | params:
54 | text_length: 226
55 | height_interpolation: 1.875
56 | width_interpolation: 1.875
57 |
58 | patch_embed_config:
59 | target: dit_video_concat.ImagePatchEmbeddingMixin
60 | params:
61 | text_hidden_size: 4096
62 |
63 | adaln_layer_config:
64 | target: dit_video_concat.AdaLNMixin
65 | params:
66 | qk_ln: True
67 |
68 | final_layer_config:
69 | target: dit_video_concat.FinalLayerMixin
70 |
71 | conditioner_config:
72 | target: sgm.modules.GeneralConditioner
73 | params:
74 | emb_models:
75 | - is_trainable: false
76 | input_key: txt
77 | ucg_rate: 0.1
78 | target: sgm.modules.encoders.modules.FrozenT5Embedder
79 | params:
80 | model_dir: "ckpt/t5-v1_1-xxl"
81 | max_length: 226
82 |
83 | first_stage_config:
84 | target: vae_modules.autoencoder.VideoAutoencoderInferenceWrapper
85 | params:
86 | cp_size: 1
87 | ckpt_path: "ckpt/vae/vae/3d-vae.pt"
88 | ignore_keys: [ 'loss' ]
89 |
90 | loss_config:
91 | target: torch.nn.Identity
92 |
93 | regularizer_config:
94 | target: vae_modules.regularizers.DiagonalGaussianRegularizer
95 |
96 | encoder_config:
97 | target: vae_modules.cp_enc_dec.ContextParallelEncoder3D
98 | params:
99 | double_z: true
100 | z_channels: 16
101 | resolution: 256
102 | in_channels: 3
103 | out_ch: 3
104 | ch: 128
105 | ch_mult: [ 1, 2, 2, 4 ]
106 | attn_resolutions: [ ]
107 | num_res_blocks: 3
108 | dropout: 0.0
109 | gather_norm: True
110 |
111 | decoder_config:
112 | target: vae_modules.cp_enc_dec.ContextParallelDecoder3D
113 | params:
114 | double_z: True
115 | z_channels: 16
116 | resolution: 256
117 | in_channels: 3
118 | out_ch: 3
119 | ch: 128
120 | ch_mult: [ 1, 2, 2, 4 ]
121 | attn_resolutions: [ ]
122 | num_res_blocks: 3
123 | dropout: 0.0
124 | gather_norm: False
125 |
126 | loss_fn_config:
127 | target: sgm.modules.diffusionmodules.loss.VideoDiffusionLoss
128 | params:
129 | offset_noise_level: 0
130 | sigma_sampler_config:
131 | target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
132 | params:
133 | uniform_sampling: True
134 | num_idx: 1000
135 | discretization_config:
136 | target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
137 | params:
138 | shift_scale: 3.0
139 |
140 | sampler_config:
141 | target: sgm.modules.diffusionmodules.sampling.VPSDEDPMPP2MSampler
142 | params:
143 | num_steps: 50
144 | verbose: True
145 |
146 | discretization_config:
147 | target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
148 | params:
149 | shift_scale: 3.0
150 |
151 | guider_config:
152 | target: sgm.modules.diffusionmodules.guiders.DynamicCFG
153 | params:
154 | scale: 6
155 | exp: 5
156 | num_steps: 50
--------------------------------------------------------------------------------
/sat/configs/inference.yaml:
--------------------------------------------------------------------------------
1 | args:
2 | latent_channels: 16
3 | mode: inference
4 | load: ckpt/RepVideo
5 |
6 |
7 | batch_size: 1
8 | input_type: txt
9 | input_file: asset/prompt.txt
10 |
11 | sampling_num_frames: 13
12 | sampling_fps: 8
13 | fp16: True
14 |
15 | output_dir: outputs/cog/train_ncfg_channel_1_step_orth
16 | force_inference: True
--------------------------------------------------------------------------------
/sat/demo.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import argparse
4 | from typing import List, Union
5 | from tqdm import tqdm
6 | import imageio
7 | import torch
8 | import numpy as np
9 | from einops import rearrange
10 | import torchvision.transforms as TT
11 | from sat.model.base_model import get_model
12 | from sat.training.model_io import load_checkpoint
13 | from sat import mpu
14 | from diffusion_video import SATVideoDiffusionEngine
15 | from torchvision.transforms.functional import resize
16 | from torchvision.transforms import InterpolationMode
17 | import gradio as gr
18 | from arguments import get_args
19 |
20 | # Load model once at the beginning
21 | class ModelHandler:
22 | def __init__(self):
23 | self.model = None
24 | self.first_stage_model = None
25 |
26 | def load_model(self, args):
27 | if self.model is None:
28 | self.model = get_model(args, SATVideoDiffusionEngine)
29 | load_checkpoint(self.model, args)
30 | self.model.eval()
31 | self.first_stage_model = self.model.first_stage_model
32 |
33 | def get_model(self):
34 | return self.model
35 |
36 | def get_first_stage_model(self):
37 | return self.first_stage_model
38 |
39 | model_handler = ModelHandler()
40 |
41 | # Utility functions
42 | def get_unique_embedder_keys_from_conditioner(conditioner):
43 | return list(set([x.input_key for x in conditioner.embedders]))
44 |
45 | def get_batch(keys, value_dict, N: List[int], T=None, device="cuda"):
46 | batch = {}
47 | batch_uc = {}
48 |
49 | for key in keys:
50 | if key == "txt":
51 | batch["txt"] = np.repeat([value_dict["prompt"]], repeats=math.prod(N)).reshape(N).tolist()
52 | batch_uc["txt"] = np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)).reshape(N).tolist()
53 | else:
54 | batch[key] = value_dict[key]
55 |
56 | if T is not None:
57 | batch["num_video_frames"] = T
58 |
59 | for key in batch.keys():
60 | if key not in batch_uc and isinstance(batch[key], torch.Tensor):
61 | batch_uc[key] = torch.clone(batch[key])
62 | return batch, batch_uc
63 |
64 | def save_video_as_grid_and_mp4(video_batch: torch.Tensor, save_path: str, fps: int = 5):
65 | os.makedirs(save_path, exist_ok=True)
66 |
67 | for i, vid in enumerate(video_batch):
68 | gif_frames = []
69 | for frame in vid:
70 | frame = rearrange(frame, "c h w -> h w c")
71 | frame = (255.0 * frame).cpu().numpy().astype(np.uint8)
72 | gif_frames.append(frame)
73 | now_save_path = os.path.join(save_path, f"{i:06d}.mp4")
74 | with imageio.get_writer(now_save_path, fps=fps) as writer:
75 | for frame in gif_frames:
76 | writer.append_data(frame)
77 |
78 | # Main inference function
79 | def infer(prompt: str, sampling_num_frames: int, batch_size: int, latent_channels: int, sampling_fps: int):
80 | args = get_args(['--base', 'configs/cogvideox_2b.yaml', 'configs/inference.yaml', '--seed', '42'])
81 | args = argparse.Namespace(**vars(args))
82 | del args.deepspeed_config
83 | args.model_config.first_stage_config.params.cp_size = 1
84 | args.model_config.network_config.params.transformer_args.model_parallel_size = 1
85 | args.model_config.network_config.params.transformer_args.checkpoint_activations = False
86 | args.model_config.loss_fn_config.params.sigma_sampler_config.params.uniform_sampling = False
87 |
88 | model_handler.load_model(args)
89 | model = model_handler.get_model()
90 | first_stage_model = model_handler.get_first_stage_model()
91 |
92 | image_size = [480, 720]
93 | T, H, W, C, F = args.sampling_num_frames, image_size[0], image_size[1], args.latent_channels, 8
94 | num_samples = [1]
95 |
96 | value_dict = {"prompt": prompt, "negative_prompt": "", "num_frames": torch.tensor(T).unsqueeze(0)}
97 |
98 | batch, batch_uc = get_batch(
99 | get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, num_samples
100 | )
101 |
102 | c, uc = model.conditioner.get_unconditional_conditioning(
103 | batch,
104 | batch_uc=batch_uc,
105 | force_uc_zero_embeddings=["txt"],
106 | )
107 | for k in c:
108 | if not k == "crossattn":
109 | c[k], uc[k] = map(lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc))
110 |
111 | samples_z = model.sample(
112 | c,
113 | uc=uc,
114 | batch_size=batch_size,
115 | shape=(T, C, H // F, W // F),
116 | )
117 | samples_z = samples_z.permute(0, 2, 1, 3, 4).contiguous()
118 |
119 | latent = 1.0 / model.scale_factor * samples_z
120 |
121 | # recons = []
122 | # loop_num = (T - 1) // 2
123 | # for i in range(loop_num):
124 | # start_frame, end_frame = i * 2 + 1, i * 2 + 3 if i != 0 else (0, 3)
125 | # recon = first_stage_model.decode(latent[:, :, start_frame:end_frame].contiguous())
126 | # recons.append(recon)
127 | recons = []
128 | loop_num = (T - 1) // 2
129 | for i in range(loop_num):
130 | if i == 0:
131 | start_frame, end_frame = 0, 3
132 | else:
133 | start_frame, end_frame = i * 2 + 1, i * 2 + 3
134 | if i == loop_num - 1:
135 | clear_fake_cp_cache = True
136 | else:
137 | clear_fake_cp_cache = False
138 | with torch.no_grad():
139 | recon = first_stage_model.decode(
140 | latent[:, :, start_frame:end_frame].contiguous(), clear_fake_cp_cache=clear_fake_cp_cache
141 | )
142 |
143 | recons.append(recon)
144 |
145 | recon = torch.cat(recons, dim=2).to(torch.float32)
146 | samples_x = recon.permute(0, 2, 1, 3, 4).contiguous()
147 | samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0).cpu()
148 |
149 | save_path = "outputs/demo"
150 | save_video_as_grid_and_mp4(samples, save_path, fps=sampling_fps)
151 | return os.path.join(save_path, "000000.mp4")
152 |
153 | # Gradio Interface
154 | def demo_interface(prompt):
155 | video_path = infer(prompt, 16, 1, 8, 5)
156 | return video_path
157 |
158 | with gr.Blocks() as demo:
159 | gr.Markdown("""# RepVideo Gradio Demo
160 | Generate high-quality videos based on text prompts.
161 | """)
162 |
163 | with gr.Row():
164 | with gr.Column():
165 | prompt = gr.Textbox(label="Prompt", placeholder="Enter your text prompt here.", lines=3)
166 | generate_button = gr.Button("Generate Video")
167 |
168 | with gr.Column():
169 | video_output = gr.Video(label="Generated Video")
170 |
171 | generate_button.click(
172 | demo_interface,
173 | inputs=[prompt],
174 | outputs=[video_output]
175 | )
176 |
177 | demo.launch(server_name="127.0.0.1", server_port=7860)
178 |
--------------------------------------------------------------------------------
/sat/run.sh:
--------------------------------------------------------------------------------
1 |
2 | export CUDA_VISIBLE_DEVICES=0
3 | python sample_video.py --base configs/cogvideox_2b.yaml configs/inference.yaml --seed 42
4 |
--------------------------------------------------------------------------------
/sat/sample_video.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import argparse
4 | from typing import List, Union
5 | from tqdm import tqdm
6 | from omegaconf import ListConfig
7 | import imageio
8 |
9 | import torch
10 | import numpy as np
11 | from einops import rearrange
12 | import torchvision.transforms as TT
13 |
14 | from sat.model.base_model import get_model
15 | from sat.training.model_io import load_checkpoint
16 | from sat import mpu
17 |
18 | from diffusion_video import SATVideoDiffusionEngine
19 | from arguments import get_args
20 | from torchvision.transforms.functional import center_crop, resize
21 | from torchvision.transforms import InterpolationMode
22 |
23 |
24 | def read_from_cli():
25 | cnt = 0
26 | try:
27 | while True:
28 | x = input("Please input English text (Ctrl-D quit): ")
29 | yield x.strip(), cnt
30 | cnt += 1
31 | except EOFError as e:
32 | pass
33 |
34 |
35 | def read_from_file(p, rank=0, world_size=1):
36 | with open(p, "r") as fin:
37 | cnt = -1
38 | for l in fin:
39 | cnt += 1
40 | if cnt % world_size != rank:
41 | continue
42 | yield l.strip(), cnt
43 |
44 |
45 | def get_unique_embedder_keys_from_conditioner(conditioner):
46 | return list(set([x.input_key for x in conditioner.embedders]))
47 |
48 |
49 | def get_batch(keys, value_dict, N: Union[List, ListConfig], T=None, device="cuda"):
50 | batch = {}
51 | batch_uc = {}
52 |
53 | for key in keys:
54 | if key == "txt":
55 | batch["txt"] = np.repeat([value_dict["prompt"]], repeats=math.prod(N)).reshape(N).tolist()
56 | batch_uc["txt"] = np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)).reshape(N).tolist()
57 | else:
58 | batch[key] = value_dict[key]
59 |
60 | if T is not None:
61 | batch["num_video_frames"] = T
62 |
63 | for key in batch.keys():
64 | if key not in batch_uc and isinstance(batch[key], torch.Tensor):
65 | batch_uc[key] = torch.clone(batch[key])
66 | return batch, batch_uc
67 |
68 |
69 | def save_video_as_grid_and_mp4(video_batch: torch.Tensor, save_path: str, fps: int = 5, args=None, key=None):
70 | os.makedirs(save_path, exist_ok=True)
71 |
72 | for i, vid in enumerate(video_batch):
73 | gif_frames = []
74 | for frame in vid:
75 | frame = rearrange(frame, "c h w -> h w c")
76 | frame = (255.0 * frame).cpu().numpy().astype(np.uint8)
77 | gif_frames.append(frame)
78 | now_save_path = os.path.join(save_path, f"{i:06d}.mp4")
79 | with imageio.get_writer(now_save_path, fps=fps) as writer:
80 | for frame in gif_frames:
81 | writer.append_data(frame)
82 |
83 |
84 | def resize_for_rectangle_crop(arr, image_size, reshape_mode="random"):
85 | if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]:
86 | arr = resize(
87 | arr,
88 | size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])],
89 | interpolation=InterpolationMode.BICUBIC,
90 | )
91 | else:
92 | arr = resize(
93 | arr,
94 | size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]],
95 | interpolation=InterpolationMode.BICUBIC,
96 | )
97 |
98 | h, w = arr.shape[2], arr.shape[3]
99 | arr = arr.squeeze(0)
100 |
101 | delta_h = h - image_size[0]
102 | delta_w = w - image_size[1]
103 |
104 | if reshape_mode == "random" or reshape_mode == "none":
105 | top = np.random.randint(0, delta_h + 1)
106 | left = np.random.randint(0, delta_w + 1)
107 | elif reshape_mode == "center":
108 | top, left = delta_h // 2, delta_w // 2
109 | else:
110 | raise NotImplementedError
111 | arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1])
112 | return arr
113 |
114 |
115 | def sampling_main(args, model_cls):
116 | if isinstance(model_cls, type):
117 | model = get_model(args, model_cls)
118 | else:
119 | model = model_cls
120 |
121 | load_checkpoint(model, args)
122 | model.eval()
123 |
124 | if args.input_type == "cli":
125 | data_iter = read_from_cli()
126 | elif args.input_type == "txt":
127 | rank, world_size = mpu.get_data_parallel_rank(), mpu.get_data_parallel_world_size()
128 | print("rank and world_size", rank, world_size)
129 | data_iter = read_from_file(args.input_file, rank=rank, world_size=world_size)
130 | else:
131 | raise NotImplementedError
132 |
133 | image_size = [480, 720]
134 |
135 | sample_func = model.sample
136 | T, H, W, C, F = args.sampling_num_frames, image_size[0], image_size[1], args.latent_channels, 8
137 | num_samples = [1]
138 | force_uc_zero_embeddings = ["txt"]
139 | device = model.device
140 | with torch.no_grad():
141 | for text, cnt in tqdm(data_iter):
142 | # reload model on GPU
143 | model.to(device)
144 | print("rank:", rank, "start to process", text, cnt)
145 | # TODO: broadcast image2video
146 | value_dict = {
147 | "prompt": text,
148 | "negative_prompt": "",
149 | "num_frames": torch.tensor(T).unsqueeze(0),
150 | }
151 |
152 | batch, batch_uc = get_batch(
153 | get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, num_samples
154 | )
155 | for key in batch:
156 | if isinstance(batch[key], torch.Tensor):
157 | print(key, batch[key].shape)
158 | elif isinstance(batch[key], list):
159 | print(key, [len(l) for l in batch[key]])
160 | else:
161 | print(key, batch[key])
162 | c, uc = model.conditioner.get_unconditional_conditioning(
163 | batch,
164 | batch_uc=batch_uc,
165 | force_uc_zero_embeddings=force_uc_zero_embeddings,
166 | )
167 |
168 | for k in c:
169 | if not k == "crossattn":
170 | c[k], uc[k] = map(lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc))
171 | for index in range(args.batch_size):
172 | # reload model on GPU
173 | model.to(device)
174 | print('c: ',c.shape)
175 | print('uc: ',uc.shape)
176 | s
177 | samples_z = sample_func(
178 | c,
179 | uc=uc,
180 | batch_size=1,
181 | shape=(T, C, H // F, W // F),
182 | )
183 | # print(samples_z.shape)
184 | samples_z = samples_z.permute(0, 2, 1, 3, 4).contiguous()
185 |
186 | # Unload the model from GPU to save GPU memory
187 | model.to("cpu")
188 | torch.cuda.empty_cache()
189 | first_stage_model = model.first_stage_model
190 | first_stage_model = first_stage_model.to(device)
191 |
192 |
193 |
194 | latent = 1.0 / model.scale_factor * samples_z
195 |
196 | # Decode latent serial to save GPU memory
197 | recons = []
198 | loop_num = (T - 1) // 2
199 | for i in range(loop_num):
200 | if i == 0:
201 | start_frame, end_frame = 0, 3
202 | else:
203 | start_frame, end_frame = i * 2 + 1, i * 2 + 3
204 | if i == loop_num - 1:
205 | clear_fake_cp_cache = True
206 | else:
207 | clear_fake_cp_cache = False
208 | with torch.no_grad():
209 | recon = first_stage_model.decode(
210 | latent[:, :, start_frame:end_frame].contiguous(), clear_fake_cp_cache=clear_fake_cp_cache
211 | )
212 |
213 | recons.append(recon)
214 |
215 | recon = torch.cat(recons, dim=2).to(torch.float32)
216 | samples_x = recon.permute(0, 2, 1, 3, 4).contiguous()
217 | samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0).cpu()
218 |
219 | save_path = os.path.join(
220 | args.output_dir, str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:120], str(index)
221 | )
222 | if mpu.get_model_parallel_rank() == 0:
223 | save_video_as_grid_and_mp4(samples, save_path, fps=args.sampling_fps)
224 |
225 |
226 | if __name__ == "__main__":
227 | if "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ:
228 | os.environ["LOCAL_RANK"] = os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]
229 | os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"]
230 | os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]
231 | py_parser = argparse.ArgumentParser(add_help=False)
232 | known, args_list = py_parser.parse_known_args()
233 |
234 | args = get_args(args_list)
235 | args = argparse.Namespace(**vars(args), **vars(known))
236 | del args.deepspeed_config
237 | args.model_config.first_stage_config.params.cp_size = 1
238 | args.model_config.network_config.params.transformer_args.model_parallel_size = 1
239 | args.model_config.network_config.params.transformer_args.checkpoint_activations = False
240 | args.model_config.loss_fn_config.params.sigma_sampler_config.params.uniform_sampling = False
241 |
242 | sampling_main(args, model_cls=SATVideoDiffusionEngine)
243 |
--------------------------------------------------------------------------------
/sat/sgm/__init__.py:
--------------------------------------------------------------------------------
1 | from .models import AutoencodingEngine
2 | from .util import get_configs_path, instantiate_from_config
3 |
4 | __version__ = "0.1.0"
5 |
--------------------------------------------------------------------------------
/sat/sgm/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class LambdaWarmUpCosineScheduler:
5 | """
6 | note: use with a base_lr of 1.0
7 | """
8 |
9 | def __init__(
10 | self,
11 | warm_up_steps,
12 | lr_min,
13 | lr_max,
14 | lr_start,
15 | max_decay_steps,
16 | verbosity_interval=0,
17 | ):
18 | self.lr_warm_up_steps = warm_up_steps
19 | self.lr_start = lr_start
20 | self.lr_min = lr_min
21 | self.lr_max = lr_max
22 | self.lr_max_decay_steps = max_decay_steps
23 | self.last_lr = 0.0
24 | self.verbosity_interval = verbosity_interval
25 |
26 | def schedule(self, n, **kwargs):
27 | if self.verbosity_interval > 0:
28 | if n % self.verbosity_interval == 0:
29 | print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
30 | if n < self.lr_warm_up_steps:
31 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
32 | self.last_lr = lr
33 | return lr
34 | else:
35 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
36 | t = min(t, 1.0)
37 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (1 + np.cos(t * np.pi))
38 | self.last_lr = lr
39 | return lr
40 |
41 | def __call__(self, n, **kwargs):
42 | return self.schedule(n, **kwargs)
43 |
44 |
45 | class LambdaWarmUpCosineScheduler2:
46 | """
47 | supports repeated iterations, configurable via lists
48 | note: use with a base_lr of 1.0.
49 | """
50 |
51 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
52 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
53 | self.lr_warm_up_steps = warm_up_steps
54 | self.f_start = f_start
55 | self.f_min = f_min
56 | self.f_max = f_max
57 | self.cycle_lengths = cycle_lengths
58 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
59 | self.last_f = 0.0
60 | self.verbosity_interval = verbosity_interval
61 |
62 | def find_in_interval(self, n):
63 | interval = 0
64 | for cl in self.cum_cycles[1:]:
65 | if n <= cl:
66 | return interval
67 | interval += 1
68 |
69 | def schedule(self, n, **kwargs):
70 | cycle = self.find_in_interval(n)
71 | n = n - self.cum_cycles[cycle]
72 | if self.verbosity_interval > 0:
73 | if n % self.verbosity_interval == 0:
74 | print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " f"current cycle {cycle}")
75 | if n < self.lr_warm_up_steps[cycle]:
76 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
77 | self.last_f = f
78 | return f
79 | else:
80 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
81 | t = min(t, 1.0)
82 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (1 + np.cos(t * np.pi))
83 | self.last_f = f
84 | return f
85 |
86 | def __call__(self, n, **kwargs):
87 | return self.schedule(n, **kwargs)
88 |
89 |
90 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
91 | def schedule(self, n, **kwargs):
92 | cycle = self.find_in_interval(n)
93 | n = n - self.cum_cycles[cycle]
94 | if self.verbosity_interval > 0:
95 | if n % self.verbosity_interval == 0:
96 | print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " f"current cycle {cycle}")
97 |
98 | if n < self.lr_warm_up_steps[cycle]:
99 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
100 | self.last_f = f
101 | return f
102 | else:
103 | f = (
104 | self.f_min[cycle]
105 | + (self.f_max[cycle] - self.f_min[cycle])
106 | * (self.cycle_lengths[cycle] - n)
107 | / (self.cycle_lengths[cycle])
108 | )
109 | self.last_f = f
110 | return f
111 |
--------------------------------------------------------------------------------
/sat/sgm/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .autoencoder import AutoencodingEngine
2 |
--------------------------------------------------------------------------------
/sat/sgm/modules/__init__.py:
--------------------------------------------------------------------------------
1 | from .encoders.modules import GeneralConditioner
2 |
3 | UNCONDITIONAL_CONFIG = {
4 | "target": "sgm.modules.GeneralConditioner",
5 | "params": {"emb_models": []},
6 | }
7 |
--------------------------------------------------------------------------------
/sat/sgm/modules/autoencoding/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Vchitect/RepVideo/7727c0c592a152b9cd09a47d9d1ba3b81792e25f/sat/sgm/modules/autoencoding/__init__.py
--------------------------------------------------------------------------------
/sat/sgm/modules/autoencoding/losses/__init__.py:
--------------------------------------------------------------------------------
1 | __all__ = [
2 | "GeneralLPIPSWithDiscriminator",
3 | "LatentLPIPS",
4 | ]
5 |
6 | from .discriminator_loss import GeneralLPIPSWithDiscriminator
7 | from .lpips import LatentLPIPS
8 | from .video_loss import VideoAutoencoderLoss
9 |
--------------------------------------------------------------------------------
/sat/sgm/modules/autoencoding/losses/lpips.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from ....util import default, instantiate_from_config
5 | from ..lpips.loss.lpips import LPIPS
6 |
7 |
8 | class LatentLPIPS(nn.Module):
9 | def __init__(
10 | self,
11 | decoder_config,
12 | perceptual_weight=1.0,
13 | latent_weight=1.0,
14 | scale_input_to_tgt_size=False,
15 | scale_tgt_to_input_size=False,
16 | perceptual_weight_on_inputs=0.0,
17 | ):
18 | super().__init__()
19 | self.scale_input_to_tgt_size = scale_input_to_tgt_size
20 | self.scale_tgt_to_input_size = scale_tgt_to_input_size
21 | self.init_decoder(decoder_config)
22 | self.perceptual_loss = LPIPS().eval()
23 | self.perceptual_weight = perceptual_weight
24 | self.latent_weight = latent_weight
25 | self.perceptual_weight_on_inputs = perceptual_weight_on_inputs
26 |
27 | def init_decoder(self, config):
28 | self.decoder = instantiate_from_config(config)
29 | if hasattr(self.decoder, "encoder"):
30 | del self.decoder.encoder
31 |
32 | def forward(self, latent_inputs, latent_predictions, image_inputs, split="train"):
33 | log = dict()
34 | loss = (latent_inputs - latent_predictions) ** 2
35 | log[f"{split}/latent_l2_loss"] = loss.mean().detach()
36 | image_reconstructions = None
37 | if self.perceptual_weight > 0.0:
38 | image_reconstructions = self.decoder.decode(latent_predictions)
39 | image_targets = self.decoder.decode(latent_inputs)
40 | perceptual_loss = self.perceptual_loss(image_targets.contiguous(), image_reconstructions.contiguous())
41 | loss = self.latent_weight * loss.mean() + self.perceptual_weight * perceptual_loss.mean()
42 | log[f"{split}/perceptual_loss"] = perceptual_loss.mean().detach()
43 |
44 | if self.perceptual_weight_on_inputs > 0.0:
45 | image_reconstructions = default(image_reconstructions, self.decoder.decode(latent_predictions))
46 | if self.scale_input_to_tgt_size:
47 | image_inputs = torch.nn.functional.interpolate(
48 | image_inputs,
49 | image_reconstructions.shape[2:],
50 | mode="bicubic",
51 | antialias=True,
52 | )
53 | elif self.scale_tgt_to_input_size:
54 | image_reconstructions = torch.nn.functional.interpolate(
55 | image_reconstructions,
56 | image_inputs.shape[2:],
57 | mode="bicubic",
58 | antialias=True,
59 | )
60 |
61 | perceptual_loss2 = self.perceptual_loss(image_inputs.contiguous(), image_reconstructions.contiguous())
62 | loss = loss + self.perceptual_weight_on_inputs * perceptual_loss2.mean()
63 | log[f"{split}/perceptual_loss_on_inputs"] = perceptual_loss2.mean().detach()
64 | return loss, log
65 |
--------------------------------------------------------------------------------
/sat/sgm/modules/autoencoding/lpips/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Vchitect/RepVideo/7727c0c592a152b9cd09a47d9d1ba3b81792e25f/sat/sgm/modules/autoencoding/lpips/__init__.py
--------------------------------------------------------------------------------
/sat/sgm/modules/autoencoding/lpips/loss/.gitignore:
--------------------------------------------------------------------------------
1 | vgg.pth
--------------------------------------------------------------------------------
/sat/sgm/modules/autoencoding/lpips/loss/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang
2 | All rights reserved.
3 |
4 | Redistribution and use in source and binary forms, with or without
5 | modification, are permitted provided that the following conditions are met:
6 |
7 | * Redistributions of source code must retain the above copyright notice, this
8 | list of conditions and the following disclaimer.
9 |
10 | * Redistributions in binary form must reproduce the above copyright notice,
11 | this list of conditions and the following disclaimer in the documentation
12 | and/or other materials provided with the distribution.
13 |
14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
15 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
16 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
17 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
18 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
19 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
20 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
21 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
22 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
23 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
--------------------------------------------------------------------------------
/sat/sgm/modules/autoencoding/lpips/loss/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Vchitect/RepVideo/7727c0c592a152b9cd09a47d9d1ba3b81792e25f/sat/sgm/modules/autoencoding/lpips/loss/__init__.py
--------------------------------------------------------------------------------
/sat/sgm/modules/autoencoding/lpips/loss/lpips.py:
--------------------------------------------------------------------------------
1 | """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
2 |
3 | from collections import namedtuple
4 |
5 | import torch
6 | import torch.nn as nn
7 | from torchvision import models
8 |
9 | from ..util import get_ckpt_path
10 |
11 |
12 | class LPIPS(nn.Module):
13 | # Learned perceptual metric
14 | def __init__(self, use_dropout=True):
15 | super().__init__()
16 | self.scaling_layer = ScalingLayer()
17 | self.chns = [64, 128, 256, 512, 512] # vg16 features
18 | self.net = vgg16(pretrained=True, requires_grad=False)
19 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
20 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
21 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
22 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
23 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
24 | self.load_from_pretrained()
25 | for param in self.parameters():
26 | param.requires_grad = False
27 |
28 | def load_from_pretrained(self, name="vgg_lpips"):
29 | ckpt = get_ckpt_path(name, "sgm/modules/autoencoding/lpips/loss")
30 | self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
31 | print("loaded pretrained LPIPS loss from {}".format(ckpt))
32 |
33 | @classmethod
34 | def from_pretrained(cls, name="vgg_lpips"):
35 | if name != "vgg_lpips":
36 | raise NotImplementedError
37 | model = cls()
38 | ckpt = get_ckpt_path(name)
39 | model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
40 | return model
41 |
42 | def forward(self, input, target):
43 | in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
44 | outs0, outs1 = self.net(in0_input), self.net(in1_input)
45 | feats0, feats1, diffs = {}, {}, {}
46 | lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
47 | for kk in range(len(self.chns)):
48 | feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
49 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
50 |
51 | res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))]
52 | val = res[0]
53 | for l in range(1, len(self.chns)):
54 | val += res[l]
55 | return val
56 |
57 |
58 | class ScalingLayer(nn.Module):
59 | def __init__(self):
60 | super(ScalingLayer, self).__init__()
61 | self.register_buffer("shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None])
62 | self.register_buffer("scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None])
63 |
64 | def forward(self, inp):
65 | return (inp - self.shift) / self.scale
66 |
67 |
68 | class NetLinLayer(nn.Module):
69 | """A single linear layer which does a 1x1 conv"""
70 |
71 | def __init__(self, chn_in, chn_out=1, use_dropout=False):
72 | super(NetLinLayer, self).__init__()
73 | layers = (
74 | [
75 | nn.Dropout(),
76 | ]
77 | if (use_dropout)
78 | else []
79 | )
80 | layers += [
81 | nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),
82 | ]
83 | self.model = nn.Sequential(*layers)
84 |
85 |
86 | class vgg16(torch.nn.Module):
87 | def __init__(self, requires_grad=False, pretrained=True):
88 | super(vgg16, self).__init__()
89 | vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
90 | self.slice1 = torch.nn.Sequential()
91 | self.slice2 = torch.nn.Sequential()
92 | self.slice3 = torch.nn.Sequential()
93 | self.slice4 = torch.nn.Sequential()
94 | self.slice5 = torch.nn.Sequential()
95 | self.N_slices = 5
96 | for x in range(4):
97 | self.slice1.add_module(str(x), vgg_pretrained_features[x])
98 | for x in range(4, 9):
99 | self.slice2.add_module(str(x), vgg_pretrained_features[x])
100 | for x in range(9, 16):
101 | self.slice3.add_module(str(x), vgg_pretrained_features[x])
102 | for x in range(16, 23):
103 | self.slice4.add_module(str(x), vgg_pretrained_features[x])
104 | for x in range(23, 30):
105 | self.slice5.add_module(str(x), vgg_pretrained_features[x])
106 | if not requires_grad:
107 | for param in self.parameters():
108 | param.requires_grad = False
109 |
110 | def forward(self, X):
111 | h = self.slice1(X)
112 | h_relu1_2 = h
113 | h = self.slice2(h)
114 | h_relu2_2 = h
115 | h = self.slice3(h)
116 | h_relu3_3 = h
117 | h = self.slice4(h)
118 | h_relu4_3 = h
119 | h = self.slice5(h)
120 | h_relu5_3 = h
121 | vgg_outputs = namedtuple("VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"])
122 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
123 | return out
124 |
125 |
126 | def normalize_tensor(x, eps=1e-10):
127 | norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True))
128 | return x / (norm_factor + eps)
129 |
130 |
131 | def spatial_average(x, keepdim=True):
132 | return x.mean([2, 3], keepdim=keepdim)
133 |
--------------------------------------------------------------------------------
/sat/sgm/modules/autoencoding/lpips/model/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2017, Jun-Yan Zhu and Taesung Park
2 | All rights reserved.
3 |
4 | Redistribution and use in source and binary forms, with or without
5 | modification, are permitted provided that the following conditions are met:
6 |
7 | * Redistributions of source code must retain the above copyright notice, this
8 | list of conditions and the following disclaimer.
9 |
10 | * Redistributions in binary form must reproduce the above copyright notice,
11 | this list of conditions and the following disclaimer in the documentation
12 | and/or other materials provided with the distribution.
13 |
14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
15 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
16 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
17 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
18 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
19 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
20 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
21 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
22 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
23 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24 |
25 |
26 | --------------------------- LICENSE FOR pix2pix --------------------------------
27 | BSD License
28 |
29 | For pix2pix software
30 | Copyright (c) 2016, Phillip Isola and Jun-Yan Zhu
31 | All rights reserved.
32 |
33 | Redistribution and use in source and binary forms, with or without
34 | modification, are permitted provided that the following conditions are met:
35 |
36 | * Redistributions of source code must retain the above copyright notice, this
37 | list of conditions and the following disclaimer.
38 |
39 | * Redistributions in binary form must reproduce the above copyright notice,
40 | this list of conditions and the following disclaimer in the documentation
41 | and/or other materials provided with the distribution.
42 |
43 | ----------------------------- LICENSE FOR DCGAN --------------------------------
44 | BSD License
45 |
46 | For dcgan.torch software
47 |
48 | Copyright (c) 2015, Facebook, Inc. All rights reserved.
49 |
50 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
51 |
52 | Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
53 |
54 | Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
55 |
56 | Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
57 |
58 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
--------------------------------------------------------------------------------
/sat/sgm/modules/autoencoding/lpips/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Vchitect/RepVideo/7727c0c592a152b9cd09a47d9d1ba3b81792e25f/sat/sgm/modules/autoencoding/lpips/model/__init__.py
--------------------------------------------------------------------------------
/sat/sgm/modules/autoencoding/lpips/model/model.py:
--------------------------------------------------------------------------------
1 | import functools
2 |
3 | import torch.nn as nn
4 |
5 | from ..util import ActNorm
6 |
7 |
8 | def weights_init(m):
9 | classname = m.__class__.__name__
10 | if classname.find("Conv") != -1:
11 | try:
12 | nn.init.normal_(m.weight.data, 0.0, 0.02)
13 | except:
14 | nn.init.normal_(m.conv.weight.data, 0.0, 0.02)
15 | elif classname.find("BatchNorm") != -1:
16 | nn.init.normal_(m.weight.data, 1.0, 0.02)
17 | nn.init.constant_(m.bias.data, 0)
18 |
19 |
20 | class NLayerDiscriminator(nn.Module):
21 | """Defines a PatchGAN discriminator as in Pix2Pix
22 | --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
23 | """
24 |
25 | def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
26 | """Construct a PatchGAN discriminator
27 | Parameters:
28 | input_nc (int) -- the number of channels in input images
29 | ndf (int) -- the number of filters in the last conv layer
30 | n_layers (int) -- the number of conv layers in the discriminator
31 | norm_layer -- normalization layer
32 | """
33 | super(NLayerDiscriminator, self).__init__()
34 | if not use_actnorm:
35 | norm_layer = nn.BatchNorm2d
36 | else:
37 | norm_layer = ActNorm
38 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
39 | use_bias = norm_layer.func != nn.BatchNorm2d
40 | else:
41 | use_bias = norm_layer != nn.BatchNorm2d
42 |
43 | kw = 4
44 | padw = 1
45 | sequence = [
46 | nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
47 | nn.LeakyReLU(0.2, True),
48 | ]
49 | nf_mult = 1
50 | nf_mult_prev = 1
51 | for n in range(1, n_layers): # gradually increase the number of filters
52 | nf_mult_prev = nf_mult
53 | nf_mult = min(2**n, 8)
54 | sequence += [
55 | nn.Conv2d(
56 | ndf * nf_mult_prev,
57 | ndf * nf_mult,
58 | kernel_size=kw,
59 | stride=2,
60 | padding=padw,
61 | bias=use_bias,
62 | ),
63 | norm_layer(ndf * nf_mult),
64 | nn.LeakyReLU(0.2, True),
65 | ]
66 |
67 | nf_mult_prev = nf_mult
68 | nf_mult = min(2**n_layers, 8)
69 | sequence += [
70 | nn.Conv2d(
71 | ndf * nf_mult_prev,
72 | ndf * nf_mult,
73 | kernel_size=kw,
74 | stride=1,
75 | padding=padw,
76 | bias=use_bias,
77 | ),
78 | norm_layer(ndf * nf_mult),
79 | nn.LeakyReLU(0.2, True),
80 | ]
81 |
82 | sequence += [
83 | nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)
84 | ] # output 1 channel prediction map
85 | self.main = nn.Sequential(*sequence)
86 |
87 | def forward(self, input):
88 | """Standard forward."""
89 | return self.main(input)
90 |
--------------------------------------------------------------------------------
/sat/sgm/modules/autoencoding/lpips/util.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import os
3 |
4 | import requests
5 | import torch
6 | import torch.nn as nn
7 | from tqdm import tqdm
8 |
9 | URL_MAP = {"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"}
10 |
11 | CKPT_MAP = {"vgg_lpips": "vgg.pth"}
12 |
13 | MD5_MAP = {"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"}
14 |
15 |
16 | def download(url, local_path, chunk_size=1024):
17 | os.makedirs(os.path.split(local_path)[0], exist_ok=True)
18 | with requests.get(url, stream=True) as r:
19 | total_size = int(r.headers.get("content-length", 0))
20 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
21 | with open(local_path, "wb") as f:
22 | for data in r.iter_content(chunk_size=chunk_size):
23 | if data:
24 | f.write(data)
25 | pbar.update(chunk_size)
26 |
27 |
28 | def md5_hash(path):
29 | with open(path, "rb") as f:
30 | content = f.read()
31 | return hashlib.md5(content).hexdigest()
32 |
33 |
34 | def get_ckpt_path(name, root, check=False):
35 | assert name in URL_MAP
36 | path = os.path.join(root, CKPT_MAP[name])
37 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
38 | print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
39 | download(URL_MAP[name], path)
40 | md5 = md5_hash(path)
41 | assert md5 == MD5_MAP[name], md5
42 | return path
43 |
44 |
45 | class ActNorm(nn.Module):
46 | def __init__(self, num_features, logdet=False, affine=True, allow_reverse_init=False):
47 | assert affine
48 | super().__init__()
49 | self.logdet = logdet
50 | self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
51 | self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
52 | self.allow_reverse_init = allow_reverse_init
53 |
54 | self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8))
55 |
56 | def initialize(self, input):
57 | with torch.no_grad():
58 | flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
59 | mean = flatten.mean(1).unsqueeze(1).unsqueeze(2).unsqueeze(3).permute(1, 0, 2, 3)
60 | std = flatten.std(1).unsqueeze(1).unsqueeze(2).unsqueeze(3).permute(1, 0, 2, 3)
61 |
62 | self.loc.data.copy_(-mean)
63 | self.scale.data.copy_(1 / (std + 1e-6))
64 |
65 | def forward(self, input, reverse=False):
66 | if reverse:
67 | return self.reverse(input)
68 | if len(input.shape) == 2:
69 | input = input[:, :, None, None]
70 | squeeze = True
71 | else:
72 | squeeze = False
73 |
74 | _, _, height, width = input.shape
75 |
76 | if self.training and self.initialized.item() == 0:
77 | self.initialize(input)
78 | self.initialized.fill_(1)
79 |
80 | h = self.scale * (input + self.loc)
81 |
82 | if squeeze:
83 | h = h.squeeze(-1).squeeze(-1)
84 |
85 | if self.logdet:
86 | log_abs = torch.log(torch.abs(self.scale))
87 | logdet = height * width * torch.sum(log_abs)
88 | logdet = logdet * torch.ones(input.shape[0]).to(input)
89 | return h, logdet
90 |
91 | return h
92 |
93 | def reverse(self, output):
94 | if self.training and self.initialized.item() == 0:
95 | if not self.allow_reverse_init:
96 | raise RuntimeError(
97 | "Initializing ActNorm in reverse direction is "
98 | "disabled by default. Use allow_reverse_init=True to enable."
99 | )
100 | else:
101 | self.initialize(output)
102 | self.initialized.fill_(1)
103 |
104 | if len(output.shape) == 2:
105 | output = output[:, :, None, None]
106 | squeeze = True
107 | else:
108 | squeeze = False
109 |
110 | h = output / self.scale - self.loc
111 |
112 | if squeeze:
113 | h = h.squeeze(-1).squeeze(-1)
114 | return h
115 |
--------------------------------------------------------------------------------
/sat/sgm/modules/autoencoding/lpips/vqperceptual.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 |
5 | def hinge_d_loss(logits_real, logits_fake):
6 | loss_real = torch.mean(F.relu(1.0 - logits_real))
7 | loss_fake = torch.mean(F.relu(1.0 + logits_fake))
8 | d_loss = 0.5 * (loss_real + loss_fake)
9 | return d_loss
10 |
11 |
12 | def vanilla_d_loss(logits_real, logits_fake):
13 | d_loss = 0.5 * (
14 | torch.mean(torch.nn.functional.softplus(-logits_real)) + torch.mean(torch.nn.functional.softplus(logits_fake))
15 | )
16 | return d_loss
17 |
--------------------------------------------------------------------------------
/sat/sgm/modules/autoencoding/regularizers/__init__.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod
2 | from typing import Any, Tuple
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 | from ....modules.distributions.distributions import DiagonalGaussianDistribution
9 | from .base import AbstractRegularizer
10 |
11 |
12 | class DiagonalGaussianRegularizer(AbstractRegularizer):
13 | def __init__(self, sample: bool = True):
14 | super().__init__()
15 | self.sample = sample
16 |
17 | def get_trainable_parameters(self) -> Any:
18 | yield from ()
19 |
20 | def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
21 | log = dict()
22 | posterior = DiagonalGaussianDistribution(z)
23 | if self.sample:
24 | z = posterior.sample()
25 | else:
26 | z = posterior.mode()
27 | kl_loss = posterior.kl()
28 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
29 | log["kl_loss"] = kl_loss
30 | return z, log
31 |
--------------------------------------------------------------------------------
/sat/sgm/modules/autoencoding/regularizers/base.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod
2 | from typing import Any, Tuple
3 |
4 | import torch
5 | import torch.nn.functional as F
6 | from torch import nn
7 |
8 |
9 | class AbstractRegularizer(nn.Module):
10 | def __init__(self):
11 | super().__init__()
12 |
13 | def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
14 | raise NotImplementedError()
15 |
16 | @abstractmethod
17 | def get_trainable_parameters(self) -> Any:
18 | raise NotImplementedError()
19 |
20 |
21 | class IdentityRegularizer(AbstractRegularizer):
22 | def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
23 | return z, dict()
24 |
25 | def get_trainable_parameters(self) -> Any:
26 | yield from ()
27 |
28 |
29 | def measure_perplexity(predicted_indices: torch.Tensor, num_centroids: int) -> Tuple[torch.Tensor, torch.Tensor]:
30 | # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
31 | # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
32 | encodings = F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids)
33 | avg_probs = encodings.mean(0)
34 | perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
35 | cluster_use = torch.sum(avg_probs > 0)
36 | return perplexity, cluster_use
37 |
--------------------------------------------------------------------------------
/sat/sgm/modules/autoencoding/regularizers/finite_scalar_quantization.py:
--------------------------------------------------------------------------------
1 | """
2 | Finite Scalar Quantization: VQ-VAE Made Simple - https://arxiv.org/abs/2309.15505
3 | Code adapted from Jax version in Appendix A.1
4 | """
5 |
6 | from typing import List, Optional
7 |
8 | import torch
9 | import torch.nn as nn
10 | from torch.nn import Module
11 | from torch import Tensor, int32
12 | from torch.cuda.amp import autocast
13 |
14 | from einops import rearrange, pack, unpack
15 |
16 | # helper functions
17 |
18 |
19 | def exists(v):
20 | return v is not None
21 |
22 |
23 | def default(*args):
24 | for arg in args:
25 | if exists(arg):
26 | return arg
27 | return None
28 |
29 |
30 | def pack_one(t, pattern):
31 | return pack([t], pattern)
32 |
33 |
34 | def unpack_one(t, ps, pattern):
35 | return unpack(t, ps, pattern)[0]
36 |
37 |
38 | # tensor helpers
39 |
40 |
41 | def round_ste(z: Tensor) -> Tensor:
42 | """Round with straight through gradients."""
43 | zhat = z.round()
44 | return z + (zhat - z).detach()
45 |
46 |
47 | # main class
48 |
49 |
50 | class FSQ(Module):
51 | def __init__(
52 | self,
53 | levels: List[int],
54 | dim: Optional[int] = None,
55 | num_codebooks=1,
56 | keep_num_codebooks_dim: Optional[bool] = None,
57 | scale: Optional[float] = None,
58 | ):
59 | super().__init__()
60 | _levels = torch.tensor(levels, dtype=int32)
61 | self.register_buffer("_levels", _levels, persistent=False)
62 |
63 | _basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=int32)
64 | self.register_buffer("_basis", _basis, persistent=False)
65 |
66 | self.scale = scale
67 |
68 | codebook_dim = len(levels)
69 | self.codebook_dim = codebook_dim
70 |
71 | effective_codebook_dim = codebook_dim * num_codebooks
72 | self.num_codebooks = num_codebooks
73 | self.effective_codebook_dim = effective_codebook_dim
74 |
75 | keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1)
76 | assert not (num_codebooks > 1 and not keep_num_codebooks_dim)
77 | self.keep_num_codebooks_dim = keep_num_codebooks_dim
78 |
79 | self.dim = default(dim, len(_levels) * num_codebooks)
80 |
81 | has_projections = self.dim != effective_codebook_dim
82 | self.project_in = nn.Linear(self.dim, effective_codebook_dim) if has_projections else nn.Identity()
83 | self.project_out = nn.Linear(effective_codebook_dim, self.dim) if has_projections else nn.Identity()
84 | self.has_projections = has_projections
85 |
86 | self.codebook_size = self._levels.prod().item()
87 |
88 | implicit_codebook = self.indices_to_codes(torch.arange(self.codebook_size), project_out=False)
89 | self.register_buffer("implicit_codebook", implicit_codebook, persistent=False)
90 |
91 | def bound(self, z: Tensor, eps: float = 1e-3) -> Tensor:
92 | """Bound `z`, an array of shape (..., d)."""
93 | half_l = (self._levels - 1) * (1 + eps) / 2
94 | offset = torch.where(self._levels % 2 == 0, 0.5, 0.0)
95 | shift = (offset / half_l).atanh()
96 | return (z + shift).tanh() * half_l - offset
97 |
98 | def quantize(self, z: Tensor) -> Tensor:
99 | """Quantizes z, returns quantized zhat, same shape as z."""
100 | quantized = round_ste(self.bound(z))
101 | half_width = self._levels // 2 # Renormalize to [-1, 1].
102 | return quantized / half_width
103 |
104 | def _scale_and_shift(self, zhat_normalized: Tensor) -> Tensor:
105 | half_width = self._levels // 2
106 | return (zhat_normalized * half_width) + half_width
107 |
108 | def _scale_and_shift_inverse(self, zhat: Tensor) -> Tensor:
109 | half_width = self._levels // 2
110 | return (zhat - half_width) / half_width
111 |
112 | def codes_to_indices(self, zhat: Tensor) -> Tensor:
113 | """Converts a `code` to an index in the codebook."""
114 | assert zhat.shape[-1] == self.codebook_dim
115 | zhat = self._scale_and_shift(zhat)
116 | return (zhat * self._basis).sum(dim=-1).to(int32)
117 |
118 | def indices_to_codes(self, indices: Tensor, project_out=True) -> Tensor:
119 | """Inverse of `codes_to_indices`."""
120 |
121 | is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))
122 |
123 | indices = rearrange(indices, "... -> ... 1")
124 | codes_non_centered = (indices // self._basis) % self._levels
125 | codes = self._scale_and_shift_inverse(codes_non_centered)
126 |
127 | if self.keep_num_codebooks_dim:
128 | codes = rearrange(codes, "... c d -> ... (c d)")
129 |
130 | if project_out:
131 | codes = self.project_out(codes)
132 |
133 | if is_img_or_video:
134 | codes = rearrange(codes, "b ... d -> b d ...")
135 |
136 | return codes
137 |
138 | @autocast(enabled=False)
139 | def forward(self, z: Tensor) -> Tensor:
140 | """
141 | einstein notation
142 | b - batch
143 | n - sequence (or flattened spatial dimensions)
144 | d - feature dimension
145 | c - number of codebook dim
146 | """
147 |
148 | is_img_or_video = z.ndim >= 4
149 |
150 | # standardize image or video into (batch, seq, dimension)
151 |
152 | if is_img_or_video:
153 | z = rearrange(z, "b d ... -> b ... d")
154 | z, ps = pack_one(z, "b * d")
155 |
156 | assert z.shape[-1] == self.dim, f"expected dimension of {self.dim} but found dimension of {z.shape[-1]}"
157 |
158 | z = self.project_in(z)
159 |
160 | z = rearrange(z, "b n (c d) -> b n c d", c=self.num_codebooks)
161 |
162 | codes = self.quantize(z)
163 | indices = self.codes_to_indices(codes)
164 |
165 | codes = rearrange(codes, "b n c d -> b n (c d)")
166 |
167 | out = self.project_out(codes)
168 |
169 | # reconstitute image or video dimensions
170 |
171 | if is_img_or_video:
172 | out = unpack_one(out, ps, "b * d")
173 | out = rearrange(out, "b ... d -> b d ...")
174 |
175 | indices = unpack_one(indices, ps, "b * c")
176 |
177 | if not self.keep_num_codebooks_dim:
178 | indices = rearrange(indices, "... 1 -> ...")
179 |
180 | return out, indices
181 |
--------------------------------------------------------------------------------
/sat/sgm/modules/autoencoding/regularizers/lookup_free_quantization.py:
--------------------------------------------------------------------------------
1 | """
2 | Lookup Free Quantization
3 | Proposed in https://arxiv.org/abs/2310.05737
4 |
5 | In the simplest setup, each dimension is quantized into {-1, 1}.
6 | An entropy penalty is used to encourage utilization.
7 | """
8 |
9 | from math import log2, ceil
10 | from collections import namedtuple
11 |
12 | import torch
13 | from torch import nn, einsum
14 | import torch.nn.functional as F
15 | from torch.nn import Module
16 | from torch.cuda.amp import autocast
17 |
18 | from einops import rearrange, reduce, pack, unpack
19 |
20 | # constants
21 |
22 | Return = namedtuple("Return", ["quantized", "indices", "entropy_aux_loss"])
23 |
24 | LossBreakdown = namedtuple("LossBreakdown", ["per_sample_entropy", "batch_entropy", "commitment"])
25 |
26 | # helper functions
27 |
28 |
29 | def exists(v):
30 | return v is not None
31 |
32 |
33 | def default(*args):
34 | for arg in args:
35 | if exists(arg):
36 | return arg() if callable(arg) else arg
37 | return None
38 |
39 |
40 | def pack_one(t, pattern):
41 | return pack([t], pattern)
42 |
43 |
44 | def unpack_one(t, ps, pattern):
45 | return unpack(t, ps, pattern)[0]
46 |
47 |
48 | # entropy
49 |
50 |
51 | def log(t, eps=1e-5):
52 | return t.clamp(min=eps).log()
53 |
54 |
55 | def entropy(prob):
56 | return (-prob * log(prob)).sum(dim=-1)
57 |
58 |
59 | # class
60 |
61 |
62 | class LFQ(Module):
63 | def __init__(
64 | self,
65 | *,
66 | dim=None,
67 | codebook_size=None,
68 | entropy_loss_weight=0.1,
69 | commitment_loss_weight=0.25,
70 | diversity_gamma=1.0,
71 | straight_through_activation=nn.Identity(),
72 | num_codebooks=1,
73 | keep_num_codebooks_dim=None,
74 | codebook_scale=1.0, # for residual LFQ, codebook scaled down by 2x at each layer
75 | frac_per_sample_entropy=1.0, # make less than 1. to only use a random fraction of the probs for per sample entropy
76 | ):
77 | super().__init__()
78 |
79 | # some assert validations
80 |
81 | assert exists(dim) or exists(codebook_size), "either dim or codebook_size must be specified for LFQ"
82 | assert (
83 | not exists(codebook_size) or log2(codebook_size).is_integer()
84 | ), f"your codebook size must be a power of 2 for lookup free quantization (suggested {2 ** ceil(log2(codebook_size))})"
85 |
86 | codebook_size = default(codebook_size, lambda: 2**dim)
87 | codebook_dim = int(log2(codebook_size))
88 |
89 | codebook_dims = codebook_dim * num_codebooks
90 | dim = default(dim, codebook_dims)
91 |
92 | has_projections = dim != codebook_dims
93 | self.project_in = nn.Linear(dim, codebook_dims) if has_projections else nn.Identity()
94 | self.project_out = nn.Linear(codebook_dims, dim) if has_projections else nn.Identity()
95 | self.has_projections = has_projections
96 |
97 | self.dim = dim
98 | self.codebook_dim = codebook_dim
99 | self.num_codebooks = num_codebooks
100 |
101 | keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1)
102 | assert not (num_codebooks > 1 and not keep_num_codebooks_dim)
103 | self.keep_num_codebooks_dim = keep_num_codebooks_dim
104 |
105 | # straight through activation
106 |
107 | self.activation = straight_through_activation
108 |
109 | # entropy aux loss related weights
110 |
111 | assert 0 < frac_per_sample_entropy <= 1.0
112 | self.frac_per_sample_entropy = frac_per_sample_entropy
113 |
114 | self.diversity_gamma = diversity_gamma
115 | self.entropy_loss_weight = entropy_loss_weight
116 |
117 | # codebook scale
118 |
119 | self.codebook_scale = codebook_scale
120 |
121 | # commitment loss
122 |
123 | self.commitment_loss_weight = commitment_loss_weight
124 |
125 | # for no auxiliary loss, during inference
126 |
127 | self.register_buffer("mask", 2 ** torch.arange(codebook_dim - 1, -1, -1))
128 | self.register_buffer("zero", torch.tensor(0.0), persistent=False)
129 |
130 | # codes
131 |
132 | all_codes = torch.arange(codebook_size)
133 | bits = ((all_codes[..., None].int() & self.mask) != 0).float()
134 | codebook = self.bits_to_codes(bits)
135 |
136 | self.register_buffer("codebook", codebook, persistent=False)
137 |
138 | def bits_to_codes(self, bits):
139 | return bits * self.codebook_scale * 2 - self.codebook_scale
140 |
141 | @property
142 | def dtype(self):
143 | return self.codebook.dtype
144 |
145 | def indices_to_codes(self, indices, project_out=True):
146 | is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))
147 |
148 | if not self.keep_num_codebooks_dim:
149 | indices = rearrange(indices, "... -> ... 1")
150 |
151 | # indices to codes, which are bits of either -1 or 1
152 |
153 | bits = ((indices[..., None].int() & self.mask) != 0).to(self.dtype)
154 |
155 | codes = self.bits_to_codes(bits)
156 |
157 | codes = rearrange(codes, "... c d -> ... (c d)")
158 |
159 | # whether to project codes out to original dimensions
160 | # if the input feature dimensions were not log2(codebook size)
161 |
162 | if project_out:
163 | codes = self.project_out(codes)
164 |
165 | # rearrange codes back to original shape
166 |
167 | if is_img_or_video:
168 | codes = rearrange(codes, "b ... d -> b d ...")
169 |
170 | return codes
171 |
172 | @autocast(enabled=False)
173 | def forward(
174 | self,
175 | x,
176 | inv_temperature=100.0,
177 | return_loss_breakdown=False,
178 | mask=None,
179 | ):
180 | """
181 | einstein notation
182 | b - batch
183 | n - sequence (or flattened spatial dimensions)
184 | d - feature dimension, which is also log2(codebook size)
185 | c - number of codebook dim
186 | """
187 |
188 | x = x.float()
189 |
190 | is_img_or_video = x.ndim >= 4
191 |
192 | # standardize image or video into (batch, seq, dimension)
193 |
194 | if is_img_or_video:
195 | x = rearrange(x, "b d ... -> b ... d")
196 | x, ps = pack_one(x, "b * d")
197 |
198 | assert x.shape[-1] == self.dim, f"expected dimension of {self.dim} but received {x.shape[-1]}"
199 |
200 | x = self.project_in(x)
201 |
202 | # split out number of codebooks
203 |
204 | x = rearrange(x, "b n (c d) -> b n c d", c=self.num_codebooks)
205 |
206 | # quantize by eq 3.
207 |
208 | original_input = x
209 |
210 | codebook_value = torch.ones_like(x) * self.codebook_scale
211 | quantized = torch.where(x > 0, codebook_value, -codebook_value)
212 |
213 | # use straight-through gradients (optionally with custom activation fn) if training
214 |
215 | if self.training:
216 | x = self.activation(x)
217 | x = x + (quantized - x).detach()
218 | else:
219 | x = quantized
220 |
221 | # calculate indices
222 |
223 | indices = reduce((x > 0).int() * self.mask.int(), "b n c d -> b n c", "sum")
224 |
225 | # entropy aux loss
226 |
227 | if self.training:
228 | # the same as euclidean distance up to a constant
229 | distance = -2 * einsum("... i d, j d -> ... i j", original_input, self.codebook)
230 |
231 | prob = (-distance * inv_temperature).softmax(dim=-1)
232 |
233 | # account for mask
234 |
235 | if exists(mask):
236 | prob = prob[mask]
237 | else:
238 | prob = rearrange(prob, "b n ... -> (b n) ...")
239 |
240 | # whether to only use a fraction of probs, for reducing memory
241 |
242 | if self.frac_per_sample_entropy < 1.0:
243 | num_tokens = prob.shape[0]
244 | num_sampled_tokens = int(num_tokens * self.frac_per_sample_entropy)
245 | rand_mask = torch.randn(num_tokens).argsort(dim=-1) < num_sampled_tokens
246 | per_sample_probs = prob[rand_mask]
247 | else:
248 | per_sample_probs = prob
249 |
250 | # calculate per sample entropy
251 |
252 | per_sample_entropy = entropy(per_sample_probs).mean()
253 |
254 | # distribution over all available tokens in the batch
255 |
256 | avg_prob = reduce(per_sample_probs, "... c d -> c d", "mean")
257 | codebook_entropy = entropy(avg_prob).mean()
258 |
259 | # 1. entropy will be nudged to be low for each code, to encourage the network to output confident predictions
260 | # 2. codebook entropy will be nudged to be high, to encourage all codes to be uniformly used within the batch
261 |
262 | entropy_aux_loss = per_sample_entropy - self.diversity_gamma * codebook_entropy
263 | else:
264 | # if not training, just return dummy 0
265 | entropy_aux_loss = per_sample_entropy = codebook_entropy = self.zero
266 |
267 | # commit loss
268 |
269 | if self.training:
270 | commit_loss = F.mse_loss(original_input, quantized.detach(), reduction="none")
271 |
272 | if exists(mask):
273 | commit_loss = commit_loss[mask]
274 |
275 | commit_loss = commit_loss.mean()
276 | else:
277 | commit_loss = self.zero
278 |
279 | # merge back codebook dim
280 |
281 | x = rearrange(x, "b n c d -> b n (c d)")
282 |
283 | # project out to feature dimension if needed
284 |
285 | x = self.project_out(x)
286 |
287 | # reconstitute image or video dimensions
288 |
289 | if is_img_or_video:
290 | x = unpack_one(x, ps, "b * d")
291 | x = rearrange(x, "b ... d -> b d ...")
292 |
293 | indices = unpack_one(indices, ps, "b * c")
294 |
295 | # whether to remove single codebook dim
296 |
297 | if not self.keep_num_codebooks_dim:
298 | indices = rearrange(indices, "... 1 -> ...")
299 |
300 | # complete aux loss
301 |
302 | aux_loss = entropy_aux_loss * self.entropy_loss_weight + commit_loss * self.commitment_loss_weight
303 |
304 | ret = Return(x, indices, aux_loss)
305 |
306 | if not return_loss_breakdown:
307 | return ret
308 |
309 | return ret, LossBreakdown(per_sample_entropy, codebook_entropy, commit_loss)
310 |
--------------------------------------------------------------------------------
/sat/sgm/modules/autoencoding/temporal_ae.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, Iterable, Union
2 |
3 | import torch
4 | from einops import rearrange, repeat
5 |
6 | from sgm.modules.diffusionmodules.model import (
7 | XFORMERS_IS_AVAILABLE,
8 | AttnBlock,
9 | Decoder,
10 | MemoryEfficientAttnBlock,
11 | ResnetBlock,
12 | )
13 | from sgm.modules.diffusionmodules.openaimodel import ResBlock, timestep_embedding
14 | from sgm.modules.video_attention import VideoTransformerBlock
15 | from sgm.util import partialclass
16 |
17 |
18 | class VideoResBlock(ResnetBlock):
19 | def __init__(
20 | self,
21 | out_channels,
22 | *args,
23 | dropout=0.0,
24 | video_kernel_size=3,
25 | alpha=0.0,
26 | merge_strategy="learned",
27 | **kwargs,
28 | ):
29 | super().__init__(out_channels=out_channels, dropout=dropout, *args, **kwargs)
30 | if video_kernel_size is None:
31 | video_kernel_size = [3, 1, 1]
32 | self.time_stack = ResBlock(
33 | channels=out_channels,
34 | emb_channels=0,
35 | dropout=dropout,
36 | dims=3,
37 | use_scale_shift_norm=False,
38 | use_conv=False,
39 | up=False,
40 | down=False,
41 | kernel_size=video_kernel_size,
42 | use_checkpoint=False,
43 | skip_t_emb=True,
44 | )
45 |
46 | self.merge_strategy = merge_strategy
47 | if self.merge_strategy == "fixed":
48 | self.register_buffer("mix_factor", torch.Tensor([alpha]))
49 | elif self.merge_strategy == "learned":
50 | self.register_parameter("mix_factor", torch.nn.Parameter(torch.Tensor([alpha])))
51 | else:
52 | raise ValueError(f"unknown merge strategy {self.merge_strategy}")
53 |
54 | def get_alpha(self, bs):
55 | if self.merge_strategy == "fixed":
56 | return self.mix_factor
57 | elif self.merge_strategy == "learned":
58 | return torch.sigmoid(self.mix_factor)
59 | else:
60 | raise NotImplementedError()
61 |
62 | def forward(self, x, temb, skip_video=False, timesteps=None):
63 | if timesteps is None:
64 | timesteps = self.timesteps
65 |
66 | b, c, h, w = x.shape
67 |
68 | x = super().forward(x, temb)
69 |
70 | if not skip_video:
71 | x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
72 |
73 | x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
74 |
75 | x = self.time_stack(x, temb)
76 |
77 | alpha = self.get_alpha(bs=b // timesteps)
78 | x = alpha * x + (1.0 - alpha) * x_mix
79 |
80 | x = rearrange(x, "b c t h w -> (b t) c h w")
81 | return x
82 |
83 |
84 | class AE3DConv(torch.nn.Conv2d):
85 | def __init__(self, in_channels, out_channels, video_kernel_size=3, *args, **kwargs):
86 | super().__init__(in_channels, out_channels, *args, **kwargs)
87 | if isinstance(video_kernel_size, Iterable):
88 | padding = [int(k // 2) for k in video_kernel_size]
89 | else:
90 | padding = int(video_kernel_size // 2)
91 |
92 | self.time_mix_conv = torch.nn.Conv3d(
93 | in_channels=out_channels,
94 | out_channels=out_channels,
95 | kernel_size=video_kernel_size,
96 | padding=padding,
97 | )
98 |
99 | def forward(self, input, timesteps, skip_video=False):
100 | x = super().forward(input)
101 | if skip_video:
102 | return x
103 | x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
104 | x = self.time_mix_conv(x)
105 | return rearrange(x, "b c t h w -> (b t) c h w")
106 |
107 |
108 | class VideoBlock(AttnBlock):
109 | def __init__(self, in_channels: int, alpha: float = 0, merge_strategy: str = "learned"):
110 | super().__init__(in_channels)
111 | # no context, single headed, as in base class
112 | self.time_mix_block = VideoTransformerBlock(
113 | dim=in_channels,
114 | n_heads=1,
115 | d_head=in_channels,
116 | checkpoint=False,
117 | ff_in=True,
118 | attn_mode="softmax",
119 | )
120 |
121 | time_embed_dim = self.in_channels * 4
122 | self.video_time_embed = torch.nn.Sequential(
123 | torch.nn.Linear(self.in_channels, time_embed_dim),
124 | torch.nn.SiLU(),
125 | torch.nn.Linear(time_embed_dim, self.in_channels),
126 | )
127 |
128 | self.merge_strategy = merge_strategy
129 | if self.merge_strategy == "fixed":
130 | self.register_buffer("mix_factor", torch.Tensor([alpha]))
131 | elif self.merge_strategy == "learned":
132 | self.register_parameter("mix_factor", torch.nn.Parameter(torch.Tensor([alpha])))
133 | else:
134 | raise ValueError(f"unknown merge strategy {self.merge_strategy}")
135 |
136 | def forward(self, x, timesteps, skip_video=False):
137 | if skip_video:
138 | return super().forward(x)
139 |
140 | x_in = x
141 | x = self.attention(x)
142 | h, w = x.shape[2:]
143 | x = rearrange(x, "b c h w -> b (h w) c")
144 |
145 | x_mix = x
146 | num_frames = torch.arange(timesteps, device=x.device)
147 | num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
148 | num_frames = rearrange(num_frames, "b t -> (b t)")
149 | t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False)
150 | emb = self.video_time_embed(t_emb) # b, n_channels
151 | emb = emb[:, None, :]
152 | x_mix = x_mix + emb
153 |
154 | alpha = self.get_alpha()
155 | x_mix = self.time_mix_block(x_mix, timesteps=timesteps)
156 | x = alpha * x + (1.0 - alpha) * x_mix # alpha merge
157 |
158 | x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
159 | x = self.proj_out(x)
160 |
161 | return x_in + x
162 |
163 | def get_alpha(
164 | self,
165 | ):
166 | if self.merge_strategy == "fixed":
167 | return self.mix_factor
168 | elif self.merge_strategy == "learned":
169 | return torch.sigmoid(self.mix_factor)
170 | else:
171 | raise NotImplementedError(f"unknown merge strategy {self.merge_strategy}")
172 |
173 |
174 | class MemoryEfficientVideoBlock(MemoryEfficientAttnBlock):
175 | def __init__(self, in_channels: int, alpha: float = 0, merge_strategy: str = "learned"):
176 | super().__init__(in_channels)
177 | # no context, single headed, as in base class
178 | self.time_mix_block = VideoTransformerBlock(
179 | dim=in_channels,
180 | n_heads=1,
181 | d_head=in_channels,
182 | checkpoint=False,
183 | ff_in=True,
184 | attn_mode="softmax-xformers",
185 | )
186 |
187 | time_embed_dim = self.in_channels * 4
188 | self.video_time_embed = torch.nn.Sequential(
189 | torch.nn.Linear(self.in_channels, time_embed_dim),
190 | torch.nn.SiLU(),
191 | torch.nn.Linear(time_embed_dim, self.in_channels),
192 | )
193 |
194 | self.merge_strategy = merge_strategy
195 | if self.merge_strategy == "fixed":
196 | self.register_buffer("mix_factor", torch.Tensor([alpha]))
197 | elif self.merge_strategy == "learned":
198 | self.register_parameter("mix_factor", torch.nn.Parameter(torch.Tensor([alpha])))
199 | else:
200 | raise ValueError(f"unknown merge strategy {self.merge_strategy}")
201 |
202 | def forward(self, x, timesteps, skip_time_block=False):
203 | if skip_time_block:
204 | return super().forward(x)
205 |
206 | x_in = x
207 | x = self.attention(x)
208 | h, w = x.shape[2:]
209 | x = rearrange(x, "b c h w -> b (h w) c")
210 |
211 | x_mix = x
212 | num_frames = torch.arange(timesteps, device=x.device)
213 | num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
214 | num_frames = rearrange(num_frames, "b t -> (b t)")
215 | t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False)
216 | emb = self.video_time_embed(t_emb) # b, n_channels
217 | emb = emb[:, None, :]
218 | x_mix = x_mix + emb
219 |
220 | alpha = self.get_alpha()
221 | x_mix = self.time_mix_block(x_mix, timesteps=timesteps)
222 | x = alpha * x + (1.0 - alpha) * x_mix # alpha merge
223 |
224 | x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
225 | x = self.proj_out(x)
226 |
227 | return x_in + x
228 |
229 | def get_alpha(
230 | self,
231 | ):
232 | if self.merge_strategy == "fixed":
233 | return self.mix_factor
234 | elif self.merge_strategy == "learned":
235 | return torch.sigmoid(self.mix_factor)
236 | else:
237 | raise NotImplementedError(f"unknown merge strategy {self.merge_strategy}")
238 |
239 |
240 | def make_time_attn(
241 | in_channels,
242 | attn_type="vanilla",
243 | attn_kwargs=None,
244 | alpha: float = 0,
245 | merge_strategy: str = "learned",
246 | ):
247 | assert attn_type in [
248 | "vanilla",
249 | "vanilla-xformers",
250 | ], f"attn_type {attn_type} not supported for spatio-temporal attention"
251 | print(f"making spatial and temporal attention of type '{attn_type}' with {in_channels} in_channels")
252 | if not XFORMERS_IS_AVAILABLE and attn_type == "vanilla-xformers":
253 | print(
254 | f"Attention mode '{attn_type}' is not available. Falling back to vanilla attention. "
255 | f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}"
256 | )
257 | attn_type = "vanilla"
258 |
259 | if attn_type == "vanilla":
260 | assert attn_kwargs is None
261 | return partialclass(VideoBlock, in_channels, alpha=alpha, merge_strategy=merge_strategy)
262 | elif attn_type == "vanilla-xformers":
263 | print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
264 | return partialclass(
265 | MemoryEfficientVideoBlock,
266 | in_channels,
267 | alpha=alpha,
268 | merge_strategy=merge_strategy,
269 | )
270 | else:
271 | return NotImplementedError()
272 |
273 |
274 | class Conv2DWrapper(torch.nn.Conv2d):
275 | def forward(self, input: torch.Tensor, **kwargs) -> torch.Tensor:
276 | return super().forward(input)
277 |
278 |
279 | class VideoDecoder(Decoder):
280 | available_time_modes = ["all", "conv-only", "attn-only"]
281 |
282 | def __init__(
283 | self,
284 | *args,
285 | video_kernel_size: Union[int, list] = 3,
286 | alpha: float = 0.0,
287 | merge_strategy: str = "learned",
288 | time_mode: str = "conv-only",
289 | **kwargs,
290 | ):
291 | self.video_kernel_size = video_kernel_size
292 | self.alpha = alpha
293 | self.merge_strategy = merge_strategy
294 | self.time_mode = time_mode
295 | assert (
296 | self.time_mode in self.available_time_modes
297 | ), f"time_mode parameter has to be in {self.available_time_modes}"
298 | super().__init__(*args, **kwargs)
299 |
300 | def get_last_layer(self, skip_time_mix=False, **kwargs):
301 | if self.time_mode == "attn-only":
302 | raise NotImplementedError("TODO")
303 | else:
304 | return self.conv_out.time_mix_conv.weight if not skip_time_mix else self.conv_out.weight
305 |
306 | def _make_attn(self) -> Callable:
307 | if self.time_mode not in ["conv-only", "only-last-conv"]:
308 | return partialclass(
309 | make_time_attn,
310 | alpha=self.alpha,
311 | merge_strategy=self.merge_strategy,
312 | )
313 | else:
314 | return super()._make_attn()
315 |
316 | def _make_conv(self) -> Callable:
317 | if self.time_mode != "attn-only":
318 | return partialclass(AE3DConv, video_kernel_size=self.video_kernel_size)
319 | else:
320 | return Conv2DWrapper
321 |
322 | def _make_resblock(self) -> Callable:
323 | if self.time_mode not in ["attn-only", "only-last-conv"]:
324 | return partialclass(
325 | VideoResBlock,
326 | video_kernel_size=self.video_kernel_size,
327 | alpha=self.alpha,
328 | merge_strategy=self.merge_strategy,
329 | )
330 | else:
331 | return super()._make_resblock()
332 |
--------------------------------------------------------------------------------
/sat/sgm/modules/autoencoding/vqvae/quantize.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | from torch import einsum
6 | from einops import rearrange
7 |
8 |
9 | class VectorQuantizer2(nn.Module):
10 | """
11 | Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
12 | avoids costly matrix multiplications and allows for post-hoc remapping of indices.
13 | """
14 |
15 | # NOTE: due to a bug the beta term was applied to the wrong term. for
16 | # backwards compatibility we use the buggy version by default, but you can
17 | # specify legacy=False to fix it.
18 | def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True):
19 | super().__init__()
20 | self.n_e = n_e
21 | self.e_dim = e_dim
22 | self.beta = beta
23 | self.legacy = legacy
24 |
25 | self.embedding = nn.Embedding(self.n_e, self.e_dim)
26 | self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
27 |
28 | self.remap = remap
29 | if self.remap is not None:
30 | self.register_buffer("used", torch.tensor(np.load(self.remap)))
31 | self.re_embed = self.used.shape[0]
32 | self.unknown_index = unknown_index # "random" or "extra" or integer
33 | if self.unknown_index == "extra":
34 | self.unknown_index = self.re_embed
35 | self.re_embed = self.re_embed + 1
36 | print(
37 | f"Remapping {self.n_e} indices to {self.re_embed} indices. "
38 | f"Using {self.unknown_index} for unknown indices."
39 | )
40 | else:
41 | self.re_embed = n_e
42 |
43 | self.sane_index_shape = sane_index_shape
44 |
45 | def remap_to_used(self, inds):
46 | ishape = inds.shape
47 | assert len(ishape) > 1
48 | inds = inds.reshape(ishape[0], -1)
49 | used = self.used.to(inds)
50 | match = (inds[:, :, None] == used[None, None, ...]).long()
51 | new = match.argmax(-1)
52 | unknown = match.sum(2) < 1
53 | if self.unknown_index == "random":
54 | new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
55 | else:
56 | new[unknown] = self.unknown_index
57 | return new.reshape(ishape)
58 |
59 | def unmap_to_all(self, inds):
60 | ishape = inds.shape
61 | assert len(ishape) > 1
62 | inds = inds.reshape(ishape[0], -1)
63 | used = self.used.to(inds)
64 | if self.re_embed > self.used.shape[0]: # extra token
65 | inds[inds >= self.used.shape[0]] = 0 # simply set to zero
66 | back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
67 | return back.reshape(ishape)
68 |
69 | def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
70 | assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel"
71 | assert rescale_logits == False, "Only for interface compatible with Gumbel"
72 | assert return_logits == False, "Only for interface compatible with Gumbel"
73 | # reshape z -> (batch, height, width, channel) and flatten
74 | z = rearrange(z, "b c h w -> b h w c").contiguous()
75 | z_flattened = z.view(-1, self.e_dim)
76 | # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
77 |
78 | d = (
79 | torch.sum(z_flattened**2, dim=1, keepdim=True)
80 | + torch.sum(self.embedding.weight**2, dim=1)
81 | - 2 * torch.einsum("bd,dn->bn", z_flattened, rearrange(self.embedding.weight, "n d -> d n"))
82 | )
83 |
84 | min_encoding_indices = torch.argmin(d, dim=1)
85 | z_q = self.embedding(min_encoding_indices).view(z.shape)
86 | perplexity = None
87 | min_encodings = None
88 |
89 | # compute loss for embedding
90 | if not self.legacy:
91 | loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2)
92 | else:
93 | loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
94 |
95 | # preserve gradients
96 | z_q = z + (z_q - z).detach()
97 |
98 | # reshape back to match original input shape
99 | z_q = rearrange(z_q, "b h w c -> b c h w").contiguous()
100 |
101 | if self.remap is not None:
102 | min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis
103 | min_encoding_indices = self.remap_to_used(min_encoding_indices)
104 | min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
105 |
106 | if self.sane_index_shape:
107 | min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3])
108 |
109 | return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
110 |
111 | def get_codebook_entry(self, indices, shape):
112 | # shape specifying (batch, height, width, channel)
113 | if self.remap is not None:
114 | indices = indices.reshape(shape[0], -1) # add batch axis
115 | indices = self.unmap_to_all(indices)
116 | indices = indices.reshape(-1) # flatten again
117 |
118 | # get quantized latent vectors
119 | z_q = self.embedding(indices)
120 |
121 | if shape is not None:
122 | z_q = z_q.view(shape)
123 | # reshape back to match original input shape
124 | z_q = z_q.permute(0, 3, 1, 2).contiguous()
125 |
126 | return z_q
127 |
128 |
129 | class GumbelQuantize(nn.Module):
130 | """
131 | credit to @karpathy: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!)
132 | Gumbel Softmax trick quantizer
133 | Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016
134 | https://arxiv.org/abs/1611.01144
135 | """
136 |
137 | def __init__(
138 | self,
139 | num_hiddens,
140 | embedding_dim,
141 | n_embed,
142 | straight_through=True,
143 | kl_weight=5e-4,
144 | temp_init=1.0,
145 | use_vqinterface=True,
146 | remap=None,
147 | unknown_index="random",
148 | ):
149 | super().__init__()
150 |
151 | self.embedding_dim = embedding_dim
152 | self.n_embed = n_embed
153 |
154 | self.straight_through = straight_through
155 | self.temperature = temp_init
156 | self.kl_weight = kl_weight
157 |
158 | self.proj = nn.Conv2d(num_hiddens, n_embed, 1)
159 | self.embed = nn.Embedding(n_embed, embedding_dim)
160 |
161 | self.use_vqinterface = use_vqinterface
162 |
163 | self.remap = remap
164 | if self.remap is not None:
165 | self.register_buffer("used", torch.tensor(np.load(self.remap)))
166 | self.re_embed = self.used.shape[0]
167 | self.unknown_index = unknown_index # "random" or "extra" or integer
168 | if self.unknown_index == "extra":
169 | self.unknown_index = self.re_embed
170 | self.re_embed = self.re_embed + 1
171 | print(
172 | f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
173 | f"Using {self.unknown_index} for unknown indices."
174 | )
175 | else:
176 | self.re_embed = n_embed
177 |
178 | def remap_to_used(self, inds):
179 | ishape = inds.shape
180 | assert len(ishape) > 1
181 | inds = inds.reshape(ishape[0], -1)
182 | used = self.used.to(inds)
183 | match = (inds[:, :, None] == used[None, None, ...]).long()
184 | new = match.argmax(-1)
185 | unknown = match.sum(2) < 1
186 | if self.unknown_index == "random":
187 | new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
188 | else:
189 | new[unknown] = self.unknown_index
190 | return new.reshape(ishape)
191 |
192 | def unmap_to_all(self, inds):
193 | ishape = inds.shape
194 | assert len(ishape) > 1
195 | inds = inds.reshape(ishape[0], -1)
196 | used = self.used.to(inds)
197 | if self.re_embed > self.used.shape[0]: # extra token
198 | inds[inds >= self.used.shape[0]] = 0 # simply set to zero
199 | back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
200 | return back.reshape(ishape)
201 |
202 | def forward(self, z, temp=None, return_logits=False):
203 | # force hard = True when we are in eval mode, as we must quantize. actually, always true seems to work
204 | hard = self.straight_through if self.training else True
205 | temp = self.temperature if temp is None else temp
206 |
207 | logits = self.proj(z)
208 | if self.remap is not None:
209 | # continue only with used logits
210 | full_zeros = torch.zeros_like(logits)
211 | logits = logits[:, self.used, ...]
212 |
213 | soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard)
214 | if self.remap is not None:
215 | # go back to all entries but unused set to zero
216 | full_zeros[:, self.used, ...] = soft_one_hot
217 | soft_one_hot = full_zeros
218 | z_q = einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight)
219 |
220 | # + kl divergence to the prior loss
221 | qy = F.softmax(logits, dim=1)
222 | diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean()
223 |
224 | ind = soft_one_hot.argmax(dim=1)
225 | if self.remap is not None:
226 | ind = self.remap_to_used(ind)
227 | if self.use_vqinterface:
228 | if return_logits:
229 | return z_q, diff, (None, None, ind), logits
230 | return z_q, diff, (None, None, ind)
231 | return z_q, diff, ind
232 |
233 | def get_codebook_entry(self, indices, shape):
234 | b, h, w, c = shape
235 | assert b * h * w == indices.shape[0]
236 | indices = rearrange(indices, "(b h w) -> b h w", b=b, h=h, w=w)
237 | if self.remap is not None:
238 | indices = self.unmap_to_all(indices)
239 | one_hot = F.one_hot(indices, num_classes=self.n_embed).permute(0, 3, 1, 2).float()
240 | z_q = einsum("b n h w, n d -> b d h w", one_hot, self.embed.weight)
241 | return z_q
242 |
--------------------------------------------------------------------------------
/sat/sgm/modules/cp_enc_dec.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.distributed
4 | import torch.nn as nn
5 | from ..util import (
6 | get_context_parallel_group,
7 | get_context_parallel_rank,
8 | get_context_parallel_world_size,
9 |
10 | )
11 |
12 | _USE_CP = True
13 |
14 |
15 | def cast_tuple(t, length=1):
16 | return t if isinstance(t, tuple) else ((t,) * length)
17 |
18 |
19 | def divisible_by(num, den):
20 | return (num % den) == 0
21 |
22 |
23 | def is_odd(n):
24 | return not divisible_by(n, 2)
25 |
26 |
27 | def exists(v):
28 | return v is not None
29 |
30 |
31 | def pair(t):
32 | return t if isinstance(t, tuple) else (t, t)
33 |
34 |
35 | def get_timestep_embedding(timesteps, embedding_dim):
36 | """
37 | This matches the implementation in Denoising Diffusion Probabilistic Models:
38 | From Fairseq.
39 | Build sinusoidal embeddings.
40 | This matches the implementation in tensor2tensor, but differs slightly
41 | from the description in Section 3.5 of "Attention Is All You Need".
42 | """
43 | assert len(timesteps.shape) == 1
44 |
45 | half_dim = embedding_dim // 2
46 | emb = math.log(10000) / (half_dim - 1)
47 | emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
48 | emb = emb.to(device=timesteps.device)
49 | emb = timesteps.float()[:, None] * emb[None, :]
50 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
51 | if embedding_dim % 2 == 1: # zero pad
52 | emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
53 | return emb
54 |
55 |
56 | def nonlinearity(x):
57 | # swish
58 | return x * torch.sigmoid(x)
59 |
60 |
61 | def leaky_relu(p=0.1):
62 | return nn.LeakyReLU(p)
63 |
64 |
65 | def _split(input_, dim):
66 | cp_world_size = get_context_parallel_world_size()
67 |
68 | if cp_world_size == 1:
69 | return input_
70 |
71 | cp_rank = get_context_parallel_rank()
72 |
73 | # print('in _split, cp_rank:', cp_rank, 'input_size:', input_.shape)
74 |
75 | inpu_first_frame_ = input_.transpose(0, dim)[:1].transpose(0, dim).contiguous()
76 | input_ = input_.transpose(0, dim)[1:].transpose(0, dim).contiguous()
77 | dim_size = input_.size()[dim] // cp_world_size
78 |
79 | input_list = torch.split(input_, dim_size, dim=dim)
80 | output = input_list[cp_rank]
81 |
82 | if cp_rank == 0:
83 | output = torch.cat([inpu_first_frame_, output], dim=dim)
84 | output = output.contiguous()
85 |
86 | # print('out _split, cp_rank:', cp_rank, 'output_size:', output.shape)
87 |
88 | return output
89 |
90 |
91 | def _gather(input_, dim):
92 | cp_world_size = get_context_parallel_world_size()
93 |
94 | # Bypass the function if context parallel is 1
95 | if cp_world_size == 1:
96 | return input_
97 |
98 | group = get_context_parallel_group()
99 | cp_rank = get_context_parallel_rank()
100 |
101 | # print('in _gather, cp_rank:', cp_rank, 'input_size:', input_.shape)
102 |
103 | input_first_frame_ = input_.transpose(0, dim)[:1].transpose(0, dim).contiguous()
104 | if cp_rank == 0:
105 | input_ = input_.transpose(0, dim)[1:].transpose(0, dim).contiguous()
106 |
107 | tensor_list = [torch.empty_like(torch.cat([input_first_frame_, input_], dim=dim))] + [
108 | torch.empty_like(input_) for _ in range(cp_world_size - 1)
109 | ]
110 |
111 | if cp_rank == 0:
112 | input_ = torch.cat([input_first_frame_, input_], dim=dim)
113 |
114 | tensor_list[cp_rank] = input_
115 | torch.distributed.all_gather(tensor_list, input_, group=group)
116 |
117 | output = torch.cat(tensor_list, dim=dim).contiguous()
118 |
119 | # print('out _gather, cp_rank:', cp_rank, 'output_size:', output.shape)
120 |
121 | return output
122 |
123 |
124 | def _conv_split(input_, dim, kernel_size):
125 | cp_world_size = get_context_parallel_world_size()
126 |
127 | # Bypass the function if context parallel is 1
128 | if cp_world_size == 1:
129 | return input_
130 |
131 | # print('in _conv_split, cp_rank:', cp_rank, 'input_size:', input_.shape)
132 |
133 | cp_rank = get_context_parallel_rank()
134 |
135 | dim_size = (input_.size()[dim] - kernel_size) // cp_world_size
136 |
137 | if cp_rank == 0:
138 | output = input_.transpose(dim, 0)[: dim_size + kernel_size].transpose(dim, 0)
139 | else:
140 | output = input_.transpose(dim, 0)[cp_rank * dim_size + 1 : (cp_rank + 1) * dim_size + kernel_size].transpose(
141 | dim, 0
142 | )
143 | output = output.contiguous()
144 |
145 | # print('out _conv_split, cp_rank:', cp_rank, 'input_size:', output.shape)
146 |
147 | return output
148 |
149 |
150 | def _conv_gather(input_, dim, kernel_size):
151 | cp_world_size = get_context_parallel_world_size()
152 |
153 | # Bypass the function if context parallel is 1
154 | if cp_world_size == 1:
155 | return input_
156 |
157 | group = get_context_parallel_group()
158 | cp_rank = get_context_parallel_rank()
159 |
160 | # print('in _conv_gather, cp_rank:', cp_rank, 'input_size:', input_.shape)
161 |
162 | input_first_kernel_ = input_.transpose(0, dim)[:kernel_size].transpose(0, dim).contiguous()
163 | if cp_rank == 0:
164 | input_ = input_.transpose(0, dim)[kernel_size:].transpose(0, dim).contiguous()
165 | else:
166 | input_ = input_.transpose(0, dim)[kernel_size - 1 :].transpose(0, dim).contiguous()
167 |
168 | tensor_list = [torch.empty_like(torch.cat([input_first_kernel_, input_], dim=dim))] + [
169 | torch.empty_like(input_) for _ in range(cp_world_size - 1)
170 | ]
171 | if cp_rank == 0:
172 | input_ = torch.cat([input_first_kernel_, input_], dim=dim)
173 |
174 | tensor_list[cp_rank] = input_
175 | torch.distributed.all_gather(tensor_list, input_, group=group)
176 |
177 | # Note: torch.cat already creates a contiguous tensor.
178 | output = torch.cat(tensor_list, dim=dim).contiguous()
179 |
180 | # print('out _conv_gather, cp_rank:', cp_rank, 'input_size:', output.shape)
181 |
182 | return output
--------------------------------------------------------------------------------
/sat/sgm/modules/diffusionmodules/__init__.py:
--------------------------------------------------------------------------------
1 | from .denoiser import Denoiser
2 | from .discretizer import Discretization
3 | from .model import Decoder, Encoder, Model
4 | from .openaimodel import UNetModel
5 | from .sampling import BaseDiffusionSampler
6 | from .wrappers import OpenAIWrapper
7 |
--------------------------------------------------------------------------------
/sat/sgm/modules/diffusionmodules/denoiser.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Union
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | from ...util import append_dims, instantiate_from_config
7 |
8 |
9 | class Denoiser(nn.Module):
10 | def __init__(self, weighting_config, scaling_config):
11 | super().__init__()
12 |
13 | self.weighting = instantiate_from_config(weighting_config)
14 | self.scaling = instantiate_from_config(scaling_config)
15 |
16 | def possibly_quantize_sigma(self, sigma):
17 | return sigma
18 |
19 | def possibly_quantize_c_noise(self, c_noise):
20 | return c_noise
21 |
22 | def w(self, sigma):
23 | return self.weighting(sigma)
24 |
25 | def forward(
26 | self,
27 | network: nn.Module,
28 | input: torch.Tensor,
29 | sigma: torch.Tensor,
30 | cond: Dict,
31 | **additional_model_inputs,
32 | ) -> torch.Tensor:
33 | sigma = self.possibly_quantize_sigma(sigma)
34 | sigma_shape = sigma.shape
35 | sigma = append_dims(sigma, input.ndim)
36 | c_skip, c_out, c_in, c_noise = self.scaling(sigma, **additional_model_inputs)
37 | c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape))
38 | return network(input * c_in, c_noise, cond, **additional_model_inputs) * c_out + input * c_skip
39 |
40 |
41 | class DiscreteDenoiser(Denoiser):
42 | def __init__(
43 | self,
44 | weighting_config,
45 | scaling_config,
46 | num_idx,
47 | discretization_config,
48 | do_append_zero=False,
49 | quantize_c_noise=True,
50 | flip=True,
51 | ):
52 | super().__init__(weighting_config, scaling_config)
53 | sigmas = instantiate_from_config(discretization_config)(num_idx, do_append_zero=do_append_zero, flip=flip)
54 | self.sigmas = sigmas
55 | # self.register_buffer("sigmas", sigmas)
56 | self.quantize_c_noise = quantize_c_noise
57 |
58 | def sigma_to_idx(self, sigma):
59 | dists = sigma - self.sigmas.to(sigma.device)[:, None]
60 | return dists.abs().argmin(dim=0).view(sigma.shape)
61 |
62 | def idx_to_sigma(self, idx):
63 | return self.sigmas.to(idx.device)[idx]
64 |
65 | def possibly_quantize_sigma(self, sigma):
66 | return self.idx_to_sigma(self.sigma_to_idx(sigma))
67 |
68 | def possibly_quantize_c_noise(self, c_noise):
69 | if self.quantize_c_noise:
70 | return self.sigma_to_idx(c_noise)
71 | else:
72 | return c_noise
73 |
--------------------------------------------------------------------------------
/sat/sgm/modules/diffusionmodules/denoiser_scaling.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import Any, Tuple
3 |
4 | import torch
5 |
6 |
7 | class DenoiserScaling(ABC):
8 | @abstractmethod
9 | def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
10 | pass
11 |
12 |
13 | class EDMScaling:
14 | def __init__(self, sigma_data: float = 0.5):
15 | self.sigma_data = sigma_data
16 |
17 | def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
18 | c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
19 | c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5
20 | c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5
21 | c_noise = 0.25 * sigma.log()
22 | return c_skip, c_out, c_in, c_noise
23 |
24 |
25 | class EpsScaling:
26 | def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
27 | c_skip = torch.ones_like(sigma, device=sigma.device)
28 | c_out = -sigma
29 | c_in = 1 / (sigma**2 + 1.0) ** 0.5
30 | c_noise = sigma.clone()
31 | return c_skip, c_out, c_in, c_noise
32 |
33 |
34 | class VScaling:
35 | def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
36 | c_skip = 1.0 / (sigma**2 + 1.0)
37 | c_out = -sigma / (sigma**2 + 1.0) ** 0.5
38 | c_in = 1.0 / (sigma**2 + 1.0) ** 0.5
39 | c_noise = sigma.clone()
40 | return c_skip, c_out, c_in, c_noise
41 |
42 |
43 | class VScalingWithEDMcNoise(DenoiserScaling):
44 | def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
45 | c_skip = 1.0 / (sigma**2 + 1.0)
46 | c_out = -sigma / (sigma**2 + 1.0) ** 0.5
47 | c_in = 1.0 / (sigma**2 + 1.0) ** 0.5
48 | c_noise = 0.25 * sigma.log()
49 | return c_skip, c_out, c_in, c_noise
50 |
51 |
52 | class VideoScaling: # similar to VScaling
53 | def __call__(
54 | self, alphas_cumprod_sqrt: torch.Tensor, **additional_model_inputs
55 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
56 | c_skip = alphas_cumprod_sqrt
57 | c_out = -((1 - alphas_cumprod_sqrt**2) ** 0.5)
58 | c_in = torch.ones_like(alphas_cumprod_sqrt, device=alphas_cumprod_sqrt.device)
59 | c_noise = additional_model_inputs["idx"].clone()
60 | return c_skip, c_out, c_in, c_noise
61 |
--------------------------------------------------------------------------------
/sat/sgm/modules/diffusionmodules/denoiser_weighting.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class UnitWeighting:
5 | def __call__(self, sigma):
6 | return torch.ones_like(sigma, device=sigma.device)
7 |
8 |
9 | class EDMWeighting:
10 | def __init__(self, sigma_data=0.5):
11 | self.sigma_data = sigma_data
12 |
13 | def __call__(self, sigma):
14 | return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2
15 |
16 |
17 | class VWeighting(EDMWeighting):
18 | def __init__(self):
19 | super().__init__(sigma_data=1.0)
20 |
21 |
22 | class EpsWeighting:
23 | def __call__(self, sigma):
24 | return sigma**-2.0
25 |
--------------------------------------------------------------------------------
/sat/sgm/modules/diffusionmodules/discretizer.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod
2 | from functools import partial
3 |
4 | import numpy as np
5 | import torch
6 |
7 | from ...modules.diffusionmodules.util import make_beta_schedule
8 | from ...util import append_zero
9 |
10 |
11 | def generate_roughly_equally_spaced_steps(num_substeps: int, max_step: int) -> np.ndarray:
12 | return np.linspace(max_step - 1, 0, num_substeps, endpoint=False).astype(int)[::-1]
13 |
14 |
15 | class Discretization:
16 | def __call__(self, n, do_append_zero=True, device="cpu", flip=False, return_idx=False):
17 | if return_idx:
18 | sigmas, idx = self.get_sigmas(n, device=device, return_idx=return_idx)
19 | else:
20 | sigmas = self.get_sigmas(n, device=device, return_idx=return_idx)
21 | sigmas = append_zero(sigmas) if do_append_zero else sigmas
22 | if return_idx:
23 | return sigmas if not flip else torch.flip(sigmas, (0,)), idx
24 | else:
25 | return sigmas if not flip else torch.flip(sigmas, (0,))
26 |
27 | @abstractmethod
28 | def get_sigmas(self, n, device):
29 | pass
30 |
31 |
32 | class EDMDiscretization(Discretization):
33 | def __init__(self, sigma_min=0.002, sigma_max=80.0, rho=7.0):
34 | self.sigma_min = sigma_min
35 | self.sigma_max = sigma_max
36 | self.rho = rho
37 |
38 | def get_sigmas(self, n, device="cpu"):
39 | ramp = torch.linspace(0, 1, n, device=device)
40 | min_inv_rho = self.sigma_min ** (1 / self.rho)
41 | max_inv_rho = self.sigma_max ** (1 / self.rho)
42 | sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** self.rho
43 | return sigmas
44 |
45 |
46 | class LegacyDDPMDiscretization(Discretization):
47 | def __init__(
48 | self,
49 | linear_start=0.00085,
50 | linear_end=0.0120,
51 | num_timesteps=1000,
52 | ):
53 | super().__init__()
54 | self.num_timesteps = num_timesteps
55 | betas = make_beta_schedule("linear", num_timesteps, linear_start=linear_start, linear_end=linear_end)
56 | alphas = 1.0 - betas
57 | self.alphas_cumprod = np.cumprod(alphas, axis=0)
58 | self.to_torch = partial(torch.tensor, dtype=torch.float32)
59 |
60 | def get_sigmas(self, n, device="cpu"):
61 | if n < self.num_timesteps:
62 | timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps)
63 | alphas_cumprod = self.alphas_cumprod[timesteps]
64 | elif n == self.num_timesteps:
65 | alphas_cumprod = self.alphas_cumprod
66 | else:
67 | raise ValueError
68 |
69 | to_torch = partial(torch.tensor, dtype=torch.float32, device=device)
70 | sigmas = to_torch((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
71 | return torch.flip(sigmas, (0,)) # sigma_t: 14.4 -> 0.029
72 |
73 |
74 | class ZeroSNRDDPMDiscretization(Discretization):
75 | def __init__(
76 | self,
77 | linear_start=0.00085,
78 | linear_end=0.0120,
79 | num_timesteps=1000,
80 | shift_scale=1.0, # noise schedule t_n -> t_m: logSNR(t_m) = logSNR(t_n) - log(shift_scale)
81 | keep_start=False,
82 | post_shift=False,
83 | ):
84 | super().__init__()
85 | if keep_start and not post_shift:
86 | linear_start = linear_start / (shift_scale + (1 - shift_scale) * linear_start)
87 | self.num_timesteps = num_timesteps
88 | betas = make_beta_schedule("linear", num_timesteps, linear_start=linear_start, linear_end=linear_end)
89 | alphas = 1.0 - betas
90 | self.alphas_cumprod = np.cumprod(alphas, axis=0)
91 | self.to_torch = partial(torch.tensor, dtype=torch.float32)
92 |
93 | # SNR shift
94 | if not post_shift:
95 | self.alphas_cumprod = self.alphas_cumprod / (shift_scale + (1 - shift_scale) * self.alphas_cumprod)
96 |
97 | self.post_shift = post_shift
98 | self.shift_scale = shift_scale
99 |
100 | def get_sigmas(self, n, device="cpu", return_idx=False):
101 | if n < self.num_timesteps:
102 | timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps)
103 | alphas_cumprod = self.alphas_cumprod[timesteps]
104 | elif n == self.num_timesteps:
105 | alphas_cumprod = self.alphas_cumprod
106 | else:
107 | raise ValueError
108 |
109 | to_torch = partial(torch.tensor, dtype=torch.float32, device=device)
110 | alphas_cumprod = to_torch(alphas_cumprod)
111 | alphas_cumprod_sqrt = alphas_cumprod.sqrt()
112 | alphas_cumprod_sqrt_0 = alphas_cumprod_sqrt[0].clone()
113 | alphas_cumprod_sqrt_T = alphas_cumprod_sqrt[-1].clone()
114 |
115 | alphas_cumprod_sqrt -= alphas_cumprod_sqrt_T
116 | alphas_cumprod_sqrt *= alphas_cumprod_sqrt_0 / (alphas_cumprod_sqrt_0 - alphas_cumprod_sqrt_T)
117 |
118 | if self.post_shift:
119 | alphas_cumprod_sqrt = (
120 | alphas_cumprod_sqrt**2 / (self.shift_scale + (1 - self.shift_scale) * alphas_cumprod_sqrt**2)
121 | ) ** 0.5
122 |
123 | if return_idx:
124 | return torch.flip(alphas_cumprod_sqrt, (0,)), timesteps
125 | else:
126 | return torch.flip(alphas_cumprod_sqrt, (0,)) # sqrt(alpha_t): 0 -> 0.99
127 |
--------------------------------------------------------------------------------
/sat/sgm/modules/diffusionmodules/guiders.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from abc import ABC, abstractmethod
3 | from typing import Dict, List, Optional, Tuple, Union
4 | from functools import partial
5 | import math
6 |
7 | import torch
8 | from einops import rearrange, repeat
9 |
10 | from ...util import append_dims, default, instantiate_from_config
11 |
12 |
13 | class Guider(ABC):
14 | @abstractmethod
15 | def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor:
16 | pass
17 |
18 | def prepare_inputs(self, x: torch.Tensor, s: float, c: Dict, uc: Dict) -> Tuple[torch.Tensor, float, Dict]:
19 | pass
20 |
21 |
22 | class VanillaCFG:
23 | """
24 | implements parallelized CFG
25 | """
26 |
27 | def __init__(self, scale, dyn_thresh_config=None):
28 | self.scale = scale
29 | scale_schedule = lambda scale, sigma: scale # independent of step
30 | self.scale_schedule = partial(scale_schedule, scale)
31 | self.dyn_thresh = instantiate_from_config(
32 | default(
33 | dyn_thresh_config,
34 | {"target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"},
35 | )
36 | )
37 |
38 | def __call__(self, x, sigma, scale=None):
39 | x_u, x_c = x.chunk(2)
40 | scale_value = default(scale, self.scale_schedule(sigma))
41 | x_pred = self.dyn_thresh(x_u, x_c, scale_value)
42 | return x_pred
43 |
44 | def prepare_inputs(self, x, s, c, uc):
45 | c_out = dict()
46 |
47 | for k in c:
48 | if k in ["vector", "crossattn", "concat"]:
49 | c_out[k] = torch.cat((uc[k], c[k]), 0)
50 | else:
51 | assert c[k] == uc[k]
52 | c_out[k] = c[k]
53 | return torch.cat([x] * 2), torch.cat([s] * 2), c_out
54 |
55 |
56 | class DynamicCFG(VanillaCFG):
57 | def __init__(self, scale, exp, num_steps, dyn_thresh_config=None):
58 | super().__init__(scale, dyn_thresh_config)
59 | scale_schedule = (
60 | lambda scale, sigma, step_index: 1 + scale * (1 - math.cos(math.pi * (step_index / num_steps) ** exp)) / 2
61 | )
62 | self.scale_schedule = partial(scale_schedule, scale)
63 | self.dyn_thresh = instantiate_from_config(
64 | default(
65 | dyn_thresh_config,
66 | {"target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"},
67 | )
68 | )
69 |
70 | def __call__(self, x, sigma, step_index, scale=None):
71 | x_u, x_c = x.chunk(2)
72 | scale_value = self.scale_schedule(sigma, step_index.item())
73 | x_pred = self.dyn_thresh(x_u, x_c, scale_value)
74 | return x_pred
75 |
76 |
77 | class IdentityGuider:
78 | def __call__(self, x, sigma):
79 | return x
80 |
81 | def prepare_inputs(self, x, s, c, uc):
82 | c_out = dict()
83 |
84 | for k in c:
85 | c_out[k] = c[k]
86 |
87 | return x, s, c_out
88 |
--------------------------------------------------------------------------------
/sat/sgm/modules/diffusionmodules/loss.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional, Union
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from omegaconf import ListConfig
7 | import math
8 |
9 | from ...modules.diffusionmodules.sampling import VideoDDIMSampler, VPSDEDPMPP2MSampler
10 | from ...util import append_dims, instantiate_from_config
11 | from ...modules.autoencoding.lpips.loss.lpips import LPIPS
12 |
13 | # import rearrange
14 | from einops import rearrange
15 | import random
16 | from sat import mpu
17 |
18 |
19 | class StandardDiffusionLoss(nn.Module):
20 | def __init__(
21 | self,
22 | sigma_sampler_config,
23 | type="l2",
24 | offset_noise_level=0.0,
25 | batch2model_keys: Optional[Union[str, List[str], ListConfig]] = None,
26 | ):
27 | super().__init__()
28 |
29 | assert type in ["l2", "l1", "lpips"]
30 |
31 | self.sigma_sampler = instantiate_from_config(sigma_sampler_config)
32 |
33 | self.type = type
34 | self.offset_noise_level = offset_noise_level
35 |
36 | if type == "lpips":
37 | self.lpips = LPIPS().eval()
38 |
39 | if not batch2model_keys:
40 | batch2model_keys = []
41 |
42 | if isinstance(batch2model_keys, str):
43 | batch2model_keys = [batch2model_keys]
44 |
45 | self.batch2model_keys = set(batch2model_keys)
46 |
47 | def __call__(self, network, denoiser, conditioner, input, batch):
48 | cond = conditioner(batch)
49 | additional_model_inputs = {key: batch[key] for key in self.batch2model_keys.intersection(batch)}
50 |
51 | sigmas = self.sigma_sampler(input.shape[0]).to(input.device)
52 | noise = torch.randn_like(input)
53 | if self.offset_noise_level > 0.0:
54 | noise = (
55 | noise + append_dims(torch.randn(input.shape[0]).to(input.device), input.ndim) * self.offset_noise_level
56 | )
57 | noise = noise.to(input.dtype)
58 | noised_input = input.float() + noise * append_dims(sigmas, input.ndim)
59 | model_output = denoiser(network, noised_input, sigmas, cond, **additional_model_inputs)
60 | w = append_dims(denoiser.w(sigmas), input.ndim)
61 | return self.get_loss(model_output, input, w)
62 |
63 | def get_loss(self, model_output, target, w):
64 | if self.type == "l2":
65 | return torch.mean((w * (model_output - target) ** 2).reshape(target.shape[0], -1), 1)
66 | elif self.type == "l1":
67 | return torch.mean((w * (model_output - target).abs()).reshape(target.shape[0], -1), 1)
68 | elif self.type == "lpips":
69 | loss = self.lpips(model_output, target).reshape(-1)
70 | return loss
71 |
72 |
73 | class VideoDiffusionLoss(StandardDiffusionLoss):
74 | def __init__(self, block_scale=None, block_size=None, min_snr_value=None, fixed_frames=0, **kwargs):
75 | self.fixed_frames = fixed_frames
76 | self.block_scale = block_scale
77 | self.block_size = block_size
78 | self.min_snr_value = min_snr_value
79 | super().__init__(**kwargs)
80 |
81 | def __call__(self, network, denoiser, conditioner, input, batch):
82 | cond = conditioner(batch)
83 | additional_model_inputs = {key: batch[key] for key in self.batch2model_keys.intersection(batch)}
84 |
85 | alphas_cumprod_sqrt, idx = self.sigma_sampler(input.shape[0], return_idx=True)
86 | alphas_cumprod_sqrt = alphas_cumprod_sqrt.to(input.device)
87 | idx = idx.to(input.device)
88 |
89 | noise = torch.randn_like(input)
90 |
91 | # broadcast noise
92 | mp_size = mpu.get_model_parallel_world_size()
93 | global_rank = torch.distributed.get_rank() // mp_size
94 | src = global_rank * mp_size
95 | torch.distributed.broadcast(idx, src=src, group=mpu.get_model_parallel_group())
96 | torch.distributed.broadcast(noise, src=src, group=mpu.get_model_parallel_group())
97 | torch.distributed.broadcast(alphas_cumprod_sqrt, src=src, group=mpu.get_model_parallel_group())
98 |
99 | additional_model_inputs["idx"] = idx
100 |
101 | if self.offset_noise_level > 0.0:
102 | noise = (
103 | noise + append_dims(torch.randn(input.shape[0]).to(input.device), input.ndim) * self.offset_noise_level
104 | )
105 |
106 | noised_input = input.float() * append_dims(alphas_cumprod_sqrt, input.ndim) + noise * append_dims(
107 | (1 - alphas_cumprod_sqrt**2) ** 0.5, input.ndim
108 | )
109 |
110 | model_output = denoiser(network, noised_input, alphas_cumprod_sqrt, cond, **additional_model_inputs)
111 | w = append_dims(1 / (1 - alphas_cumprod_sqrt**2), input.ndim) # v-pred
112 |
113 | if self.min_snr_value is not None:
114 | w = min(w, self.min_snr_value)
115 | return self.get_loss(model_output, input, w)
116 |
117 | def get_loss(self, model_output, target, w):
118 | if self.type == "l2":
119 | return torch.mean((w * (model_output - target) ** 2).reshape(target.shape[0], -1), 1)
120 | elif self.type == "l1":
121 | return torch.mean((w * (model_output - target).abs()).reshape(target.shape[0], -1), 1)
122 | elif self.type == "lpips":
123 | loss = self.lpips(model_output, target).reshape(-1)
124 | return loss
125 |
126 |
127 | def get_3d_position_ids(frame_len, h, w):
128 | i = torch.arange(frame_len).view(frame_len, 1, 1).expand(frame_len, h, w)
129 | j = torch.arange(h).view(1, h, 1).expand(frame_len, h, w)
130 | k = torch.arange(w).view(1, 1, w).expand(frame_len, h, w)
131 | position_ids = torch.stack([i, j, k], dim=-1).reshape(-1, 3)
132 | return position_ids
133 |
--------------------------------------------------------------------------------
/sat/sgm/modules/diffusionmodules/sampling_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from scipy import integrate
3 |
4 | from ...util import append_dims
5 | from einops import rearrange
6 |
7 |
8 | class NoDynamicThresholding:
9 | def __call__(self, uncond, cond, scale):
10 | scale = append_dims(scale, cond.ndim) if isinstance(scale, torch.Tensor) else scale
11 | return uncond + scale * (cond - uncond)
12 |
13 |
14 | class StaticThresholding:
15 | def __call__(self, uncond, cond, scale):
16 | result = uncond + scale * (cond - uncond)
17 | result = torch.clamp(result, min=-1.0, max=1.0)
18 | return result
19 |
20 |
21 | def dynamic_threshold(x, p=0.95):
22 | N, T, C, H, W = x.shape
23 | x = rearrange(x, "n t c h w -> n c (t h w)")
24 | l, r = x.quantile(q=torch.tensor([1 - p, p], device=x.device), dim=-1, keepdim=True)
25 | s = torch.maximum(-l, r)
26 | threshold_mask = (s > 1).expand(-1, -1, H * W * T)
27 | if threshold_mask.any():
28 | x = torch.where(threshold_mask, x.clamp(min=-1 * s, max=s), x)
29 | x = rearrange(x, "n c (t h w) -> n t c h w", t=T, h=H, w=W)
30 | return x
31 |
32 |
33 | def dynamic_thresholding2(x0):
34 | p = 0.995 # A hyperparameter in the paper of "Imagen" [1].
35 | origin_dtype = x0.dtype
36 | x0 = x0.to(torch.float32)
37 | s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
38 | s = append_dims(torch.maximum(s, torch.ones_like(s).to(s.device)), x0.dim())
39 | x0 = torch.clamp(x0, -s, s) # / s
40 | return x0.to(origin_dtype)
41 |
42 |
43 | def latent_dynamic_thresholding(x0):
44 | p = 0.9995
45 | origin_dtype = x0.dtype
46 | x0 = x0.to(torch.float32)
47 | s = torch.quantile(torch.abs(x0), p, dim=2)
48 | s = append_dims(s, x0.dim())
49 | x0 = torch.clamp(x0, -s, s) / s
50 | return x0.to(origin_dtype)
51 |
52 |
53 | def dynamic_thresholding3(x0):
54 | p = 0.995 # A hyperparameter in the paper of "Imagen" [1].
55 | origin_dtype = x0.dtype
56 | x0 = x0.to(torch.float32)
57 | s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
58 | s = append_dims(torch.maximum(s, torch.ones_like(s).to(s.device)), x0.dim())
59 | x0 = torch.clamp(x0, -s, s) # / s
60 | return x0.to(origin_dtype)
61 |
62 |
63 | class DynamicThresholding:
64 | def __call__(self, uncond, cond, scale):
65 | mean = uncond.mean()
66 | std = uncond.std()
67 | result = uncond + scale * (cond - uncond)
68 | result_mean, result_std = result.mean(), result.std()
69 | result = (result - result_mean) / result_std * std
70 | # result = dynamic_thresholding3(result)
71 | return result
72 |
73 |
74 | class DynamicThresholdingV1:
75 | def __init__(self, scale_factor):
76 | self.scale_factor = scale_factor
77 |
78 | def __call__(self, uncond, cond, scale):
79 | result = uncond + scale * (cond - uncond)
80 | unscaled_result = result / self.scale_factor
81 | B, T, C, H, W = unscaled_result.shape
82 | flattened = rearrange(unscaled_result, "b t c h w -> b c (t h w)")
83 | means = flattened.mean(dim=2).unsqueeze(2)
84 | recentered = flattened - means
85 | magnitudes = recentered.abs().max()
86 | normalized = recentered / magnitudes
87 | thresholded = latent_dynamic_thresholding(normalized)
88 | denormalized = thresholded * magnitudes
89 | uncentered = denormalized + means
90 | unflattened = rearrange(uncentered, "b c (t h w) -> b t c h w", t=T, h=H, w=W)
91 | scaled_result = unflattened * self.scale_factor
92 | return scaled_result
93 |
94 |
95 | class DynamicThresholdingV2:
96 | def __call__(self, uncond, cond, scale):
97 | B, T, C, H, W = uncond.shape
98 | diff = cond - uncond
99 | mim_target = uncond + diff * 4.0
100 | cfg_target = uncond + diff * 8.0
101 |
102 | mim_flattened = rearrange(mim_target, "b t c h w -> b c (t h w)")
103 | cfg_flattened = rearrange(cfg_target, "b t c h w -> b c (t h w)")
104 | mim_means = mim_flattened.mean(dim=2).unsqueeze(2)
105 | cfg_means = cfg_flattened.mean(dim=2).unsqueeze(2)
106 | mim_centered = mim_flattened - mim_means
107 | cfg_centered = cfg_flattened - cfg_means
108 |
109 | mim_scaleref = mim_centered.std(dim=2).unsqueeze(2)
110 | cfg_scaleref = cfg_centered.std(dim=2).unsqueeze(2)
111 |
112 | cfg_renormalized = cfg_centered / cfg_scaleref * mim_scaleref
113 |
114 | result = cfg_renormalized + cfg_means
115 | unflattened = rearrange(result, "b c (t h w) -> b t c h w", t=T, h=H, w=W)
116 |
117 | return unflattened
118 |
119 |
120 | def linear_multistep_coeff(order, t, i, j, epsrel=1e-4):
121 | if order - 1 > i:
122 | raise ValueError(f"Order {order} too high for step {i}")
123 |
124 | def fn(tau):
125 | prod = 1.0
126 | for k in range(order):
127 | if j == k:
128 | continue
129 | prod *= (tau - t[i - k]) / (t[i - j] - t[i - k])
130 | return prod
131 |
132 | return integrate.quad(fn, t[i], t[i + 1], epsrel=epsrel)[0]
133 |
134 |
135 | def get_ancestral_step(sigma_from, sigma_to, eta=1.0):
136 | if not eta:
137 | return sigma_to, 0.0
138 | sigma_up = torch.minimum(
139 | sigma_to,
140 | eta * (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5,
141 | )
142 | sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
143 | return sigma_down, sigma_up
144 |
145 |
146 | def to_d(x, sigma, denoised):
147 | return (x - denoised) / append_dims(sigma, x.ndim)
148 |
149 |
150 | def to_neg_log_sigma(sigma):
151 | return sigma.log().neg()
152 |
153 |
154 | def to_sigma(neg_log_sigma):
155 | return neg_log_sigma.neg().exp()
156 |
--------------------------------------------------------------------------------
/sat/sgm/modules/diffusionmodules/sigma_sampling.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.distributed
3 |
4 | from sat import mpu
5 |
6 | from ...util import default, instantiate_from_config
7 |
8 |
9 | class EDMSampling:
10 | def __init__(self, p_mean=-1.2, p_std=1.2):
11 | self.p_mean = p_mean
12 | self.p_std = p_std
13 |
14 | def __call__(self, n_samples, rand=None):
15 | log_sigma = self.p_mean + self.p_std * default(rand, torch.randn((n_samples,)))
16 | return log_sigma.exp()
17 |
18 |
19 | class DiscreteSampling:
20 | def __init__(self, discretization_config, num_idx, do_append_zero=False, flip=True, uniform_sampling=False):
21 | self.num_idx = num_idx
22 | self.sigmas = instantiate_from_config(discretization_config)(num_idx, do_append_zero=do_append_zero, flip=flip)
23 | world_size = mpu.get_data_parallel_world_size()
24 | self.uniform_sampling = uniform_sampling
25 | if self.uniform_sampling:
26 | i = 1
27 | while True:
28 | if world_size % i != 0 or num_idx % (world_size // i) != 0:
29 | i += 1
30 | else:
31 | self.group_num = world_size // i
32 | break
33 |
34 | assert self.group_num > 0
35 | assert world_size % self.group_num == 0
36 | self.group_width = world_size // self.group_num # the number of rank in one group
37 | self.sigma_interval = self.num_idx // self.group_num
38 |
39 | def idx_to_sigma(self, idx):
40 | return self.sigmas[idx]
41 |
42 | def __call__(self, n_samples, rand=None, return_idx=False):
43 | if self.uniform_sampling:
44 | rank = mpu.get_data_parallel_rank()
45 | group_index = rank // self.group_width
46 | idx = default(
47 | rand,
48 | torch.randint(
49 | group_index * self.sigma_interval, (group_index + 1) * self.sigma_interval, (n_samples,)
50 | ),
51 | )
52 | else:
53 | idx = default(
54 | rand,
55 | torch.randint(0, self.num_idx, (n_samples,)),
56 | )
57 | if return_idx:
58 | return self.idx_to_sigma(idx), idx
59 | else:
60 | return self.idx_to_sigma(idx)
61 |
62 |
63 | class PartialDiscreteSampling:
64 | def __init__(self, discretization_config, total_num_idx, partial_num_idx, do_append_zero=False, flip=True):
65 | self.total_num_idx = total_num_idx
66 | self.partial_num_idx = partial_num_idx
67 | self.sigmas = instantiate_from_config(discretization_config)(
68 | total_num_idx, do_append_zero=do_append_zero, flip=flip
69 | )
70 |
71 | def idx_to_sigma(self, idx):
72 | return self.sigmas[idx]
73 |
74 | def __call__(self, n_samples, rand=None):
75 | idx = default(
76 | rand,
77 | # torch.randint(self.total_num_idx-self.partial_num_idx, self.total_num_idx, (n_samples,)),
78 | torch.randint(0, self.partial_num_idx, (n_samples,)),
79 | )
80 | return self.idx_to_sigma(idx)
81 |
--------------------------------------------------------------------------------
/sat/sgm/modules/diffusionmodules/wrappers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from packaging import version
4 |
5 | OPENAIUNETWRAPPER = "sgm.modules.diffusionmodules.wrappers.OpenAIWrapper"
6 |
7 |
8 | class IdentityWrapper(nn.Module):
9 | def __init__(self, diffusion_model, compile_model: bool = False, dtype: torch.dtype = torch.float32):
10 | super().__init__()
11 | compile = (
12 | torch.compile
13 | if (version.parse(torch.__version__) >= version.parse("2.0.0")) and compile_model
14 | else lambda x: x
15 | )
16 | self.diffusion_model = compile(diffusion_model)
17 | self.dtype = dtype
18 |
19 | def forward(self, *args, **kwargs):
20 | return self.diffusion_model(*args, **kwargs)
21 |
22 |
23 | class OpenAIWrapper(IdentityWrapper):
24 | def forward(self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs) -> torch.Tensor:
25 | for key in c:
26 | c[key] = c[key].to(self.dtype)
27 |
28 | if x.dim() == 4:
29 | x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=1)
30 | elif x.dim() == 5:
31 | x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=2)
32 | else:
33 | raise ValueError("Input tensor must be 4D or 5D")
34 |
35 | return self.diffusion_model(
36 | x,
37 | timesteps=t,
38 | context=c.get("crossattn", None),
39 | y=c.get("vector", None),
40 | **kwargs,
41 | )
42 |
--------------------------------------------------------------------------------
/sat/sgm/modules/distributions/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Vchitect/RepVideo/7727c0c592a152b9cd09a47d9d1ba3b81792e25f/sat/sgm/modules/distributions/__init__.py
--------------------------------------------------------------------------------
/sat/sgm/modules/distributions/distributions.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | class AbstractDistribution:
6 | def sample(self):
7 | raise NotImplementedError()
8 |
9 | def mode(self):
10 | raise NotImplementedError()
11 |
12 |
13 | class DiracDistribution(AbstractDistribution):
14 | def __init__(self, value):
15 | self.value = value
16 |
17 | def sample(self):
18 | return self.value
19 |
20 | def mode(self):
21 | return self.value
22 |
23 |
24 | class DiagonalGaussianDistribution(object):
25 | def __init__(self, parameters, deterministic=False):
26 | self.parameters = parameters
27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29 | self.deterministic = deterministic
30 | self.std = torch.exp(0.5 * self.logvar)
31 | self.var = torch.exp(self.logvar)
32 | if self.deterministic:
33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
34 |
35 | def sample(self):
36 | # x = self.mean + self.std * torch.randn(self.mean.shape).to(
37 | # device=self.parameters.device
38 | # )
39 | x = self.mean + self.std * torch.randn_like(self.mean).to(device=self.parameters.device)
40 | return x
41 |
42 | def kl(self, other=None):
43 | if self.deterministic:
44 | return torch.Tensor([0.0])
45 | else:
46 | if other is None:
47 | return 0.5 * torch.sum(
48 | torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
49 | dim=[1, 2, 3],
50 | )
51 | else:
52 | return 0.5 * torch.sum(
53 | torch.pow(self.mean - other.mean, 2) / other.var
54 | + self.var / other.var
55 | - 1.0
56 | - self.logvar
57 | + other.logvar,
58 | dim=[1, 2, 3],
59 | )
60 |
61 | def nll(self, sample, dims=[1, 2, 3]):
62 | if self.deterministic:
63 | return torch.Tensor([0.0])
64 | logtwopi = np.log(2.0 * np.pi)
65 | return 0.5 * torch.sum(
66 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
67 | dim=dims,
68 | )
69 |
70 | def mode(self):
71 | return self.mean
72 |
73 |
74 | def normal_kl(mean1, logvar1, mean2, logvar2):
75 | """
76 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
77 | Compute the KL divergence between two gaussians.
78 | Shapes are automatically broadcasted, so batches can be compared to
79 | scalars, among other use cases.
80 | """
81 | tensor = None
82 | for obj in (mean1, logvar1, mean2, logvar2):
83 | if isinstance(obj, torch.Tensor):
84 | tensor = obj
85 | break
86 | assert tensor is not None, "at least one argument must be a Tensor"
87 |
88 | # Force variances to be Tensors. Broadcasting helps convert scalars to
89 | # Tensors, but it does not work for torch.exp().
90 | logvar1, logvar2 = [x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) for x in (logvar1, logvar2)]
91 |
92 | return 0.5 * (
93 | -1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
94 | )
95 |
--------------------------------------------------------------------------------
/sat/sgm/modules/ema.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class LitEma(nn.Module):
6 | def __init__(self, model, decay=0.9999, use_num_upates=True):
7 | super().__init__()
8 | if decay < 0.0 or decay > 1.0:
9 | raise ValueError("Decay must be between 0 and 1")
10 |
11 | self.m_name2s_name = {}
12 | self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32))
13 | self.register_buffer(
14 | "num_updates",
15 | torch.tensor(0, dtype=torch.int) if use_num_upates else torch.tensor(-1, dtype=torch.int),
16 | )
17 |
18 | for name, p in model.named_parameters():
19 | if p.requires_grad:
20 | # remove as '.'-character is not allowed in buffers
21 | s_name = name.replace(".", "")
22 | self.m_name2s_name.update({name: s_name})
23 | self.register_buffer(s_name, p.clone().detach().data)
24 |
25 | self.collected_params = []
26 |
27 | def reset_num_updates(self):
28 | del self.num_updates
29 | self.register_buffer("num_updates", torch.tensor(0, dtype=torch.int))
30 |
31 | def forward(self, model):
32 | decay = self.decay
33 |
34 | if self.num_updates >= 0:
35 | self.num_updates += 1
36 | decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
37 |
38 | one_minus_decay = 1.0 - decay
39 |
40 | with torch.no_grad():
41 | m_param = dict(model.named_parameters())
42 | shadow_params = dict(self.named_buffers())
43 |
44 | for key in m_param:
45 | if m_param[key].requires_grad:
46 | sname = self.m_name2s_name[key]
47 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
48 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
49 | else:
50 | assert not key in self.m_name2s_name
51 |
52 | def copy_to(self, model):
53 | m_param = dict(model.named_parameters())
54 | shadow_params = dict(self.named_buffers())
55 | for key in m_param:
56 | if m_param[key].requires_grad:
57 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
58 | else:
59 | assert not key in self.m_name2s_name
60 |
61 | def store(self, parameters):
62 | """
63 | Save the current parameters for restoring later.
64 | Args:
65 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
66 | temporarily stored.
67 | """
68 | self.collected_params = [param.clone() for param in parameters]
69 |
70 | def restore(self, parameters):
71 | """
72 | Restore the parameters stored with the `store` method.
73 | Useful to validate the model with EMA parameters without affecting the
74 | original optimization process. Store the parameters before the
75 | `copy_to` method. After validation (or model saving), use this to
76 | restore the former parameters.
77 | Args:
78 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
79 | updated with the stored parameters.
80 | """
81 | for c_param, param in zip(self.collected_params, parameters):
82 | param.data.copy_(c_param.data)
83 |
--------------------------------------------------------------------------------
/sat/sgm/modules/encoders/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Vchitect/RepVideo/7727c0c592a152b9cd09a47d9d1ba3b81792e25f/sat/sgm/modules/encoders/__init__.py
--------------------------------------------------------------------------------
/sat/sgm/modules/encoders/modules.py:
--------------------------------------------------------------------------------
1 | import math
2 | from contextlib import nullcontext
3 | from functools import partial
4 | from typing import Dict, List, Optional, Tuple, Union
5 |
6 | import kornia
7 | import numpy as np
8 | import torch
9 | import torch.nn as nn
10 | from einops import rearrange, repeat
11 | from omegaconf import ListConfig
12 | from torch.utils.checkpoint import checkpoint
13 | from transformers import (
14 | T5EncoderModel,
15 | T5Tokenizer,
16 | )
17 |
18 | from ...util import (
19 | append_dims,
20 | autocast,
21 | count_params,
22 | default,
23 | disabled_train,
24 | expand_dims_like,
25 | instantiate_from_config,
26 | )
27 |
28 |
29 | class AbstractEmbModel(nn.Module):
30 | def __init__(self):
31 | super().__init__()
32 | self._is_trainable = None
33 | self._ucg_rate = None
34 | self._input_key = None
35 |
36 | @property
37 | def is_trainable(self) -> bool:
38 | return self._is_trainable
39 |
40 | @property
41 | def ucg_rate(self) -> Union[float, torch.Tensor]:
42 | return self._ucg_rate
43 |
44 | @property
45 | def input_key(self) -> str:
46 | return self._input_key
47 |
48 | @is_trainable.setter
49 | def is_trainable(self, value: bool):
50 | self._is_trainable = value
51 |
52 | @ucg_rate.setter
53 | def ucg_rate(self, value: Union[float, torch.Tensor]):
54 | self._ucg_rate = value
55 |
56 | @input_key.setter
57 | def input_key(self, value: str):
58 | self._input_key = value
59 |
60 | @is_trainable.deleter
61 | def is_trainable(self):
62 | del self._is_trainable
63 |
64 | @ucg_rate.deleter
65 | def ucg_rate(self):
66 | del self._ucg_rate
67 |
68 | @input_key.deleter
69 | def input_key(self):
70 | del self._input_key
71 |
72 |
73 | class GeneralConditioner(nn.Module):
74 | OUTPUT_DIM2KEYS = {2: "vector", 3: "crossattn", 4: "concat", 5: "concat"}
75 | KEY2CATDIM = {"vector": 1, "crossattn": 2, "concat": 1}
76 |
77 | def __init__(self, emb_models: Union[List, ListConfig], cor_embs=[], cor_p=[]):
78 | super().__init__()
79 | embedders = []
80 | for n, embconfig in enumerate(emb_models):
81 | embedder = instantiate_from_config(embconfig)
82 | assert isinstance(
83 | embedder, AbstractEmbModel
84 | ), f"embedder model {embedder.__class__.__name__} has to inherit from AbstractEmbModel"
85 | embedder.is_trainable = embconfig.get("is_trainable", False)
86 | embedder.ucg_rate = embconfig.get("ucg_rate", 0.0)
87 | if not embedder.is_trainable:
88 | embedder.train = disabled_train
89 | for param in embedder.parameters():
90 | param.requires_grad = False
91 | embedder.eval()
92 | print(
93 | f"Initialized embedder #{n}: {embedder.__class__.__name__} "
94 | f"with {count_params(embedder, False)} params. Trainable: {embedder.is_trainable}"
95 | )
96 |
97 | if "input_key" in embconfig:
98 | embedder.input_key = embconfig["input_key"]
99 | elif "input_keys" in embconfig:
100 | embedder.input_keys = embconfig["input_keys"]
101 | else:
102 | raise KeyError(f"need either 'input_key' or 'input_keys' for embedder {embedder.__class__.__name__}")
103 |
104 | embedder.legacy_ucg_val = embconfig.get("legacy_ucg_value", None)
105 | if embedder.legacy_ucg_val is not None:
106 | embedder.ucg_prng = np.random.RandomState()
107 |
108 | embedders.append(embedder)
109 | self.embedders = nn.ModuleList(embedders)
110 |
111 | if len(cor_embs) > 0:
112 | assert len(cor_p) == 2 ** len(cor_embs)
113 | self.cor_embs = cor_embs
114 | self.cor_p = cor_p
115 |
116 | def possibly_get_ucg_val(self, embedder: AbstractEmbModel, batch: Dict) -> Dict:
117 | assert embedder.legacy_ucg_val is not None
118 | p = embedder.ucg_rate
119 | val = embedder.legacy_ucg_val
120 | for i in range(len(batch[embedder.input_key])):
121 | if embedder.ucg_prng.choice(2, p=[1 - p, p]):
122 | batch[embedder.input_key][i] = val
123 | return batch
124 |
125 | def surely_get_ucg_val(self, embedder: AbstractEmbModel, batch: Dict, cond_or_not) -> Dict:
126 | assert embedder.legacy_ucg_val is not None
127 | val = embedder.legacy_ucg_val
128 | for i in range(len(batch[embedder.input_key])):
129 | if cond_or_not[i]:
130 | batch[embedder.input_key][i] = val
131 | return batch
132 |
133 | def get_single_embedding(
134 | self,
135 | embedder,
136 | batch,
137 | output,
138 | cond_or_not: Optional[np.ndarray] = None,
139 | force_zero_embeddings: Optional[List] = None,
140 | ):
141 | embedding_context = nullcontext if embedder.is_trainable else torch.no_grad
142 | with embedding_context():
143 | if hasattr(embedder, "input_key") and (embedder.input_key is not None):
144 | if embedder.legacy_ucg_val is not None:
145 | if cond_or_not is None:
146 | batch = self.possibly_get_ucg_val(embedder, batch)
147 | else:
148 | batch = self.surely_get_ucg_val(embedder, batch, cond_or_not)
149 | emb_out = embedder(batch[embedder.input_key])
150 | elif hasattr(embedder, "input_keys"):
151 | emb_out = embedder(*[batch[k] for k in embedder.input_keys])
152 | assert isinstance(
153 | emb_out, (torch.Tensor, list, tuple)
154 | ), f"encoder outputs must be tensors or a sequence, but got {type(emb_out)}"
155 | if not isinstance(emb_out, (list, tuple)):
156 | emb_out = [emb_out]
157 | for emb in emb_out:
158 | out_key = self.OUTPUT_DIM2KEYS[emb.dim()]
159 | if embedder.ucg_rate > 0.0 and embedder.legacy_ucg_val is None:
160 | if cond_or_not is None:
161 | emb = (
162 | expand_dims_like(
163 | torch.bernoulli((1.0 - embedder.ucg_rate) * torch.ones(emb.shape[0], device=emb.device)),
164 | emb,
165 | )
166 | * emb
167 | )
168 | else:
169 | emb = (
170 | expand_dims_like(
171 | torch.tensor(1 - cond_or_not, dtype=emb.dtype, device=emb.device),
172 | emb,
173 | )
174 | * emb
175 | )
176 | if hasattr(embedder, "input_key") and embedder.input_key in force_zero_embeddings:
177 | emb = torch.zeros_like(emb)
178 | if out_key in output:
179 | output[out_key] = torch.cat((output[out_key], emb), self.KEY2CATDIM[out_key])
180 | else:
181 | output[out_key] = emb
182 | return output
183 |
184 | def forward(self, batch: Dict, force_zero_embeddings: Optional[List] = None) -> Dict:
185 | output = dict()
186 | if force_zero_embeddings is None:
187 | force_zero_embeddings = []
188 |
189 | if len(self.cor_embs) > 0:
190 | batch_size = len(batch[list(batch.keys())[0]])
191 | rand_idx = np.random.choice(len(self.cor_p), size=(batch_size,), p=self.cor_p)
192 | for emb_idx in self.cor_embs:
193 | cond_or_not = rand_idx % 2
194 | rand_idx //= 2
195 | output = self.get_single_embedding(
196 | self.embedders[emb_idx],
197 | batch,
198 | output=output,
199 | cond_or_not=cond_or_not,
200 | force_zero_embeddings=force_zero_embeddings,
201 | )
202 |
203 | for i, embedder in enumerate(self.embedders):
204 | if i in self.cor_embs:
205 | continue
206 | output = self.get_single_embedding(
207 | embedder, batch, output=output, force_zero_embeddings=force_zero_embeddings
208 | )
209 | return output
210 |
211 | def get_unconditional_conditioning(self, batch_c, batch_uc=None, force_uc_zero_embeddings=None):
212 | if force_uc_zero_embeddings is None:
213 | force_uc_zero_embeddings = []
214 | ucg_rates = list()
215 | for embedder in self.embedders:
216 | ucg_rates.append(embedder.ucg_rate)
217 | embedder.ucg_rate = 0.0
218 | cor_embs = self.cor_embs
219 | cor_p = self.cor_p
220 | self.cor_embs = []
221 | self.cor_p = []
222 |
223 | c = self(batch_c)
224 | uc = self(batch_c if batch_uc is None else batch_uc, force_uc_zero_embeddings)
225 |
226 | for embedder, rate in zip(self.embedders, ucg_rates):
227 | embedder.ucg_rate = rate
228 | self.cor_embs = cor_embs
229 | self.cor_p = cor_p
230 |
231 | return c, uc
232 |
233 |
234 | class FrozenT5Embedder(AbstractEmbModel):
235 | """Uses the T5 transformer encoder for text"""
236 |
237 | def __init__(
238 | self,
239 | model_dir="google/t5-v1_1-xxl",
240 | device="cuda",
241 | max_length=77,
242 | freeze=True,
243 | cache_dir=None,
244 | ):
245 | super().__init__()
246 | if model_dir is not "google/t5-v1_1-xxl":
247 | self.tokenizer = T5Tokenizer.from_pretrained(model_dir)
248 | self.transformer = T5EncoderModel.from_pretrained(model_dir)
249 | else:
250 | self.tokenizer = T5Tokenizer.from_pretrained(model_dir, cache_dir=cache_dir)
251 | self.transformer = T5EncoderModel.from_pretrained(model_dir, cache_dir=cache_dir)
252 | self.device = device
253 | self.max_length = max_length
254 | if freeze:
255 | self.freeze()
256 |
257 | def freeze(self):
258 | self.transformer = self.transformer.eval()
259 |
260 | for param in self.parameters():
261 | param.requires_grad = False
262 |
263 | # @autocast
264 | def forward(self, text):
265 | batch_encoding = self.tokenizer(
266 | text,
267 | truncation=True,
268 | max_length=self.max_length,
269 | return_length=True,
270 | return_overflowing_tokens=False,
271 | padding="max_length",
272 | return_tensors="pt",
273 | )
274 | tokens = batch_encoding["input_ids"].to(self.device)
275 | with torch.autocast("cuda", enabled=False):
276 | outputs = self.transformer(input_ids=tokens)
277 | z = outputs.last_hidden_state
278 | return z
279 |
280 | def encode(self, text):
281 | return self(text)
282 |
--------------------------------------------------------------------------------
/sat/sgm/modules/video_attention.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from ..modules.attention import *
4 | from ..modules.diffusionmodules.util import AlphaBlender, linear, timestep_embedding
5 |
6 |
7 | class TimeMixSequential(nn.Sequential):
8 | def forward(self, x, context=None, timesteps=None):
9 | for layer in self:
10 | x = layer(x, context, timesteps)
11 |
12 | return x
13 |
14 |
15 | class VideoTransformerBlock(nn.Module):
16 | ATTENTION_MODES = {
17 | "softmax": CrossAttention,
18 | "softmax-xformers": MemoryEfficientCrossAttention,
19 | }
20 |
21 | def __init__(
22 | self,
23 | dim,
24 | n_heads,
25 | d_head,
26 | dropout=0.0,
27 | context_dim=None,
28 | gated_ff=True,
29 | checkpoint=True,
30 | timesteps=None,
31 | ff_in=False,
32 | inner_dim=None,
33 | attn_mode="softmax",
34 | disable_self_attn=False,
35 | disable_temporal_crossattention=False,
36 | switch_temporal_ca_to_sa=False,
37 | ):
38 | super().__init__()
39 |
40 | attn_cls = self.ATTENTION_MODES[attn_mode]
41 |
42 | self.ff_in = ff_in or inner_dim is not None
43 | if inner_dim is None:
44 | inner_dim = dim
45 |
46 | assert int(n_heads * d_head) == inner_dim
47 |
48 | self.is_res = inner_dim == dim
49 |
50 | if self.ff_in:
51 | self.norm_in = nn.LayerNorm(dim)
52 | self.ff_in = FeedForward(dim, dim_out=inner_dim, dropout=dropout, glu=gated_ff)
53 |
54 | self.timesteps = timesteps
55 | self.disable_self_attn = disable_self_attn
56 | if self.disable_self_attn:
57 | self.attn1 = attn_cls(
58 | query_dim=inner_dim,
59 | heads=n_heads,
60 | dim_head=d_head,
61 | context_dim=context_dim,
62 | dropout=dropout,
63 | ) # is a cross-attention
64 | else:
65 | self.attn1 = attn_cls(
66 | query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout
67 | ) # is a self-attention
68 |
69 | self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff)
70 |
71 | if disable_temporal_crossattention:
72 | if switch_temporal_ca_to_sa:
73 | raise ValueError
74 | else:
75 | self.attn2 = None
76 | else:
77 | self.norm2 = nn.LayerNorm(inner_dim)
78 | if switch_temporal_ca_to_sa:
79 | self.attn2 = attn_cls(
80 | query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout
81 | ) # is a self-attention
82 | else:
83 | self.attn2 = attn_cls(
84 | query_dim=inner_dim,
85 | context_dim=context_dim,
86 | heads=n_heads,
87 | dim_head=d_head,
88 | dropout=dropout,
89 | ) # is self-attn if context is none
90 |
91 | self.norm1 = nn.LayerNorm(inner_dim)
92 | self.norm3 = nn.LayerNorm(inner_dim)
93 | self.switch_temporal_ca_to_sa = switch_temporal_ca_to_sa
94 |
95 | self.checkpoint = checkpoint
96 | if self.checkpoint:
97 | print(f"{self.__class__.__name__} is using checkpointing")
98 |
99 | def forward(self, x: torch.Tensor, context: torch.Tensor = None, timesteps: int = None) -> torch.Tensor:
100 | if self.checkpoint:
101 | return checkpoint(self._forward, x, context, timesteps)
102 | else:
103 | return self._forward(x, context, timesteps=timesteps)
104 |
105 | def _forward(self, x, context=None, timesteps=None):
106 | assert self.timesteps or timesteps
107 | assert not (self.timesteps and timesteps) or self.timesteps == timesteps
108 | timesteps = self.timesteps or timesteps
109 | B, S, C = x.shape
110 | x = rearrange(x, "(b t) s c -> (b s) t c", t=timesteps)
111 |
112 | if self.ff_in:
113 | x_skip = x
114 | x = self.ff_in(self.norm_in(x))
115 | if self.is_res:
116 | x += x_skip
117 |
118 | if self.disable_self_attn:
119 | x = self.attn1(self.norm1(x), context=context) + x
120 | else:
121 | x = self.attn1(self.norm1(x)) + x
122 |
123 | if self.attn2 is not None:
124 | if self.switch_temporal_ca_to_sa:
125 | x = self.attn2(self.norm2(x)) + x
126 | else:
127 | x = self.attn2(self.norm2(x), context=context) + x
128 | x_skip = x
129 | x = self.ff(self.norm3(x))
130 | if self.is_res:
131 | x += x_skip
132 |
133 | x = rearrange(x, "(b s) t c -> (b t) s c", s=S, b=B // timesteps, c=C, t=timesteps)
134 | return x
135 |
136 | def get_last_layer(self):
137 | return self.ff.net[-1].weight
138 |
139 |
140 | str_to_dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}
141 |
142 |
143 | class SpatialVideoTransformer(SpatialTransformer):
144 | def __init__(
145 | self,
146 | in_channels,
147 | n_heads,
148 | d_head,
149 | depth=1,
150 | dropout=0.0,
151 | use_linear=False,
152 | context_dim=None,
153 | use_spatial_context=False,
154 | timesteps=None,
155 | merge_strategy: str = "fixed",
156 | merge_factor: float = 0.5,
157 | time_context_dim=None,
158 | ff_in=False,
159 | checkpoint=False,
160 | time_depth=1,
161 | attn_mode="softmax",
162 | disable_self_attn=False,
163 | disable_temporal_crossattention=False,
164 | max_time_embed_period: int = 10000,
165 | dtype="fp32",
166 | ):
167 | super().__init__(
168 | in_channels,
169 | n_heads,
170 | d_head,
171 | depth=depth,
172 | dropout=dropout,
173 | attn_type=attn_mode,
174 | use_checkpoint=checkpoint,
175 | context_dim=context_dim,
176 | use_linear=use_linear,
177 | disable_self_attn=disable_self_attn,
178 | )
179 | self.time_depth = time_depth
180 | self.depth = depth
181 | self.max_time_embed_period = max_time_embed_period
182 |
183 | time_mix_d_head = d_head
184 | n_time_mix_heads = n_heads
185 |
186 | time_mix_inner_dim = int(time_mix_d_head * n_time_mix_heads)
187 |
188 | inner_dim = n_heads * d_head
189 | if use_spatial_context:
190 | time_context_dim = context_dim
191 |
192 | self.time_stack = nn.ModuleList(
193 | [
194 | VideoTransformerBlock(
195 | inner_dim,
196 | n_time_mix_heads,
197 | time_mix_d_head,
198 | dropout=dropout,
199 | context_dim=time_context_dim,
200 | timesteps=timesteps,
201 | checkpoint=checkpoint,
202 | ff_in=ff_in,
203 | inner_dim=time_mix_inner_dim,
204 | attn_mode=attn_mode,
205 | disable_self_attn=disable_self_attn,
206 | disable_temporal_crossattention=disable_temporal_crossattention,
207 | )
208 | for _ in range(self.depth)
209 | ]
210 | )
211 |
212 | assert len(self.time_stack) == len(self.transformer_blocks)
213 |
214 | self.use_spatial_context = use_spatial_context
215 | self.in_channels = in_channels
216 |
217 | time_embed_dim = self.in_channels * 4
218 | self.time_pos_embed = nn.Sequential(
219 | linear(self.in_channels, time_embed_dim),
220 | nn.SiLU(),
221 | linear(time_embed_dim, self.in_channels),
222 | )
223 |
224 | self.time_mixer = AlphaBlender(alpha=merge_factor, merge_strategy=merge_strategy)
225 | self.dtype = str_to_dtype[dtype]
226 |
227 | def forward(
228 | self,
229 | x: torch.Tensor,
230 | context: Optional[torch.Tensor] = None,
231 | time_context: Optional[torch.Tensor] = None,
232 | timesteps: Optional[int] = None,
233 | image_only_indicator: Optional[torch.Tensor] = None,
234 | ) -> torch.Tensor:
235 | _, _, h, w = x.shape
236 | x_in = x
237 | spatial_context = None
238 | if exists(context):
239 | spatial_context = context
240 |
241 | if self.use_spatial_context:
242 | assert context.ndim == 3, f"n dims of spatial context should be 3 but are {context.ndim}"
243 |
244 | time_context = context
245 | time_context_first_timestep = time_context[::timesteps]
246 | time_context = repeat(time_context_first_timestep, "b ... -> (b n) ...", n=h * w)
247 | elif time_context is not None and not self.use_spatial_context:
248 | time_context = repeat(time_context, "b ... -> (b n) ...", n=h * w)
249 | if time_context.ndim == 2:
250 | time_context = rearrange(time_context, "b c -> b 1 c")
251 |
252 | x = self.norm(x)
253 | if not self.use_linear:
254 | x = self.proj_in(x)
255 | x = rearrange(x, "b c h w -> b (h w) c")
256 | if self.use_linear:
257 | x = self.proj_in(x)
258 |
259 | num_frames = torch.arange(timesteps, device=x.device)
260 | num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
261 | num_frames = rearrange(num_frames, "b t -> (b t)")
262 | t_emb = timestep_embedding(
263 | num_frames,
264 | self.in_channels,
265 | repeat_only=False,
266 | max_period=self.max_time_embed_period,
267 | dtype=self.dtype,
268 | )
269 | emb = self.time_pos_embed(t_emb)
270 | emb = emb[:, None, :]
271 |
272 | for it_, (block, mix_block) in enumerate(zip(self.transformer_blocks, self.time_stack)):
273 | x = block(
274 | x,
275 | context=spatial_context,
276 | )
277 |
278 | x_mix = x
279 | x_mix = x_mix + emb
280 |
281 | x_mix = mix_block(x_mix, context=time_context, timesteps=timesteps)
282 | x = self.time_mixer(
283 | x_spatial=x,
284 | x_temporal=x_mix,
285 | image_only_indicator=image_only_indicator,
286 | )
287 | if self.use_linear:
288 | x = self.proj_out(x)
289 | x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
290 | if not self.use_linear:
291 | x = self.proj_out(x)
292 | out = x + x_in
293 | return out
294 |
--------------------------------------------------------------------------------
/sat/vae_modules/ema.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class LitEma(nn.Module):
6 | def __init__(self, model, decay=0.9999, use_num_upates=True):
7 | super().__init__()
8 | if decay < 0.0 or decay > 1.0:
9 | raise ValueError("Decay must be between 0 and 1")
10 |
11 | self.m_name2s_name = {}
12 | self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32))
13 | self.register_buffer(
14 | "num_updates",
15 | torch.tensor(0, dtype=torch.int) if use_num_upates else torch.tensor(-1, dtype=torch.int),
16 | )
17 |
18 | for name, p in model.named_parameters():
19 | if p.requires_grad:
20 | # remove as '.'-character is not allowed in buffers
21 | s_name = name.replace(".", "")
22 | self.m_name2s_name.update({name: s_name})
23 | self.register_buffer(s_name, p.clone().detach().data)
24 |
25 | self.collected_params = []
26 |
27 | def reset_num_updates(self):
28 | del self.num_updates
29 | self.register_buffer("num_updates", torch.tensor(0, dtype=torch.int))
30 |
31 | def forward(self, model):
32 | decay = self.decay
33 |
34 | if self.num_updates >= 0:
35 | self.num_updates += 1
36 | decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
37 |
38 | one_minus_decay = 1.0 - decay
39 |
40 | with torch.no_grad():
41 | m_param = dict(model.named_parameters())
42 | shadow_params = dict(self.named_buffers())
43 |
44 | for key in m_param:
45 | if m_param[key].requires_grad:
46 | sname = self.m_name2s_name[key]
47 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
48 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
49 | else:
50 | assert not key in self.m_name2s_name
51 |
52 | def copy_to(self, model):
53 | m_param = dict(model.named_parameters())
54 | shadow_params = dict(self.named_buffers())
55 | for key in m_param:
56 | if m_param[key].requires_grad:
57 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
58 | else:
59 | assert not key in self.m_name2s_name
60 |
61 | def store(self, parameters):
62 | """
63 | Save the current parameters for restoring later.
64 | Args:
65 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
66 | temporarily stored.
67 | """
68 | self.collected_params = [param.clone() for param in parameters]
69 |
70 | def restore(self, parameters):
71 | """
72 | Restore the parameters stored with the `store` method.
73 | Useful to validate the model with EMA parameters without affecting the
74 | original optimization process. Store the parameters before the
75 | `copy_to` method. After validation (or model saving), use this to
76 | restore the former parameters.
77 | Args:
78 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
79 | updated with the stored parameters.
80 | """
81 | for c_param, param in zip(self.collected_params, parameters):
82 | param.data.copy_(c_param.data)
83 |
--------------------------------------------------------------------------------
/sat/vae_modules/regularizers.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod
2 | from typing import Any, Tuple
3 |
4 | import numpy as np
5 | import torch
6 | import torch.nn.functional as F
7 | from torch import nn
8 |
9 |
10 | class DiagonalGaussianDistribution(object):
11 | def __init__(self, parameters, deterministic=False):
12 | self.parameters = parameters
13 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
14 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
15 | self.deterministic = deterministic
16 | self.std = torch.exp(0.5 * self.logvar)
17 | self.var = torch.exp(self.logvar)
18 | if self.deterministic:
19 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
20 |
21 | def sample(self):
22 | # x = self.mean + self.std * torch.randn(self.mean.shape).to(
23 | # device=self.parameters.device
24 | # )
25 | x = self.mean + self.std * torch.randn_like(self.mean)
26 | return x
27 |
28 | def kl(self, other=None):
29 | if self.deterministic:
30 | return torch.Tensor([0.0])
31 | else:
32 | if other is None:
33 | return 0.5 * torch.sum(
34 | torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
35 | dim=[1, 2, 3],
36 | )
37 | else:
38 | return 0.5 * torch.sum(
39 | torch.pow(self.mean - other.mean, 2) / other.var
40 | + self.var / other.var
41 | - 1.0
42 | - self.logvar
43 | + other.logvar,
44 | dim=[1, 2, 3],
45 | )
46 |
47 | def nll(self, sample, dims=[1, 2, 3]):
48 | if self.deterministic:
49 | return torch.Tensor([0.0])
50 | logtwopi = np.log(2.0 * np.pi)
51 | return 0.5 * torch.sum(
52 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
53 | dim=dims,
54 | )
55 |
56 | def mode(self):
57 | return self.mean
58 |
59 |
60 | class AbstractRegularizer(nn.Module):
61 | def __init__(self):
62 | super().__init__()
63 |
64 | def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
65 | raise NotImplementedError()
66 |
67 | @abstractmethod
68 | def get_trainable_parameters(self) -> Any:
69 | raise NotImplementedError()
70 |
71 |
72 | class IdentityRegularizer(AbstractRegularizer):
73 | def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
74 | return z, dict()
75 |
76 | def get_trainable_parameters(self) -> Any:
77 | yield from ()
78 |
79 |
80 | def measure_perplexity(predicted_indices: torch.Tensor, num_centroids: int) -> Tuple[torch.Tensor, torch.Tensor]:
81 | # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
82 | # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
83 | encodings = F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids)
84 | avg_probs = encodings.mean(0)
85 | perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
86 | cluster_use = torch.sum(avg_probs > 0)
87 | return perplexity, cluster_use
88 |
89 |
90 | class DiagonalGaussianRegularizer(AbstractRegularizer):
91 | def __init__(self, sample: bool = True):
92 | super().__init__()
93 | self.sample = sample
94 |
95 | def get_trainable_parameters(self) -> Any:
96 | yield from ()
97 |
98 | def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
99 | log = dict()
100 | posterior = DiagonalGaussianDistribution(z)
101 | if self.sample:
102 | z = posterior.sample()
103 | else:
104 | z = posterior.mode()
105 | kl_loss = posterior.kl()
106 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
107 | log["kl_loss"] = kl_loss
108 | return z, log
109 |
--------------------------------------------------------------------------------
/tools/caption/README.md:
--------------------------------------------------------------------------------
1 | # Video Caption
2 |
3 | Typically, most video data does not come with corresponding descriptive text, so it is necessary to convert the video
4 | data into textual descriptions to provide the essential training data for text-to-video models.
5 |
6 | ## Video Caption via CogVLM2-Video
7 |
8 |
9 | 🤗 Hugging Face   |   🤖 ModelScope   |    📑 Blog    | 💬 Online Demo  
10 |
11 |
12 | CogVLM2-Video is a versatile video understanding model equipped with timestamp-based question answering capabilities.
13 | Users can input prompts such as `Please describe this video in detail.` to the model to obtain a detailed video caption:
14 |
15 |

16 |
17 |
18 | Users can use the provided [code](https://github.com/THUDM/CogVLM2/tree/main/video_demo) to load the model or configure a RESTful API to generate video captions.
--------------------------------------------------------------------------------
/tools/caption/README_ja.md:
--------------------------------------------------------------------------------
1 | # ビデオキャプション
2 |
3 | 通常、ほとんどのビデオデータには対応する説明文が付いていないため、ビデオデータをテキストの説明に変換して、テキストからビデオへのモデルに必要なトレーニングデータを提供する必要があります。
4 |
5 | ## CogVLM2-Video を使用したビデオキャプション
6 |
7 |
8 | 🤗 Hugging Face   |   🤖 ModelScope   |    📑 ブログ    | 💬 オンラインデモ  
9 |
10 |
11 | CogVLM2-Video は、タイムスタンプベースの質問応答機能を備えた多機能なビデオ理解モデルです。ユーザーは `このビデオを詳細に説明してください。` などのプロンプトをモデルに入力して、詳細なビデオキャプションを取得できます:
12 |
13 |

14 |
15 |
16 | ユーザーは提供された[コード](https://github.com/THUDM/CogVLM2/tree/main/video_demo)を使用してモデルをロードするか、RESTful API を構成してビデオキャプションを生成できます。
17 |
--------------------------------------------------------------------------------
/tools/caption/README_zh.md:
--------------------------------------------------------------------------------
1 | # 视频Caption
2 |
3 | 通常,大多数视频数据不带有相应的描述性文本,因此需要将视频数据转换为文本描述,以提供必要的训练数据用于文本到视频模型。
4 |
5 | ## 通过 CogVLM2-Video 模型生成视频Caption
6 |
7 | 🤗 [Hugging Face](https://huggingface.co/THUDM/cogvlm2-video-llama3-chat) | 🤖 [ModelScope](https://modelscope.cn/models/ZhipuAI/cogvlm2-video-llama3-chat) | 📑 [Blog](https://cogvlm2-video.github.io/) | [💬 Online Demo](http://cogvlm2-online.cogviewai.cn:7868/)
8 |
9 | CogVLM2-Video 是一个多功能的视频理解模型,具备基于时间戳的问题回答能力。用户可以输入诸如 `请详细描述这个视频` 的提示语给模型,以获得详细的视频Caption:
10 |
11 |
12 |
13 |

14 |
15 |
16 | 用户可以使用提供的[代码](https://github.com/THUDM/CogVLM2/tree/main/video_demo)加载模型或配置 RESTful API 来生成视频Caption。
--------------------------------------------------------------------------------
/tools/caption/assests/cogvlm2-video-example.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Vchitect/RepVideo/7727c0c592a152b9cd09a47d9d1ba3b81792e25f/tools/caption/assests/cogvlm2-video-example.png
--------------------------------------------------------------------------------
/tools/convert_weight_sat2hf.py:
--------------------------------------------------------------------------------
1 | """
2 | This script demonstrates how to convert and generate video from a text prompt using CogVideoX with 🤗Huggingface Diffusers Pipeline.
3 |
4 | Note:
5 | This script requires the `diffusers>=0.30.0` library to be installed.
6 |
7 | Run the script:
8 | $ python convert_and_generate.py --transformer_ckpt_path --vae_ckpt_path --output_path --text_encoder_path
9 |
10 | Functions:
11 | - reassign_query_key_value_inplace: Reassigns the query, key, and value weights in-place.
12 | - reassign_query_key_layernorm_inplace: Reassigns layer normalization for query and key in-place.
13 | - reassign_adaln_norm_inplace: Reassigns adaptive layer normalization in-place.
14 | - remove_keys_inplace: Removes specified keys from the state_dict in-place.
15 | - replace_up_keys_inplace: Replaces keys in the "up" block in-place.
16 | - get_state_dict: Extracts the state_dict from a saved checkpoint.
17 | - update_state_dict_inplace: Updates the state_dict with new key assignments in-place.
18 | - convert_transformer: Converts a transformer checkpoint to the CogVideoX format.
19 | - convert_vae: Converts a VAE checkpoint to the CogVideoX format.
20 | - get_args: Parses command-line arguments for the script.
21 | - generate_video: Generates a video from a text prompt using the CogVideoX pipeline.
22 | """
23 |
24 | import argparse
25 | from typing import Any, Dict
26 |
27 | import torch
28 | from diffusers import AutoencoderKLCogVideoX, CogVideoXDDIMScheduler, CogVideoXPipeline, CogVideoXTransformer3DModel
29 | from transformers import T5EncoderModel, T5Tokenizer
30 |
31 |
32 | # Function to reassign the query, key, and value weights in-place
33 | def reassign_query_key_value_inplace(key: str, state_dict: Dict[str, Any]):
34 | to_q_key = key.replace("query_key_value", "to_q")
35 | to_k_key = key.replace("query_key_value", "to_k")
36 | to_v_key = key.replace("query_key_value", "to_v")
37 | to_q, to_k, to_v = torch.chunk(state_dict[key], chunks=3, dim=0)
38 | state_dict[to_q_key] = to_q
39 | state_dict[to_k_key] = to_k
40 | state_dict[to_v_key] = to_v
41 | state_dict.pop(key)
42 |
43 |
44 | # Function to reassign layer normalization for query and key in-place
45 | def reassign_query_key_layernorm_inplace(key: str, state_dict: Dict[str, Any]):
46 | layer_id, weight_or_bias = key.split(".")[-2:]
47 |
48 | if "query" in key:
49 | new_key = f"transformer_blocks.{layer_id}.attn1.norm_q.{weight_or_bias}"
50 | elif "key" in key:
51 | new_key = f"transformer_blocks.{layer_id}.attn1.norm_k.{weight_or_bias}"
52 |
53 | state_dict[new_key] = state_dict.pop(key)
54 |
55 |
56 | # Function to reassign adaptive layer normalization in-place
57 | def reassign_adaln_norm_inplace(key: str, state_dict: Dict[str, Any]):
58 | layer_id, _, weight_or_bias = key.split(".")[-3:]
59 |
60 | weights_or_biases = state_dict[key].chunk(12, dim=0)
61 | norm1_weights_or_biases = torch.cat(weights_or_biases[0:3] + weights_or_biases[6:9])
62 | norm2_weights_or_biases = torch.cat(weights_or_biases[3:6] + weights_or_biases[9:12])
63 |
64 | norm1_key = f"transformer_blocks.{layer_id}.norm1.linear.{weight_or_bias}"
65 | state_dict[norm1_key] = norm1_weights_or_biases
66 |
67 | norm2_key = f"transformer_blocks.{layer_id}.norm2.linear.{weight_or_bias}"
68 | state_dict[norm2_key] = norm2_weights_or_biases
69 |
70 | state_dict.pop(key)
71 |
72 |
73 | # Function to remove keys from state_dict in-place
74 | def remove_keys_inplace(key: str, state_dict: Dict[str, Any]):
75 | state_dict.pop(key)
76 |
77 |
78 | # Function to replace keys in the "up" block in-place
79 | def replace_up_keys_inplace(key: str, state_dict: Dict[str, Any]):
80 | key_split = key.split(".")
81 | layer_index = int(key_split[2])
82 | replace_layer_index = 4 - 1 - layer_index
83 |
84 | key_split[1] = "up_blocks"
85 | key_split[2] = str(replace_layer_index)
86 | new_key = ".".join(key_split)
87 |
88 | state_dict[new_key] = state_dict.pop(key)
89 |
90 |
91 | # Dictionary for renaming transformer keys
92 | TRANSFORMER_KEYS_RENAME_DICT = {
93 | "transformer.final_layernorm": "norm_final",
94 | "transformer": "transformer_blocks",
95 | "attention": "attn1",
96 | "mlp": "ff.net",
97 | "dense_h_to_4h": "0.proj",
98 | "dense_4h_to_h": "2",
99 | ".layers": "",
100 | "dense": "to_out.0",
101 | "input_layernorm": "norm1.norm",
102 | "post_attn1_layernorm": "norm2.norm",
103 | "time_embed.0": "time_embedding.linear_1",
104 | "time_embed.2": "time_embedding.linear_2",
105 | "mixins.patch_embed": "patch_embed",
106 | "mixins.final_layer.norm_final": "norm_out.norm",
107 | "mixins.final_layer.linear": "proj_out",
108 | "mixins.final_layer.adaLN_modulation.1": "norm_out.linear",
109 | }
110 |
111 | # Dictionary for handling special keys in transformer
112 | TRANSFORMER_SPECIAL_KEYS_REMAP = {
113 | "query_key_value": reassign_query_key_value_inplace,
114 | "query_layernorm_list": reassign_query_key_layernorm_inplace,
115 | "key_layernorm_list": reassign_query_key_layernorm_inplace,
116 | "adaln_layer.adaLN_modulations": reassign_adaln_norm_inplace,
117 | "embed_tokens": remove_keys_inplace,
118 | }
119 |
120 | # Dictionary for renaming VAE keys
121 | VAE_KEYS_RENAME_DICT = {
122 | "block.": "resnets.",
123 | "down.": "down_blocks.",
124 | "downsample": "downsamplers.0",
125 | "upsample": "upsamplers.0",
126 | "nin_shortcut": "conv_shortcut",
127 | "encoder.mid.block_1": "encoder.mid_block.resnets.0",
128 | "encoder.mid.block_2": "encoder.mid_block.resnets.1",
129 | "decoder.mid.block_1": "decoder.mid_block.resnets.0",
130 | "decoder.mid.block_2": "decoder.mid_block.resnets.1",
131 | }
132 |
133 | # Dictionary for handling special keys in VAE
134 | VAE_SPECIAL_KEYS_REMAP = {
135 | "loss": remove_keys_inplace,
136 | "up.": replace_up_keys_inplace,
137 | }
138 |
139 | # Maximum length of the tokenizer (Must be 226)
140 | TOKENIZER_MAX_LENGTH = 226
141 |
142 |
143 | # Function to extract the state_dict from a saved checkpoint
144 | def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
145 | state_dict = saved_dict
146 | if "model" in saved_dict.keys():
147 | state_dict = state_dict["model"]
148 | if "module" in saved_dict.keys():
149 | state_dict = state_dict["module"]
150 | if "state_dict" in saved_dict.keys():
151 | state_dict = state_dict["state_dict"]
152 | return state_dict
153 |
154 |
155 | # Function to update the state_dict with new key assignments in-place
156 | def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
157 | state_dict[new_key] = state_dict.pop(old_key)
158 |
159 |
160 | # Function to convert a transformer checkpoint to the CogVideoX format
161 | def convert_transformer(ckpt_path: str):
162 | PREFIX_KEY = "model.diffusion_model."
163 |
164 | original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
165 | transformer = CogVideoXTransformer3DModel()
166 |
167 | for key in list(original_state_dict.keys()):
168 | new_key = key[len(PREFIX_KEY) :]
169 | for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
170 | new_key = new_key.replace(replace_key, rename_key)
171 | update_state_dict_inplace(original_state_dict, key, new_key)
172 |
173 | for key in list(original_state_dict.keys()):
174 | for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
175 | if special_key not in key:
176 | continue
177 | handler_fn_inplace(key, original_state_dict)
178 |
179 | transformer.load_state_dict(original_state_dict, strict=True)
180 | return transformer
181 |
182 |
183 | # Function to convert a VAE checkpoint to the CogVideoX format
184 | def convert_vae(ckpt_path: str):
185 | original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
186 | vae = AutoencoderKLCogVideoX()
187 |
188 | for key in list(original_state_dict.keys()):
189 | new_key = key[:]
190 | for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
191 | new_key = new_key.replace(replace_key, rename_key)
192 | update_state_dict_inplace(original_state_dict, key, new_key)
193 |
194 | for key in list(original_state_dict.keys()):
195 | for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items():
196 | if special_key not in key:
197 | continue
198 | handler_fn_inplace(key, original_state_dict)
199 |
200 | vae.load_state_dict(original_state_dict, strict=True)
201 | return vae
202 |
203 |
204 | # Function to parse command-line arguments for the script
205 | def get_args():
206 | parser = argparse.ArgumentParser()
207 | parser.add_argument(
208 | "--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
209 | )
210 | parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint")
211 | parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
212 | parser.add_argument(
213 | "--text_encoder_path",
214 | type=str,
215 | required=True,
216 | default="google/t5-v1_1-xxl",
217 | help="Path where converted model should be saved",
218 | )
219 | parser.add_argument(
220 | "--text_encoder_cache_dir",
221 | type=str,
222 | default=None,
223 | help="Path to text encoder cache directory. Not needed if text_encoder_path is in your local.",
224 | )
225 | parser.add_argument("--fp16", action="store_true", default=True, help="Whether to save the model weights in fp16")
226 | parser.add_argument(
227 | "--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving"
228 | )
229 | return parser.parse_args()
230 |
231 |
232 | if __name__ == "__main__":
233 | args = get_args()
234 |
235 | transformer = None
236 | vae = None
237 |
238 | if args.transformer_ckpt_path is not None:
239 | transformer = convert_transformer(args.transformer_ckpt_path)
240 | if args.vae_ckpt_path is not None:
241 | vae = convert_vae(args.vae_ckpt_path)
242 |
243 | tokenizer = T5Tokenizer.from_pretrained(args.text_encoder_path, model_max_length=TOKENIZER_MAX_LENGTH)
244 | text_encoder = T5EncoderModel.from_pretrained(args.text_encoder_path, cache_dir=args.text_encoder_cache_dir)
245 |
246 | scheduler = CogVideoXDDIMScheduler.from_config(
247 | {
248 | "snr_shift_scale": 3.0,
249 | "beta_end": 0.012,
250 | "beta_schedule": "scaled_linear",
251 | "beta_start": 0.00085,
252 | "clip_sample": False,
253 | "num_train_timesteps": 1000,
254 | "prediction_type": "v_prediction",
255 | "rescale_betas_zero_snr": True,
256 | "set_alpha_to_one": True,
257 | "timestep_spacing": "linspace",
258 | }
259 | )
260 |
261 | pipe = CogVideoXPipeline(
262 | tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
263 | )
264 |
265 | if args.fp16:
266 | pipe = pipe.to(dtype=torch.float16)
267 |
268 | pipe.save_pretrained(args.output_path, safe_serialization=True, push_to_hub=args.push_to_hub)
269 |
--------------------------------------------------------------------------------
/tools/venhancer/README.md:
--------------------------------------------------------------------------------
1 | # Enhance CogVideoX Generated Videos with VEnhancer
2 |
3 | This tutorial will guide you through using the VEnhancer tool to enhance videos generated by CogVideoX, including
4 | achieving higher frame rates and higher resolutions.
5 |
6 | ## Model Introduction
7 |
8 | VEnhancer implements spatial super-resolution, temporal super-resolution (frame interpolation), and video refinement in
9 | a unified framework. It can flexibly adapt to different upsampling factors (e.g., 1x~8x) for spatial or temporal
10 | super-resolution. Additionally, it provides flexible control to modify the refinement strength, enabling it to handle
11 | diverse video artifacts.
12 |
13 | VEnhancer follows the design of ControlNet, copying the architecture and weights of the multi-frame encoder and middle
14 | block from a pre-trained video diffusion model to build a trainable conditional network. This video ControlNet accepts
15 | low-resolution keyframes and noisy full-frame latents as inputs. In addition to the time step t and prompt, our proposed
16 | video-aware conditioning also includes noise augmentation level σ and downscaling factor s as additional network
17 | conditioning inputs.
18 |
19 | ## Hardware Requirements
20 |
21 | + Operating System: Linux (requires xformers dependency)
22 | + Hardware: NVIDIA GPU with at least 60GB of VRAM per card. Machines such as H100, A100 are recommended.
23 |
24 | ## Quick Start
25 |
26 | 1. Clone the repository and install dependencies as per the official instructions:
27 |
28 | ```shell
29 | git clone https://github.com/Vchitect/VEnhancer.git
30 | cd VEnhancer
31 | ## Torch and other dependencies can use those from CogVideoX. If you need to create a new environment, use the following commands:
32 | pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2
33 |
34 | ## Install required dependencies
35 | pip install -r requirements.txt
36 | ```
37 |
38 | Where:
39 |
40 | - `input_path` is the path to the input video
41 | - `prompt` is the description of the video content. The prompt used by this tool should be shorter, not exceeding 77
42 | words. You may need to simplify the prompt used for generating the CogVideoX video.
43 | - `target_fps` is the target frame rate for the video. Typically, 16 fps is already smooth, with 24 fps as the default
44 | value.
45 | - `up_scale` is recommend to be set to 2,3,4. The target resolution is limited to be around 2k and below.
46 | - `noise_aug` value depends on the input video quality. Lower quality needs higher noise levels, which corresponds to
47 | stronger refinement. 250~300 is for very low-quality videos. good videos: <= 200.
48 | - `steps` if you want fewer steps, please change solver_mode to "normal" first, then decline the number of steps. "
49 | fast" solver_mode has fixed steps (15).
50 | The code will automatically download the required models from Hugging Face during execution.
51 |
52 | Typical runtime logs are as follows:
53 |
54 | ```shell
55 | /share/home/zyx/.conda/envs/cogvideox/lib/python3.10/site-packages/xformers/ops/fmha/flash.py:211: FutureWarning: `torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch.
56 | @torch.library.impl_abstract("xformers_flash::flash_fwd")
57 | /share/home/zyx/.conda/envs/cogvideox/lib/python3.10/site-packages/xformers/ops/fmha/flash.py:344: FutureWarning: `torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch.
58 | @torch.library.impl_abstract("xformers_flash::flash_bwd")
59 | 2024-08-20 13:25:17,553 - video_to_video - INFO - checkpoint_path: ./ckpts/venhancer_paper.pt
60 | /share/home/zyx/.conda/envs/cogvideox/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
61 | checkpoint = torch.load(checkpoint_path, map_location=map_location)
62 | 2024-08-20 13:25:37,486 - video_to_video - INFO - Build encoder with FrozenOpenCLIPEmbedder
63 | /share/home/zyx/Code/VEnhancer/video_to_video/video_to_video_model.py:35: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
64 | load_dict = torch.load(cfg.model_path, map_location='cpu')
65 | 2024-08-20 13:25:55,391 - video_to_video - INFO - Load model path ./ckpts/venhancer_paper.pt, with local status
66 | 2024-08-20 13:25:55,392 - video_to_video - INFO - Build diffusion with GaussianDiffusion
67 | 2024-08-20 13:26:16,092 - video_to_video - INFO - input video path: inputs/000000.mp4
68 | 2024-08-20 13:26:16,093 - video_to_video - INFO - text: Wide-angle aerial shot at dawn,soft morning light casting long shadows,an elderly man walking his dog through a quiet,foggy park,trees and benches in the background,peaceful and serene atmosphere
69 | 2024-08-20 13:26:16,156 - video_to_video - INFO - input frames length: 49
70 | 2024-08-20 13:26:16,156 - video_to_video - INFO - input fps: 8.0
71 | 2024-08-20 13:26:16,156 - video_to_video - INFO - target_fps: 24.0
72 | 2024-08-20 13:26:16,311 - video_to_video - INFO - input resolution: (480, 720)
73 | 2024-08-20 13:26:16,312 - video_to_video - INFO - target resolution: (1320, 1982)
74 | 2024-08-20 13:26:16,312 - video_to_video - INFO - noise augmentation: 250
75 | 2024-08-20 13:26:16,312 - video_to_video - INFO - scale s is set to: 8
76 | 2024-08-20 13:26:16,399 - video_to_video - INFO - video_data shape: torch.Size([145, 3, 1320, 1982])
77 | /share/home/zyx/Code/VEnhancer/video_to_video/video_to_video_model.py:108: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
78 | with amp.autocast(enabled=True):
79 | 2024-08-20 13:27:19,605 - video_to_video - INFO - step: 0
80 | 2024-08-20 13:30:12,020 - video_to_video - INFO - step: 1
81 | 2024-08-20 13:33:04,956 - video_to_video - INFO - step: 2
82 | 2024-08-20 13:35:58,691 - video_to_video - INFO - step: 3
83 | 2024-08-20 13:38:51,254 - video_to_video - INFO - step: 4
84 | 2024-08-20 13:41:44,150 - video_to_video - INFO - step: 5
85 | 2024-08-20 13:44:37,017 - video_to_video - INFO - step: 6
86 | 2024-08-20 13:47:30,037 - video_to_video - INFO - step: 7
87 | 2024-08-20 13:50:22,838 - video_to_video - INFO - step: 8
88 | 2024-08-20 13:53:15,844 - video_to_video - INFO - step: 9
89 | 2024-08-20 13:56:08,657 - video_to_video - INFO - step: 10
90 | 2024-08-20 13:59:01,648 - video_to_video - INFO - step: 11
91 | 2024-08-20 14:01:54,541 - video_to_video - INFO - step: 12
92 | 2024-08-20 14:04:47,488 - video_to_video - INFO - step: 13
93 | 2024-08-20 14:10:13,637 - video_to_video - INFO - sampling, finished.
94 |
95 | ```
96 |
97 | Running on a single A100 GPU, enhancing each 6-second CogVideoX generated video with default settings will consume 60GB
98 | of VRAM and take 40-50 minutes.
--------------------------------------------------------------------------------
/tools/venhancer/README_ja.md:
--------------------------------------------------------------------------------
1 |
2 | # VEnhancer で CogVideoX によって生成されたビデオを強化する
3 |
4 | このチュートリアルでは、VEnhancer ツールを使用して、CogVideoX で生成されたビデオを強化し、より高いフレームレートと高い解像度を実現する方法を説明します。
5 |
6 | ## モデルの紹介
7 |
8 | VEnhancer は、空間超解像、時間超解像(フレーム補間)、およびビデオのリファインメントを統一されたフレームワークで実現します。空間または時間の超解像のために、さまざまなアップサンプリング係数(例:1x〜8x)に柔軟に対応できます。さらに、多様なビデオアーティファクトを処理するために、リファインメント強度を変更する柔軟な制御を提供します。
9 |
10 | VEnhancer は ControlNet の設計に従い、事前訓練されたビデオ拡散モデルのマルチフレームエンコーダーとミドルブロックのアーキテクチャとウェイトをコピーして、トレーニング可能な条件ネットワークを構築します。このビデオ ControlNet は、低解像度のキーフレームとノイズを含む完全なフレームを入力として受け取ります。さらに、タイムステップ t とプロンプトに加えて、提案されたビデオ対応条件により、ノイズ増幅レベル σ およびダウンスケーリングファクター s が追加のネットワーク条件として使用されます。
11 |
12 | ## ハードウェア要件
13 |
14 | + オペレーティングシステム: Linux (xformers 依存関係が必要)
15 | + ハードウェア: 単一カードあたり少なくとも 60GB の VRAM を持つ NVIDIA GPU。H100、A100 などのマシンを推奨します。
16 |
17 | ## クイックスタート
18 |
19 | 1. 公式の指示に従ってリポジトリをクローンし、依存関係をインストールします。
20 |
21 | ```shell
22 | git clone https://github.com/Vchitect/VEnhancer.git
23 | cd VEnhancer
24 | ## Torch などの依存関係は CogVideoX の依存関係を使用できます。新しい環境を作成する必要がある場合は、以下のコマンドを使用してください。
25 | pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2
26 |
27 | ## 必須の依存関係をインストールします。
28 | pip install -r requirements.txt
29 | ```
30 |
31 | 2. コードを実行します。
32 |
33 | ```shell
34 | python enhance_a_video.py --up_scale 4 --target_fps 24 --noise_aug 250 --solver_mode 'fast' --steps 15 --input_path inputs/000000.mp4 --prompt 'Wide-angle aerial shot at dawn, soft morning light casting long shadows, an elderly man walking his dog through a quiet, foggy park, trees and benches in the background, peaceful and serene atmosphere' --save_dir 'results/'
35 | ```
36 |
37 | 次の設定を行います:
38 |
39 | - `input_path` 是输入视频的路径
40 | - `prompt` 是视频内容的描述。此工具使用的提示词应更短,不超过77个字。您可能需要简化用于生成CogVideoX视频的提示词。
41 | - `target_fps` 是视频的目标帧率。通常,16 fps已经很流畅,默认值为24 fps。
42 | - `up_scale` 推荐设置为2、3或4。目标分辨率限制在2k左右及以下。
43 | - `noise_aug` 的值取决于输入视频的质量。质量较低的视频需要更高的噪声级别,这对应于更强的优化。250~300适用于非常低质量的视频。对于高质量视频,设置为≤200。
44 | - `steps` 如果想减少步数,请先将solver_mode改为“normal”,然后减少步数。“fast”模式的步数是固定的(15步)。
45 | 代码在执行过程中会自动从Hugging Face下载所需的模型。
46 |
47 | コードの実行中に、必要なモデルは Hugging Face から自動的にダウンロードされます。
48 |
49 | ```shell
50 | /share/home/zyx/.conda/envs/cogvideox/lib/python3.10/site-packages/xformers/ops/fmha/flash.py:211: FutureWarning: `torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch.
51 | @torch.library.impl_abstract("xformers_flash::flash_fwd")
52 | /share/home/zyx/.conda/envs/cogvideox/lib/python3.10/site-packages/xformers/ops/fmha/flash.py:344: FutureWarning: `torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch.
53 | @torch.library.impl_abstract("xformers_flash::flash_bwd")
54 | 2024-08-20 13:25:17,553 - video_to_video - INFO - checkpoint_path: ./ckpts/venhancer_paper.pt
55 | /share/home/zyx/.conda/envs/cogvideox/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
56 | checkpoint = torch.load(checkpoint_path, map_location=map_location)
57 | 2024-08-20 13:25:37,486 - video_to_video - INFO - Build encoder with FrozenOpenCLIPEmbedder
58 | /share/home/zyx/Code/VEnhancer/video_to_video/video_to_video_model.py:35: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
59 | load_dict = torch.load(cfg.model_path, map_location='cpu')
60 | 2024-08-20 13:25:55,391 - video_to_video - INFO - Load model path ./ckpts/venhancer_paper.pt, with local status
61 | 2024-08-20 13:25:55,392 - video_to_video - INFO - Build diffusion with GaussianDiffusion
62 | 2024-08-20 13:26:16,092 - video_to_video - INFO - input video path: inputs/000000.mp4
63 | 2024-08-20 13:26:16,093 - video_to_video - INFO - text: Wide-angle aerial shot at dawn,soft morning light casting long shadows,an elderly man walking his dog through a quiet,foggy park,trees and benches in the background,peaceful and serene atmosphere
64 | 2024-08-20 13:26:16,156 - video_to_video - INFO - input frames length: 49
65 | 2024-08-20 13:26:16,156 - video_to_video - INFO - input fps: 8.0
66 | 2024-08-20 13:26:16,156 - video_to_video - INFO - target_fps: 24.0
67 | 2024-08-20 13:26:16,311 - video_to_video - INFO - input resolution: (480, 720)
68 | 2024-08-20 13:26:16,312 - video_to_video - INFO - target resolution: (1320, 1982)
69 | 2024-08-20 13:26:16,312 - video_to_video - INFO - noise augmentation: 250
70 | 2024-08-20 13:26:16,312 - video_to_video - INFO - scale s is set to: 8
71 | 2024-08-20 13:26:16,399 - video_to_video - INFO - video_data shape: torch.Size([145, 3, 1320, 1982])
72 | /share/home/zyx/Code/VEnhancer/video_to_video/video_to_video_model.py:108: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
73 | with amp.autocast(enabled=True):
74 | 2024-08-20 13:27:19,605 - video_to_video - INFO - step: 0
75 | 2024-08-20 13:30:12,020 - video_to_video - INFO - step: 1
76 | 2024-08-20 13:33:04,956 - video_to_video - INFO - step: 2
77 | 2024-08-20 13:35:58,691 - video_to_video - INFO - step: 3
78 | 2024-08-20 13:38:51,254 - video_to_video - INFO - step: 4
79 | 2024-08-20 13:41:44,150 - video_to_video - INFO - step: 5
80 | 2024-08-20 13:44:37,017 - video_to_video - INFO - step: 6
81 | 2024-08-20 13:47:30,037 - video_to_video - INFO - step: 7
82 | 2024-08-20 13:50:22,838 - video_to_video - INFO - step: 8
83 | 2024-08-20 13:53:15,844 - video_to_video - INFO - step: 9
84 | 2024-08-20 13:56:08,657 - video_to_video - INFO - step: 10
85 | 2024-08-20 13:59:01,648 - video_to_video - INFO - step: 11
86 | 2024-08-20 14:01:54,541 - video_to_video - INFO - step: 12
87 | 2024-08-20 14:04:47,488 - video_to_video - INFO - step: 13
88 | 2024-08-20 14:10:13,637 - video_to_video - INFO - sampling, finished.
89 |
90 | ```
91 |
92 | A100 GPU を単一で使用している場合、CogVideoX によって生成された 6 秒間のビデオを強化するには、デフォルト設定で 60GB の VRAM を消費し、40〜50 分かかります。
93 |
--------------------------------------------------------------------------------
/tools/venhancer/README_zh.md:
--------------------------------------------------------------------------------
1 | # 使用 VEnhancer 对 CogVdieoX 生成视频进行增强
2 |
3 | 本教程将要使用 VEnhancer 工具 对 CogVdieoX 生成视频进行增强, 包括更高的帧率和更高的分辨率
4 |
5 | ## 模型介绍
6 |
7 | VEnhancer 在一个统一的框架中实现了空间超分辨率、时间超分辨率(帧插值)和视频优化。它可以灵活地适应不同的上采样因子(例如,1x~
8 | 8x)用于空间或时间超分辨率。此外,它提供了灵活的控制,以修改优化强度,从而处理多样化的视频伪影。
9 |
10 | VEnhancer 遵循 ControlNet 的设计,复制了预训练的视频扩散模型的多帧编码器和中间块的架构和权重,构建了一个可训练的条件网络。这个视频
11 | ControlNet 接受低分辨率关键帧和包含噪声的完整帧作为输入。此外,除了时间步 t 和提示词外,我们提出的视频感知条件还将噪声增强的噪声级别
12 | σ 和降尺度因子 s 作为附加的网络条件输入。
13 |
14 | ## 硬件需求
15 |
16 | + 操作系统: Linux (需要依赖xformers)
17 | + 硬件: NVIDIA GPU 并至少保证单卡显存超过60G,推荐使用 H100,A100等机器。
18 |
19 | ## 快速上手
20 |
21 | 1. 按照官方指引克隆仓库并安装依赖
22 |
23 | ```shell
24 | git clone https://github.com/Vchitect/VEnhancer.git
25 | cd VEnhancer
26 | ## torch等依赖可以使用CogVideoX的依赖,如果你需要创建一个新的环境,可以使用以下命令
27 | pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2
28 |
29 | ## 安装必须的依赖
30 | pip install -r requirements.txt
31 | ```
32 |
33 | 2. 运行代码
34 |
35 | ```shell
36 | python enhance_a_video.py \
37 | --up_scale 4 --target_fps 24 --noise_aug 250 \
38 | --solver_mode 'fast' --steps 15 \
39 | --input_path inputs/000000.mp4 \
40 | --prompt 'Wide-angle aerial shot at dawn,soft morning light casting long shadows,an elderly man walking his dog through a quiet,foggy park,trees and benches in the background,peaceful and serene atmosphere' \
41 | --save_dir 'results/'
42 | ```
43 |
44 | 其中:
45 |
46 | - `input_path` 是输入视频的路径
47 | - `prompt` 是视频内容的描述。此工具使用的提示词应更短,不超过77个字。您可能需要简化用于生成CogVideoX视频的提示词。
48 | - `target_fps` 是视频的目标帧率。通常,16 fps已经很流畅,默认值为24 fps。
49 | - `up_scale` 推荐设置为2、3或4。目标分辨率限制在2k左右及以下。
50 | - `noise_aug` 的值取决于输入视频的质量。质量较低的视频需要更高的噪声级别,这对应于更强的优化。250~300适用于非常低质量的视频。对于高质量视频,设置为≤200。
51 | - `steps` 如果想减少步数,请先将solver_mode改为“normal”,然后减少步数。“fast”模式的步数是固定的(15步)。
52 | 代码在执行过程中会自动从Hugging Face下载所需的模型。
53 |
54 | 代码运行过程中,会自动从Huggingface拉取需要的模型
55 |
56 | 运行日志通常如下:
57 |
58 | ```shell
59 | /share/home/zyx/.conda/envs/cogvideox/lib/python3.10/site-packages/xformers/ops/fmha/flash.py:211: FutureWarning: `torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch.
60 | @torch.library.impl_abstract("xformers_flash::flash_fwd")
61 | /share/home/zyx/.conda/envs/cogvideox/lib/python3.10/site-packages/xformers/ops/fmha/flash.py:344: FutureWarning: `torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch.
62 | @torch.library.impl_abstract("xformers_flash::flash_bwd")
63 | 2024-08-20 13:25:17,553 - video_to_video - INFO - checkpoint_path: ./ckpts/venhancer_paper.pt
64 | /share/home/zyx/.conda/envs/cogvideox/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
65 | checkpoint = torch.load(checkpoint_path, map_location=map_location)
66 | 2024-08-20 13:25:37,486 - video_to_video - INFO - Build encoder with FrozenOpenCLIPEmbedder
67 | /share/home/zyx/Code/VEnhancer/video_to_video/video_to_video_model.py:35: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
68 | load_dict = torch.load(cfg.model_path, map_location='cpu')
69 | 2024-08-20 13:25:55,391 - video_to_video - INFO - Load model path ./ckpts/venhancer_paper.pt, with local status
70 | 2024-08-20 13:25:55,392 - video_to_video - INFO - Build diffusion with GaussianDiffusion
71 | 2024-08-20 13:26:16,092 - video_to_video - INFO - input video path: inputs/000000.mp4
72 | 2024-08-20 13:26:16,093 - video_to_video - INFO - text: Wide-angle aerial shot at dawn,soft morning light casting long shadows,an elderly man walking his dog through a quiet,foggy park,trees and benches in the background,peaceful and serene atmosphere
73 | 2024-08-20 13:26:16,156 - video_to_video - INFO - input frames length: 49
74 | 2024-08-20 13:26:16,156 - video_to_video - INFO - input fps: 8.0
75 | 2024-08-20 13:26:16,156 - video_to_video - INFO - target_fps: 24.0
76 | 2024-08-20 13:26:16,311 - video_to_video - INFO - input resolution: (480, 720)
77 | 2024-08-20 13:26:16,312 - video_to_video - INFO - target resolution: (1320, 1982)
78 | 2024-08-20 13:26:16,312 - video_to_video - INFO - noise augmentation: 250
79 | 2024-08-20 13:26:16,312 - video_to_video - INFO - scale s is set to: 8
80 | 2024-08-20 13:26:16,399 - video_to_video - INFO - video_data shape: torch.Size([145, 3, 1320, 1982])
81 | /share/home/zyx/Code/VEnhancer/video_to_video/video_to_video_model.py:108: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
82 | with amp.autocast(enabled=True):
83 | 2024-08-20 13:27:19,605 - video_to_video - INFO - step: 0
84 | 2024-08-20 13:30:12,020 - video_to_video - INFO - step: 1
85 | 2024-08-20 13:33:04,956 - video_to_video - INFO - step: 2
86 | 2024-08-20 13:35:58,691 - video_to_video - INFO - step: 3
87 | 2024-08-20 13:38:51,254 - video_to_video - INFO - step: 4
88 | 2024-08-20 13:41:44,150 - video_to_video - INFO - step: 5
89 | 2024-08-20 13:44:37,017 - video_to_video - INFO - step: 6
90 | 2024-08-20 13:47:30,037 - video_to_video - INFO - step: 7
91 | 2024-08-20 13:50:22,838 - video_to_video - INFO - step: 8
92 | 2024-08-20 13:53:15,844 - video_to_video - INFO - step: 9
93 | 2024-08-20 13:56:08,657 - video_to_video - INFO - step: 10
94 | 2024-08-20 13:59:01,648 - video_to_video - INFO - step: 11
95 | 2024-08-20 14:01:54,541 - video_to_video - INFO - step: 12
96 | 2024-08-20 14:04:47,488 - video_to_video - INFO - step: 13
97 | 2024-08-20 14:10:13,637 - video_to_video - INFO - sampling, finished.
98 |
99 | ```
100 |
101 | 使用A100单卡运行,对于每个CogVideoX生产的6秒视频,按照默认配置,会消耗60G显存,并用时40-50分钟。
--------------------------------------------------------------------------------