├── PVDM ├── .gitignore ├── assets │ ├── sky_long.gif │ └── ucf101_long.gif ├── configs │ ├── autoencoder │ │ ├── base.yaml │ │ └── base_gan.yaml │ └── latent-diffusion │ │ └── base.yaml ├── evals │ ├── __init__.py │ ├── eval.py │ └── fvd │ │ ├── __init__.py │ │ ├── convert_tf_pretrained.py │ │ ├── download.py │ │ ├── fvd.py │ │ └── pytorch_i3d.py ├── exps │ ├── __init__.py │ ├── diffusion.py │ └── first_stage.py ├── losses │ ├── ddpm.py │ ├── diffaugment.py │ ├── lpips.py │ ├── perceptual.py │ └── vgg.pth ├── main.py ├── metric_utils.py ├── models │ ├── __init__.py │ ├── autoencoder │ │ ├── __init__.py │ │ ├── autoencoder_vit.py │ │ └── vit_modules.py │ ├── ddpm │ │ ├── __init__.py │ │ ├── diffusionmodules.py │ │ └── unet.py │ └── ema.py ├── tools │ ├── __init__.py │ ├── data_utils.py │ ├── dataloader.py │ ├── scheduler.py │ ├── trainer.py │ └── video_utils.py └── utils.py ├── README.md ├── images └── teaser.png ├── inv_dyn └── inv_dyn_ft.py └── task_subgoal_consistency ├── __init__.py ├── arguments.py ├── datasets.py ├── networks.py ├── train.py └── utils.py /PVDM/.gitignore: -------------------------------------------------------------------------------- 1 | sftp-config.json 2 | .DS_Store 3 | -------------------------------------------------------------------------------- /PVDM/assets/sky_long.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anuragajay/hip/f83186af477bf6d8c5f8f3c1fdef43ccb89db25d/PVDM/assets/sky_long.gif -------------------------------------------------------------------------------- /PVDM/assets/ucf101_long.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anuragajay/hip/f83186af477bf6d8c5f8f3c1fdef43ccb89db25d/PVDM/assets/ucf101_long.gif -------------------------------------------------------------------------------- /PVDM/configs/autoencoder/base.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | resume: False 3 | amp: True 4 | base_learning_rate: 1.0e-4 5 | params: 6 | embed_dim: 4 7 | lossconfig: 8 | params: 9 | disc_start: 100000000 10 | 11 | ddconfig: 12 | double_z: False 13 | channels: 384 14 | resolution: 64 15 | timesteps: 16 16 | skip: 1 17 | in_channels: 3 18 | out_ch: 3 19 | num_res_blocks: 2 20 | attn_resolutions: [] 21 | splits: 1 22 | -------------------------------------------------------------------------------- /PVDM/configs/autoencoder/base_gan.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | resume: True 3 | amp: True 4 | base_learning_rate: 1.0e-4 5 | params: 6 | embed_dim: 4 7 | lossconfig: 8 | params: 9 | disc_start: -1 10 | 11 | ddconfig: 12 | double_z: False 13 | channels: 384 14 | resolution: 64 15 | timesteps: 16 16 | skip: 1 17 | in_channels: 3 18 | out_ch: 3 19 | num_res_blocks: 2 20 | attn_resolutions: [] 21 | splits: 1 22 | -------------------------------------------------------------------------------- /PVDM/configs/latent-diffusion/base.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-4 # set to target_lr by starting main.py with '--scale_lr False' 3 | cond_model: True 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | loss_type: l1 11 | first_stage_key: "image" 12 | cond_stage_key: "image" 13 | image_size: 32 14 | channels: 4 15 | cond_stage_trainable: False 16 | concat_mode: False 17 | scale_by_std: True 18 | w: 0. 19 | 20 | scheduler_config: # 10000 warmup steps 21 | warm_up_steps: [10000] 22 | cycle_lengths: [10000000000000] 23 | f_start: [1.e-6] 24 | f_max: [1.] 25 | f_min: [ 1.] 26 | 27 | unet_config: 28 | image_size: 32 29 | in_channels: 4 30 | out_channels: 4 31 | model_channels: 128 32 | attention_resolutions: [4,2,1] # 32, 16, 8, 4, 33 | num_res_blocks: 2 34 | channel_mult: [1,2,4] # 32, 16, 8, 4, 2 35 | num_heads: 8 36 | use_scale_shift_norm: True 37 | resblock_updown: True 38 | cond_model: True -------------------------------------------------------------------------------- /PVDM/evals/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anuragajay/hip/f83186af477bf6d8c5f8f3c1fdef43ccb89db25d/PVDM/evals/__init__.py -------------------------------------------------------------------------------- /PVDM/evals/eval.py: -------------------------------------------------------------------------------- 1 | import time 2 | import sys; sys.path.extend(['.', 'src']) 3 | import numpy as np 4 | import torch 5 | from utils import AverageMeter 6 | from torchvision.utils import save_image, make_grid 7 | from einops import rearrange 8 | from losses.ddpm import DDPM 9 | from torch.cuda.amp import GradScaler, autocast 10 | 11 | from evals.fvd.fvd import get_fvd_logits, frechet_distance 12 | from evals.fvd.download import load_i3d_pretrained 13 | import os 14 | 15 | import torchvision 16 | import PIL 17 | 18 | def save_image_grid(img, fname, drange, grid_size, normalize=True): 19 | if normalize: 20 | lo, hi = drange 21 | img = np.asarray(img, dtype=np.float32) 22 | img = (img - lo) * (255 / (hi - lo)) 23 | img = np.rint(img).clip(0, 255).astype(np.uint8) 24 | 25 | gw, gh = grid_size 26 | _N, C, T, H, W = img.shape 27 | img = img.reshape(gh, gw, C, T, H, W) 28 | img = img.transpose(3, 0, 4, 1, 5, 2) 29 | img = img.reshape(T, gh * H, gw * W, C) 30 | 31 | print (f'Saving Video with {T} frames, img shape {H}, {W}') 32 | 33 | assert C in [3] 34 | 35 | if C == 3: 36 | torchvision.io.write_video(f'{fname[:-3]}mp4', torch.from_numpy(img), fps=16) 37 | imgs = [PIL.Image.fromarray(img[i], 'RGB') for i in range(len(img))] 38 | imgs[0].save(fname, quality=95, save_all=True, append_images=imgs[1:], duration=100, loop=0) 39 | 40 | return img 41 | 42 | def test_psnr(rank, model, loader, it, logger=None): 43 | device = torch.device('cuda', rank) 44 | 45 | losses = dict() 46 | losses['psnr'] = AverageMeter() 47 | check = time.time() 48 | 49 | model.eval() 50 | with torch.no_grad(): 51 | for n, (x, _) in enumerate(loader): 52 | if n > 100: 53 | break 54 | x = x.permute(0, 1, 4, 2, 3) 55 | x = x.contiguous() 56 | 57 | batch_size = x.size(0) 58 | clip_length = x.size(1) 59 | x = x.to(device) / 127.5 - 1 60 | recon, _ = model(rearrange(x, 'b t c h w -> b c t h w')) 61 | 62 | x = x.view(batch_size, -1) 63 | recon = recon.view(batch_size, -1) 64 | 65 | mse = ((x * 0.5 - recon * 0.5) ** 2).mean(dim=-1) 66 | psnr = (-10 * torch.log10(mse)).mean() 67 | 68 | losses['psnr'].update(psnr.item(), batch_size) 69 | 70 | 71 | model.train() 72 | return losses['psnr'].average 73 | 74 | def test_ifvd(rank, model, loader, it, logger=None): 75 | device = torch.device('cuda', rank) 76 | 77 | losses = dict() 78 | losses['fvd'] = AverageMeter() 79 | check = time.time() 80 | 81 | real_embeddings = [] 82 | fake_embeddings = [] 83 | fakes = [] 84 | reals = [] 85 | 86 | model.eval() 87 | i3d = load_i3d_pretrained(device) 88 | 89 | with torch.no_grad(): 90 | for n, (real, idx) in enumerate(loader): 91 | if n > 512: 92 | break 93 | real = real.permute(0, 1, 4, 2, 3) 94 | real = real.contiguous() 95 | 96 | batch_size = real.size(0) 97 | clip_length = real.size(1) 98 | real = real.to(device) 99 | fake, _ = model(rearrange(real / 127.5 - 1, 'b t c h w -> b c t h w')) 100 | 101 | real = rearrange(real, 'b t c h w -> b t h w c') # videos 102 | fake = rearrange((fake.clamp(-1,1) + 1) * 127.5, '(b t) c h w -> b t h w c', b=real.size(0)) 103 | 104 | real = real.type(torch.uint8).cpu() 105 | fake = fake.type(torch.uint8) 106 | 107 | real_embeddings.append(get_fvd_logits(real.numpy(), i3d=i3d, device=device)) 108 | fake_embeddings.append(get_fvd_logits(fake.cpu().numpy(), i3d=i3d, device=device)) 109 | if len(fakes) < 16: 110 | reals.append(rearrange(real[0:1], 'b t h w c -> b c t h w')) 111 | fakes.append(rearrange(fake[0:1], 'b t h w c -> b c t h w')) 112 | 113 | model.train() 114 | 115 | reals = torch.cat(reals) 116 | fakes = torch.cat(fakes) 117 | 118 | if rank == 0: 119 | real_vid = save_image_grid(reals.cpu().numpy(), os.path.join(logger.logdir, "real.gif"), drange=[0, 255], grid_size=(4,4)) 120 | fake_vid = save_image_grid(fakes.cpu().numpy(), os.path.join(logger.logdir, f'generated_{it}.gif'), drange=[0, 255], grid_size=(4,4)) 121 | 122 | if it == 0: 123 | real_vid = np.expand_dims(real_vid,0).transpose(0, 1, 4, 2, 3) 124 | logger.video_summary('real', real_vid, it) 125 | 126 | fake_vid = np.expand_dims(fake_vid,0).transpose(0, 1, 4, 2, 3) 127 | logger.video_summary('recon', fake_vid, it) 128 | 129 | real_embeddings = torch.cat(real_embeddings) 130 | fake_embeddings = torch.cat(fake_embeddings) 131 | 132 | fvd = frechet_distance(fake_embeddings.clone().detach(), real_embeddings.clone().detach()) 133 | return fvd.item() 134 | 135 | 136 | def test_fvd_ddpm(rank, ema_model, decoder, loader, it, tokenizer, text_model, uncond_latents, logger=None): 137 | device = torch.device('cuda', rank) 138 | 139 | losses = dict() 140 | losses['fvd'] = AverageMeter() 141 | check = time.time() 142 | 143 | cond_model = ema_model.diffusion_model.cond_model 144 | 145 | diffusion_model = DDPM(ema_model, 146 | channels=ema_model.diffusion_model.in_channels, 147 | image_size=ema_model.diffusion_model.image_size, 148 | sampling_timesteps=1000, 149 | w=0.).to(device) 150 | real_embeddings = [] 151 | pred_embeddings = [] 152 | 153 | reals = [] 154 | predictions = [] 155 | 156 | batch_size = loader.batch_size 157 | 158 | i3d = load_i3d_pretrained(device) 159 | 160 | if cond_model: 161 | with torch.no_grad(): 162 | for n, (x, text) in enumerate(loader): 163 | x = x.to(device) 164 | x = rearrange(x / 127.5 - 1, 'b t h w c -> b c t h w') # videos 165 | 166 | k = min(4, x.size(0)) 167 | if n >= 4: 168 | break 169 | 170 | tokens = torch.LongTensor([tokenizer(text[i].tobytes().decode('ascii'), padding='max_length', max_length=15).input_ids for i in range(batch_size)]).to(device) 171 | text_latents = text_model(tokens).last_hidden_state.detach() 172 | text_latents = text_latents[:k] 173 | 174 | real = x[:k,:,:,:,:] 175 | c = x[:k,:,0:1,:,:].repeat(1,1,x.shape[2],1,1) 176 | 177 | with autocast(): 178 | c = decoder.extract(c).detach() 179 | 180 | z = diffusion_model.sample(batch_size=k, cond=c, context=text_latents, uncond_latents=uncond_latents[:k]) 181 | pred = decoder.decode_from_sample(z).clamp(-1,1).cpu() 182 | 183 | pred = (1 + rearrange(pred, '(b t) c h w -> b t h w c', b=k)) * 127.5 184 | pred = pred.type(torch.uint8) 185 | pred_embeddings.append(get_fvd_logits(pred.numpy(), i3d=i3d, device=device)) 186 | 187 | real = (1 + rearrange(real, 'b c t h w -> b t h w c')) * 127.5 188 | real = real.type(torch.uint8) 189 | real_embeddings.append(get_fvd_logits(real.cpu().numpy(), i3d=i3d, device=device)) 190 | 191 | if len(predictions) < 4: 192 | reals.append(rearrange(real, 'b t h w c -> b c t h w')) 193 | predictions.append(rearrange(pred, 'b t h w c -> b c t h w')) 194 | 195 | reals = torch.cat(reals) 196 | predictions = torch.cat(predictions) 197 | 198 | real_embeddings = torch.cat(real_embeddings) 199 | pred_embeddings = torch.cat(pred_embeddings) 200 | 201 | if rank == 0: 202 | real_vid = save_image_grid(reals.cpu().numpy(), os.path.join(logger.logdir, f'real_{it}.gif'), drange=[0, 255], grid_size=(k,4)) 203 | real_vid = np.expand_dims(real_vid,0).transpose(0, 1, 4, 2, 3) 204 | pred_vid = save_image_grid(predictions.cpu().numpy(), os.path.join(logger.logdir, f'predicted_{it}.gif'), drange=[0, 255], grid_size=(k,4)) 205 | pred_vid = np.expand_dims(pred_vid,0).transpose(0, 1, 4, 2, 3) 206 | 207 | logger.video_summary('real', real_vid, it) 208 | logger.video_summary('prediction', pred_vid, it) 209 | else: 210 | raise NotImplementedError 211 | 212 | fvd = frechet_distance(pred_embeddings.clone().detach(), real_embeddings.clone().detach()) 213 | return fvd.item() 214 | 215 | -------------------------------------------------------------------------------- /PVDM/evals/fvd/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anuragajay/hip/f83186af477bf6d8c5f8f3c1fdef43ccb89db25d/PVDM/evals/fvd/__init__.py -------------------------------------------------------------------------------- /PVDM/evals/fvd/convert_tf_pretrained.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from collections import OrderedDict 3 | import tensorflow_hub as hub 4 | import torch 5 | 6 | from src_pytorch.fvd.pytorch_i3d import InceptionI3d 7 | 8 | 9 | def convert_name(name): 10 | mapping = { 11 | 'conv_3d': 'conv3d', 12 | 'batch_norm': 'bn', 13 | 'w:0': 'weight', 14 | 'b:0': 'bias', 15 | 'moving_mean:0': 'running_mean', 16 | 'moving_variance:0': 'running_var', 17 | 'beta:0': 'bias' 18 | } 19 | 20 | segs = name.split('/') 21 | new_segs = [] 22 | i = 0 23 | while i < len(segs): 24 | seg = segs[i] 25 | if 'Mixed' in seg: 26 | new_segs.append(seg) 27 | elif 'Conv' in seg and 'Mixed' not in name: 28 | new_segs.append(seg) 29 | elif 'Branch' in seg: 30 | branch_i = int(seg.split('_')[-1]) 31 | i += 1 32 | seg = segs[i] 33 | 34 | # special case due to typo in original code 35 | if 'Mixed_5b' in name and branch_i == 2: 36 | if '1x1' in seg: 37 | new_segs.append(f'b{branch_i}a') 38 | elif '3x3' in seg: 39 | new_segs.append(f'b{branch_i}b') 40 | else: 41 | raise Exception() 42 | # Either Conv3d_{i}a_... or Conv3d_{i}b_... 43 | elif 'a' in seg: 44 | if branch_i == 0: 45 | new_segs.append('b0') 46 | else: 47 | new_segs.append(f'b{branch_i}a') 48 | elif 'b' in seg: 49 | new_segs.append(f'b{branch_i}b') 50 | else: 51 | raise Exception 52 | elif seg == 'Logits': 53 | new_segs.append('logits') 54 | i += 1 55 | elif seg in mapping: 56 | new_segs.append(mapping[seg]) 57 | else: 58 | raise Exception(f"No match found for seg {seg} in name {name}") 59 | 60 | i += 1 61 | return '.'.join(new_segs) 62 | 63 | def convert_tensor(tensor): 64 | tensor_dim = len(tensor.shape) 65 | if tensor_dim == 5: # conv or bn 66 | if all([t == 1 for t in tensor.shape[:-1]]): 67 | tensor = tensor.squeeze() 68 | else: 69 | tensor = tensor.permute(4, 3, 0, 1, 2).contiguous() 70 | elif tensor_dim == 1: # conv bias 71 | pass 72 | else: 73 | raise Exception(f"Invalid shape {tensor.shape}") 74 | return tensor 75 | 76 | n_class = int(sys.argv[1]) # 600 or 400 77 | assert n_class in [400, 600] 78 | 79 | # Converts model from https://github.com/google-research/google-research/tree/master/frechet_video_distance 80 | # to pytorch version for loading 81 | model_url = f"https://tfhub.dev/deepmind/i3d-kinetics-{n_class}/1" 82 | i3d = hub.load(model_url) 83 | name_prefix = 'RGB/inception_i3d/' 84 | 85 | print('Creating state_dict...') 86 | all_names = [] 87 | state_dict = OrderedDict() 88 | for var in i3d.variables: 89 | name = var.name[len(name_prefix):] 90 | new_name = convert_name(name) 91 | all_names.append(new_name) 92 | 93 | tensor = torch.FloatTensor(var.value().numpy()) 94 | new_tensor = convert_tensor(tensor) 95 | 96 | state_dict[new_name] = new_tensor 97 | 98 | if 'bn.bias' in new_name: 99 | new_name = new_name[:-4] + 'weight' # bn.weight 100 | new_tensor = torch.ones_like(new_tensor).float() 101 | state_dict[new_name] = new_tensor 102 | 103 | print(f'Complete state_dict with {len(state_dict)} entries') 104 | 105 | s = dict() 106 | for i, n in enumerate(all_names): 107 | s[n] = s.get(n, []) + [i] 108 | 109 | for k, v in s.items(): 110 | if len(v) > 1: 111 | print('dup', k) 112 | for i in v: 113 | print('\t', i3d.variables[i].name) 114 | 115 | print('Testing load_state_dict...') 116 | print('Creating model...') 117 | 118 | i3d = InceptionI3d(n_class, in_channels=3) 119 | 120 | print('Loading state_dict...') 121 | i3d.load_state_dict(state_dict) 122 | 123 | print(f'Saving state_dict as fvd/i3d_pretrained_{n_class}.pt') 124 | torch.save(state_dict, f'fvd/i3d_pretrained_{n_class}.pt') 125 | 126 | print('Done') 127 | 128 | -------------------------------------------------------------------------------- /PVDM/evals/fvd/download.py: -------------------------------------------------------------------------------- 1 | import requests 2 | from tqdm import tqdm 3 | import os 4 | import torch 5 | 6 | from utils import download 7 | 8 | def get_confirm_token(response): 9 | for key, value in response.cookies.items(): 10 | if key.startswith('download_warning'): 11 | return value 12 | return None 13 | 14 | 15 | def save_response_content(response, destination): 16 | CHUNK_SIZE = 8192 17 | 18 | pbar = tqdm(total=0, unit='iB', unit_scale=True) 19 | with open(destination, 'wb') as f: 20 | for chunk in response.iter_content(CHUNK_SIZE): 21 | if chunk: 22 | f.write(chunk) 23 | pbar.update(len(chunk)) 24 | pbar.close() 25 | 26 | 27 | _I3D_PRETRAINED_ID = '1fBNl3TS0LA5FEhZv5nMGJs2_7qQmvTmh' 28 | 29 | def load_i3d_pretrained(device=torch.device('cpu')): 30 | from evals.fvd.pytorch_i3d import InceptionI3d 31 | i3d = InceptionI3d(400, in_channels=3).to(device) 32 | filepath = download(_I3D_PRETRAINED_ID, 'i3d_pretrained_400.pt') 33 | i3d.load_state_dict(torch.load(filepath, map_location=device)) 34 | i3d.eval() 35 | return i3d 36 | -------------------------------------------------------------------------------- /PVDM/evals/fvd/fvd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import torch.nn.functional as F 4 | 5 | def preprocess_single(video, resolution, sequence_length=None): 6 | # video: THWC, {0, ..., 255} 7 | video = video.permute(0, 3, 1, 2).float() / 255. # TCHW 8 | t, c, h, w = video.shape 9 | 10 | # temporal crop 11 | if sequence_length is not None: 12 | assert sequence_length <= t 13 | video = video[:sequence_length] 14 | 15 | # scale shorter side to resolution 16 | scale = resolution / min(h, w) 17 | if h < w: 18 | target_size = (resolution, math.ceil(w * scale)) 19 | else: 20 | target_size = (math.ceil(h * scale), resolution) 21 | video = F.interpolate(video, size=target_size, mode='bilinear', 22 | align_corners=False) 23 | 24 | # center crop 25 | t, c, h, w = video.shape 26 | w_start = (w - resolution) // 2 27 | h_start = (h - resolution) // 2 28 | video = video[:, :, h_start:h_start + resolution, w_start:w_start + resolution] 29 | video = video.permute(1, 0, 2, 3).contiguous() # CTHW 30 | 31 | video -= 0.5 32 | 33 | return video 34 | 35 | def preprocess(videos, target_resolution=224): 36 | # videos in {0, ..., 255} as np.uint8 array 37 | b, t, h, w, c = videos.shape 38 | videos = torch.from_numpy(videos) 39 | videos = torch.stack([preprocess_single(video, target_resolution) for video in videos]) 40 | return videos * 2 # [-0.5, 0.5] -> [-1, 1] 41 | 42 | def get_fvd_logits(videos, i3d, device): 43 | videos = preprocess(videos) 44 | embeddings = get_logits(i3d, videos, device) 45 | return embeddings 46 | 47 | # https://github.com/tensorflow/gan/blob/de4b8da3853058ea380a6152bd3bd454013bf619/tensorflow_gan/python/eval/classifier_metrics.py#L161 48 | def _symmetric_matrix_square_root(mat, eps=1e-10): 49 | u, s, v = torch.svd(mat) 50 | si = torch.where(s < eps, s, torch.sqrt(s)) 51 | return torch.matmul(torch.matmul(u, torch.diag(si)), v.t()) 52 | 53 | # https://github.com/tensorflow/gan/blob/de4b8da3853058ea380a6152bd3bd454013bf619/tensorflow_gan/python/eval/classifier_metrics.py#L400 54 | def trace_sqrt_product(sigma, sigma_v): 55 | sqrt_sigma = _symmetric_matrix_square_root(sigma) 56 | sqrt_a_sigmav_a = torch.matmul(sqrt_sigma, torch.matmul(sigma_v, sqrt_sigma)) 57 | return torch.trace(_symmetric_matrix_square_root(sqrt_a_sigmav_a)) 58 | 59 | # https://discuss.pytorch.org/t/covariance-and-gradient-support/16217/2 60 | def cov(m, rowvar=False): 61 | '''Estimate a covariance matrix given data. 62 | 63 | Covariance indicates the level to which two variables vary together. 64 | If we examine N-dimensional samples, `X = [x_1, x_2, ... x_N]^T`, 65 | then the covariance matrix element `C_{ij}` is the covariance of 66 | `x_i` and `x_j`. The element `C_{ii}` is the variance of `x_i`. 67 | 68 | Args: 69 | m: A 1-D or 2-D array containing multiple variables and observations. 70 | Each row of `m` represents a variable, and each column a single 71 | observation of all those variables. 72 | rowvar: If `rowvar` is True, then each row represents a 73 | variable, with observations in the columns. Otherwise, the 74 | relationship is transposed: each column represents a variable, 75 | while the rows contain observations. 76 | 77 | Returns: 78 | The covariance matrix of the variables. 79 | ''' 80 | if m.dim() > 2: 81 | raise ValueError('m has more than 2 dimensions') 82 | if m.dim() < 2: 83 | m = m.view(1, -1) 84 | if not rowvar and m.size(0) != 1: 85 | m = m.t() 86 | 87 | fact = 1.0 / (m.size(1) - 1) # unbiased estimate 88 | m -= torch.mean(m, dim=1, keepdim=True) 89 | mt = m.t() # if complex: mt = m.t().conj() 90 | return fact * m.matmul(mt).squeeze() 91 | 92 | 93 | def frechet_distance(x1, x2): 94 | x1 = x1.flatten(start_dim=1) 95 | x2 = x2.flatten(start_dim=1) 96 | m, m_w = x1.mean(dim=0), x2.mean(dim=0) 97 | sigma, sigma_w = cov(x1, rowvar=False), cov(x2, rowvar=False) 98 | 99 | sqrt_trace_component = trace_sqrt_product(sigma, sigma_w) 100 | trace = torch.trace(sigma + sigma_w) - 2.0 * sqrt_trace_component 101 | 102 | mean = torch.sum((m - m_w) ** 2) 103 | fd = trace + mean 104 | return fd 105 | 106 | 107 | def get_logits(i3d, videos, device): 108 | """ 109 | assert videos.shape[0] % 16 == 0 110 | with torch.no_grad(): 111 | logits = [] 112 | for i in range(0, videos.shape[0], 16): 113 | batch = videos[i:i + 16].to(device) 114 | logits.append(i3d(batch)) 115 | logits = torch.cat(logits, dim=0) 116 | return logits 117 | """ 118 | 119 | with torch.no_grad(): 120 | logits = i3d(videos.to(device)) 121 | return logits 122 | -------------------------------------------------------------------------------- /PVDM/evals/fvd/pytorch_i3d.py: -------------------------------------------------------------------------------- 1 | # Original code from https://github.com/piergiaj/pytorch-i3d 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | 7 | class MaxPool3dSamePadding(nn.MaxPool3d): 8 | 9 | def compute_pad(self, dim, s): 10 | if s % self.stride[dim] == 0: 11 | return max(self.kernel_size[dim] - self.stride[dim], 0) 12 | else: 13 | return max(self.kernel_size[dim] - (s % self.stride[dim]), 0) 14 | 15 | def forward(self, x): 16 | # compute 'same' padding 17 | (batch, channel, t, h, w) = x.size() 18 | out_t = np.ceil(float(t) / float(self.stride[0])) 19 | out_h = np.ceil(float(h) / float(self.stride[1])) 20 | out_w = np.ceil(float(w) / float(self.stride[2])) 21 | pad_t = self.compute_pad(0, t) 22 | pad_h = self.compute_pad(1, h) 23 | pad_w = self.compute_pad(2, w) 24 | 25 | pad_t_f = pad_t // 2 26 | pad_t_b = pad_t - pad_t_f 27 | pad_h_f = pad_h // 2 28 | pad_h_b = pad_h - pad_h_f 29 | pad_w_f = pad_w // 2 30 | pad_w_b = pad_w - pad_w_f 31 | 32 | pad = (pad_w_f, pad_w_b, pad_h_f, pad_h_b, pad_t_f, pad_t_b) 33 | x = F.pad(x, pad) 34 | return super(MaxPool3dSamePadding, self).forward(x) 35 | 36 | 37 | class Unit3D(nn.Module): 38 | 39 | def __init__(self, in_channels, 40 | output_channels, 41 | kernel_shape=(1, 1, 1), 42 | stride=(1, 1, 1), 43 | padding=0, 44 | activation_fn=F.relu, 45 | use_batch_norm=True, 46 | use_bias=False, 47 | name='unit_3d'): 48 | 49 | """Initializes Unit3D module.""" 50 | super(Unit3D, self).__init__() 51 | 52 | self._output_channels = output_channels 53 | self._kernel_shape = kernel_shape 54 | self._stride = stride 55 | self._use_batch_norm = use_batch_norm 56 | self._activation_fn = activation_fn 57 | self._use_bias = use_bias 58 | self.name = name 59 | self.padding = padding 60 | 61 | self.conv3d = nn.Conv3d(in_channels=in_channels, 62 | out_channels=self._output_channels, 63 | kernel_size=self._kernel_shape, 64 | stride=self._stride, 65 | padding=0, # we always want padding to be 0 here. We will dynamically pad based on input size in forward function 66 | bias=self._use_bias) 67 | 68 | if self._use_batch_norm: 69 | self.bn = nn.BatchNorm3d(self._output_channels, eps=1e-5, momentum=0.001) 70 | 71 | def compute_pad(self, dim, s): 72 | if s % self._stride[dim] == 0: 73 | return max(self._kernel_shape[dim] - self._stride[dim], 0) 74 | else: 75 | return max(self._kernel_shape[dim] - (s % self._stride[dim]), 0) 76 | 77 | 78 | def forward(self, x): 79 | # compute 'same' padding 80 | (batch, channel, t, h, w) = x.size() 81 | out_t = np.ceil(float(t) / float(self._stride[0])) 82 | out_h = np.ceil(float(h) / float(self._stride[1])) 83 | out_w = np.ceil(float(w) / float(self._stride[2])) 84 | pad_t = self.compute_pad(0, t) 85 | pad_h = self.compute_pad(1, h) 86 | pad_w = self.compute_pad(2, w) 87 | 88 | pad_t_f = pad_t // 2 89 | pad_t_b = pad_t - pad_t_f 90 | pad_h_f = pad_h // 2 91 | pad_h_b = pad_h - pad_h_f 92 | pad_w_f = pad_w // 2 93 | pad_w_b = pad_w - pad_w_f 94 | 95 | pad = (pad_w_f, pad_w_b, pad_h_f, pad_h_b, pad_t_f, pad_t_b) 96 | x = F.pad(x, pad) 97 | 98 | x = self.conv3d(x) 99 | if self._use_batch_norm: 100 | x = self.bn(x) 101 | if self._activation_fn is not None: 102 | x = self._activation_fn(x) 103 | return x 104 | 105 | 106 | 107 | class InceptionModule(nn.Module): 108 | def __init__(self, in_channels, out_channels, name): 109 | super(InceptionModule, self).__init__() 110 | 111 | self.b0 = Unit3D(in_channels=in_channels, output_channels=out_channels[0], kernel_shape=[1, 1, 1], padding=0, 112 | name=name+'/Branch_0/Conv3d_0a_1x1') 113 | self.b1a = Unit3D(in_channels=in_channels, output_channels=out_channels[1], kernel_shape=[1, 1, 1], padding=0, 114 | name=name+'/Branch_1/Conv3d_0a_1x1') 115 | self.b1b = Unit3D(in_channels=out_channels[1], output_channels=out_channels[2], kernel_shape=[3, 3, 3], 116 | name=name+'/Branch_1/Conv3d_0b_3x3') 117 | self.b2a = Unit3D(in_channels=in_channels, output_channels=out_channels[3], kernel_shape=[1, 1, 1], padding=0, 118 | name=name+'/Branch_2/Conv3d_0a_1x1') 119 | self.b2b = Unit3D(in_channels=out_channels[3], output_channels=out_channels[4], kernel_shape=[3, 3, 3], 120 | name=name+'/Branch_2/Conv3d_0b_3x3') 121 | self.b3a = MaxPool3dSamePadding(kernel_size=[3, 3, 3], 122 | stride=(1, 1, 1), padding=0) 123 | self.b3b = Unit3D(in_channels=in_channels, output_channels=out_channels[5], kernel_shape=[1, 1, 1], padding=0, 124 | name=name+'/Branch_3/Conv3d_0b_1x1') 125 | self.name = name 126 | 127 | def forward(self, x): 128 | b0 = self.b0(x) 129 | b1 = self.b1b(self.b1a(x)) 130 | b2 = self.b2b(self.b2a(x)) 131 | b3 = self.b3b(self.b3a(x)) 132 | return torch.cat([b0,b1,b2,b3], dim=1) 133 | 134 | 135 | class InceptionI3d(nn.Module): 136 | """Inception-v1 I3D architecture. 137 | The model is introduced in: 138 | Quo Vadis, Action Recognition? A New Model and the Kinetics Dataset 139 | Joao Carreira, Andrew Zisserman 140 | https://arxiv.org/pdf/1705.07750v1.pdf. 141 | See also the Inception architecture, introduced in: 142 | Going deeper with convolutions 143 | Christian Szegedy, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott Reed, 144 | Dragomir Anguelov, Dumitru Erhan, Vincent Vanhoucke, Andrew Rabinovich. 145 | http://arxiv.org/pdf/1409.4842v1.pdf. 146 | """ 147 | 148 | # Endpoints of the model in order. During construction, all the endpoints up 149 | # to a designated `final_endpoint` are returned in a dictionary as the 150 | # second return value. 151 | VALID_ENDPOINTS = ( 152 | 'Conv3d_1a_7x7', 153 | 'MaxPool3d_2a_3x3', 154 | 'Conv3d_2b_1x1', 155 | 'Conv3d_2c_3x3', 156 | 'MaxPool3d_3a_3x3', 157 | 'Mixed_3b', 158 | 'Mixed_3c', 159 | 'MaxPool3d_4a_3x3', 160 | 'Mixed_4b', 161 | 'Mixed_4c', 162 | 'Mixed_4d', 163 | 'Mixed_4e', 164 | 'Mixed_4f', 165 | 'MaxPool3d_5a_2x2', 166 | 'Mixed_5b', 167 | 'Mixed_5c', 168 | 'Logits', 169 | 'Predictions', 170 | ) 171 | 172 | def __init__(self, num_classes=400, spatial_squeeze=True, 173 | final_endpoint='Logits', name='inception_i3d', in_channels=3, dropout_keep_prob=0.5): 174 | """Initializes I3D model instance. 175 | Args: 176 | num_classes: The number of outputs in the logit layer (default 400, which 177 | matches the Kinetics dataset). 178 | spatial_squeeze: Whether to squeeze the spatial dimensions for the logits 179 | before returning (default True). 180 | final_endpoint: The model contains many possible endpoints. 181 | `final_endpoint` specifies the last endpoint for the model to be built 182 | up to. In addition to the output at `final_endpoint`, all the outputs 183 | at endpoints up to `final_endpoint` will also be returned, in a 184 | dictionary. `final_endpoint` must be one of 185 | InceptionI3d.VALID_ENDPOINTS (default 'Logits'). 186 | name: A string (optional). The name of this module. 187 | Raises: 188 | ValueError: if `final_endpoint` is not recognized. 189 | """ 190 | 191 | if final_endpoint not in self.VALID_ENDPOINTS: 192 | raise ValueError('Unknown final endpoint %s' % final_endpoint) 193 | 194 | super(InceptionI3d, self).__init__() 195 | self._num_classes = num_classes 196 | self._spatial_squeeze = spatial_squeeze 197 | self._final_endpoint = final_endpoint 198 | self.logits = None 199 | 200 | if self._final_endpoint not in self.VALID_ENDPOINTS: 201 | raise ValueError('Unknown final endpoint %s' % self._final_endpoint) 202 | 203 | self.end_points = {} 204 | end_point = 'Conv3d_1a_7x7' 205 | self.end_points[end_point] = Unit3D(in_channels=in_channels, output_channels=64, kernel_shape=[7, 7, 7], 206 | stride=(2, 2, 2), padding=(3,3,3), name=name+end_point) 207 | if self._final_endpoint == end_point: return 208 | 209 | end_point = 'MaxPool3d_2a_3x3' 210 | self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[1, 3, 3], stride=(1, 2, 2), 211 | padding=0) 212 | if self._final_endpoint == end_point: return 213 | 214 | end_point = 'Conv3d_2b_1x1' 215 | self.end_points[end_point] = Unit3D(in_channels=64, output_channels=64, kernel_shape=[1, 1, 1], padding=0, 216 | name=name+end_point) 217 | if self._final_endpoint == end_point: return 218 | 219 | end_point = 'Conv3d_2c_3x3' 220 | self.end_points[end_point] = Unit3D(in_channels=64, output_channels=192, kernel_shape=[3, 3, 3], padding=1, 221 | name=name+end_point) 222 | if self._final_endpoint == end_point: return 223 | 224 | end_point = 'MaxPool3d_3a_3x3' 225 | self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[1, 3, 3], stride=(1, 2, 2), 226 | padding=0) 227 | if self._final_endpoint == end_point: return 228 | 229 | end_point = 'Mixed_3b' 230 | self.end_points[end_point] = InceptionModule(192, [64,96,128,16,32,32], name+end_point) 231 | if self._final_endpoint == end_point: return 232 | 233 | end_point = 'Mixed_3c' 234 | self.end_points[end_point] = InceptionModule(256, [128,128,192,32,96,64], name+end_point) 235 | if self._final_endpoint == end_point: return 236 | 237 | end_point = 'MaxPool3d_4a_3x3' 238 | self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[3, 3, 3], stride=(2, 2, 2), 239 | padding=0) 240 | if self._final_endpoint == end_point: return 241 | 242 | end_point = 'Mixed_4b' 243 | self.end_points[end_point] = InceptionModule(128+192+96+64, [192,96,208,16,48,64], name+end_point) 244 | if self._final_endpoint == end_point: return 245 | 246 | end_point = 'Mixed_4c' 247 | self.end_points[end_point] = InceptionModule(192+208+48+64, [160,112,224,24,64,64], name+end_point) 248 | if self._final_endpoint == end_point: return 249 | 250 | end_point = 'Mixed_4d' 251 | self.end_points[end_point] = InceptionModule(160+224+64+64, [128,128,256,24,64,64], name+end_point) 252 | if self._final_endpoint == end_point: return 253 | 254 | end_point = 'Mixed_4e' 255 | self.end_points[end_point] = InceptionModule(128+256+64+64, [112,144,288,32,64,64], name+end_point) 256 | if self._final_endpoint == end_point: return 257 | 258 | end_point = 'Mixed_4f' 259 | self.end_points[end_point] = InceptionModule(112+288+64+64, [256,160,320,32,128,128], name+end_point) 260 | if self._final_endpoint == end_point: return 261 | 262 | end_point = 'MaxPool3d_5a_2x2' 263 | self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[2, 2, 2], stride=(2, 2, 2), 264 | padding=0) 265 | if self._final_endpoint == end_point: return 266 | 267 | end_point = 'Mixed_5b' 268 | self.end_points[end_point] = InceptionModule(256+320+128+128, [256,160,320,32,128,128], name+end_point) 269 | if self._final_endpoint == end_point: return 270 | 271 | end_point = 'Mixed_5c' 272 | self.end_points[end_point] = InceptionModule(256+320+128+128, [384,192,384,48,128,128], name+end_point) 273 | if self._final_endpoint == end_point: return 274 | 275 | end_point = 'Logits' 276 | self.avg_pool = nn.AvgPool3d(kernel_size=[2, 7, 7], 277 | stride=(1, 1, 1)) 278 | self.dropout = nn.Dropout(dropout_keep_prob) 279 | self.logits = Unit3D(in_channels=384+384+128+128, output_channels=self._num_classes, 280 | kernel_shape=[1, 1, 1], 281 | padding=0, 282 | activation_fn=None, 283 | use_batch_norm=False, 284 | use_bias=True, 285 | name='logits') 286 | 287 | self.build() 288 | 289 | 290 | def replace_logits(self, num_classes): 291 | self._num_classes = num_classes 292 | self.logits = Unit3D(in_channels=384+384+128+128, output_channels=self._num_classes, 293 | kernel_shape=[1, 1, 1], 294 | padding=0, 295 | activation_fn=None, 296 | use_batch_norm=False, 297 | use_bias=True, 298 | name='logits') 299 | 300 | 301 | def build(self): 302 | for k in self.end_points.keys(): 303 | self.add_module(k, self.end_points[k]) 304 | 305 | def forward(self, x): 306 | for end_point in self.VALID_ENDPOINTS: 307 | if end_point in self.end_points: 308 | x = self._modules[end_point](x) # use _modules to work with dataparallel 309 | 310 | x = self.logits(self.dropout(self.avg_pool(x))) 311 | if self._spatial_squeeze: 312 | logits = x.squeeze(3).squeeze(3) 313 | logits = logits.mean(dim=2) 314 | # logits is batch X time X classes, which is what we want to work with 315 | return logits 316 | 317 | 318 | def extract_features(self, x): 319 | for end_point in self.VALID_ENDPOINTS: 320 | if end_point in self.end_points: 321 | x = self._modules[end_point](x) 322 | return self.avg_pool(x) 323 | -------------------------------------------------------------------------------- /PVDM/exps/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anuragajay/hip/f83186af477bf6d8c5f8f3c1fdef43ccb89db25d/PVDM/exps/__init__.py -------------------------------------------------------------------------------- /PVDM/exps/diffusion.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | import torch 5 | 6 | from tools.trainer import latentDDPM 7 | from tools.dataloader import get_loaders 8 | from tools.scheduler import LambdaLinearScheduler 9 | from models.autoencoder.autoencoder_vit import ViTAutoencoder 10 | from models.ddpm.unet import UNetModel, DiffusionWrapper 11 | from losses.ddpm import DDPM 12 | 13 | import copy 14 | from utils import file_name, Logger, download 15 | 16 | #---------------------------------------------------------------------------- 17 | 18 | _num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares] 19 | _reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction. 20 | _counter_dtype = torch.float64 # Data type to use for the internal counters. 21 | _rank = 0 # Rank of the current process. 22 | _sync_device = None # Device to use for multiprocess communication. None = single-process. 23 | _sync_called = False # Has _sync() been called yet? 24 | _counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor 25 | _cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor 26 | 27 | #---------------------------------------------------------------------------- 28 | 29 | def init_multiprocessing(rank, sync_device): 30 | r"""Initializes `torch_utils.training_stats` for collecting statistics 31 | across multiple processes. 32 | This function must be called after 33 | `torch.distributed.init_process_group()` and before `Collector.update()`. 34 | The call is not necessary if multi-process collection is not needed. 35 | Args: 36 | rank: Rank of the current process. 37 | sync_device: PyTorch device to use for inter-process 38 | communication, or None to disable multi-process 39 | collection. Typically `torch.device('cuda', rank)`. 40 | """ 41 | global _rank, _sync_device 42 | assert not _sync_called 43 | _rank = rank 44 | _sync_device = sync_device 45 | 46 | #---------------------------------------------------------------------------- 47 | 48 | def diffusion(rank, args): 49 | device = torch.device('cuda', rank) 50 | 51 | temp_dir = './' 52 | # if args.n_gpus > 1: 53 | # init_file = os.path.abspath(os.path.join(temp_dir, '.torch_distributed_init')) 54 | # if os.name == 'nt': 55 | # init_method = 'file:///' + init_file.replace('\\', '/') 56 | # torch.distributed.init_process_group(backend='gloo', init_method=init_method, rank=rank, world_size=args.n_gpus) 57 | # else: 58 | # init_method = f'file://{init_file}' 59 | # torch.distributed.init_process_group(backend='nccl', init_method=init_method, rank=rank, world_size=args.n_gpus) 60 | 61 | # # Init torch_utils. 62 | # sync_device = torch.device('cuda', rank) if args.n_gpus > 1 else None 63 | # init_multiprocessing(rank=rank, sync_device=sync_device) 64 | 65 | """ ROOT DIRECTORY """ 66 | if rank == 0: 67 | fn = file_name(args) 68 | logger = Logger(fn) 69 | logger.log(args) 70 | logger.log(f'Log path: {logger.logdir}') 71 | rootdir = logger.logdir 72 | else: 73 | logger = None 74 | 75 | if logger is None: 76 | log_ = print 77 | else: 78 | log_ = logger.log 79 | 80 | """ Get Image """ 81 | if rank == 0: 82 | log_(f"Loading dataset {args.data} with resolution {args.res}") 83 | train_loader, test_loader, total_vid = get_loaders(rank, args.data, args.res, args.timesteps, args.skip, args.batch_size, args.n_gpus, args.seed, args.cond_model) 84 | 85 | if args.data == 'cliport': 86 | cond_prob = 0.9 87 | elif args.data == 'SKY': 88 | cond_prob = 0.2 89 | else: 90 | cond_prob = 0.3 91 | 92 | """ Get Model """ 93 | if rank == 0: 94 | log_(f"Generating model") 95 | 96 | torch.cuda.set_device(rank) 97 | first_stage_model = ViTAutoencoder(args.embed_dim, args.ddconfig).to(device) 98 | 99 | if rank == 0: 100 | first_stage_model_ckpt = torch.load(args.first_model) 101 | first_stage_model.load_state_dict(first_stage_model_ckpt) 102 | 103 | unet = UNetModel(**args.unetconfig) 104 | model = DiffusionWrapper(unet).to(device) 105 | 106 | if rank == 0: 107 | torch.save(model.state_dict(), rootdir + f'net_init.pth') 108 | 109 | ema_model = None 110 | if args.n_gpus > 1: 111 | first_stage_model = torch.nn.parallel.DataParallel(first_stage_model) 112 | model = torch.nn.parallel.DataParallel(model) 113 | model = model.module 114 | first_stage_model = first_stage_model.module 115 | 116 | criterion = DDPM(model, channels=args.unetconfig.in_channels, 117 | image_size=args.unetconfig.image_size, 118 | linear_start=args.ddpmconfig.linear_start, 119 | linear_end=args.ddpmconfig.linear_end, 120 | log_every_t=args.ddpmconfig.log_every_t, 121 | w=args.ddpmconfig.w, 122 | ).to(device) 123 | 124 | if args.n_gpus > 1: 125 | criterion = torch.nn.parallel.DataParallel(criterion) 126 | criterion = criterion.module 127 | 128 | if args.scale_lr: 129 | args.lr *= args.batch_size 130 | 131 | opt = torch.optim.AdamW(model.parameters(), lr=args.lr) 132 | lr_scheduler = LambdaLinearScheduler(**args.scheduler) 133 | 134 | latentDDPM(rank, first_stage_model, model, opt, criterion, train_loader, test_loader, lr_scheduler, ema_model, cond_prob, logger) 135 | 136 | if rank == 0: 137 | torch.save(model.state_dict(), rootdir + f'net_meta.pth') 138 | -------------------------------------------------------------------------------- /PVDM/exps/first_stage.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | import torch 5 | 6 | from tools.trainer import first_stage_train 7 | from tools.dataloader import get_loaders 8 | from models.autoencoder.autoencoder_vit import ViTAutoencoder 9 | from losses.perceptual import LPIPSWithDiscriminator 10 | 11 | from utils import file_name, Logger 12 | 13 | #---------------------------------------------------------------------------- 14 | 15 | _num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares] 16 | _reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction. 17 | _counter_dtype = torch.float64 # Data type to use for the internal counters. 18 | _rank = 0 # Rank of the current process. 19 | _sync_device = None # Device to use for multiprocess communication. None = single-process. 20 | _sync_called = False # Has _sync() been called yet? 21 | _counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor 22 | _cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor 23 | 24 | #---------------------------------------------------------------------------- 25 | 26 | def init_multiprocessing(rank, sync_device): 27 | r"""Initializes `torch_utils.training_stats` for collecting statistics 28 | across multiple processes. 29 | This function must be called after 30 | `torch.distributed.init_process_group()` and before `Collector.update()`. 31 | The call is not necessary if multi-process collection is not needed. 32 | Args: 33 | rank: Rank of the current process. 34 | sync_device: PyTorch device to use for inter-process 35 | communication, or None to disable multi-process 36 | collection. Typically `torch.device('cuda', rank)`. 37 | """ 38 | global _rank, _sync_device 39 | assert not _sync_called 40 | _rank = rank 41 | _sync_device = sync_device 42 | 43 | #---------------------------------------------------------------------------- 44 | 45 | def first_stage(rank, args): 46 | device = torch.device('cuda', rank) 47 | 48 | temp_dir = './' 49 | # if args.n_gpus > 1: 50 | # init_file = os.path.abspath(os.path.join(temp_dir, '.torch_distributed_init')) 51 | # if os.name == 'nt': 52 | # init_method = 'file:///' + init_file.replace('\\', '/') 53 | # torch.distributed.init_process_group(backend='gloo', init_method=init_method, rank=rank, world_size=args.n_gpus) 54 | # else: 55 | # init_method = f'file://{init_file}' 56 | # torch.distributed.init_process_group(backend='nccl', init_method=init_method, rank=rank, world_size=args.n_gpus) 57 | 58 | # # Init torch_utils. 59 | # sync_device = torch.device('cuda', rank) if args.n_gpus > 1 else None 60 | # init_multiprocessing(rank=rank, sync_device=sync_device) 61 | 62 | """ ROOT DIRECTORY """ 63 | if rank == 0: 64 | fn = file_name(args) 65 | logger = Logger(fn) 66 | logger.log(args) 67 | logger.log(f'Log path: {logger.logdir}') 68 | rootdir = logger.logdir 69 | else: 70 | logger = None 71 | 72 | if logger is None: 73 | log_ = print 74 | else: 75 | log_ = logger.log 76 | 77 | """ Get Image """ 78 | if rank == 0: 79 | log_(f"Loading dataset {args.data} with resolution {args.res}") 80 | train_loader, test_loader, total_vid = get_loaders(rank, args.data, args.res, args.timesteps, args.skip, args.batch_size, args.n_gpus, args.seed, cond=False) 81 | 82 | """ Get Model """ 83 | if rank == 0: 84 | log_(f"Generating model") 85 | 86 | torch.cuda.set_device(rank) 87 | model = ViTAutoencoder(args.embed_dim, args.ddconfig) 88 | model = model.to(device) 89 | 90 | criterion = LPIPSWithDiscriminator(disc_start = args.lossconfig.params.disc_start, 91 | timesteps = args.ddconfig.timesteps).to(device) 92 | 93 | 94 | opt = torch.optim.AdamW(model.parameters(), 95 | lr=args.lr, 96 | betas=(0.5, 0.9) 97 | ) 98 | 99 | d_opt = torch.optim.AdamW(list(criterion.discriminator_2d.parameters()) + list(criterion.discriminator_3d.parameters()), 100 | lr=args.lr, 101 | betas=(0.5, 0.9)) 102 | 103 | if args.resume and rank == 0: 104 | model_ckpt = torch.load(os.path.join(args.first_stage_folder, 'model_last.pth')) 105 | model.load_state_dict(model_ckpt) 106 | opt_ckpt = torch.load(os.path.join(args.first_stage_folder, 'opt.pth')) 107 | opt.load_state_dict(opt_ckpt) 108 | 109 | del model_ckpt 110 | del opt_ckpt 111 | 112 | if rank == 0: 113 | torch.save(model.state_dict(), rootdir + f'net_init.pth') 114 | 115 | if args.n_gpus > 1: 116 | model = torch.nn.parallel.DataParallel(model) 117 | criterion = torch.nn.parallel.DataParallel(criterion) 118 | model = model.module 119 | criterion = criterion.module 120 | 121 | fp = args.amp 122 | first_stage_train(rank, model, opt, d_opt, criterion, train_loader, test_loader, args.first_model, fp, logger) 123 | 124 | if rank == 0: 125 | torch.save(model.state_dict(), rootdir + f'net_meta.pth') 126 | -------------------------------------------------------------------------------- /PVDM/losses/diffaugment.py: -------------------------------------------------------------------------------- 1 | # Differentiable Augmentation for Data-Efficient GAN Training 2 | # Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han 3 | # https://arxiv.org/pdf/2006.10738 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | 9 | def DiffAugment(x, policy='color,translation,cutout', channels_first=True): 10 | if policy: 11 | if not channels_first: 12 | x = x.permute(0, 3, 1, 2) 13 | for p in policy.split(','): 14 | for f in AUGMENT_FNS[p]: 15 | x = f(x) 16 | if not channels_first: 17 | x = x.permute(0, 2, 3, 1) 18 | x = x.contiguous() 19 | return x 20 | 21 | 22 | def rand_brightness(x): 23 | x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5) 24 | return x 25 | 26 | 27 | def rand_saturation(x): 28 | x_mean = x.mean(dim=1, keepdim=True) 29 | x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean 30 | return x 31 | 32 | 33 | def rand_contrast(x): 34 | x_mean = x.mean(dim=[1, 2, 3], keepdim=True) 35 | x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean 36 | return x 37 | 38 | 39 | def rand_translation(x, ratio=0.125): 40 | shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) 41 | translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device) 42 | translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device) 43 | grid_batch, grid_x, grid_y = torch.meshgrid( 44 | torch.arange(x.size(0), dtype=torch.long, device=x.device), 45 | torch.arange(x.size(2), dtype=torch.long, device=x.device), 46 | torch.arange(x.size(3), dtype=torch.long, device=x.device), 47 | ) 48 | grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1) 49 | grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1) 50 | x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0]) 51 | x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2).contiguous() 52 | return x 53 | 54 | 55 | def rand_cutout(x, ratio=0.5): 56 | cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) 57 | offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device) 58 | offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device) 59 | grid_batch, grid_x, grid_y = torch.meshgrid( 60 | torch.arange(x.size(0), dtype=torch.long, device=x.device), 61 | torch.arange(cutout_size[0], dtype=torch.long, device=x.device), 62 | torch.arange(cutout_size[1], dtype=torch.long, device=x.device), 63 | ) 64 | grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1) 65 | grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1) 66 | mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device) 67 | mask[grid_batch, grid_x, grid_y] = 0 68 | x = x * mask.unsqueeze(1) 69 | return x 70 | 71 | 72 | AUGMENT_FNS = { 73 | 'color': [rand_brightness, rand_saturation, rand_contrast], 74 | 'translation': [rand_translation], 75 | 'cutout': [rand_cutout], 76 | } -------------------------------------------------------------------------------- /PVDM/losses/lpips.py: -------------------------------------------------------------------------------- 1 | """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" 2 | import os 3 | import requests 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torchvision import models 8 | from collections import namedtuple 9 | 10 | from tqdm import tqdm 11 | import hashlib 12 | 13 | 14 | URL_MAP = { 15 | "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1" 16 | } 17 | 18 | CKPT_MAP = { 19 | "vgg_lpips": "vgg.pth" 20 | } 21 | 22 | MD5_MAP = { 23 | "vgg_lpips": "d507d7349b931f0638a25a48a722f98a" 24 | } 25 | 26 | def download(url, local_path, chunk_size=1024): 27 | os.makedirs(os.path.split(local_path)[0], exist_ok=True) 28 | with requests.get(url, stream=True) as r: 29 | total_size = int(r.headers.get("content-length", 0)) 30 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: 31 | with open(local_path, "wb") as f: 32 | for data in r.iter_content(chunk_size=chunk_size): 33 | if data: 34 | f.write(data) 35 | pbar.update(chunk_size) 36 | 37 | 38 | def md5_hash(path): 39 | with open(path, "rb") as f: 40 | content = f.read() 41 | return hashlib.md5(content).hexdigest() 42 | 43 | 44 | def get_ckpt_path(name, root, check=False): 45 | assert name in URL_MAP 46 | path = os.path.join(root, CKPT_MAP[name]) 47 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): 48 | print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) 49 | download(URL_MAP[name], path) 50 | md5 = md5_hash(path) 51 | assert md5 == MD5_MAP[name], md5 52 | return path 53 | 54 | 55 | class LPIPS(nn.Module): 56 | # Learned perceptual metric 57 | def __init__(self, use_dropout=True): 58 | super().__init__() 59 | self.scaling_layer = ScalingLayer() 60 | self.chns = [64, 128, 256, 512, 512] # vg16 features 61 | self.net = vgg16(pretrained=True, requires_grad=False) 62 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) 63 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) 64 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) 65 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) 66 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) 67 | self.load_from_pretrained() 68 | for param in self.parameters(): 69 | param.requires_grad = False 70 | 71 | def load_from_pretrained(self, name="vgg_lpips"): 72 | ckpt = get_ckpt_path(name, "./losses") 73 | self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) 74 | print("loaded pretrained LPIPS loss from {}".format(ckpt)) 75 | 76 | @classmethod 77 | def from_pretrained(cls, name="vgg_lpips"): 78 | if name != "vgg_lpips": 79 | raise NotImplementedError 80 | model = cls() 81 | ckpt = get_ckpt_path(name) 82 | model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) 83 | return model 84 | 85 | def forward(self, input, target): 86 | in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) 87 | outs0, outs1 = self.net(in0_input), self.net(in1_input) 88 | feats0, feats1, diffs = {}, {}, {} 89 | lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] 90 | for kk in range(len(self.chns)): 91 | feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) 92 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 93 | 94 | res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))] 95 | val = res[0] 96 | for l in range(1, len(self.chns)): 97 | val += res[l] 98 | return val 99 | 100 | 101 | class ScalingLayer(nn.Module): 102 | def __init__(self): 103 | super(ScalingLayer, self).__init__() 104 | self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 105 | self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None]) 106 | 107 | def forward(self, inp): 108 | return (inp - self.shift) / self.scale 109 | 110 | 111 | class NetLinLayer(nn.Module): 112 | """ A single linear layer which does a 1x1 conv """ 113 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 114 | super(NetLinLayer, self).__init__() 115 | layers = [nn.Dropout(), ] if (use_dropout) else [] 116 | layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] 117 | self.model = nn.Sequential(*layers) 118 | 119 | 120 | class vgg16(torch.nn.Module): 121 | def __init__(self, requires_grad=False, pretrained=True): 122 | super(vgg16, self).__init__() 123 | vgg_pretrained_features = models.vgg16(pretrained=pretrained).features 124 | self.slice1 = torch.nn.Sequential() 125 | self.slice2 = torch.nn.Sequential() 126 | self.slice3 = torch.nn.Sequential() 127 | self.slice4 = torch.nn.Sequential() 128 | self.slice5 = torch.nn.Sequential() 129 | self.N_slices = 5 130 | for x in range(4): 131 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 132 | for x in range(4, 9): 133 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 134 | for x in range(9, 16): 135 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 136 | for x in range(16, 23): 137 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 138 | for x in range(23, 30): 139 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 140 | if not requires_grad: 141 | for param in self.parameters(): 142 | param.requires_grad = False 143 | 144 | def forward(self, X): 145 | h = self.slice1(X) 146 | h_relu1_2 = h 147 | h = self.slice2(h) 148 | h_relu2_2 = h 149 | h = self.slice3(h) 150 | h_relu3_3 = h 151 | h = self.slice4(h) 152 | h_relu4_3 = h 153 | h = self.slice5(h) 154 | h_relu5_3 = h 155 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) 156 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 157 | return out 158 | 159 | 160 | def normalize_tensor(x,eps=1e-10): 161 | norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True)) 162 | return x/(norm_factor+eps) 163 | 164 | 165 | def spatial_average(x, keepdim=True): 166 | return x.mean([2,3],keepdim=keepdim) -------------------------------------------------------------------------------- /PVDM/losses/perceptual.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | import torch.nn.functional as F 5 | from einops import repeat, rearrange 6 | 7 | import functools 8 | from losses.lpips import LPIPS 9 | from utils import make_pairs 10 | from losses.diffaugment import DiffAugment 11 | 12 | class DummyLoss(nn.Module): 13 | def __init__(self): 14 | super().__init__() 15 | 16 | 17 | def adopt_weight(global_step, threshold=0, value=0.): 18 | weight = 1.0 19 | if global_step < threshold: 20 | weight = value 21 | return weight 22 | 23 | 24 | def hinge_d_loss(logits_real, logits_fake): 25 | loss_real = torch.mean(F.relu(1. - logits_real)) 26 | loss_fake = torch.mean(F.relu(1. + logits_fake)) 27 | d_loss = 0.5 * (loss_real + loss_fake) 28 | return d_loss 29 | 30 | 31 | def vanilla_d_loss(logits_real, logits_fake): 32 | d_loss = 0.5 * ( 33 | torch.mean(torch.nn.functional.softplus(-logits_real)) + 34 | torch.mean(torch.nn.functional.softplus(logits_fake))) 35 | return d_loss 36 | 37 | 38 | def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights): 39 | assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0] 40 | loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3]) 41 | loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3]) 42 | loss_real = (weights * loss_real).sum() / weights.sum() 43 | loss_fake = (weights * loss_fake).sum() / weights.sum() 44 | d_loss = 0.5 * (loss_real + loss_fake) 45 | return d_loss 46 | 47 | 48 | def measure_perplexity(predicted_indices, n_embed): 49 | # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py 50 | # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally 51 | encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed) 52 | avg_probs = encodings.mean(0) 53 | perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() 54 | cluster_use = torch.sum(avg_probs > 0) 55 | return perplexity, cluster_use 56 | 57 | def l1(x, y): 58 | return torch.abs(x-y) 59 | 60 | 61 | def l2(x, y): 62 | return torch.pow((x-y), 2) 63 | 64 | 65 | class LPIPSWithDiscriminator(nn.Module): 66 | def __init__(self, disc_start, disc_num_layers=3, disc_in_channels=3, 67 | pixelloss_weight=4.0, disc_weight=1.0, 68 | perceptual_weight=4.0, feature_weight=4.0, 69 | disc_ndf=64, disc_loss="hinge", timesteps=16): 70 | super().__init__() 71 | assert disc_loss in ["hinge", "vanilla"] 72 | self.s = timesteps 73 | self.perceptual_loss = LPIPS().eval() 74 | self.pixel_loss = l1 75 | 76 | self.discriminator_2d = NLayerDiscriminator(input_nc=disc_in_channels, 77 | n_layers=disc_num_layers, 78 | ndf=disc_ndf 79 | ).apply(weights_init) 80 | self.discriminator_3d = NLayerDiscriminator3D(input_nc=disc_in_channels, 81 | n_layers=disc_num_layers, 82 | ndf=disc_ndf 83 | ).apply(weights_init) 84 | 85 | self.discriminator_iter_start = disc_start 86 | if disc_loss == "hinge": 87 | self.disc_loss = hinge_d_loss 88 | elif disc_loss == "vanilla": 89 | self.disc_loss = vanilla_d_loss 90 | 91 | self.pixel_weight = pixelloss_weight 92 | self.gan_weight = disc_weight 93 | self.perceptual_weight = perceptual_weight 94 | self.gan_feat_weight = feature_weight 95 | 96 | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, 97 | global_step, cond=None, split="train", predicted_indices=None): 98 | 99 | b, c, _, h, w = inputs.size() 100 | rec_loss = self.pixel_weight * F.l1_loss(inputs.contiguous(), reconstructions.contiguous()) 101 | 102 | frame_idx = torch.randint(0,self.s, [b]).cuda() 103 | frame_idx_selected = frame_idx.reshape(-1, 1, 1, 1, 1).repeat(1, c, 1, h, w) 104 | inputs_2d = torch.gather(inputs, 2, frame_idx_selected).squeeze(2) 105 | reconstructions_2d = torch.gather(reconstructions, 2, frame_idx_selected).squeeze(2) 106 | 107 | if optimizer_idx == 0: 108 | if self.perceptual_weight > 0: 109 | """ 110 | p_loss = self.perceptual_weight * self.perceptual_loss(rearrange(inputs, 'b c t h w -> (b t) c h w').contiguous(), 111 | rearrange(reconstructions, 'b c t h w -> (b t) c h w').contiguous()).mean() 112 | """ 113 | p_loss = self.perceptual_weight * self.perceptual_loss(inputs_2d.contiguous(), reconstructions_2d.contiguous()).mean() 114 | else: 115 | p_loss = torch.tensor([0.0]) 116 | 117 | disc_factor = adopt_weight(global_step, threshold=self.discriminator_iter_start) 118 | logits_real_2d, pred_real_2d = self.discriminator_2d(inputs_2d) 119 | logits_real_3d, pred_real_3d = self.discriminator_3d(inputs.contiguous()) 120 | logits_fake_2d, pred_fake_2d = self.discriminator_2d(reconstructions_2d) 121 | logits_fake_3d, pred_fake_3d = self.discriminator_3d(reconstructions.contiguous()) 122 | g_loss = -disc_factor * self.gan_weight * (torch.mean(logits_fake_2d) + torch.mean(logits_fake_3d)) 123 | 124 | image_gan_feat_loss = 0. 125 | video_gan_feat_loss = 0. 126 | 127 | for i in range(len(pred_real_2d)-1): 128 | image_gan_feat_loss += F.l1_loss(pred_fake_2d[i], pred_real_2d[i].detach()) 129 | for i in range(len(pred_real_3d)-1): 130 | video_gan_feat_loss += F.l1_loss(pred_fake_3d[i], pred_real_3d[i].detach()) 131 | 132 | gan_feat_loss = disc_factor * self.gan_feat_weight * (image_gan_feat_loss + video_gan_feat_loss) 133 | return rec_loss + p_loss + g_loss + gan_feat_loss 134 | 135 | if optimizer_idx == 1: 136 | 137 | # second pass for discriminator update 138 | logits_real_2d, _ = self.discriminator_2d(inputs_2d) 139 | logits_real_3d, _ = self.discriminator_3d(inputs.contiguous()) 140 | logits_fake_2d, _ = self.discriminator_2d(reconstructions_2d) 141 | logits_fake_3d, _ = self.discriminator_3d(reconstructions.contiguous()) 142 | 143 | disc_factor = adopt_weight(global_step, threshold=self.discriminator_iter_start) 144 | d_loss = disc_factor * self.gan_weight * (self.disc_loss(logits_real_2d, logits_fake_2d) + self.disc_loss(logits_real_3d, logits_fake_3d)) 145 | 146 | return d_loss 147 | 148 | 149 | def weights_init(m): 150 | classname = m.__class__.__name__ 151 | if classname.find('Conv') != -1: 152 | nn.init.normal_(m.weight.data, 0.0, 0.02) 153 | elif classname.find('BatchNorm') != -1: 154 | nn.init.normal_(m.weight.data, 1.0, 0.02) 155 | nn.init.constant_(m.bias.data, 0) 156 | 157 | class NLayerDiscriminator(nn.Module): 158 | """Defines a PatchGAN discriminator as in Pix2Pix 159 | --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py 160 | """ 161 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, getIntermFeat=True): 162 | # def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, getIntermFeat=True): 163 | super(NLayerDiscriminator, self).__init__() 164 | self.getIntermFeat = getIntermFeat 165 | self.n_layers = n_layers 166 | 167 | kw = 4 168 | padw = int(np.ceil((kw-1.0)/2)) 169 | sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]] 170 | 171 | nf = ndf 172 | for n in range(1, n_layers): 173 | nf_prev = nf 174 | nf = min(nf * 2, 512) 175 | sequence += [[ 176 | nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw), 177 | norm_layer(nf), nn.LeakyReLU(0.2, True) 178 | ]] 179 | 180 | nf_prev = nf 181 | nf = min(nf * 2, 512) 182 | sequence += [[ 183 | nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw), 184 | norm_layer(nf), 185 | nn.LeakyReLU(0.2, True) 186 | ]] 187 | 188 | sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]] 189 | 190 | if use_sigmoid: 191 | sequence += [[nn.Sigmoid()]] 192 | 193 | if getIntermFeat: 194 | for n in range(len(sequence)): 195 | setattr(self, 'model'+str(n), nn.Sequential(*sequence[n])) 196 | else: 197 | sequence_stream = [] 198 | for n in range(len(sequence)): 199 | sequence_stream += sequence[n] 200 | self.model = nn.Sequential(*sequence_stream) 201 | 202 | def forward(self, input): 203 | if self.getIntermFeat: 204 | res = [input] 205 | for n in range(self.n_layers+2): 206 | model = getattr(self, 'model'+str(n)) 207 | res.append(model(res[-1])) 208 | return res[-1], res[1:] 209 | else: 210 | return self.model(input), _ 211 | 212 | class NLayerDiscriminator3D(nn.Module): 213 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm3d, use_sigmoid=False, getIntermFeat=True): 214 | super(NLayerDiscriminator3D, self).__init__() 215 | self.getIntermFeat = getIntermFeat 216 | self.n_layers = n_layers 217 | 218 | kw = 4 219 | padw = int(np.ceil((kw-1.0)/2)) 220 | sequence = [[nn.Conv3d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]] 221 | 222 | nf = ndf 223 | for n in range(1, n_layers): 224 | nf_prev = nf 225 | nf = min(nf * 2, 512) 226 | sequence += [[ 227 | nn.Conv3d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw), 228 | norm_layer(nf), nn.LeakyReLU(0.2, True) 229 | ]] 230 | 231 | nf_prev = nf 232 | nf = min(nf * 2, 512) 233 | sequence += [[ 234 | nn.Conv3d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw), 235 | norm_layer(nf), 236 | nn.LeakyReLU(0.2, True) 237 | ]] 238 | 239 | sequence += [[nn.Conv3d(nf, 1, kernel_size=kw, stride=1, padding=padw)]] 240 | 241 | if use_sigmoid: 242 | sequence += [[nn.Sigmoid()]] 243 | 244 | if getIntermFeat: 245 | for n in range(len(sequence)): 246 | setattr(self, 'model'+str(n), nn.Sequential(*sequence[n])) 247 | else: 248 | sequence_stream = [] 249 | for n in range(len(sequence)): 250 | sequence_stream += sequence[n] 251 | self.model = nn.Sequential(*sequence_stream) 252 | 253 | def forward(self, input): 254 | if self.getIntermFeat: 255 | res = [input] 256 | for n in range(self.n_layers+2): 257 | model = getattr(self, 'model'+str(n)) 258 | res.append(model(res[-1])) 259 | return res[-1], res[1:] 260 | else: 261 | return self.model(input), _ 262 | 263 | 264 | class ActNorm(nn.Module): 265 | def __init__(self, num_features, logdet=False, affine=True, 266 | allow_reverse_init=False): 267 | assert affine 268 | super().__init__() 269 | self.logdet = logdet 270 | self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) 271 | self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) 272 | self.allow_reverse_init = allow_reverse_init 273 | 274 | self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8)) 275 | 276 | def initialize(self, input): 277 | with torch.no_grad(): 278 | flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) 279 | mean = ( 280 | flatten.mean(1) 281 | .unsqueeze(1) 282 | .unsqueeze(2) 283 | .unsqueeze(3) 284 | .permute(1, 0, 2, 3) 285 | ) 286 | std = ( 287 | flatten.std(1) 288 | .unsqueeze(1) 289 | .unsqueeze(2) 290 | .unsqueeze(3) 291 | .permute(1, 0, 2, 3) 292 | ) 293 | 294 | self.loc.data.copy_(-mean) 295 | self.scale.data.copy_(1 / (std + 1e-6)) 296 | 297 | def forward(self, input, reverse=False): 298 | if reverse: 299 | return self.reverse(input) 300 | if len(input.shape) == 2: 301 | input = input[:,:,None,None] 302 | squeeze = True 303 | else: 304 | squeeze = False 305 | 306 | _, _, height, width = input.shape 307 | 308 | if self.training and self.initialized.item() == 0: 309 | self.initialize(input) 310 | self.initialized.fill_(1) 311 | 312 | h = self.scale * (input + self.loc) 313 | 314 | if squeeze: 315 | h = h.squeeze(-1).squeeze(-1) 316 | 317 | if self.logdet: 318 | log_abs = torch.log(torch.abs(self.scale)) 319 | logdet = height*width*torch.sum(log_abs) 320 | logdet = logdet * torch.ones(input.shape[0]).to(input) 321 | return h, logdet 322 | 323 | return h 324 | 325 | def reverse(self, output): 326 | if self.training and self.initialized.item() == 0: 327 | if not self.allow_reverse_init: 328 | raise RuntimeError( 329 | "Initializing ActNorm in reverse direction is " 330 | "disabled by default. Use allow_reverse_init=True to enable." 331 | ) 332 | else: 333 | self.initialize(output) 334 | self.initialized.fill_(1) 335 | 336 | if len(output.shape) == 2: 337 | output = output[:,:,None,None] 338 | squeeze = True 339 | else: 340 | squeeze = False 341 | 342 | h = output / self.scale - self.loc 343 | 344 | if squeeze: 345 | h = h.squeeze(-1).squeeze(-1) 346 | return h 347 | -------------------------------------------------------------------------------- /PVDM/losses/vgg.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anuragajay/hip/f83186af477bf6d8c5f8f3c1fdef43ccb89db25d/PVDM/losses/vgg.pth -------------------------------------------------------------------------------- /PVDM/main.py: -------------------------------------------------------------------------------- 1 | import sys; sys.path.extend(['.']) 2 | 3 | import os 4 | import argparse 5 | 6 | import torch 7 | from omegaconf import OmegaConf 8 | 9 | from exps.diffusion import diffusion 10 | from exps.first_stage import first_stage 11 | 12 | from utils import set_random_seed 13 | 14 | 15 | 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('--exp', type=str, required=True, help='experiment name to run') 18 | parser.add_argument('--seed', type=int, default=42, help='random seed') 19 | parser.add_argument('--id', type=str, default='main', help='experiment identifier') 20 | 21 | """ Args about Data """ 22 | parser.add_argument('--data', type=str, default='cliport') 23 | parser.add_argument('--batch_size', type=int, default=24) 24 | parser.add_argument('--ds', type=int, default=4) 25 | 26 | """ Args about Model """ 27 | parser.add_argument('--pretrain_config', type=str, default='configs/autoencoder/base.yaml') 28 | parser.add_argument('--diffusion_config', type=str, default='configs/latent-diffusion/base.yaml') 29 | 30 | # for GAN resume 31 | parser.add_argument('--first_stage_folder', type=str, default='', help='the folder of first stage experiment before GAN') 32 | 33 | # for diffusion model path specification 34 | parser.add_argument('--first_model', type=str, default='', help='the path of pretrained model') 35 | parser.add_argument('--scale_lr', action='store_true') 36 | 37 | 38 | def main(): 39 | """ Additional args ends here. """ 40 | args = parser.parse_args() 41 | """ FIX THE RANDOMNESS """ 42 | set_random_seed(args.seed) 43 | torch.backends.cudnn.deterministic = True 44 | torch.backends.cudnn.benchmark = False 45 | 46 | args.n_gpus = torch.cuda.device_count() 47 | 48 | # init and save configs 49 | 50 | """ RUN THE EXP """ 51 | if args.exp == 'ddpm': 52 | config = OmegaConf.load(args.diffusion_config) 53 | first_stage_config = OmegaConf.load(args.pretrain_config) 54 | 55 | args.unetconfig = config.model.params.unet_config 56 | args.lr = config.model.base_learning_rate 57 | args.scheduler = config.model.params.scheduler_config 58 | args.res = first_stage_config.model.params.ddconfig.resolution 59 | args.timesteps = first_stage_config.model.params.ddconfig.timesteps 60 | args.skip = first_stage_config.model.params.ddconfig.skip 61 | args.ddconfig = first_stage_config.model.params.ddconfig 62 | args.embed_dim = first_stage_config.model.params.embed_dim 63 | args.ddpmconfig = config.model.params 64 | args.cond_model = config.model.cond_model 65 | 66 | # if args.n_gpus == 1: 67 | # diffusion(rank=0, args=args) 68 | # else: 69 | # torch.multiprocessing.spawn(fn=diffusion, args=(args, ), nprocs=args.n_gpus) 70 | diffusion(rank=0, args=args) 71 | 72 | elif args.exp == 'first_stage': 73 | config = OmegaConf.load(args.pretrain_config) 74 | args.ddconfig = config.model.params.ddconfig 75 | args.embed_dim = config.model.params.embed_dim 76 | args.lossconfig = config.model.params.lossconfig 77 | args.lr = config.model.base_learning_rate 78 | args.res = config.model.params.ddconfig.resolution 79 | args.timesteps = config.model.params.ddconfig.timesteps 80 | args.skip = config.model.params.ddconfig.skip 81 | args.resume = config.model.resume 82 | args.amp = config.model.amp 83 | # if args.n_gpus == 1: 84 | # first_stage(rank=0, args=args) 85 | # else: 86 | # torch.multiprocessing.spawn(fn=first_stage, args=(args, ), nprocs=args.n_gpus) 87 | first_stage(rank=0, args=args) 88 | 89 | else: 90 | raise ValueError("Unknown experiment.") 91 | 92 | if __name__ == '__main__': 93 | main() -------------------------------------------------------------------------------- /PVDM/metric_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import os 10 | import time 11 | import hashlib 12 | import pickle 13 | import copy 14 | import uuid 15 | from urllib.parse import urlparse 16 | import numpy as np 17 | import torch 18 | 19 | import ctypes 20 | import fnmatch 21 | import importlib 22 | import inspect 23 | import numpy as np 24 | import os 25 | import shutil 26 | import sys 27 | import types 28 | import io 29 | import pickle 30 | import re 31 | import requests 32 | import html 33 | import hashlib 34 | import glob 35 | import tempfile 36 | import urllib 37 | import urllib.request 38 | import uuid 39 | 40 | 41 | _dnnlib_cache_dir = None 42 | 43 | def make_cache_dir_path(*paths: str) -> str: 44 | if _dnnlib_cache_dir is not None: 45 | return os.path.join(_dnnlib_cache_dir, *paths) 46 | if 'DNNLIB_CACHE_DIR' in os.environ: 47 | return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths) 48 | if 'HOME' in os.environ: 49 | return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths) 50 | if 'USERPROFILE' in os.environ: 51 | return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths) 52 | return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths) 53 | #---------------------------------------------------------------------------- 54 | _feature_detector_cache = dict() 55 | 56 | def get_feature_detector_name(url): 57 | return os.path.splitext(url.split('/')[-1])[0] 58 | 59 | def get_feature_detector(url, device=torch.device('cpu'), num_gpus=1, rank=0, verbose=False): 60 | assert 0 <= rank < num_gpus 61 | key = (url, device) 62 | if key not in _feature_detector_cache: 63 | is_leader = (rank == 0) 64 | if not is_leader and num_gpus > 1: 65 | torch.distributed.barrier() # leader goes first 66 | with open_url(url, verbose=(verbose and is_leader)) as f: 67 | if urlparse(url).path.endswith('.pkl'): 68 | _feature_detector_cache[key] = pickle.load(f).to(device) 69 | else: 70 | _feature_detector_cache[key] = torch.jit.load(f).eval().to(device) 71 | if is_leader and num_gpus > 1: 72 | torch.distributed.barrier() # others follow 73 | return _feature_detector_cache[key] 74 | 75 | 76 | def open_url(url: str, cache_dir=None, num_attempts= 10, verbose= True, return_filename = False, cache= True): 77 | """Download the given URL and return a binary-mode file object to access the data.""" 78 | assert num_attempts >= 1 79 | assert not (return_filename and (not cache)) 80 | 81 | # Doesn't look like an URL scheme so interpret it as a local filename. 82 | if not re.match('^[a-z]+://', url): 83 | return url if return_filename else open(url, "rb") 84 | 85 | # Handle file URLs. This code handles unusual file:// patterns that 86 | # arise on Windows: 87 | # 88 | # file:///c:/foo.txt 89 | # 90 | # which would translate to a local '/c:/foo.txt' filename that's 91 | # invalid. Drop the forward slash for such pathnames. 92 | # 93 | # If you touch this code path, you should test it on both Linux and 94 | # Windows. 95 | # 96 | # Some internet resources suggest using urllib.request.url2pathname() but 97 | # but that converts forward slashes to backslashes and this causes 98 | # its own set of problems. 99 | if url.startswith('file://'): 100 | filename = urllib.parse.urlparse(url).path 101 | if re.match(r'^/[a-zA-Z]:', filename): 102 | filename = filename[1:] 103 | return filename if return_filename else open(filename, "rb") 104 | 105 | # Lookup from cache. 106 | if cache_dir is None: 107 | cache_dir = make_cache_dir_path('downloads') 108 | 109 | url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest() 110 | if cache: 111 | cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*")) 112 | if len(cache_files) == 1: 113 | filename = cache_files[0] 114 | return filename if return_filename else open(filename, "rb") 115 | 116 | # Download. 117 | url_name = None 118 | url_data = None 119 | with requests.Session() as session: 120 | if verbose: 121 | print("Downloading %s ..." % url, end="", flush=True) 122 | for attempts_left in reversed(range(num_attempts)): 123 | try: 124 | with session.get(url) as res: 125 | res.raise_for_status() 126 | if len(res.content) == 0: 127 | raise IOError("No data received") 128 | 129 | if len(res.content) < 8192: 130 | content_str = res.content.decode("utf-8") 131 | if "download_warning" in res.headers.get("Set-Cookie", ""): 132 | links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link] 133 | if len(links) == 1: 134 | url = requests.compat.urljoin(url, links[0]) 135 | raise IOError("Google Drive virus checker nag") 136 | if "Google Drive - Quota exceeded" in content_str: 137 | raise IOError("Google Drive download quota exceeded -- please try again later") 138 | 139 | match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", "")) 140 | url_name = match[1] if match else url 141 | url_data = res.content 142 | if verbose: 143 | print(" done") 144 | break 145 | except KeyboardInterrupt: 146 | raise 147 | except: 148 | if not attempts_left: 149 | if verbose: 150 | print(" failed") 151 | raise 152 | if verbose: 153 | print(".", end="", flush=True) 154 | 155 | # Save to cache. 156 | if cache: 157 | safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name) 158 | cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name) 159 | temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name) 160 | os.makedirs(cache_dir, exist_ok=True) 161 | with open(temp_file, "wb") as f: 162 | f.write(url_data) 163 | os.replace(temp_file, cache_file) # atomic 164 | if return_filename: 165 | return cache_file 166 | 167 | # Return data as file object. 168 | assert not return_filename 169 | return io.BytesIO(url_data) 170 | 171 | 172 | 173 | 174 | #---------------------------------------------------------------------------- -------------------------------------------------------------------------------- /PVDM/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anuragajay/hip/f83186af477bf6d8c5f8f3c1fdef43ccb89db25d/PVDM/models/__init__.py -------------------------------------------------------------------------------- /PVDM/models/autoencoder/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anuragajay/hip/f83186af477bf6d8c5f8f3c1fdef43ccb89db25d/PVDM/models/autoencoder/__init__.py -------------------------------------------------------------------------------- /PVDM/models/autoencoder/autoencoder_vit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | 6 | from models.autoencoder.vit_modules import TimeSformerEncoder, TimeSformerDecoder 7 | from einops import rearrange, repeat 8 | from einops.layers.torch import Rearrange 9 | 10 | from time import sleep 11 | 12 | # siren layer 13 | 14 | class PreNorm(nn.Module): 15 | def __init__(self, dim, fn): 16 | super().__init__() 17 | self.norm = nn.LayerNorm(dim) 18 | self.fn = fn 19 | def forward(self, x, **kwargs): 20 | return self.fn(self.norm(x), **kwargs) 21 | 22 | class FeedForward(nn.Module): 23 | def __init__(self, dim, hidden_dim, dropout = 0.): 24 | super().__init__() 25 | self.net = nn.Sequential( 26 | nn.Linear(dim, hidden_dim), 27 | nn.GELU(), 28 | nn.Dropout(dropout), 29 | nn.Linear(hidden_dim, dim), 30 | nn.Dropout(dropout) 31 | ) 32 | def forward(self, x): 33 | return self.net(x) 34 | 35 | class Attention(nn.Module): 36 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): 37 | super().__init__() 38 | inner_dim = dim_head * heads 39 | project_out = not (heads == 1 and dim_head == dim) 40 | 41 | self.heads = heads 42 | self.scale = dim_head ** -0.5 43 | 44 | self.attend = nn.Softmax(dim = -1) 45 | self.dropout = nn.Dropout(dropout) 46 | 47 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 48 | 49 | self.to_out = nn.Sequential( 50 | nn.Linear(inner_dim, dim), 51 | nn.Dropout(dropout) 52 | ) if project_out else nn.Identity() 53 | 54 | def forward(self, x): 55 | qkv = self.to_qkv(x).chunk(3, dim = -1) 56 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) 57 | 58 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale 59 | 60 | attn = self.attend(dots) 61 | attn = self.dropout(attn) 62 | 63 | out = torch.matmul(attn, v) 64 | out = rearrange(out, 'b h n d -> b n (h d)') 65 | return self.to_out(out) 66 | 67 | class Transformer(nn.Module): 68 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): 69 | super().__init__() 70 | self.layers = nn.ModuleList([]) 71 | for _ in range(depth): 72 | self.layers.append(nn.ModuleList([ 73 | PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), 74 | PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) 75 | ])) 76 | def forward(self, x): 77 | for attn, ff in self.layers: 78 | x = attn(x) + x 79 | x = ff(x) + x 80 | return x 81 | 82 | # =========================================================================================== 83 | 84 | 85 | class ViTAutoencoder(nn.Module): 86 | def __init__(self, 87 | embed_dim, 88 | ddconfig, 89 | ckpt_path=None, 90 | ignore_keys=[], 91 | image_key="image", 92 | colorize_nlabels=None, 93 | monitor=None, 94 | ): 95 | super().__init__() 96 | self.splits = ddconfig["splits"] 97 | self.s = ddconfig["timesteps"] // self.splits 98 | 99 | self.res = ddconfig["resolution"] 100 | 101 | self.res_h = int(0.75*self.res) 102 | self.res_w = self.res 103 | 104 | self.embed_dim = embed_dim 105 | self.image_key = image_key 106 | 107 | patch_size = 8 108 | self.down = 3 109 | if self.res <= 128: 110 | patch_size = 4 111 | self.down = 2 112 | 113 | self.encoder = TimeSformerEncoder(dim=ddconfig["channels"], 114 | image_size=ddconfig["resolution"], 115 | num_frames=ddconfig["timesteps"], 116 | depth=8, 117 | patch_size=patch_size) 118 | 119 | self.decoder = TimeSformerDecoder(dim=ddconfig["channels"], 120 | image_size=ddconfig["resolution"], 121 | num_frames=ddconfig["timesteps"], 122 | depth=8, 123 | patch_size=patch_size) 124 | 125 | self.to_pixel = nn.Sequential( 126 | Rearrange('b (t h w) c -> (b t) c h w', h=self.res_h // patch_size, w=self.res_w // patch_size), 127 | nn.ConvTranspose2d(ddconfig["channels"], 3, kernel_size=(patch_size, patch_size), stride=patch_size), 128 | ) 129 | 130 | self.act = nn.Sigmoid() 131 | ts = torch.linspace(-1, 1, steps=self.s).unsqueeze(-1) 132 | self.register_buffer('coords', ts) 133 | 134 | self.xy_token = nn.Parameter(torch.randn(1, 1, ddconfig["channels"])) 135 | self.xt_token = nn.Parameter(torch.randn(1, 1, ddconfig["channels"])) 136 | self.yt_token = nn.Parameter(torch.randn(1, 1, ddconfig["channels"])) 137 | 138 | self.xy_pos_embedding = nn.Parameter(torch.randn(1, self.s + 1, ddconfig["channels"])) 139 | self.xt_pos_embedding = nn.Parameter(torch.randn(1, self.res_w//(2**self.down) + 1, ddconfig["channels"])) 140 | self.yt_pos_embedding = nn.Parameter(torch.randn(1, self.res_h//(2**self.down) + 1, ddconfig["channels"])) 141 | 142 | self.xy_quant_attn = Transformer(ddconfig["channels"], 4, self.embed_dim, ddconfig["channels"] // 8, 512) 143 | self.yt_quant_attn = Transformer(ddconfig["channels"], 4, self.embed_dim, ddconfig["channels"] // 8, 512) 144 | self.xt_quant_attn = Transformer(ddconfig["channels"], 4, self.embed_dim, ddconfig["channels"] // 8, 512) 145 | 146 | self.pre_xy = torch.nn.Conv2d(ddconfig["channels"], self.embed_dim, 1) 147 | self.pre_xt = torch.nn.Conv2d(ddconfig["channels"], self.embed_dim, 1) 148 | self.pre_yt = torch.nn.Conv2d(ddconfig["channels"], self.embed_dim, 1) 149 | 150 | self.post_xy = torch.nn.Conv2d(self.embed_dim, ddconfig["channels"], 1) 151 | self.post_xt = torch.nn.Conv2d(self.embed_dim, ddconfig["channels"], 1) 152 | self.post_yt = torch.nn.Conv2d(self.embed_dim, ddconfig["channels"], 1) 153 | 154 | def encode(self, x): 155 | # x: b c t h w 156 | b = x.size(0) 157 | x = rearrange(x, 'b c t h w -> b t c h w') 158 | h = self.encoder(x) 159 | h = rearrange(h, 'b (t h w) c -> b c t h w', t=self.s, h=self.res_h//(2**self.down)) 160 | 161 | h_xy = rearrange(h, 'b c t h w -> (b h w) t c') 162 | n = h_xy.size(1) 163 | xy_token = repeat(self.xy_token, '1 1 d -> bhw 1 d', bhw = h_xy.size(0)) 164 | h_xy = torch.cat([h_xy, xy_token], dim=1) 165 | h_xy += self.xy_pos_embedding[:, :(n+1)] 166 | h_xy = self.xy_quant_attn(h_xy)[:, 0] 167 | h_xy = rearrange(h_xy, '(b h w) c -> b c h w', b=b, h=self.res_h//(2**self.down)) 168 | 169 | h_yt = rearrange(h, 'b c t h w -> (b t w) h c') 170 | n = h_yt.size(1) 171 | yt_token = repeat(self.yt_token, '1 1 d -> btw 1 d', btw = h_yt.size(0)) 172 | h_yt = torch.cat([h_yt, yt_token], dim=1) 173 | h_yt += self.yt_pos_embedding[:, :(n+1)] 174 | h_yt = self.yt_quant_attn(h_yt)[:, 0] 175 | h_yt = rearrange(h_yt, '(b t w) c -> b c t w', b=b, w=self.res_w//(2**self.down)) 176 | 177 | h_xt = rearrange(h, 'b c t h w -> (b t h) w c') 178 | n = h_xt.size(1) 179 | xt_token = repeat(self.xt_token, '1 1 d -> bth 1 d', bth = h_xt.size(0)) 180 | h_xt = torch.cat([h_xt, xt_token], dim=1) 181 | h_xt += self.xt_pos_embedding[:, :(n+1)] 182 | h_xt = self.xt_quant_attn(h_xt)[:, 0] 183 | h_xt = rearrange(h_xt, '(b t h) c -> b c t h', b=b, h=self.res_h//(2**self.down)) 184 | 185 | h_xy = self.pre_xy(h_xy) 186 | h_yt = self.pre_yt(h_yt) 187 | h_xt = self.pre_xt(h_xt) 188 | 189 | h_xy = torch.tanh(h_xy) 190 | h_yt = torch.tanh(h_yt) 191 | h_xt = torch.tanh(h_xt) 192 | 193 | h_xy = self.post_xy(h_xy) 194 | h_yt = self.post_yt(h_yt) 195 | h_xt = self.post_xt(h_xt) 196 | 197 | h_xy = h_xy.unsqueeze(-3).expand(-1,-1,self.s,-1, -1) 198 | h_yt = h_yt.unsqueeze(-2).expand(-1,-1,-1,self.res_h//(2**self.down), -1) 199 | h_xt = h_xt.unsqueeze(-1).expand(-1,-1,-1,-1,self.res_w//(2**self.down)) 200 | return h_xy + h_yt + h_xt #torch.cat([h_xy, h_yt, h_xt], dim=1) 201 | 202 | def decode(self, z): 203 | b = z.size(0) 204 | dec = self.decoder(z) 205 | return 2*self.act(self.to_pixel(dec)).contiguous() -1 206 | 207 | def forward(self, input): 208 | input = rearrange(input, 'b c (n t) h w -> (b n) c t h w', n=self.splits) 209 | z = self.encode(input) 210 | dec = self.decode(z) 211 | return dec, 0. 212 | 213 | def extract(self, x): 214 | b = x.size(0) 215 | x = rearrange(x, 'b c t h w -> b t c h w') 216 | h = self.encoder(x) 217 | h = rearrange(h, 'b (t h w) c -> b c t h w', t=self.s, h=self.res_h//(2**self.down)) 218 | 219 | h_xy = rearrange(h, 'b c t h w -> (b h w) t c') 220 | n = h_xy.size(1) 221 | xy_token = repeat(self.xy_token, '1 1 d -> bhw 1 d', bhw = h_xy.size(0)) 222 | h_xy = torch.cat([h_xy, xy_token], dim=1) 223 | h_xy += self.xy_pos_embedding[:, :(n+1)] 224 | h_xy = self.xy_quant_attn(h_xy)[:, 0] 225 | h_xy = rearrange(h_xy, '(b h w) c -> b c h w', b=b, h=self.res_h//(2**self.down)) 226 | 227 | h_yt = rearrange(h, 'b c t h w -> (b t w) h c') 228 | n = h_yt.size(1) 229 | yt_token = repeat(self.yt_token, '1 1 d -> btw 1 d', btw = h_yt.size(0)) 230 | h_yt = torch.cat([h_yt, yt_token], dim=1) 231 | h_yt += self.yt_pos_embedding[:, :(n+1)] 232 | h_yt = self.yt_quant_attn(h_yt)[:, 0] 233 | h_yt = rearrange(h_yt, '(b t w) c -> b c t w', b=b, w=self.res_w//(2**self.down)) 234 | 235 | h_xt = rearrange(h, 'b c t h w -> (b t h) w c') 236 | n = h_xt.size(1) 237 | xt_token = repeat(self.xt_token, '1 1 d -> bth 1 d', bth = h_xt.size(0)) 238 | h_xt = torch.cat([h_xt, xt_token], dim=1) 239 | h_xt += self.xt_pos_embedding[:, :(n+1)] 240 | h_xt = self.xt_quant_attn(h_xt)[:, 0] 241 | h_xt = rearrange(h_xt, '(b t h) c -> b c t h', b=b, h=self.res_h//(2**self.down)) 242 | 243 | h_xy = self.pre_xy(h_xy) 244 | h_yt = self.pre_yt(h_yt) 245 | h_xt = self.pre_xt(h_xt) 246 | 247 | h_xy = torch.tanh(h_xy) 248 | h_yt = torch.tanh(h_yt) 249 | h_xt = torch.tanh(h_xt) 250 | 251 | h_xy = h_xy.view(h_xy.size(0), h_xy.size(1), -1) 252 | h_yt = h_yt.view(h_yt.size(0), h_yt.size(1), -1) 253 | h_xt = h_xt.view(h_xt.size(0), h_xt.size(1), -1) 254 | 255 | ret = torch.cat([h_xy, h_yt, h_xt], dim=-1) 256 | return ret 257 | 258 | def decode_from_sample(self, h): 259 | latent_res_h = self.res_h // (2**self.down) 260 | latent_res_w = self.res_w // (2**self.down) 261 | h_xy = h[:, :, 0:latent_res_h*latent_res_w].view(h.size(0), h.size(1), latent_res_h, latent_res_w) 262 | h_yt = h[:, :, latent_res_h*latent_res_w:latent_res_w*(latent_res_h+16)].view(h.size(0), h.size(1), 16, latent_res_w) 263 | h_xt = h[:, :, latent_res_w*(latent_res_h+16):latent_res_w*latent_res_h+16*latent_res_w+16*latent_res_h].view(h.size(0), h.size(1), 16, latent_res_h) 264 | 265 | h_xy = self.post_xy(h_xy) 266 | h_yt = self.post_yt(h_yt) 267 | h_xt = self.post_xt(h_xt) 268 | 269 | h_xy = h_xy.unsqueeze(-3).expand(-1,-1,self.s,-1, -1) 270 | h_yt = h_yt.unsqueeze(-2).expand(-1,-1,-1,self.res_h//(2**self.down), -1) 271 | h_xt = h_xt.unsqueeze(-1).expand(-1,-1,-1,-1,self.res_w//(2**self.down)) 272 | 273 | z = h_xy + h_yt + h_xt 274 | 275 | b = z.size(0) 276 | dec = self.decoder(z) 277 | return 2*self.act(self.to_pixel(dec)).contiguous()-1 278 | -------------------------------------------------------------------------------- /PVDM/models/autoencoder/vit_modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, einsum 3 | import torch.nn.functional as F 4 | from einops import rearrange, repeat 5 | from einops.layers.torch import Rearrange 6 | from math import log, pi 7 | 8 | def rotate_every_two(x): 9 | x = rearrange(x, '... (d j) -> ... d j', j = 2) 10 | x1, x2 = x.unbind(dim = -1) 11 | x = torch.stack((-x2, x1), dim = -1) 12 | return rearrange(x, '... d j -> ... (d j)') 13 | 14 | def apply_rot_emb(q, k, rot_emb): 15 | sin, cos = rot_emb 16 | rot_dim = sin.shape[-1] 17 | (q, q_pass), (k, k_pass) = map(lambda t: (t[..., :rot_dim], t[..., rot_dim:]), (q, k)) 18 | q, k = map(lambda t: t * cos + rotate_every_two(t) * sin, (q, k)) 19 | q, k = map(lambda t: torch.cat(t, dim = -1), ((q, q_pass), (k, k_pass))) 20 | return q, k 21 | 22 | class AxialRotaryEmbedding(nn.Module): 23 | def __init__(self, dim, max_freq = 10): 24 | super().__init__() 25 | self.dim = dim 26 | scales = torch.logspace(0., log(max_freq / 2) / log(2), self.dim // 4, base = 2) 27 | self.register_buffer('scales', scales) 28 | 29 | def forward(self, h, w, device): 30 | scales = rearrange(self.scales, '... -> () ...') 31 | scales = scales.to(device) 32 | 33 | h_seq = torch.linspace(-1., 1., steps = h, device = device) 34 | h_seq = h_seq.unsqueeze(-1) 35 | 36 | w_seq = torch.linspace(-1., 1., steps = w, device = device) 37 | w_seq = w_seq.unsqueeze(-1) 38 | 39 | h_seq = h_seq * scales * pi 40 | w_seq = w_seq * scales * pi 41 | 42 | x_sinu = repeat(h_seq, 'i d -> i j d', j = w) 43 | y_sinu = repeat(w_seq, 'j d -> i j d', i = h) 44 | 45 | sin = torch.cat((x_sinu.sin(), y_sinu.sin()), dim = -1) 46 | cos = torch.cat((x_sinu.cos(), y_sinu.cos()), dim = -1) 47 | 48 | sin, cos = map(lambda t: rearrange(t, 'i j d -> (i j) d'), (sin, cos)) 49 | sin, cos = map(lambda t: repeat(t, 'n d -> () n (d j)', j = 2), (sin, cos)) 50 | return sin, cos 51 | 52 | class RotaryEmbedding(nn.Module): 53 | def __init__(self, dim): 54 | super().__init__() 55 | inv_freqs = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) 56 | self.register_buffer('inv_freqs', inv_freqs) 57 | 58 | def forward(self, n, device): 59 | seq = torch.arange(n, device = device) 60 | freqs = einsum('i, j -> i j', seq, self.inv_freqs) 61 | freqs = torch.cat((freqs, freqs), dim = -1) 62 | freqs = rearrange(freqs, 'n d -> () n d') 63 | return freqs.sin(), freqs.cos() 64 | 65 | def exists(val): 66 | return val is not None 67 | 68 | # classes 69 | 70 | class PreNorm(nn.Module): 71 | def __init__(self, dim, fn): 72 | super().__init__() 73 | self.fn = fn 74 | self.norm = nn.LayerNorm(dim) 75 | 76 | def forward(self, x, *args, **kwargs): 77 | x = self.norm(x) 78 | return self.fn(x, *args, **kwargs) 79 | 80 | # time token shift 81 | 82 | def shift(t, amt): 83 | if amt == 0: 84 | return t 85 | return F.pad(t, (0, 0, 0, 0, amt, -amt)) 86 | 87 | # feedforward 88 | 89 | class GEGLU(nn.Module): 90 | def forward(self, x): 91 | x, gates = x.chunk(2, dim = -1) 92 | return x * F.gelu(gates) 93 | 94 | class FeedForward(nn.Module): 95 | def __init__(self, dim, mult = 4, dropout = 0.): 96 | super().__init__() 97 | self.net = nn.Sequential( 98 | nn.Linear(dim, dim * mult * 2), 99 | GEGLU(), 100 | nn.Dropout(dropout), 101 | nn.Linear(dim * mult, dim) 102 | ) 103 | 104 | def forward(self, x): 105 | return self.net(x) 106 | 107 | # attention 108 | 109 | def attn(q, k, v, mask = None): 110 | sim = einsum('b i d, b j d -> b i j', q, k) 111 | 112 | if exists(mask): 113 | max_neg_value = -torch.finfo(sim.dtype).max 114 | sim.masked_fill_(~mask, max_neg_value) 115 | 116 | attn = sim.softmax(dim = -1) 117 | out = einsum('b i j, b j d -> b i d', attn, v) 118 | return out 119 | 120 | class Attention(nn.Module): 121 | def __init__( 122 | self, 123 | dim, 124 | dim_head = 64, 125 | heads = 8, 126 | dropout = 0. 127 | ): 128 | super().__init__() 129 | self.heads = heads 130 | self.scale = dim_head ** -0.5 131 | inner_dim = dim_head * heads 132 | 133 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 134 | self.to_out = nn.Sequential( 135 | nn.Linear(inner_dim, dim), 136 | nn.Dropout(dropout) 137 | ) 138 | 139 | def forward(self, x, einops_from, einops_to, mask = None, rot_emb = None, **einops_dims): 140 | h = self.heads 141 | q, k, v = self.to_qkv(x).chunk(3, dim = -1) 142 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (q, k, v)) 143 | 144 | q = q * self.scale 145 | 146 | # rearrange across time or space 147 | q, k, v = map(lambda t: rearrange(t, f'{einops_from} -> {einops_to}', **einops_dims), (q, k, v)) 148 | 149 | # add rotary embeddings, if applicable 150 | if exists(rot_emb): 151 | q, k = apply_rot_emb(q, k, rot_emb) 152 | 153 | # expand cls token keys and values across time or space and concat 154 | # attention 155 | out = attn(q, k, v, mask = mask) 156 | out = rearrange(out, f'{einops_to} -> {einops_from}', **einops_dims) 157 | out = rearrange(out, '(b h) n d -> b n (h d)', h = h) 158 | 159 | # combine heads out 160 | return self.to_out(out) 161 | 162 | # main classes 163 | 164 | class TimeSformerEncoder(nn.Module): 165 | def __init__( 166 | self, 167 | *, 168 | dim = 512, 169 | num_frames = 16, 170 | image_size = 64, 171 | patch_size = 8, 172 | channels = 3, 173 | depth = 8, 174 | heads = 8, 175 | dim_head = 64, 176 | attn_dropout = 0., 177 | ff_dropout = 0., 178 | rotary_emb = True, 179 | shift_tokens = False, 180 | ): 181 | super().__init__() 182 | image_size_h = int(0.75*image_size) 183 | image_size_w = image_size 184 | 185 | assert image_size_h % patch_size == 0, 'Image dimensions must be divisible by the patch size.' 186 | assert image_size_w % patch_size == 0, 'Image dimensions must be divisible by the patch size.' 187 | 188 | num_patches = (image_size_h*image_size_w) // (patch_size**2) 189 | num_positions = num_frames * num_patches 190 | patch_dim = channels * patch_size ** 2 191 | 192 | self.heads = heads 193 | self.patch_size = patch_size 194 | self.to_patch_embedding = nn.Linear(patch_dim, dim) 195 | 196 | self.use_rotary_emb = rotary_emb 197 | if rotary_emb: 198 | self.frame_rot_emb = RotaryEmbedding(dim_head) 199 | self.image_rot_emb = AxialRotaryEmbedding(dim_head) 200 | else: 201 | self.pos_emb = nn.Embedding(num_positions, dim) 202 | 203 | self.layers = nn.ModuleList([]) 204 | for _ in range(depth): 205 | ff = FeedForward(dim, dropout = ff_dropout) 206 | time_attn = Attention(dim, dim_head = dim_head, heads = heads, dropout = attn_dropout) 207 | spatial_attn = Attention(dim, dim_head = dim_head, heads = heads, dropout = attn_dropout) 208 | 209 | time_attn, spatial_attn, ff = map(lambda t: PreNorm(dim, t), (time_attn, spatial_attn, ff)) 210 | 211 | self.layers.append(nn.ModuleList([time_attn, spatial_attn, ff])) 212 | 213 | def forward(self, video, frame_mask = None): 214 | b, f, _, h, w, *_, device, p = *video.shape, video.device, self.patch_size 215 | assert h % p == 0 and w % p == 0, f'height {h} and width {w} of video must be divisible by the patch size {p}' 216 | 217 | # calculate num patches in height and width dimension, and number of total patches (n) 218 | hp, wp = (h // p), (w // p) 219 | n = hp * wp 220 | 221 | # video to patch embeddings 222 | video = rearrange(video, 'b f c (h p1) (w p2) -> b (f h w) (p1 p2 c)', p1 = p, p2 = p) 223 | x = self.to_patch_embedding(video) 224 | 225 | # positional embedding 226 | frame_pos_emb = None 227 | image_pos_emb = None 228 | if not self.use_rotary_emb: 229 | x += self.pos_emb(torch.arange(x.shape[1], device = device)) 230 | else: 231 | frame_pos_emb = self.frame_rot_emb(f, device = device) 232 | image_pos_emb = self.image_rot_emb(hp, wp, device = device) 233 | 234 | # time and space attention 235 | for (time_attn, spatial_attn, ff) in self.layers: 236 | x = time_attn(x, 'b (f n) d', '(b n) f d', n = n, mask = frame_mask, rot_emb = frame_pos_emb) + x 237 | x = spatial_attn(x, 'b (f n) d', '(b f) n d', f = f, rot_emb = image_pos_emb) + x 238 | x = ff(x) + x 239 | 240 | return x 241 | 242 | class TimeSformerDecoder(nn.Module): 243 | def __init__( 244 | self, 245 | *, 246 | dim = 512, 247 | num_frames = 16, 248 | image_size = 64, 249 | patch_size = 8, 250 | channels = 3, 251 | depth = 8, 252 | heads = 8, 253 | dim_head = 64, 254 | attn_dropout = 0., 255 | ff_dropout = 0., 256 | rotary_emb = True, 257 | shift_tokens = False, 258 | ): 259 | super().__init__() 260 | image_size_h = int(0.75*image_size) 261 | image_size_w = image_size 262 | 263 | assert image_size_h % patch_size == 0, 'Image dimensions must be divisible by the patch size.' 264 | assert image_size_w % patch_size == 0, 'Image dimensions must be divisible by the patch size.' 265 | 266 | num_patches = (image_size_h*image_size_w) // (patch_size**2) 267 | num_positions = num_frames * num_patches 268 | patch_dim = channels * patch_size ** 2 269 | 270 | 271 | self.heads = heads 272 | self.patch_size = patch_size 273 | 274 | self.use_rotary_emb = rotary_emb 275 | if rotary_emb: 276 | self.frame_rot_emb = RotaryEmbedding(dim_head) 277 | self.image_rot_emb = AxialRotaryEmbedding(dim_head) 278 | else: 279 | self.pos_emb = nn.Embedding(num_positions, dim) 280 | 281 | self.layers = nn.ModuleList([]) 282 | for _ in range(depth): 283 | ff = FeedForward(dim, dropout = ff_dropout) 284 | time_attn = Attention(dim, dim_head = dim_head, heads = heads, dropout = attn_dropout) 285 | spatial_attn = Attention(dim, dim_head = dim_head, heads = heads, dropout = attn_dropout) 286 | 287 | time_attn, spatial_attn, ff = map(lambda t: PreNorm(dim, t), (time_attn, spatial_attn, ff)) 288 | 289 | self.layers.append(nn.ModuleList([time_attn, spatial_attn, ff])) 290 | 291 | def forward(self, x, frame_mask = None): 292 | device = x.device 293 | f, hp, wp = x.size(2), x.size(3), x.size(4) 294 | n = hp * wp 295 | x = rearrange(x, 'b c f h w -> b (f h w) c') 296 | 297 | # positional embedding 298 | frame_pos_emb = None 299 | image_pos_emb = None 300 | if not self.use_rotary_emb: 301 | x += self.pos_emb(torch.arange(x.shape[1], device = device)) 302 | else: 303 | frame_pos_emb = self.frame_rot_emb(f, device = device) 304 | image_pos_emb = self.image_rot_emb(hp, wp, device = device) 305 | 306 | # time and space attention 307 | for (time_attn, spatial_attn, ff) in self.layers: 308 | x = time_attn(x, 'b (f n) d', '(b n) f d', n = n, mask = frame_mask, rot_emb = frame_pos_emb) + x 309 | x = spatial_attn(x, 'b (f n) d', '(b f) n d', f = f, rot_emb = image_pos_emb) + x 310 | x = ff(x) + x 311 | 312 | return x -------------------------------------------------------------------------------- /PVDM/models/ddpm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anuragajay/hip/f83186af477bf6d8c5f8f3c1fdef43ccb89db25d/PVDM/models/ddpm/__init__.py -------------------------------------------------------------------------------- /PVDM/models/ddpm/diffusionmodules.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | from einops import repeat 7 | 8 | 9 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): 10 | if ddim_discr_method == 'uniform': 11 | c = num_ddpm_timesteps // num_ddim_timesteps 12 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) 13 | elif ddim_discr_method == 'quad': 14 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) 15 | else: 16 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') 17 | 18 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps 19 | # add one to get the final alpha values right (the ones from first scale to data during sampling) 20 | steps_out = ddim_timesteps + 1 21 | if verbose: 22 | print(f'Selected timesteps for ddim sampler: {steps_out}') 23 | return steps_out 24 | 25 | 26 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): 27 | # select alphas for computing the variance schedule 28 | alphas = alphacums[ddim_timesteps] 29 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) 30 | 31 | # according the the formula provided in https://arxiv.org/abs/2010.02502 32 | sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) 33 | if verbose: 34 | print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') 35 | print(f'For the chosen value of eta, which is {eta}, ' 36 | f'this results in the following sigma_t schedule for ddim sampler {sigmas}') 37 | return sigmas, alphas, alphas_prev 38 | 39 | 40 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 41 | """ 42 | Create a beta schedule that discretizes the given alpha_t_bar function, 43 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 44 | :param num_diffusion_timesteps: the number of betas to produce. 45 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 46 | produces the cumulative product of (1-beta) up to that 47 | part of the diffusion process. 48 | :param max_beta: the maximum beta to use; use values lower than 1 to 49 | prevent singularities. 50 | """ 51 | betas = [] 52 | for i in range(num_diffusion_timesteps): 53 | t1 = i / num_diffusion_timesteps 54 | t2 = (i + 1) / num_diffusion_timesteps 55 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 56 | return np.array(betas) 57 | 58 | 59 | def checkpoint(func, inputs, params, flag): 60 | """ 61 | Evaluate a function without caching intermediate activations, allowing for 62 | reduced memory at the expense of extra compute in the backward pass. 63 | :param func: the function to evaluate. 64 | :param inputs: the argument sequence to pass to `func`. 65 | :param params: a sequence of parameters `func` depends on but does not 66 | explicitly take as arguments. 67 | :param flag: if False, disable gradient checkpointing. 68 | """ 69 | if flag: 70 | args = tuple(inputs) + tuple(params) 71 | return CheckpointFunction.apply(func, len(inputs), *args) 72 | else: 73 | return func(*inputs) 74 | 75 | 76 | class CheckpointFunction(torch.autograd.Function): 77 | @staticmethod 78 | def forward(ctx, run_function, length, *args): 79 | ctx.run_function = run_function 80 | ctx.input_tensors = list(args[:length]) 81 | ctx.input_params = list(args[length:]) 82 | 83 | with torch.no_grad(): 84 | output_tensors = ctx.run_function(*ctx.input_tensors) 85 | return output_tensors 86 | 87 | @staticmethod 88 | def backward(ctx, *output_grads): 89 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 90 | with torch.enable_grad(): 91 | # Fixes a bug where the first op in run_function modifies the 92 | # Tensor storage in place, which is not allowed for detach()'d 93 | # Tensors. 94 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 95 | output_tensors = ctx.run_function(*shallow_copies) 96 | input_grads = torch.autograd.grad( 97 | output_tensors, 98 | ctx.input_tensors + ctx.input_params, 99 | output_grads, 100 | allow_unused=True, 101 | ) 102 | del ctx.input_tensors 103 | del ctx.input_params 104 | del output_tensors 105 | return (None, None) + input_grads 106 | 107 | 108 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): 109 | """ 110 | Create sinusoidal timestep embeddings. 111 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 112 | These may be fractional. 113 | :param dim: the dimension of the output. 114 | :param max_period: controls the minimum frequency of the embeddings. 115 | :return: an [N x dim] Tensor of positional embeddings. 116 | """ 117 | if not repeat_only: 118 | half = dim // 2 119 | freqs = torch.exp( 120 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 121 | ).to(device=timesteps.device) 122 | args = timesteps[:, None].float() * freqs[None] 123 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 124 | if dim % 2: 125 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 126 | else: 127 | embedding = repeat(timesteps, 'b -> b d', d=dim) 128 | return embedding 129 | 130 | 131 | def zero_module(module): 132 | """ 133 | Zero out the parameters of a module and return it. 134 | """ 135 | for p in module.parameters(): 136 | p.detach().zero_() 137 | return module 138 | 139 | 140 | def scale_module(module, scale): 141 | """ 142 | Scale the parameters of a module and return it. 143 | """ 144 | for p in module.parameters(): 145 | p.detach().mul_(scale) 146 | return module 147 | 148 | 149 | def mean_flat(tensor): 150 | """ 151 | Take the mean over all non-batch dimensions. 152 | """ 153 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 154 | 155 | 156 | def normalization(channels): 157 | """ 158 | Make a standard normalization layer. 159 | :param channels: number of input channels. 160 | :return: an nn.Module for normalization. 161 | """ 162 | return GroupNorm32(32, channels) 163 | 164 | 165 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 166 | class SiLU(nn.Module): 167 | def forward(self, x): 168 | return x * torch.sigmoid(x) 169 | 170 | 171 | class GroupNorm32(nn.GroupNorm): 172 | def forward(self, x): 173 | return super().forward(x.float()).type(x.dtype) 174 | 175 | def conv_nd(dims, *args, **kwargs): 176 | """ 177 | Create a 1D, 2D, or 3D convolution module. 178 | """ 179 | if dims == 1: 180 | return nn.Conv1d(*args, **kwargs) 181 | elif dims == 2: 182 | return nn.Conv2d(*args, **kwargs) 183 | elif dims == 3: 184 | return nn.Conv3d(*args, **kwargs) 185 | raise ValueError(f"unsupported dimensions: {dims}") 186 | 187 | 188 | def linear(*args, **kwargs): 189 | """ 190 | Create a linear module. 191 | """ 192 | return nn.Linear(*args, **kwargs) 193 | 194 | 195 | def avg_pool_nd(dims, *args, **kwargs): 196 | """ 197 | Create a 1D, 2D, or 3D average pooling module. 198 | """ 199 | if dims == 1: 200 | return nn.AvgPool1d(*args, **kwargs) 201 | elif dims == 2: 202 | return nn.AvgPool2d(*args, **kwargs) 203 | elif dims == 3: 204 | return nn.AvgPool3d(*args, **kwargs) 205 | raise ValueError(f"unsupported dimensions: {dims}") 206 | 207 | 208 | 209 | def noise_like(shape, device, repeat=False): 210 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) 211 | noise = lambda: torch.randn(shape, device=device) 212 | return repeat_noise() if repeat else noise() -------------------------------------------------------------------------------- /PVDM/models/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError('Decay must be between 0 and 1') 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates 14 | else torch.tensor(-1,dtype=torch.int)) 15 | 16 | for name, p in model.named_parameters(): 17 | if p.requires_grad: 18 | #remove as '.'-character is not allowed in buffers 19 | s_name = name.replace('.','') 20 | self.m_name2s_name.update({name:s_name}) 21 | self.register_buffer(s_name,p.clone().detach().data) 22 | 23 | self.collected_params = [] 24 | 25 | def forward(self,model): 26 | decay = self.decay 27 | 28 | if self.num_updates >= 0: 29 | self.num_updates += 1 30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) 31 | 32 | one_minus_decay = 1.0 - decay 33 | 34 | with torch.no_grad(): 35 | m_param = dict(model.named_parameters()) 36 | shadow_params = dict(self.named_buffers()) 37 | 38 | for key in m_param: 39 | if m_param[key].requires_grad: 40 | sname = self.m_name2s_name[key] 41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 43 | else: 44 | assert not key in self.m_name2s_name 45 | 46 | def copy_to(self, model): 47 | m_param = dict(model.named_parameters()) 48 | shadow_params = dict(self.named_buffers()) 49 | for key in m_param: 50 | if m_param[key].requires_grad: 51 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 52 | else: 53 | assert not key in self.m_name2s_name 54 | 55 | def store(self, parameters): 56 | """ 57 | Save the current parameters for restoring later. 58 | Args: 59 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 60 | temporarily stored. 61 | """ 62 | self.collected_params = [param.clone() for param in parameters] 63 | 64 | def restore(self, parameters): 65 | """ 66 | Restore the parameters stored with the `store` method. 67 | Useful to validate the model with EMA parameters without affecting the 68 | original optimization process. Store the parameters before the 69 | `copy_to` method. After validation (or model saving), use this to 70 | restore the former parameters. 71 | Args: 72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 73 | updated with the stored parameters. 74 | """ 75 | for c_param, param in zip(self.collected_params, parameters): 76 | param.data.copy_(c_param.data) -------------------------------------------------------------------------------- /PVDM/tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anuragajay/hip/f83186af477bf6d8c5f8f3c1fdef43ccb89db25d/PVDM/tools/__init__.py -------------------------------------------------------------------------------- /PVDM/tools/data_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import math 4 | import random 5 | import pickle 6 | import warnings 7 | import glob 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | import zipfile 12 | import PIL.Image 13 | from PIL import Image 14 | from PIL import ImageFile 15 | from einops import rearrange 16 | from torchvision import transforms 17 | import json 18 | import numpy as np 19 | import pyspng 20 | 21 | from natsort import natsorted 22 | 23 | ImageFile.LOAD_TRUNCATED_IMAGES = True 24 | IMG_EXTENSIONS = [ 25 | '.jpg', '.JPG', '.jpeg', '.JPEG', 26 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 27 | ] 28 | 29 | def is_image_file(filename): 30 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 31 | 32 | 33 | def pil_loader(path): 34 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 35 | ''' 36 | with open(path, 'rb') as f: 37 | with Image.open(f) as img: 38 | return img.convert('RGB') 39 | ''' 40 | Im = Image.open(path) 41 | return Im.convert('RGB') 42 | 43 | 44 | def default_loader(path): 45 | ''' 46 | from torchvision import get_image_backend 47 | if get_image_backend() == 'accimage': 48 | return accimage_loader(path) 49 | else: 50 | ''' 51 | return pil_loader(path) 52 | 53 | 54 | def find_classes(dir): 55 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] 56 | classes.sort() 57 | class_to_idx = {classes[i]: i for i in range(len(classes))} 58 | return classes, class_to_idx 59 | 60 | def resize_crop(video, resolution): 61 | """ Resizes video with smallest axis to `resolution * extra_scale` 62 | and then crops a `resolution` x `resolution` bock. If `crop_mode == "center"` 63 | do a center crop, if `crop_mode == "random"`, does a random crop 64 | Args 65 | video: a tensor of shape [t, c, h, w] in {0, ..., 255} 66 | resolution: an int 67 | crop_mode: 'center', 'random' 68 | Returns 69 | a processed video of shape [c, t, h, w] 70 | """ 71 | _, _, h, w = video.shape 72 | 73 | if h > w: 74 | half = (h - w) // 2 75 | cropsize = (0, half, w, half + w) # left, upper, right, lower 76 | elif w >= h: 77 | half = (w - h) // 2 78 | cropsize = (half, 0, half + h, h) 79 | 80 | video = video[:, :, cropsize[1]:cropsize[3], cropsize[0]:cropsize[2]] 81 | video = F.interpolate(video, size=resolution, mode='bilinear', align_corners=False) 82 | 83 | video = video.permute(1, 0, 2, 3).contiguous() # [c, t, h, w] 84 | return video 85 | 86 | def make_imageclip_dataset(dir, nframes, class_to_idx, vid_diverse_sampling, split='all'): 87 | """ 88 | TODO: add xflip 89 | """ 90 | def _sort(path): 91 | return natsorted(os.listdir(path)) 92 | 93 | images = [] 94 | n_video = 0 95 | n_clip = 0 96 | 97 | 98 | dir_list = natsorted(os.listdir(dir)) 99 | for target in dir_list: 100 | if split == 'train': 101 | if 'val' in target: dir_list.remove(target) 102 | elif split == 'val' or split == 'test': 103 | if 'train' in target: dir_list.remove(target) 104 | 105 | for target in dir_list: 106 | if os.path.isdir(os.path.join(dir,target))==True: 107 | n_video +=1 108 | subfolder_path = os.path.join(dir, target) 109 | for subsubfold in natsorted(os.listdir(subfolder_path) ): 110 | if os.path.isdir(os.path.join(subfolder_path, subsubfold) ): 111 | subsubfolder_path = os.path.join(subfolder_path, subsubfold) 112 | i = 1 113 | 114 | if nframes > 0 and vid_diverse_sampling: 115 | n_clip += 1 116 | 117 | item_frames_0 = [] 118 | item_frames_1 = [] 119 | item_frames_2 = [] 120 | item_frames_3 = [] 121 | 122 | for fi in _sort(subsubfolder_path): 123 | if is_image_file(fi): 124 | file_name = fi 125 | file_path = os.path.join(subsubfolder_path, file_name) 126 | item = (file_path, class_to_idx[target]) 127 | 128 | if i % 4 == 0: 129 | item_frames_0.append(item) 130 | elif i % 4 == 1: 131 | item_frames_1.append(item) 132 | elif i % 4 == 2: 133 | item_frames_2.append(item) 134 | else: 135 | item_frames_3.append(item) 136 | 137 | if i %nframes == 0 and i > 0: 138 | images.append(item_frames_0) # item_frames is a list containing n frames. 139 | images.append(item_frames_1) # item_frames is a list containing n frames. 140 | images.append(item_frames_2) # item_frames is a list containing n frames. 141 | images.append(item_frames_3) # item_frames is a list containing n frames. 142 | item_frames_0 = [] 143 | item_frames_1 = [] 144 | item_frames_2 = [] 145 | item_frames_3 = [] 146 | 147 | i = i+1 148 | else: 149 | item_frames = [] 150 | for fi in _sort(subsubfolder_path): 151 | if is_image_file(fi): 152 | # fi is an image in the subsubfolder 153 | file_name = fi 154 | file_path = os.path.join(subsubfolder_path, file_name) 155 | item = (file_path, class_to_idx[target]) 156 | item_frames.append(item) 157 | if i % nframes == 0 and i > 0: 158 | images.append(item_frames) # item_frames is a list containing 32 frames. 159 | item_frames = [] 160 | i = i + 1 161 | 162 | return images 163 | 164 | 165 | 166 | def make_imagefolder_dataset(dir, nframes, class_to_idx, vid_diverse_sampling, split='all'): 167 | """ 168 | TODO: add xflip 169 | """ 170 | def _sort(path): 171 | return natsorted(os.listdir(path)) 172 | 173 | images = [] 174 | n_video = 0 175 | n_clip = 0 176 | 177 | 178 | dir_list = natsorted(os.listdir(dir)) 179 | for target in dir_list: 180 | if split == 'train': 181 | if 'val' in target: dir_list.remove(target) 182 | elif split == 'val' or split == 'test': 183 | if 'train' in target: dir_list.remove(target) 184 | 185 | dataset_list = [] 186 | for target in dir_list: 187 | if os.path.isdir(os.path.join(dir,target))==True: 188 | n_video +=1 189 | subfolder_path = os.path.join(dir, target) 190 | for subsubfold in natsorted(os.listdir(subfolder_path) ): 191 | if os.path.isdir(os.path.join(subfolder_path, subsubfold) ): 192 | subsubfolder_path = os.path.join(subfolder_path, subsubfold) 193 | 194 | count = 0 195 | valid = False 196 | for fi in _sort(subsubfolder_path): 197 | if is_image_file(fi): 198 | valid = True 199 | count += 1 200 | else: 201 | valid = False 202 | break 203 | """ 204 | valid = True 205 | """ 206 | if valid and count >= nframes: 207 | valid = True 208 | else: 209 | valid = False 210 | 211 | if valid == True: 212 | dataset_list.append((subsubfolder_path, count)) 213 | 214 | return dataset_list 215 | 216 | 217 | class InfiniteSampler(torch.utils.data.Sampler): 218 | def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5): 219 | assert len(dataset) > 0 220 | assert num_replicas > 0 221 | assert 0 <= rank < num_replicas 222 | assert 0 <= window_size <= 1 223 | super().__init__(dataset) 224 | self.dataset = dataset 225 | self.rank = rank 226 | self.num_replicas = num_replicas 227 | self.shuffle = shuffle 228 | self.seed = seed 229 | self.window_size = window_size 230 | 231 | def __iter__(self): 232 | order = np.arange(len(self.dataset)) 233 | rnd = None 234 | window = 0 235 | if self.shuffle: 236 | rnd = np.random.RandomState(self.seed) 237 | rnd.shuffle(order) 238 | window = int(np.rint(order.size * self.window_size)) 239 | 240 | idx = 0 241 | while True: 242 | i = idx % order.size 243 | if idx % self.num_replicas == self.rank: 244 | yield order[i] 245 | if window >= 2: 246 | j = (i - rnd.randint(window)) % order.size 247 | order[i], order[j] = order[j], order[i] 248 | idx += 1 -------------------------------------------------------------------------------- /PVDM/tools/dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import math 4 | import random 5 | import pickle 6 | import warnings 7 | import glob 8 | 9 | import imageio 10 | import numpy as np 11 | 12 | import torch 13 | import torch.nn.functional as F 14 | from torch.utils.data import Dataset, DataLoader 15 | 16 | import torchvision.transforms as T 17 | import torchvision.datasets as datasets 18 | from torchvision.datasets import UCF101 19 | from torchvision.datasets.folder import make_dataset 20 | from tools.video_utils import VideoClips 21 | from torchvision.io import read_video 22 | 23 | data_location = '/data' 24 | from utils import set_random_seed 25 | from tools.data_utils import * 26 | 27 | import av 28 | 29 | from ffcv.fields.decoders import NDArrayDecoder 30 | from ffcv.transforms import ToTensor, Squeeze, ToDevice 31 | from ffcv.loader import Loader, OrderOption 32 | 33 | class VideoFolderDataset(Dataset): 34 | def __init__(self, 35 | root, 36 | train, 37 | resolution, 38 | path=None, 39 | n_frames=16, 40 | skip=1, 41 | fold=1, 42 | max_size=None, # Artificially limit the size of the dataset. None = no limit. Applied before xflip. 43 | use_labels=False, # Enable conditioning labels? False = label dimension is zero. 44 | return_vid=False, # True for evaluating FVD 45 | time_saliency=False, 46 | sub=False, 47 | seed=42, 48 | **super_kwargs, # Additional arguments for the Dataset base class. 49 | ): 50 | 51 | video_root = osp.join(os.path.join(root)) 52 | if not 1 <= fold <= 3: 53 | raise ValueError("fold should be between 1 and 3, got {}".format(fold)) 54 | 55 | self.path = video_root 56 | name = video_root.split('/')[-1] 57 | self.name = name 58 | self.train = train 59 | self.fold = fold 60 | self.resolution = resolution 61 | self.nframes = n_frames 62 | self.annotation_path = os.path.join(video_root, 'ucfTrainTestlist') 63 | self.classes = list(natsorted(p for p in os.listdir(video_root) if osp.isdir(osp.join(video_root, p)))) 64 | self.classes.remove('ucfTrainTestlist') 65 | class_to_idx = {self.classes[i]: i for i in range(len(self.classes))} 66 | self.samples = make_dataset(video_root, class_to_idx, ('avi',), is_valid_file=None) 67 | video_list = [x[0] for x in self.samples] 68 | 69 | self.video_list = video_list 70 | """ 71 | if train: 72 | self.video_list = video_list 73 | else: 74 | self.video_list = [] 75 | for p in video_list: 76 | video = read_video(p)[0] 77 | if len(video) >= 128: 78 | self.video_list.append(p) 79 | """ 80 | 81 | 82 | self._use_labels = use_labels 83 | self._label_shape = None 84 | self._raw_labels = None 85 | self._raw_shape = [len(self.video_list)] + [3, resolution, resolution] 86 | self.num_channels = 3 87 | self.return_vid = return_vid 88 | 89 | frames_between_clips = skip 90 | print(root, frames_between_clips, n_frames) 91 | self.indices = self._select_fold(self.video_list, self.annotation_path, 92 | fold, train) 93 | 94 | self.size = len(self.indices) 95 | print(self.size) 96 | random.seed(seed) 97 | self.shuffle_indices = [i for i in range(self.size)] 98 | random.shuffle(self.shuffle_indices) 99 | 100 | self._need_init = True 101 | 102 | def _select_fold(self, video_list, annotation_path, fold, train): 103 | name = "train" if train else "test" 104 | name = "{}list{:02d}.txt".format(name, fold) 105 | f = os.path.join(annotation_path, name) 106 | selected_files = [] 107 | with open(f, "r") as fid: 108 | data = fid.readlines() 109 | data = [x.strip().split(" ") for x in data] 110 | data = [os.path.join(self.path, x[0]) for x in data] 111 | 112 | """ 113 | for p in data: 114 | if p in video_list == False: 115 | data.remove(p) 116 | """ 117 | 118 | selected_files.extend(data) 119 | 120 | """ 121 | name = "train" if not train else "test" 122 | name = "{}list{:02d}.txt".format(name, fold) 123 | f = os.path.join(annotation_path, name) 124 | with open(f, "r") as fid: 125 | data = fid.readlines() 126 | data = [x.strip().split(" ") for x in data] 127 | data = [os.path.join(self.path, x[0]) for x in data] 128 | selected_files.extend(data) 129 | """ 130 | 131 | selected_files = set(selected_files) 132 | indices = [i for i in range(len(video_list)) if video_list[i] in selected_files] 133 | return indices 134 | 135 | def __len__(self): 136 | return self.size 137 | 138 | def _preprocess(self, video): 139 | video = resize_crop(video, self.resolution) 140 | return video 141 | 142 | def __getitem__(self, idx): 143 | idx = self.shuffle_indices[idx] 144 | idx = self.indices[idx] 145 | video = read_video(self.video_list[idx])[0] 146 | prefix = np.random.randint(len(video)-self.nframes+1) 147 | video = video[prefix:prefix+self.nframes].float().permute(3,0,1,2) 148 | 149 | return self._preprocess(video), idx 150 | 151 | class ImageFolderDataset(Dataset): 152 | def __init__(self, 153 | path, # Path to directory or zip. 154 | resolution=None, 155 | nframes=16, # number of frames for each video. 156 | train=True, 157 | interpolate=False, 158 | loader=default_loader, # loader for "sequence" of images 159 | return_vid=True, # True for evaluating FVD 160 | cond=False, 161 | **super_kwargs, # Additional arguments for the Dataset base class. 162 | ): 163 | 164 | self._path = path 165 | self._zipfile = None 166 | self.apply_resize = True 167 | 168 | # classes, class_to_idx = find_classes(path) 169 | if 'taichi' in path and not interpolate: 170 | classes, class_to_idx = find_classes(path) 171 | imgs = make_imagefolder_dataset(path, nframes * 4, class_to_idx, True) 172 | elif 'kinetics' in path or 'KINETICS' in path: 173 | if train: 174 | split = 'train' 175 | else: 176 | split = 'val' 177 | classes, class_to_idx = find_classes(path) 178 | imgs = make_imagefolder_dataset(path, nframes, class_to_idx, False, split) 179 | elif 'SKY' in path: 180 | if train: 181 | split = 'train' 182 | else: 183 | split = 'test' 184 | path = os.path.join(path, split) 185 | classes, class_to_idx = find_classes(path) 186 | if cond: 187 | imgs = make_imagefolder_dataset(path, nframes // 2, class_to_idx, False, split) 188 | else: 189 | imgs = make_imagefolder_dataset(path, nframes, class_to_idx, False, split) 190 | else: 191 | classes, class_to_idx = find_classes(path) 192 | imgs = make_imagefolder_dataset(path, nframes, class_to_idx, False) 193 | 194 | if len(imgs) == 0: 195 | raise(RuntimeError("Found 0 images in subfolders of: " + path + "\n" 196 | "Supported image extensions are: " + 197 | ",".join(IMG_EXTENSIONS))) 198 | 199 | self.imgs = imgs 200 | self.classes = classes 201 | self.class_to_idx = class_to_idx 202 | self.nframes = nframes 203 | self.loader = loader 204 | self.img_resolution = resolution 205 | self._path = path 206 | self._total_size = len(self.imgs) 207 | self._raw_shape = [self._total_size] + [3, resolution, resolution] 208 | self.xflip = False 209 | self.return_vid = return_vid 210 | self.shuffle_indices = [i for i in range(self._total_size)] 211 | self.to_tensor = transforms.ToTensor() 212 | random.shuffle(self.shuffle_indices) 213 | self._type = "dir" 214 | 215 | def _file_ext(self, fname): 216 | return os.path.splitext(fname)[1].lower() 217 | 218 | def _get_zipfile(self): 219 | assert self._type == 'zip' 220 | if self._zipfile is None: 221 | self._zipfile = zipfile.ZipFile(self._path) 222 | return self._zipfile 223 | 224 | def _open_file(self, fname): 225 | if self._type == 'dir': 226 | return open(os.path.join(fname), 'rb') 227 | if self._type == 'zip': 228 | return self._get_zipfile().open(fname, 'r') 229 | return None 230 | 231 | def close(self): 232 | try: 233 | if self._zipfile is not None: 234 | self._zipfile.close() 235 | finally: 236 | self._zipfile = None 237 | 238 | def _load_img_from_path(self, folder, fname): 239 | path = os.path.join(folder, fname) 240 | with self._open_file(path) as f: 241 | if pyspng is not None and self._file_ext(path) == '.png': 242 | img = pyspng.load(f.read()) 243 | img = rearrange(img, 'h w c -> c h w') 244 | else: 245 | img = self.to_tensor(PIL.Image.open(f)).numpy() * 255 # c h w 246 | return img 247 | 248 | def __getitem__(self, index): 249 | index = self.shuffle_indices[index] 250 | path = self.imgs[index] 251 | 252 | # clip is a list of 32 frames 253 | video = natsorted(os.listdir(path[0])) 254 | 255 | # zero padding. only unconditional modeling 256 | if len(video) < self.nframes: 257 | prefix = np.random.randint(len(video)-self.nframes//2+1) 258 | clip = video[prefix:prefix+self.nframes//2] 259 | else: 260 | prefix = np.random.randint(len(video)-self.nframes+1) 261 | clip = video[prefix:prefix+self.nframes] 262 | 263 | assert (len(clip) == self.nframes or len(clip)*2 == self.nframes) 264 | 265 | vid = np.stack([self._load_img_from_path(folder=path[0], fname=clip[i]) for i in range(len(clip))], axis=0) 266 | vid = resize_crop(torch.from_numpy(vid).float(), resolution=self.img_resolution) # c t h w 267 | if vid.size(1) == self.nframes//2: 268 | vid = torch.cat([torch.zeros_like(vid).to(vid.device), vid], dim=1) 269 | 270 | return rearrange(vid, 'c t h w -> t c h w'), index 271 | 272 | 273 | def __len__(self): 274 | return self._total_size 275 | 276 | def get_loaders(rank, imgstr, resolution, timesteps, skip, batch_size=1, n_gpus=1, seed=42, cond=False, use_train_set=False): 277 | """ 278 | Load dataloaders for an image dataset, center-cropped to a resolution. 279 | """ 280 | if imgstr == 'cliport': 281 | base_path = '/path/to/cliport' 282 | num_workers = 20 283 | 284 | trainloader = Loader(f'{base_path}/video_subgoal_train_text.beton', batch_size=batch_size, 285 | num_workers=num_workers, order=OrderOption.RANDOM, 286 | pipelines={ 287 | 'video': [NDArrayDecoder(), ToTensor()], 288 | 'text': [NDArrayDecoder(),] 289 | }) 290 | 291 | testloader = Loader(f'{base_path}/video_subgoal_train_text.beton', 292 | batch_size=batch_size, 293 | num_workers=num_workers, order=OrderOption.RANDOM, 294 | pipelines={ 295 | 'video': [NDArrayDecoder(), ToTensor()], 296 | 'text': [NDArrayDecoder(),] 297 | }) 298 | 299 | return trainloader, testloader, len(trainloader) 300 | 301 | if imgstr == 'UCF101': 302 | train_dir = os.path.join(data_location, 'UCF-101') 303 | test_dir = os.path.join(data_location, 'UCF-101') # We use all 304 | if cond: 305 | print("here") 306 | timesteps *= 2 # for long generation 307 | trainset = VideoFolderDataset(train_dir, train=True, resolution=resolution, n_frames=timesteps, skip=skip, seed=seed) 308 | print(len(trainset)) 309 | testset = VideoFolderDataset(train_dir, train=False, resolution=resolution, n_frames=timesteps, skip=skip, seed=seed) 310 | print(len(testset)) 311 | 312 | elif imgstr == 'SKY': 313 | train_dir = os.path.join(data_location, 'SKY') 314 | test_dir = os.path.join(data_location, 'SKY') 315 | if cond: 316 | print("here") 317 | timesteps *= 2 # for long generation 318 | trainset = ImageFolderDataset(train_dir, train=True, resolution=resolution, nframes=timesteps, cond=cond) 319 | print(len(trainset)) 320 | testset = ImageFolderDataset(test_dir, train=False, resolution=resolution, nframes=timesteps, cond=cond) 321 | print(len(testset)) 322 | 323 | else: 324 | raise NotImplementedError() 325 | 326 | shuffle = False if use_train_set else True 327 | 328 | kwargs = {'pin_memory': True, 'num_workers': 3} 329 | 330 | trainset_sampler = InfiniteSampler(dataset=trainset, rank=rank, num_replicas=n_gpus, seed=seed) 331 | trainloader = DataLoader(trainset, sampler=trainset_sampler, batch_size=batch_size // n_gpus, pin_memory=False, num_workers=4, prefetch_factor=2) 332 | 333 | testset_sampler = InfiniteSampler(testset, num_replicas=n_gpus, rank=rank, seed=seed) 334 | testloader = DataLoader(testset, sampler=testset_sampler, batch_size=batch_size // n_gpus, pin_memory=False, num_workers=4, prefetch_factor=2) 335 | 336 | return trainloader, trainloader, testloader 337 | 338 | 339 | -------------------------------------------------------------------------------- /PVDM/tools/scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LambdaWarmUpCosineScheduler: 5 | """ 6 | note: use with a base_lr of 1.0 7 | """ 8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): 9 | self.lr_warm_up_steps = warm_up_steps 10 | self.lr_start = lr_start 11 | self.lr_min = lr_min 12 | self.lr_max = lr_max 13 | self.lr_max_decay_steps = max_decay_steps 14 | self.last_lr = 0. 15 | self.verbosity_interval = verbosity_interval 16 | 17 | def schedule(self, n, **kwargs): 18 | if self.verbosity_interval > 0: 19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 20 | if n < self.lr_warm_up_steps: 21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start 22 | self.last_lr = lr 23 | return lr 24 | else: 25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) 26 | t = min(t, 1.0) 27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 28 | 1 + np.cos(t * np.pi)) 29 | self.last_lr = lr 30 | return lr 31 | 32 | def __call__(self, n, **kwargs): 33 | return self.schedule(n,**kwargs) 34 | 35 | 36 | class LambdaWarmUpCosineScheduler2: 37 | """ 38 | supports repeated iterations, configurable via lists 39 | note: use with a base_lr of 1.0. 40 | """ 41 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): 42 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) 43 | self.lr_warm_up_steps = warm_up_steps 44 | self.f_start = f_start 45 | self.f_min = f_min 46 | self.f_max = f_max 47 | self.cycle_lengths = cycle_lengths 48 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) 49 | self.last_f = 0. 50 | self.verbosity_interval = verbosity_interval 51 | 52 | def find_in_interval(self, n): 53 | interval = 0 54 | for cl in self.cum_cycles[1:]: 55 | if n <= cl: 56 | return interval 57 | interval += 1 58 | 59 | def schedule(self, n, **kwargs): 60 | cycle = self.find_in_interval(n) 61 | n = n - self.cum_cycles[cycle] 62 | if self.verbosity_interval > 0: 63 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 64 | f"current cycle {cycle}") 65 | if n < self.lr_warm_up_steps[cycle]: 66 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 67 | self.last_f = f 68 | return f 69 | else: 70 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) 71 | t = min(t, 1.0) 72 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( 73 | 1 + np.cos(t * np.pi)) 74 | self.last_f = f 75 | return f 76 | 77 | def __call__(self, n, **kwargs): 78 | return self.schedule(n, **kwargs) 79 | 80 | 81 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): 82 | 83 | def schedule(self, n, **kwargs): 84 | cycle = self.find_in_interval(n) 85 | n = n - self.cum_cycles[cycle] 86 | if self.verbosity_interval > 0: 87 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 88 | f"current cycle {cycle}") 89 | 90 | if n < self.lr_warm_up_steps[cycle]: 91 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 92 | self.last_f = f 93 | return f 94 | else: 95 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) 96 | self.last_f = f 97 | return f -------------------------------------------------------------------------------- /PVDM/tools/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | import sys; sys.path.extend([sys.path[0][:-4], '/app']) 5 | 6 | import time 7 | import tqdm 8 | import copy 9 | import torch 10 | import torch.nn.functional as F 11 | from torch.cuda.amp import GradScaler, autocast 12 | 13 | 14 | from utils import AverageMeter 15 | from evals.eval import test_psnr, test_ifvd, test_fvd_ddpm 16 | from models.ema import LitEma 17 | from einops import rearrange 18 | from torch.optim.lr_scheduler import LambdaLR 19 | from tqdm import tqdm 20 | from transformers import T5Tokenizer, T5EncoderModel 21 | from torch.distributions import Bernoulli 22 | 23 | 24 | def latentDDPM(rank, first_stage_model, model, opt, criterion, train_loader, test_loader, scheduler, ema_model=None, cond_prob=0.9, logger=None): 25 | scaler = GradScaler() 26 | 27 | if logger is None: 28 | log_ = print 29 | else: 30 | log_ = logger.log 31 | 32 | if rank == 0: 33 | rootdir = logger.logdir 34 | 35 | device = torch.device('cuda', rank) 36 | 37 | tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base") 38 | text_model = T5EncoderModel.from_pretrained("google/flan-t5-base") 39 | text_model = text_model.to(device) 40 | 41 | mask_dist = Bernoulli(probs=cond_prob) 42 | batch_size = train_loader.batch_size 43 | 44 | with torch.no_grad(): 45 | uncond_tokens = torch.LongTensor([tokenizer('', padding='max_length', max_length=15).input_ids for i in range(batch_size)]).to(device) 46 | uncond_latents = text_model(uncond_tokens).last_hidden_state.detach() 47 | 48 | losses = dict() 49 | losses['diffusion_loss'] = AverageMeter() 50 | check = time.time() 51 | 52 | if ema_model == None: 53 | ema_model = copy.deepcopy(model) 54 | ema = LitEma(ema_model) 55 | ema_model.eval() 56 | else: 57 | ema = LitEma(ema_model) 58 | ema.num_updates = torch.tensor(11200,dtype=torch.int) 59 | ema_model.eval() 60 | 61 | first_stage_model.eval() 62 | model.train() 63 | 64 | num_iters = 400000 # 25*len(train_loader) 65 | len_train_loader = len(train_loader) 66 | num_epochs = num_iters // len_train_loader 67 | print(f'Training for {num_epochs} epochs') 68 | 69 | for it_epoch in tqdm(range(num_epochs)): 70 | for it_loader, (x, text) in tqdm(enumerate(train_loader)): 71 | it = it_epoch*len_train_loader + it_loader 72 | x = x.to(device) 73 | x = rearrange(x / 127.5 - 1, 'b t h w c -> b c t h w') # videos 74 | 75 | text_masks = mask_dist.sample((batch_size,)).unsqueeze(1).unsqueeze(2).to(device) 76 | 77 | with torch.no_grad(): 78 | tokens = torch.LongTensor([tokenizer(text[i].tobytes().decode('ascii'), padding='max_length', max_length=15).input_ids for i in range(batch_size)]).to(device) 79 | text_latents = text_model(tokens).last_hidden_state.detach() 80 | text_latents = text_latents*text_masks + uncond_latents*(1-text_masks) 81 | 82 | # conditional free guidance training 83 | model.zero_grad() 84 | 85 | x = x[:,:,:,:,:] 86 | c = x[:,:,0:1,:,:].repeat(1,1,x.shape[2],1,1) 87 | 88 | with autocast(): 89 | with torch.no_grad(): 90 | z = first_stage_model.extract(x).detach() 91 | c = first_stage_model.extract(c).detach() 92 | 93 | (loss, t), loss_dict = criterion(x = z.float(), cond = c.float(), context=text_latents) 94 | 95 | """ 96 | scaler.scale(loss).backward() 97 | scaler.step(opt) 98 | scaler.update() 99 | """ 100 | loss.backward() 101 | opt.step() 102 | 103 | losses['diffusion_loss'].update(loss.item(), 1) 104 | 105 | # ema model 106 | if it % 25 == 0 and it > 0: 107 | ema(model) 108 | 109 | if it % 500 == 0: 110 | if logger is not None and rank == 0: 111 | logger.scalar_summary('train/diffusion_loss', losses['diffusion_loss'].average, it) 112 | 113 | log_('[Time %.3f] [Diffusion %f]' % 114 | (time.time() - check, losses['diffusion_loss'].average)) 115 | 116 | losses = dict() 117 | losses['diffusion_loss'] = AverageMeter() 118 | 119 | 120 | if it % 2000 == 0 and rank == 0: 121 | torch.save(model.state_dict(), rootdir + f'model_{it}.pth') 122 | ema.copy_to(ema_model) 123 | torch.save(ema_model.state_dict(), rootdir + f'ema_model_{it}.pth') 124 | fvd = test_fvd_ddpm(rank, ema_model, first_stage_model, test_loader, it, tokenizer, text_model, uncond_latents, logger) 125 | 126 | if logger is not None and rank == 0: 127 | logger.scalar_summary('test/fvd', fvd, it) 128 | log_('[Time %.3f] [FVD %f]' % 129 | (time.time() - check, fvd)) 130 | 131 | def first_stage_train(rank, model, opt, d_opt, criterion, train_loader, test_loader, first_model, fp, logger=None): 132 | if logger is None: 133 | log_ = print 134 | else: 135 | log_ = logger.log 136 | 137 | if rank == 0: 138 | rootdir = logger.logdir 139 | 140 | device = torch.device('cuda', rank) 141 | 142 | losses = dict() 143 | losses['ae_loss'] = AverageMeter() 144 | losses['d_loss'] = AverageMeter() 145 | check = time.time() 146 | 147 | accum_iter = 3 148 | disc_opt = False 149 | 150 | if fp: 151 | scaler = GradScaler() 152 | scaler_d = GradScaler() 153 | 154 | try: 155 | scaler.load_state_dict(torch.load(os.path.join(first_model, 'scaler.pth'))) 156 | scaler_d.load_state_dict(torch.load(os.path.join(first_model, 'scaler_d.pth'))) 157 | except: 158 | print("Fail to load scalers. Start from initial point.") 159 | 160 | model.train() 161 | disc_start = criterion.discriminator_iter_start 162 | num_iters = 25*len(train_loader) 163 | len_train_loader = len(train_loader) 164 | num_epochs = num_iters // len_train_loader 165 | 166 | for it_epoch in range(num_epochs): 167 | for it_loader, (x, _) in enumerate(train_loader): 168 | # x is (b, t, h, w, c) 169 | it = it_epoch*len_train_loader + it_loader 170 | batch_size = x.size(0) 171 | x = x.permute(0, 1, 4, 2, 3) 172 | x = x.contiguous() 173 | 174 | x = x.to(device) 175 | x = rearrange(x / 127.5 - 1, 'b t c h w -> b c t h w') # videos 176 | 177 | if not disc_opt: 178 | with autocast(): 179 | x_tilde, vq_loss = model(x) 180 | 181 | if it % accum_iter == 0: 182 | model.zero_grad() 183 | ae_loss = criterion(vq_loss, x, 184 | rearrange(x_tilde, '(b t) c h w -> b c t h w', b=batch_size), 185 | optimizer_idx=0, 186 | global_step=it) 187 | 188 | ae_loss = ae_loss / accum_iter 189 | 190 | scaler.scale(ae_loss).backward() 191 | 192 | if it % accum_iter == accum_iter - 1: 193 | scaler.step(opt) 194 | scaler.update() 195 | 196 | losses['ae_loss'].update(ae_loss.item(), 1) 197 | 198 | else: 199 | if it % accum_iter == 0: 200 | criterion.zero_grad() 201 | 202 | with autocast(): 203 | with torch.no_grad(): 204 | x_tilde, vq_loss = model(x) 205 | d_loss = criterion(vq_loss, x, 206 | rearrange(x_tilde, '(b t) c h w -> b c t h w', b=batch_size), 207 | optimizer_idx=1, 208 | global_step=it) 209 | d_loss = d_loss / accum_iter 210 | 211 | scaler_d.scale(d_loss).backward() 212 | 213 | if it % accum_iter == accum_iter - 1: 214 | # Unscales the gradients of optimizer's assigned params in-place 215 | scaler_d.unscale_(d_opt) 216 | 217 | # Since the gradients of optimizer's assigned params are unscaled, clips as usual: 218 | torch.nn.utils.clip_grad_norm_(criterion.discriminator_2d.parameters(), 1.0) 219 | torch.nn.utils.clip_grad_norm_(criterion.discriminator_3d.parameters(), 1.0) 220 | 221 | scaler_d.step(d_opt) 222 | scaler_d.update() 223 | 224 | losses['d_loss'].update(d_loss.item() * 3, 1) 225 | 226 | if it % accum_iter == accum_iter - 1 and it // accum_iter >= disc_start: 227 | if disc_opt: 228 | disc_opt = False 229 | else: 230 | disc_opt = True 231 | 232 | if it % 2000 == 0: 233 | fvd = test_ifvd(rank, model, test_loader, it, logger) 234 | psnr = test_psnr(rank, model, test_loader, it, logger) 235 | if logger is not None and rank == 0: 236 | logger.scalar_summary('train/ae_loss', losses['ae_loss'].average, it) 237 | logger.scalar_summary('train/d_loss', losses['d_loss'].average, it) 238 | logger.scalar_summary('test/psnr', psnr, it) 239 | logger.scalar_summary('test/fvd', fvd, it) 240 | 241 | log_('[Time %.3f] [AELoss %f] [DLoss %f] [PSNR %f] [FVD %f]' % 242 | (time.time() - check, losses['ae_loss'].average, losses['d_loss'].average, psnr, fvd)) 243 | # print('[Time %.3f] [AELoss %f] [DLoss %f] [PSNR %f]' % 244 | # (time.time() - check, losses['ae_loss'].average, losses['d_loss'].average, psnr)) 245 | 246 | torch.save(model.state_dict(), rootdir + f'model_last.pth') 247 | torch.save(criterion.state_dict(), rootdir + f'loss_last.pth') 248 | torch.save(opt.state_dict(), rootdir + f'opt.pth') 249 | torch.save(d_opt.state_dict(), rootdir + f'd_opt.pth') 250 | torch.save(scaler.state_dict(), rootdir + f'scaler.pth') 251 | torch.save(scaler_d.state_dict(), rootdir + f'scaler_d.pth') 252 | 253 | losses = dict() 254 | losses['ae_loss'] = AverageMeter() 255 | losses['d_loss'] = AverageMeter() 256 | 257 | if it % 2000 == 0 and rank == 0: 258 | torch.save(model.state_dict(), rootdir + f'model_{it}.pth') 259 | 260 | -------------------------------------------------------------------------------- /PVDM/tools/video_utils.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | import math 3 | import warnings 4 | from fractions import Fraction 5 | from typing import Any, Dict, List, Optional, Callable, Union, Tuple, TypeVar, cast 6 | 7 | import torch 8 | from torchvision.io import ( 9 | _probe_video_from_file, 10 | _read_video_from_file, 11 | read_video, 12 | read_video_timestamps, 13 | ) 14 | 15 | from tqdm import tqdm 16 | 17 | T = TypeVar("T") 18 | 19 | 20 | def pts_convert(pts: int, timebase_from: Fraction, timebase_to: Fraction, round_func: Callable = math.floor) -> int: 21 | """convert pts between different time bases 22 | Args: 23 | pts: presentation timestamp, float 24 | timebase_from: original timebase. Fraction 25 | timebase_to: new timebase. Fraction 26 | round_func: rounding function. 27 | """ 28 | new_pts = Fraction(pts, 1) * timebase_from / timebase_to 29 | return round_func(new_pts) 30 | 31 | 32 | def unfold(tensor: torch.Tensor, size: int, step: int, dilation: int = 1) -> torch.Tensor: 33 | """ 34 | similar to tensor.unfold, but with the dilation 35 | and specialized for 1d tensors 36 | Returns all consecutive windows of `size` elements, with 37 | `step` between windows. The distance between each element 38 | in a window is given by `dilation`. 39 | """ 40 | if tensor.dim() != 1: 41 | raise ValueError(f"tensor should have 1 dimension instead of {tensor.dim()}") 42 | o_stride = tensor.stride(0) 43 | numel = tensor.numel() 44 | new_stride = (step * o_stride, dilation * o_stride) 45 | new_size = ((numel - (dilation * (size - 1) + 1)) // step + 1, size) 46 | if new_size[0] < 1: 47 | new_size = (0, size) 48 | return torch.as_strided(tensor, new_size, new_stride) 49 | 50 | 51 | class _VideoTimestampsDataset: 52 | """ 53 | Dataset used to parallelize the reading of the timestamps 54 | of a list of videos, given their paths in the filesystem. 55 | Used in VideoClips and defined at top level so it can be 56 | pickled when forking. 57 | """ 58 | 59 | def __init__(self, video_paths: List[str]) -> None: 60 | self.video_paths = video_paths 61 | 62 | def __len__(self) -> int: 63 | return len(self.video_paths) 64 | 65 | def __getitem__(self, idx: int) -> Tuple[List[int], Optional[float]]: 66 | return read_video_timestamps(self.video_paths[idx]) 67 | 68 | 69 | def _collate_fn(x: T) -> T: 70 | """ 71 | Dummy collate function to be used with _VideoTimestampsDataset 72 | """ 73 | return x 74 | 75 | 76 | class VideoClips: 77 | """ 78 | Given a list of video files, computes all consecutive subvideos of size 79 | `clip_length_in_frames`, where the distance between each subvideo in the 80 | same video is defined by `frames_between_clips`. 81 | If `frame_rate` is specified, it will also resample all the videos to have 82 | the same frame rate, and the clips will refer to this frame rate. 83 | Creating this instance the first time is time-consuming, as it needs to 84 | decode all the videos in `video_paths`. It is recommended that you 85 | cache the results after instantiation of the class. 86 | Recreating the clips for different clip lengths is fast, and can be done 87 | with the `compute_clips` method. 88 | Args: 89 | video_paths (List[str]): paths to the video files 90 | clip_length_in_frames (int): size of a clip in number of frames 91 | frames_between_clips (int): step (in frames) between each clip 92 | frame_rate (int, optional): if specified, it will resample the video 93 | so that it has `frame_rate`, and then the clips will be defined 94 | on the resampled video 95 | num_workers (int): how many subprocesses to use for data loading. 96 | 0 means that the data will be loaded in the main process. (default: 0) 97 | """ 98 | 99 | def __init__( 100 | self, 101 | video_paths: List[str], 102 | clip_length_in_frames: int = 16, 103 | frames_between_clips: int = 1, 104 | frame_rate: Optional[int] = None, 105 | _precomputed_metadata: Optional[Dict[str, Any]] = None, 106 | num_workers: int = 0, 107 | _video_width: int = 0, 108 | _video_height: int = 0, 109 | _video_min_dimension: int = 0, 110 | _video_max_dimension: int = 0, 111 | _audio_samples: int = 0, 112 | _audio_channels: int = 0, 113 | ) -> None: 114 | 115 | self.video_paths = video_paths 116 | self.num_workers = num_workers 117 | 118 | # these options are not valid for pyav backend 119 | self._video_width = _video_width 120 | self._video_height = _video_height 121 | self._video_min_dimension = _video_min_dimension 122 | self._video_max_dimension = _video_max_dimension 123 | self._audio_samples = _audio_samples 124 | self._audio_channels = _audio_channels 125 | 126 | if _precomputed_metadata is None: 127 | self._compute_frame_pts() 128 | else: 129 | self._init_from_metadata(_precomputed_metadata) 130 | self.compute_clips(clip_length_in_frames, frames_between_clips, frame_rate) 131 | 132 | def _compute_frame_pts(self) -> None: 133 | self.video_pts = [] 134 | self.video_fps = [] 135 | 136 | # strategy: use a DataLoader to parallelize read_video_timestamps 137 | # so need to create a dummy dataset first 138 | import torch.utils.data 139 | 140 | dl: torch.utils.data.DataLoader = torch.utils.data.DataLoader( 141 | _VideoTimestampsDataset(self.video_paths), # type: ignore[arg-type] 142 | batch_size=16, 143 | num_workers=self.num_workers, 144 | collate_fn=_collate_fn, 145 | ) 146 | 147 | with tqdm(total=len(dl)) as pbar: 148 | for batch in dl: 149 | pbar.update(1) 150 | clips, fps = list(zip(*batch)) 151 | # we need to specify dtype=torch.long because for empty list, 152 | # torch.as_tensor will use torch.float as default dtype. This 153 | # happens when decoding fails and no pts is returned in the list. 154 | clips = [torch.as_tensor(c, dtype=torch.long) for c in clips] 155 | self.video_pts.extend(clips) 156 | self.video_fps.extend(fps) 157 | 158 | def _init_from_metadata(self, metadata: Dict[str, Any]) -> None: 159 | self.video_paths = metadata["video_paths"] 160 | assert len(self.video_paths) == len(metadata["video_pts"]) 161 | self.video_pts = metadata["video_pts"] 162 | assert len(self.video_paths) == len(metadata["video_fps"]) 163 | self.video_fps = metadata["video_fps"] 164 | 165 | @property 166 | def metadata(self) -> Dict[str, Any]: 167 | _metadata = { 168 | "video_paths": self.video_paths, 169 | "video_pts": self.video_pts, 170 | "video_fps": self.video_fps, 171 | } 172 | return _metadata 173 | 174 | def subset(self, indices: List[int]) -> "VideoClips": 175 | video_paths = [self.video_paths[i] for i in indices] 176 | video_pts = [self.video_pts[i] for i in indices] 177 | video_fps = [self.video_fps[i] for i in indices] 178 | metadata = { 179 | "video_paths": video_paths, 180 | "video_pts": video_pts, 181 | "video_fps": video_fps, 182 | } 183 | return type(self)( 184 | video_paths, 185 | self.num_frames, 186 | self.step, 187 | self.frame_rate, 188 | _precomputed_metadata=metadata, 189 | num_workers=self.num_workers, 190 | _video_width=self._video_width, 191 | _video_height=self._video_height, 192 | _video_min_dimension=self._video_min_dimension, 193 | _video_max_dimension=self._video_max_dimension, 194 | _audio_samples=self._audio_samples, 195 | _audio_channels=self._audio_channels, 196 | ) 197 | 198 | @staticmethod 199 | def compute_clips_for_video( 200 | video_pts: torch.Tensor, num_frames: int, step: int, fps: int, frame_rate: Optional[int] = None 201 | ) -> Tuple[torch.Tensor, Union[List[slice], torch.Tensor]]: 202 | if fps is None: 203 | # if for some reason the video doesn't have fps (because doesn't have a video stream) 204 | # set the fps to 1. The value doesn't matter, because video_pts is empty anyway 205 | fps = 1 206 | if frame_rate is None: 207 | frame_rate = fps 208 | total_frames = len(video_pts) * (float(frame_rate) / fps) 209 | _idxs = VideoClips._resample_video_idx(int(math.floor(total_frames)), fps, frame_rate) 210 | video_pts = video_pts[_idxs] 211 | clips = unfold(video_pts, num_frames, step) 212 | if not clips.numel(): 213 | warnings.warn( 214 | "There aren't enough frames in the current video to get a clip for the given clip length and " 215 | "frames between clips. The video (and potentially others) will be skipped." 216 | ) 217 | idxs: Union[List[slice], torch.Tensor] 218 | if isinstance(_idxs, slice): 219 | idxs = [_idxs] * len(clips) 220 | else: 221 | idxs = unfold(_idxs, num_frames, step) 222 | return clips, idxs 223 | 224 | def compute_clips(self, num_frames: int, step: int, frame_rate: Optional[int] = None) -> None: 225 | """ 226 | Compute all consecutive sequences of clips from video_pts. 227 | Always returns clips of size `num_frames`, meaning that the 228 | last few frames in a video can potentially be dropped. 229 | Args: 230 | num_frames (int): number of frames for the clip 231 | step (int): distance between two clips 232 | frame_rate (int, optional): The frame rate 233 | """ 234 | self.num_frames = num_frames 235 | self.step = step 236 | self.frame_rate = frame_rate 237 | self.clips = [] 238 | self.resampling_idxs = [] 239 | for video_pts, fps in zip(self.video_pts, self.video_fps): 240 | clips, idxs = self.compute_clips_for_video(video_pts, num_frames, step, fps, frame_rate) 241 | self.clips.append(clips) 242 | self.resampling_idxs.append(idxs) 243 | clip_lengths = torch.as_tensor([len(v) for v in self.clips]) 244 | self.cumulative_sizes = clip_lengths.cumsum(0).tolist() 245 | 246 | def __len__(self) -> int: 247 | return self.num_clips() 248 | 249 | def num_videos(self) -> int: 250 | return len(self.video_paths) 251 | 252 | def num_clips(self) -> int: 253 | """ 254 | Number of subclips that are available in the video list. 255 | """ 256 | return self.cumulative_sizes[-1] 257 | 258 | def get_clip_location(self, idx: int) -> Tuple[int, int]: 259 | """ 260 | Converts a flattened representation of the indices into a video_idx, clip_idx 261 | representation. 262 | """ 263 | video_idx = bisect.bisect_right(self.cumulative_sizes, idx) 264 | if video_idx == 0: 265 | clip_idx = idx 266 | else: 267 | clip_idx = idx - self.cumulative_sizes[video_idx - 1] 268 | return video_idx, clip_idx 269 | 270 | @staticmethod 271 | def _resample_video_idx(num_frames: int, original_fps: int, new_fps: int) -> Union[slice, torch.Tensor]: 272 | step = float(original_fps) / new_fps 273 | if step.is_integer(): 274 | # optimization: if step is integer, don't need to perform 275 | # advanced indexing 276 | step = int(step) 277 | return slice(None, None, step) 278 | idxs = torch.arange(num_frames, dtype=torch.float32) * step 279 | idxs = idxs.floor().to(torch.int64) 280 | return idxs 281 | 282 | def get_clip(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any], int]: 283 | """ 284 | Gets a subclip from a list of videos. 285 | Args: 286 | idx (int): index of the subclip. Must be between 0 and num_clips(). 287 | Returns: 288 | video (Tensor) 289 | audio (Tensor) 290 | info (Dict) 291 | video_idx (int): index of the video in `video_paths` 292 | """ 293 | if idx >= self.num_clips(): 294 | raise IndexError(f"Index {idx} out of range ({self.num_clips()} number of clips)") 295 | video_idx, clip_idx = self.get_clip_location(idx) 296 | video_path = self.video_paths[video_idx] 297 | clip_pts = self.clips[video_idx][clip_idx] 298 | 299 | from torchvision import get_video_backend 300 | 301 | backend = get_video_backend() 302 | 303 | if backend == "pyav": 304 | # check for invalid options 305 | if self._video_width != 0: 306 | raise ValueError("pyav backend doesn't support _video_width != 0") 307 | if self._video_height != 0: 308 | raise ValueError("pyav backend doesn't support _video_height != 0") 309 | if self._video_min_dimension != 0: 310 | raise ValueError("pyav backend doesn't support _video_min_dimension != 0") 311 | if self._video_max_dimension != 0: 312 | raise ValueError("pyav backend doesn't support _video_max_dimension != 0") 313 | if self._audio_samples != 0: 314 | raise ValueError("pyav backend doesn't support _audio_samples != 0") 315 | 316 | if backend == "pyav": 317 | start_pts = clip_pts[0].item() 318 | end_pts = clip_pts[-1].item() 319 | video, audio, info = read_video(video_path, start_pts, end_pts) 320 | else: 321 | _info = _probe_video_from_file(video_path) 322 | video_fps = _info.video_fps 323 | audio_fps = None 324 | 325 | video_start_pts = cast(int, clip_pts[0].item()) 326 | video_end_pts = cast(int, clip_pts[-1].item()) 327 | 328 | audio_start_pts, audio_end_pts = 0, -1 329 | audio_timebase = Fraction(0, 1) 330 | video_timebase = Fraction(_info.video_timebase.numerator, _info.video_timebase.denominator) 331 | if _info.has_audio: 332 | audio_timebase = Fraction(_info.audio_timebase.numerator, _info.audio_timebase.denominator) 333 | audio_start_pts = pts_convert(video_start_pts, video_timebase, audio_timebase, math.floor) 334 | audio_end_pts = pts_convert(video_end_pts, video_timebase, audio_timebase, math.ceil) 335 | audio_fps = _info.audio_sample_rate 336 | video, audio, _ = _read_video_from_file( 337 | video_path, 338 | video_width=self._video_width, 339 | video_height=self._video_height, 340 | video_min_dimension=self._video_min_dimension, 341 | video_max_dimension=self._video_max_dimension, 342 | video_pts_range=(video_start_pts, video_end_pts), 343 | video_timebase=video_timebase, 344 | audio_samples=self._audio_samples, 345 | audio_channels=self._audio_channels, 346 | audio_pts_range=(audio_start_pts, audio_end_pts), 347 | audio_timebase=audio_timebase, 348 | ) 349 | 350 | info = {"video_fps": video_fps} 351 | if audio_fps is not None: 352 | info["audio_fps"] = audio_fps 353 | 354 | if self.frame_rate is not None: 355 | resampling_idx = self.resampling_idxs[video_idx][clip_idx] 356 | if isinstance(resampling_idx, torch.Tensor): 357 | resampling_idx = resampling_idx - resampling_idx[0] 358 | video = video[resampling_idx] 359 | info["video_fps"] = self.frame_rate 360 | assert len(video) == self.num_frames, f"{video.shape} x {self.num_frames}" 361 | return video, audio, info, video_idx 362 | 363 | def __getstate__(self) -> Dict[str, Any]: 364 | video_pts_sizes = [len(v) for v in self.video_pts] 365 | # To be back-compatible, we convert data to dtype torch.long as needed 366 | # because for empty list, in legacy implementation, torch.as_tensor will 367 | # use torch.float as default dtype. This happens when decoding fails and 368 | # no pts is returned in the list. 369 | video_pts = [x.to(torch.int64) for x in self.video_pts] 370 | # video_pts can be an empty list if no frames have been decoded 371 | if video_pts: 372 | video_pts = torch.cat(video_pts) # type: ignore[assignment] 373 | # avoid bug in https://github.com/pytorch/pytorch/issues/32351 374 | # TODO: Revert it once the bug is fixed. 375 | video_pts = video_pts.numpy() # type: ignore[attr-defined] 376 | 377 | # make a copy of the fields of self 378 | d = self.__dict__.copy() 379 | d["video_pts_sizes"] = video_pts_sizes 380 | d["video_pts"] = video_pts 381 | # delete the following attributes to reduce the size of dictionary. They 382 | # will be re-computed in "__setstate__()" 383 | del d["clips"] 384 | del d["resampling_idxs"] 385 | del d["cumulative_sizes"] 386 | 387 | # for backwards-compatibility 388 | d["_version"] = 2 389 | return d 390 | 391 | def __setstate__(self, d: Dict[str, Any]) -> None: 392 | # for backwards-compatibility 393 | if "_version" not in d: 394 | self.__dict__ = d 395 | return 396 | 397 | video_pts = torch.as_tensor(d["video_pts"], dtype=torch.int64) 398 | video_pts = torch.split(video_pts, d["video_pts_sizes"], dim=0) 399 | # don't need this info anymore 400 | del d["video_pts_sizes"] 401 | 402 | d["video_pts"] = video_pts 403 | self.__dict__ = d 404 | # recompute attributes "clips", "resampling_idxs" and other derivative ones 405 | self.compute_clips(self.num_frames, self.step, self.frame_rate) -------------------------------------------------------------------------------- /PVDM/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import shutil 4 | import sys 5 | from datetime import datetime 6 | 7 | import numpy as np 8 | import torch 9 | from torch.utils.tensorboard import SummaryWriter 10 | 11 | import gdown 12 | 13 | 14 | class Logger(object): 15 | def __init__(self, fn, ask=True): 16 | if not os.path.exists("./results/"): 17 | os.mkdir("./results/") 18 | 19 | logdir = self._make_dir(fn) 20 | if not os.path.exists(logdir): 21 | os.mkdir(logdir) 22 | 23 | if len(os.listdir(logdir)) != 0 and ask: 24 | exit(1) 25 | 26 | self.set_dir(logdir) 27 | 28 | def _make_dir(self, fn): 29 | # today = datetime.today().strftime("%y%m%d") 30 | logdir = f'./results/{fn}/' 31 | return logdir 32 | 33 | def set_dir(self, logdir, log_fn='log.txt'): 34 | self.logdir = logdir 35 | if not os.path.exists(logdir): 36 | os.mkdir(logdir) 37 | self.writer = SummaryWriter(logdir) 38 | self.log_file = open(os.path.join(logdir, log_fn), 'a') 39 | 40 | def log(self, string): 41 | self.log_file.write('[%s] %s' % (datetime.now(), string) + '\n') 42 | self.log_file.flush() 43 | 44 | print('[%s] %s' % (datetime.now(), string)) 45 | sys.stdout.flush() 46 | 47 | def log_dirname(self, string): 48 | self.log_file.write('%s (%s)' % (string, self.logdir) + '\n') 49 | self.log_file.flush() 50 | 51 | print('%s (%s)' % (string, self.logdir)) 52 | sys.stdout.flush() 53 | 54 | def scalar_summary(self, tag, value, step): 55 | """Log a scalar variable.""" 56 | self.writer.add_scalar(tag, value, step) 57 | 58 | def image_summary(self, tag, images, step): 59 | """Log a list of images.""" 60 | self.writer.add_image(tag, images, step) 61 | 62 | def video_summary(self, tag, videos, step): 63 | self.writer.add_video(tag, videos, step, fps=16) 64 | 65 | def histo_summary(self, tag, values, step): 66 | """Log a histogram of the tensor of values.""" 67 | self.writer.add_histogram(tag, values, step, bins='auto') 68 | 69 | 70 | class AverageMeter(object): 71 | """Computes and stores the average and current value""" 72 | 73 | def __init__(self): 74 | self.value = 0 75 | self.average = 0 76 | self.sum = 0 77 | self.count = 0 78 | 79 | def reset(self): 80 | self.value = 0 81 | self.average = 0 82 | self.sum = 0 83 | self.count = 0 84 | 85 | def update(self, value, n=1): 86 | self.value = value 87 | self.sum += value * n 88 | self.count += n 89 | self.average = self.sum / self.count 90 | 91 | 92 | def set_random_seed(seed): 93 | random.seed(seed) 94 | np.random.seed(seed) 95 | torch.manual_seed(seed) 96 | torch.cuda.manual_seed(seed) 97 | torch.cuda.manual_seed_all(seed) 98 | 99 | 100 | def file_name(args): 101 | fn = f'{args.exp}_{args.id}_{args.data}' 102 | fn += f'_{args.seed}' 103 | return fn 104 | 105 | 106 | def psnr(mse): 107 | """ 108 | Computes PSNR from MSE. 109 | """ 110 | return -10.0 * mse.log10() 111 | 112 | def download(id, fname, root=os.path.expanduser('~/.cache/video-diffusion')): 113 | os.makedirs(root, exist_ok=True) 114 | destination = os.path.join(root, fname) 115 | 116 | if os.path.exists(destination): 117 | return destination 118 | 119 | gdown.download(id=id, output=destination, quiet=False) 120 | return destination 121 | 122 | 123 | def make_pairs(l, t1, t2, num_pairs, given_vid): 124 | B, T, C, H, W = given_vid.size() 125 | idx1 = t1.view(B, num_pairs, 1, 1, 1, 1).expand(B, num_pairs, 1, C, H, W).type(torch.int64) 126 | frame1 = torch.gather(given_vid.unsqueeze(1).repeat(1,num_pairs, 1,1,1,1), 2, idx1).squeeze() 127 | t1 = t1.float() / (l - 1) 128 | 129 | idx2 = t2.view(B, num_pairs, 1, 1, 1, 1).expand(B, num_pairs, 1, C, H, W).type(torch.int64) 130 | frame2 = torch.gather(given_vid.unsqueeze(1).repeat(1,num_pairs,1,1,1,1), 2, idx2).squeeze() 131 | t2 = t2.float() / (l - 1) 132 | 133 | frame1 = frame1.view(-1, C, H, W) 134 | frame2 = frame2.view(-1, C, H, W) 135 | 136 | # sort by t 137 | t1 = t1.view(-1, 1, 1, 1).repeat(1, C, H, W) 138 | t2 = t2.view(-1, 1, 1, 1).repeat(1, C, H, W) 139 | 140 | ret_frame1 = torch.where(t1 < t2, frame1, frame2) 141 | ret_frame2 = torch.where(t1 < t2, frame2 ,frame1) 142 | 143 | t1 = t1[:, 0:1] 144 | t2 = t2[:, 0:1] 145 | 146 | ret_t1 = torch.where(t1 < t2, t1, t2) 147 | ret_t2 = torch.where(t1 < t2, t2, t1) 148 | 149 | dt = ret_t2 - ret_t1 150 | 151 | return torch.cat([ret_frame1, ret_frame2, dt], dim=1) 152 | 153 | def make_mixed_pairs(l, t1, t2, given_vid_real, given_vid_fake): 154 | B, T, C, H, W = given_vid_real.size() 155 | idx1 = t1.view(-1, 1, 1, 1, 1).expand(B, 1, C, H, W).type(torch.int64) 156 | frame1 = torch.gather(given_vid_real, 1, idx1).squeeze() 157 | t1 = t1.float() / (l - 1) 158 | 159 | idx2 = t2.view(-1, 1, 1, 1, 1).expand(B, 1, C, H, W).type(torch.int64) 160 | frame2 = torch.gather(given_vid_fake, 1, idx2).squeeze() 161 | t2 = t2.float() / (l - 1) 162 | 163 | 164 | # sort by t 165 | t1 = t1.view(-1, 1, 1, 1).repeat(1, C, H, W) 166 | t2 = t2.view(-1, 1, 1, 1).repeat(1, C, H, W) 167 | 168 | ret_frame1 = torch.where(t1 < t2, frame1, frame2) 169 | ret_frame2 = torch.where(t1 < t2, frame2 ,frame1) 170 | 171 | t1 = t1[:, 0:1] 172 | t2 = t2[:, 0:1] 173 | 174 | ret_t1 = torch.where(t1 < t2, t1, t2) 175 | ret_t2 = torch.where(t1 < t2, t2, t1) 176 | 177 | dt = ret_t2 - ret_t1 178 | 179 | return torch.cat([ret_frame1, ret_frame2, dt], dim=1) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Compositional Foundation Models for Hierarchical Planning 2 |