├── .gitignore
├── assets
└── teaser.png
├── visualize
├── requirements.txt
├── segment_viewer.py
├── skel_viewer.py
└── smplx_viewer.py
├── src
├── utils
│ ├── misc.py
│ ├── model_utils.py
│ ├── word_vectorizer.py
│ ├── dist_utils.py
│ └── rotation_conversion.py
├── model
│ ├── cfg_sampler.py
│ ├── blocks.py
│ ├── layers.py
│ ├── net_2o.py
│ └── net_3o.py
├── train
│ ├── train_net_2o.py
│ ├── train_net_3o.py
│ ├── train_platforms.py
│ └── training_loop.py
├── diffusion
│ ├── losses.py
│ ├── respace.py
│ ├── resample.py
│ ├── nn.py
│ └── fp16_util.py
├── dataset
│ ├── decomp_dataset.py
│ ├── himo_2o_dataset.py
│ ├── himo_3o_dataset.py
│ ├── fe_dataset.py
│ ├── tensors.py
│ ├── eval_dataset.py
│ └── eval_gen_dataset.py
├── feature_extractor
│ ├── train_decomp.py
│ ├── train_tex_mot_match.py
│ ├── eval_wrapper.py
│ └── modules.py
└── eval
│ ├── metrics.py
│ ├── eval_himo_3o.py
│ └── eval_himo_2o.py
├── README.md
└── LICENSE
/.gitignore:
--------------------------------------------------------------------------------
1 | /body_models/
2 | /data/
3 | imgui.ini
4 | __pycache__/
5 | save/
--------------------------------------------------------------------------------
/assets/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LvXinTao/HIMO_dataset/HEAD/assets/teaser.png
--------------------------------------------------------------------------------
/visualize/requirements.txt:
--------------------------------------------------------------------------------
1 | aitviewer
2 | trimesh
3 | glfw
4 | chumpy
5 | smplx
6 | tqdm
7 | scikit-learn
8 | pandas
9 | matplotlib
--------------------------------------------------------------------------------
/src/utils/misc.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import random
4 | import os
5 |
6 | def fixseed(seed):
7 | torch.backends.cudnn.benchmark = False
8 | random.seed(seed)
9 | np.random.seed(seed)
10 | torch.manual_seed(seed)
11 |
12 | def makepath(desired_path,isfile=False):
13 | import os
14 | if isfile:
15 | if not os.path.exists(os.path.dirname(desired_path)):os.makedirs(os.path.dirname(desired_path),exist_ok=True)
16 | else:
17 | if not os.path.exists(desired_path): os.makedirs(desired_path,exist_ok=True)
18 | return desired_path
19 |
20 | def to_numpy(tensor):
21 | if torch.is_tensor(tensor):
22 | return tensor.cpu().numpy()
23 | elif type(tensor).__module__ != 'numpy':
24 | raise ValueError("Cannot convert {} to numpy array".format(
25 | type(tensor)))
26 | return tensor
27 |
28 | def to_tensor(ndarray):
29 | if type(ndarray).__module__ == 'numpy':
30 | return torch.from_numpy(ndarray).float()
31 | elif not torch.is_tensor(ndarray):
32 | raise ValueError("Cannot convert {} to torch tensor".format(
33 | type(ndarray)))
34 | return ndarray
--------------------------------------------------------------------------------
/src/model/cfg_sampler.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | from copy import deepcopy
5 |
6 | # A wrapper model for Classifier-free guidance **SAMPLING** only
7 | # https://arxiv.org/abs/2207.12598
8 | class ClassifierFreeSampleModel(nn.Module):
9 |
10 | def __init__(self, model):
11 | super().__init__()
12 | self.model = model # model is the actual model to run
13 |
14 | assert self.model.cond_mask_prob > 0, 'Cannot run a guided diffusion on a model that has not been trained with no conditions'
15 |
16 | # pointers to inner model
17 | self.rot2xyz = self.model.rot2xyz
18 | self.translation = self.model.translation
19 | self.njoints = self.model.njoints
20 | self.nfeats = self.model.nfeats
21 | self.data_rep = self.model.data_rep
22 | self.cond_mode = self.model.cond_mode
23 |
24 | def forward(self, x, timesteps, y=None):
25 | cond_mode = self.model.cond_mode
26 | assert cond_mode in ['text', 'action']
27 | y_uncond = deepcopy(y)
28 | y_uncond['uncond'] = True
29 | out = self.model(x, timesteps, y)
30 | out_uncond = self.model(x, timesteps, y_uncond)
31 | return out_uncond + (y['scale'].view(-1, 1, 1, 1) * (out - out_uncond))
32 |
33 |
--------------------------------------------------------------------------------
/src/train/train_net_2o.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import json
4 |
5 | from src.utils.misc import fixseed,makepath
6 | from src.utils.parser_utils import train_net_2o_args
7 | from src.utils.model_utils import create_model_and_diffusion
8 | from src.utils import dist_utils
9 | from src.train.train_platforms import *
10 | from src.train.training_loop import TrainLoop
11 | from src.dataset.himo_2o_dataset import HIMO_2O
12 | from src.dataset.tensors import himo_2o_collate_fn
13 |
14 | from torch.utils.data import DataLoader
15 | from loguru import logger
16 |
17 | def main():
18 | args=train_net_2o_args()
19 | # save path
20 | save_path=osp.join(args.save_dir,args.exp_name)
21 | if osp.exists(save_path):
22 | raise FileExistsError(f'{save_path} already exists!')
23 | # pass
24 | else:
25 | makepath(save_path)
26 | args.save_path=save_path
27 | # training plateform
28 | train_platform_type=eval(args.train_platform)
29 | train_platform=train_platform_type(args.save_path)
30 | train_platform.report_args(args,'Args')
31 | # config logger
32 | logger.add(osp.join(save_path,'train.log'))
33 | # save args
34 | with open(osp.join(save_path,'args.json'),'w') as f:
35 | json.dump(vars(args),f,indent=4)
36 |
37 | dist_utils.setup_dist(args.device)
38 | # get dataset loader
39 | logger.info('Loading Training dataset...')
40 | train_dataset=HIMO_2O(args,split='train')
41 | data_loader=DataLoader(train_dataset,batch_size=args.batch_size,
42 | shuffle=True,num_workers=8,drop_last=True,collate_fn=himo_2o_collate_fn)
43 |
44 | # get model and diffusion
45 | logger.info('Loading Model and Diffusion...')
46 | model,diffusion=create_model_and_diffusion(args)
47 | model.to(dist_utils.dev())
48 |
49 | logger.info('Start Training...')
50 | TrainLoop(args,train_platform,model,diffusion,data_loader).run_loop()
51 | train_platform.close()
52 |
53 | if __name__=='__main__':
54 | main()
--------------------------------------------------------------------------------
/src/train/train_net_3o.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import json
4 |
5 | from src.utils.misc import fixseed,makepath
6 | from src.utils.parser_utils import train_net_3o_args
7 | from src.utils.model_utils import create_model_and_diffusion
8 | from src.utils import dist_utils
9 | from src.train.train_platforms import *
10 | from src.train.training_loop import TrainLoop
11 | from src.dataset.himo_3o_dataset import HIMO_3O
12 | from src.dataset.tensors import himo_3o_collate_fn
13 |
14 | from torch.utils.data import DataLoader
15 | from loguru import logger
16 |
17 | def main():
18 | args=train_net_3o_args()
19 | # save path
20 | save_path=osp.join(args.save_dir,args.exp_name)
21 | if osp.exists(save_path):
22 | raise FileExistsError(f'{save_path} already exists!')
23 | # pass
24 | else:
25 | makepath(save_path)
26 | args.save_path=save_path
27 | # training plateform
28 | train_platform_type=eval(args.train_platform)
29 | train_platform=train_platform_type(args.save_path)
30 | train_platform.report_args(args,'Args')
31 | # config logger
32 | logger.add(osp.join(save_path,'train.log'))
33 | # save args
34 | with open(osp.join(save_path,'args.json'),'w') as f:
35 | json.dump(vars(args),f,indent=4)
36 |
37 | dist_utils.setup_dist(args.device)
38 | # get dataset loader
39 | logger.info('Loading Training dataset...')
40 | train_dataset=HIMO_3O(args,split='train')
41 | data_loader=DataLoader(train_dataset,batch_size=args.batch_size,
42 | shuffle=True,num_workers=8,drop_last=True,collate_fn=himo_3o_collate_fn)
43 |
44 | # get model and diffusion
45 | logger.info('Loading Model and Diffusion...')
46 | model,diffusion=create_model_and_diffusion(args)
47 | model.to(dist_utils.dev())
48 |
49 | logger.info('Start Training...')
50 | TrainLoop(args,train_platform,model,diffusion,data_loader).run_loop()
51 | train_platform.close()
52 |
53 | if __name__=='__main__':
54 | main()
--------------------------------------------------------------------------------
/src/train/train_platforms.py:
--------------------------------------------------------------------------------
1 | import os
2 | import wandb
3 |
4 | class TrainPlatform:
5 | def __init__(self, save_dir):
6 | pass
7 |
8 | def report_scalar(self, name, value, iteration, group_name=None):
9 | pass
10 |
11 | def report_args(self, args, name):
12 | pass
13 |
14 | def close(self):
15 | pass
16 |
17 |
18 | class ClearmlPlatform(TrainPlatform):
19 | def __init__(self, save_dir):
20 | from clearml import Task
21 | path, name = os.path.split(save_dir)
22 | self.task = Task.init(project_name='motion_diffusion',
23 | task_name=name,
24 | output_uri=path)
25 | self.logger = self.task.get_logger()
26 |
27 | def report_scalar(self, name, value, iteration, group_name):
28 | self.logger.report_scalar(title=group_name, series=name, iteration=iteration, value=value)
29 |
30 | def report_args(self, args, name):
31 | self.task.connect(args, name=name)
32 |
33 | def close(self):
34 | self.task.close()
35 |
36 |
37 | class TensorboardPlatform(TrainPlatform):
38 | def __init__(self, save_dir):
39 | from torch.utils.tensorboard import SummaryWriter
40 | self.writer = SummaryWriter(log_dir=save_dir)
41 |
42 | def report_scalar(self, name, value, iteration, group_name=None):
43 | self.writer.add_scalar(f'{group_name}/{name}', value, iteration)
44 |
45 | def close(self):
46 | self.writer.close()
47 |
48 | class WandbPlatform(TrainPlatform):
49 | def __init__(self, save_dir):
50 | import wandb
51 | wandb.init(project='HIMO_eccv',
52 | name=os.path.split(save_dir)[-1], dir=save_dir)
53 |
54 | def report_scalar(self, name, value, iteration, group_name=None):
55 | wandb.log({f'{group_name}/{name}': value}, step=iteration)
56 |
57 | def report_args(self, args, name):
58 | wandb.config.update(args, allow_val_change=True)
59 |
60 | def close(self):
61 | wandb.finish()
62 |
63 | class NoPlatform(TrainPlatform):
64 | def __init__(self, save_dir):
65 | pass
66 |
67 |
68 |
--------------------------------------------------------------------------------
/src/utils/model_utils.py:
--------------------------------------------------------------------------------
1 | from src.diffusion import gaussian_diffusion as gd
2 | from src.model.net_2o import NET_2O
3 | from src.model.net_3o import NET_3O
4 | from src.diffusion.respace import SpacedDiffusion,space_timesteps
5 |
6 | def load_model_wo_clip(model, state_dict):
7 | missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
8 | assert len(unexpected_keys) == 0
9 | assert all([k.startswith('clip_model.') for k in missing_keys])
10 |
11 | def create_model_and_diffusion(args):
12 | if args.network=='net_2o':
13 | model=NET_2O()
14 | diffusion=create_gaussian_diffusion(args)
15 | elif args.network=='net_3o':
16 | model=NET_3O()
17 | diffusion=create_gaussian_diffusion(args)
18 |
19 | return model,diffusion
20 |
21 | def create_gaussian_diffusion(args):
22 | # default params
23 | predict_xstart = True # we always predict x_start (a.k.a. x0), that's our deal!
24 | steps = args.diffusion_steps
25 | scale_beta = 1. # no scaling
26 | timestep_respacing = '' # can be used for ddim sampling, we don't use it.
27 | learn_sigma = False
28 | rescale_timesteps = False
29 |
30 | betas = gd.get_named_beta_schedule(args.noise_schedule, steps, scale_beta)
31 | loss_type = gd.LossType.MSE
32 |
33 | if not timestep_respacing:
34 | timestep_respacing = [steps]
35 | return SpacedDiffusion(
36 | use_timesteps=space_timesteps(steps, timestep_respacing),
37 | betas=betas,
38 | model_mean_type=(
39 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
40 | ),
41 | model_var_type=(
42 | (
43 | gd.ModelVarType.FIXED_LARGE
44 | if not args.sigma_small
45 | else gd.ModelVarType.FIXED_SMALL
46 | )
47 | if not learn_sigma
48 | else gd.ModelVarType.LEARNED_RANGE
49 | ),
50 | loss_type=loss_type,
51 | rescale_timesteps=rescale_timesteps,
52 | lambda_recon=args.lambda_recon if hasattr(args, 'lambda_recon') else 0.0,
53 | lambda_pos=args.lambda_pos if hasattr(args, 'lambda_pos') else 0.0,
54 | lambda_geo=args.lambda_geo if hasattr(args, 'lambda_geo') else 0.0,
55 | lambda_vel=args.lambda_vel if hasattr(args, 'lambda_vel') else 0.0,
56 | lambda_sp=args.lambda_sp if hasattr(args, 'lambda_sp') else 0.0,
57 | train_args=args
58 | )
59 |
--------------------------------------------------------------------------------
/src/diffusion/losses.py:
--------------------------------------------------------------------------------
1 | # This code is based on https://github.com/openai/guided-diffusion
2 | """
3 | Helpers for various likelihood-based losses. These are ported from the original
4 | Ho et al. diffusion models codebase:
5 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py
6 | """
7 |
8 | import numpy as np
9 | import torch as th
10 |
11 |
12 | def normal_kl(mean1, logvar1, mean2, logvar2):
13 | """
14 | Compute the KL divergence between two gaussians.
15 |
16 | Shapes are automatically broadcasted, so batches can be compared to
17 | scalars, among other use cases.
18 | """
19 | tensor = None
20 | for obj in (mean1, logvar1, mean2, logvar2):
21 | if isinstance(obj, th.Tensor):
22 | tensor = obj
23 | break
24 | assert tensor is not None, "at least one argument must be a Tensor"
25 |
26 | # Force variances to be Tensors. Broadcasting helps convert scalars to
27 | # Tensors, but it does not work for th.exp().
28 | logvar1, logvar2 = [
29 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
30 | for x in (logvar1, logvar2)
31 | ]
32 |
33 | return 0.5 * (
34 | -1.0
35 | + logvar2
36 | - logvar1
37 | + th.exp(logvar1 - logvar2)
38 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
39 | )
40 |
41 |
42 | def approx_standard_normal_cdf(x):
43 | """
44 | A fast approximation of the cumulative distribution function of the
45 | standard normal.
46 | """
47 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
48 |
49 |
50 | def discretized_gaussian_log_likelihood(x, *, means, log_scales):
51 | """
52 | Compute the log-likelihood of a Gaussian distribution discretizing to a
53 | given image.
54 |
55 | :param x: the target images. It is assumed that this was uint8 values,
56 | rescaled to the range [-1, 1].
57 | :param means: the Gaussian mean Tensor.
58 | :param log_scales: the Gaussian log stddev Tensor.
59 | :return: a tensor like x of log probabilities (in nats).
60 | """
61 | assert x.shape == means.shape == log_scales.shape
62 | centered_x = x - means
63 | inv_stdv = th.exp(-log_scales)
64 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
65 | cdf_plus = approx_standard_normal_cdf(plus_in)
66 | min_in = inv_stdv * (centered_x - 1.0 / 255.0)
67 | cdf_min = approx_standard_normal_cdf(min_in)
68 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
69 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
70 | cdf_delta = cdf_plus - cdf_min
71 | log_probs = th.where(
72 | x < -0.999,
73 | log_cdf_plus,
74 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
75 | )
76 | assert log_probs.shape == x.shape
77 | return log_probs
78 |
--------------------------------------------------------------------------------
/src/utils/word_vectorizer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pickle
3 | from os.path import join as pjoin
4 |
5 | POS_enumerator = {
6 | 'VERB': 0,
7 | 'NOUN': 1,
8 | 'DET': 2,
9 | 'ADP': 3,
10 | 'NUM': 4,
11 | 'AUX': 5,
12 | 'PRON': 6,
13 | 'ADJ': 7,
14 | 'ADV': 8,
15 | 'Loc_VIP': 9,
16 | 'Body_VIP': 10,
17 | 'Obj_VIP': 11,
18 | 'Act_VIP': 12,
19 | 'Desc_VIP': 13,
20 | 'OTHER': 14,
21 | }
22 |
23 | Loc_list = ('left', 'right', 'clockwise', 'counterclockwise', 'anticlockwise', 'forward', 'back', 'backward',
24 | 'up', 'down', 'straight', 'curve')
25 |
26 | Body_list = ('arm', 'chin', 'foot', 'feet', 'face', 'hand', 'mouth', 'leg', 'waist', 'eye', 'knee', 'shoulder', 'thigh')
27 |
28 | Obj_List = ('stair', 'dumbbell', 'chair', 'window', 'floor', 'car', 'ball', 'handrail', 'baseball', 'basketball')
29 |
30 | Act_list = ('walk', 'run', 'swing', 'pick', 'bring', 'kick', 'put', 'squat', 'throw', 'hop', 'dance', 'jump', 'turn',
31 | 'stumble', 'dance', 'stop', 'sit', 'lift', 'lower', 'raise', 'wash', 'stand', 'kneel', 'stroll',
32 | 'rub', 'bend', 'balance', 'flap', 'jog', 'shuffle', 'lean', 'rotate', 'spin', 'spread', 'climb')
33 |
34 | Desc_list = ('slowly', 'carefully', 'fast', 'careful', 'slow', 'quickly', 'happy', 'angry', 'sad', 'happily',
35 | 'angrily', 'sadly')
36 |
37 | VIP_dict = {
38 | 'Loc_VIP': Loc_list,
39 | 'Body_VIP': Body_list,
40 | 'Obj_VIP': Obj_List,
41 | 'Act_VIP': Act_list,
42 | 'Desc_VIP': Desc_list,
43 | }
44 |
45 |
46 | class WordVectorizer(object):
47 | def __init__(self, meta_root, prefix):
48 | vectors = np.load(pjoin(meta_root, '%s_data.npy'%prefix))
49 | words = pickle.load(open(pjoin(meta_root, '%s_words.pkl'%prefix), 'rb'))
50 | word2idx = pickle.load(open(pjoin(meta_root, '%s_idx.pkl'%prefix), 'rb'))
51 | self.word2vec = {w: vectors[word2idx[w]] for w in words}
52 |
53 | def _get_pos_ohot(self, pos):
54 | pos_vec = np.zeros(len(POS_enumerator))
55 | if pos in POS_enumerator:
56 | pos_vec[POS_enumerator[pos]] = 1
57 | else:
58 | pos_vec[POS_enumerator['OTHER']] = 1
59 | return pos_vec
60 |
61 | def __len__(self):
62 | return len(self.word2vec)
63 |
64 | def __getitem__(self, item):
65 | word, pos = item.split('/')
66 | if word in self.word2vec:
67 | word_vec = self.word2vec[word]
68 | vip_pos = None
69 | for key, values in VIP_dict.items():
70 | if word in values:
71 | vip_pos = key
72 | break
73 | if vip_pos is not None:
74 | pos_vec = self._get_pos_ohot(vip_pos)
75 | else:
76 | pos_vec = self._get_pos_ohot(pos)
77 | else:
78 | word_vec = self.word2vec['unk']
79 | pos_vec = self._get_pos_ohot('OTHER')
80 | return word_vec, pos_vec
--------------------------------------------------------------------------------
/src/dataset/decomp_dataset.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset
2 | from tqdm import tqdm
3 | import h5py
4 | import numpy as np
5 | from loguru import logger
6 | import os.path as osp
7 |
8 | class decomp_dataset(Dataset):
9 | def __init__(self,opt,split='train'):
10 | self.opt=opt
11 |
12 | self.data=[]
13 | self.lengths=[]
14 | self.split=split
15 | self.load_data()
16 |
17 | self.cumsum = np.cumsum([0] + self.lengths)
18 | logger.info("Total number of motions {}, snippets {}".format(len(self.data), self.cumsum[-1]))
19 |
20 | def __len__(self):
21 | return self.cumsum[-1]
22 |
23 | def __getitem__(self, item):
24 | if item != 0:
25 | motion_id = np.searchsorted(self.cumsum, item) - 1
26 | idx = item - self.cumsum[motion_id] - 1
27 | else:
28 | motion_id = 0
29 | idx = 0
30 | motion = self.data[motion_id][idx:idx+self.opt.window_size]
31 | return motion
32 |
33 | def load_data(self):
34 | with h5py.File(osp.join(self.opt.data_root,"processed_%s"%self.opt.mode,"{}.h5".format(self.split)),'r') as f:
35 | for seq in tqdm(f.keys(),desc=f'Loading {self.split} dataset'):
36 |
37 | body_pose=f[seq]['body_pose'][:] # nf,21,6
38 | global_orient=f[seq]['global_orient'][:][:,None,:] # nf,1,6
39 | lhand_pose=f[seq]['lhand_pose'][:] # nf,15,6
40 | rhand_pose=f[seq]['rhand_pose'][:] # nf,15,6
41 | transl=f[seq]['transl'][:] # nf,3
42 | smplx_joints=f[seq]['smplx_joints'][:]
43 | smplx_joints=smplx_joints.reshape(smplx_joints.shape[0],-1) # nf,52*3
44 | full_pose=np.concatenate([global_orient,body_pose,lhand_pose,rhand_pose],axis=1)
45 | full_pose=full_pose.reshape(full_pose.shape[0],-1) # nf,52*6
46 | human_motion=np.concatenate([smplx_joints,full_pose,transl],axis=-1) # nf,52*6+52*3+3
47 | if self.opt.mode=='2o':
48 | o1,o2=sorted(f[seq]['object_state'].keys())
49 | o1_state=f[seq]['object_state'][o1][:]
50 | o2_state=f[seq]['object_state'][o2][:]
51 | object_state=np.concatenate([o1_state,o2_state],axis=-1)
52 | elif self.opt.mode=='3o':
53 | o1,o2,o3=sorted(f[seq]['object_state'].keys())
54 | o1_state=f[seq]['object_state'][o1][:]
55 | o2_state=f[seq]['object_state'][o2][:]
56 | o3_state=f[seq]['object_state'][o3][:]
57 | object_state=np.concatenate([o1_state,o2_state,o3_state],axis=-1)
58 | motion=np.concatenate([human_motion,object_state],axis=-1) # nf,(52*6+52*3+3)+(2*(6+3))
59 | if motion.shape[0]
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 | This repository contains the content of the following paper:
20 | > HIMO: A New Benchmark for Full-Body Human Interacting with Multiple Objects
Xintao Lv 1,* , Liang Xu 1,2* , Yichao Yan 1, Xin Jin 2, Congsheng Xu 1, Shuwen Wu1, Yifan Liu1, Lincheng Li3, Mengxiao Bi3, Wenjun Zeng2, Xiaokang Yang 1
21 | > 1 Shanghai Jiao Tong University , 2 Eastern Institute of Technology, Ningbo, 3 NetEase Fuxi AI Lab
22 |
23 |
24 | ## News
25 | - [2025.03.10] We release the processed data for training and evaluation, please fill out the [this form](https://docs.google.com/forms/d/e/1FAIpQLSdl5adeyKxBSBFZpgs0A7-dAouRkMFAGUP5iz3zxGDj_PhB1w/viewform) to request `processed_2o.tar.gz` and `processed_3o.tar.gz`.
26 |
27 |
28 | ## Dataset Download
29 | Please fill out [this form](https://docs.google.com/forms/d/e/1FAIpQLSdl5adeyKxBSBFZpgs0A7-dAouRkMFAGUP5iz3zxGDj_PhB1w/viewform) to request authorization to download HIMO for research purposes.
30 | After downloading the dataset, unzip the data in `./data` and you'll get the following structure:
31 | ```shell
32 | ./data
33 | |-- joints
34 | | |-- S01T001.npy
35 | | |-- ...
36 | |-- smplx
37 | | |-- S01T001.npz
38 | | |-- ...
39 | |-- object_pose
40 | | |-- S01T001.npy
41 | | |-- ...
42 | |-- text
43 | | |-- S01T001.txt
44 | | |-- ...
45 | |-- segments
46 | | |-- S01T001.json
47 | | |-- ...
48 | |-- object_mesh
49 | | |-- Apple.obj
50 | | |-- ...
51 | |-- processed_2o
52 | | |-- ...
53 | |-- processed_3o
54 | | |-- ...
55 |
56 | ```
57 |
58 | ## Data Visualization
59 | We use the [AIT-Viewer](https://github.com/eth-ait/aitviewer) to visualize the dataset. You can follow the instructions below to visualize it.
60 | ```bash
61 | pip install -r visualize/requirements.txt
62 | ```
63 | You also need to download the [SMPL-X models](https://smpl-x.is.tue.mpg.de/) and place them in `./body_models`, which should look like:
64 | ```shell
65 | ./body_models
66 | |-- smplx
67 | ├── SMPLX_FEMALE.npz
68 | ├── SMPLX_FEMALE.pkl
69 | ├── SMPLX_MALE.npz
70 | ├── SMPLX_MALE.pkl
71 | ├── SMPLX_NEUTRAL.npz
72 | ├── SMPLX_NEUTRAL.pkl
73 | └── SMPLX_NEUTRAL_2020.npz
74 | ```
75 | Then you can run the following command to visualize the dataset.
76 | ```bash
77 | # Visualize the skeleton
78 | python visualize/skel_viewer.py
79 | # Visualize the SMPLX
80 | python visualize/smplx_viewer.py
81 | # Visualize the segment data
82 | python visualize/segment_viewer.py
83 | ```
84 |
85 | ## Training
86 | We provide preprocessd objects' BPS and vocabulary glove files [here](https://drive.google.com/drive/folders/11PYdla0R9GIyYXqDPMle9208Hv8MaUMo?usp=sharing).
87 |
88 | To train the model in 2-object setting, run
89 | ```bash
90 | python -m src.train.train_net_2o --exp_name net_2o --num_epochs 1000
91 | ```
92 | To train the model in 3-object setting, run
93 | ```bash
94 | python -m src.train.train_net_3o --exp_name net_3o --num_epochs 1000
95 | ```
96 | To evaluate the model, you need to train your own evaluator or use the checkpoint we provide [here](https://drive.google.com/drive/folders/11PYdla0R9GIyYXqDPMle9208Hv8MaUMo?usp=sharing) (put them under `./save`).
97 | Then run
98 | ```bash
99 | python -m src.eval.eval_himo_2o
100 | ```
101 | or
102 | ```bash
103 | python -m src.eval.eval_himo_3o
104 | ```
105 |
--------------------------------------------------------------------------------
/src/feature_extractor/train_tex_mot_match.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 |
4 | from loguru import logger
5 | import torch
6 | from src.utils.parser_utils import train_feature_extractor_args
7 | from src.utils.misc import fixseed
8 | from src.utils.word_vectorizer import POS_enumerator,WordVectorizer
9 | from src.feature_extractor.modules import TextEncoderBiGRUCo,MotionEncoderBiGRUCo,MovementConvEncoder
10 | from src.feature_extractor.trainer import TextMotionMatchTrainer
11 | from src.dataset.fe_dataset import feature_extractor_dataset,collate_fn
12 | from src.train.train_platforms import WandbPlatform
13 |
14 | from torch.utils.data import DataLoader
15 |
16 | def build_models(args):
17 | movement_enc=MovementConvEncoder(dim_pose,args.dim_movement_enc_hidden,args.dim_movement_latent)
18 | text_enc = TextEncoderBiGRUCo(word_size=dim_word,
19 | pos_size=dim_pos_ohot,
20 | hidden_size=args.dim_text_hidden,
21 | output_size=args.dim_coemb_hidden,
22 | device=args.device)
23 | motion_enc = MotionEncoderBiGRUCo(input_size=args.dim_movement_latent,
24 | hidden_size=args.dim_motion_hidden,
25 | output_size=args.dim_coemb_hidden,
26 | device=args.device)
27 |
28 | if not args.is_continue:
29 | logger.info('Loading Decomp......')
30 | checkpoint = torch.load(osp.join(args.checkpoints_dir, args.decomp_name, 'model', 'latest.tar'),
31 | map_location=args.device)
32 | movement_enc.load_state_dict(checkpoint['movement_enc'])
33 | return text_enc,motion_enc,movement_enc
34 |
35 | if __name__=='__main__':
36 | args=train_feature_extractor_args()
37 | args.device = torch.device("cpu" if args.gpu_id==-1 else "cuda:" + str(args.gpu_id))
38 | torch.autograd.set_detect_anomaly(True)
39 | fixseed(args.seed)
40 | if args.gpu_id!=-1:
41 | torch.cuda.set_device(args.gpu_id)
42 | args.save_path=osp.join(args.save_dir,args.exp_name)
43 | args.model_dir=osp.join(args.save_path,'model')
44 | args.log_dir=osp.join(args.save_path,'log')
45 | args.eval_dir=osp.join(args.save_path,'eval')
46 |
47 | os.makedirs(args.save_path,exist_ok=True)
48 | os.makedirs(args.model_dir,exist_ok=True)
49 | os.makedirs(args.log_dir,exist_ok=True)
50 | os.makedirs(args.eval_dir,exist_ok=True)
51 |
52 | logger.add(osp.join(args.log_dir,'train_feature_extractor.log'),rotation='10 MB')
53 |
54 | args.data_root='/data/xuliang/HO2_subsets_original/HO2_final'
55 | args.max_motion_length=300
56 | if args.mode=='2o':
57 | dim_pose=52*6+52*3+3+2*(6+3)
58 | elif args.mode=='3o':
59 | dim_pose=52*6+52*3+3+3*(6+3)
60 |
61 | meta_root=osp.join(args.data_root,'glove')
62 | dim_word=300
63 | dim_pos_ohot=len(POS_enumerator)
64 |
65 | w_vectorizer=WordVectorizer(meta_root,'himo_vab')
66 | text_encoder,motion_encoder,movement_encoder=build_models(args)
67 |
68 | pc_text_enc = sum(param.numel() for param in text_encoder.parameters())
69 | logger.info("Total parameters of text encoder: {}".format(pc_text_enc))
70 | pc_motion_enc = sum(param.numel() for param in motion_encoder.parameters())
71 | logger.info("Total parameters of motion encoder: {}".format(pc_motion_enc))
72 | logger.info("Total parameters: {}".format(pc_motion_enc + pc_text_enc))
73 |
74 | train_platform=WandbPlatform(args.save_path)
75 |
76 | trainer=TextMotionMatchTrainer(args,text_encoder,motion_encoder,movement_encoder,train_platform)
77 |
78 | train_dataset=feature_extractor_dataset(args,'train',w_vectorizer)
79 | val_dataset=feature_extractor_dataset(args,'val',w_vectorizer)
80 |
81 | train_loader = DataLoader(train_dataset, batch_size=args.batch_size, drop_last=True, num_workers=4,
82 | shuffle=True, collate_fn=collate_fn, pin_memory=True)
83 | val_loader = DataLoader(val_dataset, batch_size=args.batch_size, drop_last=True, num_workers=4,
84 | shuffle=True, collate_fn=collate_fn, pin_memory=True)
85 |
86 | trainer.train(train_loader,val_loader)
87 |
88 |
89 |
90 |
--------------------------------------------------------------------------------
/src/feature_extractor/eval_wrapper.py:
--------------------------------------------------------------------------------
1 | from src.utils import dist_utils
2 | from src.utils.word_vectorizer import POS_enumerator
3 | from src.feature_extractor.modules import TextEncoderBiGRUCo,MotionEncoderBiGRUCo,MovementConvEncoder
4 | import os.path as osp
5 | import torch
6 | import numpy as np
7 |
8 | def build_evaluators(opt):
9 | movement_enc = MovementConvEncoder(opt['dim_pose'], opt['dim_movement_enc_hidden'], opt['dim_movement_latent'])
10 | text_enc = TextEncoderBiGRUCo(word_size=opt['dim_word'],
11 | pos_size=opt['dim_pos_ohot'],
12 | hidden_size=opt['dim_text_hidden'],
13 | output_size=opt['dim_coemb_hidden'],
14 | device=opt['device'])
15 | motion_enc = MotionEncoderBiGRUCo(input_size=opt['dim_movement_enc_hidden'],
16 | hidden_size=opt['dim_motion_hidden'],
17 | output_size=opt['dim_coemb_hidden'],
18 | device=opt['device'])
19 | checkpoint=torch.load(osp.join(opt['checkpoint_dir'],'model','finest.tar'),
20 | map_location=opt['device'])
21 | movement_enc.load_state_dict(checkpoint['movement_encoder'])
22 | text_enc.load_state_dict(checkpoint['text_encoder'])
23 | motion_enc.load_state_dict(checkpoint['motion_encoder'])
24 | print('Loading Evaluation Model Wrapper (Epoch %d) Completed!!' % (checkpoint['epoch']))
25 |
26 | return text_enc,motion_enc,movement_enc
27 |
28 | class EvaluationWrapper:
29 | def __init__(self,val_args):
30 | opt={
31 | 'device':dist_utils.dev(),
32 | 'dim_word':300,
33 | 'max_motion_length':300,
34 | 'dim_pos_ohot': len(POS_enumerator),
35 | 'dim_motion_hidden': 1024,
36 | 'max_text_len':40,
37 | 'dim_text_hidden':512,
38 | 'dim_coemb_hidden':512,
39 | 'dim_pose':52*3+52*6+3+2*9 if val_args.obj=='2o' else 52*3+52*6+3+3*9,
40 | 'dim_movement_enc_hidden': 512,
41 | 'dim_movement_latent': 512,
42 | 'checkpoint_dir':'./save/fe_2o_epoch300' if val_args.obj=='2o' else './save/fe_3o_epoch300',
43 | 'unit_length':5
44 | }
45 | self.text_encoder,self.motion_encoder,self.movement_encoder=build_evaluators(opt)
46 | self.opt = opt
47 | self.device = opt['device']
48 |
49 | self.text_encoder.to(opt['device'])
50 | self.motion_encoder.to(opt['device'])
51 | self.movement_encoder.to(opt['device'])
52 |
53 | self.text_encoder.eval()
54 | self.motion_encoder.eval()
55 | self.movement_encoder.eval()
56 |
57 | # Please note that the results does not following the order of inputs
58 | def get_co_embeddings(self, word_embs, pos_ohot, cap_lens, motions, m_lens):
59 | with torch.no_grad():
60 | word_embs = word_embs.detach().to(self.device).float()
61 | pos_ohot = pos_ohot.detach().to(self.device).float()
62 | motions = motions.detach().to(self.device).float()
63 |
64 | align_idx = np.argsort(m_lens.data.tolist())[::-1].copy()
65 | motions = motions[align_idx]
66 | m_lens = m_lens[align_idx]
67 |
68 | '''Movement Encoding'''
69 | movements = self.movement_encoder(motions).detach()
70 | m_lens = m_lens // self.opt['unit_length']
71 | motion_embedding = self.motion_encoder(movements, m_lens)
72 |
73 | '''Text Encoding'''
74 | text_embedding = self.text_encoder(word_embs, pos_ohot, cap_lens)
75 | text_embedding = text_embedding[align_idx]
76 | return text_embedding, motion_embedding
77 |
78 | def get_motion_embeddings(self,motions,m_lens):
79 | with torch.no_grad():
80 | motions = motions.detach().to(self.device).float()
81 |
82 | align_idx = np.argsort(m_lens.data.tolist())[::-1].copy()
83 | motions = motions[align_idx]
84 | m_lens = m_lens[align_idx]
85 |
86 | '''Movement Encoding'''
87 | movements = self.movement_encoder(motions).detach()
88 | m_lens = m_lens // self.opt['unit_length']
89 | motion_embedding = self.motion_encoder(movements, m_lens)
90 | return motion_embedding
91 |
--------------------------------------------------------------------------------
/src/dataset/fe_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | from typing import Any
4 | import numpy as np
5 | import h5py
6 | from tqdm import tqdm
7 | from torch.utils.data import Dataset
8 | import codecs as cs
9 | from torch.utils.data._utils.collate import default_collate
10 |
11 | class feature_extractor_dataset(Dataset):
12 | def __init__(self,args,split,w_vectorizer):
13 | self.args=args
14 | self.w_vectorizer=w_vectorizer
15 | self.max_motion_length=args.max_motion_length
16 | self.split=split
17 | self.data_path=osp.join(args.data_root,'processed_{}'.format(args.mode),f'{split}.h5')
18 |
19 | self.data=[]
20 | self.load_data()
21 |
22 | def __len__(self):
23 | return len(self.data)
24 |
25 | def __getitem__(self, idx):
26 | data=self.data[idx]
27 | motion,m_length,text_list=data['motion'],data['length'],data['text']
28 | caption,tokens=text_list['caption'],text_list['tokens']
29 |
30 | if len(tokens) < self.args.max_text_len:
31 | # pad with "unk"
32 | tokens = ['sos/OTHER'] + tokens + ['eos/OTHER']
33 | sent_len = len(tokens)
34 | tokens = tokens + ['unk/OTHER'] * (self.args.max_text_len + 2 - sent_len)
35 | else:
36 | # crop
37 | tokens = tokens[:self.args.max_text_len]
38 | tokens = ['sos/OTHER'] + tokens + ['eos/OTHER']
39 | sent_len = len(tokens)
40 | pos_one_hots = []
41 | word_embeddings = []
42 | for token in tokens:
43 | word_emb, pos_oh = self.w_vectorizer[token]
44 | pos_one_hots.append(pos_oh[None, :])
45 | word_embeddings.append(word_emb[None, :])
46 | pos_one_hots = np.concatenate(pos_one_hots, axis=0)
47 | word_embeddings = np.concatenate(word_embeddings, axis=0)
48 |
49 | if m_lengthbhij', queries, keys) * self.scale
111 |
112 | if mask is not None:
113 | mask = F.pad(mask.flatten(1), (1, 0), value = True)
114 | assert mask.shape[-1] == dots.shape[-1], 'Mask has incorrect dimensions'
115 | mask = mask[:, None, :].expand(-1, n, -1)
116 | dots.masked_fill_(~mask, float('-inf'))
117 |
118 | attn = dots.softmax(dim=-1)
119 | out = torch.einsum('bhij,bhjd->bhid', attn, values)
120 | out = out.transpose(1, 2).contiguous().view(b, n, -1)
121 | return out
--------------------------------------------------------------------------------
/src/utils/dist_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Helpers for distributed training.
3 | """
4 |
5 | import socket
6 |
7 | import torch as th
8 | import torch.distributed as dist
9 |
10 | # Change this to reflect your cluster layout.
11 | # The GPU for a given rank is (rank % GPUS_PER_NODE).
12 | GPUS_PER_NODE = 8
13 |
14 | SETUP_RETRY_COUNT = 3
15 |
16 | used_device = 0
17 |
18 | def setup_dist(device=0):
19 | """
20 | Setup a distributed process group.
21 | """
22 | global used_device
23 | used_device = device
24 | if dist.is_initialized():
25 | return
26 | # os.environ["CUDA_VISIBLE_DEVICES"] = str(device) # f"{MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE}"
27 |
28 | # comm = MPI.COMM_WORLD
29 | # backend = "gloo" if not th.cuda.is_available() else "nccl"
30 |
31 | # if backend == "gloo":
32 | # hostname = "localhost"
33 | # else:
34 | # hostname = socket.gethostbyname(socket.getfqdn())
35 | # os.environ["MASTER_ADDR"] = comm.bcast(hostname, root=0)
36 | # os.environ["RANK"] = str(comm.rank)
37 | # os.environ["WORLD_SIZE"] = str(comm.size)
38 |
39 | # port = comm.bcast(_find_free_port(), root=used_device)
40 | # os.environ["MASTER_PORT"] = str(port)
41 | # dist.init_process_group(backend=backend, init_method="env://")
42 |
43 |
44 | def dev():
45 | """
46 | Get the device to use for torch.distributed.
47 | """
48 | global used_device
49 | if th.cuda.is_available() and used_device>=0:
50 | return th.device(f"cuda:{used_device}")
51 | return th.device("cpu")
52 |
53 |
54 | def load_state_dict(path, **kwargs):
55 | """
56 | Load a PyTorch file without redundant fetches across MPI ranks.
57 | """
58 | return th.load(path, **kwargs)
59 |
60 |
61 | def sync_params(params):
62 | """
63 | Synchronize a sequence of Tensors across ranks from rank 0.
64 | """
65 | for p in params:
66 | with th.no_grad():
67 | dist.broadcast(p, 0)
68 |
69 |
70 | def _find_free_port():
71 | try:
72 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
73 | s.bind(("", 0))
74 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
75 | return s.getsockname()[1]
76 | finally:
77 | s.close()
78 |
79 | # """
80 | # Helpers for distributed training.
81 | # """
82 |
83 | # import io
84 | # import os
85 | # import socket
86 |
87 | # import blobfile as bf
88 | # # from mpi4py import MPI
89 | # import torch as th
90 | # import torch.distributed as dist
91 |
92 | # # Change this to reflect your cluster layout.
93 | # # The GPU for a given rank is (rank % GPUS_PER_NODE).
94 | # GPUS_PER_NODE = 2
95 |
96 | # SETUP_RETRY_COUNT = 3
97 |
98 |
99 | # def setup_dist():
100 | # """
101 | # Setup a distributed process group.
102 | # """
103 | # if dist.is_initialized():
104 | # return
105 | # # os.environ["CUDA_VISIBLE_DEVICES"] = f"{MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE}"
106 |
107 | # # comm = MPI.COMM_WORLD
108 | # backend = "gloo" if not th.cuda.is_available() else "nccl"
109 |
110 | # if backend == "gloo":
111 | # hostname = "localhost"
112 | # else:
113 | # hostname = socket.gethostbyname(socket.getfqdn())
114 | # # os.environ["MASTER_ADDR"] = comm.bcast(hostname, root=0)
115 | # # os.environ["RANK"] = str(comm.rank)
116 | # # os.environ["WORLD_SIZE"] = str(comm.size)
117 |
118 | # # port = comm.bcast(_find_free_port(), root=0)
119 | # # os.environ["MASTER_PORT"] = str(port)
120 | # dist.init_process_group(backend=backend)
121 |
122 |
123 | # def dev():
124 | # """
125 | # Get the device to use for torch.distributed.
126 | # """
127 | # if th.cuda.is_available():
128 | # return th.device(f"cuda")
129 | # return th.device("cpu")
130 |
131 |
132 | # def load_state_dict(path, **kwargs):
133 | # """
134 | # Load a PyTorch file without redundant fetches across MPI ranks.
135 | # """
136 | # chunk_size = 2 ** 30 # MPI has a relatively small size limit
137 | # if MPI.COMM_WORLD.Get_rank() == 0:
138 | # with bf.BlobFile(path, "rb") as f:
139 | # data = f.read()
140 | # num_chunks = len(data) // chunk_size
141 | # if len(data) % chunk_size:
142 | # num_chunks += 1
143 | # MPI.COMM_WORLD.bcast(num_chunks)
144 | # for i in range(0, len(data), chunk_size):
145 | # MPI.COMM_WORLD.bcast(data[i : i + chunk_size])
146 | # else:
147 | # num_chunks = MPI.COMM_WORLD.bcast(None)
148 | # data = bytes()
149 | # for _ in range(num_chunks):
150 | # data += MPI.COMM_WORLD.bcast(None)
151 |
152 | # return th.load(io.BytesIO(data), **kwargs)
153 |
154 |
155 | # def sync_params(params):
156 | # """
157 | # Synchronize a sequence of Tensors across ranks from rank 0.
158 | # """
159 | # for p in params:
160 | # with th.no_grad():
161 | # dist.broadcast(p, 0)
162 |
163 |
164 | # def _find_free_port():
165 | # try:
166 | # s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
167 | # s.bind(("", 0))
168 | # s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
169 | # return s.getsockname()[1]
170 | # finally:
171 | # s.close()
--------------------------------------------------------------------------------
/src/dataset/tensors.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data._utils.collate import default_collate
3 |
4 | def lengths_to_mask(lengths, max_len):
5 | # max_len = max(lengths)
6 | mask = torch.arange(max_len, device=lengths.device).expand(len(lengths), max_len) < lengths.unsqueeze(1)
7 | return mask
8 |
9 | def collate_tensors(batch):
10 | dims = batch[0].dim()
11 | max_size = [max([b.size(i) for b in batch]) for i in range(dims)]
12 | size = (len(batch),) + tuple(max_size)
13 | canvas = batch[0].new_zeros(size=size)
14 | for i, b in enumerate(batch):
15 | sub_tensor = canvas[i]
16 | for d in range(dims):
17 | sub_tensor = sub_tensor.narrow(d, 0, b.size(d))
18 | sub_tensor.add_(b)
19 | return canvas
20 |
21 | def himo_2o_collate_fn(batch):
22 | adapteed_batch=[
23 | {
24 | 'inp':torch.tensor(b[6]).float(), # [nf,52*3+52*6+3+2*9]
25 | 'length':b[7],
26 | 'text':b[0],
27 | 'obj1_bps':torch.tensor(b[1]).squeeze(0).float(), # [1024,3]
28 | 'obj2_bps':torch.tensor(b[2]).squeeze(0).float(),
29 | 'obj1_sampled_verts':torch.tensor(b[3]).float(), # [1024,3]
30 | 'obj2_sampled_verts':torch.tensor(b[4]).float(),
31 | 'init_state':torch.tensor(b[5]).float(), # [52*3+52*6+3+2*9]
32 | 'betas':torch.tensor(b[8]).float(), # [10]
33 | } for b in batch
34 | ]
35 | inp_tensor=collate_tensors([b['inp'] for b in adapteed_batch]) # [bs,nf,52*3+52*6+3+2*9]
36 | len_batch=[b['length'] for b in adapteed_batch]
37 | len_tensor=torch.tensor(len_batch).long() # [B]
38 | mask_tensor=lengths_to_mask(len_tensor,inp_tensor.shape[1]).unsqueeze(1).unsqueeze(1) # [B,1,1,nf]
39 |
40 | text_batch=[b['text'] for b in adapteed_batch]
41 | o1b_tensor=torch.stack([b['obj1_bps'] for b in adapteed_batch],dim=0) # [B,1024,3]
42 | o2b_tensor=torch.stack([b['obj2_bps'] for b in adapteed_batch],dim=0) # [B,1024,3]
43 | o1sv_tensor=torch.stack([b['obj1_sampled_verts'] for b in adapteed_batch],dim=0) # [B,1024,3]
44 | o2sv_tensor=torch.stack([b['obj2_sampled_verts'] for b in adapteed_batch],dim=0) # [B,1024,3]
45 | init_state_tensor=torch.stack([b['init_state'] for b in adapteed_batch],dim=0) # [B,52*3+52*6+3+2*9]
46 | betas_tensor=torch.stack([b['betas'] for b in adapteed_batch],dim=0) # [B,10]
47 |
48 | cond={
49 | 'y':{
50 | 'mask':mask_tensor,
51 | 'length':len_tensor,
52 | 'text':text_batch,
53 | 'obj1_bps':o1b_tensor,
54 | 'obj2_bps':o2b_tensor,
55 | 'obj1_sampled_verts':o1sv_tensor,
56 | 'obj2_sampled_verts':o2sv_tensor,
57 | 'init_state':init_state_tensor,
58 | 'betas':betas_tensor
59 | }
60 | }
61 | return inp_tensor,cond
62 |
63 | def himo_3o_collate_fn(batch):
64 | adapteed_batch=[
65 | {
66 | 'inp':torch.tensor(b[8]).float(), # [nf,52*3+52*6+3+3*9]
67 | 'length':b[9],
68 | 'text':b[0],
69 | 'obj1_bps':torch.tensor(b[1]).squeeze(0).float(), # [1024,3]
70 | 'obj2_bps':torch.tensor(b[2]).squeeze(0).float(),
71 | 'obj3_bps':torch.tensor(b[3]).squeeze(0).float(),
72 | 'obj1_sampled_verts':torch.tensor(b[4]).float(), # [1024,3]
73 | 'obj2_sampled_verts':torch.tensor(b[5]).float(),
74 | 'obj3_sampled_verts':torch.tensor(b[6]).float(),
75 | 'init_state':torch.tensor(b[7]).float(), # [52*3+52*6+3+3*9]
76 | 'betas':torch.tensor(b[10]).float(), # [10]
77 |
78 | } for b in batch
79 | ]
80 | inp_tensor=collate_tensors([b['inp'] for b in adapteed_batch]) # [bs,nf,52*3+52*6+3+2*9]
81 | len_batch=[b['length'] for b in adapteed_batch]
82 | len_tensor=torch.tensor(len_batch).long() # [B]
83 | mask_tensor=lengths_to_mask(len_tensor,inp_tensor.shape[1]).unsqueeze(1).unsqueeze(1) # [B,1,1,nf]
84 |
85 | text_batch=[b['text'] for b in adapteed_batch]
86 | o1b_tensor=torch.stack([b['obj1_bps'] for b in adapteed_batch],dim=0) # [B,1024,3]
87 | o2b_tensor=torch.stack([b['obj2_bps'] for b in adapteed_batch],dim=0) # [B,1024,3]
88 | o3b_tensor=torch.stack([b['obj3_bps'] for b in adapteed_batch],dim=0) # [B,1024,3]
89 | o1sv_tensor=torch.stack([b['obj1_sampled_verts'] for b in adapteed_batch],dim=0) # [B,1024,3]
90 | o2sv_tensor=torch.stack([b['obj2_sampled_verts'] for b in adapteed_batch],dim=0) # [B,1024,3]
91 | o3sv_tensor=torch.stack([b['obj3_sampled_verts'] for b in adapteed_batch],dim=0) # [B,1024,3]
92 | init_state_tensor=torch.stack([b['init_state'] for b in adapteed_batch],dim=0) # [B,52*3+52*6+3+3*9]
93 | betas_tensor=torch.stack([b['betas'] for b in adapteed_batch],dim=0) # [B,10]
94 |
95 | cond={
96 | 'y':{
97 | 'mask':mask_tensor,
98 | 'length':len_tensor,
99 | 'text':text_batch,
100 | 'obj1_bps':o1b_tensor,
101 | 'obj2_bps':o2b_tensor,
102 | 'obj3_bps':o3b_tensor,
103 | 'obj1_sampled_verts':o1sv_tensor,
104 | 'obj2_sampled_verts':o2sv_tensor,
105 | 'obj3_sampled_verts':o3sv_tensor,
106 | 'init_state':init_state_tensor,
107 | 'betas':betas_tensor
108 | }
109 | }
110 | return inp_tensor,cond
111 |
112 | def gt_collate_fn(batch):
113 | # sort batch by sent length
114 | batch.sort(key=lambda x: x[3], reverse=True)
115 | return default_collate(batch)
116 |
--------------------------------------------------------------------------------
/src/eval/metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy import linalg
3 |
4 |
5 | # (X - X_train)*(X - X_train) = -2X*X_train + X*X + X_train*X_train
6 | def euclidean_distance_matrix(matrix1, matrix2):
7 | """
8 | Params:
9 | -- matrix1: N1 x D
10 | -- matrix2: N2 x D
11 | Returns:
12 | -- dist: N1 x N2
13 | dist[i, j] == distance(matrix1[i], matrix2[j])
14 | """
15 | assert matrix1.shape[1] == matrix2.shape[1]
16 | d1 = -2 * np.dot(matrix1, matrix2.T) # shape (num_test, num_train)
17 | d2 = np.sum(np.square(matrix1), axis=1, keepdims=True) # shape (num_test, 1)
18 | d3 = np.sum(np.square(matrix2), axis=1) # shape (num_train, )
19 | dists = np.sqrt(d1 + d2 + d3) # broadcasting
20 | return dists
21 |
22 | def calculate_top_k(mat, top_k):
23 | size = mat.shape[0]
24 | gt_mat = np.expand_dims(np.arange(size), 1).repeat(size, 1)
25 | bool_mat = (mat == gt_mat)
26 | correct_vec = False
27 | top_k_list = []
28 | for i in range(top_k):
29 | # print(correct_vec, bool_mat[:, i])
30 | correct_vec = (correct_vec | bool_mat[:, i])
31 | # print(correct_vec)
32 | top_k_list.append(correct_vec[:, None])
33 | top_k_mat = np.concatenate(top_k_list, axis=1)
34 | return top_k_mat
35 |
36 |
37 | def calculate_R_precision(embedding1, embedding2, top_k, sum_all=False):
38 | dist_mat = euclidean_distance_matrix(embedding1, embedding2)
39 | argmax = np.argsort(dist_mat, axis=1)
40 | top_k_mat = calculate_top_k(argmax, top_k)
41 | if sum_all:
42 | return top_k_mat.sum(axis=0)
43 | else:
44 | return top_k_mat
45 |
46 |
47 | def calculate_matching_score(embedding1, embedding2, sum_all=False):
48 | assert len(embedding1.shape) == 2
49 | assert embedding1.shape[0] == embedding2.shape[0]
50 | assert embedding1.shape[1] == embedding2.shape[1]
51 |
52 | dist = linalg.norm(embedding1 - embedding2, axis=1)
53 | if sum_all:
54 | return dist.sum(axis=0)
55 | else:
56 | return dist
57 |
58 |
59 |
60 | def calculate_activation_statistics(activations):
61 | """
62 | Params:
63 | -- activation: num_samples x dim_feat
64 | Returns:
65 | -- mu: dim_feat
66 | -- sigma: dim_feat x dim_feat
67 | """
68 | mu = np.mean(activations, axis=0)
69 | cov = np.cov(activations, rowvar=False)
70 | return mu, cov
71 |
72 |
73 | def calculate_diversity(activation, diversity_times):
74 | assert len(activation.shape) == 2
75 | assert activation.shape[0] > diversity_times
76 | num_samples = activation.shape[0]
77 |
78 | first_indices = np.random.choice(num_samples, diversity_times, replace=False)
79 | second_indices = np.random.choice(num_samples, diversity_times, replace=False)
80 | dist = linalg.norm(activation[first_indices] - activation[second_indices], axis=1)
81 | return dist.mean()
82 |
83 |
84 | def calculate_multimodality(activation, multimodality_times):
85 | assert len(activation.shape) == 3
86 | assert activation.shape[1] > multimodality_times
87 | num_per_sent = activation.shape[1]
88 |
89 | first_dices = np.random.choice(num_per_sent, multimodality_times, replace=False)
90 | second_dices = np.random.choice(num_per_sent, multimodality_times, replace=False)
91 | dist = linalg.norm(activation[:, first_dices] - activation[:, second_dices], axis=2)
92 | return dist.mean()
93 |
94 |
95 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
96 | """Numpy implementation of the Frechet Distance.
97 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
98 | and X_2 ~ N(mu_2, C_2) is
99 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
100 | Stable version by Dougal J. Sutherland.
101 | Params:
102 | -- mu1 : Numpy array containing the activations of a layer of the
103 | inception net (like returned by the function 'get_predictions')
104 | for generated samples.
105 | -- mu2 : The sample mean over activations, precalculated on an
106 | representative dataset set.
107 | -- sigma1: The covariance matrix over activations for generated samples.
108 | -- sigma2: The covariance matrix over activations, precalculated on an
109 | representative dataset set.
110 | Returns:
111 | -- : The Frechet Distance.
112 | """
113 |
114 | mu1 = np.atleast_1d(mu1)
115 | mu2 = np.atleast_1d(mu2)
116 |
117 | sigma1 = np.atleast_2d(sigma1)
118 | sigma2 = np.atleast_2d(sigma2)
119 |
120 | assert mu1.shape == mu2.shape, \
121 | 'Training and test mean vectors have different lengths'
122 | assert sigma1.shape == sigma2.shape, \
123 | 'Training and test covariances have different dimensions'
124 |
125 | diff = mu1 - mu2
126 |
127 | # Product might be almost singular
128 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
129 | if not np.isfinite(covmean).all():
130 | msg = ('fid calculation produces singular product; '
131 | 'adding %s to diagonal of cov estimates') % eps
132 | print(msg)
133 | offset = np.eye(sigma1.shape[0]) * eps
134 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
135 |
136 | # Numerical error might give slight imaginary component
137 | if np.iscomplexobj(covmean):
138 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
139 | m = np.max(np.abs(covmean.imag))
140 | raise ValueError('Imaginary component {}'.format(m))
141 | covmean = covmean.real
142 |
143 | tr_covmean = np.trace(covmean)
144 |
145 | return (diff.dot(diff) + np.trace(sigma1) +
146 | np.trace(sigma2) - 2 * tr_covmean)
--------------------------------------------------------------------------------
/src/diffusion/respace.py:
--------------------------------------------------------------------------------
1 | # This code is based on https://github.com/openai/guided-diffusion
2 | import numpy as np
3 | import torch as th
4 |
5 | from .gaussian_diffusion import GaussianDiffusion
6 |
7 |
8 | def space_timesteps(num_timesteps, section_counts):
9 | """
10 | Create a list of timesteps to use from an original diffusion process,
11 | given the number of timesteps we want to take from equally-sized portions
12 | of the original process.
13 |
14 | For example, if there's 300 timesteps and the section counts are [10,15,20]
15 | then the first 100 timesteps are strided to be 10 timesteps, the second 100
16 | are strided to be 15 timesteps, and the final 100 are strided to be 20.
17 |
18 | If the stride is a string starting with "ddim", then the fixed striding
19 | from the DDIM paper is used, and only one section is allowed.
20 |
21 | :param num_timesteps: the number of diffusion steps in the original
22 | process to divide up.
23 | :param section_counts: either a list of numbers, or a string containing
24 | comma-separated numbers, indicating the step count
25 | per section. As a special case, use "ddimN" where N
26 | is a number of steps to use the striding from the
27 | DDIM paper.
28 | :return: a set of diffusion steps from the original process to use.
29 | """
30 | if isinstance(section_counts, str):
31 | if section_counts.startswith("ddim"):
32 | desired_count = int(section_counts[len("ddim") :])
33 | for i in range(1, num_timesteps):
34 | if len(range(0, num_timesteps, i)) == desired_count:
35 | return set(range(0, num_timesteps, i))
36 | raise ValueError(
37 | f"cannot create exactly {num_timesteps} steps with an integer stride"
38 | )
39 | section_counts = [int(x) for x in section_counts.split(",")]
40 | size_per = num_timesteps // len(section_counts)
41 | extra = num_timesteps % len(section_counts)
42 | start_idx = 0
43 | all_steps = []
44 | for i, section_count in enumerate(section_counts):
45 | size = size_per + (1 if i < extra else 0)
46 | if size < section_count:
47 | raise ValueError(
48 | f"cannot divide section of {size} steps into {section_count}"
49 | )
50 | if section_count <= 1:
51 | frac_stride = 1
52 | else:
53 | frac_stride = (size - 1) / (section_count - 1)
54 | cur_idx = 0.0
55 | taken_steps = []
56 | for _ in range(section_count):
57 | taken_steps.append(start_idx + round(cur_idx))
58 | cur_idx += frac_stride
59 | all_steps += taken_steps
60 | start_idx += size
61 | return set(all_steps)
62 |
63 |
64 | class SpacedDiffusion(GaussianDiffusion):
65 | """
66 | A diffusion process which can skip steps in a base diffusion process.
67 |
68 | :param use_timesteps: a collection (sequence or set) of timesteps from the
69 | original diffusion process to retain.
70 | :param kwargs: the kwargs to create the base diffusion process.
71 | """
72 |
73 | def __init__(self, use_timesteps, **kwargs):
74 | self.use_timesteps = set(use_timesteps)
75 | self.timestep_map = []
76 | self.original_num_steps = len(kwargs["betas"])
77 |
78 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
79 | last_alpha_cumprod = 1.0
80 | new_betas = []
81 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
82 | if i in self.use_timesteps:
83 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
84 | last_alpha_cumprod = alpha_cumprod
85 | self.timestep_map.append(i)
86 | kwargs["betas"] = np.array(new_betas)
87 | super().__init__(**kwargs)
88 |
89 | def p_mean_variance(
90 | self, model, *args, **kwargs
91 | ): # pylint: disable=signature-differs
92 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
93 |
94 | def training_losses(
95 | self, model, *args, **kwargs
96 | ): # pylint: disable=signature-differs
97 | return super().training_losses(self._wrap_model(model), *args, **kwargs)
98 |
99 | def condition_mean(self, cond_fn, *args, **kwargs):
100 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
101 |
102 | def condition_score(self, cond_fn, *args, **kwargs):
103 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
104 |
105 | def _wrap_model(self, model):
106 | if isinstance(model, _WrappedModel):
107 | return model
108 | return _WrappedModel(
109 | model, self.timestep_map, self.rescale_timesteps, self.original_num_steps
110 | )
111 |
112 | def _scale_timesteps(self, t):
113 | # Scaling is done by the wrapped model.
114 | return t
115 |
116 |
117 | class _WrappedModel:
118 | def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):
119 | self.model = model
120 | self.timestep_map = timestep_map
121 | self.rescale_timesteps = rescale_timesteps
122 | self.original_num_steps = original_num_steps
123 |
124 | def __call__(self, x, ts, **kwargs):
125 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
126 | new_ts = map_tensor[ts]
127 | if self.rescale_timesteps:
128 | new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
129 | return self.model(x, new_ts, **kwargs)
130 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | # HIMO License
2 |
3 | ## Dataset Copyright License for non-commercial scientific research purposes
4 |
5 | Please read carefully the following terms and conditions and any accompanying documentation before you download and/or use the HIMO Dataset and the accompanying Software (jointly referred to as the "Dataset"). By downloading and/or using the Dataset, you acknowledge that you have read these terms and conditions, understand them, and agree to be bound by them. If you do not agree with these terms and conditions, you must not download and/or use the Dataset. Any infringement of the terms of this agreement will automatically terminate your rights under this License.
6 |
7 | ## Ownership / Licensees
8 |
9 | Notwithstanding the disclosure of HIMO and any material pertaining to this dataset, all copyrights are maintained by and remained proprietary property of the authors and/or the copyright holders. You may use the dataset by ways in compliance with all applicable laws and regulations.
10 |
11 | The annotations in HIMO are licensed under the Attribution-Non Commercial-Share Alike 4.0 International License (CC-BY-NC-SA 4.0). With this free CC-license, we encourage the adoption and use of HIMO and related annotations.
12 |
13 | You must give appropriate credit, provide a link to the license, and indicate if changes were made. You may do so in any reasonable manner, but not in any way that suggests the licensor endorses you or your use. You may not use the material for commercial purposes. If you remix, transform, or build upon the material, you must distribute your contributions under the same license as the original.
14 |
15 | ## License Grant
16 |
17 | Licensor grants you (Licensee) personally a single-user, non-exclusive, non-transferable, free of charge right:
18 |
19 | To obtain and install the Dataset on computers owned, leased or otherwise controlled by you and/or your organization;
20 | To use the Dataset for the sole purpose of performing non-commercial scientific research, non-commercial education, or non-commercial artistic projects;
21 | Any other use, in particular any use for commercial purposes, is prohibited. This includes, without limitation, incorporation in a commercial product, use in a commercial service, or production of other artifacts for commercial purposes. The Dataset may not be reproduced, modified and/or made available in any form to any third party without Licensor’s prior written permission.
22 |
23 | The Dataset may not be used for pornographic purposes or to generate pornographic material whether commercial or not. This license also prohibits the use of the Dataset to train methods/algorithms/neural networks/etc. for commercial use of any kind. By downloading the Dataset, you agree not to reverse engineer it.
24 |
25 | ## No Distribution
26 |
27 | The Dataset and the license herein granted shall not be copied, shared, distributed, re-sold, offered for re-sale, transferred or sub-licensed in whole or in part except that you may make one copy for archive purposes only.
28 |
29 | ## Disclaimer of Representations and Warranties
30 |
31 | You expressly acknowledge and agree that the Dataset results from basic research, is provided “AS IS”, may contain errors, and that any use of the Dataset is at your sole risk. LICENSOR MAKES NO REPRESENTATIONS OR WARRANTIES OF ANY KIND CONCERNING THE SOFTWARE, NEITHER EXPRESS NOR IMPLIED, AND THE ABSENCE OF ANY LEGAL OR ACTUAL DEFECTS, WHETHER DISCOVERABLE OR NOT. Specifically, and not to limit the foregoing, licensor makes no representations or warranties (i) regarding the merchantability or fitness for a particular purpose of the Dataset, (ii) that the use of the Dataset will not infringe any patents, copyrights or other intellectual property rights of a third party, and (iii) that the use of the Dataset will not cause any damage of any kind to you or a third party.
32 |
33 | ## Limitation of Liability
34 |
35 | Because this Software License Agreement qualifies as a donation, Licensor as a donor is liable for intent and gross negligence only. If the Licensor fraudulently conceals a legal or material defect, they are obliged to compensate the Licensee for the resulting damage.
36 | Licensor shall be liable for loss of data only up to the amount of typical recovery costs which would have arisen had proper and regular data backup measures been taken. For the avoidance of doubt Licensor shall be liable in accordance with the Chinese Product Liability Act in the event of product liability. The foregoing applies also to Licensor’s legal representatives or assistants in performance. Any further liability shall be excluded.
37 | Patent claims generated through the usage of the Dataset cannot be directed towards the copyright holders.
38 | The contractor points out that add-ons as well as minor modifications to the Dataset may lead to unforeseeable and considerable disruptions.
39 |
40 | ## No Maintenance Services
41 |
42 | You understand and agree that Licensor is under no obligation to provide either maintenance services, update services, notices of latent defects, or corrections of defects with regard to the Dataset. Licensor nevertheless reserves the right to update, modify, or discontinue the Dataset at any time.
43 |
44 | Defects of the Dataset must be notified in writing to the Licensor with a comprehensible description of the error symptoms. The notification of the defect should enable the reproduction of the error. The Licensee is encouraged to communicate any use, results, modification or publication.
45 |
46 | ## Publications using the Dataset
47 |
48 | You acknowledge that the Dataset is a valuable scientific resource and agree to appropriately reference the following paper in any publication making use of the Dataset.
49 |
50 | #### Citation:
51 | ```
52 | @article{lv2024himo,
53 | title={HIMO: A New Benchmark for Full-Body Human Interacting with Multiple Objects},
54 | author={Xintao, Lv and Liang, Xu and YiChao, Yan and Xin, Jin and Congsheng, Xu and Shuwen, Wu and Yifan, Liu and Lincheng, Li and Mengxiao, Bi and Wenjun, Zeng and Xiaokang, Yang},
55 | year={2024}
56 | }
57 | ```
58 |
59 | ## Commercial licensing opportunities
60 |
61 | For commercial uses of the Dataset, please send email to [lvxintao@sjtu.edu.cn](mailto:lvxintao@sjtu.edu.cn).
--------------------------------------------------------------------------------
/src/feature_extractor/modules.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | import time
5 | import math
6 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
7 | # from networks.layers import *
8 | import torch.nn.functional as F
9 |
10 | def init_weight(m):
11 | if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear) or isinstance(m, nn.ConvTranspose1d):
12 | nn.init.xavier_normal_(m.weight)
13 | # m.bias.data.fill_(0.01)
14 | if m.bias is not None:
15 | nn.init.constant_(m.bias, 0)
16 |
17 | class TextEncoderBiGRUCo(nn.Module):
18 | def __init__(self, word_size, pos_size, hidden_size, output_size, device):
19 | super(TextEncoderBiGRUCo, self).__init__()
20 | self.device = device
21 |
22 | self.pos_emb = nn.Linear(pos_size, word_size)
23 | self.input_emb = nn.Linear(word_size, hidden_size)
24 | self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True)
25 | self.output_net = nn.Sequential(
26 | nn.Linear(hidden_size * 2, hidden_size),
27 | nn.LayerNorm(hidden_size),
28 | nn.LeakyReLU(0.2, inplace=True),
29 | nn.Linear(hidden_size, output_size)
30 | )
31 |
32 | self.input_emb.apply(init_weight)
33 | self.pos_emb.apply(init_weight)
34 | self.output_net.apply(init_weight)
35 | # self.linear2.apply(init_weight)
36 | # self.batch_size = batch_size
37 | self.hidden_size = hidden_size
38 | self.hidden = nn.Parameter(torch.randn((2, 1, self.hidden_size), requires_grad=True))
39 |
40 | # input(batch_size, seq_len, dim)
41 | def forward(self, word_embs, pos_onehot, cap_lens):
42 | num_samples = word_embs.shape[0]
43 |
44 | pos_embs = self.pos_emb(pos_onehot)
45 | inputs = word_embs + pos_embs
46 | input_embs = self.input_emb(inputs)
47 | hidden = self.hidden.repeat(1, num_samples, 1)
48 |
49 | cap_lens = cap_lens.data.tolist()
50 | emb = pack_padded_sequence(input_embs, cap_lens, batch_first=True)
51 |
52 | gru_seq, gru_last = self.gru(emb, hidden)
53 |
54 | gru_last = torch.cat([gru_last[0], gru_last[1]], dim=-1)
55 |
56 | return self.output_net(gru_last)
57 |
58 | class MovementConvEncoder(nn.Module):
59 | def __init__(self, input_size, hidden_size, output_size):
60 | super(MovementConvEncoder, self).__init__()
61 | self.main = nn.Sequential(
62 | nn.Conv1d(input_size, hidden_size, 4, 2, 1),
63 | nn.Dropout(0.2, inplace=True),
64 | nn.LeakyReLU(0.2, inplace=True),
65 | nn.Conv1d(hidden_size, output_size, 4, 2, 1),
66 | nn.Dropout(0.2, inplace=True),
67 | nn.LeakyReLU(0.2, inplace=True),
68 | )
69 | self.out_net = nn.Linear(output_size, output_size)
70 | self.main.apply(init_weight)
71 | self.out_net.apply(init_weight)
72 |
73 | def forward(self, inputs):
74 | # bs,nf,dim
75 | inputs = inputs.permute(0, 2, 1)
76 | outputs = self.main(inputs).permute(0, 2, 1)
77 | # print(outputs.shape)
78 | return self.out_net(outputs)
79 |
80 |
81 | class MovementConvDecoder(nn.Module):
82 | def __init__(self, input_size, hidden_size, output_size):
83 | super(MovementConvDecoder, self).__init__()
84 | self.main = nn.Sequential(
85 | nn.ConvTranspose1d(input_size, hidden_size, 4, 2, 1),
86 | # nn.Dropout(0.2, inplace=True),
87 | nn.LeakyReLU(0.2, inplace=True),
88 | nn.ConvTranspose1d(hidden_size, output_size, 4, 2, 1),
89 | # nn.Dropout(0.2, inplace=True),
90 | nn.LeakyReLU(0.2, inplace=True),
91 | )
92 | self.out_net = nn.Linear(output_size, output_size)
93 |
94 | self.main.apply(init_weight)
95 | self.out_net.apply(init_weight)
96 |
97 | def forward(self, inputs):
98 | inputs = inputs.permute(0, 2, 1)
99 | outputs = self.main(inputs).permute(0, 2, 1)
100 | return self.out_net(outputs)
101 |
102 | class MotionEncoderBiGRUCo(nn.Module):
103 | def __init__(self, input_size, hidden_size, output_size, device):
104 | super(MotionEncoderBiGRUCo, self).__init__()
105 | self.device = device
106 |
107 | self.input_emb = nn.Linear(input_size, hidden_size)
108 | self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True)
109 | self.output_net = nn.Sequential(
110 | nn.Linear(hidden_size*2, hidden_size),
111 | nn.LayerNorm(hidden_size),
112 | nn.LeakyReLU(0.2, inplace=True),
113 | nn.Linear(hidden_size, output_size)
114 | )
115 |
116 | self.input_emb.apply(init_weight)
117 | self.output_net.apply(init_weight)
118 | self.hidden_size = hidden_size
119 | self.hidden = nn.Parameter(torch.randn((2, 1, self.hidden_size), requires_grad=True))
120 |
121 | # input(batch_size, seq_len, dim)
122 | def forward(self, inputs, m_lens):
123 | num_samples = inputs.shape[0]
124 |
125 | input_embs = self.input_emb(inputs)
126 | hidden = self.hidden.repeat(1, num_samples, 1)
127 |
128 | cap_lens = m_lens.data.tolist()
129 | emb = pack_padded_sequence(input_embs, cap_lens, batch_first=True)
130 |
131 | gru_seq, gru_last = self.gru(emb, hidden)
132 |
133 | gru_last = torch.cat([gru_last[0], gru_last[1]], dim=-1)
134 |
135 | return self.output_net(gru_last)
136 |
137 | class ContrastiveLoss(torch.nn.Module):
138 | """
139 | Contrastive loss function.
140 | Based on: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
141 | """
142 | def __init__(self, margin=3.0):
143 | super(ContrastiveLoss, self).__init__()
144 | self.margin = margin
145 |
146 | def forward(self, output1, output2, label):
147 | euclidean_distance = F.pairwise_distance(output1, output2, keepdim=True)
148 | loss_contrastive = torch.mean((1-label) * torch.pow(euclidean_distance, 2) +
149 | (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))
150 | return loss_contrastive
--------------------------------------------------------------------------------
/src/diffusion/resample.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | import numpy as np
4 | import torch as th
5 | import torch.distributed as dist
6 |
7 |
8 | def create_named_schedule_sampler(name, diffusion):
9 | """
10 | Create a ScheduleSampler from a library of pre-defined samplers.
11 |
12 | :param name: the name of the sampler.
13 | :param diffusion: the diffusion object to sample for.
14 | """
15 | if name == "uniform":
16 | return UniformSampler(diffusion)
17 | elif name == "loss-second-moment":
18 | return LossSecondMomentResampler(diffusion)
19 | else:
20 | raise NotImplementedError(f"unknown schedule sampler: {name}")
21 |
22 |
23 | class ScheduleSampler(ABC):
24 | """
25 | A distribution over timesteps in the diffusion process, intended to reduce
26 | variance of the objective.
27 |
28 | By default, samplers perform unbiased importance sampling, in which the
29 | objective's mean is unchanged.
30 | However, subclasses may override sample() to change how the resampled
31 | terms are reweighted, allowing for actual changes in the objective.
32 | """
33 |
34 | @abstractmethod
35 | def weights(self):
36 | """
37 | Get a numpy array of weights, one per diffusion step.
38 |
39 | The weights needn't be normalized, but must be positive.
40 | """
41 |
42 | def sample(self, batch_size, device):
43 | """
44 | Importance-sample timesteps for a batch.
45 |
46 | :param batch_size: the number of timesteps.
47 | :param device: the torch device to save to.
48 | :return: a tuple (timesteps, weights):
49 | - timesteps: a tensor of timestep indices.
50 | - weights: a tensor of weights to scale the resulting losses.
51 | """
52 | w = self.weights()
53 | p = w / np.sum(w)
54 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
55 | indices = th.from_numpy(indices_np).long().to(device)
56 | weights_np = 1 / (len(p) * p[indices_np])
57 | weights = th.from_numpy(weights_np).float().to(device)
58 | return indices, weights
59 |
60 |
61 | class UniformSampler(ScheduleSampler):
62 | def __init__(self, diffusion):
63 | self.diffusion = diffusion
64 | self._weights = np.ones([diffusion.num_timesteps])
65 |
66 | def weights(self):
67 | return self._weights
68 |
69 |
70 | class LossAwareSampler(ScheduleSampler):
71 | def update_with_local_losses(self, local_ts, local_losses):
72 | """
73 | Update the reweighting using losses from a model.
74 |
75 | Call this method from each rank with a batch of timesteps and the
76 | corresponding losses for each of those timesteps.
77 | This method will perform synchronization to make sure all of the ranks
78 | maintain the exact same reweighting.
79 |
80 | :param local_ts: an integer Tensor of timesteps.
81 | :param local_losses: a 1D Tensor of losses.
82 | """
83 | batch_sizes = [
84 | th.tensor([0], dtype=th.int32, device=local_ts.device)
85 | for _ in range(dist.get_world_size())
86 | ]
87 | dist.all_gather(
88 | batch_sizes,
89 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
90 | )
91 |
92 | # Pad all_gather batches to be the maximum batch size.
93 | batch_sizes = [x.item() for x in batch_sizes]
94 | max_bs = max(batch_sizes)
95 |
96 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
97 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
98 | dist.all_gather(timestep_batches, local_ts)
99 | dist.all_gather(loss_batches, local_losses)
100 | timesteps = [
101 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
102 | ]
103 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
104 | self.update_with_all_losses(timesteps, losses)
105 |
106 | @abstractmethod
107 | def update_with_all_losses(self, ts, losses):
108 | """
109 | Update the reweighting using losses from a model.
110 |
111 | Sub-classes should override this method to update the reweighting
112 | using losses from the model.
113 |
114 | This method directly updates the reweighting without synchronizing
115 | between workers. It is called by update_with_local_losses from all
116 | ranks with identical arguments. Thus, it should have deterministic
117 | behavior to maintain state across workers.
118 |
119 | :param ts: a list of int timesteps.
120 | :param losses: a list of float losses, one per timestep.
121 | """
122 |
123 |
124 | class LossSecondMomentResampler(LossAwareSampler):
125 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
126 | self.diffusion = diffusion
127 | self.history_per_term = history_per_term
128 | self.uniform_prob = uniform_prob
129 | self._loss_history = np.zeros(
130 | [diffusion.num_timesteps, history_per_term], dtype=np.float64
131 | )
132 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
133 |
134 | def weights(self):
135 | if not self._warmed_up():
136 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
137 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
138 | weights /= np.sum(weights)
139 | weights *= 1 - self.uniform_prob
140 | weights += self.uniform_prob / len(weights)
141 | return weights
142 |
143 | def update_with_all_losses(self, ts, losses):
144 | for t, loss in zip(ts, losses):
145 | if self._loss_counts[t] == self.history_per_term:
146 | # Shift out the oldest loss term.
147 | self._loss_history[t, :-1] = self._loss_history[t, 1:]
148 | self._loss_history[t, -1] = loss
149 | else:
150 | self._loss_history[t, self._loss_counts[t]] = loss
151 | self._loss_counts[t] += 1
152 |
153 | def _warmed_up(self):
154 | return (self._loss_counts == self.history_per_term).all()
155 |
--------------------------------------------------------------------------------
/src/dataset/eval_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import numpy as np
4 | import h5py
5 | from tqdm import tqdm
6 | from torch.utils.data import Dataset
7 | import codecs as cs
8 | from torch.utils.data._utils.collate import default_collate
9 | from src.utils.word_vectorizer import WordVectorizer
10 |
11 | class Evaluation_Dataset(Dataset):
12 | def __init__(self,args,split='test',mode='gt'):
13 | self.args=args
14 | self.max_motion_length=self.args.max_motion_length
15 | self.split=split
16 | self.mode=mode
17 | self.obj=self.args.obj
18 | self.data_path=osp.join(args.data_dir,'processed_{}'.format(self.obj),f'{split}.h5')
19 | self.data_path=osp.join(args.data_dir,'processed_{}'.format(self.obj),f'{split}.h5')
20 | self.obj_bps_path=osp.join(args.data_dir,'processed_{}'.format(self.obj),'object_bps.npz')
21 | self.obj_sampled_verts_path=osp.join(args.data_dir,'processed_{}'.format(self.obj),'sampled_obj_verts.npz')
22 | self.w_vectorizer=WordVectorizer(osp.join(self.args.data_dir,'glove'),'himo_vab')
23 |
24 | self.data=[]
25 | self.load_data()
26 |
27 | self.object_bps=dict(np.load(self.obj_bps_path,allow_pickle=True)) # 1,1024,3
28 | self.object_sampled_verts=dict(np.load(self.obj_sampled_verts_path,allow_pickle=True)) # 1024,3
29 |
30 | def __len__(self):
31 | return len(self.data)
32 |
33 | def __getitem__(self, idx):
34 | data=self.data[idx]
35 | motion,m_length,text_list=data['motion'],data['length'],data['text']
36 | caption,tokens=text_list['caption'],text_list['tokens']
37 |
38 | if len(tokens) < self.args.max_text_len:
39 | # pad with "unk"
40 | tokens = ['sos/OTHER'] + tokens + ['eos/OTHER']
41 | sent_len = len(tokens)
42 | tokens = tokens + ['unk/OTHER'] * (self.args.max_text_len + 2 - sent_len)
43 | else:
44 | # crop
45 | tokens = tokens[:self.args.max_text_len]
46 | tokens = ['sos/OTHER'] + tokens + ['eos/OTHER']
47 | sent_len = len(tokens)
48 | pos_one_hots = []
49 | word_embeddings = []
50 | for token in tokens:
51 | word_emb, pos_oh = self.w_vectorizer[token]
52 | pos_one_hots.append(pos_oh[None, :])
53 | word_embeddings.append(word_emb[None, :])
54 | pos_one_hots = np.concatenate(pos_one_hots, axis=0)
55 | word_embeddings = np.concatenate(word_embeddings, axis=0)
56 |
57 | if m_length>')
111 | if next_button:
112 | self.set_next_record()
113 | imgui.same_line()
114 | tmp_idx = ''
115 | imgui.set_next_item_width(imgui.get_window_width() * 0.1)
116 | is_go_to, tmp_idx = imgui.input_text('', tmp_idx); imgui.same_line()
117 | if is_go_to:
118 | try:
119 | self.go_to_idx = int(tmp_idx) - 1
120 | except:
121 | pass
122 | go_to_button = imgui.button('>>Go<<'); imgui.same_line()
123 | if go_to_button:
124 | self.set_goto_record(self.go_to_idx)
125 | imgui.text(str(self.label_pid+1) + '/' + str(self.total_tasks))
126 |
127 | imgui.text_wrapped(self.text_val)
128 | imgui.end()
129 |
130 | def set_prev_record(self):
131 | self.label_pid = (self.label_pid - 1) % self.total_tasks
132 | self.clear_one_sequence()
133 | self.load_one_sequence()
134 | self.scene.current_frame_id=0
135 |
136 | def set_next_record(self):
137 | self.label_pid = (self.label_pid + 1) % self.total_tasks
138 | self.clear_one_sequence()
139 | self.load_one_sequence()
140 | self.scene.current_frame_id=0
141 |
142 | def set_goto_record(self, idx):
143 | self.label_pid = int(idx) % self.total_tasks
144 | self.clear_one_sequence()
145 | self.load_one_sequence()
146 | self.scene.current_frame_id=0
147 |
148 | def get_label_file_list(self):
149 | for clip in sorted(os.listdir(self.clip_folder)):
150 | if not clip.startswith('.'):
151 | self.label_npy_list.append(os.path.join(self.clip_folder, clip))
152 |
153 | def load_text_from_file(self):
154 | self.text_val = ''
155 | clip_name = os.path.split(self.label_npy_list[self.label_pid])[-1][:-4]
156 | if os.path.exists(os.path.join(self.text_folder, clip_name+'.txt')):
157 | with open(os.path.join(self.text_folder, clip_name+'.txt'), 'r') as f:
158 | for line in f.readlines():
159 | self.text_val += line
160 | self.text_val += '\n'
161 |
162 |
163 | def load_one_sequence(self):
164 | skel_file = self.label_npy_list[self.label_pid]
165 | clip_name=os.path.split(skel_file)[-1][:-4]
166 | opj_pose_file=os.path.join(self.object_pose_folder, clip_name+'.npy')
167 |
168 | # load skeleton
169 | skel_data = np.load(skel_file, allow_pickle=True)
170 | skel_data=skel_data[:,SELECTED_JOINTS]
171 | skel=Skeletons(
172 | joint_positions=skel_data,
173 | joint_connections=OPTITRACK_LIMBS,
174 | radius=0.005
175 | )
176 | self.scene.add(skel)
177 |
178 | # Load object
179 | object_pose=np.load(opj_pose_file, allow_pickle=True).item()
180 | meshes=[]
181 | for obj_name in object_pose.keys():
182 | obj_pose=object_pose[obj_name]
183 | obj_mesh=self.object_mesh[obj_name]
184 | verts, faces = obj_mesh.vertices, obj_mesh.faces
185 | mesh = Meshes(
186 | vertices=verts,
187 | faces=faces,
188 | name=obj_name,
189 | position=obj_pose['transl'],
190 | rotation=obj_pose['rot'],
191 | color= (0.3,0.3,0.5,1)
192 | )
193 | meshes.append(mesh)
194 | self.scene.add(*meshes)
195 |
196 | self.load_text_from_file()
197 |
198 |
199 | def clear_one_sequence(self):
200 | for x in self.scene.nodes.copy():
201 | if type(x) is Skeletons or type(x) is Meshes:
202 | self.scene.remove(x)
203 |
204 |
205 | if __name__=='__main__':
206 |
207 | viewer=Seg_Viewer()
208 | viewer.scene.fps=30
209 | viewer.playback_fps=30
210 | viewer.run()
--------------------------------------------------------------------------------
/visualize/skel_viewer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 |
4 | import glfw
5 | import imgui
6 | import numpy as np
7 | import trimesh
8 |
9 | from aitviewer.configuration import CONFIG as C
10 | from aitviewer.renderables.meshes import Meshes
11 | from aitviewer.viewer import Viewer
12 | from aitviewer.renderables.skeletons import Skeletons
13 |
14 | glfw.init()
15 | primary_monitor = glfw.get_primary_monitor()
16 | mode = glfw.get_video_mode(primary_monitor)
17 | width = mode.size.width
18 | height = mode.size.height
19 |
20 | C.update_conf({'window_width': width*0.9, 'window_height': height*0.9})
21 |
22 | OPTITRACK_LIMBS=[
23 | [0,1],[1,2],[2,3],[3,4],
24 | [0,5],[5,6],[6,7],[7,8],
25 | [0,9],[9,10],
26 | [10,11],[11,12],[12,13],[13,14],
27 | [14,15],[15,16],[16,17],[17,18],
28 | [14,19],[19,20],[20,21],[21,22],
29 | [14,23],[23,24],[24,25],[25,26],
30 | [14,27],[27,28],[28,29],[29,30],
31 | [14,31],[31,32],[32,33],[33,34],
32 | [10,35],[35,36],[36,37],[37,38],
33 | [38,39],[39,40],[40,41],[41,42],
34 | [38,43],[43,44],[44,45],[45,46],
35 | [38,47],[47,48],[48,49],[49,50],
36 | [38,51],[51,52],[52,53],[53,54],
37 | [38,55],[55,56],[56,57],[57,58],
38 | [10,59],[59,60]
39 | ]
40 |
41 | SELECTED_JOINTS=np.concatenate(
42 | [range(0,5),range(6,10),range(11,63)]
43 | )
44 |
45 | class Skel_Viewer(Viewer):
46 | title='HIMO Viewer for Skeleton'
47 |
48 | def __init__(self,**kwargs):
49 | super().__init__(**kwargs)
50 | self.gui_controls.update(
51 | {
52 | 'show_text':self.gui_show_text
53 | }
54 | )
55 | self._set_prev_record=self.wnd.keys.UP
56 | self._set_next_record=self.wnd.keys.DOWN
57 |
58 | # reset
59 | self.reset_for_himo()
60 | self.load_one_sequence()
61 |
62 | def reset_for_himo(self):
63 |
64 | self.text_val = ''
65 |
66 | self.clip_folder = os.path.join('data','joints')
67 | self.text_folder = os.path.join('data','text')
68 | self.object_pose_folder = os.path.join('data','object_pose')
69 | self.object_mesh_folder = os.path.join('data','object_mesh')
70 |
71 | # Pre-load object meshes
72 | self.object_mesh={}
73 | for obj in os.listdir(self.object_mesh_folder):
74 | if not obj.startswith('.'):
75 | obj_name = obj.split('.')[0]
76 | obj_path = os.path.join(self.object_mesh_folder, obj)
77 | mesh = trimesh.load(obj_path)
78 | self.object_mesh[obj_name] = mesh
79 |
80 | self.label_npy_list = []
81 | self.get_label_file_list()
82 | self.total_tasks = len(self.label_npy_list)
83 |
84 | self.label_pid = 0
85 | self.go_to_idx = 0
86 |
87 | def key_event(self, key, action, modifiers):
88 | if action==self.wnd.keys.ACTION_PRESS:
89 | if key==self._set_prev_record:
90 | self.set_prev_record()
91 | elif key==self._set_next_record:
92 | self.set_next_record()
93 | else:
94 | return super().key_event(key, action, modifiers)
95 | else:
96 | return super().key_event(key, action, modifiers)
97 |
98 | def gui_show_text(self):
99 | imgui.set_next_window_position(self.window_size[0] * 0.6, self.window_size[1]*0.25, imgui.FIRST_USE_EVER)
100 | imgui.set_next_window_size(self.window_size[0] * 0.35, self.window_size[1]*0.4, imgui.FIRST_USE_EVER)
101 | expanded, _ = imgui.begin("HIMO Text Descriptions", None)
102 |
103 | if expanded:
104 | npy_folder = self.label_npy_list[self.label_pid].split('/')[-1]
105 | imgui.text(str(npy_folder))
106 | bef_button = imgui.button('<>')
111 | if next_button:
112 | self.set_next_record()
113 | imgui.same_line()
114 | tmp_idx = ''
115 | imgui.set_next_item_width(imgui.get_window_width() * 0.1)
116 | is_go_to, tmp_idx = imgui.input_text('', tmp_idx); imgui.same_line()
117 | if is_go_to:
118 | try:
119 | self.go_to_idx = int(tmp_idx) - 1
120 | except:
121 | pass
122 | go_to_button = imgui.button('>>Go<<'); imgui.same_line()
123 | if go_to_button:
124 | self.set_goto_record(self.go_to_idx)
125 | imgui.text(str(self.label_pid+1) + '/' + str(self.total_tasks))
126 |
127 | imgui.text_wrapped(self.text_val)
128 | imgui.end()
129 |
130 | def set_prev_record(self):
131 | self.label_pid = (self.label_pid - 1) % self.total_tasks
132 | self.clear_one_sequence()
133 | self.load_one_sequence()
134 | self.scene.current_frame_id=0
135 |
136 | def set_next_record(self):
137 | self.label_pid = (self.label_pid + 1) % self.total_tasks
138 | self.clear_one_sequence()
139 | self.load_one_sequence()
140 | self.scene.current_frame_id=0
141 |
142 | def set_goto_record(self, idx):
143 | self.label_pid = int(idx) % self.total_tasks
144 | self.clear_one_sequence()
145 | self.load_one_sequence()
146 | self.scene.current_frame_id=0
147 |
148 | def get_label_file_list(self):
149 | for clip in sorted(os.listdir(self.clip_folder)):
150 | if not clip.startswith('.'):
151 | self.label_npy_list.append(os.path.join(self.clip_folder, clip))
152 |
153 | def load_text_from_file(self):
154 | self.text_val = ''
155 | clip_name = os.path.split(self.label_npy_list[self.label_pid])[-1][:-4]
156 | if os.path.exists(os.path.join(self.text_folder, clip_name+'.txt')):
157 | with open(os.path.join(self.text_folder, clip_name+'.txt'), 'r') as f:
158 | for line in f.readlines():
159 | self.text_val += line
160 | self.text_val += '\n'
161 |
162 |
163 | def load_one_sequence(self):
164 | skel_file = self.label_npy_list[self.label_pid]
165 | clip_name=os.path.split(skel_file)[-1][:-4]
166 | opj_pose_file=os.path.join(self.object_pose_folder, clip_name+'.npy')
167 |
168 | # load skeleton
169 | skel_data = np.load(skel_file, allow_pickle=True)
170 | skel_data=skel_data[:,SELECTED_JOINTS]
171 | skel=Skeletons(
172 | joint_positions=skel_data,
173 | joint_connections=OPTITRACK_LIMBS,
174 | radius=0.005
175 | )
176 | self.scene.add(skel)
177 |
178 | # Load object
179 | object_pose=np.load(opj_pose_file, allow_pickle=True).item()
180 | meshes=[]
181 | for obj_name in object_pose.keys():
182 | obj_pose=object_pose[obj_name]
183 | obj_mesh=self.object_mesh[obj_name]
184 | verts, faces = obj_mesh.vertices, obj_mesh.faces
185 | mesh = Meshes(
186 | vertices=verts,
187 | faces=faces,
188 | name=obj_name,
189 | position=obj_pose['transl'],
190 | rotation=obj_pose['rot'],
191 | color= (0.3,0.3,0.5,1)
192 | )
193 | meshes.append(mesh)
194 | self.scene.add(*meshes)
195 |
196 | self.load_text_from_file()
197 |
198 |
199 | def clear_one_sequence(self):
200 | for x in self.scene.nodes.copy():
201 | if type(x) is Skeletons or type(x) is Meshes:
202 | self.scene.remove(x)
203 |
204 |
205 | if __name__=='__main__':
206 |
207 | viewer=Skel_Viewer()
208 | viewer.scene.fps=30
209 | viewer.playback_fps=30
210 | viewer.run()
--------------------------------------------------------------------------------
/visualize/smplx_viewer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 |
4 | import glfw
5 | import imgui
6 | import numpy as np
7 | import trimesh
8 |
9 | from aitviewer.configuration import CONFIG as C
10 | from aitviewer.renderables.meshes import Meshes
11 | from aitviewer.viewer import Viewer
12 | from aitviewer.models.smpl import SMPLLayer
13 | from aitviewer.renderables.smpl import SMPLSequence
14 |
15 | glfw.init()
16 | primary_monitor = glfw.get_primary_monitor()
17 | mode = glfw.get_video_mode(primary_monitor)
18 | width = mode.size.width
19 | height = mode.size.height
20 |
21 | C.update_conf({'window_width': width*0.9, 'window_height': height*0.9})
22 | C.update_conf({'smplx_models':'./body_models'})
23 |
24 | class SMPLX_Viewer(Viewer):
25 | title='HIMO Viewer for SMPL-X'
26 |
27 | def __init__(self,**kwargs):
28 | super().__init__(**kwargs)
29 | self.gui_controls.update(
30 | {
31 | 'show_text':self.gui_show_text
32 | }
33 | )
34 | self._set_prev_record=self.wnd.keys.UP
35 | self._set_next_record=self.wnd.keys.DOWN
36 |
37 | # reset
38 | self.reset_for_himo()
39 | self.load_one_sequence()
40 |
41 | def reset_for_himo(self):
42 |
43 | self.text_val = ''
44 |
45 | self.clip_folder = os.path.join('data','smplx')
46 | self.text_folder = os.path.join('data','text')
47 | self.object_pose_folder = os.path.join('data','object_pose')
48 | self.object_mesh_folder = os.path.join('data','object_mesh')
49 |
50 | # Pre-load object meshes
51 | self.object_mesh={}
52 | for obj in os.listdir(self.object_mesh_folder):
53 | if not obj.startswith('.'):
54 | obj_name = obj.split('.')[0]
55 | obj_path = os.path.join(self.object_mesh_folder, obj)
56 | mesh = trimesh.load(obj_path)
57 | self.object_mesh[obj_name] = mesh
58 |
59 | self.label_npy_list = []
60 | self.get_label_file_list()
61 | self.total_tasks = len(self.label_npy_list)
62 |
63 | self.label_pid = 0
64 | self.go_to_idx = 0
65 |
66 | def key_event(self, key, action, modifiers):
67 | if action==self.wnd.keys.ACTION_PRESS:
68 | if key==self._set_prev_record:
69 | self.set_prev_record()
70 | elif key==self._set_next_record:
71 | self.set_next_record()
72 | else:
73 | return super().key_event(key, action, modifiers)
74 | else:
75 | return super().key_event(key, action, modifiers)
76 |
77 | def gui_show_text(self):
78 | imgui.set_next_window_position(self.window_size[0] * 0.6, self.window_size[1]*0.25, imgui.FIRST_USE_EVER)
79 | imgui.set_next_window_size(self.window_size[0] * 0.35, self.window_size[1]*0.4, imgui.FIRST_USE_EVER)
80 | expanded, _ = imgui.begin("HIMO Text Descriptions", None)
81 |
82 | if expanded:
83 | npy_folder = self.label_npy_list[self.label_pid].split('/')[-1]
84 | imgui.text(str(npy_folder))
85 | bef_button = imgui.button('<>')
90 | if next_button:
91 | self.set_next_record()
92 | imgui.same_line()
93 | tmp_idx = ''
94 | imgui.set_next_item_width(imgui.get_window_width() * 0.1)
95 | is_go_to, tmp_idx = imgui.input_text('', tmp_idx); imgui.same_line()
96 | if is_go_to:
97 | try:
98 | self.go_to_idx = int(tmp_idx) - 1
99 | except:
100 | pass
101 | go_to_button = imgui.button('>>Go<<'); imgui.same_line()
102 | if go_to_button:
103 | self.set_goto_record(self.go_to_idx)
104 | imgui.text(str(self.label_pid+1) + '/' + str(self.total_tasks))
105 |
106 | imgui.text_wrapped(self.text_val)
107 | imgui.end()
108 |
109 | def set_prev_record(self):
110 | self.label_pid = (self.label_pid - 1) % self.total_tasks
111 | self.clear_one_sequence()
112 | self.load_one_sequence()
113 | self.scene.current_frame_id=0
114 |
115 | def set_next_record(self):
116 | self.label_pid = (self.label_pid + 1) % self.total_tasks
117 | self.clear_one_sequence()
118 | self.load_one_sequence()
119 | self.scene.current_frame_id=0
120 |
121 | def set_goto_record(self, idx):
122 | self.label_pid = int(idx) % self.total_tasks
123 | self.clear_one_sequence()
124 | self.load_one_sequence()
125 | self.scene.current_frame_id=0
126 |
127 | def get_label_file_list(self):
128 | for clip in sorted(os.listdir(self.clip_folder)):
129 | if not clip.startswith('.'):
130 | self.label_npy_list.append(os.path.join(self.clip_folder, clip))
131 |
132 | def load_text_from_file(self):
133 | self.text_val = ''
134 | clip_name = os.path.split(self.label_npy_list[self.label_pid])[-1][:-4]
135 | if os.path.exists(os.path.join(self.text_folder, clip_name+'.txt')):
136 | with open(os.path.join(self.text_folder, clip_name+'.txt'), 'r') as f:
137 | for line in f.readlines():
138 | self.text_val += line
139 | self.text_val += '\n'
140 |
141 |
142 | def load_one_sequence(self):
143 | smplx_file = self.label_npy_list[self.label_pid]
144 | clip_name=os.path.split(smplx_file)[-1][:-4]
145 | opj_pose_file=os.path.join(self.object_pose_folder, clip_name+'.npy')
146 |
147 | # load smplx
148 |
149 | smplx_params = np.load(smplx_file, allow_pickle=True)
150 | nf = smplx_params['body_pose'].shape[0]
151 |
152 | betas = smplx_params['betas']
153 | poses_root = smplx_params['global_orient']
154 | poses_body = smplx_params['body_pose'].reshape(nf,-1)
155 | poses_lhand = smplx_params['lhand_pose'].reshape(nf,-1)
156 | poses_rhand = smplx_params['rhand_pose'].reshape(nf,-1)
157 | transl = smplx_params['transl']
158 |
159 | # create body models
160 | smplx_layer = SMPLLayer(model_type='smplx',gender='neutral',num_betas=10,device=C.device)
161 |
162 | # create smplx sequence for two persons
163 | smplx_seq = SMPLSequence(poses_body=poses_body,
164 | smpl_layer=smplx_layer,
165 | poses_root=poses_root,
166 | betas=betas,
167 | trans=transl,
168 | poses_left_hand=poses_lhand,
169 | poses_right_hand=poses_rhand,
170 | device=C.device,
171 | )
172 |
173 | self.scene.add(smplx_seq)
174 |
175 | # Load object
176 | object_pose=np.load(opj_pose_file, allow_pickle=True).item()
177 | meshes=[]
178 | for obj_name in object_pose.keys():
179 | obj_pose=object_pose[obj_name]
180 | obj_mesh=self.object_mesh[obj_name]
181 | verts, faces = obj_mesh.vertices, obj_mesh.faces
182 | mesh = Meshes(
183 | vertices=verts,
184 | faces=faces,
185 | name=obj_name,
186 | position=obj_pose['transl'],
187 | rotation=obj_pose['rot'],
188 | color= (0.3,0.3,0.5,1)
189 | )
190 | meshes.append(mesh)
191 | self.scene.add(*meshes)
192 |
193 | self.load_text_from_file()
194 |
195 |
196 | def clear_one_sequence(self):
197 | for x in self.scene.nodes.copy():
198 | if type(x) is SMPLSequence or type(x) is SMPLLayer or type(x) is Meshes:
199 | self.scene.remove(x)
200 |
201 |
202 | if __name__=='__main__':
203 |
204 | viewer=SMPLX_Viewer()
205 | viewer.scene.fps=30
206 | viewer.playback_fps=30
207 | viewer.run()
--------------------------------------------------------------------------------
/src/diffusion/fp16_util.py:
--------------------------------------------------------------------------------
1 | """
2 | Helpers to train with 16-bit precision.
3 | """
4 |
5 | import numpy as np
6 | import torch as th
7 | import torch.nn as nn
8 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
9 |
10 | from src.diffusion import logger
11 |
12 | INITIAL_LOG_LOSS_SCALE = 20.0
13 |
14 |
15 | def convert_module_to_f16(l):
16 | """
17 | Convert primitive modules to float16.
18 | """
19 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
20 | l.weight.data = l.weight.data.half()
21 | if l.bias is not None:
22 | l.bias.data = l.bias.data.half()
23 |
24 |
25 | def convert_module_to_f32(l):
26 | """
27 | Convert primitive modules to float32, undoing convert_module_to_f16().
28 | """
29 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
30 | l.weight.data = l.weight.data.float()
31 | if l.bias is not None:
32 | l.bias.data = l.bias.data.float()
33 |
34 |
35 | def make_master_params(param_groups_and_shapes):
36 | """
37 | Copy model parameters into a (differently-shaped) list of full-precision
38 | parameters.
39 | """
40 | master_params = []
41 | for param_group, shape in param_groups_and_shapes:
42 | master_param = nn.Parameter(
43 | _flatten_dense_tensors(
44 | [param.detach().float() for (_, param) in param_group]
45 | ).view(shape)
46 | )
47 | master_param.requires_grad = True
48 | master_params.append(master_param)
49 | return master_params
50 |
51 |
52 | def model_grads_to_master_grads(param_groups_and_shapes, master_params):
53 | """
54 | Copy the gradients from the model parameters into the master parameters
55 | from make_master_params().
56 | """
57 | for master_param, (param_group, shape) in zip(
58 | master_params, param_groups_and_shapes
59 | ):
60 | master_param.grad = _flatten_dense_tensors(
61 | [param_grad_or_zeros(param) for (_, param) in param_group]
62 | ).view(shape)
63 |
64 |
65 | def master_params_to_model_params(param_groups_and_shapes, master_params):
66 | """
67 | Copy the master parameter data back into the model parameters.
68 | """
69 | # Without copying to a list, if a generator is passed, this will
70 | # silently not copy any parameters.
71 | for master_param, (param_group, _) in zip(master_params, param_groups_and_shapes):
72 | for (_, param), unflat_master_param in zip(
73 | param_group, unflatten_master_params(param_group, master_param.view(-1))
74 | ):
75 | param.detach().copy_(unflat_master_param)
76 |
77 |
78 | def unflatten_master_params(param_group, master_param):
79 | return _unflatten_dense_tensors(master_param, [param for (_, param) in param_group])
80 |
81 |
82 | def get_param_groups_and_shapes(named_model_params):
83 | named_model_params = list(named_model_params)
84 | scalar_vector_named_params = (
85 | [(n, p) for (n, p) in named_model_params if p.ndim <= 1],
86 | (-1),
87 | )
88 | matrix_named_params = (
89 | [(n, p) for (n, p) in named_model_params if p.ndim > 1],
90 | (1, -1),
91 | )
92 | return [scalar_vector_named_params, matrix_named_params]
93 |
94 |
95 | def master_params_to_state_dict(
96 | model, param_groups_and_shapes, master_params, use_fp16
97 | ):
98 | if use_fp16:
99 | state_dict = model.state_dict()
100 | for master_param, (param_group, _) in zip(
101 | master_params, param_groups_and_shapes
102 | ):
103 | for (name, _), unflat_master_param in zip(
104 | param_group, unflatten_master_params(param_group, master_param.view(-1))
105 | ):
106 | assert name in state_dict
107 | state_dict[name] = unflat_master_param
108 | else:
109 | state_dict = model.state_dict()
110 | for i, (name, _value) in enumerate(model.named_parameters()):
111 | assert name in state_dict
112 | state_dict[name] = master_params[i]
113 | return state_dict
114 |
115 |
116 | def state_dict_to_master_params(model, state_dict, use_fp16):
117 | if use_fp16:
118 | named_model_params = [
119 | (name, state_dict[name]) for name, _ in model.named_parameters()
120 | ]
121 | param_groups_and_shapes = get_param_groups_and_shapes(named_model_params)
122 | master_params = make_master_params(param_groups_and_shapes)
123 | else:
124 | master_params = [state_dict[name] for name, _ in model.named_parameters()]
125 | return master_params
126 |
127 |
128 | def zero_master_grads(master_params):
129 | for param in master_params:
130 | param.grad = None
131 |
132 |
133 | def zero_grad(model_params):
134 | for param in model_params:
135 | # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group
136 | if param.grad is not None:
137 | param.grad.detach_()
138 | param.grad.zero_()
139 |
140 |
141 | def param_grad_or_zeros(param):
142 | if param.grad is not None:
143 | return param.grad.data.detach()
144 | else:
145 | return th.zeros_like(param)
146 |
147 |
148 | class MixedPrecisionTrainer:
149 | def __init__(
150 | self,
151 | *,
152 | model,
153 | use_fp16=False,
154 | fp16_scale_growth=1e-3,
155 | initial_lg_loss_scale=INITIAL_LOG_LOSS_SCALE,
156 | ):
157 | self.model = model
158 | self.use_fp16 = use_fp16
159 | self.fp16_scale_growth = fp16_scale_growth
160 |
161 | self.model_params = list(self.model.parameters())
162 | self.master_params = self.model_params
163 | self.param_groups_and_shapes = None
164 | self.lg_loss_scale = initial_lg_loss_scale
165 |
166 | if self.use_fp16:
167 | self.param_groups_and_shapes = get_param_groups_and_shapes(
168 | self.model.named_parameters()
169 | )
170 | self.master_params = make_master_params(self.param_groups_and_shapes)
171 | self.model.convert_to_fp16()
172 |
173 | def zero_grad(self):
174 | zero_grad(self.model_params)
175 |
176 | def backward(self, loss: th.Tensor):
177 | if self.use_fp16:
178 | loss_scale = 2 ** self.lg_loss_scale
179 | (loss * loss_scale).backward()
180 | else:
181 | loss.backward()
182 |
183 | def optimize(self, opt: th.optim.Optimizer):
184 | if self.use_fp16:
185 | return self._optimize_fp16(opt)
186 | else:
187 | return self._optimize_normal(opt)
188 |
189 | def _optimize_fp16(self, opt: th.optim.Optimizer):
190 | logger.logkv_mean("lg_loss_scale", self.lg_loss_scale)
191 | model_grads_to_master_grads(self.param_groups_and_shapes, self.master_params)
192 | grad_norm, param_norm = self._compute_norms(grad_scale=2 ** self.lg_loss_scale)
193 | if check_overflow(grad_norm):
194 | self.lg_loss_scale -= 1
195 | logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}")
196 | zero_master_grads(self.master_params)
197 | return False
198 |
199 | logger.logkv_mean("grad_norm", grad_norm)
200 | logger.logkv_mean("param_norm", param_norm)
201 |
202 | self.master_params[0].grad.mul_(1.0 / (2 ** self.lg_loss_scale))
203 | opt.step()
204 | zero_master_grads(self.master_params)
205 | master_params_to_model_params(self.param_groups_and_shapes, self.master_params)
206 | self.lg_loss_scale += self.fp16_scale_growth
207 | return True
208 |
209 | def _optimize_normal(self, opt: th.optim.Optimizer):
210 | grad_norm, param_norm = self._compute_norms()
211 | logger.logkv_mean("grad_norm", grad_norm)
212 | logger.logkv_mean("param_norm", param_norm)
213 | opt.step()
214 | return True
215 |
216 | def _compute_norms(self, grad_scale=1.0):
217 | grad_norm = 0.0
218 | param_norm = 0.0
219 | for p in self.master_params:
220 | with th.no_grad():
221 | param_norm += th.norm(p, p=2, dtype=th.float32).item() ** 2
222 | if p.grad is not None:
223 | grad_norm += th.norm(p.grad, p=2, dtype=th.float32).item() ** 2
224 | return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm)
225 |
226 | def master_params_to_state_dict(self, master_params):
227 | return master_params_to_state_dict(
228 | self.model, self.param_groups_and_shapes, master_params, self.use_fp16
229 | )
230 |
231 | def state_dict_to_master_params(self, state_dict):
232 | return state_dict_to_master_params(self.model, state_dict, self.use_fp16)
233 |
234 |
235 | def check_overflow(value):
236 | return (value == float("inf")) or (value == -float("inf")) or (value != value)
237 |
--------------------------------------------------------------------------------
/src/utils/rotation_conversion.py:
--------------------------------------------------------------------------------
1 | """
2 | This script contains functions for converting between different rotation representations.
3 | Including:
4 | - Euler angles (euler) default to (XYZ,intrinsic)
5 | - Rotation matrices (rot)
6 | - Quaternions (quat)
7 | - Axis-angle (aa)
8 | - 6D representation (6d)
9 | We also provide numpy and torch versions of these functions.
10 | Note that all functions are in batch mode.
11 | """
12 | import torch
13 | import numpy as np
14 | from scipy.spatial.transform import Rotation as R
15 | from pytorch3d.transforms.rotation_conversions import (
16 | quaternion_to_axis_angle,quaternion_to_matrix,
17 | axis_angle_to_matrix,axis_angle_to_quaternion,
18 | matrix_to_axis_angle,matrix_to_quaternion,
19 | euler_angles_to_matrix,matrix_to_euler_angles,
20 | rotation_6d_to_matrix,matrix_to_rotation_6d
21 | )
22 |
23 | #--------------------Numpy Version--------------------------#
24 | def euler2rot_numpy(euler,degrees=False):
25 | """
26 | euler:[B,3] (XYZ,extrinsic)
27 | degrees are False if they are radians
28 | return: [B,3,3]
29 | """
30 | assert isinstance(euler, np.ndarray)
31 | ori_shape = euler.shape[:-1]
32 | rots = np.reshape(euler, (-1, 3))
33 | rots = R.as_matrix(R.from_euler("XYZ", rots, degrees=degrees))
34 | rot = np.reshape(rots, ori_shape + (3, 3))
35 | return rot
36 |
37 | def rot2euler_numpy(rot,degrees=False):
38 | """
39 | rot:[B,3,3]
40 | return: [B,3]
41 | """
42 | assert isinstance(rot, np.ndarray)
43 | ori_shape = rot.shape[:-2]
44 | rots = np.reshape(rot, (-1, 3, 3))
45 | rots = R.as_euler(R.from_matrix(rots), "XYZ", degrees=degrees)
46 | euler = np.reshape(rots, ori_shape + (3,))
47 | return euler
48 |
49 | def euler2quat_numpy(euler,degrees=False):
50 | """
51 | euler:[B,3]
52 | return [B,4]
53 | """
54 | assert isinstance(euler,np.ndarray)
55 | ori_shape=euler.shape[:-1]
56 | rots=np.reshape(euler,(-1,3))
57 | quats=R.as_quat(R.from_euler("XYZ",rots,degrees=degrees))
58 | quat=np.reshape(quats,ori_shape+(4,))
59 | return quat
60 |
61 | def quat2euler_numpy(quat,degrees=False):
62 | """
63 | quat:[B,4]
64 | return [B,3]
65 | """
66 | assert isinstance(quat,np.ndarray)
67 | ori_shape=quat.shape[:-1]
68 | rots=np.reshape(quat,(-1,3))
69 | eulers=R.as_euler("XYZ",R.from_quat(rots),degrees=degrees)
70 | euler=np.reshape(eulers,ori_shape+(3,))
71 | return euler
72 |
73 | def euler2aa_numpy(euler,degrees=False):
74 | """
75 | euler:[B,3]
76 | return: [B,3]
77 | """
78 | assert isinstance(euler, np.ndarray)
79 | ori_shape = euler.shape[:-1]
80 | rots = np.reshape(euler, (-1, 3))
81 | aas = R.as_rotvec(R.from_euler("XYZ", rots, degrees=degrees))
82 | rotation_vectors = np.reshape(aas, ori_shape + (3,))
83 | return rotation_vectors
84 |
85 | def aa2euler_numpy(aa,degrees=False):
86 | """
87 | aa:[B,3]
88 | return [B,3]
89 | """
90 | assert isinstance(aa, np.ndarray)
91 | ori_shape = aa.shape[:-1]
92 | aas = np.reshape(aa, (-1, 3))
93 | rots = R.as_euler(R.from_rotvec(aas), "XYZ", degrees=degrees)
94 | euler_angles = np.reshape(rots, ori_shape + (3,))
95 | return euler_angles
96 |
97 | def rot2quat_numpy(rot):
98 | """
99 | rot:[B,3,3]
100 | return [B,4]
101 | """
102 | return euler2quat_numpy(rot2euler_numpy(rot))
103 |
104 | def quat2rot_numpy(quat):
105 | """
106 | quat:[B,4] (w,x,y,z)
107 | return: [B,3,3]
108 | """
109 | return euler2rot_numpy(quat2euler_numpy(quat))
110 |
111 | def rot2aa_numpy(rot):
112 | """
113 | rot:[B,3,3]
114 | return:[B,3]
115 | """
116 | assert isinstance(rot, np.ndarray)
117 | ori_shape = rot.shape[:-2]
118 | rots = np.reshape(rot, (-1, 3, 3))
119 | aas = R.as_rotvec(R.from_matrix(rots))
120 | rotation_vectors = np.reshape(aas, ori_shape + (3,))
121 | return rotation_vectors
122 |
123 | def aa2rot_numpy(aa):
124 | """
125 | aa:[B,3]
126 | Rodirgues formula
127 | return: [B,3,3]
128 | """
129 | assert isinstance(aa,np.ndarray)
130 | ori_shape=aa.shape[:-1]
131 | aas=np.reshape(aa,(-1,3))
132 | rots=R.as_matrix(R.from_rotvec(aas))
133 | rot_mat=np.reshape(rots,ori_shape+(3,3))
134 | return rot_mat
135 |
136 | def quat2aa_numpy(quat):
137 | """
138 | quat:[B,4]
139 | return [B,3]
140 | """
141 | return euler2aa_numpy(quat2euler_numpy(quat))
142 |
143 | def aa2quat_numpy(aa):
144 | """
145 | aa:[B,3]
146 | return [B,4]
147 | """
148 | return euler2quat_numpy(aa2euler_numpy(aa))
149 |
150 | def rot2sixd_numpy(rot):
151 | """
152 | rot:[B,3,3]
153 | return [B,6]
154 | """
155 | assert isinstance(rot,np.ndarray)
156 | ori_shape=rot.shape[:-2]
157 | return rot[...,:2,:].copy().reshape(ori_shape+(6,))
158 |
159 | def sixd2rot_numpy(sixd):
160 | """
161 | sixd:[B,6]
162 | return [B,3,3]
163 | """
164 | assert isinstance(sixd,np.ndarray)
165 | a1,a2=sixd[...,:3],sixd[...,3:]
166 | b1=a1/np.linalg.norm(a1,axis=-1,keepdims=True)
167 | b2=a2-(b1*a2).sum(-1,keepdims=True)*b1
168 | b2=b2/np.linalg.norm(b2,axis=-1,keepdims=True)
169 | b3=np.cross(b1,b2,axis=-1)
170 | return np.stack([b1,b2,b3],axis=-2)
171 |
172 | def rpy2rot_numpy(rpy,degrees=False):
173 | """
174 | rpy: [B,3] (ZYX,intrinsic)
175 | return [B,3,3]
176 | """
177 | assert isinstance(rpy, np.ndarray)
178 | ori_shape = rpy.shape[:-1]
179 | rots = np.reshape(rpy, (-1, 3))
180 | rots = R.as_matrix(R.from_euler("ZYX", rots, degrees=degrees))
181 | rotation_matrices = np.reshape(rots, ori_shape + (3, 3))
182 | return rotation_matrices
183 |
184 | #--------------------Pytorch Version--------------------------#
185 | def euler2rot_torch(euler,degrees=False):
186 | """
187 | euler [B,3] (XYZ,intrinsic)
188 | degrees are False if they are radians
189 | """
190 | if degrees:
191 | euler_rad=torch.deg2rad(euler)
192 | return euler_angles_to_matrix(euler_rad,"XYZ")
193 | else:
194 | return euler_angles_to_matrix(euler,"XYZ")
195 |
196 | def rot2euler_torch(rot,degrees=False):
197 | """
198 | rot:[B,3,3]
199 | return: [B,3]
200 | """
201 | if degrees:
202 | euler_rad=matrix_to_euler_angles(rot,"XYZ")
203 | return torch.rad2deg(euler_rad)
204 | else:
205 | return matrix_to_euler_angles(rot,"XYZ")
206 |
207 | def euler2quat_torch(euler,degrees=False):
208 | """
209 | euler:[B,3]
210 | return [B,4]
211 | """
212 | if degrees:
213 | euler_rad=torch.deg2rad(euler)
214 | return matrix_to_quaternion(euler_angles_to_matrix(euler_rad,"XYZ"))
215 | else:
216 | return matrix_to_quaternion(euler_angles_to_matrix(euler,"XYZ"))
217 |
218 | def quat2euler_torch(quat,degrees=False):
219 | """
220 | quat:[B,4]
221 | return [B,3]
222 | """
223 | if degrees:
224 | euler_rad=quaternion_to_matrix(quat)
225 | return torch.rad2deg(matrix_to_euler_angles(euler_rad,"XYZ"))
226 | else:
227 | return matrix_to_euler_angles(quaternion_to_matrix(quat),"XYZ")
228 |
229 | def euler2aa_torch(euler,degrees=False):
230 | """
231 | euler:[B,3]
232 | return: [B,3]
233 | """
234 | if degrees:
235 | euler_rad=torch.deg2rad(euler)
236 | return matrix_to_axis_angle(euler_angles_to_matrix(euler_rad,"XYZ"))
237 | else:
238 | return matrix_to_axis_angle(euler_angles_to_matrix(euler,"XYZ"))
239 |
240 | def aa2euler_torch(aa,degrees=False):
241 | """
242 | aa:[B,3]
243 | return [B,3]
244 | """
245 | if degrees:
246 | euler_rad=axis_angle_to_matrix(aa)
247 | return torch.rad2deg(matrix_to_euler_angles(euler_rad,"XYZ"))
248 | else:
249 | return matrix_to_euler_angles(axis_angle_to_matrix(aa),"XYZ")
250 |
251 | def rot2quat_torch(rot):
252 | """
253 | rot:[B,3,3]
254 | return [B,4]
255 | """
256 | return matrix_to_quaternion(rot)
257 |
258 | def quat2rot_torch(quat):
259 | """
260 | quat:[B,4] (w,x,y,z)
261 | return: [B,3,3]
262 | """
263 | return quaternion_to_matrix(quat)
264 |
265 | def rot2aa_torch(rot):
266 | """
267 | rot:[B,3,3]
268 | return:[B,3]
269 | """
270 | return matrix_to_axis_angle(rot)
271 |
272 | def aa2rot_torch(aa):
273 | """
274 | aa:[B,3]
275 | Rodirgues formula
276 | return: [B,3,3]
277 | """
278 | return axis_angle_to_matrix(aa)
279 |
280 | def quat2aa_torch(quat):
281 | """
282 | quat:[B,4]
283 | return [B,3]
284 | """
285 | return quaternion_to_axis_angle(quat)
286 |
287 | def aa2quat_torch(aa):
288 | """
289 | aa:[B,3]
290 | return [B,4]
291 | """
292 | return axis_angle_to_quaternion(aa)
293 |
294 | def rot2sixd_torch(rot):
295 | """
296 | rot:[B,3,3]
297 | return [B,6]
298 | """
299 | return matrix_to_rotation_6d(rot)
300 |
301 | def sixd2rot_torch(sixd):
302 | """
303 | sixd:[B,6]
304 | return [B,3,3]
305 | """
306 | return rotation_6d_to_matrix(sixd)
307 |
308 | if __name__=='__main__':
309 | rot=np.eye(3)[None,...].repeat(2,axis=0)
310 | sixd=rot2sixd_numpy(rot)
311 | new_rot=sixd2rot_numpy(sixd)
312 | print(new_rot)
--------------------------------------------------------------------------------
/src/model/net_2o.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | from loguru import logger
6 | import clip
7 | from src.model.blocks import TransformerBlock
8 |
9 | class NET_2O(nn.Module):
10 | def __init__(self,
11 | n_feats=(52*3+52*6+3+2*9),
12 | clip_dim=512,
13 | latent_dim=512,
14 | ff_size=1024,
15 | num_layers=8,
16 | num_heads=4,
17 | dropout=0.1,
18 | ablation=None,
19 | activation="gelu",**kwargs):
20 | super().__init__()
21 |
22 | self.n_feats = n_feats
23 | self.clip_dim = clip_dim
24 | self.latent_dim = latent_dim
25 | self.ff_size = ff_size
26 | self.num_layers = num_layers
27 | self.num_heads = num_heads
28 | self.dropout = dropout
29 | self.ablation = ablation
30 | self.activation = activation
31 |
32 | self.cond_mask_prob=kwargs.get('cond_mask_prob',0.1)
33 |
34 | # clip and text embedder
35 | self.embed_text=nn.Linear(self.clip_dim,self.latent_dim)
36 | self.clip_version='ViT-B/32'
37 | self.clip_model=self.load_and_freeze_clip(self.clip_version)
38 |
39 | # object_geometry embedder
40 | self.embed_obj_bps=nn.Linear(1024*3,self.latent_dim)
41 | # object init state embeddr
42 | self.embed_obj_pose=nn.Linear(2*self.latent_dim+2*9+2*9,self.latent_dim)
43 |
44 | # human embedder
45 | self.embed_human_pose=nn.Linear(2*(52*3+52*6+3),self.latent_dim)
46 |
47 | # position encoding
48 | self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, dropout=self.dropout)
49 |
50 | # TODO:unshared transformer layers for human and objects,they fuse feature in the middle layers
51 | seqTransEncoderLayer=nn.TransformerEncoderLayer(
52 | d_model=self.latent_dim,
53 | nhead=self.num_heads,
54 | dim_feedforward=self.ff_size,
55 | dropout=self.dropout,
56 | activation=self.activation,
57 | )
58 | # # object transformer layers
59 | # self.object_trans_encoder=nn.TransformerEncoder(seqTransEncoderLayer,num_layers=self.num_layers)
60 | # # human transformer encoder
61 | # self.human_trans_encoder=nn.TransformerEncoder(seqTransEncoderLayer,num_layers=self.num_layers)
62 |
63 | # # Mutal cross attention
64 | # self.communication_module=nn.ModuleList()
65 | # for i in range(8):
66 | # self.communication_module.append(MutalCrossAttentionBlock(self.latent_dim,self.num_heads,self.ff_size,self.dropout))
67 | self.obj_blocks=nn.ModuleList()
68 | self.human_blocks=nn.ModuleList()
69 | for i in range(self.num_layers):
70 | self.obj_blocks.append(TransformerBlock(self.latent_dim, self.num_heads, self.ff_size, self.dropout, self.activation))
71 | self.human_blocks.append(TransformerBlock(self.latent_dim, self.num_heads, self.ff_size, self.dropout, self.activation))
72 |
73 |
74 | # embed the timestep
75 | self.embed_timestep=TimestepEmbedder(self.latent_dim, self.sequence_pos_encoder)
76 |
77 | # object output process
78 | self.obj_output_process=nn.Linear(self.latent_dim,2*9)
79 | # human motion output process
80 | self.human_output_process=nn.Linear(self.latent_dim,52*3+52*6+3)
81 |
82 | def forward(self,x,timesteps,y=None):
83 | bs,nframes,n_feats=x.shape
84 | emb=self.embed_timestep(timesteps) # [1, bs, latent_dim]
85 |
86 | enc_text=self.encode_text(y['text'])
87 | emb+=self.mask_cond(enc_text) # 1,bs,latent_dim
88 |
89 | x=x.permute((1,0,2)) # nframes,bs,nfeats
90 | human_x,obj_x=torch.split(x,[52*3+52*6+3,2*9],dim=-1) # [nframes,bs,52*3+52*6+3],[nframes,bs,2*9]
91 |
92 | # encode object geometry
93 | obj1_bps,obj2_bps=y['obj1_bps'].reshape(bs,-1),y['obj2_bps'].reshape(bs,-1) # [b,1024,3]
94 | obj1_bps_emb=self.embed_obj_bps(obj1_bps) # [b,latent_dim]
95 | obj2_bps_emb=self.embed_obj_bps(obj2_bps) # [b,latent_dim]
96 | obj_geo_emb=torch.concat([obj1_bps_emb,obj2_bps_emb],dim=-1).unsqueeze(0).repeat((nframes,1,1)) # [nf,b,2*latent_dim]
97 |
98 | # init_state,mask the other frames by padding zeros
99 | init_state=y['init_state'].unsqueeze(0) # [1,b,52*3+52*6+3+2*9]
100 | padded_zeros=torch.zeros((nframes-1,bs,52*3+52*6+3+2*9),device=init_state.device)
101 | init_state=torch.concat([init_state,padded_zeros],dim=0) # [nf,b,52*3+52*6+3+2*9]
102 |
103 | # seperate the object and human init state
104 | human_init_state=init_state[:,:,:52*3+52*6+3] # [nf,b,52*3+52*6+3]
105 | obj_init_state=init_state[:,:,52*3+52*6+3:] # [nf,b,2*9]
106 |
107 | # Object branch
108 | obj_emb=self.embed_obj_pose(torch.concat([obj_geo_emb,obj_init_state,obj_x],dim=-1)) # nframes,bs,latent_dim
109 | obj_seq_prev=self.sequence_pos_encoder(obj_emb) # [nf,bs,latent_dim]
110 |
111 |
112 | # Human branch
113 | human_emb=self.embed_human_pose(torch.concat([human_init_state,human_x],dim=-1)) # nframes,bs,latent_dim
114 | human_seq_prev=self.sequence_pos_encoder(human_emb) # [nf,bs,latent_dim]
115 |
116 | mask=y['mask'].squeeze(1).squeeze(1) # [bs,nf]
117 | key_padding_mask=~mask # [bs,nf]
118 |
119 | obj_seq_prev=obj_seq_prev.permute((1,0,2)) # [bs,nf,latent_dim]
120 | human_seq_prev=human_seq_prev.permute((1,0,2)) # [bs,nf,latent_dim]
121 | emb=emb.squeeze(0) # [bs,latent_dim]
122 | for i in range(self.num_layers):
123 | obj_seq=self.obj_blocks[i](obj_seq_prev,human_seq_prev,emb, key_padding_mask=key_padding_mask)
124 | human_seq=self.human_blocks[i](human_seq_prev,obj_seq_prev,emb, key_padding_mask=key_padding_mask)
125 | obj_seq_prev=obj_seq
126 | human_seq_prev=human_seq
127 | obj_seq=obj_seq.permute((1,0,2)) # [nf,bs,latent_dim]
128 | human_seq=human_seq.permute((1,0,2))
129 |
130 |
131 | obj_output=self.obj_output_process(obj_seq) # [nf,bs,2*9]
132 | human_output=self.human_output_process(human_seq) # [nf,bs,52*3+52*6+3]
133 |
134 | output=torch.concat([human_output,obj_output],dim=-1) # [nf,bs,52*3+52*6+3+2*9]
135 |
136 | return output.permute((1,0,2))
137 |
138 | def encode_text(self,raw_text):
139 | # raw_text - list (batch_size length) of strings with input text prompts
140 | device = next(self.parameters()).device
141 | max_text_len = 40
142 | if max_text_len is not None:
143 | default_context_length = 77
144 | context_length = max_text_len + 2 # start_token + 20 + end_token
145 | assert context_length < default_context_length
146 | texts = clip.tokenize(raw_text, context_length=context_length, truncate=True).to(device) # [bs, context_length] # if n_tokens > context_length -> will truncate
147 | # print('texts', texts.shape)
148 | zero_pad = torch.zeros([texts.shape[0], default_context_length-context_length], dtype=texts.dtype, device=texts.device)
149 | texts = torch.cat([texts, zero_pad], dim=1)
150 | # print('texts after pad', texts.shape, texts)
151 | else:
152 | texts = clip.tokenize(raw_text, truncate=True).to(device) # [bs, context_length] # if n_tokens > 77 -> will truncate
153 | return self.clip_model.encode_text(texts).float()
154 |
155 | def mask_cond(self, cond, force_mask=False):
156 | bs, d = cond.shape
157 | if force_mask:
158 | return torch.zeros_like(cond)
159 | elif self.training and self.cond_mask_prob > 0.:
160 | mask = torch.bernoulli(torch.ones(bs, device=cond.device) * self.cond_mask_prob).view(bs, 1) # 1-> use null_cond, 0-> use real cond
161 | return cond * (1. - mask)
162 | else:
163 | return cond
164 |
165 | def load_and_freeze_clip(self, clip_version):
166 | clip_model, clip_preprocess = clip.load(clip_version, device='cpu',
167 | jit=False) # Must set jit=False for training
168 | clip.model.convert_weights(
169 | clip_model) # Actually this line is unnecessary since clip by default already on float16
170 |
171 | # Freeze CLIP weights
172 | clip_model.eval()
173 | for p in clip_model.parameters():
174 | p.requires_grad = False
175 |
176 | return clip_model
177 |
178 | class PositionalEncoding(nn.Module):
179 | def __init__(self, d_model, dropout=0.1, max_len=5000):
180 | super(PositionalEncoding, self).__init__()
181 | self.dropout = nn.Dropout(p=dropout)
182 |
183 | pe = torch.zeros(max_len, d_model)
184 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
185 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
186 | pe[:, 0::2] = torch.sin(position * div_term)
187 | pe[:, 1::2] = torch.cos(position * div_term)
188 | pe = pe.unsqueeze(0).transpose(0, 1)
189 |
190 | self.register_buffer('pe', pe)
191 |
192 | def forward(self, x):
193 | # not used in the final model
194 | x = x + self.pe[:x.shape[0], :]
195 | return self.dropout(x)
196 |
197 |
198 | class TimestepEmbedder(nn.Module):
199 | def __init__(self, latent_dim, sequence_pos_encoder):
200 | super().__init__()
201 | self.latent_dim = latent_dim
202 | self.sequence_pos_encoder = sequence_pos_encoder
203 |
204 | time_embed_dim = self.latent_dim
205 | self.time_embed = nn.Sequential(
206 | nn.Linear(self.latent_dim, time_embed_dim),
207 | nn.SiLU(),
208 | nn.Linear(time_embed_dim, time_embed_dim),
209 | )
210 |
211 | def forward(self, timesteps):
212 | return self.time_embed(self.sequence_pos_encoder.pe[timesteps]).permute(1, 0, 2)
213 |
--------------------------------------------------------------------------------
/src/model/net_3o.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | from loguru import logger
6 | import clip
7 | from src.model.blocks import TransformerBlock
8 |
9 | class NET_3O(nn.Module):
10 | def __init__(self,
11 | n_feats=(52*3+52*6+3+3*9),
12 | clip_dim=512,
13 | latent_dim=512,
14 | ff_size=1024,
15 | num_layers=8,
16 | num_heads=4,
17 | dropout=0.1,
18 | ablation=None,
19 | activation="gelu",**kwargs):
20 | super().__init__()
21 |
22 | self.n_feats = n_feats
23 | self.clip_dim = clip_dim
24 | self.latent_dim = latent_dim
25 | self.ff_size = ff_size
26 | self.num_layers = num_layers
27 | self.num_heads = num_heads
28 | self.dropout = dropout
29 | self.ablation = ablation
30 | self.activation = activation
31 |
32 | self.cond_mask_prob=kwargs.get('cond_mask_prob',0.1)
33 |
34 | # clip and text embedder
35 | self.embed_text=nn.Linear(self.clip_dim,self.latent_dim)
36 | self.clip_version='ViT-B/32'
37 | self.clip_model=self.load_and_freeze_clip(self.clip_version)
38 |
39 | # object_geometry embedder
40 | self.embed_obj_bps=nn.Linear(1024*3,self.latent_dim)
41 | # object init state embeddr
42 | self.embed_obj_pose=nn.Linear(3*self.latent_dim+3*9+3*9,self.latent_dim)
43 |
44 | # human embedder
45 | self.embed_human_pose=nn.Linear(2*(52*3+52*6+3),self.latent_dim)
46 |
47 | # position encoding
48 | self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, dropout=self.dropout)
49 |
50 | # TODO:unshared transformer layers for human and objects,they fuse feature in the middle layers
51 | seqTransEncoderLayer=nn.TransformerEncoderLayer(
52 | d_model=self.latent_dim,
53 | nhead=self.num_heads,
54 | dim_feedforward=self.ff_size,
55 | dropout=self.dropout,
56 | activation=self.activation,
57 | )
58 | # # object transformer layers
59 | # self.object_trans_encoder=nn.TransformerEncoder(seqTransEncoderLayer,num_layers=self.num_layers)
60 | # # human transformer encoder
61 | # self.human_trans_encoder=nn.TransformerEncoder(seqTransEncoderLayer,num_layers=self.num_layers)
62 |
63 | # # Mutal cross attention
64 | # self.communication_module=nn.ModuleList()
65 | # for i in range(8):
66 | # self.communication_module.append(MutalCrossAttentionBlock(self.latent_dim,self.num_heads,self.ff_size,self.dropout))
67 | self.obj_blocks=nn.ModuleList()
68 | self.human_blocks=nn.ModuleList()
69 | for i in range(self.num_layers):
70 | self.obj_blocks.append(TransformerBlock(self.latent_dim, self.num_heads, self.ff_size, self.dropout, self.activation))
71 | self.human_blocks.append(TransformerBlock(self.latent_dim, self.num_heads, self.ff_size, self.dropout, self.activation))
72 |
73 |
74 | # embed the timestep
75 | self.embed_timestep=TimestepEmbedder(self.latent_dim, self.sequence_pos_encoder)
76 |
77 | # object output process
78 | self.obj_output_process=nn.Linear(self.latent_dim,3*9)
79 | # human motion output process
80 | self.human_output_process=nn.Linear(self.latent_dim,52*3+52*6+3)
81 |
82 | def forward(self,x,timesteps,y=None):
83 | bs,nframes,n_feats=x.shape
84 | emb=self.embed_timestep(timesteps) # [1, bs, latent_dim]
85 |
86 | enc_text=self.encode_text(y['text'])
87 | emb+=self.mask_cond(enc_text) # 1,bs,latent_dim
88 |
89 | x=x.permute((1,0,2)) # nframes,bs,nfeats
90 | human_x,obj_x=torch.split(x,[52*3+52*6+3,3*9],dim=-1) # [nframes,bs,52*3+52*6+3],[nframes,bs,3*9]
91 |
92 | # encode object geometry
93 | obj1_bps,obj2_bps,obj3_bps=y['obj1_bps'].reshape(bs,-1),y['obj2_bps'].reshape(bs,-1),y['obj3_bps'].reshape(bs,-1) # [b,1024,3]
94 | obj1_bps_emb=self.embed_obj_bps(obj1_bps) # [b,latent_dim]
95 | obj2_bps_emb=self.embed_obj_bps(obj2_bps) # [b,latent_dim]
96 | obj3_bps_emb=self.embed_obj_bps(obj3_bps) # [b,latent_dim]
97 | obj_geo_emb=torch.concat([obj1_bps_emb,obj2_bps_emb,obj3_bps_emb],dim=-1).unsqueeze(0).repeat((nframes,1,1)) # [nf,b,3*latent_dim]
98 |
99 | # init_state,mask the other frames by padding zeros
100 | init_state=y['init_state'].unsqueeze(0) # [1,b,52*3+52*6+3+3*9]
101 | padded_zeros=torch.zeros((nframes-1,bs,52*3+52*6+3+3*9),device=init_state.device)
102 | init_state=torch.concat([init_state,padded_zeros],dim=0) # [nf,b,52*3+52*6+3+3*9]
103 |
104 | # seperate the object and human init state
105 | human_init_state=init_state[:,:,:52*3+52*6+3] # [nf,b,52*3+52*6+3]
106 | obj_init_state=init_state[:,:,52*3+52*6+3:] # [nf,b,3*9]
107 |
108 | # Object branch
109 | obj_emb=self.embed_obj_pose(torch.concat([obj_geo_emb,obj_init_state,obj_x],dim=-1)) # nframes,bs,latent_dim
110 | obj_seq_prev=self.sequence_pos_encoder(obj_emb) # [nf,bs,latent_dim]
111 |
112 |
113 | # Human branch
114 | human_emb=self.embed_human_pose(torch.concat([human_init_state,human_x],dim=-1)) # nframes,bs,latent_dim
115 | human_seq_prev=self.sequence_pos_encoder(human_emb) # [nf,bs,latent_dim]
116 |
117 | mask=y['mask'].squeeze(1).squeeze(1) # [bs,nf]
118 | key_padding_mask=~mask # [bs,nf]
119 |
120 | obj_seq_prev=obj_seq_prev.permute((1,0,2)) # [bs,nf,latent_dim]
121 | human_seq_prev=human_seq_prev.permute((1,0,2)) # [bs,nf,latent_dim]
122 | emb=emb.squeeze(0) # [bs,latent_dim]
123 | for i in range(self.num_layers):
124 | obj_seq=self.obj_blocks[i](obj_seq_prev,human_seq_prev,emb, key_padding_mask=key_padding_mask)
125 | human_seq=self.human_blocks[i](human_seq_prev,obj_seq_prev,emb, key_padding_mask=key_padding_mask)
126 | obj_seq_prev=obj_seq
127 | human_seq_prev=human_seq
128 | obj_seq=obj_seq.permute((1,0,2)) # [nf,bs,latent_dim]
129 | human_seq=human_seq.permute((1,0,2))
130 |
131 |
132 | obj_output=self.obj_output_process(obj_seq) # [nf,bs,3*9]
133 | human_output=self.human_output_process(human_seq) # [nf,bs,52*3+52*6+3]
134 |
135 | output=torch.concat([human_output,obj_output],dim=-1) # [nf,bs,52*3+52*6+3+3*9]
136 |
137 | return output.permute((1,0,2))
138 |
139 | def encode_text(self,raw_text):
140 | # raw_text - list (batch_size length) of strings with input text prompts
141 | device = next(self.parameters()).device
142 | max_text_len = 40
143 | if max_text_len is not None:
144 | default_context_length = 77
145 | context_length = max_text_len + 2 # start_token + 20 + end_token
146 | assert context_length < default_context_length
147 | texts = clip.tokenize(raw_text, context_length=context_length, truncate=True).to(device) # [bs, context_length] # if n_tokens > context_length -> will truncate
148 | # print('texts', texts.shape)
149 | zero_pad = torch.zeros([texts.shape[0], default_context_length-context_length], dtype=texts.dtype, device=texts.device)
150 | texts = torch.cat([texts, zero_pad], dim=1)
151 | # print('texts after pad', texts.shape, texts)
152 | else:
153 | texts = clip.tokenize(raw_text, truncate=True).to(device) # [bs, context_length] # if n_tokens > 77 -> will truncate
154 | return self.clip_model.encode_text(texts).float()
155 |
156 | def mask_cond(self, cond, force_mask=False):
157 | bs, d = cond.shape
158 | if force_mask:
159 | return torch.zeros_like(cond)
160 | elif self.training and self.cond_mask_prob > 0.:
161 | mask = torch.bernoulli(torch.ones(bs, device=cond.device) * self.cond_mask_prob).view(bs, 1) # 1-> use null_cond, 0-> use real cond
162 | return cond * (1. - mask)
163 | else:
164 | return cond
165 |
166 | def load_and_freeze_clip(self, clip_version):
167 | clip_model, clip_preprocess = clip.load(clip_version, device='cpu',
168 | jit=False) # Must set jit=False for training
169 | clip.model.convert_weights(
170 | clip_model) # Actually this line is unnecessary since clip by default already on float16
171 |
172 | # Freeze CLIP weights
173 | clip_model.eval()
174 | for p in clip_model.parameters():
175 | p.requires_grad = False
176 |
177 | return clip_model
178 |
179 | class PositionalEncoding(nn.Module):
180 | def __init__(self, d_model, dropout=0.1, max_len=5000):
181 | super(PositionalEncoding, self).__init__()
182 | self.dropout = nn.Dropout(p=dropout)
183 |
184 | pe = torch.zeros(max_len, d_model)
185 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
186 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
187 | pe[:, 0::2] = torch.sin(position * div_term)
188 | pe[:, 1::2] = torch.cos(position * div_term)
189 | pe = pe.unsqueeze(0).transpose(0, 1)
190 |
191 | self.register_buffer('pe', pe)
192 |
193 | def forward(self, x):
194 | # not used in the final model
195 | x = x + self.pe[:x.shape[0], :]
196 | return self.dropout(x)
197 |
198 |
199 | class TimestepEmbedder(nn.Module):
200 | def __init__(self, latent_dim, sequence_pos_encoder):
201 | super().__init__()
202 | self.latent_dim = latent_dim
203 | self.sequence_pos_encoder = sequence_pos_encoder
204 |
205 | time_embed_dim = self.latent_dim
206 | self.time_embed = nn.Sequential(
207 | nn.Linear(self.latent_dim, time_embed_dim),
208 | nn.SiLU(),
209 | nn.Linear(time_embed_dim, time_embed_dim),
210 | )
211 |
212 | def forward(self, timesteps):
213 | return self.time_embed(self.sequence_pos_encoder.pe[timesteps]).permute(1, 0, 2)
214 |
--------------------------------------------------------------------------------
/src/dataset/eval_gen_dataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset,DataLoader
3 | import numpy as np
4 | from tqdm import tqdm
5 | from bps_torch.bps import bps_torch
6 | import os.path as osp
7 | import os
8 | from src.utils.rotation_conversion import sixd2rot_torch
9 | from src.dataset.tensors import gt_collate_fn
10 | from src.utils import dist_utils
11 | from src.dataset.tensors import lengths_to_mask
12 |
13 | def get_eval_gen_loader(args,model,diffusion,gen_loader,
14 | max_motion_length,batch_size,
15 | mm_num_samples,mm_num_repeats,num_samples_limit,scale):
16 | dataset=Evaluation_generator_Dataset(args,model,diffusion,gen_loader,
17 | max_motion_length,mm_num_samples,mm_num_repeats,num_samples_limit,scale)
18 |
19 | mm_dataset=MM_generator_Dataset('test',dataset,gen_loader.dataset.w_vectorizer)
20 |
21 | motion_loader=DataLoader(dataset,batch_size=batch_size,num_workers=4,drop_last=True,collate_fn=gt_collate_fn)
22 | mm_motion_loader = DataLoader(mm_dataset, batch_size=1, num_workers=1)
23 |
24 | print('Generated Dataset Loading Completed!!!')
25 | return motion_loader,mm_motion_loader
26 |
27 | class MM_generator_Dataset(Dataset):
28 | def __init__(self, opt, motion_dataset, w_vectorizer):
29 | self.opt = opt
30 | self.dataset = motion_dataset.mm_generated_motion
31 | self.w_vectorizer = w_vectorizer
32 |
33 | def __len__(self):
34 | return len(self.dataset)
35 |
36 | def __getitem__(self, item):
37 | data = self.dataset[item]
38 | mm_motions = data['mm_motions']
39 | m_lens = []
40 | motions = []
41 | for mm_motion in mm_motions:
42 | m_lens.append(mm_motion['length'])
43 | motion = mm_motion['motion']
44 | # We don't need the following logic because our sample func generates the full tensor anyway:
45 | # if len(motion) < self.opt.max_motion_length:
46 | # motion = np.concatenate([motion,
47 | # np.zeros((self.opt.max_motion_length - len(motion), motion.shape[1]))
48 | # ], axis=0)
49 | motion = motion[None, :]
50 | motions.append(motion)
51 | m_lens = np.array(m_lens, dtype=np.int)
52 | motions = np.concatenate(motions, axis=0)
53 | sort_indx = np.argsort(m_lens)[::-1].copy()
54 | # print(m_lens)
55 | # print(sort_indx)
56 | # print(m_lens[sort_indx])
57 | m_lens = m_lens[sort_indx]
58 | motions = motions[sort_indx]
59 | return motions, m_lens
60 |
61 | class Evaluation_generator_Dataset(Dataset):
62 | def __init__(self,args,model,diffusion,gen_loader,
63 | max_motion_length,mm_num_samples,mm_num_repeats,num_samples_limit,scale=1.0):
64 | self.dataloader=gen_loader
65 | assert mm_num_samples 0:
82 | mm_idxs = np.random.choice(real_num_batches, mm_num_samples // self.dataloader.batch_size +1, replace=False)
83 | mm_idxs = np.sort(mm_idxs)
84 | else:
85 | mm_idxs = []
86 | print('mm_idxs', mm_idxs)
87 |
88 | model.eval()
89 |
90 | with torch.no_grad():
91 | for i,eval_data_batch in tqdm(enumerate(self.dataloader),total=len(self.dataloader)):
92 |
93 | # if i==1:
94 | # break
95 | if num_samples_limit is not None and len(generated_motion) >= num_samples_limit:
96 | break
97 |
98 | if args.obj=='2o':
99 | word_embeddings,pos_one_hots,caption,sent_len,motion,m_length,tokens,\
100 | obj1_bps,obj2_bps,init_state,obj1_name,obj2_name,betas=eval_data_batch
101 | tokens=[t.split('_') for t in tokens]
102 |
103 | model_kwargs={
104 | 'y':{
105 | 'length':m_length.to(dist_utils.dev()), # [bs]
106 | 'text':caption,
107 | 'obj1_bps':obj1_bps.to(dist_utils.dev()),
108 | 'obj2_bps':obj2_bps.to(dist_utils.dev()),
109 | 'init_state':init_state.to(dist_utils.dev()),
110 | 'mask':lengths_to_mask(m_length,motion.shape[1]).unsqueeze(1).unsqueeze(1).to(dist_utils.dev()) # [bs,1,1,nf]
111 | }
112 | }
113 | elif args.obj=='3o':
114 | word_embeddings,pos_one_hots,caption,sent_len,motion,m_length,tokens,\
115 | obj1_bps,obj2_bps,obj3_bps,init_state,obj1_name,obj2_name,obj3_name,betas=eval_data_batch
116 | tokens=[t.split('_') for t in tokens]
117 |
118 | model_kwargs={
119 | 'y':{
120 | 'length':m_length.to(dist_utils.dev()),
121 | 'text':caption,
122 | 'obj1_bps':obj1_bps.to(dist_utils.dev()),
123 | 'obj2_bps':obj2_bps.to(dist_utils.dev()),
124 | 'obj3_bps':obj3_bps.to(dist_utils.dev()),
125 | 'init_state':init_state.to(dist_utils.dev()),
126 | 'mask':lengths_to_mask(m_length,motion.shape[1]).unsqueeze(1).unsqueeze(1).to(dist_utils.dev()) # [bs,1,1,nf]
127 | }
128 | }
129 |
130 | # add CFG scale to batch
131 | if scale != 1.:
132 | model_kwargs['y']['scale'] = torch.ones(motion.shape[0],
133 | device=dist_utils.dev()) * scale
134 |
135 | mm_num_now = len(mm_generated_motion) // self.dataloader.batch_size
136 | is_mm = i in mm_idxs
137 | repeat_times = mm_num_repeats if is_mm else 1
138 | mm_motions = []
139 | for t in range(repeat_times):
140 |
141 | model_out_sample=model_sample_fn(
142 | model,
143 | motion.shape,
144 | clip_denoised=clip_denoised,
145 | model_kwargs=model_kwargs,
146 | skip_timesteps=0, # 0 is the default value - i.e. don't skip any step
147 | init_image=None,
148 | progress=False,
149 | dump_steps=None,
150 | noise=None,
151 | const_noise=False,
152 | ) # bs,nf,315
153 |
154 | # export the model result
155 | network=args.model_path.split('/')[-2]
156 | export_path=osp.join('./export_results',network,'{}.npz'.format(i))
157 | os.makedirs(osp.dirname(export_path),exist_ok=True)
158 | if args.obj=='2o':
159 | output_dict={str(bs_i):
160 | {
161 | 'out':model_out_sample[bs_i].squeeze().cpu().numpy(),
162 | 'length':m_length[bs_i].cpu().numpy(),
163 | 'caption':caption[bs_i],
164 | 'obj1_name':obj1_name[bs_i],
165 | 'obj2_name':obj2_name[bs_i],
166 | 'betas':betas[bs_i].cpu().numpy(),
167 | } for bs_i in range(self.dataloader.batch_size)
168 | }
169 | elif args.obj=='3o':
170 | output_dict={str(bs_i):
171 | {
172 | 'out':model_out_sample[bs_i].squeeze().cpu().numpy(),
173 | 'length':m_length[bs_i].cpu().numpy(),
174 | 'caption':caption[bs_i],
175 | 'obj1_name':obj1_name[bs_i],
176 | 'obj2_name':obj2_name[bs_i],
177 | 'obj3_name':obj3_name[bs_i],
178 | 'betas':betas[bs_i].cpu().numpy(),
179 | } for bs_i in range(self.dataloader.batch_size)
180 | }
181 | np.savez(export_path,**output_dict)
182 |
183 | if t==0:
184 | sub_dicts=[
185 | {
186 | 'motion':model_out_sample[bs_i].squeeze().cpu().numpy(),
187 | 'length':m_length[bs_i].cpu().numpy(),
188 | 'caption':caption[bs_i],
189 | 'tokens':tokens[bs_i],
190 | 'cap_len':sent_len[bs_i].cpu().numpy(),
191 | } for bs_i in range(self.dataloader.batch_size)
192 | ]
193 | generated_motion+=sub_dicts
194 | if is_mm:
195 | mm_motions+=[
196 | {
197 | 'motion':model_out_sample[bs_i].squeeze().cpu().numpy(),
198 | 'length':m_length[bs_i].cpu().numpy(),
199 | } for bs_i in range(self.dataloader.batch_size)
200 | ]
201 | if is_mm:
202 | mm_generated_motion+=[
203 | {
204 | 'caption':model_kwargs['y']['text'][bs_i],
205 | 'tokens':tokens[bs_i],
206 | 'cap_len':len(tokens[bs_i]),
207 | 'mm_motions':mm_motions[bs_i::self.dataloader.batch_size], # collect all 10 repeats from the (32*10) generated motions
208 | } for bs_i in range(self.dataloader.batch_size)
209 | ]
210 | self.generated_motion=generated_motion
211 | self.mm_generated_motion=mm_generated_motion
212 | self.w_vectorizer=self.dataloader.dataset.w_vectorizer
213 |
214 | def __len__(self):
215 | return len(self.generated_motion)
216 |
217 | def __getitem__(self, item):
218 | data = self.generated_motion[item]
219 | motion, m_length, caption, tokens = data['motion'], data['length'], data['caption'], data['tokens']
220 | sent_len = data['cap_len']
221 |
222 | pos_one_hots = []
223 | word_embeddings = []
224 | for token in tokens:
225 | word_emb, pos_oh = self.w_vectorizer[token]
226 | pos_one_hots.append(pos_oh[None, :])
227 | word_embeddings.append(word_emb[None, :])
228 | pos_one_hots = np.concatenate(pos_one_hots, axis=0)
229 | word_embeddings = np.concatenate(word_embeddings, axis=0)
230 |
231 | return word_embeddings, pos_one_hots, caption, sent_len, motion, m_length, '_'.join(tokens)
--------------------------------------------------------------------------------
/src/train/training_loop.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from src.diffusion.fp16_util import MixedPrecisionTrainer
3 | from src.diffusion.resample import LossAwareSampler
4 | from src.dataset.eval_dataset import Evaluation_Dataset
5 | from src.dataset.tensors import gt_collate_fn
6 | from src.eval.eval_himo_2o import evaluation
7 | from src.feature_extractor.eval_wrapper import EvaluationWrapper
8 | from src.dataset.eval_gen_dataset import get_eval_gen_loader
9 | from src.diffusion import logger
10 | import functools
11 | from loguru import logger as log
12 |
13 | from src.utils import dist_utils
14 |
15 | from torch.optim import AdamW
16 | from torch.utils.data import DataLoader
17 | import blobfile as bf
18 | from src.diffusion.resample import create_named_schedule_sampler
19 | from tqdm import tqdm
20 | import numpy as np
21 | import os
22 | import time
23 |
24 |
25 | class TrainLoop:
26 | def __init__(self,args,train_platform,model,diffusion,data_loader):
27 | self.args=args
28 | self.train_platform=train_platform
29 | self.model=model
30 | self.diffusion=diffusion
31 | self.data=data_loader
32 |
33 | self.batch_size=args.batch_size
34 | self.microbatch = args.batch_size # deprecating this option
35 | self.lr=args.lr
36 | self.log_interval=args.log_interval
37 | self.save_interval=args.save_interval
38 | self.resume_checkpoint=args.resume_checkpoint
39 | self.use_fp16 = False # deprecating this option
40 | self.fp16_scale_growth = 1e-3 # deprecating this option
41 | self.weight_decay=args.weight_decay
42 | self.lr_anneal_steps=args.lr_anneal_steps
43 |
44 | self.step = 0
45 | self.resume_step = 0
46 | self.global_batch = self.batch_size # * dist.get_world_size()
47 | # self.num_steps = args.num_steps
48 | # self.num_epochs = self.num_steps // len(self.data) + 1
49 | self.num_epochs=args.num_epochs
50 |
51 | self.sync_cuda = torch.cuda.is_available()
52 |
53 | self._load_and_sync_parameters()
54 | self.mp_trainer = MixedPrecisionTrainer(
55 | model=self.model,
56 | use_fp16=self.use_fp16,
57 | fp16_scale_growth=self.fp16_scale_growth,
58 | )
59 |
60 | self.save_path=args.save_path
61 |
62 | self.opt = AdamW(
63 | self.mp_trainer.master_params, lr=self.lr, weight_decay=self.weight_decay
64 | )
65 | if self.resume_step:
66 | self._load_optimizer_state()
67 | # Model was resumed, either due to a restart or a checkpoint
68 | # being specified at the command line.
69 |
70 | self.device = torch.device("cpu")
71 | if torch.cuda.is_available() and dist_utils.dev() != 'cpu':
72 | self.device = torch.device(dist_utils.dev())
73 |
74 | self.schedule_sampler_type = 'uniform'
75 | self.schedule_sampler = create_named_schedule_sampler(self.schedule_sampler_type, diffusion)
76 | self.eval_wrapper,self.eval_data,self.eval_gt_data=None,None,None
77 | if args.eval_during_training:
78 | mm_num_samples = 0 # mm is super slow hence we won't run it during training
79 | mm_num_repeats = 0 # mm is super slow hence we won't run it during training
80 | gen_dataset=Evaluation_Dataset(args,split='val',mode='eval')
81 | gen_loader=DataLoader(gen_dataset,batch_size=args.eval_batch_size,
82 | shuffle=True,num_workers=8,drop_last=True,collate_fn=gt_collate_fn)
83 | gt_dataset=Evaluation_Dataset(args,split='val',mode='gt')
84 | self.eval_gt_data=DataLoader(gt_dataset,batch_size=args.eval_batch_size,
85 | shuffle=True,num_workers=8,drop_last=True,collate_fn=gt_collate_fn)
86 |
87 | self.eval_wrapper=EvaluationWrapper(args)
88 | self.eval_data={
89 | 'test':lambda :get_eval_gen_loader(
90 | args,model,diffusion,gen_loader,gen_loader.dataset.max_motion_length,
91 | args.eval_batch_size,mm_num_samples,mm_num_repeats,
92 | 1000,scale=1.
93 |
94 | )
95 | }
96 | self.use_ddp=False
97 | self.ddp_model=self.model
98 |
99 | def _load_and_sync_parameters(self):
100 | resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
101 |
102 | if resume_checkpoint:
103 | self.resume_step = parse_resume_step_from_filename(resume_checkpoint)
104 | log.info(f"loading model from checkpoint: {resume_checkpoint}...")
105 | self.model.load_state_dict(
106 | dist_utils.load_state_dict(
107 | resume_checkpoint, map_location=dist_utils.dev()
108 | )
109 | )
110 |
111 | def _load_optimizer_state(self):
112 | main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
113 | opt_checkpoint = bf.join(
114 | bf.dirname(main_checkpoint), f"opt{self.resume_step:09}.pt"
115 | )
116 | if bf.exists(opt_checkpoint):
117 | log.log(f"loading optimizer state from checkpoint: {opt_checkpoint}")
118 | state_dict = dist_utils.load_state_dict(
119 | opt_checkpoint, map_location=dist_utils.dev()
120 | )
121 | self.opt.load_state_dict(state_dict)
122 |
123 | def run_loop(self):
124 | for epoch in range(self.num_epochs):
125 | log.info(f'Starting epoch {epoch}/{self.num_epochs}')
126 | for inp,cond in tqdm(self.data):
127 | if not (not self.lr_anneal_steps or self.step + self.resume_step < self.lr_anneal_steps):
128 | break
129 |
130 | inp=inp.to(self.device)
131 | cond['y'] ={k:v.to(self.device) if torch.is_tensor(v) else v for k,v in cond['y'].items()}
132 |
133 | self.run_step(inp,cond)
134 | if self.step % self.log_interval == 0:
135 | for k,v in logger.get_current().name2val.items():
136 | if k == 'loss':
137 | print('step[{}]: loss[{:0.5f}]'.format(self.step+self.resume_step, v))
138 |
139 | if k in ['step', 'samples'] or '_q' in k:
140 | continue
141 | else:
142 | self.train_platform.report_scalar(name=k, value=v, iteration=self.step, group_name='Loss')
143 | if self.step % self.save_interval == 0:
144 | self.save()
145 | self.model.eval()
146 | self.evaluate()
147 | self.model.train()
148 |
149 | # Run for a finite amount of time in integration tests.
150 | if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0:
151 | return
152 | self.step += 1
153 | if not (not self.lr_anneal_steps or self.step + self.resume_step < self.lr_anneal_steps):
154 | break
155 | # Save the last checkpoint if it wasn't already saved.
156 | if (self.step - 1) % self.save_interval != 0:
157 | self.save()
158 | self.evaluate()
159 |
160 | def run_step(self,inp,cond):
161 | self.forward_backward(inp,cond)
162 | self.mp_trainer.optimize(self.opt)
163 | self._anneal_lr()
164 | self.log_step()
165 |
166 | def forward_backward(self,batch,cond):
167 | self.mp_trainer.zero_grad()
168 | for i in range(0, batch.shape[0], self.microbatch):
169 | # Eliminates the microbatch feature
170 | assert i == 0
171 | assert self.microbatch == self.batch_size
172 | micro = batch
173 | micro_cond = cond
174 | last_batch = (i + self.microbatch) >= batch.shape[0]
175 | t, weights = self.schedule_sampler.sample(micro.shape[0], dist_utils.dev())
176 |
177 | compute_losses = functools.partial(
178 | self.diffusion.training_losses,
179 | self.ddp_model,
180 | micro, # [bs, ...]
181 | t, # [bs](int) sampled timesteps
182 | model_kwargs=micro_cond,
183 | dataset=self.args.dataset
184 | )
185 |
186 | if last_batch or not self.use_ddp:
187 | losses = compute_losses()
188 | else:
189 | with self.ddp_model.no_sync():
190 | losses = compute_losses()
191 |
192 | if isinstance(self.schedule_sampler, LossAwareSampler):
193 | self.schedule_sampler.update_with_local_losses(
194 | t, losses["loss"].detach()
195 | )
196 |
197 | loss = (losses["loss"] * weights).mean()
198 | log_loss_dict(
199 | self.diffusion, t, {k: v * weights for k, v in losses.items()}
200 | )
201 | self.mp_trainer.backward(loss)
202 |
203 | def _anneal_lr(self):
204 | if not self.lr_anneal_steps:
205 | return
206 | frac_done = (self.step + self.resume_step) / self.lr_anneal_steps
207 | lr = self.lr * (1 - frac_done)
208 | for param_group in self.opt.param_groups:
209 | param_group["lr"] = lr
210 |
211 | def log_step(self):
212 | logger.logkv("step", self.step + self.resume_step)
213 | logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch)
214 |
215 | def ckpt_file_name(self):
216 | return f"model{(self.step+self.resume_step):09d}.pt"
217 |
218 | def save(self):
219 | def save_checkpoint(params):
220 | state_dict = self.mp_trainer.master_params_to_state_dict(params)
221 |
222 | # Do not save CLIP weights
223 | clip_weights = [e for e in state_dict.keys() if e.startswith('clip_model.')]
224 | for e in clip_weights:
225 | del state_dict[e]
226 |
227 | log.info(f"saving model...")
228 | filename = self.ckpt_file_name()
229 | with bf.BlobFile(bf.join(self.save_path, filename), "wb") as f:
230 | torch.save(state_dict, f)
231 |
232 | save_checkpoint(self.mp_trainer.master_params)
233 |
234 | with bf.BlobFile(
235 | bf.join(self.save_path, f"opt{(self.step+self.resume_step):09d}.pt"),
236 | "wb",
237 | ) as f:
238 | torch.save(self.opt.state_dict(), f)
239 |
240 | def evaluate(self):
241 | if not self.args.eval_during_training:
242 | return
243 | start_eval = time.time()
244 | if self.eval_wrapper is not None:
245 | log.info('Running evaluation loop: [Should take about 90 min]')
246 | log_file = os.path.join(self.save_path, f'eval_model_{(self.step + self.resume_step):09d}.log')
247 | diversity_times = 100 if self.args.obj=='2o' else 40# 200
248 | eval_rep_time=1 # 3
249 | mm_num_times = 0 # mm is super slow hence we won't run it during training
250 | eval_dict = evaluation(
251 | self.eval_wrapper, self.eval_gt_data, self.eval_data, log_file,
252 | replication_times=eval_rep_time, diversity_times=diversity_times, mm_num_times=mm_num_times, run_mm=False)
253 | log.info(eval_dict)
254 | for k, v in eval_dict.items():
255 | if k.startswith('R_precision'):
256 | for i in range(len(v)):
257 | self.train_platform.report_scalar(name=f'top{i + 1}_' + k, value=v[i],
258 | iteration=self.step + self.resume_step,
259 | group_name='Eval')
260 | else:
261 | self.train_platform.report_scalar(name=k, value=v, iteration=self.step + self.resume_step,
262 | group_name='Eval')
263 | end_eval=time.time()
264 | log.info(f'Evaluation time: {round(end_eval-start_eval)/60} seconds')
265 |
266 |
267 | def find_resume_checkpoint():
268 | # On your infrastructure, you may want to override this to automatically
269 | # discover the latest checkpoint on your blob storage, etc.
270 | return None
271 |
272 | def parse_resume_step_from_filename(filename):
273 | """
274 | Parse filenames of the form path/to/modelNNNNNN.pt, where NNNNNN is the
275 | checkpoint's number of steps.
276 | """
277 | split = filename.split("model")
278 | if len(split) < 2:
279 | return 0
280 | split1 = split[-1].split(".")[0]
281 | try:
282 | return int(split1)
283 | except ValueError:
284 | return 0
285 |
286 | def log_loss_dict(diffusion, ts, losses):
287 | for key, values in losses.items():
288 | logger.logkv_mean(key, values.mean().item())
289 | # Log the quantiles (four quartiles, in particular).
290 | for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()):
291 | quartile = int(4 * sub_t / diffusion.num_timesteps)
292 | logger.logkv_mean(f"{key}_q{quartile}", sub_loss)
--------------------------------------------------------------------------------
/src/eval/eval_himo_3o.py:
--------------------------------------------------------------------------------
1 | from src.utils.parser_utils import eval_himo_args
2 | from src.utils.misc import fixseed
3 | from src.utils import dist_utils
4 | # from src.diffusion import logger
5 |
6 | from collections import OrderedDict
7 | from src.utils.model_utils import create_model_and_diffusion,load_model_wo_clip
8 | from torch.utils.data import DataLoader
9 | from src.dataset.eval_dataset import Evaluation_Dataset
10 | from src.dataset.tensors import gt_collate_fn
11 | from src.dataset.eval_gen_dataset import get_eval_gen_loader
12 | from src.feature_extractor.eval_wrapper import EvaluationWrapper
13 | from src.eval.metrics import *
14 | from src.model.cfg_sampler import ClassifierFreeSampleModel
15 | import torch
16 | import os
17 | import os.path as osp
18 | import numpy as np
19 | from datetime import datetime
20 | from loguru import logger
21 |
22 | def evaluate_matching_score(eval_wrapper, motion_loaders, file):
23 | match_score_dict = OrderedDict({})
24 | R_precision_dict = OrderedDict({})
25 | activation_dict = OrderedDict({})
26 | logger.info('========== Evaluating Matching Score ==========')
27 | for motion_loader_name, motion_loader in motion_loaders.items():
28 | all_motion_embeddings = []
29 | score_list = []
30 | all_size = 0
31 | matching_score_sum = 0
32 | top_k_count = 0
33 | # logger.info(motion_loader_name)
34 | with torch.no_grad():
35 | for idx, batch in enumerate(motion_loader):
36 | word_embeddings, pos_one_hots, _, sent_lens, motions, m_lens, _ = batch
37 | text_embeddings, motion_embeddings = eval_wrapper.get_co_embeddings(
38 | word_embs=word_embeddings,
39 | pos_ohot=pos_one_hots,
40 | cap_lens=sent_lens,
41 | motions=motions,
42 | m_lens=m_lens
43 | )
44 | dist_mat = euclidean_distance_matrix(text_embeddings.cpu().numpy(),
45 | motion_embeddings.cpu().numpy())
46 | matching_score_sum += dist_mat.trace()
47 |
48 | argsmax = np.argsort(dist_mat, axis=1)
49 | top_k_mat = calculate_top_k(argsmax, top_k=3)
50 | top_k_count += top_k_mat.sum(axis=0)
51 |
52 | all_size += text_embeddings.shape[0]
53 |
54 | all_motion_embeddings.append(motion_embeddings.cpu().numpy())
55 |
56 | all_motion_embeddings = np.concatenate(all_motion_embeddings, axis=0)
57 | matching_score = matching_score_sum / all_size
58 | R_precision = top_k_count / all_size
59 | match_score_dict[motion_loader_name] = matching_score
60 | R_precision_dict[motion_loader_name] = R_precision
61 | activation_dict[motion_loader_name] = all_motion_embeddings
62 |
63 | logger.info(f'---> [{motion_loader_name}] Matching Score: {matching_score:.4f}')
64 | logger.info(f'---> [{motion_loader_name}] Matching Score: {matching_score:.4f}', file=file, flush=True)
65 |
66 | line = f'---> [{motion_loader_name}] R_precision: '
67 | for i in range(len(R_precision)):
68 | line += '(top %d): %.4f ' % (i+1, R_precision[i])
69 | logger.info(line)
70 | logger.info(line, file=file, flush=True)
71 |
72 | return match_score_dict, R_precision_dict, activation_dict
73 |
74 | def evaluate_fid(eval_wrapper, groundtruth_loader, activation_dict, file):
75 | eval_dict = OrderedDict({})
76 | gt_motion_embeddings = []
77 | logger.info('========== Evaluating FID ==========')
78 | with torch.no_grad():
79 | for idx, batch in enumerate(groundtruth_loader):
80 | _, _, _, sent_lens, motions, m_lens, _ = batch
81 | motion_embeddings = eval_wrapper.get_motion_embeddings(
82 | motions=motions,
83 | m_lens=m_lens
84 | )
85 | gt_motion_embeddings.append(motion_embeddings.cpu().numpy())
86 | gt_motion_embeddings = np.concatenate(gt_motion_embeddings, axis=0)
87 | gt_mu, gt_cov = calculate_activation_statistics(gt_motion_embeddings)
88 |
89 | # logger.info(gt_mu)
90 | for model_name, motion_embeddings in activation_dict.items():
91 | mu, cov = calculate_activation_statistics(motion_embeddings)
92 | # logger.info(mu)
93 | fid = calculate_frechet_distance(gt_mu, gt_cov, mu, cov)
94 | logger.info(f'---> [{model_name}] FID: {fid:.4f}')
95 | logger.info(f'---> [{model_name}] FID: {fid:.4f}', file=file, flush=True)
96 | eval_dict[model_name] = fid
97 | return eval_dict
98 |
99 |
100 | def evaluate_diversity(activation_dict, file, diversity_times):
101 | eval_dict = OrderedDict({})
102 | logger.info('========== Evaluating Diversity ==========')
103 | for model_name, motion_embeddings in activation_dict.items():
104 | diversity = calculate_diversity(motion_embeddings, diversity_times)
105 | eval_dict[model_name] = diversity
106 | logger.info(f'---> [{model_name}] Diversity: {diversity:.4f}')
107 | logger.info(f'---> [{model_name}] Diversity: {diversity:.4f}', file=file, flush=True)
108 | return eval_dict
109 |
110 |
111 | def evaluate_multimodality(eval_wrapper, mm_motion_loaders, file, mm_num_times):
112 | eval_dict = OrderedDict({})
113 | logger.info('========== Evaluating MultiModality ==========')
114 | for model_name, mm_motion_loader in mm_motion_loaders.items():
115 | mm_motion_embeddings = []
116 | with torch.no_grad():
117 | for idx, batch in enumerate(mm_motion_loader):
118 | # (1, mm_replications, dim_pos)
119 | motions, m_lens = batch
120 | motion_embedings = eval_wrapper.get_motion_embeddings(motions[0], m_lens[0])
121 | mm_motion_embeddings.append(motion_embedings.unsqueeze(0))
122 | if len(mm_motion_embeddings) == 0:
123 | multimodality = 0
124 | else:
125 | mm_motion_embeddings = torch.cat(mm_motion_embeddings, dim=0).cpu().numpy()
126 | multimodality = calculate_multimodality(mm_motion_embeddings, mm_num_times)
127 | logger.info(f'---> [{model_name}] Multimodality: {multimodality:.4f}')
128 | logger.info(f'---> [{model_name}] Multimodality: {multimodality:.4f}', file=file, flush=True)
129 | eval_dict[model_name] = multimodality
130 | return eval_dict
131 |
132 |
133 | def get_metric_statistics(values, replication_times):
134 | mean = np.mean(values, axis=0)
135 | std = np.std(values, axis=0)
136 | conf_interval = 1.96 * std / np.sqrt(replication_times)
137 | return mean, conf_interval
138 |
139 | def evaluation(eval_wrapper, gt_loader, eval_motion_loaders, log_file, replication_times, diversity_times, mm_num_times, run_mm=False):
140 | with open(log_file, 'w') as f:
141 | all_metrics = OrderedDict({'Matching Score': OrderedDict({}),
142 | 'R_precision': OrderedDict({}),
143 | 'FID': OrderedDict({}),
144 | 'Diversity': OrderedDict({}),
145 | 'MultiModality': OrderedDict({})})
146 | for replication in range(replication_times):
147 | motion_loaders = {}
148 | mm_motion_loaders = {}
149 | motion_loaders['ground truth'] = gt_loader
150 | for motion_loader_name, motion_loader_getter in eval_motion_loaders.items():
151 | motion_loader, mm_motion_loader = motion_loader_getter()
152 | motion_loaders[motion_loader_name] = motion_loader
153 | mm_motion_loaders[motion_loader_name] = mm_motion_loader
154 |
155 | logger.info(f'==================== Replication {replication} ====================')
156 | logger.info(f'==================== Replication {replication} ====================', file=f, flush=True)
157 | logger.info(f'Time: {datetime.now()}')
158 | logger.info(f'Time: {datetime.now()}', file=f, flush=True)
159 | mat_score_dict, R_precision_dict, acti_dict = evaluate_matching_score(eval_wrapper, motion_loaders, f)
160 |
161 | logger.info(f'Time: {datetime.now()}')
162 | logger.info(f'Time: {datetime.now()}', file=f, flush=True)
163 | fid_score_dict = evaluate_fid(eval_wrapper, gt_loader, acti_dict, f)
164 |
165 | logger.info(f'Time: {datetime.now()}')
166 | logger.info(f'Time: {datetime.now()}', file=f, flush=True)
167 | div_score_dict = evaluate_diversity(acti_dict, f, diversity_times)
168 |
169 | if run_mm:
170 | logger.info(f'Time: {datetime.now()}')
171 | logger.info(f'Time: {datetime.now()}', file=f, flush=True)
172 | mm_score_dict = evaluate_multimodality(eval_wrapper, mm_motion_loaders, f, mm_num_times)
173 |
174 | logger.info(f'!!! DONE !!!')
175 | logger.info(f'!!! DONE !!!', file=f, flush=True)
176 |
177 | for key, item in mat_score_dict.items():
178 | if key not in all_metrics['Matching Score']:
179 | all_metrics['Matching Score'][key] = [item]
180 | else:
181 | all_metrics['Matching Score'][key] += [item]
182 |
183 | for key, item in R_precision_dict.items():
184 | if key not in all_metrics['R_precision']:
185 | all_metrics['R_precision'][key] = [item]
186 | else:
187 | all_metrics['R_precision'][key] += [item]
188 |
189 | for key, item in fid_score_dict.items():
190 | if key not in all_metrics['FID']:
191 | all_metrics['FID'][key] = [item]
192 | else:
193 | all_metrics['FID'][key] += [item]
194 |
195 | for key, item in div_score_dict.items():
196 | if key not in all_metrics['Diversity']:
197 | all_metrics['Diversity'][key] = [item]
198 | else:
199 | all_metrics['Diversity'][key] += [item]
200 | if run_mm:
201 | for key, item in mm_score_dict.items():
202 | if key not in all_metrics['MultiModality']:
203 | all_metrics['MultiModality'][key] = [item]
204 | else:
205 | all_metrics['MultiModality'][key] += [item]
206 |
207 |
208 | # logger.info(all_metrics['Diversity'])
209 | mean_dict = {}
210 | for metric_name, metric_dict in all_metrics.items():
211 | logger.info('========== %s Summary ==========' % metric_name)
212 | logger.info('========== %s Summary ==========' % metric_name, file=f, flush=True)
213 | for model_name, values in metric_dict.items():
214 | # logger.info(metric_name, model_name)
215 | mean, conf_interval = get_metric_statistics(np.array(values), replication_times)
216 | mean_dict[metric_name + '_' + model_name] = mean
217 | # logger.info(mean, mean.dtype)
218 | if isinstance(mean, np.float64) or isinstance(mean, np.float32):
219 | logger.info(f'---> [{model_name}] Mean: {mean:.4f} CInterval: {conf_interval:.4f}')
220 | logger.info(f'---> [{model_name}] Mean: {mean:.4f} CInterval: {conf_interval:.4f}', file=f, flush=True)
221 | elif isinstance(mean, np.ndarray):
222 | line = f'---> [{model_name}]'
223 | for i in range(len(mean)):
224 | line += '(top %d) Mean: %.4f CInt: %.4f;' % (i+1, mean[i], conf_interval[i])
225 | logger.info(line)
226 | logger.info(line, file=f, flush=True)
227 | return mean_dict
228 |
229 | if __name__ == '__main__':
230 | val_args,model_args=eval_himo_args()
231 | fixseed(val_args.seed)
232 | val_args.batch_size = 32 # This must be 32! Don't change it! otherwise it will cause a bug in R precision calc!
233 |
234 | model_name = os.path.basename(os.path.dirname(val_args.model_path))
235 | model_niter = os.path.basename(val_args.model_path).replace('model', '').replace('.pt', '')
236 |
237 | log_file=osp.join('val_results','eval_{}_{}'.format(model_name,model_niter))
238 | log_file+=f'_{val_args.eval_mode}'
239 | log_file+='.log'
240 | logger.add(log_file)
241 | logger.info(f'Will save to log file [{log_file}]')
242 | logger.info(f'Eval mode [{val_args.eval_mode}]')
243 |
244 | if val_args.eval_mode == 'debug':
245 | num_samples_limit = 1000 # None means no limit (eval over all dataset)
246 | run_mm = False
247 | mm_num_samples = 0
248 | mm_num_repeats = 0
249 | mm_num_times = 0
250 | diversity_times = 100
251 | replication_times = 1
252 | elif val_args.eval_mode == 'wo_mm':
253 | num_samples_limit = 1000
254 | run_mm = False
255 | mm_num_samples = 0
256 | mm_num_repeats = 0
257 | mm_num_times = 0
258 | diversity_times = 100
259 | replication_times = 20
260 | elif val_args.eval_mode == 'mm_short':
261 | num_samples_limit = 1000
262 | run_mm = True
263 | mm_num_samples = 100
264 | mm_num_repeats = 20 # 30
265 | mm_num_times = 10
266 | diversity_times = 100
267 | replication_times = 5
268 | else:
269 | raise ValueError()
270 |
271 | dist_utils.setup_dist(val_args.device)
272 | logger.configure()
273 |
274 | logger.info("creating data loader...")
275 | split = 'test'
276 |
277 | gt_dataset=Evaluation_Dataset(val_args,split=split,mode='gt')
278 | gen_dataset=Evaluation_Dataset(val_args,split=split,mode='eval')
279 | gt_loader=DataLoader(gt_dataset,batch_size=val_args.batch_size,shuffle=True,num_workers=8,
280 | drop_last=True,collate_fn=gt_collate_fn)
281 | gen_loader=DataLoader(gen_dataset,batch_size=val_args.batch_size,shuffle=True,num_workers=8,
282 | drop_last=True,collate_fn=gt_collate_fn)
283 |
284 | logger.info("creating model and diffusion...")
285 | model,diffusion=create_model_and_diffusion(model_args)
286 | logger.info(f"Loading model from [{val_args.model_path}]...")
287 | state_dict=torch.load(val_args.model_path,map_location='cpu')
288 | load_model_wo_clip(model,state_dict)
289 |
290 | if val_args.guidance_param!=1:
291 | model=ClassifierFreeSampleModel(model)
292 |
293 | model.to(dist_utils.dev())
294 | model.eval()
295 |
296 | eval_motion_loaders={
297 | 'vald': lambda:get_eval_gen_loader(
298 | val_args,model,diffusion,
299 | gen_loader,val_args.max_motion_length,val_args.batch_size,
300 | mm_num_samples,mm_num_repeats,num_samples_limit,val_args.guidance_param
301 | )
302 | }
303 | eval_wrapper=EvaluationWrapper(val_args)
304 | evaluation(eval_wrapper,gt_loader,eval_motion_loaders,log_file,replication_times,diversity_times,
305 | mm_num_times,run_mm=run_mm)
306 |
307 |
--------------------------------------------------------------------------------
/src/eval/eval_himo_2o.py:
--------------------------------------------------------------------------------
1 | from src.utils.parser_utils import eval_himo_args
2 | from src.utils.misc import fixseed
3 | from src.utils import dist_utils
4 | # from src.diffusion import logger
5 |
6 | from collections import OrderedDict
7 | from src.utils.model_utils import create_model_and_diffusion,load_model_wo_clip
8 | from torch.utils.data import DataLoader
9 | from src.dataset.eval_dataset import Evaluation_Dataset
10 | from src.dataset.tensors import gt_collate_fn
11 | from src.dataset.eval_gen_dataset import get_eval_gen_loader
12 | from src.feature_extractor.eval_wrapper import EvaluationWrapper
13 | from src.eval.metrics import *
14 | from src.model.cfg_sampler import ClassifierFreeSampleModel
15 | import torch
16 | import os
17 | import os.path as osp
18 | import numpy as np
19 | from datetime import datetime
20 | from loguru import logger
21 |
22 | def evaluate_matching_score(eval_wrapper, motion_loaders, file):
23 | match_score_dict = OrderedDict({})
24 | R_precision_dict = OrderedDict({})
25 | activation_dict = OrderedDict({})
26 | logger.info('========== Evaluating Matching Score ==========')
27 | for motion_loader_name, motion_loader in motion_loaders.items():
28 | all_motion_embeddings = []
29 | score_list = []
30 | all_size = 0
31 | matching_score_sum = 0
32 | top_k_count = 0
33 | logger.info(motion_loader_name)
34 | with torch.no_grad():
35 | for idx, batch in enumerate(motion_loader):
36 | word_embeddings, pos_one_hots, _, sent_lens, motions, m_lens, _ = batch
37 | text_embeddings, motion_embeddings = eval_wrapper.get_co_embeddings(
38 | word_embs=word_embeddings,
39 | pos_ohot=pos_one_hots,
40 | cap_lens=sent_lens,
41 | motions=motions,
42 | m_lens=m_lens
43 | )
44 | dist_mat = euclidean_distance_matrix(text_embeddings.cpu().numpy(),
45 | motion_embeddings.cpu().numpy())
46 | matching_score_sum += dist_mat.trace()
47 |
48 | argsmax = np.argsort(dist_mat, axis=1)
49 | top_k_mat = calculate_top_k(argsmax, top_k=3)
50 | top_k_count += top_k_mat.sum(axis=0)
51 |
52 | all_size += text_embeddings.shape[0]
53 |
54 | all_motion_embeddings.append(motion_embeddings.cpu().numpy())
55 |
56 | all_motion_embeddings = np.concatenate(all_motion_embeddings, axis=0)
57 | matching_score = matching_score_sum / all_size
58 | R_precision = top_k_count / all_size
59 | match_score_dict[motion_loader_name] = matching_score
60 | R_precision_dict[motion_loader_name] = R_precision
61 | activation_dict[motion_loader_name] = all_motion_embeddings
62 |
63 | logger.info(f'---> [{motion_loader_name}] Matching Score: {matching_score:.4f}')
64 | logger.info(f'---> [{motion_loader_name}] Matching Score: {matching_score:.4f}', file=file, flush=True)
65 |
66 | line = f'---> [{motion_loader_name}] R_precision: '
67 | for i in range(len(R_precision)):
68 | line += '(top %d): %.4f ' % (i+1, R_precision[i])
69 | logger.info(line)
70 | logger.info(line, file=file, flush=True)
71 |
72 | return match_score_dict, R_precision_dict, activation_dict
73 |
74 | def evaluate_fid(eval_wrapper, groundtruth_loader, activation_dict, file):
75 | eval_dict = OrderedDict({})
76 | gt_motion_embeddings = []
77 | logger.info('========== Evaluating FID ==========')
78 | with torch.no_grad():
79 | for idx, batch in enumerate(groundtruth_loader):
80 | _, _, _, sent_lens, motions, m_lens, _ = batch
81 | motion_embeddings = eval_wrapper.get_motion_embeddings(
82 | motions=motions,
83 | m_lens=m_lens
84 | )
85 | gt_motion_embeddings.append(motion_embeddings.cpu().numpy())
86 | gt_motion_embeddings = np.concatenate(gt_motion_embeddings, axis=0)
87 | gt_mu, gt_cov = calculate_activation_statistics(gt_motion_embeddings)
88 |
89 | # logger.info(gt_mu)
90 | for model_name, motion_embeddings in activation_dict.items():
91 | mu, cov = calculate_activation_statistics(motion_embeddings)
92 | # logger.info(mu)
93 | fid = calculate_frechet_distance(gt_mu, gt_cov, mu, cov)
94 | logger.info(f'---> [{model_name}] FID: {fid:.4f}')
95 | logger.info(f'---> [{model_name}] FID: {fid:.4f}', file=file, flush=True)
96 | eval_dict[model_name] = fid
97 | return eval_dict
98 |
99 |
100 | def evaluate_diversity(activation_dict, file, diversity_times):
101 | eval_dict = OrderedDict({})
102 | logger.info('========== Evaluating Diversity ==========')
103 | for model_name, motion_embeddings in activation_dict.items():
104 | diversity = calculate_diversity(motion_embeddings, diversity_times)
105 | eval_dict[model_name] = diversity
106 | logger.info(f'---> [{model_name}] Diversity: {diversity:.4f}')
107 | logger.info(f'---> [{model_name}] Diversity: {diversity:.4f}', file=file, flush=True)
108 | return eval_dict
109 |
110 |
111 | def evaluate_multimodality(eval_wrapper, mm_motion_loaders, file, mm_num_times):
112 | eval_dict = OrderedDict({})
113 | logger.info('========== Evaluating MultiModality ==========')
114 | for model_name, mm_motion_loader in mm_motion_loaders.items():
115 | mm_motion_embeddings = []
116 | with torch.no_grad():
117 | for idx, batch in enumerate(mm_motion_loader):
118 | # (1, mm_replications, dim_pos)
119 | motions, m_lens = batch
120 | motion_embedings = eval_wrapper.get_motion_embeddings(motions[0], m_lens[0])
121 | mm_motion_embeddings.append(motion_embedings.unsqueeze(0))
122 | if len(mm_motion_embeddings) == 0:
123 | multimodality = 0
124 | else:
125 | mm_motion_embeddings = torch.cat(mm_motion_embeddings, dim=0).cpu().numpy()
126 | multimodality = calculate_multimodality(mm_motion_embeddings, mm_num_times)
127 | logger.info(f'---> [{model_name}] Multimodality: {multimodality:.4f}')
128 | logger.info(f'---> [{model_name}] Multimodality: {multimodality:.4f}', file=file, flush=True)
129 | eval_dict[model_name] = multimodality
130 | return eval_dict
131 |
132 |
133 | def get_metric_statistics(values, replication_times):
134 | mean = np.mean(values, axis=0)
135 | std = np.std(values, axis=0)
136 | conf_interval = 1.96 * std / np.sqrt(replication_times)
137 | return mean, conf_interval
138 |
139 | def evaluation(eval_wrapper, gt_loader, eval_motion_loaders, log_file, replication_times, diversity_times, mm_num_times, run_mm=False):
140 | with open(log_file, 'w') as f:
141 | all_metrics = OrderedDict({'Matching Score': OrderedDict({}),
142 | 'R_precision': OrderedDict({}),
143 | 'FID': OrderedDict({}),
144 | 'Diversity': OrderedDict({}),
145 | 'MultiModality': OrderedDict({})})
146 | for replication in range(replication_times):
147 | motion_loaders = {}
148 | mm_motion_loaders = {}
149 | motion_loaders['ground truth'] = gt_loader
150 | for motion_loader_name, motion_loader_getter in eval_motion_loaders.items():
151 | motion_loader, mm_motion_loader = motion_loader_getter()
152 | motion_loaders[motion_loader_name] = motion_loader
153 | mm_motion_loaders[motion_loader_name] = mm_motion_loader
154 |
155 | logger.info(f'==================== Replication {replication} ====================')
156 | logger.info(f'==================== Replication {replication} ====================', file=f, flush=True)
157 | logger.info(f'Time: {datetime.now()}')
158 | logger.info(f'Time: {datetime.now()}', file=f, flush=True)
159 | mat_score_dict, R_precision_dict, acti_dict = evaluate_matching_score(eval_wrapper, motion_loaders, f)
160 |
161 | logger.info(f'Time: {datetime.now()}')
162 | logger.info(f'Time: {datetime.now()}', file=f, flush=True)
163 | fid_score_dict = evaluate_fid(eval_wrapper, gt_loader, acti_dict, f)
164 |
165 | logger.info(f'Time: {datetime.now()}')
166 | logger.info(f'Time: {datetime.now()}', file=f, flush=True)
167 | div_score_dict = evaluate_diversity(acti_dict, f, diversity_times)
168 |
169 | if run_mm:
170 | logger.info(f'Time: {datetime.now()}')
171 | logger.info(f'Time: {datetime.now()}', file=f, flush=True)
172 | mm_score_dict = evaluate_multimodality(eval_wrapper, mm_motion_loaders, f, mm_num_times)
173 |
174 | logger.info(f'!!! DONE !!!')
175 | logger.info(f'!!! DONE !!!', file=f, flush=True)
176 |
177 | for key, item in mat_score_dict.items():
178 | if key not in all_metrics['Matching Score']:
179 | all_metrics['Matching Score'][key] = [item]
180 | else:
181 | all_metrics['Matching Score'][key] += [item]
182 |
183 | for key, item in R_precision_dict.items():
184 | if key not in all_metrics['R_precision']:
185 | all_metrics['R_precision'][key] = [item]
186 | else:
187 | all_metrics['R_precision'][key] += [item]
188 |
189 | for key, item in fid_score_dict.items():
190 | if key not in all_metrics['FID']:
191 | all_metrics['FID'][key] = [item]
192 | else:
193 | all_metrics['FID'][key] += [item]
194 |
195 | for key, item in div_score_dict.items():
196 | if key not in all_metrics['Diversity']:
197 | all_metrics['Diversity'][key] = [item]
198 | else:
199 | all_metrics['Diversity'][key] += [item]
200 | if run_mm:
201 | for key, item in mm_score_dict.items():
202 | if key not in all_metrics['MultiModality']:
203 | all_metrics['MultiModality'][key] = [item]
204 | else:
205 | all_metrics['MultiModality'][key] += [item]
206 |
207 |
208 | # logger.info(all_metrics['Diversity'])
209 | mean_dict = {}
210 | for metric_name, metric_dict in all_metrics.items():
211 | logger.info('========== %s Summary ==========' % metric_name)
212 | logger.info('========== %s Summary ==========' % metric_name, file=f, flush=True)
213 | for model_name, values in metric_dict.items():
214 | # logger.info(metric_name, model_name)
215 | mean, conf_interval = get_metric_statistics(np.array(values), replication_times)
216 | mean_dict[metric_name + '_' + model_name] = mean
217 | # logger.info(mean, mean.dtype)
218 | if isinstance(mean, np.float64) or isinstance(mean, np.float32):
219 | logger.info(f'---> [{model_name}] Mean: {mean:.4f} CInterval: {conf_interval:.4f}')
220 | logger.info(f'---> [{model_name}] Mean: {mean:.4f} CInterval: {conf_interval:.4f}', file=f, flush=True)
221 | elif isinstance(mean, np.ndarray):
222 | line = f'---> [{model_name}]'
223 | for i in range(len(mean)):
224 | line += '(top %d) Mean: %.4f CInt: %.4f;' % (i+1, mean[i], conf_interval[i])
225 | logger.info(line)
226 | logger.info(line, file=f, flush=True)
227 | return mean_dict
228 |
229 | if __name__ == '__main__':
230 | val_args,model_args=eval_himo_args()
231 | fixseed(val_args.seed)
232 | val_args.batch_size = 32 # This must be 32! Don't change it! otherwise it will cause a bug in R precision calc!
233 |
234 | model_name = os.path.basename(os.path.dirname(val_args.model_path))
235 | model_niter = os.path.basename(val_args.model_path).replace('model', '').replace('.pt', '')
236 |
237 | log_file=osp.join('val_results','eval_{}_{}'.format(model_name,model_niter))
238 | log_file+=f'_{val_args.eval_mode}'
239 | log_file+='.log'
240 | logger.add(log_file)
241 | logger.info(f'Will save to log file [{log_file}]')
242 | logger.info(f'Eval mode [{val_args.eval_mode}]')
243 |
244 | if val_args.eval_mode == 'debug':
245 | num_samples_limit = 1000 # None means no limit (eval over all dataset)
246 | run_mm = False
247 | mm_num_samples = 0
248 | mm_num_repeats = 0
249 | mm_num_times = 0
250 | diversity_times = 200
251 | replication_times = 1
252 | elif val_args.eval_mode == 'wo_mm':
253 | num_samples_limit = 1000
254 | run_mm = False
255 | mm_num_samples = 0
256 | mm_num_repeats = 0
257 | mm_num_times = 0
258 | diversity_times = 200 # 300
259 | replication_times = 10 #20
260 | elif val_args.eval_mode == 'mm_short':
261 | num_samples_limit = 1000
262 | run_mm = True
263 | mm_num_samples = 100
264 | mm_num_repeats = 20 # 30
265 | mm_num_times = 10
266 | diversity_times = 200 # 300
267 | replication_times = 3
268 | else:
269 | raise ValueError()
270 |
271 | dist_utils.setup_dist(val_args.device)
272 | logger.configure()
273 |
274 | logger.info("creating data loader...")
275 | split = 'test'
276 |
277 | gt_dataset=Evaluation_Dataset(val_args,split=split,mode='gt')
278 | gen_dataset=Evaluation_Dataset(val_args,split=split,mode='eval')
279 | gt_loader=DataLoader(gt_dataset,batch_size=val_args.batch_size,shuffle=False,num_workers=8,
280 | drop_last=True,collate_fn=gt_collate_fn)
281 | gen_loader=DataLoader(gen_dataset,batch_size=val_args.batch_size,shuffle=False,num_workers=8,
282 | drop_last=True,collate_fn=gt_collate_fn)
283 |
284 | logger.info("creating model and diffusion...")
285 | model,diffusion=create_model_and_diffusion(model_args)
286 | logger.info(f"Loading model from [{val_args.model_path}]...")
287 | state_dict=torch.load(val_args.model_path,map_location='cpu')
288 | load_model_wo_clip(model,state_dict)
289 |
290 | if val_args.guidance_param!=1:
291 | model=ClassifierFreeSampleModel(model)
292 |
293 | model.to(dist_utils.dev())
294 | model.eval()
295 |
296 | eval_motion_loaders={
297 | 'vald': lambda:get_eval_gen_loader(
298 | val_args,model,diffusion,
299 | gen_loader,val_args.max_motion_length,val_args.batch_size,
300 | mm_num_samples,mm_num_repeats,num_samples_limit,val_args.guidance_param
301 | )
302 | }
303 | eval_wrapper=EvaluationWrapper(val_args)
304 | evaluation(eval_wrapper,gt_loader,eval_motion_loaders,log_file,replication_times,diversity_times,
305 | mm_num_times,run_mm=run_mm)
306 |
307 |
--------------------------------------------------------------------------------