├── .gitignore ├── LICENSE ├── opensora_evaluate ├── cal_psnr.py ├── cal_lpips.py └── cal_ssim.py ├── evaluate.py ├── requirements.txt ├── README.md ├── rec_image_eval.py ├── utils.py ├── rec_video_eval.py └── model └── cdt.py /.gitignore: -------------------------------------------------------------------------------- 1 | ./reconstructed_results/ 2 | ./reconstructed_results/video_results/ 3 | ./reconstructed_results/image_results/ 4 | 5 | ./pretrained/ 6 | ./pretrained/cdt_base.ckpt 7 | ./pretrained/cdt_small.ckpt 8 | 9 | .DS_Store 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Nianzu Yang and Tongyi Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /opensora_evaluate/cal_psnr.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from tqdm import tqdm 4 | import math 5 | 6 | def img_psnr(img1, img2): 7 | 8 | mse = np.mean((img1 / 1.0 - img2 / 1.0) ** 2) 9 | if mse < 1e-10: 10 | return 100 11 | psnr = 20 * math.log10(1 / math.sqrt(mse)) 12 | return psnr 13 | 14 | def trans(x): 15 | return x 16 | 17 | def calculate_psnr(videos1, videos2): 18 | 19 | assert videos1.shape == videos2.shape 20 | 21 | videos1 = trans(videos1) 22 | videos2 = trans(videos2) 23 | 24 | psnr_results = [] 25 | 26 | for video_num in range(videos1.shape[0]): 27 | 28 | video1 = videos1[video_num] 29 | video2 = videos2[video_num] 30 | 31 | psnr_results_of_a_video = [] 32 | for clip_timestamp in range(len(video1)): 33 | 34 | img1 = video1[clip_timestamp].numpy() 35 | img2 = video2[clip_timestamp].numpy() 36 | 37 | psnr_results_of_a_video.append(img_psnr(img1, img2)) 38 | 39 | psnr_results.append(psnr_results_of_a_video) 40 | 41 | psnr_results = np.array(psnr_results) 42 | psnr = {} 43 | psnr_std = {} 44 | 45 | for clip_timestamp in range(len(video1)): 46 | psnr[clip_timestamp] = np.mean(psnr_results[:,clip_timestamp]) 47 | psnr_std[clip_timestamp] = np.std(psnr_results[:,clip_timestamp]) 48 | 49 | result = { 50 | "value": psnr, 51 | "value_std": psnr_std, 52 | "video_setting": video1.shape, 53 | "video_setting_name": "time, channel, heigth, width", 54 | } 55 | 56 | return result 57 | 58 | 59 | def main(): 60 | NUMBER_OF_VIDEOS = 8 61 | VIDEO_LENGTH = 50 62 | CHANNEL = 3 63 | SIZE = 64 64 | videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False) 65 | videos2 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False) 66 | 67 | import json 68 | result = calculate_psnr(videos1, videos2) 69 | print(json.dumps(result, indent=4)) 70 | 71 | if __name__ == "__main__": 72 | main() -------------------------------------------------------------------------------- /opensora_evaluate/cal_lpips.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from tqdm import tqdm 4 | import math 5 | 6 | import lpips 7 | import os 8 | 9 | spatial = True 10 | 11 | 12 | 13 | loss_fn = lpips.LPIPS(net='alex', spatial=spatial) 14 | 15 | def trans(x): 16 | 17 | if x.shape[-3] == 1: 18 | x = x.repeat(1, 1, 3, 1, 1) 19 | 20 | x = x * 2 - 1 21 | 22 | return x 23 | 24 | def calculate_lpips(videos1, videos2, device): 25 | 26 | assert videos1.shape == videos2.shape 27 | 28 | videos1 = trans(videos1) 29 | videos2 = trans(videos2) 30 | 31 | lpips_results = [] 32 | 33 | for video_num in range(videos1.shape[0]): 34 | 35 | video1 = videos1[video_num] 36 | video2 = videos2[video_num] 37 | 38 | lpips_results_of_a_video = [] 39 | for clip_timestamp in range(len(video1)): 40 | 41 | img1 = video1[clip_timestamp].unsqueeze(0).to(device) 42 | img2 = video2[clip_timestamp].unsqueeze(0).to(device) 43 | 44 | loss_fn.to(device) 45 | 46 | lpips_results_of_a_video.append(loss_fn.forward(img1, img2).mean().detach().cpu().tolist()) 47 | lpips_results.append(lpips_results_of_a_video) 48 | 49 | lpips_results = np.array(lpips_results) 50 | 51 | lpips = {} 52 | lpips_std = {} 53 | 54 | for clip_timestamp in range(len(video1)): 55 | lpips[clip_timestamp] = np.mean(lpips_results[:,clip_timestamp]) 56 | lpips_std[clip_timestamp] = np.std(lpips_results[:,clip_timestamp]) 57 | 58 | 59 | result = { 60 | "value": lpips, 61 | "value_std": lpips_std, 62 | "video_setting": video1.shape, 63 | "video_setting_name": "time, channel, heigth, width", 64 | } 65 | 66 | return result 67 | 68 | 69 | def main(): 70 | NUMBER_OF_VIDEOS = 8 71 | VIDEO_LENGTH = 50 72 | CHANNEL = 3 73 | SIZE = 64 74 | videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False) 75 | videos2 = torch.ones(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False) 76 | device = torch.device("cuda") 77 | 78 | import json 79 | result = calculate_lpips(videos1, videos2, device) 80 | print(json.dumps(result, indent=4)) 81 | 82 | if __name__ == "__main__": 83 | main() -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import argparse 3 | 4 | # Define constants for dataset paths 5 | WEBVID_PATH = '/mnt/wulanchabu/lpd_wlcb/dataset/webvid/val' 6 | COCO17_PATH = '/cpfs01/Group-rep-learning/multimodal/datasets/coco/val2017' 7 | 8 | def main(args): 9 | # method name 10 | method = args.method 11 | # assert method name is CDT 12 | assert "CDT" in method, "method must be CDT" 13 | # dataset name 14 | dataset = args.dataset 15 | # mode: video or image evaluation 16 | mode = args.mode 17 | # crop size 18 | size = args.size 19 | # number of frames for video evaluation 20 | num_frames = 17 21 | # subset size: 0 means all videos; otherwise, only use subset_size videos for evaluation, you can set the subset_size to 1 for debug 22 | subset_size = args.subset_size 23 | 24 | print(f"Mode: {mode} evaluation") 25 | print(f"Dataset: {dataset}") 26 | 27 | if dataset == 'webvid': 28 | assert mode == 'video', "webvid dataset only supports video evaluation" 29 | # use constant path for webvid dataset 30 | real_data_dir = WEBVID_PATH 31 | elif dataset == 'coco17': 32 | assert mode == 'image', "coco17 dataset only supports image evaluation" 33 | # use constant path for coco17 dataset 34 | real_data_dir = COCO17_PATH 35 | else: 36 | raise ValueError(f"the path of dataset {dataset} is not specified") 37 | 38 | 39 | if "image" in mode: 40 | # other config 41 | base_set_rec = ( 42 | '--real_data_dir {real_data_dir} ' 43 | '--dataset {dataset} ' 44 | '--subset_size {subset_size} ' 45 | ).format(real_data_dir=real_data_dir, 46 | dataset=dataset, subset_size=subset_size) 47 | command = 'python rec_image_eval.py --method {method} '.format(method=method) + base_set_rec 48 | else: 49 | # other config 50 | base_set_rec = ( 51 | '--crop_size {size} ' 52 | '--resolution {size} ' 53 | '--num_frames {num_frames} ' 54 | '--real_data_dir {real_data_dir} ' 55 | '--dataset {dataset} ' 56 | '--subset_size {subset_size} ' 57 | ).format(size=size, num_frames=num_frames, real_data_dir=real_data_dir, 58 | dataset=dataset, subset_size=subset_size) 59 | 60 | command = 'python rec_video_eval.py --method {method} '.format(method=method) + base_set_rec 61 | 62 | print(f"Command: {command}") 63 | 64 | subprocess.call(command, shell=True) 65 | 66 | if __name__ == '__main__': 67 | 68 | parser = argparse.ArgumentParser() 69 | parser.add_argument('--method', type=str, default='CDT_base') 70 | parser.add_argument('--dataset', type=str, default='webvid', choices=['webvid', 'coco17']) 71 | parser.add_argument('--mode', type=str, default='video', choices=['video', 'image']) 72 | parser.add_argument('--size', type=str, default='256') 73 | parser.add_argument('--subset_size', type=int, default=0) 74 | args = parser.parse_args() 75 | 76 | main(args) -------------------------------------------------------------------------------- /opensora_evaluate/cal_ssim.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from tqdm import tqdm 4 | import cv2 5 | 6 | def ssim(img1, img2): 7 | C1 = 0.01 ** 2 8 | C2 = 0.03 ** 2 9 | img1 = img1.astype(np.float64) 10 | img2 = img2.astype(np.float64) 11 | kernel = cv2.getGaussianKernel(11, 1.5) 12 | window = np.outer(kernel, kernel.transpose()) 13 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] 14 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 15 | mu1_sq = mu1 ** 2 16 | mu2_sq = mu2 ** 2 17 | mu1_mu2 = mu1 * mu2 18 | sigma1_sq = cv2.filter2D(img1 ** 2, -1, window)[5:-5, 5:-5] - mu1_sq 19 | sigma2_sq = cv2.filter2D(img2 ** 2, -1, window)[5:-5, 5:-5] - mu2_sq 20 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 21 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 22 | (sigma1_sq + sigma2_sq + C2)) 23 | return ssim_map.mean() 24 | 25 | 26 | def calculate_ssim_function(img1, img2): 27 | if not img1.shape == img2.shape: 28 | raise ValueError('Input images must have the same dimensions.') 29 | if img1.ndim == 2: 30 | return ssim(img1, img2) 31 | elif img1.ndim == 3: 32 | if img1.shape[0] == 3: 33 | ssims = [] 34 | for i in range(3): 35 | ssims.append(ssim(img1[i], img2[i])) 36 | return np.array(ssims).mean() 37 | elif img1.shape[0] == 1: 38 | return ssim(np.squeeze(img1), np.squeeze(img2)) 39 | else: 40 | raise ValueError('Wrong input image dimensions.') 41 | 42 | def trans(x): 43 | return x 44 | 45 | def calculate_ssim(videos1, videos2): 46 | 47 | assert videos1.shape == videos2.shape 48 | 49 | videos1 = trans(videos1) 50 | videos2 = trans(videos2) 51 | 52 | ssim_results = [] 53 | 54 | for video_num in range(videos1.shape[0]): 55 | 56 | video1 = videos1[video_num] 57 | video2 = videos2[video_num] 58 | 59 | ssim_results_of_a_video = [] 60 | for clip_timestamp in range(len(video1)): 61 | 62 | img1 = video1[clip_timestamp].numpy() 63 | img2 = video2[clip_timestamp].numpy() 64 | 65 | ssim_results_of_a_video.append(calculate_ssim_function(img1, img2)) 66 | 67 | ssim_results.append(ssim_results_of_a_video) 68 | 69 | ssim_results = np.array(ssim_results) 70 | 71 | ssim = {} 72 | ssim_std = {} 73 | 74 | for clip_timestamp in range(len(video1)): 75 | ssim[clip_timestamp] = np.mean(ssim_results[:,clip_timestamp]) 76 | ssim_std[clip_timestamp] = np.std(ssim_results[:,clip_timestamp]) 77 | 78 | result = { 79 | "value": ssim, 80 | "value_std": ssim_std, 81 | "video_setting": video1.shape, 82 | "video_setting_name": "time, channel, heigth, width", 83 | } 84 | 85 | return result 86 | 87 | 88 | def main(): 89 | NUMBER_OF_VIDEOS = 8 90 | VIDEO_LENGTH = 50 91 | CHANNEL = 3 92 | SIZE = 64 93 | videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False) 94 | videos2 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False) 95 | device = torch.device("cuda") 96 | 97 | import json 98 | result = calculate_ssim(videos1, videos2) 99 | print(json.dumps(result, indent=4)) 100 | 101 | if __name__ == "__main__": 102 | main() -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.1.0 2 | accelerate==1.1.0 3 | addict==2.4.0 4 | aiofiles==23.2.1 5 | aiohappyeyeballs==2.4.3 6 | aiohttp==3.10.10 7 | aiosignal==1.3.1 8 | aliyun-python-sdk-core==2.16.0 9 | aliyun-python-sdk-kms==2.16.5 10 | annotated-types==0.7.0 11 | antlr4-python3-runtime==4.9.3 12 | anyio==4.6.2.post1 13 | asttokens==2.4.1 14 | async-timeout==4.0.3 15 | attrs==24.2.0 16 | bcrypt==4.2.0 17 | beautifulsoup4==4.12.3 18 | bitsandbytes==0.44.1 19 | certifi==2024.8.30 20 | cffi==1.17.1 21 | cfgv==3.4.0 22 | chainer==7.8.1 23 | charset-normalizer==3.4.0 24 | clean-fid==0.1.35 25 | click==8.1.7 26 | colorama==0.4.6 27 | colossalai==0.4.6 28 | contourpy==1.3.0 29 | crcmod==1.7 30 | cryptography==43.0.3 31 | cycler==0.12.1 32 | decorator==5.1.1 33 | decord==0.6.0 34 | Deprecated==1.2.14 35 | diffusers==0.32.2 36 | distlib==0.3.9 37 | docstring_parser==0.16 38 | easydict==1.13 39 | einops==0.8.0 40 | exceptiongroup==1.2.2 41 | executing==2.1.0 42 | fabric==3.2.2 43 | facexlib==0.3.0 44 | fairscale==0.4.13 45 | fastapi==0.115.4 46 | ffmpy==0.5.0 47 | filelock==3.16.1 48 | filterpy==1.4.5 49 | fonttools==4.54.1 50 | frozenlist==1.5.0 51 | fsspec==2024.10.0 52 | ftfy==6.3.1 53 | future==1.0.0 54 | fvcore==0.1.5.post20221221 55 | galore-torch==1.0 56 | google==3.0.0 57 | gradio==5.12.0 58 | gradio_client==1.5.4 59 | grpcio==1.68.1 60 | h11==0.14.0 61 | httpcore==1.0.7 62 | httpx==0.28.1 63 | huggingface-hub==0.26.1 64 | icecream==2.1.3 65 | identify==2.6.1 66 | idna==3.10 67 | imageio==2.36.0 68 | importlib_metadata==8.5.0 69 | iniconfig==2.0.0 70 | invoke==2.2.0 71 | iopath==0.1.10 72 | ipython==8.29.0 73 | jedi==0.19.2 74 | Jinja2==3.1.4 75 | jmespath==0.10.0 76 | joblib==1.4.2 77 | jsonschema==4.23.0 78 | jsonschema-specifications==2024.10.1 79 | kiwisolver==1.4.7 80 | lazy_loader==0.4 81 | lightning-utilities==0.11.8 82 | llvmlite==0.43.0 83 | lmdb==1.5.1 84 | loguru==0.7.2 85 | lpips==0.1.4 86 | Markdown==3.7 87 | markdown-it-py==3.0.0 88 | MarkupSafe==2.1.5 89 | matplotlib==3.9.2 90 | matplotlib-inline==0.1.7 91 | mdurl==0.1.2 92 | mediapy==1.2.2 93 | memory-profiler==0.61.0 94 | mmengine==0.10.5 95 | mpmath==1.3.0 96 | msgpack==1.1.0 97 | multidict==6.1.0 98 | mypy-extensions==1.0.0 99 | networkx==3.4.2 100 | ninja==1.11.1.1 101 | nodeenv==1.9.1 102 | numba==0.60.0 103 | numpy==1.26.4 104 | nvidia-cublas-cu12==12.1.3.1 105 | nvidia-cuda-cupti-cu12==12.1.105 106 | nvidia-cuda-nvrtc-cu12==12.1.105 107 | nvidia-cuda-runtime-cu12==12.1.105 108 | nvidia-cudnn-cu12==8.9.2.26 109 | nvidia-cufft-cu12==11.0.2.54 110 | nvidia-curand-cu12==10.3.2.106 111 | nvidia-cusolver-cu12==11.4.5.107 112 | nvidia-cusparse-cu12==12.1.0.106 113 | nvidia-nccl-cu12==2.20.5 114 | nvidia-nvjitlink-cu12==12.4.127 115 | nvidia-nvtx-cu12==12.1.105 116 | omegaconf==2.3.0 117 | onnx==1.17.0 118 | openai-clip==1.0.1 119 | opencv-python==4.10.0.84 120 | opencv-python-headless==4.10.0.84 121 | opencv-transforms==0.0.6 122 | orjson==3.10.15 123 | oss2==2.19.1 124 | packaging==24.1 125 | pandas==2.2.3 126 | paramiko==3.5.0 127 | parso==0.8.4 128 | peft==0.13.2 129 | pexpect==4.9.0 130 | pillow==11.0.0 131 | platformdirs==4.3.6 132 | pluggy==1.5.0 133 | plumbum==1.9.0 134 | portalocker==3.0.0 135 | pre_commit==4.0.1 136 | prompt_toolkit==3.0.48 137 | propcache==0.2.0 138 | protobuf==5.28.3 139 | psutil==6.1.0 140 | ptyprocess==0.7.0 141 | pure_eval==0.2.3 142 | pyarrow==18.1.0 143 | pycparser==2.22 144 | pycryptodome==3.21.0 145 | pydantic==2.9.2 146 | pydantic_core==2.23.4 147 | pydub==0.25.1 148 | Pygments==2.18.0 149 | pyiqa==0.1.13 150 | PyNaCl==1.5.0 151 | pyparsing==3.2.0 152 | pytest==8.3.4 153 | python-dateutil==2.9.0.post0 154 | python-multipart==0.0.20 155 | pytorch-fid==0.3.0 156 | pytorch-lightning==2.4.0 157 | pytz==2024.2 158 | PyYAML==6.0.2 159 | ray==2.38.0 160 | referencing==0.35.1 161 | regex==2024.9.11 162 | requests==2.32.3 163 | rich==13.9.4 164 | rotary-embedding-torch==0.8.4 165 | rpds-py==0.20.1 166 | rpyc==6.0.0 167 | ruff==0.9.2 168 | safehttpx==0.1.6 169 | safetensors==0.4.5 170 | scikit-image==0.24.0 171 | scikit-learn==1.5.2 172 | scipy==1.14.1 173 | semantic-version==2.10.0 174 | sentencepiece==0.2.0 175 | shellingham==1.5.4 176 | six==1.16.0 177 | sniffio==1.3.1 178 | soupsieve==2.6 179 | stack-data==0.6.3 180 | starlette==0.41.2 181 | sympy==1.13.1 182 | tabulate==0.9.0 183 | tensorboard==2.18.0 184 | tensorboard-data-server==0.7.2 185 | termcolor==2.5.0 186 | threadpoolctl==3.5.0 187 | tifffile==2024.9.20 188 | timm==1.0.11 189 | tokenizers==0.15.2 190 | tomli==2.0.2 191 | tomlkit==0.13.2 192 | torch==2.3.0 193 | torch-dist @ http://visionai-wlcb.oss-cn-wulanchabu.aliyuncs.com/library/whl/torch_dist-1.2.9-py3-none-any.whl#sha256=92ae7dfce3b1adf1c02f041394c382aadc649247f397561db048a111ed7d7a81 194 | torch-fidelity==0.3.0 195 | torchaudio==2.3.0 196 | torchmetrics==1.5.1 197 | torchvision==0.18.0 198 | tqdm==4.66.5 199 | traitlets==5.14.3 200 | transformers==4.37.2 201 | triton==2.3.0 202 | typed-argument-parser==1.10.1 203 | typer==0.15.1 204 | typing-inspect==0.9.0 205 | typing_extensions==4.12.2 206 | tzdata==2024.2 207 | urllib3==2.2.3 208 | uuid==1.30 209 | uvicorn==0.29.0 210 | virtualenv==20.27.1 211 | wcwidth==0.2.13 212 | websockets==14.2 213 | Werkzeug==3.1.3 214 | wrapt==1.16.0 215 | xformers==0.0.26.post1 216 | yacs==0.1.8 217 | yapf==0.40.2 218 | yarl==1.17.1 219 | zipp==3.20.2 220 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

