├── video_model
├── diffusion
│ ├── __init__.py
│ ├── fp16_util.py
│ ├── losses.py
│ ├── dist_util.py
│ ├── respace.py
│ ├── nn.py
│ ├── resample.py
│ ├── script_util.py
│ ├── logger.py
│ ├── train_util.py
│ └── make_a_video.py
├── datasets
│ ├── clevrer.py
│ ├── MovingMNIST.py
│ └── video_datasets.py
├── video_train.py
└── video_sample.py
├── image_model
├── diffusion
│ ├── __init__.py
│ ├── diffusion_utils.py
│ ├── respace.py
│ └── timestep_sampler.py
├── datasets.py
├── sample.py
├── train_JPDVT.py
└── models.py
└── README.md
/video_model/diffusion/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Codebase for "Improved Denoising Diffusion Probabilistic Models".
3 |
4 | Raw copy of [1] at commit 32b43b6c677df7642c5408c8ef4a09272787eb50 from Feb 22, 2021.
5 |
6 | [1] https://github.com/openai/improved-diffusion/tree/main/improved_diffusion
7 | """
8 | from .gaussian_diffusion import GaussianDiffusion
9 | from .unet import UNetModel
10 | from .resample import ScheduleSampler, UniformSampler, HarmonicSampler, LossAwareSampler, LossSecondMomentResampler, create_named_schedule_sampler
11 | from .respace import SpacedDiffusion, space_timesteps
12 |
--------------------------------------------------------------------------------
/image_model/diffusion/__init__.py:
--------------------------------------------------------------------------------
1 | # Modified from OpenAI's diffusion repos
2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5 |
6 | from . import gaussian_diffusion as gd
7 | from .respace import SpacedDiffusion, space_timesteps
8 |
9 |
10 | def create_diffusion(
11 | timestep_respacing,
12 | noise_schedule="linear",
13 | use_kl=False,
14 | sigma_small=True,
15 | predict_xstart=True,
16 | learn_sigma=False,
17 | rescale_learned_sigmas=False,
18 | diffusion_steps=1000
19 | ):
20 | betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
21 | if use_kl:
22 | loss_type = gd.LossType.RESCALED_KL
23 | elif rescale_learned_sigmas:
24 | loss_type = gd.LossType.RESCALED_MSE
25 | else:
26 | loss_type = gd.LossType.MSE
27 | if timestep_respacing is None or timestep_respacing == "":
28 | timestep_respacing = [diffusion_steps]
29 | return SpacedDiffusion(
30 | use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
31 | betas=betas,
32 | model_mean_type=(
33 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
34 | ),
35 | model_var_type=(
36 | (
37 | gd.ModelVarType.FIXED_LARGE
38 | if not sigma_small
39 | else gd.ModelVarType.FIXED_SMALL
40 | )
41 | if not learn_sigma
42 | else gd.ModelVarType.LEARNED_RANGE
43 | ),
44 | loss_type=loss_type
45 | # rescale_timesteps=rescale_timesteps,
46 | )
47 |
--------------------------------------------------------------------------------
/video_model/datasets/clevrer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch
4 | from torch.utils.data import Dataset, DataLoader
5 | import glob
6 | from PIL import Image
7 | import torchvision.transforms as T
8 | class Clevrer(Dataset):
9 | def __init__(self, npy_dir,seq_len=32, downsample_ratio=4,downsample_mode='Notuniform'):
10 | self.npy_files = sorted(glob.glob(os.path.join(npy_dir, '*','*.npy')))
11 | self.dsr = downsample_ratio
12 | self.down_mode = downsample_mode
13 | self.seq_len = seq_len
14 | self.clip_num = 128 // self.dsr // self.seq_len
15 |
16 | def __len__(self):
17 | return len(self.npy_files) * self.clip_num
18 |
19 | def __getitem__(self, idx):
20 | npy_file = self.npy_files[idx // self.clip_num]
21 | video_data = np.load(npy_file)
22 | video_data = np.transpose(video_data, (0, 3, 1, 2)) # move channel axis to the first dimension
23 | if self.down_mode == 'uniform':
24 | video_data = video_data[::self.dsr]
25 | video_data = video_data[
26 | idx % self.clip_num * self.seq_len:
27 | idx % self.clip_num * self.seq_len + min(self.seq_len,video_data.shape[0])]
28 | video_data = torch.from_numpy(video_data).float()
29 | return video_data
30 | else:
31 | frame_num = np.asarray(range(0,128//self.dsr))
32 | frame_num_uneven = [min(x*self.dsr + np.random.randint(-self.dsr//2,self.dsr//2+1),127) for x in frame_num]
33 | frame_num_uneven[0] = 0
34 | uneven_video_data = video_data[frame_num_uneven]
35 | uneven_video_data = uneven_video_data[
36 | idx % self.clip_num * self.seq_len:
37 | idx % self.clip_num * self.seq_len + min(self.seq_len, video_data.shape[0])]
38 | video_data = video_data[::self.dsr]
39 | video_data = video_data[
40 | idx % self.clip_num * self.seq_len:
41 | idx % self.clip_num * self.seq_len + min(self.seq_len, video_data.shape[0])]
42 | uneven_video_data = torch.from_numpy(uneven_video_data).float()
43 | video_data = torch.from_numpy(video_data).float()
44 | return video_data,uneven_video_data,idx // self.clip_num
45 |
46 |
--------------------------------------------------------------------------------
/video_model/diffusion/fp16_util.py:
--------------------------------------------------------------------------------
1 | """
2 | Helpers to train with 16-bit precision.
3 | """
4 |
5 | import torch.nn as nn
6 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
7 |
8 |
9 | def convert_module_to_f16(l):
10 | """
11 | Convert primitive modules to float16.
12 | """
13 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
14 | l.weight.data = l.weight.data.half()
15 | l.bias.data = l.bias.data.half()
16 |
17 |
18 | def convert_module_to_f32(l):
19 | """
20 | Convert primitive modules to float32, undoing convert_module_to_f16().
21 | """
22 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
23 | l.weight.data = l.weight.data.float()
24 | l.bias.data = l.bias.data.float()
25 |
26 |
27 | def make_master_params(model_params):
28 | """
29 | Copy model parameters into a (differently-shaped) list of full-precision
30 | parameters.
31 | """
32 | master_params = _flatten_dense_tensors(
33 | [param.detach().float() for param in model_params]
34 | )
35 | master_params = nn.Parameter(master_params)
36 | master_params.requires_grad = True
37 | return [master_params]
38 |
39 |
40 | def model_grads_to_master_grads(model_params, master_params):
41 | """
42 | Copy the gradients from the model parameters into the master parameters
43 | from make_master_params().
44 | """
45 | master_params[0].grad = _flatten_dense_tensors(
46 | [param.grad.data.detach().float() for param in model_params]
47 | )
48 |
49 |
50 | def master_params_to_model_params(model_params, master_params):
51 | """
52 | Copy the master parameter data back into the model parameters.
53 | """
54 | # Without copying to a list, if a generator is passed, this will
55 | # silently not copy any parameters.
56 | model_params = list(model_params)
57 |
58 | for param, master_param in zip(
59 | model_params, unflatten_master_params(model_params, master_params)
60 | ):
61 | param.detach().copy_(master_param)
62 |
63 |
64 | def unflatten_master_params(model_params, master_params):
65 | """
66 | Unflatten the master parameters to look like model_params.
67 | """
68 | return _unflatten_dense_tensors(master_params[0].detach(), model_params)
69 |
70 |
71 | def zero_grad(model_params):
72 | for param in model_params:
73 | # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group
74 | if param.grad is not None:
75 | param.grad.detach_()
76 | param.grad.zero_()
77 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Solving Masked Jigsaw Puzzles with Diffusion Vision Transformers (SPDVT)
Official PyTorch Implementation
2 | [CVPR 2024] Solving Masked Jigsaw Puzzles with Diffusion Vision Transformers
3 |
4 | ### [[Paper]](https://openaccess.thecvf.com/content/CVPR2024/papers/Liu_Solving_Masked_Jigsaw_Puzzles_with_Diffusion_Vision_Transformers_CVPR_2024_paper.pdf) [[Arxiv]](https://arxiv.org/abs/2404.07292v1)
5 |
6 | **This GitHub repository is currently undergoing organization.** Stay tuned for the upcoming release of fully functional code!
7 |
8 |
9 |
10 | ## Setup
11 | git clone https://github.com/JinyangMarkLiu/JPDVT.git
12 | cd JPDVT
13 |
14 | ## Preparing Data
15 | Download datasets as you need. Here we give brief instructions for setting up part of the datasets we used.
16 |
17 | #### _ImageNet_
18 | You can use this [script](https://gist.github.com/bonlime/4e0d236cf98cd5b15d977dfa03a63643) to download and prepare the _ImageNet_ dataset. If you need to download the dataset, please uncomment the first part of the script.
19 |
20 | #### _JPwLEG-3_
21 | Download the _JPwLEG-3_ from this [Google Drive](https://drive.google.com/drive/folders/1MjPm7ar-u6H5WX6Bw2qshPiYPT_eQCZE). Only [select_image](https://drive.google.com/drive/folders/1MjPm7ar-u6H5WX6Bw2qshPiYPT_eQCZE) part is used in our experiments.
22 |
23 | ## Training
24 | We provide training scripts for training image models and video models.
25 |
26 | ### Training image models
27 | On ImageNet dataset:
28 |
29 | torchrun --nnodes=1 --nproc_per_node=4 train_JPDVT.py --dataset imagenet --data-path --image-size 192 --crop
30 |
31 | On MET dataset:
32 |
33 | torchrun --nnodes=1 --nproc_per_node=4 train_JPDVT.py --dataset met --data-path --image-size 288 --epochs 1000
34 |
35 | ## Testing
36 |
37 |
38 | ## BibTeX
39 | If you find our paper/project useful, please consider citing our paper:
40 |
41 | ```bibtex
42 | @InProceedings{Liu_2024_CVPR,
43 | author = {Liu, Jinyang and Teshome, Wondmgezahu and Ghimire, Sandesh and Sznaier, Mario and Camps, Octavia},
44 | title = {Solving Masked Jigsaw Puzzles with Diffusion Vision Transformers},
45 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
46 | month = {June},
47 | year = {2024},
48 | pages = {23009-23018}
49 | }
50 | ```
51 |
52 | ## Acknowledgments
53 | Our codebase is mainly based on [improved diffusion](https://github.com/openai/improved-diffusion), [make a video](https://github.com/lucidrains/make-a-video-pytorch), and [DiT](https://github.com/facebookresearch/DiT).
54 |
--------------------------------------------------------------------------------
/video_model/diffusion/losses.py:
--------------------------------------------------------------------------------
1 | """
2 | Helpers for various likelihood-based losses. These are ported from the original
3 | Ho et al. diffusion models codebase:
4 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py
5 | """
6 |
7 | import numpy as np
8 |
9 | import torch as th
10 |
11 |
12 | def normal_kl(mean1, logvar1, mean2, logvar2):
13 | """
14 | Compute the KL divergence between two gaussians.
15 |
16 | Shapes are automatically broadcasted, so batches can be compared to
17 | scalars, among other use cases.
18 | """
19 | tensor = None
20 | for obj in (mean1, logvar1, mean2, logvar2):
21 | if isinstance(obj, th.Tensor):
22 | tensor = obj
23 | break
24 | assert tensor is not None, "at least one argument must be a Tensor"
25 |
26 | # Force variances to be Tensors. Broadcasting helps convert scalars to
27 | # Tensors, but it does not work for th.exp().
28 | logvar1, logvar2 = [
29 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
30 | for x in (logvar1, logvar2)
31 | ]
32 |
33 | return 0.5 * (
34 | -1.0
35 | + logvar2
36 | - logvar1
37 | + th.exp(logvar1 - logvar2)
38 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
39 | )
40 |
41 |
42 | def approx_standard_normal_cdf(x):
43 | """
44 | A fast approximation of the cumulative distribution function of the
45 | standard normal.
46 | """
47 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
48 |
49 |
50 | def discretized_gaussian_log_likelihood(x, *, means, log_scales):
51 | """
52 | Compute the log-likelihood of a Gaussian distribution discretizing to a
53 | given image.
54 |
55 | :param x: the target images. It is assumed that this was uint8 values,
56 | rescaled to the range [-1, 1].
57 | :param means: the Gaussian mean Tensor.
58 | :param log_scales: the Gaussian log stddev Tensor.
59 | :return: a tensor like x of log probabilities (in nats).
60 | """
61 | assert x.shape == means.shape == log_scales.shape
62 | centered_x = x - means
63 | inv_stdv = th.exp(-log_scales)
64 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
65 | cdf_plus = approx_standard_normal_cdf(plus_in)
66 | min_in = inv_stdv * (centered_x - 1.0 / 255.0)
67 | cdf_min = approx_standard_normal_cdf(min_in)
68 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
69 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
70 | cdf_delta = cdf_plus - cdf_min
71 | log_probs = th.where(
72 | x < -0.999,
73 | log_cdf_plus,
74 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
75 | )
76 | assert log_probs.shape == x.shape
77 | return log_probs
78 |
--------------------------------------------------------------------------------
/video_model/diffusion/dist_util.py:
--------------------------------------------------------------------------------
1 | """
2 | Helpers for distributed training.
3 | """
4 |
5 | import io
6 | import os
7 | import socket
8 |
9 | import blobfile as bf
10 | from mpi4py import MPI
11 |
12 | import torch as th
13 | import torch.distributed as dist
14 |
15 | # Change this to reflect your cluster layout.
16 | # The GPU for a given rank is (rank % GPUS_PER_NODE).
17 | GPUS_PER_NODE = 8
18 |
19 | SETUP_RETRY_COUNT = 3
20 |
21 | th.cuda.set_device(0)
22 |
23 | def setup_dist():
24 | """
25 | Setup a distributed process group.
26 | """
27 | if dist.is_initialized():
28 | return
29 |
30 | comm = MPI.COMM_WORLD
31 | backend = "gloo" if not th.cuda.is_available() else "nccl"
32 |
33 | if backend == "gloo":
34 | hostname = "localhost"
35 | else:
36 | hostname = socket.gethostbyname(socket.getfqdn())
37 | os.environ["MASTER_ADDR"] = comm.bcast(hostname, root=0)
38 | os.environ["RANK"] = str(comm.rank)
39 | os.environ["WORLD_SIZE"] = str(comm.size)
40 | print(comm.size)
41 |
42 | port = comm.bcast(_find_free_port(), root=0)
43 | os.environ["MASTER_PORT"] = str(port)
44 | dist.init_process_group(backend=backend, world_size=comm.size, init_method="env://")
45 |
46 |
47 | # def dev():
48 | # """
49 | # Get the device to use for torch.distributed.
50 | # """
51 | # if th.cuda.is_available():
52 | # return th.device(f"cuda:{3}")
53 | # return th.device("cpu")
54 |
55 |
56 | def dev():
57 | """
58 | Get the device to use for torch.distributed.
59 | """
60 | if th.cuda.is_available():
61 | cuda_device = MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE # for other except skeleton
62 | # cuda_device=3
63 | # print('this is device',os.environ.get('CUDA_VISIBLE_DEVICES', ''))
64 | # import pdb
65 | # pdb.set_trace()
66 | # print(f"cuda:{cuda_device}")
67 | return th.device(f"cuda:{0}")
68 | return th.device("cpu")
69 |
70 |
71 | def load_state_dict(path, **kwargs):
72 | """
73 | Load a PyTorch file without redundant fetches across MPI ranks.
74 | """
75 | if MPI.COMM_WORLD.Get_rank() == 0:
76 | with bf.BlobFile(path, "rb") as f:
77 | data = f.read()
78 | else:
79 | data = None
80 | data = MPI.COMM_WORLD.bcast(data)
81 | return th.load(io.BytesIO(data), **kwargs)
82 |
83 | def load_opt_state_dict(path, **kwargs):
84 | """
85 | Load a PyTorch file without redundant fetches across MPI ranks.
86 | """
87 |
88 | with bf.BlobFile(path, "rb") as f:
89 | data = f.read()
90 |
91 | return th.load(io.BytesIO(data), **kwargs)
92 |
93 |
94 | def sync_params(params):
95 | """
96 | Synchronize a sequence of Tensors across ranks from rank 0.
97 | """
98 | for p in params:
99 | with th.no_grad():
100 | dist.broadcast(p, 0)
101 |
102 |
103 | def _find_free_port():
104 | try:
105 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
106 | s.bind(("", 0))
107 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
108 | return s.getsockname()[1]
109 | finally:
110 | s.close()
111 |
--------------------------------------------------------------------------------
/image_model/diffusion/diffusion_utils.py:
--------------------------------------------------------------------------------
1 | # Modified from OpenAI's diffusion repos
2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5 |
6 | import torch as th
7 | import numpy as np
8 |
9 |
10 | def normal_kl(mean1, logvar1, mean2, logvar2):
11 | """
12 | Compute the KL divergence between two gaussians.
13 | Shapes are automatically broadcasted, so batches can be compared to
14 | scalars, among other use cases.
15 | """
16 | tensor = None
17 | for obj in (mean1, logvar1, mean2, logvar2):
18 | if isinstance(obj, th.Tensor):
19 | tensor = obj
20 | break
21 | assert tensor is not None, "at least one argument must be a Tensor"
22 |
23 | # Force variances to be Tensors. Broadcasting helps convert scalars to
24 | # Tensors, but it does not work for th.exp().
25 | logvar1, logvar2 = [
26 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
27 | for x in (logvar1, logvar2)
28 | ]
29 |
30 | return 0.5 * (
31 | -1.0
32 | + logvar2
33 | - logvar1
34 | + th.exp(logvar1 - logvar2)
35 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
36 | )
37 |
38 |
39 | def approx_standard_normal_cdf(x):
40 | """
41 | A fast approximation of the cumulative distribution function of the
42 | standard normal.
43 | """
44 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
45 |
46 |
47 | def continuous_gaussian_log_likelihood(x, *, means, log_scales):
48 | """
49 | Compute the log-likelihood of a continuous Gaussian distribution.
50 | :param x: the targets
51 | :param means: the Gaussian mean Tensor.
52 | :param log_scales: the Gaussian log stddev Tensor.
53 | :return: a tensor like x of log probabilities (in nats).
54 | """
55 | centered_x = x - means
56 | inv_stdv = th.exp(-log_scales)
57 | normalized_x = centered_x * inv_stdv
58 | log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x)
59 | return log_probs
60 |
61 |
62 | def discretized_gaussian_log_likelihood(x, *, means, log_scales):
63 | """
64 | Compute the log-likelihood of a Gaussian distribution discretizing to a
65 | given image.
66 | :param x: the target images. It is assumed that this was uint8 values,
67 | rescaled to the range [-1, 1].
68 | :param means: the Gaussian mean Tensor.
69 | :param log_scales: the Gaussian log stddev Tensor.
70 | :return: a tensor like x of log probabilities (in nats).
71 | """
72 | assert x.shape == means.shape == log_scales.shape
73 | centered_x = x - means
74 | inv_stdv = th.exp(-log_scales)
75 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
76 | cdf_plus = approx_standard_normal_cdf(plus_in)
77 | min_in = inv_stdv * (centered_x - 1.0 / 255.0)
78 | cdf_min = approx_standard_normal_cdf(min_in)
79 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
80 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
81 | cdf_delta = cdf_plus - cdf_min
82 | log_probs = th.where(
83 | x < -0.999,
84 | log_cdf_plus,
85 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
86 | )
87 | assert log_probs.shape == x.shape
88 | return log_probs
89 |
--------------------------------------------------------------------------------
/image_model/datasets.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch
4 | from torch.utils.data import Dataset, DataLoader
5 | import glob
6 | from PIL import Image
7 | import torchvision.transforms as T
8 | from einops import rearrange
9 | from einops.layers.torch import Rearrange
10 |
11 | from torchvision import transforms
12 | from sklearn.model_selection import train_test_split
13 | import cv2
14 |
15 | class MET(Dataset):
16 | def __init__(self, image_dir,split):
17 | seed = 42
18 | torch.manual_seed(seed)
19 | self.split = split
20 |
21 | all_files = os.listdir(image_dir)
22 | self.image_files = [os.path.join(image_dir,all_files[0])+'/' + k for k in os.listdir(os.path.join(image_dir,all_files[0]))]
23 | self.image_files += [os.path.join(image_dir,all_files[1])+'/' + k for k in os.listdir(os.path.join(image_dir,all_files[1]))]
24 | self.image_files += [os.path.join(image_dir,all_files[2])+'/' + k for k in os.listdir(os.path.join(image_dir,all_files[2]))]
25 | # +os.listdir(os.path.join(image_dir,all_files[1]))+os.listdir(os.path.join(image_dir,all_files[2]))
26 | for image in self.image_files:
27 | if '.jpg' not in image:
28 | self.image_files.remove(image)
29 | dataset_indices = list(range(len( self.image_files)))
30 |
31 | train_indices, test_indices = train_test_split(dataset_indices, test_size=2000, random_state=seed)
32 | train_indices, val_indices = train_test_split(train_indices, test_size=1000, random_state=seed)
33 | self.train_indices = train_indices
34 | self.test_indices = test_indices
35 | self.val_indices = val_indices
36 |
37 | # Define the color jitter parameters
38 | brightness = 0.4 # Randomly adjust brightness with a maximum factor of 0.4
39 | contrast = 0.4 # Randomly adjust contrast with a maximum factor of 0.4
40 | saturation = 0.4 # Randomly adjust saturation with a maximum factor of 0.4
41 | hue = 0.1 # Randomly adjust hue with a maximum factor of 0.1
42 |
43 | # Because the Met is a much more smaller dataset, so we use more complex data augmentation here
44 | flip_probability = 0.5
45 | self.transform1 = transforms.Compose([
46 | transforms.Resize(398),
47 | transforms.RandomCrop((398,398)),
48 | transforms.RandomHorizontalFlip(p=flip_probability), # Horizontal flipping with 0.5 probability
49 | transforms.ColorJitter(brightness=brightness, contrast=contrast, saturation=saturation, hue=hue),
50 | transforms.ToTensor(),
51 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
52 | ])
53 |
54 | self.transform2 = transforms.Compose([
55 | transforms.Resize(398),
56 | transforms.CenterCrop((398,398)),
57 | transforms.ToTensor(),
58 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
59 | ])
60 |
61 | def __len__(self):
62 | if self.split == 'train':
63 | return len(self.train_indices)
64 | elif self.split == 'val':
65 | return len(self.val_indices)
66 | elif self.split == 'test':
67 | return len(self.test_indices)
68 |
69 | def rand_erode(self,image,n_patches):
70 | output = torch.zeros(3,96*3,96*3)
71 | crop = transforms.RandomCrop((96,96))
72 | gap = 48
73 | patch_size = 100
74 | for i in range(n_patches):
75 | for j in range(n_patches):
76 | left = i * (patch_size + gap)
77 | upper = j * (patch_size + gap)
78 | right = left + patch_size
79 | lower = upper + patch_size
80 |
81 | patch = crop(image[:,left:right, upper:lower])
82 | output[:,i*96:i*96+96,j*96:j*96+96] = patch
83 |
84 | return output
85 |
86 | def __getitem__(self, idx):
87 | if self.split == 'train':
88 | index = self.train_indices[idx]
89 | image = self.transform1(Image.open(self.image_files[index]))
90 | image = self.rand_erode(image,3)
91 | elif self.split == 'val':
92 | index = self.val_indices[idx]
93 | image = self.transform2(Image.open(self.image_files[index]))
94 | image = self.rand_erode(image,3)
95 | elif self.split == 'test':
96 | index = self.test_indices[idx]
97 | image = self.transform2(Image.open(self.image_files[index]))
98 | image = self.rand_erode(image,3)
99 |
100 | return image
101 |
--------------------------------------------------------------------------------
/video_model/video_train.py:
--------------------------------------------------------------------------------
1 | """
2 | Train a diffusion model on images.
3 | """
4 | # import torch.distributed as dist
5 | import numpy as np
6 | import os
7 | import gc
8 |
9 | import torch
10 |
11 | torch.cuda.empty_cache()
12 | gc.collect()
13 |
14 | import argparse
15 | import sys
16 |
17 | sys.path.insert(1, os.getcwd())
18 | # sys.path.insert(1, '/diffusion_openai')
19 | import dist_util, logger
20 | # from diffusion_openai import dist_util, logger
21 | from video_datasets import load_data
22 | from resample import create_named_schedule_sampler
23 | from script_util import (
24 | model_and_diffusion_defaults,
25 | create_model_and_diffusion,
26 | args_to_dict,
27 | add_dict_to_argparser,
28 | )
29 |
30 | from train_util import TrainLoop
31 |
32 |
33 | def main():
34 |
35 | parser, defaults = create_argparser()
36 | args = parser.parse_args()
37 | parameters = args_to_dict(args, defaults.keys())
38 | # th.manual_seed(args.seed)
39 | # np.random.seed(args.seed)
40 |
41 | dist_util.setup_dist()
42 | logger.configure()
43 | for key, item in parameters.items():
44 | logger.logkv(key, item)
45 | logger.dumpkvs()
46 |
47 | logger.log("creating model and diffusion...")
48 | model, diffusion = create_model_and_diffusion(
49 | **args_to_dict(args, model_and_diffusion_defaults().keys())
50 | )
51 |
52 | if len(args.model_path)>0:
53 | print("load ",args.model_path," to the model.")
54 | model.load_state_dict(
55 | dist_util.load_state_dict(args.model_path, map_location="cpu")
56 | )
57 |
58 | model.to(dist_util.dev())
59 | # model.summary()
60 | # breakpoint()
61 | # import torch
62 | # import torchviz
63 |
64 | # model = ... # create or load your PyTorch model
65 |
66 |
67 | print("device ",dist_util.dev())
68 | schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion)
69 |
70 | # print('model arch',model)
71 | # Save the model architecture to a file
72 | # with open("./figures/model_architecture.txt", "w") as f:
73 | # f.write(str(model))
74 |
75 | logger.log("creating data loader...")
76 |
77 | data = load_data(
78 | data_dir=args.data_dir,
79 | batch_size=args.batch_size,
80 | image_size=args.image_size,
81 | class_cond=args.class_cond,
82 | deterministic=False,
83 | rgb=args.rgb,
84 | seq_len=args.seq_len
85 | )
86 |
87 | # item=next(data)
88 | # breakpoint()
89 |
90 | if args.mask_range is None:
91 | mask_range = [0, args.seq_len]
92 | else:
93 | mask_range = [int(i) for i in args.mask_range if i != ","]
94 | logger.log("training...")
95 |
96 | TrainLoop(
97 | model=model,
98 | diffusion=diffusion,
99 | data=data,
100 | batch_size=args.batch_size,
101 | microbatch=args.microbatch,
102 | lr=args.lr,
103 | ema_rate=args.ema_rate,
104 | log_interval=args.log_interval,
105 | save_interval=args.save_interval,
106 | resume_checkpoint=args.resume_checkpoint,
107 | use_fp16=args.use_fp16,
108 | fp16_scale_growth=args.fp16_scale_growth,
109 | schedule_sampler=schedule_sampler,
110 | weight_decay=args.weight_decay,
111 | lr_anneal_steps=args.lr_anneal_steps,
112 | clip=args.clip,
113 | anneal_type=args.anneal_type,
114 | steps_drop=args.steps_drop,
115 | drop=args.drop,
116 | decay=args.decay,
117 | max_num_mask_frames=args.max_num_mask_frames,
118 | mask_range=mask_range,
119 | uncondition_rate=args.uncondition_rate,
120 | exclude_conditional=args.exclude_conditional,
121 | ).run_loop()
122 |
123 |
124 | def create_argparser():
125 | defaults = dict(
126 | data_dir="",
127 | schedule_sampler="uniform",
128 | lr=1e-4, # -4
129 | weight_decay=0.0,
130 | lr_anneal_steps=0,
131 | batch_size=16, # 8 for something
132 | microbatch=16, # 32, # -1 disables microbatches 32
133 | ema_rate="0.9999", # comma-separated list of EMA values
134 | log_interval=32, # 10
135 | save_interval=10000, # 2000 100
136 | resume_checkpoint="",
137 | model_path="",
138 | use_fp16=False,
139 | fp16_scale_growth=1e-3,
140 | clip=1,
141 | seed=123,
142 | anneal_type=None,
143 | steps_drop=0.0,
144 | drop=0.0,
145 | decay=0.0,
146 | seq_len=32, # 20
147 | max_num_mask_frames=6, # 4
148 | mask_range=None,
149 | uncondition_rate=0,
150 | exclude_conditional=True,
151 | model_name='two_distribution',
152 | )
153 |
154 | defaults.update(model_and_diffusion_defaults())
155 | parser = argparse.ArgumentParser()
156 | add_dict_to_argparser(parser, defaults)
157 | return parser, defaults
158 |
159 | if __name__ == "__main__":
160 | main()
161 | import json
162 | import numpy as np
163 |
--------------------------------------------------------------------------------
/image_model/sample.py:
--------------------------------------------------------------------------------
1 | """
2 | Solve Jigsaw Puzzles with JPDVT
3 | """
4 | import torch
5 | torch.backends.cuda.matmul.allow_tf32 = True
6 | torch.backends.cudnn.allow_tf32 = True
7 | from torchvision.utils import save_image
8 | from diffusion import create_diffusion
9 | from diffusers.models import AutoencoderKL
10 | from models import DiT_models
11 | import argparse
12 | from torch.utils.data import DataLoader
13 | from models import get_2d_sincos_pos_embed
14 | from datasets import MET
15 | import numpy as np
16 | from einops import rearrange
17 | import matplotlib.pyplot as plt
18 | from sklearn.metrics import pairwise_distances
19 |
20 | def main(args):
21 | # Setup PyTorch:
22 | torch.manual_seed(args.seed)
23 | torch.set_grad_enabled(False)
24 | device = "cuda" if torch.cuda.is_available() else "cpu"
25 |
26 | template = np.zeros((6,6))
27 |
28 | for i in range(6):
29 | for j in range(6):
30 | template[i,j] = 18 * i + j
31 |
32 | template = np.concatenate((template,template,template),axis=0)
33 | template = np.concatenate((template,template,template),axis=1)
34 |
35 | # Load model:
36 | model = DiT_models[args.model](
37 | input_size=args.image_size,
38 | ).to(device)
39 | print("Load model from:", args.ckpt )
40 | ckpt_path = args.ckpt
41 | state_dict = torch.load(ckpt_path)
42 | model.load_state_dict(state_dict)
43 |
44 | # Because the batchnorm doesn't work normally when batch size is 1
45 | # Thus we set the model to train mode
46 | model.train()
47 |
48 | diffusion = create_diffusion(str(args.num_sampling_steps))
49 | if args.dataset == "met":
50 | # MET dataloader give out cropped and stitched back images
51 | dataset = MET(args.data_path,'test')
52 | elif args.dataset == "imagenet":
53 | dataset = ImageFolder(args.data_path, transform=transform)
54 |
55 | loader = DataLoader(
56 | dataset,
57 | batch_size=1,
58 | shuffle=False,
59 | num_workers=2,
60 | pin_memory=False,
61 | drop_last=True
62 | )
63 |
64 | time_emb = torch.tensor(get_2d_sincos_pos_embed(8, 3)).unsqueeze(0).float().to(device)
65 | time_emb_noise = torch.tensor(get_2d_sincos_pos_embed(8, 18)).unsqueeze(0).float().to(device)
66 | time_emb_noise = torch.randn_like(time_emb_noise)
67 | time_emb_noise = time_emb_noise.repeat(1,1,1)
68 | model_kwargs = None
69 |
70 | # find the order with a greedy algorithm
71 | def find_permutation(distance_matrix):
72 | sort_list = []
73 | for m in range(distance_matrix.shape[1]):
74 | order = distance_matrix[:,0].argmin()
75 | sort_list.append(order)
76 | distance_matrix = distance_matrix[:,1:]
77 | distance_matrix[order,:] = 2024
78 | return sort_list
79 |
80 | abs_results = []
81 | for x in loader:
82 | if args.dataset == 'imagenet':
83 | x, _ = x
84 | x = x.to(device)
85 | if args.dataset == 'imagenet' and args.crop:
86 | centercrop = transforms.CenterCrop((64,64))
87 | patchs = rearrange(x, 'b c (p1 h1) (p2 w1)-> b c (p1 p2) h1 w1',p1=3,p2=3,h1=96,w1=96)
88 | patchs = centercrop(patchs)
89 | x = rearrange(patchs, 'b c (p1 p2) h1 w1-> b c (p1 h1) (p2 w1)',p1=3,p2=3,h1=64,w1=64)
90 |
91 | # Generate the Puzzles
92 | indices = np.random.permutation(9)
93 | x = rearrange(x, 'b c (p1 h1) (p2 w1)-> b c (p1 p2) h1 w1',p1=3,p2=3,h1=args.image_size//3,w1=args.image_size//3)
94 | x = x[:,:,indices,:,:]
95 | x = rearrange(x, ' b c (p1 p2) h1 w1->b c (p1 h1) (p2 w1)',p1=3,p2=3,h1=args.image_size//3,w1=args.image_size//3)
96 |
97 | samples = diffusion.p_sample_loop(
98 | model.forward, x, time_emb_noise.shape, time_emb_noise, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device
99 | )
100 | for sample,img in zip(samples,x):
101 | sample = rearrange(sample, '(p1 h1 p2 w1) d-> (p1 p2) (h1 w1) d',p1=3,p2=3,h1=args.image_size//48,w1=args.image_size//48)
102 | sample = sample.mean(1)
103 | dist = pairwise_distances(sample.cpu().numpy(), time_emb[0].cpu().numpy(), metric='manhattan')
104 | order = find_permutation(dist)
105 | pred = np.asarray(order).argsort()
106 | abs_results.append(int((pred == indices).all()))
107 |
108 | print("test result on ",len(abs_results), "samples is :", np.asarray(abs_results).sum()/len(abs_results))
109 |
110 | if len(abs_results)>=2000 and args.dataset == "met":
111 | break
112 | if len(abs_results)>=50000 and args.dataset == "imagenet":
113 | break
114 |
115 | if __name__ == "__main__":
116 | parser = argparse.ArgumentParser()
117 | parser.add_argument("--model", type=str, default="JPDVT")
118 | parser.add_argument("--dataset", type=str, choices=["imagenet", "met"], default="imagenet")
119 | parser.add_argument("--data-path", type=str,required=True)
120 | parser.add_argument("--crop", action='store_true', default=False)
121 | parser.add_argument("--image-size", type=int, choices=[192, 288], default=288)
122 | parser.add_argument("--num-sampling-steps", type=int, default=250)
123 | parser.add_argument("--seed", type=int, default=0)
124 | parser.add_argument("--ckpt", type=str, required=True)
125 | args = parser.parse_args()
126 | main(args)
127 |
--------------------------------------------------------------------------------
/video_model/diffusion/respace.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch as th
3 |
4 | # from video_sample import create_argparser
5 | # from video_datasets import load_data
6 | from gaussian_diffusion import GaussianDiffusion
7 |
8 | def space_timesteps(num_timesteps, section_counts):
9 | """
10 | Create a list of timesteps to use from an original diffusion process,
11 | given the number of timesteps we want to take from equally-sized portions
12 | of the original process.
13 |
14 | For example, if there's 300 timesteps and the section counts are [10,15,20]
15 | then the first 100 timesteps are strided to be 10 timesteps, the second 100
16 | are strided to be 15 timesteps, and the final 100 are strided to be 20.
17 |
18 | If the stride is a string starting with "ddim", then the fixed striding
19 | from the DDIM paper is used, and only one section is allowed.
20 |
21 | :param num_timesteps: the number of diffusion steps in the original
22 | process to divide up.
23 | :param section_counts: either a list of numbers, or a string containing
24 | comma-separated numbers, indicating the step count
25 | per section. As a special case, use "ddimN" where N
26 | is a number of steps to use the striding from the
27 | DDIM paper.
28 | :return: a set of diffusion steps from the original process to use.
29 | """
30 | if isinstance(section_counts, str):
31 | if section_counts.startswith("ddim"):
32 | desired_count = int(section_counts[len("ddim"):])
33 | for i in range(1, num_timesteps):
34 | if len(range(0, num_timesteps, i)) == desired_count:
35 | return set(range(0, num_timesteps, i))
36 | raise ValueError(
37 | f"cannot create exactly {num_timesteps} steps with an integer stride"
38 | )
39 | section_counts = [int(x) for x in section_counts.split(",")]
40 | size_per = num_timesteps // len(section_counts)
41 | extra = num_timesteps % len(section_counts)
42 | start_idx = 0
43 | all_steps = []
44 | for i, section_count in enumerate(section_counts):
45 | size = size_per + (1 if i < extra else 0)
46 | if size < section_count:
47 | raise ValueError(
48 | f"cannot divide section of {size} steps into {section_count}"
49 | )
50 | if section_count <= 1:
51 | frac_stride = 1
52 | else:
53 | frac_stride = (size - 1) / (section_count - 1)
54 | cur_idx = 0.0
55 | taken_steps = []
56 | for _ in range(section_count):
57 | taken_steps.append(start_idx + round(cur_idx))
58 | cur_idx += frac_stride
59 | all_steps += taken_steps
60 | start_idx += size
61 | return set(all_steps)
62 |
63 |
64 | class SpacedDiffusion(GaussianDiffusion):
65 | """
66 | A diffusion process which can skip steps in a base diffusion process.
67 |
68 | :param use_timesteps: a collection (sequence or set) of timesteps from the
69 | original diffusion process to retain.
70 | :param kwargs: the kwargs to create the base diffusion process.
71 | """
72 |
73 | def __init__(self, use_timesteps, **kwargs):
74 | self.use_timesteps = set(use_timesteps)
75 | self.timestep_map = []
76 | self.original_num_steps = len(kwargs["betas"])
77 |
78 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
79 | last_alpha_cumprod = 1.0
80 | new_betas = []
81 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
82 | if i in self.use_timesteps:
83 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
84 | last_alpha_cumprod = alpha_cumprod
85 | self.timestep_map.append(i)
86 | kwargs["betas"] = np.array(new_betas)
87 | super().__init__(**kwargs)
88 |
89 | def p_mean_variance(
90 | self, model, *args, **kwargs
91 | ): # pylint: disable=signature-differs
92 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
93 |
94 | def training_losses(
95 | self, model, *args, **kwargs
96 | ): # pylint: disable=signature-differs
97 | return super().training_losses(self._wrap_model(model), *args, **kwargs)
98 |
99 | def _wrap_model(self, model):
100 | if isinstance(model, _WrappedModel):
101 | return model
102 | return _WrappedModel(
103 | model, self.timestep_map, self.rescale_timesteps, self.original_num_steps
104 | )
105 |
106 | def _scale_timesteps(self, t):
107 | # Scaling is done by the wrapped model.
108 | return t
109 |
110 |
111 | class _WrappedModel:
112 | def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):
113 | self.model = model
114 | self.timestep_map = timestep_map
115 | self.rescale_timesteps = rescale_timesteps
116 | self.original_num_steps = original_num_steps
117 |
118 |
119 | def __call__(self, x, ts,condition_args):
120 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
121 | new_ts = map_tensor[ts]
122 |
123 | if self.rescale_timesteps:
124 | new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
125 |
126 | # breakpoint()
127 |
128 | return self.model(x, new_ts,condition_args)
129 |
--------------------------------------------------------------------------------
/video_model/diffusion/nn.py:
--------------------------------------------------------------------------------
1 | """
2 | Various utilities for neural networks.
3 | """
4 |
5 | import math
6 |
7 | import torch as th
8 | import torch.nn as nn
9 |
10 |
11 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
12 | class SiLU(nn.Module):
13 | def forward(self, x):
14 | return x * th.sigmoid(x)
15 |
16 |
17 | class GroupNorm32(nn.GroupNorm):
18 | def forward(self, x):
19 | return super().forward(x.float()).type(x.dtype)
20 |
21 |
22 | def conv_nd(dims, *args, **kwargs):
23 | """
24 | Create a 1D, 2D, or 3D convolution module.
25 | """
26 | if dims == 1:
27 | return nn.Conv1d(*args, **kwargs)
28 | elif dims == 2:
29 | return nn.Conv2d(*args, **kwargs)
30 | elif dims == 3:
31 | return nn.Conv3d(*args, **kwargs)
32 | raise ValueError(f"unsupported dimensions: {dims}")
33 |
34 |
35 | def linear(*args, **kwargs):
36 | """
37 | Create a linear module.
38 | """
39 | return nn.Linear(*args, **kwargs)
40 |
41 |
42 | def avg_pool_nd(dims, *args, **kwargs):
43 | """
44 | Create a 1D, 2D, or 3D average pooling module.
45 | """
46 | if dims == 1:
47 | return nn.AvgPool1d(*args, **kwargs)
48 | elif dims == 2:
49 | return nn.AvgPool2d(*args, **kwargs)
50 | elif dims == 3:
51 | return nn.AvgPool3d(*args, **kwargs)
52 | raise ValueError(f"unsupported dimensions: {dims}")
53 |
54 |
55 | def update_ema(target_params, source_params, rate=0.99):
56 | """
57 | Update target parameters to be closer to those of source parameters using
58 | an exponential moving average.
59 |
60 | :param target_params: the target parameter sequence.
61 | :param source_params: the source parameter sequence.
62 | :param rate: the EMA rate (closer to 1 means slower).
63 | """
64 | for targ, src in zip(target_params, source_params):
65 | targ.detach().mul_(rate).add_(src, alpha=1 - rate)
66 |
67 |
68 | def zero_module(module):
69 | """
70 | Zero out the parameters of a module and return it.
71 | """
72 | for p in module.parameters():
73 | p.detach().zero_()
74 | return module
75 |
76 |
77 | def scale_module(module, scale):
78 | """
79 | Scale the parameters of a module and return it.
80 | """
81 | for p in module.parameters():
82 | p.detach().mul_(scale)
83 | return module
84 |
85 |
86 | def mean_flat(tensor):
87 | """
88 | Take the mean over all non-batch dimensions.
89 | """
90 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
91 |
92 | def sum_flat(tensor):
93 | """
94 | Take the sum over all non-batch dimensions.
95 | """
96 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
97 |
98 | def normalization(channels):
99 | """
100 | Make a standard normalization layer.
101 |
102 | :param channels: number of input channels.
103 | :return: an nn.Module for normalization.
104 | """
105 | return GroupNorm32(32, channels)
106 |
107 |
108 | def timestep_embedding(timesteps, dim, max_period=10000):
109 | """
110 | Create sinusoidal timestep embeddings.
111 |
112 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
113 | These may be fractional.
114 | :param dim: the dimension of the output.
115 | :param max_period: controls the minimum frequency of the embeddings.
116 | :return: an [N x dim] Tensor of positional embeddings.
117 | """
118 | half = dim // 2
119 | freqs = th.exp(
120 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half
121 | ).to(device=timesteps.device)
122 | args = timesteps[:, None].float() * freqs[None]
123 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
124 | if dim % 2:
125 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
126 | return embedding
127 |
128 |
129 | def checkpoint(func, inputs, params, flag):
130 | """
131 | Evaluate a function without caching intermediate activations, allowing for
132 | reduced memory at the expense of extra compute in the backward pass.
133 |
134 | :param func: the function to evaluate.
135 | :param inputs: the argument sequence to pass to `func`.
136 | :param params: a sequence of parameters `func` depends on but does not
137 | explicitly take as arguments.
138 | :param flag: if False, disable gradient checkpointing.
139 | """
140 | if flag:
141 | args = tuple(inputs) + tuple(params)
142 | return CheckpointFunction.apply(func, len(inputs), *args)
143 | else:
144 | return func(*inputs)
145 |
146 |
147 | class CheckpointFunction(th.autograd.Function):
148 | @staticmethod
149 | def forward(ctx, run_function, length, *args):
150 | ctx.run_function = run_function
151 | ctx.input_tensors = list(args[:length])
152 | ctx.input_params = list(args[length:])
153 | with th.no_grad():
154 | output_tensors = ctx.run_function(*ctx.input_tensors)
155 | return output_tensors
156 |
157 | @staticmethod
158 | def backward(ctx, *output_grads):
159 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
160 | with th.enable_grad():
161 | # Fixes a bug where the first op in run_function modifies the
162 | # Tensor storage in place, which is not allowed for detach()'d
163 | # Tensors.
164 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
165 | output_tensors = ctx.run_function(*shallow_copies)
166 | input_grads = th.autograd.grad(
167 | output_tensors,
168 | ctx.input_tensors + ctx.input_params,
169 | output_grads,
170 | allow_unused=True,
171 | )
172 | del ctx.input_tensors
173 | del ctx.input_params
174 | del output_tensors
175 | return (None, None) + input_grads
176 |
--------------------------------------------------------------------------------
/image_model/diffusion/respace.py:
--------------------------------------------------------------------------------
1 | # Modified from OpenAI's diffusion repos
2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5 |
6 | import numpy as np
7 | import torch as th
8 |
9 | from .gaussian_diffusion import GaussianDiffusion
10 |
11 |
12 | def space_timesteps(num_timesteps, section_counts):
13 | """
14 | Create a list of timesteps to use from an original diffusion process,
15 | given the number of timesteps we want to take from equally-sized portions
16 | of the original process.
17 | For example, if there's 300 timesteps and the section counts are [10,15,20]
18 | then the first 100 timesteps are strided to be 10 timesteps, the second 100
19 | are strided to be 15 timesteps, and the final 100 are strided to be 20.
20 | If the stride is a string starting with "ddim", then the fixed striding
21 | from the DDIM paper is used, and only one section is allowed.
22 | :param num_timesteps: the number of diffusion steps in the original
23 | process to divide up.
24 | :param section_counts: either a list of numbers, or a string containing
25 | comma-separated numbers, indicating the step count
26 | per section. As a special case, use "ddimN" where N
27 | is a number of steps to use the striding from the
28 | DDIM paper.
29 | :return: a set of diffusion steps from the original process to use.
30 | """
31 | if isinstance(section_counts, str):
32 | if section_counts.startswith("ddim"):
33 | desired_count = int(section_counts[len("ddim") :])
34 | for i in range(1, num_timesteps):
35 | if len(range(0, num_timesteps, i)) == desired_count:
36 | return set(range(0, num_timesteps, i))
37 | raise ValueError(
38 | f"cannot create exactly {num_timesteps} steps with an integer stride"
39 | )
40 | section_counts = [int(x) for x in section_counts.split(",")]
41 | size_per = num_timesteps // len(section_counts)
42 | extra = num_timesteps % len(section_counts)
43 | start_idx = 0
44 | all_steps = []
45 | for i, section_count in enumerate(section_counts):
46 | size = size_per + (1 if i < extra else 0)
47 | if size < section_count:
48 | raise ValueError(
49 | f"cannot divide section of {size} steps into {section_count}"
50 | )
51 | if section_count <= 1:
52 | frac_stride = 1
53 | else:
54 | frac_stride = (size - 1) / (section_count - 1)
55 | cur_idx = 0.0
56 | taken_steps = []
57 | for _ in range(section_count):
58 | taken_steps.append(start_idx + round(cur_idx))
59 | cur_idx += frac_stride
60 | all_steps += taken_steps
61 | start_idx += size
62 | return set(all_steps)
63 |
64 |
65 | class SpacedDiffusion(GaussianDiffusion):
66 | """
67 | A diffusion process which can skip steps in a base diffusion process.
68 | :param use_timesteps: a collection (sequence or set) of timesteps from the
69 | original diffusion process to retain.
70 | :param kwargs: the kwargs to create the base diffusion process.
71 | """
72 |
73 | def __init__(self, use_timesteps, **kwargs):
74 | self.use_timesteps = set(use_timesteps)
75 | self.timestep_map = []
76 | self.original_num_steps = len(kwargs["betas"])
77 |
78 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
79 | last_alpha_cumprod = 1.0
80 | new_betas = []
81 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
82 | if i in self.use_timesteps:
83 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
84 | last_alpha_cumprod = alpha_cumprod
85 | self.timestep_map.append(i)
86 | kwargs["betas"] = np.array(new_betas)
87 | super().__init__(**kwargs)
88 |
89 | def p_mean_variance(
90 | self, model, *args, **kwargs
91 | ): # pylint: disable=signature-differs
92 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
93 |
94 | def training_losses(
95 | self, model, *args, **kwargs
96 | ): # pylint: disable=signature-differs
97 | return super().training_losses(self._wrap_model(model), *args, **kwargs)
98 |
99 | def condition_mean(self, cond_fn, *args, **kwargs):
100 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
101 |
102 | def condition_score(self, cond_fn, *args, **kwargs):
103 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
104 |
105 | def _wrap_model(self, model):
106 | if isinstance(model, _WrappedModel):
107 | return model
108 | return _WrappedModel(
109 | model, self.timestep_map, self.original_num_steps
110 | )
111 |
112 | def _scale_timesteps(self, t):
113 | # Scaling is done by the wrapped model.
114 | return t
115 |
116 |
117 | class _WrappedModel:
118 | def __init__(self, model, timestep_map, original_num_steps):
119 | self.model = model
120 | self.timestep_map = timestep_map
121 | # self.rescale_timesteps = rescale_timesteps
122 | self.original_num_steps = original_num_steps
123 |
124 | def __call__(self, x, ts, time_emb, **kwargs):
125 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
126 | new_ts = map_tensor[ts]
127 | # if self.rescale_timesteps:
128 | # new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
129 | return self.model(x, new_ts, time_emb, **kwargs)
130 |
--------------------------------------------------------------------------------
/image_model/diffusion/timestep_sampler.py:
--------------------------------------------------------------------------------
1 | # Modified from OpenAI's diffusion repos
2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5 |
6 | from abc import ABC, abstractmethod
7 |
8 | import numpy as np
9 | import torch as th
10 | import torch.distributed as dist
11 |
12 |
13 | def create_named_schedule_sampler(name, diffusion):
14 | """
15 | Create a ScheduleSampler from a library of pre-defined samplers.
16 | :param name: the name of the sampler.
17 | :param diffusion: the diffusion object to sample for.
18 | """
19 | if name == "uniform":
20 | return UniformSampler(diffusion)
21 | elif name == "loss-second-moment":
22 | return LossSecondMomentResampler(diffusion)
23 | else:
24 | raise NotImplementedError(f"unknown schedule sampler: {name}")
25 |
26 |
27 | class ScheduleSampler(ABC):
28 | """
29 | A distribution over timesteps in the diffusion process, intended to reduce
30 | variance of the objective.
31 | By default, samplers perform unbiased importance sampling, in which the
32 | objective's mean is unchanged.
33 | However, subclasses may override sample() to change how the resampled
34 | terms are reweighted, allowing for actual changes in the objective.
35 | """
36 |
37 | @abstractmethod
38 | def weights(self):
39 | """
40 | Get a numpy array of weights, one per diffusion step.
41 | The weights needn't be normalized, but must be positive.
42 | """
43 |
44 | def sample(self, batch_size, device):
45 | """
46 | Importance-sample timesteps for a batch.
47 | :param batch_size: the number of timesteps.
48 | :param device: the torch device to save to.
49 | :return: a tuple (timesteps, weights):
50 | - timesteps: a tensor of timestep indices.
51 | - weights: a tensor of weights to scale the resulting losses.
52 | """
53 | w = self.weights()
54 | p = w / np.sum(w)
55 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
56 | indices = th.from_numpy(indices_np).long().to(device)
57 | weights_np = 1 / (len(p) * p[indices_np])
58 | weights = th.from_numpy(weights_np).float().to(device)
59 | return indices, weights
60 |
61 |
62 | class UniformSampler(ScheduleSampler):
63 | def __init__(self, diffusion):
64 | self.diffusion = diffusion
65 | self._weights = np.ones([diffusion.num_timesteps])
66 |
67 | def weights(self):
68 | return self._weights
69 |
70 |
71 | class LossAwareSampler(ScheduleSampler):
72 | def update_with_local_losses(self, local_ts, local_losses):
73 | """
74 | Update the reweighting using losses from a model.
75 | Call this method from each rank with a batch of timesteps and the
76 | corresponding losses for each of those timesteps.
77 | This method will perform synchronization to make sure all of the ranks
78 | maintain the exact same reweighting.
79 | :param local_ts: an integer Tensor of timesteps.
80 | :param local_losses: a 1D Tensor of losses.
81 | """
82 | batch_sizes = [
83 | th.tensor([0], dtype=th.int32, device=local_ts.device)
84 | for _ in range(dist.get_world_size())
85 | ]
86 | dist.all_gather(
87 | batch_sizes,
88 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
89 | )
90 |
91 | # Pad all_gather batches to be the maximum batch size.
92 | batch_sizes = [x.item() for x in batch_sizes]
93 | max_bs = max(batch_sizes)
94 |
95 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
96 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
97 | dist.all_gather(timestep_batches, local_ts)
98 | dist.all_gather(loss_batches, local_losses)
99 | timesteps = [
100 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
101 | ]
102 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
103 | self.update_with_all_losses(timesteps, losses)
104 |
105 | @abstractmethod
106 | def update_with_all_losses(self, ts, losses):
107 | """
108 | Update the reweighting using losses from a model.
109 | Sub-classes should override this method to update the reweighting
110 | using losses from the model.
111 | This method directly updates the reweighting without synchronizing
112 | between workers. It is called by update_with_local_losses from all
113 | ranks with identical arguments. Thus, it should have deterministic
114 | behavior to maintain state across workers.
115 | :param ts: a list of int timesteps.
116 | :param losses: a list of float losses, one per timestep.
117 | """
118 |
119 |
120 | class LossSecondMomentResampler(LossAwareSampler):
121 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
122 | self.diffusion = diffusion
123 | self.history_per_term = history_per_term
124 | self.uniform_prob = uniform_prob
125 | self._loss_history = np.zeros(
126 | [diffusion.num_timesteps, history_per_term], dtype=np.float64
127 | )
128 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
129 |
130 | def weights(self):
131 | if not self._warmed_up():
132 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
133 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
134 | weights /= np.sum(weights)
135 | weights *= 1 - self.uniform_prob
136 | weights += self.uniform_prob / len(weights)
137 | return weights
138 |
139 | def update_with_all_losses(self, ts, losses):
140 | for t, loss in zip(ts, losses):
141 | if self._loss_counts[t] == self.history_per_term:
142 | # Shift out the oldest loss term.
143 | self._loss_history[t, :-1] = self._loss_history[t, 1:]
144 | self._loss_history[t, -1] = loss
145 | else:
146 | self._loss_history[t, self._loss_counts[t]] = loss
147 | self._loss_counts[t] += 1
148 |
149 | def _warmed_up(self):
150 | return (self._loss_counts == self.history_per_term).all()
151 |
--------------------------------------------------------------------------------
/video_model/diffusion/resample.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | import numpy as np
4 | import torch as th
5 | import torch.distributed as dist
6 |
7 |
8 | def create_named_schedule_sampler(name, diffusion, k=1):
9 | """
10 | Create a ScheduleSampler from a library of pre-defined samplers.
11 |
12 | :param name: the name of the sampler.
13 | :param diffusion: the diffusion object to sample for.
14 | :param k: the series order. Default k=1.
15 | """
16 | if name == "uniform":
17 | return UniformSampler(diffusion)
18 | elif name == "harmonic":
19 | return HarmonicSampler(diffusion, k)
20 | elif name == "loss-second-moment":
21 | return LossSecondMomentResampler(diffusion)
22 | else:
23 | raise NotImplementedError(f"unknown schedule sampler: {name}")
24 |
25 |
26 | class ScheduleSampler(ABC):
27 | """
28 | A distribution over timesteps in the diffusion process, intended to reduce
29 | variance of the objective.
30 |
31 | By default, samplers perform unbiased importance sampling, in which the
32 | objective's mean is unchanged.
33 | However, subclasses may override sample() to change how the resampled
34 | terms are reweighted, allowing for actual changes in the objective.
35 | """
36 |
37 | @abstractmethod
38 | def weights(self):
39 | """
40 | Get a numpy array of weights, one per diffusion step.
41 |
42 | The weights needn't be normalized, but must be positive.
43 | """
44 |
45 | def sample(self, batch_size, device):
46 | """
47 | Importance-sample timesteps for a batch.
48 |
49 | :param batch_size: the number of timesteps.
50 | :param device: the torch device to save to.
51 | :return: a tuple (timesteps, weights):
52 | - timesteps: a tensor of timestep indices.
53 | - weights: a tensor of weights to scale the resulting losses.
54 | """
55 | w = self.weights()
56 | p = w / np.sum(w)
57 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
58 | indices = th.from_numpy(indices_np).long().to(device)
59 | weights_np = 1 / (len(p) * p[indices_np])
60 | weights = th.from_numpy(weights_np).float().to(device)
61 | return indices, weights
62 |
63 |
64 | class UniformSampler(ScheduleSampler):
65 | def __init__(self, diffusion):
66 | self.diffusion = diffusion
67 | self._weights = np.ones([diffusion.num_timesteps])
68 |
69 | def weights(self):
70 | return self._weights
71 |
72 | class HarmonicSampler(ScheduleSampler):
73 | def __init__(self, diffusion, k=1):
74 | self.diffusion = diffusion
75 | w = 1. / np.array([t+1 for t in range(diffusion.num_timesteps)])
76 | w = w**k
77 | self._weights = w
78 |
79 | def weights(self):
80 | return self._weights
81 |
82 |
83 | class LossAwareSampler(ScheduleSampler):
84 | def update_with_local_losses(self, local_ts, local_losses):
85 | """
86 | Update the reweighting using losses from a model.
87 |
88 | Call this method from each rank with a batch of timesteps and the
89 | corresponding losses for each of those timesteps.
90 | This method will perform synchronization to make sure all of the ranks
91 | maintain the exact same reweighting.
92 |
93 | :param local_ts: an integer Tensor of timesteps.
94 | :param local_losses: a 1D Tensor of losses.
95 | """
96 | batch_sizes = [
97 | th.tensor([0], dtype=th.int32, device=local_ts.device)
98 | for _ in range(dist.get_world_size())
99 | ]
100 | dist.all_gather(
101 | batch_sizes,
102 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
103 | )
104 |
105 | # Pad all_gather batches to be the maximum batch size.
106 | batch_sizes = [x.item() for x in batch_sizes]
107 | max_bs = max(batch_sizes)
108 |
109 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
110 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
111 | dist.all_gather(timestep_batches, local_ts)
112 | dist.all_gather(loss_batches, local_losses)
113 | timesteps = [
114 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
115 | ]
116 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
117 | self.update_with_all_losses(timesteps, losses)
118 |
119 | @abstractmethod
120 | def update_with_all_losses(self, ts, losses):
121 | """
122 | Update the reweighting using losses from a model.
123 |
124 | Sub-classes should override this method to update the reweighting
125 | using losses from the model.
126 |
127 | This method directly updates the reweighting without synchronizing
128 | between workers. It is called by update_with_local_losses from all
129 | ranks with identical arguments. Thus, it should have deterministic
130 | behavior to maintain state across workers.
131 |
132 | :param ts: a list of int timesteps.
133 | :param losses: a list of float losses, one per timestep.
134 | """
135 |
136 |
137 | class LossSecondMomentResampler(LossAwareSampler):
138 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
139 | self.diffusion = diffusion
140 | self.history_per_term = history_per_term
141 | self.uniform_prob = uniform_prob
142 | self._loss_history = np.zeros(
143 | [diffusion.num_timesteps, history_per_term], dtype=np.float64
144 | )
145 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
146 |
147 | def weights(self):
148 | if not self._warmed_up():
149 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
150 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
151 | weights /= np.sum(weights)
152 | weights *= 1 - self.uniform_prob
153 | weights += self.uniform_prob / len(weights)
154 | return weights
155 |
156 | def update_with_all_losses(self, ts, losses):
157 | for t, loss in zip(ts, losses):
158 | if self._loss_counts[t] == self.history_per_term:
159 | # Shift out the oldest loss term.
160 | self._loss_history[t, :-1] = self._loss_history[t, 1:]
161 | self._loss_history[t, -1] = loss
162 | else:
163 | self._loss_history[t, self._loss_counts[t]] = loss
164 | self._loss_counts[t] += 1
165 |
166 | def _warmed_up(self):
167 | return (self._loss_counts == self.history_per_term).all()
168 |
--------------------------------------------------------------------------------
/video_model/datasets/MovingMNIST.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import torch.utils.data as data
3 | from PIL import Image
4 | import os
5 | import os.path
6 | import errno
7 | import numpy as np
8 | import torch
9 | import codecs
10 | import torchvision.transforms as T
11 |
12 | class MovingMNIST(data.Dataset):
13 | """`MovingMNIST `_ Dataset.
14 |
15 | Args:
16 | root (string): Root directory of dataset where ``processed/training.pt``
17 | and ``processed/test.pt`` exist.
18 | train (bool, optional): If True, creates dataset from ``training.pt``,
19 | otherwise from ``test.pt``.
20 | split (int, optional): Train/test split size. Number defines how many samples
21 | belong to test set.
22 | download (bool, optional): If true, downloads the dataset from the internet and
23 | puts it in root directory. If dataset is already downloaded, it is not
24 | downloaded again.
25 | transform (callable, optional): A function/transform that takes in an PIL image
26 | and returns a transformed version. E.g, ``transforms.RandomCrop``
27 | target_transform (callable, optional): A function/transform that takes in an PIL
28 | image and returns a transformed version. E.g, ``transforms.RandomCrop``
29 | """
30 | urls = [
31 | 'https://github.com/tychovdo/MovingMNIST/raw/master/mnist_test_seq.npy.gz'
32 | ]
33 | raw_folder = 'raw'
34 | processed_folder = 'processed'
35 | training_file = 'moving_mnist_train.pt'
36 | test_file = 'moving_mnist_test.pt'
37 |
38 | def __init__(self, root, train=True, split=1000, transform=None, target_transform=None, download=False,image_size=32,seq_len = 20):
39 | self.root = os.path.expanduser(root)
40 | self.transform = transform
41 | self.target_transform = target_transform
42 | self.split = split
43 | self.train = train # training set or test set
44 | self.image_size = image_size
45 | self.seq_len = seq_len
46 | if download:
47 | self.download()
48 |
49 | if not self._check_exists():
50 | raise RuntimeError('Dataset not found.' +
51 | ' You can use download=True to download it')
52 |
53 | if self.train:
54 | self.train_data = torch.load(
55 | os.path.join(self.root, self.processed_folder, self.training_file))
56 | else:
57 | self.test_data = torch.load(
58 | os.path.join(self.root, self.processed_folder, self.test_file))
59 |
60 | def __getitem__(self, index):
61 | """
62 | Args:
63 | index (int): Index
64 |
65 | Returns:
66 | tuple: (seq, target) where sampled sequences are splitted into a seq
67 | and target part
68 | """
69 |
70 | # need to iterate over time
71 | def _transform_time(data):
72 | new_data = None
73 | for i in range(data.size(0)):
74 | img = Image.fromarray(data[i].numpy(), mode='L')
75 | new_data = self.transform(img) if new_data is None else torch.cat([self.transform(img), new_data], dim=0)
76 | return new_data
77 |
78 | if self.train:
79 | seq, target = self.train_data[index,:self.seq_len], {}
80 | else:
81 | seq, target = self.test_data[index,:self.seq_len], {}
82 |
83 | if self.transform is not None:
84 | seq = _transform_time(seq)
85 | if self.target_transform is not None:
86 | target = _transform_time(target)
87 | seq = (seq.type(torch.float32) / 127.5 - 1).unsqueeze(0)
88 | resize = T.Resize((self.image_size,self.image_size))
89 | seq = resize(seq)
90 | return seq, target
91 |
92 | def __len__(self):
93 | if self.train:
94 | return len(self.train_data)
95 | else:
96 | return len(self.test_data)
97 |
98 | def _check_exists(self):
99 | return os.path.exists(os.path.join(self.root, self.processed_folder, self.training_file)) and \
100 | os.path.exists(os.path.join(self.root, self.processed_folder, self.test_file))
101 |
102 | def download(self):
103 | """Download the Moving MNIST data if it doesn't exist in processed_folder already."""
104 | from six.moves import urllib
105 | import gzip
106 |
107 | if self._check_exists():
108 | return
109 |
110 | # download files
111 | try:
112 | os.makedirs(os.path.join(self.root, self.raw_folder))
113 | os.makedirs(os.path.join(self.root, self.processed_folder))
114 | except OSError as e:
115 | if e.errno == errno.EEXIST:
116 | pass
117 | else:
118 | raise
119 |
120 | for url in self.urls:
121 | print('Downloading ' + url)
122 | data = urllib.request.urlopen(url)
123 | filename = url.rpartition('/')[2]
124 | file_path = os.path.join(self.root, self.raw_folder, filename)
125 | with open(file_path, 'wb') as f:
126 | f.write(data.read())
127 | with open(file_path.replace('.gz', ''), 'wb') as out_f, \
128 | gzip.GzipFile(file_path) as zip_f:
129 | out_f.write(zip_f.read())
130 | os.unlink(file_path)
131 |
132 | # process and save as torch files
133 | print('Processing...')
134 |
135 | training_set = torch.from_numpy(
136 | np.load(os.path.join(self.root, self.raw_folder, 'mnist_test_seq.npy')).swapaxes(0, 1)[:-self.split]
137 | )
138 | test_set = torch.from_numpy(
139 | np.load(os.path.join(self.root, self.raw_folder, 'mnist_test_seq.npy')).swapaxes(0, 1)[-self.split:]
140 | )
141 |
142 | with open(os.path.join(self.root, self.processed_folder, self.training_file), 'wb') as f:
143 | torch.save(training_set, f)
144 | with open(os.path.join(self.root, self.processed_folder, self.test_file), 'wb') as f:
145 | torch.save(test_set, f)
146 |
147 | print('Done!')
148 |
149 | def __repr__(self):
150 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
151 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
152 | tmp = 'train' if self.train is True else 'test'
153 | fmt_str += ' Train/test: {}\n'.format(tmp)
154 | fmt_str += ' Root Location: {}\n'.format(self.root)
155 | tmp = ' Transforms (if any): '
156 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
157 | tmp = ' Target Transforms (if any): '
158 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
159 | return fmt_str
160 |
--------------------------------------------------------------------------------
/video_model/diffusion/script_util.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import inspect
3 |
4 | import gaussian_diffusion as gd
5 | from respace import SpacedDiffusion, space_timesteps
6 | from unet import SuperResModel, UNetModel, UnshuffleModel,IdxPromptUnet
7 | from make_a_video import IdxPromptSTUnet
8 |
9 | NUM_CLASSES = 1000
10 |
11 |
12 | def model_and_diffusion_defaults():
13 | """
14 | Defaults for image training.
15 | """
16 | return dict(
17 | image_size=32, # 32 something
18 | num_channels=128,#128
19 | num_res_blocks=2,
20 | num_heads=4,
21 | num_heads_upsample=-1,
22 | attention_resolutions="8,8",
23 | dropout=0.0,
24 | learn_sigma=False,
25 | sigma_small=False,
26 | class_cond=False,
27 | diffusion_steps=1000, # 1000
28 | noise_schedule="linear",
29 | timestep_respacing="",
30 | use_kl=False,
31 | predict_xstart=True,
32 | rescale_timesteps=True,
33 | rescale_learned_sigmas=True,
34 | use_checkpoint=False,
35 | use_scale_shift_norm=True,
36 | scale_time_dim=0,
37 | # 20 for all frames of moving mnsit # 4 for 4 frames, 8 for 8 frames. the seq len has to be the same in movingmnist python file
38 | rgb=True # True
39 | )
40 |
41 |
42 | def create_model_and_diffusion(
43 | image_size,
44 | class_cond,
45 | learn_sigma,
46 | sigma_small,
47 | num_channels,
48 | num_res_blocks,
49 | scale_time_dim,
50 | num_heads,
51 | num_heads_upsample,
52 | attention_resolutions,
53 | dropout,
54 | diffusion_steps,
55 | noise_schedule,
56 | timestep_respacing,
57 | use_kl,
58 | predict_xstart,
59 | rescale_timesteps,
60 | rescale_learned_sigmas,
61 | use_checkpoint,
62 | use_scale_shift_norm,
63 | rgb=True, # True
64 | model_name='two_condition'
65 | ):
66 | model = create_model(
67 | image_size,
68 | num_channels,
69 | num_res_blocks,
70 | scale_time_dim=scale_time_dim,
71 | learn_sigma=learn_sigma,
72 | class_cond=class_cond,
73 | use_checkpoint=use_checkpoint,
74 | attention_resolutions=attention_resolutions,
75 | num_heads=num_heads,
76 | num_heads_upsample=num_heads_upsample,
77 | use_scale_shift_norm=use_scale_shift_norm,
78 | dropout=dropout,
79 | rgb=rgb,
80 | model_name=model_name
81 | )
82 | print('predict_xstart: ',predict_xstart)
83 | diffusion = create_gaussian_diffusion(
84 | steps=diffusion_steps,
85 | learn_sigma=learn_sigma,
86 | sigma_small=sigma_small,
87 | noise_schedule=noise_schedule,
88 | use_kl=use_kl,
89 | predict_xstart=predict_xstart,
90 | rescale_timesteps=rescale_timesteps,
91 | rescale_learned_sigmas=rescale_learned_sigmas,
92 | timestep_respacing=timestep_respacing,
93 | )
94 | return model, diffusion
95 |
96 |
97 | def create_model(
98 | image_size,
99 | num_channels,
100 | num_res_blocks,
101 | scale_time_dim,
102 | learn_sigma,
103 | class_cond,
104 | use_checkpoint,
105 | attention_resolutions,
106 | num_heads,
107 | num_heads_upsample,
108 | use_scale_shift_norm,
109 | dropout,
110 | rgb,
111 | model_name
112 | ):
113 | if image_size == 256:
114 | channel_mult = (1, 1, 2, 2, 4, 4)
115 | elif image_size == 128:
116 | channel_mult = (1, 2, 3, 4)
117 | elif image_size == 64:
118 | channel_mult = (1, 2, 3, 4)
119 | elif image_size == 32:
120 | channel_mult = (1, 2, 2, 2)# 1 2 2 2
121 | elif image_size in [16, 8]: # added
122 | channel_mult = (1, 1, 1, 1)
123 | # elif image_size == 8:
124 | # channel_mult = (1, 1, 1, 1)
125 |
126 | else:
127 | raise ValueError(f"unsupported image size: {image_size}")
128 |
129 | attention_ds = []
130 | for res in attention_resolutions.split(","):
131 | attention_ds.append(image_size // int(res))
132 |
133 |
134 | channels = 4
135 |
136 | if model_name == 'unshuffle':
137 |
138 | return UnshuffleModel(
139 | in_channels=channels,
140 | model_channels=num_channels,
141 | out_channels=(channels if not learn_sigma else 2 * channels),
142 | num_res_blocks=num_res_blocks,
143 | scale_time_dim=scale_time_dim,
144 | attention_resolutions=tuple(attention_ds),
145 | dropout=dropout,
146 | channel_mult=channel_mult,
147 | num_classes=(NUM_CLASSES if class_cond else None),
148 | use_checkpoint=use_checkpoint,
149 | num_heads=num_heads,
150 | num_heads_upsample=num_heads_upsample,
151 | use_scale_shift_norm=use_scale_shift_norm,
152 | )
153 | else:
154 | print('Use transformers')
155 | return IdxPromptSTUnet(
156 | channels = channels,
157 | dim = 16*channels,
158 | dim_mult = (1, 2, 4, 8),#1, 2, 4, 8)
159 | temporal_compression = (False, False, False, False),
160 | self_attns = (False, True, True, True),
161 | condition_on_timestep = True,
162 | # resnet_block_depths = (2, 2, 2, 2)
163 | )
164 | # return IdxPromptSTUnet(
165 | # in_channels=channels, # 1
166 | # model_channels=num_channels,
167 | # out_channels=channels+1,
168 | # num_res_blocks=num_res_blocks,
169 | # scale_time_dim=scale_time_dim,
170 | # attention_resolutions=tuple(attention_ds),
171 | # dropout=dropout,
172 | # channel_mult=channel_mult,
173 | # num_classes=(NUM_CLASSES if class_cond else None),
174 | # use_checkpoint=use_checkpoint,
175 | # num_heads=num_heads,
176 | # num_heads_upsample=num_heads_upsample,
177 | # use_scale_shift_norm=use_scale_shift_norm,
178 | # )
179 |
180 |
181 | def sr_model_and_diffusion_defaults():
182 | res = model_and_diffusion_defaults()
183 | res["large_size"] = 256
184 | res["small_size"] = 64
185 | arg_names = inspect.getfullargspec(sr_create_model_and_diffusion)[0]
186 | for k in res.copy().keys():
187 | if k not in arg_names:
188 | del res[k]
189 | return res
190 |
191 |
192 | def sr_create_model_and_diffusion(
193 | large_size,
194 | small_size,
195 | class_cond,
196 | learn_sigma,
197 | num_channels,
198 | num_res_blocks,
199 | scale_time_dim,
200 | num_heads,
201 | num_heads_upsample,
202 | attention_resolutions,
203 | dropout,
204 | diffusion_steps,
205 | noise_schedule,
206 | timestep_respacing,
207 | use_kl,
208 | predict_xstart,
209 | rescale_timesteps,
210 | rescale_learned_sigmas,
211 | use_checkpoint,
212 | use_scale_shift_norm,
213 | ):
214 | model = sr_create_model(
215 | large_size,
216 | small_size,
217 | num_channels,
218 | num_res_blocks,
219 | scale_time_dim,
220 | learn_sigma=learn_sigma,
221 | class_cond=class_cond,
222 | use_checkpoint=use_checkpoint,
223 | attention_resolutions=attention_resolutions,
224 | num_heads=num_heads,
225 | num_heads_upsample=num_heads_upsample,
226 | use_scale_shift_norm=use_scale_shift_norm,
227 | dropout=dropout,
228 | )
229 | diffusion = create_gaussian_diffusion(
230 | steps=diffusion_steps,
231 | learn_sigma=learn_sigma,
232 | noise_schedule=noise_schedule,
233 | use_kl=use_kl,
234 | predict_xstart=predict_xstart,
235 | rescale_timesteps=rescale_timesteps,
236 | rescale_learned_sigmas=rescale_learned_sigmas,
237 | timestep_respacing=timestep_respacing,
238 | )
239 | return model, diffusion
240 |
241 |
242 | def sr_create_model(
243 | large_size,
244 | small_size,
245 | num_channels,
246 | num_res_blocks,
247 | scale_time_dim,
248 | learn_sigma,
249 | class_cond,
250 | use_checkpoint,
251 | attention_resolutions,
252 | num_heads,
253 | num_heads_upsample,
254 | use_scale_shift_norm,
255 | dropout,
256 | ):
257 | _ = small_size # hack to prevent unused variable
258 |
259 | if large_size == 256:
260 | channel_mult = (1, 1, 2, 2, 4, 4)
261 | elif large_size == 64:
262 | channel_mult = (1, 2, 3, 4)
263 | else:
264 | raise ValueError(f"unsupported large size: {large_size}")
265 |
266 | attention_ds = []
267 | for res in attention_resolutions.split(","):
268 | attention_ds.append(large_size // int(res))
269 |
270 | return SuperResModel(
271 | in_channels=3,
272 | model_channels=num_channels,
273 | out_channels=(3 if not learn_sigma else 6),
274 | num_res_blocks=num_res_blocks,
275 | scale_time_dim=scale_time_dim,
276 | attention_resolutions=tuple(attention_ds),
277 | dropout=dropout,
278 | channel_mult=channel_mult,
279 | num_classes=(NUM_CLASSES if class_cond else None),
280 | use_checkpoint=use_checkpoint,
281 | num_heads=num_heads,
282 | num_heads_upsample=num_heads_upsample,
283 | use_scale_shift_norm=use_scale_shift_norm,
284 | )
285 |
286 |
287 | def create_gaussian_diffusion(
288 | *,
289 | steps=1000, #
290 | learn_sigma=False,
291 | sigma_small=False,
292 | noise_schedule="linear",
293 | use_kl=False,
294 | predict_xstart=False,
295 | rescale_timesteps=False,
296 | rescale_learned_sigmas=False,
297 | timestep_respacing="",
298 | ):
299 | betas = gd.get_named_beta_schedule(noise_schedule, steps)
300 | if use_kl:
301 | loss_type = gd.LossType.RESCALED_KL
302 | elif rescale_learned_sigmas:
303 | loss_type = gd.LossType.RESCALED_MSE
304 | else:
305 | loss_type = gd.LossType.MSE
306 | if not timestep_respacing:
307 | timestep_respacing = [steps]
308 | return SpacedDiffusion(
309 | use_timesteps=space_timesteps(steps, timestep_respacing),
310 | betas=betas,
311 | model_mean_type=(
312 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
313 | ),
314 | model_var_type=(
315 | (
316 | gd.ModelVarType.FIXED_LARGE
317 | if not sigma_small
318 | else gd.ModelVarType.FIXED_SMALL
319 | )
320 | if not learn_sigma
321 | else gd.ModelVarType.LEARNED_RANGE
322 | ),
323 | loss_type=loss_type,
324 | rescale_timesteps=rescale_timesteps,
325 | )
326 |
327 |
328 | def add_dict_to_argparser(parser, default_dict):
329 | for k, v in default_dict.items():
330 | v_type = type(v)
331 | if v is None:
332 | v_type = str
333 | elif isinstance(v, bool):
334 | v_type = str2bool
335 | parser.add_argument(f"--{k}", default=v, type=v_type)
336 |
337 |
338 | def args_to_dict(args, keys):
339 | return {k: getattr(args, k) for k in keys}
340 |
341 |
342 | def str2bool(v):
343 | """
344 | https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
345 | """
346 | if isinstance(v, bool):
347 | return v
348 | if v.lower() in ("yes", "true", "t", "y", "1"):
349 | return True
350 | elif v.lower() in ("no", "false", "f", "n", "0"):
351 | return False
352 | else:
353 | raise argparse.ArgumentTypeError("boolean value expected")
354 |
--------------------------------------------------------------------------------
/image_model/train_JPDVT.py:
--------------------------------------------------------------------------------
1 | """
2 | A minimal training script for JPDVT using PyTorch DDP.
3 | """
4 | import torch
5 | torch.backends.cuda.matmul.allow_tf32 = True
6 | torch.backends.cudnn.allow_tf32 = True
7 | import torch.distributed as dist
8 | from torch.nn.parallel import DistributedDataParallel as DDP
9 | from torch.utils.data import DataLoader
10 | from torch.utils.data.distributed import DistributedSampler
11 | from torchvision.datasets import ImageFolder
12 | from torchvision import transforms
13 | import numpy as np
14 | from collections import OrderedDict
15 | from PIL import Image
16 | from copy import deepcopy
17 | from glob import glob
18 | from time import time
19 | import argparse
20 | import logging
21 | import os
22 |
23 | from models import DiT_models
24 | from models import get_2d_sincos_pos_embed
25 | from diffusion import create_diffusion
26 | from diffusers.models import AutoencoderKL
27 |
28 | from datasets import MET
29 | from einops import rearrange
30 |
31 | #################################################################################
32 | # Training Helper Functions #
33 | #################################################################################
34 |
35 | @torch.no_grad()
36 | def update_ema(ema_model, model, decay=0.9999):
37 | """
38 | Step the EMA model towards the current model.
39 | """
40 | ema_params = OrderedDict(ema_model.named_parameters())
41 | model_params = OrderedDict(model.named_parameters())
42 |
43 | for name, param in model_params.items():
44 | # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed
45 | ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)
46 |
47 | def requires_grad(model, flag=True):
48 | """
49 | Set requires_grad flag for all parameters in a model.
50 | """
51 | for p in model.parameters():
52 | p.requires_grad = flag
53 |
54 | def cleanup():
55 | """
56 | End DDP training.
57 | """
58 | dist.destroy_process_group()
59 |
60 | def create_logger(logging_dir):
61 | """
62 | Create a logger that writes to a log file and stdout.
63 | """
64 | if dist.get_rank() == 0: # real logger
65 | logging.basicConfig(
66 | level=logging.INFO,
67 | format='[\033[34m%(asctime)s\033[0m] %(message)s',
68 | datefmt='%Y-%m-%d %H:%M:%S',
69 | handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")]
70 | )
71 | logger = logging.getLogger(__name__)
72 | else: # dummy logger (does nothing)
73 | logger = logging.getLogger(__name__)
74 | logger.addHandler(logging.NullHandler())
75 | return logger
76 |
77 |
78 | def center_crop_arr(pil_image, image_size):
79 | """
80 | Center cropping implementation from ADM.
81 | https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
82 | """
83 | while min(*pil_image.size) >= 2 * image_size:
84 | pil_image = pil_image.resize(
85 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX
86 | )
87 |
88 | scale = image_size / min(*pil_image.size)
89 | pil_image = pil_image.resize(
90 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
91 | )
92 |
93 | arr = np.array(pil_image)
94 | crop_y = (arr.shape[0] - image_size) // 2
95 | crop_x = (arr.shape[1] - image_size) // 2
96 | return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
97 |
98 |
99 | #################################################################################
100 | # Training Loop #
101 | #################################################################################
102 |
103 | def main(args):
104 | """
105 | Trains a new DiT model.
106 | """
107 | assert torch.cuda.is_available(), "Training currently requires at least one GPU."
108 |
109 | # Setup DDP:
110 | dist.init_process_group("nccl")
111 | assert args.global_batch_size % dist.get_world_size() == 0, f"Batch size must be divisible by world size."
112 | rank = dist.get_rank()
113 | device = rank % torch.cuda.device_count()
114 | seed = args.global_seed * dist.get_world_size() + rank
115 | torch.manual_seed(seed)
116 | torch.cuda.set_device(device)
117 | print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
118 |
119 | # Setup an experiment folder:
120 | if rank == 0:
121 | os.makedirs(args.results_dir, exist_ok=True)
122 | experiment_index = len(glob(f"{args.results_dir}/*"))
123 | model_string_name = args.model.replace("/", "-")
124 | model_string_name = args.dataset+"-" + model_string_name + "-crop" if args.crop else args.dataset+"-" + model_string_name
125 | model_string_name = model_string_name + "-withmask" if args.add_mask else model_string_name
126 | experiment_dir = f"{args.results_dir}/{experiment_index:03d}-{model_string_name}" # Create an experiment folder
127 | checkpoint_dir = f"{experiment_dir}/checkpoints" # Stores saved model checkpoints
128 | os.makedirs(checkpoint_dir, exist_ok=True)
129 | logger = create_logger(experiment_dir)
130 | logger.info(f"Experiment directory created at {experiment_dir}")
131 | else:
132 | logger = create_logger(None)
133 |
134 | # Create model:
135 | assert args.image_size % 3 == 0, "Image size should be Multiples of 3"
136 | if args.dataset == 'imagenet':
137 | assert args.image_size == 288 or args.crop, "Set imagesize to 192 if run experiment on imagenet with gap"
138 | model = DiT_models[args.model](
139 | input_size=args.image_size
140 | )
141 |
142 | if args.ckpt!= "":
143 | ckpt_path = args.ckpt
144 | print("Load model from ",ckpt_path)
145 | model_dict = model.state_dict()
146 | state_dict = torch.load(ckpt_path)
147 | pretrained_dict = {k: v for k, v in state_dict.items() if k in model_dict}
148 | model.load_state_dict(pretrained_dict, strict=False)
149 |
150 | # Note that parameter initialization is done within the DiT constructor
151 | ema = deepcopy(model).to(device) # Create an EMA of the model for use after training
152 | requires_grad(ema, False)
153 | model = DDP(model.to(device), device_ids=[rank])
154 | diffusion = create_diffusion(timestep_respacing="") # default: 1000 steps, linear noise schedule
155 | logger.info(f"DiT Parameters: {sum(p.numel() for p in model.parameters()):,}")
156 |
157 | # Setup optimizer (we used default Adam betas=(0.9, 0.999) and a constant learning rate of 1e-4 in our paper):
158 | opt = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0)
159 |
160 | transform = transforms.Compose([
161 | transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, 288)),
162 | transforms.RandomHorizontalFlip(),
163 | transforms.ToTensor(),
164 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
165 | ])
166 |
167 | # Setup data:
168 | if args.dataset == "met":
169 | # MET dataloader give out croped and stitched back images
170 | dataset = MET(args.data_path,'train')
171 | elif args.dataset == "imagenet":
172 | dataset = ImageFolder(args.data_path, transform=transform)
173 |
174 | sampler = DistributedSampler(
175 | dataset,
176 | num_replicas=dist.get_world_size(),
177 | rank=rank,
178 | shuffle=True,
179 | seed=args.global_seed
180 | )
181 | loader = DataLoader(
182 | dataset,
183 | batch_size=int(args.global_batch_size // dist.get_world_size()),
184 | shuffle=False,
185 | sampler=sampler,
186 | num_workers=args.num_workers,
187 | pin_memory=True,
188 | drop_last=True
189 | )
190 | logger.info(f"Dataset contains {len(dataset):,} images")
191 |
192 | # Prepare models for training:
193 | update_ema(ema, model.module, decay=0) # Ensure EMA is initialized with synced weights
194 | model.train() # important! This enables embedding dropout for classifier-free guidance
195 | ema.eval() # EMA model should always be in eval mode
196 |
197 | # Variables for monitoring/logging purposes:
198 | train_steps = 0
199 | log_steps = 0
200 | running_loss = 0
201 | start_time = time()
202 |
203 | logger.info(f"Training for {args.epochs} epochs...")
204 | for epoch in range(args.epochs):
205 | sampler.set_epoch(epoch)
206 | logger.info(f"Beginning epoch {epoch}...")
207 | for x in loader:
208 | if args.dataset == 'imagenet':
209 | x, _ = x
210 | x = x.to(device)
211 | if args.dataset == 'imagenet' and args.crop:
212 | centercrop = transforms.CenterCrop((64,64))
213 | patchs = rearrange(x, 'b c (p1 h1) (p2 w1)-> b c (p1 p2) h1 w1',p1=3,p2=3,h1=96,w1=96)
214 | patchs = centercrop(patchs)
215 | x = rearrange(patchs, 'b c (p1 p2) h1 w1-> b c (p1 h1) (p2 w1)',p1=3,p2=3,h1=64,w1=64)
216 |
217 | # Set up initial positional embedding
218 | time_emb = torch.tensor(get_2d_sincos_pos_embed(8, 3)).unsqueeze(0).float().to(device)
219 |
220 | t = torch.randint(0, diffusion.num_timesteps, (x.shape[0],), device=device)
221 | model_kwargs = None
222 | loss_dict = diffusion.training_losses(model, x, t, time_emb, model_kwargs, \
223 | block_size=args.image_size//3, patch_size=16, add_mask=args.add_mask)
224 | loss = loss_dict["loss"].mean()
225 | opt.zero_grad()
226 | loss.backward()
227 | opt.step()
228 | update_ema(ema, model.module)
229 |
230 | # Log loss values:
231 | running_loss += loss.item()
232 | log_steps += 1
233 | train_steps += 1
234 | if train_steps % args.log_every == 0:
235 | # Measure training speed:
236 | torch.cuda.synchronize()
237 | end_time = time()
238 | steps_per_sec = log_steps / (end_time - start_time)
239 | # Reduce loss history over all processes:
240 | avg_loss = torch.tensor(running_loss / log_steps, device=device)
241 | dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM)
242 | avg_loss = avg_loss.item() / dist.get_world_size()
243 | logger.info(f"(step={train_steps:07d}) Train Loss: {avg_loss:.4f}, Train Steps/Sec: {steps_per_sec:.2f}")
244 | # Reset monitoring variables:
245 | running_loss = 0
246 | log_steps = 0
247 | start_time = time()
248 |
249 | # Save JPDVT checkpoint:
250 | if train_steps % args.ckpt_every == 0 and train_steps > 0:
251 | if rank == 0:
252 | checkpoint = {
253 | "model": model.module.state_dict(),
254 | "ema": ema.state_dict(),
255 | "opt": opt.state_dict(),
256 | "args": args
257 | }
258 | checkpoint_path = f"{checkpoint_dir}/{train_steps:07d}.pt"
259 | torch.save(checkpoint, checkpoint_path)
260 | logger.info(f"Saved checkpoint to {checkpoint_path}")
261 | dist.barrier()
262 |
263 | model.eval() # important! This disables randomized embedding dropout
264 |
265 | logger.info("Done!")
266 | cleanup()
267 |
268 | if __name__ == "__main__":
269 | parser = argparse.ArgumentParser()
270 | parser.add_argument("--results-dir", type=str, default="results")
271 | parser.add_argument("--model", type=str, choices=list(DiT_models.keys()), default="JPDVT")
272 | parser.add_argument("--dataset", type=str, choices=["imagenet", "met"], default="imagenet")
273 | parser.add_argument("--data-path", type=str,required=True)
274 | parser.add_argument("--crop", action='store_true', default=False)
275 | parser.add_argument("--add-mask", action='store_true', default=False)
276 | parser.add_argument("--image-size", type=int, choices=[192, 288], default=288)
277 | parser.add_argument("--epochs", type=int, default=500)
278 | parser.add_argument("--global-batch-size", type=int, default=96)
279 | parser.add_argument("--global-seed", type=int, default=0)
280 | parser.add_argument("--num-workers", type=int, default=12)
281 | parser.add_argument("--log-every", type=int, default=100)
282 | parser.add_argument("--ckpt-every", type=int, default=10_000)
283 | parser.add_argument("--ckpt", type=str, default='')
284 | args = parser.parse_args()
285 | main(args)
286 |
--------------------------------------------------------------------------------
/video_model/datasets/video_datasets.py:
--------------------------------------------------------------------------------
1 | from random import sample
2 | from PIL import Image, ImageSequence
3 | import blobfile as bf
4 | from mpi4py import MPI
5 | import numpy as np
6 | from torch.utils.data import DataLoader, Dataset
7 |
8 | # from MovingMNIST import MovingMNIST
9 | # file: /folder1/folder2/module.py
10 | import os
11 | import sys
12 |
13 | sys.path.insert(1, os.getcwd())
14 | # sys.path.insert(0, '/RaMViD_main/mnist')
15 | # sys.path.insert(0, '/RaMViD_main/nucla_skeleton')
16 | # sys.path.insert(0, '/RaMViD_main/something')
17 |
18 | from MovingMNIST import MovingMNIST
19 | from somethingsomethingv2 import SomethingSomethingDataset
20 | from NUCLAskeleton import NUCLAskeleton
21 | from NTURGBDskeleton import NTURGBDskeleton
22 | from clevrer import Clevrer
23 | import torch
24 | import av
25 | import os
26 | # from nucla_skeleton.NUCLAskeleton import NUCLAskeleton
27 | import torch.utils.data as tudata
28 |
29 |
30 | # from somethng.somethingsomethingv2 import SomethingSomethingDataset
31 |
32 |
33 | def load_data(
34 | *, data_dir, batch_size, image_size, class_cond=False, deterministic=False, rgb=True, seq_len=8
35 | ):
36 | """
37 | For a dataset, create a generator over (videos, kwargs) pairs.
38 |
39 | Each video is an NCLHW float tensor, and the kwargs dict contains zero or
40 | more keys, each of which map to a batched Tensor of their own.
41 | The kwargs dict can be used for class labels, in which case the key is "y"
42 | and the values are integer tensors of class labels.
43 |
44 | :param data_dir: a dataset directory.
45 | :param batch_size: the batch size of each returned pair.
46 | :param image_size: the size to which frames are resized.
47 | :param class_cond: if True, include a "y" key in returned dicts for class
48 | label. If classes are not available and this is true, an
49 | exception will be raised.
50 | :param deterministic: if True, yield results in a deterministic order.
51 | """
52 | # global dataset
53 | if not data_dir:
54 | raise ValueError("unspecified data directory")
55 |
56 | # if data_dir != "MovingMNIST" and data_dir!="MovingMNIST_Test": # for sampling
57 | if data_dir not in ["MovingMNIST-train", "MovingMNIST-test", "something-train", "something-test",
58 | "NUCLAskeleton-train", "NUCLAskeleton-test", "NTURGBD-train", "NTURGBD-test","clevrer-train","clevrer-test"]:
59 |
60 | # if data_dir == "real_dataset": #
61 | all_files = _list_video_files_recursively(data_dir)
62 | classes = None
63 | if class_cond:
64 | # Assume classes are the first part of the filename,
65 | # before an underscore.
66 | class_names = [bf.basename(path).split("_")[0] for path in all_files]
67 | sorted_classes = {x: i for i, x in enumerate(sorted(set(class_names)))}
68 | classes = [sorted_classes[x] for x in class_names]
69 | entry = all_files[0].split(".")[-1]
70 | # breakpoint()
71 | if entry in ["avi", "mp4", " webm"]:
72 | dataset = VideoDataset_mp4(
73 | image_size,
74 | all_files,
75 | classes=classes,
76 | shard=MPI.COMM_WORLD.Get_rank(),
77 | num_shards=MPI.COMM_WORLD.Get_size(),
78 | rgb=rgb,
79 | seq_len=seq_len
80 | )
81 | # breakpoint()
82 | elif entry in ["gif"]:
83 | dataset = VideoDataset_gif(
84 | image_size,
85 | all_files,
86 | classes=classes,
87 | shard=MPI.COMM_WORLD.Get_rank(),
88 | num_shards=MPI.COMM_WORLD.Get_size(),
89 | rgb=rgb,
90 | seq_len=seq_len
91 | )
92 | elif data_dir == "MovingMNIST-train":
93 |
94 | root_path = '/data/mark'
95 | if not os.path.exists(root_path):
96 | os.mkdir(root_path)
97 | dataset = MovingMNIST(root='/data/mark/mnist', train=True, download=True, image_size=image_size,
98 | seq_len=seq_len)
99 | elif data_dir == "MovingMNIST-test": # newly added
100 | root_path = '/data/mark'
101 | # breakpoint()
102 | if not os.path.exists(root_path):
103 | os.mkdir(root_path)
104 | dataset = MovingMNIST(root='/data/mark/mnist', train=False, download=True, image_size=image_size,
105 | seq_len=seq_len)
106 | elif data_dir == "something-train":
107 | dataset = SomethingSomethingDataset(video_dir='/data/wondm/somethingsomethingv2/train',
108 | frame_res=image_size, num_frames=seq_len)
109 |
110 | elif data_dir == "something-test": # newly added
111 | dataset = SomethingSomethingDataset(video_dir='/data/wondm/somethingsomethingv2/test',
112 | frame_res=image_size, num_frames=seq_len)
113 |
114 | elif data_dir == "NUCLAskeleton-train": # newly added
115 | dataset = NUCLAskeleton(phase='train', cam='1,2', frame_res=image_size)
116 | elif data_dir == "NUCLAskeleton-test": # newly added
117 | dataset = NUCLAskeleton(phase='test', cam='1,2', frame_res=image_size)
118 | elif data_dir == "NTURGBD-train": # newly added
119 | dataset = NTURGBDskeleton(phase='train', cam='1') # T=8,20,40
120 | elif data_dir == "NTURGBD-test": # newly added
121 | dataset = NTURGBDskeleton(phase='val', cam='1') # T=8,20,40
122 | elif data_dir=="clevrer-train":
123 | dataset = Clevrer('/data/CLEVRER/Numpy/train/')
124 | elif data_dir=="clevrer-test":
125 | dataset = Clevrer('/data/CLEVRER/Numpy/test/')
126 |
127 |
128 |
129 | #
130 | # if deterministic:
131 | # loader = DataLoader(
132 | # dataset, batch_size=batch_size, shuffle=False, num_workers=16, drop_last=True
133 | # )
134 | # else:
135 | # loader = DataLoader(
136 | # dataset, batch_size=batch_size, shuffle=True, num_workers=4,drop_last=True
137 | # )
138 | if data_dir in ["NUCLAskeleton-train", "NUCLAskeleton-test"]:
139 | loader = tudata.DataLoader(dataset, batch_size=1, shuffle=False,
140 | num_workers=1, pin_memory=True)
141 | else:
142 | # breakpoint()
143 | loader = DataLoader(
144 | dataset, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=True
145 | )
146 |
147 | # dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)
148 | while True:
149 | yield from loader
150 |
151 |
152 | def _list_video_files_recursively(data_dir):
153 | results = []
154 | for entry in sorted(bf.listdir(data_dir)):
155 | full_path = bf.join(data_dir, entry)
156 | ext = entry.split(".")[-1]
157 | if "." in entry and ext.lower() in ["gif", "avi", "mp4", "webm"]:
158 | results.append(full_path)
159 | elif bf.isdir(full_path):
160 | results.extend(_list_video_files_recursively(full_path))
161 | return results
162 |
163 |
164 | class VideoDataset_mp4(Dataset):
165 |
166 | def __init__(self, resolution, video_paths, classes=None, shard=0, num_shards=1, rgb=True, seq_len=20):
167 | super().__init__()
168 | self.resolution = resolution
169 | self.local_videos = video_paths[shard:][::num_shards]
170 | self.local_classes = None if classes is None else classes[shard:][::num_shards]
171 | self.rgb = rgb
172 | self.seq_len = seq_len
173 |
174 | def __len__(self):
175 | return len(self.local_videos)
176 |
177 | def __getitem__(self, idx):
178 | path = self.local_videos[idx]
179 | # path = './data'
180 | arr_list = []
181 | video_container = av.open(path)
182 | n = video_container.streams.video[0].frames
183 | # breakpoint()
184 | frames = [i for i in range(n)]
185 | if n > self.seq_len:
186 | start = np.random.randint(0, n - self.seq_len)
187 | frames = frames[start:start + self.seq_len]
188 | for id, frame_av in enumerate(video_container.decode(video=0)):
189 | # We are not on a new enough PIL to support the `reducing_gap`
190 | # argument, which uses BOX downsampling at powers of two first.
191 | # Thus, we do it by hand to improve downsample quality.
192 | if (id not in frames):
193 | continue
194 | frame = frame_av.to_image()
195 | while min(*frame.size) >= 2 * self.resolution:
196 | frame = frame.resize(
197 | tuple(x // 2 for x in frame.size), resample=Image.BOX
198 | )
199 | scale = self.resolution / min(*frame.size)
200 | frame = frame.resize(
201 | tuple(round(x * scale) for x in frame.size), resample=Image.BICUBIC
202 | )
203 |
204 | if self.rgb:
205 | arr = np.array(frame.convert("RGB"))
206 | else:
207 | arr = np.array(frame.convert("L"))
208 | arr = np.expand_dims(arr, axis=2)
209 | crop_y = (arr.shape[0] - self.resolution) // 2
210 | crop_x = (arr.shape[1] - self.resolution) // 2
211 | arr = arr[crop_y: crop_y + self.resolution, crop_x: crop_x + self.resolution]
212 | arr = arr.astype(np.float32) / 127.5 - 1
213 | arr_list.append(arr)
214 | arr_seq = np.array(arr_list)
215 | arr_seq = np.transpose(arr_seq, [3, 0, 1, 2])
216 | # fill in missing frames with 0s
217 | if arr_seq.shape[1] < self.seq_len:
218 | required_dim = self.seq_len - arr_seq.shape[1]
219 | fill = np.zeros((3, required_dim, self.resolution, self.resolution))
220 | arr_seq = np.concatenate((arr_seq, fill), axis=1)
221 | out_dict = {}
222 | if self.local_classes is not None:
223 | out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64)
224 |
225 | # breakpoint()
226 | return arr_seq, out_dict
227 |
228 |
229 | class VideoDataset_gif(Dataset):
230 | def __init__(self, resolution, video_paths, classes=None, shard=0, num_shards=1, rgb=True, seq_len=20):
231 | super().__init__()
232 | self.resolution = resolution
233 | self.local_videos = video_paths[shard:][::num_shards]
234 | self.local_classes = None if classes is None else classes[shard:][::num_shards]
235 | self.rgb = rgb
236 | self.seq_len = seq_len
237 |
238 | def __len__(self):
239 | return len(self.local_videos)
240 |
241 | def __getitem__(self, idx):
242 | path = self.local_videos[idx]
243 | with bf.BlobFile(path, "rb") as f:
244 | pil_videos = Image.open(f)
245 | arr_list = []
246 | for frame in ImageSequence.Iterator(pil_videos):
247 |
248 | # We are not on a new enough PIL to support the `reducing_gap`
249 | # argument, which uses BOX downsampling at powers of two first.
250 | # Thus, we do it by hand to improve downsample quality.
251 | while min(*frame.size) >= 2 * self.resolution:
252 | frame = frame.resize(
253 | tuple(x // 2 for x in frame.size), resample=Image.BOX
254 | )
255 | scale = self.resolution / min(*frame.size)
256 | frame = frame.resize(
257 | tuple(round(x * scale) for x in frame.size), resample=Image.BICUBIC
258 | )
259 |
260 | if self.rgb:
261 | arr = np.array(frame.convert("RGB"))
262 | else:
263 | arr = np.array(frame.convert("L"))
264 | arr = np.expand_dims(arr, axis=2)
265 | crop_y = (arr.shape[0] - self.resolution) // 2
266 | crop_x = (arr.shape[1] - self.resolution) // 2
267 | arr = arr[crop_y: crop_y + self.resolution, crop_x: crop_x + self.resolution]
268 | arr = arr.astype(np.float32) / 127.5 - 1
269 | arr_list.append(arr)
270 | arr_seq = np.array(arr_list)
271 | arr_seq = np.transpose(arr_seq, [3, 0, 1, 2])
272 | if arr_seq.shape[1] > self.seq_len:
273 | start = np.random.randint(0, arr_seq.shape[1] - self.seq_len)
274 | arr_seq = arr_seq[:, start:start + self.seq_len]
275 | out_dict = {}
276 | if self.local_classes is not None:
277 | out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64)
278 | return arr_seq, out_dict
279 |
--------------------------------------------------------------------------------
/video_model/diffusion/logger.py:
--------------------------------------------------------------------------------
1 | """
2 | Logger copied from OpenAI baselines to avoid extra RL-based dependencies:
3 | https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/logger.py
4 | """
5 |
6 | import os
7 | import sys
8 | import shutil
9 | import os.path as osp
10 | import json
11 | import time
12 | import datetime
13 | import tempfile
14 | import warnings
15 | from collections import defaultdict
16 | from contextlib import contextmanager
17 |
18 | DEBUG = 10
19 | INFO = 20
20 | WARN = 30
21 | ERROR = 40
22 |
23 | DISABLED = 50
24 |
25 |
26 | class KVWriter(object):
27 | def writekvs(self, kvs):
28 | raise NotImplementedError
29 |
30 |
31 | class SeqWriter(object):
32 | def writeseq(self, seq):
33 | raise NotImplementedError
34 |
35 |
36 | class HumanOutputFormat(KVWriter, SeqWriter):
37 | def __init__(self, filename_or_file):
38 | if isinstance(filename_or_file, str):
39 | self.file = open(filename_or_file, "wt")
40 | self.own_file = True
41 | else:
42 | assert hasattr(filename_or_file, "read"), (
43 | "expected file or str, got %s" % filename_or_file
44 | )
45 | self.file = filename_or_file
46 | self.own_file = False
47 |
48 | def writekvs(self, kvs):
49 | # Create strings for printing
50 | key2str = {}
51 | for (key, val) in sorted(kvs.items()):
52 | if hasattr(val, "__float__"):
53 | valstr = "%-8.3g" % val
54 | else:
55 | valstr = str(val)
56 | key2str[self._truncate(key)] = self._truncate(valstr)
57 |
58 | # Find max widths
59 | if len(key2str) == 0:
60 | print("WARNING: tried to write empty key-value dict")
61 | return
62 | else:
63 | keywidth = max(map(len, key2str.keys()))
64 | valwidth = max(map(len, key2str.values()))
65 |
66 | # Write out the data
67 | dashes = "-" * (keywidth + valwidth + 7)
68 | lines = [dashes]
69 | for (key, val) in sorted(key2str.items(), key=lambda kv: kv[0].lower()):
70 | lines.append(
71 | "| %s%s | %s%s |"
72 | % (key, " " * (keywidth - len(key)), val, " " * (valwidth - len(val)))
73 | )
74 | lines.append(dashes)
75 | self.file.write("\n".join(lines) + "\n")
76 |
77 | # Flush the output to the file
78 | self.file.flush()
79 |
80 | def _truncate(self, s):
81 | maxlen = 30
82 | return s[: maxlen - 3] + "..." if len(s) > maxlen else s
83 |
84 | def writeseq(self, seq):
85 | seq = list(seq)
86 | for (i, elem) in enumerate(seq):
87 | self.file.write(elem)
88 | if i < len(seq) - 1: # add space unless this is the last one
89 | self.file.write(" ")
90 | self.file.write("\n")
91 | self.file.flush()
92 |
93 | def close(self):
94 | if self.own_file:
95 | self.file.close()
96 |
97 |
98 | class JSONOutputFormat(KVWriter):
99 | def __init__(self, filename):
100 | self.file = open(filename, "wt")
101 |
102 | def writekvs(self, kvs):
103 | for k, v in sorted(kvs.items()):
104 | if hasattr(v, "dtype"):
105 | kvs[k] = float(v)
106 | self.file.write(json.dumps(kvs) + "\n")
107 | self.file.flush()
108 |
109 | def close(self):
110 | self.file.close()
111 |
112 |
113 | class CSVOutputFormat(KVWriter):
114 | def __init__(self, filename):
115 | self.file = open(filename, "w+t")
116 | self.keys = []
117 | self.sep = ","
118 |
119 | def writekvs(self, kvs):
120 | # Add our current row to the history
121 | extra_keys = list(kvs.keys() - self.keys)
122 | extra_keys.sort()
123 | if extra_keys:
124 | self.keys.extend(extra_keys)
125 | self.file.seek(0)
126 | lines = self.file.readlines()
127 | self.file.seek(0)
128 | for (i, k) in enumerate(self.keys):
129 | if i > 0:
130 | self.file.write(",")
131 | self.file.write(k)
132 | self.file.write("\n")
133 | for line in lines[1:]:
134 | self.file.write(line[:-1])
135 | self.file.write(self.sep * len(extra_keys))
136 | self.file.write("\n")
137 | for (i, k) in enumerate(self.keys):
138 | if i > 0:
139 | self.file.write(",")
140 | v = kvs.get(k)
141 | if v is not None:
142 | self.file.write(str(v))
143 | self.file.write("\n")
144 | self.file.flush()
145 |
146 | def close(self):
147 | self.file.close()
148 |
149 |
150 | class TensorBoardOutputFormat(KVWriter):
151 | """
152 | Dumps key/value pairs into TensorBoard's numeric format.
153 | """
154 |
155 | def __init__(self, dir):
156 | os.makedirs(dir, exist_ok=True)
157 | self.dir = dir
158 | self.step = 1
159 | prefix = "events"
160 | path = osp.join(osp.abspath(dir), prefix)
161 | import tensorflow as tf
162 | from tensorflow.python import pywrap_tensorflow
163 | from tensorflow.core.util import event_pb2
164 | from tensorflow.python.util import compat
165 |
166 | self.tf = tf
167 | self.event_pb2 = event_pb2
168 | self.pywrap_tensorflow = pywrap_tensorflow
169 | self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path))
170 |
171 | def writekvs(self, kvs):
172 | def summary_val(k, v):
173 | kwargs = {"tag": k, "simple_value": float(v)}
174 | return self.tf.Summary.Value(**kwargs)
175 |
176 | summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()])
177 | event = self.event_pb2.Event(wall_time=time.time(), summary=summary)
178 | event.step = (
179 | self.step
180 | ) # is there any reason why you'd want to specify the step?
181 | self.writer.WriteEvent(event)
182 | self.writer.Flush()
183 | self.step += 1
184 |
185 | def close(self):
186 | if self.writer:
187 | self.writer.Close()
188 | self.writer = None
189 |
190 |
191 | def make_output_format(format, ev_dir, log_suffix=""):
192 | os.makedirs(ev_dir, exist_ok=True)
193 | if format == "stdout":
194 | return HumanOutputFormat(sys.stdout)
195 | elif format == "log":
196 | return HumanOutputFormat(osp.join(ev_dir, "log%s.txt" % log_suffix))
197 | elif format == "json":
198 | return JSONOutputFormat(osp.join(ev_dir, "progress%s.json" % log_suffix))
199 | elif format == "csv":
200 | return CSVOutputFormat(osp.join(ev_dir, "progress%s.csv" % log_suffix))
201 | elif format == "tensorboard":
202 | return TensorBoardOutputFormat(osp.join(ev_dir, "tb%s" % log_suffix))
203 | else:
204 | raise ValueError("Unknown format specified: %s" % (format,))
205 |
206 |
207 | # ================================================================
208 | # API
209 | # ================================================================
210 |
211 |
212 | def logkv(key, val):
213 | """
214 | Log a value of some diagnostic
215 | Call this once for each diagnostic quantity, each iteration
216 | If called many times, last value will be used.
217 | """
218 | get_current().logkv(key, val)
219 |
220 |
221 | def logkv_mean(key, val):
222 | """
223 | The same as logkv(), but if called many times, values averaged.
224 | """
225 | get_current().logkv_mean(key, val)
226 |
227 |
228 | def logkvs(d):
229 | """
230 | Log a dictionary of key-value pairs
231 | """
232 | for (k, v) in d.items():
233 | logkv(k, v)
234 |
235 |
236 | def dumpkvs():
237 | """
238 | Write all of the diagnostics from the current iteration
239 | """
240 | return get_current().dumpkvs()
241 |
242 |
243 | def getkvs():
244 | return get_current().name2val
245 |
246 |
247 | def log(*args, level=INFO):
248 | """
249 | Write the sequence of args, with no separators, to the console and output files (if you've configured an output file).
250 | """
251 | get_current().log(*args, level=level)
252 |
253 |
254 | def debug(*args):
255 | log(*args, level=DEBUG)
256 |
257 |
258 | def info(*args):
259 | log(*args, level=INFO)
260 |
261 |
262 | def warn(*args):
263 | log(*args, level=WARN)
264 |
265 |
266 | def error(*args):
267 | log(*args, level=ERROR)
268 |
269 |
270 | def set_level(level):
271 | """
272 | Set logging threshold on current logger.
273 | """
274 | get_current().set_level(level)
275 |
276 |
277 | def set_comm(comm):
278 | get_current().set_comm(comm)
279 |
280 |
281 | def get_dir():
282 | """
283 | Get directory that log files are being written to.
284 | will be None if there is no output directory (i.e., if you didn't call start)
285 | """
286 | return get_current().get_dir()
287 |
288 |
289 | record_tabular = logkv
290 | dump_tabular = dumpkvs
291 |
292 |
293 | @contextmanager
294 | def profile_kv(scopename):
295 | logkey = "wait_" + scopename
296 | tstart = time.time()
297 | try:
298 | yield
299 | finally:
300 | get_current().name2val[logkey] += time.time() - tstart
301 |
302 |
303 | def profile(n):
304 | """
305 | Usage:
306 | @profile("my_func")
307 | def my_func(): code
308 | """
309 |
310 | def decorator_with_name(func):
311 | def func_wrapper(*args, **kwargs):
312 | with profile_kv(n):
313 | return func(*args, **kwargs)
314 |
315 | return func_wrapper
316 |
317 | return decorator_with_name
318 |
319 |
320 | # ================================================================
321 | # Backend
322 | # ================================================================
323 |
324 |
325 | def get_current():
326 | if Logger.CURRENT is None:
327 | _configure_default_logger()
328 |
329 | return Logger.CURRENT
330 |
331 |
332 | class Logger(object):
333 | DEFAULT = None # A logger with no output files. (See right below class definition)
334 | # So that you can still log to the terminal without setting up any output files
335 | CURRENT = None # Current logger being used by the free functions above
336 |
337 | def __init__(self, dir, output_formats, comm=None):
338 | self.name2val = defaultdict(float) # values this iteration
339 | self.name2cnt = defaultdict(int)
340 | self.level = INFO
341 | self.dir = dir
342 | self.output_formats = output_formats
343 | self.comm = comm
344 |
345 | # Logging API, forwarded
346 | # ----------------------------------------
347 | def logkv(self, key, val):
348 | self.name2val[key] = val
349 |
350 | def logkv_mean(self, key, val):
351 | oldval, cnt = self.name2val[key], self.name2cnt[key]
352 | self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1)
353 | self.name2cnt[key] = cnt + 1
354 |
355 | def dumpkvs(self):
356 | if self.comm is None:
357 | d = self.name2val
358 | else:
359 | d = mpi_weighted_mean(
360 | self.comm,
361 | {
362 | name: (val, self.name2cnt.get(name, 1))
363 | for (name, val) in self.name2val.items()
364 | },
365 | )
366 | if self.comm.rank != 0:
367 | d["dummy"] = 1 # so we don't get a warning about empty dict
368 | out = d.copy() # Return the dict for unit testing purposes
369 | for fmt in self.output_formats:
370 | if isinstance(fmt, KVWriter):
371 | fmt.writekvs(d)
372 | self.name2val.clear()
373 | self.name2cnt.clear()
374 | return out
375 |
376 | def log(self, *args, level=INFO):
377 | if self.level <= level:
378 | self._do_log(args)
379 |
380 | # Configuration
381 | # ----------------------------------------
382 | def set_level(self, level):
383 | self.level = level
384 |
385 | def set_comm(self, comm):
386 | self.comm = comm
387 |
388 | def get_dir(self):
389 | return self.dir
390 |
391 | def close(self):
392 | for fmt in self.output_formats:
393 | fmt.close()
394 |
395 | # Misc
396 | # ----------------------------------------
397 | def _do_log(self, args):
398 | for fmt in self.output_formats:
399 | if isinstance(fmt, SeqWriter):
400 | fmt.writeseq(map(str, args))
401 |
402 |
403 | def get_rank_without_mpi_import():
404 | # check environment variables here instead of importing mpi4py
405 | # to avoid calling MPI_Init() when this module is imported
406 | for varname in ["PMI_RANK", "OMPI_COMM_WORLD_RANK"]:
407 | if varname in os.environ:
408 | return int(os.environ[varname])
409 | return 0
410 |
411 |
412 | def mpi_weighted_mean(comm, local_name2valcount):
413 | """
414 | Copied from: https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/mpi_util.py#L110
415 | Perform a weighted average over dicts that are each on a different node
416 | Input: local_name2valcount: dict mapping key -> (value, count)
417 | Returns: key -> mean
418 | """
419 | all_name2valcount = comm.gather(local_name2valcount)
420 | if comm.rank == 0:
421 | name2sum = defaultdict(float)
422 | name2count = defaultdict(float)
423 | for n2vc in all_name2valcount:
424 | for (name, (val, count)) in n2vc.items():
425 | try:
426 | val = float(val)
427 | except ValueError:
428 | if comm.rank == 0:
429 | warnings.warn(
430 | "WARNING: tried to compute mean on non-float {}={}".format(
431 | name, val
432 | )
433 | )
434 | else:
435 | name2sum[name] += val * count
436 | name2count[name] += count
437 | return {name: name2sum[name] / name2count[name] for name in name2sum}
438 | else:
439 | return {}
440 |
441 |
442 | def configure(dir='/home/mark/projects/prompt_2_clvr/results', format_strs=None, comm=None, log_suffix=""):
443 | """
444 | If comm is provided, average all numerical stats across that comm
445 | """
446 | if dir is None:
447 | dir = os.getenv("OPENAI_LOGDIR")
448 | if dir is None:
449 | dir = osp.join(
450 | tempfile.gettempdir(),
451 | datetime.datetime.now().strftime("openai_unshuffle-%m-%d-%H-%M-%S-%f"),
452 | )
453 | else:
454 | dir=osp.join(dir,
455 | datetime.datetime.now().strftime("prompt_mnist_20T4M_Tansformer-%m-%d-%H-%M-%S"),
456 | )
457 | assert isinstance(dir, str)
458 | dir = os.path.expanduser(dir)
459 | os.makedirs(os.path.expanduser(dir), exist_ok=True)
460 |
461 | rank = get_rank_without_mpi_import()
462 | if rank > 0:
463 | log_suffix = log_suffix + "-rank%03i" % rank
464 |
465 | if format_strs is None:
466 | if rank == 0:
467 | format_strs = os.getenv("OPENAI_LOG_FORMAT", "stdout,log,csv").split(",")
468 | else:
469 | format_strs = os.getenv("OPENAI_LOG_FORMAT_MPI", "log").split(",")
470 | format_strs = filter(None, format_strs)
471 | output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs]
472 |
473 | Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm)
474 | if output_formats:
475 | log("Logging to %s" % dir)
476 |
477 |
478 | def _configure_default_logger():
479 | configure()
480 | Logger.DEFAULT = Logger.CURRENT
481 |
482 |
483 | def reset():
484 | if Logger.CURRENT is not Logger.DEFAULT:
485 | Logger.CURRENT.close()
486 | Logger.CURRENT = Logger.DEFAULT
487 | log("Reset logger")
488 |
489 |
490 | @contextmanager
491 | def scoped_configure(dir=None, format_strs=None, comm=None):
492 | prevlogger = Logger.CURRENT
493 | configure(dir=dir, format_strs=format_strs, comm=comm)
494 | try:
495 | yield
496 | finally:
497 | Logger.CURRENT.close()
498 | Logger.CURRENT = prevlogger
499 |
500 |
--------------------------------------------------------------------------------
/video_model/video_sample.py:
--------------------------------------------------------------------------------
1 | """
2 | Generate a large batch of image samples from a model and save them as a large
3 | numpy array. This can be used to produce samples for FID evaluation.
4 | """
5 | import argparse
6 | import os
7 | import shutil
8 | import sys
9 |
10 | import matplotlib.pyplot as plt
11 | import numpy as np
12 | import torch as th
13 | import torch.nn as nn
14 | import torch.distributed as dist
15 | from matplotlib import gridspec
16 | from skimage.metrics import structural_similarity, mean_squared_error
17 | from sklearn.preprocessing import MinMaxScaler
18 | from positional_encodings.torch_encodings import PositionalEncoding1D
19 | sys.path.insert(1, os.getcwd())
20 | import random
21 |
22 | from video_datasets import load_data
23 | import dist_util, logger
24 | from script_util import (
25 | model_and_diffusion_defaults,
26 | create_model_and_diffusion,
27 | add_dict_to_argparser,
28 | args_to_dict,
29 | )
30 |
31 | from functools import reduce
32 |
33 | import torch
34 |
35 | from sklearn.metrics import pairwise_distances
36 |
37 | def save_tensor(file_name, data_tensor):
38 | directory = "/home/wondmgezahu/ppo/latent-diffusion-main/EncodedLatent"
39 | file_path = os.path.join(directory, file_name)
40 | torch.save(data_tensor, file_path)
41 | def plot_samples(generated_data, shuffled_input_data, ground_truth_data):
42 | # Delete the "figures" directory and its contents (if it exists)
43 | if os.path.exists('fig_test'):
44 | shutil.rmtree('fig_test')
45 |
46 | # Create a new "figures" directory
47 | os.makedirs('fig_test')
48 |
49 | for i in range(generated_data.shape[0]): #
50 | # Select the current batch
51 | batch_generated = generated_data[i, :, :, :, :]
52 | batch_shuffled = shuffled_input_data[i, :, :, :, :]
53 | batch_ground_truth = ground_truth_data[i, :, :, :, :]
54 |
55 | # Create a figure with 10 subplots (one for each frame)
56 | fig = plt.figure(figsize=(6, 6)) # (3,3) for 4 frame
57 | gs = gridspec.GridSpec(nrows=3, ncols=generated_data.shape[2], left=0, right=1, top=1, bottom=0)
58 |
59 | # Reduce the space between the subplots
60 | # fig.tight_layout()
61 | plt.subplots_adjust(wspace=0, hspace=0)
62 |
63 | # Loop over the frames
64 | for j in range(batch_generated.shape[1]):
65 | # Select the current frame
66 | frame_generated = batch_generated[:, j, :, :]
67 | frame_shuffled = batch_shuffled[:, j, :, :]
68 | frame_ground_truth = batch_ground_truth[:, j, :, :]
69 |
70 | # Create a subplot at the current position
71 | ax1 = fig.add_subplot(gs[0, j])
72 | ax2 = fig.add_subplot(gs[1, j])
73 | ax3 = fig.add_subplot(gs[2, j])
74 |
75 | # Plot the frame in grayscale
76 | # breakpoint()
77 | ax1.imshow(np.squeeze(frame_shuffled), cmap='gray') # cmap='gray'
78 | ax1.axis('off')
79 | if j == 4:
80 | ax1.set_title('shuffled samples', loc='center', fontsize=10)
81 |
82 | ax2.imshow(np.squeeze(frame_ground_truth), cmap='gray') # cmap='gray'
83 | ax2.axis('off')
84 | if j == 4:
85 | ax2.set_title('ground truth samples', loc='center', fontsize=10)
86 |
87 | ax3.imshow(np.squeeze(frame_generated), cmap='gray') # cmap='gray'
88 | ax3.axis('off')
89 | if j == 4:
90 | ax3.set_title('generated samples', loc='center', fontsize=10)
91 |
92 | plt.savefig('fig_test/batch{}.png'.format(i))
93 |
94 | def main():
95 | args = create_argparser().parse_args()
96 |
97 | dist_util.setup_dist()
98 | # logger.configure()
99 | if args.seed:
100 | th.manual_seed(args.seed)
101 | np.random.seed(args.seed)
102 | random.seed(args.seed)
103 |
104 | # logger.log("creating model and diffusion...")
105 | model, diffusion = create_model_and_diffusion(
106 | **args_to_dict(args, model_and_diffusion_defaults().keys())
107 | )
108 |
109 | model.load_state_dict(
110 | dist_util.load_state_dict(args.model_path, map_location="cpu")
111 | )
112 | model.to(dist_util.dev())
113 | # breakpoint()
114 | model.train()
115 |
116 | cond_kwargs = {}
117 | cond_frames = []
118 | if args.cond_generation:
119 | data = load_data(
120 | data_dir=args.data_dir,
121 | batch_size=args.batch_size,
122 | image_size=args.image_size,
123 | class_cond=args.class_cond,
124 | deterministic=True,
125 | # rgb=args.rgb,
126 | seq_len=args.seq_len
127 | )
128 |
129 | num = ""
130 |
131 | for i in args.cond_frames:
132 | if i == ",":
133 | cond_frames.append(int(num))
134 | num = ""
135 | else:
136 | num = num + i
137 | ref_frames = list(i for i in range(args.seq_len) if i not in cond_frames)
138 | # logger.log(f"cond_frames: {cond_frames}")
139 | # logger.log(f"ref_frames: {ref_frames}")
140 | # logger.log(f"seq_len: {args.seq_len}")
141 | cond_kwargs["resampling_steps"] = args.resample_steps
142 | cond_kwargs["cond_frames"] = cond_frames
143 |
144 | channels = 4
145 | # breakpoint()
146 | # logger.log("sampling...")
147 | all_videos = []
148 | all_gt = []
149 | shuffled_input = []
150 | ground_truth = []
151 | sample_time=[]
152 | perm_index = []
153 | imputation_index = []
154 | all_idx = []
155 | all_permutation = []
156 | all_names = []
157 | all_sampled_normalized=[]
158 | mask_list=[]
159 | kendall_dis_sum = 0
160 | kendall_dis_list = []
161 | while len(kendall_dis_list) < args.num_samples:
162 | print(len(kendall_dis_list))
163 | min_val, max_val = 0, 0
164 | if args.cond_generation:
165 | raw_video, _,index = next(data) # video, _ = next(data) for others except something and nturgbd-skeleton dataset
166 | # breakpoint()
167 | raw_video = raw_video.permute(0,2,1,3,4)
168 | # all_names.append(name)
169 | # raw_video = raw_video.permute(0, 2, 1, 3, 4) # permuted for something dataset and nturgbd-skeleton only
170 | # original_skeleton = video.clone().detach()
171 | video = raw_video.float() # just for skeleton and nturgbd dataset
172 | # breakpoint()
173 | # normalization for nturgbd
174 | min_val = torch.min(video)
175 | max_val = torch.max(video)
176 | range_val = max_val - min_val
177 |
178 | batch = (video - min_val) / range_val
179 | video = batch * 2 - 1
180 | # cond_kwargs["cond_img"] = video[:, :, cond_frames].to(dist_util.dev())
181 | # cond_kwargs["cond_frames"] = cond_frames
182 | # video = video.to(dist_util.dev())
183 | # breakpoint()
184 | video = video
185 | idx = th.randperm(video.shape[2])
186 | idx=idx[idx!=0]
187 | idx=torch.cat((torch.tensor([0]),idx))
188 | # breakpoint()
189 | perm_index.append(idx)
190 | # original_data = torch.Tensor(video)
191 |
192 | conditional_data = video[:, :, idx, :, :]
193 | true_idx = np.zeros(args.seq_len, )
194 | true_idx[:] = idx
195 |
196 | # Remove some frames
197 | num_frames_to_remove = 12
198 | print("number of missing frame is: ",num_frames_to_remove)
199 | #
200 | # # create an array of indices for all frames
201 | frame_indices = np.arange(conditional_data.shape[2])
202 | #
203 | # # randomly select the indices of the frames to remove for all batches
204 | remove_indices = np.random.choice(frame_indices, num_frames_to_remove, replace=False)
205 | remove_indices = remove_indices[remove_indices!=0]
206 | #
207 | # # create an array of zeros with the same shape as the video frames
208 | dummy_frames = np.random.randn(
209 | conditional_data.shape[0], conditional_data.shape[1], 1, conditional_data.shape[3],
210 | conditional_data.shape[4])
211 | dummy_frames = torch.from_numpy(dummy_frames).to(conditional_data.device)
212 | # breakpoint()
213 | conditional_data[:, :, remove_indices, :, :] = dummy_frames.float()
214 |
215 | mask = np.ones(conditional_data.shape[2])
216 | mask[remove_indices] = 0
217 | mask_list.append(mask)
218 | mask = torch.tensor(mask).cuda()
219 |
220 |
221 | cond_kwargs["cond_frames"] = [i for i in frame_indices if i not in remove_indices]
222 | cond_kwargs["cond_img"] = conditional_data[:, :, cond_kwargs["cond_frames"]].to(dist_util.dev())
223 | # breakpoint()
224 | sample_fn = (
225 | diffusion.p_sample_loop if not args.use_ddim else diffusion.ddim_sample_loop
226 | )
227 | sample = sample_fn(
228 | model,
229 | (args.batch_size, args.seq_len, 16),
230 | clip_denoised=args.clip_denoised,
231 | progress=False,
232 | cond_kwargs=cond_kwargs,
233 | condition_args=conditional_data,
234 | # mask=mask,
235 | original_args=video
236 | # gap=args.diffusion_step_gap
237 | )
238 | # breakpoint()
239 | p_enc_1d_model = PositionalEncoding1D(16)
240 | time_emb_gt = p_enc_1d_model(torch.rand(sample[0].shape[0], sample[0].shape[1], 16))
241 |
242 |
243 | def find_permutation(distance_matrix):
244 | sort_list = []
245 | for m in range(distance_matrix.shape[1]):
246 | order = distance_matrix[:,0].argmin()
247 | sort_list.append(order)
248 | distance_matrix = distance_matrix[:,1:]
249 | distance_matrix[order,:] = 10**5
250 | return sort_list
251 | # conditional_data = ((conditional_data.clamp(-1,1) + 1) * (max_val-min_val) /2+ min_val)
252 | # sample = sample.contiguous()
253 | sample_normalized = sample[0].clamp(-1, 1)
254 | dist = pairwise_distances(time_emb_gt[0], sample[0][0].cpu(), metric='manhattan')
255 | permutation = find_permutation(dist)
256 | permutation = np.array(permutation)
257 | idx = np.array(idx)
258 |
259 | if num_frames_to_remove != 0:
260 | # breakpoint()
261 | permutation_con = permutation[cond_kwargs["cond_frames"]]
262 | idx_con = idx[cond_kwargs["cond_frames"]]
263 | permutation_blank = np.sort(permutation[remove_indices])
264 | idx_blank = np.sort(idx[remove_indices])
265 |
266 | permutation = np.concatenate((permutation_con,permutation_blank))
267 | idx = np.concatenate((idx_con,idx_blank))
268 |
269 |
270 | kendall_dis_1 = 0
271 | for n1 in range(len(permutation)):
272 | for n2 in range(len(permutation)):
273 | if permutation[n1]idx[n2]:
274 | kendall_dis_1 += 1
275 | # kendall_dis_1 += np.abs(permutation - idx).sum()
276 | kendall_dis_2 = 0
277 | for n1 in range(len(permutation)):
278 | for n2 in range(len(permutation)):
279 | if permutation[n1] < permutation[n2] and (args.seq_len-1-idx)[n1] > (args.seq_len-1-idx)[n2]:
280 | kendall_dis_2 += 1
281 | # kendall_dis_2 += np.abs(permutation -args.seq_len + 1 + idx).sum()
282 | kendall_dis =min(kendall_dis_1,kendall_dis_2)/(sample[0].shape[1]*(sample[0].shape[1]-1)/2)
283 |
284 |
285 | # breakpoint()
286 |
287 | kendall_dis_sum += kendall_dis
288 | kendall_dis_list.append(kendall_dis)
289 | print(len(kendall_dis_list), 'kendall distance:', kendall_dis,kendall_dis_sum/len(kendall_dis_list) )
290 | if kendall_dis>0.05:
291 | print('video_index:',index)
292 | print(idx)
293 | print(permutation)
294 |
295 |
296 | print("Total result is:",kendall_dis_sum/args.num_samples)
297 | # # end of normalization
298 | # all_sampled_normalized.append(sample_normalized)
299 | # gathered_samples = [th.zeros_like(sample) for _ in range(dist.get_world_size())]
300 | # # gathered_samples_time=[th.zeros_like(sample_time_out) for _ in range(dist.get_world_size())]
301 | # dist.all_gather(gathered_samples, sample) # gather not supported with NCCL
302 | # all_videos.extend([sample.cpu().numpy() for sample in gathered_samples])
303 | # # sample_time.append(sample_time_out.cpu().numpy()) #for sample_time_out in gathered_samples_time])
304 | # ground_truth.append(video.permute(0, 2, 3, 4, 1).cpu().numpy())
305 | # shuffled_input.append(conditional_data.permute(0, 2, 3, 4, 1).cpu().numpy())
306 | # all_idx.append(true_idx)
307 | # # breakpoint()
308 | # # logger.log(f"created {len(all_videos) * args.batch_size} samples")
309 |
310 | # generated_data = np.concatenate(all_videos, axis=0)
311 | # ground_truth_data = np.asarray(ground_truth)
312 | # shuffled_input_data = np.asarray(shuffled_input)
313 |
314 | # gen_raw = np.concatenate(all_videos, axis=0) # raw BxTxWxHXC 1x20x32x32x4
315 | # gen_raw = torch.from_numpy(gen_raw).float()
316 | # gen_norm = torch.cat(all_sampled_normalized, dim=0)
317 | # gen_time=np.concatenate(sample_time,axis=0)
318 | # breakpoint()
319 | # plot_samples(generated_data, shuffled_input_data, ground_truth_data)
320 | # fig=plt.figure()
321 | # for i in range(gen_time.shape[0]):
322 | # plt.imshow(gen_time[i])
323 | # plt.savefig('fig_test/time{}.png'.format(i))
324 |
325 | # # breakpoint()
326 | # p_enc_1d_model = PositionalEncoding1D(16)
327 | # time_emb = p_enc_1d_model(torch.rand(generated_data.shape[0], generated_data.shape[1], 16))
328 | # # distances = torch.cdist(torch.tensor(generated_data[0]), torch.tensor(time_emb[0]))
329 | # # indices = distances.argmin(dim=-1)
330 | # indice_list=[]
331 | # for i in range(generated_data.shape[0]):
332 | # distances = torch.cdist(torch.tensor(generated_data[i]), torch.tensor(time_emb[0]))
333 | # indices = distances.argmin(dim=-1)
334 | # indice_list.append(indices)
335 | # # breakpoint()
336 | # from datetime import datetime
337 | # currentDateAndTime = datetime.now()
338 | # if args.cond_generation and args.save_gt:
339 | # np.save(os.path.join('GeneratedLatent', str(args.num_samples) + '_samples_' +str(datetime.now().day)+'_'+str(datetime.now().hour)+ '_gt_'+str(args.resample_steps)+str(args.diffusion_step_gap)+'.npy'),ground_truth_data)
340 | # np.save(os.path.join('GeneratedLatent', str(args.num_samples) + '_samples_' +str(datetime.now().day)+'_'+str(datetime.now().hour)+ '_gen_'+str(args.resample_steps)+str(args.diffusion_step_gap)+'.npy'),
341 | # generated_data)
342 | # np.save(os.path.join('GeneratedLatent', str(args.num_samples) + '_samples_' +str(datetime.now().day)+'_'+str(datetime.now().hour)+ '_idx_'+str(args.resample_steps)+str(args.diffusion_step_gap)+'.npy'),
343 | # np.asarray(all_idx))
344 | # np.save(os.path.join('GeneratedLatent', str(args.num_samples) + '_samples_' + str(datetime.now().day) + '_' + str(
345 | # datetime.now().hour) + '_file_names_' + str(args.resample_steps) + str(args.diffusion_step_gap) + '.npy'),
346 | # np.asarray(all_names))
347 | # np.save(os.path.join('GeneratedLatent', str(args.num_samples) + '_samples_' +str(datetime.now().day)+'_'+str(datetime.now().hour)+ '_condition_'+str(args.resample_steps)+str(args.diffusion_step_gap)+'.npy'),
348 | # shuffled_input_data)
349 | # np.save(os.path.join('GeneratedLatent', str(args.num_samples) + '_samples_' +str(datetime.now().day)+'_'+str(datetime.now().hour)+ '_maskList_'+str(args.resample_steps)+str(args.diffusion_step_gap)+'.npy'),
350 | # np.asarray(mask_list))
351 | # dist.barrier()
352 | # logger.log("sampling complete")
353 |
354 |
355 | def create_argparser():
356 | defaults = dict(
357 | clip_denoised=True,
358 | num_samples=1000, # 10
359 | batch_size=1, # 10
360 | use_ddim=False,
361 | model_path="",
362 | seq_len=32, # 16
363 | sampling_type="generation",
364 | cond_frames="",
365 | cond_generation=True, # True
366 | resample_steps=1,
367 | data_dir='',
368 | save_gt=True,
369 | seed=0,
370 | diffusion_step_gap=1,
371 | )
372 | defaults.update(model_and_diffusion_defaults())
373 | parser = argparse.ArgumentParser()
374 | add_dict_to_argparser(parser, defaults)
375 | return parser
376 |
377 |
378 | if __name__ == "__main__":
379 | import time
380 |
381 | MODEL_FLAGS = "--image_size 8 --num_channels 128 --num_res_blocks 3 --scale_time_dim 8" # image size 64
382 | DIFFUSION_FLAGS = "--diffusion_steps 1200 --noise_schedule linear" # 1000
383 | start = time.time()
384 | main()
385 | end = time.time()
386 | print(f"elapsed time: {end - start}")
--------------------------------------------------------------------------------
/image_model/models.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # --------------------------------------------------------
7 | # References:
8 | # GLIDE: https://github.com/openai/glide-text2im
9 | # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
10 | # --------------------------------------------------------
11 |
12 | import torch
13 | import torch.nn as nn
14 | import numpy as np
15 | import math
16 | from timm.models.vision_transformer import PatchEmbed, Attention, Mlp
17 |
18 |
19 | def modulate(x, shift, scale):
20 | return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
21 |
22 |
23 | #################################################################################
24 | # Embedding Layers for Timesteps and Class Labels #
25 | #################################################################################
26 |
27 | class TimestepEmbedder(nn.Module):
28 | """
29 | Embeds scalar timesteps into vector representations.
30 | """
31 | def __init__(self, hidden_size, frequency_embedding_size=256):
32 | super().__init__()
33 | self.mlp = nn.Sequential(
34 | nn.Linear(frequency_embedding_size, hidden_size, bias=True),
35 | nn.SiLU(),
36 | nn.Linear(hidden_size, hidden_size, bias=True),
37 | )
38 | self.frequency_embedding_size = frequency_embedding_size
39 |
40 | @staticmethod
41 | def timestep_embedding(t, dim, max_period=10000):
42 | """
43 | Create sinusoidal timestep embeddings.
44 | :param t: a 1-D Tensor of N indices, one per batch element.
45 | These may be fractional.
46 | :param dim: the dimension of the output.
47 | :param max_period: controls the minimum frequency of the embeddings.
48 | :return: an (N, D) Tensor of positional embeddings.
49 | """
50 | # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
51 | half = dim // 2
52 | freqs = torch.exp(
53 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
54 | ).to(device=t.device)
55 | args = t[:, None].float() * freqs[None]
56 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
57 | if dim % 2:
58 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
59 | return embedding
60 |
61 | def forward(self, t):
62 | t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
63 | t_emb = self.mlp(t_freq)
64 | return t_emb
65 |
66 |
67 | class LabelEmbedder(nn.Module):
68 | """
69 | Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
70 | """
71 | def __init__(self, num_classes, hidden_size, dropout_prob):
72 | super().__init__()
73 | use_cfg_embedding = dropout_prob > 0
74 | self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
75 | self.num_classes = num_classes
76 | self.dropout_prob = dropout_prob
77 |
78 | def token_drop(self, labels, force_drop_ids=None):
79 | """
80 | Drops labels to enable classifier-free guidance.
81 | """
82 | if force_drop_ids is None:
83 | drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
84 | else:
85 | drop_ids = force_drop_ids == 1
86 | labels = torch.where(drop_ids, self.num_classes, labels)
87 | return labels
88 |
89 | def forward(self, labels, train, force_drop_ids=None):
90 | use_dropout = self.dropout_prob > 0
91 | if (train and use_dropout) or (force_drop_ids is not None):
92 | labels = self.token_drop(labels, force_drop_ids)
93 | embeddings = self.embedding_table(labels)
94 | return embeddings
95 |
96 |
97 | #################################################################################
98 | # Core DiT Model #
99 | #################################################################################
100 |
101 | class DiTBlock(nn.Module):
102 | """
103 | A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
104 | """
105 | def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
106 | super().__init__()
107 | self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
108 | self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
109 | self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
110 | mlp_hidden_dim = int(hidden_size * mlp_ratio)
111 | approx_gelu = lambda: nn.GELU(approximate="tanh")
112 | self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
113 | self.adaLN_modulation = nn.Sequential(
114 | nn.SiLU(),
115 | nn.Linear(hidden_size, 6 * hidden_size, bias=True)
116 | )
117 |
118 | def forward(self, x, c):
119 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
120 | x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
121 | x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
122 | return x
123 |
124 |
125 | class FinalLayer(nn.Module):
126 | """
127 | The final layer of DiT.
128 | """
129 | def __init__(self, hidden_size, patch_size, out_channels):
130 | super().__init__()
131 | self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
132 | self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
133 | self.adaLN_modulation = nn.Sequential(
134 | nn.SiLU(),
135 | nn.Linear(hidden_size, 2 * hidden_size, bias=True)
136 | )
137 |
138 | def forward(self, x, c):
139 | shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
140 | x = modulate(self.norm_final(x), shift, scale)
141 | x = self.linear(x)
142 | return x
143 |
144 |
145 | class DiT(nn.Module):
146 | """
147 | Diffusion model with a Transformer backbone.
148 | """
149 | def __init__(
150 | self,
151 | input_size=255,
152 | patch_size=2,
153 | in_channels=3,
154 | hidden_size=1152,
155 | depth=28,
156 | num_heads=16,
157 | mlp_ratio=4.0,
158 | class_dropout_prob=0.1,
159 | num_classes=0,
160 | learn_sigma=False,
161 | ):
162 | super().__init__()
163 | self.learn_sigma = learn_sigma
164 | self.in_channels = in_channels
165 | self.out_channels = in_channels * 2 if learn_sigma else in_channels
166 | self.patch_size = patch_size
167 | self.num_heads = num_heads
168 |
169 | self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
170 | self.t_embedder = TimestepEmbedder(hidden_size)
171 | # self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
172 | num_patches = self.x_embedder.num_patches
173 | # Will use fixed sin-cos embedding:
174 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
175 |
176 | self.time_emb_in = nn.Linear(8,768)
177 | self.time_emb_out1 = nn.Linear(768,64)
178 | self.time_emb_out_silu = nn.SiLU()
179 | self.time_emb_out2 = nn.Linear(64,8)
180 |
181 | self.blocks = nn.ModuleList([
182 | DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)
183 | ])
184 | self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
185 | self.initialize_weights()
186 |
187 | def initialize_weights(self):
188 | # Initialize transformer layers:
189 | def _basic_init(module):
190 | if isinstance(module, nn.Linear):
191 | torch.nn.init.xavier_uniform_(module.weight)
192 | if module.bias is not None:
193 | nn.init.constant_(module.bias, 0)
194 | self.apply(_basic_init)
195 |
196 | # Initialize (and freeze) pos_embed by sin-cos embedding:
197 | pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5))
198 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
199 |
200 | # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
201 | w = self.x_embedder.proj.weight.data
202 | nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
203 | nn.init.constant_(self.x_embedder.proj.bias, 0)
204 |
205 | # Initialize label embedding table:
206 | # nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
207 |
208 | # Initialize timestep embedding MLP:
209 | nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
210 | nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
211 |
212 | nn.init.normal_(self.time_emb_in.weight, std=0.02)
213 | nn.init.normal_(self.time_emb_out1.weight, std=0.02)
214 | nn.init.normal_(self.time_emb_out2.weight, std=0.02)
215 |
216 | # Zero-out adaLN modulation layers in DiT blocks:
217 | for block in self.blocks:
218 | nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
219 | nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
220 |
221 | # Zero-out output layers:
222 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
223 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
224 | nn.init.constant_(self.final_layer.linear.weight, 0)
225 | nn.init.constant_(self.final_layer.linear.bias, 0)
226 |
227 | def unpatchify(self, x):
228 | """
229 | x: (N, T, patch_size**2 * C)
230 | imgs: (N, H, W, C)
231 | """
232 | c = self.out_channels
233 | p = self.x_embedder.patch_size[0]
234 | h = w = int(x.shape[1] ** 0.5)
235 | assert h * w == x.shape[1]
236 |
237 | x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
238 | x = torch.einsum('nhwpqc->nchpwq', x)
239 | imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
240 | return imgs
241 |
242 | def forward(self, x, t, time_emb, y=None):
243 | """
244 | Forward pass of DiT.
245 | x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
246 | t: (N,) tensor of diffusion timesteps
247 | y: (N,) tensor of class labels
248 | """
249 | time_emb = self.time_emb_in(time_emb)
250 | x = self.x_embedder(x) + time_emb + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2
251 | t = self.t_embedder(t) # (N, D)
252 | # y = self.y_embedder(y, self.training) # (N, D)
253 | c = t # (N, D)
254 | for block in self.blocks:
255 | x = block(x, c) # (N, T, D)
256 | x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)
257 | time_emb = self.time_emb_out1(x)
258 | time_emb = self.time_emb_out_silu(time_emb)
259 | time_emb = self.time_emb_out2(time_emb)
260 | x = self.unpatchify(x) # (N, out_channels, H, W)
261 |
262 | return x,time_emb
263 |
264 | def forward_with_cfg(self, x, t, y, cfg_scale):
265 | """
266 | Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
267 | """
268 | # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
269 | half = x[: len(x) // 2]
270 | combined = torch.cat([half, half], dim=0)
271 | model_out = self.forward(combined, t, y)
272 | # For exact reproducibility reasons, we apply classifier-free guidance on only
273 | # three channels by default. The standard approach to cfg applies it to all channels.
274 | # This can be done by uncommenting the following line and commenting-out the line following that.
275 | # eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
276 | eps, rest = model_out[:, :3], model_out[:, 3:]
277 | cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
278 | half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
279 | eps = torch.cat([half_eps, half_eps], dim=0)
280 | return torch.cat([eps, rest], dim=1)
281 |
282 |
283 | #################################################################################
284 | # Sine/Cosine Positional Embedding Functions #
285 | #################################################################################
286 | # https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
287 |
288 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
289 | """
290 | grid_size: int of the grid height and width
291 | return:
292 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
293 | """
294 | grid_h = np.arange(grid_size, dtype=np.float32)
295 | grid_w = np.arange(grid_size, dtype=np.float32)
296 | grid = np.meshgrid(grid_w, grid_h) # here w goes first
297 | grid = np.stack(grid, axis=0)
298 |
299 | grid = grid.reshape([2, 1, grid_size, grid_size])
300 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
301 | if cls_token and extra_tokens > 0:
302 | pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
303 | return pos_embed
304 |
305 |
306 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
307 | assert embed_dim % 2 == 0
308 |
309 | # use half of dimensions to encode grid_h
310 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
311 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
312 |
313 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
314 | return emb
315 |
316 |
317 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
318 | """
319 | embed_dim: output dimension for each position
320 | pos: a list of positions to be encoded: size (M,)
321 | out: (M, D)
322 | """
323 | assert embed_dim % 2 == 0
324 | omega = np.arange(embed_dim // 2, dtype=np.float64)
325 | omega /= embed_dim / 2.
326 | omega = 1. / 10000**omega # (D/2,)
327 |
328 | pos = pos.reshape(-1) # (M,)
329 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
330 |
331 | emb_sin = np.sin(out) # (M, D/2)
332 | emb_cos = np.cos(out) # (M, D/2)
333 |
334 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
335 | return emb
336 |
337 |
338 | #################################################################################
339 | # DiT Configs #
340 | #################################################################################
341 |
342 | def DiT_XL_2(**kwargs):
343 | return DiT(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs)
344 |
345 | def DiT_XL_4(**kwargs):
346 | return DiT(depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs)
347 |
348 | def DiT_XL_8(**kwargs):
349 | return DiT(depth=28, hidden_size=1152, patch_size=8, num_heads=16, **kwargs)
350 |
351 | def DiT_L_2(**kwargs):
352 | return DiT(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs)
353 |
354 | def DiT_L_4(**kwargs):
355 | return DiT(depth=24, hidden_size=1024, patch_size=4, num_heads=16, **kwargs)
356 |
357 | def DiT_L_8(**kwargs):
358 | return DiT(depth=24, hidden_size=1024, patch_size=8, num_heads=16, **kwargs)
359 |
360 | def DiT_B_2(**kwargs):
361 | return DiT(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs)
362 |
363 | def DiT_B_4(**kwargs):
364 | return DiT(depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs)
365 |
366 | def DiT_B_8(**kwargs):
367 | return DiT(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs)
368 |
369 | def DiT_S_2(**kwargs):
370 | return DiT(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs)
371 |
372 | def DiT_S_4(**kwargs):
373 | return DiT(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs)
374 |
375 | def DiT_S_8(**kwargs):
376 | return DiT(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs)
377 |
378 | def JPDVT(**kwargs):
379 | return DiT(depth=12, hidden_size=768, patch_size=16, num_heads=12, **kwargs)
380 |
381 | def JPDVT_S(**kwargs):
382 | return DiT(depth=12, hidden_size=768, patch_size=32, num_heads=12, **kwargs)
383 |
384 | def JPDVT_T(**kwargs):
385 | return DiT(depth=12, hidden_size=768, patch_size=64, num_heads=12, **kwargs)
386 |
387 | DiT_models = {
388 | 'DiT-XL/2': DiT_XL_2, 'DiT-XL/4': DiT_XL_4, 'DiT-XL/8': DiT_XL_8,
389 | 'DiT-L/2': DiT_L_2, 'DiT-L/4': DiT_L_4, 'DiT-L/8': DiT_L_8,
390 | 'DiT-B/2': DiT_B_2, 'DiT-B/4': DiT_B_4, 'DiT-B/8': DiT_B_8,
391 | 'DiT-S/2': DiT_S_2, 'DiT-S/4': DiT_S_4, 'DiT-S/8': DiT_S_8,
392 | 'JPDVT': JPDVT, 'JPDVT-S': JPDVT_S, 'JPDVT-T': JPDVT_T,
393 | }
394 |
--------------------------------------------------------------------------------
/video_model/diffusion/train_util.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import functools
3 | import os
4 | import pdb
5 | import blobfile as bf
6 | import numpy as np
7 | import torch as th
8 | import torch.distributed as dist
9 | from torch.nn.parallel.distributed import DistributedDataParallel as DDP
10 | from torch.optim import AdamW
11 | import dist_util, logger
12 | from fp16_util import (
13 | make_master_params,
14 | master_params_to_model_params,
15 | model_grads_to_master_grads,
16 | unflatten_master_params,
17 | zero_grad,
18 | )
19 | from nn import update_ema
20 | from resample import LossAwareSampler, UniformSampler
21 | from positional_encodings.torch_encodings import PositionalEncoding1D
22 |
23 | # For ImageNet experiments, this was a good default value.
24 | # We found that the lg_loss_scale quickly climbed to
25 | # 20-21 within the first ~1K steps of training.
26 | INITIAL_LOG_LOSS_SCALE = 20.0
27 | import torch
28 |
29 | class TrainLoop:
30 |
31 | def __init__(
32 | self,
33 | *,
34 | model,
35 | diffusion,
36 | data,
37 | batch_size,
38 | microbatch,
39 | lr,
40 | ema_rate,
41 | log_interval,
42 | save_interval,
43 | resume_checkpoint,
44 | use_fp16=False,
45 | fp16_scale_growth=1e-3,
46 | schedule_sampler=None,
47 | weight_decay=0.0,
48 | lr_anneal_steps=0,
49 | clip=1,
50 | anneal_type=None,
51 | steps_drop=None,
52 | drop=None,
53 | decay=None,
54 | max_num_mask_frames=4, #4
55 | mask_range=None,
56 | uncondition_rate=True,
57 | exclude_conditional=True,
58 | ):
59 |
60 | self.model = model
61 | self.diffusion = diffusion
62 | self.data = data
63 | self.batch_size = batch_size
64 | self.microbatch = microbatch if microbatch > 0 else batch_size
65 | self.accumulation_steps = batch_size / microbatch if microbatch > 0 else 1
66 | self.lr = lr
67 | self.current_lr = lr
68 | self.ema_rate = (
69 | [ema_rate]
70 | if isinstance(ema_rate, float)
71 | else [float(x) for x in ema_rate.split(",")]
72 | )
73 | self.log_interval = log_interval
74 | self.save_interval = save_interval
75 | self.resume_checkpoint = resume_checkpoint
76 | self.use_fp16 = use_fp16
77 | self.fp16_scale_growth = fp16_scale_growth
78 | self.schedule_sampler = schedule_sampler or UniformSampler(diffusion)
79 | self.weight_decay = weight_decay
80 | self.anneal_type = anneal_type
81 | if self.anneal_type == 'linear':
82 | assert lr_anneal_steps != 0
83 | self.lr_anneal_steps = lr_anneal_steps
84 | if self.anneal_type == 'step':
85 | assert steps_drop != 0
86 | assert drop != 0
87 | self.steps_drop = steps_drop
88 | self.drop = drop
89 | if self.anneal_type == 'time_based':
90 | assert decay != 0
91 | self.decay = decay
92 |
93 | self.clip = clip
94 | self.max_num_mask_frames = max_num_mask_frames
95 | self.mask_range = mask_range
96 | self.uncondition_rate =uncondition_rate
97 | self.exclude_conditional = exclude_conditional
98 |
99 | self.step = 0
100 | self.resume_step = 0
101 | self.global_batch = self.batch_size * dist.get_world_size()
102 | logger.log(f"global batch size = {self.global_batch}")
103 |
104 | self.model_params = list(self.model.parameters())
105 | self.master_params = self.model_params
106 | self.lg_loss_scale = INITIAL_LOG_LOSS_SCALE
107 | self.sync_cuda = th.cuda.is_available()
108 |
109 | self._load_and_sync_parameters()
110 | if self.use_fp16:
111 | self._setup_fp16()
112 |
113 | self.opt = AdamW(self.master_params, lr=self.lr, weight_decay=self.weight_decay)
114 | if self.resume_step:
115 | self._load_optimizer_state()
116 | # Model was resumed, either due to a restart or a checkpoint
117 | # being specified at the command line.
118 | self.ema_params = [
119 | self._load_ema_parameters(rate) for rate in self.ema_rate
120 | ]
121 | else:
122 | self.ema_params = [
123 | copy.deepcopy(self.master_params) for _ in range(len(self.ema_rate))
124 | ]
125 |
126 | if th.cuda.is_available():
127 | logger.log(f"world_size: {dist.get_world_size()}")
128 | self.use_ddp = True
129 | self.ddp_model = DDP(
130 | self.model,
131 | device_ids=[dist_util.dev()],
132 | output_device=dist_util.dev(),
133 | # device_ids=[device],
134 | # output_device=device,
135 | broadcast_buffers=False,
136 | bucket_cap_mb=128,
137 | find_unused_parameters=False,
138 | )
139 | else:
140 | if dist.get_world_size() > 1:
141 | logger.warn(
142 | "Distributed training requires CUDA. "
143 | "Gradients will not be synchronized properly!"
144 | )
145 | self.use_ddp = False
146 | self.ddp_model = self.model
147 |
148 | def _load_and_sync_parameters(self):
149 | resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
150 |
151 | if resume_checkpoint:
152 | self.resume_step = parse_resume_step_from_filename(resume_checkpoint)
153 |
154 | # if dist.get_rank() == 0:
155 | logger.log(f"loading model from checkpoint: {resume_checkpoint}...")
156 | self.model.load_state_dict(
157 | dist_util.load_state_dict(
158 | resume_checkpoint, map_location=dist_util.dev()
159 | )
160 | )
161 |
162 |
163 | dist_util.sync_params(self.model.parameters())
164 |
165 | def _load_ema_parameters(self, rate):
166 | ema_params = copy.deepcopy(self.master_params)
167 |
168 | main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
169 | ema_checkpoint = find_ema_checkpoint(main_checkpoint, self.resume_step, rate)
170 | if ema_checkpoint:
171 | logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...")
172 | state_dict = dist_util.load_state_dict(
173 | ema_checkpoint, map_location=dist_util.dev()
174 | )
175 | ema_params = self._state_dict_to_master_params(state_dict)
176 |
177 | dist_util.sync_params(ema_params)
178 | return ema_params
179 |
180 | def _load_optimizer_state(self):
181 | main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
182 | opt_checkpoint = bf.join(
183 | bf.dirname(main_checkpoint), f"opt{self.resume_step:06}.pt"
184 | )
185 | if bf.exists(opt_checkpoint):
186 | logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}")
187 | state_dict = dist_util.load_opt_state_dict(
188 | opt_checkpoint, map_location=dist_util.dev()
189 | )
190 |
191 | self.opt.load_state_dict(state_dict)
192 |
193 | def _setup_fp16(self):
194 | self.master_params = make_master_params(self.model_params)
195 | self.model.convert_to_fp16()
196 |
197 | def run_loop(self):
198 |
199 | while (
200 | self.current_lr
201 | ):
202 | batch,_= next(self.data)
203 | min_val = torch.min(batch) # normalization for the nturgbd skeleton dataset
204 | max_val = torch.max(batch)
205 | range_val = max_val - min_val
206 | batch = (batch - min_val) / range_val
207 | batch = batch * 2 - 1
208 |
209 | # min_val = torch.min(cond_batch) # normalization for the nturgbd skeleton dataset
210 | # max_val = torch.max(cond_batch)
211 | # range_val = max_val - min_val
212 | # cond_batch = (cond_batch - min_val) / range_val
213 | # cond_batch = cond_batch * 2 - 1
214 |
215 | # if(batch.isnan().any()):
216 | # breakpoint()
217 |
218 | # breakpoint()
219 |
220 | self.run_step(batch)
221 | if self.step % self.log_interval == 0:
222 | logger.dumpkvs()
223 | if self.step % self.save_interval == 0:
224 | self.save()
225 | # Run for a finite amount of time in integration tests. Does access an environment variable
226 | if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0:
227 | return
228 | self.step += 1
229 | # Save the last checkpoint if it wasn't already saved.
230 | if (self.step - 1) % self.save_interval != 0:
231 | self.save()
232 |
233 | def run_step(self, batch):
234 | self.forward_backward(batch)
235 | if self.clip:
236 | th.nn.utils.clip_grad_norm_(self.ddp_model.parameters(), self.clip)
237 | if self.use_fp16:
238 | self.optimize_fp16()
239 | else:
240 | self.optimize_normal()
241 | self.log_step()
242 |
243 | def forward_backward(self, batch):
244 | zero_grad(self.model_params)
245 | for i in range(0, batch.shape[0], self.microbatch):
246 | micro = batch[i : i + self.microbatch].to(dist_util.dev())
247 | # breakpoint()
248 | # [B, C, T, W, H]
249 | """
250 | uncomment for CLEVRER
251 | """
252 | micro=micro.permute(0,2,1,3,4)
253 |
254 | # micro=micro.permute(0,4,1,2,3) # uncomment for something-something and nturgbd skeletondataset
255 | # micro = (micro - micro.min()) / (micro.max() - micro.min())
256 |
257 | # breakpoint()
258 |
259 | idx = torch.randperm(micro.shape[2]) # random frames of 5
260 | idx=idx[idx!=0]
261 | idx=torch.cat((torch.tensor([0]),idx))
262 | # breakpoint()
263 |
264 | micro_cond= micro[:,:,idx,:,:] # shuffled video
265 | micro = micro_cond.clone()
266 | # print('training mode')
267 | breakpoint()
268 |
269 | # code for removing any frame
270 | # number of frames to remove
271 | num_frames_to_remove = 6
272 | # #
273 | # # # create an array of indices for all frames
274 | frame_indices = np.arange(micro_cond.shape[2])
275 | #
276 | # # randomly select the indices of the frames to remove for all batches
277 | remove_indices = np.random.choice(frame_indices, num_frames_to_remove, replace=False)
278 | remove_indices = remove_indices[remove_indices!=0]
279 | #
280 | # # create an array of zeros with the same shape as the video frames
281 | dummy_frames = np.zeros(
282 | (micro_cond.shape[0], micro_cond.shape[1], 1, micro_cond.shape[3], micro_cond.shape[4]))
283 | dummy_frames = torch.from_numpy(dummy_frames).to(micro_cond.to(dist_util.dev())) #
284 | #
285 | # # replace the selected frames with the dummy frames for all batches
286 | # breakpoint()
287 | micro_cond[:, :, remove_indices, :, :] = dummy_frames
288 |
289 |
290 | p_enc_1d_model = PositionalEncoding1D(16)
291 | time_emb = p_enc_1d_model(torch.rand(micro.shape[0], micro.shape[2], 16))
292 | time_emb = time_emb[:,idx,:]
293 | last_batch = (i + self.microbatch) >= batch.shape[0]
294 | t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev())
295 | # breakpoint()
296 | mask = np.ones(micro.shape[2])
297 | mask[remove_indices] = 0
298 | mask = torch.tensor(mask)
299 | compute_losses = functools.partial(
300 | self.diffusion.training_losses_two_dis,
301 | self.ddp_model,
302 | micro,
303 | t,
304 | condition_data=micro_cond,#model_kwargs
305 | time_emb=time_emb.to(dist_util.dev()),
306 | # mask=mask.to(dist_util.dev()),
307 | max_num_mask_frames=self.max_num_mask_frames,
308 | mask_range=self.mask_range,
309 | uncondition_rate=self.uncondition_rate,
310 | exclude_conditional=self.exclude_conditional,
311 | )
312 | if last_batch or not self.use_ddp:
313 | losses = compute_losses()
314 | else:
315 | with self.ddp_model.no_sync():
316 | losses = compute_losses()
317 | if isinstance(self.schedule_sampler, LossAwareSampler):
318 | self.schedule_sampler.update_with_local_losses(
319 | t, losses["loss"].detach()
320 | )
321 | loss = (losses["loss"] * weights).mean()
322 | log_loss_dict(
323 | self.diffusion, t, {k: v * weights for k, v in losses.items()}
324 | )
325 | loss = loss / self.accumulation_steps
326 | if self.use_fp16:
327 | loss_scale = 2 ** self.lg_loss_scale
328 | (loss * loss_scale).backward()
329 | else:
330 | loss.backward()
331 |
332 | def optimize_fp16(self):
333 | if any(not th.isfinite(p.grad).all() for p in self.model_params):
334 | self.lg_loss_scale -= 1
335 | logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}")
336 | return
337 |
338 | model_grads_to_master_grads(self.model_params, self.master_params)
339 | self.master_params[0].grad.mul_(1.0 / (2 ** self.lg_loss_scale))
340 | self._log_grad_norm()
341 | self._anneal_lr()
342 | self.opt.step()
343 | for rate, params in zip(self.ema_rate, self.ema_params):
344 | update_ema(params, self.master_params, rate=rate)
345 | master_params_to_model_params(self.model_params, self.master_params)
346 | self.lg_loss_scale += self.fp16_scale_growth
347 |
348 | def optimize_normal(self):
349 | self._log_grad_norm()
350 | self._anneal_lr()
351 | self.opt.step()
352 | for rate, params in zip(self.ema_rate, self.ema_params):
353 | update_ema(params, self.master_params, rate=rate)
354 |
355 | def _log_grad_norm(self):
356 | sqsum = 0.0
357 | for p in self.master_params:
358 | sqsum += (p.grad ** 2).sum().item()
359 | logger.logkv_mean("grad_norm", np.sqrt(sqsum))
360 |
361 | def _anneal_lr(self):
362 | if self.anneal_type is None:
363 | return
364 | if self.anneal_type == "linear":
365 | frac_done = (self.step + self.resume_step) / self.lr_anneal_steps
366 | lr = self.lr * (1 - frac_done)
367 | elif self.anneal_type == "step":
368 | lr = self.lr * self.drop**(np.floor((self.step + self.resume_step)/self.steps_drop))
369 | elif self.anneal_type == "time_based":
370 | lr = self.lr / (1 + self.decay * (self.step + self.resume_step))
371 | else:
372 | raise ValueError(f"unsupported anneal type: {self.anneal_type}")
373 | for param_group in self.opt.param_groups:
374 | param_group["lr"] = lr
375 | self.current_lr = lr
376 |
377 | def log_step(self):
378 | logger.logkv("step", self.step + self.resume_step)
379 | logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch)
380 | if self.use_fp16:
381 | logger.logkv("lg_loss_scale", self.lg_loss_scale)
382 |
383 | def save(self):
384 | def save_checkpoint(rate, params):
385 | state_dict = self._master_params_to_state_dict(params)
386 | if dist.get_rank() == 0:
387 | logger.log(f"saving model {rate}...")
388 | if not rate:
389 | filename = f"model{(self.step+self.resume_step):06d}.pt"
390 | else:
391 | filename = f"ema_{rate}_{(self.step+self.resume_step):06d}.pt"
392 | with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f:
393 | th.save(state_dict, f)
394 |
395 | save_checkpoint(0, self.master_params)
396 | for rate, params in zip(self.ema_rate, self.ema_params):
397 | save_checkpoint(rate, params)
398 |
399 | if dist.get_rank() == 0:
400 | with bf.BlobFile(
401 | bf.join(get_blob_logdir(), f"opt{(self.step+self.resume_step):06d}.pt"),
402 | "wb",
403 | ) as f:
404 | th.save(self.opt.state_dict(), f)
405 |
406 | dist.barrier()
407 |
408 | def _master_params_to_state_dict(self, master_params):
409 | if self.use_fp16:
410 | master_params = unflatten_master_params(
411 | list(self.model.parameters()), master_params
412 | )
413 | state_dict = self.model.state_dict()
414 | for i, (name, _value) in enumerate(self.model.named_parameters()):
415 | assert name in state_dict
416 | state_dict[name] = master_params[i]
417 | return state_dict
418 |
419 | def _state_dict_to_master_params(self, state_dict):
420 | params = [state_dict[name] for name, _ in self.model.named_parameters()]
421 | if self.use_fp16:
422 | return make_master_params(params)
423 | else:
424 | return params
425 |
426 |
427 | def parse_resume_step_from_filename(filename):
428 | """
429 | Parse filenames of the form path/to/modelNNNNNN.pt, where NNNNNN is the
430 | checkpoint's number of steps.
431 | """
432 | split = filename.split("model")
433 | if len(split) < 2:
434 | return 0
435 | split1 = split[-1].split(".")[0]
436 | try:
437 | return int(split1)
438 | except ValueError:
439 | return 0
440 |
441 |
442 | def get_blob_logdir():
443 | return os.environ.get("DIFFUSION_BLOB_LOGDIR", logger.get_dir())
444 |
445 |
446 | def find_resume_checkpoint():
447 | # On your infrastructure, you may want to override this to automatically
448 | # discover the latest checkpoint on your blob storage, etc.
449 | return None
450 |
451 |
452 | def find_ema_checkpoint(main_checkpoint, step, rate):
453 | if main_checkpoint is None:
454 | return None
455 | filename = f"ema_{rate}_{(step):06d}.pt"
456 | path = bf.join(bf.dirname(main_checkpoint), filename)
457 | if bf.exists(path):
458 | return path
459 | return None
460 |
461 |
462 | def log_loss_dict(diffusion, ts, losses):
463 | for key, values in losses.items():
464 | logger.logkv_mean(key, values.mean().item())
465 | # Log the quantiles (four quartiles, in particular).
466 | for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()):
467 | quartile = int(4 * sub_t / diffusion.num_timesteps)
468 | logger.logkv_mean(f"{key}_q{quartile}", sub_loss)
469 |
--------------------------------------------------------------------------------
/video_model/diffusion/make_a_video.py:
--------------------------------------------------------------------------------
1 | import math
2 | import functools
3 | from operator import mul
4 |
5 | import torch
6 | from torch import nn, einsum
7 |
8 | from einops import rearrange, repeat, pack, unpack
9 | from einops.layers.torch import Rearrange
10 | from nn import (
11 | SiLU,
12 | conv_nd,
13 | linear,
14 | avg_pool_nd,
15 | zero_module,
16 | normalization,
17 | timestep_embedding,
18 | checkpoint,
19 | )
20 |
21 | # helper functions
22 |
23 | def exists(val):
24 | return val is not None
25 |
26 |
27 | def default(val, d):
28 | return val if exists(val) else d
29 |
30 |
31 | def mul_reduce(tup):
32 | return functools.reduce(mul, tup)
33 |
34 |
35 | def divisible_by(numer, denom):
36 | return (numer % denom) == 0
37 |
38 |
39 | mlist = nn.ModuleList
40 |
41 |
42 | # for time conditioning
43 |
44 | class SinusoidalPosEmb(nn.Module):
45 | def __init__(self, dim, theta=10000):
46 | super().__init__()
47 | self.theta = theta
48 | self.dim = dim
49 |
50 | def forward(self, x):
51 | dtype, device = x.dtype, x.device
52 | assert dtype == torch.float, 'input to sinusoidal pos emb must be a float type'
53 |
54 | half_dim = self.dim // 2
55 | emb = math.log(self.theta) / (half_dim - 1)
56 | emb = torch.exp(torch.arange(half_dim, device=device, dtype=dtype) * -emb)
57 | emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j')
58 | return torch.cat((emb.sin(), emb.cos()), dim=-1).type(dtype)
59 |
60 |
61 | # layernorm 3d
62 |
63 | class LayerNorm(nn.Module):
64 | def __init__(self, dim):
65 | super().__init__()
66 | self.g = nn.Parameter(torch.ones(dim))
67 |
68 | def forward(self, x):
69 | eps = 1e-5 if x.dtype == torch.float32 else 1e-3
70 | var = torch.var(x, dim=1, unbiased=False, keepdim=True)
71 | mean = torch.mean(x, dim=1, keepdim=True)
72 | return (x - mean) * var.clamp(min=eps).rsqrt() * self.g
73 |
74 |
75 | # feedforward
76 |
77 | class GEGLU(nn.Module):
78 | def forward(self, x):
79 | x, gate = x.chunk(2, dim=-1)
80 | return x * F.gelu(gate)
81 |
82 |
83 | def FeedForward(dim, mult=4):
84 | inner_dim = int(dim * mult * 2 / 3)
85 | return nn.Sequential(
86 | nn.Linear(dim, inner_dim, bias=False),
87 | GEGLU(),
88 | nn.Linear(inner_dim, bias=False)
89 | )
90 |
91 |
92 | # best relative positional encoding
93 |
94 | class ContinuousPositionBias(nn.Module):
95 | """ from https://arxiv.org/abs/2111.09883 """
96 |
97 | def __init__(
98 | self,
99 | *,
100 | dim,
101 | heads,
102 | num_dims=1,
103 | layers=2,
104 | log_dist=True,
105 | cache_rel_pos=False
106 | ):
107 | super().__init__()
108 | self.num_dims = num_dims
109 | self.log_dist = log_dist
110 |
111 | self.net = nn.ModuleList([])
112 | self.net.append(nn.Sequential(nn.Linear(self.num_dims, dim), nn.SiLU()))
113 |
114 | for _ in range(layers - 1):
115 | self.net.append(nn.Sequential(nn.Linear(dim, dim), nn.SiLU()))
116 |
117 | self.net.append(nn.Linear(dim, heads))
118 |
119 | self.cache_rel_pos = cache_rel_pos
120 | self.register_buffer('rel_pos', None, persistent=False)
121 |
122 | @property
123 | def device(self):
124 | return next(self.parameters()).device
125 |
126 | def forward(self, *dimensions):
127 | device = self.device
128 |
129 | if not exists(self.rel_pos) or not self.cache_rel_pos:
130 | positions = [torch.arange(d, device=device) for d in dimensions]
131 | grid = torch.stack(torch.meshgrid(*positions, indexing='ij'))
132 | grid = rearrange(grid, 'c ... -> (...) c')
133 | rel_pos = rearrange(grid, 'i c -> i 1 c') - rearrange(grid, 'j c -> 1 j c')
134 |
135 | if self.log_dist:
136 | rel_pos = torch.sign(rel_pos) * torch.log(rel_pos.abs() + 1)
137 |
138 | self.register_buffer('rel_pos', rel_pos, persistent=False)
139 |
140 | rel_pos = self.rel_pos.float()
141 |
142 | for layer in self.net:
143 | rel_pos = layer(rel_pos)
144 |
145 | return rearrange(rel_pos, 'i j h -> h i j')
146 |
147 |
148 | # helper classes
149 |
150 | class Attention(nn.Module):
151 | def __init__(
152 | self,
153 | dim,
154 | dim_head=64,
155 | heads=8
156 | ):
157 | super().__init__()
158 | self.heads = heads
159 | self.scale = dim_head ** -0.5
160 | inner_dim = dim_head * heads
161 |
162 | self.norm = LayerNorm(dim)
163 |
164 | self.to_q = nn.Linear(dim, inner_dim, bias=False)
165 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
166 | self.to_out = nn.Linear(inner_dim, dim, bias=False)
167 |
168 | nn.init.zeros_(self.to_out.weight.data) # identity with skip connection
169 |
170 | def forward(
171 | self,
172 | x,
173 | rel_pos_bias=None
174 | ):
175 | x = self.norm(x)
176 |
177 | q, k, v = self.to_q(x), *self.to_kv(x).chunk(2, dim=-1)
178 |
179 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), (q, k, v))
180 |
181 | q = q * self.scale
182 |
183 | sim = einsum('b h i d, b h j d -> b h i j', q, k)
184 |
185 | if exists(rel_pos_bias):
186 | sim = sim + rel_pos_bias
187 |
188 | attn = sim.softmax(dim=-1)
189 |
190 | out = einsum('b h i j, b h j d -> b h i d', attn, v)
191 |
192 | out = rearrange(out, 'b h n d -> b n (h d)')
193 | return self.to_out(out)
194 |
195 |
196 | # main contribution - pseudo 3d conv
197 |
198 | class PseudoConv3d(nn.Module):
199 | def __init__(
200 | self,
201 | dim,
202 | dim_out=None,
203 | kernel_size=3,
204 | *,
205 | temporal_kernel_size=None,
206 | **kwargs
207 | ):
208 | super().__init__()
209 | dim_out = default(dim_out, dim)
210 | temporal_kernel_size = default(temporal_kernel_size, kernel_size)
211 |
212 | self.spatial_conv = nn.Conv2d(dim, dim_out, kernel_size=kernel_size, padding=kernel_size // 2)
213 | self.temporal_conv = nn.Conv1d(dim_out, dim_out, kernel_size=temporal_kernel_size,
214 | padding=temporal_kernel_size // 2) if kernel_size > 1 else None
215 |
216 | if exists(self.temporal_conv):
217 | nn.init.dirac_(self.temporal_conv.weight.data) # initialized to be identity
218 | nn.init.zeros_(self.temporal_conv.bias.data)
219 |
220 | def forward(
221 | self,
222 | x,
223 | enable_time=True
224 | ):
225 | b, c, *_, h, w = x.shape
226 |
227 | is_video = x.ndim == 5
228 | enable_time &= is_video
229 |
230 | if is_video:
231 | x = rearrange(x, 'b c f h w -> (b f) c h w')
232 |
233 | x = self.spatial_conv(x)
234 |
235 | if is_video:
236 | x = rearrange(x, '(b f) c h w -> b c f h w', b=b)
237 |
238 | if not enable_time or not exists(self.temporal_conv):
239 | return x
240 |
241 | x = rearrange(x, 'b c f h w -> (b h w) c f')
242 |
243 | x = self.temporal_conv(x)
244 |
245 | x = rearrange(x, '(b h w) c f -> b c f h w', h=h, w=w)
246 |
247 | return x
248 |
249 |
250 | # factorized spatial temporal attention from Ho et al.
251 | # todo - take care of relative positional biases + rotary embeddings
252 |
253 | class SpatioTemporalAttention(nn.Module):
254 | def __init__(
255 | self,
256 | dim,
257 | *,
258 | dim_head=64,
259 | heads=8
260 | ):
261 | super().__init__()
262 | self.spatial_attn = Attention(dim=dim, dim_head=dim_head, heads=heads)
263 | self.spatial_rel_pos_bias = ContinuousPositionBias(dim=dim // 2, heads=heads, num_dims=2)
264 |
265 | self.temporal_attn = Attention(dim=dim, dim_head=dim_head, heads=heads)
266 | self.temporal_rel_pos_bias = ContinuousPositionBias(dim=dim // 2, heads=heads, num_dims=1)
267 |
268 | def forward(
269 | self,
270 | x,
271 | enable_time=True
272 | ):
273 | b, c, *_, h, w = x.shape
274 | is_video = x.ndim == 5
275 | enable_time &= is_video
276 |
277 | if is_video:
278 | x = rearrange(x, 'b c f h w -> (b f) (h w) c')
279 | else:
280 | x = rearrange(x, 'b c h w -> b (h w) c')
281 |
282 | space_rel_pos_bias = self.spatial_rel_pos_bias(h, w)
283 |
284 | x = self.spatial_attn(x, rel_pos_bias=space_rel_pos_bias) + x
285 |
286 | if is_video:
287 | x = rearrange(x, '(b f) (h w) c -> b c f h w', b=b, h=h, w=w)
288 | else:
289 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
290 |
291 | if not enable_time:
292 | return x
293 |
294 | x = rearrange(x, 'b c f h w -> (b h w) f c')
295 |
296 | time_rel_pos_bias = self.temporal_rel_pos_bias(x.shape[1])
297 |
298 | x = self.temporal_attn(x, rel_pos_bias=time_rel_pos_bias) + x
299 |
300 | x = rearrange(x, '(b h w) f c -> b c f h w', w=w, h=h)
301 |
302 | return x
303 |
304 |
305 | # resnet block
306 |
307 | class Block(nn.Module):
308 | def __init__(
309 | self,
310 | dim,
311 | dim_out,
312 | kernel_size=3,
313 | temporal_kernel_size=None,
314 | groups=8
315 | ):
316 | super().__init__()
317 | self.project = PseudoConv3d(dim, dim_out, 3)
318 | self.norm = nn.GroupNorm(groups, dim_out)
319 | self.act = nn.SiLU()
320 |
321 | def forward(
322 | self,
323 | x,
324 | scale_shift=None,
325 | enable_time=False
326 | ):
327 | x = self.project(x, enable_time=enable_time)
328 | x = self.norm(x)
329 |
330 | if exists(scale_shift):
331 | scale, shift = scale_shift
332 | x = x * (scale + 1) + shift
333 |
334 | return self.act(x)
335 |
336 |
337 | class ResnetBlock(nn.Module):
338 | def __init__(
339 | self,
340 | dim,
341 | dim_out,
342 | *,
343 | timestep_cond_dim=None,
344 | groups=8
345 | ):
346 | super().__init__()
347 |
348 | self.timestep_mlp = None
349 |
350 | if exists(timestep_cond_dim):
351 | self.timestep_mlp = nn.Sequential(
352 | nn.SiLU(),
353 | nn.Linear(timestep_cond_dim, dim_out * 2)
354 | )
355 |
356 | self.block1 = Block(dim, dim_out, groups=groups)
357 | self.block2 = Block(dim_out, dim_out, groups=groups)
358 | self.res_conv = PseudoConv3d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
359 |
360 | def forward(
361 | self,
362 | x,
363 | timestep_emb=None,
364 | enable_time=True
365 | ):
366 | assert not (exists(timestep_emb) ^ exists(self.timestep_mlp))
367 |
368 | scale_shift = None
369 |
370 | if exists(self.timestep_mlp) and exists(timestep_emb):
371 | time_emb = self.timestep_mlp(timestep_emb)
372 | to_einsum_eq = 'b c 1 1 1' if x.ndim == 5 else 'b c 1 1'
373 | time_emb = rearrange(time_emb, f'b c -> {to_einsum_eq}')
374 | scale_shift = time_emb.chunk(2, dim=1)
375 |
376 | h = self.block1(x, scale_shift=scale_shift, enable_time=enable_time)
377 |
378 | h = self.block2(h, enable_time=enable_time)
379 |
380 | return h + self.res_conv(x)
381 |
382 |
383 | # pixelshuffle upsamples and downsamples
384 | # where time dimension can be configured
385 |
386 | class Downsample(nn.Module):
387 | def __init__(
388 | self,
389 | dim,
390 | downsample_space=True,
391 | downsample_time=False,
392 | nonlin=False
393 | ):
394 | super().__init__()
395 | assert downsample_space or downsample_time
396 |
397 | self.down_space = nn.Sequential(
398 | Rearrange('b c (h p1) (w p2) -> b (c p1 p2) h w', p1=2, p2=2),
399 | nn.Conv2d(dim * 4, dim, 1, bias=False),
400 | nn.SiLU() if nonlin else nn.Identity()
401 | ) if downsample_space else None
402 |
403 | self.down_time = nn.Sequential(
404 | Rearrange('b c (f p) h w -> b (c p) f h w', p=2),
405 | nn.Conv3d(dim * 2, dim, 1, bias=False),
406 | nn.SiLU() if nonlin else nn.Identity()
407 | ) if downsample_time else None
408 |
409 | def forward(
410 | self,
411 | x,
412 | enable_time=True
413 | ):
414 | is_video = x.ndim == 5
415 |
416 | if is_video:
417 | x = rearrange(x, 'b c f h w -> b f c h w')
418 | x, ps = pack([x], '* c h w')
419 |
420 | if exists(self.down_space):
421 | x = self.down_space(x)
422 |
423 | if is_video:
424 | x, = unpack(x, ps, '* c h w')
425 | x = rearrange(x, 'b f c h w -> b c f h w')
426 |
427 | if not is_video or not exists(self.down_time) or not enable_time:
428 | return x
429 |
430 | x = self.down_time(x)
431 |
432 | return x
433 |
434 |
435 | class Upsample(nn.Module):
436 | def __init__(
437 | self,
438 | dim,
439 | upsample_space=True,
440 | upsample_time=False,
441 | nonlin=False
442 | ):
443 | super().__init__()
444 | assert upsample_space or upsample_time
445 |
446 | self.up_space = nn.Sequential(
447 | nn.Conv2d(dim, dim * 4, 1),
448 | nn.SiLU() if nonlin else nn.Identity(),
449 | Rearrange('b (c p1 p2) h w -> b c (h p1) (w p2)', p1=2, p2=2)
450 | ) if upsample_space else None
451 |
452 | self.up_time = nn.Sequential(
453 | nn.Conv3d(dim, dim * 2, 1),
454 | nn.SiLU() if nonlin else nn.Identity(),
455 | Rearrange('b (c p) f h w -> b c (f p) h w', p=2)
456 | ) if upsample_time else None
457 |
458 | self.init_()
459 |
460 | def init_(self):
461 | if exists(self.up_space):
462 | self.init_conv_(self.up_space[0], 4)
463 |
464 | if exists(self.up_time):
465 | self.init_conv_(self.up_time[0], 2)
466 |
467 | def init_conv_(self, conv, factor):
468 | o, *remain_dims = conv.weight.shape
469 | conv_weight = torch.empty(o // factor, *remain_dims)
470 | nn.init.kaiming_uniform_(conv_weight)
471 | conv_weight = repeat(conv_weight, 'o ... -> (o r) ...', r=factor)
472 |
473 | conv.weight.data.copy_(conv_weight)
474 | nn.init.zeros_(conv.bias.data)
475 |
476 | def forward(
477 | self,
478 | x,
479 | enable_time=True
480 | ):
481 | is_video = x.ndim == 5
482 |
483 | if is_video:
484 | x = rearrange(x, 'b c f h w -> b f c h w')
485 | x, ps = pack([x], '* c h w')
486 |
487 | if exists(self.up_space):
488 | x = self.up_space(x)
489 |
490 | if is_video:
491 | x, = unpack(x, ps, '* c h w')
492 | x = rearrange(x, 'b f c h w -> b c f h w')
493 |
494 | if not is_video or not exists(self.up_time) or not enable_time:
495 | return x
496 |
497 | x = self.up_time(x)
498 |
499 | return x
500 |
501 |
502 | # space time factorized 3d unet
503 |
504 | class SpaceTimeUnet(nn.Module):
505 | def __init__(
506 | self,
507 | channels=3,
508 | dim=64,
509 | dim_mult=(1, 2, 4, 8),
510 | self_attns=(False, False, False, True),
511 | temporal_compression=(False, True, True, True),
512 | resnet_block_depths=(2, 2, 2, 2),
513 | attn_dim_head=64,
514 | attn_heads=8,
515 | condition_on_timestep=True
516 | ):
517 | super().__init__()
518 | assert len(dim_mult) == len(self_attns) == len(temporal_compression) == len(resnet_block_depths)
519 | num_layers = len(dim_mult)
520 |
521 | dims = [dim, *map(lambda mult: mult * dim, dim_mult)]
522 | dim_in_out = zip(dims[:-1], dims[1:])
523 |
524 | # determine the valid multiples of the image size and frames of the video
525 |
526 | self.frame_multiple = 2 ** sum(tuple(map(int, temporal_compression)))
527 | self.image_size_multiple = 2 ** num_layers
528 |
529 | # timestep conditioning for DDPM, not to be confused with the time dimension of the video
530 |
531 | self.to_timestep_cond = None
532 | timestep_cond_dim = (dim * 4) if condition_on_timestep else None
533 |
534 | if condition_on_timestep:
535 | self.to_timestep_cond = nn.Sequential(
536 | SinusoidalPosEmb(dim),
537 | nn.Linear(dim, timestep_cond_dim),
538 | nn.SiLU()
539 | )
540 |
541 | # layers
542 |
543 | self.downs = mlist([])
544 | self.ups = mlist([])
545 |
546 | attn_kwargs = dict(
547 | dim_head=attn_dim_head,
548 | heads=attn_heads
549 | )
550 |
551 | mid_dim = dims[-1]
552 |
553 | self.mid_block1 = ResnetBlock(mid_dim, mid_dim, timestep_cond_dim=timestep_cond_dim)
554 | self.mid_attn = SpatioTemporalAttention(dim=mid_dim)
555 | self.mid_block2 = ResnetBlock(mid_dim, mid_dim, timestep_cond_dim=timestep_cond_dim)
556 |
557 | for _, self_attend, (dim_in, dim_out), compress_time, resnet_block_depth in zip(range(num_layers), self_attns,
558 | dim_in_out, temporal_compression,
559 | resnet_block_depths):
560 | assert resnet_block_depth >= 1
561 |
562 | self.downs.append(mlist([
563 | ResnetBlock(dim_in, dim_out, timestep_cond_dim=timestep_cond_dim),
564 | mlist([ResnetBlock(dim_out, dim_out) for _ in range(resnet_block_depth)]),
565 | SpatioTemporalAttention(dim=dim_out, **attn_kwargs) if self_attend else None,
566 | Downsample(dim_out, downsample_time=compress_time)
567 | ]))
568 |
569 | self.ups.append(mlist([
570 | ResnetBlock(dim_out * 2, dim_in, timestep_cond_dim=timestep_cond_dim),
571 | mlist([ResnetBlock(dim_in + (dim_out if ind == 0 else 0), dim_in) for ind in range(resnet_block_depth)]),
572 | SpatioTemporalAttention(dim=dim_in, **attn_kwargs) if self_attend else None,
573 | Upsample(dim_out, upsample_time=compress_time)
574 |
575 | ]))
576 |
577 | self.skip_scale = 2 ** -0.5 # paper shows faster convergence
578 |
579 | self.conv_in = PseudoConv3d(dim=channels, dim_out=dim, kernel_size=7, temporal_kernel_size=3)
580 | self.conv_out = PseudoConv3d(dim=dim, dim_out=channels, kernel_size=3, temporal_kernel_size=3)
581 |
582 | def forward(
583 | self,
584 | x,
585 | timestep=None,
586 | enable_time=True
587 | ):
588 |
589 | # some asserts
590 |
591 | assert not (exists(self.to_timestep_cond) ^ exists(timestep))
592 | is_video = x.ndim == 5
593 |
594 | if enable_time and is_video:
595 | frames = x.shape[2]
596 | assert divisible_by(frames,
597 | self.frame_multiple), f'number of frames on the video ({frames}) must be divisible by the frame multiple ({self.frame_multiple})'
598 |
599 | height, width = x.shape[-2:]
600 | assert divisible_by(height, self.image_size_multiple) and divisible_by(width,
601 | self.image_size_multiple), f'height and width of the image or video must be a multiple of {self.image_size_multiple}'
602 |
603 | # main logic
604 |
605 | t = self.to_timestep_cond(rearrange(timestep, '... -> (...)')) if exists(timestep) else None
606 |
607 | x = self.conv_in(x, enable_time=enable_time)
608 |
609 | hiddens = []
610 |
611 | for init_block, blocks, maybe_attention, downsample in self.downs:
612 | x = init_block(x, t, enable_time=enable_time)
613 |
614 | hiddens.append(x.clone())
615 |
616 | for block in blocks:
617 | x = block(x, enable_time=enable_time)
618 |
619 | if exists(maybe_attention):
620 | x = maybe_attention(x, enable_time=enable_time)
621 |
622 | hiddens.append(x.clone())
623 |
624 | x = downsample(x, enable_time=enable_time)
625 |
626 | x = self.mid_block1(x, t, enable_time=enable_time)
627 | x = self.mid_attn(x, enable_time=enable_time)
628 | x = self.mid_block2(x, t, enable_time=enable_time)
629 |
630 | for init_block, blocks, maybe_attention, upsample in reversed(self.ups):
631 | x = upsample(x, enable_time=enable_time)
632 |
633 | x = torch.cat((hiddens.pop() * self.skip_scale, x), dim=1)
634 |
635 | x = init_block(x, t, enable_time=enable_time)
636 |
637 | x = torch.cat((hiddens.pop() * self.skip_scale, x), dim=1)
638 |
639 | for block in blocks:
640 | x = block(x, enable_time=enable_time)
641 |
642 | if exists(maybe_attention):
643 | x = maybe_attention(x, enable_time=enable_time)
644 |
645 | x = self.conv_out(x, enable_time=enable_time)
646 | return x
647 |
648 | class IdxPromptSTUnet(SpaceTimeUnet):
649 | """
650 | A UNetModel that performs super-resolution.
651 |
652 | Expects an extra kwarg `low_res` to condition on a low-resolution image.
653 | """
654 |
655 | def __init__(self, channels, *args, **kwargs):
656 | super().__init__(channels +1, *args, **kwargs)
657 | self.in_mlp = linear(16,1024)
658 | self.out_mlp = linear(1024,16)
659 |
660 |
661 | def forward(self, x, timesteps, time_emb):
662 |
663 | import dist_util
664 | dist_util.setup_dist()
665 |
666 | prompt = self.in_mlp(time_emb).reshape(int(time_emb.shape[0]),int(time_emb.shape[1]),32,32)
667 | prompt=prompt.to(dist_util.dev())
668 | prompt=prompt.unsqueeze(1)
669 | ##x: [B,C,T,W,H]
670 | # breakpoint()
671 | out = torch.cat([x.float(), prompt.float()], dim=1)
672 | output = super().forward(out, timesteps.float())
673 | # breakpoint()
674 | video_output = output[:,0:-1,:,:]
675 | time_emb_output = output[:, -1:, :, :]
676 |
677 | # import pdb
678 | # pdb.set_trace()
679 | time_emb_output = time_emb_output.reshape(int(time_emb_output.shape[0]),int(time_emb_output.shape[2]),1024)
680 | time_emb_output = self.out_mlp(time_emb_output)
681 | return video_output,time_emb_output
--------------------------------------------------------------------------------