├── .gitignore ├── README.md ├── __pycache__ ├── convert_from_ckpt.cpython-310.pyc ├── convert_lora_safetensor_to_diffusers.cpython-310.pyc ├── filter_utils.cpython-310.pyc └── utils.cpython-310.pyc ├── animatediff ├── data │ ├── __pycache__ │ │ └── dataset.cpython-310.pyc │ └── dataset.py ├── models │ ├── __pycache__ │ │ ├── attention.cpython-310.pyc │ │ ├── attention.cpython-311.pyc │ │ ├── motion_module.cpython-310.pyc │ │ ├── motion_module.cpython-311.pyc │ │ ├── resnet.cpython-310.pyc │ │ ├── resnet.cpython-311.pyc │ │ ├── sparse_controlnet.cpython-310.pyc │ │ ├── unet.cpython-310.pyc │ │ ├── unet.cpython-311.pyc │ │ ├── unet_blocks.cpython-310.pyc │ │ └── unet_blocks.cpython-311.pyc │ ├── attention.py │ ├── motion_module.py │ ├── resnet.py │ ├── sparse_controlnet.py │ ├── unet.py │ └── unet_blocks.py ├── pipelines │ ├── __pycache__ │ │ ├── pipeline_animation.cpython-310.pyc │ │ ├── pipeline_attention.cpython-310.pyc │ │ ├── pipeline_sd.cpython-310.pyc │ │ └── pipeline_temporal_cfg.cpython-310.pyc │ ├── pipeline_animation.py │ ├── pipeline_attention.py │ ├── pipeline_sd.py │ └── pipeline_temporal_cfg.py └── utils │ ├── __pycache__ │ ├── convert_from_ckpt.cpython-310.pyc │ ├── convert_from_ckpt.cpython-311.pyc │ ├── convert_lora_safetensor_to_diffusers.cpython-310.pyc │ ├── convert_lora_safetensor_to_diffusers.cpython-311.pyc │ ├── util.cpython-310.pyc │ └── util.cpython-311.pyc │ ├── convert_from_ckpt.py │ ├── convert_lora_safetensor_to_diffusers.py │ └── util.py ├── animatediff_configs ├── inference │ ├── inference-v1.yaml │ ├── inference-v2.yaml │ ├── inference-v3.yaml │ └── sparsectrl │ │ ├── image_condition.yaml │ │ └── latent_condition.yaml ├── prompts │ └── v2 │ │ ├── v2-1-Film.yaml │ │ ├── v2-1-RCNZcartoon.yaml │ │ ├── v2-1-RealisticVision.yaml │ │ ├── v2-1-ToonYou.yaml │ │ ├── v2-1-w-dreambooth.yaml │ │ └── v2-1-w.o-dreambooth.yaml └── training │ └── v1 │ ├── image_finetune.yaml │ └── training.yaml ├── assets └── main_fig.png ├── convert_from_ckpt.py ├── convert_lora_safetensor_to_diffusers.py ├── filter_utils.py ├── inference.sh ├── lvdm ├── __pycache__ │ ├── basics.cpython-310.pyc │ ├── basics.cpython-38.pyc │ ├── common.cpython-310.pyc │ ├── common.cpython-311.pyc │ ├── common.cpython-38.pyc │ ├── distributions.cpython-310.pyc │ ├── distributions.cpython-38.pyc │ ├── ema.cpython-310.pyc │ └── ema.cpython-38.pyc ├── basics.py ├── common.py ├── distributions.py ├── ema.py ├── models │ ├── __pycache__ │ │ ├── autoencoder.cpython-310.pyc │ │ ├── autoencoder.cpython-38.pyc │ │ ├── ddpm3d.cpython-310.pyc │ │ ├── ddpm3d.cpython-38.pyc │ │ ├── utils_diffusion.cpython-310.pyc │ │ ├── utils_diffusion.cpython-311.pyc │ │ └── utils_diffusion.cpython-38.pyc │ ├── autoencoder.py │ ├── ddpm3d.py │ ├── samplers │ │ ├── __pycache__ │ │ │ ├── ddim.cpython-310.pyc │ │ │ ├── ddim.cpython-311.pyc │ │ │ └── ddim.cpython-38.pyc │ │ └── ddim.py │ └── utils_diffusion.py └── modules │ ├── __pycache__ │ ├── attention.cpython-310.pyc │ └── attention.cpython-38.pyc │ ├── attention.py │ ├── encoders │ ├── __pycache__ │ │ ├── condition.cpython-310.pyc │ │ ├── condition.cpython-38.pyc │ │ ├── ip_resampler.cpython-310.pyc │ │ └── ip_resampler.cpython-38.pyc │ ├── condition.py │ └── ip_resampler.py │ ├── networks │ ├── __pycache__ │ │ ├── ae_modules.cpython-310.pyc │ │ ├── ae_modules.cpython-38.pyc │ │ ├── openaimodel3d.cpython-310.pyc │ │ └── openaimodel3d.cpython-38.pyc │ ├── ae_modules.py │ └── openaimodel3d.py │ └── x_transformer.py ├── prompts └── prompt.txt ├── requirements.txt ├── scripts ├── evaluation │ ├── __pycache__ │ │ ├── funcs.cpython-310.pyc │ │ ├── funcs.cpython-311.pyc │ │ └── funcs.cpython-38.pyc │ ├── ddp_wrapper.py │ ├── funcs.py │ └── inference.py ├── gradio │ ├── i2v_test.py │ └── t2v_test.py ├── run_image2video.sh └── run_text2video.sh ├── t2v_vc_guide.py ├── utils.py ├── vc_configs └── inference_t2v_512_v2.0.yaml └── vc_utils ├── __pycache__ ├── utils.cpython-310.pyc ├── utils.cpython-311.pyc └── utils.cpython-38.pyc └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | result/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [CVPR2025] VideoGuide: Improving Video Diffusion Models without Training Through a Teacher's Guide 2 | 3 | This repository is the official implementation of [VideoGuide: Improving Video Diffusion Models without Training Through a Teacher's Guide](https://arxiv.org/abs/2410.04364), led by 4 | 5 | [Dohun Lee*](https://github.com/DoHunLee1), [Bryan Sangwoo Kim*](https://scholar.google.com/citations?user=ndWU-84AAAAJ&hl=en), [Geon Yeong Park](https://geonyeong-park.github.io/), [Jong Chul Ye](https://bispl.weebly.com/professor.html) 6 | 7 | ![main figure](assets/main_fig.png) 8 | 9 | [![Project Website](https://img.shields.io/badge/Project-Website-blue)](https://dohunlee1.github.io/videoguide.github.io/) 10 | [![arXiv](https://img.shields.io/badge/arXiv-2410.04364-b31b1b.svg)](https://arxiv.org/abs/2410.04364) 11 | 12 | --- 13 | ## 🔥 Summary 14 | 15 | **VideoGuide** 🚀 enhances temporal quality in video diffusion models *without additional training or fine-tuning* by leveraging a pretrained model as a guide. During inference, it uses a guiding model to provide a temporally consistent sample, which is interpolated with the sampling model's output to improve consistency. VideoGuide shows the following advantages: 16 | 17 | 1. **Improved temporal consistency** with preserved imaging quality and motion smoothness 18 | 2. **Fast inference** as application only to early steps is proved sufficient 19 | 4. **Prior distillation** of the guiding model 20 | 21 | ## 🗓 ️News 22 | - [8 Oct 2024] Code and paper are uploaded. 23 | 24 | ## 🛠️ Setup 25 | First, create your environment. We recommend using the following comments. 26 | 27 | ``` 28 | git clone https://github.com/DoHunLee1/VideoGuide.git 29 | cd VideoGuide 30 | 31 | conda create -n videoguide python=3.10 32 | conda activate videoguide 33 | conda install pytorch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 pytorch-cuda=11.8 -c pytorch -c nvidia 34 | pip install -r requirements.txt 35 | pip install xformers==0.0.22.post4 --index-url https://download.pytorch.org/whl/cu118 36 | ``` 37 | 38 | ## ⏳ Models 39 | 40 | |Models|Checkpoints| 41 | |:---------|:--------| 42 | |VideoCrafter2|[Hugging Face](https://huggingface.co/VideoCrafter/VideoCrafter2/blob/main/model.ckpt) 43 | |AnimateDiff|[Hugging Face](https://huggingface.co/guoyww/animatediff/tree/main) 44 | |RealisticVision|[Hugging Face](https://huggingface.co/ckpt/realistic-vision-v20/blob/main/realisticVisionV20_v20.safetensors) 45 | |Stable Diffusion v1.5|[Hugging Face](https://huggingface.co/benjamin-paine/stable-diffusion-v1-5/tree/main) 46 | 47 | Please refer to the official repositories of [AnimateDiff](https://github.com/guoyww/AnimateDiff) and [VideoCrafter](https://github.com/AILab-CVC/VideoCrafter/tree/main) for detailed explanation and setup guide for each model. We thank them for sharing their impressive work! 48 | 49 | ## 🌄 Example 50 | An example of using **VideoGuide** is provided in the inference.sh code. 51 | 52 | ## 📝 Citation 53 | If you find our method useful, please cite as below or leave a star to this repository. 54 | 55 | ``` 56 | @article{lee2024videoguide, 57 | title={VideoGuide: Improving Video Diffusion Models without Training Through a Teacher's Guide}, 58 | author={Lee, Dohun and Kim, Bryan S and Park, Geon Yeong and Ye, Jong Chul}, 59 | journal={arXiv preprint arXiv:2410.04364}, 60 | year={2024} 61 | } 62 | ``` 63 | 64 | ## 🤗 Acknowledgements 65 | We thank the authors of [AnimateDiff](https://github.com/guoyww/AnimateDiff), [VideoCrafter](https://github.com/AILab-CVC/VideoCrafter/tree/main), [Stable Diffusion](https://github.com/Stability-AI/stablediffusion) for sharing their awesome work. We also thank the [CivitAI](https://civitai.com/) community for sharing their impressive T2I models! 66 | 67 | > [!note] 68 | > This work is currently in the preprint stage, and there may be some changes to the code. 69 | -------------------------------------------------------------------------------- /__pycache__/convert_from_ckpt.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoHunLee1/VideoGuide/adfe4eeff279a611c76f948ec9cc2829ed3f123e/__pycache__/convert_from_ckpt.cpython-310.pyc -------------------------------------------------------------------------------- /__pycache__/convert_lora_safetensor_to_diffusers.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoHunLee1/VideoGuide/adfe4eeff279a611c76f948ec9cc2829ed3f123e/__pycache__/convert_lora_safetensor_to_diffusers.cpython-310.pyc -------------------------------------------------------------------------------- /__pycache__/filter_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoHunLee1/VideoGuide/adfe4eeff279a611c76f948ec9cc2829ed3f123e/__pycache__/filter_utils.cpython-310.pyc -------------------------------------------------------------------------------- /__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoHunLee1/VideoGuide/adfe4eeff279a611c76f948ec9cc2829ed3f123e/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /animatediff/data/__pycache__/dataset.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoHunLee1/VideoGuide/adfe4eeff279a611c76f948ec9cc2829ed3f123e/animatediff/data/__pycache__/dataset.cpython-310.pyc -------------------------------------------------------------------------------- /animatediff/data/dataset.py: -------------------------------------------------------------------------------- 1 | import os, io, csv, math, random 2 | import numpy as np 3 | from einops import rearrange 4 | from decord import VideoReader 5 | 6 | import torch 7 | import torchvision.transforms as transforms 8 | from torch.utils.data.dataset import Dataset 9 | from animatediff.utils.util import zero_rank_print 10 | from datasets import load_dataset 11 | import pandas as pd 12 | import glob 13 | 14 | class Laion_2B(Dataset): 15 | def __init__(self, data_path="../laion_2b"): 16 | data_file = glob.glob(f"{data_path}/*.parquet", recursive=True) 17 | self.data_path = data_file 18 | self.data = [] 19 | self.__read_data__() 20 | 21 | def __read_data__(self): 22 | for data_path in sorted(self.data_path): 23 | df = pd.read_parquet(data_path) 24 | parquet_list = df.TEXT.to_list() 25 | self.data = self.data + parquet_list 26 | 27 | def __getitem__(self, idx): 28 | text = self.data[idx] 29 | text_dict = dict(text=text) 30 | 31 | return text_dict 32 | 33 | 34 | def __len__(self): 35 | return len(self.data) 36 | 37 | 38 | class WebVid10M(Dataset): 39 | def __init__( 40 | self, 41 | csv_path, video_folder, 42 | sample_size=256, sample_stride=4, sample_n_frames=16, 43 | is_image=False, 44 | ): 45 | zero_rank_print(f"loading annotations from {csv_path} ...") 46 | with open(csv_path, 'r') as csvfile: 47 | self.dataset = list(csv.DictReader(csvfile)) 48 | self.length = len(self.dataset) 49 | zero_rank_print(f"data scale: {self.length}") 50 | 51 | self.video_folder = video_folder 52 | self.sample_stride = sample_stride 53 | self.sample_n_frames = sample_n_frames 54 | self.is_image = is_image 55 | 56 | sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) 57 | self.pixel_transforms = transforms.Compose([ 58 | transforms.RandomHorizontalFlip(), 59 | transforms.Resize(sample_size[0]), 60 | transforms.CenterCrop(sample_size), 61 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), 62 | ]) 63 | 64 | def get_batch(self, idx): 65 | video_dict = self.dataset[idx] 66 | videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir'] 67 | 68 | video_dir = os.path.join(self.video_folder, f"{videoid}.mp4") 69 | video_reader = VideoReader(video_dir) 70 | video_length = len(video_reader) 71 | 72 | if not self.is_image: 73 | clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1) 74 | start_idx = random.randint(0, video_length - clip_length) 75 | batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int) 76 | else: 77 | batch_index = [random.randint(0, video_length - 1)] 78 | 79 | pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() 80 | pixel_values = pixel_values / 255. 81 | del video_reader 82 | 83 | if self.is_image: 84 | pixel_values = pixel_values[0] 85 | 86 | 87 | return pixel_values, name 88 | 89 | def __len__(self): 90 | return self.length 91 | 92 | def __getitem__(self, idx): 93 | while True: 94 | try: 95 | pixel_values, name = self.get_batch(idx) 96 | break 97 | 98 | except Exception as e: 99 | idx = random.randint(0, self.length-1) 100 | 101 | pixel_values = self.pixel_transforms(pixel_values) 102 | sample = dict(pixel_values=pixel_values, text=name) 103 | return sample 104 | 105 | 106 | 107 | if __name__ == "__main__": 108 | from animatediff.utils.util import save_videos_grid 109 | 110 | dataset = WebVid10M( 111 | csv_path="/home/leedh3726/video-diffusion/video-data/Webvid2M/results_2M_val.csv", 112 | video_folder="/home/leedh3726/video-diffusion/video-data/Webvid2M/video", 113 | sample_size=256, 114 | sample_stride=4, sample_n_frames=16, 115 | is_image=True, 116 | ) 117 | # import pdb 118 | # pdb.set_trace() 119 | 120 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, num_workers=16) 121 | print(len(dataloader)) 122 | for idx, batch in enumerate(dataloader): 123 | print(idx) 124 | # print(batch["pixel_values"].shape, len(batch["text"])) 125 | # for i in range(batch["pixel_values"].shape[0]): 126 | # save_videos_grid(batch["pixel_values"][i:i+1].permute(0,2,1,3,4), os.path.join(".", f"{idx}-{i}.mp4"), rescale=True) 127 | -------------------------------------------------------------------------------- /animatediff/models/__pycache__/attention.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoHunLee1/VideoGuide/adfe4eeff279a611c76f948ec9cc2829ed3f123e/animatediff/models/__pycache__/attention.cpython-310.pyc -------------------------------------------------------------------------------- /animatediff/models/__pycache__/attention.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoHunLee1/VideoGuide/adfe4eeff279a611c76f948ec9cc2829ed3f123e/animatediff/models/__pycache__/attention.cpython-311.pyc -------------------------------------------------------------------------------- /animatediff/models/__pycache__/motion_module.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoHunLee1/VideoGuide/adfe4eeff279a611c76f948ec9cc2829ed3f123e/animatediff/models/__pycache__/motion_module.cpython-310.pyc -------------------------------------------------------------------------------- /animatediff/models/__pycache__/motion_module.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoHunLee1/VideoGuide/adfe4eeff279a611c76f948ec9cc2829ed3f123e/animatediff/models/__pycache__/motion_module.cpython-311.pyc -------------------------------------------------------------------------------- /animatediff/models/__pycache__/resnet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoHunLee1/VideoGuide/adfe4eeff279a611c76f948ec9cc2829ed3f123e/animatediff/models/__pycache__/resnet.cpython-310.pyc -------------------------------------------------------------------------------- /animatediff/models/__pycache__/resnet.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoHunLee1/VideoGuide/adfe4eeff279a611c76f948ec9cc2829ed3f123e/animatediff/models/__pycache__/resnet.cpython-311.pyc -------------------------------------------------------------------------------- /animatediff/models/__pycache__/sparse_controlnet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoHunLee1/VideoGuide/adfe4eeff279a611c76f948ec9cc2829ed3f123e/animatediff/models/__pycache__/sparse_controlnet.cpython-310.pyc -------------------------------------------------------------------------------- /animatediff/models/__pycache__/unet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoHunLee1/VideoGuide/adfe4eeff279a611c76f948ec9cc2829ed3f123e/animatediff/models/__pycache__/unet.cpython-310.pyc -------------------------------------------------------------------------------- /animatediff/models/__pycache__/unet.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoHunLee1/VideoGuide/adfe4eeff279a611c76f948ec9cc2829ed3f123e/animatediff/models/__pycache__/unet.cpython-311.pyc -------------------------------------------------------------------------------- /animatediff/models/__pycache__/unet_blocks.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoHunLee1/VideoGuide/adfe4eeff279a611c76f948ec9cc2829ed3f123e/animatediff/models/__pycache__/unet_blocks.cpython-310.pyc -------------------------------------------------------------------------------- /animatediff/models/__pycache__/unet_blocks.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoHunLee1/VideoGuide/adfe4eeff279a611c76f948ec9cc2829ed3f123e/animatediff/models/__pycache__/unet_blocks.cpython-311.pyc -------------------------------------------------------------------------------- /animatediff/models/motion_module.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List, Optional, Tuple, Union 3 | 4 | import torch 5 | import numpy as np 6 | import torch.nn.functional as F 7 | from torch import nn 8 | import torchvision 9 | 10 | from diffusers.configuration_utils import ConfigMixin, register_to_config 11 | from diffusers.modeling_utils import ModelMixin 12 | from diffusers.utils import BaseOutput 13 | from diffusers.utils.import_utils import is_xformers_available 14 | from diffusers.models.attention import CrossAttention, FeedForward 15 | 16 | from einops import rearrange, repeat 17 | import math 18 | 19 | 20 | def get_views(video_length, window_size=16, stride=4): 21 | num_blocks_time = (video_length - window_size) // stride + 1 22 | views = [] 23 | for i in range(num_blocks_time): 24 | t_start = int(i * stride) 25 | t_end = t_start + window_size 26 | views.append((t_start,t_end)) 27 | return views 28 | 29 | 30 | def generate_weight_sequence(n): 31 | if n % 2 == 0: 32 | max_weight = n // 2 33 | weight_sequence = list(range(1, max_weight + 1, 1)) + list(range(max_weight, 0, -1)) 34 | else: 35 | max_weight = (n + 1) // 2 36 | weight_sequence = list(range(1, max_weight, 1)) + [max_weight] + list(range(max_weight - 1, 0, -1)) 37 | return weight_sequence 38 | 39 | 40 | def zero_module(module): 41 | # Zero out the parameters of a module and return it. 42 | for p in module.parameters(): 43 | p.detach().zero_() 44 | return module 45 | 46 | 47 | @dataclass 48 | class TemporalTransformer3DModelOutput(BaseOutput): 49 | sample: torch.FloatTensor 50 | 51 | 52 | if is_xformers_available(): 53 | import xformers 54 | import xformers.ops 55 | else: 56 | xformers = None 57 | 58 | 59 | def get_motion_module( 60 | in_channels, 61 | motion_module_type: str, 62 | motion_module_kwargs: dict 63 | ): 64 | if motion_module_type == "Vanilla": 65 | return VanillaTemporalModule(in_channels=in_channels, **motion_module_kwargs,) 66 | else: 67 | raise ValueError 68 | 69 | def get_window_motion_module( 70 | in_channels, 71 | motion_module_type: str, 72 | motion_module_kwargs: dict 73 | ): 74 | if motion_module_type == "Vanilla": 75 | return VanillaTemporalModule(in_channels=in_channels, local_window=True, **motion_module_kwargs,) 76 | else: 77 | raise ValueError 78 | 79 | 80 | class VanillaTemporalModule(nn.Module): 81 | def __init__( 82 | self, 83 | in_channels, 84 | num_attention_heads = 8, 85 | num_transformer_block = 2, 86 | attention_block_types =( "Temporal_Self", "Temporal_Self" ), 87 | cross_frame_attention_mode = None, 88 | temporal_position_encoding = False, 89 | temporal_position_encoding_max_len = 24, 90 | temporal_attention_dim_div = 1, 91 | zero_initialize = True, 92 | **kwargs, 93 | ): 94 | super().__init__() 95 | 96 | self.temporal_transformer = TemporalTransformer3DModel( 97 | in_channels=in_channels, 98 | num_attention_heads=num_attention_heads, 99 | attention_head_dim=in_channels // num_attention_heads // temporal_attention_dim_div, 100 | num_layers=num_transformer_block, 101 | attention_block_types=attention_block_types, 102 | cross_frame_attention_mode=cross_frame_attention_mode, 103 | temporal_position_encoding=temporal_position_encoding, 104 | temporal_position_encoding_max_len=temporal_position_encoding_max_len, 105 | **kwargs, 106 | ) 107 | 108 | if zero_initialize: 109 | self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out) 110 | 111 | def forward(self, input_tensor, temb, encoder_hidden_states, attention_mask=None): 112 | hidden_states = input_tensor 113 | hidden_states = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask) 114 | 115 | output = hidden_states 116 | return output 117 | 118 | 119 | class TemporalTransformer3DModel(nn.Module): 120 | def __init__( 121 | self, 122 | in_channels, 123 | num_attention_heads, 124 | attention_head_dim, 125 | 126 | num_layers, 127 | attention_block_types = ( "Temporal_Self", "Temporal_Self", ), 128 | dropout = 0.0, 129 | norm_num_groups = 32, 130 | cross_attention_dim = 768, 131 | activation_fn = "geglu", 132 | attention_bias = False, 133 | upcast_attention = False, 134 | 135 | cross_frame_attention_mode = None, 136 | temporal_position_encoding = False, 137 | temporal_position_encoding_max_len = 24, 138 | **kwargs, 139 | ): 140 | super().__init__() 141 | 142 | inner_dim = num_attention_heads * attention_head_dim 143 | 144 | self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) 145 | self.proj_in = nn.Linear(in_channels, inner_dim) 146 | 147 | self.transformer_blocks = nn.ModuleList( 148 | [ 149 | TemporalTransformerBlock( 150 | dim=inner_dim, 151 | num_attention_heads=num_attention_heads, 152 | attention_head_dim=attention_head_dim, 153 | attention_block_types=attention_block_types, 154 | dropout=dropout, 155 | norm_num_groups=norm_num_groups, 156 | cross_attention_dim=cross_attention_dim, 157 | activation_fn=activation_fn, 158 | attention_bias=attention_bias, 159 | upcast_attention=upcast_attention, 160 | cross_frame_attention_mode=cross_frame_attention_mode, 161 | temporal_position_encoding=temporal_position_encoding, 162 | temporal_position_encoding_max_len=temporal_position_encoding_max_len, 163 | **kwargs, 164 | ) 165 | for d in range(num_layers) 166 | ] 167 | ) 168 | self.proj_out = nn.Linear(inner_dim, in_channels) 169 | 170 | def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None): 171 | assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." 172 | video_length = hidden_states.shape[2] 173 | hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") 174 | 175 | batch, channel, height, weight = hidden_states.shape 176 | residual = hidden_states 177 | 178 | hidden_states = self.norm(hidden_states) 179 | inner_dim = hidden_states.shape[1] 180 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) 181 | hidden_states = self.proj_in(hidden_states) 182 | 183 | # Transformer Blocks 184 | for block in self.transformer_blocks: 185 | hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, video_length=video_length) 186 | 187 | # output 188 | hidden_states = self.proj_out(hidden_states) 189 | hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() 190 | 191 | output = hidden_states + residual 192 | output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) 193 | 194 | return output 195 | 196 | 197 | class TemporalTransformerBlock(nn.Module): 198 | def __init__( 199 | self, 200 | dim, 201 | num_attention_heads, 202 | attention_head_dim, 203 | attention_block_types = ( "Temporal_Self", "Temporal_Self", ), 204 | dropout = 0.0, 205 | norm_num_groups = 32, 206 | cross_attention_dim = 768, 207 | activation_fn = "geglu", 208 | attention_bias = False, 209 | upcast_attention = False, 210 | cross_frame_attention_mode = None, 211 | temporal_position_encoding = False, 212 | temporal_position_encoding_max_len = 24, 213 | local_window = False, 214 | ): 215 | super().__init__() 216 | 217 | attention_blocks = [] 218 | norms = [] 219 | 220 | for block_name in attention_block_types: 221 | attention_blocks.append( 222 | VersatileAttention( 223 | attention_mode=block_name.split("_")[0], 224 | cross_attention_dim=cross_attention_dim if block_name.endswith("_Cross") else None, 225 | 226 | query_dim=dim, 227 | heads=num_attention_heads, 228 | dim_head=attention_head_dim, 229 | dropout=dropout, 230 | bias=attention_bias, 231 | upcast_attention=upcast_attention, 232 | 233 | cross_frame_attention_mode=cross_frame_attention_mode, 234 | temporal_position_encoding=temporal_position_encoding, 235 | temporal_position_encoding_max_len=temporal_position_encoding_max_len, 236 | ) 237 | ) 238 | norms.append(nn.LayerNorm(dim)) 239 | 240 | self.attention_blocks = nn.ModuleList(attention_blocks) 241 | self.norms = nn.ModuleList(norms) 242 | 243 | self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) 244 | self.ff_norm = nn.LayerNorm(dim) 245 | 246 | self.local_window = local_window 247 | 248 | 249 | def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None): 250 | 251 | if not self.local_window: 252 | for attention_block, norm in zip(self.attention_blocks, self.norms): 253 | norm_hidden_states = norm(hidden_states) 254 | hidden_states = attention_block( 255 | norm_hidden_states, 256 | encoder_hidden_states=encoder_hidden_states if attention_block.is_cross_attention else None, 257 | video_length=video_length, 258 | ) + hidden_states 259 | else: 260 | views = get_views(video_length, stride=8) 261 | hidden_states = rearrange(hidden_states, "(b f) d c -> b f d c", f=video_length) 262 | count = torch.zeros_like(hidden_states) 263 | value = torch.zeros_like(hidden_states) 264 | for t_start, t_end in views: 265 | # weight_sequence = generate_weight_sequence(t_end - t_start) 266 | weight_tensor = torch.ones_like(count[:, t_start:t_end]) 267 | # weight_tensor = weight_tensor * torch.Tensor(weight_sequence).to(hidden_states.device).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) 268 | weight_tensor = weight_tensor.to(hidden_states.device) 269 | 270 | sub_hidden_states = rearrange(hidden_states[:, t_start:t_end], "b f d c -> (b f) d c") 271 | for attention_block, norm in zip(self.attention_blocks, self.norms): 272 | norm_hidden_states = norm(sub_hidden_states) 273 | sub_hidden_states = attention_block( 274 | norm_hidden_states, 275 | encoder_hidden_states=encoder_hidden_states if attention_block.is_cross_attention else None, 276 | video_length=t_end-t_start, 277 | ) + sub_hidden_states 278 | sub_hidden_states = rearrange(sub_hidden_states, "(b f) d c -> b f d c", f=t_end-t_start) 279 | 280 | value[:,t_start:t_end] += sub_hidden_states * weight_tensor 281 | count[:,t_start:t_end] += weight_tensor 282 | 283 | hidden_states = torch.where(count>0, value/count, value) 284 | hidden_states = rearrange(hidden_states, "b f d c -> (b f) d c") 285 | 286 | # for attention_block, norm in zip(self.attention_blocks, self.norms): 287 | # norm_hidden_states = norm(hidden_states) 288 | # hidden_states = attention_block( 289 | # norm_hidden_states, 290 | # encoder_hidden_states=encoder_hidden_states if attention_block.is_cross_attention else None, 291 | # video_length=video_length, 292 | # ) + hidden_states 293 | 294 | hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states 295 | 296 | output = hidden_states 297 | return output 298 | 299 | 300 | class PositionalEncoding(nn.Module): 301 | def __init__( 302 | self, 303 | d_model, 304 | dropout = 0., 305 | max_len = 24 306 | ): 307 | super().__init__() 308 | self.dropout = nn.Dropout(p=dropout) 309 | position = torch.arange(max_len).unsqueeze(1) 310 | div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) 311 | pe = torch.zeros(1, max_len, d_model) 312 | pe[0, :, 0::2] = torch.sin(position * div_term) 313 | pe[0, :, 1::2] = torch.cos(position * div_term) 314 | self.register_buffer('pe', pe) 315 | 316 | def forward(self, x): 317 | x = x + self.pe[:, :x.size(1)] 318 | return self.dropout(x) 319 | 320 | 321 | class VersatileAttention(CrossAttention): 322 | def __init__( 323 | self, 324 | attention_mode = None, 325 | cross_frame_attention_mode = None, 326 | temporal_position_encoding = False, 327 | temporal_position_encoding_max_len = 24, 328 | *args, **kwargs 329 | ): 330 | super().__init__(*args, **kwargs) 331 | assert attention_mode == "Temporal" 332 | 333 | self.attention_mode = attention_mode 334 | self.is_cross_attention = kwargs["cross_attention_dim"] is not None 335 | 336 | self.pos_encoder = PositionalEncoding( 337 | kwargs["query_dim"], 338 | dropout=0., 339 | max_len=temporal_position_encoding_max_len 340 | ) if (temporal_position_encoding and attention_mode == "Temporal") else None 341 | 342 | def extra_repr(self): 343 | return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}" 344 | 345 | def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None): 346 | batch_size, sequence_length, _ = hidden_states.shape 347 | 348 | if self.attention_mode == "Temporal": 349 | d = hidden_states.shape[1] 350 | hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length) 351 | 352 | if self.pos_encoder is not None: 353 | hidden_states = self.pos_encoder(hidden_states) 354 | 355 | encoder_hidden_states = repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d) if encoder_hidden_states is not None else encoder_hidden_states 356 | else: 357 | raise NotImplementedError 358 | 359 | encoder_hidden_states = encoder_hidden_states 360 | 361 | if self.group_norm is not None: 362 | hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 363 | 364 | query = self.to_q(hidden_states) 365 | dim = query.shape[-1] 366 | query = self.reshape_heads_to_batch_dim(query) 367 | 368 | if self.added_kv_proj_dim is not None: 369 | raise NotImplementedError 370 | 371 | encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states 372 | key = self.to_k(encoder_hidden_states) 373 | value = self.to_v(encoder_hidden_states) 374 | 375 | key = self.reshape_heads_to_batch_dim(key) 376 | value = self.reshape_heads_to_batch_dim(value) 377 | 378 | if attention_mask is not None: 379 | if attention_mask.shape[-1] != query.shape[1]: 380 | target_length = query.shape[1] 381 | attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) 382 | attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) 383 | 384 | # attention, what we cannot get enough of 385 | if self._use_memory_efficient_attention_xformers: 386 | hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) 387 | # Some versions of xformers return output in fp32, cast it back to the dtype of the input 388 | hidden_states = hidden_states.to(query.dtype) 389 | else: 390 | if self._slice_size is None or query.shape[0] // self._slice_size == 1: 391 | hidden_states = self._attention(query, key, value, attention_mask) 392 | else: 393 | hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) 394 | 395 | # linear proj 396 | hidden_states = self.to_out[0](hidden_states) 397 | 398 | # dropout 399 | hidden_states = self.to_out[1](hidden_states) 400 | 401 | if self.attention_mode == "Temporal": 402 | hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) 403 | 404 | return hidden_states 405 | 406 | -------------------------------------------------------------------------------- /animatediff/models/resnet.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from einops import rearrange 8 | 9 | 10 | class InflatedConv3d(nn.Conv2d): 11 | def forward(self, x): 12 | video_length = x.shape[2] 13 | 14 | x = rearrange(x, "b c f h w -> (b f) c h w") 15 | x = super().forward(x) 16 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) 17 | 18 | return x 19 | 20 | 21 | class InflatedGroupNorm(nn.GroupNorm): 22 | def forward(self, x): 23 | video_length = x.shape[2] 24 | 25 | x = rearrange(x, "b c f h w -> (b f) c h w") 26 | x = super().forward(x) 27 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) 28 | 29 | return x 30 | 31 | 32 | class Upsample3D(nn.Module): 33 | def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): 34 | super().__init__() 35 | self.channels = channels 36 | self.out_channels = out_channels or channels 37 | self.use_conv = use_conv 38 | self.use_conv_transpose = use_conv_transpose 39 | self.name = name 40 | 41 | conv = None 42 | if use_conv_transpose: 43 | raise NotImplementedError 44 | elif use_conv: 45 | self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1) 46 | 47 | def forward(self, hidden_states, output_size=None): 48 | assert hidden_states.shape[1] == self.channels 49 | 50 | if self.use_conv_transpose: 51 | raise NotImplementedError 52 | 53 | # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 54 | dtype = hidden_states.dtype 55 | if dtype == torch.bfloat16: 56 | hidden_states = hidden_states.to(torch.float32) 57 | 58 | # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 59 | if hidden_states.shape[0] >= 64: 60 | hidden_states = hidden_states.contiguous() 61 | 62 | # if `output_size` is passed we force the interpolation output 63 | # size and do not make use of `scale_factor=2` 64 | if output_size is None: 65 | hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest") 66 | else: 67 | hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") 68 | 69 | # If the input is bfloat16, we cast back to bfloat16 70 | if dtype == torch.bfloat16: 71 | hidden_states = hidden_states.to(dtype) 72 | 73 | # if self.use_conv: 74 | # if self.name == "conv": 75 | # hidden_states = self.conv(hidden_states) 76 | # else: 77 | # hidden_states = self.Conv2d_0(hidden_states) 78 | hidden_states = self.conv(hidden_states) 79 | 80 | return hidden_states 81 | 82 | 83 | class Downsample3D(nn.Module): 84 | def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): 85 | super().__init__() 86 | self.channels = channels 87 | self.out_channels = out_channels or channels 88 | self.use_conv = use_conv 89 | self.padding = padding 90 | stride = 2 91 | self.name = name 92 | 93 | if use_conv: 94 | self.conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding) 95 | else: 96 | raise NotImplementedError 97 | 98 | def forward(self, hidden_states): 99 | assert hidden_states.shape[1] == self.channels 100 | if self.use_conv and self.padding == 0: 101 | raise NotImplementedError 102 | 103 | assert hidden_states.shape[1] == self.channels 104 | hidden_states = self.conv(hidden_states) 105 | 106 | return hidden_states 107 | 108 | 109 | class ResnetBlock3D(nn.Module): 110 | def __init__( 111 | self, 112 | *, 113 | in_channels, 114 | out_channels=None, 115 | conv_shortcut=False, 116 | dropout=0.0, 117 | temb_channels=512, 118 | groups=32, 119 | groups_out=None, 120 | pre_norm=True, 121 | eps=1e-6, 122 | non_linearity="swish", 123 | time_embedding_norm="default", 124 | output_scale_factor=1.0, 125 | use_in_shortcut=None, 126 | use_inflated_groupnorm=False, 127 | ): 128 | super().__init__() 129 | self.pre_norm = pre_norm 130 | self.pre_norm = True 131 | self.in_channels = in_channels 132 | out_channels = in_channels if out_channels is None else out_channels 133 | self.out_channels = out_channels 134 | self.use_conv_shortcut = conv_shortcut 135 | self.time_embedding_norm = time_embedding_norm 136 | self.output_scale_factor = output_scale_factor 137 | 138 | if groups_out is None: 139 | groups_out = groups 140 | 141 | assert use_inflated_groupnorm != None 142 | if use_inflated_groupnorm: # True, 2D group norm 143 | self.norm1 = InflatedGroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) 144 | else: 145 | self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) 146 | 147 | self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) 148 | 149 | if temb_channels is not None: 150 | if self.time_embedding_norm == "default": 151 | time_emb_proj_out_channels = out_channels 152 | elif self.time_embedding_norm == "scale_shift": 153 | time_emb_proj_out_channels = out_channels * 2 154 | else: 155 | raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ") 156 | 157 | self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels) 158 | else: 159 | self.time_emb_proj = None 160 | 161 | if use_inflated_groupnorm: 162 | self.norm2 = InflatedGroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) 163 | else: 164 | self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) 165 | 166 | self.dropout = torch.nn.Dropout(dropout) 167 | self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) 168 | 169 | if non_linearity == "swish": 170 | self.nonlinearity = lambda x: F.silu(x) 171 | elif non_linearity == "mish": 172 | self.nonlinearity = Mish() 173 | elif non_linearity == "silu": 174 | self.nonlinearity = nn.SiLU() 175 | 176 | self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut 177 | 178 | self.conv_shortcut = None 179 | if self.use_in_shortcut: 180 | self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) 181 | 182 | # self.in_layers_features = None 183 | # self.out_layers_features = None 184 | 185 | def forward(self, input_tensor, temb, out_layers_injected=None): 186 | hidden_states = input_tensor 187 | 188 | hidden_states = self.norm1(hidden_states) 189 | hidden_states = self.nonlinearity(hidden_states) 190 | 191 | hidden_states = self.conv1(hidden_states) 192 | 193 | # self.in_layers_features = hidden_states 194 | 195 | if temb is not None: 196 | temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None] 197 | if temb.shape[0] % 16 == 0: 198 | temb = rearrange(temb, '(f b) c x y z -> b c (f x) y z', f=16) 199 | 200 | if temb is not None and self.time_embedding_norm == "default": 201 | hidden_states = hidden_states + temb 202 | 203 | hidden_states = self.norm2(hidden_states) 204 | 205 | if temb is not None and self.time_embedding_norm == "scale_shift": 206 | scale, shift = torch.chunk(temb, 2, dim=1) 207 | hidden_states = hidden_states * (1 + scale) + shift 208 | 209 | hidden_states = self.nonlinearity(hidden_states) 210 | 211 | hidden_states = self.dropout(hidden_states) 212 | hidden_states = self.conv2(hidden_states) 213 | 214 | if self.conv_shortcut is not None: 215 | input_tensor = self.conv_shortcut(input_tensor) 216 | 217 | if out_layers_injected is not None: 218 | hidden_states = out_layers_injected.to(hidden_states.device) 219 | 220 | # self.out_layers_features = hidden_states 221 | 222 | output_tensor = (input_tensor + hidden_states) / self.output_scale_factor 223 | 224 | return output_tensor 225 | 226 | 227 | class Mish(torch.nn.Module): 228 | def forward(self, hidden_states): 229 | return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states)) -------------------------------------------------------------------------------- /animatediff/pipelines/__pycache__/pipeline_animation.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoHunLee1/VideoGuide/adfe4eeff279a611c76f948ec9cc2829ed3f123e/animatediff/pipelines/__pycache__/pipeline_animation.cpython-310.pyc -------------------------------------------------------------------------------- /animatediff/pipelines/__pycache__/pipeline_attention.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoHunLee1/VideoGuide/adfe4eeff279a611c76f948ec9cc2829ed3f123e/animatediff/pipelines/__pycache__/pipeline_attention.cpython-310.pyc -------------------------------------------------------------------------------- /animatediff/pipelines/__pycache__/pipeline_sd.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoHunLee1/VideoGuide/adfe4eeff279a611c76f948ec9cc2829ed3f123e/animatediff/pipelines/__pycache__/pipeline_sd.cpython-310.pyc -------------------------------------------------------------------------------- /animatediff/pipelines/__pycache__/pipeline_temporal_cfg.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoHunLee1/VideoGuide/adfe4eeff279a611c76f948ec9cc2829ed3f123e/animatediff/pipelines/__pycache__/pipeline_temporal_cfg.cpython-310.pyc -------------------------------------------------------------------------------- /animatediff/utils/__pycache__/convert_from_ckpt.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoHunLee1/VideoGuide/adfe4eeff279a611c76f948ec9cc2829ed3f123e/animatediff/utils/__pycache__/convert_from_ckpt.cpython-310.pyc -------------------------------------------------------------------------------- /animatediff/utils/__pycache__/convert_from_ckpt.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoHunLee1/VideoGuide/adfe4eeff279a611c76f948ec9cc2829ed3f123e/animatediff/utils/__pycache__/convert_from_ckpt.cpython-311.pyc -------------------------------------------------------------------------------- /animatediff/utils/__pycache__/convert_lora_safetensor_to_diffusers.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoHunLee1/VideoGuide/adfe4eeff279a611c76f948ec9cc2829ed3f123e/animatediff/utils/__pycache__/convert_lora_safetensor_to_diffusers.cpython-310.pyc -------------------------------------------------------------------------------- /animatediff/utils/__pycache__/convert_lora_safetensor_to_diffusers.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoHunLee1/VideoGuide/adfe4eeff279a611c76f948ec9cc2829ed3f123e/animatediff/utils/__pycache__/convert_lora_safetensor_to_diffusers.cpython-311.pyc -------------------------------------------------------------------------------- /animatediff/utils/__pycache__/util.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoHunLee1/VideoGuide/adfe4eeff279a611c76f948ec9cc2829ed3f123e/animatediff/utils/__pycache__/util.cpython-310.pyc -------------------------------------------------------------------------------- /animatediff/utils/__pycache__/util.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoHunLee1/VideoGuide/adfe4eeff279a611c76f948ec9cc2829ed3f123e/animatediff/utils/__pycache__/util.cpython-311.pyc -------------------------------------------------------------------------------- /animatediff/utils/convert_lora_safetensor_to_diffusers.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023, Haofan Wang, Qixun Wang, All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | # Changes were made to this source code by Yuwei Guo. 17 | """ Conversion script for the LoRA's safetensors checkpoints. """ 18 | 19 | import argparse 20 | 21 | import torch 22 | from safetensors.torch import load_file 23 | 24 | from diffusers import StableDiffusionPipeline 25 | 26 | 27 | def load_diffusers_lora(pipeline, state_dict, alpha=1.0): 28 | # directly update weight in diffusers model 29 | for key in state_dict: 30 | # only process lora down key 31 | if "up." in key: continue 32 | 33 | up_key = key.replace(".down.", ".up.") 34 | model_key = key.replace("processor.", "").replace("_lora", "").replace("down.", "").replace("up.", "") 35 | model_key = model_key.replace("to_out.", "to_out.0.") 36 | layer_infos = model_key.split(".")[:-1] 37 | 38 | curr_layer = pipeline.unet 39 | while len(layer_infos) > 0: 40 | temp_name = layer_infos.pop(0) 41 | curr_layer = curr_layer.__getattr__(temp_name) 42 | 43 | weight_down = state_dict[key] 44 | weight_up = state_dict[up_key] 45 | curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device) 46 | 47 | return pipeline 48 | 49 | def load_diffusers_lora_unet(unet, state_dict, alpha=1.0): 50 | for key in state_dict: 51 | # only process lora down key 52 | if "up." in key: continue 53 | 54 | up_key = key.replace(".down.", ".up.") 55 | model_key = key.replace("processor.", "").replace("_lora", "").replace("down.", "").replace("up.", "") 56 | model_key = model_key.replace("to_out.", "to_out.0.") 57 | layer_infos = model_key.split(".")[:-1] 58 | 59 | curr_layer = unet 60 | while len(layer_infos) > 0: 61 | temp_name = layer_infos.pop(0) 62 | curr_layer = curr_layer.__getattr__(temp_name) 63 | 64 | weight_down = state_dict[key] 65 | weight_up = state_dict[up_key] 66 | curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device) 67 | 68 | return unet 69 | 70 | 71 | def convert_lora(pipeline, state_dict, LORA_PREFIX_UNET="lora_unet", LORA_PREFIX_TEXT_ENCODER="lora_te", alpha=0.6): 72 | # load base model 73 | # pipeline = StableDiffusionPipeline.from_pretrained(base_model_path, torch_dtype=torch.float32) 74 | 75 | # load LoRA weight from .safetensors 76 | # state_dict = load_file(checkpoint_path) 77 | 78 | visited = [] 79 | 80 | # directly update weight in diffusers model 81 | for key in state_dict: 82 | # it is suggested to print out the key, it usually will be something like below 83 | # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight" 84 | 85 | # as we have set the alpha beforehand, so just skip 86 | if ".alpha" in key or key in visited: 87 | continue 88 | 89 | if "text" in key: 90 | layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_") 91 | curr_layer = pipeline.text_encoder 92 | else: 93 | layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_") 94 | curr_layer = pipeline.unet 95 | 96 | # find the target layer 97 | temp_name = layer_infos.pop(0) 98 | while len(layer_infos) > -1: 99 | try: 100 | curr_layer = curr_layer.__getattr__(temp_name) 101 | if len(layer_infos) > 0: 102 | temp_name = layer_infos.pop(0) 103 | elif len(layer_infos) == 0: 104 | break 105 | except Exception: 106 | if len(temp_name) > 0: 107 | temp_name += "_" + layer_infos.pop(0) 108 | else: 109 | temp_name = layer_infos.pop(0) 110 | 111 | pair_keys = [] 112 | if "lora_down" in key: 113 | pair_keys.append(key.replace("lora_down", "lora_up")) 114 | pair_keys.append(key) 115 | else: 116 | pair_keys.append(key) 117 | pair_keys.append(key.replace("lora_up", "lora_down")) 118 | 119 | # update weight 120 | if len(state_dict[pair_keys[0]].shape) == 4: 121 | weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32) 122 | weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32) 123 | curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3).to(curr_layer.weight.data.device) 124 | else: 125 | weight_up = state_dict[pair_keys[0]].to(torch.float32) 126 | weight_down = state_dict[pair_keys[1]].to(torch.float32) 127 | curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device) 128 | 129 | # update visited list 130 | for item in pair_keys: 131 | visited.append(item) 132 | 133 | return pipeline 134 | 135 | 136 | if __name__ == "__main__": 137 | parser = argparse.ArgumentParser() 138 | 139 | parser.add_argument( 140 | "--base_model_path", default=None, type=str, required=True, help="Path to the base model in diffusers format." 141 | ) 142 | parser.add_argument( 143 | "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert." 144 | ) 145 | parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") 146 | parser.add_argument( 147 | "--lora_prefix_unet", default="lora_unet", type=str, help="The prefix of UNet weight in safetensors" 148 | ) 149 | parser.add_argument( 150 | "--lora_prefix_text_encoder", 151 | default="lora_te", 152 | type=str, 153 | help="The prefix of text encoder weight in safetensors", 154 | ) 155 | parser.add_argument("--alpha", default=0.75, type=float, help="The merging ratio in W = W0 + alpha * deltaW") 156 | parser.add_argument( 157 | "--to_safetensors", action="store_true", help="Whether to store pipeline in safetensors format or not." 158 | ) 159 | parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)") 160 | 161 | args = parser.parse_args() 162 | 163 | base_model_path = args.base_model_path 164 | checkpoint_path = args.checkpoint_path 165 | dump_path = args.dump_path 166 | lora_prefix_unet = args.lora_prefix_unet 167 | lora_prefix_text_encoder = args.lora_prefix_text_encoder 168 | alpha = args.alpha 169 | 170 | pipe = convert(base_model_path, checkpoint_path, lora_prefix_unet, lora_prefix_text_encoder, alpha) 171 | 172 | pipe = pipe.to(args.device) 173 | pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors) 174 | -------------------------------------------------------------------------------- /animatediff/utils/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import imageio 3 | import numpy as np 4 | from typing import Union 5 | 6 | import torch 7 | import torchvision 8 | import torch.distributed as dist 9 | 10 | from safetensors import safe_open 11 | from tqdm import tqdm 12 | from einops import rearrange 13 | from animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint 14 | from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora, load_diffusers_lora, load_diffusers_lora_unet 15 | 16 | 17 | def zero_rank_print(s): 18 | if (not dist.is_initialized()) and (dist.is_initialized() and dist.get_rank() == 0): print("### " + s) 19 | 20 | 21 | def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8): 22 | videos = rearrange(videos, "b c t h w -> t b c h w") 23 | outputs = [] 24 | for x in videos: 25 | x = torchvision.utils.make_grid(x, nrow=n_rows) 26 | x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) 27 | if rescale: 28 | x = (x + 1.0) / 2.0 # -1,1 -> 0,1 29 | x = (x * 255).numpy().astype(np.uint8) 30 | outputs.append(x) 31 | 32 | os.makedirs(os.path.dirname(path), exist_ok=True) 33 | imageio.mimsave(path, outputs, fps=fps) 34 | 35 | 36 | # DDIM Inversion 37 | @torch.no_grad() 38 | def init_prompt(prompt, pipeline): 39 | uncond_input = pipeline.tokenizer( 40 | [""], padding="max_length", max_length=pipeline.tokenizer.model_max_length, 41 | return_tensors="pt" 42 | ) 43 | uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0] 44 | text_input = pipeline.tokenizer( 45 | [prompt], 46 | padding="max_length", 47 | max_length=pipeline.tokenizer.model_max_length, 48 | truncation=True, 49 | return_tensors="pt", 50 | ) 51 | text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0] 52 | context = torch.cat([uncond_embeddings, text_embeddings]).to(dtype=torch.float16) 53 | 54 | return context 55 | 56 | 57 | def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, 58 | sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler): 59 | timestep, next_timestep = min( 60 | timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep 61 | alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod 62 | alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep] 63 | beta_prod_t = 1 - alpha_prod_t 64 | next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5 65 | next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output 66 | next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction 67 | return next_sample 68 | 69 | 70 | def get_noise_pred_single(latents, t, context, unet): 71 | noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"] 72 | return noise_pred 73 | 74 | 75 | @torch.no_grad() 76 | def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt): 77 | context = init_prompt(prompt, pipeline) 78 | uncond_embeddings, cond_embeddings = context.chunk(2) 79 | all_latent = [latent] 80 | latent = latent.clone().detach() 81 | for i in tqdm(range(num_inv_steps)): 82 | t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1] 83 | noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet) 84 | latent = next_step(noise_pred, t, latent, ddim_scheduler) 85 | all_latent.append(latent) 86 | return all_latent 87 | 88 | 89 | @torch.no_grad() 90 | def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""): 91 | ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt) 92 | return ddim_latents 93 | 94 | def load_weights( 95 | animation_pipeline, 96 | # motion module 97 | motion_module_path = "", 98 | motion_module_lora_configs = [], 99 | # domain adapter 100 | adapter_lora_path = "", 101 | adapter_lora_scale = 1.0, 102 | # image layers 103 | dreambooth_model_path = "", 104 | lora_model_path = "", 105 | lora_alpha = 0.8, 106 | ): 107 | # motion module 108 | unet_state_dict = {} 109 | if motion_module_path != "": 110 | print(f"load motion module from {motion_module_path}") 111 | if motion_module_path.endswith(".safetensors"): 112 | motion_module_state_dict = {} 113 | with safe_open(motion_module_path, framework="pt", device="cpu") as f: 114 | for key in f.keys(): 115 | motion_module_state_dict[key] = f.get_tensor(key) 116 | else: 117 | motion_module_state_dict = torch.load(motion_module_path, map_location="cpu") 118 | 119 | motion_module_state_dict = motion_module_state_dict["state_dict"] if "state_dict" in motion_module_state_dict else motion_module_state_dict 120 | motion_module_state_dict = {k.replace("module.", ""):v for k, v in motion_module_state_dict.items()} 121 | unet_state_dict.update({name: param for name, param in motion_module_state_dict.items() if "motion_modules." in name}) 122 | unet_state_dict.pop("animatediff_config", "") 123 | 124 | missing, unexpected = animation_pipeline.unet.load_state_dict(unet_state_dict, strict=False) 125 | print("missing: ", len(missing), " unexpected: ", len(unexpected)) 126 | assert len(unexpected) == 0 127 | del unet_state_dict 128 | 129 | # base model 130 | if dreambooth_model_path != "": 131 | print(f"load dreambooth model from {dreambooth_model_path}") 132 | if dreambooth_model_path.endswith(".safetensors"): 133 | dreambooth_state_dict = {} 134 | with safe_open(dreambooth_model_path, framework="pt", device="cpu") as f: 135 | for key in f.keys(): 136 | dreambooth_state_dict[key] = f.get_tensor(key) 137 | elif dreambooth_model_path.endswith(".ckpt"): 138 | dreambooth_state_dict = torch.load(dreambooth_model_path, map_location="cpu") 139 | 140 | # 1. vae 141 | converted_vae_checkpoint = convert_ldm_vae_checkpoint(dreambooth_state_dict, animation_pipeline.vae.config) 142 | animation_pipeline.vae.load_state_dict(converted_vae_checkpoint) 143 | # 2. unet 144 | converted_unet_checkpoint = convert_ldm_unet_checkpoint(dreambooth_state_dict, animation_pipeline.unet.config) 145 | animation_pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False) 146 | # 3. text_model 147 | animation_pipeline.text_encoder = convert_ldm_clip_checkpoint(dreambooth_state_dict) 148 | del dreambooth_state_dict 149 | 150 | # lora layers 151 | if lora_model_path != "": 152 | print(f"load lora model from {lora_model_path}") 153 | assert lora_model_path.endswith(".safetensors") 154 | lora_state_dict = {} 155 | with safe_open(lora_model_path, framework="pt", device="cpu") as f: 156 | for key in f.keys(): 157 | lora_state_dict[key] = f.get_tensor(key) 158 | 159 | # convert lora function은 각 layer에 맞춰서 weight를 더해주는 작업을 함 160 | animation_pipeline = convert_lora(animation_pipeline, lora_state_dict, alpha=lora_alpha) 161 | del lora_state_dict 162 | 163 | # domain adapter lora 164 | if adapter_lora_path != "": 165 | print(f"load domain lora from {adapter_lora_path}") 166 | 167 | if adapter_lora_path.endswith(".safetensors"): 168 | domain_lora_state_dict = {} 169 | with safe_open(adapter_lora_path, framework="pt", device="cpu") as f: 170 | for key in f.keys(): 171 | domain_lora_state_dict[key] = f.get_tensor(key) 172 | else: 173 | domain_lora_state_dict = torch.load(adapter_lora_path, map_location="cpu") 174 | 175 | domain_lora_state_dict = domain_lora_state_dict["state_dict"] if "state_dict" in domain_lora_state_dict else domain_lora_state_dict 176 | domain_lora_state_dict.pop("animatediff_config", "") 177 | 178 | animation_pipeline = load_diffusers_lora(animation_pipeline, domain_lora_state_dict, alpha=adapter_lora_scale) 179 | 180 | # motion module lora 181 | for motion_module_lora_config in motion_module_lora_configs: 182 | path, alpha = motion_module_lora_config["path"], motion_module_lora_config["alpha"] 183 | print(f"load motion LoRA from {path}") 184 | motion_lora_state_dict = torch.load(path, map_location="cpu") 185 | motion_lora_state_dict = motion_lora_state_dict["state_dict"] if "state_dict" in motion_lora_state_dict else motion_lora_state_dict 186 | motion_lora_state_dict.pop("animatediff_config", "") 187 | 188 | animation_pipeline = load_diffusers_lora(animation_pipeline, motion_lora_state_dict, alpha) 189 | 190 | return animation_pipeline 191 | 192 | def load_motion_module_weights( 193 | unet, 194 | # motion module 195 | motion_module_path = "", 196 | # domain adapter 197 | adapter_lora_path = "", 198 | adapter_lora_scale = 1.0, 199 | ): 200 | # motion module 201 | unet_state_dict = {} 202 | if motion_module_path != "": 203 | print(f"load motion module from {motion_module_path}") 204 | if motion_module_path.endswith(".safetensors"): 205 | motion_module_state_dict = {} 206 | with safe_open(motion_module_path, framework="pt", device="cpu") as f: 207 | for key in f.keys(): 208 | motion_module_state_dict[key] = f.get_tensor(key) 209 | else: 210 | motion_module_state_dict = torch.load(motion_module_path, map_location="cpu") 211 | 212 | motion_module_state_dict = motion_module_state_dict["state_dict"] if "state_dict" in motion_module_state_dict else motion_module_state_dict 213 | unet_state_dict.update({name: param for name, param in motion_module_state_dict.items() if "motion_modules." in name}) 214 | unet_state_dict.pop("animatediff_config", "") 215 | 216 | missing, unexpected = unet.load_state_dict(unet_state_dict, strict=False) 217 | assert len(unexpected) == 0 218 | del unet_state_dict 219 | 220 | if adapter_lora_path != "": 221 | print(f"load domain lora from {adapter_lora_path}") 222 | 223 | if adapter_lora_path.endswith(".safetensors"): 224 | domain_lora_state_dict = {} 225 | with safe_open(adapter_lora_path, framework="pt", device="cpu") as f: 226 | for key in f.keys(): 227 | domain_lora_state_dict[key] = f.get_tensor(key) 228 | else: 229 | domain_lora_state_dict = torch.load(adapter_lora_path, map_location="cpu") 230 | 231 | domain_lora_state_dict = domain_lora_state_dict["state_dict"] if "state_dict" in domain_lora_state_dict else domain_lora_state_dict 232 | domain_lora_state_dict.pop("animatediff_config", "") 233 | 234 | unet = load_diffusers_lora_unet(unet, domain_lora_state_dict, alpha=adapter_lora_scale) 235 | 236 | return unet -------------------------------------------------------------------------------- /animatediff_configs/inference/inference-v1.yaml: -------------------------------------------------------------------------------- 1 | unet_additional_kwargs: 2 | unet_use_cross_frame_attention: false 3 | unet_use_temporal_attention: false 4 | use_motion_module: true 5 | motion_module_resolutions: [1,2,4,8] 6 | motion_module_mid_block: false 7 | motion_module_decoder_only: false 8 | motion_module_type: "Vanilla" 9 | 10 | motion_module_kwargs: 11 | num_attention_heads: 8 12 | num_transformer_block: 1 13 | attention_block_types: [ "Temporal_Self", "Temporal_Self" ] 14 | temporal_position_encoding: true 15 | temporal_position_encoding_max_len: 24 16 | temporal_attention_dim_div: 1 17 | 18 | noise_scheduler_kwargs: 19 | beta_start: 0.00085 20 | beta_end: 0.012 21 | beta_schedule: "linear" 22 | steps_offset: 1 23 | clip_sample: False 24 | -------------------------------------------------------------------------------- /animatediff_configs/inference/inference-v2.yaml: -------------------------------------------------------------------------------- 1 | unet_additional_kwargs: 2 | use_inflated_groupnorm: true 3 | unet_use_cross_frame_attention: false 4 | unet_use_temporal_attention: false 5 | use_motion_module: true 6 | use_attn_fusion: true 7 | motion_module_resolutions: [1,2,4,8] 8 | motion_module_mid_block: true 9 | motion_module_decoder_only: false 10 | motion_module_type: "Vanilla" 11 | 12 | motion_module_kwargs: 13 | num_attention_heads: 8 14 | num_transformer_block: 1 15 | attention_block_types: [ "Temporal_Self", "Temporal_Self" ] 16 | temporal_position_encoding: true 17 | temporal_position_encoding_max_len: 32 18 | temporal_attention_dim_div: 1 19 | 20 | noise_scheduler_kwargs: 21 | beta_start: 0.00085 22 | beta_end: 0.012 23 | beta_schedule: "linear" 24 | steps_offset: 1 25 | clip_sample: False 26 | -------------------------------------------------------------------------------- /animatediff_configs/inference/inference-v3.yaml: -------------------------------------------------------------------------------- 1 | unet_additional_kwargs: 2 | use_inflated_groupnorm: true 3 | use_motion_module: true 4 | motion_module_resolutions: [1,2,4,8] 5 | motion_module_mid_block: false 6 | motion_module_type: Vanilla 7 | 8 | motion_module_kwargs: 9 | num_attention_heads: 8 10 | num_transformer_block: 1 11 | attention_block_types: [ "Temporal_Self", "Temporal_Self" ] 12 | temporal_position_encoding: true 13 | temporal_position_encoding_max_len: 32 14 | temporal_attention_dim_div: 1 15 | zero_initialize: true 16 | 17 | noise_scheduler_kwargs: 18 | beta_start: 0.00085 19 | beta_end: 0.012 20 | beta_schedule: "linear" 21 | steps_offset: 1 22 | clip_sample: False 23 | -------------------------------------------------------------------------------- /animatediff_configs/inference/sparsectrl/image_condition.yaml: -------------------------------------------------------------------------------- 1 | controlnet_additional_kwargs: 2 | set_noisy_sample_input_to_zero: true 3 | use_simplified_condition_embedding: false 4 | conditioning_channels: 3 5 | 6 | use_motion_module: true 7 | motion_module_resolutions: [1,2,4,8] 8 | motion_module_mid_block: false 9 | motion_module_type: "Vanilla" 10 | 11 | motion_module_kwargs: 12 | num_attention_heads: 8 13 | num_transformer_block: 1 14 | attention_block_types: [ "Temporal_Self" ] 15 | temporal_position_encoding: true 16 | temporal_position_encoding_max_len: 32 17 | temporal_attention_dim_div: 1 18 | -------------------------------------------------------------------------------- /animatediff_configs/inference/sparsectrl/latent_condition.yaml: -------------------------------------------------------------------------------- 1 | controlnet_additional_kwargs: 2 | set_noisy_sample_input_to_zero: true 3 | use_simplified_condition_embedding: true 4 | conditioning_channels: 4 5 | 6 | use_motion_module: true 7 | motion_module_resolutions: [1,2,4,8] 8 | motion_module_mid_block: false 9 | motion_module_type: "Vanilla" 10 | 11 | motion_module_kwargs: 12 | num_attention_heads: 8 13 | num_transformer_block: 1 14 | attention_block_types: [ "Temporal_Self" ] 15 | temporal_position_encoding: true 16 | temporal_position_encoding_max_len: 32 17 | temporal_attention_dim_div: 1 18 | -------------------------------------------------------------------------------- /animatediff_configs/prompts/v2/v2-1-Film.yaml: -------------------------------------------------------------------------------- 1 | - inference_config: "animatediff_configs/inference/inference-v2.yaml" 2 | motion_module: "../animatediff/models/Motion_Module/mm_sd_v15_v2.ckpt" # Change into your motion module path 3 | 4 | dreambooth_path: "../animatediff/models/DreamBooth_LoRA/leosamsFilmgirlUltra_ultraBaseModel.safetensors" # Change into your dreambooth model path 5 | lora_model_path: "" -------------------------------------------------------------------------------- /animatediff_configs/prompts/v2/v2-1-RCNZcartoon.yaml: -------------------------------------------------------------------------------- 1 | - inference_config: "animatediff_configs/inference/inference-v2.yaml" 2 | motion_module: "../animatediff/models/Motion_Module/mm_sd_v15_v2.ckpt" # Change into your motion module path 3 | 4 | dreambooth_path: "../animatediff/models/DreamBooth_LoRA/rcnzCartoon3d_v10.safetensors" # Change into your dreambooth model path 5 | lora_model_path: "" 6 | -------------------------------------------------------------------------------- /animatediff_configs/prompts/v2/v2-1-RealisticVision.yaml: -------------------------------------------------------------------------------- 1 | - inference_config: "animatediff_configs/inference/inference-v2.yaml" 2 | motion_module: "../animatediff/models/Motion_Module/mm_sd_v15_v2.ckpt" # Change into your motion module path 3 | 4 | dreambooth_path: "../animatediff/models/DreamBooth_LoRA/realisticVisionV51_v51VAE.safetensors" # Change into your dreambooth model path 5 | lora_model_path: "" -------------------------------------------------------------------------------- /animatediff_configs/prompts/v2/v2-1-ToonYou.yaml: -------------------------------------------------------------------------------- 1 | - inference_config: "animatediff_configs/inference/inference-v2.yaml" 2 | motion_module: "../animatediff/models/Motion_Module/mm_sd_v15_v2.ckpt" # Change into your motion module path 3 | 4 | dreambooth_path: "../animatediff/models/DreamBooth_LoRA/toonyou_beta3.safetensors" # Change into your dreambooth model path 5 | lora_model_path: "" 6 | -------------------------------------------------------------------------------- /animatediff_configs/prompts/v2/v2-1-w-dreambooth.yaml: -------------------------------------------------------------------------------- 1 | - inference_config: "animatediff_configs/inference/inference-v2.yaml" 2 | motion_module: "../animatediff/models/Motion_Module/mm_sd_v15_v2.ckpt" # Change into your motion module path 3 | 4 | dreambooth_path: "../animatediff/models/DreamBooth_LoRA/realisticVisionV20_v20.safetensors" # Change into your dreambooth model path 5 | lora_model_path: "" -------------------------------------------------------------------------------- /animatediff_configs/prompts/v2/v2-1-w.o-dreambooth.yaml: -------------------------------------------------------------------------------- 1 | - inference_config: "animatediff_configs/inference/inference-v2.yaml" 2 | motion_module: "../animatediff/models/Motion_Module/mm_sd_v15_v2.ckpt" # Change into your motion module path 3 | 4 | dreambooth_path: "" 5 | lora_model_path: "" 6 | -------------------------------------------------------------------------------- /animatediff_configs/training/v1/image_finetune.yaml: -------------------------------------------------------------------------------- 1 | image_finetune: true 2 | 3 | output_dir: "outputs" 4 | pretrained_model_path: "models/StableDiffusion/stable-diffusion-v1-5" 5 | 6 | noise_scheduler_kwargs: 7 | num_train_timesteps: 1000 8 | beta_start: 0.00085 9 | beta_end: 0.012 10 | beta_schedule: "scaled_linear" 11 | steps_offset: 1 12 | clip_sample: false 13 | 14 | train_data: 15 | csv_path: "/home/leedh3726/video-diffusion/video-data/Webvid2M/results_2M_val.csv" 16 | video_folder: "/home/leedh3726/video-diffusion/video-data/Webvid2M/video" 17 | sample_size: 256 18 | 19 | validation_data: 20 | prompts: 21 | - "Snow rocky mountains peaks canyon. Snow blanketed rocky mountains surround and shadow deep canyons." 22 | - "A drone view of celebration with Christma tree and fireworks, starry sky - background." 23 | - "Robot dancing in times square." 24 | - "Pacific coast, carmel by the sea ocean and waves." 25 | num_inference_steps: 25 26 | guidance_scale: 8. 27 | 28 | trainable_modules: 29 | - "." 30 | 31 | unet_checkpoint_path: "" 32 | 33 | learning_rate: 1.e-5 34 | train_batch_size: 10 35 | 36 | max_train_epoch: -1 37 | max_train_steps: 1000 38 | checkpointing_epochs: -1 39 | checkpointing_steps: 100 40 | 41 | validation_steps: 100 42 | validation_steps_tuple: [1] 43 | 44 | global_seed: 42 45 | mixed_precision_training: true 46 | enable_xformers_memory_efficient_attention: True 47 | 48 | is_debug: False 49 | -------------------------------------------------------------------------------- /animatediff_configs/training/v1/training.yaml: -------------------------------------------------------------------------------- 1 | image_finetune: false 2 | 3 | output_dir: "outputs" 4 | pretrained_model_path: "models/StableDiffusion/stable-diffusion-v1-5" 5 | motion_module: "models/Motion_Module/mm_sd_v15_v2.ckpt" 6 | 7 | unet_additional_kwargs: 8 | use_motion_module : true 9 | motion_module_resolutions : [ 1,2,4,8 ] 10 | unet_use_cross_frame_attention : false 11 | unet_use_temporal_attention : false 12 | motion_module_mid_block : true 13 | 14 | motion_module_type: Vanilla 15 | motion_module_kwargs: 16 | num_attention_heads : 8 17 | num_transformer_block : 1 18 | attention_block_types : [ "Temporal_Self", "Temporal_Self" ] 19 | temporal_position_encoding : true 20 | temporal_position_encoding_max_len : 32 21 | temporal_attention_dim_div : 1 22 | zero_initialize : true 23 | 24 | noise_scheduler_kwargs: 25 | num_train_timesteps: 1000 26 | beta_start: 0.00085 27 | beta_end: 0.012 28 | beta_schedule: "linear" 29 | steps_offset: 1 30 | clip_sample: false 31 | 32 | train_data: 33 | csv_path: "/home/leedh3726/video-diffusion/video-data/Webvid2M/results_2M_val.csv" 34 | video_folder: "/home/leedh3726/video-diffusion/video-data/Webvid2M/video" 35 | sample_size: 256 36 | sample_stride: 4 37 | sample_n_frames: 16 38 | 39 | validation_data: 40 | prompts: 41 | - "Snow rocky mountains peaks canyon. Snow blanketed rocky mountains surround and shadow deep canyons." 42 | - "A drone view of celebration with Christma tree and fireworks, starry sky - background." 43 | - "Robot dancing in times square." 44 | - "Pacific coast, carmel by the sea ocean and waves." 45 | num_inference_steps: 16 46 | guidance_scale: 7.5 47 | 48 | trainable_modules: 49 | - "motion_modules." 50 | 51 | unet_checkpoint_path: "" 52 | 53 | learning_rate: 1.e-4 54 | train_batch_size: 1 55 | 56 | max_train_epoch: -1 57 | max_train_steps: 3000 58 | checkpointing_epochs: -1 59 | checkpointing_steps: 100 60 | 61 | validation_steps: 100 62 | validation_steps_tuple: [1] 63 | 64 | global_seed: 42 65 | mixed_precision_training: true 66 | enable_xformers_memory_efficient_attention: True 67 | 68 | is_debug: False 69 | -------------------------------------------------------------------------------- /assets/main_fig.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoHunLee1/VideoGuide/adfe4eeff279a611c76f948ec9cc2829ed3f123e/assets/main_fig.png -------------------------------------------------------------------------------- /convert_lora_safetensor_to_diffusers.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023, Haofan Wang, Qixun Wang, All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | # Changes were made to this source code by Yuwei Guo. 17 | """ Conversion script for the LoRA's safetensors checkpoints. """ 18 | 19 | import argparse 20 | 21 | import torch 22 | from safetensors.torch import load_file 23 | 24 | from diffusers import StableDiffusionPipeline 25 | 26 | 27 | def load_diffusers_lora(pipeline, state_dict, alpha=1.0): 28 | # directly update weight in diffusers model 29 | for key in state_dict: 30 | # only process lora down key 31 | if "up." in key: continue 32 | 33 | up_key = key.replace(".down.", ".up.") 34 | model_key = key.replace("processor.", "").replace("_lora", "").replace("down.", "").replace("up.", "") 35 | model_key = model_key.replace("to_out.", "to_out.0.") 36 | layer_infos = model_key.split(".")[:-1] 37 | 38 | curr_layer = pipeline.unet 39 | while len(layer_infos) > 0: 40 | temp_name = layer_infos.pop(0) 41 | curr_layer = curr_layer.__getattr__(temp_name) 42 | 43 | weight_down = state_dict[key] 44 | weight_up = state_dict[up_key] 45 | curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device) 46 | 47 | return pipeline 48 | 49 | def load_diffusers_lora_unet(unet, state_dict, alpha=1.0): 50 | for key in state_dict: 51 | # only process lora down key 52 | if "up." in key: continue 53 | 54 | up_key = key.replace(".down.", ".up.") 55 | model_key = key.replace("processor.", "").replace("_lora", "").replace("down.", "").replace("up.", "") 56 | model_key = model_key.replace("to_out.", "to_out.0.") 57 | layer_infos = model_key.split(".")[:-1] 58 | 59 | curr_layer = unet 60 | while len(layer_infos) > 0: 61 | temp_name = layer_infos.pop(0) 62 | curr_layer = curr_layer.__getattr__(temp_name) 63 | 64 | weight_down = state_dict[key] 65 | weight_up = state_dict[up_key] 66 | curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device) 67 | 68 | return unet 69 | 70 | 71 | def convert_lora(pipeline, state_dict, LORA_PREFIX_UNET="lora_unet", LORA_PREFIX_TEXT_ENCODER="lora_te", alpha=0.6): 72 | # load base model 73 | # pipeline = StableDiffusionPipeline.from_pretrained(base_model_path, torch_dtype=torch.float32) 74 | 75 | # load LoRA weight from .safetensors 76 | # state_dict = load_file(checkpoint_path) 77 | 78 | visited = [] 79 | 80 | # directly update weight in diffusers model 81 | for key in state_dict: 82 | # it is suggested to print out the key, it usually will be something like below 83 | # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight" 84 | 85 | # as we have set the alpha beforehand, so just skip 86 | if ".alpha" in key or key in visited: 87 | continue 88 | 89 | if "text" in key: 90 | layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_") 91 | curr_layer = pipeline.text_encoder 92 | else: 93 | layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_") 94 | curr_layer = pipeline.unet 95 | 96 | # find the target layer 97 | temp_name = layer_infos.pop(0) 98 | while len(layer_infos) > -1: 99 | try: 100 | curr_layer = curr_layer.__getattr__(temp_name) 101 | if len(layer_infos) > 0: 102 | temp_name = layer_infos.pop(0) 103 | elif len(layer_infos) == 0: 104 | break 105 | except Exception: 106 | if len(temp_name) > 0: 107 | temp_name += "_" + layer_infos.pop(0) 108 | else: 109 | temp_name = layer_infos.pop(0) 110 | 111 | pair_keys = [] 112 | if "lora_down" in key: 113 | pair_keys.append(key.replace("lora_down", "lora_up")) 114 | pair_keys.append(key) 115 | else: 116 | pair_keys.append(key) 117 | pair_keys.append(key.replace("lora_up", "lora_down")) 118 | 119 | # update weight 120 | if len(state_dict[pair_keys[0]].shape) == 4: 121 | weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32) 122 | weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32) 123 | curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3).to(curr_layer.weight.data.device) 124 | else: 125 | weight_up = state_dict[pair_keys[0]].to(torch.float32) 126 | weight_down = state_dict[pair_keys[1]].to(torch.float32) 127 | curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device) 128 | 129 | # update visited list 130 | for item in pair_keys: 131 | visited.append(item) 132 | 133 | return pipeline 134 | 135 | 136 | if __name__ == "__main__": 137 | parser = argparse.ArgumentParser() 138 | 139 | parser.add_argument( 140 | "--base_model_path", default=None, type=str, required=True, help="Path to the base model in diffusers format." 141 | ) 142 | parser.add_argument( 143 | "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert." 144 | ) 145 | parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") 146 | parser.add_argument( 147 | "--lora_prefix_unet", default="lora_unet", type=str, help="The prefix of UNet weight in safetensors" 148 | ) 149 | parser.add_argument( 150 | "--lora_prefix_text_encoder", 151 | default="lora_te", 152 | type=str, 153 | help="The prefix of text encoder weight in safetensors", 154 | ) 155 | parser.add_argument("--alpha", default=0.75, type=float, help="The merging ratio in W = W0 + alpha * deltaW") 156 | parser.add_argument( 157 | "--to_safetensors", action="store_true", help="Whether to store pipeline in safetensors format or not." 158 | ) 159 | parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)") 160 | 161 | args = parser.parse_args() 162 | 163 | base_model_path = args.base_model_path 164 | checkpoint_path = args.checkpoint_path 165 | dump_path = args.dump_path 166 | lora_prefix_unet = args.lora_prefix_unet 167 | lora_prefix_text_encoder = args.lora_prefix_text_encoder 168 | alpha = args.alpha 169 | 170 | pipe = convert(base_model_path, checkpoint_path, lora_prefix_unet, lora_prefix_text_encoder, alpha) 171 | 172 | pipe = pipe.to(args.device) 173 | pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors) 174 | -------------------------------------------------------------------------------- /filter_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.fft as fft 3 | import math 4 | 5 | 6 | def freq_mix_3d(x, noise, LPF): 7 | """ 8 | Noise reinitialization. 9 | 10 | Args: 11 | x: diffused latent 12 | noise: randomly sampled noise 13 | LPF: low pass filter 14 | """ 15 | if LPF is None: 16 | return x 17 | 18 | # FFT 19 | x_freq = fft.fftn(x, dim=(-3, -2, -1)) 20 | x_freq = fft.fftshift(x_freq, dim=(-3, -2, -1)) 21 | noise_freq = fft.fftn(noise, dim=(-3, -2, -1)) 22 | noise_freq = fft.fftshift(noise_freq, dim=(-3, -2, -1)) 23 | 24 | # frequency mix 25 | HPF = 1 - LPF 26 | x_freq_low = x_freq * LPF 27 | noise_freq_high = noise_freq * HPF 28 | x_freq_mixed = x_freq_low + noise_freq_high # mix in freq domain 29 | 30 | # IFFT 31 | x_freq_mixed = fft.ifftshift(x_freq_mixed, dim=(-3, -2, -1)) 32 | x_mixed = fft.ifftn(x_freq_mixed, dim=(-3, -2, -1)).real 33 | 34 | return x_mixed 35 | 36 | 37 | def get_freq_filter(shape, device, params: dict): 38 | """ 39 | Form the frequency filter for noise reinitialization. 40 | 41 | Args: 42 | shape: shape of latent (B, C, T, H, W) 43 | params: filter parameters 44 | """ 45 | if params.method == "gaussian": 46 | return gaussian_low_pass_filter(shape=shape, d_s=params.d_s, d_t=params.d_t).to(device) 47 | elif params.method == "ideal": 48 | return ideal_low_pass_filter(shape=shape, d_s=params.d_s, d_t=params.d_t).to(device) 49 | elif params.method == "box": 50 | return box_low_pass_filter(shape=shape, d_s=params.d_s, d_t=params.d_t).to(device) 51 | elif params.method == "butterworth": 52 | return butterworth_low_pass_filter(shape=shape, n=params.n, d_s=params.d_s, d_t=params.d_t).to(device) 53 | else: 54 | raise NotImplementedError 55 | 56 | def gaussian_low_pass_filter(shape, d_s=0.25, d_t=0.25): 57 | """ 58 | Compute the gaussian low pass filter mask. 59 | 60 | Args: 61 | shape: shape of the filter (volume) 62 | d_s: normalized stop frequency for spatial dimensions (0.0-1.0) 63 | d_t: normalized stop frequency for temporal dimension (0.0-1.0) 64 | """ 65 | T, H, W = shape[-3], shape[-2], shape[-1] 66 | mask = torch.zeros(shape) 67 | if d_s==0 or d_t==0: 68 | return mask 69 | for t in range(T): 70 | for h in range(H): 71 | for w in range(W): 72 | d_square = (((d_s/d_t)*(2*t/T-1))**2 + (2*h/H-1)**2 + (2*w/W-1)**2) 73 | mask[..., t,h,w] = math.exp(-1/(2*d_s**2) * d_square) 74 | return mask 75 | 76 | 77 | def butterworth_low_pass_filter(shape, n=4, d_s=0.25, d_t=0.25): 78 | """ 79 | Compute the butterworth low pass filter mask. 80 | 81 | Args: 82 | shape: shape of the filter (volume) 83 | n: order of the filter, larger n ~ ideal, smaller n ~ gaussian 84 | d_s: normalized stop frequency for spatial dimensions (0.0-1.0) 85 | d_t: normalized stop frequency for temporal dimension (0.0-1.0) 86 | """ 87 | T, H, W = shape[-3], shape[-2], shape[-1] 88 | mask = torch.zeros(shape) 89 | if d_s==0 or d_t==0: 90 | return mask 91 | for t in range(T): 92 | for h in range(H): 93 | for w in range(W): 94 | d_square = (((d_s/d_t)*(2*t/T-1))**2 + (2*h/H-1)**2 + (2*w/W-1)**2) 95 | mask[..., t,h,w] = 1 / (1 + (d_square / d_s**2)**n) 96 | return mask 97 | 98 | 99 | def ideal_low_pass_filter(shape, d_s=0.25, d_t=0.25): 100 | """ 101 | Compute the ideal low pass filter mask. 102 | 103 | Args: 104 | shape: shape of the filter (volume) 105 | d_s: normalized stop frequency for spatial dimensions (0.0-1.0) 106 | d_t: normalized stop frequency for temporal dimension (0.0-1.0) 107 | """ 108 | T, H, W = shape[-3], shape[-2], shape[-1] 109 | mask = torch.zeros(shape) 110 | if d_s==0 or d_t==0: 111 | return mask 112 | for t in range(T): 113 | for h in range(H): 114 | for w in range(W): 115 | d_square = (((d_s/d_t)*(2*t/T-1))**2 + (2*h/H-1)**2 + (2*w/W-1)**2) 116 | mask[..., t,h,w] = 1 if d_square <= d_s*2 else 0 117 | return mask 118 | 119 | 120 | def box_low_pass_filter(shape, d_s=0.25, d_t=0.25): 121 | """ 122 | Compute the ideal low pass filter mask (approximated version). 123 | 124 | Args: 125 | shape: shape of the filter (volume) 126 | d_s: normalized stop frequency for spatial dimensions (0.0-1.0) 127 | d_t: normalized stop frequency for temporal dimension (0.0-1.0) 128 | """ 129 | T, H, W = shape[-3], shape[-2], shape[-1] 130 | mask = torch.zeros(shape) 131 | if d_s==0 or d_t==0: 132 | return mask 133 | 134 | threshold_s = round(int(H // 2) * d_s) 135 | threshold_t = round(T // 2 * d_t) 136 | 137 | cframe, crow, ccol = T // 2, H // 2, W //2 138 | mask[..., cframe - threshold_t:cframe + threshold_t, crow - threshold_s:crow + threshold_s, ccol - threshold_s:ccol + threshold_s] = 1.0 139 | 140 | return mask -------------------------------------------------------------------------------- /inference.sh: -------------------------------------------------------------------------------- 1 | name="videoguide" 2 | 3 | animatediff_ckpt='../animatediff/models/StableDiffusion/stable-diffusion-v1-5' # Change into your stable-diffusion-v1-5 path 4 | ckpt='../checkpoints/videocrafter/base_512_v2/model.ckpt' # Change into your videocrafter path 5 | # config='./animatediff_configs/prompts/v2/v2-1-Film.yaml' 6 | # config='./animatediff_configs/prompts/v2/v2-1-ToonYou.yaml' 7 | # config='./animatediff_configs/prompts/v2/v2-1-RealisticVision.yaml' 8 | config='./animatediff_configs/prompts/v2/v2-1-w-dreambooth.yaml' 9 | vc_config='./vc_configs/inference_t2v_512_v2.0.yaml' 10 | 11 | prompts="./prompts/prompt.txt" 12 | 13 | python t2v_vc_guide.py \ 14 | --seed 42 \ 15 | --video_length 16 \ 16 | --fps 8 \ 17 | --cfg_scale 0.8 \ 18 | --animatediff_model_path $animatediff_ckpt \ 19 | --vc_model_path $ckpt \ 20 | --config $config \ 21 | --vc_config $vc_config \ 22 | --num_step 50 \ 23 | --savedir "./result" \ 24 | --precision 'float16' \ 25 | --prompt "$prompts" \ 26 | --mode 1 \ 27 | --cfg_plus 28 | 29 | 30 | -------------------------------------------------------------------------------- /lvdm/__pycache__/basics.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoHunLee1/VideoGuide/adfe4eeff279a611c76f948ec9cc2829ed3f123e/lvdm/__pycache__/basics.cpython-310.pyc -------------------------------------------------------------------------------- /lvdm/__pycache__/basics.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoHunLee1/VideoGuide/adfe4eeff279a611c76f948ec9cc2829ed3f123e/lvdm/__pycache__/basics.cpython-38.pyc -------------------------------------------------------------------------------- /lvdm/__pycache__/common.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoHunLee1/VideoGuide/adfe4eeff279a611c76f948ec9cc2829ed3f123e/lvdm/__pycache__/common.cpython-310.pyc -------------------------------------------------------------------------------- /lvdm/__pycache__/common.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoHunLee1/VideoGuide/adfe4eeff279a611c76f948ec9cc2829ed3f123e/lvdm/__pycache__/common.cpython-311.pyc -------------------------------------------------------------------------------- /lvdm/__pycache__/common.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoHunLee1/VideoGuide/adfe4eeff279a611c76f948ec9cc2829ed3f123e/lvdm/__pycache__/common.cpython-38.pyc -------------------------------------------------------------------------------- /lvdm/__pycache__/distributions.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoHunLee1/VideoGuide/adfe4eeff279a611c76f948ec9cc2829ed3f123e/lvdm/__pycache__/distributions.cpython-310.pyc -------------------------------------------------------------------------------- /lvdm/__pycache__/distributions.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoHunLee1/VideoGuide/adfe4eeff279a611c76f948ec9cc2829ed3f123e/lvdm/__pycache__/distributions.cpython-38.pyc -------------------------------------------------------------------------------- /lvdm/__pycache__/ema.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoHunLee1/VideoGuide/adfe4eeff279a611c76f948ec9cc2829ed3f123e/lvdm/__pycache__/ema.cpython-310.pyc -------------------------------------------------------------------------------- /lvdm/__pycache__/ema.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoHunLee1/VideoGuide/adfe4eeff279a611c76f948ec9cc2829ed3f123e/lvdm/__pycache__/ema.cpython-38.pyc -------------------------------------------------------------------------------- /lvdm/basics.py: -------------------------------------------------------------------------------- 1 | # adopted from 2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 3 | # and 4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py 5 | # and 6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py 7 | # 8 | # thanks! 9 | 10 | import torch.nn as nn 11 | from vc_utils.utils import instantiate_from_config 12 | 13 | 14 | def disabled_train(self, mode=True): 15 | """Overwrite model.train with this function to make sure train/eval mode 16 | does not change anymore.""" 17 | return self 18 | 19 | def zero_module(module): 20 | """ 21 | Zero out the parameters of a module and return it. 22 | """ 23 | for p in module.parameters(): 24 | p.detach().zero_() 25 | return module 26 | 27 | def scale_module(module, scale): 28 | """ 29 | Scale the parameters of a module and return it. 30 | """ 31 | for p in module.parameters(): 32 | p.detach().mul_(scale) 33 | return module 34 | 35 | 36 | def conv_nd(dims, *args, **kwargs): 37 | """ 38 | Create a 1D, 2D, or 3D convolution module. 39 | """ 40 | if dims == 1: 41 | return nn.Conv1d(*args, **kwargs) 42 | elif dims == 2: 43 | return nn.Conv2d(*args, **kwargs) 44 | elif dims == 3: 45 | return nn.Conv3d(*args, **kwargs) 46 | raise ValueError(f"unsupported dimensions: {dims}") 47 | 48 | 49 | def linear(*args, **kwargs): 50 | """ 51 | Create a linear module. 52 | """ 53 | return nn.Linear(*args, **kwargs) 54 | 55 | 56 | def avg_pool_nd(dims, *args, **kwargs): 57 | """ 58 | Create a 1D, 2D, or 3D average pooling module. 59 | """ 60 | if dims == 1: 61 | return nn.AvgPool1d(*args, **kwargs) 62 | elif dims == 2: 63 | return nn.AvgPool2d(*args, **kwargs) 64 | elif dims == 3: 65 | return nn.AvgPool3d(*args, **kwargs) 66 | raise ValueError(f"unsupported dimensions: {dims}") 67 | 68 | 69 | def nonlinearity(type='silu'): 70 | if type == 'silu': 71 | return nn.SiLU() 72 | elif type == 'leaky_relu': 73 | return nn.LeakyReLU() 74 | 75 | 76 | class GroupNormSpecific(nn.GroupNorm): 77 | def forward(self, x): 78 | return super().forward(x.float()).type(x.dtype) 79 | 80 | 81 | def normalization(channels, num_groups=32): 82 | """ 83 | Make a standard normalization layer. 84 | :param channels: number of input channels. 85 | :return: an nn.Module for normalization. 86 | """ 87 | return GroupNormSpecific(num_groups, channels) 88 | 89 | 90 | class HybridConditioner(nn.Module): 91 | 92 | def __init__(self, c_concat_config, c_crossattn_config): 93 | super().__init__() 94 | self.concat_conditioner = instantiate_from_config(c_concat_config) 95 | self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) 96 | 97 | def forward(self, c_concat, c_crossattn): 98 | c_concat = self.concat_conditioner(c_concat) 99 | c_crossattn = self.crossattn_conditioner(c_crossattn) 100 | return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} -------------------------------------------------------------------------------- /lvdm/common.py: -------------------------------------------------------------------------------- 1 | import math 2 | from inspect import isfunction 3 | import torch 4 | from torch import nn 5 | import torch.distributed as dist 6 | 7 | 8 | def gather_data(data, return_np=True): 9 | ''' gather data from multiple processes to one list ''' 10 | data_list = [torch.zeros_like(data) for _ in range(dist.get_world_size())] 11 | dist.all_gather(data_list, data) # gather not supported with NCCL 12 | if return_np: 13 | data_list = [data.cpu().numpy() for data in data_list] 14 | return data_list 15 | 16 | def autocast(f): 17 | def do_autocast(*args, **kwargs): 18 | with torch.cuda.amp.autocast(enabled=True, 19 | dtype=torch.get_autocast_gpu_dtype(), 20 | cache_enabled=torch.is_autocast_cache_enabled()): 21 | return f(*args, **kwargs) 22 | return do_autocast 23 | 24 | 25 | def extract_into_tensor(a, t, x_shape): 26 | b, *_ = t.shape 27 | out = a.gather(-1, t) 28 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 29 | 30 | 31 | def noise_like(shape, device, repeat=False): 32 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) 33 | noise = lambda: torch.randn(shape, device=device) 34 | return repeat_noise() if repeat else noise() 35 | 36 | 37 | def default(val, d): 38 | if exists(val): 39 | return val 40 | return d() if isfunction(d) else d 41 | 42 | def exists(val): 43 | return val is not None 44 | 45 | def identity(*args, **kwargs): 46 | return nn.Identity() 47 | 48 | def uniq(arr): 49 | return{el: True for el in arr}.keys() 50 | 51 | def mean_flat(tensor): 52 | """ 53 | Take the mean over all non-batch dimensions. 54 | """ 55 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 56 | 57 | def ismap(x): 58 | if not isinstance(x, torch.Tensor): 59 | return False 60 | return (len(x.shape) == 4) and (x.shape[1] > 3) 61 | 62 | def isimage(x): 63 | if not isinstance(x,torch.Tensor): 64 | return False 65 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) 66 | 67 | def max_neg_value(t): 68 | return -torch.finfo(t.dtype).max 69 | 70 | def shape_to_str(x): 71 | shape_str = "x".join([str(x) for x in x.shape]) 72 | return shape_str 73 | 74 | def init_(tensor): 75 | dim = tensor.shape[-1] 76 | std = 1 / math.sqrt(dim) 77 | tensor.uniform_(-std, std) 78 | return tensor 79 | 80 | ckpt = torch.utils.checkpoint.checkpoint 81 | def checkpoint(func, inputs, params, flag): 82 | """ 83 | Evaluate a function without caching intermediate activations, allowing for 84 | reduced memory at the expense of extra compute in the backward pass. 85 | :param func: the function to evaluate. 86 | :param inputs: the argument sequence to pass to `func`. 87 | :param params: a sequence of parameters `func` depends on but does not 88 | explicitly take as arguments. 89 | :param flag: if False, disable gradient checkpointing. 90 | """ 91 | if flag: 92 | return ckpt(func, *inputs) 93 | else: 94 | return func(*inputs) 95 | 96 | -------------------------------------------------------------------------------- /lvdm/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self, noise=None): 36 | if noise is None: 37 | noise = torch.randn(self.mean.shape) 38 | 39 | x = self.mean + self.std * noise.to(device=self.parameters.device) 40 | return x 41 | 42 | def kl(self, other=None): 43 | if self.deterministic: 44 | return torch.Tensor([0.]) 45 | else: 46 | if other is None: 47 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 48 | + self.var - 1.0 - self.logvar, 49 | dim=[1, 2, 3]) 50 | else: 51 | return 0.5 * torch.sum( 52 | torch.pow(self.mean - other.mean, 2) / other.var 53 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 54 | dim=[1, 2, 3]) 55 | 56 | def nll(self, sample, dims=[1,2,3]): 57 | if self.deterministic: 58 | return torch.Tensor([0.]) 59 | logtwopi = np.log(2.0 * np.pi) 60 | return 0.5 * torch.sum( 61 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 62 | dim=dims) 63 | 64 | def mode(self): 65 | return self.mean 66 | 67 | 68 | def normal_kl(mean1, logvar1, mean2, logvar2): 69 | """ 70 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 71 | Compute the KL divergence between two gaussians. 72 | Shapes are automatically broadcasted, so batches can be compared to 73 | scalars, among other use cases. 74 | """ 75 | tensor = None 76 | for obj in (mean1, logvar1, mean2, logvar2): 77 | if isinstance(obj, torch.Tensor): 78 | tensor = obj 79 | break 80 | assert tensor is not None, "at least one argument must be a Tensor" 81 | 82 | # Force variances to be Tensors. Broadcasting helps convert scalars to 83 | # Tensors, but it does not work for torch.exp(). 84 | logvar1, logvar2 = [ 85 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 86 | for x in (logvar1, logvar2) 87 | ] 88 | 89 | return 0.5 * ( 90 | -1.0 91 | + logvar2 92 | - logvar1 93 | + torch.exp(logvar1 - logvar2) 94 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 95 | ) 96 | -------------------------------------------------------------------------------- /lvdm/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError('Decay must be between 0 and 1') 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates 14 | else torch.tensor(-1,dtype=torch.int)) 15 | 16 | for name, p in model.named_parameters(): 17 | if p.requires_grad: 18 | #remove as '.'-character is not allowed in buffers 19 | s_name = name.replace('.','') 20 | self.m_name2s_name.update({name:s_name}) 21 | self.register_buffer(s_name,p.clone().detach().data) 22 | 23 | self.collected_params = [] 24 | 25 | def forward(self,model): 26 | decay = self.decay 27 | 28 | if self.num_updates >= 0: 29 | self.num_updates += 1 30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) 31 | 32 | one_minus_decay = 1.0 - decay 33 | 34 | with torch.no_grad(): 35 | m_param = dict(model.named_parameters()) 36 | shadow_params = dict(self.named_buffers()) 37 | 38 | for key in m_param: 39 | if m_param[key].requires_grad: 40 | sname = self.m_name2s_name[key] 41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 43 | else: 44 | assert not key in self.m_name2s_name 45 | 46 | def copy_to(self, model): 47 | m_param = dict(model.named_parameters()) 48 | shadow_params = dict(self.named_buffers()) 49 | for key in m_param: 50 | if m_param[key].requires_grad: 51 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 52 | else: 53 | assert not key in self.m_name2s_name 54 | 55 | def store(self, parameters): 56 | """ 57 | Save the current parameters for restoring later. 58 | Args: 59 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 60 | temporarily stored. 61 | """ 62 | self.collected_params = [param.clone() for param in parameters] 63 | 64 | def restore(self, parameters): 65 | """ 66 | Restore the parameters stored with the `store` method. 67 | Useful to validate the model with EMA parameters without affecting the 68 | original optimization process. Store the parameters before the 69 | `copy_to` method. After validation (or model saving), use this to 70 | restore the former parameters. 71 | Args: 72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 73 | updated with the stored parameters. 74 | """ 75 | for c_param, param in zip(self.collected_params, parameters): 76 | param.data.copy_(c_param.data) 77 | -------------------------------------------------------------------------------- /lvdm/models/__pycache__/autoencoder.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoHunLee1/VideoGuide/adfe4eeff279a611c76f948ec9cc2829ed3f123e/lvdm/models/__pycache__/autoencoder.cpython-310.pyc -------------------------------------------------------------------------------- /lvdm/models/__pycache__/autoencoder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoHunLee1/VideoGuide/adfe4eeff279a611c76f948ec9cc2829ed3f123e/lvdm/models/__pycache__/autoencoder.cpython-38.pyc -------------------------------------------------------------------------------- /lvdm/models/__pycache__/ddpm3d.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoHunLee1/VideoGuide/adfe4eeff279a611c76f948ec9cc2829ed3f123e/lvdm/models/__pycache__/ddpm3d.cpython-310.pyc -------------------------------------------------------------------------------- /lvdm/models/__pycache__/ddpm3d.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoHunLee1/VideoGuide/adfe4eeff279a611c76f948ec9cc2829ed3f123e/lvdm/models/__pycache__/ddpm3d.cpython-38.pyc -------------------------------------------------------------------------------- /lvdm/models/__pycache__/utils_diffusion.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoHunLee1/VideoGuide/adfe4eeff279a611c76f948ec9cc2829ed3f123e/lvdm/models/__pycache__/utils_diffusion.cpython-310.pyc -------------------------------------------------------------------------------- /lvdm/models/__pycache__/utils_diffusion.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoHunLee1/VideoGuide/adfe4eeff279a611c76f948ec9cc2829ed3f123e/lvdm/models/__pycache__/utils_diffusion.cpython-311.pyc -------------------------------------------------------------------------------- /lvdm/models/__pycache__/utils_diffusion.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoHunLee1/VideoGuide/adfe4eeff279a611c76f948ec9cc2829ed3f123e/lvdm/models/__pycache__/utils_diffusion.cpython-38.pyc -------------------------------------------------------------------------------- /lvdm/models/autoencoder.py: -------------------------------------------------------------------------------- 1 | import os 2 | from contextlib import contextmanager 3 | import torch 4 | import numpy as np 5 | from einops import rearrange 6 | import torch.nn.functional as F 7 | import pytorch_lightning as pl 8 | from lvdm.modules.networks.ae_modules import Encoder, Decoder 9 | from lvdm.distributions import DiagonalGaussianDistribution 10 | from vc_utils.utils import instantiate_from_config 11 | 12 | 13 | class AutoencoderKL(pl.LightningModule): 14 | def __init__(self, 15 | ddconfig, 16 | lossconfig, 17 | embed_dim, 18 | ckpt_path=None, 19 | ignore_keys=[], 20 | image_key="image", 21 | colorize_nlabels=None, 22 | monitor=None, 23 | test=False, 24 | logdir=None, 25 | input_dim=4, 26 | test_args=None, 27 | ): 28 | super().__init__() 29 | self.image_key = image_key 30 | self.encoder = Encoder(**ddconfig) 31 | self.decoder = Decoder(**ddconfig) 32 | self.loss = instantiate_from_config(lossconfig) 33 | assert ddconfig["double_z"] 34 | self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) 35 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) 36 | self.embed_dim = embed_dim 37 | self.input_dim = input_dim 38 | self.test = test 39 | self.test_args = test_args 40 | self.logdir = logdir 41 | if colorize_nlabels is not None: 42 | assert type(colorize_nlabels)==int 43 | self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) 44 | if monitor is not None: 45 | self.monitor = monitor 46 | if ckpt_path is not None: 47 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) 48 | if self.test: 49 | self.init_test() 50 | 51 | def init_test(self,): 52 | self.test = True 53 | save_dir = os.path.join(self.logdir, "test") 54 | if 'ckpt' in self.test_args: 55 | ckpt_name = os.path.basename(self.test_args.ckpt).split('.ckpt')[0] + f'_epoch{self._cur_epoch}' 56 | self.root = os.path.join(save_dir, ckpt_name) 57 | else: 58 | self.root = save_dir 59 | if 'test_subdir' in self.test_args: 60 | self.root = os.path.join(save_dir, self.test_args.test_subdir) 61 | 62 | self.root_zs = os.path.join(self.root, "zs") 63 | self.root_dec = os.path.join(self.root, "reconstructions") 64 | self.root_inputs = os.path.join(self.root, "inputs") 65 | os.makedirs(self.root, exist_ok=True) 66 | 67 | if self.test_args.save_z: 68 | os.makedirs(self.root_zs, exist_ok=True) 69 | if self.test_args.save_reconstruction: 70 | os.makedirs(self.root_dec, exist_ok=True) 71 | if self.test_args.save_input: 72 | os.makedirs(self.root_inputs, exist_ok=True) 73 | assert(self.test_args is not None) 74 | self.test_maximum = getattr(self.test_args, 'test_maximum', None) 75 | self.count = 0 76 | self.eval_metrics = {} 77 | self.decodes = [] 78 | self.save_decode_samples = 2048 79 | 80 | def init_from_ckpt(self, path, ignore_keys=list()): 81 | sd = torch.load(path, map_location="cpu") 82 | try: 83 | self._cur_epoch = sd['epoch'] 84 | sd = sd["state_dict"] 85 | except: 86 | self._cur_epoch = 'null' 87 | keys = list(sd.keys()) 88 | for k in keys: 89 | for ik in ignore_keys: 90 | if k.startswith(ik): 91 | print("Deleting key {} from state_dict.".format(k)) 92 | del sd[k] 93 | self.load_state_dict(sd, strict=False) 94 | # self.load_state_dict(sd, strict=True) 95 | print(f"Restored from {path}") 96 | 97 | def encode(self, x, **kwargs): 98 | 99 | h = self.encoder(x) 100 | moments = self.quant_conv(h) 101 | posterior = DiagonalGaussianDistribution(moments) 102 | return posterior 103 | 104 | def decode(self, z, **kwargs): 105 | z = self.post_quant_conv(z) 106 | dec = self.decoder(z) 107 | return dec 108 | 109 | def forward(self, input, sample_posterior=True): 110 | posterior = self.encode(input) 111 | if sample_posterior: 112 | z = posterior.sample() 113 | else: 114 | z = posterior.mode() 115 | dec = self.decode(z) 116 | return dec, posterior 117 | 118 | def get_input(self, batch, k): 119 | x = batch[k] 120 | if x.dim() == 5 and self.input_dim == 4: 121 | b,c,t,h,w = x.shape 122 | self.b = b 123 | self.t = t 124 | x = rearrange(x, 'b c t h w -> (b t) c h w') 125 | 126 | return x 127 | 128 | def training_step(self, batch, batch_idx, optimizer_idx): 129 | inputs = self.get_input(batch, self.image_key) 130 | reconstructions, posterior = self(inputs) 131 | 132 | if optimizer_idx == 0: 133 | # train encoder+decoder+logvar 134 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, 135 | last_layer=self.get_last_layer(), split="train") 136 | self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) 137 | self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) 138 | return aeloss 139 | 140 | if optimizer_idx == 1: 141 | # train the discriminator 142 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, 143 | last_layer=self.get_last_layer(), split="train") 144 | 145 | self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) 146 | self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) 147 | return discloss 148 | 149 | def validation_step(self, batch, batch_idx): 150 | inputs = self.get_input(batch, self.image_key) 151 | reconstructions, posterior = self(inputs) 152 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, 153 | last_layer=self.get_last_layer(), split="val") 154 | 155 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, 156 | last_layer=self.get_last_layer(), split="val") 157 | 158 | self.log("val/rec_loss", log_dict_ae["val/rec_loss"]) 159 | self.log_dict(log_dict_ae) 160 | self.log_dict(log_dict_disc) 161 | return self.log_dict 162 | 163 | def configure_optimizers(self): 164 | lr = self.learning_rate 165 | opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ 166 | list(self.decoder.parameters())+ 167 | list(self.quant_conv.parameters())+ 168 | list(self.post_quant_conv.parameters()), 169 | lr=lr, betas=(0.5, 0.9)) 170 | opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), 171 | lr=lr, betas=(0.5, 0.9)) 172 | return [opt_ae, opt_disc], [] 173 | 174 | def get_last_layer(self): 175 | return self.decoder.conv_out.weight 176 | 177 | @torch.no_grad() 178 | def log_images(self, batch, only_inputs=False, **kwargs): 179 | log = dict() 180 | x = self.get_input(batch, self.image_key) 181 | x = x.to(self.device) 182 | if not only_inputs: 183 | xrec, posterior = self(x) 184 | if x.shape[1] > 3: 185 | # colorize with random projection 186 | assert xrec.shape[1] > 3 187 | x = self.to_rgb(x) 188 | xrec = self.to_rgb(xrec) 189 | log["samples"] = self.decode(torch.randn_like(posterior.sample())) 190 | log["reconstructions"] = xrec 191 | log["inputs"] = x 192 | return log 193 | 194 | def to_rgb(self, x): 195 | assert self.image_key == "segmentation" 196 | if not hasattr(self, "colorize"): 197 | self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) 198 | x = F.conv2d(x, weight=self.colorize) 199 | x = 2.*(x-x.min())/(x.max()-x.min()) - 1. 200 | return x 201 | 202 | class IdentityFirstStage(torch.nn.Module): 203 | def __init__(self, *args, vq_interface=False, **kwargs): 204 | self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff 205 | super().__init__() 206 | 207 | def encode(self, x, *args, **kwargs): 208 | return x 209 | 210 | def decode(self, x, *args, **kwargs): 211 | return x 212 | 213 | def quantize(self, x, *args, **kwargs): 214 | if self.vq_interface: 215 | return x, None, [None, None, None] 216 | return x 217 | 218 | def forward(self, x, *args, **kwargs): 219 | return x 220 | -------------------------------------------------------------------------------- /lvdm/models/samplers/__pycache__/ddim.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoHunLee1/VideoGuide/adfe4eeff279a611c76f948ec9cc2829ed3f123e/lvdm/models/samplers/__pycache__/ddim.cpython-310.pyc -------------------------------------------------------------------------------- /lvdm/models/samplers/__pycache__/ddim.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoHunLee1/VideoGuide/adfe4eeff279a611c76f948ec9cc2829ed3f123e/lvdm/models/samplers/__pycache__/ddim.cpython-311.pyc -------------------------------------------------------------------------------- /lvdm/models/samplers/__pycache__/ddim.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoHunLee1/VideoGuide/adfe4eeff279a611c76f948ec9cc2829ed3f123e/lvdm/models/samplers/__pycache__/ddim.cpython-38.pyc -------------------------------------------------------------------------------- /lvdm/models/samplers/ddim.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import tqdm 3 | import torch 4 | from lvdm.models.utils_diffusion import make_ddim_sampling_parameters, make_ddim_timesteps 5 | from lvdm.common import noise_like 6 | 7 | 8 | class DDIMSampler(object): 9 | def __init__(self, model, schedule="linear", **kwargs): 10 | super().__init__() 11 | self.model = model 12 | self.ddpm_num_timesteps = model.num_timesteps 13 | self.schedule = schedule 14 | self.counter = 0 15 | 16 | def register_buffer(self, name, attr): 17 | if type(attr) == torch.Tensor: 18 | if attr.device != torch.device("cuda"): 19 | attr = attr.to(torch.device("cuda")) 20 | setattr(self, name, attr) 21 | 22 | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): 23 | self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, 24 | num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) 25 | alphas_cumprod = self.model.alphas_cumprod 26 | assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' 27 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) 28 | 29 | self.register_buffer('betas', to_torch(self.model.betas)) 30 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 31 | self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) 32 | self.use_scale = self.model.use_scale 33 | print('DDIM scale', self.use_scale) 34 | 35 | if self.use_scale: 36 | self.register_buffer('scale_arr', to_torch(self.model.scale_arr)) 37 | ddim_scale_arr = self.scale_arr.cpu()[self.ddim_timesteps] 38 | self.register_buffer('ddim_scale_arr', ddim_scale_arr) 39 | ddim_scale_arr = np.asarray([self.scale_arr.cpu()[0]] + self.scale_arr.cpu()[self.ddim_timesteps[:-1]].tolist()) 40 | self.register_buffer('ddim_scale_arr_prev', ddim_scale_arr) 41 | 42 | # calculations for diffusion q(x_t | x_{t-1}) and others 43 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) 44 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) 45 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) 46 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) 47 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) 48 | 49 | # ddim sampling parameters 50 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), 51 | ddim_timesteps=self.ddim_timesteps, 52 | eta=ddim_eta,verbose=verbose) 53 | self.register_buffer('ddim_sigmas', ddim_sigmas) 54 | self.register_buffer('ddim_alphas', ddim_alphas) 55 | self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) 56 | self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) 57 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( 58 | (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( 59 | 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) 60 | self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) 61 | 62 | @torch.no_grad() 63 | def sample(self, 64 | S, 65 | batch_size, 66 | shape, 67 | conditioning=None, 68 | callback=None, 69 | normals_sequence=None, 70 | img_callback=None, 71 | quantize_x0=False, 72 | eta=0., 73 | mask=None, 74 | x0=None, 75 | temperature=1., 76 | noise_dropout=0., 77 | score_corrector=None, 78 | corrector_kwargs=None, 79 | verbose=True, 80 | schedule_verbose=False, 81 | x_T=None, 82 | log_every_t=100, 83 | unconditional_guidance_scale=1., 84 | unconditional_conditioning=None, 85 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... 86 | **kwargs 87 | ): 88 | 89 | # check condition bs 90 | if conditioning is not None: 91 | if isinstance(conditioning, dict): 92 | try: 93 | cbs = conditioning[list(conditioning.keys())[0]].shape[0] 94 | except: 95 | cbs = conditioning[list(conditioning.keys())[0]][0].shape[0] 96 | 97 | if cbs != batch_size: 98 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 99 | else: 100 | if conditioning.shape[0] != batch_size: 101 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 102 | 103 | self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=schedule_verbose) 104 | 105 | # make shape 106 | if len(shape) == 3: 107 | C, H, W = shape 108 | size = (batch_size, C, H, W) 109 | elif len(shape) == 4: 110 | C, T, H, W = shape 111 | size = (batch_size, C, T, H, W) 112 | # print(f'Data shape for DDIM sampling is {size}, eta {eta}') 113 | 114 | samples, intermediates = self.ddim_sampling(conditioning, size, 115 | callback=callback, 116 | img_callback=img_callback, 117 | quantize_denoised=quantize_x0, 118 | mask=mask, x0=x0, 119 | ddim_use_original_steps=False, 120 | noise_dropout=noise_dropout, 121 | temperature=temperature, 122 | score_corrector=score_corrector, 123 | corrector_kwargs=corrector_kwargs, 124 | x_T=x_T, 125 | log_every_t=log_every_t, 126 | unconditional_guidance_scale=unconditional_guidance_scale, 127 | unconditional_conditioning=unconditional_conditioning, 128 | verbose=verbose, 129 | **kwargs) 130 | return samples, intermediates 131 | 132 | @torch.no_grad() 133 | def ddim_sampling(self, cond, shape, 134 | x_T=None, ddim_use_original_steps=False, 135 | callback=None, timesteps=None, quantize_denoised=False, 136 | mask=None, x0=None, img_callback=None, log_every_t=100, 137 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 138 | unconditional_guidance_scale=1., unconditional_conditioning=None, verbose=True, 139 | cond_tau=1., target_size=None, start_timesteps=None, 140 | **kwargs): 141 | device = self.model.betas.device 142 | print('ddim device', device) 143 | b = shape[0] 144 | if x_T is None: 145 | img = torch.randn(shape, device=device) 146 | else: 147 | img = x_T 148 | 149 | if timesteps is None: 150 | timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps 151 | elif timesteps is not None and not ddim_use_original_steps: 152 | subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 153 | timesteps = self.ddim_timesteps[:subset_end] 154 | 155 | intermediates = {'x_inter': [img], 'pred_x0': [img]} 156 | time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) 157 | total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] 158 | if verbose: 159 | iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) 160 | else: 161 | iterator = time_range 162 | 163 | init_x0 = False 164 | clean_cond = kwargs.pop("clean_cond", False) 165 | for i, step in enumerate(iterator): 166 | index = total_steps - i - 1 167 | ts = torch.full((b,), step, device=device, dtype=torch.long) 168 | if start_timesteps is not None: 169 | assert x0 is not None 170 | if step > start_timesteps*time_range[0]: 171 | continue 172 | elif not init_x0: 173 | img = self.model.q_sample(x0, ts) 174 | init_x0 = True 175 | 176 | # use mask to blend noised original latent (img_orig) & new sampled latent (img) 177 | if mask is not None: 178 | assert x0 is not None 179 | if clean_cond: 180 | img_orig = x0 181 | else: 182 | img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? 183 | img = img_orig * mask + (1. - mask) * img # keep original & modify use img 184 | 185 | index_clip = int((1 - cond_tau) * total_steps) 186 | if index <= index_clip and target_size is not None: 187 | target_size_ = [target_size[0], target_size[1]//8, target_size[2]//8] 188 | img = torch.nn.functional.interpolate( 189 | img, 190 | size=target_size_, 191 | mode="nearest", 192 | ) 193 | outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, 194 | quantize_denoised=quantize_denoised, temperature=temperature, 195 | noise_dropout=noise_dropout, score_corrector=score_corrector, 196 | corrector_kwargs=corrector_kwargs, 197 | unconditional_guidance_scale=unconditional_guidance_scale, 198 | unconditional_conditioning=unconditional_conditioning, 199 | x0=x0, 200 | **kwargs) 201 | 202 | img, pred_x0 = outs 203 | if callback: callback(i) 204 | if img_callback: img_callback(pred_x0, i) 205 | 206 | if index % log_every_t == 0 or index == total_steps - 1: 207 | intermediates['x_inter'].append(img) 208 | intermediates['pred_x0'].append(pred_x0) 209 | 210 | return img, intermediates 211 | 212 | @torch.no_grad() 213 | def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, 214 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 215 | unconditional_guidance_scale=1., unconditional_conditioning=None, 216 | uc_type=None, conditional_guidance_scale_temporal=None, **kwargs): 217 | b, *_, device = *x.shape, x.device 218 | if x.dim() == 5: 219 | is_video = True 220 | else: 221 | is_video = False 222 | if unconditional_conditioning is None or unconditional_guidance_scale == 1.: 223 | e_t = self.model.apply_model(x, t, c, **kwargs) # unet denoiser 224 | else: 225 | # with unconditional condition 226 | if isinstance(c, torch.Tensor): 227 | e_t = self.model.apply_model(x, t, c, **kwargs) 228 | e_t_uncond = self.model.apply_model(x, t, unconditional_conditioning, **kwargs) 229 | elif isinstance(c, dict): 230 | e_t = self.model.apply_model(x, t, c, **kwargs) 231 | e_t_uncond = self.model.apply_model(x, t, unconditional_conditioning, **kwargs) 232 | else: 233 | raise NotImplementedError 234 | # text cfg 235 | if uc_type is None: 236 | e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) 237 | else: 238 | if uc_type == 'cfg_original': 239 | e_t = e_t + unconditional_guidance_scale * (e_t - e_t_uncond) 240 | elif uc_type == 'cfg_ours': 241 | e_t = e_t + unconditional_guidance_scale * (e_t_uncond - e_t) 242 | else: 243 | raise NotImplementedError 244 | # temporal guidance 245 | if conditional_guidance_scale_temporal is not None: 246 | e_t_temporal = self.model.apply_model(x, t, c, **kwargs) 247 | e_t_image = self.model.apply_model(x, t, c, no_temporal_attn=True, **kwargs) 248 | e_t = e_t + conditional_guidance_scale_temporal * (e_t_temporal - e_t_image) 249 | 250 | if score_corrector is not None: 251 | assert self.model.parameterization == "eps" 252 | e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) 253 | 254 | alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas 255 | alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev 256 | sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas 257 | sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas 258 | # select parameters corresponding to the currently considered timestep 259 | 260 | if is_video: 261 | size = (b, 1, 1, 1, 1) 262 | else: 263 | size = (b, 1, 1, 1) 264 | a_t = torch.full(size, alphas[index], device=device) 265 | a_prev = torch.full(size, alphas_prev[index], device=device) 266 | sigma_t = torch.full(size, sigmas[index], device=device) 267 | sqrt_one_minus_at = torch.full(size, sqrt_one_minus_alphas[index],device=device) 268 | 269 | # current prediction for x_0 270 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() 271 | if quantize_denoised: 272 | pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) 273 | # direction pointing to x_t 274 | dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t 275 | 276 | noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature 277 | if noise_dropout > 0.: 278 | noise = torch.nn.functional.dropout(noise, p=noise_dropout) 279 | 280 | alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas 281 | if self.use_scale: 282 | scale_arr = self.model.scale_arr if use_original_steps else self.ddim_scale_arr 283 | scale_t = torch.full(size, scale_arr[index], device=device) 284 | scale_arr_prev = self.model.scale_arr_prev if use_original_steps else self.ddim_scale_arr_prev 285 | scale_t_prev = torch.full(size, scale_arr_prev[index], device=device) 286 | pred_x0 /= scale_t 287 | x_prev = a_prev.sqrt() * scale_t_prev * pred_x0 + dir_xt + noise 288 | else: 289 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise 290 | 291 | return x_prev, pred_x0 292 | 293 | 294 | @torch.no_grad() 295 | def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): 296 | # fast, but does not allow for exact reconstruction 297 | # t serves as an index to gather the correct alphas 298 | if use_original_steps: 299 | sqrt_alphas_cumprod = self.sqrt_alphas_cumprod 300 | sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod 301 | else: 302 | sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) 303 | sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas 304 | 305 | if noise is None: 306 | noise = torch.randn_like(x0) 307 | 308 | def extract_into_tensor(a, t, x_shape): 309 | b, *_ = t.shape 310 | out = a.gather(-1, t) 311 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 312 | 313 | return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + 314 | extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise) 315 | 316 | @torch.no_grad() 317 | def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, 318 | use_original_steps=False): 319 | 320 | timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps 321 | timesteps = timesteps[:t_start] 322 | 323 | time_range = np.flip(timesteps) 324 | total_steps = timesteps.shape[0] 325 | print(f"Running DDIM Sampling with {total_steps} timesteps") 326 | 327 | iterator = tqdm(time_range, desc='Decoding image', total=total_steps) 328 | x_dec = x_latent 329 | for i, step in enumerate(iterator): 330 | index = total_steps - i - 1 331 | ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long) 332 | x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps, 333 | unconditional_guidance_scale=unconditional_guidance_scale, 334 | unconditional_conditioning=unconditional_conditioning) 335 | return x_dec 336 | 337 | -------------------------------------------------------------------------------- /lvdm/models/utils_diffusion.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | from einops import repeat 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | 8 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): 9 | """ 10 | Create sinusoidal timestep embeddings. 11 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 12 | These may be fractional. 13 | :param dim: the dimension of the output. 14 | :param max_period: controls the minimum frequency of the embeddings. 15 | :return: an [N x dim] Tensor of positional embeddings. 16 | """ 17 | if not repeat_only: 18 | half = dim // 2 19 | freqs = torch.exp( 20 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 21 | ).to(device=timesteps.device) 22 | args = timesteps[:, None].float() * freqs[None] 23 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 24 | if dim % 2: 25 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 26 | else: 27 | embedding = repeat(timesteps, 'b -> b d', d=dim) 28 | return embedding 29 | 30 | 31 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 32 | if schedule == "linear": 33 | betas = ( 34 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 35 | ) 36 | 37 | elif schedule == "cosine": 38 | timesteps = ( 39 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s 40 | ) 41 | alphas = timesteps / (1 + cosine_s) * np.pi / 2 42 | alphas = torch.cos(alphas).pow(2) 43 | alphas = alphas / alphas[0] 44 | betas = 1 - alphas[1:] / alphas[:-1] 45 | betas = np.clip(betas, a_min=0, a_max=0.999) 46 | 47 | elif schedule == "sqrt_linear": 48 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) 49 | elif schedule == "sqrt": 50 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 51 | else: 52 | raise ValueError(f"schedule '{schedule}' unknown.") 53 | return betas.numpy() 54 | 55 | 56 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): 57 | if ddim_discr_method == 'uniform': 58 | c = num_ddpm_timesteps // num_ddim_timesteps 59 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) 60 | elif ddim_discr_method == 'quad': 61 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) 62 | else: 63 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') 64 | 65 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps 66 | # add one to get the final alpha values right (the ones from first scale to data during sampling) 67 | steps_out = ddim_timesteps + 1 68 | if verbose: 69 | print(f'Selected timesteps for ddim sampler: {steps_out}') 70 | return steps_out 71 | 72 | 73 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): 74 | # select alphas for computing the variance schedule 75 | # print(f'ddim_timesteps={ddim_timesteps}, len_alphacums={len(alphacums)}') 76 | alphas = alphacums[ddim_timesteps] 77 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) 78 | 79 | # according the the formula provided in https://arxiv.org/abs/2010.02502 80 | sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) 81 | if verbose: 82 | print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') 83 | print(f'For the chosen value of eta, which is {eta}, ' 84 | f'this results in the following sigma_t schedule for ddim sampler {sigmas}') 85 | return sigmas, alphas, alphas_prev 86 | 87 | 88 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 89 | """ 90 | Create a beta schedule that discretizes the given alpha_t_bar function, 91 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 92 | :param num_diffusion_timesteps: the number of betas to produce. 93 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 94 | produces the cumulative product of (1-beta) up to that 95 | part of the diffusion process. 96 | :param max_beta: the maximum beta to use; use values lower than 1 to 97 | prevent singularities. 98 | """ 99 | betas = [] 100 | for i in range(num_diffusion_timesteps): 101 | t1 = i / num_diffusion_timesteps 102 | t2 = (i + 1) / num_diffusion_timesteps 103 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 104 | return np.array(betas) -------------------------------------------------------------------------------- /lvdm/modules/__pycache__/attention.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoHunLee1/VideoGuide/adfe4eeff279a611c76f948ec9cc2829ed3f123e/lvdm/modules/__pycache__/attention.cpython-310.pyc -------------------------------------------------------------------------------- /lvdm/modules/__pycache__/attention.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoHunLee1/VideoGuide/adfe4eeff279a611c76f948ec9cc2829ed3f123e/lvdm/modules/__pycache__/attention.cpython-38.pyc -------------------------------------------------------------------------------- /lvdm/modules/attention.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import torch 3 | from torch import nn, einsum 4 | import torch.nn.functional as F 5 | from einops import rearrange, repeat 6 | try: 7 | import xformers 8 | import xformers.ops 9 | XFORMERS_IS_AVAILBLE = True 10 | except: 11 | XFORMERS_IS_AVAILBLE = False 12 | from lvdm.common import ( 13 | checkpoint, 14 | exists, 15 | default, 16 | ) 17 | from lvdm.basics import ( 18 | zero_module, 19 | ) 20 | 21 | class RelativePosition(nn.Module): 22 | """ https://github.com/evelinehong/Transformer_Relative_Position_PyTorch/blob/master/relative_position.py """ 23 | 24 | def __init__(self, num_units, max_relative_position): 25 | super().__init__() 26 | self.num_units = num_units 27 | self.max_relative_position = max_relative_position 28 | self.embeddings_table = nn.Parameter(torch.Tensor(max_relative_position * 2 + 1, num_units)) 29 | nn.init.xavier_uniform_(self.embeddings_table) 30 | 31 | def forward(self, length_q, length_k): 32 | device = self.embeddings_table.device 33 | range_vec_q = torch.arange(length_q, device=device) 34 | range_vec_k = torch.arange(length_k, device=device) 35 | distance_mat = range_vec_k[None, :] - range_vec_q[:, None] 36 | distance_mat_clipped = torch.clamp(distance_mat, -self.max_relative_position, self.max_relative_position) 37 | final_mat = distance_mat_clipped + self.max_relative_position 38 | final_mat = final_mat.long() 39 | embeddings = self.embeddings_table[final_mat] 40 | return embeddings 41 | 42 | 43 | class CrossAttention(nn.Module): 44 | 45 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., 46 | relative_position=False, temporal_length=None, img_cross_attention=False): 47 | super().__init__() 48 | inner_dim = dim_head * heads 49 | context_dim = default(context_dim, query_dim) 50 | 51 | self.scale = dim_head**-0.5 52 | self.heads = heads 53 | self.dim_head = dim_head 54 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 55 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False) 56 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False) 57 | self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) 58 | 59 | self.image_cross_attention_scale = 1.0 60 | self.text_context_len = 77 61 | self.img_cross_attention = img_cross_attention 62 | if self.img_cross_attention: 63 | self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False) 64 | self.to_v_ip = nn.Linear(context_dim, inner_dim, bias=False) 65 | 66 | self.relative_position = relative_position 67 | if self.relative_position: 68 | assert(temporal_length is not None) 69 | self.relative_position_k = RelativePosition(num_units=dim_head, max_relative_position=temporal_length) 70 | self.relative_position_v = RelativePosition(num_units=dim_head, max_relative_position=temporal_length) 71 | else: 72 | ## only used for spatial attention, while NOT for temporal attention 73 | if XFORMERS_IS_AVAILBLE and temporal_length is None: 74 | self.forward = self.efficient_forward 75 | 76 | def forward(self, x, context=None, mask=None): 77 | h = self.heads 78 | 79 | q = self.to_q(x) 80 | context = default(context, x) 81 | ## considering image token additionally 82 | if context is not None and self.img_cross_attention: 83 | context, context_img = context[:,:self.text_context_len,:], context[:,self.text_context_len:,:] 84 | k = self.to_k(context) 85 | v = self.to_v(context) 86 | k_ip = self.to_k_ip(context_img) 87 | v_ip = self.to_v_ip(context_img) 88 | else: 89 | k = self.to_k(context) 90 | v = self.to_v(context) 91 | 92 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) 93 | sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale 94 | if self.relative_position: 95 | len_q, len_k, len_v = q.shape[1], k.shape[1], v.shape[1] 96 | k2 = self.relative_position_k(len_q, len_k) 97 | sim2 = einsum('b t d, t s d -> b t s', q, k2) * self.scale # TODO check 98 | sim += sim2 99 | del k 100 | 101 | if exists(mask): 102 | ## feasible for causal attention mask only 103 | max_neg_value = -torch.finfo(sim.dtype).max 104 | mask = repeat(mask, 'b i j -> (b h) i j', h=h) 105 | sim.masked_fill_(~(mask>0.5), max_neg_value) 106 | 107 | # attention, what we cannot get enough of 108 | sim = sim.softmax(dim=-1) 109 | out = torch.einsum('b i j, b j d -> b i d', sim, v) 110 | if self.relative_position: 111 | v2 = self.relative_position_v(len_q, len_v) 112 | out2 = einsum('b t s, t s d -> b t d', sim, v2) # TODO check 113 | out += out2 114 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h) 115 | 116 | ## considering image token additionally 117 | if context is not None and self.img_cross_attention: 118 | k_ip, v_ip = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (k_ip, v_ip)) 119 | sim_ip = torch.einsum('b i d, b j d -> b i j', q, k_ip) * self.scale 120 | del k_ip 121 | sim_ip = sim_ip.softmax(dim=-1) 122 | out_ip = torch.einsum('b i j, b j d -> b i d', sim_ip, v_ip) 123 | out_ip = rearrange(out_ip, '(b h) n d -> b n (h d)', h=h) 124 | out = out + self.image_cross_attention_scale * out_ip 125 | del q 126 | 127 | return self.to_out(out) 128 | 129 | def efficient_forward(self, x, context=None, mask=None): 130 | q = self.to_q(x) 131 | context = default(context, x) 132 | 133 | ## considering image token additionally 134 | if context is not None and self.img_cross_attention: 135 | context, context_img = context[:,:self.text_context_len,:], context[:,self.text_context_len:,:] 136 | k = self.to_k(context) 137 | v = self.to_v(context) 138 | k_ip = self.to_k_ip(context_img) 139 | v_ip = self.to_v_ip(context_img) 140 | else: 141 | k = self.to_k(context) 142 | v = self.to_v(context) 143 | 144 | b, _, _ = q.shape 145 | q, k, v = map( 146 | lambda t: t.unsqueeze(3) 147 | .reshape(b, t.shape[1], self.heads, self.dim_head) 148 | .permute(0, 2, 1, 3) 149 | .reshape(b * self.heads, t.shape[1], self.dim_head) 150 | .contiguous(), 151 | (q, k, v), 152 | ) 153 | # actually compute the attention, what we cannot get enough of 154 | out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=None) 155 | 156 | ## considering image token additionally 157 | if context is not None and self.img_cross_attention: 158 | k_ip, v_ip = map( 159 | lambda t: t.unsqueeze(3) 160 | .reshape(b, t.shape[1], self.heads, self.dim_head) 161 | .permute(0, 2, 1, 3) 162 | .reshape(b * self.heads, t.shape[1], self.dim_head) 163 | .contiguous(), 164 | (k_ip, v_ip), 165 | ) 166 | out_ip = xformers.ops.memory_efficient_attention(q, k_ip, v_ip, attn_bias=None, op=None) 167 | out_ip = ( 168 | out_ip.unsqueeze(0) 169 | .reshape(b, self.heads, out.shape[1], self.dim_head) 170 | .permute(0, 2, 1, 3) 171 | .reshape(b, out.shape[1], self.heads * self.dim_head) 172 | ) 173 | 174 | if exists(mask): 175 | raise NotImplementedError 176 | out = ( 177 | out.unsqueeze(0) 178 | .reshape(b, self.heads, out.shape[1], self.dim_head) 179 | .permute(0, 2, 1, 3) 180 | .reshape(b, out.shape[1], self.heads * self.dim_head) 181 | ) 182 | if context is not None and self.img_cross_attention: 183 | out = out + self.image_cross_attention_scale * out_ip 184 | return self.to_out(out) 185 | 186 | 187 | class BasicTransformerBlock(nn.Module): 188 | 189 | def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, 190 | disable_self_attn=False, attention_cls=None, img_cross_attention=False): 191 | super().__init__() 192 | attn_cls = CrossAttention if attention_cls is None else attention_cls 193 | self.disable_self_attn = disable_self_attn 194 | self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, 195 | context_dim=context_dim if self.disable_self_attn else None) 196 | self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) 197 | self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout, 198 | img_cross_attention=img_cross_attention) 199 | self.norm1 = nn.LayerNorm(dim) 200 | self.norm2 = nn.LayerNorm(dim) 201 | self.norm3 = nn.LayerNorm(dim) 202 | self.checkpoint = checkpoint 203 | 204 | def forward(self, x, context=None, mask=None): 205 | ## implementation tricks: because checkpointing doesn't support non-tensor (e.g. None or scalar) arguments 206 | input_tuple = (x,) ## should not be (x), otherwise *input_tuple will decouple x into multiple arguments 207 | if context is not None: 208 | input_tuple = (x, context) 209 | if mask is not None: 210 | forward_mask = partial(self._forward, mask=mask) 211 | return checkpoint(forward_mask, (x,), self.parameters(), self.checkpoint) 212 | if context is not None and mask is not None: 213 | input_tuple = (x, context, mask) 214 | return checkpoint(self._forward, input_tuple, self.parameters(), self.checkpoint) 215 | 216 | def _forward(self, x, context=None, mask=None): 217 | x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None, mask=mask) + x 218 | x = self.attn2(self.norm2(x), context=context, mask=mask) + x 219 | x = self.ff(self.norm3(x)) + x 220 | return x 221 | 222 | 223 | class SpatialTransformer(nn.Module): 224 | """ 225 | Transformer block for image-like data in spatial axis. 226 | First, project the input (aka embedding) 227 | and reshape to b, t, d. 228 | Then apply standard transformer action. 229 | Finally, reshape to image 230 | NEW: use_linear for more efficiency instead of the 1x1 convs 231 | """ 232 | 233 | def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0., context_dim=None, 234 | use_checkpoint=True, disable_self_attn=False, use_linear=False, img_cross_attention=False): 235 | super().__init__() 236 | self.in_channels = in_channels 237 | inner_dim = n_heads * d_head 238 | self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 239 | if not use_linear: 240 | self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) 241 | else: 242 | self.proj_in = nn.Linear(in_channels, inner_dim) 243 | 244 | self.transformer_blocks = nn.ModuleList([ 245 | BasicTransformerBlock( 246 | inner_dim, 247 | n_heads, 248 | d_head, 249 | dropout=dropout, 250 | context_dim=context_dim, 251 | img_cross_attention=img_cross_attention, 252 | disable_self_attn=disable_self_attn, 253 | checkpoint=use_checkpoint) for d in range(depth) 254 | ]) 255 | if not use_linear: 256 | self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)) 257 | else: 258 | self.proj_out = zero_module(nn.Linear(inner_dim, in_channels)) 259 | self.use_linear = use_linear 260 | 261 | 262 | def forward(self, x, context=None): 263 | b, c, h, w = x.shape 264 | x_in = x 265 | x = self.norm(x) 266 | if not self.use_linear: 267 | x = self.proj_in(x) 268 | x = rearrange(x, 'b c h w -> b (h w) c').contiguous() 269 | if self.use_linear: 270 | x = self.proj_in(x) 271 | for i, block in enumerate(self.transformer_blocks): 272 | x = block(x, context=context) 273 | if self.use_linear: 274 | x = self.proj_out(x) 275 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() 276 | if not self.use_linear: 277 | x = self.proj_out(x) 278 | return x + x_in 279 | 280 | 281 | class TemporalTransformer(nn.Module): 282 | """ 283 | Transformer block for image-like data in temporal axis. 284 | First, reshape to b, t, d. 285 | Then apply standard transformer action. 286 | Finally, reshape to image 287 | """ 288 | def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0., context_dim=None, 289 | use_checkpoint=True, use_linear=False, only_self_att=True, causal_attention=False, 290 | relative_position=False, temporal_length=None): 291 | super().__init__() 292 | self.only_self_att = only_self_att 293 | self.relative_position = relative_position 294 | self.causal_attention = causal_attention 295 | self.in_channels = in_channels 296 | inner_dim = n_heads * d_head 297 | self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 298 | self.proj_in = nn.Conv1d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) 299 | if not use_linear: 300 | self.proj_in = nn.Conv1d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) 301 | else: 302 | self.proj_in = nn.Linear(in_channels, inner_dim) 303 | 304 | if relative_position: 305 | assert(temporal_length is not None) 306 | attention_cls = partial(CrossAttention, relative_position=True, temporal_length=temporal_length) 307 | else: 308 | attention_cls = None 309 | if self.causal_attention: 310 | assert(temporal_length is not None) 311 | self.mask = torch.tril(torch.ones([1, temporal_length, temporal_length])) 312 | 313 | if self.only_self_att: 314 | context_dim = None 315 | self.transformer_blocks = nn.ModuleList([ 316 | BasicTransformerBlock( 317 | inner_dim, 318 | n_heads, 319 | d_head, 320 | dropout=dropout, 321 | context_dim=context_dim, 322 | attention_cls=attention_cls, 323 | checkpoint=use_checkpoint) for d in range(depth) 324 | ]) 325 | if not use_linear: 326 | self.proj_out = zero_module(nn.Conv1d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)) 327 | else: 328 | self.proj_out = zero_module(nn.Linear(inner_dim, in_channels)) 329 | self.use_linear = use_linear 330 | 331 | def forward(self, x, context=None): 332 | b, c, t, h, w = x.shape 333 | x_in = x 334 | x = self.norm(x) 335 | x = rearrange(x, 'b c t h w -> (b h w) c t').contiguous() 336 | if not self.use_linear: 337 | x = self.proj_in(x) 338 | x = rearrange(x, 'bhw c t -> bhw t c').contiguous() 339 | if self.use_linear: 340 | x = self.proj_in(x) 341 | 342 | if self.causal_attention: 343 | mask = self.mask.to(x.device) 344 | mask = repeat(mask, 'l i j -> (l bhw) i j', bhw=b*h*w) 345 | else: 346 | mask = None 347 | 348 | if self.only_self_att: 349 | ## note: if no context is given, cross-attention defaults to self-attention 350 | for i, block in enumerate(self.transformer_blocks): 351 | x = block(x, mask=mask) 352 | x = rearrange(x, '(b hw) t c -> b hw t c', b=b).contiguous() 353 | else: 354 | x = rearrange(x, '(b hw) t c -> b hw t c', b=b).contiguous() 355 | context = rearrange(context, '(b t) l con -> b t l con', t=t).contiguous() 356 | for i, block in enumerate(self.transformer_blocks): 357 | # calculate each batch one by one (since number in shape could not greater then 65,535 for some package) 358 | for j in range(b): 359 | context_j = repeat( 360 | context[j], 361 | 't l con -> (t r) l con', r=(h * w) // t, t=t).contiguous() 362 | ## note: causal mask will not applied in cross-attention case 363 | x[j] = block(x[j], context=context_j) 364 | 365 | if self.use_linear: 366 | x = self.proj_out(x) 367 | x = rearrange(x, 'b (h w) t c -> b c t h w', h=h, w=w).contiguous() 368 | if not self.use_linear: 369 | x = rearrange(x, 'b hw t c -> (b hw) c t').contiguous() 370 | x = self.proj_out(x) 371 | x = rearrange(x, '(b h w) c t -> b c t h w', b=b, h=h, w=w).contiguous() 372 | 373 | return x + x_in 374 | 375 | 376 | class GEGLU(nn.Module): 377 | def __init__(self, dim_in, dim_out): 378 | super().__init__() 379 | self.proj = nn.Linear(dim_in, dim_out * 2) 380 | 381 | def forward(self, x): 382 | x, gate = self.proj(x).chunk(2, dim=-1) 383 | return x * F.gelu(gate) 384 | 385 | 386 | class FeedForward(nn.Module): 387 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): 388 | super().__init__() 389 | inner_dim = int(dim * mult) 390 | dim_out = default(dim_out, dim) 391 | project_in = nn.Sequential( 392 | nn.Linear(dim, inner_dim), 393 | nn.GELU() 394 | ) if not glu else GEGLU(dim, inner_dim) 395 | 396 | self.net = nn.Sequential( 397 | project_in, 398 | nn.Dropout(dropout), 399 | nn.Linear(inner_dim, dim_out) 400 | ) 401 | 402 | def forward(self, x): 403 | return self.net(x) 404 | 405 | 406 | class LinearAttention(nn.Module): 407 | def __init__(self, dim, heads=4, dim_head=32): 408 | super().__init__() 409 | self.heads = heads 410 | hidden_dim = dim_head * heads 411 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) 412 | self.to_out = nn.Conv2d(hidden_dim, dim, 1) 413 | 414 | def forward(self, x): 415 | b, c, h, w = x.shape 416 | qkv = self.to_qkv(x) 417 | q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) 418 | k = k.softmax(dim=-1) 419 | context = torch.einsum('bhdn,bhen->bhde', k, v) 420 | out = torch.einsum('bhde,bhdn->bhen', context, q) 421 | out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) 422 | return self.to_out(out) 423 | 424 | 425 | class SpatialSelfAttention(nn.Module): 426 | def __init__(self, in_channels): 427 | super().__init__() 428 | self.in_channels = in_channels 429 | 430 | self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 431 | self.q = torch.nn.Conv2d(in_channels, 432 | in_channels, 433 | kernel_size=1, 434 | stride=1, 435 | padding=0) 436 | self.k = torch.nn.Conv2d(in_channels, 437 | in_channels, 438 | kernel_size=1, 439 | stride=1, 440 | padding=0) 441 | self.v = torch.nn.Conv2d(in_channels, 442 | in_channels, 443 | kernel_size=1, 444 | stride=1, 445 | padding=0) 446 | self.proj_out = torch.nn.Conv2d(in_channels, 447 | in_channels, 448 | kernel_size=1, 449 | stride=1, 450 | padding=0) 451 | 452 | def forward(self, x): 453 | h_ = x 454 | h_ = self.norm(h_) 455 | q = self.q(h_) 456 | k = self.k(h_) 457 | v = self.v(h_) 458 | 459 | # compute attention 460 | b,c,h,w = q.shape 461 | q = rearrange(q, 'b c h w -> b (h w) c') 462 | k = rearrange(k, 'b c h w -> b c (h w)') 463 | w_ = torch.einsum('bij,bjk->bik', q, k) 464 | 465 | w_ = w_ * (int(c)**(-0.5)) 466 | w_ = torch.nn.functional.softmax(w_, dim=2) 467 | 468 | # attend to values 469 | v = rearrange(v, 'b c h w -> b c (h w)') 470 | w_ = rearrange(w_, 'b i j -> b j i') 471 | h_ = torch.einsum('bij,bjk->bik', v, w_) 472 | h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) 473 | h_ = self.proj_out(h_) 474 | 475 | return x+h_ 476 | -------------------------------------------------------------------------------- /lvdm/modules/encoders/__pycache__/condition.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoHunLee1/VideoGuide/adfe4eeff279a611c76f948ec9cc2829ed3f123e/lvdm/modules/encoders/__pycache__/condition.cpython-310.pyc -------------------------------------------------------------------------------- /lvdm/modules/encoders/__pycache__/condition.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoHunLee1/VideoGuide/adfe4eeff279a611c76f948ec9cc2829ed3f123e/lvdm/modules/encoders/__pycache__/condition.cpython-38.pyc -------------------------------------------------------------------------------- /lvdm/modules/encoders/__pycache__/ip_resampler.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoHunLee1/VideoGuide/adfe4eeff279a611c76f948ec9cc2829ed3f123e/lvdm/modules/encoders/__pycache__/ip_resampler.cpython-310.pyc -------------------------------------------------------------------------------- /lvdm/modules/encoders/__pycache__/ip_resampler.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoHunLee1/VideoGuide/adfe4eeff279a611c76f948ec9cc2829ed3f123e/lvdm/modules/encoders/__pycache__/ip_resampler.cpython-38.pyc -------------------------------------------------------------------------------- /lvdm/modules/encoders/condition.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.utils.checkpoint import checkpoint 4 | import kornia 5 | import open_clip 6 | from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel 7 | from lvdm.common import autocast 8 | from vc_utils.utils import count_params 9 | 10 | class AbstractEncoder(nn.Module): 11 | def __init__(self): 12 | super().__init__() 13 | 14 | def encode(self, *args, **kwargs): 15 | raise NotImplementedError 16 | 17 | 18 | class IdentityEncoder(AbstractEncoder): 19 | 20 | def encode(self, x): 21 | return x 22 | 23 | 24 | class ClassEmbedder(nn.Module): 25 | def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1): 26 | super().__init__() 27 | self.key = key 28 | self.embedding = nn.Embedding(n_classes, embed_dim) 29 | self.n_classes = n_classes 30 | self.ucg_rate = ucg_rate 31 | 32 | def forward(self, batch, key=None, disable_dropout=False): 33 | if key is None: 34 | key = self.key 35 | # this is for use in crossattn 36 | c = batch[key][:, None] 37 | if self.ucg_rate > 0. and not disable_dropout: 38 | mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate) 39 | c = mask * c + (1 - mask) * torch.ones_like(c) * (self.n_classes - 1) 40 | c = c.long() 41 | c = self.embedding(c) 42 | return c 43 | 44 | def get_unconditional_conditioning(self, bs, device="cuda"): 45 | uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000) 46 | uc = torch.ones((bs,), device=device) * uc_class 47 | uc = {self.key: uc} 48 | return uc 49 | 50 | 51 | def disabled_train(self, mode=True): 52 | """Overwrite model.train with this function to make sure train/eval mode 53 | does not change anymore.""" 54 | return self 55 | 56 | 57 | class FrozenT5Embedder(AbstractEncoder): 58 | """Uses the T5 transformer encoder for text""" 59 | 60 | def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, 61 | freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl 62 | super().__init__() 63 | self.tokenizer = T5Tokenizer.from_pretrained(version) 64 | self.transformer = T5EncoderModel.from_pretrained(version) 65 | self.device = device 66 | self.max_length = max_length # TODO: typical value? 67 | if freeze: 68 | self.freeze() 69 | 70 | def freeze(self): 71 | self.transformer = self.transformer.eval() 72 | # self.train = disabled_train 73 | for param in self.parameters(): 74 | param.requires_grad = False 75 | 76 | def forward(self, text): 77 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 78 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 79 | tokens = batch_encoding["input_ids"].to(self.device) 80 | outputs = self.transformer(input_ids=tokens) 81 | 82 | z = outputs.last_hidden_state 83 | return z 84 | 85 | def encode(self, text): 86 | return self(text) 87 | 88 | 89 | class FrozenCLIPEmbedder(AbstractEncoder): 90 | """Uses the CLIP transformer encoder for text (from huggingface)""" 91 | LAYERS = [ 92 | "last", 93 | "pooled", 94 | "hidden" 95 | ] 96 | 97 | def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, 98 | freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32 99 | super().__init__() 100 | assert layer in self.LAYERS 101 | self.tokenizer = CLIPTokenizer.from_pretrained(version) 102 | self.transformer = CLIPTextModel.from_pretrained(version) 103 | self.device = device 104 | self.max_length = max_length 105 | if freeze: 106 | self.freeze() 107 | self.layer = layer 108 | self.layer_idx = layer_idx 109 | if layer == "hidden": 110 | assert layer_idx is not None 111 | assert 0 <= abs(layer_idx) <= 12 112 | 113 | def freeze(self): 114 | self.transformer = self.transformer.eval() 115 | # self.train = disabled_train 116 | for param in self.parameters(): 117 | param.requires_grad = False 118 | 119 | def forward(self, text): 120 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 121 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 122 | tokens = batch_encoding["input_ids"].to(self.device) 123 | outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer == "hidden") 124 | if self.layer == "last": 125 | z = outputs.last_hidden_state 126 | elif self.layer == "pooled": 127 | z = outputs.pooler_output[:, None, :] 128 | else: 129 | z = outputs.hidden_states[self.layer_idx] 130 | return z 131 | 132 | def encode(self, text): 133 | return self(text) 134 | 135 | 136 | class ClipImageEmbedder(nn.Module): 137 | def __init__( 138 | self, 139 | model, 140 | jit=False, 141 | device='cuda' if torch.cuda.is_available() else 'cpu', 142 | antialias=True, 143 | ucg_rate=0. 144 | ): 145 | super().__init__() 146 | from clip import load as load_clip 147 | self.model, _ = load_clip(name=model, device=device, jit=jit) 148 | 149 | self.antialias = antialias 150 | 151 | self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) 152 | self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) 153 | self.ucg_rate = ucg_rate 154 | 155 | def preprocess(self, x): 156 | # normalize to [0,1] 157 | x = kornia.geometry.resize(x, (224, 224), 158 | interpolation='bicubic', align_corners=True, 159 | antialias=self.antialias) 160 | x = (x + 1.) / 2. 161 | # re-normalize according to clip 162 | x = kornia.enhance.normalize(x, self.mean, self.std) 163 | return x 164 | 165 | def forward(self, x, no_dropout=False): 166 | # x is assumed to be in range [-1,1] 167 | out = self.model.encode_image(self.preprocess(x)) 168 | out = out.to(x.dtype) 169 | if self.ucg_rate > 0. and not no_dropout: 170 | out = torch.bernoulli((1. - self.ucg_rate) * torch.ones(out.shape[0], device=out.device))[:, None] * out 171 | return out 172 | 173 | 174 | class FrozenOpenCLIPEmbedder(AbstractEncoder): 175 | """ 176 | Uses the OpenCLIP transformer encoder for text 177 | """ 178 | LAYERS = [ 179 | # "pooled", 180 | "last", 181 | "penultimate" 182 | ] 183 | 184 | def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77, 185 | freeze=True, layer="last"): 186 | super().__init__() 187 | assert layer in self.LAYERS 188 | model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu')) 189 | del model.visual 190 | self.model = model 191 | 192 | self.device = device 193 | self.max_length = max_length 194 | if freeze: 195 | self.freeze() 196 | self.layer = layer 197 | if self.layer == "last": 198 | self.layer_idx = 0 199 | elif self.layer == "penultimate": 200 | self.layer_idx = 1 201 | else: 202 | raise NotImplementedError() 203 | 204 | def freeze(self): 205 | self.model = self.model.eval() 206 | for param in self.parameters(): 207 | param.requires_grad = False 208 | 209 | def forward(self, text): 210 | self.device = self.model.positional_embedding.device 211 | tokens = open_clip.tokenize(text) 212 | z = self.encode_with_transformer(tokens.to(self.device)) 213 | return z 214 | 215 | def encode_with_transformer(self, text): 216 | x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] 217 | x = x + self.model.positional_embedding 218 | x = x.permute(1, 0, 2) # NLD -> LND 219 | x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) 220 | x = x.permute(1, 0, 2) # LND -> NLD 221 | x = self.model.ln_final(x) 222 | return x 223 | 224 | def text_transformer_forward(self, x: torch.Tensor, attn_mask=None): 225 | for i, r in enumerate(self.model.transformer.resblocks): 226 | if i == len(self.model.transformer.resblocks) - self.layer_idx: 227 | break 228 | if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(): 229 | x = checkpoint(r, x, attn_mask) 230 | else: 231 | x = r(x, attn_mask=attn_mask) 232 | return x 233 | 234 | def encode(self, text): 235 | return self(text) 236 | 237 | 238 | class FrozenOpenCLIPImageEmbedder(AbstractEncoder): 239 | """ 240 | Uses the OpenCLIP vision transformer encoder for images 241 | """ 242 | 243 | def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77, 244 | freeze=True, layer="pooled", antialias=True, ucg_rate=0.): 245 | super().__init__() 246 | model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), 247 | pretrained=version, ) 248 | del model.transformer 249 | self.model = model 250 | 251 | self.device = device 252 | self.max_length = max_length 253 | if freeze: 254 | self.freeze() 255 | self.layer = layer 256 | if self.layer == "penultimate": 257 | raise NotImplementedError() 258 | self.layer_idx = 1 259 | 260 | self.antialias = antialias 261 | 262 | self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) 263 | self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) 264 | self.ucg_rate = ucg_rate 265 | 266 | def preprocess(self, x): 267 | # normalize to [0,1] 268 | x = kornia.geometry.resize(x, (224, 224), 269 | interpolation='bicubic', align_corners=True, 270 | antialias=self.antialias) 271 | x = (x + 1.) / 2. 272 | # renormalize according to clip 273 | x = kornia.enhance.normalize(x, self.mean, self.std) 274 | return x 275 | 276 | def freeze(self): 277 | self.model = self.model.eval() 278 | for param in self.parameters(): 279 | param.requires_grad = False 280 | 281 | @autocast 282 | def forward(self, image, no_dropout=False): 283 | z = self.encode_with_vision_transformer(image) 284 | if self.ucg_rate > 0. and not no_dropout: 285 | z = torch.bernoulli((1. - self.ucg_rate) * torch.ones(z.shape[0], device=z.device))[:, None] * z 286 | return z 287 | 288 | def encode_with_vision_transformer(self, img): 289 | img = self.preprocess(img) 290 | x = self.model.visual(img) 291 | return x 292 | 293 | def encode(self, text): 294 | return self(text) 295 | 296 | 297 | 298 | class FrozenOpenCLIPImageEmbedderV2(AbstractEncoder): 299 | """ 300 | Uses the OpenCLIP vision transformer encoder for images 301 | """ 302 | 303 | def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", 304 | freeze=True, layer="pooled", antialias=True): 305 | super().__init__() 306 | model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), 307 | pretrained=version, ) 308 | del model.transformer 309 | self.model = model 310 | self.device = device 311 | 312 | if freeze: 313 | self.freeze() 314 | self.layer = layer 315 | if self.layer == "penultimate": 316 | raise NotImplementedError() 317 | self.layer_idx = 1 318 | 319 | self.antialias = antialias 320 | self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) 321 | self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) 322 | 323 | 324 | def preprocess(self, x): 325 | # normalize to [0,1] 326 | x = kornia.geometry.resize(x, (224, 224), 327 | interpolation='bicubic', align_corners=True, 328 | antialias=self.antialias) 329 | x = (x + 1.) / 2. 330 | # renormalize according to clip 331 | x = kornia.enhance.normalize(x, self.mean, self.std) 332 | return x 333 | 334 | def freeze(self): 335 | self.model = self.model.eval() 336 | for param in self.model.parameters(): 337 | param.requires_grad = False 338 | 339 | def forward(self, image, no_dropout=False): 340 | ## image: b c h w 341 | z = self.encode_with_vision_transformer(image) 342 | return z 343 | 344 | def encode_with_vision_transformer(self, x): 345 | x = self.preprocess(x) 346 | 347 | # to patches - whether to use dual patchnorm - https://arxiv.org/abs/2302.01327v1 348 | if self.model.visual.input_patchnorm: 349 | # einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)') 350 | x = x.reshape(x.shape[0], x.shape[1], self.model.visual.grid_size[0], self.model.visual.patch_size[0], self.model.visual.grid_size[1], self.model.visual.patch_size[1]) 351 | x = x.permute(0, 2, 4, 1, 3, 5) 352 | x = x.reshape(x.shape[0], self.model.visual.grid_size[0] * self.model.visual.grid_size[1], -1) 353 | x = self.model.visual.patchnorm_pre_ln(x) 354 | x = self.model.visual.conv1(x) 355 | else: 356 | x = self.model.visual.conv1(x) # shape = [*, width, grid, grid] 357 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 358 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 359 | 360 | # class embeddings and positional embeddings 361 | x = torch.cat( 362 | [self.model.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), 363 | x], dim=1) # shape = [*, grid ** 2 + 1, width] 364 | x = x + self.model.visual.positional_embedding.to(x.dtype) 365 | 366 | # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in 367 | x = self.model.visual.patch_dropout(x) 368 | x = self.model.visual.ln_pre(x) 369 | 370 | x = x.permute(1, 0, 2) # NLD -> LND 371 | x = self.model.visual.transformer(x) 372 | x = x.permute(1, 0, 2) # LND -> NLD 373 | 374 | return x 375 | 376 | 377 | class FrozenCLIPT5Encoder(AbstractEncoder): 378 | def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda", 379 | clip_max_length=77, t5_max_length=77): 380 | super().__init__() 381 | self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length) 382 | self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length) 383 | print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, " 384 | f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params.") 385 | 386 | def encode(self, text): 387 | return self(text) 388 | 389 | def forward(self, text): 390 | clip_z = self.clip_encoder.encode(text) 391 | t5_z = self.t5_encoder.encode(text) 392 | return [clip_z, t5_z] -------------------------------------------------------------------------------- /lvdm/modules/encoders/ip_resampler.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class ImageProjModel(nn.Module): 8 | """Projection Model""" 9 | def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4): 10 | super().__init__() 11 | self.cross_attention_dim = cross_attention_dim 12 | self.clip_extra_context_tokens = clip_extra_context_tokens 13 | self.proj = nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) 14 | self.norm = nn.LayerNorm(cross_attention_dim) 15 | 16 | def forward(self, image_embeds): 17 | #embeds = image_embeds 18 | embeds = image_embeds.type(list(self.proj.parameters())[0].dtype) 19 | clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim) 20 | clip_extra_context_tokens = self.norm(clip_extra_context_tokens) 21 | return clip_extra_context_tokens 22 | 23 | # FFN 24 | def FeedForward(dim, mult=4): 25 | inner_dim = int(dim * mult) 26 | return nn.Sequential( 27 | nn.LayerNorm(dim), 28 | nn.Linear(dim, inner_dim, bias=False), 29 | nn.GELU(), 30 | nn.Linear(inner_dim, dim, bias=False), 31 | ) 32 | 33 | 34 | def reshape_tensor(x, heads): 35 | bs, length, width = x.shape 36 | #(bs, length, width) --> (bs, length, n_heads, dim_per_head) 37 | x = x.view(bs, length, heads, -1) 38 | # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) 39 | x = x.transpose(1, 2) 40 | # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) 41 | x = x.reshape(bs, heads, length, -1) 42 | return x 43 | 44 | 45 | class PerceiverAttention(nn.Module): 46 | def __init__(self, *, dim, dim_head=64, heads=8): 47 | super().__init__() 48 | self.scale = dim_head**-0.5 49 | self.dim_head = dim_head 50 | self.heads = heads 51 | inner_dim = dim_head * heads 52 | 53 | self.norm1 = nn.LayerNorm(dim) 54 | self.norm2 = nn.LayerNorm(dim) 55 | 56 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 57 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) 58 | self.to_out = nn.Linear(inner_dim, dim, bias=False) 59 | 60 | 61 | def forward(self, x, latents): 62 | """ 63 | Args: 64 | x (torch.Tensor): image features 65 | shape (b, n1, D) 66 | latent (torch.Tensor): latent features 67 | shape (b, n2, D) 68 | """ 69 | x = self.norm1(x) 70 | latents = self.norm2(latents) 71 | 72 | b, l, _ = latents.shape 73 | 74 | q = self.to_q(latents) 75 | kv_input = torch.cat((x, latents), dim=-2) 76 | k, v = self.to_kv(kv_input).chunk(2, dim=-1) 77 | 78 | q = reshape_tensor(q, self.heads) 79 | k = reshape_tensor(k, self.heads) 80 | v = reshape_tensor(v, self.heads) 81 | 82 | # attention 83 | scale = 1 / math.sqrt(math.sqrt(self.dim_head)) 84 | weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards 85 | weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) 86 | out = weight @ v 87 | 88 | out = out.permute(0, 2, 1, 3).reshape(b, l, -1) 89 | 90 | return self.to_out(out) 91 | 92 | 93 | class Resampler(nn.Module): 94 | def __init__( 95 | self, 96 | dim=1024, 97 | depth=8, 98 | dim_head=64, 99 | heads=16, 100 | num_queries=8, 101 | embedding_dim=768, 102 | output_dim=1024, 103 | ff_mult=4, 104 | ): 105 | super().__init__() 106 | 107 | self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5) 108 | 109 | self.proj_in = nn.Linear(embedding_dim, dim) 110 | 111 | self.proj_out = nn.Linear(dim, output_dim) 112 | self.norm_out = nn.LayerNorm(output_dim) 113 | 114 | self.layers = nn.ModuleList([]) 115 | for _ in range(depth): 116 | self.layers.append( 117 | nn.ModuleList( 118 | [ 119 | PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), 120 | FeedForward(dim=dim, mult=ff_mult), 121 | ] 122 | ) 123 | ) 124 | 125 | def forward(self, x): 126 | 127 | latents = self.latents.repeat(x.size(0), 1, 1) 128 | 129 | x = self.proj_in(x) 130 | 131 | for attn, ff in self.layers: 132 | latents = attn(x, latents) + latents 133 | latents = ff(latents) + latents 134 | 135 | latents = self.proj_out(latents) 136 | return self.norm_out(latents) -------------------------------------------------------------------------------- /lvdm/modules/networks/__pycache__/ae_modules.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoHunLee1/VideoGuide/adfe4eeff279a611c76f948ec9cc2829ed3f123e/lvdm/modules/networks/__pycache__/ae_modules.cpython-310.pyc -------------------------------------------------------------------------------- /lvdm/modules/networks/__pycache__/ae_modules.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoHunLee1/VideoGuide/adfe4eeff279a611c76f948ec9cc2829ed3f123e/lvdm/modules/networks/__pycache__/ae_modules.cpython-38.pyc -------------------------------------------------------------------------------- /lvdm/modules/networks/__pycache__/openaimodel3d.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoHunLee1/VideoGuide/adfe4eeff279a611c76f948ec9cc2829ed3f123e/lvdm/modules/networks/__pycache__/openaimodel3d.cpython-310.pyc -------------------------------------------------------------------------------- /lvdm/modules/networks/__pycache__/openaimodel3d.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoHunLee1/VideoGuide/adfe4eeff279a611c76f948ec9cc2829ed3f123e/lvdm/modules/networks/__pycache__/openaimodel3d.cpython-38.pyc -------------------------------------------------------------------------------- /prompts/prompt.txt: -------------------------------------------------------------------------------- 1 | A drone view of celebration with Christmas tree and fireworks 2 | Slow motion footage of a racing car 3 | A dog drinking water 4 | A koala bear playing piano in the forest -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | diffusers==0.11.1 2 | numpy==1.26.4 3 | transformers==4.25.1 4 | imageio==2.27.0 5 | safetensors==0.4.4 6 | pytorch-lightning==2.4.0 7 | open-clip-torch==2.24.0 8 | einops 9 | opencv-python 10 | decord 11 | omegaconf 12 | kornia 13 | # xformers==0.0.22.post4 --index-url https://download.pytorch.org/whl/cu118 -------------------------------------------------------------------------------- /scripts/evaluation/__pycache__/funcs.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoHunLee1/VideoGuide/adfe4eeff279a611c76f948ec9cc2829ed3f123e/scripts/evaluation/__pycache__/funcs.cpython-310.pyc -------------------------------------------------------------------------------- /scripts/evaluation/__pycache__/funcs.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoHunLee1/VideoGuide/adfe4eeff279a611c76f948ec9cc2829ed3f123e/scripts/evaluation/__pycache__/funcs.cpython-311.pyc -------------------------------------------------------------------------------- /scripts/evaluation/__pycache__/funcs.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoHunLee1/VideoGuide/adfe4eeff279a611c76f948ec9cc2829ed3f123e/scripts/evaluation/__pycache__/funcs.cpython-38.pyc -------------------------------------------------------------------------------- /scripts/evaluation/ddp_wrapper.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import argparse, importlib 3 | from pytorch_lightning import seed_everything 4 | 5 | import torch 6 | import torch.distributed as dist 7 | 8 | def setup_dist(local_rank): 9 | if dist.is_initialized(): 10 | return 11 | torch.cuda.set_device(local_rank) 12 | torch.distributed.init_process_group('nccl', init_method='env://') 13 | 14 | 15 | def get_dist_info(): 16 | if dist.is_available(): 17 | initialized = dist.is_initialized() 18 | else: 19 | initialized = False 20 | if initialized: 21 | rank = dist.get_rank() 22 | world_size = dist.get_world_size() 23 | else: 24 | rank = 0 25 | world_size = 1 26 | return rank, world_size 27 | 28 | 29 | if __name__ == '__main__': 30 | now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument("--module", type=str, help="module name", default="inference") 33 | parser.add_argument("--local_rank", type=int, nargs="?", help="for ddp", default=0) 34 | args, unknown = parser.parse_known_args() 35 | inference_api = importlib.import_module(args.module, package=None) 36 | 37 | inference_parser = inference_api.get_parser() 38 | inference_args, unknown = inference_parser.parse_known_args() 39 | 40 | seed_everything(inference_args.seed) 41 | setup_dist(args.local_rank) 42 | torch.backends.cudnn.benchmark = True 43 | rank, gpu_num = get_dist_info() 44 | 45 | print("@CoLVDM Inference [rank%d]: %s"%(rank, now)) 46 | inference_api.run_inference(inference_args, gpu_num, rank) -------------------------------------------------------------------------------- /scripts/evaluation/funcs.py: -------------------------------------------------------------------------------- 1 | import os, sys, glob 2 | import numpy as np 3 | from collections import OrderedDict 4 | from decord import VideoReader, cpu 5 | import cv2 6 | 7 | import torch 8 | import torchvision 9 | sys.path.insert(1, os.path.join(sys.path[0], '..', '..')) 10 | from lvdm.models.samplers.ddim import DDIMSampler 11 | 12 | 13 | def batch_ddim_sampling(model, cond, noise_shape, n_samples=1, ddim_steps=50, ddim_eta=1.0,\ 14 | cfg_scale=1.0, temporal_cfg_scale=None, **kwargs): 15 | ddim_sampler = DDIMSampler(model) 16 | uncond_type = model.uncond_type 17 | batch_size = noise_shape[0] 18 | 19 | ## construct unconditional guidance 20 | if cfg_scale != 1.0: 21 | if uncond_type == "empty_seq": 22 | prompts = batch_size * [""] 23 | #prompts = N * T * [""] ## if is_imgbatch=True 24 | uc_emb = model.get_learned_conditioning(prompts) 25 | elif uncond_type == "zero_embed": 26 | c_emb = cond["c_crossattn"][0] if isinstance(cond, dict) else cond 27 | uc_emb = torch.zeros_like(c_emb) 28 | 29 | ## process image embedding token 30 | if hasattr(model, 'embedder'): 31 | uc_img = torch.zeros(noise_shape[0],3,224,224).to(model.device) 32 | ## img: b c h w >> b l c 33 | uc_img = model.get_image_embeds(uc_img) 34 | uc_emb = torch.cat([uc_emb, uc_img], dim=1) 35 | 36 | if isinstance(cond, dict): 37 | uc = {key:cond[key] for key in cond.keys()} 38 | uc.update({'c_crossattn': [uc_emb]}) 39 | else: 40 | uc = uc_emb 41 | else: 42 | uc = None 43 | 44 | x_T = None 45 | batch_variants = [] 46 | #batch_variants1, batch_variants2 = [], [] 47 | for _ in range(n_samples): 48 | if ddim_sampler is not None: 49 | kwargs.update({"clean_cond": True}) 50 | samples, _ = ddim_sampler.sample(S=ddim_steps, 51 | conditioning=cond, 52 | batch_size=noise_shape[0], 53 | shape=noise_shape[1:], 54 | verbose=False, 55 | unconditional_guidance_scale=cfg_scale, 56 | unconditional_conditioning=uc, 57 | eta=ddim_eta, 58 | temporal_length=noise_shape[2], 59 | conditional_guidance_scale_temporal=temporal_cfg_scale, 60 | x_T=x_T, 61 | **kwargs 62 | ) 63 | ## reconstruct from latent to pixel space 64 | batch_images = model.decode_first_stage_2DAE(samples) 65 | batch_variants.append(batch_images) 66 | ## batch, , c, t, h, w 67 | batch_variants = torch.stack(batch_variants, dim=1) 68 | return batch_variants 69 | 70 | 71 | def get_filelist(data_dir, ext='*'): 72 | file_list = glob.glob(os.path.join(data_dir, '*.%s'%ext)) 73 | file_list.sort() 74 | return file_list 75 | 76 | def get_dirlist(path): 77 | list = [] 78 | if (os.path.exists(path)): 79 | files = os.listdir(path) 80 | for file in files: 81 | m = os.path.join(path,file) 82 | if (os.path.isdir(m)): 83 | list.append(m) 84 | list.sort() 85 | return list 86 | 87 | 88 | def load_model_checkpoint(model, ckpt): 89 | def load_checkpoint(model, ckpt, full_strict): 90 | state_dict = torch.load(ckpt, map_location="cpu") 91 | try: 92 | ## deepspeed 93 | new_pl_sd = OrderedDict() 94 | for key in state_dict['module'].keys(): 95 | new_pl_sd[key[16:]]=state_dict['module'][key] 96 | model.load_state_dict(new_pl_sd, strict=full_strict) 97 | except: 98 | if "state_dict" in list(state_dict.keys()): 99 | state_dict = state_dict["state_dict"] 100 | model.load_state_dict(state_dict, strict=full_strict) 101 | return model 102 | load_checkpoint(model, ckpt, full_strict=True) 103 | print('>>> model checkpoint loaded.') 104 | return model 105 | 106 | 107 | def load_prompts(prompt_file): 108 | f = open(prompt_file, 'r') 109 | prompt_list = [] 110 | for idx, line in enumerate(f.readlines()): 111 | l = line.strip() 112 | if len(l) != 0: 113 | prompt_list.append(l) 114 | f.close() 115 | return prompt_list 116 | 117 | 118 | def load_video_batch(filepath_list, frame_stride, video_size=(256,256), video_frames=16): 119 | ''' 120 | Notice about some special cases: 121 | 1. video_frames=-1 means to take all the frames (with fs=1) 122 | 2. when the total video frames is less than required, padding strategy will be used (repreated last frame) 123 | ''' 124 | fps_list = [] 125 | batch_tensor = [] 126 | assert frame_stride > 0, "valid frame stride should be a positive interge!" 127 | for filepath in filepath_list: 128 | padding_num = 0 129 | vidreader = VideoReader(filepath, ctx=cpu(0), width=video_size[1], height=video_size[0]) 130 | fps = vidreader.get_avg_fps() 131 | total_frames = len(vidreader) 132 | max_valid_frames = (total_frames-1) // frame_stride + 1 133 | if video_frames < 0: 134 | ## all frames are collected: fs=1 is a must 135 | required_frames = total_frames 136 | frame_stride = 1 137 | else: 138 | required_frames = video_frames 139 | query_frames = min(required_frames, max_valid_frames) 140 | frame_indices = [frame_stride*i for i in range(query_frames)] 141 | 142 | ## [t,h,w,c] -> [c,t,h,w] 143 | frames = vidreader.get_batch(frame_indices) 144 | frame_tensor = torch.tensor(frames.asnumpy()).permute(3, 0, 1, 2).float() 145 | frame_tensor = (frame_tensor / 255. - 0.5) * 2 146 | if max_valid_frames < required_frames: 147 | padding_num = required_frames - max_valid_frames 148 | frame_tensor = torch.cat([frame_tensor, *([frame_tensor[:,-1:,:,:]]*padding_num)], dim=1) 149 | print(f'{os.path.split(filepath)[1]} is not long enough: {padding_num} frames padded.') 150 | batch_tensor.append(frame_tensor) 151 | sample_fps = int(fps/frame_stride) 152 | fps_list.append(sample_fps) 153 | 154 | return torch.stack(batch_tensor, dim=0) 155 | 156 | from PIL import Image 157 | def load_image_batch(filepath_list, image_size=(256,256)): 158 | batch_tensor = [] 159 | for filepath in filepath_list: 160 | _, filename = os.path.split(filepath) 161 | _, ext = os.path.splitext(filename) 162 | if ext == '.mp4': 163 | vidreader = VideoReader(filepath, ctx=cpu(0), width=image_size[1], height=image_size[0]) 164 | frame = vidreader.get_batch([0]) 165 | img_tensor = torch.tensor(frame.asnumpy()).squeeze(0).permute(2, 0, 1).float() 166 | elif ext == '.png' or ext == '.jpg': 167 | img = Image.open(filepath).convert("RGB") 168 | rgb_img = np.array(img, np.float32) 169 | #bgr_img = cv2.imread(filepath, cv2.IMREAD_COLOR) 170 | #bgr_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB) 171 | rgb_img = cv2.resize(rgb_img, (image_size[1],image_size[0]), interpolation=cv2.INTER_LINEAR) 172 | img_tensor = torch.from_numpy(rgb_img).permute(2, 0, 1).float() 173 | else: 174 | print(f'ERROR: <{ext}> image loading only support format: [mp4], [png], [jpg]') 175 | raise NotImplementedError 176 | img_tensor = (img_tensor / 255. - 0.5) * 2 177 | batch_tensor.append(img_tensor) 178 | return torch.stack(batch_tensor, dim=0) 179 | 180 | 181 | def save_videos(batch_tensors, savedir, filenames, fps=10): 182 | # b,samples,c,t,h,w 183 | n_samples = batch_tensors.shape[1] 184 | for idx, vid_tensor in enumerate(batch_tensors): 185 | video = vid_tensor.detach().cpu() 186 | video = torch.clamp(video.float(), -1., 1.) 187 | video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w 188 | frame_grids = [torchvision.utils.make_grid(framesheet, nrow=int(n_samples)) for framesheet in video] #[3, 1*h, n*w] 189 | grid = torch.stack(frame_grids, dim=0) # stack in temporal dim [t, 3, n*h, w] 190 | grid = (grid + 1.0) / 2.0 191 | grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1) 192 | savepath = os.path.join(savedir, f"{filenames[idx]}.mp4") 193 | torchvision.io.write_video(savepath, grid, fps=fps, video_codec='h264', options={'crf': '10'}) 194 | 195 | -------------------------------------------------------------------------------- /scripts/evaluation/inference.py: -------------------------------------------------------------------------------- 1 | import argparse, os, sys, glob, yaml, math, random 2 | import datetime, time 3 | import numpy as np 4 | from omegaconf import OmegaConf 5 | from collections import OrderedDict 6 | from tqdm import trange, tqdm 7 | from einops import repeat 8 | from einops import rearrange, repeat 9 | from functools import partial 10 | import torch 11 | from pytorch_lightning import seed_everything 12 | 13 | from funcs import load_model_checkpoint, load_prompts, load_image_batch, get_filelist, save_videos 14 | from funcs import batch_ddim_sampling 15 | from utils.utils import instantiate_from_config 16 | 17 | 18 | def get_parser(): 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument("--seed", type=int, default=20230211, help="seed for seed_everything") 21 | parser.add_argument("--mode", default="base", type=str, help="which kind of inference mode: {'base', 'i2v'}") 22 | parser.add_argument("--ckpt_path", type=str, default=None, help="checkpoint path") 23 | parser.add_argument("--config", type=str, help="config (yaml) path") 24 | parser.add_argument("--prompt_file", type=str, default=None, help="a text file containing many prompts") 25 | parser.add_argument("--savedir", type=str, default=None, help="results saving path") 26 | parser.add_argument("--savefps", type=str, default=8, help="video fps to generate") 27 | parser.add_argument("--n_samples", type=int, default=1, help="num of samples per prompt",) 28 | parser.add_argument("--ddim_steps", type=int, default=50, help="steps of ddim if positive, otherwise use DDPM",) 29 | parser.add_argument("--ddim_eta", type=float, default=1.0, help="eta for ddim sampling (0.0 yields deterministic sampling)",) 30 | parser.add_argument("--bs", type=int, default=1, help="batch size for inference") 31 | parser.add_argument("--height", type=int, default=512, help="image height, in pixel space") 32 | parser.add_argument("--width", type=int, default=512, help="image width, in pixel space") 33 | parser.add_argument("--frames", type=int, default=-1, help="frames num to inference") 34 | parser.add_argument("--fps", type=int, default=24) 35 | parser.add_argument("--unconditional_guidance_scale", type=float, default=1.0, help="prompt classifier-free guidance") 36 | parser.add_argument("--unconditional_guidance_scale_temporal", type=float, default=None, help="temporal consistency guidance") 37 | ## for conditional i2v only 38 | parser.add_argument("--cond_input", type=str, default=None, help="data dir of conditional input") 39 | return parser 40 | 41 | 42 | def run_inference(args, gpu_num, gpu_no, **kwargs): 43 | ## step 1: model config 44 | ## ----------------------------------------------------------------- 45 | config = OmegaConf.load(args.config) 46 | #data_config = config.pop("data", OmegaConf.create()) 47 | model_config = config.pop("model", OmegaConf.create()) 48 | model = instantiate_from_config(model_config) 49 | model = model.cuda(gpu_no) 50 | assert os.path.exists(args.ckpt_path), f"Error: checkpoint [{args.ckpt_path}] Not Found!" 51 | model = load_model_checkpoint(model, args.ckpt_path) 52 | model.eval() 53 | 54 | ## sample shape 55 | assert (args.height % 16 == 0) and (args.width % 16 == 0), "Error: image size [h,w] should be multiples of 16!" 56 | ## latent noise shape 57 | h, w = args.height // 8, args.width // 8 58 | frames = model.temporal_length if args.frames < 0 else args.frames 59 | channels = model.channels 60 | 61 | ## saving folders 62 | os.makedirs(args.savedir, exist_ok=True) 63 | 64 | ## step 2: load data 65 | ## ----------------------------------------------------------------- 66 | assert os.path.exists(args.prompt_file), "Error: prompt file NOT Found!" 67 | prompt_list = load_prompts(args.prompt_file) 68 | num_samples = len(prompt_list) 69 | filename_list = [f"{id+1:04d}" for id in range(num_samples)] 70 | 71 | samples_split = num_samples // gpu_num 72 | residual_tail = num_samples % gpu_num 73 | print(f'[rank:{gpu_no}] {samples_split}/{num_samples} samples loaded.') 74 | indices = list(range(samples_split*gpu_no, samples_split*(gpu_no+1))) 75 | if gpu_no == 0 and residual_tail != 0: 76 | indices = indices + list(range(num_samples-residual_tail, num_samples)) 77 | prompt_list_rank = [prompt_list[i] for i in indices] 78 | 79 | ## conditional input 80 | if args.mode == "i2v": 81 | ## each video or frames dir per prompt 82 | cond_inputs = get_filelist(args.cond_input, ext='[mpj][pn][4gj]') # '[mpj][pn][4gj]' 83 | assert len(cond_inputs) == num_samples, f"Error: conditional input ({len(cond_inputs)}) NOT match prompt ({num_samples})!" 84 | filename_list = [f"{os.path.split(cond_inputs[id])[-1][:-4]}" for id in range(num_samples)] 85 | cond_inputs_rank = [cond_inputs[i] for i in indices] 86 | 87 | filename_list_rank = [filename_list[i] for i in indices] 88 | 89 | ## step 3: run over samples 90 | ## ----------------------------------------------------------------- 91 | start = time.time() 92 | n_rounds = len(prompt_list_rank) // args.bs 93 | n_rounds = n_rounds+1 if len(prompt_list_rank) % args.bs != 0 else n_rounds 94 | for idx in range(0, n_rounds): 95 | print(f'[rank:{gpu_no}] batch-{idx+1} ({args.bs})x{args.n_samples} ...') 96 | idx_s = idx*args.bs 97 | idx_e = min(idx_s+args.bs, len(prompt_list_rank)) 98 | batch_size = idx_e - idx_s 99 | filenames = filename_list_rank[idx_s:idx_e] 100 | noise_shape = [batch_size, channels, frames, h, w] 101 | fps = torch.tensor([args.fps]*batch_size).to(model.device).long() 102 | 103 | prompts = prompt_list_rank[idx_s:idx_e] 104 | if isinstance(prompts, str): 105 | prompts = [prompts] 106 | #prompts = batch_size * [""] 107 | text_emb = model.get_learned_conditioning(prompts) 108 | 109 | if args.mode == 'base': 110 | cond = {"c_crossattn": [text_emb], "fps": fps} 111 | elif args.mode == 'i2v': 112 | #cond_images = torch.zeros(noise_shape[0],3,224,224).to(model.device) 113 | cond_images = load_image_batch(cond_inputs_rank[idx_s:idx_e], (args.height, args.width)) 114 | cond_images = cond_images.to(model.device) 115 | img_emb = model.get_image_embeds(cond_images) 116 | imtext_cond = torch.cat([text_emb, img_emb], dim=1) 117 | cond = {"c_crossattn": [imtext_cond], "fps": fps} 118 | else: 119 | raise NotImplementedError 120 | 121 | ## inference 122 | batch_samples = batch_ddim_sampling(model, cond, noise_shape, args.n_samples, \ 123 | args.ddim_steps, args.ddim_eta, args.unconditional_guidance_scale, **kwargs) 124 | ## b,samples,c,t,h,w 125 | save_videos(batch_samples, args.savedir, filenames, fps=args.savefps) 126 | 127 | print(f"Saved in {args.savedir}. Time used: {(time.time() - start):.2f} seconds") 128 | 129 | 130 | if __name__ == '__main__': 131 | now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") 132 | print("@CoLVDM Inference: %s"%now) 133 | parser = get_parser() 134 | args = parser.parse_args() 135 | seed_everything(args.seed) 136 | rank, gpu_num = 0, 1 137 | run_inference(args, gpu_num, rank) -------------------------------------------------------------------------------- /scripts/gradio/i2v_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from omegaconf import OmegaConf 4 | import torch 5 | from scripts.evaluation.funcs import load_model_checkpoint, load_image_batch, save_videos, batch_ddim_sampling 6 | from utils.utils import instantiate_from_config 7 | from huggingface_hub import hf_hub_download 8 | 9 | class Image2Video(): 10 | def __init__(self,result_dir='./tmp/',gpu_num=1) -> None: 11 | self.download_model() 12 | self.result_dir = result_dir 13 | if not os.path.exists(self.result_dir): 14 | os.mkdir(self.result_dir) 15 | ckpt_path='checkpoints/i2v_512_v1/model.ckpt' 16 | config_file='configs/inference_i2v_512_v1.0.yaml' 17 | config = OmegaConf.load(config_file) 18 | model_config = config.pop("model", OmegaConf.create()) 19 | model_config['params']['unet_config']['params']['use_checkpoint']=False 20 | model_list = [] 21 | for gpu_id in range(gpu_num): 22 | model = instantiate_from_config(model_config) 23 | # model = model.cuda(gpu_id) 24 | assert os.path.exists(ckpt_path), "Error: checkpoint Not Found!" 25 | model = load_model_checkpoint(model, ckpt_path) 26 | model.eval() 27 | model_list.append(model) 28 | self.model_list = model_list 29 | self.save_fps = 8 30 | 31 | def get_image(self, image, prompt, steps=50, cfg_scale=12.0, eta=1.0, fps=16): 32 | torch.cuda.empty_cache() 33 | print('start:', prompt, time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time()))) 34 | start = time.time() 35 | gpu_id=0 36 | if steps > 60: 37 | steps = 60 38 | model = self.model_list[gpu_id] 39 | model = model.cuda() 40 | batch_size=1 41 | channels = model.model.diffusion_model.in_channels 42 | frames = model.temporal_length 43 | h, w = 320 // 8, 512 // 8 44 | noise_shape = [batch_size, channels, frames, h, w] 45 | 46 | # text cond 47 | text_emb = model.get_learned_conditioning([prompt]) 48 | 49 | # img cond 50 | img_tensor = torch.from_numpy(image).permute(2, 0, 1).float() 51 | img_tensor = (img_tensor / 255. - 0.5) * 2 52 | img_tensor = img_tensor.unsqueeze(0) 53 | cond_images = img_tensor.to(model.device) 54 | img_emb = model.get_image_embeds(cond_images) 55 | imtext_cond = torch.cat([text_emb, img_emb], dim=1) 56 | cond = {"c_crossattn": [imtext_cond], "fps": fps} 57 | 58 | ## inference 59 | batch_samples = batch_ddim_sampling(model, cond, noise_shape, n_samples=1, ddim_steps=steps, ddim_eta=eta, cfg_scale=cfg_scale) 60 | ## b,samples,c,t,h,w 61 | prompt_str = prompt.replace("/", "_slash_") if "/" in prompt else prompt 62 | prompt_str = prompt_str.replace(" ", "_") if " " in prompt else prompt_str 63 | prompt_str=prompt_str[:30] 64 | 65 | save_videos(batch_samples, self.result_dir, filenames=[prompt_str], fps=self.save_fps) 66 | print(f"Saved in {prompt_str}. Time used: {(time.time() - start):.2f} seconds") 67 | model = model.cpu() 68 | return os.path.join(self.result_dir, f"{prompt_str}.mp4") 69 | 70 | def download_model(self): 71 | REPO_ID = 'VideoCrafter/Image2Video-512' 72 | filename_list = ['model.ckpt'] 73 | if not os.path.exists('./checkpoints/i2v_512_v1/'): 74 | os.makedirs('./checkpoints/i2v_512_v1/') 75 | for filename in filename_list: 76 | local_file = os.path.join('./checkpoints/i2v_512_v1/', filename) 77 | if not os.path.exists(local_file): 78 | hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir='./checkpoints/i2v_512_v1/', local_dir_use_symlinks=False) 79 | 80 | if __name__ == '__main__': 81 | i2v = Image2Video() 82 | video_path = i2v.get_image('prompts/i2v_prompts/horse.png','horses are walking on the grassland') 83 | print('done', video_path) -------------------------------------------------------------------------------- /scripts/gradio/t2v_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from omegaconf import OmegaConf 4 | import torch 5 | from scripts.evaluation.funcs import load_model_checkpoint, save_videos, batch_ddim_sampling 6 | from utils.utils import instantiate_from_config 7 | from huggingface_hub import hf_hub_download 8 | 9 | class Text2Video(): 10 | def __init__(self,result_dir='./tmp/',gpu_num=1) -> None: 11 | self.download_model() 12 | self.result_dir = result_dir 13 | if not os.path.exists(self.result_dir): 14 | os.mkdir(self.result_dir) 15 | ckpt_path='checkpoints/base_512_v2/model.ckpt' 16 | config_file='configs/inference_t2v_512_v2.0.yaml' 17 | config = OmegaConf.load(config_file) 18 | model_config = config.pop("model", OmegaConf.create()) 19 | model_config['params']['unet_config']['params']['use_checkpoint']=False 20 | model_list = [] 21 | for gpu_id in range(gpu_num): 22 | model = instantiate_from_config(model_config) 23 | # model = model.cuda(gpu_id) 24 | assert os.path.exists(ckpt_path), "Error: checkpoint Not Found!" 25 | model = load_model_checkpoint(model, ckpt_path) 26 | model.eval() 27 | model_list.append(model) 28 | self.model_list = model_list 29 | self.save_fps = 8 30 | 31 | def get_prompt(self, prompt, steps=50, cfg_scale=12.0, eta=1.0, fps=16): 32 | torch.cuda.empty_cache() 33 | print('start:', prompt, time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time()))) 34 | start = time.time() 35 | gpu_id=0 36 | if steps > 60: 37 | steps = 60 38 | model = self.model_list[gpu_id] 39 | model = model.cuda() 40 | batch_size=1 41 | channels = model.model.diffusion_model.in_channels 42 | frames = model.temporal_length 43 | h, w = 320 // 8, 512 // 8 44 | noise_shape = [batch_size, channels, frames, h, w] 45 | 46 | # text cond 47 | text_emb = model.get_learned_conditioning([prompt]) 48 | cond = {"c_crossattn": [text_emb], "fps": fps} 49 | 50 | ## inference 51 | batch_samples = batch_ddim_sampling(model, cond, noise_shape, n_samples=1, ddim_steps=steps, ddim_eta=eta, cfg_scale=cfg_scale) 52 | ## b,samples,c,t,h,w 53 | prompt_str = prompt.replace("/", "_slash_") if "/" in prompt else prompt 54 | prompt_str = prompt_str.replace(" ", "_") if " " in prompt else prompt_str 55 | prompt_str=prompt_str[:30] 56 | 57 | save_videos(batch_samples, self.result_dir, filenames=[prompt_str], fps=self.save_fps) 58 | print(f"Saved in {prompt_str}. Time used: {(time.time() - start):.2f} seconds") 59 | model=model.cpu() 60 | return os.path.join(self.result_dir, f"{prompt_str}.mp4") 61 | 62 | def download_model(self): 63 | REPO_ID = 'VideoCrafter/VideoCrafter2' 64 | filename_list = ['model.ckpt'] 65 | if not os.path.exists('./checkpoints/base_512_v2/'): 66 | os.makedirs('./checkpoints/base_512_v2/') 67 | for filename in filename_list: 68 | local_file = os.path.join('./checkpoints/base_512_v2/', filename) 69 | 70 | if not os.path.exists(local_file): 71 | hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir='./checkpoints/base_512_v2/', local_dir_use_symlinks=False) 72 | 73 | 74 | if __name__ == '__main__': 75 | t2v = Text2Video() 76 | video_path = t2v.get_prompt('a black swan swims on the pond') 77 | print('done', video_path) -------------------------------------------------------------------------------- /scripts/run_image2video.sh: -------------------------------------------------------------------------------- 1 | name="i2v_512_test" 2 | 3 | ckpt='checkpoints/i2v_512_v1/model.ckpt' 4 | config='configs/inference_i2v_512_v1.0.yaml' 5 | 6 | prompt_file="prompts/i2v_prompts/test_prompts.txt" 7 | condimage_dir="prompts/i2v_prompts" 8 | res_dir="results" 9 | 10 | python3 scripts/evaluation/inference.py \ 11 | --seed 123 \ 12 | --mode 'i2v' \ 13 | --ckpt_path $ckpt \ 14 | --config $config \ 15 | --savedir $res_dir/$name \ 16 | --n_samples 1 \ 17 | --bs 1 --height 320 --width 512 \ 18 | --unconditional_guidance_scale 12.0 \ 19 | --ddim_steps 50 \ 20 | --ddim_eta 1.0 \ 21 | --prompt_file $prompt_file \ 22 | --cond_input $condimage_dir \ 23 | --fps 8 24 | 25 | -------------------------------------------------------------------------------- /scripts/run_text2video.sh: -------------------------------------------------------------------------------- 1 | name="base_512_v2" 2 | 3 | ckpt='checkpoints/base_512_v2/model.ckpt' 4 | config='configs/inference_t2v_512_v2.0.yaml' 5 | 6 | prompt_file="prompts/test_prompts.txt" 7 | res_dir="results_512" 8 | 9 | CUDA_VISIBLE_DEVICES=5 python3 scripts/evaluation/inference.py \ 10 | --seed 123 \ 11 | --mode 'base' \ 12 | --ckpt_path $ckpt \ 13 | --config $config \ 14 | --savedir $res_dir/$name \ 15 | --n_samples 1 \ 16 | --bs 1 --height 512 --width 512 \ 17 | --unconditional_guidance_scale 12.0 \ 18 | --ddim_steps 50 \ 19 | --ddim_eta 1.0 \ 20 | --prompt_file $prompt_file \ 21 | --fps 8 22 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import imageio 3 | import numpy as np 4 | from typing import Union 5 | 6 | import torch 7 | import torchvision 8 | import torch.distributed as dist 9 | 10 | from safetensors import safe_open 11 | from tqdm import tqdm 12 | from einops import rearrange 13 | from convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint 14 | from convert_lora_safetensor_to_diffusers import convert_lora, load_diffusers_lora, load_diffusers_lora_unet 15 | 16 | def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8): 17 | videos = rearrange(videos, "b c t h w -> t b c h w") 18 | outputs = [] 19 | for x in videos: 20 | x = torchvision.utils.make_grid(x, nrow=n_rows) 21 | x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) 22 | if rescale: 23 | x = (x + 1.0) / 2.0 # -1,1 -> 0,1 24 | x = (x * 255).numpy().astype(np.uint8) 25 | outputs.append(x) 26 | 27 | os.makedirs(os.path.dirname(path), exist_ok=True) 28 | imageio.mimsave(path, outputs, fps=fps) 29 | 30 | def load_weights( 31 | unet, 32 | vae, 33 | text_encoder, 34 | # motion module 35 | motion_module_path = "", 36 | motion_module_lora_configs = [], 37 | # domain adapter 38 | adapter_lora_path = "", 39 | adapter_lora_scale = 1.0, 40 | # image layers 41 | dreambooth_model_path = "", 42 | lora_model_path = "", 43 | lora_alpha = 0.8, 44 | ): 45 | # motion module 46 | unet_state_dict = {} 47 | if motion_module_path != "": 48 | print(f"load motion module from {motion_module_path}") 49 | if motion_module_path.endswith(".safetensors"): 50 | motion_module_state_dict = {} 51 | with safe_open(motion_module_path, framework="pt", device="cpu") as f: 52 | for key in f.keys(): 53 | motion_module_state_dict[key] = f.get_tensor(key) 54 | else: 55 | motion_module_state_dict = torch.load(motion_module_path, map_location="cpu") 56 | 57 | motion_module_state_dict = motion_module_state_dict["state_dict"] if "state_dict" in motion_module_state_dict else motion_module_state_dict 58 | motion_module_state_dict = {k.replace("module.", ""):v for k, v in motion_module_state_dict.items()} 59 | unet_state_dict.update({name: param for name, param in motion_module_state_dict.items() if "motion_modules." in name}) 60 | unet_state_dict.pop("animatediff_config", "") 61 | 62 | missing, unexpected = unet.load_state_dict(unet_state_dict, strict=False) 63 | print("missing: ", len(missing), " unexpected: ", len(unexpected)) 64 | assert len(unexpected) == 0 65 | del unet_state_dict 66 | 67 | # base model 68 | if dreambooth_model_path != "": 69 | print(f"load dreambooth model from {dreambooth_model_path}") 70 | if dreambooth_model_path.endswith(".safetensors"): 71 | dreambooth_state_dict = {} 72 | with safe_open(dreambooth_model_path, framework="pt", device="cpu") as f: 73 | for key in f.keys(): 74 | dreambooth_state_dict[key] = f.get_tensor(key) 75 | elif dreambooth_model_path.endswith(".ckpt"): 76 | dreambooth_state_dict = torch.load(dreambooth_model_path, map_location="cpu") 77 | 78 | # 1. vae 79 | converted_vae_checkpoint = convert_ldm_vae_checkpoint(dreambooth_state_dict, vae.config) 80 | vae.load_state_dict(converted_vae_checkpoint) 81 | # 2. unet 82 | converted_unet_checkpoint = convert_ldm_unet_checkpoint(dreambooth_state_dict, unet.config) 83 | unet.load_state_dict(converted_unet_checkpoint, strict=False) 84 | # 3. text_model 85 | text_encoder = convert_ldm_clip_checkpoint(dreambooth_state_dict) 86 | del dreambooth_state_dict 87 | 88 | # # lora layers 89 | # if lora_model_path != "": 90 | # print(f"load lora model from {lora_model_path}") 91 | # assert lora_model_path.endswith(".safetensors") 92 | # lora_state_dict = {} 93 | # with safe_open(lora_model_path, framework="pt", device="cpu") as f: 94 | # for key in f.keys(): 95 | # lora_state_dict[key] = f.get_tensor(key) 96 | 97 | # # convert lora function은 각 layer에 맞춰서 weight를 더해주는 작업을 함 98 | # animation_pipeline = convert_lora(animation_pipeline, lora_state_dict, alpha=lora_alpha) 99 | # del lora_state_dict 100 | 101 | # # domain adapter lora 102 | # if adapter_lora_path != "": 103 | # print(f"load domain lora from {adapter_lora_path}") 104 | 105 | # if adapter_lora_path.endswith(".safetensors"): 106 | # domain_lora_state_dict = {} 107 | # with safe_open(adapter_lora_path, framework="pt", device="cpu") as f: 108 | # for key in f.keys(): 109 | # domain_lora_state_dict[key] = f.get_tensor(key) 110 | # else: 111 | # domain_lora_state_dict = torch.load(adapter_lora_path, map_location="cpu") 112 | 113 | # domain_lora_state_dict = domain_lora_state_dict["state_dict"] if "state_dict" in domain_lora_state_dict else domain_lora_state_dict 114 | # domain_lora_state_dict.pop("animatediff_config", "") 115 | 116 | # animation_pipeline = load_diffusers_lora(animation_pipeline, domain_lora_state_dict, alpha=adapter_lora_scale) 117 | 118 | # # motion module lora 119 | # for motion_module_lora_config in motion_module_lora_configs: 120 | # path, alpha = motion_module_lora_config["path"], motion_module_lora_config["alpha"] 121 | # print(f"load motion LoRA from {path}") 122 | # motion_lora_state_dict = torch.load(path, map_location="cpu") 123 | # motion_lora_state_dict = motion_lora_state_dict["state_dict"] if "state_dict" in motion_lora_state_dict else motion_lora_state_dict 124 | # motion_lora_state_dict.pop("animatediff_config", "") 125 | 126 | # animation_pipeline = load_diffusers_lora(animation_pipeline, motion_lora_state_dict, alpha) 127 | 128 | return unet -------------------------------------------------------------------------------- /vc_configs/inference_t2v_512_v2.0.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: lvdm.models.ddpm3d.LatentDiffusion 3 | params: 4 | linear_start: 0.00085 5 | linear_end: 0.012 6 | num_timesteps_cond: 1 7 | timesteps: 1000 8 | first_stage_key: video 9 | cond_stage_key: caption 10 | cond_stage_trainable: false 11 | conditioning_key: crossattn 12 | image_size: 13 | - 40 14 | - 64 15 | channels: 4 16 | scale_by_std: false 17 | scale_factor: 0.18215 18 | use_ema: false 19 | uncond_type: empty_seq 20 | use_scale: true 21 | scale_b: 0.7 22 | unet_config: 23 | target: lvdm.modules.networks.openaimodel3d.UNetModel 24 | params: 25 | in_channels: 4 26 | out_channels: 4 27 | model_channels: 320 28 | attention_resolutions: 29 | - 4 30 | - 2 31 | - 1 32 | num_res_blocks: 2 33 | channel_mult: 34 | - 1 35 | - 2 36 | - 4 37 | - 4 38 | num_head_channels: 64 39 | transformer_depth: 1 40 | context_dim: 1024 41 | use_linear: true 42 | use_checkpoint: true 43 | temporal_conv: true 44 | temporal_attention: true 45 | temporal_selfatt_only: true 46 | use_relative_position: false 47 | use_causal_attention: false 48 | temporal_length: 16 49 | addition_attention: true 50 | fps_cond: true 51 | first_stage_config: 52 | target: lvdm.models.autoencoder.AutoencoderKL 53 | params: 54 | embed_dim: 4 55 | monitor: val/rec_loss 56 | ddconfig: 57 | double_z: true 58 | z_channels: 4 59 | resolution: 512 60 | in_channels: 3 61 | out_ch: 3 62 | ch: 128 63 | ch_mult: 64 | - 1 65 | - 2 66 | - 4 67 | - 4 68 | num_res_blocks: 2 69 | attn_resolutions: [] 70 | dropout: 0.0 71 | lossconfig: 72 | target: torch.nn.Identity 73 | cond_stage_config: 74 | target: lvdm.modules.encoders.condition.FrozenOpenCLIPEmbedder 75 | params: 76 | freeze: true 77 | layer: penultimate 78 | 79 | noise_scheduler_kwargs: 80 | beta_start: 0.00085 81 | beta_end: 0.012 82 | beta_schedule: "linear" 83 | steps_offset: 1 84 | clip_sample: False 85 | -------------------------------------------------------------------------------- /vc_utils/__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoHunLee1/VideoGuide/adfe4eeff279a611c76f948ec9cc2829ed3f123e/vc_utils/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /vc_utils/__pycache__/utils.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoHunLee1/VideoGuide/adfe4eeff279a611c76f948ec9cc2829ed3f123e/vc_utils/__pycache__/utils.cpython-311.pyc -------------------------------------------------------------------------------- /vc_utils/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DoHunLee1/VideoGuide/adfe4eeff279a611c76f948ec9cc2829ed3f123e/vc_utils/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /vc_utils/utils.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import numpy as np 3 | import cv2 4 | import torch 5 | import torch.distributed as dist 6 | 7 | 8 | def count_params(model, verbose=False): 9 | total_params = sum(p.numel() for p in model.parameters()) 10 | if verbose: 11 | print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") 12 | return total_params 13 | 14 | 15 | def check_istarget(name, para_list): 16 | """ 17 | name: full name of source para 18 | para_list: partial name of target para 19 | """ 20 | istarget=False 21 | for para in para_list: 22 | if para in name: 23 | return True 24 | return istarget 25 | 26 | 27 | def instantiate_from_config(config): 28 | if not "target" in config: 29 | if config == '__is_first_stage__': 30 | return None 31 | elif config == "__is_unconditional__": 32 | return None 33 | raise KeyError("Expected key `target` to instantiate.") 34 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 35 | 36 | 37 | def get_obj_from_str(string, reload=False): 38 | module, cls = string.rsplit(".", 1) 39 | if reload: 40 | module_imp = importlib.import_module(module) 41 | importlib.reload(module_imp) 42 | return getattr(importlib.import_module(module, package=None), cls) 43 | 44 | 45 | def load_npz_from_dir(data_dir): 46 | data = [np.load(os.path.join(data_dir, data_name))['arr_0'] for data_name in os.listdir(data_dir)] 47 | data = np.concatenate(data, axis=0) 48 | return data 49 | 50 | 51 | def load_npz_from_paths(data_paths): 52 | data = [np.load(data_path)['arr_0'] for data_path in data_paths] 53 | data = np.concatenate(data, axis=0) 54 | return data 55 | 56 | 57 | def resize_numpy_image(image, max_resolution=512 * 512, resize_short_edge=None): 58 | h, w = image.shape[:2] 59 | if resize_short_edge is not None: 60 | k = resize_short_edge / min(h, w) 61 | else: 62 | k = max_resolution / (h * w) 63 | k = k**0.5 64 | h = int(np.round(h * k / 64)) * 64 65 | w = int(np.round(w * k / 64)) * 64 66 | image = cv2.resize(image, (w, h), interpolation=cv2.INTER_LANCZOS4) 67 | return image 68 | 69 | 70 | def setup_dist(args): 71 | if dist.is_initialized(): 72 | return 73 | torch.cuda.set_device(args.local_rank) 74 | torch.distributed.init_process_group( 75 | 'nccl', 76 | init_method='env://' 77 | ) --------------------------------------------------------------------------------