├── .gitignore ├── assets └── teaser.png ├── visualize ├── requirements.txt ├── segment_viewer.py ├── skel_viewer.py └── smplx_viewer.py ├── src ├── utils │ ├── misc.py │ ├── model_utils.py │ ├── word_vectorizer.py │ ├── dist_utils.py │ └── rotation_conversion.py ├── model │ ├── cfg_sampler.py │ ├── blocks.py │ ├── layers.py │ ├── net_2o.py │ └── net_3o.py ├── train │ ├── train_net_2o.py │ ├── train_net_3o.py │ ├── train_platforms.py │ └── training_loop.py ├── diffusion │ ├── losses.py │ ├── respace.py │ ├── resample.py │ ├── nn.py │ └── fp16_util.py ├── dataset │ ├── decomp_dataset.py │ ├── himo_2o_dataset.py │ ├── himo_3o_dataset.py │ ├── fe_dataset.py │ ├── tensors.py │ ├── eval_dataset.py │ └── eval_gen_dataset.py ├── feature_extractor │ ├── train_decomp.py │ ├── train_tex_mot_match.py │ ├── eval_wrapper.py │ └── modules.py └── eval │ ├── metrics.py │ ├── eval_himo_3o.py │ └── eval_himo_2o.py ├── README.md └── LICENSE /.gitignore: -------------------------------------------------------------------------------- 1 | /body_models/ 2 | /data/ 3 | imgui.ini 4 | __pycache__/ 5 | save/ -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LvXinTao/HIMO_dataset/HEAD/assets/teaser.png -------------------------------------------------------------------------------- /visualize/requirements.txt: -------------------------------------------------------------------------------- 1 | aitviewer 2 | trimesh 3 | glfw 4 | chumpy 5 | smplx 6 | tqdm 7 | scikit-learn 8 | pandas 9 | matplotlib -------------------------------------------------------------------------------- /src/utils/misc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import random 4 | import os 5 | 6 | def fixseed(seed): 7 | torch.backends.cudnn.benchmark = False 8 | random.seed(seed) 9 | np.random.seed(seed) 10 | torch.manual_seed(seed) 11 | 12 | def makepath(desired_path,isfile=False): 13 | import os 14 | if isfile: 15 | if not os.path.exists(os.path.dirname(desired_path)):os.makedirs(os.path.dirname(desired_path),exist_ok=True) 16 | else: 17 | if not os.path.exists(desired_path): os.makedirs(desired_path,exist_ok=True) 18 | return desired_path 19 | 20 | def to_numpy(tensor): 21 | if torch.is_tensor(tensor): 22 | return tensor.cpu().numpy() 23 | elif type(tensor).__module__ != 'numpy': 24 | raise ValueError("Cannot convert {} to numpy array".format( 25 | type(tensor))) 26 | return tensor 27 | 28 | def to_tensor(ndarray): 29 | if type(ndarray).__module__ == 'numpy': 30 | return torch.from_numpy(ndarray).float() 31 | elif not torch.is_tensor(ndarray): 32 | raise ValueError("Cannot convert {} to torch tensor".format( 33 | type(ndarray))) 34 | return ndarray -------------------------------------------------------------------------------- /src/model/cfg_sampler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from copy import deepcopy 5 | 6 | # A wrapper model for Classifier-free guidance **SAMPLING** only 7 | # https://arxiv.org/abs/2207.12598 8 | class ClassifierFreeSampleModel(nn.Module): 9 | 10 | def __init__(self, model): 11 | super().__init__() 12 | self.model = model # model is the actual model to run 13 | 14 | assert self.model.cond_mask_prob > 0, 'Cannot run a guided diffusion on a model that has not been trained with no conditions' 15 | 16 | # pointers to inner model 17 | self.rot2xyz = self.model.rot2xyz 18 | self.translation = self.model.translation 19 | self.njoints = self.model.njoints 20 | self.nfeats = self.model.nfeats 21 | self.data_rep = self.model.data_rep 22 | self.cond_mode = self.model.cond_mode 23 | 24 | def forward(self, x, timesteps, y=None): 25 | cond_mode = self.model.cond_mode 26 | assert cond_mode in ['text', 'action'] 27 | y_uncond = deepcopy(y) 28 | y_uncond['uncond'] = True 29 | out = self.model(x, timesteps, y) 30 | out_uncond = self.model(x, timesteps, y_uncond) 31 | return out_uncond + (y['scale'].view(-1, 1, 1, 1) * (out - out_uncond)) 32 | 33 | -------------------------------------------------------------------------------- /src/train/train_net_2o.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import json 4 | 5 | from src.utils.misc import fixseed,makepath 6 | from src.utils.parser_utils import train_net_2o_args 7 | from src.utils.model_utils import create_model_and_diffusion 8 | from src.utils import dist_utils 9 | from src.train.train_platforms import * 10 | from src.train.training_loop import TrainLoop 11 | from src.dataset.himo_2o_dataset import HIMO_2O 12 | from src.dataset.tensors import himo_2o_collate_fn 13 | 14 | from torch.utils.data import DataLoader 15 | from loguru import logger 16 | 17 | def main(): 18 | args=train_net_2o_args() 19 | # save path 20 | save_path=osp.join(args.save_dir,args.exp_name) 21 | if osp.exists(save_path): 22 | raise FileExistsError(f'{save_path} already exists!') 23 | # pass 24 | else: 25 | makepath(save_path) 26 | args.save_path=save_path 27 | # training plateform 28 | train_platform_type=eval(args.train_platform) 29 | train_platform=train_platform_type(args.save_path) 30 | train_platform.report_args(args,'Args') 31 | # config logger 32 | logger.add(osp.join(save_path,'train.log')) 33 | # save args 34 | with open(osp.join(save_path,'args.json'),'w') as f: 35 | json.dump(vars(args),f,indent=4) 36 | 37 | dist_utils.setup_dist(args.device) 38 | # get dataset loader 39 | logger.info('Loading Training dataset...') 40 | train_dataset=HIMO_2O(args,split='train') 41 | data_loader=DataLoader(train_dataset,batch_size=args.batch_size, 42 | shuffle=True,num_workers=8,drop_last=True,collate_fn=himo_2o_collate_fn) 43 | 44 | # get model and diffusion 45 | logger.info('Loading Model and Diffusion...') 46 | model,diffusion=create_model_and_diffusion(args) 47 | model.to(dist_utils.dev()) 48 | 49 | logger.info('Start Training...') 50 | TrainLoop(args,train_platform,model,diffusion,data_loader).run_loop() 51 | train_platform.close() 52 | 53 | if __name__=='__main__': 54 | main() -------------------------------------------------------------------------------- /src/train/train_net_3o.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import json 4 | 5 | from src.utils.misc import fixseed,makepath 6 | from src.utils.parser_utils import train_net_3o_args 7 | from src.utils.model_utils import create_model_and_diffusion 8 | from src.utils import dist_utils 9 | from src.train.train_platforms import * 10 | from src.train.training_loop import TrainLoop 11 | from src.dataset.himo_3o_dataset import HIMO_3O 12 | from src.dataset.tensors import himo_3o_collate_fn 13 | 14 | from torch.utils.data import DataLoader 15 | from loguru import logger 16 | 17 | def main(): 18 | args=train_net_3o_args() 19 | # save path 20 | save_path=osp.join(args.save_dir,args.exp_name) 21 | if osp.exists(save_path): 22 | raise FileExistsError(f'{save_path} already exists!') 23 | # pass 24 | else: 25 | makepath(save_path) 26 | args.save_path=save_path 27 | # training plateform 28 | train_platform_type=eval(args.train_platform) 29 | train_platform=train_platform_type(args.save_path) 30 | train_platform.report_args(args,'Args') 31 | # config logger 32 | logger.add(osp.join(save_path,'train.log')) 33 | # save args 34 | with open(osp.join(save_path,'args.json'),'w') as f: 35 | json.dump(vars(args),f,indent=4) 36 | 37 | dist_utils.setup_dist(args.device) 38 | # get dataset loader 39 | logger.info('Loading Training dataset...') 40 | train_dataset=HIMO_3O(args,split='train') 41 | data_loader=DataLoader(train_dataset,batch_size=args.batch_size, 42 | shuffle=True,num_workers=8,drop_last=True,collate_fn=himo_3o_collate_fn) 43 | 44 | # get model and diffusion 45 | logger.info('Loading Model and Diffusion...') 46 | model,diffusion=create_model_and_diffusion(args) 47 | model.to(dist_utils.dev()) 48 | 49 | logger.info('Start Training...') 50 | TrainLoop(args,train_platform,model,diffusion,data_loader).run_loop() 51 | train_platform.close() 52 | 53 | if __name__=='__main__': 54 | main() -------------------------------------------------------------------------------- /src/train/train_platforms.py: -------------------------------------------------------------------------------- 1 | import os 2 | import wandb 3 | 4 | class TrainPlatform: 5 | def __init__(self, save_dir): 6 | pass 7 | 8 | def report_scalar(self, name, value, iteration, group_name=None): 9 | pass 10 | 11 | def report_args(self, args, name): 12 | pass 13 | 14 | def close(self): 15 | pass 16 | 17 | 18 | class ClearmlPlatform(TrainPlatform): 19 | def __init__(self, save_dir): 20 | from clearml import Task 21 | path, name = os.path.split(save_dir) 22 | self.task = Task.init(project_name='motion_diffusion', 23 | task_name=name, 24 | output_uri=path) 25 | self.logger = self.task.get_logger() 26 | 27 | def report_scalar(self, name, value, iteration, group_name): 28 | self.logger.report_scalar(title=group_name, series=name, iteration=iteration, value=value) 29 | 30 | def report_args(self, args, name): 31 | self.task.connect(args, name=name) 32 | 33 | def close(self): 34 | self.task.close() 35 | 36 | 37 | class TensorboardPlatform(TrainPlatform): 38 | def __init__(self, save_dir): 39 | from torch.utils.tensorboard import SummaryWriter 40 | self.writer = SummaryWriter(log_dir=save_dir) 41 | 42 | def report_scalar(self, name, value, iteration, group_name=None): 43 | self.writer.add_scalar(f'{group_name}/{name}', value, iteration) 44 | 45 | def close(self): 46 | self.writer.close() 47 | 48 | class WandbPlatform(TrainPlatform): 49 | def __init__(self, save_dir): 50 | import wandb 51 | wandb.init(project='HIMO_eccv', 52 | name=os.path.split(save_dir)[-1], dir=save_dir) 53 | 54 | def report_scalar(self, name, value, iteration, group_name=None): 55 | wandb.log({f'{group_name}/{name}': value}, step=iteration) 56 | 57 | def report_args(self, args, name): 58 | wandb.config.update(args, allow_val_change=True) 59 | 60 | def close(self): 61 | wandb.finish() 62 | 63 | class NoPlatform(TrainPlatform): 64 | def __init__(self, save_dir): 65 | pass 66 | 67 | 68 | -------------------------------------------------------------------------------- /src/utils/model_utils.py: -------------------------------------------------------------------------------- 1 | from src.diffusion import gaussian_diffusion as gd 2 | from src.model.net_2o import NET_2O 3 | from src.model.net_3o import NET_3O 4 | from src.diffusion.respace import SpacedDiffusion,space_timesteps 5 | 6 | def load_model_wo_clip(model, state_dict): 7 | missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) 8 | assert len(unexpected_keys) == 0 9 | assert all([k.startswith('clip_model.') for k in missing_keys]) 10 | 11 | def create_model_and_diffusion(args): 12 | if args.network=='net_2o': 13 | model=NET_2O() 14 | diffusion=create_gaussian_diffusion(args) 15 | elif args.network=='net_3o': 16 | model=NET_3O() 17 | diffusion=create_gaussian_diffusion(args) 18 | 19 | return model,diffusion 20 | 21 | def create_gaussian_diffusion(args): 22 | # default params 23 | predict_xstart = True # we always predict x_start (a.k.a. x0), that's our deal! 24 | steps = args.diffusion_steps 25 | scale_beta = 1. # no scaling 26 | timestep_respacing = '' # can be used for ddim sampling, we don't use it. 27 | learn_sigma = False 28 | rescale_timesteps = False 29 | 30 | betas = gd.get_named_beta_schedule(args.noise_schedule, steps, scale_beta) 31 | loss_type = gd.LossType.MSE 32 | 33 | if not timestep_respacing: 34 | timestep_respacing = [steps] 35 | return SpacedDiffusion( 36 | use_timesteps=space_timesteps(steps, timestep_respacing), 37 | betas=betas, 38 | model_mean_type=( 39 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X 40 | ), 41 | model_var_type=( 42 | ( 43 | gd.ModelVarType.FIXED_LARGE 44 | if not args.sigma_small 45 | else gd.ModelVarType.FIXED_SMALL 46 | ) 47 | if not learn_sigma 48 | else gd.ModelVarType.LEARNED_RANGE 49 | ), 50 | loss_type=loss_type, 51 | rescale_timesteps=rescale_timesteps, 52 | lambda_recon=args.lambda_recon if hasattr(args, 'lambda_recon') else 0.0, 53 | lambda_pos=args.lambda_pos if hasattr(args, 'lambda_pos') else 0.0, 54 | lambda_geo=args.lambda_geo if hasattr(args, 'lambda_geo') else 0.0, 55 | lambda_vel=args.lambda_vel if hasattr(args, 'lambda_vel') else 0.0, 56 | lambda_sp=args.lambda_sp if hasattr(args, 'lambda_sp') else 0.0, 57 | train_args=args 58 | ) 59 | -------------------------------------------------------------------------------- /src/diffusion/losses.py: -------------------------------------------------------------------------------- 1 | # This code is based on https://github.com/openai/guided-diffusion 2 | """ 3 | Helpers for various likelihood-based losses. These are ported from the original 4 | Ho et al. diffusion models codebase: 5 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py 6 | """ 7 | 8 | import numpy as np 9 | import torch as th 10 | 11 | 12 | def normal_kl(mean1, logvar1, mean2, logvar2): 13 | """ 14 | Compute the KL divergence between two gaussians. 15 | 16 | Shapes are automatically broadcasted, so batches can be compared to 17 | scalars, among other use cases. 18 | """ 19 | tensor = None 20 | for obj in (mean1, logvar1, mean2, logvar2): 21 | if isinstance(obj, th.Tensor): 22 | tensor = obj 23 | break 24 | assert tensor is not None, "at least one argument must be a Tensor" 25 | 26 | # Force variances to be Tensors. Broadcasting helps convert scalars to 27 | # Tensors, but it does not work for th.exp(). 28 | logvar1, logvar2 = [ 29 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) 30 | for x in (logvar1, logvar2) 31 | ] 32 | 33 | return 0.5 * ( 34 | -1.0 35 | + logvar2 36 | - logvar1 37 | + th.exp(logvar1 - logvar2) 38 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2) 39 | ) 40 | 41 | 42 | def approx_standard_normal_cdf(x): 43 | """ 44 | A fast approximation of the cumulative distribution function of the 45 | standard normal. 46 | """ 47 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) 48 | 49 | 50 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 51 | """ 52 | Compute the log-likelihood of a Gaussian distribution discretizing to a 53 | given image. 54 | 55 | :param x: the target images. It is assumed that this was uint8 values, 56 | rescaled to the range [-1, 1]. 57 | :param means: the Gaussian mean Tensor. 58 | :param log_scales: the Gaussian log stddev Tensor. 59 | :return: a tensor like x of log probabilities (in nats). 60 | """ 61 | assert x.shape == means.shape == log_scales.shape 62 | centered_x = x - means 63 | inv_stdv = th.exp(-log_scales) 64 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 65 | cdf_plus = approx_standard_normal_cdf(plus_in) 66 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 67 | cdf_min = approx_standard_normal_cdf(min_in) 68 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 69 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 70 | cdf_delta = cdf_plus - cdf_min 71 | log_probs = th.where( 72 | x < -0.999, 73 | log_cdf_plus, 74 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 75 | ) 76 | assert log_probs.shape == x.shape 77 | return log_probs 78 | -------------------------------------------------------------------------------- /src/utils/word_vectorizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pickle 3 | from os.path import join as pjoin 4 | 5 | POS_enumerator = { 6 | 'VERB': 0, 7 | 'NOUN': 1, 8 | 'DET': 2, 9 | 'ADP': 3, 10 | 'NUM': 4, 11 | 'AUX': 5, 12 | 'PRON': 6, 13 | 'ADJ': 7, 14 | 'ADV': 8, 15 | 'Loc_VIP': 9, 16 | 'Body_VIP': 10, 17 | 'Obj_VIP': 11, 18 | 'Act_VIP': 12, 19 | 'Desc_VIP': 13, 20 | 'OTHER': 14, 21 | } 22 | 23 | Loc_list = ('left', 'right', 'clockwise', 'counterclockwise', 'anticlockwise', 'forward', 'back', 'backward', 24 | 'up', 'down', 'straight', 'curve') 25 | 26 | Body_list = ('arm', 'chin', 'foot', 'feet', 'face', 'hand', 'mouth', 'leg', 'waist', 'eye', 'knee', 'shoulder', 'thigh') 27 | 28 | Obj_List = ('stair', 'dumbbell', 'chair', 'window', 'floor', 'car', 'ball', 'handrail', 'baseball', 'basketball') 29 | 30 | Act_list = ('walk', 'run', 'swing', 'pick', 'bring', 'kick', 'put', 'squat', 'throw', 'hop', 'dance', 'jump', 'turn', 31 | 'stumble', 'dance', 'stop', 'sit', 'lift', 'lower', 'raise', 'wash', 'stand', 'kneel', 'stroll', 32 | 'rub', 'bend', 'balance', 'flap', 'jog', 'shuffle', 'lean', 'rotate', 'spin', 'spread', 'climb') 33 | 34 | Desc_list = ('slowly', 'carefully', 'fast', 'careful', 'slow', 'quickly', 'happy', 'angry', 'sad', 'happily', 35 | 'angrily', 'sadly') 36 | 37 | VIP_dict = { 38 | 'Loc_VIP': Loc_list, 39 | 'Body_VIP': Body_list, 40 | 'Obj_VIP': Obj_List, 41 | 'Act_VIP': Act_list, 42 | 'Desc_VIP': Desc_list, 43 | } 44 | 45 | 46 | class WordVectorizer(object): 47 | def __init__(self, meta_root, prefix): 48 | vectors = np.load(pjoin(meta_root, '%s_data.npy'%prefix)) 49 | words = pickle.load(open(pjoin(meta_root, '%s_words.pkl'%prefix), 'rb')) 50 | word2idx = pickle.load(open(pjoin(meta_root, '%s_idx.pkl'%prefix), 'rb')) 51 | self.word2vec = {w: vectors[word2idx[w]] for w in words} 52 | 53 | def _get_pos_ohot(self, pos): 54 | pos_vec = np.zeros(len(POS_enumerator)) 55 | if pos in POS_enumerator: 56 | pos_vec[POS_enumerator[pos]] = 1 57 | else: 58 | pos_vec[POS_enumerator['OTHER']] = 1 59 | return pos_vec 60 | 61 | def __len__(self): 62 | return len(self.word2vec) 63 | 64 | def __getitem__(self, item): 65 | word, pos = item.split('/') 66 | if word in self.word2vec: 67 | word_vec = self.word2vec[word] 68 | vip_pos = None 69 | for key, values in VIP_dict.items(): 70 | if word in values: 71 | vip_pos = key 72 | break 73 | if vip_pos is not None: 74 | pos_vec = self._get_pos_ohot(vip_pos) 75 | else: 76 | pos_vec = self._get_pos_ohot(pos) 77 | else: 78 | word_vec = self.word2vec['unk'] 79 | pos_vec = self._get_pos_ohot('OTHER') 80 | return word_vec, pos_vec -------------------------------------------------------------------------------- /src/dataset/decomp_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from tqdm import tqdm 3 | import h5py 4 | import numpy as np 5 | from loguru import logger 6 | import os.path as osp 7 | 8 | class decomp_dataset(Dataset): 9 | def __init__(self,opt,split='train'): 10 | self.opt=opt 11 | 12 | self.data=[] 13 | self.lengths=[] 14 | self.split=split 15 | self.load_data() 16 | 17 | self.cumsum = np.cumsum([0] + self.lengths) 18 | logger.info("Total number of motions {}, snippets {}".format(len(self.data), self.cumsum[-1])) 19 | 20 | def __len__(self): 21 | return self.cumsum[-1] 22 | 23 | def __getitem__(self, item): 24 | if item != 0: 25 | motion_id = np.searchsorted(self.cumsum, item) - 1 26 | idx = item - self.cumsum[motion_id] - 1 27 | else: 28 | motion_id = 0 29 | idx = 0 30 | motion = self.data[motion_id][idx:idx+self.opt.window_size] 31 | return motion 32 | 33 | def load_data(self): 34 | with h5py.File(osp.join(self.opt.data_root,"processed_%s"%self.opt.mode,"{}.h5".format(self.split)),'r') as f: 35 | for seq in tqdm(f.keys(),desc=f'Loading {self.split} dataset'): 36 | 37 | body_pose=f[seq]['body_pose'][:] # nf,21,6 38 | global_orient=f[seq]['global_orient'][:][:,None,:] # nf,1,6 39 | lhand_pose=f[seq]['lhand_pose'][:] # nf,15,6 40 | rhand_pose=f[seq]['rhand_pose'][:] # nf,15,6 41 | transl=f[seq]['transl'][:] # nf,3 42 | smplx_joints=f[seq]['smplx_joints'][:] 43 | smplx_joints=smplx_joints.reshape(smplx_joints.shape[0],-1) # nf,52*3 44 | full_pose=np.concatenate([global_orient,body_pose,lhand_pose,rhand_pose],axis=1) 45 | full_pose=full_pose.reshape(full_pose.shape[0],-1) # nf,52*6 46 | human_motion=np.concatenate([smplx_joints,full_pose,transl],axis=-1) # nf,52*6+52*3+3 47 | if self.opt.mode=='2o': 48 | o1,o2=sorted(f[seq]['object_state'].keys()) 49 | o1_state=f[seq]['object_state'][o1][:] 50 | o2_state=f[seq]['object_state'][o2][:] 51 | object_state=np.concatenate([o1_state,o2_state],axis=-1) 52 | elif self.opt.mode=='3o': 53 | o1,o2,o3=sorted(f[seq]['object_state'].keys()) 54 | o1_state=f[seq]['object_state'][o1][:] 55 | o2_state=f[seq]['object_state'][o2][:] 56 | o3_state=f[seq]['object_state'][o3][:] 57 | object_state=np.concatenate([o1_state,o2_state,o3_state],axis=-1) 58 | motion=np.concatenate([human_motion,object_state],axis=-1) # nf,(52*6+52*3+3)+(2*(6+3)) 59 | if motion.shape[0] 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 |

