├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── assets ├── framework.jpeg ├── framework.jpg ├── sky-long-001.gif ├── sky-long-002.gif ├── sky-long-003.gif ├── t2v-001.gif ├── t2v-002.gif ├── t2v-003.gif ├── t2v-004.gif ├── t2v-005.gif ├── t2v-006.gif ├── t2v-007.gif ├── t2v-008.gif ├── ucf-long-001.gif ├── ucf-long-002.gif └── ucf-long-003.gif ├── configs ├── lvdm_long │ ├── sky_interp.yaml │ └── sky_pred.yaml ├── lvdm_short │ ├── sky.yaml │ ├── taichi.yaml │ ├── text2video.yaml │ └── ucf.yaml └── videoae │ ├── sky.yaml │ ├── taichi.yaml │ ├── ucf.yaml │ └── ucf_videodata.yaml ├── input └── prompts.txt ├── lvdm ├── data │ ├── frame_dataset.py │ ├── split_ucf101.py │ ├── taichi.py │ └── ucf.py ├── models │ ├── autoencoder.py │ ├── autoencoder3d.py │ ├── ddpm3d.py │ ├── losses │ │ ├── __init__.py │ │ ├── contperceptual.py │ │ └── vqperceptual.py │ └── modules │ │ ├── aemodules.py │ │ ├── aemodules3d.py │ │ ├── attention_temporal.py │ │ ├── condition_modules.py │ │ ├── distributions.py │ │ ├── ema.py │ │ ├── openaimodel3d.py │ │ └── util.py ├── samplers │ └── ddim.py └── utils │ ├── callbacks.py │ ├── common_utils.py │ ├── dist_utils.py │ ├── log.py │ └── saving_utils.py ├── main.py ├── requirements.txt ├── requirements_h800_gpu.txt ├── scripts ├── eval_cal_fvd_kvd.py ├── fvd_utils │ ├── fvd_utils.py │ └── pytorch_i3d.py ├── sample_long_videos_utils.py ├── sample_text2video.py ├── sample_uncond.py ├── sample_uncond_long_videos.py └── sample_utils.py ├── setup.py └── shellscripts ├── eval_lvdm_short.sh ├── sample_lvdm_long.sh ├── sample_lvdm_short.sh ├── sample_lvdm_text2video.sh ├── train_lvdm_interp_sky.sh ├── train_lvdm_pred_sky.sh ├── train_lvdm_short.sh ├── train_lvdm_videoae.sh └── train_lvdm_videoae_ucf.sh /.gitattributes: -------------------------------------------------------------------------------- 1 | *.gif filter=lfs diff=lfs merge=lfs -text 2 | assets/*.gif filter=lfs diff=lfs merge=lfs -text 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.pyc* 3 | __pycache__ 4 | .vscode/* 5 | *.egg-info/ 6 | results/ 7 | *.pt 8 | *.ckpt 9 | clip-vit-large-patch14 10 | results 11 | temp -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Cassie 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 |
3 | 4 |

LVDM: Latent Video Diffusion Models for High-Fidelity Long Video Generation

5 | 6 |       7 | 8 | 9 |
10 | Yingqing He 1   11 | Tianyu Yang 2  12 | Yong Zhang 2  13 | Ying Shan 2  14 | Qifeng Chen 1
15 |
16 |
17 |
18 | 1 The Hong Kong University of Science and Technology   2 Tencent AI Lab   19 |
20 |
21 |
22 | 23 | TL;DR: An efficient video diffusion model that can: 24 | 1️⃣ conditionally generate videos based on input text; 25 | 2️⃣ unconditionally generate videos with thousands of frames. 26 | 27 |
28 | 29 |
30 | 31 | 32 | ## 🍻 Results 33 | ### ☝️ Text-to-Video Generation 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 |
"A corgi is swimming fastly""astronaut riding a horse""A glass bead falling into water with a huge splash. Sunset in the background""A beautiful sunrise on mars. High definition, timelapse, dramaticcolors.""A bear dancing and jumping to upbeat music, moving his whole body.""An iron man surfing in the sea. cartoon style"
52 | 53 | ### ✌️ Unconditional Long Video Generation (40 seconds) 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 |
64 | 65 | ## ⏳ TODO 66 | - [x] Release pretrained text-to-video generation models and inference code 67 | - [x] Release unconditional video generation models 68 | - [x] Release training code 69 | - [ ] Update training and sampling for long video generation 70 |
71 | 72 | --- 73 | ## ⚙️ Setup 74 | 75 | ### Install Environment via Anaconda 76 | ```bash 77 | conda create -n lvdm python=3.8.5 78 | conda activate lvdm 79 | pip install -r requirements.txt 80 | ``` 81 | ### Pretrained Models and Used Datasets 82 | 83 | 84 | 85 | Download the pretrained checkpoints via the following commands in Linux terminal: 86 | ``` 87 | mkdir -p models/ae 88 | mkdir -p models/lvdm_short 89 | mkdir -p models/t2v 90 | 91 | # sky timelapse 92 | wget -O models/ae/ae_sky.ckpt https://huggingface.co/Yingqing/LVDM/resolve/main/ae/ae_sky.ckpt 93 | wget -O models/lvdm_short/short_sky.ckpt https://huggingface.co/Yingqing/LVDM/resolve/main/lvdm_short/short_sky.ckpt 94 | 95 | # taichi 96 | wget -O models/ae/ae_taichi.ckpt https://huggingface.co/Yingqing/LVDM/resolve/main/ae/ae_taichi.ckpt 97 | wget -O models/lvdm_short/short_taichi.ckpt https://huggingface.co/Yingqing/LVDM/resolve/main/lvdm_short/short_taichi.ckpt 98 | 99 | # text2video 100 | wget -O models/t2v/model.ckpt https://huggingface.co/Yingqing/LVDM/resolve/main/lvdm_short/t2v.ckpt 101 | 102 | ``` 103 | 104 | Prepare UCF-101 dataset 105 | ``` 106 | mkdir temp; cd temp 107 | 108 | # Download UCF-101 from the official website https://www.crcv.ucf.edu/data/UCF101.php (The UCF101 data ) 109 | 110 | wget https://www.crcv.ucf.edu/data/UCF101/UCF101.rar --no-check-certificate 111 | unrar x UCF101.rar 112 | 113 | # Download annotations from https://www.crcv.ucf.edu/data/UCF101.php (The Train/Test Splits for Action Recognition on UCF101 data set) 114 | 115 | wget https://www.crcv.ucf.edu/data/UCF101/UCF101TrainTestSplits-RecognitionTask.zip --no-check-certificate 116 | unzip UCF101TrainTestSplits-RecognitionTask.zip 117 | 118 | # Split the train and test split 119 | cd .. 120 | python lvdm/data/split_ucf101.py # please check this script 121 | 122 | ``` 123 | 125 | 126 | 127 | 128 | Download manually: 129 | - Sky Timelapse: [VideoAE](https://huggingface.co/Yingqing/LVDM/blob/main/ae/ae_sky.ckpt), [LVDM_short](https://huggingface.co/Yingqing/LVDM/blob/main/lvdm_short/short_sky.ckpt), [LVDM_pred](TBD), [LVDM_interp](TBD), [dataset](https://github.com/weixiong-ur/mdgan) 130 | - Taichi: [VideoAE](https://huggingface.co/Yingqing/LVDM/blob/main/ae/ae_taichi.ckpt), [LVDM_short](https://huggingface.co/Yingqing/LVDM/blob/main/lvdm_short/short_taichi.ckpt), [dataset](https://github.com/AliaksandrSiarohin/first-order-model/blob/master/data/taichi-loading/README.md) 131 | - Text2Video: [model](https://huggingface.co/Yingqing/LVDM/blob/main/lvdm_short/t2v.ckpt) 132 | 133 | --- 134 | ## 💫 Inference 135 | ### Sample Short Videos 136 | - unconditional generation 137 | 138 | ``` 139 | bash shellscripts/sample_lvdm_short.sh 140 | ``` 141 | - text to video generation 142 | ``` 143 | bash shellscripts/sample_lvdm_text2video.sh 144 | ``` 145 | 146 | ### Sample Long Videos 147 | ``` 148 | bash shellscripts/sample_lvdm_long.sh 149 | ``` 150 | 151 | --- 152 | ## 💫 Training 153 | 154 | ### Train video autoencoder 155 | ``` 156 | bash shellscripts/train_lvdm_videoae.sh 157 | ``` 158 | - remember to set `PROJ_ROOT`, `EXPNAME`, `DATADIR`, and `CONFIG`. 159 | 160 | ### Train unconditional lvdm for short video generation 161 | ``` 162 | bash shellscripts/train_lvdm_short.sh 163 | ``` 164 | - remember to set `PROJ_ROOT`, `EXPNAME`, `DATADIR`, `AEPATH` and `CONFIG`. 165 | 166 | ### Train unconditional lvdm for long video generation 167 | ``` 168 | # TBD 169 | ``` 170 | 171 | --- 172 | ## 💫 Evaluation 173 | ``` 174 | bash shellscripts/eval_lvdm_short.sh 175 | ``` 176 | - remember to set `DATACONFIG`, `FAKEPATH`, `REALPATH`, and `RESDIR`. 177 | --- 178 | 179 | ## 📃 Abstract 180 | AI-generated content has attracted lots of attention recently, but photo-realistic video synthesis is still challenging. Although many attempts using GANs and autoregressive models have been made in this area, the visual quality and length of generated videos are far from satisfactory. Diffusion models have shown remarkable results recently but require significant computational resources. To address this, we introduce lightweight video diffusion models by leveraging a low-dimensional 3D latent space, significantly outperforming previous pixel-space video diffusion models under a limited computational budget. In addition, we propose hierarchical diffusion in the latent space such that longer videos with more than one thousand frames can be produced. To further overcome the performance degradation issue for long video generation, we propose conditional latent perturbation and unconditional guidance that effectively mitigate the accumulated errors during the extension of video length. Extensive experiments on small domain datasets of different categories suggest that our framework generates more realistic and longer videos than previous strong baselines. We additionally provide an extension to large-scale text-to-video generation to demonstrate the superiority of our work. Our code and models will be made publicly available. 181 |
182 | 183 | ## 🔮 Pipeline 184 | 185 |

186 | 187 |

188 | 189 | --- 190 | ## 😉 Citation 191 | 192 | ``` 193 | @article{he2022lvdm, 194 | title={Latent Video Diffusion Models for High-Fidelity Long Video Generation}, 195 | author={Yingqing He and Tianyu Yang and Yong Zhang and Ying Shan and Qifeng Chen}, 196 | year={2022}, 197 | eprint={2211.13221}, 198 | archivePrefix={arXiv}, 199 | primaryClass={cs.CV} 200 | } 201 | ``` 202 | 203 | ## 🤗 Acknowledgements 204 | We built our code partially based on [latent diffusion models](https://github.com/CompVis/latent-diffusion) and [TATS](https://github.com/SongweiGe/TATS). Thanks the authors for sharing their awesome codebases! We aslo adopt Xintao Wang's [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN) for upscaling our text-to-video generation results. Thanks for their wonderful work! -------------------------------------------------------------------------------- /assets/framework.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YingqingHe/LVDM/d251dccfbf6352826f5c5681abd86e87ed7e6371/assets/framework.jpeg -------------------------------------------------------------------------------- /assets/framework.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YingqingHe/LVDM/d251dccfbf6352826f5c5681abd86e87ed7e6371/assets/framework.jpg -------------------------------------------------------------------------------- /assets/sky-long-001.gif: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:25bc40a8cfcfaf2ab8e1de96f136a14d61099f93319627056156ad265df3d649 3 | size 9214020 4 | -------------------------------------------------------------------------------- /assets/sky-long-002.gif: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:e513b06e4cc56afc1626539553297412cfaec4a5e246f2b9eff51171597aeb29 3 | size 9616345 4 | -------------------------------------------------------------------------------- /assets/sky-long-003.gif: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:491a5072a19828fd5aa0bc0094c0b95fdfb51e29b91bf896be16a8f704b30e7b 3 | size 9907437 4 | -------------------------------------------------------------------------------- /assets/t2v-001.gif: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:8e59eebb8da31a1fc1e2f38ab0042bf6e2b0db5b671c8ffaec603f8dc9643a29 3 | size 4873384 4 | -------------------------------------------------------------------------------- /assets/t2v-002.gif: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:98d7e6dd8b0841d74e96c695f62e1a0194748d7566f993b80409477f37124e15 3 | size 5200960 4 | -------------------------------------------------------------------------------- /assets/t2v-003.gif: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:7c8e8c0fbd617bdda67160d99b5f07b1de85991073e21a7e0833ac99ffb18fd1 3 | size 4522552 4 | -------------------------------------------------------------------------------- /assets/t2v-004.gif: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:ef0f32351de73248f4394f0dd101218399b7e9f47da73b88d69140416a47b6f5 3 | size 2402692 4 | -------------------------------------------------------------------------------- /assets/t2v-005.gif: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:fc9b8a763ead3dff469de42805bd0fec3df3b545213acba4ce0b5a42b56e4a22 3 | size 2754672 4 | -------------------------------------------------------------------------------- /assets/t2v-006.gif: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:00ffc6692300905413327e0982368e0774b7abd22635523f4b66d128db02e405 3 | size 4607083 4 | -------------------------------------------------------------------------------- /assets/t2v-007.gif: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:44ce784dca47e1695c115440bdf797d220e7b70bc243df1ad216ba0c7c2ba1a4 3 | size 5966789 4 | -------------------------------------------------------------------------------- /assets/t2v-008.gif: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:fd881c7daa09503c00331c23124dac8242f09676daa3eee1b329f374b3dd112b 3 | size 3926517 4 | -------------------------------------------------------------------------------- /assets/ucf-long-001.gif: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:0d788762c3d8edc1fb4ef35cee358fd094fe951ebec5a7fd244b8a770e2389d3 3 | size 8980149 4 | -------------------------------------------------------------------------------- /assets/ucf-long-002.gif: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:7a08a0794cd1df65904692f3712cc2f297b1319ebba2a716236f13f56691ccda 3 | size 7676461 4 | -------------------------------------------------------------------------------- /assets/ucf-long-003.gif: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:c08c0815540fc78dcd198415416984a26d37946dcad5a02ac8906ec94c0a3140 3 | size 8280966 4 | -------------------------------------------------------------------------------- /configs/lvdm_long/sky_interp.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 8.0e-5 #1.5e-04 3 | scale_lr: False 4 | target: lvdm.models.ddpm3d.FrameInterpPredLatentDiffusion 5 | params: 6 | linear_start: 0.0015 7 | linear_end: 0.0155 8 | log_every_t: 200 9 | timesteps: 1000 10 | loss_type: l1 11 | first_stage_key: image 12 | cond_stage_key: "image" 13 | image_size: 32 14 | channels: 4 15 | monitor: val/loss_simple_ema 16 | conditioning_key: concat-adm-mask 17 | cond_stage_config: null 18 | noisy_cond: True 19 | max_noise_level: 250 20 | cond_stage_trainable: False 21 | concat_mode: False 22 | scale_by_std: False 23 | scale_factor: 0.33422927 24 | shift_factor: 1.4606637 25 | encoder_type: 3d 26 | rand_temporal_mask: true 27 | p_interp: 0.9 28 | p_pred: 0.0 29 | n_prevs: null 30 | split_clips: False 31 | downfactor_t: null # used for split video frames to clips before encoding 32 | clip_length: null 33 | 34 | unet_config: 35 | target: lvdm.models.modules.openaimodel3d.FrameInterpPredUNet 36 | params: 37 | num_classes: 251 # timesteps for noise conditoining 38 | image_size: 32 39 | in_channels: 5 40 | out_channels: 4 41 | model_channels: 256 42 | attention_resolutions: 43 | - 8 44 | - 4 45 | - 2 46 | num_res_blocks: 3 47 | channel_mult: 48 | - 1 49 | - 2 50 | - 3 51 | - 4 52 | num_heads: 4 53 | use_temporal_transformer: False 54 | use_checkpoint: true 55 | legacy: False 56 | # temporal 57 | kernel_size_t: 1 58 | padding_t: 0 59 | temporal_length: 5 60 | use_relative_position: True 61 | use_scale_shift_norm: True 62 | first_stage_config: 63 | target: lvdm.models.autoencoder3d.AutoencoderKL 64 | params: 65 | monitor: "val/rec_loss" 66 | embed_dim: 4 67 | lossconfig: __is_first_stage__ 68 | ddconfig: 69 | double_z: True 70 | z_channels: 4 71 | encoder: 72 | target: lvdm.models.modules.aemodules3d.Encoder 73 | params: 74 | n_hiddens: 32 75 | downsample: [4, 8, 8] 76 | image_channel: 3 77 | norm_type: group 78 | padding_type: replicate 79 | double_z: True 80 | z_channels: 4 81 | decoder: 82 | target: lvdm.models.modules.aemodules3d.Decoder 83 | params: 84 | n_hiddens: 32 85 | upsample: [4, 8, 8] 86 | z_channels: 4 87 | image_channel: 3 88 | norm_type: group 89 | 90 | data: 91 | target: main.DataModuleFromConfig 92 | params: 93 | batch_size: 2 94 | num_workers: 0 95 | wrap: false 96 | train: 97 | target: lvdm.data.frame_dataset.VideoFrameDataset 98 | params: 99 | data_root: /dockerdata/sky_timelapse 100 | resolution: 256 101 | video_length: 20 102 | dataset_name: sky 103 | subset_split: train 104 | spatial_transform: center_crop_resize 105 | clip_step: 1 106 | temporal_transform: rand_clips 107 | validation: 108 | target: lvdm.data.frame_dataset.VideoFrameDataset 109 | params: 110 | data_root: /dockerdata/sky_timelapse 111 | resolution: 256 112 | video_length: 20 113 | dataset_name: sky 114 | subset_split: test 115 | spatial_transform: center_crop_resize 116 | clip_step: 1 117 | temporal_transform: rand_clips 118 | 119 | lightning: 120 | callbacks: 121 | image_logger: 122 | target: lvdm.utils.callbacks.ImageLogger 123 | params: 124 | batch_frequency: 2000 125 | max_images: 8 126 | increase_log_steps: False 127 | metrics_over_trainsteps_checkpoint: 128 | target: pytorch_lightning.callbacks.ModelCheckpoint 129 | params: 130 | filename: "{epoch:06}-{step:09}" 131 | save_weights_only: False 132 | every_n_epochs: 200 133 | every_n_train_steps: null 134 | trainer: 135 | benchmark: True 136 | batch_size: 2 137 | num_workers: 0 138 | num_nodes: 1 139 | max_epochs: 2000 140 | modelcheckpoint: 141 | target: pytorch_lightning.callbacks.ModelCheckpoint 142 | params: 143 | every_n_epochs: 1 144 | filename: "{epoch:04}-{step:06}" -------------------------------------------------------------------------------- /configs/lvdm_long/sky_pred.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 8.0e-5 # 1.5e-04 3 | scale_lr: False 4 | target: lvdm.models.ddpm3d.FrameInterpPredLatentDiffusion 5 | params: 6 | linear_start: 0.0015 7 | linear_end: 0.0155 8 | log_every_t: 200 9 | timesteps: 1000 10 | loss_type: l1 11 | first_stage_key: image 12 | cond_stage_key: "image" 13 | image_size: 32 14 | channels: 4 15 | monitor: val/loss_simple_ema 16 | conditioning_key: concat-adm-mask 17 | cond_stage_config: null 18 | noisy_cond: True 19 | max_noise_level: 250 20 | cond_stage_trainable: False 21 | concat_mode: False 22 | scale_by_std: False 23 | scale_factor: 0.33422927 24 | shift_factor: 1.4606637 25 | encoder_type: 3d 26 | rand_temporal_mask: true 27 | p_interp: 0.0 28 | p_pred: 0.5 29 | n_prevs: [1,] 30 | split_clips: False 31 | downfactor_t: null # used for split video frames to clips before encoding 32 | clip_length: null 33 | latent_frame_strde: 4 34 | 35 | unet_config: 36 | target: lvdm.models.modules.openaimodel3d.FrameInterpPredUNet 37 | params: 38 | num_classes: 251 # timesteps for noise conditoining 39 | image_size: 32 40 | in_channels: 5 41 | out_channels: 4 42 | model_channels: 256 43 | attention_resolutions: 44 | - 8 45 | - 4 46 | - 2 47 | num_res_blocks: 3 48 | channel_mult: 49 | - 1 50 | - 2 51 | - 3 52 | - 4 53 | num_heads: 4 54 | use_temporal_transformer: False 55 | use_checkpoint: true 56 | legacy: False 57 | # temporal 58 | kernel_size_t: 1 59 | padding_t: 0 60 | temporal_length: 4 61 | use_relative_position: True 62 | use_scale_shift_norm: True 63 | first_stage_config: 64 | target: lvdm.models.autoencoder3d.AutoencoderKL 65 | params: 66 | monitor: "val/rec_loss" 67 | embed_dim: 4 68 | lossconfig: __is_first_stage__ 69 | ddconfig: 70 | double_z: True 71 | z_channels: 4 72 | encoder: 73 | target: lvdm.models.modules.aemodules3d.Encoder 74 | params: 75 | n_hiddens: 32 76 | downsample: [4, 8, 8] 77 | image_channel: 3 78 | norm_type: group 79 | padding_type: replicate 80 | double_z: True 81 | z_channels: 4 82 | decoder: 83 | target: lvdm.models.modules.aemodules3d.Decoder 84 | params: 85 | n_hiddens: 32 86 | upsample: [4, 8, 8] 87 | z_channels: 4 88 | image_channel: 3 89 | norm_type: group 90 | 91 | data: 92 | target: main.DataModuleFromConfig 93 | params: 94 | batch_size: 2 95 | num_workers: 0 96 | wrap: false 97 | train: 98 | target: lvdm.data.frame_dataset.VideoFrameDataset 99 | params: 100 | data_root: /dockerdata/sky_timelapse 101 | resolution: 256 102 | video_length: 64 103 | dataset_name: sky 104 | subset_split: train 105 | spatial_transform: center_crop_resize 106 | clip_step: 1 107 | temporal_transform: rand_clips 108 | validation: 109 | target: lvdm.data.frame_dataset.VideoFrameDataset 110 | params: 111 | data_root: /dockerdata/sky_timelapse 112 | resolution: 256 113 | video_length: 64 114 | dataset_name: sky 115 | subset_split: test 116 | spatial_transform: center_crop_resize 117 | clip_step: 1 118 | temporal_transform: rand_clips 119 | 120 | lightning: 121 | callbacks: 122 | image_logger: 123 | target: lvdm.utils.callbacks.ImageLogger 124 | params: 125 | batch_frequency: 2000 126 | max_images: 8 127 | increase_log_steps: False 128 | metrics_over_trainsteps_checkpoint: 129 | target: pytorch_lightning.callbacks.ModelCheckpoint 130 | params: 131 | filename: "{epoch:06}-{step:09}" 132 | save_weights_only: False 133 | every_n_epochs: 100 134 | every_n_train_steps: null 135 | trainer: 136 | benchmark: True 137 | batch_size: 2 138 | num_workers: 0 139 | num_nodes: 1 140 | max_epochs: 2000 141 | modelcheckpoint: 142 | target: pytorch_lightning.callbacks.ModelCheckpoint 143 | params: 144 | every_n_epochs: 1 145 | filename: "{epoch:04}-{step:06}" -------------------------------------------------------------------------------- /configs/lvdm_short/sky.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 6.0e-05 # 1.5e-04 3 | scale_lr: False 4 | target: lvdm.models.ddpm3d.LatentDiffusion 5 | params: 6 | linear_start: 0.0015 7 | linear_end: 0.0155 8 | log_every_t: 200 9 | timesteps: 1000 10 | loss_type: l1 11 | first_stage_key: image 12 | cond_stage_key: "image" 13 | image_size: 32 14 | video_length: 4 15 | channels: 4 16 | monitor: val/loss_simple_ema 17 | cond_stage_trainable: False 18 | concat_mode: False 19 | scale_by_std: False 20 | scale_factor: 0.33422927 21 | shift_factor: 1.4606637 22 | cond_stage_config: __is_unconditional__ 23 | encoder_type: 3d 24 | 25 | unet_config: 26 | target: lvdm.models.modules.openaimodel3d.UNetModel 27 | params: 28 | image_size: 32 29 | in_channels: 4 30 | out_channels: 4 31 | model_channels: 256 32 | attention_resolutions: 33 | - 8 34 | - 4 35 | - 2 36 | num_res_blocks: 3 37 | channel_mult: 38 | - 1 39 | - 2 40 | - 3 41 | - 4 42 | num_heads: 4 43 | use_temporal_transformer: False 44 | use_checkpoint: true 45 | legacy: False 46 | # temporal 47 | kernel_size_t: 1 48 | padding_t: 0 49 | temporal_length: 4 50 | use_relative_position: True 51 | use_scale_shift_norm: True 52 | first_stage_config: 53 | target: lvdm.models.autoencoder3d.AutoencoderKL 54 | params: 55 | ckpt_path: ${ckpt_path} 56 | monitor: "val/rec_loss" 57 | embed_dim: 4 58 | lossconfig: __is_first_stage__ 59 | ddconfig: 60 | double_z: True 61 | z_channels: 4 62 | encoder: 63 | target: lvdm.models.modules.aemodules3d.Encoder 64 | params: 65 | n_hiddens: 32 66 | downsample: [4, 8, 8] 67 | image_channel: 3 68 | norm_type: group 69 | padding_type: replicate 70 | double_z: True 71 | z_channels: 4 72 | decoder: 73 | target: lvdm.models.modules.aemodules3d.Decoder 74 | params: 75 | n_hiddens: 32 76 | upsample: [4, 8, 8] 77 | z_channels: 4 78 | image_channel: 3 79 | norm_type: group 80 | 81 | data: 82 | target: main.DataModuleFromConfig 83 | params: 84 | batch_size: 3 85 | num_workers: 0 86 | wrap: false 87 | train: 88 | target: lvdm.data.frame_dataset.VideoFrameDataset 89 | params: 90 | data_root: ${data_root} 91 | resolution: 256 92 | video_length: 16 93 | dataset_name: sky 94 | subset_split: train 95 | spatial_transform: center_crop_resize 96 | temporal_transform: rand_clips 97 | validation: 98 | target: lvdm.data.frame_dataset.VideoFrameDataset 99 | params: 100 | data_root: ${data_root} 101 | resolution: 256 102 | video_length: 16 103 | dataset_name: sky 104 | subset_split: test 105 | spatial_transform: center_crop_resize 106 | temporal_transform: rand_clips 107 | 108 | lightning: 109 | callbacks: 110 | image_logger: 111 | target: lvdm.utils.callbacks.ImageLogger 112 | params: 113 | batch_frequency: 1000 114 | max_images: 8 115 | increase_log_steps: False 116 | metrics_over_trainsteps_checkpoint: 117 | target: pytorch_lightning.callbacks.ModelCheckpoint 118 | params: 119 | filename: "{epoch:06}-{step:09}" 120 | save_weights_only: False 121 | every_n_epochs: 300 122 | every_n_train_steps: null 123 | trainer: 124 | benchmark: True 125 | batch_size: 3 126 | num_workers: 0 127 | num_nodes: 4 128 | accumulate_grad_batches: 2 129 | max_epochs: 2000 130 | modelcheckpoint: 131 | target: pytorch_lightning.callbacks.ModelCheckpoint 132 | params: 133 | every_n_epochs: 1 134 | filename: "{epoch:04}-{step:06}" -------------------------------------------------------------------------------- /configs/lvdm_short/taichi.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 6.0e-05 # 1.5e-04 3 | scale_lr: False 4 | target: lvdm.models.ddpm3d.LatentDiffusion 5 | params: 6 | linear_start: 0.0015 7 | linear_end: 0.0155 8 | log_every_t: 200 9 | timesteps: 1000 10 | loss_type: l1 11 | first_stage_key: video 12 | cond_stage_key: "image" 13 | image_size: 32 14 | video_length: 4 15 | channels: 4 16 | monitor: val/loss_simple_ema 17 | # use_ema: False # default =True 18 | cond_stage_trainable: False 19 | concat_mode: False 20 | scale_by_std: False 21 | scale_factor: 0.175733933 22 | shift_factor: 0.03291025 23 | cond_stage_config: __is_unconditional__ 24 | encoder_type: 3d 25 | 26 | unet_config: 27 | target: lvdm.models.modules.openaimodel3d.UNetModel 28 | params: 29 | image_size: 32 30 | in_channels: 4 31 | out_channels: 4 32 | model_channels: 256 33 | attention_resolutions: 34 | - 8 35 | - 4 36 | - 2 37 | num_res_blocks: 3 38 | channel_mult: 39 | - 1 40 | - 2 41 | - 3 42 | - 4 43 | num_heads: 4 44 | use_temporal_transformer: False 45 | use_checkpoint: true 46 | legacy: False 47 | # temporal 48 | kernel_size_t: 1 49 | padding_t: 0 50 | temporal_length: 4 51 | use_relative_position: True 52 | use_scale_shift_norm: True 53 | first_stage_config: 54 | target: lvdm.models.autoencoder3d.AutoencoderKL 55 | params: 56 | ckpt_path: ${ckpt_path} 57 | monitor: "val/rec_loss" 58 | embed_dim: 4 59 | lossconfig: __is_first_stage__ 60 | ddconfig: 61 | double_z: True 62 | z_channels: 4 63 | encoder: 64 | target: lvdm.models.modules.aemodules3d.Encoder 65 | params: 66 | n_hiddens: 32 67 | downsample: [4, 8, 8] 68 | image_channel: 3 69 | norm_type: group 70 | padding_type: replicate 71 | double_z: True 72 | z_channels: 4 73 | decoder: 74 | target: lvdm.models.modules.aemodules3d.Decoder 75 | params: 76 | n_hiddens: 32 77 | upsample: [4, 8, 8] 78 | z_channels: 4 79 | image_channel: 3 80 | norm_type: group 81 | 82 | data: 83 | target: main.DataModuleFromConfig 84 | params: 85 | batch_size: 3 86 | num_workers: 0 87 | wrap: false 88 | train: 89 | target: lvdm.data.taichi.Taichi 90 | params: 91 | data_root: data_root 92 | resolution: 256 93 | video_length: 16 94 | subset_split: all 95 | frame_stride: 4 96 | validation: 97 | target: lvdm.data.taichi.Taichi 98 | params: 99 | data_root: data_root 100 | resolution: 256 101 | video_length: 16 102 | subset_split: test 103 | frame_stride: 4 104 | 105 | lightning: 106 | callbacks: 107 | image_logger: 108 | target: lvdm.utils.callbacks.ImageLogger 109 | params: 110 | batch_frequency: 1000 111 | max_images: 8 112 | increase_log_steps: False 113 | metrics_over_trainsteps_checkpoint: 114 | target: pytorch_lightning.callbacks.ModelCheckpoint 115 | params: 116 | filename: "{epoch:06}-{step:09}" 117 | save_weights_only: False 118 | every_n_epochs: 300 119 | every_n_train_steps: null 120 | trainer: 121 | benchmark: True 122 | batch_size: 3 123 | num_workers: 0 124 | num_nodes: 4 125 | accumulate_grad_batches: 2 126 | max_epochs: 4000 127 | modelcheckpoint: 128 | target: pytorch_lightning.callbacks.ModelCheckpoint 129 | params: 130 | every_n_epochs: 1 131 | filename: "{epoch:04}-{step:06}" -------------------------------------------------------------------------------- /configs/lvdm_short/text2video.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: lvdm.models.ddpm3d.LatentDiffusion 3 | params: 4 | linear_start: 0.00085 5 | linear_end: 0.012 6 | num_timesteps_cond: 1 7 | log_every_t: 200 8 | timesteps: 1000 9 | first_stage_key: video 10 | cond_stage_key: caption 11 | image_size: 12 | - 32 13 | - 32 14 | video_length: 16 15 | channels: 4 16 | cond_stage_trainable: false 17 | conditioning_key: crossattn 18 | scale_by_std: false 19 | scale_factor: 0.18215 20 | use_ema: false 21 | 22 | unet_config: 23 | target: lvdm.models.modules.openaimodel3d.UNetModel 24 | params: 25 | image_size: 32 26 | in_channels: 4 27 | out_channels: 4 28 | model_channels: 320 29 | attention_resolutions: 30 | - 4 31 | - 2 32 | - 1 33 | num_res_blocks: 2 34 | channel_mult: 35 | - 1 36 | - 2 37 | - 4 38 | - 4 39 | num_heads: 8 40 | transformer_depth: 1 41 | context_dim: 768 42 | use_checkpoint: true 43 | legacy: false 44 | kernel_size_t: 1 45 | padding_t: 0 46 | use_temporal_transformer: true 47 | temporal_length: 16 48 | use_relative_position: true 49 | 50 | first_stage_config: 51 | target: lvdm.models.autoencoder.AutoencoderKL 52 | params: 53 | embed_dim: 4 54 | monitor: val/rec_loss 55 | ddconfig: 56 | double_z: true 57 | z_channels: 4 58 | resolution: 256 59 | in_channels: 3 60 | out_ch: 3 61 | ch: 128 62 | ch_mult: 63 | - 1 64 | - 2 65 | - 4 66 | - 4 67 | num_res_blocks: 2 68 | attn_resolutions: [] 69 | dropout: 0.0 70 | lossconfig: 71 | target: torch.nn.Identity 72 | 73 | cond_stage_config: 74 | target: lvdm.models.modules.condition_modules.FrozenCLIPEmbedder -------------------------------------------------------------------------------- /configs/lvdm_short/ucf.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 6.0e-05 # 1.5e-04 3 | scale_lr: False 4 | target: lvdm.models.ddpm3d.LatentDiffusion 5 | params: 6 | linear_start: 0.0015 7 | linear_end: 0.0155 8 | log_every_t: 200 9 | timesteps: 1000 10 | loss_type: l1 11 | first_stage_key: image 12 | cond_stage_key: "image" 13 | image_size: 32 14 | video_length: 4 15 | channels: 4 16 | monitor: val/loss_simple_ema 17 | cond_stage_trainable: False 18 | concat_mode: False 19 | scale_by_std: False 20 | scale_factor: 0.220142075 21 | shift_factor: 0.5837740898 22 | cond_stage_config: __is_unconditional__ 23 | encoder_type: 3d 24 | 25 | unet_config: 26 | target: lvdm.models.modules.openaimodel3d.UNetModel 27 | params: 28 | image_size: 32 29 | in_channels: 4 30 | out_channels: 4 31 | model_channels: 256 32 | attention_resolutions: 33 | - 8 34 | - 4 35 | - 2 36 | num_res_blocks: 3 37 | channel_mult: 38 | - 1 39 | - 2 40 | - 3 41 | - 4 42 | num_heads: 4 43 | use_temporal_transformer: False 44 | use_checkpoint: true 45 | legacy: False 46 | # temporal 47 | kernel_size_t: 1 48 | padding_t: 0 49 | temporal_length: 4 50 | use_relative_position: True 51 | use_scale_shift_norm: True 52 | first_stage_config: 53 | target: lvdm.models.autoencoder3d.AutoencoderKL 54 | params: 55 | ckpt_path: ${ckpt_path} 56 | monitor: "val/rec_loss" 57 | embed_dim: 4 58 | lossconfig: __is_first_stage__ 59 | ddconfig: 60 | double_z: True 61 | z_channels: 4 62 | encoder: 63 | target: lvdm.models.modules.aemodules3d.Encoder 64 | params: 65 | n_hiddens: 32 66 | downsample: [4, 8, 8] 67 | image_channel: 3 68 | norm_type: group 69 | padding_type: replicate 70 | double_z: True 71 | z_channels: 4 72 | decoder: 73 | target: lvdm.models.modules.aemodules3d.Decoder 74 | params: 75 | n_hiddens: 32 76 | upsample: [4, 8, 8] 77 | z_channels: 4 78 | image_channel: 3 79 | norm_type: group 80 | 81 | data: 82 | target: main.DataModuleFromConfig 83 | params: 84 | batch_size: 2 85 | num_workers: 0 86 | wrap: false 87 | train: 88 | target: lvdm.data.frame_dataset.VideoFrameDataset 89 | params: 90 | data_root: ${data_root} 91 | resolution: 256 92 | video_length: 16 93 | dataset_name: UCF-101 94 | subset_split: all 95 | spatial_transform: center_crop_resize 96 | clip_step: 1 97 | temporal_transform: rand_clips 98 | validation: 99 | target: lvdm.data.frame_dataset.VideoFrameDataset 100 | params: 101 | data_root: ${data_root} 102 | resolution: 256 103 | video_length: 16 104 | dataset_name: UCF-101 105 | subset_split: test 106 | spatial_transform: center_crop_resize 107 | clip_step: 1 108 | temporal_transform: rand_clips 109 | 110 | lightning: 111 | callbacks: 112 | image_logger: 113 | target: lvdm.utils.callbacks.ImageLogger 114 | params: 115 | batch_frequency: 1000 116 | max_images: 8 117 | increase_log_steps: False 118 | metrics_over_trainsteps_checkpoint: 119 | target: pytorch_lightning.callbacks.ModelCheckpoint 120 | params: 121 | filename: "{epoch:06}-{step:09}" 122 | save_weights_only: False 123 | every_n_epochs: 300 124 | every_n_train_steps: null 125 | trainer: 126 | benchmark: True 127 | batch_size: 2 128 | num_workers: 0 129 | num_nodes: 4 130 | accumulate_grad_batches: 2 131 | max_epochs: 2500 #2000 132 | modelcheckpoint: 133 | target: pytorch_lightning.callbacks.ModelCheckpoint 134 | params: 135 | every_n_epochs: 1 136 | filename: "{epoch:04}-{step:06}" -------------------------------------------------------------------------------- /configs/videoae/sky.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 3.60e-05 3 | scale_lr: False 4 | target: lvdm.models.autoencoder3d.AutoencoderKL 5 | params: 6 | monitor: "val/rec_loss" 7 | embed_dim: 4 8 | lossconfig: 9 | target: lvdm.models.losses.LPIPSWithDiscriminator 10 | params: 11 | disc_start: 50001 12 | kl_weight: 0.0 13 | disc_weight: 0.5 14 | ddconfig: 15 | double_z: True 16 | z_channels: 4 17 | encoder: 18 | target: lvdm.models.modules.aemodules3d.Encoder 19 | params: 20 | n_hiddens: 32 21 | downsample: [4, 8, 8] 22 | image_channel: 3 23 | norm_type: group 24 | padding_type: replicate 25 | double_z: True 26 | z_channels: 4 27 | 28 | decoder: 29 | target: lvdm.models.modules.aemodules3d.Decoder 30 | params: 31 | n_hiddens: 32 32 | upsample: [4, 8, 8] 33 | z_channels: 4 34 | image_channel: 3 35 | norm_type: group 36 | 37 | data: 38 | target: main.DataModuleFromConfig 39 | params: 40 | batch_size: 1 41 | num_workers: 0 42 | wrap: false 43 | train: 44 | target: lvdm.data.frame_dataset.VideoFrameDataset 45 | params: 46 | data_root: ${data_root} 47 | resolution: 256 48 | video_length: 16 49 | dataset_name: sky 50 | subset_split: train 51 | spatial_transform: center_crop_resize 52 | clip_step: 1 53 | temporal_transform: rand_clips 54 | validation: 55 | target: lvdm.data.frame_dataset.VideoFrameDataset 56 | params: 57 | data_root: ${data_root} 58 | resolution: 256 59 | video_length: 16 60 | dataset_name: sky 61 | subset_split: test 62 | spatial_transform: center_crop_resize 63 | clip_step: 1 64 | temporal_transform: rand_clips 65 | lightning: 66 | find_unused_parameters: True 67 | callbacks: 68 | image_logger: 69 | target: lvdm.utils.callbacks.ImageLogger 70 | params: 71 | batch_frequency: 300 72 | max_images: 8 73 | increase_log_steps: False 74 | log_to_tblogger: False 75 | metrics_over_trainsteps_checkpoint: 76 | target: pytorch_lightning.callbacks.ModelCheckpoint 77 | params: 78 | filename: "{epoch:06}-{step:09}" 79 | save_weights_only: False 80 | every_n_epochs: 100 81 | every_n_train_steps: null 82 | trainer: 83 | benchmark: True 84 | accumulate_grad_batches: 2 85 | batch_size: 1 86 | num_workers: 0 87 | max_epochs: 1000 88 | modelcheckpoint: 89 | target: pytorch_lightning.callbacks.ModelCheckpoint 90 | params: 91 | every_n_epochs: 1 92 | filename: "{epoch:04}-{step:06}" 93 | -------------------------------------------------------------------------------- /configs/videoae/taichi.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.6e-05 # 8.0e-05 3 | target: lvdm.models.autoencoder3d.AutoencoderKL 4 | params: 5 | monitor: "val/rec_loss" 6 | embed_dim: 4 7 | image_key: video 8 | lossconfig: 9 | target: lvdm.models.losses.LPIPSWithDiscriminator 10 | params: 11 | disc_start: 50001 12 | kl_weight: 0.0 13 | disc_weight: 0.5 14 | ddconfig: 15 | double_z: True 16 | z_channels: 4 17 | encoder: 18 | target: lvdm.models.modules.aemodules3d.Encoder 19 | params: 20 | n_hiddens: 32 21 | downsample: [4, 8, 8] 22 | image_channel: 3 23 | norm_type: group 24 | padding_type: replicate 25 | double_z: True 26 | z_channels: 4 27 | 28 | decoder: 29 | target: lvdm.models.modules.aemodules3d.Decoder 30 | params: 31 | n_hiddens: 32 32 | upsample: [4, 8, 8] 33 | z_channels: 4 34 | image_channel: 3 35 | norm_type: group 36 | 37 | data: 38 | target: main.DataModuleFromConfig 39 | params: 40 | batch_size: 1 41 | num_workers: 0 42 | wrap: false 43 | train: 44 | target: lvdm.data.taichi.Taichi 45 | params: 46 | data_dir: /apdcephfs_cq2/share_1290939/yingqinghe/taichi 47 | resolution: 256 48 | video_length: 16 49 | subset_split: all 50 | frame_stride: 4 51 | validation: 52 | target: lvdm.data.taichi.Taichi 53 | params: 54 | data_dir: /apdcephfs_cq2/share_1290939/yingqinghe/taichi 55 | resolution: 256 56 | video_length: 16 57 | subset_split: test 58 | frame_stride: 4 59 | lightning: 60 | find_unused_parameters: True 61 | callbacks: 62 | image_logger: 63 | target: lvdm.utils.callbacks.ImageLogger 64 | params: 65 | batch_frequency: 300 66 | max_images: 8 67 | increase_log_steps: False 68 | log_to_tblogger: False 69 | metrics_over_trainsteps_checkpoint: 70 | target: pytorch_lightning.callbacks.ModelCheckpoint 71 | params: 72 | filename: "{epoch:06}-{step:09}" 73 | save_weights_only: False 74 | every_n_epochs: 100 75 | every_n_train_steps: null 76 | trainer: 77 | benchmark: True 78 | accumulate_grad_batches: 2 79 | batch_size: 1 80 | num_workers: 0 81 | modelcheckpoint: 82 | target: pytorch_lightning.callbacks.ModelCheckpoint 83 | params: 84 | every_n_epochs: 1 85 | filename: "{epoch:04}-{step:06}" 86 | -------------------------------------------------------------------------------- /configs/videoae/ucf.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 3.60e-05 3 | scale_lr: False 4 | target: lvdm.models.autoencoder3d.AutoencoderKL 5 | params: 6 | monitor: "val/rec_loss" 7 | embed_dim: 4 8 | lossconfig: 9 | target: lvdm.models.losses.LPIPSWithDiscriminator 10 | params: 11 | disc_start: 50001 12 | kl_weight: 0.0 13 | disc_weight: 0.5 14 | ddconfig: 15 | double_z: True 16 | z_channels: 4 17 | encoder: 18 | target: lvdm.models.modules.aemodules3d.Encoder 19 | params: 20 | n_hiddens: 32 21 | downsample: [4, 8, 8] 22 | image_channel: 3 23 | norm_type: group 24 | padding_type: replicate 25 | double_z: True 26 | z_channels: 4 27 | 28 | decoder: 29 | target: lvdm.models.modules.aemodules3d.Decoder 30 | params: 31 | n_hiddens: 32 32 | upsample: [4, 8, 8] 33 | z_channels: 4 34 | image_channel: 3 35 | norm_type: group 36 | 37 | data: 38 | target: main.DataModuleFromConfig 39 | params: 40 | batch_size: 1 41 | num_workers: 0 42 | wrap: false 43 | train: 44 | target: lvdm.data.frame_dataset.VideoFrameDataset 45 | params: 46 | data_root: ${data_root} 47 | resolution: 256 48 | video_length: 16 49 | dataset_name: UCF-101 50 | subset_split: train 51 | spatial_transform: center_crop_resize 52 | clip_step: 1 53 | temporal_transform: rand_clips 54 | validation: 55 | target: lvdm.data.frame_dataset.VideoFrameDataset 56 | params: 57 | data_root: ${data_root} 58 | resolution: 256 59 | video_length: 16 60 | dataset_name: UCF-101 61 | subset_split: test 62 | spatial_transform: center_crop_resize 63 | clip_step: 1 64 | temporal_transform: rand_clips 65 | lightning: 66 | find_unused_parameters: True 67 | callbacks: 68 | image_logger: 69 | target: lvdm.utils.callbacks.ImageLogger 70 | params: 71 | batch_frequency: 1000 72 | max_images: 8 73 | increase_log_steps: False 74 | log_to_tblogger: False 75 | trainer: 76 | benchmark: True 77 | accumulate_grad_batches: 2 78 | batch_size: 1 79 | num_workers: 0 80 | modelcheckpoint: 81 | target: pytorch_lightning.callbacks.ModelCheckpoint 82 | params: 83 | every_n_epochs: 1 84 | filename: "{epoch:04}-{step:06}" -------------------------------------------------------------------------------- /configs/videoae/ucf_videodata.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 3.60e-05 3 | scale_lr: False 4 | target: lvdm.models.autoencoder3d.AutoencoderKL 5 | params: 6 | image_key: video 7 | monitor: "val/rec_loss" 8 | embed_dim: 4 9 | lossconfig: 10 | target: lvdm.models.losses.LPIPSWithDiscriminator 11 | params: 12 | disc_start: 50001 13 | kl_weight: 0.0 14 | disc_weight: 0.5 15 | ddconfig: 16 | double_z: True 17 | z_channels: 4 18 | encoder: 19 | target: lvdm.models.modules.aemodules3d.Encoder 20 | params: 21 | n_hiddens: 32 22 | downsample: [4, 8, 8] 23 | image_channel: 3 24 | norm_type: group 25 | padding_type: replicate 26 | double_z: True 27 | z_channels: 4 28 | 29 | decoder: 30 | target: lvdm.models.modules.aemodules3d.Decoder 31 | params: 32 | n_hiddens: 32 33 | upsample: [4, 8, 8] 34 | z_channels: 4 35 | image_channel: 3 36 | norm_type: group 37 | data: 38 | target: main.DataModuleFromConfig 39 | params: 40 | batch_size: 1 41 | num_workers: 0 42 | wrap: false 43 | train: 44 | target: lvdm.data.ucf.UCF101 45 | params: 46 | data_root: ${data_root} 47 | resolution: 256 48 | video_length: 16 49 | subset_split: all 50 | frame_stride: 1 51 | validation: 52 | target: lvdm.data.ucf.UCF101 53 | params: 54 | data_root: ${data_root} 55 | resolution: 256 56 | video_length: 16 57 | subset_split: test 58 | frame_stride: 1 59 | 60 | lightning: 61 | find_unused_parameters: True 62 | callbacks: 63 | image_logger: 64 | target: lvdm.utils.callbacks.ImageLogger 65 | params: 66 | batch_frequency: 1000 67 | max_images: 8 68 | increase_log_steps: False 69 | log_to_tblogger: False 70 | trainer: 71 | benchmark: True 72 | accumulate_grad_batches: 2 73 | batch_size: 1 74 | num_workers: 0 75 | modelcheckpoint: 76 | target: pytorch_lightning.callbacks.ModelCheckpoint 77 | params: 78 | every_n_epochs: 1 79 | filename: "{epoch:04}-{step:06}" -------------------------------------------------------------------------------- /input/prompts.txt: -------------------------------------------------------------------------------- 1 | astronaut riding a horse 2 | Flying through an intense battle between pirate ships in a stormy ocean -------------------------------------------------------------------------------- /lvdm/data/frame_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import re 4 | from PIL import ImageFile 5 | from PIL import Image 6 | 7 | import torch 8 | import torch.utils.data as data 9 | import torchvision.transforms as transforms 10 | import torchvision.transforms._transforms_video as transforms_video 11 | 12 | """ VideoFrameDataset """ 13 | 14 | ImageFile.LOAD_TRUNCATED_IMAGES = True 15 | IMG_EXTENSIONS = [ 16 | '.jpg', '.JPG', '.jpeg', '.JPEG', 17 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 18 | ] 19 | 20 | 21 | def pil_loader(path): 22 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 23 | ''' 24 | with open(path, 'rb') as f: 25 | with Image.open(f) as img: 26 | return img.convert('RGB') 27 | ''' 28 | Im = Image.open(path) 29 | return Im.convert('RGB') 30 | 31 | def accimage_loader(path): 32 | import accimage 33 | try: 34 | return accimage.Image(path) 35 | except IOError: 36 | # Potentially a decoding problem, fall back to PIL.Image 37 | return pil_loader(path) 38 | 39 | def default_loader(path): 40 | ''' 41 | from torchvision import get_image_backend 42 | if get_image_backend() == 'accimage': 43 | return accimage_loader(path) 44 | else: 45 | ''' 46 | return pil_loader(path) 47 | 48 | def is_image_file(filename): 49 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 50 | 51 | def find_classes(dir): 52 | assert(os.path.exists(dir)), f'{dir} does not exist' 53 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] 54 | classes.sort() 55 | class_to_idx = {classes[i]: i for i in range(len(classes))} 56 | return classes, class_to_idx 57 | 58 | def class_name_to_idx(annotation_dir): 59 | """ 60 | return class indices from 0 ~ num_classes-1 61 | """ 62 | fpath = os.path.join(annotation_dir, "classInd.txt") 63 | with open(fpath, "r") as f: 64 | data = f.readlines() 65 | class_to_idx = {x.strip().split(" ")[1].lower():int(x.strip().split(" ")[0]) - 1 for x in data} 66 | return class_to_idx 67 | 68 | def make_dataset(dir, nframes, class_to_idx, frame_stride=1, **kwargs): 69 | """ 70 | videos are saved in second-level directory: 71 | dir: video dir. Format: 72 | videoxxx 73 | videoxxx_1 74 | frame1.jpg 75 | frame2.jpg 76 | videoxxx_2 77 | frame1.jpg 78 | ... 79 | videoxxx 80 | 81 | nframes: num of frames of every video clips 82 | class_to_idx: for mapping video name to video id 83 | """ 84 | if frame_stride != 1: 85 | raise NotImplementedError 86 | 87 | clips = [] 88 | videos = [] 89 | n_clip = 0 90 | video_frames = [] 91 | for video_name in sorted(os.listdir(dir)): 92 | if os.path.isdir(os.path.join(dir,video_name)): 93 | 94 | # eg: dir + '/rM7aPu9WV2Q' 95 | subfolder_path = os.path.join(dir, video_name) # video_name: rM7aPu9WV2Q 96 | for subsubfold in sorted(os.listdir(subfolder_path)): 97 | subsubfolder_path = os.path.join(subfolder_path, subsubfold) 98 | if os.path.isdir(subsubfolder_path): # eg: dir/rM7aPu9WV2Q/1' 99 | clip_frames = [] 100 | i = 1 101 | # traverse frames in one video 102 | for fname in sorted(os.listdir(subsubfolder_path)): 103 | if is_image_file(fname): 104 | img_path = os.path.join(subsubfolder_path, fname) # eg: dir + '/rM7aPu9WV2Q/rM7aPu9WV2Q_1/rM7aPu9WV2Q_frames_00086552.jpg' 105 | frame_info = (img_path, class_to_idx[video_name]) #(img_path, video_id) 106 | clip_frames.append(frame_info) 107 | video_frames.append(frame_info) 108 | 109 | # append clips, clip_step=n_frames (no frame overlap between clips). 110 | if i % nframes == 0 and i >0: 111 | clips.append(clip_frames) 112 | n_clip += 1 113 | clip_frames = [] 114 | i = i+1 115 | 116 | if len(video_frames) >= nframes: 117 | videos.append(video_frames) 118 | video_frames = [] 119 | 120 | print('number of long videos:', len(videos)) 121 | print('number of short videos', len(clips)) 122 | return clips, videos 123 | 124 | def split_by_captical(s): 125 | s_list = re.sub( r"([A-Z])", r" \1", s).split() 126 | string = "" 127 | for s in s_list: 128 | string += s + " " 129 | return string.rstrip(" ").lower() 130 | 131 | def make_dataset_ucf(dir, nframes, class_to_idx, frame_stride=1, clip_step=None): 132 | """ 133 | Load consecutive clips and consecutive frames from `dir`. 134 | 135 | args: 136 | nframes: num of frames of every video clips 137 | class_to_idx: for mapping video name to video id 138 | frame_stride: select frames with a stride. 139 | clip_step: select clips with a step. if clip_step< nframes, 140 | there will be overlapped frames among two consecutive clips. 141 | 142 | assert videos are saved in first-level directory: 143 | dir: 144 | videoxxx1 145 | frame1.jpg 146 | frame2.jpg 147 | videoxxx2 148 | """ 149 | if clip_step is None: 150 | # consecutive clips with no frame overlap 151 | clip_step = nframes 152 | # make videos 153 | clips = [] # 2d list 154 | videos = [] # 2d list 155 | for video_name in sorted(os.listdir(dir)): 156 | if video_name != '_broken_clips': 157 | video_path = os.path.join(dir, video_name) 158 | assert(os.path.isdir(video_path)) 159 | 160 | frames = [] 161 | for i, fname in enumerate(sorted(os.listdir(video_path))): 162 | assert(is_image_file(fname)),f'fname={fname},video_path={video_path},dir={dir}' 163 | 164 | # get frame info 165 | img_path = os.path.join(video_path, fname) 166 | class_name = video_name.split("_")[1].lower() # v_BoxingSpeedBag_g12_c05 -> boxingspeedbag 167 | class_caption = split_by_captical(video_name.split("_")[1]) # v_BoxingSpeedBag_g12_c05 -> BoxingSpeedBag -> boxing speed bag 168 | frame_info = { 169 | "img_path": img_path, 170 | "class_index": class_to_idx[class_name], 171 | "class_name": class_name, #boxingspeedbag 172 | "class_caption": class_caption #boxing speed bag 173 | } 174 | frames.append(frame_info) 175 | frames = frames[::frame_stride] 176 | 177 | # make videos 178 | if len(frames) >= nframes: 179 | videos.append(frames) 180 | 181 | # make clips 182 | start_indices = list(range(len(frames)))[::clip_step] 183 | for i in start_indices: 184 | clip = frames[i:i+nframes] 185 | if len(clip) == nframes: 186 | clips.append(clip) 187 | return clips, videos 188 | 189 | def load_and_transform_frames(frame_list, loader, img_transform=None): 190 | assert(isinstance(frame_list, list)) 191 | clip = [] 192 | labels = [] 193 | for frame in frame_list: 194 | 195 | if isinstance(frame, tuple): 196 | fpath, label = frame 197 | elif isinstance(frame, dict): 198 | fpath = frame["img_path"] 199 | label = { 200 | "class_index": frame["class_index"], 201 | "class_name": frame["class_name"], 202 | "class_caption": frame["class_caption"], 203 | } 204 | 205 | labels.append(label) 206 | img = loader(fpath) 207 | if img_transform is not None: 208 | img = img_transform(img) 209 | img = img.view(img.size(0),1, img.size(1), img.size(2)) 210 | clip.append(img) 211 | return clip, labels[0] # all frames have same label.. 212 | 213 | class VideoFrameDataset(data.Dataset): 214 | def __init__(self, 215 | data_root, 216 | resolution, 217 | video_length, # clip length 218 | dataset_name="", 219 | subset_split="", 220 | annotation_dir=None, 221 | spatial_transform="", 222 | temporal_transform="", 223 | frame_stride=1, 224 | clip_step=None, 225 | ): 226 | 227 | self.loader = default_loader 228 | self.video_length = video_length 229 | self.subset_split = subset_split 230 | self.temporal_transform = temporal_transform 231 | self.spatial_transform = spatial_transform 232 | self.frame_stride = frame_stride 233 | self.dataset_name = dataset_name 234 | 235 | assert(subset_split in ["train", "test", "all", ""]) # "" means no subset_split directory. 236 | assert(self.temporal_transform in ["", "rand_clips"]) 237 | 238 | if subset_split == 'all': 239 | video_dir = os.path.join(data_root, "train") 240 | else: 241 | video_dir = os.path.join(data_root, subset_split) 242 | 243 | if dataset_name == 'UCF-101': 244 | if annotation_dir is None: 245 | annotation_dir = os.path.join(data_root, "ucfTrainTestlist") 246 | class_to_idx = class_name_to_idx(annotation_dir) 247 | assert(len(class_to_idx) == 101), f'num of classes = {len(class_to_idx)}, not 101' 248 | elif dataset_name == 'sky': 249 | classes, class_to_idx = find_classes(video_dir) 250 | else: 251 | class_to_idx = None 252 | 253 | # make dataset 254 | if dataset_name == 'UCF-101': 255 | func = make_dataset_ucf 256 | else: 257 | func = make_dataset 258 | self.clips, self.videos = func(video_dir, video_length, class_to_idx, frame_stride=frame_stride, clip_step=clip_step) 259 | assert(len(self.clips[0]) == video_length), f"Invalid clip length = {len(self.clips[0])}" 260 | if self.temporal_transform == 'rand_clips': 261 | self.clips = self.videos 262 | 263 | if subset_split == 'all': 264 | # add test videos 265 | video_dir = video_dir.rstrip('/train')+'/test' 266 | cs, vs = func(video_dir, video_length, class_to_idx) 267 | if self.temporal_transform == 'rand_clips': 268 | self.clips += vs 269 | else: 270 | self.clips += cs 271 | 272 | print('[VideoFrameDataset] number of videos:', len(self.videos)) 273 | print('[VideoFrameDataset] number of clips', len(self.clips)) 274 | 275 | # check data 276 | if len(self.clips) == 0: 277 | raise(RuntimeError(f"Found 0 clips in {video_dir}. \n" 278 | "Supported image extensions are: " + 279 | ",".join(IMG_EXTENSIONS))) 280 | 281 | # data transform 282 | self.img_transform = transforms.Compose([ 283 | transforms.ToTensor(), 284 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 285 | ]) 286 | if self.spatial_transform == "center_crop_resize": 287 | print('Spatial transform: center crop and then resize') 288 | self.video_transform = transforms.Compose([ 289 | transforms.Resize(resolution), 290 | transforms_video.CenterCropVideo(resolution), 291 | ]) 292 | elif self.spatial_transform == "resize": 293 | print('Spatial transform: resize with no crop') 294 | self.video_transform = transforms.Resize((resolution, resolution)) 295 | elif self.spatial_transform == "random_crop": 296 | self.video_transform = transforms.Compose([ 297 | transforms_video.RandomCropVideo(resolution), 298 | ]) 299 | elif self.spatial_transform == "": 300 | self.video_transform = None 301 | else: 302 | raise NotImplementedError 303 | 304 | def __getitem__(self, index): 305 | # get clip info 306 | if self.temporal_transform == 'rand_clips': 307 | raw_video = self.clips[index] 308 | rand_idx = random.randint(0, len(raw_video) - self.video_length) 309 | clip = raw_video[rand_idx:rand_idx+self.video_length] 310 | else: 311 | clip = self.clips[index] 312 | assert(len(clip) == self.video_length), f'current clip_length={len(clip)}, target clip_length={self.video_length}, {clip}' 313 | 314 | # make clip tensor 315 | frames, labels = load_and_transform_frames(clip, self.loader, self.img_transform) 316 | assert(len(frames) == self.video_length), f'current clip_length={len(frames)}, target clip_length={self.video_length}, {clip}' 317 | frames = torch.cat(frames, 1) # c,t,h,w 318 | if self.video_transform is not None: 319 | frames = self.video_transform(frames) 320 | 321 | example = dict() 322 | example["image"] = frames 323 | if labels is not None and self.dataset_name == 'UCF-101': 324 | example["caption"] = labels["class_caption"] 325 | example["class_label"] = labels["class_index"] 326 | example["class_name"] = labels["class_name"] 327 | example["frame_stride"] = self.frame_stride 328 | return example 329 | 330 | def __len__(self): 331 | return len(self.clips) 332 | -------------------------------------------------------------------------------- /lvdm/data/split_ucf101.py: -------------------------------------------------------------------------------- 1 | # Split the UCF-101 official dataset to train and test splits 2 | # The output data formate: 3 | # UCF-101/ 4 | # ├── train/ 5 | # │ ├── ApplyEyeMakeup/ 6 | # │ │ ├── v_ApplyEyeMakeup_g08_c01.avi 7 | # │ │ ├── v_ApplyEyeMakeup_g08_c02.avi 8 | # │ │ ├── ... 9 | # │ ├── ApplyLipstick/ 10 | # │ │ ├── v_ApplyLipstick_g01_c01.avi 11 | # ├── test/ 12 | # │ ├── ApplyEyeMakeup/ 13 | # │ │ ├── v_ApplyEyeMakeup_g01_c01.avi 14 | # │ │ ├── v_ApplyEyeMakeup_g01_c02.avi 15 | # │ │ ├── ... 16 | # │ ├── ApplyLipstick/ 17 | # │ │ ├── v_ApplyLipstick_g01_c01.avi 18 | # ├── ucfTrainTestlist/ 19 | # │ ├── classInd.txt 20 | # │ ├── testlist01.txt 21 | # │ ├── trainlist01.txt 22 | # │ ├── ... 23 | 24 | input_dir = "temp/UCF-101" # the root directory of the UCF-101 dataset 25 | input_annotation = "temp/ucfTrainTestlist" # the annotation file of the UCF-101 dataset 26 | output_dir_tmp = f"{input_dir}_split" # the temporary directory to store the split dataset 27 | 28 | remove_original_dir = False # The output directory will be created in the same directory as the input directory 29 | 30 | import os 31 | import random 32 | import shutil 33 | 34 | split_idx = 1 # the split index 35 | # read the annotation file 36 | 37 | # make the train and test directories 38 | os.makedirs(os.path.join(output_dir_tmp, f"train"), exist_ok=True) 39 | os.makedirs(os.path.join(output_dir_tmp, f"test"), exist_ok=True) 40 | 41 | def extract_split(subset="train"): 42 | if subset not in ["train", "test"]: 43 | raise ValueError("subset must be either 'train' or 'test'") 44 | 45 | with open(os.path.join(input_annotation, f"{subset}list0{split_idx}.txt")) as f: 46 | train_list = f.readlines() 47 | train_list = [x.strip() for x in train_list] 48 | for item in train_list: 49 | if subset == "test": 50 | class_name = item.split("/")[0] 51 | video_name = item.split("/")[1] 52 | elif subset == "train": 53 | class_name = item.split("/")[0] 54 | video_name = item.split(" ")[0].split("/")[1] 55 | video_path = os.path.join(input_dir, class_name, video_name) 56 | print(f"input_dir: {input_dir}, class_name: {class_name}, video_name: {video_name}") 57 | 58 | class_dir = os.path.join(output_dir_tmp, f"{subset}", class_name) 59 | os.makedirs(class_dir, exist_ok=True) 60 | shutil.copy(video_path, class_dir) 61 | print(f"Copy {video_path} to {class_dir}") 62 | 63 | # split the dataset into the output directory 64 | extract_split(subset="train") 65 | extract_split(subset="test") 66 | 67 | # copy the annotation files to the output directory 68 | shutil.copytree(input_annotation, os.path.join(output_dir_tmp, "ucfTrainTestlist")) 69 | 70 | if remove_original_dir: 71 | shutil.rmtree(input_dir) 72 | shutil.move(output_dir_tmp, input_dir) 73 | -------------------------------------------------------------------------------- /lvdm/data/taichi.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import torch 4 | from torch.utils.data import Dataset 5 | from decord import VideoReader, cpu 6 | import glob 7 | 8 | class Taichi(Dataset): 9 | """ 10 | Taichi Dataset. 11 | Assumes data is structured as follows. 12 | Taichi/ 13 | train/ 14 | xxx.mp4 15 | ... 16 | test/ 17 | xxx.mp4 18 | ... 19 | """ 20 | def __init__(self, 21 | data_root, 22 | resolution, 23 | video_length, 24 | subset_split, 25 | frame_stride, 26 | ): 27 | self.data_root = data_root 28 | self.resolution = resolution 29 | self.video_length = video_length 30 | self.subset_split = subset_split 31 | self.frame_stride = frame_stride 32 | assert(self.subset_split in ['train', 'test', 'all']) 33 | self.exts = ['avi', 'mp4', 'webm'] 34 | 35 | if isinstance(self.resolution, int): 36 | self.resolution = [self.resolution, self.resolution] 37 | assert(isinstance(self.resolution, list) and len(self.resolution) == 2) 38 | 39 | self._make_dataset() 40 | 41 | def _make_dataset(self): 42 | if self.subset_split == 'all': 43 | data_folder = self.data_root 44 | else: 45 | data_folder = os.path.join(self.data_root, self.subset_split) 46 | self.videos = sum([glob.glob(os.path.join(data_folder, '**', f'*.{ext}'), recursive=True) 47 | for ext in self.exts], []) 48 | print(f'Number of videos = {len(self.videos)}') 49 | 50 | def __getitem__(self, index): 51 | while True: 52 | video_path = self.videos[index] 53 | 54 | try: 55 | video_reader = VideoReader(video_path, ctx=cpu(0), width=self.resolution[1], height=self.resolution[0]) 56 | if len(video_reader) < self.video_length: 57 | index += 1 58 | continue 59 | else: 60 | break 61 | except: 62 | index += 1 63 | print(f"Load video failed! path = {video_path}") 64 | 65 | all_frames = list(range(0, len(video_reader), self.frame_stride)) 66 | if len(all_frames) < self.video_length: 67 | all_frames = list(range(0, len(video_reader), 1)) 68 | 69 | # select random clip 70 | rand_idx = random.randint(0, len(all_frames) - self.video_length) 71 | frame_indices = list(range(rand_idx, rand_idx+self.video_length)) 72 | frames = video_reader.get_batch(frame_indices) 73 | assert(frames.shape[0] == self.video_length),f'{len(frames)}, self.video_length={self.video_length}' 74 | 75 | frames = torch.tensor(frames.asnumpy()).permute(3, 0, 1, 2).float() # [t,h,w,c] -> [c,t,h,w] 76 | assert(frames.shape[2] == self.resolution[0] and frames.shape[3] == self.resolution[1]), f'frames={frames.shape}, self.resolution={self.resolution}' 77 | frames = (frames / 255 - 0.5) * 2 78 | data = {'video': frames} 79 | return data 80 | 81 | def __len__(self): 82 | return len(self.videos) -------------------------------------------------------------------------------- /lvdm/data/ucf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import glob 4 | import random 5 | import torch 6 | from torch.utils.data import Dataset 7 | from decord import VideoReader, cpu 8 | 9 | def split_by_captical(s): 10 | s_list = re.sub( r"([A-Z])", r" \1", s).split() 11 | string = "" 12 | for s in s_list: 13 | string += s + " " 14 | return string.rstrip(" ").lower() 15 | 16 | def sample_strided_frames(vid_len, frame_stride, target_vid_len): 17 | frame_indices = list(range(0, vid_len, frame_stride)) 18 | if len(frame_indices) < target_vid_len: 19 | frame_stride = vid_len // target_vid_len # recalculate a max fs 20 | assert(frame_stride != 0) 21 | frame_indices = list(range(0, vid_len, frame_stride)) 22 | return frame_indices, frame_stride 23 | 24 | def class_name_to_idx(annotation_dir): 25 | """ 26 | return class indices from 0 ~ num_classes-1 27 | """ 28 | fpath = os.path.join(annotation_dir, "classInd.txt") 29 | with open(fpath, "r") as f: 30 | data = f.readlines() 31 | class_to_idx = {x.strip().split(" ")[1].lower():int(x.strip().split(" ")[0]) - 1 for x in data} 32 | return class_to_idx 33 | 34 | def class_idx_to_caption(caption_path): 35 | """ 36 | return class captions 37 | """ 38 | with open(caption_path, "r") as f: 39 | data = f.readlines() 40 | idx_to_cap = {i: line.strip() for i, line in enumerate(data)} 41 | return idx_to_cap 42 | 43 | class UCF101(Dataset): 44 | """ 45 | UCF101 Dataset. Assumes data is structured as follows. 46 | 47 | UCF101/ (data_root) 48 | train/ 49 | classname 50 | xxx.avi 51 | xxx.avi 52 | test/ 53 | classname 54 | xxx.avi 55 | xxx.avi 56 | ucfTrainTestlist/ 57 | """ 58 | def __init__(self, 59 | data_root, 60 | resolution, 61 | video_length, 62 | subset_split, 63 | frame_stride, 64 | annotation_dir=None, 65 | caption_file=None, 66 | ): 67 | self.data_root = data_root 68 | self.resolution = resolution 69 | self.video_length = video_length 70 | self.subset_split = subset_split 71 | self.frame_stride = frame_stride 72 | self.annotation_dir = annotation_dir if annotation_dir is not None else os.path.join(data_root, "ucfTrainTestlist") 73 | self.caption_file = caption_file 74 | 75 | assert(self.subset_split in ['train', 'test', 'all']) 76 | self.exts = ['avi', 'mp4', 'webm'] 77 | if isinstance(self.resolution, int): 78 | self.resolution = [self.resolution, self.resolution] 79 | assert(isinstance(self.resolution, list) and len(self.resolution) == 2) 80 | 81 | self._make_dataset() 82 | 83 | def _make_dataset(self): 84 | if self.subset_split == 'all': 85 | data_folder = self.data_root 86 | else: 87 | data_folder = os.path.join(self.data_root, self.subset_split) 88 | video_paths = sum([glob.glob(os.path.join(data_folder, '**', f'*.{ext}'), recursive=True) 89 | for ext in self.exts], []) 90 | # ucf class_to_idx 91 | class_to_idx = class_name_to_idx(self.annotation_dir) 92 | idx_to_cap = class_idx_to_caption(self.caption_path) if self.caption_file is not None else None 93 | 94 | self.videos = video_paths 95 | self.class_to_idx = class_to_idx 96 | self.idx_to_cap = idx_to_cap 97 | print(f'Number of videos = {len(self.videos)}') 98 | 99 | def _get_ucf_classinfo(self, videopath): 100 | video_name = os.path.basename(videopath) 101 | class_name = video_name.split("_")[1].lower() # v_BoxingSpeedBag_g12_c05 -> boxingspeedbag 102 | class_index = self.class_to_idx[class_name] # 0-100 103 | class_caption = self.idx_to_cap[class_index] if self.caption_file is not None else \ 104 | split_by_captical(video_name.split("_")[1]) # v_BoxingSpeedBag_g12_c05 -> boxing speed bag 105 | return class_name, class_index, class_caption 106 | 107 | def __getitem__(self, index): 108 | while True: 109 | video_path = self.videos[index] 110 | 111 | try: 112 | video_reader = VideoReader(video_path, ctx=cpu(0), width=self.resolution[1], height=self.resolution[0]) 113 | vid_len = len(video_reader) 114 | if vid_len < self.video_length: 115 | index += 1 116 | continue 117 | else: 118 | break 119 | except: 120 | index += 1 121 | print(f"Load video failed! path = {video_path}") 122 | 123 | # sample strided frames 124 | all_frames, fs = sample_strided_frames(vid_len, self.frame_stride, self.video_length) 125 | 126 | # select random clip 127 | rand_idx = random.randint(0, len(all_frames) - self.video_length) 128 | frame_indices = list(range(rand_idx, rand_idx+self.video_length)) 129 | frames = video_reader.get_batch(frame_indices) 130 | assert(frames.shape[0] == self.video_length),f'{len(frames)}, self.video_length={self.video_length}' 131 | 132 | frames = torch.tensor(frames.asnumpy()).permute(3, 0, 1, 2).float() # [t,h,w,c] -> [c,t,h,w] 133 | assert(frames.shape[2] == self.resolution[0] and frames.shape[3] == self.resolution[1]), f'frames={frames.shape}, self.resolution={self.resolution}' 134 | frames = (frames / 255 - 0.5) * 2 135 | 136 | class_name, class_index, class_caption = self._get_ucf_classinfo(videopath=video_path) 137 | data = {'video': frames, 138 | 'class_name': class_name, 139 | 'class_index': class_index, 140 | 'class_caption': class_caption 141 | } 142 | return data 143 | 144 | def __len__(self): 145 | return len(self.videos) -------------------------------------------------------------------------------- /lvdm/models/autoencoder.py: -------------------------------------------------------------------------------- 1 | import os 2 | from einops import rearrange 3 | 4 | import torch 5 | import pytorch_lightning as pl 6 | import torch.nn.functional as F 7 | 8 | from lvdm.models.modules.aemodules import Encoder, Decoder 9 | from lvdm.models.modules.distributions import DiagonalGaussianDistribution 10 | from lvdm.utils.common_utils import instantiate_from_config 11 | 12 | class AutoencoderKL(pl.LightningModule): 13 | def __init__(self, 14 | ddconfig, 15 | lossconfig, 16 | embed_dim, 17 | ckpt_path=None, 18 | ignore_keys=[], 19 | image_key="image", 20 | colorize_nlabels=None, 21 | monitor=None, 22 | test=False, 23 | logdir=None, 24 | input_dim=4, 25 | test_args=None, 26 | ): 27 | super().__init__() 28 | self.image_key = image_key 29 | self.encoder = Encoder(**ddconfig) 30 | self.decoder = Decoder(**ddconfig) 31 | self.loss = instantiate_from_config(lossconfig) 32 | assert ddconfig["double_z"] 33 | self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) 34 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) 35 | self.embed_dim = embed_dim 36 | self.input_dim = input_dim 37 | self.test = test 38 | self.test_args = test_args 39 | self.logdir = logdir 40 | if colorize_nlabels is not None: 41 | assert type(colorize_nlabels)==int 42 | self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) 43 | if monitor is not None: 44 | self.monitor = monitor 45 | if ckpt_path is not None: 46 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) 47 | if self.test: 48 | self.init_test() 49 | 50 | def init_test(self,): 51 | self.test = True 52 | save_dir = os.path.join(self.logdir, "test") 53 | if 'ckpt' in self.test_args: 54 | ckpt_name = os.path.basename(self.test_args.ckpt).split('.ckpt')[0] + f'_epoch{self._cur_epoch}' 55 | self.root = os.path.join(save_dir, ckpt_name) 56 | else: 57 | self.root = save_dir 58 | if 'test_subdir' in self.test_args: 59 | self.root = os.path.join(save_dir, self.test_args.test_subdir) 60 | 61 | self.root_zs = os.path.join(self.root, "zs") 62 | self.root_dec = os.path.join(self.root, "reconstructions") 63 | self.root_inputs = os.path.join(self.root, "inputs") 64 | os.makedirs(self.root, exist_ok=True) 65 | 66 | if self.test_args.save_z: 67 | os.makedirs(self.root_zs, exist_ok=True) 68 | if self.test_args.save_reconstruction: 69 | os.makedirs(self.root_dec, exist_ok=True) 70 | if self.test_args.save_input: 71 | os.makedirs(self.root_inputs, exist_ok=True) 72 | assert(self.test_args is not None) 73 | self.test_maximum = getattr(self.test_args, 'test_maximum', None) #1500 # 12000/8 74 | self.count = 0 75 | self.eval_metrics = {} 76 | self.decodes = [] 77 | self.save_decode_samples = 2048 78 | 79 | def init_from_ckpt(self, path, ignore_keys=list()): 80 | sd = torch.load(path, map_location="cpu") 81 | try: 82 | self._cur_epoch = sd['epoch'] 83 | sd = sd["state_dict"] 84 | except: 85 | self._cur_epoch = 'null' 86 | keys = list(sd.keys()) 87 | for k in keys: 88 | for ik in ignore_keys: 89 | if k.startswith(ik): 90 | print("Deleting key {} from state_dict.".format(k)) 91 | del sd[k] 92 | self.load_state_dict(sd, strict=False) 93 | # self.load_state_dict(sd, strict=True) 94 | print(f"Restored from {path}") 95 | 96 | def encode(self, x, **kwargs): 97 | 98 | h = self.encoder(x) 99 | moments = self.quant_conv(h) 100 | posterior = DiagonalGaussianDistribution(moments) 101 | return posterior 102 | 103 | def decode(self, z, **kwargs): 104 | z = self.post_quant_conv(z) 105 | dec = self.decoder(z) 106 | return dec 107 | 108 | def forward(self, input, sample_posterior=True): 109 | posterior = self.encode(input) 110 | if sample_posterior: 111 | z = posterior.sample() 112 | else: 113 | z = posterior.mode() 114 | dec = self.decode(z) 115 | return dec, posterior 116 | 117 | def get_input(self, batch, k): 118 | x = batch[k] 119 | # if len(x.shape) == 3: 120 | # x = x[..., None] 121 | # if x.dim() == 4: 122 | # x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() 123 | if x.dim() == 5 and self.input_dim == 4: 124 | b,c,t,h,w = x.shape 125 | self.b = b 126 | self.t = t 127 | x = rearrange(x, 'b c t h w -> (b t) c h w') 128 | 129 | return x 130 | 131 | def training_step(self, batch, batch_idx, optimizer_idx): 132 | inputs = self.get_input(batch, self.image_key) 133 | reconstructions, posterior = self(inputs) 134 | 135 | if optimizer_idx == 0: 136 | # train encoder+decoder+logvar 137 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, 138 | last_layer=self.get_last_layer(), split="train") 139 | self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) 140 | self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) 141 | return aeloss 142 | 143 | if optimizer_idx == 1: 144 | # train the discriminator 145 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, 146 | last_layer=self.get_last_layer(), split="train") 147 | 148 | self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) 149 | self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) 150 | return discloss 151 | 152 | def validation_step(self, batch, batch_idx): 153 | inputs = self.get_input(batch, self.image_key) 154 | reconstructions, posterior = self(inputs) 155 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, 156 | last_layer=self.get_last_layer(), split="val") 157 | 158 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, 159 | last_layer=self.get_last_layer(), split="val") 160 | 161 | self.log("val/rec_loss", log_dict_ae["val/rec_loss"]) 162 | self.log_dict(log_dict_ae) 163 | self.log_dict(log_dict_disc) 164 | return self.log_dict 165 | 166 | def configure_optimizers(self): 167 | lr = self.learning_rate 168 | opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ 169 | list(self.decoder.parameters())+ 170 | list(self.quant_conv.parameters())+ 171 | list(self.post_quant_conv.parameters()), 172 | lr=lr, betas=(0.5, 0.9)) 173 | opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), 174 | lr=lr, betas=(0.5, 0.9)) 175 | return [opt_ae, opt_disc], [] 176 | 177 | def get_last_layer(self): 178 | return self.decoder.conv_out.weight 179 | 180 | @torch.no_grad() 181 | def log_images(self, batch, only_inputs=False, **kwargs): 182 | log = dict() 183 | x = self.get_input(batch, self.image_key) 184 | x = x.to(self.device) 185 | if not only_inputs: 186 | xrec, posterior = self(x) 187 | if x.shape[1] > 3: 188 | # colorize with random projection 189 | assert xrec.shape[1] > 3 190 | x = self.to_rgb(x) 191 | xrec = self.to_rgb(xrec) 192 | log["samples"] = self.decode(torch.randn_like(posterior.sample())) 193 | log["reconstructions"] = xrec 194 | log["inputs"] = x 195 | return log 196 | 197 | def to_rgb(self, x): 198 | assert self.image_key == "segmentation" 199 | if not hasattr(self, "colorize"): 200 | self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) 201 | x = F.conv2d(x, weight=self.colorize) 202 | x = 2.*(x-x.min())/(x.max()-x.min()) - 1. 203 | return x -------------------------------------------------------------------------------- /lvdm/models/autoencoder3d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytorch_lightning as pl 3 | import torch.nn.functional as F 4 | 5 | from lvdm.models.modules.aemodules3d import SamePadConv3d 6 | from lvdm.utils.common_utils import instantiate_from_config 7 | from lvdm.models.modules.distributions import DiagonalGaussianDistribution 8 | 9 | 10 | def conv3d(in_channels, out_channels, kernel_size, conv3d_type='SamePadConv3d'): 11 | if conv3d_type == 'SamePadConv3d': 12 | return SamePadConv3d(in_channels, out_channels, kernel_size=kernel_size, padding_type='replicate') 13 | else: 14 | raise NotImplementedError 15 | 16 | class AutoencoderKL(pl.LightningModule): 17 | def __init__(self, 18 | ddconfig, 19 | lossconfig, 20 | embed_dim, 21 | ckpt_path=None, 22 | ignore_keys=[], 23 | image_key="image", 24 | monitor=None, 25 | std=1., 26 | mean=0., 27 | prob=0.2, 28 | **kwargs, 29 | ): 30 | super().__init__() 31 | self.image_key = image_key 32 | self.encoder = instantiate_from_config(ddconfig['encoder']) 33 | self.decoder = instantiate_from_config(ddconfig['decoder']) 34 | self.loss = instantiate_from_config(lossconfig) 35 | assert ddconfig["double_z"] 36 | self.quant_conv = conv3d(2*ddconfig["z_channels"], 2*embed_dim, 1) 37 | self.post_quant_conv = conv3d(embed_dim, ddconfig["z_channels"], 1) 38 | self.embed_dim = embed_dim 39 | self.std = std 40 | self.mean = mean 41 | self.prob = prob 42 | if monitor is not None: 43 | self.monitor = monitor 44 | if ckpt_path is not None: 45 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) 46 | 47 | def init_from_ckpt(self, path, ignore_keys=list()): 48 | sd = torch.load(path, map_location="cpu") 49 | try: 50 | self._cur_epoch = sd['epoch'] 51 | sd = sd["state_dict"] 52 | except: 53 | pass 54 | keys = list(sd.keys()) 55 | for k in keys: 56 | for ik in ignore_keys: 57 | if k.startswith(ik): 58 | print("Deleting key {} from state_dict.".format(k)) 59 | del sd[k] 60 | self.load_state_dict(sd, strict=False) 61 | print(f"Restored from {path}") 62 | 63 | def encode(self, x, **kwargs): 64 | h = self.encoder(x) 65 | moments = self.quant_conv(h) 66 | posterior = DiagonalGaussianDistribution(moments) 67 | return posterior 68 | 69 | def decode(self, z, **kwargs): 70 | z = self.post_quant_conv(z) 71 | dec = self.decoder(z) 72 | return dec 73 | 74 | def forward(self, input, sample_posterior=True, **kwargs): 75 | posterior = self.encode(input) 76 | if sample_posterior: 77 | z = posterior.sample() 78 | else: 79 | z = posterior.mode() 80 | dec = self.decode(z) 81 | return dec, posterior 82 | 83 | def get_input(self, batch, k): 84 | x = batch[k] 85 | if len(x.shape) == 4: 86 | x = x[..., None] 87 | x = x.to(memory_format=torch.contiguous_format).float() 88 | return x 89 | 90 | def training_step(self, batch, batch_idx, optimizer_idx): 91 | inputs = self.get_input(batch, self.image_key) 92 | reconstructions, posterior = self(inputs) 93 | 94 | if optimizer_idx == 0: 95 | # train encoder+decoder+logvar 96 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, 97 | last_layer=self.get_last_layer(), split="train") 98 | self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) 99 | self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) 100 | return aeloss 101 | 102 | if optimizer_idx == 1: 103 | # train the discriminator 104 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, 105 | last_layer=self.get_last_layer(), split="train") 106 | 107 | self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) 108 | self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) 109 | return discloss 110 | 111 | def validation_step(self, batch, batch_idx): 112 | inputs = self.get_input(batch, self.image_key) 113 | reconstructions, posterior = self(inputs) 114 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, 115 | last_layer=self.get_last_layer(), split="val") 116 | 117 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, 118 | last_layer=self.get_last_layer(), split="val") 119 | 120 | self.log("val/rec_loss", log_dict_ae["val/rec_loss"]) 121 | self.log_dict(log_dict_ae) 122 | self.log_dict(log_dict_disc) 123 | return self.log_dict 124 | 125 | def configure_optimizers(self): 126 | lr = self.learning_rate 127 | opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ 128 | list(self.decoder.parameters())+ 129 | list(self.quant_conv.parameters())+ 130 | list(self.post_quant_conv.parameters()), 131 | lr=lr, betas=(0.5, 0.9)) 132 | opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), 133 | lr=lr, betas=(0.5, 0.9)) 134 | return [opt_ae, opt_disc], [] 135 | 136 | def get_last_layer(self): 137 | return self.decoder.conv_out.weight 138 | 139 | @torch.no_grad() 140 | def log_images(self, batch, only_inputs=False, **kwargs): 141 | log = dict() 142 | x = self.get_input(batch, self.image_key) 143 | x = x.to(self.device) 144 | if not only_inputs: 145 | xrec, posterior = self(x) 146 | if x.shape[1] > 3: 147 | # colorize with random projection 148 | assert xrec.shape[1] > 3 149 | x = self.to_rgb(x) 150 | xrec = self.to_rgb(xrec) 151 | log["samples"] = self.decode(torch.randn_like(posterior.sample())) 152 | log["reconstructions"] = xrec 153 | log["inputs"] = x 154 | return log 155 | 156 | def to_rgb(self, x): 157 | assert self.image_key == "segmentation" 158 | if not hasattr(self, "colorize"): 159 | self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) 160 | x = F.conv2d(x, weight=self.colorize) 161 | x = 2.*(x-x.min())/(x.max()-x.min()) - 1. 162 | return x 163 | -------------------------------------------------------------------------------- /lvdm/models/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from lvdm.models.losses.contperceptual import LPIPSWithDiscriminator -------------------------------------------------------------------------------- /lvdm/models/losses/contperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from einops import rearrange 4 | from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no? 5 | 6 | 7 | class LPIPSWithDiscriminator(nn.Module): 8 | def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0, 9 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 10 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 11 | disc_loss="hinge", max_bs=None,): 12 | 13 | super().__init__() 14 | assert disc_loss in ["hinge", "vanilla"] 15 | self.kl_weight = kl_weight 16 | self.pixel_weight = pixelloss_weight 17 | self.perceptual_loss = LPIPS().eval() 18 | self.perceptual_weight = perceptual_weight 19 | # output log variance 20 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) 21 | 22 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 23 | n_layers=disc_num_layers, 24 | use_actnorm=use_actnorm 25 | ).apply(weights_init) 26 | self.discriminator_iter_start = disc_start 27 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss 28 | self.disc_factor = disc_factor 29 | self.discriminator_weight = disc_weight 30 | self.disc_conditional = disc_conditional 31 | self.max_bs = max_bs 32 | 33 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 34 | if last_layer is not None: 35 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 36 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 37 | else: 38 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 39 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 40 | 41 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 42 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 43 | d_weight = d_weight * self.discriminator_weight 44 | return d_weight 45 | 46 | def forward(self, inputs, reconstructions, posteriors, optimizer_idx, 47 | global_step, last_layer=None, cond=None, split="train", 48 | weights=None,): 49 | if inputs.dim() == 5: 50 | inputs = rearrange(inputs, 'b c t h w -> (b t) c h w') 51 | if reconstructions.dim() == 5: 52 | reconstructions = rearrange(reconstructions, 'b c t h w -> (b t) c h w') 53 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 54 | if self.perceptual_weight > 0: 55 | if self.max_bs is not None and self.max_bs < inputs.shape[0] : 56 | input_list = torch.split(inputs, self.max_bs, dim=0) 57 | reconstruction_list = torch.split(reconstructions, self.max_bs, dim=0) 58 | p_losses = [self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) for inputs, reconstructions in zip(input_list, reconstruction_list)] 59 | p_loss=torch.cat(p_losses,dim=0) 60 | else: 61 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 62 | rec_loss = rec_loss + self.perceptual_weight * p_loss 63 | 64 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar 65 | weighted_nll_loss = nll_loss 66 | if weights is not None: 67 | weighted_nll_loss = weights*nll_loss 68 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] 69 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 70 | 71 | kl_loss = posteriors.kl() 72 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] 73 | 74 | # now the GAN part 75 | if optimizer_idx == 0: 76 | # generator update 77 | if cond is None: 78 | assert not self.disc_conditional 79 | logits_fake = self.discriminator(reconstructions.contiguous()) 80 | else: 81 | assert self.disc_conditional 82 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 83 | g_loss = -torch.mean(logits_fake) 84 | 85 | if self.disc_factor > 0.0: 86 | try: 87 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 88 | except RuntimeError: 89 | assert not self.training 90 | d_weight = torch.tensor(0.0) 91 | else: 92 | d_weight = torch.tensor(0.0) 93 | 94 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 95 | loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss 96 | 97 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(), 98 | "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(), 99 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 100 | "{}/d_weight".format(split): d_weight.detach(), 101 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 102 | "{}/g_loss".format(split): g_loss.detach().mean(), 103 | } 104 | return loss, log 105 | 106 | if optimizer_idx == 1: 107 | # second pass for discriminator update 108 | if cond is None: 109 | logits_real = self.discriminator(inputs.contiguous().detach()) 110 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 111 | else: 112 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 113 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 114 | 115 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 116 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 117 | 118 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 119 | "{}/logits_real".format(split): logits_real.detach().mean(), 120 | "{}/logits_fake".format(split): logits_fake.detach().mean() 121 | } 122 | return d_loss, log 123 | 124 | -------------------------------------------------------------------------------- /lvdm/models/losses/vqperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from einops import repeat 5 | 6 | from taming.modules.discriminator.model import NLayerDiscriminator, weights_init 7 | from taming.modules.losses.lpips import LPIPS 8 | from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss 9 | 10 | 11 | def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights): 12 | assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0] 13 | loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3]) 14 | loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3]) 15 | loss_real = (weights * loss_real).sum() / weights.sum() 16 | loss_fake = (weights * loss_fake).sum() / weights.sum() 17 | d_loss = 0.5 * (loss_real + loss_fake) 18 | return d_loss 19 | 20 | def adopt_weight(weight, global_step, threshold=0, value=0.): 21 | if global_step < threshold: 22 | weight = value 23 | return weight 24 | 25 | 26 | def measure_perplexity(predicted_indices, n_embed): 27 | # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py 28 | # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally 29 | encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed) 30 | avg_probs = encodings.mean(0) 31 | perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() 32 | cluster_use = torch.sum(avg_probs > 0) 33 | return perplexity, cluster_use 34 | 35 | def l1(x, y): 36 | return torch.abs(x-y) 37 | 38 | 39 | def l2(x, y): 40 | return torch.pow((x-y), 2) 41 | 42 | 43 | class VQLPIPSWithDiscriminator(nn.Module): 44 | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, 45 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 46 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 47 | disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips", 48 | pixel_loss="l1"): 49 | super().__init__() 50 | assert disc_loss in ["hinge", "vanilla"] 51 | assert perceptual_loss in ["lpips", "clips", "dists"] 52 | assert pixel_loss in ["l1", "l2"] 53 | self.codebook_weight = codebook_weight 54 | self.pixel_weight = pixelloss_weight 55 | if perceptual_loss == "lpips": 56 | print(f"{self.__class__.__name__}: Running with LPIPS.") 57 | self.perceptual_loss = LPIPS().eval() 58 | else: 59 | raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<") 60 | self.perceptual_weight = perceptual_weight 61 | 62 | if pixel_loss == "l1": 63 | self.pixel_loss = l1 64 | else: 65 | self.pixel_loss = l2 66 | 67 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 68 | n_layers=disc_num_layers, 69 | use_actnorm=use_actnorm, 70 | ndf=disc_ndf 71 | ).apply(weights_init) 72 | self.discriminator_iter_start = disc_start 73 | if disc_loss == "hinge": 74 | self.disc_loss = hinge_d_loss 75 | elif disc_loss == "vanilla": 76 | self.disc_loss = vanilla_d_loss 77 | else: 78 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.") 79 | print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") 80 | self.disc_factor = disc_factor 81 | self.discriminator_weight = disc_weight 82 | self.disc_conditional = disc_conditional 83 | self.n_classes = n_classes 84 | 85 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 86 | if last_layer is not None: 87 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 88 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 89 | else: 90 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 91 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 92 | 93 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 94 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 95 | d_weight = d_weight * self.discriminator_weight 96 | return d_weight 97 | 98 | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, 99 | global_step, last_layer=None, cond=None, split="train", predicted_indices=None): 100 | if not exists(codebook_loss): 101 | codebook_loss = torch.tensor([0.]).to(inputs.device) 102 | #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 103 | rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous()) 104 | if self.perceptual_weight > 0: 105 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 106 | rec_loss = rec_loss + self.perceptual_weight * p_loss 107 | else: 108 | p_loss = torch.tensor([0.0]) 109 | 110 | nll_loss = rec_loss 111 | #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 112 | nll_loss = torch.mean(nll_loss) 113 | 114 | # now the GAN part 115 | if optimizer_idx == 0: 116 | # generator update 117 | if cond is None: 118 | assert not self.disc_conditional 119 | logits_fake = self.discriminator(reconstructions.contiguous()) 120 | else: 121 | assert self.disc_conditional 122 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 123 | g_loss = -torch.mean(logits_fake) 124 | 125 | try: 126 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 127 | except RuntimeError: 128 | assert not self.training 129 | d_weight = torch.tensor(0.0) 130 | 131 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 132 | loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() 133 | 134 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), 135 | "{}/quant_loss".format(split): codebook_loss.detach().mean(), 136 | "{}/nll_loss".format(split): nll_loss.detach().mean(), 137 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 138 | "{}/p_loss".format(split): p_loss.detach().mean(), 139 | "{}/d_weight".format(split): d_weight.detach(), 140 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 141 | "{}/g_loss".format(split): g_loss.detach().mean(), 142 | } 143 | if predicted_indices is not None: 144 | assert self.n_classes is not None 145 | with torch.no_grad(): 146 | perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes) 147 | log[f"{split}/perplexity"] = perplexity 148 | log[f"{split}/cluster_usage"] = cluster_usage 149 | return loss, log 150 | 151 | if optimizer_idx == 1: 152 | # second pass for discriminator update 153 | if cond is None: 154 | logits_real = self.discriminator(inputs.contiguous().detach()) 155 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 156 | else: 157 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 158 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 159 | 160 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 161 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 162 | 163 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 164 | "{}/logits_real".format(split): logits_real.detach().mean(), 165 | "{}/logits_fake".format(split): logits_fake.detach().mean() 166 | } 167 | return d_loss, log 168 | -------------------------------------------------------------------------------- /lvdm/models/modules/aemodules3d.py: -------------------------------------------------------------------------------- 1 | # TATS 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | 4 | import math 5 | import numpy as np 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | def silu(x): 12 | return x*torch.sigmoid(x) 13 | 14 | class SiLU(nn.Module): 15 | def __init__(self): 16 | super(SiLU, self).__init__() 17 | 18 | def forward(self, x): 19 | return silu(x) 20 | 21 | def hinge_d_loss(logits_real, logits_fake): 22 | loss_real = torch.mean(F.relu(1. - logits_real)) 23 | loss_fake = torch.mean(F.relu(1. + logits_fake)) 24 | d_loss = 0.5 * (loss_real + loss_fake) 25 | return d_loss 26 | 27 | def vanilla_d_loss(logits_real, logits_fake): 28 | d_loss = 0.5 * ( 29 | torch.mean(torch.nn.functional.softplus(-logits_real)) + 30 | torch.mean(torch.nn.functional.softplus(logits_fake))) 31 | return d_loss 32 | 33 | def Normalize(in_channels, norm_type='group'): 34 | assert norm_type in ['group', 'batch'] 35 | if norm_type == 'group': 36 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 37 | elif norm_type == 'batch': 38 | return torch.nn.SyncBatchNorm(in_channels) 39 | 40 | 41 | class ResBlock(nn.Module): 42 | def __init__(self, in_channels, out_channels=None, conv_shortcut=False, dropout=0.0, norm_type='group', padding_type='replicate'): 43 | super().__init__() 44 | self.in_channels = in_channels 45 | out_channels = in_channels if out_channels is None else out_channels 46 | self.out_channels = out_channels 47 | self.use_conv_shortcut = conv_shortcut 48 | 49 | self.norm1 = Normalize(in_channels, norm_type) 50 | self.conv1 = SamePadConv3d(in_channels, out_channels, kernel_size=3, padding_type=padding_type) 51 | self.dropout = torch.nn.Dropout(dropout) 52 | self.norm2 = Normalize(in_channels, norm_type) 53 | self.conv2 = SamePadConv3d(out_channels, out_channels, kernel_size=3, padding_type=padding_type) 54 | if self.in_channels != self.out_channels: 55 | self.conv_shortcut = SamePadConv3d(in_channels, out_channels, kernel_size=3, padding_type=padding_type) 56 | 57 | def forward(self, x): 58 | h = x 59 | h = self.norm1(h) 60 | h = silu(h) 61 | h = self.conv1(h) 62 | h = self.norm2(h) 63 | h = silu(h) 64 | h = self.conv2(h) 65 | 66 | if self.in_channels != self.out_channels: 67 | x = self.conv_shortcut(x) 68 | 69 | return x+h 70 | 71 | 72 | # Does not support dilation 73 | class SamePadConv3d(nn.Module): 74 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=True, padding_type='replicate'): 75 | super().__init__() 76 | if isinstance(kernel_size, int): 77 | kernel_size = (kernel_size,) * 3 78 | if isinstance(stride, int): 79 | stride = (stride,) * 3 80 | 81 | # assumes that the input shape is divisible by stride 82 | total_pad = tuple([k - s for k, s in zip(kernel_size, stride)]) 83 | pad_input = [] 84 | for p in total_pad[::-1]: # reverse since F.pad starts from last dim 85 | pad_input.append((p // 2 + p % 2, p // 2)) 86 | pad_input = sum(pad_input, tuple()) 87 | 88 | self.pad_input = pad_input 89 | self.padding_type = padding_type 90 | 91 | self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, 92 | stride=stride, padding=0, bias=bias) 93 | self.weight = self.conv.weight 94 | 95 | def forward(self, x): 96 | return self.conv(F.pad(x, self.pad_input, mode=self.padding_type)) 97 | 98 | 99 | class SamePadConvTranspose3d(nn.Module): 100 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=True, padding_type='replicate'): 101 | super().__init__() 102 | if isinstance(kernel_size, int): 103 | kernel_size = (kernel_size,) * 3 104 | if isinstance(stride, int): 105 | stride = (stride,) * 3 106 | 107 | total_pad = tuple([k - s for k, s in zip(kernel_size, stride)]) 108 | pad_input = [] 109 | for p in total_pad[::-1]: # reverse since F.pad starts from last dim 110 | pad_input.append((p // 2 + p % 2, p // 2)) 111 | pad_input = sum(pad_input, tuple()) 112 | self.pad_input = pad_input 113 | self.padding_type = padding_type 114 | self.convt = nn.ConvTranspose3d(in_channels, out_channels, kernel_size, 115 | stride=stride, bias=bias, 116 | padding=tuple([k - 1 for k in kernel_size])) 117 | 118 | def forward(self, x): 119 | return self.convt(F.pad(x, self.pad_input, mode=self.padding_type)) 120 | 121 | 122 | class Encoder(nn.Module): 123 | def __init__(self, n_hiddens, downsample, z_channels, double_z, image_channel=3, norm_type='group', padding_type='replicate'): 124 | super().__init__() 125 | n_times_downsample = np.array([int(math.log2(d)) for d in downsample]) 126 | self.conv_blocks = nn.ModuleList() 127 | max_ds = n_times_downsample.max() 128 | self.conv_first = SamePadConv3d(image_channel, n_hiddens, kernel_size=3, padding_type=padding_type) 129 | 130 | for i in range(max_ds): 131 | block = nn.Module() 132 | in_channels = n_hiddens * 2**i 133 | out_channels = n_hiddens * 2**(i+1) 134 | stride = tuple([2 if d > 0 else 1 for d in n_times_downsample]) 135 | block.down = SamePadConv3d(in_channels, out_channels, 4, stride=stride, padding_type=padding_type) 136 | block.res = ResBlock(out_channels, out_channels, norm_type=norm_type) 137 | self.conv_blocks.append(block) 138 | n_times_downsample -= 1 139 | 140 | self.final_block = nn.Sequential( 141 | Normalize(out_channels, norm_type), 142 | SiLU(), 143 | SamePadConv3d(out_channels, 2*z_channels if double_z else z_channels, 144 | kernel_size=3, 145 | stride=1, 146 | padding_type=padding_type) 147 | ) 148 | self.out_channels = out_channels 149 | 150 | def forward(self, x): 151 | h = self.conv_first(x) 152 | for block in self.conv_blocks: 153 | h = block.down(h) 154 | h = block.res(h) 155 | h = self.final_block(h) 156 | return h 157 | 158 | 159 | class Decoder(nn.Module): 160 | def __init__(self, n_hiddens, upsample, z_channels, image_channel, norm_type='group'): 161 | super().__init__() 162 | 163 | n_times_upsample = np.array([int(math.log2(d)) for d in upsample]) 164 | max_us = n_times_upsample.max() 165 | in_channels = z_channels 166 | self.conv_blocks = nn.ModuleList() 167 | for i in range(max_us): 168 | block = nn.Module() 169 | in_channels = in_channels if i ==0 else n_hiddens*2**(max_us-i+1) 170 | out_channels = n_hiddens*2**(max_us-i) 171 | us = tuple([2 if d > 0 else 1 for d in n_times_upsample]) 172 | block.up = SamePadConvTranspose3d(in_channels, out_channels, 4, stride=us) 173 | block.res1 = ResBlock(out_channels, out_channels, norm_type=norm_type) 174 | block.res2 = ResBlock(out_channels, out_channels, norm_type=norm_type) 175 | self.conv_blocks.append(block) 176 | n_times_upsample -= 1 177 | 178 | self.conv_out = SamePadConv3d(out_channels, image_channel, kernel_size=3) 179 | 180 | def forward(self, x): 181 | h = x 182 | for i, block in enumerate(self.conv_blocks): 183 | h = block.up(h) 184 | h = block.res1(h) 185 | h = block.res2(h) 186 | h = self.conv_out(h) 187 | return h 188 | -------------------------------------------------------------------------------- /lvdm/models/modules/condition_modules.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from transformers import logging 3 | from transformers import CLIPTokenizer, CLIPTextModel 4 | logging.set_verbosity_error() 5 | 6 | 7 | class AbstractEncoder(nn.Module): 8 | def __init__(self): 9 | super().__init__() 10 | 11 | def encode(self, *args, **kwargs): 12 | raise NotImplementedError 13 | 14 | 15 | class FrozenCLIPEmbedder(AbstractEncoder): 16 | """Uses the CLIP transformer encoder for text (from huggingface)""" 17 | def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77): 18 | super().__init__() 19 | self.tokenizer = CLIPTokenizer.from_pretrained(version) 20 | self.transformer = CLIPTextModel.from_pretrained(version) 21 | self.device = device 22 | self.max_length = max_length 23 | self.freeze() 24 | 25 | def freeze(self): 26 | self.transformer = self.transformer.eval() 27 | for param in self.parameters(): 28 | param.requires_grad = False 29 | 30 | def forward(self, text): 31 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 32 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 33 | tokens = batch_encoding["input_ids"].to(self.device) 34 | outputs = self.transformer(input_ids=tokens) 35 | 36 | z = outputs.last_hidden_state 37 | return z 38 | 39 | def encode(self, text): 40 | return self(text) -------------------------------------------------------------------------------- /lvdm/models/modules/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self, noise=None): 36 | if noise is None: 37 | noise = torch.randn(self.mean.shape) 38 | 39 | x = self.mean + self.std * noise.to(device=self.parameters.device) 40 | return x 41 | 42 | def kl(self, other=None): 43 | if self.deterministic: 44 | return torch.Tensor([0.]) 45 | else: 46 | if other is None: 47 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 48 | + self.var - 1.0 - self.logvar, 49 | dim=[1, 2, 3]) 50 | else: 51 | return 0.5 * torch.sum( 52 | torch.pow(self.mean - other.mean, 2) / other.var 53 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 54 | dim=[1, 2, 3]) 55 | 56 | def nll(self, sample, dims=[1,2,3]): 57 | if self.deterministic: 58 | return torch.Tensor([0.]) 59 | logtwopi = np.log(2.0 * np.pi) 60 | return 0.5 * torch.sum( 61 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 62 | dim=dims) 63 | 64 | def mode(self): 65 | return self.mean 66 | 67 | 68 | def normal_kl(mean1, logvar1, mean2, logvar2): 69 | """ 70 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 71 | Compute the KL divergence between two gaussians. 72 | Shapes are automatically broadcasted, so batches can be compared to 73 | scalars, among other use cases. 74 | """ 75 | tensor = None 76 | for obj in (mean1, logvar1, mean2, logvar2): 77 | if isinstance(obj, torch.Tensor): 78 | tensor = obj 79 | break 80 | assert tensor is not None, "at least one argument must be a Tensor" 81 | 82 | # Force variances to be Tensors. Broadcasting helps convert scalars to 83 | # Tensors, but it does not work for torch.exp(). 84 | logvar1, logvar2 = [ 85 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 86 | for x in (logvar1, logvar2) 87 | ] 88 | 89 | return 0.5 * ( 90 | -1.0 91 | + logvar2 92 | - logvar1 93 | + torch.exp(logvar1 - logvar2) 94 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 95 | ) 96 | -------------------------------------------------------------------------------- /lvdm/models/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError('Decay must be between 0 and 1') 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates 14 | else torch.tensor(-1,dtype=torch.int)) 15 | 16 | for name, p in model.named_parameters(): 17 | if p.requires_grad: 18 | #remove as '.'-character is not allowed in buffers 19 | s_name = name.replace('.','') 20 | self.m_name2s_name.update({name:s_name}) 21 | self.register_buffer(s_name,p.clone().detach().data) 22 | 23 | self.collected_params = [] 24 | 25 | def forward(self,model): 26 | decay = self.decay 27 | 28 | if self.num_updates >= 0: 29 | self.num_updates += 1 30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) 31 | 32 | one_minus_decay = 1.0 - decay 33 | 34 | with torch.no_grad(): 35 | m_param = dict(model.named_parameters()) 36 | shadow_params = dict(self.named_buffers()) 37 | 38 | for key in m_param: 39 | if m_param[key].requires_grad: 40 | sname = self.m_name2s_name[key] 41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 43 | else: 44 | assert not key in self.m_name2s_name 45 | 46 | def copy_to(self, model): 47 | m_param = dict(model.named_parameters()) 48 | shadow_params = dict(self.named_buffers()) 49 | for key in m_param: 50 | if m_param[key].requires_grad: 51 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 52 | else: 53 | assert not key in self.m_name2s_name 54 | 55 | def store(self, parameters): 56 | """ 57 | Save the current parameters for restoring later. 58 | Args: 59 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 60 | temporarily stored. 61 | """ 62 | self.collected_params = [param.clone() for param in parameters] 63 | 64 | def restore(self, parameters): 65 | """ 66 | Restore the parameters stored with the `store` method. 67 | Useful to validate the model with EMA parameters without affecting the 68 | original optimization process. Store the parameters before the 69 | `copy_to` method. After validation (or model saving), use this to 70 | restore the former parameters. 71 | Args: 72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 73 | updated with the stored parameters. 74 | """ 75 | for c_param, param in zip(self.collected_params, parameters): 76 | param.data.copy_(c_param.data) 77 | -------------------------------------------------------------------------------- /lvdm/samplers/ddim.py: -------------------------------------------------------------------------------- 1 | """SAMPLING ONLY.""" 2 | 3 | import torch 4 | import numpy as np 5 | from tqdm import tqdm 6 | 7 | from lvdm.models.modules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like 8 | 9 | 10 | class DDIMSampler(object): 11 | def __init__(self, model, schedule="linear", **kwargs): 12 | super().__init__() 13 | self.model = model 14 | self.ddpm_num_timesteps = model.num_timesteps 15 | self.schedule = schedule 16 | self.counter = 0 17 | 18 | def register_buffer(self, name, attr): 19 | if type(attr) == torch.Tensor: 20 | if attr.device != torch.device("cuda"): 21 | attr = attr.to(torch.device("cuda")) 22 | setattr(self, name, attr) 23 | 24 | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): 25 | self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, 26 | num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) 27 | alphas_cumprod = self.model.alphas_cumprod 28 | assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' 29 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) 30 | 31 | self.register_buffer('betas', to_torch(self.model.betas)) 32 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 33 | self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) 34 | 35 | # calculations for diffusion q(x_t | x_{t-1}) and others 36 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) 37 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) 38 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) 39 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) 40 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) 41 | 42 | # ddim sampling parameters 43 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), 44 | ddim_timesteps=self.ddim_timesteps, 45 | eta=ddim_eta,verbose=verbose) 46 | self.register_buffer('ddim_sigmas', ddim_sigmas) 47 | self.register_buffer('ddim_alphas', ddim_alphas) 48 | self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) 49 | self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) 50 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( 51 | (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( 52 | 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) 53 | self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) 54 | 55 | @torch.no_grad() 56 | def sample(self, 57 | S, 58 | batch_size, 59 | shape, 60 | conditioning=None, 61 | callback=None, 62 | img_callback=None, 63 | quantize_x0=False, 64 | eta=0., 65 | mask=None, 66 | x0=None, 67 | temperature=1., 68 | noise_dropout=0., 69 | score_corrector=None, 70 | corrector_kwargs=None, 71 | verbose=True, 72 | schedule_verbose=False, 73 | x_T=None, 74 | log_every_t=100, 75 | unconditional_guidance_scale=1., 76 | unconditional_conditioning=None, 77 | postprocess_fn=None, 78 | sample_noise=None, 79 | cond_fn=None, 80 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... 81 | **kwargs 82 | ): 83 | 84 | # check condition bs 85 | if conditioning is not None: 86 | if isinstance(conditioning, dict): 87 | try: 88 | cbs = conditioning[list(conditioning.keys())[0]].shape[0] 89 | if cbs != batch_size: 90 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 91 | except: 92 | # cbs = conditioning[list(conditioning.keys())[0]][0].shape[0] 93 | pass 94 | else: 95 | if conditioning.shape[0] != batch_size: 96 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 97 | 98 | self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=schedule_verbose) 99 | 100 | # make shape 101 | if len(shape) == 3: 102 | C, H, W = shape 103 | size = (batch_size, C, H, W) 104 | elif len(shape) == 4: 105 | C, T, H, W = shape 106 | size = (batch_size, C, T, H, W) 107 | 108 | samples, intermediates = self.ddim_sampling(conditioning, size, 109 | callback=callback, 110 | img_callback=img_callback, 111 | quantize_denoised=quantize_x0, 112 | mask=mask, x0=x0, 113 | ddim_use_original_steps=False, 114 | noise_dropout=noise_dropout, 115 | temperature=temperature, 116 | score_corrector=score_corrector, 117 | corrector_kwargs=corrector_kwargs, 118 | x_T=x_T, 119 | log_every_t=log_every_t, 120 | unconditional_guidance_scale=unconditional_guidance_scale, 121 | unconditional_conditioning=unconditional_conditioning, 122 | postprocess_fn=postprocess_fn, 123 | sample_noise=sample_noise, 124 | cond_fn=cond_fn, 125 | verbose=verbose, 126 | **kwargs 127 | ) 128 | return samples, intermediates 129 | 130 | @torch.no_grad() 131 | def ddim_sampling(self, cond, shape, 132 | x_T=None, ddim_use_original_steps=False, 133 | callback=None, timesteps=None, quantize_denoised=False, 134 | mask=None, x0=None, img_callback=None, log_every_t=100, 135 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 136 | unconditional_guidance_scale=1., unconditional_conditioning=None, 137 | postprocess_fn=None,sample_noise=None,cond_fn=None, 138 | uc_type=None, verbose=True, **kwargs, 139 | ): 140 | 141 | device = self.model.betas.device 142 | 143 | b = shape[0] 144 | if x_T is None: 145 | img = torch.randn(shape, device=device) 146 | else: 147 | img = x_T 148 | 149 | if timesteps is None: 150 | timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps 151 | elif timesteps is not None and not ddim_use_original_steps: 152 | subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 153 | timesteps = self.ddim_timesteps[:subset_end] 154 | intermediates = {'x_inter': [img], 'pred_x0': [img]} 155 | time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) 156 | total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] 157 | if verbose: 158 | iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) 159 | else: 160 | iterator = time_range 161 | 162 | for i, step in enumerate(iterator): 163 | index = total_steps - i - 1 164 | ts = torch.full((b,), step, device=device, dtype=torch.long) 165 | 166 | if postprocess_fn is not None: 167 | img = postprocess_fn(img, ts) 168 | 169 | outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, 170 | quantize_denoised=quantize_denoised, temperature=temperature, 171 | noise_dropout=noise_dropout, score_corrector=score_corrector, 172 | corrector_kwargs=corrector_kwargs, 173 | unconditional_guidance_scale=unconditional_guidance_scale, 174 | unconditional_conditioning=unconditional_conditioning, 175 | sample_noise=sample_noise,cond_fn=cond_fn,uc_type=uc_type, **kwargs,) 176 | img, pred_x0 = outs 177 | 178 | if mask is not None: 179 | # use mask to blend x_known_t-1 & x_sample_t-1 180 | assert x0 is not None 181 | x0 = x0.to(img.device) 182 | mask = mask.to(img.device) 183 | t = torch.tensor([step-1]*x0.shape[0], dtype=torch.long, device=img.device) 184 | img_known = self.model.q_sample(x0, t) 185 | img = img_known * mask + (1. - mask) * img 186 | 187 | if callback: callback(i) 188 | if img_callback: img_callback(pred_x0, i) 189 | 190 | if index % log_every_t == 0 or index == total_steps - 1: 191 | intermediates['x_inter'].append(img) 192 | intermediates['pred_x0'].append(pred_x0) 193 | 194 | return img, intermediates 195 | 196 | @torch.no_grad() 197 | def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, 198 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 199 | unconditional_guidance_scale=1., unconditional_conditioning=None, sample_noise=None, 200 | cond_fn=None, uc_type=None, 201 | **kwargs, 202 | ): 203 | b, *_, device = *x.shape, x.device 204 | if x.dim() == 5: 205 | is_video = True 206 | else: 207 | is_video = False 208 | if unconditional_conditioning is None or unconditional_guidance_scale == 1.: 209 | e_t = self.model.apply_model(x, t, c, **kwargs) # unet denoiser 210 | else: 211 | # with unconditional condition 212 | if isinstance(c, torch.Tensor): 213 | e_t = self.model.apply_model(x, t, c, **kwargs) 214 | e_t_uncond = self.model.apply_model(x, t, unconditional_conditioning, **kwargs) 215 | elif isinstance(c, dict): 216 | e_t = self.model.apply_model(x, t, c, **kwargs) 217 | e_t_uncond = self.model.apply_model(x, t, unconditional_conditioning, **kwargs) 218 | else: 219 | raise NotImplementedError 220 | 221 | if uc_type is None: 222 | e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) 223 | elif uc_type == 'cfg_original': 224 | e_t = e_t + unconditional_guidance_scale * (e_t - e_t_uncond) 225 | else: 226 | raise NotImplementedError 227 | 228 | if score_corrector is not None: 229 | assert self.model.parameterization == "eps" 230 | e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) 231 | 232 | alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas 233 | alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev 234 | sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas 235 | sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas 236 | # select parameters corresponding to the currently considered timestep 237 | 238 | if is_video: 239 | size = (b, 1, 1, 1, 1) 240 | else: 241 | size = (b, 1, 1, 1) 242 | a_t = torch.full(size, alphas[index], device=device) 243 | a_prev = torch.full(size, alphas_prev[index], device=device) 244 | sigma_t = torch.full(size, sigmas[index], device=device) 245 | sqrt_one_minus_at = torch.full(size, sqrt_one_minus_alphas[index],device=device) 246 | 247 | # current prediction for x_0 248 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() 249 | # print(f't={t}, pred_x0, min={torch.min(pred_x0)}, max={torch.max(pred_x0)}',file=f) 250 | if quantize_denoised: 251 | pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) 252 | # direction pointing to x_t 253 | dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t 254 | if sample_noise is None: 255 | noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature 256 | if noise_dropout > 0.: 257 | noise = torch.nn.functional.dropout(noise, p=noise_dropout) 258 | else: 259 | noise = sigma_t * sample_noise * temperature 260 | 261 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise 262 | 263 | return x_prev, pred_x0 264 | -------------------------------------------------------------------------------- /lvdm/utils/common_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import importlib 3 | import numpy as np 4 | from inspect import isfunction 5 | 6 | import torch 7 | 8 | 9 | def shape_to_str(x): 10 | shape_str = "x".join([str(x) for x in x.shape]) 11 | return shape_str 12 | 13 | 14 | def str2bool(v): 15 | if isinstance(v, bool): 16 | return v 17 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 18 | return True 19 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 20 | return False 21 | else: 22 | raise ValueError('Boolean value expected.') 23 | 24 | 25 | def get_obj_from_str(string, reload=False): 26 | module, cls = string.rsplit(".", 1) 27 | if reload: 28 | module_imp = importlib.import_module(module) 29 | importlib.reload(module_imp) 30 | return getattr(importlib.import_module(module, package=None), cls) 31 | 32 | 33 | def instantiate_from_config(config): 34 | if not "target" in config: 35 | if config == '__is_first_stage__': 36 | return None 37 | elif config == "__is_unconditional__": 38 | return None 39 | raise KeyError("Expected key `target` to instantiate.") 40 | 41 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 42 | 43 | 44 | def shift_dim(x, src_dim=-1, dest_dim=-1, make_contiguous=True): 45 | """ Shifts src_tf dim to dest dim 46 | i.e. shift_dim(x, 1, -1) would be (b, c, t, h, w) -> (b, t, h, w, c) 47 | """ 48 | n_dims = len(x.shape) 49 | if src_dim < 0: 50 | src_dim = n_dims + src_dim 51 | if dest_dim < 0: 52 | dest_dim = n_dims + dest_dim 53 | 54 | assert 0 <= src_dim < n_dims and 0 <= dest_dim < n_dims 55 | 56 | dims = list(range(n_dims)) 57 | del dims[src_dim] 58 | 59 | permutation = [] 60 | ctr = 0 61 | for i in range(n_dims): 62 | if i == dest_dim: 63 | permutation.append(src_dim) 64 | else: 65 | permutation.append(dims[ctr]) 66 | ctr += 1 67 | x = x.permute(permutation) 68 | if make_contiguous: 69 | x = x.contiguous() 70 | return x 71 | 72 | 73 | def torch_to_np(x): 74 | sample = x.detach().cpu() 75 | sample = ((sample + 1) * 127.5).clamp(0, 255).to(torch.uint8) 76 | if sample.dim() == 5: 77 | sample = sample.permute(0, 2, 3, 4, 1) 78 | else: 79 | sample = sample.permute(0, 2, 3, 1) 80 | sample = sample.contiguous().numpy() 81 | return sample 82 | 83 | 84 | def np_to_torch_video(x): 85 | x = torch.tensor(x).permute(3, 0, 1, 2).float() # [t,h,w,c] -> [c,t,h,w] 86 | x = (x / 255 - 0.5) * 2 87 | return x 88 | 89 | 90 | def load_npz_from_dir(data_dir): 91 | data = [np.load(os.path.join(data_dir, data_name))['arr_0'] for data_name in os.listdir(data_dir)] 92 | data = np.concatenate(data, axis=0) 93 | return data 94 | 95 | 96 | def load_npz_from_paths(data_paths): 97 | data = [np.load(data_path)['arr_0'] for data_path in data_paths] 98 | data = np.concatenate(data, axis=0) 99 | return data 100 | 101 | 102 | def ismap(x): 103 | if not isinstance(x, torch.Tensor): 104 | return False 105 | return (len(x.shape) == 4) and (x.shape[1] > 3) 106 | 107 | 108 | def isimage(x): 109 | if not isinstance(x,torch.Tensor): 110 | return False 111 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) 112 | 113 | 114 | def exists(x): 115 | return x is not None 116 | 117 | 118 | def default(val, d): 119 | if exists(val): 120 | return val 121 | return d() if isfunction(d) else d 122 | 123 | 124 | def mean_flat(tensor): 125 | """ 126 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 127 | Take the mean over all non-batch dimensions. 128 | """ 129 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 130 | 131 | 132 | def count_params(model, verbose=False): 133 | total_params = sum(p.numel() for p in model.parameters()) 134 | if verbose: 135 | print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") 136 | return total_params 137 | 138 | 139 | def check_istarget(name, para_list): 140 | """ 141 | name: full name of source para 142 | para_list: partial name of target para 143 | """ 144 | istarget=False 145 | for para in para_list: 146 | if para in name: 147 | return True 148 | return istarget -------------------------------------------------------------------------------- /lvdm/utils/dist_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | 4 | def setup_dist(local_rank): 5 | if dist.is_initialized(): 6 | return 7 | torch.cuda.set_device(local_rank) 8 | torch.distributed.init_process_group( 9 | 'nccl', 10 | init_method='env://' 11 | ) 12 | 13 | def gather_data(data, return_np=True): 14 | ''' gather data from multiple processes to one list ''' 15 | data_list = [torch.zeros_like(data) for _ in range(dist.get_world_size())] 16 | dist.all_gather(data_list, data) 17 | if return_np: 18 | data_list = [data.cpu().numpy() for data in data_list] 19 | return data_list 20 | -------------------------------------------------------------------------------- /lvdm/utils/log.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import multiprocessing as mproc 4 | import types 5 | 6 | #: number of available CPUs on this computer 7 | CPU_COUNT = int(mproc.cpu_count()) 8 | #: default date-time format 9 | FORMAT_DATE_TIME = '%Y%m%d-%H%M%S' 10 | #: default logging tile 11 | FILE_LOGS = 'logging.log' 12 | #: default logging template - log location/source for logging to file 13 | STR_LOG_FORMAT = '%(asctime)s:%(levelname)s@%(filename)s:%(processName)s - %(message)s' 14 | #: default logging template - date-time for logging to file 15 | LOG_FILE_FORMAT = logging.Formatter(STR_LOG_FORMAT, datefmt="%H:%M:%S") 16 | #: define all types to be assume list like 17 | ITERABLE_TYPES = (list, tuple, types.GeneratorType) 18 | def release_logger_files(): 19 | """ close all handlers to a file 20 | >>> release_logger_files() 21 | >>> len([1 for lh in logging.getLogger().handlers 22 | ... if type(lh) is logging.FileHandler]) 23 | 0 24 | """ 25 | for hl in logging.getLogger().handlers: 26 | if isinstance(hl, logging.FileHandler): 27 | hl.close() 28 | logging.getLogger().removeHandler(hl) 29 | 30 | def set_experiment_logger(path_out, file_name=FILE_LOGS, reset=True): 31 | """ set the logger to file 32 | :param str path_out: path to the output folder 33 | :param str file_name: log file name 34 | :param bool reset: reset all previous logging into a file 35 | >>> set_experiment_logger('.') 36 | >>> len([1 for lh in logging.getLogger().handlers 37 | ... if type(lh) is logging.FileHandler]) 38 | 1 39 | >>> release_logger_files() 40 | >>> os.remove(FILE_LOGS) 41 | """ 42 | log = logging.getLogger() 43 | log.setLevel(logging.DEBUG) 44 | 45 | if reset: 46 | release_logger_files() 47 | path_logger = os.path.join(path_out, file_name) 48 | fh = logging.FileHandler(path_logger) 49 | fh.setLevel(logging.DEBUG) 50 | fh.setFormatter(LOG_FILE_FORMAT) 51 | log.addHandler(fh) 52 | 53 | def set_ptl_logger(path_out, phase, file_name="ptl.log", reset=True): 54 | """ set the logger to file 55 | :param str path_out: path to the output folder 56 | :param str file_name: log file name 57 | :param bool reset: reset all previous logging into a file 58 | >>> set_experiment_logger('.') 59 | >>> len([1 for lh in logging.getLogger().handlers 60 | ... if type(lh) is logging.FileHandler]) 61 | 1 62 | >>> release_logger_files() 63 | >>> os.remove(FILE_LOGS) 64 | """ 65 | file_name = f"ptl_{phase}.log" 66 | level = logging.INFO 67 | log = logging.getLogger("pytorch_lightning") 68 | log.setLevel(level) 69 | 70 | if reset: 71 | release_logger_files() 72 | 73 | path_logger = os.path.join(path_out, file_name) 74 | fh = logging.FileHandler(path_logger) 75 | fh.setLevel(level) 76 | fh.setFormatter(LOG_FILE_FORMAT) 77 | 78 | log.addHandler(fh) 79 | -------------------------------------------------------------------------------- /lvdm/utils/saving_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.insert(1, os.path.join(sys.path[0], '..')) 4 | import cv2 5 | import os 6 | import time 7 | import imageio 8 | import numpy as np 9 | from PIL import Image 10 | from tqdm import tqdm 11 | from PIL import Image, ImageDraw, ImageFont 12 | 13 | import torch 14 | import torchvision 15 | from torch import Tensor 16 | from torchvision.utils import make_grid 17 | from torchvision.transforms.functional import to_tensor 18 | 19 | 20 | def tensor_to_mp4(video, savepath, fps, rescale=True, nrow=None): 21 | """ 22 | video: torch.Tensor, b,c,t,h,w, 0-1 23 | if -1~1, enable rescale=True 24 | """ 25 | n = video.shape[0] 26 | video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w 27 | nrow = int(np.sqrt(n)) if nrow is None else nrow 28 | frame_grids = [torchvision.utils.make_grid(framesheet, nrow=nrow) for framesheet in video] # [3, grid_h, grid_w] 29 | grid = torch.stack(frame_grids, dim=0) # stack in temporal dim [T, 3, grid_h, grid_w] 30 | grid = torch.clamp(grid.float(), -1., 1.) 31 | if rescale: 32 | grid = (grid + 1.0) / 2.0 33 | grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1) # [T, 3, grid_h, grid_w] -> [T, grid_h, grid_w, 3] 34 | #print(f'Save video to {savepath}') 35 | torchvision.io.write_video(savepath, grid, fps=fps, video_codec='h264', options={'crf': '10'}) 36 | 37 | # ---------------------------------------------------------------------------------------------- 38 | def savenp2sheet(imgs, savepath, nrow=None): 39 | """ save multiple imgs (in numpy array type) to a img sheet. 40 | img sheet is one row. 41 | 42 | imgs: 43 | np array of size [N, H, W, 3] or List[array] with array size = [H,W,3] 44 | """ 45 | if imgs.ndim == 4: 46 | img_list = [imgs[i] for i in range(imgs.shape[0])] 47 | imgs = img_list 48 | 49 | imgs_new = [] 50 | for i, img in enumerate(imgs): 51 | if img.ndim == 3 and img.shape[0] == 3: 52 | img = np.transpose(img,(1,2,0)) 53 | 54 | assert(img.ndim == 3 and img.shape[-1] == 3), img.shape # h,w,3 55 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) 56 | imgs_new.append(img) 57 | n = len(imgs) 58 | if nrow is not None: 59 | n_cols = nrow 60 | else: 61 | n_cols=int(n**0.5) 62 | n_rows=int(np.ceil(n/n_cols)) 63 | print(n_cols) 64 | print(n_rows) 65 | 66 | imgsheet = cv2.vconcat([cv2.hconcat(imgs_new[i*n_cols:(i+1)*n_cols]) for i in range(n_rows)]) 67 | cv2.imwrite(savepath, imgsheet) 68 | print(f'saved in {savepath}') 69 | 70 | # ---------------------------------------------------------------------------------------------- 71 | def save_np_to_img(img, path, norm=True): 72 | if norm: 73 | img = (img + 1) / 2 * 255 74 | img = img.astype(np.uint8) 75 | image = Image.fromarray(img) 76 | image.save(path, q=95) 77 | 78 | # ---------------------------------------------------------------------------------------------- 79 | def npz_to_imgsheet_5d(data_path, res_dir, nrow=None,): 80 | if isinstance(data_path, str): 81 | imgs = np.load(data_path)['arr_0'] # NTHWC 82 | elif isinstance(data_path, np.ndarray): 83 | imgs = data_path 84 | else: 85 | raise Exception 86 | 87 | if os.path.isdir(res_dir): 88 | res_path = os.path.join(res_dir, f'samples.jpg') 89 | else: 90 | assert(res_dir.endswith('.jpg')) 91 | res_path = res_dir 92 | imgs = np.concatenate([imgs[i] for i in range(imgs.shape[0])], axis=0) 93 | savenp2sheet(imgs, res_path, nrow=nrow) 94 | 95 | # ---------------------------------------------------------------------------------------------- 96 | def npz_to_imgsheet_4d(data_path, res_path, nrow=None,): 97 | if isinstance(data_path, str): 98 | imgs = np.load(data_path)['arr_0'] # NHWC 99 | elif isinstance(data_path, np.ndarray): 100 | imgs = data_path 101 | else: 102 | raise Exception 103 | print(imgs.shape) 104 | savenp2sheet(imgs, res_path, nrow=nrow) 105 | 106 | 107 | # ---------------------------------------------------------------------------------------------- 108 | def tensor_to_imgsheet(tensor, save_path): 109 | """ 110 | save a batch of videos in one image sheet with shape of [batch_size * num_frames]. 111 | data: [b,c,t,h,w] 112 | """ 113 | assert(tensor.dim() == 5) 114 | b,c,t,h,w = tensor.shape 115 | imgs = [tensor[bi,:,ti, :, :] for bi in range(b) for ti in range(t)] 116 | torchvision.utils.save_image(imgs, save_path, normalize=True, nrow=t) 117 | 118 | 119 | # ---------------------------------------------------------------------------------------------- 120 | def npz_to_frames(data_path, res_dir, norm, num_frames=None, num_samples=None): 121 | start = time.time() 122 | arr = np.load(data_path) 123 | imgs = arr['arr_0'] # [N, T, H, W, 3] 124 | print('original data shape: ', imgs.shape) 125 | 126 | if num_samples is not None: 127 | imgs = imgs[:num_samples, :, :, :, :] 128 | print('after sample selection: ', imgs.shape) 129 | 130 | if num_frames is not None: 131 | imgs = imgs[:, :num_frames, :, :, :] 132 | print('after frame selection: ', imgs.shape) 133 | 134 | for vid in tqdm(range(imgs.shape[0]), desc='Video'): 135 | video_dir = os.path.join(res_dir, f'video{vid:04d}') 136 | os.makedirs(video_dir, exist_ok=True) 137 | for fid in range(imgs.shape[1]): 138 | frame = imgs[vid, fid, :, :, :] #HW3 139 | save_np_to_img(frame, os.path.join(video_dir, f'frame{fid:04d}.jpg'), norm=norm) 140 | print('Finish') 141 | print(f'Total time = {time.time()- start}') 142 | 143 | # ---------------------------------------------------------------------------------------------- 144 | def npz_to_gifs(data_path, res_dir, duration=0.2, start_idx=0, num_videos=None, mode='gif'): 145 | os.makedirs(res_dir, exist_ok=True) 146 | if isinstance(data_path, str): 147 | imgs = np.load(data_path)['arr_0'] # NTHWC 148 | elif isinstance(data_path, np.ndarray): 149 | imgs = data_path 150 | else: 151 | raise Exception 152 | 153 | for i in range(imgs.shape[0]): 154 | frames = [imgs[i,j,:,:,:] for j in range(imgs[i].shape[0])] # [(h,w,3)] 155 | if mode == 'gif': 156 | imageio.mimwrite(os.path.join(res_dir, f'samples_{start_idx+i}.gif'), frames, format='GIF', duration=duration) 157 | elif mode == 'mp4': 158 | frames = [torch.from_numpy(frame) for frame in frames] 159 | frames = torch.stack(frames, dim=0).to(torch.uint8) # [T, H, W, C] 160 | torchvision.io.write_video(os.path.join(res_dir, f'samples_{start_idx+i}.mp4'), 161 | frames, fps=0.5, video_codec='h264', options={'crf': '10'}) 162 | if i+ 1 == num_videos: 163 | break 164 | 165 | # ---------------------------------------------------------------------------------------------- 166 | def fill_with_black_squares(video, desired_len: int) -> Tensor: 167 | if len(video) >= desired_len: 168 | return video 169 | 170 | return torch.cat([ 171 | video, 172 | torch.zeros_like(video[0]).unsqueeze(0).repeat(desired_len - len(video), 1, 1, 1), 173 | ], dim=0) 174 | 175 | # ---------------------------------------------------------------------------------------------- 176 | def load_num_videos(data_path, num_videos): 177 | # data_path can be either data_path of np array 178 | if isinstance(data_path, str): 179 | videos = np.load(data_path)['arr_0'] # NTHWC 180 | elif isinstance(data_path, np.ndarray): 181 | videos = data_path 182 | else: 183 | raise Exception 184 | 185 | if num_videos is not None: 186 | videos = videos[:num_videos, :, :, :, :] 187 | return videos 188 | 189 | # ---------------------------------------------------------------------------------------------- 190 | def npz_to_video_grid(data_path, out_path, num_frames=None, fps=8, num_videos=None, nrow=None, verbose=True): 191 | if isinstance(data_path, str): 192 | videos = load_num_videos(data_path, num_videos) 193 | elif isinstance(data_path, np.ndarray): 194 | videos = data_path 195 | else: 196 | raise Exception 197 | n,t,h,w,c = videos.shape 198 | 199 | videos_th = [] 200 | for i in range(n): 201 | video = videos[i, :,:,:,:] 202 | images = [video[j, :,:,:] for j in range(t)] 203 | images = [to_tensor(img) for img in images] 204 | video = torch.stack(images) 205 | videos_th.append(video) 206 | 207 | if num_frames is None: 208 | num_frames = videos.shape[1] 209 | if verbose: 210 | videos = [fill_with_black_squares(v, num_frames) for v in tqdm(videos_th, desc='Adding empty frames')] # NTCHW 211 | else: 212 | videos = [fill_with_black_squares(v, num_frames) for v in videos_th] # NTCHW 213 | 214 | frame_grids = torch.stack(videos).permute(1, 0, 2, 3, 4) # [T, N, C, H, W] 215 | if nrow is None: 216 | nrow = int(np.ceil(np.sqrt(n))) 217 | if verbose: 218 | frame_grids = [make_grid(fs, nrow=nrow) for fs in tqdm(frame_grids, desc='Making grids')] 219 | else: 220 | frame_grids = [make_grid(fs, nrow=nrow) for fs in frame_grids] 221 | 222 | if os.path.dirname(out_path) != "": 223 | os.makedirs(os.path.dirname(out_path), exist_ok=True) 224 | frame_grids = (torch.stack(frame_grids) * 255).to(torch.uint8).permute(0, 2, 3, 1) # [T, H, W, C] 225 | torchvision.io.write_video(out_path, frame_grids, fps=fps, video_codec='h264', options={'crf': '10'}) 226 | 227 | # ---------------------------------------------------------------------------------------------- 228 | def npz_to_gif_grid(data_path, out_path, n_cols=None, num_videos=20): 229 | arr = np.load(data_path) 230 | imgs = arr['arr_0'] # [N, T, H, W, 3] 231 | imgs = imgs[:num_videos] 232 | n, t, h, w, c = imgs.shape 233 | assert(n == num_videos) 234 | n_cols = n_cols if n_cols else imgs.shape[0] 235 | n_rows = np.ceil(imgs.shape[0] / n_cols).astype(np.int8) 236 | H, W = h * n_rows, w * n_cols 237 | grid = np.zeros((t, H, W, c), dtype=np.uint8) 238 | 239 | for i in range(n_rows): 240 | for j in range(n_cols): 241 | if i*n_cols+j < imgs.shape[0]: 242 | grid[:, i*h:(i+1)*h, j*w:(j+1)*w, :] = imgs[i*n_cols+j, :, :, :, :] 243 | 244 | videos = [grid[i] for i in range(grid.shape[0])] # grid: TH'W'C 245 | imageio.mimwrite(out_path, videos, format='GIF', duration=0.5,palettesize=256) 246 | 247 | 248 | # ---------------------------------------------------------------------------------------------- 249 | def torch_to_video_grid(videos, out_path, num_frames, fps, num_videos=None, nrow=None, verbose=True): 250 | """ 251 | videos: -1 ~ 1, torch.Tensor, BCTHW 252 | """ 253 | n,t,h,w,c = videos.shape 254 | videos_th = [videos[i, ...] for i in range(n)] 255 | if verbose: 256 | videos = [fill_with_black_squares(v, num_frames) for v in tqdm(videos_th, desc='Adding empty frames')] # NTCHW 257 | else: 258 | videos = [fill_with_black_squares(v, num_frames) for v in videos_th] # NTCHW 259 | 260 | frame_grids = torch.stack(videos).permute(1, 0, 2, 3, 4) # [T, N, C, H, W] 261 | if nrow is None: 262 | nrow = int(np.ceil(np.sqrt(n))) 263 | if verbose: 264 | frame_grids = [make_grid(fs, nrow=nrow) for fs in tqdm(frame_grids, desc='Making grids')] 265 | else: 266 | frame_grids = [make_grid(fs, nrow=nrow) for fs in frame_grids] 267 | 268 | if os.path.dirname(out_path) != "": 269 | os.makedirs(os.path.dirname(out_path), exist_ok=True) 270 | frame_grids = ((torch.stack(frame_grids) + 1) / 2 * 255).to(torch.uint8).permute(0, 2, 3, 1) # [T, H, W, C] 271 | torchvision.io.write_video(out_path, frame_grids, fps=fps, video_codec='h264', options={'crf': '10'}) 272 | 273 | 274 | def log_txt_as_img(wh, xc, size=10): 275 | # wh a tuple of (width, height) 276 | # xc a list of captions to plot 277 | b = len(xc) 278 | txts = list() 279 | for bi in range(b): 280 | txt = Image.new("RGB", wh, color="white") 281 | draw = ImageDraw.Draw(txt) 282 | font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) 283 | nc = int(40 * (wh[0] / 256)) 284 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) 285 | 286 | try: 287 | draw.text((0, 0), lines, fill="black", font=font) 288 | except UnicodeEncodeError: 289 | print("Cant encode string for logging. Skipping.") 290 | 291 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 292 | txts.append(txt) 293 | txts = np.stack(txts) 294 | txts = torch.tensor(txts) 295 | return txts 296 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | decord==0.6.0 2 | einops==0.3.0 3 | eval==0.0.5 4 | imageio==2.9.0 5 | numpy==1.19.2 6 | omegaconf==2.3.0 7 | opencv_python_headless==4.6.0.66 8 | packaging==21.3 9 | Pillow==9.5.0 10 | pudb==2019.2 11 | pytorch_lightning==1.4.2 12 | PyYAML==6.0 13 | scikit_learn==1.2.2 14 | setuptools==59.5.0 15 | torch==1.9.0+cu111 16 | torchvision==0.10.0+cu111 17 | -f https://download.pytorch.org/whl/torch_stable.html 18 | tqdm==4.64.0 19 | transformers==4.24.0 20 | torchmetrics==0.6.0 21 | moviepy 22 | av 23 | six 24 | timm 25 | test-tube 26 | -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers 27 | -e . -------------------------------------------------------------------------------- /requirements_h800_gpu.txt: -------------------------------------------------------------------------------- 1 | decord==0.6.0 2 | einops==0.3.0 3 | eval==0.0.5 4 | imageio==2.9.0 5 | numpy==1.19.2 6 | omegaconf==2.3.0 7 | opencv_python_headless==4.6.0.66 8 | packaging==21.3 9 | Pillow==9.5.0 10 | pudb==2019.2 11 | pytorch_lightning==1.4.2 12 | PyYAML==6.0 13 | scikit_learn==1.2.2 14 | setuptools==59.5.0 15 | torch==2.2.2 16 | torchvision==0.17.2 17 | tqdm==4.64.0 18 | transformers==4.24.0 19 | torchmetrics==0.6.0 20 | moviepy 21 | av 22 | six 23 | timm 24 | test-tube 25 | -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers 26 | -e . -------------------------------------------------------------------------------- /scripts/eval_cal_fvd_kvd.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import sys 4 | import time 5 | import math 6 | import json 7 | import argparse 8 | import numpy as np 9 | from tqdm import tqdm 10 | from omegaconf import OmegaConf 11 | 12 | import torch 13 | from torch.utils.data import DataLoader 14 | 15 | from lvdm.utils.common_utils import instantiate_from_config, shift_dim 16 | from scripts.fvd_utils.fvd_utils import get_fvd_logits, frechet_distance, load_fvd_model, polynomial_mmd 17 | 18 | def get_args(): 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('--fake_path', type=str, default='fake data path') 21 | parser.add_argument('--real_path', type=str, default='real data dir') 22 | parser.add_argument('--yaml', type=str, default=None, help="training config file for construct dataloader") 23 | parser.add_argument('--batch_size', type=int, default=32) 24 | parser.add_argument('--num_workers', type=int, default=8) 25 | parser.add_argument('--n_runs', type=int, default=1, help='calculate multiple times') 26 | parser.add_argument('--gpu_id', type=int, default=None) 27 | parser.add_argument('--res_dir', type=str, default='') 28 | parser.add_argument('--n_sample', type=int, default=2048) 29 | parser.add_argument('--start_clip_id', type=int, default=None, help="for the evaluation of long video generation") 30 | parser.add_argument('--end_clip_id', type=int, default=None, help="for the evaluation of long video generation") 31 | args = parser.parse_args() 32 | return args 33 | 34 | def run(args, run_id, fake_path, i3d, loader, device): 35 | start = time.time() 36 | 37 | print('load fake videos in numpy ...') 38 | if fake_path.endswith('npz'): 39 | fake_data = np.load(fake_path)['arr_0'] 40 | elif fake_path.endswith('npy'): 41 | fake_data = np.load(fake_path) 42 | else: 43 | print(fake_path) 44 | raise NotImplementedError 45 | 46 | s = time.time() 47 | real_embeddings = [] 48 | for batch in tqdm(loader, desc="Extract Real Embedding", total=math.ceil(args.n_sample / args.batch_size)): 49 | if len(real_embeddings)*args.batch_size >=args.n_sample: break 50 | videos = shift_dim((batch[args.image_key]+1)*255/2, 1, -1).int().data.numpy() # convert to 0-255 51 | real_embeddings.append(get_fvd_logits(videos, i3d=i3d, device=device, batch_size=args.batch_size)) 52 | real_embeddings = torch.cat(real_embeddings, 0)[:args.n_sample] 53 | t = time.time() - s 54 | s = time.time() 55 | fake_embeddings = [] 56 | n_batch = fake_data.shape[0]//args.batch_size 57 | for i in tqdm(range(n_batch), desc="Extract Fake Embedding"): 58 | fake_embeddings.append(get_fvd_logits(fake_data[i*args.batch_size:(i+1)*args.batch_size], i3d=i3d, device=device, batch_size=args.batch_size)) 59 | fake_embeddings = torch.cat(fake_embeddings, 0)[:args.n_sample] 60 | t = time.time() - s 61 | 62 | print('calculate fvd ...') 63 | fvd = frechet_distance(fake_embeddings, real_embeddings) 64 | fvd = fvd.cpu().numpy() # np float32 65 | 66 | print('calculate kvd ...') 67 | kvd = polynomial_mmd(fake_embeddings.cpu(), real_embeddings.cpu()) # np float 64 68 | total = time.time() - start 69 | 70 | print(f'Run_id = {run_id}') 71 | print(f'FVD = {fvd:.2f}') 72 | print(f'KVD = {kvd:.2f}') 73 | print(f'Time = {total:.2f}') 74 | return [fvd, kvd, total] 75 | 76 | def run_multitimes(args, i3d, loader, device): 77 | res_all = [] 78 | for i in range(args.n_runs): 79 | run_id = i 80 | res = run(args, run_id, fake_path=args.fake_path, i3d=i3d, loader=loader, device=device) 81 | res_all.append(np.array(res)) 82 | res_avg = np.mean(np.stack(res_all, axis=0), axis=0) 83 | res_std = np.std(np.stack(res_all, axis=0), axis=0) 84 | 85 | print(f'Results of {args.n_runs} runs:') 86 | print(f'FVD = {res_avg[0]} ({res_std[0]})') 87 | print(f'KVD = {res_avg[1]} ({res_std[1]})') 88 | print(f'Time = {res_avg[2]} ({res_std[2]})') 89 | 90 | # dump results 91 | res={'FVD': f'{res_avg[0]} ({res_std[0]})', 92 | 'KVD': f'{res_avg[1]} ({res_std[1]})', 93 | 'Time': f'{res_avg[2]} ({res_std[2]})', 94 | 'Clip_path': f'{args.fake_path}' 95 | } 96 | f = open(os.path.join(args.res_dir, f'{args.n_runs}runs_fvd_stat.json'), 'w') 97 | json.dump(res, f) 98 | f.close() 99 | 100 | def run_multitimes_dir(args, i3d, loader, device): 101 | datalist = sorted(os.listdir(args.fake_path)) 102 | for i in range(args.n_runs): 103 | 104 | run_id = i 105 | 106 | for idx, path in tqdm(enumerate(datalist)): 107 | if args.start_clip_id is not None and idx < args.start_clip_id: 108 | continue 109 | 110 | print(f'Cal metrics for clip: {idx}, data: {path}') 111 | fvd, kvd, total = run(args, run_id, fake_path=os.path.join(args.fake_path, path), 112 | i3d=i3d, loader=loader, device=device) 113 | print(f"Run id {run_id}, Clip {idx}, FVD={fvd}, KVD={kvd}, Time={total}") 114 | 115 | # dump results 116 | fvd = float(fvd) if isinstance(fvd, np.ndarray) else fvd 117 | kvd = float(kvd) if isinstance(kvd, np.ndarray) else kvd 118 | 119 | res={'FVD': fvd, 'KVD': kvd, 'Time': total, 'Clip_path': path} 120 | f = open(os.path.join(args.res_dir, f'run{run_id}_clip{idx}_{device}.json'), 'w') 121 | json.dump(res, f) 122 | f.close() 123 | 124 | if args.end_clip_id is not None and idx == args.end_clip_id: 125 | break 126 | 127 | 128 | if __name__ == '__main__': 129 | args = get_args() 130 | print(args) 131 | os.makedirs(args.res_dir, exist_ok=True) 132 | 133 | print('load i3d ...') 134 | if args.gpu_id is not None: 135 | device = torch.device(f'cuda:{args.gpu_id}') 136 | else: 137 | device = torch.device('cuda') 138 | i3d = load_fvd_model(device) 139 | 140 | print('prepare dataset and dataloader ...') 141 | config = OmegaConf.load(args.yaml) 142 | config.data.params.train.params.data_root=args.real_path 143 | dataset = instantiate_from_config(config.data.params.train) 144 | if 'first_stage_key' in config.model.params: 145 | args.image_key = config.model.params.first_stage_key 146 | loader = DataLoader(dataset, 147 | batch_size=args.batch_size, 148 | num_workers=args.num_workers, 149 | shuffle=True) 150 | 151 | # run 152 | if os.path.isdir(args.fake_path): 153 | run_multitimes_dir(args, i3d, loader, device) 154 | else: 155 | run_multitimes(args, i3d, loader, device) 156 | 157 | -------------------------------------------------------------------------------- /scripts/fvd_utils/fvd_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import torch.nn.functional as F 5 | from sklearn.metrics.pairwise import polynomial_kernel 6 | 7 | from scripts.fvd_utils.pytorch_i3d import InceptionI3d 8 | 9 | 10 | MAX_BATCH = 8 11 | TARGET_RESOLUTION = (224, 224) 12 | 13 | 14 | def preprocess(videos, target_resolution): 15 | # videos in {0, ..., 255} as np.uint8 array 16 | b, t, h, w, c = videos.shape 17 | all_frames = torch.FloatTensor(videos).flatten(end_dim=1) # (b * t, h, w, c) 18 | all_frames = all_frames.permute(0, 3, 1, 2).contiguous() # (b * t, c, h, w) 19 | resized_videos = F.interpolate(all_frames, size=target_resolution, 20 | mode='bilinear', align_corners=False) 21 | resized_videos = resized_videos.view(b, t, c, *target_resolution) 22 | output_videos = resized_videos.transpose(1, 2).contiguous() # (b, c, t, *) 23 | scaled_videos = 2. * output_videos / 255. - 1 # [-1, 1] 24 | return scaled_videos 25 | 26 | 27 | def get_logits(i3d, videos, device, batch_size=None): 28 | if batch_size is None: 29 | batch_size = MAX_BATCH 30 | with torch.no_grad(): 31 | logits = [] 32 | for i in range(0, videos.shape[0], batch_size): 33 | batch = videos[i:i + batch_size].to(device) 34 | logits.append(i3d(batch)) 35 | logits = torch.cat(logits, dim=0) 36 | return logits 37 | 38 | 39 | def get_fvd_logits(videos, i3d, device, batch_size=None): 40 | videos = preprocess(videos, TARGET_RESOLUTION) 41 | embeddings = get_logits(i3d, videos, device, batch_size=batch_size) 42 | return embeddings 43 | 44 | 45 | def load_fvd_model(device): 46 | i3d = InceptionI3d(400, in_channels=3).to(device) 47 | current_dir = os.path.dirname(os.path.abspath(__file__)) 48 | i3d_path = os.path.join(current_dir, 'i3d_pretrained_400.pt') 49 | i3d.load_state_dict(torch.load(i3d_path, map_location=device)) 50 | i3d.eval() 51 | return i3d 52 | 53 | 54 | # https://github.com/tensorflow/gan/blob/de4b8da3853058ea380a6152bd3bd454013bf619/tensorflow_gan/python/eval/classifier_metrics.py#L161 55 | def _symmetric_matrix_square_root(mat, eps=1e-10): 56 | u, s, v = torch.svd(mat) 57 | si = torch.where(s < eps, s, torch.sqrt(s)) 58 | return torch.matmul(torch.matmul(u, torch.diag(si)), v.t()) 59 | 60 | # https://github.com/tensorflow/gan/blob/de4b8da3853058ea380a6152bd3bd454013bf619/tensorflow_gan/python/eval/classifier_metrics.py#L400 61 | def trace_sqrt_product(sigma, sigma_v): 62 | sqrt_sigma = _symmetric_matrix_square_root(sigma) 63 | sqrt_a_sigmav_a = torch.matmul(sqrt_sigma, torch.matmul(sigma_v, sqrt_sigma)) 64 | return torch.trace(_symmetric_matrix_square_root(sqrt_a_sigmav_a)) 65 | 66 | # https://discuss.pytorch.org/t/covariance-and-gradient-support/16217/2 67 | def cov(m, rowvar=False): 68 | '''Estimate a covariance matrix given data. 69 | 70 | Covariance indicates the level to which two variables vary together. 71 | If we examine N-dimensional samples, `X = [x_1, x_2, ... x_N]^T`, 72 | then the covariance matrix element `C_{ij}` is the covariance of 73 | `x_i` and `x_j`. The element `C_{ii}` is the variance of `x_i`. 74 | 75 | Args: 76 | m: A 1-D or 2-D array containing multiple variables and observations. 77 | Each row of `m` represents a variable, and each column a single 78 | observation of all those variables. 79 | rowvar: If `rowvar` is True, then each row represents a 80 | variable, with observations in the columns. Otherwise, the 81 | relationship is transposed: each column represents a variable, 82 | while the rows contain observations. 83 | 84 | Returns: 85 | The covariance matrix of the variables. 86 | ''' 87 | if m.dim() > 2: 88 | raise ValueError('m has more than 2 dimensions') 89 | if m.dim() < 2: 90 | m = m.view(1, -1) 91 | if not rowvar and m.size(0) != 1: 92 | m = m.t() 93 | 94 | fact = 1.0 / (m.size(1) - 1) # unbiased estimate 95 | m_center = m - torch.mean(m, dim=1, keepdim=True) 96 | mt = m_center.t() # if complex: mt = m.t().conj() 97 | return fact * m_center.matmul(mt).squeeze() 98 | 99 | 100 | def frechet_distance(x1, x2): 101 | x1 = x1.flatten(start_dim=1) 102 | x2 = x2.flatten(start_dim=1) 103 | m, m_w = x1.mean(dim=0), x2.mean(dim=0) 104 | sigma, sigma_w = cov(x1, rowvar=False), cov(x2, rowvar=False) 105 | sqrt_trace_component = trace_sqrt_product(sigma, sigma_w) 106 | trace = torch.trace(sigma + sigma_w) - 2.0 * sqrt_trace_component 107 | mean = torch.sum((m - m_w) ** 2) 108 | fd = trace + mean 109 | return fd 110 | 111 | 112 | 113 | def polynomial_mmd(X, Y): 114 | m = X.shape[0] 115 | n = Y.shape[0] 116 | # compute kernels 117 | K_XX = polynomial_kernel(X) 118 | K_YY = polynomial_kernel(Y) 119 | K_XY = polynomial_kernel(X, Y) 120 | # compute mmd distance 121 | K_XX_sum = (K_XX.sum() - np.diagonal(K_XX).sum()) / (m * (m - 1)) 122 | K_YY_sum = (K_YY.sum() - np.diagonal(K_YY).sum()) / (n * (n - 1)) 123 | K_XY_sum = K_XY.sum() / (m * n) 124 | mmd = K_XX_sum + K_YY_sum - 2 * K_XY_sum 125 | return mmd 126 | -------------------------------------------------------------------------------- /scripts/sample_long_videos_utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | from tqdm import trange 4 | from einops import repeat 5 | from lvdm.samplers.ddim import DDIMSampler 6 | from lvdm.models.ddpm3d import FrameInterpPredLatentDiffusion 7 | 8 | # ------------------------------------------------------------------------------------------ 9 | def add_conditions_long_video(model, cond, batch_size, 10 | sample_cond_noise_level=None, 11 | mask=None, 12 | input_shape=None, 13 | T_cond_pred=None, 14 | cond_frames=None, 15 | ): 16 | assert(isinstance(cond, dict)) 17 | try: 18 | device = model.device 19 | except: 20 | device = next(model.parameters()).device 21 | 22 | # add cond noisy level 23 | if sample_cond_noise_level is not None: 24 | if getattr(model, "noisy_cond", False): 25 | assert(sample_cond_noise_level is not None) 26 | s = sample_cond_noise_level 27 | s = repeat(torch.tensor([s]), '1 -> b', b=batch_size) 28 | s = s.to(device).long() 29 | else: 30 | s = None 31 | cond['s'] =s 32 | 33 | # add cond mask 34 | if mask is not None: 35 | if mask == "uncond": 36 | mask = torch.zeros(batch_size, 1, *input_shape[2:], device=device) 37 | elif mask == "pred": 38 | mask = torch.zeros(batch_size, 1, *input_shape[2:], device=device) 39 | mask[:, :, :T_cond_pred, :, :] = 1 40 | elif mask == "interp": 41 | mask = torch.zeros(batch_size, 1, *input_shape[2:], device=device) 42 | mask[:, :, 0, :, :] = 1 43 | mask[:, :, -1, :, :] = 1 44 | else: 45 | raise NotImplementedError 46 | cond['mask'] = mask 47 | 48 | # add cond_frames 49 | if cond_frames is not None: 50 | if sample_cond_noise_level is not None: 51 | noise = torch.randn_like(cond_frames) 52 | noisy_cond = model.q_sample(x_start=cond_frames, t=s, noise=noise) 53 | else: 54 | noisy_cond = cond_frames 55 | cond['noisy_cond'] = noisy_cond 56 | 57 | return cond 58 | 59 | # ------------------------------------------------------------------------------------------ 60 | @torch.no_grad() 61 | def sample_clip(model, shape, sample_type="ddpm", ddim_steps=None, eta=1.0, cond=None, 62 | uc=None, uncond_scale=1., uc_type=None, **kwargs): 63 | 64 | log = dict() 65 | 66 | # sample batch 67 | with model.ema_scope("EMA"): 68 | t0 = time.time() 69 | if sample_type == "ddpm": 70 | samples = model.p_sample_loop(cond, shape, return_intermediates=False, verbose=False, 71 | unconditional_guidance_scale=uncond_scale, 72 | unconditional_conditioning=uc, uc_type=uc_type, ) 73 | elif sample_type == "ddim": 74 | ddim = DDIMSampler(model) 75 | samples, intermediates = ddim.sample(S=ddim_steps, batch_size=shape[0], shape=shape[1:], 76 | conditioning=cond, eta=eta, verbose=False, 77 | unconditional_guidance_scale=uncond_scale, 78 | unconditional_conditioning=uc, uc_type=uc_type,) 79 | 80 | t1 = time.time() 81 | log["sample"] = samples 82 | log["time"] = t1 - t0 83 | log['throughput'] = samples.shape[0] / (t1 - t0) 84 | return log 85 | 86 | # ------------------------------------------------------------------------------------------ 87 | @torch.no_grad() 88 | def autoregressive_pred(model, batch_size, *args, 89 | T_cond=1, 90 | n_pred_steps=3, 91 | sample_cond_noise_level=None, 92 | decode_single_video_allframes=False, 93 | uncond_scale=1.0, 94 | uc_type=None, 95 | max_z_t=None, 96 | overlap_t=0, 97 | **kwargs): 98 | 99 | model.sample_cond_noise_level = sample_cond_noise_level 100 | 101 | image_size = model.image_size 102 | image_size = [image_size, image_size] if isinstance(image_size, int) else image_size 103 | C = model.model.diffusion_model.in_channels-1 if isinstance(model, FrameInterpPredLatentDiffusion) \ 104 | else model.model.diffusion_model.in_channels 105 | T = model.model.diffusion_model.temporal_length 106 | shape = [batch_size, C, T, *image_size] 107 | 108 | t0 = time.time() 109 | long_samples = [] 110 | 111 | # uncond sample 112 | log = dict() 113 | 114 | # -------------------------------------------------------------------- 115 | # make condition 116 | cond = add_conditions_long_video(model, {}, batch_size, mask="uncond", input_shape=shape) 117 | if (uc_type is None and uncond_scale != 1.0) or (uc_type == 'cfg_original' and uncond_scale != 0.0): 118 | print('Use Uncondition guidance') 119 | uc = add_conditions_long_video(model, {}, batch_size, mask="uncond", input_shape=shape) 120 | else: 121 | print('NO Uncondition guidance') 122 | uc=None 123 | 124 | # sample an initial clip (unconditional) 125 | sample = sample_clip(model, shape, cond=cond, 126 | uc=uc, uncond_scale=uncond_scale, uc_type=uc_type, 127 | **kwargs)['sample'] 128 | long_samples.append(sample.cpu()) 129 | 130 | # extend 131 | for i in range(n_pred_steps): 132 | T = sample.shape[2] 133 | cond_z0 = sample[:, :, T-T_cond:, :, :] 134 | assert(cond_z0.shape[2] == T_cond) 135 | 136 | # make prediction model's condition 137 | cond = add_conditions_long_video(model, {}, batch_size, mask="pred", input_shape=shape, cond_frames=cond_z0, sample_cond_noise_level=sample_cond_noise_level) 138 | ## unconditional_guidance's condition 139 | if (uc_type is None and uncond_scale != 1.0) or (uc_type == 'cfg_original' and uncond_scale != 0.0): 140 | print('Use Uncondition guidance') 141 | uc = add_conditions_long_video(model, {}, batch_size, mask="uncond", input_shape=shape) 142 | else: 143 | print('NO Uncondition guidance') 144 | uc=None 145 | 146 | # sample a short clip (condition on previous latents) 147 | sample = sample_clip(model, shape, *args, cond=cond, uc=uc, uncond_scale=uncond_scale, uc_type=uc_type, 148 | **kwargs)['sample'] 149 | ext = sample[:, :, T_cond:, :, :] 150 | assert(ext.dim() == 5) 151 | long_samples.append(ext.cpu()) 152 | progress = (i+1)/n_pred_steps * 100 153 | print(f"Finish pred step {int(progress)}% [{i+1}/{n_pred_steps}]") 154 | torch.cuda.empty_cache() 155 | long_samples = torch.cat(long_samples, dim=2) 156 | 157 | t1 = time.time() 158 | log["sample"] = long_samples 159 | log["time"] = t1 - t0 160 | log['throughput'] = long_samples.shape[0] / (t1 - t0) 161 | return log 162 | 163 | 164 | # ------------------------------------------------------------------------------------------ 165 | @torch.no_grad() 166 | def interpolate(base_samples, 167 | model, batch_size, *args, sample_video=False, 168 | sample_cond_noise_level=None, decode_single_video_allframes=False, 169 | uncond_scale=1.0, uc_type=None, max_z_t=None, 170 | overlap_t=0, prompt=None, config=None, 171 | interpolate_cond_fps=None, 172 | **kwargs): 173 | 174 | model.sample_cond_noise_level = sample_cond_noise_level 175 | 176 | N, c, t, h, w = base_samples.shape 177 | n_steps = len(range(0, t-3, 3)) 178 | device = next(model.parameters()).device 179 | if N < batch_size: 180 | batch_size = N 181 | elif N > batch_size: 182 | raise ValueError 183 | assert(N == batch_size) 184 | 185 | C = model.model.diffusion_model.in_channels-1 if isinstance(model, FrameInterpPredLatentDiffusion) and model.concat_mask_on_input \ 186 | else model.model.diffusion_model.in_channels 187 | image_size = model.image_size 188 | image_size = [image_size, image_size] if isinstance(image_size, int) else image_size 189 | T = model.model.diffusion_model.temporal_length 190 | shape = [batch_size, C, T, *image_size ] 191 | 192 | t0 = time.time() 193 | long_samples = [] 194 | cond = {} 195 | for i in trange(n_steps, desc='Interpolation Steps'): 196 | cond_z0 = base_samples[:, :, i:i+2, :, :].cuda() 197 | # make prediction model's condition 198 | cond = add_conditions_long_video(model, {}, batch_size, mask="interp", input_shape=shape, cond_frames=cond_z0, sample_cond_noise_level=sample_cond_noise_level) 199 | ## unconditional_guidance's condition 200 | if (uc_type is None and uncond_scale != 1.0) or (uc_type == 'cfg_original' and uncond_scale != 0.0): 201 | print('Use Uncondition guidance') 202 | uc = add_conditions_long_video(model, {}, batch_size, mask="uncond", input_shape=shape) 203 | else: 204 | print('NO Uncondition guidance') 205 | uc=None 206 | 207 | # sample an interpolation clip 208 | sample = sample_clip(model, shape, *args, cond=cond, uc=uc, 209 | uncond_scale=uncond_scale, 210 | uc_type=uc_type, 211 | **kwargs)['sample'] 212 | ext = sample[:, :, 1:-1, :, :] 213 | assert(ext.dim() == 5) 214 | assert(ext.shape[2] == T - 2) 215 | # ----------------------------------------------- 216 | if i != n_steps - 1: 217 | long_samples.extend([cond_z0[:, :, 0:1, :, :].cpu(), ext.cpu()]) 218 | else: 219 | long_samples.extend([cond_z0[:, :, 0:1, :, :].cpu(), ext.cpu(), cond_z0[:, :, 1:, :, :].cpu()]) 220 | # ----------------------------------------------- 221 | 222 | torch.cuda.empty_cache() 223 | long_samples = torch.cat(long_samples, dim=2) 224 | 225 | # decode 226 | long_samples_decoded = [] 227 | print('Decoding ...') 228 | for i in trange(long_samples.shape[0]): 229 | torch.cuda.empty_cache() 230 | 231 | long_sample = long_samples[i].unsqueeze(0).cuda() 232 | if overlap_t != 0: 233 | print('Use overlapped decoding') 234 | res = model.overlapped_decode(long_sample, max_z_t=max_z_t, 235 | overlap_t=overlap_t).cpu() 236 | else: 237 | res = model.decode_first_stage(long_sample, 238 | bs=None, 239 | decode_single_video_allframes=decode_single_video_allframes, 240 | max_z_t=max_z_t).cpu() 241 | long_samples_decoded.append(res) 242 | torch.cuda.empty_cache() 243 | 244 | long_samples_decoded = torch.cat(long_samples_decoded, dim=0) 245 | 246 | # log 247 | t1 = time.time() 248 | log = {} 249 | log["sample"] = long_samples_decoded 250 | log["sample_z"] = long_samples 251 | log["time"] = t1 - t0 252 | torch.cuda.empty_cache() 253 | return log 254 | -------------------------------------------------------------------------------- /scripts/sample_text2video.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import argparse 4 | import yaml, math 5 | from tqdm import trange 6 | import torch 7 | import numpy as np 8 | from omegaconf import OmegaConf 9 | import torch.distributed as dist 10 | from pytorch_lightning import seed_everything 11 | 12 | from lvdm.samplers.ddim import DDIMSampler 13 | from lvdm.utils.common_utils import str2bool 14 | from lvdm.utils.dist_utils import setup_dist, gather_data 15 | from scripts.sample_utils import (load_model, 16 | get_conditions, make_model_input_shape, torch_to_np, sample_batch, 17 | save_results, 18 | save_args 19 | ) 20 | 21 | 22 | # ------------------------------------------------------------------------------------------ 23 | def get_parser(): 24 | parser = argparse.ArgumentParser() 25 | # basic args 26 | parser.add_argument("--ckpt_path", type=str, help="model checkpoint path") 27 | parser.add_argument("--config_path", type=str, help="model config path (a yaml file)") 28 | parser.add_argument("--prompt", type=str, help="input text prompts for text2video (a sentence OR a txt file).") 29 | parser.add_argument("--save_dir", type=str, help="results saving dir", default="results/") 30 | # device args 31 | parser.add_argument("--ddp", action='store_true', help="whether use pytorch ddp mode for parallel sampling (recommend for multi-gpu case)", default=False) 32 | parser.add_argument("--local_rank", type=int, help="is used for pytorch ddp mode", default=0) 33 | parser.add_argument("--gpu_id", type=int, help="choose a specific gpu", default=0) 34 | # sampling args 35 | parser.add_argument("--n_samples", type=int, help="how many samples for each text prompt", default=2) 36 | parser.add_argument("--batch_size", type=int, help="video batch size for sampling", default=1) 37 | parser.add_argument("--decode_frame_bs", type=int, help="frame batch size for framewise decoding", default=1) 38 | parser.add_argument("--sample_type", type=str, help="ddpm or ddim", default="ddim", choices=["ddpm", "ddim"]) 39 | parser.add_argument("--ddim_steps", type=int, help="ddim sampling -- number of ddim denoising timesteps", default=50) 40 | parser.add_argument("--eta", type=float, help="ddim sampling -- eta (0.0 yields deterministic sampling, 1.0 yields random sampling)", default=1.0) 41 | parser.add_argument("--seed", type=int, default=None, help="fix a seed for randomness (If you want to reproduce the sample results)") 42 | parser.add_argument("--num_frames", type=int, default=16, help="number of input frames") 43 | parser.add_argument("--show_denoising_progress", action='store_true', default=False, help="whether show denoising progress during sampling one batch",) 44 | parser.add_argument("--cfg_scale", type=float, default=15.0, help="classifier-free guidance scale") 45 | # saving args 46 | parser.add_argument("--save_mp4", type=str2bool, default=True, help="whether save samples in separate mp4 files", choices=["True", "true", "False", "false"]) 47 | parser.add_argument("--save_mp4_sheet", action='store_true', default=False, help="whether save samples in mp4 file",) 48 | parser.add_argument("--save_npz", action='store_true', default=False, help="whether save samples in npz file",) 49 | parser.add_argument("--save_jpg", action='store_true', default=False, help="whether save samples in jpg file",) 50 | parser.add_argument("--save_fps", type=int, default=8, help="fps of saved mp4 videos",) 51 | return parser 52 | 53 | # ------------------------------------------------------------------------------------------ 54 | @torch.no_grad() 55 | def sample_text2video(model, prompt, n_samples, batch_size, 56 | sample_type="ddim", sampler=None, 57 | ddim_steps=50, eta=1.0, cfg_scale=7.5, 58 | decode_frame_bs=1, 59 | ddp=False, all_gather=True, 60 | batch_progress=True, show_denoising_progress=False, 61 | num_frames=None, 62 | ): 63 | # get cond vector 64 | assert(model.cond_stage_model is not None) 65 | cond_embd = get_conditions(prompt, model, batch_size) 66 | uncond_embd = get_conditions("", model, batch_size) if cfg_scale != 1.0 else None 67 | 68 | # sample batches 69 | all_videos = [] 70 | n_iter = math.ceil(n_samples / batch_size) 71 | iterator = trange(n_iter, desc="Sampling Batches (text-to-video)") if batch_progress else range(n_iter) 72 | for _ in iterator: 73 | noise_shape = make_model_input_shape(model, batch_size, T=num_frames) 74 | samples_latent = sample_batch(model, noise_shape, cond_embd, 75 | sample_type=sample_type, 76 | sampler=sampler, 77 | ddim_steps=ddim_steps, 78 | eta=eta, 79 | unconditional_guidance_scale=cfg_scale, 80 | uc=uncond_embd, 81 | denoising_progress=show_denoising_progress, 82 | ) 83 | samples = model.decode_first_stage(samples_latent, decode_bs=decode_frame_bs, return_cpu=False) 84 | 85 | # gather samples from multiple gpus 86 | if ddp and all_gather: 87 | data_list = gather_data(samples, return_np=False) 88 | all_videos.extend([torch_to_np(data) for data in data_list]) 89 | else: 90 | all_videos.append(torch_to_np(samples)) 91 | 92 | all_videos = np.concatenate(all_videos, axis=0) 93 | assert(all_videos.shape[0] >= n_samples) 94 | return all_videos 95 | 96 | 97 | 98 | # ------------------------------------------------------------------------------------------ 99 | def main(): 100 | """ 101 | text-to-video generation 102 | """ 103 | parser = get_parser() 104 | opt, unknown = parser.parse_known_args() 105 | os.makedirs(opt.save_dir, exist_ok=True) 106 | save_args(opt.save_dir, opt) 107 | 108 | # set device 109 | if opt.ddp: 110 | setup_dist(opt.local_rank) 111 | opt.n_samples = math.ceil(opt.n_samples / dist.get_world_size()) 112 | gpu_id = None 113 | else: 114 | gpu_id = opt.gpu_id 115 | os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu_id}" 116 | 117 | # set random seed 118 | if opt.seed is not None: 119 | if opt.ddp: 120 | seed = opt.local_rank + opt.seed 121 | else: 122 | seed = opt.seed 123 | seed_everything(seed) 124 | 125 | # load & merge config 126 | config = OmegaConf.load(opt.config_path) 127 | cli = OmegaConf.from_dotlist(unknown) 128 | config = OmegaConf.merge(config, cli) 129 | print("config: \n", config) 130 | 131 | # get model & sampler 132 | model, _, _ = load_model(config, opt.ckpt_path) 133 | ddim_sampler = DDIMSampler(model) if opt.sample_type == "ddim" else None 134 | 135 | # prepare prompt 136 | if opt.prompt.endswith(".txt"): 137 | opt.prompt_file = opt.prompt 138 | opt.prompt = None 139 | else: 140 | opt.prompt_file = None 141 | 142 | if opt.prompt_file is not None: 143 | f = open(opt.prompt_file, 'r') 144 | prompts, line_idx = [], [] 145 | for idx, line in enumerate(f.readlines()): 146 | l = line.strip() 147 | if len(l) != 0: 148 | prompts.append(l) 149 | line_idx.append(idx) 150 | f.close() 151 | cmd = f"cp {opt.prompt_file} {opt.save_dir}" 152 | os.system(cmd) 153 | else: 154 | prompts = [opt.prompt] 155 | line_idx = [None] 156 | 157 | # go 158 | start = time.time() 159 | for prompt in prompts: 160 | # sample 161 | samples = sample_text2video(model, prompt, opt.n_samples, opt.batch_size, 162 | sample_type=opt.sample_type, sampler=ddim_sampler, 163 | ddim_steps=opt.ddim_steps, eta=opt.eta, 164 | cfg_scale=opt.cfg_scale, 165 | decode_frame_bs=opt.decode_frame_bs, 166 | ddp=opt.ddp, show_denoising_progress=opt.show_denoising_progress, 167 | num_frames=opt.num_frames, 168 | ) 169 | # save 170 | if (opt.ddp and dist.get_rank() == 0) or (not opt.ddp): 171 | prompt_str = prompt.replace("/", "_slash_") if "/" in prompt else prompt 172 | save_name = prompt_str.replace(" ", "_") if " " in prompt else prompt_str 173 | if opt.seed is not None: 174 | save_name = save_name + f"_seed{seed:05d}" 175 | save_results(samples, opt.save_dir, save_name=save_name, save_fps=opt.save_fps) 176 | print("Finish sampling!") 177 | print(f"Run time = {(time.time() - start):.2f} seconds") 178 | 179 | if opt.ddp: 180 | dist.destroy_process_group() 181 | 182 | 183 | # ------------------------------------------------------------------------------------------ 184 | if __name__ == "__main__": 185 | main() -------------------------------------------------------------------------------- /scripts/sample_uncond.py: -------------------------------------------------------------------------------- 1 | import os,sys 2 | sys.path.insert(1, os.path.join(sys.path[0], '..')) 3 | import time 4 | import math 5 | import argparse 6 | import numpy as np 7 | from tqdm import trange 8 | from omegaconf import OmegaConf 9 | 10 | import torch 11 | import torch.distributed as dist 12 | from pytorch_lightning import seed_everything 13 | 14 | from lvdm.samplers.ddim import DDIMSampler 15 | from lvdm.utils.common_utils import torch_to_np, str2bool 16 | from lvdm.utils.dist_utils import setup_dist, gather_data 17 | from scripts.sample_utils import ( 18 | load_model, 19 | save_args, 20 | make_model_input_shape, 21 | sample_batch, 22 | save_results, 23 | ) 24 | 25 | # ------------------------------------------------------------------------------------------ 26 | def get_parser(): 27 | parser = argparse.ArgumentParser() 28 | # basic args 29 | parser.add_argument("--ckpt_path", type=str, help="model checkpoint path") 30 | parser.add_argument("--config_path", type=str, help="model config path (a yaml file)") 31 | parser.add_argument("--prompt", type=str, help="input text prompts for text2video (a sentence OR a txt file).") 32 | parser.add_argument("--save_dir", type=str, help="results saving dir", default="results/") 33 | # device args 34 | parser.add_argument("--ddp", action='store_true', help="whether use pytorch ddp mode for parallel sampling (recommend for multi-gpu case)", default=False) 35 | parser.add_argument("--local_rank", type=int, help="is used for pytorch ddp mode", default=0) 36 | parser.add_argument("--gpu_id", type=int, help="choose a specific gpu", default=0) 37 | # sampling args 38 | parser.add_argument("--n_samples", type=int, help="how many samples for each text prompt", default=2) 39 | parser.add_argument("--batch_size", type=int, help="video batch size for sampling", default=1) 40 | parser.add_argument("--sample_type", type=str, help="ddpm or ddim", default="ddpm", choices=["ddpm", "ddim"]) 41 | parser.add_argument("--ddim_steps", type=int, help="ddim sampling -- number of ddim denoising timesteps", default=50) 42 | parser.add_argument("--eta", type=float, help="ddim sampling -- eta (0.0 yields deterministic sampling, 1.0 yields random sampling)", default=1.0) 43 | parser.add_argument("--seed", type=int, default=None, help="fix a seed for randomness (If you want to reproduce the sample results)") 44 | parser.add_argument("--num_frames", type=int, default=None, help="number of input frames") 45 | parser.add_argument("--show_denoising_progress", action='store_true', default=False, help="whether show denoising progress during sampling one batch",) 46 | parser.add_argument("--uncond_scale", type=float, default=15.0, help="uncondition guidance scale") 47 | # saving args 48 | parser.add_argument("--save_mp4", type=str2bool, default=True, help="whether save samples in separate mp4 files", choices=["True", "true", "False", "false"]) 49 | parser.add_argument("--save_mp4_sheet", action='store_true', default=False, help="whether save samples in mp4 file",) 50 | parser.add_argument("--save_npz", action='store_true', default=False, help="whether save samples in npz file",) 51 | parser.add_argument("--save_jpg", action='store_true', default=False, help="whether save samples in jpg file",) 52 | parser.add_argument("--save_fps", type=int, default=8, help="fps of saved mp4 videos",) 53 | return parser 54 | 55 | # ------------------------------------------------------------------------------------------ 56 | @torch.no_grad() 57 | def sample(model, noise_shape, n_iters, ddp=False, **kwargs): 58 | all_videos = [] 59 | for _ in trange(n_iters, desc="Sampling Batches (unconditional)"): 60 | samples = sample_batch(model, noise_shape, condition=None, **kwargs) 61 | samples = model.decode_first_stage(samples) 62 | if ddp: # gather samples from multiple gpus 63 | data_list = gather_data(samples, return_np=False) 64 | all_videos.extend([torch_to_np(data) for data in data_list]) 65 | else: 66 | all_videos.append(torch_to_np(samples)) 67 | all_videos = np.concatenate(all_videos, axis=0) 68 | return all_videos 69 | 70 | # ------------------------------------------------------------------------------------------ 71 | def main(): 72 | """ 73 | unconditional generation of short videos 74 | """ 75 | parser = get_parser() 76 | opt, unknown = parser.parse_known_args() 77 | os.makedirs(opt.save_dir, exist_ok=True) 78 | save_args(opt.save_dir, opt) 79 | 80 | # set device 81 | if opt.ddp: 82 | setup_dist(opt.local_rank) 83 | opt.n_samples = math.ceil(opt.n_samples / dist.get_world_size()) 84 | else: 85 | os.environ["CUDA_VISIBLE_DEVICES"] = f"{opt.gpu_id}" 86 | 87 | # set random seed 88 | if opt.seed is not None: 89 | seed = opt.local_rank + opt.seed if opt.ddp else opt.seed 90 | seed_everything(seed) 91 | 92 | # load & merge config 93 | config = OmegaConf.load(opt.config_path) 94 | cli = OmegaConf.from_dotlist(unknown) 95 | config = OmegaConf.merge(config, cli) 96 | print("config: \n", config) 97 | 98 | # get model & sampler 99 | model, _, _ = load_model(config, opt.ckpt_path) 100 | ddim_sampler = DDIMSampler(model) if opt.sample_type == "ddim" else None 101 | 102 | # sample 103 | start = time.time() 104 | noise_shape = make_model_input_shape(model, opt.batch_size, T=opt.num_frames) 105 | ngpus = 1 if not opt.ddp else dist.get_world_size() 106 | n_iters = math.ceil(opt.n_samples / (ngpus * opt.batch_size)) 107 | samples = sample(model, noise_shape, n_iters, sampler=ddim_sampler, **vars(opt)) 108 | assert(samples.shape[0] >= opt.n_samples) 109 | 110 | # save 111 | if (opt.ddp and dist.get_rank() == 0) or (not opt.ddp): 112 | if opt.seed is not None: 113 | save_name = f"seed{seed:05d}" 114 | save_results(samples, opt.save_dir, save_name=save_name, save_fps=opt.save_fps) 115 | print("Finish sampling!") 116 | print(f"Run time = {(time.time() - start):.2f} seconds") 117 | 118 | if opt.ddp: 119 | dist.destroy_process_group() 120 | 121 | # ------------------------------------------------------------------------------------------ 122 | if __name__ == "__main__": 123 | main() -------------------------------------------------------------------------------- /scripts/sample_uncond_long_videos.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.insert(1, os.path.join(sys.path[0], '..')) 4 | 5 | import math 6 | import torch 7 | import time 8 | import argparse 9 | import numpy as np 10 | from tqdm import trange 11 | from omegaconf import OmegaConf 12 | import torch.distributed as dist 13 | from pytorch_lightning import seed_everything 14 | 15 | from lvdm.utils.dist_utils import setup_dist, gather_data 16 | from lvdm.utils.common_utils import torch_to_np, str2bool 17 | from scripts.sample_long_videos_utils import autoregressive_pred, interpolate 18 | from scripts.sample_utils import load_model, save_args, save_results 19 | 20 | def get_parser(): 21 | parser = argparse.ArgumentParser() 22 | # basic args 23 | parser.add_argument("--ckpt_pred", type=str, default="", help="ckpt path") 24 | parser.add_argument("--ckpt_interp", type=str, default=None, help="ckpt path") 25 | parser.add_argument("--config_pred", type=str, default="", help="config yaml") 26 | parser.add_argument("--config_interp", type=str, default=None, help="config yaml") 27 | parser.add_argument("--save_dir", type=str, default="results/longvideos", help="results saving dir") 28 | # device args 29 | parser.add_argument("--ddp", action='store_true', help="whether use pytorch ddp mode for parallel sampling (recommend for multi-gpu case)", default=False) 30 | parser.add_argument("--local_rank", type=int, help="is used for pytorch ddp mode", default=0) 31 | parser.add_argument("--gpu_id", type=int, help="choose a specific gpu", default=0) 32 | # sampling args 33 | parser.add_argument("--n_samples", type=int, help="how many samples for each text prompt", default=2) 34 | parser.add_argument("--batch_size", type=int, help="video batch size for sampling", default=1) 35 | parser.add_argument("--sample_type", type=str, help="ddpm or ddim", default="ddpm", choices=["ddpm", "ddim"]) 36 | parser.add_argument("--ddim_steps", type=int, help="ddim sampling -- number of ddim denoising timesteps", default=50) 37 | parser.add_argument("--eta", type=float, help="ddim sampling -- eta (0.0 yields deterministic sampling, 1.0 yields random sampling)", default=1.0) 38 | parser.add_argument("--seed", type=int, default=None, help="fix a seed for randomness (If you want to reproduce the sample results)") 39 | parser.add_argument("--num_frames", type=int, default=None, help="number of input frames") 40 | parser.add_argument("--show_denoising_progress", action='store_true', default=False, help="whether show denoising progress during sampling one batch",) 41 | parser.add_argument("--uncond_scale", type=float, default=1.0, help="unconditional guidance scale") 42 | parser.add_argument("--uc_type", type=str, help="unconditional guidance scale", default="cfg_original", choices=["cfg_original", None]) 43 | # prediction & interpolation args 44 | parser.add_argument("--T_cond", type=int, default=1, help="temporal length of condition frames") 45 | parser.add_argument("--n_pred_steps", type=int, default=None, help="") 46 | parser.add_argument("--sample_cond_noise_level", type=int, default=None, help="") 47 | parser.add_argument("--overlap_t", type=int, default=0, help="") 48 | # saving args 49 | parser.add_argument("--save_mp4", type=str2bool, default=True, help="whether save samples in separate mp4 files", choices=["True", "true", "False", "false"]) 50 | parser.add_argument("--save_mp4_sheet", action='store_true', default=False, help="whether save samples in mp4 file",) 51 | parser.add_argument("--save_npz", action='store_true', default=False, help="whether save samples in npz file",) 52 | parser.add_argument("--save_jpg", action='store_true', default=False, help="whether save samples in jpg file",) 53 | parser.add_argument("--save_fps", type=int, default=8, help="fps of saved mp4 videos",) 54 | return parser 55 | 56 | # ------------------------------------------------------------------------------------------ 57 | @torch.no_grad() 58 | def sample(model_pred, model_interp, n_iters, ddp=False, all_gather=False, **kwargs): 59 | all_videos = [] 60 | for _ in trange(n_iters, desc="Sampling Batches (unconditional)"): 61 | 62 | # autoregressive predict latents 63 | logs = autoregressive_pred(model_pred, **kwargs) 64 | samples_z = logs['sample'] if isinstance(logs, dict) else logs 65 | 66 | # interpolate latents 67 | logs = interpolate(samples_z, model_interp, **kwargs) 68 | samples=logs['sample'] 69 | 70 | if ddp and all_gather: # gather samples from multiple gpus 71 | data_list = gather_data(samples, return_np=False) 72 | all_videos.extend([torch_to_np(data) for data in data_list]) 73 | else: 74 | all_videos.append(torch_to_np(samples)) 75 | all_videos = np.concatenate(all_videos, axis=0) 76 | return all_videos 77 | 78 | # ------------------------------------------------------------------------------------------ 79 | def main(): 80 | """ 81 | unconditional generation of long videos 82 | """ 83 | parser = get_parser() 84 | opt, unknown = parser.parse_known_args() 85 | os.makedirs(opt.save_dir, exist_ok=True) 86 | save_args(opt.save_dir, opt) 87 | 88 | # set device 89 | if opt.ddp: 90 | setup_dist(opt.local_rank) 91 | opt.n_samples = math.ceil(opt.n_samples / dist.get_world_size()) 92 | else: 93 | os.environ["CUDA_VISIBLE_DEVICES"] = f"{opt.gpu_id}" 94 | 95 | # set random seed 96 | if opt.seed is not None: 97 | seed = opt.local_rank + opt.seed if opt.ddp else opt.seed 98 | seed_everything(seed) 99 | 100 | # load & merge config 101 | config_pred = OmegaConf.load(opt.config_pred) 102 | cli = OmegaConf.from_dotlist(unknown) 103 | config_pred = OmegaConf.merge(config_pred, cli) 104 | if opt.config_interp is not None: 105 | config_interp = OmegaConf.load(opt.config_interp) 106 | cli = OmegaConf.from_dotlist(unknown) 107 | config_interp = OmegaConf.merge(config_interp, cli) 108 | 109 | # calculate n_pred_steps 110 | if opt.num_frames is not None: 111 | temporal_down = config_pred.model.params.first_stage_config.params.ddconfig.encoder.params.downsample[0] 112 | temporal_length_z = opt.num_frames // temporal_down 113 | model_pred_length = config_pred.model.params.unet_config.params.temporal_length 114 | if opt.config_interp is not None: 115 | model_interp_length = config_interp.model.params.unet_config.params.temporal_length - 2 116 | pred_length = math.ceil((temporal_length_z + model_interp_length) / (model_interp_length + 1)) 117 | else: 118 | pred_length = temporal_length_z 119 | n_pred_steps = math.ceil((pred_length - model_pred_length) / (model_pred_length - 1)) 120 | opt.n_pred_steps = n_pred_steps 121 | print(f'Temporal length {opt.num_frames} needs latent length = {temporal_length_z}; \n \ 122 | pred_length = {pred_length}; \n \ 123 | prediction steps = {n_pred_steps}') 124 | else: 125 | assert(opt.n_pred_steps is not None) 126 | 127 | # model 128 | model_pred, _, _ = load_model(config_pred, opt.ckpt_pred) 129 | model_interp, _, _ = load_model(config_interp, opt.ckpt_interp) 130 | 131 | # sample 132 | start = time.time() 133 | ngpus = 1 if not opt.ddp else dist.get_world_size() 134 | n_iters = math.ceil(opt.n_samples / (ngpus * opt.batch_size)) 135 | samples = sample(model_pred, model_interp, n_iters, **vars(opt)) 136 | assert(samples.shape[0] >= opt.n_samples) 137 | 138 | # save 139 | if (opt.ddp and dist.get_rank() == 0) or (not opt.ddp): 140 | if opt.seed is not None: 141 | save_name = f"seed{seed:05d}" 142 | save_results(samples, opt.save_dir, save_name=save_name, save_fps=opt.save_fps) 143 | print("Finish sampling!") 144 | print(f"total time = {int(time.time()- start)} seconds, \ 145 | num of iters = {n_iters}; \ 146 | num of samples = {ngpus * opt.batch_size * n_iters}; \ 147 | temporal length = {opt.num_frames}") 148 | 149 | if opt.ddp: 150 | dist.destroy_process_group() 151 | 152 | # ------------------------------------------------------------------------------------------ 153 | if __name__ == "__main__": 154 | main() -------------------------------------------------------------------------------- /scripts/sample_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | import numpy as np 4 | from PIL import Image 5 | 6 | import torch 7 | 8 | from lvdm.utils.common_utils import instantiate_from_config 9 | from lvdm.utils.saving_utils import npz_to_video_grid, npz_to_imgsheet_5d 10 | 11 | def custom_to_pil(x): 12 | x = x.detach().cpu() 13 | x = torch.clamp(x, -1., 1.) 14 | x = (x + 1.) / 2. 15 | x = x.permute(1, 2, 0).numpy() 16 | x = (255 * x).astype(np.uint8) 17 | x = Image.fromarray(x) 18 | if not x.mode == "RGB": 19 | x = x.convert("RGB") 20 | return x 21 | 22 | def custom_to_np(x): 23 | # saves the batch in adm style as in https://github.com/openai/guided-diffusion/blob/main/scripts/image_sample.py 24 | sample = x.detach().cpu() 25 | sample = ((sample + 1) * 127.5).clamp(0, 255).to(torch.uint8) 26 | if sample.dim() == 5: 27 | sample = sample.permute(0, 2, 3, 4, 1) 28 | else: 29 | sample = sample.permute(0, 2, 3, 1) 30 | sample = sample.contiguous() 31 | return sample 32 | 33 | def save_args(save_dir, args): 34 | fpath = os.path.join(save_dir, "sampling_args.yaml") 35 | with open(fpath, 'w') as f: 36 | yaml.dump(vars(args), f, default_flow_style=False) 37 | 38 | # ------------------------------------------------------------------------------------------ 39 | def load_model(config, ckpt_path, gpu_id=None): 40 | print(f"Loading model from {ckpt_path}") 41 | 42 | # load sd 43 | pl_sd = torch.load(ckpt_path, map_location="cpu") 44 | try: 45 | global_step = pl_sd["global_step"] 46 | epoch = pl_sd["epoch"] 47 | except: 48 | global_step = -1 49 | epoch = -1 50 | 51 | # load sd to model 52 | try: 53 | sd = pl_sd["state_dict"] 54 | except: 55 | sd = pl_sd 56 | model = instantiate_from_config(config.model) 57 | model.load_state_dict(sd, strict=True) 58 | 59 | # move to device & eval 60 | if gpu_id is not None: 61 | model.to(f"cuda:{gpu_id}") 62 | else: 63 | model.cuda() 64 | model.eval() 65 | 66 | return model, global_step, epoch 67 | 68 | def make_sample_dir(opt, global_step, epoch): 69 | if not getattr(opt, 'not_automatic_logdir', False): 70 | gs_str = f"globalstep{global_step:09}" if global_step is not None else "None" 71 | e_str = f"epoch{epoch:06}" if epoch is not None else "None" 72 | ckpt_dir = os.path.join(opt.logdir, f"{gs_str}_{e_str}") 73 | 74 | # subdir name 75 | if opt.prompt_file is not None: 76 | subdir = f"prompts_{os.path.splitext(os.path.basename(opt.prompt_file))[0]}" 77 | else: 78 | subdir = f"prompt_{opt.prompt[:10]}" 79 | subdir += "_DDPM" if opt.vanilla_sample else f"_DDIM{opt.custom_steps}steps" 80 | subdir += f"_CfgScale{opt.scale}" 81 | if opt.cond_fps is not None: 82 | subdir += f"_fps{opt.cond_fps}" 83 | if opt.seed is not None: 84 | subdir += f"_seed{opt.seed}" 85 | 86 | return os.path.join(ckpt_dir, subdir) 87 | else: 88 | return opt.logdir 89 | 90 | # ------------------------------------------------------------------------------------------ 91 | @torch.no_grad() 92 | def get_conditions(prompts, model, batch_size, cond_fps=None,): 93 | 94 | if isinstance(prompts, str) or isinstance(prompts, int): 95 | prompts = [prompts] 96 | if isinstance(prompts, list): 97 | if len(prompts) == 1: 98 | prompts = prompts * batch_size 99 | elif len(prompts) == batch_size: 100 | pass 101 | else: 102 | raise ValueError(f"invalid prompts length: {len(prompts)}") 103 | else: 104 | raise ValueError(f"invalid prompts: {prompts}") 105 | assert(len(prompts) == batch_size) 106 | 107 | # content condition: text / class label 108 | c = model.get_learned_conditioning(prompts) 109 | key = 'c_concat' if model.conditioning_key == 'concat' else 'c_crossattn' 110 | c = {key: [c]} 111 | 112 | # temporal condition: fps 113 | if getattr(model, 'cond_stage2_config', None) is not None: 114 | if model.cond_stage2_key == "temporal_context": 115 | assert(cond_fps is not None) 116 | batch = {'fps': torch.tensor([cond_fps] * batch_size).long().to(model.device)} 117 | fps_embd = model.cond_stage2_model(batch) 118 | c[model.cond_stage2_key] = fps_embd 119 | 120 | return c 121 | 122 | # ------------------------------------------------------------------------------------------ 123 | def make_model_input_shape(model, batch_size, T=None): 124 | image_size = [model.image_size, model.image_size] if isinstance(model.image_size, int) else model.image_size 125 | C = model.model.diffusion_model.in_channels 126 | if T is None: 127 | T = model.model.diffusion_model.temporal_length 128 | shape = [batch_size, C, T, *image_size] 129 | return shape 130 | 131 | # ------------------------------------------------------------------------------------------ 132 | def sample_batch(model, noise_shape, condition, 133 | sample_type="ddim", 134 | sampler=None, 135 | ddim_steps=None, 136 | eta=None, 137 | unconditional_guidance_scale=1.0, 138 | uc=None, 139 | denoising_progress=False, 140 | **kwargs, 141 | ): 142 | 143 | if sample_type == "ddpm": 144 | samples = model.p_sample_loop(cond=condition, shape=noise_shape, 145 | return_intermediates=False, 146 | verbose=denoising_progress, 147 | ) 148 | elif sample_type == "ddim": 149 | assert(sampler is not None) 150 | assert(ddim_steps is not None) 151 | assert(eta is not None) 152 | ddim_sampler = sampler 153 | samples, _ = ddim_sampler.sample(S=ddim_steps, 154 | conditioning=condition, 155 | batch_size=noise_shape[0], 156 | shape=noise_shape[1:], 157 | verbose=denoising_progress, 158 | unconditional_guidance_scale=unconditional_guidance_scale, 159 | unconditional_conditioning=uc, 160 | eta=eta, 161 | **kwargs, 162 | ) 163 | else: 164 | raise ValueError 165 | return samples 166 | 167 | # ------------------------------------------------------------------------------------------ 168 | def torch_to_np(x): 169 | # saves the batch in adm style as in https://github.com/openai/guided-diffusion/blob/main/scripts/image_sample.py 170 | sample = x.detach().cpu() 171 | sample = ((sample + 1) * 127.5).clamp(0, 255).to(torch.uint8) 172 | if sample.dim() == 5: 173 | sample = sample.permute(0, 2, 3, 4, 1) 174 | else: 175 | sample = sample.permute(0, 2, 3, 1) 176 | sample = sample.contiguous() 177 | return sample 178 | # ------------------------------------------------------------------------------------------ 179 | def save_results(videos, save_dir, 180 | save_name="results", save_fps=8, save_mp4=True, 181 | save_npz=False, save_mp4_sheet=False, save_jpg=False 182 | ): 183 | if save_mp4: 184 | save_subdir = os.path.join(save_dir, "videos") 185 | os.makedirs(save_subdir, exist_ok=True) 186 | shape_str = "x".join([str(x) for x in videos[0:1,...].shape]) 187 | for i in range(videos.shape[0]): 188 | npz_to_video_grid(videos[i:i+1,...], 189 | os.path.join(save_subdir, f"{save_name}_{i:03d}_{shape_str}.mp4"), 190 | fps=save_fps) 191 | print(f'Successfully saved videos in {save_subdir}') 192 | 193 | shape_str = "x".join([str(x) for x in videos.shape]) 194 | if save_npz: 195 | save_path = os.path.join(save_dir, f"{save_name}_{shape_str}.npz") 196 | np.savez(save_path, videos) 197 | print(f'Successfully saved npz in {save_path}') 198 | 199 | if save_mp4_sheet: 200 | save_path = os.path.join(save_dir, f"{save_name}_{shape_str}.mp4") 201 | npz_to_video_grid(videos, save_path, fps=save_fps) 202 | print(f'Successfully saved mp4 sheet in {save_path}') 203 | 204 | if save_jpg: 205 | save_path = os.path.join(save_dir, f"{save_name}_{shape_str}.jpg") 206 | npz_to_imgsheet_5d(videos, save_path, nrow=videos.shape[1]) 207 | print(f'Successfully saved jpg sheet in {save_path}') 208 | 209 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='lvdm', 5 | version='0.0.1', 6 | description='generic video generation models', 7 | packages=find_packages(), 8 | install_requires=[], 9 | ) -------------------------------------------------------------------------------- /shellscripts/eval_lvdm_short.sh: -------------------------------------------------------------------------------- 1 | 2 | DATACONFIG="configs/lvdm_short/sky.yaml" 3 | FAKEPATH='${Your_Path}/2048x16x256x256x3-samples.npz' 4 | REALPATH='/dataset/sky_timelapse' 5 | RESDIR='results/fvd' 6 | 7 | mkdir -p $res_dir 8 | python scripts/eval_cal_fvd_kvd.py \ 9 | --yaml ${DATACONFIG} \ 10 | --real_path ${REALPATH} \ 11 | --fake_path ${FAKEPATH} \ 12 | --batch_size 32 \ 13 | --num_workers 4 \ 14 | --n_runs 10 \ 15 | --res_dir ${RESDIR} \ 16 | --n_sample 2048 17 | -------------------------------------------------------------------------------- /shellscripts/sample_lvdm_long.sh: -------------------------------------------------------------------------------- 1 | 2 | CKPT_PRED="models/lvdm_long/sky_pred.ckpt" 3 | CKPT_INTERP="models/lvdm_long/sky_interp.ckpt" 4 | AEPATH="models/ae/ae_sky.ckpt" 5 | CONFIG_PRED="configs/lvdm_long/sky_pred.yaml" 6 | CONFIG_INTERP="configs/lvdm_long/sky_interp.yaml" 7 | OUTDIR="results/longvideos/" 8 | 9 | python scripts/sample_uncond_long_videos.py \ 10 | --ckpt_pred $CKPT_PRED \ 11 | --config_pred $CONFIG_PRED \ 12 | --ckpt_interp $CKPT_INTERP \ 13 | --config_interp $CONFIG_INTERP \ 14 | --save_dir $OUTDIR \ 15 | --n_samples 1 \ 16 | --batch_size 1 \ 17 | --seed 1000 \ 18 | --show_denoising_progress \ 19 | model.params.first_stage_config.params.ckpt_path=$AEPATH \ 20 | --sample_cond_noise_level 100 \ 21 | --uncond_scale 0.1 \ 22 | --n_pred_steps 2 \ 23 | --sample_type ddim --ddim_steps 50 24 | 25 | # if use DDPM: remove: `--sample_type ddim --ddim_steps 50` -------------------------------------------------------------------------------- /shellscripts/sample_lvdm_short.sh: -------------------------------------------------------------------------------- 1 | 2 | CONFIG_PATH="configs/lvdm_short/sky.yaml" 3 | BASE_PATH="models/lvdm_short/short_sky.ckpt" 4 | AEPATH="models/ae/ae_sky.ckpt" 5 | OUTDIR="results/uncond_short/" 6 | 7 | python scripts/sample_uncond.py \ 8 | --ckpt_path $BASE_PATH \ 9 | --config_path $CONFIG_PATH \ 10 | --save_dir $OUTDIR \ 11 | --n_samples 1 \ 12 | --batch_size 1 \ 13 | --seed 1000 \ 14 | --show_denoising_progress \ 15 | model.params.first_stage_config.params.ckpt_path=$AEPATH 16 | 17 | # if use DDIM: add: `--sample_type ddim --ddim_steps 50` 18 | -------------------------------------------------------------------------------- /shellscripts/sample_lvdm_text2video.sh: -------------------------------------------------------------------------------- 1 | 2 | PROMPT="astronaut riding a horse" # OR: PROMPT="input/prompts.txt" for sampling multiple prompts 3 | OUTDIR="results/t2v" 4 | 5 | CKPT_PATH="models/t2v/model.ckpt" 6 | CONFIG_PATH="configs/lvdm_short/text2video.yaml" 7 | 8 | python scripts/sample_text2video.py \ 9 | --ckpt_path $CKPT_PATH \ 10 | --config_path $CONFIG_PATH \ 11 | --prompt "$PROMPT" \ 12 | --save_dir $OUTDIR \ 13 | --n_samples 1 \ 14 | --batch_size 1 \ 15 | --seed 1000 \ 16 | --show_denoising_progress \ 17 | --save_jpg 18 | -------------------------------------------------------------------------------- /shellscripts/train_lvdm_interp_sky.sh: -------------------------------------------------------------------------------- 1 | export PATH=/apdcephfs_cq2/share_1290939/yingqinghe/anaconda/envs/ldmA100/bin/:$PATH 2 | export http_proxy="http://star-proxy.oa.com:3128" 3 | export https_proxy="http://star-proxy.oa.com:3128" 4 | export ftp_proxy="http://star-proxy.oa.com:3128" 5 | export no_proxy=".woa.com,mirrors.cloud.tencent.com,tlinux-mirror.tencent-cloud.com,tlinux-mirrorlist.tencent-cloud.com,localhost,127.0.0.1,mirrors-tlinux.tencentyun.com,.oa.com,.local,.3gqq.com,.7700.org,.ad.com,.ada_sixjoy.com,.addev.com,.app.local,.apps.local,.aurora.com,.autotest123.com,.bocaiwawa.com,.boss.com,.cdc.com,.cdn.com,.cds.com,.cf.com,.cjgc.local,.cm.com,.code.com,.datamine.com,.dvas.com,.dyndns.tv,.ecc.com,.expochart.cn,.expovideo.cn,.fms.com,.great.com,.hadoop.sec,.heme.com,.home.com,.hotbar.com,.ibg.com,.ied.com,.ieg.local,.ierd.com,.imd.com,.imoss.com,.isd.com,.isoso.com,.itil.com,.kao5.com,.kf.com,.kitty.com,.lpptp.com,.m.com,.matrix.cloud,.matrix.net,.mickey.com,.mig.local,.mqq.com,.oiweb.com,.okbuy.isddev.com,.oss.com,.otaworld.com,.paipaioa.com,.qqbrowser.local,.qqinternal.com,.qqwork.com,.rtpre.com,.sc.oa.com,.sec.com,.server.com,.service.com,.sjkxinternal.com,.sllwrnm5.cn,.sng.local,.soc.com,.t.km,.tcna.com,.teg.local,.tencentvoip.com,.tenpayoa.com,.test.air.tenpay.com,.tr.com,.tr_autotest123.com,.vpn.com,.wb.local,.webdev.com,.webdev2.com,.wizard.com,.wqq.com,.wsd.com,.sng.com,.music.lan,.mnet2.com,.tencentb2.com,.tmeoa.com,.pcg.com,www.wip3.adobe.com,www-mm.wip3.adobe.com,mirrors.tencent.com,csighub.tencentyun.com" 6 | 7 | export TOKENIZERS_PARALLELISM=false 8 | 9 | PROJ_ROOT="/apdcephfs_cq2/share_1290939/yingqinghe/results/latent_diffusion" 10 | EXPNAME="test_sky_train_interp" 11 | CONFIG="configs/lvdm_long/sky_interp.yaml" 12 | DATADIR="/dockerdata/sky_timelapse" 13 | AEPATH="/apdcephfs/share_1290939/yingqinghe/results/latent_diffusion/LVDM/ae_013_sky256_basedon003_4nodes_e0/checkpoints/trainstep_checkpoints/epoch=000299-step=000010199.ckpt" 14 | 15 | # run 16 | python main.py \ 17 | --base $CONFIG \ 18 | -t --gpus 0, \ 19 | --name $EXPNAME \ 20 | --logdir $PROJ_ROOT \ 21 | --auto_resume True \ 22 | lightning.trainer.num_nodes=1 \ 23 | data.params.train.params.data_root=$DATADIR \ 24 | data.params.validation.params.data_root=$DATADIR \ 25 | model.params.first_stage_config.params.ckpt_path=$AEPATH 26 | 27 | 28 | # commands for multi nodes training 29 | # --------------------------------------------------------------------------------------------------- 30 | # python -m torch.distributed.run \ 31 | # --nproc_per_node=8 --nnodes=$NHOST --master_addr=$MASTER_ADDR --master_port=1234 --node_rank=$INDEX \ 32 | # main.py \ 33 | # --base $CONFIG \ 34 | # -t --gpus 0,1,2,3,4,5,6,7 \ 35 | # --name $EXPNAME \ 36 | # --logdir $PROJ_ROOT \ 37 | # --auto_resume True \ 38 | # lightning.trainer.num_nodes=$NHOST \ 39 | # data.params.train.params.data_root=$DATADIR \ 40 | # data.params.validation.params.data_root=$DATADIR 41 | -------------------------------------------------------------------------------- /shellscripts/train_lvdm_pred_sky.sh: -------------------------------------------------------------------------------- 1 | export PATH=/apdcephfs_cq2/share_1290939/yingqinghe/anaconda/envs/ldmA100/bin/:$PATH 2 | export http_proxy="http://star-proxy.oa.com:3128" 3 | export https_proxy="http://star-proxy.oa.com:3128" 4 | export ftp_proxy="http://star-proxy.oa.com:3128" 5 | export no_proxy=".woa.com,mirrors.cloud.tencent.com,tlinux-mirror.tencent-cloud.com,tlinux-mirrorlist.tencent-cloud.com,localhost,127.0.0.1,mirrors-tlinux.tencentyun.com,.oa.com,.local,.3gqq.com,.7700.org,.ad.com,.ada_sixjoy.com,.addev.com,.app.local,.apps.local,.aurora.com,.autotest123.com,.bocaiwawa.com,.boss.com,.cdc.com,.cdn.com,.cds.com,.cf.com,.cjgc.local,.cm.com,.code.com,.datamine.com,.dvas.com,.dyndns.tv,.ecc.com,.expochart.cn,.expovideo.cn,.fms.com,.great.com,.hadoop.sec,.heme.com,.home.com,.hotbar.com,.ibg.com,.ied.com,.ieg.local,.ierd.com,.imd.com,.imoss.com,.isd.com,.isoso.com,.itil.com,.kao5.com,.kf.com,.kitty.com,.lpptp.com,.m.com,.matrix.cloud,.matrix.net,.mickey.com,.mig.local,.mqq.com,.oiweb.com,.okbuy.isddev.com,.oss.com,.otaworld.com,.paipaioa.com,.qqbrowser.local,.qqinternal.com,.qqwork.com,.rtpre.com,.sc.oa.com,.sec.com,.server.com,.service.com,.sjkxinternal.com,.sllwrnm5.cn,.sng.local,.soc.com,.t.km,.tcna.com,.teg.local,.tencentvoip.com,.tenpayoa.com,.test.air.tenpay.com,.tr.com,.tr_autotest123.com,.vpn.com,.wb.local,.webdev.com,.webdev2.com,.wizard.com,.wqq.com,.wsd.com,.sng.com,.music.lan,.mnet2.com,.tencentb2.com,.tmeoa.com,.pcg.com,www.wip3.adobe.com,www-mm.wip3.adobe.com,mirrors.tencent.com,csighub.tencentyun.com" 6 | 7 | export TOKENIZERS_PARALLELISM=false 8 | 9 | PROJ_ROOT="/apdcephfs_cq2/share_1290939/yingqinghe/results/latent_diffusion" 10 | EXPNAME="test_sky_train_pred" 11 | CONFIG="configs/lvdm_long/sky_pred.yaml" 12 | DATADIR="/dockerdata/sky_timelapse" 13 | AEPATH="/apdcephfs/share_1290939/yingqinghe/results/latent_diffusion/LVDM/ae_013_sky256_basedon003_4nodes_e0/checkpoints/trainstep_checkpoints/epoch=000299-step=000010199.ckpt" 14 | 15 | # run 16 | python main.py \ 17 | --base $CONFIG \ 18 | -t --gpus 0, \ 19 | --name $EXPNAME \ 20 | --logdir $PROJ_ROOT \ 21 | --auto_resume True \ 22 | lightning.trainer.num_nodes=1 \ 23 | data.params.train.params.data_root=$DATADIR \ 24 | data.params.validation.params.data_root=$DATADIR \ 25 | model.params.first_stage_config.params.ckpt_path=$AEPATH 26 | 27 | 28 | # commands for multi nodes training 29 | # --------------------------------------------------------------------------------------------------- 30 | # python -m torch.distributed.run \ 31 | # --nproc_per_node=8 --nnodes=$NHOST --master_addr=$MASTER_ADDR --master_port=1234 --node_rank=$INDEX \ 32 | # main.py \ 33 | # --base $CONFIG \ 34 | # -t --gpus 0,1,2,3,4,5,6,7 \ 35 | # --name $EXPNAME \ 36 | # --logdir $PROJ_ROOT \ 37 | # --auto_resume True \ 38 | # lightning.trainer.num_nodes=$NHOST \ 39 | # data.params.train.params.data_root=$DATADIR \ 40 | # data.params.validation.params.data_root=$DATADIR 41 | -------------------------------------------------------------------------------- /shellscripts/train_lvdm_short.sh: -------------------------------------------------------------------------------- 1 | 2 | PROJ_ROOT="" # root directory for saving experiment logs 3 | EXPNAME="lvdm_short_sky" # experiment name 4 | DATADIR="/dataset/sky_timelapse" # dataset directory 5 | AEPATH="models/ae/ae_sky.ckpt" # pretrained video autoencoder checkpoint 6 | 7 | CONFIG="configs/lvdm_short/sky.yaml" 8 | # OR CONFIG="configs/videoae/ucf.yaml" 9 | # OR CONFIG="configs/videoae/taichi.yaml" 10 | 11 | # run 12 | export TOKENIZERS_PARALLELISM=false 13 | python main.py \ 14 | --base $CONFIG \ 15 | -t --gpus 0, \ 16 | --name $EXPNAME \ 17 | --logdir $PROJ_ROOT \ 18 | --auto_resume True \ 19 | lightning.trainer.num_nodes=1 \ 20 | data.params.train.params.data_root=$DATADIR \ 21 | data.params.validation.params.data_root=$DATADIR \ 22 | model.params.first_stage_config.params.ckpt_path=$AEPATH 23 | 24 | # ------------------------------------------------------------------------------------------------- 25 | # commands for multi nodes training 26 | # - use torch.distributed.run to launch main.py 27 | # - set `gpus` and `lightning.trainer.num_nodes` 28 | 29 | # For example: 30 | 31 | # python -m torch.distributed.run \ 32 | # --nproc_per_node=8 --nnodes=$NHOST --master_addr=$MASTER_ADDR --master_port=1234 --node_rank=$INDEX \ 33 | # main.py \ 34 | # --base $CONFIG \ 35 | # -t --gpus 0,1,2,3,4,5,6,7 \ 36 | # --name $EXPNAME \ 37 | # --logdir $PROJ_ROOT \ 38 | # --auto_resume True \ 39 | # lightning.trainer.num_nodes=$NHOST \ 40 | # data.params.train.params.data_root=$DATADIR \ 41 | # data.params.validation.params.data_root=$DATADIR 42 | -------------------------------------------------------------------------------- /shellscripts/train_lvdm_videoae.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | PROJ_ROOT="results/" # root directory for saving experiment logs 4 | EXPNAME="lvdm_videoae_sky" # experiment name 5 | DATADIR="/dataset/sky_timelapse" # dataset directory 6 | 7 | CONFIG="configs/videoae/sky.yaml" 8 | # OR CONFIG="configs/videoae/ucf.yaml" 9 | # OR CONFIG="configs/videoae/taichi.yaml" 10 | 11 | # run 12 | export TOKENIZERS_PARALLELISM=false 13 | python main.py \ 14 | --base $CONFIG \ 15 | -t --gpus 0, \ 16 | --name $EXPNAME \ 17 | --logdir $PROJ_ROOT \ 18 | --auto_resume True \ 19 | lightning.trainer.num_nodes=1 \ 20 | data.params.train.params.data_root=$DATADIR \ 21 | data.params.validation.params.data_root=$DATADIR 22 | 23 | # ------------------------------------------------------------------------------------------------- 24 | # commands for multi nodes training 25 | # - use torch.distributed.run to launch main.py 26 | # - set `gpus` and `lightning.trainer.num_nodes` 27 | 28 | # For example: 29 | 30 | # python -m torch.distributed.run \ 31 | # --nproc_per_node=8 --nnodes=$NHOST --master_addr=$MASTER_ADDR --master_port=1234 --node_rank=$INDEX \ 32 | # main.py \ 33 | # --base $CONFIG \ 34 | # -t --gpus 0,1,2,3,4,5,6,7 \ 35 | # --name $EXPNAME \ 36 | # --logdir $PROJ_ROOT \ 37 | # --auto_resume True \ 38 | # lightning.trainer.num_nodes=$NHOST \ 39 | # data.params.train.params.data_root=$DATADIR \ 40 | # data.params.validation.params.data_root=$DATADIR 41 | -------------------------------------------------------------------------------- /shellscripts/train_lvdm_videoae_ucf.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | PROJ_ROOT="results/" # input the root directory for saving experiment logs 4 | EXPNAME="lvdm_videoae_ucf101" # experiment name 5 | DATADIR="UCF-101/" # input the dataset directory 6 | 7 | CONFIG="configs/videoae/ucf_videodata.yaml" 8 | 9 | # run 10 | export TOKENIZERS_PARALLELISM=false 11 | python main.py \ 12 | --base $CONFIG \ 13 | -t --gpus 0, \ 14 | --name $EXPNAME \ 15 | --logdir $PROJ_ROOT \ 16 | --auto_resume True \ 17 | lightning.trainer.num_nodes=1 \ 18 | data.params.train.params.data_root=$DATADIR \ 19 | data.params.validation.params.data_root=$DATADIR 20 | 21 | # ------------------------------------------------------------------------------------------------- 22 | # commands for multi nodes training 23 | # - use torch.distributed.run to launch main.py 24 | # - set `gpus` and `lightning.trainer.num_nodes` 25 | 26 | # For example: 27 | 28 | # python -m torch.distributed.run \ 29 | # --nproc_per_node=8 --nnodes=$NHOST --master_addr=$MASTER_ADDR --master_port=1234 --node_rank=$INDEX \ 30 | # main.py \ 31 | # --base $CONFIG \ 32 | # -t --gpus 0,1,2,3,4,5,6,7 \ 33 | # --name $EXPNAME \ 34 | # --logdir $PROJ_ROOT \ 35 | # --auto_resume True \ 36 | # lightning.trainer.num_nodes=$NHOST \ 37 | # data.params.train.params.data_root=$DATADIR \ 38 | # data.params.validation.params.data_root=$DATADIR 39 | --------------------------------------------------------------------------------