├── .gitignore
├── LICENSE
├── README.md
├── capture_frames.py
├── config
├── mnist.yaml
└── ucf.yaml
├── dataset
├── __init__.py
├── image_dataset.py
├── ucf_filter.txt
└── video_dataset.py
├── model
├── __init__.py
├── attention.py
├── blocks.py
├── discriminator.py
├── lpips.py
├── patch_embed.py
├── transformer.py
├── transformer_layer.py
├── vae.py
└── weights
│ └── v0.1
│ └── .gitkeep
├── requirements.txt
├── scheduler
├── __init__.py
└── linear_scheduler.py
└── tools
├── __init__.py
├── sample_vae_ditv.py
├── save_latents.py
├── train_vae.py
└── train_vae_ditv.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Ignore all image files
2 | *.jpg
3 | *.png
4 | *.jpeg
5 |
6 | # Ignore pycharm and system files
7 | .DS_Store
8 | *.idea
9 | __pycache__
10 | *.zip
11 |
12 | # Ignore dataset files
13 | *.csv
14 | *.json
15 |
16 | # Ignore checkpoints
17 | *.pth
18 |
19 | # Ignore pickle files
20 | *.pkl
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 ExplainingAI
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | Video Generation using Diffusion Transformers in PyTorch
2 | ========
3 |
4 | ## Building Video Generation Model Tutorial
5 |
6 |
8 |
9 |
10 |
11 | ## Sample Output for Latte on moving mnist easy videos
12 | Trained for 300 epochs
13 |
14 | 
15 | 
16 | 
17 | 
18 |
19 | ## Sample Output for Latte on UCF101 videos
20 | Trained for 500 epochs(needs more training)
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 | ___
30 | This repository implements Latent Diffusion Transformer for Video Generation Paper. It provides code for the following:
31 | * Training and inference of VAE on Moving Mnist and UCF101 frames
32 | * Training and Inference of Latte Video Model using trained VAE on 16 frame video clips of both datasets
33 | * Configurable code for training all models from Latte-S to Latte-XL
34 |
35 | This repo has few changes from the [official Latte implementation](https://github.com/Vchitect/Latte) except the following changes.
36 | * Current code is for unconditional generation
37 | * Variance is fixed during training and not learned (like original DDPM)
38 | * No EMA
39 | * Ability to train VAE
40 | * Ability to save latents of video frames for faster training
41 |
42 |
43 | ## Setup
44 | * Create a new conda environment with python 3.10 then run below commands
45 | * `conda activate `
46 | * ```git clone https://github.com/explainingai-code/VideoGeneration-PyTorch.git```
47 | * ```cd VideoGeneration-PyTorch```
48 | * ```pip install -r requirements.txt```
49 | * Download lpips weights by opening this link in browser(dont use wget) https://github.com/richzhang/PerceptualSimilarity/blob/master/lpips/weights/v0.1/vgg.pth and downloading the raw file. Place the downloaded weights file in ```model/weights/v0.1/vgg.pth```
50 | ___
51 |
52 | ## Data Preparation
53 |
54 | ### Moving Mnist Easy Videos
55 | For moving mnist I have used the easy category of videos which have one number moving across frames.
56 | Download the videos from [this](https://www.kaggle.com/datasets/yichengs/captioned-moving-mnist-dataset-easy-version/) page. This includes captions as well so one can also try with text to video models using this.
57 | Create a `data` directory in the repo root and add the downloaded `mmnist-easy` folder there.
58 | Ensure directory structure is the following
59 | ```
60 | VideoGeneration-PyTorch
61 | -> data
62 | -> mmnist-easy
63 | -> *.mp4
64 | ```
65 |
66 | For setting up UCF101, simply download the videos from the official page [here](https://www.crcv.ucf.edu/data/UCF101.php)
67 | and add it to `data` directory.
68 | Ensure directory structure is the following
69 | ```
70 | VideoGeneration-PyTorch
71 | -> data
72 | -> UCF101
73 | -> *.avi
74 |
75 | ```
76 | ---
77 | ## Configuration
78 | Allows you to play with different components of Latte and autoencoder
79 | * ```config/mnist.yaml``` - Configuration used for moving mnist dataset
80 | * ```config/ucf.yaml``` - Configuration used for ucf dataset
81 | Important configuration parameters
82 | * `autoencoder_acc_steps` : For accumulating gradients if video size is too large and a large batch size cant be used.
83 |
84 | ___
85 | ## Training
86 | The repo provides training and inference for Moving Mnist (Unconditional Latte Model)
87 |
88 | For working on your own dataset:
89 | * Create your own config and ensure following config parameters are correctly set
90 | * `im_path` - Folder name where latents will be saved (when `save_latents` script is run later)
91 | * `video_path`- Path to the videos
92 | * `video_ext` - Extension for videos. Assumption is all videos are same Extension
93 | * `frame_height`, `frame_width`, `frame_channels' - Dimension to which frames would be resized to.
94 | * `centre_square_crop` - If center cropping is needed on frames or not
95 | * `video_filter_path` - null / location of txt file which contains video names that need to be taken for training. If `null` then all videos will be used(like moving mnist). For seeing how to construct this filter file look at `dataset/ucf_filter.txt` file
96 | * The existing `video_dataset.py` should require minimal modifications to adapt to your dataset requirements
97 |
98 | Once the config and dataset is setup:
99 | * First train the auto encoder on your dataset using [this section](#training-autoencoder-for-latte)
100 | * For training and inference of Unconditional Latte follow [this section](#training-unconditional-latte)
101 |
102 | ## Training AutoEncoder for Latte
103 | * We need to first extract frames for training autoencoder
104 | * By default, we extract only 10% of frames from our videos. Change `ae_frame_sample_prob` in `dataset_params` of config if you want to train on larger number of frames. For both moving mnist and ucf, 10% works fine.
105 | * If you need to train the latte model on a subset of videos then ensure the `video_filter_path` is correctly set. Look at ucf config and `dataset/ucf_filter.txt` for guidance.
106 | * Extract the frames by running `python -m capture_frames --config config/mnist.yaml` with the right config value
107 | * For training autoencoder
108 | * Minimal modifications might be needed to `dataset/iamge_dataset.py` to adapt to your own dataset.
109 | * Make sure the `frame_height` and `frame_width` parameters are correctly set for resizing frames(if needed)
110 | * Run ```python -m tools.train_vae --config config/mnist.yaml``` for training autoencoder with the right config file
111 | * In case you gpu memory is limited, I would suggest to run `python -m tools.save_latents --config config/mnist.yaml` with the correct config. This script would save the latent frames for all your dataset videos. Otherwise during diffusion model training we would have to load vae also.
112 |
113 | ## Training Unconditional Latte
114 | Train the autoencoder first and make changes to config and `video_dataset.py`(if any needed) to adapt to your requirements.
115 |
116 | * ```python -m tools.train_vae_ditv --config config/mnist.yaml``` for training unconditional Latte using right config
117 | * ```python -m tools.sample_vae_ditv --config config/mnist.yaml``` for generating videos using trained Latte model
118 |
119 |
120 | ## Output
121 | Outputs will be saved according to the configuration present in yaml files.
122 |
123 | For every run a folder of ```task_name``` key in config will be created.
124 |
125 | During frame extraction , folder name for the key `im_path` will be created in `task_name` directory and frames will be saved in there.
126 |
127 | During training of autoencoder the following output will be saved
128 | * Latest Autoencoder and discriminator checkpoint in ```task_name``` directory
129 | * Sample reconstructions in ```task_name/vae_autoencoder_samples```
130 |
131 | If `save_latents` script is run then latents will be saved in ```task_name/save_video_latent_dir``` if mentioned in config
132 |
133 | During training and inference of unconditional Latte following output will be saved:
134 | * During training we will save the latest checkpoint in ```task_name``` directory
135 | * During sampling, unconditional sampled video generated for all timesteps will be saved in ```task_name/samples/*.mp4``` . The final decoded generated video will be `sample_video_0.mp4`. Videos from `sample_output_999.mp4` to `sample_output_1.mp4` will be latent video predictions of denoising process from T=999 to T=1. Final Generated Video is at T=0
136 |
137 |
138 |
139 | ## Citations
140 | ```
141 | @misc{ma2024lattelatentdiffusiontransformer,
142 | title={Latte: Latent Diffusion Transformer for Video Generation},
143 | author={Xin Ma and Yaohui Wang and Gengyun Jia and Xinyuan Chen and Ziwei Liu and Yuan-Fang Li and Cunjian Chen and Yu Qiao},
144 | year={2024},
145 | eprint={2401.03048},
146 | archivePrefix={arXiv},
147 | primaryClass={cs.CV},
148 | url={https://arxiv.org/abs/2401.03048},
149 | }
150 | ```
151 |
152 |
153 |
154 |
155 |
156 |
--------------------------------------------------------------------------------
/capture_frames.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import random
3 | import argparse
4 | import glob
5 | import yaml
6 | import os
7 | from tqdm import tqdm
8 |
9 |
10 | def extract_frames_from_video(video_path, im_path, frame_sample_prob, count):
11 | video_obj = cv2.VideoCapture(video_path)
12 |
13 | success = 1
14 | while success:
15 | success, image = video_obj.read()
16 | if not success:
17 | break
18 | if random.random() > frame_sample_prob:
19 | continue
20 |
21 | cv2.imwrite('{}/frame_{}.png'.format(im_path, count), image)
22 | count += 1
23 | return count
24 |
25 |
26 | def extract_frames(args):
27 | # Read the config file #
28 | with open(args.config_path, 'r') as file:
29 | try:
30 | config = yaml.safe_load(file)
31 | except yaml.YAMLError as exc:
32 | print(exc)
33 | print(config)
34 | ########################
35 |
36 | dataset_config = config['dataset_params']
37 | task_name = config['train_params']['task_name']
38 | video_path = dataset_config['video_path']
39 | assert os.path.exists(video_path), "video path {} does not exist".format(video_path)
40 |
41 | im_path = os.path.join(task_name, dataset_config['im_path'])
42 | if not os.path.exists(im_path):
43 | os.mkdir(im_path)
44 | filter_fpath = dataset_config['video_filter_path']
45 | video_ext = dataset_config['video_ext']
46 | frame_sample_prob = dataset_config['ae_frame_sample_prob']
47 |
48 | # Create frames directory if not present
49 | if not os.path.exists(im_path):
50 | os.mkdir(im_path)
51 |
52 | video_paths = []
53 | if filter_fpath is None:
54 | filters = ['*']
55 | else:
56 | filters = []
57 | assert os.path.exists(filter_fpath), "Filter file not present"
58 | with open(filter_fpath, 'r') as f:
59 | for line in f.readlines():
60 | filters.append(line.strip())
61 | for filter_i in filters:
62 | for fname in glob.glob(os.path.join(video_path, '{}.{}'.format(filter_i,
63 | video_ext))):
64 | video_paths.append(fname)
65 |
66 | print('Found {} videos'.format(len(video_paths)))
67 | print('Extracting frames....')
68 | count = 0
69 | for video_path in tqdm(video_paths):
70 | count = extract_frames_from_video(video_path, im_path, frame_sample_prob, count)
71 |
72 | print('Extracted total {} frames'.format(count))
73 |
74 |
75 | if __name__ == '__main__':
76 | parser = argparse.ArgumentParser(description='Arguments for frame extraction '
77 | 'for autoencoder training')
78 | parser.add_argument('--config', dest='config_path',
79 | default='config/mnist.yaml', type=str)
80 | args = parser.parse_args()
81 | extract_frames(args)
82 |
--------------------------------------------------------------------------------
/config/mnist.yaml:
--------------------------------------------------------------------------------
1 | dataset_params:
2 | im_path: 'mmnist-easy-images'
3 | video_path: 'data/mmnist-easy'
4 | video_ext: 'mp4'
5 | num_images_train: 8
6 | frame_height : 72
7 | frame_width : 128
8 | frame_channels : 1
9 | num_frames: 16
10 | frame_interval: 2
11 | centre_square_crop: False
12 | video_filter_path: null
13 | ae_frame_sample_prob : 0.1
14 |
15 | diffusion_params:
16 | num_timesteps : 1000
17 | beta_start : 0.0001
18 | beta_end : 0.02
19 |
20 | ditv_params:
21 | patch_size : 2
22 | num_layers : 12
23 | hidden_size : 768
24 | num_heads : 12
25 | head_dim : 64
26 | timestep_emb_dim : 256
27 |
28 | autoencoder_params:
29 | z_channels: 4
30 | codebook_size : 20
31 | down_channels : [32, 64, 128]
32 | mid_channels : [128]
33 | down_sample : [True, True]
34 | attn_down : [False, False]
35 | norm_channels: 32
36 | num_heads: 16
37 | num_down_layers : 1
38 | num_mid_layers : 1
39 | num_up_layers : 1
40 |
41 |
42 | train_params:
43 | seed : 1111
44 | task_name: 'mmnist'
45 | autoencoder_batch_size: 64
46 | autoencoder_epochs: 20
47 | autoencoder_lr: 0.0001
48 | autoencoder_acc_steps: 1
49 | disc_start: 500
50 | disc_weight: 0.5
51 | codebook_weight: 1
52 | commitment_beta: 0.2
53 | perceptual_weight: 1
54 | kl_weight: 0.000005
55 | autoencoder_img_save_steps: 64
56 | save_latents: False
57 | ditv_batch_size: 4
58 | ditv_epochs: 300
59 | num_samples: 1
60 | ditv_lr: 0.0001
61 | ditv_acc_steps: 1
62 | save_video_latent_dir: 'video_latents'
63 | vae_latent_dir_name: 'vae_latents'
64 | ditv_ckpt_name: 'dit_ckpt.pth'
65 | vae_autoencoder_ckpt_name: 'vae_autoencoder_ckpt.pth'
66 | vae_discriminator_ckpt_name: 'vae_discriminator_ckpt.pth'
67 |
--------------------------------------------------------------------------------
/config/ucf.yaml:
--------------------------------------------------------------------------------
1 | dataset_params:
2 | im_path: 'ucf-images'
3 | video_path: 'data/UCF101'
4 | video_ext: 'avi'
5 | num_images_train : 8
6 | frame_height: 256
7 | frame_width: 256
8 | frame_channels: 3
9 | num_frames: 16
10 | frame_interval: 3
11 | centre_square_crop : True
12 | video_filter_path: 'dataset/ucf_filter.txt'
13 | ae_frame_sample_prob : 0.1
14 |
15 | diffusion_params:
16 | num_timesteps : 1000
17 | beta_start : 0.0001
18 | beta_end : 0.02
19 |
20 | ditv_params:
21 | patch_size : 2
22 | num_layers : 12
23 | hidden_size : 768
24 | num_heads : 12
25 | head_dim : 64
26 | timestep_emb_dim : 256
27 |
28 | autoencoder_params:
29 | z_channels: 4
30 | codebook_size : 8192
31 | down_channels : [128, 256, 384, 512]
32 | mid_channels : [512]
33 | down_sample : [True, True, True]
34 | attn_down : [False, False, False]
35 | norm_channels: 32
36 | num_heads: 4
37 | num_down_layers : 2
38 | num_mid_layers : 2
39 | num_up_layers : 2
40 |
41 |
42 | train_params:
43 | seed : 1111
44 | task_name: 'ucf'
45 | autoencoder_batch_size: 4
46 | autoencoder_epochs: 30
47 | autoencoder_lr: 0.00001
48 | autoencoder_acc_steps: 1
49 | disc_start: 5000
50 | disc_weight: 0.5
51 | codebook_weight: 1
52 | commitment_beta: 0.2
53 | perceptual_weight: 1
54 | kl_weight: 0.000005
55 | autoencoder_img_save_steps: 64
56 | save_latents: False
57 | ditv_batch_size: 4
58 | ditv_epochs: 1000
59 | num_samples: 1
60 | ditv_lr: 0.0001
61 | ditv_acc_steps: 1
62 | save_video_latent_dir: 'video_latents'
63 | vae_latent_dir_name: 'vae_latents'
64 | ditv_ckpt_name: 'dit_ckpt.pth'
65 | vae_autoencoder_ckpt_name: 'vae_autoencoder_ckpt.pth'
66 | vae_discriminator_ckpt_name: 'vae_discriminator_ckpt.pth'
67 |
--------------------------------------------------------------------------------
/dataset/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/explainingai-code/VideoGeneration-PyTorch/9d40d42a8f5a8919ce0d356ad54ebfce4c3089b8/dataset/__init__.py
--------------------------------------------------------------------------------
/dataset/image_dataset.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 | import torchvision
4 | from PIL import Image
5 | from torch.utils.data.dataset import Dataset
6 |
7 |
8 | class ImageDataset(Dataset):
9 | r"""
10 | Simple Image Dataset class for training autoencoder on frames
11 | of the videos
12 | """
13 | def __init__(self, split, dataset_config, task_name, im_ext='png'):
14 | r"""
15 | Init method for initializing the dataset properties
16 | :param split: train/test to locate the image files
17 | :param dataset_config: config parameters for dataset(mnist/ucf)
18 | :param im_ext: image extension. assumes all
19 | images would be this type.
20 | """
21 | self.split = split
22 | self.im_ext = im_ext
23 | self.frame_height = dataset_config['frame_height']
24 | self.frame_width = dataset_config['frame_width']
25 | self.frame_channels = dataset_config['frame_channels']
26 | self.center_square_crop = dataset_config['centre_square_crop']
27 | self.images = self.load_images(os.path.join(task_name, dataset_config['im_path']))
28 |
29 | def load_images(self, im_path):
30 | r"""
31 | Gets all images from the path specified
32 | and stacks them all up
33 | :param im_path:
34 | :return:
35 | """
36 | assert os.path.exists(im_path), "images path {} does not exist".format(im_path)
37 | ims = []
38 | for fname in glob.glob(os.path.join(im_path, '*.{}'.format(self.im_ext))):
39 | ims.append(fname)
40 | print('Found {} images for split {}'.format(len(ims), self.split))
41 | return ims
42 |
43 | def __len__(self):
44 | return len(self.images)
45 |
46 | def __getitem__(self, index):
47 | assert self.frame_channels in (1, 3), "Frame channels can only be 1/3"
48 | if self.frame_channels == 1:
49 | im = Image.open(self.images[index]).convert('L')
50 | else:
51 | im = Image.open(self.images[index]).convert('RGB')
52 | if self.center_square_crop:
53 | assert self.frame_height == self.frame_width, \
54 | ('For centre square crop frame_height '
55 | 'and frame_width should be same')
56 | im_tensor = torchvision.transforms.Compose([
57 | torchvision.transforms.Resize(self.frame_height),
58 | torchvision.transforms.CenterCrop(self.frame_height),
59 | torchvision.transforms.ToTensor(),
60 | ])(im)
61 | else:
62 | im_tensor = torchvision.transforms.Compose([
63 | torchvision.transforms.Resize((self.frame_height, self.frame_width)),
64 | torchvision.transforms.ToTensor(),
65 | ])(im)
66 |
67 | im.close()
68 | im_tensor = (2 * im_tensor) - 1
69 | return im_tensor
70 |
--------------------------------------------------------------------------------
/dataset/ucf_filter.txt:
--------------------------------------------------------------------------------
1 | *WallPushups*
2 | *_PushUps*
3 | *_PullUps*
4 | *_TaiChi*
--------------------------------------------------------------------------------
/dataset/video_dataset.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 | import cv2
4 | import torch
5 | import numpy as np
6 | import random
7 | import torchvision
8 | import pickle
9 | import torchvision.transforms.v2
10 | from PIL import Image
11 | from tqdm import tqdm
12 | from torch.utils.data.dataset import Dataset
13 |
14 |
15 | class VideoDataset(Dataset):
16 | r"""
17 | Simple Video Dataset class for training diffusion transformer
18 | for video generation.
19 | If latents are present, the dataset uses the saved latents for the videos,
20 | else it reads the video and extracts frames from it.
21 | """
22 | def __init__(self, split, dataset_config, latent_path=None, im_ext='png'):
23 | r"""
24 | Initialize all parameters and also check
25 | if latents are present or not
26 | :param split: for now this is always train
27 | :param dataset_config: config parameters for dataset(mnist/ucf)
28 | :param latent_path: Path for saved latents
29 | :param im_ext: assumes all images are of this extension. Used only
30 | if latents are not present
31 | """
32 | self.split = split
33 | self.video_ext = dataset_config['video_ext']
34 | self.num_images = dataset_config['num_images_train']
35 | self.use_images = self.num_images > 0
36 | self.num_frames = dataset_config['num_frames']
37 | self.frame_interval = dataset_config['frame_interval']
38 | self.frame_height = dataset_config['frame_height']
39 | self.frame_width = dataset_config['frame_width']
40 | self.frame_channels = dataset_config['frame_channels']
41 | self.center_square_crop = dataset_config['centre_square_crop']
42 | self.filter_fpath = dataset_config['video_filter_path']
43 | if self.center_square_crop:
44 | assert self.frame_height == self.frame_width, \
45 | ('For centre square crop frame_height '
46 | 'and frame_width should be same')
47 | self.transforms = torchvision.transforms.v2.Compose([
48 | torchvision.transforms.v2.Resize(self.frame_height),
49 | torchvision.transforms.v2.CenterCrop(self.frame_height),
50 | torchvision.transforms.v2.ToPureTensor(),
51 | torchvision.transforms.v2.ToDtype(torch.float32, scale=True),
52 | torchvision.transforms.v2.Normalize(mean=[0.5, 0.5, 0.5],
53 | std=[0.5, 0.5, 0.5])
54 | ])
55 | else:
56 | self.transforms = torchvision.transforms.v2.Compose([
57 | torchvision.transforms.v2.Resize((self.frame_height,
58 | self.frame_width)),
59 | torchvision.transforms.v2.ToPureTensor(),
60 | torchvision.transforms.v2.ToDtype(torch.float32, scale=True),
61 | torchvision.transforms.v2.Normalize(mean=[0.5, 0.5, 0.5],
62 | std=[0.5, 0.5, 0.5])
63 | ])
64 |
65 | # Load video paths for this dataset
66 | self.video_paths = self.load_videos(dataset_config['video_path'],
67 | self.filter_fpath)
68 |
69 | # Validate if latents are present and if they are present for all
70 | # videos. And only upon validation set `use_latents` as True
71 | self.use_latents = False
72 | if latent_path is not None and os.path.exists(latent_path):
73 | num_latents = len(glob.glob(os.path.join(latent_path, '*.pkl')))
74 | self.latents = glob.glob(os.path.join(latent_path, '*.pkl'))
75 | if num_latents == len(self.video_paths):
76 | self.use_latents = True
77 | print('Will use latents')
78 |
79 | def load_videos(self, video_path, filter_fpath=None):
80 | r"""
81 | Method to load all video names for training.
82 | This uses the filter file to use only selective videos
83 | for training.
84 | :param video_path: Path for all videos in the dataset
85 | :param filter_fpath: Path for file containing filters for relevant videos
86 | :return:
87 | """
88 | assert os.path.exists(video_path), (
89 | "video path {} does not exist".format(video_path))
90 | video_paths = []
91 |
92 | if filter_fpath is None:
93 | filters = ['*']
94 | else:
95 | filters = []
96 | assert os.path.exists(filter_fpath), "Filter file not present"
97 | with open(filter_fpath, 'r') as f:
98 | for line in f.readlines():
99 | filters.append(line.strip())
100 | for filter in filters:
101 | for fname in glob.glob(os.path.join(video_path,
102 | '{}.{}'.format(filter,
103 | self.video_ext))):
104 | video_paths.append(fname)
105 | print('Found {} videos'.format(len(video_paths)))
106 | return video_paths
107 |
108 | def __len__(self):
109 | return len(self.video_paths)
110 |
111 | def __getitem__(self, index):
112 | # We do things differently whether we are working with latents or not
113 | if self.use_latents:
114 | # Load the latent corresponding to this item
115 | latent_path = self.latents[index]
116 | latent = pickle.load(open(latent_path, 'rb')).cpu()
117 |
118 | # Sample (self.frame_interval * self.num_frames) frames
119 | # and from that take num_frames(16) equally spaced frames
120 | # Keep only the latents of these sampled frames
121 | num_frames = len(latent)
122 | total_frames = self.frame_interval * self.num_frames
123 | max_end = max(0, num_frames - total_frames - 1)
124 | start_index = random.randint(0, max_end)
125 | end_index = min(start_index + total_frames, num_frames)
126 | frame_idxs = np.linspace(start_index, end_index - 1, self.num_frames,
127 | dtype=int)
128 | latent = latent[frame_idxs]
129 |
130 | # From the latent extract the mean and std
131 | # and reparametrization to sample according to this
132 | # mean and std
133 | mean, logvar = torch.chunk(latent, 2, dim=1)
134 | std = torch.exp(0.5 * logvar)
135 | frames_tensor = mean + std * torch.randn(mean.shape)
136 |
137 | # If we are doing image+video joint training
138 | # then sample random num_images(8) videos
139 | # and then sample a random latent frame from that video
140 | if self.use_images:
141 | im_latents = []
142 | for _ in range(self.num_images):
143 | # Sample a random video
144 | # From this video sample a random saved latent
145 | # Extract mean, std from this latent and then
146 | # use that to get a im_latent
147 | video_idx = random.randint(0, len(self.latents)-1)
148 | random_video_latent = pickle.load(open(self.latents[video_idx],
149 | 'rb')).cpu()
150 | frame_idx = random.randint(0, len(random_video_latent) - 1)
151 | im_latent = random_video_latent[frame_idx]
152 | mean, logvar = torch.chunk(im_latent, 2, dim=0)
153 | std = torch.exp(0.5 * logvar)
154 | im_latent = mean + std * torch.randn(mean.shape)
155 | im_latents.append(im_latent.unsqueeze(0))
156 |
157 | im_latents = torch.cat(im_latents)
158 |
159 | # Concat video latents and image latents together for training
160 | frames_tensor = torch.cat([frames_tensor, im_latents], dim=0)
161 | return frames_tensor
162 | else:
163 | # Read the video corresponding to this item
164 | path = self.video_paths[index]
165 | frames, _, _ = torchvision.io.read_video(filename=path,
166 | pts_unit='sec',
167 | output_format='TCHW')
168 |
169 | # Sample (self.frame_interval * self.num_frames) frames
170 | # and from that take num_frames(16) equally spaced frames
171 | # Keep only these sampled frames
172 | num_frames = len(frames)
173 | max_end = max(0, num_frames - (self.num_frames * self.frame_interval) - 1)
174 | start_index = random.randint(0, max_end)
175 | end_index = min(start_index + (self.num_frames * self.frame_interval),
176 | num_frames)
177 | frame_idxs = np.linspace(start_index, end_index - 1, self.num_frames,
178 | dtype=int)
179 | frames = frames[frame_idxs]
180 |
181 | # Resize frames according to transformations
182 | # desired based on config
183 | frames_tensor = self.transforms(frames)
184 |
185 | # For grayscale keep only the first channel
186 | if self.frame_channels == 1:
187 | frames_tensor = frames_tensor[:, 0:1, :, :]
188 |
189 | # If we are doing image+video joint training
190 | # then sample random num_images(8) videos
191 | # and then sample a random frame from that video
192 | if self.use_images:
193 | im_tensors = []
194 | for _ in range(self.num_images):
195 | # Sample a random video
196 | # From this video sample a random image frame
197 | video_idx = random.randint(0, len(self.video_paths)-1)
198 | path = self.video_paths[video_idx]
199 | ims, _, _ = torchvision.io.read_video(filename=path,
200 | pts_unit='sec',
201 | output_format='TCHW')
202 | im_idx = random.randint(0, len(ims) - 1)
203 | ims = ims[im_idx]
204 |
205 | # Resize this sampled image according to transformations
206 | # desired based on config
207 | im_tensor = self.transforms(ims)
208 |
209 | # For grayscale keep only the first channel
210 | if self.frame_channels == 1:
211 | im_tensor = im_tensor[0:1, :, :]
212 | im_tensors.append(im_tensor.unsqueeze(0))
213 |
214 | im_tensors = torch.cat(im_tensors)
215 | frames_tensor = torch.cat([frames_tensor, im_tensors], dim=0)
216 | return frames_tensor
217 |
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/explainingai-code/VideoGeneration-PyTorch/9d40d42a8f5a8919ce0d356ad54ebfce4c3089b8/model/__init__.py
--------------------------------------------------------------------------------
/model/attention.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from einops import rearrange
4 |
5 |
6 | class Attention(nn.Module):
7 | r"""
8 | Attention Module for DiT.
9 | This is same as VIT code and does not have any changes
10 | from it.
11 | """
12 | def __init__(self, config):
13 | super().__init__()
14 | self.n_heads = config['num_heads']
15 | self.hidden_size = config['hidden_size']
16 | self.head_dim = config['head_dim']
17 |
18 | self.att_dim = self.n_heads * self.head_dim
19 |
20 | # QKV projection for the input
21 | self.qkv_proj = nn.Linear(self.hidden_size, 3 * self.att_dim, bias=True)
22 | self.output_proj = nn.Sequential(
23 | nn.Linear(self.att_dim, self.hidden_size))
24 |
25 | ############################
26 | # DiT Layer Initialization #
27 | ############################
28 | nn.init.xavier_uniform_(self.qkv_proj.weight)
29 | nn.init.constant_(self.qkv_proj.bias, 0)
30 | nn.init.xavier_uniform_(self.output_proj[0].weight)
31 | nn.init.constant_(self.output_proj[0].bias, 0)
32 |
33 | def forward(self, x):
34 | # Converting to Attention Dimension
35 | ######################################################
36 | # Batch Size x Number of Patches x Dimension
37 | B, N = x.shape[:2]
38 | # Projecting to 3*att_dim and then splitting to get q, k v(each of att_dim)
39 | # qkv -> Batch Size x Number of Patches x (3* Attention Dimension)
40 | # q(as well as k and v) -> Batch Size x Number of Patches x Attention Dimension
41 | q, k, v = self.qkv_proj(x).split(self.att_dim, dim=-1)
42 | # Batch Size x Number of Patches x Attention Dimension
43 | # -> Batch Size x Number of Patches x (Heads * Head Dimension)
44 | # -> Batch Size x Number of Patches x (Heads * Head Dimension)
45 | # -> Batch Size x Heads x Number of Patches x Head Dimension
46 | # -> B x H x N x Head Dimension
47 | q = rearrange(q, 'b n (n_h h_dim) -> b n_h n h_dim',
48 | n_h=self.n_heads, h_dim=self.head_dim)
49 | k = rearrange(k, 'b n (n_h h_dim) -> b n_h n h_dim',
50 | n_h=self.n_heads, h_dim=self.head_dim)
51 | v = rearrange(v, 'b n (n_h h_dim) -> b n_h n h_dim',
52 | n_h=self.n_heads, h_dim=self.head_dim)
53 | #########################################################
54 |
55 | # Compute Attention Weights
56 | #########################################################
57 | # B x H x N x Head Dimension @ B x H x Head Dimension x N
58 | # -> B x H x N x N
59 | att = torch.matmul(q, k.transpose(-2, -1)) * (self.head_dim ** (-0.5))
60 | att = torch.nn.functional.softmax(att, dim=-1)
61 | #########################################################
62 |
63 | # Weighted Value Computation
64 | #########################################################
65 | # B x H x N x N @ B x H x N x Head Dimension
66 | # -> B x H x N x Head Dimension
67 | out = torch.matmul(att, v)
68 | #########################################################
69 |
70 | # Converting to Transformer Dimension
71 | #########################################################
72 | # B x N x (Heads * Head Dimension) -> B x N x (Attention Dimension)
73 | out = rearrange(out, 'b n_h n h_dim -> b n (n_h h_dim)')
74 | # B x N x Dimension
75 | out = self.output_proj(out)
76 | ##########################################################
77 |
78 | return out
79 |
--------------------------------------------------------------------------------
/model/blocks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | def get_time_embedding(time_steps, temb_dim):
6 | r"""
7 | Convert time steps tensor into an embedding using the
8 | sinusoidal time embedding formula
9 | :param time_steps: 1D tensor of length batch size
10 | :param temb_dim: Dimension of the embedding
11 | :return: BxD embedding representation of B time steps
12 | """
13 | assert temb_dim % 2 == 0, "time embedding dimension must be divisible by 2"
14 |
15 | # factor = 10000^(2i/d_model)
16 | factor = 10000 ** ((torch.arange(
17 | start=0, end=temb_dim // 2, dtype=torch.float32, device=time_steps.device) / (temb_dim // 2))
18 | )
19 |
20 | # pos / factor
21 | # timesteps B -> B, 1 -> B, temb_dim
22 | t_emb = time_steps[:, None].repeat(1, temb_dim // 2) / factor
23 | t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1)
24 | return t_emb
25 |
26 |
27 | class DownBlock(nn.Module):
28 | r"""
29 | Down conv block with attention.
30 | Sequence of following block
31 | 1. Resnet block with time embedding
32 | 2. Attention block
33 | 3. Downsample
34 | """
35 |
36 | def __init__(self, in_channels, out_channels, t_emb_dim,
37 | down_sample, num_heads, num_layers, attn, norm_channels, cross_attn=False, context_dim=None):
38 | super().__init__()
39 | self.num_layers = num_layers
40 | self.down_sample = down_sample
41 | self.attn = attn
42 | self.context_dim = context_dim
43 | self.cross_attn = cross_attn
44 | self.t_emb_dim = t_emb_dim
45 | self.resnet_conv_first = nn.ModuleList(
46 | [
47 | nn.Sequential(
48 | nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels),
49 | nn.SiLU(),
50 | nn.Conv2d(in_channels if i == 0 else out_channels, out_channels,
51 | kernel_size=3, stride=1, padding=1),
52 | )
53 | for i in range(num_layers)
54 | ]
55 | )
56 | if self.t_emb_dim is not None:
57 | self.t_emb_layers = nn.ModuleList([
58 | nn.Sequential(
59 | nn.SiLU(),
60 | nn.Linear(self.t_emb_dim, out_channels)
61 | )
62 | for _ in range(num_layers)
63 | ])
64 | self.resnet_conv_second = nn.ModuleList(
65 | [
66 | nn.Sequential(
67 | nn.GroupNorm(norm_channels, out_channels),
68 | nn.SiLU(),
69 | nn.Conv2d(out_channels, out_channels,
70 | kernel_size=3, stride=1, padding=1),
71 | )
72 | for _ in range(num_layers)
73 | ]
74 | )
75 |
76 | if self.attn:
77 | self.attention_norms = nn.ModuleList(
78 | [nn.GroupNorm(norm_channels, out_channels)
79 | for _ in range(num_layers)]
80 | )
81 |
82 | self.attentions = nn.ModuleList(
83 | [nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
84 | for _ in range(num_layers)]
85 | )
86 |
87 | if self.cross_attn:
88 | assert context_dim is not None, "Context Dimension must be passed for cross attention"
89 | self.cross_attention_norms = nn.ModuleList(
90 | [nn.GroupNorm(norm_channels, out_channels)
91 | for _ in range(num_layers)]
92 | )
93 | self.cross_attentions = nn.ModuleList(
94 | [nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
95 | for _ in range(num_layers)]
96 | )
97 | self.context_proj = nn.ModuleList(
98 | [nn.Linear(context_dim, out_channels)
99 | for _ in range(num_layers)]
100 | )
101 |
102 | self.residual_input_conv = nn.ModuleList(
103 | [
104 | nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
105 | for i in range(num_layers)
106 | ]
107 | )
108 | self.down_sample_conv = nn.Conv2d(out_channels, out_channels,
109 | 4, 2, 1) if self.down_sample else nn.Identity()
110 |
111 | def forward(self, x, t_emb=None, context=None):
112 | out = x
113 | for i in range(self.num_layers):
114 | # Resnet block of Unet
115 | resnet_input = out
116 | out = self.resnet_conv_first[i](out)
117 | if self.t_emb_dim is not None:
118 | out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
119 | out = self.resnet_conv_second[i](out)
120 | out = out + self.residual_input_conv[i](resnet_input)
121 |
122 | if self.attn:
123 | # Attention block of Unet
124 | batch_size, channels, h, w = out.shape
125 | in_attn = out.reshape(batch_size, channels, h * w)
126 | in_attn = self.attention_norms[i](in_attn)
127 | in_attn = in_attn.transpose(1, 2)
128 | out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
129 | out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
130 | out = out + out_attn
131 |
132 | if self.cross_attn:
133 | assert context is not None, "context cannot be None if cross attention layers are used"
134 | batch_size, channels, h, w = out.shape
135 | in_attn = out.reshape(batch_size, channels, h * w)
136 | in_attn = self.cross_attention_norms[i](in_attn)
137 | in_attn = in_attn.transpose(1, 2)
138 | assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim
139 | context_proj = self.context_proj[i](context)
140 | out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj)
141 | out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
142 | out = out + out_attn
143 |
144 | # Downsample
145 | out = self.down_sample_conv(out)
146 | return out
147 |
148 |
149 | class MidBlock(nn.Module):
150 | r"""
151 | Mid conv block with attention.
152 | Sequence of following blocks
153 | 1. Resnet block with time embedding
154 | 2. Attention block
155 | 3. Resnet block with time embedding
156 | """
157 |
158 | def __init__(self, in_channels, out_channels, t_emb_dim, num_heads, num_layers, norm_channels, cross_attn=None,
159 | context_dim=None):
160 | super().__init__()
161 | self.num_layers = num_layers
162 | self.t_emb_dim = t_emb_dim
163 | self.context_dim = context_dim
164 | self.cross_attn = cross_attn
165 | self.resnet_conv_first = nn.ModuleList(
166 | [
167 | nn.Sequential(
168 | nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels),
169 | nn.SiLU(),
170 | nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1,
171 | padding=1),
172 | )
173 | for i in range(num_layers + 1)
174 | ]
175 | )
176 |
177 | if self.t_emb_dim is not None:
178 | self.t_emb_layers = nn.ModuleList([
179 | nn.Sequential(
180 | nn.SiLU(),
181 | nn.Linear(t_emb_dim, out_channels)
182 | )
183 | for _ in range(num_layers + 1)
184 | ])
185 | self.resnet_conv_second = nn.ModuleList(
186 | [
187 | nn.Sequential(
188 | nn.GroupNorm(norm_channels, out_channels),
189 | nn.SiLU(),
190 | nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
191 | )
192 | for _ in range(num_layers + 1)
193 | ]
194 | )
195 |
196 | self.attention_norms = nn.ModuleList(
197 | [nn.GroupNorm(norm_channels, out_channels)
198 | for _ in range(num_layers)]
199 | )
200 |
201 | self.attentions = nn.ModuleList(
202 | [nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
203 | for _ in range(num_layers)]
204 | )
205 | if self.cross_attn:
206 | assert context_dim is not None, "Context Dimension must be passed for cross attention"
207 | self.cross_attention_norms = nn.ModuleList(
208 | [nn.GroupNorm(norm_channels, out_channels)
209 | for _ in range(num_layers)]
210 | )
211 | self.cross_attentions = nn.ModuleList(
212 | [nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
213 | for _ in range(num_layers)]
214 | )
215 | self.context_proj = nn.ModuleList(
216 | [nn.Linear(context_dim, out_channels)
217 | for _ in range(num_layers)]
218 | )
219 | self.residual_input_conv = nn.ModuleList(
220 | [
221 | nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
222 | for i in range(num_layers + 1)
223 | ]
224 | )
225 |
226 | def forward(self, x, t_emb=None, context=None):
227 | out = x
228 |
229 | # First resnet block
230 | resnet_input = out
231 | out = self.resnet_conv_first[0](out)
232 | if self.t_emb_dim is not None:
233 | out = out + self.t_emb_layers[0](t_emb)[:, :, None, None]
234 | out = self.resnet_conv_second[0](out)
235 | out = out + self.residual_input_conv[0](resnet_input)
236 |
237 | for i in range(self.num_layers):
238 | # Attention Block
239 | batch_size, channels, h, w = out.shape
240 | in_attn = out.reshape(batch_size, channels, h * w)
241 | in_attn = self.attention_norms[i](in_attn)
242 | in_attn = in_attn.transpose(1, 2)
243 | out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
244 | out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
245 | out = out + out_attn
246 |
247 | if self.cross_attn:
248 | assert context is not None, "context cannot be None if cross attention layers are used"
249 | batch_size, channels, h, w = out.shape
250 | in_attn = out.reshape(batch_size, channels, h * w)
251 | in_attn = self.cross_attention_norms[i](in_attn)
252 | in_attn = in_attn.transpose(1, 2)
253 | assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim
254 | context_proj = self.context_proj[i](context)
255 | out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj)
256 | out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
257 | out = out + out_attn
258 |
259 | # Resnet Block
260 | resnet_input = out
261 | out = self.resnet_conv_first[i + 1](out)
262 | if self.t_emb_dim is not None:
263 | out = out + self.t_emb_layers[i + 1](t_emb)[:, :, None, None]
264 | out = self.resnet_conv_second[i + 1](out)
265 | out = out + self.residual_input_conv[i + 1](resnet_input)
266 |
267 | return out
268 |
269 |
270 | class UpBlock(nn.Module):
271 | r"""
272 | Up conv block with attention.
273 | Sequence of following blocks
274 | 1. Upsample
275 | 1. Concatenate Down block output
276 | 2. Resnet block with time embedding
277 | 3. Attention Block
278 | """
279 |
280 | def __init__(self, in_channels, out_channels, t_emb_dim,
281 | up_sample, num_heads, num_layers, attn, norm_channels):
282 | super().__init__()
283 | self.num_layers = num_layers
284 | self.up_sample = up_sample
285 | self.t_emb_dim = t_emb_dim
286 | self.attn = attn
287 | self.resnet_conv_first = nn.ModuleList(
288 | [
289 | nn.Sequential(
290 | nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels),
291 | nn.SiLU(),
292 | nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1,
293 | padding=1),
294 | )
295 | for i in range(num_layers)
296 | ]
297 | )
298 |
299 | if self.t_emb_dim is not None:
300 | self.t_emb_layers = nn.ModuleList([
301 | nn.Sequential(
302 | nn.SiLU(),
303 | nn.Linear(t_emb_dim, out_channels)
304 | )
305 | for _ in range(num_layers)
306 | ])
307 |
308 | self.resnet_conv_second = nn.ModuleList(
309 | [
310 | nn.Sequential(
311 | nn.GroupNorm(norm_channels, out_channels),
312 | nn.SiLU(),
313 | nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
314 | )
315 | for _ in range(num_layers)
316 | ]
317 | )
318 | if self.attn:
319 | self.attention_norms = nn.ModuleList(
320 | [
321 | nn.GroupNorm(norm_channels, out_channels)
322 | for _ in range(num_layers)
323 | ]
324 | )
325 |
326 | self.attentions = nn.ModuleList(
327 | [
328 | nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
329 | for _ in range(num_layers)
330 | ]
331 | )
332 |
333 | self.residual_input_conv = nn.ModuleList(
334 | [
335 | nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
336 | for i in range(num_layers)
337 | ]
338 | )
339 | self.up_sample_conv = nn.ConvTranspose2d(in_channels, in_channels,
340 | 4, 2, 1) \
341 | if self.up_sample else nn.Identity()
342 |
343 | def forward(self, x, out_down=None, t_emb=None):
344 | # Upsample
345 | x = self.up_sample_conv(x)
346 |
347 | # Concat with Downblock output
348 | if out_down is not None:
349 | x = torch.cat([x, out_down], dim=1)
350 |
351 | out = x
352 | for i in range(self.num_layers):
353 | # Resnet Block
354 | resnet_input = out
355 | out = self.resnet_conv_first[i](out)
356 | if self.t_emb_dim is not None:
357 | out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
358 | out = self.resnet_conv_second[i](out)
359 | out = out + self.residual_input_conv[i](resnet_input)
360 |
361 | # Self Attention
362 | if self.attn:
363 | batch_size, channels, h, w = out.shape
364 | in_attn = out.reshape(batch_size, channels, h * w)
365 | in_attn = self.attention_norms[i](in_attn)
366 | in_attn = in_attn.transpose(1, 2)
367 | out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
368 | out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
369 | out = out + out_attn
370 | return out
371 |
372 |
373 | class UpBlockUnet(nn.Module):
374 | r"""
375 | Up conv block with attention.
376 | Sequence of following blocks
377 | 1. Upsample
378 | 1. Concatenate Down block output
379 | 2. Resnet block with time embedding
380 | 3. Attention Block
381 | """
382 |
383 | def __init__(self, in_channels, out_channels, t_emb_dim, up_sample,
384 | num_heads, num_layers, norm_channels, cross_attn=False, context_dim=None):
385 | super().__init__()
386 | self.num_layers = num_layers
387 | self.up_sample = up_sample
388 | self.t_emb_dim = t_emb_dim
389 | self.cross_attn = cross_attn
390 | self.context_dim = context_dim
391 | self.resnet_conv_first = nn.ModuleList(
392 | [
393 | nn.Sequential(
394 | nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels),
395 | nn.SiLU(),
396 | nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1,
397 | padding=1),
398 | )
399 | for i in range(num_layers)
400 | ]
401 | )
402 |
403 | if self.t_emb_dim is not None:
404 | self.t_emb_layers = nn.ModuleList([
405 | nn.Sequential(
406 | nn.SiLU(),
407 | nn.Linear(t_emb_dim, out_channels)
408 | )
409 | for _ in range(num_layers)
410 | ])
411 |
412 | self.resnet_conv_second = nn.ModuleList(
413 | [
414 | nn.Sequential(
415 | nn.GroupNorm(norm_channels, out_channels),
416 | nn.SiLU(),
417 | nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
418 | )
419 | for _ in range(num_layers)
420 | ]
421 | )
422 |
423 | self.attention_norms = nn.ModuleList(
424 | [
425 | nn.GroupNorm(norm_channels, out_channels)
426 | for _ in range(num_layers)
427 | ]
428 | )
429 |
430 | self.attentions = nn.ModuleList(
431 | [
432 | nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
433 | for _ in range(num_layers)
434 | ]
435 | )
436 |
437 | if self.cross_attn:
438 | assert context_dim is not None, "Context Dimension must be passed for cross attention"
439 | self.cross_attention_norms = nn.ModuleList(
440 | [nn.GroupNorm(norm_channels, out_channels)
441 | for _ in range(num_layers)]
442 | )
443 | self.cross_attentions = nn.ModuleList(
444 | [nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
445 | for _ in range(num_layers)]
446 | )
447 | self.context_proj = nn.ModuleList(
448 | [nn.Linear(context_dim, out_channels)
449 | for _ in range(num_layers)]
450 | )
451 | self.residual_input_conv = nn.ModuleList(
452 | [
453 | nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
454 | for i in range(num_layers)
455 | ]
456 | )
457 | self.up_sample_conv = nn.ConvTranspose2d(in_channels // 2, in_channels // 2,
458 | 4, 2, 1) \
459 | if self.up_sample else nn.Identity()
460 |
461 | def forward(self, x, out_down=None, t_emb=None, context=None):
462 | x = self.up_sample_conv(x)
463 | if out_down is not None:
464 | x = torch.cat([x, out_down], dim=1)
465 |
466 | out = x
467 | for i in range(self.num_layers):
468 | # Resnet
469 | resnet_input = out
470 | out = self.resnet_conv_first[i](out)
471 | if self.t_emb_dim is not None:
472 | out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
473 | out = self.resnet_conv_second[i](out)
474 | out = out + self.residual_input_conv[i](resnet_input)
475 | # Self Attention
476 | batch_size, channels, h, w = out.shape
477 | in_attn = out.reshape(batch_size, channels, h * w)
478 | in_attn = self.attention_norms[i](in_attn)
479 | in_attn = in_attn.transpose(1, 2)
480 | out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
481 | out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
482 | out = out + out_attn
483 | # Cross Attention
484 | if self.cross_attn:
485 | assert context is not None, "context cannot be None if cross attention layers are used"
486 | batch_size, channels, h, w = out.shape
487 | in_attn = out.reshape(batch_size, channels, h * w)
488 | in_attn = self.cross_attention_norms[i](in_attn)
489 | in_attn = in_attn.transpose(1, 2)
490 | assert len(context.shape) == 3, \
491 | "Context shape does not match B,_,CONTEXT_DIM"
492 | assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim, \
493 | "Context shape does not match B,_,CONTEXT_DIM"
494 | context_proj = self.context_proj[i](context)
495 | out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj)
496 | out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
497 | out = out + out_attn
498 |
499 | return out
500 |
501 |
502 |
--------------------------------------------------------------------------------
/model/discriminator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class Discriminator(nn.Module):
6 | r"""
7 | PatchGAN Discriminator.
8 | Rather than taking IMG_CHANNELSxIMG_HxIMG_W all the way to
9 | 1 scalar value , we instead predict grid of values.
10 | Where each grid is prediction of how likely
11 | the discriminator thinks that the image patch corresponding
12 | to the grid cell is real
13 | """
14 |
15 | def __init__(self, im_channels=3,
16 | conv_channels=[64, 128, 256],
17 | kernels=[4,4,4,4],
18 | strides=[2,2,2,1],
19 | paddings=[1,1,1,1]):
20 | super().__init__()
21 | self.im_channels = im_channels
22 | activation = nn.LeakyReLU(0.2)
23 | layers_dim = [self.im_channels] + conv_channels + [1]
24 | self.layers = nn.ModuleList([
25 | nn.Sequential(
26 | nn.Conv2d(layers_dim[i], layers_dim[i + 1],
27 | kernel_size=kernels[i],
28 | stride=strides[i],
29 | padding=paddings[i],
30 | bias=False if i !=0 else True),
31 | nn.BatchNorm2d(layers_dim[i + 1]) if i != len(layers_dim) - 2 and i != 0 else nn.Identity(),
32 | activation if i != len(layers_dim) - 2 else nn.Identity()
33 | )
34 | for i in range(len(layers_dim) - 1)
35 | ])
36 |
37 | def forward(self, x):
38 | out = x
39 | for layer in self.layers:
40 | out = layer(out)
41 | return out
42 |
43 |
44 | if __name__ == '__main__':
45 | x = torch.randn((2,3, 256, 256))
46 | prob = Discriminator(im_channels=3)(x)
47 | print(prob.shape)
48 |
--------------------------------------------------------------------------------
/model/lpips.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from collections import namedtuple
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.init as init
6 | from torch.autograd import Variable
7 | import numpy as np
8 | import torch.nn
9 | import torchvision
10 |
11 | # Taken from https://github.com/richzhang/PerceptualSimilarity/blob/master/lpips/lpips.py
12 |
13 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
14 | if torch.backends.mps.is_available():
15 | device = torch.device('mps')
16 | print('Using mps')
17 |
18 | def spatial_average(in_tens, keepdim=True):
19 | return in_tens.mean([2, 3], keepdim=keepdim)
20 |
21 |
22 | class vgg16(torch.nn.Module):
23 | def __init__(self, requires_grad=False, pretrained=True):
24 | super(vgg16, self).__init__()
25 | # Load pretrained vgg model from torchvision
26 | vgg_pretrained_features = torchvision.models.vgg16(pretrained=pretrained).features
27 | self.slice1 = torch.nn.Sequential()
28 | self.slice2 = torch.nn.Sequential()
29 | self.slice3 = torch.nn.Sequential()
30 | self.slice4 = torch.nn.Sequential()
31 | self.slice5 = torch.nn.Sequential()
32 | self.N_slices = 5
33 | for x in range(4):
34 | self.slice1.add_module(str(x), vgg_pretrained_features[x])
35 | for x in range(4, 9):
36 | self.slice2.add_module(str(x), vgg_pretrained_features[x])
37 | for x in range(9, 16):
38 | self.slice3.add_module(str(x), vgg_pretrained_features[x])
39 | for x in range(16, 23):
40 | self.slice4.add_module(str(x), vgg_pretrained_features[x])
41 | for x in range(23, 30):
42 | self.slice5.add_module(str(x), vgg_pretrained_features[x])
43 |
44 | # Freeze vgg model
45 | if not requires_grad:
46 | for param in self.parameters():
47 | param.requires_grad = False
48 |
49 | def forward(self, X):
50 | # Return output of vgg features
51 | h = self.slice1(X)
52 | h_relu1_2 = h
53 | h = self.slice2(h)
54 | h_relu2_2 = h
55 | h = self.slice3(h)
56 | h_relu3_3 = h
57 | h = self.slice4(h)
58 | h_relu4_3 = h
59 | h = self.slice5(h)
60 | h_relu5_3 = h
61 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
62 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
63 | return out
64 |
65 |
66 | # Learned perceptual metric
67 | class LPIPS(nn.Module):
68 | def __init__(self, net='vgg', version='0.1', use_dropout=True):
69 | super(LPIPS, self).__init__()
70 | self.version = version
71 | # Imagenet normalization
72 | self.scaling_layer = ScalingLayer()
73 | ########################
74 |
75 | # Instantiate vgg model
76 | self.chns = [64, 128, 256, 512, 512]
77 | self.L = len(self.chns)
78 | self.net = vgg16(pretrained=True, requires_grad=False)
79 |
80 | # Add 1x1 convolutional Layers
81 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
82 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
83 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
84 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
85 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
86 | self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
87 | self.lins = nn.ModuleList(self.lins)
88 | ########################
89 |
90 | # Load the weights of trained LPIPS model
91 | import inspect
92 | import os
93 | model_path = os.path.abspath(
94 | os.path.join(inspect.getfile(self.__init__), '..', 'weights/v%s/%s.pth' % (version, net)))
95 | print('Loading model from: %s' % model_path)
96 | self.load_state_dict(torch.load(model_path, map_location=device), strict=False)
97 | ########################
98 |
99 | # Freeze all parameters
100 | self.eval()
101 | for param in self.parameters():
102 | param.requires_grad = False
103 | ########################
104 |
105 | def forward(self, in0, in1, normalize=False):
106 | # Scale the inputs to -1 to +1 range if needed
107 | if normalize: # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1]
108 | in0 = 2 * in0 - 1
109 | in1 = 2 * in1 - 1
110 | ########################
111 |
112 | # Normalize the inputs according to imagenet normalization
113 | in0_input, in1_input = self.scaling_layer(in0), self.scaling_layer(in1)
114 | ########################
115 |
116 | # Get VGG outputs for image0 and image1
117 | outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
118 | feats0, feats1, diffs = {}, {}, {}
119 | ########################
120 |
121 | # Compute Square of Difference for each layer output
122 | for kk in range(self.L):
123 | feats0[kk], feats1[kk] = torch.nn.functional.normalize(outs0[kk], dim=1), torch.nn.functional.normalize(
124 | outs1[kk])
125 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
126 | ########################
127 |
128 | # 1x1 convolution followed by spatial average on the square differences
129 | res = [spatial_average(self.lins[kk](diffs[kk]), keepdim=True) for kk in range(self.L)]
130 | val = 0
131 |
132 | # Aggregate the results of each layer
133 | for l in range(self.L):
134 | val += res[l]
135 | return val
136 |
137 |
138 | class ScalingLayer(nn.Module):
139 | def __init__(self):
140 | super(ScalingLayer, self).__init__()
141 | # Imagnet normalization for (0-1)
142 | # mean = [0.485, 0.456, 0.406]
143 | # std = [0.229, 0.224, 0.225]
144 | self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
145 | self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])
146 |
147 | def forward(self, inp):
148 | return (inp - self.shift) / self.scale
149 |
150 |
151 | class NetLinLayer(nn.Module):
152 | ''' A single linear layer which does a 1x1 conv '''
153 |
154 | def __init__(self, chn_in, chn_out=1, use_dropout=False):
155 | super(NetLinLayer, self).__init__()
156 |
157 | layers = [nn.Dropout(), ] if (use_dropout) else []
158 | layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]
159 | self.model = nn.Sequential(*layers)
160 |
161 | def forward(self, x):
162 | out = self.model(x)
163 | return out
164 |
--------------------------------------------------------------------------------
/model/patch_embed.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from einops import rearrange
4 |
5 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
6 |
7 |
8 | def get_patch_position_embedding(pos_emb_dim, grid_size, device):
9 | assert pos_emb_dim % 4 == 0, 'Position embedding dimension must be divisible by 4'
10 | grid_size_h, grid_size_w = grid_size
11 | grid_h = torch.arange(grid_size_h, dtype=torch.float32, device=device)
12 | grid_w = torch.arange(grid_size_w, dtype=torch.float32, device=device)
13 | grid = torch.meshgrid(grid_h, grid_w, indexing='ij')
14 | grid = torch.stack(grid, dim=0)
15 |
16 | # grid_h_positions -> (Number of patch tokens,)
17 | grid_h_positions = grid[0].reshape(-1)
18 | grid_w_positions = grid[1].reshape(-1)
19 |
20 | # factor = 10000^(2i/d_model)
21 | factor = 10000 ** ((torch.arange(
22 | start=0,
23 | end=pos_emb_dim // 4,
24 | dtype=torch.float32,
25 | device=device) / (pos_emb_dim // 4))
26 | )
27 |
28 | grid_h_emb = grid_h_positions[:, None].repeat(1, pos_emb_dim // 4) / factor
29 | grid_h_emb = torch.cat([torch.sin(grid_h_emb), torch.cos(grid_h_emb)], dim=-1)
30 | # grid_h_emb -> (Number of patch tokens, pos_emb_dim // 2)
31 |
32 | grid_w_emb = grid_w_positions[:, None].repeat(1, pos_emb_dim // 4) / factor
33 | grid_w_emb = torch.cat([torch.sin(grid_w_emb), torch.cos(grid_w_emb)], dim=-1)
34 | pos_emb = torch.cat([grid_h_emb, grid_w_emb], dim=-1)
35 |
36 | # pos_emb -> (Number of patch tokens, pos_emb_dim)
37 | return pos_emb
38 |
39 |
40 | class PatchEmbedding(nn.Module):
41 | r"""
42 | Layer to take in the input image and do the following:
43 | 1. Transform grid of image patches into a sequence of patches.
44 | Number of patches are decided based on image height,width and
45 | patch height, width.
46 | 2. Add positional embedding to the above sequence
47 | """
48 |
49 | def __init__(self,
50 | image_height,
51 | image_width,
52 | im_channels,
53 | patch_height,
54 | patch_width,
55 | hidden_size):
56 | super().__init__()
57 | self.image_height = image_height
58 | self.image_width = image_width
59 | self.im_channels = im_channels
60 |
61 | self.hidden_size = hidden_size
62 |
63 | self.patch_height = patch_height
64 | self.patch_width = patch_width
65 |
66 | # Input dimension for Patch Embedding FC Layer
67 | patch_dim = self.im_channels * self.patch_height * self.patch_width
68 | self.patch_embed = nn.Sequential(
69 | nn.Linear(patch_dim, self.hidden_size)
70 | )
71 |
72 | ############################
73 | # DiT Layer Initialization #
74 | ############################
75 | nn.init.xavier_uniform_(self.patch_embed[0].weight)
76 | nn.init.constant_(self.patch_embed[0].bias, 0)
77 |
78 | def forward(self, x):
79 | grid_size_h = self.image_height // self.patch_height
80 | grid_size_w = self.image_width // self.patch_width
81 |
82 | # B, C, H, W -> B, (Patches along height * Patches along width), Patch Dimension
83 | # Number of tokens = Patches along height * Patches along width
84 | out = rearrange(x, 'b c (nh ph) (nw pw) -> b (nh nw) (ph pw c)',
85 | ph=self.patch_height,
86 | pw=self.patch_width)
87 |
88 | # BxNumber of tokens x Patch Dimension -> B x Number of tokens x Transformer Dimension
89 | out = self.patch_embed(out)
90 |
91 | # Add 2d sinusoidal position embeddings
92 | pos_embed = get_patch_position_embedding(pos_emb_dim=self.hidden_size,
93 | grid_size=(grid_size_h, grid_size_w),
94 | device=x.device)
95 | out += pos_embed
96 | return out
97 |
98 |
--------------------------------------------------------------------------------
/model/transformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from model.patch_embed import PatchEmbedding
4 | from model.transformer_layer import TransformerLayer
5 | from einops import rearrange, repeat
6 |
7 |
8 | def get_time_embedding(time_steps, temb_dim):
9 | r"""
10 | Convert time steps tensor into an embedding using the
11 | sinusoidal time embedding formula
12 | :param time_steps: 1D tensor of length batch size
13 | :param temb_dim: Dimension of the embedding
14 | :return: BxD embedding representation of B time steps
15 | """
16 | assert temb_dim % 2 == 0, "time embedding dimension must be divisible by 2"
17 |
18 | # factor = 10000^(2i/d_model)
19 | factor = 10000 ** ((torch.arange(
20 | start=0,
21 | end=temb_dim // 2,
22 | dtype=torch.float32,
23 | device=time_steps.device) / (temb_dim // 2))
24 | )
25 |
26 | # pos / factor
27 | # timesteps B -> B, 1 -> B, temb_dim
28 | t_emb = time_steps[:, None].repeat(1, temb_dim // 2) / factor
29 | t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1)
30 | return t_emb
31 |
32 |
33 | class DITVideo(nn.Module):
34 | r"""
35 | Class for the Latte Model, which makes use of alternate spatial
36 | and temporal encoder layers, each of which are DiT blocks.
37 | """
38 | def __init__(self, frame_height, frame_width, im_channels, num_frames, config):
39 | super().__init__()
40 |
41 | num_layers = config['num_layers']
42 | self.image_height = frame_height
43 | self.image_width = frame_width
44 | self.im_channels = im_channels
45 | self.hidden_size = config['hidden_size']
46 | self.patch_height = config['patch_size']
47 | self.patch_width = config['patch_size']
48 | self.num_frames = num_frames
49 |
50 | self.timestep_emb_dim = config['timestep_emb_dim']
51 |
52 | # Number of patches along height and width
53 | self.nh = self.image_height // self.patch_height
54 | self.nw = self.image_width // self.patch_width
55 |
56 | # Patch Embedding Block
57 | self.patch_embed_layer = PatchEmbedding(image_height=self.image_height,
58 | image_width=self.image_width,
59 | im_channels=self.im_channels,
60 | patch_height=self.patch_height,
61 | patch_width=self.patch_width,
62 | hidden_size=self.hidden_size)
63 |
64 | # Initial projection from sinusoidal time embedding
65 | self.t_proj = nn.Sequential(
66 | nn.Linear(self.timestep_emb_dim, self.hidden_size),
67 | nn.SiLU(),
68 | nn.Linear(self.hidden_size, self.hidden_size)
69 | )
70 |
71 | # All Transformer Layers
72 | self.layers = nn.ModuleList([
73 | TransformerLayer(config) for _ in range(num_layers)
74 | ])
75 |
76 | # Final normalization for unpatchify block
77 | self.norm = nn.LayerNorm(self.hidden_size, elementwise_affine=False, eps=1E-6)
78 |
79 | # Scale and Shift parameters for the norm
80 | self.adaptive_norm_layer = nn.Sequential(
81 | nn.SiLU(),
82 | nn.Linear(self.hidden_size, 2 * self.hidden_size, bias=True)
83 | )
84 |
85 | # Final Linear Layer
86 | self.proj_out = nn.Linear(self.hidden_size,
87 | self.patch_height * self.patch_width * self.im_channels)
88 |
89 | ############################
90 | # DiT Layer Initialization #
91 | ############################
92 | nn.init.normal_(self.t_proj[0].weight, std=0.02)
93 | nn.init.normal_(self.t_proj[2].weight, std=0.02)
94 |
95 | nn.init.constant_(self.adaptive_norm_layer[-1].weight, 0)
96 | nn.init.constant_(self.adaptive_norm_layer[-1].bias, 0)
97 |
98 | nn.init.constant_(self.proj_out.weight, 0)
99 | nn.init.constant_(self.proj_out.bias, 0)
100 |
101 | def forward(self, x, t, num_images=0):
102 | r"""
103 | Forward method of our ditv model which predicts the noise
104 | :param x: input noisy image
105 | :param t: timestep of noise
106 | :param num_images: if joint training then this is number
107 | of images appended to video frames
108 | :return:
109 | """
110 | # Shape of x is Batch_size x (num_frames + num_images) x Channels x H x W
111 | B, F, C, H, W = x.shape
112 |
113 | ##################
114 | # Patchify Block #
115 | ##################
116 | # rearrange to (Batch_size * (num_frames + num_images)) x Channels x H x W
117 | x = rearrange(x, 'b f c h w -> (b f) c h w')
118 | out = self.patch_embed_layer(x)
119 |
120 | # out->(Batch_size * (num_frames + num_images)) x num_patch_tokens x hidden_size
121 | num_patch_tokens = out.shape[1]
122 |
123 | # Compute Timestep representation
124 | # t_emb -> (Batch, timestep_emb_dim)
125 | t_emb = get_time_embedding(torch.as_tensor(t).long(), self.timestep_emb_dim)
126 | # (Batch, timestep_emb_dim) -> (Batch, hidden_size)
127 | t_emb = self.t_proj(t_emb)
128 |
129 | # Timestep embedding will be Batch_size x hidden_size
130 | # We repeat it to get different timestep shapes for spatial and temporal layers
131 | # For spatial -> (Batch size * (num_frames + num_images)) x hidden_size
132 | # For temporal -> (Batch size * num_patch_tokens) x hidden_size
133 | t_emb_spatial = repeat(t_emb, 'b d -> (b f) d',
134 | f=self.num_frames+num_images)
135 | t_emb_temporal = repeat(t_emb, 'b d -> (b p) d', p=num_patch_tokens)
136 |
137 | # get temporal embedding from 0-num_frames(16)
138 | frame_pos = torch.arange(self.num_frames, dtype=torch.float32, device=x.device)
139 | frame_emb = get_time_embedding(frame_pos, self.hidden_size)
140 | # frame_emb -> (16 x hidden_size)
141 |
142 | # Loop over all transformer layers
143 | for layer_idx in range(0, len(self.layers), 2):
144 | spatial_layer = self.layers[layer_idx]
145 | temporal_layer = self.layers[layer_idx+1]
146 |
147 | # out->(Batch_size * (num_frames+num_images)) x num_patch_tokens x hidden_size
148 |
149 | #################
150 | # Spatial Layer #
151 | #################
152 | # position embedding is already added in patch embedding layer
153 | out = spatial_layer(out, t_emb_spatial)
154 |
155 | ##################
156 | # Temporal Layer #
157 | ##################
158 | # rearrange to (B * patch_tokens) x (num_frames+num_images) x hidden_size
159 | out = rearrange(out, '(b f) p d -> (b p) f d', b=B)
160 |
161 | # Separate the video tokens and image tokens
162 | out_video = out[:, :self.num_frames, :]
163 | out_images = out[:, self.num_frames:, :]
164 |
165 | # Add temporal embedding to video tokens
166 | # but only if first temporal layer
167 | if layer_idx == 0:
168 | out_video = out_video + frame_emb
169 | # Call temporal layer
170 | out_video = temporal_layer(out_video, t_emb_temporal)
171 |
172 | # Concatenate the image tokens back to the new video output
173 | out = torch.cat([out_video, out_images], dim=1)
174 |
175 | # Rearrange to (B * (num_frames+num_images)) x num_patch_tokens x hidden_size
176 | out = rearrange(out, '(b p) f d -> (b f) p d',
177 | f=self.num_frames+num_images, p=num_patch_tokens)
178 |
179 | # Shift and scale predictions for output normalization
180 | pre_mlp_shift, pre_mlp_scale = (self.adaptive_norm_layer(t_emb_spatial).
181 | chunk(2, dim=1))
182 | out = (self.norm(out) * (1 + pre_mlp_scale.unsqueeze(1)) +
183 | pre_mlp_shift.unsqueeze(1))
184 |
185 | # Unpatchify
186 | # Batch_size * (num_frames+num_images)) x patches x hidden_size
187 | # -> (B * (num_frames+num_images)) x patches x (patch height*patch width*channels)
188 | out = self.proj_out(out)
189 | out = rearrange(out, 'b (nh nw) (ph pw c) -> b c (nh ph) (nw pw)',
190 | ph=self.patch_height,
191 | pw=self.patch_width,
192 | nw=self.nw,
193 | nh=self.nh)
194 | # out -> (Batch_size * (num_frames+num_images)) x channels x h x w
195 | out = out.reshape((B, F, C, H, W))
196 | return out
197 |
--------------------------------------------------------------------------------
/model/transformer_layer.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from model.attention import Attention
3 |
4 |
5 | class TransformerLayer(nn.Module):
6 | r"""
7 | Transformer block which is just doing the following based on VIT
8 | 1. LayerNorm followed by Attention
9 | 2. LayerNorm followed by Feed forward Block
10 | Both these also have residuals added to them
11 |
12 | For DiT we additionally have
13 | 1. Layernorm mlp to predict layernorm affine parameters from
14 | 2. Same Layernorm mlp to also predict scale parameters for outputs
15 | of both mlp/attention prior to residual connection.
16 | """
17 |
18 | def __init__(self, config):
19 | super().__init__()
20 | self.hidden_size = config['hidden_size']
21 |
22 | ff_hidden_dim = 4 * self.hidden_size
23 |
24 | # Layer norm for attention block
25 | self.att_norm = nn.LayerNorm(self.hidden_size, elementwise_affine=False, eps=1E-6)
26 |
27 | self.attn_block = Attention(config)
28 |
29 | # Layer norm for mlp block
30 | self.ff_norm = nn.LayerNorm(self.hidden_size, elementwise_affine=False, eps=1E-6)
31 |
32 | self.mlp_block = nn.Sequential(
33 | nn.Linear(self.hidden_size, ff_hidden_dim),
34 | nn.GELU(approximate='tanh'),
35 | nn.Linear(ff_hidden_dim, self.hidden_size),
36 | )
37 |
38 | # Scale Shift Parameter predictions for this layer
39 | # 1. Scale and shift parameters for layernorm of attention (2 * hidden_size)
40 | # 2. Scale and shift parameters for layernorm of mlp (2 * hidden_size)
41 | # 3. Scale for output of attention prior to residual connection (hidden_size)
42 | # 4. Scale for output of mlp prior to residual connection (hidden_size)
43 | # Total 6 * hidden_size
44 | self.adaptive_norm_layer = nn.Sequential(
45 | nn.SiLU(),
46 | nn.Linear(self.hidden_size, 6 * self.hidden_size, bias=True)
47 | )
48 |
49 | ############################
50 | # DiT Layer Initialization #
51 | ############################
52 | nn.init.xavier_uniform_(self.mlp_block[0].weight)
53 | nn.init.constant_(self.mlp_block[0].bias, 0)
54 | nn.init.xavier_uniform_(self.mlp_block[-1].weight)
55 | nn.init.constant_(self.mlp_block[-1].bias, 0)
56 |
57 | nn.init.constant_(self.adaptive_norm_layer[-1].weight, 0)
58 | nn.init.constant_(self.adaptive_norm_layer[-1].bias, 0)
59 |
60 | def forward(self, x, condition):
61 | scale_shift_params = self.adaptive_norm_layer(condition).chunk(6, dim=1)
62 | (pre_attn_shift, pre_attn_scale, post_attn_scale,
63 | pre_mlp_shift, pre_mlp_scale, post_mlp_scale) = scale_shift_params
64 | out = x
65 | attn_norm_output = (self.att_norm(out) * (1 + pre_attn_scale.unsqueeze(1))
66 | + pre_attn_shift.unsqueeze(1))
67 | out = out + post_attn_scale.unsqueeze(1) * self.attn_block(attn_norm_output)
68 | mlp_norm_output = (self.ff_norm(out) * (1 + pre_mlp_scale.unsqueeze(1)) +
69 | pre_mlp_shift.unsqueeze(1))
70 | out = out + post_mlp_scale.unsqueeze(1) * self.mlp_block(mlp_norm_output)
71 | return out
72 |
--------------------------------------------------------------------------------
/model/vae.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from model.blocks import DownBlock, MidBlock, UpBlock
4 |
5 |
6 | class VAE(nn.Module):
7 | def __init__(self, im_channels, model_config):
8 | super().__init__()
9 | self.down_channels = model_config['down_channels']
10 | self.mid_channels = model_config['mid_channels']
11 | self.down_sample = model_config['down_sample']
12 | self.num_down_layers = model_config['num_down_layers']
13 | self.num_mid_layers = model_config['num_mid_layers']
14 | self.num_up_layers = model_config['num_up_layers']
15 |
16 | # To disable attention in Downblock of Encoder and Upblock of Decoder
17 | self.attns = model_config['attn_down']
18 |
19 | # Latent Dimension
20 | self.z_channels = model_config['z_channels']
21 | self.norm_channels = model_config['norm_channels']
22 | self.num_heads = model_config['num_heads']
23 |
24 | # Assertion to validate the channel information
25 | assert self.mid_channels[0] == self.down_channels[-1]
26 | assert self.mid_channels[-1] == self.down_channels[-1]
27 | assert len(self.down_sample) == len(self.down_channels) - 1
28 | assert len(self.attns) == len(self.down_channels) - 1
29 |
30 | # Wherever we use downsampling in encoder correspondingly use
31 | # upsampling in decoder
32 | self.up_sample = list(reversed(self.down_sample))
33 |
34 | ##################### Encoder ######################
35 | self.encoder_conv_in = nn.Conv2d(im_channels, self.down_channels[0], kernel_size=3, padding=(1, 1))
36 |
37 | # Downblock + Midblock
38 | self.encoder_layers = nn.ModuleList([])
39 | for i in range(len(self.down_channels) - 1):
40 | self.encoder_layers.append(DownBlock(self.down_channels[i], self.down_channels[i + 1],
41 | t_emb_dim=None, down_sample=self.down_sample[i],
42 | num_heads=self.num_heads,
43 | num_layers=self.num_down_layers,
44 | attn=self.attns[i],
45 | norm_channels=self.norm_channels))
46 |
47 | self.encoder_mids = nn.ModuleList([])
48 | for i in range(len(self.mid_channels) - 1):
49 | self.encoder_mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i + 1],
50 | t_emb_dim=None,
51 | num_heads=self.num_heads,
52 | num_layers=self.num_mid_layers,
53 | norm_channels=self.norm_channels))
54 |
55 | self.encoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[-1])
56 | self.encoder_conv_out = nn.Conv2d(self.down_channels[-1], 2 * self.z_channels, kernel_size=3, padding=1)
57 |
58 | # Latent Dimension is 2*Latent because we are predicting mean & variance
59 | self.pre_quant_conv = nn.Conv2d(2 * self.z_channels, 2 * self.z_channels, kernel_size=1)
60 | ####################################################
61 |
62 | ##################### Decoder ######################
63 | self.post_quant_conv = nn.Conv2d(self.z_channels, self.z_channels, kernel_size=1)
64 | self.decoder_conv_in = nn.Conv2d(self.z_channels, self.mid_channels[-1], kernel_size=3, padding=(1, 1))
65 |
66 | # Midblock + Upblock
67 | self.decoder_mids = nn.ModuleList([])
68 | for i in reversed(range(1, len(self.mid_channels))):
69 | self.decoder_mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i - 1],
70 | t_emb_dim=None,
71 | num_heads=self.num_heads,
72 | num_layers=self.num_mid_layers,
73 | norm_channels=self.norm_channels))
74 |
75 | self.decoder_layers = nn.ModuleList([])
76 | for i in reversed(range(1, len(self.down_channels))):
77 | self.decoder_layers.append(UpBlock(self.down_channels[i], self.down_channels[i - 1],
78 | t_emb_dim=None, up_sample=self.down_sample[i - 1],
79 | num_heads=self.num_heads,
80 | num_layers=self.num_up_layers,
81 | attn=self.attns[i - 1],
82 | norm_channels=self.norm_channels))
83 |
84 | self.decoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[0])
85 | self.decoder_conv_out = nn.Conv2d(self.down_channels[0], im_channels, kernel_size=3, padding=1)
86 |
87 | def encode(self, x):
88 | out = self.encoder_conv_in(x)
89 | for idx, down in enumerate(self.encoder_layers):
90 | out = down(out)
91 | for mid in self.encoder_mids:
92 | out = mid(out)
93 | out = self.encoder_norm_out(out)
94 | out = nn.SiLU()(out)
95 | out = self.encoder_conv_out(out)
96 | out = self.pre_quant_conv(out)
97 | mean, logvar = torch.chunk(out, 2, dim=1)
98 | std = torch.exp(0.5 * logvar)
99 | sample = mean + std * torch.randn(mean.shape).to(device=x.device)
100 | return sample, out
101 |
102 | def decode(self, z):
103 | out = z
104 | out = self.post_quant_conv(out)
105 | out = self.decoder_conv_in(out)
106 | for mid in self.decoder_mids:
107 | out = mid(out)
108 | for idx, up in enumerate(self.decoder_layers):
109 | out = up(out)
110 |
111 | out = self.decoder_norm_out(out)
112 | out = nn.SiLU()(out)
113 | out = self.decoder_conv_out(out)
114 | return out
115 |
116 | def forward(self, x):
117 | z, encoder_output = self.encode(x)
118 | out = self.decode(z)
119 | return out, encoder_output
120 |
121 |
122 |
123 | if __name__ == '__main__':
124 | a = torch.randn((2, 32, 16, 16))
125 | net = nn.Sequential(
126 | nn.Flatten(),
127 | nn.Linear(16* 16 * 32, 2 * 32)
128 | )
129 | net = torch.nn.Conv2d(32, 64, kernel_size=1)
130 | out = net(a)
131 | print(out.shape)
--------------------------------------------------------------------------------
/model/weights/v0.1/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/explainingai-code/VideoGeneration-PyTorch/9d40d42a8f5a8919ce0d356ad54ebfce4c3089b8/model/weights/v0.1/.gitkeep
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | einops==0.8.0
2 | numpy==2.0.1
3 | opencv_python==4.10.0.84
4 | Pillow==10.4.0
5 | PyYAML==6.0.1
6 | torch==2.3.1
7 | torchvision==0.18.1
8 | tqdm==4.66.4
9 |
--------------------------------------------------------------------------------
/scheduler/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/explainingai-code/VideoGeneration-PyTorch/9d40d42a8f5a8919ce0d356ad54ebfce4c3089b8/scheduler/__init__.py
--------------------------------------------------------------------------------
/scheduler/linear_scheduler.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class LinearNoiseScheduler:
5 | r"""
6 | Class for the linear noise scheduler that is used in DDPM.
7 | """
8 |
9 | def __init__(self, num_timesteps, beta_start, beta_end):
10 | self.num_timesteps = num_timesteps
11 | self.beta_start = beta_start
12 | self.beta_end = beta_end
13 |
14 | self.betas = torch.linspace(beta_start, beta_end, num_timesteps)
15 | self.alphas = 1. - self.betas
16 | self.alpha_cum_prod = torch.cumprod(self.alphas, dim=0)
17 | self.sqrt_alpha_cum_prod = torch.sqrt(self.alpha_cum_prod)
18 | self.sqrt_one_minus_alpha_cum_prod = torch.sqrt(1 - self.alpha_cum_prod)
19 |
20 | def add_noise(self, original, noise, t):
21 | r"""
22 | Forward method for diffusion
23 | :param original: Image on which noise is to be applied
24 | :param noise: Random Noise Tensor (from normal dist)
25 | :param t: timestep of the forward process of shape -> (B,)
26 | :return:
27 | """
28 | original_shape = original.shape
29 | batch_size = original_shape[0]
30 |
31 | sqrt_alpha_cum_prod = self.sqrt_alpha_cum_prod.to(original.device)[t].reshape(batch_size)
32 | sqrt_one_minus_alpha_cum_prod = self.sqrt_one_minus_alpha_cum_prod.to(original.device)[t].reshape(batch_size)
33 |
34 | # Reshape till (B,) becomes (B,1,1,1) if image is (B,C,H,W)
35 | for _ in range(len(original_shape) - 1):
36 | sqrt_alpha_cum_prod = sqrt_alpha_cum_prod.unsqueeze(-1)
37 | for _ in range(len(original_shape) - 1):
38 | sqrt_one_minus_alpha_cum_prod = sqrt_one_minus_alpha_cum_prod.unsqueeze(-1)
39 |
40 | # Apply and Return Forward process equation
41 | return (sqrt_alpha_cum_prod.to(original.device) * original
42 | + sqrt_one_minus_alpha_cum_prod.to(original.device) * noise)
43 |
44 | def sample_prev_timestep(self, xt, pred, t):
45 | r"""
46 | Use the noise prediction by model to get
47 | xt-1 using xt and the noise predicted
48 | :param xt: current timestep sample
49 | :param pred: model noise prediction
50 | :param t: current timestep we are at
51 | :return:
52 | """
53 | x0 = ((xt - (self.sqrt_one_minus_alpha_cum_prod.to(xt.device)[t] * pred)) /
54 | torch.sqrt(self.alpha_cum_prod.to(xt.device)[t]))
55 | x0 = torch.clamp(x0, -1., 1.)
56 |
57 | mean = xt - ((self.betas.to(xt.device)[t]) * pred) / (self.sqrt_one_minus_alpha_cum_prod.to(xt.device)[t])
58 | mean = mean / torch.sqrt(self.alphas.to(xt.device)[t])
59 |
60 | if t == 0:
61 | return mean, x0
62 | else:
63 | variance = (1 - self.alpha_cum_prod.to(xt.device)[t - 1]) / (1.0 - self.alpha_cum_prod.to(xt.device)[t])
64 | variance = variance * self.betas.to(xt.device)[t]
65 | sigma = variance ** 0.5
66 | z = torch.randn(xt.shape).to(xt.device)
67 | return mean + sigma * z, x0
68 |
--------------------------------------------------------------------------------
/tools/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/explainingai-code/VideoGeneration-PyTorch/9d40d42a8f5a8919ce0d356ad54ebfce4c3089b8/tools/__init__.py
--------------------------------------------------------------------------------
/tools/sample_vae_ditv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 | import argparse
4 | import yaml
5 | import os
6 | from torchvision.utils import make_grid
7 | from PIL import Image
8 | from tqdm import tqdm
9 | import torchvision.transforms.v2 as v2
10 | from model.vae import VAE
11 | from model.transformer import DITVideo
12 | from scheduler.linear_scheduler import LinearNoiseScheduler
13 |
14 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15 | if torch.backends.mps.is_available():
16 | device = torch.device('mps')
17 | print('Using mps')
18 |
19 |
20 | def sample(model, scheduler, train_config, ditv_model_config,
21 | autoencoder_model_config, diffusion_config, dataset_config, vae):
22 | r"""
23 | Sample stepwise by going backward one timestep at a time.
24 | We save the x0 predictions
25 | """
26 | latent_frame_height = (dataset_config['frame_height']
27 | // 2 ** sum(autoencoder_model_config['down_sample']))
28 | latent_frame_width = (dataset_config['frame_width']
29 | // 2 ** sum(autoencoder_model_config['down_sample']))
30 |
31 | xt = torch.randn((1, dataset_config['num_frames'],
32 | autoencoder_model_config['z_channels'],
33 | latent_frame_height, latent_frame_width)).to(device)
34 |
35 | for i in tqdm(reversed(range(diffusion_config['num_timesteps']))):
36 | # Get prediction of noise
37 | noise_pred = model(xt, torch.as_tensor(i).unsqueeze(0).to(device))
38 |
39 | # Use scheduler to get x0 and xt-1
40 | xt, x0_pred = scheduler.sample_prev_timestep(xt,
41 | noise_pred,
42 | torch.as_tensor(i).to(device))
43 |
44 | # Save x0
45 | if i == 0:
46 | # Decode ONLY the final video to save time
47 | ims = vae.to(device).decode(xt[0])
48 | else:
49 | ims = xt
50 |
51 | ims = torch.clamp(ims, -1., 1.).detach().cpu()
52 | ims = (ims + 1) / 2
53 | tv_frames = ims * 255
54 |
55 | if i == 0:
56 | tv_frames = tv_frames.permute((0, 2, 3, 1))
57 | if tv_frames.shape[-1] == 1:
58 | tv_frames = tv_frames.repeat((1, 1, 1, 3))
59 | else:
60 | tv_frames = v2.Compose([
61 | v2.Resize((dataset_config['frame_height'],
62 | dataset_config['frame_width']),
63 | interpolation=v2.InterpolationMode.NEAREST),
64 | ])(tv_frames)
65 | tv_frames = tv_frames[0].permute((0, 2, 3, 1))[:, :, :, :3]
66 |
67 | if not os.path.exists(os.path.join(train_config['task_name'], 'samples')):
68 | os.mkdir(os.path.join(train_config['task_name'], 'samples'))
69 | torchvision.io.write_video(os.path.join(train_config['task_name'],
70 | 'samples/sample_output_{}.mp4'.format(i)),
71 | tv_frames,
72 | fps=8)
73 |
74 |
75 | def infer(args):
76 | # Read the config file #
77 | with open(args.config_path, 'r') as file:
78 | try:
79 | config = yaml.safe_load(file)
80 | except yaml.YAMLError as exc:
81 | print(exc)
82 | print(config)
83 | ########################
84 |
85 | diffusion_config = config['diffusion_params']
86 | dataset_config = config['dataset_params']
87 | ditv_model_config = config['ditv_params']
88 | autoencoder_model_config = config['autoencoder_params']
89 | train_config = config['train_params']
90 |
91 | # Create the noise scheduler
92 | scheduler = LinearNoiseScheduler(num_timesteps=diffusion_config['num_timesteps'],
93 | beta_start=diffusion_config['beta_start'],
94 | beta_end=diffusion_config['beta_end'])
95 |
96 | # Get latent image size
97 | frame_height = (dataset_config['frame_height']
98 | // 2 ** sum(autoencoder_model_config['down_sample']))
99 | frame_width = (dataset_config['frame_width']
100 | // 2 ** sum(autoencoder_model_config['down_sample']))
101 | num_frames = dataset_config['num_frames']
102 | model = DITVideo(frame_height=frame_height,
103 | frame_width=frame_width,
104 | im_channels=autoencoder_model_config['z_channels'],
105 | num_frames=num_frames,
106 | config=ditv_model_config).to(device)
107 | model.eval()
108 |
109 | assert os.path.exists(os.path.join(train_config['task_name'],
110 | train_config['ditv_ckpt_name'])), \
111 | "Train DiT Video Model first"
112 |
113 | model.load_state_dict(torch.load(os.path.join(train_config['task_name'],
114 | train_config['ditv_ckpt_name']),
115 | map_location=device))
116 | print('Loaded dit video checkpoint')
117 |
118 | # Create output directories
119 | if not os.path.exists(train_config['task_name']):
120 | os.mkdir(train_config['task_name'])
121 |
122 | vae = VAE(im_channels=dataset_config['frame_channels'],
123 | model_config=autoencoder_model_config)
124 | vae.eval()
125 |
126 | # Load vae if found
127 | assert os.path.exists(os.path.join(train_config['task_name'],
128 | train_config['vae_autoencoder_ckpt_name'])), \
129 | "VAE checkpoint not present. Train VAE first."
130 | vae.load_state_dict(torch.load(
131 | os.path.join(train_config['task_name'],
132 | train_config['vae_autoencoder_ckpt_name']),
133 | map_location=device), strict=True)
134 | print('Loaded vae checkpoint')
135 |
136 | with torch.no_grad():
137 | sample(model, scheduler, train_config, ditv_model_config,
138 | autoencoder_model_config, diffusion_config, dataset_config, vae)
139 |
140 |
141 | if __name__ == '__main__':
142 | parser = argparse.ArgumentParser(description='Arguments for latte video generation')
143 | parser.add_argument('--config', dest='config_path',
144 | default='config/mnist.yaml', type=str)
145 | args = parser.parse_args()
146 | infer(args)
147 |
--------------------------------------------------------------------------------
/tools/save_latents.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import glob
3 | import os
4 | import pickle
5 |
6 | import torch
7 | import torchvision
8 | import yaml
9 | from dataset.ucf_dataset import UCFDataset
10 | from dataset.mnist_dataset import MnistDataset
11 | from dataset.video_dataset import VideoDataset
12 |
13 | from torch.utils.data.dataloader import DataLoader
14 | import torchvision.transforms.v2
15 | from torchvision.utils import make_grid
16 | from tqdm import tqdm
17 |
18 |
19 | from model.vae import VAE
20 |
21 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
22 | if torch.backends.mps.is_available():
23 | device = torch.device('mps')
24 | print('Using mps')
25 |
26 |
27 | def save_vae_latents(args):
28 | ######## Read the config file #######
29 | with open(args.config_path, 'r') as file:
30 | try:
31 | config = yaml.safe_load(file)
32 | except yaml.YAMLError as exc:
33 | print(exc)
34 | print(config)
35 |
36 | dataset_config = config['dataset_params']
37 | autoencoder_config = config['autoencoder_params']
38 | train_config = config['train_params']
39 |
40 | model = VAE(im_channels=dataset_config['frame_channels'],
41 | model_config=autoencoder_config).to(device)
42 | model.load_state_dict(torch.load(os.path.join(
43 | train_config['task_name'],
44 | train_config['vae_autoencoder_ckpt_name']),
45 | map_location=device))
46 | model.eval()
47 |
48 | dataset = VideoDataset('train',
49 | dataset_config)
50 |
51 | print('Will be generating latents for {} videos'.format(len(dataset.video_paths)))
52 | with torch.no_grad():
53 | if not os.path.exists(os.path.join(train_config['task_name'],
54 | train_config['save_video_latent_dir'])):
55 | os.mkdir(os.path.join(train_config['task_name'],
56 | train_config['save_video_latent_dir']))
57 | for path in tqdm(dataset.video_paths):
58 | # Read the video
59 | frames, _, _ = torchvision.io.read_video(filename=path,
60 | pts_unit='sec',
61 | output_format='TCHW')
62 |
63 | # Transform all frames
64 | frames_tensor = dataset.transforms(frames)
65 | if dataset_config['frame_channels'] == 1:
66 | frames_tensor = frames_tensor[:, 0:1, :, :]
67 |
68 | encoded_outputs = []
69 | for frame_tensor in frames_tensor:
70 | _, encoded_output = model.encode(
71 | frame_tensor.float().unsqueeze(0).to(device)
72 | )
73 | encoded_outputs.append(encoded_output)
74 | encoded_outputs = torch.cat(encoded_outputs, dim=0)
75 | pickle.dump(encoded_outputs, open(
76 | os.path.join(train_config['task_name'],
77 | train_config['save_video_latent_dir'],
78 | '{}.pkl'.format(os.path.basename(path))), 'wb'))
79 |
80 |
81 | if __name__ == '__main__':
82 | parser = argparse.ArgumentParser(description='Arguments for vae inference and '
83 | 'saving latents')
84 | parser.add_argument('--config', dest='config_path',
85 | default='config/ucf.yaml', type=str)
86 | args = parser.parse_args()
87 | save_vae_latents(args)
88 |
--------------------------------------------------------------------------------
/tools/train_vae.py:
--------------------------------------------------------------------------------
1 | import yaml
2 | import argparse
3 | import torch
4 | import random
5 | import torchvision
6 | import os
7 | import numpy as np
8 | from tqdm import tqdm
9 | from model.vae import VAE
10 | from model.lpips import LPIPS
11 | from model.discriminator import Discriminator
12 | from torch.utils.data.dataloader import DataLoader
13 | from torch.optim import Adam
14 | from torchvision.utils import make_grid
15 | from dataset.image_dataset import ImageDataset
16 |
17 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
18 | if torch.backends.mps.is_available():
19 | device = torch.device('mps')
20 | print('Using mps')
21 |
22 |
23 | def train(args):
24 | # Read the config file #
25 | with open(args.config_path, 'r') as file:
26 | try:
27 | config = yaml.safe_load(file)
28 | except yaml.YAMLError as exc:
29 | print(exc)
30 | print(config)
31 |
32 | dataset_config = config['dataset_params']
33 | autoencoder_config = config['autoencoder_params']
34 | train_config = config['train_params']
35 |
36 | # Set the desired seed value #
37 | seed = train_config['seed']
38 | torch.manual_seed(seed)
39 | np.random.seed(seed)
40 | random.seed(seed)
41 | if device == 'cuda':
42 | torch.cuda.manual_seed_all(seed)
43 | #############################
44 |
45 | # Create the model and dataset #
46 | model = VAE(im_channels=dataset_config['frame_channels'],
47 | model_config=autoencoder_config).to(device)
48 |
49 | # Create the dataset
50 | im_dataset = ImageDataset(split='train',
51 | dataset_config=dataset_config,
52 | task_name=train_config['task_name'])
53 |
54 | data_loader = DataLoader(im_dataset,
55 | batch_size=train_config['autoencoder_batch_size'],
56 | shuffle=True)
57 |
58 | # Create output directories
59 | if not os.path.exists(train_config['task_name']):
60 | os.mkdir(train_config['task_name'])
61 |
62 | num_epochs = train_config['autoencoder_epochs']
63 |
64 | # L1/L2 loss for Reconstruction
65 | recon_criterion = torch.nn.MSELoss()
66 | # Disc Loss can even be BCEWithLogits
67 | disc_criterion = torch.nn.MSELoss()
68 |
69 | # No need to freeze lpips as lpips.py takes care of that
70 | lpips_model = LPIPS().eval().to(device)
71 | discriminator = Discriminator(im_channels=dataset_config['frame_channels']).to(device)
72 |
73 | if os.path.exists(os.path.join(train_config['task_name'],
74 | train_config['vae_autoencoder_ckpt_name'])):
75 | model.load_state_dict(torch.load(os.path.join(train_config['task_name'],
76 | train_config['vae_autoencoder_ckpt_name']),
77 | map_location=device))
78 | print('Loaded autoencoder from checkpoint')
79 |
80 | if os.path.exists(os.path.join(train_config['task_name'],
81 | train_config['vae_discriminator_ckpt_name'])):
82 | discriminator.load_state_dict(torch.load(os.path.join(train_config['task_name'],
83 | train_config['vae_discriminator_ckpt_name']),
84 | map_location=device))
85 | print('Loaded discriminator from checkpoint')
86 |
87 | optimizer_d = Adam(discriminator.parameters(), lr=train_config['autoencoder_lr'], betas=(0.5, 0.999))
88 | optimizer_g = Adam(model.parameters(), lr=train_config['autoencoder_lr'], betas=(0.5, 0.999))
89 |
90 | disc_step_start = train_config['disc_start']
91 | step_count = 0
92 |
93 | # This is for accumulating gradients incase the images are huge
94 | # And one cant afford higher batch sizes
95 | acc_steps = train_config['autoencoder_acc_steps']
96 | image_save_steps = train_config['autoencoder_img_save_steps']
97 | img_save_count = 0
98 |
99 | for epoch_idx in range(num_epochs):
100 | recon_losses = []
101 | perceptual_losses = []
102 | disc_losses = []
103 | gen_losses = []
104 | losses = []
105 |
106 | optimizer_g.zero_grad()
107 | optimizer_d.zero_grad()
108 |
109 | for im in tqdm(data_loader):
110 | step_count += 1
111 | im = im.float().to(device)
112 |
113 | # Fetch autoencoders output(reconstructions)
114 | model_output = model(im)
115 | output, encoder_output = model_output
116 |
117 | # Image Saving Logic
118 | if step_count % image_save_steps == 0 or step_count == 1:
119 | sample_size = min(8, im.shape[0])
120 | save_output = torch.clamp(output[:sample_size], -1., 1.).detach().cpu()
121 | save_output = ((save_output + 1) / 2)
122 | save_input = ((im[:sample_size] + 1) / 2).detach().cpu()
123 |
124 | grid = make_grid(torch.cat([save_input, save_output], dim=0), nrow=sample_size)
125 | img = torchvision.transforms.ToPILImage()(grid)
126 | if not os.path.exists(os.path.join(train_config['task_name'], 'vae_autoencoder_samples')):
127 | os.mkdir(os.path.join(train_config['task_name'], 'vae_autoencoder_samples'))
128 | img.save(os.path.join(train_config['task_name'], 'vae_autoencoder_samples',
129 | 'current_autoencoder_sample_{}.png'.format(img_save_count)))
130 | img_save_count += 1
131 | img.close()
132 |
133 | ######### Optimize Generator ##########
134 | # L2 Loss
135 | recon_loss = recon_criterion(output, im)
136 | recon_losses.append(recon_loss.item())
137 | recon_loss = recon_loss / acc_steps
138 |
139 | mean, logvar = torch.chunk(encoder_output, 2, dim=1)
140 | kl_loss = torch.mean(0.5 * torch.sum(torch.exp(logvar) + mean ** 2 - 1 - logvar, dim=[1, 2, 3]))
141 |
142 | g_loss = recon_loss + (train_config['kl_weight'] * kl_loss / acc_steps)
143 |
144 | # Adversarial loss only if disc_step_start steps passed
145 | if step_count > disc_step_start:
146 | disc_fake_pred = discriminator(model_output[0])
147 | disc_fake_loss = disc_criterion(disc_fake_pred,
148 | torch.ones(disc_fake_pred.shape,
149 | device=disc_fake_pred.device))
150 | gen_losses.append(train_config['disc_weight'] * disc_fake_loss.item())
151 | g_loss += train_config['disc_weight'] * disc_fake_loss / acc_steps
152 | lpips_loss = torch.mean(lpips_model(output, im))
153 | perceptual_losses.append(train_config['perceptual_weight'] * lpips_loss.item())
154 | g_loss += train_config['perceptual_weight'] * lpips_loss / acc_steps
155 | losses.append(g_loss.item())
156 | g_loss.backward()
157 | #####################################
158 |
159 | ######### Optimize Discriminator #######
160 | if step_count > disc_step_start:
161 | fake = output
162 | disc_fake_pred = discriminator(fake.detach())
163 | disc_real_pred = discriminator(im)
164 | disc_fake_loss = disc_criterion(disc_fake_pred,
165 | torch.zeros(disc_fake_pred.shape,
166 | device=disc_fake_pred.device))
167 | disc_real_loss = disc_criterion(disc_real_pred,
168 | torch.ones(disc_real_pred.shape,
169 | device=disc_real_pred.device))
170 | disc_loss = train_config['disc_weight'] * (disc_fake_loss + disc_real_loss) / 2
171 | disc_losses.append(disc_loss.item())
172 | disc_loss = disc_loss / acc_steps
173 | disc_loss.backward()
174 | if step_count % acc_steps == 0:
175 | optimizer_d.step()
176 | optimizer_d.zero_grad()
177 | #####################################
178 |
179 | if step_count % acc_steps == 0:
180 | optimizer_g.step()
181 | optimizer_g.zero_grad()
182 | optimizer_d.step()
183 | optimizer_d.zero_grad()
184 | optimizer_g.step()
185 | optimizer_g.zero_grad()
186 | if len(disc_losses) > 0:
187 | print(
188 | 'Finished epoch: {} | Recon Loss : {:.4f} | Perceptual Loss : {:.4f} | '
189 | 'G Loss : {:.4f} | D Loss {:.4f}'.
190 | format(epoch_idx + 1,
191 | np.mean(recon_losses),
192 | np.mean(perceptual_losses),
193 | np.mean(gen_losses),
194 | np.mean(disc_losses)))
195 | else:
196 | print('Finished epoch: {} | Recon Loss : {:.4f} | Perceptual Loss : {:.4f}'.
197 | format(epoch_idx + 1,
198 | np.mean(recon_losses),
199 | np.mean(perceptual_losses)))
200 |
201 | torch.save(model.state_dict(), os.path.join(train_config['task_name'],
202 | train_config['vae_autoencoder_ckpt_name']))
203 | torch.save(discriminator.state_dict(), os.path.join(train_config['task_name'],
204 | train_config['vae_discriminator_ckpt_name']))
205 | print('Done Training...')
206 |
207 |
208 | if __name__ == '__main__':
209 | parser = argparse.ArgumentParser(description='Arguments for vae training')
210 | parser.add_argument('--config', dest='config_path',
211 | default='config/ucf.yaml', type=str)
212 | args = parser.parse_args()
213 | train(args)
214 |
--------------------------------------------------------------------------------
/tools/train_vae_ditv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import yaml
3 | import argparse
4 | import os
5 | import numpy as np
6 | from tqdm import tqdm
7 | from einops import rearrange
8 | from torch.optim import AdamW
9 | from torch.utils.data import DataLoader
10 | from dataset.video_dataset import VideoDataset
11 | from model.transformer import DITVideo
12 | from model.vae import VAE
13 | from scheduler.linear_scheduler import LinearNoiseScheduler
14 |
15 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
16 | if torch.backends.mps.is_available():
17 | device = torch.device('mps')
18 | print('Using mps')
19 |
20 |
21 | def train(args):
22 | # Read the config file #
23 | with open(args.config_path, 'r') as file:
24 | try:
25 | config = yaml.safe_load(file)
26 | except yaml.YAMLError as exc:
27 | print(exc)
28 | print(config)
29 | ########################
30 |
31 | diffusion_config = config['diffusion_params']
32 | dataset_config = config['dataset_params']
33 | ditv_model_config = config['ditv_params']
34 | autoencoder_model_config = config['autoencoder_params']
35 | train_config = config['train_params']
36 |
37 | # Create the noise scheduler
38 | scheduler = LinearNoiseScheduler(num_timesteps=diffusion_config['num_timesteps'],
39 | beta_start=diffusion_config['beta_start'],
40 | beta_end=diffusion_config['beta_end'])
41 |
42 | dataset = VideoDataset('train',
43 | dataset_config=dataset_config,
44 | latent_path=os.path.join(
45 | train_config['task_name'],
46 | train_config['save_video_latent_dir']))
47 |
48 | data_loader = DataLoader(dataset,
49 | batch_size=train_config['ditv_batch_size'],
50 | shuffle=True)
51 |
52 | # Instantiate the model
53 | frame_height = (dataset_config['frame_height'] //
54 | 2 ** sum(autoencoder_model_config['down_sample']))
55 | frame_width = (dataset_config['frame_width'] //
56 | 2 ** sum(autoencoder_model_config['down_sample']))
57 | num_frames = dataset_config['num_frames']
58 | model = DITVideo(frame_height=frame_height,
59 | frame_width=frame_width,
60 | im_channels=autoencoder_model_config['z_channels'],
61 | num_frames=num_frames,
62 | config=ditv_model_config).to(device)
63 | model.train()
64 |
65 | if os.path.exists(os.path.join(train_config['task_name'],
66 | train_config['ditv_ckpt_name'])):
67 | print('Loaded DiT Video checkpoint')
68 | model.load_state_dict(torch.load(os.path.join(train_config['task_name'],
69 | train_config['ditv_ckpt_name']),
70 | map_location=device))
71 |
72 | # Load VAE
73 | if not dataset.use_latents:
74 | print('Loading vae model as latents not present')
75 | vae = VAE(im_channels=dataset_config['frame_channels'],
76 | model_config=autoencoder_model_config).to(device)
77 | vae.eval()
78 | # Load vae if found
79 | assert os.path.exists(os.path.join(train_config['task_name'],
80 | train_config['vae_autoencoder_ckpt_name'])), \
81 | "VAE checkpoint not found"
82 | vae.load_state_dict(torch.load(os.path.join(
83 | train_config['task_name'],
84 | train_config['vae_autoencoder_ckpt_name']),
85 | map_location=device))
86 | print('Loaded vae checkpoint')
87 | for param in vae.parameters():
88 | param.requires_grad = False
89 |
90 | # Specify training parameters
91 | num_epochs = train_config['ditv_epochs']
92 | optimizer = AdamW(model.parameters(), lr=1E-4, weight_decay=0)
93 | criterion = torch.nn.MSELoss()
94 |
95 | acc_steps = train_config['ditv_acc_steps']
96 | for epoch_idx in range(num_epochs):
97 | losses = []
98 | step_count = 0
99 | for ims in tqdm(data_loader):
100 | step_count += 1
101 | ims = ims.float().to(device)
102 | B, F, C, H, W = ims.shape
103 | if not dataset.use_latents:
104 | with torch.no_grad():
105 | ims, _ = vae.encode(ims.reshape(-1, C, H, W))
106 | ims = rearrange(ims, '(b f) c h w -> b f c h w', b=B, f=F)
107 |
108 | # Sample random noise
109 | noise = torch.randn_like(ims).to(device)
110 |
111 | # Sample timestep
112 | t = torch.randint(0, diffusion_config['num_timesteps'],
113 | (ims.shape[0],)).to(device)
114 |
115 | # Add noise to video according to timestep
116 | noisy_im = scheduler.add_noise(ims, noise, t)
117 | pred = model(noisy_im, t, num_images=dataset_config['num_images_train'])
118 | loss = criterion(pred, noise)
119 | losses.append(loss.item())
120 | loss = loss / acc_steps
121 | loss.backward()
122 | if step_count % acc_steps == 0:
123 | optimizer.step()
124 | optimizer.zero_grad()
125 | optimizer.step()
126 | optimizer.zero_grad()
127 | print('Finished epoch:{} | Loss : {:.4f}'.format(
128 | epoch_idx + 1,
129 | np.mean(losses)))
130 | torch.save(model.state_dict(), os.path.join(train_config['task_name'],
131 | train_config['ditv_ckpt_name']))
132 |
133 | print('Done Training ...')
134 |
135 |
136 | if __name__ == '__main__':
137 | parser = argparse.ArgumentParser(description='Arguments for latte training')
138 | parser.add_argument('--config', dest='config_path',
139 | default='config/mnist.yaml', type=str)
140 | args = parser.parse_args()
141 | train(args)
142 |
--------------------------------------------------------------------------------