18 | 19 | This repository contains the content of the following paper: 20 | > HIMO: A New Benchmark for Full-Body Human Interacting with Multiple Objects
Xintao Lv 1,* , Liang Xu 1,2* , Yichao Yan 1, Xin Jin 2, Congsheng Xu 1, Shuwen Wu1, Yifan Liu1, Lincheng Li3, Mengxiao Bi3, Wenjun Zeng2, Xiaokang Yang 1
21 | > 1 Shanghai Jiao Tong University , 2 Eastern Institute of Technology, Ningbo, 3 NetEase Fuxi AI Lab 22 | 23 | 24 | ## News 25 | - [2025.03.10] We release the processed data for training and evaluation, please fill out the [this form](https://docs.google.com/forms/d/e/1FAIpQLSdl5adeyKxBSBFZpgs0A7-dAouRkMFAGUP5iz3zxGDj_PhB1w/viewform) to request `processed_2o.tar.gz` and `processed_3o.tar.gz`. 26 | 27 | 28 | ## Dataset Download 29 | Please fill out [this form](https://docs.google.com/forms/d/e/1FAIpQLSdl5adeyKxBSBFZpgs0A7-dAouRkMFAGUP5iz3zxGDj_PhB1w/viewform) to request authorization to download HIMO for research purposes. 30 | After downloading the dataset, unzip the data in `./data` and you'll get the following structure: 31 | ```shell 32 | ./data 33 | |-- joints 34 | | |-- S01T001.npy 35 | | |-- ... 36 | |-- smplx 37 | | |-- S01T001.npz 38 | | |-- ... 39 | |-- object_pose 40 | | |-- S01T001.npy 41 | | |-- ... 42 | |-- text 43 | | |-- S01T001.txt 44 | | |-- ... 45 | |-- segments 46 | | |-- S01T001.json 47 | | |-- ... 48 | |-- object_mesh 49 | | |-- Apple.obj 50 | | |-- ... 51 | |-- processed_2o 52 | | |-- ... 53 | |-- processed_3o 54 | | |-- ... 55 | 56 | ``` 57 | 58 | ## Data Visualization 59 | We use the [AIT-Viewer](https://github.com/eth-ait/aitviewer) to visualize the dataset. You can follow the instructions below to visualize it. 60 | ```bash 61 | pip install -r visualize/requirements.txt 62 | ``` 63 | You also need to download the [SMPL-X models](https://smpl-x.is.tue.mpg.de/) and place them in `./body_models`, which should look like: 64 | ```shell 65 | ./body_models 66 | |-- smplx 67 | ├── SMPLX_FEMALE.npz 68 | ├── SMPLX_FEMALE.pkl 69 | ├── SMPLX_MALE.npz 70 | ├── SMPLX_MALE.pkl 71 | ├── SMPLX_NEUTRAL.npz 72 | ├── SMPLX_NEUTRAL.pkl 73 | └── SMPLX_NEUTRAL_2020.npz 74 | ``` 75 | Then you can run the following command to visualize the dataset. 76 | ```bash 77 | # Visualize the skeleton 78 | python visualize/skel_viewer.py 79 | # Visualize the SMPLX 80 | python visualize/smplx_viewer.py 81 | # Visualize the segment data 82 | python visualize/segment_viewer.py 83 | ``` 84 | 85 | ## Training 86 | We provide preprocessd objects' BPS and vocabulary glove files [here](https://drive.google.com/drive/folders/11PYdla0R9GIyYXqDPMle9208Hv8MaUMo?usp=sharing). 87 | 88 | To train the model in 2-object setting, run 89 | ```bash 90 | python -m src.train.train_net_2o --exp_name net_2o --num_epochs 1000 91 | ``` 92 | To train the model in 3-object setting, run 93 | ```bash 94 | python -m src.train.train_net_3o --exp_name net_3o --num_epochs 1000 95 | ``` 96 | To evaluate the model, you need to train your own evaluator or use the checkpoint we provide [here](https://drive.google.com/drive/folders/11PYdla0R9GIyYXqDPMle9208Hv8MaUMo?usp=sharing) (put them under `./save`). 97 | Then run 98 | ```bash 99 | python -m src.eval.eval_himo_2o 100 | ``` 101 | or 102 | ```bash 103 | python -m src.eval.eval_himo_3o 104 | ``` 105 | -------------------------------------------------------------------------------- /src/feature_extractor/train_tex_mot_match.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | 4 | from loguru import logger 5 | import torch 6 | from src.utils.parser_utils import train_feature_extractor_args 7 | from src.utils.misc import fixseed 8 | from src.utils.word_vectorizer import POS_enumerator,WordVectorizer 9 | from src.feature_extractor.modules import TextEncoderBiGRUCo,MotionEncoderBiGRUCo,MovementConvEncoder 10 | from src.feature_extractor.trainer import TextMotionMatchTrainer 11 | from src.dataset.fe_dataset import feature_extractor_dataset,collate_fn 12 | from src.train.train_platforms import WandbPlatform 13 | 14 | from torch.utils.data import DataLoader 15 | 16 | def build_models(args): 17 | movement_enc=MovementConvEncoder(dim_pose,args.dim_movement_enc_hidden,args.dim_movement_latent) 18 | text_enc = TextEncoderBiGRUCo(word_size=dim_word, 19 | pos_size=dim_pos_ohot, 20 | hidden_size=args.dim_text_hidden, 21 | output_size=args.dim_coemb_hidden, 22 | device=args.device) 23 | motion_enc = MotionEncoderBiGRUCo(input_size=args.dim_movement_latent, 24 | hidden_size=args.dim_motion_hidden, 25 | output_size=args.dim_coemb_hidden, 26 | device=args.device) 27 | 28 | if not args.is_continue: 29 | logger.info('Loading Decomp......') 30 | checkpoint = torch.load(osp.join(args.checkpoints_dir, args.decomp_name, 'model', 'latest.tar'), 31 | map_location=args.device) 32 | movement_enc.load_state_dict(checkpoint['movement_enc']) 33 | return text_enc,motion_enc,movement_enc 34 | 35 | if __name__=='__main__': 36 | args=train_feature_extractor_args() 37 | args.device = torch.device("cpu" if args.gpu_id==-1 else "cuda:" + str(args.gpu_id)) 38 | torch.autograd.set_detect_anomaly(True) 39 | fixseed(args.seed) 40 | if args.gpu_id!=-1: 41 | torch.cuda.set_device(args.gpu_id) 42 | args.save_path=osp.join(args.save_dir,args.exp_name) 43 | args.model_dir=osp.join(args.save_path,'model') 44 | args.log_dir=osp.join(args.save_path,'log') 45 | args.eval_dir=osp.join(args.save_path,'eval') 46 | 47 | os.makedirs(args.save_path,exist_ok=True) 48 | os.makedirs(args.model_dir,exist_ok=True) 49 | os.makedirs(args.log_dir,exist_ok=True) 50 | os.makedirs(args.eval_dir,exist_ok=True) 51 | 52 | logger.add(osp.join(args.log_dir,'train_feature_extractor.log'),rotation='10 MB') 53 | 54 | args.data_root='/data/xuliang/HO2_subsets_original/HO2_final' 55 | args.max_motion_length=300 56 | if args.mode=='2o': 57 | dim_pose=52*6+52*3+3+2*(6+3) 58 | elif args.mode=='3o': 59 | dim_pose=52*6+52*3+3+3*(6+3) 60 | 61 | meta_root=osp.join(args.data_root,'glove') 62 | dim_word=300 63 | dim_pos_ohot=len(POS_enumerator) 64 | 65 | w_vectorizer=WordVectorizer(meta_root,'himo_vab') 66 | text_encoder,motion_encoder,movement_encoder=build_models(args) 67 | 68 | pc_text_enc = sum(param.numel() for param in text_encoder.parameters()) 69 | logger.info("Total parameters of text encoder: {}".format(pc_text_enc)) 70 | pc_motion_enc = sum(param.numel() for param in motion_encoder.parameters()) 71 | logger.info("Total parameters of motion encoder: {}".format(pc_motion_enc)) 72 | logger.info("Total parameters: {}".format(pc_motion_enc + pc_text_enc)) 73 | 74 | train_platform=WandbPlatform(args.save_path) 75 | 76 | trainer=TextMotionMatchTrainer(args,text_encoder,motion_encoder,movement_encoder,train_platform) 77 | 78 | train_dataset=feature_extractor_dataset(args,'train',w_vectorizer) 79 | val_dataset=feature_extractor_dataset(args,'val',w_vectorizer) 80 | 81 | train_loader = DataLoader(train_dataset, batch_size=args.batch_size, drop_last=True, num_workers=4, 82 | shuffle=True, collate_fn=collate_fn, pin_memory=True) 83 | val_loader = DataLoader(val_dataset, batch_size=args.batch_size, drop_last=True, num_workers=4, 84 | shuffle=True, collate_fn=collate_fn, pin_memory=True) 85 | 86 | trainer.train(train_loader,val_loader) 87 | 88 | 89 | 90 | -------------------------------------------------------------------------------- /src/feature_extractor/eval_wrapper.py: -------------------------------------------------------------------------------- 1 | from src.utils import dist_utils 2 | from src.utils.word_vectorizer import POS_enumerator 3 | from src.feature_extractor.modules import TextEncoderBiGRUCo,MotionEncoderBiGRUCo,MovementConvEncoder 4 | import os.path as osp 5 | import torch 6 | import numpy as np 7 | 8 | def build_evaluators(opt): 9 | movement_enc = MovementConvEncoder(opt['dim_pose'], opt['dim_movement_enc_hidden'], opt['dim_movement_latent']) 10 | text_enc = TextEncoderBiGRUCo(word_size=opt['dim_word'], 11 | pos_size=opt['dim_pos_ohot'], 12 | hidden_size=opt['dim_text_hidden'], 13 | output_size=opt['dim_coemb_hidden'], 14 | device=opt['device']) 15 | motion_enc = MotionEncoderBiGRUCo(input_size=opt['dim_movement_enc_hidden'], 16 | hidden_size=opt['dim_motion_hidden'], 17 | output_size=opt['dim_coemb_hidden'], 18 | device=opt['device']) 19 | checkpoint=torch.load(osp.join(opt['checkpoint_dir'],'model','finest.tar'), 20 | map_location=opt['device']) 21 | movement_enc.load_state_dict(checkpoint['movement_encoder']) 22 | text_enc.load_state_dict(checkpoint['text_encoder']) 23 | motion_enc.load_state_dict(checkpoint['motion_encoder']) 24 | print('Loading Evaluation Model Wrapper (Epoch %d) Completed!!' % (checkpoint['epoch'])) 25 | 26 | return text_enc,motion_enc,movement_enc 27 | 28 | class EvaluationWrapper: 29 | def __init__(self,val_args): 30 | opt={ 31 | 'device':dist_utils.dev(), 32 | 'dim_word':300, 33 | 'max_motion_length':300, 34 | 'dim_pos_ohot': len(POS_enumerator), 35 | 'dim_motion_hidden': 1024, 36 | 'max_text_len':40, 37 | 'dim_text_hidden':512, 38 | 'dim_coemb_hidden':512, 39 | 'dim_pose':52*3+52*6+3+2*9 if val_args.obj=='2o' else 52*3+52*6+3+3*9, 40 | 'dim_movement_enc_hidden': 512, 41 | 'dim_movement_latent': 512, 42 | 'checkpoint_dir':'./save/fe_2o_epoch300' if val_args.obj=='2o' else './save/fe_3o_epoch300', 43 | 'unit_length':5 44 | } 45 | self.text_encoder,self.motion_encoder,self.movement_encoder=build_evaluators(opt) 46 | self.opt = opt 47 | self.device = opt['device'] 48 | 49 | self.text_encoder.to(opt['device']) 50 | self.motion_encoder.to(opt['device']) 51 | self.movement_encoder.to(opt['device']) 52 | 53 | self.text_encoder.eval() 54 | self.motion_encoder.eval() 55 | self.movement_encoder.eval() 56 | 57 | # Please note that the results does not following the order of inputs 58 | def get_co_embeddings(self, word_embs, pos_ohot, cap_lens, motions, m_lens): 59 | with torch.no_grad(): 60 | word_embs = word_embs.detach().to(self.device).float() 61 | pos_ohot = pos_ohot.detach().to(self.device).float() 62 | motions = motions.detach().to(self.device).float() 63 | 64 | align_idx = np.argsort(m_lens.data.tolist())[::-1].copy() 65 | motions = motions[align_idx] 66 | m_lens = m_lens[align_idx] 67 | 68 | '''Movement Encoding''' 69 | movements = self.movement_encoder(motions).detach() 70 | m_lens = m_lens // self.opt['unit_length'] 71 | motion_embedding = self.motion_encoder(movements, m_lens) 72 | 73 | '''Text Encoding''' 74 | text_embedding = self.text_encoder(word_embs, pos_ohot, cap_lens) 75 | text_embedding = text_embedding[align_idx] 76 | return text_embedding, motion_embedding 77 | 78 | def get_motion_embeddings(self,motions,m_lens): 79 | with torch.no_grad(): 80 | motions = motions.detach().to(self.device).float() 81 | 82 | align_idx = np.argsort(m_lens.data.tolist())[::-1].copy() 83 | motions = motions[align_idx] 84 | m_lens = m_lens[align_idx] 85 | 86 | '''Movement Encoding''' 87 | movements = self.movement_encoder(motions).detach() 88 | m_lens = m_lens // self.opt['unit_length'] 89 | motion_embedding = self.motion_encoder(movements, m_lens) 90 | return motion_embedding 91 | -------------------------------------------------------------------------------- /src/dataset/fe_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | from typing import Any 4 | import numpy as np 5 | import h5py 6 | from tqdm import tqdm 7 | from torch.utils.data import Dataset 8 | import codecs as cs 9 | from torch.utils.data._utils.collate import default_collate 10 | 11 | class feature_extractor_dataset(Dataset): 12 | def __init__(self,args,split,w_vectorizer): 13 | self.args=args 14 | self.w_vectorizer=w_vectorizer 15 | self.max_motion_length=args.max_motion_length 16 | self.split=split 17 | self.data_path=osp.join(args.data_root,'processed_{}'.format(args.mode),f'{split}.h5') 18 | 19 | self.data=[] 20 | self.load_data() 21 | 22 | def __len__(self): 23 | return len(self.data) 24 | 25 | def __getitem__(self, idx): 26 | data=self.data[idx] 27 | motion,m_length,text_list=data['motion'],data['length'],data['text'] 28 | caption,tokens=text_list['caption'],text_list['tokens'] 29 | 30 | if len(tokens) < self.args.max_text_len: 31 | # pad with "unk" 32 | tokens = ['sos/OTHER'] + tokens + ['eos/OTHER'] 33 | sent_len = len(tokens) 34 | tokens = tokens + ['unk/OTHER'] * (self.args.max_text_len + 2 - sent_len) 35 | else: 36 | # crop 37 | tokens = tokens[:self.args.max_text_len] 38 | tokens = ['sos/OTHER'] + tokens + ['eos/OTHER'] 39 | sent_len = len(tokens) 40 | pos_one_hots = [] 41 | word_embeddings = [] 42 | for token in tokens: 43 | word_emb, pos_oh = self.w_vectorizer[token] 44 | pos_one_hots.append(pos_oh[None, :]) 45 | word_embeddings.append(word_emb[None, :]) 46 | pos_one_hots = np.concatenate(pos_one_hots, axis=0) 47 | word_embeddings = np.concatenate(word_embeddings, axis=0) 48 | 49 | if m_lengthbhij', queries, keys) * self.scale 111 | 112 | if mask is not None: 113 | mask = F.pad(mask.flatten(1), (1, 0), value = True) 114 | assert mask.shape[-1] == dots.shape[-1], 'Mask has incorrect dimensions' 115 | mask = mask[:, None, :].expand(-1, n, -1) 116 | dots.masked_fill_(~mask, float('-inf')) 117 | 118 | attn = dots.softmax(dim=-1) 119 | out = torch.einsum('bhij,bhjd->bhid', attn, values) 120 | out = out.transpose(1, 2).contiguous().view(b, n, -1) 121 | return out -------------------------------------------------------------------------------- /src/utils/dist_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for distributed training. 3 | """ 4 | 5 | import socket 6 | 7 | import torch as th 8 | import torch.distributed as dist 9 | 10 | # Change this to reflect your cluster layout. 11 | # The GPU for a given rank is (rank % GPUS_PER_NODE). 12 | GPUS_PER_NODE = 8 13 | 14 | SETUP_RETRY_COUNT = 3 15 | 16 | used_device = 0 17 | 18 | def setup_dist(device=0): 19 | """ 20 | Setup a distributed process group. 21 | """ 22 | global used_device 23 | used_device = device 24 | if dist.is_initialized(): 25 | return 26 | # os.environ["CUDA_VISIBLE_DEVICES"] = str(device) # f"{MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE}" 27 | 28 | # comm = MPI.COMM_WORLD 29 | # backend = "gloo" if not th.cuda.is_available() else "nccl" 30 | 31 | # if backend == "gloo": 32 | # hostname = "localhost" 33 | # else: 34 | # hostname = socket.gethostbyname(socket.getfqdn()) 35 | # os.environ["MASTER_ADDR"] = comm.bcast(hostname, root=0) 36 | # os.environ["RANK"] = str(comm.rank) 37 | # os.environ["WORLD_SIZE"] = str(comm.size) 38 | 39 | # port = comm.bcast(_find_free_port(), root=used_device) 40 | # os.environ["MASTER_PORT"] = str(port) 41 | # dist.init_process_group(backend=backend, init_method="env://") 42 | 43 | 44 | def dev(): 45 | """ 46 | Get the device to use for torch.distributed. 47 | """ 48 | global used_device 49 | if th.cuda.is_available() and used_device>=0: 50 | return th.device(f"cuda:{used_device}") 51 | return th.device("cpu") 52 | 53 | 54 | def load_state_dict(path, **kwargs): 55 | """ 56 | Load a PyTorch file without redundant fetches across MPI ranks. 57 | """ 58 | return th.load(path, **kwargs) 59 | 60 | 61 | def sync_params(params): 62 | """ 63 | Synchronize a sequence of Tensors across ranks from rank 0. 64 | """ 65 | for p in params: 66 | with th.no_grad(): 67 | dist.broadcast(p, 0) 68 | 69 | 70 | def _find_free_port(): 71 | try: 72 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 73 | s.bind(("", 0)) 74 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 75 | return s.getsockname()[1] 76 | finally: 77 | s.close() 78 | 79 | # """ 80 | # Helpers for distributed training. 81 | # """ 82 | 83 | # import io 84 | # import os 85 | # import socket 86 | 87 | # import blobfile as bf 88 | # # from mpi4py import MPI 89 | # import torch as th 90 | # import torch.distributed as dist 91 | 92 | # # Change this to reflect your cluster layout. 93 | # # The GPU for a given rank is (rank % GPUS_PER_NODE). 94 | # GPUS_PER_NODE = 2 95 | 96 | # SETUP_RETRY_COUNT = 3 97 | 98 | 99 | # def setup_dist(): 100 | # """ 101 | # Setup a distributed process group. 102 | # """ 103 | # if dist.is_initialized(): 104 | # return 105 | # # os.environ["CUDA_VISIBLE_DEVICES"] = f"{MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE}" 106 | 107 | # # comm = MPI.COMM_WORLD 108 | # backend = "gloo" if not th.cuda.is_available() else "nccl" 109 | 110 | # if backend == "gloo": 111 | # hostname = "localhost" 112 | # else: 113 | # hostname = socket.gethostbyname(socket.getfqdn()) 114 | # # os.environ["MASTER_ADDR"] = comm.bcast(hostname, root=0) 115 | # # os.environ["RANK"] = str(comm.rank) 116 | # # os.environ["WORLD_SIZE"] = str(comm.size) 117 | 118 | # # port = comm.bcast(_find_free_port(), root=0) 119 | # # os.environ["MASTER_PORT"] = str(port) 120 | # dist.init_process_group(backend=backend) 121 | 122 | 123 | # def dev(): 124 | # """ 125 | # Get the device to use for torch.distributed. 126 | # """ 127 | # if th.cuda.is_available(): 128 | # return th.device(f"cuda") 129 | # return th.device("cpu") 130 | 131 | 132 | # def load_state_dict(path, **kwargs): 133 | # """ 134 | # Load a PyTorch file without redundant fetches across MPI ranks. 135 | # """ 136 | # chunk_size = 2 ** 30 # MPI has a relatively small size limit 137 | # if MPI.COMM_WORLD.Get_rank() == 0: 138 | # with bf.BlobFile(path, "rb") as f: 139 | # data = f.read() 140 | # num_chunks = len(data) // chunk_size 141 | # if len(data) % chunk_size: 142 | # num_chunks += 1 143 | # MPI.COMM_WORLD.bcast(num_chunks) 144 | # for i in range(0, len(data), chunk_size): 145 | # MPI.COMM_WORLD.bcast(data[i : i + chunk_size]) 146 | # else: 147 | # num_chunks = MPI.COMM_WORLD.bcast(None) 148 | # data = bytes() 149 | # for _ in range(num_chunks): 150 | # data += MPI.COMM_WORLD.bcast(None) 151 | 152 | # return th.load(io.BytesIO(data), **kwargs) 153 | 154 | 155 | # def sync_params(params): 156 | # """ 157 | # Synchronize a sequence of Tensors across ranks from rank 0. 158 | # """ 159 | # for p in params: 160 | # with th.no_grad(): 161 | # dist.broadcast(p, 0) 162 | 163 | 164 | # def _find_free_port(): 165 | # try: 166 | # s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 167 | # s.bind(("", 0)) 168 | # s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 169 | # return s.getsockname()[1] 170 | # finally: 171 | # s.close() -------------------------------------------------------------------------------- /src/dataset/tensors.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data._utils.collate import default_collate 3 | 4 | def lengths_to_mask(lengths, max_len): 5 | # max_len = max(lengths) 6 | mask = torch.arange(max_len, device=lengths.device).expand(len(lengths), max_len) < lengths.unsqueeze(1) 7 | return mask 8 | 9 | def collate_tensors(batch): 10 | dims = batch[0].dim() 11 | max_size = [max([b.size(i) for b in batch]) for i in range(dims)] 12 | size = (len(batch),) + tuple(max_size) 13 | canvas = batch[0].new_zeros(size=size) 14 | for i, b in enumerate(batch): 15 | sub_tensor = canvas[i] 16 | for d in range(dims): 17 | sub_tensor = sub_tensor.narrow(d, 0, b.size(d)) 18 | sub_tensor.add_(b) 19 | return canvas 20 | 21 | def himo_2o_collate_fn(batch): 22 | adapteed_batch=[ 23 | { 24 | 'inp':torch.tensor(b[6]).float(), # [nf,52*3+52*6+3+2*9] 25 | 'length':b[7], 26 | 'text':b[0], 27 | 'obj1_bps':torch.tensor(b[1]).squeeze(0).float(), # [1024,3] 28 | 'obj2_bps':torch.tensor(b[2]).squeeze(0).float(), 29 | 'obj1_sampled_verts':torch.tensor(b[3]).float(), # [1024,3] 30 | 'obj2_sampled_verts':torch.tensor(b[4]).float(), 31 | 'init_state':torch.tensor(b[5]).float(), # [52*3+52*6+3+2*9] 32 | 'betas':torch.tensor(b[8]).float(), # [10] 33 | } for b in batch 34 | ] 35 | inp_tensor=collate_tensors([b['inp'] for b in adapteed_batch]) # [bs,nf,52*3+52*6+3+2*9] 36 | len_batch=[b['length'] for b in adapteed_batch] 37 | len_tensor=torch.tensor(len_batch).long() # [B] 38 | mask_tensor=lengths_to_mask(len_tensor,inp_tensor.shape[1]).unsqueeze(1).unsqueeze(1) # [B,1,1,nf] 39 | 40 | text_batch=[b['text'] for b in adapteed_batch] 41 | o1b_tensor=torch.stack([b['obj1_bps'] for b in adapteed_batch],dim=0) # [B,1024,3] 42 | o2b_tensor=torch.stack([b['obj2_bps'] for b in adapteed_batch],dim=0) # [B,1024,3] 43 | o1sv_tensor=torch.stack([b['obj1_sampled_verts'] for b in adapteed_batch],dim=0) # [B,1024,3] 44 | o2sv_tensor=torch.stack([b['obj2_sampled_verts'] for b in adapteed_batch],dim=0) # [B,1024,3] 45 | init_state_tensor=torch.stack([b['init_state'] for b in adapteed_batch],dim=0) # [B,52*3+52*6+3+2*9] 46 | betas_tensor=torch.stack([b['betas'] for b in adapteed_batch],dim=0) # [B,10] 47 | 48 | cond={ 49 | 'y':{ 50 | 'mask':mask_tensor, 51 | 'length':len_tensor, 52 | 'text':text_batch, 53 | 'obj1_bps':o1b_tensor, 54 | 'obj2_bps':o2b_tensor, 55 | 'obj1_sampled_verts':o1sv_tensor, 56 | 'obj2_sampled_verts':o2sv_tensor, 57 | 'init_state':init_state_tensor, 58 | 'betas':betas_tensor 59 | } 60 | } 61 | return inp_tensor,cond 62 | 63 | def himo_3o_collate_fn(batch): 64 | adapteed_batch=[ 65 | { 66 | 'inp':torch.tensor(b[8]).float(), # [nf,52*3+52*6+3+3*9] 67 | 'length':b[9], 68 | 'text':b[0], 69 | 'obj1_bps':torch.tensor(b[1]).squeeze(0).float(), # [1024,3] 70 | 'obj2_bps':torch.tensor(b[2]).squeeze(0).float(), 71 | 'obj3_bps':torch.tensor(b[3]).squeeze(0).float(), 72 | 'obj1_sampled_verts':torch.tensor(b[4]).float(), # [1024,3] 73 | 'obj2_sampled_verts':torch.tensor(b[5]).float(), 74 | 'obj3_sampled_verts':torch.tensor(b[6]).float(), 75 | 'init_state':torch.tensor(b[7]).float(), # [52*3+52*6+3+3*9] 76 | 'betas':torch.tensor(b[10]).float(), # [10] 77 | 78 | } for b in batch 79 | ] 80 | inp_tensor=collate_tensors([b['inp'] for b in adapteed_batch]) # [bs,nf,52*3+52*6+3+2*9] 81 | len_batch=[b['length'] for b in adapteed_batch] 82 | len_tensor=torch.tensor(len_batch).long() # [B] 83 | mask_tensor=lengths_to_mask(len_tensor,inp_tensor.shape[1]).unsqueeze(1).unsqueeze(1) # [B,1,1,nf] 84 | 85 | text_batch=[b['text'] for b in adapteed_batch] 86 | o1b_tensor=torch.stack([b['obj1_bps'] for b in adapteed_batch],dim=0) # [B,1024,3] 87 | o2b_tensor=torch.stack([b['obj2_bps'] for b in adapteed_batch],dim=0) # [B,1024,3] 88 | o3b_tensor=torch.stack([b['obj3_bps'] for b in adapteed_batch],dim=0) # [B,1024,3] 89 | o1sv_tensor=torch.stack([b['obj1_sampled_verts'] for b in adapteed_batch],dim=0) # [B,1024,3] 90 | o2sv_tensor=torch.stack([b['obj2_sampled_verts'] for b in adapteed_batch],dim=0) # [B,1024,3] 91 | o3sv_tensor=torch.stack([b['obj3_sampled_verts'] for b in adapteed_batch],dim=0) # [B,1024,3] 92 | init_state_tensor=torch.stack([b['init_state'] for b in adapteed_batch],dim=0) # [B,52*3+52*6+3+3*9] 93 | betas_tensor=torch.stack([b['betas'] for b in adapteed_batch],dim=0) # [B,10] 94 | 95 | cond={ 96 | 'y':{ 97 | 'mask':mask_tensor, 98 | 'length':len_tensor, 99 | 'text':text_batch, 100 | 'obj1_bps':o1b_tensor, 101 | 'obj2_bps':o2b_tensor, 102 | 'obj3_bps':o3b_tensor, 103 | 'obj1_sampled_verts':o1sv_tensor, 104 | 'obj2_sampled_verts':o2sv_tensor, 105 | 'obj3_sampled_verts':o3sv_tensor, 106 | 'init_state':init_state_tensor, 107 | 'betas':betas_tensor 108 | } 109 | } 110 | return inp_tensor,cond 111 | 112 | def gt_collate_fn(batch): 113 | # sort batch by sent length 114 | batch.sort(key=lambda x: x[3], reverse=True) 115 | return default_collate(batch) 116 | -------------------------------------------------------------------------------- /src/eval/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import linalg 3 | 4 | 5 | # (X - X_train)*(X - X_train) = -2X*X_train + X*X + X_train*X_train 6 | def euclidean_distance_matrix(matrix1, matrix2): 7 | """ 8 | Params: 9 | -- matrix1: N1 x D 10 | -- matrix2: N2 x D 11 | Returns: 12 | -- dist: N1 x N2 13 | dist[i, j] == distance(matrix1[i], matrix2[j]) 14 | """ 15 | assert matrix1.shape[1] == matrix2.shape[1] 16 | d1 = -2 * np.dot(matrix1, matrix2.T) # shape (num_test, num_train) 17 | d2 = np.sum(np.square(matrix1), axis=1, keepdims=True) # shape (num_test, 1) 18 | d3 = np.sum(np.square(matrix2), axis=1) # shape (num_train, ) 19 | dists = np.sqrt(d1 + d2 + d3) # broadcasting 20 | return dists 21 | 22 | def calculate_top_k(mat, top_k): 23 | size = mat.shape[0] 24 | gt_mat = np.expand_dims(np.arange(size), 1).repeat(size, 1) 25 | bool_mat = (mat == gt_mat) 26 | correct_vec = False 27 | top_k_list = [] 28 | for i in range(top_k): 29 | # print(correct_vec, bool_mat[:, i]) 30 | correct_vec = (correct_vec | bool_mat[:, i]) 31 | # print(correct_vec) 32 | top_k_list.append(correct_vec[:, None]) 33 | top_k_mat = np.concatenate(top_k_list, axis=1) 34 | return top_k_mat 35 | 36 | 37 | def calculate_R_precision(embedding1, embedding2, top_k, sum_all=False): 38 | dist_mat = euclidean_distance_matrix(embedding1, embedding2) 39 | argmax = np.argsort(dist_mat, axis=1) 40 | top_k_mat = calculate_top_k(argmax, top_k) 41 | if sum_all: 42 | return top_k_mat.sum(axis=0) 43 | else: 44 | return top_k_mat 45 | 46 | 47 | def calculate_matching_score(embedding1, embedding2, sum_all=False): 48 | assert len(embedding1.shape) == 2 49 | assert embedding1.shape[0] == embedding2.shape[0] 50 | assert embedding1.shape[1] == embedding2.shape[1] 51 | 52 | dist = linalg.norm(embedding1 - embedding2, axis=1) 53 | if sum_all: 54 | return dist.sum(axis=0) 55 | else: 56 | return dist 57 | 58 | 59 | 60 | def calculate_activation_statistics(activations): 61 | """ 62 | Params: 63 | -- activation: num_samples x dim_feat 64 | Returns: 65 | -- mu: dim_feat 66 | -- sigma: dim_feat x dim_feat 67 | """ 68 | mu = np.mean(activations, axis=0) 69 | cov = np.cov(activations, rowvar=False) 70 | return mu, cov 71 | 72 | 73 | def calculate_diversity(activation, diversity_times): 74 | assert len(activation.shape) == 2 75 | assert activation.shape[0] > diversity_times 76 | num_samples = activation.shape[0] 77 | 78 | first_indices = np.random.choice(num_samples, diversity_times, replace=False) 79 | second_indices = np.random.choice(num_samples, diversity_times, replace=False) 80 | dist = linalg.norm(activation[first_indices] - activation[second_indices], axis=1) 81 | return dist.mean() 82 | 83 | 84 | def calculate_multimodality(activation, multimodality_times): 85 | assert len(activation.shape) == 3 86 | assert activation.shape[1] > multimodality_times 87 | num_per_sent = activation.shape[1] 88 | 89 | first_dices = np.random.choice(num_per_sent, multimodality_times, replace=False) 90 | second_dices = np.random.choice(num_per_sent, multimodality_times, replace=False) 91 | dist = linalg.norm(activation[:, first_dices] - activation[:, second_dices], axis=2) 92 | return dist.mean() 93 | 94 | 95 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 96 | """Numpy implementation of the Frechet Distance. 97 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 98 | and X_2 ~ N(mu_2, C_2) is 99 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 100 | Stable version by Dougal J. Sutherland. 101 | Params: 102 | -- mu1 : Numpy array containing the activations of a layer of the 103 | inception net (like returned by the function 'get_predictions') 104 | for generated samples. 105 | -- mu2 : The sample mean over activations, precalculated on an 106 | representative dataset set. 107 | -- sigma1: The covariance matrix over activations for generated samples. 108 | -- sigma2: The covariance matrix over activations, precalculated on an 109 | representative dataset set. 110 | Returns: 111 | -- : The Frechet Distance. 112 | """ 113 | 114 | mu1 = np.atleast_1d(mu1) 115 | mu2 = np.atleast_1d(mu2) 116 | 117 | sigma1 = np.atleast_2d(sigma1) 118 | sigma2 = np.atleast_2d(sigma2) 119 | 120 | assert mu1.shape == mu2.shape, \ 121 | 'Training and test mean vectors have different lengths' 122 | assert sigma1.shape == sigma2.shape, \ 123 | 'Training and test covariances have different dimensions' 124 | 125 | diff = mu1 - mu2 126 | 127 | # Product might be almost singular 128 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 129 | if not np.isfinite(covmean).all(): 130 | msg = ('fid calculation produces singular product; ' 131 | 'adding %s to diagonal of cov estimates') % eps 132 | print(msg) 133 | offset = np.eye(sigma1.shape[0]) * eps 134 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 135 | 136 | # Numerical error might give slight imaginary component 137 | if np.iscomplexobj(covmean): 138 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 139 | m = np.max(np.abs(covmean.imag)) 140 | raise ValueError('Imaginary component {}'.format(m)) 141 | covmean = covmean.real 142 | 143 | tr_covmean = np.trace(covmean) 144 | 145 | return (diff.dot(diff) + np.trace(sigma1) + 146 | np.trace(sigma2) - 2 * tr_covmean) -------------------------------------------------------------------------------- /src/diffusion/respace.py: -------------------------------------------------------------------------------- 1 | # This code is based on https://github.com/openai/guided-diffusion 2 | import numpy as np 3 | import torch as th 4 | 5 | from .gaussian_diffusion import GaussianDiffusion 6 | 7 | 8 | def space_timesteps(num_timesteps, section_counts): 9 | """ 10 | Create a list of timesteps to use from an original diffusion process, 11 | given the number of timesteps we want to take from equally-sized portions 12 | of the original process. 13 | 14 | For example, if there's 300 timesteps and the section counts are [10,15,20] 15 | then the first 100 timesteps are strided to be 10 timesteps, the second 100 16 | are strided to be 15 timesteps, and the final 100 are strided to be 20. 17 | 18 | If the stride is a string starting with "ddim", then the fixed striding 19 | from the DDIM paper is used, and only one section is allowed. 20 | 21 | :param num_timesteps: the number of diffusion steps in the original 22 | process to divide up. 23 | :param section_counts: either a list of numbers, or a string containing 24 | comma-separated numbers, indicating the step count 25 | per section. As a special case, use "ddimN" where N 26 | is a number of steps to use the striding from the 27 | DDIM paper. 28 | :return: a set of diffusion steps from the original process to use. 29 | """ 30 | if isinstance(section_counts, str): 31 | if section_counts.startswith("ddim"): 32 | desired_count = int(section_counts[len("ddim") :]) 33 | for i in range(1, num_timesteps): 34 | if len(range(0, num_timesteps, i)) == desired_count: 35 | return set(range(0, num_timesteps, i)) 36 | raise ValueError( 37 | f"cannot create exactly {num_timesteps} steps with an integer stride" 38 | ) 39 | section_counts = [int(x) for x in section_counts.split(",")] 40 | size_per = num_timesteps // len(section_counts) 41 | extra = num_timesteps % len(section_counts) 42 | start_idx = 0 43 | all_steps = [] 44 | for i, section_count in enumerate(section_counts): 45 | size = size_per + (1 if i < extra else 0) 46 | if size < section_count: 47 | raise ValueError( 48 | f"cannot divide section of {size} steps into {section_count}" 49 | ) 50 | if section_count <= 1: 51 | frac_stride = 1 52 | else: 53 | frac_stride = (size - 1) / (section_count - 1) 54 | cur_idx = 0.0 55 | taken_steps = [] 56 | for _ in range(section_count): 57 | taken_steps.append(start_idx + round(cur_idx)) 58 | cur_idx += frac_stride 59 | all_steps += taken_steps 60 | start_idx += size 61 | return set(all_steps) 62 | 63 | 64 | class SpacedDiffusion(GaussianDiffusion): 65 | """ 66 | A diffusion process which can skip steps in a base diffusion process. 67 | 68 | :param use_timesteps: a collection (sequence or set) of timesteps from the 69 | original diffusion process to retain. 70 | :param kwargs: the kwargs to create the base diffusion process. 71 | """ 72 | 73 | def __init__(self, use_timesteps, **kwargs): 74 | self.use_timesteps = set(use_timesteps) 75 | self.timestep_map = [] 76 | self.original_num_steps = len(kwargs["betas"]) 77 | 78 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa 79 | last_alpha_cumprod = 1.0 80 | new_betas = [] 81 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): 82 | if i in self.use_timesteps: 83 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) 84 | last_alpha_cumprod = alpha_cumprod 85 | self.timestep_map.append(i) 86 | kwargs["betas"] = np.array(new_betas) 87 | super().__init__(**kwargs) 88 | 89 | def p_mean_variance( 90 | self, model, *args, **kwargs 91 | ): # pylint: disable=signature-differs 92 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) 93 | 94 | def training_losses( 95 | self, model, *args, **kwargs 96 | ): # pylint: disable=signature-differs 97 | return super().training_losses(self._wrap_model(model), *args, **kwargs) 98 | 99 | def condition_mean(self, cond_fn, *args, **kwargs): 100 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) 101 | 102 | def condition_score(self, cond_fn, *args, **kwargs): 103 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) 104 | 105 | def _wrap_model(self, model): 106 | if isinstance(model, _WrappedModel): 107 | return model 108 | return _WrappedModel( 109 | model, self.timestep_map, self.rescale_timesteps, self.original_num_steps 110 | ) 111 | 112 | def _scale_timesteps(self, t): 113 | # Scaling is done by the wrapped model. 114 | return t 115 | 116 | 117 | class _WrappedModel: 118 | def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps): 119 | self.model = model 120 | self.timestep_map = timestep_map 121 | self.rescale_timesteps = rescale_timesteps 122 | self.original_num_steps = original_num_steps 123 | 124 | def __call__(self, x, ts, **kwargs): 125 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) 126 | new_ts = map_tensor[ts] 127 | if self.rescale_timesteps: 128 | new_ts = new_ts.float() * (1000.0 / self.original_num_steps) 129 | return self.model(x, new_ts, **kwargs) 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | # HIMO License 2 | 3 | ## Dataset Copyright License for non-commercial scientific research purposes 4 | 5 | Please read carefully the following terms and conditions and any accompanying documentation before you download and/or use the HIMO Dataset and the accompanying Software (jointly referred to as the "Dataset"). By downloading and/or using the Dataset, you acknowledge that you have read these terms and conditions, understand them, and agree to be bound by them. If you do not agree with these terms and conditions, you must not download and/or use the Dataset. Any infringement of the terms of this agreement will automatically terminate your rights under this License. 6 | 7 | ## Ownership / Licensees 8 | 9 | Notwithstanding the disclosure of HIMO and any material pertaining to this dataset, all copyrights are maintained by and remained proprietary property of the authors and/or the copyright holders. You may use the dataset by ways in compliance with all applicable laws and regulations. 10 | 11 | The annotations in HIMO are licensed under the Attribution-Non Commercial-Share Alike 4.0 International License (CC-BY-NC-SA 4.0). With this free CC-license, we encourage the adoption and use of HIMO and related annotations. 12 | 13 | You must give appropriate credit, provide a link to the license, and indicate if changes were made. You may do so in any reasonable manner, but not in any way that suggests the licensor endorses you or your use. You may not use the material for commercial purposes. If you remix, transform, or build upon the material, you must distribute your contributions under the same license as the original. 14 | 15 | ## License Grant 16 | 17 | Licensor grants you (Licensee) personally a single-user, non-exclusive, non-transferable, free of charge right: 18 | 19 | To obtain and install the Dataset on computers owned, leased or otherwise controlled by you and/or your organization; 20 | To use the Dataset for the sole purpose of performing non-commercial scientific research, non-commercial education, or non-commercial artistic projects; 21 | Any other use, in particular any use for commercial purposes, is prohibited. This includes, without limitation, incorporation in a commercial product, use in a commercial service, or production of other artifacts for commercial purposes. The Dataset may not be reproduced, modified and/or made available in any form to any third party without Licensor’s prior written permission. 22 | 23 | The Dataset may not be used for pornographic purposes or to generate pornographic material whether commercial or not. This license also prohibits the use of the Dataset to train methods/algorithms/neural networks/etc. for commercial use of any kind. By downloading the Dataset, you agree not to reverse engineer it. 24 | 25 | ## No Distribution 26 | 27 | The Dataset and the license herein granted shall not be copied, shared, distributed, re-sold, offered for re-sale, transferred or sub-licensed in whole or in part except that you may make one copy for archive purposes only. 28 | 29 | ## Disclaimer of Representations and Warranties 30 | 31 | You expressly acknowledge and agree that the Dataset results from basic research, is provided “AS IS”, may contain errors, and that any use of the Dataset is at your sole risk. LICENSOR MAKES NO REPRESENTATIONS OR WARRANTIES OF ANY KIND CONCERNING THE SOFTWARE, NEITHER EXPRESS NOR IMPLIED, AND THE ABSENCE OF ANY LEGAL OR ACTUAL DEFECTS, WHETHER DISCOVERABLE OR NOT. Specifically, and not to limit the foregoing, licensor makes no representations or warranties (i) regarding the merchantability or fitness for a particular purpose of the Dataset, (ii) that the use of the Dataset will not infringe any patents, copyrights or other intellectual property rights of a third party, and (iii) that the use of the Dataset will not cause any damage of any kind to you or a third party. 32 | 33 | ## Limitation of Liability 34 | 35 | Because this Software License Agreement qualifies as a donation, Licensor as a donor is liable for intent and gross negligence only. If the Licensor fraudulently conceals a legal or material defect, they are obliged to compensate the Licensee for the resulting damage. 36 | Licensor shall be liable for loss of data only up to the amount of typical recovery costs which would have arisen had proper and regular data backup measures been taken. For the avoidance of doubt Licensor shall be liable in accordance with the Chinese Product Liability Act in the event of product liability. The foregoing applies also to Licensor’s legal representatives or assistants in performance. Any further liability shall be excluded. 37 | Patent claims generated through the usage of the Dataset cannot be directed towards the copyright holders. 38 | The contractor points out that add-ons as well as minor modifications to the Dataset may lead to unforeseeable and considerable disruptions. 39 | 40 | ## No Maintenance Services 41 | 42 | You understand and agree that Licensor is under no obligation to provide either maintenance services, update services, notices of latent defects, or corrections of defects with regard to the Dataset. Licensor nevertheless reserves the right to update, modify, or discontinue the Dataset at any time. 43 | 44 | Defects of the Dataset must be notified in writing to the Licensor with a comprehensible description of the error symptoms. The notification of the defect should enable the reproduction of the error. The Licensee is encouraged to communicate any use, results, modification or publication. 45 | 46 | ## Publications using the Dataset 47 | 48 | You acknowledge that the Dataset is a valuable scientific resource and agree to appropriately reference the following paper in any publication making use of the Dataset. 49 | 50 | #### Citation: 51 | ``` 52 | @article{lv2024himo, 53 | title={HIMO: A New Benchmark for Full-Body Human Interacting with Multiple Objects}, 54 | author={Xintao, Lv and Liang, Xu and YiChao, Yan and Xin, Jin and Congsheng, Xu and Shuwen, Wu and Yifan, Liu and Lincheng, Li and Mengxiao, Bi and Wenjun, Zeng and Xiaokang, Yang}, 55 | year={2024} 56 | } 57 | ``` 58 | 59 | ## Commercial licensing opportunities 60 | 61 | For commercial uses of the Dataset, please send email to [lvxintao@sjtu.edu.cn](mailto:lvxintao@sjtu.edu.cn). -------------------------------------------------------------------------------- /src/feature_extractor/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import time 5 | import math 6 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 7 | # from networks.layers import * 8 | import torch.nn.functional as F 9 | 10 | def init_weight(m): 11 | if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear) or isinstance(m, nn.ConvTranspose1d): 12 | nn.init.xavier_normal_(m.weight) 13 | # m.bias.data.fill_(0.01) 14 | if m.bias is not None: 15 | nn.init.constant_(m.bias, 0) 16 | 17 | class TextEncoderBiGRUCo(nn.Module): 18 | def __init__(self, word_size, pos_size, hidden_size, output_size, device): 19 | super(TextEncoderBiGRUCo, self).__init__() 20 | self.device = device 21 | 22 | self.pos_emb = nn.Linear(pos_size, word_size) 23 | self.input_emb = nn.Linear(word_size, hidden_size) 24 | self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True) 25 | self.output_net = nn.Sequential( 26 | nn.Linear(hidden_size * 2, hidden_size), 27 | nn.LayerNorm(hidden_size), 28 | nn.LeakyReLU(0.2, inplace=True), 29 | nn.Linear(hidden_size, output_size) 30 | ) 31 | 32 | self.input_emb.apply(init_weight) 33 | self.pos_emb.apply(init_weight) 34 | self.output_net.apply(init_weight) 35 | # self.linear2.apply(init_weight) 36 | # self.batch_size = batch_size 37 | self.hidden_size = hidden_size 38 | self.hidden = nn.Parameter(torch.randn((2, 1, self.hidden_size), requires_grad=True)) 39 | 40 | # input(batch_size, seq_len, dim) 41 | def forward(self, word_embs, pos_onehot, cap_lens): 42 | num_samples = word_embs.shape[0] 43 | 44 | pos_embs = self.pos_emb(pos_onehot) 45 | inputs = word_embs + pos_embs 46 | input_embs = self.input_emb(inputs) 47 | hidden = self.hidden.repeat(1, num_samples, 1) 48 | 49 | cap_lens = cap_lens.data.tolist() 50 | emb = pack_padded_sequence(input_embs, cap_lens, batch_first=True) 51 | 52 | gru_seq, gru_last = self.gru(emb, hidden) 53 | 54 | gru_last = torch.cat([gru_last[0], gru_last[1]], dim=-1) 55 | 56 | return self.output_net(gru_last) 57 | 58 | class MovementConvEncoder(nn.Module): 59 | def __init__(self, input_size, hidden_size, output_size): 60 | super(MovementConvEncoder, self).__init__() 61 | self.main = nn.Sequential( 62 | nn.Conv1d(input_size, hidden_size, 4, 2, 1), 63 | nn.Dropout(0.2, inplace=True), 64 | nn.LeakyReLU(0.2, inplace=True), 65 | nn.Conv1d(hidden_size, output_size, 4, 2, 1), 66 | nn.Dropout(0.2, inplace=True), 67 | nn.LeakyReLU(0.2, inplace=True), 68 | ) 69 | self.out_net = nn.Linear(output_size, output_size) 70 | self.main.apply(init_weight) 71 | self.out_net.apply(init_weight) 72 | 73 | def forward(self, inputs): 74 | # bs,nf,dim 75 | inputs = inputs.permute(0, 2, 1) 76 | outputs = self.main(inputs).permute(0, 2, 1) 77 | # print(outputs.shape) 78 | return self.out_net(outputs) 79 | 80 | 81 | class MovementConvDecoder(nn.Module): 82 | def __init__(self, input_size, hidden_size, output_size): 83 | super(MovementConvDecoder, self).__init__() 84 | self.main = nn.Sequential( 85 | nn.ConvTranspose1d(input_size, hidden_size, 4, 2, 1), 86 | # nn.Dropout(0.2, inplace=True), 87 | nn.LeakyReLU(0.2, inplace=True), 88 | nn.ConvTranspose1d(hidden_size, output_size, 4, 2, 1), 89 | # nn.Dropout(0.2, inplace=True), 90 | nn.LeakyReLU(0.2, inplace=True), 91 | ) 92 | self.out_net = nn.Linear(output_size, output_size) 93 | 94 | self.main.apply(init_weight) 95 | self.out_net.apply(init_weight) 96 | 97 | def forward(self, inputs): 98 | inputs = inputs.permute(0, 2, 1) 99 | outputs = self.main(inputs).permute(0, 2, 1) 100 | return self.out_net(outputs) 101 | 102 | class MotionEncoderBiGRUCo(nn.Module): 103 | def __init__(self, input_size, hidden_size, output_size, device): 104 | super(MotionEncoderBiGRUCo, self).__init__() 105 | self.device = device 106 | 107 | self.input_emb = nn.Linear(input_size, hidden_size) 108 | self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True) 109 | self.output_net = nn.Sequential( 110 | nn.Linear(hidden_size*2, hidden_size), 111 | nn.LayerNorm(hidden_size), 112 | nn.LeakyReLU(0.2, inplace=True), 113 | nn.Linear(hidden_size, output_size) 114 | ) 115 | 116 | self.input_emb.apply(init_weight) 117 | self.output_net.apply(init_weight) 118 | self.hidden_size = hidden_size 119 | self.hidden = nn.Parameter(torch.randn((2, 1, self.hidden_size), requires_grad=True)) 120 | 121 | # input(batch_size, seq_len, dim) 122 | def forward(self, inputs, m_lens): 123 | num_samples = inputs.shape[0] 124 | 125 | input_embs = self.input_emb(inputs) 126 | hidden = self.hidden.repeat(1, num_samples, 1) 127 | 128 | cap_lens = m_lens.data.tolist() 129 | emb = pack_padded_sequence(input_embs, cap_lens, batch_first=True) 130 | 131 | gru_seq, gru_last = self.gru(emb, hidden) 132 | 133 | gru_last = torch.cat([gru_last[0], gru_last[1]], dim=-1) 134 | 135 | return self.output_net(gru_last) 136 | 137 | class ContrastiveLoss(torch.nn.Module): 138 | """ 139 | Contrastive loss function. 140 | Based on: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf 141 | """ 142 | def __init__(self, margin=3.0): 143 | super(ContrastiveLoss, self).__init__() 144 | self.margin = margin 145 | 146 | def forward(self, output1, output2, label): 147 | euclidean_distance = F.pairwise_distance(output1, output2, keepdim=True) 148 | loss_contrastive = torch.mean((1-label) * torch.pow(euclidean_distance, 2) + 149 | (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2)) 150 | return loss_contrastive -------------------------------------------------------------------------------- /src/diffusion/resample.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import numpy as np 4 | import torch as th 5 | import torch.distributed as dist 6 | 7 | 8 | def create_named_schedule_sampler(name, diffusion): 9 | """ 10 | Create a ScheduleSampler from a library of pre-defined samplers. 11 | 12 | :param name: the name of the sampler. 13 | :param diffusion: the diffusion object to sample for. 14 | """ 15 | if name == "uniform": 16 | return UniformSampler(diffusion) 17 | elif name == "loss-second-moment": 18 | return LossSecondMomentResampler(diffusion) 19 | else: 20 | raise NotImplementedError(f"unknown schedule sampler: {name}") 21 | 22 | 23 | class ScheduleSampler(ABC): 24 | """ 25 | A distribution over timesteps in the diffusion process, intended to reduce 26 | variance of the objective. 27 | 28 | By default, samplers perform unbiased importance sampling, in which the 29 | objective's mean is unchanged. 30 | However, subclasses may override sample() to change how the resampled 31 | terms are reweighted, allowing for actual changes in the objective. 32 | """ 33 | 34 | @abstractmethod 35 | def weights(self): 36 | """ 37 | Get a numpy array of weights, one per diffusion step. 38 | 39 | The weights needn't be normalized, but must be positive. 40 | """ 41 | 42 | def sample(self, batch_size, device): 43 | """ 44 | Importance-sample timesteps for a batch. 45 | 46 | :param batch_size: the number of timesteps. 47 | :param device: the torch device to save to. 48 | :return: a tuple (timesteps, weights): 49 | - timesteps: a tensor of timestep indices. 50 | - weights: a tensor of weights to scale the resulting losses. 51 | """ 52 | w = self.weights() 53 | p = w / np.sum(w) 54 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p) 55 | indices = th.from_numpy(indices_np).long().to(device) 56 | weights_np = 1 / (len(p) * p[indices_np]) 57 | weights = th.from_numpy(weights_np).float().to(device) 58 | return indices, weights 59 | 60 | 61 | class UniformSampler(ScheduleSampler): 62 | def __init__(self, diffusion): 63 | self.diffusion = diffusion 64 | self._weights = np.ones([diffusion.num_timesteps]) 65 | 66 | def weights(self): 67 | return self._weights 68 | 69 | 70 | class LossAwareSampler(ScheduleSampler): 71 | def update_with_local_losses(self, local_ts, local_losses): 72 | """ 73 | Update the reweighting using losses from a model. 74 | 75 | Call this method from each rank with a batch of timesteps and the 76 | corresponding losses for each of those timesteps. 77 | This method will perform synchronization to make sure all of the ranks 78 | maintain the exact same reweighting. 79 | 80 | :param local_ts: an integer Tensor of timesteps. 81 | :param local_losses: a 1D Tensor of losses. 82 | """ 83 | batch_sizes = [ 84 | th.tensor([0], dtype=th.int32, device=local_ts.device) 85 | for _ in range(dist.get_world_size()) 86 | ] 87 | dist.all_gather( 88 | batch_sizes, 89 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), 90 | ) 91 | 92 | # Pad all_gather batches to be the maximum batch size. 93 | batch_sizes = [x.item() for x in batch_sizes] 94 | max_bs = max(batch_sizes) 95 | 96 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] 97 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] 98 | dist.all_gather(timestep_batches, local_ts) 99 | dist.all_gather(loss_batches, local_losses) 100 | timesteps = [ 101 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] 102 | ] 103 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] 104 | self.update_with_all_losses(timesteps, losses) 105 | 106 | @abstractmethod 107 | def update_with_all_losses(self, ts, losses): 108 | """ 109 | Update the reweighting using losses from a model. 110 | 111 | Sub-classes should override this method to update the reweighting 112 | using losses from the model. 113 | 114 | This method directly updates the reweighting without synchronizing 115 | between workers. It is called by update_with_local_losses from all 116 | ranks with identical arguments. Thus, it should have deterministic 117 | behavior to maintain state across workers. 118 | 119 | :param ts: a list of int timesteps. 120 | :param losses: a list of float losses, one per timestep. 121 | """ 122 | 123 | 124 | class LossSecondMomentResampler(LossAwareSampler): 125 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): 126 | self.diffusion = diffusion 127 | self.history_per_term = history_per_term 128 | self.uniform_prob = uniform_prob 129 | self._loss_history = np.zeros( 130 | [diffusion.num_timesteps, history_per_term], dtype=np.float64 131 | ) 132 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) 133 | 134 | def weights(self): 135 | if not self._warmed_up(): 136 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64) 137 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) 138 | weights /= np.sum(weights) 139 | weights *= 1 - self.uniform_prob 140 | weights += self.uniform_prob / len(weights) 141 | return weights 142 | 143 | def update_with_all_losses(self, ts, losses): 144 | for t, loss in zip(ts, losses): 145 | if self._loss_counts[t] == self.history_per_term: 146 | # Shift out the oldest loss term. 147 | self._loss_history[t, :-1] = self._loss_history[t, 1:] 148 | self._loss_history[t, -1] = loss 149 | else: 150 | self._loss_history[t, self._loss_counts[t]] = loss 151 | self._loss_counts[t] += 1 152 | 153 | def _warmed_up(self): 154 | return (self._loss_counts == self.history_per_term).all() 155 | -------------------------------------------------------------------------------- /src/dataset/eval_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import numpy as np 4 | import h5py 5 | from tqdm import tqdm 6 | from torch.utils.data import Dataset 7 | import codecs as cs 8 | from torch.utils.data._utils.collate import default_collate 9 | from src.utils.word_vectorizer import WordVectorizer 10 | 11 | class Evaluation_Dataset(Dataset): 12 | def __init__(self,args,split='test',mode='gt'): 13 | self.args=args 14 | self.max_motion_length=self.args.max_motion_length 15 | self.split=split 16 | self.mode=mode 17 | self.obj=self.args.obj 18 | self.data_path=osp.join(args.data_dir,'processed_{}'.format(self.obj),f'{split}.h5') 19 | self.data_path=osp.join(args.data_dir,'processed_{}'.format(self.obj),f'{split}.h5') 20 | self.obj_bps_path=osp.join(args.data_dir,'processed_{}'.format(self.obj),'object_bps.npz') 21 | self.obj_sampled_verts_path=osp.join(args.data_dir,'processed_{}'.format(self.obj),'sampled_obj_verts.npz') 22 | self.w_vectorizer=WordVectorizer(osp.join(self.args.data_dir,'glove'),'himo_vab') 23 | 24 | self.data=[] 25 | self.load_data() 26 | 27 | self.object_bps=dict(np.load(self.obj_bps_path,allow_pickle=True)) # 1,1024,3 28 | self.object_sampled_verts=dict(np.load(self.obj_sampled_verts_path,allow_pickle=True)) # 1024,3 29 | 30 | def __len__(self): 31 | return len(self.data) 32 | 33 | def __getitem__(self, idx): 34 | data=self.data[idx] 35 | motion,m_length,text_list=data['motion'],data['length'],data['text'] 36 | caption,tokens=text_list['caption'],text_list['tokens'] 37 | 38 | if len(tokens) < self.args.max_text_len: 39 | # pad with "unk" 40 | tokens = ['sos/OTHER'] + tokens + ['eos/OTHER'] 41 | sent_len = len(tokens) 42 | tokens = tokens + ['unk/OTHER'] * (self.args.max_text_len + 2 - sent_len) 43 | else: 44 | # crop 45 | tokens = tokens[:self.args.max_text_len] 46 | tokens = ['sos/OTHER'] + tokens + ['eos/OTHER'] 47 | sent_len = len(tokens) 48 | pos_one_hots = [] 49 | word_embeddings = [] 50 | for token in tokens: 51 | word_emb, pos_oh = self.w_vectorizer[token] 52 | pos_one_hots.append(pos_oh[None, :]) 53 | word_embeddings.append(word_emb[None, :]) 54 | pos_one_hots = np.concatenate(pos_one_hots, axis=0) 55 | word_embeddings = np.concatenate(word_embeddings, axis=0) 56 | 57 | if m_length>') 111 | if next_button: 112 | self.set_next_record() 113 | imgui.same_line() 114 | tmp_idx = '' 115 | imgui.set_next_item_width(imgui.get_window_width() * 0.1) 116 | is_go_to, tmp_idx = imgui.input_text('', tmp_idx); imgui.same_line() 117 | if is_go_to: 118 | try: 119 | self.go_to_idx = int(tmp_idx) - 1 120 | except: 121 | pass 122 | go_to_button = imgui.button('>>Go<<'); imgui.same_line() 123 | if go_to_button: 124 | self.set_goto_record(self.go_to_idx) 125 | imgui.text(str(self.label_pid+1) + '/' + str(self.total_tasks)) 126 | 127 | imgui.text_wrapped(self.text_val) 128 | imgui.end() 129 | 130 | def set_prev_record(self): 131 | self.label_pid = (self.label_pid - 1) % self.total_tasks 132 | self.clear_one_sequence() 133 | self.load_one_sequence() 134 | self.scene.current_frame_id=0 135 | 136 | def set_next_record(self): 137 | self.label_pid = (self.label_pid + 1) % self.total_tasks 138 | self.clear_one_sequence() 139 | self.load_one_sequence() 140 | self.scene.current_frame_id=0 141 | 142 | def set_goto_record(self, idx): 143 | self.label_pid = int(idx) % self.total_tasks 144 | self.clear_one_sequence() 145 | self.load_one_sequence() 146 | self.scene.current_frame_id=0 147 | 148 | def get_label_file_list(self): 149 | for clip in sorted(os.listdir(self.clip_folder)): 150 | if not clip.startswith('.'): 151 | self.label_npy_list.append(os.path.join(self.clip_folder, clip)) 152 | 153 | def load_text_from_file(self): 154 | self.text_val = '' 155 | clip_name = os.path.split(self.label_npy_list[self.label_pid])[-1][:-4] 156 | if os.path.exists(os.path.join(self.text_folder, clip_name+'.txt')): 157 | with open(os.path.join(self.text_folder, clip_name+'.txt'), 'r') as f: 158 | for line in f.readlines(): 159 | self.text_val += line 160 | self.text_val += '\n' 161 | 162 | 163 | def load_one_sequence(self): 164 | skel_file = self.label_npy_list[self.label_pid] 165 | clip_name=os.path.split(skel_file)[-1][:-4] 166 | opj_pose_file=os.path.join(self.object_pose_folder, clip_name+'.npy') 167 | 168 | # load skeleton 169 | skel_data = np.load(skel_file, allow_pickle=True) 170 | skel_data=skel_data[:,SELECTED_JOINTS] 171 | skel=Skeletons( 172 | joint_positions=skel_data, 173 | joint_connections=OPTITRACK_LIMBS, 174 | radius=0.005 175 | ) 176 | self.scene.add(skel) 177 | 178 | # Load object 179 | object_pose=np.load(opj_pose_file, allow_pickle=True).item() 180 | meshes=[] 181 | for obj_name in object_pose.keys(): 182 | obj_pose=object_pose[obj_name] 183 | obj_mesh=self.object_mesh[obj_name] 184 | verts, faces = obj_mesh.vertices, obj_mesh.faces 185 | mesh = Meshes( 186 | vertices=verts, 187 | faces=faces, 188 | name=obj_name, 189 | position=obj_pose['transl'], 190 | rotation=obj_pose['rot'], 191 | color= (0.3,0.3,0.5,1) 192 | ) 193 | meshes.append(mesh) 194 | self.scene.add(*meshes) 195 | 196 | self.load_text_from_file() 197 | 198 | 199 | def clear_one_sequence(self): 200 | for x in self.scene.nodes.copy(): 201 | if type(x) is Skeletons or type(x) is Meshes: 202 | self.scene.remove(x) 203 | 204 | 205 | if __name__=='__main__': 206 | 207 | viewer=Seg_Viewer() 208 | viewer.scene.fps=30 209 | viewer.playback_fps=30 210 | viewer.run() -------------------------------------------------------------------------------- /visualize/skel_viewer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | import glfw 5 | import imgui 6 | import numpy as np 7 | import trimesh 8 | 9 | from aitviewer.configuration import CONFIG as C 10 | from aitviewer.renderables.meshes import Meshes 11 | from aitviewer.viewer import Viewer 12 | from aitviewer.renderables.skeletons import Skeletons 13 | 14 | glfw.init() 15 | primary_monitor = glfw.get_primary_monitor() 16 | mode = glfw.get_video_mode(primary_monitor) 17 | width = mode.size.width 18 | height = mode.size.height 19 | 20 | C.update_conf({'window_width': width*0.9, 'window_height': height*0.9}) 21 | 22 | OPTITRACK_LIMBS=[ 23 | [0,1],[1,2],[2,3],[3,4], 24 | [0,5],[5,6],[6,7],[7,8], 25 | [0,9],[9,10], 26 | [10,11],[11,12],[12,13],[13,14], 27 | [14,15],[15,16],[16,17],[17,18], 28 | [14,19],[19,20],[20,21],[21,22], 29 | [14,23],[23,24],[24,25],[25,26], 30 | [14,27],[27,28],[28,29],[29,30], 31 | [14,31],[31,32],[32,33],[33,34], 32 | [10,35],[35,36],[36,37],[37,38], 33 | [38,39],[39,40],[40,41],[41,42], 34 | [38,43],[43,44],[44,45],[45,46], 35 | [38,47],[47,48],[48,49],[49,50], 36 | [38,51],[51,52],[52,53],[53,54], 37 | [38,55],[55,56],[56,57],[57,58], 38 | [10,59],[59,60] 39 | ] 40 | 41 | SELECTED_JOINTS=np.concatenate( 42 | [range(0,5),range(6,10),range(11,63)] 43 | ) 44 | 45 | class Skel_Viewer(Viewer): 46 | title='HIMO Viewer for Skeleton' 47 | 48 | def __init__(self,**kwargs): 49 | super().__init__(**kwargs) 50 | self.gui_controls.update( 51 | { 52 | 'show_text':self.gui_show_text 53 | } 54 | ) 55 | self._set_prev_record=self.wnd.keys.UP 56 | self._set_next_record=self.wnd.keys.DOWN 57 | 58 | # reset 59 | self.reset_for_himo() 60 | self.load_one_sequence() 61 | 62 | def reset_for_himo(self): 63 | 64 | self.text_val = '' 65 | 66 | self.clip_folder = os.path.join('data','joints') 67 | self.text_folder = os.path.join('data','text') 68 | self.object_pose_folder = os.path.join('data','object_pose') 69 | self.object_mesh_folder = os.path.join('data','object_mesh') 70 | 71 | # Pre-load object meshes 72 | self.object_mesh={} 73 | for obj in os.listdir(self.object_mesh_folder): 74 | if not obj.startswith('.'): 75 | obj_name = obj.split('.')[0] 76 | obj_path = os.path.join(self.object_mesh_folder, obj) 77 | mesh = trimesh.load(obj_path) 78 | self.object_mesh[obj_name] = mesh 79 | 80 | self.label_npy_list = [] 81 | self.get_label_file_list() 82 | self.total_tasks = len(self.label_npy_list) 83 | 84 | self.label_pid = 0 85 | self.go_to_idx = 0 86 | 87 | def key_event(self, key, action, modifiers): 88 | if action==self.wnd.keys.ACTION_PRESS: 89 | if key==self._set_prev_record: 90 | self.set_prev_record() 91 | elif key==self._set_next_record: 92 | self.set_next_record() 93 | else: 94 | return super().key_event(key, action, modifiers) 95 | else: 96 | return super().key_event(key, action, modifiers) 97 | 98 | def gui_show_text(self): 99 | imgui.set_next_window_position(self.window_size[0] * 0.6, self.window_size[1]*0.25, imgui.FIRST_USE_EVER) 100 | imgui.set_next_window_size(self.window_size[0] * 0.35, self.window_size[1]*0.4, imgui.FIRST_USE_EVER) 101 | expanded, _ = imgui.begin("HIMO Text Descriptions", None) 102 | 103 | if expanded: 104 | npy_folder = self.label_npy_list[self.label_pid].split('/')[-1] 105 | imgui.text(str(npy_folder)) 106 | bef_button = imgui.button('<>') 111 | if next_button: 112 | self.set_next_record() 113 | imgui.same_line() 114 | tmp_idx = '' 115 | imgui.set_next_item_width(imgui.get_window_width() * 0.1) 116 | is_go_to, tmp_idx = imgui.input_text('', tmp_idx); imgui.same_line() 117 | if is_go_to: 118 | try: 119 | self.go_to_idx = int(tmp_idx) - 1 120 | except: 121 | pass 122 | go_to_button = imgui.button('>>Go<<'); imgui.same_line() 123 | if go_to_button: 124 | self.set_goto_record(self.go_to_idx) 125 | imgui.text(str(self.label_pid+1) + '/' + str(self.total_tasks)) 126 | 127 | imgui.text_wrapped(self.text_val) 128 | imgui.end() 129 | 130 | def set_prev_record(self): 131 | self.label_pid = (self.label_pid - 1) % self.total_tasks 132 | self.clear_one_sequence() 133 | self.load_one_sequence() 134 | self.scene.current_frame_id=0 135 | 136 | def set_next_record(self): 137 | self.label_pid = (self.label_pid + 1) % self.total_tasks 138 | self.clear_one_sequence() 139 | self.load_one_sequence() 140 | self.scene.current_frame_id=0 141 | 142 | def set_goto_record(self, idx): 143 | self.label_pid = int(idx) % self.total_tasks 144 | self.clear_one_sequence() 145 | self.load_one_sequence() 146 | self.scene.current_frame_id=0 147 | 148 | def get_label_file_list(self): 149 | for clip in sorted(os.listdir(self.clip_folder)): 150 | if not clip.startswith('.'): 151 | self.label_npy_list.append(os.path.join(self.clip_folder, clip)) 152 | 153 | def load_text_from_file(self): 154 | self.text_val = '' 155 | clip_name = os.path.split(self.label_npy_list[self.label_pid])[-1][:-4] 156 | if os.path.exists(os.path.join(self.text_folder, clip_name+'.txt')): 157 | with open(os.path.join(self.text_folder, clip_name+'.txt'), 'r') as f: 158 | for line in f.readlines(): 159 | self.text_val += line 160 | self.text_val += '\n' 161 | 162 | 163 | def load_one_sequence(self): 164 | skel_file = self.label_npy_list[self.label_pid] 165 | clip_name=os.path.split(skel_file)[-1][:-4] 166 | opj_pose_file=os.path.join(self.object_pose_folder, clip_name+'.npy') 167 | 168 | # load skeleton 169 | skel_data = np.load(skel_file, allow_pickle=True) 170 | skel_data=skel_data[:,SELECTED_JOINTS] 171 | skel=Skeletons( 172 | joint_positions=skel_data, 173 | joint_connections=OPTITRACK_LIMBS, 174 | radius=0.005 175 | ) 176 | self.scene.add(skel) 177 | 178 | # Load object 179 | object_pose=np.load(opj_pose_file, allow_pickle=True).item() 180 | meshes=[] 181 | for obj_name in object_pose.keys(): 182 | obj_pose=object_pose[obj_name] 183 | obj_mesh=self.object_mesh[obj_name] 184 | verts, faces = obj_mesh.vertices, obj_mesh.faces 185 | mesh = Meshes( 186 | vertices=verts, 187 | faces=faces, 188 | name=obj_name, 189 | position=obj_pose['transl'], 190 | rotation=obj_pose['rot'], 191 | color= (0.3,0.3,0.5,1) 192 | ) 193 | meshes.append(mesh) 194 | self.scene.add(*meshes) 195 | 196 | self.load_text_from_file() 197 | 198 | 199 | def clear_one_sequence(self): 200 | for x in self.scene.nodes.copy(): 201 | if type(x) is Skeletons or type(x) is Meshes: 202 | self.scene.remove(x) 203 | 204 | 205 | if __name__=='__main__': 206 | 207 | viewer=Skel_Viewer() 208 | viewer.scene.fps=30 209 | viewer.playback_fps=30 210 | viewer.run() -------------------------------------------------------------------------------- /visualize/smplx_viewer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | import glfw 5 | import imgui 6 | import numpy as np 7 | import trimesh 8 | 9 | from aitviewer.configuration import CONFIG as C 10 | from aitviewer.renderables.meshes import Meshes 11 | from aitviewer.viewer import Viewer 12 | from aitviewer.models.smpl import SMPLLayer 13 | from aitviewer.renderables.smpl import SMPLSequence 14 | 15 | glfw.init() 16 | primary_monitor = glfw.get_primary_monitor() 17 | mode = glfw.get_video_mode(primary_monitor) 18 | width = mode.size.width 19 | height = mode.size.height 20 | 21 | C.update_conf({'window_width': width*0.9, 'window_height': height*0.9}) 22 | C.update_conf({'smplx_models':'./body_models'}) 23 | 24 | class SMPLX_Viewer(Viewer): 25 | title='HIMO Viewer for SMPL-X' 26 | 27 | def __init__(self,**kwargs): 28 | super().__init__(**kwargs) 29 | self.gui_controls.update( 30 | { 31 | 'show_text':self.gui_show_text 32 | } 33 | ) 34 | self._set_prev_record=self.wnd.keys.UP 35 | self._set_next_record=self.wnd.keys.DOWN 36 | 37 | # reset 38 | self.reset_for_himo() 39 | self.load_one_sequence() 40 | 41 | def reset_for_himo(self): 42 | 43 | self.text_val = '' 44 | 45 | self.clip_folder = os.path.join('data','smplx') 46 | self.text_folder = os.path.join('data','text') 47 | self.object_pose_folder = os.path.join('data','object_pose') 48 | self.object_mesh_folder = os.path.join('data','object_mesh') 49 | 50 | # Pre-load object meshes 51 | self.object_mesh={} 52 | for obj in os.listdir(self.object_mesh_folder): 53 | if not obj.startswith('.'): 54 | obj_name = obj.split('.')[0] 55 | obj_path = os.path.join(self.object_mesh_folder, obj) 56 | mesh = trimesh.load(obj_path) 57 | self.object_mesh[obj_name] = mesh 58 | 59 | self.label_npy_list = [] 60 | self.get_label_file_list() 61 | self.total_tasks = len(self.label_npy_list) 62 | 63 | self.label_pid = 0 64 | self.go_to_idx = 0 65 | 66 | def key_event(self, key, action, modifiers): 67 | if action==self.wnd.keys.ACTION_PRESS: 68 | if key==self._set_prev_record: 69 | self.set_prev_record() 70 | elif key==self._set_next_record: 71 | self.set_next_record() 72 | else: 73 | return super().key_event(key, action, modifiers) 74 | else: 75 | return super().key_event(key, action, modifiers) 76 | 77 | def gui_show_text(self): 78 | imgui.set_next_window_position(self.window_size[0] * 0.6, self.window_size[1]*0.25, imgui.FIRST_USE_EVER) 79 | imgui.set_next_window_size(self.window_size[0] * 0.35, self.window_size[1]*0.4, imgui.FIRST_USE_EVER) 80 | expanded, _ = imgui.begin("HIMO Text Descriptions", None) 81 | 82 | if expanded: 83 | npy_folder = self.label_npy_list[self.label_pid].split('/')[-1] 84 | imgui.text(str(npy_folder)) 85 | bef_button = imgui.button('<>') 90 | if next_button: 91 | self.set_next_record() 92 | imgui.same_line() 93 | tmp_idx = '' 94 | imgui.set_next_item_width(imgui.get_window_width() * 0.1) 95 | is_go_to, tmp_idx = imgui.input_text('', tmp_idx); imgui.same_line() 96 | if is_go_to: 97 | try: 98 | self.go_to_idx = int(tmp_idx) - 1 99 | except: 100 | pass 101 | go_to_button = imgui.button('>>Go<<'); imgui.same_line() 102 | if go_to_button: 103 | self.set_goto_record(self.go_to_idx) 104 | imgui.text(str(self.label_pid+1) + '/' + str(self.total_tasks)) 105 | 106 | imgui.text_wrapped(self.text_val) 107 | imgui.end() 108 | 109 | def set_prev_record(self): 110 | self.label_pid = (self.label_pid - 1) % self.total_tasks 111 | self.clear_one_sequence() 112 | self.load_one_sequence() 113 | self.scene.current_frame_id=0 114 | 115 | def set_next_record(self): 116 | self.label_pid = (self.label_pid + 1) % self.total_tasks 117 | self.clear_one_sequence() 118 | self.load_one_sequence() 119 | self.scene.current_frame_id=0 120 | 121 | def set_goto_record(self, idx): 122 | self.label_pid = int(idx) % self.total_tasks 123 | self.clear_one_sequence() 124 | self.load_one_sequence() 125 | self.scene.current_frame_id=0 126 | 127 | def get_label_file_list(self): 128 | for clip in sorted(os.listdir(self.clip_folder)): 129 | if not clip.startswith('.'): 130 | self.label_npy_list.append(os.path.join(self.clip_folder, clip)) 131 | 132 | def load_text_from_file(self): 133 | self.text_val = '' 134 | clip_name = os.path.split(self.label_npy_list[self.label_pid])[-1][:-4] 135 | if os.path.exists(os.path.join(self.text_folder, clip_name+'.txt')): 136 | with open(os.path.join(self.text_folder, clip_name+'.txt'), 'r') as f: 137 | for line in f.readlines(): 138 | self.text_val += line 139 | self.text_val += '\n' 140 | 141 | 142 | def load_one_sequence(self): 143 | smplx_file = self.label_npy_list[self.label_pid] 144 | clip_name=os.path.split(smplx_file)[-1][:-4] 145 | opj_pose_file=os.path.join(self.object_pose_folder, clip_name+'.npy') 146 | 147 | # load smplx 148 | 149 | smplx_params = np.load(smplx_file, allow_pickle=True) 150 | nf = smplx_params['body_pose'].shape[0] 151 | 152 | betas = smplx_params['betas'] 153 | poses_root = smplx_params['global_orient'] 154 | poses_body = smplx_params['body_pose'].reshape(nf,-1) 155 | poses_lhand = smplx_params['lhand_pose'].reshape(nf,-1) 156 | poses_rhand = smplx_params['rhand_pose'].reshape(nf,-1) 157 | transl = smplx_params['transl'] 158 | 159 | # create body models 160 | smplx_layer = SMPLLayer(model_type='smplx',gender='neutral',num_betas=10,device=C.device) 161 | 162 | # create smplx sequence for two persons 163 | smplx_seq = SMPLSequence(poses_body=poses_body, 164 | smpl_layer=smplx_layer, 165 | poses_root=poses_root, 166 | betas=betas, 167 | trans=transl, 168 | poses_left_hand=poses_lhand, 169 | poses_right_hand=poses_rhand, 170 | device=C.device, 171 | ) 172 | 173 | self.scene.add(smplx_seq) 174 | 175 | # Load object 176 | object_pose=np.load(opj_pose_file, allow_pickle=True).item() 177 | meshes=[] 178 | for obj_name in object_pose.keys(): 179 | obj_pose=object_pose[obj_name] 180 | obj_mesh=self.object_mesh[obj_name] 181 | verts, faces = obj_mesh.vertices, obj_mesh.faces 182 | mesh = Meshes( 183 | vertices=verts, 184 | faces=faces, 185 | name=obj_name, 186 | position=obj_pose['transl'], 187 | rotation=obj_pose['rot'], 188 | color= (0.3,0.3,0.5,1) 189 | ) 190 | meshes.append(mesh) 191 | self.scene.add(*meshes) 192 | 193 | self.load_text_from_file() 194 | 195 | 196 | def clear_one_sequence(self): 197 | for x in self.scene.nodes.copy(): 198 | if type(x) is SMPLSequence or type(x) is SMPLLayer or type(x) is Meshes: 199 | self.scene.remove(x) 200 | 201 | 202 | if __name__=='__main__': 203 | 204 | viewer=SMPLX_Viewer() 205 | viewer.scene.fps=30 206 | viewer.playback_fps=30 207 | viewer.run() -------------------------------------------------------------------------------- /src/diffusion/fp16_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers to train with 16-bit precision. 3 | """ 4 | 5 | import numpy as np 6 | import torch as th 7 | import torch.nn as nn 8 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors 9 | 10 | from src.diffusion import logger 11 | 12 | INITIAL_LOG_LOSS_SCALE = 20.0 13 | 14 | 15 | def convert_module_to_f16(l): 16 | """ 17 | Convert primitive modules to float16. 18 | """ 19 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 20 | l.weight.data = l.weight.data.half() 21 | if l.bias is not None: 22 | l.bias.data = l.bias.data.half() 23 | 24 | 25 | def convert_module_to_f32(l): 26 | """ 27 | Convert primitive modules to float32, undoing convert_module_to_f16(). 28 | """ 29 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 30 | l.weight.data = l.weight.data.float() 31 | if l.bias is not None: 32 | l.bias.data = l.bias.data.float() 33 | 34 | 35 | def make_master_params(param_groups_and_shapes): 36 | """ 37 | Copy model parameters into a (differently-shaped) list of full-precision 38 | parameters. 39 | """ 40 | master_params = [] 41 | for param_group, shape in param_groups_and_shapes: 42 | master_param = nn.Parameter( 43 | _flatten_dense_tensors( 44 | [param.detach().float() for (_, param) in param_group] 45 | ).view(shape) 46 | ) 47 | master_param.requires_grad = True 48 | master_params.append(master_param) 49 | return master_params 50 | 51 | 52 | def model_grads_to_master_grads(param_groups_and_shapes, master_params): 53 | """ 54 | Copy the gradients from the model parameters into the master parameters 55 | from make_master_params(). 56 | """ 57 | for master_param, (param_group, shape) in zip( 58 | master_params, param_groups_and_shapes 59 | ): 60 | master_param.grad = _flatten_dense_tensors( 61 | [param_grad_or_zeros(param) for (_, param) in param_group] 62 | ).view(shape) 63 | 64 | 65 | def master_params_to_model_params(param_groups_and_shapes, master_params): 66 | """ 67 | Copy the master parameter data back into the model parameters. 68 | """ 69 | # Without copying to a list, if a generator is passed, this will 70 | # silently not copy any parameters. 71 | for master_param, (param_group, _) in zip(master_params, param_groups_and_shapes): 72 | for (_, param), unflat_master_param in zip( 73 | param_group, unflatten_master_params(param_group, master_param.view(-1)) 74 | ): 75 | param.detach().copy_(unflat_master_param) 76 | 77 | 78 | def unflatten_master_params(param_group, master_param): 79 | return _unflatten_dense_tensors(master_param, [param for (_, param) in param_group]) 80 | 81 | 82 | def get_param_groups_and_shapes(named_model_params): 83 | named_model_params = list(named_model_params) 84 | scalar_vector_named_params = ( 85 | [(n, p) for (n, p) in named_model_params if p.ndim <= 1], 86 | (-1), 87 | ) 88 | matrix_named_params = ( 89 | [(n, p) for (n, p) in named_model_params if p.ndim > 1], 90 | (1, -1), 91 | ) 92 | return [scalar_vector_named_params, matrix_named_params] 93 | 94 | 95 | def master_params_to_state_dict( 96 | model, param_groups_and_shapes, master_params, use_fp16 97 | ): 98 | if use_fp16: 99 | state_dict = model.state_dict() 100 | for master_param, (param_group, _) in zip( 101 | master_params, param_groups_and_shapes 102 | ): 103 | for (name, _), unflat_master_param in zip( 104 | param_group, unflatten_master_params(param_group, master_param.view(-1)) 105 | ): 106 | assert name in state_dict 107 | state_dict[name] = unflat_master_param 108 | else: 109 | state_dict = model.state_dict() 110 | for i, (name, _value) in enumerate(model.named_parameters()): 111 | assert name in state_dict 112 | state_dict[name] = master_params[i] 113 | return state_dict 114 | 115 | 116 | def state_dict_to_master_params(model, state_dict, use_fp16): 117 | if use_fp16: 118 | named_model_params = [ 119 | (name, state_dict[name]) for name, _ in model.named_parameters() 120 | ] 121 | param_groups_and_shapes = get_param_groups_and_shapes(named_model_params) 122 | master_params = make_master_params(param_groups_and_shapes) 123 | else: 124 | master_params = [state_dict[name] for name, _ in model.named_parameters()] 125 | return master_params 126 | 127 | 128 | def zero_master_grads(master_params): 129 | for param in master_params: 130 | param.grad = None 131 | 132 | 133 | def zero_grad(model_params): 134 | for param in model_params: 135 | # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group 136 | if param.grad is not None: 137 | param.grad.detach_() 138 | param.grad.zero_() 139 | 140 | 141 | def param_grad_or_zeros(param): 142 | if param.grad is not None: 143 | return param.grad.data.detach() 144 | else: 145 | return th.zeros_like(param) 146 | 147 | 148 | class MixedPrecisionTrainer: 149 | def __init__( 150 | self, 151 | *, 152 | model, 153 | use_fp16=False, 154 | fp16_scale_growth=1e-3, 155 | initial_lg_loss_scale=INITIAL_LOG_LOSS_SCALE, 156 | ): 157 | self.model = model 158 | self.use_fp16 = use_fp16 159 | self.fp16_scale_growth = fp16_scale_growth 160 | 161 | self.model_params = list(self.model.parameters()) 162 | self.master_params = self.model_params 163 | self.param_groups_and_shapes = None 164 | self.lg_loss_scale = initial_lg_loss_scale 165 | 166 | if self.use_fp16: 167 | self.param_groups_and_shapes = get_param_groups_and_shapes( 168 | self.model.named_parameters() 169 | ) 170 | self.master_params = make_master_params(self.param_groups_and_shapes) 171 | self.model.convert_to_fp16() 172 | 173 | def zero_grad(self): 174 | zero_grad(self.model_params) 175 | 176 | def backward(self, loss: th.Tensor): 177 | if self.use_fp16: 178 | loss_scale = 2 ** self.lg_loss_scale 179 | (loss * loss_scale).backward() 180 | else: 181 | loss.backward() 182 | 183 | def optimize(self, opt: th.optim.Optimizer): 184 | if self.use_fp16: 185 | return self._optimize_fp16(opt) 186 | else: 187 | return self._optimize_normal(opt) 188 | 189 | def _optimize_fp16(self, opt: th.optim.Optimizer): 190 | logger.logkv_mean("lg_loss_scale", self.lg_loss_scale) 191 | model_grads_to_master_grads(self.param_groups_and_shapes, self.master_params) 192 | grad_norm, param_norm = self._compute_norms(grad_scale=2 ** self.lg_loss_scale) 193 | if check_overflow(grad_norm): 194 | self.lg_loss_scale -= 1 195 | logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}") 196 | zero_master_grads(self.master_params) 197 | return False 198 | 199 | logger.logkv_mean("grad_norm", grad_norm) 200 | logger.logkv_mean("param_norm", param_norm) 201 | 202 | self.master_params[0].grad.mul_(1.0 / (2 ** self.lg_loss_scale)) 203 | opt.step() 204 | zero_master_grads(self.master_params) 205 | master_params_to_model_params(self.param_groups_and_shapes, self.master_params) 206 | self.lg_loss_scale += self.fp16_scale_growth 207 | return True 208 | 209 | def _optimize_normal(self, opt: th.optim.Optimizer): 210 | grad_norm, param_norm = self._compute_norms() 211 | logger.logkv_mean("grad_norm", grad_norm) 212 | logger.logkv_mean("param_norm", param_norm) 213 | opt.step() 214 | return True 215 | 216 | def _compute_norms(self, grad_scale=1.0): 217 | grad_norm = 0.0 218 | param_norm = 0.0 219 | for p in self.master_params: 220 | with th.no_grad(): 221 | param_norm += th.norm(p, p=2, dtype=th.float32).item() ** 2 222 | if p.grad is not None: 223 | grad_norm += th.norm(p.grad, p=2, dtype=th.float32).item() ** 2 224 | return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm) 225 | 226 | def master_params_to_state_dict(self, master_params): 227 | return master_params_to_state_dict( 228 | self.model, self.param_groups_and_shapes, master_params, self.use_fp16 229 | ) 230 | 231 | def state_dict_to_master_params(self, state_dict): 232 | return state_dict_to_master_params(self.model, state_dict, self.use_fp16) 233 | 234 | 235 | def check_overflow(value): 236 | return (value == float("inf")) or (value == -float("inf")) or (value != value) 237 | -------------------------------------------------------------------------------- /src/utils/rotation_conversion.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script contains functions for converting between different rotation representations. 3 | Including: 4 | - Euler angles (euler) default to (XYZ,intrinsic) 5 | - Rotation matrices (rot) 6 | - Quaternions (quat) 7 | - Axis-angle (aa) 8 | - 6D representation (6d) 9 | We also provide numpy and torch versions of these functions. 10 | Note that all functions are in batch mode. 11 | """ 12 | import torch 13 | import numpy as np 14 | from scipy.spatial.transform import Rotation as R 15 | from pytorch3d.transforms.rotation_conversions import ( 16 | quaternion_to_axis_angle,quaternion_to_matrix, 17 | axis_angle_to_matrix,axis_angle_to_quaternion, 18 | matrix_to_axis_angle,matrix_to_quaternion, 19 | euler_angles_to_matrix,matrix_to_euler_angles, 20 | rotation_6d_to_matrix,matrix_to_rotation_6d 21 | ) 22 | 23 | #--------------------Numpy Version--------------------------# 24 | def euler2rot_numpy(euler,degrees=False): 25 | """ 26 | euler:[B,3] (XYZ,extrinsic) 27 | degrees are False if they are radians 28 | return: [B,3,3] 29 | """ 30 | assert isinstance(euler, np.ndarray) 31 | ori_shape = euler.shape[:-1] 32 | rots = np.reshape(euler, (-1, 3)) 33 | rots = R.as_matrix(R.from_euler("XYZ", rots, degrees=degrees)) 34 | rot = np.reshape(rots, ori_shape + (3, 3)) 35 | return rot 36 | 37 | def rot2euler_numpy(rot,degrees=False): 38 | """ 39 | rot:[B,3,3] 40 | return: [B,3] 41 | """ 42 | assert isinstance(rot, np.ndarray) 43 | ori_shape = rot.shape[:-2] 44 | rots = np.reshape(rot, (-1, 3, 3)) 45 | rots = R.as_euler(R.from_matrix(rots), "XYZ", degrees=degrees) 46 | euler = np.reshape(rots, ori_shape + (3,)) 47 | return euler 48 | 49 | def euler2quat_numpy(euler,degrees=False): 50 | """ 51 | euler:[B,3] 52 | return [B,4] 53 | """ 54 | assert isinstance(euler,np.ndarray) 55 | ori_shape=euler.shape[:-1] 56 | rots=np.reshape(euler,(-1,3)) 57 | quats=R.as_quat(R.from_euler("XYZ",rots,degrees=degrees)) 58 | quat=np.reshape(quats,ori_shape+(4,)) 59 | return quat 60 | 61 | def quat2euler_numpy(quat,degrees=False): 62 | """ 63 | quat:[B,4] 64 | return [B,3] 65 | """ 66 | assert isinstance(quat,np.ndarray) 67 | ori_shape=quat.shape[:-1] 68 | rots=np.reshape(quat,(-1,3)) 69 | eulers=R.as_euler("XYZ",R.from_quat(rots),degrees=degrees) 70 | euler=np.reshape(eulers,ori_shape+(3,)) 71 | return euler 72 | 73 | def euler2aa_numpy(euler,degrees=False): 74 | """ 75 | euler:[B,3] 76 | return: [B,3] 77 | """ 78 | assert isinstance(euler, np.ndarray) 79 | ori_shape = euler.shape[:-1] 80 | rots = np.reshape(euler, (-1, 3)) 81 | aas = R.as_rotvec(R.from_euler("XYZ", rots, degrees=degrees)) 82 | rotation_vectors = np.reshape(aas, ori_shape + (3,)) 83 | return rotation_vectors 84 | 85 | def aa2euler_numpy(aa,degrees=False): 86 | """ 87 | aa:[B,3] 88 | return [B,3] 89 | """ 90 | assert isinstance(aa, np.ndarray) 91 | ori_shape = aa.shape[:-1] 92 | aas = np.reshape(aa, (-1, 3)) 93 | rots = R.as_euler(R.from_rotvec(aas), "XYZ", degrees=degrees) 94 | euler_angles = np.reshape(rots, ori_shape + (3,)) 95 | return euler_angles 96 | 97 | def rot2quat_numpy(rot): 98 | """ 99 | rot:[B,3,3] 100 | return [B,4] 101 | """ 102 | return euler2quat_numpy(rot2euler_numpy(rot)) 103 | 104 | def quat2rot_numpy(quat): 105 | """ 106 | quat:[B,4] (w,x,y,z) 107 | return: [B,3,3] 108 | """ 109 | return euler2rot_numpy(quat2euler_numpy(quat)) 110 | 111 | def rot2aa_numpy(rot): 112 | """ 113 | rot:[B,3,3] 114 | return:[B,3] 115 | """ 116 | assert isinstance(rot, np.ndarray) 117 | ori_shape = rot.shape[:-2] 118 | rots = np.reshape(rot, (-1, 3, 3)) 119 | aas = R.as_rotvec(R.from_matrix(rots)) 120 | rotation_vectors = np.reshape(aas, ori_shape + (3,)) 121 | return rotation_vectors 122 | 123 | def aa2rot_numpy(aa): 124 | """ 125 | aa:[B,3] 126 | Rodirgues formula 127 | return: [B,3,3] 128 | """ 129 | assert isinstance(aa,np.ndarray) 130 | ori_shape=aa.shape[:-1] 131 | aas=np.reshape(aa,(-1,3)) 132 | rots=R.as_matrix(R.from_rotvec(aas)) 133 | rot_mat=np.reshape(rots,ori_shape+(3,3)) 134 | return rot_mat 135 | 136 | def quat2aa_numpy(quat): 137 | """ 138 | quat:[B,4] 139 | return [B,3] 140 | """ 141 | return euler2aa_numpy(quat2euler_numpy(quat)) 142 | 143 | def aa2quat_numpy(aa): 144 | """ 145 | aa:[B,3] 146 | return [B,4] 147 | """ 148 | return euler2quat_numpy(aa2euler_numpy(aa)) 149 | 150 | def rot2sixd_numpy(rot): 151 | """ 152 | rot:[B,3,3] 153 | return [B,6] 154 | """ 155 | assert isinstance(rot,np.ndarray) 156 | ori_shape=rot.shape[:-2] 157 | return rot[...,:2,:].copy().reshape(ori_shape+(6,)) 158 | 159 | def sixd2rot_numpy(sixd): 160 | """ 161 | sixd:[B,6] 162 | return [B,3,3] 163 | """ 164 | assert isinstance(sixd,np.ndarray) 165 | a1,a2=sixd[...,:3],sixd[...,3:] 166 | b1=a1/np.linalg.norm(a1,axis=-1,keepdims=True) 167 | b2=a2-(b1*a2).sum(-1,keepdims=True)*b1 168 | b2=b2/np.linalg.norm(b2,axis=-1,keepdims=True) 169 | b3=np.cross(b1,b2,axis=-1) 170 | return np.stack([b1,b2,b3],axis=-2) 171 | 172 | def rpy2rot_numpy(rpy,degrees=False): 173 | """ 174 | rpy: [B,3] (ZYX,intrinsic) 175 | return [B,3,3] 176 | """ 177 | assert isinstance(rpy, np.ndarray) 178 | ori_shape = rpy.shape[:-1] 179 | rots = np.reshape(rpy, (-1, 3)) 180 | rots = R.as_matrix(R.from_euler("ZYX", rots, degrees=degrees)) 181 | rotation_matrices = np.reshape(rots, ori_shape + (3, 3)) 182 | return rotation_matrices 183 | 184 | #--------------------Pytorch Version--------------------------# 185 | def euler2rot_torch(euler,degrees=False): 186 | """ 187 | euler [B,3] (XYZ,intrinsic) 188 | degrees are False if they are radians 189 | """ 190 | if degrees: 191 | euler_rad=torch.deg2rad(euler) 192 | return euler_angles_to_matrix(euler_rad,"XYZ") 193 | else: 194 | return euler_angles_to_matrix(euler,"XYZ") 195 | 196 | def rot2euler_torch(rot,degrees=False): 197 | """ 198 | rot:[B,3,3] 199 | return: [B,3] 200 | """ 201 | if degrees: 202 | euler_rad=matrix_to_euler_angles(rot,"XYZ") 203 | return torch.rad2deg(euler_rad) 204 | else: 205 | return matrix_to_euler_angles(rot,"XYZ") 206 | 207 | def euler2quat_torch(euler,degrees=False): 208 | """ 209 | euler:[B,3] 210 | return [B,4] 211 | """ 212 | if degrees: 213 | euler_rad=torch.deg2rad(euler) 214 | return matrix_to_quaternion(euler_angles_to_matrix(euler_rad,"XYZ")) 215 | else: 216 | return matrix_to_quaternion(euler_angles_to_matrix(euler,"XYZ")) 217 | 218 | def quat2euler_torch(quat,degrees=False): 219 | """ 220 | quat:[B,4] 221 | return [B,3] 222 | """ 223 | if degrees: 224 | euler_rad=quaternion_to_matrix(quat) 225 | return torch.rad2deg(matrix_to_euler_angles(euler_rad,"XYZ")) 226 | else: 227 | return matrix_to_euler_angles(quaternion_to_matrix(quat),"XYZ") 228 | 229 | def euler2aa_torch(euler,degrees=False): 230 | """ 231 | euler:[B,3] 232 | return: [B,3] 233 | """ 234 | if degrees: 235 | euler_rad=torch.deg2rad(euler) 236 | return matrix_to_axis_angle(euler_angles_to_matrix(euler_rad,"XYZ")) 237 | else: 238 | return matrix_to_axis_angle(euler_angles_to_matrix(euler,"XYZ")) 239 | 240 | def aa2euler_torch(aa,degrees=False): 241 | """ 242 | aa:[B,3] 243 | return [B,3] 244 | """ 245 | if degrees: 246 | euler_rad=axis_angle_to_matrix(aa) 247 | return torch.rad2deg(matrix_to_euler_angles(euler_rad,"XYZ")) 248 | else: 249 | return matrix_to_euler_angles(axis_angle_to_matrix(aa),"XYZ") 250 | 251 | def rot2quat_torch(rot): 252 | """ 253 | rot:[B,3,3] 254 | return [B,4] 255 | """ 256 | return matrix_to_quaternion(rot) 257 | 258 | def quat2rot_torch(quat): 259 | """ 260 | quat:[B,4] (w,x,y,z) 261 | return: [B,3,3] 262 | """ 263 | return quaternion_to_matrix(quat) 264 | 265 | def rot2aa_torch(rot): 266 | """ 267 | rot:[B,3,3] 268 | return:[B,3] 269 | """ 270 | return matrix_to_axis_angle(rot) 271 | 272 | def aa2rot_torch(aa): 273 | """ 274 | aa:[B,3] 275 | Rodirgues formula 276 | return: [B,3,3] 277 | """ 278 | return axis_angle_to_matrix(aa) 279 | 280 | def quat2aa_torch(quat): 281 | """ 282 | quat:[B,4] 283 | return [B,3] 284 | """ 285 | return quaternion_to_axis_angle(quat) 286 | 287 | def aa2quat_torch(aa): 288 | """ 289 | aa:[B,3] 290 | return [B,4] 291 | """ 292 | return axis_angle_to_quaternion(aa) 293 | 294 | def rot2sixd_torch(rot): 295 | """ 296 | rot:[B,3,3] 297 | return [B,6] 298 | """ 299 | return matrix_to_rotation_6d(rot) 300 | 301 | def sixd2rot_torch(sixd): 302 | """ 303 | sixd:[B,6] 304 | return [B,3,3] 305 | """ 306 | return rotation_6d_to_matrix(sixd) 307 | 308 | if __name__=='__main__': 309 | rot=np.eye(3)[None,...].repeat(2,axis=0) 310 | sixd=rot2sixd_numpy(rot) 311 | new_rot=sixd2rot_numpy(sixd) 312 | print(new_rot) -------------------------------------------------------------------------------- /src/model/net_2o.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from loguru import logger 6 | import clip 7 | from src.model.blocks import TransformerBlock 8 | 9 | class NET_2O(nn.Module): 10 | def __init__(self, 11 | n_feats=(52*3+52*6+3+2*9), 12 | clip_dim=512, 13 | latent_dim=512, 14 | ff_size=1024, 15 | num_layers=8, 16 | num_heads=4, 17 | dropout=0.1, 18 | ablation=None, 19 | activation="gelu",**kwargs): 20 | super().__init__() 21 | 22 | self.n_feats = n_feats 23 | self.clip_dim = clip_dim 24 | self.latent_dim = latent_dim 25 | self.ff_size = ff_size 26 | self.num_layers = num_layers 27 | self.num_heads = num_heads 28 | self.dropout = dropout 29 | self.ablation = ablation 30 | self.activation = activation 31 | 32 | self.cond_mask_prob=kwargs.get('cond_mask_prob',0.1) 33 | 34 | # clip and text embedder 35 | self.embed_text=nn.Linear(self.clip_dim,self.latent_dim) 36 | self.clip_version='ViT-B/32' 37 | self.clip_model=self.load_and_freeze_clip(self.clip_version) 38 | 39 | # object_geometry embedder 40 | self.embed_obj_bps=nn.Linear(1024*3,self.latent_dim) 41 | # object init state embeddr 42 | self.embed_obj_pose=nn.Linear(2*self.latent_dim+2*9+2*9,self.latent_dim) 43 | 44 | # human embedder 45 | self.embed_human_pose=nn.Linear(2*(52*3+52*6+3),self.latent_dim) 46 | 47 | # position encoding 48 | self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, dropout=self.dropout) 49 | 50 | # TODO:unshared transformer layers for human and objects,they fuse feature in the middle layers 51 | seqTransEncoderLayer=nn.TransformerEncoderLayer( 52 | d_model=self.latent_dim, 53 | nhead=self.num_heads, 54 | dim_feedforward=self.ff_size, 55 | dropout=self.dropout, 56 | activation=self.activation, 57 | ) 58 | # # object transformer layers 59 | # self.object_trans_encoder=nn.TransformerEncoder(seqTransEncoderLayer,num_layers=self.num_layers) 60 | # # human transformer encoder 61 | # self.human_trans_encoder=nn.TransformerEncoder(seqTransEncoderLayer,num_layers=self.num_layers) 62 | 63 | # # Mutal cross attention 64 | # self.communication_module=nn.ModuleList() 65 | # for i in range(8): 66 | # self.communication_module.append(MutalCrossAttentionBlock(self.latent_dim,self.num_heads,self.ff_size,self.dropout)) 67 | self.obj_blocks=nn.ModuleList() 68 | self.human_blocks=nn.ModuleList() 69 | for i in range(self.num_layers): 70 | self.obj_blocks.append(TransformerBlock(self.latent_dim, self.num_heads, self.ff_size, self.dropout, self.activation)) 71 | self.human_blocks.append(TransformerBlock(self.latent_dim, self.num_heads, self.ff_size, self.dropout, self.activation)) 72 | 73 | 74 | # embed the timestep 75 | self.embed_timestep=TimestepEmbedder(self.latent_dim, self.sequence_pos_encoder) 76 | 77 | # object output process 78 | self.obj_output_process=nn.Linear(self.latent_dim,2*9) 79 | # human motion output process 80 | self.human_output_process=nn.Linear(self.latent_dim,52*3+52*6+3) 81 | 82 | def forward(self,x,timesteps,y=None): 83 | bs,nframes,n_feats=x.shape 84 | emb=self.embed_timestep(timesteps) # [1, bs, latent_dim] 85 | 86 | enc_text=self.encode_text(y['text']) 87 | emb+=self.mask_cond(enc_text) # 1,bs,latent_dim 88 | 89 | x=x.permute((1,0,2)) # nframes,bs,nfeats 90 | human_x,obj_x=torch.split(x,[52*3+52*6+3,2*9],dim=-1) # [nframes,bs,52*3+52*6+3],[nframes,bs,2*9] 91 | 92 | # encode object geometry 93 | obj1_bps,obj2_bps=y['obj1_bps'].reshape(bs,-1),y['obj2_bps'].reshape(bs,-1) # [b,1024,3] 94 | obj1_bps_emb=self.embed_obj_bps(obj1_bps) # [b,latent_dim] 95 | obj2_bps_emb=self.embed_obj_bps(obj2_bps) # [b,latent_dim] 96 | obj_geo_emb=torch.concat([obj1_bps_emb,obj2_bps_emb],dim=-1).unsqueeze(0).repeat((nframes,1,1)) # [nf,b,2*latent_dim] 97 | 98 | # init_state,mask the other frames by padding zeros 99 | init_state=y['init_state'].unsqueeze(0) # [1,b,52*3+52*6+3+2*9] 100 | padded_zeros=torch.zeros((nframes-1,bs,52*3+52*6+3+2*9),device=init_state.device) 101 | init_state=torch.concat([init_state,padded_zeros],dim=0) # [nf,b,52*3+52*6+3+2*9] 102 | 103 | # seperate the object and human init state 104 | human_init_state=init_state[:,:,:52*3+52*6+3] # [nf,b,52*3+52*6+3] 105 | obj_init_state=init_state[:,:,52*3+52*6+3:] # [nf,b,2*9] 106 | 107 | # Object branch 108 | obj_emb=self.embed_obj_pose(torch.concat([obj_geo_emb,obj_init_state,obj_x],dim=-1)) # nframes,bs,latent_dim 109 | obj_seq_prev=self.sequence_pos_encoder(obj_emb) # [nf,bs,latent_dim] 110 | 111 | 112 | # Human branch 113 | human_emb=self.embed_human_pose(torch.concat([human_init_state,human_x],dim=-1)) # nframes,bs,latent_dim 114 | human_seq_prev=self.sequence_pos_encoder(human_emb) # [nf,bs,latent_dim] 115 | 116 | mask=y['mask'].squeeze(1).squeeze(1) # [bs,nf] 117 | key_padding_mask=~mask # [bs,nf] 118 | 119 | obj_seq_prev=obj_seq_prev.permute((1,0,2)) # [bs,nf,latent_dim] 120 | human_seq_prev=human_seq_prev.permute((1,0,2)) # [bs,nf,latent_dim] 121 | emb=emb.squeeze(0) # [bs,latent_dim] 122 | for i in range(self.num_layers): 123 | obj_seq=self.obj_blocks[i](obj_seq_prev,human_seq_prev,emb, key_padding_mask=key_padding_mask) 124 | human_seq=self.human_blocks[i](human_seq_prev,obj_seq_prev,emb, key_padding_mask=key_padding_mask) 125 | obj_seq_prev=obj_seq 126 | human_seq_prev=human_seq 127 | obj_seq=obj_seq.permute((1,0,2)) # [nf,bs,latent_dim] 128 | human_seq=human_seq.permute((1,0,2)) 129 | 130 | 131 | obj_output=self.obj_output_process(obj_seq) # [nf,bs,2*9] 132 | human_output=self.human_output_process(human_seq) # [nf,bs,52*3+52*6+3] 133 | 134 | output=torch.concat([human_output,obj_output],dim=-1) # [nf,bs,52*3+52*6+3+2*9] 135 | 136 | return output.permute((1,0,2)) 137 | 138 | def encode_text(self,raw_text): 139 | # raw_text - list (batch_size length) of strings with input text prompts 140 | device = next(self.parameters()).device 141 | max_text_len = 40 142 | if max_text_len is not None: 143 | default_context_length = 77 144 | context_length = max_text_len + 2 # start_token + 20 + end_token 145 | assert context_length < default_context_length 146 | texts = clip.tokenize(raw_text, context_length=context_length, truncate=True).to(device) # [bs, context_length] # if n_tokens > context_length -> will truncate 147 | # print('texts', texts.shape) 148 | zero_pad = torch.zeros([texts.shape[0], default_context_length-context_length], dtype=texts.dtype, device=texts.device) 149 | texts = torch.cat([texts, zero_pad], dim=1) 150 | # print('texts after pad', texts.shape, texts) 151 | else: 152 | texts = clip.tokenize(raw_text, truncate=True).to(device) # [bs, context_length] # if n_tokens > 77 -> will truncate 153 | return self.clip_model.encode_text(texts).float() 154 | 155 | def mask_cond(self, cond, force_mask=False): 156 | bs, d = cond.shape 157 | if force_mask: 158 | return torch.zeros_like(cond) 159 | elif self.training and self.cond_mask_prob > 0.: 160 | mask = torch.bernoulli(torch.ones(bs, device=cond.device) * self.cond_mask_prob).view(bs, 1) # 1-> use null_cond, 0-> use real cond 161 | return cond * (1. - mask) 162 | else: 163 | return cond 164 | 165 | def load_and_freeze_clip(self, clip_version): 166 | clip_model, clip_preprocess = clip.load(clip_version, device='cpu', 167 | jit=False) # Must set jit=False for training 168 | clip.model.convert_weights( 169 | clip_model) # Actually this line is unnecessary since clip by default already on float16 170 | 171 | # Freeze CLIP weights 172 | clip_model.eval() 173 | for p in clip_model.parameters(): 174 | p.requires_grad = False 175 | 176 | return clip_model 177 | 178 | class PositionalEncoding(nn.Module): 179 | def __init__(self, d_model, dropout=0.1, max_len=5000): 180 | super(PositionalEncoding, self).__init__() 181 | self.dropout = nn.Dropout(p=dropout) 182 | 183 | pe = torch.zeros(max_len, d_model) 184 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 185 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)) 186 | pe[:, 0::2] = torch.sin(position * div_term) 187 | pe[:, 1::2] = torch.cos(position * div_term) 188 | pe = pe.unsqueeze(0).transpose(0, 1) 189 | 190 | self.register_buffer('pe', pe) 191 | 192 | def forward(self, x): 193 | # not used in the final model 194 | x = x + self.pe[:x.shape[0], :] 195 | return self.dropout(x) 196 | 197 | 198 | class TimestepEmbedder(nn.Module): 199 | def __init__(self, latent_dim, sequence_pos_encoder): 200 | super().__init__() 201 | self.latent_dim = latent_dim 202 | self.sequence_pos_encoder = sequence_pos_encoder 203 | 204 | time_embed_dim = self.latent_dim 205 | self.time_embed = nn.Sequential( 206 | nn.Linear(self.latent_dim, time_embed_dim), 207 | nn.SiLU(), 208 | nn.Linear(time_embed_dim, time_embed_dim), 209 | ) 210 | 211 | def forward(self, timesteps): 212 | return self.time_embed(self.sequence_pos_encoder.pe[timesteps]).permute(1, 0, 2) 213 | -------------------------------------------------------------------------------- /src/model/net_3o.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from loguru import logger 6 | import clip 7 | from src.model.blocks import TransformerBlock 8 | 9 | class NET_3O(nn.Module): 10 | def __init__(self, 11 | n_feats=(52*3+52*6+3+3*9), 12 | clip_dim=512, 13 | latent_dim=512, 14 | ff_size=1024, 15 | num_layers=8, 16 | num_heads=4, 17 | dropout=0.1, 18 | ablation=None, 19 | activation="gelu",**kwargs): 20 | super().__init__() 21 | 22 | self.n_feats = n_feats 23 | self.clip_dim = clip_dim 24 | self.latent_dim = latent_dim 25 | self.ff_size = ff_size 26 | self.num_layers = num_layers 27 | self.num_heads = num_heads 28 | self.dropout = dropout 29 | self.ablation = ablation 30 | self.activation = activation 31 | 32 | self.cond_mask_prob=kwargs.get('cond_mask_prob',0.1) 33 | 34 | # clip and text embedder 35 | self.embed_text=nn.Linear(self.clip_dim,self.latent_dim) 36 | self.clip_version='ViT-B/32' 37 | self.clip_model=self.load_and_freeze_clip(self.clip_version) 38 | 39 | # object_geometry embedder 40 | self.embed_obj_bps=nn.Linear(1024*3,self.latent_dim) 41 | # object init state embeddr 42 | self.embed_obj_pose=nn.Linear(3*self.latent_dim+3*9+3*9,self.latent_dim) 43 | 44 | # human embedder 45 | self.embed_human_pose=nn.Linear(2*(52*3+52*6+3),self.latent_dim) 46 | 47 | # position encoding 48 | self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, dropout=self.dropout) 49 | 50 | # TODO:unshared transformer layers for human and objects,they fuse feature in the middle layers 51 | seqTransEncoderLayer=nn.TransformerEncoderLayer( 52 | d_model=self.latent_dim, 53 | nhead=self.num_heads, 54 | dim_feedforward=self.ff_size, 55 | dropout=self.dropout, 56 | activation=self.activation, 57 | ) 58 | # # object transformer layers 59 | # self.object_trans_encoder=nn.TransformerEncoder(seqTransEncoderLayer,num_layers=self.num_layers) 60 | # # human transformer encoder 61 | # self.human_trans_encoder=nn.TransformerEncoder(seqTransEncoderLayer,num_layers=self.num_layers) 62 | 63 | # # Mutal cross attention 64 | # self.communication_module=nn.ModuleList() 65 | # for i in range(8): 66 | # self.communication_module.append(MutalCrossAttentionBlock(self.latent_dim,self.num_heads,self.ff_size,self.dropout)) 67 | self.obj_blocks=nn.ModuleList() 68 | self.human_blocks=nn.ModuleList() 69 | for i in range(self.num_layers): 70 | self.obj_blocks.append(TransformerBlock(self.latent_dim, self.num_heads, self.ff_size, self.dropout, self.activation)) 71 | self.human_blocks.append(TransformerBlock(self.latent_dim, self.num_heads, self.ff_size, self.dropout, self.activation)) 72 | 73 | 74 | # embed the timestep 75 | self.embed_timestep=TimestepEmbedder(self.latent_dim, self.sequence_pos_encoder) 76 | 77 | # object output process 78 | self.obj_output_process=nn.Linear(self.latent_dim,3*9) 79 | # human motion output process 80 | self.human_output_process=nn.Linear(self.latent_dim,52*3+52*6+3) 81 | 82 | def forward(self,x,timesteps,y=None): 83 | bs,nframes,n_feats=x.shape 84 | emb=self.embed_timestep(timesteps) # [1, bs, latent_dim] 85 | 86 | enc_text=self.encode_text(y['text']) 87 | emb+=self.mask_cond(enc_text) # 1,bs,latent_dim 88 | 89 | x=x.permute((1,0,2)) # nframes,bs,nfeats 90 | human_x,obj_x=torch.split(x,[52*3+52*6+3,3*9],dim=-1) # [nframes,bs,52*3+52*6+3],[nframes,bs,3*9] 91 | 92 | # encode object geometry 93 | obj1_bps,obj2_bps,obj3_bps=y['obj1_bps'].reshape(bs,-1),y['obj2_bps'].reshape(bs,-1),y['obj3_bps'].reshape(bs,-1) # [b,1024,3] 94 | obj1_bps_emb=self.embed_obj_bps(obj1_bps) # [b,latent_dim] 95 | obj2_bps_emb=self.embed_obj_bps(obj2_bps) # [b,latent_dim] 96 | obj3_bps_emb=self.embed_obj_bps(obj3_bps) # [b,latent_dim] 97 | obj_geo_emb=torch.concat([obj1_bps_emb,obj2_bps_emb,obj3_bps_emb],dim=-1).unsqueeze(0).repeat((nframes,1,1)) # [nf,b,3*latent_dim] 98 | 99 | # init_state,mask the other frames by padding zeros 100 | init_state=y['init_state'].unsqueeze(0) # [1,b,52*3+52*6+3+3*9] 101 | padded_zeros=torch.zeros((nframes-1,bs,52*3+52*6+3+3*9),device=init_state.device) 102 | init_state=torch.concat([init_state,padded_zeros],dim=0) # [nf,b,52*3+52*6+3+3*9] 103 | 104 | # seperate the object and human init state 105 | human_init_state=init_state[:,:,:52*3+52*6+3] # [nf,b,52*3+52*6+3] 106 | obj_init_state=init_state[:,:,52*3+52*6+3:] # [nf,b,3*9] 107 | 108 | # Object branch 109 | obj_emb=self.embed_obj_pose(torch.concat([obj_geo_emb,obj_init_state,obj_x],dim=-1)) # nframes,bs,latent_dim 110 | obj_seq_prev=self.sequence_pos_encoder(obj_emb) # [nf,bs,latent_dim] 111 | 112 | 113 | # Human branch 114 | human_emb=self.embed_human_pose(torch.concat([human_init_state,human_x],dim=-1)) # nframes,bs,latent_dim 115 | human_seq_prev=self.sequence_pos_encoder(human_emb) # [nf,bs,latent_dim] 116 | 117 | mask=y['mask'].squeeze(1).squeeze(1) # [bs,nf] 118 | key_padding_mask=~mask # [bs,nf] 119 | 120 | obj_seq_prev=obj_seq_prev.permute((1,0,2)) # [bs,nf,latent_dim] 121 | human_seq_prev=human_seq_prev.permute((1,0,2)) # [bs,nf,latent_dim] 122 | emb=emb.squeeze(0) # [bs,latent_dim] 123 | for i in range(self.num_layers): 124 | obj_seq=self.obj_blocks[i](obj_seq_prev,human_seq_prev,emb, key_padding_mask=key_padding_mask) 125 | human_seq=self.human_blocks[i](human_seq_prev,obj_seq_prev,emb, key_padding_mask=key_padding_mask) 126 | obj_seq_prev=obj_seq 127 | human_seq_prev=human_seq 128 | obj_seq=obj_seq.permute((1,0,2)) # [nf,bs,latent_dim] 129 | human_seq=human_seq.permute((1,0,2)) 130 | 131 | 132 | obj_output=self.obj_output_process(obj_seq) # [nf,bs,3*9] 133 | human_output=self.human_output_process(human_seq) # [nf,bs,52*3+52*6+3] 134 | 135 | output=torch.concat([human_output,obj_output],dim=-1) # [nf,bs,52*3+52*6+3+3*9] 136 | 137 | return output.permute((1,0,2)) 138 | 139 | def encode_text(self,raw_text): 140 | # raw_text - list (batch_size length) of strings with input text prompts 141 | device = next(self.parameters()).device 142 | max_text_len = 40 143 | if max_text_len is not None: 144 | default_context_length = 77 145 | context_length = max_text_len + 2 # start_token + 20 + end_token 146 | assert context_length < default_context_length 147 | texts = clip.tokenize(raw_text, context_length=context_length, truncate=True).to(device) # [bs, context_length] # if n_tokens > context_length -> will truncate 148 | # print('texts', texts.shape) 149 | zero_pad = torch.zeros([texts.shape[0], default_context_length-context_length], dtype=texts.dtype, device=texts.device) 150 | texts = torch.cat([texts, zero_pad], dim=1) 151 | # print('texts after pad', texts.shape, texts) 152 | else: 153 | texts = clip.tokenize(raw_text, truncate=True).to(device) # [bs, context_length] # if n_tokens > 77 -> will truncate 154 | return self.clip_model.encode_text(texts).float() 155 | 156 | def mask_cond(self, cond, force_mask=False): 157 | bs, d = cond.shape 158 | if force_mask: 159 | return torch.zeros_like(cond) 160 | elif self.training and self.cond_mask_prob > 0.: 161 | mask = torch.bernoulli(torch.ones(bs, device=cond.device) * self.cond_mask_prob).view(bs, 1) # 1-> use null_cond, 0-> use real cond 162 | return cond * (1. - mask) 163 | else: 164 | return cond 165 | 166 | def load_and_freeze_clip(self, clip_version): 167 | clip_model, clip_preprocess = clip.load(clip_version, device='cpu', 168 | jit=False) # Must set jit=False for training 169 | clip.model.convert_weights( 170 | clip_model) # Actually this line is unnecessary since clip by default already on float16 171 | 172 | # Freeze CLIP weights 173 | clip_model.eval() 174 | for p in clip_model.parameters(): 175 | p.requires_grad = False 176 | 177 | return clip_model 178 | 179 | class PositionalEncoding(nn.Module): 180 | def __init__(self, d_model, dropout=0.1, max_len=5000): 181 | super(PositionalEncoding, self).__init__() 182 | self.dropout = nn.Dropout(p=dropout) 183 | 184 | pe = torch.zeros(max_len, d_model) 185 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 186 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)) 187 | pe[:, 0::2] = torch.sin(position * div_term) 188 | pe[:, 1::2] = torch.cos(position * div_term) 189 | pe = pe.unsqueeze(0).transpose(0, 1) 190 | 191 | self.register_buffer('pe', pe) 192 | 193 | def forward(self, x): 194 | # not used in the final model 195 | x = x + self.pe[:x.shape[0], :] 196 | return self.dropout(x) 197 | 198 | 199 | class TimestepEmbedder(nn.Module): 200 | def __init__(self, latent_dim, sequence_pos_encoder): 201 | super().__init__() 202 | self.latent_dim = latent_dim 203 | self.sequence_pos_encoder = sequence_pos_encoder 204 | 205 | time_embed_dim = self.latent_dim 206 | self.time_embed = nn.Sequential( 207 | nn.Linear(self.latent_dim, time_embed_dim), 208 | nn.SiLU(), 209 | nn.Linear(time_embed_dim, time_embed_dim), 210 | ) 211 | 212 | def forward(self, timesteps): 213 | return self.time_embed(self.sequence_pos_encoder.pe[timesteps]).permute(1, 0, 2) 214 | -------------------------------------------------------------------------------- /src/dataset/eval_gen_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset,DataLoader 3 | import numpy as np 4 | from tqdm import tqdm 5 | from bps_torch.bps import bps_torch 6 | import os.path as osp 7 | import os 8 | from src.utils.rotation_conversion import sixd2rot_torch 9 | from src.dataset.tensors import gt_collate_fn 10 | from src.utils import dist_utils 11 | from src.dataset.tensors import lengths_to_mask 12 | 13 | def get_eval_gen_loader(args,model,diffusion,gen_loader, 14 | max_motion_length,batch_size, 15 | mm_num_samples,mm_num_repeats,num_samples_limit,scale): 16 | dataset=Evaluation_generator_Dataset(args,model,diffusion,gen_loader, 17 | max_motion_length,mm_num_samples,mm_num_repeats,num_samples_limit,scale) 18 | 19 | mm_dataset=MM_generator_Dataset('test',dataset,gen_loader.dataset.w_vectorizer) 20 | 21 | motion_loader=DataLoader(dataset,batch_size=batch_size,num_workers=4,drop_last=True,collate_fn=gt_collate_fn) 22 | mm_motion_loader = DataLoader(mm_dataset, batch_size=1, num_workers=1) 23 | 24 | print('Generated Dataset Loading Completed!!!') 25 | return motion_loader,mm_motion_loader 26 | 27 | class MM_generator_Dataset(Dataset): 28 | def __init__(self, opt, motion_dataset, w_vectorizer): 29 | self.opt = opt 30 | self.dataset = motion_dataset.mm_generated_motion 31 | self.w_vectorizer = w_vectorizer 32 | 33 | def __len__(self): 34 | return len(self.dataset) 35 | 36 | def __getitem__(self, item): 37 | data = self.dataset[item] 38 | mm_motions = data['mm_motions'] 39 | m_lens = [] 40 | motions = [] 41 | for mm_motion in mm_motions: 42 | m_lens.append(mm_motion['length']) 43 | motion = mm_motion['motion'] 44 | # We don't need the following logic because our sample func generates the full tensor anyway: 45 | # if len(motion) < self.opt.max_motion_length: 46 | # motion = np.concatenate([motion, 47 | # np.zeros((self.opt.max_motion_length - len(motion), motion.shape[1])) 48 | # ], axis=0) 49 | motion = motion[None, :] 50 | motions.append(motion) 51 | m_lens = np.array(m_lens, dtype=np.int) 52 | motions = np.concatenate(motions, axis=0) 53 | sort_indx = np.argsort(m_lens)[::-1].copy() 54 | # print(m_lens) 55 | # print(sort_indx) 56 | # print(m_lens[sort_indx]) 57 | m_lens = m_lens[sort_indx] 58 | motions = motions[sort_indx] 59 | return motions, m_lens 60 | 61 | class Evaluation_generator_Dataset(Dataset): 62 | def __init__(self,args,model,diffusion,gen_loader, 63 | max_motion_length,mm_num_samples,mm_num_repeats,num_samples_limit,scale=1.0): 64 | self.dataloader=gen_loader 65 | assert mm_num_samples 0: 82 | mm_idxs = np.random.choice(real_num_batches, mm_num_samples // self.dataloader.batch_size +1, replace=False) 83 | mm_idxs = np.sort(mm_idxs) 84 | else: 85 | mm_idxs = [] 86 | print('mm_idxs', mm_idxs) 87 | 88 | model.eval() 89 | 90 | with torch.no_grad(): 91 | for i,eval_data_batch in tqdm(enumerate(self.dataloader),total=len(self.dataloader)): 92 | 93 | # if i==1: 94 | # break 95 | if num_samples_limit is not None and len(generated_motion) >= num_samples_limit: 96 | break 97 | 98 | if args.obj=='2o': 99 | word_embeddings,pos_one_hots,caption,sent_len,motion,m_length,tokens,\ 100 | obj1_bps,obj2_bps,init_state,obj1_name,obj2_name,betas=eval_data_batch 101 | tokens=[t.split('_') for t in tokens] 102 | 103 | model_kwargs={ 104 | 'y':{ 105 | 'length':m_length.to(dist_utils.dev()), # [bs] 106 | 'text':caption, 107 | 'obj1_bps':obj1_bps.to(dist_utils.dev()), 108 | 'obj2_bps':obj2_bps.to(dist_utils.dev()), 109 | 'init_state':init_state.to(dist_utils.dev()), 110 | 'mask':lengths_to_mask(m_length,motion.shape[1]).unsqueeze(1).unsqueeze(1).to(dist_utils.dev()) # [bs,1,1,nf] 111 | } 112 | } 113 | elif args.obj=='3o': 114 | word_embeddings,pos_one_hots,caption,sent_len,motion,m_length,tokens,\ 115 | obj1_bps,obj2_bps,obj3_bps,init_state,obj1_name,obj2_name,obj3_name,betas=eval_data_batch 116 | tokens=[t.split('_') for t in tokens] 117 | 118 | model_kwargs={ 119 | 'y':{ 120 | 'length':m_length.to(dist_utils.dev()), 121 | 'text':caption, 122 | 'obj1_bps':obj1_bps.to(dist_utils.dev()), 123 | 'obj2_bps':obj2_bps.to(dist_utils.dev()), 124 | 'obj3_bps':obj3_bps.to(dist_utils.dev()), 125 | 'init_state':init_state.to(dist_utils.dev()), 126 | 'mask':lengths_to_mask(m_length,motion.shape[1]).unsqueeze(1).unsqueeze(1).to(dist_utils.dev()) # [bs,1,1,nf] 127 | } 128 | } 129 | 130 | # add CFG scale to batch 131 | if scale != 1.: 132 | model_kwargs['y']['scale'] = torch.ones(motion.shape[0], 133 | device=dist_utils.dev()) * scale 134 | 135 | mm_num_now = len(mm_generated_motion) // self.dataloader.batch_size 136 | is_mm = i in mm_idxs 137 | repeat_times = mm_num_repeats if is_mm else 1 138 | mm_motions = [] 139 | for t in range(repeat_times): 140 | 141 | model_out_sample=model_sample_fn( 142 | model, 143 | motion.shape, 144 | clip_denoised=clip_denoised, 145 | model_kwargs=model_kwargs, 146 | skip_timesteps=0, # 0 is the default value - i.e. don't skip any step 147 | init_image=None, 148 | progress=False, 149 | dump_steps=None, 150 | noise=None, 151 | const_noise=False, 152 | ) # bs,nf,315 153 | 154 | # export the model result 155 | network=args.model_path.split('/')[-2] 156 | export_path=osp.join('./export_results',network,'{}.npz'.format(i)) 157 | os.makedirs(osp.dirname(export_path),exist_ok=True) 158 | if args.obj=='2o': 159 | output_dict={str(bs_i): 160 | { 161 | 'out':model_out_sample[bs_i].squeeze().cpu().numpy(), 162 | 'length':m_length[bs_i].cpu().numpy(), 163 | 'caption':caption[bs_i], 164 | 'obj1_name':obj1_name[bs_i], 165 | 'obj2_name':obj2_name[bs_i], 166 | 'betas':betas[bs_i].cpu().numpy(), 167 | } for bs_i in range(self.dataloader.batch_size) 168 | } 169 | elif args.obj=='3o': 170 | output_dict={str(bs_i): 171 | { 172 | 'out':model_out_sample[bs_i].squeeze().cpu().numpy(), 173 | 'length':m_length[bs_i].cpu().numpy(), 174 | 'caption':caption[bs_i], 175 | 'obj1_name':obj1_name[bs_i], 176 | 'obj2_name':obj2_name[bs_i], 177 | 'obj3_name':obj3_name[bs_i], 178 | 'betas':betas[bs_i].cpu().numpy(), 179 | } for bs_i in range(self.dataloader.batch_size) 180 | } 181 | np.savez(export_path,**output_dict) 182 | 183 | if t==0: 184 | sub_dicts=[ 185 | { 186 | 'motion':model_out_sample[bs_i].squeeze().cpu().numpy(), 187 | 'length':m_length[bs_i].cpu().numpy(), 188 | 'caption':caption[bs_i], 189 | 'tokens':tokens[bs_i], 190 | 'cap_len':sent_len[bs_i].cpu().numpy(), 191 | } for bs_i in range(self.dataloader.batch_size) 192 | ] 193 | generated_motion+=sub_dicts 194 | if is_mm: 195 | mm_motions+=[ 196 | { 197 | 'motion':model_out_sample[bs_i].squeeze().cpu().numpy(), 198 | 'length':m_length[bs_i].cpu().numpy(), 199 | } for bs_i in range(self.dataloader.batch_size) 200 | ] 201 | if is_mm: 202 | mm_generated_motion+=[ 203 | { 204 | 'caption':model_kwargs['y']['text'][bs_i], 205 | 'tokens':tokens[bs_i], 206 | 'cap_len':len(tokens[bs_i]), 207 | 'mm_motions':mm_motions[bs_i::self.dataloader.batch_size], # collect all 10 repeats from the (32*10) generated motions 208 | } for bs_i in range(self.dataloader.batch_size) 209 | ] 210 | self.generated_motion=generated_motion 211 | self.mm_generated_motion=mm_generated_motion 212 | self.w_vectorizer=self.dataloader.dataset.w_vectorizer 213 | 214 | def __len__(self): 215 | return len(self.generated_motion) 216 | 217 | def __getitem__(self, item): 218 | data = self.generated_motion[item] 219 | motion, m_length, caption, tokens = data['motion'], data['length'], data['caption'], data['tokens'] 220 | sent_len = data['cap_len'] 221 | 222 | pos_one_hots = [] 223 | word_embeddings = [] 224 | for token in tokens: 225 | word_emb, pos_oh = self.w_vectorizer[token] 226 | pos_one_hots.append(pos_oh[None, :]) 227 | word_embeddings.append(word_emb[None, :]) 228 | pos_one_hots = np.concatenate(pos_one_hots, axis=0) 229 | word_embeddings = np.concatenate(word_embeddings, axis=0) 230 | 231 | return word_embeddings, pos_one_hots, caption, sent_len, motion, m_length, '_'.join(tokens) -------------------------------------------------------------------------------- /src/train/training_loop.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from src.diffusion.fp16_util import MixedPrecisionTrainer 3 | from src.diffusion.resample import LossAwareSampler 4 | from src.dataset.eval_dataset import Evaluation_Dataset 5 | from src.dataset.tensors import gt_collate_fn 6 | from src.eval.eval_himo_2o import evaluation 7 | from src.feature_extractor.eval_wrapper import EvaluationWrapper 8 | from src.dataset.eval_gen_dataset import get_eval_gen_loader 9 | from src.diffusion import logger 10 | import functools 11 | from loguru import logger as log 12 | 13 | from src.utils import dist_utils 14 | 15 | from torch.optim import AdamW 16 | from torch.utils.data import DataLoader 17 | import blobfile as bf 18 | from src.diffusion.resample import create_named_schedule_sampler 19 | from tqdm import tqdm 20 | import numpy as np 21 | import os 22 | import time 23 | 24 | 25 | class TrainLoop: 26 | def __init__(self,args,train_platform,model,diffusion,data_loader): 27 | self.args=args 28 | self.train_platform=train_platform 29 | self.model=model 30 | self.diffusion=diffusion 31 | self.data=data_loader 32 | 33 | self.batch_size=args.batch_size 34 | self.microbatch = args.batch_size # deprecating this option 35 | self.lr=args.lr 36 | self.log_interval=args.log_interval 37 | self.save_interval=args.save_interval 38 | self.resume_checkpoint=args.resume_checkpoint 39 | self.use_fp16 = False # deprecating this option 40 | self.fp16_scale_growth = 1e-3 # deprecating this option 41 | self.weight_decay=args.weight_decay 42 | self.lr_anneal_steps=args.lr_anneal_steps 43 | 44 | self.step = 0 45 | self.resume_step = 0 46 | self.global_batch = self.batch_size # * dist.get_world_size() 47 | # self.num_steps = args.num_steps 48 | # self.num_epochs = self.num_steps // len(self.data) + 1 49 | self.num_epochs=args.num_epochs 50 | 51 | self.sync_cuda = torch.cuda.is_available() 52 | 53 | self._load_and_sync_parameters() 54 | self.mp_trainer = MixedPrecisionTrainer( 55 | model=self.model, 56 | use_fp16=self.use_fp16, 57 | fp16_scale_growth=self.fp16_scale_growth, 58 | ) 59 | 60 | self.save_path=args.save_path 61 | 62 | self.opt = AdamW( 63 | self.mp_trainer.master_params, lr=self.lr, weight_decay=self.weight_decay 64 | ) 65 | if self.resume_step: 66 | self._load_optimizer_state() 67 | # Model was resumed, either due to a restart or a checkpoint 68 | # being specified at the command line. 69 | 70 | self.device = torch.device("cpu") 71 | if torch.cuda.is_available() and dist_utils.dev() != 'cpu': 72 | self.device = torch.device(dist_utils.dev()) 73 | 74 | self.schedule_sampler_type = 'uniform' 75 | self.schedule_sampler = create_named_schedule_sampler(self.schedule_sampler_type, diffusion) 76 | self.eval_wrapper,self.eval_data,self.eval_gt_data=None,None,None 77 | if args.eval_during_training: 78 | mm_num_samples = 0 # mm is super slow hence we won't run it during training 79 | mm_num_repeats = 0 # mm is super slow hence we won't run it during training 80 | gen_dataset=Evaluation_Dataset(args,split='val',mode='eval') 81 | gen_loader=DataLoader(gen_dataset,batch_size=args.eval_batch_size, 82 | shuffle=True,num_workers=8,drop_last=True,collate_fn=gt_collate_fn) 83 | gt_dataset=Evaluation_Dataset(args,split='val',mode='gt') 84 | self.eval_gt_data=DataLoader(gt_dataset,batch_size=args.eval_batch_size, 85 | shuffle=True,num_workers=8,drop_last=True,collate_fn=gt_collate_fn) 86 | 87 | self.eval_wrapper=EvaluationWrapper(args) 88 | self.eval_data={ 89 | 'test':lambda :get_eval_gen_loader( 90 | args,model,diffusion,gen_loader,gen_loader.dataset.max_motion_length, 91 | args.eval_batch_size,mm_num_samples,mm_num_repeats, 92 | 1000,scale=1. 93 | 94 | ) 95 | } 96 | self.use_ddp=False 97 | self.ddp_model=self.model 98 | 99 | def _load_and_sync_parameters(self): 100 | resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint 101 | 102 | if resume_checkpoint: 103 | self.resume_step = parse_resume_step_from_filename(resume_checkpoint) 104 | log.info(f"loading model from checkpoint: {resume_checkpoint}...") 105 | self.model.load_state_dict( 106 | dist_utils.load_state_dict( 107 | resume_checkpoint, map_location=dist_utils.dev() 108 | ) 109 | ) 110 | 111 | def _load_optimizer_state(self): 112 | main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint 113 | opt_checkpoint = bf.join( 114 | bf.dirname(main_checkpoint), f"opt{self.resume_step:09}.pt" 115 | ) 116 | if bf.exists(opt_checkpoint): 117 | log.log(f"loading optimizer state from checkpoint: {opt_checkpoint}") 118 | state_dict = dist_utils.load_state_dict( 119 | opt_checkpoint, map_location=dist_utils.dev() 120 | ) 121 | self.opt.load_state_dict(state_dict) 122 | 123 | def run_loop(self): 124 | for epoch in range(self.num_epochs): 125 | log.info(f'Starting epoch {epoch}/{self.num_epochs}') 126 | for inp,cond in tqdm(self.data): 127 | if not (not self.lr_anneal_steps or self.step + self.resume_step < self.lr_anneal_steps): 128 | break 129 | 130 | inp=inp.to(self.device) 131 | cond['y'] ={k:v.to(self.device) if torch.is_tensor(v) else v for k,v in cond['y'].items()} 132 | 133 | self.run_step(inp,cond) 134 | if self.step % self.log_interval == 0: 135 | for k,v in logger.get_current().name2val.items(): 136 | if k == 'loss': 137 | print('step[{}]: loss[{:0.5f}]'.format(self.step+self.resume_step, v)) 138 | 139 | if k in ['step', 'samples'] or '_q' in k: 140 | continue 141 | else: 142 | self.train_platform.report_scalar(name=k, value=v, iteration=self.step, group_name='Loss') 143 | if self.step % self.save_interval == 0: 144 | self.save() 145 | self.model.eval() 146 | self.evaluate() 147 | self.model.train() 148 | 149 | # Run for a finite amount of time in integration tests. 150 | if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0: 151 | return 152 | self.step += 1 153 | if not (not self.lr_anneal_steps or self.step + self.resume_step < self.lr_anneal_steps): 154 | break 155 | # Save the last checkpoint if it wasn't already saved. 156 | if (self.step - 1) % self.save_interval != 0: 157 | self.save() 158 | self.evaluate() 159 | 160 | def run_step(self,inp,cond): 161 | self.forward_backward(inp,cond) 162 | self.mp_trainer.optimize(self.opt) 163 | self._anneal_lr() 164 | self.log_step() 165 | 166 | def forward_backward(self,batch,cond): 167 | self.mp_trainer.zero_grad() 168 | for i in range(0, batch.shape[0], self.microbatch): 169 | # Eliminates the microbatch feature 170 | assert i == 0 171 | assert self.microbatch == self.batch_size 172 | micro = batch 173 | micro_cond = cond 174 | last_batch = (i + self.microbatch) >= batch.shape[0] 175 | t, weights = self.schedule_sampler.sample(micro.shape[0], dist_utils.dev()) 176 | 177 | compute_losses = functools.partial( 178 | self.diffusion.training_losses, 179 | self.ddp_model, 180 | micro, # [bs, ...] 181 | t, # [bs](int) sampled timesteps 182 | model_kwargs=micro_cond, 183 | dataset=self.args.dataset 184 | ) 185 | 186 | if last_batch or not self.use_ddp: 187 | losses = compute_losses() 188 | else: 189 | with self.ddp_model.no_sync(): 190 | losses = compute_losses() 191 | 192 | if isinstance(self.schedule_sampler, LossAwareSampler): 193 | self.schedule_sampler.update_with_local_losses( 194 | t, losses["loss"].detach() 195 | ) 196 | 197 | loss = (losses["loss"] * weights).mean() 198 | log_loss_dict( 199 | self.diffusion, t, {k: v * weights for k, v in losses.items()} 200 | ) 201 | self.mp_trainer.backward(loss) 202 | 203 | def _anneal_lr(self): 204 | if not self.lr_anneal_steps: 205 | return 206 | frac_done = (self.step + self.resume_step) / self.lr_anneal_steps 207 | lr = self.lr * (1 - frac_done) 208 | for param_group in self.opt.param_groups: 209 | param_group["lr"] = lr 210 | 211 | def log_step(self): 212 | logger.logkv("step", self.step + self.resume_step) 213 | logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch) 214 | 215 | def ckpt_file_name(self): 216 | return f"model{(self.step+self.resume_step):09d}.pt" 217 | 218 | def save(self): 219 | def save_checkpoint(params): 220 | state_dict = self.mp_trainer.master_params_to_state_dict(params) 221 | 222 | # Do not save CLIP weights 223 | clip_weights = [e for e in state_dict.keys() if e.startswith('clip_model.')] 224 | for e in clip_weights: 225 | del state_dict[e] 226 | 227 | log.info(f"saving model...") 228 | filename = self.ckpt_file_name() 229 | with bf.BlobFile(bf.join(self.save_path, filename), "wb") as f: 230 | torch.save(state_dict, f) 231 | 232 | save_checkpoint(self.mp_trainer.master_params) 233 | 234 | with bf.BlobFile( 235 | bf.join(self.save_path, f"opt{(self.step+self.resume_step):09d}.pt"), 236 | "wb", 237 | ) as f: 238 | torch.save(self.opt.state_dict(), f) 239 | 240 | def evaluate(self): 241 | if not self.args.eval_during_training: 242 | return 243 | start_eval = time.time() 244 | if self.eval_wrapper is not None: 245 | log.info('Running evaluation loop: [Should take about 90 min]') 246 | log_file = os.path.join(self.save_path, f'eval_model_{(self.step + self.resume_step):09d}.log') 247 | diversity_times = 100 if self.args.obj=='2o' else 40# 200 248 | eval_rep_time=1 # 3 249 | mm_num_times = 0 # mm is super slow hence we won't run it during training 250 | eval_dict = evaluation( 251 | self.eval_wrapper, self.eval_gt_data, self.eval_data, log_file, 252 | replication_times=eval_rep_time, diversity_times=diversity_times, mm_num_times=mm_num_times, run_mm=False) 253 | log.info(eval_dict) 254 | for k, v in eval_dict.items(): 255 | if k.startswith('R_precision'): 256 | for i in range(len(v)): 257 | self.train_platform.report_scalar(name=f'top{i + 1}_' + k, value=v[i], 258 | iteration=self.step + self.resume_step, 259 | group_name='Eval') 260 | else: 261 | self.train_platform.report_scalar(name=k, value=v, iteration=self.step + self.resume_step, 262 | group_name='Eval') 263 | end_eval=time.time() 264 | log.info(f'Evaluation time: {round(end_eval-start_eval)/60} seconds') 265 | 266 | 267 | def find_resume_checkpoint(): 268 | # On your infrastructure, you may want to override this to automatically 269 | # discover the latest checkpoint on your blob storage, etc. 270 | return None 271 | 272 | def parse_resume_step_from_filename(filename): 273 | """ 274 | Parse filenames of the form path/to/modelNNNNNN.pt, where NNNNNN is the 275 | checkpoint's number of steps. 276 | """ 277 | split = filename.split("model") 278 | if len(split) < 2: 279 | return 0 280 | split1 = split[-1].split(".")[0] 281 | try: 282 | return int(split1) 283 | except ValueError: 284 | return 0 285 | 286 | def log_loss_dict(diffusion, ts, losses): 287 | for key, values in losses.items(): 288 | logger.logkv_mean(key, values.mean().item()) 289 | # Log the quantiles (four quartiles, in particular). 290 | for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()): 291 | quartile = int(4 * sub_t / diffusion.num_timesteps) 292 | logger.logkv_mean(f"{key}_q{quartile}", sub_loss) -------------------------------------------------------------------------------- /src/eval/eval_himo_3o.py: -------------------------------------------------------------------------------- 1 | from src.utils.parser_utils import eval_himo_args 2 | from src.utils.misc import fixseed 3 | from src.utils import dist_utils 4 | # from src.diffusion import logger 5 | 6 | from collections import OrderedDict 7 | from src.utils.model_utils import create_model_and_diffusion,load_model_wo_clip 8 | from torch.utils.data import DataLoader 9 | from src.dataset.eval_dataset import Evaluation_Dataset 10 | from src.dataset.tensors import gt_collate_fn 11 | from src.dataset.eval_gen_dataset import get_eval_gen_loader 12 | from src.feature_extractor.eval_wrapper import EvaluationWrapper 13 | from src.eval.metrics import * 14 | from src.model.cfg_sampler import ClassifierFreeSampleModel 15 | import torch 16 | import os 17 | import os.path as osp 18 | import numpy as np 19 | from datetime import datetime 20 | from loguru import logger 21 | 22 | def evaluate_matching_score(eval_wrapper, motion_loaders, file): 23 | match_score_dict = OrderedDict({}) 24 | R_precision_dict = OrderedDict({}) 25 | activation_dict = OrderedDict({}) 26 | logger.info('========== Evaluating Matching Score ==========') 27 | for motion_loader_name, motion_loader in motion_loaders.items(): 28 | all_motion_embeddings = [] 29 | score_list = [] 30 | all_size = 0 31 | matching_score_sum = 0 32 | top_k_count = 0 33 | # logger.info(motion_loader_name) 34 | with torch.no_grad(): 35 | for idx, batch in enumerate(motion_loader): 36 | word_embeddings, pos_one_hots, _, sent_lens, motions, m_lens, _ = batch 37 | text_embeddings, motion_embeddings = eval_wrapper.get_co_embeddings( 38 | word_embs=word_embeddings, 39 | pos_ohot=pos_one_hots, 40 | cap_lens=sent_lens, 41 | motions=motions, 42 | m_lens=m_lens 43 | ) 44 | dist_mat = euclidean_distance_matrix(text_embeddings.cpu().numpy(), 45 | motion_embeddings.cpu().numpy()) 46 | matching_score_sum += dist_mat.trace() 47 | 48 | argsmax = np.argsort(dist_mat, axis=1) 49 | top_k_mat = calculate_top_k(argsmax, top_k=3) 50 | top_k_count += top_k_mat.sum(axis=0) 51 | 52 | all_size += text_embeddings.shape[0] 53 | 54 | all_motion_embeddings.append(motion_embeddings.cpu().numpy()) 55 | 56 | all_motion_embeddings = np.concatenate(all_motion_embeddings, axis=0) 57 | matching_score = matching_score_sum / all_size 58 | R_precision = top_k_count / all_size 59 | match_score_dict[motion_loader_name] = matching_score 60 | R_precision_dict[motion_loader_name] = R_precision 61 | activation_dict[motion_loader_name] = all_motion_embeddings 62 | 63 | logger.info(f'---> [{motion_loader_name}] Matching Score: {matching_score:.4f}') 64 | logger.info(f'---> [{motion_loader_name}] Matching Score: {matching_score:.4f}', file=file, flush=True) 65 | 66 | line = f'---> [{motion_loader_name}] R_precision: ' 67 | for i in range(len(R_precision)): 68 | line += '(top %d): %.4f ' % (i+1, R_precision[i]) 69 | logger.info(line) 70 | logger.info(line, file=file, flush=True) 71 | 72 | return match_score_dict, R_precision_dict, activation_dict 73 | 74 | def evaluate_fid(eval_wrapper, groundtruth_loader, activation_dict, file): 75 | eval_dict = OrderedDict({}) 76 | gt_motion_embeddings = [] 77 | logger.info('========== Evaluating FID ==========') 78 | with torch.no_grad(): 79 | for idx, batch in enumerate(groundtruth_loader): 80 | _, _, _, sent_lens, motions, m_lens, _ = batch 81 | motion_embeddings = eval_wrapper.get_motion_embeddings( 82 | motions=motions, 83 | m_lens=m_lens 84 | ) 85 | gt_motion_embeddings.append(motion_embeddings.cpu().numpy()) 86 | gt_motion_embeddings = np.concatenate(gt_motion_embeddings, axis=0) 87 | gt_mu, gt_cov = calculate_activation_statistics(gt_motion_embeddings) 88 | 89 | # logger.info(gt_mu) 90 | for model_name, motion_embeddings in activation_dict.items(): 91 | mu, cov = calculate_activation_statistics(motion_embeddings) 92 | # logger.info(mu) 93 | fid = calculate_frechet_distance(gt_mu, gt_cov, mu, cov) 94 | logger.info(f'---> [{model_name}] FID: {fid:.4f}') 95 | logger.info(f'---> [{model_name}] FID: {fid:.4f}', file=file, flush=True) 96 | eval_dict[model_name] = fid 97 | return eval_dict 98 | 99 | 100 | def evaluate_diversity(activation_dict, file, diversity_times): 101 | eval_dict = OrderedDict({}) 102 | logger.info('========== Evaluating Diversity ==========') 103 | for model_name, motion_embeddings in activation_dict.items(): 104 | diversity = calculate_diversity(motion_embeddings, diversity_times) 105 | eval_dict[model_name] = diversity 106 | logger.info(f'---> [{model_name}] Diversity: {diversity:.4f}') 107 | logger.info(f'---> [{model_name}] Diversity: {diversity:.4f}', file=file, flush=True) 108 | return eval_dict 109 | 110 | 111 | def evaluate_multimodality(eval_wrapper, mm_motion_loaders, file, mm_num_times): 112 | eval_dict = OrderedDict({}) 113 | logger.info('========== Evaluating MultiModality ==========') 114 | for model_name, mm_motion_loader in mm_motion_loaders.items(): 115 | mm_motion_embeddings = [] 116 | with torch.no_grad(): 117 | for idx, batch in enumerate(mm_motion_loader): 118 | # (1, mm_replications, dim_pos) 119 | motions, m_lens = batch 120 | motion_embedings = eval_wrapper.get_motion_embeddings(motions[0], m_lens[0]) 121 | mm_motion_embeddings.append(motion_embedings.unsqueeze(0)) 122 | if len(mm_motion_embeddings) == 0: 123 | multimodality = 0 124 | else: 125 | mm_motion_embeddings = torch.cat(mm_motion_embeddings, dim=0).cpu().numpy() 126 | multimodality = calculate_multimodality(mm_motion_embeddings, mm_num_times) 127 | logger.info(f'---> [{model_name}] Multimodality: {multimodality:.4f}') 128 | logger.info(f'---> [{model_name}] Multimodality: {multimodality:.4f}', file=file, flush=True) 129 | eval_dict[model_name] = multimodality 130 | return eval_dict 131 | 132 | 133 | def get_metric_statistics(values, replication_times): 134 | mean = np.mean(values, axis=0) 135 | std = np.std(values, axis=0) 136 | conf_interval = 1.96 * std / np.sqrt(replication_times) 137 | return mean, conf_interval 138 | 139 | def evaluation(eval_wrapper, gt_loader, eval_motion_loaders, log_file, replication_times, diversity_times, mm_num_times, run_mm=False): 140 | with open(log_file, 'w') as f: 141 | all_metrics = OrderedDict({'Matching Score': OrderedDict({}), 142 | 'R_precision': OrderedDict({}), 143 | 'FID': OrderedDict({}), 144 | 'Diversity': OrderedDict({}), 145 | 'MultiModality': OrderedDict({})}) 146 | for replication in range(replication_times): 147 | motion_loaders = {} 148 | mm_motion_loaders = {} 149 | motion_loaders['ground truth'] = gt_loader 150 | for motion_loader_name, motion_loader_getter in eval_motion_loaders.items(): 151 | motion_loader, mm_motion_loader = motion_loader_getter() 152 | motion_loaders[motion_loader_name] = motion_loader 153 | mm_motion_loaders[motion_loader_name] = mm_motion_loader 154 | 155 | logger.info(f'==================== Replication {replication} ====================') 156 | logger.info(f'==================== Replication {replication} ====================', file=f, flush=True) 157 | logger.info(f'Time: {datetime.now()}') 158 | logger.info(f'Time: {datetime.now()}', file=f, flush=True) 159 | mat_score_dict, R_precision_dict, acti_dict = evaluate_matching_score(eval_wrapper, motion_loaders, f) 160 | 161 | logger.info(f'Time: {datetime.now()}') 162 | logger.info(f'Time: {datetime.now()}', file=f, flush=True) 163 | fid_score_dict = evaluate_fid(eval_wrapper, gt_loader, acti_dict, f) 164 | 165 | logger.info(f'Time: {datetime.now()}') 166 | logger.info(f'Time: {datetime.now()}', file=f, flush=True) 167 | div_score_dict = evaluate_diversity(acti_dict, f, diversity_times) 168 | 169 | if run_mm: 170 | logger.info(f'Time: {datetime.now()}') 171 | logger.info(f'Time: {datetime.now()}', file=f, flush=True) 172 | mm_score_dict = evaluate_multimodality(eval_wrapper, mm_motion_loaders, f, mm_num_times) 173 | 174 | logger.info(f'!!! DONE !!!') 175 | logger.info(f'!!! DONE !!!', file=f, flush=True) 176 | 177 | for key, item in mat_score_dict.items(): 178 | if key not in all_metrics['Matching Score']: 179 | all_metrics['Matching Score'][key] = [item] 180 | else: 181 | all_metrics['Matching Score'][key] += [item] 182 | 183 | for key, item in R_precision_dict.items(): 184 | if key not in all_metrics['R_precision']: 185 | all_metrics['R_precision'][key] = [item] 186 | else: 187 | all_metrics['R_precision'][key] += [item] 188 | 189 | for key, item in fid_score_dict.items(): 190 | if key not in all_metrics['FID']: 191 | all_metrics['FID'][key] = [item] 192 | else: 193 | all_metrics['FID'][key] += [item] 194 | 195 | for key, item in div_score_dict.items(): 196 | if key not in all_metrics['Diversity']: 197 | all_metrics['Diversity'][key] = [item] 198 | else: 199 | all_metrics['Diversity'][key] += [item] 200 | if run_mm: 201 | for key, item in mm_score_dict.items(): 202 | if key not in all_metrics['MultiModality']: 203 | all_metrics['MultiModality'][key] = [item] 204 | else: 205 | all_metrics['MultiModality'][key] += [item] 206 | 207 | 208 | # logger.info(all_metrics['Diversity']) 209 | mean_dict = {} 210 | for metric_name, metric_dict in all_metrics.items(): 211 | logger.info('========== %s Summary ==========' % metric_name) 212 | logger.info('========== %s Summary ==========' % metric_name, file=f, flush=True) 213 | for model_name, values in metric_dict.items(): 214 | # logger.info(metric_name, model_name) 215 | mean, conf_interval = get_metric_statistics(np.array(values), replication_times) 216 | mean_dict[metric_name + '_' + model_name] = mean 217 | # logger.info(mean, mean.dtype) 218 | if isinstance(mean, np.float64) or isinstance(mean, np.float32): 219 | logger.info(f'---> [{model_name}] Mean: {mean:.4f} CInterval: {conf_interval:.4f}') 220 | logger.info(f'---> [{model_name}] Mean: {mean:.4f} CInterval: {conf_interval:.4f}', file=f, flush=True) 221 | elif isinstance(mean, np.ndarray): 222 | line = f'---> [{model_name}]' 223 | for i in range(len(mean)): 224 | line += '(top %d) Mean: %.4f CInt: %.4f;' % (i+1, mean[i], conf_interval[i]) 225 | logger.info(line) 226 | logger.info(line, file=f, flush=True) 227 | return mean_dict 228 | 229 | if __name__ == '__main__': 230 | val_args,model_args=eval_himo_args() 231 | fixseed(val_args.seed) 232 | val_args.batch_size = 32 # This must be 32! Don't change it! otherwise it will cause a bug in R precision calc! 233 | 234 | model_name = os.path.basename(os.path.dirname(val_args.model_path)) 235 | model_niter = os.path.basename(val_args.model_path).replace('model', '').replace('.pt', '') 236 | 237 | log_file=osp.join('val_results','eval_{}_{}'.format(model_name,model_niter)) 238 | log_file+=f'_{val_args.eval_mode}' 239 | log_file+='.log' 240 | logger.add(log_file) 241 | logger.info(f'Will save to log file [{log_file}]') 242 | logger.info(f'Eval mode [{val_args.eval_mode}]') 243 | 244 | if val_args.eval_mode == 'debug': 245 | num_samples_limit = 1000 # None means no limit (eval over all dataset) 246 | run_mm = False 247 | mm_num_samples = 0 248 | mm_num_repeats = 0 249 | mm_num_times = 0 250 | diversity_times = 100 251 | replication_times = 1 252 | elif val_args.eval_mode == 'wo_mm': 253 | num_samples_limit = 1000 254 | run_mm = False 255 | mm_num_samples = 0 256 | mm_num_repeats = 0 257 | mm_num_times = 0 258 | diversity_times = 100 259 | replication_times = 20 260 | elif val_args.eval_mode == 'mm_short': 261 | num_samples_limit = 1000 262 | run_mm = True 263 | mm_num_samples = 100 264 | mm_num_repeats = 20 # 30 265 | mm_num_times = 10 266 | diversity_times = 100 267 | replication_times = 5 268 | else: 269 | raise ValueError() 270 | 271 | dist_utils.setup_dist(val_args.device) 272 | logger.configure() 273 | 274 | logger.info("creating data loader...") 275 | split = 'test' 276 | 277 | gt_dataset=Evaluation_Dataset(val_args,split=split,mode='gt') 278 | gen_dataset=Evaluation_Dataset(val_args,split=split,mode='eval') 279 | gt_loader=DataLoader(gt_dataset,batch_size=val_args.batch_size,shuffle=True,num_workers=8, 280 | drop_last=True,collate_fn=gt_collate_fn) 281 | gen_loader=DataLoader(gen_dataset,batch_size=val_args.batch_size,shuffle=True,num_workers=8, 282 | drop_last=True,collate_fn=gt_collate_fn) 283 | 284 | logger.info("creating model and diffusion...") 285 | model,diffusion=create_model_and_diffusion(model_args) 286 | logger.info(f"Loading model from [{val_args.model_path}]...") 287 | state_dict=torch.load(val_args.model_path,map_location='cpu') 288 | load_model_wo_clip(model,state_dict) 289 | 290 | if val_args.guidance_param!=1: 291 | model=ClassifierFreeSampleModel(model) 292 | 293 | model.to(dist_utils.dev()) 294 | model.eval() 295 | 296 | eval_motion_loaders={ 297 | 'vald': lambda:get_eval_gen_loader( 298 | val_args,model,diffusion, 299 | gen_loader,val_args.max_motion_length,val_args.batch_size, 300 | mm_num_samples,mm_num_repeats,num_samples_limit,val_args.guidance_param 301 | ) 302 | } 303 | eval_wrapper=EvaluationWrapper(val_args) 304 | evaluation(eval_wrapper,gt_loader,eval_motion_loaders,log_file,replication_times,diversity_times, 305 | mm_num_times,run_mm=run_mm) 306 | 307 | -------------------------------------------------------------------------------- /src/eval/eval_himo_2o.py: -------------------------------------------------------------------------------- 1 | from src.utils.parser_utils import eval_himo_args 2 | from src.utils.misc import fixseed 3 | from src.utils import dist_utils 4 | # from src.diffusion import logger 5 | 6 | from collections import OrderedDict 7 | from src.utils.model_utils import create_model_and_diffusion,load_model_wo_clip 8 | from torch.utils.data import DataLoader 9 | from src.dataset.eval_dataset import Evaluation_Dataset 10 | from src.dataset.tensors import gt_collate_fn 11 | from src.dataset.eval_gen_dataset import get_eval_gen_loader 12 | from src.feature_extractor.eval_wrapper import EvaluationWrapper 13 | from src.eval.metrics import * 14 | from src.model.cfg_sampler import ClassifierFreeSampleModel 15 | import torch 16 | import os 17 | import os.path as osp 18 | import numpy as np 19 | from datetime import datetime 20 | from loguru import logger 21 | 22 | def evaluate_matching_score(eval_wrapper, motion_loaders, file): 23 | match_score_dict = OrderedDict({}) 24 | R_precision_dict = OrderedDict({}) 25 | activation_dict = OrderedDict({}) 26 | logger.info('========== Evaluating Matching Score ==========') 27 | for motion_loader_name, motion_loader in motion_loaders.items(): 28 | all_motion_embeddings = [] 29 | score_list = [] 30 | all_size = 0 31 | matching_score_sum = 0 32 | top_k_count = 0 33 | logger.info(motion_loader_name) 34 | with torch.no_grad(): 35 | for idx, batch in enumerate(motion_loader): 36 | word_embeddings, pos_one_hots, _, sent_lens, motions, m_lens, _ = batch 37 | text_embeddings, motion_embeddings = eval_wrapper.get_co_embeddings( 38 | word_embs=word_embeddings, 39 | pos_ohot=pos_one_hots, 40 | cap_lens=sent_lens, 41 | motions=motions, 42 | m_lens=m_lens 43 | ) 44 | dist_mat = euclidean_distance_matrix(text_embeddings.cpu().numpy(), 45 | motion_embeddings.cpu().numpy()) 46 | matching_score_sum += dist_mat.trace() 47 | 48 | argsmax = np.argsort(dist_mat, axis=1) 49 | top_k_mat = calculate_top_k(argsmax, top_k=3) 50 | top_k_count += top_k_mat.sum(axis=0) 51 | 52 | all_size += text_embeddings.shape[0] 53 | 54 | all_motion_embeddings.append(motion_embeddings.cpu().numpy()) 55 | 56 | all_motion_embeddings = np.concatenate(all_motion_embeddings, axis=0) 57 | matching_score = matching_score_sum / all_size 58 | R_precision = top_k_count / all_size 59 | match_score_dict[motion_loader_name] = matching_score 60 | R_precision_dict[motion_loader_name] = R_precision 61 | activation_dict[motion_loader_name] = all_motion_embeddings 62 | 63 | logger.info(f'---> [{motion_loader_name}] Matching Score: {matching_score:.4f}') 64 | logger.info(f'---> [{motion_loader_name}] Matching Score: {matching_score:.4f}', file=file, flush=True) 65 | 66 | line = f'---> [{motion_loader_name}] R_precision: ' 67 | for i in range(len(R_precision)): 68 | line += '(top %d): %.4f ' % (i+1, R_precision[i]) 69 | logger.info(line) 70 | logger.info(line, file=file, flush=True) 71 | 72 | return match_score_dict, R_precision_dict, activation_dict 73 | 74 | def evaluate_fid(eval_wrapper, groundtruth_loader, activation_dict, file): 75 | eval_dict = OrderedDict({}) 76 | gt_motion_embeddings = [] 77 | logger.info('========== Evaluating FID ==========') 78 | with torch.no_grad(): 79 | for idx, batch in enumerate(groundtruth_loader): 80 | _, _, _, sent_lens, motions, m_lens, _ = batch 81 | motion_embeddings = eval_wrapper.get_motion_embeddings( 82 | motions=motions, 83 | m_lens=m_lens 84 | ) 85 | gt_motion_embeddings.append(motion_embeddings.cpu().numpy()) 86 | gt_motion_embeddings = np.concatenate(gt_motion_embeddings, axis=0) 87 | gt_mu, gt_cov = calculate_activation_statistics(gt_motion_embeddings) 88 | 89 | # logger.info(gt_mu) 90 | for model_name, motion_embeddings in activation_dict.items(): 91 | mu, cov = calculate_activation_statistics(motion_embeddings) 92 | # logger.info(mu) 93 | fid = calculate_frechet_distance(gt_mu, gt_cov, mu, cov) 94 | logger.info(f'---> [{model_name}] FID: {fid:.4f}') 95 | logger.info(f'---> [{model_name}] FID: {fid:.4f}', file=file, flush=True) 96 | eval_dict[model_name] = fid 97 | return eval_dict 98 | 99 | 100 | def evaluate_diversity(activation_dict, file, diversity_times): 101 | eval_dict = OrderedDict({}) 102 | logger.info('========== Evaluating Diversity ==========') 103 | for model_name, motion_embeddings in activation_dict.items(): 104 | diversity = calculate_diversity(motion_embeddings, diversity_times) 105 | eval_dict[model_name] = diversity 106 | logger.info(f'---> [{model_name}] Diversity: {diversity:.4f}') 107 | logger.info(f'---> [{model_name}] Diversity: {diversity:.4f}', file=file, flush=True) 108 | return eval_dict 109 | 110 | 111 | def evaluate_multimodality(eval_wrapper, mm_motion_loaders, file, mm_num_times): 112 | eval_dict = OrderedDict({}) 113 | logger.info('========== Evaluating MultiModality ==========') 114 | for model_name, mm_motion_loader in mm_motion_loaders.items(): 115 | mm_motion_embeddings = [] 116 | with torch.no_grad(): 117 | for idx, batch in enumerate(mm_motion_loader): 118 | # (1, mm_replications, dim_pos) 119 | motions, m_lens = batch 120 | motion_embedings = eval_wrapper.get_motion_embeddings(motions[0], m_lens[0]) 121 | mm_motion_embeddings.append(motion_embedings.unsqueeze(0)) 122 | if len(mm_motion_embeddings) == 0: 123 | multimodality = 0 124 | else: 125 | mm_motion_embeddings = torch.cat(mm_motion_embeddings, dim=0).cpu().numpy() 126 | multimodality = calculate_multimodality(mm_motion_embeddings, mm_num_times) 127 | logger.info(f'---> [{model_name}] Multimodality: {multimodality:.4f}') 128 | logger.info(f'---> [{model_name}] Multimodality: {multimodality:.4f}', file=file, flush=True) 129 | eval_dict[model_name] = multimodality 130 | return eval_dict 131 | 132 | 133 | def get_metric_statistics(values, replication_times): 134 | mean = np.mean(values, axis=0) 135 | std = np.std(values, axis=0) 136 | conf_interval = 1.96 * std / np.sqrt(replication_times) 137 | return mean, conf_interval 138 | 139 | def evaluation(eval_wrapper, gt_loader, eval_motion_loaders, log_file, replication_times, diversity_times, mm_num_times, run_mm=False): 140 | with open(log_file, 'w') as f: 141 | all_metrics = OrderedDict({'Matching Score': OrderedDict({}), 142 | 'R_precision': OrderedDict({}), 143 | 'FID': OrderedDict({}), 144 | 'Diversity': OrderedDict({}), 145 | 'MultiModality': OrderedDict({})}) 146 | for replication in range(replication_times): 147 | motion_loaders = {} 148 | mm_motion_loaders = {} 149 | motion_loaders['ground truth'] = gt_loader 150 | for motion_loader_name, motion_loader_getter in eval_motion_loaders.items(): 151 | motion_loader, mm_motion_loader = motion_loader_getter() 152 | motion_loaders[motion_loader_name] = motion_loader 153 | mm_motion_loaders[motion_loader_name] = mm_motion_loader 154 | 155 | logger.info(f'==================== Replication {replication} ====================') 156 | logger.info(f'==================== Replication {replication} ====================', file=f, flush=True) 157 | logger.info(f'Time: {datetime.now()}') 158 | logger.info(f'Time: {datetime.now()}', file=f, flush=True) 159 | mat_score_dict, R_precision_dict, acti_dict = evaluate_matching_score(eval_wrapper, motion_loaders, f) 160 | 161 | logger.info(f'Time: {datetime.now()}') 162 | logger.info(f'Time: {datetime.now()}', file=f, flush=True) 163 | fid_score_dict = evaluate_fid(eval_wrapper, gt_loader, acti_dict, f) 164 | 165 | logger.info(f'Time: {datetime.now()}') 166 | logger.info(f'Time: {datetime.now()}', file=f, flush=True) 167 | div_score_dict = evaluate_diversity(acti_dict, f, diversity_times) 168 | 169 | if run_mm: 170 | logger.info(f'Time: {datetime.now()}') 171 | logger.info(f'Time: {datetime.now()}', file=f, flush=True) 172 | mm_score_dict = evaluate_multimodality(eval_wrapper, mm_motion_loaders, f, mm_num_times) 173 | 174 | logger.info(f'!!! DONE !!!') 175 | logger.info(f'!!! DONE !!!', file=f, flush=True) 176 | 177 | for key, item in mat_score_dict.items(): 178 | if key not in all_metrics['Matching Score']: 179 | all_metrics['Matching Score'][key] = [item] 180 | else: 181 | all_metrics['Matching Score'][key] += [item] 182 | 183 | for key, item in R_precision_dict.items(): 184 | if key not in all_metrics['R_precision']: 185 | all_metrics['R_precision'][key] = [item] 186 | else: 187 | all_metrics['R_precision'][key] += [item] 188 | 189 | for key, item in fid_score_dict.items(): 190 | if key not in all_metrics['FID']: 191 | all_metrics['FID'][key] = [item] 192 | else: 193 | all_metrics['FID'][key] += [item] 194 | 195 | for key, item in div_score_dict.items(): 196 | if key not in all_metrics['Diversity']: 197 | all_metrics['Diversity'][key] = [item] 198 | else: 199 | all_metrics['Diversity'][key] += [item] 200 | if run_mm: 201 | for key, item in mm_score_dict.items(): 202 | if key not in all_metrics['MultiModality']: 203 | all_metrics['MultiModality'][key] = [item] 204 | else: 205 | all_metrics['MultiModality'][key] += [item] 206 | 207 | 208 | # logger.info(all_metrics['Diversity']) 209 | mean_dict = {} 210 | for metric_name, metric_dict in all_metrics.items(): 211 | logger.info('========== %s Summary ==========' % metric_name) 212 | logger.info('========== %s Summary ==========' % metric_name, file=f, flush=True) 213 | for model_name, values in metric_dict.items(): 214 | # logger.info(metric_name, model_name) 215 | mean, conf_interval = get_metric_statistics(np.array(values), replication_times) 216 | mean_dict[metric_name + '_' + model_name] = mean 217 | # logger.info(mean, mean.dtype) 218 | if isinstance(mean, np.float64) or isinstance(mean, np.float32): 219 | logger.info(f'---> [{model_name}] Mean: {mean:.4f} CInterval: {conf_interval:.4f}') 220 | logger.info(f'---> [{model_name}] Mean: {mean:.4f} CInterval: {conf_interval:.4f}', file=f, flush=True) 221 | elif isinstance(mean, np.ndarray): 222 | line = f'---> [{model_name}]' 223 | for i in range(len(mean)): 224 | line += '(top %d) Mean: %.4f CInt: %.4f;' % (i+1, mean[i], conf_interval[i]) 225 | logger.info(line) 226 | logger.info(line, file=f, flush=True) 227 | return mean_dict 228 | 229 | if __name__ == '__main__': 230 | val_args,model_args=eval_himo_args() 231 | fixseed(val_args.seed) 232 | val_args.batch_size = 32 # This must be 32! Don't change it! otherwise it will cause a bug in R precision calc! 233 | 234 | model_name = os.path.basename(os.path.dirname(val_args.model_path)) 235 | model_niter = os.path.basename(val_args.model_path).replace('model', '').replace('.pt', '') 236 | 237 | log_file=osp.join('val_results','eval_{}_{}'.format(model_name,model_niter)) 238 | log_file+=f'_{val_args.eval_mode}' 239 | log_file+='.log' 240 | logger.add(log_file) 241 | logger.info(f'Will save to log file [{log_file}]') 242 | logger.info(f'Eval mode [{val_args.eval_mode}]') 243 | 244 | if val_args.eval_mode == 'debug': 245 | num_samples_limit = 1000 # None means no limit (eval over all dataset) 246 | run_mm = False 247 | mm_num_samples = 0 248 | mm_num_repeats = 0 249 | mm_num_times = 0 250 | diversity_times = 200 251 | replication_times = 1 252 | elif val_args.eval_mode == 'wo_mm': 253 | num_samples_limit = 1000 254 | run_mm = False 255 | mm_num_samples = 0 256 | mm_num_repeats = 0 257 | mm_num_times = 0 258 | diversity_times = 200 # 300 259 | replication_times = 10 #20 260 | elif val_args.eval_mode == 'mm_short': 261 | num_samples_limit = 1000 262 | run_mm = True 263 | mm_num_samples = 100 264 | mm_num_repeats = 20 # 30 265 | mm_num_times = 10 266 | diversity_times = 200 # 300 267 | replication_times = 3 268 | else: 269 | raise ValueError() 270 | 271 | dist_utils.setup_dist(val_args.device) 272 | logger.configure() 273 | 274 | logger.info("creating data loader...") 275 | split = 'test' 276 | 277 | gt_dataset=Evaluation_Dataset(val_args,split=split,mode='gt') 278 | gen_dataset=Evaluation_Dataset(val_args,split=split,mode='eval') 279 | gt_loader=DataLoader(gt_dataset,batch_size=val_args.batch_size,shuffle=False,num_workers=8, 280 | drop_last=True,collate_fn=gt_collate_fn) 281 | gen_loader=DataLoader(gen_dataset,batch_size=val_args.batch_size,shuffle=False,num_workers=8, 282 | drop_last=True,collate_fn=gt_collate_fn) 283 | 284 | logger.info("creating model and diffusion...") 285 | model,diffusion=create_model_and_diffusion(model_args) 286 | logger.info(f"Loading model from [{val_args.model_path}]...") 287 | state_dict=torch.load(val_args.model_path,map_location='cpu') 288 | load_model_wo_clip(model,state_dict) 289 | 290 | if val_args.guidance_param!=1: 291 | model=ClassifierFreeSampleModel(model) 292 | 293 | model.to(dist_utils.dev()) 294 | model.eval() 295 | 296 | eval_motion_loaders={ 297 | 'vald': lambda:get_eval_gen_loader( 298 | val_args,model,diffusion, 299 | gen_loader,val_args.max_motion_length,val_args.batch_size, 300 | mm_num_samples,mm_num_repeats,num_samples_limit,val_args.guidance_param 301 | ) 302 | } 303 | eval_wrapper=EvaluationWrapper(val_args) 304 | evaluation(eval_wrapper,gt_loader,eval_motion_loaders,log_file,replication_times,diversity_times, 305 | mm_num_times,run_mm=run_mm) 306 | 307 | --------------------------------------------------------------------------------