├── PVDM ├── .gitignore ├── assets │ ├── sky_long.gif │ └── ucf101_long.gif ├── configs │ ├── autoencoder │ │ ├── base.yaml │ │ └── base_gan.yaml │ └── latent-diffusion │ │ └── base.yaml ├── evals │ ├── __init__.py │ ├── eval.py │ └── fvd │ │ ├── __init__.py │ │ ├── convert_tf_pretrained.py │ │ ├── download.py │ │ ├── fvd.py │ │ └── pytorch_i3d.py ├── exps │ ├── __init__.py │ ├── diffusion.py │ └── first_stage.py ├── losses │ ├── ddpm.py │ ├── diffaugment.py │ ├── lpips.py │ ├── perceptual.py │ └── vgg.pth ├── main.py ├── metric_utils.py ├── models │ ├── __init__.py │ ├── autoencoder │ │ ├── __init__.py │ │ ├── autoencoder_vit.py │ │ └── vit_modules.py │ ├── ddpm │ │ ├── __init__.py │ │ ├── diffusionmodules.py │ │ └── unet.py │ └── ema.py ├── tools │ ├── __init__.py │ ├── data_utils.py │ ├── dataloader.py │ ├── scheduler.py │ ├── trainer.py │ └── video_utils.py └── utils.py ├── README.md ├── images └── teaser.png ├── inv_dyn └── inv_dyn_ft.py └── task_subgoal_consistency ├── __init__.py ├── arguments.py ├── datasets.py ├── networks.py ├── train.py └── utils.py /PVDM/.gitignore: -------------------------------------------------------------------------------- 1 | sftp-config.json 2 | .DS_Store 3 | -------------------------------------------------------------------------------- /PVDM/assets/sky_long.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anuragajay/hip/f83186af477bf6d8c5f8f3c1fdef43ccb89db25d/PVDM/assets/sky_long.gif -------------------------------------------------------------------------------- /PVDM/assets/ucf101_long.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anuragajay/hip/f83186af477bf6d8c5f8f3c1fdef43ccb89db25d/PVDM/assets/ucf101_long.gif -------------------------------------------------------------------------------- /PVDM/configs/autoencoder/base.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | resume: False 3 | amp: True 4 | base_learning_rate: 1.0e-4 5 | params: 6 | embed_dim: 4 7 | lossconfig: 8 | params: 9 | disc_start: 100000000 10 | 11 | ddconfig: 12 | double_z: False 13 | channels: 384 14 | resolution: 64 15 | timesteps: 16 16 | skip: 1 17 | in_channels: 3 18 | out_ch: 3 19 | num_res_blocks: 2 20 | attn_resolutions: [] 21 | splits: 1 22 | -------------------------------------------------------------------------------- /PVDM/configs/autoencoder/base_gan.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | resume: True 3 | amp: True 4 | base_learning_rate: 1.0e-4 5 | params: 6 | embed_dim: 4 7 | lossconfig: 8 | params: 9 | disc_start: -1 10 | 11 | ddconfig: 12 | double_z: False 13 | channels: 384 14 | resolution: 64 15 | timesteps: 16 16 | skip: 1 17 | in_channels: 3 18 | out_ch: 3 19 | num_res_blocks: 2 20 | attn_resolutions: [] 21 | splits: 1 22 | -------------------------------------------------------------------------------- /PVDM/configs/latent-diffusion/base.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-4 # set to target_lr by starting main.py with '--scale_lr False' 3 | cond_model: True 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | loss_type: l1 11 | first_stage_key: "image" 12 | cond_stage_key: "image" 13 | image_size: 32 14 | channels: 4 15 | cond_stage_trainable: False 16 | concat_mode: False 17 | scale_by_std: True 18 | w: 0. 19 | 20 | scheduler_config: # 10000 warmup steps 21 | warm_up_steps: [10000] 22 | cycle_lengths: [10000000000000] 23 | f_start: [1.e-6] 24 | f_max: [1.] 25 | f_min: [ 1.] 26 | 27 | unet_config: 28 | image_size: 32 29 | in_channels: 4 30 | out_channels: 4 31 | model_channels: 128 32 | attention_resolutions: [4,2,1] # 32, 16, 8, 4, 33 | num_res_blocks: 2 34 | channel_mult: [1,2,4] # 32, 16, 8, 4, 2 35 | num_heads: 8 36 | use_scale_shift_norm: True 37 | resblock_updown: True 38 | cond_model: True -------------------------------------------------------------------------------- /PVDM/evals/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anuragajay/hip/f83186af477bf6d8c5f8f3c1fdef43ccb89db25d/PVDM/evals/__init__.py -------------------------------------------------------------------------------- /PVDM/evals/eval.py: -------------------------------------------------------------------------------- 1 | import time 2 | import sys; sys.path.extend(['.', 'src']) 3 | import numpy as np 4 | import torch 5 | from utils import AverageMeter 6 | from torchvision.utils import save_image, make_grid 7 | from einops import rearrange 8 | from losses.ddpm import DDPM 9 | from torch.cuda.amp import GradScaler, autocast 10 | 11 | from evals.fvd.fvd import get_fvd_logits, frechet_distance 12 | from evals.fvd.download import load_i3d_pretrained 13 | import os 14 | 15 | import torchvision 16 | import PIL 17 | 18 | def save_image_grid(img, fname, drange, grid_size, normalize=True): 19 | if normalize: 20 | lo, hi = drange 21 | img = np.asarray(img, dtype=np.float32) 22 | img = (img - lo) * (255 / (hi - lo)) 23 | img = np.rint(img).clip(0, 255).astype(np.uint8) 24 | 25 | gw, gh = grid_size 26 | _N, C, T, H, W = img.shape 27 | img = img.reshape(gh, gw, C, T, H, W) 28 | img = img.transpose(3, 0, 4, 1, 5, 2) 29 | img = img.reshape(T, gh * H, gw * W, C) 30 | 31 | print (f'Saving Video with {T} frames, img shape {H}, {W}') 32 | 33 | assert C in [3] 34 | 35 | if C == 3: 36 | torchvision.io.write_video(f'{fname[:-3]}mp4', torch.from_numpy(img), fps=16) 37 | imgs = [PIL.Image.fromarray(img[i], 'RGB') for i in range(len(img))] 38 | imgs[0].save(fname, quality=95, save_all=True, append_images=imgs[1:], duration=100, loop=0) 39 | 40 | return img 41 | 42 | def test_psnr(rank, model, loader, it, logger=None): 43 | device = torch.device('cuda', rank) 44 | 45 | losses = dict() 46 | losses['psnr'] = AverageMeter() 47 | check = time.time() 48 | 49 | model.eval() 50 | with torch.no_grad(): 51 | for n, (x, _) in enumerate(loader): 52 | if n > 100: 53 | break 54 | x = x.permute(0, 1, 4, 2, 3) 55 | x = x.contiguous() 56 | 57 | batch_size = x.size(0) 58 | clip_length = x.size(1) 59 | x = x.to(device) / 127.5 - 1 60 | recon, _ = model(rearrange(x, 'b t c h w -> b c t h w')) 61 | 62 | x = x.view(batch_size, -1) 63 | recon = recon.view(batch_size, -1) 64 | 65 | mse = ((x * 0.5 - recon * 0.5) ** 2).mean(dim=-1) 66 | psnr = (-10 * torch.log10(mse)).mean() 67 | 68 | losses['psnr'].update(psnr.item(), batch_size) 69 | 70 | 71 | model.train() 72 | return losses['psnr'].average 73 | 74 | def test_ifvd(rank, model, loader, it, logger=None): 75 | device = torch.device('cuda', rank) 76 | 77 | losses = dict() 78 | losses['fvd'] = AverageMeter() 79 | check = time.time() 80 | 81 | real_embeddings = [] 82 | fake_embeddings = [] 83 | fakes = [] 84 | reals = [] 85 | 86 | model.eval() 87 | i3d = load_i3d_pretrained(device) 88 | 89 | with torch.no_grad(): 90 | for n, (real, idx) in enumerate(loader): 91 | if n > 512: 92 | break 93 | real = real.permute(0, 1, 4, 2, 3) 94 | real = real.contiguous() 95 | 96 | batch_size = real.size(0) 97 | clip_length = real.size(1) 98 | real = real.to(device) 99 | fake, _ = model(rearrange(real / 127.5 - 1, 'b t c h w -> b c t h w')) 100 | 101 | real = rearrange(real, 'b t c h w -> b t h w c') # videos 102 | fake = rearrange((fake.clamp(-1,1) + 1) * 127.5, '(b t) c h w -> b t h w c', b=real.size(0)) 103 | 104 | real = real.type(torch.uint8).cpu() 105 | fake = fake.type(torch.uint8) 106 | 107 | real_embeddings.append(get_fvd_logits(real.numpy(), i3d=i3d, device=device)) 108 | fake_embeddings.append(get_fvd_logits(fake.cpu().numpy(), i3d=i3d, device=device)) 109 | if len(fakes) < 16: 110 | reals.append(rearrange(real[0:1], 'b t h w c -> b c t h w')) 111 | fakes.append(rearrange(fake[0:1], 'b t h w c -> b c t h w')) 112 | 113 | model.train() 114 | 115 | reals = torch.cat(reals) 116 | fakes = torch.cat(fakes) 117 | 118 | if rank == 0: 119 | real_vid = save_image_grid(reals.cpu().numpy(), os.path.join(logger.logdir, "real.gif"), drange=[0, 255], grid_size=(4,4)) 120 | fake_vid = save_image_grid(fakes.cpu().numpy(), os.path.join(logger.logdir, f'generated_{it}.gif'), drange=[0, 255], grid_size=(4,4)) 121 | 122 | if it == 0: 123 | real_vid = np.expand_dims(real_vid,0).transpose(0, 1, 4, 2, 3) 124 | logger.video_summary('real', real_vid, it) 125 | 126 | fake_vid = np.expand_dims(fake_vid,0).transpose(0, 1, 4, 2, 3) 127 | logger.video_summary('recon', fake_vid, it) 128 | 129 | real_embeddings = torch.cat(real_embeddings) 130 | fake_embeddings = torch.cat(fake_embeddings) 131 | 132 | fvd = frechet_distance(fake_embeddings.clone().detach(), real_embeddings.clone().detach()) 133 | return fvd.item() 134 | 135 | 136 | def test_fvd_ddpm(rank, ema_model, decoder, loader, it, tokenizer, text_model, uncond_latents, logger=None): 137 | device = torch.device('cuda', rank) 138 | 139 | losses = dict() 140 | losses['fvd'] = AverageMeter() 141 | check = time.time() 142 | 143 | cond_model = ema_model.diffusion_model.cond_model 144 | 145 | diffusion_model = DDPM(ema_model, 146 | channels=ema_model.diffusion_model.in_channels, 147 | image_size=ema_model.diffusion_model.image_size, 148 | sampling_timesteps=1000, 149 | w=0.).to(device) 150 | real_embeddings = [] 151 | pred_embeddings = [] 152 | 153 | reals = [] 154 | predictions = [] 155 | 156 | batch_size = loader.batch_size 157 | 158 | i3d = load_i3d_pretrained(device) 159 | 160 | if cond_model: 161 | with torch.no_grad(): 162 | for n, (x, text) in enumerate(loader): 163 | x = x.to(device) 164 | x = rearrange(x / 127.5 - 1, 'b t h w c -> b c t h w') # videos 165 | 166 | k = min(4, x.size(0)) 167 | if n >= 4: 168 | break 169 | 170 | tokens = torch.LongTensor([tokenizer(text[i].tobytes().decode('ascii'), padding='max_length', max_length=15).input_ids for i in range(batch_size)]).to(device) 171 | text_latents = text_model(tokens).last_hidden_state.detach() 172 | text_latents = text_latents[:k] 173 | 174 | real = x[:k,:,:,:,:] 175 | c = x[:k,:,0:1,:,:].repeat(1,1,x.shape[2],1,1) 176 | 177 | with autocast(): 178 | c = decoder.extract(c).detach() 179 | 180 | z = diffusion_model.sample(batch_size=k, cond=c, context=text_latents, uncond_latents=uncond_latents[:k]) 181 | pred = decoder.decode_from_sample(z).clamp(-1,1).cpu() 182 | 183 | pred = (1 + rearrange(pred, '(b t) c h w -> b t h w c', b=k)) * 127.5 184 | pred = pred.type(torch.uint8) 185 | pred_embeddings.append(get_fvd_logits(pred.numpy(), i3d=i3d, device=device)) 186 | 187 | real = (1 + rearrange(real, 'b c t h w -> b t h w c')) * 127.5 188 | real = real.type(torch.uint8) 189 | real_embeddings.append(get_fvd_logits(real.cpu().numpy(), i3d=i3d, device=device)) 190 | 191 | if len(predictions) < 4: 192 | reals.append(rearrange(real, 'b t h w c -> b c t h w')) 193 | predictions.append(rearrange(pred, 'b t h w c -> b c t h w')) 194 | 195 | reals = torch.cat(reals) 196 | predictions = torch.cat(predictions) 197 | 198 | real_embeddings = torch.cat(real_embeddings) 199 | pred_embeddings = torch.cat(pred_embeddings) 200 | 201 | if rank == 0: 202 | real_vid = save_image_grid(reals.cpu().numpy(), os.path.join(logger.logdir, f'real_{it}.gif'), drange=[0, 255], grid_size=(k,4)) 203 | real_vid = np.expand_dims(real_vid,0).transpose(0, 1, 4, 2, 3) 204 | pred_vid = save_image_grid(predictions.cpu().numpy(), os.path.join(logger.logdir, f'predicted_{it}.gif'), drange=[0, 255], grid_size=(k,4)) 205 | pred_vid = np.expand_dims(pred_vid,0).transpose(0, 1, 4, 2, 3) 206 | 207 | logger.video_summary('real', real_vid, it) 208 | logger.video_summary('prediction', pred_vid, it) 209 | else: 210 | raise NotImplementedError 211 | 212 | fvd = frechet_distance(pred_embeddings.clone().detach(), real_embeddings.clone().detach()) 213 | return fvd.item() 214 | 215 | -------------------------------------------------------------------------------- /PVDM/evals/fvd/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anuragajay/hip/f83186af477bf6d8c5f8f3c1fdef43ccb89db25d/PVDM/evals/fvd/__init__.py -------------------------------------------------------------------------------- /PVDM/evals/fvd/convert_tf_pretrained.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from collections import OrderedDict 3 | import tensorflow_hub as hub 4 | import torch 5 | 6 | from src_pytorch.fvd.pytorch_i3d import InceptionI3d 7 | 8 | 9 | def convert_name(name): 10 | mapping = { 11 | 'conv_3d': 'conv3d', 12 | 'batch_norm': 'bn', 13 | 'w:0': 'weight', 14 | 'b:0': 'bias', 15 | 'moving_mean:0': 'running_mean', 16 | 'moving_variance:0': 'running_var', 17 | 'beta:0': 'bias' 18 | } 19 | 20 | segs = name.split('/') 21 | new_segs = [] 22 | i = 0 23 | while i < len(segs): 24 | seg = segs[i] 25 | if 'Mixed' in seg: 26 | new_segs.append(seg) 27 | elif 'Conv' in seg and 'Mixed' not in name: 28 | new_segs.append(seg) 29 | elif 'Branch' in seg: 30 | branch_i = int(seg.split('_')[-1]) 31 | i += 1 32 | seg = segs[i] 33 | 34 | # special case due to typo in original code 35 | if 'Mixed_5b' in name and branch_i == 2: 36 | if '1x1' in seg: 37 | new_segs.append(f'b{branch_i}a') 38 | elif '3x3' in seg: 39 | new_segs.append(f'b{branch_i}b') 40 | else: 41 | raise Exception() 42 | # Either Conv3d_{i}a_... or Conv3d_{i}b_... 43 | elif 'a' in seg: 44 | if branch_i == 0: 45 | new_segs.append('b0') 46 | else: 47 | new_segs.append(f'b{branch_i}a') 48 | elif 'b' in seg: 49 | new_segs.append(f'b{branch_i}b') 50 | else: 51 | raise Exception 52 | elif seg == 'Logits': 53 | new_segs.append('logits') 54 | i += 1 55 | elif seg in mapping: 56 | new_segs.append(mapping[seg]) 57 | else: 58 | raise Exception(f"No match found for seg {seg} in name {name}") 59 | 60 | i += 1 61 | return '.'.join(new_segs) 62 | 63 | def convert_tensor(tensor): 64 | tensor_dim = len(tensor.shape) 65 | if tensor_dim == 5: # conv or bn 66 | if all([t == 1 for t in tensor.shape[:-1]]): 67 | tensor = tensor.squeeze() 68 | else: 69 | tensor = tensor.permute(4, 3, 0, 1, 2).contiguous() 70 | elif tensor_dim == 1: # conv bias 71 | pass 72 | else: 73 | raise Exception(f"Invalid shape {tensor.shape}") 74 | return tensor 75 | 76 | n_class = int(sys.argv[1]) # 600 or 400 77 | assert n_class in [400, 600] 78 | 79 | # Converts model from https://github.com/google-research/google-research/tree/master/frechet_video_distance 80 | # to pytorch version for loading 81 | model_url = f"https://tfhub.dev/deepmind/i3d-kinetics-{n_class}/1" 82 | i3d = hub.load(model_url) 83 | name_prefix = 'RGB/inception_i3d/' 84 | 85 | print('Creating state_dict...') 86 | all_names = [] 87 | state_dict = OrderedDict() 88 | for var in i3d.variables: 89 | name = var.name[len(name_prefix):] 90 | new_name = convert_name(name) 91 | all_names.append(new_name) 92 | 93 | tensor = torch.FloatTensor(var.value().numpy()) 94 | new_tensor = convert_tensor(tensor) 95 | 96 | state_dict[new_name] = new_tensor 97 | 98 | if 'bn.bias' in new_name: 99 | new_name = new_name[:-4] + 'weight' # bn.weight 100 | new_tensor = torch.ones_like(new_tensor).float() 101 | state_dict[new_name] = new_tensor 102 | 103 | print(f'Complete state_dict with {len(state_dict)} entries') 104 | 105 | s = dict() 106 | for i, n in enumerate(all_names): 107 | s[n] = s.get(n, []) + [i] 108 | 109 | for k, v in s.items(): 110 | if len(v) > 1: 111 | print('dup', k) 112 | for i in v: 113 | print('\t', i3d.variables[i].name) 114 | 115 | print('Testing load_state_dict...') 116 | print('Creating model...') 117 | 118 | i3d = InceptionI3d(n_class, in_channels=3) 119 | 120 | print('Loading state_dict...') 121 | i3d.load_state_dict(state_dict) 122 | 123 | print(f'Saving state_dict as fvd/i3d_pretrained_{n_class}.pt') 124 | torch.save(state_dict, f'fvd/i3d_pretrained_{n_class}.pt') 125 | 126 | print('Done') 127 | 128 | -------------------------------------------------------------------------------- /PVDM/evals/fvd/download.py: -------------------------------------------------------------------------------- 1 | import requests 2 | from tqdm import tqdm 3 | import os 4 | import torch 5 | 6 | from utils import download 7 | 8 | def get_confirm_token(response): 9 | for key, value in response.cookies.items(): 10 | if key.startswith('download_warning'): 11 | return value 12 | return None 13 | 14 | 15 | def save_response_content(response, destination): 16 | CHUNK_SIZE = 8192 17 | 18 | pbar = tqdm(total=0, unit='iB', unit_scale=True) 19 | with open(destination, 'wb') as f: 20 | for chunk in response.iter_content(CHUNK_SIZE): 21 | if chunk: 22 | f.write(chunk) 23 | pbar.update(len(chunk)) 24 | pbar.close() 25 | 26 | 27 | _I3D_PRETRAINED_ID = '1fBNl3TS0LA5FEhZv5nMGJs2_7qQmvTmh' 28 | 29 | def load_i3d_pretrained(device=torch.device('cpu')): 30 | from evals.fvd.pytorch_i3d import InceptionI3d 31 | i3d = InceptionI3d(400, in_channels=3).to(device) 32 | filepath = download(_I3D_PRETRAINED_ID, 'i3d_pretrained_400.pt') 33 | i3d.load_state_dict(torch.load(filepath, map_location=device)) 34 | i3d.eval() 35 | return i3d 36 | -------------------------------------------------------------------------------- /PVDM/evals/fvd/fvd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import torch.nn.functional as F 4 | 5 | def preprocess_single(video, resolution, sequence_length=None): 6 | # video: THWC, {0, ..., 255} 7 | video = video.permute(0, 3, 1, 2).float() / 255. # TCHW 8 | t, c, h, w = video.shape 9 | 10 | # temporal crop 11 | if sequence_length is not None: 12 | assert sequence_length <= t 13 | video = video[:sequence_length] 14 | 15 | # scale shorter side to resolution 16 | scale = resolution / min(h, w) 17 | if h < w: 18 | target_size = (resolution, math.ceil(w * scale)) 19 | else: 20 | target_size = (math.ceil(h * scale), resolution) 21 | video = F.interpolate(video, size=target_size, mode='bilinear', 22 | align_corners=False) 23 | 24 | # center crop 25 | t, c, h, w = video.shape 26 | w_start = (w - resolution) // 2 27 | h_start = (h - resolution) // 2 28 | video = video[:, :, h_start:h_start + resolution, w_start:w_start + resolution] 29 | video = video.permute(1, 0, 2, 3).contiguous() # CTHW 30 | 31 | video -= 0.5 32 | 33 | return video 34 | 35 | def preprocess(videos, target_resolution=224): 36 | # videos in {0, ..., 255} as np.uint8 array 37 | b, t, h, w, c = videos.shape 38 | videos = torch.from_numpy(videos) 39 | videos = torch.stack([preprocess_single(video, target_resolution) for video in videos]) 40 | return videos * 2 # [-0.5, 0.5] -> [-1, 1] 41 | 42 | def get_fvd_logits(videos, i3d, device): 43 | videos = preprocess(videos) 44 | embeddings = get_logits(i3d, videos, device) 45 | return embeddings 46 | 47 | # https://github.com/tensorflow/gan/blob/de4b8da3853058ea380a6152bd3bd454013bf619/tensorflow_gan/python/eval/classifier_metrics.py#L161 48 | def _symmetric_matrix_square_root(mat, eps=1e-10): 49 | u, s, v = torch.svd(mat) 50 | si = torch.where(s < eps, s, torch.sqrt(s)) 51 | return torch.matmul(torch.matmul(u, torch.diag(si)), v.t()) 52 | 53 | # https://github.com/tensorflow/gan/blob/de4b8da3853058ea380a6152bd3bd454013bf619/tensorflow_gan/python/eval/classifier_metrics.py#L400 54 | def trace_sqrt_product(sigma, sigma_v): 55 | sqrt_sigma = _symmetric_matrix_square_root(sigma) 56 | sqrt_a_sigmav_a = torch.matmul(sqrt_sigma, torch.matmul(sigma_v, sqrt_sigma)) 57 | return torch.trace(_symmetric_matrix_square_root(sqrt_a_sigmav_a)) 58 | 59 | # https://discuss.pytorch.org/t/covariance-and-gradient-support/16217/2 60 | def cov(m, rowvar=False): 61 | '''Estimate a covariance matrix given data. 62 | 63 | Covariance indicates the level to which two variables vary together. 64 | If we examine N-dimensional samples, `X = [x_1, x_2, ... x_N]^T`, 65 | then the covariance matrix element `C_{ij}` is the covariance of 66 | `x_i` and `x_j`. The element `C_{ii}` is the variance of `x_i`. 67 | 68 | Args: 69 | m: A 1-D or 2-D array containing multiple variables and observations. 70 | Each row of `m` represents a variable, and each column a single 71 | observation of all those variables. 72 | rowvar: If `rowvar` is True, then each row represents a 73 | variable, with observations in the columns. Otherwise, the 74 | relationship is transposed: each column represents a variable, 75 | while the rows contain observations. 76 | 77 | Returns: 78 | The covariance matrix of the variables. 79 | ''' 80 | if m.dim() > 2: 81 | raise ValueError('m has more than 2 dimensions') 82 | if m.dim() < 2: 83 | m = m.view(1, -1) 84 | if not rowvar and m.size(0) != 1: 85 | m = m.t() 86 | 87 | fact = 1.0 / (m.size(1) - 1) # unbiased estimate 88 | m -= torch.mean(m, dim=1, keepdim=True) 89 | mt = m.t() # if complex: mt = m.t().conj() 90 | return fact * m.matmul(mt).squeeze() 91 | 92 | 93 | def frechet_distance(x1, x2): 94 | x1 = x1.flatten(start_dim=1) 95 | x2 = x2.flatten(start_dim=1) 96 | m, m_w = x1.mean(dim=0), x2.mean(dim=0) 97 | sigma, sigma_w = cov(x1, rowvar=False), cov(x2, rowvar=False) 98 | 99 | sqrt_trace_component = trace_sqrt_product(sigma, sigma_w) 100 | trace = torch.trace(sigma + sigma_w) - 2.0 * sqrt_trace_component 101 | 102 | mean = torch.sum((m - m_w) ** 2) 103 | fd = trace + mean 104 | return fd 105 | 106 | 107 | def get_logits(i3d, videos, device): 108 | """ 109 | assert videos.shape[0] % 16 == 0 110 | with torch.no_grad(): 111 | logits = [] 112 | for i in range(0, videos.shape[0], 16): 113 | batch = videos[i:i + 16].to(device) 114 | logits.append(i3d(batch)) 115 | logits = torch.cat(logits, dim=0) 116 | return logits 117 | """ 118 | 119 | with torch.no_grad(): 120 | logits = i3d(videos.to(device)) 121 | return logits 122 | -------------------------------------------------------------------------------- /PVDM/evals/fvd/pytorch_i3d.py: -------------------------------------------------------------------------------- 1 | # Original code from https://github.com/piergiaj/pytorch-i3d 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | 7 | class MaxPool3dSamePadding(nn.MaxPool3d): 8 | 9 | def compute_pad(self, dim, s): 10 | if s % self.stride[dim] == 0: 11 | return max(self.kernel_size[dim] - self.stride[dim], 0) 12 | else: 13 | return max(self.kernel_size[dim] - (s % self.stride[dim]), 0) 14 | 15 | def forward(self, x): 16 | # compute 'same' padding 17 | (batch, channel, t, h, w) = x.size() 18 | out_t = np.ceil(float(t) / float(self.stride[0])) 19 | out_h = np.ceil(float(h) / float(self.stride[1])) 20 | out_w = np.ceil(float(w) / float(self.stride[2])) 21 | pad_t = self.compute_pad(0, t) 22 | pad_h = self.compute_pad(1, h) 23 | pad_w = self.compute_pad(2, w) 24 | 25 | pad_t_f = pad_t // 2 26 | pad_t_b = pad_t - pad_t_f 27 | pad_h_f = pad_h // 2 28 | pad_h_b = pad_h - pad_h_f 29 | pad_w_f = pad_w // 2 30 | pad_w_b = pad_w - pad_w_f 31 | 32 | pad = (pad_w_f, pad_w_b, pad_h_f, pad_h_b, pad_t_f, pad_t_b) 33 | x = F.pad(x, pad) 34 | return super(MaxPool3dSamePadding, self).forward(x) 35 | 36 | 37 | class Unit3D(nn.Module): 38 | 39 | def __init__(self, in_channels, 40 | output_channels, 41 | kernel_shape=(1, 1, 1), 42 | stride=(1, 1, 1), 43 | padding=0, 44 | activation_fn=F.relu, 45 | use_batch_norm=True, 46 | use_bias=False, 47 | name='unit_3d'): 48 | 49 | """Initializes Unit3D module.""" 50 | super(Unit3D, self).__init__() 51 | 52 | self._output_channels = output_channels 53 | self._kernel_shape = kernel_shape 54 | self._stride = stride 55 | self._use_batch_norm = use_batch_norm 56 | self._activation_fn = activation_fn 57 | self._use_bias = use_bias 58 | self.name = name 59 | self.padding = padding 60 | 61 | self.conv3d = nn.Conv3d(in_channels=in_channels, 62 | out_channels=self._output_channels, 63 | kernel_size=self._kernel_shape, 64 | stride=self._stride, 65 | padding=0, # we always want padding to be 0 here. We will dynamically pad based on input size in forward function 66 | bias=self._use_bias) 67 | 68 | if self._use_batch_norm: 69 | self.bn = nn.BatchNorm3d(self._output_channels, eps=1e-5, momentum=0.001) 70 | 71 | def compute_pad(self, dim, s): 72 | if s % self._stride[dim] == 0: 73 | return max(self._kernel_shape[dim] - self._stride[dim], 0) 74 | else: 75 | return max(self._kernel_shape[dim] - (s % self._stride[dim]), 0) 76 | 77 | 78 | def forward(self, x): 79 | # compute 'same' padding 80 | (batch, channel, t, h, w) = x.size() 81 | out_t = np.ceil(float(t) / float(self._stride[0])) 82 | out_h = np.ceil(float(h) / float(self._stride[1])) 83 | out_w = np.ceil(float(w) / float(self._stride[2])) 84 | pad_t = self.compute_pad(0, t) 85 | pad_h = self.compute_pad(1, h) 86 | pad_w = self.compute_pad(2, w) 87 | 88 | pad_t_f = pad_t // 2 89 | pad_t_b = pad_t - pad_t_f 90 | pad_h_f = pad_h // 2 91 | pad_h_b = pad_h - pad_h_f 92 | pad_w_f = pad_w // 2 93 | pad_w_b = pad_w - pad_w_f 94 | 95 | pad = (pad_w_f, pad_w_b, pad_h_f, pad_h_b, pad_t_f, pad_t_b) 96 | x = F.pad(x, pad) 97 | 98 | x = self.conv3d(x) 99 | if self._use_batch_norm: 100 | x = self.bn(x) 101 | if self._activation_fn is not None: 102 | x = self._activation_fn(x) 103 | return x 104 | 105 | 106 | 107 | class InceptionModule(nn.Module): 108 | def __init__(self, in_channels, out_channels, name): 109 | super(InceptionModule, self).__init__() 110 | 111 | self.b0 = Unit3D(in_channels=in_channels, output_channels=out_channels[0], kernel_shape=[1, 1, 1], padding=0, 112 | name=name+'/Branch_0/Conv3d_0a_1x1') 113 | self.b1a = Unit3D(in_channels=in_channels, output_channels=out_channels[1], kernel_shape=[1, 1, 1], padding=0, 114 | name=name+'/Branch_1/Conv3d_0a_1x1') 115 | self.b1b = Unit3D(in_channels=out_channels[1], output_channels=out_channels[2], kernel_shape=[3, 3, 3], 116 | name=name+'/Branch_1/Conv3d_0b_3x3') 117 | self.b2a = Unit3D(in_channels=in_channels, output_channels=out_channels[3], kernel_shape=[1, 1, 1], padding=0, 118 | name=name+'/Branch_2/Conv3d_0a_1x1') 119 | self.b2b = Unit3D(in_channels=out_channels[3], output_channels=out_channels[4], kernel_shape=[3, 3, 3], 120 | name=name+'/Branch_2/Conv3d_0b_3x3') 121 | self.b3a = MaxPool3dSamePadding(kernel_size=[3, 3, 3], 122 | stride=(1, 1, 1), padding=0) 123 | self.b3b = Unit3D(in_channels=in_channels, output_channels=out_channels[5], kernel_shape=[1, 1, 1], padding=0, 124 | name=name+'/Branch_3/Conv3d_0b_1x1') 125 | self.name = name 126 | 127 | def forward(self, x): 128 | b0 = self.b0(x) 129 | b1 = self.b1b(self.b1a(x)) 130 | b2 = self.b2b(self.b2a(x)) 131 | b3 = self.b3b(self.b3a(x)) 132 | return torch.cat([b0,b1,b2,b3], dim=1) 133 | 134 | 135 | class InceptionI3d(nn.Module): 136 | """Inception-v1 I3D architecture. 137 | The model is introduced in: 138 | Quo Vadis, Action Recognition? A New Model and the Kinetics Dataset 139 | Joao Carreira, Andrew Zisserman 140 | https://arxiv.org/pdf/1705.07750v1.pdf. 141 | See also the Inception architecture, introduced in: 142 | Going deeper with convolutions 143 | Christian Szegedy, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott Reed, 144 | Dragomir Anguelov, Dumitru Erhan, Vincent Vanhoucke, Andrew Rabinovich. 145 | http://arxiv.org/pdf/1409.4842v1.pdf. 146 | """ 147 | 148 | # Endpoints of the model in order. During construction, all the endpoints up 149 | # to a designated `final_endpoint` are returned in a dictionary as the 150 | # second return value. 151 | VALID_ENDPOINTS = ( 152 | 'Conv3d_1a_7x7', 153 | 'MaxPool3d_2a_3x3', 154 | 'Conv3d_2b_1x1', 155 | 'Conv3d_2c_3x3', 156 | 'MaxPool3d_3a_3x3', 157 | 'Mixed_3b', 158 | 'Mixed_3c', 159 | 'MaxPool3d_4a_3x3', 160 | 'Mixed_4b', 161 | 'Mixed_4c', 162 | 'Mixed_4d', 163 | 'Mixed_4e', 164 | 'Mixed_4f', 165 | 'MaxPool3d_5a_2x2', 166 | 'Mixed_5b', 167 | 'Mixed_5c', 168 | 'Logits', 169 | 'Predictions', 170 | ) 171 | 172 | def __init__(self, num_classes=400, spatial_squeeze=True, 173 | final_endpoint='Logits', name='inception_i3d', in_channels=3, dropout_keep_prob=0.5): 174 | """Initializes I3D model instance. 175 | Args: 176 | num_classes: The number of outputs in the logit layer (default 400, which 177 | matches the Kinetics dataset). 178 | spatial_squeeze: Whether to squeeze the spatial dimensions for the logits 179 | before returning (default True). 180 | final_endpoint: The model contains many possible endpoints. 181 | `final_endpoint` specifies the last endpoint for the model to be built 182 | up to. In addition to the output at `final_endpoint`, all the outputs 183 | at endpoints up to `final_endpoint` will also be returned, in a 184 | dictionary. `final_endpoint` must be one of 185 | InceptionI3d.VALID_ENDPOINTS (default 'Logits'). 186 | name: A string (optional). The name of this module. 187 | Raises: 188 | ValueError: if `final_endpoint` is not recognized. 189 | """ 190 | 191 | if final_endpoint not in self.VALID_ENDPOINTS: 192 | raise ValueError('Unknown final endpoint %s' % final_endpoint) 193 | 194 | super(InceptionI3d, self).__init__() 195 | self._num_classes = num_classes 196 | self._spatial_squeeze = spatial_squeeze 197 | self._final_endpoint = final_endpoint 198 | self.logits = None 199 | 200 | if self._final_endpoint not in self.VALID_ENDPOINTS: 201 | raise ValueError('Unknown final endpoint %s' % self._final_endpoint) 202 | 203 | self.end_points = {} 204 | end_point = 'Conv3d_1a_7x7' 205 | self.end_points[end_point] = Unit3D(in_channels=in_channels, output_channels=64, kernel_shape=[7, 7, 7], 206 | stride=(2, 2, 2), padding=(3,3,3), name=name+end_point) 207 | if self._final_endpoint == end_point: return 208 | 209 | end_point = 'MaxPool3d_2a_3x3' 210 | self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[1, 3, 3], stride=(1, 2, 2), 211 | padding=0) 212 | if self._final_endpoint == end_point: return 213 | 214 | end_point = 'Conv3d_2b_1x1' 215 | self.end_points[end_point] = Unit3D(in_channels=64, output_channels=64, kernel_shape=[1, 1, 1], padding=0, 216 | name=name+end_point) 217 | if self._final_endpoint == end_point: return 218 | 219 | end_point = 'Conv3d_2c_3x3' 220 | self.end_points[end_point] = Unit3D(in_channels=64, output_channels=192, kernel_shape=[3, 3, 3], padding=1, 221 | name=name+end_point) 222 | if self._final_endpoint == end_point: return 223 | 224 | end_point = 'MaxPool3d_3a_3x3' 225 | self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[1, 3, 3], stride=(1, 2, 2), 226 | padding=0) 227 | if self._final_endpoint == end_point: return 228 | 229 | end_point = 'Mixed_3b' 230 | self.end_points[end_point] = InceptionModule(192, [64,96,128,16,32,32], name+end_point) 231 | if self._final_endpoint == end_point: return 232 | 233 | end_point = 'Mixed_3c' 234 | self.end_points[end_point] = InceptionModule(256, [128,128,192,32,96,64], name+end_point) 235 | if self._final_endpoint == end_point: return 236 | 237 | end_point = 'MaxPool3d_4a_3x3' 238 | self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[3, 3, 3], stride=(2, 2, 2), 239 | padding=0) 240 | if self._final_endpoint == end_point: return 241 | 242 | end_point = 'Mixed_4b' 243 | self.end_points[end_point] = InceptionModule(128+192+96+64, [192,96,208,16,48,64], name+end_point) 244 | if self._final_endpoint == end_point: return 245 | 246 | end_point = 'Mixed_4c' 247 | self.end_points[end_point] = InceptionModule(192+208+48+64, [160,112,224,24,64,64], name+end_point) 248 | if self._final_endpoint == end_point: return 249 | 250 | end_point = 'Mixed_4d' 251 | self.end_points[end_point] = InceptionModule(160+224+64+64, [128,128,256,24,64,64], name+end_point) 252 | if self._final_endpoint == end_point: return 253 | 254 | end_point = 'Mixed_4e' 255 | self.end_points[end_point] = InceptionModule(128+256+64+64, [112,144,288,32,64,64], name+end_point) 256 | if self._final_endpoint == end_point: return 257 | 258 | end_point = 'Mixed_4f' 259 | self.end_points[end_point] = InceptionModule(112+288+64+64, [256,160,320,32,128,128], name+end_point) 260 | if self._final_endpoint == end_point: return 261 | 262 | end_point = 'MaxPool3d_5a_2x2' 263 | self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[2, 2, 2], stride=(2, 2, 2), 264 | padding=0) 265 | if self._final_endpoint == end_point: return 266 | 267 | end_point = 'Mixed_5b' 268 | self.end_points[end_point] = InceptionModule(256+320+128+128, [256,160,320,32,128,128], name+end_point) 269 | if self._final_endpoint == end_point: return 270 | 271 | end_point = 'Mixed_5c' 272 | self.end_points[end_point] = InceptionModule(256+320+128+128, [384,192,384,48,128,128], name+end_point) 273 | if self._final_endpoint == end_point: return 274 | 275 | end_point = 'Logits' 276 | self.avg_pool = nn.AvgPool3d(kernel_size=[2, 7, 7], 277 | stride=(1, 1, 1)) 278 | self.dropout = nn.Dropout(dropout_keep_prob) 279 | self.logits = Unit3D(in_channels=384+384+128+128, output_channels=self._num_classes, 280 | kernel_shape=[1, 1, 1], 281 | padding=0, 282 | activation_fn=None, 283 | use_batch_norm=False, 284 | use_bias=True, 285 | name='logits') 286 | 287 | self.build() 288 | 289 | 290 | def replace_logits(self, num_classes): 291 | self._num_classes = num_classes 292 | self.logits = Unit3D(in_channels=384+384+128+128, output_channels=self._num_classes, 293 | kernel_shape=[1, 1, 1], 294 | padding=0, 295 | activation_fn=None, 296 | use_batch_norm=False, 297 | use_bias=True, 298 | name='logits') 299 | 300 | 301 | def build(self): 302 | for k in self.end_points.keys(): 303 | self.add_module(k, self.end_points[k]) 304 | 305 | def forward(self, x): 306 | for end_point in self.VALID_ENDPOINTS: 307 | if end_point in self.end_points: 308 | x = self._modules[end_point](x) # use _modules to work with dataparallel 309 | 310 | x = self.logits(self.dropout(self.avg_pool(x))) 311 | if self._spatial_squeeze: 312 | logits = x.squeeze(3).squeeze(3) 313 | logits = logits.mean(dim=2) 314 | # logits is batch X time X classes, which is what we want to work with 315 | return logits 316 | 317 | 318 | def extract_features(self, x): 319 | for end_point in self.VALID_ENDPOINTS: 320 | if end_point in self.end_points: 321 | x = self._modules[end_point](x) 322 | return self.avg_pool(x) 323 | -------------------------------------------------------------------------------- /PVDM/exps/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anuragajay/hip/f83186af477bf6d8c5f8f3c1fdef43ccb89db25d/PVDM/exps/__init__.py -------------------------------------------------------------------------------- /PVDM/exps/diffusion.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | import torch 5 | 6 | from tools.trainer import latentDDPM 7 | from tools.dataloader import get_loaders 8 | from tools.scheduler import LambdaLinearScheduler 9 | from models.autoencoder.autoencoder_vit import ViTAutoencoder 10 | from models.ddpm.unet import UNetModel, DiffusionWrapper 11 | from losses.ddpm import DDPM 12 | 13 | import copy 14 | from utils import file_name, Logger, download 15 | 16 | #---------------------------------------------------------------------------- 17 | 18 | _num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares] 19 | _reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction. 20 | _counter_dtype = torch.float64 # Data type to use for the internal counters. 21 | _rank = 0 # Rank of the current process. 22 | _sync_device = None # Device to use for multiprocess communication. None = single-process. 23 | _sync_called = False # Has _sync() been called yet? 24 | _counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor 25 | _cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor 26 | 27 | #---------------------------------------------------------------------------- 28 | 29 | def init_multiprocessing(rank, sync_device): 30 | r"""Initializes `torch_utils.training_stats` for collecting statistics 31 | across multiple processes. 32 | This function must be called after 33 | `torch.distributed.init_process_group()` and before `Collector.update()`. 34 | The call is not necessary if multi-process collection is not needed. 35 | Args: 36 | rank: Rank of the current process. 37 | sync_device: PyTorch device to use for inter-process 38 | communication, or None to disable multi-process 39 | collection. Typically `torch.device('cuda', rank)`. 40 | """ 41 | global _rank, _sync_device 42 | assert not _sync_called 43 | _rank = rank 44 | _sync_device = sync_device 45 | 46 | #---------------------------------------------------------------------------- 47 | 48 | def diffusion(rank, args): 49 | device = torch.device('cuda', rank) 50 | 51 | temp_dir = './' 52 | # if args.n_gpus > 1: 53 | # init_file = os.path.abspath(os.path.join(temp_dir, '.torch_distributed_init')) 54 | # if os.name == 'nt': 55 | # init_method = 'file:///' + init_file.replace('\\', '/') 56 | # torch.distributed.init_process_group(backend='gloo', init_method=init_method, rank=rank, world_size=args.n_gpus) 57 | # else: 58 | # init_method = f'file://{init_file}' 59 | # torch.distributed.init_process_group(backend='nccl', init_method=init_method, rank=rank, world_size=args.n_gpus) 60 | 61 | # # Init torch_utils. 62 | # sync_device = torch.device('cuda', rank) if args.n_gpus > 1 else None 63 | # init_multiprocessing(rank=rank, sync_device=sync_device) 64 | 65 | """ ROOT DIRECTORY """ 66 | if rank == 0: 67 | fn = file_name(args) 68 | logger = Logger(fn) 69 | logger.log(args) 70 | logger.log(f'Log path: {logger.logdir}') 71 | rootdir = logger.logdir 72 | else: 73 | logger = None 74 | 75 | if logger is None: 76 | log_ = print 77 | else: 78 | log_ = logger.log 79 | 80 | """ Get Image """ 81 | if rank == 0: 82 | log_(f"Loading dataset {args.data} with resolution {args.res}") 83 | train_loader, test_loader, total_vid = get_loaders(rank, args.data, args.res, args.timesteps, args.skip, args.batch_size, args.n_gpus, args.seed, args.cond_model) 84 | 85 | if args.data == 'cliport': 86 | cond_prob = 0.9 87 | elif args.data == 'SKY': 88 | cond_prob = 0.2 89 | else: 90 | cond_prob = 0.3 91 | 92 | """ Get Model """ 93 | if rank == 0: 94 | log_(f"Generating model") 95 | 96 | torch.cuda.set_device(rank) 97 | first_stage_model = ViTAutoencoder(args.embed_dim, args.ddconfig).to(device) 98 | 99 | if rank == 0: 100 | first_stage_model_ckpt = torch.load(args.first_model) 101 | first_stage_model.load_state_dict(first_stage_model_ckpt) 102 | 103 | unet = UNetModel(**args.unetconfig) 104 | model = DiffusionWrapper(unet).to(device) 105 | 106 | if rank == 0: 107 | torch.save(model.state_dict(), rootdir + f'net_init.pth') 108 | 109 | ema_model = None 110 | if args.n_gpus > 1: 111 | first_stage_model = torch.nn.parallel.DataParallel(first_stage_model) 112 | model = torch.nn.parallel.DataParallel(model) 113 | model = model.module 114 | first_stage_model = first_stage_model.module 115 | 116 | criterion = DDPM(model, channels=args.unetconfig.in_channels, 117 | image_size=args.unetconfig.image_size, 118 | linear_start=args.ddpmconfig.linear_start, 119 | linear_end=args.ddpmconfig.linear_end, 120 | log_every_t=args.ddpmconfig.log_every_t, 121 | w=args.ddpmconfig.w, 122 | ).to(device) 123 | 124 | if args.n_gpus > 1: 125 | criterion = torch.nn.parallel.DataParallel(criterion) 126 | criterion = criterion.module 127 | 128 | if args.scale_lr: 129 | args.lr *= args.batch_size 130 | 131 | opt = torch.optim.AdamW(model.parameters(), lr=args.lr) 132 | lr_scheduler = LambdaLinearScheduler(**args.scheduler) 133 | 134 | latentDDPM(rank, first_stage_model, model, opt, criterion, train_loader, test_loader, lr_scheduler, ema_model, cond_prob, logger) 135 | 136 | if rank == 0: 137 | torch.save(model.state_dict(), rootdir + f'net_meta.pth') 138 | -------------------------------------------------------------------------------- /PVDM/exps/first_stage.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | import torch 5 | 6 | from tools.trainer import first_stage_train 7 | from tools.dataloader import get_loaders 8 | from models.autoencoder.autoencoder_vit import ViTAutoencoder 9 | from losses.perceptual import LPIPSWithDiscriminator 10 | 11 | from utils import file_name, Logger 12 | 13 | #---------------------------------------------------------------------------- 14 | 15 | _num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares] 16 | _reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction. 17 | _counter_dtype = torch.float64 # Data type to use for the internal counters. 18 | _rank = 0 # Rank of the current process. 19 | _sync_device = None # Device to use for multiprocess communication. None = single-process. 20 | _sync_called = False # Has _sync() been called yet? 21 | _counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor 22 | _cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor 23 | 24 | #---------------------------------------------------------------------------- 25 | 26 | def init_multiprocessing(rank, sync_device): 27 | r"""Initializes `torch_utils.training_stats` for collecting statistics 28 | across multiple processes. 29 | This function must be called after 30 | `torch.distributed.init_process_group()` and before `Collector.update()`. 31 | The call is not necessary if multi-process collection is not needed. 32 | Args: 33 | rank: Rank of the current process. 34 | sync_device: PyTorch device to use for inter-process 35 | communication, or None to disable multi-process 36 | collection. Typically `torch.device('cuda', rank)`. 37 | """ 38 | global _rank, _sync_device 39 | assert not _sync_called 40 | _rank = rank 41 | _sync_device = sync_device 42 | 43 | #---------------------------------------------------------------------------- 44 | 45 | def first_stage(rank, args): 46 | device = torch.device('cuda', rank) 47 | 48 | temp_dir = './' 49 | # if args.n_gpus > 1: 50 | # init_file = os.path.abspath(os.path.join(temp_dir, '.torch_distributed_init')) 51 | # if os.name == 'nt': 52 | # init_method = 'file:///' + init_file.replace('\\', '/') 53 | # torch.distributed.init_process_group(backend='gloo', init_method=init_method, rank=rank, world_size=args.n_gpus) 54 | # else: 55 | # init_method = f'file://{init_file}' 56 | # torch.distributed.init_process_group(backend='nccl', init_method=init_method, rank=rank, world_size=args.n_gpus) 57 | 58 | # # Init torch_utils. 59 | # sync_device = torch.device('cuda', rank) if args.n_gpus > 1 else None 60 | # init_multiprocessing(rank=rank, sync_device=sync_device) 61 | 62 | """ ROOT DIRECTORY """ 63 | if rank == 0: 64 | fn = file_name(args) 65 | logger = Logger(fn) 66 | logger.log(args) 67 | logger.log(f'Log path: {logger.logdir}') 68 | rootdir = logger.logdir 69 | else: 70 | logger = None 71 | 72 | if logger is None: 73 | log_ = print 74 | else: 75 | log_ = logger.log 76 | 77 | """ Get Image """ 78 | if rank == 0: 79 | log_(f"Loading dataset {args.data} with resolution {args.res}") 80 | train_loader, test_loader, total_vid = get_loaders(rank, args.data, args.res, args.timesteps, args.skip, args.batch_size, args.n_gpus, args.seed, cond=False) 81 | 82 | """ Get Model """ 83 | if rank == 0: 84 | log_(f"Generating model") 85 | 86 | torch.cuda.set_device(rank) 87 | model = ViTAutoencoder(args.embed_dim, args.ddconfig) 88 | model = model.to(device) 89 | 90 | criterion = LPIPSWithDiscriminator(disc_start = args.lossconfig.params.disc_start, 91 | timesteps = args.ddconfig.timesteps).to(device) 92 | 93 | 94 | opt = torch.optim.AdamW(model.parameters(), 95 | lr=args.lr, 96 | betas=(0.5, 0.9) 97 | ) 98 | 99 | d_opt = torch.optim.AdamW(list(criterion.discriminator_2d.parameters()) + list(criterion.discriminator_3d.parameters()), 100 | lr=args.lr, 101 | betas=(0.5, 0.9)) 102 | 103 | if args.resume and rank == 0: 104 | model_ckpt = torch.load(os.path.join(args.first_stage_folder, 'model_last.pth')) 105 | model.load_state_dict(model_ckpt) 106 | opt_ckpt = torch.load(os.path.join(args.first_stage_folder, 'opt.pth')) 107 | opt.load_state_dict(opt_ckpt) 108 | 109 | del model_ckpt 110 | del opt_ckpt 111 | 112 | if rank == 0: 113 | torch.save(model.state_dict(), rootdir + f'net_init.pth') 114 | 115 | if args.n_gpus > 1: 116 | model = torch.nn.parallel.DataParallel(model) 117 | criterion = torch.nn.parallel.DataParallel(criterion) 118 | model = model.module 119 | criterion = criterion.module 120 | 121 | fp = args.amp 122 | first_stage_train(rank, model, opt, d_opt, criterion, train_loader, test_loader, args.first_model, fp, logger) 123 | 124 | if rank == 0: 125 | torch.save(model.state_dict(), rootdir + f'net_meta.pth') 126 | -------------------------------------------------------------------------------- /PVDM/losses/diffaugment.py: -------------------------------------------------------------------------------- 1 | # Differentiable Augmentation for Data-Efficient GAN Training 2 | # Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han 3 | # https://arxiv.org/pdf/2006.10738 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | 9 | def DiffAugment(x, policy='color,translation,cutout', channels_first=True): 10 | if policy: 11 | if not channels_first: 12 | x = x.permute(0, 3, 1, 2) 13 | for p in policy.split(','): 14 | for f in AUGMENT_FNS[p]: 15 | x = f(x) 16 | if not channels_first: 17 | x = x.permute(0, 2, 3, 1) 18 | x = x.contiguous() 19 | return x 20 | 21 | 22 | def rand_brightness(x): 23 | x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5) 24 | return x 25 | 26 | 27 | def rand_saturation(x): 28 | x_mean = x.mean(dim=1, keepdim=True) 29 | x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean 30 | return x 31 | 32 | 33 | def rand_contrast(x): 34 | x_mean = x.mean(dim=[1, 2, 3], keepdim=True) 35 | x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean 36 | return x 37 | 38 | 39 | def rand_translation(x, ratio=0.125): 40 | shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) 41 | translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device) 42 | translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device) 43 | grid_batch, grid_x, grid_y = torch.meshgrid( 44 | torch.arange(x.size(0), dtype=torch.long, device=x.device), 45 | torch.arange(x.size(2), dtype=torch.long, device=x.device), 46 | torch.arange(x.size(3), dtype=torch.long, device=x.device), 47 | ) 48 | grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1) 49 | grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1) 50 | x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0]) 51 | x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2).contiguous() 52 | return x 53 | 54 | 55 | def rand_cutout(x, ratio=0.5): 56 | cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) 57 | offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device) 58 | offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device) 59 | grid_batch, grid_x, grid_y = torch.meshgrid( 60 | torch.arange(x.size(0), dtype=torch.long, device=x.device), 61 | torch.arange(cutout_size[0], dtype=torch.long, device=x.device), 62 | torch.arange(cutout_size[1], dtype=torch.long, device=x.device), 63 | ) 64 | grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1) 65 | grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1) 66 | mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device) 67 | mask[grid_batch, grid_x, grid_y] = 0 68 | x = x * mask.unsqueeze(1) 69 | return x 70 | 71 | 72 | AUGMENT_FNS = { 73 | 'color': [rand_brightness, rand_saturation, rand_contrast], 74 | 'translation': [rand_translation], 75 | 'cutout': [rand_cutout], 76 | } -------------------------------------------------------------------------------- /PVDM/losses/lpips.py: -------------------------------------------------------------------------------- 1 | """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" 2 | import os 3 | import requests 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torchvision import models 8 | from collections import namedtuple 9 | 10 | from tqdm import tqdm 11 | import hashlib 12 | 13 | 14 | URL_MAP = { 15 | "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1" 16 | } 17 | 18 | CKPT_MAP = { 19 | "vgg_lpips": "vgg.pth" 20 | } 21 | 22 | MD5_MAP = { 23 | "vgg_lpips": "d507d7349b931f0638a25a48a722f98a" 24 | } 25 | 26 | def download(url, local_path, chunk_size=1024): 27 | os.makedirs(os.path.split(local_path)[0], exist_ok=True) 28 | with requests.get(url, stream=True) as r: 29 | total_size = int(r.headers.get("content-length", 0)) 30 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: 31 | with open(local_path, "wb") as f: 32 | for data in r.iter_content(chunk_size=chunk_size): 33 | if data: 34 | f.write(data) 35 | pbar.update(chunk_size) 36 | 37 | 38 | def md5_hash(path): 39 | with open(path, "rb") as f: 40 | content = f.read() 41 | return hashlib.md5(content).hexdigest() 42 | 43 | 44 | def get_ckpt_path(name, root, check=False): 45 | assert name in URL_MAP 46 | path = os.path.join(root, CKPT_MAP[name]) 47 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): 48 | print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) 49 | download(URL_MAP[name], path) 50 | md5 = md5_hash(path) 51 | assert md5 == MD5_MAP[name], md5 52 | return path 53 | 54 | 55 | class LPIPS(nn.Module): 56 | # Learned perceptual metric 57 | def __init__(self, use_dropout=True): 58 | super().__init__() 59 | self.scaling_layer = ScalingLayer() 60 | self.chns = [64, 128, 256, 512, 512] # vg16 features 61 | self.net = vgg16(pretrained=True, requires_grad=False) 62 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) 63 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) 64 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) 65 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) 66 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) 67 | self.load_from_pretrained() 68 | for param in self.parameters(): 69 | param.requires_grad = False 70 | 71 | def load_from_pretrained(self, name="vgg_lpips"): 72 | ckpt = get_ckpt_path(name, "./losses") 73 | self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) 74 | print("loaded pretrained LPIPS loss from {}".format(ckpt)) 75 | 76 | @classmethod 77 | def from_pretrained(cls, name="vgg_lpips"): 78 | if name != "vgg_lpips": 79 | raise NotImplementedError 80 | model = cls() 81 | ckpt = get_ckpt_path(name) 82 | model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) 83 | return model 84 | 85 | def forward(self, input, target): 86 | in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) 87 | outs0, outs1 = self.net(in0_input), self.net(in1_input) 88 | feats0, feats1, diffs = {}, {}, {} 89 | lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] 90 | for kk in range(len(self.chns)): 91 | feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) 92 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 93 | 94 | res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))] 95 | val = res[0] 96 | for l in range(1, len(self.chns)): 97 | val += res[l] 98 | return val 99 | 100 | 101 | class ScalingLayer(nn.Module): 102 | def __init__(self): 103 | super(ScalingLayer, self).__init__() 104 | self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 105 | self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None]) 106 | 107 | def forward(self, inp): 108 | return (inp - self.shift) / self.scale 109 | 110 | 111 | class NetLinLayer(nn.Module): 112 | """ A single linear layer which does a 1x1 conv """ 113 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 114 | super(NetLinLayer, self).__init__() 115 | layers = [nn.Dropout(), ] if (use_dropout) else [] 116 | layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] 117 | self.model = nn.Sequential(*layers) 118 | 119 | 120 | class vgg16(torch.nn.Module): 121 | def __init__(self, requires_grad=False, pretrained=True): 122 | super(vgg16, self).__init__() 123 | vgg_pretrained_features = models.vgg16(pretrained=pretrained).features 124 | self.slice1 = torch.nn.Sequential() 125 | self.slice2 = torch.nn.Sequential() 126 | self.slice3 = torch.nn.Sequential() 127 | self.slice4 = torch.nn.Sequential() 128 | self.slice5 = torch.nn.Sequential() 129 | self.N_slices = 5 130 | for x in range(4): 131 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 132 | for x in range(4, 9): 133 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 134 | for x in range(9, 16): 135 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 136 | for x in range(16, 23): 137 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 138 | for x in range(23, 30): 139 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 140 | if not requires_grad: 141 | for param in self.parameters(): 142 | param.requires_grad = False 143 | 144 | def forward(self, X): 145 | h = self.slice1(X) 146 | h_relu1_2 = h 147 | h = self.slice2(h) 148 | h_relu2_2 = h 149 | h = self.slice3(h) 150 | h_relu3_3 = h 151 | h = self.slice4(h) 152 | h_relu4_3 = h 153 | h = self.slice5(h) 154 | h_relu5_3 = h 155 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) 156 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 157 | return out 158 | 159 | 160 | def normalize_tensor(x,eps=1e-10): 161 | norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True)) 162 | return x/(norm_factor+eps) 163 | 164 | 165 | def spatial_average(x, keepdim=True): 166 | return x.mean([2,3],keepdim=keepdim) -------------------------------------------------------------------------------- /PVDM/losses/perceptual.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | import torch.nn.functional as F 5 | from einops import repeat, rearrange 6 | 7 | import functools 8 | from losses.lpips import LPIPS 9 | from utils import make_pairs 10 | from losses.diffaugment import DiffAugment 11 | 12 | class DummyLoss(nn.Module): 13 | def __init__(self): 14 | super().__init__() 15 | 16 | 17 | def adopt_weight(global_step, threshold=0, value=0.): 18 | weight = 1.0 19 | if global_step < threshold: 20 | weight = value 21 | return weight 22 | 23 | 24 | def hinge_d_loss(logits_real, logits_fake): 25 | loss_real = torch.mean(F.relu(1. - logits_real)) 26 | loss_fake = torch.mean(F.relu(1. + logits_fake)) 27 | d_loss = 0.5 * (loss_real + loss_fake) 28 | return d_loss 29 | 30 | 31 | def vanilla_d_loss(logits_real, logits_fake): 32 | d_loss = 0.5 * ( 33 | torch.mean(torch.nn.functional.softplus(-logits_real)) + 34 | torch.mean(torch.nn.functional.softplus(logits_fake))) 35 | return d_loss 36 | 37 | 38 | def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights): 39 | assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0] 40 | loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3]) 41 | loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3]) 42 | loss_real = (weights * loss_real).sum() / weights.sum() 43 | loss_fake = (weights * loss_fake).sum() / weights.sum() 44 | d_loss = 0.5 * (loss_real + loss_fake) 45 | return d_loss 46 | 47 | 48 | def measure_perplexity(predicted_indices, n_embed): 49 | # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py 50 | # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally 51 | encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed) 52 | avg_probs = encodings.mean(0) 53 | perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() 54 | cluster_use = torch.sum(avg_probs > 0) 55 | return perplexity, cluster_use 56 | 57 | def l1(x, y): 58 | return torch.abs(x-y) 59 | 60 | 61 | def l2(x, y): 62 | return torch.pow((x-y), 2) 63 | 64 | 65 | class LPIPSWithDiscriminator(nn.Module): 66 | def __init__(self, disc_start, disc_num_layers=3, disc_in_channels=3, 67 | pixelloss_weight=4.0, disc_weight=1.0, 68 | perceptual_weight=4.0, feature_weight=4.0, 69 | disc_ndf=64, disc_loss="hinge", timesteps=16): 70 | super().__init__() 71 | assert disc_loss in ["hinge", "vanilla"] 72 | self.s = timesteps 73 | self.perceptual_loss = LPIPS().eval() 74 | self.pixel_loss = l1 75 | 76 | self.discriminator_2d = NLayerDiscriminator(input_nc=disc_in_channels, 77 | n_layers=disc_num_layers, 78 | ndf=disc_ndf 79 | ).apply(weights_init) 80 | self.discriminator_3d = NLayerDiscriminator3D(input_nc=disc_in_channels, 81 | n_layers=disc_num_layers, 82 | ndf=disc_ndf 83 | ).apply(weights_init) 84 | 85 | self.discriminator_iter_start = disc_start 86 | if disc_loss == "hinge": 87 | self.disc_loss = hinge_d_loss 88 | elif disc_loss == "vanilla": 89 | self.disc_loss = vanilla_d_loss 90 | 91 | self.pixel_weight = pixelloss_weight 92 | self.gan_weight = disc_weight 93 | self.perceptual_weight = perceptual_weight 94 | self.gan_feat_weight = feature_weight 95 | 96 | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, 97 | global_step, cond=None, split="train", predicted_indices=None): 98 | 99 | b, c, _, h, w = inputs.size() 100 | rec_loss = self.pixel_weight * F.l1_loss(inputs.contiguous(), reconstructions.contiguous()) 101 | 102 | frame_idx = torch.randint(0,self.s, [b]).cuda() 103 | frame_idx_selected = frame_idx.reshape(-1, 1, 1, 1, 1).repeat(1, c, 1, h, w) 104 | inputs_2d = torch.gather(inputs, 2, frame_idx_selected).squeeze(2) 105 | reconstructions_2d = torch.gather(reconstructions, 2, frame_idx_selected).squeeze(2) 106 | 107 | if optimizer_idx == 0: 108 | if self.perceptual_weight > 0: 109 | """ 110 | p_loss = self.perceptual_weight * self.perceptual_loss(rearrange(inputs, 'b c t h w -> (b t) c h w').contiguous(), 111 | rearrange(reconstructions, 'b c t h w -> (b t) c h w').contiguous()).mean() 112 | """ 113 | p_loss = self.perceptual_weight * self.perceptual_loss(inputs_2d.contiguous(), reconstructions_2d.contiguous()).mean() 114 | else: 115 | p_loss = torch.tensor([0.0]) 116 | 117 | disc_factor = adopt_weight(global_step, threshold=self.discriminator_iter_start) 118 | logits_real_2d, pred_real_2d = self.discriminator_2d(inputs_2d) 119 | logits_real_3d, pred_real_3d = self.discriminator_3d(inputs.contiguous()) 120 | logits_fake_2d, pred_fake_2d = self.discriminator_2d(reconstructions_2d) 121 | logits_fake_3d, pred_fake_3d = self.discriminator_3d(reconstructions.contiguous()) 122 | g_loss = -disc_factor * self.gan_weight * (torch.mean(logits_fake_2d) + torch.mean(logits_fake_3d)) 123 | 124 | image_gan_feat_loss = 0. 125 | video_gan_feat_loss = 0. 126 | 127 | for i in range(len(pred_real_2d)-1): 128 | image_gan_feat_loss += F.l1_loss(pred_fake_2d[i], pred_real_2d[i].detach()) 129 | for i in range(len(pred_real_3d)-1): 130 | video_gan_feat_loss += F.l1_loss(pred_fake_3d[i], pred_real_3d[i].detach()) 131 | 132 | gan_feat_loss = disc_factor * self.gan_feat_weight * (image_gan_feat_loss + video_gan_feat_loss) 133 | return rec_loss + p_loss + g_loss + gan_feat_loss 134 | 135 | if optimizer_idx == 1: 136 | 137 | # second pass for discriminator update 138 | logits_real_2d, _ = self.discriminator_2d(inputs_2d) 139 | logits_real_3d, _ = self.discriminator_3d(inputs.contiguous()) 140 | logits_fake_2d, _ = self.discriminator_2d(reconstructions_2d) 141 | logits_fake_3d, _ = self.discriminator_3d(reconstructions.contiguous()) 142 | 143 | disc_factor = adopt_weight(global_step, threshold=self.discriminator_iter_start) 144 | d_loss = disc_factor * self.gan_weight * (self.disc_loss(logits_real_2d, logits_fake_2d) + self.disc_loss(logits_real_3d, logits_fake_3d)) 145 | 146 | return d_loss 147 | 148 | 149 | def weights_init(m): 150 | classname = m.__class__.__name__ 151 | if classname.find('Conv') != -1: 152 | nn.init.normal_(m.weight.data, 0.0, 0.02) 153 | elif classname.find('BatchNorm') != -1: 154 | nn.init.normal_(m.weight.data, 1.0, 0.02) 155 | nn.init.constant_(m.bias.data, 0) 156 | 157 | class NLayerDiscriminator(nn.Module): 158 | """Defines a PatchGAN discriminator as in Pix2Pix 159 | --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py 160 | """ 161 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, getIntermFeat=True): 162 | # def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, getIntermFeat=True): 163 | super(NLayerDiscriminator, self).__init__() 164 | self.getIntermFeat = getIntermFeat 165 | self.n_layers = n_layers 166 | 167 | kw = 4 168 | padw = int(np.ceil((kw-1.0)/2)) 169 | sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]] 170 | 171 | nf = ndf 172 | for n in range(1, n_layers): 173 | nf_prev = nf 174 | nf = min(nf * 2, 512) 175 | sequence += [[ 176 | nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw), 177 | norm_layer(nf), nn.LeakyReLU(0.2, True) 178 | ]] 179 | 180 | nf_prev = nf 181 | nf = min(nf * 2, 512) 182 | sequence += [[ 183 | nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw), 184 | norm_layer(nf), 185 | nn.LeakyReLU(0.2, True) 186 | ]] 187 | 188 | sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]] 189 | 190 | if use_sigmoid: 191 | sequence += [[nn.Sigmoid()]] 192 | 193 | if getIntermFeat: 194 | for n in range(len(sequence)): 195 | setattr(self, 'model'+str(n), nn.Sequential(*sequence[n])) 196 | else: 197 | sequence_stream = [] 198 | for n in range(len(sequence)): 199 | sequence_stream += sequence[n] 200 | self.model = nn.Sequential(*sequence_stream) 201 | 202 | def forward(self, input): 203 | if self.getIntermFeat: 204 | res = [input] 205 | for n in range(self.n_layers+2): 206 | model = getattr(self, 'model'+str(n)) 207 | res.append(model(res[-1])) 208 | return res[-1], res[1:] 209 | else: 210 | return self.model(input), _ 211 | 212 | class NLayerDiscriminator3D(nn.Module): 213 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm3d, use_sigmoid=False, getIntermFeat=True): 214 | super(NLayerDiscriminator3D, self).__init__() 215 | self.getIntermFeat = getIntermFeat 216 | self.n_layers = n_layers 217 | 218 | kw = 4 219 | padw = int(np.ceil((kw-1.0)/2)) 220 | sequence = [[nn.Conv3d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]] 221 | 222 | nf = ndf 223 | for n in range(1, n_layers): 224 | nf_prev = nf 225 | nf = min(nf * 2, 512) 226 | sequence += [[ 227 | nn.Conv3d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw), 228 | norm_layer(nf), nn.LeakyReLU(0.2, True) 229 | ]] 230 | 231 | nf_prev = nf 232 | nf = min(nf * 2, 512) 233 | sequence += [[ 234 | nn.Conv3d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw), 235 | norm_layer(nf), 236 | nn.LeakyReLU(0.2, True) 237 | ]] 238 | 239 | sequence += [[nn.Conv3d(nf, 1, kernel_size=kw, stride=1, padding=padw)]] 240 | 241 | if use_sigmoid: 242 | sequence += [[nn.Sigmoid()]] 243 | 244 | if getIntermFeat: 245 | for n in range(len(sequence)): 246 | setattr(self, 'model'+str(n), nn.Sequential(*sequence[n])) 247 | else: 248 | sequence_stream = [] 249 | for n in range(len(sequence)): 250 | sequence_stream += sequence[n] 251 | self.model = nn.Sequential(*sequence_stream) 252 | 253 | def forward(self, input): 254 | if self.getIntermFeat: 255 | res = [input] 256 | for n in range(self.n_layers+2): 257 | model = getattr(self, 'model'+str(n)) 258 | res.append(model(res[-1])) 259 | return res[-1], res[1:] 260 | else: 261 | return self.model(input), _ 262 | 263 | 264 | class ActNorm(nn.Module): 265 | def __init__(self, num_features, logdet=False, affine=True, 266 | allow_reverse_init=False): 267 | assert affine 268 | super().__init__() 269 | self.logdet = logdet 270 | self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) 271 | self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) 272 | self.allow_reverse_init = allow_reverse_init 273 | 274 | self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8)) 275 | 276 | def initialize(self, input): 277 | with torch.no_grad(): 278 | flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) 279 | mean = ( 280 | flatten.mean(1) 281 | .unsqueeze(1) 282 | .unsqueeze(2) 283 | .unsqueeze(3) 284 | .permute(1, 0, 2, 3) 285 | ) 286 | std = ( 287 | flatten.std(1) 288 | .unsqueeze(1) 289 | .unsqueeze(2) 290 | .unsqueeze(3) 291 | .permute(1, 0, 2, 3) 292 | ) 293 | 294 | self.loc.data.copy_(-mean) 295 | self.scale.data.copy_(1 / (std + 1e-6)) 296 | 297 | def forward(self, input, reverse=False): 298 | if reverse: 299 | return self.reverse(input) 300 | if len(input.shape) == 2: 301 | input = input[:,:,None,None] 302 | squeeze = True 303 | else: 304 | squeeze = False 305 | 306 | _, _, height, width = input.shape 307 | 308 | if self.training and self.initialized.item() == 0: 309 | self.initialize(input) 310 | self.initialized.fill_(1) 311 | 312 | h = self.scale * (input + self.loc) 313 | 314 | if squeeze: 315 | h = h.squeeze(-1).squeeze(-1) 316 | 317 | if self.logdet: 318 | log_abs = torch.log(torch.abs(self.scale)) 319 | logdet = height*width*torch.sum(log_abs) 320 | logdet = logdet * torch.ones(input.shape[0]).to(input) 321 | return h, logdet 322 | 323 | return h 324 | 325 | def reverse(self, output): 326 | if self.training and self.initialized.item() == 0: 327 | if not self.allow_reverse_init: 328 | raise RuntimeError( 329 | "Initializing ActNorm in reverse direction is " 330 | "disabled by default. Use allow_reverse_init=True to enable." 331 | ) 332 | else: 333 | self.initialize(output) 334 | self.initialized.fill_(1) 335 | 336 | if len(output.shape) == 2: 337 | output = output[:,:,None,None] 338 | squeeze = True 339 | else: 340 | squeeze = False 341 | 342 | h = output / self.scale - self.loc 343 | 344 | if squeeze: 345 | h = h.squeeze(-1).squeeze(-1) 346 | return h 347 | -------------------------------------------------------------------------------- /PVDM/losses/vgg.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anuragajay/hip/f83186af477bf6d8c5f8f3c1fdef43ccb89db25d/PVDM/losses/vgg.pth -------------------------------------------------------------------------------- /PVDM/main.py: -------------------------------------------------------------------------------- 1 | import sys; sys.path.extend(['.']) 2 | 3 | import os 4 | import argparse 5 | 6 | import torch 7 | from omegaconf import OmegaConf 8 | 9 | from exps.diffusion import diffusion 10 | from exps.first_stage import first_stage 11 | 12 | from utils import set_random_seed 13 | 14 | 15 | 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('--exp', type=str, required=True, help='experiment name to run') 18 | parser.add_argument('--seed', type=int, default=42, help='random seed') 19 | parser.add_argument('--id', type=str, default='main', help='experiment identifier') 20 | 21 | """ Args about Data """ 22 | parser.add_argument('--data', type=str, default='cliport') 23 | parser.add_argument('--batch_size', type=int, default=24) 24 | parser.add_argument('--ds', type=int, default=4) 25 | 26 | """ Args about Model """ 27 | parser.add_argument('--pretrain_config', type=str, default='configs/autoencoder/base.yaml') 28 | parser.add_argument('--diffusion_config', type=str, default='configs/latent-diffusion/base.yaml') 29 | 30 | # for GAN resume 31 | parser.add_argument('--first_stage_folder', type=str, default='', help='the folder of first stage experiment before GAN') 32 | 33 | # for diffusion model path specification 34 | parser.add_argument('--first_model', type=str, default='', help='the path of pretrained model') 35 | parser.add_argument('--scale_lr', action='store_true') 36 | 37 | 38 | def main(): 39 | """ Additional args ends here. """ 40 | args = parser.parse_args() 41 | """ FIX THE RANDOMNESS """ 42 | set_random_seed(args.seed) 43 | torch.backends.cudnn.deterministic = True 44 | torch.backends.cudnn.benchmark = False 45 | 46 | args.n_gpus = torch.cuda.device_count() 47 | 48 | # init and save configs 49 | 50 | """ RUN THE EXP """ 51 | if args.exp == 'ddpm': 52 | config = OmegaConf.load(args.diffusion_config) 53 | first_stage_config = OmegaConf.load(args.pretrain_config) 54 | 55 | args.unetconfig = config.model.params.unet_config 56 | args.lr = config.model.base_learning_rate 57 | args.scheduler = config.model.params.scheduler_config 58 | args.res = first_stage_config.model.params.ddconfig.resolution 59 | args.timesteps = first_stage_config.model.params.ddconfig.timesteps 60 | args.skip = first_stage_config.model.params.ddconfig.skip 61 | args.ddconfig = first_stage_config.model.params.ddconfig 62 | args.embed_dim = first_stage_config.model.params.embed_dim 63 | args.ddpmconfig = config.model.params 64 | args.cond_model = config.model.cond_model 65 | 66 | # if args.n_gpus == 1: 67 | # diffusion(rank=0, args=args) 68 | # else: 69 | # torch.multiprocessing.spawn(fn=diffusion, args=(args, ), nprocs=args.n_gpus) 70 | diffusion(rank=0, args=args) 71 | 72 | elif args.exp == 'first_stage': 73 | config = OmegaConf.load(args.pretrain_config) 74 | args.ddconfig = config.model.params.ddconfig 75 | args.embed_dim = config.model.params.embed_dim 76 | args.lossconfig = config.model.params.lossconfig 77 | args.lr = config.model.base_learning_rate 78 | args.res = config.model.params.ddconfig.resolution 79 | args.timesteps = config.model.params.ddconfig.timesteps 80 | args.skip = config.model.params.ddconfig.skip 81 | args.resume = config.model.resume 82 | args.amp = config.model.amp 83 | # if args.n_gpus == 1: 84 | # first_stage(rank=0, args=args) 85 | # else: 86 | # torch.multiprocessing.spawn(fn=first_stage, args=(args, ), nprocs=args.n_gpus) 87 | first_stage(rank=0, args=args) 88 | 89 | else: 90 | raise ValueError("Unknown experiment.") 91 | 92 | if __name__ == '__main__': 93 | main() -------------------------------------------------------------------------------- /PVDM/metric_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import os 10 | import time 11 | import hashlib 12 | import pickle 13 | import copy 14 | import uuid 15 | from urllib.parse import urlparse 16 | import numpy as np 17 | import torch 18 | 19 | import ctypes 20 | import fnmatch 21 | import importlib 22 | import inspect 23 | import numpy as np 24 | import os 25 | import shutil 26 | import sys 27 | import types 28 | import io 29 | import pickle 30 | import re 31 | import requests 32 | import html 33 | import hashlib 34 | import glob 35 | import tempfile 36 | import urllib 37 | import urllib.request 38 | import uuid 39 | 40 | 41 | _dnnlib_cache_dir = None 42 | 43 | def make_cache_dir_path(*paths: str) -> str: 44 | if _dnnlib_cache_dir is not None: 45 | return os.path.join(_dnnlib_cache_dir, *paths) 46 | if 'DNNLIB_CACHE_DIR' in os.environ: 47 | return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths) 48 | if 'HOME' in os.environ: 49 | return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths) 50 | if 'USERPROFILE' in os.environ: 51 | return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths) 52 | return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths) 53 | #---------------------------------------------------------------------------- 54 | _feature_detector_cache = dict() 55 | 56 | def get_feature_detector_name(url): 57 | return os.path.splitext(url.split('/')[-1])[0] 58 | 59 | def get_feature_detector(url, device=torch.device('cpu'), num_gpus=1, rank=0, verbose=False): 60 | assert 0 <= rank < num_gpus 61 | key = (url, device) 62 | if key not in _feature_detector_cache: 63 | is_leader = (rank == 0) 64 | if not is_leader and num_gpus > 1: 65 | torch.distributed.barrier() # leader goes first 66 | with open_url(url, verbose=(verbose and is_leader)) as f: 67 | if urlparse(url).path.endswith('.pkl'): 68 | _feature_detector_cache[key] = pickle.load(f).to(device) 69 | else: 70 | _feature_detector_cache[key] = torch.jit.load(f).eval().to(device) 71 | if is_leader and num_gpus > 1: 72 | torch.distributed.barrier() # others follow 73 | return _feature_detector_cache[key] 74 | 75 | 76 | def open_url(url: str, cache_dir=None, num_attempts= 10, verbose= True, return_filename = False, cache= True): 77 | """Download the given URL and return a binary-mode file object to access the data.""" 78 | assert num_attempts >= 1 79 | assert not (return_filename and (not cache)) 80 | 81 | # Doesn't look like an URL scheme so interpret it as a local filename. 82 | if not re.match('^[a-z]+://', url): 83 | return url if return_filename else open(url, "rb") 84 | 85 | # Handle file URLs. This code handles unusual file:// patterns that 86 | # arise on Windows: 87 | # 88 | # file:///c:/foo.txt 89 | # 90 | # which would translate to a local '/c:/foo.txt' filename that's 91 | # invalid. Drop the forward slash for such pathnames. 92 | # 93 | # If you touch this code path, you should test it on both Linux and 94 | # Windows. 95 | # 96 | # Some internet resources suggest using urllib.request.url2pathname() but 97 | # but that converts forward slashes to backslashes and this causes 98 | # its own set of problems. 99 | if url.startswith('file://'): 100 | filename = urllib.parse.urlparse(url).path 101 | if re.match(r'^/[a-zA-Z]:', filename): 102 | filename = filename[1:] 103 | return filename if return_filename else open(filename, "rb") 104 | 105 | # Lookup from cache. 106 | if cache_dir is None: 107 | cache_dir = make_cache_dir_path('downloads') 108 | 109 | url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest() 110 | if cache: 111 | cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*")) 112 | if len(cache_files) == 1: 113 | filename = cache_files[0] 114 | return filename if return_filename else open(filename, "rb") 115 | 116 | # Download. 117 | url_name = None 118 | url_data = None 119 | with requests.Session() as session: 120 | if verbose: 121 | print("Downloading %s ..." % url, end="", flush=True) 122 | for attempts_left in reversed(range(num_attempts)): 123 | try: 124 | with session.get(url) as res: 125 | res.raise_for_status() 126 | if len(res.content) == 0: 127 | raise IOError("No data received") 128 | 129 | if len(res.content) < 8192: 130 | content_str = res.content.decode("utf-8") 131 | if "download_warning" in res.headers.get("Set-Cookie", ""): 132 | links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link] 133 | if len(links) == 1: 134 | url = requests.compat.urljoin(url, links[0]) 135 | raise IOError("Google Drive virus checker nag") 136 | if "Google Drive - Quota exceeded" in content_str: 137 | raise IOError("Google Drive download quota exceeded -- please try again later") 138 | 139 | match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", "")) 140 | url_name = match[1] if match else url 141 | url_data = res.content 142 | if verbose: 143 | print(" done") 144 | break 145 | except KeyboardInterrupt: 146 | raise 147 | except: 148 | if not attempts_left: 149 | if verbose: 150 | print(" failed") 151 | raise 152 | if verbose: 153 | print(".", end="", flush=True) 154 | 155 | # Save to cache. 156 | if cache: 157 | safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name) 158 | cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name) 159 | temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name) 160 | os.makedirs(cache_dir, exist_ok=True) 161 | with open(temp_file, "wb") as f: 162 | f.write(url_data) 163 | os.replace(temp_file, cache_file) # atomic 164 | if return_filename: 165 | return cache_file 166 | 167 | # Return data as file object. 168 | assert not return_filename 169 | return io.BytesIO(url_data) 170 | 171 | 172 | 173 | 174 | #---------------------------------------------------------------------------- -------------------------------------------------------------------------------- /PVDM/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anuragajay/hip/f83186af477bf6d8c5f8f3c1fdef43ccb89db25d/PVDM/models/__init__.py -------------------------------------------------------------------------------- /PVDM/models/autoencoder/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anuragajay/hip/f83186af477bf6d8c5f8f3c1fdef43ccb89db25d/PVDM/models/autoencoder/__init__.py -------------------------------------------------------------------------------- /PVDM/models/autoencoder/autoencoder_vit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | 6 | from models.autoencoder.vit_modules import TimeSformerEncoder, TimeSformerDecoder 7 | from einops import rearrange, repeat 8 | from einops.layers.torch import Rearrange 9 | 10 | from time import sleep 11 | 12 | # siren layer 13 | 14 | class PreNorm(nn.Module): 15 | def __init__(self, dim, fn): 16 | super().__init__() 17 | self.norm = nn.LayerNorm(dim) 18 | self.fn = fn 19 | def forward(self, x, **kwargs): 20 | return self.fn(self.norm(x), **kwargs) 21 | 22 | class FeedForward(nn.Module): 23 | def __init__(self, dim, hidden_dim, dropout = 0.): 24 | super().__init__() 25 | self.net = nn.Sequential( 26 | nn.Linear(dim, hidden_dim), 27 | nn.GELU(), 28 | nn.Dropout(dropout), 29 | nn.Linear(hidden_dim, dim), 30 | nn.Dropout(dropout) 31 | ) 32 | def forward(self, x): 33 | return self.net(x) 34 | 35 | class Attention(nn.Module): 36 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): 37 | super().__init__() 38 | inner_dim = dim_head * heads 39 | project_out = not (heads == 1 and dim_head == dim) 40 | 41 | self.heads = heads 42 | self.scale = dim_head ** -0.5 43 | 44 | self.attend = nn.Softmax(dim = -1) 45 | self.dropout = nn.Dropout(dropout) 46 | 47 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 48 | 49 | self.to_out = nn.Sequential( 50 | nn.Linear(inner_dim, dim), 51 | nn.Dropout(dropout) 52 | ) if project_out else nn.Identity() 53 | 54 | def forward(self, x): 55 | qkv = self.to_qkv(x).chunk(3, dim = -1) 56 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) 57 | 58 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale 59 | 60 | attn = self.attend(dots) 61 | attn = self.dropout(attn) 62 | 63 | out = torch.matmul(attn, v) 64 | out = rearrange(out, 'b h n d -> b n (h d)') 65 | return self.to_out(out) 66 | 67 | class Transformer(nn.Module): 68 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): 69 | super().__init__() 70 | self.layers = nn.ModuleList([]) 71 | for _ in range(depth): 72 | self.layers.append(nn.ModuleList([ 73 | PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), 74 | PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) 75 | ])) 76 | def forward(self, x): 77 | for attn, ff in self.layers: 78 | x = attn(x) + x 79 | x = ff(x) + x 80 | return x 81 | 82 | # =========================================================================================== 83 | 84 | 85 | class ViTAutoencoder(nn.Module): 86 | def __init__(self, 87 | embed_dim, 88 | ddconfig, 89 | ckpt_path=None, 90 | ignore_keys=[], 91 | image_key="image", 92 | colorize_nlabels=None, 93 | monitor=None, 94 | ): 95 | super().__init__() 96 | self.splits = ddconfig["splits"] 97 | self.s = ddconfig["timesteps"] // self.splits 98 | 99 | self.res = ddconfig["resolution"] 100 | 101 | self.res_h = int(0.75*self.res) 102 | self.res_w = self.res 103 | 104 | self.embed_dim = embed_dim 105 | self.image_key = image_key 106 | 107 | patch_size = 8 108 | self.down = 3 109 | if self.res <= 128: 110 | patch_size = 4 111 | self.down = 2 112 | 113 | self.encoder = TimeSformerEncoder(dim=ddconfig["channels"], 114 | image_size=ddconfig["resolution"], 115 | num_frames=ddconfig["timesteps"], 116 | depth=8, 117 | patch_size=patch_size) 118 | 119 | self.decoder = TimeSformerDecoder(dim=ddconfig["channels"], 120 | image_size=ddconfig["resolution"], 121 | num_frames=ddconfig["timesteps"], 122 | depth=8, 123 | patch_size=patch_size) 124 | 125 | self.to_pixel = nn.Sequential( 126 | Rearrange('b (t h w) c -> (b t) c h w', h=self.res_h // patch_size, w=self.res_w // patch_size), 127 | nn.ConvTranspose2d(ddconfig["channels"], 3, kernel_size=(patch_size, patch_size), stride=patch_size), 128 | ) 129 | 130 | self.act = nn.Sigmoid() 131 | ts = torch.linspace(-1, 1, steps=self.s).unsqueeze(-1) 132 | self.register_buffer('coords', ts) 133 | 134 | self.xy_token = nn.Parameter(torch.randn(1, 1, ddconfig["channels"])) 135 | self.xt_token = nn.Parameter(torch.randn(1, 1, ddconfig["channels"])) 136 | self.yt_token = nn.Parameter(torch.randn(1, 1, ddconfig["channels"])) 137 | 138 | self.xy_pos_embedding = nn.Parameter(torch.randn(1, self.s + 1, ddconfig["channels"])) 139 | self.xt_pos_embedding = nn.Parameter(torch.randn(1, self.res_w//(2**self.down) + 1, ddconfig["channels"])) 140 | self.yt_pos_embedding = nn.Parameter(torch.randn(1, self.res_h//(2**self.down) + 1, ddconfig["channels"])) 141 | 142 | self.xy_quant_attn = Transformer(ddconfig["channels"], 4, self.embed_dim, ddconfig["channels"] // 8, 512) 143 | self.yt_quant_attn = Transformer(ddconfig["channels"], 4, self.embed_dim, ddconfig["channels"] // 8, 512) 144 | self.xt_quant_attn = Transformer(ddconfig["channels"], 4, self.embed_dim, ddconfig["channels"] // 8, 512) 145 | 146 | self.pre_xy = torch.nn.Conv2d(ddconfig["channels"], self.embed_dim, 1) 147 | self.pre_xt = torch.nn.Conv2d(ddconfig["channels"], self.embed_dim, 1) 148 | self.pre_yt = torch.nn.Conv2d(ddconfig["channels"], self.embed_dim, 1) 149 | 150 | self.post_xy = torch.nn.Conv2d(self.embed_dim, ddconfig["channels"], 1) 151 | self.post_xt = torch.nn.Conv2d(self.embed_dim, ddconfig["channels"], 1) 152 | self.post_yt = torch.nn.Conv2d(self.embed_dim, ddconfig["channels"], 1) 153 | 154 | def encode(self, x): 155 | # x: b c t h w 156 | b = x.size(0) 157 | x = rearrange(x, 'b c t h w -> b t c h w') 158 | h = self.encoder(x) 159 | h = rearrange(h, 'b (t h w) c -> b c t h w', t=self.s, h=self.res_h//(2**self.down)) 160 | 161 | h_xy = rearrange(h, 'b c t h w -> (b h w) t c') 162 | n = h_xy.size(1) 163 | xy_token = repeat(self.xy_token, '1 1 d -> bhw 1 d', bhw = h_xy.size(0)) 164 | h_xy = torch.cat([h_xy, xy_token], dim=1) 165 | h_xy += self.xy_pos_embedding[:, :(n+1)] 166 | h_xy = self.xy_quant_attn(h_xy)[:, 0] 167 | h_xy = rearrange(h_xy, '(b h w) c -> b c h w', b=b, h=self.res_h//(2**self.down)) 168 | 169 | h_yt = rearrange(h, 'b c t h w -> (b t w) h c') 170 | n = h_yt.size(1) 171 | yt_token = repeat(self.yt_token, '1 1 d -> btw 1 d', btw = h_yt.size(0)) 172 | h_yt = torch.cat([h_yt, yt_token], dim=1) 173 | h_yt += self.yt_pos_embedding[:, :(n+1)] 174 | h_yt = self.yt_quant_attn(h_yt)[:, 0] 175 | h_yt = rearrange(h_yt, '(b t w) c -> b c t w', b=b, w=self.res_w//(2**self.down)) 176 | 177 | h_xt = rearrange(h, 'b c t h w -> (b t h) w c') 178 | n = h_xt.size(1) 179 | xt_token = repeat(self.xt_token, '1 1 d -> bth 1 d', bth = h_xt.size(0)) 180 | h_xt = torch.cat([h_xt, xt_token], dim=1) 181 | h_xt += self.xt_pos_embedding[:, :(n+1)] 182 | h_xt = self.xt_quant_attn(h_xt)[:, 0] 183 | h_xt = rearrange(h_xt, '(b t h) c -> b c t h', b=b, h=self.res_h//(2**self.down)) 184 | 185 | h_xy = self.pre_xy(h_xy) 186 | h_yt = self.pre_yt(h_yt) 187 | h_xt = self.pre_xt(h_xt) 188 | 189 | h_xy = torch.tanh(h_xy) 190 | h_yt = torch.tanh(h_yt) 191 | h_xt = torch.tanh(h_xt) 192 | 193 | h_xy = self.post_xy(h_xy) 194 | h_yt = self.post_yt(h_yt) 195 | h_xt = self.post_xt(h_xt) 196 | 197 | h_xy = h_xy.unsqueeze(-3).expand(-1,-1,self.s,-1, -1) 198 | h_yt = h_yt.unsqueeze(-2).expand(-1,-1,-1,self.res_h//(2**self.down), -1) 199 | h_xt = h_xt.unsqueeze(-1).expand(-1,-1,-1,-1,self.res_w//(2**self.down)) 200 | return h_xy + h_yt + h_xt #torch.cat([h_xy, h_yt, h_xt], dim=1) 201 | 202 | def decode(self, z): 203 | b = z.size(0) 204 | dec = self.decoder(z) 205 | return 2*self.act(self.to_pixel(dec)).contiguous() -1 206 | 207 | def forward(self, input): 208 | input = rearrange(input, 'b c (n t) h w -> (b n) c t h w', n=self.splits) 209 | z = self.encode(input) 210 | dec = self.decode(z) 211 | return dec, 0. 212 | 213 | def extract(self, x): 214 | b = x.size(0) 215 | x = rearrange(x, 'b c t h w -> b t c h w') 216 | h = self.encoder(x) 217 | h = rearrange(h, 'b (t h w) c -> b c t h w', t=self.s, h=self.res_h//(2**self.down)) 218 | 219 | h_xy = rearrange(h, 'b c t h w -> (b h w) t c') 220 | n = h_xy.size(1) 221 | xy_token = repeat(self.xy_token, '1 1 d -> bhw 1 d', bhw = h_xy.size(0)) 222 | h_xy = torch.cat([h_xy, xy_token], dim=1) 223 | h_xy += self.xy_pos_embedding[:, :(n+1)] 224 | h_xy = self.xy_quant_attn(h_xy)[:, 0] 225 | h_xy = rearrange(h_xy, '(b h w) c -> b c h w', b=b, h=self.res_h//(2**self.down)) 226 | 227 | h_yt = rearrange(h, 'b c t h w -> (b t w) h c') 228 | n = h_yt.size(1) 229 | yt_token = repeat(self.yt_token, '1 1 d -> btw 1 d', btw = h_yt.size(0)) 230 | h_yt = torch.cat([h_yt, yt_token], dim=1) 231 | h_yt += self.yt_pos_embedding[:, :(n+1)] 232 | h_yt = self.yt_quant_attn(h_yt)[:, 0] 233 | h_yt = rearrange(h_yt, '(b t w) c -> b c t w', b=b, w=self.res_w//(2**self.down)) 234 | 235 | h_xt = rearrange(h, 'b c t h w -> (b t h) w c') 236 | n = h_xt.size(1) 237 | xt_token = repeat(self.xt_token, '1 1 d -> bth 1 d', bth = h_xt.size(0)) 238 | h_xt = torch.cat([h_xt, xt_token], dim=1) 239 | h_xt += self.xt_pos_embedding[:, :(n+1)] 240 | h_xt = self.xt_quant_attn(h_xt)[:, 0] 241 | h_xt = rearrange(h_xt, '(b t h) c -> b c t h', b=b, h=self.res_h//(2**self.down)) 242 | 243 | h_xy = self.pre_xy(h_xy) 244 | h_yt = self.pre_yt(h_yt) 245 | h_xt = self.pre_xt(h_xt) 246 | 247 | h_xy = torch.tanh(h_xy) 248 | h_yt = torch.tanh(h_yt) 249 | h_xt = torch.tanh(h_xt) 250 | 251 | h_xy = h_xy.view(h_xy.size(0), h_xy.size(1), -1) 252 | h_yt = h_yt.view(h_yt.size(0), h_yt.size(1), -1) 253 | h_xt = h_xt.view(h_xt.size(0), h_xt.size(1), -1) 254 | 255 | ret = torch.cat([h_xy, h_yt, h_xt], dim=-1) 256 | return ret 257 | 258 | def decode_from_sample(self, h): 259 | latent_res_h = self.res_h // (2**self.down) 260 | latent_res_w = self.res_w // (2**self.down) 261 | h_xy = h[:, :, 0:latent_res_h*latent_res_w].view(h.size(0), h.size(1), latent_res_h, latent_res_w) 262 | h_yt = h[:, :, latent_res_h*latent_res_w:latent_res_w*(latent_res_h+16)].view(h.size(0), h.size(1), 16, latent_res_w) 263 | h_xt = h[:, :, latent_res_w*(latent_res_h+16):latent_res_w*latent_res_h+16*latent_res_w+16*latent_res_h].view(h.size(0), h.size(1), 16, latent_res_h) 264 | 265 | h_xy = self.post_xy(h_xy) 266 | h_yt = self.post_yt(h_yt) 267 | h_xt = self.post_xt(h_xt) 268 | 269 | h_xy = h_xy.unsqueeze(-3).expand(-1,-1,self.s,-1, -1) 270 | h_yt = h_yt.unsqueeze(-2).expand(-1,-1,-1,self.res_h//(2**self.down), -1) 271 | h_xt = h_xt.unsqueeze(-1).expand(-1,-1,-1,-1,self.res_w//(2**self.down)) 272 | 273 | z = h_xy + h_yt + h_xt 274 | 275 | b = z.size(0) 276 | dec = self.decoder(z) 277 | return 2*self.act(self.to_pixel(dec)).contiguous()-1 278 | -------------------------------------------------------------------------------- /PVDM/models/autoencoder/vit_modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, einsum 3 | import torch.nn.functional as F 4 | from einops import rearrange, repeat 5 | from einops.layers.torch import Rearrange 6 | from math import log, pi 7 | 8 | def rotate_every_two(x): 9 | x = rearrange(x, '... (d j) -> ... d j', j = 2) 10 | x1, x2 = x.unbind(dim = -1) 11 | x = torch.stack((-x2, x1), dim = -1) 12 | return rearrange(x, '... d j -> ... (d j)') 13 | 14 | def apply_rot_emb(q, k, rot_emb): 15 | sin, cos = rot_emb 16 | rot_dim = sin.shape[-1] 17 | (q, q_pass), (k, k_pass) = map(lambda t: (t[..., :rot_dim], t[..., rot_dim:]), (q, k)) 18 | q, k = map(lambda t: t * cos + rotate_every_two(t) * sin, (q, k)) 19 | q, k = map(lambda t: torch.cat(t, dim = -1), ((q, q_pass), (k, k_pass))) 20 | return q, k 21 | 22 | class AxialRotaryEmbedding(nn.Module): 23 | def __init__(self, dim, max_freq = 10): 24 | super().__init__() 25 | self.dim = dim 26 | scales = torch.logspace(0., log(max_freq / 2) / log(2), self.dim // 4, base = 2) 27 | self.register_buffer('scales', scales) 28 | 29 | def forward(self, h, w, device): 30 | scales = rearrange(self.scales, '... -> () ...') 31 | scales = scales.to(device) 32 | 33 | h_seq = torch.linspace(-1., 1., steps = h, device = device) 34 | h_seq = h_seq.unsqueeze(-1) 35 | 36 | w_seq = torch.linspace(-1., 1., steps = w, device = device) 37 | w_seq = w_seq.unsqueeze(-1) 38 | 39 | h_seq = h_seq * scales * pi 40 | w_seq = w_seq * scales * pi 41 | 42 | x_sinu = repeat(h_seq, 'i d -> i j d', j = w) 43 | y_sinu = repeat(w_seq, 'j d -> i j d', i = h) 44 | 45 | sin = torch.cat((x_sinu.sin(), y_sinu.sin()), dim = -1) 46 | cos = torch.cat((x_sinu.cos(), y_sinu.cos()), dim = -1) 47 | 48 | sin, cos = map(lambda t: rearrange(t, 'i j d -> (i j) d'), (sin, cos)) 49 | sin, cos = map(lambda t: repeat(t, 'n d -> () n (d j)', j = 2), (sin, cos)) 50 | return sin, cos 51 | 52 | class RotaryEmbedding(nn.Module): 53 | def __init__(self, dim): 54 | super().__init__() 55 | inv_freqs = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) 56 | self.register_buffer('inv_freqs', inv_freqs) 57 | 58 | def forward(self, n, device): 59 | seq = torch.arange(n, device = device) 60 | freqs = einsum('i, j -> i j', seq, self.inv_freqs) 61 | freqs = torch.cat((freqs, freqs), dim = -1) 62 | freqs = rearrange(freqs, 'n d -> () n d') 63 | return freqs.sin(), freqs.cos() 64 | 65 | def exists(val): 66 | return val is not None 67 | 68 | # classes 69 | 70 | class PreNorm(nn.Module): 71 | def __init__(self, dim, fn): 72 | super().__init__() 73 | self.fn = fn 74 | self.norm = nn.LayerNorm(dim) 75 | 76 | def forward(self, x, *args, **kwargs): 77 | x = self.norm(x) 78 | return self.fn(x, *args, **kwargs) 79 | 80 | # time token shift 81 | 82 | def shift(t, amt): 83 | if amt == 0: 84 | return t 85 | return F.pad(t, (0, 0, 0, 0, amt, -amt)) 86 | 87 | # feedforward 88 | 89 | class GEGLU(nn.Module): 90 | def forward(self, x): 91 | x, gates = x.chunk(2, dim = -1) 92 | return x * F.gelu(gates) 93 | 94 | class FeedForward(nn.Module): 95 | def __init__(self, dim, mult = 4, dropout = 0.): 96 | super().__init__() 97 | self.net = nn.Sequential( 98 | nn.Linear(dim, dim * mult * 2), 99 | GEGLU(), 100 | nn.Dropout(dropout), 101 | nn.Linear(dim * mult, dim) 102 | ) 103 | 104 | def forward(self, x): 105 | return self.net(x) 106 | 107 | # attention 108 | 109 | def attn(q, k, v, mask = None): 110 | sim = einsum('b i d, b j d -> b i j', q, k) 111 | 112 | if exists(mask): 113 | max_neg_value = -torch.finfo(sim.dtype).max 114 | sim.masked_fill_(~mask, max_neg_value) 115 | 116 | attn = sim.softmax(dim = -1) 117 | out = einsum('b i j, b j d -> b i d', attn, v) 118 | return out 119 | 120 | class Attention(nn.Module): 121 | def __init__( 122 | self, 123 | dim, 124 | dim_head = 64, 125 | heads = 8, 126 | dropout = 0. 127 | ): 128 | super().__init__() 129 | self.heads = heads 130 | self.scale = dim_head ** -0.5 131 | inner_dim = dim_head * heads 132 | 133 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 134 | self.to_out = nn.Sequential( 135 | nn.Linear(inner_dim, dim), 136 | nn.Dropout(dropout) 137 | ) 138 | 139 | def forward(self, x, einops_from, einops_to, mask = None, rot_emb = None, **einops_dims): 140 | h = self.heads 141 | q, k, v = self.to_qkv(x).chunk(3, dim = -1) 142 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (q, k, v)) 143 | 144 | q = q * self.scale 145 | 146 | # rearrange across time or space 147 | q, k, v = map(lambda t: rearrange(t, f'{einops_from} -> {einops_to}', **einops_dims), (q, k, v)) 148 | 149 | # add rotary embeddings, if applicable 150 | if exists(rot_emb): 151 | q, k = apply_rot_emb(q, k, rot_emb) 152 | 153 | # expand cls token keys and values across time or space and concat 154 | # attention 155 | out = attn(q, k, v, mask = mask) 156 | out = rearrange(out, f'{einops_to} -> {einops_from}', **einops_dims) 157 | out = rearrange(out, '(b h) n d -> b n (h d)', h = h) 158 | 159 | # combine heads out 160 | return self.to_out(out) 161 | 162 | # main classes 163 | 164 | class TimeSformerEncoder(nn.Module): 165 | def __init__( 166 | self, 167 | *, 168 | dim = 512, 169 | num_frames = 16, 170 | image_size = 64, 171 | patch_size = 8, 172 | channels = 3, 173 | depth = 8, 174 | heads = 8, 175 | dim_head = 64, 176 | attn_dropout = 0., 177 | ff_dropout = 0., 178 | rotary_emb = True, 179 | shift_tokens = False, 180 | ): 181 | super().__init__() 182 | image_size_h = int(0.75*image_size) 183 | image_size_w = image_size 184 | 185 | assert image_size_h % patch_size == 0, 'Image dimensions must be divisible by the patch size.' 186 | assert image_size_w % patch_size == 0, 'Image dimensions must be divisible by the patch size.' 187 | 188 | num_patches = (image_size_h*image_size_w) // (patch_size**2) 189 | num_positions = num_frames * num_patches 190 | patch_dim = channels * patch_size ** 2 191 | 192 | self.heads = heads 193 | self.patch_size = patch_size 194 | self.to_patch_embedding = nn.Linear(patch_dim, dim) 195 | 196 | self.use_rotary_emb = rotary_emb 197 | if rotary_emb: 198 | self.frame_rot_emb = RotaryEmbedding(dim_head) 199 | self.image_rot_emb = AxialRotaryEmbedding(dim_head) 200 | else: 201 | self.pos_emb = nn.Embedding(num_positions, dim) 202 | 203 | self.layers = nn.ModuleList([]) 204 | for _ in range(depth): 205 | ff = FeedForward(dim, dropout = ff_dropout) 206 | time_attn = Attention(dim, dim_head = dim_head, heads = heads, dropout = attn_dropout) 207 | spatial_attn = Attention(dim, dim_head = dim_head, heads = heads, dropout = attn_dropout) 208 | 209 | time_attn, spatial_attn, ff = map(lambda t: PreNorm(dim, t), (time_attn, spatial_attn, ff)) 210 | 211 | self.layers.append(nn.ModuleList([time_attn, spatial_attn, ff])) 212 | 213 | def forward(self, video, frame_mask = None): 214 | b, f, _, h, w, *_, device, p = *video.shape, video.device, self.patch_size 215 | assert h % p == 0 and w % p == 0, f'height {h} and width {w} of video must be divisible by the patch size {p}' 216 | 217 | # calculate num patches in height and width dimension, and number of total patches (n) 218 | hp, wp = (h // p), (w // p) 219 | n = hp * wp 220 | 221 | # video to patch embeddings 222 | video = rearrange(video, 'b f c (h p1) (w p2) -> b (f h w) (p1 p2 c)', p1 = p, p2 = p) 223 | x = self.to_patch_embedding(video) 224 | 225 | # positional embedding 226 | frame_pos_emb = None 227 | image_pos_emb = None 228 | if not self.use_rotary_emb: 229 | x += self.pos_emb(torch.arange(x.shape[1], device = device)) 230 | else: 231 | frame_pos_emb = self.frame_rot_emb(f, device = device) 232 | image_pos_emb = self.image_rot_emb(hp, wp, device = device) 233 | 234 | # time and space attention 235 | for (time_attn, spatial_attn, ff) in self.layers: 236 | x = time_attn(x, 'b (f n) d', '(b n) f d', n = n, mask = frame_mask, rot_emb = frame_pos_emb) + x 237 | x = spatial_attn(x, 'b (f n) d', '(b f) n d', f = f, rot_emb = image_pos_emb) + x 238 | x = ff(x) + x 239 | 240 | return x 241 | 242 | class TimeSformerDecoder(nn.Module): 243 | def __init__( 244 | self, 245 | *, 246 | dim = 512, 247 | num_frames = 16, 248 | image_size = 64, 249 | patch_size = 8, 250 | channels = 3, 251 | depth = 8, 252 | heads = 8, 253 | dim_head = 64, 254 | attn_dropout = 0., 255 | ff_dropout = 0., 256 | rotary_emb = True, 257 | shift_tokens = False, 258 | ): 259 | super().__init__() 260 | image_size_h = int(0.75*image_size) 261 | image_size_w = image_size 262 | 263 | assert image_size_h % patch_size == 0, 'Image dimensions must be divisible by the patch size.' 264 | assert image_size_w % patch_size == 0, 'Image dimensions must be divisible by the patch size.' 265 | 266 | num_patches = (image_size_h*image_size_w) // (patch_size**2) 267 | num_positions = num_frames * num_patches 268 | patch_dim = channels * patch_size ** 2 269 | 270 | 271 | self.heads = heads 272 | self.patch_size = patch_size 273 | 274 | self.use_rotary_emb = rotary_emb 275 | if rotary_emb: 276 | self.frame_rot_emb = RotaryEmbedding(dim_head) 277 | self.image_rot_emb = AxialRotaryEmbedding(dim_head) 278 | else: 279 | self.pos_emb = nn.Embedding(num_positions, dim) 280 | 281 | self.layers = nn.ModuleList([]) 282 | for _ in range(depth): 283 | ff = FeedForward(dim, dropout = ff_dropout) 284 | time_attn = Attention(dim, dim_head = dim_head, heads = heads, dropout = attn_dropout) 285 | spatial_attn = Attention(dim, dim_head = dim_head, heads = heads, dropout = attn_dropout) 286 | 287 | time_attn, spatial_attn, ff = map(lambda t: PreNorm(dim, t), (time_attn, spatial_attn, ff)) 288 | 289 | self.layers.append(nn.ModuleList([time_attn, spatial_attn, ff])) 290 | 291 | def forward(self, x, frame_mask = None): 292 | device = x.device 293 | f, hp, wp = x.size(2), x.size(3), x.size(4) 294 | n = hp * wp 295 | x = rearrange(x, 'b c f h w -> b (f h w) c') 296 | 297 | # positional embedding 298 | frame_pos_emb = None 299 | image_pos_emb = None 300 | if not self.use_rotary_emb: 301 | x += self.pos_emb(torch.arange(x.shape[1], device = device)) 302 | else: 303 | frame_pos_emb = self.frame_rot_emb(f, device = device) 304 | image_pos_emb = self.image_rot_emb(hp, wp, device = device) 305 | 306 | # time and space attention 307 | for (time_attn, spatial_attn, ff) in self.layers: 308 | x = time_attn(x, 'b (f n) d', '(b n) f d', n = n, mask = frame_mask, rot_emb = frame_pos_emb) + x 309 | x = spatial_attn(x, 'b (f n) d', '(b f) n d', f = f, rot_emb = image_pos_emb) + x 310 | x = ff(x) + x 311 | 312 | return x -------------------------------------------------------------------------------- /PVDM/models/ddpm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anuragajay/hip/f83186af477bf6d8c5f8f3c1fdef43ccb89db25d/PVDM/models/ddpm/__init__.py -------------------------------------------------------------------------------- /PVDM/models/ddpm/diffusionmodules.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | from einops import repeat 7 | 8 | 9 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): 10 | if ddim_discr_method == 'uniform': 11 | c = num_ddpm_timesteps // num_ddim_timesteps 12 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) 13 | elif ddim_discr_method == 'quad': 14 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) 15 | else: 16 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') 17 | 18 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps 19 | # add one to get the final alpha values right (the ones from first scale to data during sampling) 20 | steps_out = ddim_timesteps + 1 21 | if verbose: 22 | print(f'Selected timesteps for ddim sampler: {steps_out}') 23 | return steps_out 24 | 25 | 26 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): 27 | # select alphas for computing the variance schedule 28 | alphas = alphacums[ddim_timesteps] 29 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) 30 | 31 | # according the the formula provided in https://arxiv.org/abs/2010.02502 32 | sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) 33 | if verbose: 34 | print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') 35 | print(f'For the chosen value of eta, which is {eta}, ' 36 | f'this results in the following sigma_t schedule for ddim sampler {sigmas}') 37 | return sigmas, alphas, alphas_prev 38 | 39 | 40 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 41 | """ 42 | Create a beta schedule that discretizes the given alpha_t_bar function, 43 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 44 | :param num_diffusion_timesteps: the number of betas to produce. 45 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 46 | produces the cumulative product of (1-beta) up to that 47 | part of the diffusion process. 48 | :param max_beta: the maximum beta to use; use values lower than 1 to 49 | prevent singularities. 50 | """ 51 | betas = [] 52 | for i in range(num_diffusion_timesteps): 53 | t1 = i / num_diffusion_timesteps 54 | t2 = (i + 1) / num_diffusion_timesteps 55 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 56 | return np.array(betas) 57 | 58 | 59 | def checkpoint(func, inputs, params, flag): 60 | """ 61 | Evaluate a function without caching intermediate activations, allowing for 62 | reduced memory at the expense of extra compute in the backward pass. 63 | :param func: the function to evaluate. 64 | :param inputs: the argument sequence to pass to `func`. 65 | :param params: a sequence of parameters `func` depends on but does not 66 | explicitly take as arguments. 67 | :param flag: if False, disable gradient checkpointing. 68 | """ 69 | if flag: 70 | args = tuple(inputs) + tuple(params) 71 | return CheckpointFunction.apply(func, len(inputs), *args) 72 | else: 73 | return func(*inputs) 74 | 75 | 76 | class CheckpointFunction(torch.autograd.Function): 77 | @staticmethod 78 | def forward(ctx, run_function, length, *args): 79 | ctx.run_function = run_function 80 | ctx.input_tensors = list(args[:length]) 81 | ctx.input_params = list(args[length:]) 82 | 83 | with torch.no_grad(): 84 | output_tensors = ctx.run_function(*ctx.input_tensors) 85 | return output_tensors 86 | 87 | @staticmethod 88 | def backward(ctx, *output_grads): 89 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 90 | with torch.enable_grad(): 91 | # Fixes a bug where the first op in run_function modifies the 92 | # Tensor storage in place, which is not allowed for detach()'d 93 | # Tensors. 94 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 95 | output_tensors = ctx.run_function(*shallow_copies) 96 | input_grads = torch.autograd.grad( 97 | output_tensors, 98 | ctx.input_tensors + ctx.input_params, 99 | output_grads, 100 | allow_unused=True, 101 | ) 102 | del ctx.input_tensors 103 | del ctx.input_params 104 | del output_tensors 105 | return (None, None) + input_grads 106 | 107 | 108 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): 109 | """ 110 | Create sinusoidal timestep embeddings. 111 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 112 | These may be fractional. 113 | :param dim: the dimension of the output. 114 | :param max_period: controls the minimum frequency of the embeddings. 115 | :return: an [N x dim] Tensor of positional embeddings. 116 | """ 117 | if not repeat_only: 118 | half = dim // 2 119 | freqs = torch.exp( 120 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 121 | ).to(device=timesteps.device) 122 | args = timesteps[:, None].float() * freqs[None] 123 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 124 | if dim % 2: 125 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 126 | else: 127 | embedding = repeat(timesteps, 'b -> b d', d=dim) 128 | return embedding 129 | 130 | 131 | def zero_module(module): 132 | """ 133 | Zero out the parameters of a module and return it. 134 | """ 135 | for p in module.parameters(): 136 | p.detach().zero_() 137 | return module 138 | 139 | 140 | def scale_module(module, scale): 141 | """ 142 | Scale the parameters of a module and return it. 143 | """ 144 | for p in module.parameters(): 145 | p.detach().mul_(scale) 146 | return module 147 | 148 | 149 | def mean_flat(tensor): 150 | """ 151 | Take the mean over all non-batch dimensions. 152 | """ 153 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 154 | 155 | 156 | def normalization(channels): 157 | """ 158 | Make a standard normalization layer. 159 | :param channels: number of input channels. 160 | :return: an nn.Module for normalization. 161 | """ 162 | return GroupNorm32(32, channels) 163 | 164 | 165 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 166 | class SiLU(nn.Module): 167 | def forward(self, x): 168 | return x * torch.sigmoid(x) 169 | 170 | 171 | class GroupNorm32(nn.GroupNorm): 172 | def forward(self, x): 173 | return super().forward(x.float()).type(x.dtype) 174 | 175 | def conv_nd(dims, *args, **kwargs): 176 | """ 177 | Create a 1D, 2D, or 3D convolution module. 178 | """ 179 | if dims == 1: 180 | return nn.Conv1d(*args, **kwargs) 181 | elif dims == 2: 182 | return nn.Conv2d(*args, **kwargs) 183 | elif dims == 3: 184 | return nn.Conv3d(*args, **kwargs) 185 | raise ValueError(f"unsupported dimensions: {dims}") 186 | 187 | 188 | def linear(*args, **kwargs): 189 | """ 190 | Create a linear module. 191 | """ 192 | return nn.Linear(*args, **kwargs) 193 | 194 | 195 | def avg_pool_nd(dims, *args, **kwargs): 196 | """ 197 | Create a 1D, 2D, or 3D average pooling module. 198 | """ 199 | if dims == 1: 200 | return nn.AvgPool1d(*args, **kwargs) 201 | elif dims == 2: 202 | return nn.AvgPool2d(*args, **kwargs) 203 | elif dims == 3: 204 | return nn.AvgPool3d(*args, **kwargs) 205 | raise ValueError(f"unsupported dimensions: {dims}") 206 | 207 | 208 | 209 | def noise_like(shape, device, repeat=False): 210 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) 211 | noise = lambda: torch.randn(shape, device=device) 212 | return repeat_noise() if repeat else noise() -------------------------------------------------------------------------------- /PVDM/models/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError('Decay must be between 0 and 1') 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates 14 | else torch.tensor(-1,dtype=torch.int)) 15 | 16 | for name, p in model.named_parameters(): 17 | if p.requires_grad: 18 | #remove as '.'-character is not allowed in buffers 19 | s_name = name.replace('.','') 20 | self.m_name2s_name.update({name:s_name}) 21 | self.register_buffer(s_name,p.clone().detach().data) 22 | 23 | self.collected_params = [] 24 | 25 | def forward(self,model): 26 | decay = self.decay 27 | 28 | if self.num_updates >= 0: 29 | self.num_updates += 1 30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) 31 | 32 | one_minus_decay = 1.0 - decay 33 | 34 | with torch.no_grad(): 35 | m_param = dict(model.named_parameters()) 36 | shadow_params = dict(self.named_buffers()) 37 | 38 | for key in m_param: 39 | if m_param[key].requires_grad: 40 | sname = self.m_name2s_name[key] 41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 43 | else: 44 | assert not key in self.m_name2s_name 45 | 46 | def copy_to(self, model): 47 | m_param = dict(model.named_parameters()) 48 | shadow_params = dict(self.named_buffers()) 49 | for key in m_param: 50 | if m_param[key].requires_grad: 51 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 52 | else: 53 | assert not key in self.m_name2s_name 54 | 55 | def store(self, parameters): 56 | """ 57 | Save the current parameters for restoring later. 58 | Args: 59 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 60 | temporarily stored. 61 | """ 62 | self.collected_params = [param.clone() for param in parameters] 63 | 64 | def restore(self, parameters): 65 | """ 66 | Restore the parameters stored with the `store` method. 67 | Useful to validate the model with EMA parameters without affecting the 68 | original optimization process. Store the parameters before the 69 | `copy_to` method. After validation (or model saving), use this to 70 | restore the former parameters. 71 | Args: 72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 73 | updated with the stored parameters. 74 | """ 75 | for c_param, param in zip(self.collected_params, parameters): 76 | param.data.copy_(c_param.data) -------------------------------------------------------------------------------- /PVDM/tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anuragajay/hip/f83186af477bf6d8c5f8f3c1fdef43ccb89db25d/PVDM/tools/__init__.py -------------------------------------------------------------------------------- /PVDM/tools/data_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import math 4 | import random 5 | import pickle 6 | import warnings 7 | import glob 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | import zipfile 12 | import PIL.Image 13 | from PIL import Image 14 | from PIL import ImageFile 15 | from einops import rearrange 16 | from torchvision import transforms 17 | import json 18 | import numpy as np 19 | import pyspng 20 | 21 | from natsort import natsorted 22 | 23 | ImageFile.LOAD_TRUNCATED_IMAGES = True 24 | IMG_EXTENSIONS = [ 25 | '.jpg', '.JPG', '.jpeg', '.JPEG', 26 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 27 | ] 28 | 29 | def is_image_file(filename): 30 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 31 | 32 | 33 | def pil_loader(path): 34 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 35 | ''' 36 | with open(path, 'rb') as f: 37 | with Image.open(f) as img: 38 | return img.convert('RGB') 39 | ''' 40 | Im = Image.open(path) 41 | return Im.convert('RGB') 42 | 43 | 44 | def default_loader(path): 45 | ''' 46 | from torchvision import get_image_backend 47 | if get_image_backend() == 'accimage': 48 | return accimage_loader(path) 49 | else: 50 | ''' 51 | return pil_loader(path) 52 | 53 | 54 | def find_classes(dir): 55 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] 56 | classes.sort() 57 | class_to_idx = {classes[i]: i for i in range(len(classes))} 58 | return classes, class_to_idx 59 | 60 | def resize_crop(video, resolution): 61 | """ Resizes video with smallest axis to `resolution * extra_scale` 62 | and then crops a `resolution` x `resolution` bock. If `crop_mode == "center"` 63 | do a center crop, if `crop_mode == "random"`, does a random crop 64 | Args 65 | video: a tensor of shape [t, c, h, w] in {0, ..., 255} 66 | resolution: an int 67 | crop_mode: 'center', 'random' 68 | Returns 69 | a processed video of shape [c, t, h, w] 70 | """ 71 | _, _, h, w = video.shape 72 | 73 | if h > w: 74 | half = (h - w) // 2 75 | cropsize = (0, half, w, half + w) # left, upper, right, lower 76 | elif w >= h: 77 | half = (w - h) // 2 78 | cropsize = (half, 0, half + h, h) 79 | 80 | video = video[:, :, cropsize[1]:cropsize[3], cropsize[0]:cropsize[2]] 81 | video = F.interpolate(video, size=resolution, mode='bilinear', align_corners=False) 82 | 83 | video = video.permute(1, 0, 2, 3).contiguous() # [c, t, h, w] 84 | return video 85 | 86 | def make_imageclip_dataset(dir, nframes, class_to_idx, vid_diverse_sampling, split='all'): 87 | """ 88 | TODO: add xflip 89 | """ 90 | def _sort(path): 91 | return natsorted(os.listdir(path)) 92 | 93 | images = [] 94 | n_video = 0 95 | n_clip = 0 96 | 97 | 98 | dir_list = natsorted(os.listdir(dir)) 99 | for target in dir_list: 100 | if split == 'train': 101 | if 'val' in target: dir_list.remove(target) 102 | elif split == 'val' or split == 'test': 103 | if 'train' in target: dir_list.remove(target) 104 | 105 | for target in dir_list: 106 | if os.path.isdir(os.path.join(dir,target))==True: 107 | n_video +=1 108 | subfolder_path = os.path.join(dir, target) 109 | for subsubfold in natsorted(os.listdir(subfolder_path) ): 110 | if os.path.isdir(os.path.join(subfolder_path, subsubfold) ): 111 | subsubfolder_path = os.path.join(subfolder_path, subsubfold) 112 | i = 1 113 | 114 | if nframes > 0 and vid_diverse_sampling: 115 | n_clip += 1 116 | 117 | item_frames_0 = [] 118 | item_frames_1 = [] 119 | item_frames_2 = [] 120 | item_frames_3 = [] 121 | 122 | for fi in _sort(subsubfolder_path): 123 | if is_image_file(fi): 124 | file_name = fi 125 | file_path = os.path.join(subsubfolder_path, file_name) 126 | item = (file_path, class_to_idx[target]) 127 | 128 | if i % 4 == 0: 129 | item_frames_0.append(item) 130 | elif i % 4 == 1: 131 | item_frames_1.append(item) 132 | elif i % 4 == 2: 133 | item_frames_2.append(item) 134 | else: 135 | item_frames_3.append(item) 136 | 137 | if i %nframes == 0 and i > 0: 138 | images.append(item_frames_0) # item_frames is a list containing n frames. 139 | images.append(item_frames_1) # item_frames is a list containing n frames. 140 | images.append(item_frames_2) # item_frames is a list containing n frames. 141 | images.append(item_frames_3) # item_frames is a list containing n frames. 142 | item_frames_0 = [] 143 | item_frames_1 = [] 144 | item_frames_2 = [] 145 | item_frames_3 = [] 146 | 147 | i = i+1 148 | else: 149 | item_frames = [] 150 | for fi in _sort(subsubfolder_path): 151 | if is_image_file(fi): 152 | # fi is an image in the subsubfolder 153 | file_name = fi 154 | file_path = os.path.join(subsubfolder_path, file_name) 155 | item = (file_path, class_to_idx[target]) 156 | item_frames.append(item) 157 | if i % nframes == 0 and i > 0: 158 | images.append(item_frames) # item_frames is a list containing 32 frames. 159 | item_frames = [] 160 | i = i + 1 161 | 162 | return images 163 | 164 | 165 | 166 | def make_imagefolder_dataset(dir, nframes, class_to_idx, vid_diverse_sampling, split='all'): 167 | """ 168 | TODO: add xflip 169 | """ 170 | def _sort(path): 171 | return natsorted(os.listdir(path)) 172 | 173 | images = [] 174 | n_video = 0 175 | n_clip = 0 176 | 177 | 178 | dir_list = natsorted(os.listdir(dir)) 179 | for target in dir_list: 180 | if split == 'train': 181 | if 'val' in target: dir_list.remove(target) 182 | elif split == 'val' or split == 'test': 183 | if 'train' in target: dir_list.remove(target) 184 | 185 | dataset_list = [] 186 | for target in dir_list: 187 | if os.path.isdir(os.path.join(dir,target))==True: 188 | n_video +=1 189 | subfolder_path = os.path.join(dir, target) 190 | for subsubfold in natsorted(os.listdir(subfolder_path) ): 191 | if os.path.isdir(os.path.join(subfolder_path, subsubfold) ): 192 | subsubfolder_path = os.path.join(subfolder_path, subsubfold) 193 | 194 | count = 0 195 | valid = False 196 | for fi in _sort(subsubfolder_path): 197 | if is_image_file(fi): 198 | valid = True 199 | count += 1 200 | else: 201 | valid = False 202 | break 203 | """ 204 | valid = True 205 | """ 206 | if valid and count >= nframes: 207 | valid = True 208 | else: 209 | valid = False 210 | 211 | if valid == True: 212 | dataset_list.append((subsubfolder_path, count)) 213 | 214 | return dataset_list 215 | 216 | 217 | class InfiniteSampler(torch.utils.data.Sampler): 218 | def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5): 219 | assert len(dataset) > 0 220 | assert num_replicas > 0 221 | assert 0 <= rank < num_replicas 222 | assert 0 <= window_size <= 1 223 | super().__init__(dataset) 224 | self.dataset = dataset 225 | self.rank = rank 226 | self.num_replicas = num_replicas 227 | self.shuffle = shuffle 228 | self.seed = seed 229 | self.window_size = window_size 230 | 231 | def __iter__(self): 232 | order = np.arange(len(self.dataset)) 233 | rnd = None 234 | window = 0 235 | if self.shuffle: 236 | rnd = np.random.RandomState(self.seed) 237 | rnd.shuffle(order) 238 | window = int(np.rint(order.size * self.window_size)) 239 | 240 | idx = 0 241 | while True: 242 | i = idx % order.size 243 | if idx % self.num_replicas == self.rank: 244 | yield order[i] 245 | if window >= 2: 246 | j = (i - rnd.randint(window)) % order.size 247 | order[i], order[j] = order[j], order[i] 248 | idx += 1 -------------------------------------------------------------------------------- /PVDM/tools/dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import math 4 | import random 5 | import pickle 6 | import warnings 7 | import glob 8 | 9 | import imageio 10 | import numpy as np 11 | 12 | import torch 13 | import torch.nn.functional as F 14 | from torch.utils.data import Dataset, DataLoader 15 | 16 | import torchvision.transforms as T 17 | import torchvision.datasets as datasets 18 | from torchvision.datasets import UCF101 19 | from torchvision.datasets.folder import make_dataset 20 | from tools.video_utils import VideoClips 21 | from torchvision.io import read_video 22 | 23 | data_location = '/data' 24 | from utils import set_random_seed 25 | from tools.data_utils import * 26 | 27 | import av 28 | 29 | from ffcv.fields.decoders import NDArrayDecoder 30 | from ffcv.transforms import ToTensor, Squeeze, ToDevice 31 | from ffcv.loader import Loader, OrderOption 32 | 33 | class VideoFolderDataset(Dataset): 34 | def __init__(self, 35 | root, 36 | train, 37 | resolution, 38 | path=None, 39 | n_frames=16, 40 | skip=1, 41 | fold=1, 42 | max_size=None, # Artificially limit the size of the dataset. None = no limit. Applied before xflip. 43 | use_labels=False, # Enable conditioning labels? False = label dimension is zero. 44 | return_vid=False, # True for evaluating FVD 45 | time_saliency=False, 46 | sub=False, 47 | seed=42, 48 | **super_kwargs, # Additional arguments for the Dataset base class. 49 | ): 50 | 51 | video_root = osp.join(os.path.join(root)) 52 | if not 1 <= fold <= 3: 53 | raise ValueError("fold should be between 1 and 3, got {}".format(fold)) 54 | 55 | self.path = video_root 56 | name = video_root.split('/')[-1] 57 | self.name = name 58 | self.train = train 59 | self.fold = fold 60 | self.resolution = resolution 61 | self.nframes = n_frames 62 | self.annotation_path = os.path.join(video_root, 'ucfTrainTestlist') 63 | self.classes = list(natsorted(p for p in os.listdir(video_root) if osp.isdir(osp.join(video_root, p)))) 64 | self.classes.remove('ucfTrainTestlist') 65 | class_to_idx = {self.classes[i]: i for i in range(len(self.classes))} 66 | self.samples = make_dataset(video_root, class_to_idx, ('avi',), is_valid_file=None) 67 | video_list = [x[0] for x in self.samples] 68 | 69 | self.video_list = video_list 70 | """ 71 | if train: 72 | self.video_list = video_list 73 | else: 74 | self.video_list = [] 75 | for p in video_list: 76 | video = read_video(p)[0] 77 | if len(video) >= 128: 78 | self.video_list.append(p) 79 | """ 80 | 81 | 82 | self._use_labels = use_labels 83 | self._label_shape = None 84 | self._raw_labels = None 85 | self._raw_shape = [len(self.video_list)] + [3, resolution, resolution] 86 | self.num_channels = 3 87 | self.return_vid = return_vid 88 | 89 | frames_between_clips = skip 90 | print(root, frames_between_clips, n_frames) 91 | self.indices = self._select_fold(self.video_list, self.annotation_path, 92 | fold, train) 93 | 94 | self.size = len(self.indices) 95 | print(self.size) 96 | random.seed(seed) 97 | self.shuffle_indices = [i for i in range(self.size)] 98 | random.shuffle(self.shuffle_indices) 99 | 100 | self._need_init = True 101 | 102 | def _select_fold(self, video_list, annotation_path, fold, train): 103 | name = "train" if train else "test" 104 | name = "{}list{:02d}.txt".format(name, fold) 105 | f = os.path.join(annotation_path, name) 106 | selected_files = [] 107 | with open(f, "r") as fid: 108 | data = fid.readlines() 109 | data = [x.strip().split(" ") for x in data] 110 | data = [os.path.join(self.path, x[0]) for x in data] 111 | 112 | """ 113 | for p in data: 114 | if p in video_list == False: 115 | data.remove(p) 116 | """ 117 | 118 | selected_files.extend(data) 119 | 120 | """ 121 | name = "train" if not train else "test" 122 | name = "{}list{:02d}.txt".format(name, fold) 123 | f = os.path.join(annotation_path, name) 124 | with open(f, "r") as fid: 125 | data = fid.readlines() 126 | data = [x.strip().split(" ") for x in data] 127 | data = [os.path.join(self.path, x[0]) for x in data] 128 | selected_files.extend(data) 129 | """ 130 | 131 | selected_files = set(selected_files) 132 | indices = [i for i in range(len(video_list)) if video_list[i] in selected_files] 133 | return indices 134 | 135 | def __len__(self): 136 | return self.size 137 | 138 | def _preprocess(self, video): 139 | video = resize_crop(video, self.resolution) 140 | return video 141 | 142 | def __getitem__(self, idx): 143 | idx = self.shuffle_indices[idx] 144 | idx = self.indices[idx] 145 | video = read_video(self.video_list[idx])[0] 146 | prefix = np.random.randint(len(video)-self.nframes+1) 147 | video = video[prefix:prefix+self.nframes].float().permute(3,0,1,2) 148 | 149 | return self._preprocess(video), idx 150 | 151 | class ImageFolderDataset(Dataset): 152 | def __init__(self, 153 | path, # Path to directory or zip. 154 | resolution=None, 155 | nframes=16, # number of frames for each video. 156 | train=True, 157 | interpolate=False, 158 | loader=default_loader, # loader for "sequence" of images 159 | return_vid=True, # True for evaluating FVD 160 | cond=False, 161 | **super_kwargs, # Additional arguments for the Dataset base class. 162 | ): 163 | 164 | self._path = path 165 | self._zipfile = None 166 | self.apply_resize = True 167 | 168 | # classes, class_to_idx = find_classes(path) 169 | if 'taichi' in path and not interpolate: 170 | classes, class_to_idx = find_classes(path) 171 | imgs = make_imagefolder_dataset(path, nframes * 4, class_to_idx, True) 172 | elif 'kinetics' in path or 'KINETICS' in path: 173 | if train: 174 | split = 'train' 175 | else: 176 | split = 'val' 177 | classes, class_to_idx = find_classes(path) 178 | imgs = make_imagefolder_dataset(path, nframes, class_to_idx, False, split) 179 | elif 'SKY' in path: 180 | if train: 181 | split = 'train' 182 | else: 183 | split = 'test' 184 | path = os.path.join(path, split) 185 | classes, class_to_idx = find_classes(path) 186 | if cond: 187 | imgs = make_imagefolder_dataset(path, nframes // 2, class_to_idx, False, split) 188 | else: 189 | imgs = make_imagefolder_dataset(path, nframes, class_to_idx, False, split) 190 | else: 191 | classes, class_to_idx = find_classes(path) 192 | imgs = make_imagefolder_dataset(path, nframes, class_to_idx, False) 193 | 194 | if len(imgs) == 0: 195 | raise(RuntimeError("Found 0 images in subfolders of: " + path + "\n" 196 | "Supported image extensions are: " + 197 | ",".join(IMG_EXTENSIONS))) 198 | 199 | self.imgs = imgs 200 | self.classes = classes 201 | self.class_to_idx = class_to_idx 202 | self.nframes = nframes 203 | self.loader = loader 204 | self.img_resolution = resolution 205 | self._path = path 206 | self._total_size = len(self.imgs) 207 | self._raw_shape = [self._total_size] + [3, resolution, resolution] 208 | self.xflip = False 209 | self.return_vid = return_vid 210 | self.shuffle_indices = [i for i in range(self._total_size)] 211 | self.to_tensor = transforms.ToTensor() 212 | random.shuffle(self.shuffle_indices) 213 | self._type = "dir" 214 | 215 | def _file_ext(self, fname): 216 | return os.path.splitext(fname)[1].lower() 217 | 218 | def _get_zipfile(self): 219 | assert self._type == 'zip' 220 | if self._zipfile is None: 221 | self._zipfile = zipfile.ZipFile(self._path) 222 | return self._zipfile 223 | 224 | def _open_file(self, fname): 225 | if self._type == 'dir': 226 | return open(os.path.join(fname), 'rb') 227 | if self._type == 'zip': 228 | return self._get_zipfile().open(fname, 'r') 229 | return None 230 | 231 | def close(self): 232 | try: 233 | if self._zipfile is not None: 234 | self._zipfile.close() 235 | finally: 236 | self._zipfile = None 237 | 238 | def _load_img_from_path(self, folder, fname): 239 | path = os.path.join(folder, fname) 240 | with self._open_file(path) as f: 241 | if pyspng is not None and self._file_ext(path) == '.png': 242 | img = pyspng.load(f.read()) 243 | img = rearrange(img, 'h w c -> c h w') 244 | else: 245 | img = self.to_tensor(PIL.Image.open(f)).numpy() * 255 # c h w 246 | return img 247 | 248 | def __getitem__(self, index): 249 | index = self.shuffle_indices[index] 250 | path = self.imgs[index] 251 | 252 | # clip is a list of 32 frames 253 | video = natsorted(os.listdir(path[0])) 254 | 255 | # zero padding. only unconditional modeling 256 | if len(video) < self.nframes: 257 | prefix = np.random.randint(len(video)-self.nframes//2+1) 258 | clip = video[prefix:prefix+self.nframes//2] 259 | else: 260 | prefix = np.random.randint(len(video)-self.nframes+1) 261 | clip = video[prefix:prefix+self.nframes] 262 | 263 | assert (len(clip) == self.nframes or len(clip)*2 == self.nframes) 264 | 265 | vid = np.stack([self._load_img_from_path(folder=path[0], fname=clip[i]) for i in range(len(clip))], axis=0) 266 | vid = resize_crop(torch.from_numpy(vid).float(), resolution=self.img_resolution) # c t h w 267 | if vid.size(1) == self.nframes//2: 268 | vid = torch.cat([torch.zeros_like(vid).to(vid.device), vid], dim=1) 269 | 270 | return rearrange(vid, 'c t h w -> t c h w'), index 271 | 272 | 273 | def __len__(self): 274 | return self._total_size 275 | 276 | def get_loaders(rank, imgstr, resolution, timesteps, skip, batch_size=1, n_gpus=1, seed=42, cond=False, use_train_set=False): 277 | """ 278 | Load dataloaders for an image dataset, center-cropped to a resolution. 279 | """ 280 | if imgstr == 'cliport': 281 | base_path = '/path/to/cliport' 282 | num_workers = 20 283 | 284 | trainloader = Loader(f'{base_path}/video_subgoal_train_text.beton', batch_size=batch_size, 285 | num_workers=num_workers, order=OrderOption.RANDOM, 286 | pipelines={ 287 | 'video': [NDArrayDecoder(), ToTensor()], 288 | 'text': [NDArrayDecoder(),] 289 | }) 290 | 291 | testloader = Loader(f'{base_path}/video_subgoal_train_text.beton', 292 | batch_size=batch_size, 293 | num_workers=num_workers, order=OrderOption.RANDOM, 294 | pipelines={ 295 | 'video': [NDArrayDecoder(), ToTensor()], 296 | 'text': [NDArrayDecoder(),] 297 | }) 298 | 299 | return trainloader, testloader, len(trainloader) 300 | 301 | if imgstr == 'UCF101': 302 | train_dir = os.path.join(data_location, 'UCF-101') 303 | test_dir = os.path.join(data_location, 'UCF-101') # We use all 304 | if cond: 305 | print("here") 306 | timesteps *= 2 # for long generation 307 | trainset = VideoFolderDataset(train_dir, train=True, resolution=resolution, n_frames=timesteps, skip=skip, seed=seed) 308 | print(len(trainset)) 309 | testset = VideoFolderDataset(train_dir, train=False, resolution=resolution, n_frames=timesteps, skip=skip, seed=seed) 310 | print(len(testset)) 311 | 312 | elif imgstr == 'SKY': 313 | train_dir = os.path.join(data_location, 'SKY') 314 | test_dir = os.path.join(data_location, 'SKY') 315 | if cond: 316 | print("here") 317 | timesteps *= 2 # for long generation 318 | trainset = ImageFolderDataset(train_dir, train=True, resolution=resolution, nframes=timesteps, cond=cond) 319 | print(len(trainset)) 320 | testset = ImageFolderDataset(test_dir, train=False, resolution=resolution, nframes=timesteps, cond=cond) 321 | print(len(testset)) 322 | 323 | else: 324 | raise NotImplementedError() 325 | 326 | shuffle = False if use_train_set else True 327 | 328 | kwargs = {'pin_memory': True, 'num_workers': 3} 329 | 330 | trainset_sampler = InfiniteSampler(dataset=trainset, rank=rank, num_replicas=n_gpus, seed=seed) 331 | trainloader = DataLoader(trainset, sampler=trainset_sampler, batch_size=batch_size // n_gpus, pin_memory=False, num_workers=4, prefetch_factor=2) 332 | 333 | testset_sampler = InfiniteSampler(testset, num_replicas=n_gpus, rank=rank, seed=seed) 334 | testloader = DataLoader(testset, sampler=testset_sampler, batch_size=batch_size // n_gpus, pin_memory=False, num_workers=4, prefetch_factor=2) 335 | 336 | return trainloader, trainloader, testloader 337 | 338 | 339 | -------------------------------------------------------------------------------- /PVDM/tools/scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LambdaWarmUpCosineScheduler: 5 | """ 6 | note: use with a base_lr of 1.0 7 | """ 8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): 9 | self.lr_warm_up_steps = warm_up_steps 10 | self.lr_start = lr_start 11 | self.lr_min = lr_min 12 | self.lr_max = lr_max 13 | self.lr_max_decay_steps = max_decay_steps 14 | self.last_lr = 0. 15 | self.verbosity_interval = verbosity_interval 16 | 17 | def schedule(self, n, **kwargs): 18 | if self.verbosity_interval > 0: 19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 20 | if n < self.lr_warm_up_steps: 21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start 22 | self.last_lr = lr 23 | return lr 24 | else: 25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) 26 | t = min(t, 1.0) 27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 28 | 1 + np.cos(t * np.pi)) 29 | self.last_lr = lr 30 | return lr 31 | 32 | def __call__(self, n, **kwargs): 33 | return self.schedule(n,**kwargs) 34 | 35 | 36 | class LambdaWarmUpCosineScheduler2: 37 | """ 38 | supports repeated iterations, configurable via lists 39 | note: use with a base_lr of 1.0. 40 | """ 41 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): 42 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) 43 | self.lr_warm_up_steps = warm_up_steps 44 | self.f_start = f_start 45 | self.f_min = f_min 46 | self.f_max = f_max 47 | self.cycle_lengths = cycle_lengths 48 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) 49 | self.last_f = 0. 50 | self.verbosity_interval = verbosity_interval 51 | 52 | def find_in_interval(self, n): 53 | interval = 0 54 | for cl in self.cum_cycles[1:]: 55 | if n <= cl: 56 | return interval 57 | interval += 1 58 | 59 | def schedule(self, n, **kwargs): 60 | cycle = self.find_in_interval(n) 61 | n = n - self.cum_cycles[cycle] 62 | if self.verbosity_interval > 0: 63 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 64 | f"current cycle {cycle}") 65 | if n < self.lr_warm_up_steps[cycle]: 66 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 67 | self.last_f = f 68 | return f 69 | else: 70 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) 71 | t = min(t, 1.0) 72 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( 73 | 1 + np.cos(t * np.pi)) 74 | self.last_f = f 75 | return f 76 | 77 | def __call__(self, n, **kwargs): 78 | return self.schedule(n, **kwargs) 79 | 80 | 81 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): 82 | 83 | def schedule(self, n, **kwargs): 84 | cycle = self.find_in_interval(n) 85 | n = n - self.cum_cycles[cycle] 86 | if self.verbosity_interval > 0: 87 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 88 | f"current cycle {cycle}") 89 | 90 | if n < self.lr_warm_up_steps[cycle]: 91 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 92 | self.last_f = f 93 | return f 94 | else: 95 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) 96 | self.last_f = f 97 | return f -------------------------------------------------------------------------------- /PVDM/tools/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | import sys; sys.path.extend([sys.path[0][:-4], '/app']) 5 | 6 | import time 7 | import tqdm 8 | import copy 9 | import torch 10 | import torch.nn.functional as F 11 | from torch.cuda.amp import GradScaler, autocast 12 | 13 | 14 | from utils import AverageMeter 15 | from evals.eval import test_psnr, test_ifvd, test_fvd_ddpm 16 | from models.ema import LitEma 17 | from einops import rearrange 18 | from torch.optim.lr_scheduler import LambdaLR 19 | from tqdm import tqdm 20 | from transformers import T5Tokenizer, T5EncoderModel 21 | from torch.distributions import Bernoulli 22 | 23 | 24 | def latentDDPM(rank, first_stage_model, model, opt, criterion, train_loader, test_loader, scheduler, ema_model=None, cond_prob=0.9, logger=None): 25 | scaler = GradScaler() 26 | 27 | if logger is None: 28 | log_ = print 29 | else: 30 | log_ = logger.log 31 | 32 | if rank == 0: 33 | rootdir = logger.logdir 34 | 35 | device = torch.device('cuda', rank) 36 | 37 | tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base") 38 | text_model = T5EncoderModel.from_pretrained("google/flan-t5-base") 39 | text_model = text_model.to(device) 40 | 41 | mask_dist = Bernoulli(probs=cond_prob) 42 | batch_size = train_loader.batch_size 43 | 44 | with torch.no_grad(): 45 | uncond_tokens = torch.LongTensor([tokenizer('', padding='max_length', max_length=15).input_ids for i in range(batch_size)]).to(device) 46 | uncond_latents = text_model(uncond_tokens).last_hidden_state.detach() 47 | 48 | losses = dict() 49 | losses['diffusion_loss'] = AverageMeter() 50 | check = time.time() 51 | 52 | if ema_model == None: 53 | ema_model = copy.deepcopy(model) 54 | ema = LitEma(ema_model) 55 | ema_model.eval() 56 | else: 57 | ema = LitEma(ema_model) 58 | ema.num_updates = torch.tensor(11200,dtype=torch.int) 59 | ema_model.eval() 60 | 61 | first_stage_model.eval() 62 | model.train() 63 | 64 | num_iters = 400000 # 25*len(train_loader) 65 | len_train_loader = len(train_loader) 66 | num_epochs = num_iters // len_train_loader 67 | print(f'Training for {num_epochs} epochs') 68 | 69 | for it_epoch in tqdm(range(num_epochs)): 70 | for it_loader, (x, text) in tqdm(enumerate(train_loader)): 71 | it = it_epoch*len_train_loader + it_loader 72 | x = x.to(device) 73 | x = rearrange(x / 127.5 - 1, 'b t h w c -> b c t h w') # videos 74 | 75 | text_masks = mask_dist.sample((batch_size,)).unsqueeze(1).unsqueeze(2).to(device) 76 | 77 | with torch.no_grad(): 78 | tokens = torch.LongTensor([tokenizer(text[i].tobytes().decode('ascii'), padding='max_length', max_length=15).input_ids for i in range(batch_size)]).to(device) 79 | text_latents = text_model(tokens).last_hidden_state.detach() 80 | text_latents = text_latents*text_masks + uncond_latents*(1-text_masks) 81 | 82 | # conditional free guidance training 83 | model.zero_grad() 84 | 85 | x = x[:,:,:,:,:] 86 | c = x[:,:,0:1,:,:].repeat(1,1,x.shape[2],1,1) 87 | 88 | with autocast(): 89 | with torch.no_grad(): 90 | z = first_stage_model.extract(x).detach() 91 | c = first_stage_model.extract(c).detach() 92 | 93 | (loss, t), loss_dict = criterion(x = z.float(), cond = c.float(), context=text_latents) 94 | 95 | """ 96 | scaler.scale(loss).backward() 97 | scaler.step(opt) 98 | scaler.update() 99 | """ 100 | loss.backward() 101 | opt.step() 102 | 103 | losses['diffusion_loss'].update(loss.item(), 1) 104 | 105 | # ema model 106 | if it % 25 == 0 and it > 0: 107 | ema(model) 108 | 109 | if it % 500 == 0: 110 | if logger is not None and rank == 0: 111 | logger.scalar_summary('train/diffusion_loss', losses['diffusion_loss'].average, it) 112 | 113 | log_('[Time %.3f] [Diffusion %f]' % 114 | (time.time() - check, losses['diffusion_loss'].average)) 115 | 116 | losses = dict() 117 | losses['diffusion_loss'] = AverageMeter() 118 | 119 | 120 | if it % 2000 == 0 and rank == 0: 121 | torch.save(model.state_dict(), rootdir + f'model_{it}.pth') 122 | ema.copy_to(ema_model) 123 | torch.save(ema_model.state_dict(), rootdir + f'ema_model_{it}.pth') 124 | fvd = test_fvd_ddpm(rank, ema_model, first_stage_model, test_loader, it, tokenizer, text_model, uncond_latents, logger) 125 | 126 | if logger is not None and rank == 0: 127 | logger.scalar_summary('test/fvd', fvd, it) 128 | log_('[Time %.3f] [FVD %f]' % 129 | (time.time() - check, fvd)) 130 | 131 | def first_stage_train(rank, model, opt, d_opt, criterion, train_loader, test_loader, first_model, fp, logger=None): 132 | if logger is None: 133 | log_ = print 134 | else: 135 | log_ = logger.log 136 | 137 | if rank == 0: 138 | rootdir = logger.logdir 139 | 140 | device = torch.device('cuda', rank) 141 | 142 | losses = dict() 143 | losses['ae_loss'] = AverageMeter() 144 | losses['d_loss'] = AverageMeter() 145 | check = time.time() 146 | 147 | accum_iter = 3 148 | disc_opt = False 149 | 150 | if fp: 151 | scaler = GradScaler() 152 | scaler_d = GradScaler() 153 | 154 | try: 155 | scaler.load_state_dict(torch.load(os.path.join(first_model, 'scaler.pth'))) 156 | scaler_d.load_state_dict(torch.load(os.path.join(first_model, 'scaler_d.pth'))) 157 | except: 158 | print("Fail to load scalers. Start from initial point.") 159 | 160 | model.train() 161 | disc_start = criterion.discriminator_iter_start 162 | num_iters = 25*len(train_loader) 163 | len_train_loader = len(train_loader) 164 | num_epochs = num_iters // len_train_loader 165 | 166 | for it_epoch in range(num_epochs): 167 | for it_loader, (x, _) in enumerate(train_loader): 168 | # x is (b, t, h, w, c) 169 | it = it_epoch*len_train_loader + it_loader 170 | batch_size = x.size(0) 171 | x = x.permute(0, 1, 4, 2, 3) 172 | x = x.contiguous() 173 | 174 | x = x.to(device) 175 | x = rearrange(x / 127.5 - 1, 'b t c h w -> b c t h w') # videos 176 | 177 | if not disc_opt: 178 | with autocast(): 179 | x_tilde, vq_loss = model(x) 180 | 181 | if it % accum_iter == 0: 182 | model.zero_grad() 183 | ae_loss = criterion(vq_loss, x, 184 | rearrange(x_tilde, '(b t) c h w -> b c t h w', b=batch_size), 185 | optimizer_idx=0, 186 | global_step=it) 187 | 188 | ae_loss = ae_loss / accum_iter 189 | 190 | scaler.scale(ae_loss).backward() 191 | 192 | if it % accum_iter == accum_iter - 1: 193 | scaler.step(opt) 194 | scaler.update() 195 | 196 | losses['ae_loss'].update(ae_loss.item(), 1) 197 | 198 | else: 199 | if it % accum_iter == 0: 200 | criterion.zero_grad() 201 | 202 | with autocast(): 203 | with torch.no_grad(): 204 | x_tilde, vq_loss = model(x) 205 | d_loss = criterion(vq_loss, x, 206 | rearrange(x_tilde, '(b t) c h w -> b c t h w', b=batch_size), 207 | optimizer_idx=1, 208 | global_step=it) 209 | d_loss = d_loss / accum_iter 210 | 211 | scaler_d.scale(d_loss).backward() 212 | 213 | if it % accum_iter == accum_iter - 1: 214 | # Unscales the gradients of optimizer's assigned params in-place 215 | scaler_d.unscale_(d_opt) 216 | 217 | # Since the gradients of optimizer's assigned params are unscaled, clips as usual: 218 | torch.nn.utils.clip_grad_norm_(criterion.discriminator_2d.parameters(), 1.0) 219 | torch.nn.utils.clip_grad_norm_(criterion.discriminator_3d.parameters(), 1.0) 220 | 221 | scaler_d.step(d_opt) 222 | scaler_d.update() 223 | 224 | losses['d_loss'].update(d_loss.item() * 3, 1) 225 | 226 | if it % accum_iter == accum_iter - 1 and it // accum_iter >= disc_start: 227 | if disc_opt: 228 | disc_opt = False 229 | else: 230 | disc_opt = True 231 | 232 | if it % 2000 == 0: 233 | fvd = test_ifvd(rank, model, test_loader, it, logger) 234 | psnr = test_psnr(rank, model, test_loader, it, logger) 235 | if logger is not None and rank == 0: 236 | logger.scalar_summary('train/ae_loss', losses['ae_loss'].average, it) 237 | logger.scalar_summary('train/d_loss', losses['d_loss'].average, it) 238 | logger.scalar_summary('test/psnr', psnr, it) 239 | logger.scalar_summary('test/fvd', fvd, it) 240 | 241 | log_('[Time %.3f] [AELoss %f] [DLoss %f] [PSNR %f] [FVD %f]' % 242 | (time.time() - check, losses['ae_loss'].average, losses['d_loss'].average, psnr, fvd)) 243 | # print('[Time %.3f] [AELoss %f] [DLoss %f] [PSNR %f]' % 244 | # (time.time() - check, losses['ae_loss'].average, losses['d_loss'].average, psnr)) 245 | 246 | torch.save(model.state_dict(), rootdir + f'model_last.pth') 247 | torch.save(criterion.state_dict(), rootdir + f'loss_last.pth') 248 | torch.save(opt.state_dict(), rootdir + f'opt.pth') 249 | torch.save(d_opt.state_dict(), rootdir + f'd_opt.pth') 250 | torch.save(scaler.state_dict(), rootdir + f'scaler.pth') 251 | torch.save(scaler_d.state_dict(), rootdir + f'scaler_d.pth') 252 | 253 | losses = dict() 254 | losses['ae_loss'] = AverageMeter() 255 | losses['d_loss'] = AverageMeter() 256 | 257 | if it % 2000 == 0 and rank == 0: 258 | torch.save(model.state_dict(), rootdir + f'model_{it}.pth') 259 | 260 | -------------------------------------------------------------------------------- /PVDM/tools/video_utils.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | import math 3 | import warnings 4 | from fractions import Fraction 5 | from typing import Any, Dict, List, Optional, Callable, Union, Tuple, TypeVar, cast 6 | 7 | import torch 8 | from torchvision.io import ( 9 | _probe_video_from_file, 10 | _read_video_from_file, 11 | read_video, 12 | read_video_timestamps, 13 | ) 14 | 15 | from tqdm import tqdm 16 | 17 | T = TypeVar("T") 18 | 19 | 20 | def pts_convert(pts: int, timebase_from: Fraction, timebase_to: Fraction, round_func: Callable = math.floor) -> int: 21 | """convert pts between different time bases 22 | Args: 23 | pts: presentation timestamp, float 24 | timebase_from: original timebase. Fraction 25 | timebase_to: new timebase. Fraction 26 | round_func: rounding function. 27 | """ 28 | new_pts = Fraction(pts, 1) * timebase_from / timebase_to 29 | return round_func(new_pts) 30 | 31 | 32 | def unfold(tensor: torch.Tensor, size: int, step: int, dilation: int = 1) -> torch.Tensor: 33 | """ 34 | similar to tensor.unfold, but with the dilation 35 | and specialized for 1d tensors 36 | Returns all consecutive windows of `size` elements, with 37 | `step` between windows. The distance between each element 38 | in a window is given by `dilation`. 39 | """ 40 | if tensor.dim() != 1: 41 | raise ValueError(f"tensor should have 1 dimension instead of {tensor.dim()}") 42 | o_stride = tensor.stride(0) 43 | numel = tensor.numel() 44 | new_stride = (step * o_stride, dilation * o_stride) 45 | new_size = ((numel - (dilation * (size - 1) + 1)) // step + 1, size) 46 | if new_size[0] < 1: 47 | new_size = (0, size) 48 | return torch.as_strided(tensor, new_size, new_stride) 49 | 50 | 51 | class _VideoTimestampsDataset: 52 | """ 53 | Dataset used to parallelize the reading of the timestamps 54 | of a list of videos, given their paths in the filesystem. 55 | Used in VideoClips and defined at top level so it can be 56 | pickled when forking. 57 | """ 58 | 59 | def __init__(self, video_paths: List[str]) -> None: 60 | self.video_paths = video_paths 61 | 62 | def __len__(self) -> int: 63 | return len(self.video_paths) 64 | 65 | def __getitem__(self, idx: int) -> Tuple[List[int], Optional[float]]: 66 | return read_video_timestamps(self.video_paths[idx]) 67 | 68 | 69 | def _collate_fn(x: T) -> T: 70 | """ 71 | Dummy collate function to be used with _VideoTimestampsDataset 72 | """ 73 | return x 74 | 75 | 76 | class VideoClips: 77 | """ 78 | Given a list of video files, computes all consecutive subvideos of size 79 | `clip_length_in_frames`, where the distance between each subvideo in the 80 | same video is defined by `frames_between_clips`. 81 | If `frame_rate` is specified, it will also resample all the videos to have 82 | the same frame rate, and the clips will refer to this frame rate. 83 | Creating this instance the first time is time-consuming, as it needs to 84 | decode all the videos in `video_paths`. It is recommended that you 85 | cache the results after instantiation of the class. 86 | Recreating the clips for different clip lengths is fast, and can be done 87 | with the `compute_clips` method. 88 | Args: 89 | video_paths (List[str]): paths to the video files 90 | clip_length_in_frames (int): size of a clip in number of frames 91 | frames_between_clips (int): step (in frames) between each clip 92 | frame_rate (int, optional): if specified, it will resample the video 93 | so that it has `frame_rate`, and then the clips will be defined 94 | on the resampled video 95 | num_workers (int): how many subprocesses to use for data loading. 96 | 0 means that the data will be loaded in the main process. (default: 0) 97 | """ 98 | 99 | def __init__( 100 | self, 101 | video_paths: List[str], 102 | clip_length_in_frames: int = 16, 103 | frames_between_clips: int = 1, 104 | frame_rate: Optional[int] = None, 105 | _precomputed_metadata: Optional[Dict[str, Any]] = None, 106 | num_workers: int = 0, 107 | _video_width: int = 0, 108 | _video_height: int = 0, 109 | _video_min_dimension: int = 0, 110 | _video_max_dimension: int = 0, 111 | _audio_samples: int = 0, 112 | _audio_channels: int = 0, 113 | ) -> None: 114 | 115 | self.video_paths = video_paths 116 | self.num_workers = num_workers 117 | 118 | # these options are not valid for pyav backend 119 | self._video_width = _video_width 120 | self._video_height = _video_height 121 | self._video_min_dimension = _video_min_dimension 122 | self._video_max_dimension = _video_max_dimension 123 | self._audio_samples = _audio_samples 124 | self._audio_channels = _audio_channels 125 | 126 | if _precomputed_metadata is None: 127 | self._compute_frame_pts() 128 | else: 129 | self._init_from_metadata(_precomputed_metadata) 130 | self.compute_clips(clip_length_in_frames, frames_between_clips, frame_rate) 131 | 132 | def _compute_frame_pts(self) -> None: 133 | self.video_pts = [] 134 | self.video_fps = [] 135 | 136 | # strategy: use a DataLoader to parallelize read_video_timestamps 137 | # so need to create a dummy dataset first 138 | import torch.utils.data 139 | 140 | dl: torch.utils.data.DataLoader = torch.utils.data.DataLoader( 141 | _VideoTimestampsDataset(self.video_paths), # type: ignore[arg-type] 142 | batch_size=16, 143 | num_workers=self.num_workers, 144 | collate_fn=_collate_fn, 145 | ) 146 | 147 | with tqdm(total=len(dl)) as pbar: 148 | for batch in dl: 149 | pbar.update(1) 150 | clips, fps = list(zip(*batch)) 151 | # we need to specify dtype=torch.long because for empty list, 152 | # torch.as_tensor will use torch.float as default dtype. This 153 | # happens when decoding fails and no pts is returned in the list. 154 | clips = [torch.as_tensor(c, dtype=torch.long) for c in clips] 155 | self.video_pts.extend(clips) 156 | self.video_fps.extend(fps) 157 | 158 | def _init_from_metadata(self, metadata: Dict[str, Any]) -> None: 159 | self.video_paths = metadata["video_paths"] 160 | assert len(self.video_paths) == len(metadata["video_pts"]) 161 | self.video_pts = metadata["video_pts"] 162 | assert len(self.video_paths) == len(metadata["video_fps"]) 163 | self.video_fps = metadata["video_fps"] 164 | 165 | @property 166 | def metadata(self) -> Dict[str, Any]: 167 | _metadata = { 168 | "video_paths": self.video_paths, 169 | "video_pts": self.video_pts, 170 | "video_fps": self.video_fps, 171 | } 172 | return _metadata 173 | 174 | def subset(self, indices: List[int]) -> "VideoClips": 175 | video_paths = [self.video_paths[i] for i in indices] 176 | video_pts = [self.video_pts[i] for i in indices] 177 | video_fps = [self.video_fps[i] for i in indices] 178 | metadata = { 179 | "video_paths": video_paths, 180 | "video_pts": video_pts, 181 | "video_fps": video_fps, 182 | } 183 | return type(self)( 184 | video_paths, 185 | self.num_frames, 186 | self.step, 187 | self.frame_rate, 188 | _precomputed_metadata=metadata, 189 | num_workers=self.num_workers, 190 | _video_width=self._video_width, 191 | _video_height=self._video_height, 192 | _video_min_dimension=self._video_min_dimension, 193 | _video_max_dimension=self._video_max_dimension, 194 | _audio_samples=self._audio_samples, 195 | _audio_channels=self._audio_channels, 196 | ) 197 | 198 | @staticmethod 199 | def compute_clips_for_video( 200 | video_pts: torch.Tensor, num_frames: int, step: int, fps: int, frame_rate: Optional[int] = None 201 | ) -> Tuple[torch.Tensor, Union[List[slice], torch.Tensor]]: 202 | if fps is None: 203 | # if for some reason the video doesn't have fps (because doesn't have a video stream) 204 | # set the fps to 1. The value doesn't matter, because video_pts is empty anyway 205 | fps = 1 206 | if frame_rate is None: 207 | frame_rate = fps 208 | total_frames = len(video_pts) * (float(frame_rate) / fps) 209 | _idxs = VideoClips._resample_video_idx(int(math.floor(total_frames)), fps, frame_rate) 210 | video_pts = video_pts[_idxs] 211 | clips = unfold(video_pts, num_frames, step) 212 | if not clips.numel(): 213 | warnings.warn( 214 | "There aren't enough frames in the current video to get a clip for the given clip length and " 215 | "frames between clips. The video (and potentially others) will be skipped." 216 | ) 217 | idxs: Union[List[slice], torch.Tensor] 218 | if isinstance(_idxs, slice): 219 | idxs = [_idxs] * len(clips) 220 | else: 221 | idxs = unfold(_idxs, num_frames, step) 222 | return clips, idxs 223 | 224 | def compute_clips(self, num_frames: int, step: int, frame_rate: Optional[int] = None) -> None: 225 | """ 226 | Compute all consecutive sequences of clips from video_pts. 227 | Always returns clips of size `num_frames`, meaning that the 228 | last few frames in a video can potentially be dropped. 229 | Args: 230 | num_frames (int): number of frames for the clip 231 | step (int): distance between two clips 232 | frame_rate (int, optional): The frame rate 233 | """ 234 | self.num_frames = num_frames 235 | self.step = step 236 | self.frame_rate = frame_rate 237 | self.clips = [] 238 | self.resampling_idxs = [] 239 | for video_pts, fps in zip(self.video_pts, self.video_fps): 240 | clips, idxs = self.compute_clips_for_video(video_pts, num_frames, step, fps, frame_rate) 241 | self.clips.append(clips) 242 | self.resampling_idxs.append(idxs) 243 | clip_lengths = torch.as_tensor([len(v) for v in self.clips]) 244 | self.cumulative_sizes = clip_lengths.cumsum(0).tolist() 245 | 246 | def __len__(self) -> int: 247 | return self.num_clips() 248 | 249 | def num_videos(self) -> int: 250 | return len(self.video_paths) 251 | 252 | def num_clips(self) -> int: 253 | """ 254 | Number of subclips that are available in the video list. 255 | """ 256 | return self.cumulative_sizes[-1] 257 | 258 | def get_clip_location(self, idx: int) -> Tuple[int, int]: 259 | """ 260 | Converts a flattened representation of the indices into a video_idx, clip_idx 261 | representation. 262 | """ 263 | video_idx = bisect.bisect_right(self.cumulative_sizes, idx) 264 | if video_idx == 0: 265 | clip_idx = idx 266 | else: 267 | clip_idx = idx - self.cumulative_sizes[video_idx - 1] 268 | return video_idx, clip_idx 269 | 270 | @staticmethod 271 | def _resample_video_idx(num_frames: int, original_fps: int, new_fps: int) -> Union[slice, torch.Tensor]: 272 | step = float(original_fps) / new_fps 273 | if step.is_integer(): 274 | # optimization: if step is integer, don't need to perform 275 | # advanced indexing 276 | step = int(step) 277 | return slice(None, None, step) 278 | idxs = torch.arange(num_frames, dtype=torch.float32) * step 279 | idxs = idxs.floor().to(torch.int64) 280 | return idxs 281 | 282 | def get_clip(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any], int]: 283 | """ 284 | Gets a subclip from a list of videos. 285 | Args: 286 | idx (int): index of the subclip. Must be between 0 and num_clips(). 287 | Returns: 288 | video (Tensor) 289 | audio (Tensor) 290 | info (Dict) 291 | video_idx (int): index of the video in `video_paths` 292 | """ 293 | if idx >= self.num_clips(): 294 | raise IndexError(f"Index {idx} out of range ({self.num_clips()} number of clips)") 295 | video_idx, clip_idx = self.get_clip_location(idx) 296 | video_path = self.video_paths[video_idx] 297 | clip_pts = self.clips[video_idx][clip_idx] 298 | 299 | from torchvision import get_video_backend 300 | 301 | backend = get_video_backend() 302 | 303 | if backend == "pyav": 304 | # check for invalid options 305 | if self._video_width != 0: 306 | raise ValueError("pyav backend doesn't support _video_width != 0") 307 | if self._video_height != 0: 308 | raise ValueError("pyav backend doesn't support _video_height != 0") 309 | if self._video_min_dimension != 0: 310 | raise ValueError("pyav backend doesn't support _video_min_dimension != 0") 311 | if self._video_max_dimension != 0: 312 | raise ValueError("pyav backend doesn't support _video_max_dimension != 0") 313 | if self._audio_samples != 0: 314 | raise ValueError("pyav backend doesn't support _audio_samples != 0") 315 | 316 | if backend == "pyav": 317 | start_pts = clip_pts[0].item() 318 | end_pts = clip_pts[-1].item() 319 | video, audio, info = read_video(video_path, start_pts, end_pts) 320 | else: 321 | _info = _probe_video_from_file(video_path) 322 | video_fps = _info.video_fps 323 | audio_fps = None 324 | 325 | video_start_pts = cast(int, clip_pts[0].item()) 326 | video_end_pts = cast(int, clip_pts[-1].item()) 327 | 328 | audio_start_pts, audio_end_pts = 0, -1 329 | audio_timebase = Fraction(0, 1) 330 | video_timebase = Fraction(_info.video_timebase.numerator, _info.video_timebase.denominator) 331 | if _info.has_audio: 332 | audio_timebase = Fraction(_info.audio_timebase.numerator, _info.audio_timebase.denominator) 333 | audio_start_pts = pts_convert(video_start_pts, video_timebase, audio_timebase, math.floor) 334 | audio_end_pts = pts_convert(video_end_pts, video_timebase, audio_timebase, math.ceil) 335 | audio_fps = _info.audio_sample_rate 336 | video, audio, _ = _read_video_from_file( 337 | video_path, 338 | video_width=self._video_width, 339 | video_height=self._video_height, 340 | video_min_dimension=self._video_min_dimension, 341 | video_max_dimension=self._video_max_dimension, 342 | video_pts_range=(video_start_pts, video_end_pts), 343 | video_timebase=video_timebase, 344 | audio_samples=self._audio_samples, 345 | audio_channels=self._audio_channels, 346 | audio_pts_range=(audio_start_pts, audio_end_pts), 347 | audio_timebase=audio_timebase, 348 | ) 349 | 350 | info = {"video_fps": video_fps} 351 | if audio_fps is not None: 352 | info["audio_fps"] = audio_fps 353 | 354 | if self.frame_rate is not None: 355 | resampling_idx = self.resampling_idxs[video_idx][clip_idx] 356 | if isinstance(resampling_idx, torch.Tensor): 357 | resampling_idx = resampling_idx - resampling_idx[0] 358 | video = video[resampling_idx] 359 | info["video_fps"] = self.frame_rate 360 | assert len(video) == self.num_frames, f"{video.shape} x {self.num_frames}" 361 | return video, audio, info, video_idx 362 | 363 | def __getstate__(self) -> Dict[str, Any]: 364 | video_pts_sizes = [len(v) for v in self.video_pts] 365 | # To be back-compatible, we convert data to dtype torch.long as needed 366 | # because for empty list, in legacy implementation, torch.as_tensor will 367 | # use torch.float as default dtype. This happens when decoding fails and 368 | # no pts is returned in the list. 369 | video_pts = [x.to(torch.int64) for x in self.video_pts] 370 | # video_pts can be an empty list if no frames have been decoded 371 | if video_pts: 372 | video_pts = torch.cat(video_pts) # type: ignore[assignment] 373 | # avoid bug in https://github.com/pytorch/pytorch/issues/32351 374 | # TODO: Revert it once the bug is fixed. 375 | video_pts = video_pts.numpy() # type: ignore[attr-defined] 376 | 377 | # make a copy of the fields of self 378 | d = self.__dict__.copy() 379 | d["video_pts_sizes"] = video_pts_sizes 380 | d["video_pts"] = video_pts 381 | # delete the following attributes to reduce the size of dictionary. They 382 | # will be re-computed in "__setstate__()" 383 | del d["clips"] 384 | del d["resampling_idxs"] 385 | del d["cumulative_sizes"] 386 | 387 | # for backwards-compatibility 388 | d["_version"] = 2 389 | return d 390 | 391 | def __setstate__(self, d: Dict[str, Any]) -> None: 392 | # for backwards-compatibility 393 | if "_version" not in d: 394 | self.__dict__ = d 395 | return 396 | 397 | video_pts = torch.as_tensor(d["video_pts"], dtype=torch.int64) 398 | video_pts = torch.split(video_pts, d["video_pts_sizes"], dim=0) 399 | # don't need this info anymore 400 | del d["video_pts_sizes"] 401 | 402 | d["video_pts"] = video_pts 403 | self.__dict__ = d 404 | # recompute attributes "clips", "resampling_idxs" and other derivative ones 405 | self.compute_clips(self.num_frames, self.step, self.frame_rate) -------------------------------------------------------------------------------- /PVDM/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import shutil 4 | import sys 5 | from datetime import datetime 6 | 7 | import numpy as np 8 | import torch 9 | from torch.utils.tensorboard import SummaryWriter 10 | 11 | import gdown 12 | 13 | 14 | class Logger(object): 15 | def __init__(self, fn, ask=True): 16 | if not os.path.exists("./results/"): 17 | os.mkdir("./results/") 18 | 19 | logdir = self._make_dir(fn) 20 | if not os.path.exists(logdir): 21 | os.mkdir(logdir) 22 | 23 | if len(os.listdir(logdir)) != 0 and ask: 24 | exit(1) 25 | 26 | self.set_dir(logdir) 27 | 28 | def _make_dir(self, fn): 29 | # today = datetime.today().strftime("%y%m%d") 30 | logdir = f'./results/{fn}/' 31 | return logdir 32 | 33 | def set_dir(self, logdir, log_fn='log.txt'): 34 | self.logdir = logdir 35 | if not os.path.exists(logdir): 36 | os.mkdir(logdir) 37 | self.writer = SummaryWriter(logdir) 38 | self.log_file = open(os.path.join(logdir, log_fn), 'a') 39 | 40 | def log(self, string): 41 | self.log_file.write('[%s] %s' % (datetime.now(), string) + '\n') 42 | self.log_file.flush() 43 | 44 | print('[%s] %s' % (datetime.now(), string)) 45 | sys.stdout.flush() 46 | 47 | def log_dirname(self, string): 48 | self.log_file.write('%s (%s)' % (string, self.logdir) + '\n') 49 | self.log_file.flush() 50 | 51 | print('%s (%s)' % (string, self.logdir)) 52 | sys.stdout.flush() 53 | 54 | def scalar_summary(self, tag, value, step): 55 | """Log a scalar variable.""" 56 | self.writer.add_scalar(tag, value, step) 57 | 58 | def image_summary(self, tag, images, step): 59 | """Log a list of images.""" 60 | self.writer.add_image(tag, images, step) 61 | 62 | def video_summary(self, tag, videos, step): 63 | self.writer.add_video(tag, videos, step, fps=16) 64 | 65 | def histo_summary(self, tag, values, step): 66 | """Log a histogram of the tensor of values.""" 67 | self.writer.add_histogram(tag, values, step, bins='auto') 68 | 69 | 70 | class AverageMeter(object): 71 | """Computes and stores the average and current value""" 72 | 73 | def __init__(self): 74 | self.value = 0 75 | self.average = 0 76 | self.sum = 0 77 | self.count = 0 78 | 79 | def reset(self): 80 | self.value = 0 81 | self.average = 0 82 | self.sum = 0 83 | self.count = 0 84 | 85 | def update(self, value, n=1): 86 | self.value = value 87 | self.sum += value * n 88 | self.count += n 89 | self.average = self.sum / self.count 90 | 91 | 92 | def set_random_seed(seed): 93 | random.seed(seed) 94 | np.random.seed(seed) 95 | torch.manual_seed(seed) 96 | torch.cuda.manual_seed(seed) 97 | torch.cuda.manual_seed_all(seed) 98 | 99 | 100 | def file_name(args): 101 | fn = f'{args.exp}_{args.id}_{args.data}' 102 | fn += f'_{args.seed}' 103 | return fn 104 | 105 | 106 | def psnr(mse): 107 | """ 108 | Computes PSNR from MSE. 109 | """ 110 | return -10.0 * mse.log10() 111 | 112 | def download(id, fname, root=os.path.expanduser('~/.cache/video-diffusion')): 113 | os.makedirs(root, exist_ok=True) 114 | destination = os.path.join(root, fname) 115 | 116 | if os.path.exists(destination): 117 | return destination 118 | 119 | gdown.download(id=id, output=destination, quiet=False) 120 | return destination 121 | 122 | 123 | def make_pairs(l, t1, t2, num_pairs, given_vid): 124 | B, T, C, H, W = given_vid.size() 125 | idx1 = t1.view(B, num_pairs, 1, 1, 1, 1).expand(B, num_pairs, 1, C, H, W).type(torch.int64) 126 | frame1 = torch.gather(given_vid.unsqueeze(1).repeat(1,num_pairs, 1,1,1,1), 2, idx1).squeeze() 127 | t1 = t1.float() / (l - 1) 128 | 129 | idx2 = t2.view(B, num_pairs, 1, 1, 1, 1).expand(B, num_pairs, 1, C, H, W).type(torch.int64) 130 | frame2 = torch.gather(given_vid.unsqueeze(1).repeat(1,num_pairs,1,1,1,1), 2, idx2).squeeze() 131 | t2 = t2.float() / (l - 1) 132 | 133 | frame1 = frame1.view(-1, C, H, W) 134 | frame2 = frame2.view(-1, C, H, W) 135 | 136 | # sort by t 137 | t1 = t1.view(-1, 1, 1, 1).repeat(1, C, H, W) 138 | t2 = t2.view(-1, 1, 1, 1).repeat(1, C, H, W) 139 | 140 | ret_frame1 = torch.where(t1 < t2, frame1, frame2) 141 | ret_frame2 = torch.where(t1 < t2, frame2 ,frame1) 142 | 143 | t1 = t1[:, 0:1] 144 | t2 = t2[:, 0:1] 145 | 146 | ret_t1 = torch.where(t1 < t2, t1, t2) 147 | ret_t2 = torch.where(t1 < t2, t2, t1) 148 | 149 | dt = ret_t2 - ret_t1 150 | 151 | return torch.cat([ret_frame1, ret_frame2, dt], dim=1) 152 | 153 | def make_mixed_pairs(l, t1, t2, given_vid_real, given_vid_fake): 154 | B, T, C, H, W = given_vid_real.size() 155 | idx1 = t1.view(-1, 1, 1, 1, 1).expand(B, 1, C, H, W).type(torch.int64) 156 | frame1 = torch.gather(given_vid_real, 1, idx1).squeeze() 157 | t1 = t1.float() / (l - 1) 158 | 159 | idx2 = t2.view(-1, 1, 1, 1, 1).expand(B, 1, C, H, W).type(torch.int64) 160 | frame2 = torch.gather(given_vid_fake, 1, idx2).squeeze() 161 | t2 = t2.float() / (l - 1) 162 | 163 | 164 | # sort by t 165 | t1 = t1.view(-1, 1, 1, 1).repeat(1, C, H, W) 166 | t2 = t2.view(-1, 1, 1, 1).repeat(1, C, H, W) 167 | 168 | ret_frame1 = torch.where(t1 < t2, frame1, frame2) 169 | ret_frame2 = torch.where(t1 < t2, frame2 ,frame1) 170 | 171 | t1 = t1[:, 0:1] 172 | t2 = t2[:, 0:1] 173 | 174 | ret_t1 = torch.where(t1 < t2, t1, t2) 175 | ret_t2 = torch.where(t1 < t2, t2, t1) 176 | 177 | dt = ret_t2 - ret_t1 178 | 179 | return torch.cat([ret_frame1, ret_frame2, dt], dim=1) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Compositional Foundation Models for Hierarchical Planning 2 |
3 | 4 | [[Website]](https://hierarchical-planning-foundation-model.github.io/) 5 | [[arXiv]](https://arxiv.org/abs/2309.08587) 6 | [[PDF]](https://arxiv.org/pdf/2309.08587.pdf) 7 | 8 | 9 | ![](images/teaser.png) 10 |
11 | 12 | To make effective decisions in novel environments with long-horizon goals, it is crucial to engage in hierarchical reasoning across spatial and temporal scales. This entails planning abstract subgoal sequences, visually reasoning about the underlying plans, and executing actions in accordance with the devised plan through visual-motor control. We propose Compositional Foundation Models for Hierarchical Planning (HiP), a foundation model which leverages multiple expert foundation model trained on language, vision and action data individually jointly together to solve long-horizon tasks. We use a large language model to construct symbolic plans that are grounded in the environment through a large video diffusion model. Generated video plans are then grounded to visual-motor control, through an inverse dynamics model that infers actions from generated videos. To enable effective reasoning within this hierarchy, we enforce consistency between the models via iterative refinement. We illustrate the efficacy and adaptability of our approach in three different long-horizon table-top manipulation tasks. 13 | 14 | **NOTE** This is a prelim version of the code. We are working on cleaning up the code and will provide a clean and complete version of the code in coming weeks. 15 | 16 | 17 | # Training Information 18 | 19 | 1. The training script for inverse dynamics is contained in `inv_dyn/inv_dyn_ft.py`. 20 | 2. The training script for video diffusion is contained in `PVDM/main.py`. 21 | 3. The training script for subgoal classifier is contained in `task_subgoal_consistency/train.py`. 22 | 23 | # Reference 24 | 25 | ```bibtex 26 | @article{ajay2023compositional, 27 | title={Compositional Foundation Models for Hierarchical Planning}, 28 | author={Ajay, Anurag and Han, Seungwook and Du, Yilun and Li, Shaung and Gupta, Abhi and Jaakkola, Tommi and Tenenbaum, Josh and Kaelbling, Leslie and Srivastava, Akash and Agrawal, Pulkit}, 29 | journal={arXiv preprint arXiv:2309.08587}, 30 | year={2023} 31 | } 32 | ``` 33 | 34 | # Acknowledgements 35 | 36 | The codebase is derived from [PVDM repo](https://github.com/sihyun-yu/PVDM). 37 | -------------------------------------------------------------------------------- /images/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anuragajay/hip/f83186af477bf6d8c5f8f3c1fdef43ccb89db25d/images/teaser.png -------------------------------------------------------------------------------- /inv_dyn/inv_dyn_ft.py: -------------------------------------------------------------------------------- 1 | import threading 2 | from io import BytesIO 3 | from queue import Queue 4 | import numpy as np 5 | import torch 6 | from torch.nn import functional as F 7 | import torch.nn as nn 8 | from ml_collections import ConfigDict 9 | from PIL import Image 10 | from scipy.spatial.transform import Rotation 11 | import torchvision.transforms as transforms 12 | from os import listdir 13 | from os.path import isfile, join 14 | from torch.utils.data import DataLoader 15 | from transformers import AutoImageProcessor, ViTMAEForPreTraining, ViTConfig, ViTModel, get_cosine_schedule_with_warmup, get_constant_schedule_with_warmup 16 | from transformers import AutoFeatureExtractor, ResNetForImageClassification, ResNetConfig, ResNetModel 17 | from transformers import ViTFeatureExtractor, ViTModel 18 | from transformers.image_utils import ChannelDimension 19 | from inv_dynamics.action_decoder import ActionDecoder, MultiCategoricalNet, CategoricalNet 20 | from inv_dynamics.dists import Categorical, MultiCategorical 21 | from matplotlib import pyplot as plt 22 | from tqdm import tqdm 23 | from .resnet import ResNetSmall 24 | from ffcv.fields.decoders import NDArrayDecoder 25 | from ffcv.transforms import ToTensor, Squeeze, ToDevice 26 | from ffcv.loader import Loader, OrderOption 27 | import vc_models 28 | from vc_models.models.vit import model_utils 29 | 30 | class InvDynamics(nn.Module): 31 | def __init__(self, state_dim=7): 32 | super(InvDynamics, self).__init__() 33 | self.state_dim = state_dim 34 | 35 | self.visual_model, self.embd_size, self.model_transforms, self.model_info = model_utils.load_model(model_utils.VC1_BASE_NAME) 36 | 37 | self.inv_model = nn.Linear(self.embd_size, self.state_dim) 38 | 39 | def forward(self, obs): 40 | # Get Action 41 | obs = self.model_transforms(obs) 42 | embed = self.visual_model(obs) 43 | return self.inv_model(embed) 44 | 45 | def calculate_loss(self, obs, state): 46 | pred_state = self.forward(obs) 47 | mse = F.mse_loss(pred_state, state) 48 | return mse 49 | 50 | @torch.no_grad() 51 | def calculate_test_loss(self, obs, state): 52 | pred_state = self.forward(obs) 53 | mse = F.mse_loss(pred_state, state) 54 | return mse 55 | 56 | 57 | def main(**deps): 58 | from tqdm import tqdm 59 | import wandb 60 | import numpy as np 61 | import torch 62 | import random 63 | import os 64 | 65 | device=torch.device('cuda:0') 66 | log_every=25 67 | 68 | # Create Loaders 69 | wandb.init( 70 | project='llm_diffusion', 71 | config={"lr": 3e-5, "batch_size":256}, 72 | group='inv_model_ft', 73 | ) 74 | 75 | batch_size = wandb.config.batch_size 76 | num_workers = 20 77 | parent_path = '/path/to/data' 78 | 79 | train_dataloader = Loader(f'{parent_path}/inv_dyn_train.beton', batch_size=batch_size, 80 | num_workers=num_workers, order=OrderOption.RANDOM, 81 | pipelines={ 82 | 'image': [NDArrayDecoder(), ToTensor(), ToDevice(device)], 83 | 'state': [NDArrayDecoder(), ToTensor(), ToDevice(device)], 84 | }) 85 | 86 | test_dataloader = Loader(f'{parent_path}/inv_dyn_test.beton', batch_size=batch_size, 87 | num_workers=num_workers, order=OrderOption.RANDOM, 88 | pipelines={ 89 | 'image': [NDArrayDecoder(), ToTensor(), ToDevice(device)], 90 | 'state': [NDArrayDecoder(), ToTensor(), ToDevice(device)], 91 | }) 92 | 93 | train_steps_per_epoch = len(train_dataloader) 94 | num_epochs = 20 95 | 96 | # Define model and optimizer 97 | inv_model = InvDynamics() 98 | inv_model = inv_model.to(device) 99 | 100 | optimizer = torch.optim.AdamW(inv_model.parameters(), lr=wandb.config.lr, betas=(0.9, 0.99), eps=1e-08, weight_decay=5e-4) 101 | 102 | # Run optimization 103 | for epoch_num in tqdm(range(num_epochs)): 104 | # Run train epoch 105 | inv_model.train() 106 | running_training_loss = 0 107 | for (idx, (obs, state)) in enumerate(train_dataloader): 108 | obs = (obs.permute(0,3,1,2))/255.0 109 | obs = obs.contiguous() 110 | loss = inv_model.calculate_loss(obs, state) 111 | optimizer.zero_grad() 112 | loss.backward() 113 | nn.utils.clip_grad_norm_(inv_model.parameters(), 1.0) 114 | optimizer.step() 115 | running_training_loss += loss.detach().item() 116 | del obs, state 117 | if idx > 0 and idx % log_every == 0: 118 | wandb.log({"epoch_num": epoch_num, "train_loss": running_training_loss/log_every, "itr": idx + epoch_num*train_steps_per_epoch}) 119 | print({"epoch_num": epoch_num, "train_loss": running_training_loss/log_every, "itr": idx + epoch_num*train_steps_per_epoch}) 120 | running_training_loss = 0 121 | 122 | # Run test epoch 123 | inv_model.eval() 124 | running_test_loss = 0 125 | with torch.no_grad(): 126 | for (idx, (obs, state)) in enumerate(test_dataloader): 127 | obs = (obs.permute(0,3,1,2))/255.0 128 | obs = obs.contiguous() 129 | sin_cos_mse = inv_model.calculate_test_loss(obs, state) 130 | running_test_loss += sin_cos_mse 131 | del obs, state 132 | 133 | wandb.log({"epoch_num": epoch_num, "test_loss": running_test_loss/(idx+1), "itr": (epoch_num+1)*train_steps_per_epoch}) 134 | print({"epoch_num": epoch_num, "test_loss": running_test_loss/(idx+1), "itr": (epoch_num+1)*train_steps_per_epoch}) 135 | 136 | # Save the current progress 137 | torch.save({'inv_model':inv_model.state_dict(), 'opt': optimizer.state_dict()}, 'inv.pt') 138 | 139 | wandb.finish() -------------------------------------------------------------------------------- /task_subgoal_consistency/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anuragajay/hip/f83186af477bf6d8c5f8f3c1fdef43ccb89db25d/task_subgoal_consistency/__init__.py -------------------------------------------------------------------------------- /task_subgoal_consistency/arguments.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def parse_arguments(): 4 | parser = argparse.ArgumentParser(description='Train a classifier to check task subgoal consistency') 5 | parser.add_argument('--data', type=str, default='/path/to/data/', metavar='DIR', help='path to dataset') 6 | parser.add_argument('--log-dir', type=str, default='./logs/', help='path to log directory') 7 | parser.add_argument('--checkpoint-dir', type=str, default='./checkpoints/', help='path to log directory') 8 | 9 | # model hyperparams 10 | parser.add_argument('--model-type', default='classifier', type=str, choices=['scorer', 'classifier'], 11 | help='type of model to train') 12 | parser.add_argument('--img-feature-extractor', default='resnet18', type=str, 13 | choices=('clip', 'conv', 'resnet18', 'resnet34'), 14 | help='pretrained image feature extractor to use (default=None=> no extractor)') 15 | parser.add_argument('--text-feature-extractor', default='flan-t5', type=str, 16 | choices=('clip', 'bert', 'gpt-2', 'flan-t5'), 17 | help='pretrained text feature extractor to use (default=None=> no extractor)') 18 | parser.add_argument('--classifier-arch', default='mlp', type=str, 19 | help='classifier architecture') 20 | parser.add_argument('--hidden-dims', help='hidden dimensions (in str) of the classifier delimited by comma', type=str, default='512,256,128') 21 | parser.add_argument('--output-dim', type=int, default=1, help='output dimension of the classifier') 22 | parser.add_argument('--vocab-size', type=int, default=22, 23 | help='vocab size of the data') 24 | parser.add_argument('--embedding-dim', type=int, default=128, 25 | help='embedding dimension for scorer') # maybe smaller? 26 | parser.add_argument('--scorer-arch', default='rnn', type=str, choices=['rnn', 'transformer'], 27 | help='scorer architecture') 28 | parser.add_argument('--dropout', type=float, default=0.0, 29 | help='dropout applied to outputs of each rnn layer') 30 | 31 | # dataset hyperparameters 32 | parser.add_argument('--dataset-type', default='single', choices=['single', 'subset', 'all'], type=str, 33 | help='whether to use single subgoal classification or subset classification') 34 | parser.add_argument('--task', default='paint', choices=['paint', 'cliport'], type=str, 35 | help='which task to run on') 36 | parser.add_argument('--train-ratio', default=0.9, type=float, 37 | help='ratio of data to use for training vs validation') 38 | parser.add_argument('--sample-ratio', default=1.0, type=float, 39 | help='sample complexity ratio (proportion of training data to use from entire training data)') 40 | parser.add_argument('--negative-sample-prob', default=0.5, type=float, 41 | help='probabiilty of sampling a negative example') 42 | 43 | # training hyperparameters 44 | parser.add_argument('--epochs', default=50, type=int, metavar='N', 45 | help='number of total epochs to run') 46 | parser.add_argument('--lr', default=0.0, type=float, metavar='LR', 47 | help='base learning rate') 48 | parser.add_argument('--lr-scheduler', default=None, type=str, metavar='LR-SCH', 49 | choices=('cosine'), 50 | help='scheduler for learning rate') 51 | parser.add_argument('--weight-decay', default=1e-6, type=float, metavar='W', 52 | help='weight decay') 53 | parser.add_argument('--batch-size', default=256, type=int, metavar='N', 54 | help='mini-batch size') 55 | parser.add_argument('--val-batch-size', default=256, type=int, metavar='N', 56 | help='mini-batch size for validation (uses only 1 gpu)') 57 | parser.add_argument('--concat-before', default=False, action='store_true', 58 | help='whether to concat the task & subgoals before passing to encoders') 59 | 60 | # training environment configs 61 | parser.add_argument('--workers', default=32, type=int, metavar='N', 62 | help='number of data loader workers') 63 | 64 | # submit configs 65 | parser.add_argument('--server', type=str, default='sc') 66 | parser.add_argument('--arg_str', default='--', type=str) 67 | parser.add_argument('--add_prefix', default='', type=str) 68 | parser.add_argument('--submit', action='store_true', default=False) 69 | 70 | # misc configs 71 | parser.add_argument('--gpu', default=-1, type=int, metavar='G', 72 | help='gpu to use (default: -1 => use cpu)') 73 | parser.add_argument('--seed', default=None, type=int, metavar='S', 74 | help='random seed') 75 | parser.add_argument('--print-freq', default=100, type=int, metavar='N', 76 | help='print frequency in # of iterations') 77 | parser.add_argument('--save-freq', default=5, type=int, metavar='N', 78 | help='save frequency in # of epochs') 79 | 80 | args = parser.parse_args() 81 | 82 | return args 83 | 84 | -------------------------------------------------------------------------------- /task_subgoal_consistency/train.py: -------------------------------------------------------------------------------- 1 | import random 2 | import warnings 3 | import os 4 | import getpass 5 | import sys 6 | import operator 7 | 8 | import numpy as np 9 | import torch 10 | from torch import nn 11 | import torch.nn.functional as F 12 | from torch.utils.tensorboard import SummaryWriter 13 | import torch.backends.cudnn as cudnn 14 | from torchmetrics import Accuracy 15 | from tqdm import tqdm 16 | import wandb 17 | 18 | from networks import ConsistencyClassifier, ConsistencyScorer 19 | from datasets import create_dataloaders 20 | from utils import process_hparams, AverageMeter, save_checkpoint, pad_list_of_strings, convert_token_to_label 21 | from arguments import parse_arguments 22 | 23 | 24 | def main(): 25 | args = parse_arguments() 26 | print('args:', args, flush=True) 27 | 28 | if args.submit: 29 | make_sh_and_submit(args) 30 | return 31 | 32 | # set deterministic seed if specified 33 | if args.seed is not None: 34 | random.seed(args.seed) 35 | torch.manual_seed(args.seed) 36 | cudnn.deterministic = True 37 | cudnn.benchmark = False 38 | warnings.warn('You have chosen to seed training. ' 39 | 'This will turn on the CUDNN deterministic setting, ' 40 | 'which can slow down your training considerably! ' 41 | 'You may see unexpected behavior when restarting ' 42 | 'from checkpoints.') 43 | 44 | torch.cuda.set_device(args.gpu) 45 | torch.backends.cudnn.benchmark = True 46 | 47 | # set up tensorboard logger and log hyperparameters 48 | if args.model_type == 'classifier': 49 | args.exp_id = f'task{args.task}_dataset{args.dataset_type}_sample{args.sample_ratio}_img{args.img_feature_extractor}_text{args.text_feature_extractor}_{args.classifier_arch}_{args.hidden_dims}_lr{args.lr}_scheduler{args.lr_scheduler}_bs{args.batch_size}_seed{args.seed}' 50 | elif args.model_type == 'scorer': 51 | args.exp_id = f'arch{args.scorer_arch}_img{args.img_feature_extractor}_text{args.text_feature_extractor}_hidden{args.hidden_dims}_{args.concat_before}_lr{args.lr}_scheduler{args.lr_scheduler}_bs{args.batch_size}_seed{args.seed}' 52 | args.log_dir = os.path.join(args.log_dir, args.exp_id) 53 | args.checkpoint_dir = os.path.join(args.checkpoint_dir, args.exp_id) 54 | os.makedirs(args.checkpoint_dir, exist_ok=True) 55 | 56 | wandb.init(project=f'llm_diffusion', name=args.exp_id, config=args, save_code=True) 57 | wandb.define_metric('train/step') 58 | wandb.define_metric("train/*", step_metric="train/step") 59 | wandb.define_metric('epoch') 60 | wandb.define_metric("val/*", step_metric="epoch") 61 | wandb.define_metric("best/*", step_metric="epoch") 62 | 63 | # process hyperparameters 64 | args = process_hparams(args) 65 | 66 | # set up model 67 | print('=> creating model...', flush=True) 68 | if args.model_type == 'classifier': 69 | model = ConsistencyClassifier(args.img_feature_extractor, args.text_feature_extractor, args.classifier_arch, args.hidden_dims, args.output_dim, args.concat_before, args.dataset_type, args.task, args.gpu) 70 | train = train_clf 71 | validate = validate_clf 72 | metric_op = operator.gt # for higher is better 73 | 74 | # loss & metric 75 | if args.output_dim == 1: 76 | criterion = nn.BCEWithLogitsLoss().cuda() 77 | metric = Accuracy('binary').cuda() 78 | elif args.output_dim > 1: 79 | criterion = nn.CrossEntropyLoss().cuda() 80 | metric = Accuracy('multiclass', num_classes=args.output_dim).cuda() 81 | elif args.model_type == 'scorer': 82 | args.vocab_size += 3 # add , , token, but no gradients will be propagated along this label 83 | model = ConsistencyScorer(args.vocab_size, args.img_feature_extractor, args.text_feature_extractor, args.scorer_arch, args.hidden_dims, args.dropout, args.concat_before, args.gpu) 84 | train = train_scorer 85 | validate = validate_scorer 86 | metric_op = operator.lt # for lower is better 87 | 88 | # loss & metric 89 | criterion = nn.functional.cross_entropy 90 | 91 | metric = token_to_label = {} # including token_to_label for clarity of naming 92 | token_to_label[''] = 0 93 | token_to_label[''] = 1 94 | token_to_label[''] = 2 95 | 96 | model = model.cuda() 97 | 98 | # optimizer 99 | optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), args.lr, weight_decay=args.weight_decay) 100 | 101 | # load checkpoint 102 | if os.path.isfile(args.checkpoint_dir / 'checkpoint.pth.tar'): 103 | print("=> loading checkpoint '{}'".format(args.checkpoint_dir / 'checkpoint.pth.tar')) 104 | checkpoint = torch.load(args.checkpoint_dir / 'checkpoint.pth.tar', map_location='cpu') 105 | args.start_epoch = checkpoint['epoch'] 106 | start_epoch = checkpoint['epoch'] 107 | best_metric = checkpoint['best_metric'] 108 | best_epoch = checkpoint['best_epoch'] 109 | 110 | model.load_state_dict(checkpoint['state_dict']) 111 | optimizer.load_state_dict(checkpoint['optimizer']) 112 | print("=> loaded checkpoint '{}' (epoch {})" 113 | .format(args.checkpoint_dir / 'checkpoint.pth.tar', checkpoint['epoch'])) 114 | 115 | if args.start_epoch >= args.epochs: 116 | print('=> already trained for {} epochs'.format(args.epochs)) 117 | return 118 | else: 119 | start_epoch = 0 120 | best_metric = 0. if metric_op == operator.gt else float('inf') 121 | 122 | # dataloader 123 | print('=> creating dataloader...', flush=True) 124 | train_loader, test_loader = create_dataloaders(args) 125 | 126 | # scheduler 127 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs) if args.lr_scheduler == 'cosine' else None 128 | 129 | best_epoch = 0 130 | 131 | for epoch in range(start_epoch, args.epochs): 132 | train(model, optimizer, train_loader, scheduler, criterion, metric, epoch, args) 133 | val_metric = validate(model, test_loader, criterion, metric, epoch, args) 134 | 135 | is_best = metric_op(val_metric, best_metric) 136 | best_epoch = epoch if is_best else best_epoch 137 | best_metric = val_metric if is_best else best_metric 138 | 139 | if (epoch+1) % args.save_freq == 0 or is_best: 140 | save_checkpoint({ 141 | 'epoch': epoch + 1, 142 | 'state_dict': model.state_dict(), 143 | 'best_metric': best_metric, 144 | 'best_epoch': best_epoch, 145 | 'optimizer' : optimizer.state_dict(), 146 | 'args': vars(args), 147 | }, is_best, args.checkpoint_dir, f'checkpoint_{epoch}.pth.tar') 148 | 149 | wandb.log({'best/best_metric': best_metric, 'best/best_epoch': best_epoch, 'epoch': epoch}) 150 | 151 | # train function for rnn 152 | def train_scorer(model, optimizer, train_loader, scheduler, criterion, token_to_label, epoch, args): 153 | print(f'=> training epoch {epoch}...', flush=True) 154 | model.train() 155 | 156 | for step, (task, subgoals, obs, obs_lens, obs_idxs) in enumerate(tqdm(train_loader), start=epoch * len(train_loader)): 157 | # add bos token to beginning of subgoals 158 | 159 | if torch.cuda.is_available(): 160 | # task = task.cuda(non_blocking=True) 161 | # subgoals = subgoals.cuda(non_blocking=True) 162 | obs = obs.cuda(non_blocking=True) 163 | 164 | # initialize hidden state 165 | hidden = None 166 | 167 | subgoals = [s.split() for s in subgoals] # split subgoals into list of tokens 168 | subgoals = pad_list_of_strings(subgoals, obs_idxs) # pad with to max seq len 169 | seq_len = len(subgoals[0]) 170 | subgoals = list(zip(*subgoals)) # transpose to list of seq_len x batch_size 171 | obs_idxs = torch.tensor(list(zip(*obs_idxs))) # transpose to list of seq_len x batch_size 172 | 173 | # using cumulative index calculated from obs_lens and relative idx from obs_idxs to get actual idx 174 | cum_idxs = torch.cumsum(obs_lens, dim=0) 175 | cum_idxs[1:] = cum_idxs[:-1].clone() 176 | cum_idxs[0] = 0 177 | 178 | loss = 0.0 179 | for s in range(seq_len-1): 180 | cur_subgoals = list(zip(*subgoals[:s+1])) 181 | cur_subgoals = [' '.join(c_s) for c_s in cur_subgoals] 182 | next_subgoals = list(subgoals[s+1]) 183 | next_subgoals = torch.tensor(convert_token_to_label(next_subgoals, token_to_label)).long().cuda(non_blocking=True) # convert on the fly 184 | # for loop over subgoals 185 | cur_obs_idxs = cum_idxs + obs_idxs[s] 186 | outputs, hidden = model(task, cur_subgoals, obs[cur_obs_idxs], hidden) 187 | loss += criterion(outputs, next_subgoals, ignore_index=token_to_label['']) # ignore the label for token 188 | 189 | optimizer.zero_grad() 190 | loss.backward() 191 | optimizer.step() 192 | 193 | if step % args.print_freq == 0: 194 | pp = np.exp(loss.item() / seq_len) 195 | print(f'epoch={epoch} step={step} loss={loss.item():.4f} perplexity={pp:.2f}', flush=True) 196 | wandb.log({'train/loss': loss.item(), 'train/perplexity': pp, 'train/step': step}) 197 | 198 | if scheduler is not None: 199 | scheduler.step() 200 | 201 | def train_clf(model, optimizer, train_loader, scheduler, criterion, metric, epoch, args): 202 | print(f'=> training epoch {epoch}...', flush=True) 203 | model.train() 204 | 205 | for step, (task, subgoals, obs, label) in enumerate(tqdm(train_loader), start=epoch * len(train_loader)): 206 | if torch.cuda.is_available(): 207 | # task = task.cuda(non_blocking=True) 208 | # subgoals = subgoals.cuda(non_blocking=True) 209 | obs = obs.cuda(non_blocking=True) 210 | label = label.view(-1, 1).cuda(non_blocking=True) 211 | 212 | output = model(task, subgoals, obs, all_subgoals=True if args.dataset_type == 'all' else False) 213 | loss = criterion(output, label) if args.output_dim == 1 else criterion(output, label.squeeze().long()) 214 | 215 | optimizer.zero_grad() 216 | loss.backward() 217 | optimizer.step() 218 | 219 | if step % args.print_freq == 0: 220 | if args.output_dim == 1: 221 | acc = metric(torch.sigmoid(output), label) 222 | else: 223 | acc = metric(output.argmax(dim=-1), label.squeeze()) 224 | print(f'epoch={epoch}, step={step}, loss={loss.item()}, acc={acc.item()}', flush=True) 225 | 226 | log_dict = { 227 | 'train/loss': loss.item(), 228 | 'train/acc': acc.item(), 229 | 'train/lr': optimizer.param_groups[0]['lr'], 230 | } 231 | 232 | wandb.log(log_dict) 233 | 234 | if scheduler is not None: 235 | scheduler.step() 236 | 237 | 238 | def validate_clf(model, test_loader, criterion, metric, epoch, args): 239 | print(f'=> validating epoch {epoch}...', flush=True) 240 | model.eval() 241 | 242 | val_loss = AverageMeter('ValLoss') 243 | val_acc = AverageMeter('ValAcc') 244 | 245 | with torch.no_grad(): 246 | for step, (task, subgoals, obs, label) in enumerate(tqdm(test_loader), start=epoch * len(test_loader)): 247 | if torch.cuda.is_available(): 248 | # task = task.cuda(non_blocking=True) 249 | # subgoals = subgoals.cuda(non_blocking=True) 250 | obs = obs.cuda(non_blocking=True) 251 | label = label.view(-1, 1).cuda(non_blocking=True) 252 | output = model(task, subgoals, obs, all_subgoals=True if args.dataset_type == 'all' else False) 253 | loss = criterion(output, label) if args.output_dim == 1 else criterion(output, label.squeeze().long()) 254 | if args.output_dim == 1: 255 | acc = metric(torch.sigmoid(output), label) 256 | else: 257 | acc = metric(output.argmax(dim=-1), label.squeeze()) 258 | 259 | val_acc.update(acc.item(), label.size(0)) 260 | val_loss.update(loss.item(), label.size(0)) 261 | 262 | # log to wandb 263 | log_dict = { 264 | 'val/loss': val_loss.avg, 265 | 'val/acc': val_acc.avg, 266 | 'epoch': epoch 267 | } 268 | 269 | wandb.log(log_dict) 270 | print(f'val epoch={epoch}, loss={val_loss.avg}, acc={val_acc.avg}', flush=True) 271 | 272 | return val_acc.avg 273 | 274 | def validate_scorer(model, test_loader, criterion, token_to_label, epoch, args): 275 | print(f'=> validating epoch {epoch}...', flush=True) 276 | model.eval() 277 | 278 | val_loss = AverageMeter('ValLoss') 279 | val_pp = AverageMeter('ValPerplexity') 280 | 281 | with torch.no_grad(): 282 | for step, (task, subgoals, obs, obs_lens, obs_idxs) in enumerate(tqdm(test_loader), start=epoch * len(test_loader)): 283 | if torch.cuda.is_available(): 284 | obs = obs.cuda(non_blocking=True) 285 | 286 | # initialize hidden state 287 | hidden = None 288 | 289 | subgoals = [s.split() for s in subgoals] # split subgoals into list of tokens 290 | subgoals = pad_list_of_strings(subgoals, obs_idxs) # pad with to max seq len 291 | seq_len = len(subgoals[0]) 292 | subgoals = list(zip(*subgoals)) # transpose to list of seq_len x batch_size 293 | obs_idxs = torch.tensor(list(zip(*obs_idxs))) # transpose to list of seq_len x batch_size 294 | 295 | # using cumulative index calculated from obs_lens and relative idx from obs_idxs to get actual idx 296 | cum_idxs = torch.cumsum(obs_lens, dim=0) 297 | cum_idxs[1:] = cum_idxs[:-1].clone() 298 | cum_idxs[0] = 0 299 | 300 | loss = 0.0 301 | for s in range(seq_len-1): 302 | cur_subgoals = list(zip(*subgoals[:s+1])) 303 | cur_subgoals = [' '.join(c_s) for c_s in cur_subgoals] 304 | next_subgoals = list(subgoals[s+1]) 305 | next_subgoals = torch.tensor(convert_token_to_label(next_subgoals, token_to_label)).long().cuda(non_blocking=True) # convert on the fly 306 | # for loop over subgoals 307 | cur_obs_idxs = cum_idxs + obs_idxs[s] 308 | outputs, hidden = model(task, cur_subgoals, obs[cur_obs_idxs], hidden) 309 | loss += criterion(outputs, next_subgoals, ignore_index=token_to_label['']) # ignore the label for token 310 | 311 | if step % args.print_freq == 0: 312 | pp = np.exp(loss.item() / seq_len) 313 | print(f'val epoch={epoch} step={step} loss={loss.item():.4f} perplexity={pp:.2f}', flush=True) 314 | # wandb.log({'val/loss': loss.item(), 'val/perplexity': pp}) 315 | 316 | val_loss.update(loss.item(), len(task)) 317 | val_pp.update(pp, len(task)) 318 | 319 | # log to wandb 320 | log_dict = { 321 | 'val/loss': val_loss.avg, 322 | 'val/perpelxity': val_pp.avg, 323 | } 324 | wandb.log(log_dict) 325 | 326 | return val_pp.avg 327 | 328 | 329 | 330 | 331 | def make_sh_and_submit(args, delay=0): 332 | os.makedirs('./scripts/submit_scripts/', exist_ok=True) 333 | os.makedirs(args.log_dir, exist_ok=True) 334 | options = args.arg_str 335 | 336 | if delay == 0: 337 | options_split = options.split(" ") 338 | name = ''.join([opt1.replace("--","").replace("=","").replace('gpu', '').replace('print-freq', '') for opt1 in options_split]) 339 | name = args.add_prefix + name 340 | 341 | else: # log_id should be already defined 342 | name = args.log_id 343 | print('Submitting the job with options: ') 344 | # print(options) 345 | print(f"experiment name: {name}") 346 | 347 | if args.server == 'insert server name': 348 | options += f' --server= --arg_str=\"{args.arg_str}\" ' 349 | preamble = ( 350 | f'#!/bin/sh\n#SBATCH --gres=gpu:2\n#SBATCH --exclusive\n#SBATCH --cpus-per-task=20\n#SBATCH ' 351 | f'-N 1\n#SBATCH -t 360\n#SBATCH ') 352 | preamble += f'--begin=now+{delay}hour\n#SBATCH ' 353 | preamble += (f'-o ./logs/{name}.out\n#SBATCH ' 354 | f'--job-name={name}_{delay}\n#SBATCH ' 355 | f'--open-mode=append\n\n') 356 | 357 | else: 358 | username = getpass.getuser() 359 | options += f' --server={args.server} ' 360 | preamble = ( 361 | f'#!/bin/sh\n#SBATCH --gres=gpu:volta:1\n#SBATCH --cpus-per-task=20\n#SBATCH ' 362 | f'-o ./logs/{name}.out\n#SBATCH ' 363 | f'--job-name={name}\n#SBATCH ' 364 | f'--open-mode=append\n\n' 365 | ) 366 | with open(f'./scripts/submit_scripts/{name}_{delay}.sh', 'w') as file: 367 | file.write(preamble) 368 | file.write("echo \"current time: $(date)\";\n") 369 | port = random.randrange(10000, 20000) 370 | file.write(f'wandb offline\n') 371 | file.write( 372 | f'python {sys.argv[0]} {options} ' 373 | ) 374 | 375 | if args.server == 'sc': 376 | if args.task == 'paint': 377 | file.write(f'--data /path/to/data') 378 | elif args.task == 'cliport': 379 | file.write(f'--data /path/to/data') 380 | 381 | os.system(f'sbatch ./scripts/submit_scripts/{name}_{delay}.sh') 382 | 383 | if __name__ == '__main__': 384 | main() 385 | -------------------------------------------------------------------------------- /task_subgoal_consistency/utils.py: -------------------------------------------------------------------------------- 1 | from itertools import chain 2 | import collections 3 | from pathlib import Path 4 | import shutil 5 | 6 | import numpy as np 7 | 8 | import torch 9 | from torch.utils.data.dataloader import default_collate 10 | 11 | 12 | # processing applied to hyperparameters 13 | def process_hparams(args): 14 | args.data = Path(args.data) 15 | args.log_dir = Path(args.log_dir) 16 | args.checkpoint_dir = Path(args.checkpoint_dir) 17 | 18 | args.hidden_dims = [int(x) for x in args.hidden_dims.split(',')] 19 | 20 | return args 21 | 22 | # 23 | def convert_token_to_label(tokens, token_to_label): 24 | idxs = [] 25 | for i in range(len(tokens)): 26 | if tokens[i] not in token_to_label: 27 | token_to_label[tokens[i]] = len(token_to_label) 28 | 29 | idxs.append(token_to_label[tokens[i]]) 30 | 31 | return idxs 32 | 33 | # pad list of strings to same length, pad obs_idxs with the last state too 34 | def pad_list_of_strings(list_of_strings, obs_idxs, pad_token=''): 35 | max_len = max([len(s) for s in list_of_strings]) 36 | 37 | for l in list_of_strings: 38 | l += [f'{pad_token}'] * (max_len - len(l)) 39 | 40 | return list_of_strings 41 | 42 | # # custom collate designed for new dataset without obs idxs (classifier) 43 | def custom_collate_clf(batch): 44 | elem = batch[0] 45 | 46 | if isinstance(elem, collections.abc.Sequence): # custom collate for subgoals (assuming it is index 1 in the batch) 47 | # check to make sure that the elements in batch have consistent size 48 | # it = iter(batch) 49 | # elem_size = len(next(it)) 50 | # if not all(len(elem) == elem_size for elem in it): 51 | # raise RuntimeError('each element in list of batch should be of equal size') 52 | transposed = list(zip(*batch)) 53 | 54 | return [default_collate(transposed[0]), transposed[1], default_collate(transposed[2]), default_collate(transposed[3])] # task, subgoals, obs, label 55 | 56 | else: # Fall back to `default_collate` 57 | return default_collate(batch) 58 | 59 | # custom collate designed for new dataset (scorer) 60 | def custom_collate_scorer(batch): 61 | elem = batch[0] 62 | 63 | if isinstance(elem, collections.abc.Sequence): # custom collate for subgoals (assuming it is index 1 in the batch) 64 | # check to make sure that the elements in batch have consistent size 65 | # it = iter(batch) 66 | # elem_size = len(next(it)) 67 | # if not all(len(elem) == elem_size for elem in it): 68 | # raise RuntimeError('each element in list of batch should be of equal size') 69 | transposed = list(zip(*batch)) 70 | obs_lens = torch.tensor([len(t) for t in transposed[2]]) 71 | obs = torch.cat(transposed[2], dim=0) 72 | # padding obs_idxs here 73 | obs_idxs = list(transposed[3]) 74 | max_len = max([len(obs_idx) for obs_idx in obs_idxs]) 75 | 76 | for i in range(len(obs_idxs)): 77 | if len(obs_idxs[i]) < max_len: 78 | obs_idxs[i] += [obs_idxs[i][-1]] * (max_len - len(obs_idxs[i])) 79 | obs_idxs = tuple(obs_idxs) 80 | 81 | return [default_collate(transposed[0]), list(chain.from_iterable(transposed[1])), obs, obs_lens, obs_idxs] # task (B x 1), subgoals (B x 1), obs (total number of frames x C x H x W), obs_idxs (B x seq_len) 82 | 83 | else: # Fall back to `default_collate` 84 | return default_collate(batch) 85 | 86 | # Average metric meter 87 | class AverageMeter(object): 88 | """Computes and stores the average and current value""" 89 | def __init__(self, name, fmt=':f'): 90 | self.name = name 91 | self.fmt = fmt 92 | self.reset() 93 | 94 | def reset(self): 95 | self.val = 0 96 | self.avg = 0 97 | self.sum = 0 98 | self.count = 0 99 | 100 | def update(self, val, n=1): 101 | self.val = val 102 | self.sum += val * n 103 | self.count += n 104 | self.avg = self.sum / self.count 105 | 106 | def __str__(self): 107 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 108 | return fmtstr.format(**self.__dict__) 109 | 110 | # save checkpoint 111 | def save_checkpoint(state, is_best, filedir, filename='checkpoint.pth.tar'): 112 | torch.save(state, filedir / filename) 113 | shutil.copyfile(filedir / filename, filedir / 'checkpoint.pth.tar') 114 | if is_best: 115 | shutil.copyfile(filedir / filename, filedir / 'checkpoint_best.pth.tar') 116 | --------------------------------------------------------------------------------