├── .gitignore
├── LICENSE
├── opensora_evaluate
├── cal_psnr.py
├── cal_lpips.py
└── cal_ssim.py
├── evaluate.py
├── requirements.txt
├── README.md
├── rec_image_eval.py
├── utils.py
├── rec_video_eval.py
└── model
└── cdt.py
/.gitignore:
--------------------------------------------------------------------------------
1 | ./reconstructed_results/
2 | ./reconstructed_results/video_results/
3 | ./reconstructed_results/image_results/
4 |
5 | ./pretrained/
6 | ./pretrained/cdt_base.ckpt
7 | ./pretrained/cdt_small.ckpt
8 |
9 | .DS_Store
10 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 Nianzu Yang and Tongyi Lab
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/opensora_evaluate/cal_psnr.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from tqdm import tqdm
4 | import math
5 |
6 | def img_psnr(img1, img2):
7 |
8 | mse = np.mean((img1 / 1.0 - img2 / 1.0) ** 2)
9 | if mse < 1e-10:
10 | return 100
11 | psnr = 20 * math.log10(1 / math.sqrt(mse))
12 | return psnr
13 |
14 | def trans(x):
15 | return x
16 |
17 | def calculate_psnr(videos1, videos2):
18 |
19 | assert videos1.shape == videos2.shape
20 |
21 | videos1 = trans(videos1)
22 | videos2 = trans(videos2)
23 |
24 | psnr_results = []
25 |
26 | for video_num in range(videos1.shape[0]):
27 |
28 | video1 = videos1[video_num]
29 | video2 = videos2[video_num]
30 |
31 | psnr_results_of_a_video = []
32 | for clip_timestamp in range(len(video1)):
33 |
34 | img1 = video1[clip_timestamp].numpy()
35 | img2 = video2[clip_timestamp].numpy()
36 |
37 | psnr_results_of_a_video.append(img_psnr(img1, img2))
38 |
39 | psnr_results.append(psnr_results_of_a_video)
40 |
41 | psnr_results = np.array(psnr_results)
42 | psnr = {}
43 | psnr_std = {}
44 |
45 | for clip_timestamp in range(len(video1)):
46 | psnr[clip_timestamp] = np.mean(psnr_results[:,clip_timestamp])
47 | psnr_std[clip_timestamp] = np.std(psnr_results[:,clip_timestamp])
48 |
49 | result = {
50 | "value": psnr,
51 | "value_std": psnr_std,
52 | "video_setting": video1.shape,
53 | "video_setting_name": "time, channel, heigth, width",
54 | }
55 |
56 | return result
57 |
58 |
59 | def main():
60 | NUMBER_OF_VIDEOS = 8
61 | VIDEO_LENGTH = 50
62 | CHANNEL = 3
63 | SIZE = 64
64 | videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
65 | videos2 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
66 |
67 | import json
68 | result = calculate_psnr(videos1, videos2)
69 | print(json.dumps(result, indent=4))
70 |
71 | if __name__ == "__main__":
72 | main()
--------------------------------------------------------------------------------
/opensora_evaluate/cal_lpips.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from tqdm import tqdm
4 | import math
5 |
6 | import lpips
7 | import os
8 |
9 | spatial = True
10 |
11 |
12 |
13 | loss_fn = lpips.LPIPS(net='alex', spatial=spatial)
14 |
15 | def trans(x):
16 |
17 | if x.shape[-3] == 1:
18 | x = x.repeat(1, 1, 3, 1, 1)
19 |
20 | x = x * 2 - 1
21 |
22 | return x
23 |
24 | def calculate_lpips(videos1, videos2, device):
25 |
26 | assert videos1.shape == videos2.shape
27 |
28 | videos1 = trans(videos1)
29 | videos2 = trans(videos2)
30 |
31 | lpips_results = []
32 |
33 | for video_num in range(videos1.shape[0]):
34 |
35 | video1 = videos1[video_num]
36 | video2 = videos2[video_num]
37 |
38 | lpips_results_of_a_video = []
39 | for clip_timestamp in range(len(video1)):
40 |
41 | img1 = video1[clip_timestamp].unsqueeze(0).to(device)
42 | img2 = video2[clip_timestamp].unsqueeze(0).to(device)
43 |
44 | loss_fn.to(device)
45 |
46 | lpips_results_of_a_video.append(loss_fn.forward(img1, img2).mean().detach().cpu().tolist())
47 | lpips_results.append(lpips_results_of_a_video)
48 |
49 | lpips_results = np.array(lpips_results)
50 |
51 | lpips = {}
52 | lpips_std = {}
53 |
54 | for clip_timestamp in range(len(video1)):
55 | lpips[clip_timestamp] = np.mean(lpips_results[:,clip_timestamp])
56 | lpips_std[clip_timestamp] = np.std(lpips_results[:,clip_timestamp])
57 |
58 |
59 | result = {
60 | "value": lpips,
61 | "value_std": lpips_std,
62 | "video_setting": video1.shape,
63 | "video_setting_name": "time, channel, heigth, width",
64 | }
65 |
66 | return result
67 |
68 |
69 | def main():
70 | NUMBER_OF_VIDEOS = 8
71 | VIDEO_LENGTH = 50
72 | CHANNEL = 3
73 | SIZE = 64
74 | videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
75 | videos2 = torch.ones(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
76 | device = torch.device("cuda")
77 |
78 | import json
79 | result = calculate_lpips(videos1, videos2, device)
80 | print(json.dumps(result, indent=4))
81 |
82 | if __name__ == "__main__":
83 | main()
--------------------------------------------------------------------------------
/evaluate.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 | import argparse
3 |
4 | # Define constants for dataset paths
5 | WEBVID_PATH = '/mnt/wulanchabu/lpd_wlcb/dataset/webvid/val'
6 | COCO17_PATH = '/cpfs01/Group-rep-learning/multimodal/datasets/coco/val2017'
7 |
8 | def main(args):
9 | # method name
10 | method = args.method
11 | # assert method name is CDT
12 | assert "CDT" in method, "method must be CDT"
13 | # dataset name
14 | dataset = args.dataset
15 | # mode: video or image evaluation
16 | mode = args.mode
17 | # crop size
18 | size = args.size
19 | # number of frames for video evaluation
20 | num_frames = 17
21 | # subset size: 0 means all videos; otherwise, only use subset_size videos for evaluation, you can set the subset_size to 1 for debug
22 | subset_size = args.subset_size
23 |
24 | print(f"Mode: {mode} evaluation")
25 | print(f"Dataset: {dataset}")
26 |
27 | if dataset == 'webvid':
28 | assert mode == 'video', "webvid dataset only supports video evaluation"
29 | # use constant path for webvid dataset
30 | real_data_dir = WEBVID_PATH
31 | elif dataset == 'coco17':
32 | assert mode == 'image', "coco17 dataset only supports image evaluation"
33 | # use constant path for coco17 dataset
34 | real_data_dir = COCO17_PATH
35 | else:
36 | raise ValueError(f"the path of dataset {dataset} is not specified")
37 |
38 |
39 | if "image" in mode:
40 | # other config
41 | base_set_rec = (
42 | '--real_data_dir {real_data_dir} '
43 | '--dataset {dataset} '
44 | '--subset_size {subset_size} '
45 | ).format(real_data_dir=real_data_dir,
46 | dataset=dataset, subset_size=subset_size)
47 | command = 'python rec_image_eval.py --method {method} '.format(method=method) + base_set_rec
48 | else:
49 | # other config
50 | base_set_rec = (
51 | '--crop_size {size} '
52 | '--resolution {size} '
53 | '--num_frames {num_frames} '
54 | '--real_data_dir {real_data_dir} '
55 | '--dataset {dataset} '
56 | '--subset_size {subset_size} '
57 | ).format(size=size, num_frames=num_frames, real_data_dir=real_data_dir,
58 | dataset=dataset, subset_size=subset_size)
59 |
60 | command = 'python rec_video_eval.py --method {method} '.format(method=method) + base_set_rec
61 |
62 | print(f"Command: {command}")
63 |
64 | subprocess.call(command, shell=True)
65 |
66 | if __name__ == '__main__':
67 |
68 | parser = argparse.ArgumentParser()
69 | parser.add_argument('--method', type=str, default='CDT_base')
70 | parser.add_argument('--dataset', type=str, default='webvid', choices=['webvid', 'coco17'])
71 | parser.add_argument('--mode', type=str, default='video', choices=['video', 'image'])
72 | parser.add_argument('--size', type=str, default='256')
73 | parser.add_argument('--subset_size', type=int, default=0)
74 | args = parser.parse_args()
75 |
76 | main(args)
--------------------------------------------------------------------------------
/opensora_evaluate/cal_ssim.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from tqdm import tqdm
4 | import cv2
5 |
6 | def ssim(img1, img2):
7 | C1 = 0.01 ** 2
8 | C2 = 0.03 ** 2
9 | img1 = img1.astype(np.float64)
10 | img2 = img2.astype(np.float64)
11 | kernel = cv2.getGaussianKernel(11, 1.5)
12 | window = np.outer(kernel, kernel.transpose())
13 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]
14 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
15 | mu1_sq = mu1 ** 2
16 | mu2_sq = mu2 ** 2
17 | mu1_mu2 = mu1 * mu2
18 | sigma1_sq = cv2.filter2D(img1 ** 2, -1, window)[5:-5, 5:-5] - mu1_sq
19 | sigma2_sq = cv2.filter2D(img2 ** 2, -1, window)[5:-5, 5:-5] - mu2_sq
20 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
21 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
22 | (sigma1_sq + sigma2_sq + C2))
23 | return ssim_map.mean()
24 |
25 |
26 | def calculate_ssim_function(img1, img2):
27 | if not img1.shape == img2.shape:
28 | raise ValueError('Input images must have the same dimensions.')
29 | if img1.ndim == 2:
30 | return ssim(img1, img2)
31 | elif img1.ndim == 3:
32 | if img1.shape[0] == 3:
33 | ssims = []
34 | for i in range(3):
35 | ssims.append(ssim(img1[i], img2[i]))
36 | return np.array(ssims).mean()
37 | elif img1.shape[0] == 1:
38 | return ssim(np.squeeze(img1), np.squeeze(img2))
39 | else:
40 | raise ValueError('Wrong input image dimensions.')
41 |
42 | def trans(x):
43 | return x
44 |
45 | def calculate_ssim(videos1, videos2):
46 |
47 | assert videos1.shape == videos2.shape
48 |
49 | videos1 = trans(videos1)
50 | videos2 = trans(videos2)
51 |
52 | ssim_results = []
53 |
54 | for video_num in range(videos1.shape[0]):
55 |
56 | video1 = videos1[video_num]
57 | video2 = videos2[video_num]
58 |
59 | ssim_results_of_a_video = []
60 | for clip_timestamp in range(len(video1)):
61 |
62 | img1 = video1[clip_timestamp].numpy()
63 | img2 = video2[clip_timestamp].numpy()
64 |
65 | ssim_results_of_a_video.append(calculate_ssim_function(img1, img2))
66 |
67 | ssim_results.append(ssim_results_of_a_video)
68 |
69 | ssim_results = np.array(ssim_results)
70 |
71 | ssim = {}
72 | ssim_std = {}
73 |
74 | for clip_timestamp in range(len(video1)):
75 | ssim[clip_timestamp] = np.mean(ssim_results[:,clip_timestamp])
76 | ssim_std[clip_timestamp] = np.std(ssim_results[:,clip_timestamp])
77 |
78 | result = {
79 | "value": ssim,
80 | "value_std": ssim_std,
81 | "video_setting": video1.shape,
82 | "video_setting_name": "time, channel, heigth, width",
83 | }
84 |
85 | return result
86 |
87 |
88 | def main():
89 | NUMBER_OF_VIDEOS = 8
90 | VIDEO_LENGTH = 50
91 | CHANNEL = 3
92 | SIZE = 64
93 | videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
94 | videos2 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
95 | device = torch.device("cuda")
96 |
97 | import json
98 | result = calculate_ssim(videos1, videos2)
99 | print(json.dumps(result, indent=4))
100 |
101 | if __name__ == "__main__":
102 | main()
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==2.1.0
2 | accelerate==1.1.0
3 | addict==2.4.0
4 | aiofiles==23.2.1
5 | aiohappyeyeballs==2.4.3
6 | aiohttp==3.10.10
7 | aiosignal==1.3.1
8 | aliyun-python-sdk-core==2.16.0
9 | aliyun-python-sdk-kms==2.16.5
10 | annotated-types==0.7.0
11 | antlr4-python3-runtime==4.9.3
12 | anyio==4.6.2.post1
13 | asttokens==2.4.1
14 | async-timeout==4.0.3
15 | attrs==24.2.0
16 | bcrypt==4.2.0
17 | beautifulsoup4==4.12.3
18 | bitsandbytes==0.44.1
19 | certifi==2024.8.30
20 | cffi==1.17.1
21 | cfgv==3.4.0
22 | chainer==7.8.1
23 | charset-normalizer==3.4.0
24 | clean-fid==0.1.35
25 | click==8.1.7
26 | colorama==0.4.6
27 | colossalai==0.4.6
28 | contourpy==1.3.0
29 | crcmod==1.7
30 | cryptography==43.0.3
31 | cycler==0.12.1
32 | decorator==5.1.1
33 | decord==0.6.0
34 | Deprecated==1.2.14
35 | diffusers==0.32.2
36 | distlib==0.3.9
37 | docstring_parser==0.16
38 | easydict==1.13
39 | einops==0.8.0
40 | exceptiongroup==1.2.2
41 | executing==2.1.0
42 | fabric==3.2.2
43 | facexlib==0.3.0
44 | fairscale==0.4.13
45 | fastapi==0.115.4
46 | ffmpy==0.5.0
47 | filelock==3.16.1
48 | filterpy==1.4.5
49 | fonttools==4.54.1
50 | frozenlist==1.5.0
51 | fsspec==2024.10.0
52 | ftfy==6.3.1
53 | future==1.0.0
54 | fvcore==0.1.5.post20221221
55 | galore-torch==1.0
56 | google==3.0.0
57 | gradio==5.12.0
58 | gradio_client==1.5.4
59 | grpcio==1.68.1
60 | h11==0.14.0
61 | httpcore==1.0.7
62 | httpx==0.28.1
63 | huggingface-hub==0.26.1
64 | icecream==2.1.3
65 | identify==2.6.1
66 | idna==3.10
67 | imageio==2.36.0
68 | importlib_metadata==8.5.0
69 | iniconfig==2.0.0
70 | invoke==2.2.0
71 | iopath==0.1.10
72 | ipython==8.29.0
73 | jedi==0.19.2
74 | Jinja2==3.1.4
75 | jmespath==0.10.0
76 | joblib==1.4.2
77 | jsonschema==4.23.0
78 | jsonschema-specifications==2024.10.1
79 | kiwisolver==1.4.7
80 | lazy_loader==0.4
81 | lightning-utilities==0.11.8
82 | llvmlite==0.43.0
83 | lmdb==1.5.1
84 | loguru==0.7.2
85 | lpips==0.1.4
86 | Markdown==3.7
87 | markdown-it-py==3.0.0
88 | MarkupSafe==2.1.5
89 | matplotlib==3.9.2
90 | matplotlib-inline==0.1.7
91 | mdurl==0.1.2
92 | mediapy==1.2.2
93 | memory-profiler==0.61.0
94 | mmengine==0.10.5
95 | mpmath==1.3.0
96 | msgpack==1.1.0
97 | multidict==6.1.0
98 | mypy-extensions==1.0.0
99 | networkx==3.4.2
100 | ninja==1.11.1.1
101 | nodeenv==1.9.1
102 | numba==0.60.0
103 | numpy==1.26.4
104 | nvidia-cublas-cu12==12.1.3.1
105 | nvidia-cuda-cupti-cu12==12.1.105
106 | nvidia-cuda-nvrtc-cu12==12.1.105
107 | nvidia-cuda-runtime-cu12==12.1.105
108 | nvidia-cudnn-cu12==8.9.2.26
109 | nvidia-cufft-cu12==11.0.2.54
110 | nvidia-curand-cu12==10.3.2.106
111 | nvidia-cusolver-cu12==11.4.5.107
112 | nvidia-cusparse-cu12==12.1.0.106
113 | nvidia-nccl-cu12==2.20.5
114 | nvidia-nvjitlink-cu12==12.4.127
115 | nvidia-nvtx-cu12==12.1.105
116 | omegaconf==2.3.0
117 | onnx==1.17.0
118 | openai-clip==1.0.1
119 | opencv-python==4.10.0.84
120 | opencv-python-headless==4.10.0.84
121 | opencv-transforms==0.0.6
122 | orjson==3.10.15
123 | oss2==2.19.1
124 | packaging==24.1
125 | pandas==2.2.3
126 | paramiko==3.5.0
127 | parso==0.8.4
128 | peft==0.13.2
129 | pexpect==4.9.0
130 | pillow==11.0.0
131 | platformdirs==4.3.6
132 | pluggy==1.5.0
133 | plumbum==1.9.0
134 | portalocker==3.0.0
135 | pre_commit==4.0.1
136 | prompt_toolkit==3.0.48
137 | propcache==0.2.0
138 | protobuf==5.28.3
139 | psutil==6.1.0
140 | ptyprocess==0.7.0
141 | pure_eval==0.2.3
142 | pyarrow==18.1.0
143 | pycparser==2.22
144 | pycryptodome==3.21.0
145 | pydantic==2.9.2
146 | pydantic_core==2.23.4
147 | pydub==0.25.1
148 | Pygments==2.18.0
149 | pyiqa==0.1.13
150 | PyNaCl==1.5.0
151 | pyparsing==3.2.0
152 | pytest==8.3.4
153 | python-dateutil==2.9.0.post0
154 | python-multipart==0.0.20
155 | pytorch-fid==0.3.0
156 | pytorch-lightning==2.4.0
157 | pytz==2024.2
158 | PyYAML==6.0.2
159 | ray==2.38.0
160 | referencing==0.35.1
161 | regex==2024.9.11
162 | requests==2.32.3
163 | rich==13.9.4
164 | rotary-embedding-torch==0.8.4
165 | rpds-py==0.20.1
166 | rpyc==6.0.0
167 | ruff==0.9.2
168 | safehttpx==0.1.6
169 | safetensors==0.4.5
170 | scikit-image==0.24.0
171 | scikit-learn==1.5.2
172 | scipy==1.14.1
173 | semantic-version==2.10.0
174 | sentencepiece==0.2.0
175 | shellingham==1.5.4
176 | six==1.16.0
177 | sniffio==1.3.1
178 | soupsieve==2.6
179 | stack-data==0.6.3
180 | starlette==0.41.2
181 | sympy==1.13.1
182 | tabulate==0.9.0
183 | tensorboard==2.18.0
184 | tensorboard-data-server==0.7.2
185 | termcolor==2.5.0
186 | threadpoolctl==3.5.0
187 | tifffile==2024.9.20
188 | timm==1.0.11
189 | tokenizers==0.15.2
190 | tomli==2.0.2
191 | tomlkit==0.13.2
192 | torch==2.3.0
193 | torch-dist @ http://visionai-wlcb.oss-cn-wulanchabu.aliyuncs.com/library/whl/torch_dist-1.2.9-py3-none-any.whl#sha256=92ae7dfce3b1adf1c02f041394c382aadc649247f397561db048a111ed7d7a81
194 | torch-fidelity==0.3.0
195 | torchaudio==2.3.0
196 | torchmetrics==1.5.1
197 | torchvision==0.18.0
198 | tqdm==4.66.5
199 | traitlets==5.14.3
200 | transformers==4.37.2
201 | triton==2.3.0
202 | typed-argument-parser==1.10.1
203 | typer==0.15.1
204 | typing-inspect==0.9.0
205 | typing_extensions==4.12.2
206 | tzdata==2024.2
207 | urllib3==2.2.3
208 | uuid==1.30
209 | uvicorn==0.29.0
210 | virtualenv==20.27.1
211 | wcwidth==0.2.13
212 | websockets==14.2
213 | Werkzeug==3.1.3
214 | wrapt==1.16.0
215 | xformers==0.0.26.post1
216 | yacs==0.1.8
217 | yapf==0.40.2
218 | yarl==1.17.1
219 | zipp==3.20.2
220 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
Conditioned Diffusion-based Video Tokenizer (CDT)
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 | Official implementation for our paper:
10 |
11 | **Rethinking Video Tokenization: A Conditioned Diffusion-based Approach**
12 |
13 | Author List: Nianzu Yang†, Pandeng Li†, Liming Zhao, Yang Li, Chen-Wei Xie, Yehui Tang, Xudong Lu, Zhihang Liu, Yun Zheng, Yu Liu, Junchi Yan*
14 |
15 | † Equal contribution; * Corresponding author
16 |
17 |
18 | # Content
19 |
20 | - [Folder Specification](#folder-specification)
21 | - [Preparation](#preparation)
22 | - [Environment Setup](#environment-setup)
23 | - [Download Pre-trained Models](#download-pre-trained-models)
24 | - [Prepare Data](#prepare-data)
25 | - [Evaluation](#evaluation)
26 | - [Citation](#citation)
27 | - [Acknowledgement](#acknowledgement)
28 | - [Contact](#contact)
29 |
30 |
31 | ## Folder Specification
32 | ```bash
33 | ├── evaluate.py # script for evaluating the performance of CDT on reconstruction task
34 | ├── model # directory of CDT model
35 | │ └── cdt.py # definition of CDT model
36 | ├── opensora_evaluate # scripts for metrics calculation
37 | │ ├── cal_lpips.py # calculate LPIPS
38 | │ ├── cal_psnr.py # calculate PSNR
39 | │ └── cal_ssim.py # calculate SSIM
40 | ├── pretrained # directory of pretrained models, which you should create by yourself
41 | │ ├── cdt_base.ckpt # CDT-base
42 | │ └── cdt_small.ckpt # CDT-small
43 | ├── README.md # README
44 | ├── rec_image_eval.py # script for evaluating the performance of CDT on image reconstruction task
45 | ├── rec_video_eval.py # script for evaluating the performance of CDT on video reconstruction task
46 | ├── requirements.txt # dependencies
47 | └── utils.py # utility functions
48 | ```
49 |
50 | ## Preparation
51 |
52 | ### Environment Setup
53 |
54 | You can create a new environment and install the dependencies by running the following command:
55 | ```shell
56 | conda create -n cdt python=3.10
57 | conda activate cdt
58 | pip install -r requirements.txt
59 | ```
60 |
61 | ### Download Pre-trained Models
62 |
63 | We provide the pre-trained models, i.e., CDT-base and CDT-small, on [Hugging Face](https://huggingface.co/yangnianzu/CDT). You can download them and put them in the `pretrained` folder.
64 |
65 | ### Prepare Data
66 |
67 | In our paper, we use two datasets for benchmarking the reconstruction performance:
68 |
69 | - `COCO2017-val` for image reconstruction
70 | - `Webvid-val` for video reconstruction
71 |
72 | You can download these two datasets and put them in the 'data' folder. Next, you need to specify the path of these two datasets in the 'evaluate.py' file.
73 |
74 | ## Evaluation
75 |
76 | Evaluate the performance of CDT-base on image reconstruction:
77 | ```shell
78 | python evaluate.py --method CDT-base --dataset coco17 --mode image
79 | ```
80 |
81 | Evaluate the performance of CDT-base on video reconstruction at the 256x256 resolution:
82 |
83 | ```shell
84 | python evaluate.py --method CDT-base --dataset webvid --mode video
85 | ```
86 |
87 | Evaluate the performance of CDT-base on video reconstruction at the 720x720 resolution:
88 |
89 | ```shell
90 | python evaluate.py --method CDT-base --dataset webvid --mode video --size 720
91 | ```
92 |
93 | ---
94 |
95 | Evaluate the performance of CDT-small on image reconstruction:
96 | ```shell
97 | python evaluate.py --method CDT-small --dataset coco17 --mode image
98 | ```
99 |
100 | Evaluate the performance of CDT-small on video reconstruction at the 256x256 resolution:
101 | ```shell
102 | python evaluate.py --method CDT-small --dataset webvid --mode video
103 | ```
104 |
105 | Evaluate the performance of CDT-small on video reconstruction at the 720x720 resolution:
106 | ```shell
107 | python evaluate.py --method CDT-small --dataset webvid --mode video --size 720
108 | ```
109 |
110 | The reconstructed images or videos and the evaluation results will be saved in the `reconstructed_results` folder.
111 |
112 | ## Citation
113 | If you find this work useful in your research, please consider citing:
114 |
115 | ```bibtex
116 | @article{yang2025rethinking,
117 | title={Rethinking Video Tokenization: A Conditioned Diffusion-based Approach},
118 | author={Yang, Nianzu and Li, Pandeng and Zhao, Liming and Li, Yang and Xie, Chen-Wei and Tang, Yehui and Lu, Xudong and Liu, Zhihang and Zheng, Yun and Liu, Yu and Yan, Junchi},
119 | journal={arXiv preprint arXiv:2503.03708},
120 | year={2025}
121 | }
122 | ```
123 |
124 | ## Acknowledgement
125 |
126 | We would like to thank the authors of [Open-Sora-Plan](https://github.com/PKU-YuanGroup/Open-Sora-Plan) for their excellent work, which provides the code for the evaluation metrics.
127 |
128 | ## Contact
129 | Welcome to contact us [yangnianzu@sjtu.edu.cn](mailto:yangnianzu@sjtu.edu.cn) for any question.
--------------------------------------------------------------------------------
/rec_image_eval.py:
--------------------------------------------------------------------------------
1 | import random
2 | import argparse
3 | from tqdm import tqdm
4 | import numpy as np
5 | import torch
6 | from torch.nn import functional as F
7 | import sys
8 | from torch.utils.data import DataLoader, Subset
9 | import os
10 | sys.path.append(".")
11 | from utils import AverageMeter, custom_to_images, SimpleImageDataset
12 | from opensora_evaluate.cal_lpips import calculate_lpips
13 | from opensora_evaluate.cal_psnr import calculate_psnr
14 | from opensora_evaluate.cal_ssim import calculate_ssim
15 | import time
16 | from model.cdt import load_cdt
17 |
18 |
19 | @torch.no_grad()
20 | def main(args: argparse.Namespace):
21 | real_data_dir = args.real_data_dir
22 | dataset = args.dataset
23 | device = args.device
24 | batch_size = args.batch_size
25 | num_workers = 4
26 | subset_size = args.subset_size
27 |
28 | if args.data_type == "bfloat16":
29 | data_type = torch.bfloat16
30 | elif args.data_type == "float32":
31 | data_type = torch.float32
32 | else:
33 | raise ValueError(f"Invalid data type: {args.data_type}")
34 |
35 | folder_name = f"{args.method}_{args.data_type}"
36 |
37 |
38 | generated_images_dir = os.path.join('./reconstructed_results/image_results/', dataset, folder_name)
39 | metrics_results = os.path.join('./reconstructed_results/image_results/', dataset, 'results.txt')
40 |
41 |
42 | if not os.path.exists(generated_images_dir):
43 | os.makedirs(generated_images_dir)
44 |
45 | # ---- Load Model ----
46 | device = args.device
47 | assert 'CDT' in args.method, f"method must be CDT, but got {args.method}"
48 | if 'base' in args.method:
49 | print(f"Loading CDT-base")
50 | vae = load_cdt('base')
51 | print(f"CDT-base Loaded")
52 | elif 'small' in args.method:
53 | print(f"Loading CDT-small")
54 | vae = load_cdt('small')
55 | print(f"CDT-small Loaded")
56 | vae = vae.to(device).to(data_type).eval()
57 | model_size = sum([p.numel() for p in vae.parameters()]) / 1e6
58 | print(f'Successfully loaded {args.method} model with {model_size:.3f} million parameters')
59 | # ---- Load Model ----
60 |
61 | # ---- Prepare Dataset ----
62 | dataset = SimpleImageDataset(image_dir=real_data_dir)
63 | print(f"Total images found: {len(dataset)}")
64 | if subset_size:
65 | indices = range(subset_size)
66 | dataset = Subset(dataset, indices=indices)
67 | dataloader = DataLoader(
68 | dataset, batch_size=batch_size, pin_memory=True, num_workers=num_workers
69 | )
70 | # ---- Prepare Dataset
71 |
72 | # ---- Inference ----
73 | avg_ssim = AverageMeter()
74 | avg_psnr = AverageMeter()
75 | avg_lpips = AverageMeter()
76 |
77 | log_txt = os.path.join(generated_images_dir, 'results.txt')
78 |
79 | step = 0
80 | total_time = 0
81 | total_images = 0
82 |
83 | with open(log_txt, 'a+') as f:
84 | for batch in tqdm(dataloader):
85 | step += 1
86 | x, file_names = batch['image'], batch['file_name']
87 | original_width = batch['original_width'][0]
88 | original_height = batch['original_height'][0]
89 | torch.cuda.empty_cache()
90 |
91 | x = x.to(device=device, dtype=data_type)
92 | x=x.unsqueeze(2)
93 | start_time = time.time()
94 | video_recon = vae(x)
95 | torch.cuda.synchronize()
96 | end_time = time.time()
97 | total_time += end_time - start_time
98 | total_images += 1
99 |
100 | x, video_recon = x.data.cpu().float(), video_recon.data.cpu().float()
101 |
102 |
103 | if not os.path.exists(generated_images_dir):
104 | os.makedirs(generated_images_dir, exist_ok=True)
105 |
106 | video_recon = video_recon.squeeze(2)
107 | for idx, image_recon in enumerate(video_recon):
108 | output_file = os.path.join(generated_images_dir, file_names[idx])
109 | custom_to_images(image_recon,output_file,original_height,original_width)
110 |
111 |
112 | video_recon = video_recon.unsqueeze(2)
113 | x = torch.clamp(x, -1, 1)
114 | x = (x + 1) / 2
115 | video_recon = torch.clamp(video_recon, -1, 1)
116 | video_recon = (video_recon + 1) / 2
117 | x = x.permute(0,2,1,3,4).float()
118 | video_recon = video_recon.permute(0,2,1,3,4).float()
119 |
120 | # SSIM
121 | tmp_list = list(calculate_ssim(x, video_recon)['value'].values())
122 | avg_ssim.updata(np.mean(tmp_list))
123 |
124 | # PSNR
125 | tmp_list = list(calculate_psnr(x, video_recon)['value'].values())
126 | avg_psnr.updata(np.mean(tmp_list))
127 |
128 | # LPIPS
129 | tmp_list = list(calculate_lpips(x, video_recon, args.device)['value'].values())
130 | avg_lpips.updata(np.mean(tmp_list))
131 |
132 | if step % args.log_every_steps ==0:
133 | result = (
134 | f'Step: {step}, PSNR: {avg_psnr.avg}\n'
135 | f'Step: {step}, SSIM: {avg_ssim.avg}\n'
136 | f'Step: {step}, LPIPS: {avg_lpips.avg}\n')
137 | print(result, flush=True)
138 | f.write("="*20+'\n')
139 | f.write(result)
140 |
141 | final_result = (f'psnr: {avg_psnr.avg}\n'
142 | f'ssim: {avg_ssim.avg}\n'
143 | f'lpips: {avg_lpips.avg}')
144 |
145 | print("="*20)
146 | print("Final Results:")
147 | print(final_result)
148 | print("="*20)
149 | print(f'Eval Info:\nmethod: {args.method}\nreal_data_dir: {args.real_data_dir}')
150 | print("="*20)
151 |
152 |
153 |
154 | with open(metrics_results, 'a') as f:
155 | f.write("="*20+'\n')
156 | f.write(f'PSNR: {avg_psnr.avg}\n')
157 | f.write(f'SSIM: {avg_ssim.avg}\n')
158 | f.write(f'LPIPS: {avg_lpips.avg}\n')
159 | f.write(f'Time: {total_time}\n')
160 | f.write(f'Images Number: {total_images}\n')
161 | f.write(f'Avg Time: {total_time/total_images:.4f}\n')
162 | f.write(f'Method: {args.method}\n')
163 | f.write(f'Real Data Dir: {args.real_data_dir}\n')
164 | f.write(f'Data Type: {data_type}\n')
165 | f.write("="*20+'\n\n')
166 | # ---- Inference ----
167 |
168 | if __name__ == "__main__":
169 | parser = argparse.ArgumentParser()
170 | parser.add_argument("--real_data_dir", type=str)
171 | parser.add_argument("--dataset", type=str, default='coco17')
172 | parser.add_argument("--method", type=str)
173 | parser.add_argument("--batch_size", type=int, default=1)
174 | parser.add_argument("--num_workers", type=int, default=8)
175 | parser.add_argument("--subset_size", type=int, default=100)
176 | parser.add_argument("--device", type=str, default="cuda")
177 | parser.add_argument("--data_type", type=str, default="float32", choices=["float32", "bfloat16"])
178 | parser.add_argument("--log_every_steps", type=int, default=50)
179 | args = parser.parse_args()
180 | main(args)
181 |
182 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import torch
4 | import numpy as np
5 | import numpy.typing as npt
6 | import torchvision.transforms as T
7 | from PIL import Image
8 | from torch.utils.data import Dataset
9 | from decord import VideoReader, cpu
10 |
11 |
12 | class AverageMeter(object):
13 | def __init__(self):
14 | self.reset()
15 |
16 | def reset(self):
17 | self.val = 0.0
18 | self.avg = 0.0
19 | self.sum = 0.0
20 | self.cnt = 0.0
21 |
22 | def updata(self, val, n=1.0):
23 | self.val = val
24 | self.sum += val * n
25 | self.cnt += n
26 | if self.cnt == 0:
27 | self.avg = 1
28 | else:
29 | self.avg = self.sum / self.cnt
30 |
31 |
32 | # =============================== Image ===============================
33 |
34 | class Image_Crop(object):
35 | def __init__(self):
36 | pass
37 | def __call__(self, img):
38 | iw, ih = img.size
39 | ow = (iw // 8) * 8
40 | oh = (ih // 8) * 8
41 | return img.crop((0, 0, ow, oh))
42 |
43 | def custom_to_images(x, output_dir, h_i, h_w):
44 | x = x.detach().cpu()
45 | x = torch.clamp(x, -1, 1)
46 | x = (x + 1) / 2
47 | x = x.permute(1, 2, 0).float()
48 | x = (x.numpy() * 255).round().astype(np.uint8)
49 |
50 | img = Image.fromarray(x, mode='RGB')
51 | img = T.Resize((h_i, h_w))(img)
52 | img.save(output_dir)
53 |
54 | class SimpleImageDataset(Dataset):
55 | def __init__(self, image_dir):
56 |
57 | self.image_dir = image_dir
58 | self.image_files = [f for f in os.listdir(image_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
59 | self.transform = T.Compose([
60 | Image_Crop(),
61 | T.ToTensor(),
62 | T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
63 | ])
64 |
65 | def __len__(self):
66 | return len(self.image_files)
67 |
68 | def __getitem__(self, idx):
69 | img_name = self.image_files[idx]
70 | img_path = os.path.join(self.image_dir, img_name)
71 |
72 | image = Image.open(img_path).convert('RGB')
73 | original_width, original_height = image.size
74 | image = self.transform(image)
75 |
76 | return {
77 | 'image': image,
78 | 'file_name': img_name,
79 | 'original_width': original_width,
80 | 'original_height': original_height
81 | }
82 |
83 | # =============================== Video ===============================
84 |
85 | class CenterCrop(object):
86 |
87 | def __init__(self, size):
88 | self.size = size
89 |
90 | def __call__(self, img):
91 | # resize
92 | iw, ih, ow, oh = *img.size, *self.size
93 | scale = max(ow / iw, oh / ih)
94 | img = img.resize((round(scale * iw), round(scale * ih)), Image.LANCZOS)
95 |
96 | # center crop
97 | w, h = img.size
98 | if w > ow:
99 | x1 = (w - ow) // 2
100 | img = img.crop((x1, 0, x1 + ow, oh))
101 | elif h > oh:
102 | y1 = (h - oh) // 2
103 | img = img.crop((0, y1, ow, y1 + oh))
104 | return img
105 |
106 |
107 | def array_to_video(
108 | image_array: npt.NDArray, fps: float = 30.0, path: str = "output_video.mp4"
109 | ) -> None:
110 | frame_dir = path.replace('&', '_').replace('.mp4', '_temp')
111 | os.makedirs(frame_dir, exist_ok=True)
112 | for fid, frame in enumerate(image_array):
113 | tpth = os.path.join(frame_dir, '%04d.png' % (fid+1))
114 | cv2.imwrite(tpth, frame[:,:,::-1])
115 |
116 |
117 | cmd = f'ffmpeg -y -f image2 -loglevel quiet -framerate {fps} -i {frame_dir}/%04d.png -vcodec libx264 -crf 17 -pix_fmt yuv420p {path}'
118 | os.system(cmd)
119 | os.system(f'rm -rf {frame_dir}')
120 |
121 |
122 | def custom_to_video(
123 | x: torch.Tensor, fps: float = 2.0, output_file: str = "output_video.mp4"
124 | ) -> None:
125 | x = x.detach().cpu()
126 | x = torch.clamp(x, -1, 1)
127 | x = (x + 1) / 2
128 | x = x.permute(1, 2, 3, 0).float().numpy()
129 | x = (255 * x).astype(np.uint8)
130 | array_to_video(x, fps=fps, path=output_file)
131 | # breakpoint()
132 | return
133 |
134 |
135 | class RealVideoDataset(Dataset):
136 | def __init__(
137 | self,
138 | real_data_dir,
139 | num_frames,
140 | sample_rate=1,
141 | crop_size=None,
142 | resolution=128,
143 | ) -> None:
144 | super().__init__()
145 | self.real_video_files = self._combine_without_prefix(real_data_dir)
146 | self.num_frames = num_frames
147 | self.sample_rate = sample_rate
148 | self.crop_size = crop_size
149 |
150 | def __len__(self):
151 | return len(self.real_video_files)
152 |
153 | def __getitem__(self, index):
154 | if index >= len(self):
155 | raise IndexError
156 | real_video_file = self.real_video_files[index]
157 | real_video_tensor = self._load_video(real_video_file)
158 | video_name = os.path.basename(real_video_file)
159 | return {'video': real_video_tensor, 'file_name': video_name }
160 |
161 | def _load_video(self, video_path):
162 | num_frames = self.num_frames
163 | sample_rate = self.sample_rate
164 | decord_vr = VideoReader(video_path, ctx=cpu(0))
165 | total_frames = len(decord_vr)
166 | sample_frames_len = sample_rate * num_frames
167 |
168 | if total_frames > sample_frames_len:
169 | s = 0
170 | e = s + sample_frames_len
171 | num_frames = num_frames
172 | else:
173 | s = 0
174 | e = total_frames
175 | num_frames = int(total_frames / sample_frames_len * num_frames)
176 | print(
177 | f"sample_frames_len {sample_frames_len}, only can sample {num_frames * sample_rate}",
178 | video_path,
179 | total_frames,
180 | )
181 |
182 | frame_id_list = np.linspace(s, e - 1, num_frames, dtype=int)
183 | video_data = decord_vr.get_batch(frame_id_list).asnumpy()
184 | frames = []
185 | for frame_ in range(video_data.shape[0]):
186 | frame_i = video_data[frame_,:,:,:]
187 | frame_i = Image.fromarray(frame_i)
188 | if frame_i.mode != 'RGB':
189 | frame_i = frame_i.convert('RGB')
190 | frame_i = _preprocess(frame_i, crop_size=self.crop_size)
191 | frames.append(frame_i)
192 | frames = torch.stack(frames, dim=1)
193 | return frames
194 |
195 | def _combine_without_prefix(self, folder_path, prefix="."):
196 | folder = []
197 | for name in os.listdir(folder_path):
198 | if name[0] == prefix:
199 | continue
200 | folder.append(os.path.join(folder_path, name))
201 | folder.sort()
202 | return folder
203 |
204 |
205 | def _preprocess(video_data, crop_size=None):
206 | transform=T.Compose([
207 | CenterCrop([crop_size,crop_size]),
208 | T.ToTensor(),
209 | T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
210 | ])
211 | video_outputs = transform(video_data)
212 | return video_outputs
--------------------------------------------------------------------------------
/rec_video_eval.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import random
4 | import argparse
5 | from tqdm import tqdm
6 | import numpy as np
7 | import numpy.typing as npt
8 | import torch
9 | from torch.utils.data import DataLoader, Subset
10 | sys.path.append(".")
11 | from utils import *
12 | from opensora_evaluate.cal_lpips import calculate_lpips
13 | from opensora_evaluate.cal_psnr import calculate_psnr
14 | from opensora_evaluate.cal_ssim import calculate_ssim
15 | import time
16 | from model.cdt import load_cdt
17 |
18 |
19 | @torch.no_grad()
20 | def main(args: argparse.Namespace):
21 | real_data_dir = args.real_data_dir
22 | dataset = args.dataset
23 | sample_rate = args.sample_rate
24 | resolution = args.resolution
25 | crop_size = args.crop_size
26 | num_frames = args.num_frames
27 | sample_rate = args.sample_rate
28 | device = args.device
29 | sample_fps = args.sample_fps
30 | batch_size = args.batch_size
31 | num_workers = args.num_workers
32 | subset_size = args.subset_size
33 |
34 | if args.data_type == "bfloat16":
35 | data_type = torch.bfloat16
36 | elif args.data_type == "float32":
37 | data_type = torch.float32
38 | else:
39 | raise ValueError(f"Invalid data type: {args.data_type}")
40 |
41 |
42 | folder_name = f"{args.method}_{args.resolution}_{args.data_type}"
43 |
44 |
45 | generated_video_dir = os.path.join('./reconstructed_results/video_results/', dataset, folder_name)
46 | metrics_results = os.path.join('./reconstructed_results/video_results/', dataset, 'results.txt')
47 |
48 |
49 | if not os.path.exists(generated_video_dir):
50 | os.makedirs(generated_video_dir)
51 |
52 |
53 | # ---- Load Model ----
54 | device = args.device
55 | assert 'CDT' in args.method, f"method must be CDT, but got {args.method}"
56 | if 'base' in args.method:
57 | print(f"Loading CDT-base")
58 | vae = load_cdt('base')
59 | print(f"CDT-base Loaded")
60 | elif 'small' in args.method:
61 | print(f"Loading CDT-small")
62 | vae = load_cdt('small')
63 | print(f"CDT-small Loaded")
64 | vae = vae.to(device).to(data_type).eval()
65 | model_size = sum([p.numel() for p in vae.parameters()]) / 1e6
66 | print(f'Successfully loaded {args.method} model with {model_size:.3f} million parameters')
67 | # ---- Load Model ----
68 |
69 |
70 | # ---- Prepare Dataset ----
71 | dataset = RealVideoDataset(
72 | real_data_dir=real_data_dir,
73 | num_frames=num_frames,
74 | sample_rate=sample_rate,
75 | crop_size=crop_size,
76 | resolution=resolution,
77 | )
78 | if subset_size:
79 | indices = range(subset_size)
80 | dataset = Subset(dataset, indices=indices)
81 | dataloader = DataLoader(
82 | dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=num_workers
83 | )
84 | # ---- Prepare Dataset
85 |
86 |
87 |
88 | # ---- Inference ----
89 |
90 | avg_ssim = AverageMeter()
91 | avg_psnr = AverageMeter()
92 | avg_lpips = AverageMeter()
93 |
94 | log_txt = os.path.join(generated_video_dir, 'results.txt')
95 |
96 | total_time = 0
97 | total_videos = 0
98 | step = 0
99 |
100 | with open(log_txt, 'a+') as f:
101 | for batch in tqdm(dataloader):
102 | step += 1
103 | x, file_names = batch['video'], batch['file_name']
104 | if x.size(2) < args.num_frames:
105 | print(file_names)
106 | continue
107 | torch.cuda.empty_cache()
108 | x = x.to(device=device, dtype=data_type)
109 |
110 | start_time = time.time()
111 | video_recon = vae(x)
112 | torch.cuda.synchronize()
113 | end_time = time.time()
114 | total_time += end_time - start_time
115 | total_videos += 1
116 |
117 |
118 | x, video_recon = x.data.cpu().float(), video_recon.data.cpu().float()
119 |
120 | # save reconstructed video
121 | if not os.path.exists(generated_video_dir):
122 | os.makedirs(generated_video_dir, exist_ok=True)
123 | for idx, video in enumerate(video_recon):
124 | output_path = os.path.join(generated_video_dir, file_names[idx])
125 | custom_to_video(
126 | video, fps=sample_fps / sample_rate, output_file=output_path
127 | )
128 |
129 | x = torch.clamp(x, -1, 1)
130 | x = (x + 1) / 2
131 | video_recon = torch.clamp(video_recon, -1, 1)
132 | video_recon = (video_recon + 1) / 2
133 |
134 | x = x.permute(0,2,1,3,4).float()
135 | video_recon = video_recon.permute(0,2,1,3,4).float()
136 |
137 | # SSIM
138 | tmp_list = list(calculate_ssim(x, video_recon)['value'].values())
139 | avg_ssim.updata(np.mean(tmp_list))
140 |
141 | # PSNR
142 | tmp_list = list(calculate_psnr(x, video_recon)['value'].values())
143 | avg_psnr.updata(np.mean(tmp_list))
144 |
145 | # LPIPS
146 | tmp_list = list(calculate_lpips(x, video_recon, args.device)['value'].values())
147 | avg_lpips.updata(np.mean(tmp_list))
148 |
149 | if step % args.log_every_steps ==0:
150 | result = (
151 | f'Step: {step}, PSNR: {avg_psnr.avg}\n'
152 | f'Step: {step}, SSIM: {avg_ssim.avg}\n'
153 | f'Step: {step}, LPIPS: {avg_lpips.avg}\n')
154 | print(result, flush=True)
155 | f.write("="*20+'\n')
156 | f.write(result)
157 |
158 |
159 | final_result = (f'psnr: {avg_psnr.avg}\n'
160 | f'ssim: {avg_ssim.avg}\n'
161 | f'lpips: {avg_lpips.avg}')
162 |
163 | print("="*20)
164 | print("Final Results:")
165 | print(final_result)
166 | print("="*20)
167 | print(f'Eval Info:\nmethod: {args.method}\nresolution: {args.resolution}\nnum_frames: {args.num_frames}\nreal_data_dir: {args.real_data_dir}')
168 | print("="*20)
169 |
170 | with open(metrics_results, 'a') as f:
171 | f.write("="*20+'\n')
172 | f.write(f'PSNR: {avg_psnr.avg}\n')
173 | f.write(f'SSIM: {avg_ssim.avg}\n')
174 | f.write(f'LPIPS: {avg_lpips.avg}\n')
175 | f.write(f'Time: {total_time}\n')
176 | f.write(f'Images Number: {total_videos}\n')
177 | f.write(f'Avg Time: {total_time/total_videos:.4f}\n')
178 | f.write(f'Method: {args.method}\n')
179 | f.write(f'Resolution: {args.resolution}\n')
180 | f.write(f'Num Frames: {args.num_frames}\n')
181 | f.write(f'Real Data Dir: {args.real_data_dir}\n')
182 | f.write(f'Data Type: {data_type}\n')
183 | f.write("="*20+'\n\n')
184 | # ---- Inference ----
185 |
186 |
187 | if __name__ == "__main__":
188 |
189 | parser = argparse.ArgumentParser()
190 | parser.add_argument("--real_data_dir", type=str)
191 | parser.add_argument("--dataset", type=str, default='webvid')
192 | parser.add_argument("--method", type=str)
193 | parser.add_argument("--sample_fps", type=int, default=10)
194 | parser.add_argument("--resolution", type=int, default=256)
195 | parser.add_argument("--crop_size", type=int, default=256)
196 | parser.add_argument("--log_every_steps", type=int, default=50)
197 | parser.add_argument("--num_frames", type=int, default=17) # number of frames for video evaluation
198 | parser.add_argument("--sample_rate", type=int, default=1)
199 | parser.add_argument("--batch_size", type=int, default=1)
200 | parser.add_argument("--num_workers", type=int, default=8)
201 | parser.add_argument("--subset_size", type=int, default=0)
202 | parser.add_argument("--device", type=str, default="cuda")
203 | parser.add_argument("--data_type", type=str, default="float32", choices=["float32", "bfloat16"])
204 |
205 | args = parser.parse_args()
206 | main(args)
207 |
208 |
--------------------------------------------------------------------------------
/model/cdt.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import logging
4 | import numpy as np
5 | from functools import partial
6 | from einops import rearrange, repeat, pack, unpack
7 | from typing import Any, Dict, Tuple, Union
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 |
13 |
14 | CLIP_FIRST_CHUNK = False
15 | CACHE_T = 2
16 | INT_MAX = 2**31
17 |
18 |
19 | class CausalConv3d(nn.Conv3d):
20 | """
21 | Causal 3d convolusion.
22 | """
23 | PAD_MODE = "replicate"
24 |
25 | def __init__(self, *args, **kwargs):
26 | super().__init__(*args, **kwargs)
27 | self._padding = (
28 | self.padding[2],
29 | self.padding[2],
30 | self.padding[1],
31 | self.padding[1],
32 | 2 * self.padding[0],
33 | 0
34 | )
35 | arg_list = list(args)
36 | if len(arg_list)>=5:
37 | arg_list[4] = 0
38 | elif 'padding' in kwargs:
39 | kwargs['padding']= 0
40 | super().__init__(*arg_list, **kwargs)
41 | nn.init.zeros_(self.weight)
42 | if self.bias is not None:
43 | nn.init.zeros_(self.bias)
44 |
45 | def forward(self, x, cache_x=None):
46 | padding = list(self._padding)
47 | if cache_x is not None and self._padding[4] > 0:
48 | cache_x = cache_x.to(x.device)
49 | x = torch.cat([cache_x, x], dim=2)
50 | padding[4] -= cache_x.shape[2]
51 | x = F.pad(x, padding, mode=self.PAD_MODE)
52 |
53 | if x.numel() > INT_MAX:
54 | t = x.shape[2]
55 | kernel_t = self.kernel_size[0]
56 | num_split = max(1, t - kernel_t + 1)
57 | out_list = []
58 | for i in range(num_split):
59 | x_s = x[:, :, i:i+kernel_t, :, :]
60 | out_list.append(super().forward(x_s))
61 | out = torch.cat(out_list, dim=2)
62 | del out_list
63 | else:
64 | out = super().forward(x)
65 |
66 | return out
67 |
68 |
69 | class Residual(nn.Module):
70 | def __init__(self, fn):
71 | super().__init__()
72 | self.fn = fn
73 |
74 | def forward(self, x, *args, **kwargs):
75 | return self.fn(x, *args, **kwargs) + x
76 |
77 |
78 | class Fp32Upsample(nn.Upsample):
79 |
80 | def forward(self, x):
81 | """
82 | Fix bfloat16 support for nearest neighbor interpolation.
83 | """
84 | return super().forward(x.float()).type_as(x)
85 |
86 |
87 | class Upsample(nn.Module):
88 | def __init__(self, dim_in, dim_out=None, new_upsample=False):
89 | super().__init__()
90 | if dim_out is None:
91 | dim_out = dim_in
92 | if new_upsample:
93 | self.up = Fp32Upsample(scale_factor=(1.0, 2.0, 2.0), mode='nearest-exact')
94 | self.conv = CausalConv3d(dim_in, dim_out, (1, 5, 5), padding=(0, 2, 2))
95 | else:
96 | self.up = None
97 | self.conv = nn.ConvTranspose3d(dim_in, dim_out, (1, 5, 5), (1, 2, 2), (0, 2, 2),
98 | output_padding=(0, 1, 1))
99 |
100 | def forward(self, x):
101 | if self.up is not None:
102 | x = self.up(x)
103 | x = self.conv(x)
104 | return x
105 |
106 |
107 | class Downsample(nn.Module):
108 | def __init__(self, dim_in, dim_out=None):
109 | super().__init__()
110 | if dim_out is None:
111 | dim_out = dim_in
112 | self.conv = CausalConv3d(dim_in, dim_out, (1, 3, 3), (1, 2, 2), (0, 1, 1))
113 |
114 | def forward(self, x, feat_cache=None, feat_idx=[0]):
115 | x = self.conv(x)
116 | return x
117 |
118 |
119 | class TimeDownsample(nn.Module):
120 | def __init__(
121 | self,
122 | dim,
123 | dim_out = None,
124 | kernel_size = 3,
125 | antialias = False
126 | ):
127 | super().__init__()
128 | dim_out = dim_out or dim
129 | self.time_causal_padding = (kernel_size - 1, 0)
130 | self.conv = nn.Conv1d(dim, dim_out, kernel_size, stride = 2)
131 | nn.init.zeros_(self.conv.bias)
132 | nn.init.zeros_(self.conv.weight)
133 | self.conv.weight.data[range(dim_out), range(dim), -2] = 1.0
134 | self.conv.weight.data[range(dim_out), range(dim), -1] = 1.0
135 | self.cache_t = 1
136 |
137 | def forward(self, x, feat_cache=None, feat_idx=[0]):
138 | x = rearrange(x, 'b c t h w -> b h w c t').contiguous()
139 | x, ps = pack([x], '* c t')
140 |
141 | if feat_cache is not None:
142 | idx = feat_idx[0]
143 | cache_x = x[..., -self.cache_t:].clone()
144 |
145 | if feat_cache[idx] is not None and self.time_causal_padding[0] > 0:
146 | x = torch.cat([feat_cache[idx], x], dim=-1)
147 | else:
148 | x = F.pad(x, self.time_causal_padding, mode=CausalConv3d.PAD_MODE)
149 | out = self.conv(x)
150 | feat_cache[idx] = cache_x
151 | feat_idx[0] += 1
152 | else:
153 | x = F.pad(x, self.time_causal_padding, mode=CausalConv3d.PAD_MODE)
154 | out = self.conv(x)
155 |
156 | out = unpack(out, ps, '* c t')[0]
157 | out = rearrange(out, 'b h w c t -> b c t h w').contiguous()
158 | return out
159 |
160 |
161 | class TimeUpsample(nn.Module):
162 | def __init__(
163 | self,
164 | dim
165 | ):
166 | super().__init__()
167 | self.conv = nn.Conv1d(dim, dim*2, 1)
168 | self.init_conv_(self.conv)
169 |
170 | def init_conv_(self, conv):
171 | o, i, t = conv.weight.shape
172 | conv_weight = torch.zeros(o // 2, i, t)
173 | conv_weight[range(o//2), range(i), :] = 1.0
174 | conv_weight = repeat(conv_weight, 'o ... -> (o 2) ...')
175 |
176 | conv.weight.data.copy_(conv_weight)
177 | nn.init.zeros_(conv.bias.data)
178 |
179 | def forward(self, x):
180 | x = rearrange(x, 'b c t h w -> b h w c t').contiguous()
181 | x, ps = pack([x], '* c t')
182 |
183 | out = self.conv(x)
184 |
185 | out = rearrange(out, 'b (c p) t -> b c (t p)', p = 2).contiguous()
186 | out = unpack(out, ps, '* c t')[0]
187 | out = rearrange(out, 'b h w c t -> b c t h w').contiguous()
188 | if CLIP_FIRST_CHUNK:
189 | out = out[:, :, 1:, :, :]
190 | return out
191 |
192 |
193 | class LayerNorm(nn.Module):
194 | def __init__(self, dim, eps=1e-5):
195 | super().__init__()
196 | self.eps = eps
197 | self.g = nn.Parameter(torch.ones(1, dim, 1, 1, 1))
198 | self.b = nn.Parameter(torch.zeros(1, dim, 1, 1, 1))
199 |
200 | def forward(self, x):
201 | var = torch.var(x, dim=1, unbiased=False, keepdim=True)
202 | mean = torch.mean(x, dim=1, keepdim=True)
203 | return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
204 |
205 |
206 | class PreNorm(nn.Module):
207 | def __init__(self, dim, fn):
208 | super().__init__()
209 | self.fn = fn
210 | self.norm = LayerNorm(dim)
211 |
212 | def forward(self, x, feat_cache=None, feat_idx=[0]):
213 | x = self.norm(x)
214 | return self.fn(x, feat_cache=feat_cache, feat_idx=feat_idx)
215 |
216 |
217 | class Block(nn.Module):
218 | def __init__(self, dim, dim_out, large_filter=False):
219 | super().__init__()
220 | self.block = nn.ModuleList([
221 | CausalConv3d(dim, dim_out, (3, 7, 7) if large_filter else (3, 3, 3),
222 | padding=(1, 3, 3) if large_filter else (1, 1, 1)),
223 | LayerNorm(dim_out), nn.ReLU()]
224 | )
225 |
226 | def forward(self, x, feat_cache=None, feat_idx=[0]):
227 | for layer in self.block:
228 | if isinstance(layer, CausalConv3d) and feat_cache is not None:
229 | idx = feat_idx[0]
230 | cache_x = x[:, :, -CACHE_T:, :, :].clone()
231 | if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
232 | # cache last frame of last two chunk
233 | cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
234 | x = layer(x, feat_cache[idx])
235 | feat_cache[idx] = cache_x
236 | feat_idx[0] += 1
237 | else:
238 | x = layer(x)
239 | return x
240 |
241 |
242 |
243 | def module_forward(backbone, x, t=None, feat_cache=None, feat_idx=[0]):
244 | if t is None:
245 | if isinstance(backbone, CausalConv3d):
246 | if feat_cache is not None:
247 | idx = feat_idx[0]
248 | cache_x = x[:, :, -CACHE_T:, :, :].clone()
249 | if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
250 | # cache last frame of last two chunk
251 | cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2), cache_x], dim=2)
252 | x = backbone(x, feat_cache[idx])
253 | feat_cache[idx] = cache_x
254 | feat_idx[0] += 1
255 | else:
256 | x = backbone(x)
257 | elif type(backbone) in [Upsample, TimeUpsample, nn.Identity]:
258 | x = backbone(x)
259 | else:
260 | x = backbone(x, feat_cache=feat_cache, feat_idx=feat_idx)
261 | else:
262 | x = backbone(x, t, feat_cache=feat_cache, feat_idx=feat_idx)
263 | return x
264 |
265 | class ResnetBlock(nn.Module):
266 | def __init__(self, dim, dim_out, time_emb_dim=None, large_filter=False):
267 | super().__init__()
268 | self.mlp = (
269 | nn.Sequential(nn.LeakyReLU(0.2), nn.Linear(time_emb_dim, dim_out))
270 | if time_emb_dim is not None
271 | else None
272 | )
273 |
274 | self.block1 = Block(dim, dim_out, large_filter)
275 | self.block2 = Block(dim_out, dim_out)
276 | self.res_conv = CausalConv3d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
277 |
278 | def forward(self, x, time_emb=None, feat_cache=None, feat_idx=[0]):
279 | h = self.block1(x, feat_cache=feat_cache, feat_idx=feat_idx)
280 |
281 | if time_emb is not None:
282 | h = h + self.mlp(time_emb)[:, :, None, None, None]
283 |
284 | h = self.block2(h, feat_cache=feat_cache, feat_idx=feat_idx)
285 |
286 | return h + self.res_conv(x)
287 |
288 |
289 | class LinearAttention(nn.Module):
290 | def __init__(self, dim, heads=1, dim_head=None):
291 | super().__init__()
292 | if dim_head is None:
293 | dim_head = dim
294 | self.scale = dim_head ** -0.5
295 | self.heads = heads
296 | hidden_dim = dim_head * heads
297 | self.to_qkv = CausalConv3d(dim, hidden_dim * 3, 1, bias=False)
298 | self.to_out = CausalConv3d(hidden_dim, dim, 1)
299 |
300 | def forward(self, x, feat_cache=None, feat_idx=[0]):
301 | b, c, t, h, w = x.shape
302 | qkv = self.to_qkv(x).chunk(3, dim=1)
303 | q, k, v = map(lambda t: rearrange(t, "b (h c) t x y -> (b t) h c (x y)", h=self.heads), qkv)
304 |
305 | q = q.softmax(dim=-2)
306 | k = k.softmax(dim=-1)
307 |
308 | q = q * self.scale
309 | context = torch.einsum("b h d n, b h e n -> b h d e", k, v)
310 |
311 | out = torch.einsum("b h d e, b h d n -> b h e n", context, q)
312 | out = rearrange(out, "(b t) h c (x y) -> b (h c) t x y", h=self.heads, x=h, y=w, t=t)
313 | return self.to_out(out)
314 |
315 |
316 | def count_conv3d(model):
317 | count = 0
318 | for m in model.modules():
319 | if isinstance(m, CausalConv3d):
320 | count += 1
321 | return count
322 |
323 | def count_time_down_sample(model):
324 | count = 0
325 | for m in model.modules():
326 | if isinstance(m, TimeDownsample):
327 | count += 1
328 | return count
329 |
330 |
331 | class DiagonalGaussianDistribution(object):
332 | def __init__(self, parameters, deterministic=False):
333 | self.parameters = parameters
334 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
335 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
336 | self.deterministic = deterministic
337 | self.std = torch.exp(0.5 * self.logvar)
338 | self.var = torch.exp(self.logvar)
339 | if self.deterministic:
340 | self.var = self.std = torch.zeros_like(self.mean).to(
341 | device=self.parameters.device
342 | )
343 |
344 | def sample(self):
345 | x = self.mean + self.std * torch.randn(self.mean.shape).to(
346 | device=self.parameters.device
347 | )
348 | return x
349 |
350 | def mode(self):
351 | return self.mean
352 |
353 |
354 | class Compressor(nn.Module):
355 | def __init__(
356 | self,
357 | dim=64,
358 | dim_mults=(1, 2, 3, 4),
359 | reverse_dim_mults=(4, 3, 2, 1),
360 | space_down=(1, 1, 1, 1),
361 | time_down=(0, 0, 1, 1),
362 | new_upsample=False,
363 | channels=3,
364 | out_channels=None,
365 | latent_dim=64,
366 | ):
367 | super().__init__()
368 | self.channels = channels
369 | out_channels = out_channels or channels
370 | self.space_down = space_down
371 | self.new_upsample = new_upsample
372 | self.reversed_space_down = list(reversed(self.space_down))
373 | self.time_down = time_down
374 | self.reversed_time_down = list(reversed(self.time_down))
375 | self.dims = [channels, *map(lambda m: dim * m, dim_mults)]
376 | self.in_out = list(zip(self.dims[:-1], self.dims[1:]))
377 | self.reversed_dims = [*map(lambda m: dim * m, reverse_dim_mults), out_channels]
378 | self.reversed_in_out = list(zip(self.reversed_dims[:-1], self.reversed_dims[1:]))
379 | assert self.dims[-1] == self.reversed_dims[0]
380 | latent_dim = latent_dim or out_channels
381 | self.quant_conv = torch.nn.Conv3d(self.dims[-1], 2 * latent_dim, 1)
382 | self.post_quant_conv = torch.nn.Conv3d(latent_dim, self.dims[-1], 1)
383 | self.quant_conv_res = nn.ModuleList(
384 | [ResnetBlock(dim_in, dim_out) for dim_in, dim_out in self.reversed_in_out[:-1]]+
385 | [torch.nn.Conv3d(self.reversed_dims[-2], 2 * latent_dim, 1)])
386 | self.post_quant_conv_res = nn.ModuleList(
387 | [torch.nn.Conv3d(latent_dim, self.dims[1], 1)]+
388 | [ResnetBlock(dim_in, dim_out) for dim_in, dim_out in self.in_out[1:]])
389 | # build network
390 | self.build_network()
391 |
392 | # cache the last two frame of feature map
393 | self._conv_num = count_conv3d(self.post_quant_conv_res) + count_conv3d(self.dec)
394 | self._conv_idx = [0]
395 | self._feat_map = [None] * self._conv_num
396 | # cache encode
397 | self._enc_conv_num = count_conv3d(self.quant_conv_res) + count_conv3d(self.enc) + count_time_down_sample(self.enc)
398 | self._enc_conv_idx = [0]
399 | self._enc_feat_map = [None] * self._enc_conv_num
400 |
401 |
402 | @property
403 | def dtype(self):
404 | return self.enc[0][0].block1.block[0].weight.dtype
405 |
406 | def build_network(self):
407 | self.enc = nn.ModuleList([])
408 | self.dec = nn.ModuleList([])
409 |
410 | for ind, (dim_in, dim_out) in enumerate(self.in_out):
411 | self.enc.append(
412 | nn.ModuleList(
413 | [
414 | ResnetBlock(dim_in, dim_out, None, True if ind == 0 else False),
415 | Downsample(dim_out) if self.space_down[ind] else nn.Identity(),
416 | TimeDownsample(dim_out) if self.time_down[ind] else nn.Identity(),
417 | ]
418 | )
419 | )
420 |
421 |
422 |
423 | for ind, (dim_in, dim_out) in enumerate(self.reversed_in_out):
424 | is_last = ind >= (len(self.reversed_in_out) - 1)
425 | mapping_or_identity = ResnetBlock(dim_in, dim_out) if not self.reversed_space_down[ind] and is_last else nn.Identity()
426 | upsample_layer = Upsample(dim_out if not is_last else dim_in, dim_out, self.new_upsample) if self.reversed_space_down[ind] else mapping_or_identity
427 | self.dec.append(
428 | nn.ModuleList(
429 | [
430 | ResnetBlock(dim_in, dim_out if not is_last else dim_in),
431 | upsample_layer,
432 | TimeUpsample(dim_out) if self.reversed_time_down[ind] else nn.Identity(),
433 | ]
434 | )
435 | )
436 |
437 | def encode(self, input, deterministic=True):
438 | self._enc_conv_idx = [0]
439 |
440 | input = input.to(self.dtype)
441 | for i, (resnet, down, time_down) in enumerate(self.enc):
442 | input = resnet(input, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
443 | input = down(input)
444 | if isinstance(time_down, TimeDownsample):
445 | input = time_down(input, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
446 | else:
447 | input = time_down(input)
448 | input = input.float()
449 | quant_conv = self.quant_conv.float()
450 | quant_conv_res = self.quant_conv_res.float()
451 |
452 | conv1_out = quant_conv(input)
453 | for layer in quant_conv_res:
454 | if not isinstance(layer, ResnetBlock):
455 | input = layer(input)
456 | else: # ResnetBlock
457 | input = layer(input, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
458 | input += conv1_out
459 |
460 | posterior = DiagonalGaussianDistribution(input)
461 | if deterministic:
462 | z = posterior.mode()
463 | else:
464 | z = posterior.sample()
465 | return z, input
466 |
467 | def decode(self, input):
468 | self._conv_idx = [0]
469 |
470 | input = input.float()
471 | post_quant_conv = self.post_quant_conv.float()
472 | post_quant_conv_res = self.post_quant_conv_res.float()
473 |
474 | post1 = post_quant_conv(input)
475 | for layer in post_quant_conv_res:
476 | if not isinstance(layer, ResnetBlock):
477 | input = layer(input)
478 | else: # ResnetBlock
479 | input = layer(input, feat_cache=self._feat_map, feat_idx=self._conv_idx)
480 | input += post1
481 |
482 | input = input.to(self.dtype)
483 | output = []
484 | for i, (resnet, up, time_up) in enumerate(self.dec):
485 | input = resnet(input, feat_cache=self._feat_map, feat_idx=self._conv_idx)
486 | input = up(input)
487 | input = time_up(input)
488 | output.append(input)
489 | return output[::-1]
490 |
491 | def clear_cache(self):
492 | self._feat_map = [None] * self._conv_num
493 | self._enc_feat_map = [None] * self._enc_conv_num
494 | self._conv_idx = [0]
495 | self._enc_conv_idx = [0]
496 |
497 |
498 | class Unet(nn.Module):
499 | def __init__(
500 | self,
501 | dim,
502 | out_dim=None,
503 | dim_mults=(1, 2, 4, 8),
504 | context_dim=None,
505 | context_out_channels=None,
506 | context_dim_mults=(1, 2, 3, 3),
507 | space_down=(1, 1, 1, 1),
508 | time_down=(0, 0, 0, 1),
509 | channels=3,
510 | with_time_emb=True,
511 | new_upsample=False,
512 | embd_type="01",
513 | condition_times=4,
514 | ):
515 | super().__init__()
516 | self.channels = channels
517 |
518 | dims = [channels, *map(lambda m: dim * m, dim_mults)]
519 | context_dim = context_dim or dim
520 | context_out_channels = context_out_channels or context_dim
521 | context_dims = [context_out_channels, *map(lambda m: context_dim * m, context_dim_mults)]
522 | self.space_down = space_down + [1] * (len(dim_mults)-len(space_down)-1) + [0]
523 | self.reversed_space_down = list(reversed(self.space_down[:-1]))
524 | self.time_down = time_down + [0] * (len(dim_mults)-len(time_down))
525 | self.reversed_time_down = list(reversed(self.time_down[:-1]))
526 | in_out = list(zip(dims[:-1], dims[1:]))
527 | self.embd_type = embd_type
528 | self.condition_times = condition_times
529 |
530 | if with_time_emb:
531 | if embd_type == "01":
532 | time_dim = dim
533 | self.time_mlp = nn.Sequential(nn.Linear(1, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim))
534 | else:
535 | raise NotImplementedError
536 | else:
537 | time_dim = None
538 | self.time_mlp = None
539 |
540 | self.downs = nn.ModuleList([])
541 | self.ups = nn.ModuleList([])
542 | num_resolutions = len(in_out)
543 |
544 | for ind, (dim_in, dim_out) in enumerate(in_out):
545 | is_last = ind >= (num_resolutions - 1)
546 | self.downs.append(
547 | nn.ModuleList(
548 | [
549 | ResnetBlock(
550 | dim_in + context_dims[ind]
551 | if (not is_last) and (ind < self.condition_times)
552 | else dim_in,
553 | dim_out,
554 | time_dim,
555 | True if ind == 0 else False
556 | ),
557 | ResnetBlock(dim_out, dim_out, time_dim),
558 | Residual(PreNorm(dim_out, LinearAttention(dim_out))),
559 | Downsample(dim_out) if self.space_down[ind] else nn.Identity(),
560 | TimeDownsample(dim_out) if self.time_down[ind] else nn.Identity(),
561 | ]
562 | )
563 | )
564 |
565 |
566 | mid_dim = dims[-1]
567 | self.mid_block1 = ResnetBlock(mid_dim, mid_dim, time_dim)
568 | self.mid_attn = Residual(PreNorm(mid_dim, LinearAttention(mid_dim)))
569 | self.mid_block2 = ResnetBlock(mid_dim, mid_dim, time_dim)
570 |
571 | for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
572 |
573 | is_last = ind >= (num_resolutions - 1)
574 | self.ups.append(
575 | nn.ModuleList(
576 | [
577 | ResnetBlock(dim_out * 2, dim_in, time_dim),
578 | ResnetBlock(dim_in, dim_in, time_dim),
579 | Residual(PreNorm(dim_in, LinearAttention(dim_in))),
580 | Upsample(dim_in, new_upsample=new_upsample) if self.reversed_space_down[ind] else nn.Identity(),
581 | TimeUpsample(dim_in) if self.reversed_time_down[ind] else nn.Identity(),
582 | ]
583 | )
584 | )
585 |
586 | out_dim = out_dim or channels
587 | self.final_conv = nn.ModuleList([LayerNorm(dim), CausalConv3d(dim, out_dim, (3, 7, 7), padding=(1, 3, 3))])
588 |
589 | # cache the last two frame of feature map
590 | self._conv_num = count_conv3d(self.ups) + count_conv3d(self.downs) + \
591 | count_conv3d(self.mid_block1) + count_conv3d(self.mid_block2) + \
592 | count_conv3d(self.final_conv) + count_time_down_sample(self.downs)
593 | self._conv_idx = [0]
594 | self._feat_map = [None] * self._conv_num
595 |
596 | @property
597 | def dtype(self):
598 | return self.final_conv[1].weight.dtype
599 |
600 | def encode(self, x, t, context):
601 | h = []
602 | for idx, (backbone, backbone2, attn, downsample, time_downsample) in enumerate(self.downs):
603 | x = torch.cat([x, context[idx]], dim=1) if idx < self.condition_times else x
604 | x = backbone(x, t, feat_cache=self._feat_map, feat_idx=self._conv_idx)
605 | x = backbone2(x, t, feat_cache=self._feat_map, feat_idx=self._conv_idx)
606 | x = attn(x, feat_cache=self._feat_map, feat_idx=self._conv_idx)
607 | h.append(x)
608 | x = downsample(x)
609 | if isinstance(time_downsample, TimeDownsample):
610 | x = time_downsample(x, feat_cache=self._feat_map, feat_idx=self._conv_idx)
611 | else:
612 | x = time_downsample(x)
613 |
614 | x = self.mid_block1(x, t, feat_cache=self._feat_map, feat_idx=self._conv_idx)
615 | return x, h
616 |
617 | def decode(self, x, h, t):
618 | device = x.device
619 | dtype = x.dtype
620 | x = self.mid_attn(x)
621 | x = self.mid_block2(x, t, feat_cache=self._feat_map, feat_idx=self._conv_idx)
622 | for backbone, backbone2, attn, upsample, time_upsample in self.ups:
623 | reference = h.pop()
624 | if x.shape[2:] != reference.shape[2:]:
625 | x = F.interpolate(
626 | x.float(), size=reference.shape[2:], mode='nearest'
627 | ).type_as(x)
628 | x = torch.cat((x, reference), dim=1)
629 | x = module_forward(backbone, x, t, feat_cache=self._feat_map, feat_idx=self._conv_idx)
630 | x = module_forward(backbone2, x, t, feat_cache=self._feat_map, feat_idx=self._conv_idx)
631 | x = module_forward(attn, x)
632 | x = module_forward(upsample, x)
633 | x = module_forward(time_upsample, x)
634 | x = x.to(device).to(dtype)
635 |
636 | x = self.final_conv[0](x)
637 | if self._feat_map is not None:
638 | idx = self._conv_idx[0]
639 | cache_x = x[:, :, -CACHE_T:, :, :].clone()
640 | if cache_x.shape[2] < 2 and self._feat_map[idx] is not None:
641 | # cache last frame of last two chunk
642 | cache_x = torch.cat([self._feat_map[idx][:, :, -1, :, :].unsqueeze(2), cache_x], dim=2)
643 | x = self.final_conv[1](x, self._feat_map[idx])
644 | self._feat_map[idx] = cache_x
645 | self._conv_idx[0] += 1
646 | else:
647 | x = self.final_conv[1](x)
648 | return x
649 |
650 | def forward(self, x, time=None, context=None):
651 | self._conv_idx = [0]
652 | t = None
653 | if self.time_mlp is not None:
654 | time_mlp = self.time_mlp.float()
655 | t = time_mlp(time).to(self.dtype)
656 |
657 | x = x.to(self.dtype)
658 | x, h = self.encode(x, t, context)
659 | return self.decode(x, h, t)
660 |
661 | def clear_cache(self):
662 | self._feat_map = [None] * self._conv_num
663 | self._conv_idx = [0]
664 |
665 |
666 | def extract(a, t, x_shape):
667 | a = a.to(t.device)
668 | b, *_ = t.shape
669 | out = a.gather(-1, t)
670 | return out.reshape(b, *((1,) * (len(x_shape) - 1)))
671 |
672 |
673 | def cosine_beta_schedule(timesteps, s=0.008):
674 | """
675 | cosine schedule
676 | as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
677 | """
678 | steps = timesteps + 1
679 | x = np.linspace(0, steps, steps)
680 | alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2
681 | alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
682 | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
683 | return np.clip(betas, a_min=0, a_max=0.999)
684 |
685 |
686 | class GaussianDiffusion(nn.Module):
687 | def __init__(
688 | self,
689 | denoise_fn=None,
690 | context_fn=None,
691 | ae_fn=None,
692 | num_timesteps=8192,
693 | pred_mode="x",
694 | var_schedule="cosine",
695 | ):
696 | super().__init__()
697 | self.denoise_fn = denoise_fn
698 | self.context_fn = context_fn
699 | self.ae_fn = ae_fn
700 | self.otherlogs = {}
701 | self.var_schedule = var_schedule
702 | self.sample_steps = None
703 | assert pred_mode in ["noise", "x", "v"]
704 | self.pred_mode = pred_mode
705 | to_torch = partial(torch.tensor, dtype=torch.float32)
706 |
707 | train_betas = cosine_beta_schedule(num_timesteps)
708 | train_alphas = 1.0 - train_betas
709 | train_alphas_cumprod = np.cumprod(train_alphas, axis=0)
710 | (num_timesteps,) = train_betas.shape
711 | self.num_timesteps = int(num_timesteps)
712 |
713 | self.train_snr=to_torch(train_alphas_cumprod / (1 - train_alphas_cumprod))
714 | self.train_betas=to_torch(train_betas)
715 | self.train_alphas_cumprod=to_torch(train_alphas_cumprod)
716 | self.train_sqrt_alphas_cumprod=to_torch(np.sqrt(train_alphas_cumprod))
717 | self.train_sqrt_one_minus_alphas_cumprod=to_torch(np.sqrt(1.0 - train_alphas_cumprod))
718 | self.train_sqrt_recip_alphas_cumprod=to_torch(np.sqrt(1.0 / train_alphas_cumprod))
719 | self.train_sqrt_recipm1_alphas_cumprod=to_torch(np.sqrt(1.0 / train_alphas_cumprod - 1))
720 |
721 | def set_sample_schedule(self, sample_steps, device):
722 | self.sample_steps = sample_steps
723 | if sample_steps != 1:
724 | indice = torch.linspace(0, self.num_timesteps - 1, sample_steps, device=device).long()
725 | else:
726 | indice = torch.tensor([self.num_timesteps - 1], device=device).long()
727 | self.train_alphas_cumprod = self.train_alphas_cumprod.to(device)
728 | self.train_snr = self.train_snr.to(device)
729 | self.alphas_cumprod = self.train_alphas_cumprod[indice]
730 | self.snr = self.train_snr[indice]
731 | self.index = torch.arange(self.num_timesteps, device=device)[indice]
732 | self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0)
733 | self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
734 | self.sqrt_alphas_cumprod_prev = torch.sqrt(self.alphas_cumprod_prev)
735 | self.one_minus_alphas_cumprod = 1.0 - self.alphas_cumprod
736 | self.one_minus_alphas_cumprod_prev = 1.0 - self.alphas_cumprod_prev
737 | self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)
738 | self.sqrt_one_minus_alphas_cumprod_prev = torch.sqrt(1.0 - self.alphas_cumprod_prev)
739 | self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod)
740 | self.sqrt_recip_alphas_cumprod_prev = torch.sqrt(1.0 / self.alphas_cumprod_prev)
741 | self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod - 1)
742 | self.sigma = self.sqrt_one_minus_alphas_cumprod_prev / self.sqrt_one_minus_alphas_cumprod * torch.sqrt(1.0 - self.alphas_cumprod / self.alphas_cumprod_prev)
743 |
744 | def predict_noise_from_start(self, x_t, t, x0):
745 | return (
746 | (extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / \
747 | extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
748 | )
749 |
750 | def predict_v(self, x_start, t, noise):
751 | if self.training:
752 | return (
753 | extract(self.train_sqrt_alphas_cumprod, t, x_start.shape) * noise -
754 | extract(self.train_sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start
755 | )
756 | else:
757 | return (
758 | extract(self.sqrt_alphas_cumprod, t, x_start.shape) * noise -
759 | extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start
760 | )
761 |
762 | def predict_start_from_v(self, x_t, t, v):
763 | if self.training:
764 | return (
765 | extract(self.train_sqrt_alphas_cumprod, t, x_t.shape) * x_t -
766 | extract(self.train_sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
767 | )
768 | else:
769 | return (
770 | extract(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
771 | extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
772 | )
773 |
774 | def predict_start_from_noise(self, x_t, t, noise):
775 | if self.training:
776 | return (
777 | extract(self.train_sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
778 | - extract(self.train_sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
779 | )
780 | else:
781 | return (
782 | extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
783 | - extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
784 | )
785 |
786 | def ddim(self, x, t, context, clip_denoised, eta=0):
787 | if self.denoise_fn.embd_type == "01":
788 | fx = self.denoise_fn(x, self.index[t].float().unsqueeze(-1) / self.num_timesteps, context=context)
789 | else:
790 | fx = self.denoise_fn(x, self.index[t], context=context)
791 | fx = fx.float()
792 | if self.pred_mode == "noise":
793 | x_recon = self.predict_start_from_noise(x, t=t, noise=fx)
794 | elif self.pred_mode == "x":
795 | x_recon = fx
796 | elif self.pred_mode == "v":
797 | x_recon = self.predict_start_from_v(x, t=t, v=fx)
798 | if clip_denoised:
799 | x_recon.clamp_(-1.0, 1.0)
800 | noise = fx if self.pred_mode == "noise" else self.predict_noise_from_start(x, t=t, x0=x_recon)
801 | x_next = (
802 | extract(self.sqrt_alphas_cumprod_prev, t, x.shape) * x_recon
803 | + torch.sqrt(
804 | (extract(self.one_minus_alphas_cumprod_prev, t, x.shape)
805 | - (eta * extract(self.sigma, t, x.shape)) ** 2).clamp(min=0)
806 | )
807 | * noise + eta * extract(self.sigma, t, x.shape) * torch.randn_like(noise)
808 | )
809 | return x_next
810 |
811 | def p_sample(self, x, t, context, clip_denoised, eta=0):
812 | return self.ddim(x=x, t=t, context=context, clip_denoised=clip_denoised, eta=eta)
813 |
814 | def p_sample_loop(self, shape, context, clip_denoised=False, init=None, eta=0):
815 | device = self.alphas_cumprod.device
816 |
817 | b = shape[0]
818 | img = torch.zeros(shape, device=device) if init is None else init
819 | for count, i in enumerate(reversed(range(0, self.sample_steps))):
820 | time = torch.full((b,), i, device=device, dtype=torch.long)
821 | img = self.p_sample(
822 | img,
823 | time,
824 | context=context,
825 | clip_denoised=clip_denoised,
826 | eta=eta,
827 | )
828 | return img
829 |
830 | @torch.no_grad()
831 | def compress(
832 | self,
833 | images,
834 | sample_steps=10,
835 | init=None,
836 | clip_denoised=True,
837 | eta=0,
838 | ):
839 | context_dict = self.context_fn(images, 'test')
840 | self.set_sample_schedule(
841 | self.num_timesteps if (sample_steps is None) else sample_steps,
842 | context_dict["output"][0].device,
843 | )
844 | return self.p_sample_loop(
845 | images.shape, context_dict["output"],
846 | clip_denoised=clip_denoised, init=init, eta=eta
847 | )
848 |
849 | @torch.no_grad()
850 | def diffusion_decode(
851 | self,
852 | latent,
853 | sample_steps=30,
854 | init=None,
855 | time=None,
856 | clip_denoised=False,
857 | eta=0,
858 | ):
859 | context = self.context_fn.decode(latent)
860 |
861 | # breakpoint()
862 | self.set_sample_schedule(
863 | self.num_timesteps if (sample_steps is None) else sample_steps,
864 | context[0].device,
865 | )
866 |
867 | img = init
868 | img = self.p_sample(
869 | img,
870 | time,
871 | context=context,
872 | clip_denoised=clip_denoised,
873 | eta=eta,
874 | )
875 | return img
876 |
877 |
878 | class ConditionedDiffusionTokenizer(nn.Module):
879 |
880 | def __init__(self,
881 | pretrained=None, # pretrained ckpt path
882 | enc_bs=1, # mini-batch to loop for both enc and dec
883 | enc_frames=13, # mini-batch frames to loop for both enc and dec
884 | dec_bs=1, # mini-batch to loop for dec
885 | dec_frames=4, # mini-batch frames to loop for dec
886 | z_overlap=1, # mini-batch inference overlap
887 | sample_steps=10, # decode diffusion steps
888 | sample_gamma=0.8, # decode noise weight
889 | num_timesteps=8192,
890 | out_channels=3,
891 | context_dim=64,
892 | unet_dim=64,
893 | new_upsample=False,
894 | latent_dim=16,
895 | context_dim_mults = [1, 2, 3, 4],
896 | space_down = [1, 1, 1, 1],
897 | time_down = [0, 0, 1, 1],
898 | condition_times=4,
899 | **kwargs
900 | ):
901 | super().__init__()
902 | self.latent_scale = None
903 | self.enc_bs = enc_bs
904 | self.enc_frames = enc_frames
905 | self.dec_bs = dec_bs or enc_bs
906 | self.dec_frames = dec_frames or enc_frames
907 | self.space_factor = 2 ** sum(space_down)
908 | self.time_factor = 2 ** sum(time_down)
909 | self.sample_steps = sample_steps
910 | self.sample_gamma = sample_gamma
911 | self.z_overlap = z_overlap
912 | self.num_timesteps = num_timesteps
913 | self.condition_times = condition_times
914 | CausalConv3d.PAD_MODE = 'constant' if new_upsample else 'replicate'
915 | context_model = Compressor(
916 | dim=context_dim,
917 | latent_dim=latent_dim,
918 | new_upsample=new_upsample,
919 | channels=3,
920 | out_channels=out_channels,
921 | dim_mults=context_dim_mults,
922 | reverse_dim_mults=reversed(context_dim_mults),
923 | space_down=space_down,
924 | time_down=time_down)
925 | denoise_model = Unet(
926 | dim=unet_dim,
927 | channels=3,
928 | new_upsample=new_upsample,
929 | context_dim=context_dim,
930 | context_out_channels=out_channels,
931 | dim_mults=[1, 2, 3, 4, 5, 6],
932 | context_dim_mults=context_dim_mults,
933 | space_down=space_down,
934 | time_down=time_down,
935 | condition_times=condition_times
936 | )
937 | self.model = GaussianDiffusion(
938 | context_fn=context_model,
939 | denoise_fn=denoise_model,
940 | num_timesteps=self.num_timesteps,
941 | )
942 | self.z_dim = latent_dim
943 | if pretrained is not None:
944 | self.init_from_ckpt(pretrained)
945 |
946 | def init_from_ckpt(
947 | self,
948 | path: str,
949 | ) -> None:
950 | if path.endswith("safetensors"):
951 | sd = load_safetensors(path)
952 | else:
953 | # breakpoint()
954 | sd = torch.load(path, map_location="cpu", weights_only=False)
955 | sd = sd.get("state_dict", sd)
956 | sd = sd.get("module", sd)
957 | sd = {k:v.to(torch.float32) for k,v in sd.items()}
958 | missing, unexpected = self.load_state_dict(sd, strict=False)
959 | logging.info(
960 | f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
961 | )
962 | if len(missing) > 0:
963 | logging.info(f"Missing Keys: {missing}")
964 | if len(unexpected) > 0:
965 | logging.info(f"Unexpected Keys: {unexpected}")
966 |
967 | @torch.no_grad()
968 | def forward(self, x, **kwargs):
969 | latent = self.encode(x)
970 | x_recon = self.decode(latent)
971 | return x_recon
972 |
973 | @torch.no_grad()
974 | def forward_with_info(self, x, **kwargs):
975 | latent, posterior = self.encode(x, return_posterior=True)
976 | x_recon = self.decode(latent)
977 | mu = posterior.mean
978 | logvar = posterior.logvar
979 | z_std = mu.std()
980 | return x_recon, mu, logvar, z_std
981 |
982 | @torch.no_grad()
983 | def encode(self, x, scale_factor=1.0, return_posterior=False, **kwargs):
984 | N, C, T, H, W = x.shape
985 | mini_batch_size = min(self.enc_bs, N) if self.enc_bs else N
986 | enc_frames_ratio = (1024/H * 1024/W) ** 1.0
987 | enc_frames = int(self.enc_frames * enc_frames_ratio/4) * 4
988 | enc_frames = max(enc_frames, 4)
989 | logging.debug(f"enc: {x.shape}, {self.enc_frames}, {enc_frames_ratio}, {enc_frames}")
990 | z_overlap = 0
991 | frame_overlap = 0
992 |
993 | mini_frames = min(enc_frames, T) if enc_frames else T
994 | n_batches = int(math.ceil(N / mini_batch_size))
995 | n_frame_batches = math.ceil((T-frame_overlap) / max(1, mini_frames-frame_overlap))
996 | n_frame_batches = max(int(n_frame_batches), 1)
997 | remainder = T % mini_frames
998 | z = list()
999 | mean_std_list = list()
1000 | for i_batch in range(n_batches):
1001 | z_batch = []
1002 | mean_std_batch = []
1003 | x_batch = x[i_batch * mini_batch_size : (i_batch + 1) * mini_batch_size]
1004 | frame_end = 0
1005 | for i_frames in range(n_frame_batches):
1006 | frame_start = frame_end
1007 | if i_frames == 0 and remainder > 0:
1008 | frame_end = frame_start + remainder
1009 | else:
1010 | frame_end = frame_start + mini_frames
1011 | i_batch_input = x_batch[:, :, frame_start:frame_end, ...]
1012 | logging.debug(f'enc: {i_frames}, {i_batch_input.shape}')
1013 | i_batch_input = i_batch_input.to(next(self.parameters()).device)
1014 | z_frames, mean_std = self.model.context_fn.encode(i_batch_input, deterministic=True)
1015 | z_batch.append(z_frames[:, :, (z_overlap if i_frames>0 else 0):, ...])
1016 | mean_std_batch.append(mean_std[:, :, (z_overlap if i_frames>0 else 0):, ...])
1017 | z_batch = torch.cat(z_batch, dim=2)
1018 | z.append(z_batch)
1019 | mean_std_batch = torch.cat(mean_std_batch, dim=2)
1020 | mean_std_list.append(mean_std_batch)
1021 | self.model.context_fn.clear_cache()
1022 | latent = torch.cat(z, 0)
1023 | mean_std = torch.cat(mean_std_list, 0)
1024 | posterior = DiagonalGaussianDistribution(mean_std)
1025 | latent *= scale_factor
1026 | if return_posterior:
1027 | return latent, posterior
1028 | else:
1029 | return latent
1030 |
1031 | @torch.no_grad()
1032 | def decode(self, latent, scale_factor=1.0, **kwargs):
1033 | N, C, T, H, W = latent.shape
1034 | latent = latent/scale_factor
1035 | mini_batch_size = min(self.dec_bs, N) if self.dec_bs else N
1036 | n_batches = int(math.ceil(N / mini_batch_size))
1037 | dec_frames_ratio = (1024/self.space_factor/H * 1024/self.space_factor/W) ** 1.0
1038 | dec_frames = int(self.dec_frames * dec_frames_ratio)
1039 | dec_frames = max(1, int(dec_frames))
1040 | logging.debug(f"dec: {latent.shape}, {self.dec_frames}, {dec_frames_ratio}, {dec_frames}")
1041 | z_overlap = 0
1042 | frame_overlap = 0
1043 |
1044 | assert dec_frames >= z_overlap + 1, f"dec_frames {dec_frames} too small"
1045 | mini_frames = min(dec_frames, T) if dec_frames else T
1046 | n_frame_batches = math.ceil((T-z_overlap) / max(1, (mini_frames-z_overlap)))
1047 | n_frame_batches = max(int(n_frame_batches), 1)
1048 | dec = list()
1049 | target_shape = [N, 3, T*self.time_factor, H*self.space_factor, W*self.space_factor]
1050 | init_noise = self.sample_gamma*torch.randn(target_shape, dtype=latent.dtype, device=latent.device)
1051 |
1052 | for i_batch in range(n_batches):
1053 | x_batch = latent[i_batch * mini_batch_size : (i_batch + 1) * mini_batch_size]
1054 | noise_batch = init_noise[i_batch * mini_batch_size : (i_batch + 1) * mini_batch_size]
1055 | z_batch_list = [None] * n_frame_batches
1056 | for count, t_idx in enumerate(reversed(range(0, self.sample_steps))):
1057 | for i_frames in range(n_frame_batches):
1058 | global CLIP_FIRST_CHUNK
1059 | CLIP_FIRST_CHUNK = True if i_frames == 0 else False
1060 |
1061 | latent_frame_start = i_frames * (mini_frames-z_overlap)
1062 | latent_frame_end = latent_frame_start+mini_frames
1063 | i_batch_input = x_batch[:, :, latent_frame_start:latent_frame_end, ...]
1064 | if count == 0:
1065 | frame_start = latent_frame_start*self.time_factor
1066 | frame_end = frame_start + (mini_frames)*self.time_factor
1067 | if CLIP_FIRST_CHUNK:
1068 | frame_start += 3
1069 | cur_noise = noise_batch[:, :, frame_start:frame_end, ...]
1070 | else:
1071 | cur_noise = z_batch_list[i_frames]
1072 | time = torch.full((cur_noise.shape[0],), t_idx, device=latent.device, dtype=torch.long)
1073 | logging.debug(f'dec: {i_frames}, {i_batch_input.shape}, {cur_noise.shape}')
1074 | x_rec = self.model.diffusion_decode(i_batch_input,
1075 | sample_steps=self.sample_steps, init=cur_noise, time=time)
1076 | z_batch_list[i_frames] = x_rec[:, :, (0 if i_frames==0 else frame_overlap):, ...]
1077 | # clear cache
1078 | self.model.context_fn.clear_cache()
1079 | self.model.denoise_fn.clear_cache()
1080 |
1081 | dec.append(torch.cat(z_batch_list, dim=2))
1082 | del z_batch_list, x_batch, noise_batch
1083 |
1084 | dec = torch.cat(dec, 0)
1085 | return dec
1086 |
1087 |
1088 | def load_cdt_base(
1089 | ckpt = None,
1090 | device='cpu',
1091 | eval=True,
1092 | sampling_step=1,
1093 | **kwargs
1094 | ):
1095 | ckpt = ckpt or './pretrained/cdt_base.ckpt'
1096 | print(f"Loading CDT-base from {ckpt}")
1097 |
1098 |
1099 | model = ConditionedDiffusionTokenizer(pretrained=ckpt,
1100 | enc_frames=kwargs.pop('enc_frames', 4),
1101 | dec_frames=kwargs.pop('dec_frames', 1),
1102 | space_down=[0,1,1,1],
1103 | out_channels=16,
1104 | latent_dim=16,
1105 | context_dim=128,
1106 | new_upsample=True,
1107 | sample_steps=sampling_step,
1108 | sample_gamma=kwargs.pop('sample_gamma', 0.8),
1109 | **kwargs)
1110 | model = model.to(device)
1111 | if eval:
1112 | model = model.eval()
1113 | return model
1114 |
1115 | def load_cdt_small(
1116 | ckpt = None,
1117 | device='cpu',
1118 | eval=True,
1119 | latent_dim=16,
1120 | diffusion_step=8192,
1121 | sampling_step=1,
1122 | condition_times=4,
1123 | **kwargs
1124 | ):
1125 | ckpt = ckpt or "./pretrained/cdt_small.ckpt"
1126 |
1127 | model = ConditionedDiffusionTokenizer(pretrained=ckpt,
1128 | enc_frames=kwargs.pop('enc_frames', 4),
1129 | dec_frames=kwargs.pop('dec_frames', 1),
1130 | space_down=[1, 1, 1, 0],
1131 | out_channels=3,
1132 | latent_dim=latent_dim,
1133 | context_dim=64,
1134 | num_timesteps=diffusion_step,
1135 | condition_times=condition_times,
1136 | sample_steps=sampling_step,
1137 | sample_gamma=kwargs.pop('sample_gamma', 0.8),
1138 | **kwargs)
1139 | model = model.to(device)
1140 | if eval:
1141 | model = model.eval()
1142 | return model
1143 |
1144 | def load_cdt(version='base', dtype=torch.float16, *args, **kwargs):
1145 | VERSIONS = {
1146 | 'base': (load_cdt_base, 1.3),
1147 | 'small': (load_cdt_small, 1.3),
1148 | }
1149 | if version not in VERSIONS:
1150 | print(f"ERROR: wrong version '{version}' of CDT, not in '{VERSIONS.keys()}'")
1151 | return None
1152 | model_func, latent_scale = VERSIONS[version]
1153 | model = model_func(*args, **kwargs)
1154 | model.latent_scale = latent_scale
1155 | model = model.to(dtype)
1156 | return model
1157 |
1158 |
--------------------------------------------------------------------------------