├── .github └── FUNDING.yml ├── FeedbackPolicy ├── data │ └── data.py ├── enrich_lang_annotations.json ├── eval │ ├── eval_calvin.py │ └── eval_utils.py ├── eval_calvin.sh ├── eval_sequences.json ├── models │ ├── distributed.py │ ├── factory.py │ ├── policy.py │ ├── transformer_utils.py │ └── vit.py ├── train │ ├── distributed.py │ ├── train_calvin.py │ └── train_utils.py └── train_calvin.sh ├── LICENSE ├── README.md ├── __init__.py ├── assets ├── CLOVER_Poster-1.png ├── closed-loop.jpg ├── clover_teaser.png ├── gen_diff_condition.png ├── long-horizon-task.gif └── vis_robustness.jpg ├── requirements.txt ├── setup.py └── visual_planner ├── accelerate_cfg.yaml ├── calvin_data.py ├── diffusion_model ├── __init__.py ├── diffusion_utils.py ├── dist_util.py ├── fp16_util.py ├── gaussian_diffusion.py ├── image_datasets.py ├── imagen.py ├── logger.py ├── losses.py ├── nn.py ├── resample.py ├── respace.py ├── script_util.py ├── train_util.py └── unet.py ├── metric_utils ├── __init__.py ├── calc_fvd.py ├── calc_lpips.py ├── calc_psnr.py └── calc_ssim.py ├── raft_utils ├── __init__.py ├── corr.py ├── extractor.py ├── raft.py ├── update.py └── utils.py ├── train.py ├── train.sh ├── trainer.py ├── val.sh └── visual_planner.py /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: [OpenDriveLab] # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] 4 | patreon: # Replace with a single Patreon username 5 | open_collective: # Replace with a single Open Collective username 6 | ko_fi: # Replace with a single Ko-fi username 7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel 8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry 9 | liberapay: # Replace with a single Liberapay username 10 | issuehunt: # Replace with a single IssueHunt username 11 | otechie: # Replace with a single Otechie username 12 | lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry 13 | custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] 14 | -------------------------------------------------------------------------------- /FeedbackPolicy/eval/eval_calvin.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import os 4 | import random 5 | from eval_utils import eval_one_epoch_calvin_ddp 6 | from torch.distributed.elastic.multiprocessing.errors import record 7 | 8 | os.environ['PYOPENGL_PLATFORM'] = 'egl' 9 | # os.environ['PYOPENGL_PLATFORM'] = 'osmesa' 10 | import numpy as np 11 | import torch 12 | import wandb 13 | from FeedbackPolicy.models.distributed import init_distributed_device, world_info_from_env 14 | from torch.nn.parallel import DistributedDataParallel as DDP 15 | 16 | from FeedbackPolicy.models.factory import load_model 17 | 18 | 19 | def random_seed(seed=42, rank=0): 20 | torch.manual_seed(seed) 21 | np.random.seed(seed) 22 | random.seed(seed) 23 | 24 | 25 | @record 26 | def main(): 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument("--vision_encoder_path", default="ViT-L-14", type=str) 29 | parser.add_argument("--vision_encoder_pretrained", default="openai", type=str) 30 | parser.add_argument("--seed", type=int, default=42) 31 | parser.add_argument( 32 | "--precision", 33 | choices=["amp_bf16", "amp_bfloat16", "bf16", "fp16", "fp32"], 34 | default="fp32", 35 | help="Floating point precision.", 36 | ) 37 | parser.add_argument( 38 | "--calvin_dataset", 39 | type=str, 40 | help="path to calvin_dataset", 41 | ) 42 | parser.add_argument("--calvin_conf_path", type=str, help="path to calvin configuration file") 43 | parser.add_argument( 44 | "--visual_planner_checkpoint", 45 | type=str, 46 | help="path to checkpoint to evaluate , this should contain model", 47 | default=None, 48 | ) 49 | parser.add_argument( 50 | "--policy_checkpoint", 51 | type=str, 52 | help="path to policy checkpoint to evaluate , this should contain model", 53 | default=None, 54 | ) 55 | parser.add_argument( 56 | "--horovod", 57 | default=False, 58 | action="store_true", 59 | help="Use horovod for distributed training.", 60 | ) 61 | parser.add_argument( 62 | "--dist-url", 63 | default="env://", 64 | type=str, 65 | help="url used to set up distributed training", 66 | ) 67 | parser.add_argument( 68 | "--no-set-device-rank", 69 | default=False, 70 | action="store_true", 71 | help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).", 72 | ) 73 | parser.add_argument( 74 | "--dist-backend", default="nccl", type=str, help="distributed backend" 75 | ) 76 | parser.add_argument( 77 | "--reset", 78 | default=False, 79 | action="store_true" 80 | ) 81 | parser.add_argument( 82 | "--visualize", 83 | default=False, 84 | action="store_true" 85 | ) 86 | parser.add_argument( 87 | "--diverse_inst", 88 | default=False, 89 | action="store_true" 90 | ) 91 | parser.add_argument('--sample_step', type=int, default=20, help="diffusion time steps") 92 | 93 | args = parser.parse_args() 94 | 95 | 96 | args.local_rank, args.rank, args.world_size = world_info_from_env() 97 | 98 | device_id = init_distributed_device(args) 99 | print("device_id: ", device_id) 100 | print("world_size: ", torch.distributed.get_world_size()) 101 | random_seed(args.seed) 102 | 103 | diffusion_model, policy_model, tokenizer, text_encoder = load_model( 104 | args.vision_encoder_path, 105 | args.vision_encoder_pretrained, 106 | sample_steps=args.sample_step, 107 | ) 108 | 109 | checkpoint_path = args.visual_planner_checkpoint 110 | print("Loading checkpoint from ", checkpoint_path) 111 | diffusion_model.load_state_dict(torch.load(checkpoint_path)['ema'], strict=True) 112 | 113 | 114 | diffusion_model = diffusion_model.to(device_id) 115 | diffusion_model.eval() 116 | ddp_diffusion_model = DDP(diffusion_model, device_ids=[device_id]) 117 | 118 | policy_model = policy_model.to(device_id) 119 | policy_model.eval() 120 | ddp_policy_model = DDP(policy_model, device_ids=[device_id]) 121 | 122 | checkpoint_path = args.policy_checkpoint 123 | print("Loading policy checkpoint from ", checkpoint_path) 124 | ddp_policy_model.load_state_dict(torch.load(checkpoint_path)['model_state_dict'], strict=False) 125 | 126 | 127 | ddp_diffusion_model.eval() 128 | eval_log_dir = None 129 | if args.visualize: 130 | eval_log_dir = 'evaluate/{}'.format(args.visual_planner_checkpoint.split('.')[0]) 131 | eval_one_epoch_calvin_ddp( 132 | args=args, 133 | model=ddp_diffusion_model, 134 | policy_model=ddp_policy_model, 135 | text_encoder=text_encoder, 136 | tokenizer=tokenizer, 137 | dataset_path=args.calvin_dataset, 138 | eval_log_dir=eval_log_dir, 139 | debug=args.visualize, 140 | reset=args.reset, 141 | diverse_inst=args.diverse_inst 142 | ) 143 | 144 | 145 | if __name__ == "__main__": 146 | os.environ["NCCL_BLOCKING_WAIT"] = '1' 147 | main() 148 | -------------------------------------------------------------------------------- /FeedbackPolicy/eval_calvin.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export EVALUTION_ROOT=$(pwd) 3 | 4 | # Set CALVIN path 5 | calvin_dataset_path='path_to_your/calvin/dataset/task_ABC_D' 6 | calvin_conf_path="path_to_your/calvin/calvin_models/conf" 7 | 8 | # Set checkpoints path 9 | visual_planner_checkpoint='path_to_your/visual_planner.pt' 10 | policy_checkpoint='path_to_your/feedback_policy.pth' 11 | 12 | export MESA_GL_VERSION_OVERRIDE=4.1 13 | node_num=4 14 | 15 | 16 | torchrun --nnodes=1 --nproc_per_node=${node_num} --master_port=6600 eval/eval_calvin.py \ 17 | --visual_planner_checkpoint ${visual_planner_checkpoint} \ 18 | --policy_checkpoint ${policy_checkpoint} \ 19 | --calvin_dataset ${calvin_dataset_path} \ 20 | --calvin_conf_path ${calvin_conf_path} \ 21 | --sample_step 20 \ 22 | 23 | 24 | -------------------------------------------------------------------------------- /FeedbackPolicy/models/distributed.py: -------------------------------------------------------------------------------- 1 | """ 2 | Util functions for setting up distributed training. 3 | Credit: https://github.com/mlfoundations/open_clip/blob/main/src/training/distributed.py 4 | """ 5 | 6 | import os 7 | import torch 8 | 9 | try: 10 | import horovod.torch as hvd 11 | except ImportError: 12 | hvd = None 13 | 14 | 15 | def is_global_master(args): 16 | return args.rank == 0 17 | 18 | 19 | def is_local_master(args): 20 | return args.local_rank == 0 21 | 22 | 23 | def is_master(args, local=False): 24 | return is_local_master(args) if local else is_global_master(args) 25 | 26 | 27 | def is_using_horovod(): 28 | # NOTE w/ horovod run, OMPI vars should be set, but w/ SLURM PMI vars will be set 29 | # Differentiating between horovod and DDP use via SLURM may not be possible, so horovod arg still required... 30 | ompi_vars = ["OMPI_COMM_WORLD_RANK", "OMPI_COMM_WORLD_SIZE"] 31 | pmi_vars = ["PMI_RANK", "PMI_SIZE"] 32 | if all([var in os.environ for var in ompi_vars]) or all( 33 | [var in os.environ for var in pmi_vars] 34 | ): 35 | return True 36 | else: 37 | return False 38 | 39 | 40 | def is_using_distributed(): 41 | if "WORLD_SIZE" in os.environ: 42 | return int(os.environ["WORLD_SIZE"]) > 1 43 | if "SLURM_NTASKS" in os.environ: 44 | return int(os.environ["SLURM_NTASKS"]) > 1 45 | return False 46 | 47 | 48 | def world_info_from_env(): 49 | local_rank = 0 50 | for v in ( 51 | "LOCAL_RANK", 52 | "MPI_LOCALRANKID", 53 | "SLURM_LOCALID", 54 | "OMPI_COMM_WORLD_LOCAL_RANK", 55 | ): 56 | if v in os.environ: 57 | local_rank = int(os.environ[v]) 58 | break 59 | global_rank = 0 60 | for v in ("RANK", "PMI_RANK", "SLURM_PROCID", "OMPI_COMM_WORLD_RANK"): 61 | if v in os.environ: 62 | global_rank = int(os.environ[v]) 63 | break 64 | world_size = 1 65 | for v in ("WORLD_SIZE", "PMI_SIZE", "SLURM_NTASKS", "OMPI_COMM_WORLD_SIZE"): 66 | if v in os.environ: 67 | world_size = int(os.environ[v]) 68 | break 69 | 70 | return local_rank, global_rank, world_size 71 | 72 | 73 | def init_distributed_device(args): 74 | # Distributed training = training on more than one GPU. 75 | # Works in both single and multi-node scenarios. 76 | args.distributed = False 77 | args.world_size = 1 78 | args.rank = 0 # global rank 79 | args.local_rank = 0 80 | if args.horovod: 81 | assert hvd is not None, "Horovod is not installed" 82 | hvd.init() 83 | args.local_rank = int(hvd.local_rank()) 84 | args.rank = hvd.rank() 85 | args.world_size = hvd.size() 86 | args.distributed = True 87 | os.environ["LOCAL_RANK"] = str(args.local_rank) 88 | os.environ["RANK"] = str(args.rank) 89 | os.environ["WORLD_SIZE"] = str(args.world_size) 90 | elif is_using_distributed(): 91 | if "SLURM_PROCID" in os.environ: 92 | # DDP via SLURM 93 | args.local_rank, args.rank, args.world_size = world_info_from_env() 94 | # SLURM var -> torch.distributed vars in case needed 95 | os.environ["LOCAL_RANK"] = str(args.local_rank) 96 | os.environ["RANK"] = str(args.rank) 97 | os.environ["WORLD_SIZE"] = str(args.world_size) 98 | torch.distributed.init_process_group( 99 | backend=args.dist_backend, 100 | init_method=args.dist_url, 101 | world_size=args.world_size, 102 | rank=args.rank, 103 | ) 104 | 105 | else: 106 | # DDP via torchrun, torch.distributed.launch 107 | args.local_rank, _, _ = world_info_from_env() 108 | from datetime import timedelta 109 | timeout = timedelta(hours=8) 110 | torch.distributed.init_process_group( 111 | backend=args.dist_backend, init_method=args.dist_url, timeout=timeout, 112 | ) 113 | print('is_using_torchrun' ) 114 | args.world_size = torch.distributed.get_world_size() 115 | args.rank = torch.distributed.get_rank() 116 | args.distributed = True 117 | else: 118 | # needed to run on single gpu 119 | torch.distributed.init_process_group( 120 | backend=args.dist_backend, 121 | init_method=args.dist_url, 122 | world_size=1, 123 | rank=0, 124 | ) 125 | 126 | if torch.cuda.is_available(): 127 | if args.distributed and not args.no_set_device_rank: 128 | device = "cuda:%d" % args.local_rank 129 | else: 130 | device = "cuda:0" 131 | torch.cuda.set_device(device) 132 | else: 133 | device = "cpu" 134 | args.device = device 135 | device = torch.device(device) 136 | return device 137 | -------------------------------------------------------------------------------- /FeedbackPolicy/models/factory.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from transformers import CLIPTextModel, CLIPTokenizer 4 | 5 | from visual_planner.trainer import GoalGaussianDiffusion 6 | from visual_planner.visual_planner import VisualPlanner 7 | 8 | from ema_pytorch import EMA 9 | 10 | from .policy import FeedbackDrivenPolicy 11 | from .vit import VisionTransformer 12 | 13 | 14 | IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) 15 | IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) 16 | 17 | IMAGENET_CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073) 18 | IMAGENET_CLIP_STD = (0.26862954, 0.26130258, 0.27577711) 19 | 20 | 21 | def load_model( 22 | clip_vision_encoder_path: str, 23 | clip_vision_encoder_pretrained: str, 24 | target_size=(128, 128), 25 | use_vae=False, 26 | with_depth=True, 27 | flow_reg=True, 28 | sample_per_seq=8, 29 | diffusion_steps=100, 30 | sample_steps=20, 31 | ): 32 | # CLIP text tokenizer / encoder 33 | pretrained_model = "clip-vit-large-patch14" 34 | tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path = pretrained_model) 35 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path = pretrained_model) 36 | text_encoder.requires_grad_(False) 37 | text_encoder.eval() 38 | 39 | 40 | # Visual planner 41 | unet = VisualPlanner( 42 | image_size = target_size[0], 43 | in_channels = 8, # RGBD 44 | out_channels = 4, 45 | use_vae = use_vae, 46 | decoupled_output = False, 47 | temporal_length = sample_per_seq, 48 | dims = 3, 49 | flow_reg = flow_reg, 50 | with_state_estimate = False, 51 | ) 52 | 53 | visual_planner = GoalGaussianDiffusion( 54 | channels=3, 55 | model=unet, 56 | image_size=target_size, 57 | timesteps=diffusion_steps, 58 | sampling_timesteps=sample_steps, 59 | loss_type='l2', 60 | objective='pred_v', 61 | beta_schedule = 'cosine', 62 | min_snr_loss_weight = True, 63 | auto_normalize = False, 64 | with_depth=with_depth, 65 | use_vae=use_vae, 66 | ) 67 | visual_planner.eval() 68 | visual_planner = EMA(visual_planner, beta = 0.999, update_every = 10) 69 | 70 | 71 | # Policy with VC-1 as RGB encoder 72 | from vc_models.models.vit import model_utils 73 | vision_encoder = model_utils.load_model(model_utils.VC1_BASE_NAME) 74 | embd_size = 768 75 | policy_model = FeedbackDrivenPolicy( 76 | vision_encoder = vision_encoder, 77 | vis_dim = embd_size, # 1024 for Large, 384 for Small 78 | window_size = 5, 79 | sampling_step = 1 80 | ) 81 | policy_model.eval() 82 | 83 | 84 | return visual_planner, policy_model, tokenizer, text_encoder 85 | 86 | 87 | 88 | def create_feedback_policy( 89 | vision_encoder: str = 'vc1-base', #TODO: Support additional visual encoders 90 | resume_from_checkpoint: str = None, 91 | ): 92 | 93 | import torchvision.transforms as transforms 94 | image_processor = transforms.Compose([ 95 | transforms.Resize((192, 192), interpolation = transforms.InterpolationMode.BICUBIC), 96 | transforms.ToTensor(), 97 | transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), 98 | ]) 99 | pretrained_model = "clip-vit-large-patch14" 100 | text_tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path = pretrained_model) 101 | 102 | from vc_models.models.vit import model_utils 103 | vision_encoder = model_utils.load_model(model_utils.VC1_BASE_NAME) 104 | embd_size = 768 105 | 106 | 107 | model = FeedbackDrivenPolicy(vision_encoder = vision_encoder, \ 108 | vis_dim = embd_size, 109 | window_size = 5, 110 | sampling_step = 1) 111 | 112 | model.vision_encoder.requires_grad_(False) 113 | 114 | def check_file_exists(file_path): 115 | if not os.path.isfile(file_path): 116 | raise FileNotFoundError(f"The file '{file_path}' does not exist.") 117 | 118 | 119 | print('Try loading from ckpt') 120 | try: 121 | check_file_exists(resume_from_checkpoint) 122 | old_ckpt = torch.load(resume_from_checkpoint)['model_state_dict'] 123 | 124 | # remove 'module.' in original keys 125 | new_ckpt = {} 126 | for k, v in old_ckpt.items(): 127 | new_ckpt[k[7:]] = v 128 | model.load_state_dict(new_ckpt, strict=False) 129 | 130 | except FileNotFoundError as e: 131 | print(e) 132 | 133 | return model, image_processor, text_tokenizer -------------------------------------------------------------------------------- /FeedbackPolicy/models/policy.py: -------------------------------------------------------------------------------- 1 | # import open_clip 2 | 3 | from typing import Optional, Tuple 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from einops import rearrange 10 | from einops import repeat 11 | 12 | from .transformer_utils import Block, PatchEmbed, get_2D_position_embeddings, RMSNorm, SwishGLU 13 | 14 | 15 | 16 | 17 | class MAPAttention(nn.Module): 18 | def __init__(self, embed_dim: int, n_heads: int) -> None: 19 | """Multi-Input Multi-Headed Attention Operation""" 20 | super().__init__() 21 | assert embed_dim % n_heads == 0, "`embed_dim` must be divisible by `n_heads`!" 22 | self.n_heads, self.scale = n_heads, (embed_dim // n_heads) ** -0.5 23 | 24 | # Projections (no bias) --> separate for Q (seed vector), and KV ("pool" inputs) 25 | self.q, self.kv = nn.Linear(embed_dim, embed_dim, bias=False), nn.Linear(embed_dim, 2 * embed_dim, bias=False) 26 | self.proj = nn.Linear(embed_dim, embed_dim) 27 | 28 | def forward(self, seed: torch.Tensor, x: torch.Tensor, attention_mask = None) -> torch.Tensor: 29 | (B_s, K, C_s), (B_x, N, C_x) = seed.shape, x.shape 30 | assert C_s == C_x, "Seed vectors and pool inputs must have the same embedding dimensionality!" 31 | 32 | # Project Seed Vectors to `queries` 33 | q = self.q(seed).reshape(B_s, K, self.n_heads, C_s // self.n_heads).permute(0, 2, 1, 3) 34 | kv = self.kv(x).reshape(B_x, N, 2, self.n_heads, C_x // self.n_heads).permute(2, 0, 3, 1, 4) 35 | k, v = kv.unbind(0) 36 | 37 | # Attention --> compute weighted sum over values! 38 | scores = q @ (k.transpose(-2, -1) * self.scale) 39 | if attention_mask is not None: 40 | attention_mask = ( 41 | attention_mask[:, None, None, :].repeat(1, self.n_heads, 1, 1) #.flatten(0, 1) 42 | ) 43 | scores.masked_fill_(attention_mask == 0, float("-inf")) 44 | attn = scores.softmax(dim=-1) 45 | vals = (attn @ v).transpose(1, 2).reshape(B_s, K, C_s) 46 | 47 | # Project back to `embed_dim` 48 | return self.proj(vals) 49 | 50 | 51 | ### ======Token Aggregator===== ### 52 | class TokenAggregation(nn.Module): 53 | def __init__( 54 | self, 55 | n_latents: int, 56 | embed_dim: int, 57 | n_heads: int, 58 | mlp_ratio: float = 4.0, 59 | do_rms_norm: bool = True, 60 | do_swish_glu: bool = True, 61 | #add_internal_latents: bool = False, 62 | ) -> None: 63 | """Multiheaded Attention Pooling Block -- note that for MAP, we adopt earlier post-norm conventions.""" 64 | super().__init__() 65 | self.n_latents, self.embed_dim, self.n_heads = n_latents, embed_dim, 2 * n_heads 66 | 67 | # Projection Operator 68 | self.projection = nn.Linear(embed_dim, self.embed_dim) 69 | 70 | # Custom MAP Attention (seed, encoder outputs) -> seed 71 | self.attn_norm = RMSNorm(self.embed_dim) if do_rms_norm else nn.LayerNorm(self.embed_dim, eps=1e-6) 72 | self.attn = MAPAttention(self.embed_dim, n_heads=self.n_heads) 73 | 74 | # Position-wise Feed-Forward Components 75 | self.mlp_norm = RMSNorm(self.embed_dim) if do_rms_norm else nn.LayerNorm(self.embed_dim, eps=1e-6) 76 | self.mlp = nn.Sequential( 77 | # Handle SwishGLU vs. GELU MLP... 78 | ( 79 | SwishGLU(self.embed_dim, int(mlp_ratio * self.embed_dim)) 80 | if do_swish_glu 81 | else nn.Sequential(nn.Linear(self.embed_dim, int(mlp_ratio * self.embed_dim)), nn.GELU()) 82 | ), 83 | nn.Linear(int(mlp_ratio * self.embed_dim), self.embed_dim), 84 | ) 85 | 86 | 87 | def forward(self, x: torch.Tensor, latents: torch.Tensor = None, mask = None) -> torch.Tensor: 88 | if len(latents.shape) == 2: 89 | latents = repeat(latents, "n_latents d -> bsz n_latents d", bsz=x.shape[0]) 90 | 91 | latents = self.attn_norm(latents + self.attn(latents, self.projection(x), mask)) 92 | latents = self.mlp_norm(latents + self.mlp(latents)) 93 | return latents.squeeze(dim=1) 94 | 95 | 96 | ### ======Feature Fusion===== ### 97 | class ConvFuser(nn.Module): 98 | def __init__(self, in_channels: int, out_channels: int, patch_num: int) -> None: 99 | super().__init__() 100 | self.in_channels = in_channels 101 | self.out_channels = out_channels 102 | self.patch_num = patch_num 103 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) 104 | 105 | self.channel_selection = nn.Sequential( 106 | nn.Linear(out_channels, out_channels), 107 | nn.Sigmoid() 108 | ) 109 | 110 | def forward(self, inputs_rgb: torch.Tensor, inputs_depth: torch.Tensor) -> torch.Tensor: 111 | inputs = torch.cat([inputs_rgb, inputs_depth], dim=-1) 112 | inputs = rearrange(inputs, 'b (h w) d -> b d h w', h=self.patch_num, w=self.patch_num) 113 | feature = self.conv(inputs) 114 | feature = rearrange(feature, 'b d h w -> b (h w) d') 115 | selection_weights = self.channel_selection(feature.mean(dim=1)) 116 | 117 | # channel-wise multiply 118 | feature = feature * selection_weights.unsqueeze(1) 119 | 120 | return feature 121 | 122 | 123 | 124 | ### ======Action Decoder===== ### 125 | class MLPActionVelocityHead_Tanh(torch.nn.Module): 126 | def __init__(self, hidden_size): 127 | super().__init__() 128 | self.hidden_size = hidden_size 129 | # Create a linear layer for each action 130 | self.num_head = nn.Sequential( 131 | nn.Linear(hidden_size, 1024), 132 | nn.ReLU(), 133 | nn.Linear(1024, 512), 134 | nn.ReLU(), 135 | nn.Linear(512, 6), 136 | nn.Tanh(), 137 | ) 138 | 139 | def forward(self, x): 140 | x = self.num_head(x) 141 | return x 142 | 143 | 144 | class MLPActionGripperHead(torch.nn.Module): 145 | def __init__(self, hidden_size): 146 | super().__init__() 147 | self.hidden_size = hidden_size 148 | # Create a linear layer for each action 149 | 150 | self.bin_head = nn.Sequential( 151 | nn.Linear(hidden_size, 1024), 152 | nn.ReLU(), 153 | nn.Linear(1024, 512), 154 | nn.ReLU(), 155 | nn.Linear(512, 1), 156 | # nn.Sigmoid() 157 | ) 158 | 159 | def forward(self, x): 160 | x = self.bin_head(x) 161 | return x 162 | 163 | 164 | 165 | class FeedbackDrivenPolicy(nn.Module): 166 | def __init__(self, vision_encoder, vis_dim, window_size, sampling_step): 167 | super().__init__() 168 | 169 | self.vision_encoder = vision_encoder 170 | self.window_size = window_size // sampling_step 171 | 172 | self.action_embed_cur = nn.Parameter(torch.zeros(1, vis_dim), requires_grad=True) 173 | nn.init.normal_(self.action_embed_cur, std=0.02) 174 | self.action_embed_tgt = nn.Parameter(torch.zeros(1, vis_dim), requires_grad=True) 175 | nn.init.normal_(self.action_embed_tgt, std=0.02) 176 | 177 | # ViT-S as depth encoder 178 | depth_res = 168 179 | patch_size = 14 180 | depth_dim = 384 181 | depth_encoder_layers = 6 182 | 183 | self.depth_patch2embed = PatchEmbed( 184 | resolution = depth_res, patch_size=patch_size, embed_dim=depth_dim, in_channels=1 185 | ) 186 | self.depth_encoder_pe = nn.Parameter( 187 | torch.zeros(1, self.depth_patch2embed.num_patches, depth_dim), 188 | requires_grad=False, 189 | ) 190 | enc_pe = get_2D_position_embeddings( 191 | depth_dim, int(self.depth_patch2embed.num_patches**0.5) 192 | ) 193 | self.depth_encoder_pe.data.copy_(torch.from_numpy(enc_pe).float().unsqueeze(0)) 194 | 195 | self.depth_encoder_blocks = nn.ModuleList( 196 | [ 197 | Block( 198 | embed_dim = depth_dim, 199 | n_heads = 6, 200 | mlp_ratio = 4, 201 | do_rms_norm=True, 202 | do_swish_glu=True, 203 | do_layer_scale=True, 204 | ) 205 | for _ in range(depth_encoder_layers) 206 | ] 207 | ) 208 | 209 | # Multimodal feature fusion 210 | self.fuser = ConvFuser(in_channels=vis_dim + depth_dim, out_channels=vis_dim, patch_num = depth_res // patch_size) 211 | 212 | # Token aggregator 213 | self.token_aggregation = TokenAggregation(n_latents = 1, embed_dim = vis_dim, n_heads = 8) 214 | 215 | # Action decoder 216 | self.velo_head = MLPActionVelocityHead_Tanh(hidden_size = vis_dim) 217 | self.gripper_head = MLPActionGripperHead(hidden_size = vis_dim) 218 | 219 | 220 | def _encode_vision(self, vision_x: torch.Tensor): 221 | """ 222 | Encode RGB inputs with VC-1. 223 | Args: 224 | vision_x (torch.Tensor): Vision input 225 | """ 226 | b, T = vision_x.shape[:2] 227 | vision_x = rearrange(vision_x, "b T c h w -> (b T) c h w") 228 | 229 | with torch.no_grad(): 230 | vision_x = self.vision_encoder(vision_x) 231 | 232 | vision_x = rearrange(vision_x, "(b T) d h w -> b T (h w) d", b=b, T=T) 233 | return vision_x 234 | 235 | 236 | def _encode_vision_depth(self, vision_depth: torch.Tensor, state_tensor=None): 237 | """ 238 | Encode depth map with ViT-S. 239 | Args: 240 | vision_depth (torch.Tensor): Depth map input 241 | """ 242 | b, T = vision_depth.shape[:2] 243 | vision_depth = rearrange(vision_depth, "b T c h w -> (b T) c h w") 244 | patch_depth = self.depth_patch2embed(vision_depth) + self.depth_encoder_pe 245 | 246 | for block in self.depth_encoder_blocks: 247 | patch_depth = block(patch_depth) 248 | 249 | patch_depth = rearrange(patch_depth, "(b T) v d -> b T v d", b=b, T=T) 250 | return patch_depth 251 | 252 | 253 | def get_pred_features(self, vision_x, vision_depth): 254 | 255 | vision_rgb = self._encode_vision(vision_x) 256 | vision_depth = self._encode_vision_depth(vision_depth) 257 | 258 | 259 | fused_feature = [] 260 | for i in range(vision_depth.shape[1]): 261 | fused_feature.append(self.fuser(vision_rgb[:,i], vision_depth[:,i])) 262 | fused_feature = torch.stack(fused_feature, dim=1) 263 | 264 | aggregated_tgt = [] 265 | 266 | for i in range(fused_feature.shape[1]): 267 | aggregated_tgt.append(self.token_aggregation(fused_feature[:,i], self.action_embed_tgt)) 268 | 269 | return aggregated_tgt 270 | 271 | 272 | def forward(self, vision_x, vision_depth): 273 | 274 | 275 | ### Multimodal Encoder 276 | vision_rgb = self._encode_vision(vision_x) 277 | vision_depth = self._encode_vision_depth(vision_depth) 278 | 279 | fused_feature = [] 280 | for i in range(vision_rgb.shape[1]): 281 | fused_feature.append(self.fuser(vision_rgb[:,i], vision_depth[:,i])) 282 | fused_feature = torch.stack(fused_feature, dim=1) 283 | 284 | 285 | ### Token Aggregator 286 | aggregated_tgt = self.token_aggregation(fused_feature[:,-1], self.action_embed_tgt) 287 | 288 | stacked_velo_pred = [] 289 | stacked_grip_pred = [] 290 | state_estimation = [] 291 | is_close_pred = [] 292 | 293 | for i in range(fused_feature.shape[1] - 1): 294 | aggregated_cur = self.token_aggregation(fused_feature[:,i], self.action_embed_cur) 295 | 296 | ### Error measurement 297 | cos_distance = 1 - F.cosine_similarity(F.normalize(aggregated_tgt), F.normalize(aggregated_cur)) 298 | aggregated_tokens = aggregated_tgt - aggregated_cur 299 | 300 | ### Decode Actions 301 | velo_pred = self.velo_head(aggregated_tokens) 302 | grip_pred = self.gripper_head(aggregated_tokens) 303 | 304 | stacked_velo_pred.append(velo_pred) 305 | stacked_grip_pred.append(grip_pred) 306 | 307 | 308 | stacked_velo_pred = torch.stack(stacked_velo_pred, dim=1) 309 | stacked_grip_pred = torch.stack(stacked_grip_pred, dim=1) 310 | 311 | 312 | return stacked_velo_pred, stacked_grip_pred, cos_distance 313 | 314 | 315 | -------------------------------------------------------------------------------- /FeedbackPolicy/models/transformer_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | from einops import rearrange 7 | from torch import einsum 8 | 9 | 10 | # === Position Encoding Utilities === 11 | 12 | 13 | # Helper/Utility Function -- computes simple 1D sinusoidal position embeddings for both 1D/2D use cases. 14 | # > We'll be combining two 1D sin-cos (traditional) position encodings for height/width of an image (grid features). 15 | def get_1D_sine_cosine(dim: int, pos: np.ndarray) -> np.ndarray: 16 | omega = np.arange(dim // 2, dtype=np.float32) / (dim / 2.0) 17 | omega = 1.0 / (10000**omega) 18 | out = np.einsum("m,d->md", pos.reshape(-1), omega) # [flatten(pos) x omega] -- outer product! 19 | emb_sin, emb_cos = np.sin(out), np.cos(out) 20 | return np.concatenate([emb_sin, emb_cos], axis=1) # [flatten(pos) x D] 21 | 22 | 23 | # 1D Sine-Cosine Position Embedding -- standard from "Attention is all you need!" 24 | def get_1D_position_embeddings(embed_dim: int, length: int) -> np.ndarray: 25 | return get_1D_sine_cosine(embed_dim, np.arange(length)) 26 | 27 | 28 | # 2D Sine-Cosine Position Embedding (from MAE repository) 29 | # > https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20 30 | def get_2D_position_embeddings(embed_dim: int, grid_size: int, cls_token: bool = False) -> np.ndarray: 31 | # Create 2D Position embeddings by taking cross product of height and width and splicing 1D embeddings... 32 | grid_h, grid_w = np.arange(grid_size, dtype=np.float32), np.arange(grid_size, dtype=np.float32) 33 | grid = np.stack(np.meshgrid(grid_w, grid_h), axis=0).reshape(2, 1, grid_size, grid_size) # w goes first? 34 | 35 | # Use half of dimensions to encode grid_h, other half to encode grid_w 36 | emb_h, emb_w = get_1D_sine_cosine(embed_dim // 2, grid[0]), get_1D_sine_cosine(embed_dim // 2, grid[1]) 37 | pos_embed = np.concatenate([emb_h, emb_w], axis=1) 38 | 39 | # CLS token handling (only for R-MVP) 40 | if cls_token: 41 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 42 | 43 | return pos_embed 44 | 45 | 46 | # === Vision Transformer Building Blocks === # 47 | # Patch Embedding Module 48 | class PatchEmbed(nn.Module): 49 | def __init__( 50 | self, 51 | resolution: int, 52 | patch_size: int, 53 | embed_dim: int, 54 | in_channels: int = 3, 55 | flatten: bool = True, 56 | ): 57 | super().__init__() 58 | self.resolution, self.patch_size = (resolution, resolution), (patch_size, patch_size) 59 | self.grid_size = (self.resolution[0] // self.patch_size[0], self.resolution[1] // self.patch_size[1]) 60 | self.num_patches = self.grid_size[0] * self.grid_size[1] 61 | self.flatten = flatten 62 | self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=self.patch_size, stride=self.patch_size) 63 | 64 | def forward(self, patches: torch.Tensor) -> torch.Tensor: 65 | patch_embeddings = self.proj(patches) 66 | if self.flatten: 67 | return rearrange(patch_embeddings, "bsz embed patch_h patch_w -> bsz (patch_h patch_w) embed") 68 | return patch_embeddings 69 | 70 | 71 | # === Stability Utilities === 72 | 73 | 74 | # LayerScale -- Trainable scaling for residual blocks -- Mistral/CaIT 75 | class LayerScale(nn.Module): 76 | def __init__(self, dim: int, init_values: float = 0.1) -> None: # CaIT :: 0.1 -> lay 12, 1e-5 -> lay 24, 1e-6... 77 | super().__init__() 78 | self.gamma = nn.Parameter(init_values * torch.ones(dim)) 79 | 80 | def forward(self, x: torch.Tensor) -> torch.Tensor: 81 | return x * self.gamma 82 | 83 | 84 | # RMSNorm -- Better, simpler alternative to LayerNorm 85 | class RMSNorm(nn.Module): 86 | def __init__(self, dim: int, eps: float = 1e-8) -> None: 87 | super().__init__() 88 | self.scale, self.eps = dim**-0.5, eps 89 | self.g = nn.Parameter(torch.ones(dim)) 90 | 91 | def forward(self, x: torch.Tensor) -> torch.Tensor: 92 | norm = torch.norm(x, dim=-1, keepdim=True) * self.scale 93 | return x / norm.clamp(min=self.eps) * self.g 94 | 95 | 96 | # SwishGLU -- A Gated Linear Unit (GLU) with the Swish activation; always better than GELU MLP! 97 | class SwishGLU(nn.Module): 98 | def __init__(self, in_dim: int, out_dim: int) -> None: 99 | super().__init__() 100 | self.act, self.project = nn.SiLU(), nn.Linear(in_dim, 2 * out_dim) 101 | 102 | def forward(self, x: torch.Tensor) -> torch.Tensor: 103 | projected, gate = self.project(x).tensor_split(2, dim=-1) 104 | return projected * self.act(gate) 105 | 106 | 107 | # === Fundamental Transformer Building Blocks === 108 | 109 | 110 | class Attention(nn.Module): 111 | def __init__(self, embed_dim: int, n_heads: int, dropout: float = 0.0) -> None: 112 | """Multi-Headed Self-Attention Operation""" 113 | super().__init__() 114 | assert embed_dim % n_heads == 0, "`embed_dim` must be divisible by `n_heads`!" 115 | self.n_heads, self.scale = n_heads, (embed_dim // n_heads) ** -0.5 116 | self.attn_softmax = None 117 | 118 | # Projections 119 | self.qkv, self.proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True), nn.Linear(embed_dim, embed_dim) 120 | self.dropout = nn.Dropout(dropout) 121 | 122 | def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: 123 | B, N, C = x.shape 124 | 125 | # Project to Q-K-V 126 | qkv = self.qkv(x).reshape(B, N, 3, self.n_heads, C // self.n_heads).permute(2, 0, 3, 1, 4) 127 | q, k, v = qkv.unbind(0) 128 | 129 | # Self-attention -- with masking! 130 | scores = q @ (k.transpose(-2, -1) * self.scale) 131 | if mask is not None: 132 | if mask.ndim == 2: 133 | mask = rearrange(mask, "bsz seq -> bsz 1 seq 1") 134 | elif mask.ndim != 4: 135 | raise NotImplementedError("Attention got `mask` of shape not in {2, 4}!") 136 | 137 | # Mask out by filling indices with negative infinity... 138 | scores = scores.masked_fill(mask == 0, torch.finfo(scores.dtype).min) 139 | 140 | # Compute weighted sum over values 141 | self.attn_softmax = scores.softmax(dim=-1) 142 | vals = (self.attn_softmax @ v).transpose(1, 2).reshape(B, N, C) 143 | 144 | # Project back to `embed_dim` -- with optional dropout 145 | vals = self.dropout(self.proj(vals)) 146 | return vals 147 | 148 | 149 | class Block(nn.Module): 150 | def __init__( 151 | self, 152 | embed_dim: int, 153 | n_heads: int, 154 | mlp_ratio: float = 4.0, 155 | dropout: float = 0.0, 156 | do_rms_norm: bool = False, 157 | do_swish_glu: bool = False, 158 | do_layer_scale: bool = False, 159 | ) -> None: 160 | """ 161 | Transformer Block Implementation (modality-agnostic). 162 | 163 | :param embed_dim: Core embedding/hidden dimension for vision transformer backbone. 164 | :param n_heads: Number of heads for multi-headed self-attention. 165 | :param mlp_ratio: Ratio for embedding size to position-wise feed-forward MLP (gets shrunk back down). 166 | :param dropout: [Optional] dropout for projection layer and MLPs -- for MAEs, always 0.0! 167 | :param do_rms_norm: Boolean whether or not to use RMSNorm in lieu of LayerNorm within block. 168 | :param do_swish_glu: Use the Swish-variant of the Gated Linear Unit for the feed-forward layers. 169 | :param do_layer_scale: Boolean whether or not to use LayerScale from Mistral/CaIT w/ initialization of 0.1. 170 | """ 171 | super().__init__() 172 | self.embed_dim, self.n_heads, self.do_layer_scale = embed_dim, n_heads, do_layer_scale 173 | 174 | # Attention Components 175 | self.pre_norm_attn = RMSNorm(self.embed_dim) if do_rms_norm else nn.LayerNorm(self.embed_dim, eps=1e-6) 176 | self.attn = Attention(self.embed_dim, n_heads=n_heads, dropout=dropout) 177 | if do_layer_scale: 178 | self.layer_scale_attn = LayerScale(self.embed_dim) 179 | 180 | # Position-wise Feed-Forward Components 181 | self.pre_norm_mlp = RMSNorm(self.embed_dim) if do_rms_norm else nn.LayerNorm(self.embed_dim, eps=1e-6) 182 | self.mlp = nn.Sequential( 183 | # Handle SwishGLU vs. GELU MLP... 184 | ( 185 | SwishGLU(embed_dim, int(mlp_ratio * embed_dim)) 186 | if do_swish_glu 187 | else nn.Sequential(nn.Linear(embed_dim, int(mlp_ratio * embed_dim)), nn.GELU()) 188 | ), 189 | nn.Dropout(dropout), 190 | nn.Linear(int(mlp_ratio * embed_dim), embed_dim), 191 | ) 192 | if self.do_layer_scale: 193 | self.layer_scale_mlp = LayerScale(self.embed_dim) 194 | 195 | def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: 196 | if self.do_layer_scale: 197 | x = x + self.layer_scale_attn(self.attn(self.pre_norm_attn(x), mask)) 198 | x = x + self.layer_scale_mlp(self.mlp(self.pre_norm_mlp(x))) 199 | else: 200 | x = x + self.attn(self.pre_norm_attn(x), mask) 201 | x = x + self.mlp(self.pre_norm_mlp(x)) 202 | return x 203 | 204 | 205 | 206 | 207 | -------------------------------------------------------------------------------- /FeedbackPolicy/models/vit.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Tuple, Union 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn 8 | 9 | 10 | class LayerNorm(nn.LayerNorm): 11 | """Subclass torch's LayerNorm to handle fp16.""" 12 | 13 | def forward(self, x: torch.Tensor): 14 | orig_type = x.dtype 15 | ret = super().forward(x.type(torch.float32)) 16 | return ret.type(orig_type) 17 | 18 | 19 | class QuickGELU(nn.Module): 20 | def forward(self, x: torch.Tensor): 21 | return x * torch.sigmoid(1.702 * x) 22 | 23 | 24 | class ResidualAttentionBlock(nn.Module): 25 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): 26 | super().__init__() 27 | 28 | self.attn = nn.MultiheadAttention(d_model, n_head) 29 | self.ln_1 = LayerNorm(d_model) 30 | self.mlp = nn.Sequential(OrderedDict([ 31 | ("c_fc", nn.Linear(d_model, d_model * 4)), 32 | ("gelu", QuickGELU()), 33 | ("c_proj", nn.Linear(d_model * 4, d_model)) 34 | ])) 35 | self.ln_2 = LayerNorm(d_model) 36 | self.attn_mask = attn_mask 37 | 38 | def attention(self, x: torch.Tensor): 39 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 40 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 41 | 42 | def forward(self, x: torch.Tensor): 43 | x = x + self.attention(self.ln_1(x)) 44 | x = x + self.mlp(self.ln_2(x)) 45 | return x 46 | 47 | 48 | class Transformer(nn.Module): 49 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): 50 | super().__init__() 51 | self.width = width 52 | self.layers = layers 53 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) 54 | 55 | def forward(self, x: torch.Tensor): 56 | return self.resblocks(x) 57 | 58 | 59 | class VisionTransformer(nn.Module): 60 | def __init__(self, input_resolution: int, patch_size: int, width: int=768, layers: int=12, heads: int=12, output_dim: int=1): 61 | super().__init__() 62 | self.input_resolution = input_resolution 63 | self.output_dim = output_dim 64 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) 65 | 66 | scale = width ** -0.5 67 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 68 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) 69 | self.ln_pre = LayerNorm(width) 70 | 71 | self.transformer = Transformer(width, layers, heads) 72 | 73 | self.ln_post = LayerNorm(width) 74 | # self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) 75 | 76 | def forward(self, x: torch.Tensor): 77 | x = self.conv1(x) # shape = [*, width, grid, grid] 78 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 79 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 80 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] 81 | x = x + self.positional_embedding.to(x.dtype) 82 | x = self.ln_pre(x) 83 | 84 | x = x.permute(1, 0, 2) # NLD -> LND 85 | x = self.transformer(x) 86 | x = x.permute(1, 0, 2) # LND -> NLD 87 | 88 | x = self.ln_post(x[:, 1:, :]) 89 | 90 | # if self.proj is not None: 91 | # x = x @ self.proj 92 | 93 | return x -------------------------------------------------------------------------------- /FeedbackPolicy/train/distributed.py: -------------------------------------------------------------------------------- 1 | """ 2 | Util functions for setting up distributed training. 3 | Credit: https://github.com/mlfoundations/open_clip/blob/main/src/training/distributed.py 4 | """ 5 | 6 | import os 7 | import torch 8 | 9 | try: 10 | import horovod.torch as hvd 11 | except ImportError: 12 | hvd = None 13 | 14 | 15 | def is_global_master(args): 16 | return args.rank == 0 17 | 18 | 19 | def is_local_master(args): 20 | return args.local_rank == 0 21 | 22 | 23 | def is_master(args, local=False): 24 | return is_local_master(args) if local else is_global_master(args) 25 | 26 | 27 | def is_using_horovod(): 28 | # NOTE w/ horovod run, OMPI vars should be set, but w/ SLURM PMI vars will be set 29 | # Differentiating between horovod and DDP use via SLURM may not be possible, so horovod arg still required... 30 | ompi_vars = ["OMPI_COMM_WORLD_RANK", "OMPI_COMM_WORLD_SIZE"] 31 | pmi_vars = ["PMI_RANK", "PMI_SIZE"] 32 | if all([var in os.environ for var in ompi_vars]) or all( 33 | [var in os.environ for var in pmi_vars] 34 | ): 35 | return True 36 | else: 37 | return False 38 | 39 | 40 | def is_using_distributed(): 41 | if "WORLD_SIZE" in os.environ: 42 | return int(os.environ["WORLD_SIZE"]) > 1 43 | if "SLURM_NTASKS" in os.environ: 44 | return int(os.environ["SLURM_NTASKS"]) > 1 45 | return False 46 | 47 | 48 | def world_info_from_env(): 49 | local_rank = 0 50 | for v in ( 51 | "LOCAL_RANK", 52 | "MPI_LOCALRANKID", 53 | "SLURM_LOCALID", 54 | "OMPI_COMM_WORLD_LOCAL_RANK", 55 | ): 56 | if v in os.environ: 57 | local_rank = int(os.environ[v]) 58 | break 59 | global_rank = 0 60 | for v in ("RANK", "PMI_RANK", "SLURM_PROCID", "OMPI_COMM_WORLD_RANK"): 61 | if v in os.environ: 62 | global_rank = int(os.environ[v]) 63 | break 64 | world_size = 1 65 | for v in ("WORLD_SIZE", "PMI_SIZE", "SLURM_NTASKS", "OMPI_COMM_WORLD_SIZE"): 66 | if v in os.environ: 67 | world_size = int(os.environ[v]) 68 | break 69 | 70 | return local_rank, global_rank, world_size 71 | 72 | 73 | def init_distributed_device(args): 74 | # Distributed training = training on more than one GPU. 75 | # Works in both single and multi-node scenarios. 76 | args.distributed = False 77 | args.world_size = 1 78 | args.rank = 0 # global rank 79 | args.local_rank = 0 80 | if args.horovod: 81 | assert hvd is not None, "Horovod is not installed" 82 | hvd.init() 83 | args.local_rank = int(hvd.local_rank()) 84 | args.rank = hvd.rank() 85 | args.world_size = hvd.size() 86 | args.distributed = True 87 | os.environ["LOCAL_RANK"] = str(args.local_rank) 88 | os.environ["RANK"] = str(args.rank) 89 | os.environ["WORLD_SIZE"] = str(args.world_size) 90 | elif is_using_distributed(): 91 | if "SLURM_PROCID" in os.environ: 92 | # DDP via SLURM 93 | args.local_rank, args.rank, args.world_size = world_info_from_env() 94 | # SLURM var -> torch.distributed vars in case needed 95 | os.environ["LOCAL_RANK"] = str(args.local_rank) 96 | os.environ["RANK"] = str(args.rank) 97 | os.environ["WORLD_SIZE"] = str(args.world_size) 98 | torch.distributed.init_process_group( 99 | backend=args.dist_backend, 100 | init_method=args.dist_url, 101 | world_size=args.world_size, 102 | rank=args.rank, 103 | ) 104 | else: 105 | # DDP via torchrun, torch.distributed.launch 106 | args.local_rank, _, _ = world_info_from_env() 107 | torch.distributed.init_process_group( 108 | backend=args.dist_backend, init_method=args.dist_url 109 | ) 110 | args.world_size = torch.distributed.get_world_size() 111 | args.rank = torch.distributed.get_rank() 112 | args.distributed = True 113 | else: 114 | # needed to run on single gpu 115 | torch.distributed.init_process_group( 116 | backend=args.dist_backend, 117 | init_method=args.dist_url, 118 | world_size=1, 119 | rank=0, 120 | ) 121 | 122 | if torch.cuda.is_available(): 123 | if args.distributed and not args.no_set_device_rank: 124 | device = "cuda:%d" % args.local_rank 125 | else: 126 | device = "cuda:0" 127 | torch.cuda.set_device(device) 128 | else: 129 | device = "cpu" 130 | args.device = device 131 | device = torch.device(device) 132 | return device 133 | -------------------------------------------------------------------------------- /FeedbackPolicy/train/train_calvin.py: -------------------------------------------------------------------------------- 1 | """ Main training script """ 2 | 3 | import argparse 4 | import copy 5 | import glob 6 | import os 7 | import random 8 | from collections import OrderedDict 9 | import numpy as np 10 | import torch 11 | import wandb 12 | from huggingface_hub import hf_hub_download 13 | 14 | from torch.nn.parallel import DistributedDataParallel as DDP 15 | 16 | from FeedbackPolicy.data.data import get_data 17 | from FeedbackPolicy.train.distributed import init_distributed_device, world_info_from_env 18 | from train_utils import train_one_epoch_calvin, get_ckpt_name 19 | from torch.distributed.elastic.multiprocessing.errors import record 20 | from transformers import ( 21 | get_constant_schedule_with_warmup, 22 | get_cosine_schedule_with_warmup, 23 | get_linear_schedule_with_warmup, 24 | ) 25 | 26 | from FeedbackPolicy.models.factory import create_feedback_policy 27 | 28 | 29 | def random_seed(seed=42, rank=0): 30 | torch.manual_seed(seed + rank) 31 | np.random.seed(seed + rank) 32 | random.seed(seed + rank) 33 | 34 | def adjust_learning_rate(optimizer, epoch): 35 | lr = optimizer.param_groups[0]['lr'] * 0.1 36 | return lr 37 | 38 | 39 | @record 40 | def main(): 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument("--vision_encoder", default="vc1-base", type=str) 43 | parser.add_argument( 44 | "--run_name", 45 | type=str, 46 | default="RobotFlamingo", 47 | help="used to name saving directory and wandb run", 48 | ) 49 | parser.add_argument("--num_epochs", type=int, default=1) 50 | parser.add_argument("--window_size", type=int, default=5) 51 | parser.add_argument("--sampling_step", type=int, default=1) 52 | parser.add_argument( 53 | "--logging_steps", type=int, default=100, help="log loss every n steps" 54 | ) 55 | # Sum of gradient optimization batch size 56 | parser.add_argument("--batch_size_calvin", type=int, default=1) 57 | parser.add_argument("--gradient_accumulation_steps", type=int, default=1) 58 | parser.add_argument( 59 | "--resume_from_checkpoint", 60 | type=str, 61 | help="path to checkpoint to resume from, this should contain model, optimizer, and lr_scheduler states", 62 | default="feedback_policy.pth", 63 | ) 64 | parser.add_argument("--seed", type=int, default=42) 65 | parser.add_argument("--learning_rate", default=1e-4, type=float) # 1e-4 66 | parser.add_argument( 67 | "--calvin_dataset", 68 | type=str, 69 | help="path to calvin_dataset", 70 | ) 71 | parser.add_argument("--warmup_steps", default=5000, type=int) 72 | parser.add_argument("--local-rank", default=0, type=int) 73 | parser.add_argument("--weight_decay", default=0.1, type=float) 74 | parser.add_argument( 75 | "--precision", 76 | choices=["amp_bf16", "amp_bfloat16", "bf16", "fp16", "fp32"], 77 | default="fp32", 78 | help="Floating point precision.", 79 | ) 80 | # data args 81 | parser.add_argument("--workers", type=int, default=1) 82 | parser.add_argument("--train_num_samples_calvin", type=int, default=100) 83 | parser.add_argument("--dataset_resampled", action="store_true") 84 | # distributed training args 85 | parser.add_argument( 86 | "--dist-url", 87 | default="env://", 88 | type=str, 89 | help="url used to set up distributed training", 90 | ) 91 | parser.add_argument( 92 | "--dist-backend", default="nccl", type=str, help="distributed backend" 93 | ) 94 | parser.add_argument( 95 | "--horovod", 96 | default=False, 97 | action="store_true", 98 | help="Use horovod for distributed training.", 99 | ) 100 | parser.add_argument( 101 | "--no-set-device-rank", 102 | default=False, 103 | action="store_true", 104 | help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).", 105 | ) 106 | # wandb args 107 | parser.add_argument("--report_to_wandb", default=False, action="store_true") 108 | parser.add_argument( 109 | "--wandb_project", 110 | type=str, 111 | ) 112 | parser.add_argument( 113 | "--wandb_entity", 114 | type=str, 115 | ) 116 | 117 | args = parser.parse_args() 118 | 119 | args.local_rank, args.rank, args.world_size = world_info_from_env() 120 | 121 | device_id = init_distributed_device(args) 122 | print("device_id: ", device_id) 123 | 124 | 125 | # Prepare models 126 | model, image_processor, tokenizer = create_feedback_policy( 127 | args.vision_encoder, 128 | args.resume_from_checkpoint, 129 | ) 130 | 131 | 132 | calvin_dataset = get_data(args, image_processor, tokenizer, "calvin") 133 | random_seed(args.seed, args.rank) 134 | 135 | print(f"Start running training on rank {args.rank}.") 136 | 137 | if args.rank == 0 and args.report_to_wandb: 138 | wandb.init( 139 | project=args.wandb_project, 140 | entity=args.wandb_entity, 141 | name=args.run_name, 142 | config=vars(args), 143 | ) 144 | 145 | device_id = args.rank % torch.cuda.device_count() 146 | if args.precision == "bf16" or args.precision == "amp_bfloat16" or args.precision == "amp_bf16": 147 | model = model.bfloat16() 148 | elif args.precision == "fp16": 149 | model = model.half() 150 | else: 151 | model = model.float() 152 | 153 | 154 | model = model.to(device_id) 155 | ddp_model = DDP(model, device_ids=[device_id], find_unused_parameters=False) 156 | 157 | args.learning_rate = args.learning_rate 158 | optimizer = torch.optim.AdamW(ddp_model.parameters(), lr=args.learning_rate) 159 | 160 | total_training_steps = calvin_dataset.dataloader.num_batches * args.num_epochs 161 | 162 | if args.rank == 0: 163 | print(f"Total training steps: {total_training_steps}") 164 | 165 | lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size = int(total_training_steps * 0.7), gamma=0.1) 166 | 167 | 168 | for epoch in range(args.num_epochs): 169 | calvin_dataset.set_epoch(epoch) 170 | calvin_loader = calvin_dataset.dataloader 171 | 172 | train_one_epoch_calvin( 173 | args=args, 174 | model=ddp_model, 175 | epoch=epoch, 176 | tokenizer=tokenizer, 177 | optimizer=optimizer, 178 | lr_scheduler=lr_scheduler, 179 | calvin_loader=calvin_loader, 180 | device_id=device_id, 181 | wandb=wandb, 182 | window_size = args.window_size 183 | ) 184 | 185 | if args.rank == 0: 186 | # pass 187 | if not os.path.exists(args.run_name): 188 | os.makedirs(args.run_name) 189 | 190 | checkpoint_dict = { 191 | "epoch": epoch, 192 | "model_state_dict": get_checkpoint(ddp_model), 193 | "optimizer_state_dict": optimizer.state_dict(), 194 | "lr_scheduler_state_dict": lr_scheduler.state_dict(), 195 | } 196 | 197 | ckpt_name = get_ckpt_name(args, epoch) 198 | ckpt_path = os.path.join(args.run_name, ckpt_name) 199 | 200 | print(f"Saving checkpoint to {ckpt_path}") 201 | torch.save(checkpoint_dict, ckpt_path) 202 | if args.delete_previous_checkpoint: 203 | if epoch > 0: 204 | os.remove(ckpt_path) 205 | 206 | if args.rank == 0: 207 | if not os.path.exists(args.run_name): 208 | os.makedirs(args.run_name) 209 | 210 | ckpt_name = get_ckpt_name(args,) 211 | torch.save(get_checkpoint(ddp_model), f"{args.run_name}/{ckpt_name}") 212 | if args.report_to_wandb and args.save_checkpoints_to_wandb: 213 | wandb.save(f"{args.run_name}/{ckpt_name}") 214 | 215 | 216 | if __name__ == "__main__": 217 | main() 218 | -------------------------------------------------------------------------------- /FeedbackPolicy/train/train_utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | from contextlib import suppress 3 | 4 | import torch 5 | from torch import nn 6 | import torch.nn.functional as F 7 | from torch.nn.parallel import DistributedDataParallel 8 | from tqdm import tqdm 9 | import itertools 10 | from einops import rearrange 11 | 12 | 13 | def get_cast_dtype(precision: str): 14 | cast_dtype = None 15 | if precision == "bf16" or precision == "amp_bf16": 16 | cast_dtype = torch.bfloat16 17 | elif precision == "fp16": 18 | cast_dtype = torch.float16 19 | return cast_dtype 20 | 21 | 22 | def get_autocast(precision): 23 | if precision == "amp": 24 | return torch.cuda.amp.autocast 25 | elif precision == "amp_bfloat16" or precision == "amp_bf16": 26 | # amp_bfloat16 is more stable than amp float16 for clip training 27 | return lambda: torch.cuda.amp.autocast(dtype=torch.bfloat16) 28 | else: 29 | return suppress 30 | 31 | 32 | def get_ckpt_name(args, epoch=-1): 33 | if epoch != -1: 34 | if epoch > 1000: 35 | ckpt_name += '{}_iter.pth'.format(epoch) 36 | else: 37 | ckpt_name += '{}.pth'.format(epoch) 38 | else: 39 | ckpt_name += 'final_weights.pth' 40 | return ckpt_name 41 | 42 | 43 | 44 | def train_one_epoch_calvin( 45 | args, 46 | model, 47 | epoch, 48 | calvin_loader, 49 | tokenizer, 50 | optimizer, 51 | lr_scheduler, 52 | device_id, 53 | wandb, 54 | window_size 55 | ): 56 | 57 | num_batches_per_epoch_calvin = calvin_loader.num_batches 58 | 59 | num_batches_per_epoch = num_batches_per_epoch_calvin 60 | total_training_steps = num_batches_per_epoch * args.num_epochs 61 | 62 | autocast = get_autocast(args.precision) 63 | cast_dtype = get_cast_dtype(args.precision) 64 | 65 | media_token_id = tokenizer("", add_special_tokens=False)["input_ids"][-1] 66 | endofchunk_token_id = tokenizer("<|endofchunk|>", add_special_tokens=False)[ 67 | "input_ids" 68 | ][-1] 69 | 70 | model.train() 71 | 72 | # setup logging 73 | step_time_m = ( 74 | AverageMeter() 75 | ) # time for one optimizer step (> 1 batch if using gradient accum) 76 | data_time_m = ( 77 | AverageMeter() 78 | ) # avg time to load one batch of both calvin (= 1 batch regardless of gradient accum) 79 | end = time.time() 80 | 81 | # loop through dataloader 82 | t = tqdm( 83 | enumerate(calvin_loader), 84 | disable=args.rank != 0, 85 | total=total_training_steps, 86 | initial=(epoch * num_batches_per_epoch), 87 | ) 88 | t.set_description(f"epoch {epoch+1}/{args.num_epochs}") 89 | mv_avg_loss = [] 90 | loss_record = [] 91 | for num_steps, batch_calvin in t: 92 | data_time_m.update(time.time() - end) 93 | global_step = num_steps + epoch * num_batches_per_epoch 94 | 95 | # put images and labels on device 96 | images = (batch_calvin[0].to(device_id, dtype=cast_dtype, non_blocking=True)) 97 | labels = batch_calvin[2].to(device_id, dtype=cast_dtype, non_blocking=True) 98 | gripper = (batch_calvin[3].to(device_id, dtype=cast_dtype, non_blocking=True)) 99 | depth_images = (batch_calvin[-2].to(device_id, dtype=cast_dtype, non_blocking=True)) 100 | 101 | # get and clip state tensor into 7-DoFs 102 | state_tensor = batch_calvin[4].to(device_id, dtype=cast_dtype, non_blocking=True) 103 | state_tensor = torch.cat([state_tensor[..., :6], state_tensor[..., [-1]]], dim=-1) 104 | state_tensor = state_tensor.unsqueeze(2).unsqueeze(2) 105 | 106 | labels = [labels[..., :6], (labels[..., 6:] + 1) // 2] 107 | 108 | 109 | # run model 110 | with autocast(): 111 | output = model( 112 | vision_x=images, 113 | vision_depth=depth_images, 114 | ) 115 | 116 | ### compute loss 117 | num_actions, bin_actions = output[0], output[1]#, output[2], output[3] 118 | 119 | velo_label = labels[0] 120 | grip_label = labels[1] 121 | 122 | loss_calvin_num = 0 123 | loss_calvin_bin = 0 124 | # communitive loss over time steps 125 | for i in range(velo_label.shape[1] - 1): 126 | loss_calvin_num += torch.nn.functional.huber_loss(num_actions[:, i], velo_label[:, i]) / (velo_label.shape[1] - 1) 127 | loss_calvin_bin += torch.nn.functional.binary_cross_entropy_with_logits(bin_actions[:, i], grip_label[:, i]) / (velo_label.shape[1] - 1) 128 | 129 | 130 | loss_calvin = loss_calvin_num + loss_calvin_bin * 0.1 131 | 132 | divided_loss_calvin = loss_calvin / args.gradient_accumulation_steps 133 | 134 | #### BACKWARD PASS #### 135 | loss = ( 136 | divided_loss_calvin * 1.0 137 | ) 138 | mv_avg_loss.append(loss.item()) 139 | loss_record.append(loss.item()) 140 | loss.backward() 141 | 142 | torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) 143 | 144 | # step optimizer and log 145 | if (((num_steps + 1) % args.gradient_accumulation_steps) == 0) or ( 146 | num_steps == num_batches_per_epoch - 1 147 | ): 148 | optimizer.step() 149 | lr_scheduler.step() 150 | optimizer.zero_grad() 151 | 152 | # step time and reset end outside of rank 0 153 | step_time_m.update(time.time() - end) 154 | end = time.time() 155 | 156 | if args.rank == 0 and args.report_to_wandb: 157 | # compute within rank 0 158 | calvin_samples_per_second = ( 159 | args.gradient_accumulation_steps 160 | * args.batch_size_calvin 161 | * args.world_size 162 | / step_time_m.val 163 | ) 164 | calvin_samples_per_second_per_gpu = ( 165 | args.gradient_accumulation_steps 166 | * args.batch_size_calvin 167 | / step_time_m.val 168 | ) 169 | 170 | wandb.log( 171 | { 172 | "data_time": data_time_m.avg, 173 | "step_time": step_time_m.avg, 174 | "calvin_samples_per_second": calvin_samples_per_second, 175 | "calvin_samples_per_second_per_gpu": calvin_samples_per_second_per_gpu, 176 | "lr": optimizer.param_groups[0]["lr"], 177 | }, 178 | commit=False, 179 | ) 180 | step_time_m.reset() 181 | data_time_m.reset() 182 | 183 | wandb.log( 184 | { 185 | "loss_calvin": divided_loss_calvin.item(), 186 | "global_step": global_step, 187 | "loss_velo": loss_calvin_num.item(), 188 | "loss_grip": loss_calvin_bin.item(), 189 | }, 190 | commit=True, 191 | ) 192 | 193 | 194 | # Log loss to console 195 | if ((num_steps + 1) % args.logging_steps == 0) and args.rank == 0: 196 | print( 197 | f"Step {num_steps+1}/{num_batches_per_epoch} of epoch {epoch+1}/{args.num_epochs} complete. Loss: (all){loss_calvin.item():.3f} (mse){loss_calvin_num.item():.3f} " + \ 198 | f"(bce){loss_calvin_bin.item():.3f}" 199 | ) 200 | avg_horizon = min(100, len(mv_avg_loss)) 201 | t.set_postfix({"avg loss": sum(mv_avg_loss[-avg_horizon:]) / avg_horizon, "loss": loss_calvin.item(), "Lnum": loss_calvin_num.item(), "Lbin": loss_calvin_bin.item() }) 202 | 203 | 204 | with open(f'D_loss_log_{args.run_name}.txt', 'a', encoding='utf8') as f: 205 | f.write('Average Loss: '+ str(sum(loss_record) / len(loss_record)) + '\n') 206 | 207 | 208 | 209 | def get_checkpoint(model): 210 | state_dict = model.state_dict() 211 | 212 | for name, p in model.named_parameters(): 213 | if not p.requires_grad and 'normalizer' not in name: 214 | del state_dict[name] 215 | 216 | return state_dict 217 | 218 | 219 | class AverageMeter(object): 220 | """Computes and stores the average and current value""" 221 | 222 | def __init__(self): 223 | self.reset() 224 | 225 | def reset(self): 226 | self.val = 0 227 | self.avg = 0 228 | self.sum = 0 229 | self.count = 0 230 | 231 | def update(self, val, n=1): 232 | self.val = val 233 | self.sum += val * n 234 | self.count += n 235 | self.avg = self.sum / self.count 236 | -------------------------------------------------------------------------------- /FeedbackPolicy/train_calvin.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # dataset path 3 | calvin_dataset_path='path_to_your/calvin/dataset/task_ABC_D' 4 | 5 | subfix=`date "+%Y%m%d-%H%M"` 6 | log_file="logs/training_"${subfix}".log" 7 | 8 | torchrun --nnodes=1 --nproc_per_node=8 train/train_calvin.py \ 9 | --vision_encoder vc1-base \ 10 | --num_epochs 10 \ 11 | --gradient_accumulation_steps 1 \ 12 | --batch_size_calvin 16 \ 13 | --run_name feedback_policy_calvin_abc \ 14 | --calvin_dataset ${calvin_dataset_path} \ 15 | --workers 4 \ 16 | --learning_rate 1e-4 \ 17 | --window_size 5 \ 18 | 2>&1 | tee ${log_file} 19 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

