├── .gitignore ├── LICENSE ├── README.md ├── capture_frames.py ├── config ├── mnist.yaml └── ucf.yaml ├── dataset ├── __init__.py ├── image_dataset.py ├── ucf_filter.txt └── video_dataset.py ├── model ├── __init__.py ├── attention.py ├── blocks.py ├── discriminator.py ├── lpips.py ├── patch_embed.py ├── transformer.py ├── transformer_layer.py ├── vae.py └── weights │ └── v0.1 │ └── .gitkeep ├── requirements.txt ├── scheduler ├── __init__.py └── linear_scheduler.py └── tools ├── __init__.py ├── sample_vae_ditv.py ├── save_latents.py ├── train_vae.py └── train_vae_ditv.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore all image files 2 | *.jpg 3 | *.png 4 | *.jpeg 5 | 6 | # Ignore pycharm and system files 7 | .DS_Store 8 | *.idea 9 | __pycache__ 10 | *.zip 11 | 12 | # Ignore dataset files 13 | *.csv 14 | *.json 15 | 16 | # Ignore checkpoints 17 | *.pth 18 | 19 | # Ignore pickle files 20 | *.pkl -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 ExplainingAI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Video Generation using Diffusion Transformers in PyTorch 2 | ======== 3 | 4 | ## Building Video Generation Model Tutorial 5 | 6 | Video generation with Diffusion Transformers 8 | 9 | 10 | 11 | ## Sample Output for Latte on moving mnist easy videos 12 | Trained for 300 epochs 13 | 14 | ![mnist0-ezgif com-video-to-gif-converter](https://github.com/user-attachments/assets/a71397ae-5848-439a-94f6-4a73bc35bd4e) 15 | ![mnist1-ezgif com-video-to-gif-converter](https://github.com/user-attachments/assets/5c535116-95b1-46e3-86ef-0cec4b1e56c2) 16 | ![mnist3-ezgif com-video-to-gif-converter](https://github.com/user-attachments/assets/5c4dfb2b-82b1-4bba-ac03-023ddbcf58ab) 17 | ![mnist4-ezgif com-video-to-gif-converter](https://github.com/user-attachments/assets/7d549275-b6bd-4af3-8607-f8214831eb18) 18 | 19 | ## Sample Output for Latte on UCF101 videos 20 | Trained for 500 epochs(needs more training) 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | ___ 30 | This repository implements Latent Diffusion Transformer for Video Generation Paper. It provides code for the following: 31 | * Training and inference of VAE on Moving Mnist and UCF101 frames 32 | * Training and Inference of Latte Video Model using trained VAE on 16 frame video clips of both datasets 33 | * Configurable code for training all models from Latte-S to Latte-XL 34 | 35 | This repo has few changes from the [official Latte implementation](https://github.com/Vchitect/Latte) except the following changes. 36 | * Current code is for unconditional generation 37 | * Variance is fixed during training and not learned (like original DDPM) 38 | * No EMA 39 | * Ability to train VAE 40 | * Ability to save latents of video frames for faster training 41 | 42 | 43 | ## Setup 44 | * Create a new conda environment with python 3.10 then run below commands 45 | * `conda activate ` 46 | * ```git clone https://github.com/explainingai-code/VideoGeneration-PyTorch.git``` 47 | * ```cd VideoGeneration-PyTorch``` 48 | * ```pip install -r requirements.txt``` 49 | * Download lpips weights by opening this link in browser(dont use wget) https://github.com/richzhang/PerceptualSimilarity/blob/master/lpips/weights/v0.1/vgg.pth and downloading the raw file. Place the downloaded weights file in ```model/weights/v0.1/vgg.pth``` 50 | ___ 51 | 52 | ## Data Preparation 53 | 54 | ### Moving Mnist Easy Videos 55 | For moving mnist I have used the easy category of videos which have one number moving across frames. 56 | Download the videos from [this](https://www.kaggle.com/datasets/yichengs/captioned-moving-mnist-dataset-easy-version/) page. This includes captions as well so one can also try with text to video models using this. 57 | Create a `data` directory in the repo root and add the downloaded `mmnist-easy` folder there. 58 | Ensure directory structure is the following 59 | ``` 60 | VideoGeneration-PyTorch 61 | -> data 62 | -> mmnist-easy 63 | -> *.mp4 64 | ``` 65 | 66 | For setting up UCF101, simply download the videos from the official page [here](https://www.crcv.ucf.edu/data/UCF101.php) 67 | and add it to `data` directory. 68 | Ensure directory structure is the following 69 | ``` 70 | VideoGeneration-PyTorch 71 | -> data 72 | -> UCF101 73 | -> *.avi 74 | 75 | ``` 76 | --- 77 | ## Configuration 78 | Allows you to play with different components of Latte and autoencoder 79 | * ```config/mnist.yaml``` - Configuration used for moving mnist dataset 80 | * ```config/ucf.yaml``` - Configuration used for ucf dataset 81 | Important configuration parameters 82 | * `autoencoder_acc_steps` : For accumulating gradients if video size is too large and a large batch size cant be used. 83 | 84 | ___ 85 | ## Training 86 | The repo provides training and inference for Moving Mnist (Unconditional Latte Model) 87 | 88 | For working on your own dataset: 89 | * Create your own config and ensure following config parameters are correctly set 90 | * `im_path` - Folder name where latents will be saved (when `save_latents` script is run later) 91 | * `video_path`- Path to the videos 92 | * `video_ext` - Extension for videos. Assumption is all videos are same Extension 93 | * `frame_height`, `frame_width`, `frame_channels' - Dimension to which frames would be resized to. 94 | * `centre_square_crop` - If center cropping is needed on frames or not 95 | * `video_filter_path` - null / location of txt file which contains video names that need to be taken for training. If `null` then all videos will be used(like moving mnist). For seeing how to construct this filter file look at `dataset/ucf_filter.txt` file 96 | * The existing `video_dataset.py` should require minimal modifications to adapt to your dataset requirements 97 | 98 | Once the config and dataset is setup: 99 | * First train the auto encoder on your dataset using [this section](#training-autoencoder-for-latte) 100 | * For training and inference of Unconditional Latte follow [this section](#training-unconditional-latte) 101 | 102 | ## Training AutoEncoder for Latte 103 | * We need to first extract frames for training autoencoder 104 | * By default, we extract only 10% of frames from our videos. Change `ae_frame_sample_prob` in `dataset_params` of config if you want to train on larger number of frames. For both moving mnist and ucf, 10% works fine. 105 | * If you need to train the latte model on a subset of videos then ensure the `video_filter_path` is correctly set. Look at ucf config and `dataset/ucf_filter.txt` for guidance. 106 | * Extract the frames by running `python -m capture_frames --config config/mnist.yaml` with the right config value 107 | * For training autoencoder 108 | * Minimal modifications might be needed to `dataset/iamge_dataset.py` to adapt to your own dataset. 109 | * Make sure the `frame_height` and `frame_width` parameters are correctly set for resizing frames(if needed) 110 | * Run ```python -m tools.train_vae --config config/mnist.yaml``` for training autoencoder with the right config file 111 | * In case you gpu memory is limited, I would suggest to run `python -m tools.save_latents --config config/mnist.yaml` with the correct config. This script would save the latent frames for all your dataset videos. Otherwise during diffusion model training we would have to load vae also. 112 | 113 | ## Training Unconditional Latte 114 | Train the autoencoder first and make changes to config and `video_dataset.py`(if any needed) to adapt to your requirements. 115 | 116 | * ```python -m tools.train_vae_ditv --config config/mnist.yaml``` for training unconditional Latte using right config 117 | * ```python -m tools.sample_vae_ditv --config config/mnist.yaml``` for generating videos using trained Latte model 118 | 119 | 120 | ## Output 121 | Outputs will be saved according to the configuration present in yaml files. 122 | 123 | For every run a folder of ```task_name``` key in config will be created. 124 | 125 | During frame extraction , folder name for the key `im_path` will be created in `task_name` directory and frames will be saved in there. 126 | 127 | During training of autoencoder the following output will be saved 128 | * Latest Autoencoder and discriminator checkpoint in ```task_name``` directory 129 | * Sample reconstructions in ```task_name/vae_autoencoder_samples``` 130 | 131 | If `save_latents` script is run then latents will be saved in ```task_name/save_video_latent_dir``` if mentioned in config 132 | 133 | During training and inference of unconditional Latte following output will be saved: 134 | * During training we will save the latest checkpoint in ```task_name``` directory 135 | * During sampling, unconditional sampled video generated for all timesteps will be saved in ```task_name/samples/*.mp4``` . The final decoded generated video will be `sample_video_0.mp4`. Videos from `sample_output_999.mp4` to `sample_output_1.mp4` will be latent video predictions of denoising process from T=999 to T=1. Final Generated Video is at T=0 136 | 137 | 138 | 139 | ## Citations 140 | ``` 141 | @misc{ma2024lattelatentdiffusiontransformer, 142 | title={Latte: Latent Diffusion Transformer for Video Generation}, 143 | author={Xin Ma and Yaohui Wang and Gengyun Jia and Xinyuan Chen and Ziwei Liu and Yuan-Fang Li and Cunjian Chen and Yu Qiao}, 144 | year={2024}, 145 | eprint={2401.03048}, 146 | archivePrefix={arXiv}, 147 | primaryClass={cs.CV}, 148 | url={https://arxiv.org/abs/2401.03048}, 149 | } 150 | ``` 151 | 152 | 153 | 154 | 155 | 156 | -------------------------------------------------------------------------------- /capture_frames.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import random 3 | import argparse 4 | import glob 5 | import yaml 6 | import os 7 | from tqdm import tqdm 8 | 9 | 10 | def extract_frames_from_video(video_path, im_path, frame_sample_prob, count): 11 | video_obj = cv2.VideoCapture(video_path) 12 | 13 | success = 1 14 | while success: 15 | success, image = video_obj.read() 16 | if not success: 17 | break 18 | if random.random() > frame_sample_prob: 19 | continue 20 | 21 | cv2.imwrite('{}/frame_{}.png'.format(im_path, count), image) 22 | count += 1 23 | return count 24 | 25 | 26 | def extract_frames(args): 27 | # Read the config file # 28 | with open(args.config_path, 'r') as file: 29 | try: 30 | config = yaml.safe_load(file) 31 | except yaml.YAMLError as exc: 32 | print(exc) 33 | print(config) 34 | ######################## 35 | 36 | dataset_config = config['dataset_params'] 37 | task_name = config['train_params']['task_name'] 38 | video_path = dataset_config['video_path'] 39 | assert os.path.exists(video_path), "video path {} does not exist".format(video_path) 40 | 41 | im_path = os.path.join(task_name, dataset_config['im_path']) 42 | if not os.path.exists(im_path): 43 | os.mkdir(im_path) 44 | filter_fpath = dataset_config['video_filter_path'] 45 | video_ext = dataset_config['video_ext'] 46 | frame_sample_prob = dataset_config['ae_frame_sample_prob'] 47 | 48 | # Create frames directory if not present 49 | if not os.path.exists(im_path): 50 | os.mkdir(im_path) 51 | 52 | video_paths = [] 53 | if filter_fpath is None: 54 | filters = ['*'] 55 | else: 56 | filters = [] 57 | assert os.path.exists(filter_fpath), "Filter file not present" 58 | with open(filter_fpath, 'r') as f: 59 | for line in f.readlines(): 60 | filters.append(line.strip()) 61 | for filter_i in filters: 62 | for fname in glob.glob(os.path.join(video_path, '{}.{}'.format(filter_i, 63 | video_ext))): 64 | video_paths.append(fname) 65 | 66 | print('Found {} videos'.format(len(video_paths))) 67 | print('Extracting frames....') 68 | count = 0 69 | for video_path in tqdm(video_paths): 70 | count = extract_frames_from_video(video_path, im_path, frame_sample_prob, count) 71 | 72 | print('Extracted total {} frames'.format(count)) 73 | 74 | 75 | if __name__ == '__main__': 76 | parser = argparse.ArgumentParser(description='Arguments for frame extraction ' 77 | 'for autoencoder training') 78 | parser.add_argument('--config', dest='config_path', 79 | default='config/mnist.yaml', type=str) 80 | args = parser.parse_args() 81 | extract_frames(args) 82 | -------------------------------------------------------------------------------- /config/mnist.yaml: -------------------------------------------------------------------------------- 1 | dataset_params: 2 | im_path: 'mmnist-easy-images' 3 | video_path: 'data/mmnist-easy' 4 | video_ext: 'mp4' 5 | num_images_train: 8 6 | frame_height : 72 7 | frame_width : 128 8 | frame_channels : 1 9 | num_frames: 16 10 | frame_interval: 2 11 | centre_square_crop: False 12 | video_filter_path: null 13 | ae_frame_sample_prob : 0.1 14 | 15 | diffusion_params: 16 | num_timesteps : 1000 17 | beta_start : 0.0001 18 | beta_end : 0.02 19 | 20 | ditv_params: 21 | patch_size : 2 22 | num_layers : 12 23 | hidden_size : 768 24 | num_heads : 12 25 | head_dim : 64 26 | timestep_emb_dim : 256 27 | 28 | autoencoder_params: 29 | z_channels: 4 30 | codebook_size : 20 31 | down_channels : [32, 64, 128] 32 | mid_channels : [128] 33 | down_sample : [True, True] 34 | attn_down : [False, False] 35 | norm_channels: 32 36 | num_heads: 16 37 | num_down_layers : 1 38 | num_mid_layers : 1 39 | num_up_layers : 1 40 | 41 | 42 | train_params: 43 | seed : 1111 44 | task_name: 'mmnist' 45 | autoencoder_batch_size: 64 46 | autoencoder_epochs: 20 47 | autoencoder_lr: 0.0001 48 | autoencoder_acc_steps: 1 49 | disc_start: 500 50 | disc_weight: 0.5 51 | codebook_weight: 1 52 | commitment_beta: 0.2 53 | perceptual_weight: 1 54 | kl_weight: 0.000005 55 | autoencoder_img_save_steps: 64 56 | save_latents: False 57 | ditv_batch_size: 4 58 | ditv_epochs: 300 59 | num_samples: 1 60 | ditv_lr: 0.0001 61 | ditv_acc_steps: 1 62 | save_video_latent_dir: 'video_latents' 63 | vae_latent_dir_name: 'vae_latents' 64 | ditv_ckpt_name: 'dit_ckpt.pth' 65 | vae_autoencoder_ckpt_name: 'vae_autoencoder_ckpt.pth' 66 | vae_discriminator_ckpt_name: 'vae_discriminator_ckpt.pth' 67 | -------------------------------------------------------------------------------- /config/ucf.yaml: -------------------------------------------------------------------------------- 1 | dataset_params: 2 | im_path: 'ucf-images' 3 | video_path: 'data/UCF101' 4 | video_ext: 'avi' 5 | num_images_train : 8 6 | frame_height: 256 7 | frame_width: 256 8 | frame_channels: 3 9 | num_frames: 16 10 | frame_interval: 3 11 | centre_square_crop : True 12 | video_filter_path: 'dataset/ucf_filter.txt' 13 | ae_frame_sample_prob : 0.1 14 | 15 | diffusion_params: 16 | num_timesteps : 1000 17 | beta_start : 0.0001 18 | beta_end : 0.02 19 | 20 | ditv_params: 21 | patch_size : 2 22 | num_layers : 12 23 | hidden_size : 768 24 | num_heads : 12 25 | head_dim : 64 26 | timestep_emb_dim : 256 27 | 28 | autoencoder_params: 29 | z_channels: 4 30 | codebook_size : 8192 31 | down_channels : [128, 256, 384, 512] 32 | mid_channels : [512] 33 | down_sample : [True, True, True] 34 | attn_down : [False, False, False] 35 | norm_channels: 32 36 | num_heads: 4 37 | num_down_layers : 2 38 | num_mid_layers : 2 39 | num_up_layers : 2 40 | 41 | 42 | train_params: 43 | seed : 1111 44 | task_name: 'ucf' 45 | autoencoder_batch_size: 4 46 | autoencoder_epochs: 30 47 | autoencoder_lr: 0.00001 48 | autoencoder_acc_steps: 1 49 | disc_start: 5000 50 | disc_weight: 0.5 51 | codebook_weight: 1 52 | commitment_beta: 0.2 53 | perceptual_weight: 1 54 | kl_weight: 0.000005 55 | autoencoder_img_save_steps: 64 56 | save_latents: False 57 | ditv_batch_size: 4 58 | ditv_epochs: 1000 59 | num_samples: 1 60 | ditv_lr: 0.0001 61 | ditv_acc_steps: 1 62 | save_video_latent_dir: 'video_latents' 63 | vae_latent_dir_name: 'vae_latents' 64 | ditv_ckpt_name: 'dit_ckpt.pth' 65 | vae_autoencoder_ckpt_name: 'vae_autoencoder_ckpt.pth' 66 | vae_discriminator_ckpt_name: 'vae_discriminator_ckpt.pth' 67 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/explainingai-code/VideoGeneration-PyTorch/9d40d42a8f5a8919ce0d356ad54ebfce4c3089b8/dataset/__init__.py -------------------------------------------------------------------------------- /dataset/image_dataset.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import torchvision 4 | from PIL import Image 5 | from torch.utils.data.dataset import Dataset 6 | 7 | 8 | class ImageDataset(Dataset): 9 | r""" 10 | Simple Image Dataset class for training autoencoder on frames 11 | of the videos 12 | """ 13 | def __init__(self, split, dataset_config, task_name, im_ext='png'): 14 | r""" 15 | Init method for initializing the dataset properties 16 | :param split: train/test to locate the image files 17 | :param dataset_config: config parameters for dataset(mnist/ucf) 18 | :param im_ext: image extension. assumes all 19 | images would be this type. 20 | """ 21 | self.split = split 22 | self.im_ext = im_ext 23 | self.frame_height = dataset_config['frame_height'] 24 | self.frame_width = dataset_config['frame_width'] 25 | self.frame_channels = dataset_config['frame_channels'] 26 | self.center_square_crop = dataset_config['centre_square_crop'] 27 | self.images = self.load_images(os.path.join(task_name, dataset_config['im_path'])) 28 | 29 | def load_images(self, im_path): 30 | r""" 31 | Gets all images from the path specified 32 | and stacks them all up 33 | :param im_path: 34 | :return: 35 | """ 36 | assert os.path.exists(im_path), "images path {} does not exist".format(im_path) 37 | ims = [] 38 | for fname in glob.glob(os.path.join(im_path, '*.{}'.format(self.im_ext))): 39 | ims.append(fname) 40 | print('Found {} images for split {}'.format(len(ims), self.split)) 41 | return ims 42 | 43 | def __len__(self): 44 | return len(self.images) 45 | 46 | def __getitem__(self, index): 47 | assert self.frame_channels in (1, 3), "Frame channels can only be 1/3" 48 | if self.frame_channels == 1: 49 | im = Image.open(self.images[index]).convert('L') 50 | else: 51 | im = Image.open(self.images[index]).convert('RGB') 52 | if self.center_square_crop: 53 | assert self.frame_height == self.frame_width, \ 54 | ('For centre square crop frame_height ' 55 | 'and frame_width should be same') 56 | im_tensor = torchvision.transforms.Compose([ 57 | torchvision.transforms.Resize(self.frame_height), 58 | torchvision.transforms.CenterCrop(self.frame_height), 59 | torchvision.transforms.ToTensor(), 60 | ])(im) 61 | else: 62 | im_tensor = torchvision.transforms.Compose([ 63 | torchvision.transforms.Resize((self.frame_height, self.frame_width)), 64 | torchvision.transforms.ToTensor(), 65 | ])(im) 66 | 67 | im.close() 68 | im_tensor = (2 * im_tensor) - 1 69 | return im_tensor 70 | -------------------------------------------------------------------------------- /dataset/ucf_filter.txt: -------------------------------------------------------------------------------- 1 | *WallPushups* 2 | *_PushUps* 3 | *_PullUps* 4 | *_TaiChi* -------------------------------------------------------------------------------- /dataset/video_dataset.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import cv2 4 | import torch 5 | import numpy as np 6 | import random 7 | import torchvision 8 | import pickle 9 | import torchvision.transforms.v2 10 | from PIL import Image 11 | from tqdm import tqdm 12 | from torch.utils.data.dataset import Dataset 13 | 14 | 15 | class VideoDataset(Dataset): 16 | r""" 17 | Simple Video Dataset class for training diffusion transformer 18 | for video generation. 19 | If latents are present, the dataset uses the saved latents for the videos, 20 | else it reads the video and extracts frames from it. 21 | """ 22 | def __init__(self, split, dataset_config, latent_path=None, im_ext='png'): 23 | r""" 24 | Initialize all parameters and also check 25 | if latents are present or not 26 | :param split: for now this is always train 27 | :param dataset_config: config parameters for dataset(mnist/ucf) 28 | :param latent_path: Path for saved latents 29 | :param im_ext: assumes all images are of this extension. Used only 30 | if latents are not present 31 | """ 32 | self.split = split 33 | self.video_ext = dataset_config['video_ext'] 34 | self.num_images = dataset_config['num_images_train'] 35 | self.use_images = self.num_images > 0 36 | self.num_frames = dataset_config['num_frames'] 37 | self.frame_interval = dataset_config['frame_interval'] 38 | self.frame_height = dataset_config['frame_height'] 39 | self.frame_width = dataset_config['frame_width'] 40 | self.frame_channels = dataset_config['frame_channels'] 41 | self.center_square_crop = dataset_config['centre_square_crop'] 42 | self.filter_fpath = dataset_config['video_filter_path'] 43 | if self.center_square_crop: 44 | assert self.frame_height == self.frame_width, \ 45 | ('For centre square crop frame_height ' 46 | 'and frame_width should be same') 47 | self.transforms = torchvision.transforms.v2.Compose([ 48 | torchvision.transforms.v2.Resize(self.frame_height), 49 | torchvision.transforms.v2.CenterCrop(self.frame_height), 50 | torchvision.transforms.v2.ToPureTensor(), 51 | torchvision.transforms.v2.ToDtype(torch.float32, scale=True), 52 | torchvision.transforms.v2.Normalize(mean=[0.5, 0.5, 0.5], 53 | std=[0.5, 0.5, 0.5]) 54 | ]) 55 | else: 56 | self.transforms = torchvision.transforms.v2.Compose([ 57 | torchvision.transforms.v2.Resize((self.frame_height, 58 | self.frame_width)), 59 | torchvision.transforms.v2.ToPureTensor(), 60 | torchvision.transforms.v2.ToDtype(torch.float32, scale=True), 61 | torchvision.transforms.v2.Normalize(mean=[0.5, 0.5, 0.5], 62 | std=[0.5, 0.5, 0.5]) 63 | ]) 64 | 65 | # Load video paths for this dataset 66 | self.video_paths = self.load_videos(dataset_config['video_path'], 67 | self.filter_fpath) 68 | 69 | # Validate if latents are present and if they are present for all 70 | # videos. And only upon validation set `use_latents` as True 71 | self.use_latents = False 72 | if latent_path is not None and os.path.exists(latent_path): 73 | num_latents = len(glob.glob(os.path.join(latent_path, '*.pkl'))) 74 | self.latents = glob.glob(os.path.join(latent_path, '*.pkl')) 75 | if num_latents == len(self.video_paths): 76 | self.use_latents = True 77 | print('Will use latents') 78 | 79 | def load_videos(self, video_path, filter_fpath=None): 80 | r""" 81 | Method to load all video names for training. 82 | This uses the filter file to use only selective videos 83 | for training. 84 | :param video_path: Path for all videos in the dataset 85 | :param filter_fpath: Path for file containing filters for relevant videos 86 | :return: 87 | """ 88 | assert os.path.exists(video_path), ( 89 | "video path {} does not exist".format(video_path)) 90 | video_paths = [] 91 | 92 | if filter_fpath is None: 93 | filters = ['*'] 94 | else: 95 | filters = [] 96 | assert os.path.exists(filter_fpath), "Filter file not present" 97 | with open(filter_fpath, 'r') as f: 98 | for line in f.readlines(): 99 | filters.append(line.strip()) 100 | for filter in filters: 101 | for fname in glob.glob(os.path.join(video_path, 102 | '{}.{}'.format(filter, 103 | self.video_ext))): 104 | video_paths.append(fname) 105 | print('Found {} videos'.format(len(video_paths))) 106 | return video_paths 107 | 108 | def __len__(self): 109 | return len(self.video_paths) 110 | 111 | def __getitem__(self, index): 112 | # We do things differently whether we are working with latents or not 113 | if self.use_latents: 114 | # Load the latent corresponding to this item 115 | latent_path = self.latents[index] 116 | latent = pickle.load(open(latent_path, 'rb')).cpu() 117 | 118 | # Sample (self.frame_interval * self.num_frames) frames 119 | # and from that take num_frames(16) equally spaced frames 120 | # Keep only the latents of these sampled frames 121 | num_frames = len(latent) 122 | total_frames = self.frame_interval * self.num_frames 123 | max_end = max(0, num_frames - total_frames - 1) 124 | start_index = random.randint(0, max_end) 125 | end_index = min(start_index + total_frames, num_frames) 126 | frame_idxs = np.linspace(start_index, end_index - 1, self.num_frames, 127 | dtype=int) 128 | latent = latent[frame_idxs] 129 | 130 | # From the latent extract the mean and std 131 | # and reparametrization to sample according to this 132 | # mean and std 133 | mean, logvar = torch.chunk(latent, 2, dim=1) 134 | std = torch.exp(0.5 * logvar) 135 | frames_tensor = mean + std * torch.randn(mean.shape) 136 | 137 | # If we are doing image+video joint training 138 | # then sample random num_images(8) videos 139 | # and then sample a random latent frame from that video 140 | if self.use_images: 141 | im_latents = [] 142 | for _ in range(self.num_images): 143 | # Sample a random video 144 | # From this video sample a random saved latent 145 | # Extract mean, std from this latent and then 146 | # use that to get a im_latent 147 | video_idx = random.randint(0, len(self.latents)-1) 148 | random_video_latent = pickle.load(open(self.latents[video_idx], 149 | 'rb')).cpu() 150 | frame_idx = random.randint(0, len(random_video_latent) - 1) 151 | im_latent = random_video_latent[frame_idx] 152 | mean, logvar = torch.chunk(im_latent, 2, dim=0) 153 | std = torch.exp(0.5 * logvar) 154 | im_latent = mean + std * torch.randn(mean.shape) 155 | im_latents.append(im_latent.unsqueeze(0)) 156 | 157 | im_latents = torch.cat(im_latents) 158 | 159 | # Concat video latents and image latents together for training 160 | frames_tensor = torch.cat([frames_tensor, im_latents], dim=0) 161 | return frames_tensor 162 | else: 163 | # Read the video corresponding to this item 164 | path = self.video_paths[index] 165 | frames, _, _ = torchvision.io.read_video(filename=path, 166 | pts_unit='sec', 167 | output_format='TCHW') 168 | 169 | # Sample (self.frame_interval * self.num_frames) frames 170 | # and from that take num_frames(16) equally spaced frames 171 | # Keep only these sampled frames 172 | num_frames = len(frames) 173 | max_end = max(0, num_frames - (self.num_frames * self.frame_interval) - 1) 174 | start_index = random.randint(0, max_end) 175 | end_index = min(start_index + (self.num_frames * self.frame_interval), 176 | num_frames) 177 | frame_idxs = np.linspace(start_index, end_index - 1, self.num_frames, 178 | dtype=int) 179 | frames = frames[frame_idxs] 180 | 181 | # Resize frames according to transformations 182 | # desired based on config 183 | frames_tensor = self.transforms(frames) 184 | 185 | # For grayscale keep only the first channel 186 | if self.frame_channels == 1: 187 | frames_tensor = frames_tensor[:, 0:1, :, :] 188 | 189 | # If we are doing image+video joint training 190 | # then sample random num_images(8) videos 191 | # and then sample a random frame from that video 192 | if self.use_images: 193 | im_tensors = [] 194 | for _ in range(self.num_images): 195 | # Sample a random video 196 | # From this video sample a random image frame 197 | video_idx = random.randint(0, len(self.video_paths)-1) 198 | path = self.video_paths[video_idx] 199 | ims, _, _ = torchvision.io.read_video(filename=path, 200 | pts_unit='sec', 201 | output_format='TCHW') 202 | im_idx = random.randint(0, len(ims) - 1) 203 | ims = ims[im_idx] 204 | 205 | # Resize this sampled image according to transformations 206 | # desired based on config 207 | im_tensor = self.transforms(ims) 208 | 209 | # For grayscale keep only the first channel 210 | if self.frame_channels == 1: 211 | im_tensor = im_tensor[0:1, :, :] 212 | im_tensors.append(im_tensor.unsqueeze(0)) 213 | 214 | im_tensors = torch.cat(im_tensors) 215 | frames_tensor = torch.cat([frames_tensor, im_tensors], dim=0) 216 | return frames_tensor 217 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/explainingai-code/VideoGeneration-PyTorch/9d40d42a8f5a8919ce0d356ad54ebfce4c3089b8/model/__init__.py -------------------------------------------------------------------------------- /model/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from einops import rearrange 4 | 5 | 6 | class Attention(nn.Module): 7 | r""" 8 | Attention Module for DiT. 9 | This is same as VIT code and does not have any changes 10 | from it. 11 | """ 12 | def __init__(self, config): 13 | super().__init__() 14 | self.n_heads = config['num_heads'] 15 | self.hidden_size = config['hidden_size'] 16 | self.head_dim = config['head_dim'] 17 | 18 | self.att_dim = self.n_heads * self.head_dim 19 | 20 | # QKV projection for the input 21 | self.qkv_proj = nn.Linear(self.hidden_size, 3 * self.att_dim, bias=True) 22 | self.output_proj = nn.Sequential( 23 | nn.Linear(self.att_dim, self.hidden_size)) 24 | 25 | ############################ 26 | # DiT Layer Initialization # 27 | ############################ 28 | nn.init.xavier_uniform_(self.qkv_proj.weight) 29 | nn.init.constant_(self.qkv_proj.bias, 0) 30 | nn.init.xavier_uniform_(self.output_proj[0].weight) 31 | nn.init.constant_(self.output_proj[0].bias, 0) 32 | 33 | def forward(self, x): 34 | # Converting to Attention Dimension 35 | ###################################################### 36 | # Batch Size x Number of Patches x Dimension 37 | B, N = x.shape[:2] 38 | # Projecting to 3*att_dim and then splitting to get q, k v(each of att_dim) 39 | # qkv -> Batch Size x Number of Patches x (3* Attention Dimension) 40 | # q(as well as k and v) -> Batch Size x Number of Patches x Attention Dimension 41 | q, k, v = self.qkv_proj(x).split(self.att_dim, dim=-1) 42 | # Batch Size x Number of Patches x Attention Dimension 43 | # -> Batch Size x Number of Patches x (Heads * Head Dimension) 44 | # -> Batch Size x Number of Patches x (Heads * Head Dimension) 45 | # -> Batch Size x Heads x Number of Patches x Head Dimension 46 | # -> B x H x N x Head Dimension 47 | q = rearrange(q, 'b n (n_h h_dim) -> b n_h n h_dim', 48 | n_h=self.n_heads, h_dim=self.head_dim) 49 | k = rearrange(k, 'b n (n_h h_dim) -> b n_h n h_dim', 50 | n_h=self.n_heads, h_dim=self.head_dim) 51 | v = rearrange(v, 'b n (n_h h_dim) -> b n_h n h_dim', 52 | n_h=self.n_heads, h_dim=self.head_dim) 53 | ######################################################### 54 | 55 | # Compute Attention Weights 56 | ######################################################### 57 | # B x H x N x Head Dimension @ B x H x Head Dimension x N 58 | # -> B x H x N x N 59 | att = torch.matmul(q, k.transpose(-2, -1)) * (self.head_dim ** (-0.5)) 60 | att = torch.nn.functional.softmax(att, dim=-1) 61 | ######################################################### 62 | 63 | # Weighted Value Computation 64 | ######################################################### 65 | # B x H x N x N @ B x H x N x Head Dimension 66 | # -> B x H x N x Head Dimension 67 | out = torch.matmul(att, v) 68 | ######################################################### 69 | 70 | # Converting to Transformer Dimension 71 | ######################################################### 72 | # B x N x (Heads * Head Dimension) -> B x N x (Attention Dimension) 73 | out = rearrange(out, 'b n_h n h_dim -> b n (n_h h_dim)') 74 | # B x N x Dimension 75 | out = self.output_proj(out) 76 | ########################################################## 77 | 78 | return out 79 | -------------------------------------------------------------------------------- /model/blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def get_time_embedding(time_steps, temb_dim): 6 | r""" 7 | Convert time steps tensor into an embedding using the 8 | sinusoidal time embedding formula 9 | :param time_steps: 1D tensor of length batch size 10 | :param temb_dim: Dimension of the embedding 11 | :return: BxD embedding representation of B time steps 12 | """ 13 | assert temb_dim % 2 == 0, "time embedding dimension must be divisible by 2" 14 | 15 | # factor = 10000^(2i/d_model) 16 | factor = 10000 ** ((torch.arange( 17 | start=0, end=temb_dim // 2, dtype=torch.float32, device=time_steps.device) / (temb_dim // 2)) 18 | ) 19 | 20 | # pos / factor 21 | # timesteps B -> B, 1 -> B, temb_dim 22 | t_emb = time_steps[:, None].repeat(1, temb_dim // 2) / factor 23 | t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1) 24 | return t_emb 25 | 26 | 27 | class DownBlock(nn.Module): 28 | r""" 29 | Down conv block with attention. 30 | Sequence of following block 31 | 1. Resnet block with time embedding 32 | 2. Attention block 33 | 3. Downsample 34 | """ 35 | 36 | def __init__(self, in_channels, out_channels, t_emb_dim, 37 | down_sample, num_heads, num_layers, attn, norm_channels, cross_attn=False, context_dim=None): 38 | super().__init__() 39 | self.num_layers = num_layers 40 | self.down_sample = down_sample 41 | self.attn = attn 42 | self.context_dim = context_dim 43 | self.cross_attn = cross_attn 44 | self.t_emb_dim = t_emb_dim 45 | self.resnet_conv_first = nn.ModuleList( 46 | [ 47 | nn.Sequential( 48 | nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels), 49 | nn.SiLU(), 50 | nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, 51 | kernel_size=3, stride=1, padding=1), 52 | ) 53 | for i in range(num_layers) 54 | ] 55 | ) 56 | if self.t_emb_dim is not None: 57 | self.t_emb_layers = nn.ModuleList([ 58 | nn.Sequential( 59 | nn.SiLU(), 60 | nn.Linear(self.t_emb_dim, out_channels) 61 | ) 62 | for _ in range(num_layers) 63 | ]) 64 | self.resnet_conv_second = nn.ModuleList( 65 | [ 66 | nn.Sequential( 67 | nn.GroupNorm(norm_channels, out_channels), 68 | nn.SiLU(), 69 | nn.Conv2d(out_channels, out_channels, 70 | kernel_size=3, stride=1, padding=1), 71 | ) 72 | for _ in range(num_layers) 73 | ] 74 | ) 75 | 76 | if self.attn: 77 | self.attention_norms = nn.ModuleList( 78 | [nn.GroupNorm(norm_channels, out_channels) 79 | for _ in range(num_layers)] 80 | ) 81 | 82 | self.attentions = nn.ModuleList( 83 | [nn.MultiheadAttention(out_channels, num_heads, batch_first=True) 84 | for _ in range(num_layers)] 85 | ) 86 | 87 | if self.cross_attn: 88 | assert context_dim is not None, "Context Dimension must be passed for cross attention" 89 | self.cross_attention_norms = nn.ModuleList( 90 | [nn.GroupNorm(norm_channels, out_channels) 91 | for _ in range(num_layers)] 92 | ) 93 | self.cross_attentions = nn.ModuleList( 94 | [nn.MultiheadAttention(out_channels, num_heads, batch_first=True) 95 | for _ in range(num_layers)] 96 | ) 97 | self.context_proj = nn.ModuleList( 98 | [nn.Linear(context_dim, out_channels) 99 | for _ in range(num_layers)] 100 | ) 101 | 102 | self.residual_input_conv = nn.ModuleList( 103 | [ 104 | nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1) 105 | for i in range(num_layers) 106 | ] 107 | ) 108 | self.down_sample_conv = nn.Conv2d(out_channels, out_channels, 109 | 4, 2, 1) if self.down_sample else nn.Identity() 110 | 111 | def forward(self, x, t_emb=None, context=None): 112 | out = x 113 | for i in range(self.num_layers): 114 | # Resnet block of Unet 115 | resnet_input = out 116 | out = self.resnet_conv_first[i](out) 117 | if self.t_emb_dim is not None: 118 | out = out + self.t_emb_layers[i](t_emb)[:, :, None, None] 119 | out = self.resnet_conv_second[i](out) 120 | out = out + self.residual_input_conv[i](resnet_input) 121 | 122 | if self.attn: 123 | # Attention block of Unet 124 | batch_size, channels, h, w = out.shape 125 | in_attn = out.reshape(batch_size, channels, h * w) 126 | in_attn = self.attention_norms[i](in_attn) 127 | in_attn = in_attn.transpose(1, 2) 128 | out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn) 129 | out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) 130 | out = out + out_attn 131 | 132 | if self.cross_attn: 133 | assert context is not None, "context cannot be None if cross attention layers are used" 134 | batch_size, channels, h, w = out.shape 135 | in_attn = out.reshape(batch_size, channels, h * w) 136 | in_attn = self.cross_attention_norms[i](in_attn) 137 | in_attn = in_attn.transpose(1, 2) 138 | assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim 139 | context_proj = self.context_proj[i](context) 140 | out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj) 141 | out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) 142 | out = out + out_attn 143 | 144 | # Downsample 145 | out = self.down_sample_conv(out) 146 | return out 147 | 148 | 149 | class MidBlock(nn.Module): 150 | r""" 151 | Mid conv block with attention. 152 | Sequence of following blocks 153 | 1. Resnet block with time embedding 154 | 2. Attention block 155 | 3. Resnet block with time embedding 156 | """ 157 | 158 | def __init__(self, in_channels, out_channels, t_emb_dim, num_heads, num_layers, norm_channels, cross_attn=None, 159 | context_dim=None): 160 | super().__init__() 161 | self.num_layers = num_layers 162 | self.t_emb_dim = t_emb_dim 163 | self.context_dim = context_dim 164 | self.cross_attn = cross_attn 165 | self.resnet_conv_first = nn.ModuleList( 166 | [ 167 | nn.Sequential( 168 | nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels), 169 | nn.SiLU(), 170 | nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1, 171 | padding=1), 172 | ) 173 | for i in range(num_layers + 1) 174 | ] 175 | ) 176 | 177 | if self.t_emb_dim is not None: 178 | self.t_emb_layers = nn.ModuleList([ 179 | nn.Sequential( 180 | nn.SiLU(), 181 | nn.Linear(t_emb_dim, out_channels) 182 | ) 183 | for _ in range(num_layers + 1) 184 | ]) 185 | self.resnet_conv_second = nn.ModuleList( 186 | [ 187 | nn.Sequential( 188 | nn.GroupNorm(norm_channels, out_channels), 189 | nn.SiLU(), 190 | nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), 191 | ) 192 | for _ in range(num_layers + 1) 193 | ] 194 | ) 195 | 196 | self.attention_norms = nn.ModuleList( 197 | [nn.GroupNorm(norm_channels, out_channels) 198 | for _ in range(num_layers)] 199 | ) 200 | 201 | self.attentions = nn.ModuleList( 202 | [nn.MultiheadAttention(out_channels, num_heads, batch_first=True) 203 | for _ in range(num_layers)] 204 | ) 205 | if self.cross_attn: 206 | assert context_dim is not None, "Context Dimension must be passed for cross attention" 207 | self.cross_attention_norms = nn.ModuleList( 208 | [nn.GroupNorm(norm_channels, out_channels) 209 | for _ in range(num_layers)] 210 | ) 211 | self.cross_attentions = nn.ModuleList( 212 | [nn.MultiheadAttention(out_channels, num_heads, batch_first=True) 213 | for _ in range(num_layers)] 214 | ) 215 | self.context_proj = nn.ModuleList( 216 | [nn.Linear(context_dim, out_channels) 217 | for _ in range(num_layers)] 218 | ) 219 | self.residual_input_conv = nn.ModuleList( 220 | [ 221 | nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1) 222 | for i in range(num_layers + 1) 223 | ] 224 | ) 225 | 226 | def forward(self, x, t_emb=None, context=None): 227 | out = x 228 | 229 | # First resnet block 230 | resnet_input = out 231 | out = self.resnet_conv_first[0](out) 232 | if self.t_emb_dim is not None: 233 | out = out + self.t_emb_layers[0](t_emb)[:, :, None, None] 234 | out = self.resnet_conv_second[0](out) 235 | out = out + self.residual_input_conv[0](resnet_input) 236 | 237 | for i in range(self.num_layers): 238 | # Attention Block 239 | batch_size, channels, h, w = out.shape 240 | in_attn = out.reshape(batch_size, channels, h * w) 241 | in_attn = self.attention_norms[i](in_attn) 242 | in_attn = in_attn.transpose(1, 2) 243 | out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn) 244 | out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) 245 | out = out + out_attn 246 | 247 | if self.cross_attn: 248 | assert context is not None, "context cannot be None if cross attention layers are used" 249 | batch_size, channels, h, w = out.shape 250 | in_attn = out.reshape(batch_size, channels, h * w) 251 | in_attn = self.cross_attention_norms[i](in_attn) 252 | in_attn = in_attn.transpose(1, 2) 253 | assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim 254 | context_proj = self.context_proj[i](context) 255 | out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj) 256 | out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) 257 | out = out + out_attn 258 | 259 | # Resnet Block 260 | resnet_input = out 261 | out = self.resnet_conv_first[i + 1](out) 262 | if self.t_emb_dim is not None: 263 | out = out + self.t_emb_layers[i + 1](t_emb)[:, :, None, None] 264 | out = self.resnet_conv_second[i + 1](out) 265 | out = out + self.residual_input_conv[i + 1](resnet_input) 266 | 267 | return out 268 | 269 | 270 | class UpBlock(nn.Module): 271 | r""" 272 | Up conv block with attention. 273 | Sequence of following blocks 274 | 1. Upsample 275 | 1. Concatenate Down block output 276 | 2. Resnet block with time embedding 277 | 3. Attention Block 278 | """ 279 | 280 | def __init__(self, in_channels, out_channels, t_emb_dim, 281 | up_sample, num_heads, num_layers, attn, norm_channels): 282 | super().__init__() 283 | self.num_layers = num_layers 284 | self.up_sample = up_sample 285 | self.t_emb_dim = t_emb_dim 286 | self.attn = attn 287 | self.resnet_conv_first = nn.ModuleList( 288 | [ 289 | nn.Sequential( 290 | nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels), 291 | nn.SiLU(), 292 | nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1, 293 | padding=1), 294 | ) 295 | for i in range(num_layers) 296 | ] 297 | ) 298 | 299 | if self.t_emb_dim is not None: 300 | self.t_emb_layers = nn.ModuleList([ 301 | nn.Sequential( 302 | nn.SiLU(), 303 | nn.Linear(t_emb_dim, out_channels) 304 | ) 305 | for _ in range(num_layers) 306 | ]) 307 | 308 | self.resnet_conv_second = nn.ModuleList( 309 | [ 310 | nn.Sequential( 311 | nn.GroupNorm(norm_channels, out_channels), 312 | nn.SiLU(), 313 | nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), 314 | ) 315 | for _ in range(num_layers) 316 | ] 317 | ) 318 | if self.attn: 319 | self.attention_norms = nn.ModuleList( 320 | [ 321 | nn.GroupNorm(norm_channels, out_channels) 322 | for _ in range(num_layers) 323 | ] 324 | ) 325 | 326 | self.attentions = nn.ModuleList( 327 | [ 328 | nn.MultiheadAttention(out_channels, num_heads, batch_first=True) 329 | for _ in range(num_layers) 330 | ] 331 | ) 332 | 333 | self.residual_input_conv = nn.ModuleList( 334 | [ 335 | nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1) 336 | for i in range(num_layers) 337 | ] 338 | ) 339 | self.up_sample_conv = nn.ConvTranspose2d(in_channels, in_channels, 340 | 4, 2, 1) \ 341 | if self.up_sample else nn.Identity() 342 | 343 | def forward(self, x, out_down=None, t_emb=None): 344 | # Upsample 345 | x = self.up_sample_conv(x) 346 | 347 | # Concat with Downblock output 348 | if out_down is not None: 349 | x = torch.cat([x, out_down], dim=1) 350 | 351 | out = x 352 | for i in range(self.num_layers): 353 | # Resnet Block 354 | resnet_input = out 355 | out = self.resnet_conv_first[i](out) 356 | if self.t_emb_dim is not None: 357 | out = out + self.t_emb_layers[i](t_emb)[:, :, None, None] 358 | out = self.resnet_conv_second[i](out) 359 | out = out + self.residual_input_conv[i](resnet_input) 360 | 361 | # Self Attention 362 | if self.attn: 363 | batch_size, channels, h, w = out.shape 364 | in_attn = out.reshape(batch_size, channels, h * w) 365 | in_attn = self.attention_norms[i](in_attn) 366 | in_attn = in_attn.transpose(1, 2) 367 | out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn) 368 | out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) 369 | out = out + out_attn 370 | return out 371 | 372 | 373 | class UpBlockUnet(nn.Module): 374 | r""" 375 | Up conv block with attention. 376 | Sequence of following blocks 377 | 1. Upsample 378 | 1. Concatenate Down block output 379 | 2. Resnet block with time embedding 380 | 3. Attention Block 381 | """ 382 | 383 | def __init__(self, in_channels, out_channels, t_emb_dim, up_sample, 384 | num_heads, num_layers, norm_channels, cross_attn=False, context_dim=None): 385 | super().__init__() 386 | self.num_layers = num_layers 387 | self.up_sample = up_sample 388 | self.t_emb_dim = t_emb_dim 389 | self.cross_attn = cross_attn 390 | self.context_dim = context_dim 391 | self.resnet_conv_first = nn.ModuleList( 392 | [ 393 | nn.Sequential( 394 | nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels), 395 | nn.SiLU(), 396 | nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1, 397 | padding=1), 398 | ) 399 | for i in range(num_layers) 400 | ] 401 | ) 402 | 403 | if self.t_emb_dim is not None: 404 | self.t_emb_layers = nn.ModuleList([ 405 | nn.Sequential( 406 | nn.SiLU(), 407 | nn.Linear(t_emb_dim, out_channels) 408 | ) 409 | for _ in range(num_layers) 410 | ]) 411 | 412 | self.resnet_conv_second = nn.ModuleList( 413 | [ 414 | nn.Sequential( 415 | nn.GroupNorm(norm_channels, out_channels), 416 | nn.SiLU(), 417 | nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), 418 | ) 419 | for _ in range(num_layers) 420 | ] 421 | ) 422 | 423 | self.attention_norms = nn.ModuleList( 424 | [ 425 | nn.GroupNorm(norm_channels, out_channels) 426 | for _ in range(num_layers) 427 | ] 428 | ) 429 | 430 | self.attentions = nn.ModuleList( 431 | [ 432 | nn.MultiheadAttention(out_channels, num_heads, batch_first=True) 433 | for _ in range(num_layers) 434 | ] 435 | ) 436 | 437 | if self.cross_attn: 438 | assert context_dim is not None, "Context Dimension must be passed for cross attention" 439 | self.cross_attention_norms = nn.ModuleList( 440 | [nn.GroupNorm(norm_channels, out_channels) 441 | for _ in range(num_layers)] 442 | ) 443 | self.cross_attentions = nn.ModuleList( 444 | [nn.MultiheadAttention(out_channels, num_heads, batch_first=True) 445 | for _ in range(num_layers)] 446 | ) 447 | self.context_proj = nn.ModuleList( 448 | [nn.Linear(context_dim, out_channels) 449 | for _ in range(num_layers)] 450 | ) 451 | self.residual_input_conv = nn.ModuleList( 452 | [ 453 | nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1) 454 | for i in range(num_layers) 455 | ] 456 | ) 457 | self.up_sample_conv = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, 458 | 4, 2, 1) \ 459 | if self.up_sample else nn.Identity() 460 | 461 | def forward(self, x, out_down=None, t_emb=None, context=None): 462 | x = self.up_sample_conv(x) 463 | if out_down is not None: 464 | x = torch.cat([x, out_down], dim=1) 465 | 466 | out = x 467 | for i in range(self.num_layers): 468 | # Resnet 469 | resnet_input = out 470 | out = self.resnet_conv_first[i](out) 471 | if self.t_emb_dim is not None: 472 | out = out + self.t_emb_layers[i](t_emb)[:, :, None, None] 473 | out = self.resnet_conv_second[i](out) 474 | out = out + self.residual_input_conv[i](resnet_input) 475 | # Self Attention 476 | batch_size, channels, h, w = out.shape 477 | in_attn = out.reshape(batch_size, channels, h * w) 478 | in_attn = self.attention_norms[i](in_attn) 479 | in_attn = in_attn.transpose(1, 2) 480 | out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn) 481 | out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) 482 | out = out + out_attn 483 | # Cross Attention 484 | if self.cross_attn: 485 | assert context is not None, "context cannot be None if cross attention layers are used" 486 | batch_size, channels, h, w = out.shape 487 | in_attn = out.reshape(batch_size, channels, h * w) 488 | in_attn = self.cross_attention_norms[i](in_attn) 489 | in_attn = in_attn.transpose(1, 2) 490 | assert len(context.shape) == 3, \ 491 | "Context shape does not match B,_,CONTEXT_DIM" 492 | assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim, \ 493 | "Context shape does not match B,_,CONTEXT_DIM" 494 | context_proj = self.context_proj[i](context) 495 | out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj) 496 | out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) 497 | out = out + out_attn 498 | 499 | return out 500 | 501 | 502 | -------------------------------------------------------------------------------- /model/discriminator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Discriminator(nn.Module): 6 | r""" 7 | PatchGAN Discriminator. 8 | Rather than taking IMG_CHANNELSxIMG_HxIMG_W all the way to 9 | 1 scalar value , we instead predict grid of values. 10 | Where each grid is prediction of how likely 11 | the discriminator thinks that the image patch corresponding 12 | to the grid cell is real 13 | """ 14 | 15 | def __init__(self, im_channels=3, 16 | conv_channels=[64, 128, 256], 17 | kernels=[4,4,4,4], 18 | strides=[2,2,2,1], 19 | paddings=[1,1,1,1]): 20 | super().__init__() 21 | self.im_channels = im_channels 22 | activation = nn.LeakyReLU(0.2) 23 | layers_dim = [self.im_channels] + conv_channels + [1] 24 | self.layers = nn.ModuleList([ 25 | nn.Sequential( 26 | nn.Conv2d(layers_dim[i], layers_dim[i + 1], 27 | kernel_size=kernels[i], 28 | stride=strides[i], 29 | padding=paddings[i], 30 | bias=False if i !=0 else True), 31 | nn.BatchNorm2d(layers_dim[i + 1]) if i != len(layers_dim) - 2 and i != 0 else nn.Identity(), 32 | activation if i != len(layers_dim) - 2 else nn.Identity() 33 | ) 34 | for i in range(len(layers_dim) - 1) 35 | ]) 36 | 37 | def forward(self, x): 38 | out = x 39 | for layer in self.layers: 40 | out = layer(out) 41 | return out 42 | 43 | 44 | if __name__ == '__main__': 45 | x = torch.randn((2,3, 256, 256)) 46 | prob = Discriminator(im_channels=3)(x) 47 | print(prob.shape) 48 | -------------------------------------------------------------------------------- /model/lpips.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from collections import namedtuple 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.init as init 6 | from torch.autograd import Variable 7 | import numpy as np 8 | import torch.nn 9 | import torchvision 10 | 11 | # Taken from https://github.com/richzhang/PerceptualSimilarity/blob/master/lpips/lpips.py 12 | 13 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 14 | if torch.backends.mps.is_available(): 15 | device = torch.device('mps') 16 | print('Using mps') 17 | 18 | def spatial_average(in_tens, keepdim=True): 19 | return in_tens.mean([2, 3], keepdim=keepdim) 20 | 21 | 22 | class vgg16(torch.nn.Module): 23 | def __init__(self, requires_grad=False, pretrained=True): 24 | super(vgg16, self).__init__() 25 | # Load pretrained vgg model from torchvision 26 | vgg_pretrained_features = torchvision.models.vgg16(pretrained=pretrained).features 27 | self.slice1 = torch.nn.Sequential() 28 | self.slice2 = torch.nn.Sequential() 29 | self.slice3 = torch.nn.Sequential() 30 | self.slice4 = torch.nn.Sequential() 31 | self.slice5 = torch.nn.Sequential() 32 | self.N_slices = 5 33 | for x in range(4): 34 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 35 | for x in range(4, 9): 36 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 37 | for x in range(9, 16): 38 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 39 | for x in range(16, 23): 40 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 41 | for x in range(23, 30): 42 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 43 | 44 | # Freeze vgg model 45 | if not requires_grad: 46 | for param in self.parameters(): 47 | param.requires_grad = False 48 | 49 | def forward(self, X): 50 | # Return output of vgg features 51 | h = self.slice1(X) 52 | h_relu1_2 = h 53 | h = self.slice2(h) 54 | h_relu2_2 = h 55 | h = self.slice3(h) 56 | h_relu3_3 = h 57 | h = self.slice4(h) 58 | h_relu4_3 = h 59 | h = self.slice5(h) 60 | h_relu5_3 = h 61 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) 62 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 63 | return out 64 | 65 | 66 | # Learned perceptual metric 67 | class LPIPS(nn.Module): 68 | def __init__(self, net='vgg', version='0.1', use_dropout=True): 69 | super(LPIPS, self).__init__() 70 | self.version = version 71 | # Imagenet normalization 72 | self.scaling_layer = ScalingLayer() 73 | ######################## 74 | 75 | # Instantiate vgg model 76 | self.chns = [64, 128, 256, 512, 512] 77 | self.L = len(self.chns) 78 | self.net = vgg16(pretrained=True, requires_grad=False) 79 | 80 | # Add 1x1 convolutional Layers 81 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) 82 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) 83 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) 84 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) 85 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) 86 | self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] 87 | self.lins = nn.ModuleList(self.lins) 88 | ######################## 89 | 90 | # Load the weights of trained LPIPS model 91 | import inspect 92 | import os 93 | model_path = os.path.abspath( 94 | os.path.join(inspect.getfile(self.__init__), '..', 'weights/v%s/%s.pth' % (version, net))) 95 | print('Loading model from: %s' % model_path) 96 | self.load_state_dict(torch.load(model_path, map_location=device), strict=False) 97 | ######################## 98 | 99 | # Freeze all parameters 100 | self.eval() 101 | for param in self.parameters(): 102 | param.requires_grad = False 103 | ######################## 104 | 105 | def forward(self, in0, in1, normalize=False): 106 | # Scale the inputs to -1 to +1 range if needed 107 | if normalize: # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1] 108 | in0 = 2 * in0 - 1 109 | in1 = 2 * in1 - 1 110 | ######################## 111 | 112 | # Normalize the inputs according to imagenet normalization 113 | in0_input, in1_input = self.scaling_layer(in0), self.scaling_layer(in1) 114 | ######################## 115 | 116 | # Get VGG outputs for image0 and image1 117 | outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input) 118 | feats0, feats1, diffs = {}, {}, {} 119 | ######################## 120 | 121 | # Compute Square of Difference for each layer output 122 | for kk in range(self.L): 123 | feats0[kk], feats1[kk] = torch.nn.functional.normalize(outs0[kk], dim=1), torch.nn.functional.normalize( 124 | outs1[kk]) 125 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 126 | ######################## 127 | 128 | # 1x1 convolution followed by spatial average on the square differences 129 | res = [spatial_average(self.lins[kk](diffs[kk]), keepdim=True) for kk in range(self.L)] 130 | val = 0 131 | 132 | # Aggregate the results of each layer 133 | for l in range(self.L): 134 | val += res[l] 135 | return val 136 | 137 | 138 | class ScalingLayer(nn.Module): 139 | def __init__(self): 140 | super(ScalingLayer, self).__init__() 141 | # Imagnet normalization for (0-1) 142 | # mean = [0.485, 0.456, 0.406] 143 | # std = [0.229, 0.224, 0.225] 144 | self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 145 | self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None]) 146 | 147 | def forward(self, inp): 148 | return (inp - self.shift) / self.scale 149 | 150 | 151 | class NetLinLayer(nn.Module): 152 | ''' A single linear layer which does a 1x1 conv ''' 153 | 154 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 155 | super(NetLinLayer, self).__init__() 156 | 157 | layers = [nn.Dropout(), ] if (use_dropout) else [] 158 | layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] 159 | self.model = nn.Sequential(*layers) 160 | 161 | def forward(self, x): 162 | out = self.model(x) 163 | return out 164 | -------------------------------------------------------------------------------- /model/patch_embed.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from einops import rearrange 4 | 5 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 6 | 7 | 8 | def get_patch_position_embedding(pos_emb_dim, grid_size, device): 9 | assert pos_emb_dim % 4 == 0, 'Position embedding dimension must be divisible by 4' 10 | grid_size_h, grid_size_w = grid_size 11 | grid_h = torch.arange(grid_size_h, dtype=torch.float32, device=device) 12 | grid_w = torch.arange(grid_size_w, dtype=torch.float32, device=device) 13 | grid = torch.meshgrid(grid_h, grid_w, indexing='ij') 14 | grid = torch.stack(grid, dim=0) 15 | 16 | # grid_h_positions -> (Number of patch tokens,) 17 | grid_h_positions = grid[0].reshape(-1) 18 | grid_w_positions = grid[1].reshape(-1) 19 | 20 | # factor = 10000^(2i/d_model) 21 | factor = 10000 ** ((torch.arange( 22 | start=0, 23 | end=pos_emb_dim // 4, 24 | dtype=torch.float32, 25 | device=device) / (pos_emb_dim // 4)) 26 | ) 27 | 28 | grid_h_emb = grid_h_positions[:, None].repeat(1, pos_emb_dim // 4) / factor 29 | grid_h_emb = torch.cat([torch.sin(grid_h_emb), torch.cos(grid_h_emb)], dim=-1) 30 | # grid_h_emb -> (Number of patch tokens, pos_emb_dim // 2) 31 | 32 | grid_w_emb = grid_w_positions[:, None].repeat(1, pos_emb_dim // 4) / factor 33 | grid_w_emb = torch.cat([torch.sin(grid_w_emb), torch.cos(grid_w_emb)], dim=-1) 34 | pos_emb = torch.cat([grid_h_emb, grid_w_emb], dim=-1) 35 | 36 | # pos_emb -> (Number of patch tokens, pos_emb_dim) 37 | return pos_emb 38 | 39 | 40 | class PatchEmbedding(nn.Module): 41 | r""" 42 | Layer to take in the input image and do the following: 43 | 1. Transform grid of image patches into a sequence of patches. 44 | Number of patches are decided based on image height,width and 45 | patch height, width. 46 | 2. Add positional embedding to the above sequence 47 | """ 48 | 49 | def __init__(self, 50 | image_height, 51 | image_width, 52 | im_channels, 53 | patch_height, 54 | patch_width, 55 | hidden_size): 56 | super().__init__() 57 | self.image_height = image_height 58 | self.image_width = image_width 59 | self.im_channels = im_channels 60 | 61 | self.hidden_size = hidden_size 62 | 63 | self.patch_height = patch_height 64 | self.patch_width = patch_width 65 | 66 | # Input dimension for Patch Embedding FC Layer 67 | patch_dim = self.im_channels * self.patch_height * self.patch_width 68 | self.patch_embed = nn.Sequential( 69 | nn.Linear(patch_dim, self.hidden_size) 70 | ) 71 | 72 | ############################ 73 | # DiT Layer Initialization # 74 | ############################ 75 | nn.init.xavier_uniform_(self.patch_embed[0].weight) 76 | nn.init.constant_(self.patch_embed[0].bias, 0) 77 | 78 | def forward(self, x): 79 | grid_size_h = self.image_height // self.patch_height 80 | grid_size_w = self.image_width // self.patch_width 81 | 82 | # B, C, H, W -> B, (Patches along height * Patches along width), Patch Dimension 83 | # Number of tokens = Patches along height * Patches along width 84 | out = rearrange(x, 'b c (nh ph) (nw pw) -> b (nh nw) (ph pw c)', 85 | ph=self.patch_height, 86 | pw=self.patch_width) 87 | 88 | # BxNumber of tokens x Patch Dimension -> B x Number of tokens x Transformer Dimension 89 | out = self.patch_embed(out) 90 | 91 | # Add 2d sinusoidal position embeddings 92 | pos_embed = get_patch_position_embedding(pos_emb_dim=self.hidden_size, 93 | grid_size=(grid_size_h, grid_size_w), 94 | device=x.device) 95 | out += pos_embed 96 | return out 97 | 98 | -------------------------------------------------------------------------------- /model/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from model.patch_embed import PatchEmbedding 4 | from model.transformer_layer import TransformerLayer 5 | from einops import rearrange, repeat 6 | 7 | 8 | def get_time_embedding(time_steps, temb_dim): 9 | r""" 10 | Convert time steps tensor into an embedding using the 11 | sinusoidal time embedding formula 12 | :param time_steps: 1D tensor of length batch size 13 | :param temb_dim: Dimension of the embedding 14 | :return: BxD embedding representation of B time steps 15 | """ 16 | assert temb_dim % 2 == 0, "time embedding dimension must be divisible by 2" 17 | 18 | # factor = 10000^(2i/d_model) 19 | factor = 10000 ** ((torch.arange( 20 | start=0, 21 | end=temb_dim // 2, 22 | dtype=torch.float32, 23 | device=time_steps.device) / (temb_dim // 2)) 24 | ) 25 | 26 | # pos / factor 27 | # timesteps B -> B, 1 -> B, temb_dim 28 | t_emb = time_steps[:, None].repeat(1, temb_dim // 2) / factor 29 | t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1) 30 | return t_emb 31 | 32 | 33 | class DITVideo(nn.Module): 34 | r""" 35 | Class for the Latte Model, which makes use of alternate spatial 36 | and temporal encoder layers, each of which are DiT blocks. 37 | """ 38 | def __init__(self, frame_height, frame_width, im_channels, num_frames, config): 39 | super().__init__() 40 | 41 | num_layers = config['num_layers'] 42 | self.image_height = frame_height 43 | self.image_width = frame_width 44 | self.im_channels = im_channels 45 | self.hidden_size = config['hidden_size'] 46 | self.patch_height = config['patch_size'] 47 | self.patch_width = config['patch_size'] 48 | self.num_frames = num_frames 49 | 50 | self.timestep_emb_dim = config['timestep_emb_dim'] 51 | 52 | # Number of patches along height and width 53 | self.nh = self.image_height // self.patch_height 54 | self.nw = self.image_width // self.patch_width 55 | 56 | # Patch Embedding Block 57 | self.patch_embed_layer = PatchEmbedding(image_height=self.image_height, 58 | image_width=self.image_width, 59 | im_channels=self.im_channels, 60 | patch_height=self.patch_height, 61 | patch_width=self.patch_width, 62 | hidden_size=self.hidden_size) 63 | 64 | # Initial projection from sinusoidal time embedding 65 | self.t_proj = nn.Sequential( 66 | nn.Linear(self.timestep_emb_dim, self.hidden_size), 67 | nn.SiLU(), 68 | nn.Linear(self.hidden_size, self.hidden_size) 69 | ) 70 | 71 | # All Transformer Layers 72 | self.layers = nn.ModuleList([ 73 | TransformerLayer(config) for _ in range(num_layers) 74 | ]) 75 | 76 | # Final normalization for unpatchify block 77 | self.norm = nn.LayerNorm(self.hidden_size, elementwise_affine=False, eps=1E-6) 78 | 79 | # Scale and Shift parameters for the norm 80 | self.adaptive_norm_layer = nn.Sequential( 81 | nn.SiLU(), 82 | nn.Linear(self.hidden_size, 2 * self.hidden_size, bias=True) 83 | ) 84 | 85 | # Final Linear Layer 86 | self.proj_out = nn.Linear(self.hidden_size, 87 | self.patch_height * self.patch_width * self.im_channels) 88 | 89 | ############################ 90 | # DiT Layer Initialization # 91 | ############################ 92 | nn.init.normal_(self.t_proj[0].weight, std=0.02) 93 | nn.init.normal_(self.t_proj[2].weight, std=0.02) 94 | 95 | nn.init.constant_(self.adaptive_norm_layer[-1].weight, 0) 96 | nn.init.constant_(self.adaptive_norm_layer[-1].bias, 0) 97 | 98 | nn.init.constant_(self.proj_out.weight, 0) 99 | nn.init.constant_(self.proj_out.bias, 0) 100 | 101 | def forward(self, x, t, num_images=0): 102 | r""" 103 | Forward method of our ditv model which predicts the noise 104 | :param x: input noisy image 105 | :param t: timestep of noise 106 | :param num_images: if joint training then this is number 107 | of images appended to video frames 108 | :return: 109 | """ 110 | # Shape of x is Batch_size x (num_frames + num_images) x Channels x H x W 111 | B, F, C, H, W = x.shape 112 | 113 | ################## 114 | # Patchify Block # 115 | ################## 116 | # rearrange to (Batch_size * (num_frames + num_images)) x Channels x H x W 117 | x = rearrange(x, 'b f c h w -> (b f) c h w') 118 | out = self.patch_embed_layer(x) 119 | 120 | # out->(Batch_size * (num_frames + num_images)) x num_patch_tokens x hidden_size 121 | num_patch_tokens = out.shape[1] 122 | 123 | # Compute Timestep representation 124 | # t_emb -> (Batch, timestep_emb_dim) 125 | t_emb = get_time_embedding(torch.as_tensor(t).long(), self.timestep_emb_dim) 126 | # (Batch, timestep_emb_dim) -> (Batch, hidden_size) 127 | t_emb = self.t_proj(t_emb) 128 | 129 | # Timestep embedding will be Batch_size x hidden_size 130 | # We repeat it to get different timestep shapes for spatial and temporal layers 131 | # For spatial -> (Batch size * (num_frames + num_images)) x hidden_size 132 | # For temporal -> (Batch size * num_patch_tokens) x hidden_size 133 | t_emb_spatial = repeat(t_emb, 'b d -> (b f) d', 134 | f=self.num_frames+num_images) 135 | t_emb_temporal = repeat(t_emb, 'b d -> (b p) d', p=num_patch_tokens) 136 | 137 | # get temporal embedding from 0-num_frames(16) 138 | frame_pos = torch.arange(self.num_frames, dtype=torch.float32, device=x.device) 139 | frame_emb = get_time_embedding(frame_pos, self.hidden_size) 140 | # frame_emb -> (16 x hidden_size) 141 | 142 | # Loop over all transformer layers 143 | for layer_idx in range(0, len(self.layers), 2): 144 | spatial_layer = self.layers[layer_idx] 145 | temporal_layer = self.layers[layer_idx+1] 146 | 147 | # out->(Batch_size * (num_frames+num_images)) x num_patch_tokens x hidden_size 148 | 149 | ################# 150 | # Spatial Layer # 151 | ################# 152 | # position embedding is already added in patch embedding layer 153 | out = spatial_layer(out, t_emb_spatial) 154 | 155 | ################## 156 | # Temporal Layer # 157 | ################## 158 | # rearrange to (B * patch_tokens) x (num_frames+num_images) x hidden_size 159 | out = rearrange(out, '(b f) p d -> (b p) f d', b=B) 160 | 161 | # Separate the video tokens and image tokens 162 | out_video = out[:, :self.num_frames, :] 163 | out_images = out[:, self.num_frames:, :] 164 | 165 | # Add temporal embedding to video tokens 166 | # but only if first temporal layer 167 | if layer_idx == 0: 168 | out_video = out_video + frame_emb 169 | # Call temporal layer 170 | out_video = temporal_layer(out_video, t_emb_temporal) 171 | 172 | # Concatenate the image tokens back to the new video output 173 | out = torch.cat([out_video, out_images], dim=1) 174 | 175 | # Rearrange to (B * (num_frames+num_images)) x num_patch_tokens x hidden_size 176 | out = rearrange(out, '(b p) f d -> (b f) p d', 177 | f=self.num_frames+num_images, p=num_patch_tokens) 178 | 179 | # Shift and scale predictions for output normalization 180 | pre_mlp_shift, pre_mlp_scale = (self.adaptive_norm_layer(t_emb_spatial). 181 | chunk(2, dim=1)) 182 | out = (self.norm(out) * (1 + pre_mlp_scale.unsqueeze(1)) + 183 | pre_mlp_shift.unsqueeze(1)) 184 | 185 | # Unpatchify 186 | # Batch_size * (num_frames+num_images)) x patches x hidden_size 187 | # -> (B * (num_frames+num_images)) x patches x (patch height*patch width*channels) 188 | out = self.proj_out(out) 189 | out = rearrange(out, 'b (nh nw) (ph pw c) -> b c (nh ph) (nw pw)', 190 | ph=self.patch_height, 191 | pw=self.patch_width, 192 | nw=self.nw, 193 | nh=self.nh) 194 | # out -> (Batch_size * (num_frames+num_images)) x channels x h x w 195 | out = out.reshape((B, F, C, H, W)) 196 | return out 197 | -------------------------------------------------------------------------------- /model/transformer_layer.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from model.attention import Attention 3 | 4 | 5 | class TransformerLayer(nn.Module): 6 | r""" 7 | Transformer block which is just doing the following based on VIT 8 | 1. LayerNorm followed by Attention 9 | 2. LayerNorm followed by Feed forward Block 10 | Both these also have residuals added to them 11 | 12 | For DiT we additionally have 13 | 1. Layernorm mlp to predict layernorm affine parameters from 14 | 2. Same Layernorm mlp to also predict scale parameters for outputs 15 | of both mlp/attention prior to residual connection. 16 | """ 17 | 18 | def __init__(self, config): 19 | super().__init__() 20 | self.hidden_size = config['hidden_size'] 21 | 22 | ff_hidden_dim = 4 * self.hidden_size 23 | 24 | # Layer norm for attention block 25 | self.att_norm = nn.LayerNorm(self.hidden_size, elementwise_affine=False, eps=1E-6) 26 | 27 | self.attn_block = Attention(config) 28 | 29 | # Layer norm for mlp block 30 | self.ff_norm = nn.LayerNorm(self.hidden_size, elementwise_affine=False, eps=1E-6) 31 | 32 | self.mlp_block = nn.Sequential( 33 | nn.Linear(self.hidden_size, ff_hidden_dim), 34 | nn.GELU(approximate='tanh'), 35 | nn.Linear(ff_hidden_dim, self.hidden_size), 36 | ) 37 | 38 | # Scale Shift Parameter predictions for this layer 39 | # 1. Scale and shift parameters for layernorm of attention (2 * hidden_size) 40 | # 2. Scale and shift parameters for layernorm of mlp (2 * hidden_size) 41 | # 3. Scale for output of attention prior to residual connection (hidden_size) 42 | # 4. Scale for output of mlp prior to residual connection (hidden_size) 43 | # Total 6 * hidden_size 44 | self.adaptive_norm_layer = nn.Sequential( 45 | nn.SiLU(), 46 | nn.Linear(self.hidden_size, 6 * self.hidden_size, bias=True) 47 | ) 48 | 49 | ############################ 50 | # DiT Layer Initialization # 51 | ############################ 52 | nn.init.xavier_uniform_(self.mlp_block[0].weight) 53 | nn.init.constant_(self.mlp_block[0].bias, 0) 54 | nn.init.xavier_uniform_(self.mlp_block[-1].weight) 55 | nn.init.constant_(self.mlp_block[-1].bias, 0) 56 | 57 | nn.init.constant_(self.adaptive_norm_layer[-1].weight, 0) 58 | nn.init.constant_(self.adaptive_norm_layer[-1].bias, 0) 59 | 60 | def forward(self, x, condition): 61 | scale_shift_params = self.adaptive_norm_layer(condition).chunk(6, dim=1) 62 | (pre_attn_shift, pre_attn_scale, post_attn_scale, 63 | pre_mlp_shift, pre_mlp_scale, post_mlp_scale) = scale_shift_params 64 | out = x 65 | attn_norm_output = (self.att_norm(out) * (1 + pre_attn_scale.unsqueeze(1)) 66 | + pre_attn_shift.unsqueeze(1)) 67 | out = out + post_attn_scale.unsqueeze(1) * self.attn_block(attn_norm_output) 68 | mlp_norm_output = (self.ff_norm(out) * (1 + pre_mlp_scale.unsqueeze(1)) + 69 | pre_mlp_shift.unsqueeze(1)) 70 | out = out + post_mlp_scale.unsqueeze(1) * self.mlp_block(mlp_norm_output) 71 | return out 72 | -------------------------------------------------------------------------------- /model/vae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from model.blocks import DownBlock, MidBlock, UpBlock 4 | 5 | 6 | class VAE(nn.Module): 7 | def __init__(self, im_channels, model_config): 8 | super().__init__() 9 | self.down_channels = model_config['down_channels'] 10 | self.mid_channels = model_config['mid_channels'] 11 | self.down_sample = model_config['down_sample'] 12 | self.num_down_layers = model_config['num_down_layers'] 13 | self.num_mid_layers = model_config['num_mid_layers'] 14 | self.num_up_layers = model_config['num_up_layers'] 15 | 16 | # To disable attention in Downblock of Encoder and Upblock of Decoder 17 | self.attns = model_config['attn_down'] 18 | 19 | # Latent Dimension 20 | self.z_channels = model_config['z_channels'] 21 | self.norm_channels = model_config['norm_channels'] 22 | self.num_heads = model_config['num_heads'] 23 | 24 | # Assertion to validate the channel information 25 | assert self.mid_channels[0] == self.down_channels[-1] 26 | assert self.mid_channels[-1] == self.down_channels[-1] 27 | assert len(self.down_sample) == len(self.down_channels) - 1 28 | assert len(self.attns) == len(self.down_channels) - 1 29 | 30 | # Wherever we use downsampling in encoder correspondingly use 31 | # upsampling in decoder 32 | self.up_sample = list(reversed(self.down_sample)) 33 | 34 | ##################### Encoder ###################### 35 | self.encoder_conv_in = nn.Conv2d(im_channels, self.down_channels[0], kernel_size=3, padding=(1, 1)) 36 | 37 | # Downblock + Midblock 38 | self.encoder_layers = nn.ModuleList([]) 39 | for i in range(len(self.down_channels) - 1): 40 | self.encoder_layers.append(DownBlock(self.down_channels[i], self.down_channels[i + 1], 41 | t_emb_dim=None, down_sample=self.down_sample[i], 42 | num_heads=self.num_heads, 43 | num_layers=self.num_down_layers, 44 | attn=self.attns[i], 45 | norm_channels=self.norm_channels)) 46 | 47 | self.encoder_mids = nn.ModuleList([]) 48 | for i in range(len(self.mid_channels) - 1): 49 | self.encoder_mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i + 1], 50 | t_emb_dim=None, 51 | num_heads=self.num_heads, 52 | num_layers=self.num_mid_layers, 53 | norm_channels=self.norm_channels)) 54 | 55 | self.encoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[-1]) 56 | self.encoder_conv_out = nn.Conv2d(self.down_channels[-1], 2 * self.z_channels, kernel_size=3, padding=1) 57 | 58 | # Latent Dimension is 2*Latent because we are predicting mean & variance 59 | self.pre_quant_conv = nn.Conv2d(2 * self.z_channels, 2 * self.z_channels, kernel_size=1) 60 | #################################################### 61 | 62 | ##################### Decoder ###################### 63 | self.post_quant_conv = nn.Conv2d(self.z_channels, self.z_channels, kernel_size=1) 64 | self.decoder_conv_in = nn.Conv2d(self.z_channels, self.mid_channels[-1], kernel_size=3, padding=(1, 1)) 65 | 66 | # Midblock + Upblock 67 | self.decoder_mids = nn.ModuleList([]) 68 | for i in reversed(range(1, len(self.mid_channels))): 69 | self.decoder_mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i - 1], 70 | t_emb_dim=None, 71 | num_heads=self.num_heads, 72 | num_layers=self.num_mid_layers, 73 | norm_channels=self.norm_channels)) 74 | 75 | self.decoder_layers = nn.ModuleList([]) 76 | for i in reversed(range(1, len(self.down_channels))): 77 | self.decoder_layers.append(UpBlock(self.down_channels[i], self.down_channels[i - 1], 78 | t_emb_dim=None, up_sample=self.down_sample[i - 1], 79 | num_heads=self.num_heads, 80 | num_layers=self.num_up_layers, 81 | attn=self.attns[i - 1], 82 | norm_channels=self.norm_channels)) 83 | 84 | self.decoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[0]) 85 | self.decoder_conv_out = nn.Conv2d(self.down_channels[0], im_channels, kernel_size=3, padding=1) 86 | 87 | def encode(self, x): 88 | out = self.encoder_conv_in(x) 89 | for idx, down in enumerate(self.encoder_layers): 90 | out = down(out) 91 | for mid in self.encoder_mids: 92 | out = mid(out) 93 | out = self.encoder_norm_out(out) 94 | out = nn.SiLU()(out) 95 | out = self.encoder_conv_out(out) 96 | out = self.pre_quant_conv(out) 97 | mean, logvar = torch.chunk(out, 2, dim=1) 98 | std = torch.exp(0.5 * logvar) 99 | sample = mean + std * torch.randn(mean.shape).to(device=x.device) 100 | return sample, out 101 | 102 | def decode(self, z): 103 | out = z 104 | out = self.post_quant_conv(out) 105 | out = self.decoder_conv_in(out) 106 | for mid in self.decoder_mids: 107 | out = mid(out) 108 | for idx, up in enumerate(self.decoder_layers): 109 | out = up(out) 110 | 111 | out = self.decoder_norm_out(out) 112 | out = nn.SiLU()(out) 113 | out = self.decoder_conv_out(out) 114 | return out 115 | 116 | def forward(self, x): 117 | z, encoder_output = self.encode(x) 118 | out = self.decode(z) 119 | return out, encoder_output 120 | 121 | 122 | 123 | if __name__ == '__main__': 124 | a = torch.randn((2, 32, 16, 16)) 125 | net = nn.Sequential( 126 | nn.Flatten(), 127 | nn.Linear(16* 16 * 32, 2 * 32) 128 | ) 129 | net = torch.nn.Conv2d(32, 64, kernel_size=1) 130 | out = net(a) 131 | print(out.shape) -------------------------------------------------------------------------------- /model/weights/v0.1/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/explainingai-code/VideoGeneration-PyTorch/9d40d42a8f5a8919ce0d356ad54ebfce4c3089b8/model/weights/v0.1/.gitkeep -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | einops==0.8.0 2 | numpy==2.0.1 3 | opencv_python==4.10.0.84 4 | Pillow==10.4.0 5 | PyYAML==6.0.1 6 | torch==2.3.1 7 | torchvision==0.18.1 8 | tqdm==4.66.4 9 | -------------------------------------------------------------------------------- /scheduler/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/explainingai-code/VideoGeneration-PyTorch/9d40d42a8f5a8919ce0d356ad54ebfce4c3089b8/scheduler/__init__.py -------------------------------------------------------------------------------- /scheduler/linear_scheduler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class LinearNoiseScheduler: 5 | r""" 6 | Class for the linear noise scheduler that is used in DDPM. 7 | """ 8 | 9 | def __init__(self, num_timesteps, beta_start, beta_end): 10 | self.num_timesteps = num_timesteps 11 | self.beta_start = beta_start 12 | self.beta_end = beta_end 13 | 14 | self.betas = torch.linspace(beta_start, beta_end, num_timesteps) 15 | self.alphas = 1. - self.betas 16 | self.alpha_cum_prod = torch.cumprod(self.alphas, dim=0) 17 | self.sqrt_alpha_cum_prod = torch.sqrt(self.alpha_cum_prod) 18 | self.sqrt_one_minus_alpha_cum_prod = torch.sqrt(1 - self.alpha_cum_prod) 19 | 20 | def add_noise(self, original, noise, t): 21 | r""" 22 | Forward method for diffusion 23 | :param original: Image on which noise is to be applied 24 | :param noise: Random Noise Tensor (from normal dist) 25 | :param t: timestep of the forward process of shape -> (B,) 26 | :return: 27 | """ 28 | original_shape = original.shape 29 | batch_size = original_shape[0] 30 | 31 | sqrt_alpha_cum_prod = self.sqrt_alpha_cum_prod.to(original.device)[t].reshape(batch_size) 32 | sqrt_one_minus_alpha_cum_prod = self.sqrt_one_minus_alpha_cum_prod.to(original.device)[t].reshape(batch_size) 33 | 34 | # Reshape till (B,) becomes (B,1,1,1) if image is (B,C,H,W) 35 | for _ in range(len(original_shape) - 1): 36 | sqrt_alpha_cum_prod = sqrt_alpha_cum_prod.unsqueeze(-1) 37 | for _ in range(len(original_shape) - 1): 38 | sqrt_one_minus_alpha_cum_prod = sqrt_one_minus_alpha_cum_prod.unsqueeze(-1) 39 | 40 | # Apply and Return Forward process equation 41 | return (sqrt_alpha_cum_prod.to(original.device) * original 42 | + sqrt_one_minus_alpha_cum_prod.to(original.device) * noise) 43 | 44 | def sample_prev_timestep(self, xt, pred, t): 45 | r""" 46 | Use the noise prediction by model to get 47 | xt-1 using xt and the noise predicted 48 | :param xt: current timestep sample 49 | :param pred: model noise prediction 50 | :param t: current timestep we are at 51 | :return: 52 | """ 53 | x0 = ((xt - (self.sqrt_one_minus_alpha_cum_prod.to(xt.device)[t] * pred)) / 54 | torch.sqrt(self.alpha_cum_prod.to(xt.device)[t])) 55 | x0 = torch.clamp(x0, -1., 1.) 56 | 57 | mean = xt - ((self.betas.to(xt.device)[t]) * pred) / (self.sqrt_one_minus_alpha_cum_prod.to(xt.device)[t]) 58 | mean = mean / torch.sqrt(self.alphas.to(xt.device)[t]) 59 | 60 | if t == 0: 61 | return mean, x0 62 | else: 63 | variance = (1 - self.alpha_cum_prod.to(xt.device)[t - 1]) / (1.0 - self.alpha_cum_prod.to(xt.device)[t]) 64 | variance = variance * self.betas.to(xt.device)[t] 65 | sigma = variance ** 0.5 66 | z = torch.randn(xt.shape).to(xt.device) 67 | return mean + sigma * z, x0 68 | -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/explainingai-code/VideoGeneration-PyTorch/9d40d42a8f5a8919ce0d356ad54ebfce4c3089b8/tools/__init__.py -------------------------------------------------------------------------------- /tools/sample_vae_ditv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import argparse 4 | import yaml 5 | import os 6 | from torchvision.utils import make_grid 7 | from PIL import Image 8 | from tqdm import tqdm 9 | import torchvision.transforms.v2 as v2 10 | from model.vae import VAE 11 | from model.transformer import DITVideo 12 | from scheduler.linear_scheduler import LinearNoiseScheduler 13 | 14 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 15 | if torch.backends.mps.is_available(): 16 | device = torch.device('mps') 17 | print('Using mps') 18 | 19 | 20 | def sample(model, scheduler, train_config, ditv_model_config, 21 | autoencoder_model_config, diffusion_config, dataset_config, vae): 22 | r""" 23 | Sample stepwise by going backward one timestep at a time. 24 | We save the x0 predictions 25 | """ 26 | latent_frame_height = (dataset_config['frame_height'] 27 | // 2 ** sum(autoencoder_model_config['down_sample'])) 28 | latent_frame_width = (dataset_config['frame_width'] 29 | // 2 ** sum(autoencoder_model_config['down_sample'])) 30 | 31 | xt = torch.randn((1, dataset_config['num_frames'], 32 | autoencoder_model_config['z_channels'], 33 | latent_frame_height, latent_frame_width)).to(device) 34 | 35 | for i in tqdm(reversed(range(diffusion_config['num_timesteps']))): 36 | # Get prediction of noise 37 | noise_pred = model(xt, torch.as_tensor(i).unsqueeze(0).to(device)) 38 | 39 | # Use scheduler to get x0 and xt-1 40 | xt, x0_pred = scheduler.sample_prev_timestep(xt, 41 | noise_pred, 42 | torch.as_tensor(i).to(device)) 43 | 44 | # Save x0 45 | if i == 0: 46 | # Decode ONLY the final video to save time 47 | ims = vae.to(device).decode(xt[0]) 48 | else: 49 | ims = xt 50 | 51 | ims = torch.clamp(ims, -1., 1.).detach().cpu() 52 | ims = (ims + 1) / 2 53 | tv_frames = ims * 255 54 | 55 | if i == 0: 56 | tv_frames = tv_frames.permute((0, 2, 3, 1)) 57 | if tv_frames.shape[-1] == 1: 58 | tv_frames = tv_frames.repeat((1, 1, 1, 3)) 59 | else: 60 | tv_frames = v2.Compose([ 61 | v2.Resize((dataset_config['frame_height'], 62 | dataset_config['frame_width']), 63 | interpolation=v2.InterpolationMode.NEAREST), 64 | ])(tv_frames) 65 | tv_frames = tv_frames[0].permute((0, 2, 3, 1))[:, :, :, :3] 66 | 67 | if not os.path.exists(os.path.join(train_config['task_name'], 'samples')): 68 | os.mkdir(os.path.join(train_config['task_name'], 'samples')) 69 | torchvision.io.write_video(os.path.join(train_config['task_name'], 70 | 'samples/sample_output_{}.mp4'.format(i)), 71 | tv_frames, 72 | fps=8) 73 | 74 | 75 | def infer(args): 76 | # Read the config file # 77 | with open(args.config_path, 'r') as file: 78 | try: 79 | config = yaml.safe_load(file) 80 | except yaml.YAMLError as exc: 81 | print(exc) 82 | print(config) 83 | ######################## 84 | 85 | diffusion_config = config['diffusion_params'] 86 | dataset_config = config['dataset_params'] 87 | ditv_model_config = config['ditv_params'] 88 | autoencoder_model_config = config['autoencoder_params'] 89 | train_config = config['train_params'] 90 | 91 | # Create the noise scheduler 92 | scheduler = LinearNoiseScheduler(num_timesteps=diffusion_config['num_timesteps'], 93 | beta_start=diffusion_config['beta_start'], 94 | beta_end=diffusion_config['beta_end']) 95 | 96 | # Get latent image size 97 | frame_height = (dataset_config['frame_height'] 98 | // 2 ** sum(autoencoder_model_config['down_sample'])) 99 | frame_width = (dataset_config['frame_width'] 100 | // 2 ** sum(autoencoder_model_config['down_sample'])) 101 | num_frames = dataset_config['num_frames'] 102 | model = DITVideo(frame_height=frame_height, 103 | frame_width=frame_width, 104 | im_channels=autoencoder_model_config['z_channels'], 105 | num_frames=num_frames, 106 | config=ditv_model_config).to(device) 107 | model.eval() 108 | 109 | assert os.path.exists(os.path.join(train_config['task_name'], 110 | train_config['ditv_ckpt_name'])), \ 111 | "Train DiT Video Model first" 112 | 113 | model.load_state_dict(torch.load(os.path.join(train_config['task_name'], 114 | train_config['ditv_ckpt_name']), 115 | map_location=device)) 116 | print('Loaded dit video checkpoint') 117 | 118 | # Create output directories 119 | if not os.path.exists(train_config['task_name']): 120 | os.mkdir(train_config['task_name']) 121 | 122 | vae = VAE(im_channels=dataset_config['frame_channels'], 123 | model_config=autoencoder_model_config) 124 | vae.eval() 125 | 126 | # Load vae if found 127 | assert os.path.exists(os.path.join(train_config['task_name'], 128 | train_config['vae_autoencoder_ckpt_name'])), \ 129 | "VAE checkpoint not present. Train VAE first." 130 | vae.load_state_dict(torch.load( 131 | os.path.join(train_config['task_name'], 132 | train_config['vae_autoencoder_ckpt_name']), 133 | map_location=device), strict=True) 134 | print('Loaded vae checkpoint') 135 | 136 | with torch.no_grad(): 137 | sample(model, scheduler, train_config, ditv_model_config, 138 | autoencoder_model_config, diffusion_config, dataset_config, vae) 139 | 140 | 141 | if __name__ == '__main__': 142 | parser = argparse.ArgumentParser(description='Arguments for latte video generation') 143 | parser.add_argument('--config', dest='config_path', 144 | default='config/mnist.yaml', type=str) 145 | args = parser.parse_args() 146 | infer(args) 147 | -------------------------------------------------------------------------------- /tools/save_latents.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import os 4 | import pickle 5 | 6 | import torch 7 | import torchvision 8 | import yaml 9 | from dataset.ucf_dataset import UCFDataset 10 | from dataset.mnist_dataset import MnistDataset 11 | from dataset.video_dataset import VideoDataset 12 | 13 | from torch.utils.data.dataloader import DataLoader 14 | import torchvision.transforms.v2 15 | from torchvision.utils import make_grid 16 | from tqdm import tqdm 17 | 18 | 19 | from model.vae import VAE 20 | 21 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 22 | if torch.backends.mps.is_available(): 23 | device = torch.device('mps') 24 | print('Using mps') 25 | 26 | 27 | def save_vae_latents(args): 28 | ######## Read the config file ####### 29 | with open(args.config_path, 'r') as file: 30 | try: 31 | config = yaml.safe_load(file) 32 | except yaml.YAMLError as exc: 33 | print(exc) 34 | print(config) 35 | 36 | dataset_config = config['dataset_params'] 37 | autoencoder_config = config['autoencoder_params'] 38 | train_config = config['train_params'] 39 | 40 | model = VAE(im_channels=dataset_config['frame_channels'], 41 | model_config=autoencoder_config).to(device) 42 | model.load_state_dict(torch.load(os.path.join( 43 | train_config['task_name'], 44 | train_config['vae_autoencoder_ckpt_name']), 45 | map_location=device)) 46 | model.eval() 47 | 48 | dataset = VideoDataset('train', 49 | dataset_config) 50 | 51 | print('Will be generating latents for {} videos'.format(len(dataset.video_paths))) 52 | with torch.no_grad(): 53 | if not os.path.exists(os.path.join(train_config['task_name'], 54 | train_config['save_video_latent_dir'])): 55 | os.mkdir(os.path.join(train_config['task_name'], 56 | train_config['save_video_latent_dir'])) 57 | for path in tqdm(dataset.video_paths): 58 | # Read the video 59 | frames, _, _ = torchvision.io.read_video(filename=path, 60 | pts_unit='sec', 61 | output_format='TCHW') 62 | 63 | # Transform all frames 64 | frames_tensor = dataset.transforms(frames) 65 | if dataset_config['frame_channels'] == 1: 66 | frames_tensor = frames_tensor[:, 0:1, :, :] 67 | 68 | encoded_outputs = [] 69 | for frame_tensor in frames_tensor: 70 | _, encoded_output = model.encode( 71 | frame_tensor.float().unsqueeze(0).to(device) 72 | ) 73 | encoded_outputs.append(encoded_output) 74 | encoded_outputs = torch.cat(encoded_outputs, dim=0) 75 | pickle.dump(encoded_outputs, open( 76 | os.path.join(train_config['task_name'], 77 | train_config['save_video_latent_dir'], 78 | '{}.pkl'.format(os.path.basename(path))), 'wb')) 79 | 80 | 81 | if __name__ == '__main__': 82 | parser = argparse.ArgumentParser(description='Arguments for vae inference and ' 83 | 'saving latents') 84 | parser.add_argument('--config', dest='config_path', 85 | default='config/ucf.yaml', type=str) 86 | args = parser.parse_args() 87 | save_vae_latents(args) 88 | -------------------------------------------------------------------------------- /tools/train_vae.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import argparse 3 | import torch 4 | import random 5 | import torchvision 6 | import os 7 | import numpy as np 8 | from tqdm import tqdm 9 | from model.vae import VAE 10 | from model.lpips import LPIPS 11 | from model.discriminator import Discriminator 12 | from torch.utils.data.dataloader import DataLoader 13 | from torch.optim import Adam 14 | from torchvision.utils import make_grid 15 | from dataset.image_dataset import ImageDataset 16 | 17 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 18 | if torch.backends.mps.is_available(): 19 | device = torch.device('mps') 20 | print('Using mps') 21 | 22 | 23 | def train(args): 24 | # Read the config file # 25 | with open(args.config_path, 'r') as file: 26 | try: 27 | config = yaml.safe_load(file) 28 | except yaml.YAMLError as exc: 29 | print(exc) 30 | print(config) 31 | 32 | dataset_config = config['dataset_params'] 33 | autoencoder_config = config['autoencoder_params'] 34 | train_config = config['train_params'] 35 | 36 | # Set the desired seed value # 37 | seed = train_config['seed'] 38 | torch.manual_seed(seed) 39 | np.random.seed(seed) 40 | random.seed(seed) 41 | if device == 'cuda': 42 | torch.cuda.manual_seed_all(seed) 43 | ############################# 44 | 45 | # Create the model and dataset # 46 | model = VAE(im_channels=dataset_config['frame_channels'], 47 | model_config=autoencoder_config).to(device) 48 | 49 | # Create the dataset 50 | im_dataset = ImageDataset(split='train', 51 | dataset_config=dataset_config, 52 | task_name=train_config['task_name']) 53 | 54 | data_loader = DataLoader(im_dataset, 55 | batch_size=train_config['autoencoder_batch_size'], 56 | shuffle=True) 57 | 58 | # Create output directories 59 | if not os.path.exists(train_config['task_name']): 60 | os.mkdir(train_config['task_name']) 61 | 62 | num_epochs = train_config['autoencoder_epochs'] 63 | 64 | # L1/L2 loss for Reconstruction 65 | recon_criterion = torch.nn.MSELoss() 66 | # Disc Loss can even be BCEWithLogits 67 | disc_criterion = torch.nn.MSELoss() 68 | 69 | # No need to freeze lpips as lpips.py takes care of that 70 | lpips_model = LPIPS().eval().to(device) 71 | discriminator = Discriminator(im_channels=dataset_config['frame_channels']).to(device) 72 | 73 | if os.path.exists(os.path.join(train_config['task_name'], 74 | train_config['vae_autoencoder_ckpt_name'])): 75 | model.load_state_dict(torch.load(os.path.join(train_config['task_name'], 76 | train_config['vae_autoencoder_ckpt_name']), 77 | map_location=device)) 78 | print('Loaded autoencoder from checkpoint') 79 | 80 | if os.path.exists(os.path.join(train_config['task_name'], 81 | train_config['vae_discriminator_ckpt_name'])): 82 | discriminator.load_state_dict(torch.load(os.path.join(train_config['task_name'], 83 | train_config['vae_discriminator_ckpt_name']), 84 | map_location=device)) 85 | print('Loaded discriminator from checkpoint') 86 | 87 | optimizer_d = Adam(discriminator.parameters(), lr=train_config['autoencoder_lr'], betas=(0.5, 0.999)) 88 | optimizer_g = Adam(model.parameters(), lr=train_config['autoencoder_lr'], betas=(0.5, 0.999)) 89 | 90 | disc_step_start = train_config['disc_start'] 91 | step_count = 0 92 | 93 | # This is for accumulating gradients incase the images are huge 94 | # And one cant afford higher batch sizes 95 | acc_steps = train_config['autoencoder_acc_steps'] 96 | image_save_steps = train_config['autoencoder_img_save_steps'] 97 | img_save_count = 0 98 | 99 | for epoch_idx in range(num_epochs): 100 | recon_losses = [] 101 | perceptual_losses = [] 102 | disc_losses = [] 103 | gen_losses = [] 104 | losses = [] 105 | 106 | optimizer_g.zero_grad() 107 | optimizer_d.zero_grad() 108 | 109 | for im in tqdm(data_loader): 110 | step_count += 1 111 | im = im.float().to(device) 112 | 113 | # Fetch autoencoders output(reconstructions) 114 | model_output = model(im) 115 | output, encoder_output = model_output 116 | 117 | # Image Saving Logic 118 | if step_count % image_save_steps == 0 or step_count == 1: 119 | sample_size = min(8, im.shape[0]) 120 | save_output = torch.clamp(output[:sample_size], -1., 1.).detach().cpu() 121 | save_output = ((save_output + 1) / 2) 122 | save_input = ((im[:sample_size] + 1) / 2).detach().cpu() 123 | 124 | grid = make_grid(torch.cat([save_input, save_output], dim=0), nrow=sample_size) 125 | img = torchvision.transforms.ToPILImage()(grid) 126 | if not os.path.exists(os.path.join(train_config['task_name'], 'vae_autoencoder_samples')): 127 | os.mkdir(os.path.join(train_config['task_name'], 'vae_autoencoder_samples')) 128 | img.save(os.path.join(train_config['task_name'], 'vae_autoencoder_samples', 129 | 'current_autoencoder_sample_{}.png'.format(img_save_count))) 130 | img_save_count += 1 131 | img.close() 132 | 133 | ######### Optimize Generator ########## 134 | # L2 Loss 135 | recon_loss = recon_criterion(output, im) 136 | recon_losses.append(recon_loss.item()) 137 | recon_loss = recon_loss / acc_steps 138 | 139 | mean, logvar = torch.chunk(encoder_output, 2, dim=1) 140 | kl_loss = torch.mean(0.5 * torch.sum(torch.exp(logvar) + mean ** 2 - 1 - logvar, dim=[1, 2, 3])) 141 | 142 | g_loss = recon_loss + (train_config['kl_weight'] * kl_loss / acc_steps) 143 | 144 | # Adversarial loss only if disc_step_start steps passed 145 | if step_count > disc_step_start: 146 | disc_fake_pred = discriminator(model_output[0]) 147 | disc_fake_loss = disc_criterion(disc_fake_pred, 148 | torch.ones(disc_fake_pred.shape, 149 | device=disc_fake_pred.device)) 150 | gen_losses.append(train_config['disc_weight'] * disc_fake_loss.item()) 151 | g_loss += train_config['disc_weight'] * disc_fake_loss / acc_steps 152 | lpips_loss = torch.mean(lpips_model(output, im)) 153 | perceptual_losses.append(train_config['perceptual_weight'] * lpips_loss.item()) 154 | g_loss += train_config['perceptual_weight'] * lpips_loss / acc_steps 155 | losses.append(g_loss.item()) 156 | g_loss.backward() 157 | ##################################### 158 | 159 | ######### Optimize Discriminator ####### 160 | if step_count > disc_step_start: 161 | fake = output 162 | disc_fake_pred = discriminator(fake.detach()) 163 | disc_real_pred = discriminator(im) 164 | disc_fake_loss = disc_criterion(disc_fake_pred, 165 | torch.zeros(disc_fake_pred.shape, 166 | device=disc_fake_pred.device)) 167 | disc_real_loss = disc_criterion(disc_real_pred, 168 | torch.ones(disc_real_pred.shape, 169 | device=disc_real_pred.device)) 170 | disc_loss = train_config['disc_weight'] * (disc_fake_loss + disc_real_loss) / 2 171 | disc_losses.append(disc_loss.item()) 172 | disc_loss = disc_loss / acc_steps 173 | disc_loss.backward() 174 | if step_count % acc_steps == 0: 175 | optimizer_d.step() 176 | optimizer_d.zero_grad() 177 | ##################################### 178 | 179 | if step_count % acc_steps == 0: 180 | optimizer_g.step() 181 | optimizer_g.zero_grad() 182 | optimizer_d.step() 183 | optimizer_d.zero_grad() 184 | optimizer_g.step() 185 | optimizer_g.zero_grad() 186 | if len(disc_losses) > 0: 187 | print( 188 | 'Finished epoch: {} | Recon Loss : {:.4f} | Perceptual Loss : {:.4f} | ' 189 | 'G Loss : {:.4f} | D Loss {:.4f}'. 190 | format(epoch_idx + 1, 191 | np.mean(recon_losses), 192 | np.mean(perceptual_losses), 193 | np.mean(gen_losses), 194 | np.mean(disc_losses))) 195 | else: 196 | print('Finished epoch: {} | Recon Loss : {:.4f} | Perceptual Loss : {:.4f}'. 197 | format(epoch_idx + 1, 198 | np.mean(recon_losses), 199 | np.mean(perceptual_losses))) 200 | 201 | torch.save(model.state_dict(), os.path.join(train_config['task_name'], 202 | train_config['vae_autoencoder_ckpt_name'])) 203 | torch.save(discriminator.state_dict(), os.path.join(train_config['task_name'], 204 | train_config['vae_discriminator_ckpt_name'])) 205 | print('Done Training...') 206 | 207 | 208 | if __name__ == '__main__': 209 | parser = argparse.ArgumentParser(description='Arguments for vae training') 210 | parser.add_argument('--config', dest='config_path', 211 | default='config/ucf.yaml', type=str) 212 | args = parser.parse_args() 213 | train(args) 214 | -------------------------------------------------------------------------------- /tools/train_vae_ditv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import yaml 3 | import argparse 4 | import os 5 | import numpy as np 6 | from tqdm import tqdm 7 | from einops import rearrange 8 | from torch.optim import AdamW 9 | from torch.utils.data import DataLoader 10 | from dataset.video_dataset import VideoDataset 11 | from model.transformer import DITVideo 12 | from model.vae import VAE 13 | from scheduler.linear_scheduler import LinearNoiseScheduler 14 | 15 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 16 | if torch.backends.mps.is_available(): 17 | device = torch.device('mps') 18 | print('Using mps') 19 | 20 | 21 | def train(args): 22 | # Read the config file # 23 | with open(args.config_path, 'r') as file: 24 | try: 25 | config = yaml.safe_load(file) 26 | except yaml.YAMLError as exc: 27 | print(exc) 28 | print(config) 29 | ######################## 30 | 31 | diffusion_config = config['diffusion_params'] 32 | dataset_config = config['dataset_params'] 33 | ditv_model_config = config['ditv_params'] 34 | autoencoder_model_config = config['autoencoder_params'] 35 | train_config = config['train_params'] 36 | 37 | # Create the noise scheduler 38 | scheduler = LinearNoiseScheduler(num_timesteps=diffusion_config['num_timesteps'], 39 | beta_start=diffusion_config['beta_start'], 40 | beta_end=diffusion_config['beta_end']) 41 | 42 | dataset = VideoDataset('train', 43 | dataset_config=dataset_config, 44 | latent_path=os.path.join( 45 | train_config['task_name'], 46 | train_config['save_video_latent_dir'])) 47 | 48 | data_loader = DataLoader(dataset, 49 | batch_size=train_config['ditv_batch_size'], 50 | shuffle=True) 51 | 52 | # Instantiate the model 53 | frame_height = (dataset_config['frame_height'] // 54 | 2 ** sum(autoencoder_model_config['down_sample'])) 55 | frame_width = (dataset_config['frame_width'] // 56 | 2 ** sum(autoencoder_model_config['down_sample'])) 57 | num_frames = dataset_config['num_frames'] 58 | model = DITVideo(frame_height=frame_height, 59 | frame_width=frame_width, 60 | im_channels=autoencoder_model_config['z_channels'], 61 | num_frames=num_frames, 62 | config=ditv_model_config).to(device) 63 | model.train() 64 | 65 | if os.path.exists(os.path.join(train_config['task_name'], 66 | train_config['ditv_ckpt_name'])): 67 | print('Loaded DiT Video checkpoint') 68 | model.load_state_dict(torch.load(os.path.join(train_config['task_name'], 69 | train_config['ditv_ckpt_name']), 70 | map_location=device)) 71 | 72 | # Load VAE 73 | if not dataset.use_latents: 74 | print('Loading vae model as latents not present') 75 | vae = VAE(im_channels=dataset_config['frame_channels'], 76 | model_config=autoencoder_model_config).to(device) 77 | vae.eval() 78 | # Load vae if found 79 | assert os.path.exists(os.path.join(train_config['task_name'], 80 | train_config['vae_autoencoder_ckpt_name'])), \ 81 | "VAE checkpoint not found" 82 | vae.load_state_dict(torch.load(os.path.join( 83 | train_config['task_name'], 84 | train_config['vae_autoencoder_ckpt_name']), 85 | map_location=device)) 86 | print('Loaded vae checkpoint') 87 | for param in vae.parameters(): 88 | param.requires_grad = False 89 | 90 | # Specify training parameters 91 | num_epochs = train_config['ditv_epochs'] 92 | optimizer = AdamW(model.parameters(), lr=1E-4, weight_decay=0) 93 | criterion = torch.nn.MSELoss() 94 | 95 | acc_steps = train_config['ditv_acc_steps'] 96 | for epoch_idx in range(num_epochs): 97 | losses = [] 98 | step_count = 0 99 | for ims in tqdm(data_loader): 100 | step_count += 1 101 | ims = ims.float().to(device) 102 | B, F, C, H, W = ims.shape 103 | if not dataset.use_latents: 104 | with torch.no_grad(): 105 | ims, _ = vae.encode(ims.reshape(-1, C, H, W)) 106 | ims = rearrange(ims, '(b f) c h w -> b f c h w', b=B, f=F) 107 | 108 | # Sample random noise 109 | noise = torch.randn_like(ims).to(device) 110 | 111 | # Sample timestep 112 | t = torch.randint(0, diffusion_config['num_timesteps'], 113 | (ims.shape[0],)).to(device) 114 | 115 | # Add noise to video according to timestep 116 | noisy_im = scheduler.add_noise(ims, noise, t) 117 | pred = model(noisy_im, t, num_images=dataset_config['num_images_train']) 118 | loss = criterion(pred, noise) 119 | losses.append(loss.item()) 120 | loss = loss / acc_steps 121 | loss.backward() 122 | if step_count % acc_steps == 0: 123 | optimizer.step() 124 | optimizer.zero_grad() 125 | optimizer.step() 126 | optimizer.zero_grad() 127 | print('Finished epoch:{} | Loss : {:.4f}'.format( 128 | epoch_idx + 1, 129 | np.mean(losses))) 130 | torch.save(model.state_dict(), os.path.join(train_config['task_name'], 131 | train_config['ditv_ckpt_name'])) 132 | 133 | print('Done Training ...') 134 | 135 | 136 | if __name__ == '__main__': 137 | parser = argparse.ArgumentParser(description='Arguments for latte training') 138 | parser.add_argument('--config', dest='config_path', 139 | default='config/mnist.yaml', type=str) 140 | args = parser.parse_args() 141 | train(args) 142 | --------------------------------------------------------------------------------