├── diffip2d ├── utils │ ├── __init__.py │ ├── dist_util.py │ ├── fp16_util.py │ ├── nn.py │ ├── losses.py │ └── logger.py ├── post_decoder.py ├── pre_encoder.py ├── rounding.py └── step_sample.py ├── docs ├── pred.gif ├── motivation.png └── Diff-IP2D_appendix.pdf ├── train.sh ├── val_traj.sh ├── val_affordance.sh ├── requirements.txt ├── run_train.py ├── run_val_affordance.py ├── run_val_traj.py ├── netscripts ├── get_datasets.py ├── epoch_utils.py ├── get_network.py ├── modelio.py └── get_optimizer.py ├── networks ├── embedding.py ├── decoder_modules.py ├── affordance_decoder.py ├── traj_decoder.py ├── net_utils.py ├── model.py └── layer.py ├── basic_utils.py ├── options ├── expopts.py └── netsopts.py ├── datasets ├── datasetopts.py ├── ho_utils.py ├── input_loaders.py ├── dataloaders.py └── dataset_utils.py ├── preprocess ├── vis_util.py ├── dataset_util.py ├── traj_util.py ├── affordance_util.py ├── ho_types.py └── obj_util.py ├── evaluation ├── traj_eval.py └── affordance_eval.py ├── traineval.py └── README.md /diffip2d/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/pred.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IRMVLab/Diff-IP2D/HEAD/docs/pred.gif -------------------------------------------------------------------------------- /docs/motivation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IRMVLab/Diff-IP2D/HEAD/docs/motivation.png -------------------------------------------------------------------------------- /docs/Diff-IP2D_appendix.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IRMVLab/Diff-IP2D/HEAD/docs/Diff-IP2D_appendix.pdf -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | TORCH_DISTRIBUTED_DEBUG=DETAIL python -m torch.distributed.launch --nproc_per_node=2 --master_port=12224 --use_env run_train.py \ 2 | -------------------------------------------------------------------------------- /val_traj.sh: -------------------------------------------------------------------------------- 1 | TORCH_DISTRIBUTED_DEBUG=DETAIL python -m torch.distributed.launch --nproc_per_node=2 --master_port=12263 --use_env run_val_traj.py \ 2 | -------------------------------------------------------------------------------- /val_affordance.sh: -------------------------------------------------------------------------------- 1 | TORCH_DISTRIBUTED_DEBUG=DETAIL python -m torch.distributed.launch --nproc_per_node=2 --master_port=12233 --use_env run_val_affordance.py \ 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | einops==0.8.0 2 | joblib==1.3.1 3 | lmdb==1.4.1 4 | lmdbdict==0.2.2 5 | matplotlib==3.5.2 6 | numpy==1.19.0 7 | opencv_python==4.8.0.74 8 | pandas==1.3.5 9 | Pillow==11.0.0 10 | pretrainedmodels==0.7.4 11 | protobuf==3.20.3 12 | PyYAML==6.0.2 13 | scikit_learn==1.0.2 14 | scipy==1.7.3 15 | tensorflow==2.18.0 16 | torch==1.10.1+cu111 17 | torchnet==0.0.4 18 | torchvision==0.11.2+cu111 19 | tqdm==4.65.0 20 | transformers==4.26.0 21 | -------------------------------------------------------------------------------- /run_train.py: -------------------------------------------------------------------------------- 1 | # Developed by Junyi Ma 2 | # Diff-IP2D: Diffusion-Based Hand-Object Interaction Prediction on Egocentric Videos 3 | # https://github.com/IRMVLab/Diff-IP2D 4 | # We thank OCT (Liu et al.), Diffuseq (Gong et al.), and USST (Bao et al.) for providing the codebases. 5 | 6 | import sys 7 | import os 8 | import argparse 9 | import time 10 | sys.path.append('.') 11 | 12 | if __name__ == '__main__': 13 | 14 | # TODO add more options 15 | COMMANDLINE = f"python traineval.py --ek_version=ek100" \ 16 | 17 | print(COMMANDLINE) 18 | os.system(COMMANDLINE) -------------------------------------------------------------------------------- /run_val_affordance.py: -------------------------------------------------------------------------------- 1 | # Developed by Junyi Ma 2 | # Diff-IP2D: Diffusion-Based Hand-Object Interaction Prediction on Egocentric Videos 3 | # https://github.com/IRMVLab/Diff-IP2D 4 | # We thank OCT (Liu et al.), Diffuseq (Gong et al.), and USST (Bao et al.) for providing the codebases. 5 | 6 | import sys 7 | import os 8 | sys.path.append('.') 9 | 10 | if __name__ == '__main__': 11 | 12 | 13 | COMMANDLINE = f"python traineval.py --evaluate --ek_version=ek100 --resume=./diffip_weights/checkpoint_aff.pth.tar" \ 14 | 15 | print(COMMANDLINE) 16 | os.system(COMMANDLINE) -------------------------------------------------------------------------------- /run_val_traj.py: -------------------------------------------------------------------------------- 1 | # Developed by Junyi Ma 2 | # Diff-IP2D: Diffusion-Based Hand-Object Interaction Prediction on Egocentric Videos 3 | # https://github.com/IRMVLab/Diff-IP2D 4 | # We thank OCT (Liu et al.), Diffuseq (Gong et al.), and USST (Bao et al.) for providing the codebases. 5 | 6 | import sys 7 | import os 8 | sys.path.append('.') 9 | 10 | if __name__ == '__main__': 11 | 12 | COMMANDLINE = f"python traineval.py --evaluate --ek_version=ek100 --resume=./diffip_weights/checkpoint_traj.pth.tar --traj_only" \ 13 | 14 | print(COMMANDLINE) 15 | os.system(COMMANDLINE) -------------------------------------------------------------------------------- /netscripts/get_datasets.py: -------------------------------------------------------------------------------- 1 | from datasets.datasetopts import DatasetArgs 2 | from datasets.holoaders import EpicHODataset as HODataset, FeaturesHOLoader, get_dataloaders 3 | 4 | 5 | def get_dataset(args, base_path="./"): 6 | if args.evaluate: 7 | mode = "validation" 8 | else: 9 | mode = 'train' 10 | 11 | datasetargs = DatasetArgs(ek_version=args.ek_version, mode=mode, 12 | use_label_only=True, base_path=base_path, 13 | batch_size=args.batch_size, num_workers=args.workers) 14 | dls = get_dataloaders(datasetargs, HODataset, featuresloader=FeaturesHOLoader) 15 | return mode, dls 16 | -------------------------------------------------------------------------------- /netscripts/epoch_utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | 4 | def progress_bar(msg=None): 5 | 6 | L = [] 7 | if msg: 8 | L.append(msg) 9 | 10 | msg = ''.join(L) 11 | sys.stdout.write(msg+'\n') 12 | 13 | 14 | class AverageMeter(object): 15 | """Computes and stores the average and current value""" 16 | 17 | def __init__(self): 18 | self.reset() 19 | 20 | def reset(self): 21 | self.val = 0 22 | self.avg = 0 23 | self.sum = 0 24 | self.count = 0 25 | 26 | def update(self, val, n=1): 27 | self.val = val 28 | self.sum += val * n 29 | self.count += n 30 | self.avg = self.sum / self.count 31 | 32 | 33 | class AverageMeters: 34 | def __init__(self): 35 | super().__init__() 36 | self.average_meters = {} 37 | 38 | def add_loss_value(self, loss_name, loss_val, n=1): 39 | if loss_name not in self.average_meters: 40 | self.average_meters[loss_name] = AverageMeter() 41 | self.average_meters[loss_name].update(loss_val, n=n) -------------------------------------------------------------------------------- /diffip2d/post_decoder.py: -------------------------------------------------------------------------------- 1 | # Developed by Junyi Ma 2 | # Diff-IP2D: Diffusion-Based Hand-Object Interaction Prediction on Egocentric Videos 3 | # https://github.com/IRMVLab/Diff-IP2D 4 | # We thank OCT (Liu et al.), Diffuseq (Gong et al.), and USST (Bao et al.) for providing the codebases. 5 | 6 | from transformers import AutoConfig 7 | import torch 8 | 9 | import numpy as np 10 | import torch as th 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | from .utils.nn import ( 15 | SiLU, 16 | linear, 17 | timestep_embedding, 18 | SIG, 19 | ) 20 | 21 | class TrajDecoder(nn.Module): 22 | """ 23 | transform 512 channel to 2 channel 24 | """ 25 | 26 | def __init__( 27 | self, 28 | input_dims, 29 | output_dims, 30 | encoder_hidden_dims1, 31 | encoder_hidden_dims2, 32 | ): 33 | super().__init__() 34 | 35 | self.input_dims = input_dims 36 | self.hidden_t_dim1 = encoder_hidden_dims1 37 | self.hidden_t_dim2 = encoder_hidden_dims2 38 | self.output_dims = output_dims 39 | 40 | self.feat_embed = nn.Sequential( 41 | linear(input_dims, encoder_hidden_dims1), 42 | nn.ELU(), 43 | linear(encoder_hidden_dims1, encoder_hidden_dims2), 44 | nn.ELU(), 45 | linear(encoder_hidden_dims2, output_dims), 46 | SIG(), 47 | ) 48 | 49 | def forward(self, x): 50 | 51 | B = x.shape[0] 52 | T = x.shape[1] 53 | F = x.shape[2] 54 | x = x.view(B*T, F) 55 | x = self.feat_embed(x) 56 | x = x.view(B, T, x.shape[-1]) 57 | 58 | return x -------------------------------------------------------------------------------- /networks/embedding.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class PositionalEncoding(nn.Module): 8 | def __init__(self, d_model, max_len=5000): 9 | super(PositionalEncoding, self).__init__() 10 | pe = torch.zeros(max_len, d_model) 11 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 12 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)) 13 | pe[:, 0::2] = torch.sin(position * div_term) 14 | pe[:, 1::2] = torch.cos(position * div_term) 15 | pe = pe.unsqueeze(0) 16 | 17 | self.register_buffer('pe', pe) 18 | 19 | def forward(self, x): 20 | return x + self.pe[:, :x.shape[1]] 21 | 22 | 23 | class Encoder_PositionalEmbedding(nn.Module): 24 | def __init__(self, d_model, seq_len): 25 | super(Encoder_PositionalEmbedding, self).__init__() 26 | self.position_embedding = nn.Parameter(torch.zeros(1, seq_len, d_model)) 27 | 28 | def forward(self, x): 29 | B, T = x.shape[:2] 30 | if T != self.position_embedding.size(1): 31 | position_embedding = self.position_embedding.transpose(1, 2) 32 | new_position_embedding = F.interpolate(position_embedding, size=(T), mode='nearest') 33 | new_position_embedding = new_position_embedding.transpose(1, 2) 34 | x = x + new_position_embedding 35 | else: 36 | x = x + self.position_embedding 37 | return x 38 | 39 | 40 | class Decoder_PositionalEmbedding(nn.Module): 41 | def __init__(self, d_model, seq_len): 42 | super(Decoder_PositionalEmbedding, self).__init__() 43 | self.position_embedding = nn.Parameter(torch.zeros(1, seq_len, d_model)) 44 | 45 | def forward(self, x): 46 | x = x + self.position_embedding[:, :x.shape[1], :] 47 | return x 48 | 49 | -------------------------------------------------------------------------------- /diffip2d/utils/dist_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for distributed training. 3 | """ 4 | 5 | import io 6 | import os 7 | import socket 8 | import torch as th 9 | import torch.distributed as dist 10 | 11 | 12 | def setup_dist(): 13 | """ 14 | Setup a distributed process group. 15 | """ 16 | if dist.is_initialized(): 17 | return 18 | 19 | backend = "gloo" if not th.cuda.is_available() else "nccl" 20 | 21 | if backend == "gloo": 22 | hostname = "localhost" 23 | else: 24 | hostname = socket.gethostbyname(socket.getfqdn()) 25 | 26 | if os.environ.get("LOCAL_RANK") is None: 27 | os.environ["MASTER_ADDR"] = hostname 28 | os.environ["RANK"] = str(0) 29 | os.environ["WORLD_SIZE"] = str(1) 30 | port = _find_free_port() 31 | os.environ["MASTER_PORT"] = str(port) 32 | os.environ['LOCAL_RANK'] = str(0) 33 | 34 | dist.init_process_group(backend=backend, init_method="env://") 35 | 36 | if th.cuda.is_available(): # This clears remaining caches in GPU 0 37 | th.cuda.set_device(dev()) 38 | th.cuda.empty_cache() 39 | 40 | 41 | def dev(): 42 | """ 43 | Get the device to use for torch.distributed. 44 | """ 45 | if th.cuda.is_available(): 46 | return th.device(f"cuda:{os.environ['LOCAL_RANK']}") 47 | return th.device("cpu") 48 | 49 | 50 | def sync_params(params): 51 | """ 52 | Synchronize a sequence of Tensors across ranks from rank 0. 53 | """ 54 | for p in params: 55 | with th.no_grad(): 56 | dist.broadcast(p, 0) 57 | 58 | 59 | def _find_free_port(): 60 | try: 61 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 62 | s.bind(("", 0)) 63 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 64 | return s.getsockname()[1] 65 | finally: 66 | s.close() 67 | 68 | def get_rank(): 69 | return dist.get_rank() -------------------------------------------------------------------------------- /diffip2d/pre_encoder.py: -------------------------------------------------------------------------------- 1 | # Developed by Junyi Ma 2 | # Diff-IP2D: Diffusion-Based Hand-Object Interaction Prediction on Egocentric Videos 3 | # https://github.com/IRMVLab/Diff-IP2D 4 | # We thank OCT (Liu et al.), Diffuseq (Gong et al.), and USST (Bao et al.) for providing the codebases. 5 | 6 | from transformers import AutoConfig 7 | import torch 8 | 9 | import numpy as np 10 | import torch as th 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | from .utils.nn import ( 15 | SiLU, 16 | linear, 17 | timestep_embedding, 18 | ) 19 | 20 | class SideFusionEncoder(nn.Module): 21 | """ 22 | transform 5 channel to 1 channel 23 | """ 24 | 25 | def __init__( 26 | self, 27 | input_dims, 28 | output_dims, 29 | encoder_hidden_dims, 30 | ): 31 | super().__init__() 32 | 33 | self.input_dims = input_dims 34 | self.hidden_t_dim = encoder_hidden_dims 35 | self.output_dims = output_dims 36 | 37 | self.feat_embed = nn.Sequential( 38 | # linear(input_dims, encoder_hidden_dims), 39 | # SiLU(), 40 | # linear(encoder_hidden_dims, output_dims), 41 | linear(input_dims, output_dims), 42 | ) 43 | 44 | def forward(self, x): 45 | 46 | B = x.shape[0] 47 | T = x.shape[1] 48 | F = x.shape[2] 49 | x = x.view(B*T, F) 50 | x = self.feat_embed(x) 51 | x = x.view(B, T, x.shape[-1]) 52 | 53 | return x 54 | 55 | class MotionEncoder(nn.Module): 56 | """ 57 | transform 9 channel to 512 channel 58 | """ 59 | 60 | def __init__( 61 | self, 62 | input_dims, 63 | output_dims, 64 | encoder_hidden_dims, 65 | ): 66 | super().__init__() 67 | 68 | self.input_dims = input_dims 69 | self.hidden_t_dim = encoder_hidden_dims 70 | self.output_dims = output_dims 71 | 72 | self.feat_embed = nn.Sequential( 73 | linear(input_dims, output_dims), 74 | ) 75 | 76 | def forward(self, x): 77 | 78 | B = x.shape[0] 79 | T = x.shape[1] 80 | F = x.shape[2] 81 | x = x.view(B*T, F) 82 | x = self.feat_embed(x) 83 | x = x.view(B, T, x.shape[-1]) 84 | 85 | return x -------------------------------------------------------------------------------- /networks/decoder_modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class VAE(nn.Module): 6 | 7 | def __init__(self, in_dim, hidden_dim, latent_dim, conditional=False, condition_dim=None): 8 | 9 | super().__init__() 10 | 11 | self.latent_dim = latent_dim 12 | self.conditional = conditional 13 | 14 | if self.conditional and condition_dim is not None: 15 | input_dim = in_dim + condition_dim 16 | dec_dim = latent_dim + condition_dim 17 | else: 18 | input_dim = in_dim 19 | dec_dim = latent_dim 20 | self.enc_MLP = nn.Sequential( 21 | nn.Linear(input_dim, hidden_dim), 22 | nn.ELU()) 23 | self.linear_means = nn.Linear(hidden_dim, latent_dim) 24 | self.linear_log_var = nn.Linear(hidden_dim, latent_dim) 25 | self.dec_MLP = nn.Sequential( 26 | nn.Linear(dec_dim, hidden_dim), 27 | nn.ELU(), 28 | nn.Linear(hidden_dim, in_dim)) 29 | 30 | 31 | def forward(self, x, c=None, return_pred=False): 32 | if self.conditional and c is not None: 33 | inp = torch.cat((x, c), dim=-1) 34 | else: 35 | inp = x 36 | h = self.enc_MLP(inp) 37 | mean = self.linear_means(h) 38 | log_var = self.linear_log_var(h) 39 | z = self.reparameterize(mean, log_var) 40 | if self.conditional and c is not None: 41 | z = torch.cat((z, c), dim=-1) 42 | recon_x = self.dec_MLP(z) 43 | recon_loss, KLD = self.loss_fn(recon_x, x, mean, log_var) 44 | if not return_pred: 45 | return recon_loss, KLD 46 | else: 47 | return recon_x, recon_loss, KLD 48 | 49 | def loss_fn(self, recon_x, x, mean, log_var): 50 | recon_loss = torch.sum((recon_x - x) ** 2, dim=1) 51 | KLD = -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp(), dim=1) 52 | return recon_loss, KLD 53 | 54 | def reparameterize(self, mu, log_var): 55 | std = torch.exp(0.5 * log_var) 56 | eps = torch.randn_like(std) 57 | return mu + eps * std 58 | 59 | def inference(self, z, c=None): 60 | if self.conditional and c is not None: 61 | z = torch.cat((z, c), dim=-1) 62 | recon_x = self.dec_MLP(z) 63 | return recon_x -------------------------------------------------------------------------------- /diffip2d/utils/fp16_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers to train with 16-bit precision. 3 | """ 4 | 5 | import torch.nn as nn 6 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors 7 | 8 | 9 | def convert_module_to_f16(l): 10 | """ 11 | Convert primitive modules to float16. 12 | """ 13 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 14 | l.weight.data = l.weight.data.half() 15 | l.bias.data = l.bias.data.half() 16 | 17 | 18 | def convert_module_to_f32(l): 19 | """ 20 | Convert primitive modules to float32, undoing convert_module_to_f16(). 21 | """ 22 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 23 | l.weight.data = l.weight.data.float() 24 | l.bias.data = l.bias.data.float() 25 | 26 | 27 | def make_master_params(model_params): 28 | """ 29 | Copy model parameters into a (differently-shaped) list of full-precision 30 | parameters. 31 | """ 32 | master_params = _flatten_dense_tensors( 33 | [param.detach().float() for param in model_params] 34 | ) 35 | master_params = nn.Parameter(master_params) 36 | master_params.requires_grad = True 37 | return [master_params] 38 | 39 | 40 | def model_grads_to_master_grads(model_params, master_params): 41 | """ 42 | Copy the gradients from the model parameters into the master parameters 43 | from make_master_params(). 44 | """ 45 | master_params[0].grad = _flatten_dense_tensors( 46 | [param.grad.data.detach().float() for param in model_params] 47 | ) 48 | 49 | 50 | def master_params_to_model_params(model_params, master_params): 51 | """ 52 | Copy the master parameter data back into the model parameters. 53 | """ 54 | # Without copying to a list, if a generator is passed, this will 55 | # silently not copy any parameters. 56 | model_params = list(model_params) 57 | 58 | for param, master_param in zip( 59 | model_params, unflatten_master_params(model_params, master_params) 60 | ): 61 | param.detach().copy_(master_param) 62 | 63 | 64 | def unflatten_master_params(model_params, master_params): 65 | """ 66 | Unflatten the master parameters to look like model_params. 67 | """ 68 | return _unflatten_dense_tensors(master_params[0].detach(), model_params) 69 | 70 | 71 | def zero_grad(model_params): 72 | for param in model_params: 73 | # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group 74 | if param.grad is not None: 75 | param.grad.detach_() 76 | param.grad.zero_() 77 | -------------------------------------------------------------------------------- /basic_utils.py: -------------------------------------------------------------------------------- 1 | # Developed by Junyi Ma 2 | # Diff-IP2D: Diffusion-Based Hand-Object Interaction Prediction on Egocentric Videos 3 | # https://github.com/IRMVLab/Diff-IP2D 4 | # We thank OCT (Liu et al.), Diffuseq (Gong et al.), and USST (Bao et al.) for providing the codebases. 5 | 6 | from diffip2d import gaussian_diffusion as gd 7 | from diffip2d.gaussian_diffusion import HOIDiffusion, space_timesteps 8 | from diffip2d.transformer_model import TransformerNetModel, MADT, RL_HOITransformerNetModel 9 | from diffip2d.pre_encoder import SideFusionEncoder, MotionEncoder 10 | from diffip2d.post_decoder import TrajDecoder 11 | 12 | def create_network_and_diffusion( 13 | hidden_t_dim, 14 | hidden_dim, 15 | vocab_size, 16 | config_name, 17 | use_plm_init, 18 | dropout, 19 | diffusion_steps, 20 | noise_schedule, 21 | learn_sigma, 22 | timestep_respacing, 23 | predict_xstart, 24 | rescale_timesteps, 25 | sigma_small, 26 | rescale_learned_sigmas, 27 | use_kl, 28 | sf_encoder_hidden, 29 | traj_decoder_hidden1, 30 | traj_decoder_hidden2, 31 | motion_encoder_hidden, 32 | madt_depth, 33 | feat_num=3, # global hand object 34 | traj_dim=2, # 2D traj on egocentric video 35 | homo_dim=3, # homography matrix 36 | **kwargs, 37 | ): 38 | 39 | # we will support more input params for different structures 40 | sf_encoder = SideFusionEncoder(input_dims=feat_num * hidden_dim, output_dims=hidden_dim, encoder_hidden_dims=sf_encoder_hidden) 41 | traj_decoder = TrajDecoder(input_dims=hidden_dim, output_dims=traj_dim, encoder_hidden_dims1=traj_decoder_hidden1, encoder_hidden_dims2=traj_decoder_hidden2) 42 | motion_encoder = MotionEncoder(input_dims=homo_dim * homo_dim, output_dims=hidden_dim, encoder_hidden_dims=motion_encoder_hidden) 43 | denoised_model = MADT( 44 | input_dims=hidden_dim, 45 | output_dims=(hidden_dim if not learn_sigma else hidden_dim*2), 46 | hidden_t_dim=hidden_t_dim, 47 | dropout=dropout, 48 | depth=madt_depth, 49 | ) 50 | 51 | betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps) 52 | if not timestep_respacing: 53 | timestep_respacing = [diffusion_steps] 54 | diffusion = HOIDiffusion( 55 | use_timesteps=space_timesteps(diffusion_steps, timestep_respacing), 56 | betas=betas, 57 | rescale_timesteps=rescale_timesteps, 58 | predict_xstart=predict_xstart, 59 | learn_sigmas = learn_sigma, 60 | sigma_small = sigma_small, 61 | use_kl = use_kl, 62 | rescale_learned_sigmas=rescale_learned_sigmas 63 | ) 64 | 65 | return sf_encoder, denoised_model, diffusion, traj_decoder, motion_encoder -------------------------------------------------------------------------------- /networks/affordance_decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from networks.decoder_modules import VAE 4 | 5 | 6 | class AffordanceCVAE(nn.Module): 7 | def __init__(self, in_dim, hidden_dim, latent_dim, condition_dim, coord_dim=None, 8 | pred_len=4, condition_traj=True, z_scale=2.0): 9 | super().__init__() 10 | self.latent_dim = latent_dim 11 | self.condition_traj = True 12 | self.z_scale = z_scale 13 | if self.condition_traj: 14 | if coord_dim is None: 15 | coord_dim = hidden_dim // 2 16 | self.coord_dim = coord_dim 17 | self.traj_to_feature = nn.Sequential( 18 | nn.Linear(2*(pred_len+1), coord_dim*(pred_len+1), bias=False), 19 | nn.ELU(inplace=True)) 20 | self.traj_context_fusion = nn.Sequential( 21 | nn.Linear(condition_dim+coord_dim*(pred_len+1), condition_dim, bias=False), 22 | nn.ELU(inplace=True)) 23 | 24 | self.cvae = VAE(in_dim=in_dim, hidden_dim=hidden_dim, latent_dim=latent_dim, 25 | conditional=True, condition_dim=condition_dim) 26 | 27 | 28 | def forward(self, context, contact_point, hand_traj=None, return_pred=False): 29 | if self.condition_traj: 30 | assert hand_traj is not None 31 | batch_size = context.shape[0] 32 | hand_traj = hand_traj.reshape(batch_size, -1) 33 | traj_feat = self.traj_to_feature(hand_traj) 34 | fusion_feat = torch.cat([context, traj_feat], dim=1) 35 | condition_context = self.traj_context_fusion(fusion_feat) 36 | else: 37 | condition_context = context 38 | if not return_pred: 39 | recon_loss, KLD = self.cvae(contact_point, c=condition_context) 40 | return recon_loss, KLD 41 | else: 42 | pred_contact, recon_loss, KLD = self.cvae(contact_point, c=condition_context, return_pred=return_pred) 43 | return pred_contact, recon_loss, KLD 44 | 45 | def inference(self, context, hand_traj=None): 46 | if self.condition_traj: 47 | assert hand_traj is not None 48 | batch_size = context.shape[0] 49 | hand_traj = hand_traj.reshape(batch_size, -1) 50 | traj_feat = self.traj_to_feature(hand_traj) 51 | fusion_feat = torch.cat([context, traj_feat], dim=1) 52 | condition_context = self.traj_context_fusion(fusion_feat) 53 | else: 54 | condition_context = context 55 | z = self.z_scale * torch.randn([condition_context.shape[0], self.latent_dim], device=condition_context.device) 56 | recon_x = self.cvae.inference(z, c=condition_context) 57 | return recon_x -------------------------------------------------------------------------------- /options/expopts.py: -------------------------------------------------------------------------------- 1 | def add_exp_opts(parser): 2 | parser.add_argument("--resume", type=str, nargs="+", metavar="PATH", 3 | help="path to latest checkpoint (default: none)") 4 | parser.add_argument("--evaluate", dest="evaluate", action="store_true", 5 | help="evaluate model on validation set") 6 | parser.add_argument("--test_freq", type=int, default=100, 7 | help="testing frequency on evaluation dataset (set specific in traineval.py)") 8 | parser.add_argument("--snapshot", default=1, type=int, metavar="N", 9 | help="How often to take a snapshot of the model (0 = never)") 10 | parser.add_argument("--use_cuda", default=1, type=int, help="use GPU (default: True)") 11 | parser.add_argument('--ek_version', default="ek55", choices=["ek55", "ek100"], help="epic dataset version") 12 | parser.add_argument("--traj_only", action="store_true", help="evaluate traj on validation dataset") 13 | parser.add_argument("--schedule_sampler_args", default="lossaware", choices=["uniform", "lossaware", "fixstep"], help="loss schedule for diffusion") 14 | parser.add_argument("--seq_len_obs", default=10, type=int, help="length of observed (past) sequence") 15 | parser.add_argument("--seq_len_unobs", default=4, type=int, help="length of unobserved (future) sequence") 16 | parser.add_argument("--learnable_weight", default=False, type=bool, help="whether to use learnable loss weights") 17 | parser.add_argument("--rec_loss_weight", default=1.0, type=float, help="initial value of diffusion losses") 18 | parser.add_argument("--reg_loss_weight", default=0.2, type=float, help="initial value of regularization loss") 19 | parser.add_argument("--use_schedule", default=False, type=bool, help="whether to specify optimizer schedule") 20 | parser.add_argument("--sample_times", default=10, type=int, help="how many samples for one prediction") 21 | parser.add_argument("--fast_test", default=True, type=bool, help="whether to use faster inference") 22 | 23 | def add_path_opts(parser): 24 | parser.add_argument("--base_model", default="./base_models/model.pth.tar", type=str, nargs="+", metavar="PATH", help="path to base model") 25 | parser.add_argument("--log_path", default="./log", type=str, nargs="+", metavar="PATH", help="path to record logs") 26 | parser.add_argument("--checkpoint_path", default="./diffip_weights", type=str, nargs="+", metavar="PATH", help="path to save checkpoints") 27 | parser.add_argument("--collection_path_traj", default="./collected_pred_traj", type=str, nargs="+", metavar="PATH", help="path to gather traj eval results") 28 | parser.add_argument("--collection_path_aff", default="./collected_pred_aff", type=str, nargs="+", metavar="PATH", help="path to gather aff eval results") 29 | -------------------------------------------------------------------------------- /datasets/datasetopts.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | 5 | class DatasetArgs(object): 6 | def __init__(self, ek_version='ek55', mode="train", use_label_only=True, 7 | base_path="./", batch_size=32, num_workers=0, modalities=['feat'], 8 | fps=4, t_buffer=2.5): 9 | 10 | self.features_paths = { 11 | 'ek55': os.path.join(base_path, 'data/ek55/feats'), 12 | 'ek100': os.path.join(base_path, 'data/ek100/feats')} 13 | # generated data labels 14 | self.label_path = { 15 | 'ek55': os.path.join(base_path, 'data/ek55'), 16 | 'ek100': os.path.join(base_path, 'data/ek100')} 17 | 18 | # amazon-annotated eval labels 19 | self.eval_label_path = { 20 | 'ek55': os.path.join(base_path, 'data/ek55/ek55_eval_labels.pkl'), 21 | 'ek100': os.path.join(base_path, 'data/ek100/ek100_eval_labels.pkl') 22 | } 23 | 24 | self.annot_path = { 25 | 'ek55': os.path.join(base_path, 'common/epic-kitchens-55-annotations'), 26 | 'ek100': os.path.join(base_path, 'common/epic-kitchens-100-annotations')} 27 | 28 | self.rulstm_annot_path = { 29 | 'ek55': os.path.join(base_path, 'common/rulstm/RULSTM/data/ek55'), 30 | 'ek100': os.path.join(base_path, 'common/rulstm/RULSTM/data/ek100')} 31 | 32 | self.pretrained_backbone_path = { 33 | 'ek55': os.path.join(base_path, 'common/rulstm/FEATEXT/models/ek55', 'TSN-rgb.pth.tar'), 34 | 'ek100': os.path.join(base_path, 'common/rulstm/FEATEXT/models/ek100', 'TSN-rgb-ek100.pth.tar'), 35 | } 36 | 37 | # default settings, no need changes 38 | if fps is None: 39 | self.fps = 4 40 | else: 41 | self.fps = fps 42 | 43 | if t_buffer is None: 44 | self.t_buffer = 2.5 45 | else: 46 | self.t_buffer = t_buffer 47 | 48 | self.ori_fps = 60.0 49 | self.t_ant = 1.0 50 | 51 | self.validation_ratio = 0.2 52 | self.use_rulstm_splits = True 53 | 54 | # only preprocess uids that have corresponding labels, in "video_info.json" 55 | self.use_label_only = use_label_only 56 | 57 | self.task = 'anticipation' 58 | self.num_actions_prev = 1 59 | 60 | self.batch_size = batch_size 61 | self.num_workers = num_workers 62 | 63 | self.modalities = modalities 64 | self.ek_version = ek_version # 'ek55' or 'ek100' 65 | self.mode = mode # 'train' 66 | 67 | def add_attr(self, attr_name, attr_value): 68 | setattr(self, attr_name, attr_value) 69 | 70 | def has_attr(self, attr_name): 71 | return hasattr(self, attr_name) 72 | 73 | def __repr__(self): 74 | return 'Input Args: ' + json.dumps(self.__dict__, indent=4) 75 | 76 | -------------------------------------------------------------------------------- /netscripts/get_network.py: -------------------------------------------------------------------------------- 1 | # Developed by Junyi Ma 2 | # Diff-IP2D: Diffusion-Based Hand-Object Interaction Prediction on Egocentric Videos 3 | # https://github.com/IRMVLab/Diff-IP2D 4 | # We thank OCT (Liu et al.), Diffuseq (Gong et al.), and USST (Bao et al.) for providing the codebases. 5 | 6 | import torch 7 | from networks.traj_decoder import TrajCVAE 8 | from networks.affordance_decoder import AffordanceCVAE 9 | from networks.transformer import ObjectTransformer,ObjectTransformerModel 10 | from networks.model import Model 11 | 12 | def get_network_for_diffip(args, num_frames_input=10, num_frames_output=4): 13 | net = ObjectTransformerModel(src_in_features=args.src_in_features, 14 | trg_in_features=args.trg_in_features, 15 | num_patches=args.num_patches, 16 | encoder_time_embed_type=args.encoder_time_embed_type, 17 | decoder_time_embed_type=args.decoder_time_embed_type, 18 | num_frames_input=num_frames_input, 19 | num_frames_output=num_frames_output, 20 | embed_dim=args.embed_dim, coord_dim=args.coord_dim, 21 | num_heads=args.num_heads, enc_depth=args.enc_depth, dec_depth=args.dec_depth) 22 | 23 | obj_head = AffordanceCVAE(in_dim=2, hidden_dim=512, latent_dim=256, condition_dim=512) 24 | return net, obj_head 25 | 26 | 27 | def get_network(args, num_frames_input=10, num_frames_output=4): 28 | hand_head = TrajCVAE(in_dim=2, hidden_dim=args.hidden_dim, 29 | latent_dim=args.latent_dim, condition_dim=args.embed_dim, 30 | coord_dim=args.coord_dim) 31 | obj_head = AffordanceCVAE(in_dim=2, hidden_dim=args.hidden_dim, 32 | latent_dim=args.latent_dim, condition_dim=args.embed_dim) 33 | net = ObjectTransformer(src_in_features=args.src_in_features, 34 | trg_in_features=args.trg_in_features, 35 | num_patches=args.num_patches, 36 | hand_head=hand_head, obj_head=obj_head, 37 | encoder_time_embed_type=args.encoder_time_embed_type, 38 | decoder_time_embed_type=args.decoder_time_embed_type, 39 | num_frames_input=num_frames_input, 40 | num_frames_output=num_frames_output, 41 | embed_dim=args.embed_dim, coord_dim=args.coord_dim, 42 | num_heads=args.num_heads, enc_depth=args.enc_depth, dec_depth=args.dec_depth) 43 | net = torch.nn.DataParallel(net) 44 | 45 | model = Model(net, lambda_obj=args.lambda_obj, lambda_traj=args.lambda_traj, 46 | lambda_obj_kl=args.lambda_obj_kl, lambda_traj_kl=args.lambda_traj_kl) 47 | return model 48 | 49 | -------------------------------------------------------------------------------- /preprocess/vis_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import matplotlib.pyplot as plt 4 | from preprocess.affordance_util import compute_heatmap 5 | 6 | hand_rgb = {"LEFT": (0, 90, 181), "RIGHT": (220, 50, 32)} 7 | object_rgb = (255, 194, 10) 8 | 9 | 10 | def vis_traj(frame_vis, traj, fill_indices=None, side=None, circle_radis=4, circle_thickness=3, line_thickness=2, style='line', gap=5): 11 | for idx in range(len(traj)): 12 | x, y = traj[idx] 13 | if fill_indices is not None and idx in fill_indices: 14 | thickness = -1 15 | else: 16 | thickness = -1 17 | color = hand_rgb[side][::-1] if side is not None else (0, 255, 255) 18 | frame_vis = cv2.circle(frame_vis, (int(round(x)), int(round(y))), radius=circle_radis, color=color, 19 | thickness=thickness) 20 | if idx > 0: 21 | pt1 = (int(round(traj[idx-1][0])), int(round(traj[idx-1][1]))) 22 | pt2 = (int(round(traj[idx][0])), int(round(traj[idx][1]))) 23 | dist = ((pt1[0] - pt2[0]) ** 2 + (pt1[1] - pt2[1]) ** 2) ** .5 24 | pts = [] 25 | for i in np.arange(0, dist, gap): 26 | r = i / dist 27 | x = int((pt1[0] * (1 - r) + pt2[0] * r) + .5) 28 | y = int((pt1[1] * (1 - r) + pt2[1] * r) + .5) 29 | p = (x, y) 30 | pts.append(p) 31 | if style == 'dotted': 32 | for p in pts: 33 | cv2.circle(frame_vis, p, circle_thickness, color, -1) 34 | else: 35 | if len(pts) > 0: 36 | s = pts[0] 37 | e = pts[0] 38 | i = 0 39 | for p in pts: 40 | s = e 41 | e = p 42 | if i % 2 == 1: 43 | cv2.line(frame_vis, s, e, color, line_thickness) 44 | i += 1 45 | return frame_vis 46 | 47 | 48 | def vis_hand_traj(frames, hand_trajs): 49 | frame_vis = frames[0].copy() 50 | for side in hand_trajs: 51 | meta = hand_trajs[side] 52 | traj, fill_indices = meta["traj"], meta["fill_indices"] 53 | frame_vis = vis_traj(frame_vis, traj, fill_indices, side) 54 | return frame_vis 55 | 56 | 57 | def vis_affordance(frame, affordance_info): 58 | select_points = affordance_info["select_points_homo"] 59 | hmap = compute_heatmap(select_points, (frame.shape[1], frame.shape[0])) 60 | hmap = (hmap * 255).astype(np.uint8) 61 | hmap = cv2.applyColorMap(hmap, colormap=cv2.COLORMAP_JET) 62 | for idx in range((len(select_points))): 63 | point = select_points[idx].astype(np.int) 64 | frame_vis = cv2.circle(frame, (point[0], point[1]), radius=2, color=(255, 0, 255), 65 | thickness=-1) 66 | overlay = (0.7 * frame + 0.3 * hmap).astype(np.uint8) 67 | return overlay 68 | -------------------------------------------------------------------------------- /netscripts/modelio.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import traceback 4 | import warnings 5 | import torch 6 | 7 | 8 | def load_checkpoint(model, resume_path, strict=True, device=None): 9 | if os.path.isfile(resume_path): 10 | print("=> loading checkpoint '{}'".format(resume_path)) 11 | if device is not None: 12 | checkpoint = torch.load(resume_path, map_location=device) 13 | else: 14 | checkpoint = torch.load(resume_path) 15 | if "module" in list(checkpoint["state_dict"].keys())[0]: 16 | state_dict = checkpoint["state_dict"] 17 | else: 18 | state_dict = { 19 | "module.{}".format(key): item 20 | for key, item in checkpoint["state_dict"].items()} 21 | print("=> loaded checkpoint '{}' (epoch {})".format(resume_path, checkpoint["epoch"])) 22 | missing_states = set(model.state_dict().keys()) - set(state_dict.keys()) 23 | if len(missing_states) > 0: 24 | warnings.warn("Missing keys ! : {}".format(missing_states)) 25 | model.load_state_dict(state_dict, strict=strict) 26 | else: 27 | raise ValueError("=> no checkpoint found at '{}'".format(resume_path)) 28 | return checkpoint["epoch"] 29 | 30 | 31 | def load_checkpoint_by_name(model, resume_path, state_dict_name, strict=True, device=None): 32 | if os.path.isfile(resume_path): 33 | print("=> loading "+state_dict_name+" checkpoint '{}'".format(resume_path)) 34 | if device is not None: 35 | checkpoint = torch.load(resume_path, map_location=device) 36 | else: 37 | checkpoint = torch.load(resume_path) 38 | if "module" in list(checkpoint[state_dict_name].keys())[0]: 39 | state_dict = checkpoint[state_dict_name] 40 | else: 41 | state_dict = { 42 | "module.{}".format(key): item 43 | for key, item in checkpoint[state_dict_name].items()} 44 | print("=> loaded "+state_dict_name+" checkpoint '{}' (epoch {})".format(resume_path, checkpoint["epoch"])) 45 | missing_states = set(model.state_dict().keys()) - set(state_dict.keys()) 46 | if len(missing_states) > 0: 47 | warnings.warn("Missing keys ! : {}".format(missing_states)) 48 | model.load_state_dict(state_dict, strict=strict) 49 | else: 50 | raise ValueError("=> no checkpoint found at '{}'".format(resume_path)) 51 | return checkpoint["epoch"] 52 | 53 | def load_optimizer(optimizer, resume_path, state_dict_name='optimizer', device=None): 54 | if os.path.isfile(resume_path): 55 | print("=> loading "+state_dict_name+" checkpoint '{}'".format(resume_path)) 56 | checkpoint = torch.load(resume_path) 57 | optimizer.load_state_dict(checkpoint) 58 | else: 59 | raise ValueError("=> no optimizer checkpoint found at '{}'".format(resume_path)) 60 | 61 | def save_checkpoint(state, checkpoint="checkpoint", filename="checkpoint.pth.tar"): 62 | filepath = os.path.join(checkpoint, filename) 63 | torch.save(state, filepath) 64 | -------------------------------------------------------------------------------- /diffip2d/utils/nn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various utilities for neural networks. 3 | """ 4 | 5 | import math 6 | 7 | import torch as th 8 | import torch.nn as nn 9 | 10 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 11 | class SiLU(nn.Module): 12 | def forward(self, x): 13 | return x * th.sigmoid(x) 14 | 15 | class SIG(nn.Module): 16 | def forward(self, x): 17 | return th.sigmoid(x) 18 | 19 | class GroupNorm32(nn.GroupNorm): 20 | def forward(self, x): 21 | return super().forward(x.float()).type(x.dtype) 22 | 23 | def linear(*args, **kwargs): 24 | """ 25 | Create a linear module. 26 | """ 27 | return nn.Linear(*args, **kwargs) 28 | 29 | 30 | def avg_pool_nd(dims, *args, **kwargs): 31 | """ 32 | Create a 1D, 2D, or 3D average pooling module. 33 | """ 34 | if dims == 1: 35 | return nn.AvgPool1d(*args, **kwargs) 36 | elif dims == 2: 37 | return nn.AvgPool2d(*args, **kwargs) 38 | elif dims == 3: 39 | return nn.AvgPool3d(*args, **kwargs) 40 | raise ValueError(f"unsupported dimensions: {dims}") 41 | 42 | 43 | def update_ema(target_params, source_params, rate=0.99): 44 | """ 45 | Update target parameters to be closer to those of source parameters using 46 | an exponential moving average. 47 | 48 | :param target_params: the target parameter sequence. 49 | :param source_params: the source parameter sequence. 50 | :param rate: the EMA rate (closer to 1 means slower). 51 | """ 52 | for targ, src in zip(target_params, source_params): 53 | targ.detach().mul_(rate).add_(src, alpha=1 - rate) 54 | 55 | 56 | def zero_module(module): 57 | """ 58 | Zero out the parameters of a module and return it. 59 | """ 60 | for p in module.parameters(): 61 | p.detach().zero_() 62 | return module 63 | 64 | 65 | def scale_module(module, scale): 66 | """ 67 | Scale the parameters of a module and return it. 68 | """ 69 | for p in module.parameters(): 70 | p.detach().mul_(scale) 71 | return module 72 | 73 | 74 | def mean_flat(tensor): 75 | """ 76 | Take the mean over all non-batch dimensions. 77 | """ 78 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 79 | 80 | 81 | def normalization(channels): 82 | """ 83 | Make a standard normalization layer. 84 | 85 | :param channels: number of input channels. 86 | :return: an nn.Module for normalization. 87 | """ 88 | return GroupNorm32(32, channels) 89 | 90 | 91 | def timestep_embedding(timesteps, dim, max_period=10000): 92 | 93 | half = dim // 2 94 | freqs = th.exp( 95 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half 96 | ).to(device=timesteps.device) 97 | args = timesteps[:, None].float() * freqs[None] 98 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) 99 | if dim % 2: 100 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) 101 | return embedding 102 | -------------------------------------------------------------------------------- /networks/traj_decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import einops 4 | from networks.affordance_decoder import VAE 5 | 6 | 7 | class TrajCVAE(nn.Module): 8 | def __init__(self, in_dim, hidden_dim, latent_dim, condition_dim, coord_dim=None, 9 | condition_contact=False, z_scale=2.0): 10 | super().__init__() 11 | self.latent_dim = latent_dim 12 | self.condition_contact = condition_contact 13 | self.z_scale = z_scale 14 | if self.condition_contact: 15 | if coord_dim is None: 16 | coord_dim = hidden_dim // 2 17 | self.coord_dim = coord_dim 18 | self.contact_to_feature = nn.Sequential( 19 | nn.Linear(2, coord_dim, bias=False), 20 | nn.ELU(inplace=True)) 21 | self.contact_context_fusion = nn.Sequential( 22 | nn.Linear(condition_dim+coord_dim, condition_dim, bias=False), 23 | nn.ELU(inplace=True)) 24 | 25 | self.cvae = VAE(in_dim=in_dim, hidden_dim=hidden_dim, latent_dim=latent_dim, 26 | conditional=True, condition_dim=condition_dim) 27 | 28 | 29 | def forward(self, context, target_hand, future_valid, contact_point=None, return_pred=False): 30 | batch_size = future_valid.shape[0] 31 | if self.condition_contact: 32 | assert contact_point is not None 33 | time_steps = int(context.shape[0] / batch_size / 2) 34 | contact_feat = self.contact_to_feature(contact_point) 35 | contact_feat = einops.repeat(contact_feat, 'm n -> m p q n', p=2, q=time_steps) 36 | contact_feat = contact_feat.reshape(-1, self.coord_dim) 37 | fusion_feat = torch.cat([context, contact_feat], dim=1) 38 | condition_context = self.contact_context_fusion(fusion_feat) 39 | else: 40 | condition_context = context 41 | if not return_pred: 42 | recon_loss, KLD = self.cvae(target_hand, c=condition_context) 43 | else: 44 | pred_hand, recon_loss, KLD = self.cvae(target_hand, c=condition_context, return_pred=return_pred) 45 | KLD = KLD.reshape(batch_size, 2, -1).sum(-1) 46 | KLD = (KLD * future_valid).sum(1) 47 | recon_loss = recon_loss.reshape(batch_size, 2, -1).sum(-1) 48 | traj_loss = (recon_loss * future_valid).sum(1) 49 | if not return_pred: 50 | return traj_loss, KLD 51 | else: 52 | return pred_hand, traj_loss, KLD 53 | 54 | def inference(self, context, contact_point=None): 55 | if self.condition_contact: 56 | assert contact_point is not None 57 | batch_size = contact_point.shape[0] 58 | time_steps = int(context.shape[0] / batch_size) 59 | contact_feat = self.contact_to_feature(contact_point) 60 | contact_feat = einops.repeat(contact_feat, 'm n -> m p n', p=time_steps) 61 | contact_feat = contact_feat.reshape(-1, self.coord_dim) 62 | fusion_feat = torch.cat([context, contact_feat], dim=1) 63 | condition_context = self.contact_context_fusion(fusion_feat) 64 | else: 65 | condition_context = context 66 | z = self.z_scale * torch.randn([context.shape[0], self.latent_dim], device=context.device) 67 | recon_x = self.cvae.inference(z, c=condition_context) 68 | return recon_x -------------------------------------------------------------------------------- /networks/net_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | import warnings 5 | 6 | 7 | def get_pad_mask(seq, pad_idx=0): 8 | if seq.dim() != 2: 9 | raise ValueError(" has to be a 2-dimensional tensor!") 10 | if not isinstance(pad_idx, int): 11 | raise TypeError(" has to be an int!") 12 | 13 | return (seq != pad_idx).unsqueeze(1) 14 | 15 | 16 | def get_subsequent_mask(seq, diagonal=1): 17 | if seq.dim() < 2: 18 | raise ValueError(" has to be at least a 2-dimensional tensor!") 19 | 20 | seq_len = seq.size(1) 21 | mask = (1 - torch.triu(torch.ones((1, seq_len, seq_len), device=seq.device), diagonal=diagonal)).bool() 22 | return mask 23 | 24 | 25 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 26 | def norm_cdf(x): 27 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 28 | 29 | if (mean < a - 2 * std) or (mean > b + 2 * std): 30 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 31 | "The distribution of values may be incorrect.", 32 | stacklevel=2) 33 | 34 | with torch.no_grad(): 35 | l = norm_cdf((a - mean) / std) 36 | u = norm_cdf((b - mean) / std) 37 | 38 | tensor.uniform_(2 * l - 1, 2 * u - 1) 39 | tensor.erfinv_() 40 | 41 | tensor.mul_(std * math.sqrt(2.)) 42 | tensor.add_(mean) 43 | tensor.clamp_(min=a, max=b) 44 | return tensor 45 | 46 | 47 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 48 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 49 | 50 | 51 | def drop_path(x, drop_prob: float = 0., training: bool = False): 52 | if drop_prob == 0. or not training: 53 | return x 54 | keep_prob = 1 - drop_prob 55 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) 56 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) 57 | random_tensor.floor_() 58 | output = x.div(keep_prob) * random_tensor 59 | return output 60 | 61 | 62 | class DropPath(nn.Module): 63 | def __init__(self, drop_prob=None): 64 | super(DropPath, self).__init__() 65 | self.drop_prob = drop_prob # 0.1 66 | 67 | def forward(self, x): 68 | return drop_path(x, self.drop_prob, self.training) 69 | 70 | 71 | def traj_affordance_dist(hand_traj, contact_point, future_valid=None, invalid_value=9): 72 | batch_size = contact_point.shape[0] 73 | expand_size = int(hand_traj.shape[0] / batch_size) 74 | contact_point = contact_point.unsqueeze(dim=1).expand(-1, expand_size, 2).reshape(-1, 2) 75 | dist = torch.sum((hand_traj - contact_point) ** 2, dim=1).reshape(batch_size, -1) 76 | if future_valid is None: 77 | sorted_dist, sorted_idx = torch.sort(dist, dim=-1, descending=False) 78 | return sorted_dist[:, 0] # (B, ) 79 | else: 80 | dist = dist.reshape(batch_size, 2, -1) 81 | future_valid = future_valid > 0 82 | future_invalid = ~future_valid[:, :, None].expand(dist.shape) 83 | dist[future_invalid] = invalid_value 84 | sorted_dist, sorted_idx = torch.sort(dist, dim=-1, descending=False) 85 | selected_dist = sorted_dist[:, :, 0] 86 | selected_dist, selected_idx = selected_dist.min(dim=1) 87 | valid = torch.gather(future_valid, dim=1, index=selected_idx.unsqueeze(dim=1)).squeeze(dim=1) 88 | selected_dist = selected_dist * valid 89 | return selected_dist -------------------------------------------------------------------------------- /networks/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Model(nn.Module): 6 | 7 | def __init__(self, net, lambda_obj=None, lambda_traj=None, lambda_obj_kl=None, lambda_traj_kl=None): 8 | super(Model, self).__init__() 9 | self.net = net 10 | self.lambda_obj = lambda_obj 11 | self.lambda_obj_kl = lambda_obj_kl 12 | self.lambda_traj = lambda_traj 13 | self.lambda_traj_kl = lambda_traj_kl 14 | 15 | def forward(self, feat, bbox_feat, valid_mask, future_hands=None, contact_point=None, future_valid=None, 16 | num_samples=5, pred_len=4): 17 | 18 | if self.training: 19 | losses = {} 20 | total_loss = 0 21 | traj_loss, traj_kl_loss, obj_loss, obj_kl_loss = self.net(feat, bbox_feat, valid_mask, future_hands, 22 | contact_point, future_valid) 23 | if self.lambda_traj is not None and traj_loss is not None: 24 | traj_loss = self.lambda_traj * traj_loss.sum() 25 | total_loss += traj_loss 26 | losses['traj_loss'] = traj_loss.detach().cpu() 27 | else: 28 | losses['traj_loss'] = 0. 29 | 30 | if self.lambda_traj_kl is not None and traj_kl_loss is not None: 31 | traj_kl_loss = self.lambda_traj_kl * traj_kl_loss.sum() 32 | total_loss += traj_kl_loss 33 | losses['traj_kl_loss'] = traj_kl_loss.detach().cpu() 34 | else: 35 | losses['traj_kl_loss'] = 0. 36 | 37 | if self.lambda_obj is not None and obj_loss is not None: 38 | obj_loss = self.lambda_obj * obj_loss.sum() 39 | total_loss += obj_loss 40 | losses['obj_loss'] = obj_loss.detach().cpu() 41 | else: 42 | losses['obj_loss'] = 0. 43 | 44 | if self.lambda_obj_kl is not None and obj_kl_loss is not None: 45 | obj_kl_loss = self.lambda_obj_kl * obj_kl_loss.sum() 46 | total_loss += obj_kl_loss 47 | losses['obj_kl_loss'] = obj_kl_loss.detach().cpu() 48 | else: 49 | losses['obj_kl_loss'] = 0. 50 | 51 | if total_loss is not None: 52 | losses["total_loss"] = total_loss.detach().cpu() 53 | else: 54 | losses["total_loss"] = 0. 55 | return total_loss, losses 56 | 57 | else: 58 | future_hands_list = [] 59 | contact_points_list = [] 60 | sentence_feature_output = 0 61 | 62 | for i in range(10): 63 | future_hands, contact_point, sentence_feature = self.net.module.inference(feat, bbox_feat, valid_mask, 64 | future_valid=future_valid, 65 | pred_len=pred_len) 66 | future_hands_list.append(future_hands) 67 | contact_points_list.append(contact_point) 68 | 69 | if i == 0: 70 | sentence_feature_output = sentence_feature 71 | else: 72 | assert torch.all(sentence_feature_output == sentence_feature) 73 | 74 | contact_points = torch.stack(contact_points_list, dim=0) 75 | 76 | assert len(contact_points.shape) == 3 77 | contact_points = contact_points.transpose(0, 1) 78 | 79 | future_hands_list = torch.stack(future_hands_list, dim=0) 80 | future_hands_list = future_hands_list.transpose(0, 1) 81 | return future_hands_list, contact_points, sentence_feature_output 82 | -------------------------------------------------------------------------------- /diffip2d/rounding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer, default_data_collator, GPT2TokenizerFast 3 | import sys, yaml, os 4 | import json 5 | 6 | import numpy as np 7 | 8 | def get_knn(model_emb, text_emb, dist='cos'): 9 | if dist == 'cos': 10 | adjacency = model_emb @ text_emb.transpose(1, 0).to(model_emb.device) 11 | elif dist == 'l2': 12 | adjacency = model_emb.unsqueeze(1).expand(-1, text_emb.size(0), -1) - text_emb.unsqueeze(0).expand( 13 | model_emb.size(0), -1, -1) 14 | adjacency = -torch.norm(adjacency, dim=-1) 15 | topk_out = torch.topk(adjacency, k=6, dim=0) 16 | return topk_out.values, topk_out.indices 17 | 18 | def get_efficient_knn(model_emb, text_emb): 19 | emb_norm = (model_emb**2).sum(-1).view(-1, 1) 20 | text_emb_t = torch.transpose(text_emb.view(-1, text_emb.size(-1)), 0, 1) 21 | arr_norm = (text_emb ** 2).sum(-1).view(-1, 1) 22 | dist = emb_norm + arr_norm.transpose(0, 1) - 2.0 * torch.mm(model_emb, text_emb_t) 23 | dist = torch.clamp(dist, 0.0, np.inf) 24 | topk_out = torch.topk(-dist, k=1, dim=0) 25 | return topk_out.values, topk_out.indices 26 | 27 | def rounding_func(text_emb_lst, model, tokenizer, emb_scale_factor=1.0): 28 | decoded_out_lst = [] 29 | 30 | model_emb = model.weight 31 | down_proj_emb2 = None 32 | 33 | dist = 'l2' 34 | 35 | for text_emb in text_emb_lst: 36 | import torch 37 | text_emb = torch.tensor(text_emb) 38 | if len(text_emb.shape) > 2: 39 | text_emb = text_emb.view(-1, text_emb.size(-1)) 40 | else: 41 | text_emb = text_emb 42 | val, indices = get_knn((down_proj_emb2 if dist == 'cos' else model_emb), 43 | text_emb.to(model_emb.device), dist=dist) 44 | 45 | decoded_out_lst.append(tokenizer.decode_token(indices[0])) 46 | 47 | return decoded_out_lst 48 | 49 | def compute_logp(args, model, x, input_ids): 50 | word_emb = model.weight 51 | sigma = 0.1 52 | if args.model_arch == '1d-unet': 53 | x = x.permute(0, 2, 1) 54 | 55 | bsz, seqlen, dim = x.shape 56 | 57 | x_flat = x.reshape(-1, x.size(-1)).unsqueeze(0) 58 | word_emb_flat = word_emb.unsqueeze(1) 59 | diff = (x_flat - word_emb_flat) ** 2 60 | 61 | logp_expanded = -diff.sum(dim=-1) / (2 * sigma ** 2) 62 | logp_expanded = logp_expanded.permute((1, 0)) 63 | 64 | ce = torch.nn.CrossEntropyLoss(reduction='none') 65 | loss = ce(logp_expanded, input_ids.view(-1)).view(bsz, seqlen) 66 | 67 | return loss 68 | 69 | def get_weights(model, args): 70 | if hasattr(model, 'transformer'): 71 | input_embs = model.transformer.wte 72 | down_proj = model.down_proj 73 | model_emb = down_proj(input_embs.weight) 74 | print(model_emb.shape) 75 | model = torch.nn.Embedding(model_emb.size(0), model_emb.size(1)) 76 | print(args.emb_scale_factor) 77 | model.weight.data = model_emb * args.emb_scale_factor 78 | 79 | elif hasattr(model, 'weight'): 80 | pass 81 | else: 82 | assert NotImplementedError 83 | 84 | model.weight.requires_grad = False 85 | return model 86 | 87 | def denoised_fn_round(args, model, text_emb, t): 88 | model_emb = model.weight 89 | old_shape = text_emb.shape 90 | old_device = text_emb.device 91 | 92 | if len(text_emb.shape) > 2: 93 | text_emb = text_emb.reshape(-1, text_emb.size(-1)) 94 | else: 95 | text_emb = text_emb 96 | val, indices = get_efficient_knn(model_emb, text_emb.to(model_emb.device)) 97 | rounded_tokens = indices[0] 98 | new_embeds = model(rounded_tokens).view(old_shape).to(old_device) 99 | 100 | return new_embeds -------------------------------------------------------------------------------- /datasets/ho_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import numpy as np 4 | 5 | 6 | def load_video_info(label_path, video_index): 7 | with open(os.path.join(label_path, "label_{}.pkl".format(video_index)), 'rb') as f: 8 | video_info = pickle.load(f) 9 | return video_info 10 | 11 | 12 | def sample_hand_traj(meta, fps, t_ant, shape=(456, 256)): 13 | width, height = shape 14 | traj = meta["traj"] 15 | ori_fps = int((len(traj) - 1) / t_ant) 16 | gap = int(ori_fps // fps) 17 | stop_idx = len(traj) 18 | indices = [0] + list(range(gap, stop_idx, gap)) 19 | hand_traj = [] 20 | for idx in indices: 21 | x, y = traj[idx] 22 | x, y, = x / width, y / height 23 | hand_traj.append(np.array([x, y], dtype=np.float32)) 24 | hand_traj = np.array(hand_traj, dtype=np.float32) 25 | return hand_traj, indices 26 | 27 | 28 | def sample_homo(homo_all, fps, t_ant): 29 | ori_fps = int((len(homo_all) - 1) / t_ant) 30 | gap = int(ori_fps // fps) 31 | stop_idx = len(homo_all) 32 | indices = [0] + list(range(gap, stop_idx, gap)) 33 | homo_mat = [] 34 | for idx in indices: 35 | homo_mat.append(homo_all[idx]) 36 | homo_transform = np.array(homo_mat, dtype=np.float32) 37 | return homo_transform, indices 38 | 39 | def process_video_info(video_info, fps=4, t_ant=1.0, shape=(456, 256)): 40 | frames_idxs = video_info["frame_indices"] 41 | hand_trajs = video_info["hand_trajs"] 42 | 43 | ''' 44 | hand_trajs["RIGHT"] = {"traj": right_complete_traj, "fill_indices": right_fill_indices, 45 | "fit_curve": right_curve, "centers": right_centers} 46 | ''' 47 | obj_affordance = video_info['affordance']['select_points_homo'] 48 | num_points = obj_affordance.shape[0] 49 | select_idx = np.random.choice(num_points, 1, replace=False) 50 | contact_point = obj_affordance[select_idx] 51 | cx, cy = contact_point[0] 52 | width, height = shape 53 | cx, cy = cx / width, cy/ height 54 | contact_point = np.array([cx, cy], dtype=np.float32) 55 | 56 | valid_mask = [] 57 | if "RIGHT" in hand_trajs: 58 | meta = hand_trajs["RIGHT"] 59 | rhand_traj, indices = sample_hand_traj(meta, fps, t_ant, shape) 60 | valid_mask.append(1) 61 | else: 62 | length = int(fps * t_ant + 1) 63 | rhand_traj = np.repeat(np.array([[0.75, 1.5]], dtype=np.float32), length, axis=0) 64 | valid_mask.append(0) 65 | 66 | if "LEFT" in hand_trajs: 67 | meta = hand_trajs["LEFT"] 68 | lhand_traj, indices = sample_hand_traj(meta, fps, t_ant, shape) 69 | valid_mask.append(1) 70 | else: 71 | length = int(fps * t_ant + 1) 72 | lhand_traj = np.repeat(np.array([[0.25, 1.5]], dtype=np.float32), length, axis=0) 73 | valid_mask.append(0) 74 | 75 | future_hands = np.stack((rhand_traj, lhand_traj), axis=0) 76 | future_valid = np.array(valid_mask, dtype=np.int) 77 | 78 | last_frame_index = frames_idxs[0] 79 | 80 | return future_hands, contact_point, future_valid, last_frame_index 81 | 82 | 83 | def process_eval_video_info(video_info, fps=4, t_ant=1.0): 84 | valid_mask = [] 85 | if "RIGHT" in video_info: 86 | rhand_traj = video_info["RIGHT"] 87 | assert rhand_traj.shape[0] == int(fps * t_ant + 1) 88 | valid_mask.append(1) 89 | else: 90 | rhand_traj = np.repeat(np.array([[0.75, 1.5]], dtype=np.float32), int(fps * t_ant + 1), axis=0) 91 | valid_mask.append(0) 92 | 93 | if "LEFT" in video_info: 94 | lhand_traj = video_info['LEFT'] 95 | assert lhand_traj.shape[0] == int(fps * t_ant + 1) 96 | valid_mask.append(1) 97 | else: 98 | lhand_traj = np.repeat(np.array([[0.25, 1.5]], dtype=np.float32), int(fps * t_ant + 1), axis=0) 99 | valid_mask.append(0) 100 | 101 | future_hands = np.stack((rhand_traj, lhand_traj), axis=0) 102 | future_valid = np.array(valid_mask, dtype=np.int) 103 | return future_hands, future_valid 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | -------------------------------------------------------------------------------- /diffip2d/utils/losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for various likelihood-based losses. These are ported from the original 3 | Ho et al. diffusion models codebase: 4 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py 5 | """ 6 | 7 | import numpy as np 8 | 9 | import torch as th 10 | 11 | 12 | def normal_kl(mean1, logvar1, mean2, logvar2): 13 | """ 14 | Compute the KL divergence between two gaussians. 15 | 16 | Shapes are automatically broadcasted, so batches can be compared to 17 | scalars, among other use cases. 18 | """ 19 | tensor = None 20 | for obj in (mean1, logvar1, mean2, logvar2): 21 | if isinstance(obj, th.Tensor): 22 | tensor = obj 23 | break 24 | assert tensor is not None, "at least one argument must be a Tensor" 25 | 26 | # Force variances to be Tensors. Broadcasting helps convert scalars to 27 | # Tensors, but it does not work for th.exp(). 28 | logvar1, logvar2 = [ 29 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) 30 | for x in (logvar1, logvar2) 31 | ] 32 | 33 | 34 | return 0.5 * ( 35 | -1.0 36 | + logvar2 37 | - logvar1 38 | + th.exp(logvar1 - logvar2) 39 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2) 40 | ) 41 | 42 | 43 | def approx_standard_normal_cdf(x): 44 | """ 45 | A fast approximation of the cumulative distribution function of the 46 | standard normal. 47 | """ 48 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) 49 | 50 | 51 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 52 | """ 53 | Compute the log-likelihood of a Gaussian distribution discretizing to a 54 | given image. 55 | 56 | :param x: the target images. It is assumed that this was uint8 values, 57 | rescaled to the range [-1, 1]. 58 | :param means: the Gaussian mean Tensor. 59 | :param log_scales: the Gaussian log stddev Tensor. 60 | :return: a tensor like x of log probabilities (in nats). 61 | """ 62 | assert x.shape == means.shape == log_scales.shape 63 | centered_x = x - means 64 | inv_stdv = th.exp(-log_scales) 65 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 66 | cdf_plus = approx_standard_normal_cdf(plus_in) 67 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 68 | cdf_min = approx_standard_normal_cdf(min_in) 69 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 70 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 71 | cdf_delta = cdf_plus - cdf_min 72 | log_probs = th.where( 73 | x < -0.999, 74 | log_cdf_plus, 75 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 76 | ) 77 | assert log_probs.shape == x.shape 78 | return log_probs 79 | 80 | def gaussian_density(x, *, means, log_scales): 81 | from torch.distributions import Normal 82 | normal_dist = Normal(means, log_scales.exp()) 83 | logp = normal_dist.log_prob(x) 84 | return logp 85 | 86 | 87 | def discretized_text_log_likelihood(x, *, means, log_scales): 88 | """ 89 | Compute the log-likelihood of a Gaussian distribution discretizing to a 90 | given image. 91 | 92 | :param x: the target images. It is assumed that this was uint8 values, 93 | rescaled to the range [-1, 1]. 94 | :param means: the Gaussian mean Tensor. 95 | :param log_scales: the Gaussian log stddev Tensor. 96 | :return: a tensor like x of log probabilities (in nats). 97 | """ 98 | 99 | centered_x = x - means 100 | inv_stdv = th.exp(-log_scales) 101 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 102 | cdf_plus = approx_standard_normal_cdf(plus_in) 103 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 104 | cdf_min = approx_standard_normal_cdf(min_in) 105 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 106 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 107 | cdf_delta = cdf_plus - cdf_min 108 | log_probs = th.where( 109 | x < -0.999, 110 | log_cdf_plus, 111 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 112 | ) 113 | assert log_probs.shape == x.shape 114 | return log_probs 115 | -------------------------------------------------------------------------------- /netscripts/get_optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class Warmup(torch.optim.lr_scheduler._LRScheduler): 4 | def __init__( 5 | self, 6 | optimizer: torch.optim.Optimizer, 7 | scheduler: torch.optim.lr_scheduler._LRScheduler, 8 | init_lr_ratio: float = 0.0, 9 | num_epochs: int = 5, 10 | last_epoch: int = -1, 11 | iters_per_epoch: int = None, 12 | ): 13 | self.base_scheduler = scheduler 14 | self.warmup_iters = max(num_epochs * iters_per_epoch, 1) 15 | if self.warmup_iters > 1: 16 | self.init_lr_ratio = init_lr_ratio 17 | else: 18 | self.init_lr_ratio = 1.0 19 | super().__init__(optimizer, last_epoch) 20 | 21 | def get_lr(self): 22 | assert self.last_epoch < self.warmup_iters 23 | return [ 24 | el * (self.init_lr_ratio + (1 - self.init_lr_ratio) * 25 | (float(self.last_epoch) / self.warmup_iters)) 26 | for el in self.base_lrs 27 | ] 28 | 29 | def step(self, *args, **kwargs): 30 | if self.last_epoch < (self.warmup_iters - 1): 31 | super().step(*args, **kwargs) 32 | else: 33 | self.base_scheduler.step(*args, **kwargs) 34 | 35 | 36 | def get_optimizer(args, sf_encoder, model_denoise, traj_decoder, train_loader, model_hoi=None, motion_encoder=None, obj_head=None): 37 | assert train_loader is not None, "train_loader is None, " \ 38 | "warmup or cosine learning rate need number of iterations in dataloader" 39 | iters_per_epoch = len(train_loader) 40 | sf_encoder_params = [p for p_name, p in sf_encoder.named_parameters() if p.requires_grad] 41 | model_denoise_params = [p for p_name, p in model_denoise.named_parameters() if p.requires_grad] 42 | traj_decoder_params = [p for p_name, p in traj_decoder.named_parameters() if p.requires_grad] 43 | model_hoi_params = [p for p_name, p in model_hoi.named_parameters() if p.requires_grad] 44 | motion_encoder_params = [p for p_name, p in motion_encoder.named_parameters() if p.requires_grad] 45 | obj_head_params = [p for p_name, p in obj_head.named_parameters() if p.requires_grad] 46 | 47 | if args.optimizer == "adam": 48 | optimizer = torch.optim.Adam([{'params': model_denoise_params, 'weight_decay': 0.0}, {'params': sf_encoder_params}, {'params': traj_decoder_params}, {'params': model_hoi_params}], 49 | lr=args.lr, weight_decay=args.weight_decay) 50 | elif args.optimizer == "rms": 51 | optimizer = torch.optim.RMSprop([{'params': model_denoise_params, 'weight_decay': 0.0}, {'params': sf_encoder_params}, {'params': traj_decoder_params}, {'params': model_hoi_params}], 52 | lr=args.lr, weight_decay=args.weight_decay) 53 | elif args.optimizer == "sgd": 54 | optimizer = torch.optim.SGD([{'params': model_denoise_params, 'weight_decay': 0.0}, {'params': sf_encoder_params}, {'params': traj_decoder_params}, {'params': model_hoi_params}], 55 | lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 56 | elif args.optimizer == 'adamw': 57 | optimizer = torch.optim.AdamW([{'params': model_denoise_params, 'weight_decay': 0.0}, {'params': sf_encoder_params}, 58 | {'params': traj_decoder_params}, {'params': model_hoi_params}, {'params': motion_encoder_params}, {'params': obj_head_params}], 59 | lr=args.lr) 60 | else: 61 | raise ValueError("unsupported optimizer type") 62 | 63 | for group in optimizer.param_groups: 64 | group["lr"] = args.lr 65 | group["initial_lr"] = args.lr 66 | 67 | if args.scheduler == "step": 68 | assert isinstance(args.lr_decay_step, int), "learning rate scheduler need integar lr_decay_step" 69 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_decay_step, gamma=args.lr_decay_gamma) 70 | elif args.scheduler == "multistep": 71 | if isinstance(args.lr_decay_step, list): 72 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, args.lr_decay_step, gamma=args.lr_decay_gamma) 73 | else: 74 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.epochs // 2, gamma=0.1) 75 | elif args.scheduler == "cosine": 76 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs*iters_per_epoch, 77 | last_epoch=-1, eta_min=0) 78 | else: 79 | raise ValueError("Unrecognized learning rate scheduler {}".format(args.scheduler)) 80 | 81 | main_scheduler = Warmup(optimizer, scheduler, init_lr_ratio=0., num_epochs=args.warmup_epochs, 82 | iters_per_epoch=iters_per_epoch) 83 | return optimizer, main_scheduler 84 | 85 | 86 | 87 | 88 | -------------------------------------------------------------------------------- /evaluation/traj_eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def compute_ade(pred_traj, gt_traj, valid_traj=None, reduction=True): 6 | valid_loc = (gt_traj[:, :, :, 0] >= 0) & (gt_traj[:, :, :, 1] >= 0) \ 7 | & (gt_traj[:, :, :, 0] < 1) & (gt_traj[:, :, :, 1] < 1) 8 | 9 | error = gt_traj - pred_traj 10 | error = error * valid_loc[:, :, :, None] 11 | 12 | if torch.is_tensor(error): 13 | if valid_traj is None: 14 | valid_traj = torch.ones(pred_traj.shape[0], pred_traj.shape[1]) 15 | error = error ** 2 16 | ade = torch.sqrt(error.sum(dim=3)).mean(dim=2) * valid_traj 17 | if reduction: 18 | ade = ade.sum() / valid_traj.sum() 19 | valid_traj = valid_traj.sum() 20 | else: 21 | if valid_traj is None: 22 | valid_traj = np.ones((pred_traj.shape[0], pred_traj.shape[1]), dtype=int) 23 | error = np.linalg.norm(error, axis=3) 24 | ade = error.mean(axis=2) * valid_traj 25 | if reduction: 26 | ade = ade.sum() / valid_traj.sum() 27 | valid_traj = valid_traj.sum() 28 | 29 | return ade, valid_traj 30 | 31 | 32 | def compute_fde(pred_traj, gt_traj, valid_traj=None, reduction=True): 33 | pred_last = pred_traj[:, :, -1, :] 34 | gt_last = gt_traj[:, :, -1, :] 35 | 36 | valid_loc = (gt_last[:, :, 0] >= 0) & (gt_last[:, :, 1] >= 0) \ 37 | & (gt_last[:, :, 0] < 1) & (gt_last[:, :, 1] < 1) 38 | 39 | error = gt_last - pred_last 40 | error = error * valid_loc[:, :, None] 41 | 42 | if torch.is_tensor(error): 43 | if valid_traj is None: 44 | valid_traj = torch.ones(pred_traj.shape[0], pred_traj.shape[1]) 45 | error = error ** 2 46 | fde = torch.sqrt(error.sum(dim=2)) * valid_traj 47 | if reduction: 48 | fde = fde.sum() / valid_traj.sum() 49 | valid_traj = valid_traj.sum() 50 | else: 51 | if valid_traj is None: 52 | valid_traj = np.ones((pred_traj.shape[0], pred_traj.shape[1]), dtype=int) 53 | error = np.linalg.norm(error, axis=2) 54 | fde = error * valid_traj 55 | if reduction: 56 | fde = fde.sum() / valid_traj.sum() 57 | valid_traj = valid_traj.sum() 58 | 59 | return fde, valid_traj 60 | 61 | 62 | def evaluate_traj_stochastic(preds, gts, valids): 63 | # preds[B*bs,20,2,n,2] gts[B*bs,2,n,2] 64 | len_dataset, num_samples, num_obj = preds.shape[0], preds.shape[1], preds.shape[2] 65 | ade_list, fde_list = [], [] 66 | for idx in range(num_samples): 67 | ade, _ = compute_fde(preds[:, idx, :, :, :], gts, valids, reduction=False) 68 | ade_list.append(ade) 69 | fde, _ = compute_ade(preds[:, idx, :, :, :], gts, valids, reduction=False) 70 | fde_list.append(fde) 71 | 72 | if torch.is_tensor(preds): 73 | ade_list = torch.stack(ade_list, dim=0) 74 | fde_list = torch.stack(fde_list, dim=0) 75 | 76 | ade_err_min, _ = torch.min(ade_list, dim=0) 77 | ade_err_min = ade_err_min * valids 78 | fde_err_min, _ = torch.min(fde_list, dim=0) 79 | fde_err_min = fde_err_min * valids 80 | 81 | ade_err_mean = torch.mean(ade_list, dim=0) 82 | ade_err_mean = ade_err_mean * valids 83 | fde_err_mean = torch.mean(fde_list, dim=0) 84 | fde_err_mean = fde_err_mean * valids 85 | 86 | ade_err_std = torch.std(ade_list, dim=0) * np.sqrt((ade_list.shape[0] - 1.) / ade_list.shape[0]) 87 | ade_err_std = ade_err_std * valids 88 | fde_err_std = torch.std(fde_list, dim=0) * np.sqrt((fde_list.shape[0] - 1.) / fde_list.shape[0]) 89 | fde_err_std = fde_err_std * valids 90 | 91 | else: 92 | ade_list = np.array(ade_list, dtype=np.float32) 93 | fde_list = np.array(fde_list, dtype=np.float32) 94 | 95 | ade_err_min = ade_list.min(axis=0) * valids 96 | fde_err_min = fde_list.min(axis=0) * valids 97 | 98 | ade_err_mean = ade_list.mean(axis=0) * valids 99 | fde_err_mean = fde_list.mean(axis=0) * valids 100 | 101 | ade_err_std = ade_list.std(axis=0) * valids 102 | fde_err_std = fde_list.std(axis=0) * valids 103 | 104 | ade_mean = ade_err_mean.sum() / valids.sum() 105 | fde_mean = fde_err_mean.sum() / valids.sum() 106 | 107 | ade_std = ade_err_std.sum() / valids.sum() 108 | fde_std = fde_err_std.sum() / valids.sum() 109 | ade_mean_info = 'ADE: %.3f ± %.3f (%d/%d)' % (ade_mean, ade_std, valids.sum(), len_dataset * num_obj) 110 | fde_mean_info = "FDE: %.3f ± %.3f (%d/%d)" % (fde_mean, fde_std, valids.sum(), len_dataset * num_obj) 111 | 112 | ade_min = ade_err_min.sum() / valids.sum() 113 | fde_min = fde_err_min.sum() / valids.sum() 114 | ade_min_info = 'min ADE: %.3f (%d/%d)' % (ade_min, valids.sum(), len_dataset * num_obj) 115 | fde_min_info = "min FDE: %.3f (%d/%d)" % (fde_min, valids.sum(), len_dataset * num_obj) 116 | 117 | print(ade_min_info) 118 | print(fde_min_info) 119 | print(ade_mean_info) 120 | print(fde_mean_info) 121 | 122 | return ade_mean, fde_mean 123 | 124 | 125 | 126 | 127 | -------------------------------------------------------------------------------- /options/netsopts.py: -------------------------------------------------------------------------------- 1 | def add_nets_opts(parser): 2 | parser.add_argument('--src_in_features', type=int, default=1024, help='Network encoder input size') 3 | parser.add_argument('--trg_in_features', type=int, default=2, help='Network decoder input size') 4 | parser.add_argument('--num_patches', type=int, default=5, help='Number of classes') 5 | parser.add_argument('--num_classes', type=int, default=2513, help='Number of classes') 6 | 7 | parser.add_argument('--embed_dim', type=int, default=512, help='embedded dimension') 8 | parser.add_argument('--num_heads', type=int, default=8, help='num of heads in transformer') 9 | parser.add_argument('--enc_depth', type=int, default=6, help='transformer encoder depth') 10 | parser.add_argument('--dec_depth', type=int, default=4, help='transformer decoder depth') 11 | 12 | parser.add_argument('--coord_dim', type=int, default=64, help='coordinates feature dimension') 13 | parser.add_argument('--hidden_dim', type=int, default=512, help='stochastic modules hidden dimension') 14 | parser.add_argument('--latent_dim', type=int, default=256, help='stochastic modules latent dimension') 15 | 16 | parser.add_argument("--encoder_time_embed_type", default="sin", 17 | choices=["sin", "param"], help="transformer encoder time position embedding") 18 | parser.add_argument("--decoder_time_embed_type", default="sin", 19 | choices=["sin", "param"], help="transformer decoder time position embedding") 20 | 21 | parser.add_argument("--num_samples", default=20, type=int, help="get number of samples during inference, " 22 | "stochastic model multiple runs") 23 | parser.add_argument("--num_points", default=5, type=int, 24 | help="number of remaining contact points after farthest point " 25 | "sampling for evaluation affordance") 26 | parser.add_argument("--gaussian_sigma", default=3., type=float, 27 | help="predicted contact points gaussian kernel sigma") 28 | parser.add_argument("--gaussian_k_ratio", default=3., type=float, 29 | help="predicted contact points gaussian kernel size") 30 | parser.add_argument("--dropout", default=0.1, type=float, help="dropout rate") 31 | parser.add_argument("--diffusion_steps", default=1000, type=int, help="diffusion steps") 32 | parser.add_argument("--noise_schedule", default="sqrt", type=str, help="noise schedule for diffusion") 33 | parser.add_argument("--learn_sigma", default=False, type=bool, help="whether to learn sigma") 34 | parser.add_argument("--timestep_respacing", default="", type=str, help="timestep respacing") 35 | parser.add_argument("--rescale_timesteps", default=True, type=bool, help="whether to rescale timesteps") 36 | parser.add_argument("--predict_xstart", default=True, type=bool, help="whether to predict start x") 37 | parser.add_argument("--sigma_small", default=False, type=bool, help='small sigma') 38 | parser.add_argument("--rescale_learned_sigmas", default=False, type=bool, help="whether to rescale sigmas") 39 | parser.add_argument("--use_kl", default=False, type=bool, help="whether to cal KLD") 40 | parser.add_argument("--sf_encoder_hidden", default=64, type=int, help="hidden layer of sidefusion encoder") 41 | parser.add_argument("--motion_encoder_hidden", default=64, type=int, help="hidden layer of motion encoder") 42 | parser.add_argument("--traj_decoder_hidden1", default=256, type=int, help="hidden layer 1 of traj decoder") 43 | parser.add_argument("--traj_decoder_hidden2", default=64, type=int, help="hidden layer 2 of traj decoder") 44 | parser.add_argument("--madt_depth", default=6, type=int, help="number of transformer layers in madt") 45 | parser.add_argument("--holi_past", default=False, type=bool, help="whether to use holistic past seq as condition") 46 | parser.add_argument("--test_start_idx", default=0, type=int, help="start index for test seq") 47 | 48 | 49 | parser.add_argument("--lambda_obj", default=1e-1, type=float, help="Weight to supervise object affordance") 50 | parser.add_argument("--lambda_traj", default=1., type=float, help="Weight to supervise hand traj") 51 | parser.add_argument("--lambda_obj_kl", default=1e-3, type=float, help="Weight to supervise object affordance KLD") 52 | parser.add_argument("--lambda_traj_kl", default=1e-3, type=float, help="Weight to supervise hand traj KLD") 53 | 54 | 55 | def add_train_opts(parser): 56 | parser.add_argument("--manual_seed", default=1, type=int, help="manual seed") 57 | parser.add_argument("-j", "--workers", default=16, type=int, help="number of workers") 58 | parser.add_argument("--epochs", default=35, type=int, help="number epochs") 59 | parser.add_argument("--batch_size", default=16, type=int, help="batch size") 60 | 61 | parser.add_argument("--optimizer", default="adamw", choices=["rms", "adam", "sgd", "adamw"]) 62 | parser.add_argument("--lr", "--learning-rate", default=2e-4, type=float, metavar="LR", help="initial learning rate") 63 | parser.add_argument("--momentum", default=0.9, type=float) 64 | 65 | parser.add_argument("--scheduler", default="cosine", choices=['cosine', 'step', 'multistep'], 66 | help="learning rate scheduler") 67 | parser.add_argument("--warmup_epochs", default=0, type=int, help="number of warmup epochs to run") 68 | parser.add_argument("--lr_decay_step", nargs="+", default=10, type=int, 69 | help="Epochs after which to decay learning rate") 70 | parser.add_argument( 71 | "--lr_decay_gamma", default=0.5, type=float, help="Factor by which to decay the learning rate") 72 | parser.add_argument("--weight_decay", default=1e-4, type=float) 73 | -------------------------------------------------------------------------------- /datasets/input_loaders.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | from torchvision import transforms 4 | import torch 5 | import os 6 | import lmdb 7 | 8 | 9 | class ActionAnticipationSampler(object): 10 | def __init__(self, t_buffer, t_ant=1.0, fps=4.0, ori_fps=60.0): 11 | self.t_buffer = t_buffer 12 | self.t_ant = t_ant 13 | self.fps = fps 14 | self.ori_fps = ori_fps 15 | 16 | def __call__(self, action): 17 | times, frames_idxs = sample_history_frames(action.start_frame, self.t_buffer, 18 | self.t_ant, fps=self.fps, 19 | fps_init=self.ori_fps) 20 | return times, frames_idxs 21 | 22 | def get_sampler(args): 23 | sampler = ActionAnticipationSampler(t_buffer=args.t_buffer, t_ant=args.t_ant, 24 | fps=args.fps, ori_fps=args.ori_fps) 25 | return sampler 26 | 27 | 28 | def sample_history_frames(frame_start, t_buffer=2.5, t_ant=1.0, fps=4.0, fps_init=60.0): 29 | time_start = (frame_start - 1) / fps_init 30 | num_frames = int(np.floor(t_buffer * fps)) 31 | time_ant = time_start - t_ant 32 | times = (np.arange(1, num_frames + 1) - num_frames) / fps + time_ant 33 | times = np.clip(times, 0, np.inf) 34 | times = times.astype(np.float32) 35 | frames_idxs = np.floor(times * fps_init).astype(np.int32) + 1 36 | times = (frames_idxs - 1) / fps_init 37 | return times, frames_idxs 38 | 39 | 40 | def sample_future_frames(frame_start, t_buffer=1, fps=4.0, fps_init=60.0): 41 | time_start = (frame_start - 1) / fps_init 42 | num_frames = int(np.floor(t_buffer * fps)) 43 | times = (np.arange(num_frames + 1) - num_frames) / fps + time_start 44 | times = np.clip(times, 0, np.inf) 45 | times = times.astype(np.float32) 46 | frames_idxs = np.floor(times * fps_init).astype(np.int32) + 1 47 | if frames_idxs.max() >= 1: 48 | frames_idxs[frames_idxs < 1] = frames_idxs[frames_idxs >= 1].min() 49 | return list(frames_idxs) 50 | 51 | 52 | class FeaturesLoader(object): 53 | def __init__(self, sampler, feature_base_path, fps, input_name='rgb', 54 | frame_tmpl='frame_{:010d}.jpg', transform_feat=None, 55 | transform_video=None): 56 | self.feature_base_path = feature_base_path 57 | self.env = lmdb.open(os.path.join(self.feature_base_path, input_name), readonly=True, lock=False) 58 | self.fps = fps 59 | self.input_name = input_name 60 | self.frame_tmpl = frame_tmpl 61 | self.transform_feat = transform_feat 62 | self.transform_video = transform_video 63 | self.sampler = sampler 64 | 65 | def __call__(self, action): 66 | times, frames_idxs = self.sampler(action) 67 | frames_names = [self.frame_tmpl.format(action.video_id, i) for i in frames_idxs] 68 | feats = [] 69 | with self.env.begin() as env: 70 | for f_name in frames_names: 71 | feat = env.get(f_name.strip().encode('utf-8')) 72 | if feat is None: 73 | print(f_name) 74 | feat = np.frombuffer(feat, 'float32') 75 | 76 | if self.transform_feat is not None: 77 | feat = self.transform_feat(feat) 78 | feats += [feat] 79 | 80 | if self.transform_video is not None: 81 | feats = self.transform_video(feats) 82 | out = {self.input_name: feats} 83 | out['times'] = times 84 | out['start_time'] = action.start_time 85 | out['frames_idxs'] = frames_idxs 86 | return out 87 | 88 | 89 | class PipeLoaders(object): 90 | def __init__(self, loader_list): 91 | self.loader_list = loader_list 92 | 93 | def __call__(self, action): 94 | out = {} 95 | for loader in self.loader_list: 96 | out.update(loader(action)) 97 | return out 98 | 99 | 100 | def get_features_loader(args, featuresloader=None): 101 | sampler = get_sampler(args) 102 | feat_in_modalities = list({'feat'}.intersection(args.modalities)) 103 | transform_feat = lambda x: torch.tensor(x.copy()) 104 | transform_video = lambda x: torch.stack(x, 0) 105 | loader_args = { 106 | 'feature_base_path': args.features_paths[args.ek_version], 107 | 'fps': args.fps, 108 | 'frame_tmpl': 'frame_{:010d}.jpg', 109 | 'transform_feat': transform_feat, 110 | 'transform_video': transform_video, 111 | 'sampler': sampler, 112 | 'mode': args.mode} 113 | if featuresloader is None: 114 | featuresloader = FeaturesLoader 115 | feat_loader_list = [] 116 | for modality in feat_in_modalities: 117 | feat_loader = featuresloader(input_name=modality, **loader_args) 118 | feat_loader_list += [feat_loader] 119 | feat_loaders = { 120 | 'train': PipeLoaders(feat_loader_list) if len(feat_loader_list) else None, 121 | 'validation': PipeLoaders(feat_loader_list) if len(feat_loader_list) else None, 122 | 'test': PipeLoaders(feat_loader_list) if len(feat_loader_list) else None, 123 | } 124 | return feat_loaders 125 | 126 | 127 | def get_loaders(args, featuresloader=None): 128 | loaders = { 129 | 'train': [], 130 | 'validation': [], 131 | 'test': [], 132 | } 133 | 134 | if 'feat' in args.modalities: 135 | feat_loaders = get_features_loader(args, featuresloader=featuresloader) 136 | for k, l in feat_loaders.items(): 137 | if l is not None: 138 | loaders[k] += [l] 139 | 140 | for k, l in loaders.items(): 141 | loaders[k] = PipeLoaders(l) 142 | return loaders 143 | -------------------------------------------------------------------------------- /networks/layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from networks.net_utils import DropPath, get_pad_mask 5 | from einops import rearrange 6 | 7 | 8 | class Mlp(nn.Module): 9 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 10 | super().__init__() 11 | out_features = out_features or in_features 12 | hidden_features = hidden_features or in_features 13 | self.fc1 = nn.Linear(in_features, hidden_features) 14 | self.act = act_layer() 15 | self.fc2 = nn.Linear(hidden_features, out_features) 16 | self.drop = nn.Dropout(drop) 17 | 18 | def forward(self, x): 19 | x = self.fc1(x) 20 | x = self.act(x) 21 | x = self.drop(x) 22 | x = self.fc2(x) 23 | x = self.drop(x) 24 | return x 25 | 26 | 27 | class ScaledDotProductAttention(nn.Module): 28 | def __init__(self, temperature, attn_dropout=0.1): 29 | super().__init__() 30 | self.temperature = temperature 31 | self.dropout = nn.Dropout(attn_dropout) 32 | 33 | def forward(self, q, k, v, mask=None): 34 | attn = torch.matmul(q / self.temperature, k.transpose(2, 3)) 35 | if mask is not None: 36 | attn = attn.masked_fill(mask == 0, -1e9) 37 | attn = self.dropout(F.softmax(attn, dim=-1)) 38 | output = torch.matmul(attn, v) 39 | return output, attn 40 | 41 | 42 | class MultiHeadAttention(nn.Module): 43 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., with_qkv=True): 44 | super().__init__() 45 | self.num_heads = num_heads 46 | head_dim = dim // num_heads 47 | self.with_qkv = with_qkv 48 | if self.with_qkv: 49 | self.proj_q = nn.Linear(dim, dim, bias=qkv_bias) 50 | self.proj_k = nn.Linear(dim, dim, bias=qkv_bias) 51 | self.proj_v = nn.Linear(dim, dim, bias=qkv_bias) 52 | 53 | self.proj = nn.Linear(dim, dim) 54 | self.proj_drop = nn.Dropout(proj_drop) 55 | self.attention = ScaledDotProductAttention(temperature=qk_scale or head_dim ** 0.5) 56 | self.attn_drop = nn.Dropout(attn_drop) 57 | 58 | def forward(self, q, k, v, mask=None): 59 | B, Nq, Nk, Nv, C = q.shape[0], q.shape[1], k.shape[1], v.shape[1], q.shape[2] 60 | if self.with_qkv: 61 | q = self.proj_q(q).reshape(B, Nq, self.num_heads, C // self.num_heads).transpose(1, 2) 62 | k = self.proj_k(k).reshape(B, Nk, self.num_heads, C // self.num_heads).transpose(1, 2) 63 | v = self.proj_v(v).reshape(B, Nv, self.num_heads, C // self.num_heads).transpose(1, 2) 64 | else: 65 | q = q.reshape(B, Nq, self.num_heads, C // self.num_heads).transpose(1, 2) 66 | k = k.reshape(B, Nk, self.num_heads, C // self.num_heads).transpose(1, 2) 67 | v = v.reshape(B, Nv, self.num_heads, C // self.num_heads).transpose(1, 2) 68 | if mask is not None: 69 | mask = mask.unsqueeze(1) 70 | 71 | x, attn = self.attention(q, k, v, mask=mask) 72 | x = x.transpose(1, 2).reshape(B, Nq, C) 73 | if self.with_qkv: 74 | x = self.proj(x) 75 | x = self.proj_drop(x) 76 | return x 77 | 78 | 79 | class EncoderBlock(nn.Module): 80 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., 81 | drop_path=0.1, act_layer=nn.GELU, norm_layer=nn.LayerNorm): 82 | super().__init__() 83 | self.norm1 = norm_layer(dim) 84 | self.attn = MultiHeadAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias, 85 | qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 86 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 87 | self.norm2 = norm_layer(dim) 88 | mlp_hidden_dim = int(dim * mlp_ratio) 89 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 90 | 91 | def forward(self, x, B, T, N, mask=None): 92 | 93 | if mask is not None: 94 | src_mask = rearrange(mask, 'b n t -> b (n t)', b=B, n=N, t=T) 95 | src_mask = get_pad_mask(src_mask, 0) 96 | else: 97 | src_mask = None 98 | x2 = self.norm1(x) 99 | x = x + self.drop_path(self.attn(q=x2, k=x2, v=x2, mask=src_mask)) 100 | x = x + self.drop_path(self.mlp(self.norm2(x))) 101 | return x 102 | 103 | 104 | class DecoderBlock(nn.Module): 105 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., 106 | drop_path=0.1, act_layer=nn.GELU, norm_layer=nn.LayerNorm): 107 | super().__init__() 108 | self.norm1 = norm_layer(dim) 109 | self.self_attn = MultiHeadAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias, 110 | qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 111 | 112 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 113 | 114 | self.norm2 = norm_layer(dim) 115 | self.enc_dec_attn = MultiHeadAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias, 116 | qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 117 | 118 | self.norm3 = nn.LayerNorm(dim) 119 | mlp_hidden_dim = int(dim * mlp_ratio) 120 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 121 | 122 | def forward(self, tgt, memory, memory_mask=None, trg_mask=None): 123 | tgt_2 = self.norm1(tgt) 124 | tgt = tgt + self.drop_path(self.self_attn(q=tgt_2, k=tgt_2, v=tgt_2, mask=trg_mask)) 125 | tgt = tgt + self.drop_path(self.enc_dec_attn(q=self.norm2(tgt), k=memory, v=memory, mask=memory_mask)) 126 | tgt = tgt + self.drop_path(self.mlp(self.norm2(tgt))) 127 | return tgt 128 | -------------------------------------------------------------------------------- /diffip2d/step_sample.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import numpy as np 4 | import torch as th 5 | import torch.distributed as dist 6 | 7 | 8 | def create_named_schedule_sampler(name, diffusion): 9 | """ 10 | Create a ScheduleSampler from a library of pre-defined samplers. 11 | 12 | :param name: the name of the sampler. 13 | :param diffusion: the diffusion object to sample for. 14 | """ 15 | if name == "uniform": 16 | return UniformSampler(diffusion) 17 | elif name == "lossaware": 18 | return LossSecondMomentResampler(diffusion) 19 | elif name == "fixstep": 20 | return FixSampler(diffusion) 21 | else: 22 | raise NotImplementedError(f"unknown schedule sampler: {name}") 23 | 24 | 25 | class ScheduleSampler(ABC): 26 | """ 27 | A distribution over timesteps in the diffusion process, intended to reduce 28 | variance of the objective. 29 | 30 | By default, samplers perform unbiased importance sampling, in which the 31 | objective's mean is unchanged. 32 | However, subclasses may override sample() to change how the resampled 33 | terms are reweighted, allowing for actual changes in the objective. 34 | """ 35 | 36 | @abstractmethod 37 | def weights(self): 38 | """ 39 | Get a numpy array of weights, one per diffusion step. 40 | 41 | The weights needn't be normalized, but must be positive. 42 | """ 43 | 44 | def sample(self, batch_size, device): 45 | """ 46 | Importance-sample timesteps for a batch. 47 | 48 | :param batch_size: the number of timesteps. 49 | :param device: the torch device to save to. 50 | :return: a tuple (timesteps, weights): 51 | - timesteps: a tensor of timestep indices. 52 | - weights: a tensor of weights to scale the resulting losses. 53 | """ 54 | w = self.weights() 55 | p = w / np.sum(w) 56 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p) 57 | indices = th.from_numpy(indices_np).long().to(device) 58 | weights_np = 1 / (len(p) * p[indices_np]) 59 | weights = th.from_numpy(weights_np).float().to(device) 60 | return indices, weights 61 | 62 | class UniformSampler(ScheduleSampler): 63 | def __init__(self, diffusion): 64 | self.diffusion = diffusion 65 | self._weights = np.ones([diffusion.num_timesteps]) 66 | 67 | def weights(self): 68 | return self._weights 69 | 70 | class FixSampler(ScheduleSampler): 71 | def __init__(self, diffusion): 72 | self.diffusion = diffusion 73 | 74 | ############################################################### 75 | ### You can custome your own sampling weight of steps here. ### 76 | ############################################################### 77 | self._weights = np.concatenate([np.ones([diffusion.num_timesteps//2]), np.zeros([diffusion.num_timesteps//2]) + 0.5]) 78 | 79 | def weights(self): 80 | return self._weights 81 | 82 | 83 | class LossAwareSampler(ScheduleSampler): 84 | def update_with_local_losses(self, local_ts, local_losses): 85 | """ 86 | Update the reweighting using losses from a model. 87 | 88 | Call this method from each rank with a batch of timesteps and the 89 | corresponding losses for each of those timesteps. 90 | This method will perform synchronization to make sure all of the ranks 91 | maintain the exact same reweighting. 92 | 93 | :param local_ts: an integer Tensor of timesteps. 94 | :param local_losses: a 1D Tensor of losses. 95 | """ 96 | batch_sizes = [ 97 | th.tensor([0], dtype=th.int32, device=local_ts.device) 98 | for _ in range(dist.get_world_size()) 99 | ] 100 | dist.all_gather( 101 | batch_sizes, 102 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), 103 | ) 104 | 105 | batch_sizes = [x.item() for x in batch_sizes] 106 | max_bs = max(batch_sizes) 107 | 108 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] 109 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] 110 | dist.all_gather(timestep_batches, local_ts) 111 | dist.all_gather(loss_batches, local_losses) 112 | timesteps = [ 113 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] 114 | ] 115 | 116 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] 117 | self.update_with_all_losses(timesteps, losses) 118 | 119 | @abstractmethod 120 | def update_with_all_losses(self, ts, losses): 121 | """ 122 | Update the reweighting using losses from a model. 123 | 124 | Sub-classes should override this method to update the reweighting 125 | using losses from the model. 126 | 127 | This method directly updates the reweighting without synchronizing 128 | between workers. It is called by update_with_local_losses from all 129 | ranks with identical arguments. Thus, it should have deterministic 130 | behavior to maintain state across workers. 131 | 132 | :param ts: a list of int timesteps. 133 | :param losses: a list of float losses, one per timestep. 134 | """ 135 | 136 | 137 | class LossSecondMomentResampler(LossAwareSampler): 138 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): 139 | self.diffusion = diffusion 140 | self.history_per_term = history_per_term 141 | self.uniform_prob = uniform_prob 142 | self._loss_history = np.zeros( 143 | [diffusion.num_timesteps, history_per_term], dtype=np.float64 144 | ) 145 | self._loss_counts = np.zeros([diffusion.num_timesteps]).astype(int) 146 | 147 | def weights(self): 148 | if not self._warmed_up(): 149 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64) 150 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) 151 | weights /= np.sum(weights) 152 | weights *= 1 - self.uniform_prob 153 | weights += self.uniform_prob / len(weights) 154 | return weights 155 | 156 | def update_with_all_losses(self, ts, losses): 157 | for t, loss in zip(ts, losses): 158 | if self._loss_counts[t] == self.history_per_term: 159 | self._loss_history[t, :-1] = self._loss_history[t, 1:] 160 | self._loss_history[t, -1] = loss 161 | else: 162 | self._loss_history[t, self._loss_counts[t]] = loss 163 | self._loss_counts[t] += 1 164 | 165 | def _warmed_up(self): 166 | return (self._loss_counts == self.history_per_term).all() 167 | -------------------------------------------------------------------------------- /traineval.py: -------------------------------------------------------------------------------- 1 | # Developed by Junyi Ma 2 | # Diff-IP2D: Diffusion-Based Hand-Object Interaction Prediction on Egocentric Videos 3 | # https://github.com/IRMVLab/Diff-IP2D 4 | # We thank OCT (Liu et al.), Diffuseq (Gong et al.), and USST (Bao et al.) for providing the codebases. 5 | 6 | 7 | import argparse 8 | import os 9 | import datetime 10 | import random 11 | import numpy as np 12 | import torch 13 | import torch.nn.parallel 14 | import torch.optim 15 | from netscripts.get_datasets import get_dataset 16 | from netscripts.get_network import get_network_for_diffip 17 | from netscripts.get_optimizer import get_optimizer 18 | from netscripts import modelio 19 | from options import netsopts, expopts 20 | from datasets.datasetopts import DatasetArgs 21 | from diffip2d.step_sample import create_named_schedule_sampler 22 | from torch.nn.parallel.distributed import DistributedDataParallel as DDP 23 | from netscripts.epoch_feat import TrainValLoop 24 | from basic_utils import create_network_and_diffusion 25 | import logging.config 26 | import logging 27 | logging.basicConfig(level=logging.INFO, format='[%(levelname)s] %(message)s') 28 | from diffip2d.utils import dist_util, logger 29 | 30 | 31 | def main(args): 32 | 33 | # Initialization 34 | torch.cuda.manual_seed_all(args.manual_seed) 35 | torch.manual_seed(args.manual_seed) 36 | np.random.seed(args.manual_seed) 37 | random.seed(args.manual_seed) 38 | dist_util.setup_dist() 39 | 40 | datasetargs = DatasetArgs(ek_version=args.ek_version) 41 | num_frames_input = int(datasetargs.fps * datasetargs.t_buffer) 42 | num_frames_output = int(datasetargs.fps * datasetargs.t_ant) 43 | start_epoch = 0 44 | 45 | # building architecture 46 | model_hoi, obj_head = get_network_for_diffip(args, num_frames_input=num_frames_input, 47 | num_frames_output=num_frames_output) 48 | 49 | model_diff_args = { 50 | "hidden_t_dim": args.hidden_dim, 51 | "hidden_dim": args.hidden_dim, 52 | "vocab_size": None, # deprecated in non-nlp task 53 | "config_name": "huggingface-config", # deprecated in non-nlp task 54 | "use_plm_init": "no", 55 | "dropout": args.dropout, 56 | "diffusion_steps": args.diffusion_steps, 57 | "noise_schedule": args.noise_schedule, 58 | "learn_sigma": args.learn_sigma, 59 | "timestep_respacing": args.timestep_respacing, 60 | "predict_xstart": args.predict_xstart, 61 | "rescale_timesteps": args.rescale_timesteps, 62 | "sigma_small": args.sigma_small, 63 | "rescale_learned_sigmas": args.rescale_learned_sigmas, 64 | "use_kl": args.use_kl, 65 | "sf_encoder_hidden": args.sf_encoder_hidden, 66 | "traj_decoder_hidden1": args.traj_decoder_hidden1, 67 | "traj_decoder_hidden2": args.traj_decoder_hidden2, 68 | "motion_encoder_hidden": args.motion_encoder_hidden, 69 | "madt_depth": args.madt_depth, 70 | } 71 | if int(os.environ['LOCAL_RANK']) == 0: 72 | logging.info("diffusion setups\n================= \n%s \n=================", model_diff_args) 73 | sf_encoder, model_denoise, diffusion, traj_decoder, motion_encoder = create_network_and_diffusion(**model_diff_args) 74 | if int(os.environ['LOCAL_RANK']) == 0: 75 | logging.info("finish building diffusion model!") 76 | 77 | schedule_sampler_args = args.schedule_sampler_args 78 | schedule_sampler = create_named_schedule_sampler(schedule_sampler_args, diffusion) 79 | if int(os.environ['LOCAL_RANK']) == 0: 80 | logging.info("finish building schedule sampler!") 81 | 82 | _, dls = get_dataset(args, base_path="./") 83 | 84 | if args.evaluate: 85 | args.epochs = start_epoch + 1 86 | traj_val_loader = None 87 | optimizer=None 88 | scheduler=None 89 | else: 90 | train_loader = dls['train'] 91 | traj_val_loader = dls['validation'] 92 | print("training dataset size: {}".format(len(train_loader.dataset))) 93 | optimizer, scheduler = get_optimizer(args, sf_encoder=sf_encoder, model_denoise=model_denoise,traj_decoder=traj_decoder, 94 | train_loader=train_loader,model_hoi=model_hoi, motion_encoder=motion_encoder, obj_head=obj_head) 95 | 96 | # We follow data structure of OCT to train and test our models 97 | if not args.traj_only: 98 | val_loader = dls['eval'] 99 | else: 100 | traj_val_loader = val_loader = dls['validation'] 101 | print("evaluation dataset size: {}".format(len(val_loader.dataset))) 102 | 103 | if args.evaluate and args.traj_only: 104 | loader = traj_val_loader 105 | elif args.evaluate and (not args.traj_only): 106 | loader = val_loader 107 | else: 108 | loader = train_loader 109 | 110 | TrainValLoop( 111 | epochs = args.epochs, 112 | loader=loader, 113 | evaluate=args.evaluate, 114 | optimizer=optimizer, 115 | use_schedule=args.use_schedule, 116 | scheduler=scheduler, 117 | model_hoi=model_hoi, 118 | obj_head=obj_head, 119 | sf_encoder=sf_encoder, 120 | model_denoise=model_denoise, 121 | diffusion=diffusion, 122 | diffusion_steps=args.diffusion_steps, 123 | traj_decoder=traj_decoder, 124 | motion_encoder=motion_encoder, 125 | holi_past=args.holi_past, 126 | fast_test=args.fast_test, 127 | seq_len_obs=args.seq_len_obs, 128 | seq_len_unobs=args.seq_len_unobs, 129 | feat_dim=args.hidden_dim, 130 | sample_times=args.sample_times, 131 | learnable_weight=args.learnable_weight, 132 | reg_loss_weight=args.reg_loss_weight, 133 | rec_loss_weight=args.rec_loss_weight, 134 | schedule_sampler=schedule_sampler, 135 | test_start_idx=args.test_start_idx, 136 | resume=args.resume, 137 | base_model=args.base_model, 138 | log_path=args.log_path, 139 | checkpoint_path=args.checkpoint_path, 140 | collection_path_traj=args.collection_path_traj, 141 | collection_path_aff=args.collection_path_aff, 142 | ).run_loop() 143 | 144 | 145 | if __name__ == "__main__": 146 | parser = argparse.ArgumentParser(description="HOI Forecasting") 147 | netsopts.add_nets_opts(parser) 148 | netsopts.add_train_opts(parser) 149 | expopts.add_exp_opts(parser) 150 | expopts.add_path_opts(parser) 151 | args = parser.parse_args() 152 | 153 | if args.use_cuda and torch.cuda.is_available(): 154 | num_gpus = torch.cuda.device_count() 155 | args.batch_size = args.batch_size * num_gpus 156 | if int(os.environ['LOCAL_RANK']) == 0: 157 | logging.info("use batch size: %s", args.batch_size) 158 | 159 | if args.traj_only: assert args.evaluate, "evaluate trajectory on validation set must set --evaluate" 160 | main(args) 161 | gpu_id = os.environ['LOCAL_RANK'] 162 | logging.info("GPU: %s Done!", gpu_id) -------------------------------------------------------------------------------- /preprocess/dataset_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | from preprocess.ho_types import FrameDetections, HandDetection, HandSide, HandState, ObjectDetection 5 | 6 | 7 | def sample_action_anticipation_frames(frame_start, t_buffer=1, fps=4.0, fps_init=60.0): 8 | time_start = (frame_start - 1) / fps_init 9 | num_frames = int(np.floor(t_buffer * fps)) 10 | times = (np.arange(num_frames + 1) - num_frames) / fps + time_start 11 | times = np.clip(times, 0, np.inf) 12 | times = times.astype(np.float32) 13 | frames_idxs = np.floor(times * fps_init).astype(np.int32) + 1 14 | if frames_idxs.max() >= 1: 15 | frames_idxs[frames_idxs < 1] = frames_idxs[frames_idxs >= 1].min() 16 | print(len(list(frames_idxs))) 17 | print(frames_idxs) 18 | return list(frames_idxs) 19 | 20 | 21 | def load_ho_annot(video_detections, frame_index, imgW=456, imgH=256): 22 | annot = video_detections[frame_index-1] # frame_index start from 1 23 | assert annot.frame_number == frame_index, "wrong frame index" 24 | annot.scale(width_factor=imgW, height_factor=imgH) 25 | return annot 26 | 27 | 28 | def load_img(frames_path, frame_index): 29 | frame = cv2.imread(os.path.join(frames_path, "frame_{:010d}.jpg".format(frame_index))) 30 | return frame 31 | 32 | 33 | def get_mask(frame, annot, hand_threshold=0.1, obj_threshold=0.1): 34 | msk_img = np.ones((frame.shape[:2]), dtype=frame.dtype) 35 | hands = [hand for hand in annot.hands if hand.score >= hand_threshold] 36 | objs = [obj for obj in annot.objects if obj.score >= obj_threshold] 37 | for hand in hands: 38 | (x1, y1), (x2, y2) = hand.bbox.coords_int 39 | msk_img[y1:y2, x1:x2] = 0 40 | 41 | if len(objs) > 0: 42 | hand_object_idx_correspondences = annot.get_hand_object_interactions(object_threshold=obj_threshold, 43 | hand_threshold=hand_threshold) 44 | for hand_idx, object_idx in hand_object_idx_correspondences.items(): 45 | hand = annot.hands[hand_idx] 46 | object = annot.objects[object_idx] 47 | if not hand.state.value == HandState.STATIONARY_OBJECT.value: 48 | (x1, y1), (x2, y2) = object.bbox.coords_int 49 | msk_img[y1:y2, x1:x2] = 0 50 | return msk_img 51 | 52 | 53 | def bbox_inter(boxA, boxB): 54 | xA = max(boxA[0], boxB[0]) 55 | yA = max(boxA[1], boxB[1]) 56 | xB = min(boxA[2], boxB[2]) 57 | yB = min(boxA[3], boxB[3]) 58 | 59 | interArea = abs(max((xB - xA, 0)) * max((yB - yA), 0)) 60 | if interArea == 0: 61 | return xA, yA, xB, yB, 0 62 | 63 | boxAArea = abs((boxA[2] - boxA[0]) * (boxA[3] - boxA[1])) 64 | boxBArea = abs((boxB[2] - boxB[0]) * (boxB[3] - boxB[1])) 65 | iou = interArea / float(boxAArea + boxBArea - interArea) 66 | return xA, yA, xB, yB, iou 67 | 68 | 69 | def compute_iou(boxA, boxB): 70 | boxA = np.array(boxA).reshape(-1) 71 | boxB = np.array(boxB).reshape(-1) 72 | xA = max(boxA[0], boxB[0]) 73 | yA = max(boxA[1], boxB[1]) 74 | xB = min(boxA[2], boxB[2]) 75 | yB = min(boxA[3], boxB[3]) 76 | interArea = abs(max((xB - xA, 0)) * max((yB - yA), 0)) 77 | if interArea == 0: 78 | return 0 79 | boxAArea = abs((boxA[2] - boxA[0]) * (boxA[3] - boxA[1])) 80 | boxBArea = abs((boxB[2] - boxB[0]) * (boxB[3] - boxB[1])) 81 | iou = interArea / float(boxAArea + boxBArea - interArea) 82 | return iou 83 | 84 | 85 | def points_in_bbox(point, bbox): 86 | (x1, y1), (x2, y2) = bbox 87 | (x, y) = point 88 | return (x1 <= x <= x2) and (y1 <= y <= y2) 89 | 90 | 91 | def valid_point(point, imgW=456, imgH=256): 92 | if point is None: 93 | return False 94 | else: 95 | x, y = point 96 | return (0 <= x < imgW) and (0 <=y < imgH) 97 | 98 | 99 | def valid_traj(traj, imgW=456, imgH=256): 100 | if len(traj) > 0: 101 | num_outlier = np.sum([not valid_point(point, imgW=imgW, imgH=imgH) 102 | for point in traj if point is not None]) 103 | valid_ratio = np.sum([valid_point(point, imgW=imgW, imgH=imgH) for point in traj[1:]]) / len(traj[1:]) 104 | valid_last = valid_point(traj[-1], imgW=imgW, imgH=imgH) 105 | if num_outlier > 1 or valid_ratio < 0.5 or not valid_last: 106 | traj = [] 107 | return traj 108 | 109 | 110 | def get_valid_traj(traj, imgW=456, imgH=256): 111 | try: 112 | traj[traj < 0] = traj[traj >= 0].min() 113 | except: 114 | traj[traj < 0] = 0 115 | try: 116 | traj[:, 0][traj[:, 0] >= imgW] = imgW - 1 117 | except: 118 | traj[:, 0][traj[:, 0] >= imgW] = imgW - 1 119 | try: 120 | traj[:, 1][traj[:, 1] >= imgH] = imgH - 1 121 | except: 122 | traj[:, 1][traj[:, 1] >= imgH] = imgH - 1 123 | return traj 124 | 125 | 126 | def fetch_data(frames_path, video_detections, frames_idxs, hand_threshold=0.1, obj_threshold=0.1): 127 | tolerance = frames_idxs[1] - frames_idxs[0] # extend future act frame by tolerance to find ho interaction 128 | frames = [] 129 | annots = [] 130 | 131 | miss_hand = 0 132 | for frame_idx in frames_idxs[:-1]: 133 | frame = load_img(frames_path, frame_idx) 134 | annot = load_ho_annot(video_detections, frame_idx) 135 | hands = [hand for hand in annot.hands if hand.score >= hand_threshold] 136 | if len(hands) == 0: 137 | miss_hand += 1 138 | frames.append(frame) 139 | annots.append(annot) 140 | if miss_hand == len(frames_idxs[:-1]): 141 | return None 142 | frame_idx = frames_idxs[-1] 143 | frames_idxs = frames_idxs[:-1] 144 | 145 | hand_sides = [] 146 | idx = 0 147 | flag = False 148 | while idx < tolerance: 149 | annot = load_ho_annot(video_detections, frame_idx) 150 | hands = [hand for hand in annot.hands if hand.score >= hand_threshold] 151 | objs = [obj for obj in annot.objects if obj.score >= obj_threshold] 152 | if len(hands) > 0 and len(objs) > 0: # at least one hand is contact with obj 153 | hand_object_idx_correspondences = annot.get_hand_object_interactions(object_threshold=obj_threshold, 154 | hand_threshold=hand_threshold) 155 | for hand_idx, object_idx in hand_object_idx_correspondences.items(): 156 | hand_bbox = np.array(annot.hands[hand_idx].bbox.coords).reshape(-1) 157 | obj_bbox = np.array(annot.objects[object_idx].bbox.coords).reshape(-1) 158 | xA, yA, xB, yB, iou = bbox_inter(hand_bbox, obj_bbox) 159 | contact_state = annot.hands[hand_idx].state.value 160 | if iou > 0 and (contact_state == HandState.STATIONARY_OBJECT.value or 161 | contact_state == HandState.PORTABLE_OBJECT.value): 162 | hand_side = annot.hands[hand_idx].side.name 163 | hand_sides.append(hand_side) 164 | flag = True 165 | if flag: 166 | break 167 | else: 168 | idx += 1 169 | frame_idx += 1 170 | else: 171 | idx += 1 172 | frame_idx += 1 173 | if flag: 174 | frames_idxs.append(frame_idx) 175 | frames.append(load_img(frames_path, frame_idx)) 176 | annots.append(annot) 177 | return frames_idxs, frames, annots, list(set(hand_sides)) # remove redundant hand sides 178 | else: 179 | return None 180 | 181 | 182 | def save_video_info(save_path, video_index, frames_idxs, homography_stack, contacts, 183 | hand_trajs, obj_trajs, affordance_info): 184 | import pickle 185 | video_info = {"frame_indices": frames_idxs, 186 | "homography": homography_stack, 187 | "contact": contacts} 188 | video_info.update({"hand_trajs": hand_trajs}) 189 | video_info.update({"obj_trajs": obj_trajs}) 190 | video_info.update({"affordance": affordance_info}) 191 | with open(os.path.join(save_path, "label_{}.pkl".format(video_index)), 'wb') as f: 192 | pickle.dump(video_info, f) 193 | 194 | 195 | def load_video_info(save_path, video_index): 196 | import pickle 197 | with open(os.path.join(save_path, "label_{}.pkl".format(video_index)), 'rb') as f: 198 | video_info = pickle.load(f) 199 | return video_info 200 | -------------------------------------------------------------------------------- /evaluation/affordance_eval.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import cv2 4 | from sklearn.cluster import KMeans 5 | from joblib import Parallel, delayed 6 | 7 | 8 | def farthest_sampling(pcd, n_samples): 9 | def compute_distance(a, b): 10 | return np.linalg.norm(a - b, ord=2, axis=2) 11 | 12 | n_pts, dim = pcd.shape[0], pcd.shape[1] 13 | selected_pts_expanded = np.zeros(shape=(n_samples, 1, dim)) 14 | remaining_pts = np.copy(pcd) 15 | 16 | if n_pts > 1: 17 | start_idx = np.random.randint(low=0, high=n_pts - 1) 18 | else: 19 | start_idx = 0 20 | selected_pts_expanded[0] = remaining_pts[start_idx] 21 | n_selected_pts = 1 22 | 23 | for _ in range(1, n_samples): 24 | if n_selected_pts < n_samples: 25 | dist_pts_to_selected = compute_distance(remaining_pts, selected_pts_expanded[:n_selected_pts]).T 26 | dist_pts_to_selected_min = np.min(dist_pts_to_selected, axis=1, keepdims=True) 27 | res_selected_idx = np.argmax(dist_pts_to_selected_min) 28 | selected_pts_expanded[n_selected_pts] = remaining_pts[res_selected_idx] 29 | n_selected_pts += 1 30 | 31 | selected_pts = np.squeeze(selected_pts_expanded, axis=1) 32 | return selected_pts 33 | 34 | 35 | def makeGaussian(size, fwhm=3., center=None): 36 | x = np.arange(0, size, 1, float) 37 | y = x[:, np.newaxis] 38 | if center is None: 39 | x0 = y0 = size // 2 40 | else: 41 | x0 = center[0] 42 | y0 = center[1] 43 | return np.exp(-4 * np.log(2) * ((x - x0) ** 2 + (y - y0) ** 2) / fwhm ** 2) 44 | 45 | 46 | def compute_heatmap(normalized_points, image_size, k_ratio=3.0, transpose=True, 47 | fps=False, kmeans=False, n_pts=5, gaussian_sigma=0.): 48 | normalized_points = np.asarray(normalized_points) 49 | heatmap = np.zeros((image_size[0], image_size[1]), dtype=np.float32) 50 | n_points = normalized_points.shape[0] 51 | if n_points > n_pts and kmeans: 52 | kmeans = KMeans(n_clusters=n_pts, random_state=0).fit(normalized_points) 53 | normalized_points = kmeans.cluster_centers_ 54 | elif n_points > n_pts and fps: 55 | normalized_points = farthest_sampling(normalized_points, n_samples=n_pts) 56 | n_points = normalized_points.shape[0] 57 | for i in range(n_points): 58 | x = normalized_points[i, 0] * image_size[0] 59 | y = normalized_points[i, 1] * image_size[1] 60 | col = int(x) 61 | row = int(y) 62 | try: 63 | heatmap[col, row] += 1.0 64 | except: 65 | col = min(max(col, 0), image_size[0] - 1) 66 | row = min(max(row, 0), image_size[1] - 1) 67 | heatmap[col, row] += 1.0 68 | k_size = int(np.sqrt(image_size[0] * image_size[1]) / k_ratio) 69 | if k_size % 2 == 0: 70 | k_size += 1 71 | heatmap = cv2.GaussianBlur(heatmap, (k_size, k_size), gaussian_sigma) 72 | if heatmap.max() > 0: 73 | heatmap /= heatmap.max() 74 | if transpose: 75 | heatmap = heatmap.transpose() 76 | return heatmap 77 | 78 | 79 | def SIM(map1, map2, eps=1e-12): 80 | map1, map2 = map1 / (map1.sum() + eps), map2 / (map2.sum() + eps) 81 | intersection = np.minimum(map1, map2) 82 | return np.sum(intersection) 83 | 84 | 85 | def AUC_Judd(saliency_map, fixation_map, jitter=True): 86 | saliency_map = np.array(saliency_map, copy=False) 87 | fixation_map = np.array(fixation_map, copy=False) > 0.5 88 | if not np.any(fixation_map): 89 | return np.nan 90 | if saliency_map.shape != fixation_map.shape: 91 | saliency_map = cv2.resize(saliency_map, fixation_map.shape, interpolation=cv2.INTER_AREA) 92 | if jitter: 93 | saliency_map += np.random.rand(*saliency_map.shape) * 1e-7 94 | saliency_map = (saliency_map - np.min(saliency_map)) / (np.max(saliency_map) - np.min(saliency_map) + 1e-12) 95 | 96 | S = saliency_map.ravel() 97 | F = fixation_map.ravel() 98 | S_fix = S[F] 99 | n_fix = len(S_fix) 100 | n_pixels = len(S) 101 | thresholds = sorted(S_fix, reverse=True) 102 | tp = np.zeros(len(thresholds) + 2) 103 | fp = np.zeros(len(thresholds) + 2) 104 | tp[0] = 0; 105 | tp[-1] = 1 106 | fp[0] = 0; 107 | fp[-1] = 1 108 | for k, thresh in enumerate(thresholds): 109 | above_th = np.sum(S >= thresh) 110 | tp[k + 1] = (k + 1) / float(n_fix) 111 | fp[k + 1] = (above_th - (k + 1)) / float(n_pixels - n_fix) 112 | return np.trapz(tp, fp) 113 | 114 | 115 | def NSS(saliency_map, fixation_map): 116 | MAP = (saliency_map - saliency_map.mean()) / (saliency_map.std()) 117 | mask = fixation_map.astype(np.bool) 118 | score = MAP[mask].mean() 119 | return score 120 | 121 | 122 | def compute_score(pred, gt, valid_thresh=0.001): 123 | if torch.is_tensor(pred): 124 | pred = pred.numpy() 125 | if torch.is_tensor(gt): 126 | gt = gt.numpy() 127 | 128 | pred = pred / (pred.max() + 1e-12) 129 | 130 | all_thresh = np.linspace(0.001, 1.0, 41) 131 | tp = np.zeros((all_thresh.shape[0],)) 132 | fp = np.zeros((all_thresh.shape[0],)) 133 | fn = np.zeros((all_thresh.shape[0],)) 134 | tn = np.zeros((all_thresh.shape[0],)) 135 | valid_gt = gt > valid_thresh 136 | for idx, thresh in enumerate(all_thresh): 137 | mask = (pred >= thresh) 138 | tp[idx] += np.sum(np.logical_and(mask == 1, valid_gt == 1)) 139 | tn[idx] += np.sum(np.logical_and(mask == 0, valid_gt == 0)) 140 | fp[idx] += np.sum(np.logical_and(mask == 1, valid_gt == 0)) 141 | fn[idx] += np.sum(np.logical_and(mask == 0, valid_gt == 1)) 142 | 143 | scores = {} 144 | gt_real = np.array(gt) 145 | if gt_real.sum() == 0: 146 | gt_real = np.ones(gt_real.shape) / np.product(gt_real.shape) 147 | 148 | score = SIM(pred, gt_real) 149 | scores['SIM'] = score if not np.isnan(score) else None 150 | 151 | gt_binary = np.array(gt) 152 | gt_binary = (gt_binary / gt_binary.max() + 1e-12) if gt_binary.max() > 0 else gt_binary 153 | gt_binary = np.where(gt_binary > 0.5, 1, 0) 154 | score = AUC_Judd(pred, gt_binary) 155 | scores['AUC-J'] = score if not np.isnan(score) else None 156 | 157 | score = NSS(pred, gt_binary) 158 | scores['NSS'] = score if not np.isnan(score) else None 159 | 160 | return dict(scores), tp, tn, fp, fn 161 | 162 | 163 | def evaluate_affordance(preds_dict, gts_dict, val_log=None, 164 | sz=32, fps=False, kmeans=False, n_pts=5, 165 | gaussian_sigma=3., gaussian_k_ratio=3.): 166 | scores = [] 167 | all_thresh = np.linspace(0.001, 1.0, 41) 168 | tp = np.zeros((all_thresh.shape[0],)) 169 | fp = np.zeros((all_thresh.shape[0],)) 170 | fn = np.zeros((all_thresh.shape[0],)) 171 | tn = np.zeros((all_thresh.shape[0],)) 172 | 173 | pred_hmaps = Parallel(n_jobs=16, verbose=0)(delayed(compute_heatmap)(norm_contacts, (sz, sz), 174 | fps=fps, kmeans=kmeans, n_pts=n_pts, 175 | gaussian_sigma=gaussian_sigma, 176 | k_ratio=gaussian_k_ratio) 177 | for (uid, norm_contacts) in preds_dict.items()) 178 | gt_hmaps = Parallel(n_jobs=16, verbose=0)(delayed(compute_heatmap)(norm_contacts, (sz, sz), 179 | fps=fps, n_pts=n_pts, 180 | gaussian_sigma=0, 181 | k_ratio=3.) 182 | for (uid, norm_contacts) in gts_dict.items()) 183 | 184 | for (pred_hmap, gt_hmap) in zip(pred_hmaps, gt_hmaps): 185 | score, ctp, ctn, cfp, cfn = compute_score(pred_hmap, gt_hmap) 186 | scores.append(score) 187 | tp = tp + ctp 188 | tn = tn + ctn 189 | fp = fp + cfp 190 | fn = fn + cfn 191 | 192 | metrics = {} 193 | for key in ['SIM', 'AUC-J', 'NSS']: 194 | key_score = [s[key] for s in scores if s[key] is not None] 195 | mean, stderr = np.mean(key_score), np.std(key_score) / (np.sqrt(len(key_score))) 196 | metrics[key] = mean 197 | 198 | prec = tp / (tp + fp + 1e-6) 199 | recall = tp / (tp + fn + 1e-6) 200 | f1 = 2 * prec * recall / (prec + recall + 1e-6) 201 | idx = np.argmax(f1) 202 | prec_score = prec[idx] 203 | f1_score = f1[idx] 204 | recall_score = recall[idx] 205 | 206 | return metrics -------------------------------------------------------------------------------- /preprocess/traj_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | from preprocess.dataset_util import get_mask, valid_traj 4 | 5 | 6 | def match_keypoints(kpsA, kpsB, featuresA, featuresB, ratio=0.7, reprojThresh=4.0): 7 | matcher = cv2.DescriptorMatcher_create("BruteForce") 8 | rawMatches = matcher.knnMatch(featuresA, featuresB, 2) 9 | matches = [] 10 | 11 | for m in rawMatches: 12 | if len(m) == 2 and m[0].distance < m[1].distance * ratio: 13 | matches.append((m[0])) 14 | 15 | if len(matches) > 4: 16 | ptsA = np.float32([kpsA[m.queryIdx].pt for m in matches]) 17 | ptsB = np.float32([kpsB[m.trainIdx].pt for m in matches]) 18 | (H, status) = cv2.findHomography(ptsA, ptsB, cv2.RANSAC, reprojThresh) 19 | matchesMask = status.ravel().tolist() 20 | return matches, H, matchesMask 21 | return None 22 | 23 | 24 | def get_pair_homography(frame_1, frame_2, annot_1, annot_2, hand_threshold=0.1, obj_threshold=0.1): 25 | flag = True 26 | descriptor = cv2.xfeatures2d.SURF_create() 27 | msk_img_1 = get_mask(frame_1, annot_1, hand_threshold=hand_threshold, obj_threshold=obj_threshold) 28 | msk_img_2 = get_mask(frame_2, annot_2, hand_threshold=hand_threshold, obj_threshold=obj_threshold) 29 | (kpsA, featuresA) = descriptor.detectAndCompute(frame_1, mask=msk_img_1) 30 | (kpsB, featuresB) = descriptor.detectAndCompute(frame_2, mask=msk_img_2) 31 | matches, matchesMask = None, None 32 | try: 33 | (matches, H_BA, matchesMask) = match_keypoints(kpsB, kpsA, featuresB, featuresA) 34 | except Exception: 35 | print("compute homography failed!") 36 | H_BA = np.array([1.0, 0, 0, 0, 1.0, 0, 0, 0, 1.0]).reshape(3, 3) 37 | flag = False 38 | 39 | NoneType = type(None) 40 | if type(H_BA) == NoneType: 41 | print("compute homography failed!") 42 | H_BA = np.array([1.0, 0, 0, 0, 1.0, 0, 0, 0, 1.0]).reshape(3, 3) 43 | flag = False 44 | try: 45 | np.linalg.inv(H_BA) 46 | except Exception: 47 | print("compute homography failed!") 48 | H_BA = np.array([1.0, 0, 0, 0, 1.0, 0, 0, 0, 1.0]).reshape(3, 3) 49 | flag = False 50 | return matches, H_BA, matchesMask, flag 51 | 52 | 53 | def get_homo_point(point, homography): 54 | cx, cy = point 55 | center = np.array((cx, cy, 1.0), dtype=np.float32) 56 | x, y, z = np.dot(homography, center) 57 | x, y = x / z, y / z 58 | point = np.array((x, y), dtype=np.float32) 59 | return point 60 | 61 | 62 | def get_homo_bbox_point(bbox, homography): 63 | x1, y1, x2, y2 = np.array(bbox).reshape(-1) 64 | points = np.array([[x1, y1], [x2, y1], [x1, y2], [x2, y2]], dtype=np.float32) 65 | points_homo = np.concatenate((points, np.ones((4, 1), dtype=np.float32)), axis=1) 66 | points_coord = np.dot(points_homo, homography.T) 67 | points_coord2d = points_coord[:, :2] / points_coord[:, None, 2] 68 | return points_coord2d 69 | 70 | 71 | def get_hand_center(annot, hand_threshold=0.1): 72 | hands = [hand for hand in annot.hands if hand.score >= hand_threshold] 73 | hands_center= {} 74 | hands_score = {} 75 | for hand in hands: 76 | side = hand.side.name 77 | score = hand.score 78 | if side not in hands_center or score > hands_score[side]: 79 | hands_center[side] = hand.bbox.center 80 | hands_score[side] = score 81 | return hands_center 82 | 83 | 84 | def get_hand_point(hands_center, homography, side): 85 | point, homo_point = None, None 86 | if side in hands_center: 87 | point = hands_center[side] 88 | homo_point = get_homo_point(point, homography) 89 | return point, homo_point 90 | 91 | 92 | def traj_compute(frames, annots, hand_sides, hand_threshold=0.1, obj_threshold=0.1): 93 | imgH, imgW = frames[0].shape[:2] 94 | left_traj, right_traj = [], [] 95 | left_centers , right_centers= [], [] 96 | homography_stack = [np.eye(3)] 97 | for idx in range(1, len(frames)): 98 | matches, H_BA, matchesMask, flag = get_pair_homography(frames[idx - 1], frames[idx], 99 | annots[idx - 1], annots[idx], 100 | hand_threshold=hand_threshold, 101 | obj_threshold=obj_threshold) 102 | if not flag: 103 | return None 104 | else: 105 | homography_stack.append(np.dot(homography_stack[-1], H_BA)) 106 | for idx in range(len(frames)): 107 | hands_center = get_hand_center(annots[idx], hand_threshold=hand_threshold) 108 | if "LEFT" in hand_sides: 109 | left_center, left_point = get_hand_point(hands_center, homography_stack[idx], "LEFT") 110 | left_centers.append(left_center) 111 | left_traj.append(left_point) 112 | if "RIGHT" in hand_sides: 113 | right_center, right_point = get_hand_point(hands_center, homography_stack[idx], "RIGHT") 114 | right_centers.append(right_center) 115 | right_traj.append(right_point) 116 | 117 | left_traj = valid_traj(left_traj, imgW=imgW, imgH=imgH) 118 | right_traj = valid_traj(right_traj, imgW=imgW, imgH=imgH) 119 | return left_traj, left_centers, right_traj, right_centers, homography_stack 120 | 121 | 122 | def traj_completion(traj, side, imgW=456, imgH=256): 123 | from scipy.interpolate import CubicHermiteSpline 124 | 125 | def get_valid_traj(traj, imgW, imgH): 126 | traj[traj < 0] = traj[traj >= 0].min() 127 | traj[:, 0][traj[:, 0] > 1.5 * imgW] = 1.5 * imgW 128 | traj[:, 1][traj[:, 1] > 1.5 * imgH] = 1.5 * imgH 129 | return traj 130 | 131 | def spline_interpolation(axis): 132 | fill_times = np.array(fill_indices, dtype=np.float32) 133 | fill_traj = np.array([traj[idx][axis] for idx in fill_indices], dtype=np.float32) 134 | dt = fill_times[2:] - fill_times[:-2] 135 | dt = np.hstack([fill_times[1] - fill_times[0], dt, fill_times[-1] - fill_times[-2]]) 136 | dx = fill_traj[2:] - fill_traj[:-2] 137 | dx = np.hstack([fill_traj[1] - fill_traj[0], dx, fill_traj[-1] - fill_traj[-2]]) 138 | dxdt = dx / dt 139 | curve = CubicHermiteSpline(fill_times, fill_traj, dxdt) 140 | full_traj = curve(np.arange(len(traj), dtype=np.float32)) 141 | return full_traj, curve 142 | 143 | fill_indices = [idx for idx, point in enumerate(traj) if point is not None] 144 | if 0 not in fill_indices: 145 | if side == "LEFT": 146 | traj[0] = np.array((0.25*imgW, 1.5*imgH), dtype=np.float32) 147 | else: 148 | traj[0] = np.array((0.75*imgW, 1.5*imgH), dtype=np.float32) 149 | fill_indices = np.insert(fill_indices, 0, 0).tolist() 150 | fill_indices.sort() 151 | full_traj_x, curve_x = spline_interpolation(axis=0) 152 | full_traj_y, curve_y = spline_interpolation(axis=1) 153 | full_traj = np.stack([full_traj_x, full_traj_y], axis=1) 154 | full_traj = get_valid_traj(full_traj, imgW=imgW, imgH=imgH) 155 | curve = [curve_x, curve_y] 156 | return full_traj, fill_indices, curve 157 | 158 | 159 | def compute_hand_traj(frames, annots, hand_sides, hand_threshold=0.1, obj_threshold=0.1): 160 | imgH, imgW = frames[0].shape[:2] 161 | results = traj_compute(frames, annots, hand_sides, 162 | hand_threshold=hand_threshold, obj_threshold=obj_threshold) 163 | if results is None: 164 | print("compute homography failed") 165 | return None 166 | else: 167 | left_traj, left_centers, right_traj, right_centers, homography_stack = results 168 | if len(left_traj) == 0 and len(right_traj) == 0: 169 | print("compute traj failed") 170 | return None 171 | hand_trajs = {} 172 | if len(left_traj) == 0: 173 | print("left traj filtered out") 174 | else: 175 | left_complete_traj, left_fill_indices, left_curve = traj_completion(left_traj, side="LEFT", 176 | imgW=imgW, imgH=imgH) 177 | hand_trajs["LEFT"] = {"traj": left_complete_traj, "fill_indices": left_fill_indices, 178 | "fit_curve": left_curve, "centers": left_centers} 179 | if len(right_traj) == 0: 180 | print("right traj filtered out") 181 | else: 182 | right_complete_traj, right_fill_indices, right_curve = traj_completion(right_traj, side="RIGHT", 183 | imgW=imgW, imgH=imgH) 184 | hand_trajs["RIGHT"] = {"traj": right_complete_traj, "fill_indices": right_fill_indices, 185 | "fit_curve": right_curve, "centers": right_centers} 186 | return homography_stack, hand_trajs 187 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Diff-IP2D: Diffusion Models for 2D HOI Prediction 2 | 3 | This is the official implementation for our paper accepted by **IROS 2025**: 4 | 5 | [Diff-IP2D: Diffusion-Based Hand-Object Interaction Prediction on Egocentric Videos](https://arxiv.org/abs/2405.04370). 6 | 7 | [Junyi Ma](https://github.com/BIT-MJY)1, [Jingyi Xu](https://github.com/BIT-XJY)1, [Xieyuanli Chen](https://xieyuanli-chen.com/)2, [Hesheng Wang](https://scholar.google.com/citations?hl=en&user=q6AY9XsAAAAJ&view_op=list_works&sortby=pubdate)1* 8 | 9 | 1SJTU  2NUDT  *Corresponding author 10 | 11 | 12 | Diff-IP2D is the first work using the devised denoising diffusion probabilistic model to jointly forecast future hand trajectories and object affordances with only 2D egocentric videos as input. It provides a foundation generative paradigm in the field of HOI prediction. 13 | 14 | pred 15 | 16 | white: ours, blue: baseline, red: GT. Diff-IP2D generates plausible future hand waypoints and final hand positions (even if there is a large error in the early stage) with bidirectional constraints. 17 | 18 | pred 19 | 20 | A cup example :coffee:. The hand waypoints from ground-truth labels and HOI prediction approaches are connected by blue and white dashed lines respectively. **There is a lack of backward constraints in AR-based USST, leading to a shorter predicted trajectory (almost curled up into a point) and larger accumulated displacement errors. In contrast, our Diff-IP2D with iter-NAR paradigm is potentially guided by final HOI states, and thus predicts more accurate hand trajectories following both spatial causality and temporal causality.** 21 | 22 | 23 | If you find our work helpful to your research, please cite our paper as 24 | ``` 25 | @article{ma2024diffip2d, 26 | title={Diff-IP2D: Diffusion-Based Hand-Object Interaction Prediction on Egocentric Videos}, 27 | author={Ma, Junyi and Xu, Jingyi and Chen, Xieyuanli and Wang, Hesheng}, 28 | journal={arXiv preprint arXiv:2405.04370}, 29 | year={2024}} 30 | ``` 31 | 32 | ## 1. Setup 33 | 34 | Clone the repository (requires git): 35 | 36 | ```bash 37 | git clone https://github.com/IRMVLab/Diff-IP2D.git 38 | cd Diff-IP2D 39 | ``` 40 | 41 | Create the environment and install dependencies into it: 42 | 43 | ```bash 44 | conda create -n diffip python=3.8 pip 45 | conda activate diffip 46 | pip install -r requirements.txt 47 | ``` 48 | 49 | ## 2. Data Structure 50 | 51 | We suggest using our proposed data structure for faster reproducing, which is posted here: 52 | 53 | ```bash 54 | ├── base_models 55 | │ └── model.pth.tar 56 | ├── common 57 | │ ├── epic-kitchens-100-annotations # from OCT or merged ourselves 58 | │ │ ├── EPIC_100_test_timestamps.csv 59 | │ │ ├── EPIC_100_test_timestamps.pkl 60 | │ │ ├── EPIC_100_train.csv 61 | │ │ ├── EPIC_100_train.pkl 62 | │ │ ├── EPIC_100_train_val_test.csv 63 | │ │ ├── EPIC_100_verb_classes.csv 64 | │ │ ├── EPIC_100_video_info.csv 65 | │ │ ├── actions.csv 66 | │ │ └── ... 67 | │ └── rulstm # raw rulstm repo 68 | │ ├── FEATEXT 69 | │ ├── FasterRCNN 70 | │ └── RULSTM 71 | ├── data 72 | │ ├── ek100 # manually generated ourselves or from OCT or from raw EK 73 | │ │ ├── feats_train 74 | │ │ │ ├── full_data_with_future_train_part1.lmdb 75 | │ │ │ └── full_data_with_future_train_part2.lmdb 76 | │ │ ├── feats_test 77 | │ │ │ └── data.lmdb 78 | │ │ ├── labels # 79 | │ │ │ ├── label_0.pkl 80 | │ │ │ └── ... 81 | │ │ ├── ek100_eval_labels.pkl 82 | │ │ └── video_info.json 83 | │ ├── raw_images # raw EPIC-KITCHENS dataset 84 | │ │ └── EPIC-KITCHENS 85 | │ ├── homos_train # auto generated when first running 86 | │ ├── homos_test # auto generated when first running 87 | ├── diffip_weights # auto generated when first saving checkpoints 88 | │ ├── checkpoint_1.pth.tar 89 | │ └── ... 90 | ├── collected_pred_traj # auto generated when first eval traj 91 | ├── collected_pred_aff # auto generated when first eval affordance 92 | ├── log # auto generated when first running 93 | └── uid2future_file_name.pickle 94 | ``` 95 | 96 | ## 3. Access to Required Files 97 | Here we provide the links to access all the above-mentioned files that cannot be generated automatically by running scripts of this repo: 98 | * [base_models/model.pth.tar](https://drive.google.com/file/d/16IkQ4hOQk2_Klhd806J-46hLN-OGokxa/view): Base model from OCT [1]. 99 | * [common/epic-kitchens-100-annotations](https://1drv.ms/u/c/0e8794c880029a8f/EVhoi5yXoGNKh0FGUFNgQM4Ba26rLoBaN4cmDTaOyj9WVA?e=5JypW7): Annotations from raw EK [2] and our mannually merged files. Please do not confuse this folder with the one provided by OCT [1]. 100 | * [common/rulstm](https://github.com/fpv-iplab/rulstm): Original RULSTM [3] repo. 101 | * [data/ek100/feats_train](https://1drv.ms/u/c/0e8794c880029a8f/EYne7Qr09u1Mie-3N0U0CkgB1c-72AguS5nj0mGojLIflg?e=KMi2Mn): Our mannually generated feature files for training our model. 102 | * [data/ek100/feats_test](https://1drv.ms/u/c/0e8794c880029a8f/EfqtPairdOxJudetBi66Fz4BxN1W6c7TXFjcVUIJFfJrxA?e=IZ5j3U): Feature files provided by OCT [1] for testing our model. 103 | * [data/ek100/labels](https://1drv.ms/u/c/0e8794c880029a8f/EXTy2gkcv69LrvyYMGcZ5YsBi949htwa60QGEVCcIkv-4w?e=UM38ap): Labels from OCT [1] for training models. 104 | * [data/ek100/ek100_eval_labels.pkl](https://drive.google.com/file/d/1s7qpBa-JjjuGk7v_aiuU2lvRjgP6fi9C/view): Labels from OCT [1] for affordance evaluation. Please refer to the original OCT folder. 105 | * [data/ek100/video_info.json](https://1drv.ms/u/c/0e8794c880029a8f/ERcBH9ic9AxMg5czXFkFqooBlF-q-TQS1kHyJ3L6iUt0vQ?e=wdEtBK): Used video index. 106 | * [data/raw_images](https://github.com/epic-kitchens/epic-kitchens-100-annotations): Original EK images [2]. Following the instructions in EK repo for downloading raw RGB images by `python epic_downloader.py --rgb-frames` since only raw images are required in Diff-IP2D. 107 | * [uid2future_file_name.pickle](https://1drv.ms/u/c/0e8794c880029a8f/EQAUO4TrytBJp--tziBD8q0B1RooqdzotTGlRNkX6o-qtQ?e=v83ehP): Indicator generated ourselves. 108 | 109 | ## 4. How to Use 110 | 111 | We have released the deployment of Diff-IP2D on EK100. We are going to release relevant codes and data on EK55 and EG soon ... 112 | 113 | ### 4.1 Pretrained Weights 114 | 115 | | Version | Download link | Notes | 116 | |----|----|----| 117 | | 1.1 | [OneDrive]() / [Google Drive]() | pretrained on EK100 (two val) | 118 | | 1.2 | [OneDrive]() / [Google Drive]() | pretrained on EK100 (one val) | 119 | 120 | Please change the paths to pretrained weights in `run_train.py`, `run_val_traj.py`, and `run_val_affordance.py`. 121 | 122 | ### 4.2 Train 123 | 124 | ```bash 125 | bash train.sh 126 | ``` 127 | 128 | ### 4.3 Test 129 | 130 | Please test trajectory prediction by 131 | 132 | ```bash 133 | bash val_traj.sh 134 | ``` 135 | 136 | Test affordance prediction by 137 | 138 | ```bash 139 | bash val_affordance.sh 140 | ``` 141 | 142 | ### 4.4 Other Notes 143 | 144 | * We are working hard to organize and release a more polished version of the code, along with its application on the new dataset. 145 | * You may obtain results that slightly differ from those presented in the paper due to the stochastic nature of diffusion inference with different seeds. Prediction clusters can be obtained using multiple different seeds. 146 | * Homography will be automatically saved to `data/homos_train` and `data/homos_test` after the first training/test epoch for quick reuse. 147 | * Separate validation sets will lead to checkpoints at different epochs for two tasks. 148 | * Please modify the params in [config files]() before training and testing. For example, change the paths in `options/expopts.py`. You can also set `fast_test=True` for faster inference without sacrificing much accuracy. 149 | 150 | ## 5. Acknowledgment 151 | 152 | 153 | We sincerely appreciate the fantastic pioneering works that provide codebases and datasets for this work. Please also cite them if you use the relevant code and data. 154 | 155 | [1] Shaowei Liu, Subarna Tripathi, Somdeb Majumdar, and Xiaolong Wang. Joint hand motion and interaction hotspots prediction from egocentric videos. In CVPR, pages 3282–3292, 2022. [Paper](https://arxiv.org/abs/2204.01696) 156 | 157 | [2] Dima Damen, Hazel Doughty, Giovanni Maria Farinella, Antonino Furnari, Evangelos Kazakos, Jian Ma, Davide Moltisanti, Jonathan Munro, Toby Perrett, Will Price, et al. Rescaling egocentric vision: Collection, pipeline and challenges for epic-kitchens-100. IJCV, pages 1–23, 2022. [Paper](https://arxiv.org/abs/2006.13256) 158 | 159 | [3] Antonino Furnari and Giovanni Maria Farinella. Rolling-unrolling lstms for action anticipation from first-person video. IEEE TPAMI, 43(11):4021–4036, 2020. [Paper](https://arxiv.org/abs/2005.02190) 160 | 161 | [4] Shansan Gong, Mukai Li, Jiangtao Feng, Zhiyong Wu, and Lingpeng Kong. Diffuseq: Sequence to sequence text generation with diffusion models. In ICLR, 2023. [Paper](https://arxiv.org/abs/2210.08933) 162 | 163 | [5] Wentao Bao, Lele Chen, Libing Zeng, Zhong Li, Yi Xu, Junsong Yuan, and Yu Kong. Uncertainty-aware state space transformer for egocentric 3d hand trajectory forecasting. In ICCV, pages 13702–13711, 2023. [Paper](https://arxiv.org/abs/2307.08243) 164 | -------------------------------------------------------------------------------- /preprocess/affordance_util.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | from preprocess.dataset_util import bbox_inter 4 | 5 | 6 | def skin_extract(image): 7 | def color_segmentation(): 8 | lower_HSV_values = np.array([0, 40, 0], dtype="uint8") 9 | upper_HSV_values = np.array([25, 255, 255], dtype="uint8") 10 | lower_YCbCr_values = np.array((0, 138, 67), dtype="uint8") 11 | upper_YCbCr_values = np.array((255, 173, 133), dtype="uint8") 12 | mask_YCbCr = cv2.inRange(YCbCr_image, lower_YCbCr_values, upper_YCbCr_values) 13 | mask_HSV = cv2.inRange(HSV_image, lower_HSV_values, upper_HSV_values) 14 | binary_mask_image = cv2.add(mask_HSV, mask_YCbCr) 15 | return binary_mask_image 16 | 17 | HSV_image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) 18 | YCbCr_image = cv2.cvtColor(image, cv2.COLOR_BGR2YCR_CB) 19 | binary_mask_image = color_segmentation() 20 | image_foreground = cv2.erode(binary_mask_image, None, iterations=3) 21 | dilated_binary_image = cv2.dilate(binary_mask_image, None, iterations=3) 22 | ret, image_background = cv2.threshold(dilated_binary_image, 1, 128, cv2.THRESH_BINARY) 23 | 24 | image_marker = cv2.add(image_foreground, image_background) 25 | image_marker32 = np.int32(image_marker) 26 | cv2.watershed(image, image_marker32) 27 | m = cv2.convertScaleAbs(image_marker32) 28 | ret, image_mask = cv2.threshold(m, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) 29 | kernel = np.ones((20, 20), np.uint8) 30 | image_mask = cv2.morphologyEx(image_mask, cv2.MORPH_CLOSE, kernel) 31 | return image_mask 32 | 33 | 34 | def farthest_sampling(pcd, n_samples, init_pcd=None): 35 | def compute_distance(a, b): 36 | return np.linalg.norm(a - b, ord=2, axis=2) 37 | 38 | n_pts, dim = pcd.shape[0], pcd.shape[1] 39 | selected_pts_expanded = np.zeros(shape=(n_samples, 1, dim)) 40 | remaining_pts = np.copy(pcd) 41 | 42 | if init_pcd is None: 43 | if n_pts > 1: 44 | start_idx = np.random.randint(low=0, high=n_pts - 1) 45 | else: 46 | start_idx = 0 47 | selected_pts_expanded[0] = remaining_pts[start_idx] 48 | n_selected_pts = 1 49 | else: 50 | num_points = min(init_pcd.shape[0], n_samples) 51 | selected_pts_expanded[:num_points] = init_pcd[:num_points, None, :] 52 | n_selected_pts = num_points 53 | 54 | for _ in range(1, n_samples): 55 | if n_selected_pts < n_samples: 56 | dist_pts_to_selected = compute_distance(remaining_pts, selected_pts_expanded[:n_selected_pts]).T 57 | dist_pts_to_selected_min = np.min(dist_pts_to_selected, axis=1, keepdims=True) 58 | res_selected_idx = np.argmax(dist_pts_to_selected_min) 59 | selected_pts_expanded[n_selected_pts] = remaining_pts[res_selected_idx] 60 | n_selected_pts += 1 61 | 62 | selected_pts = np.squeeze(selected_pts_expanded, axis=1) 63 | return selected_pts 64 | 65 | 66 | def compute_heatmap(points, image_size, k_ratio=3.0): 67 | points = np.asarray(points) 68 | heatmap = np.zeros((image_size[0], image_size[1]), dtype=np.float32) 69 | n_points = points.shape[0] 70 | for i in range(n_points): 71 | x = points[i, 0] 72 | y = points[i, 1] 73 | col = int(x) 74 | row = int(y) 75 | try: 76 | heatmap[col, row] += 1.0 77 | except: 78 | col = min(max(col, 0), image_size[0] - 1) 79 | row = min(max(row, 0), image_size[1] - 1) 80 | heatmap[col, row] += 1.0 81 | k_size = int(np.sqrt(image_size[0] * image_size[1]) / k_ratio) 82 | if k_size % 2 == 0: 83 | k_size += 1 84 | heatmap = cv2.GaussianBlur(heatmap, (k_size, k_size), 0) 85 | if heatmap.max() > 0: 86 | heatmap /= heatmap.max() 87 | heatmap = heatmap.transpose() 88 | return heatmap 89 | 90 | 91 | def select_points_bbox(bbox, points, tolerance=2): 92 | x1, y1, x2, y2 = bbox 93 | ind_x = np.logical_and(points[:, 0] > x1-tolerance, points[:, 0] < x2+tolerance) 94 | ind_y = np.logical_and(points[:, 1] > y1-tolerance, points[:, 1] < y2+tolerance) 95 | ind = np.logical_and(ind_x, ind_y) 96 | indices = np.where(ind == True)[0] 97 | return points[indices] 98 | 99 | 100 | def find_contour_points(mask): 101 | _, contours, hierarchy = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) 102 | if len(contours) != 0: 103 | c = max(contours, key=cv2.contourArea) 104 | c = c.squeeze(axis=1) 105 | return c 106 | else: 107 | return None 108 | 109 | 110 | def get_points_homo(select_points, homography, active_obj_traj, obj_bboxs_traj): 111 | # active_obj_traj: active obj traj in last observation frame 112 | # obj_bboxs_traj: active obj bbox traj in last observation frame 113 | select_points_homo = np.concatenate((select_points, np.ones((select_points.shape[0], 1), dtype=np.float32)), axis=1) 114 | select_points_homo = np.dot(select_points_homo, homography.T) 115 | select_points_homo = select_points_homo[:, :2] / select_points_homo[:, None, 2] 116 | 117 | obj_point_last_observe = np.array(active_obj_traj[0]) 118 | obj_point_future_start = np.array(active_obj_traj[-1]) 119 | 120 | future2last_trans = obj_point_last_observe - obj_point_future_start 121 | select_points_homo = select_points_homo + future2last_trans 122 | 123 | fill_indices = [idx for idx, points in enumerate(obj_bboxs_traj) if points is not None] 124 | contour_last_observe = obj_bboxs_traj[fill_indices[0]] 125 | contour_future_homo = obj_bboxs_traj[fill_indices[-1]] + future2last_trans 126 | contour_last_observe = contour_last_observe[:, None, :].astype(np.int) 127 | contour_future_homo = contour_future_homo[:, None, :].astype(np.int) 128 | filtered_points = [] 129 | for point in select_points_homo: 130 | if cv2.pointPolygonTest(contour_last_observe, (point[0], point[1]), False) >= 0 \ 131 | or cv2.pointPolygonTest(contour_future_homo, (point[0], point[1]), False) >= 0: 132 | filtered_points.append(point) 133 | filtered_points = np.array(filtered_points) 134 | return filtered_points 135 | 136 | 137 | def compute_affordance(frame, active_hand, active_obj, num_points=5, num_sampling=20): 138 | skin_mask = skin_extract(frame) 139 | hand_bbox = np.array(active_hand.bbox.coords_int).reshape(-1) 140 | obj_bbox = np.array(active_obj.bbox.coords_int).reshape(-1) 141 | obj_center = active_obj.bbox.center 142 | xA, yA, xB, yB, iou = bbox_inter(hand_bbox, obj_bbox) 143 | if not iou > 0: 144 | return None 145 | x1, y1, x2, y2 = hand_bbox 146 | hand_mask = np.zeros_like(skin_mask, dtype=np.uint8) 147 | hand_mask[y1:y2, x1:x2] = 255 148 | hand_mask = cv2.bitwise_and(skin_mask, hand_mask) 149 | select_points, init_points = None, None 150 | contact_points = find_contour_points(hand_mask) 151 | 152 | if contact_points is not None and contact_points.shape[0] > 0: 153 | contact_points = select_points_bbox((xA, yA, xB, yB), contact_points) 154 | if contact_points.shape[0] >= num_points: 155 | if contact_points.shape[0] > num_sampling: 156 | contact_points = farthest_sampling(contact_points, n_samples=num_sampling) 157 | distance = np.linalg.norm(contact_points - obj_center, ord=2, axis=1) 158 | indices = np.argsort(distance)[:num_points] 159 | select_points = contact_points[indices] 160 | elif contact_points.shape[0] > 0: 161 | print("no enough boundary points detected, sampling points in interaction region") 162 | init_points = contact_points 163 | else: 164 | print("no boundary points detected, use farthest point sampling") 165 | else: 166 | print("no boundary points detected, use farthest point sampling") 167 | if select_points is None: 168 | ho_mask = np.zeros_like(skin_mask, dtype=np.uint8) 169 | ho_mask[yA:yB, xA:xB] = 255 170 | ho_mask = cv2.bitwise_and(skin_mask, ho_mask) 171 | points = np.array(np.where(ho_mask[yA:yB, xA:xB] > 0)).T 172 | if points.shape[0] == 0: 173 | xx, yy = np.meshgrid(np.arange(xB - xA), np.arange(yB - yA)) 174 | xx += xA 175 | yy += yA 176 | points = np.vstack([xx.reshape(-1), yy.reshape(-1)]).T 177 | else: 178 | points = points[:, [1, 0]] 179 | points[:, 0] += xA 180 | points[:, 1] += yA 181 | if not points.shape[0] > 0: 182 | return None 183 | contact_points = farthest_sampling(points, n_samples=min(num_sampling, points.shape[0]), init_pcd=init_points) 184 | distance = np.linalg.norm(contact_points - obj_center, ord=2, axis=1) 185 | indices = np.argsort(distance)[:num_points] 186 | select_points = contact_points[indices] 187 | return select_points 188 | 189 | 190 | def compute_obj_affordance(frame, annot, active_obj, active_obj_idx, homography, 191 | active_obj_traj, obj_bboxs_traj, 192 | num_points=5, num_sampling=20, 193 | hand_threshold=0.1, obj_threshold=0.1): 194 | affordance_info = {} 195 | hand_object_idx_correspondences = annot.get_hand_object_interactions(object_threshold=obj_threshold, 196 | hand_threshold=hand_threshold) 197 | select_points = None 198 | for hand_idx, object_idx in hand_object_idx_correspondences.items(): 199 | if object_idx == active_obj_idx: 200 | active_hand = annot.hands[hand_idx] 201 | affordance_info[active_hand.side.name] = np.array(active_hand.bbox.coords_int).reshape(-1) 202 | cmap_points = compute_affordance(frame, active_hand, active_obj, num_points=num_points, num_sampling=num_sampling) 203 | if select_points is None and (cmap_points is not None and cmap_points.shape[0] > 0): 204 | select_points = cmap_points 205 | elif select_points is not None and (cmap_points is not None and cmap_points.shape[0] > 0): 206 | select_points = np.concatenate((select_points, cmap_points), axis=0) 207 | if select_points is None: 208 | print("affordance contact points filtered out") 209 | return None 210 | select_points_homo = get_points_homo(select_points, homography, active_obj_traj, obj_bboxs_traj) 211 | if len(select_points_homo) == 0: 212 | print("affordance contact points filtered out") 213 | return None 214 | else: 215 | affordance_info["select_points"] = select_points 216 | affordance_info["select_points_homo"] = select_points_homo 217 | 218 | obj_bbox = np.array(active_obj.bbox.coords_int).reshape(-1) 219 | affordance_info["obj_bbox"] = obj_bbox 220 | return affordance_info -------------------------------------------------------------------------------- /datasets/dataloaders.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset, DataLoader 2 | from tqdm import tqdm 3 | import json 4 | import numpy as np 5 | 6 | from datasets.dataset_utils import get_ek55_annotation, get_ek100_annotation 7 | from datasets.input_loaders import get_loaders 8 | 9 | 10 | class EpicAction(object): 11 | def __init__(self, uid, participant_id, video_id, verb, verb_class, 12 | noun, noun_class, all_nouns, all_noun_classes, start_frame, 13 | stop_frame, start_time, stop_time, ori_fps, partition, action, action_class): 14 | self.uid = uid 15 | self.participant_id = participant_id 16 | self.video_id = video_id 17 | self.verb = verb 18 | self.verb_class = verb_class 19 | self.noun = noun 20 | self.noun_class = noun_class 21 | self.all_nouns = all_nouns 22 | self.all_noun_classes = all_noun_classes 23 | self.start_frame = start_frame 24 | self.stop_frame = stop_frame 25 | self.start_time = start_time 26 | self.stop_time = stop_time 27 | self.ori_fps = ori_fps 28 | self.partition = partition 29 | self.action = action 30 | self.action_class = action_class 31 | 32 | self.duration = self.stop_time - self.start_time 33 | 34 | def __repr__(self): 35 | return json.dumps(self.__dict__, indent=4) 36 | 37 | def set_previous_actions(self, actions): 38 | self.actions_prev = actions 39 | 40 | 41 | class EpicVideo(object): 42 | def __init__(self, df_video, ori_fps, partition, t_ant=None): 43 | self.df = df_video 44 | self.ori_fps = ori_fps 45 | self.partition = partition 46 | self.t_ant = t_ant 47 | 48 | self.actions, self.actions_invalid = self._get_actions() 49 | self.duration = max([a.stop_time for a in self.actions]) 50 | 51 | def _get_actions(self): 52 | actions = [] 53 | _actions_all = [] 54 | actions_invalid = [] 55 | for _, row in self.df.iterrows(): 56 | action_args = { 57 | 'uid': row.uid, 58 | 'participant_id': row.participant_id, 59 | 'video_id': row.video_id, 60 | 'verb': row.verb if 'test' not in self.partition else None, 61 | 'verb_class': row.verb_class if 'test' not in self.partition else None, 62 | 'noun': row.noun if 'test' not in self.partition else None, 63 | 'noun_class': row.noun_class if 'test' not in self.partition else None, 64 | 'all_nouns': row.all_nouns if 'test' not in self.partition else None, 65 | 'all_noun_classes': row.all_noun_classes if 'test' not in self.partition else None, 66 | 'start_frame': row.start_frame, 67 | 'stop_frame': row.stop_time, 68 | 'start_time': row.start_time, 69 | 'stop_time': row.stop_time, 70 | 'ori_fps': self.ori_fps, 71 | 'partition': self.partition, 72 | 'action': row.action if 'test' not in self.partition else None, 73 | 'action_class': row.action_class if 'test' not in self.partition else None, 74 | } 75 | action = EpicAction(**action_args) 76 | action.set_previous_actions([aa for aa in _actions_all]) 77 | assert self.t_ant is not None 78 | assert self.t_ant > 0.0 79 | if action.start_time - self.t_ant >= 0: 80 | actions += [action] 81 | else: 82 | actions_invalid += [action] 83 | _actions_all += [action] 84 | return actions, actions_invalid 85 | 86 | 87 | class EpicDataset(Dataset): 88 | def __init__(self, df, partition, ori_fps=60.0, fps=4.0, loader=None, t_ant=None, transform=None, 89 | num_actions_prev=None, label_path=None, eval_label_path=None, 90 | annot_path=None, rulstm_annot_path=None, ek_version=None): 91 | super().__init__() 92 | self.partition = partition 93 | self.ori_fps = ori_fps 94 | self.fps = fps 95 | self.df = df 96 | self.loader = loader 97 | self.t_ant = t_ant 98 | self.transform = transform 99 | self.num_actions_prev = num_actions_prev 100 | 101 | self.videos = self._get_videos() 102 | self.actions, self.actions_invalid = self._get_actions() 103 | 104 | def _get_videos(self): 105 | video_ids = sorted(list(set(self.df['video_id'].values.tolist()))) 106 | videos = [] 107 | pbar = tqdm(desc=f'Loading {self.partition} samples', total=len(self.df)) 108 | for video_id in video_ids: 109 | video_args = { 110 | 'df_video': self.df[self.df['video_id'] == video_id].copy(), 111 | 'ori_fps': self.ori_fps, 112 | 'partition': self.partition, 113 | 't_ant': self.t_ant 114 | } 115 | video = EpicVideo(**video_args) 116 | videos += [video] 117 | pbar.update(len(video.actions)) 118 | pbar.close() 119 | return videos 120 | 121 | def _get_actions(self): 122 | actions = [] 123 | actions_invalid = [] 124 | for video in self.videos: 125 | actions += video.actions 126 | actions_invalid += video.actions_invalid 127 | return actions, actions_invalid 128 | 129 | def __len__(self): 130 | return len(self.actions) 131 | 132 | 133 | def __getitem__(self, idx): 134 | a = self.actions[idx] 135 | sample = {'uid': a.uid} 136 | 137 | inputs = self.loader(a) 138 | sample.update(inputs) 139 | 140 | if 'test' not in self.partition: 141 | sample['verb_class'] = a.verb_class 142 | sample['noun_class'] = a.noun_class 143 | sample['action_class'] = a.action_class 144 | 145 | actions_prev = [-1] + [aa.action_class for aa in a.actions_prev] 146 | actions_prev = actions_prev[-self.num_actions_prev:] 147 | if len(actions_prev) < self.num_actions_prev: 148 | actions_prev = actions_prev[0:1] * (self.num_actions_prev - len(actions_prev)) + actions_prev 149 | actions_prev = np.array(actions_prev, dtype=np.int64) 150 | sample['action_class_prev'] = actions_prev 151 | return sample 152 | 153 | 154 | def get_datasets(args, epic_ds=None, featuresloader=None): 155 | loaders = get_loaders(args, featuresloader=featuresloader) 156 | 157 | annotation_args = { 158 | 'annot_path': args.annot_path, 159 | 'label_path': args.label_path, 160 | 'eval_label_path': args.eval_label_path, 161 | 'rulstm_annot_path': args.rulstm_annot_path, 162 | 'validation_ratio': args.validation_ratio, 163 | 'use_rulstm_splits': args.use_rulstm_splits, 164 | 'use_label_only': args.use_label_only 165 | } 166 | 167 | if args.ek_version == 'ek55': 168 | dfs = { 169 | 'train': get_ek55_annotation(partition='train', **annotation_args), 170 | 'validation': get_ek55_annotation(partition='validation', **annotation_args), 171 | 'eval': get_ek55_annotation(partition='eval', **annotation_args), 172 | 'test_s1': get_ek55_annotation(partition='test_s1', **annotation_args), 173 | 'test_s2': get_ek55_annotation(partition='test_s2', **annotation_args), 174 | } 175 | elif args.ek_version == 'ek100': 176 | dfs = { 177 | 'train': get_ek100_annotation(partition='train', **annotation_args), 178 | 'validation': get_ek100_annotation(partition='validation', **annotation_args), 179 | 'eval': get_ek100_annotation(partition='eval', **annotation_args), 180 | 'test': get_ek100_annotation(partition='test', **annotation_args), 181 | } 182 | else: 183 | raise Exception(f'Error. EPIC-Kitchens Version "{args.ek_version}" not supported.') 184 | 185 | ds_args = { 186 | 'label_path': args.label_path[args.ek_version], 187 | 'eval_label_path': args.eval_label_path[args.ek_version], 188 | 'annot_path': args.annot_path, 189 | 'rulstm_annot_path': args.rulstm_annot_path[args.ek_version], 190 | 'ek_version': args.ek_version, 191 | 'ori_fps': args.ori_fps, 192 | 'fps': args.fps, 193 | 't_ant': args.t_ant, 194 | 'num_actions_prev': args.num_actions_prev if args.task in ['anticipation'] else None, 195 | 'mode': args.mode, 196 | } 197 | 198 | if epic_ds is None: 199 | epic_ds = EpicDataset 200 | 201 | if args.mode in ['train', 'training']: 202 | dss = { 203 | 'train': epic_ds(df=dfs['train'], partition='train', loader=loaders['train'], **ds_args), 204 | 'validation': epic_ds(df=dfs['validation'], partition='validation', loader=loaders['validation'], 205 | **ds_args), 206 | 'eval': epic_ds(df=dfs['eval'], partition='eval', loader=loaders['validation'], **ds_args), 207 | } 208 | elif args.mode in ['validation', 'validating', 'validate']: 209 | dss = { 210 | 'validation': epic_ds(df=dfs['validation'], partition='validation', 211 | loader=loaders['validation'], **ds_args), 212 | 'eval': epic_ds(df=dfs['eval'], partition='eval', loader=loaders['validation'], **ds_args), 213 | } 214 | elif args.mode in ['test', 'testing']: 215 | 216 | if args.ek_version == "ek55": 217 | dss = { 218 | 'test_s1': epic_ds(df=dfs['test_s1'], partition='test_s1', loader=loaders['test'], **ds_args), 219 | 'test_s2': epic_ds(df=dfs['test_s2'], partition='test_s2', loader=loaders['test'], **ds_args), 220 | } 221 | elif args.ek_version == "ek100": 222 | dss = { 223 | 'test': epic_ds(df=dfs['test'], partition='test', loader=loaders['test'], **ds_args), 224 | } 225 | else: 226 | raise Exception(f'Error. Mode "{args.mode}" not supported.') 227 | 228 | return dss 229 | 230 | 231 | def get_dataloaders(args, epic_ds=None, featuresloader=None): 232 | dss = get_datasets(args, epic_ds=epic_ds,featuresloader=featuresloader) 233 | dl_args = { 234 | 'batch_size': args.batch_size, 235 | 'pin_memory': True, 236 | 'num_workers': args.num_workers, 237 | } 238 | if args.mode in ['train', 'training']: 239 | dls = { 240 | 'train': DataLoader(dss['train'], shuffle=False, **dl_args), 241 | 'validation': DataLoader(dss['validation'], shuffle=False, **dl_args), 242 | } 243 | elif args.mode in ['validate', 'validation', 'validating']: 244 | dls = { 245 | 'validation': DataLoader(dss['validation'], shuffle=False, **dl_args), 246 | } 247 | elif args.mode == 'test': 248 | if args.ek_version == "ek55": 249 | dls = { 250 | 'test_s1': DataLoader(dss['test_s1'], shuffle=False, **dl_args), 251 | 'test_s2': DataLoader(dss['test_s2'], shuffle=False, **dl_args), 252 | } 253 | elif args.ek_version == "ek100": 254 | dls = { 255 | 'test': DataLoader(dss['test'], shuffle=False, **dl_args), 256 | } 257 | else: 258 | raise Exception(f'Error. Mode "{args.mode}" not supported.') 259 | return dls 260 | -------------------------------------------------------------------------------- /preprocess/ho_types.py: -------------------------------------------------------------------------------- 1 | """The core set of types that represent hand-object detections""" 2 | 3 | from enum import Enum, unique 4 | from itertools import chain 5 | from typing import Dict, Iterator, List, Tuple, cast 6 | 7 | import numpy as np 8 | from dataclasses import dataclass 9 | import preprocess.types_pb2 as pb 10 | 11 | __all__ = [ 12 | "HandSide", 13 | "HandState", 14 | "FloatVector", 15 | "BBox", 16 | "HandDetection", 17 | "ObjectDetection", 18 | "FrameDetections", 19 | ] 20 | 21 | 22 | @unique 23 | class HandSide(Enum): 24 | LEFT = 0 25 | RIGHT = 1 26 | 27 | 28 | @unique 29 | class HandState(Enum): 30 | """An enum describing the different states a hand can be in: 31 | - No contact: The hand isn't touching anything 32 | - Self contact: The hand is touching itself 33 | - Another person: The hand is touching another person 34 | - Portable object: The hand is in contact with a portable object 35 | - Stationary object: The hand is in contact with an immovable/stationary object""" 36 | 37 | NO_CONTACT = 0 38 | SELF_CONTACT = 1 39 | ANOTHER_PERSON = 2 40 | PORTABLE_OBJECT = 3 41 | STATIONARY_OBJECT = 4 42 | 43 | 44 | @dataclass 45 | class FloatVector: 46 | """A floating-point 2D vector representation""" 47 | x: np.float32 48 | y: np.float32 49 | 50 | def to_protobuf(self) -> pb.FloatVector: 51 | vector = pb.FloatVector() 52 | vector.x = self.x 53 | vector.y = self.y 54 | assert vector.IsInitialized() 55 | return vector 56 | 57 | @staticmethod 58 | def from_protobuf(vector: pb.FloatVector) -> "FloatVector": 59 | return FloatVector(x=vector.x, y=vector.y) 60 | 61 | def __add__(self, other: "FloatVector") -> "FloatVector": 62 | return FloatVector(x=self.x + other.x, y=self.y + other.y) 63 | 64 | def __mul__(self, scaler: float) -> "FloatVector": 65 | return FloatVector(x=self.x * scaler, y=self.y * scaler) 66 | 67 | def __iter__(self) -> Iterator[float]: 68 | yield from (self.x, self.y) 69 | 70 | @property 71 | def coord(self) -> Tuple[float, float]: 72 | """Return coordinates as a tuple""" 73 | return (self.x, self.y) 74 | 75 | def scale(self, width_factor: float = 1, height_factor: float = 1) -> None: 76 | """Scale x component by ``width_factor`` and y component by ``height_factor``""" 77 | self.x *= width_factor 78 | self.y *= height_factor 79 | 80 | 81 | @dataclass 82 | class BBox: 83 | left: float 84 | top: float 85 | right: float 86 | bottom: float 87 | 88 | def to_protobuf(self) -> pb.BBox: 89 | bbox = pb.BBox() 90 | bbox.left = self.left 91 | bbox.top = self.top 92 | bbox.right = self.right 93 | bbox.bottom = self.bottom 94 | assert bbox.IsInitialized() 95 | return bbox 96 | 97 | @staticmethod 98 | def from_protobuf(bbox: pb.BBox) -> "BBox": 99 | return BBox( 100 | left=bbox.left, 101 | top=bbox.top, 102 | right=bbox.right, 103 | bottom=bbox.bottom, 104 | ) 105 | 106 | @property 107 | def center(self) -> Tuple[float, float]: 108 | x = (self.left + self.right) / 2 109 | y = (self.top + self.bottom) / 2 110 | return x, y 111 | 112 | @property 113 | def center_int(self) -> Tuple[int, int]: 114 | """Get center position as a tuple of integers (rounded)""" 115 | x, y = self.center 116 | return (round(x), round(y)) 117 | 118 | def scale(self, width_factor: float = 1, height_factor: float = 1) -> None: 119 | self.left *= width_factor 120 | self.right *= width_factor 121 | self.top *= height_factor 122 | self.bottom *= height_factor 123 | 124 | def center_scale(self, width_factor: float = 1, height_factor: float = 1) -> None: 125 | x, y = self.center 126 | new_width = self.width * width_factor 127 | new_height = self.height * height_factor 128 | self.left = x - new_width / 2 129 | self.right = x + new_width / 2 130 | self.top = y - new_height / 2 131 | self.bottom = y + new_height / 2 132 | 133 | @property 134 | def coords(self) -> Tuple[Tuple[float, float], Tuple[float, float]]: 135 | return ( 136 | self.top_left, 137 | self.bottom_right, 138 | ) 139 | 140 | @property 141 | def coords_int(self) -> Tuple[Tuple[int, int], Tuple[int, int]]: 142 | return ( 143 | self.top_left_int, 144 | self.bottom_right_int, 145 | ) 146 | 147 | @property 148 | def width(self) -> float: 149 | return self.right - self.left 150 | 151 | @property 152 | def height(self) -> float: 153 | return self.bottom - self.top 154 | 155 | @property 156 | def top_left(self) -> Tuple[float, float]: 157 | return (self.left, self.top) 158 | 159 | @property 160 | def bottom_right(self) -> Tuple[float, float]: 161 | return (self.right, self.bottom) 162 | 163 | @property 164 | def top_left_int(self) -> Tuple[int, int]: 165 | return (round(self.left), round(self.top)) 166 | 167 | @property 168 | def bottom_right_int(self) -> Tuple[int, int]: 169 | return (round(self.right), round(self.bottom)) 170 | 171 | 172 | @dataclass 173 | class HandDetection: 174 | """Dataclass representing a hand detection, consisting of a bounding box, 175 | a score (representing the model's confidence this is a hand), the predicted state 176 | of the hand, whether this is a left/right hand, and a predicted offset to the 177 | interacted object if the hand is interacting.""" 178 | 179 | bbox: BBox 180 | score: np.float32 181 | state: HandState 182 | side: HandSide 183 | object_offset: FloatVector 184 | 185 | def to_protobuf(self) -> pb.HandDetection: 186 | detection = pb.HandDetection() 187 | detection.bbox.MergeFrom(self.bbox.to_protobuf()) 188 | detection.score = self.score 189 | detection.state = self.state.value 190 | detection.object_offset.MergeFrom(self.object_offset.to_protobuf()) 191 | detection.side = self.side.value 192 | assert detection.IsInitialized() 193 | return detection 194 | 195 | @staticmethod 196 | def from_protobuf(detection: pb.HandDetection) -> "HandDetection": 197 | return HandDetection( 198 | bbox=BBox.from_protobuf(detection.bbox), 199 | score=detection.score, 200 | state=HandState(detection.state), 201 | object_offset=FloatVector.from_protobuf(detection.object_offset), 202 | side=HandSide(detection.side), 203 | ) 204 | 205 | def scale(self, width_factor: float = 1, height_factor: float = 1) -> None: 206 | self.bbox.scale(width_factor=width_factor, height_factor=height_factor) 207 | self.object_offset.scale(width_factor=width_factor, height_factor=height_factor) 208 | 209 | def center_scale(self, width_factor: float = 1, height_factor: float = 1) -> None: 210 | self.bbox.center_scale(width_factor=width_factor, height_factor=height_factor) 211 | 212 | 213 | @dataclass 214 | class ObjectDetection: 215 | """Dataclass representing an object detection, consisting of a bounding box and a 216 | score (the model's confidence this is an object)""" 217 | 218 | bbox: BBox 219 | score: np.float32 220 | 221 | def to_protobuf(self) -> pb.ObjectDetection: 222 | detection = pb.ObjectDetection() 223 | detection.bbox.MergeFrom(self.bbox.to_protobuf()) 224 | detection.score = self.score 225 | assert detection.IsInitialized() 226 | return detection 227 | 228 | @staticmethod 229 | def from_protobuf(detection: pb.ObjectDetection) -> "ObjectDetection": 230 | return ObjectDetection( 231 | bbox=BBox.from_protobuf(detection.bbox), score=detection.score 232 | ) 233 | 234 | def scale(self, width_factor: float = 1, height_factor: float = 1) -> None: 235 | self.bbox.scale(width_factor=width_factor, height_factor=height_factor) 236 | 237 | def center_scale(self, width_factor: float = 1, height_factor: float = 1) -> None: 238 | self.bbox.center_scale(width_factor=width_factor, height_factor=height_factor) 239 | 240 | 241 | @dataclass 242 | class FrameDetections: 243 | """Dataclass representing hand-object detections for a frame of a video""" 244 | 245 | video_id: str 246 | frame_number: int 247 | objects: List[ObjectDetection] 248 | hands: List[HandDetection] 249 | 250 | def to_protobuf(self) -> pb.Detections: 251 | detections = pb.Detections() 252 | detections.video_id = self.video_id 253 | detections.frame_number = self.frame_number 254 | detections.hands.extend([hand.to_protobuf() for hand in self.hands]) 255 | detections.objects.extend([object.to_protobuf() for object in self.objects]) 256 | assert detections.IsInitialized() 257 | return detections 258 | 259 | @staticmethod 260 | def from_protobuf(detections: pb.Detections) -> "FrameDetections": 261 | return FrameDetections( 262 | video_id=detections.video_id, 263 | frame_number=detections.frame_number, 264 | hands=[HandDetection.from_protobuf(pb) for pb in detections.hands], 265 | objects=[ObjectDetection.from_protobuf(pb) for pb in detections.objects], 266 | ) 267 | 268 | @staticmethod 269 | def from_protobuf_str(pb_str: bytes) -> "FrameDetections": 270 | pb_detection = pb.Detections() 271 | pb_detection.MergeFromString(pb_str) 272 | return FrameDetections.from_protobuf(pb_detection) 273 | 274 | def get_hand_object_interactions( 275 | self, object_threshold: float = 0, hand_threshold: float = 0 276 | ) -> Dict[int, int]: 277 | """Match the hands to objects based on the hand offset vector that the model 278 | uses to predict the location of the interacted object. 279 | 280 | Args: 281 | object_threshold: Object score threshold above which to consider objects 282 | for matching 283 | hand_threshold: Hand score threshold above which to consider hands for 284 | matching. 285 | 286 | Returns: 287 | A dictionary mapping hand detections to objects by indices 288 | """ 289 | interactions = dict() 290 | object_idxs = [ 291 | i for i, obj in enumerate(self.objects) if obj.score >= object_threshold 292 | ] 293 | object_centers = np.array( 294 | [self.objects[object_id].bbox.center for object_id in object_idxs] 295 | ) 296 | for hand_idx, hand_detection in enumerate(self.hands): 297 | if ( 298 | hand_detection.state.value == HandState.NO_CONTACT.value 299 | or hand_detection.score <= hand_threshold 300 | ): 301 | continue 302 | estimated_object_position = ( 303 | np.array(hand_detection.bbox.center) + 304 | np.array(hand_detection.object_offset.coord) 305 | ) 306 | distances = ((object_centers - estimated_object_position) ** 2).sum( 307 | axis=-1) 308 | interactions[hand_idx] = object_idxs[cast(int, np.argmin(distances))] 309 | return interactions 310 | 311 | def scale(self, width_factor: float = 1, height_factor: float = 1) -> None: 312 | """ 313 | Scale the coordinates of all the hands/objects. x components are multiplied 314 | by the ``width_factor`` and y components by the ``height_factor`` 315 | """ 316 | for det in chain(self.hands, self.objects): 317 | det.scale(width_factor=width_factor, height_factor=height_factor) 318 | 319 | def center_scale(self, width_factor: float = 1, height_factor: float = 1) -> None: 320 | """ 321 | Scale all the hands/objects about their center points. 322 | """ 323 | for det in chain(self.hands, self.objects): 324 | det.center_scale(width_factor=width_factor, height_factor=height_factor) -------------------------------------------------------------------------------- /preprocess/obj_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from preprocess.dataset_util import bbox_inter, HandState, compute_iou, \ 3 | valid_traj, get_valid_traj, points_in_bbox 4 | from preprocess.traj_util import get_homo_point, get_homo_bbox_point 5 | 6 | 7 | def find_active_side(annots, hand_sides, hand_threshold=0.1, obj_threshold=0.1): 8 | if len(hand_sides) == 1: 9 | return hand_sides[0] 10 | else: 11 | hand_counter = {"LEFT": 0, "RIGHT": 0} 12 | for annot in annots: 13 | hands = [hand for hand in annot.hands if hand.score >= hand_threshold] 14 | objs = [obj for obj in annot.objects if obj.score >= obj_threshold] 15 | if len(hands) > 0 and len(objs) > 0: 16 | hand_object_idx_correspondences = annot.get_hand_object_interactions(object_threshold=obj_threshold, 17 | hand_threshold=hand_threshold) 18 | for hand_idx, object_idx in hand_object_idx_correspondences.items(): 19 | hand_bbox = np.array(annot.hands[hand_idx].bbox.coords_int).reshape(-1) 20 | obj_bbox = np.array(annot.objects[object_idx].bbox.coords_int).reshape(-1) 21 | xA, yA, xB, yB, iou = bbox_inter(hand_bbox, obj_bbox) 22 | if iou > 0: 23 | hand_side = annot.hands[hand_idx].side.name 24 | if annot.hands[hand_idx].state.value == HandState.PORTABLE_OBJECT.value: 25 | hand_counter[hand_side] += 1 26 | elif annot.hands[hand_idx].state.value == HandState.STATIONARY_OBJECT.value: 27 | hand_counter[hand_side] += 0.5 28 | if hand_counter["LEFT"] == hand_counter["RIGHT"]: 29 | return "RIGHT" 30 | else: 31 | return max(hand_counter, key=hand_counter.get) 32 | 33 | 34 | def compute_contact(annots, hand_side, contact_state, hand_threshold=0.1): 35 | contacts = [] 36 | for annot in annots: 37 | hands = [hand for hand in annot.hands if hand.score >= hand_threshold 38 | and hand.side.name == hand_side and hand.state.value == contact_state] 39 | if len(hands) > 0: 40 | contacts.append(1) 41 | else: 42 | contacts.append(0) 43 | contacts = np.array(contacts) 44 | padding_contacts = np.pad(contacts, [1, 1], 'edge') 45 | contacts = np.convolve(padding_contacts, [1, 1, 1], 'same') 46 | contacts = contacts[1:-1] / 3 47 | contacts = contacts > 0.5 48 | indices = np.diff(contacts) != 0 49 | if indices.sum() == 0: 50 | return contacts 51 | else: 52 | split = np.where(indices)[0] + 1 53 | contacts_idx = split[-1] 54 | contacts[:contacts_idx] = False 55 | return contacts 56 | 57 | 58 | def find_active_obj_side(annot, hand_side, return_hand=False, return_idx=False, hand_threshold=0.1, obj_threshold=0.1): 59 | objs = [obj for obj in annot.objects if obj.score >= obj_threshold] 60 | if len(objs) == 0: 61 | return None 62 | else: 63 | hand_object_idx_correspondences = annot.get_hand_object_interactions(object_threshold=obj_threshold, 64 | hand_threshold=hand_threshold) 65 | for hand_idx, object_idx in hand_object_idx_correspondences.items(): 66 | if annot.hands[hand_idx].side.name == hand_side: 67 | if return_hand and return_idx: 68 | return annot.objects[object_idx], object_idx, annot.hands[hand_idx], hand_idx 69 | elif return_hand: 70 | return annot.objects[object_idx], annot.hands[hand_idx] 71 | elif return_idx: 72 | return annot.objects[object_idx], object_idx 73 | else: 74 | return annot.objects[object_idx] 75 | return None 76 | 77 | 78 | def find_active_obj_iou(objs, bbox): 79 | max_iou = 0 80 | active_obj = None 81 | for obj in objs: 82 | iou = compute_iou(obj.bbox.coords, bbox) 83 | if iou > max_iou: 84 | max_iou = iou 85 | active_obj = obj 86 | return active_obj, max_iou 87 | 88 | 89 | def traj_compute(annots, hand_sides, homography_stack, hand_threshold=0.1, obj_threshold=0.1): 90 | annot = annots[-1] 91 | obj_traj = [] 92 | obj_centers = [] 93 | obj_bboxs =[] 94 | obj_bboxs_traj = [] 95 | active_hand_side = find_active_side(annots, hand_sides, hand_threshold=hand_threshold, 96 | obj_threshold=obj_threshold) 97 | active_obj, active_object_idx, active_hand, active_hand_idx = find_active_obj_side(annot, 98 | hand_side=active_hand_side, 99 | return_hand=True, return_idx=True, 100 | hand_threshold=hand_threshold, 101 | obj_threshold=obj_threshold) 102 | contact_state = active_hand.state.value 103 | contacts = compute_contact(annots, active_hand_side, contact_state, 104 | hand_threshold=hand_threshold) 105 | obj_center = active_obj.bbox.center 106 | obj_centers.append(obj_center) 107 | obj_point = get_homo_point(obj_center, homography_stack[-1]) 108 | obj_bbox = active_obj.bbox.coords 109 | obj_traj.append(obj_point) 110 | obj_bboxs.append(obj_bbox) 111 | 112 | obj_points2d = get_homo_bbox_point(obj_bbox, homography_stack[-1]) 113 | obj_bboxs_traj.append(obj_points2d) 114 | 115 | for idx in np.arange(len(annots)-2, -1, -1): 116 | annot = annots[idx] 117 | objs = [obj for obj in annot.objects if obj.score >= obj_threshold] 118 | contact = contacts[idx] 119 | if not contact: 120 | obj_centers.append(None) 121 | obj_traj.append(None) 122 | obj_bboxs_traj.append(None) 123 | else: 124 | if len(objs) >= 2: 125 | target_obj, max_iou = find_active_obj_iou(objs, obj_bboxs[-1]) 126 | if target_obj is None: 127 | target_obj = find_active_obj_side(annot, hand_side=active_hand_side, 128 | hand_threshold=hand_threshold, 129 | obj_threshold=obj_threshold) 130 | if target_obj is None: 131 | obj_centers.append(None) 132 | obj_traj.append(None) 133 | obj_bboxs_traj.append(None) 134 | else: 135 | obj_center = target_obj.bbox.center 136 | obj_centers.append(obj_center) 137 | obj_point = get_homo_point(obj_center, homography_stack[idx]) 138 | obj_bbox = target_obj.bbox.coords 139 | obj_traj.append(obj_point) 140 | obj_bboxs.append(obj_bbox) 141 | 142 | obj_points2d = get_homo_bbox_point(obj_bbox, homography_stack[idx]) 143 | obj_bboxs_traj.append(obj_points2d) 144 | 145 | elif len(objs) > 0: 146 | target_obj = find_active_obj_side(annot, hand_side=active_hand_side, 147 | hand_threshold=hand_threshold, 148 | obj_threshold=obj_threshold) 149 | if target_obj is None: 150 | obj_centers.append(None) 151 | obj_traj.append(None) 152 | obj_bboxs_traj.append(None) 153 | else: 154 | obj_center = target_obj.bbox.center 155 | obj_centers.append(obj_center) 156 | obj_point = get_homo_point(obj_center, homography_stack[idx]) 157 | obj_bbox = target_obj.bbox.coords 158 | obj_traj.append(obj_point) 159 | obj_bboxs.append(obj_bbox) 160 | 161 | obj_points2d = get_homo_bbox_point(obj_bbox, homography_stack[idx]) 162 | obj_bboxs_traj.append(obj_points2d) 163 | else: 164 | obj_centers.append(None) 165 | obj_traj.append(None) 166 | obj_bboxs_traj.append(None) 167 | obj_bboxs.reverse() 168 | obj_traj.reverse() 169 | obj_centers.reverse() 170 | obj_bboxs_traj.reverse() 171 | return obj_traj, obj_centers, obj_bboxs, contacts, active_obj, active_object_idx, obj_bboxs_traj 172 | 173 | 174 | def traj_filter(obj_traj, obj_centers, obj_bbox, contacts, homography_stack, contact_ratio=0.4): 175 | assert len(obj_traj) == len(obj_centers), "traj length and center length not equal" 176 | assert len(obj_centers) == len(homography_stack), "center length and homography length not equal" 177 | homo_last2first = homography_stack[-1] 178 | homo_first2last = np.linalg.inv(homo_last2first) 179 | obj_points = [] 180 | obj_inside, obj_detect = [], [] 181 | for idx, obj_center in enumerate(obj_centers): 182 | if obj_center is not None: 183 | homo_current2first = homography_stack[idx] 184 | homo_current2last = homo_current2first.dot(homo_first2last) 185 | obj_point = get_homo_point(obj_center, homo_current2last) 186 | obj_points.append(obj_point) 187 | obj_inside.append(points_in_bbox(obj_point, obj_bbox)) 188 | obj_detect.append(True) 189 | else: 190 | obj_detect.append(False) 191 | obj_inside = np.array(obj_inside) 192 | obj_detect = np.array(obj_detect) 193 | contacts = np.bitwise_and(obj_detect, contacts) 194 | if np.sum(obj_inside) == len(obj_inside) and np.sum(contacts) / len(contacts) < contact_ratio: 195 | obj_traj = np.tile(obj_traj[-1], (len(obj_traj), 1)) 196 | return obj_traj, contacts 197 | 198 | 199 | def traj_completion(traj, imgW=456, imgH=256): 200 | fill_indices = [idx for idx, point in enumerate(traj) if point is not None] 201 | full_traj = traj.copy() 202 | if len(fill_indices) == 1: 203 | point = traj[fill_indices[0]] 204 | full_traj = np.array([point] * len(traj), dtype=np.float32) 205 | else: 206 | contact_time = fill_indices[0] 207 | if contact_time > 0: 208 | full_traj[:contact_time] = [traj[contact_time]] * contact_time 209 | for previous_idx, current_idx in zip(fill_indices[:-1], fill_indices[1:]): 210 | start_point, end_point = traj[previous_idx], traj[current_idx] 211 | time_expand = current_idx - previous_idx 212 | for idx in range(previous_idx+1, current_idx): 213 | full_traj[idx] = (idx-previous_idx) / time_expand * end_point + (current_idx-idx) / time_expand * start_point 214 | full_traj = np.array(full_traj, dtype=np.float32) 215 | full_traj = get_valid_traj(full_traj, imgW=imgW, imgH=imgH) 216 | return full_traj, fill_indices 217 | 218 | 219 | def compute_obj_traj(frames, annots, hand_sides, homography_stack, hand_threshold=0.1, obj_threshold=0.1, 220 | contact_ratio=0.4): 221 | imgH, imgW = frames[0].shape[:2] 222 | obj_traj, obj_centers, obj_bboxs, contacts, active_obj, active_object_idx, obj_bboxs_traj = traj_compute(annots, hand_sides, homography_stack, 223 | hand_threshold=hand_threshold, obj_threshold=obj_threshold) 224 | obj_traj, contacts = traj_filter(obj_traj, obj_centers, obj_bboxs[-1], contacts, homography_stack, 225 | contact_ratio=contact_ratio) 226 | obj_traj = valid_traj(obj_traj, imgW=imgW, imgH=imgH) 227 | if len(obj_traj) == 0: 228 | print("object traj filtered out") 229 | return None 230 | else: 231 | complete_traj, fill_indices = traj_completion(obj_traj, imgW=imgW, imgH=imgH) 232 | obj_trajs = {"traj": complete_traj, "fill_indices": fill_indices, "centers": obj_centers} 233 | return contacts, obj_trajs, active_obj, active_object_idx, obj_bboxs_traj -------------------------------------------------------------------------------- /datasets/dataset_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import pandas as pd 4 | import pickle 5 | from sklearn.model_selection import train_test_split 6 | import numpy as np 7 | 8 | 9 | def timestr2sec(t_str): 10 | hh, mm, ss = [float(x) for x in t_str.split(':')] 11 | t_sec = hh * 3600.0 + mm * 60.0 + ss 12 | return t_sec 13 | 14 | 15 | def read_rulstm_splits(rulstm_annotation_path): 16 | header = ['uid', 'video_id', 'start_frame', 'stop_frame', 'verb_class', 'noun_class', 'action_class'] 17 | df_train = pd.read_csv(os.path.join(rulstm_annotation_path, 'training.csv'), names=header) 18 | df_validation = pd.read_csv(os.path.join(rulstm_annotation_path, 'validation.csv'), names=header) 19 | return df_train, df_validation 20 | 21 | 22 | def str2list(s, out_type=None): 23 | """ 24 | Convert a string "[i1, i2, ...]" of items into a list [i1, i2, ...] of items. 25 | """ 26 | s = s.replace('[', '').replace(']', '') 27 | s = s.replace('\'', '') 28 | s = s.split(', ') 29 | if out_type is not None: 30 | s = [out_type(ss) for ss in s] 31 | return s 32 | 33 | 34 | def split_train_val(df, validation_ratio=0.2, use_rulstm_splits=False, 35 | rulstm_annotation_path=None, label_info_path=None, 36 | use_label_only=True): 37 | if label_info_path is not None and use_label_only: 38 | with open(label_info_path, 'r') as f: 39 | uids_label = json.load(f) 40 | df = df.loc[df['uid'].isin(uids_label)] 41 | if use_rulstm_splits: 42 | assert rulstm_annotation_path is not None 43 | df_train_rulstm, df_validation_rulstm = read_rulstm_splits(rulstm_annotation_path) 44 | uids_train = df_train_rulstm['uid'].values.tolist() 45 | uids_validation = df_validation_rulstm['uid'].values.tolist() 46 | df_train = df.loc[df['uid'].isin(uids_train)] 47 | df_validation = df.loc[df['uid'].isin(uids_validation)] 48 | else: 49 | if validation_ratio == 0.0: 50 | df_train = df 51 | df_validation = pd.DataFrame(columns=df.columns) 52 | elif validation_ratio == 1.0: 53 | df_train = pd.DataFrame(columns=df.columns) 54 | df_validation = df 55 | elif 0.0 < validation_ratio < 1.0: 56 | df_train, df_validation = train_test_split(df, test_size=validation_ratio, 57 | random_state=3577, 58 | shuffle=True, stratify=df['participant_id']) 59 | else: 60 | raise Exception(f'Error. Validation "{validation_ratio}" not supported.') 61 | return df_train, df_validation 62 | 63 | 64 | def create_actions_df(annot_path, rulstm_annot_path, label_path, eval_label_path, ek_version, out_path='actions.csv', use_rulstm_splits=True): 65 | if use_rulstm_splits: 66 | if ek_version == 'ek55': 67 | df_actions = pd.read_csv(os.path.join(rulstm_annot_path['ek55'], 'actions.csv')) 68 | elif ek_version == 'ek100': 69 | df_actions = pd.read_csv(os.path.join(rulstm_annot_path['ek100'], 'actions.csv')) 70 | df_actions['action'] = df_actions.action.map(lambda x: x.replace(' ', '_')) 71 | 72 | df_actions['verb_class'] = df_actions.verb 73 | df_actions['noun_class'] = df_actions.noun 74 | df_actions['verb'] = df_actions.action.map(lambda x: x.split('_')[0]) 75 | df_actions['noun'] = df_actions.action.map(lambda x: x.split('_')[1]) 76 | df_actions['action'] = df_actions.action 77 | df_actions['action_class'] = df_actions.id 78 | del df_actions['id'] 79 | 80 | else: 81 | if ek_version == 'ek55': 82 | df_train = get_ek55_annotation(annot_path, rulstm_annot_path, label_path, eval_label_path=None, partition='train', 83 | use_label_only=False, raw=True) 84 | df_validation = get_ek55_annotation(annot_path, rulstm_annot_path, label_path, eval_label_path, partition='validation', 85 | use_label_only=False, raw=True) 86 | df = pd.concat([df_train, df_validation]) 87 | df.sort_values(by=['uid'], inplace=True) 88 | 89 | elif ek_version == 'ek100': 90 | df_train = get_ek100_annotation(annot_path, rulstm_annot_path, label_path, eval_label_path=None, partition='train', 91 | use_label_only=False, raw=True) 92 | df_validation = get_ek100_annotation(annot_path, rulstm_annot_path, label_path, eval_label_path, partition='validation', 93 | use_label_only=False, raw=True) 94 | df = pd.concat([df_train, df_validation]) 95 | df.sort_values(by=['narration_id'], inplace=True) 96 | 97 | noun_classes = df.noun_class.values 98 | nouns = df.noun.values 99 | verb_classes = df.verb_class.values 100 | verbs = df.verb.values 101 | 102 | actions_combinations = [f'{v}_{n}' for v, n in zip(verb_classes, noun_classes)] 103 | actions = [f'{v}_{n}' for v, n in zip(verbs, nouns)] 104 | 105 | df_actions = {'verb_class': [], 'noun_class': [], 'verb': [], 'noun': [], 'action': []} 106 | vn_combinations = [] 107 | for i, a in enumerate(actions_combinations): 108 | if a in vn_combinations: 109 | continue 110 | 111 | v, n = a.split('_') 112 | v = int(v) 113 | n = int(n) 114 | df_actions['verb_class'] += [v] 115 | df_actions['noun_class'] += [n] 116 | df_actions['action'] += [actions[i]] 117 | df_actions['verb'] += [verbs[i]] 118 | df_actions['noun'] += [nouns[i]] 119 | vn_combinations += [a] 120 | df_actions = pd.DataFrame(df_actions) 121 | df_actions.sort_values(by=['verb_class', 'noun_class'], inplace=True) 122 | df_actions['action_class'] = range(len(df_actions)) 123 | 124 | df_actions.to_csv(out_path, index=False) 125 | print(f'Saved file at "{out_path}".') 126 | 127 | 128 | def get_ek55_annotation(annot_path, rulstm_annot_path, label_path, eval_label_path, partition, validation_ratio=0.2, 129 | use_rulstm_splits=False, use_label_only=True, raw=False): 130 | if partition in ['train', 'validation']: 131 | csv_path = os.path.join(annot_path['ek55'], 'EPIC_train_action_labels.csv') 132 | label_info_path = os.path.join(label_path['ek55'], "video_info.json") 133 | df = pd.read_csv(csv_path) 134 | df_train, df_validation = split_train_val(df, validation_ratio=validation_ratio, 135 | use_rulstm_splits=use_rulstm_splits, 136 | rulstm_annotation_path=rulstm_annot_path['ek55'], 137 | label_info_path=label_info_path, 138 | use_label_only=use_label_only) 139 | 140 | df = df_train if partition == 'train' else df_validation 141 | if not use_rulstm_splits: 142 | df.sort_values(by=['uid'], inplace=True) 143 | 144 | elif partition in ['eval', 'evaluation']: 145 | csv_path = os.path.join(annot_path['ek55'], 'EPIC_train_action_labels.csv') 146 | df = pd.read_csv(csv_path) 147 | with open(eval_label_path['ek55'], 'rb') as f: 148 | eval_labels = pickle.load(f) 149 | eval_uids = eval_labels.keys() 150 | df = df.loc[df['uid'].isin(eval_uids)] 151 | 152 | elif partition == 'test_s1': 153 | csv_path = os.path.join(annot_path['ek55'], 'EPIC_test_s1_timestamps.csv') 154 | df = pd.read_csv(csv_path) 155 | 156 | elif partition == 'test_s2': 157 | csv_path = os.path.join(annot_path['ek55'], 'EPIC_test_s2_timestamps.csv') 158 | df = pd.read_csv(csv_path) 159 | else: 160 | raise Exception(f'Error. Partition "{partition}" not supported.') 161 | 162 | if raw: 163 | return df 164 | 165 | actions_df_path = os.path.join(annot_path['ek55'], 'actions.csv') 166 | if not os.path.exists(actions_df_path): 167 | create_actions_df(annot_path, rulstm_annot_path, label_path, eval_label_path, 'ek55', out_path=actions_df_path, use_rulstm_splits=True) 168 | df_actions = pd.read_csv(actions_df_path) 169 | 170 | df['start_time'] = df['start_timestamp'].map(lambda t: timestr2sec(t)) 171 | df['stop_time'] = df['stop_timestamp'].map(lambda t: timestr2sec(t)) 172 | if 'test' not in partition: 173 | action_classes = [] 174 | actions = [] 175 | for _, row in df.iterrows(): 176 | v, n = row.verb_class, row.noun_class 177 | df_a_sub = df_actions[(df_actions['verb_class'] == v) & (df_actions['noun_class'] == n)] 178 | a_cl = df_a_sub['action_class'].values 179 | a = df_a_sub['action'].values 180 | if len(a_cl) > 1: 181 | print(a_cl) 182 | action_classes += [a_cl[0]] 183 | actions += [a[0]] 184 | df['action_class'] = action_classes 185 | df['action'] = actions 186 | df['all_nouns'] = df['all_nouns'].map(lambda x: str2list(x)) 187 | df['all_noun_classes'] = df['all_noun_classes'].map(lambda x: str2list(x, out_type=int)) 188 | 189 | return df 190 | 191 | 192 | def get_ek100_annotation(annot_path, rulstm_annot_path, label_path, eval_label_path, partition, validation_ratio=0.2, 193 | use_rulstm_splits=False, use_label_only=True, raw=False): 194 | 195 | if partition in 'train': 196 | df = pd.read_csv(os.path.join(annot_path['ek100'], 'EPIC_100_train.csv')) 197 | uids = np.arange(len(df)) 198 | 199 | elif partition in 'validation': 200 | df_train = pd.read_csv(os.path.join(annot_path['ek100'], 'EPIC_100_train.csv')) 201 | df = pd.read_csv(os.path.join(annot_path['ek100'], 'EPIC_100_validation.csv')) 202 | uids = np.arange(len(df)) + len(df_train) 203 | 204 | elif partition in 'evaluation': 205 | df_train = pd.read_csv(os.path.join(annot_path['ek100'], 'EPIC_100_train.csv')) 206 | df = pd.read_csv(os.path.join(annot_path['ek100'], 'EPIC_100_validation.csv')) 207 | uids = np.arange(len(df)) + len(df_train) 208 | df['uid'] = uids 209 | with open(eval_label_path['ek100'], 'rb') as f: 210 | eval_labels = pickle.load(f) 211 | eval_uids = eval_labels.keys() 212 | df = df.loc[df['uid'].isin(eval_uids)] 213 | 214 | elif partition == 'test': 215 | df_train = pd.read_csv(os.path.join(annot_path['ek100'], 'EPIC_100_train.csv')) 216 | df_validation = pd.read_csv(os.path.join(annot_path['ek100'], 'EPIC_100_validation.csv')) 217 | df = pd.read_csv(os.path.join(annot_path['ek100'], 'EPIC_100_test_timestamps.csv')) 218 | uids = np.arange(len(df)) + len(df_train) + len(df_validation) 219 | 220 | else: 221 | raise Exception(f'Error. Partition "{partition}" not supported.') 222 | if raw: 223 | return df 224 | 225 | actions_df_path = os.path.join(annot_path['ek100'], 'actions.csv') 226 | if not os.path.exists(actions_df_path): 227 | create_actions_df(annot_path, rulstm_annot_path, label_path, eval_label_path, 'ek100', actions_df_path) 228 | df_actions = pd.read_csv(actions_df_path) 229 | 230 | df['start_time'] = df['start_timestamp'].map(lambda t: timestr2sec(t)) 231 | df['stop_time'] = df['stop_timestamp'].map(lambda t: timestr2sec(t)) 232 | if not 'uid' in df: 233 | df['uid'] = uids 234 | 235 | if use_label_only: 236 | label_info_path = os.path.join(label_path['ek100'], "video_info.json") 237 | with open(label_info_path, 'r') as f: 238 | uids_label = json.load(f) 239 | df = df.loc[df['uid'].isin(uids_label)] 240 | 241 | if 'test' not in partition: 242 | action_classes = [] 243 | actions = [] 244 | for _, row in df.iterrows(): 245 | v, n = row.verb_class, row.noun_class 246 | df_a_sub = df_actions[(df_actions['verb_class'] == v) & (df_actions['noun_class'] == n)] 247 | a_cl = df_a_sub['action_class'].values 248 | a = df_a_sub['action'].values 249 | if len(a_cl) > 1: 250 | print(a_cl) 251 | action_classes += [a_cl[0]] 252 | actions += [a[0]] 253 | df['action_class'] = action_classes 254 | df['action'] = actions 255 | df['all_nouns'] = df['all_nouns'].map(lambda x: str2list(x)) 256 | df['all_noun_classes'] = df['all_noun_classes'].map(lambda x: str2list(x, out_type=int)) 257 | return df -------------------------------------------------------------------------------- /diffip2d/utils/logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | Logger copied from OpenAI baselines to avoid extra RL-based dependencies: 3 | https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/logger.py 4 | """ 5 | 6 | import os 7 | import sys 8 | import shutil 9 | import os.path as osp 10 | import json 11 | import time 12 | import datetime 13 | import tempfile 14 | import warnings 15 | from collections import defaultdict 16 | from contextlib import contextmanager 17 | # import wandb 18 | 19 | DEBUG = 10 20 | INFO = 20 21 | WARN = 30 22 | ERROR = 40 23 | 24 | DISABLED = 50 25 | 26 | 27 | class KVWriter(object): 28 | def writekvs(self, kvs): 29 | raise NotImplementedError 30 | 31 | 32 | class SeqWriter(object): 33 | def writeseq(self, seq): 34 | raise NotImplementedError 35 | 36 | 37 | class HumanOutputFormat(KVWriter, SeqWriter): 38 | def __init__(self, filename_or_file): 39 | if isinstance(filename_or_file, str): 40 | self.file = open(filename_or_file, "wt") 41 | self.own_file = True 42 | else: 43 | assert hasattr(filename_or_file, "read"), ( 44 | "expected file or str, got %s" % filename_or_file 45 | ) 46 | self.file = filename_or_file 47 | self.own_file = False 48 | 49 | def writekvs(self, kvs): 50 | # Create strings for printing 51 | key2str = {} 52 | for (key, val) in sorted(kvs.items()): 53 | if hasattr(val, "__float__"): 54 | valstr = "%-8.3g" % val 55 | else: 56 | valstr = str(val) 57 | key2str[self._truncate(key)] = self._truncate(valstr) 58 | 59 | # Find max widths 60 | if len(key2str) == 0: 61 | print("WARNING: tried to write empty key-value dict") 62 | return 63 | else: 64 | keywidth = max(map(len, key2str.keys())) 65 | valwidth = max(map(len, key2str.values())) 66 | 67 | # Write out the data 68 | dashes = "-" * (keywidth + valwidth + 7) 69 | lines = [dashes] 70 | for (key, val) in sorted(key2str.items(), key=lambda kv: kv[0].lower()): 71 | lines.append( 72 | "| %s%s | %s%s |" 73 | % (key, " " * (keywidth - len(key)), val, " " * (valwidth - len(val))) 74 | ) 75 | lines.append(dashes) 76 | self.file.write("\n".join(lines) + "\n") 77 | 78 | # Flush the output to the file 79 | self.file.flush() 80 | 81 | def _truncate(self, s): 82 | maxlen = 30 83 | return s[: maxlen - 3] + "..." if len(s) > maxlen else s 84 | 85 | def writeseq(self, seq): 86 | seq = list(seq) 87 | for (i, elem) in enumerate(seq): 88 | self.file.write(elem) 89 | if i < len(seq) - 1: # add space unless this is the last one 90 | self.file.write(" ") 91 | self.file.write("\n") 92 | self.file.flush() 93 | 94 | def close(self): 95 | if self.own_file: 96 | self.file.close() 97 | 98 | 99 | class JSONOutputFormat(KVWriter): 100 | def __init__(self, filename): 101 | self.file = open(filename, "wt") 102 | 103 | def writekvs(self, kvs): 104 | for k, v in sorted(kvs.items()): 105 | if hasattr(v, "dtype"): 106 | kvs[k] = float(v) 107 | self.file.write(json.dumps(kvs) + "\n") 108 | self.file.flush() 109 | 110 | def close(self): 111 | self.file.close() 112 | 113 | 114 | class CSVOutputFormat(KVWriter): 115 | def __init__(self, filename): 116 | self.file = open(filename, "w+t") 117 | self.keys = [] 118 | self.sep = "," 119 | 120 | def writekvs(self, kvs): 121 | # Add our current row to the history 122 | extra_keys = list(kvs.keys() - self.keys) 123 | extra_keys.sort() 124 | if extra_keys: 125 | self.keys.extend(extra_keys) 126 | self.file.seek(0) 127 | lines = self.file.readlines() 128 | self.file.seek(0) 129 | for (i, k) in enumerate(self.keys): 130 | if i > 0: 131 | self.file.write(",") 132 | self.file.write(k) 133 | self.file.write("\n") 134 | for line in lines[1:]: 135 | self.file.write(line[:-1]) 136 | self.file.write(self.sep * len(extra_keys)) 137 | self.file.write("\n") 138 | for (i, k) in enumerate(self.keys): 139 | if i > 0: 140 | self.file.write(",") 141 | v = kvs.get(k) 142 | if v is not None: 143 | self.file.write(str(v)) 144 | self.file.write("\n") 145 | self.file.flush() 146 | 147 | def close(self): 148 | self.file.close() 149 | 150 | 151 | class TensorBoardOutputFormat(KVWriter): 152 | """ 153 | Dumps key/value pairs into TensorBoard's numeric format. 154 | """ 155 | 156 | def __init__(self, dir): 157 | os.makedirs(dir, exist_ok=True) 158 | self.dir = dir 159 | self.step = 1 160 | prefix = "events" 161 | path = osp.join(osp.abspath(dir), prefix) 162 | import tensorflow as tf 163 | from tensorflow.python import pywrap_tensorflow 164 | from tensorflow.core.util import event_pb2 165 | from tensorflow.python.util import compat 166 | 167 | self.tf = tf 168 | self.event_pb2 = event_pb2 169 | self.pywrap_tensorflow = pywrap_tensorflow 170 | self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path)) 171 | 172 | def writekvs(self, kvs): 173 | def summary_val(k, v): 174 | kwargs = {"tag": k, "simple_value": float(v)} 175 | return self.tf.Summary.Value(**kwargs) 176 | 177 | summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()]) 178 | event = self.event_pb2.Event(wall_time=time.time(), summary=summary) 179 | event.step = ( 180 | self.step 181 | ) # is there any reason why you'd want to specify the step? 182 | self.writer.WriteEvent(event) 183 | self.writer.Flush() 184 | self.step += 1 185 | 186 | def close(self): 187 | if self.writer: 188 | self.writer.Close() 189 | self.writer = None 190 | 191 | 192 | def make_output_format(format, ev_dir, log_suffix=""): 193 | os.makedirs(ev_dir, exist_ok=True) 194 | if format == "stdout": 195 | return HumanOutputFormat(sys.stdout) 196 | elif format == "log": 197 | return HumanOutputFormat(osp.join(ev_dir, "log%s.txt" % log_suffix)) 198 | elif format == "json": 199 | return JSONOutputFormat(osp.join(ev_dir, "progress%s.json" % log_suffix)) 200 | elif format == "csv": 201 | return CSVOutputFormat(osp.join(ev_dir, "progress%s.csv" % log_suffix)) 202 | elif format == "tensorboard": 203 | return TensorBoardOutputFormat(osp.join(ev_dir, "tb%s" % log_suffix)) 204 | else: 205 | raise ValueError("Unknown format specified: %s" % (format,)) 206 | 207 | 208 | # ================================================================ 209 | # API 210 | # ================================================================ 211 | 212 | 213 | def logkv(key, val): 214 | """ 215 | Log a value of some diagnostic 216 | Call this once for each diagnostic quantity, each iteration 217 | If called many times, last value will be used. 218 | """ 219 | get_current().logkv(key, val) 220 | 221 | 222 | def logkv_mean(key, val): 223 | """ 224 | The same as logkv(), but if called many times, values averaged. 225 | """ 226 | get_current().logkv_mean(key, val) 227 | 228 | 229 | def logkvs(d): 230 | """ 231 | Log a dictionary of key-value pairs 232 | """ 233 | for (k, v) in d.items(): 234 | logkv(k, v) 235 | 236 | 237 | def dumpkvs(): 238 | """ 239 | Write all of the diagnostics from the current iteration 240 | """ 241 | return get_current().dumpkvs() 242 | 243 | 244 | def getkvs(): 245 | return get_current().name2val 246 | 247 | 248 | def log(*args, level=INFO): 249 | """ 250 | Write the sequence of args, with no separators, to the console and output files (if you've configured an output file). 251 | """ 252 | get_current().log(*args, level=level) 253 | 254 | 255 | def debug(*args): 256 | log(*args, level=DEBUG) 257 | 258 | 259 | def info(*args): 260 | log(*args, level=INFO) 261 | 262 | 263 | def warn(*args): 264 | log(*args, level=WARN) 265 | 266 | 267 | def error(*args): 268 | log(*args, level=ERROR) 269 | 270 | 271 | def set_level(level): 272 | """ 273 | Set logging threshold on current logger. 274 | """ 275 | get_current().set_level(level) 276 | 277 | 278 | def set_comm(comm): 279 | get_current().set_comm(comm) 280 | 281 | 282 | def get_dir(): 283 | """ 284 | Get directory that log files are being written to. 285 | will be None if there is no output directory (i.e., if you didn't call start) 286 | """ 287 | return get_current().get_dir() 288 | 289 | 290 | record_tabular = logkv 291 | dump_tabular = dumpkvs 292 | 293 | 294 | @contextmanager 295 | def profile_kv(scopename): 296 | logkey = "wait_" + scopename 297 | tstart = time.time() 298 | try: 299 | yield 300 | finally: 301 | get_current().name2val[logkey] += time.time() - tstart 302 | 303 | 304 | def profile(n): 305 | """ 306 | Usage: 307 | @profile("my_func") 308 | def my_func(): code 309 | """ 310 | 311 | def decorator_with_name(func): 312 | def func_wrapper(*args, **kwargs): 313 | with profile_kv(n): 314 | return func(*args, **kwargs) 315 | 316 | return func_wrapper 317 | 318 | return decorator_with_name 319 | 320 | 321 | # ================================================================ 322 | # Backend 323 | # ================================================================ 324 | 325 | 326 | def get_current(): 327 | if Logger.CURRENT is None: 328 | _configure_default_logger() 329 | 330 | return Logger.CURRENT 331 | 332 | 333 | class Logger(object): 334 | DEFAULT = None # A logger with no output files. (See right below class definition) 335 | # So that you can still log to the terminal without setting up any output files 336 | CURRENT = None # Current logger being used by the free functions above 337 | 338 | def __init__(self, dir, output_formats, comm=None): 339 | self.name2val = defaultdict(float) # values this iteration 340 | self.name2cnt = defaultdict(int) 341 | self.level = INFO 342 | self.dir = dir 343 | self.output_formats = output_formats 344 | self.comm = comm 345 | 346 | # Logging API, forwarded 347 | # ---------------------------------------- 348 | def logkv(self, key, val): 349 | self.name2val[key] = val 350 | 351 | def logkv_mean(self, key, val): 352 | oldval, cnt = self.name2val[key], self.name2cnt[key] 353 | self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1) 354 | self.name2cnt[key] = cnt + 1 355 | 356 | def dumpkvs(self, prefix=None): 357 | if self.comm is None: 358 | d = self.name2val 359 | else: 360 | d = mpi_weighted_mean( 361 | self.comm, 362 | { 363 | name: (val, self.name2cnt.get(name, 1)) 364 | for (name, val) in self.name2val.items() 365 | }, 366 | ) 367 | if self.comm.rank != 0: 368 | d["dummy"] = 1 # so we don't get a warning about empty dict 369 | # LISA 370 | out = d.copy() # Return the dict for unit testing purposes 371 | if int(os.environ['LOCAL_RANK']) == 0: 372 | # wandb.log({**d}) 373 | for fmt in self.output_formats: 374 | if isinstance(fmt, KVWriter): 375 | fmt.writekvs(d) 376 | self.name2val.clear() 377 | self.name2cnt.clear() 378 | return out 379 | 380 | def log(self, *args, level=INFO): 381 | if self.level <= level: 382 | self._do_log(args) 383 | 384 | # Configuration 385 | # ---------------------------------------- 386 | def set_level(self, level): 387 | self.level = level 388 | 389 | def set_comm(self, comm): 390 | self.comm = comm 391 | 392 | def get_dir(self): 393 | return self.dir 394 | 395 | def close(self): 396 | for fmt in self.output_formats: 397 | fmt.close() 398 | 399 | # Misc 400 | # ---------------------------------------- 401 | def _do_log(self, args): 402 | for fmt in self.output_formats: 403 | if isinstance(fmt, SeqWriter): 404 | fmt.writeseq(map(str, args)) 405 | 406 | 407 | def get_rank_without_mpi_import(): 408 | # check environment variables here instead of importing mpi4py 409 | # to avoid calling MPI_Init() when this module is imported 410 | for varname in ["PMI_RANK", "OMPI_COMM_WORLD_RANK"]: 411 | if varname in os.environ: 412 | return int(os.environ[varname]) 413 | return 0 414 | 415 | 416 | def mpi_weighted_mean(comm, local_name2valcount): 417 | """ 418 | Copied from: https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/mpi_util.py#L110 419 | Perform a weighted average over dicts that are each on a different node 420 | Input: local_name2valcount: dict mapping key -> (value, count) 421 | Returns: key -> mean 422 | """ 423 | all_name2valcount = comm.gather(local_name2valcount) 424 | if comm.rank == 0: 425 | name2sum = defaultdict(float) 426 | name2count = defaultdict(float) 427 | for n2vc in all_name2valcount: 428 | for (name, (val, count)) in n2vc.items(): 429 | try: 430 | val = float(val) 431 | except ValueError: 432 | if comm.rank == 0: 433 | warnings.warn( 434 | "WARNING: tried to compute mean on non-float {}={}".format( 435 | name, val 436 | ) 437 | ) 438 | else: 439 | name2sum[name] += val * count 440 | name2count[name] += count 441 | return {name: name2sum[name] / name2count[name] for name in name2sum} 442 | else: 443 | return {} 444 | 445 | 446 | def configure(dir=None, format_strs=None, comm=None, log_suffix=""): 447 | """ 448 | If comm is provided, average all numerical stats across that comm 449 | """ 450 | if dir is None: 451 | dir = os.getenv("OPENAI_LOGDIR") 452 | if dir is None: 453 | dir = osp.join( 454 | tempfile.gettempdir(), 455 | datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f"), 456 | ) 457 | assert isinstance(dir, str) 458 | dir = os.path.expanduser(dir) 459 | os.makedirs(os.path.expanduser(dir), exist_ok=True) 460 | 461 | rank = get_rank_without_mpi_import() 462 | if rank > 0: 463 | log_suffix = log_suffix + "-rank%03i" % rank 464 | 465 | if format_strs is None: 466 | if rank == 0: 467 | format_strs = os.getenv("OPENAI_LOG_FORMAT", "stdout,log,csv").split(",") 468 | else: 469 | format_strs = os.getenv("OPENAI_LOG_FORMAT_MPI", "log").split(",") 470 | format_strs = filter(None, format_strs) 471 | output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs] 472 | 473 | Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm) 474 | if output_formats: 475 | log("Logging to %s" % dir) 476 | 477 | 478 | def _configure_default_logger(): 479 | configure() 480 | Logger.DEFAULT = Logger.CURRENT 481 | 482 | 483 | def reset(): 484 | if Logger.CURRENT is not Logger.DEFAULT: 485 | Logger.CURRENT.close() 486 | Logger.CURRENT = Logger.DEFAULT 487 | log("Reset logger") 488 | 489 | 490 | @contextmanager 491 | def scoped_configure(dir=None, format_strs=None, comm=None): 492 | prevlogger = Logger.CURRENT 493 | configure(dir=dir, format_strs=format_strs, comm=comm) 494 | try: 495 | yield 496 | finally: 497 | Logger.CURRENT.close() 498 | Logger.CURRENT = prevlogger 499 | 500 | --------------------------------------------------------------------------------