├── video_model ├── diffusion │ ├── __init__.py │ ├── fp16_util.py │ ├── losses.py │ ├── dist_util.py │ ├── respace.py │ ├── nn.py │ ├── resample.py │ ├── script_util.py │ ├── logger.py │ ├── train_util.py │ └── make_a_video.py ├── datasets │ ├── clevrer.py │ ├── MovingMNIST.py │ └── video_datasets.py ├── video_train.py └── video_sample.py ├── image_model ├── diffusion │ ├── __init__.py │ ├── diffusion_utils.py │ ├── respace.py │ └── timestep_sampler.py ├── datasets.py ├── sample.py ├── train_JPDVT.py └── models.py └── README.md /video_model/diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Codebase for "Improved Denoising Diffusion Probabilistic Models". 3 | 4 | Raw copy of [1] at commit 32b43b6c677df7642c5408c8ef4a09272787eb50 from Feb 22, 2021. 5 | 6 | [1] https://github.com/openai/improved-diffusion/tree/main/improved_diffusion 7 | """ 8 | from .gaussian_diffusion import GaussianDiffusion 9 | from .unet import UNetModel 10 | from .resample import ScheduleSampler, UniformSampler, HarmonicSampler, LossAwareSampler, LossSecondMomentResampler, create_named_schedule_sampler 11 | from .respace import SpacedDiffusion, space_timesteps 12 | -------------------------------------------------------------------------------- /image_model/diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | from . import gaussian_diffusion as gd 7 | from .respace import SpacedDiffusion, space_timesteps 8 | 9 | 10 | def create_diffusion( 11 | timestep_respacing, 12 | noise_schedule="linear", 13 | use_kl=False, 14 | sigma_small=True, 15 | predict_xstart=True, 16 | learn_sigma=False, 17 | rescale_learned_sigmas=False, 18 | diffusion_steps=1000 19 | ): 20 | betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps) 21 | if use_kl: 22 | loss_type = gd.LossType.RESCALED_KL 23 | elif rescale_learned_sigmas: 24 | loss_type = gd.LossType.RESCALED_MSE 25 | else: 26 | loss_type = gd.LossType.MSE 27 | if timestep_respacing is None or timestep_respacing == "": 28 | timestep_respacing = [diffusion_steps] 29 | return SpacedDiffusion( 30 | use_timesteps=space_timesteps(diffusion_steps, timestep_respacing), 31 | betas=betas, 32 | model_mean_type=( 33 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X 34 | ), 35 | model_var_type=( 36 | ( 37 | gd.ModelVarType.FIXED_LARGE 38 | if not sigma_small 39 | else gd.ModelVarType.FIXED_SMALL 40 | ) 41 | if not learn_sigma 42 | else gd.ModelVarType.LEARNED_RANGE 43 | ), 44 | loss_type=loss_type 45 | # rescale_timesteps=rescale_timesteps, 46 | ) 47 | -------------------------------------------------------------------------------- /video_model/datasets/clevrer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | from torch.utils.data import Dataset, DataLoader 5 | import glob 6 | from PIL import Image 7 | import torchvision.transforms as T 8 | class Clevrer(Dataset): 9 | def __init__(self, npy_dir,seq_len=32, downsample_ratio=4,downsample_mode='Notuniform'): 10 | self.npy_files = sorted(glob.glob(os.path.join(npy_dir, '*','*.npy'))) 11 | self.dsr = downsample_ratio 12 | self.down_mode = downsample_mode 13 | self.seq_len = seq_len 14 | self.clip_num = 128 // self.dsr // self.seq_len 15 | 16 | def __len__(self): 17 | return len(self.npy_files) * self.clip_num 18 | 19 | def __getitem__(self, idx): 20 | npy_file = self.npy_files[idx // self.clip_num] 21 | video_data = np.load(npy_file) 22 | video_data = np.transpose(video_data, (0, 3, 1, 2)) # move channel axis to the first dimension 23 | if self.down_mode == 'uniform': 24 | video_data = video_data[::self.dsr] 25 | video_data = video_data[ 26 | idx % self.clip_num * self.seq_len: 27 | idx % self.clip_num * self.seq_len + min(self.seq_len,video_data.shape[0])] 28 | video_data = torch.from_numpy(video_data).float() 29 | return video_data 30 | else: 31 | frame_num = np.asarray(range(0,128//self.dsr)) 32 | frame_num_uneven = [min(x*self.dsr + np.random.randint(-self.dsr//2,self.dsr//2+1),127) for x in frame_num] 33 | frame_num_uneven[0] = 0 34 | uneven_video_data = video_data[frame_num_uneven] 35 | uneven_video_data = uneven_video_data[ 36 | idx % self.clip_num * self.seq_len: 37 | idx % self.clip_num * self.seq_len + min(self.seq_len, video_data.shape[0])] 38 | video_data = video_data[::self.dsr] 39 | video_data = video_data[ 40 | idx % self.clip_num * self.seq_len: 41 | idx % self.clip_num * self.seq_len + min(self.seq_len, video_data.shape[0])] 42 | uneven_video_data = torch.from_numpy(uneven_video_data).float() 43 | video_data = torch.from_numpy(video_data).float() 44 | return video_data,uneven_video_data,idx // self.clip_num 45 | 46 | -------------------------------------------------------------------------------- /video_model/diffusion/fp16_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers to train with 16-bit precision. 3 | """ 4 | 5 | import torch.nn as nn 6 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors 7 | 8 | 9 | def convert_module_to_f16(l): 10 | """ 11 | Convert primitive modules to float16. 12 | """ 13 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 14 | l.weight.data = l.weight.data.half() 15 | l.bias.data = l.bias.data.half() 16 | 17 | 18 | def convert_module_to_f32(l): 19 | """ 20 | Convert primitive modules to float32, undoing convert_module_to_f16(). 21 | """ 22 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 23 | l.weight.data = l.weight.data.float() 24 | l.bias.data = l.bias.data.float() 25 | 26 | 27 | def make_master_params(model_params): 28 | """ 29 | Copy model parameters into a (differently-shaped) list of full-precision 30 | parameters. 31 | """ 32 | master_params = _flatten_dense_tensors( 33 | [param.detach().float() for param in model_params] 34 | ) 35 | master_params = nn.Parameter(master_params) 36 | master_params.requires_grad = True 37 | return [master_params] 38 | 39 | 40 | def model_grads_to_master_grads(model_params, master_params): 41 | """ 42 | Copy the gradients from the model parameters into the master parameters 43 | from make_master_params(). 44 | """ 45 | master_params[0].grad = _flatten_dense_tensors( 46 | [param.grad.data.detach().float() for param in model_params] 47 | ) 48 | 49 | 50 | def master_params_to_model_params(model_params, master_params): 51 | """ 52 | Copy the master parameter data back into the model parameters. 53 | """ 54 | # Without copying to a list, if a generator is passed, this will 55 | # silently not copy any parameters. 56 | model_params = list(model_params) 57 | 58 | for param, master_param in zip( 59 | model_params, unflatten_master_params(model_params, master_params) 60 | ): 61 | param.detach().copy_(master_param) 62 | 63 | 64 | def unflatten_master_params(model_params, master_params): 65 | """ 66 | Unflatten the master parameters to look like model_params. 67 | """ 68 | return _unflatten_dense_tensors(master_params[0].detach(), model_params) 69 | 70 | 71 | def zero_grad(model_params): 72 | for param in model_params: 73 | # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group 74 | if param.grad is not None: 75 | param.grad.detach_() 76 | param.grad.zero_() 77 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Solving Masked Jigsaw Puzzles with Diffusion Vision Transformers (SPDVT)
Official PyTorch Implementation 2 | [CVPR 2024] Solving Masked Jigsaw Puzzles with Diffusion Vision Transformers 3 | 4 | ### [[Paper]](https://openaccess.thecvf.com/content/CVPR2024/papers/Liu_Solving_Masked_Jigsaw_Puzzles_with_Diffusion_Vision_Transformers_CVPR_2024_paper.pdf) [[Arxiv]](https://arxiv.org/abs/2404.07292v1) 5 | 6 | **This GitHub repository is currently undergoing organization.** Stay tuned for the upcoming release of fully functional code! 7 | 8 | Main Arch 9 | 10 | ## Setup 11 | git clone https://github.com/JinyangMarkLiu/JPDVT.git 12 | cd JPDVT 13 | 14 | ## Preparing Data 15 | Download datasets as you need. Here we give brief instructions for setting up part of the datasets we used. 16 | 17 | #### _ImageNet_ 18 | You can use this [script](https://gist.github.com/bonlime/4e0d236cf98cd5b15d977dfa03a63643) to download and prepare the _ImageNet_ dataset. If you need to download the dataset, please uncomment the first part of the script. 19 | 20 | #### _JPwLEG-3_ 21 | Download the _JPwLEG-3_ from this [Google Drive](https://drive.google.com/drive/folders/1MjPm7ar-u6H5WX6Bw2qshPiYPT_eQCZE). Only [select_image](https://drive.google.com/drive/folders/1MjPm7ar-u6H5WX6Bw2qshPiYPT_eQCZE) part is used in our experiments. 22 | 23 | ## Training 24 | We provide training scripts for training image models and video models. 25 | 26 | ### Training image models 27 | On ImageNet dataset: 28 | 29 | torchrun --nnodes=1 --nproc_per_node=4 train_JPDVT.py --dataset imagenet --data-path --image-size 192 --crop 30 | 31 | On MET dataset: 32 | 33 | torchrun --nnodes=1 --nproc_per_node=4 train_JPDVT.py --dataset met --data-path --image-size 288 --epochs 1000 34 | 35 | ## Testing 36 | 37 | 38 | ## BibTeX 39 | If you find our paper/project useful, please consider citing our paper: 40 | 41 | ```bibtex 42 | @InProceedings{Liu_2024_CVPR, 43 | author = {Liu, Jinyang and Teshome, Wondmgezahu and Ghimire, Sandesh and Sznaier, Mario and Camps, Octavia}, 44 | title = {Solving Masked Jigsaw Puzzles with Diffusion Vision Transformers}, 45 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 46 | month = {June}, 47 | year = {2024}, 48 | pages = {23009-23018} 49 | } 50 | ``` 51 | 52 | ## Acknowledgments 53 | Our codebase is mainly based on [improved diffusion](https://github.com/openai/improved-diffusion), [make a video](https://github.com/lucidrains/make-a-video-pytorch), and [DiT](https://github.com/facebookresearch/DiT). 54 | -------------------------------------------------------------------------------- /video_model/diffusion/losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for various likelihood-based losses. These are ported from the original 3 | Ho et al. diffusion models codebase: 4 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py 5 | """ 6 | 7 | import numpy as np 8 | 9 | import torch as th 10 | 11 | 12 | def normal_kl(mean1, logvar1, mean2, logvar2): 13 | """ 14 | Compute the KL divergence between two gaussians. 15 | 16 | Shapes are automatically broadcasted, so batches can be compared to 17 | scalars, among other use cases. 18 | """ 19 | tensor = None 20 | for obj in (mean1, logvar1, mean2, logvar2): 21 | if isinstance(obj, th.Tensor): 22 | tensor = obj 23 | break 24 | assert tensor is not None, "at least one argument must be a Tensor" 25 | 26 | # Force variances to be Tensors. Broadcasting helps convert scalars to 27 | # Tensors, but it does not work for th.exp(). 28 | logvar1, logvar2 = [ 29 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) 30 | for x in (logvar1, logvar2) 31 | ] 32 | 33 | return 0.5 * ( 34 | -1.0 35 | + logvar2 36 | - logvar1 37 | + th.exp(logvar1 - logvar2) 38 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2) 39 | ) 40 | 41 | 42 | def approx_standard_normal_cdf(x): 43 | """ 44 | A fast approximation of the cumulative distribution function of the 45 | standard normal. 46 | """ 47 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) 48 | 49 | 50 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 51 | """ 52 | Compute the log-likelihood of a Gaussian distribution discretizing to a 53 | given image. 54 | 55 | :param x: the target images. It is assumed that this was uint8 values, 56 | rescaled to the range [-1, 1]. 57 | :param means: the Gaussian mean Tensor. 58 | :param log_scales: the Gaussian log stddev Tensor. 59 | :return: a tensor like x of log probabilities (in nats). 60 | """ 61 | assert x.shape == means.shape == log_scales.shape 62 | centered_x = x - means 63 | inv_stdv = th.exp(-log_scales) 64 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 65 | cdf_plus = approx_standard_normal_cdf(plus_in) 66 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 67 | cdf_min = approx_standard_normal_cdf(min_in) 68 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 69 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 70 | cdf_delta = cdf_plus - cdf_min 71 | log_probs = th.where( 72 | x < -0.999, 73 | log_cdf_plus, 74 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 75 | ) 76 | assert log_probs.shape == x.shape 77 | return log_probs 78 | -------------------------------------------------------------------------------- /video_model/diffusion/dist_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for distributed training. 3 | """ 4 | 5 | import io 6 | import os 7 | import socket 8 | 9 | import blobfile as bf 10 | from mpi4py import MPI 11 | 12 | import torch as th 13 | import torch.distributed as dist 14 | 15 | # Change this to reflect your cluster layout. 16 | # The GPU for a given rank is (rank % GPUS_PER_NODE). 17 | GPUS_PER_NODE = 8 18 | 19 | SETUP_RETRY_COUNT = 3 20 | 21 | th.cuda.set_device(0) 22 | 23 | def setup_dist(): 24 | """ 25 | Setup a distributed process group. 26 | """ 27 | if dist.is_initialized(): 28 | return 29 | 30 | comm = MPI.COMM_WORLD 31 | backend = "gloo" if not th.cuda.is_available() else "nccl" 32 | 33 | if backend == "gloo": 34 | hostname = "localhost" 35 | else: 36 | hostname = socket.gethostbyname(socket.getfqdn()) 37 | os.environ["MASTER_ADDR"] = comm.bcast(hostname, root=0) 38 | os.environ["RANK"] = str(comm.rank) 39 | os.environ["WORLD_SIZE"] = str(comm.size) 40 | print(comm.size) 41 | 42 | port = comm.bcast(_find_free_port(), root=0) 43 | os.environ["MASTER_PORT"] = str(port) 44 | dist.init_process_group(backend=backend, world_size=comm.size, init_method="env://") 45 | 46 | 47 | # def dev(): 48 | # """ 49 | # Get the device to use for torch.distributed. 50 | # """ 51 | # if th.cuda.is_available(): 52 | # return th.device(f"cuda:{3}") 53 | # return th.device("cpu") 54 | 55 | 56 | def dev(): 57 | """ 58 | Get the device to use for torch.distributed. 59 | """ 60 | if th.cuda.is_available(): 61 | cuda_device = MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE # for other except skeleton 62 | # cuda_device=3 63 | # print('this is device',os.environ.get('CUDA_VISIBLE_DEVICES', '')) 64 | # import pdb 65 | # pdb.set_trace() 66 | # print(f"cuda:{cuda_device}") 67 | return th.device(f"cuda:{0}") 68 | return th.device("cpu") 69 | 70 | 71 | def load_state_dict(path, **kwargs): 72 | """ 73 | Load a PyTorch file without redundant fetches across MPI ranks. 74 | """ 75 | if MPI.COMM_WORLD.Get_rank() == 0: 76 | with bf.BlobFile(path, "rb") as f: 77 | data = f.read() 78 | else: 79 | data = None 80 | data = MPI.COMM_WORLD.bcast(data) 81 | return th.load(io.BytesIO(data), **kwargs) 82 | 83 | def load_opt_state_dict(path, **kwargs): 84 | """ 85 | Load a PyTorch file without redundant fetches across MPI ranks. 86 | """ 87 | 88 | with bf.BlobFile(path, "rb") as f: 89 | data = f.read() 90 | 91 | return th.load(io.BytesIO(data), **kwargs) 92 | 93 | 94 | def sync_params(params): 95 | """ 96 | Synchronize a sequence of Tensors across ranks from rank 0. 97 | """ 98 | for p in params: 99 | with th.no_grad(): 100 | dist.broadcast(p, 0) 101 | 102 | 103 | def _find_free_port(): 104 | try: 105 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 106 | s.bind(("", 0)) 107 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 108 | return s.getsockname()[1] 109 | finally: 110 | s.close() 111 | -------------------------------------------------------------------------------- /image_model/diffusion/diffusion_utils.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | import torch as th 7 | import numpy as np 8 | 9 | 10 | def normal_kl(mean1, logvar1, mean2, logvar2): 11 | """ 12 | Compute the KL divergence between two gaussians. 13 | Shapes are automatically broadcasted, so batches can be compared to 14 | scalars, among other use cases. 15 | """ 16 | tensor = None 17 | for obj in (mean1, logvar1, mean2, logvar2): 18 | if isinstance(obj, th.Tensor): 19 | tensor = obj 20 | break 21 | assert tensor is not None, "at least one argument must be a Tensor" 22 | 23 | # Force variances to be Tensors. Broadcasting helps convert scalars to 24 | # Tensors, but it does not work for th.exp(). 25 | logvar1, logvar2 = [ 26 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) 27 | for x in (logvar1, logvar2) 28 | ] 29 | 30 | return 0.5 * ( 31 | -1.0 32 | + logvar2 33 | - logvar1 34 | + th.exp(logvar1 - logvar2) 35 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2) 36 | ) 37 | 38 | 39 | def approx_standard_normal_cdf(x): 40 | """ 41 | A fast approximation of the cumulative distribution function of the 42 | standard normal. 43 | """ 44 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) 45 | 46 | 47 | def continuous_gaussian_log_likelihood(x, *, means, log_scales): 48 | """ 49 | Compute the log-likelihood of a continuous Gaussian distribution. 50 | :param x: the targets 51 | :param means: the Gaussian mean Tensor. 52 | :param log_scales: the Gaussian log stddev Tensor. 53 | :return: a tensor like x of log probabilities (in nats). 54 | """ 55 | centered_x = x - means 56 | inv_stdv = th.exp(-log_scales) 57 | normalized_x = centered_x * inv_stdv 58 | log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x) 59 | return log_probs 60 | 61 | 62 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 63 | """ 64 | Compute the log-likelihood of a Gaussian distribution discretizing to a 65 | given image. 66 | :param x: the target images. It is assumed that this was uint8 values, 67 | rescaled to the range [-1, 1]. 68 | :param means: the Gaussian mean Tensor. 69 | :param log_scales: the Gaussian log stddev Tensor. 70 | :return: a tensor like x of log probabilities (in nats). 71 | """ 72 | assert x.shape == means.shape == log_scales.shape 73 | centered_x = x - means 74 | inv_stdv = th.exp(-log_scales) 75 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 76 | cdf_plus = approx_standard_normal_cdf(plus_in) 77 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 78 | cdf_min = approx_standard_normal_cdf(min_in) 79 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 80 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 81 | cdf_delta = cdf_plus - cdf_min 82 | log_probs = th.where( 83 | x < -0.999, 84 | log_cdf_plus, 85 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 86 | ) 87 | assert log_probs.shape == x.shape 88 | return log_probs 89 | -------------------------------------------------------------------------------- /image_model/datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | from torch.utils.data import Dataset, DataLoader 5 | import glob 6 | from PIL import Image 7 | import torchvision.transforms as T 8 | from einops import rearrange 9 | from einops.layers.torch import Rearrange 10 | 11 | from torchvision import transforms 12 | from sklearn.model_selection import train_test_split 13 | import cv2 14 | 15 | class MET(Dataset): 16 | def __init__(self, image_dir,split): 17 | seed = 42 18 | torch.manual_seed(seed) 19 | self.split = split 20 | 21 | all_files = os.listdir(image_dir) 22 | self.image_files = [os.path.join(image_dir,all_files[0])+'/' + k for k in os.listdir(os.path.join(image_dir,all_files[0]))] 23 | self.image_files += [os.path.join(image_dir,all_files[1])+'/' + k for k in os.listdir(os.path.join(image_dir,all_files[1]))] 24 | self.image_files += [os.path.join(image_dir,all_files[2])+'/' + k for k in os.listdir(os.path.join(image_dir,all_files[2]))] 25 | # +os.listdir(os.path.join(image_dir,all_files[1]))+os.listdir(os.path.join(image_dir,all_files[2])) 26 | for image in self.image_files: 27 | if '.jpg' not in image: 28 | self.image_files.remove(image) 29 | dataset_indices = list(range(len( self.image_files))) 30 | 31 | train_indices, test_indices = train_test_split(dataset_indices, test_size=2000, random_state=seed) 32 | train_indices, val_indices = train_test_split(train_indices, test_size=1000, random_state=seed) 33 | self.train_indices = train_indices 34 | self.test_indices = test_indices 35 | self.val_indices = val_indices 36 | 37 | # Define the color jitter parameters 38 | brightness = 0.4 # Randomly adjust brightness with a maximum factor of 0.4 39 | contrast = 0.4 # Randomly adjust contrast with a maximum factor of 0.4 40 | saturation = 0.4 # Randomly adjust saturation with a maximum factor of 0.4 41 | hue = 0.1 # Randomly adjust hue with a maximum factor of 0.1 42 | 43 | # Because the Met is a much more smaller dataset, so we use more complex data augmentation here 44 | flip_probability = 0.5 45 | self.transform1 = transforms.Compose([ 46 | transforms.Resize(398), 47 | transforms.RandomCrop((398,398)), 48 | transforms.RandomHorizontalFlip(p=flip_probability), # Horizontal flipping with 0.5 probability 49 | transforms.ColorJitter(brightness=brightness, contrast=contrast, saturation=saturation, hue=hue), 50 | transforms.ToTensor(), 51 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) 52 | ]) 53 | 54 | self.transform2 = transforms.Compose([ 55 | transforms.Resize(398), 56 | transforms.CenterCrop((398,398)), 57 | transforms.ToTensor(), 58 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) 59 | ]) 60 | 61 | def __len__(self): 62 | if self.split == 'train': 63 | return len(self.train_indices) 64 | elif self.split == 'val': 65 | return len(self.val_indices) 66 | elif self.split == 'test': 67 | return len(self.test_indices) 68 | 69 | def rand_erode(self,image,n_patches): 70 | output = torch.zeros(3,96*3,96*3) 71 | crop = transforms.RandomCrop((96,96)) 72 | gap = 48 73 | patch_size = 100 74 | for i in range(n_patches): 75 | for j in range(n_patches): 76 | left = i * (patch_size + gap) 77 | upper = j * (patch_size + gap) 78 | right = left + patch_size 79 | lower = upper + patch_size 80 | 81 | patch = crop(image[:,left:right, upper:lower]) 82 | output[:,i*96:i*96+96,j*96:j*96+96] = patch 83 | 84 | return output 85 | 86 | def __getitem__(self, idx): 87 | if self.split == 'train': 88 | index = self.train_indices[idx] 89 | image = self.transform1(Image.open(self.image_files[index])) 90 | image = self.rand_erode(image,3) 91 | elif self.split == 'val': 92 | index = self.val_indices[idx] 93 | image = self.transform2(Image.open(self.image_files[index])) 94 | image = self.rand_erode(image,3) 95 | elif self.split == 'test': 96 | index = self.test_indices[idx] 97 | image = self.transform2(Image.open(self.image_files[index])) 98 | image = self.rand_erode(image,3) 99 | 100 | return image 101 | -------------------------------------------------------------------------------- /video_model/video_train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train a diffusion model on images. 3 | """ 4 | # import torch.distributed as dist 5 | import numpy as np 6 | import os 7 | import gc 8 | 9 | import torch 10 | 11 | torch.cuda.empty_cache() 12 | gc.collect() 13 | 14 | import argparse 15 | import sys 16 | 17 | sys.path.insert(1, os.getcwd()) 18 | # sys.path.insert(1, '/diffusion_openai') 19 | import dist_util, logger 20 | # from diffusion_openai import dist_util, logger 21 | from video_datasets import load_data 22 | from resample import create_named_schedule_sampler 23 | from script_util import ( 24 | model_and_diffusion_defaults, 25 | create_model_and_diffusion, 26 | args_to_dict, 27 | add_dict_to_argparser, 28 | ) 29 | 30 | from train_util import TrainLoop 31 | 32 | 33 | def main(): 34 | 35 | parser, defaults = create_argparser() 36 | args = parser.parse_args() 37 | parameters = args_to_dict(args, defaults.keys()) 38 | # th.manual_seed(args.seed) 39 | # np.random.seed(args.seed) 40 | 41 | dist_util.setup_dist() 42 | logger.configure() 43 | for key, item in parameters.items(): 44 | logger.logkv(key, item) 45 | logger.dumpkvs() 46 | 47 | logger.log("creating model and diffusion...") 48 | model, diffusion = create_model_and_diffusion( 49 | **args_to_dict(args, model_and_diffusion_defaults().keys()) 50 | ) 51 | 52 | if len(args.model_path)>0: 53 | print("load ",args.model_path," to the model.") 54 | model.load_state_dict( 55 | dist_util.load_state_dict(args.model_path, map_location="cpu") 56 | ) 57 | 58 | model.to(dist_util.dev()) 59 | # model.summary() 60 | # breakpoint() 61 | # import torch 62 | # import torchviz 63 | 64 | # model = ... # create or load your PyTorch model 65 | 66 | 67 | print("device ",dist_util.dev()) 68 | schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion) 69 | 70 | # print('model arch',model) 71 | # Save the model architecture to a file 72 | # with open("./figures/model_architecture.txt", "w") as f: 73 | # f.write(str(model)) 74 | 75 | logger.log("creating data loader...") 76 | 77 | data = load_data( 78 | data_dir=args.data_dir, 79 | batch_size=args.batch_size, 80 | image_size=args.image_size, 81 | class_cond=args.class_cond, 82 | deterministic=False, 83 | rgb=args.rgb, 84 | seq_len=args.seq_len 85 | ) 86 | 87 | # item=next(data) 88 | # breakpoint() 89 | 90 | if args.mask_range is None: 91 | mask_range = [0, args.seq_len] 92 | else: 93 | mask_range = [int(i) for i in args.mask_range if i != ","] 94 | logger.log("training...") 95 | 96 | TrainLoop( 97 | model=model, 98 | diffusion=diffusion, 99 | data=data, 100 | batch_size=args.batch_size, 101 | microbatch=args.microbatch, 102 | lr=args.lr, 103 | ema_rate=args.ema_rate, 104 | log_interval=args.log_interval, 105 | save_interval=args.save_interval, 106 | resume_checkpoint=args.resume_checkpoint, 107 | use_fp16=args.use_fp16, 108 | fp16_scale_growth=args.fp16_scale_growth, 109 | schedule_sampler=schedule_sampler, 110 | weight_decay=args.weight_decay, 111 | lr_anneal_steps=args.lr_anneal_steps, 112 | clip=args.clip, 113 | anneal_type=args.anneal_type, 114 | steps_drop=args.steps_drop, 115 | drop=args.drop, 116 | decay=args.decay, 117 | max_num_mask_frames=args.max_num_mask_frames, 118 | mask_range=mask_range, 119 | uncondition_rate=args.uncondition_rate, 120 | exclude_conditional=args.exclude_conditional, 121 | ).run_loop() 122 | 123 | 124 | def create_argparser(): 125 | defaults = dict( 126 | data_dir="", 127 | schedule_sampler="uniform", 128 | lr=1e-4, # -4 129 | weight_decay=0.0, 130 | lr_anneal_steps=0, 131 | batch_size=16, # 8 for something 132 | microbatch=16, # 32, # -1 disables microbatches 32 133 | ema_rate="0.9999", # comma-separated list of EMA values 134 | log_interval=32, # 10 135 | save_interval=10000, # 2000 100 136 | resume_checkpoint="", 137 | model_path="", 138 | use_fp16=False, 139 | fp16_scale_growth=1e-3, 140 | clip=1, 141 | seed=123, 142 | anneal_type=None, 143 | steps_drop=0.0, 144 | drop=0.0, 145 | decay=0.0, 146 | seq_len=32, # 20 147 | max_num_mask_frames=6, # 4 148 | mask_range=None, 149 | uncondition_rate=0, 150 | exclude_conditional=True, 151 | model_name='two_distribution', 152 | ) 153 | 154 | defaults.update(model_and_diffusion_defaults()) 155 | parser = argparse.ArgumentParser() 156 | add_dict_to_argparser(parser, defaults) 157 | return parser, defaults 158 | 159 | if __name__ == "__main__": 160 | main() 161 | import json 162 | import numpy as np 163 | -------------------------------------------------------------------------------- /image_model/sample.py: -------------------------------------------------------------------------------- 1 | """ 2 | Solve Jigsaw Puzzles with JPDVT 3 | """ 4 | import torch 5 | torch.backends.cuda.matmul.allow_tf32 = True 6 | torch.backends.cudnn.allow_tf32 = True 7 | from torchvision.utils import save_image 8 | from diffusion import create_diffusion 9 | from diffusers.models import AutoencoderKL 10 | from models import DiT_models 11 | import argparse 12 | from torch.utils.data import DataLoader 13 | from models import get_2d_sincos_pos_embed 14 | from datasets import MET 15 | import numpy as np 16 | from einops import rearrange 17 | import matplotlib.pyplot as plt 18 | from sklearn.metrics import pairwise_distances 19 | 20 | def main(args): 21 | # Setup PyTorch: 22 | torch.manual_seed(args.seed) 23 | torch.set_grad_enabled(False) 24 | device = "cuda" if torch.cuda.is_available() else "cpu" 25 | 26 | template = np.zeros((6,6)) 27 | 28 | for i in range(6): 29 | for j in range(6): 30 | template[i,j] = 18 * i + j 31 | 32 | template = np.concatenate((template,template,template),axis=0) 33 | template = np.concatenate((template,template,template),axis=1) 34 | 35 | # Load model: 36 | model = DiT_models[args.model]( 37 | input_size=args.image_size, 38 | ).to(device) 39 | print("Load model from:", args.ckpt ) 40 | ckpt_path = args.ckpt 41 | state_dict = torch.load(ckpt_path) 42 | model.load_state_dict(state_dict) 43 | 44 | # Because the batchnorm doesn't work normally when batch size is 1 45 | # Thus we set the model to train mode 46 | model.train() 47 | 48 | diffusion = create_diffusion(str(args.num_sampling_steps)) 49 | if args.dataset == "met": 50 | # MET dataloader give out cropped and stitched back images 51 | dataset = MET(args.data_path,'test') 52 | elif args.dataset == "imagenet": 53 | dataset = ImageFolder(args.data_path, transform=transform) 54 | 55 | loader = DataLoader( 56 | dataset, 57 | batch_size=1, 58 | shuffle=False, 59 | num_workers=2, 60 | pin_memory=False, 61 | drop_last=True 62 | ) 63 | 64 | time_emb = torch.tensor(get_2d_sincos_pos_embed(8, 3)).unsqueeze(0).float().to(device) 65 | time_emb_noise = torch.tensor(get_2d_sincos_pos_embed(8, 18)).unsqueeze(0).float().to(device) 66 | time_emb_noise = torch.randn_like(time_emb_noise) 67 | time_emb_noise = time_emb_noise.repeat(1,1,1) 68 | model_kwargs = None 69 | 70 | # find the order with a greedy algorithm 71 | def find_permutation(distance_matrix): 72 | sort_list = [] 73 | for m in range(distance_matrix.shape[1]): 74 | order = distance_matrix[:,0].argmin() 75 | sort_list.append(order) 76 | distance_matrix = distance_matrix[:,1:] 77 | distance_matrix[order,:] = 2024 78 | return sort_list 79 | 80 | abs_results = [] 81 | for x in loader: 82 | if args.dataset == 'imagenet': 83 | x, _ = x 84 | x = x.to(device) 85 | if args.dataset == 'imagenet' and args.crop: 86 | centercrop = transforms.CenterCrop((64,64)) 87 | patchs = rearrange(x, 'b c (p1 h1) (p2 w1)-> b c (p1 p2) h1 w1',p1=3,p2=3,h1=96,w1=96) 88 | patchs = centercrop(patchs) 89 | x = rearrange(patchs, 'b c (p1 p2) h1 w1-> b c (p1 h1) (p2 w1)',p1=3,p2=3,h1=64,w1=64) 90 | 91 | # Generate the Puzzles 92 | indices = np.random.permutation(9) 93 | x = rearrange(x, 'b c (p1 h1) (p2 w1)-> b c (p1 p2) h1 w1',p1=3,p2=3,h1=args.image_size//3,w1=args.image_size//3) 94 | x = x[:,:,indices,:,:] 95 | x = rearrange(x, ' b c (p1 p2) h1 w1->b c (p1 h1) (p2 w1)',p1=3,p2=3,h1=args.image_size//3,w1=args.image_size//3) 96 | 97 | samples = diffusion.p_sample_loop( 98 | model.forward, x, time_emb_noise.shape, time_emb_noise, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device 99 | ) 100 | for sample,img in zip(samples,x): 101 | sample = rearrange(sample, '(p1 h1 p2 w1) d-> (p1 p2) (h1 w1) d',p1=3,p2=3,h1=args.image_size//48,w1=args.image_size//48) 102 | sample = sample.mean(1) 103 | dist = pairwise_distances(sample.cpu().numpy(), time_emb[0].cpu().numpy(), metric='manhattan') 104 | order = find_permutation(dist) 105 | pred = np.asarray(order).argsort() 106 | abs_results.append(int((pred == indices).all())) 107 | 108 | print("test result on ",len(abs_results), "samples is :", np.asarray(abs_results).sum()/len(abs_results)) 109 | 110 | if len(abs_results)>=2000 and args.dataset == "met": 111 | break 112 | if len(abs_results)>=50000 and args.dataset == "imagenet": 113 | break 114 | 115 | if __name__ == "__main__": 116 | parser = argparse.ArgumentParser() 117 | parser.add_argument("--model", type=str, default="JPDVT") 118 | parser.add_argument("--dataset", type=str, choices=["imagenet", "met"], default="imagenet") 119 | parser.add_argument("--data-path", type=str,required=True) 120 | parser.add_argument("--crop", action='store_true', default=False) 121 | parser.add_argument("--image-size", type=int, choices=[192, 288], default=288) 122 | parser.add_argument("--num-sampling-steps", type=int, default=250) 123 | parser.add_argument("--seed", type=int, default=0) 124 | parser.add_argument("--ckpt", type=str, required=True) 125 | args = parser.parse_args() 126 | main(args) 127 | -------------------------------------------------------------------------------- /video_model/diffusion/respace.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch as th 3 | 4 | # from video_sample import create_argparser 5 | # from video_datasets import load_data 6 | from gaussian_diffusion import GaussianDiffusion 7 | 8 | def space_timesteps(num_timesteps, section_counts): 9 | """ 10 | Create a list of timesteps to use from an original diffusion process, 11 | given the number of timesteps we want to take from equally-sized portions 12 | of the original process. 13 | 14 | For example, if there's 300 timesteps and the section counts are [10,15,20] 15 | then the first 100 timesteps are strided to be 10 timesteps, the second 100 16 | are strided to be 15 timesteps, and the final 100 are strided to be 20. 17 | 18 | If the stride is a string starting with "ddim", then the fixed striding 19 | from the DDIM paper is used, and only one section is allowed. 20 | 21 | :param num_timesteps: the number of diffusion steps in the original 22 | process to divide up. 23 | :param section_counts: either a list of numbers, or a string containing 24 | comma-separated numbers, indicating the step count 25 | per section. As a special case, use "ddimN" where N 26 | is a number of steps to use the striding from the 27 | DDIM paper. 28 | :return: a set of diffusion steps from the original process to use. 29 | """ 30 | if isinstance(section_counts, str): 31 | if section_counts.startswith("ddim"): 32 | desired_count = int(section_counts[len("ddim"):]) 33 | for i in range(1, num_timesteps): 34 | if len(range(0, num_timesteps, i)) == desired_count: 35 | return set(range(0, num_timesteps, i)) 36 | raise ValueError( 37 | f"cannot create exactly {num_timesteps} steps with an integer stride" 38 | ) 39 | section_counts = [int(x) for x in section_counts.split(",")] 40 | size_per = num_timesteps // len(section_counts) 41 | extra = num_timesteps % len(section_counts) 42 | start_idx = 0 43 | all_steps = [] 44 | for i, section_count in enumerate(section_counts): 45 | size = size_per + (1 if i < extra else 0) 46 | if size < section_count: 47 | raise ValueError( 48 | f"cannot divide section of {size} steps into {section_count}" 49 | ) 50 | if section_count <= 1: 51 | frac_stride = 1 52 | else: 53 | frac_stride = (size - 1) / (section_count - 1) 54 | cur_idx = 0.0 55 | taken_steps = [] 56 | for _ in range(section_count): 57 | taken_steps.append(start_idx + round(cur_idx)) 58 | cur_idx += frac_stride 59 | all_steps += taken_steps 60 | start_idx += size 61 | return set(all_steps) 62 | 63 | 64 | class SpacedDiffusion(GaussianDiffusion): 65 | """ 66 | A diffusion process which can skip steps in a base diffusion process. 67 | 68 | :param use_timesteps: a collection (sequence or set) of timesteps from the 69 | original diffusion process to retain. 70 | :param kwargs: the kwargs to create the base diffusion process. 71 | """ 72 | 73 | def __init__(self, use_timesteps, **kwargs): 74 | self.use_timesteps = set(use_timesteps) 75 | self.timestep_map = [] 76 | self.original_num_steps = len(kwargs["betas"]) 77 | 78 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa 79 | last_alpha_cumprod = 1.0 80 | new_betas = [] 81 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): 82 | if i in self.use_timesteps: 83 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) 84 | last_alpha_cumprod = alpha_cumprod 85 | self.timestep_map.append(i) 86 | kwargs["betas"] = np.array(new_betas) 87 | super().__init__(**kwargs) 88 | 89 | def p_mean_variance( 90 | self, model, *args, **kwargs 91 | ): # pylint: disable=signature-differs 92 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) 93 | 94 | def training_losses( 95 | self, model, *args, **kwargs 96 | ): # pylint: disable=signature-differs 97 | return super().training_losses(self._wrap_model(model), *args, **kwargs) 98 | 99 | def _wrap_model(self, model): 100 | if isinstance(model, _WrappedModel): 101 | return model 102 | return _WrappedModel( 103 | model, self.timestep_map, self.rescale_timesteps, self.original_num_steps 104 | ) 105 | 106 | def _scale_timesteps(self, t): 107 | # Scaling is done by the wrapped model. 108 | return t 109 | 110 | 111 | class _WrappedModel: 112 | def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps): 113 | self.model = model 114 | self.timestep_map = timestep_map 115 | self.rescale_timesteps = rescale_timesteps 116 | self.original_num_steps = original_num_steps 117 | 118 | 119 | def __call__(self, x, ts,condition_args): 120 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) 121 | new_ts = map_tensor[ts] 122 | 123 | if self.rescale_timesteps: 124 | new_ts = new_ts.float() * (1000.0 / self.original_num_steps) 125 | 126 | # breakpoint() 127 | 128 | return self.model(x, new_ts,condition_args) 129 | -------------------------------------------------------------------------------- /video_model/diffusion/nn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various utilities for neural networks. 3 | """ 4 | 5 | import math 6 | 7 | import torch as th 8 | import torch.nn as nn 9 | 10 | 11 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 12 | class SiLU(nn.Module): 13 | def forward(self, x): 14 | return x * th.sigmoid(x) 15 | 16 | 17 | class GroupNorm32(nn.GroupNorm): 18 | def forward(self, x): 19 | return super().forward(x.float()).type(x.dtype) 20 | 21 | 22 | def conv_nd(dims, *args, **kwargs): 23 | """ 24 | Create a 1D, 2D, or 3D convolution module. 25 | """ 26 | if dims == 1: 27 | return nn.Conv1d(*args, **kwargs) 28 | elif dims == 2: 29 | return nn.Conv2d(*args, **kwargs) 30 | elif dims == 3: 31 | return nn.Conv3d(*args, **kwargs) 32 | raise ValueError(f"unsupported dimensions: {dims}") 33 | 34 | 35 | def linear(*args, **kwargs): 36 | """ 37 | Create a linear module. 38 | """ 39 | return nn.Linear(*args, **kwargs) 40 | 41 | 42 | def avg_pool_nd(dims, *args, **kwargs): 43 | """ 44 | Create a 1D, 2D, or 3D average pooling module. 45 | """ 46 | if dims == 1: 47 | return nn.AvgPool1d(*args, **kwargs) 48 | elif dims == 2: 49 | return nn.AvgPool2d(*args, **kwargs) 50 | elif dims == 3: 51 | return nn.AvgPool3d(*args, **kwargs) 52 | raise ValueError(f"unsupported dimensions: {dims}") 53 | 54 | 55 | def update_ema(target_params, source_params, rate=0.99): 56 | """ 57 | Update target parameters to be closer to those of source parameters using 58 | an exponential moving average. 59 | 60 | :param target_params: the target parameter sequence. 61 | :param source_params: the source parameter sequence. 62 | :param rate: the EMA rate (closer to 1 means slower). 63 | """ 64 | for targ, src in zip(target_params, source_params): 65 | targ.detach().mul_(rate).add_(src, alpha=1 - rate) 66 | 67 | 68 | def zero_module(module): 69 | """ 70 | Zero out the parameters of a module and return it. 71 | """ 72 | for p in module.parameters(): 73 | p.detach().zero_() 74 | return module 75 | 76 | 77 | def scale_module(module, scale): 78 | """ 79 | Scale the parameters of a module and return it. 80 | """ 81 | for p in module.parameters(): 82 | p.detach().mul_(scale) 83 | return module 84 | 85 | 86 | def mean_flat(tensor): 87 | """ 88 | Take the mean over all non-batch dimensions. 89 | """ 90 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 91 | 92 | def sum_flat(tensor): 93 | """ 94 | Take the sum over all non-batch dimensions. 95 | """ 96 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 97 | 98 | def normalization(channels): 99 | """ 100 | Make a standard normalization layer. 101 | 102 | :param channels: number of input channels. 103 | :return: an nn.Module for normalization. 104 | """ 105 | return GroupNorm32(32, channels) 106 | 107 | 108 | def timestep_embedding(timesteps, dim, max_period=10000): 109 | """ 110 | Create sinusoidal timestep embeddings. 111 | 112 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 113 | These may be fractional. 114 | :param dim: the dimension of the output. 115 | :param max_period: controls the minimum frequency of the embeddings. 116 | :return: an [N x dim] Tensor of positional embeddings. 117 | """ 118 | half = dim // 2 119 | freqs = th.exp( 120 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half 121 | ).to(device=timesteps.device) 122 | args = timesteps[:, None].float() * freqs[None] 123 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) 124 | if dim % 2: 125 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) 126 | return embedding 127 | 128 | 129 | def checkpoint(func, inputs, params, flag): 130 | """ 131 | Evaluate a function without caching intermediate activations, allowing for 132 | reduced memory at the expense of extra compute in the backward pass. 133 | 134 | :param func: the function to evaluate. 135 | :param inputs: the argument sequence to pass to `func`. 136 | :param params: a sequence of parameters `func` depends on but does not 137 | explicitly take as arguments. 138 | :param flag: if False, disable gradient checkpointing. 139 | """ 140 | if flag: 141 | args = tuple(inputs) + tuple(params) 142 | return CheckpointFunction.apply(func, len(inputs), *args) 143 | else: 144 | return func(*inputs) 145 | 146 | 147 | class CheckpointFunction(th.autograd.Function): 148 | @staticmethod 149 | def forward(ctx, run_function, length, *args): 150 | ctx.run_function = run_function 151 | ctx.input_tensors = list(args[:length]) 152 | ctx.input_params = list(args[length:]) 153 | with th.no_grad(): 154 | output_tensors = ctx.run_function(*ctx.input_tensors) 155 | return output_tensors 156 | 157 | @staticmethod 158 | def backward(ctx, *output_grads): 159 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 160 | with th.enable_grad(): 161 | # Fixes a bug where the first op in run_function modifies the 162 | # Tensor storage in place, which is not allowed for detach()'d 163 | # Tensors. 164 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 165 | output_tensors = ctx.run_function(*shallow_copies) 166 | input_grads = th.autograd.grad( 167 | output_tensors, 168 | ctx.input_tensors + ctx.input_params, 169 | output_grads, 170 | allow_unused=True, 171 | ) 172 | del ctx.input_tensors 173 | del ctx.input_params 174 | del output_tensors 175 | return (None, None) + input_grads 176 | -------------------------------------------------------------------------------- /image_model/diffusion/respace.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | import numpy as np 7 | import torch as th 8 | 9 | from .gaussian_diffusion import GaussianDiffusion 10 | 11 | 12 | def space_timesteps(num_timesteps, section_counts): 13 | """ 14 | Create a list of timesteps to use from an original diffusion process, 15 | given the number of timesteps we want to take from equally-sized portions 16 | of the original process. 17 | For example, if there's 300 timesteps and the section counts are [10,15,20] 18 | then the first 100 timesteps are strided to be 10 timesteps, the second 100 19 | are strided to be 15 timesteps, and the final 100 are strided to be 20. 20 | If the stride is a string starting with "ddim", then the fixed striding 21 | from the DDIM paper is used, and only one section is allowed. 22 | :param num_timesteps: the number of diffusion steps in the original 23 | process to divide up. 24 | :param section_counts: either a list of numbers, or a string containing 25 | comma-separated numbers, indicating the step count 26 | per section. As a special case, use "ddimN" where N 27 | is a number of steps to use the striding from the 28 | DDIM paper. 29 | :return: a set of diffusion steps from the original process to use. 30 | """ 31 | if isinstance(section_counts, str): 32 | if section_counts.startswith("ddim"): 33 | desired_count = int(section_counts[len("ddim") :]) 34 | for i in range(1, num_timesteps): 35 | if len(range(0, num_timesteps, i)) == desired_count: 36 | return set(range(0, num_timesteps, i)) 37 | raise ValueError( 38 | f"cannot create exactly {num_timesteps} steps with an integer stride" 39 | ) 40 | section_counts = [int(x) for x in section_counts.split(",")] 41 | size_per = num_timesteps // len(section_counts) 42 | extra = num_timesteps % len(section_counts) 43 | start_idx = 0 44 | all_steps = [] 45 | for i, section_count in enumerate(section_counts): 46 | size = size_per + (1 if i < extra else 0) 47 | if size < section_count: 48 | raise ValueError( 49 | f"cannot divide section of {size} steps into {section_count}" 50 | ) 51 | if section_count <= 1: 52 | frac_stride = 1 53 | else: 54 | frac_stride = (size - 1) / (section_count - 1) 55 | cur_idx = 0.0 56 | taken_steps = [] 57 | for _ in range(section_count): 58 | taken_steps.append(start_idx + round(cur_idx)) 59 | cur_idx += frac_stride 60 | all_steps += taken_steps 61 | start_idx += size 62 | return set(all_steps) 63 | 64 | 65 | class SpacedDiffusion(GaussianDiffusion): 66 | """ 67 | A diffusion process which can skip steps in a base diffusion process. 68 | :param use_timesteps: a collection (sequence or set) of timesteps from the 69 | original diffusion process to retain. 70 | :param kwargs: the kwargs to create the base diffusion process. 71 | """ 72 | 73 | def __init__(self, use_timesteps, **kwargs): 74 | self.use_timesteps = set(use_timesteps) 75 | self.timestep_map = [] 76 | self.original_num_steps = len(kwargs["betas"]) 77 | 78 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa 79 | last_alpha_cumprod = 1.0 80 | new_betas = [] 81 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): 82 | if i in self.use_timesteps: 83 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) 84 | last_alpha_cumprod = alpha_cumprod 85 | self.timestep_map.append(i) 86 | kwargs["betas"] = np.array(new_betas) 87 | super().__init__(**kwargs) 88 | 89 | def p_mean_variance( 90 | self, model, *args, **kwargs 91 | ): # pylint: disable=signature-differs 92 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) 93 | 94 | def training_losses( 95 | self, model, *args, **kwargs 96 | ): # pylint: disable=signature-differs 97 | return super().training_losses(self._wrap_model(model), *args, **kwargs) 98 | 99 | def condition_mean(self, cond_fn, *args, **kwargs): 100 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) 101 | 102 | def condition_score(self, cond_fn, *args, **kwargs): 103 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) 104 | 105 | def _wrap_model(self, model): 106 | if isinstance(model, _WrappedModel): 107 | return model 108 | return _WrappedModel( 109 | model, self.timestep_map, self.original_num_steps 110 | ) 111 | 112 | def _scale_timesteps(self, t): 113 | # Scaling is done by the wrapped model. 114 | return t 115 | 116 | 117 | class _WrappedModel: 118 | def __init__(self, model, timestep_map, original_num_steps): 119 | self.model = model 120 | self.timestep_map = timestep_map 121 | # self.rescale_timesteps = rescale_timesteps 122 | self.original_num_steps = original_num_steps 123 | 124 | def __call__(self, x, ts, time_emb, **kwargs): 125 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) 126 | new_ts = map_tensor[ts] 127 | # if self.rescale_timesteps: 128 | # new_ts = new_ts.float() * (1000.0 / self.original_num_steps) 129 | return self.model(x, new_ts, time_emb, **kwargs) 130 | -------------------------------------------------------------------------------- /image_model/diffusion/timestep_sampler.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | from abc import ABC, abstractmethod 7 | 8 | import numpy as np 9 | import torch as th 10 | import torch.distributed as dist 11 | 12 | 13 | def create_named_schedule_sampler(name, diffusion): 14 | """ 15 | Create a ScheduleSampler from a library of pre-defined samplers. 16 | :param name: the name of the sampler. 17 | :param diffusion: the diffusion object to sample for. 18 | """ 19 | if name == "uniform": 20 | return UniformSampler(diffusion) 21 | elif name == "loss-second-moment": 22 | return LossSecondMomentResampler(diffusion) 23 | else: 24 | raise NotImplementedError(f"unknown schedule sampler: {name}") 25 | 26 | 27 | class ScheduleSampler(ABC): 28 | """ 29 | A distribution over timesteps in the diffusion process, intended to reduce 30 | variance of the objective. 31 | By default, samplers perform unbiased importance sampling, in which the 32 | objective's mean is unchanged. 33 | However, subclasses may override sample() to change how the resampled 34 | terms are reweighted, allowing for actual changes in the objective. 35 | """ 36 | 37 | @abstractmethod 38 | def weights(self): 39 | """ 40 | Get a numpy array of weights, one per diffusion step. 41 | The weights needn't be normalized, but must be positive. 42 | """ 43 | 44 | def sample(self, batch_size, device): 45 | """ 46 | Importance-sample timesteps for a batch. 47 | :param batch_size: the number of timesteps. 48 | :param device: the torch device to save to. 49 | :return: a tuple (timesteps, weights): 50 | - timesteps: a tensor of timestep indices. 51 | - weights: a tensor of weights to scale the resulting losses. 52 | """ 53 | w = self.weights() 54 | p = w / np.sum(w) 55 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p) 56 | indices = th.from_numpy(indices_np).long().to(device) 57 | weights_np = 1 / (len(p) * p[indices_np]) 58 | weights = th.from_numpy(weights_np).float().to(device) 59 | return indices, weights 60 | 61 | 62 | class UniformSampler(ScheduleSampler): 63 | def __init__(self, diffusion): 64 | self.diffusion = diffusion 65 | self._weights = np.ones([diffusion.num_timesteps]) 66 | 67 | def weights(self): 68 | return self._weights 69 | 70 | 71 | class LossAwareSampler(ScheduleSampler): 72 | def update_with_local_losses(self, local_ts, local_losses): 73 | """ 74 | Update the reweighting using losses from a model. 75 | Call this method from each rank with a batch of timesteps and the 76 | corresponding losses for each of those timesteps. 77 | This method will perform synchronization to make sure all of the ranks 78 | maintain the exact same reweighting. 79 | :param local_ts: an integer Tensor of timesteps. 80 | :param local_losses: a 1D Tensor of losses. 81 | """ 82 | batch_sizes = [ 83 | th.tensor([0], dtype=th.int32, device=local_ts.device) 84 | for _ in range(dist.get_world_size()) 85 | ] 86 | dist.all_gather( 87 | batch_sizes, 88 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), 89 | ) 90 | 91 | # Pad all_gather batches to be the maximum batch size. 92 | batch_sizes = [x.item() for x in batch_sizes] 93 | max_bs = max(batch_sizes) 94 | 95 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] 96 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] 97 | dist.all_gather(timestep_batches, local_ts) 98 | dist.all_gather(loss_batches, local_losses) 99 | timesteps = [ 100 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] 101 | ] 102 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] 103 | self.update_with_all_losses(timesteps, losses) 104 | 105 | @abstractmethod 106 | def update_with_all_losses(self, ts, losses): 107 | """ 108 | Update the reweighting using losses from a model. 109 | Sub-classes should override this method to update the reweighting 110 | using losses from the model. 111 | This method directly updates the reweighting without synchronizing 112 | between workers. It is called by update_with_local_losses from all 113 | ranks with identical arguments. Thus, it should have deterministic 114 | behavior to maintain state across workers. 115 | :param ts: a list of int timesteps. 116 | :param losses: a list of float losses, one per timestep. 117 | """ 118 | 119 | 120 | class LossSecondMomentResampler(LossAwareSampler): 121 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): 122 | self.diffusion = diffusion 123 | self.history_per_term = history_per_term 124 | self.uniform_prob = uniform_prob 125 | self._loss_history = np.zeros( 126 | [diffusion.num_timesteps, history_per_term], dtype=np.float64 127 | ) 128 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) 129 | 130 | def weights(self): 131 | if not self._warmed_up(): 132 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64) 133 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) 134 | weights /= np.sum(weights) 135 | weights *= 1 - self.uniform_prob 136 | weights += self.uniform_prob / len(weights) 137 | return weights 138 | 139 | def update_with_all_losses(self, ts, losses): 140 | for t, loss in zip(ts, losses): 141 | if self._loss_counts[t] == self.history_per_term: 142 | # Shift out the oldest loss term. 143 | self._loss_history[t, :-1] = self._loss_history[t, 1:] 144 | self._loss_history[t, -1] = loss 145 | else: 146 | self._loss_history[t, self._loss_counts[t]] = loss 147 | self._loss_counts[t] += 1 148 | 149 | def _warmed_up(self): 150 | return (self._loss_counts == self.history_per_term).all() 151 | -------------------------------------------------------------------------------- /video_model/diffusion/resample.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import numpy as np 4 | import torch as th 5 | import torch.distributed as dist 6 | 7 | 8 | def create_named_schedule_sampler(name, diffusion, k=1): 9 | """ 10 | Create a ScheduleSampler from a library of pre-defined samplers. 11 | 12 | :param name: the name of the sampler. 13 | :param diffusion: the diffusion object to sample for. 14 | :param k: the series order. Default k=1. 15 | """ 16 | if name == "uniform": 17 | return UniformSampler(diffusion) 18 | elif name == "harmonic": 19 | return HarmonicSampler(diffusion, k) 20 | elif name == "loss-second-moment": 21 | return LossSecondMomentResampler(diffusion) 22 | else: 23 | raise NotImplementedError(f"unknown schedule sampler: {name}") 24 | 25 | 26 | class ScheduleSampler(ABC): 27 | """ 28 | A distribution over timesteps in the diffusion process, intended to reduce 29 | variance of the objective. 30 | 31 | By default, samplers perform unbiased importance sampling, in which the 32 | objective's mean is unchanged. 33 | However, subclasses may override sample() to change how the resampled 34 | terms are reweighted, allowing for actual changes in the objective. 35 | """ 36 | 37 | @abstractmethod 38 | def weights(self): 39 | """ 40 | Get a numpy array of weights, one per diffusion step. 41 | 42 | The weights needn't be normalized, but must be positive. 43 | """ 44 | 45 | def sample(self, batch_size, device): 46 | """ 47 | Importance-sample timesteps for a batch. 48 | 49 | :param batch_size: the number of timesteps. 50 | :param device: the torch device to save to. 51 | :return: a tuple (timesteps, weights): 52 | - timesteps: a tensor of timestep indices. 53 | - weights: a tensor of weights to scale the resulting losses. 54 | """ 55 | w = self.weights() 56 | p = w / np.sum(w) 57 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p) 58 | indices = th.from_numpy(indices_np).long().to(device) 59 | weights_np = 1 / (len(p) * p[indices_np]) 60 | weights = th.from_numpy(weights_np).float().to(device) 61 | return indices, weights 62 | 63 | 64 | class UniformSampler(ScheduleSampler): 65 | def __init__(self, diffusion): 66 | self.diffusion = diffusion 67 | self._weights = np.ones([diffusion.num_timesteps]) 68 | 69 | def weights(self): 70 | return self._weights 71 | 72 | class HarmonicSampler(ScheduleSampler): 73 | def __init__(self, diffusion, k=1): 74 | self.diffusion = diffusion 75 | w = 1. / np.array([t+1 for t in range(diffusion.num_timesteps)]) 76 | w = w**k 77 | self._weights = w 78 | 79 | def weights(self): 80 | return self._weights 81 | 82 | 83 | class LossAwareSampler(ScheduleSampler): 84 | def update_with_local_losses(self, local_ts, local_losses): 85 | """ 86 | Update the reweighting using losses from a model. 87 | 88 | Call this method from each rank with a batch of timesteps and the 89 | corresponding losses for each of those timesteps. 90 | This method will perform synchronization to make sure all of the ranks 91 | maintain the exact same reweighting. 92 | 93 | :param local_ts: an integer Tensor of timesteps. 94 | :param local_losses: a 1D Tensor of losses. 95 | """ 96 | batch_sizes = [ 97 | th.tensor([0], dtype=th.int32, device=local_ts.device) 98 | for _ in range(dist.get_world_size()) 99 | ] 100 | dist.all_gather( 101 | batch_sizes, 102 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), 103 | ) 104 | 105 | # Pad all_gather batches to be the maximum batch size. 106 | batch_sizes = [x.item() for x in batch_sizes] 107 | max_bs = max(batch_sizes) 108 | 109 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] 110 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] 111 | dist.all_gather(timestep_batches, local_ts) 112 | dist.all_gather(loss_batches, local_losses) 113 | timesteps = [ 114 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] 115 | ] 116 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] 117 | self.update_with_all_losses(timesteps, losses) 118 | 119 | @abstractmethod 120 | def update_with_all_losses(self, ts, losses): 121 | """ 122 | Update the reweighting using losses from a model. 123 | 124 | Sub-classes should override this method to update the reweighting 125 | using losses from the model. 126 | 127 | This method directly updates the reweighting without synchronizing 128 | between workers. It is called by update_with_local_losses from all 129 | ranks with identical arguments. Thus, it should have deterministic 130 | behavior to maintain state across workers. 131 | 132 | :param ts: a list of int timesteps. 133 | :param losses: a list of float losses, one per timestep. 134 | """ 135 | 136 | 137 | class LossSecondMomentResampler(LossAwareSampler): 138 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): 139 | self.diffusion = diffusion 140 | self.history_per_term = history_per_term 141 | self.uniform_prob = uniform_prob 142 | self._loss_history = np.zeros( 143 | [diffusion.num_timesteps, history_per_term], dtype=np.float64 144 | ) 145 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) 146 | 147 | def weights(self): 148 | if not self._warmed_up(): 149 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64) 150 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) 151 | weights /= np.sum(weights) 152 | weights *= 1 - self.uniform_prob 153 | weights += self.uniform_prob / len(weights) 154 | return weights 155 | 156 | def update_with_all_losses(self, ts, losses): 157 | for t, loss in zip(ts, losses): 158 | if self._loss_counts[t] == self.history_per_term: 159 | # Shift out the oldest loss term. 160 | self._loss_history[t, :-1] = self._loss_history[t, 1:] 161 | self._loss_history[t, -1] = loss 162 | else: 163 | self._loss_history[t, self._loss_counts[t]] = loss 164 | self._loss_counts[t] += 1 165 | 166 | def _warmed_up(self): 167 | return (self._loss_counts == self.history_per_term).all() 168 | -------------------------------------------------------------------------------- /video_model/datasets/MovingMNIST.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch.utils.data as data 3 | from PIL import Image 4 | import os 5 | import os.path 6 | import errno 7 | import numpy as np 8 | import torch 9 | import codecs 10 | import torchvision.transforms as T 11 | 12 | class MovingMNIST(data.Dataset): 13 | """`MovingMNIST `_ Dataset. 14 | 15 | Args: 16 | root (string): Root directory of dataset where ``processed/training.pt`` 17 | and ``processed/test.pt`` exist. 18 | train (bool, optional): If True, creates dataset from ``training.pt``, 19 | otherwise from ``test.pt``. 20 | split (int, optional): Train/test split size. Number defines how many samples 21 | belong to test set. 22 | download (bool, optional): If true, downloads the dataset from the internet and 23 | puts it in root directory. If dataset is already downloaded, it is not 24 | downloaded again. 25 | transform (callable, optional): A function/transform that takes in an PIL image 26 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 27 | target_transform (callable, optional): A function/transform that takes in an PIL 28 | image and returns a transformed version. E.g, ``transforms.RandomCrop`` 29 | """ 30 | urls = [ 31 | 'https://github.com/tychovdo/MovingMNIST/raw/master/mnist_test_seq.npy.gz' 32 | ] 33 | raw_folder = 'raw' 34 | processed_folder = 'processed' 35 | training_file = 'moving_mnist_train.pt' 36 | test_file = 'moving_mnist_test.pt' 37 | 38 | def __init__(self, root, train=True, split=1000, transform=None, target_transform=None, download=False,image_size=32,seq_len = 20): 39 | self.root = os.path.expanduser(root) 40 | self.transform = transform 41 | self.target_transform = target_transform 42 | self.split = split 43 | self.train = train # training set or test set 44 | self.image_size = image_size 45 | self.seq_len = seq_len 46 | if download: 47 | self.download() 48 | 49 | if not self._check_exists(): 50 | raise RuntimeError('Dataset not found.' + 51 | ' You can use download=True to download it') 52 | 53 | if self.train: 54 | self.train_data = torch.load( 55 | os.path.join(self.root, self.processed_folder, self.training_file)) 56 | else: 57 | self.test_data = torch.load( 58 | os.path.join(self.root, self.processed_folder, self.test_file)) 59 | 60 | def __getitem__(self, index): 61 | """ 62 | Args: 63 | index (int): Index 64 | 65 | Returns: 66 | tuple: (seq, target) where sampled sequences are splitted into a seq 67 | and target part 68 | """ 69 | 70 | # need to iterate over time 71 | def _transform_time(data): 72 | new_data = None 73 | for i in range(data.size(0)): 74 | img = Image.fromarray(data[i].numpy(), mode='L') 75 | new_data = self.transform(img) if new_data is None else torch.cat([self.transform(img), new_data], dim=0) 76 | return new_data 77 | 78 | if self.train: 79 | seq, target = self.train_data[index,:self.seq_len], {} 80 | else: 81 | seq, target = self.test_data[index,:self.seq_len], {} 82 | 83 | if self.transform is not None: 84 | seq = _transform_time(seq) 85 | if self.target_transform is not None: 86 | target = _transform_time(target) 87 | seq = (seq.type(torch.float32) / 127.5 - 1).unsqueeze(0) 88 | resize = T.Resize((self.image_size,self.image_size)) 89 | seq = resize(seq) 90 | return seq, target 91 | 92 | def __len__(self): 93 | if self.train: 94 | return len(self.train_data) 95 | else: 96 | return len(self.test_data) 97 | 98 | def _check_exists(self): 99 | return os.path.exists(os.path.join(self.root, self.processed_folder, self.training_file)) and \ 100 | os.path.exists(os.path.join(self.root, self.processed_folder, self.test_file)) 101 | 102 | def download(self): 103 | """Download the Moving MNIST data if it doesn't exist in processed_folder already.""" 104 | from six.moves import urllib 105 | import gzip 106 | 107 | if self._check_exists(): 108 | return 109 | 110 | # download files 111 | try: 112 | os.makedirs(os.path.join(self.root, self.raw_folder)) 113 | os.makedirs(os.path.join(self.root, self.processed_folder)) 114 | except OSError as e: 115 | if e.errno == errno.EEXIST: 116 | pass 117 | else: 118 | raise 119 | 120 | for url in self.urls: 121 | print('Downloading ' + url) 122 | data = urllib.request.urlopen(url) 123 | filename = url.rpartition('/')[2] 124 | file_path = os.path.join(self.root, self.raw_folder, filename) 125 | with open(file_path, 'wb') as f: 126 | f.write(data.read()) 127 | with open(file_path.replace('.gz', ''), 'wb') as out_f, \ 128 | gzip.GzipFile(file_path) as zip_f: 129 | out_f.write(zip_f.read()) 130 | os.unlink(file_path) 131 | 132 | # process and save as torch files 133 | print('Processing...') 134 | 135 | training_set = torch.from_numpy( 136 | np.load(os.path.join(self.root, self.raw_folder, 'mnist_test_seq.npy')).swapaxes(0, 1)[:-self.split] 137 | ) 138 | test_set = torch.from_numpy( 139 | np.load(os.path.join(self.root, self.raw_folder, 'mnist_test_seq.npy')).swapaxes(0, 1)[-self.split:] 140 | ) 141 | 142 | with open(os.path.join(self.root, self.processed_folder, self.training_file), 'wb') as f: 143 | torch.save(training_set, f) 144 | with open(os.path.join(self.root, self.processed_folder, self.test_file), 'wb') as f: 145 | torch.save(test_set, f) 146 | 147 | print('Done!') 148 | 149 | def __repr__(self): 150 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' 151 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) 152 | tmp = 'train' if self.train is True else 'test' 153 | fmt_str += ' Train/test: {}\n'.format(tmp) 154 | fmt_str += ' Root Location: {}\n'.format(self.root) 155 | tmp = ' Transforms (if any): ' 156 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 157 | tmp = ' Target Transforms (if any): ' 158 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 159 | return fmt_str 160 | -------------------------------------------------------------------------------- /video_model/diffusion/script_util.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import inspect 3 | 4 | import gaussian_diffusion as gd 5 | from respace import SpacedDiffusion, space_timesteps 6 | from unet import SuperResModel, UNetModel, UnshuffleModel,IdxPromptUnet 7 | from make_a_video import IdxPromptSTUnet 8 | 9 | NUM_CLASSES = 1000 10 | 11 | 12 | def model_and_diffusion_defaults(): 13 | """ 14 | Defaults for image training. 15 | """ 16 | return dict( 17 | image_size=32, # 32 something 18 | num_channels=128,#128 19 | num_res_blocks=2, 20 | num_heads=4, 21 | num_heads_upsample=-1, 22 | attention_resolutions="8,8", 23 | dropout=0.0, 24 | learn_sigma=False, 25 | sigma_small=False, 26 | class_cond=False, 27 | diffusion_steps=1000, # 1000 28 | noise_schedule="linear", 29 | timestep_respacing="", 30 | use_kl=False, 31 | predict_xstart=True, 32 | rescale_timesteps=True, 33 | rescale_learned_sigmas=True, 34 | use_checkpoint=False, 35 | use_scale_shift_norm=True, 36 | scale_time_dim=0, 37 | # 20 for all frames of moving mnsit # 4 for 4 frames, 8 for 8 frames. the seq len has to be the same in movingmnist python file 38 | rgb=True # True 39 | ) 40 | 41 | 42 | def create_model_and_diffusion( 43 | image_size, 44 | class_cond, 45 | learn_sigma, 46 | sigma_small, 47 | num_channels, 48 | num_res_blocks, 49 | scale_time_dim, 50 | num_heads, 51 | num_heads_upsample, 52 | attention_resolutions, 53 | dropout, 54 | diffusion_steps, 55 | noise_schedule, 56 | timestep_respacing, 57 | use_kl, 58 | predict_xstart, 59 | rescale_timesteps, 60 | rescale_learned_sigmas, 61 | use_checkpoint, 62 | use_scale_shift_norm, 63 | rgb=True, # True 64 | model_name='two_condition' 65 | ): 66 | model = create_model( 67 | image_size, 68 | num_channels, 69 | num_res_blocks, 70 | scale_time_dim=scale_time_dim, 71 | learn_sigma=learn_sigma, 72 | class_cond=class_cond, 73 | use_checkpoint=use_checkpoint, 74 | attention_resolutions=attention_resolutions, 75 | num_heads=num_heads, 76 | num_heads_upsample=num_heads_upsample, 77 | use_scale_shift_norm=use_scale_shift_norm, 78 | dropout=dropout, 79 | rgb=rgb, 80 | model_name=model_name 81 | ) 82 | print('predict_xstart: ',predict_xstart) 83 | diffusion = create_gaussian_diffusion( 84 | steps=diffusion_steps, 85 | learn_sigma=learn_sigma, 86 | sigma_small=sigma_small, 87 | noise_schedule=noise_schedule, 88 | use_kl=use_kl, 89 | predict_xstart=predict_xstart, 90 | rescale_timesteps=rescale_timesteps, 91 | rescale_learned_sigmas=rescale_learned_sigmas, 92 | timestep_respacing=timestep_respacing, 93 | ) 94 | return model, diffusion 95 | 96 | 97 | def create_model( 98 | image_size, 99 | num_channels, 100 | num_res_blocks, 101 | scale_time_dim, 102 | learn_sigma, 103 | class_cond, 104 | use_checkpoint, 105 | attention_resolutions, 106 | num_heads, 107 | num_heads_upsample, 108 | use_scale_shift_norm, 109 | dropout, 110 | rgb, 111 | model_name 112 | ): 113 | if image_size == 256: 114 | channel_mult = (1, 1, 2, 2, 4, 4) 115 | elif image_size == 128: 116 | channel_mult = (1, 2, 3, 4) 117 | elif image_size == 64: 118 | channel_mult = (1, 2, 3, 4) 119 | elif image_size == 32: 120 | channel_mult = (1, 2, 2, 2)# 1 2 2 2 121 | elif image_size in [16, 8]: # added 122 | channel_mult = (1, 1, 1, 1) 123 | # elif image_size == 8: 124 | # channel_mult = (1, 1, 1, 1) 125 | 126 | else: 127 | raise ValueError(f"unsupported image size: {image_size}") 128 | 129 | attention_ds = [] 130 | for res in attention_resolutions.split(","): 131 | attention_ds.append(image_size // int(res)) 132 | 133 | 134 | channels = 4 135 | 136 | if model_name == 'unshuffle': 137 | 138 | return UnshuffleModel( 139 | in_channels=channels, 140 | model_channels=num_channels, 141 | out_channels=(channels if not learn_sigma else 2 * channels), 142 | num_res_blocks=num_res_blocks, 143 | scale_time_dim=scale_time_dim, 144 | attention_resolutions=tuple(attention_ds), 145 | dropout=dropout, 146 | channel_mult=channel_mult, 147 | num_classes=(NUM_CLASSES if class_cond else None), 148 | use_checkpoint=use_checkpoint, 149 | num_heads=num_heads, 150 | num_heads_upsample=num_heads_upsample, 151 | use_scale_shift_norm=use_scale_shift_norm, 152 | ) 153 | else: 154 | print('Use transformers') 155 | return IdxPromptSTUnet( 156 | channels = channels, 157 | dim = 16*channels, 158 | dim_mult = (1, 2, 4, 8),#1, 2, 4, 8) 159 | temporal_compression = (False, False, False, False), 160 | self_attns = (False, True, True, True), 161 | condition_on_timestep = True, 162 | # resnet_block_depths = (2, 2, 2, 2) 163 | ) 164 | # return IdxPromptSTUnet( 165 | # in_channels=channels, # 1 166 | # model_channels=num_channels, 167 | # out_channels=channels+1, 168 | # num_res_blocks=num_res_blocks, 169 | # scale_time_dim=scale_time_dim, 170 | # attention_resolutions=tuple(attention_ds), 171 | # dropout=dropout, 172 | # channel_mult=channel_mult, 173 | # num_classes=(NUM_CLASSES if class_cond else None), 174 | # use_checkpoint=use_checkpoint, 175 | # num_heads=num_heads, 176 | # num_heads_upsample=num_heads_upsample, 177 | # use_scale_shift_norm=use_scale_shift_norm, 178 | # ) 179 | 180 | 181 | def sr_model_and_diffusion_defaults(): 182 | res = model_and_diffusion_defaults() 183 | res["large_size"] = 256 184 | res["small_size"] = 64 185 | arg_names = inspect.getfullargspec(sr_create_model_and_diffusion)[0] 186 | for k in res.copy().keys(): 187 | if k not in arg_names: 188 | del res[k] 189 | return res 190 | 191 | 192 | def sr_create_model_and_diffusion( 193 | large_size, 194 | small_size, 195 | class_cond, 196 | learn_sigma, 197 | num_channels, 198 | num_res_blocks, 199 | scale_time_dim, 200 | num_heads, 201 | num_heads_upsample, 202 | attention_resolutions, 203 | dropout, 204 | diffusion_steps, 205 | noise_schedule, 206 | timestep_respacing, 207 | use_kl, 208 | predict_xstart, 209 | rescale_timesteps, 210 | rescale_learned_sigmas, 211 | use_checkpoint, 212 | use_scale_shift_norm, 213 | ): 214 | model = sr_create_model( 215 | large_size, 216 | small_size, 217 | num_channels, 218 | num_res_blocks, 219 | scale_time_dim, 220 | learn_sigma=learn_sigma, 221 | class_cond=class_cond, 222 | use_checkpoint=use_checkpoint, 223 | attention_resolutions=attention_resolutions, 224 | num_heads=num_heads, 225 | num_heads_upsample=num_heads_upsample, 226 | use_scale_shift_norm=use_scale_shift_norm, 227 | dropout=dropout, 228 | ) 229 | diffusion = create_gaussian_diffusion( 230 | steps=diffusion_steps, 231 | learn_sigma=learn_sigma, 232 | noise_schedule=noise_schedule, 233 | use_kl=use_kl, 234 | predict_xstart=predict_xstart, 235 | rescale_timesteps=rescale_timesteps, 236 | rescale_learned_sigmas=rescale_learned_sigmas, 237 | timestep_respacing=timestep_respacing, 238 | ) 239 | return model, diffusion 240 | 241 | 242 | def sr_create_model( 243 | large_size, 244 | small_size, 245 | num_channels, 246 | num_res_blocks, 247 | scale_time_dim, 248 | learn_sigma, 249 | class_cond, 250 | use_checkpoint, 251 | attention_resolutions, 252 | num_heads, 253 | num_heads_upsample, 254 | use_scale_shift_norm, 255 | dropout, 256 | ): 257 | _ = small_size # hack to prevent unused variable 258 | 259 | if large_size == 256: 260 | channel_mult = (1, 1, 2, 2, 4, 4) 261 | elif large_size == 64: 262 | channel_mult = (1, 2, 3, 4) 263 | else: 264 | raise ValueError(f"unsupported large size: {large_size}") 265 | 266 | attention_ds = [] 267 | for res in attention_resolutions.split(","): 268 | attention_ds.append(large_size // int(res)) 269 | 270 | return SuperResModel( 271 | in_channels=3, 272 | model_channels=num_channels, 273 | out_channels=(3 if not learn_sigma else 6), 274 | num_res_blocks=num_res_blocks, 275 | scale_time_dim=scale_time_dim, 276 | attention_resolutions=tuple(attention_ds), 277 | dropout=dropout, 278 | channel_mult=channel_mult, 279 | num_classes=(NUM_CLASSES if class_cond else None), 280 | use_checkpoint=use_checkpoint, 281 | num_heads=num_heads, 282 | num_heads_upsample=num_heads_upsample, 283 | use_scale_shift_norm=use_scale_shift_norm, 284 | ) 285 | 286 | 287 | def create_gaussian_diffusion( 288 | *, 289 | steps=1000, # 290 | learn_sigma=False, 291 | sigma_small=False, 292 | noise_schedule="linear", 293 | use_kl=False, 294 | predict_xstart=False, 295 | rescale_timesteps=False, 296 | rescale_learned_sigmas=False, 297 | timestep_respacing="", 298 | ): 299 | betas = gd.get_named_beta_schedule(noise_schedule, steps) 300 | if use_kl: 301 | loss_type = gd.LossType.RESCALED_KL 302 | elif rescale_learned_sigmas: 303 | loss_type = gd.LossType.RESCALED_MSE 304 | else: 305 | loss_type = gd.LossType.MSE 306 | if not timestep_respacing: 307 | timestep_respacing = [steps] 308 | return SpacedDiffusion( 309 | use_timesteps=space_timesteps(steps, timestep_respacing), 310 | betas=betas, 311 | model_mean_type=( 312 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X 313 | ), 314 | model_var_type=( 315 | ( 316 | gd.ModelVarType.FIXED_LARGE 317 | if not sigma_small 318 | else gd.ModelVarType.FIXED_SMALL 319 | ) 320 | if not learn_sigma 321 | else gd.ModelVarType.LEARNED_RANGE 322 | ), 323 | loss_type=loss_type, 324 | rescale_timesteps=rescale_timesteps, 325 | ) 326 | 327 | 328 | def add_dict_to_argparser(parser, default_dict): 329 | for k, v in default_dict.items(): 330 | v_type = type(v) 331 | if v is None: 332 | v_type = str 333 | elif isinstance(v, bool): 334 | v_type = str2bool 335 | parser.add_argument(f"--{k}", default=v, type=v_type) 336 | 337 | 338 | def args_to_dict(args, keys): 339 | return {k: getattr(args, k) for k in keys} 340 | 341 | 342 | def str2bool(v): 343 | """ 344 | https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse 345 | """ 346 | if isinstance(v, bool): 347 | return v 348 | if v.lower() in ("yes", "true", "t", "y", "1"): 349 | return True 350 | elif v.lower() in ("no", "false", "f", "n", "0"): 351 | return False 352 | else: 353 | raise argparse.ArgumentTypeError("boolean value expected") 354 | -------------------------------------------------------------------------------- /image_model/train_JPDVT.py: -------------------------------------------------------------------------------- 1 | """ 2 | A minimal training script for JPDVT using PyTorch DDP. 3 | """ 4 | import torch 5 | torch.backends.cuda.matmul.allow_tf32 = True 6 | torch.backends.cudnn.allow_tf32 = True 7 | import torch.distributed as dist 8 | from torch.nn.parallel import DistributedDataParallel as DDP 9 | from torch.utils.data import DataLoader 10 | from torch.utils.data.distributed import DistributedSampler 11 | from torchvision.datasets import ImageFolder 12 | from torchvision import transforms 13 | import numpy as np 14 | from collections import OrderedDict 15 | from PIL import Image 16 | from copy import deepcopy 17 | from glob import glob 18 | from time import time 19 | import argparse 20 | import logging 21 | import os 22 | 23 | from models import DiT_models 24 | from models import get_2d_sincos_pos_embed 25 | from diffusion import create_diffusion 26 | from diffusers.models import AutoencoderKL 27 | 28 | from datasets import MET 29 | from einops import rearrange 30 | 31 | ################################################################################# 32 | # Training Helper Functions # 33 | ################################################################################# 34 | 35 | @torch.no_grad() 36 | def update_ema(ema_model, model, decay=0.9999): 37 | """ 38 | Step the EMA model towards the current model. 39 | """ 40 | ema_params = OrderedDict(ema_model.named_parameters()) 41 | model_params = OrderedDict(model.named_parameters()) 42 | 43 | for name, param in model_params.items(): 44 | # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed 45 | ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay) 46 | 47 | def requires_grad(model, flag=True): 48 | """ 49 | Set requires_grad flag for all parameters in a model. 50 | """ 51 | for p in model.parameters(): 52 | p.requires_grad = flag 53 | 54 | def cleanup(): 55 | """ 56 | End DDP training. 57 | """ 58 | dist.destroy_process_group() 59 | 60 | def create_logger(logging_dir): 61 | """ 62 | Create a logger that writes to a log file and stdout. 63 | """ 64 | if dist.get_rank() == 0: # real logger 65 | logging.basicConfig( 66 | level=logging.INFO, 67 | format='[\033[34m%(asctime)s\033[0m] %(message)s', 68 | datefmt='%Y-%m-%d %H:%M:%S', 69 | handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")] 70 | ) 71 | logger = logging.getLogger(__name__) 72 | else: # dummy logger (does nothing) 73 | logger = logging.getLogger(__name__) 74 | logger.addHandler(logging.NullHandler()) 75 | return logger 76 | 77 | 78 | def center_crop_arr(pil_image, image_size): 79 | """ 80 | Center cropping implementation from ADM. 81 | https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 82 | """ 83 | while min(*pil_image.size) >= 2 * image_size: 84 | pil_image = pil_image.resize( 85 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX 86 | ) 87 | 88 | scale = image_size / min(*pil_image.size) 89 | pil_image = pil_image.resize( 90 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC 91 | ) 92 | 93 | arr = np.array(pil_image) 94 | crop_y = (arr.shape[0] - image_size) // 2 95 | crop_x = (arr.shape[1] - image_size) // 2 96 | return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]) 97 | 98 | 99 | ################################################################################# 100 | # Training Loop # 101 | ################################################################################# 102 | 103 | def main(args): 104 | """ 105 | Trains a new DiT model. 106 | """ 107 | assert torch.cuda.is_available(), "Training currently requires at least one GPU." 108 | 109 | # Setup DDP: 110 | dist.init_process_group("nccl") 111 | assert args.global_batch_size % dist.get_world_size() == 0, f"Batch size must be divisible by world size." 112 | rank = dist.get_rank() 113 | device = rank % torch.cuda.device_count() 114 | seed = args.global_seed * dist.get_world_size() + rank 115 | torch.manual_seed(seed) 116 | torch.cuda.set_device(device) 117 | print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.") 118 | 119 | # Setup an experiment folder: 120 | if rank == 0: 121 | os.makedirs(args.results_dir, exist_ok=True) 122 | experiment_index = len(glob(f"{args.results_dir}/*")) 123 | model_string_name = args.model.replace("/", "-") 124 | model_string_name = args.dataset+"-" + model_string_name + "-crop" if args.crop else args.dataset+"-" + model_string_name 125 | model_string_name = model_string_name + "-withmask" if args.add_mask else model_string_name 126 | experiment_dir = f"{args.results_dir}/{experiment_index:03d}-{model_string_name}" # Create an experiment folder 127 | checkpoint_dir = f"{experiment_dir}/checkpoints" # Stores saved model checkpoints 128 | os.makedirs(checkpoint_dir, exist_ok=True) 129 | logger = create_logger(experiment_dir) 130 | logger.info(f"Experiment directory created at {experiment_dir}") 131 | else: 132 | logger = create_logger(None) 133 | 134 | # Create model: 135 | assert args.image_size % 3 == 0, "Image size should be Multiples of 3" 136 | if args.dataset == 'imagenet': 137 | assert args.image_size == 288 or args.crop, "Set imagesize to 192 if run experiment on imagenet with gap" 138 | model = DiT_models[args.model]( 139 | input_size=args.image_size 140 | ) 141 | 142 | if args.ckpt!= "": 143 | ckpt_path = args.ckpt 144 | print("Load model from ",ckpt_path) 145 | model_dict = model.state_dict() 146 | state_dict = torch.load(ckpt_path) 147 | pretrained_dict = {k: v for k, v in state_dict.items() if k in model_dict} 148 | model.load_state_dict(pretrained_dict, strict=False) 149 | 150 | # Note that parameter initialization is done within the DiT constructor 151 | ema = deepcopy(model).to(device) # Create an EMA of the model for use after training 152 | requires_grad(ema, False) 153 | model = DDP(model.to(device), device_ids=[rank]) 154 | diffusion = create_diffusion(timestep_respacing="") # default: 1000 steps, linear noise schedule 155 | logger.info(f"DiT Parameters: {sum(p.numel() for p in model.parameters()):,}") 156 | 157 | # Setup optimizer (we used default Adam betas=(0.9, 0.999) and a constant learning rate of 1e-4 in our paper): 158 | opt = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0) 159 | 160 | transform = transforms.Compose([ 161 | transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, 288)), 162 | transforms.RandomHorizontalFlip(), 163 | transforms.ToTensor(), 164 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) 165 | ]) 166 | 167 | # Setup data: 168 | if args.dataset == "met": 169 | # MET dataloader give out croped and stitched back images 170 | dataset = MET(args.data_path,'train') 171 | elif args.dataset == "imagenet": 172 | dataset = ImageFolder(args.data_path, transform=transform) 173 | 174 | sampler = DistributedSampler( 175 | dataset, 176 | num_replicas=dist.get_world_size(), 177 | rank=rank, 178 | shuffle=True, 179 | seed=args.global_seed 180 | ) 181 | loader = DataLoader( 182 | dataset, 183 | batch_size=int(args.global_batch_size // dist.get_world_size()), 184 | shuffle=False, 185 | sampler=sampler, 186 | num_workers=args.num_workers, 187 | pin_memory=True, 188 | drop_last=True 189 | ) 190 | logger.info(f"Dataset contains {len(dataset):,} images") 191 | 192 | # Prepare models for training: 193 | update_ema(ema, model.module, decay=0) # Ensure EMA is initialized with synced weights 194 | model.train() # important! This enables embedding dropout for classifier-free guidance 195 | ema.eval() # EMA model should always be in eval mode 196 | 197 | # Variables for monitoring/logging purposes: 198 | train_steps = 0 199 | log_steps = 0 200 | running_loss = 0 201 | start_time = time() 202 | 203 | logger.info(f"Training for {args.epochs} epochs...") 204 | for epoch in range(args.epochs): 205 | sampler.set_epoch(epoch) 206 | logger.info(f"Beginning epoch {epoch}...") 207 | for x in loader: 208 | if args.dataset == 'imagenet': 209 | x, _ = x 210 | x = x.to(device) 211 | if args.dataset == 'imagenet' and args.crop: 212 | centercrop = transforms.CenterCrop((64,64)) 213 | patchs = rearrange(x, 'b c (p1 h1) (p2 w1)-> b c (p1 p2) h1 w1',p1=3,p2=3,h1=96,w1=96) 214 | patchs = centercrop(patchs) 215 | x = rearrange(patchs, 'b c (p1 p2) h1 w1-> b c (p1 h1) (p2 w1)',p1=3,p2=3,h1=64,w1=64) 216 | 217 | # Set up initial positional embedding 218 | time_emb = torch.tensor(get_2d_sincos_pos_embed(8, 3)).unsqueeze(0).float().to(device) 219 | 220 | t = torch.randint(0, diffusion.num_timesteps, (x.shape[0],), device=device) 221 | model_kwargs = None 222 | loss_dict = diffusion.training_losses(model, x, t, time_emb, model_kwargs, \ 223 | block_size=args.image_size//3, patch_size=16, add_mask=args.add_mask) 224 | loss = loss_dict["loss"].mean() 225 | opt.zero_grad() 226 | loss.backward() 227 | opt.step() 228 | update_ema(ema, model.module) 229 | 230 | # Log loss values: 231 | running_loss += loss.item() 232 | log_steps += 1 233 | train_steps += 1 234 | if train_steps % args.log_every == 0: 235 | # Measure training speed: 236 | torch.cuda.synchronize() 237 | end_time = time() 238 | steps_per_sec = log_steps / (end_time - start_time) 239 | # Reduce loss history over all processes: 240 | avg_loss = torch.tensor(running_loss / log_steps, device=device) 241 | dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM) 242 | avg_loss = avg_loss.item() / dist.get_world_size() 243 | logger.info(f"(step={train_steps:07d}) Train Loss: {avg_loss:.4f}, Train Steps/Sec: {steps_per_sec:.2f}") 244 | # Reset monitoring variables: 245 | running_loss = 0 246 | log_steps = 0 247 | start_time = time() 248 | 249 | # Save JPDVT checkpoint: 250 | if train_steps % args.ckpt_every == 0 and train_steps > 0: 251 | if rank == 0: 252 | checkpoint = { 253 | "model": model.module.state_dict(), 254 | "ema": ema.state_dict(), 255 | "opt": opt.state_dict(), 256 | "args": args 257 | } 258 | checkpoint_path = f"{checkpoint_dir}/{train_steps:07d}.pt" 259 | torch.save(checkpoint, checkpoint_path) 260 | logger.info(f"Saved checkpoint to {checkpoint_path}") 261 | dist.barrier() 262 | 263 | model.eval() # important! This disables randomized embedding dropout 264 | 265 | logger.info("Done!") 266 | cleanup() 267 | 268 | if __name__ == "__main__": 269 | parser = argparse.ArgumentParser() 270 | parser.add_argument("--results-dir", type=str, default="results") 271 | parser.add_argument("--model", type=str, choices=list(DiT_models.keys()), default="JPDVT") 272 | parser.add_argument("--dataset", type=str, choices=["imagenet", "met"], default="imagenet") 273 | parser.add_argument("--data-path", type=str,required=True) 274 | parser.add_argument("--crop", action='store_true', default=False) 275 | parser.add_argument("--add-mask", action='store_true', default=False) 276 | parser.add_argument("--image-size", type=int, choices=[192, 288], default=288) 277 | parser.add_argument("--epochs", type=int, default=500) 278 | parser.add_argument("--global-batch-size", type=int, default=96) 279 | parser.add_argument("--global-seed", type=int, default=0) 280 | parser.add_argument("--num-workers", type=int, default=12) 281 | parser.add_argument("--log-every", type=int, default=100) 282 | parser.add_argument("--ckpt-every", type=int, default=10_000) 283 | parser.add_argument("--ckpt", type=str, default='') 284 | args = parser.parse_args() 285 | main(args) 286 | -------------------------------------------------------------------------------- /video_model/datasets/video_datasets.py: -------------------------------------------------------------------------------- 1 | from random import sample 2 | from PIL import Image, ImageSequence 3 | import blobfile as bf 4 | from mpi4py import MPI 5 | import numpy as np 6 | from torch.utils.data import DataLoader, Dataset 7 | 8 | # from MovingMNIST import MovingMNIST 9 | # file: /folder1/folder2/module.py 10 | import os 11 | import sys 12 | 13 | sys.path.insert(1, os.getcwd()) 14 | # sys.path.insert(0, '/RaMViD_main/mnist') 15 | # sys.path.insert(0, '/RaMViD_main/nucla_skeleton') 16 | # sys.path.insert(0, '/RaMViD_main/something') 17 | 18 | from MovingMNIST import MovingMNIST 19 | from somethingsomethingv2 import SomethingSomethingDataset 20 | from NUCLAskeleton import NUCLAskeleton 21 | from NTURGBDskeleton import NTURGBDskeleton 22 | from clevrer import Clevrer 23 | import torch 24 | import av 25 | import os 26 | # from nucla_skeleton.NUCLAskeleton import NUCLAskeleton 27 | import torch.utils.data as tudata 28 | 29 | 30 | # from somethng.somethingsomethingv2 import SomethingSomethingDataset 31 | 32 | 33 | def load_data( 34 | *, data_dir, batch_size, image_size, class_cond=False, deterministic=False, rgb=True, seq_len=8 35 | ): 36 | """ 37 | For a dataset, create a generator over (videos, kwargs) pairs. 38 | 39 | Each video is an NCLHW float tensor, and the kwargs dict contains zero or 40 | more keys, each of which map to a batched Tensor of their own. 41 | The kwargs dict can be used for class labels, in which case the key is "y" 42 | and the values are integer tensors of class labels. 43 | 44 | :param data_dir: a dataset directory. 45 | :param batch_size: the batch size of each returned pair. 46 | :param image_size: the size to which frames are resized. 47 | :param class_cond: if True, include a "y" key in returned dicts for class 48 | label. If classes are not available and this is true, an 49 | exception will be raised. 50 | :param deterministic: if True, yield results in a deterministic order. 51 | """ 52 | # global dataset 53 | if not data_dir: 54 | raise ValueError("unspecified data directory") 55 | 56 | # if data_dir != "MovingMNIST" and data_dir!="MovingMNIST_Test": # for sampling 57 | if data_dir not in ["MovingMNIST-train", "MovingMNIST-test", "something-train", "something-test", 58 | "NUCLAskeleton-train", "NUCLAskeleton-test", "NTURGBD-train", "NTURGBD-test","clevrer-train","clevrer-test"]: 59 | 60 | # if data_dir == "real_dataset": # 61 | all_files = _list_video_files_recursively(data_dir) 62 | classes = None 63 | if class_cond: 64 | # Assume classes are the first part of the filename, 65 | # before an underscore. 66 | class_names = [bf.basename(path).split("_")[0] for path in all_files] 67 | sorted_classes = {x: i for i, x in enumerate(sorted(set(class_names)))} 68 | classes = [sorted_classes[x] for x in class_names] 69 | entry = all_files[0].split(".")[-1] 70 | # breakpoint() 71 | if entry in ["avi", "mp4", " webm"]: 72 | dataset = VideoDataset_mp4( 73 | image_size, 74 | all_files, 75 | classes=classes, 76 | shard=MPI.COMM_WORLD.Get_rank(), 77 | num_shards=MPI.COMM_WORLD.Get_size(), 78 | rgb=rgb, 79 | seq_len=seq_len 80 | ) 81 | # breakpoint() 82 | elif entry in ["gif"]: 83 | dataset = VideoDataset_gif( 84 | image_size, 85 | all_files, 86 | classes=classes, 87 | shard=MPI.COMM_WORLD.Get_rank(), 88 | num_shards=MPI.COMM_WORLD.Get_size(), 89 | rgb=rgb, 90 | seq_len=seq_len 91 | ) 92 | elif data_dir == "MovingMNIST-train": 93 | 94 | root_path = '/data/mark' 95 | if not os.path.exists(root_path): 96 | os.mkdir(root_path) 97 | dataset = MovingMNIST(root='/data/mark/mnist', train=True, download=True, image_size=image_size, 98 | seq_len=seq_len) 99 | elif data_dir == "MovingMNIST-test": # newly added 100 | root_path = '/data/mark' 101 | # breakpoint() 102 | if not os.path.exists(root_path): 103 | os.mkdir(root_path) 104 | dataset = MovingMNIST(root='/data/mark/mnist', train=False, download=True, image_size=image_size, 105 | seq_len=seq_len) 106 | elif data_dir == "something-train": 107 | dataset = SomethingSomethingDataset(video_dir='/data/wondm/somethingsomethingv2/train', 108 | frame_res=image_size, num_frames=seq_len) 109 | 110 | elif data_dir == "something-test": # newly added 111 | dataset = SomethingSomethingDataset(video_dir='/data/wondm/somethingsomethingv2/test', 112 | frame_res=image_size, num_frames=seq_len) 113 | 114 | elif data_dir == "NUCLAskeleton-train": # newly added 115 | dataset = NUCLAskeleton(phase='train', cam='1,2', frame_res=image_size) 116 | elif data_dir == "NUCLAskeleton-test": # newly added 117 | dataset = NUCLAskeleton(phase='test', cam='1,2', frame_res=image_size) 118 | elif data_dir == "NTURGBD-train": # newly added 119 | dataset = NTURGBDskeleton(phase='train', cam='1') # T=8,20,40 120 | elif data_dir == "NTURGBD-test": # newly added 121 | dataset = NTURGBDskeleton(phase='val', cam='1') # T=8,20,40 122 | elif data_dir=="clevrer-train": 123 | dataset = Clevrer('/data/CLEVRER/Numpy/train/') 124 | elif data_dir=="clevrer-test": 125 | dataset = Clevrer('/data/CLEVRER/Numpy/test/') 126 | 127 | 128 | 129 | # 130 | # if deterministic: 131 | # loader = DataLoader( 132 | # dataset, batch_size=batch_size, shuffle=False, num_workers=16, drop_last=True 133 | # ) 134 | # else: 135 | # loader = DataLoader( 136 | # dataset, batch_size=batch_size, shuffle=True, num_workers=4,drop_last=True 137 | # ) 138 | if data_dir in ["NUCLAskeleton-train", "NUCLAskeleton-test"]: 139 | loader = tudata.DataLoader(dataset, batch_size=1, shuffle=False, 140 | num_workers=1, pin_memory=True) 141 | else: 142 | # breakpoint() 143 | loader = DataLoader( 144 | dataset, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=True 145 | ) 146 | 147 | # dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4) 148 | while True: 149 | yield from loader 150 | 151 | 152 | def _list_video_files_recursively(data_dir): 153 | results = [] 154 | for entry in sorted(bf.listdir(data_dir)): 155 | full_path = bf.join(data_dir, entry) 156 | ext = entry.split(".")[-1] 157 | if "." in entry and ext.lower() in ["gif", "avi", "mp4", "webm"]: 158 | results.append(full_path) 159 | elif bf.isdir(full_path): 160 | results.extend(_list_video_files_recursively(full_path)) 161 | return results 162 | 163 | 164 | class VideoDataset_mp4(Dataset): 165 | 166 | def __init__(self, resolution, video_paths, classes=None, shard=0, num_shards=1, rgb=True, seq_len=20): 167 | super().__init__() 168 | self.resolution = resolution 169 | self.local_videos = video_paths[shard:][::num_shards] 170 | self.local_classes = None if classes is None else classes[shard:][::num_shards] 171 | self.rgb = rgb 172 | self.seq_len = seq_len 173 | 174 | def __len__(self): 175 | return len(self.local_videos) 176 | 177 | def __getitem__(self, idx): 178 | path = self.local_videos[idx] 179 | # path = './data' 180 | arr_list = [] 181 | video_container = av.open(path) 182 | n = video_container.streams.video[0].frames 183 | # breakpoint() 184 | frames = [i for i in range(n)] 185 | if n > self.seq_len: 186 | start = np.random.randint(0, n - self.seq_len) 187 | frames = frames[start:start + self.seq_len] 188 | for id, frame_av in enumerate(video_container.decode(video=0)): 189 | # We are not on a new enough PIL to support the `reducing_gap` 190 | # argument, which uses BOX downsampling at powers of two first. 191 | # Thus, we do it by hand to improve downsample quality. 192 | if (id not in frames): 193 | continue 194 | frame = frame_av.to_image() 195 | while min(*frame.size) >= 2 * self.resolution: 196 | frame = frame.resize( 197 | tuple(x // 2 for x in frame.size), resample=Image.BOX 198 | ) 199 | scale = self.resolution / min(*frame.size) 200 | frame = frame.resize( 201 | tuple(round(x * scale) for x in frame.size), resample=Image.BICUBIC 202 | ) 203 | 204 | if self.rgb: 205 | arr = np.array(frame.convert("RGB")) 206 | else: 207 | arr = np.array(frame.convert("L")) 208 | arr = np.expand_dims(arr, axis=2) 209 | crop_y = (arr.shape[0] - self.resolution) // 2 210 | crop_x = (arr.shape[1] - self.resolution) // 2 211 | arr = arr[crop_y: crop_y + self.resolution, crop_x: crop_x + self.resolution] 212 | arr = arr.astype(np.float32) / 127.5 - 1 213 | arr_list.append(arr) 214 | arr_seq = np.array(arr_list) 215 | arr_seq = np.transpose(arr_seq, [3, 0, 1, 2]) 216 | # fill in missing frames with 0s 217 | if arr_seq.shape[1] < self.seq_len: 218 | required_dim = self.seq_len - arr_seq.shape[1] 219 | fill = np.zeros((3, required_dim, self.resolution, self.resolution)) 220 | arr_seq = np.concatenate((arr_seq, fill), axis=1) 221 | out_dict = {} 222 | if self.local_classes is not None: 223 | out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64) 224 | 225 | # breakpoint() 226 | return arr_seq, out_dict 227 | 228 | 229 | class VideoDataset_gif(Dataset): 230 | def __init__(self, resolution, video_paths, classes=None, shard=0, num_shards=1, rgb=True, seq_len=20): 231 | super().__init__() 232 | self.resolution = resolution 233 | self.local_videos = video_paths[shard:][::num_shards] 234 | self.local_classes = None if classes is None else classes[shard:][::num_shards] 235 | self.rgb = rgb 236 | self.seq_len = seq_len 237 | 238 | def __len__(self): 239 | return len(self.local_videos) 240 | 241 | def __getitem__(self, idx): 242 | path = self.local_videos[idx] 243 | with bf.BlobFile(path, "rb") as f: 244 | pil_videos = Image.open(f) 245 | arr_list = [] 246 | for frame in ImageSequence.Iterator(pil_videos): 247 | 248 | # We are not on a new enough PIL to support the `reducing_gap` 249 | # argument, which uses BOX downsampling at powers of two first. 250 | # Thus, we do it by hand to improve downsample quality. 251 | while min(*frame.size) >= 2 * self.resolution: 252 | frame = frame.resize( 253 | tuple(x // 2 for x in frame.size), resample=Image.BOX 254 | ) 255 | scale = self.resolution / min(*frame.size) 256 | frame = frame.resize( 257 | tuple(round(x * scale) for x in frame.size), resample=Image.BICUBIC 258 | ) 259 | 260 | if self.rgb: 261 | arr = np.array(frame.convert("RGB")) 262 | else: 263 | arr = np.array(frame.convert("L")) 264 | arr = np.expand_dims(arr, axis=2) 265 | crop_y = (arr.shape[0] - self.resolution) // 2 266 | crop_x = (arr.shape[1] - self.resolution) // 2 267 | arr = arr[crop_y: crop_y + self.resolution, crop_x: crop_x + self.resolution] 268 | arr = arr.astype(np.float32) / 127.5 - 1 269 | arr_list.append(arr) 270 | arr_seq = np.array(arr_list) 271 | arr_seq = np.transpose(arr_seq, [3, 0, 1, 2]) 272 | if arr_seq.shape[1] > self.seq_len: 273 | start = np.random.randint(0, arr_seq.shape[1] - self.seq_len) 274 | arr_seq = arr_seq[:, start:start + self.seq_len] 275 | out_dict = {} 276 | if self.local_classes is not None: 277 | out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64) 278 | return arr_seq, out_dict 279 | -------------------------------------------------------------------------------- /video_model/diffusion/logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | Logger copied from OpenAI baselines to avoid extra RL-based dependencies: 3 | https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/logger.py 4 | """ 5 | 6 | import os 7 | import sys 8 | import shutil 9 | import os.path as osp 10 | import json 11 | import time 12 | import datetime 13 | import tempfile 14 | import warnings 15 | from collections import defaultdict 16 | from contextlib import contextmanager 17 | 18 | DEBUG = 10 19 | INFO = 20 20 | WARN = 30 21 | ERROR = 40 22 | 23 | DISABLED = 50 24 | 25 | 26 | class KVWriter(object): 27 | def writekvs(self, kvs): 28 | raise NotImplementedError 29 | 30 | 31 | class SeqWriter(object): 32 | def writeseq(self, seq): 33 | raise NotImplementedError 34 | 35 | 36 | class HumanOutputFormat(KVWriter, SeqWriter): 37 | def __init__(self, filename_or_file): 38 | if isinstance(filename_or_file, str): 39 | self.file = open(filename_or_file, "wt") 40 | self.own_file = True 41 | else: 42 | assert hasattr(filename_or_file, "read"), ( 43 | "expected file or str, got %s" % filename_or_file 44 | ) 45 | self.file = filename_or_file 46 | self.own_file = False 47 | 48 | def writekvs(self, kvs): 49 | # Create strings for printing 50 | key2str = {} 51 | for (key, val) in sorted(kvs.items()): 52 | if hasattr(val, "__float__"): 53 | valstr = "%-8.3g" % val 54 | else: 55 | valstr = str(val) 56 | key2str[self._truncate(key)] = self._truncate(valstr) 57 | 58 | # Find max widths 59 | if len(key2str) == 0: 60 | print("WARNING: tried to write empty key-value dict") 61 | return 62 | else: 63 | keywidth = max(map(len, key2str.keys())) 64 | valwidth = max(map(len, key2str.values())) 65 | 66 | # Write out the data 67 | dashes = "-" * (keywidth + valwidth + 7) 68 | lines = [dashes] 69 | for (key, val) in sorted(key2str.items(), key=lambda kv: kv[0].lower()): 70 | lines.append( 71 | "| %s%s | %s%s |" 72 | % (key, " " * (keywidth - len(key)), val, " " * (valwidth - len(val))) 73 | ) 74 | lines.append(dashes) 75 | self.file.write("\n".join(lines) + "\n") 76 | 77 | # Flush the output to the file 78 | self.file.flush() 79 | 80 | def _truncate(self, s): 81 | maxlen = 30 82 | return s[: maxlen - 3] + "..." if len(s) > maxlen else s 83 | 84 | def writeseq(self, seq): 85 | seq = list(seq) 86 | for (i, elem) in enumerate(seq): 87 | self.file.write(elem) 88 | if i < len(seq) - 1: # add space unless this is the last one 89 | self.file.write(" ") 90 | self.file.write("\n") 91 | self.file.flush() 92 | 93 | def close(self): 94 | if self.own_file: 95 | self.file.close() 96 | 97 | 98 | class JSONOutputFormat(KVWriter): 99 | def __init__(self, filename): 100 | self.file = open(filename, "wt") 101 | 102 | def writekvs(self, kvs): 103 | for k, v in sorted(kvs.items()): 104 | if hasattr(v, "dtype"): 105 | kvs[k] = float(v) 106 | self.file.write(json.dumps(kvs) + "\n") 107 | self.file.flush() 108 | 109 | def close(self): 110 | self.file.close() 111 | 112 | 113 | class CSVOutputFormat(KVWriter): 114 | def __init__(self, filename): 115 | self.file = open(filename, "w+t") 116 | self.keys = [] 117 | self.sep = "," 118 | 119 | def writekvs(self, kvs): 120 | # Add our current row to the history 121 | extra_keys = list(kvs.keys() - self.keys) 122 | extra_keys.sort() 123 | if extra_keys: 124 | self.keys.extend(extra_keys) 125 | self.file.seek(0) 126 | lines = self.file.readlines() 127 | self.file.seek(0) 128 | for (i, k) in enumerate(self.keys): 129 | if i > 0: 130 | self.file.write(",") 131 | self.file.write(k) 132 | self.file.write("\n") 133 | for line in lines[1:]: 134 | self.file.write(line[:-1]) 135 | self.file.write(self.sep * len(extra_keys)) 136 | self.file.write("\n") 137 | for (i, k) in enumerate(self.keys): 138 | if i > 0: 139 | self.file.write(",") 140 | v = kvs.get(k) 141 | if v is not None: 142 | self.file.write(str(v)) 143 | self.file.write("\n") 144 | self.file.flush() 145 | 146 | def close(self): 147 | self.file.close() 148 | 149 | 150 | class TensorBoardOutputFormat(KVWriter): 151 | """ 152 | Dumps key/value pairs into TensorBoard's numeric format. 153 | """ 154 | 155 | def __init__(self, dir): 156 | os.makedirs(dir, exist_ok=True) 157 | self.dir = dir 158 | self.step = 1 159 | prefix = "events" 160 | path = osp.join(osp.abspath(dir), prefix) 161 | import tensorflow as tf 162 | from tensorflow.python import pywrap_tensorflow 163 | from tensorflow.core.util import event_pb2 164 | from tensorflow.python.util import compat 165 | 166 | self.tf = tf 167 | self.event_pb2 = event_pb2 168 | self.pywrap_tensorflow = pywrap_tensorflow 169 | self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path)) 170 | 171 | def writekvs(self, kvs): 172 | def summary_val(k, v): 173 | kwargs = {"tag": k, "simple_value": float(v)} 174 | return self.tf.Summary.Value(**kwargs) 175 | 176 | summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()]) 177 | event = self.event_pb2.Event(wall_time=time.time(), summary=summary) 178 | event.step = ( 179 | self.step 180 | ) # is there any reason why you'd want to specify the step? 181 | self.writer.WriteEvent(event) 182 | self.writer.Flush() 183 | self.step += 1 184 | 185 | def close(self): 186 | if self.writer: 187 | self.writer.Close() 188 | self.writer = None 189 | 190 | 191 | def make_output_format(format, ev_dir, log_suffix=""): 192 | os.makedirs(ev_dir, exist_ok=True) 193 | if format == "stdout": 194 | return HumanOutputFormat(sys.stdout) 195 | elif format == "log": 196 | return HumanOutputFormat(osp.join(ev_dir, "log%s.txt" % log_suffix)) 197 | elif format == "json": 198 | return JSONOutputFormat(osp.join(ev_dir, "progress%s.json" % log_suffix)) 199 | elif format == "csv": 200 | return CSVOutputFormat(osp.join(ev_dir, "progress%s.csv" % log_suffix)) 201 | elif format == "tensorboard": 202 | return TensorBoardOutputFormat(osp.join(ev_dir, "tb%s" % log_suffix)) 203 | else: 204 | raise ValueError("Unknown format specified: %s" % (format,)) 205 | 206 | 207 | # ================================================================ 208 | # API 209 | # ================================================================ 210 | 211 | 212 | def logkv(key, val): 213 | """ 214 | Log a value of some diagnostic 215 | Call this once for each diagnostic quantity, each iteration 216 | If called many times, last value will be used. 217 | """ 218 | get_current().logkv(key, val) 219 | 220 | 221 | def logkv_mean(key, val): 222 | """ 223 | The same as logkv(), but if called many times, values averaged. 224 | """ 225 | get_current().logkv_mean(key, val) 226 | 227 | 228 | def logkvs(d): 229 | """ 230 | Log a dictionary of key-value pairs 231 | """ 232 | for (k, v) in d.items(): 233 | logkv(k, v) 234 | 235 | 236 | def dumpkvs(): 237 | """ 238 | Write all of the diagnostics from the current iteration 239 | """ 240 | return get_current().dumpkvs() 241 | 242 | 243 | def getkvs(): 244 | return get_current().name2val 245 | 246 | 247 | def log(*args, level=INFO): 248 | """ 249 | Write the sequence of args, with no separators, to the console and output files (if you've configured an output file). 250 | """ 251 | get_current().log(*args, level=level) 252 | 253 | 254 | def debug(*args): 255 | log(*args, level=DEBUG) 256 | 257 | 258 | def info(*args): 259 | log(*args, level=INFO) 260 | 261 | 262 | def warn(*args): 263 | log(*args, level=WARN) 264 | 265 | 266 | def error(*args): 267 | log(*args, level=ERROR) 268 | 269 | 270 | def set_level(level): 271 | """ 272 | Set logging threshold on current logger. 273 | """ 274 | get_current().set_level(level) 275 | 276 | 277 | def set_comm(comm): 278 | get_current().set_comm(comm) 279 | 280 | 281 | def get_dir(): 282 | """ 283 | Get directory that log files are being written to. 284 | will be None if there is no output directory (i.e., if you didn't call start) 285 | """ 286 | return get_current().get_dir() 287 | 288 | 289 | record_tabular = logkv 290 | dump_tabular = dumpkvs 291 | 292 | 293 | @contextmanager 294 | def profile_kv(scopename): 295 | logkey = "wait_" + scopename 296 | tstart = time.time() 297 | try: 298 | yield 299 | finally: 300 | get_current().name2val[logkey] += time.time() - tstart 301 | 302 | 303 | def profile(n): 304 | """ 305 | Usage: 306 | @profile("my_func") 307 | def my_func(): code 308 | """ 309 | 310 | def decorator_with_name(func): 311 | def func_wrapper(*args, **kwargs): 312 | with profile_kv(n): 313 | return func(*args, **kwargs) 314 | 315 | return func_wrapper 316 | 317 | return decorator_with_name 318 | 319 | 320 | # ================================================================ 321 | # Backend 322 | # ================================================================ 323 | 324 | 325 | def get_current(): 326 | if Logger.CURRENT is None: 327 | _configure_default_logger() 328 | 329 | return Logger.CURRENT 330 | 331 | 332 | class Logger(object): 333 | DEFAULT = None # A logger with no output files. (See right below class definition) 334 | # So that you can still log to the terminal without setting up any output files 335 | CURRENT = None # Current logger being used by the free functions above 336 | 337 | def __init__(self, dir, output_formats, comm=None): 338 | self.name2val = defaultdict(float) # values this iteration 339 | self.name2cnt = defaultdict(int) 340 | self.level = INFO 341 | self.dir = dir 342 | self.output_formats = output_formats 343 | self.comm = comm 344 | 345 | # Logging API, forwarded 346 | # ---------------------------------------- 347 | def logkv(self, key, val): 348 | self.name2val[key] = val 349 | 350 | def logkv_mean(self, key, val): 351 | oldval, cnt = self.name2val[key], self.name2cnt[key] 352 | self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1) 353 | self.name2cnt[key] = cnt + 1 354 | 355 | def dumpkvs(self): 356 | if self.comm is None: 357 | d = self.name2val 358 | else: 359 | d = mpi_weighted_mean( 360 | self.comm, 361 | { 362 | name: (val, self.name2cnt.get(name, 1)) 363 | for (name, val) in self.name2val.items() 364 | }, 365 | ) 366 | if self.comm.rank != 0: 367 | d["dummy"] = 1 # so we don't get a warning about empty dict 368 | out = d.copy() # Return the dict for unit testing purposes 369 | for fmt in self.output_formats: 370 | if isinstance(fmt, KVWriter): 371 | fmt.writekvs(d) 372 | self.name2val.clear() 373 | self.name2cnt.clear() 374 | return out 375 | 376 | def log(self, *args, level=INFO): 377 | if self.level <= level: 378 | self._do_log(args) 379 | 380 | # Configuration 381 | # ---------------------------------------- 382 | def set_level(self, level): 383 | self.level = level 384 | 385 | def set_comm(self, comm): 386 | self.comm = comm 387 | 388 | def get_dir(self): 389 | return self.dir 390 | 391 | def close(self): 392 | for fmt in self.output_formats: 393 | fmt.close() 394 | 395 | # Misc 396 | # ---------------------------------------- 397 | def _do_log(self, args): 398 | for fmt in self.output_formats: 399 | if isinstance(fmt, SeqWriter): 400 | fmt.writeseq(map(str, args)) 401 | 402 | 403 | def get_rank_without_mpi_import(): 404 | # check environment variables here instead of importing mpi4py 405 | # to avoid calling MPI_Init() when this module is imported 406 | for varname in ["PMI_RANK", "OMPI_COMM_WORLD_RANK"]: 407 | if varname in os.environ: 408 | return int(os.environ[varname]) 409 | return 0 410 | 411 | 412 | def mpi_weighted_mean(comm, local_name2valcount): 413 | """ 414 | Copied from: https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/mpi_util.py#L110 415 | Perform a weighted average over dicts that are each on a different node 416 | Input: local_name2valcount: dict mapping key -> (value, count) 417 | Returns: key -> mean 418 | """ 419 | all_name2valcount = comm.gather(local_name2valcount) 420 | if comm.rank == 0: 421 | name2sum = defaultdict(float) 422 | name2count = defaultdict(float) 423 | for n2vc in all_name2valcount: 424 | for (name, (val, count)) in n2vc.items(): 425 | try: 426 | val = float(val) 427 | except ValueError: 428 | if comm.rank == 0: 429 | warnings.warn( 430 | "WARNING: tried to compute mean on non-float {}={}".format( 431 | name, val 432 | ) 433 | ) 434 | else: 435 | name2sum[name] += val * count 436 | name2count[name] += count 437 | return {name: name2sum[name] / name2count[name] for name in name2sum} 438 | else: 439 | return {} 440 | 441 | 442 | def configure(dir='/home/mark/projects/prompt_2_clvr/results', format_strs=None, comm=None, log_suffix=""): 443 | """ 444 | If comm is provided, average all numerical stats across that comm 445 | """ 446 | if dir is None: 447 | dir = os.getenv("OPENAI_LOGDIR") 448 | if dir is None: 449 | dir = osp.join( 450 | tempfile.gettempdir(), 451 | datetime.datetime.now().strftime("openai_unshuffle-%m-%d-%H-%M-%S-%f"), 452 | ) 453 | else: 454 | dir=osp.join(dir, 455 | datetime.datetime.now().strftime("prompt_mnist_20T4M_Tansformer-%m-%d-%H-%M-%S"), 456 | ) 457 | assert isinstance(dir, str) 458 | dir = os.path.expanduser(dir) 459 | os.makedirs(os.path.expanduser(dir), exist_ok=True) 460 | 461 | rank = get_rank_without_mpi_import() 462 | if rank > 0: 463 | log_suffix = log_suffix + "-rank%03i" % rank 464 | 465 | if format_strs is None: 466 | if rank == 0: 467 | format_strs = os.getenv("OPENAI_LOG_FORMAT", "stdout,log,csv").split(",") 468 | else: 469 | format_strs = os.getenv("OPENAI_LOG_FORMAT_MPI", "log").split(",") 470 | format_strs = filter(None, format_strs) 471 | output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs] 472 | 473 | Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm) 474 | if output_formats: 475 | log("Logging to %s" % dir) 476 | 477 | 478 | def _configure_default_logger(): 479 | configure() 480 | Logger.DEFAULT = Logger.CURRENT 481 | 482 | 483 | def reset(): 484 | if Logger.CURRENT is not Logger.DEFAULT: 485 | Logger.CURRENT.close() 486 | Logger.CURRENT = Logger.DEFAULT 487 | log("Reset logger") 488 | 489 | 490 | @contextmanager 491 | def scoped_configure(dir=None, format_strs=None, comm=None): 492 | prevlogger = Logger.CURRENT 493 | configure(dir=dir, format_strs=format_strs, comm=comm) 494 | try: 495 | yield 496 | finally: 497 | Logger.CURRENT.close() 498 | Logger.CURRENT = prevlogger 499 | 500 | -------------------------------------------------------------------------------- /video_model/video_sample.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generate a large batch of image samples from a model and save them as a large 3 | numpy array. This can be used to produce samples for FID evaluation. 4 | """ 5 | import argparse 6 | import os 7 | import shutil 8 | import sys 9 | 10 | import matplotlib.pyplot as plt 11 | import numpy as np 12 | import torch as th 13 | import torch.nn as nn 14 | import torch.distributed as dist 15 | from matplotlib import gridspec 16 | from skimage.metrics import structural_similarity, mean_squared_error 17 | from sklearn.preprocessing import MinMaxScaler 18 | from positional_encodings.torch_encodings import PositionalEncoding1D 19 | sys.path.insert(1, os.getcwd()) 20 | import random 21 | 22 | from video_datasets import load_data 23 | import dist_util, logger 24 | from script_util import ( 25 | model_and_diffusion_defaults, 26 | create_model_and_diffusion, 27 | add_dict_to_argparser, 28 | args_to_dict, 29 | ) 30 | 31 | from functools import reduce 32 | 33 | import torch 34 | 35 | from sklearn.metrics import pairwise_distances 36 | 37 | def save_tensor(file_name, data_tensor): 38 | directory = "/home/wondmgezahu/ppo/latent-diffusion-main/EncodedLatent" 39 | file_path = os.path.join(directory, file_name) 40 | torch.save(data_tensor, file_path) 41 | def plot_samples(generated_data, shuffled_input_data, ground_truth_data): 42 | # Delete the "figures" directory and its contents (if it exists) 43 | if os.path.exists('fig_test'): 44 | shutil.rmtree('fig_test') 45 | 46 | # Create a new "figures" directory 47 | os.makedirs('fig_test') 48 | 49 | for i in range(generated_data.shape[0]): # 50 | # Select the current batch 51 | batch_generated = generated_data[i, :, :, :, :] 52 | batch_shuffled = shuffled_input_data[i, :, :, :, :] 53 | batch_ground_truth = ground_truth_data[i, :, :, :, :] 54 | 55 | # Create a figure with 10 subplots (one for each frame) 56 | fig = plt.figure(figsize=(6, 6)) # (3,3) for 4 frame 57 | gs = gridspec.GridSpec(nrows=3, ncols=generated_data.shape[2], left=0, right=1, top=1, bottom=0) 58 | 59 | # Reduce the space between the subplots 60 | # fig.tight_layout() 61 | plt.subplots_adjust(wspace=0, hspace=0) 62 | 63 | # Loop over the frames 64 | for j in range(batch_generated.shape[1]): 65 | # Select the current frame 66 | frame_generated = batch_generated[:, j, :, :] 67 | frame_shuffled = batch_shuffled[:, j, :, :] 68 | frame_ground_truth = batch_ground_truth[:, j, :, :] 69 | 70 | # Create a subplot at the current position 71 | ax1 = fig.add_subplot(gs[0, j]) 72 | ax2 = fig.add_subplot(gs[1, j]) 73 | ax3 = fig.add_subplot(gs[2, j]) 74 | 75 | # Plot the frame in grayscale 76 | # breakpoint() 77 | ax1.imshow(np.squeeze(frame_shuffled), cmap='gray') # cmap='gray' 78 | ax1.axis('off') 79 | if j == 4: 80 | ax1.set_title('shuffled samples', loc='center', fontsize=10) 81 | 82 | ax2.imshow(np.squeeze(frame_ground_truth), cmap='gray') # cmap='gray' 83 | ax2.axis('off') 84 | if j == 4: 85 | ax2.set_title('ground truth samples', loc='center', fontsize=10) 86 | 87 | ax3.imshow(np.squeeze(frame_generated), cmap='gray') # cmap='gray' 88 | ax3.axis('off') 89 | if j == 4: 90 | ax3.set_title('generated samples', loc='center', fontsize=10) 91 | 92 | plt.savefig('fig_test/batch{}.png'.format(i)) 93 | 94 | def main(): 95 | args = create_argparser().parse_args() 96 | 97 | dist_util.setup_dist() 98 | # logger.configure() 99 | if args.seed: 100 | th.manual_seed(args.seed) 101 | np.random.seed(args.seed) 102 | random.seed(args.seed) 103 | 104 | # logger.log("creating model and diffusion...") 105 | model, diffusion = create_model_and_diffusion( 106 | **args_to_dict(args, model_and_diffusion_defaults().keys()) 107 | ) 108 | 109 | model.load_state_dict( 110 | dist_util.load_state_dict(args.model_path, map_location="cpu") 111 | ) 112 | model.to(dist_util.dev()) 113 | # breakpoint() 114 | model.train() 115 | 116 | cond_kwargs = {} 117 | cond_frames = [] 118 | if args.cond_generation: 119 | data = load_data( 120 | data_dir=args.data_dir, 121 | batch_size=args.batch_size, 122 | image_size=args.image_size, 123 | class_cond=args.class_cond, 124 | deterministic=True, 125 | # rgb=args.rgb, 126 | seq_len=args.seq_len 127 | ) 128 | 129 | num = "" 130 | 131 | for i in args.cond_frames: 132 | if i == ",": 133 | cond_frames.append(int(num)) 134 | num = "" 135 | else: 136 | num = num + i 137 | ref_frames = list(i for i in range(args.seq_len) if i not in cond_frames) 138 | # logger.log(f"cond_frames: {cond_frames}") 139 | # logger.log(f"ref_frames: {ref_frames}") 140 | # logger.log(f"seq_len: {args.seq_len}") 141 | cond_kwargs["resampling_steps"] = args.resample_steps 142 | cond_kwargs["cond_frames"] = cond_frames 143 | 144 | channels = 4 145 | # breakpoint() 146 | # logger.log("sampling...") 147 | all_videos = [] 148 | all_gt = [] 149 | shuffled_input = [] 150 | ground_truth = [] 151 | sample_time=[] 152 | perm_index = [] 153 | imputation_index = [] 154 | all_idx = [] 155 | all_permutation = [] 156 | all_names = [] 157 | all_sampled_normalized=[] 158 | mask_list=[] 159 | kendall_dis_sum = 0 160 | kendall_dis_list = [] 161 | while len(kendall_dis_list) < args.num_samples: 162 | print(len(kendall_dis_list)) 163 | min_val, max_val = 0, 0 164 | if args.cond_generation: 165 | raw_video, _,index = next(data) # video, _ = next(data) for others except something and nturgbd-skeleton dataset 166 | # breakpoint() 167 | raw_video = raw_video.permute(0,2,1,3,4) 168 | # all_names.append(name) 169 | # raw_video = raw_video.permute(0, 2, 1, 3, 4) # permuted for something dataset and nturgbd-skeleton only 170 | # original_skeleton = video.clone().detach() 171 | video = raw_video.float() # just for skeleton and nturgbd dataset 172 | # breakpoint() 173 | # normalization for nturgbd 174 | min_val = torch.min(video) 175 | max_val = torch.max(video) 176 | range_val = max_val - min_val 177 | 178 | batch = (video - min_val) / range_val 179 | video = batch * 2 - 1 180 | # cond_kwargs["cond_img"] = video[:, :, cond_frames].to(dist_util.dev()) 181 | # cond_kwargs["cond_frames"] = cond_frames 182 | # video = video.to(dist_util.dev()) 183 | # breakpoint() 184 | video = video 185 | idx = th.randperm(video.shape[2]) 186 | idx=idx[idx!=0] 187 | idx=torch.cat((torch.tensor([0]),idx)) 188 | # breakpoint() 189 | perm_index.append(idx) 190 | # original_data = torch.Tensor(video) 191 | 192 | conditional_data = video[:, :, idx, :, :] 193 | true_idx = np.zeros(args.seq_len, ) 194 | true_idx[:] = idx 195 | 196 | # Remove some frames 197 | num_frames_to_remove = 12 198 | print("number of missing frame is: ",num_frames_to_remove) 199 | # 200 | # # create an array of indices for all frames 201 | frame_indices = np.arange(conditional_data.shape[2]) 202 | # 203 | # # randomly select the indices of the frames to remove for all batches 204 | remove_indices = np.random.choice(frame_indices, num_frames_to_remove, replace=False) 205 | remove_indices = remove_indices[remove_indices!=0] 206 | # 207 | # # create an array of zeros with the same shape as the video frames 208 | dummy_frames = np.random.randn( 209 | conditional_data.shape[0], conditional_data.shape[1], 1, conditional_data.shape[3], 210 | conditional_data.shape[4]) 211 | dummy_frames = torch.from_numpy(dummy_frames).to(conditional_data.device) 212 | # breakpoint() 213 | conditional_data[:, :, remove_indices, :, :] = dummy_frames.float() 214 | 215 | mask = np.ones(conditional_data.shape[2]) 216 | mask[remove_indices] = 0 217 | mask_list.append(mask) 218 | mask = torch.tensor(mask).cuda() 219 | 220 | 221 | cond_kwargs["cond_frames"] = [i for i in frame_indices if i not in remove_indices] 222 | cond_kwargs["cond_img"] = conditional_data[:, :, cond_kwargs["cond_frames"]].to(dist_util.dev()) 223 | # breakpoint() 224 | sample_fn = ( 225 | diffusion.p_sample_loop if not args.use_ddim else diffusion.ddim_sample_loop 226 | ) 227 | sample = sample_fn( 228 | model, 229 | (args.batch_size, args.seq_len, 16), 230 | clip_denoised=args.clip_denoised, 231 | progress=False, 232 | cond_kwargs=cond_kwargs, 233 | condition_args=conditional_data, 234 | # mask=mask, 235 | original_args=video 236 | # gap=args.diffusion_step_gap 237 | ) 238 | # breakpoint() 239 | p_enc_1d_model = PositionalEncoding1D(16) 240 | time_emb_gt = p_enc_1d_model(torch.rand(sample[0].shape[0], sample[0].shape[1], 16)) 241 | 242 | 243 | def find_permutation(distance_matrix): 244 | sort_list = [] 245 | for m in range(distance_matrix.shape[1]): 246 | order = distance_matrix[:,0].argmin() 247 | sort_list.append(order) 248 | distance_matrix = distance_matrix[:,1:] 249 | distance_matrix[order,:] = 10**5 250 | return sort_list 251 | # conditional_data = ((conditional_data.clamp(-1,1) + 1) * (max_val-min_val) /2+ min_val) 252 | # sample = sample.contiguous() 253 | sample_normalized = sample[0].clamp(-1, 1) 254 | dist = pairwise_distances(time_emb_gt[0], sample[0][0].cpu(), metric='manhattan') 255 | permutation = find_permutation(dist) 256 | permutation = np.array(permutation) 257 | idx = np.array(idx) 258 | 259 | if num_frames_to_remove != 0: 260 | # breakpoint() 261 | permutation_con = permutation[cond_kwargs["cond_frames"]] 262 | idx_con = idx[cond_kwargs["cond_frames"]] 263 | permutation_blank = np.sort(permutation[remove_indices]) 264 | idx_blank = np.sort(idx[remove_indices]) 265 | 266 | permutation = np.concatenate((permutation_con,permutation_blank)) 267 | idx = np.concatenate((idx_con,idx_blank)) 268 | 269 | 270 | kendall_dis_1 = 0 271 | for n1 in range(len(permutation)): 272 | for n2 in range(len(permutation)): 273 | if permutation[n1]idx[n2]: 274 | kendall_dis_1 += 1 275 | # kendall_dis_1 += np.abs(permutation - idx).sum() 276 | kendall_dis_2 = 0 277 | for n1 in range(len(permutation)): 278 | for n2 in range(len(permutation)): 279 | if permutation[n1] < permutation[n2] and (args.seq_len-1-idx)[n1] > (args.seq_len-1-idx)[n2]: 280 | kendall_dis_2 += 1 281 | # kendall_dis_2 += np.abs(permutation -args.seq_len + 1 + idx).sum() 282 | kendall_dis =min(kendall_dis_1,kendall_dis_2)/(sample[0].shape[1]*(sample[0].shape[1]-1)/2) 283 | 284 | 285 | # breakpoint() 286 | 287 | kendall_dis_sum += kendall_dis 288 | kendall_dis_list.append(kendall_dis) 289 | print(len(kendall_dis_list), 'kendall distance:', kendall_dis,kendall_dis_sum/len(kendall_dis_list) ) 290 | if kendall_dis>0.05: 291 | print('video_index:',index) 292 | print(idx) 293 | print(permutation) 294 | 295 | 296 | print("Total result is:",kendall_dis_sum/args.num_samples) 297 | # # end of normalization 298 | # all_sampled_normalized.append(sample_normalized) 299 | # gathered_samples = [th.zeros_like(sample) for _ in range(dist.get_world_size())] 300 | # # gathered_samples_time=[th.zeros_like(sample_time_out) for _ in range(dist.get_world_size())] 301 | # dist.all_gather(gathered_samples, sample) # gather not supported with NCCL 302 | # all_videos.extend([sample.cpu().numpy() for sample in gathered_samples]) 303 | # # sample_time.append(sample_time_out.cpu().numpy()) #for sample_time_out in gathered_samples_time]) 304 | # ground_truth.append(video.permute(0, 2, 3, 4, 1).cpu().numpy()) 305 | # shuffled_input.append(conditional_data.permute(0, 2, 3, 4, 1).cpu().numpy()) 306 | # all_idx.append(true_idx) 307 | # # breakpoint() 308 | # # logger.log(f"created {len(all_videos) * args.batch_size} samples") 309 | 310 | # generated_data = np.concatenate(all_videos, axis=0) 311 | # ground_truth_data = np.asarray(ground_truth) 312 | # shuffled_input_data = np.asarray(shuffled_input) 313 | 314 | # gen_raw = np.concatenate(all_videos, axis=0) # raw BxTxWxHXC 1x20x32x32x4 315 | # gen_raw = torch.from_numpy(gen_raw).float() 316 | # gen_norm = torch.cat(all_sampled_normalized, dim=0) 317 | # gen_time=np.concatenate(sample_time,axis=0) 318 | # breakpoint() 319 | # plot_samples(generated_data, shuffled_input_data, ground_truth_data) 320 | # fig=plt.figure() 321 | # for i in range(gen_time.shape[0]): 322 | # plt.imshow(gen_time[i]) 323 | # plt.savefig('fig_test/time{}.png'.format(i)) 324 | 325 | # # breakpoint() 326 | # p_enc_1d_model = PositionalEncoding1D(16) 327 | # time_emb = p_enc_1d_model(torch.rand(generated_data.shape[0], generated_data.shape[1], 16)) 328 | # # distances = torch.cdist(torch.tensor(generated_data[0]), torch.tensor(time_emb[0])) 329 | # # indices = distances.argmin(dim=-1) 330 | # indice_list=[] 331 | # for i in range(generated_data.shape[0]): 332 | # distances = torch.cdist(torch.tensor(generated_data[i]), torch.tensor(time_emb[0])) 333 | # indices = distances.argmin(dim=-1) 334 | # indice_list.append(indices) 335 | # # breakpoint() 336 | # from datetime import datetime 337 | # currentDateAndTime = datetime.now() 338 | # if args.cond_generation and args.save_gt: 339 | # np.save(os.path.join('GeneratedLatent', str(args.num_samples) + '_samples_' +str(datetime.now().day)+'_'+str(datetime.now().hour)+ '_gt_'+str(args.resample_steps)+str(args.diffusion_step_gap)+'.npy'),ground_truth_data) 340 | # np.save(os.path.join('GeneratedLatent', str(args.num_samples) + '_samples_' +str(datetime.now().day)+'_'+str(datetime.now().hour)+ '_gen_'+str(args.resample_steps)+str(args.diffusion_step_gap)+'.npy'), 341 | # generated_data) 342 | # np.save(os.path.join('GeneratedLatent', str(args.num_samples) + '_samples_' +str(datetime.now().day)+'_'+str(datetime.now().hour)+ '_idx_'+str(args.resample_steps)+str(args.diffusion_step_gap)+'.npy'), 343 | # np.asarray(all_idx)) 344 | # np.save(os.path.join('GeneratedLatent', str(args.num_samples) + '_samples_' + str(datetime.now().day) + '_' + str( 345 | # datetime.now().hour) + '_file_names_' + str(args.resample_steps) + str(args.diffusion_step_gap) + '.npy'), 346 | # np.asarray(all_names)) 347 | # np.save(os.path.join('GeneratedLatent', str(args.num_samples) + '_samples_' +str(datetime.now().day)+'_'+str(datetime.now().hour)+ '_condition_'+str(args.resample_steps)+str(args.diffusion_step_gap)+'.npy'), 348 | # shuffled_input_data) 349 | # np.save(os.path.join('GeneratedLatent', str(args.num_samples) + '_samples_' +str(datetime.now().day)+'_'+str(datetime.now().hour)+ '_maskList_'+str(args.resample_steps)+str(args.diffusion_step_gap)+'.npy'), 350 | # np.asarray(mask_list)) 351 | # dist.barrier() 352 | # logger.log("sampling complete") 353 | 354 | 355 | def create_argparser(): 356 | defaults = dict( 357 | clip_denoised=True, 358 | num_samples=1000, # 10 359 | batch_size=1, # 10 360 | use_ddim=False, 361 | model_path="", 362 | seq_len=32, # 16 363 | sampling_type="generation", 364 | cond_frames="", 365 | cond_generation=True, # True 366 | resample_steps=1, 367 | data_dir='', 368 | save_gt=True, 369 | seed=0, 370 | diffusion_step_gap=1, 371 | ) 372 | defaults.update(model_and_diffusion_defaults()) 373 | parser = argparse.ArgumentParser() 374 | add_dict_to_argparser(parser, defaults) 375 | return parser 376 | 377 | 378 | if __name__ == "__main__": 379 | import time 380 | 381 | MODEL_FLAGS = "--image_size 8 --num_channels 128 --num_res_blocks 3 --scale_time_dim 8" # image size 64 382 | DIFFUSION_FLAGS = "--diffusion_steps 1200 --noise_schedule linear" # 1000 383 | start = time.time() 384 | main() 385 | end = time.time() 386 | print(f"elapsed time: {end - start}") -------------------------------------------------------------------------------- /image_model/models.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # GLIDE: https://github.com/openai/glide-text2im 9 | # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py 10 | # -------------------------------------------------------- 11 | 12 | import torch 13 | import torch.nn as nn 14 | import numpy as np 15 | import math 16 | from timm.models.vision_transformer import PatchEmbed, Attention, Mlp 17 | 18 | 19 | def modulate(x, shift, scale): 20 | return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) 21 | 22 | 23 | ################################################################################# 24 | # Embedding Layers for Timesteps and Class Labels # 25 | ################################################################################# 26 | 27 | class TimestepEmbedder(nn.Module): 28 | """ 29 | Embeds scalar timesteps into vector representations. 30 | """ 31 | def __init__(self, hidden_size, frequency_embedding_size=256): 32 | super().__init__() 33 | self.mlp = nn.Sequential( 34 | nn.Linear(frequency_embedding_size, hidden_size, bias=True), 35 | nn.SiLU(), 36 | nn.Linear(hidden_size, hidden_size, bias=True), 37 | ) 38 | self.frequency_embedding_size = frequency_embedding_size 39 | 40 | @staticmethod 41 | def timestep_embedding(t, dim, max_period=10000): 42 | """ 43 | Create sinusoidal timestep embeddings. 44 | :param t: a 1-D Tensor of N indices, one per batch element. 45 | These may be fractional. 46 | :param dim: the dimension of the output. 47 | :param max_period: controls the minimum frequency of the embeddings. 48 | :return: an (N, D) Tensor of positional embeddings. 49 | """ 50 | # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py 51 | half = dim // 2 52 | freqs = torch.exp( 53 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 54 | ).to(device=t.device) 55 | args = t[:, None].float() * freqs[None] 56 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 57 | if dim % 2: 58 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 59 | return embedding 60 | 61 | def forward(self, t): 62 | t_freq = self.timestep_embedding(t, self.frequency_embedding_size) 63 | t_emb = self.mlp(t_freq) 64 | return t_emb 65 | 66 | 67 | class LabelEmbedder(nn.Module): 68 | """ 69 | Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. 70 | """ 71 | def __init__(self, num_classes, hidden_size, dropout_prob): 72 | super().__init__() 73 | use_cfg_embedding = dropout_prob > 0 74 | self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size) 75 | self.num_classes = num_classes 76 | self.dropout_prob = dropout_prob 77 | 78 | def token_drop(self, labels, force_drop_ids=None): 79 | """ 80 | Drops labels to enable classifier-free guidance. 81 | """ 82 | if force_drop_ids is None: 83 | drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob 84 | else: 85 | drop_ids = force_drop_ids == 1 86 | labels = torch.where(drop_ids, self.num_classes, labels) 87 | return labels 88 | 89 | def forward(self, labels, train, force_drop_ids=None): 90 | use_dropout = self.dropout_prob > 0 91 | if (train and use_dropout) or (force_drop_ids is not None): 92 | labels = self.token_drop(labels, force_drop_ids) 93 | embeddings = self.embedding_table(labels) 94 | return embeddings 95 | 96 | 97 | ################################################################################# 98 | # Core DiT Model # 99 | ################################################################################# 100 | 101 | class DiTBlock(nn.Module): 102 | """ 103 | A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning. 104 | """ 105 | def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs): 106 | super().__init__() 107 | self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 108 | self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs) 109 | self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 110 | mlp_hidden_dim = int(hidden_size * mlp_ratio) 111 | approx_gelu = lambda: nn.GELU(approximate="tanh") 112 | self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) 113 | self.adaLN_modulation = nn.Sequential( 114 | nn.SiLU(), 115 | nn.Linear(hidden_size, 6 * hidden_size, bias=True) 116 | ) 117 | 118 | def forward(self, x, c): 119 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1) 120 | x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa)) 121 | x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) 122 | return x 123 | 124 | 125 | class FinalLayer(nn.Module): 126 | """ 127 | The final layer of DiT. 128 | """ 129 | def __init__(self, hidden_size, patch_size, out_channels): 130 | super().__init__() 131 | self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 132 | self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) 133 | self.adaLN_modulation = nn.Sequential( 134 | nn.SiLU(), 135 | nn.Linear(hidden_size, 2 * hidden_size, bias=True) 136 | ) 137 | 138 | def forward(self, x, c): 139 | shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) 140 | x = modulate(self.norm_final(x), shift, scale) 141 | x = self.linear(x) 142 | return x 143 | 144 | 145 | class DiT(nn.Module): 146 | """ 147 | Diffusion model with a Transformer backbone. 148 | """ 149 | def __init__( 150 | self, 151 | input_size=255, 152 | patch_size=2, 153 | in_channels=3, 154 | hidden_size=1152, 155 | depth=28, 156 | num_heads=16, 157 | mlp_ratio=4.0, 158 | class_dropout_prob=0.1, 159 | num_classes=0, 160 | learn_sigma=False, 161 | ): 162 | super().__init__() 163 | self.learn_sigma = learn_sigma 164 | self.in_channels = in_channels 165 | self.out_channels = in_channels * 2 if learn_sigma else in_channels 166 | self.patch_size = patch_size 167 | self.num_heads = num_heads 168 | 169 | self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True) 170 | self.t_embedder = TimestepEmbedder(hidden_size) 171 | # self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob) 172 | num_patches = self.x_embedder.num_patches 173 | # Will use fixed sin-cos embedding: 174 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False) 175 | 176 | self.time_emb_in = nn.Linear(8,768) 177 | self.time_emb_out1 = nn.Linear(768,64) 178 | self.time_emb_out_silu = nn.SiLU() 179 | self.time_emb_out2 = nn.Linear(64,8) 180 | 181 | self.blocks = nn.ModuleList([ 182 | DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth) 183 | ]) 184 | self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels) 185 | self.initialize_weights() 186 | 187 | def initialize_weights(self): 188 | # Initialize transformer layers: 189 | def _basic_init(module): 190 | if isinstance(module, nn.Linear): 191 | torch.nn.init.xavier_uniform_(module.weight) 192 | if module.bias is not None: 193 | nn.init.constant_(module.bias, 0) 194 | self.apply(_basic_init) 195 | 196 | # Initialize (and freeze) pos_embed by sin-cos embedding: 197 | pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5)) 198 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) 199 | 200 | # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): 201 | w = self.x_embedder.proj.weight.data 202 | nn.init.xavier_uniform_(w.view([w.shape[0], -1])) 203 | nn.init.constant_(self.x_embedder.proj.bias, 0) 204 | 205 | # Initialize label embedding table: 206 | # nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) 207 | 208 | # Initialize timestep embedding MLP: 209 | nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) 210 | nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) 211 | 212 | nn.init.normal_(self.time_emb_in.weight, std=0.02) 213 | nn.init.normal_(self.time_emb_out1.weight, std=0.02) 214 | nn.init.normal_(self.time_emb_out2.weight, std=0.02) 215 | 216 | # Zero-out adaLN modulation layers in DiT blocks: 217 | for block in self.blocks: 218 | nn.init.constant_(block.adaLN_modulation[-1].weight, 0) 219 | nn.init.constant_(block.adaLN_modulation[-1].bias, 0) 220 | 221 | # Zero-out output layers: 222 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) 223 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) 224 | nn.init.constant_(self.final_layer.linear.weight, 0) 225 | nn.init.constant_(self.final_layer.linear.bias, 0) 226 | 227 | def unpatchify(self, x): 228 | """ 229 | x: (N, T, patch_size**2 * C) 230 | imgs: (N, H, W, C) 231 | """ 232 | c = self.out_channels 233 | p = self.x_embedder.patch_size[0] 234 | h = w = int(x.shape[1] ** 0.5) 235 | assert h * w == x.shape[1] 236 | 237 | x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) 238 | x = torch.einsum('nhwpqc->nchpwq', x) 239 | imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p)) 240 | return imgs 241 | 242 | def forward(self, x, t, time_emb, y=None): 243 | """ 244 | Forward pass of DiT. 245 | x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) 246 | t: (N,) tensor of diffusion timesteps 247 | y: (N,) tensor of class labels 248 | """ 249 | time_emb = self.time_emb_in(time_emb) 250 | x = self.x_embedder(x) + time_emb + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2 251 | t = self.t_embedder(t) # (N, D) 252 | # y = self.y_embedder(y, self.training) # (N, D) 253 | c = t # (N, D) 254 | for block in self.blocks: 255 | x = block(x, c) # (N, T, D) 256 | x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels) 257 | time_emb = self.time_emb_out1(x) 258 | time_emb = self.time_emb_out_silu(time_emb) 259 | time_emb = self.time_emb_out2(time_emb) 260 | x = self.unpatchify(x) # (N, out_channels, H, W) 261 | 262 | return x,time_emb 263 | 264 | def forward_with_cfg(self, x, t, y, cfg_scale): 265 | """ 266 | Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance. 267 | """ 268 | # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb 269 | half = x[: len(x) // 2] 270 | combined = torch.cat([half, half], dim=0) 271 | model_out = self.forward(combined, t, y) 272 | # For exact reproducibility reasons, we apply classifier-free guidance on only 273 | # three channels by default. The standard approach to cfg applies it to all channels. 274 | # This can be done by uncommenting the following line and commenting-out the line following that. 275 | # eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:] 276 | eps, rest = model_out[:, :3], model_out[:, 3:] 277 | cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) 278 | half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) 279 | eps = torch.cat([half_eps, half_eps], dim=0) 280 | return torch.cat([eps, rest], dim=1) 281 | 282 | 283 | ################################################################################# 284 | # Sine/Cosine Positional Embedding Functions # 285 | ################################################################################# 286 | # https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py 287 | 288 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): 289 | """ 290 | grid_size: int of the grid height and width 291 | return: 292 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 293 | """ 294 | grid_h = np.arange(grid_size, dtype=np.float32) 295 | grid_w = np.arange(grid_size, dtype=np.float32) 296 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 297 | grid = np.stack(grid, axis=0) 298 | 299 | grid = grid.reshape([2, 1, grid_size, grid_size]) 300 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 301 | if cls_token and extra_tokens > 0: 302 | pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) 303 | return pos_embed 304 | 305 | 306 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 307 | assert embed_dim % 2 == 0 308 | 309 | # use half of dimensions to encode grid_h 310 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 311 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 312 | 313 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 314 | return emb 315 | 316 | 317 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 318 | """ 319 | embed_dim: output dimension for each position 320 | pos: a list of positions to be encoded: size (M,) 321 | out: (M, D) 322 | """ 323 | assert embed_dim % 2 == 0 324 | omega = np.arange(embed_dim // 2, dtype=np.float64) 325 | omega /= embed_dim / 2. 326 | omega = 1. / 10000**omega # (D/2,) 327 | 328 | pos = pos.reshape(-1) # (M,) 329 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 330 | 331 | emb_sin = np.sin(out) # (M, D/2) 332 | emb_cos = np.cos(out) # (M, D/2) 333 | 334 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 335 | return emb 336 | 337 | 338 | ################################################################################# 339 | # DiT Configs # 340 | ################################################################################# 341 | 342 | def DiT_XL_2(**kwargs): 343 | return DiT(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs) 344 | 345 | def DiT_XL_4(**kwargs): 346 | return DiT(depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs) 347 | 348 | def DiT_XL_8(**kwargs): 349 | return DiT(depth=28, hidden_size=1152, patch_size=8, num_heads=16, **kwargs) 350 | 351 | def DiT_L_2(**kwargs): 352 | return DiT(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs) 353 | 354 | def DiT_L_4(**kwargs): 355 | return DiT(depth=24, hidden_size=1024, patch_size=4, num_heads=16, **kwargs) 356 | 357 | def DiT_L_8(**kwargs): 358 | return DiT(depth=24, hidden_size=1024, patch_size=8, num_heads=16, **kwargs) 359 | 360 | def DiT_B_2(**kwargs): 361 | return DiT(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs) 362 | 363 | def DiT_B_4(**kwargs): 364 | return DiT(depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs) 365 | 366 | def DiT_B_8(**kwargs): 367 | return DiT(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs) 368 | 369 | def DiT_S_2(**kwargs): 370 | return DiT(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs) 371 | 372 | def DiT_S_4(**kwargs): 373 | return DiT(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs) 374 | 375 | def DiT_S_8(**kwargs): 376 | return DiT(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs) 377 | 378 | def JPDVT(**kwargs): 379 | return DiT(depth=12, hidden_size=768, patch_size=16, num_heads=12, **kwargs) 380 | 381 | def JPDVT_S(**kwargs): 382 | return DiT(depth=12, hidden_size=768, patch_size=32, num_heads=12, **kwargs) 383 | 384 | def JPDVT_T(**kwargs): 385 | return DiT(depth=12, hidden_size=768, patch_size=64, num_heads=12, **kwargs) 386 | 387 | DiT_models = { 388 | 'DiT-XL/2': DiT_XL_2, 'DiT-XL/4': DiT_XL_4, 'DiT-XL/8': DiT_XL_8, 389 | 'DiT-L/2': DiT_L_2, 'DiT-L/4': DiT_L_4, 'DiT-L/8': DiT_L_8, 390 | 'DiT-B/2': DiT_B_2, 'DiT-B/4': DiT_B_4, 'DiT-B/8': DiT_B_8, 391 | 'DiT-S/2': DiT_S_2, 'DiT-S/4': DiT_S_4, 'DiT-S/8': DiT_S_8, 392 | 'JPDVT': JPDVT, 'JPDVT-S': JPDVT_S, 'JPDVT-T': JPDVT_T, 393 | } 394 | -------------------------------------------------------------------------------- /video_model/diffusion/train_util.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import functools 3 | import os 4 | import pdb 5 | import blobfile as bf 6 | import numpy as np 7 | import torch as th 8 | import torch.distributed as dist 9 | from torch.nn.parallel.distributed import DistributedDataParallel as DDP 10 | from torch.optim import AdamW 11 | import dist_util, logger 12 | from fp16_util import ( 13 | make_master_params, 14 | master_params_to_model_params, 15 | model_grads_to_master_grads, 16 | unflatten_master_params, 17 | zero_grad, 18 | ) 19 | from nn import update_ema 20 | from resample import LossAwareSampler, UniformSampler 21 | from positional_encodings.torch_encodings import PositionalEncoding1D 22 | 23 | # For ImageNet experiments, this was a good default value. 24 | # We found that the lg_loss_scale quickly climbed to 25 | # 20-21 within the first ~1K steps of training. 26 | INITIAL_LOG_LOSS_SCALE = 20.0 27 | import torch 28 | 29 | class TrainLoop: 30 | 31 | def __init__( 32 | self, 33 | *, 34 | model, 35 | diffusion, 36 | data, 37 | batch_size, 38 | microbatch, 39 | lr, 40 | ema_rate, 41 | log_interval, 42 | save_interval, 43 | resume_checkpoint, 44 | use_fp16=False, 45 | fp16_scale_growth=1e-3, 46 | schedule_sampler=None, 47 | weight_decay=0.0, 48 | lr_anneal_steps=0, 49 | clip=1, 50 | anneal_type=None, 51 | steps_drop=None, 52 | drop=None, 53 | decay=None, 54 | max_num_mask_frames=4, #4 55 | mask_range=None, 56 | uncondition_rate=True, 57 | exclude_conditional=True, 58 | ): 59 | 60 | self.model = model 61 | self.diffusion = diffusion 62 | self.data = data 63 | self.batch_size = batch_size 64 | self.microbatch = microbatch if microbatch > 0 else batch_size 65 | self.accumulation_steps = batch_size / microbatch if microbatch > 0 else 1 66 | self.lr = lr 67 | self.current_lr = lr 68 | self.ema_rate = ( 69 | [ema_rate] 70 | if isinstance(ema_rate, float) 71 | else [float(x) for x in ema_rate.split(",")] 72 | ) 73 | self.log_interval = log_interval 74 | self.save_interval = save_interval 75 | self.resume_checkpoint = resume_checkpoint 76 | self.use_fp16 = use_fp16 77 | self.fp16_scale_growth = fp16_scale_growth 78 | self.schedule_sampler = schedule_sampler or UniformSampler(diffusion) 79 | self.weight_decay = weight_decay 80 | self.anneal_type = anneal_type 81 | if self.anneal_type == 'linear': 82 | assert lr_anneal_steps != 0 83 | self.lr_anneal_steps = lr_anneal_steps 84 | if self.anneal_type == 'step': 85 | assert steps_drop != 0 86 | assert drop != 0 87 | self.steps_drop = steps_drop 88 | self.drop = drop 89 | if self.anneal_type == 'time_based': 90 | assert decay != 0 91 | self.decay = decay 92 | 93 | self.clip = clip 94 | self.max_num_mask_frames = max_num_mask_frames 95 | self.mask_range = mask_range 96 | self.uncondition_rate =uncondition_rate 97 | self.exclude_conditional = exclude_conditional 98 | 99 | self.step = 0 100 | self.resume_step = 0 101 | self.global_batch = self.batch_size * dist.get_world_size() 102 | logger.log(f"global batch size = {self.global_batch}") 103 | 104 | self.model_params = list(self.model.parameters()) 105 | self.master_params = self.model_params 106 | self.lg_loss_scale = INITIAL_LOG_LOSS_SCALE 107 | self.sync_cuda = th.cuda.is_available() 108 | 109 | self._load_and_sync_parameters() 110 | if self.use_fp16: 111 | self._setup_fp16() 112 | 113 | self.opt = AdamW(self.master_params, lr=self.lr, weight_decay=self.weight_decay) 114 | if self.resume_step: 115 | self._load_optimizer_state() 116 | # Model was resumed, either due to a restart or a checkpoint 117 | # being specified at the command line. 118 | self.ema_params = [ 119 | self._load_ema_parameters(rate) for rate in self.ema_rate 120 | ] 121 | else: 122 | self.ema_params = [ 123 | copy.deepcopy(self.master_params) for _ in range(len(self.ema_rate)) 124 | ] 125 | 126 | if th.cuda.is_available(): 127 | logger.log(f"world_size: {dist.get_world_size()}") 128 | self.use_ddp = True 129 | self.ddp_model = DDP( 130 | self.model, 131 | device_ids=[dist_util.dev()], 132 | output_device=dist_util.dev(), 133 | # device_ids=[device], 134 | # output_device=device, 135 | broadcast_buffers=False, 136 | bucket_cap_mb=128, 137 | find_unused_parameters=False, 138 | ) 139 | else: 140 | if dist.get_world_size() > 1: 141 | logger.warn( 142 | "Distributed training requires CUDA. " 143 | "Gradients will not be synchronized properly!" 144 | ) 145 | self.use_ddp = False 146 | self.ddp_model = self.model 147 | 148 | def _load_and_sync_parameters(self): 149 | resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint 150 | 151 | if resume_checkpoint: 152 | self.resume_step = parse_resume_step_from_filename(resume_checkpoint) 153 | 154 | # if dist.get_rank() == 0: 155 | logger.log(f"loading model from checkpoint: {resume_checkpoint}...") 156 | self.model.load_state_dict( 157 | dist_util.load_state_dict( 158 | resume_checkpoint, map_location=dist_util.dev() 159 | ) 160 | ) 161 | 162 | 163 | dist_util.sync_params(self.model.parameters()) 164 | 165 | def _load_ema_parameters(self, rate): 166 | ema_params = copy.deepcopy(self.master_params) 167 | 168 | main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint 169 | ema_checkpoint = find_ema_checkpoint(main_checkpoint, self.resume_step, rate) 170 | if ema_checkpoint: 171 | logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...") 172 | state_dict = dist_util.load_state_dict( 173 | ema_checkpoint, map_location=dist_util.dev() 174 | ) 175 | ema_params = self._state_dict_to_master_params(state_dict) 176 | 177 | dist_util.sync_params(ema_params) 178 | return ema_params 179 | 180 | def _load_optimizer_state(self): 181 | main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint 182 | opt_checkpoint = bf.join( 183 | bf.dirname(main_checkpoint), f"opt{self.resume_step:06}.pt" 184 | ) 185 | if bf.exists(opt_checkpoint): 186 | logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}") 187 | state_dict = dist_util.load_opt_state_dict( 188 | opt_checkpoint, map_location=dist_util.dev() 189 | ) 190 | 191 | self.opt.load_state_dict(state_dict) 192 | 193 | def _setup_fp16(self): 194 | self.master_params = make_master_params(self.model_params) 195 | self.model.convert_to_fp16() 196 | 197 | def run_loop(self): 198 | 199 | while ( 200 | self.current_lr 201 | ): 202 | batch,_= next(self.data) 203 | min_val = torch.min(batch) # normalization for the nturgbd skeleton dataset 204 | max_val = torch.max(batch) 205 | range_val = max_val - min_val 206 | batch = (batch - min_val) / range_val 207 | batch = batch * 2 - 1 208 | 209 | # min_val = torch.min(cond_batch) # normalization for the nturgbd skeleton dataset 210 | # max_val = torch.max(cond_batch) 211 | # range_val = max_val - min_val 212 | # cond_batch = (cond_batch - min_val) / range_val 213 | # cond_batch = cond_batch * 2 - 1 214 | 215 | # if(batch.isnan().any()): 216 | # breakpoint() 217 | 218 | # breakpoint() 219 | 220 | self.run_step(batch) 221 | if self.step % self.log_interval == 0: 222 | logger.dumpkvs() 223 | if self.step % self.save_interval == 0: 224 | self.save() 225 | # Run for a finite amount of time in integration tests. Does access an environment variable 226 | if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0: 227 | return 228 | self.step += 1 229 | # Save the last checkpoint if it wasn't already saved. 230 | if (self.step - 1) % self.save_interval != 0: 231 | self.save() 232 | 233 | def run_step(self, batch): 234 | self.forward_backward(batch) 235 | if self.clip: 236 | th.nn.utils.clip_grad_norm_(self.ddp_model.parameters(), self.clip) 237 | if self.use_fp16: 238 | self.optimize_fp16() 239 | else: 240 | self.optimize_normal() 241 | self.log_step() 242 | 243 | def forward_backward(self, batch): 244 | zero_grad(self.model_params) 245 | for i in range(0, batch.shape[0], self.microbatch): 246 | micro = batch[i : i + self.microbatch].to(dist_util.dev()) 247 | # breakpoint() 248 | # [B, C, T, W, H] 249 | """ 250 | uncomment for CLEVRER 251 | """ 252 | micro=micro.permute(0,2,1,3,4) 253 | 254 | # micro=micro.permute(0,4,1,2,3) # uncomment for something-something and nturgbd skeletondataset 255 | # micro = (micro - micro.min()) / (micro.max() - micro.min()) 256 | 257 | # breakpoint() 258 | 259 | idx = torch.randperm(micro.shape[2]) # random frames of 5 260 | idx=idx[idx!=0] 261 | idx=torch.cat((torch.tensor([0]),idx)) 262 | # breakpoint() 263 | 264 | micro_cond= micro[:,:,idx,:,:] # shuffled video 265 | micro = micro_cond.clone() 266 | # print('training mode') 267 | breakpoint() 268 | 269 | # code for removing any frame 270 | # number of frames to remove 271 | num_frames_to_remove = 6 272 | # # 273 | # # # create an array of indices for all frames 274 | frame_indices = np.arange(micro_cond.shape[2]) 275 | # 276 | # # randomly select the indices of the frames to remove for all batches 277 | remove_indices = np.random.choice(frame_indices, num_frames_to_remove, replace=False) 278 | remove_indices = remove_indices[remove_indices!=0] 279 | # 280 | # # create an array of zeros with the same shape as the video frames 281 | dummy_frames = np.zeros( 282 | (micro_cond.shape[0], micro_cond.shape[1], 1, micro_cond.shape[3], micro_cond.shape[4])) 283 | dummy_frames = torch.from_numpy(dummy_frames).to(micro_cond.to(dist_util.dev())) # 284 | # 285 | # # replace the selected frames with the dummy frames for all batches 286 | # breakpoint() 287 | micro_cond[:, :, remove_indices, :, :] = dummy_frames 288 | 289 | 290 | p_enc_1d_model = PositionalEncoding1D(16) 291 | time_emb = p_enc_1d_model(torch.rand(micro.shape[0], micro.shape[2], 16)) 292 | time_emb = time_emb[:,idx,:] 293 | last_batch = (i + self.microbatch) >= batch.shape[0] 294 | t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev()) 295 | # breakpoint() 296 | mask = np.ones(micro.shape[2]) 297 | mask[remove_indices] = 0 298 | mask = torch.tensor(mask) 299 | compute_losses = functools.partial( 300 | self.diffusion.training_losses_two_dis, 301 | self.ddp_model, 302 | micro, 303 | t, 304 | condition_data=micro_cond,#model_kwargs 305 | time_emb=time_emb.to(dist_util.dev()), 306 | # mask=mask.to(dist_util.dev()), 307 | max_num_mask_frames=self.max_num_mask_frames, 308 | mask_range=self.mask_range, 309 | uncondition_rate=self.uncondition_rate, 310 | exclude_conditional=self.exclude_conditional, 311 | ) 312 | if last_batch or not self.use_ddp: 313 | losses = compute_losses() 314 | else: 315 | with self.ddp_model.no_sync(): 316 | losses = compute_losses() 317 | if isinstance(self.schedule_sampler, LossAwareSampler): 318 | self.schedule_sampler.update_with_local_losses( 319 | t, losses["loss"].detach() 320 | ) 321 | loss = (losses["loss"] * weights).mean() 322 | log_loss_dict( 323 | self.diffusion, t, {k: v * weights for k, v in losses.items()} 324 | ) 325 | loss = loss / self.accumulation_steps 326 | if self.use_fp16: 327 | loss_scale = 2 ** self.lg_loss_scale 328 | (loss * loss_scale).backward() 329 | else: 330 | loss.backward() 331 | 332 | def optimize_fp16(self): 333 | if any(not th.isfinite(p.grad).all() for p in self.model_params): 334 | self.lg_loss_scale -= 1 335 | logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}") 336 | return 337 | 338 | model_grads_to_master_grads(self.model_params, self.master_params) 339 | self.master_params[0].grad.mul_(1.0 / (2 ** self.lg_loss_scale)) 340 | self._log_grad_norm() 341 | self._anneal_lr() 342 | self.opt.step() 343 | for rate, params in zip(self.ema_rate, self.ema_params): 344 | update_ema(params, self.master_params, rate=rate) 345 | master_params_to_model_params(self.model_params, self.master_params) 346 | self.lg_loss_scale += self.fp16_scale_growth 347 | 348 | def optimize_normal(self): 349 | self._log_grad_norm() 350 | self._anneal_lr() 351 | self.opt.step() 352 | for rate, params in zip(self.ema_rate, self.ema_params): 353 | update_ema(params, self.master_params, rate=rate) 354 | 355 | def _log_grad_norm(self): 356 | sqsum = 0.0 357 | for p in self.master_params: 358 | sqsum += (p.grad ** 2).sum().item() 359 | logger.logkv_mean("grad_norm", np.sqrt(sqsum)) 360 | 361 | def _anneal_lr(self): 362 | if self.anneal_type is None: 363 | return 364 | if self.anneal_type == "linear": 365 | frac_done = (self.step + self.resume_step) / self.lr_anneal_steps 366 | lr = self.lr * (1 - frac_done) 367 | elif self.anneal_type == "step": 368 | lr = self.lr * self.drop**(np.floor((self.step + self.resume_step)/self.steps_drop)) 369 | elif self.anneal_type == "time_based": 370 | lr = self.lr / (1 + self.decay * (self.step + self.resume_step)) 371 | else: 372 | raise ValueError(f"unsupported anneal type: {self.anneal_type}") 373 | for param_group in self.opt.param_groups: 374 | param_group["lr"] = lr 375 | self.current_lr = lr 376 | 377 | def log_step(self): 378 | logger.logkv("step", self.step + self.resume_step) 379 | logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch) 380 | if self.use_fp16: 381 | logger.logkv("lg_loss_scale", self.lg_loss_scale) 382 | 383 | def save(self): 384 | def save_checkpoint(rate, params): 385 | state_dict = self._master_params_to_state_dict(params) 386 | if dist.get_rank() == 0: 387 | logger.log(f"saving model {rate}...") 388 | if not rate: 389 | filename = f"model{(self.step+self.resume_step):06d}.pt" 390 | else: 391 | filename = f"ema_{rate}_{(self.step+self.resume_step):06d}.pt" 392 | with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f: 393 | th.save(state_dict, f) 394 | 395 | save_checkpoint(0, self.master_params) 396 | for rate, params in zip(self.ema_rate, self.ema_params): 397 | save_checkpoint(rate, params) 398 | 399 | if dist.get_rank() == 0: 400 | with bf.BlobFile( 401 | bf.join(get_blob_logdir(), f"opt{(self.step+self.resume_step):06d}.pt"), 402 | "wb", 403 | ) as f: 404 | th.save(self.opt.state_dict(), f) 405 | 406 | dist.barrier() 407 | 408 | def _master_params_to_state_dict(self, master_params): 409 | if self.use_fp16: 410 | master_params = unflatten_master_params( 411 | list(self.model.parameters()), master_params 412 | ) 413 | state_dict = self.model.state_dict() 414 | for i, (name, _value) in enumerate(self.model.named_parameters()): 415 | assert name in state_dict 416 | state_dict[name] = master_params[i] 417 | return state_dict 418 | 419 | def _state_dict_to_master_params(self, state_dict): 420 | params = [state_dict[name] for name, _ in self.model.named_parameters()] 421 | if self.use_fp16: 422 | return make_master_params(params) 423 | else: 424 | return params 425 | 426 | 427 | def parse_resume_step_from_filename(filename): 428 | """ 429 | Parse filenames of the form path/to/modelNNNNNN.pt, where NNNNNN is the 430 | checkpoint's number of steps. 431 | """ 432 | split = filename.split("model") 433 | if len(split) < 2: 434 | return 0 435 | split1 = split[-1].split(".")[0] 436 | try: 437 | return int(split1) 438 | except ValueError: 439 | return 0 440 | 441 | 442 | def get_blob_logdir(): 443 | return os.environ.get("DIFFUSION_BLOB_LOGDIR", logger.get_dir()) 444 | 445 | 446 | def find_resume_checkpoint(): 447 | # On your infrastructure, you may want to override this to automatically 448 | # discover the latest checkpoint on your blob storage, etc. 449 | return None 450 | 451 | 452 | def find_ema_checkpoint(main_checkpoint, step, rate): 453 | if main_checkpoint is None: 454 | return None 455 | filename = f"ema_{rate}_{(step):06d}.pt" 456 | path = bf.join(bf.dirname(main_checkpoint), filename) 457 | if bf.exists(path): 458 | return path 459 | return None 460 | 461 | 462 | def log_loss_dict(diffusion, ts, losses): 463 | for key, values in losses.items(): 464 | logger.logkv_mean(key, values.mean().item()) 465 | # Log the quantiles (four quartiles, in particular). 466 | for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()): 467 | quartile = int(4 * sub_t / diffusion.num_timesteps) 468 | logger.logkv_mean(f"{key}_q{quartile}", sub_loss) 469 | -------------------------------------------------------------------------------- /video_model/diffusion/make_a_video.py: -------------------------------------------------------------------------------- 1 | import math 2 | import functools 3 | from operator import mul 4 | 5 | import torch 6 | from torch import nn, einsum 7 | 8 | from einops import rearrange, repeat, pack, unpack 9 | from einops.layers.torch import Rearrange 10 | from nn import ( 11 | SiLU, 12 | conv_nd, 13 | linear, 14 | avg_pool_nd, 15 | zero_module, 16 | normalization, 17 | timestep_embedding, 18 | checkpoint, 19 | ) 20 | 21 | # helper functions 22 | 23 | def exists(val): 24 | return val is not None 25 | 26 | 27 | def default(val, d): 28 | return val if exists(val) else d 29 | 30 | 31 | def mul_reduce(tup): 32 | return functools.reduce(mul, tup) 33 | 34 | 35 | def divisible_by(numer, denom): 36 | return (numer % denom) == 0 37 | 38 | 39 | mlist = nn.ModuleList 40 | 41 | 42 | # for time conditioning 43 | 44 | class SinusoidalPosEmb(nn.Module): 45 | def __init__(self, dim, theta=10000): 46 | super().__init__() 47 | self.theta = theta 48 | self.dim = dim 49 | 50 | def forward(self, x): 51 | dtype, device = x.dtype, x.device 52 | assert dtype == torch.float, 'input to sinusoidal pos emb must be a float type' 53 | 54 | half_dim = self.dim // 2 55 | emb = math.log(self.theta) / (half_dim - 1) 56 | emb = torch.exp(torch.arange(half_dim, device=device, dtype=dtype) * -emb) 57 | emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j') 58 | return torch.cat((emb.sin(), emb.cos()), dim=-1).type(dtype) 59 | 60 | 61 | # layernorm 3d 62 | 63 | class LayerNorm(nn.Module): 64 | def __init__(self, dim): 65 | super().__init__() 66 | self.g = nn.Parameter(torch.ones(dim)) 67 | 68 | def forward(self, x): 69 | eps = 1e-5 if x.dtype == torch.float32 else 1e-3 70 | var = torch.var(x, dim=1, unbiased=False, keepdim=True) 71 | mean = torch.mean(x, dim=1, keepdim=True) 72 | return (x - mean) * var.clamp(min=eps).rsqrt() * self.g 73 | 74 | 75 | # feedforward 76 | 77 | class GEGLU(nn.Module): 78 | def forward(self, x): 79 | x, gate = x.chunk(2, dim=-1) 80 | return x * F.gelu(gate) 81 | 82 | 83 | def FeedForward(dim, mult=4): 84 | inner_dim = int(dim * mult * 2 / 3) 85 | return nn.Sequential( 86 | nn.Linear(dim, inner_dim, bias=False), 87 | GEGLU(), 88 | nn.Linear(inner_dim, bias=False) 89 | ) 90 | 91 | 92 | # best relative positional encoding 93 | 94 | class ContinuousPositionBias(nn.Module): 95 | """ from https://arxiv.org/abs/2111.09883 """ 96 | 97 | def __init__( 98 | self, 99 | *, 100 | dim, 101 | heads, 102 | num_dims=1, 103 | layers=2, 104 | log_dist=True, 105 | cache_rel_pos=False 106 | ): 107 | super().__init__() 108 | self.num_dims = num_dims 109 | self.log_dist = log_dist 110 | 111 | self.net = nn.ModuleList([]) 112 | self.net.append(nn.Sequential(nn.Linear(self.num_dims, dim), nn.SiLU())) 113 | 114 | for _ in range(layers - 1): 115 | self.net.append(nn.Sequential(nn.Linear(dim, dim), nn.SiLU())) 116 | 117 | self.net.append(nn.Linear(dim, heads)) 118 | 119 | self.cache_rel_pos = cache_rel_pos 120 | self.register_buffer('rel_pos', None, persistent=False) 121 | 122 | @property 123 | def device(self): 124 | return next(self.parameters()).device 125 | 126 | def forward(self, *dimensions): 127 | device = self.device 128 | 129 | if not exists(self.rel_pos) or not self.cache_rel_pos: 130 | positions = [torch.arange(d, device=device) for d in dimensions] 131 | grid = torch.stack(torch.meshgrid(*positions, indexing='ij')) 132 | grid = rearrange(grid, 'c ... -> (...) c') 133 | rel_pos = rearrange(grid, 'i c -> i 1 c') - rearrange(grid, 'j c -> 1 j c') 134 | 135 | if self.log_dist: 136 | rel_pos = torch.sign(rel_pos) * torch.log(rel_pos.abs() + 1) 137 | 138 | self.register_buffer('rel_pos', rel_pos, persistent=False) 139 | 140 | rel_pos = self.rel_pos.float() 141 | 142 | for layer in self.net: 143 | rel_pos = layer(rel_pos) 144 | 145 | return rearrange(rel_pos, 'i j h -> h i j') 146 | 147 | 148 | # helper classes 149 | 150 | class Attention(nn.Module): 151 | def __init__( 152 | self, 153 | dim, 154 | dim_head=64, 155 | heads=8 156 | ): 157 | super().__init__() 158 | self.heads = heads 159 | self.scale = dim_head ** -0.5 160 | inner_dim = dim_head * heads 161 | 162 | self.norm = LayerNorm(dim) 163 | 164 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 165 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) 166 | self.to_out = nn.Linear(inner_dim, dim, bias=False) 167 | 168 | nn.init.zeros_(self.to_out.weight.data) # identity with skip connection 169 | 170 | def forward( 171 | self, 172 | x, 173 | rel_pos_bias=None 174 | ): 175 | x = self.norm(x) 176 | 177 | q, k, v = self.to_q(x), *self.to_kv(x).chunk(2, dim=-1) 178 | 179 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), (q, k, v)) 180 | 181 | q = q * self.scale 182 | 183 | sim = einsum('b h i d, b h j d -> b h i j', q, k) 184 | 185 | if exists(rel_pos_bias): 186 | sim = sim + rel_pos_bias 187 | 188 | attn = sim.softmax(dim=-1) 189 | 190 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 191 | 192 | out = rearrange(out, 'b h n d -> b n (h d)') 193 | return self.to_out(out) 194 | 195 | 196 | # main contribution - pseudo 3d conv 197 | 198 | class PseudoConv3d(nn.Module): 199 | def __init__( 200 | self, 201 | dim, 202 | dim_out=None, 203 | kernel_size=3, 204 | *, 205 | temporal_kernel_size=None, 206 | **kwargs 207 | ): 208 | super().__init__() 209 | dim_out = default(dim_out, dim) 210 | temporal_kernel_size = default(temporal_kernel_size, kernel_size) 211 | 212 | self.spatial_conv = nn.Conv2d(dim, dim_out, kernel_size=kernel_size, padding=kernel_size // 2) 213 | self.temporal_conv = nn.Conv1d(dim_out, dim_out, kernel_size=temporal_kernel_size, 214 | padding=temporal_kernel_size // 2) if kernel_size > 1 else None 215 | 216 | if exists(self.temporal_conv): 217 | nn.init.dirac_(self.temporal_conv.weight.data) # initialized to be identity 218 | nn.init.zeros_(self.temporal_conv.bias.data) 219 | 220 | def forward( 221 | self, 222 | x, 223 | enable_time=True 224 | ): 225 | b, c, *_, h, w = x.shape 226 | 227 | is_video = x.ndim == 5 228 | enable_time &= is_video 229 | 230 | if is_video: 231 | x = rearrange(x, 'b c f h w -> (b f) c h w') 232 | 233 | x = self.spatial_conv(x) 234 | 235 | if is_video: 236 | x = rearrange(x, '(b f) c h w -> b c f h w', b=b) 237 | 238 | if not enable_time or not exists(self.temporal_conv): 239 | return x 240 | 241 | x = rearrange(x, 'b c f h w -> (b h w) c f') 242 | 243 | x = self.temporal_conv(x) 244 | 245 | x = rearrange(x, '(b h w) c f -> b c f h w', h=h, w=w) 246 | 247 | return x 248 | 249 | 250 | # factorized spatial temporal attention from Ho et al. 251 | # todo - take care of relative positional biases + rotary embeddings 252 | 253 | class SpatioTemporalAttention(nn.Module): 254 | def __init__( 255 | self, 256 | dim, 257 | *, 258 | dim_head=64, 259 | heads=8 260 | ): 261 | super().__init__() 262 | self.spatial_attn = Attention(dim=dim, dim_head=dim_head, heads=heads) 263 | self.spatial_rel_pos_bias = ContinuousPositionBias(dim=dim // 2, heads=heads, num_dims=2) 264 | 265 | self.temporal_attn = Attention(dim=dim, dim_head=dim_head, heads=heads) 266 | self.temporal_rel_pos_bias = ContinuousPositionBias(dim=dim // 2, heads=heads, num_dims=1) 267 | 268 | def forward( 269 | self, 270 | x, 271 | enable_time=True 272 | ): 273 | b, c, *_, h, w = x.shape 274 | is_video = x.ndim == 5 275 | enable_time &= is_video 276 | 277 | if is_video: 278 | x = rearrange(x, 'b c f h w -> (b f) (h w) c') 279 | else: 280 | x = rearrange(x, 'b c h w -> b (h w) c') 281 | 282 | space_rel_pos_bias = self.spatial_rel_pos_bias(h, w) 283 | 284 | x = self.spatial_attn(x, rel_pos_bias=space_rel_pos_bias) + x 285 | 286 | if is_video: 287 | x = rearrange(x, '(b f) (h w) c -> b c f h w', b=b, h=h, w=w) 288 | else: 289 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) 290 | 291 | if not enable_time: 292 | return x 293 | 294 | x = rearrange(x, 'b c f h w -> (b h w) f c') 295 | 296 | time_rel_pos_bias = self.temporal_rel_pos_bias(x.shape[1]) 297 | 298 | x = self.temporal_attn(x, rel_pos_bias=time_rel_pos_bias) + x 299 | 300 | x = rearrange(x, '(b h w) f c -> b c f h w', w=w, h=h) 301 | 302 | return x 303 | 304 | 305 | # resnet block 306 | 307 | class Block(nn.Module): 308 | def __init__( 309 | self, 310 | dim, 311 | dim_out, 312 | kernel_size=3, 313 | temporal_kernel_size=None, 314 | groups=8 315 | ): 316 | super().__init__() 317 | self.project = PseudoConv3d(dim, dim_out, 3) 318 | self.norm = nn.GroupNorm(groups, dim_out) 319 | self.act = nn.SiLU() 320 | 321 | def forward( 322 | self, 323 | x, 324 | scale_shift=None, 325 | enable_time=False 326 | ): 327 | x = self.project(x, enable_time=enable_time) 328 | x = self.norm(x) 329 | 330 | if exists(scale_shift): 331 | scale, shift = scale_shift 332 | x = x * (scale + 1) + shift 333 | 334 | return self.act(x) 335 | 336 | 337 | class ResnetBlock(nn.Module): 338 | def __init__( 339 | self, 340 | dim, 341 | dim_out, 342 | *, 343 | timestep_cond_dim=None, 344 | groups=8 345 | ): 346 | super().__init__() 347 | 348 | self.timestep_mlp = None 349 | 350 | if exists(timestep_cond_dim): 351 | self.timestep_mlp = nn.Sequential( 352 | nn.SiLU(), 353 | nn.Linear(timestep_cond_dim, dim_out * 2) 354 | ) 355 | 356 | self.block1 = Block(dim, dim_out, groups=groups) 357 | self.block2 = Block(dim_out, dim_out, groups=groups) 358 | self.res_conv = PseudoConv3d(dim, dim_out, 1) if dim != dim_out else nn.Identity() 359 | 360 | def forward( 361 | self, 362 | x, 363 | timestep_emb=None, 364 | enable_time=True 365 | ): 366 | assert not (exists(timestep_emb) ^ exists(self.timestep_mlp)) 367 | 368 | scale_shift = None 369 | 370 | if exists(self.timestep_mlp) and exists(timestep_emb): 371 | time_emb = self.timestep_mlp(timestep_emb) 372 | to_einsum_eq = 'b c 1 1 1' if x.ndim == 5 else 'b c 1 1' 373 | time_emb = rearrange(time_emb, f'b c -> {to_einsum_eq}') 374 | scale_shift = time_emb.chunk(2, dim=1) 375 | 376 | h = self.block1(x, scale_shift=scale_shift, enable_time=enable_time) 377 | 378 | h = self.block2(h, enable_time=enable_time) 379 | 380 | return h + self.res_conv(x) 381 | 382 | 383 | # pixelshuffle upsamples and downsamples 384 | # where time dimension can be configured 385 | 386 | class Downsample(nn.Module): 387 | def __init__( 388 | self, 389 | dim, 390 | downsample_space=True, 391 | downsample_time=False, 392 | nonlin=False 393 | ): 394 | super().__init__() 395 | assert downsample_space or downsample_time 396 | 397 | self.down_space = nn.Sequential( 398 | Rearrange('b c (h p1) (w p2) -> b (c p1 p2) h w', p1=2, p2=2), 399 | nn.Conv2d(dim * 4, dim, 1, bias=False), 400 | nn.SiLU() if nonlin else nn.Identity() 401 | ) if downsample_space else None 402 | 403 | self.down_time = nn.Sequential( 404 | Rearrange('b c (f p) h w -> b (c p) f h w', p=2), 405 | nn.Conv3d(dim * 2, dim, 1, bias=False), 406 | nn.SiLU() if nonlin else nn.Identity() 407 | ) if downsample_time else None 408 | 409 | def forward( 410 | self, 411 | x, 412 | enable_time=True 413 | ): 414 | is_video = x.ndim == 5 415 | 416 | if is_video: 417 | x = rearrange(x, 'b c f h w -> b f c h w') 418 | x, ps = pack([x], '* c h w') 419 | 420 | if exists(self.down_space): 421 | x = self.down_space(x) 422 | 423 | if is_video: 424 | x, = unpack(x, ps, '* c h w') 425 | x = rearrange(x, 'b f c h w -> b c f h w') 426 | 427 | if not is_video or not exists(self.down_time) or not enable_time: 428 | return x 429 | 430 | x = self.down_time(x) 431 | 432 | return x 433 | 434 | 435 | class Upsample(nn.Module): 436 | def __init__( 437 | self, 438 | dim, 439 | upsample_space=True, 440 | upsample_time=False, 441 | nonlin=False 442 | ): 443 | super().__init__() 444 | assert upsample_space or upsample_time 445 | 446 | self.up_space = nn.Sequential( 447 | nn.Conv2d(dim, dim * 4, 1), 448 | nn.SiLU() if nonlin else nn.Identity(), 449 | Rearrange('b (c p1 p2) h w -> b c (h p1) (w p2)', p1=2, p2=2) 450 | ) if upsample_space else None 451 | 452 | self.up_time = nn.Sequential( 453 | nn.Conv3d(dim, dim * 2, 1), 454 | nn.SiLU() if nonlin else nn.Identity(), 455 | Rearrange('b (c p) f h w -> b c (f p) h w', p=2) 456 | ) if upsample_time else None 457 | 458 | self.init_() 459 | 460 | def init_(self): 461 | if exists(self.up_space): 462 | self.init_conv_(self.up_space[0], 4) 463 | 464 | if exists(self.up_time): 465 | self.init_conv_(self.up_time[0], 2) 466 | 467 | def init_conv_(self, conv, factor): 468 | o, *remain_dims = conv.weight.shape 469 | conv_weight = torch.empty(o // factor, *remain_dims) 470 | nn.init.kaiming_uniform_(conv_weight) 471 | conv_weight = repeat(conv_weight, 'o ... -> (o r) ...', r=factor) 472 | 473 | conv.weight.data.copy_(conv_weight) 474 | nn.init.zeros_(conv.bias.data) 475 | 476 | def forward( 477 | self, 478 | x, 479 | enable_time=True 480 | ): 481 | is_video = x.ndim == 5 482 | 483 | if is_video: 484 | x = rearrange(x, 'b c f h w -> b f c h w') 485 | x, ps = pack([x], '* c h w') 486 | 487 | if exists(self.up_space): 488 | x = self.up_space(x) 489 | 490 | if is_video: 491 | x, = unpack(x, ps, '* c h w') 492 | x = rearrange(x, 'b f c h w -> b c f h w') 493 | 494 | if not is_video or not exists(self.up_time) or not enable_time: 495 | return x 496 | 497 | x = self.up_time(x) 498 | 499 | return x 500 | 501 | 502 | # space time factorized 3d unet 503 | 504 | class SpaceTimeUnet(nn.Module): 505 | def __init__( 506 | self, 507 | channels=3, 508 | dim=64, 509 | dim_mult=(1, 2, 4, 8), 510 | self_attns=(False, False, False, True), 511 | temporal_compression=(False, True, True, True), 512 | resnet_block_depths=(2, 2, 2, 2), 513 | attn_dim_head=64, 514 | attn_heads=8, 515 | condition_on_timestep=True 516 | ): 517 | super().__init__() 518 | assert len(dim_mult) == len(self_attns) == len(temporal_compression) == len(resnet_block_depths) 519 | num_layers = len(dim_mult) 520 | 521 | dims = [dim, *map(lambda mult: mult * dim, dim_mult)] 522 | dim_in_out = zip(dims[:-1], dims[1:]) 523 | 524 | # determine the valid multiples of the image size and frames of the video 525 | 526 | self.frame_multiple = 2 ** sum(tuple(map(int, temporal_compression))) 527 | self.image_size_multiple = 2 ** num_layers 528 | 529 | # timestep conditioning for DDPM, not to be confused with the time dimension of the video 530 | 531 | self.to_timestep_cond = None 532 | timestep_cond_dim = (dim * 4) if condition_on_timestep else None 533 | 534 | if condition_on_timestep: 535 | self.to_timestep_cond = nn.Sequential( 536 | SinusoidalPosEmb(dim), 537 | nn.Linear(dim, timestep_cond_dim), 538 | nn.SiLU() 539 | ) 540 | 541 | # layers 542 | 543 | self.downs = mlist([]) 544 | self.ups = mlist([]) 545 | 546 | attn_kwargs = dict( 547 | dim_head=attn_dim_head, 548 | heads=attn_heads 549 | ) 550 | 551 | mid_dim = dims[-1] 552 | 553 | self.mid_block1 = ResnetBlock(mid_dim, mid_dim, timestep_cond_dim=timestep_cond_dim) 554 | self.mid_attn = SpatioTemporalAttention(dim=mid_dim) 555 | self.mid_block2 = ResnetBlock(mid_dim, mid_dim, timestep_cond_dim=timestep_cond_dim) 556 | 557 | for _, self_attend, (dim_in, dim_out), compress_time, resnet_block_depth in zip(range(num_layers), self_attns, 558 | dim_in_out, temporal_compression, 559 | resnet_block_depths): 560 | assert resnet_block_depth >= 1 561 | 562 | self.downs.append(mlist([ 563 | ResnetBlock(dim_in, dim_out, timestep_cond_dim=timestep_cond_dim), 564 | mlist([ResnetBlock(dim_out, dim_out) for _ in range(resnet_block_depth)]), 565 | SpatioTemporalAttention(dim=dim_out, **attn_kwargs) if self_attend else None, 566 | Downsample(dim_out, downsample_time=compress_time) 567 | ])) 568 | 569 | self.ups.append(mlist([ 570 | ResnetBlock(dim_out * 2, dim_in, timestep_cond_dim=timestep_cond_dim), 571 | mlist([ResnetBlock(dim_in + (dim_out if ind == 0 else 0), dim_in) for ind in range(resnet_block_depth)]), 572 | SpatioTemporalAttention(dim=dim_in, **attn_kwargs) if self_attend else None, 573 | Upsample(dim_out, upsample_time=compress_time) 574 | 575 | ])) 576 | 577 | self.skip_scale = 2 ** -0.5 # paper shows faster convergence 578 | 579 | self.conv_in = PseudoConv3d(dim=channels, dim_out=dim, kernel_size=7, temporal_kernel_size=3) 580 | self.conv_out = PseudoConv3d(dim=dim, dim_out=channels, kernel_size=3, temporal_kernel_size=3) 581 | 582 | def forward( 583 | self, 584 | x, 585 | timestep=None, 586 | enable_time=True 587 | ): 588 | 589 | # some asserts 590 | 591 | assert not (exists(self.to_timestep_cond) ^ exists(timestep)) 592 | is_video = x.ndim == 5 593 | 594 | if enable_time and is_video: 595 | frames = x.shape[2] 596 | assert divisible_by(frames, 597 | self.frame_multiple), f'number of frames on the video ({frames}) must be divisible by the frame multiple ({self.frame_multiple})' 598 | 599 | height, width = x.shape[-2:] 600 | assert divisible_by(height, self.image_size_multiple) and divisible_by(width, 601 | self.image_size_multiple), f'height and width of the image or video must be a multiple of {self.image_size_multiple}' 602 | 603 | # main logic 604 | 605 | t = self.to_timestep_cond(rearrange(timestep, '... -> (...)')) if exists(timestep) else None 606 | 607 | x = self.conv_in(x, enable_time=enable_time) 608 | 609 | hiddens = [] 610 | 611 | for init_block, blocks, maybe_attention, downsample in self.downs: 612 | x = init_block(x, t, enable_time=enable_time) 613 | 614 | hiddens.append(x.clone()) 615 | 616 | for block in blocks: 617 | x = block(x, enable_time=enable_time) 618 | 619 | if exists(maybe_attention): 620 | x = maybe_attention(x, enable_time=enable_time) 621 | 622 | hiddens.append(x.clone()) 623 | 624 | x = downsample(x, enable_time=enable_time) 625 | 626 | x = self.mid_block1(x, t, enable_time=enable_time) 627 | x = self.mid_attn(x, enable_time=enable_time) 628 | x = self.mid_block2(x, t, enable_time=enable_time) 629 | 630 | for init_block, blocks, maybe_attention, upsample in reversed(self.ups): 631 | x = upsample(x, enable_time=enable_time) 632 | 633 | x = torch.cat((hiddens.pop() * self.skip_scale, x), dim=1) 634 | 635 | x = init_block(x, t, enable_time=enable_time) 636 | 637 | x = torch.cat((hiddens.pop() * self.skip_scale, x), dim=1) 638 | 639 | for block in blocks: 640 | x = block(x, enable_time=enable_time) 641 | 642 | if exists(maybe_attention): 643 | x = maybe_attention(x, enable_time=enable_time) 644 | 645 | x = self.conv_out(x, enable_time=enable_time) 646 | return x 647 | 648 | class IdxPromptSTUnet(SpaceTimeUnet): 649 | """ 650 | A UNetModel that performs super-resolution. 651 | 652 | Expects an extra kwarg `low_res` to condition on a low-resolution image. 653 | """ 654 | 655 | def __init__(self, channels, *args, **kwargs): 656 | super().__init__(channels +1, *args, **kwargs) 657 | self.in_mlp = linear(16,1024) 658 | self.out_mlp = linear(1024,16) 659 | 660 | 661 | def forward(self, x, timesteps, time_emb): 662 | 663 | import dist_util 664 | dist_util.setup_dist() 665 | 666 | prompt = self.in_mlp(time_emb).reshape(int(time_emb.shape[0]),int(time_emb.shape[1]),32,32) 667 | prompt=prompt.to(dist_util.dev()) 668 | prompt=prompt.unsqueeze(1) 669 | ##x: [B,C,T,W,H] 670 | # breakpoint() 671 | out = torch.cat([x.float(), prompt.float()], dim=1) 672 | output = super().forward(out, timesteps.float()) 673 | # breakpoint() 674 | video_output = output[:,0:-1,:,:] 675 | time_emb_output = output[:, -1:, :, :] 676 | 677 | # import pdb 678 | # pdb.set_trace() 679 | time_emb_output = time_emb_output.reshape(int(time_emb_output.shape[0]),int(time_emb_output.shape[2]),1024) 680 | time_emb_output = self.out_mlp(time_emb_output) 681 | return video_output,time_emb_output --------------------------------------------------------------------------------