:four_leaf_clover: CLOVER

2 | 3 | The official implementation of our **NeurIPS 2024** paper: \ 4 | **Closed-Loop Visuomotor Control with Generative Expectation for Robotic Manipulation** 5 |
6 |

7 | 8 |

9 |
10 | 11 | > [Qingwen Bu](https://scholar.google.com/citations?user=-JCRysgAAAAJ&hl=zh-CN&oi=ao), [Jia Zeng](https://scholar.google.com/citations?hl=zh-CN&user=kYrUfMoAAAAJ), [Li Chen](https://scholar.google.com/citations?user=ulZxvY0AAAAJ&hl=zh-CN), Yanchao Yang, Guyue Zhou, Junchi Yan, Ping Luo, Heming Cui, Yi Ma and Hongyang Li 12 | 13 | > 📜 Preprint: :pushpin: Poster: 14 | 15 | > :mailbox_with_mail: If you have any questions, please feel free to contact: *Qingwen Bu* ( qwbu01@sjtu.edu.cn ) 16 | 17 | Full code and checkpoints release is coming soon. Please stay tuned.🦾 18 | 19 | ## :fire: Highlight 20 | 21 | * :four_leaf_clover: ​**CLOVER** employs a text-conditioned video diffusion model for generating visual plans as reference inputs, then these sub-goals guide the feedback-driven policy to generate actions with an error measurement strategy. 22 | 23 |
24 |

25 | 26 |

27 |
28 | 29 | * Owing to the closed-loop attribute, ​**CLOVER** is robust to visual distraction and object variation: 30 |
31 |

32 | 33 |

34 |
35 | 36 | * This closed-loop mechanism enables achieving the desired states accurately and reliably, thereby facilitating the execution of long-term tasks: 37 |
38 |

39 |

41 |
42 | 43 | 44 | 45 | 46 | 47 | ## :loudspeaker: News 48 | 49 | - **[2024/09/16]** We released our paper on [arXiv](https://arxiv.org/abs/2409.09016). 50 | - **[2024/12/01]** We have open sourced the entire codebase and will keep it updated, please give it a try! 51 | 52 | ## :pushpin: TODO list 53 | 54 | - [x] Training script for visual planner 55 | - [x] Checkpoints release (*Scheduled Release Date*: **Mid-October, 2024**) 56 | - [x] Evaluation codes on CALVIN (*Scheduled Release Date*: **Mid-October, 2024**) 57 | - [x] Policy training codes on CALVIN (*Estimated Release Period*: **November, 2024**) 58 | 59 | 60 | 61 | ## :video_game: Getting started 62 | 63 | Our training are conducted with **PyTorch 1.13.1**, **CUDA 11.7**, **Ubuntu 22.04**, and **NVIDIA Tesla A100 (80 GB)**. The closed-loop evaluation on CALVIN is run on a system with **NVIDIA RTX 3090**. 64 | 65 | We did further testing with **PyTorch 2.2.0 + CUDA 11.8**, and the training also goes fine. 66 | 67 | 1. (Optional) We use conda to manage the environment. 68 | 69 | ```bash 70 | conda create -n clover python=3.8 71 | conda activate clover 72 | ``` 73 | 74 | 2. Install dependencies. 75 | 76 | ```bash 77 | pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117 78 | pip install git+https://github.com/hassony2/torch_videovision 79 | pip install -e . 80 | ``` 81 | 82 | 3. Installation of CALVIN simulator. 83 | 84 | ```bash 85 | git clone --recurse-submodules https://github.com/mees/calvin.git 86 | export CALVIN_ROOT=$(pwd)/calvin 87 | cd $CALVIN_ROOT 88 | sh install.sh 89 | ``` 90 | 91 | ## :cd: Checkpoints 92 | 93 | We release model weights of our **Visual Planner** and **Feedback-driven Policy** at [HuggingFace](https://huggingface.co/qwbu/CLOVER). 94 | 95 | ## Training of Visual Planner 96 | 97 | - ### Requirement 98 | 99 | The visual planner requires **24 GB** GPU VRAM with a batch size of 4 (per GPU), video length of 8 and image size of 128. 100 | 101 | - ### Preparation 102 | 103 | * We use [OpenAI-CLIP](https://huggingface.co/openai/clip-vit-large-patch14) to encode task instructions for conditioning. 104 | 105 | - ### Initiate training of the visual planner (video diffusion model) on CALVIN 106 | 107 | > Please modify **accelerate_cfg.yaml** first according to your setup. 108 | 109 | ```bash 110 | accelerate launch --config_file accelerate_cfg.yaml train.py \ 111 | --learning_rate 1e-4 \ 112 | --train_num_steps 300000 \ 113 | --save_and_sample_every 10000 \ 114 | --train_batch_size 32 \ 115 | --sample_per_seq 8 \ 116 | --sampling_step 5 \ 117 | --with_text_conditioning \ 118 | --diffusion_steps 100 \ 119 | --sample_steps 10 \ 120 | --with_depth \ 121 | --flow_reg \ 122 | --results_folder *path_to_save_your_ckpts* 123 | ``` 124 | 125 | ## Training of Feedback Policy 126 | 127 | - ### Preparation 128 | 129 | * We only support VC-1 as visual encoder for now, please setup environments and download pre-trained checkpoints according to [eai-vc](https://github.com/facebookresearch/eai-vc) 130 | * Set your **calvin_dataset_path** in ```FeedbackPolicy/train_calvin.sh``` 131 | 132 | - ### Initiate training of the Feedback-driven Policy (Inverse Dynamics Model) on CALVIN 133 | ``` 134 | cd ./FeedbackPolicy 135 | bash train_calvin.sh 136 | ``` 137 | 138 | 139 | ## Evaluation 140 | 141 | - ### Preparation 142 | 143 | 1. Set your CALVIN and checkpoint path at *FeedbackPolicy/eval_calvin.sh* 144 | 2. We train our policy with input size of 192*192, please modify the config file correspondingly in [VC-1 Config](https://github.com/facebookresearch/eai-vc/blob/76fe35e87b1937168f1ec4b236e863451883eaf3/vc_models/src/vc_models/conf/model/vc1_vitb.yaml#L7) with `img_size: 192` and `use_cls: False`. 145 | 146 | - ### Initiate evaluation on CALVIN simply with 147 | 148 | ```bash 149 | cd ./FeedbackPolicy 150 | bash eval_calvin.sh 151 | ``` 152 | 153 | 154 | 155 | ## :pencil: Citation 156 | 157 | If you find the project helpful for your research, please consider citing our paper: 158 | 159 | ```bibtex 160 | @article{bu2024clover, 161 | title={Closed-Loop Visuomotor Control with Generative Expectation for Robotic Manipulation}, 162 | author={Bu, Qingwen and Zeng, Jia and Chen, Li and Yang, Yanchao and Zhou, Guyue and Yan, Junchi and Luo, Ping and Cui, Heming and Ma, Yi and Li, Hongyang}, 163 | journal={arXiv preprint arXiv:2409.09016}, 164 | year={2024} 165 | } 166 | ``` 167 | 168 | ## Acknowledgements 169 | 170 | We thank [AVDC](https://github.com/flow-diffusion/AVDC) and [RoboFlamingo](https://github.com/RoboFlamingo/RoboFlamingo) for their open-sourced work! 171 | 172 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | """CLOVER: Closed-Loop Visuomotor Control with Generative Expectation for Robotic Manipulation 2 | """ 3 | 4 | __version__ = "0.0.1" 5 | __project__ = "CLOVER" 6 | __author__ = "Qingwen Bu" 7 | __license__ = "Apache License 2.0" 8 | __email__ = "qwbu01@sjtu.edu.cn" -------------------------------------------------------------------------------- /assets/CLOVER_Poster-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenDriveLab/CLOVER/92324ed0fbe563fcb272bed18522a34580958bdc/assets/CLOVER_Poster-1.png -------------------------------------------------------------------------------- /assets/closed-loop.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenDriveLab/CLOVER/92324ed0fbe563fcb272bed18522a34580958bdc/assets/closed-loop.jpg -------------------------------------------------------------------------------- /assets/clover_teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenDriveLab/CLOVER/92324ed0fbe563fcb272bed18522a34580958bdc/assets/clover_teaser.png -------------------------------------------------------------------------------- /assets/gen_diff_condition.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenDriveLab/CLOVER/92324ed0fbe563fcb272bed18522a34580958bdc/assets/gen_diff_condition.png -------------------------------------------------------------------------------- /assets/long-horizon-task.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenDriveLab/CLOVER/92324ed0fbe563fcb272bed18522a34580958bdc/assets/long-horizon-task.gif -------------------------------------------------------------------------------- /assets/vis_robustness.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenDriveLab/CLOVER/92324ed0fbe563fcb272bed18522a34580958bdc/assets/vis_robustness.jpg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | huggingface-hub 2 | diffusers 3 | matplotlib==3.7.5 4 | einops==0.7.0 5 | einops-exts==0.0.4 6 | ema-pytorch==0.2.3 7 | tqdm==4.66.1 8 | accelerate==0.23.0 9 | transformers==4.34.0 10 | pytorch-fid==0.3.0 11 | pynvml==11.5.0 12 | tensorboard 13 | imageio[ffmpeg] 14 | wandb==0.16.6 15 | opencv-python==4.9.0.80 16 | pytorch-lightning==1.8.6 17 | peft==0.6.2 18 | rotary-embedding-torch==0.5.3 19 | git+https://github.com/hassony2/torch_videovision 20 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Setup CLOVER installation.""" 4 | 5 | from os import path as op 6 | import re 7 | 8 | from setuptools import find_packages, setup 9 | 10 | 11 | def _read(f): 12 | return open(op.join(op.dirname(__file__), f)).read() if op.exists(f) else "" 13 | 14 | 15 | _meta = _read("__init__.py") 16 | 17 | 18 | def find_meta(_meta, string): 19 | l_match = re.search(r"^" + string + r'\s*=\s*"(.*)"', _meta, re.M) 20 | if l_match: 21 | return l_match.group(1) 22 | raise RuntimeError(f"Unable to find {string} string.") 23 | 24 | 25 | install_requires = [ 26 | l for l in _read("requirements.txt").split("\n") if l and not l.startswith("#") and not l.startswith("-") 27 | ] 28 | 29 | meta = dict( 30 | name=find_meta(_meta, "__project__"), 31 | version=find_meta(_meta, "__version__"), 32 | license=find_meta(_meta, "__license__"), 33 | description="CLOVER: Closed-Loop Visuomotor Control with Generative Expectation for Robotic Manipulation", 34 | platforms=("Any"), 35 | zip_safe=False, 36 | author=find_meta(_meta, "__author__"), 37 | author_email=find_meta(_meta, "__email__"), 38 | url="https://github.com/OpenDriveLab/CLOVER", 39 | packages=find_packages(exclude=["tests"]), 40 | install_requires=install_requires, 41 | ) 42 | 43 | if __name__ == "__main__": 44 | print("find_package", find_packages(exclude=["tests"])) 45 | setup(**meta) -------------------------------------------------------------------------------- /visual_planner/accelerate_cfg.yaml: -------------------------------------------------------------------------------- 1 | main_process_port: 29502 2 | distributed_type: MULTI_GPU 3 | num_processes: 8 -------------------------------------------------------------------------------- /visual_planner/diffusion_model/__init__.py: -------------------------------------------------------------------------------- 1 | from . import gaussian_diffusion as gd 2 | from .respace import SpacedDiffusion, space_timesteps 3 | 4 | 5 | def create_diffusion( 6 | timestep_respacing, 7 | noise_schedule="linear", 8 | use_kl=False, 9 | sigma_small=False, 10 | predict_xstart=False, 11 | learn_sigma=True, 12 | # learn_sigma=False, 13 | rescale_learned_sigmas=False, 14 | diffusion_steps=100 15 | ): 16 | betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps) 17 | if use_kl: 18 | loss_type = gd.LossType.RESCALED_KL 19 | elif rescale_learned_sigmas: 20 | loss_type = gd.LossType.RESCALED_MSE 21 | else: 22 | loss_type = gd.LossType.MSE 23 | if timestep_respacing is None or timestep_respacing == "": 24 | timestep_respacing = [diffusion_steps] 25 | return SpacedDiffusion( 26 | use_timesteps=space_timesteps(diffusion_steps, timestep_respacing), 27 | betas=betas, 28 | model_mean_type=( 29 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X 30 | ), 31 | model_var_type=( 32 | ( 33 | gd.ModelVarType.FIXED_LARGE 34 | if not sigma_small 35 | else gd.ModelVarType.FIXED_SMALL 36 | ) 37 | if not learn_sigma 38 | else gd.ModelVarType.LEARNED_RANGE 39 | ), 40 | loss_type=loss_type 41 | # rescale_timesteps=rescale_timesteps, 42 | ) -------------------------------------------------------------------------------- /visual_planner/diffusion_model/diffusion_utils.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | import torch as th 7 | import numpy as np 8 | 9 | 10 | def get_1d_sincos_temp_embed(embed_dim, length): 11 | pos = th.arange(0, length).unsqueeze(1) 12 | return get_1d_sincos_pos_embed_from_grid(embed_dim, pos) 13 | 14 | 15 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 16 | """ 17 | embed_dim: output dimension for each position 18 | pos: a list of positions to be encoded: size (M,) 19 | out: (M, D) 20 | """ 21 | assert embed_dim % 2 == 0 22 | omega = np.arange(embed_dim // 2, dtype=np.float64) 23 | omega /= embed_dim / 2. 24 | omega = 1. / 10000**omega 25 | 26 | pos = pos.reshape(-1) 27 | out = np.einsum('m,d->md', pos, omega) 28 | 29 | emb_sin = np.sin(out) 30 | emb_cos = np.cos(out) 31 | 32 | emb = np.concatenate([emb_sin, emb_cos], axis=1) 33 | return emb 34 | 35 | 36 | def normal_kl(mean1, logvar1, mean2, logvar2): 37 | """ 38 | Compute the KL divergence between two gaussians. 39 | Shapes are automatically broadcasted, so batches can be compared to 40 | scalars, among other use cases. 41 | """ 42 | tensor = None 43 | for obj in (mean1, logvar1, mean2, logvar2): 44 | if isinstance(obj, th.Tensor): 45 | tensor = obj 46 | break 47 | assert tensor is not None, "at least one argument must be a Tensor" 48 | 49 | # Force variances to be Tensors. Broadcasting helps convert scalars to 50 | # Tensors, but it does not work for th.exp(). 51 | logvar1, logvar2 = [ 52 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) 53 | for x in (logvar1, logvar2) 54 | ] 55 | 56 | return 0.5 * ( 57 | -1.0 58 | + logvar2 59 | - logvar1 60 | + th.exp(logvar1 - logvar2) 61 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2) 62 | ) 63 | 64 | 65 | def approx_standard_normal_cdf(x): 66 | """ 67 | A fast approximation of the cumulative distribution function of the 68 | standard normal. 69 | """ 70 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) 71 | 72 | 73 | def continuous_gaussian_log_likelihood(x, *, means, log_scales): 74 | """ 75 | Compute the log-likelihood of a continuous Gaussian distribution. 76 | :param x: the targets 77 | :param means: the Gaussian mean Tensor. 78 | :param log_scales: the Gaussian log stddev Tensor. 79 | :return: a tensor like x of log probabilities (in nats). 80 | """ 81 | centered_x = x - means 82 | inv_stdv = th.exp(-log_scales) 83 | normalized_x = centered_x * inv_stdv 84 | log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x) 85 | return log_probs 86 | 87 | 88 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 89 | """ 90 | Compute the log-likelihood of a Gaussian distribution discretizing to a 91 | given image. 92 | :param x: the target images. It is assumed that this was uint8 values, 93 | rescaled to the range [-1, 1]. 94 | :param means: the Gaussian mean Tensor. 95 | :param log_scales: the Gaussian log stddev Tensor. 96 | :return: a tensor like x of log probabilities (in nats). 97 | """ 98 | assert x.shape == means.shape == log_scales.shape 99 | centered_x = x - means 100 | inv_stdv = th.exp(-log_scales) 101 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 102 | cdf_plus = approx_standard_normal_cdf(plus_in) 103 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 104 | cdf_min = approx_standard_normal_cdf(min_in) 105 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 106 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 107 | cdf_delta = cdf_plus - cdf_min 108 | log_probs = th.where( 109 | x < -0.999, 110 | log_cdf_plus, 111 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 112 | ) 113 | assert log_probs.shape == x.shape 114 | return log_probs -------------------------------------------------------------------------------- /visual_planner/diffusion_model/dist_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for distributed training. 3 | """ 4 | 5 | import io 6 | import os 7 | import socket 8 | 9 | import blobfile as bf 10 | from mpi4py import MPI 11 | import torch as th 12 | import torch.distributed as dist 13 | 14 | # Change this to reflect your cluster layout. 15 | # The GPU for a given rank is (rank % GPUS_PER_NODE). 16 | GPUS_PER_NODE = 8 17 | 18 | SETUP_RETRY_COUNT = 3 19 | 20 | 21 | def setup_dist(): 22 | """ 23 | Setup a distributed process group. 24 | """ 25 | if dist.is_initialized(): 26 | return 27 | os.environ["CUDA_VISIBLE_DEVICES"] = f"{MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE}" 28 | 29 | comm = MPI.COMM_WORLD 30 | backend = "gloo" if not th.cuda.is_available() else "nccl" 31 | 32 | if backend == "gloo": 33 | hostname = "localhost" 34 | else: 35 | hostname = socket.gethostbyname(socket.getfqdn()) 36 | os.environ["MASTER_ADDR"] = comm.bcast(hostname, root=0) 37 | os.environ["RANK"] = str(comm.rank) 38 | os.environ["WORLD_SIZE"] = str(comm.size) 39 | 40 | port = comm.bcast(_find_free_port(), root=0) 41 | os.environ["MASTER_PORT"] = str(port) 42 | dist.init_process_group(backend=backend, init_method="env://") 43 | 44 | 45 | def dev(): 46 | """ 47 | Get the device to use for torch.distributed. 48 | """ 49 | if th.cuda.is_available(): 50 | return th.device(f"cuda") 51 | return th.device("cpu") 52 | 53 | 54 | def load_state_dict(path, **kwargs): 55 | """ 56 | Load a PyTorch file without redundant fetches across MPI ranks. 57 | """ 58 | chunk_size = 2 ** 30 # MPI has a relatively small size limit 59 | if MPI.COMM_WORLD.Get_rank() == 0: 60 | with bf.BlobFile(path, "rb") as f: 61 | data = f.read() 62 | num_chunks = len(data) // chunk_size 63 | if len(data) % chunk_size: 64 | num_chunks += 1 65 | MPI.COMM_WORLD.bcast(num_chunks) 66 | for i in range(0, len(data), chunk_size): 67 | MPI.COMM_WORLD.bcast(data[i : i + chunk_size]) 68 | else: 69 | num_chunks = MPI.COMM_WORLD.bcast(None) 70 | data = bytes() 71 | for _ in range(num_chunks): 72 | data += MPI.COMM_WORLD.bcast(None) 73 | 74 | return th.load(io.BytesIO(data), **kwargs) 75 | 76 | 77 | def sync_params(params): 78 | """ 79 | Synchronize a sequence of Tensors across ranks from rank 0. 80 | """ 81 | for p in params: 82 | with th.no_grad(): 83 | dist.broadcast(p, 0) 84 | 85 | 86 | def _find_free_port(): 87 | try: 88 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 89 | s.bind(("", 0)) 90 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 91 | return s.getsockname()[1] 92 | finally: 93 | s.close() 94 | -------------------------------------------------------------------------------- /visual_planner/diffusion_model/fp16_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers to train with 16-bit precision. 3 | """ 4 | 5 | import numpy as np 6 | import torch as th 7 | import torch.nn as nn 8 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors 9 | 10 | from . import logger 11 | 12 | INITIAL_LOG_LOSS_SCALE = 20.0 13 | 14 | 15 | def convert_module_to_f16(l): 16 | """ 17 | Convert primitive modules to float16. 18 | """ 19 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 20 | l.weight.data = l.weight.data.half() 21 | if l.bias is not None: 22 | l.bias.data = l.bias.data.half() 23 | 24 | 25 | def convert_module_to_f32(l): 26 | """ 27 | Convert primitive modules to float32, undoing convert_module_to_f16(). 28 | """ 29 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 30 | l.weight.data = l.weight.data.float() 31 | if l.bias is not None: 32 | l.bias.data = l.bias.data.float() 33 | 34 | 35 | def make_master_params(param_groups_and_shapes): 36 | """ 37 | Copy model parameters into a (differently-shaped) list of full-precision 38 | parameters. 39 | """ 40 | master_params = [] 41 | for param_group, shape in param_groups_and_shapes: 42 | master_param = nn.Parameter( 43 | _flatten_dense_tensors( 44 | [param.detach().float() for (_, param) in param_group] 45 | ).view(shape) 46 | ) 47 | master_param.requires_grad = True 48 | master_params.append(master_param) 49 | return master_params 50 | 51 | 52 | def model_grads_to_master_grads(param_groups_and_shapes, master_params): 53 | """ 54 | Copy the gradients from the model parameters into the master parameters 55 | from make_master_params(). 56 | """ 57 | for master_param, (param_group, shape) in zip( 58 | master_params, param_groups_and_shapes 59 | ): 60 | master_param.grad = _flatten_dense_tensors( 61 | [param_grad_or_zeros(param) for (_, param) in param_group] 62 | ).view(shape) 63 | 64 | 65 | def master_params_to_model_params(param_groups_and_shapes, master_params): 66 | """ 67 | Copy the master parameter data back into the model parameters. 68 | """ 69 | # Without copying to a list, if a generator is passed, this will 70 | # silently not copy any parameters. 71 | for master_param, (param_group, _) in zip(master_params, param_groups_and_shapes): 72 | for (_, param), unflat_master_param in zip( 73 | param_group, unflatten_master_params(param_group, master_param.view(-1)) 74 | ): 75 | param.detach().copy_(unflat_master_param) 76 | 77 | 78 | def unflatten_master_params(param_group, master_param): 79 | return _unflatten_dense_tensors(master_param, [param for (_, param) in param_group]) 80 | 81 | 82 | def get_param_groups_and_shapes(named_model_params): 83 | named_model_params = list(named_model_params) 84 | scalar_vector_named_params = ( 85 | [(n, p) for (n, p) in named_model_params if p.ndim <= 1], 86 | (-1), 87 | ) 88 | matrix_named_params = ( 89 | [(n, p) for (n, p) in named_model_params if p.ndim > 1], 90 | (1, -1), 91 | ) 92 | return [scalar_vector_named_params, matrix_named_params] 93 | 94 | 95 | def master_params_to_state_dict( 96 | model, param_groups_and_shapes, master_params, use_fp16 97 | ): 98 | if use_fp16: 99 | state_dict = model.state_dict() 100 | for master_param, (param_group, _) in zip( 101 | master_params, param_groups_and_shapes 102 | ): 103 | for (name, _), unflat_master_param in zip( 104 | param_group, unflatten_master_params(param_group, master_param.view(-1)) 105 | ): 106 | assert name in state_dict 107 | state_dict[name] = unflat_master_param 108 | else: 109 | state_dict = model.state_dict() 110 | for i, (name, _value) in enumerate(model.named_parameters()): 111 | assert name in state_dict 112 | state_dict[name] = master_params[i] 113 | return state_dict 114 | 115 | 116 | def state_dict_to_master_params(model, state_dict, use_fp16): 117 | if use_fp16: 118 | named_model_params = [ 119 | (name, state_dict[name]) for name, _ in model.named_parameters() 120 | ] 121 | param_groups_and_shapes = get_param_groups_and_shapes(named_model_params) 122 | master_params = make_master_params(param_groups_and_shapes) 123 | else: 124 | master_params = [state_dict[name] for name, _ in model.named_parameters()] 125 | return master_params 126 | 127 | 128 | def zero_master_grads(master_params): 129 | for param in master_params: 130 | param.grad = None 131 | 132 | 133 | def zero_grad(model_params): 134 | for param in model_params: 135 | # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group 136 | if param.grad is not None: 137 | param.grad.detach_() 138 | param.grad.zero_() 139 | 140 | 141 | def param_grad_or_zeros(param): 142 | if param.grad is not None: 143 | return param.grad.data.detach() 144 | else: 145 | return th.zeros_like(param) 146 | 147 | 148 | class MixedPrecisionTrainer: 149 | def __init__( 150 | self, 151 | *, 152 | model, 153 | use_fp16=False, 154 | fp16_scale_growth=1e-3, 155 | initial_lg_loss_scale=INITIAL_LOG_LOSS_SCALE, 156 | ): 157 | self.model = model 158 | self.use_fp16 = use_fp16 159 | self.fp16_scale_growth = fp16_scale_growth 160 | 161 | self.model_params = list(self.model.parameters()) 162 | self.master_params = self.model_params 163 | self.param_groups_and_shapes = None 164 | self.lg_loss_scale = initial_lg_loss_scale 165 | 166 | if self.use_fp16: 167 | self.param_groups_and_shapes = get_param_groups_and_shapes( 168 | self.model.named_parameters() 169 | ) 170 | self.master_params = make_master_params(self.param_groups_and_shapes) 171 | self.model.convert_to_fp16() 172 | 173 | def zero_grad(self): 174 | zero_grad(self.model_params) 175 | 176 | def backward(self, loss: th.Tensor): 177 | if self.use_fp16: 178 | loss_scale = 2 ** self.lg_loss_scale 179 | (loss * loss_scale).backward() 180 | else: 181 | loss.backward() 182 | 183 | def optimize(self, opt: th.optim.Optimizer): 184 | if self.use_fp16: 185 | return self._optimize_fp16(opt) 186 | else: 187 | return self._optimize_normal(opt) 188 | 189 | def _optimize_fp16(self, opt: th.optim.Optimizer): 190 | logger.logkv_mean("lg_loss_scale", self.lg_loss_scale) 191 | model_grads_to_master_grads(self.param_groups_and_shapes, self.master_params) 192 | grad_norm, param_norm = self._compute_norms(grad_scale=2 ** self.lg_loss_scale) 193 | if check_overflow(grad_norm): 194 | self.lg_loss_scale -= 1 195 | logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}") 196 | zero_master_grads(self.master_params) 197 | return False 198 | 199 | logger.logkv_mean("grad_norm", grad_norm) 200 | logger.logkv_mean("param_norm", param_norm) 201 | 202 | for p in self.master_params: 203 | p.grad.mul_(1.0 / (2 ** self.lg_loss_scale)) 204 | opt.step() 205 | zero_master_grads(self.master_params) 206 | master_params_to_model_params(self.param_groups_and_shapes, self.master_params) 207 | self.lg_loss_scale += self.fp16_scale_growth 208 | return True 209 | 210 | def _optimize_normal(self, opt: th.optim.Optimizer): 211 | grad_norm, param_norm = self._compute_norms() 212 | logger.logkv_mean("grad_norm", grad_norm) 213 | logger.logkv_mean("param_norm", param_norm) 214 | opt.step() 215 | return True 216 | 217 | def _compute_norms(self, grad_scale=1.0): 218 | grad_norm = 0.0 219 | param_norm = 0.0 220 | for p in self.master_params: 221 | with th.no_grad(): 222 | param_norm += th.norm(p, p=2, dtype=th.float32).item() ** 2 223 | if p.grad is not None: 224 | grad_norm += th.norm(p.grad, p=2, dtype=th.float32).item() ** 2 225 | return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm) 226 | 227 | def master_params_to_state_dict(self, master_params): 228 | return master_params_to_state_dict( 229 | self.model, self.param_groups_and_shapes, master_params, self.use_fp16 230 | ) 231 | 232 | def state_dict_to_master_params(self, state_dict): 233 | return state_dict_to_master_params(self.model, state_dict, self.use_fp16) 234 | 235 | 236 | def check_overflow(value): 237 | return (value == float("inf")) or (value == -float("inf")) or (value != value) 238 | -------------------------------------------------------------------------------- /visual_planner/diffusion_model/image_datasets.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | 4 | from PIL import Image 5 | import blobfile as bf 6 | from mpi4py import MPI 7 | import numpy as np 8 | from torch.utils.data import DataLoader, Dataset 9 | 10 | 11 | def load_data( 12 | *, 13 | data_dir, 14 | batch_size, 15 | image_size, 16 | class_cond=False, 17 | deterministic=False, 18 | random_crop=False, 19 | random_flip=True, 20 | ): 21 | """ 22 | For a dataset, create a generator over (images, kwargs) pairs. 23 | 24 | Each images is an NCHW float tensor, and the kwargs dict contains zero or 25 | more keys, each of which map to a batched Tensor of their own. 26 | The kwargs dict can be used for class labels, in which case the key is "y" 27 | and the values are integer tensors of class labels. 28 | 29 | :param data_dir: a dataset directory. 30 | :param batch_size: the batch size of each returned pair. 31 | :param image_size: the size to which images are resized. 32 | :param class_cond: if True, include a "y" key in returned dicts for class 33 | label. If classes are not available and this is true, an 34 | exception will be raised. 35 | :param deterministic: if True, yield results in a deterministic order. 36 | :param random_crop: if True, randomly crop the images for augmentation. 37 | :param random_flip: if True, randomly flip the images for augmentation. 38 | """ 39 | if not data_dir: 40 | raise ValueError("unspecified data directory") 41 | all_files = _list_image_files_recursively(data_dir) 42 | classes = None 43 | if class_cond: 44 | # Assume classes are the first part of the filename, 45 | # before an underscore. 46 | class_names = [bf.basename(path).split("_")[0] for path in all_files] 47 | sorted_classes = {x: i for i, x in enumerate(sorted(set(class_names)))} 48 | classes = [sorted_classes[x] for x in class_names] 49 | dataset = ImageDataset( 50 | image_size, 51 | all_files, 52 | classes=classes, 53 | shard=MPI.COMM_WORLD.Get_rank(), 54 | num_shards=MPI.COMM_WORLD.Get_size(), 55 | random_crop=random_crop, 56 | random_flip=random_flip, 57 | ) 58 | if deterministic: 59 | loader = DataLoader( 60 | dataset, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=True 61 | ) 62 | else: 63 | loader = DataLoader( 64 | dataset, batch_size=batch_size, shuffle=True, num_workers=1, drop_last=True 65 | ) 66 | while True: 67 | yield from loader 68 | 69 | 70 | def _list_image_files_recursively(data_dir): 71 | results = [] 72 | for entry in sorted(bf.listdir(data_dir)): 73 | full_path = bf.join(data_dir, entry) 74 | ext = entry.split(".")[-1] 75 | if "." in entry and ext.lower() in ["jpg", "jpeg", "png", "gif"]: 76 | results.append(full_path) 77 | elif bf.isdir(full_path): 78 | results.extend(_list_image_files_recursively(full_path)) 79 | return results 80 | 81 | 82 | class ImageDataset(Dataset): 83 | def __init__( 84 | self, 85 | resolution, 86 | image_paths, 87 | classes=None, 88 | shard=0, 89 | num_shards=1, 90 | random_crop=False, 91 | random_flip=True, 92 | ): 93 | super().__init__() 94 | self.resolution = resolution 95 | self.local_images = image_paths[shard:][::num_shards] 96 | self.local_classes = None if classes is None else classes[shard:][::num_shards] 97 | self.random_crop = random_crop 98 | self.random_flip = random_flip 99 | 100 | def __len__(self): 101 | return len(self.local_images) 102 | 103 | def __getitem__(self, idx): 104 | path = self.local_images[idx] 105 | with bf.BlobFile(path, "rb") as f: 106 | pil_image = Image.open(f) 107 | pil_image.load() 108 | pil_image = pil_image.convert("RGB") 109 | 110 | if self.random_crop: 111 | arr = random_crop_arr(pil_image, self.resolution) 112 | else: 113 | arr = center_crop_arr(pil_image, self.resolution) 114 | 115 | if self.random_flip and random.random() < 0.5: 116 | arr = arr[:, ::-1] 117 | 118 | arr = arr.astype(np.float32) / 127.5 - 1 119 | 120 | out_dict = {} 121 | if self.local_classes is not None: 122 | out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64) 123 | return np.transpose(arr, [2, 0, 1]), out_dict 124 | 125 | 126 | def center_crop_arr(pil_image, image_size): 127 | # We are not on a new enough PIL to support the `reducing_gap` 128 | # argument, which uses BOX downsampling at powers of two first. 129 | # Thus, we do it by hand to improve downsample quality. 130 | while min(*pil_image.size) >= 2 * image_size: 131 | pil_image = pil_image.resize( 132 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX 133 | ) 134 | 135 | scale = image_size / min(*pil_image.size) 136 | pil_image = pil_image.resize( 137 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC 138 | ) 139 | 140 | arr = np.array(pil_image) 141 | crop_y = (arr.shape[0] - image_size) // 2 142 | crop_x = (arr.shape[1] - image_size) // 2 143 | return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size] 144 | 145 | 146 | def random_crop_arr(pil_image, image_size, min_crop_frac=0.8, max_crop_frac=1.0): 147 | min_smaller_dim_size = math.ceil(image_size / max_crop_frac) 148 | max_smaller_dim_size = math.ceil(image_size / min_crop_frac) 149 | smaller_dim_size = random.randrange(min_smaller_dim_size, max_smaller_dim_size + 1) 150 | 151 | # We are not on a new enough PIL to support the `reducing_gap` 152 | # argument, which uses BOX downsampling at powers of two first. 153 | # Thus, we do it by hand to improve downsample quality. 154 | while min(*pil_image.size) >= 2 * smaller_dim_size: 155 | pil_image = pil_image.resize( 156 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX 157 | ) 158 | 159 | scale = smaller_dim_size / min(*pil_image.size) 160 | pil_image = pil_image.resize( 161 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC 162 | ) 163 | 164 | arr = np.array(pil_image) 165 | crop_y = random.randrange(arr.shape[0] - image_size + 1) 166 | crop_x = random.randrange(arr.shape[1] - image_size + 1) 167 | return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size] 168 | -------------------------------------------------------------------------------- /visual_planner/diffusion_model/logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | Logger copied from OpenAI baselines to avoid extra RL-based dependencies: 3 | https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/logger.py 4 | """ 5 | 6 | import os 7 | import sys 8 | import shutil 9 | import os.path as osp 10 | import json 11 | import time 12 | import datetime 13 | import tempfile 14 | import warnings 15 | from collections import defaultdict 16 | from contextlib import contextmanager 17 | 18 | DEBUG = 10 19 | INFO = 20 20 | WARN = 30 21 | ERROR = 40 22 | 23 | DISABLED = 50 24 | 25 | 26 | class KVWriter(object): 27 | def writekvs(self, kvs): 28 | raise NotImplementedError 29 | 30 | 31 | class SeqWriter(object): 32 | def writeseq(self, seq): 33 | raise NotImplementedError 34 | 35 | 36 | class HumanOutputFormat(KVWriter, SeqWriter): 37 | def __init__(self, filename_or_file): 38 | if isinstance(filename_or_file, str): 39 | self.file = open(filename_or_file, "wt") 40 | self.own_file = True 41 | else: 42 | assert hasattr(filename_or_file, "read"), ( 43 | "expected file or str, got %s" % filename_or_file 44 | ) 45 | self.file = filename_or_file 46 | self.own_file = False 47 | 48 | def writekvs(self, kvs): 49 | # Create strings for printing 50 | key2str = {} 51 | for (key, val) in sorted(kvs.items()): 52 | if hasattr(val, "__float__"): 53 | valstr = "%-8.3g" % val 54 | else: 55 | valstr = str(val) 56 | key2str[self._truncate(key)] = self._truncate(valstr) 57 | 58 | # Find max widths 59 | if len(key2str) == 0: 60 | print("WARNING: tried to write empty key-value dict") 61 | return 62 | else: 63 | keywidth = max(map(len, key2str.keys())) 64 | valwidth = max(map(len, key2str.values())) 65 | 66 | # Write out the data 67 | dashes = "-" * (keywidth + valwidth + 7) 68 | lines = [dashes] 69 | for (key, val) in sorted(key2str.items(), key=lambda kv: kv[0].lower()): 70 | lines.append( 71 | "| %s%s | %s%s |" 72 | % (key, " " * (keywidth - len(key)), val, " " * (valwidth - len(val))) 73 | ) 74 | lines.append(dashes) 75 | self.file.write("\n".join(lines) + "\n") 76 | 77 | # Flush the output to the file 78 | self.file.flush() 79 | 80 | def _truncate(self, s): 81 | maxlen = 30 82 | return s[: maxlen - 3] + "..." if len(s) > maxlen else s 83 | 84 | def writeseq(self, seq): 85 | seq = list(seq) 86 | for (i, elem) in enumerate(seq): 87 | self.file.write(elem) 88 | if i < len(seq) - 1: # add space unless this is the last one 89 | self.file.write(" ") 90 | self.file.write("\n") 91 | self.file.flush() 92 | 93 | def close(self): 94 | if self.own_file: 95 | self.file.close() 96 | 97 | 98 | class JSONOutputFormat(KVWriter): 99 | def __init__(self, filename): 100 | self.file = open(filename, "wt") 101 | 102 | def writekvs(self, kvs): 103 | for k, v in sorted(kvs.items()): 104 | if hasattr(v, "dtype"): 105 | kvs[k] = float(v) 106 | self.file.write(json.dumps(kvs) + "\n") 107 | self.file.flush() 108 | 109 | def close(self): 110 | self.file.close() 111 | 112 | 113 | class CSVOutputFormat(KVWriter): 114 | def __init__(self, filename): 115 | self.file = open(filename, "w+t") 116 | self.keys = [] 117 | self.sep = "," 118 | 119 | def writekvs(self, kvs): 120 | # Add our current row to the history 121 | extra_keys = list(kvs.keys() - self.keys) 122 | extra_keys.sort() 123 | if extra_keys: 124 | self.keys.extend(extra_keys) 125 | self.file.seek(0) 126 | lines = self.file.readlines() 127 | self.file.seek(0) 128 | for (i, k) in enumerate(self.keys): 129 | if i > 0: 130 | self.file.write(",") 131 | self.file.write(k) 132 | self.file.write("\n") 133 | for line in lines[1:]: 134 | self.file.write(line[:-1]) 135 | self.file.write(self.sep * len(extra_keys)) 136 | self.file.write("\n") 137 | for (i, k) in enumerate(self.keys): 138 | if i > 0: 139 | self.file.write(",") 140 | v = kvs.get(k) 141 | if v is not None: 142 | self.file.write(str(v)) 143 | self.file.write("\n") 144 | self.file.flush() 145 | 146 | def close(self): 147 | self.file.close() 148 | 149 | 150 | class TensorBoardOutputFormat(KVWriter): 151 | """ 152 | Dumps key/value pairs into TensorBoard's numeric format. 153 | """ 154 | 155 | def __init__(self, dir): 156 | os.makedirs(dir, exist_ok=True) 157 | self.dir = dir 158 | self.step = 1 159 | prefix = "events" 160 | path = osp.join(osp.abspath(dir), prefix) 161 | import tensorflow as tf 162 | from tensorflow.python import pywrap_tensorflow 163 | from tensorflow.core.util import event_pb2 164 | from tensorflow.python.util import compat 165 | 166 | self.tf = tf 167 | self.event_pb2 = event_pb2 168 | self.pywrap_tensorflow = pywrap_tensorflow 169 | self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path)) 170 | 171 | def writekvs(self, kvs): 172 | def summary_val(k, v): 173 | kwargs = {"tag": k, "simple_value": float(v)} 174 | return self.tf.Summary.Value(**kwargs) 175 | 176 | summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()]) 177 | event = self.event_pb2.Event(wall_time=time.time(), summary=summary) 178 | event.step = ( 179 | self.step 180 | ) # is there any reason why you'd want to specify the step? 181 | self.writer.WriteEvent(event) 182 | self.writer.Flush() 183 | self.step += 1 184 | 185 | def close(self): 186 | if self.writer: 187 | self.writer.Close() 188 | self.writer = None 189 | 190 | 191 | def make_output_format(format, ev_dir, log_suffix=""): 192 | os.makedirs(ev_dir, exist_ok=True) 193 | if format == "stdout": 194 | return HumanOutputFormat(sys.stdout) 195 | elif format == "log": 196 | return HumanOutputFormat(osp.join(ev_dir, "log%s.txt" % log_suffix)) 197 | elif format == "json": 198 | return JSONOutputFormat(osp.join(ev_dir, "progress%s.json" % log_suffix)) 199 | elif format == "csv": 200 | return CSVOutputFormat(osp.join(ev_dir, "progress%s.csv" % log_suffix)) 201 | elif format == "tensorboard": 202 | return TensorBoardOutputFormat(osp.join(ev_dir, "tb%s" % log_suffix)) 203 | else: 204 | raise ValueError("Unknown format specified: %s" % (format,)) 205 | 206 | 207 | # ================================================================ 208 | # API 209 | # ================================================================ 210 | 211 | 212 | def logkv(key, val): 213 | """ 214 | Log a value of some diagnostic 215 | Call this once for each diagnostic quantity, each iteration 216 | If called many times, last value will be used. 217 | """ 218 | get_current().logkv(key, val) 219 | 220 | 221 | def logkv_mean(key, val): 222 | """ 223 | The same as logkv(), but if called many times, values averaged. 224 | """ 225 | get_current().logkv_mean(key, val) 226 | 227 | 228 | def logkvs(d): 229 | """ 230 | Log a dictionary of key-value pairs 231 | """ 232 | for (k, v) in d.items(): 233 | logkv(k, v) 234 | 235 | 236 | def dumpkvs(): 237 | """ 238 | Write all of the diagnostics from the current iteration 239 | """ 240 | return get_current().dumpkvs() 241 | 242 | 243 | def getkvs(): 244 | return get_current().name2val 245 | 246 | 247 | def log(*args, level=INFO): 248 | """ 249 | Write the sequence of args, with no separators, to the console and output files (if you've configured an output file). 250 | """ 251 | get_current().log(*args, level=level) 252 | 253 | 254 | def debug(*args): 255 | log(*args, level=DEBUG) 256 | 257 | 258 | def info(*args): 259 | log(*args, level=INFO) 260 | 261 | 262 | def warn(*args): 263 | log(*args, level=WARN) 264 | 265 | 266 | def error(*args): 267 | log(*args, level=ERROR) 268 | 269 | 270 | def set_level(level): 271 | """ 272 | Set logging threshold on current logger. 273 | """ 274 | get_current().set_level(level) 275 | 276 | 277 | def set_comm(comm): 278 | get_current().set_comm(comm) 279 | 280 | 281 | def get_dir(): 282 | """ 283 | Get directory that log files are being written to. 284 | will be None if there is no output directory (i.e., if you didn't call start) 285 | """ 286 | return get_current().get_dir() 287 | 288 | 289 | record_tabular = logkv 290 | dump_tabular = dumpkvs 291 | 292 | 293 | @contextmanager 294 | def profile_kv(scopename): 295 | logkey = "wait_" + scopename 296 | tstart = time.time() 297 | try: 298 | yield 299 | finally: 300 | get_current().name2val[logkey] += time.time() - tstart 301 | 302 | 303 | def profile(n): 304 | """ 305 | Usage: 306 | @profile("my_func") 307 | def my_func(): code 308 | """ 309 | 310 | def decorator_with_name(func): 311 | def func_wrapper(*args, **kwargs): 312 | with profile_kv(n): 313 | return func(*args, **kwargs) 314 | 315 | return func_wrapper 316 | 317 | return decorator_with_name 318 | 319 | 320 | # ================================================================ 321 | # Backend 322 | # ================================================================ 323 | 324 | 325 | def get_current(): 326 | if Logger.CURRENT is None: 327 | _configure_default_logger() 328 | 329 | return Logger.CURRENT 330 | 331 | 332 | class Logger(object): 333 | DEFAULT = None # A logger with no output files. (See right below class definition) 334 | # So that you can still log to the terminal without setting up any output files 335 | CURRENT = None # Current logger being used by the free functions above 336 | 337 | def __init__(self, dir, output_formats, comm=None): 338 | self.name2val = defaultdict(float) # values this iteration 339 | self.name2cnt = defaultdict(int) 340 | self.level = INFO 341 | self.dir = dir 342 | self.output_formats = output_formats 343 | self.comm = comm 344 | 345 | # Logging API, forwarded 346 | # ---------------------------------------- 347 | def logkv(self, key, val): 348 | self.name2val[key] = val 349 | 350 | def logkv_mean(self, key, val): 351 | oldval, cnt = self.name2val[key], self.name2cnt[key] 352 | self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1) 353 | self.name2cnt[key] = cnt + 1 354 | 355 | def dumpkvs(self): 356 | if self.comm is None: 357 | d = self.name2val 358 | else: 359 | d = mpi_weighted_mean( 360 | self.comm, 361 | { 362 | name: (val, self.name2cnt.get(name, 1)) 363 | for (name, val) in self.name2val.items() 364 | }, 365 | ) 366 | if self.comm.rank != 0: 367 | d["dummy"] = 1 # so we don't get a warning about empty dict 368 | out = d.copy() # Return the dict for unit testing purposes 369 | for fmt in self.output_formats: 370 | if isinstance(fmt, KVWriter): 371 | fmt.writekvs(d) 372 | self.name2val.clear() 373 | self.name2cnt.clear() 374 | return out 375 | 376 | def log(self, *args, level=INFO): 377 | if self.level <= level: 378 | self._do_log(args) 379 | 380 | # Configuration 381 | # ---------------------------------------- 382 | def set_level(self, level): 383 | self.level = level 384 | 385 | def set_comm(self, comm): 386 | self.comm = comm 387 | 388 | def get_dir(self): 389 | return self.dir 390 | 391 | def close(self): 392 | for fmt in self.output_formats: 393 | fmt.close() 394 | 395 | # Misc 396 | # ---------------------------------------- 397 | def _do_log(self, args): 398 | for fmt in self.output_formats: 399 | if isinstance(fmt, SeqWriter): 400 | fmt.writeseq(map(str, args)) 401 | 402 | 403 | def get_rank_without_mpi_import(): 404 | # check environment variables here instead of importing mpi4py 405 | # to avoid calling MPI_Init() when this module is imported 406 | for varname in ["PMI_RANK", "OMPI_COMM_WORLD_RANK"]: 407 | if varname in os.environ: 408 | return int(os.environ[varname]) 409 | return 0 410 | 411 | 412 | def mpi_weighted_mean(comm, local_name2valcount): 413 | """ 414 | Copied from: https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/mpi_util.py#L110 415 | Perform a weighted average over dicts that are each on a different node 416 | Input: local_name2valcount: dict mapping key -> (value, count) 417 | Returns: key -> mean 418 | """ 419 | all_name2valcount = comm.gather(local_name2valcount) 420 | if comm.rank == 0: 421 | name2sum = defaultdict(float) 422 | name2count = defaultdict(float) 423 | for n2vc in all_name2valcount: 424 | for (name, (val, count)) in n2vc.items(): 425 | try: 426 | val = float(val) 427 | except ValueError: 428 | if comm.rank == 0: 429 | warnings.warn( 430 | "WARNING: tried to compute mean on non-float {}={}".format( 431 | name, val 432 | ) 433 | ) 434 | else: 435 | name2sum[name] += val * count 436 | name2count[name] += count 437 | return {name: name2sum[name] / name2count[name] for name in name2sum} 438 | else: 439 | return {} 440 | 441 | 442 | def configure(dir=None, format_strs=None, comm=None, log_suffix=""): 443 | """ 444 | If comm is provided, average all numerical stats across that comm 445 | """ 446 | if dir is None: 447 | dir = os.getenv("OPENAI_LOGDIR") 448 | if dir is None: 449 | dir = osp.join( 450 | tempfile.gettempdir(), 451 | datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f"), 452 | ) 453 | assert isinstance(dir, str) 454 | dir = os.path.expanduser(dir) 455 | os.makedirs(os.path.expanduser(dir), exist_ok=True) 456 | 457 | rank = get_rank_without_mpi_import() 458 | if rank > 0: 459 | log_suffix = log_suffix + "-rank%03i" % rank 460 | 461 | if format_strs is None: 462 | if rank == 0: 463 | format_strs = os.getenv("OPENAI_LOG_FORMAT", "stdout,log,csv").split(",") 464 | else: 465 | format_strs = os.getenv("OPENAI_LOG_FORMAT_MPI", "log").split(",") 466 | format_strs = filter(None, format_strs) 467 | output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs] 468 | 469 | Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm) 470 | if output_formats: 471 | log("Logging to %s" % dir) 472 | 473 | 474 | def _configure_default_logger(): 475 | configure() 476 | Logger.DEFAULT = Logger.CURRENT 477 | 478 | 479 | def reset(): 480 | if Logger.CURRENT is not Logger.DEFAULT: 481 | Logger.CURRENT.close() 482 | Logger.CURRENT = Logger.DEFAULT 483 | log("Reset logger") 484 | 485 | 486 | @contextmanager 487 | def scoped_configure(dir=None, format_strs=None, comm=None): 488 | prevlogger = Logger.CURRENT 489 | configure(dir=dir, format_strs=format_strs, comm=comm) 490 | try: 491 | yield 492 | finally: 493 | Logger.CURRENT.close() 494 | Logger.CURRENT = prevlogger 495 | 496 | -------------------------------------------------------------------------------- /visual_planner/diffusion_model/losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for various likelihood-based losses. These are ported from the original 3 | Ho et al. diffusion models codebase: 4 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py 5 | """ 6 | 7 | import numpy as np 8 | 9 | import torch as th 10 | 11 | 12 | def normal_kl(mean1, logvar1, mean2, logvar2): 13 | """ 14 | Compute the KL divergence between two gaussians. 15 | 16 | Shapes are automatically broadcasted, so batches can be compared to 17 | scalars, among other use cases. 18 | """ 19 | tensor = None 20 | for obj in (mean1, logvar1, mean2, logvar2): 21 | if isinstance(obj, th.Tensor): 22 | tensor = obj 23 | break 24 | assert tensor is not None, "at least one argument must be a Tensor" 25 | 26 | # Force variances to be Tensors. Broadcasting helps convert scalars to 27 | # Tensors, but it does not work for th.exp(). 28 | logvar1, logvar2 = [ 29 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) 30 | for x in (logvar1, logvar2) 31 | ] 32 | 33 | return 0.5 * ( 34 | -1.0 35 | + logvar2 36 | - logvar1 37 | + th.exp(logvar1 - logvar2) 38 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2) 39 | ) 40 | 41 | 42 | def approx_standard_normal_cdf(x): 43 | """ 44 | A fast approximation of the cumulative distribution function of the 45 | standard normal. 46 | """ 47 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) 48 | 49 | 50 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 51 | """ 52 | Compute the log-likelihood of a Gaussian distribution discretizing to a 53 | given image. 54 | 55 | :param x: the target images. It is assumed that this was uint8 values, 56 | rescaled to the range [-1, 1]. 57 | :param means: the Gaussian mean Tensor. 58 | :param log_scales: the Gaussian log stddev Tensor. 59 | :return: a tensor like x of log probabilities (in nats). 60 | """ 61 | assert x.shape == means.shape == log_scales.shape 62 | centered_x = x - means 63 | inv_stdv = th.exp(-log_scales) 64 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 65 | cdf_plus = approx_standard_normal_cdf(plus_in) 66 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 67 | cdf_min = approx_standard_normal_cdf(min_in) 68 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 69 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 70 | cdf_delta = cdf_plus - cdf_min 71 | log_probs = th.where( 72 | x < -0.999, 73 | log_cdf_plus, 74 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 75 | ) 76 | assert log_probs.shape == x.shape 77 | return log_probs 78 | -------------------------------------------------------------------------------- /visual_planner/diffusion_model/nn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various utilities for neural networks. 3 | """ 4 | 5 | import math 6 | 7 | import torch as th 8 | import torch.nn as nn 9 | from einops import rearrange 10 | from torch.nn import functional as F 11 | 12 | def exists(val): 13 | return val is not None 14 | 15 | def default(val, d): 16 | if val is not None: 17 | return val 18 | return d() if callable(d) else d 19 | 20 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 21 | class SiLU(nn.Module): 22 | def forward(self, x): 23 | return x * th.sigmoid(x) 24 | 25 | 26 | class GroupNorm32(nn.GroupNorm): 27 | def forward(self, x): 28 | return super().forward(x.float()).type(x.dtype) 29 | 30 | class Conv3d(nn.Module): 31 | def __init__( 32 | self, 33 | dim, 34 | dim_out = None, 35 | kernel_size = 3, 36 | stride = [1, 1, 1], 37 | *, 38 | temporal_kernel_size = None, 39 | **kwargs 40 | ): 41 | super().__init__() 42 | dim_out = default(dim_out, dim) 43 | temporal_kernel_size = default(temporal_kernel_size, kernel_size) 44 | 45 | self.spatial_conv = nn.Conv2d(dim, dim_out, kernel_size = kernel_size, padding = kernel_size // 2, stride = stride[1:]) 46 | self.temporal_conv = nn.Conv1d(dim_out, dim_out, kernel_size = temporal_kernel_size) if kernel_size > 1 else None 47 | self.kernel_size = kernel_size 48 | 49 | if exists(self.temporal_conv): 50 | nn.init.dirac_(self.temporal_conv.weight.data) # initialized to be identity 51 | nn.init.zeros_(self.temporal_conv.bias.data) 52 | 53 | def forward( 54 | self, 55 | x, 56 | ignore_time = False 57 | ): 58 | b = x.shape[0] 59 | 60 | is_video = x.ndim == 5 61 | ignore_time &= is_video 62 | 63 | if is_video: 64 | x = rearrange(x, 'b c f h w -> (b f) c h w') 65 | 66 | x = self.spatial_conv(x) 67 | 68 | if is_video: 69 | x = rearrange(x, '(b f) c h w -> b c f h w', b = b) 70 | 71 | if ignore_time or not exists(self.temporal_conv): 72 | return x 73 | 74 | h, w = x.shape[-2:] 75 | 76 | x = rearrange(x, 'b c f h w -> (b h w) c f') 77 | 78 | # causal temporal convolution - time is causal in imagen-video 79 | 80 | if self.kernel_size > 1: 81 | x = F.pad(x, (self.kernel_size//2, self.kernel_size//2)) 82 | 83 | x = self.temporal_conv(x) 84 | 85 | x = rearrange(x, '(b h w) c f -> b c f h w', h = h, w = w) 86 | 87 | return x 88 | 89 | 90 | def conv_nd(dims, *args, **kwargs): 91 | """ 92 | Create a 1D, 2D, or 3D convolution module. 93 | """ 94 | if dims == 1: 95 | return nn.Conv1d(*args, **kwargs) 96 | elif dims == 2: 97 | return nn.Conv2d(*args, **kwargs) 98 | elif dims == 3: 99 | return Conv3d(*args, **kwargs) 100 | raise ValueError(f"unsupported dimensions: {dims}") 101 | 102 | 103 | def linear(*args, **kwargs): 104 | """ 105 | Create a linear module. 106 | """ 107 | return nn.Linear(*args, **kwargs) 108 | 109 | 110 | def avg_pool_nd(dims, *args, **kwargs): 111 | """ 112 | Create a 1D, 2D, or 3D average pooling module. 113 | """ 114 | if dims == 1: 115 | return nn.AvgPool1d(*args, **kwargs) 116 | elif dims == 2: 117 | return nn.AvgPool2d(*args, **kwargs) 118 | elif dims == 3: 119 | return nn.AvgPool3d(*args, **kwargs) 120 | raise ValueError(f"unsupported dimensions: {dims}") 121 | 122 | 123 | def update_ema(target_params, source_params, rate=0.99): 124 | """ 125 | Update target parameters to be closer to those of source parameters using 126 | an exponential moving average. 127 | 128 | :param target_params: the target parameter sequence. 129 | :param source_params: the source parameter sequence. 130 | :param rate: the EMA rate (closer to 1 means slower). 131 | """ 132 | for targ, src in zip(target_params, source_params): 133 | targ.detach().mul_(rate).add_(src, alpha=1 - rate) 134 | 135 | 136 | def zero_module(module): 137 | """ 138 | Zero out the parameters of a module and return it. 139 | """ 140 | for p in module.parameters(): 141 | p.detach().zero_() 142 | return module 143 | 144 | 145 | def scale_module(module, scale): 146 | """ 147 | Scale the parameters of a module and return it. 148 | """ 149 | for p in module.parameters(): 150 | p.detach().mul_(scale) 151 | return module 152 | 153 | 154 | def mean_flat(tensor): 155 | """ 156 | Take the mean over all non-batch dimensions. 157 | """ 158 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 159 | 160 | 161 | def normalization(channels): 162 | """ 163 | Make a standard normalization layer. 164 | 165 | :param channels: number of input channels. 166 | :return: an nn.Module for normalization. 167 | """ 168 | return GroupNorm32(32, channels) 169 | 170 | 171 | def timestep_embedding(timesteps, dim, max_period=10000): 172 | """ 173 | Create sinusoidal timestep embeddings. 174 | 175 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 176 | These may be fractional. 177 | :param dim: the dimension of the output. 178 | :param max_period: controls the minimum frequency of the embeddings. 179 | :return: an [N x dim] Tensor of positional embeddings. 180 | """ 181 | half = dim // 2 182 | freqs = th.exp( 183 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half 184 | ).to(device=timesteps.device) 185 | args = timesteps[:, None].float() * freqs[None] 186 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) 187 | if dim % 2: 188 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) 189 | return embedding 190 | 191 | 192 | def checkpoint(func, inputs, params, flag): 193 | """ 194 | Evaluate a function without caching intermediate activations, allowing for 195 | reduced memory at the expense of extra compute in the backward pass. 196 | 197 | :param func: the function to evaluate. 198 | :param inputs: the argument sequence to pass to `func`. 199 | :param params: a sequence of parameters `func` depends on but does not 200 | explicitly take as arguments. 201 | :param flag: if False, disable gradient checkpointing. 202 | """ 203 | if flag: 204 | args = tuple(inputs) + tuple(params) 205 | return CheckpointFunction.apply(func, len(inputs), *args) 206 | else: 207 | return func(*inputs) 208 | 209 | 210 | class CheckpointFunction(th.autograd.Function): 211 | @staticmethod 212 | def forward(ctx, run_function, length, *args): 213 | ctx.run_function = run_function 214 | ctx.input_tensors = list(args[:length]) 215 | ctx.input_params = list(args[length:]) 216 | with th.no_grad(): 217 | output_tensors = ctx.run_function(*ctx.input_tensors) 218 | return output_tensors 219 | 220 | @staticmethod 221 | def backward(ctx, *output_grads): 222 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 223 | with th.enable_grad(): 224 | # Fixes a bug where the first op in run_function modifies the 225 | # Tensor storage in place, which is not allowed for detach()'d 226 | # Tensors. 227 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 228 | output_tensors = ctx.run_function(*shallow_copies) 229 | input_grads = th.autograd.grad( 230 | output_tensors, 231 | ctx.input_tensors + ctx.input_params, 232 | output_grads, 233 | allow_unused=True, 234 | ) 235 | del ctx.input_tensors 236 | del ctx.input_params 237 | del output_tensors 238 | return (None, None) + input_grads 239 | -------------------------------------------------------------------------------- /visual_planner/diffusion_model/resample.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import numpy as np 4 | import torch as th 5 | import torch.distributed as dist 6 | 7 | 8 | def create_named_schedule_sampler(name, diffusion): 9 | """ 10 | Create a ScheduleSampler from a library of pre-defined samplers. 11 | 12 | :param name: the name of the sampler. 13 | :param diffusion: the diffusion object to sample for. 14 | """ 15 | if name == "uniform": 16 | return UniformSampler(diffusion) 17 | elif name == "loss-second-moment": 18 | return LossSecondMomentResampler(diffusion) 19 | else: 20 | raise NotImplementedError(f"unknown schedule sampler: {name}") 21 | 22 | 23 | class ScheduleSampler(ABC): 24 | """ 25 | A distribution over timesteps in the diffusion process, intended to reduce 26 | variance of the objective. 27 | 28 | By default, samplers perform unbiased importance sampling, in which the 29 | objective's mean is unchanged. 30 | However, subclasses may override sample() to change how the resampled 31 | terms are reweighted, allowing for actual changes in the objective. 32 | """ 33 | 34 | @abstractmethod 35 | def weights(self): 36 | """ 37 | Get a numpy array of weights, one per diffusion step. 38 | 39 | The weights needn't be normalized, but must be positive. 40 | """ 41 | 42 | def sample(self, batch_size, device): 43 | """ 44 | Importance-sample timesteps for a batch. 45 | 46 | :param batch_size: the number of timesteps. 47 | :param device: the torch device to save to. 48 | :return: a tuple (timesteps, weights): 49 | - timesteps: a tensor of timestep indices. 50 | - weights: a tensor of weights to scale the resulting losses. 51 | """ 52 | w = self.weights() 53 | p = w / np.sum(w) 54 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p) 55 | indices = th.from_numpy(indices_np).long().to(device) 56 | weights_np = 1 / (len(p) * p[indices_np]) 57 | weights = th.from_numpy(weights_np).float().to(device) 58 | return indices, weights 59 | 60 | 61 | class UniformSampler(ScheduleSampler): 62 | def __init__(self, diffusion): 63 | self.diffusion = diffusion 64 | self._weights = np.ones([diffusion.num_timesteps]) 65 | 66 | def weights(self): 67 | return self._weights 68 | 69 | 70 | class LossAwareSampler(ScheduleSampler): 71 | def update_with_local_losses(self, local_ts, local_losses): 72 | """ 73 | Update the reweighting using losses from a model. 74 | 75 | Call this method from each rank with a batch of timesteps and the 76 | corresponding losses for each of those timesteps. 77 | This method will perform synchronization to make sure all of the ranks 78 | maintain the exact same reweighting. 79 | 80 | :param local_ts: an integer Tensor of timesteps. 81 | :param local_losses: a 1D Tensor of losses. 82 | """ 83 | batch_sizes = [ 84 | th.tensor([0], dtype=th.int32, device=local_ts.device) 85 | for _ in range(dist.get_world_size()) 86 | ] 87 | dist.all_gather( 88 | batch_sizes, 89 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), 90 | ) 91 | 92 | # Pad all_gather batches to be the maximum batch size. 93 | batch_sizes = [x.item() for x in batch_sizes] 94 | max_bs = max(batch_sizes) 95 | 96 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] 97 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] 98 | dist.all_gather(timestep_batches, local_ts) 99 | dist.all_gather(loss_batches, local_losses) 100 | timesteps = [ 101 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] 102 | ] 103 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] 104 | self.update_with_all_losses(timesteps, losses) 105 | 106 | @abstractmethod 107 | def update_with_all_losses(self, ts, losses): 108 | """ 109 | Update the reweighting using losses from a model. 110 | 111 | Sub-classes should override this method to update the reweighting 112 | using losses from the model. 113 | 114 | This method directly updates the reweighting without synchronizing 115 | between workers. It is called by update_with_local_losses from all 116 | ranks with identical arguments. Thus, it should have deterministic 117 | behavior to maintain state across workers. 118 | 119 | :param ts: a list of int timesteps. 120 | :param losses: a list of float losses, one per timestep. 121 | """ 122 | 123 | 124 | class LossSecondMomentResampler(LossAwareSampler): 125 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): 126 | self.diffusion = diffusion 127 | self.history_per_term = history_per_term 128 | self.uniform_prob = uniform_prob 129 | self._loss_history = np.zeros( 130 | [diffusion.num_timesteps, history_per_term], dtype=np.float64 131 | ) 132 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) 133 | 134 | def weights(self): 135 | if not self._warmed_up(): 136 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64) 137 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) 138 | weights /= np.sum(weights) 139 | weights *= 1 - self.uniform_prob 140 | weights += self.uniform_prob / len(weights) 141 | return weights 142 | 143 | def update_with_all_losses(self, ts, losses): 144 | for t, loss in zip(ts, losses): 145 | if self._loss_counts[t] == self.history_per_term: 146 | # Shift out the oldest loss term. 147 | self._loss_history[t, :-1] = self._loss_history[t, 1:] 148 | self._loss_history[t, -1] = loss 149 | else: 150 | self._loss_history[t, self._loss_counts[t]] = loss 151 | self._loss_counts[t] += 1 152 | 153 | def _warmed_up(self): 154 | return (self._loss_counts == self.history_per_term).all() 155 | -------------------------------------------------------------------------------- /visual_planner/diffusion_model/respace.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | import torch 6 | import numpy as np 7 | import torch as th 8 | 9 | from .gaussian_diffusion import GaussianDiffusion 10 | 11 | 12 | def space_timesteps(num_timesteps, section_counts): 13 | """ 14 | Create a list of timesteps to use from an original diffusion process, 15 | given the number of timesteps we want to take from equally-sized portions 16 | of the original process. 17 | For example, if there's 300 timesteps and the section counts are [10,15,20] 18 | then the first 100 timesteps are strided to be 10 timesteps, the second 100 19 | are strided to be 15 timesteps, and the final 100 are strided to be 20. 20 | If the stride is a string starting with "ddim", then the fixed striding 21 | from the DDIM paper is used, and only one section is allowed. 22 | :param num_timesteps: the number of diffusion steps in the original 23 | process to divide up. 24 | :param section_counts: either a list of numbers, or a string containing 25 | comma-separated numbers, indicating the step count 26 | per section. As a special case, use "ddimN" where N 27 | is a number of steps to use the striding from the 28 | DDIM paper. 29 | :return: a set of diffusion steps from the original process to use. 30 | """ 31 | if isinstance(section_counts, str): 32 | if section_counts.startswith("ddim"): 33 | desired_count = int(section_counts[len("ddim") :]) 34 | for i in range(1, num_timesteps): 35 | if len(range(0, num_timesteps, i)) == desired_count: 36 | return set(range(0, num_timesteps, i)) 37 | raise ValueError( 38 | f"cannot create exactly {num_timesteps} steps with an integer stride" 39 | ) 40 | section_counts = [int(x) for x in section_counts.split(",")] 41 | size_per = num_timesteps // len(section_counts) 42 | extra = num_timesteps % len(section_counts) 43 | start_idx = 0 44 | all_steps = [] 45 | for i, section_count in enumerate(section_counts): 46 | size = size_per + (1 if i < extra else 0) 47 | if size < section_count: 48 | raise ValueError( 49 | f"cannot divide section of {size} steps into {section_count}" 50 | ) 51 | if section_count <= 1: 52 | frac_stride = 1 53 | else: 54 | frac_stride = (size - 1) / (section_count - 1) 55 | cur_idx = 0.0 56 | taken_steps = [] 57 | for _ in range(section_count): 58 | taken_steps.append(start_idx + round(cur_idx)) 59 | cur_idx += frac_stride 60 | all_steps += taken_steps 61 | start_idx += size 62 | return set(all_steps) 63 | 64 | 65 | class SpacedDiffusion(GaussianDiffusion): 66 | """ 67 | A diffusion process which can skip steps in a base diffusion process. 68 | :param use_timesteps: a collection (sequence or set) of timesteps from the 69 | original diffusion process to retain. 70 | :param kwargs: the kwargs to create the base diffusion process. 71 | """ 72 | 73 | def __init__(self, use_timesteps, **kwargs): 74 | self.use_timesteps = set(use_timesteps) 75 | self.timestep_map = [] 76 | self.original_num_steps = len(kwargs["betas"]) 77 | 78 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa 79 | last_alpha_cumprod = 1.0 80 | new_betas = [] 81 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): 82 | if i in self.use_timesteps: 83 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) 84 | last_alpha_cumprod = alpha_cumprod 85 | self.timestep_map.append(i) 86 | kwargs["betas"] = np.array(new_betas) 87 | super().__init__(**kwargs) 88 | 89 | def p_mean_variance( 90 | self, model, *args, **kwargs 91 | ): # pylint: disable=signature-differs 92 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) 93 | 94 | # @torch.compile 95 | def training_losses( 96 | self, model, *args, **kwargs 97 | ): # pylint: disable=signature-differs 98 | return super().training_losses(self._wrap_model(model), *args, **kwargs) 99 | 100 | def condition_mean(self, cond_fn, *args, **kwargs): 101 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) 102 | 103 | def condition_score(self, cond_fn, *args, **kwargs): 104 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) 105 | 106 | def _wrap_model(self, model): 107 | if isinstance(model, _WrappedModel): 108 | return model 109 | return _WrappedModel( 110 | model, self.timestep_map, self.original_num_steps 111 | ) 112 | 113 | def _scale_timesteps(self, t): 114 | # Scaling is done by the wrapped model. 115 | return t 116 | 117 | 118 | class _WrappedModel: 119 | def __init__(self, model, timestep_map, original_num_steps): 120 | self.model = model 121 | self.timestep_map = timestep_map 122 | # self.rescale_timesteps = rescale_timesteps 123 | self.original_num_steps = original_num_steps 124 | 125 | def __call__(self, x, ts, **kwargs): 126 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) 127 | new_ts = map_tensor[ts] 128 | # if self.rescale_timesteps: 129 | # new_ts = new_ts.float() * (1000.0 / self.original_num_steps) 130 | return self.model(x, new_ts, **kwargs) -------------------------------------------------------------------------------- /visual_planner/diffusion_model/script_util.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import inspect 3 | 4 | from . import gaussian_diffusion as gd 5 | from .respace import SpacedDiffusion, space_timesteps 6 | from .unet import SuperResModel, UNetModel, EncoderUNetModel 7 | 8 | NUM_CLASSES = 1000 9 | 10 | 11 | def diffusion_defaults(): 12 | """ 13 | Defaults for image and classifier training. 14 | """ 15 | return dict( 16 | learn_sigma=False, 17 | diffusion_steps=1000, 18 | noise_schedule="linear", 19 | timestep_respacing="", 20 | use_kl=False, 21 | predict_xstart=False, 22 | rescale_timesteps=False, 23 | rescale_learned_sigmas=False, 24 | ) 25 | 26 | 27 | def classifier_defaults(): 28 | """ 29 | Defaults for classifier models. 30 | """ 31 | return dict( 32 | image_size=64, 33 | classifier_use_fp16=False, 34 | classifier_width=128, 35 | classifier_depth=2, 36 | classifier_attention_resolutions="32,16,8", # 16 37 | classifier_use_scale_shift_norm=True, # False 38 | classifier_resblock_updown=True, # False 39 | classifier_pool="attention", 40 | ) 41 | 42 | 43 | def model_and_diffusion_defaults(): 44 | """ 45 | Defaults for image training. 46 | """ 47 | res = dict( 48 | image_size=64, 49 | num_channels=128, 50 | num_res_blocks=2, 51 | num_heads=4, 52 | num_heads_upsample=-1, 53 | num_head_channels=-1, 54 | attention_resolutions="16,8", 55 | channel_mult="", 56 | dropout=0.0, 57 | class_cond=False, 58 | use_checkpoint=False, 59 | use_scale_shift_norm=True, 60 | resblock_updown=False, 61 | use_fp16=False, 62 | use_new_attention_order=False, 63 | ) 64 | res.update(diffusion_defaults()) 65 | return res 66 | 67 | 68 | def classifier_and_diffusion_defaults(): 69 | res = classifier_defaults() 70 | res.update(diffusion_defaults()) 71 | return res 72 | 73 | 74 | def create_model_and_diffusion( 75 | image_size, 76 | class_cond, 77 | learn_sigma, 78 | num_channels, 79 | num_res_blocks, 80 | channel_mult, 81 | num_heads, 82 | num_head_channels, 83 | num_heads_upsample, 84 | attention_resolutions, 85 | dropout, 86 | diffusion_steps, 87 | noise_schedule, 88 | timestep_respacing, 89 | use_kl, 90 | predict_xstart, 91 | rescale_timesteps, 92 | rescale_learned_sigmas, 93 | use_checkpoint, 94 | use_scale_shift_norm, 95 | resblock_updown, 96 | use_fp16, 97 | use_new_attention_order, 98 | ): 99 | model = create_model( 100 | image_size, 101 | num_channels, 102 | num_res_blocks, 103 | channel_mult=channel_mult, 104 | learn_sigma=learn_sigma, 105 | class_cond=class_cond, 106 | use_checkpoint=use_checkpoint, 107 | attention_resolutions=attention_resolutions, 108 | num_heads=num_heads, 109 | num_head_channels=num_head_channels, 110 | num_heads_upsample=num_heads_upsample, 111 | use_scale_shift_norm=use_scale_shift_norm, 112 | dropout=dropout, 113 | resblock_updown=resblock_updown, 114 | use_fp16=use_fp16, 115 | use_new_attention_order=use_new_attention_order, 116 | ) 117 | diffusion = create_gaussian_diffusion( 118 | steps=diffusion_steps, 119 | learn_sigma=learn_sigma, 120 | noise_schedule=noise_schedule, 121 | use_kl=use_kl, 122 | predict_xstart=predict_xstart, 123 | rescale_timesteps=rescale_timesteps, 124 | rescale_learned_sigmas=rescale_learned_sigmas, 125 | timestep_respacing=timestep_respacing, 126 | ) 127 | return model, diffusion 128 | 129 | 130 | def create_model( 131 | image_size, 132 | num_channels, 133 | num_res_blocks, 134 | channel_mult="", 135 | learn_sigma=False, 136 | class_cond=False, 137 | use_checkpoint=False, 138 | attention_resolutions="16", 139 | num_heads=1, 140 | num_head_channels=-1, 141 | num_heads_upsample=-1, 142 | use_scale_shift_norm=False, 143 | dropout=0, 144 | resblock_updown=False, 145 | use_fp16=False, 146 | use_new_attention_order=False, 147 | ): 148 | if channel_mult == "": 149 | if image_size == 512: 150 | channel_mult = (0.5, 1, 1, 2, 2, 4, 4) 151 | elif image_size == 256: 152 | channel_mult = (1, 1, 2, 2, 4, 4) 153 | elif image_size == 128: 154 | channel_mult = (1, 1, 2, 3, 4) 155 | elif image_size == 64: 156 | channel_mult = (1, 2, 3, 4) 157 | else: 158 | raise ValueError(f"unsupported image size: {image_size}") 159 | else: 160 | channel_mult = tuple(int(ch_mult) for ch_mult in channel_mult.split(",")) 161 | 162 | attention_ds = [] 163 | for res in attention_resolutions.split(","): 164 | attention_ds.append(image_size // int(res)) 165 | 166 | return UNetModel( 167 | image_size=image_size, 168 | in_channels=3, 169 | model_channels=num_channels, 170 | out_channels=(3 if not learn_sigma else 6), 171 | num_res_blocks=num_res_blocks, 172 | attention_resolutions=tuple(attention_ds), 173 | dropout=dropout, 174 | channel_mult=channel_mult, 175 | num_classes=(NUM_CLASSES if class_cond else None), 176 | use_checkpoint=use_checkpoint, 177 | use_fp16=use_fp16, 178 | num_heads=num_heads, 179 | num_head_channels=num_head_channels, 180 | num_heads_upsample=num_heads_upsample, 181 | use_scale_shift_norm=use_scale_shift_norm, 182 | resblock_updown=resblock_updown, 183 | use_new_attention_order=use_new_attention_order, 184 | ) 185 | 186 | 187 | def create_classifier_and_diffusion( 188 | image_size, 189 | classifier_use_fp16, 190 | classifier_width, 191 | classifier_depth, 192 | classifier_attention_resolutions, 193 | classifier_use_scale_shift_norm, 194 | classifier_resblock_updown, 195 | classifier_pool, 196 | learn_sigma, 197 | diffusion_steps, 198 | noise_schedule, 199 | timestep_respacing, 200 | use_kl, 201 | predict_xstart, 202 | rescale_timesteps, 203 | rescale_learned_sigmas, 204 | ): 205 | classifier = create_classifier( 206 | image_size, 207 | classifier_use_fp16, 208 | classifier_width, 209 | classifier_depth, 210 | classifier_attention_resolutions, 211 | classifier_use_scale_shift_norm, 212 | classifier_resblock_updown, 213 | classifier_pool, 214 | ) 215 | diffusion = create_gaussian_diffusion( 216 | steps=diffusion_steps, 217 | learn_sigma=learn_sigma, 218 | noise_schedule=noise_schedule, 219 | use_kl=use_kl, 220 | predict_xstart=predict_xstart, 221 | rescale_timesteps=rescale_timesteps, 222 | rescale_learned_sigmas=rescale_learned_sigmas, 223 | timestep_respacing=timestep_respacing, 224 | ) 225 | return classifier, diffusion 226 | 227 | 228 | def create_classifier( 229 | image_size, 230 | classifier_use_fp16, 231 | classifier_width, 232 | classifier_depth, 233 | classifier_attention_resolutions, 234 | classifier_use_scale_shift_norm, 235 | classifier_resblock_updown, 236 | classifier_pool, 237 | ): 238 | if image_size == 512: 239 | channel_mult = (0.5, 1, 1, 2, 2, 4, 4) 240 | elif image_size == 256: 241 | channel_mult = (1, 1, 2, 2, 4, 4) 242 | elif image_size == 128: 243 | channel_mult = (1, 1, 2, 3, 4) 244 | elif image_size == 64: 245 | channel_mult = (1, 2, 3, 4) 246 | else: 247 | raise ValueError(f"unsupported image size: {image_size}") 248 | 249 | attention_ds = [] 250 | for res in classifier_attention_resolutions.split(","): 251 | attention_ds.append(image_size // int(res)) 252 | 253 | return EncoderUNetModel( 254 | image_size=image_size, 255 | in_channels=3, 256 | model_channels=classifier_width, 257 | out_channels=1000, 258 | num_res_blocks=classifier_depth, 259 | attention_resolutions=tuple(attention_ds), 260 | channel_mult=channel_mult, 261 | use_fp16=classifier_use_fp16, 262 | num_head_channels=64, 263 | use_scale_shift_norm=classifier_use_scale_shift_norm, 264 | resblock_updown=classifier_resblock_updown, 265 | pool=classifier_pool, 266 | ) 267 | 268 | 269 | def sr_model_and_diffusion_defaults(): 270 | res = model_and_diffusion_defaults() 271 | res["large_size"] = 256 272 | res["small_size"] = 64 273 | arg_names = inspect.getfullargspec(sr_create_model_and_diffusion)[0] 274 | for k in res.copy().keys(): 275 | if k not in arg_names: 276 | del res[k] 277 | return res 278 | 279 | 280 | def sr_create_model_and_diffusion( 281 | large_size, 282 | small_size, 283 | class_cond, 284 | learn_sigma, 285 | num_channels, 286 | num_res_blocks, 287 | num_heads, 288 | num_head_channels, 289 | num_heads_upsample, 290 | attention_resolutions, 291 | dropout, 292 | diffusion_steps, 293 | noise_schedule, 294 | timestep_respacing, 295 | use_kl, 296 | predict_xstart, 297 | rescale_timesteps, 298 | rescale_learned_sigmas, 299 | use_checkpoint, 300 | use_scale_shift_norm, 301 | resblock_updown, 302 | use_fp16, 303 | ): 304 | model = sr_create_model( 305 | large_size, 306 | small_size, 307 | num_channels, 308 | num_res_blocks, 309 | learn_sigma=learn_sigma, 310 | class_cond=class_cond, 311 | use_checkpoint=use_checkpoint, 312 | attention_resolutions=attention_resolutions, 313 | num_heads=num_heads, 314 | num_head_channels=num_head_channels, 315 | num_heads_upsample=num_heads_upsample, 316 | use_scale_shift_norm=use_scale_shift_norm, 317 | dropout=dropout, 318 | resblock_updown=resblock_updown, 319 | use_fp16=use_fp16, 320 | ) 321 | diffusion = create_gaussian_diffusion( 322 | steps=diffusion_steps, 323 | learn_sigma=learn_sigma, 324 | noise_schedule=noise_schedule, 325 | use_kl=use_kl, 326 | predict_xstart=predict_xstart, 327 | rescale_timesteps=rescale_timesteps, 328 | rescale_learned_sigmas=rescale_learned_sigmas, 329 | timestep_respacing=timestep_respacing, 330 | ) 331 | return model, diffusion 332 | 333 | 334 | def sr_create_model( 335 | large_size, 336 | small_size, 337 | num_channels, 338 | num_res_blocks, 339 | learn_sigma, 340 | class_cond, 341 | use_checkpoint, 342 | attention_resolutions, 343 | num_heads, 344 | num_head_channels, 345 | num_heads_upsample, 346 | use_scale_shift_norm, 347 | dropout, 348 | resblock_updown, 349 | use_fp16, 350 | ): 351 | _ = small_size # hack to prevent unused variable 352 | 353 | if large_size == 512: 354 | channel_mult = (1, 1, 2, 2, 4, 4) 355 | elif large_size == 256: 356 | channel_mult = (1, 1, 2, 2, 4, 4) 357 | elif large_size == 64: 358 | channel_mult = (1, 2, 3, 4) 359 | else: 360 | raise ValueError(f"unsupported large size: {large_size}") 361 | 362 | attention_ds = [] 363 | for res in attention_resolutions.split(","): 364 | attention_ds.append(large_size // int(res)) 365 | 366 | return SuperResModel( 367 | image_size=large_size, 368 | in_channels=3, 369 | model_channels=num_channels, 370 | out_channels=(3 if not learn_sigma else 6), 371 | num_res_blocks=num_res_blocks, 372 | attention_resolutions=tuple(attention_ds), 373 | dropout=dropout, 374 | channel_mult=channel_mult, 375 | num_classes=(NUM_CLASSES if class_cond else None), 376 | use_checkpoint=use_checkpoint, 377 | num_heads=num_heads, 378 | num_head_channels=num_head_channels, 379 | num_heads_upsample=num_heads_upsample, 380 | use_scale_shift_norm=use_scale_shift_norm, 381 | resblock_updown=resblock_updown, 382 | use_fp16=use_fp16, 383 | ) 384 | 385 | 386 | def create_gaussian_diffusion( 387 | *, 388 | steps=1000, 389 | learn_sigma=False, 390 | sigma_small=False, 391 | noise_schedule="linear", 392 | use_kl=False, 393 | predict_xstart=False, 394 | rescale_timesteps=False, 395 | rescale_learned_sigmas=False, 396 | timestep_respacing="", 397 | ): 398 | betas = gd.get_named_beta_schedule(noise_schedule, steps) 399 | if use_kl: 400 | loss_type = gd.LossType.RESCALED_KL 401 | elif rescale_learned_sigmas: 402 | loss_type = gd.LossType.RESCALED_MSE 403 | else: 404 | loss_type = gd.LossType.MSE 405 | if not timestep_respacing: 406 | timestep_respacing = [steps] 407 | return SpacedDiffusion( 408 | use_timesteps=space_timesteps(steps, timestep_respacing), 409 | betas=betas, 410 | model_mean_type=( 411 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X 412 | ), 413 | model_var_type=( 414 | ( 415 | gd.ModelVarType.FIXED_LARGE 416 | if not sigma_small 417 | else gd.ModelVarType.FIXED_SMALL 418 | ) 419 | if not learn_sigma 420 | else gd.ModelVarType.LEARNED_RANGE 421 | ), 422 | loss_type=loss_type, 423 | rescale_timesteps=rescale_timesteps, 424 | ) 425 | 426 | 427 | def add_dict_to_argparser(parser, default_dict): 428 | for k, v in default_dict.items(): 429 | v_type = type(v) 430 | if v is None: 431 | v_type = str 432 | elif isinstance(v, bool): 433 | v_type = str2bool 434 | parser.add_argument(f"--{k}", default=v, type=v_type) 435 | 436 | 437 | def args_to_dict(args, keys): 438 | return {k: getattr(args, k) for k in keys} 439 | 440 | 441 | def str2bool(v): 442 | """ 443 | https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse 444 | """ 445 | if isinstance(v, bool): 446 | return v 447 | if v.lower() in ("yes", "true", "t", "y", "1"): 448 | return True 449 | elif v.lower() in ("no", "false", "f", "n", "0"): 450 | return False 451 | else: 452 | raise argparse.ArgumentTypeError("boolean value expected") 453 | -------------------------------------------------------------------------------- /visual_planner/diffusion_model/train_util.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import functools 3 | import os 4 | 5 | import blobfile as bf 6 | import torch as th 7 | import torch.distributed as dist 8 | from torch.nn.parallel.distributed import DistributedDataParallel as DDP 9 | from torch.optim import AdamW 10 | 11 | from . import dist_util, logger 12 | from .fp16_util import MixedPrecisionTrainer 13 | from .nn import update_ema 14 | from .resample import LossAwareSampler, UniformSampler 15 | 16 | # For ImageNet experiments, this was a good default value. 17 | # We found that the lg_loss_scale quickly climbed to 18 | # 20-21 within the first ~1K steps of training. 19 | INITIAL_LOG_LOSS_SCALE = 20.0 20 | 21 | 22 | class TrainLoop: 23 | def __init__( 24 | self, 25 | *, 26 | model, 27 | diffusion, 28 | data, 29 | batch_size, 30 | microbatch, 31 | lr, 32 | ema_rate, 33 | log_interval, 34 | save_interval, 35 | resume_checkpoint, 36 | use_fp16=False, 37 | fp16_scale_growth=1e-3, 38 | schedule_sampler=None, 39 | weight_decay=0.0, 40 | lr_anneal_steps=0, 41 | ): 42 | self.model = model 43 | self.diffusion = diffusion 44 | self.data = data 45 | self.batch_size = batch_size 46 | self.microbatch = microbatch if microbatch > 0 else batch_size 47 | self.lr = lr 48 | self.ema_rate = ( 49 | [ema_rate] 50 | if isinstance(ema_rate, float) 51 | else [float(x) for x in ema_rate.split(",")] 52 | ) 53 | self.log_interval = log_interval 54 | self.save_interval = save_interval 55 | self.resume_checkpoint = resume_checkpoint 56 | self.use_fp16 = use_fp16 57 | self.fp16_scale_growth = fp16_scale_growth 58 | self.schedule_sampler = schedule_sampler or UniformSampler(diffusion) 59 | self.weight_decay = weight_decay 60 | self.lr_anneal_steps = lr_anneal_steps 61 | 62 | self.step = 0 63 | self.resume_step = 0 64 | self.global_batch = self.batch_size * dist.get_world_size() 65 | 66 | self.sync_cuda = th.cuda.is_available() 67 | 68 | self._load_and_sync_parameters() 69 | self.mp_trainer = MixedPrecisionTrainer( 70 | model=self.model, 71 | use_fp16=self.use_fp16, 72 | fp16_scale_growth=fp16_scale_growth, 73 | ) 74 | 75 | self.opt = AdamW( 76 | self.mp_trainer.master_params, lr=self.lr, weight_decay=self.weight_decay 77 | ) 78 | if self.resume_step: 79 | self._load_optimizer_state() 80 | # Model was resumed, either due to a restart or a checkpoint 81 | # being specified at the command line. 82 | self.ema_params = [ 83 | self._load_ema_parameters(rate) for rate in self.ema_rate 84 | ] 85 | else: 86 | self.ema_params = [ 87 | copy.deepcopy(self.mp_trainer.master_params) 88 | for _ in range(len(self.ema_rate)) 89 | ] 90 | 91 | if th.cuda.is_available(): 92 | self.use_ddp = True 93 | self.ddp_model = DDP( 94 | self.model, 95 | device_ids=[dist_util.dev()], 96 | output_device=dist_util.dev(), 97 | broadcast_buffers=False, 98 | bucket_cap_mb=128, 99 | find_unused_parameters=False, 100 | ) 101 | else: 102 | if dist.get_world_size() > 1: 103 | logger.warn( 104 | "Distributed training requires CUDA. " 105 | "Gradients will not be synchronized properly!" 106 | ) 107 | self.use_ddp = False 108 | self.ddp_model = self.model 109 | 110 | def _load_and_sync_parameters(self): 111 | resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint 112 | 113 | if resume_checkpoint: 114 | self.resume_step = parse_resume_step_from_filename(resume_checkpoint) 115 | if dist.get_rank() == 0: 116 | logger.log(f"loading model from checkpoint: {resume_checkpoint}...") 117 | self.model.load_state_dict( 118 | dist_util.load_state_dict( 119 | resume_checkpoint, map_location=dist_util.dev() 120 | ) 121 | ) 122 | 123 | dist_util.sync_params(self.model.parameters()) 124 | 125 | def _load_ema_parameters(self, rate): 126 | ema_params = copy.deepcopy(self.mp_trainer.master_params) 127 | 128 | main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint 129 | ema_checkpoint = find_ema_checkpoint(main_checkpoint, self.resume_step, rate) 130 | if ema_checkpoint: 131 | if dist.get_rank() == 0: 132 | logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...") 133 | state_dict = dist_util.load_state_dict( 134 | ema_checkpoint, map_location=dist_util.dev() 135 | ) 136 | ema_params = self.mp_trainer.state_dict_to_master_params(state_dict) 137 | 138 | dist_util.sync_params(ema_params) 139 | return ema_params 140 | 141 | def _load_optimizer_state(self): 142 | main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint 143 | opt_checkpoint = bf.join( 144 | bf.dirname(main_checkpoint), f"opt{self.resume_step:06}.pt" 145 | ) 146 | if bf.exists(opt_checkpoint): 147 | logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}") 148 | state_dict = dist_util.load_state_dict( 149 | opt_checkpoint, map_location=dist_util.dev() 150 | ) 151 | self.opt.load_state_dict(state_dict) 152 | 153 | def run_loop(self): 154 | while ( 155 | not self.lr_anneal_steps 156 | or self.step + self.resume_step < self.lr_anneal_steps 157 | ): 158 | batch, cond = next(self.data) 159 | self.run_step(batch, cond) 160 | if self.step % self.log_interval == 0: 161 | logger.dumpkvs() 162 | if self.step % self.save_interval == 0: 163 | self.save() 164 | # Run for a finite amount of time in integration tests. 165 | if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0: 166 | return 167 | self.step += 1 168 | # Save the last checkpoint if it wasn't already saved. 169 | if (self.step - 1) % self.save_interval != 0: 170 | self.save() 171 | 172 | def run_step(self, batch, cond): 173 | self.forward_backward(batch, cond) 174 | took_step = self.mp_trainer.optimize(self.opt) 175 | if took_step: 176 | self._update_ema() 177 | self._anneal_lr() 178 | self.log_step() 179 | 180 | def forward_backward(self, batch, cond): 181 | self.mp_trainer.zero_grad() 182 | for i in range(0, batch.shape[0], self.microbatch): 183 | micro = batch[i : i + self.microbatch].to(dist_util.dev()) 184 | micro_cond = { 185 | k: v[i : i + self.microbatch].to(dist_util.dev()) 186 | for k, v in cond.items() 187 | } 188 | last_batch = (i + self.microbatch) >= batch.shape[0] 189 | t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev()) 190 | 191 | compute_losses = functools.partial( 192 | self.diffusion.training_losses, 193 | self.ddp_model, 194 | micro, 195 | t, 196 | model_kwargs=micro_cond, 197 | ) 198 | 199 | if last_batch or not self.use_ddp: 200 | losses = compute_losses() 201 | else: 202 | with self.ddp_model.no_sync(): 203 | losses = compute_losses() 204 | 205 | if isinstance(self.schedule_sampler, LossAwareSampler): 206 | self.schedule_sampler.update_with_local_losses( 207 | t, losses["loss"].detach() 208 | ) 209 | 210 | loss = (losses["loss"] * weights).mean() 211 | log_loss_dict( 212 | self.diffusion, t, {k: v * weights for k, v in losses.items()} 213 | ) 214 | self.mp_trainer.backward(loss) 215 | 216 | def _update_ema(self): 217 | for rate, params in zip(self.ema_rate, self.ema_params): 218 | update_ema(params, self.mp_trainer.master_params, rate=rate) 219 | 220 | def _anneal_lr(self): 221 | if not self.lr_anneal_steps: 222 | return 223 | frac_done = (self.step + self.resume_step) / self.lr_anneal_steps 224 | lr = self.lr * (1 - frac_done) 225 | for param_group in self.opt.param_groups: 226 | param_group["lr"] = lr 227 | 228 | def log_step(self): 229 | logger.logkv("step", self.step + self.resume_step) 230 | logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch) 231 | 232 | def save(self): 233 | def save_checkpoint(rate, params): 234 | state_dict = self.mp_trainer.master_params_to_state_dict(params) 235 | if dist.get_rank() == 0: 236 | logger.log(f"saving model {rate}...") 237 | if not rate: 238 | filename = f"model{(self.step+self.resume_step):06d}.pt" 239 | else: 240 | filename = f"ema_{rate}_{(self.step+self.resume_step):06d}.pt" 241 | with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f: 242 | th.save(state_dict, f) 243 | 244 | save_checkpoint(0, self.mp_trainer.master_params) 245 | for rate, params in zip(self.ema_rate, self.ema_params): 246 | save_checkpoint(rate, params) 247 | 248 | if dist.get_rank() == 0: 249 | with bf.BlobFile( 250 | bf.join(get_blob_logdir(), f"opt{(self.step+self.resume_step):06d}.pt"), 251 | "wb", 252 | ) as f: 253 | th.save(self.opt.state_dict(), f) 254 | 255 | dist.barrier() 256 | 257 | 258 | def parse_resume_step_from_filename(filename): 259 | """ 260 | Parse filenames of the form path/to/modelNNNNNN.pt, where NNNNNN is the 261 | checkpoint's number of steps. 262 | """ 263 | split = filename.split("model") 264 | if len(split) < 2: 265 | return 0 266 | split1 = split[-1].split(".")[0] 267 | try: 268 | return int(split1) 269 | except ValueError: 270 | return 0 271 | 272 | 273 | def get_blob_logdir(): 274 | # You can change this to be a separate path to save checkpoints to 275 | # a blobstore or some external drive. 276 | return logger.get_dir() 277 | 278 | 279 | def find_resume_checkpoint(): 280 | # On your infrastructure, you may want to override this to automatically 281 | # discover the latest checkpoint on your blob storage, etc. 282 | return None 283 | 284 | 285 | def find_ema_checkpoint(main_checkpoint, step, rate): 286 | if main_checkpoint is None: 287 | return None 288 | filename = f"ema_{rate}_{(step):06d}.pt" 289 | path = bf.join(bf.dirname(main_checkpoint), filename) 290 | if bf.exists(path): 291 | return path 292 | return None 293 | 294 | 295 | def log_loss_dict(diffusion, ts, losses): 296 | for key, values in losses.items(): 297 | logger.logkv_mean(key, values.mean().item()) 298 | # Log the quantiles (four quartiles, in particular). 299 | for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()): 300 | quartile = int(4 * sub_t / diffusion.num_timesteps) 301 | logger.logkv_mean(f"{key}_q{quartile}", sub_loss) 302 | -------------------------------------------------------------------------------- /visual_planner/metric_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .calc_fvd import calculate_fvd 2 | from .calc_lpips import calculate_lpips 3 | from .calc_psnr import calculate_psnr 4 | from .calc_ssim import calculate_ssim -------------------------------------------------------------------------------- /visual_planner/metric_utils/calc_fvd.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from tqdm import tqdm 4 | 5 | def trans(x): 6 | # if greyscale images add channel 7 | if x.shape[-3] == 1: 8 | x = x.repeat(1, 1, 3, 1, 1) 9 | 10 | # permute BTCHW -> BCTHW 11 | x = x.permute(0, 2, 1, 3, 4) 12 | 13 | return x 14 | 15 | def calculate_fvd(videos1, videos2, device, method='styleganv'): 16 | 17 | if method == 'styleganv': 18 | from fvd.styleganv.fvd import get_fvd_feats, frechet_distance, load_i3d_pretrained 19 | elif method == 'videogpt': 20 | from fvd.videogpt.fvd import load_i3d_pretrained 21 | from fvd.videogpt.fvd import get_fvd_logits as get_fvd_feats 22 | from fvd.videogpt.fvd import frechet_distance 23 | 24 | print("calculate_fvd...") 25 | 26 | # videos [batch_size, timestamps, channel, h, w] 27 | 28 | assert videos1.shape == videos2.shape 29 | 30 | i3d = load_i3d_pretrained(device=device) 31 | fvd_results = [] 32 | 33 | # support grayscale input, if grayscale -> channel*3 34 | # BTCHW -> BCTHW 35 | # videos -> [batch_size, channel, timestamps, h, w] 36 | 37 | videos1 = trans(videos1) 38 | videos2 = trans(videos2) 39 | 40 | fvd_results = {} 41 | 42 | # for calculate FVD, each clip_timestamp must >= 10 43 | for clip_timestamp in tqdm(range(10, videos1.shape[-3]+1)): 44 | 45 | # get a video clip 46 | # videos_clip [batch_size, channel, timestamps[:clip], h, w] 47 | videos_clip1 = videos1[:, :, : clip_timestamp] 48 | videos_clip2 = videos2[:, :, : clip_timestamp] 49 | 50 | # get FVD features 51 | feats1 = get_fvd_feats(videos_clip1, i3d=i3d, device=device) 52 | feats2 = get_fvd_feats(videos_clip2, i3d=i3d, device=device) 53 | 54 | # calculate FVD when timestamps[:clip] 55 | fvd_results[clip_timestamp] = frechet_distance(feats1, feats2) 56 | 57 | result = { 58 | "value": fvd_results, 59 | "video_setting": videos1.shape, 60 | "video_setting_name": "batch_size, channel, time, heigth, width", 61 | } 62 | 63 | return result 64 | 65 | # test code / using example 66 | 67 | def main(): 68 | NUMBER_OF_VIDEOS = 8 69 | VIDEO_LENGTH = 50 70 | CHANNEL = 3 71 | SIZE = 64 72 | videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False) 73 | videos2 = torch.ones(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False) 74 | device = torch.device("cuda") 75 | # device = torch.device("cpu") 76 | 77 | import json 78 | result = calculate_fvd(videos1, videos2, device, method='videogpt') 79 | print(json.dumps(result, indent=4)) 80 | 81 | result = calculate_fvd(videos1, videos2, device, method='styleganv') 82 | print(json.dumps(result, indent=4)) 83 | 84 | if __name__ == "__main__": 85 | main() -------------------------------------------------------------------------------- /visual_planner/metric_utils/calc_lpips.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from tqdm import tqdm 4 | import math 5 | 6 | import torch 7 | import lpips 8 | 9 | spatial = True # Return a spatial map of perceptual distance. 10 | 11 | # Linearly calibrated models (LPIPS) 12 | loss_fn = lpips.LPIPS(net='alex', spatial=spatial) # Can also set net = 'squeeze' or 'vgg' 13 | # loss_fn = lpips.LPIPS(net='alex', spatial=spatial, lpips=False) # Can also set net = 'squeeze' or 'vgg' 14 | 15 | def trans(x): 16 | # if greyscale images add channel 17 | if x.shape[-3] == 1: 18 | x = x.repeat(1, 1, 3, 1, 1) 19 | 20 | # value range [0, 1] -> [-1, 1] 21 | x = x * 2 - 1 22 | 23 | return x 24 | 25 | def calculate_lpips(videos1, videos2, device): 26 | # image should be RGB, IMPORTANT: normalized to [-1,1] 27 | print("calculate_lpips...") 28 | 29 | assert videos1.shape == videos2.shape 30 | 31 | # videos [batch_size, timestamps, channel, h, w] 32 | 33 | # support grayscale input, if grayscale -> channel*3 34 | # value range [0, 1] -> [-1, 1] 35 | videos1 = trans(videos1) 36 | videos2 = trans(videos2) 37 | 38 | lpips_results = [] 39 | 40 | for video_num in tqdm(range(videos1.shape[0])): 41 | # get a video 42 | # video [timestamps, channel, h, w] 43 | video1 = videos1[video_num] 44 | video2 = videos2[video_num] 45 | 46 | lpips_results_of_a_video = [] 47 | for clip_timestamp in range(len(video1)): 48 | # get a img 49 | # img [timestamps[x], channel, h, w] 50 | # img [channel, h, w] tensor 51 | 52 | img1 = video1[clip_timestamp].unsqueeze(0).to(device) 53 | img2 = video2[clip_timestamp].unsqueeze(0).to(device) 54 | 55 | loss_fn.to(device) 56 | 57 | # calculate lpips of a video 58 | lpips_results_of_a_video.append(loss_fn.forward(img1, img2).mean().detach().cpu().tolist()) 59 | lpips_results.append(lpips_results_of_a_video) 60 | 61 | lpips_results = np.array(lpips_results) 62 | 63 | lpips = {} 64 | lpips_std = {} 65 | 66 | for clip_timestamp in range(len(video1)): 67 | lpips[clip_timestamp] = np.mean(lpips_results[:,clip_timestamp]) 68 | lpips_std[clip_timestamp] = np.std(lpips_results[:,clip_timestamp]) 69 | 70 | 71 | result = { 72 | "value": lpips, 73 | "value_std": lpips_std, 74 | "video_setting": video1.shape, 75 | "video_setting_name": "time, channel, heigth, width", 76 | } 77 | 78 | return result 79 | 80 | # test code / using example 81 | 82 | def main(): 83 | NUMBER_OF_VIDEOS = 8 84 | VIDEO_LENGTH = 50 85 | CHANNEL = 3 86 | SIZE = 64 87 | videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False) 88 | videos2 = torch.ones(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False) 89 | device = torch.device("cuda") 90 | # device = torch.device("cpu") 91 | 92 | import json 93 | result = calculate_lpips(videos1, videos2, device) 94 | print(json.dumps(result, indent=4)) 95 | 96 | if __name__ == "__main__": 97 | main() -------------------------------------------------------------------------------- /visual_planner/metric_utils/calc_psnr.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from tqdm import tqdm 4 | import math 5 | 6 | def img_psnr(img1, img2): 7 | # [0,1] 8 | # compute mse 9 | # mse = np.mean((img1-img2)**2) 10 | mse = np.mean((img1 / 1.0 - img2 / 1.0) ** 2) 11 | # compute psnr 12 | if mse < 1e-10: 13 | return 100 14 | psnr = 20 * math.log10(1 / math.sqrt(mse)) 15 | return psnr 16 | 17 | def trans(x): 18 | return x 19 | 20 | def calculate_psnr(videos1, videos2): 21 | print("calculate_psnr...") 22 | 23 | # videos [batch_size, timestamps, channel, h, w] 24 | 25 | assert videos1.shape == videos2.shape 26 | 27 | videos1 = trans(videos1) 28 | videos2 = trans(videos2) 29 | 30 | psnr_results = [] 31 | 32 | for video_num in tqdm(range(videos1.shape[0])): 33 | # get a video 34 | # video [timestamps, channel, h, w] 35 | video1 = videos1[video_num] 36 | video2 = videos2[video_num] 37 | 38 | psnr_results_of_a_video = [] 39 | for clip_timestamp in range(len(video1)): 40 | # get a img 41 | # img [timestamps[x], channel, h, w] 42 | # img [channel, h, w] numpy 43 | 44 | img1 = video1[clip_timestamp].numpy() 45 | img2 = video2[clip_timestamp].numpy() 46 | 47 | # calculate psnr of a video 48 | psnr_results_of_a_video.append(img_psnr(img1, img2)) 49 | 50 | psnr_results.append(psnr_results_of_a_video) 51 | 52 | psnr_results = np.array(psnr_results) 53 | 54 | psnr = {} 55 | psnr_std = {} 56 | 57 | for clip_timestamp in range(len(video1)): 58 | psnr[clip_timestamp] = np.mean(psnr_results[:,clip_timestamp]) 59 | psnr_std[clip_timestamp] = np.std(psnr_results[:,clip_timestamp]) 60 | 61 | result = { 62 | "value": psnr, 63 | "value_std": psnr_std, 64 | "video_setting": video1.shape, 65 | "video_setting_name": "time, channel, heigth, width", 66 | } 67 | 68 | return result 69 | 70 | # test code / using example 71 | 72 | def main(): 73 | NUMBER_OF_VIDEOS = 8 74 | VIDEO_LENGTH = 50 75 | CHANNEL = 3 76 | SIZE = 64 77 | videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False) 78 | videos2 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False) 79 | 80 | import json 81 | result = calculate_psnr(videos1, videos2) 82 | print(json.dumps(result, indent=4)) 83 | 84 | if __name__ == "__main__": 85 | main() -------------------------------------------------------------------------------- /visual_planner/metric_utils/calc_ssim.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from tqdm import tqdm 4 | import cv2 5 | 6 | def ssim(img1, img2): 7 | C1 = 0.01 ** 2 8 | C2 = 0.03 ** 2 9 | img1 = img1.astype(np.float64) 10 | img2 = img2.astype(np.float64) 11 | kernel = cv2.getGaussianKernel(11, 1.5) 12 | window = np.outer(kernel, kernel.transpose()) 13 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid 14 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 15 | mu1_sq = mu1 ** 2 16 | mu2_sq = mu2 ** 2 17 | mu1_mu2 = mu1 * mu2 18 | sigma1_sq = cv2.filter2D(img1 ** 2, -1, window)[5:-5, 5:-5] - mu1_sq 19 | sigma2_sq = cv2.filter2D(img2 ** 2, -1, window)[5:-5, 5:-5] - mu2_sq 20 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 21 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 22 | (sigma1_sq + sigma2_sq + C2)) 23 | return ssim_map.mean() 24 | 25 | 26 | def calculate_ssim_function(img1, img2): 27 | # [0,1] 28 | # ssim is the only metric extremely sensitive to gray being compared to b/w 29 | if not img1.shape == img2.shape: 30 | raise ValueError('Input images must have the same dimensions.') 31 | if img1.ndim == 2: 32 | return ssim(img1, img2) 33 | elif img1.ndim == 3: 34 | if img1.shape[0] == 3: 35 | ssims = [] 36 | for i in range(3): 37 | ssims.append(ssim(img1[i], img2[i])) 38 | return np.array(ssims).mean() 39 | elif img1.shape[0] == 1: 40 | return ssim(np.squeeze(img1), np.squeeze(img2)) 41 | else: 42 | raise ValueError('Wrong input image dimensions.') 43 | 44 | def trans(x): 45 | return x 46 | 47 | def calculate_ssim(videos1, videos2): 48 | print("calculate_ssim...") 49 | 50 | # videos [batch_size, timestamps, channel, h, w] 51 | 52 | assert videos1.shape == videos2.shape 53 | 54 | videos1 = trans(videos1) 55 | videos2 = trans(videos2) 56 | 57 | ssim_results = [] 58 | 59 | for video_num in tqdm(range(videos1.shape[0])): 60 | # get a video 61 | # video [timestamps, channel, h, w] 62 | video1 = videos1[video_num] 63 | video2 = videos2[video_num] 64 | 65 | ssim_results_of_a_video = [] 66 | for clip_timestamp in range(len(video1)): 67 | # get a img 68 | # img [timestamps[x], channel, h, w] 69 | # img [channel, h, w] numpy 70 | 71 | img1 = video1[clip_timestamp].numpy() 72 | img2 = video2[clip_timestamp].numpy() 73 | 74 | # calculate ssim of a video 75 | ssim_results_of_a_video.append(calculate_ssim_function(img1, img2)) 76 | 77 | ssim_results.append(ssim_results_of_a_video) 78 | 79 | ssim_results = np.array(ssim_results) 80 | 81 | ssim = {} 82 | ssim_std = {} 83 | 84 | for clip_timestamp in range(len(video1)): 85 | ssim[clip_timestamp] = np.mean(ssim_results[:,clip_timestamp]) 86 | ssim_std[clip_timestamp] = np.std(ssim_results[:,clip_timestamp]) 87 | 88 | result = { 89 | "value": ssim, 90 | "value_std": ssim_std, 91 | "video_setting": video1.shape, 92 | "video_setting_name": "time, channel, heigth, width", 93 | } 94 | 95 | return result 96 | 97 | # test code / using example 98 | 99 | def main(): 100 | NUMBER_OF_VIDEOS = 8 101 | VIDEO_LENGTH = 50 102 | CHANNEL = 3 103 | SIZE = 64 104 | videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False) 105 | videos2 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False) 106 | device = torch.device("cuda") 107 | 108 | import json 109 | result = calculate_ssim(videos1, videos2) 110 | print(json.dumps(result, indent=4)) 111 | 112 | if __name__ == "__main__": 113 | main() -------------------------------------------------------------------------------- /visual_planner/raft_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenDriveLab/CLOVER/92324ed0fbe563fcb272bed18522a34580958bdc/visual_planner/raft_utils/__init__.py -------------------------------------------------------------------------------- /visual_planner/raft_utils/corr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from .utils import bilinear_sampler, coords_grid 4 | 5 | try: 6 | import alt_cuda_corr 7 | except: 8 | # alt_cuda_corr is not compiled 9 | pass 10 | 11 | 12 | class CorrBlock: 13 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 14 | self.num_levels = num_levels 15 | self.radius = radius 16 | self.corr_pyramid = [] 17 | 18 | # all pairs correlation 19 | corr = CorrBlock.corr(fmap1, fmap2) 20 | 21 | batch, h1, w1, dim, h2, w2 = corr.shape 22 | corr = corr.reshape(batch*h1*w1, dim, h2, w2) 23 | 24 | self.corr_pyramid.append(corr) 25 | for i in range(self.num_levels-1): 26 | corr = F.avg_pool2d(corr, 2, stride=2) 27 | self.corr_pyramid.append(corr) 28 | 29 | def __call__(self, coords): 30 | r = self.radius 31 | coords = coords.permute(0, 2, 3, 1) 32 | batch, h1, w1, _ = coords.shape 33 | 34 | out_pyramid = [] 35 | for i in range(self.num_levels): 36 | corr = self.corr_pyramid[i] 37 | dx = torch.linspace(-r, r, 2*r+1, device=coords.device) 38 | dy = torch.linspace(-r, r, 2*r+1, device=coords.device) 39 | delta = torch.stack(torch.meshgrid(dy, dx), axis=-1) 40 | 41 | centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i 42 | delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2) 43 | coords_lvl = centroid_lvl + delta_lvl 44 | 45 | corr = bilinear_sampler(corr, coords_lvl) 46 | corr = corr.view(batch, h1, w1, -1) 47 | out_pyramid.append(corr) 48 | 49 | out = torch.cat(out_pyramid, dim=-1) 50 | return out.permute(0, 3, 1, 2).contiguous().float() 51 | 52 | @staticmethod 53 | def corr(fmap1, fmap2): 54 | batch, dim, ht, wd = fmap1.shape 55 | fmap1 = fmap1.view(batch, dim, ht*wd) 56 | fmap2 = fmap2.view(batch, dim, ht*wd) 57 | 58 | corr = torch.matmul(fmap1.transpose(1,2), fmap2) 59 | corr = corr.view(batch, ht, wd, 1, ht, wd) 60 | return corr / torch.sqrt(torch.tensor(dim).float()) 61 | 62 | 63 | class AlternateCorrBlock: 64 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 65 | self.num_levels = num_levels 66 | self.radius = radius 67 | 68 | self.pyramid = [(fmap1, fmap2)] 69 | for i in range(self.num_levels): 70 | fmap1 = F.avg_pool2d(fmap1, 2, stride=2) 71 | fmap2 = F.avg_pool2d(fmap2, 2, stride=2) 72 | self.pyramid.append((fmap1, fmap2)) 73 | 74 | def __call__(self, coords): 75 | coords = coords.permute(0, 2, 3, 1) 76 | B, H, W, _ = coords.shape 77 | dim = self.pyramid[0][0].shape[1] 78 | 79 | corr_list = [] 80 | for i in range(self.num_levels): 81 | r = self.radius 82 | fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous() 83 | fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous() 84 | 85 | coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous() 86 | corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r) 87 | corr_list.append(corr.squeeze(1)) 88 | 89 | corr = torch.stack(corr_list, dim=1) 90 | corr = corr.reshape(B, -1, H, W) 91 | return corr / torch.sqrt(torch.tensor(dim).float()) -------------------------------------------------------------------------------- /visual_planner/raft_utils/extractor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class ResidualBlock(nn.Module): 7 | def __init__(self, in_planes, planes, norm_fn='group', stride=1): 8 | super(ResidualBlock, self).__init__() 9 | 10 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride) 11 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) 12 | self.relu = nn.ReLU(inplace=True) 13 | 14 | num_groups = planes // 8 15 | 16 | if norm_fn == 'group': 17 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 18 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 19 | if not stride == 1: 20 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 21 | 22 | elif norm_fn == 'batch': 23 | self.norm1 = nn.BatchNorm2d(planes) 24 | self.norm2 = nn.BatchNorm2d(planes) 25 | if not stride == 1: 26 | self.norm3 = nn.BatchNorm2d(planes) 27 | 28 | elif norm_fn == 'instance': 29 | self.norm1 = nn.InstanceNorm2d(planes) 30 | self.norm2 = nn.InstanceNorm2d(planes) 31 | if not stride == 1: 32 | self.norm3 = nn.InstanceNorm2d(planes) 33 | 34 | elif norm_fn == 'none': 35 | self.norm1 = nn.Sequential() 36 | self.norm2 = nn.Sequential() 37 | if not stride == 1: 38 | self.norm3 = nn.Sequential() 39 | 40 | if stride == 1: 41 | self.downsample = None 42 | 43 | else: 44 | self.downsample = nn.Sequential( 45 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) 46 | 47 | 48 | def forward(self, x): 49 | y = x 50 | y = self.relu(self.norm1(self.conv1(y))) 51 | y = self.relu(self.norm2(self.conv2(y))) 52 | 53 | if self.downsample is not None: 54 | x = self.downsample(x) 55 | 56 | return self.relu(x+y) 57 | 58 | 59 | 60 | class BottleneckBlock(nn.Module): 61 | def __init__(self, in_planes, planes, norm_fn='group', stride=1): 62 | super(BottleneckBlock, self).__init__() 63 | 64 | self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0) 65 | self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride) 66 | self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0) 67 | self.relu = nn.ReLU(inplace=True) 68 | 69 | num_groups = planes // 8 70 | 71 | if norm_fn == 'group': 72 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) 73 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) 74 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 75 | if not stride == 1: 76 | self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 77 | 78 | elif norm_fn == 'batch': 79 | self.norm1 = nn.BatchNorm2d(planes//4) 80 | self.norm2 = nn.BatchNorm2d(planes//4) 81 | self.norm3 = nn.BatchNorm2d(planes) 82 | if not stride == 1: 83 | self.norm4 = nn.BatchNorm2d(planes) 84 | 85 | elif norm_fn == 'instance': 86 | self.norm1 = nn.InstanceNorm2d(planes//4) 87 | self.norm2 = nn.InstanceNorm2d(planes//4) 88 | self.norm3 = nn.InstanceNorm2d(planes) 89 | if not stride == 1: 90 | self.norm4 = nn.InstanceNorm2d(planes) 91 | 92 | elif norm_fn == 'none': 93 | self.norm1 = nn.Sequential() 94 | self.norm2 = nn.Sequential() 95 | self.norm3 = nn.Sequential() 96 | if not stride == 1: 97 | self.norm4 = nn.Sequential() 98 | 99 | if stride == 1: 100 | self.downsample = None 101 | 102 | else: 103 | self.downsample = nn.Sequential( 104 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4) 105 | 106 | 107 | def forward(self, x): 108 | y = x 109 | y = self.relu(self.norm1(self.conv1(y))) 110 | y = self.relu(self.norm2(self.conv2(y))) 111 | y = self.relu(self.norm3(self.conv3(y))) 112 | 113 | if self.downsample is not None: 114 | x = self.downsample(x) 115 | 116 | return self.relu(x+y) 117 | 118 | class BasicEncoder(nn.Module): 119 | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): 120 | super(BasicEncoder, self).__init__() 121 | self.norm_fn = norm_fn 122 | 123 | if self.norm_fn == 'group': 124 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) 125 | 126 | elif self.norm_fn == 'batch': 127 | self.norm1 = nn.BatchNorm2d(64) 128 | 129 | elif self.norm_fn == 'instance': 130 | self.norm1 = nn.InstanceNorm2d(64) 131 | 132 | elif self.norm_fn == 'none': 133 | self.norm1 = nn.Sequential() 134 | 135 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) 136 | self.relu1 = nn.ReLU(inplace=True) 137 | 138 | self.in_planes = 64 139 | self.layer1 = self._make_layer(64, stride=1) 140 | self.layer2 = self._make_layer(96, stride=2) 141 | self.layer3 = self._make_layer(128, stride=2) 142 | 143 | # output convolution 144 | self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1) 145 | 146 | self.dropout = None 147 | if dropout > 0: 148 | self.dropout = nn.Dropout2d(p=dropout) 149 | 150 | for m in self.modules(): 151 | if isinstance(m, nn.Conv2d): 152 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 153 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): 154 | if m.weight is not None: 155 | nn.init.constant_(m.weight, 1) 156 | if m.bias is not None: 157 | nn.init.constant_(m.bias, 0) 158 | 159 | def _make_layer(self, dim, stride=1): 160 | layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) 161 | layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) 162 | layers = (layer1, layer2) 163 | 164 | self.in_planes = dim 165 | return nn.Sequential(*layers) 166 | 167 | 168 | def forward(self, x): 169 | 170 | # if input is list, combine batch dimension 171 | is_list = isinstance(x, tuple) or isinstance(x, list) 172 | if is_list: 173 | batch_dim = x[0].shape[0] 174 | x = torch.cat(x, dim=0) 175 | 176 | x = self.conv1(x) 177 | x = self.norm1(x) 178 | x = self.relu1(x) 179 | 180 | x = self.layer1(x) 181 | x = self.layer2(x) 182 | x = self.layer3(x) 183 | 184 | x = self.conv2(x) 185 | 186 | if self.training and self.dropout is not None: 187 | x = self.dropout(x) 188 | 189 | if is_list: 190 | x = torch.split(x, [batch_dim, batch_dim], dim=0) 191 | 192 | return x 193 | 194 | 195 | class SmallEncoder(nn.Module): 196 | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): 197 | super(SmallEncoder, self).__init__() 198 | self.norm_fn = norm_fn 199 | 200 | if self.norm_fn == 'group': 201 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32) 202 | 203 | elif self.norm_fn == 'batch': 204 | self.norm1 = nn.BatchNorm2d(32) 205 | 206 | elif self.norm_fn == 'instance': 207 | self.norm1 = nn.InstanceNorm2d(32) 208 | 209 | elif self.norm_fn == 'none': 210 | self.norm1 = nn.Sequential() 211 | 212 | self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3) 213 | self.relu1 = nn.ReLU(inplace=True) 214 | 215 | self.in_planes = 32 216 | self.layer1 = self._make_layer(32, stride=1) 217 | self.layer2 = self._make_layer(64, stride=2) 218 | self.layer3 = self._make_layer(96, stride=2) 219 | 220 | self.dropout = None 221 | if dropout > 0: 222 | self.dropout = nn.Dropout2d(p=dropout) 223 | 224 | self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1) 225 | 226 | for m in self.modules(): 227 | if isinstance(m, nn.Conv2d): 228 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 229 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): 230 | if m.weight is not None: 231 | nn.init.constant_(m.weight, 1) 232 | if m.bias is not None: 233 | nn.init.constant_(m.bias, 0) 234 | 235 | def _make_layer(self, dim, stride=1): 236 | layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride) 237 | layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1) 238 | layers = (layer1, layer2) 239 | 240 | self.in_planes = dim 241 | return nn.Sequential(*layers) 242 | 243 | 244 | def forward(self, x): 245 | 246 | # if input is list, combine batch dimension 247 | is_list = isinstance(x, tuple) or isinstance(x, list) 248 | if is_list: 249 | batch_dim = x[0].shape[0] 250 | x = torch.cat(x, dim=0) 251 | 252 | x = self.conv1(x) 253 | x = self.norm1(x) 254 | x = self.relu1(x) 255 | 256 | x = self.layer1(x) 257 | x = self.layer2(x) 258 | x = self.layer3(x) 259 | x = self.conv2(x) 260 | 261 | if self.training and self.dropout is not None: 262 | x = self.dropout(x) 263 | 264 | if is_list: 265 | x = torch.split(x, [batch_dim, batch_dim], dim=0) 266 | 267 | return x -------------------------------------------------------------------------------- /visual_planner/raft_utils/raft.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from .update import BasicUpdateBlock, SmallUpdateBlock 7 | from .extractor import BasicEncoder, SmallEncoder 8 | from .corr import CorrBlock, AlternateCorrBlock 9 | from .utils import bilinear_sampler, coords_grid, upflow8 10 | 11 | 12 | try: 13 | autocast = torch.cuda.amp.autocast 14 | except: 15 | # dummy autocast for PyTorch < 1.6 16 | class autocast: 17 | def __init__(self, enabled): 18 | pass 19 | def __enter__(self): 20 | pass 21 | def __exit__(self, *args): 22 | pass 23 | 24 | 25 | class RAFT(nn.Module): 26 | def __init__(self, small=True): 27 | super(RAFT, self).__init__() 28 | 29 | 30 | if small: 31 | self.hidden_dim = hdim = 96 32 | self.context_dim = cdim = 64 33 | self.corr_levels = 4 34 | self.corr_radius = 3 35 | 36 | else: 37 | self.hidden_dim = hdim = 128 38 | self.context_dim = cdim = 128 39 | self.corr_levels = 4 40 | self.corr_radius = 4 41 | 42 | 43 | # feature network, context network, and update block 44 | if small: 45 | self.cnet = SmallEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=0) 46 | self.update_block = SmallUpdateBlock(self.corr_levels, self.corr_radius, hidden_dim=hdim) 47 | else: 48 | self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=0) 49 | self.update_block = BasicUpdateBlock(self.corr_levels, self.corr_radius, hidden_dim=hdim) 50 | 51 | def freeze_bn(self): 52 | for m in self.modules(): 53 | if isinstance(m, nn.BatchNorm2d): 54 | m.eval() 55 | 56 | def initialize_flow(self, img): 57 | """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" 58 | N, C, H, W = img.shape 59 | coords0 = coords_grid(N, H//8, W//8, device=img.device) 60 | coords1 = coords_grid(N, H//8, W//8, device=img.device) 61 | 62 | # optical flow computed as difference: flow = coords1 - coords0 63 | return coords0, coords1 64 | 65 | def upsample_flow(self, flow, mask): 66 | """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """ 67 | N, _, H, W = flow.shape 68 | mask = mask.view(N, 1, 9, 8, 8, H, W) 69 | mask = torch.softmax(mask, dim=2) 70 | 71 | up_flow = F.unfold(8 * flow, [3,3], padding=1) 72 | up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) 73 | 74 | up_flow = torch.sum(mask * up_flow, dim=2) 75 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) 76 | return up_flow.reshape(N, 2, 8*H, 8*W) 77 | 78 | 79 | def forward(self, image1, fmap1, fmap2, iters=12, flow_init=None, upsample=True, test_mode=False): 80 | """ Estimate optical flow between pair of frames """ 81 | 82 | 83 | corr_fn = CorrBlock(fmap1, fmap2, radius=self.corr_radius) 84 | 85 | # run the context network 86 | with autocast(enabled=True): 87 | cnet = self.cnet(image1) 88 | net, inp = torch.split(cnet, [self.hidden_dim, self.context_dim], dim=1) 89 | net = torch.tanh(net) 90 | inp = torch.relu(inp) 91 | 92 | 93 | coords0, coords1 = self.initialize_flow(image1) 94 | 95 | if flow_init is not None: 96 | coords1 = coords1 + flow_init 97 | 98 | flow_predictions = [] 99 | for itr in range(iters): 100 | coords1 = coords1.detach() 101 | corr = corr_fn(coords1) # index correlation volume 102 | 103 | flow = coords1 - coords0 104 | 105 | with autocast(enabled=True): 106 | net, up_mask, delta_flow = self.update_block(net, inp, corr, flow) 107 | 108 | # F(t+1) = F(t) + \Delta(t) 109 | coords1 = coords1 + delta_flow 110 | 111 | # upsample predictions 112 | if up_mask is None: 113 | flow_up = upflow8(coords1 - coords0) 114 | else: 115 | flow_up = self.upsample_flow(coords1 - coords0, up_mask) 116 | 117 | flow_predictions.append(flow_up) 118 | 119 | 120 | return flow_up 121 | # if test_mode: 122 | # return coords1 - coords0, flow_up 123 | 124 | # return flow_predictions 125 | 126 | -------------------------------------------------------------------------------- /visual_planner/raft_utils/update.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class FlowHead(nn.Module): 7 | def __init__(self, input_dim=128, hidden_dim=256): 8 | super(FlowHead, self).__init__() 9 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) 10 | self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) 11 | self.relu = nn.ReLU(inplace=True) 12 | 13 | def forward(self, x): 14 | return self.conv2(self.relu(self.conv1(x))) 15 | 16 | class ConvGRU(nn.Module): 17 | def __init__(self, hidden_dim=128, input_dim=192+128): 18 | super(ConvGRU, self).__init__() 19 | self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 20 | self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 21 | self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 22 | 23 | def forward(self, h, x): 24 | hx = torch.cat([h, x], dim=1) 25 | 26 | z = torch.sigmoid(self.convz(hx)) 27 | r = torch.sigmoid(self.convr(hx)) 28 | q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1))) 29 | 30 | h = (1-z) * h + z * q 31 | return h 32 | 33 | class SepConvGRU(nn.Module): 34 | def __init__(self, hidden_dim=128, input_dim=192+128): 35 | super(SepConvGRU, self).__init__() 36 | self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 37 | self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 38 | self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 39 | 40 | self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 41 | self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 42 | self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 43 | 44 | 45 | def forward(self, h, x): 46 | # horizontal 47 | hx = torch.cat([h, x], dim=1) 48 | z = torch.sigmoid(self.convz1(hx)) 49 | r = torch.sigmoid(self.convr1(hx)) 50 | q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1))) 51 | h = (1-z) * h + z * q 52 | 53 | # vertical 54 | hx = torch.cat([h, x], dim=1) 55 | z = torch.sigmoid(self.convz2(hx)) 56 | r = torch.sigmoid(self.convr2(hx)) 57 | q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1))) 58 | h = (1-z) * h + z * q 59 | 60 | return h 61 | 62 | class SmallMotionEncoder(nn.Module): 63 | def __init__(self, corr_levels, corr_radius): 64 | super(SmallMotionEncoder, self).__init__() 65 | cor_planes = corr_levels * (2*corr_radius + 1)**2 66 | self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0) 67 | self.convf1 = nn.Conv2d(2, 64, 7, padding=3) 68 | self.convf2 = nn.Conv2d(64, 32, 3, padding=1) 69 | self.conv = nn.Conv2d(128, 80, 3, padding=1) 70 | 71 | def forward(self, flow, corr): 72 | cor = F.relu(self.convc1(corr)) 73 | flo = F.relu(self.convf1(flow)) 74 | flo = F.relu(self.convf2(flo)) 75 | cor_flo = torch.cat([cor, flo], dim=1) 76 | out = F.relu(self.conv(cor_flo)) 77 | return torch.cat([out, flow], dim=1) 78 | 79 | class BasicMotionEncoder(nn.Module): 80 | def __init__(self, corr_levels, corr_radius): 81 | super(BasicMotionEncoder, self).__init__() 82 | cor_planes = corr_levels * (2*corr_radius + 1)**2 83 | self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) 84 | self.convc2 = nn.Conv2d(256, 192, 3, padding=1) 85 | self.convf1 = nn.Conv2d(2, 128, 7, padding=3) 86 | self.convf2 = nn.Conv2d(128, 64, 3, padding=1) 87 | self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1) 88 | 89 | def forward(self, flow, corr): 90 | cor = F.relu(self.convc1(corr)) 91 | cor = F.relu(self.convc2(cor)) 92 | flo = F.relu(self.convf1(flow)) 93 | flo = F.relu(self.convf2(flo)) 94 | 95 | cor_flo = torch.cat([cor, flo], dim=1) 96 | out = F.relu(self.conv(cor_flo)) 97 | return torch.cat([out, flow], dim=1) 98 | 99 | class SmallUpdateBlock(nn.Module): 100 | def __init__(self, corr_levels, corr_radius, hidden_dim=96): 101 | super(SmallUpdateBlock, self).__init__() 102 | self.encoder = SmallMotionEncoder(corr_levels, corr_radius) 103 | self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64) 104 | self.flow_head = FlowHead(hidden_dim, hidden_dim=128) 105 | 106 | def forward(self, net, inp, corr, flow): 107 | motion_features = self.encoder(flow, corr) 108 | inp = torch.cat([inp, motion_features], dim=1) 109 | net = self.gru(net, inp) 110 | delta_flow = self.flow_head(net) 111 | 112 | return net, None, delta_flow 113 | 114 | class BasicUpdateBlock(nn.Module): 115 | def __init__(self, corr_levels, corr_radius, hidden_dim=128, input_dim=128): 116 | super(BasicUpdateBlock, self).__init__() 117 | # self.args = args 118 | self.encoder = BasicMotionEncoder(corr_levels, corr_radius) 119 | self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim) 120 | self.flow_head = FlowHead(hidden_dim, hidden_dim=256) 121 | 122 | self.mask = nn.Sequential( 123 | nn.Conv2d(128, 256, 3, padding=1), 124 | nn.ReLU(inplace=True), 125 | nn.Conv2d(256, 64*9, 1, padding=0)) 126 | 127 | def forward(self, net, inp, corr, flow, upsample=True): 128 | motion_features = self.encoder(flow, corr) 129 | inp = torch.cat([inp, motion_features], dim=1) 130 | 131 | net = self.gru(net, inp) 132 | delta_flow = self.flow_head(net) 133 | 134 | # scale mask to balence gradients 135 | mask = .25 * self.mask(net) 136 | return net, mask, delta_flow 137 | 138 | 139 | -------------------------------------------------------------------------------- /visual_planner/raft_utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | from scipy import interpolate 5 | from torch.autograd import Variable 6 | 7 | 8 | def bilinear_sampler(img, coords, mode='bilinear', mask=False): 9 | """ Wrapper for grid_sample, uses pixel coordinates """ 10 | H, W = img.shape[-2:] 11 | xgrid, ygrid = coords.split([1,1], dim=-1) 12 | xgrid = 2*xgrid/(W-1) - 1 13 | ygrid = 2*ygrid/(H-1) - 1 14 | 15 | grid = torch.cat([xgrid, ygrid], dim=-1) 16 | img = F.grid_sample(img, grid, align_corners=True) 17 | 18 | if mask: 19 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) 20 | return img, mask.float() 21 | 22 | return img 23 | 24 | 25 | def coords_grid(batch, ht, wd, device): 26 | coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device)) 27 | coords = torch.stack(coords[::-1], dim=0).float() 28 | return coords[None].repeat(batch, 1, 1, 1) 29 | 30 | 31 | def upflow8(flow, mode='bilinear'): 32 | new_size = (8 * flow.shape[2], 8 * flow.shape[3]) 33 | return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) 34 | 35 | 36 | 37 | def flow_warp(img, flow, padding_mode='zeros'): 38 | """ 39 | Inverse warp a source image to the target image plane. 40 | 41 | Args: 42 | img: the source image (where to sample pixels) -- [B, 3, H, W] 43 | flow: flow map of the target image -- [B, 2, H, W] 44 | Returns: 45 | Source image warped to the target image plane 46 | """ 47 | 48 | bs, _, h, w = flow.size() 49 | u = flow[:,0,:,:] 50 | v = flow[:,1,:,:] 51 | 52 | # print(u.max(), v.max()) 53 | 54 | grid_x = Variable(torch.arange(0, w).view(1, 1, w).expand(1,h,w), requires_grad=False).type_as(u).expand_as(u) # [bs, H, W] 55 | grid_y = Variable(torch.arange(0, h).view(1, h, 1).expand(1,h,w), requires_grad=False).type_as(v).expand_as(v) # [bs, H, W] 56 | 57 | X = grid_x + u 58 | Y = grid_y + v 59 | 60 | X = 2*(X/(w-1.0) - 0.5) 61 | Y = 2*(Y/(h-1.0) - 0.5) 62 | grid_tf = torch.stack((X,Y), dim=3) 63 | img_tf = torch.nn.functional.grid_sample(img, grid_tf, padding_mode=padding_mode, align_corners=True) 64 | 65 | return img_tf 66 | 67 | def robust_l1(x, q=0.5, eps=1e-8): 68 | x = torch.pow((x.pow(2) + eps), q) 69 | x = x.mean() 70 | return x -------------------------------------------------------------------------------- /visual_planner/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import wandb 4 | import argparse 5 | 6 | # the first flag below was False when we tested this script but True makes A100 training a lot faster: 7 | torch.backends.cuda.matmul.allow_tf32 = True 8 | torch.backends.cudnn.allow_tf32 = True 9 | 10 | from trainer import GoalGaussianDiffusion, Trainer 11 | from visual_planner import VisualPlanner 12 | from transformers import CLIPTextModel, CLIPTokenizer 13 | 14 | from calvin_data import DiskCalvinDataset 15 | from pathlib import Path 16 | 17 | 18 | def world_info_from_env(): 19 | local_rank = 0 20 | for v in ( 21 | "LOCAL_RANK", 22 | "MPI_LOCALRANKID", 23 | "SLURM_LOCALID", 24 | "OMPI_COMM_WORLD_LOCAL_RANK", 25 | ): 26 | if v in os.environ: 27 | local_rank = int(os.environ[v]) 28 | break 29 | global_rank = 0 30 | for v in ("RANK", "PMI_RANK", "SLURM_PROCID", "OMPI_COMM_WORLD_RANK"): 31 | if v in os.environ: 32 | global_rank = int(os.environ[v]) 33 | break 34 | world_size = 1 35 | for v in ("WORLD_SIZE", "PMI_SIZE", "SLURM_NTASKS", "OMPI_COMM_WORLD_SIZE"): 36 | if v in os.environ: 37 | world_size = int(os.environ[v]) 38 | break 39 | 40 | return local_rank, global_rank, world_size 41 | 42 | 43 | def main(args, wandb = None): 44 | 45 | target_size = (128, 128) 46 | sampling_step = args.sampling_step 47 | window_size = args.sample_per_seq * sampling_step + sampling_step 48 | train_set = DiskCalvinDataset( 49 | datasets_dir=Path('calvin/dataset/task_ABC_D') / "training", 50 | window_size=window_size, 51 | sampling_step=sampling_step, 52 | image_size=target_size[0], 53 | with_depth=args.with_depth, 54 | ) 55 | 56 | valid_set = DiskCalvinDataset( 57 | datasets_dir=Path('calvin/dataset/task_ABC_D') / "validation", 58 | window_size=window_size, 59 | sampling_step=sampling_step, 60 | image_size=target_size[0], 61 | with_depth=args.with_depth, 62 | ) 63 | 64 | 65 | tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path = "openai/clip-vit-large-patch14") 66 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path = "openai/clip-vit-large-patch14") 67 | text_encoder.requires_grad_(False) 68 | text_encoder.eval() 69 | 70 | 71 | # We use UNet-based Diffusion model as the visual planner 72 | raw_input_channels = 8 if args.with_depth else 6 73 | model = VisualPlanner( 74 | image_size = target_size[0] // 8 if args.use_vae else target_size[0], 75 | in_channels = 8 if args.use_vae else raw_input_channels, 76 | out_channels = 8 if args.use_vae else raw_input_channels // 2, 77 | use_vae = args.use_vae, 78 | decoupled_output = False, 79 | temporal_length = args.sample_per_seq, # Number of frames to predict 80 | dims = 3, 81 | flow_reg = args.flow_reg, 82 | with_state_estimate = False, 83 | ) 84 | 85 | print( 86 | f"Model initialized with {sum(p.numel() for p in model.parameters() if p.requires_grad)} trainable parameters" 87 | ) 88 | 89 | if args.use_vae: 90 | from guided_diffusion.guided_diffusion import create_diffusion 91 | gaussian_diffusion = create_diffusion(timestep_respacing="", diffusion_steps=args.diffusion_steps) 92 | else: 93 | diffusion = GoalGaussianDiffusion( 94 | channels=3, 95 | model=model, 96 | image_size=target_size, 97 | timesteps=args.diffusion_steps, 98 | sampling_timesteps=args.sample_steps, 99 | loss_type='l2', 100 | objective='pred_v', 101 | beta_schedule = 'cosine', 102 | min_snr_loss_weight = True, 103 | auto_normalize = False, 104 | with_depth=args.with_depth, 105 | use_vae=args.use_vae, 106 | ) 107 | 108 | if os.path.exists(args.resume_path): 109 | model_ckpt = torch.load(args.resume_path)['model'] 110 | if args.use_vae: 111 | model.load_state_dict(model_ckpt) 112 | else: 113 | diffusion.load_state_dict(model_ckpt) # Model warped with diffusion 114 | print('resume ckpt successfully loaded form: ', args.resume_path) 115 | 116 | 117 | trainer = Trainer( 118 | args = args, 119 | model=model if args.use_vae else diffusion, 120 | diffusion = gaussian_diffusion if args.use_vae else None, 121 | latent_size = target_size[0] // 8 if args.use_vae else None, 122 | image_size = target_size, 123 | tokenizer=tokenizer, 124 | text_encoder=text_encoder, 125 | train_set=train_set, 126 | valid_set=valid_set, 127 | train_lr=args.learning_rate, 128 | train_num_steps = args.train_num_steps, 129 | save_and_sample_every = args.save_and_sample_every, 130 | ema_update_every = args.ema_update_every, 131 | ema_decay = args.ema_decay, 132 | train_batch_size = args.train_batch_size, 133 | valid_batch_size = args.val_batch_size, 134 | gradient_accumulate_every = args.gradient_accumulate_every, 135 | num_samples=1, 136 | results_folder = args.results_folder, 137 | fp16 = False, 138 | amp = False, 139 | wandb = wandb, 140 | use_vae=args.use_vae, 141 | with_depth=args.with_depth, 142 | cond_drop_chance = 0., # Classifier free guidance 143 | ) 144 | 145 | if os.path.exists(args.resume_path) and args.load_trainer: 146 | checkpoint_num = int(args.resume_path.split('/')[-1].split('-')[-1].split('.')[0]) 147 | trainer.load(checkpoint_num) 148 | 149 | if args.mode == 'train': 150 | trainer.train() 151 | else: 152 | model.eval() 153 | diffusion.eval() 154 | trainer.eval() 155 | 156 | 157 | 158 | if __name__ == "__main__": 159 | parser = argparse.ArgumentParser() 160 | parser.add_argument('-m', '--mode', type=str, default='train', choices=['train', 'val']) # set to 'inference' to generate samples 161 | parser.add_argument('-c', '--checkpoint_num', type=int, default=None) # set to checkpoint number to resume training or generate samples 162 | parser.add_argument('-p', '--inference_path', type=str, default=None) # set to path to generate samples 163 | parser.add_argument('-t', '--text', type=str, default=None) # set to text to generate samples 164 | parser.add_argument('-n', '--sample_steps', type=int, default=100) # set to number of steps to sample 165 | parser.add_argument('-g', '--guidance_weight', type=int, default=0) # set to positive to use guidance 166 | 167 | # Training Config 168 | parser.add_argument('--learning_rate', type=float, default=1e-4) 169 | parser.add_argument('--train_num_steps', type=int, default=60000) 170 | parser.add_argument('--save_and_sample_every', type=int, default=2500) 171 | parser.add_argument('--train_batch_size', type=int, default=16) 172 | parser.add_argument('--val_batch_size', type=int, default=1) 173 | parser.add_argument('--gradient_accumulate_every', type=int, default=1) 174 | parser.add_argument('--sample_per_seq', type=int, default=8) 175 | parser.add_argument('--resume_path', type=str, default='') 176 | parser.add_argument('--load_trainer', default=False, action="store_true") 177 | parser.add_argument('--sampling_step', type=int, default=5) 178 | 179 | # EMA config 180 | parser.add_argument('--ema_update_every', type=int, default=10) 181 | parser.add_argument('--ema_decay', type=float, default=0.999) 182 | 183 | # Model Config 184 | parser.add_argument('--use_vae', default=False, action="store_true") 185 | parser.add_argument('--flow_reg', default=False, action="store_true") 186 | parser.add_argument('--with_depth', default=False, action="store_true") 187 | parser.add_argument('--with_text_conditioning', default=False, action="store_true") 188 | parser.add_argument('--diffusion_steps', type=int, default=100) 189 | 190 | # Log Config 191 | parser.add_argument('--report_to_wandb', default=False, action="store_true") 192 | parser.add_argument('--run_name', type=str, default='train_visual_planner') 193 | parser.add_argument('--results_folder', type=str, default='../results/visual_planner') 194 | 195 | args = parser.parse_args() 196 | args.local_rank, args.rank, args.world_size = world_info_from_env() 197 | 198 | if args.rank == 0 and args.report_to_wandb: 199 | wandb.init( 200 | name=args.run_name, 201 | config=vars(args), 202 | ) 203 | 204 | if args.mode == 'inference': 205 | assert args.checkpoint_num is not None 206 | assert args.inference_path is not None 207 | assert args.text is not None 208 | assert args.sample_steps <= 100 209 | main(args, wandb=wandb) -------------------------------------------------------------------------------- /visual_planner/train.sh: -------------------------------------------------------------------------------- 1 | accelerate launch --config_file accelerate_cfg.yaml train.py \ 2 | --learning_rate 1e-4 \ 3 | --train_num_steps 300000 \ 4 | --save_and_sample_every 10000 \ 5 | --train_batch_size 32 \ 6 | --sample_per_seq 8 \ 7 | --sampling_step 5 \ 8 | --with_text_conditioning \ 9 | --diffusion_steps 100 \ 10 | --sample_steps 10 \ 11 | --with_depth \ 12 | --flow_reg \ 13 | --results_folder *path_to_save_ckpts* 14 | -------------------------------------------------------------------------------- /visual_planner/val.sh: -------------------------------------------------------------------------------- 1 | 2 | accelerate launch --config_file accelerate_cfg.yaml train.py \ 3 | --val_batch_size 32 \ 4 | --sample_per_seq 8 \ 5 | --sampling_step 5 \ 6 | --results_folder path_to_your_results_folder \ 7 | --resume_path patch_to_your_model_ckpt \ 8 | --with_text_conditioning \ 9 | --diffusion_steps 100 \ 10 | --sample_steps 10 \ 11 | --mode val \ 12 | --with_depth \ 13 | --flow_reg \ 14 | -------------------------------------------------------------------------------- /visual_planner/visual_planner.py: -------------------------------------------------------------------------------- 1 | from diffusion_model.unet import UNetModel 2 | from torch import nn 3 | import torch 4 | from einops import repeat, rearrange 5 | 6 | 7 | class VisualPlanner(nn.Module): 8 | # ResidualBlocks + ExtendedSelfAttn + CalsualTemporalAttn 9 | def __init__(self, 10 | image_size, 11 | in_channels, 12 | out_channels, # output channels (RGB + Depth) 13 | dims=3, # dimension of conv blocks 14 | temporal_length=8, # num of video frames 15 | use_vae=False, # whether to use VAE for frame encoding / decoding 16 | decoupled_output=False, # decoupled RGB & Depth output 17 | decoupled_input=False, # decoupled RGB & Depth input 18 | flow_reg=False, 19 | with_state_estimate=False 20 | ): 21 | super(VisualPlanner, self).__init__() 22 | self.unet = UNetModel( 23 | image_size=image_size, 24 | in_channels=in_channels, 25 | model_channels=64, 26 | out_channels=out_channels, 27 | num_res_blocks=2, 28 | attention_resolutions=(8, 16, ), 29 | dropout=0, 30 | channel_mult=(1, 2, 3, 4, 5), 31 | conv_resample=True, 32 | dims=dims, # whether to use temporal 1D-Conv 33 | num_classes=None, 34 | task_tokens=True, 35 | task_token_channels=768, # 768 for CLIP-Large Text Encoder 36 | use_checkpoint=False, 37 | use_fp16=False, 38 | num_head_channels=32, 39 | decoupled_output=decoupled_output, 40 | decoupled_input=decoupled_input, 41 | temporal_length=temporal_length, 42 | simple_adapter=False, 43 | flow_reg=flow_reg, 44 | ) 45 | self.use_vae = use_vae 46 | self.dims = dims 47 | self.flow_reg = flow_reg 48 | 49 | 50 | def forward(self, x, t, text_embedding=None, x_cond = None, **kwargs): 51 | if x_cond is not None: 52 | x_cond = repeat(x_cond.squeeze(1), 'b c h w -> b f c h w', f=x.shape[1]) 53 | x = torch.cat([x_cond, x], dim=2) # ( b, f, c * 2 , H, W ) 54 | 55 | b, f, c, h, w = x.shape 56 | if self.dims == 2: 57 | x = rearrange(x, 'b f c h w -> (b f) c h w') 58 | else: 59 | x = x.transpose(1,2) # ( b, c * 2 , f, H, W ) 60 | 61 | if self.flow_reg and kwargs['forward']: 62 | out, flow = self.unet(x, t, text_embedding, **kwargs) 63 | else: 64 | out = self.unet(x, t, text_embedding, **kwargs) 65 | 66 | if self.dims == 2: 67 | out = rearrange(out, '(b f) c h w -> b f c h w', b=b) 68 | else: 69 | out = out.transpose(1,2) 70 | 71 | if self.flow_reg and kwargs['forward']: 72 | return out, flow 73 | else: 74 | return out 75 | 76 | 77 | 78 | 79 | 80 | --------------------------------------------------------------------------------