Conditioned Diffusion-based Video Tokenizer (CDT)

2 |

3 | Publication 4 | PRs 5 | License 6 | Stars 7 |

8 | 9 | Official implementation for our paper: 10 | 11 | **Rethinking Video Tokenization: A Conditioned Diffusion-based Approach** 12 | 13 | Author List: Nianzu Yang, Pandeng Li, Liming Zhao, Yang Li, Chen-Wei Xie, Yehui Tang, Xudong Lu, Zhihang Liu, Yun Zheng, Yu Liu, Junchi Yan* 14 | 15 | Equal contribution; * Corresponding author 16 | 17 | 18 | # Content 19 | 20 | - [Folder Specification](#folder-specification) 21 | - [Preparation](#preparation) 22 | - [Environment Setup](#environment-setup) 23 | - [Download Pre-trained Models](#download-pre-trained-models) 24 | - [Prepare Data](#prepare-data) 25 | - [Evaluation](#evaluation) 26 | - [Citation](#citation) 27 | - [Acknowledgement](#acknowledgement) 28 | - [Contact](#contact) 29 | 30 | 31 | ## Folder Specification 32 | ```bash 33 | ├── evaluate.py # script for evaluating the performance of CDT on reconstruction task 34 | ├── model # directory of CDT model 35 | │ └── cdt.py # definition of CDT model 36 | ├── opensora_evaluate # scripts for metrics calculation 37 | │ ├── cal_lpips.py # calculate LPIPS 38 | │ ├── cal_psnr.py # calculate PSNR 39 | │ └── cal_ssim.py # calculate SSIM 40 | ├── pretrained # directory of pretrained models, which you should create by yourself 41 | │ ├── cdt_base.ckpt # CDT-base 42 | │ └── cdt_small.ckpt # CDT-small 43 | ├── README.md # README 44 | ├── rec_image_eval.py # script for evaluating the performance of CDT on image reconstruction task 45 | ├── rec_video_eval.py # script for evaluating the performance of CDT on video reconstruction task 46 | ├── requirements.txt # dependencies 47 | └── utils.py # utility functions 48 | ``` 49 | 50 | ## Preparation 51 | 52 | ### Environment Setup 53 | 54 | You can create a new environment and install the dependencies by running the following command: 55 | ```shell 56 | conda create -n cdt python=3.10 57 | conda activate cdt 58 | pip install -r requirements.txt 59 | ``` 60 | 61 | ### Download Pre-trained Models 62 | 63 | We provide the pre-trained models, i.e., CDT-base and CDT-small, on [Hugging Face](https://huggingface.co/yangnianzu/CDT). You can download them and put them in the `pretrained` folder. 64 | 65 | ### Prepare Data 66 | 67 | In our paper, we use two datasets for benchmarking the reconstruction performance: 68 | 69 | - `COCO2017-val` for image reconstruction 70 | - `Webvid-val` for video reconstruction 71 | 72 | You can download these two datasets and put them in the 'data' folder. Next, you need to specify the path of these two datasets in the 'evaluate.py' file. 73 | 74 | ## Evaluation 75 | 76 | Evaluate the performance of CDT-base on image reconstruction: 77 | ```shell 78 | python evaluate.py --method CDT-base --dataset coco17 --mode image 79 | ``` 80 | 81 | Evaluate the performance of CDT-base on video reconstruction at the 256x256 resolution: 82 | 83 | ```shell 84 | python evaluate.py --method CDT-base --dataset webvid --mode video 85 | ``` 86 | 87 | Evaluate the performance of CDT-base on video reconstruction at the 720x720 resolution: 88 | 89 | ```shell 90 | python evaluate.py --method CDT-base --dataset webvid --mode video --size 720 91 | ``` 92 | 93 | --- 94 | 95 | Evaluate the performance of CDT-small on image reconstruction: 96 | ```shell 97 | python evaluate.py --method CDT-small --dataset coco17 --mode image 98 | ``` 99 | 100 | Evaluate the performance of CDT-small on video reconstruction at the 256x256 resolution: 101 | ```shell 102 | python evaluate.py --method CDT-small --dataset webvid --mode video 103 | ``` 104 | 105 | Evaluate the performance of CDT-small on video reconstruction at the 720x720 resolution: 106 | ```shell 107 | python evaluate.py --method CDT-small --dataset webvid --mode video --size 720 108 | ``` 109 | 110 | The reconstructed images or videos and the evaluation results will be saved in the `reconstructed_results` folder. 111 | 112 | ## Citation 113 | If you find this work useful in your research, please consider citing: 114 | 115 | ```bibtex 116 | @article{yang2025rethinking, 117 | title={Rethinking Video Tokenization: A Conditioned Diffusion-based Approach}, 118 | author={Yang, Nianzu and Li, Pandeng and Zhao, Liming and Li, Yang and Xie, Chen-Wei and Tang, Yehui and Lu, Xudong and Liu, Zhihang and Zheng, Yun and Liu, Yu and Yan, Junchi}, 119 | journal={arXiv preprint arXiv:2503.03708}, 120 | year={2025} 121 | } 122 | ``` 123 | 124 | ## Acknowledgement 125 | 126 | We would like to thank the authors of [Open-Sora-Plan](https://github.com/PKU-YuanGroup/Open-Sora-Plan) for their excellent work, which provides the code for the evaluation metrics. 127 | 128 | ## Contact 129 | Welcome to contact us [yangnianzu@sjtu.edu.cn](mailto:yangnianzu@sjtu.edu.cn) for any question. -------------------------------------------------------------------------------- /rec_image_eval.py: -------------------------------------------------------------------------------- 1 | import random 2 | import argparse 3 | from tqdm import tqdm 4 | import numpy as np 5 | import torch 6 | from torch.nn import functional as F 7 | import sys 8 | from torch.utils.data import DataLoader, Subset 9 | import os 10 | sys.path.append(".") 11 | from utils import AverageMeter, custom_to_images, SimpleImageDataset 12 | from opensora_evaluate.cal_lpips import calculate_lpips 13 | from opensora_evaluate.cal_psnr import calculate_psnr 14 | from opensora_evaluate.cal_ssim import calculate_ssim 15 | import time 16 | from model.cdt import load_cdt 17 | 18 | 19 | @torch.no_grad() 20 | def main(args: argparse.Namespace): 21 | real_data_dir = args.real_data_dir 22 | dataset = args.dataset 23 | device = args.device 24 | batch_size = args.batch_size 25 | num_workers = 4 26 | subset_size = args.subset_size 27 | 28 | if args.data_type == "bfloat16": 29 | data_type = torch.bfloat16 30 | elif args.data_type == "float32": 31 | data_type = torch.float32 32 | else: 33 | raise ValueError(f"Invalid data type: {args.data_type}") 34 | 35 | folder_name = f"{args.method}_{args.data_type}" 36 | 37 | 38 | generated_images_dir = os.path.join('./reconstructed_results/image_results/', dataset, folder_name) 39 | metrics_results = os.path.join('./reconstructed_results/image_results/', dataset, 'results.txt') 40 | 41 | 42 | if not os.path.exists(generated_images_dir): 43 | os.makedirs(generated_images_dir) 44 | 45 | # ---- Load Model ---- 46 | device = args.device 47 | assert 'CDT' in args.method, f"method must be CDT, but got {args.method}" 48 | if 'base' in args.method: 49 | print(f"Loading CDT-base") 50 | vae = load_cdt('base') 51 | print(f"CDT-base Loaded") 52 | elif 'small' in args.method: 53 | print(f"Loading CDT-small") 54 | vae = load_cdt('small') 55 | print(f"CDT-small Loaded") 56 | vae = vae.to(device).to(data_type).eval() 57 | model_size = sum([p.numel() for p in vae.parameters()]) / 1e6 58 | print(f'Successfully loaded {args.method} model with {model_size:.3f} million parameters') 59 | # ---- Load Model ---- 60 | 61 | # ---- Prepare Dataset ---- 62 | dataset = SimpleImageDataset(image_dir=real_data_dir) 63 | print(f"Total images found: {len(dataset)}") 64 | if subset_size: 65 | indices = range(subset_size) 66 | dataset = Subset(dataset, indices=indices) 67 | dataloader = DataLoader( 68 | dataset, batch_size=batch_size, pin_memory=True, num_workers=num_workers 69 | ) 70 | # ---- Prepare Dataset 71 | 72 | # ---- Inference ---- 73 | avg_ssim = AverageMeter() 74 | avg_psnr = AverageMeter() 75 | avg_lpips = AverageMeter() 76 | 77 | log_txt = os.path.join(generated_images_dir, 'results.txt') 78 | 79 | step = 0 80 | total_time = 0 81 | total_images = 0 82 | 83 | with open(log_txt, 'a+') as f: 84 | for batch in tqdm(dataloader): 85 | step += 1 86 | x, file_names = batch['image'], batch['file_name'] 87 | original_width = batch['original_width'][0] 88 | original_height = batch['original_height'][0] 89 | torch.cuda.empty_cache() 90 | 91 | x = x.to(device=device, dtype=data_type) 92 | x=x.unsqueeze(2) 93 | start_time = time.time() 94 | video_recon = vae(x) 95 | torch.cuda.synchronize() 96 | end_time = time.time() 97 | total_time += end_time - start_time 98 | total_images += 1 99 | 100 | x, video_recon = x.data.cpu().float(), video_recon.data.cpu().float() 101 | 102 | 103 | if not os.path.exists(generated_images_dir): 104 | os.makedirs(generated_images_dir, exist_ok=True) 105 | 106 | video_recon = video_recon.squeeze(2) 107 | for idx, image_recon in enumerate(video_recon): 108 | output_file = os.path.join(generated_images_dir, file_names[idx]) 109 | custom_to_images(image_recon,output_file,original_height,original_width) 110 | 111 | 112 | video_recon = video_recon.unsqueeze(2) 113 | x = torch.clamp(x, -1, 1) 114 | x = (x + 1) / 2 115 | video_recon = torch.clamp(video_recon, -1, 1) 116 | video_recon = (video_recon + 1) / 2 117 | x = x.permute(0,2,1,3,4).float() 118 | video_recon = video_recon.permute(0,2,1,3,4).float() 119 | 120 | # SSIM 121 | tmp_list = list(calculate_ssim(x, video_recon)['value'].values()) 122 | avg_ssim.updata(np.mean(tmp_list)) 123 | 124 | # PSNR 125 | tmp_list = list(calculate_psnr(x, video_recon)['value'].values()) 126 | avg_psnr.updata(np.mean(tmp_list)) 127 | 128 | # LPIPS 129 | tmp_list = list(calculate_lpips(x, video_recon, args.device)['value'].values()) 130 | avg_lpips.updata(np.mean(tmp_list)) 131 | 132 | if step % args.log_every_steps ==0: 133 | result = ( 134 | f'Step: {step}, PSNR: {avg_psnr.avg}\n' 135 | f'Step: {step}, SSIM: {avg_ssim.avg}\n' 136 | f'Step: {step}, LPIPS: {avg_lpips.avg}\n') 137 | print(result, flush=True) 138 | f.write("="*20+'\n') 139 | f.write(result) 140 | 141 | final_result = (f'psnr: {avg_psnr.avg}\n' 142 | f'ssim: {avg_ssim.avg}\n' 143 | f'lpips: {avg_lpips.avg}') 144 | 145 | print("="*20) 146 | print("Final Results:") 147 | print(final_result) 148 | print("="*20) 149 | print(f'Eval Info:\nmethod: {args.method}\nreal_data_dir: {args.real_data_dir}') 150 | print("="*20) 151 | 152 | 153 | 154 | with open(metrics_results, 'a') as f: 155 | f.write("="*20+'\n') 156 | f.write(f'PSNR: {avg_psnr.avg}\n') 157 | f.write(f'SSIM: {avg_ssim.avg}\n') 158 | f.write(f'LPIPS: {avg_lpips.avg}\n') 159 | f.write(f'Time: {total_time}\n') 160 | f.write(f'Images Number: {total_images}\n') 161 | f.write(f'Avg Time: {total_time/total_images:.4f}\n') 162 | f.write(f'Method: {args.method}\n') 163 | f.write(f'Real Data Dir: {args.real_data_dir}\n') 164 | f.write(f'Data Type: {data_type}\n') 165 | f.write("="*20+'\n\n') 166 | # ---- Inference ---- 167 | 168 | if __name__ == "__main__": 169 | parser = argparse.ArgumentParser() 170 | parser.add_argument("--real_data_dir", type=str) 171 | parser.add_argument("--dataset", type=str, default='coco17') 172 | parser.add_argument("--method", type=str) 173 | parser.add_argument("--batch_size", type=int, default=1) 174 | parser.add_argument("--num_workers", type=int, default=8) 175 | parser.add_argument("--subset_size", type=int, default=100) 176 | parser.add_argument("--device", type=str, default="cuda") 177 | parser.add_argument("--data_type", type=str, default="float32", choices=["float32", "bfloat16"]) 178 | parser.add_argument("--log_every_steps", type=int, default=50) 179 | args = parser.parse_args() 180 | main(args) 181 | 182 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import torch 4 | import numpy as np 5 | import numpy.typing as npt 6 | import torchvision.transforms as T 7 | from PIL import Image 8 | from torch.utils.data import Dataset 9 | from decord import VideoReader, cpu 10 | 11 | 12 | class AverageMeter(object): 13 | def __init__(self): 14 | self.reset() 15 | 16 | def reset(self): 17 | self.val = 0.0 18 | self.avg = 0.0 19 | self.sum = 0.0 20 | self.cnt = 0.0 21 | 22 | def updata(self, val, n=1.0): 23 | self.val = val 24 | self.sum += val * n 25 | self.cnt += n 26 | if self.cnt == 0: 27 | self.avg = 1 28 | else: 29 | self.avg = self.sum / self.cnt 30 | 31 | 32 | # =============================== Image =============================== 33 | 34 | class Image_Crop(object): 35 | def __init__(self): 36 | pass 37 | def __call__(self, img): 38 | iw, ih = img.size 39 | ow = (iw // 8) * 8 40 | oh = (ih // 8) * 8 41 | return img.crop((0, 0, ow, oh)) 42 | 43 | def custom_to_images(x, output_dir, h_i, h_w): 44 | x = x.detach().cpu() 45 | x = torch.clamp(x, -1, 1) 46 | x = (x + 1) / 2 47 | x = x.permute(1, 2, 0).float() 48 | x = (x.numpy() * 255).round().astype(np.uint8) 49 | 50 | img = Image.fromarray(x, mode='RGB') 51 | img = T.Resize((h_i, h_w))(img) 52 | img.save(output_dir) 53 | 54 | class SimpleImageDataset(Dataset): 55 | def __init__(self, image_dir): 56 | 57 | self.image_dir = image_dir 58 | self.image_files = [f for f in os.listdir(image_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))] 59 | self.transform = T.Compose([ 60 | Image_Crop(), 61 | T.ToTensor(), 62 | T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 63 | ]) 64 | 65 | def __len__(self): 66 | return len(self.image_files) 67 | 68 | def __getitem__(self, idx): 69 | img_name = self.image_files[idx] 70 | img_path = os.path.join(self.image_dir, img_name) 71 | 72 | image = Image.open(img_path).convert('RGB') 73 | original_width, original_height = image.size 74 | image = self.transform(image) 75 | 76 | return { 77 | 'image': image, 78 | 'file_name': img_name, 79 | 'original_width': original_width, 80 | 'original_height': original_height 81 | } 82 | 83 | # =============================== Video =============================== 84 | 85 | class CenterCrop(object): 86 | 87 | def __init__(self, size): 88 | self.size = size 89 | 90 | def __call__(self, img): 91 | # resize 92 | iw, ih, ow, oh = *img.size, *self.size 93 | scale = max(ow / iw, oh / ih) 94 | img = img.resize((round(scale * iw), round(scale * ih)), Image.LANCZOS) 95 | 96 | # center crop 97 | w, h = img.size 98 | if w > ow: 99 | x1 = (w - ow) // 2 100 | img = img.crop((x1, 0, x1 + ow, oh)) 101 | elif h > oh: 102 | y1 = (h - oh) // 2 103 | img = img.crop((0, y1, ow, y1 + oh)) 104 | return img 105 | 106 | 107 | def array_to_video( 108 | image_array: npt.NDArray, fps: float = 30.0, path: str = "output_video.mp4" 109 | ) -> None: 110 | frame_dir = path.replace('&', '_').replace('.mp4', '_temp') 111 | os.makedirs(frame_dir, exist_ok=True) 112 | for fid, frame in enumerate(image_array): 113 | tpth = os.path.join(frame_dir, '%04d.png' % (fid+1)) 114 | cv2.imwrite(tpth, frame[:,:,::-1]) 115 | 116 | 117 | cmd = f'ffmpeg -y -f image2 -loglevel quiet -framerate {fps} -i {frame_dir}/%04d.png -vcodec libx264 -crf 17 -pix_fmt yuv420p {path}' 118 | os.system(cmd) 119 | os.system(f'rm -rf {frame_dir}') 120 | 121 | 122 | def custom_to_video( 123 | x: torch.Tensor, fps: float = 2.0, output_file: str = "output_video.mp4" 124 | ) -> None: 125 | x = x.detach().cpu() 126 | x = torch.clamp(x, -1, 1) 127 | x = (x + 1) / 2 128 | x = x.permute(1, 2, 3, 0).float().numpy() 129 | x = (255 * x).astype(np.uint8) 130 | array_to_video(x, fps=fps, path=output_file) 131 | # breakpoint() 132 | return 133 | 134 | 135 | class RealVideoDataset(Dataset): 136 | def __init__( 137 | self, 138 | real_data_dir, 139 | num_frames, 140 | sample_rate=1, 141 | crop_size=None, 142 | resolution=128, 143 | ) -> None: 144 | super().__init__() 145 | self.real_video_files = self._combine_without_prefix(real_data_dir) 146 | self.num_frames = num_frames 147 | self.sample_rate = sample_rate 148 | self.crop_size = crop_size 149 | 150 | def __len__(self): 151 | return len(self.real_video_files) 152 | 153 | def __getitem__(self, index): 154 | if index >= len(self): 155 | raise IndexError 156 | real_video_file = self.real_video_files[index] 157 | real_video_tensor = self._load_video(real_video_file) 158 | video_name = os.path.basename(real_video_file) 159 | return {'video': real_video_tensor, 'file_name': video_name } 160 | 161 | def _load_video(self, video_path): 162 | num_frames = self.num_frames 163 | sample_rate = self.sample_rate 164 | decord_vr = VideoReader(video_path, ctx=cpu(0)) 165 | total_frames = len(decord_vr) 166 | sample_frames_len = sample_rate * num_frames 167 | 168 | if total_frames > sample_frames_len: 169 | s = 0 170 | e = s + sample_frames_len 171 | num_frames = num_frames 172 | else: 173 | s = 0 174 | e = total_frames 175 | num_frames = int(total_frames / sample_frames_len * num_frames) 176 | print( 177 | f"sample_frames_len {sample_frames_len}, only can sample {num_frames * sample_rate}", 178 | video_path, 179 | total_frames, 180 | ) 181 | 182 | frame_id_list = np.linspace(s, e - 1, num_frames, dtype=int) 183 | video_data = decord_vr.get_batch(frame_id_list).asnumpy() 184 | frames = [] 185 | for frame_ in range(video_data.shape[0]): 186 | frame_i = video_data[frame_,:,:,:] 187 | frame_i = Image.fromarray(frame_i) 188 | if frame_i.mode != 'RGB': 189 | frame_i = frame_i.convert('RGB') 190 | frame_i = _preprocess(frame_i, crop_size=self.crop_size) 191 | frames.append(frame_i) 192 | frames = torch.stack(frames, dim=1) 193 | return frames 194 | 195 | def _combine_without_prefix(self, folder_path, prefix="."): 196 | folder = [] 197 | for name in os.listdir(folder_path): 198 | if name[0] == prefix: 199 | continue 200 | folder.append(os.path.join(folder_path, name)) 201 | folder.sort() 202 | return folder 203 | 204 | 205 | def _preprocess(video_data, crop_size=None): 206 | transform=T.Compose([ 207 | CenterCrop([crop_size,crop_size]), 208 | T.ToTensor(), 209 | T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 210 | ]) 211 | video_outputs = transform(video_data) 212 | return video_outputs -------------------------------------------------------------------------------- /rec_video_eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import random 4 | import argparse 5 | from tqdm import tqdm 6 | import numpy as np 7 | import numpy.typing as npt 8 | import torch 9 | from torch.utils.data import DataLoader, Subset 10 | sys.path.append(".") 11 | from utils import * 12 | from opensora_evaluate.cal_lpips import calculate_lpips 13 | from opensora_evaluate.cal_psnr import calculate_psnr 14 | from opensora_evaluate.cal_ssim import calculate_ssim 15 | import time 16 | from model.cdt import load_cdt 17 | 18 | 19 | @torch.no_grad() 20 | def main(args: argparse.Namespace): 21 | real_data_dir = args.real_data_dir 22 | dataset = args.dataset 23 | sample_rate = args.sample_rate 24 | resolution = args.resolution 25 | crop_size = args.crop_size 26 | num_frames = args.num_frames 27 | sample_rate = args.sample_rate 28 | device = args.device 29 | sample_fps = args.sample_fps 30 | batch_size = args.batch_size 31 | num_workers = args.num_workers 32 | subset_size = args.subset_size 33 | 34 | if args.data_type == "bfloat16": 35 | data_type = torch.bfloat16 36 | elif args.data_type == "float32": 37 | data_type = torch.float32 38 | else: 39 | raise ValueError(f"Invalid data type: {args.data_type}") 40 | 41 | 42 | folder_name = f"{args.method}_{args.resolution}_{args.data_type}" 43 | 44 | 45 | generated_video_dir = os.path.join('./reconstructed_results/video_results/', dataset, folder_name) 46 | metrics_results = os.path.join('./reconstructed_results/video_results/', dataset, 'results.txt') 47 | 48 | 49 | if not os.path.exists(generated_video_dir): 50 | os.makedirs(generated_video_dir) 51 | 52 | 53 | # ---- Load Model ---- 54 | device = args.device 55 | assert 'CDT' in args.method, f"method must be CDT, but got {args.method}" 56 | if 'base' in args.method: 57 | print(f"Loading CDT-base") 58 | vae = load_cdt('base') 59 | print(f"CDT-base Loaded") 60 | elif 'small' in args.method: 61 | print(f"Loading CDT-small") 62 | vae = load_cdt('small') 63 | print(f"CDT-small Loaded") 64 | vae = vae.to(device).to(data_type).eval() 65 | model_size = sum([p.numel() for p in vae.parameters()]) / 1e6 66 | print(f'Successfully loaded {args.method} model with {model_size:.3f} million parameters') 67 | # ---- Load Model ---- 68 | 69 | 70 | # ---- Prepare Dataset ---- 71 | dataset = RealVideoDataset( 72 | real_data_dir=real_data_dir, 73 | num_frames=num_frames, 74 | sample_rate=sample_rate, 75 | crop_size=crop_size, 76 | resolution=resolution, 77 | ) 78 | if subset_size: 79 | indices = range(subset_size) 80 | dataset = Subset(dataset, indices=indices) 81 | dataloader = DataLoader( 82 | dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=num_workers 83 | ) 84 | # ---- Prepare Dataset 85 | 86 | 87 | 88 | # ---- Inference ---- 89 | 90 | avg_ssim = AverageMeter() 91 | avg_psnr = AverageMeter() 92 | avg_lpips = AverageMeter() 93 | 94 | log_txt = os.path.join(generated_video_dir, 'results.txt') 95 | 96 | total_time = 0 97 | total_videos = 0 98 | step = 0 99 | 100 | with open(log_txt, 'a+') as f: 101 | for batch in tqdm(dataloader): 102 | step += 1 103 | x, file_names = batch['video'], batch['file_name'] 104 | if x.size(2) < args.num_frames: 105 | print(file_names) 106 | continue 107 | torch.cuda.empty_cache() 108 | x = x.to(device=device, dtype=data_type) 109 | 110 | start_time = time.time() 111 | video_recon = vae(x) 112 | torch.cuda.synchronize() 113 | end_time = time.time() 114 | total_time += end_time - start_time 115 | total_videos += 1 116 | 117 | 118 | x, video_recon = x.data.cpu().float(), video_recon.data.cpu().float() 119 | 120 | # save reconstructed video 121 | if not os.path.exists(generated_video_dir): 122 | os.makedirs(generated_video_dir, exist_ok=True) 123 | for idx, video in enumerate(video_recon): 124 | output_path = os.path.join(generated_video_dir, file_names[idx]) 125 | custom_to_video( 126 | video, fps=sample_fps / sample_rate, output_file=output_path 127 | ) 128 | 129 | x = torch.clamp(x, -1, 1) 130 | x = (x + 1) / 2 131 | video_recon = torch.clamp(video_recon, -1, 1) 132 | video_recon = (video_recon + 1) / 2 133 | 134 | x = x.permute(0,2,1,3,4).float() 135 | video_recon = video_recon.permute(0,2,1,3,4).float() 136 | 137 | # SSIM 138 | tmp_list = list(calculate_ssim(x, video_recon)['value'].values()) 139 | avg_ssim.updata(np.mean(tmp_list)) 140 | 141 | # PSNR 142 | tmp_list = list(calculate_psnr(x, video_recon)['value'].values()) 143 | avg_psnr.updata(np.mean(tmp_list)) 144 | 145 | # LPIPS 146 | tmp_list = list(calculate_lpips(x, video_recon, args.device)['value'].values()) 147 | avg_lpips.updata(np.mean(tmp_list)) 148 | 149 | if step % args.log_every_steps ==0: 150 | result = ( 151 | f'Step: {step}, PSNR: {avg_psnr.avg}\n' 152 | f'Step: {step}, SSIM: {avg_ssim.avg}\n' 153 | f'Step: {step}, LPIPS: {avg_lpips.avg}\n') 154 | print(result, flush=True) 155 | f.write("="*20+'\n') 156 | f.write(result) 157 | 158 | 159 | final_result = (f'psnr: {avg_psnr.avg}\n' 160 | f'ssim: {avg_ssim.avg}\n' 161 | f'lpips: {avg_lpips.avg}') 162 | 163 | print("="*20) 164 | print("Final Results:") 165 | print(final_result) 166 | print("="*20) 167 | print(f'Eval Info:\nmethod: {args.method}\nresolution: {args.resolution}\nnum_frames: {args.num_frames}\nreal_data_dir: {args.real_data_dir}') 168 | print("="*20) 169 | 170 | with open(metrics_results, 'a') as f: 171 | f.write("="*20+'\n') 172 | f.write(f'PSNR: {avg_psnr.avg}\n') 173 | f.write(f'SSIM: {avg_ssim.avg}\n') 174 | f.write(f'LPIPS: {avg_lpips.avg}\n') 175 | f.write(f'Time: {total_time}\n') 176 | f.write(f'Images Number: {total_videos}\n') 177 | f.write(f'Avg Time: {total_time/total_videos:.4f}\n') 178 | f.write(f'Method: {args.method}\n') 179 | f.write(f'Resolution: {args.resolution}\n') 180 | f.write(f'Num Frames: {args.num_frames}\n') 181 | f.write(f'Real Data Dir: {args.real_data_dir}\n') 182 | f.write(f'Data Type: {data_type}\n') 183 | f.write("="*20+'\n\n') 184 | # ---- Inference ---- 185 | 186 | 187 | if __name__ == "__main__": 188 | 189 | parser = argparse.ArgumentParser() 190 | parser.add_argument("--real_data_dir", type=str) 191 | parser.add_argument("--dataset", type=str, default='webvid') 192 | parser.add_argument("--method", type=str) 193 | parser.add_argument("--sample_fps", type=int, default=10) 194 | parser.add_argument("--resolution", type=int, default=256) 195 | parser.add_argument("--crop_size", type=int, default=256) 196 | parser.add_argument("--log_every_steps", type=int, default=50) 197 | parser.add_argument("--num_frames", type=int, default=17) # number of frames for video evaluation 198 | parser.add_argument("--sample_rate", type=int, default=1) 199 | parser.add_argument("--batch_size", type=int, default=1) 200 | parser.add_argument("--num_workers", type=int, default=8) 201 | parser.add_argument("--subset_size", type=int, default=0) 202 | parser.add_argument("--device", type=str, default="cuda") 203 | parser.add_argument("--data_type", type=str, default="float32", choices=["float32", "bfloat16"]) 204 | 205 | args = parser.parse_args() 206 | main(args) 207 | 208 | -------------------------------------------------------------------------------- /model/cdt.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import logging 4 | import numpy as np 5 | from functools import partial 6 | from einops import rearrange, repeat, pack, unpack 7 | from typing import Any, Dict, Tuple, Union 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | 14 | CLIP_FIRST_CHUNK = False 15 | CACHE_T = 2 16 | INT_MAX = 2**31 17 | 18 | 19 | class CausalConv3d(nn.Conv3d): 20 | """ 21 | Causal 3d convolusion. 22 | """ 23 | PAD_MODE = "replicate" 24 | 25 | def __init__(self, *args, **kwargs): 26 | super().__init__(*args, **kwargs) 27 | self._padding = ( 28 | self.padding[2], 29 | self.padding[2], 30 | self.padding[1], 31 | self.padding[1], 32 | 2 * self.padding[0], 33 | 0 34 | ) 35 | arg_list = list(args) 36 | if len(arg_list)>=5: 37 | arg_list[4] = 0 38 | elif 'padding' in kwargs: 39 | kwargs['padding']= 0 40 | super().__init__(*arg_list, **kwargs) 41 | nn.init.zeros_(self.weight) 42 | if self.bias is not None: 43 | nn.init.zeros_(self.bias) 44 | 45 | def forward(self, x, cache_x=None): 46 | padding = list(self._padding) 47 | if cache_x is not None and self._padding[4] > 0: 48 | cache_x = cache_x.to(x.device) 49 | x = torch.cat([cache_x, x], dim=2) 50 | padding[4] -= cache_x.shape[2] 51 | x = F.pad(x, padding, mode=self.PAD_MODE) 52 | 53 | if x.numel() > INT_MAX: 54 | t = x.shape[2] 55 | kernel_t = self.kernel_size[0] 56 | num_split = max(1, t - kernel_t + 1) 57 | out_list = [] 58 | for i in range(num_split): 59 | x_s = x[:, :, i:i+kernel_t, :, :] 60 | out_list.append(super().forward(x_s)) 61 | out = torch.cat(out_list, dim=2) 62 | del out_list 63 | else: 64 | out = super().forward(x) 65 | 66 | return out 67 | 68 | 69 | class Residual(nn.Module): 70 | def __init__(self, fn): 71 | super().__init__() 72 | self.fn = fn 73 | 74 | def forward(self, x, *args, **kwargs): 75 | return self.fn(x, *args, **kwargs) + x 76 | 77 | 78 | class Fp32Upsample(nn.Upsample): 79 | 80 | def forward(self, x): 81 | """ 82 | Fix bfloat16 support for nearest neighbor interpolation. 83 | """ 84 | return super().forward(x.float()).type_as(x) 85 | 86 | 87 | class Upsample(nn.Module): 88 | def __init__(self, dim_in, dim_out=None, new_upsample=False): 89 | super().__init__() 90 | if dim_out is None: 91 | dim_out = dim_in 92 | if new_upsample: 93 | self.up = Fp32Upsample(scale_factor=(1.0, 2.0, 2.0), mode='nearest-exact') 94 | self.conv = CausalConv3d(dim_in, dim_out, (1, 5, 5), padding=(0, 2, 2)) 95 | else: 96 | self.up = None 97 | self.conv = nn.ConvTranspose3d(dim_in, dim_out, (1, 5, 5), (1, 2, 2), (0, 2, 2), 98 | output_padding=(0, 1, 1)) 99 | 100 | def forward(self, x): 101 | if self.up is not None: 102 | x = self.up(x) 103 | x = self.conv(x) 104 | return x 105 | 106 | 107 | class Downsample(nn.Module): 108 | def __init__(self, dim_in, dim_out=None): 109 | super().__init__() 110 | if dim_out is None: 111 | dim_out = dim_in 112 | self.conv = CausalConv3d(dim_in, dim_out, (1, 3, 3), (1, 2, 2), (0, 1, 1)) 113 | 114 | def forward(self, x, feat_cache=None, feat_idx=[0]): 115 | x = self.conv(x) 116 | return x 117 | 118 | 119 | class TimeDownsample(nn.Module): 120 | def __init__( 121 | self, 122 | dim, 123 | dim_out = None, 124 | kernel_size = 3, 125 | antialias = False 126 | ): 127 | super().__init__() 128 | dim_out = dim_out or dim 129 | self.time_causal_padding = (kernel_size - 1, 0) 130 | self.conv = nn.Conv1d(dim, dim_out, kernel_size, stride = 2) 131 | nn.init.zeros_(self.conv.bias) 132 | nn.init.zeros_(self.conv.weight) 133 | self.conv.weight.data[range(dim_out), range(dim), -2] = 1.0 134 | self.conv.weight.data[range(dim_out), range(dim), -1] = 1.0 135 | self.cache_t = 1 136 | 137 | def forward(self, x, feat_cache=None, feat_idx=[0]): 138 | x = rearrange(x, 'b c t h w -> b h w c t').contiguous() 139 | x, ps = pack([x], '* c t') 140 | 141 | if feat_cache is not None: 142 | idx = feat_idx[0] 143 | cache_x = x[..., -self.cache_t:].clone() 144 | 145 | if feat_cache[idx] is not None and self.time_causal_padding[0] > 0: 146 | x = torch.cat([feat_cache[idx], x], dim=-1) 147 | else: 148 | x = F.pad(x, self.time_causal_padding, mode=CausalConv3d.PAD_MODE) 149 | out = self.conv(x) 150 | feat_cache[idx] = cache_x 151 | feat_idx[0] += 1 152 | else: 153 | x = F.pad(x, self.time_causal_padding, mode=CausalConv3d.PAD_MODE) 154 | out = self.conv(x) 155 | 156 | out = unpack(out, ps, '* c t')[0] 157 | out = rearrange(out, 'b h w c t -> b c t h w').contiguous() 158 | return out 159 | 160 | 161 | class TimeUpsample(nn.Module): 162 | def __init__( 163 | self, 164 | dim 165 | ): 166 | super().__init__() 167 | self.conv = nn.Conv1d(dim, dim*2, 1) 168 | self.init_conv_(self.conv) 169 | 170 | def init_conv_(self, conv): 171 | o, i, t = conv.weight.shape 172 | conv_weight = torch.zeros(o // 2, i, t) 173 | conv_weight[range(o//2), range(i), :] = 1.0 174 | conv_weight = repeat(conv_weight, 'o ... -> (o 2) ...') 175 | 176 | conv.weight.data.copy_(conv_weight) 177 | nn.init.zeros_(conv.bias.data) 178 | 179 | def forward(self, x): 180 | x = rearrange(x, 'b c t h w -> b h w c t').contiguous() 181 | x, ps = pack([x], '* c t') 182 | 183 | out = self.conv(x) 184 | 185 | out = rearrange(out, 'b (c p) t -> b c (t p)', p = 2).contiguous() 186 | out = unpack(out, ps, '* c t')[0] 187 | out = rearrange(out, 'b h w c t -> b c t h w').contiguous() 188 | if CLIP_FIRST_CHUNK: 189 | out = out[:, :, 1:, :, :] 190 | return out 191 | 192 | 193 | class LayerNorm(nn.Module): 194 | def __init__(self, dim, eps=1e-5): 195 | super().__init__() 196 | self.eps = eps 197 | self.g = nn.Parameter(torch.ones(1, dim, 1, 1, 1)) 198 | self.b = nn.Parameter(torch.zeros(1, dim, 1, 1, 1)) 199 | 200 | def forward(self, x): 201 | var = torch.var(x, dim=1, unbiased=False, keepdim=True) 202 | mean = torch.mean(x, dim=1, keepdim=True) 203 | return (x - mean) / (var + self.eps).sqrt() * self.g + self.b 204 | 205 | 206 | class PreNorm(nn.Module): 207 | def __init__(self, dim, fn): 208 | super().__init__() 209 | self.fn = fn 210 | self.norm = LayerNorm(dim) 211 | 212 | def forward(self, x, feat_cache=None, feat_idx=[0]): 213 | x = self.norm(x) 214 | return self.fn(x, feat_cache=feat_cache, feat_idx=feat_idx) 215 | 216 | 217 | class Block(nn.Module): 218 | def __init__(self, dim, dim_out, large_filter=False): 219 | super().__init__() 220 | self.block = nn.ModuleList([ 221 | CausalConv3d(dim, dim_out, (3, 7, 7) if large_filter else (3, 3, 3), 222 | padding=(1, 3, 3) if large_filter else (1, 1, 1)), 223 | LayerNorm(dim_out), nn.ReLU()] 224 | ) 225 | 226 | def forward(self, x, feat_cache=None, feat_idx=[0]): 227 | for layer in self.block: 228 | if isinstance(layer, CausalConv3d) and feat_cache is not None: 229 | idx = feat_idx[0] 230 | cache_x = x[:, :, -CACHE_T:, :, :].clone() 231 | if cache_x.shape[2] < 2 and feat_cache[idx] is not None: 232 | # cache last frame of last two chunk 233 | cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) 234 | x = layer(x, feat_cache[idx]) 235 | feat_cache[idx] = cache_x 236 | feat_idx[0] += 1 237 | else: 238 | x = layer(x) 239 | return x 240 | 241 | 242 | 243 | def module_forward(backbone, x, t=None, feat_cache=None, feat_idx=[0]): 244 | if t is None: 245 | if isinstance(backbone, CausalConv3d): 246 | if feat_cache is not None: 247 | idx = feat_idx[0] 248 | cache_x = x[:, :, -CACHE_T:, :, :].clone() 249 | if cache_x.shape[2] < 2 and feat_cache[idx] is not None: 250 | # cache last frame of last two chunk 251 | cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2), cache_x], dim=2) 252 | x = backbone(x, feat_cache[idx]) 253 | feat_cache[idx] = cache_x 254 | feat_idx[0] += 1 255 | else: 256 | x = backbone(x) 257 | elif type(backbone) in [Upsample, TimeUpsample, nn.Identity]: 258 | x = backbone(x) 259 | else: 260 | x = backbone(x, feat_cache=feat_cache, feat_idx=feat_idx) 261 | else: 262 | x = backbone(x, t, feat_cache=feat_cache, feat_idx=feat_idx) 263 | return x 264 | 265 | class ResnetBlock(nn.Module): 266 | def __init__(self, dim, dim_out, time_emb_dim=None, large_filter=False): 267 | super().__init__() 268 | self.mlp = ( 269 | nn.Sequential(nn.LeakyReLU(0.2), nn.Linear(time_emb_dim, dim_out)) 270 | if time_emb_dim is not None 271 | else None 272 | ) 273 | 274 | self.block1 = Block(dim, dim_out, large_filter) 275 | self.block2 = Block(dim_out, dim_out) 276 | self.res_conv = CausalConv3d(dim, dim_out, 1) if dim != dim_out else nn.Identity() 277 | 278 | def forward(self, x, time_emb=None, feat_cache=None, feat_idx=[0]): 279 | h = self.block1(x, feat_cache=feat_cache, feat_idx=feat_idx) 280 | 281 | if time_emb is not None: 282 | h = h + self.mlp(time_emb)[:, :, None, None, None] 283 | 284 | h = self.block2(h, feat_cache=feat_cache, feat_idx=feat_idx) 285 | 286 | return h + self.res_conv(x) 287 | 288 | 289 | class LinearAttention(nn.Module): 290 | def __init__(self, dim, heads=1, dim_head=None): 291 | super().__init__() 292 | if dim_head is None: 293 | dim_head = dim 294 | self.scale = dim_head ** -0.5 295 | self.heads = heads 296 | hidden_dim = dim_head * heads 297 | self.to_qkv = CausalConv3d(dim, hidden_dim * 3, 1, bias=False) 298 | self.to_out = CausalConv3d(hidden_dim, dim, 1) 299 | 300 | def forward(self, x, feat_cache=None, feat_idx=[0]): 301 | b, c, t, h, w = x.shape 302 | qkv = self.to_qkv(x).chunk(3, dim=1) 303 | q, k, v = map(lambda t: rearrange(t, "b (h c) t x y -> (b t) h c (x y)", h=self.heads), qkv) 304 | 305 | q = q.softmax(dim=-2) 306 | k = k.softmax(dim=-1) 307 | 308 | q = q * self.scale 309 | context = torch.einsum("b h d n, b h e n -> b h d e", k, v) 310 | 311 | out = torch.einsum("b h d e, b h d n -> b h e n", context, q) 312 | out = rearrange(out, "(b t) h c (x y) -> b (h c) t x y", h=self.heads, x=h, y=w, t=t) 313 | return self.to_out(out) 314 | 315 | 316 | def count_conv3d(model): 317 | count = 0 318 | for m in model.modules(): 319 | if isinstance(m, CausalConv3d): 320 | count += 1 321 | return count 322 | 323 | def count_time_down_sample(model): 324 | count = 0 325 | for m in model.modules(): 326 | if isinstance(m, TimeDownsample): 327 | count += 1 328 | return count 329 | 330 | 331 | class DiagonalGaussianDistribution(object): 332 | def __init__(self, parameters, deterministic=False): 333 | self.parameters = parameters 334 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 335 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 336 | self.deterministic = deterministic 337 | self.std = torch.exp(0.5 * self.logvar) 338 | self.var = torch.exp(self.logvar) 339 | if self.deterministic: 340 | self.var = self.std = torch.zeros_like(self.mean).to( 341 | device=self.parameters.device 342 | ) 343 | 344 | def sample(self): 345 | x = self.mean + self.std * torch.randn(self.mean.shape).to( 346 | device=self.parameters.device 347 | ) 348 | return x 349 | 350 | def mode(self): 351 | return self.mean 352 | 353 | 354 | class Compressor(nn.Module): 355 | def __init__( 356 | self, 357 | dim=64, 358 | dim_mults=(1, 2, 3, 4), 359 | reverse_dim_mults=(4, 3, 2, 1), 360 | space_down=(1, 1, 1, 1), 361 | time_down=(0, 0, 1, 1), 362 | new_upsample=False, 363 | channels=3, 364 | out_channels=None, 365 | latent_dim=64, 366 | ): 367 | super().__init__() 368 | self.channels = channels 369 | out_channels = out_channels or channels 370 | self.space_down = space_down 371 | self.new_upsample = new_upsample 372 | self.reversed_space_down = list(reversed(self.space_down)) 373 | self.time_down = time_down 374 | self.reversed_time_down = list(reversed(self.time_down)) 375 | self.dims = [channels, *map(lambda m: dim * m, dim_mults)] 376 | self.in_out = list(zip(self.dims[:-1], self.dims[1:])) 377 | self.reversed_dims = [*map(lambda m: dim * m, reverse_dim_mults), out_channels] 378 | self.reversed_in_out = list(zip(self.reversed_dims[:-1], self.reversed_dims[1:])) 379 | assert self.dims[-1] == self.reversed_dims[0] 380 | latent_dim = latent_dim or out_channels 381 | self.quant_conv = torch.nn.Conv3d(self.dims[-1], 2 * latent_dim, 1) 382 | self.post_quant_conv = torch.nn.Conv3d(latent_dim, self.dims[-1], 1) 383 | self.quant_conv_res = nn.ModuleList( 384 | [ResnetBlock(dim_in, dim_out) for dim_in, dim_out in self.reversed_in_out[:-1]]+ 385 | [torch.nn.Conv3d(self.reversed_dims[-2], 2 * latent_dim, 1)]) 386 | self.post_quant_conv_res = nn.ModuleList( 387 | [torch.nn.Conv3d(latent_dim, self.dims[1], 1)]+ 388 | [ResnetBlock(dim_in, dim_out) for dim_in, dim_out in self.in_out[1:]]) 389 | # build network 390 | self.build_network() 391 | 392 | # cache the last two frame of feature map 393 | self._conv_num = count_conv3d(self.post_quant_conv_res) + count_conv3d(self.dec) 394 | self._conv_idx = [0] 395 | self._feat_map = [None] * self._conv_num 396 | # cache encode 397 | self._enc_conv_num = count_conv3d(self.quant_conv_res) + count_conv3d(self.enc) + count_time_down_sample(self.enc) 398 | self._enc_conv_idx = [0] 399 | self._enc_feat_map = [None] * self._enc_conv_num 400 | 401 | 402 | @property 403 | def dtype(self): 404 | return self.enc[0][0].block1.block[0].weight.dtype 405 | 406 | def build_network(self): 407 | self.enc = nn.ModuleList([]) 408 | self.dec = nn.ModuleList([]) 409 | 410 | for ind, (dim_in, dim_out) in enumerate(self.in_out): 411 | self.enc.append( 412 | nn.ModuleList( 413 | [ 414 | ResnetBlock(dim_in, dim_out, None, True if ind == 0 else False), 415 | Downsample(dim_out) if self.space_down[ind] else nn.Identity(), 416 | TimeDownsample(dim_out) if self.time_down[ind] else nn.Identity(), 417 | ] 418 | ) 419 | ) 420 | 421 | 422 | 423 | for ind, (dim_in, dim_out) in enumerate(self.reversed_in_out): 424 | is_last = ind >= (len(self.reversed_in_out) - 1) 425 | mapping_or_identity = ResnetBlock(dim_in, dim_out) if not self.reversed_space_down[ind] and is_last else nn.Identity() 426 | upsample_layer = Upsample(dim_out if not is_last else dim_in, dim_out, self.new_upsample) if self.reversed_space_down[ind] else mapping_or_identity 427 | self.dec.append( 428 | nn.ModuleList( 429 | [ 430 | ResnetBlock(dim_in, dim_out if not is_last else dim_in), 431 | upsample_layer, 432 | TimeUpsample(dim_out) if self.reversed_time_down[ind] else nn.Identity(), 433 | ] 434 | ) 435 | ) 436 | 437 | def encode(self, input, deterministic=True): 438 | self._enc_conv_idx = [0] 439 | 440 | input = input.to(self.dtype) 441 | for i, (resnet, down, time_down) in enumerate(self.enc): 442 | input = resnet(input, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) 443 | input = down(input) 444 | if isinstance(time_down, TimeDownsample): 445 | input = time_down(input, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) 446 | else: 447 | input = time_down(input) 448 | input = input.float() 449 | quant_conv = self.quant_conv.float() 450 | quant_conv_res = self.quant_conv_res.float() 451 | 452 | conv1_out = quant_conv(input) 453 | for layer in quant_conv_res: 454 | if not isinstance(layer, ResnetBlock): 455 | input = layer(input) 456 | else: # ResnetBlock 457 | input = layer(input, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) 458 | input += conv1_out 459 | 460 | posterior = DiagonalGaussianDistribution(input) 461 | if deterministic: 462 | z = posterior.mode() 463 | else: 464 | z = posterior.sample() 465 | return z, input 466 | 467 | def decode(self, input): 468 | self._conv_idx = [0] 469 | 470 | input = input.float() 471 | post_quant_conv = self.post_quant_conv.float() 472 | post_quant_conv_res = self.post_quant_conv_res.float() 473 | 474 | post1 = post_quant_conv(input) 475 | for layer in post_quant_conv_res: 476 | if not isinstance(layer, ResnetBlock): 477 | input = layer(input) 478 | else: # ResnetBlock 479 | input = layer(input, feat_cache=self._feat_map, feat_idx=self._conv_idx) 480 | input += post1 481 | 482 | input = input.to(self.dtype) 483 | output = [] 484 | for i, (resnet, up, time_up) in enumerate(self.dec): 485 | input = resnet(input, feat_cache=self._feat_map, feat_idx=self._conv_idx) 486 | input = up(input) 487 | input = time_up(input) 488 | output.append(input) 489 | return output[::-1] 490 | 491 | def clear_cache(self): 492 | self._feat_map = [None] * self._conv_num 493 | self._enc_feat_map = [None] * self._enc_conv_num 494 | self._conv_idx = [0] 495 | self._enc_conv_idx = [0] 496 | 497 | 498 | class Unet(nn.Module): 499 | def __init__( 500 | self, 501 | dim, 502 | out_dim=None, 503 | dim_mults=(1, 2, 4, 8), 504 | context_dim=None, 505 | context_out_channels=None, 506 | context_dim_mults=(1, 2, 3, 3), 507 | space_down=(1, 1, 1, 1), 508 | time_down=(0, 0, 0, 1), 509 | channels=3, 510 | with_time_emb=True, 511 | new_upsample=False, 512 | embd_type="01", 513 | condition_times=4, 514 | ): 515 | super().__init__() 516 | self.channels = channels 517 | 518 | dims = [channels, *map(lambda m: dim * m, dim_mults)] 519 | context_dim = context_dim or dim 520 | context_out_channels = context_out_channels or context_dim 521 | context_dims = [context_out_channels, *map(lambda m: context_dim * m, context_dim_mults)] 522 | self.space_down = space_down + [1] * (len(dim_mults)-len(space_down)-1) + [0] 523 | self.reversed_space_down = list(reversed(self.space_down[:-1])) 524 | self.time_down = time_down + [0] * (len(dim_mults)-len(time_down)) 525 | self.reversed_time_down = list(reversed(self.time_down[:-1])) 526 | in_out = list(zip(dims[:-1], dims[1:])) 527 | self.embd_type = embd_type 528 | self.condition_times = condition_times 529 | 530 | if with_time_emb: 531 | if embd_type == "01": 532 | time_dim = dim 533 | self.time_mlp = nn.Sequential(nn.Linear(1, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)) 534 | else: 535 | raise NotImplementedError 536 | else: 537 | time_dim = None 538 | self.time_mlp = None 539 | 540 | self.downs = nn.ModuleList([]) 541 | self.ups = nn.ModuleList([]) 542 | num_resolutions = len(in_out) 543 | 544 | for ind, (dim_in, dim_out) in enumerate(in_out): 545 | is_last = ind >= (num_resolutions - 1) 546 | self.downs.append( 547 | nn.ModuleList( 548 | [ 549 | ResnetBlock( 550 | dim_in + context_dims[ind] 551 | if (not is_last) and (ind < self.condition_times) 552 | else dim_in, 553 | dim_out, 554 | time_dim, 555 | True if ind == 0 else False 556 | ), 557 | ResnetBlock(dim_out, dim_out, time_dim), 558 | Residual(PreNorm(dim_out, LinearAttention(dim_out))), 559 | Downsample(dim_out) if self.space_down[ind] else nn.Identity(), 560 | TimeDownsample(dim_out) if self.time_down[ind] else nn.Identity(), 561 | ] 562 | ) 563 | ) 564 | 565 | 566 | mid_dim = dims[-1] 567 | self.mid_block1 = ResnetBlock(mid_dim, mid_dim, time_dim) 568 | self.mid_attn = Residual(PreNorm(mid_dim, LinearAttention(mid_dim))) 569 | self.mid_block2 = ResnetBlock(mid_dim, mid_dim, time_dim) 570 | 571 | for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): 572 | 573 | is_last = ind >= (num_resolutions - 1) 574 | self.ups.append( 575 | nn.ModuleList( 576 | [ 577 | ResnetBlock(dim_out * 2, dim_in, time_dim), 578 | ResnetBlock(dim_in, dim_in, time_dim), 579 | Residual(PreNorm(dim_in, LinearAttention(dim_in))), 580 | Upsample(dim_in, new_upsample=new_upsample) if self.reversed_space_down[ind] else nn.Identity(), 581 | TimeUpsample(dim_in) if self.reversed_time_down[ind] else nn.Identity(), 582 | ] 583 | ) 584 | ) 585 | 586 | out_dim = out_dim or channels 587 | self.final_conv = nn.ModuleList([LayerNorm(dim), CausalConv3d(dim, out_dim, (3, 7, 7), padding=(1, 3, 3))]) 588 | 589 | # cache the last two frame of feature map 590 | self._conv_num = count_conv3d(self.ups) + count_conv3d(self.downs) + \ 591 | count_conv3d(self.mid_block1) + count_conv3d(self.mid_block2) + \ 592 | count_conv3d(self.final_conv) + count_time_down_sample(self.downs) 593 | self._conv_idx = [0] 594 | self._feat_map = [None] * self._conv_num 595 | 596 | @property 597 | def dtype(self): 598 | return self.final_conv[1].weight.dtype 599 | 600 | def encode(self, x, t, context): 601 | h = [] 602 | for idx, (backbone, backbone2, attn, downsample, time_downsample) in enumerate(self.downs): 603 | x = torch.cat([x, context[idx]], dim=1) if idx < self.condition_times else x 604 | x = backbone(x, t, feat_cache=self._feat_map, feat_idx=self._conv_idx) 605 | x = backbone2(x, t, feat_cache=self._feat_map, feat_idx=self._conv_idx) 606 | x = attn(x, feat_cache=self._feat_map, feat_idx=self._conv_idx) 607 | h.append(x) 608 | x = downsample(x) 609 | if isinstance(time_downsample, TimeDownsample): 610 | x = time_downsample(x, feat_cache=self._feat_map, feat_idx=self._conv_idx) 611 | else: 612 | x = time_downsample(x) 613 | 614 | x = self.mid_block1(x, t, feat_cache=self._feat_map, feat_idx=self._conv_idx) 615 | return x, h 616 | 617 | def decode(self, x, h, t): 618 | device = x.device 619 | dtype = x.dtype 620 | x = self.mid_attn(x) 621 | x = self.mid_block2(x, t, feat_cache=self._feat_map, feat_idx=self._conv_idx) 622 | for backbone, backbone2, attn, upsample, time_upsample in self.ups: 623 | reference = h.pop() 624 | if x.shape[2:] != reference.shape[2:]: 625 | x = F.interpolate( 626 | x.float(), size=reference.shape[2:], mode='nearest' 627 | ).type_as(x) 628 | x = torch.cat((x, reference), dim=1) 629 | x = module_forward(backbone, x, t, feat_cache=self._feat_map, feat_idx=self._conv_idx) 630 | x = module_forward(backbone2, x, t, feat_cache=self._feat_map, feat_idx=self._conv_idx) 631 | x = module_forward(attn, x) 632 | x = module_forward(upsample, x) 633 | x = module_forward(time_upsample, x) 634 | x = x.to(device).to(dtype) 635 | 636 | x = self.final_conv[0](x) 637 | if self._feat_map is not None: 638 | idx = self._conv_idx[0] 639 | cache_x = x[:, :, -CACHE_T:, :, :].clone() 640 | if cache_x.shape[2] < 2 and self._feat_map[idx] is not None: 641 | # cache last frame of last two chunk 642 | cache_x = torch.cat([self._feat_map[idx][:, :, -1, :, :].unsqueeze(2), cache_x], dim=2) 643 | x = self.final_conv[1](x, self._feat_map[idx]) 644 | self._feat_map[idx] = cache_x 645 | self._conv_idx[0] += 1 646 | else: 647 | x = self.final_conv[1](x) 648 | return x 649 | 650 | def forward(self, x, time=None, context=None): 651 | self._conv_idx = [0] 652 | t = None 653 | if self.time_mlp is not None: 654 | time_mlp = self.time_mlp.float() 655 | t = time_mlp(time).to(self.dtype) 656 | 657 | x = x.to(self.dtype) 658 | x, h = self.encode(x, t, context) 659 | return self.decode(x, h, t) 660 | 661 | def clear_cache(self): 662 | self._feat_map = [None] * self._conv_num 663 | self._conv_idx = [0] 664 | 665 | 666 | def extract(a, t, x_shape): 667 | a = a.to(t.device) 668 | b, *_ = t.shape 669 | out = a.gather(-1, t) 670 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 671 | 672 | 673 | def cosine_beta_schedule(timesteps, s=0.008): 674 | """ 675 | cosine schedule 676 | as proposed in https://openreview.net/forum?id=-NEXDKk8gZ 677 | """ 678 | steps = timesteps + 1 679 | x = np.linspace(0, steps, steps) 680 | alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2 681 | alphas_cumprod = alphas_cumprod / alphas_cumprod[0] 682 | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) 683 | return np.clip(betas, a_min=0, a_max=0.999) 684 | 685 | 686 | class GaussianDiffusion(nn.Module): 687 | def __init__( 688 | self, 689 | denoise_fn=None, 690 | context_fn=None, 691 | ae_fn=None, 692 | num_timesteps=8192, 693 | pred_mode="x", 694 | var_schedule="cosine", 695 | ): 696 | super().__init__() 697 | self.denoise_fn = denoise_fn 698 | self.context_fn = context_fn 699 | self.ae_fn = ae_fn 700 | self.otherlogs = {} 701 | self.var_schedule = var_schedule 702 | self.sample_steps = None 703 | assert pred_mode in ["noise", "x", "v"] 704 | self.pred_mode = pred_mode 705 | to_torch = partial(torch.tensor, dtype=torch.float32) 706 | 707 | train_betas = cosine_beta_schedule(num_timesteps) 708 | train_alphas = 1.0 - train_betas 709 | train_alphas_cumprod = np.cumprod(train_alphas, axis=0) 710 | (num_timesteps,) = train_betas.shape 711 | self.num_timesteps = int(num_timesteps) 712 | 713 | self.train_snr=to_torch(train_alphas_cumprod / (1 - train_alphas_cumprod)) 714 | self.train_betas=to_torch(train_betas) 715 | self.train_alphas_cumprod=to_torch(train_alphas_cumprod) 716 | self.train_sqrt_alphas_cumprod=to_torch(np.sqrt(train_alphas_cumprod)) 717 | self.train_sqrt_one_minus_alphas_cumprod=to_torch(np.sqrt(1.0 - train_alphas_cumprod)) 718 | self.train_sqrt_recip_alphas_cumprod=to_torch(np.sqrt(1.0 / train_alphas_cumprod)) 719 | self.train_sqrt_recipm1_alphas_cumprod=to_torch(np.sqrt(1.0 / train_alphas_cumprod - 1)) 720 | 721 | def set_sample_schedule(self, sample_steps, device): 722 | self.sample_steps = sample_steps 723 | if sample_steps != 1: 724 | indice = torch.linspace(0, self.num_timesteps - 1, sample_steps, device=device).long() 725 | else: 726 | indice = torch.tensor([self.num_timesteps - 1], device=device).long() 727 | self.train_alphas_cumprod = self.train_alphas_cumprod.to(device) 728 | self.train_snr = self.train_snr.to(device) 729 | self.alphas_cumprod = self.train_alphas_cumprod[indice] 730 | self.snr = self.train_snr[indice] 731 | self.index = torch.arange(self.num_timesteps, device=device)[indice] 732 | self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0) 733 | self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod) 734 | self.sqrt_alphas_cumprod_prev = torch.sqrt(self.alphas_cumprod_prev) 735 | self.one_minus_alphas_cumprod = 1.0 - self.alphas_cumprod 736 | self.one_minus_alphas_cumprod_prev = 1.0 - self.alphas_cumprod_prev 737 | self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod) 738 | self.sqrt_one_minus_alphas_cumprod_prev = torch.sqrt(1.0 - self.alphas_cumprod_prev) 739 | self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod) 740 | self.sqrt_recip_alphas_cumprod_prev = torch.sqrt(1.0 / self.alphas_cumprod_prev) 741 | self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod - 1) 742 | self.sigma = self.sqrt_one_minus_alphas_cumprod_prev / self.sqrt_one_minus_alphas_cumprod * torch.sqrt(1.0 - self.alphas_cumprod / self.alphas_cumprod_prev) 743 | 744 | def predict_noise_from_start(self, x_t, t, x0): 745 | return ( 746 | (extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / \ 747 | extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) 748 | ) 749 | 750 | def predict_v(self, x_start, t, noise): 751 | if self.training: 752 | return ( 753 | extract(self.train_sqrt_alphas_cumprod, t, x_start.shape) * noise - 754 | extract(self.train_sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start 755 | ) 756 | else: 757 | return ( 758 | extract(self.sqrt_alphas_cumprod, t, x_start.shape) * noise - 759 | extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start 760 | ) 761 | 762 | def predict_start_from_v(self, x_t, t, v): 763 | if self.training: 764 | return ( 765 | extract(self.train_sqrt_alphas_cumprod, t, x_t.shape) * x_t - 766 | extract(self.train_sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v 767 | ) 768 | else: 769 | return ( 770 | extract(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t - 771 | extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v 772 | ) 773 | 774 | def predict_start_from_noise(self, x_t, t, noise): 775 | if self.training: 776 | return ( 777 | extract(self.train_sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t 778 | - extract(self.train_sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise 779 | ) 780 | else: 781 | return ( 782 | extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t 783 | - extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise 784 | ) 785 | 786 | def ddim(self, x, t, context, clip_denoised, eta=0): 787 | if self.denoise_fn.embd_type == "01": 788 | fx = self.denoise_fn(x, self.index[t].float().unsqueeze(-1) / self.num_timesteps, context=context) 789 | else: 790 | fx = self.denoise_fn(x, self.index[t], context=context) 791 | fx = fx.float() 792 | if self.pred_mode == "noise": 793 | x_recon = self.predict_start_from_noise(x, t=t, noise=fx) 794 | elif self.pred_mode == "x": 795 | x_recon = fx 796 | elif self.pred_mode == "v": 797 | x_recon = self.predict_start_from_v(x, t=t, v=fx) 798 | if clip_denoised: 799 | x_recon.clamp_(-1.0, 1.0) 800 | noise = fx if self.pred_mode == "noise" else self.predict_noise_from_start(x, t=t, x0=x_recon) 801 | x_next = ( 802 | extract(self.sqrt_alphas_cumprod_prev, t, x.shape) * x_recon 803 | + torch.sqrt( 804 | (extract(self.one_minus_alphas_cumprod_prev, t, x.shape) 805 | - (eta * extract(self.sigma, t, x.shape)) ** 2).clamp(min=0) 806 | ) 807 | * noise + eta * extract(self.sigma, t, x.shape) * torch.randn_like(noise) 808 | ) 809 | return x_next 810 | 811 | def p_sample(self, x, t, context, clip_denoised, eta=0): 812 | return self.ddim(x=x, t=t, context=context, clip_denoised=clip_denoised, eta=eta) 813 | 814 | def p_sample_loop(self, shape, context, clip_denoised=False, init=None, eta=0): 815 | device = self.alphas_cumprod.device 816 | 817 | b = shape[0] 818 | img = torch.zeros(shape, device=device) if init is None else init 819 | for count, i in enumerate(reversed(range(0, self.sample_steps))): 820 | time = torch.full((b,), i, device=device, dtype=torch.long) 821 | img = self.p_sample( 822 | img, 823 | time, 824 | context=context, 825 | clip_denoised=clip_denoised, 826 | eta=eta, 827 | ) 828 | return img 829 | 830 | @torch.no_grad() 831 | def compress( 832 | self, 833 | images, 834 | sample_steps=10, 835 | init=None, 836 | clip_denoised=True, 837 | eta=0, 838 | ): 839 | context_dict = self.context_fn(images, 'test') 840 | self.set_sample_schedule( 841 | self.num_timesteps if (sample_steps is None) else sample_steps, 842 | context_dict["output"][0].device, 843 | ) 844 | return self.p_sample_loop( 845 | images.shape, context_dict["output"], 846 | clip_denoised=clip_denoised, init=init, eta=eta 847 | ) 848 | 849 | @torch.no_grad() 850 | def diffusion_decode( 851 | self, 852 | latent, 853 | sample_steps=30, 854 | init=None, 855 | time=None, 856 | clip_denoised=False, 857 | eta=0, 858 | ): 859 | context = self.context_fn.decode(latent) 860 | 861 | # breakpoint() 862 | self.set_sample_schedule( 863 | self.num_timesteps if (sample_steps is None) else sample_steps, 864 | context[0].device, 865 | ) 866 | 867 | img = init 868 | img = self.p_sample( 869 | img, 870 | time, 871 | context=context, 872 | clip_denoised=clip_denoised, 873 | eta=eta, 874 | ) 875 | return img 876 | 877 | 878 | class ConditionedDiffusionTokenizer(nn.Module): 879 | 880 | def __init__(self, 881 | pretrained=None, # pretrained ckpt path 882 | enc_bs=1, # mini-batch to loop for both enc and dec 883 | enc_frames=13, # mini-batch frames to loop for both enc and dec 884 | dec_bs=1, # mini-batch to loop for dec 885 | dec_frames=4, # mini-batch frames to loop for dec 886 | z_overlap=1, # mini-batch inference overlap 887 | sample_steps=10, # decode diffusion steps 888 | sample_gamma=0.8, # decode noise weight 889 | num_timesteps=8192, 890 | out_channels=3, 891 | context_dim=64, 892 | unet_dim=64, 893 | new_upsample=False, 894 | latent_dim=16, 895 | context_dim_mults = [1, 2, 3, 4], 896 | space_down = [1, 1, 1, 1], 897 | time_down = [0, 0, 1, 1], 898 | condition_times=4, 899 | **kwargs 900 | ): 901 | super().__init__() 902 | self.latent_scale = None 903 | self.enc_bs = enc_bs 904 | self.enc_frames = enc_frames 905 | self.dec_bs = dec_bs or enc_bs 906 | self.dec_frames = dec_frames or enc_frames 907 | self.space_factor = 2 ** sum(space_down) 908 | self.time_factor = 2 ** sum(time_down) 909 | self.sample_steps = sample_steps 910 | self.sample_gamma = sample_gamma 911 | self.z_overlap = z_overlap 912 | self.num_timesteps = num_timesteps 913 | self.condition_times = condition_times 914 | CausalConv3d.PAD_MODE = 'constant' if new_upsample else 'replicate' 915 | context_model = Compressor( 916 | dim=context_dim, 917 | latent_dim=latent_dim, 918 | new_upsample=new_upsample, 919 | channels=3, 920 | out_channels=out_channels, 921 | dim_mults=context_dim_mults, 922 | reverse_dim_mults=reversed(context_dim_mults), 923 | space_down=space_down, 924 | time_down=time_down) 925 | denoise_model = Unet( 926 | dim=unet_dim, 927 | channels=3, 928 | new_upsample=new_upsample, 929 | context_dim=context_dim, 930 | context_out_channels=out_channels, 931 | dim_mults=[1, 2, 3, 4, 5, 6], 932 | context_dim_mults=context_dim_mults, 933 | space_down=space_down, 934 | time_down=time_down, 935 | condition_times=condition_times 936 | ) 937 | self.model = GaussianDiffusion( 938 | context_fn=context_model, 939 | denoise_fn=denoise_model, 940 | num_timesteps=self.num_timesteps, 941 | ) 942 | self.z_dim = latent_dim 943 | if pretrained is not None: 944 | self.init_from_ckpt(pretrained) 945 | 946 | def init_from_ckpt( 947 | self, 948 | path: str, 949 | ) -> None: 950 | if path.endswith("safetensors"): 951 | sd = load_safetensors(path) 952 | else: 953 | # breakpoint() 954 | sd = torch.load(path, map_location="cpu", weights_only=False) 955 | sd = sd.get("state_dict", sd) 956 | sd = sd.get("module", sd) 957 | sd = {k:v.to(torch.float32) for k,v in sd.items()} 958 | missing, unexpected = self.load_state_dict(sd, strict=False) 959 | logging.info( 960 | f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys" 961 | ) 962 | if len(missing) > 0: 963 | logging.info(f"Missing Keys: {missing}") 964 | if len(unexpected) > 0: 965 | logging.info(f"Unexpected Keys: {unexpected}") 966 | 967 | @torch.no_grad() 968 | def forward(self, x, **kwargs): 969 | latent = self.encode(x) 970 | x_recon = self.decode(latent) 971 | return x_recon 972 | 973 | @torch.no_grad() 974 | def forward_with_info(self, x, **kwargs): 975 | latent, posterior = self.encode(x, return_posterior=True) 976 | x_recon = self.decode(latent) 977 | mu = posterior.mean 978 | logvar = posterior.logvar 979 | z_std = mu.std() 980 | return x_recon, mu, logvar, z_std 981 | 982 | @torch.no_grad() 983 | def encode(self, x, scale_factor=1.0, return_posterior=False, **kwargs): 984 | N, C, T, H, W = x.shape 985 | mini_batch_size = min(self.enc_bs, N) if self.enc_bs else N 986 | enc_frames_ratio = (1024/H * 1024/W) ** 1.0 987 | enc_frames = int(self.enc_frames * enc_frames_ratio/4) * 4 988 | enc_frames = max(enc_frames, 4) 989 | logging.debug(f"enc: {x.shape}, {self.enc_frames}, {enc_frames_ratio}, {enc_frames}") 990 | z_overlap = 0 991 | frame_overlap = 0 992 | 993 | mini_frames = min(enc_frames, T) if enc_frames else T 994 | n_batches = int(math.ceil(N / mini_batch_size)) 995 | n_frame_batches = math.ceil((T-frame_overlap) / max(1, mini_frames-frame_overlap)) 996 | n_frame_batches = max(int(n_frame_batches), 1) 997 | remainder = T % mini_frames 998 | z = list() 999 | mean_std_list = list() 1000 | for i_batch in range(n_batches): 1001 | z_batch = [] 1002 | mean_std_batch = [] 1003 | x_batch = x[i_batch * mini_batch_size : (i_batch + 1) * mini_batch_size] 1004 | frame_end = 0 1005 | for i_frames in range(n_frame_batches): 1006 | frame_start = frame_end 1007 | if i_frames == 0 and remainder > 0: 1008 | frame_end = frame_start + remainder 1009 | else: 1010 | frame_end = frame_start + mini_frames 1011 | i_batch_input = x_batch[:, :, frame_start:frame_end, ...] 1012 | logging.debug(f'enc: {i_frames}, {i_batch_input.shape}') 1013 | i_batch_input = i_batch_input.to(next(self.parameters()).device) 1014 | z_frames, mean_std = self.model.context_fn.encode(i_batch_input, deterministic=True) 1015 | z_batch.append(z_frames[:, :, (z_overlap if i_frames>0 else 0):, ...]) 1016 | mean_std_batch.append(mean_std[:, :, (z_overlap if i_frames>0 else 0):, ...]) 1017 | z_batch = torch.cat(z_batch, dim=2) 1018 | z.append(z_batch) 1019 | mean_std_batch = torch.cat(mean_std_batch, dim=2) 1020 | mean_std_list.append(mean_std_batch) 1021 | self.model.context_fn.clear_cache() 1022 | latent = torch.cat(z, 0) 1023 | mean_std = torch.cat(mean_std_list, 0) 1024 | posterior = DiagonalGaussianDistribution(mean_std) 1025 | latent *= scale_factor 1026 | if return_posterior: 1027 | return latent, posterior 1028 | else: 1029 | return latent 1030 | 1031 | @torch.no_grad() 1032 | def decode(self, latent, scale_factor=1.0, **kwargs): 1033 | N, C, T, H, W = latent.shape 1034 | latent = latent/scale_factor 1035 | mini_batch_size = min(self.dec_bs, N) if self.dec_bs else N 1036 | n_batches = int(math.ceil(N / mini_batch_size)) 1037 | dec_frames_ratio = (1024/self.space_factor/H * 1024/self.space_factor/W) ** 1.0 1038 | dec_frames = int(self.dec_frames * dec_frames_ratio) 1039 | dec_frames = max(1, int(dec_frames)) 1040 | logging.debug(f"dec: {latent.shape}, {self.dec_frames}, {dec_frames_ratio}, {dec_frames}") 1041 | z_overlap = 0 1042 | frame_overlap = 0 1043 | 1044 | assert dec_frames >= z_overlap + 1, f"dec_frames {dec_frames} too small" 1045 | mini_frames = min(dec_frames, T) if dec_frames else T 1046 | n_frame_batches = math.ceil((T-z_overlap) / max(1, (mini_frames-z_overlap))) 1047 | n_frame_batches = max(int(n_frame_batches), 1) 1048 | dec = list() 1049 | target_shape = [N, 3, T*self.time_factor, H*self.space_factor, W*self.space_factor] 1050 | init_noise = self.sample_gamma*torch.randn(target_shape, dtype=latent.dtype, device=latent.device) 1051 | 1052 | for i_batch in range(n_batches): 1053 | x_batch = latent[i_batch * mini_batch_size : (i_batch + 1) * mini_batch_size] 1054 | noise_batch = init_noise[i_batch * mini_batch_size : (i_batch + 1) * mini_batch_size] 1055 | z_batch_list = [None] * n_frame_batches 1056 | for count, t_idx in enumerate(reversed(range(0, self.sample_steps))): 1057 | for i_frames in range(n_frame_batches): 1058 | global CLIP_FIRST_CHUNK 1059 | CLIP_FIRST_CHUNK = True if i_frames == 0 else False 1060 | 1061 | latent_frame_start = i_frames * (mini_frames-z_overlap) 1062 | latent_frame_end = latent_frame_start+mini_frames 1063 | i_batch_input = x_batch[:, :, latent_frame_start:latent_frame_end, ...] 1064 | if count == 0: 1065 | frame_start = latent_frame_start*self.time_factor 1066 | frame_end = frame_start + (mini_frames)*self.time_factor 1067 | if CLIP_FIRST_CHUNK: 1068 | frame_start += 3 1069 | cur_noise = noise_batch[:, :, frame_start:frame_end, ...] 1070 | else: 1071 | cur_noise = z_batch_list[i_frames] 1072 | time = torch.full((cur_noise.shape[0],), t_idx, device=latent.device, dtype=torch.long) 1073 | logging.debug(f'dec: {i_frames}, {i_batch_input.shape}, {cur_noise.shape}') 1074 | x_rec = self.model.diffusion_decode(i_batch_input, 1075 | sample_steps=self.sample_steps, init=cur_noise, time=time) 1076 | z_batch_list[i_frames] = x_rec[:, :, (0 if i_frames==0 else frame_overlap):, ...] 1077 | # clear cache 1078 | self.model.context_fn.clear_cache() 1079 | self.model.denoise_fn.clear_cache() 1080 | 1081 | dec.append(torch.cat(z_batch_list, dim=2)) 1082 | del z_batch_list, x_batch, noise_batch 1083 | 1084 | dec = torch.cat(dec, 0) 1085 | return dec 1086 | 1087 | 1088 | def load_cdt_base( 1089 | ckpt = None, 1090 | device='cpu', 1091 | eval=True, 1092 | sampling_step=1, 1093 | **kwargs 1094 | ): 1095 | ckpt = ckpt or './pretrained/cdt_base.ckpt' 1096 | print(f"Loading CDT-base from {ckpt}") 1097 | 1098 | 1099 | model = ConditionedDiffusionTokenizer(pretrained=ckpt, 1100 | enc_frames=kwargs.pop('enc_frames', 4), 1101 | dec_frames=kwargs.pop('dec_frames', 1), 1102 | space_down=[0,1,1,1], 1103 | out_channels=16, 1104 | latent_dim=16, 1105 | context_dim=128, 1106 | new_upsample=True, 1107 | sample_steps=sampling_step, 1108 | sample_gamma=kwargs.pop('sample_gamma', 0.8), 1109 | **kwargs) 1110 | model = model.to(device) 1111 | if eval: 1112 | model = model.eval() 1113 | return model 1114 | 1115 | def load_cdt_small( 1116 | ckpt = None, 1117 | device='cpu', 1118 | eval=True, 1119 | latent_dim=16, 1120 | diffusion_step=8192, 1121 | sampling_step=1, 1122 | condition_times=4, 1123 | **kwargs 1124 | ): 1125 | ckpt = ckpt or "./pretrained/cdt_small.ckpt" 1126 | 1127 | model = ConditionedDiffusionTokenizer(pretrained=ckpt, 1128 | enc_frames=kwargs.pop('enc_frames', 4), 1129 | dec_frames=kwargs.pop('dec_frames', 1), 1130 | space_down=[1, 1, 1, 0], 1131 | out_channels=3, 1132 | latent_dim=latent_dim, 1133 | context_dim=64, 1134 | num_timesteps=diffusion_step, 1135 | condition_times=condition_times, 1136 | sample_steps=sampling_step, 1137 | sample_gamma=kwargs.pop('sample_gamma', 0.8), 1138 | **kwargs) 1139 | model = model.to(device) 1140 | if eval: 1141 | model = model.eval() 1142 | return model 1143 | 1144 | def load_cdt(version='base', dtype=torch.float16, *args, **kwargs): 1145 | VERSIONS = { 1146 | 'base': (load_cdt_base, 1.3), 1147 | 'small': (load_cdt_small, 1.3), 1148 | } 1149 | if version not in VERSIONS: 1150 | print(f"ERROR: wrong version '{version}' of CDT, not in '{VERSIONS.keys()}'") 1151 | return None 1152 | model_func, latent_scale = VERSIONS[version] 1153 | model = model_func(*args, **kwargs) 1154 | model.latent_scale = latent_scale 1155 | model = model.to(dtype) 1156 | return model 1157 | 1158 | --------------------------------------------------------------------------------