├── .gitattributes
├── .gitignore
├── LICENSE
├── README.md
├── assets
├── framework.jpeg
├── framework.jpg
├── sky-long-001.gif
├── sky-long-002.gif
├── sky-long-003.gif
├── t2v-001.gif
├── t2v-002.gif
├── t2v-003.gif
├── t2v-004.gif
├── t2v-005.gif
├── t2v-006.gif
├── t2v-007.gif
├── t2v-008.gif
├── ucf-long-001.gif
├── ucf-long-002.gif
└── ucf-long-003.gif
├── configs
├── lvdm_long
│ ├── sky_interp.yaml
│ └── sky_pred.yaml
├── lvdm_short
│ ├── sky.yaml
│ ├── taichi.yaml
│ ├── text2video.yaml
│ └── ucf.yaml
└── videoae
│ ├── sky.yaml
│ ├── taichi.yaml
│ ├── ucf.yaml
│ └── ucf_videodata.yaml
├── input
└── prompts.txt
├── lvdm
├── data
│ ├── frame_dataset.py
│ ├── split_ucf101.py
│ ├── taichi.py
│ └── ucf.py
├── models
│ ├── autoencoder.py
│ ├── autoencoder3d.py
│ ├── ddpm3d.py
│ ├── losses
│ │ ├── __init__.py
│ │ ├── contperceptual.py
│ │ └── vqperceptual.py
│ └── modules
│ │ ├── aemodules.py
│ │ ├── aemodules3d.py
│ │ ├── attention_temporal.py
│ │ ├── condition_modules.py
│ │ ├── distributions.py
│ │ ├── ema.py
│ │ ├── openaimodel3d.py
│ │ └── util.py
├── samplers
│ └── ddim.py
└── utils
│ ├── callbacks.py
│ ├── common_utils.py
│ ├── dist_utils.py
│ ├── log.py
│ └── saving_utils.py
├── main.py
├── requirements.txt
├── requirements_h800_gpu.txt
├── scripts
├── eval_cal_fvd_kvd.py
├── fvd_utils
│ ├── fvd_utils.py
│ └── pytorch_i3d.py
├── sample_long_videos_utils.py
├── sample_text2video.py
├── sample_uncond.py
├── sample_uncond_long_videos.py
└── sample_utils.py
├── setup.py
└── shellscripts
├── eval_lvdm_short.sh
├── sample_lvdm_long.sh
├── sample_lvdm_short.sh
├── sample_lvdm_text2video.sh
├── train_lvdm_interp_sky.sh
├── train_lvdm_pred_sky.sh
├── train_lvdm_short.sh
├── train_lvdm_videoae.sh
└── train_lvdm_videoae_ucf.sh
/.gitattributes:
--------------------------------------------------------------------------------
1 | *.gif filter=lfs diff=lfs merge=lfs -text
2 | assets/*.gif filter=lfs diff=lfs merge=lfs -text
3 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | *.pyc*
3 | __pycache__
4 | .vscode/*
5 | *.egg-info/
6 | results/
7 | *.pt
8 | *.ckpt
9 | clip-vit-large-patch14
10 | results
11 | temp
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Cassie
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
LVDM: Latent Video Diffusion Models for High-Fidelity Long Video Generation
5 |
6 |

7 |
8 |
9 |
16 |
17 |
18 | 1 The Hong Kong University of Science and Technology 2 Tencent AI Lab
19 |
20 |
21 |
22 |
23 |
TL;DR: An efficient video diffusion model that can:
24 | 1️⃣ conditionally generate videos based on input text;
25 | 2️⃣ unconditionally generate videos with thousands of frames.
26 |
27 |
28 |
29 |
30 |
31 |
32 | ## 🍻 Results
33 | ### ☝️ Text-to-Video Generation
34 |
35 |
36 |
37 | "A corgi is swimming fastly" |
38 | "astronaut riding a horse" |
39 | "A glass bead falling into water with a huge splash. Sunset in the background" |
40 | "A beautiful sunrise on mars. High definition, timelapse, dramaticcolors." |
41 | "A bear dancing and jumping to upbeat music, moving his whole body." |
42 | "An iron man surfing in the sea. cartoon style" |
43 |
44 |  |
45 |  |
46 |  |
47 |  |
48 |  |
49 |  |
50 |
51 |
52 |
53 | ### ✌️ Unconditional Long Video Generation (40 seconds)
54 |
55 |  |
56 |  |
57 |  |
58 |  |
59 |  |
60 |  |
61 |
62 |
63 |
64 |
65 | ## ⏳ TODO
66 | - [x] Release pretrained text-to-video generation models and inference code
67 | - [x] Release unconditional video generation models
68 | - [x] Release training code
69 | - [ ] Update training and sampling for long video generation
70 |
71 |
72 | ---
73 | ## ⚙️ Setup
74 |
75 | ### Install Environment via Anaconda
76 | ```bash
77 | conda create -n lvdm python=3.8.5
78 | conda activate lvdm
79 | pip install -r requirements.txt
80 | ```
81 | ### Pretrained Models and Used Datasets
82 |
83 |
84 |
85 | Download the pretrained checkpoints via the following commands in Linux terminal:
86 | ```
87 | mkdir -p models/ae
88 | mkdir -p models/lvdm_short
89 | mkdir -p models/t2v
90 |
91 | # sky timelapse
92 | wget -O models/ae/ae_sky.ckpt https://huggingface.co/Yingqing/LVDM/resolve/main/ae/ae_sky.ckpt
93 | wget -O models/lvdm_short/short_sky.ckpt https://huggingface.co/Yingqing/LVDM/resolve/main/lvdm_short/short_sky.ckpt
94 |
95 | # taichi
96 | wget -O models/ae/ae_taichi.ckpt https://huggingface.co/Yingqing/LVDM/resolve/main/ae/ae_taichi.ckpt
97 | wget -O models/lvdm_short/short_taichi.ckpt https://huggingface.co/Yingqing/LVDM/resolve/main/lvdm_short/short_taichi.ckpt
98 |
99 | # text2video
100 | wget -O models/t2v/model.ckpt https://huggingface.co/Yingqing/LVDM/resolve/main/lvdm_short/t2v.ckpt
101 |
102 | ```
103 |
104 | Prepare UCF-101 dataset
105 | ```
106 | mkdir temp; cd temp
107 |
108 | # Download UCF-101 from the official website https://www.crcv.ucf.edu/data/UCF101.php (The UCF101 data )
109 |
110 | wget https://www.crcv.ucf.edu/data/UCF101/UCF101.rar --no-check-certificate
111 | unrar x UCF101.rar
112 |
113 | # Download annotations from https://www.crcv.ucf.edu/data/UCF101.php (The Train/Test Splits for Action Recognition on UCF101 data set)
114 |
115 | wget https://www.crcv.ucf.edu/data/UCF101/UCF101TrainTestSplits-RecognitionTask.zip --no-check-certificate
116 | unzip UCF101TrainTestSplits-RecognitionTask.zip
117 |
118 | # Split the train and test split
119 | cd ..
120 | python lvdm/data/split_ucf101.py # please check this script
121 |
122 | ```
123 |
125 |
126 |
127 |
128 | Download manually:
129 | - Sky Timelapse: [VideoAE](https://huggingface.co/Yingqing/LVDM/blob/main/ae/ae_sky.ckpt), [LVDM_short](https://huggingface.co/Yingqing/LVDM/blob/main/lvdm_short/short_sky.ckpt), [LVDM_pred](TBD), [LVDM_interp](TBD), [dataset](https://github.com/weixiong-ur/mdgan)
130 | - Taichi: [VideoAE](https://huggingface.co/Yingqing/LVDM/blob/main/ae/ae_taichi.ckpt), [LVDM_short](https://huggingface.co/Yingqing/LVDM/blob/main/lvdm_short/short_taichi.ckpt), [dataset](https://github.com/AliaksandrSiarohin/first-order-model/blob/master/data/taichi-loading/README.md)
131 | - Text2Video: [model](https://huggingface.co/Yingqing/LVDM/blob/main/lvdm_short/t2v.ckpt)
132 |
133 | ---
134 | ## 💫 Inference
135 | ### Sample Short Videos
136 | - unconditional generation
137 |
138 | ```
139 | bash shellscripts/sample_lvdm_short.sh
140 | ```
141 | - text to video generation
142 | ```
143 | bash shellscripts/sample_lvdm_text2video.sh
144 | ```
145 |
146 | ### Sample Long Videos
147 | ```
148 | bash shellscripts/sample_lvdm_long.sh
149 | ```
150 |
151 | ---
152 | ## 💫 Training
153 |
154 | ### Train video autoencoder
155 | ```
156 | bash shellscripts/train_lvdm_videoae.sh
157 | ```
158 | - remember to set `PROJ_ROOT`, `EXPNAME`, `DATADIR`, and `CONFIG`.
159 |
160 | ### Train unconditional lvdm for short video generation
161 | ```
162 | bash shellscripts/train_lvdm_short.sh
163 | ```
164 | - remember to set `PROJ_ROOT`, `EXPNAME`, `DATADIR`, `AEPATH` and `CONFIG`.
165 |
166 | ### Train unconditional lvdm for long video generation
167 | ```
168 | # TBD
169 | ```
170 |
171 | ---
172 | ## 💫 Evaluation
173 | ```
174 | bash shellscripts/eval_lvdm_short.sh
175 | ```
176 | - remember to set `DATACONFIG`, `FAKEPATH`, `REALPATH`, and `RESDIR`.
177 | ---
178 |
179 | ## 📃 Abstract
180 | AI-generated content has attracted lots of attention recently, but photo-realistic video synthesis is still challenging. Although many attempts using GANs and autoregressive models have been made in this area, the visual quality and length of generated videos are far from satisfactory. Diffusion models have shown remarkable results recently but require significant computational resources. To address this, we introduce lightweight video diffusion models by leveraging a low-dimensional 3D latent space, significantly outperforming previous pixel-space video diffusion models under a limited computational budget. In addition, we propose hierarchical diffusion in the latent space such that longer videos with more than one thousand frames can be produced. To further overcome the performance degradation issue for long video generation, we propose conditional latent perturbation and unconditional guidance that effectively mitigate the accumulated errors during the extension of video length. Extensive experiments on small domain datasets of different categories suggest that our framework generates more realistic and longer videos than previous strong baselines. We additionally provide an extension to large-scale text-to-video generation to demonstrate the superiority of our work. Our code and models will be made publicly available.
181 |
182 |
183 | ## 🔮 Pipeline
184 |
185 |
186 |
187 |
188 |
189 | ---
190 | ## 😉 Citation
191 |
192 | ```
193 | @article{he2022lvdm,
194 | title={Latent Video Diffusion Models for High-Fidelity Long Video Generation},
195 | author={Yingqing He and Tianyu Yang and Yong Zhang and Ying Shan and Qifeng Chen},
196 | year={2022},
197 | eprint={2211.13221},
198 | archivePrefix={arXiv},
199 | primaryClass={cs.CV}
200 | }
201 | ```
202 |
203 | ## 🤗 Acknowledgements
204 | We built our code partially based on [latent diffusion models](https://github.com/CompVis/latent-diffusion) and [TATS](https://github.com/SongweiGe/TATS). Thanks the authors for sharing their awesome codebases! We aslo adopt Xintao Wang's [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN) for upscaling our text-to-video generation results. Thanks for their wonderful work!
--------------------------------------------------------------------------------
/assets/framework.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YingqingHe/LVDM/d251dccfbf6352826f5c5681abd86e87ed7e6371/assets/framework.jpeg
--------------------------------------------------------------------------------
/assets/framework.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YingqingHe/LVDM/d251dccfbf6352826f5c5681abd86e87ed7e6371/assets/framework.jpg
--------------------------------------------------------------------------------
/assets/sky-long-001.gif:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:25bc40a8cfcfaf2ab8e1de96f136a14d61099f93319627056156ad265df3d649
3 | size 9214020
4 |
--------------------------------------------------------------------------------
/assets/sky-long-002.gif:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:e513b06e4cc56afc1626539553297412cfaec4a5e246f2b9eff51171597aeb29
3 | size 9616345
4 |
--------------------------------------------------------------------------------
/assets/sky-long-003.gif:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:491a5072a19828fd5aa0bc0094c0b95fdfb51e29b91bf896be16a8f704b30e7b
3 | size 9907437
4 |
--------------------------------------------------------------------------------
/assets/t2v-001.gif:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:8e59eebb8da31a1fc1e2f38ab0042bf6e2b0db5b671c8ffaec603f8dc9643a29
3 | size 4873384
4 |
--------------------------------------------------------------------------------
/assets/t2v-002.gif:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:98d7e6dd8b0841d74e96c695f62e1a0194748d7566f993b80409477f37124e15
3 | size 5200960
4 |
--------------------------------------------------------------------------------
/assets/t2v-003.gif:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:7c8e8c0fbd617bdda67160d99b5f07b1de85991073e21a7e0833ac99ffb18fd1
3 | size 4522552
4 |
--------------------------------------------------------------------------------
/assets/t2v-004.gif:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:ef0f32351de73248f4394f0dd101218399b7e9f47da73b88d69140416a47b6f5
3 | size 2402692
4 |
--------------------------------------------------------------------------------
/assets/t2v-005.gif:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:fc9b8a763ead3dff469de42805bd0fec3df3b545213acba4ce0b5a42b56e4a22
3 | size 2754672
4 |
--------------------------------------------------------------------------------
/assets/t2v-006.gif:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:00ffc6692300905413327e0982368e0774b7abd22635523f4b66d128db02e405
3 | size 4607083
4 |
--------------------------------------------------------------------------------
/assets/t2v-007.gif:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:44ce784dca47e1695c115440bdf797d220e7b70bc243df1ad216ba0c7c2ba1a4
3 | size 5966789
4 |
--------------------------------------------------------------------------------
/assets/t2v-008.gif:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:fd881c7daa09503c00331c23124dac8242f09676daa3eee1b329f374b3dd112b
3 | size 3926517
4 |
--------------------------------------------------------------------------------
/assets/ucf-long-001.gif:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:0d788762c3d8edc1fb4ef35cee358fd094fe951ebec5a7fd244b8a770e2389d3
3 | size 8980149
4 |
--------------------------------------------------------------------------------
/assets/ucf-long-002.gif:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:7a08a0794cd1df65904692f3712cc2f297b1319ebba2a716236f13f56691ccda
3 | size 7676461
4 |
--------------------------------------------------------------------------------
/assets/ucf-long-003.gif:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:c08c0815540fc78dcd198415416984a26d37946dcad5a02ac8906ec94c0a3140
3 | size 8280966
4 |
--------------------------------------------------------------------------------
/configs/lvdm_long/sky_interp.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 8.0e-5 #1.5e-04
3 | scale_lr: False
4 | target: lvdm.models.ddpm3d.FrameInterpPredLatentDiffusion
5 | params:
6 | linear_start: 0.0015
7 | linear_end: 0.0155
8 | log_every_t: 200
9 | timesteps: 1000
10 | loss_type: l1
11 | first_stage_key: image
12 | cond_stage_key: "image"
13 | image_size: 32
14 | channels: 4
15 | monitor: val/loss_simple_ema
16 | conditioning_key: concat-adm-mask
17 | cond_stage_config: null
18 | noisy_cond: True
19 | max_noise_level: 250
20 | cond_stage_trainable: False
21 | concat_mode: False
22 | scale_by_std: False
23 | scale_factor: 0.33422927
24 | shift_factor: 1.4606637
25 | encoder_type: 3d
26 | rand_temporal_mask: true
27 | p_interp: 0.9
28 | p_pred: 0.0
29 | n_prevs: null
30 | split_clips: False
31 | downfactor_t: null # used for split video frames to clips before encoding
32 | clip_length: null
33 |
34 | unet_config:
35 | target: lvdm.models.modules.openaimodel3d.FrameInterpPredUNet
36 | params:
37 | num_classes: 251 # timesteps for noise conditoining
38 | image_size: 32
39 | in_channels: 5
40 | out_channels: 4
41 | model_channels: 256
42 | attention_resolutions:
43 | - 8
44 | - 4
45 | - 2
46 | num_res_blocks: 3
47 | channel_mult:
48 | - 1
49 | - 2
50 | - 3
51 | - 4
52 | num_heads: 4
53 | use_temporal_transformer: False
54 | use_checkpoint: true
55 | legacy: False
56 | # temporal
57 | kernel_size_t: 1
58 | padding_t: 0
59 | temporal_length: 5
60 | use_relative_position: True
61 | use_scale_shift_norm: True
62 | first_stage_config:
63 | target: lvdm.models.autoencoder3d.AutoencoderKL
64 | params:
65 | monitor: "val/rec_loss"
66 | embed_dim: 4
67 | lossconfig: __is_first_stage__
68 | ddconfig:
69 | double_z: True
70 | z_channels: 4
71 | encoder:
72 | target: lvdm.models.modules.aemodules3d.Encoder
73 | params:
74 | n_hiddens: 32
75 | downsample: [4, 8, 8]
76 | image_channel: 3
77 | norm_type: group
78 | padding_type: replicate
79 | double_z: True
80 | z_channels: 4
81 | decoder:
82 | target: lvdm.models.modules.aemodules3d.Decoder
83 | params:
84 | n_hiddens: 32
85 | upsample: [4, 8, 8]
86 | z_channels: 4
87 | image_channel: 3
88 | norm_type: group
89 |
90 | data:
91 | target: main.DataModuleFromConfig
92 | params:
93 | batch_size: 2
94 | num_workers: 0
95 | wrap: false
96 | train:
97 | target: lvdm.data.frame_dataset.VideoFrameDataset
98 | params:
99 | data_root: /dockerdata/sky_timelapse
100 | resolution: 256
101 | video_length: 20
102 | dataset_name: sky
103 | subset_split: train
104 | spatial_transform: center_crop_resize
105 | clip_step: 1
106 | temporal_transform: rand_clips
107 | validation:
108 | target: lvdm.data.frame_dataset.VideoFrameDataset
109 | params:
110 | data_root: /dockerdata/sky_timelapse
111 | resolution: 256
112 | video_length: 20
113 | dataset_name: sky
114 | subset_split: test
115 | spatial_transform: center_crop_resize
116 | clip_step: 1
117 | temporal_transform: rand_clips
118 |
119 | lightning:
120 | callbacks:
121 | image_logger:
122 | target: lvdm.utils.callbacks.ImageLogger
123 | params:
124 | batch_frequency: 2000
125 | max_images: 8
126 | increase_log_steps: False
127 | metrics_over_trainsteps_checkpoint:
128 | target: pytorch_lightning.callbacks.ModelCheckpoint
129 | params:
130 | filename: "{epoch:06}-{step:09}"
131 | save_weights_only: False
132 | every_n_epochs: 200
133 | every_n_train_steps: null
134 | trainer:
135 | benchmark: True
136 | batch_size: 2
137 | num_workers: 0
138 | num_nodes: 1
139 | max_epochs: 2000
140 | modelcheckpoint:
141 | target: pytorch_lightning.callbacks.ModelCheckpoint
142 | params:
143 | every_n_epochs: 1
144 | filename: "{epoch:04}-{step:06}"
--------------------------------------------------------------------------------
/configs/lvdm_long/sky_pred.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 8.0e-5 # 1.5e-04
3 | scale_lr: False
4 | target: lvdm.models.ddpm3d.FrameInterpPredLatentDiffusion
5 | params:
6 | linear_start: 0.0015
7 | linear_end: 0.0155
8 | log_every_t: 200
9 | timesteps: 1000
10 | loss_type: l1
11 | first_stage_key: image
12 | cond_stage_key: "image"
13 | image_size: 32
14 | channels: 4
15 | monitor: val/loss_simple_ema
16 | conditioning_key: concat-adm-mask
17 | cond_stage_config: null
18 | noisy_cond: True
19 | max_noise_level: 250
20 | cond_stage_trainable: False
21 | concat_mode: False
22 | scale_by_std: False
23 | scale_factor: 0.33422927
24 | shift_factor: 1.4606637
25 | encoder_type: 3d
26 | rand_temporal_mask: true
27 | p_interp: 0.0
28 | p_pred: 0.5
29 | n_prevs: [1,]
30 | split_clips: False
31 | downfactor_t: null # used for split video frames to clips before encoding
32 | clip_length: null
33 | latent_frame_strde: 4
34 |
35 | unet_config:
36 | target: lvdm.models.modules.openaimodel3d.FrameInterpPredUNet
37 | params:
38 | num_classes: 251 # timesteps for noise conditoining
39 | image_size: 32
40 | in_channels: 5
41 | out_channels: 4
42 | model_channels: 256
43 | attention_resolutions:
44 | - 8
45 | - 4
46 | - 2
47 | num_res_blocks: 3
48 | channel_mult:
49 | - 1
50 | - 2
51 | - 3
52 | - 4
53 | num_heads: 4
54 | use_temporal_transformer: False
55 | use_checkpoint: true
56 | legacy: False
57 | # temporal
58 | kernel_size_t: 1
59 | padding_t: 0
60 | temporal_length: 4
61 | use_relative_position: True
62 | use_scale_shift_norm: True
63 | first_stage_config:
64 | target: lvdm.models.autoencoder3d.AutoencoderKL
65 | params:
66 | monitor: "val/rec_loss"
67 | embed_dim: 4
68 | lossconfig: __is_first_stage__
69 | ddconfig:
70 | double_z: True
71 | z_channels: 4
72 | encoder:
73 | target: lvdm.models.modules.aemodules3d.Encoder
74 | params:
75 | n_hiddens: 32
76 | downsample: [4, 8, 8]
77 | image_channel: 3
78 | norm_type: group
79 | padding_type: replicate
80 | double_z: True
81 | z_channels: 4
82 | decoder:
83 | target: lvdm.models.modules.aemodules3d.Decoder
84 | params:
85 | n_hiddens: 32
86 | upsample: [4, 8, 8]
87 | z_channels: 4
88 | image_channel: 3
89 | norm_type: group
90 |
91 | data:
92 | target: main.DataModuleFromConfig
93 | params:
94 | batch_size: 2
95 | num_workers: 0
96 | wrap: false
97 | train:
98 | target: lvdm.data.frame_dataset.VideoFrameDataset
99 | params:
100 | data_root: /dockerdata/sky_timelapse
101 | resolution: 256
102 | video_length: 64
103 | dataset_name: sky
104 | subset_split: train
105 | spatial_transform: center_crop_resize
106 | clip_step: 1
107 | temporal_transform: rand_clips
108 | validation:
109 | target: lvdm.data.frame_dataset.VideoFrameDataset
110 | params:
111 | data_root: /dockerdata/sky_timelapse
112 | resolution: 256
113 | video_length: 64
114 | dataset_name: sky
115 | subset_split: test
116 | spatial_transform: center_crop_resize
117 | clip_step: 1
118 | temporal_transform: rand_clips
119 |
120 | lightning:
121 | callbacks:
122 | image_logger:
123 | target: lvdm.utils.callbacks.ImageLogger
124 | params:
125 | batch_frequency: 2000
126 | max_images: 8
127 | increase_log_steps: False
128 | metrics_over_trainsteps_checkpoint:
129 | target: pytorch_lightning.callbacks.ModelCheckpoint
130 | params:
131 | filename: "{epoch:06}-{step:09}"
132 | save_weights_only: False
133 | every_n_epochs: 100
134 | every_n_train_steps: null
135 | trainer:
136 | benchmark: True
137 | batch_size: 2
138 | num_workers: 0
139 | num_nodes: 1
140 | max_epochs: 2000
141 | modelcheckpoint:
142 | target: pytorch_lightning.callbacks.ModelCheckpoint
143 | params:
144 | every_n_epochs: 1
145 | filename: "{epoch:04}-{step:06}"
--------------------------------------------------------------------------------
/configs/lvdm_short/sky.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 6.0e-05 # 1.5e-04
3 | scale_lr: False
4 | target: lvdm.models.ddpm3d.LatentDiffusion
5 | params:
6 | linear_start: 0.0015
7 | linear_end: 0.0155
8 | log_every_t: 200
9 | timesteps: 1000
10 | loss_type: l1
11 | first_stage_key: image
12 | cond_stage_key: "image"
13 | image_size: 32
14 | video_length: 4
15 | channels: 4
16 | monitor: val/loss_simple_ema
17 | cond_stage_trainable: False
18 | concat_mode: False
19 | scale_by_std: False
20 | scale_factor: 0.33422927
21 | shift_factor: 1.4606637
22 | cond_stage_config: __is_unconditional__
23 | encoder_type: 3d
24 |
25 | unet_config:
26 | target: lvdm.models.modules.openaimodel3d.UNetModel
27 | params:
28 | image_size: 32
29 | in_channels: 4
30 | out_channels: 4
31 | model_channels: 256
32 | attention_resolutions:
33 | - 8
34 | - 4
35 | - 2
36 | num_res_blocks: 3
37 | channel_mult:
38 | - 1
39 | - 2
40 | - 3
41 | - 4
42 | num_heads: 4
43 | use_temporal_transformer: False
44 | use_checkpoint: true
45 | legacy: False
46 | # temporal
47 | kernel_size_t: 1
48 | padding_t: 0
49 | temporal_length: 4
50 | use_relative_position: True
51 | use_scale_shift_norm: True
52 | first_stage_config:
53 | target: lvdm.models.autoencoder3d.AutoencoderKL
54 | params:
55 | ckpt_path: ${ckpt_path}
56 | monitor: "val/rec_loss"
57 | embed_dim: 4
58 | lossconfig: __is_first_stage__
59 | ddconfig:
60 | double_z: True
61 | z_channels: 4
62 | encoder:
63 | target: lvdm.models.modules.aemodules3d.Encoder
64 | params:
65 | n_hiddens: 32
66 | downsample: [4, 8, 8]
67 | image_channel: 3
68 | norm_type: group
69 | padding_type: replicate
70 | double_z: True
71 | z_channels: 4
72 | decoder:
73 | target: lvdm.models.modules.aemodules3d.Decoder
74 | params:
75 | n_hiddens: 32
76 | upsample: [4, 8, 8]
77 | z_channels: 4
78 | image_channel: 3
79 | norm_type: group
80 |
81 | data:
82 | target: main.DataModuleFromConfig
83 | params:
84 | batch_size: 3
85 | num_workers: 0
86 | wrap: false
87 | train:
88 | target: lvdm.data.frame_dataset.VideoFrameDataset
89 | params:
90 | data_root: ${data_root}
91 | resolution: 256
92 | video_length: 16
93 | dataset_name: sky
94 | subset_split: train
95 | spatial_transform: center_crop_resize
96 | temporal_transform: rand_clips
97 | validation:
98 | target: lvdm.data.frame_dataset.VideoFrameDataset
99 | params:
100 | data_root: ${data_root}
101 | resolution: 256
102 | video_length: 16
103 | dataset_name: sky
104 | subset_split: test
105 | spatial_transform: center_crop_resize
106 | temporal_transform: rand_clips
107 |
108 | lightning:
109 | callbacks:
110 | image_logger:
111 | target: lvdm.utils.callbacks.ImageLogger
112 | params:
113 | batch_frequency: 1000
114 | max_images: 8
115 | increase_log_steps: False
116 | metrics_over_trainsteps_checkpoint:
117 | target: pytorch_lightning.callbacks.ModelCheckpoint
118 | params:
119 | filename: "{epoch:06}-{step:09}"
120 | save_weights_only: False
121 | every_n_epochs: 300
122 | every_n_train_steps: null
123 | trainer:
124 | benchmark: True
125 | batch_size: 3
126 | num_workers: 0
127 | num_nodes: 4
128 | accumulate_grad_batches: 2
129 | max_epochs: 2000
130 | modelcheckpoint:
131 | target: pytorch_lightning.callbacks.ModelCheckpoint
132 | params:
133 | every_n_epochs: 1
134 | filename: "{epoch:04}-{step:06}"
--------------------------------------------------------------------------------
/configs/lvdm_short/taichi.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 6.0e-05 # 1.5e-04
3 | scale_lr: False
4 | target: lvdm.models.ddpm3d.LatentDiffusion
5 | params:
6 | linear_start: 0.0015
7 | linear_end: 0.0155
8 | log_every_t: 200
9 | timesteps: 1000
10 | loss_type: l1
11 | first_stage_key: video
12 | cond_stage_key: "image"
13 | image_size: 32
14 | video_length: 4
15 | channels: 4
16 | monitor: val/loss_simple_ema
17 | # use_ema: False # default =True
18 | cond_stage_trainable: False
19 | concat_mode: False
20 | scale_by_std: False
21 | scale_factor: 0.175733933
22 | shift_factor: 0.03291025
23 | cond_stage_config: __is_unconditional__
24 | encoder_type: 3d
25 |
26 | unet_config:
27 | target: lvdm.models.modules.openaimodel3d.UNetModel
28 | params:
29 | image_size: 32
30 | in_channels: 4
31 | out_channels: 4
32 | model_channels: 256
33 | attention_resolutions:
34 | - 8
35 | - 4
36 | - 2
37 | num_res_blocks: 3
38 | channel_mult:
39 | - 1
40 | - 2
41 | - 3
42 | - 4
43 | num_heads: 4
44 | use_temporal_transformer: False
45 | use_checkpoint: true
46 | legacy: False
47 | # temporal
48 | kernel_size_t: 1
49 | padding_t: 0
50 | temporal_length: 4
51 | use_relative_position: True
52 | use_scale_shift_norm: True
53 | first_stage_config:
54 | target: lvdm.models.autoencoder3d.AutoencoderKL
55 | params:
56 | ckpt_path: ${ckpt_path}
57 | monitor: "val/rec_loss"
58 | embed_dim: 4
59 | lossconfig: __is_first_stage__
60 | ddconfig:
61 | double_z: True
62 | z_channels: 4
63 | encoder:
64 | target: lvdm.models.modules.aemodules3d.Encoder
65 | params:
66 | n_hiddens: 32
67 | downsample: [4, 8, 8]
68 | image_channel: 3
69 | norm_type: group
70 | padding_type: replicate
71 | double_z: True
72 | z_channels: 4
73 | decoder:
74 | target: lvdm.models.modules.aemodules3d.Decoder
75 | params:
76 | n_hiddens: 32
77 | upsample: [4, 8, 8]
78 | z_channels: 4
79 | image_channel: 3
80 | norm_type: group
81 |
82 | data:
83 | target: main.DataModuleFromConfig
84 | params:
85 | batch_size: 3
86 | num_workers: 0
87 | wrap: false
88 | train:
89 | target: lvdm.data.taichi.Taichi
90 | params:
91 | data_root: data_root
92 | resolution: 256
93 | video_length: 16
94 | subset_split: all
95 | frame_stride: 4
96 | validation:
97 | target: lvdm.data.taichi.Taichi
98 | params:
99 | data_root: data_root
100 | resolution: 256
101 | video_length: 16
102 | subset_split: test
103 | frame_stride: 4
104 |
105 | lightning:
106 | callbacks:
107 | image_logger:
108 | target: lvdm.utils.callbacks.ImageLogger
109 | params:
110 | batch_frequency: 1000
111 | max_images: 8
112 | increase_log_steps: False
113 | metrics_over_trainsteps_checkpoint:
114 | target: pytorch_lightning.callbacks.ModelCheckpoint
115 | params:
116 | filename: "{epoch:06}-{step:09}"
117 | save_weights_only: False
118 | every_n_epochs: 300
119 | every_n_train_steps: null
120 | trainer:
121 | benchmark: True
122 | batch_size: 3
123 | num_workers: 0
124 | num_nodes: 4
125 | accumulate_grad_batches: 2
126 | max_epochs: 4000
127 | modelcheckpoint:
128 | target: pytorch_lightning.callbacks.ModelCheckpoint
129 | params:
130 | every_n_epochs: 1
131 | filename: "{epoch:04}-{step:06}"
--------------------------------------------------------------------------------
/configs/lvdm_short/text2video.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | target: lvdm.models.ddpm3d.LatentDiffusion
3 | params:
4 | linear_start: 0.00085
5 | linear_end: 0.012
6 | num_timesteps_cond: 1
7 | log_every_t: 200
8 | timesteps: 1000
9 | first_stage_key: video
10 | cond_stage_key: caption
11 | image_size:
12 | - 32
13 | - 32
14 | video_length: 16
15 | channels: 4
16 | cond_stage_trainable: false
17 | conditioning_key: crossattn
18 | scale_by_std: false
19 | scale_factor: 0.18215
20 | use_ema: false
21 |
22 | unet_config:
23 | target: lvdm.models.modules.openaimodel3d.UNetModel
24 | params:
25 | image_size: 32
26 | in_channels: 4
27 | out_channels: 4
28 | model_channels: 320
29 | attention_resolutions:
30 | - 4
31 | - 2
32 | - 1
33 | num_res_blocks: 2
34 | channel_mult:
35 | - 1
36 | - 2
37 | - 4
38 | - 4
39 | num_heads: 8
40 | transformer_depth: 1
41 | context_dim: 768
42 | use_checkpoint: true
43 | legacy: false
44 | kernel_size_t: 1
45 | padding_t: 0
46 | use_temporal_transformer: true
47 | temporal_length: 16
48 | use_relative_position: true
49 |
50 | first_stage_config:
51 | target: lvdm.models.autoencoder.AutoencoderKL
52 | params:
53 | embed_dim: 4
54 | monitor: val/rec_loss
55 | ddconfig:
56 | double_z: true
57 | z_channels: 4
58 | resolution: 256
59 | in_channels: 3
60 | out_ch: 3
61 | ch: 128
62 | ch_mult:
63 | - 1
64 | - 2
65 | - 4
66 | - 4
67 | num_res_blocks: 2
68 | attn_resolutions: []
69 | dropout: 0.0
70 | lossconfig:
71 | target: torch.nn.Identity
72 |
73 | cond_stage_config:
74 | target: lvdm.models.modules.condition_modules.FrozenCLIPEmbedder
--------------------------------------------------------------------------------
/configs/lvdm_short/ucf.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 6.0e-05 # 1.5e-04
3 | scale_lr: False
4 | target: lvdm.models.ddpm3d.LatentDiffusion
5 | params:
6 | linear_start: 0.0015
7 | linear_end: 0.0155
8 | log_every_t: 200
9 | timesteps: 1000
10 | loss_type: l1
11 | first_stage_key: image
12 | cond_stage_key: "image"
13 | image_size: 32
14 | video_length: 4
15 | channels: 4
16 | monitor: val/loss_simple_ema
17 | cond_stage_trainable: False
18 | concat_mode: False
19 | scale_by_std: False
20 | scale_factor: 0.220142075
21 | shift_factor: 0.5837740898
22 | cond_stage_config: __is_unconditional__
23 | encoder_type: 3d
24 |
25 | unet_config:
26 | target: lvdm.models.modules.openaimodel3d.UNetModel
27 | params:
28 | image_size: 32
29 | in_channels: 4
30 | out_channels: 4
31 | model_channels: 256
32 | attention_resolutions:
33 | - 8
34 | - 4
35 | - 2
36 | num_res_blocks: 3
37 | channel_mult:
38 | - 1
39 | - 2
40 | - 3
41 | - 4
42 | num_heads: 4
43 | use_temporal_transformer: False
44 | use_checkpoint: true
45 | legacy: False
46 | # temporal
47 | kernel_size_t: 1
48 | padding_t: 0
49 | temporal_length: 4
50 | use_relative_position: True
51 | use_scale_shift_norm: True
52 | first_stage_config:
53 | target: lvdm.models.autoencoder3d.AutoencoderKL
54 | params:
55 | ckpt_path: ${ckpt_path}
56 | monitor: "val/rec_loss"
57 | embed_dim: 4
58 | lossconfig: __is_first_stage__
59 | ddconfig:
60 | double_z: True
61 | z_channels: 4
62 | encoder:
63 | target: lvdm.models.modules.aemodules3d.Encoder
64 | params:
65 | n_hiddens: 32
66 | downsample: [4, 8, 8]
67 | image_channel: 3
68 | norm_type: group
69 | padding_type: replicate
70 | double_z: True
71 | z_channels: 4
72 | decoder:
73 | target: lvdm.models.modules.aemodules3d.Decoder
74 | params:
75 | n_hiddens: 32
76 | upsample: [4, 8, 8]
77 | z_channels: 4
78 | image_channel: 3
79 | norm_type: group
80 |
81 | data:
82 | target: main.DataModuleFromConfig
83 | params:
84 | batch_size: 2
85 | num_workers: 0
86 | wrap: false
87 | train:
88 | target: lvdm.data.frame_dataset.VideoFrameDataset
89 | params:
90 | data_root: ${data_root}
91 | resolution: 256
92 | video_length: 16
93 | dataset_name: UCF-101
94 | subset_split: all
95 | spatial_transform: center_crop_resize
96 | clip_step: 1
97 | temporal_transform: rand_clips
98 | validation:
99 | target: lvdm.data.frame_dataset.VideoFrameDataset
100 | params:
101 | data_root: ${data_root}
102 | resolution: 256
103 | video_length: 16
104 | dataset_name: UCF-101
105 | subset_split: test
106 | spatial_transform: center_crop_resize
107 | clip_step: 1
108 | temporal_transform: rand_clips
109 |
110 | lightning:
111 | callbacks:
112 | image_logger:
113 | target: lvdm.utils.callbacks.ImageLogger
114 | params:
115 | batch_frequency: 1000
116 | max_images: 8
117 | increase_log_steps: False
118 | metrics_over_trainsteps_checkpoint:
119 | target: pytorch_lightning.callbacks.ModelCheckpoint
120 | params:
121 | filename: "{epoch:06}-{step:09}"
122 | save_weights_only: False
123 | every_n_epochs: 300
124 | every_n_train_steps: null
125 | trainer:
126 | benchmark: True
127 | batch_size: 2
128 | num_workers: 0
129 | num_nodes: 4
130 | accumulate_grad_batches: 2
131 | max_epochs: 2500 #2000
132 | modelcheckpoint:
133 | target: pytorch_lightning.callbacks.ModelCheckpoint
134 | params:
135 | every_n_epochs: 1
136 | filename: "{epoch:04}-{step:06}"
--------------------------------------------------------------------------------
/configs/videoae/sky.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 3.60e-05
3 | scale_lr: False
4 | target: lvdm.models.autoencoder3d.AutoencoderKL
5 | params:
6 | monitor: "val/rec_loss"
7 | embed_dim: 4
8 | lossconfig:
9 | target: lvdm.models.losses.LPIPSWithDiscriminator
10 | params:
11 | disc_start: 50001
12 | kl_weight: 0.0
13 | disc_weight: 0.5
14 | ddconfig:
15 | double_z: True
16 | z_channels: 4
17 | encoder:
18 | target: lvdm.models.modules.aemodules3d.Encoder
19 | params:
20 | n_hiddens: 32
21 | downsample: [4, 8, 8]
22 | image_channel: 3
23 | norm_type: group
24 | padding_type: replicate
25 | double_z: True
26 | z_channels: 4
27 |
28 | decoder:
29 | target: lvdm.models.modules.aemodules3d.Decoder
30 | params:
31 | n_hiddens: 32
32 | upsample: [4, 8, 8]
33 | z_channels: 4
34 | image_channel: 3
35 | norm_type: group
36 |
37 | data:
38 | target: main.DataModuleFromConfig
39 | params:
40 | batch_size: 1
41 | num_workers: 0
42 | wrap: false
43 | train:
44 | target: lvdm.data.frame_dataset.VideoFrameDataset
45 | params:
46 | data_root: ${data_root}
47 | resolution: 256
48 | video_length: 16
49 | dataset_name: sky
50 | subset_split: train
51 | spatial_transform: center_crop_resize
52 | clip_step: 1
53 | temporal_transform: rand_clips
54 | validation:
55 | target: lvdm.data.frame_dataset.VideoFrameDataset
56 | params:
57 | data_root: ${data_root}
58 | resolution: 256
59 | video_length: 16
60 | dataset_name: sky
61 | subset_split: test
62 | spatial_transform: center_crop_resize
63 | clip_step: 1
64 | temporal_transform: rand_clips
65 | lightning:
66 | find_unused_parameters: True
67 | callbacks:
68 | image_logger:
69 | target: lvdm.utils.callbacks.ImageLogger
70 | params:
71 | batch_frequency: 300
72 | max_images: 8
73 | increase_log_steps: False
74 | log_to_tblogger: False
75 | metrics_over_trainsteps_checkpoint:
76 | target: pytorch_lightning.callbacks.ModelCheckpoint
77 | params:
78 | filename: "{epoch:06}-{step:09}"
79 | save_weights_only: False
80 | every_n_epochs: 100
81 | every_n_train_steps: null
82 | trainer:
83 | benchmark: True
84 | accumulate_grad_batches: 2
85 | batch_size: 1
86 | num_workers: 0
87 | max_epochs: 1000
88 | modelcheckpoint:
89 | target: pytorch_lightning.callbacks.ModelCheckpoint
90 | params:
91 | every_n_epochs: 1
92 | filename: "{epoch:04}-{step:06}"
93 |
--------------------------------------------------------------------------------
/configs/videoae/taichi.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.6e-05 # 8.0e-05
3 | target: lvdm.models.autoencoder3d.AutoencoderKL
4 | params:
5 | monitor: "val/rec_loss"
6 | embed_dim: 4
7 | image_key: video
8 | lossconfig:
9 | target: lvdm.models.losses.LPIPSWithDiscriminator
10 | params:
11 | disc_start: 50001
12 | kl_weight: 0.0
13 | disc_weight: 0.5
14 | ddconfig:
15 | double_z: True
16 | z_channels: 4
17 | encoder:
18 | target: lvdm.models.modules.aemodules3d.Encoder
19 | params:
20 | n_hiddens: 32
21 | downsample: [4, 8, 8]
22 | image_channel: 3
23 | norm_type: group
24 | padding_type: replicate
25 | double_z: True
26 | z_channels: 4
27 |
28 | decoder:
29 | target: lvdm.models.modules.aemodules3d.Decoder
30 | params:
31 | n_hiddens: 32
32 | upsample: [4, 8, 8]
33 | z_channels: 4
34 | image_channel: 3
35 | norm_type: group
36 |
37 | data:
38 | target: main.DataModuleFromConfig
39 | params:
40 | batch_size: 1
41 | num_workers: 0
42 | wrap: false
43 | train:
44 | target: lvdm.data.taichi.Taichi
45 | params:
46 | data_dir: /apdcephfs_cq2/share_1290939/yingqinghe/taichi
47 | resolution: 256
48 | video_length: 16
49 | subset_split: all
50 | frame_stride: 4
51 | validation:
52 | target: lvdm.data.taichi.Taichi
53 | params:
54 | data_dir: /apdcephfs_cq2/share_1290939/yingqinghe/taichi
55 | resolution: 256
56 | video_length: 16
57 | subset_split: test
58 | frame_stride: 4
59 | lightning:
60 | find_unused_parameters: True
61 | callbacks:
62 | image_logger:
63 | target: lvdm.utils.callbacks.ImageLogger
64 | params:
65 | batch_frequency: 300
66 | max_images: 8
67 | increase_log_steps: False
68 | log_to_tblogger: False
69 | metrics_over_trainsteps_checkpoint:
70 | target: pytorch_lightning.callbacks.ModelCheckpoint
71 | params:
72 | filename: "{epoch:06}-{step:09}"
73 | save_weights_only: False
74 | every_n_epochs: 100
75 | every_n_train_steps: null
76 | trainer:
77 | benchmark: True
78 | accumulate_grad_batches: 2
79 | batch_size: 1
80 | num_workers: 0
81 | modelcheckpoint:
82 | target: pytorch_lightning.callbacks.ModelCheckpoint
83 | params:
84 | every_n_epochs: 1
85 | filename: "{epoch:04}-{step:06}"
86 |
--------------------------------------------------------------------------------
/configs/videoae/ucf.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 3.60e-05
3 | scale_lr: False
4 | target: lvdm.models.autoencoder3d.AutoencoderKL
5 | params:
6 | monitor: "val/rec_loss"
7 | embed_dim: 4
8 | lossconfig:
9 | target: lvdm.models.losses.LPIPSWithDiscriminator
10 | params:
11 | disc_start: 50001
12 | kl_weight: 0.0
13 | disc_weight: 0.5
14 | ddconfig:
15 | double_z: True
16 | z_channels: 4
17 | encoder:
18 | target: lvdm.models.modules.aemodules3d.Encoder
19 | params:
20 | n_hiddens: 32
21 | downsample: [4, 8, 8]
22 | image_channel: 3
23 | norm_type: group
24 | padding_type: replicate
25 | double_z: True
26 | z_channels: 4
27 |
28 | decoder:
29 | target: lvdm.models.modules.aemodules3d.Decoder
30 | params:
31 | n_hiddens: 32
32 | upsample: [4, 8, 8]
33 | z_channels: 4
34 | image_channel: 3
35 | norm_type: group
36 |
37 | data:
38 | target: main.DataModuleFromConfig
39 | params:
40 | batch_size: 1
41 | num_workers: 0
42 | wrap: false
43 | train:
44 | target: lvdm.data.frame_dataset.VideoFrameDataset
45 | params:
46 | data_root: ${data_root}
47 | resolution: 256
48 | video_length: 16
49 | dataset_name: UCF-101
50 | subset_split: train
51 | spatial_transform: center_crop_resize
52 | clip_step: 1
53 | temporal_transform: rand_clips
54 | validation:
55 | target: lvdm.data.frame_dataset.VideoFrameDataset
56 | params:
57 | data_root: ${data_root}
58 | resolution: 256
59 | video_length: 16
60 | dataset_name: UCF-101
61 | subset_split: test
62 | spatial_transform: center_crop_resize
63 | clip_step: 1
64 | temporal_transform: rand_clips
65 | lightning:
66 | find_unused_parameters: True
67 | callbacks:
68 | image_logger:
69 | target: lvdm.utils.callbacks.ImageLogger
70 | params:
71 | batch_frequency: 1000
72 | max_images: 8
73 | increase_log_steps: False
74 | log_to_tblogger: False
75 | trainer:
76 | benchmark: True
77 | accumulate_grad_batches: 2
78 | batch_size: 1
79 | num_workers: 0
80 | modelcheckpoint:
81 | target: pytorch_lightning.callbacks.ModelCheckpoint
82 | params:
83 | every_n_epochs: 1
84 | filename: "{epoch:04}-{step:06}"
--------------------------------------------------------------------------------
/configs/videoae/ucf_videodata.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 3.60e-05
3 | scale_lr: False
4 | target: lvdm.models.autoencoder3d.AutoencoderKL
5 | params:
6 | image_key: video
7 | monitor: "val/rec_loss"
8 | embed_dim: 4
9 | lossconfig:
10 | target: lvdm.models.losses.LPIPSWithDiscriminator
11 | params:
12 | disc_start: 50001
13 | kl_weight: 0.0
14 | disc_weight: 0.5
15 | ddconfig:
16 | double_z: True
17 | z_channels: 4
18 | encoder:
19 | target: lvdm.models.modules.aemodules3d.Encoder
20 | params:
21 | n_hiddens: 32
22 | downsample: [4, 8, 8]
23 | image_channel: 3
24 | norm_type: group
25 | padding_type: replicate
26 | double_z: True
27 | z_channels: 4
28 |
29 | decoder:
30 | target: lvdm.models.modules.aemodules3d.Decoder
31 | params:
32 | n_hiddens: 32
33 | upsample: [4, 8, 8]
34 | z_channels: 4
35 | image_channel: 3
36 | norm_type: group
37 | data:
38 | target: main.DataModuleFromConfig
39 | params:
40 | batch_size: 1
41 | num_workers: 0
42 | wrap: false
43 | train:
44 | target: lvdm.data.ucf.UCF101
45 | params:
46 | data_root: ${data_root}
47 | resolution: 256
48 | video_length: 16
49 | subset_split: all
50 | frame_stride: 1
51 | validation:
52 | target: lvdm.data.ucf.UCF101
53 | params:
54 | data_root: ${data_root}
55 | resolution: 256
56 | video_length: 16
57 | subset_split: test
58 | frame_stride: 1
59 |
60 | lightning:
61 | find_unused_parameters: True
62 | callbacks:
63 | image_logger:
64 | target: lvdm.utils.callbacks.ImageLogger
65 | params:
66 | batch_frequency: 1000
67 | max_images: 8
68 | increase_log_steps: False
69 | log_to_tblogger: False
70 | trainer:
71 | benchmark: True
72 | accumulate_grad_batches: 2
73 | batch_size: 1
74 | num_workers: 0
75 | modelcheckpoint:
76 | target: pytorch_lightning.callbacks.ModelCheckpoint
77 | params:
78 | every_n_epochs: 1
79 | filename: "{epoch:04}-{step:06}"
--------------------------------------------------------------------------------
/input/prompts.txt:
--------------------------------------------------------------------------------
1 | astronaut riding a horse
2 | Flying through an intense battle between pirate ships in a stormy ocean
--------------------------------------------------------------------------------
/lvdm/data/frame_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import re
4 | from PIL import ImageFile
5 | from PIL import Image
6 |
7 | import torch
8 | import torch.utils.data as data
9 | import torchvision.transforms as transforms
10 | import torchvision.transforms._transforms_video as transforms_video
11 |
12 | """ VideoFrameDataset """
13 |
14 | ImageFile.LOAD_TRUNCATED_IMAGES = True
15 | IMG_EXTENSIONS = [
16 | '.jpg', '.JPG', '.jpeg', '.JPEG',
17 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
18 | ]
19 |
20 |
21 | def pil_loader(path):
22 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
23 | '''
24 | with open(path, 'rb') as f:
25 | with Image.open(f) as img:
26 | return img.convert('RGB')
27 | '''
28 | Im = Image.open(path)
29 | return Im.convert('RGB')
30 |
31 | def accimage_loader(path):
32 | import accimage
33 | try:
34 | return accimage.Image(path)
35 | except IOError:
36 | # Potentially a decoding problem, fall back to PIL.Image
37 | return pil_loader(path)
38 |
39 | def default_loader(path):
40 | '''
41 | from torchvision import get_image_backend
42 | if get_image_backend() == 'accimage':
43 | return accimage_loader(path)
44 | else:
45 | '''
46 | return pil_loader(path)
47 |
48 | def is_image_file(filename):
49 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
50 |
51 | def find_classes(dir):
52 | assert(os.path.exists(dir)), f'{dir} does not exist'
53 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
54 | classes.sort()
55 | class_to_idx = {classes[i]: i for i in range(len(classes))}
56 | return classes, class_to_idx
57 |
58 | def class_name_to_idx(annotation_dir):
59 | """
60 | return class indices from 0 ~ num_classes-1
61 | """
62 | fpath = os.path.join(annotation_dir, "classInd.txt")
63 | with open(fpath, "r") as f:
64 | data = f.readlines()
65 | class_to_idx = {x.strip().split(" ")[1].lower():int(x.strip().split(" ")[0]) - 1 for x in data}
66 | return class_to_idx
67 |
68 | def make_dataset(dir, nframes, class_to_idx, frame_stride=1, **kwargs):
69 | """
70 | videos are saved in second-level directory:
71 | dir: video dir. Format:
72 | videoxxx
73 | videoxxx_1
74 | frame1.jpg
75 | frame2.jpg
76 | videoxxx_2
77 | frame1.jpg
78 | ...
79 | videoxxx
80 |
81 | nframes: num of frames of every video clips
82 | class_to_idx: for mapping video name to video id
83 | """
84 | if frame_stride != 1:
85 | raise NotImplementedError
86 |
87 | clips = []
88 | videos = []
89 | n_clip = 0
90 | video_frames = []
91 | for video_name in sorted(os.listdir(dir)):
92 | if os.path.isdir(os.path.join(dir,video_name)):
93 |
94 | # eg: dir + '/rM7aPu9WV2Q'
95 | subfolder_path = os.path.join(dir, video_name) # video_name: rM7aPu9WV2Q
96 | for subsubfold in sorted(os.listdir(subfolder_path)):
97 | subsubfolder_path = os.path.join(subfolder_path, subsubfold)
98 | if os.path.isdir(subsubfolder_path): # eg: dir/rM7aPu9WV2Q/1'
99 | clip_frames = []
100 | i = 1
101 | # traverse frames in one video
102 | for fname in sorted(os.listdir(subsubfolder_path)):
103 | if is_image_file(fname):
104 | img_path = os.path.join(subsubfolder_path, fname) # eg: dir + '/rM7aPu9WV2Q/rM7aPu9WV2Q_1/rM7aPu9WV2Q_frames_00086552.jpg'
105 | frame_info = (img_path, class_to_idx[video_name]) #(img_path, video_id)
106 | clip_frames.append(frame_info)
107 | video_frames.append(frame_info)
108 |
109 | # append clips, clip_step=n_frames (no frame overlap between clips).
110 | if i % nframes == 0 and i >0:
111 | clips.append(clip_frames)
112 | n_clip += 1
113 | clip_frames = []
114 | i = i+1
115 |
116 | if len(video_frames) >= nframes:
117 | videos.append(video_frames)
118 | video_frames = []
119 |
120 | print('number of long videos:', len(videos))
121 | print('number of short videos', len(clips))
122 | return clips, videos
123 |
124 | def split_by_captical(s):
125 | s_list = re.sub( r"([A-Z])", r" \1", s).split()
126 | string = ""
127 | for s in s_list:
128 | string += s + " "
129 | return string.rstrip(" ").lower()
130 |
131 | def make_dataset_ucf(dir, nframes, class_to_idx, frame_stride=1, clip_step=None):
132 | """
133 | Load consecutive clips and consecutive frames from `dir`.
134 |
135 | args:
136 | nframes: num of frames of every video clips
137 | class_to_idx: for mapping video name to video id
138 | frame_stride: select frames with a stride.
139 | clip_step: select clips with a step. if clip_step< nframes,
140 | there will be overlapped frames among two consecutive clips.
141 |
142 | assert videos are saved in first-level directory:
143 | dir:
144 | videoxxx1
145 | frame1.jpg
146 | frame2.jpg
147 | videoxxx2
148 | """
149 | if clip_step is None:
150 | # consecutive clips with no frame overlap
151 | clip_step = nframes
152 | # make videos
153 | clips = [] # 2d list
154 | videos = [] # 2d list
155 | for video_name in sorted(os.listdir(dir)):
156 | if video_name != '_broken_clips':
157 | video_path = os.path.join(dir, video_name)
158 | assert(os.path.isdir(video_path))
159 |
160 | frames = []
161 | for i, fname in enumerate(sorted(os.listdir(video_path))):
162 | assert(is_image_file(fname)),f'fname={fname},video_path={video_path},dir={dir}'
163 |
164 | # get frame info
165 | img_path = os.path.join(video_path, fname)
166 | class_name = video_name.split("_")[1].lower() # v_BoxingSpeedBag_g12_c05 -> boxingspeedbag
167 | class_caption = split_by_captical(video_name.split("_")[1]) # v_BoxingSpeedBag_g12_c05 -> BoxingSpeedBag -> boxing speed bag
168 | frame_info = {
169 | "img_path": img_path,
170 | "class_index": class_to_idx[class_name],
171 | "class_name": class_name, #boxingspeedbag
172 | "class_caption": class_caption #boxing speed bag
173 | }
174 | frames.append(frame_info)
175 | frames = frames[::frame_stride]
176 |
177 | # make videos
178 | if len(frames) >= nframes:
179 | videos.append(frames)
180 |
181 | # make clips
182 | start_indices = list(range(len(frames)))[::clip_step]
183 | for i in start_indices:
184 | clip = frames[i:i+nframes]
185 | if len(clip) == nframes:
186 | clips.append(clip)
187 | return clips, videos
188 |
189 | def load_and_transform_frames(frame_list, loader, img_transform=None):
190 | assert(isinstance(frame_list, list))
191 | clip = []
192 | labels = []
193 | for frame in frame_list:
194 |
195 | if isinstance(frame, tuple):
196 | fpath, label = frame
197 | elif isinstance(frame, dict):
198 | fpath = frame["img_path"]
199 | label = {
200 | "class_index": frame["class_index"],
201 | "class_name": frame["class_name"],
202 | "class_caption": frame["class_caption"],
203 | }
204 |
205 | labels.append(label)
206 | img = loader(fpath)
207 | if img_transform is not None:
208 | img = img_transform(img)
209 | img = img.view(img.size(0),1, img.size(1), img.size(2))
210 | clip.append(img)
211 | return clip, labels[0] # all frames have same label..
212 |
213 | class VideoFrameDataset(data.Dataset):
214 | def __init__(self,
215 | data_root,
216 | resolution,
217 | video_length, # clip length
218 | dataset_name="",
219 | subset_split="",
220 | annotation_dir=None,
221 | spatial_transform="",
222 | temporal_transform="",
223 | frame_stride=1,
224 | clip_step=None,
225 | ):
226 |
227 | self.loader = default_loader
228 | self.video_length = video_length
229 | self.subset_split = subset_split
230 | self.temporal_transform = temporal_transform
231 | self.spatial_transform = spatial_transform
232 | self.frame_stride = frame_stride
233 | self.dataset_name = dataset_name
234 |
235 | assert(subset_split in ["train", "test", "all", ""]) # "" means no subset_split directory.
236 | assert(self.temporal_transform in ["", "rand_clips"])
237 |
238 | if subset_split == 'all':
239 | video_dir = os.path.join(data_root, "train")
240 | else:
241 | video_dir = os.path.join(data_root, subset_split)
242 |
243 | if dataset_name == 'UCF-101':
244 | if annotation_dir is None:
245 | annotation_dir = os.path.join(data_root, "ucfTrainTestlist")
246 | class_to_idx = class_name_to_idx(annotation_dir)
247 | assert(len(class_to_idx) == 101), f'num of classes = {len(class_to_idx)}, not 101'
248 | elif dataset_name == 'sky':
249 | classes, class_to_idx = find_classes(video_dir)
250 | else:
251 | class_to_idx = None
252 |
253 | # make dataset
254 | if dataset_name == 'UCF-101':
255 | func = make_dataset_ucf
256 | else:
257 | func = make_dataset
258 | self.clips, self.videos = func(video_dir, video_length, class_to_idx, frame_stride=frame_stride, clip_step=clip_step)
259 | assert(len(self.clips[0]) == video_length), f"Invalid clip length = {len(self.clips[0])}"
260 | if self.temporal_transform == 'rand_clips':
261 | self.clips = self.videos
262 |
263 | if subset_split == 'all':
264 | # add test videos
265 | video_dir = video_dir.rstrip('/train')+'/test'
266 | cs, vs = func(video_dir, video_length, class_to_idx)
267 | if self.temporal_transform == 'rand_clips':
268 | self.clips += vs
269 | else:
270 | self.clips += cs
271 |
272 | print('[VideoFrameDataset] number of videos:', len(self.videos))
273 | print('[VideoFrameDataset] number of clips', len(self.clips))
274 |
275 | # check data
276 | if len(self.clips) == 0:
277 | raise(RuntimeError(f"Found 0 clips in {video_dir}. \n"
278 | "Supported image extensions are: " +
279 | ",".join(IMG_EXTENSIONS)))
280 |
281 | # data transform
282 | self.img_transform = transforms.Compose([
283 | transforms.ToTensor(),
284 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
285 | ])
286 | if self.spatial_transform == "center_crop_resize":
287 | print('Spatial transform: center crop and then resize')
288 | self.video_transform = transforms.Compose([
289 | transforms.Resize(resolution),
290 | transforms_video.CenterCropVideo(resolution),
291 | ])
292 | elif self.spatial_transform == "resize":
293 | print('Spatial transform: resize with no crop')
294 | self.video_transform = transforms.Resize((resolution, resolution))
295 | elif self.spatial_transform == "random_crop":
296 | self.video_transform = transforms.Compose([
297 | transforms_video.RandomCropVideo(resolution),
298 | ])
299 | elif self.spatial_transform == "":
300 | self.video_transform = None
301 | else:
302 | raise NotImplementedError
303 |
304 | def __getitem__(self, index):
305 | # get clip info
306 | if self.temporal_transform == 'rand_clips':
307 | raw_video = self.clips[index]
308 | rand_idx = random.randint(0, len(raw_video) - self.video_length)
309 | clip = raw_video[rand_idx:rand_idx+self.video_length]
310 | else:
311 | clip = self.clips[index]
312 | assert(len(clip) == self.video_length), f'current clip_length={len(clip)}, target clip_length={self.video_length}, {clip}'
313 |
314 | # make clip tensor
315 | frames, labels = load_and_transform_frames(clip, self.loader, self.img_transform)
316 | assert(len(frames) == self.video_length), f'current clip_length={len(frames)}, target clip_length={self.video_length}, {clip}'
317 | frames = torch.cat(frames, 1) # c,t,h,w
318 | if self.video_transform is not None:
319 | frames = self.video_transform(frames)
320 |
321 | example = dict()
322 | example["image"] = frames
323 | if labels is not None and self.dataset_name == 'UCF-101':
324 | example["caption"] = labels["class_caption"]
325 | example["class_label"] = labels["class_index"]
326 | example["class_name"] = labels["class_name"]
327 | example["frame_stride"] = self.frame_stride
328 | return example
329 |
330 | def __len__(self):
331 | return len(self.clips)
332 |
--------------------------------------------------------------------------------
/lvdm/data/split_ucf101.py:
--------------------------------------------------------------------------------
1 | # Split the UCF-101 official dataset to train and test splits
2 | # The output data formate:
3 | # UCF-101/
4 | # ├── train/
5 | # │ ├── ApplyEyeMakeup/
6 | # │ │ ├── v_ApplyEyeMakeup_g08_c01.avi
7 | # │ │ ├── v_ApplyEyeMakeup_g08_c02.avi
8 | # │ │ ├── ...
9 | # │ ├── ApplyLipstick/
10 | # │ │ ├── v_ApplyLipstick_g01_c01.avi
11 | # ├── test/
12 | # │ ├── ApplyEyeMakeup/
13 | # │ │ ├── v_ApplyEyeMakeup_g01_c01.avi
14 | # │ │ ├── v_ApplyEyeMakeup_g01_c02.avi
15 | # │ │ ├── ...
16 | # │ ├── ApplyLipstick/
17 | # │ │ ├── v_ApplyLipstick_g01_c01.avi
18 | # ├── ucfTrainTestlist/
19 | # │ ├── classInd.txt
20 | # │ ├── testlist01.txt
21 | # │ ├── trainlist01.txt
22 | # │ ├── ...
23 |
24 | input_dir = "temp/UCF-101" # the root directory of the UCF-101 dataset
25 | input_annotation = "temp/ucfTrainTestlist" # the annotation file of the UCF-101 dataset
26 | output_dir_tmp = f"{input_dir}_split" # the temporary directory to store the split dataset
27 |
28 | remove_original_dir = False # The output directory will be created in the same directory as the input directory
29 |
30 | import os
31 | import random
32 | import shutil
33 |
34 | split_idx = 1 # the split index
35 | # read the annotation file
36 |
37 | # make the train and test directories
38 | os.makedirs(os.path.join(output_dir_tmp, f"train"), exist_ok=True)
39 | os.makedirs(os.path.join(output_dir_tmp, f"test"), exist_ok=True)
40 |
41 | def extract_split(subset="train"):
42 | if subset not in ["train", "test"]:
43 | raise ValueError("subset must be either 'train' or 'test'")
44 |
45 | with open(os.path.join(input_annotation, f"{subset}list0{split_idx}.txt")) as f:
46 | train_list = f.readlines()
47 | train_list = [x.strip() for x in train_list]
48 | for item in train_list:
49 | if subset == "test":
50 | class_name = item.split("/")[0]
51 | video_name = item.split("/")[1]
52 | elif subset == "train":
53 | class_name = item.split("/")[0]
54 | video_name = item.split(" ")[0].split("/")[1]
55 | video_path = os.path.join(input_dir, class_name, video_name)
56 | print(f"input_dir: {input_dir}, class_name: {class_name}, video_name: {video_name}")
57 |
58 | class_dir = os.path.join(output_dir_tmp, f"{subset}", class_name)
59 | os.makedirs(class_dir, exist_ok=True)
60 | shutil.copy(video_path, class_dir)
61 | print(f"Copy {video_path} to {class_dir}")
62 |
63 | # split the dataset into the output directory
64 | extract_split(subset="train")
65 | extract_split(subset="test")
66 |
67 | # copy the annotation files to the output directory
68 | shutil.copytree(input_annotation, os.path.join(output_dir_tmp, "ucfTrainTestlist"))
69 |
70 | if remove_original_dir:
71 | shutil.rmtree(input_dir)
72 | shutil.move(output_dir_tmp, input_dir)
73 |
--------------------------------------------------------------------------------
/lvdm/data/taichi.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import torch
4 | from torch.utils.data import Dataset
5 | from decord import VideoReader, cpu
6 | import glob
7 |
8 | class Taichi(Dataset):
9 | """
10 | Taichi Dataset.
11 | Assumes data is structured as follows.
12 | Taichi/
13 | train/
14 | xxx.mp4
15 | ...
16 | test/
17 | xxx.mp4
18 | ...
19 | """
20 | def __init__(self,
21 | data_root,
22 | resolution,
23 | video_length,
24 | subset_split,
25 | frame_stride,
26 | ):
27 | self.data_root = data_root
28 | self.resolution = resolution
29 | self.video_length = video_length
30 | self.subset_split = subset_split
31 | self.frame_stride = frame_stride
32 | assert(self.subset_split in ['train', 'test', 'all'])
33 | self.exts = ['avi', 'mp4', 'webm']
34 |
35 | if isinstance(self.resolution, int):
36 | self.resolution = [self.resolution, self.resolution]
37 | assert(isinstance(self.resolution, list) and len(self.resolution) == 2)
38 |
39 | self._make_dataset()
40 |
41 | def _make_dataset(self):
42 | if self.subset_split == 'all':
43 | data_folder = self.data_root
44 | else:
45 | data_folder = os.path.join(self.data_root, self.subset_split)
46 | self.videos = sum([glob.glob(os.path.join(data_folder, '**', f'*.{ext}'), recursive=True)
47 | for ext in self.exts], [])
48 | print(f'Number of videos = {len(self.videos)}')
49 |
50 | def __getitem__(self, index):
51 | while True:
52 | video_path = self.videos[index]
53 |
54 | try:
55 | video_reader = VideoReader(video_path, ctx=cpu(0), width=self.resolution[1], height=self.resolution[0])
56 | if len(video_reader) < self.video_length:
57 | index += 1
58 | continue
59 | else:
60 | break
61 | except:
62 | index += 1
63 | print(f"Load video failed! path = {video_path}")
64 |
65 | all_frames = list(range(0, len(video_reader), self.frame_stride))
66 | if len(all_frames) < self.video_length:
67 | all_frames = list(range(0, len(video_reader), 1))
68 |
69 | # select random clip
70 | rand_idx = random.randint(0, len(all_frames) - self.video_length)
71 | frame_indices = list(range(rand_idx, rand_idx+self.video_length))
72 | frames = video_reader.get_batch(frame_indices)
73 | assert(frames.shape[0] == self.video_length),f'{len(frames)}, self.video_length={self.video_length}'
74 |
75 | frames = torch.tensor(frames.asnumpy()).permute(3, 0, 1, 2).float() # [t,h,w,c] -> [c,t,h,w]
76 | assert(frames.shape[2] == self.resolution[0] and frames.shape[3] == self.resolution[1]), f'frames={frames.shape}, self.resolution={self.resolution}'
77 | frames = (frames / 255 - 0.5) * 2
78 | data = {'video': frames}
79 | return data
80 |
81 | def __len__(self):
82 | return len(self.videos)
--------------------------------------------------------------------------------
/lvdm/data/ucf.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import glob
4 | import random
5 | import torch
6 | from torch.utils.data import Dataset
7 | from decord import VideoReader, cpu
8 |
9 | def split_by_captical(s):
10 | s_list = re.sub( r"([A-Z])", r" \1", s).split()
11 | string = ""
12 | for s in s_list:
13 | string += s + " "
14 | return string.rstrip(" ").lower()
15 |
16 | def sample_strided_frames(vid_len, frame_stride, target_vid_len):
17 | frame_indices = list(range(0, vid_len, frame_stride))
18 | if len(frame_indices) < target_vid_len:
19 | frame_stride = vid_len // target_vid_len # recalculate a max fs
20 | assert(frame_stride != 0)
21 | frame_indices = list(range(0, vid_len, frame_stride))
22 | return frame_indices, frame_stride
23 |
24 | def class_name_to_idx(annotation_dir):
25 | """
26 | return class indices from 0 ~ num_classes-1
27 | """
28 | fpath = os.path.join(annotation_dir, "classInd.txt")
29 | with open(fpath, "r") as f:
30 | data = f.readlines()
31 | class_to_idx = {x.strip().split(" ")[1].lower():int(x.strip().split(" ")[0]) - 1 for x in data}
32 | return class_to_idx
33 |
34 | def class_idx_to_caption(caption_path):
35 | """
36 | return class captions
37 | """
38 | with open(caption_path, "r") as f:
39 | data = f.readlines()
40 | idx_to_cap = {i: line.strip() for i, line in enumerate(data)}
41 | return idx_to_cap
42 |
43 | class UCF101(Dataset):
44 | """
45 | UCF101 Dataset. Assumes data is structured as follows.
46 |
47 | UCF101/ (data_root)
48 | train/
49 | classname
50 | xxx.avi
51 | xxx.avi
52 | test/
53 | classname
54 | xxx.avi
55 | xxx.avi
56 | ucfTrainTestlist/
57 | """
58 | def __init__(self,
59 | data_root,
60 | resolution,
61 | video_length,
62 | subset_split,
63 | frame_stride,
64 | annotation_dir=None,
65 | caption_file=None,
66 | ):
67 | self.data_root = data_root
68 | self.resolution = resolution
69 | self.video_length = video_length
70 | self.subset_split = subset_split
71 | self.frame_stride = frame_stride
72 | self.annotation_dir = annotation_dir if annotation_dir is not None else os.path.join(data_root, "ucfTrainTestlist")
73 | self.caption_file = caption_file
74 |
75 | assert(self.subset_split in ['train', 'test', 'all'])
76 | self.exts = ['avi', 'mp4', 'webm']
77 | if isinstance(self.resolution, int):
78 | self.resolution = [self.resolution, self.resolution]
79 | assert(isinstance(self.resolution, list) and len(self.resolution) == 2)
80 |
81 | self._make_dataset()
82 |
83 | def _make_dataset(self):
84 | if self.subset_split == 'all':
85 | data_folder = self.data_root
86 | else:
87 | data_folder = os.path.join(self.data_root, self.subset_split)
88 | video_paths = sum([glob.glob(os.path.join(data_folder, '**', f'*.{ext}'), recursive=True)
89 | for ext in self.exts], [])
90 | # ucf class_to_idx
91 | class_to_idx = class_name_to_idx(self.annotation_dir)
92 | idx_to_cap = class_idx_to_caption(self.caption_path) if self.caption_file is not None else None
93 |
94 | self.videos = video_paths
95 | self.class_to_idx = class_to_idx
96 | self.idx_to_cap = idx_to_cap
97 | print(f'Number of videos = {len(self.videos)}')
98 |
99 | def _get_ucf_classinfo(self, videopath):
100 | video_name = os.path.basename(videopath)
101 | class_name = video_name.split("_")[1].lower() # v_BoxingSpeedBag_g12_c05 -> boxingspeedbag
102 | class_index = self.class_to_idx[class_name] # 0-100
103 | class_caption = self.idx_to_cap[class_index] if self.caption_file is not None else \
104 | split_by_captical(video_name.split("_")[1]) # v_BoxingSpeedBag_g12_c05 -> boxing speed bag
105 | return class_name, class_index, class_caption
106 |
107 | def __getitem__(self, index):
108 | while True:
109 | video_path = self.videos[index]
110 |
111 | try:
112 | video_reader = VideoReader(video_path, ctx=cpu(0), width=self.resolution[1], height=self.resolution[0])
113 | vid_len = len(video_reader)
114 | if vid_len < self.video_length:
115 | index += 1
116 | continue
117 | else:
118 | break
119 | except:
120 | index += 1
121 | print(f"Load video failed! path = {video_path}")
122 |
123 | # sample strided frames
124 | all_frames, fs = sample_strided_frames(vid_len, self.frame_stride, self.video_length)
125 |
126 | # select random clip
127 | rand_idx = random.randint(0, len(all_frames) - self.video_length)
128 | frame_indices = list(range(rand_idx, rand_idx+self.video_length))
129 | frames = video_reader.get_batch(frame_indices)
130 | assert(frames.shape[0] == self.video_length),f'{len(frames)}, self.video_length={self.video_length}'
131 |
132 | frames = torch.tensor(frames.asnumpy()).permute(3, 0, 1, 2).float() # [t,h,w,c] -> [c,t,h,w]
133 | assert(frames.shape[2] == self.resolution[0] and frames.shape[3] == self.resolution[1]), f'frames={frames.shape}, self.resolution={self.resolution}'
134 | frames = (frames / 255 - 0.5) * 2
135 |
136 | class_name, class_index, class_caption = self._get_ucf_classinfo(videopath=video_path)
137 | data = {'video': frames,
138 | 'class_name': class_name,
139 | 'class_index': class_index,
140 | 'class_caption': class_caption
141 | }
142 | return data
143 |
144 | def __len__(self):
145 | return len(self.videos)
--------------------------------------------------------------------------------
/lvdm/models/autoencoder.py:
--------------------------------------------------------------------------------
1 | import os
2 | from einops import rearrange
3 |
4 | import torch
5 | import pytorch_lightning as pl
6 | import torch.nn.functional as F
7 |
8 | from lvdm.models.modules.aemodules import Encoder, Decoder
9 | from lvdm.models.modules.distributions import DiagonalGaussianDistribution
10 | from lvdm.utils.common_utils import instantiate_from_config
11 |
12 | class AutoencoderKL(pl.LightningModule):
13 | def __init__(self,
14 | ddconfig,
15 | lossconfig,
16 | embed_dim,
17 | ckpt_path=None,
18 | ignore_keys=[],
19 | image_key="image",
20 | colorize_nlabels=None,
21 | monitor=None,
22 | test=False,
23 | logdir=None,
24 | input_dim=4,
25 | test_args=None,
26 | ):
27 | super().__init__()
28 | self.image_key = image_key
29 | self.encoder = Encoder(**ddconfig)
30 | self.decoder = Decoder(**ddconfig)
31 | self.loss = instantiate_from_config(lossconfig)
32 | assert ddconfig["double_z"]
33 | self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
34 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
35 | self.embed_dim = embed_dim
36 | self.input_dim = input_dim
37 | self.test = test
38 | self.test_args = test_args
39 | self.logdir = logdir
40 | if colorize_nlabels is not None:
41 | assert type(colorize_nlabels)==int
42 | self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
43 | if monitor is not None:
44 | self.monitor = monitor
45 | if ckpt_path is not None:
46 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
47 | if self.test:
48 | self.init_test()
49 |
50 | def init_test(self,):
51 | self.test = True
52 | save_dir = os.path.join(self.logdir, "test")
53 | if 'ckpt' in self.test_args:
54 | ckpt_name = os.path.basename(self.test_args.ckpt).split('.ckpt')[0] + f'_epoch{self._cur_epoch}'
55 | self.root = os.path.join(save_dir, ckpt_name)
56 | else:
57 | self.root = save_dir
58 | if 'test_subdir' in self.test_args:
59 | self.root = os.path.join(save_dir, self.test_args.test_subdir)
60 |
61 | self.root_zs = os.path.join(self.root, "zs")
62 | self.root_dec = os.path.join(self.root, "reconstructions")
63 | self.root_inputs = os.path.join(self.root, "inputs")
64 | os.makedirs(self.root, exist_ok=True)
65 |
66 | if self.test_args.save_z:
67 | os.makedirs(self.root_zs, exist_ok=True)
68 | if self.test_args.save_reconstruction:
69 | os.makedirs(self.root_dec, exist_ok=True)
70 | if self.test_args.save_input:
71 | os.makedirs(self.root_inputs, exist_ok=True)
72 | assert(self.test_args is not None)
73 | self.test_maximum = getattr(self.test_args, 'test_maximum', None) #1500 # 12000/8
74 | self.count = 0
75 | self.eval_metrics = {}
76 | self.decodes = []
77 | self.save_decode_samples = 2048
78 |
79 | def init_from_ckpt(self, path, ignore_keys=list()):
80 | sd = torch.load(path, map_location="cpu")
81 | try:
82 | self._cur_epoch = sd['epoch']
83 | sd = sd["state_dict"]
84 | except:
85 | self._cur_epoch = 'null'
86 | keys = list(sd.keys())
87 | for k in keys:
88 | for ik in ignore_keys:
89 | if k.startswith(ik):
90 | print("Deleting key {} from state_dict.".format(k))
91 | del sd[k]
92 | self.load_state_dict(sd, strict=False)
93 | # self.load_state_dict(sd, strict=True)
94 | print(f"Restored from {path}")
95 |
96 | def encode(self, x, **kwargs):
97 |
98 | h = self.encoder(x)
99 | moments = self.quant_conv(h)
100 | posterior = DiagonalGaussianDistribution(moments)
101 | return posterior
102 |
103 | def decode(self, z, **kwargs):
104 | z = self.post_quant_conv(z)
105 | dec = self.decoder(z)
106 | return dec
107 |
108 | def forward(self, input, sample_posterior=True):
109 | posterior = self.encode(input)
110 | if sample_posterior:
111 | z = posterior.sample()
112 | else:
113 | z = posterior.mode()
114 | dec = self.decode(z)
115 | return dec, posterior
116 |
117 | def get_input(self, batch, k):
118 | x = batch[k]
119 | # if len(x.shape) == 3:
120 | # x = x[..., None]
121 | # if x.dim() == 4:
122 | # x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
123 | if x.dim() == 5 and self.input_dim == 4:
124 | b,c,t,h,w = x.shape
125 | self.b = b
126 | self.t = t
127 | x = rearrange(x, 'b c t h w -> (b t) c h w')
128 |
129 | return x
130 |
131 | def training_step(self, batch, batch_idx, optimizer_idx):
132 | inputs = self.get_input(batch, self.image_key)
133 | reconstructions, posterior = self(inputs)
134 |
135 | if optimizer_idx == 0:
136 | # train encoder+decoder+logvar
137 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
138 | last_layer=self.get_last_layer(), split="train")
139 | self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
140 | self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
141 | return aeloss
142 |
143 | if optimizer_idx == 1:
144 | # train the discriminator
145 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
146 | last_layer=self.get_last_layer(), split="train")
147 |
148 | self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
149 | self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
150 | return discloss
151 |
152 | def validation_step(self, batch, batch_idx):
153 | inputs = self.get_input(batch, self.image_key)
154 | reconstructions, posterior = self(inputs)
155 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
156 | last_layer=self.get_last_layer(), split="val")
157 |
158 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
159 | last_layer=self.get_last_layer(), split="val")
160 |
161 | self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
162 | self.log_dict(log_dict_ae)
163 | self.log_dict(log_dict_disc)
164 | return self.log_dict
165 |
166 | def configure_optimizers(self):
167 | lr = self.learning_rate
168 | opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
169 | list(self.decoder.parameters())+
170 | list(self.quant_conv.parameters())+
171 | list(self.post_quant_conv.parameters()),
172 | lr=lr, betas=(0.5, 0.9))
173 | opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
174 | lr=lr, betas=(0.5, 0.9))
175 | return [opt_ae, opt_disc], []
176 |
177 | def get_last_layer(self):
178 | return self.decoder.conv_out.weight
179 |
180 | @torch.no_grad()
181 | def log_images(self, batch, only_inputs=False, **kwargs):
182 | log = dict()
183 | x = self.get_input(batch, self.image_key)
184 | x = x.to(self.device)
185 | if not only_inputs:
186 | xrec, posterior = self(x)
187 | if x.shape[1] > 3:
188 | # colorize with random projection
189 | assert xrec.shape[1] > 3
190 | x = self.to_rgb(x)
191 | xrec = self.to_rgb(xrec)
192 | log["samples"] = self.decode(torch.randn_like(posterior.sample()))
193 | log["reconstructions"] = xrec
194 | log["inputs"] = x
195 | return log
196 |
197 | def to_rgb(self, x):
198 | assert self.image_key == "segmentation"
199 | if not hasattr(self, "colorize"):
200 | self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
201 | x = F.conv2d(x, weight=self.colorize)
202 | x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
203 | return x
--------------------------------------------------------------------------------
/lvdm/models/autoencoder3d.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import pytorch_lightning as pl
3 | import torch.nn.functional as F
4 |
5 | from lvdm.models.modules.aemodules3d import SamePadConv3d
6 | from lvdm.utils.common_utils import instantiate_from_config
7 | from lvdm.models.modules.distributions import DiagonalGaussianDistribution
8 |
9 |
10 | def conv3d(in_channels, out_channels, kernel_size, conv3d_type='SamePadConv3d'):
11 | if conv3d_type == 'SamePadConv3d':
12 | return SamePadConv3d(in_channels, out_channels, kernel_size=kernel_size, padding_type='replicate')
13 | else:
14 | raise NotImplementedError
15 |
16 | class AutoencoderKL(pl.LightningModule):
17 | def __init__(self,
18 | ddconfig,
19 | lossconfig,
20 | embed_dim,
21 | ckpt_path=None,
22 | ignore_keys=[],
23 | image_key="image",
24 | monitor=None,
25 | std=1.,
26 | mean=0.,
27 | prob=0.2,
28 | **kwargs,
29 | ):
30 | super().__init__()
31 | self.image_key = image_key
32 | self.encoder = instantiate_from_config(ddconfig['encoder'])
33 | self.decoder = instantiate_from_config(ddconfig['decoder'])
34 | self.loss = instantiate_from_config(lossconfig)
35 | assert ddconfig["double_z"]
36 | self.quant_conv = conv3d(2*ddconfig["z_channels"], 2*embed_dim, 1)
37 | self.post_quant_conv = conv3d(embed_dim, ddconfig["z_channels"], 1)
38 | self.embed_dim = embed_dim
39 | self.std = std
40 | self.mean = mean
41 | self.prob = prob
42 | if monitor is not None:
43 | self.monitor = monitor
44 | if ckpt_path is not None:
45 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
46 |
47 | def init_from_ckpt(self, path, ignore_keys=list()):
48 | sd = torch.load(path, map_location="cpu")
49 | try:
50 | self._cur_epoch = sd['epoch']
51 | sd = sd["state_dict"]
52 | except:
53 | pass
54 | keys = list(sd.keys())
55 | for k in keys:
56 | for ik in ignore_keys:
57 | if k.startswith(ik):
58 | print("Deleting key {} from state_dict.".format(k))
59 | del sd[k]
60 | self.load_state_dict(sd, strict=False)
61 | print(f"Restored from {path}")
62 |
63 | def encode(self, x, **kwargs):
64 | h = self.encoder(x)
65 | moments = self.quant_conv(h)
66 | posterior = DiagonalGaussianDistribution(moments)
67 | return posterior
68 |
69 | def decode(self, z, **kwargs):
70 | z = self.post_quant_conv(z)
71 | dec = self.decoder(z)
72 | return dec
73 |
74 | def forward(self, input, sample_posterior=True, **kwargs):
75 | posterior = self.encode(input)
76 | if sample_posterior:
77 | z = posterior.sample()
78 | else:
79 | z = posterior.mode()
80 | dec = self.decode(z)
81 | return dec, posterior
82 |
83 | def get_input(self, batch, k):
84 | x = batch[k]
85 | if len(x.shape) == 4:
86 | x = x[..., None]
87 | x = x.to(memory_format=torch.contiguous_format).float()
88 | return x
89 |
90 | def training_step(self, batch, batch_idx, optimizer_idx):
91 | inputs = self.get_input(batch, self.image_key)
92 | reconstructions, posterior = self(inputs)
93 |
94 | if optimizer_idx == 0:
95 | # train encoder+decoder+logvar
96 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
97 | last_layer=self.get_last_layer(), split="train")
98 | self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
99 | self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
100 | return aeloss
101 |
102 | if optimizer_idx == 1:
103 | # train the discriminator
104 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
105 | last_layer=self.get_last_layer(), split="train")
106 |
107 | self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
108 | self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
109 | return discloss
110 |
111 | def validation_step(self, batch, batch_idx):
112 | inputs = self.get_input(batch, self.image_key)
113 | reconstructions, posterior = self(inputs)
114 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
115 | last_layer=self.get_last_layer(), split="val")
116 |
117 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
118 | last_layer=self.get_last_layer(), split="val")
119 |
120 | self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
121 | self.log_dict(log_dict_ae)
122 | self.log_dict(log_dict_disc)
123 | return self.log_dict
124 |
125 | def configure_optimizers(self):
126 | lr = self.learning_rate
127 | opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
128 | list(self.decoder.parameters())+
129 | list(self.quant_conv.parameters())+
130 | list(self.post_quant_conv.parameters()),
131 | lr=lr, betas=(0.5, 0.9))
132 | opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
133 | lr=lr, betas=(0.5, 0.9))
134 | return [opt_ae, opt_disc], []
135 |
136 | def get_last_layer(self):
137 | return self.decoder.conv_out.weight
138 |
139 | @torch.no_grad()
140 | def log_images(self, batch, only_inputs=False, **kwargs):
141 | log = dict()
142 | x = self.get_input(batch, self.image_key)
143 | x = x.to(self.device)
144 | if not only_inputs:
145 | xrec, posterior = self(x)
146 | if x.shape[1] > 3:
147 | # colorize with random projection
148 | assert xrec.shape[1] > 3
149 | x = self.to_rgb(x)
150 | xrec = self.to_rgb(xrec)
151 | log["samples"] = self.decode(torch.randn_like(posterior.sample()))
152 | log["reconstructions"] = xrec
153 | log["inputs"] = x
154 | return log
155 |
156 | def to_rgb(self, x):
157 | assert self.image_key == "segmentation"
158 | if not hasattr(self, "colorize"):
159 | self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
160 | x = F.conv2d(x, weight=self.colorize)
161 | x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
162 | return x
163 |
--------------------------------------------------------------------------------
/lvdm/models/losses/__init__.py:
--------------------------------------------------------------------------------
1 | from lvdm.models.losses.contperceptual import LPIPSWithDiscriminator
--------------------------------------------------------------------------------
/lvdm/models/losses/contperceptual.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from einops import rearrange
4 | from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no?
5 |
6 |
7 | class LPIPSWithDiscriminator(nn.Module):
8 | def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0,
9 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
10 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
11 | disc_loss="hinge", max_bs=None,):
12 |
13 | super().__init__()
14 | assert disc_loss in ["hinge", "vanilla"]
15 | self.kl_weight = kl_weight
16 | self.pixel_weight = pixelloss_weight
17 | self.perceptual_loss = LPIPS().eval()
18 | self.perceptual_weight = perceptual_weight
19 | # output log variance
20 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
21 |
22 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
23 | n_layers=disc_num_layers,
24 | use_actnorm=use_actnorm
25 | ).apply(weights_init)
26 | self.discriminator_iter_start = disc_start
27 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
28 | self.disc_factor = disc_factor
29 | self.discriminator_weight = disc_weight
30 | self.disc_conditional = disc_conditional
31 | self.max_bs = max_bs
32 |
33 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
34 | if last_layer is not None:
35 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
36 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
37 | else:
38 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
39 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
40 |
41 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
42 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
43 | d_weight = d_weight * self.discriminator_weight
44 | return d_weight
45 |
46 | def forward(self, inputs, reconstructions, posteriors, optimizer_idx,
47 | global_step, last_layer=None, cond=None, split="train",
48 | weights=None,):
49 | if inputs.dim() == 5:
50 | inputs = rearrange(inputs, 'b c t h w -> (b t) c h w')
51 | if reconstructions.dim() == 5:
52 | reconstructions = rearrange(reconstructions, 'b c t h w -> (b t) c h w')
53 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
54 | if self.perceptual_weight > 0:
55 | if self.max_bs is not None and self.max_bs < inputs.shape[0] :
56 | input_list = torch.split(inputs, self.max_bs, dim=0)
57 | reconstruction_list = torch.split(reconstructions, self.max_bs, dim=0)
58 | p_losses = [self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) for inputs, reconstructions in zip(input_list, reconstruction_list)]
59 | p_loss=torch.cat(p_losses,dim=0)
60 | else:
61 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
62 | rec_loss = rec_loss + self.perceptual_weight * p_loss
63 |
64 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
65 | weighted_nll_loss = nll_loss
66 | if weights is not None:
67 | weighted_nll_loss = weights*nll_loss
68 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
69 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
70 |
71 | kl_loss = posteriors.kl()
72 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
73 |
74 | # now the GAN part
75 | if optimizer_idx == 0:
76 | # generator update
77 | if cond is None:
78 | assert not self.disc_conditional
79 | logits_fake = self.discriminator(reconstructions.contiguous())
80 | else:
81 | assert self.disc_conditional
82 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
83 | g_loss = -torch.mean(logits_fake)
84 |
85 | if self.disc_factor > 0.0:
86 | try:
87 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
88 | except RuntimeError:
89 | assert not self.training
90 | d_weight = torch.tensor(0.0)
91 | else:
92 | d_weight = torch.tensor(0.0)
93 |
94 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
95 | loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss
96 |
97 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(),
98 | "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(),
99 | "{}/rec_loss".format(split): rec_loss.detach().mean(),
100 | "{}/d_weight".format(split): d_weight.detach(),
101 | "{}/disc_factor".format(split): torch.tensor(disc_factor),
102 | "{}/g_loss".format(split): g_loss.detach().mean(),
103 | }
104 | return loss, log
105 |
106 | if optimizer_idx == 1:
107 | # second pass for discriminator update
108 | if cond is None:
109 | logits_real = self.discriminator(inputs.contiguous().detach())
110 | logits_fake = self.discriminator(reconstructions.contiguous().detach())
111 | else:
112 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
113 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
114 |
115 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
116 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
117 |
118 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
119 | "{}/logits_real".format(split): logits_real.detach().mean(),
120 | "{}/logits_fake".format(split): logits_fake.detach().mean()
121 | }
122 | return d_loss, log
123 |
124 |
--------------------------------------------------------------------------------
/lvdm/models/losses/vqperceptual.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | from einops import repeat
5 |
6 | from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
7 | from taming.modules.losses.lpips import LPIPS
8 | from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss
9 |
10 |
11 | def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights):
12 | assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0]
13 | loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3])
14 | loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3])
15 | loss_real = (weights * loss_real).sum() / weights.sum()
16 | loss_fake = (weights * loss_fake).sum() / weights.sum()
17 | d_loss = 0.5 * (loss_real + loss_fake)
18 | return d_loss
19 |
20 | def adopt_weight(weight, global_step, threshold=0, value=0.):
21 | if global_step < threshold:
22 | weight = value
23 | return weight
24 |
25 |
26 | def measure_perplexity(predicted_indices, n_embed):
27 | # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
28 | # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
29 | encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed)
30 | avg_probs = encodings.mean(0)
31 | perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
32 | cluster_use = torch.sum(avg_probs > 0)
33 | return perplexity, cluster_use
34 |
35 | def l1(x, y):
36 | return torch.abs(x-y)
37 |
38 |
39 | def l2(x, y):
40 | return torch.pow((x-y), 2)
41 |
42 |
43 | class VQLPIPSWithDiscriminator(nn.Module):
44 | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
45 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
46 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
47 | disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips",
48 | pixel_loss="l1"):
49 | super().__init__()
50 | assert disc_loss in ["hinge", "vanilla"]
51 | assert perceptual_loss in ["lpips", "clips", "dists"]
52 | assert pixel_loss in ["l1", "l2"]
53 | self.codebook_weight = codebook_weight
54 | self.pixel_weight = pixelloss_weight
55 | if perceptual_loss == "lpips":
56 | print(f"{self.__class__.__name__}: Running with LPIPS.")
57 | self.perceptual_loss = LPIPS().eval()
58 | else:
59 | raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<")
60 | self.perceptual_weight = perceptual_weight
61 |
62 | if pixel_loss == "l1":
63 | self.pixel_loss = l1
64 | else:
65 | self.pixel_loss = l2
66 |
67 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
68 | n_layers=disc_num_layers,
69 | use_actnorm=use_actnorm,
70 | ndf=disc_ndf
71 | ).apply(weights_init)
72 | self.discriminator_iter_start = disc_start
73 | if disc_loss == "hinge":
74 | self.disc_loss = hinge_d_loss
75 | elif disc_loss == "vanilla":
76 | self.disc_loss = vanilla_d_loss
77 | else:
78 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
79 | print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
80 | self.disc_factor = disc_factor
81 | self.discriminator_weight = disc_weight
82 | self.disc_conditional = disc_conditional
83 | self.n_classes = n_classes
84 |
85 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
86 | if last_layer is not None:
87 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
88 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
89 | else:
90 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
91 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
92 |
93 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
94 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
95 | d_weight = d_weight * self.discriminator_weight
96 | return d_weight
97 |
98 | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,
99 | global_step, last_layer=None, cond=None, split="train", predicted_indices=None):
100 | if not exists(codebook_loss):
101 | codebook_loss = torch.tensor([0.]).to(inputs.device)
102 | #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
103 | rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous())
104 | if self.perceptual_weight > 0:
105 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
106 | rec_loss = rec_loss + self.perceptual_weight * p_loss
107 | else:
108 | p_loss = torch.tensor([0.0])
109 |
110 | nll_loss = rec_loss
111 | #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
112 | nll_loss = torch.mean(nll_loss)
113 |
114 | # now the GAN part
115 | if optimizer_idx == 0:
116 | # generator update
117 | if cond is None:
118 | assert not self.disc_conditional
119 | logits_fake = self.discriminator(reconstructions.contiguous())
120 | else:
121 | assert self.disc_conditional
122 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
123 | g_loss = -torch.mean(logits_fake)
124 |
125 | try:
126 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
127 | except RuntimeError:
128 | assert not self.training
129 | d_weight = torch.tensor(0.0)
130 |
131 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
132 | loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean()
133 |
134 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
135 | "{}/quant_loss".format(split): codebook_loss.detach().mean(),
136 | "{}/nll_loss".format(split): nll_loss.detach().mean(),
137 | "{}/rec_loss".format(split): rec_loss.detach().mean(),
138 | "{}/p_loss".format(split): p_loss.detach().mean(),
139 | "{}/d_weight".format(split): d_weight.detach(),
140 | "{}/disc_factor".format(split): torch.tensor(disc_factor),
141 | "{}/g_loss".format(split): g_loss.detach().mean(),
142 | }
143 | if predicted_indices is not None:
144 | assert self.n_classes is not None
145 | with torch.no_grad():
146 | perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes)
147 | log[f"{split}/perplexity"] = perplexity
148 | log[f"{split}/cluster_usage"] = cluster_usage
149 | return loss, log
150 |
151 | if optimizer_idx == 1:
152 | # second pass for discriminator update
153 | if cond is None:
154 | logits_real = self.discriminator(inputs.contiguous().detach())
155 | logits_fake = self.discriminator(reconstructions.contiguous().detach())
156 | else:
157 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
158 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
159 |
160 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
161 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
162 |
163 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
164 | "{}/logits_real".format(split): logits_real.detach().mean(),
165 | "{}/logits_fake".format(split): logits_fake.detach().mean()
166 | }
167 | return d_loss, log
168 |
--------------------------------------------------------------------------------
/lvdm/models/modules/aemodules3d.py:
--------------------------------------------------------------------------------
1 | # TATS
2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved
3 |
4 | import math
5 | import numpy as np
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 | def silu(x):
12 | return x*torch.sigmoid(x)
13 |
14 | class SiLU(nn.Module):
15 | def __init__(self):
16 | super(SiLU, self).__init__()
17 |
18 | def forward(self, x):
19 | return silu(x)
20 |
21 | def hinge_d_loss(logits_real, logits_fake):
22 | loss_real = torch.mean(F.relu(1. - logits_real))
23 | loss_fake = torch.mean(F.relu(1. + logits_fake))
24 | d_loss = 0.5 * (loss_real + loss_fake)
25 | return d_loss
26 |
27 | def vanilla_d_loss(logits_real, logits_fake):
28 | d_loss = 0.5 * (
29 | torch.mean(torch.nn.functional.softplus(-logits_real)) +
30 | torch.mean(torch.nn.functional.softplus(logits_fake)))
31 | return d_loss
32 |
33 | def Normalize(in_channels, norm_type='group'):
34 | assert norm_type in ['group', 'batch']
35 | if norm_type == 'group':
36 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
37 | elif norm_type == 'batch':
38 | return torch.nn.SyncBatchNorm(in_channels)
39 |
40 |
41 | class ResBlock(nn.Module):
42 | def __init__(self, in_channels, out_channels=None, conv_shortcut=False, dropout=0.0, norm_type='group', padding_type='replicate'):
43 | super().__init__()
44 | self.in_channels = in_channels
45 | out_channels = in_channels if out_channels is None else out_channels
46 | self.out_channels = out_channels
47 | self.use_conv_shortcut = conv_shortcut
48 |
49 | self.norm1 = Normalize(in_channels, norm_type)
50 | self.conv1 = SamePadConv3d(in_channels, out_channels, kernel_size=3, padding_type=padding_type)
51 | self.dropout = torch.nn.Dropout(dropout)
52 | self.norm2 = Normalize(in_channels, norm_type)
53 | self.conv2 = SamePadConv3d(out_channels, out_channels, kernel_size=3, padding_type=padding_type)
54 | if self.in_channels != self.out_channels:
55 | self.conv_shortcut = SamePadConv3d(in_channels, out_channels, kernel_size=3, padding_type=padding_type)
56 |
57 | def forward(self, x):
58 | h = x
59 | h = self.norm1(h)
60 | h = silu(h)
61 | h = self.conv1(h)
62 | h = self.norm2(h)
63 | h = silu(h)
64 | h = self.conv2(h)
65 |
66 | if self.in_channels != self.out_channels:
67 | x = self.conv_shortcut(x)
68 |
69 | return x+h
70 |
71 |
72 | # Does not support dilation
73 | class SamePadConv3d(nn.Module):
74 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=True, padding_type='replicate'):
75 | super().__init__()
76 | if isinstance(kernel_size, int):
77 | kernel_size = (kernel_size,) * 3
78 | if isinstance(stride, int):
79 | stride = (stride,) * 3
80 |
81 | # assumes that the input shape is divisible by stride
82 | total_pad = tuple([k - s for k, s in zip(kernel_size, stride)])
83 | pad_input = []
84 | for p in total_pad[::-1]: # reverse since F.pad starts from last dim
85 | pad_input.append((p // 2 + p % 2, p // 2))
86 | pad_input = sum(pad_input, tuple())
87 |
88 | self.pad_input = pad_input
89 | self.padding_type = padding_type
90 |
91 | self.conv = nn.Conv3d(in_channels, out_channels, kernel_size,
92 | stride=stride, padding=0, bias=bias)
93 | self.weight = self.conv.weight
94 |
95 | def forward(self, x):
96 | return self.conv(F.pad(x, self.pad_input, mode=self.padding_type))
97 |
98 |
99 | class SamePadConvTranspose3d(nn.Module):
100 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=True, padding_type='replicate'):
101 | super().__init__()
102 | if isinstance(kernel_size, int):
103 | kernel_size = (kernel_size,) * 3
104 | if isinstance(stride, int):
105 | stride = (stride,) * 3
106 |
107 | total_pad = tuple([k - s for k, s in zip(kernel_size, stride)])
108 | pad_input = []
109 | for p in total_pad[::-1]: # reverse since F.pad starts from last dim
110 | pad_input.append((p // 2 + p % 2, p // 2))
111 | pad_input = sum(pad_input, tuple())
112 | self.pad_input = pad_input
113 | self.padding_type = padding_type
114 | self.convt = nn.ConvTranspose3d(in_channels, out_channels, kernel_size,
115 | stride=stride, bias=bias,
116 | padding=tuple([k - 1 for k in kernel_size]))
117 |
118 | def forward(self, x):
119 | return self.convt(F.pad(x, self.pad_input, mode=self.padding_type))
120 |
121 |
122 | class Encoder(nn.Module):
123 | def __init__(self, n_hiddens, downsample, z_channels, double_z, image_channel=3, norm_type='group', padding_type='replicate'):
124 | super().__init__()
125 | n_times_downsample = np.array([int(math.log2(d)) for d in downsample])
126 | self.conv_blocks = nn.ModuleList()
127 | max_ds = n_times_downsample.max()
128 | self.conv_first = SamePadConv3d(image_channel, n_hiddens, kernel_size=3, padding_type=padding_type)
129 |
130 | for i in range(max_ds):
131 | block = nn.Module()
132 | in_channels = n_hiddens * 2**i
133 | out_channels = n_hiddens * 2**(i+1)
134 | stride = tuple([2 if d > 0 else 1 for d in n_times_downsample])
135 | block.down = SamePadConv3d(in_channels, out_channels, 4, stride=stride, padding_type=padding_type)
136 | block.res = ResBlock(out_channels, out_channels, norm_type=norm_type)
137 | self.conv_blocks.append(block)
138 | n_times_downsample -= 1
139 |
140 | self.final_block = nn.Sequential(
141 | Normalize(out_channels, norm_type),
142 | SiLU(),
143 | SamePadConv3d(out_channels, 2*z_channels if double_z else z_channels,
144 | kernel_size=3,
145 | stride=1,
146 | padding_type=padding_type)
147 | )
148 | self.out_channels = out_channels
149 |
150 | def forward(self, x):
151 | h = self.conv_first(x)
152 | for block in self.conv_blocks:
153 | h = block.down(h)
154 | h = block.res(h)
155 | h = self.final_block(h)
156 | return h
157 |
158 |
159 | class Decoder(nn.Module):
160 | def __init__(self, n_hiddens, upsample, z_channels, image_channel, norm_type='group'):
161 | super().__init__()
162 |
163 | n_times_upsample = np.array([int(math.log2(d)) for d in upsample])
164 | max_us = n_times_upsample.max()
165 | in_channels = z_channels
166 | self.conv_blocks = nn.ModuleList()
167 | for i in range(max_us):
168 | block = nn.Module()
169 | in_channels = in_channels if i ==0 else n_hiddens*2**(max_us-i+1)
170 | out_channels = n_hiddens*2**(max_us-i)
171 | us = tuple([2 if d > 0 else 1 for d in n_times_upsample])
172 | block.up = SamePadConvTranspose3d(in_channels, out_channels, 4, stride=us)
173 | block.res1 = ResBlock(out_channels, out_channels, norm_type=norm_type)
174 | block.res2 = ResBlock(out_channels, out_channels, norm_type=norm_type)
175 | self.conv_blocks.append(block)
176 | n_times_upsample -= 1
177 |
178 | self.conv_out = SamePadConv3d(out_channels, image_channel, kernel_size=3)
179 |
180 | def forward(self, x):
181 | h = x
182 | for i, block in enumerate(self.conv_blocks):
183 | h = block.up(h)
184 | h = block.res1(h)
185 | h = block.res2(h)
186 | h = self.conv_out(h)
187 | return h
188 |
--------------------------------------------------------------------------------
/lvdm/models/modules/condition_modules.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from transformers import logging
3 | from transformers import CLIPTokenizer, CLIPTextModel
4 | logging.set_verbosity_error()
5 |
6 |
7 | class AbstractEncoder(nn.Module):
8 | def __init__(self):
9 | super().__init__()
10 |
11 | def encode(self, *args, **kwargs):
12 | raise NotImplementedError
13 |
14 |
15 | class FrozenCLIPEmbedder(AbstractEncoder):
16 | """Uses the CLIP transformer encoder for text (from huggingface)"""
17 | def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77):
18 | super().__init__()
19 | self.tokenizer = CLIPTokenizer.from_pretrained(version)
20 | self.transformer = CLIPTextModel.from_pretrained(version)
21 | self.device = device
22 | self.max_length = max_length
23 | self.freeze()
24 |
25 | def freeze(self):
26 | self.transformer = self.transformer.eval()
27 | for param in self.parameters():
28 | param.requires_grad = False
29 |
30 | def forward(self, text):
31 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
32 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
33 | tokens = batch_encoding["input_ids"].to(self.device)
34 | outputs = self.transformer(input_ids=tokens)
35 |
36 | z = outputs.last_hidden_state
37 | return z
38 |
39 | def encode(self, text):
40 | return self(text)
--------------------------------------------------------------------------------
/lvdm/models/modules/distributions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | class AbstractDistribution:
6 | def sample(self):
7 | raise NotImplementedError()
8 |
9 | def mode(self):
10 | raise NotImplementedError()
11 |
12 |
13 | class DiracDistribution(AbstractDistribution):
14 | def __init__(self, value):
15 | self.value = value
16 |
17 | def sample(self):
18 | return self.value
19 |
20 | def mode(self):
21 | return self.value
22 |
23 |
24 | class DiagonalGaussianDistribution(object):
25 | def __init__(self, parameters, deterministic=False):
26 | self.parameters = parameters
27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29 | self.deterministic = deterministic
30 | self.std = torch.exp(0.5 * self.logvar)
31 | self.var = torch.exp(self.logvar)
32 | if self.deterministic:
33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
34 |
35 | def sample(self, noise=None):
36 | if noise is None:
37 | noise = torch.randn(self.mean.shape)
38 |
39 | x = self.mean + self.std * noise.to(device=self.parameters.device)
40 | return x
41 |
42 | def kl(self, other=None):
43 | if self.deterministic:
44 | return torch.Tensor([0.])
45 | else:
46 | if other is None:
47 | return 0.5 * torch.sum(torch.pow(self.mean, 2)
48 | + self.var - 1.0 - self.logvar,
49 | dim=[1, 2, 3])
50 | else:
51 | return 0.5 * torch.sum(
52 | torch.pow(self.mean - other.mean, 2) / other.var
53 | + self.var / other.var - 1.0 - self.logvar + other.logvar,
54 | dim=[1, 2, 3])
55 |
56 | def nll(self, sample, dims=[1,2,3]):
57 | if self.deterministic:
58 | return torch.Tensor([0.])
59 | logtwopi = np.log(2.0 * np.pi)
60 | return 0.5 * torch.sum(
61 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
62 | dim=dims)
63 |
64 | def mode(self):
65 | return self.mean
66 |
67 |
68 | def normal_kl(mean1, logvar1, mean2, logvar2):
69 | """
70 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
71 | Compute the KL divergence between two gaussians.
72 | Shapes are automatically broadcasted, so batches can be compared to
73 | scalars, among other use cases.
74 | """
75 | tensor = None
76 | for obj in (mean1, logvar1, mean2, logvar2):
77 | if isinstance(obj, torch.Tensor):
78 | tensor = obj
79 | break
80 | assert tensor is not None, "at least one argument must be a Tensor"
81 |
82 | # Force variances to be Tensors. Broadcasting helps convert scalars to
83 | # Tensors, but it does not work for torch.exp().
84 | logvar1, logvar2 = [
85 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
86 | for x in (logvar1, logvar2)
87 | ]
88 |
89 | return 0.5 * (
90 | -1.0
91 | + logvar2
92 | - logvar1
93 | + torch.exp(logvar1 - logvar2)
94 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
95 | )
96 |
--------------------------------------------------------------------------------
/lvdm/models/modules/ema.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class LitEma(nn.Module):
6 | def __init__(self, model, decay=0.9999, use_num_upates=True):
7 | super().__init__()
8 | if decay < 0.0 or decay > 1.0:
9 | raise ValueError('Decay must be between 0 and 1')
10 |
11 | self.m_name2s_name = {}
12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates
14 | else torch.tensor(-1,dtype=torch.int))
15 |
16 | for name, p in model.named_parameters():
17 | if p.requires_grad:
18 | #remove as '.'-character is not allowed in buffers
19 | s_name = name.replace('.','')
20 | self.m_name2s_name.update({name:s_name})
21 | self.register_buffer(s_name,p.clone().detach().data)
22 |
23 | self.collected_params = []
24 |
25 | def forward(self,model):
26 | decay = self.decay
27 |
28 | if self.num_updates >= 0:
29 | self.num_updates += 1
30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates))
31 |
32 | one_minus_decay = 1.0 - decay
33 |
34 | with torch.no_grad():
35 | m_param = dict(model.named_parameters())
36 | shadow_params = dict(self.named_buffers())
37 |
38 | for key in m_param:
39 | if m_param[key].requires_grad:
40 | sname = self.m_name2s_name[key]
41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
43 | else:
44 | assert not key in self.m_name2s_name
45 |
46 | def copy_to(self, model):
47 | m_param = dict(model.named_parameters())
48 | shadow_params = dict(self.named_buffers())
49 | for key in m_param:
50 | if m_param[key].requires_grad:
51 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
52 | else:
53 | assert not key in self.m_name2s_name
54 |
55 | def store(self, parameters):
56 | """
57 | Save the current parameters for restoring later.
58 | Args:
59 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
60 | temporarily stored.
61 | """
62 | self.collected_params = [param.clone() for param in parameters]
63 |
64 | def restore(self, parameters):
65 | """
66 | Restore the parameters stored with the `store` method.
67 | Useful to validate the model with EMA parameters without affecting the
68 | original optimization process. Store the parameters before the
69 | `copy_to` method. After validation (or model saving), use this to
70 | restore the former parameters.
71 | Args:
72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
73 | updated with the stored parameters.
74 | """
75 | for c_param, param in zip(self.collected_params, parameters):
76 | param.data.copy_(c_param.data)
77 |
--------------------------------------------------------------------------------
/lvdm/samplers/ddim.py:
--------------------------------------------------------------------------------
1 | """SAMPLING ONLY."""
2 |
3 | import torch
4 | import numpy as np
5 | from tqdm import tqdm
6 |
7 | from lvdm.models.modules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
8 |
9 |
10 | class DDIMSampler(object):
11 | def __init__(self, model, schedule="linear", **kwargs):
12 | super().__init__()
13 | self.model = model
14 | self.ddpm_num_timesteps = model.num_timesteps
15 | self.schedule = schedule
16 | self.counter = 0
17 |
18 | def register_buffer(self, name, attr):
19 | if type(attr) == torch.Tensor:
20 | if attr.device != torch.device("cuda"):
21 | attr = attr.to(torch.device("cuda"))
22 | setattr(self, name, attr)
23 |
24 | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
25 | self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
26 | num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
27 | alphas_cumprod = self.model.alphas_cumprod
28 | assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
29 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
30 |
31 | self.register_buffer('betas', to_torch(self.model.betas))
32 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
33 | self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
34 |
35 | # calculations for diffusion q(x_t | x_{t-1}) and others
36 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
37 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
38 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
39 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
40 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
41 |
42 | # ddim sampling parameters
43 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
44 | ddim_timesteps=self.ddim_timesteps,
45 | eta=ddim_eta,verbose=verbose)
46 | self.register_buffer('ddim_sigmas', ddim_sigmas)
47 | self.register_buffer('ddim_alphas', ddim_alphas)
48 | self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
49 | self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
50 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
51 | (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
52 | 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
53 | self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
54 |
55 | @torch.no_grad()
56 | def sample(self,
57 | S,
58 | batch_size,
59 | shape,
60 | conditioning=None,
61 | callback=None,
62 | img_callback=None,
63 | quantize_x0=False,
64 | eta=0.,
65 | mask=None,
66 | x0=None,
67 | temperature=1.,
68 | noise_dropout=0.,
69 | score_corrector=None,
70 | corrector_kwargs=None,
71 | verbose=True,
72 | schedule_verbose=False,
73 | x_T=None,
74 | log_every_t=100,
75 | unconditional_guidance_scale=1.,
76 | unconditional_conditioning=None,
77 | postprocess_fn=None,
78 | sample_noise=None,
79 | cond_fn=None,
80 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
81 | **kwargs
82 | ):
83 |
84 | # check condition bs
85 | if conditioning is not None:
86 | if isinstance(conditioning, dict):
87 | try:
88 | cbs = conditioning[list(conditioning.keys())[0]].shape[0]
89 | if cbs != batch_size:
90 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
91 | except:
92 | # cbs = conditioning[list(conditioning.keys())[0]][0].shape[0]
93 | pass
94 | else:
95 | if conditioning.shape[0] != batch_size:
96 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
97 |
98 | self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=schedule_verbose)
99 |
100 | # make shape
101 | if len(shape) == 3:
102 | C, H, W = shape
103 | size = (batch_size, C, H, W)
104 | elif len(shape) == 4:
105 | C, T, H, W = shape
106 | size = (batch_size, C, T, H, W)
107 |
108 | samples, intermediates = self.ddim_sampling(conditioning, size,
109 | callback=callback,
110 | img_callback=img_callback,
111 | quantize_denoised=quantize_x0,
112 | mask=mask, x0=x0,
113 | ddim_use_original_steps=False,
114 | noise_dropout=noise_dropout,
115 | temperature=temperature,
116 | score_corrector=score_corrector,
117 | corrector_kwargs=corrector_kwargs,
118 | x_T=x_T,
119 | log_every_t=log_every_t,
120 | unconditional_guidance_scale=unconditional_guidance_scale,
121 | unconditional_conditioning=unconditional_conditioning,
122 | postprocess_fn=postprocess_fn,
123 | sample_noise=sample_noise,
124 | cond_fn=cond_fn,
125 | verbose=verbose,
126 | **kwargs
127 | )
128 | return samples, intermediates
129 |
130 | @torch.no_grad()
131 | def ddim_sampling(self, cond, shape,
132 | x_T=None, ddim_use_original_steps=False,
133 | callback=None, timesteps=None, quantize_denoised=False,
134 | mask=None, x0=None, img_callback=None, log_every_t=100,
135 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
136 | unconditional_guidance_scale=1., unconditional_conditioning=None,
137 | postprocess_fn=None,sample_noise=None,cond_fn=None,
138 | uc_type=None, verbose=True, **kwargs,
139 | ):
140 |
141 | device = self.model.betas.device
142 |
143 | b = shape[0]
144 | if x_T is None:
145 | img = torch.randn(shape, device=device)
146 | else:
147 | img = x_T
148 |
149 | if timesteps is None:
150 | timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
151 | elif timesteps is not None and not ddim_use_original_steps:
152 | subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
153 | timesteps = self.ddim_timesteps[:subset_end]
154 | intermediates = {'x_inter': [img], 'pred_x0': [img]}
155 | time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
156 | total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
157 | if verbose:
158 | iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
159 | else:
160 | iterator = time_range
161 |
162 | for i, step in enumerate(iterator):
163 | index = total_steps - i - 1
164 | ts = torch.full((b,), step, device=device, dtype=torch.long)
165 |
166 | if postprocess_fn is not None:
167 | img = postprocess_fn(img, ts)
168 |
169 | outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
170 | quantize_denoised=quantize_denoised, temperature=temperature,
171 | noise_dropout=noise_dropout, score_corrector=score_corrector,
172 | corrector_kwargs=corrector_kwargs,
173 | unconditional_guidance_scale=unconditional_guidance_scale,
174 | unconditional_conditioning=unconditional_conditioning,
175 | sample_noise=sample_noise,cond_fn=cond_fn,uc_type=uc_type, **kwargs,)
176 | img, pred_x0 = outs
177 |
178 | if mask is not None:
179 | # use mask to blend x_known_t-1 & x_sample_t-1
180 | assert x0 is not None
181 | x0 = x0.to(img.device)
182 | mask = mask.to(img.device)
183 | t = torch.tensor([step-1]*x0.shape[0], dtype=torch.long, device=img.device)
184 | img_known = self.model.q_sample(x0, t)
185 | img = img_known * mask + (1. - mask) * img
186 |
187 | if callback: callback(i)
188 | if img_callback: img_callback(pred_x0, i)
189 |
190 | if index % log_every_t == 0 or index == total_steps - 1:
191 | intermediates['x_inter'].append(img)
192 | intermediates['pred_x0'].append(pred_x0)
193 |
194 | return img, intermediates
195 |
196 | @torch.no_grad()
197 | def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
198 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
199 | unconditional_guidance_scale=1., unconditional_conditioning=None, sample_noise=None,
200 | cond_fn=None, uc_type=None,
201 | **kwargs,
202 | ):
203 | b, *_, device = *x.shape, x.device
204 | if x.dim() == 5:
205 | is_video = True
206 | else:
207 | is_video = False
208 | if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
209 | e_t = self.model.apply_model(x, t, c, **kwargs) # unet denoiser
210 | else:
211 | # with unconditional condition
212 | if isinstance(c, torch.Tensor):
213 | e_t = self.model.apply_model(x, t, c, **kwargs)
214 | e_t_uncond = self.model.apply_model(x, t, unconditional_conditioning, **kwargs)
215 | elif isinstance(c, dict):
216 | e_t = self.model.apply_model(x, t, c, **kwargs)
217 | e_t_uncond = self.model.apply_model(x, t, unconditional_conditioning, **kwargs)
218 | else:
219 | raise NotImplementedError
220 |
221 | if uc_type is None:
222 | e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
223 | elif uc_type == 'cfg_original':
224 | e_t = e_t + unconditional_guidance_scale * (e_t - e_t_uncond)
225 | else:
226 | raise NotImplementedError
227 |
228 | if score_corrector is not None:
229 | assert self.model.parameterization == "eps"
230 | e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
231 |
232 | alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
233 | alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
234 | sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
235 | sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
236 | # select parameters corresponding to the currently considered timestep
237 |
238 | if is_video:
239 | size = (b, 1, 1, 1, 1)
240 | else:
241 | size = (b, 1, 1, 1)
242 | a_t = torch.full(size, alphas[index], device=device)
243 | a_prev = torch.full(size, alphas_prev[index], device=device)
244 | sigma_t = torch.full(size, sigmas[index], device=device)
245 | sqrt_one_minus_at = torch.full(size, sqrt_one_minus_alphas[index],device=device)
246 |
247 | # current prediction for x_0
248 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
249 | # print(f't={t}, pred_x0, min={torch.min(pred_x0)}, max={torch.max(pred_x0)}',file=f)
250 | if quantize_denoised:
251 | pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
252 | # direction pointing to x_t
253 | dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
254 | if sample_noise is None:
255 | noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
256 | if noise_dropout > 0.:
257 | noise = torch.nn.functional.dropout(noise, p=noise_dropout)
258 | else:
259 | noise = sigma_t * sample_noise * temperature
260 |
261 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
262 |
263 | return x_prev, pred_x0
264 |
--------------------------------------------------------------------------------
/lvdm/utils/common_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import importlib
3 | import numpy as np
4 | from inspect import isfunction
5 |
6 | import torch
7 |
8 |
9 | def shape_to_str(x):
10 | shape_str = "x".join([str(x) for x in x.shape])
11 | return shape_str
12 |
13 |
14 | def str2bool(v):
15 | if isinstance(v, bool):
16 | return v
17 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
18 | return True
19 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
20 | return False
21 | else:
22 | raise ValueError('Boolean value expected.')
23 |
24 |
25 | def get_obj_from_str(string, reload=False):
26 | module, cls = string.rsplit(".", 1)
27 | if reload:
28 | module_imp = importlib.import_module(module)
29 | importlib.reload(module_imp)
30 | return getattr(importlib.import_module(module, package=None), cls)
31 |
32 |
33 | def instantiate_from_config(config):
34 | if not "target" in config:
35 | if config == '__is_first_stage__':
36 | return None
37 | elif config == "__is_unconditional__":
38 | return None
39 | raise KeyError("Expected key `target` to instantiate.")
40 |
41 | return get_obj_from_str(config["target"])(**config.get("params", dict()))
42 |
43 |
44 | def shift_dim(x, src_dim=-1, dest_dim=-1, make_contiguous=True):
45 | """ Shifts src_tf dim to dest dim
46 | i.e. shift_dim(x, 1, -1) would be (b, c, t, h, w) -> (b, t, h, w, c)
47 | """
48 | n_dims = len(x.shape)
49 | if src_dim < 0:
50 | src_dim = n_dims + src_dim
51 | if dest_dim < 0:
52 | dest_dim = n_dims + dest_dim
53 |
54 | assert 0 <= src_dim < n_dims and 0 <= dest_dim < n_dims
55 |
56 | dims = list(range(n_dims))
57 | del dims[src_dim]
58 |
59 | permutation = []
60 | ctr = 0
61 | for i in range(n_dims):
62 | if i == dest_dim:
63 | permutation.append(src_dim)
64 | else:
65 | permutation.append(dims[ctr])
66 | ctr += 1
67 | x = x.permute(permutation)
68 | if make_contiguous:
69 | x = x.contiguous()
70 | return x
71 |
72 |
73 | def torch_to_np(x):
74 | sample = x.detach().cpu()
75 | sample = ((sample + 1) * 127.5).clamp(0, 255).to(torch.uint8)
76 | if sample.dim() == 5:
77 | sample = sample.permute(0, 2, 3, 4, 1)
78 | else:
79 | sample = sample.permute(0, 2, 3, 1)
80 | sample = sample.contiguous().numpy()
81 | return sample
82 |
83 |
84 | def np_to_torch_video(x):
85 | x = torch.tensor(x).permute(3, 0, 1, 2).float() # [t,h,w,c] -> [c,t,h,w]
86 | x = (x / 255 - 0.5) * 2
87 | return x
88 |
89 |
90 | def load_npz_from_dir(data_dir):
91 | data = [np.load(os.path.join(data_dir, data_name))['arr_0'] for data_name in os.listdir(data_dir)]
92 | data = np.concatenate(data, axis=0)
93 | return data
94 |
95 |
96 | def load_npz_from_paths(data_paths):
97 | data = [np.load(data_path)['arr_0'] for data_path in data_paths]
98 | data = np.concatenate(data, axis=0)
99 | return data
100 |
101 |
102 | def ismap(x):
103 | if not isinstance(x, torch.Tensor):
104 | return False
105 | return (len(x.shape) == 4) and (x.shape[1] > 3)
106 |
107 |
108 | def isimage(x):
109 | if not isinstance(x,torch.Tensor):
110 | return False
111 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
112 |
113 |
114 | def exists(x):
115 | return x is not None
116 |
117 |
118 | def default(val, d):
119 | if exists(val):
120 | return val
121 | return d() if isfunction(d) else d
122 |
123 |
124 | def mean_flat(tensor):
125 | """
126 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
127 | Take the mean over all non-batch dimensions.
128 | """
129 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
130 |
131 |
132 | def count_params(model, verbose=False):
133 | total_params = sum(p.numel() for p in model.parameters())
134 | if verbose:
135 | print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
136 | return total_params
137 |
138 |
139 | def check_istarget(name, para_list):
140 | """
141 | name: full name of source para
142 | para_list: partial name of target para
143 | """
144 | istarget=False
145 | for para in para_list:
146 | if para in name:
147 | return True
148 | return istarget
--------------------------------------------------------------------------------
/lvdm/utils/dist_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.distributed as dist
3 |
4 | def setup_dist(local_rank):
5 | if dist.is_initialized():
6 | return
7 | torch.cuda.set_device(local_rank)
8 | torch.distributed.init_process_group(
9 | 'nccl',
10 | init_method='env://'
11 | )
12 |
13 | def gather_data(data, return_np=True):
14 | ''' gather data from multiple processes to one list '''
15 | data_list = [torch.zeros_like(data) for _ in range(dist.get_world_size())]
16 | dist.all_gather(data_list, data)
17 | if return_np:
18 | data_list = [data.cpu().numpy() for data in data_list]
19 | return data_list
20 |
--------------------------------------------------------------------------------
/lvdm/utils/log.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | import multiprocessing as mproc
4 | import types
5 |
6 | #: number of available CPUs on this computer
7 | CPU_COUNT = int(mproc.cpu_count())
8 | #: default date-time format
9 | FORMAT_DATE_TIME = '%Y%m%d-%H%M%S'
10 | #: default logging tile
11 | FILE_LOGS = 'logging.log'
12 | #: default logging template - log location/source for logging to file
13 | STR_LOG_FORMAT = '%(asctime)s:%(levelname)s@%(filename)s:%(processName)s - %(message)s'
14 | #: default logging template - date-time for logging to file
15 | LOG_FILE_FORMAT = logging.Formatter(STR_LOG_FORMAT, datefmt="%H:%M:%S")
16 | #: define all types to be assume list like
17 | ITERABLE_TYPES = (list, tuple, types.GeneratorType)
18 | def release_logger_files():
19 | """ close all handlers to a file
20 | >>> release_logger_files()
21 | >>> len([1 for lh in logging.getLogger().handlers
22 | ... if type(lh) is logging.FileHandler])
23 | 0
24 | """
25 | for hl in logging.getLogger().handlers:
26 | if isinstance(hl, logging.FileHandler):
27 | hl.close()
28 | logging.getLogger().removeHandler(hl)
29 |
30 | def set_experiment_logger(path_out, file_name=FILE_LOGS, reset=True):
31 | """ set the logger to file
32 | :param str path_out: path to the output folder
33 | :param str file_name: log file name
34 | :param bool reset: reset all previous logging into a file
35 | >>> set_experiment_logger('.')
36 | >>> len([1 for lh in logging.getLogger().handlers
37 | ... if type(lh) is logging.FileHandler])
38 | 1
39 | >>> release_logger_files()
40 | >>> os.remove(FILE_LOGS)
41 | """
42 | log = logging.getLogger()
43 | log.setLevel(logging.DEBUG)
44 |
45 | if reset:
46 | release_logger_files()
47 | path_logger = os.path.join(path_out, file_name)
48 | fh = logging.FileHandler(path_logger)
49 | fh.setLevel(logging.DEBUG)
50 | fh.setFormatter(LOG_FILE_FORMAT)
51 | log.addHandler(fh)
52 |
53 | def set_ptl_logger(path_out, phase, file_name="ptl.log", reset=True):
54 | """ set the logger to file
55 | :param str path_out: path to the output folder
56 | :param str file_name: log file name
57 | :param bool reset: reset all previous logging into a file
58 | >>> set_experiment_logger('.')
59 | >>> len([1 for lh in logging.getLogger().handlers
60 | ... if type(lh) is logging.FileHandler])
61 | 1
62 | >>> release_logger_files()
63 | >>> os.remove(FILE_LOGS)
64 | """
65 | file_name = f"ptl_{phase}.log"
66 | level = logging.INFO
67 | log = logging.getLogger("pytorch_lightning")
68 | log.setLevel(level)
69 |
70 | if reset:
71 | release_logger_files()
72 |
73 | path_logger = os.path.join(path_out, file_name)
74 | fh = logging.FileHandler(path_logger)
75 | fh.setLevel(level)
76 | fh.setFormatter(LOG_FILE_FORMAT)
77 |
78 | log.addHandler(fh)
79 |
--------------------------------------------------------------------------------
/lvdm/utils/saving_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | sys.path.insert(1, os.path.join(sys.path[0], '..'))
4 | import cv2
5 | import os
6 | import time
7 | import imageio
8 | import numpy as np
9 | from PIL import Image
10 | from tqdm import tqdm
11 | from PIL import Image, ImageDraw, ImageFont
12 |
13 | import torch
14 | import torchvision
15 | from torch import Tensor
16 | from torchvision.utils import make_grid
17 | from torchvision.transforms.functional import to_tensor
18 |
19 |
20 | def tensor_to_mp4(video, savepath, fps, rescale=True, nrow=None):
21 | """
22 | video: torch.Tensor, b,c,t,h,w, 0-1
23 | if -1~1, enable rescale=True
24 | """
25 | n = video.shape[0]
26 | video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w
27 | nrow = int(np.sqrt(n)) if nrow is None else nrow
28 | frame_grids = [torchvision.utils.make_grid(framesheet, nrow=nrow) for framesheet in video] # [3, grid_h, grid_w]
29 | grid = torch.stack(frame_grids, dim=0) # stack in temporal dim [T, 3, grid_h, grid_w]
30 | grid = torch.clamp(grid.float(), -1., 1.)
31 | if rescale:
32 | grid = (grid + 1.0) / 2.0
33 | grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1) # [T, 3, grid_h, grid_w] -> [T, grid_h, grid_w, 3]
34 | #print(f'Save video to {savepath}')
35 | torchvision.io.write_video(savepath, grid, fps=fps, video_codec='h264', options={'crf': '10'})
36 |
37 | # ----------------------------------------------------------------------------------------------
38 | def savenp2sheet(imgs, savepath, nrow=None):
39 | """ save multiple imgs (in numpy array type) to a img sheet.
40 | img sheet is one row.
41 |
42 | imgs:
43 | np array of size [N, H, W, 3] or List[array] with array size = [H,W,3]
44 | """
45 | if imgs.ndim == 4:
46 | img_list = [imgs[i] for i in range(imgs.shape[0])]
47 | imgs = img_list
48 |
49 | imgs_new = []
50 | for i, img in enumerate(imgs):
51 | if img.ndim == 3 and img.shape[0] == 3:
52 | img = np.transpose(img,(1,2,0))
53 |
54 | assert(img.ndim == 3 and img.shape[-1] == 3), img.shape # h,w,3
55 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
56 | imgs_new.append(img)
57 | n = len(imgs)
58 | if nrow is not None:
59 | n_cols = nrow
60 | else:
61 | n_cols=int(n**0.5)
62 | n_rows=int(np.ceil(n/n_cols))
63 | print(n_cols)
64 | print(n_rows)
65 |
66 | imgsheet = cv2.vconcat([cv2.hconcat(imgs_new[i*n_cols:(i+1)*n_cols]) for i in range(n_rows)])
67 | cv2.imwrite(savepath, imgsheet)
68 | print(f'saved in {savepath}')
69 |
70 | # ----------------------------------------------------------------------------------------------
71 | def save_np_to_img(img, path, norm=True):
72 | if norm:
73 | img = (img + 1) / 2 * 255
74 | img = img.astype(np.uint8)
75 | image = Image.fromarray(img)
76 | image.save(path, q=95)
77 |
78 | # ----------------------------------------------------------------------------------------------
79 | def npz_to_imgsheet_5d(data_path, res_dir, nrow=None,):
80 | if isinstance(data_path, str):
81 | imgs = np.load(data_path)['arr_0'] # NTHWC
82 | elif isinstance(data_path, np.ndarray):
83 | imgs = data_path
84 | else:
85 | raise Exception
86 |
87 | if os.path.isdir(res_dir):
88 | res_path = os.path.join(res_dir, f'samples.jpg')
89 | else:
90 | assert(res_dir.endswith('.jpg'))
91 | res_path = res_dir
92 | imgs = np.concatenate([imgs[i] for i in range(imgs.shape[0])], axis=0)
93 | savenp2sheet(imgs, res_path, nrow=nrow)
94 |
95 | # ----------------------------------------------------------------------------------------------
96 | def npz_to_imgsheet_4d(data_path, res_path, nrow=None,):
97 | if isinstance(data_path, str):
98 | imgs = np.load(data_path)['arr_0'] # NHWC
99 | elif isinstance(data_path, np.ndarray):
100 | imgs = data_path
101 | else:
102 | raise Exception
103 | print(imgs.shape)
104 | savenp2sheet(imgs, res_path, nrow=nrow)
105 |
106 |
107 | # ----------------------------------------------------------------------------------------------
108 | def tensor_to_imgsheet(tensor, save_path):
109 | """
110 | save a batch of videos in one image sheet with shape of [batch_size * num_frames].
111 | data: [b,c,t,h,w]
112 | """
113 | assert(tensor.dim() == 5)
114 | b,c,t,h,w = tensor.shape
115 | imgs = [tensor[bi,:,ti, :, :] for bi in range(b) for ti in range(t)]
116 | torchvision.utils.save_image(imgs, save_path, normalize=True, nrow=t)
117 |
118 |
119 | # ----------------------------------------------------------------------------------------------
120 | def npz_to_frames(data_path, res_dir, norm, num_frames=None, num_samples=None):
121 | start = time.time()
122 | arr = np.load(data_path)
123 | imgs = arr['arr_0'] # [N, T, H, W, 3]
124 | print('original data shape: ', imgs.shape)
125 |
126 | if num_samples is not None:
127 | imgs = imgs[:num_samples, :, :, :, :]
128 | print('after sample selection: ', imgs.shape)
129 |
130 | if num_frames is not None:
131 | imgs = imgs[:, :num_frames, :, :, :]
132 | print('after frame selection: ', imgs.shape)
133 |
134 | for vid in tqdm(range(imgs.shape[0]), desc='Video'):
135 | video_dir = os.path.join(res_dir, f'video{vid:04d}')
136 | os.makedirs(video_dir, exist_ok=True)
137 | for fid in range(imgs.shape[1]):
138 | frame = imgs[vid, fid, :, :, :] #HW3
139 | save_np_to_img(frame, os.path.join(video_dir, f'frame{fid:04d}.jpg'), norm=norm)
140 | print('Finish')
141 | print(f'Total time = {time.time()- start}')
142 |
143 | # ----------------------------------------------------------------------------------------------
144 | def npz_to_gifs(data_path, res_dir, duration=0.2, start_idx=0, num_videos=None, mode='gif'):
145 | os.makedirs(res_dir, exist_ok=True)
146 | if isinstance(data_path, str):
147 | imgs = np.load(data_path)['arr_0'] # NTHWC
148 | elif isinstance(data_path, np.ndarray):
149 | imgs = data_path
150 | else:
151 | raise Exception
152 |
153 | for i in range(imgs.shape[0]):
154 | frames = [imgs[i,j,:,:,:] for j in range(imgs[i].shape[0])] # [(h,w,3)]
155 | if mode == 'gif':
156 | imageio.mimwrite(os.path.join(res_dir, f'samples_{start_idx+i}.gif'), frames, format='GIF', duration=duration)
157 | elif mode == 'mp4':
158 | frames = [torch.from_numpy(frame) for frame in frames]
159 | frames = torch.stack(frames, dim=0).to(torch.uint8) # [T, H, W, C]
160 | torchvision.io.write_video(os.path.join(res_dir, f'samples_{start_idx+i}.mp4'),
161 | frames, fps=0.5, video_codec='h264', options={'crf': '10'})
162 | if i+ 1 == num_videos:
163 | break
164 |
165 | # ----------------------------------------------------------------------------------------------
166 | def fill_with_black_squares(video, desired_len: int) -> Tensor:
167 | if len(video) >= desired_len:
168 | return video
169 |
170 | return torch.cat([
171 | video,
172 | torch.zeros_like(video[0]).unsqueeze(0).repeat(desired_len - len(video), 1, 1, 1),
173 | ], dim=0)
174 |
175 | # ----------------------------------------------------------------------------------------------
176 | def load_num_videos(data_path, num_videos):
177 | # data_path can be either data_path of np array
178 | if isinstance(data_path, str):
179 | videos = np.load(data_path)['arr_0'] # NTHWC
180 | elif isinstance(data_path, np.ndarray):
181 | videos = data_path
182 | else:
183 | raise Exception
184 |
185 | if num_videos is not None:
186 | videos = videos[:num_videos, :, :, :, :]
187 | return videos
188 |
189 | # ----------------------------------------------------------------------------------------------
190 | def npz_to_video_grid(data_path, out_path, num_frames=None, fps=8, num_videos=None, nrow=None, verbose=True):
191 | if isinstance(data_path, str):
192 | videos = load_num_videos(data_path, num_videos)
193 | elif isinstance(data_path, np.ndarray):
194 | videos = data_path
195 | else:
196 | raise Exception
197 | n,t,h,w,c = videos.shape
198 |
199 | videos_th = []
200 | for i in range(n):
201 | video = videos[i, :,:,:,:]
202 | images = [video[j, :,:,:] for j in range(t)]
203 | images = [to_tensor(img) for img in images]
204 | video = torch.stack(images)
205 | videos_th.append(video)
206 |
207 | if num_frames is None:
208 | num_frames = videos.shape[1]
209 | if verbose:
210 | videos = [fill_with_black_squares(v, num_frames) for v in tqdm(videos_th, desc='Adding empty frames')] # NTCHW
211 | else:
212 | videos = [fill_with_black_squares(v, num_frames) for v in videos_th] # NTCHW
213 |
214 | frame_grids = torch.stack(videos).permute(1, 0, 2, 3, 4) # [T, N, C, H, W]
215 | if nrow is None:
216 | nrow = int(np.ceil(np.sqrt(n)))
217 | if verbose:
218 | frame_grids = [make_grid(fs, nrow=nrow) for fs in tqdm(frame_grids, desc='Making grids')]
219 | else:
220 | frame_grids = [make_grid(fs, nrow=nrow) for fs in frame_grids]
221 |
222 | if os.path.dirname(out_path) != "":
223 | os.makedirs(os.path.dirname(out_path), exist_ok=True)
224 | frame_grids = (torch.stack(frame_grids) * 255).to(torch.uint8).permute(0, 2, 3, 1) # [T, H, W, C]
225 | torchvision.io.write_video(out_path, frame_grids, fps=fps, video_codec='h264', options={'crf': '10'})
226 |
227 | # ----------------------------------------------------------------------------------------------
228 | def npz_to_gif_grid(data_path, out_path, n_cols=None, num_videos=20):
229 | arr = np.load(data_path)
230 | imgs = arr['arr_0'] # [N, T, H, W, 3]
231 | imgs = imgs[:num_videos]
232 | n, t, h, w, c = imgs.shape
233 | assert(n == num_videos)
234 | n_cols = n_cols if n_cols else imgs.shape[0]
235 | n_rows = np.ceil(imgs.shape[0] / n_cols).astype(np.int8)
236 | H, W = h * n_rows, w * n_cols
237 | grid = np.zeros((t, H, W, c), dtype=np.uint8)
238 |
239 | for i in range(n_rows):
240 | for j in range(n_cols):
241 | if i*n_cols+j < imgs.shape[0]:
242 | grid[:, i*h:(i+1)*h, j*w:(j+1)*w, :] = imgs[i*n_cols+j, :, :, :, :]
243 |
244 | videos = [grid[i] for i in range(grid.shape[0])] # grid: TH'W'C
245 | imageio.mimwrite(out_path, videos, format='GIF', duration=0.5,palettesize=256)
246 |
247 |
248 | # ----------------------------------------------------------------------------------------------
249 | def torch_to_video_grid(videos, out_path, num_frames, fps, num_videos=None, nrow=None, verbose=True):
250 | """
251 | videos: -1 ~ 1, torch.Tensor, BCTHW
252 | """
253 | n,t,h,w,c = videos.shape
254 | videos_th = [videos[i, ...] for i in range(n)]
255 | if verbose:
256 | videos = [fill_with_black_squares(v, num_frames) for v in tqdm(videos_th, desc='Adding empty frames')] # NTCHW
257 | else:
258 | videos = [fill_with_black_squares(v, num_frames) for v in videos_th] # NTCHW
259 |
260 | frame_grids = torch.stack(videos).permute(1, 0, 2, 3, 4) # [T, N, C, H, W]
261 | if nrow is None:
262 | nrow = int(np.ceil(np.sqrt(n)))
263 | if verbose:
264 | frame_grids = [make_grid(fs, nrow=nrow) for fs in tqdm(frame_grids, desc='Making grids')]
265 | else:
266 | frame_grids = [make_grid(fs, nrow=nrow) for fs in frame_grids]
267 |
268 | if os.path.dirname(out_path) != "":
269 | os.makedirs(os.path.dirname(out_path), exist_ok=True)
270 | frame_grids = ((torch.stack(frame_grids) + 1) / 2 * 255).to(torch.uint8).permute(0, 2, 3, 1) # [T, H, W, C]
271 | torchvision.io.write_video(out_path, frame_grids, fps=fps, video_codec='h264', options={'crf': '10'})
272 |
273 |
274 | def log_txt_as_img(wh, xc, size=10):
275 | # wh a tuple of (width, height)
276 | # xc a list of captions to plot
277 | b = len(xc)
278 | txts = list()
279 | for bi in range(b):
280 | txt = Image.new("RGB", wh, color="white")
281 | draw = ImageDraw.Draw(txt)
282 | font = ImageFont.truetype('data/DejaVuSans.ttf', size=size)
283 | nc = int(40 * (wh[0] / 256))
284 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
285 |
286 | try:
287 | draw.text((0, 0), lines, fill="black", font=font)
288 | except UnicodeEncodeError:
289 | print("Cant encode string for logging. Skipping.")
290 |
291 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
292 | txts.append(txt)
293 | txts = np.stack(txts)
294 | txts = torch.tensor(txts)
295 | return txts
296 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | decord==0.6.0
2 | einops==0.3.0
3 | eval==0.0.5
4 | imageio==2.9.0
5 | numpy==1.19.2
6 | omegaconf==2.3.0
7 | opencv_python_headless==4.6.0.66
8 | packaging==21.3
9 | Pillow==9.5.0
10 | pudb==2019.2
11 | pytorch_lightning==1.4.2
12 | PyYAML==6.0
13 | scikit_learn==1.2.2
14 | setuptools==59.5.0
15 | torch==1.9.0+cu111
16 | torchvision==0.10.0+cu111
17 | -f https://download.pytorch.org/whl/torch_stable.html
18 | tqdm==4.64.0
19 | transformers==4.24.0
20 | torchmetrics==0.6.0
21 | moviepy
22 | av
23 | six
24 | timm
25 | test-tube
26 | -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
27 | -e .
--------------------------------------------------------------------------------
/requirements_h800_gpu.txt:
--------------------------------------------------------------------------------
1 | decord==0.6.0
2 | einops==0.3.0
3 | eval==0.0.5
4 | imageio==2.9.0
5 | numpy==1.19.2
6 | omegaconf==2.3.0
7 | opencv_python_headless==4.6.0.66
8 | packaging==21.3
9 | Pillow==9.5.0
10 | pudb==2019.2
11 | pytorch_lightning==1.4.2
12 | PyYAML==6.0
13 | scikit_learn==1.2.2
14 | setuptools==59.5.0
15 | torch==2.2.2
16 | torchvision==0.17.2
17 | tqdm==4.64.0
18 | transformers==4.24.0
19 | torchmetrics==0.6.0
20 | moviepy
21 | av
22 | six
23 | timm
24 | test-tube
25 | -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
26 | -e .
--------------------------------------------------------------------------------
/scripts/eval_cal_fvd_kvd.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import sys
4 | import time
5 | import math
6 | import json
7 | import argparse
8 | import numpy as np
9 | from tqdm import tqdm
10 | from omegaconf import OmegaConf
11 |
12 | import torch
13 | from torch.utils.data import DataLoader
14 |
15 | from lvdm.utils.common_utils import instantiate_from_config, shift_dim
16 | from scripts.fvd_utils.fvd_utils import get_fvd_logits, frechet_distance, load_fvd_model, polynomial_mmd
17 |
18 | def get_args():
19 | parser = argparse.ArgumentParser()
20 | parser.add_argument('--fake_path', type=str, default='fake data path')
21 | parser.add_argument('--real_path', type=str, default='real data dir')
22 | parser.add_argument('--yaml', type=str, default=None, help="training config file for construct dataloader")
23 | parser.add_argument('--batch_size', type=int, default=32)
24 | parser.add_argument('--num_workers', type=int, default=8)
25 | parser.add_argument('--n_runs', type=int, default=1, help='calculate multiple times')
26 | parser.add_argument('--gpu_id', type=int, default=None)
27 | parser.add_argument('--res_dir', type=str, default='')
28 | parser.add_argument('--n_sample', type=int, default=2048)
29 | parser.add_argument('--start_clip_id', type=int, default=None, help="for the evaluation of long video generation")
30 | parser.add_argument('--end_clip_id', type=int, default=None, help="for the evaluation of long video generation")
31 | args = parser.parse_args()
32 | return args
33 |
34 | def run(args, run_id, fake_path, i3d, loader, device):
35 | start = time.time()
36 |
37 | print('load fake videos in numpy ...')
38 | if fake_path.endswith('npz'):
39 | fake_data = np.load(fake_path)['arr_0']
40 | elif fake_path.endswith('npy'):
41 | fake_data = np.load(fake_path)
42 | else:
43 | print(fake_path)
44 | raise NotImplementedError
45 |
46 | s = time.time()
47 | real_embeddings = []
48 | for batch in tqdm(loader, desc="Extract Real Embedding", total=math.ceil(args.n_sample / args.batch_size)):
49 | if len(real_embeddings)*args.batch_size >=args.n_sample: break
50 | videos = shift_dim((batch[args.image_key]+1)*255/2, 1, -1).int().data.numpy() # convert to 0-255
51 | real_embeddings.append(get_fvd_logits(videos, i3d=i3d, device=device, batch_size=args.batch_size))
52 | real_embeddings = torch.cat(real_embeddings, 0)[:args.n_sample]
53 | t = time.time() - s
54 | s = time.time()
55 | fake_embeddings = []
56 | n_batch = fake_data.shape[0]//args.batch_size
57 | for i in tqdm(range(n_batch), desc="Extract Fake Embedding"):
58 | fake_embeddings.append(get_fvd_logits(fake_data[i*args.batch_size:(i+1)*args.batch_size], i3d=i3d, device=device, batch_size=args.batch_size))
59 | fake_embeddings = torch.cat(fake_embeddings, 0)[:args.n_sample]
60 | t = time.time() - s
61 |
62 | print('calculate fvd ...')
63 | fvd = frechet_distance(fake_embeddings, real_embeddings)
64 | fvd = fvd.cpu().numpy() # np float32
65 |
66 | print('calculate kvd ...')
67 | kvd = polynomial_mmd(fake_embeddings.cpu(), real_embeddings.cpu()) # np float 64
68 | total = time.time() - start
69 |
70 | print(f'Run_id = {run_id}')
71 | print(f'FVD = {fvd:.2f}')
72 | print(f'KVD = {kvd:.2f}')
73 | print(f'Time = {total:.2f}')
74 | return [fvd, kvd, total]
75 |
76 | def run_multitimes(args, i3d, loader, device):
77 | res_all = []
78 | for i in range(args.n_runs):
79 | run_id = i
80 | res = run(args, run_id, fake_path=args.fake_path, i3d=i3d, loader=loader, device=device)
81 | res_all.append(np.array(res))
82 | res_avg = np.mean(np.stack(res_all, axis=0), axis=0)
83 | res_std = np.std(np.stack(res_all, axis=0), axis=0)
84 |
85 | print(f'Results of {args.n_runs} runs:')
86 | print(f'FVD = {res_avg[0]} ({res_std[0]})')
87 | print(f'KVD = {res_avg[1]} ({res_std[1]})')
88 | print(f'Time = {res_avg[2]} ({res_std[2]})')
89 |
90 | # dump results
91 | res={'FVD': f'{res_avg[0]} ({res_std[0]})',
92 | 'KVD': f'{res_avg[1]} ({res_std[1]})',
93 | 'Time': f'{res_avg[2]} ({res_std[2]})',
94 | 'Clip_path': f'{args.fake_path}'
95 | }
96 | f = open(os.path.join(args.res_dir, f'{args.n_runs}runs_fvd_stat.json'), 'w')
97 | json.dump(res, f)
98 | f.close()
99 |
100 | def run_multitimes_dir(args, i3d, loader, device):
101 | datalist = sorted(os.listdir(args.fake_path))
102 | for i in range(args.n_runs):
103 |
104 | run_id = i
105 |
106 | for idx, path in tqdm(enumerate(datalist)):
107 | if args.start_clip_id is not None and idx < args.start_clip_id:
108 | continue
109 |
110 | print(f'Cal metrics for clip: {idx}, data: {path}')
111 | fvd, kvd, total = run(args, run_id, fake_path=os.path.join(args.fake_path, path),
112 | i3d=i3d, loader=loader, device=device)
113 | print(f"Run id {run_id}, Clip {idx}, FVD={fvd}, KVD={kvd}, Time={total}")
114 |
115 | # dump results
116 | fvd = float(fvd) if isinstance(fvd, np.ndarray) else fvd
117 | kvd = float(kvd) if isinstance(kvd, np.ndarray) else kvd
118 |
119 | res={'FVD': fvd, 'KVD': kvd, 'Time': total, 'Clip_path': path}
120 | f = open(os.path.join(args.res_dir, f'run{run_id}_clip{idx}_{device}.json'), 'w')
121 | json.dump(res, f)
122 | f.close()
123 |
124 | if args.end_clip_id is not None and idx == args.end_clip_id:
125 | break
126 |
127 |
128 | if __name__ == '__main__':
129 | args = get_args()
130 | print(args)
131 | os.makedirs(args.res_dir, exist_ok=True)
132 |
133 | print('load i3d ...')
134 | if args.gpu_id is not None:
135 | device = torch.device(f'cuda:{args.gpu_id}')
136 | else:
137 | device = torch.device('cuda')
138 | i3d = load_fvd_model(device)
139 |
140 | print('prepare dataset and dataloader ...')
141 | config = OmegaConf.load(args.yaml)
142 | config.data.params.train.params.data_root=args.real_path
143 | dataset = instantiate_from_config(config.data.params.train)
144 | if 'first_stage_key' in config.model.params:
145 | args.image_key = config.model.params.first_stage_key
146 | loader = DataLoader(dataset,
147 | batch_size=args.batch_size,
148 | num_workers=args.num_workers,
149 | shuffle=True)
150 |
151 | # run
152 | if os.path.isdir(args.fake_path):
153 | run_multitimes_dir(args, i3d, loader, device)
154 | else:
155 | run_multitimes(args, i3d, loader, device)
156 |
157 |
--------------------------------------------------------------------------------
/scripts/fvd_utils/fvd_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | import torch.nn.functional as F
5 | from sklearn.metrics.pairwise import polynomial_kernel
6 |
7 | from scripts.fvd_utils.pytorch_i3d import InceptionI3d
8 |
9 |
10 | MAX_BATCH = 8
11 | TARGET_RESOLUTION = (224, 224)
12 |
13 |
14 | def preprocess(videos, target_resolution):
15 | # videos in {0, ..., 255} as np.uint8 array
16 | b, t, h, w, c = videos.shape
17 | all_frames = torch.FloatTensor(videos).flatten(end_dim=1) # (b * t, h, w, c)
18 | all_frames = all_frames.permute(0, 3, 1, 2).contiguous() # (b * t, c, h, w)
19 | resized_videos = F.interpolate(all_frames, size=target_resolution,
20 | mode='bilinear', align_corners=False)
21 | resized_videos = resized_videos.view(b, t, c, *target_resolution)
22 | output_videos = resized_videos.transpose(1, 2).contiguous() # (b, c, t, *)
23 | scaled_videos = 2. * output_videos / 255. - 1 # [-1, 1]
24 | return scaled_videos
25 |
26 |
27 | def get_logits(i3d, videos, device, batch_size=None):
28 | if batch_size is None:
29 | batch_size = MAX_BATCH
30 | with torch.no_grad():
31 | logits = []
32 | for i in range(0, videos.shape[0], batch_size):
33 | batch = videos[i:i + batch_size].to(device)
34 | logits.append(i3d(batch))
35 | logits = torch.cat(logits, dim=0)
36 | return logits
37 |
38 |
39 | def get_fvd_logits(videos, i3d, device, batch_size=None):
40 | videos = preprocess(videos, TARGET_RESOLUTION)
41 | embeddings = get_logits(i3d, videos, device, batch_size=batch_size)
42 | return embeddings
43 |
44 |
45 | def load_fvd_model(device):
46 | i3d = InceptionI3d(400, in_channels=3).to(device)
47 | current_dir = os.path.dirname(os.path.abspath(__file__))
48 | i3d_path = os.path.join(current_dir, 'i3d_pretrained_400.pt')
49 | i3d.load_state_dict(torch.load(i3d_path, map_location=device))
50 | i3d.eval()
51 | return i3d
52 |
53 |
54 | # https://github.com/tensorflow/gan/blob/de4b8da3853058ea380a6152bd3bd454013bf619/tensorflow_gan/python/eval/classifier_metrics.py#L161
55 | def _symmetric_matrix_square_root(mat, eps=1e-10):
56 | u, s, v = torch.svd(mat)
57 | si = torch.where(s < eps, s, torch.sqrt(s))
58 | return torch.matmul(torch.matmul(u, torch.diag(si)), v.t())
59 |
60 | # https://github.com/tensorflow/gan/blob/de4b8da3853058ea380a6152bd3bd454013bf619/tensorflow_gan/python/eval/classifier_metrics.py#L400
61 | def trace_sqrt_product(sigma, sigma_v):
62 | sqrt_sigma = _symmetric_matrix_square_root(sigma)
63 | sqrt_a_sigmav_a = torch.matmul(sqrt_sigma, torch.matmul(sigma_v, sqrt_sigma))
64 | return torch.trace(_symmetric_matrix_square_root(sqrt_a_sigmav_a))
65 |
66 | # https://discuss.pytorch.org/t/covariance-and-gradient-support/16217/2
67 | def cov(m, rowvar=False):
68 | '''Estimate a covariance matrix given data.
69 |
70 | Covariance indicates the level to which two variables vary together.
71 | If we examine N-dimensional samples, `X = [x_1, x_2, ... x_N]^T`,
72 | then the covariance matrix element `C_{ij}` is the covariance of
73 | `x_i` and `x_j`. The element `C_{ii}` is the variance of `x_i`.
74 |
75 | Args:
76 | m: A 1-D or 2-D array containing multiple variables and observations.
77 | Each row of `m` represents a variable, and each column a single
78 | observation of all those variables.
79 | rowvar: If `rowvar` is True, then each row represents a
80 | variable, with observations in the columns. Otherwise, the
81 | relationship is transposed: each column represents a variable,
82 | while the rows contain observations.
83 |
84 | Returns:
85 | The covariance matrix of the variables.
86 | '''
87 | if m.dim() > 2:
88 | raise ValueError('m has more than 2 dimensions')
89 | if m.dim() < 2:
90 | m = m.view(1, -1)
91 | if not rowvar and m.size(0) != 1:
92 | m = m.t()
93 |
94 | fact = 1.0 / (m.size(1) - 1) # unbiased estimate
95 | m_center = m - torch.mean(m, dim=1, keepdim=True)
96 | mt = m_center.t() # if complex: mt = m.t().conj()
97 | return fact * m_center.matmul(mt).squeeze()
98 |
99 |
100 | def frechet_distance(x1, x2):
101 | x1 = x1.flatten(start_dim=1)
102 | x2 = x2.flatten(start_dim=1)
103 | m, m_w = x1.mean(dim=0), x2.mean(dim=0)
104 | sigma, sigma_w = cov(x1, rowvar=False), cov(x2, rowvar=False)
105 | sqrt_trace_component = trace_sqrt_product(sigma, sigma_w)
106 | trace = torch.trace(sigma + sigma_w) - 2.0 * sqrt_trace_component
107 | mean = torch.sum((m - m_w) ** 2)
108 | fd = trace + mean
109 | return fd
110 |
111 |
112 |
113 | def polynomial_mmd(X, Y):
114 | m = X.shape[0]
115 | n = Y.shape[0]
116 | # compute kernels
117 | K_XX = polynomial_kernel(X)
118 | K_YY = polynomial_kernel(Y)
119 | K_XY = polynomial_kernel(X, Y)
120 | # compute mmd distance
121 | K_XX_sum = (K_XX.sum() - np.diagonal(K_XX).sum()) / (m * (m - 1))
122 | K_YY_sum = (K_YY.sum() - np.diagonal(K_YY).sum()) / (n * (n - 1))
123 | K_XY_sum = K_XY.sum() / (m * n)
124 | mmd = K_XX_sum + K_YY_sum - 2 * K_XY_sum
125 | return mmd
126 |
--------------------------------------------------------------------------------
/scripts/sample_long_videos_utils.py:
--------------------------------------------------------------------------------
1 | import time
2 | import torch
3 | from tqdm import trange
4 | from einops import repeat
5 | from lvdm.samplers.ddim import DDIMSampler
6 | from lvdm.models.ddpm3d import FrameInterpPredLatentDiffusion
7 |
8 | # ------------------------------------------------------------------------------------------
9 | def add_conditions_long_video(model, cond, batch_size,
10 | sample_cond_noise_level=None,
11 | mask=None,
12 | input_shape=None,
13 | T_cond_pred=None,
14 | cond_frames=None,
15 | ):
16 | assert(isinstance(cond, dict))
17 | try:
18 | device = model.device
19 | except:
20 | device = next(model.parameters()).device
21 |
22 | # add cond noisy level
23 | if sample_cond_noise_level is not None:
24 | if getattr(model, "noisy_cond", False):
25 | assert(sample_cond_noise_level is not None)
26 | s = sample_cond_noise_level
27 | s = repeat(torch.tensor([s]), '1 -> b', b=batch_size)
28 | s = s.to(device).long()
29 | else:
30 | s = None
31 | cond['s'] =s
32 |
33 | # add cond mask
34 | if mask is not None:
35 | if mask == "uncond":
36 | mask = torch.zeros(batch_size, 1, *input_shape[2:], device=device)
37 | elif mask == "pred":
38 | mask = torch.zeros(batch_size, 1, *input_shape[2:], device=device)
39 | mask[:, :, :T_cond_pred, :, :] = 1
40 | elif mask == "interp":
41 | mask = torch.zeros(batch_size, 1, *input_shape[2:], device=device)
42 | mask[:, :, 0, :, :] = 1
43 | mask[:, :, -1, :, :] = 1
44 | else:
45 | raise NotImplementedError
46 | cond['mask'] = mask
47 |
48 | # add cond_frames
49 | if cond_frames is not None:
50 | if sample_cond_noise_level is not None:
51 | noise = torch.randn_like(cond_frames)
52 | noisy_cond = model.q_sample(x_start=cond_frames, t=s, noise=noise)
53 | else:
54 | noisy_cond = cond_frames
55 | cond['noisy_cond'] = noisy_cond
56 |
57 | return cond
58 |
59 | # ------------------------------------------------------------------------------------------
60 | @torch.no_grad()
61 | def sample_clip(model, shape, sample_type="ddpm", ddim_steps=None, eta=1.0, cond=None,
62 | uc=None, uncond_scale=1., uc_type=None, **kwargs):
63 |
64 | log = dict()
65 |
66 | # sample batch
67 | with model.ema_scope("EMA"):
68 | t0 = time.time()
69 | if sample_type == "ddpm":
70 | samples = model.p_sample_loop(cond, shape, return_intermediates=False, verbose=False,
71 | unconditional_guidance_scale=uncond_scale,
72 | unconditional_conditioning=uc, uc_type=uc_type, )
73 | elif sample_type == "ddim":
74 | ddim = DDIMSampler(model)
75 | samples, intermediates = ddim.sample(S=ddim_steps, batch_size=shape[0], shape=shape[1:],
76 | conditioning=cond, eta=eta, verbose=False,
77 | unconditional_guidance_scale=uncond_scale,
78 | unconditional_conditioning=uc, uc_type=uc_type,)
79 |
80 | t1 = time.time()
81 | log["sample"] = samples
82 | log["time"] = t1 - t0
83 | log['throughput'] = samples.shape[0] / (t1 - t0)
84 | return log
85 |
86 | # ------------------------------------------------------------------------------------------
87 | @torch.no_grad()
88 | def autoregressive_pred(model, batch_size, *args,
89 | T_cond=1,
90 | n_pred_steps=3,
91 | sample_cond_noise_level=None,
92 | decode_single_video_allframes=False,
93 | uncond_scale=1.0,
94 | uc_type=None,
95 | max_z_t=None,
96 | overlap_t=0,
97 | **kwargs):
98 |
99 | model.sample_cond_noise_level = sample_cond_noise_level
100 |
101 | image_size = model.image_size
102 | image_size = [image_size, image_size] if isinstance(image_size, int) else image_size
103 | C = model.model.diffusion_model.in_channels-1 if isinstance(model, FrameInterpPredLatentDiffusion) \
104 | else model.model.diffusion_model.in_channels
105 | T = model.model.diffusion_model.temporal_length
106 | shape = [batch_size, C, T, *image_size]
107 |
108 | t0 = time.time()
109 | long_samples = []
110 |
111 | # uncond sample
112 | log = dict()
113 |
114 | # --------------------------------------------------------------------
115 | # make condition
116 | cond = add_conditions_long_video(model, {}, batch_size, mask="uncond", input_shape=shape)
117 | if (uc_type is None and uncond_scale != 1.0) or (uc_type == 'cfg_original' and uncond_scale != 0.0):
118 | print('Use Uncondition guidance')
119 | uc = add_conditions_long_video(model, {}, batch_size, mask="uncond", input_shape=shape)
120 | else:
121 | print('NO Uncondition guidance')
122 | uc=None
123 |
124 | # sample an initial clip (unconditional)
125 | sample = sample_clip(model, shape, cond=cond,
126 | uc=uc, uncond_scale=uncond_scale, uc_type=uc_type,
127 | **kwargs)['sample']
128 | long_samples.append(sample.cpu())
129 |
130 | # extend
131 | for i in range(n_pred_steps):
132 | T = sample.shape[2]
133 | cond_z0 = sample[:, :, T-T_cond:, :, :]
134 | assert(cond_z0.shape[2] == T_cond)
135 |
136 | # make prediction model's condition
137 | cond = add_conditions_long_video(model, {}, batch_size, mask="pred", input_shape=shape, cond_frames=cond_z0, sample_cond_noise_level=sample_cond_noise_level)
138 | ## unconditional_guidance's condition
139 | if (uc_type is None and uncond_scale != 1.0) or (uc_type == 'cfg_original' and uncond_scale != 0.0):
140 | print('Use Uncondition guidance')
141 | uc = add_conditions_long_video(model, {}, batch_size, mask="uncond", input_shape=shape)
142 | else:
143 | print('NO Uncondition guidance')
144 | uc=None
145 |
146 | # sample a short clip (condition on previous latents)
147 | sample = sample_clip(model, shape, *args, cond=cond, uc=uc, uncond_scale=uncond_scale, uc_type=uc_type,
148 | **kwargs)['sample']
149 | ext = sample[:, :, T_cond:, :, :]
150 | assert(ext.dim() == 5)
151 | long_samples.append(ext.cpu())
152 | progress = (i+1)/n_pred_steps * 100
153 | print(f"Finish pred step {int(progress)}% [{i+1}/{n_pred_steps}]")
154 | torch.cuda.empty_cache()
155 | long_samples = torch.cat(long_samples, dim=2)
156 |
157 | t1 = time.time()
158 | log["sample"] = long_samples
159 | log["time"] = t1 - t0
160 | log['throughput'] = long_samples.shape[0] / (t1 - t0)
161 | return log
162 |
163 |
164 | # ------------------------------------------------------------------------------------------
165 | @torch.no_grad()
166 | def interpolate(base_samples,
167 | model, batch_size, *args, sample_video=False,
168 | sample_cond_noise_level=None, decode_single_video_allframes=False,
169 | uncond_scale=1.0, uc_type=None, max_z_t=None,
170 | overlap_t=0, prompt=None, config=None,
171 | interpolate_cond_fps=None,
172 | **kwargs):
173 |
174 | model.sample_cond_noise_level = sample_cond_noise_level
175 |
176 | N, c, t, h, w = base_samples.shape
177 | n_steps = len(range(0, t-3, 3))
178 | device = next(model.parameters()).device
179 | if N < batch_size:
180 | batch_size = N
181 | elif N > batch_size:
182 | raise ValueError
183 | assert(N == batch_size)
184 |
185 | C = model.model.diffusion_model.in_channels-1 if isinstance(model, FrameInterpPredLatentDiffusion) and model.concat_mask_on_input \
186 | else model.model.diffusion_model.in_channels
187 | image_size = model.image_size
188 | image_size = [image_size, image_size] if isinstance(image_size, int) else image_size
189 | T = model.model.diffusion_model.temporal_length
190 | shape = [batch_size, C, T, *image_size ]
191 |
192 | t0 = time.time()
193 | long_samples = []
194 | cond = {}
195 | for i in trange(n_steps, desc='Interpolation Steps'):
196 | cond_z0 = base_samples[:, :, i:i+2, :, :].cuda()
197 | # make prediction model's condition
198 | cond = add_conditions_long_video(model, {}, batch_size, mask="interp", input_shape=shape, cond_frames=cond_z0, sample_cond_noise_level=sample_cond_noise_level)
199 | ## unconditional_guidance's condition
200 | if (uc_type is None and uncond_scale != 1.0) or (uc_type == 'cfg_original' and uncond_scale != 0.0):
201 | print('Use Uncondition guidance')
202 | uc = add_conditions_long_video(model, {}, batch_size, mask="uncond", input_shape=shape)
203 | else:
204 | print('NO Uncondition guidance')
205 | uc=None
206 |
207 | # sample an interpolation clip
208 | sample = sample_clip(model, shape, *args, cond=cond, uc=uc,
209 | uncond_scale=uncond_scale,
210 | uc_type=uc_type,
211 | **kwargs)['sample']
212 | ext = sample[:, :, 1:-1, :, :]
213 | assert(ext.dim() == 5)
214 | assert(ext.shape[2] == T - 2)
215 | # -----------------------------------------------
216 | if i != n_steps - 1:
217 | long_samples.extend([cond_z0[:, :, 0:1, :, :].cpu(), ext.cpu()])
218 | else:
219 | long_samples.extend([cond_z0[:, :, 0:1, :, :].cpu(), ext.cpu(), cond_z0[:, :, 1:, :, :].cpu()])
220 | # -----------------------------------------------
221 |
222 | torch.cuda.empty_cache()
223 | long_samples = torch.cat(long_samples, dim=2)
224 |
225 | # decode
226 | long_samples_decoded = []
227 | print('Decoding ...')
228 | for i in trange(long_samples.shape[0]):
229 | torch.cuda.empty_cache()
230 |
231 | long_sample = long_samples[i].unsqueeze(0).cuda()
232 | if overlap_t != 0:
233 | print('Use overlapped decoding')
234 | res = model.overlapped_decode(long_sample, max_z_t=max_z_t,
235 | overlap_t=overlap_t).cpu()
236 | else:
237 | res = model.decode_first_stage(long_sample,
238 | bs=None,
239 | decode_single_video_allframes=decode_single_video_allframes,
240 | max_z_t=max_z_t).cpu()
241 | long_samples_decoded.append(res)
242 | torch.cuda.empty_cache()
243 |
244 | long_samples_decoded = torch.cat(long_samples_decoded, dim=0)
245 |
246 | # log
247 | t1 = time.time()
248 | log = {}
249 | log["sample"] = long_samples_decoded
250 | log["sample_z"] = long_samples
251 | log["time"] = t1 - t0
252 | torch.cuda.empty_cache()
253 | return log
254 |
--------------------------------------------------------------------------------
/scripts/sample_text2video.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import argparse
4 | import yaml, math
5 | from tqdm import trange
6 | import torch
7 | import numpy as np
8 | from omegaconf import OmegaConf
9 | import torch.distributed as dist
10 | from pytorch_lightning import seed_everything
11 |
12 | from lvdm.samplers.ddim import DDIMSampler
13 | from lvdm.utils.common_utils import str2bool
14 | from lvdm.utils.dist_utils import setup_dist, gather_data
15 | from scripts.sample_utils import (load_model,
16 | get_conditions, make_model_input_shape, torch_to_np, sample_batch,
17 | save_results,
18 | save_args
19 | )
20 |
21 |
22 | # ------------------------------------------------------------------------------------------
23 | def get_parser():
24 | parser = argparse.ArgumentParser()
25 | # basic args
26 | parser.add_argument("--ckpt_path", type=str, help="model checkpoint path")
27 | parser.add_argument("--config_path", type=str, help="model config path (a yaml file)")
28 | parser.add_argument("--prompt", type=str, help="input text prompts for text2video (a sentence OR a txt file).")
29 | parser.add_argument("--save_dir", type=str, help="results saving dir", default="results/")
30 | # device args
31 | parser.add_argument("--ddp", action='store_true', help="whether use pytorch ddp mode for parallel sampling (recommend for multi-gpu case)", default=False)
32 | parser.add_argument("--local_rank", type=int, help="is used for pytorch ddp mode", default=0)
33 | parser.add_argument("--gpu_id", type=int, help="choose a specific gpu", default=0)
34 | # sampling args
35 | parser.add_argument("--n_samples", type=int, help="how many samples for each text prompt", default=2)
36 | parser.add_argument("--batch_size", type=int, help="video batch size for sampling", default=1)
37 | parser.add_argument("--decode_frame_bs", type=int, help="frame batch size for framewise decoding", default=1)
38 | parser.add_argument("--sample_type", type=str, help="ddpm or ddim", default="ddim", choices=["ddpm", "ddim"])
39 | parser.add_argument("--ddim_steps", type=int, help="ddim sampling -- number of ddim denoising timesteps", default=50)
40 | parser.add_argument("--eta", type=float, help="ddim sampling -- eta (0.0 yields deterministic sampling, 1.0 yields random sampling)", default=1.0)
41 | parser.add_argument("--seed", type=int, default=None, help="fix a seed for randomness (If you want to reproduce the sample results)")
42 | parser.add_argument("--num_frames", type=int, default=16, help="number of input frames")
43 | parser.add_argument("--show_denoising_progress", action='store_true', default=False, help="whether show denoising progress during sampling one batch",)
44 | parser.add_argument("--cfg_scale", type=float, default=15.0, help="classifier-free guidance scale")
45 | # saving args
46 | parser.add_argument("--save_mp4", type=str2bool, default=True, help="whether save samples in separate mp4 files", choices=["True", "true", "False", "false"])
47 | parser.add_argument("--save_mp4_sheet", action='store_true', default=False, help="whether save samples in mp4 file",)
48 | parser.add_argument("--save_npz", action='store_true', default=False, help="whether save samples in npz file",)
49 | parser.add_argument("--save_jpg", action='store_true', default=False, help="whether save samples in jpg file",)
50 | parser.add_argument("--save_fps", type=int, default=8, help="fps of saved mp4 videos",)
51 | return parser
52 |
53 | # ------------------------------------------------------------------------------------------
54 | @torch.no_grad()
55 | def sample_text2video(model, prompt, n_samples, batch_size,
56 | sample_type="ddim", sampler=None,
57 | ddim_steps=50, eta=1.0, cfg_scale=7.5,
58 | decode_frame_bs=1,
59 | ddp=False, all_gather=True,
60 | batch_progress=True, show_denoising_progress=False,
61 | num_frames=None,
62 | ):
63 | # get cond vector
64 | assert(model.cond_stage_model is not None)
65 | cond_embd = get_conditions(prompt, model, batch_size)
66 | uncond_embd = get_conditions("", model, batch_size) if cfg_scale != 1.0 else None
67 |
68 | # sample batches
69 | all_videos = []
70 | n_iter = math.ceil(n_samples / batch_size)
71 | iterator = trange(n_iter, desc="Sampling Batches (text-to-video)") if batch_progress else range(n_iter)
72 | for _ in iterator:
73 | noise_shape = make_model_input_shape(model, batch_size, T=num_frames)
74 | samples_latent = sample_batch(model, noise_shape, cond_embd,
75 | sample_type=sample_type,
76 | sampler=sampler,
77 | ddim_steps=ddim_steps,
78 | eta=eta,
79 | unconditional_guidance_scale=cfg_scale,
80 | uc=uncond_embd,
81 | denoising_progress=show_denoising_progress,
82 | )
83 | samples = model.decode_first_stage(samples_latent, decode_bs=decode_frame_bs, return_cpu=False)
84 |
85 | # gather samples from multiple gpus
86 | if ddp and all_gather:
87 | data_list = gather_data(samples, return_np=False)
88 | all_videos.extend([torch_to_np(data) for data in data_list])
89 | else:
90 | all_videos.append(torch_to_np(samples))
91 |
92 | all_videos = np.concatenate(all_videos, axis=0)
93 | assert(all_videos.shape[0] >= n_samples)
94 | return all_videos
95 |
96 |
97 |
98 | # ------------------------------------------------------------------------------------------
99 | def main():
100 | """
101 | text-to-video generation
102 | """
103 | parser = get_parser()
104 | opt, unknown = parser.parse_known_args()
105 | os.makedirs(opt.save_dir, exist_ok=True)
106 | save_args(opt.save_dir, opt)
107 |
108 | # set device
109 | if opt.ddp:
110 | setup_dist(opt.local_rank)
111 | opt.n_samples = math.ceil(opt.n_samples / dist.get_world_size())
112 | gpu_id = None
113 | else:
114 | gpu_id = opt.gpu_id
115 | os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu_id}"
116 |
117 | # set random seed
118 | if opt.seed is not None:
119 | if opt.ddp:
120 | seed = opt.local_rank + opt.seed
121 | else:
122 | seed = opt.seed
123 | seed_everything(seed)
124 |
125 | # load & merge config
126 | config = OmegaConf.load(opt.config_path)
127 | cli = OmegaConf.from_dotlist(unknown)
128 | config = OmegaConf.merge(config, cli)
129 | print("config: \n", config)
130 |
131 | # get model & sampler
132 | model, _, _ = load_model(config, opt.ckpt_path)
133 | ddim_sampler = DDIMSampler(model) if opt.sample_type == "ddim" else None
134 |
135 | # prepare prompt
136 | if opt.prompt.endswith(".txt"):
137 | opt.prompt_file = opt.prompt
138 | opt.prompt = None
139 | else:
140 | opt.prompt_file = None
141 |
142 | if opt.prompt_file is not None:
143 | f = open(opt.prompt_file, 'r')
144 | prompts, line_idx = [], []
145 | for idx, line in enumerate(f.readlines()):
146 | l = line.strip()
147 | if len(l) != 0:
148 | prompts.append(l)
149 | line_idx.append(idx)
150 | f.close()
151 | cmd = f"cp {opt.prompt_file} {opt.save_dir}"
152 | os.system(cmd)
153 | else:
154 | prompts = [opt.prompt]
155 | line_idx = [None]
156 |
157 | # go
158 | start = time.time()
159 | for prompt in prompts:
160 | # sample
161 | samples = sample_text2video(model, prompt, opt.n_samples, opt.batch_size,
162 | sample_type=opt.sample_type, sampler=ddim_sampler,
163 | ddim_steps=opt.ddim_steps, eta=opt.eta,
164 | cfg_scale=opt.cfg_scale,
165 | decode_frame_bs=opt.decode_frame_bs,
166 | ddp=opt.ddp, show_denoising_progress=opt.show_denoising_progress,
167 | num_frames=opt.num_frames,
168 | )
169 | # save
170 | if (opt.ddp and dist.get_rank() == 0) or (not opt.ddp):
171 | prompt_str = prompt.replace("/", "_slash_") if "/" in prompt else prompt
172 | save_name = prompt_str.replace(" ", "_") if " " in prompt else prompt_str
173 | if opt.seed is not None:
174 | save_name = save_name + f"_seed{seed:05d}"
175 | save_results(samples, opt.save_dir, save_name=save_name, save_fps=opt.save_fps)
176 | print("Finish sampling!")
177 | print(f"Run time = {(time.time() - start):.2f} seconds")
178 |
179 | if opt.ddp:
180 | dist.destroy_process_group()
181 |
182 |
183 | # ------------------------------------------------------------------------------------------
184 | if __name__ == "__main__":
185 | main()
--------------------------------------------------------------------------------
/scripts/sample_uncond.py:
--------------------------------------------------------------------------------
1 | import os,sys
2 | sys.path.insert(1, os.path.join(sys.path[0], '..'))
3 | import time
4 | import math
5 | import argparse
6 | import numpy as np
7 | from tqdm import trange
8 | from omegaconf import OmegaConf
9 |
10 | import torch
11 | import torch.distributed as dist
12 | from pytorch_lightning import seed_everything
13 |
14 | from lvdm.samplers.ddim import DDIMSampler
15 | from lvdm.utils.common_utils import torch_to_np, str2bool
16 | from lvdm.utils.dist_utils import setup_dist, gather_data
17 | from scripts.sample_utils import (
18 | load_model,
19 | save_args,
20 | make_model_input_shape,
21 | sample_batch,
22 | save_results,
23 | )
24 |
25 | # ------------------------------------------------------------------------------------------
26 | def get_parser():
27 | parser = argparse.ArgumentParser()
28 | # basic args
29 | parser.add_argument("--ckpt_path", type=str, help="model checkpoint path")
30 | parser.add_argument("--config_path", type=str, help="model config path (a yaml file)")
31 | parser.add_argument("--prompt", type=str, help="input text prompts for text2video (a sentence OR a txt file).")
32 | parser.add_argument("--save_dir", type=str, help="results saving dir", default="results/")
33 | # device args
34 | parser.add_argument("--ddp", action='store_true', help="whether use pytorch ddp mode for parallel sampling (recommend for multi-gpu case)", default=False)
35 | parser.add_argument("--local_rank", type=int, help="is used for pytorch ddp mode", default=0)
36 | parser.add_argument("--gpu_id", type=int, help="choose a specific gpu", default=0)
37 | # sampling args
38 | parser.add_argument("--n_samples", type=int, help="how many samples for each text prompt", default=2)
39 | parser.add_argument("--batch_size", type=int, help="video batch size for sampling", default=1)
40 | parser.add_argument("--sample_type", type=str, help="ddpm or ddim", default="ddpm", choices=["ddpm", "ddim"])
41 | parser.add_argument("--ddim_steps", type=int, help="ddim sampling -- number of ddim denoising timesteps", default=50)
42 | parser.add_argument("--eta", type=float, help="ddim sampling -- eta (0.0 yields deterministic sampling, 1.0 yields random sampling)", default=1.0)
43 | parser.add_argument("--seed", type=int, default=None, help="fix a seed for randomness (If you want to reproduce the sample results)")
44 | parser.add_argument("--num_frames", type=int, default=None, help="number of input frames")
45 | parser.add_argument("--show_denoising_progress", action='store_true', default=False, help="whether show denoising progress during sampling one batch",)
46 | parser.add_argument("--uncond_scale", type=float, default=15.0, help="uncondition guidance scale")
47 | # saving args
48 | parser.add_argument("--save_mp4", type=str2bool, default=True, help="whether save samples in separate mp4 files", choices=["True", "true", "False", "false"])
49 | parser.add_argument("--save_mp4_sheet", action='store_true', default=False, help="whether save samples in mp4 file",)
50 | parser.add_argument("--save_npz", action='store_true', default=False, help="whether save samples in npz file",)
51 | parser.add_argument("--save_jpg", action='store_true', default=False, help="whether save samples in jpg file",)
52 | parser.add_argument("--save_fps", type=int, default=8, help="fps of saved mp4 videos",)
53 | return parser
54 |
55 | # ------------------------------------------------------------------------------------------
56 | @torch.no_grad()
57 | def sample(model, noise_shape, n_iters, ddp=False, **kwargs):
58 | all_videos = []
59 | for _ in trange(n_iters, desc="Sampling Batches (unconditional)"):
60 | samples = sample_batch(model, noise_shape, condition=None, **kwargs)
61 | samples = model.decode_first_stage(samples)
62 | if ddp: # gather samples from multiple gpus
63 | data_list = gather_data(samples, return_np=False)
64 | all_videos.extend([torch_to_np(data) for data in data_list])
65 | else:
66 | all_videos.append(torch_to_np(samples))
67 | all_videos = np.concatenate(all_videos, axis=0)
68 | return all_videos
69 |
70 | # ------------------------------------------------------------------------------------------
71 | def main():
72 | """
73 | unconditional generation of short videos
74 | """
75 | parser = get_parser()
76 | opt, unknown = parser.parse_known_args()
77 | os.makedirs(opt.save_dir, exist_ok=True)
78 | save_args(opt.save_dir, opt)
79 |
80 | # set device
81 | if opt.ddp:
82 | setup_dist(opt.local_rank)
83 | opt.n_samples = math.ceil(opt.n_samples / dist.get_world_size())
84 | else:
85 | os.environ["CUDA_VISIBLE_DEVICES"] = f"{opt.gpu_id}"
86 |
87 | # set random seed
88 | if opt.seed is not None:
89 | seed = opt.local_rank + opt.seed if opt.ddp else opt.seed
90 | seed_everything(seed)
91 |
92 | # load & merge config
93 | config = OmegaConf.load(opt.config_path)
94 | cli = OmegaConf.from_dotlist(unknown)
95 | config = OmegaConf.merge(config, cli)
96 | print("config: \n", config)
97 |
98 | # get model & sampler
99 | model, _, _ = load_model(config, opt.ckpt_path)
100 | ddim_sampler = DDIMSampler(model) if opt.sample_type == "ddim" else None
101 |
102 | # sample
103 | start = time.time()
104 | noise_shape = make_model_input_shape(model, opt.batch_size, T=opt.num_frames)
105 | ngpus = 1 if not opt.ddp else dist.get_world_size()
106 | n_iters = math.ceil(opt.n_samples / (ngpus * opt.batch_size))
107 | samples = sample(model, noise_shape, n_iters, sampler=ddim_sampler, **vars(opt))
108 | assert(samples.shape[0] >= opt.n_samples)
109 |
110 | # save
111 | if (opt.ddp and dist.get_rank() == 0) or (not opt.ddp):
112 | if opt.seed is not None:
113 | save_name = f"seed{seed:05d}"
114 | save_results(samples, opt.save_dir, save_name=save_name, save_fps=opt.save_fps)
115 | print("Finish sampling!")
116 | print(f"Run time = {(time.time() - start):.2f} seconds")
117 |
118 | if opt.ddp:
119 | dist.destroy_process_group()
120 |
121 | # ------------------------------------------------------------------------------------------
122 | if __name__ == "__main__":
123 | main()
--------------------------------------------------------------------------------
/scripts/sample_uncond_long_videos.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | sys.path.insert(1, os.path.join(sys.path[0], '..'))
4 |
5 | import math
6 | import torch
7 | import time
8 | import argparse
9 | import numpy as np
10 | from tqdm import trange
11 | from omegaconf import OmegaConf
12 | import torch.distributed as dist
13 | from pytorch_lightning import seed_everything
14 |
15 | from lvdm.utils.dist_utils import setup_dist, gather_data
16 | from lvdm.utils.common_utils import torch_to_np, str2bool
17 | from scripts.sample_long_videos_utils import autoregressive_pred, interpolate
18 | from scripts.sample_utils import load_model, save_args, save_results
19 |
20 | def get_parser():
21 | parser = argparse.ArgumentParser()
22 | # basic args
23 | parser.add_argument("--ckpt_pred", type=str, default="", help="ckpt path")
24 | parser.add_argument("--ckpt_interp", type=str, default=None, help="ckpt path")
25 | parser.add_argument("--config_pred", type=str, default="", help="config yaml")
26 | parser.add_argument("--config_interp", type=str, default=None, help="config yaml")
27 | parser.add_argument("--save_dir", type=str, default="results/longvideos", help="results saving dir")
28 | # device args
29 | parser.add_argument("--ddp", action='store_true', help="whether use pytorch ddp mode for parallel sampling (recommend for multi-gpu case)", default=False)
30 | parser.add_argument("--local_rank", type=int, help="is used for pytorch ddp mode", default=0)
31 | parser.add_argument("--gpu_id", type=int, help="choose a specific gpu", default=0)
32 | # sampling args
33 | parser.add_argument("--n_samples", type=int, help="how many samples for each text prompt", default=2)
34 | parser.add_argument("--batch_size", type=int, help="video batch size for sampling", default=1)
35 | parser.add_argument("--sample_type", type=str, help="ddpm or ddim", default="ddpm", choices=["ddpm", "ddim"])
36 | parser.add_argument("--ddim_steps", type=int, help="ddim sampling -- number of ddim denoising timesteps", default=50)
37 | parser.add_argument("--eta", type=float, help="ddim sampling -- eta (0.0 yields deterministic sampling, 1.0 yields random sampling)", default=1.0)
38 | parser.add_argument("--seed", type=int, default=None, help="fix a seed for randomness (If you want to reproduce the sample results)")
39 | parser.add_argument("--num_frames", type=int, default=None, help="number of input frames")
40 | parser.add_argument("--show_denoising_progress", action='store_true', default=False, help="whether show denoising progress during sampling one batch",)
41 | parser.add_argument("--uncond_scale", type=float, default=1.0, help="unconditional guidance scale")
42 | parser.add_argument("--uc_type", type=str, help="unconditional guidance scale", default="cfg_original", choices=["cfg_original", None])
43 | # prediction & interpolation args
44 | parser.add_argument("--T_cond", type=int, default=1, help="temporal length of condition frames")
45 | parser.add_argument("--n_pred_steps", type=int, default=None, help="")
46 | parser.add_argument("--sample_cond_noise_level", type=int, default=None, help="")
47 | parser.add_argument("--overlap_t", type=int, default=0, help="")
48 | # saving args
49 | parser.add_argument("--save_mp4", type=str2bool, default=True, help="whether save samples in separate mp4 files", choices=["True", "true", "False", "false"])
50 | parser.add_argument("--save_mp4_sheet", action='store_true', default=False, help="whether save samples in mp4 file",)
51 | parser.add_argument("--save_npz", action='store_true', default=False, help="whether save samples in npz file",)
52 | parser.add_argument("--save_jpg", action='store_true', default=False, help="whether save samples in jpg file",)
53 | parser.add_argument("--save_fps", type=int, default=8, help="fps of saved mp4 videos",)
54 | return parser
55 |
56 | # ------------------------------------------------------------------------------------------
57 | @torch.no_grad()
58 | def sample(model_pred, model_interp, n_iters, ddp=False, all_gather=False, **kwargs):
59 | all_videos = []
60 | for _ in trange(n_iters, desc="Sampling Batches (unconditional)"):
61 |
62 | # autoregressive predict latents
63 | logs = autoregressive_pred(model_pred, **kwargs)
64 | samples_z = logs['sample'] if isinstance(logs, dict) else logs
65 |
66 | # interpolate latents
67 | logs = interpolate(samples_z, model_interp, **kwargs)
68 | samples=logs['sample']
69 |
70 | if ddp and all_gather: # gather samples from multiple gpus
71 | data_list = gather_data(samples, return_np=False)
72 | all_videos.extend([torch_to_np(data) for data in data_list])
73 | else:
74 | all_videos.append(torch_to_np(samples))
75 | all_videos = np.concatenate(all_videos, axis=0)
76 | return all_videos
77 |
78 | # ------------------------------------------------------------------------------------------
79 | def main():
80 | """
81 | unconditional generation of long videos
82 | """
83 | parser = get_parser()
84 | opt, unknown = parser.parse_known_args()
85 | os.makedirs(opt.save_dir, exist_ok=True)
86 | save_args(opt.save_dir, opt)
87 |
88 | # set device
89 | if opt.ddp:
90 | setup_dist(opt.local_rank)
91 | opt.n_samples = math.ceil(opt.n_samples / dist.get_world_size())
92 | else:
93 | os.environ["CUDA_VISIBLE_DEVICES"] = f"{opt.gpu_id}"
94 |
95 | # set random seed
96 | if opt.seed is not None:
97 | seed = opt.local_rank + opt.seed if opt.ddp else opt.seed
98 | seed_everything(seed)
99 |
100 | # load & merge config
101 | config_pred = OmegaConf.load(opt.config_pred)
102 | cli = OmegaConf.from_dotlist(unknown)
103 | config_pred = OmegaConf.merge(config_pred, cli)
104 | if opt.config_interp is not None:
105 | config_interp = OmegaConf.load(opt.config_interp)
106 | cli = OmegaConf.from_dotlist(unknown)
107 | config_interp = OmegaConf.merge(config_interp, cli)
108 |
109 | # calculate n_pred_steps
110 | if opt.num_frames is not None:
111 | temporal_down = config_pred.model.params.first_stage_config.params.ddconfig.encoder.params.downsample[0]
112 | temporal_length_z = opt.num_frames // temporal_down
113 | model_pred_length = config_pred.model.params.unet_config.params.temporal_length
114 | if opt.config_interp is not None:
115 | model_interp_length = config_interp.model.params.unet_config.params.temporal_length - 2
116 | pred_length = math.ceil((temporal_length_z + model_interp_length) / (model_interp_length + 1))
117 | else:
118 | pred_length = temporal_length_z
119 | n_pred_steps = math.ceil((pred_length - model_pred_length) / (model_pred_length - 1))
120 | opt.n_pred_steps = n_pred_steps
121 | print(f'Temporal length {opt.num_frames} needs latent length = {temporal_length_z}; \n \
122 | pred_length = {pred_length}; \n \
123 | prediction steps = {n_pred_steps}')
124 | else:
125 | assert(opt.n_pred_steps is not None)
126 |
127 | # model
128 | model_pred, _, _ = load_model(config_pred, opt.ckpt_pred)
129 | model_interp, _, _ = load_model(config_interp, opt.ckpt_interp)
130 |
131 | # sample
132 | start = time.time()
133 | ngpus = 1 if not opt.ddp else dist.get_world_size()
134 | n_iters = math.ceil(opt.n_samples / (ngpus * opt.batch_size))
135 | samples = sample(model_pred, model_interp, n_iters, **vars(opt))
136 | assert(samples.shape[0] >= opt.n_samples)
137 |
138 | # save
139 | if (opt.ddp and dist.get_rank() == 0) or (not opt.ddp):
140 | if opt.seed is not None:
141 | save_name = f"seed{seed:05d}"
142 | save_results(samples, opt.save_dir, save_name=save_name, save_fps=opt.save_fps)
143 | print("Finish sampling!")
144 | print(f"total time = {int(time.time()- start)} seconds, \
145 | num of iters = {n_iters}; \
146 | num of samples = {ngpus * opt.batch_size * n_iters}; \
147 | temporal length = {opt.num_frames}")
148 |
149 | if opt.ddp:
150 | dist.destroy_process_group()
151 |
152 | # ------------------------------------------------------------------------------------------
153 | if __name__ == "__main__":
154 | main()
--------------------------------------------------------------------------------
/scripts/sample_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import yaml
3 | import numpy as np
4 | from PIL import Image
5 |
6 | import torch
7 |
8 | from lvdm.utils.common_utils import instantiate_from_config
9 | from lvdm.utils.saving_utils import npz_to_video_grid, npz_to_imgsheet_5d
10 |
11 | def custom_to_pil(x):
12 | x = x.detach().cpu()
13 | x = torch.clamp(x, -1., 1.)
14 | x = (x + 1.) / 2.
15 | x = x.permute(1, 2, 0).numpy()
16 | x = (255 * x).astype(np.uint8)
17 | x = Image.fromarray(x)
18 | if not x.mode == "RGB":
19 | x = x.convert("RGB")
20 | return x
21 |
22 | def custom_to_np(x):
23 | # saves the batch in adm style as in https://github.com/openai/guided-diffusion/blob/main/scripts/image_sample.py
24 | sample = x.detach().cpu()
25 | sample = ((sample + 1) * 127.5).clamp(0, 255).to(torch.uint8)
26 | if sample.dim() == 5:
27 | sample = sample.permute(0, 2, 3, 4, 1)
28 | else:
29 | sample = sample.permute(0, 2, 3, 1)
30 | sample = sample.contiguous()
31 | return sample
32 |
33 | def save_args(save_dir, args):
34 | fpath = os.path.join(save_dir, "sampling_args.yaml")
35 | with open(fpath, 'w') as f:
36 | yaml.dump(vars(args), f, default_flow_style=False)
37 |
38 | # ------------------------------------------------------------------------------------------
39 | def load_model(config, ckpt_path, gpu_id=None):
40 | print(f"Loading model from {ckpt_path}")
41 |
42 | # load sd
43 | pl_sd = torch.load(ckpt_path, map_location="cpu")
44 | try:
45 | global_step = pl_sd["global_step"]
46 | epoch = pl_sd["epoch"]
47 | except:
48 | global_step = -1
49 | epoch = -1
50 |
51 | # load sd to model
52 | try:
53 | sd = pl_sd["state_dict"]
54 | except:
55 | sd = pl_sd
56 | model = instantiate_from_config(config.model)
57 | model.load_state_dict(sd, strict=True)
58 |
59 | # move to device & eval
60 | if gpu_id is not None:
61 | model.to(f"cuda:{gpu_id}")
62 | else:
63 | model.cuda()
64 | model.eval()
65 |
66 | return model, global_step, epoch
67 |
68 | def make_sample_dir(opt, global_step, epoch):
69 | if not getattr(opt, 'not_automatic_logdir', False):
70 | gs_str = f"globalstep{global_step:09}" if global_step is not None else "None"
71 | e_str = f"epoch{epoch:06}" if epoch is not None else "None"
72 | ckpt_dir = os.path.join(opt.logdir, f"{gs_str}_{e_str}")
73 |
74 | # subdir name
75 | if opt.prompt_file is not None:
76 | subdir = f"prompts_{os.path.splitext(os.path.basename(opt.prompt_file))[0]}"
77 | else:
78 | subdir = f"prompt_{opt.prompt[:10]}"
79 | subdir += "_DDPM" if opt.vanilla_sample else f"_DDIM{opt.custom_steps}steps"
80 | subdir += f"_CfgScale{opt.scale}"
81 | if opt.cond_fps is not None:
82 | subdir += f"_fps{opt.cond_fps}"
83 | if opt.seed is not None:
84 | subdir += f"_seed{opt.seed}"
85 |
86 | return os.path.join(ckpt_dir, subdir)
87 | else:
88 | return opt.logdir
89 |
90 | # ------------------------------------------------------------------------------------------
91 | @torch.no_grad()
92 | def get_conditions(prompts, model, batch_size, cond_fps=None,):
93 |
94 | if isinstance(prompts, str) or isinstance(prompts, int):
95 | prompts = [prompts]
96 | if isinstance(prompts, list):
97 | if len(prompts) == 1:
98 | prompts = prompts * batch_size
99 | elif len(prompts) == batch_size:
100 | pass
101 | else:
102 | raise ValueError(f"invalid prompts length: {len(prompts)}")
103 | else:
104 | raise ValueError(f"invalid prompts: {prompts}")
105 | assert(len(prompts) == batch_size)
106 |
107 | # content condition: text / class label
108 | c = model.get_learned_conditioning(prompts)
109 | key = 'c_concat' if model.conditioning_key == 'concat' else 'c_crossattn'
110 | c = {key: [c]}
111 |
112 | # temporal condition: fps
113 | if getattr(model, 'cond_stage2_config', None) is not None:
114 | if model.cond_stage2_key == "temporal_context":
115 | assert(cond_fps is not None)
116 | batch = {'fps': torch.tensor([cond_fps] * batch_size).long().to(model.device)}
117 | fps_embd = model.cond_stage2_model(batch)
118 | c[model.cond_stage2_key] = fps_embd
119 |
120 | return c
121 |
122 | # ------------------------------------------------------------------------------------------
123 | def make_model_input_shape(model, batch_size, T=None):
124 | image_size = [model.image_size, model.image_size] if isinstance(model.image_size, int) else model.image_size
125 | C = model.model.diffusion_model.in_channels
126 | if T is None:
127 | T = model.model.diffusion_model.temporal_length
128 | shape = [batch_size, C, T, *image_size]
129 | return shape
130 |
131 | # ------------------------------------------------------------------------------------------
132 | def sample_batch(model, noise_shape, condition,
133 | sample_type="ddim",
134 | sampler=None,
135 | ddim_steps=None,
136 | eta=None,
137 | unconditional_guidance_scale=1.0,
138 | uc=None,
139 | denoising_progress=False,
140 | **kwargs,
141 | ):
142 |
143 | if sample_type == "ddpm":
144 | samples = model.p_sample_loop(cond=condition, shape=noise_shape,
145 | return_intermediates=False,
146 | verbose=denoising_progress,
147 | )
148 | elif sample_type == "ddim":
149 | assert(sampler is not None)
150 | assert(ddim_steps is not None)
151 | assert(eta is not None)
152 | ddim_sampler = sampler
153 | samples, _ = ddim_sampler.sample(S=ddim_steps,
154 | conditioning=condition,
155 | batch_size=noise_shape[0],
156 | shape=noise_shape[1:],
157 | verbose=denoising_progress,
158 | unconditional_guidance_scale=unconditional_guidance_scale,
159 | unconditional_conditioning=uc,
160 | eta=eta,
161 | **kwargs,
162 | )
163 | else:
164 | raise ValueError
165 | return samples
166 |
167 | # ------------------------------------------------------------------------------------------
168 | def torch_to_np(x):
169 | # saves the batch in adm style as in https://github.com/openai/guided-diffusion/blob/main/scripts/image_sample.py
170 | sample = x.detach().cpu()
171 | sample = ((sample + 1) * 127.5).clamp(0, 255).to(torch.uint8)
172 | if sample.dim() == 5:
173 | sample = sample.permute(0, 2, 3, 4, 1)
174 | else:
175 | sample = sample.permute(0, 2, 3, 1)
176 | sample = sample.contiguous()
177 | return sample
178 | # ------------------------------------------------------------------------------------------
179 | def save_results(videos, save_dir,
180 | save_name="results", save_fps=8, save_mp4=True,
181 | save_npz=False, save_mp4_sheet=False, save_jpg=False
182 | ):
183 | if save_mp4:
184 | save_subdir = os.path.join(save_dir, "videos")
185 | os.makedirs(save_subdir, exist_ok=True)
186 | shape_str = "x".join([str(x) for x in videos[0:1,...].shape])
187 | for i in range(videos.shape[0]):
188 | npz_to_video_grid(videos[i:i+1,...],
189 | os.path.join(save_subdir, f"{save_name}_{i:03d}_{shape_str}.mp4"),
190 | fps=save_fps)
191 | print(f'Successfully saved videos in {save_subdir}')
192 |
193 | shape_str = "x".join([str(x) for x in videos.shape])
194 | if save_npz:
195 | save_path = os.path.join(save_dir, f"{save_name}_{shape_str}.npz")
196 | np.savez(save_path, videos)
197 | print(f'Successfully saved npz in {save_path}')
198 |
199 | if save_mp4_sheet:
200 | save_path = os.path.join(save_dir, f"{save_name}_{shape_str}.mp4")
201 | npz_to_video_grid(videos, save_path, fps=save_fps)
202 | print(f'Successfully saved mp4 sheet in {save_path}')
203 |
204 | if save_jpg:
205 | save_path = os.path.join(save_dir, f"{save_name}_{shape_str}.jpg")
206 | npz_to_imgsheet_5d(videos, save_path, nrow=videos.shape[1])
207 | print(f'Successfully saved jpg sheet in {save_path}')
208 |
209 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name='lvdm',
5 | version='0.0.1',
6 | description='generic video generation models',
7 | packages=find_packages(),
8 | install_requires=[],
9 | )
--------------------------------------------------------------------------------
/shellscripts/eval_lvdm_short.sh:
--------------------------------------------------------------------------------
1 |
2 | DATACONFIG="configs/lvdm_short/sky.yaml"
3 | FAKEPATH='${Your_Path}/2048x16x256x256x3-samples.npz'
4 | REALPATH='/dataset/sky_timelapse'
5 | RESDIR='results/fvd'
6 |
7 | mkdir -p $res_dir
8 | python scripts/eval_cal_fvd_kvd.py \
9 | --yaml ${DATACONFIG} \
10 | --real_path ${REALPATH} \
11 | --fake_path ${FAKEPATH} \
12 | --batch_size 32 \
13 | --num_workers 4 \
14 | --n_runs 10 \
15 | --res_dir ${RESDIR} \
16 | --n_sample 2048
17 |
--------------------------------------------------------------------------------
/shellscripts/sample_lvdm_long.sh:
--------------------------------------------------------------------------------
1 |
2 | CKPT_PRED="models/lvdm_long/sky_pred.ckpt"
3 | CKPT_INTERP="models/lvdm_long/sky_interp.ckpt"
4 | AEPATH="models/ae/ae_sky.ckpt"
5 | CONFIG_PRED="configs/lvdm_long/sky_pred.yaml"
6 | CONFIG_INTERP="configs/lvdm_long/sky_interp.yaml"
7 | OUTDIR="results/longvideos/"
8 |
9 | python scripts/sample_uncond_long_videos.py \
10 | --ckpt_pred $CKPT_PRED \
11 | --config_pred $CONFIG_PRED \
12 | --ckpt_interp $CKPT_INTERP \
13 | --config_interp $CONFIG_INTERP \
14 | --save_dir $OUTDIR \
15 | --n_samples 1 \
16 | --batch_size 1 \
17 | --seed 1000 \
18 | --show_denoising_progress \
19 | model.params.first_stage_config.params.ckpt_path=$AEPATH \
20 | --sample_cond_noise_level 100 \
21 | --uncond_scale 0.1 \
22 | --n_pred_steps 2 \
23 | --sample_type ddim --ddim_steps 50
24 |
25 | # if use DDPM: remove: `--sample_type ddim --ddim_steps 50`
--------------------------------------------------------------------------------
/shellscripts/sample_lvdm_short.sh:
--------------------------------------------------------------------------------
1 |
2 | CONFIG_PATH="configs/lvdm_short/sky.yaml"
3 | BASE_PATH="models/lvdm_short/short_sky.ckpt"
4 | AEPATH="models/ae/ae_sky.ckpt"
5 | OUTDIR="results/uncond_short/"
6 |
7 | python scripts/sample_uncond.py \
8 | --ckpt_path $BASE_PATH \
9 | --config_path $CONFIG_PATH \
10 | --save_dir $OUTDIR \
11 | --n_samples 1 \
12 | --batch_size 1 \
13 | --seed 1000 \
14 | --show_denoising_progress \
15 | model.params.first_stage_config.params.ckpt_path=$AEPATH
16 |
17 | # if use DDIM: add: `--sample_type ddim --ddim_steps 50`
18 |
--------------------------------------------------------------------------------
/shellscripts/sample_lvdm_text2video.sh:
--------------------------------------------------------------------------------
1 |
2 | PROMPT="astronaut riding a horse" # OR: PROMPT="input/prompts.txt" for sampling multiple prompts
3 | OUTDIR="results/t2v"
4 |
5 | CKPT_PATH="models/t2v/model.ckpt"
6 | CONFIG_PATH="configs/lvdm_short/text2video.yaml"
7 |
8 | python scripts/sample_text2video.py \
9 | --ckpt_path $CKPT_PATH \
10 | --config_path $CONFIG_PATH \
11 | --prompt "$PROMPT" \
12 | --save_dir $OUTDIR \
13 | --n_samples 1 \
14 | --batch_size 1 \
15 | --seed 1000 \
16 | --show_denoising_progress \
17 | --save_jpg
18 |
--------------------------------------------------------------------------------
/shellscripts/train_lvdm_interp_sky.sh:
--------------------------------------------------------------------------------
1 | export PATH=/apdcephfs_cq2/share_1290939/yingqinghe/anaconda/envs/ldmA100/bin/:$PATH
2 | export http_proxy="http://star-proxy.oa.com:3128"
3 | export https_proxy="http://star-proxy.oa.com:3128"
4 | export ftp_proxy="http://star-proxy.oa.com:3128"
5 | export no_proxy=".woa.com,mirrors.cloud.tencent.com,tlinux-mirror.tencent-cloud.com,tlinux-mirrorlist.tencent-cloud.com,localhost,127.0.0.1,mirrors-tlinux.tencentyun.com,.oa.com,.local,.3gqq.com,.7700.org,.ad.com,.ada_sixjoy.com,.addev.com,.app.local,.apps.local,.aurora.com,.autotest123.com,.bocaiwawa.com,.boss.com,.cdc.com,.cdn.com,.cds.com,.cf.com,.cjgc.local,.cm.com,.code.com,.datamine.com,.dvas.com,.dyndns.tv,.ecc.com,.expochart.cn,.expovideo.cn,.fms.com,.great.com,.hadoop.sec,.heme.com,.home.com,.hotbar.com,.ibg.com,.ied.com,.ieg.local,.ierd.com,.imd.com,.imoss.com,.isd.com,.isoso.com,.itil.com,.kao5.com,.kf.com,.kitty.com,.lpptp.com,.m.com,.matrix.cloud,.matrix.net,.mickey.com,.mig.local,.mqq.com,.oiweb.com,.okbuy.isddev.com,.oss.com,.otaworld.com,.paipaioa.com,.qqbrowser.local,.qqinternal.com,.qqwork.com,.rtpre.com,.sc.oa.com,.sec.com,.server.com,.service.com,.sjkxinternal.com,.sllwrnm5.cn,.sng.local,.soc.com,.t.km,.tcna.com,.teg.local,.tencentvoip.com,.tenpayoa.com,.test.air.tenpay.com,.tr.com,.tr_autotest123.com,.vpn.com,.wb.local,.webdev.com,.webdev2.com,.wizard.com,.wqq.com,.wsd.com,.sng.com,.music.lan,.mnet2.com,.tencentb2.com,.tmeoa.com,.pcg.com,www.wip3.adobe.com,www-mm.wip3.adobe.com,mirrors.tencent.com,csighub.tencentyun.com"
6 |
7 | export TOKENIZERS_PARALLELISM=false
8 |
9 | PROJ_ROOT="/apdcephfs_cq2/share_1290939/yingqinghe/results/latent_diffusion"
10 | EXPNAME="test_sky_train_interp"
11 | CONFIG="configs/lvdm_long/sky_interp.yaml"
12 | DATADIR="/dockerdata/sky_timelapse"
13 | AEPATH="/apdcephfs/share_1290939/yingqinghe/results/latent_diffusion/LVDM/ae_013_sky256_basedon003_4nodes_e0/checkpoints/trainstep_checkpoints/epoch=000299-step=000010199.ckpt"
14 |
15 | # run
16 | python main.py \
17 | --base $CONFIG \
18 | -t --gpus 0, \
19 | --name $EXPNAME \
20 | --logdir $PROJ_ROOT \
21 | --auto_resume True \
22 | lightning.trainer.num_nodes=1 \
23 | data.params.train.params.data_root=$DATADIR \
24 | data.params.validation.params.data_root=$DATADIR \
25 | model.params.first_stage_config.params.ckpt_path=$AEPATH
26 |
27 |
28 | # commands for multi nodes training
29 | # ---------------------------------------------------------------------------------------------------
30 | # python -m torch.distributed.run \
31 | # --nproc_per_node=8 --nnodes=$NHOST --master_addr=$MASTER_ADDR --master_port=1234 --node_rank=$INDEX \
32 | # main.py \
33 | # --base $CONFIG \
34 | # -t --gpus 0,1,2,3,4,5,6,7 \
35 | # --name $EXPNAME \
36 | # --logdir $PROJ_ROOT \
37 | # --auto_resume True \
38 | # lightning.trainer.num_nodes=$NHOST \
39 | # data.params.train.params.data_root=$DATADIR \
40 | # data.params.validation.params.data_root=$DATADIR
41 |
--------------------------------------------------------------------------------
/shellscripts/train_lvdm_pred_sky.sh:
--------------------------------------------------------------------------------
1 | export PATH=/apdcephfs_cq2/share_1290939/yingqinghe/anaconda/envs/ldmA100/bin/:$PATH
2 | export http_proxy="http://star-proxy.oa.com:3128"
3 | export https_proxy="http://star-proxy.oa.com:3128"
4 | export ftp_proxy="http://star-proxy.oa.com:3128"
5 | export no_proxy=".woa.com,mirrors.cloud.tencent.com,tlinux-mirror.tencent-cloud.com,tlinux-mirrorlist.tencent-cloud.com,localhost,127.0.0.1,mirrors-tlinux.tencentyun.com,.oa.com,.local,.3gqq.com,.7700.org,.ad.com,.ada_sixjoy.com,.addev.com,.app.local,.apps.local,.aurora.com,.autotest123.com,.bocaiwawa.com,.boss.com,.cdc.com,.cdn.com,.cds.com,.cf.com,.cjgc.local,.cm.com,.code.com,.datamine.com,.dvas.com,.dyndns.tv,.ecc.com,.expochart.cn,.expovideo.cn,.fms.com,.great.com,.hadoop.sec,.heme.com,.home.com,.hotbar.com,.ibg.com,.ied.com,.ieg.local,.ierd.com,.imd.com,.imoss.com,.isd.com,.isoso.com,.itil.com,.kao5.com,.kf.com,.kitty.com,.lpptp.com,.m.com,.matrix.cloud,.matrix.net,.mickey.com,.mig.local,.mqq.com,.oiweb.com,.okbuy.isddev.com,.oss.com,.otaworld.com,.paipaioa.com,.qqbrowser.local,.qqinternal.com,.qqwork.com,.rtpre.com,.sc.oa.com,.sec.com,.server.com,.service.com,.sjkxinternal.com,.sllwrnm5.cn,.sng.local,.soc.com,.t.km,.tcna.com,.teg.local,.tencentvoip.com,.tenpayoa.com,.test.air.tenpay.com,.tr.com,.tr_autotest123.com,.vpn.com,.wb.local,.webdev.com,.webdev2.com,.wizard.com,.wqq.com,.wsd.com,.sng.com,.music.lan,.mnet2.com,.tencentb2.com,.tmeoa.com,.pcg.com,www.wip3.adobe.com,www-mm.wip3.adobe.com,mirrors.tencent.com,csighub.tencentyun.com"
6 |
7 | export TOKENIZERS_PARALLELISM=false
8 |
9 | PROJ_ROOT="/apdcephfs_cq2/share_1290939/yingqinghe/results/latent_diffusion"
10 | EXPNAME="test_sky_train_pred"
11 | CONFIG="configs/lvdm_long/sky_pred.yaml"
12 | DATADIR="/dockerdata/sky_timelapse"
13 | AEPATH="/apdcephfs/share_1290939/yingqinghe/results/latent_diffusion/LVDM/ae_013_sky256_basedon003_4nodes_e0/checkpoints/trainstep_checkpoints/epoch=000299-step=000010199.ckpt"
14 |
15 | # run
16 | python main.py \
17 | --base $CONFIG \
18 | -t --gpus 0, \
19 | --name $EXPNAME \
20 | --logdir $PROJ_ROOT \
21 | --auto_resume True \
22 | lightning.trainer.num_nodes=1 \
23 | data.params.train.params.data_root=$DATADIR \
24 | data.params.validation.params.data_root=$DATADIR \
25 | model.params.first_stage_config.params.ckpt_path=$AEPATH
26 |
27 |
28 | # commands for multi nodes training
29 | # ---------------------------------------------------------------------------------------------------
30 | # python -m torch.distributed.run \
31 | # --nproc_per_node=8 --nnodes=$NHOST --master_addr=$MASTER_ADDR --master_port=1234 --node_rank=$INDEX \
32 | # main.py \
33 | # --base $CONFIG \
34 | # -t --gpus 0,1,2,3,4,5,6,7 \
35 | # --name $EXPNAME \
36 | # --logdir $PROJ_ROOT \
37 | # --auto_resume True \
38 | # lightning.trainer.num_nodes=$NHOST \
39 | # data.params.train.params.data_root=$DATADIR \
40 | # data.params.validation.params.data_root=$DATADIR
41 |
--------------------------------------------------------------------------------
/shellscripts/train_lvdm_short.sh:
--------------------------------------------------------------------------------
1 |
2 | PROJ_ROOT="" # root directory for saving experiment logs
3 | EXPNAME="lvdm_short_sky" # experiment name
4 | DATADIR="/dataset/sky_timelapse" # dataset directory
5 | AEPATH="models/ae/ae_sky.ckpt" # pretrained video autoencoder checkpoint
6 |
7 | CONFIG="configs/lvdm_short/sky.yaml"
8 | # OR CONFIG="configs/videoae/ucf.yaml"
9 | # OR CONFIG="configs/videoae/taichi.yaml"
10 |
11 | # run
12 | export TOKENIZERS_PARALLELISM=false
13 | python main.py \
14 | --base $CONFIG \
15 | -t --gpus 0, \
16 | --name $EXPNAME \
17 | --logdir $PROJ_ROOT \
18 | --auto_resume True \
19 | lightning.trainer.num_nodes=1 \
20 | data.params.train.params.data_root=$DATADIR \
21 | data.params.validation.params.data_root=$DATADIR \
22 | model.params.first_stage_config.params.ckpt_path=$AEPATH
23 |
24 | # -------------------------------------------------------------------------------------------------
25 | # commands for multi nodes training
26 | # - use torch.distributed.run to launch main.py
27 | # - set `gpus` and `lightning.trainer.num_nodes`
28 |
29 | # For example:
30 |
31 | # python -m torch.distributed.run \
32 | # --nproc_per_node=8 --nnodes=$NHOST --master_addr=$MASTER_ADDR --master_port=1234 --node_rank=$INDEX \
33 | # main.py \
34 | # --base $CONFIG \
35 | # -t --gpus 0,1,2,3,4,5,6,7 \
36 | # --name $EXPNAME \
37 | # --logdir $PROJ_ROOT \
38 | # --auto_resume True \
39 | # lightning.trainer.num_nodes=$NHOST \
40 | # data.params.train.params.data_root=$DATADIR \
41 | # data.params.validation.params.data_root=$DATADIR
42 |
--------------------------------------------------------------------------------
/shellscripts/train_lvdm_videoae.sh:
--------------------------------------------------------------------------------
1 |
2 |
3 | PROJ_ROOT="results/" # root directory for saving experiment logs
4 | EXPNAME="lvdm_videoae_sky" # experiment name
5 | DATADIR="/dataset/sky_timelapse" # dataset directory
6 |
7 | CONFIG="configs/videoae/sky.yaml"
8 | # OR CONFIG="configs/videoae/ucf.yaml"
9 | # OR CONFIG="configs/videoae/taichi.yaml"
10 |
11 | # run
12 | export TOKENIZERS_PARALLELISM=false
13 | python main.py \
14 | --base $CONFIG \
15 | -t --gpus 0, \
16 | --name $EXPNAME \
17 | --logdir $PROJ_ROOT \
18 | --auto_resume True \
19 | lightning.trainer.num_nodes=1 \
20 | data.params.train.params.data_root=$DATADIR \
21 | data.params.validation.params.data_root=$DATADIR
22 |
23 | # -------------------------------------------------------------------------------------------------
24 | # commands for multi nodes training
25 | # - use torch.distributed.run to launch main.py
26 | # - set `gpus` and `lightning.trainer.num_nodes`
27 |
28 | # For example:
29 |
30 | # python -m torch.distributed.run \
31 | # --nproc_per_node=8 --nnodes=$NHOST --master_addr=$MASTER_ADDR --master_port=1234 --node_rank=$INDEX \
32 | # main.py \
33 | # --base $CONFIG \
34 | # -t --gpus 0,1,2,3,4,5,6,7 \
35 | # --name $EXPNAME \
36 | # --logdir $PROJ_ROOT \
37 | # --auto_resume True \
38 | # lightning.trainer.num_nodes=$NHOST \
39 | # data.params.train.params.data_root=$DATADIR \
40 | # data.params.validation.params.data_root=$DATADIR
41 |
--------------------------------------------------------------------------------
/shellscripts/train_lvdm_videoae_ucf.sh:
--------------------------------------------------------------------------------
1 |
2 |
3 | PROJ_ROOT="results/" # input the root directory for saving experiment logs
4 | EXPNAME="lvdm_videoae_ucf101" # experiment name
5 | DATADIR="UCF-101/" # input the dataset directory
6 |
7 | CONFIG="configs/videoae/ucf_videodata.yaml"
8 |
9 | # run
10 | export TOKENIZERS_PARALLELISM=false
11 | python main.py \
12 | --base $CONFIG \
13 | -t --gpus 0, \
14 | --name $EXPNAME \
15 | --logdir $PROJ_ROOT \
16 | --auto_resume True \
17 | lightning.trainer.num_nodes=1 \
18 | data.params.train.params.data_root=$DATADIR \
19 | data.params.validation.params.data_root=$DATADIR
20 |
21 | # -------------------------------------------------------------------------------------------------
22 | # commands for multi nodes training
23 | # - use torch.distributed.run to launch main.py
24 | # - set `gpus` and `lightning.trainer.num_nodes`
25 |
26 | # For example:
27 |
28 | # python -m torch.distributed.run \
29 | # --nproc_per_node=8 --nnodes=$NHOST --master_addr=$MASTER_ADDR --master_port=1234 --node_rank=$INDEX \
30 | # main.py \
31 | # --base $CONFIG \
32 | # -t --gpus 0,1,2,3,4,5,6,7 \
33 | # --name $EXPNAME \
34 | # --logdir $PROJ_ROOT \
35 | # --auto_resume True \
36 | # lightning.trainer.num_nodes=$NHOST \
37 | # data.params.train.params.data_root=$DATADIR \
38 | # data.params.validation.params.data_root=$DATADIR
39 |
--------------------------------------------------------------------------------