├── default_code_trevor_emoca2.pkl ├── requirements.txt ├── utils ├── config.py ├── paramUtil.py ├── losses.py ├── motion_process.py ├── utils_model.py ├── word_vectorizer.py ├── skeleton.py ├── quaternion.py └── rotation_conversions.py ├── models ├── pos_encoding.py ├── encdec.py ├── resnet.py ├── rotation2xyz.py ├── evaluator_wrapper.py ├── smpl.py ├── modules.py ├── vqvae.py └── quantize_cnn.py ├── options ├── get_eval_option.py ├── option_vq.py └── option_transformer.py ├── dataset ├── dataset_VQ.py ├── dataset_tokenize.py └── dataset_TM_eval.py ├── train_vq.py ├── README.md ├── baselines.py ├── visualize_listener.py ├── evaluate_listener.py └── train_t2m_trans.py /default_code_trevor_emoca2.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sanjayss34/lm-listener/HEAD/default_code_trevor_emoca2.pkl -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | einops 2 | matplotlib 3 | #numpy==1.20.3 4 | numpy 5 | opencv_contrib_python 6 | scipy 7 | seaborn 8 | six 9 | torch 10 | torchvision 11 | tensorboard 12 | transformers 13 | imageio 14 | pytorch_lightning 15 | adabound 16 | omegaconf 17 | scikit-image 18 | compress_pickle 19 | -------------------------------------------------------------------------------- /utils/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | SMPL_DATA_PATH = "./body_models/smpl" 4 | 5 | SMPL_KINTREE_PATH = os.path.join(SMPL_DATA_PATH, "kintree_table.pkl") 6 | SMPL_MODEL_PATH = os.path.join(SMPL_DATA_PATH, "SMPL_NEUTRAL.pkl") 7 | JOINT_REGRESSOR_TRAIN_EXTRA = os.path.join(SMPL_DATA_PATH, 'J_regressor_extra.npy') 8 | 9 | ROT_CONVENTION_TO_ROT_NUMBER = { 10 | 'legacy': 23, 11 | 'no_hands': 21, 12 | 'full_hands': 51, 13 | 'mitten_hands': 33, 14 | } 15 | 16 | GENDERS = ['neutral', 'male', 'female'] 17 | NUM_BETAS = 10 -------------------------------------------------------------------------------- /models/pos_encoding.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various positional encodings for the transformer. 3 | """ 4 | import math 5 | import torch 6 | from torch import nn 7 | 8 | def PE1d_sincos(seq_length, dim): 9 | """ 10 | :param d_model: dimension of the model 11 | :param length: length of positions 12 | :return: length*d_model position matrix 13 | """ 14 | if dim % 2 != 0: 15 | raise ValueError("Cannot use sin/cos positional encoding with " 16 | "odd dim (got dim={:d})".format(dim)) 17 | pe = torch.zeros(seq_length, dim) 18 | position = torch.arange(0, seq_length).unsqueeze(1) 19 | div_term = torch.exp((torch.arange(0, dim, 2, dtype=torch.float) * 20 | -(math.log(10000.0) / dim))) 21 | pe[:, 0::2] = torch.sin(position.float() * div_term) 22 | pe[:, 1::2] = torch.cos(position.float() * div_term) 23 | 24 | return pe.unsqueeze(1) 25 | 26 | 27 | class PositionEmbedding(nn.Module): 28 | """ 29 | Absolute pos embedding (standard), learned. 30 | """ 31 | def __init__(self, seq_length, dim, dropout, grad=False): 32 | super().__init__() 33 | self.embed = nn.Parameter(data=PE1d_sincos(seq_length, dim), requires_grad=grad) 34 | self.dropout = nn.Dropout(p=dropout) 35 | 36 | def forward(self, x): 37 | # x.shape: bs, seq_len, feat_dim 38 | # print('x',x.shape) 39 | l = x.shape[1] 40 | # print('l',l, 'embed',self.embed[:l].shape) 41 | x = x.permute(1, 0, 2) + self.embed[:l].expand(x.permute(1, 0, 2).shape) 42 | # print('x2', x.shape) 43 | x = self.dropout(x.permute(1, 0, 2)) 44 | return x 45 | 46 | -------------------------------------------------------------------------------- /utils/paramUtil.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | # Define a kinematic tree for the skeletal struture 4 | kit_kinematic_chain = [[0, 11, 12, 13, 14, 15], [0, 16, 17, 18, 19, 20], [0, 1, 2, 3, 4], [3, 5, 6, 7], [3, 8, 9, 10]] 5 | 6 | kit_raw_offsets = np.array( 7 | [ 8 | [0, 0, 0], 9 | [0, 1, 0], 10 | [0, 1, 0], 11 | [0, 1, 0], 12 | [0, 1, 0], 13 | [1, 0, 0], 14 | [0, -1, 0], 15 | [0, -1, 0], 16 | [-1, 0, 0], 17 | [0, -1, 0], 18 | [0, -1, 0], 19 | [1, 0, 0], 20 | [0, -1, 0], 21 | [0, -1, 0], 22 | [0, 0, 1], 23 | [0, 0, 1], 24 | [-1, 0, 0], 25 | [0, -1, 0], 26 | [0, -1, 0], 27 | [0, 0, 1], 28 | [0, 0, 1] 29 | ] 30 | ) 31 | 32 | t2m_raw_offsets = np.array([[0,0,0], 33 | [1,0,0], 34 | [-1,0,0], 35 | [0,1,0], 36 | [0,-1,0], 37 | [0,-1,0], 38 | [0,1,0], 39 | [0,-1,0], 40 | [0,-1,0], 41 | [0,1,0], 42 | [0,0,1], 43 | [0,0,1], 44 | [0,1,0], 45 | [1,0,0], 46 | [-1,0,0], 47 | [0,0,1], 48 | [0,-1,0], 49 | [0,-1,0], 50 | [0,-1,0], 51 | [0,-1,0], 52 | [0,-1,0], 53 | [0,-1,0]]) 54 | 55 | t2m_kinematic_chain = [[0, 2, 5, 8, 11], [0, 1, 4, 7, 10], [0, 3, 6, 9, 12, 15], [9, 14, 17, 19, 21], [9, 13, 16, 18, 20]] 56 | t2m_left_hand_chain = [[20, 22, 23, 24], [20, 34, 35, 36], [20, 25, 26, 27], [20, 31, 32, 33], [20, 28, 29, 30]] 57 | t2m_right_hand_chain = [[21, 43, 44, 45], [21, 46, 47, 48], [21, 40, 41, 42], [21, 37, 38, 39], [21, 49, 50, 51]] 58 | 59 | 60 | kit_tgt_skel_id = '03950' 61 | 62 | t2m_tgt_skel_id = '000021' 63 | 64 | -------------------------------------------------------------------------------- /utils/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class ReConsLoss(nn.Module): 5 | def __init__(self, recons_loss, nb_joints, pose_alpha): 6 | super(ReConsLoss, self).__init__() 7 | 8 | if recons_loss == 'l1': 9 | self.Loss = torch.nn.L1Loss() 10 | elif recons_loss == 'l2' : 11 | self.Loss = torch.nn.MSELoss() 12 | elif recons_loss == 'l1_smooth' : 13 | self.Loss = torch.nn.SmoothL1Loss() 14 | 15 | self.nb_joints = None 16 | # 4 global motion associated to root 17 | # 12 local motion (3 local xyz, 3 vel xyz, 6 rot6d) 18 | # 3 global vel xyz 19 | # 4 foot contact 20 | if nb_joints is not None: 21 | self.nb_joints = nb_joints 22 | self.motion_dim = (nb_joints - 1) * 12 + 4 + 3 + 4 23 | else: 24 | self.jaw_alpha = 1.0 25 | self.pose_alpha = pose_alpha 26 | 27 | def forward(self, motion_pred, motion_gt) : 28 | if self.nb_joints is not None: 29 | loss = self.Loss(motion_pred[..., : self.motion_dim], motion_gt[..., :self.motion_dim]) 30 | else: 31 | exp_loss = self.Loss(motion_pred[:,:,:50], motion_gt[:,:,:50]) 32 | rot_loss = self.Loss(motion_pred[:,:,50:53], motion_gt[:,:,50:53])*self.pose_alpha 33 | jaw_loss = self.jaw_alpha * self.Loss(motion_pred[:,:,53:56], motion_gt[:,:,53:56]) 34 | loss = exp_loss+rot_loss+jaw_loss 35 | return loss 36 | 37 | def forward_vel(self, motion_pred, motion_gt) : 38 | if self.nb_joints is None: 39 | vel_pred = torch.cat(( 40 | torch.zeros_like(motion_pred[:,:1,:]), 41 | motion_pred[:,1:,:]-motion_pred[:,:-1,:] 42 | ), dim=1) 43 | vel_gt = torch.cat(( 44 | torch.zeros_like(motion_gt[:,:1,:]), 45 | motion_gt[:,1:,:]-motion_gt[:,:-1,:] 46 | ), dim=1) 47 | loss = self.Loss(vel_pred, vel_gt) 48 | else: 49 | loss = self.Loss(motion_pred[..., 4 : (self.nb_joints - 1) * 3 + 4], motion_gt[..., 4 : (self.nb_joints - 1) * 3 + 4]) 50 | return loss 51 | 52 | 53 | -------------------------------------------------------------------------------- /utils/motion_process.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils.quaternion import quaternion_to_cont6d, qrot, qinv 3 | 4 | def recover_root_rot_pos(data): 5 | rot_vel = data[..., 0] 6 | r_rot_ang = torch.zeros_like(rot_vel).to(data.device) 7 | '''Get Y-axis rotation from rotation velocity''' 8 | r_rot_ang[..., 1:] = rot_vel[..., :-1] 9 | r_rot_ang = torch.cumsum(r_rot_ang, dim=-1) 10 | 11 | r_rot_quat = torch.zeros(data.shape[:-1] + (4,)).to(data.device) 12 | r_rot_quat[..., 0] = torch.cos(r_rot_ang) 13 | r_rot_quat[..., 2] = torch.sin(r_rot_ang) 14 | 15 | r_pos = torch.zeros(data.shape[:-1] + (3,)).to(data.device) 16 | r_pos[..., 1:, [0, 2]] = data[..., :-1, 1:3] 17 | '''Add Y-axis rotation to root position''' 18 | r_pos = qrot(qinv(r_rot_quat), r_pos) 19 | 20 | r_pos = torch.cumsum(r_pos, dim=-2) 21 | 22 | r_pos[..., 1] = data[..., 3] 23 | return r_rot_quat, r_pos 24 | 25 | 26 | def recover_from_rot(data, joints_num, skeleton): 27 | r_rot_quat, r_pos = recover_root_rot_pos(data) 28 | 29 | r_rot_cont6d = quaternion_to_cont6d(r_rot_quat) 30 | 31 | start_indx = 1 + 2 + 1 + (joints_num - 1) * 3 32 | end_indx = start_indx + (joints_num - 1) * 6 33 | cont6d_params = data[..., start_indx:end_indx] 34 | # print(r_rot_cont6d.shape, cont6d_params.shape, r_pos.shape) 35 | cont6d_params = torch.cat([r_rot_cont6d, cont6d_params], dim=-1) 36 | cont6d_params = cont6d_params.view(-1, joints_num, 6) 37 | 38 | positions = skeleton.forward_kinematics_cont6d(cont6d_params, r_pos) 39 | 40 | return positions 41 | 42 | 43 | def recover_from_ric(data, joints_num): 44 | r_rot_quat, r_pos = recover_root_rot_pos(data) 45 | positions = data[..., 4:(joints_num - 1) * 3 + 4] 46 | positions = positions.view(positions.shape[:-1] + (-1, 3)) 47 | 48 | '''Add Y-axis rotation to local joints''' 49 | positions = qrot(qinv(r_rot_quat[..., None, :]).expand(positions.shape[:-1] + (4,)), positions) 50 | 51 | '''Add root XZ to joints''' 52 | positions[..., 0] += r_pos[..., 0:1] 53 | positions[..., 2] += r_pos[..., 2:3] 54 | 55 | '''Concate root and joints''' 56 | positions = torch.cat([r_pos.unsqueeze(-2), positions], dim=-2) 57 | 58 | return positions 59 | -------------------------------------------------------------------------------- /models/encdec.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from models.resnet import Resnet1D 3 | 4 | class Encoder(nn.Module): 5 | def __init__(self, 6 | input_emb_width = 3, 7 | output_emb_width = 512, 8 | down_t = 3, 9 | stride_t = 2, 10 | width = 512, 11 | depth = 3, 12 | dilation_growth_rate = 3, 13 | activation='relu', 14 | norm=None): 15 | super().__init__() 16 | 17 | blocks = [] 18 | filter_t, pad_t = stride_t * 2, stride_t // 2 19 | blocks.append(nn.Conv1d(input_emb_width, width, 3, 1, 1)) 20 | blocks.append(nn.ReLU()) 21 | 22 | for i in range(down_t): 23 | input_dim = width 24 | block = nn.Sequential( 25 | nn.Conv1d(input_dim, width, filter_t, stride_t, pad_t), 26 | Resnet1D(width, depth, dilation_growth_rate, activation=activation, norm=norm), 27 | ) 28 | blocks.append(block) 29 | blocks.append(nn.Conv1d(width, output_emb_width, 3, 1, 1)) 30 | self.model = nn.Sequential(*blocks) 31 | 32 | def forward(self, x): 33 | # print('X:', x.shape) # 256, 56, 32 (B,F,T) 34 | return self.model(x) 35 | 36 | class Decoder(nn.Module): 37 | def __init__(self, 38 | input_emb_width = 3, 39 | output_emb_width = 512, 40 | down_t = 3, 41 | stride_t = 2, 42 | width = 512, 43 | depth = 3, 44 | dilation_growth_rate = 3, 45 | activation='relu', 46 | norm=None): 47 | super().__init__() 48 | blocks = [] 49 | 50 | filter_t, pad_t = stride_t * 2, stride_t // 2 51 | blocks.append(nn.Conv1d(output_emb_width, width, 3, 1, 1)) 52 | blocks.append(nn.ReLU()) 53 | for i in range(down_t): 54 | out_dim = width 55 | block = nn.Sequential( 56 | Resnet1D(width, depth, dilation_growth_rate, reverse_dilation=True, activation=activation, norm=norm), 57 | nn.Upsample(scale_factor=2, mode='nearest'), 58 | nn.Conv1d(width, out_dim, 3, 1, 1) 59 | ) 60 | blocks.append(block) 61 | blocks.append(nn.Conv1d(width, width, 3, 1, 1)) 62 | blocks.append(nn.ReLU()) 63 | blocks.append(nn.Conv1d(width, input_emb_width, 3, 1, 1)) 64 | self.model = nn.Sequential(*blocks) 65 | 66 | def forward(self, x): 67 | return self.model(x) 68 | 69 | -------------------------------------------------------------------------------- /utils/utils_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.optim as optim 4 | import logging 5 | import os 6 | import sys 7 | 8 | def getCi(accLog): 9 | 10 | mean = np.mean(accLog) 11 | std = np.std(accLog) 12 | ci95 = 1.96*std/np.sqrt(len(accLog)) 13 | 14 | return mean, ci95 15 | 16 | def get_logger(out_dir): 17 | logger = logging.getLogger('Exp') 18 | logger.setLevel(logging.INFO) 19 | formatter = logging.Formatter("%(asctime)s %(levelname)s %(message)s") 20 | 21 | file_path = os.path.join(out_dir, "run.log") 22 | file_hdlr = logging.FileHandler(file_path) 23 | file_hdlr.setFormatter(formatter) 24 | 25 | strm_hdlr = logging.StreamHandler(sys.stdout) 26 | strm_hdlr.setFormatter(formatter) 27 | 28 | logger.addHandler(file_hdlr) 29 | logger.addHandler(strm_hdlr) 30 | return logger 31 | 32 | ## Optimizer 33 | def initial_optim(decay_option, lr, weight_decay, net, optimizer) : 34 | 35 | if optimizer == 'adamw' : 36 | optimizer_adam_family = optim.AdamW 37 | elif optimizer == 'adam' : 38 | optimizer_adam_family = optim.Adam 39 | if decay_option == 'all': 40 | #optimizer = optimizer_adam_family(net.parameters(), lr=lr, betas=(0.9, 0.999), weight_decay=weight_decay) 41 | optimizer = optimizer_adam_family(net.parameters(), lr=lr, betas=(0.5, 0.9), weight_decay=weight_decay) 42 | 43 | elif decay_option == 'noVQ': 44 | all_params = set(net.parameters()) 45 | no_decay = set([net.vq_layer]) 46 | 47 | decay = all_params - no_decay 48 | optimizer = optimizer_adam_family([ 49 | {'params': list(no_decay), 'weight_decay': 0}, 50 | {'params': list(decay), 'weight_decay' : weight_decay}], lr=lr) 51 | 52 | return optimizer 53 | 54 | 55 | def get_motion_with_trans(motion, velocity) : 56 | ''' 57 | motion : torch.tensor, shape (batch_size, T, 72), with the global translation = 0 58 | velocity : torch.tensor, shape (batch_size, T, 3), contain the information of velocity = 0 59 | 60 | ''' 61 | trans = torch.cumsum(velocity, dim=1) 62 | trans = trans - trans[:, :1] ## the first root is initialized at 0 (just for visualization) 63 | trans = trans.repeat((1, 1, 21)) 64 | motion_with_trans = motion + trans 65 | return motion_with_trans 66 | 67 | def convert_vq_state_dict(state_dict): 68 | new_state_dict = {} 69 | for key in state_dict: 70 | new_state_dict[key.replace('encoder.model', 'encoder.0.model').replace('decoder.model', 'decoder.0.model').replace('quantizer.codebook', 'quantizer.0.codebook')] = state_dict[key] 71 | return new_state_dict 72 | -------------------------------------------------------------------------------- /options/get_eval_option.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | import re 3 | from os.path import join as pjoin 4 | 5 | 6 | def is_float(numStr): 7 | flag = False 8 | numStr = str(numStr).strip().lstrip('-').lstrip('+') 9 | try: 10 | reg = re.compile(r'^[-+]?[0-9]+\.[0-9]+$') 11 | res = reg.match(str(numStr)) 12 | if res: 13 | flag = True 14 | except Exception as ex: 15 | print("is_float() - error: " + str(ex)) 16 | return flag 17 | 18 | 19 | def is_number(numStr): 20 | flag = False 21 | numStr = str(numStr).strip().lstrip('-').lstrip('+') 22 | if str(numStr).isdigit(): 23 | flag = True 24 | return flag 25 | 26 | 27 | def get_opt(opt_path, device): 28 | opt = Namespace() 29 | opt_dict = vars(opt) 30 | 31 | skip = ('-------------- End ----------------', 32 | '------------ Options -------------', 33 | '\n') 34 | print('Reading', opt_path) 35 | with open(opt_path) as f: 36 | for line in f: 37 | if line.strip() not in skip: 38 | # print(line.strip()) 39 | key, value = line.strip().split(': ') 40 | if value in ('True', 'False'): 41 | opt_dict[key] = (value == 'True') 42 | # print(key, value) 43 | elif is_float(value): 44 | opt_dict[key] = float(value) 45 | elif is_number(value): 46 | opt_dict[key] = int(value) 47 | else: 48 | opt_dict[key] = str(value) 49 | 50 | # print(opt) 51 | opt_dict['which_epoch'] = 'finest' 52 | opt.save_root = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.name) 53 | opt.model_dir = pjoin(opt.save_root, 'model') 54 | opt.meta_dir = pjoin(opt.save_root, 'meta') 55 | 56 | if opt.dataset_name == 't2m': 57 | opt.data_root = './dataset/HumanML3D/' 58 | opt.motion_dir = pjoin(opt.data_root, 'new_joint_vecs') 59 | opt.text_dir = pjoin(opt.data_root, 'texts') 60 | opt.joints_num = 22 61 | opt.dim_pose = 263 62 | opt.max_motion_length = 196 63 | opt.max_motion_frame = 196 64 | opt.max_motion_token = 55 65 | elif opt.dataset_name == 'kit': 66 | opt.data_root = './dataset/KIT-ML/' 67 | opt.motion_dir = pjoin(opt.data_root, 'new_joint_vecs') 68 | opt.text_dir = pjoin(opt.data_root, 'texts') 69 | opt.joints_num = 21 70 | opt.dim_pose = 251 71 | opt.max_motion_length = 196 72 | opt.max_motion_frame = 196 73 | opt.max_motion_token = 55 74 | else: 75 | raise KeyError('Dataset not recognized') 76 | 77 | opt.dim_word = 300 78 | opt.num_classes = 200 // opt.unit_length 79 | opt.is_train = False 80 | opt.is_continue = False 81 | opt.device = device 82 | 83 | return opt -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | class nonlinearity(nn.Module): 5 | def __init__(self): 6 | super().__init__() 7 | 8 | def forward(self, x): 9 | # swish 10 | return x * torch.sigmoid(x) 11 | 12 | class ResConv1DBlock(nn.Module): 13 | def __init__(self, n_in, n_state, dilation=1, activation='silu', norm=None, dropout=None): 14 | super().__init__() 15 | padding = dilation 16 | self.norm = norm 17 | if norm == "LN": 18 | self.norm1 = nn.LayerNorm(n_in) 19 | self.norm2 = nn.LayerNorm(n_in) 20 | elif norm == "GN": 21 | self.norm1 = nn.GroupNorm(num_groups=32, num_channels=n_in, eps=1e-6, affine=True) 22 | self.norm2 = nn.GroupNorm(num_groups=32, num_channels=n_in, eps=1e-6, affine=True) 23 | elif norm == "BN": 24 | self.norm1 = nn.BatchNorm1d(num_features=n_in, eps=1e-6, affine=True) 25 | self.norm2 = nn.BatchNorm1d(num_features=n_in, eps=1e-6, affine=True) 26 | 27 | else: 28 | self.norm1 = nn.Identity() 29 | self.norm2 = nn.Identity() 30 | 31 | if activation == "relu": 32 | self.activation1 = nn.ReLU() 33 | self.activation2 = nn.ReLU() 34 | 35 | elif activation == "silu": 36 | self.activation1 = nonlinearity() 37 | self.activation2 = nonlinearity() 38 | 39 | elif activation == "gelu": 40 | self.activation1 = nn.GELU() 41 | self.activation2 = nn.GELU() 42 | 43 | 44 | 45 | self.conv1 = nn.Conv1d(n_in, n_state, 3, 1, padding, dilation) 46 | self.conv2 = nn.Conv1d(n_state, n_in, 1, 1, 0,) 47 | 48 | 49 | def forward(self, x): 50 | x_orig = x 51 | if self.norm == "LN": 52 | x = self.norm1(x.transpose(-2, -1)) 53 | x = self.activation1(x.transpose(-2, -1)) 54 | else: 55 | x = self.norm1(x) 56 | x = self.activation1(x) 57 | 58 | x = self.conv1(x) 59 | 60 | if self.norm == "LN": 61 | x = self.norm2(x.transpose(-2, -1)) 62 | x = self.activation2(x.transpose(-2, -1)) 63 | else: 64 | x = self.norm2(x) 65 | x = self.activation2(x) 66 | 67 | x = self.conv2(x) 68 | x = x + x_orig 69 | return x 70 | 71 | class Resnet1D(nn.Module): 72 | def __init__(self, n_in, n_depth, dilation_growth_rate=1, reverse_dilation=True, activation='relu', norm=None): 73 | super().__init__() 74 | 75 | blocks = [ResConv1DBlock(n_in, n_in, dilation=dilation_growth_rate ** depth, activation=activation, norm=norm) for depth in range(n_depth)] 76 | if reverse_dilation: 77 | blocks = blocks[::-1] 78 | 79 | self.model = nn.Sequential(*blocks) 80 | 81 | def forward(self, x): 82 | return self.model(x) -------------------------------------------------------------------------------- /utils/word_vectorizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pickle 3 | from os.path import join as pjoin 4 | 5 | POS_enumerator = { 6 | 'VERB': 0, 7 | 'NOUN': 1, 8 | 'DET': 2, 9 | 'ADP': 3, 10 | 'NUM': 4, 11 | 'AUX': 5, 12 | 'PRON': 6, 13 | 'ADJ': 7, 14 | 'ADV': 8, 15 | 'Loc_VIP': 9, 16 | 'Body_VIP': 10, 17 | 'Obj_VIP': 11, 18 | 'Act_VIP': 12, 19 | 'Desc_VIP': 13, 20 | 'OTHER': 14, 21 | } 22 | 23 | Loc_list = ('left', 'right', 'clockwise', 'counterclockwise', 'anticlockwise', 'forward', 'back', 'backward', 24 | 'up', 'down', 'straight', 'curve') 25 | 26 | Body_list = ('arm', 'chin', 'foot', 'feet', 'face', 'hand', 'mouth', 'leg', 'waist', 'eye', 'knee', 'shoulder', 'thigh') 27 | 28 | Obj_List = ('stair', 'dumbbell', 'chair', 'window', 'floor', 'car', 'ball', 'handrail', 'baseball', 'basketball') 29 | 30 | Act_list = ('walk', 'run', 'swing', 'pick', 'bring', 'kick', 'put', 'squat', 'throw', 'hop', 'dance', 'jump', 'turn', 31 | 'stumble', 'dance', 'stop', 'sit', 'lift', 'lower', 'raise', 'wash', 'stand', 'kneel', 'stroll', 32 | 'rub', 'bend', 'balance', 'flap', 'jog', 'shuffle', 'lean', 'rotate', 'spin', 'spread', 'climb') 33 | 34 | Desc_list = ('slowly', 'carefully', 'fast', 'careful', 'slow', 'quickly', 'happy', 'angry', 'sad', 'happily', 35 | 'angrily', 'sadly') 36 | 37 | VIP_dict = { 38 | 'Loc_VIP': Loc_list, 39 | 'Body_VIP': Body_list, 40 | 'Obj_VIP': Obj_List, 41 | 'Act_VIP': Act_list, 42 | 'Desc_VIP': Desc_list, 43 | } 44 | 45 | 46 | class WordVectorizer(object): 47 | def __init__(self, meta_root, prefix): 48 | vectors = np.load(pjoin(meta_root, '%s_data.npy'%prefix)) 49 | words = pickle.load(open(pjoin(meta_root, '%s_words.pkl'%prefix), 'rb')) 50 | self.word2idx = pickle.load(open(pjoin(meta_root, '%s_idx.pkl'%prefix), 'rb')) 51 | self.word2vec = {w: vectors[self.word2idx[w]] for w in words} 52 | 53 | def _get_pos_ohot(self, pos): 54 | pos_vec = np.zeros(len(POS_enumerator)) 55 | if pos in POS_enumerator: 56 | pos_vec[POS_enumerator[pos]] = 1 57 | else: 58 | pos_vec[POS_enumerator['OTHER']] = 1 59 | return pos_vec 60 | 61 | def __len__(self): 62 | return len(self.word2vec) 63 | 64 | def __getitem__(self, item): 65 | word, pos = item.split('/') 66 | if word in self.word2vec: 67 | word_vec = self.word2vec[word] 68 | vip_pos = None 69 | for key, values in VIP_dict.items(): 70 | if word in values: 71 | vip_pos = key 72 | break 73 | if vip_pos is not None: 74 | pos_vec = self._get_pos_ohot(vip_pos) 75 | else: 76 | pos_vec = self._get_pos_ohot(pos) 77 | else: 78 | word_vec = self.word2vec['unk'] 79 | pos_vec = self._get_pos_ohot('OTHER') 80 | return word_vec, pos_vec 81 | 82 | 83 | class WordVectorizerV2(WordVectorizer): 84 | def __init__(self, meta_root, prefix): 85 | super(WordVectorizerV2, self).__init__(meta_root, prefix) 86 | self.idx2word = {self.word2idx[w]: w for w in self.word2idx} 87 | 88 | def __getitem__(self, item): 89 | word_vec, pose_vec = super(WordVectorizerV2, self).__getitem__(item) 90 | word, pos = item.split('/') 91 | if word in self.word2vec: 92 | return word_vec, pose_vec, self.word2idx[word] 93 | else: 94 | return word_vec, pose_vec, self.word2idx['unk'] 95 | 96 | def itos(self, idx): 97 | if idx == len(self.idx2word): 98 | return "pad" 99 | return self.idx2word[idx] -------------------------------------------------------------------------------- /models/rotation2xyz.py: -------------------------------------------------------------------------------- 1 | # This code is based on https://github.com/Mathux/ACTOR.git 2 | import torch 3 | import utils.rotation_conversions as geometry 4 | 5 | 6 | from models.smpl import SMPL, JOINTSTYPE_ROOT 7 | # from .get_model import JOINTSTYPES 8 | JOINTSTYPES = ["a2m", "a2mpl", "smpl", "vibe", "vertices"] 9 | 10 | 11 | class Rotation2xyz: 12 | def __init__(self, device, dataset='amass'): 13 | self.device = device 14 | self.dataset = dataset 15 | self.smpl_model = SMPL().eval().to(device) 16 | 17 | def __call__(self, x, mask, pose_rep, translation, glob, 18 | jointstype, vertstrans, betas=None, beta=0, 19 | glob_rot=None, get_rotations_back=False, **kwargs): 20 | if pose_rep == "xyz": 21 | return x 22 | 23 | if mask is None: 24 | mask = torch.ones((x.shape[0], x.shape[-1]), dtype=bool, device=x.device) 25 | 26 | if not glob and glob_rot is None: 27 | raise TypeError("You must specify global rotation if glob is False") 28 | 29 | if jointstype not in JOINTSTYPES: 30 | raise NotImplementedError("This jointstype is not implemented.") 31 | 32 | if translation: 33 | x_translations = x[:, -1, :3] 34 | x_rotations = x[:, :-1] 35 | else: 36 | x_rotations = x 37 | 38 | x_rotations = x_rotations.permute(0, 3, 1, 2) 39 | nsamples, time, njoints, feats = x_rotations.shape 40 | 41 | # Compute rotations (convert only masked sequences output) 42 | if pose_rep == "rotvec": 43 | rotations = geometry.axis_angle_to_matrix(x_rotations[mask]) 44 | elif pose_rep == "rotmat": 45 | rotations = x_rotations[mask].view(-1, njoints, 3, 3) 46 | elif pose_rep == "rotquat": 47 | rotations = geometry.quaternion_to_matrix(x_rotations[mask]) 48 | elif pose_rep == "rot6d": 49 | rotations = geometry.rotation_6d_to_matrix(x_rotations[mask]) 50 | else: 51 | raise NotImplementedError("No geometry for this one.") 52 | 53 | if not glob: 54 | global_orient = torch.tensor(glob_rot, device=x.device) 55 | global_orient = geometry.axis_angle_to_matrix(global_orient).view(1, 1, 3, 3) 56 | global_orient = global_orient.repeat(len(rotations), 1, 1, 1) 57 | else: 58 | global_orient = rotations[:, 0] 59 | rotations = rotations[:, 1:] 60 | 61 | if betas is None: 62 | betas = torch.zeros([rotations.shape[0], self.smpl_model.num_betas], 63 | dtype=rotations.dtype, device=rotations.device) 64 | betas[:, 1] = beta 65 | # import ipdb; ipdb.set_trace() 66 | out = self.smpl_model(body_pose=rotations, global_orient=global_orient, betas=betas) 67 | 68 | # get the desirable joints 69 | joints = out[jointstype] 70 | 71 | x_xyz = torch.empty(nsamples, time, joints.shape[1], 3, device=x.device, dtype=x.dtype) 72 | x_xyz[~mask] = 0 73 | x_xyz[mask] = joints 74 | 75 | x_xyz = x_xyz.permute(0, 2, 3, 1).contiguous() 76 | 77 | # the first translation root at the origin on the prediction 78 | if jointstype != "vertices": 79 | rootindex = JOINTSTYPE_ROOT[jointstype] 80 | x_xyz = x_xyz - x_xyz[:, [rootindex], :, :] 81 | 82 | if translation and vertstrans: 83 | # the first translation root at the origin 84 | x_translations = x_translations - x_translations[:, :, [0]] 85 | 86 | # add the translation to all the joints 87 | x_xyz = x_xyz + x_translations[:, None, :, :] 88 | 89 | if get_rotations_back: 90 | return x_xyz, rotations, global_orient 91 | else: 92 | return x_xyz 93 | -------------------------------------------------------------------------------- /models/evaluator_wrapper.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from os.path import join as pjoin 4 | import numpy as np 5 | from models.modules import MovementConvEncoder, TextEncoderBiGRUCo, MotionEncoderBiGRUCo 6 | from utils.word_vectorizer import POS_enumerator 7 | 8 | def build_models(opt): 9 | movement_enc = MovementConvEncoder(opt.dim_pose-4, opt.dim_movement_enc_hidden, opt.dim_movement_latent) 10 | text_enc = TextEncoderBiGRUCo(word_size=opt.dim_word, 11 | pos_size=opt.dim_pos_ohot, 12 | hidden_size=opt.dim_text_hidden, 13 | output_size=opt.dim_coemb_hidden, 14 | device=opt.device) 15 | 16 | motion_enc = MotionEncoderBiGRUCo(input_size=opt.dim_movement_latent, 17 | hidden_size=opt.dim_motion_hidden, 18 | output_size=opt.dim_coemb_hidden, 19 | device=opt.device) 20 | 21 | checkpoint = torch.load(pjoin(opt.checkpoints_dir, opt.dataset_name, 'text_mot_match', 'model', 'finest.tar'), 22 | map_location=opt.device) 23 | movement_enc.load_state_dict(checkpoint['movement_encoder']) 24 | text_enc.load_state_dict(checkpoint['text_encoder']) 25 | motion_enc.load_state_dict(checkpoint['motion_encoder']) 26 | print('Loading Evaluation Model Wrapper (Epoch %d) Completed!!' % (checkpoint['epoch'])) 27 | return text_enc, motion_enc, movement_enc 28 | 29 | 30 | class EvaluatorModelWrapper(object): 31 | 32 | def __init__(self, opt): 33 | 34 | if opt.dataset_name == 't2m': 35 | opt.dim_pose = 263 36 | elif opt.dataset_name == 'kit': 37 | opt.dim_pose = 251 38 | else: 39 | raise KeyError('Dataset not Recognized!!!') 40 | 41 | opt.dim_word = 300 42 | opt.max_motion_length = 196 43 | opt.dim_pos_ohot = len(POS_enumerator) 44 | opt.dim_motion_hidden = 1024 45 | opt.max_text_len = 20 46 | opt.dim_text_hidden = 512 47 | opt.dim_coemb_hidden = 512 48 | 49 | # print(opt) 50 | 51 | self.text_encoder, self.motion_encoder, self.movement_encoder = build_models(opt) 52 | self.opt = opt 53 | self.device = opt.device 54 | 55 | self.text_encoder.to(opt.device) 56 | self.motion_encoder.to(opt.device) 57 | self.movement_encoder.to(opt.device) 58 | 59 | self.text_encoder.eval() 60 | self.motion_encoder.eval() 61 | self.movement_encoder.eval() 62 | 63 | # Please note that the results does not following the order of inputs 64 | def get_co_embeddings(self, word_embs, pos_ohot, cap_lens, motions, m_lens): 65 | with torch.no_grad(): 66 | word_embs = word_embs.detach().to(self.device).float() 67 | pos_ohot = pos_ohot.detach().to(self.device).float() 68 | motions = motions.detach().to(self.device).float() 69 | 70 | '''Movement Encoding''' 71 | movements = self.movement_encoder(motions[..., :-4]).detach() 72 | m_lens = m_lens // self.opt.unit_length 73 | motion_embedding = self.motion_encoder(movements, m_lens) 74 | 75 | '''Text Encoding''' 76 | text_embedding = self.text_encoder(word_embs, pos_ohot, cap_lens) 77 | return text_embedding, motion_embedding 78 | 79 | # Please note that the results does not following the order of inputs 80 | def get_motion_embeddings(self, motions, m_lens): 81 | with torch.no_grad(): 82 | motions = motions.detach().to(self.device).float() 83 | 84 | align_idx = np.argsort(m_lens.data.tolist())[::-1].copy() 85 | motions = motions[align_idx] 86 | m_lens = m_lens[align_idx] 87 | 88 | '''Movement Encoding''' 89 | movements = self.movement_encoder(motions[..., :-4]).detach() 90 | m_lens = m_lens // self.opt.unit_length 91 | motion_embedding = self.motion_encoder(movements, m_lens) 92 | return motion_embedding 93 | -------------------------------------------------------------------------------- /models/smpl.py: -------------------------------------------------------------------------------- 1 | # This code is based on https://github.com/Mathux/ACTOR.git 2 | import numpy as np 3 | import torch 4 | 5 | import contextlib 6 | 7 | from smplx import SMPLLayer as _SMPLLayer 8 | from smplx.lbs import vertices2joints 9 | 10 | 11 | # action2motion_joints = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 21, 24, 38] 12 | # change 0 and 8 13 | action2motion_joints = [8, 1, 2, 3, 4, 5, 6, 7, 0, 9, 10, 11, 12, 13, 14, 21, 24, 38] 14 | 15 | from utils.config import SMPL_MODEL_PATH, JOINT_REGRESSOR_TRAIN_EXTRA 16 | 17 | JOINTSTYPE_ROOT = {"a2m": 0, # action2motion 18 | "smpl": 0, 19 | "a2mpl": 0, # set(smpl, a2m) 20 | "vibe": 8} # 0 is the 8 position: OP MidHip below 21 | 22 | JOINT_MAP = { 23 | 'OP Nose': 24, 'OP Neck': 12, 'OP RShoulder': 17, 24 | 'OP RElbow': 19, 'OP RWrist': 21, 'OP LShoulder': 16, 25 | 'OP LElbow': 18, 'OP LWrist': 20, 'OP MidHip': 0, 26 | 'OP RHip': 2, 'OP RKnee': 5, 'OP RAnkle': 8, 27 | 'OP LHip': 1, 'OP LKnee': 4, 'OP LAnkle': 7, 28 | 'OP REye': 25, 'OP LEye': 26, 'OP REar': 27, 29 | 'OP LEar': 28, 'OP LBigToe': 29, 'OP LSmallToe': 30, 30 | 'OP LHeel': 31, 'OP RBigToe': 32, 'OP RSmallToe': 33, 'OP RHeel': 34, 31 | 'Right Ankle': 8, 'Right Knee': 5, 'Right Hip': 45, 32 | 'Left Hip': 46, 'Left Knee': 4, 'Left Ankle': 7, 33 | 'Right Wrist': 21, 'Right Elbow': 19, 'Right Shoulder': 17, 34 | 'Left Shoulder': 16, 'Left Elbow': 18, 'Left Wrist': 20, 35 | 'Neck (LSP)': 47, 'Top of Head (LSP)': 48, 36 | 'Pelvis (MPII)': 49, 'Thorax (MPII)': 50, 37 | 'Spine (H36M)': 51, 'Jaw (H36M)': 52, 38 | 'Head (H36M)': 53, 'Nose': 24, 'Left Eye': 26, 39 | 'Right Eye': 25, 'Left Ear': 28, 'Right Ear': 27 40 | } 41 | 42 | JOINT_NAMES = [ 43 | 'OP Nose', 'OP Neck', 'OP RShoulder', 44 | 'OP RElbow', 'OP RWrist', 'OP LShoulder', 45 | 'OP LElbow', 'OP LWrist', 'OP MidHip', 46 | 'OP RHip', 'OP RKnee', 'OP RAnkle', 47 | 'OP LHip', 'OP LKnee', 'OP LAnkle', 48 | 'OP REye', 'OP LEye', 'OP REar', 49 | 'OP LEar', 'OP LBigToe', 'OP LSmallToe', 50 | 'OP LHeel', 'OP RBigToe', 'OP RSmallToe', 'OP RHeel', 51 | 'Right Ankle', 'Right Knee', 'Right Hip', 52 | 'Left Hip', 'Left Knee', 'Left Ankle', 53 | 'Right Wrist', 'Right Elbow', 'Right Shoulder', 54 | 'Left Shoulder', 'Left Elbow', 'Left Wrist', 55 | 'Neck (LSP)', 'Top of Head (LSP)', 56 | 'Pelvis (MPII)', 'Thorax (MPII)', 57 | 'Spine (H36M)', 'Jaw (H36M)', 58 | 'Head (H36M)', 'Nose', 'Left Eye', 59 | 'Right Eye', 'Left Ear', 'Right Ear' 60 | ] 61 | 62 | 63 | # adapted from VIBE/SPIN to output smpl_joints, vibe joints and action2motion joints 64 | class SMPL(_SMPLLayer): 65 | """ Extension of the official SMPL implementation to support more joints """ 66 | 67 | def __init__(self, model_path=SMPL_MODEL_PATH, **kwargs): 68 | kwargs["model_path"] = model_path 69 | 70 | # remove the verbosity for the 10-shapes beta parameters 71 | with contextlib.redirect_stdout(None): 72 | super(SMPL, self).__init__(**kwargs) 73 | 74 | J_regressor_extra = np.load(JOINT_REGRESSOR_TRAIN_EXTRA) 75 | self.register_buffer('J_regressor_extra', torch.tensor(J_regressor_extra, dtype=torch.float32)) 76 | vibe_indexes = np.array([JOINT_MAP[i] for i in JOINT_NAMES]) 77 | a2m_indexes = vibe_indexes[action2motion_joints] 78 | smpl_indexes = np.arange(24) 79 | a2mpl_indexes = np.unique(np.r_[smpl_indexes, a2m_indexes]) 80 | 81 | self.maps = {"vibe": vibe_indexes, 82 | "a2m": a2m_indexes, 83 | "smpl": smpl_indexes, 84 | "a2mpl": a2mpl_indexes} 85 | 86 | def forward(self, *args, **kwargs): 87 | smpl_output = super(SMPL, self).forward(*args, **kwargs) 88 | 89 | extra_joints = vertices2joints(self.J_regressor_extra, smpl_output.vertices) 90 | all_joints = torch.cat([smpl_output.joints, extra_joints], dim=1) 91 | 92 | output = {"vertices": smpl_output.vertices} 93 | 94 | for joinstype, indexes in self.maps.items(): 95 | output[joinstype] = all_joints[:, indexes] 96 | 97 | return output -------------------------------------------------------------------------------- /models/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.utils.rnn import pack_padded_sequence 4 | 5 | def init_weight(m): 6 | if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear) or isinstance(m, nn.ConvTranspose1d): 7 | nn.init.xavier_normal_(m.weight) 8 | # m.bias.data.fill_(0.01) 9 | if m.bias is not None: 10 | nn.init.constant_(m.bias, 0) 11 | 12 | 13 | class MovementConvEncoder(nn.Module): 14 | def __init__(self, input_size, hidden_size, output_size): 15 | super(MovementConvEncoder, self).__init__() 16 | self.main = nn.Sequential( 17 | nn.Conv1d(input_size, hidden_size, 4, 2, 1), 18 | nn.Dropout(0.2, inplace=True), 19 | nn.LeakyReLU(0.2, inplace=True), 20 | nn.Conv1d(hidden_size, output_size, 4, 2, 1), 21 | nn.Dropout(0.2, inplace=True), 22 | nn.LeakyReLU(0.2, inplace=True), 23 | ) 24 | self.out_net = nn.Linear(output_size, output_size) 25 | self.main.apply(init_weight) 26 | self.out_net.apply(init_weight) 27 | 28 | def forward(self, inputs): 29 | inputs = inputs.permute(0, 2, 1) 30 | outputs = self.main(inputs).permute(0, 2, 1) 31 | # print(outputs.shape) 32 | return self.out_net(outputs) 33 | 34 | 35 | 36 | class TextEncoderBiGRUCo(nn.Module): 37 | def __init__(self, word_size, pos_size, hidden_size, output_size, device): 38 | super(TextEncoderBiGRUCo, self).__init__() 39 | self.device = device 40 | 41 | self.pos_emb = nn.Linear(pos_size, word_size) 42 | self.input_emb = nn.Linear(word_size, hidden_size) 43 | self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True) 44 | self.output_net = nn.Sequential( 45 | nn.Linear(hidden_size * 2, hidden_size), 46 | nn.LayerNorm(hidden_size), 47 | nn.LeakyReLU(0.2, inplace=True), 48 | nn.Linear(hidden_size, output_size) 49 | ) 50 | 51 | self.input_emb.apply(init_weight) 52 | self.pos_emb.apply(init_weight) 53 | self.output_net.apply(init_weight) 54 | self.hidden_size = hidden_size 55 | self.hidden = nn.Parameter(torch.randn((2, 1, self.hidden_size), requires_grad=True)) 56 | 57 | # input(batch_size, seq_len, dim) 58 | def forward(self, word_embs, pos_onehot, cap_lens): 59 | num_samples = word_embs.shape[0] 60 | 61 | pos_embs = self.pos_emb(pos_onehot) 62 | inputs = word_embs + pos_embs 63 | input_embs = self.input_emb(inputs) 64 | hidden = self.hidden.repeat(1, num_samples, 1) 65 | 66 | cap_lens = cap_lens.data.tolist() 67 | emb = pack_padded_sequence(input_embs, cap_lens, batch_first=True) 68 | 69 | gru_seq, gru_last = self.gru(emb, hidden) 70 | 71 | gru_last = torch.cat([gru_last[0], gru_last[1]], dim=-1) 72 | 73 | return self.output_net(gru_last) 74 | 75 | 76 | class MotionEncoderBiGRUCo(nn.Module): 77 | def __init__(self, input_size, hidden_size, output_size, device): 78 | super(MotionEncoderBiGRUCo, self).__init__() 79 | self.device = device 80 | 81 | self.input_emb = nn.Linear(input_size, hidden_size) 82 | self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True) 83 | self.output_net = nn.Sequential( 84 | nn.Linear(hidden_size*2, hidden_size), 85 | nn.LayerNorm(hidden_size), 86 | nn.LeakyReLU(0.2, inplace=True), 87 | nn.Linear(hidden_size, output_size) 88 | ) 89 | 90 | self.input_emb.apply(init_weight) 91 | self.output_net.apply(init_weight) 92 | self.hidden_size = hidden_size 93 | self.hidden = nn.Parameter(torch.randn((2, 1, self.hidden_size), requires_grad=True)) 94 | 95 | # input(batch_size, seq_len, dim) 96 | def forward(self, inputs, m_lens): 97 | num_samples = inputs.shape[0] 98 | 99 | input_embs = self.input_emb(inputs) 100 | hidden = self.hidden.repeat(1, num_samples, 1) 101 | 102 | cap_lens = m_lens.data.tolist() 103 | emb = pack_padded_sequence(input_embs, cap_lens, batch_first=True, enforce_sorted=False) 104 | 105 | gru_seq, gru_last = self.gru(emb, hidden) 106 | 107 | gru_last = torch.cat([gru_last[0], gru_last[1]], dim=-1) 108 | 109 | return self.output_net(gru_last) 110 | -------------------------------------------------------------------------------- /options/option_vq.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def get_args_parser(): 4 | parser = argparse.ArgumentParser(description='Optimal Transport AutoEncoder training for AIST', 5 | add_help=True, 6 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 7 | 8 | ## dataloader 9 | parser.add_argument('--dataname', type=str, default='kit', help='dataset directory') 10 | parser.add_argument('--batch-size', default=128, type=int, help='batch size') 11 | parser.add_argument('--window-size', type=int, default=64, help='training motion length') 12 | parser.add_argument('--eval-split', default='val') 13 | parser.add_argument('--person-id', type=int, default=0) 14 | 15 | ## optimization 16 | parser.add_argument('--total-iter', default=200000, type=int, help='number of total iterations to run') 17 | parser.add_argument('--warm-up-iter', default=1000, type=int, help='number of total iterations for warmup') 18 | parser.add_argument('--lr', default=2e-4, type=float, help='max learning rate') 19 | parser.add_argument('--lr-scheduler', default=[50000, 400000], nargs="+", type=int, help="learning rate schedule (iterations)") 20 | parser.add_argument('--gamma', default=0.05, type=float, help="learning rate decay") 21 | 22 | parser.add_argument('--weight-decay', default=0.0, type=float, help='weight decay') 23 | parser.add_argument("--commit", type=float, default=0.02, help="hyper-parameter for the commitment loss") 24 | parser.add_argument('--loss-vel', type=float, default=0.1, help='hyper-parameter for the velocity loss') 25 | parser.add_argument('--loss-aux', type=float, default=0.0, help='hyper-parameter for the velocity loss') 26 | parser.add_argument('--recons-loss', type=str, default='l2', help='reconstruction loss') 27 | parser.add_argument('--pose-alpha', type=float, default=1.0) 28 | 29 | ## vqvae arch 30 | parser.add_argument("--code-dim", type=int, default=512, help="embedding dimension") 31 | parser.add_argument("--nb-code", type=int, default=512, help="nb of embedding") 32 | parser.add_argument("--mu", type=float, default=0.99, help="exponential moving average to update the codebook") 33 | parser.add_argument("--down-t", type=int, default=2, help="downsampling rate") 34 | parser.add_argument("--stride-t", type=int, default=2, help="stride size") 35 | parser.add_argument("--width", type=int, default=512, help="width of the network") 36 | parser.add_argument("--depth", type=int, default=3, help="depth of the network") 37 | parser.add_argument("--dilation-growth-rate", type=int, default=3, help="dilation growth rate") 38 | parser.add_argument("--output-emb-width", type=int, default=512, help="output embedding width") 39 | parser.add_argument('--vq-act', type=str, default='relu', choices = ['relu', 'silu', 'gelu'], help='dataset directory') 40 | parser.add_argument('--vq-norm', type=str, default=None, help='dataset directory') 41 | parser.add_argument('--separate_pose', type=int, default=None) 42 | parser.add_argument('--aux_labels', type=int, default=0) 43 | 44 | ## quantizer 45 | parser.add_argument("--quantizer", type=str, default='ema_reset', choices = ['ema', 'orig', 'ema_reset', 'reset'], help="eps for optimal transport") 46 | parser.add_argument('--beta', type=float, default=1.0, help='commitment loss in standard VQ') 47 | 48 | ## resume 49 | parser.add_argument("--resume-pth", type=str, default=None, help='resume pth for VQ') 50 | parser.add_argument("--resume-gpt", type=str, default=None, help='resume pth for GPT') 51 | 52 | 53 | ## output directory 54 | parser.add_argument('--out-dir', type=str, default='output_vqfinal/', help='output directory') 55 | parser.add_argument('--results-dir', type=str, default='visual_results/', help='output directory') 56 | parser.add_argument('--visual-name', type=str, default='baseline', help='output directory') 57 | parser.add_argument('--exp-name', type=str, default='exp_debug', help='name of the experiment, will create a file inside out-dir') 58 | ## other 59 | parser.add_argument('--print-iter', default=200, type=int, help='print frequency') 60 | parser.add_argument('--eval-iter', default=1000, type=int, help='evaluation frequency') 61 | parser.add_argument('--seed', default=123, type=int, help='seed for initializing training.') 62 | 63 | parser.add_argument('--vis-gt', action='store_true', help='whether visualize GT motions') 64 | parser.add_argument('--nb-vis', default=20, type=int, help='nb of visualizations') 65 | 66 | 67 | return parser.parse_args() 68 | -------------------------------------------------------------------------------- /dataset/dataset_VQ.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils import data 3 | import numpy as np 4 | from os.path import join as pjoin 5 | import random 6 | import json 7 | import codecs as cs 8 | from tqdm import tqdm 9 | 10 | 11 | """ 12 | python train_vq.py --batch-size 256 --lr 2e-4 --total-iter 300000 --lr-scheduler 200000 --nb-code 256 --down-t 2 --depth 3 --dilation-growth-rate 3 --out-dir output --dataname audio_trevor --vq-act relu --quantizer ema_reset --loss-vel 0.5 --recons-loss l1_smooth --exp-name audio_p8 --window-size 32 13 | """ 14 | 15 | class VQMotionDataset(data.Dataset): 16 | def __init__(self, dataset_name, window_size = 64, unit_length = 4, split = "train", add_velocity=False, person_id = 0): 17 | self.window_size = window_size 18 | self.unit_length = unit_length 19 | self.dataset_name = dataset_name 20 | 21 | if dataset_name == 't2m': 22 | self.data_root = './dataset/HumanML3D' 23 | self.motion_dir = pjoin(self.data_root, 'new_joint_vecs') 24 | self.text_dir = pjoin(self.data_root, 'texts') 25 | self.joints_num = 22 26 | self.max_motion_length = 196 27 | self.meta_dir = 'checkpoints/t2m/VQVAEV3_CB1024_CMT_H1024_NRES3/meta' 28 | 29 | elif dataset_name == 'kit': 30 | self.data_root = './dataset/KIT-ML' 31 | self.motion_dir = pjoin(self.data_root, 'new_joint_vecs') 32 | self.text_dir = pjoin(self.data_root, 'texts') 33 | self.joints_num = 21 34 | 35 | self.max_motion_length = 196 36 | self.meta_dir = 'checkpoints/kit/VQVAEV3_CB1024_CMT_H1024_NRES3/meta' 37 | 38 | elif dataset_name.split('_')[0] == 'face': 39 | self.data_root = './dataset/'+dataset_name.split('_')[-1] 40 | self.segments = torch.load(pjoin(self.data_root, 'segments_'+split+'.pth'), map_location='cpu') 41 | self.data = [torch.cat((seg['p'+str(person_id)+'_exp'], seg['p'+str(person_id)+'_pose']), dim=1).numpy() for seg in self.segments] 42 | # print(self.data[0].shape) 43 | # assert False 44 | self.fnames = [seg['fname'] for seg, faces in zip(self.segments, self.data) if faces.shape[0] >= window_size] 45 | self.starts = [seg['split_start_frame'] for seg, faces in zip(self.segments, self.data) if faces.shape[0] >= window_size] 46 | self.data = [datum for datum in self.data if datum.shape[0] >= window_size] 47 | self.lengths = [datum.shape[0] for datum in self.data] 48 | assert len(self.fnames) == len(self.data) == len(self.lengths) == len(self.starts) 49 | 50 | self.split = split 51 | mean = np.load(pjoin(self.data_root, 'p'+str(person_id)+'_mean.npy')) 52 | std = np.load(pjoin(self.data_root, 'p'+str(person_id)+'_std.npy')) 53 | 54 | self.add_velocity = add_velocity 55 | 56 | print(dataset_name.split('_')[0]) 57 | if dataset_name.split('_')[0] != 'face' and dataset_name.split('_')[0] != 'audio': 58 | self.data = [] 59 | self.lengths = [] 60 | id_list = [] 61 | self.names = [] 62 | split_file = pjoin(self.data_root, split+'.txt') 63 | with cs.open(split_file, 'r') as f: 64 | for line in f.readlines(): 65 | id_list.append(line.strip()) 66 | 67 | for name in tqdm(id_list): 68 | try: 69 | motion = np.load(pjoin(self.motion_dir, name + '.npy')) 70 | if motion.shape[0] < self.window_size: 71 | continue 72 | self.lengths.append(motion.shape[0] - self.window_size) 73 | self.data.append(motion) 74 | self.names.append(name) 75 | except: 76 | # Some motion may not exist in KIT dataset 77 | pass 78 | 79 | 80 | self.mean = mean 81 | self.std = std 82 | print("Total number of motions {}".format(len(self.data))) 83 | 84 | def inv_transform(self, data): 85 | return data * self.std + self.mean 86 | 87 | def compute_sampling_prob(self) : 88 | 89 | prob = np.array(self.lengths, dtype=np.float32) 90 | prob /= np.sum(prob) 91 | return prob 92 | 93 | def __len__(self): 94 | return len(self.data) 95 | 96 | def __getitem__(self, item): 97 | motion = self.data[item] 98 | 99 | # print(len(motion), self.data[item].shape) 100 | idx = random.randint(0, len(motion) - self.window_size) 101 | 102 | motion = motion[idx:idx+self.window_size] 103 | "Z Normalization" 104 | motion = (motion - self.mean) / self.std 105 | """if self.add_velocity: 106 | vel = np.concatenate(( 107 | np.zeros((1, motion.shape[-1]), dtype=motion.dtype), 108 | motion[1:] - motion[:-1] 109 | ), axis=0) 110 | motion = np.concatenate((motion, vel), axis=1)""" 111 | 112 | if "train" in self.split: 113 | return motion 114 | return motion, np.array([self.window_size]).squeeze(), self.fnames[item]+'_'+str(self.starts[item]+idx) 115 | 116 | def DATALoader(dataset_name, 117 | batch_size, 118 | num_workers = 8, 119 | window_size = 64, 120 | unit_length = 4, 121 | split = "train", 122 | person_id = 0): 123 | 124 | trainSet = VQMotionDataset(dataset_name, window_size=window_size, unit_length=unit_length, split=split, person_id=person_id) 125 | prob = trainSet.compute_sampling_prob() 126 | sampler = torch.utils.data.WeightedRandomSampler(prob, num_samples = len(trainSet) * 1000, replacement=True) 127 | train_loader = torch.utils.data.DataLoader(trainSet, 128 | batch_size, 129 | shuffle=True, 130 | #sampler=sampler, 131 | num_workers=num_workers, 132 | #collate_fn=collate_fn, 133 | drop_last = (split == "train")) 134 | 135 | return train_loader 136 | 137 | def cycle(iterable): 138 | while True: 139 | for x in iterable: 140 | yield x 141 | -------------------------------------------------------------------------------- /dataset/dataset_tokenize.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils import data 3 | import numpy as np 4 | from os.path import join as pjoin 5 | import random 6 | import codecs as cs 7 | from tqdm import tqdm 8 | 9 | 10 | 11 | class VQMotionDataset(data.Dataset): 12 | def __init__(self, dataset_name, feat_bias = 5, window_size = 64, unit_length = 8, split = "train", max_motion_length = None, delay_start_frames=0, fps=30, min_length=0): 13 | self.window_size = window_size 14 | self.unit_length = unit_length 15 | self.feat_bias = feat_bias 16 | 17 | self.dataset_name = dataset_name 18 | min_motion_len = 40 if dataset_name =='t2m' else 24 19 | 20 | if dataset_name == 't2m': 21 | self.data_root = './dataset/HumanML3D' 22 | self.motion_dir = pjoin(self.data_root, 'new_joint_vecs') 23 | self.text_dir = pjoin(self.data_root, 'texts') 24 | self.joints_num = 22 25 | radius = 4 26 | fps = 20 27 | self.max_motion_length = 196 28 | dim_pose = 263 29 | self.meta_dir = 'checkpoints/t2m/VQVAEV3_CB1024_CMT_H1024_NRES3/meta' 30 | #kinematic_chain = paramUtil.t2m_kinematic_chain 31 | elif dataset_name == 'kit': 32 | self.data_root = './dataset/KIT-ML' 33 | self.motion_dir = pjoin(self.data_root, 'new_joint_vecs') 34 | self.text_dir = pjoin(self.data_root, 'texts') 35 | self.joints_num = 21 36 | radius = 240 * 8 37 | fps = 12.5 38 | dim_pose = 251 39 | self.max_motion_length = 196 40 | self.meta_dir = 'checkpoints/kit/VQVAEV3_CB1024_CMT_H1024_NRES3/meta' 41 | #kinematic_chain = paramUtil.kit_kinematic_chain 42 | elif dataset_name.split('_')[0] == 'face': 43 | self.data_root = './dataset/'+dataset_name.split('_')[-1] 44 | self.data = torch.load(pjoin(self.data_root, "segments_"+split+".pth"), map_location="cpu") 45 | self.lengths = [datum["split_end_frame"]-datum["split_start_frame"] for datum in self.data] 46 | self.names = [datum["fname"]+"_"+str(datum["split_start_frame"]) for datum in self.data] 47 | self.max_motion_length = None 48 | 49 | min_motion_len = max(min_motion_len, min_length) 50 | if max_motion_length is not None: 51 | self.max_motion_length = max_motion_length 52 | self.split = split 53 | mean = np.load(pjoin(self.data_root, 'mean.npy')) 54 | std = np.load(pjoin(self.data_root, 'std.npy')) 55 | person_id = 0 56 | 57 | if dataset_name.split('_')[0] == 'face': 58 | data_dict = {} 59 | length_list = [] 60 | new_name_list = [] 61 | step_size = self.max_motion_length 62 | # print(person_id, unit_length, history_size) 63 | # print(unit_length, history_size) 64 | for i in range(len(self.data)): 65 | motion_len = self.lengths[i] 66 | for start in range(delay_start_frames, motion_len, step_size): 67 | segment_len = min(motion_len-start, self.max_motion_length) 68 | if segment_len >= min_motion_len: # and motion_len <= self.max_motion_length: 69 | s = start 70 | e = start+segment_len 71 | length_list.append(e-s) 72 | parts = self.names[i].split('_') 73 | new_name_list.append('_'.join(parts[:-1])+'_'+str(int(parts[-1])+start)) 74 | # print('NAME', split, new_name_list[-1], s, e, fix_vq_tokenizer_start_frame) 75 | if dataset_name.split('_')[0] == 'face': 76 | curr_motion = torch.cat((self.data[i]['p'+str(person_id)+'_exp'][s:e,:], self.data[i]['p'+str(person_id)+'_pose'][s:e,:]), dim=1).numpy() 77 | data_dict[new_name_list[-1]] = { 78 | 'motion': curr_motion, 79 | 'length': segment_len, 80 | 'name': new_name_list[-1], 81 | } 82 | else: 83 | split_file = pjoin(self.data_root, split+'.txt') 84 | joints_num = self.joints_num 85 | 86 | data_dict = {} 87 | id_list = [] 88 | with cs.open(split_file, 'r') as f: 89 | for line in f.readlines(): 90 | id_list.append(line.strip()) 91 | 92 | new_name_list = [] 93 | length_list = [] 94 | for name in tqdm(id_list): 95 | try: 96 | motion = np.load(pjoin(self.motion_dir, name + '.npy')) 97 | if (len(motion)) < min_motion_len or (len(motion) >= self.max_motion_length): 98 | continue 99 | 100 | data_dict[name] = {'motion': motion, 101 | 'length': len(motion), 102 | 'name': name} 103 | new_name_list.append(name) 104 | length_list.append(len(motion)) 105 | except: 106 | # Some motion may not exist in KIT dataset 107 | pass 108 | 109 | 110 | self.mean = mean 111 | self.std = std 112 | self.length_arr = np.array(length_list) 113 | self.data_dict = data_dict 114 | self.name_list = new_name_list 115 | 116 | def inv_transform(self, data): 117 | return data * self.std + self.mean 118 | 119 | def __len__(self): 120 | return len(self.data_dict) 121 | 122 | def __getitem__(self, item): 123 | name = self.name_list[item] 124 | data = self.data_dict[name] 125 | motion, m_length = data['motion'], data['length'] 126 | 127 | m_length = (m_length // self.unit_length) * self.unit_length 128 | 129 | idx = random.randint(0, len(motion) - m_length) 130 | motion = motion[idx:idx+m_length] 131 | 132 | "Z Normalization" 133 | motion = (motion - self.mean) / self.std 134 | 135 | return motion, name 136 | 137 | def DATALoader(dataset_name, 138 | batch_size = 1, 139 | num_workers = 8, unit_length = 4, split = "train", max_motion_length = None, delay_start_frames=0, fps=30, min_length=0) : 140 | 141 | train_loader = torch.utils.data.DataLoader(VQMotionDataset(dataset_name, unit_length=unit_length, split=split, max_motion_length=max_motion_length, delay_start_frames=delay_start_frames, fps=fps, min_length=min_length), 142 | batch_size, 143 | shuffle=True, 144 | num_workers=num_workers, 145 | #collate_fn=collate_fn, 146 | drop_last = True) 147 | 148 | return train_loader 149 | 150 | def cycle(iterable): 151 | while True: 152 | for x in iterable: 153 | yield x 154 | -------------------------------------------------------------------------------- /models/vqvae.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from models.encdec import Encoder, Decoder 5 | from models.quantize_cnn import QuantizeEMAReset, Quantizer, QuantizeEMA, QuantizeReset 6 | 7 | 8 | class VQVAE_251(nn.Module): 9 | def __init__(self, 10 | args, 11 | nb_code=1024, 12 | code_dim=512, 13 | output_emb_width=512, 14 | down_t=3, 15 | stride_t=2, 16 | width=512, 17 | depth=3, 18 | dilation_growth_rate=3, 19 | activation='relu', 20 | norm=None, 21 | index_groups=None): 22 | 23 | super().__init__() 24 | self.quant = args.quantizer 25 | input_dim = 263 26 | if args.dataname.split('_')[0] == "face": 27 | input_dim = 56 28 | elif args.dataname.split('_')[0] == "pats": 29 | input_dim = 129 30 | elif args.dataname.split('_')[0] == "audio": 31 | input_dim = 128 32 | elif args.dataname == "kit": 33 | input_dim = 251 34 | if index_groups is None: 35 | index_groups = [list(range(input_dim))] 36 | if not isinstance(nb_code, list): 37 | nb_code = [nb_code for _ in index_groups] 38 | self.num_code = math.prod(nb_code) 39 | self.code_dim = code_dim*len(index_groups) 40 | self.index_groups = index_groups 41 | encoders = [Encoder(len(group), output_emb_width, down_t, stride_t, width, depth, dilation_growth_rate, activation=activation, norm=norm) for group in index_groups] 42 | self.encoder = nn.ModuleList(encoders) 43 | decoders = [Decoder(len(group), output_emb_width, down_t, stride_t, width, depth, dilation_growth_rate, activation=activation, norm=norm) for group in index_groups] 44 | self.decoder = nn.ModuleList(decoders) 45 | # self.encoder = Encoder(input_dim, output_emb_width, down_t, stride_t, width, depth, dilation_growth_rate, activation=activation, norm=norm) 46 | # self.decoder = Decoder(input_dim, output_emb_width, down_t, stride_t, width, depth, dilation_growth_rate, activation=activation, norm=norm) 47 | if args.quantizer == "ema_reset": 48 | self.quantizer = nn.ModuleList([QuantizeEMAReset(nb_c, code_dim, args) for nb_c in nb_code]) 49 | elif args.quantizer == "orig": 50 | self.quantizer = nn.ModuleList([Quantizer(nb_c, code_dim, 1.0) for nb_c in nb_code]) 51 | elif args.quantizer == "ema": 52 | self.quantizer = nn.ModuleList([QuantizeEMA(nb_c, code_dim, args) for nb_c in nb_code]) 53 | elif args.quantizer == "reset": 54 | self.quantizer = nn.ModuleList([QuantizeReset(nb_c, code_dim, args) for nb_c in nb_code]) 55 | 56 | 57 | def preprocess(self, x): 58 | # (bs, T, Jx3) -> (bs, Jx3, T) 59 | x = x.permute(0,2,1).float() 60 | return x 61 | 62 | 63 | def postprocess(self, x): 64 | # (bs, Jx3, T) -> (bs, T, Jx3) 65 | x = x.permute(0,2,1) 66 | return x 67 | 68 | 69 | def encode(self, x): 70 | code_idx_per_group = [] 71 | for group_index, group in enumerate(self.index_groups): 72 | N, T, _ = x[...,group].shape 73 | x_in = self.preprocess(x[...,group]) 74 | x_encoder = self.encoder[group_index](x_in) 75 | x_encoder = self.postprocess(x_encoder) 76 | x_encoder = x_encoder.contiguous().view(-1, x_encoder.shape[-1]) # (NT, C) 77 | code_idx = self.quantizer[group_index].quantize(x_encoder) 78 | # print("code", code_idx.shape, "x", x.shape) 79 | # assert False 80 | code_idx = code_idx.view(N, -1) 81 | code_idx_per_group.append(code_idx) 82 | if len(code_idx_per_group) == 1: 83 | return code_idx_per_group[0] 84 | return tuple(code_idx_per_group) 85 | 86 | 87 | def forward(self, x): 88 | total_loss = 0 89 | total_perplexity = 0 90 | output = torch.zeros_like(x).float() 91 | # print('x', x.shape) 92 | for group_index, group in enumerate(self.index_groups): 93 | # print('group', group) 94 | x_in = self.preprocess(x[...,group]) 95 | # print('x_in', x_in.shape) 96 | # Encode 97 | x_encoder = self.encoder[group_index](x_in) 98 | 99 | ## quantization 100 | x_quantized, loss, perplexity = self.quantizer[group_index](x_encoder) 101 | 102 | total_loss += loss 103 | total_perplexity += perplexity 104 | 105 | ## decoder 106 | x_decoder = self.decoder[group_index](x_quantized) 107 | x_out = self.postprocess(x_decoder) 108 | output[...,group] = x_out 109 | return output, total_loss, total_perplexity 110 | 111 | 112 | def forward_decoder(self, x): 113 | output = [] 114 | for group_index, group in enumerate(self.index_groups): 115 | x_d = self.quantizer[group_index].dequantize(x) 116 | x_d = x_d.view(1, -1, self.code_dim).permute(0, 2, 1).contiguous() 117 | 118 | # decoder 119 | x_decoder = self.decoder[group_index](x_d) 120 | x_out = self.postprocess(x_decoder) 121 | output.append((group, x_out)) 122 | x_out = torch.zeros((*output[0][1].shape[:-1], sum([len(out[0]) for out in output]))).to(x.device).to(output[0][1].dtype) 123 | for (group, out) in output: 124 | x_out[...,group] = out 125 | return x_out 126 | 127 | 128 | 129 | class HumanVQVAE(nn.Module): 130 | def __init__(self, 131 | args, 132 | nb_code=512, 133 | code_dim=512, 134 | output_emb_width=512, 135 | down_t=3, 136 | stride_t=2, 137 | width=512, 138 | depth=3, 139 | dilation_growth_rate=3, 140 | activation='relu', 141 | norm=None, 142 | aux_labels=0, 143 | index_groups=None): 144 | 145 | super().__init__() 146 | 147 | self.nb_joints = 21 if args.dataname == 'kit' else 22 148 | self.vqvae = VQVAE_251(args, nb_code, code_dim, output_emb_width, down_t, stride_t, width, depth, dilation_growth_rate, activation=activation, norm=norm, index_groups=index_groups) 149 | if aux_labels > 0: 150 | self.aux_classifier = nn.Linear(output_emb_width, aux_labels) 151 | 152 | def encode(self, x): 153 | b, t, c = x.size() 154 | quants = self.vqvae.encode(x) # (N, T) 155 | return quants 156 | 157 | def forward(self, x): 158 | 159 | x_out, loss, perplexity = self.vqvae(x) 160 | 161 | return x_out, loss, perplexity 162 | 163 | def forward_decoder(self, x): 164 | x_out = self.vqvae.forward_decoder(x) 165 | return x_out 166 | 167 | def forward_aux(self, x): 168 | encoding = self.vqvae.encoder(self.vqvae.preprocess(x)) 169 | return self.aux_classifier(encoding) 170 | -------------------------------------------------------------------------------- /train_vq.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | import torch 5 | import torch.optim as optim 6 | from torch.utils.tensorboard import SummaryWriter 7 | 8 | import models.vqvae as vqvae 9 | import utils.losses as losses 10 | import options.option_vq as option_vq 11 | import utils.utils_model as utils_model 12 | from dataset import dataset_VQ, dataset_TM_eval 13 | import utils.eval_trans as eval_trans 14 | from options.get_eval_option import get_opt 15 | from models.evaluator_wrapper import EvaluatorModelWrapper 16 | import warnings 17 | warnings.filterwarnings('ignore') 18 | from utils.word_vectorizer import WordVectorizer 19 | 20 | def update_lr_warm_up(optimizer, nb_iter, warm_up_iter, lr): 21 | 22 | current_lr = lr * (nb_iter + 1) / (warm_up_iter + 1) 23 | for param_group in optimizer.param_groups: 24 | param_group["lr"] = current_lr 25 | 26 | return optimizer, current_lr 27 | 28 | ##### ---- Exp dirs ---- ##### 29 | args = option_vq.get_args_parser() 30 | torch.manual_seed(args.seed) 31 | 32 | args.out_dir = os.path.join(args.out_dir, f'{args.exp_name}') 33 | os.makedirs(args.out_dir, exist_ok = True) 34 | 35 | ##### ---- Logger ---- ##### 36 | logger = utils_model.get_logger(args.out_dir) 37 | writer = SummaryWriter(args.out_dir) 38 | logger.info(json.dumps(vars(args), indent=4, sort_keys=True)) 39 | 40 | 41 | 42 | # w_vectorizer = WordVectorizer('./glove', 'our_vab') 43 | 44 | if args.dataname.split('_')[0] == 'face': 45 | args.nb_joints = None 46 | else: 47 | 48 | if args.dataname == 'kit' : 49 | dataset_opt_path = 'checkpoints/kit/Comp_v6_KLD005/opt.txt' 50 | args.nb_joints = 21 51 | 52 | else : 53 | dataset_opt_path = 'checkpoints/t2m/Comp_v6_KLD005/opt.txt' 54 | args.nb_joints = 22 55 | 56 | logger.info(f'Training on {args.dataname}, motions are with {args.nb_joints} joints') 57 | 58 | ##### ---- Dataloader ---- ##### 59 | train_loader = dataset_VQ.DATALoader(args.dataname, 60 | args.batch_size, 61 | window_size=args.window_size, 62 | unit_length=2**args.down_t, split="train", person_id=args.person_id) 63 | 64 | train_loader_iter = dataset_VQ.cycle(train_loader) 65 | 66 | val_loader = dataset_VQ.DATALoader(args.dataname, args.batch_size, window_size=args.window_size, unit_length=2**args.down_t, split=args.eval_split, person_id=args.person_id) 67 | 68 | index_groups = None 69 | if args.separate_pose is not None: 70 | index_groups = [list(range(50))+[53, 54, 55], [50, 51, 52]] 71 | args.nb_code = [args.nb_code, args.separate_pose] 72 | 73 | ##### ---- Network ---- ##### 74 | net = vqvae.HumanVQVAE(args, ## use args to define different parameters in different quantizers 75 | args.nb_code, 76 | args.code_dim, 77 | args.output_emb_width, 78 | args.down_t, 79 | args.stride_t, 80 | args.width, 81 | args.depth, 82 | args.dilation_growth_rate, 83 | args.vq_act, 84 | args.vq_norm, 85 | args.aux_labels, 86 | index_groups) 87 | 88 | 89 | if args.resume_pth : 90 | logger.info('loading checkpoint from {}'.format(args.resume_pth)) 91 | ckpt = torch.load(args.resume_pth, map_location='cpu') 92 | net.load_state_dict(ckpt['net'], strict=True) 93 | net.train() 94 | net.cuda() 95 | 96 | ##### ---- Optimizer & Scheduler ---- ##### 97 | optimizer = optim.AdamW(net.parameters(), lr=args.lr, betas=(0.9, 0.99), weight_decay=args.weight_decay) 98 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_scheduler, gamma=args.gamma) 99 | 100 | 101 | Loss = losses.ReConsLoss(args.recons_loss, args.nb_joints, args.pose_alpha) 102 | 103 | ##### ------ warm-up ------- ##### 104 | avg_recons, avg_perplexity, avg_commit = 0., 0., 0. 105 | 106 | for nb_iter in range(1, args.warm_up_iter): 107 | 108 | optimizer, current_lr = update_lr_warm_up(optimizer, nb_iter, args.warm_up_iter, args.lr) 109 | 110 | gt_motion = next(train_loader_iter) 111 | # print("mot!!", gt_motion.shape) # (256, 32, 56) 112 | # assert False 113 | if args.loss_aux > 0: 114 | gt_motion, aux_gt = gt_motion 115 | gt_motion = gt_motion.cuda().float() # (bs, 64, dim) 116 | 117 | pred_motion, loss_commit, perplexity = net(gt_motion) 118 | loss_motion = Loss(pred_motion, gt_motion) 119 | loss_vel = 0.0 120 | loss_aux = 0.0 121 | if args.loss_vel > 0: 122 | loss_vel = Loss.forward_vel(pred_motion, gt_motion) 123 | if args.loss_aux > 0: 124 | aux_pred = net.forward_aux(gt_motion) 125 | loss_aux = torch.nn.CrossEntropyLoss()(aux_pred.unsqueeze(2).repeat(1, 1, gt_motion.shape[1] // aux_pred.shape[1], 1).view(aux_pred.shape[0], -1, aux_pred.shape[-1]), aux_gt) 126 | 127 | loss = loss_motion + args.commit * loss_commit + args.loss_vel * loss_vel + args.loss_aux * loss_aux 128 | 129 | optimizer.zero_grad() 130 | loss.backward() 131 | optimizer.step() 132 | 133 | avg_recons += loss_motion.item() 134 | avg_perplexity += perplexity.item() 135 | avg_commit += loss_commit.item() 136 | 137 | if nb_iter % args.print_iter == 0 : 138 | avg_recons /= args.print_iter 139 | avg_perplexity /= args.print_iter 140 | avg_commit /= args.print_iter 141 | 142 | logger.info(f"Warmup. Iter {nb_iter} : lr {current_lr:.5f} \t Commit. {avg_commit:.5f} \t PPL. {avg_perplexity:.2f} \t Recons. {avg_recons:.5f}") 143 | 144 | avg_recons, avg_perplexity, avg_commit = 0., 0., 0. 145 | 146 | ##### ---- Training ---- ##### 147 | avg_recons, avg_perplexity, avg_commit = 0., 0., 0. 148 | best_perplexity, best_iter, best_commit, best_recons, writter, logger = eval_trans.evaluation_vqvae(args.out_dir, val_loader, net, logger, writer, 0, best_perplexity=float("inf"), best_commit=float("inf"), best_recons=float("inf"), best_iter=0, recons_loss_fn=Loss, savenpy=True, save=(args.total_iter > 0)) 149 | 150 | for nb_iter in range(1, args.total_iter + 1): 151 | 152 | gt_motion = next(train_loader_iter) 153 | gt_motion = gt_motion.cuda().float() # bs, nb_joints, joints_dim, seq_len 154 | 155 | pred_motion, loss_commit, perplexity = net(gt_motion) 156 | loss_motion = Loss(pred_motion, gt_motion) 157 | loss_vel = 0.0 158 | if args.loss_vel > 0: 159 | loss_vel = Loss.forward_vel(pred_motion, gt_motion) 160 | 161 | loss = loss_motion + args.commit * loss_commit + args.loss_vel * loss_vel 162 | 163 | optimizer.zero_grad() 164 | loss.backward() 165 | optimizer.step() 166 | scheduler.step() 167 | 168 | avg_recons += loss_motion.item() 169 | avg_perplexity += perplexity.item() 170 | avg_commit += loss_commit.item() 171 | 172 | if nb_iter % args.print_iter == 0 : 173 | avg_recons /= args.print_iter 174 | avg_perplexity /= args.print_iter 175 | avg_commit /= args.print_iter 176 | 177 | writer.add_scalar('./Train/L1', avg_recons, nb_iter) 178 | writer.add_scalar('./Train/PPL', avg_perplexity, nb_iter) 179 | writer.add_scalar('./Train/Commit', avg_commit, nb_iter) 180 | 181 | logger.info(f"Train. Iter {nb_iter} : \t Commit. {avg_commit:.5f} \t PPL. {avg_perplexity:.2f} \t Recons. {avg_recons:.5f}") 182 | 183 | avg_recons, avg_perplexity, avg_commit = 0., 0., 0., 184 | 185 | if nb_iter % args.eval_iter==0 : 186 | best_perplexity, best_iter, best_commit, best_recons, writter, logger = eval_trans.evaluation_vqvae(args.out_dir, val_loader, net, logger, writer, nb_iter, best_perplexity=best_perplexity, best_commit=best_commit, best_recons=best_recons, best_iter=best_iter, recons_loss_fn=Loss, savenpy=True) 187 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Can Language Models Learn to Listen? 2 | This is the repo for the paper [Can Language Models Learn to Listen?](https://arxiv.org/abs/2308.10897), appearing at ICCV 2023. 3 | 4 | ## Setup Environment 5 | Create a new Python 3 environment and install PyTorch 1.11.0 from https://pytorch.org/get-started/previous-versions/. Then install the requirements for this repo via `pip install -r requirements.txt`. 6 | Also, please clone the DECA (for visualization) and EMOCA (for emotion/valence evaluation) repositories, and set the following environment variables: 7 | ``` 8 | export PYTHONPATH=/PATH/TO/EMOCA/:$PYTHONPATH 9 | export DECA_PATH=/PATH/TO/DECA/ 10 | ``` 11 | You will need to change EMOCA emotion recogition to not process from image. In `gdl/models/EmoDeca.py`, add the following lines to the beginning of the `forward` method: 12 | ``` 13 | if 'image' not in batch: 14 | values = batch 15 | else: 16 | values = self.deca.encode(batch, training=False) 17 | ``` 18 | You will also need to download the DECA and EMOCA models (there are instructions in those repos). 19 | 20 | ## Data Preparation 21 | Please download the data from the Google Driver folder: [here](https://drive.google.com/file/d/1fR4sobslLB0gESQj6zya63XgupJFpVjc/view?usp=sharing). Place the data so that there are directories `dataset/trevor`, `dataset/conan`, `dataset/stephen`, and `dataset/trevorconanstephen` that have the corresponding segment files. 22 | Note: If you want to use a cross-speaker VQ to train an LM Listener for a speaker (as we did for Conan and Stephen), you should copy the corresponding speaker's directory and then overwrite the `mean.npy` and `std.npy` files with the files from the `trevorconanstephen` directory. For instance, for Conan, you should copy `dataset/conan` to `dataset/conanglobal` and then copy `dataset/trevorconanstephen/{mean,std}.npy` to `dataset/conanglobal/`. 23 | 24 | ## Pre-trained model 25 | We provide a pre-trained VQ model and LM Listener for Trevor Noah [here](https://drive.google.com/drive/folders/1WMAsrky61gI36x_IstkoNzuNiqCmhKJV?usp=sharing). 26 | 27 | ## Training 28 | The following command will train a VQ encoder-decoder: 29 | ``` 30 | python3 train_vq.py \ 31 | --batch-size 256 \ 32 | --lr 2e-4 \ 33 | --total-iter 300000 \ 34 | --lr-scheduler 200000 \ 35 | --nb-code 256 \ 36 | --down-t 3 \ 37 | --depth 3 \ 38 | --window-size 32 \ 39 | --dilation-growth-rate 3 \ 40 | --out-dir output \ 41 | --dataname face_{trevor/trevorconanstephen} \ 42 | --vq-act relu \ 43 | --quantizer ema_reset \ 44 | --loss-vel 0.5 \ 45 | --recons-loss l1_smooth \ 46 | --exp-name VQVAE_{trevor/trevorconanstephen} 47 | ``` 48 | The following command will train an LM Listener: 49 | ``` 50 | python train_t2m_trans.py \ 51 | --exp-name listener_{trevor/conanglobal/stephenglobal} \ 52 | --batch-size 8 \ 53 | --nb-code 256 \ 54 | --drop-out-rate 0.1 \ 55 | --resume-pth output/VQVAE_{trevor/trevorconanstephen}/net_iter300000.pth \ 56 | --vq-name VQVAE_{trevor/trevorconanstephen} \ 57 | --out-dir output \ 58 | --total-iter 100000 \ 59 | --lr-scheduler 150000 \ 60 | --lr 0.00005 \ 61 | --dataname face_realtalkv2 \ 62 | --down-t 2 \ 63 | --depth 3 \ 64 | --quantizer ema_reset \ 65 | --eval-iter 2000 \ 66 | --pkeep 0.50 \ 67 | --dilation-growth-rate 3 \ 68 | --vq-act relu \ 69 | --max-motion-length 240 \ 70 | --gpt2 gpt2-medium \ 71 | --print_val_pred \ 72 | --gradient_accumulation_steps 2 \ 73 | --manual-bf16 \ 74 | --delay-start-frames 96 \ 75 | --max-time-before 3 76 | ``` 77 | 78 | ## Generation 79 | The following command can be used to generate prediction files (in `.npy` format) from a trained LM Listener: 80 | ``` 81 | python train_t2m_trans.py \ 82 | --exp-name listener_{trevor/conanglobal/stephenglobal} \ 83 | --batch-size 8 \ 84 | --nb-code 256 \ 85 | --drop-out-rate 0.1 \ 86 | --resume-pth output/VQVAE_{trevor/trevorconanstephen}/net_iter300000.pth \ 87 | --vq-name VQVAE_{trevor/trevorconanstephen} \ 88 | --out-dir output \ 89 | --total-iter 0 \ 90 | --lr-scheduler 150000 \ 91 | --lr 0.00005 \ 92 | --dataname face_trevor \ 93 | --down-t 3 \ 94 | --depth 3 \ 95 | --quantizer ema_reset \ 96 | --eval-iter 2000 \ 97 | --pkeep 0.50 \ 98 | --dilation-growth-rate 3 \ 99 | --vq-act relu \ 100 | --max-motion-length 240 \ 101 | --gpt2 gpt2-medium \ 102 | --print_val_pred \ 103 | --gradient_accumulation_steps 2 \ 104 | --manual-bf16 \ 105 | --delay-start-frames 96 \ 106 | --max-time-before 3 \ 107 | --save-name subdir_where_predictions_will_be_saved \ 108 | --seed 50 \ 109 | --resume-trans /path/to/model/checkpoint.pth 110 | ``` 111 | 112 | ## Evaluation 113 | The following command can be used to compute evaluation metrics for an LM Listener: 114 | ``` 115 | python evaluate_listener.py --output_dir output/{EXPERIMENT_NAME} --segments_path dataset/{trevor/conanglobal/stephenglobal}/segments_val.pth --mean_std_path dataset/{trevor/conanglobal/stephenglobal}/ 116 | ``` 117 | 118 | ## Baselines 119 | To produce a directory of predictions for the Random VQ, Random Train, and Nearest Neighbor baselines, use the following command templates: 120 | ``` 121 | python baselines.py --vq-dir dataset/{trevor/conanglobal/stephenglobal}/vqvae_{trevor/trevorconanstephen}_val/ --output-dir output/{trevor/conan/stephen}_random_vq --params-path path_to_vq_config.json --max-motion-length 240 --history-size 3 --mean-std-path dataset/{trevor/conanglobal/stephenglobal}/ 122 | python baselines.py --vq-dir dataset/{trevor/conanglobal/stephenglobal}/vqvae_{trevor/trevorconanstephen}_val/ --output-dir output/{trevor/conan/stephen}_nearest_neighbor --params-path path_to_vq_config.json --max-motion-length 240 --history-size 3 --mean-std-path dataset/{trevor/conanglobal/stephenglobal}/ --train-segments-path dataset/{trevor/conanglobal/stephenglobal}/segments_train.pth --val-segments-path dataset/{trevor/conanglobal/stephenglobal}/segments_val.pth --nearest-neighbor --embedding-model-name sentence-transformers/all-mpnet-base-v2 --batch-size 32 --normalize 123 | python baselines.py --vq-dir dataset/{trevor/conanglobal/stephenglobal}/vqvae_{trevor/trevorconanstephen}_val/ --output-dir output/{trevor/conan/stephen}_random_train --params-path path_to_vq_config.json --max-motion-length 240 --history-size 3 --mean-std-path dataset/{trevor/conanglobal/stephenglobal}/ --train-segments-path dataset/{trevor/conanglobal/stephenglobal}/segments_train.pth --val-segments-path dataset/{trevor/conanglobal/stephenglobal}/segments_val.pth 124 | ``` 125 | The format of the predictions is `.npy`, just like the predictions produced by the LM Listener. 126 | 127 | ## Visualization 128 | The following command can be used to generate visualizations for an LM Listener: 129 | ``` 130 | python visualize_listener.py --output_dir /path/to/output/dir/ --segments_path dataset/{trevor/conanglobal/stephenglobal}/segments_val.pth --default_code_path default_code_trevor_emoca2.pkl --params_path output/{EXPERIMENT_NAME}/config.json --items output/{EXPERIMENT_NAME}/,vq,gt,video --mean_std_path dataset/{trevor/conanglobal/stephenglobal}/ --audio_root /path/to/raw/audios/ --video_root /path/to/raw/videos/ --fps 30 131 | ``` 132 | The `--items` parameter allows you to specify a comma-separated list of what to visualize. The options are: `video` (raw video), `gt` (the ground-truth EMOCA face reconstruction of the listener), `vq` (the VQ reconstruction of the listener), or a path to the output directory containing the predicted `.npy` files of an LM Listener. 133 | 134 | ## Acknowledgements 135 | Much of the code in this repo is taken from [T2M-GPT](https://github.com/Mael-zys/T2M-GPT). 136 | 137 | ## Citation 138 | ``` 139 | @inproceedings{ng2023text2listen, 140 | title={Can Language Models Learn to Listen?} 141 | author={Ng, Evonne and Subramanian, Sanjay 142 | and Klein, Dan and Kanazawa, Angjoo 143 | and Darrell, Trevor and Ginosar, Shiry}, 144 | booktitle={Proceedings of the International 145 | Conference on Computer Vision (ICCV)}, 146 | year={2023} 147 | } 148 | ``` 149 | -------------------------------------------------------------------------------- /options/option_transformer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def get_args_parser(): 4 | parser = argparse.ArgumentParser(description='Optimal Transport AutoEncoder training for Amass', 5 | add_help=True, 6 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 7 | 8 | ## dataloader 9 | 10 | parser.add_argument('--dataname', type=str, default='kit', help='dataset directory') 11 | parser.add_argument('--batch-size', default=128, type=int, help='batch size') 12 | parser.add_argument('--fps', default=[30], nargs="+", type=int, help='frames per second') 13 | parser.add_argument('--seq-len', type=int, default=64, help='training motion length') 14 | parser.add_argument('--max-motion-length', type=int, default=128, help='max motion length') 15 | parser.add_argument('--max-tokens', type=int, default=None) 16 | parser.add_argument('--step-size', type=int, default=None, help='max motion length') 17 | parser.add_argument('--train_eval', action="store_true") 18 | parser.add_argument('--test-eval', action='store_true') 19 | parser.add_argument('--data_v2', action="store_true") 20 | parser.add_argument('--no-before-text', action="store_true") 21 | parser.add_argument('--max-time-before', type=int, default=None) 22 | parser.add_argument('--normalize-speaker', action="store_true") 23 | parser.add_argument('--normalize-audio', action="store_true") 24 | parser.add_argument('--delay-start-frames', type=int, default=0) 25 | parser.add_argument('--train-min-length', type=int, default=0) 26 | parser.add_argument('--val-min-length', type=int, default=0) 27 | 28 | ## optimization 29 | parser.add_argument('--total-iter', default=100000, type=int, help='number of total iterations to run') 30 | parser.add_argument('--warm-up-iter', default=1000, type=int, help='number of total iterations for warmup') 31 | parser.add_argument('--lr', default=2e-4, type=float, help='max learning rate') 32 | parser.add_argument('--lr-scheduler', default=[60000], nargs="+", type=int, help="learning rate schedule (iterations)") 33 | parser.add_argument('--gamma', default=0.05, type=float, help="learning rate decay") 34 | 35 | parser.add_argument('--weight-decay', default=1e-6, type=float, help='weight decay') 36 | parser.add_argument('--decay-option',default='all', type=str, choices=['all', 'noVQ'], help='disable weight decay on codebook') 37 | parser.add_argument('--optimizer',default='adamw', type=str, choices=['adam', 'adamw'], help='disable weight decay on codebook') 38 | parser.add_argument('--gradient_accumulation_steps', type=int, default=1) 39 | parser.add_argument('--grad-scaling', action="store_true") 40 | parser.add_argument("--fp16", action="store_true") 41 | parser.add_argument("--fp16-half", action="store_true") 42 | parser.add_argument("--manual-bf16", action="store_true") 43 | parser.add_argument("--gradient-checkpointing", action="store_true") 44 | parser.add_argument("--grad-clip", type=float, default=None) 45 | parser.add_argument("--train-loss-threshold", type=float, default=0.001) 46 | parser.add_argument("--training-end-check-interval", type=int, default=600) 47 | parser.add_argument("--lora", action="store_true") 48 | parser.add_argument("--linear-scheduler", action="store_true") 49 | 50 | ## vqvae arch 51 | parser.add_argument("--code-dim", type=int, default=512, help="embedding dimension") 52 | parser.add_argument("--nb-code", type=int, default=512, help="nb of embedding") 53 | parser.add_argument("--mu", type=float, default=0.99, help="exponential moving average to update the codebook") 54 | parser.add_argument("--down-t", type=int, default=3, help="downsampling rate") 55 | parser.add_argument("--stride-t", type=int, default=2, help="stride size") 56 | parser.add_argument("--width", type=int, default=512, help="width of the network") 57 | parser.add_argument("--depth", type=int, default=3, help="depth of the network") 58 | parser.add_argument("--dilation-growth-rate", type=int, default=3, help="dilation growth rate") 59 | parser.add_argument("--output-emb-width", type=int, default=512, help="output embedding width") 60 | parser.add_argument('--vq-act', type=str, default='relu', choices = ['relu', 'silu', 'gelu'], help='dataset directory') 61 | 62 | ## gpt arch 63 | parser.add_argument("--block-size", type=int, default=25, help="seq len") 64 | parser.add_argument("--embed-dim-gpt", type=int, default=512, help="embedding dimension") 65 | parser.add_argument("--clip-dim", type=int, default=512, help="latent dimension in the clip feature") 66 | parser.add_argument("--num-layers", type=int, default=2, help="nb of transformer layers") 67 | parser.add_argument("--n-head-gpt", type=int, default=8, help="nb of heads") 68 | parser.add_argument("--ff-rate", type=int, default=4, help="feedforward size") 69 | parser.add_argument("--drop-out-rate", type=float, default=0.1, help="dropout ratio in the pos encoding") 70 | parser.add_argument("--text-model-name", type=str, default="openai/clip-vit-base-patch32") 71 | parser.add_argument("--extra_input_dim", type=int, nargs='+', default=[]) 72 | parser.add_argument("--include-speaker", action="store_true") 73 | parser.add_argument("--include-audio", action="store_true") 74 | parser.add_argument("--include-speaker-before", action="store_true") 75 | parser.add_argument("--include-audio-before", action="store_true") 76 | parser.add_argument("--text_token_level", action="store_true") 77 | parser.add_argument("--top_p", type=float, default=None) 78 | parser.add_argument("--label_smoothing", type=float, default=0.0) 79 | parser.add_argument("--no-text", action="store_true") 80 | parser.add_argument("--no-end", action="store_true") 81 | parser.add_argument("--gpt2", type=str, default=None) 82 | parser.add_argument("--sentiment-token", type=str, default=None) 83 | parser.add_argument("--freeze-lm", action="store_true") 84 | parser.add_argument("--num-output-layers", type=int, default=0) 85 | parser.add_argument("--transformer-not-pretrained", action="store_true") 86 | parser.add_argument("--speaker-pkeep", type=float, default=None) 87 | parser.add_argument("--audio-pkeep", type=float, default=None) 88 | parser.add_argument("--speaker-vq-path") 89 | parser.add_argument("--speaker-vq-loss", action="store_true") 90 | parser.add_argument("--audio-vq-path") 91 | parser.add_argument("--audio-vq-loss", action="store_true") 92 | parser.add_argument("--fix-pkeep", action="store_true") 93 | parser.add_argument("--fixed-text-token", action="store_true") 94 | parser.add_argument("--fixed-text-token-not-space", action="store_true") 95 | parser.add_argument("--fixed-text-token-not-punctuation", action="store_true") 96 | parser.add_argument("--unaligned-text", action="store_true") 97 | parser.add_argument("--remove-space-before-vq-tokens", action="store_true") 98 | parser.add_argument("--random-text-token-order", action="store_true") 99 | 100 | ## quantizer 101 | parser.add_argument("--quantizer", type=str, default='ema_reset', choices = ['ema', 'orig', 'ema_reset', 'reset'], help="eps for optimal transport") 102 | parser.add_argument('--quantbeta', type=float, default=1.0, help='dataset directory') 103 | 104 | ## resume 105 | parser.add_argument("--resume-pth", type=str, default=None, help='resume vq pth') 106 | parser.add_argument("--resume-trans", type=str, default=None, help='resume gpt pth') 107 | 108 | 109 | ## output directory 110 | parser.add_argument('--out-dir', type=str, default='output_GPT_Final/', help='output directory') 111 | parser.add_argument('--exp-name', type=str, default='exp_debug', help='name of the experiment, will create a file inside out-dir') 112 | parser.add_argument('--vq-name', type=str, default='exp_debug', help='name of the generated dataset .npy, will create a file inside out-dir') 113 | ## other 114 | parser.add_argument('--print-iter', default=200, type=int, help='print frequency') 115 | parser.add_argument('--eval-iter', default=5000, type=int, help='evaluation frequency') 116 | parser.add_argument('--seed', default=123, type=int, help='seed for initializing training. ') 117 | parser.add_argument("--if-maxtest", action='store_true', help="test in max") 118 | parser.add_argument('--pkeep', type=float, default=1.0, help='keep rate for gpt training') 119 | parser.add_argument('--print_val_pred', action='store_true') 120 | parser.add_argument('--save-name', default=None) 121 | parser.add_argument('--num-samples', default=1, type=int) 122 | 123 | parser.add_argument("--valence-window-size", type=int, default=90) 124 | parser.add_argument('--control-sentiment', action="store_true") 125 | 126 | 127 | return parser.parse_args() 128 | -------------------------------------------------------------------------------- /baselines.py: -------------------------------------------------------------------------------- 1 | import random 2 | import json 3 | import numpy as np 4 | from collections import defaultdict 5 | import argparse 6 | import os 7 | from tqdm import tqdm 8 | import torch 9 | from transformers import AutoTokenizer, AutoModel 10 | from sentence_transformers import SentenceTransformer 11 | 12 | import models.vqvae as vqvae 13 | 14 | if __name__ == "__main__": 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument("--vq-dir") 17 | parser.add_argument("--output-dir") 18 | parser.add_argument("--seed", type=int, default=42) 19 | parser.add_argument("--params-path") 20 | parser.add_argument("--mean-std-path") 21 | parser.add_argument("--train-segments-path") 22 | parser.add_argument("--val-segments-path") 23 | parser.add_argument("--max-motion-length", type=int) 24 | parser.add_argument("--history-size", type=int) 25 | parser.add_argument("--nearest-neighbor", action="store_true") 26 | parser.add_argument("--normalize", action="store_true") 27 | parser.add_argument("--embedding-model-name") 28 | parser.add_argument("--batch-size", type=int, default=8) 29 | parser.add_argument("--fps", type=int, default=30) 30 | parser.add_argument("--static-face", action="store_true") 31 | parser.add_argument("--random-train-select", action="store_true") 32 | args = parser.parse_args() 33 | os.system('mkdir '+args.output_dir) 34 | random.seed(args.seed) 35 | np.random.seed(seed=args.seed) 36 | if args.mean_std_path is not None: 37 | mean = np.load(os.path.join(args.mean_std_path, 'mean.npy')) 38 | std = np.load(os.path.join(args.mean_std_path, 'std.npy')) 39 | with open(args.params_path) as f: 40 | params = json.load(f) 41 | for key in params: 42 | if not hasattr(args, key): 43 | setattr(args, key, params[key]) 44 | if args.nearest_neighbor or args.random_train_select: 45 | segments = torch.load(args.train_segments_path, map_location='cpu') 46 | text_to_motion = {} 47 | fps = args.fps 48 | text_to_file_id = {} 49 | for seg in segments: 50 | if seg['split_end_frame']-seg['split_start_frame'] < args.max_motion_length: 51 | continue 52 | for start in range(seg['split_start_frame'], seg['split_end_frame']-args.max_motion_length+1): 53 | words = [word for word in seg['before_words']+seg['during_words'] if word['end']*fps >= start-args.history_size*fps and word['end']*fps < start+args.max_motion_length] 54 | text = ' '.join([word['text'] for word in words]) 55 | if text not in text_to_motion: 56 | text_to_motion[text] = torch.cat((seg['p0_exp'][start-seg['split_start_frame']:start-seg['split_start_frame']+args.max_motion_length,:], seg['p0_pose'][start-seg['split_start_frame']:start-seg['split_start_frame']+args.max_motion_length,:]), dim=1).numpy() 57 | text_to_file_id[text] = seg['fname']+'_'+str(start) 58 | # text_to_motion[text] = torch.cat((seg['p0_exp'], seg['p0_pose']), dim=1) 59 | all_texts = [] 60 | text_id_to_motion = {} 61 | file_ids = [] 62 | for i, text in enumerate(text_to_motion): 63 | all_texts.append(text) 64 | text_id_to_motion[i] = text_to_motion[text] 65 | file_ids.append(text_to_file_id[text]) 66 | if args.nearest_neighbor: 67 | text_embeddings = [] 68 | tokenizer = AutoTokenizer.from_pretrained(args.embedding_model_name) 69 | if tokenizer.pad_token is None: 70 | tokenizer.pad_token = tokenizer.eos_token 71 | model = SentenceTransformer(args.embedding_model_name).eval() 72 | if torch.cuda.is_available(): 73 | model = model.cuda() 74 | # bos_token = '' 75 | for i in tqdm(range(0, len(all_texts), args.batch_size)): 76 | sentence_embeddings = torch.from_numpy(model.encode(all_texts[i:i+args.batch_size])) 77 | for j in range(sentence_embeddings.shape[0]): 78 | text_embeddings.append(sentence_embeddings[j:j+1,:].cpu()) 79 | text_embeddings = torch.cat(text_embeddings, dim=0) 80 | assert text_embeddings.shape[0] == len(all_texts), str(text_embeddings.shape)+', '+str(len(all_texts)) 81 | print(text_embeddings.shape) 82 | print(text_embeddings[:2,:]) 83 | if torch.cuda.is_available(): 84 | text_embeddings = text_embeddings.cuda() 85 | if args.normalize: 86 | # text_embeddings = text_embeddings / torch.linalg.norm(text_embeddings, dim=-1, keepdim=True) 87 | text_embeddings = torch.nn.functional.normalize(text_embeddings, p=2, dim=1) 88 | val_segments = torch.load(args.val_segments_path, map_location="cpu") 89 | val_segments_dict = {} 90 | for seg in val_segments: 91 | for i in range(seg['split_start_frame'], seg['split_end_frame']): 92 | val_segments_dict[seg['fname'].split('/')[-1]+'_'+str(i)] = seg 93 | # print(val_segments_dict.keys()) 94 | else: 95 | net = vqvae.HumanVQVAE(args, ## use args to define different parameters in different quantizers 96 | args.nb_code, 97 | args.code_dim, 98 | args.output_emb_width, 99 | args.down_t, 100 | args.stride_t, 101 | args.width, 102 | args.depth, 103 | args.dilation_growth_rate) 104 | ckpt = torch.load(args.resume_pth, map_location='cpu') 105 | net.load_state_dict(ckpt['net'], strict=True) 106 | net.eval() 107 | if torch.cuda.is_available(): 108 | net.cuda() 109 | chosen = defaultdict(int) 110 | count = 0 111 | for root, _, files in tqdm(os.walk(args.vq_dir)): 112 | for fname in files: 113 | if fname[-4:] == ".npy": 114 | count += 1 115 | gt_vq = np.load(os.path.join(root, fname)) 116 | num_frames = gt_vq.reshape(-1).shape[0]*(2**args.down_t) 117 | if args.nearest_neighbor: 118 | seg = val_segments_dict[fname.split('.npy')[0]] 119 | start_frame = int(fname.split('.npy')[0].split('_')[-1]) 120 | words = [word['text'] for word in seg['before_words']+seg['during_words'] if word['end']*fps >= start_frame-args.history_size*fps and word['end']*fps < start_frame+num_frames] 121 | text = ' '.join(words) 122 | embedding = torch.from_numpy(model.encode([text])).view(1, -1) 123 | if torch.cuda.is_available(): 124 | embedding = embedding.to('cuda:0') 125 | if args.normalize: 126 | # text_embeddings = text_embeddings / torch.linalg.norm(text_embeddings, dim=-1, keepdim=True) 127 | embedding = torch.nn.functional.normalize(embedding, p=2, dim=1) 128 | best_index = (text_embeddings @ embedding.t()).view(-1).argmax().item() 129 | # if count > 180: 130 | print(fname, text, '|||', all_texts[best_index]) 131 | chosen[best_index] += 1 132 | motion = text_id_to_motion[best_index][:num_frames,:] 133 | elif args.random_train_select: 134 | index = random.choice(list(range(len(text_id_to_motion)))) 135 | motion = text_id_to_motion[index][:num_frames,:] 136 | elif args.static_face: 137 | motion = np.expand_dims(mean, axis=0) 138 | if len(gt_vq.shape) == 3: 139 | motion = np.repeat(np.expand_dims(motion, axis=0), num_frames, axis=1) 140 | else: 141 | motion = np.repeat(motion, num_frames, axis=0) 142 | else: 143 | random_pred = np.random.randint(low=0, high=args.nb_code, size=gt_vq.shape) 144 | inp = torch.from_numpy(random_pred).view(1, -1) 145 | if torch.cuda.is_available(): 146 | inp = inp.cuda() 147 | with torch.no_grad(): 148 | decoded = net.forward_decoder(inp) 149 | motion = decoded.cpu().view(-1, 56).numpy() 150 | motion = (motion*std.reshape(1, -1))+mean.reshape(1, -1) 151 | path_parts = os.path.join(root, fname).replace('.npy', '_pred.npy').split('/') 152 | new_path = os.path.join(args.output_dir, *path_parts[-4:]) 153 | path_parts = new_path.split('/') 154 | # print(path_parts) 155 | for j in range(len(path_parts)-1): 156 | if not os.path.exists('/'.join(path_parts[:j+1])): 157 | os.system('mkdir '+'/'.join(path_parts[:j+1])) 158 | np.save(new_path, motion) 159 | -------------------------------------------------------------------------------- /utils/skeleton.py: -------------------------------------------------------------------------------- 1 | from utils.quaternion import * 2 | import scipy.ndimage.filters as filters 3 | 4 | class Skeleton(object): 5 | def __init__(self, offset, kinematic_tree, device): 6 | self.device = device 7 | self._raw_offset_np = offset.numpy() 8 | self._raw_offset = offset.clone().detach().to(device).float() 9 | self._kinematic_tree = kinematic_tree 10 | self._offset = None 11 | self._parents = [0] * len(self._raw_offset) 12 | self._parents[0] = -1 13 | for chain in self._kinematic_tree: 14 | for j in range(1, len(chain)): 15 | self._parents[chain[j]] = chain[j-1] 16 | 17 | def njoints(self): 18 | return len(self._raw_offset) 19 | 20 | def offset(self): 21 | return self._offset 22 | 23 | def set_offset(self, offsets): 24 | self._offset = offsets.clone().detach().to(self.device).float() 25 | 26 | def kinematic_tree(self): 27 | return self._kinematic_tree 28 | 29 | def parents(self): 30 | return self._parents 31 | 32 | # joints (batch_size, joints_num, 3) 33 | def get_offsets_joints_batch(self, joints): 34 | assert len(joints.shape) == 3 35 | _offsets = self._raw_offset.expand(joints.shape[0], -1, -1).clone() 36 | for i in range(1, self._raw_offset.shape[0]): 37 | _offsets[:, i] = torch.norm(joints[:, i] - joints[:, self._parents[i]], p=2, dim=1)[:, None] * _offsets[:, i] 38 | 39 | self._offset = _offsets.detach() 40 | return _offsets 41 | 42 | # joints (joints_num, 3) 43 | def get_offsets_joints(self, joints): 44 | assert len(joints.shape) == 2 45 | _offsets = self._raw_offset.clone() 46 | for i in range(1, self._raw_offset.shape[0]): 47 | # print(joints.shape) 48 | _offsets[i] = torch.norm(joints[i] - joints[self._parents[i]], p=2, dim=0) * _offsets[i] 49 | 50 | self._offset = _offsets.detach() 51 | return _offsets 52 | 53 | # face_joint_idx should follow the order of right hip, left hip, right shoulder, left shoulder 54 | # joints (batch_size, joints_num, 3) 55 | def inverse_kinematics_np(self, joints, face_joint_idx, smooth_forward=False): 56 | assert len(face_joint_idx) == 4 57 | '''Get Forward Direction''' 58 | l_hip, r_hip, sdr_r, sdr_l = face_joint_idx 59 | across1 = joints[:, r_hip] - joints[:, l_hip] 60 | across2 = joints[:, sdr_r] - joints[:, sdr_l] 61 | across = across1 + across2 62 | across = across / np.sqrt((across**2).sum(axis=-1))[:, np.newaxis] 63 | # print(across1.shape, across2.shape) 64 | 65 | # forward (batch_size, 3) 66 | forward = np.cross(np.array([[0, 1, 0]]), across, axis=-1) 67 | if smooth_forward: 68 | forward = filters.gaussian_filter1d(forward, 20, axis=0, mode='nearest') 69 | # forward (batch_size, 3) 70 | forward = forward / np.sqrt((forward**2).sum(axis=-1))[..., np.newaxis] 71 | 72 | '''Get Root Rotation''' 73 | target = np.array([[0,0,1]]).repeat(len(forward), axis=0) 74 | root_quat = qbetween_np(forward, target) 75 | 76 | '''Inverse Kinematics''' 77 | # quat_params (batch_size, joints_num, 4) 78 | # print(joints.shape[:-1]) 79 | quat_params = np.zeros(joints.shape[:-1] + (4,)) 80 | # print(quat_params.shape) 81 | root_quat[0] = np.array([[1.0, 0.0, 0.0, 0.0]]) 82 | quat_params[:, 0] = root_quat 83 | # quat_params[0, 0] = np.array([[1.0, 0.0, 0.0, 0.0]]) 84 | for chain in self._kinematic_tree: 85 | R = root_quat 86 | for j in range(len(chain) - 1): 87 | # (batch, 3) 88 | u = self._raw_offset_np[chain[j+1]][np.newaxis,...].repeat(len(joints), axis=0) 89 | # print(u.shape) 90 | # (batch, 3) 91 | v = joints[:, chain[j+1]] - joints[:, chain[j]] 92 | v = v / np.sqrt((v**2).sum(axis=-1))[:, np.newaxis] 93 | # print(u.shape, v.shape) 94 | rot_u_v = qbetween_np(u, v) 95 | 96 | R_loc = qmul_np(qinv_np(R), rot_u_v) 97 | 98 | quat_params[:,chain[j + 1], :] = R_loc 99 | R = qmul_np(R, R_loc) 100 | 101 | return quat_params 102 | 103 | # Be sure root joint is at the beginning of kinematic chains 104 | def forward_kinematics(self, quat_params, root_pos, skel_joints=None, do_root_R=True): 105 | # quat_params (batch_size, joints_num, 4) 106 | # joints (batch_size, joints_num, 3) 107 | # root_pos (batch_size, 3) 108 | if skel_joints is not None: 109 | offsets = self.get_offsets_joints_batch(skel_joints) 110 | if len(self._offset.shape) == 2: 111 | offsets = self._offset.expand(quat_params.shape[0], -1, -1) 112 | joints = torch.zeros(quat_params.shape[:-1] + (3,)).to(self.device) 113 | joints[:, 0] = root_pos 114 | for chain in self._kinematic_tree: 115 | if do_root_R: 116 | R = quat_params[:, 0] 117 | else: 118 | R = torch.tensor([[1.0, 0.0, 0.0, 0.0]]).expand(len(quat_params), -1).detach().to(self.device) 119 | for i in range(1, len(chain)): 120 | R = qmul(R, quat_params[:, chain[i]]) 121 | offset_vec = offsets[:, chain[i]] 122 | joints[:, chain[i]] = qrot(R, offset_vec) + joints[:, chain[i-1]] 123 | return joints 124 | 125 | # Be sure root joint is at the beginning of kinematic chains 126 | def forward_kinematics_np(self, quat_params, root_pos, skel_joints=None, do_root_R=True): 127 | # quat_params (batch_size, joints_num, 4) 128 | # joints (batch_size, joints_num, 3) 129 | # root_pos (batch_size, 3) 130 | if skel_joints is not None: 131 | skel_joints = torch.from_numpy(skel_joints) 132 | offsets = self.get_offsets_joints_batch(skel_joints) 133 | if len(self._offset.shape) == 2: 134 | offsets = self._offset.expand(quat_params.shape[0], -1, -1) 135 | offsets = offsets.numpy() 136 | joints = np.zeros(quat_params.shape[:-1] + (3,)) 137 | joints[:, 0] = root_pos 138 | for chain in self._kinematic_tree: 139 | if do_root_R: 140 | R = quat_params[:, 0] 141 | else: 142 | R = np.array([[1.0, 0.0, 0.0, 0.0]]).repeat(len(quat_params), axis=0) 143 | for i in range(1, len(chain)): 144 | R = qmul_np(R, quat_params[:, chain[i]]) 145 | offset_vec = offsets[:, chain[i]] 146 | joints[:, chain[i]] = qrot_np(R, offset_vec) + joints[:, chain[i - 1]] 147 | return joints 148 | 149 | def forward_kinematics_cont6d_np(self, cont6d_params, root_pos, skel_joints=None, do_root_R=True): 150 | # cont6d_params (batch_size, joints_num, 6) 151 | # joints (batch_size, joints_num, 3) 152 | # root_pos (batch_size, 3) 153 | if skel_joints is not None: 154 | skel_joints = torch.from_numpy(skel_joints) 155 | offsets = self.get_offsets_joints_batch(skel_joints) 156 | if len(self._offset.shape) == 2: 157 | offsets = self._offset.expand(cont6d_params.shape[0], -1, -1) 158 | offsets = offsets.numpy() 159 | joints = np.zeros(cont6d_params.shape[:-1] + (3,)) 160 | joints[:, 0] = root_pos 161 | for chain in self._kinematic_tree: 162 | if do_root_R: 163 | matR = cont6d_to_matrix_np(cont6d_params[:, 0]) 164 | else: 165 | matR = np.eye(3)[np.newaxis, :].repeat(len(cont6d_params), axis=0) 166 | for i in range(1, len(chain)): 167 | matR = np.matmul(matR, cont6d_to_matrix_np(cont6d_params[:, chain[i]])) 168 | offset_vec = offsets[:, chain[i]][..., np.newaxis] 169 | # print(matR.shape, offset_vec.shape) 170 | joints[:, chain[i]] = np.matmul(matR, offset_vec).squeeze(-1) + joints[:, chain[i-1]] 171 | return joints 172 | 173 | def forward_kinematics_cont6d(self, cont6d_params, root_pos, skel_joints=None, do_root_R=True): 174 | # cont6d_params (batch_size, joints_num, 6) 175 | # joints (batch_size, joints_num, 3) 176 | # root_pos (batch_size, 3) 177 | if skel_joints is not None: 178 | # skel_joints = torch.from_numpy(skel_joints) 179 | offsets = self.get_offsets_joints_batch(skel_joints) 180 | if len(self._offset.shape) == 2: 181 | offsets = self._offset.expand(cont6d_params.shape[0], -1, -1) 182 | joints = torch.zeros(cont6d_params.shape[:-1] + (3,)).to(cont6d_params.device) 183 | joints[..., 0, :] = root_pos 184 | for chain in self._kinematic_tree: 185 | if do_root_R: 186 | matR = cont6d_to_matrix(cont6d_params[:, 0]) 187 | else: 188 | matR = torch.eye(3).expand((len(cont6d_params), -1, -1)).detach().to(cont6d_params.device) 189 | for i in range(1, len(chain)): 190 | matR = torch.matmul(matR, cont6d_to_matrix(cont6d_params[:, chain[i]])) 191 | offset_vec = offsets[:, chain[i]].unsqueeze(-1) 192 | # print(matR.shape, offset_vec.shape) 193 | joints[:, chain[i]] = torch.matmul(matR, offset_vec).squeeze(-1) + joints[:, chain[i-1]] 194 | return joints 195 | 196 | 197 | 198 | 199 | 200 | -------------------------------------------------------------------------------- /dataset/dataset_TM_eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils import data 3 | import numpy as np 4 | from os.path import join as pjoin 5 | import random 6 | import codecs as cs 7 | from tqdm import tqdm 8 | 9 | import utils.paramUtil as paramUtil 10 | from torch.utils.data._utils.collate import default_collate 11 | 12 | 13 | def collate_fn(batch): 14 | batch.sort(key=lambda x: x[3], reverse=True) 15 | return default_collate(batch) 16 | 17 | 18 | '''For use of training text-2-motion generative model''' 19 | class Text2MotionDataset(data.Dataset): 20 | def __init__(self, dataset_name, is_test, w_vectorizer, feat_bias = 5, max_text_len = 20, unit_length = 4): 21 | 22 | self.max_length = 20 23 | self.pointer = 0 24 | self.dataset_name = dataset_name 25 | self.is_test = is_test 26 | self.max_text_len = max_text_len 27 | self.unit_length = unit_length 28 | self.w_vectorizer = w_vectorizer 29 | if dataset_name == 't2m': 30 | self.data_root = './dataset/HumanML3D' 31 | self.motion_dir = pjoin(self.data_root, 'new_joint_vecs') 32 | self.text_dir = pjoin(self.data_root, 'texts') 33 | self.joints_num = 22 34 | radius = 4 35 | fps = 20 36 | self.max_motion_length = 196 37 | dim_pose = 263 38 | kinematic_chain = paramUtil.t2m_kinematic_chain 39 | self.meta_dir = 'checkpoints/t2m/VQVAEV3_CB1024_CMT_H1024_NRES3/meta' 40 | elif dataset_name == 'kit': 41 | self.data_root = './dataset/KIT-ML' 42 | self.motion_dir = pjoin(self.data_root, 'new_joint_vecs') 43 | self.text_dir = pjoin(self.data_root, 'texts') 44 | self.joints_num = 21 45 | radius = 240 * 8 46 | fps = 12.5 47 | dim_pose = 251 48 | self.max_motion_length = 196 49 | kinematic_chain = paramUtil.kit_kinematic_chain 50 | self.meta_dir = 'checkpoints/kit/VQVAEV3_CB1024_CMT_H1024_NRES3/meta' 51 | elif dataset_name.split('_')[0] == 'face': 52 | self.data_root = './dataset/'+dataset_name.split('_')[1] 53 | 54 | 55 | mean = np.load(pjoin(self.meta_dir, 'mean.npy')) 56 | std = np.load(pjoin(self.meta_dir, 'std.npy')) 57 | 58 | if is_test: 59 | split_file = pjoin(self.data_root, 'test.txt') 60 | else: 61 | split_file = pjoin(self.data_root, 'val.txt') 62 | 63 | min_motion_len = 40 if self.dataset_name =='t2m' else 24 64 | # min_motion_len = 64 65 | 66 | joints_num = self.joints_num 67 | 68 | data_dict = {} 69 | id_list = [] 70 | with cs.open(split_file, 'r') as f: 71 | for line in f.readlines(): 72 | id_list.append(line.strip()) 73 | 74 | new_name_list = [] 75 | length_list = [] 76 | for name in tqdm(id_list): 77 | try: 78 | motion = np.load(pjoin(self.motion_dir, name + '.npy')) 79 | if (len(motion)) < min_motion_len or (len(motion) >= 200): 80 | continue 81 | text_data = [] 82 | flag = False 83 | with cs.open(pjoin(self.text_dir, name + '.txt')) as f: 84 | for line in f.readlines(): 85 | text_dict = {} 86 | line_split = line.strip().split('#') 87 | caption = line_split[0] 88 | tokens = line_split[1].split(' ') 89 | f_tag = float(line_split[2]) 90 | to_tag = float(line_split[3]) 91 | f_tag = 0.0 if np.isnan(f_tag) else f_tag 92 | to_tag = 0.0 if np.isnan(to_tag) else to_tag 93 | 94 | text_dict['caption'] = caption 95 | text_dict['tokens'] = tokens 96 | if f_tag == 0.0 and to_tag == 0.0: 97 | flag = True 98 | text_data.append(text_dict) 99 | else: 100 | try: 101 | n_motion = motion[int(f_tag*fps) : int(to_tag*fps)] 102 | if (len(n_motion)) < min_motion_len or (len(n_motion) >= 200): 103 | continue 104 | new_name = random.choice('ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name 105 | while new_name in data_dict: 106 | new_name = random.choice('ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name 107 | data_dict[new_name] = {'motion': n_motion, 108 | 'length': len(n_motion), 109 | 'text':[text_dict]} 110 | new_name_list.append(new_name) 111 | length_list.append(len(n_motion)) 112 | except: 113 | print(line_split) 114 | print(line_split[2], line_split[3], f_tag, to_tag, name) 115 | # break 116 | 117 | if flag: 118 | data_dict[name] = {'motion': motion, 119 | 'length': len(motion), 120 | 'text': text_data} 121 | new_name_list.append(name) 122 | length_list.append(len(motion)) 123 | except Exception as e: 124 | # print(e) 125 | pass 126 | 127 | name_list, length_list = zip(*sorted(zip(new_name_list, length_list), key=lambda x: x[1])) 128 | self.mean = mean 129 | self.std = std 130 | self.length_arr = np.array(length_list) 131 | self.data_dict = data_dict 132 | self.name_list = name_list 133 | self.reset_max_len(self.max_length) 134 | 135 | def reset_max_len(self, length): 136 | assert length <= self.max_motion_length 137 | self.pointer = np.searchsorted(self.length_arr, length) 138 | print("Pointer Pointing at %d"%self.pointer) 139 | self.max_length = length 140 | 141 | def inv_transform(self, data): 142 | return data * self.std + self.mean 143 | 144 | def forward_transform(self, data): 145 | return (data - self.mean) / self.std 146 | 147 | def __len__(self): 148 | return len(self.data_dict) - self.pointer 149 | 150 | def __getitem__(self, item): 151 | idx = self.pointer + item 152 | name = self.name_list[idx] 153 | data = self.data_dict[name] 154 | # data = self.data_dict[self.name_list[idx]] 155 | motion, m_length, text_list = data['motion'], data['length'], data['text'] 156 | # Randomly select a caption 157 | text_data = random.choice(text_list) 158 | caption, tokens = text_data['caption'], text_data['tokens'] 159 | 160 | if len(tokens) < self.max_text_len: 161 | # pad with "unk" 162 | tokens = ['sos/OTHER'] + tokens + ['eos/OTHER'] 163 | sent_len = len(tokens) 164 | tokens = tokens + ['unk/OTHER'] * (self.max_text_len + 2 - sent_len) 165 | else: 166 | # crop 167 | tokens = tokens[:self.max_text_len] 168 | tokens = ['sos/OTHER'] + tokens + ['eos/OTHER'] 169 | sent_len = len(tokens) 170 | pos_one_hots = [] 171 | word_embeddings = [] 172 | for token in tokens: 173 | word_emb, pos_oh = self.w_vectorizer[token] 174 | pos_one_hots.append(pos_oh[None, :]) 175 | word_embeddings.append(word_emb[None, :]) 176 | pos_one_hots = np.concatenate(pos_one_hots, axis=0) 177 | word_embeddings = np.concatenate(word_embeddings, axis=0) 178 | 179 | if self.unit_length < 10: 180 | coin2 = np.random.choice(['single', 'single', 'double']) 181 | else: 182 | coin2 = 'single' 183 | 184 | if coin2 == 'double': 185 | m_length = (m_length // self.unit_length - 1) * self.unit_length 186 | elif coin2 == 'single': 187 | m_length = (m_length // self.unit_length) * self.unit_length 188 | idx = random.randint(0, len(motion) - m_length) 189 | motion = motion[idx:idx+m_length] 190 | 191 | "Z Normalization" 192 | motion = (motion - self.mean) / self.std 193 | 194 | if m_length < self.max_motion_length: 195 | motion = np.concatenate([motion, 196 | np.zeros((self.max_motion_length - m_length, motion.shape[1])) 197 | ], axis=0) 198 | 199 | return word_embeddings, pos_one_hots, caption, sent_len, motion, m_length, '_'.join(tokens), name 200 | 201 | 202 | 203 | 204 | def DATALoader(dataset_name, is_test, 205 | batch_size, w_vectorizer, 206 | num_workers = 8, unit_length = 4) : 207 | 208 | val_loader = torch.utils.data.DataLoader(Text2MotionDataset(dataset_name, is_test, w_vectorizer, unit_length=unit_length), 209 | batch_size, 210 | shuffle = True, 211 | num_workers=num_workers, 212 | collate_fn=collate_fn, 213 | drop_last = True) 214 | return val_loader 215 | 216 | 217 | def cycle(iterable): 218 | while True: 219 | for x in iterable: 220 | yield x 221 | -------------------------------------------------------------------------------- /visualize_listener.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import cv2 3 | import torch 4 | import models.vqvae as vqvae 5 | import numpy as np 6 | import subprocess 7 | import json 8 | import os 9 | import cv2 10 | from tqdm import tqdm 11 | import pickle as pkl 12 | 13 | import sys 14 | sys.path.append(os.environ['DECA_PATH']) 15 | from decalib.deca import DECA 16 | from decalib.utils.config import cfg as deca_cfg 17 | from decalib.datasets import datasets 18 | 19 | def gen_image(deca, codedict, include_im, fix_cam=True): 20 | #codedict['cam'] = [5.,-0.02,0.02] 21 | if fix_cam: 22 | codedict['cam'][0,0] = 5. 23 | codedict['cam'][0,1] = 0. 24 | codedict['cam'][0,2] = 0.05 25 | #print(codedict['cam']) 26 | opdict, visdict = deca.decode(codedict) # , include_im=include_im) #tensor 27 | landmarks = {'landmarks2d': visdict['landmarks2d']} 28 | if include_im: 29 | #remainder = {'inputs': visdict['inputs'], 'shape_detail_images': visdict['shape_detail_images']} 30 | remainder = {'shape_detail_images': visdict['shape_detail_images'], 'inputs': visdict['inputs']} 31 | else: 32 | remainder = {'shape_detail_images': visdict['shape_detail_images']} 33 | 34 | #if include_im: 35 | # remainder['inputs'] = visdict['inputs'] 36 | return deca.visualize(remainder, size=640), deca.visualize(landmarks, size=640) 37 | 38 | 39 | if __name__ == "__main__": 40 | parser = argparse.ArgumentParser() 41 | parser.add_argument("--items", help="comma-separated list of things to visualize (choices: \"video\", \"gt\", \"vq\", or any output directory)") 42 | parser.add_argument("--output_dir") 43 | parser.add_argument("--segments_path") 44 | parser.add_argument("--default_code_path") 45 | parser.add_argument("--params_path") 46 | parser.add_argument("--mean_std_path") 47 | parser.add_argument("--audio_root") 48 | parser.add_argument("--video_root") 49 | parser.add_argument("--tmp-dir", default="vis_tmp") 50 | parser.add_argument("--fps", type=int, default=30) 51 | parser.add_argument("--history", type=int, default=0) 52 | args = parser.parse_args() 53 | 54 | deca = DECA(config = deca_cfg, device='cuda') 55 | with open(args.default_code_path, 'rb') as f: 56 | default_code = pkl.load(f) 57 | basename = os.path.basename(os.path.abspath(args.output_dir)) 58 | 59 | params = None 60 | if args.params_path is not None: 61 | with open(args.params_path) as f: 62 | params = json.load(f) 63 | for key in params: 64 | if not hasattr(args, key): 65 | setattr(args, key, params[key]) 66 | net = vqvae.HumanVQVAE(args, ## use args to define different parameters in different quantizers 67 | args.nb_code, 68 | args.code_dim, 69 | args.output_emb_width, 70 | args.down_t, 71 | args.stride_t, 72 | args.width, 73 | args.depth, 74 | args.dilation_growth_rate) 75 | unit_length = 2**args.down_t 76 | ckpt = torch.load(args.resume_pth, map_location='cpu') 77 | net.load_state_dict(ckpt['net'], strict=True) 78 | net.eval() 79 | net.cuda() 80 | mean = torch.from_numpy(np.load(os.path.join(args.mean_std_path, 'mean.npy'))).cuda().view(1, -1) 81 | std = torch.from_numpy(np.load(os.path.join(args.mean_std_path, 'std.npy'))).cuda().view(1, -1) 82 | segments = torch.load(args.segments_path, map_location='cpu') 83 | frame_map = {} 84 | for index, seg in enumerate(segments): 85 | for i in range(seg['split_start_frame'], seg['split_end_frame']): 86 | frame_map[seg['fname'].split('/')[-1]+'_'+str(i)] = torch.from_numpy(np.concatenate((seg['p0_exp'][i-seg['split_start_frame'],:], seg['p0_pose'][i-seg['split_start_frame'],:]), axis=0)) 87 | 88 | items = [item.strip() for item in args.items.split(',')] 89 | fname_pairs = [] 90 | fname_maps = {} 91 | for item in items: 92 | if item not in ["gt", "vq", "video"]: 93 | fname_pairs = [] 94 | fname_maps[item] = {} 95 | for root, _, files in os.walk(item): 96 | for fname in files: 97 | if '_pred.npy' in fname: 98 | fname_pairs.append(('_'.join(fname.split('_')[:-2]), int(fname.split('_')[-2]))) 99 | fname_maps[item]['_'.join(fname.split('_')[:-1])] = root+'/'+'_'.join(fname.split('_')[:-1]) 100 | if len(fname_pairs) == 0: 101 | for i, datum in enumerate(segments): 102 | fname_pairs.append((datum['fname'], datum['split_start_frame'])) 103 | audio_fname_map = {} 104 | for root, dirs, files in os.walk(args.audio_root+'/'): 105 | for fname in files: 106 | if fname[-4:] in {'.mp3', '.wav'}: 107 | audio_fname_map[fname.split('.')[0]] = root.split(args.audio_root+'/')[1] 108 | video_fname_map = {} 109 | for root, dirs, files in os.walk(args.video_root+'/'): 110 | for fname in files: 111 | if fname[-4:] in {'.mp4'}: 112 | video_fname_map[fname.split('.')[0]] = root.split(args.video_root+'/')[1] 113 | os.system('mkdir '+args.tmp_dir) 114 | for path, start_frame in tqdm(fname_pairs): 115 | fname = path.split('/')[-1] 116 | frames = [] 117 | num_frames = None 118 | f = start_frame 119 | pred_dict = {} 120 | for item in items: 121 | if item not in {'video', 'gt', 'vq'}: 122 | pred = np.load(os.path.join(fname_maps[item][path+'_'+str(start_frame)]+'_pred.npy')) 123 | pred_dict[item] = pred.reshape(-1, pred.shape[-1]) 124 | cap = None 125 | video_fname = fname 126 | if video_fname[-4:] != '.mp4': 127 | video_fname += '.mp4' 128 | if "video" in items: 129 | assert args.video_root is not None, "video_root must be non-None if you want to include the video in the visualization" 130 | print(os.path.join(args.video_root, video_fname_map[fname.split('.')[0]], fname)) 131 | cap = cv2.VideoCapture(os.path.join(args.video_root, video_fname_map[fname.split('.')[0]], video_fname)) 132 | cap.set(cv2.CAP_PROP_POS_FRAMES, int(cap.get(cv2.CAP_PROP_FPS)*start_frame/args.fps)-1) 133 | while num_frames is None or len(frames) < num_frames: 134 | frame = [] 135 | for item in items: 136 | if item == "gt": 137 | gt_code = { 138 | 'exp': frame_map[fname+'_'+str(f)][:50].cuda().view(1, -1), 139 | 'pose': frame_map[fname+'_'+str(f)][50:56].cuda().view(1, -1) 140 | } 141 | for key in default_code: 142 | if key not in {'exp', 'pose'}: 143 | gt_code[key] = default_code[key].float().cuda() 144 | gt_code[key] = gt_code[key].cuda() 145 | gt_image, _ = gen_image(deca, gt_code, include_im=False) 146 | frame.append(gt_image) 147 | if num_frames is None: 148 | t = f 149 | while fname+'_'+str(t) in frame_map: 150 | t += 1 151 | num_frames = t-f 152 | # print('GT', num_frames) 153 | elif item == "video": 154 | cap.set(cv2.CAP_PROP_POS_FRAMES, int(cap.get(cv2.CAP_PROP_FPS)*(start_frame+len(frames))/args.fps)-1) 155 | res, video_frame = cap.read() 156 | assert res 157 | video_frame = cv2.resize(video_frame, (640, 640), interpolation = cv2.INTER_AREA) 158 | frame.append(video_frame) 159 | elif item not in {"vq", "video"}: 160 | pred_code = { 161 | 'exp': torch.from_numpy(pred_dict[item][f-start_frame,:50]).cuda().view(1, -1), 162 | 'pose': torch.from_numpy(pred_dict[item][f-start_frame,50:56]).cuda().view(1, -1) 163 | } 164 | for key in default_code: 165 | if key not in {'exp', 'pose'}: 166 | pred_code[key] = default_code[key].float().cuda() 167 | pred_code[key] = pred_code[key].cuda() 168 | pred_image, _ = gen_image(deca, pred_code, include_im=False) 169 | frame.append(pred_image) 170 | if num_frames is None: 171 | num_frames = pred_dict[item].shape[0] 172 | num_frames = min(num_frames, pred_dict[item].shape[0]) 173 | # print('PRED', num_frames, pred_dict[item].shape) 174 | frames.append(tuple(frame)) 175 | f += 1 176 | if "vq" in items: 177 | gt = torch.stack([ 178 | frame_map[fname+'_'+str(f)] 179 | for f in range(start_frame, start_frame+num_frames) 180 | ]).cuda() 181 | normalized = ((gt-mean.cuda()) / std.cuda()).unsqueeze(0).cuda() 182 | with torch.no_grad(): 183 | encoded = net.encode(normalized) 184 | decoded = net.forward_decoder(encoded).view(-1, 56) 185 | denorm = (std*decoded+mean) 186 | while denorm.shape[0] < num_frames: 187 | denorm = torch.cat((denorm, denorm[-1:,:]), dim=0) 188 | for f in range(num_frames): 189 | vq_code = { 190 | 'exp': denorm[f,:50].view(1, -1), 191 | 'pose': denorm[f,50:].view(1, -1), 192 | } 193 | for key in default_code: 194 | if key not in {'exp', 'pose'}: 195 | vq_code[key] = default_code[key].float().cuda() 196 | vq_code[key] = vq_code[key].cuda() 197 | vq_image, _ = gen_image(deca, vq_code, include_im=False) 198 | frames[f] = frames[f][:items.index('vq')]+(vq_image,)+frames[f][items.index('vq'):] 199 | print('NUM_FRAMES', len(frames)) 200 | vis_start_frame = max(0, start_frame - args.fps * args.history) 201 | start_time = vis_start_frame / args.fps 202 | num_frames += start_frame-vis_start_frame 203 | interval = (num_frames) / args.fps 204 | prefix_frames = [] 205 | for f in range(vis_start_frame, start_frame): 206 | cap.set(cv2.CAP_PROP_POS_FRAMES, int(cap.get(cv2.CAP_PROP_FPS)*(f)/args.fps)-1) 207 | res, video_frame = cap.read() 208 | assert res 209 | video_frame = cv2.resize(video_frame, (640, 640), interpolation = cv2.INTER_AREA) 210 | frame = [] 211 | for item in items: 212 | if item == "video": 213 | frame.append(video_frame) 214 | else: 215 | frame.append(np.zeros_like(video_frame)) 216 | prefix_frames.append(tuple(frame)) 217 | frames = prefix_frames + frames 218 | for f, frame in enumerate(frames): 219 | cv2.imwrite(args.tmp_dir+'/{:08d}.jpg'.format(f), np.concatenate(frame, axis=1)) 220 | audio_path = None 221 | if args.audio_root is not None: 222 | audio_path = os.path.join(args.audio_root, audio_fname_map[fname.split('.')[0]], video_fname.replace('.mp4', '.wav')) 223 | if not os.path.exists(audio_path): 224 | audio_path = audio_path.replace('.wav', '.mp3') 225 | subprocess.call('ffmpeg -y -ss '+str(start_time)+' -t '+str(interval)+' -i '+audio_path+' '+args.tmp_dir+'/audio.wav', shell=True) 226 | cmd = "ffmpeg -y -r "+str(args.fps)+f" -start_number 0 -i "+args.tmp_dir+"/%8d.jpg -i "+args.tmp_dir+f"/audio.wav -pix_fmt yuv420p -vframes {num_frames} "+os.path.join(args.output_dir, fname+'_'+str(start_frame))+'.mp4' 227 | else: 228 | cmd = "ffmpeg -y -r "+str(args.fps)+f" -start_number 0 -i "+args.tmp_dir+"/%8d.jpg -pix_fmt yuv420p -vframes {num_frames} "+os.path.join(args.output_dir, fname+'_'+str(start_frame))+'.mp4' 229 | subprocess.call(cmd, shell=True) 230 | os.system('rm -rf '+args.tmp_dir+'/*') 231 | -------------------------------------------------------------------------------- /evaluate_listener.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | import numpy as np 5 | import pickle as pkl 6 | import subprocess 7 | import json 8 | import cv2 9 | import random 10 | random.seed(224) 11 | from tqdm import tqdm 12 | from scipy import linalg 13 | from pathlib import Path 14 | import pandas as pd 15 | 16 | 17 | import models.vqvae as vqvae 18 | 19 | import sys 20 | sys.path.append(os.environ['DECA_PATH']) 21 | from decalib.deca import DECA 22 | from decalib.utils.config import cfg as deca_cfg 23 | from decalib.datasets import datasets 24 | 25 | from gdl.utils.other import get_path_to_assets 26 | from gdl_apps.EmotionRecognition.utils.io import load_model, test 27 | import scipy.stats as stats 28 | 29 | def calc_pearson(in_features, out_features): 30 | T,F = in_features.shape 31 | res_corr = np.zeros(F) 32 | for f in range(F): 33 | r,p = stats.pearsonr(in_features[:,f], out_features[:,f]) 34 | res_corr[f] = r 35 | return abs(np.mean(np.mean(res_corr, axis=-1))) 36 | 37 | def crosscorr(datax, datay, lag=0, wrap=False): 38 | if wrap: 39 | shiftedy = datay.shift(lag) 40 | shiftedy.iloc[:lag] = datay.iloc[-lag:].values 41 | return datax.corr(shiftedy) 42 | else: 43 | return datax.corr(datay.shift(lag)) 44 | 45 | def face_valence(gt_exp, gt_pose, gt_shape, affect_model): 46 | gt_dict = {"expcode": torch.reshape(gt_exp, (-1, 50)), 47 | "posecode": torch.reshape(gt_pose, (-1, 6)), 48 | "shapecode": torch.reshape(gt_shape, (-1, 100))} 49 | with torch.no_grad(): 50 | gt_affect = affect_model(gt_dict) 51 | return gt_affect['valence'] 52 | 53 | def calculate_diversity(activation, diversity_times): 54 | assert len(activation.shape) == 2 55 | assert activation.shape[0] > diversity_times 56 | num_samples = activation.shape[0] 57 | 58 | first_indices = np.random.choice(num_samples, diversity_times, replace=False) 59 | second_indices = np.random.choice(num_samples, diversity_times, replace=False) 60 | dist = linalg.norm(activation[first_indices] - activation[second_indices], axis=1) 61 | return dist.mean() 62 | 63 | def calculate_activation_statistics(activations): 64 | 65 | mu = np.mean(activations, axis=0) 66 | cov = np.cov(activations, rowvar=False) 67 | return mu, cov 68 | 69 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 70 | 71 | mu1 = np.atleast_1d(mu1) 72 | mu2 = np.atleast_1d(mu2) 73 | 74 | sigma1 = np.atleast_2d(sigma1) 75 | sigma2 = np.atleast_2d(sigma2) 76 | 77 | assert mu1.shape == mu2.shape, \ 78 | 'Training and test mean vectors have different lengths' 79 | assert sigma1.shape == sigma2.shape, \ 80 | 'Training and test covariances have different dimensions' 81 | 82 | diff = mu1 - mu2 83 | 84 | # Product might be almost singular 85 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 86 | if not np.isfinite(covmean).all(): 87 | msg = ('fid calculation produces singular product; ' 88 | 'adding %s to diagonal of cov estimates') % eps 89 | print(msg) 90 | offset = np.eye(sigma1.shape[0]) * eps 91 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 92 | 93 | # Numerical error might give slight imaginary component 94 | if np.iscomplexobj(covmean): 95 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 96 | m = np.max(np.abs(covmean.imag)) 97 | raise ValueError('Imaginary component {}'.format(m)) 98 | covmean = covmean.real 99 | 100 | tr_covmean = np.trace(covmean) 101 | 102 | return (diff.dot(diff) + np.trace(sigma1) 103 | + np.trace(sigma2) - 2 * tr_covmean) 104 | 105 | def main(args): 106 | total_l2 = [] 107 | total_fid = [] 108 | total_fid2 = [] 109 | total_diversity = [] 110 | total_diversity_gt = [] 111 | total_var = [] 112 | total_var_gt = [] 113 | total_windowed_l2v = [] 114 | total_peak_windowed_l2v = [] 115 | total_l2v = [] 116 | processed = [] 117 | 118 | # NOTE: added affect model here 119 | model_name = 'EMOCA-emorec' 120 | path_to_models = get_path_to_assets() /"EmotionRecognition" 121 | path_to_models = path_to_models / "face_reconstruction_based" # for 3dmm model 122 | affect_model = load_model(Path(path_to_models) / model_name) 123 | affect_model.eval() # .cuda() 124 | 125 | segments = torch.load(args.segments_path, map_location='cpu') 126 | segments_dict = {datum['fname']+'_'+str(datum['split_start_frame']): datum for datum in segments} 127 | 128 | frame_map = {} 129 | for index, seg in enumerate(segments): 130 | for i in range(seg['split_start_frame'], seg['split_end_frame']): 131 | frame_map[seg['fname']+'_'+str(i)] = np.concatenate((seg['p0_exp'][i-seg['split_start_frame'],:], seg['p0_pose'][i-seg['split_start_frame'],:], seg['p0_shape'][i-seg['split_start_frame'],:]), axis=0) 132 | 133 | speaker_map = {} 134 | for index, seg in enumerate(segments): 135 | for i in range(seg['split_start_frame'], seg['split_end_frame']): 136 | speaker_map[seg['fname']+'_'+str(i)] = np.concatenate((seg['p1_exp'][i-seg['split_start_frame'],:], seg['p1_pose'][i-seg['split_start_frame'],:], seg['p1_shape'][i-seg['split_start_frame'],:]), axis=0) 137 | 138 | fps = args.fps 139 | fname_pairs = [] 140 | for root, _, files in os.walk(args.output_dir): 141 | for fname in files: 142 | if '_pred.npy' in fname: 143 | # print(fname) 144 | fname_pairs.append((root, fname)) 145 | fname_pairs = sorted(fname_pairs, key=lambda x: '/'.join(os.path.join(x[0], x[1]).split('/')[2:])) 146 | 147 | fids = [] 148 | fid2s = [] 149 | l2s = [] 150 | gt_diversities = [] 151 | pred_diversities = [] 152 | gt_vars = [] 153 | pred_vars = [] 154 | # trevor_videos/done_trevor_videos1/025YouTubetrevor_videos/done_trevor_videos1/025YouTube/025YouTube.mp4_10916 155 | 156 | for root, fname in fname_pairs: 157 | final_name = "_".join(root.split('/')[-3:]) 158 | 159 | pred = np.load(os.path.join(root, fname)).reshape(-1, 56) 160 | # gt = np.load(os.path.join(root, fname.replace('_pred.npy', '_gt.npy')))[:,:56] 161 | root_parts = root.split('/') 162 | # 10946 163 | 164 | if not fname.split('_')[-2].isnumeric(): 165 | continue 166 | start_frame = int(fname.split('_')[-2]) 167 | fn = '/'.join(root_parts[-3:])+'/'+root_parts[-1]+'.mp4' 168 | valid_keys = [x for x in frame_map.keys() if fn in x] 169 | # print(frame_map.keys()) 170 | res = [] 171 | 172 | if pred.shape[0] < args.min_num_frames: 173 | continue 174 | 175 | if any([fn+'_'+str(f) not in frame_map for f in range(start_frame, start_frame+pred.shape[0])]): 176 | print(fn+' NOT FOUND') 177 | continue 178 | 179 | gt = np.stack([frame_map[fn+'_'+str(f)] for f in range(start_frame, start_frame+pred.shape[0])]) 180 | speaker = np.stack([speaker_map[fn+'_'+str(f)] for f in range(start_frame, start_frame+pred.shape[0])]) 181 | 182 | gt_v = face_valence(torch.from_numpy(gt[:,:50]), torch.from_numpy(gt[:,50:56]), torch.from_numpy(gt[:,56:]), affect_model).cpu().detach().numpy() 183 | pred_v = face_valence(torch.from_numpy(pred[:,:50]), torch.from_numpy(pred[:,50:56]), torch.from_numpy(gt[:,56:]), affect_model).cpu().detach().numpy() 184 | 185 | # 1. fid 186 | gt_mu, gt_cov = calculate_activation_statistics(gt[:,:56]) 187 | mu, cov = calculate_activation_statistics(pred[:,:56]) 188 | fid = calculate_frechet_distance(gt_mu, gt_cov, mu, cov) 189 | total_fid.append(fid) 190 | # 2. paired fid 191 | gt_mu2, gt_cov2 = calculate_activation_statistics(np.concatenate([speaker[:,:56], gt[:,:56]], axis=-1)) 192 | mu2, cov2 = calculate_activation_statistics(np.concatenate([speaker[:,:56], pred[:,:56]], axis=-1)) 193 | fid2 = calculate_frechet_distance(gt_mu2, gt_cov2, mu2, cov2) 194 | total_fid2.append(fid2) 195 | # 3. l2 196 | mse = ((gt[:,:56] - pred[:,:56])**2).mean() 197 | total_l2.append(mse) 198 | # 4. diversity 199 | gt_diversity = calculate_diversity(gt[:,:56], 30 if len(gt[:,:56]) > 30 else 10) 200 | pred_diversity = calculate_diversity(pred[:,:56], 30 if len(pred[:,:56]) > 30 else 10) 201 | total_diversity.append(pred_diversity) 202 | total_diversity_gt.append(gt_diversity) 203 | # 5. variance 204 | gt_var = np.mean(np.var(gt[:,:56], axis=0)) 205 | pred_var = np.mean(np.var(pred[:,:56], axis=0)) 206 | total_var.append(pred_var) 207 | total_var_gt.append(gt_var) 208 | # # 7. diff in valence 209 | mse_v = ((gt_v - pred_v)**2).mean() 210 | total_l2v.append(mse_v) 211 | 212 | # Windowed valence 213 | windowed_gt_v = torch.from_numpy(gt_v).view(-1).unfold(dimension=0, size=min(args.valence_window_size, gt_v.shape[0]), step=args.valence_window_size) 214 | index_per_window_gt = windowed_gt_v.abs().argmax(dim=-1) 215 | assert windowed_gt_v.shape[-1] == min(args.valence_window_size, gt_v.shape[0]) 216 | # windowed_gt_v = windowed_gt_v.mean(dim=-1) 217 | windowed_pred_v = torch.from_numpy(pred_v).view(-1).unfold(dimension=0, size=min(args.valence_window_size, pred_v.shape[0]), step=args.valence_window_size) 218 | assert windowed_pred_v.shape[-1] == min(args.valence_window_size, pred_v.shape[0]) 219 | index_per_window_pred = windowed_pred_v.abs().argmax(dim=-1) 220 | value_per_window_gt = windowed_gt_v.gather(dim=1, index=index_per_window_gt.view(-1, 1)) 221 | value_per_window_pred = windowed_pred_v.gather(dim=1, index=index_per_window_pred.view(-1, 1)) 222 | windowed_mse_v = ((value_per_window_pred-value_per_window_gt)**2).mean() 223 | # windowed_pred_v = windowed_pred_v.mean(dim=-1).numpy() 224 | # windowed_mse_v = ((windowed_gt_v-windowed_pred_v)**2).mean() 225 | total_peak_windowed_l2v.append(windowed_mse_v) 226 | 227 | # Windowed valence 228 | windowed_gt_v = torch.from_numpy(gt_v).view(-1, 1).unfold(dimension=0, size=min(args.valence_window_size, gt_v.shape[0]), step=args.valence_window_size) 229 | windowed_gt_v = windowed_gt_v.mean(dim=-1) 230 | windowed_pred_v = torch.from_numpy(pred_v).view(-1, 1).unfold(dimension=0, size=min(args.valence_window_size, pred_v.shape[0]), step=args.valence_window_size) 231 | windowed_pred_v = windowed_pred_v.mean(dim=-1).numpy() 232 | windowed_mse_v = ((windowed_gt_v-windowed_pred_v)**2).mean() 233 | total_windowed_l2v.append(windowed_mse_v) 234 | 235 | processed.append((root, fname)) 236 | 237 | print("l2", np.mean(np.array(total_l2))) 238 | print("windowed avg.l2v", np.mean(np.array(total_windowed_l2v))) 239 | print("fid", np.mean(np.array(total_fid))) 240 | print("fid2", np.mean(np.array(total_fid2))) 241 | print("diversity", np.mean(np.array(total_diversity))) 242 | print("diversity GT", np.mean(np.array(total_diversity_gt))) 243 | print("var", np.mean(np.array(total_var))) 244 | print("var GT", np.mean(np.array(total_var_gt))) 245 | 246 | result = { 247 | "name": args.output_dir, 248 | "l2": str(np.mean(np.array(total_l2))), 249 | "windowed avg.l2v": str(np.mean(np.array(total_windowed_l2v))), 250 | "fid": str(np.mean(np.array(total_fid))), 251 | "fid2": str(np.mean(np.array(total_fid2))), 252 | "diversity": str(np.mean(np.array(total_diversity))), 253 | "var": str(np.mean(np.array(total_var))), 254 | } 255 | 256 | tag = "talkshow" 257 | with open(f"{args.output_dir}/{tag}_eval.json", "w") as f: 258 | json.dump(result, f, indent=2) 259 | with open(f"{args.output_dir}/{tag}_scores.json", "w") as fout: 260 | json.dump({ 261 | 'paths': processed, 262 | 'l2': [float(val) for val in total_l2], 263 | 'fid': [float(val) for val in total_fid], 264 | 'fid2': [float(val) for val in total_fid2], 265 | 'windowed_avg_l2v': [float(val) for val in total_windowed_l2v], 266 | 'diversity': [float(val) for val in total_diversity], 267 | 'diversity_gt': [float(val) for val in total_diversity_gt], 268 | 'var': [float(val) for val in total_var], 269 | 'var_gt': [float(val) for val in total_var_gt], 270 | }, fout) 271 | print(f"dumped to: {args.output_dir}/{tag}_eval.json") 272 | 273 | 274 | if __name__ == "__main__": 275 | parser = argparse.ArgumentParser() 276 | parser.add_argument("--output_dir") 277 | # parser.add_argument("--vq_dir") 278 | parser.add_argument("--segments_path") 279 | parser.add_argument("--default_code_path") 280 | parser.add_argument("--mean_std_path") 281 | parser.add_argument("--fps", type=int, default=30) 282 | parser.add_argument("--valence_window_size", type=int, default=30) 283 | parser.add_argument("--min-num-frames", type=int, default=0) 284 | args = parser.parse_args() 285 | main(args) 286 | -------------------------------------------------------------------------------- /utils/quaternion.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import torch 9 | import numpy as np 10 | 11 | _EPS4 = np.finfo(float).eps * 4.0 12 | 13 | _FLOAT_EPS = np.finfo(float).eps 14 | 15 | # PyTorch-backed implementations 16 | def qinv(q): 17 | assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)' 18 | mask = torch.ones_like(q) 19 | mask[..., 1:] = -mask[..., 1:] 20 | return q * mask 21 | 22 | 23 | def qinv_np(q): 24 | assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)' 25 | return qinv(torch.from_numpy(q).float()).numpy() 26 | 27 | 28 | def qnormalize(q): 29 | assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)' 30 | return q / torch.norm(q, dim=-1, keepdim=True) 31 | 32 | 33 | def qmul(q, r): 34 | """ 35 | Multiply quaternion(s) q with quaternion(s) r. 36 | Expects two equally-sized tensors of shape (*, 4), where * denotes any number of dimensions. 37 | Returns q*r as a tensor of shape (*, 4). 38 | """ 39 | assert q.shape[-1] == 4 40 | assert r.shape[-1] == 4 41 | 42 | original_shape = q.shape 43 | 44 | # Compute outer product 45 | terms = torch.bmm(r.view(-1, 4, 1), q.view(-1, 1, 4)) 46 | 47 | w = terms[:, 0, 0] - terms[:, 1, 1] - terms[:, 2, 2] - terms[:, 3, 3] 48 | x = terms[:, 0, 1] + terms[:, 1, 0] - terms[:, 2, 3] + terms[:, 3, 2] 49 | y = terms[:, 0, 2] + terms[:, 1, 3] + terms[:, 2, 0] - terms[:, 3, 1] 50 | z = terms[:, 0, 3] - terms[:, 1, 2] + terms[:, 2, 1] + terms[:, 3, 0] 51 | return torch.stack((w, x, y, z), dim=1).view(original_shape) 52 | 53 | 54 | def qrot(q, v): 55 | """ 56 | Rotate vector(s) v about the rotation described by quaternion(s) q. 57 | Expects a tensor of shape (*, 4) for q and a tensor of shape (*, 3) for v, 58 | where * denotes any number of dimensions. 59 | Returns a tensor of shape (*, 3). 60 | """ 61 | assert q.shape[-1] == 4 62 | assert v.shape[-1] == 3 63 | assert q.shape[:-1] == v.shape[:-1] 64 | 65 | original_shape = list(v.shape) 66 | # print(q.shape) 67 | q = q.contiguous().view(-1, 4) 68 | v = v.contiguous().view(-1, 3) 69 | 70 | qvec = q[:, 1:] 71 | uv = torch.cross(qvec, v, dim=1) 72 | uuv = torch.cross(qvec, uv, dim=1) 73 | return (v + 2 * (q[:, :1] * uv + uuv)).view(original_shape) 74 | 75 | 76 | def qeuler(q, order, epsilon=0, deg=True): 77 | """ 78 | Convert quaternion(s) q to Euler angles. 79 | Expects a tensor of shape (*, 4), where * denotes any number of dimensions. 80 | Returns a tensor of shape (*, 3). 81 | """ 82 | assert q.shape[-1] == 4 83 | 84 | original_shape = list(q.shape) 85 | original_shape[-1] = 3 86 | q = q.view(-1, 4) 87 | 88 | q0 = q[:, 0] 89 | q1 = q[:, 1] 90 | q2 = q[:, 2] 91 | q3 = q[:, 3] 92 | 93 | if order == 'xyz': 94 | x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2)) 95 | y = torch.asin(torch.clamp(2 * (q1 * q3 + q0 * q2), -1 + epsilon, 1 - epsilon)) 96 | z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3)) 97 | elif order == 'yzx': 98 | x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3)) 99 | y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3)) 100 | z = torch.asin(torch.clamp(2 * (q1 * q2 + q0 * q3), -1 + epsilon, 1 - epsilon)) 101 | elif order == 'zxy': 102 | x = torch.asin(torch.clamp(2 * (q0 * q1 + q2 * q3), -1 + epsilon, 1 - epsilon)) 103 | y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q1 * q1 + q2 * q2)) 104 | z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q1 * q1 + q3 * q3)) 105 | elif order == 'xzy': 106 | x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3)) 107 | y = torch.atan2(2 * (q0 * q2 + q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3)) 108 | z = torch.asin(torch.clamp(2 * (q0 * q3 - q1 * q2), -1 + epsilon, 1 - epsilon)) 109 | elif order == 'yxz': 110 | x = torch.asin(torch.clamp(2 * (q0 * q1 - q2 * q3), -1 + epsilon, 1 - epsilon)) 111 | y = torch.atan2(2 * (q1 * q3 + q0 * q2), 1 - 2 * (q1 * q1 + q2 * q2)) 112 | z = torch.atan2(2 * (q1 * q2 + q0 * q3), 1 - 2 * (q1 * q1 + q3 * q3)) 113 | elif order == 'zyx': 114 | x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2)) 115 | y = torch.asin(torch.clamp(2 * (q0 * q2 - q1 * q3), -1 + epsilon, 1 - epsilon)) 116 | z = torch.atan2(2 * (q0 * q3 + q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3)) 117 | else: 118 | raise 119 | 120 | if deg: 121 | return torch.stack((x, y, z), dim=1).view(original_shape) * 180 / np.pi 122 | else: 123 | return torch.stack((x, y, z), dim=1).view(original_shape) 124 | 125 | 126 | # Numpy-backed implementations 127 | 128 | def qmul_np(q, r): 129 | q = torch.from_numpy(q).contiguous().float() 130 | r = torch.from_numpy(r).contiguous().float() 131 | return qmul(q, r).numpy() 132 | 133 | 134 | def qrot_np(q, v): 135 | q = torch.from_numpy(q).contiguous().float() 136 | v = torch.from_numpy(v).contiguous().float() 137 | return qrot(q, v).numpy() 138 | 139 | 140 | def qeuler_np(q, order, epsilon=0, use_gpu=False): 141 | if use_gpu: 142 | q = torch.from_numpy(q).cuda().float() 143 | return qeuler(q, order, epsilon).cpu().numpy() 144 | else: 145 | q = torch.from_numpy(q).contiguous().float() 146 | return qeuler(q, order, epsilon).numpy() 147 | 148 | 149 | def qfix(q): 150 | """ 151 | Enforce quaternion continuity across the time dimension by selecting 152 | the representation (q or -q) with minimal distance (or, equivalently, maximal dot product) 153 | between two consecutive frames. 154 | 155 | Expects a tensor of shape (L, J, 4), where L is the sequence length and J is the number of joints. 156 | Returns a tensor of the same shape. 157 | """ 158 | assert len(q.shape) == 3 159 | assert q.shape[-1] == 4 160 | 161 | result = q.copy() 162 | dot_products = np.sum(q[1:] * q[:-1], axis=2) 163 | mask = dot_products < 0 164 | mask = (np.cumsum(mask, axis=0) % 2).astype(bool) 165 | result[1:][mask] *= -1 166 | return result 167 | 168 | 169 | def euler2quat(e, order, deg=True): 170 | """ 171 | Convert Euler angles to quaternions. 172 | """ 173 | assert e.shape[-1] == 3 174 | 175 | original_shape = list(e.shape) 176 | original_shape[-1] = 4 177 | 178 | e = e.view(-1, 3) 179 | 180 | ## if euler angles in degrees 181 | if deg: 182 | e = e * np.pi / 180. 183 | 184 | x = e[:, 0] 185 | y = e[:, 1] 186 | z = e[:, 2] 187 | 188 | rx = torch.stack((torch.cos(x / 2), torch.sin(x / 2), torch.zeros_like(x), torch.zeros_like(x)), dim=1) 189 | ry = torch.stack((torch.cos(y / 2), torch.zeros_like(y), torch.sin(y / 2), torch.zeros_like(y)), dim=1) 190 | rz = torch.stack((torch.cos(z / 2), torch.zeros_like(z), torch.zeros_like(z), torch.sin(z / 2)), dim=1) 191 | 192 | result = None 193 | for coord in order: 194 | if coord == 'x': 195 | r = rx 196 | elif coord == 'y': 197 | r = ry 198 | elif coord == 'z': 199 | r = rz 200 | else: 201 | raise 202 | if result is None: 203 | result = r 204 | else: 205 | result = qmul(result, r) 206 | 207 | # Reverse antipodal representation to have a non-negative "w" 208 | if order in ['xyz', 'yzx', 'zxy']: 209 | result *= -1 210 | 211 | return result.view(original_shape) 212 | 213 | 214 | def expmap_to_quaternion(e): 215 | """ 216 | Convert axis-angle rotations (aka exponential maps) to quaternions. 217 | Stable formula from "Practical Parameterization of Rotations Using the Exponential Map". 218 | Expects a tensor of shape (*, 3), where * denotes any number of dimensions. 219 | Returns a tensor of shape (*, 4). 220 | """ 221 | assert e.shape[-1] == 3 222 | 223 | original_shape = list(e.shape) 224 | original_shape[-1] = 4 225 | e = e.reshape(-1, 3) 226 | 227 | theta = np.linalg.norm(e, axis=1).reshape(-1, 1) 228 | w = np.cos(0.5 * theta).reshape(-1, 1) 229 | xyz = 0.5 * np.sinc(0.5 * theta / np.pi) * e 230 | return np.concatenate((w, xyz), axis=1).reshape(original_shape) 231 | 232 | 233 | def euler_to_quaternion(e, order): 234 | """ 235 | Convert Euler angles to quaternions. 236 | """ 237 | assert e.shape[-1] == 3 238 | 239 | original_shape = list(e.shape) 240 | original_shape[-1] = 4 241 | 242 | e = e.reshape(-1, 3) 243 | 244 | x = e[:, 0] 245 | y = e[:, 1] 246 | z = e[:, 2] 247 | 248 | rx = np.stack((np.cos(x / 2), np.sin(x / 2), np.zeros_like(x), np.zeros_like(x)), axis=1) 249 | ry = np.stack((np.cos(y / 2), np.zeros_like(y), np.sin(y / 2), np.zeros_like(y)), axis=1) 250 | rz = np.stack((np.cos(z / 2), np.zeros_like(z), np.zeros_like(z), np.sin(z / 2)), axis=1) 251 | 252 | result = None 253 | for coord in order: 254 | if coord == 'x': 255 | r = rx 256 | elif coord == 'y': 257 | r = ry 258 | elif coord == 'z': 259 | r = rz 260 | else: 261 | raise 262 | if result is None: 263 | result = r 264 | else: 265 | result = qmul_np(result, r) 266 | 267 | # Reverse antipodal representation to have a non-negative "w" 268 | if order in ['xyz', 'yzx', 'zxy']: 269 | result *= -1 270 | 271 | return result.reshape(original_shape) 272 | 273 | 274 | def quaternion_to_matrix(quaternions): 275 | """ 276 | Convert rotations given as quaternions to rotation matrices. 277 | Args: 278 | quaternions: quaternions with real part first, 279 | as tensor of shape (..., 4). 280 | Returns: 281 | Rotation matrices as tensor of shape (..., 3, 3). 282 | """ 283 | r, i, j, k = torch.unbind(quaternions, -1) 284 | two_s = 2.0 / (quaternions * quaternions).sum(-1) 285 | 286 | o = torch.stack( 287 | ( 288 | 1 - two_s * (j * j + k * k), 289 | two_s * (i * j - k * r), 290 | two_s * (i * k + j * r), 291 | two_s * (i * j + k * r), 292 | 1 - two_s * (i * i + k * k), 293 | two_s * (j * k - i * r), 294 | two_s * (i * k - j * r), 295 | two_s * (j * k + i * r), 296 | 1 - two_s * (i * i + j * j), 297 | ), 298 | -1, 299 | ) 300 | return o.reshape(quaternions.shape[:-1] + (3, 3)) 301 | 302 | 303 | def quaternion_to_matrix_np(quaternions): 304 | q = torch.from_numpy(quaternions).contiguous().float() 305 | return quaternion_to_matrix(q).numpy() 306 | 307 | 308 | def quaternion_to_cont6d_np(quaternions): 309 | rotation_mat = quaternion_to_matrix_np(quaternions) 310 | cont_6d = np.concatenate([rotation_mat[..., 0], rotation_mat[..., 1]], axis=-1) 311 | return cont_6d 312 | 313 | 314 | def quaternion_to_cont6d(quaternions): 315 | rotation_mat = quaternion_to_matrix(quaternions) 316 | cont_6d = torch.cat([rotation_mat[..., 0], rotation_mat[..., 1]], dim=-1) 317 | return cont_6d 318 | 319 | 320 | def cont6d_to_matrix(cont6d): 321 | assert cont6d.shape[-1] == 6, "The last dimension must be 6" 322 | x_raw = cont6d[..., 0:3] 323 | y_raw = cont6d[..., 3:6] 324 | 325 | x = x_raw / torch.norm(x_raw, dim=-1, keepdim=True) 326 | z = torch.cross(x, y_raw, dim=-1) 327 | z = z / torch.norm(z, dim=-1, keepdim=True) 328 | 329 | y = torch.cross(z, x, dim=-1) 330 | 331 | x = x[..., None] 332 | y = y[..., None] 333 | z = z[..., None] 334 | 335 | mat = torch.cat([x, y, z], dim=-1) 336 | return mat 337 | 338 | 339 | def cont6d_to_matrix_np(cont6d): 340 | q = torch.from_numpy(cont6d).contiguous().float() 341 | return cont6d_to_matrix(q).numpy() 342 | 343 | 344 | def qpow(q0, t, dtype=torch.float): 345 | ''' q0 : tensor of quaternions 346 | t: tensor of powers 347 | ''' 348 | q0 = qnormalize(q0) 349 | theta0 = torch.acos(q0[..., 0]) 350 | 351 | ## if theta0 is close to zero, add epsilon to avoid NaNs 352 | mask = (theta0 <= 10e-10) * (theta0 >= -10e-10) 353 | theta0 = (1 - mask) * theta0 + mask * 10e-10 354 | v0 = q0[..., 1:] / torch.sin(theta0).view(-1, 1) 355 | 356 | if isinstance(t, torch.Tensor): 357 | q = torch.zeros(t.shape + q0.shape) 358 | theta = t.view(-1, 1) * theta0.view(1, -1) 359 | else: ## if t is a number 360 | q = torch.zeros(q0.shape) 361 | theta = t * theta0 362 | 363 | q[..., 0] = torch.cos(theta) 364 | q[..., 1:] = v0 * torch.sin(theta).unsqueeze(-1) 365 | 366 | return q.to(dtype) 367 | 368 | 369 | def qslerp(q0, q1, t): 370 | ''' 371 | q0: starting quaternion 372 | q1: ending quaternion 373 | t: array of points along the way 374 | 375 | Returns: 376 | Tensor of Slerps: t.shape + q0.shape 377 | ''' 378 | 379 | q0 = qnormalize(q0) 380 | q1 = qnormalize(q1) 381 | q_ = qpow(qmul(q1, qinv(q0)), t) 382 | 383 | return qmul(q_, 384 | q0.contiguous().view(torch.Size([1] * len(t.shape)) + q0.shape).expand(t.shape + q0.shape).contiguous()) 385 | 386 | 387 | def qbetween(v0, v1): 388 | ''' 389 | find the quaternion used to rotate v0 to v1 390 | ''' 391 | assert v0.shape[-1] == 3, 'v0 must be of the shape (*, 3)' 392 | assert v1.shape[-1] == 3, 'v1 must be of the shape (*, 3)' 393 | 394 | v = torch.cross(v0, v1) 395 | w = torch.sqrt((v0 ** 2).sum(dim=-1, keepdim=True) * (v1 ** 2).sum(dim=-1, keepdim=True)) + (v0 * v1).sum(dim=-1, 396 | keepdim=True) 397 | return qnormalize(torch.cat([w, v], dim=-1)) 398 | 399 | 400 | def qbetween_np(v0, v1): 401 | ''' 402 | find the quaternion used to rotate v0 to v1 403 | ''' 404 | assert v0.shape[-1] == 3, 'v0 must be of the shape (*, 3)' 405 | assert v1.shape[-1] == 3, 'v1 must be of the shape (*, 3)' 406 | 407 | v0 = torch.from_numpy(v0).float() 408 | v1 = torch.from_numpy(v1).float() 409 | return qbetween(v0, v1).numpy() 410 | 411 | 412 | def lerp(p0, p1, t): 413 | if not isinstance(t, torch.Tensor): 414 | t = torch.Tensor([t]) 415 | 416 | new_shape = t.shape + p0.shape 417 | new_view_t = t.shape + torch.Size([1] * len(p0.shape)) 418 | new_view_p = torch.Size([1] * len(t.shape)) + p0.shape 419 | p0 = p0.view(new_view_p).expand(new_shape) 420 | p1 = p1.view(new_view_p).expand(new_shape) 421 | t = t.view(new_view_t).expand(new_shape) 422 | 423 | return p0 + t * (p1 - p0) 424 | -------------------------------------------------------------------------------- /models/quantize_cnn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | class QuantizeEMAReset(nn.Module): 7 | def __init__(self, nb_code, code_dim, args): 8 | super().__init__() 9 | self.nb_code = nb_code 10 | self.code_dim = code_dim 11 | self.mu = args.mu 12 | self.reset_codebook() 13 | 14 | def reset_codebook(self): 15 | self.init = False 16 | self.code_sum = None 17 | self.code_count = None 18 | self.register_buffer('codebook', torch.zeros(self.nb_code, self.code_dim).cuda()) 19 | 20 | def _tile(self, x): 21 | nb_code_x, code_dim = x.shape 22 | if nb_code_x < self.nb_code: 23 | n_repeats = (self.nb_code + nb_code_x - 1) // nb_code_x 24 | std = 0.01 / np.sqrt(code_dim) 25 | out = x.repeat(n_repeats, 1) 26 | out = out + torch.randn_like(out) * std 27 | else : 28 | out = x 29 | return out 30 | 31 | def init_codebook(self, x): 32 | out = self._tile(x) 33 | self.codebook = out[:self.nb_code] 34 | self.code_sum = self.codebook.clone() 35 | self.code_count = torch.ones(self.nb_code, device=self.codebook.device) 36 | self.init = True 37 | 38 | @torch.no_grad() 39 | def compute_perplexity(self, code_idx) : 40 | # Calculate new centres 41 | code_onehot = torch.zeros(self.nb_code, code_idx.shape[0], device=code_idx.device) # nb_code, N * L 42 | code_onehot.scatter_(0, code_idx.view(1, code_idx.shape[0]), 1) 43 | 44 | code_count = code_onehot.sum(dim=-1) # nb_code 45 | prob = code_count / torch.sum(code_count) 46 | perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7))) 47 | return perplexity 48 | 49 | @torch.no_grad() 50 | def update_codebook(self, x, code_idx): 51 | 52 | code_onehot = torch.zeros(self.nb_code, x.shape[0], device=x.device) # nb_code, N * L 53 | code_onehot.scatter_(0, code_idx.view(1, x.shape[0]), 1) 54 | 55 | code_sum = torch.matmul(code_onehot, x) # nb_code, w 56 | code_count = code_onehot.sum(dim=-1) # nb_code 57 | 58 | out = self._tile(x) 59 | code_rand = out[:self.nb_code] 60 | 61 | # Update centres 62 | self.code_sum = self.mu * self.code_sum + (1. - self.mu) * code_sum # w, nb_code 63 | self.code_count = self.mu * self.code_count + (1. - self.mu) * code_count # nb_code 64 | 65 | usage = (self.code_count.view(self.nb_code, 1) >= 1.0).float() 66 | code_update = self.code_sum.view(self.nb_code, self.code_dim) / self.code_count.view(self.nb_code, 1) 67 | 68 | self.codebook = usage * code_update + (1 - usage) * code_rand 69 | prob = code_count / torch.sum(code_count) 70 | perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7))) 71 | 72 | 73 | return perplexity 74 | 75 | def preprocess(self, x): 76 | # NCT -> NTC -> [NT, C] 77 | x = x.permute(0, 2, 1).contiguous() 78 | x = x.view(-1, x.shape[-1]) 79 | return x 80 | 81 | def quantize(self, x): 82 | # Calculate latent code x_l 83 | k_w = self.codebook.t() 84 | distance = torch.sum(x ** 2, dim=-1, keepdim=True) - 2 * torch.matmul(x, k_w) + torch.sum(k_w ** 2, dim=0, 85 | keepdim=True) # (N * L, b) 86 | _, code_idx = torch.min(distance, dim=-1) 87 | return code_idx 88 | 89 | def dequantize(self, code_idx): 90 | x = F.embedding(code_idx, self.codebook) 91 | return x 92 | 93 | 94 | def forward(self, x): 95 | N, width, T = x.shape 96 | 97 | # Preprocess 98 | x = self.preprocess(x) 99 | 100 | # Init codebook if not inited 101 | if self.training and not self.init: 102 | self.init_codebook(x) 103 | 104 | # quantize and dequantize through bottleneck 105 | code_idx = self.quantize(x) 106 | x_d = self.dequantize(code_idx) 107 | 108 | # Update embeddings 109 | if self.training: 110 | perplexity = self.update_codebook(x, code_idx) 111 | else : 112 | perplexity = self.compute_perplexity(code_idx) 113 | 114 | # Loss 115 | commit_loss = F.mse_loss(x, x_d.detach()) 116 | 117 | # Passthrough 118 | x_d = x + (x_d - x).detach() 119 | 120 | # Postprocess 121 | x_d = x_d.view(N, T, -1).permute(0, 2, 1).contiguous() #(N, DIM, T) 122 | 123 | return x_d, commit_loss, perplexity 124 | 125 | 126 | 127 | class Quantizer(nn.Module): 128 | def __init__(self, n_e, e_dim, beta): 129 | super(Quantizer, self).__init__() 130 | 131 | self.e_dim = e_dim 132 | self.n_e = n_e 133 | self.beta = beta 134 | 135 | self.embedding = nn.Embedding(self.n_e, self.e_dim) 136 | self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) 137 | 138 | def forward(self, z): 139 | 140 | N, width, T = z.shape 141 | z = self.preprocess(z) 142 | assert z.shape[-1] == self.e_dim 143 | z_flattened = z.contiguous().view(-1, self.e_dim) 144 | 145 | # B x V 146 | d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ 147 | torch.sum(self.embedding.weight**2, dim=1) - 2 * \ 148 | torch.matmul(z_flattened, self.embedding.weight.t()) 149 | # B x 1 150 | min_encoding_indices = torch.argmin(d, dim=1) 151 | z_q = self.embedding(min_encoding_indices).view(z.shape) 152 | 153 | # compute loss for embedding 154 | loss = torch.mean((z_q - z.detach())**2) + self.beta * \ 155 | torch.mean((z_q.detach() - z)**2) 156 | 157 | # preserve gradients 158 | z_q = z + (z_q - z).detach() 159 | z_q = z_q.view(N, T, -1).permute(0, 2, 1).contiguous() #(N, DIM, T) 160 | 161 | min_encodings = F.one_hot(min_encoding_indices, self.n_e).type(z.dtype) 162 | e_mean = torch.mean(min_encodings, dim=0) 163 | perplexity = torch.exp(-torch.sum(e_mean*torch.log(e_mean + 1e-10))) 164 | return z_q, loss, perplexity 165 | 166 | def quantize(self, z): 167 | 168 | assert z.shape[-1] == self.e_dim 169 | 170 | # B x V 171 | d = torch.sum(z ** 2, dim=1, keepdim=True) + \ 172 | torch.sum(self.embedding.weight ** 2, dim=1) - 2 * \ 173 | torch.matmul(z, self.embedding.weight.t()) 174 | # B x 1 175 | min_encoding_indices = torch.argmin(d, dim=1) 176 | return min_encoding_indices 177 | 178 | def dequantize(self, indices): 179 | 180 | index_flattened = indices.view(-1) 181 | z_q = self.embedding(index_flattened) 182 | z_q = z_q.view(indices.shape + (self.e_dim, )).contiguous() 183 | return z_q 184 | 185 | def preprocess(self, x): 186 | # NCT -> NTC -> [NT, C] 187 | x = x.permute(0, 2, 1).contiguous() 188 | x = x.view(-1, x.shape[-1]) 189 | return x 190 | 191 | 192 | 193 | class QuantizeReset(nn.Module): 194 | def __init__(self, nb_code, code_dim, args): 195 | super().__init__() 196 | self.nb_code = nb_code 197 | self.code_dim = code_dim 198 | self.reset_codebook() 199 | self.codebook = nn.Parameter(torch.randn(nb_code, code_dim)) 200 | 201 | def reset_codebook(self): 202 | self.init = False 203 | self.code_count = None 204 | 205 | def _tile(self, x): 206 | nb_code_x, code_dim = x.shape 207 | if nb_code_x < self.nb_code: 208 | n_repeats = (self.nb_code + nb_code_x - 1) // nb_code_x 209 | std = 0.01 / np.sqrt(code_dim) 210 | out = x.repeat(n_repeats, 1) 211 | out = out + torch.randn_like(out) * std 212 | else : 213 | out = x 214 | return out 215 | 216 | def init_codebook(self, x): 217 | out = self._tile(x) 218 | self.codebook = nn.Parameter(out[:self.nb_code]) 219 | self.code_count = torch.ones(self.nb_code, device=self.codebook.device) 220 | self.init = True 221 | 222 | @torch.no_grad() 223 | def compute_perplexity(self, code_idx) : 224 | # Calculate new centres 225 | code_onehot = torch.zeros(self.nb_code, code_idx.shape[0], device=code_idx.device) # nb_code, N * L 226 | code_onehot.scatter_(0, code_idx.view(1, code_idx.shape[0]), 1) 227 | 228 | code_count = code_onehot.sum(dim=-1) # nb_code 229 | prob = code_count / torch.sum(code_count) 230 | perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7))) 231 | return perplexity 232 | 233 | def update_codebook(self, x, code_idx): 234 | 235 | code_onehot = torch.zeros(self.nb_code, x.shape[0], device=x.device) # nb_code, N * L 236 | code_onehot.scatter_(0, code_idx.view(1, x.shape[0]), 1) 237 | 238 | code_count = code_onehot.sum(dim=-1) # nb_code 239 | 240 | out = self._tile(x) 241 | code_rand = out[:self.nb_code] 242 | 243 | # Update centres 244 | self.code_count = code_count # nb_code 245 | usage = (self.code_count.view(self.nb_code, 1) >= 1.0).float() 246 | 247 | self.codebook.data = usage * self.codebook.data + (1 - usage) * code_rand 248 | prob = code_count / torch.sum(code_count) 249 | perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7))) 250 | 251 | 252 | return perplexity 253 | 254 | def preprocess(self, x): 255 | # NCT -> NTC -> [NT, C] 256 | x = x.permute(0, 2, 1).contiguous() 257 | x = x.view(-1, x.shape[-1]) 258 | return x 259 | 260 | def quantize(self, x): 261 | # Calculate latent code x_l 262 | k_w = self.codebook.t() 263 | distance = torch.sum(x ** 2, dim=-1, keepdim=True) - 2 * torch.matmul(x, k_w) + torch.sum(k_w ** 2, dim=0, 264 | keepdim=True) # (N * L, b) 265 | _, code_idx = torch.min(distance, dim=-1) 266 | return code_idx 267 | 268 | def dequantize(self, code_idx): 269 | x = F.embedding(code_idx, self.codebook) 270 | return x 271 | 272 | 273 | def forward(self, x): 274 | N, width, T = x.shape 275 | # Preprocess 276 | x = self.preprocess(x) 277 | # Init codebook if not inited 278 | if self.training and not self.init: 279 | self.init_codebook(x) 280 | # quantize and dequantize through bottleneck 281 | code_idx = self.quantize(x) 282 | x_d = self.dequantize(code_idx) 283 | # Update embeddings 284 | if self.training: 285 | perplexity = self.update_codebook(x, code_idx) 286 | else : 287 | perplexity = self.compute_perplexity(code_idx) 288 | 289 | # Loss 290 | commit_loss = F.mse_loss(x, x_d.detach()) 291 | 292 | # Passthrough 293 | x_d = x + (x_d - x).detach() 294 | 295 | # Postprocess 296 | x_d = x_d.view(N, T, -1).permute(0, 2, 1).contiguous() #(N, DIM, T) 297 | 298 | return x_d, commit_loss, perplexity 299 | 300 | 301 | class QuantizeEMA(nn.Module): 302 | def __init__(self, nb_code, code_dim, args): 303 | super().__init__() 304 | self.nb_code = nb_code 305 | self.code_dim = code_dim 306 | self.mu = 0.99 307 | self.reset_codebook() 308 | 309 | def reset_codebook(self): 310 | self.init = False 311 | self.code_sum = None 312 | self.code_count = None 313 | self.register_buffer('codebook', torch.zeros(self.nb_code, self.code_dim).cuda()) 314 | 315 | def _tile(self, x): 316 | nb_code_x, code_dim = x.shape 317 | if nb_code_x < self.nb_code: 318 | n_repeats = (self.nb_code + nb_code_x - 1) // nb_code_x 319 | std = 0.01 / np.sqrt(code_dim) 320 | out = x.repeat(n_repeats, 1) 321 | out = out + torch.randn_like(out) * std 322 | else : 323 | out = x 324 | return out 325 | 326 | def init_codebook(self, x): 327 | out = self._tile(x) 328 | self.codebook = out[:self.nb_code] 329 | self.code_sum = self.codebook.clone() 330 | self.code_count = torch.ones(self.nb_code, device=self.codebook.device) 331 | self.init = True 332 | 333 | @torch.no_grad() 334 | def compute_perplexity(self, code_idx) : 335 | # Calculate new centres 336 | code_onehot = torch.zeros(self.nb_code, code_idx.shape[0], device=code_idx.device) # nb_code, N * L 337 | code_onehot.scatter_(0, code_idx.view(1, code_idx.shape[0]), 1) 338 | 339 | code_count = code_onehot.sum(dim=-1) # nb_code 340 | prob = code_count / torch.sum(code_count) 341 | perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7))) 342 | return perplexity 343 | 344 | @torch.no_grad() 345 | def update_codebook(self, x, code_idx): 346 | 347 | code_onehot = torch.zeros(self.nb_code, x.shape[0], device=x.device) # nb_code, N * L 348 | code_onehot.scatter_(0, code_idx.view(1, x.shape[0]), 1) 349 | 350 | code_sum = torch.matmul(code_onehot, x) # nb_code, w 351 | code_count = code_onehot.sum(dim=-1) # nb_code 352 | 353 | # Update centres 354 | self.code_sum = self.mu * self.code_sum + (1. - self.mu) * code_sum # w, nb_code 355 | self.code_count = self.mu * self.code_count + (1. - self.mu) * code_count # nb_code 356 | 357 | code_update = self.code_sum.view(self.nb_code, self.code_dim) / self.code_count.view(self.nb_code, 1) 358 | 359 | self.codebook = code_update 360 | prob = code_count / torch.sum(code_count) 361 | perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7))) 362 | 363 | return perplexity 364 | 365 | def preprocess(self, x): 366 | # NCT -> NTC -> [NT, C] 367 | x = x.permute(0, 2, 1).contiguous() 368 | x = x.view(-1, x.shape[-1]) 369 | return x 370 | 371 | def quantize(self, x): 372 | # Calculate latent code x_l 373 | k_w = self.codebook.t() 374 | distance = torch.sum(x ** 2, dim=-1, keepdim=True) - 2 * torch.matmul(x, k_w) + torch.sum(k_w ** 2, dim=0, 375 | keepdim=True) # (N * L, b) 376 | _, code_idx = torch.min(distance, dim=-1) 377 | return code_idx 378 | 379 | def dequantize(self, code_idx): 380 | x = F.embedding(code_idx, self.codebook) 381 | return x 382 | 383 | 384 | def forward(self, x): 385 | N, width, T = x.shape 386 | 387 | # Preprocess 388 | x = self.preprocess(x) 389 | 390 | # Init codebook if not inited 391 | if self.training and not self.init: 392 | self.init_codebook(x) 393 | 394 | # quantize and dequantize through bottleneck 395 | code_idx = self.quantize(x) 396 | x_d = self.dequantize(code_idx) 397 | 398 | # Update embeddings 399 | if self.training: 400 | perplexity = self.update_codebook(x, code_idx) 401 | else : 402 | perplexity = self.compute_perplexity(code_idx) 403 | 404 | # Loss 405 | commit_loss = F.mse_loss(x, x_d.detach()) 406 | 407 | # Passthrough 408 | x_d = x + (x_d - x).detach() 409 | 410 | # Postprocess 411 | x_d = x_d.view(N, T, -1).permute(0, 2, 1).contiguous() #(N, DIM, T) 412 | 413 | return x_d, commit_loss, perplexity 414 | -------------------------------------------------------------------------------- /train_t2m_trans.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | 5 | from torch.utils.tensorboard import SummaryWriter 6 | from os.path import join as pjoin 7 | from torch.distributions import Categorical 8 | import json 9 | import clip 10 | from transformers import AutoModel, AutoTokenizer, AutoConfig, get_scheduler 11 | 12 | import options.option_transformer as option_trans 13 | import models.vqvae as vqvae 14 | import utils.utils_model as utils_model 15 | import utils.eval_trans as eval_trans 16 | from dataset import dataset_TM_train 17 | from dataset import dataset_TM_eval 18 | from dataset import dataset_tokenize 19 | import models.t2m_trans as trans 20 | from options.get_eval_option import get_opt 21 | from models.evaluator_wrapper import EvaluatorModelWrapper 22 | import warnings 23 | from argparse import Namespace 24 | warnings.filterwarnings('ignore') 25 | 26 | ##### ---- Exp dirs ---- ##### 27 | args = option_trans.get_args_parser() 28 | torch.manual_seed(args.seed) 29 | 30 | args.out_dir = os.path.join(args.out_dir, f'{args.exp_name}') 31 | args.vq_dir = './dataset/HumanML3D' 32 | if args.dataname == 'kit': 33 | args.vq_dir = './dataset/KIT-ML' 34 | elif args.dataname.split('_')[0] == 'face': 35 | args.vq_dir = './dataset/'+args.dataname.split('_')[1] 36 | args.vq_dir = os.path.join(args.vq_dir, args.vq_name) 37 | 38 | os.makedirs(args.out_dir, exist_ok = True) 39 | os.makedirs(args.vq_dir, exist_ok = True) 40 | 41 | ##### ---- Logger ---- ##### 42 | logger = utils_model.get_logger(args.out_dir) 43 | writer = SummaryWriter(args.out_dir) 44 | logger.info(json.dumps(vars(args), indent=4, sort_keys=True)) 45 | with open(os.path.join(args.out_dir, "config.json"), 'w') as f: 46 | json.dump(vars(args), f, indent=4, sort_keys=True) 47 | 48 | ##### ---- Dataloader ---- ##### 49 | eval_split = "val" 50 | if args.test_eval: 51 | eval_split = "test" 52 | train_loader_token = dataset_tokenize.DATALoader(args.dataname, 1, unit_length=2**args.down_t, max_motion_length=args.max_motion_length, split="train", delay_start_frames=args.delay_start_frames, fps=args.fps[0], min_length=args.train_min_length) 53 | val_loader_token = dataset_tokenize.DATALoader(args.dataname, 1, unit_length=2**args.down_t, max_motion_length=args.max_motion_length, split=eval_split, delay_start_frames=args.delay_start_frames, fps=args.fps[0], min_length=args.val_min_length) 54 | 55 | 56 | ##### ---- Network ---- ##### 57 | if args.gpt2 is None: 58 | if "openai/clip" in args.text_model_name: 59 | text_model, clip_preprocess = clip.load("ViT-B/32", device=torch.device('cuda'), jit=False) 60 | else: 61 | text_model = AutoModel.from_pretrained(args.text_model_name).cuda() 62 | text_tokenizer = AutoTokenizer.from_pretrained(args.text_model_name) 63 | for p in text_model.parameters(): 64 | p.requires_grad = False 65 | 66 | net = vqvae.HumanVQVAE(args, ## use args to define different parameters in different quantizers 67 | args.nb_code, 68 | args.code_dim, 69 | args.output_emb_width, 70 | args.down_t, 71 | args.stride_t, 72 | args.width, 73 | args.depth, 74 | args.dilation_growth_rate) 75 | 76 | args.extra_input_dim={} 77 | if args.gpt2 is not None: 78 | trans_encoder = trans.GPT2MotionTransformer(num_vq=args.nb_code, num_input_vq=(0 if args.speaker_vq_path is None else speaker_vq_args.nb_code), model_name=args.gpt2, top_p=args.top_p, extra_input_dim=args.extra_input_dim, freeze_lm=args.freeze_lm, output_layers=args.num_output_layers, not_pretrained=args.transformer_not_pretrained, gradient_checkpointing=args.gradient_checkpointing, predict_input_vq=args.speaker_vq_loss) 79 | text_model = trans_encoder 80 | text_tokenizer = None 81 | gpt2_config = AutoConfig.from_pretrained(args.gpt2) 82 | else: 83 | trans_encoder = trans.Text2Motion_Transformer(num_vq=args.nb_code, 84 | embed_dim=args.embed_dim_gpt, 85 | clip_dim=args.clip_dim, 86 | block_size=args.block_size, 87 | num_layers=args.num_layers, 88 | n_head=args.n_head_gpt, 89 | drop_out_rate=args.drop_out_rate, 90 | fc_rate=args.ff_rate, 91 | extra_dim=args.extra_input_dim, 92 | top_p=args.top_p) 93 | gpt2_config = None 94 | 95 | 96 | print ('loading checkpoint from {}'.format(args.resume_pth)) 97 | ckpt = torch.load(args.resume_pth, map_location='cpu') 98 | net.load_state_dict(utils_model.convert_vq_state_dict(ckpt['net']), strict=True) 99 | net.eval() 100 | net.cuda() 101 | 102 | if args.resume_trans is not None: 103 | print ('loading transformer checkpoint from {}'.format(args.resume_trans)) 104 | ckpt = torch.load(args.resume_trans, map_location='cpu') 105 | trans_encoder.load_state_dict(ckpt['trans'], strict=True) 106 | trans_encoder.train() 107 | if args.fp16_half: 108 | trans_encoder = trans_encoder.half() 109 | if args.manual_bf16: 110 | trans_encoder = trans_encoder.bfloat16() 111 | trans_encoder.cuda() 112 | 113 | ##### ---- Optimizer & Scheduler ---- ##### 114 | optimizer = utils_model.initial_optim(args.decay_option, args.lr, args.weight_decay, trans_encoder, args.optimizer) 115 | if args.linear_scheduler: 116 | scheduler = get_scheduler(name="linear", optimizer=optimizer, num_warmup_steps=args.warm_up_iter, num_training_steps=args.total_iter//args.gradient_accumulation_steps) 117 | else: 118 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_scheduler, gamma=args.gamma) 119 | scaler = torch.cuda.amp.GradScaler() 120 | 121 | ##### ---- Optimization goals ---- ##### 122 | loss_ce = torch.nn.CrossEntropyLoss(label_smoothing=args.label_smoothing) 123 | 124 | nb_iter, avg_loss_cls, avg_acc = 0, 0., 0. 125 | right_num = 0 126 | nb_sample_train = 0 127 | 128 | speaker_vq_suffix = "_ofspeaker" 129 | if args.total_iter > 0 or len(list(os.listdir(args.vq_dir))) == 0: 130 | ##### ---- get code ---- ##### 131 | for batch in train_loader_token: 132 | pose, name = batch 133 | bs, seq = pose.shape[0], pose.shape[1] 134 | 135 | pose = pose.cuda().float() # bs, nb_joints, joints_dim, seq_len 136 | target = net.encode(pose) 137 | target = target.cpu().numpy() 138 | os.system('mkdir -p '+pjoin(args.vq_dir, *name[0].split('/')[:-1])) 139 | np.save(pjoin(args.vq_dir, name[0] +'.npy'), target) 140 | 141 | if args.total_iter > 0 or (not os.path.exists(args.vq_dir+"_"+eval_split)) or len(list(os.listdir(args.vq_dir+"_"+eval_split))): 142 | for batch in val_loader_token: 143 | pose, name = batch 144 | bs, seq = pose.shape[0], pose.shape[1] 145 | 146 | pose = pose.cuda().float() 147 | target = net.encode(pose) 148 | target = target.cpu().numpy() 149 | os.system('mkdir -p '+pjoin(args.vq_dir+'_'+eval_split, *name[0].split('/')[:-1])) 150 | np.save(pjoin(args.vq_dir+'_'+eval_split, name[0]+'.npy'), target) 151 | 152 | 153 | train_loader = dataset_TM_train.DATALoader(args.dataname, args.batch_size, args.nb_code, args.vq_name, unit_length=2**args.down_t, split="train", max_motion_length=args.max_motion_length, 154 | evaluation=False, gpt2_config=gpt2_config, no_text=args.no_text, max_tokens=args.max_tokens, no_before_text=args.no_before_text, max_time_before=args.max_time_before, 155 | fps=args.fps[0], 156 | fixed_text_token=args.fixed_text_token, fixed_text_token_not_space=args.fixed_text_token_not_space, fixed_text_token_not_punctuation=args.fixed_text_token_not_punctuation, unaligned_text=args.unaligned_text, remove_space_before_vq_tokens=args.remove_space_before_vq_tokens, random_text_token_order=args.random_text_token_order) 157 | train_loader_iter = dataset_TM_train.cycle(train_loader) 158 | 159 | val_loader = dataset_TM_train.DATALoader(args.dataname, args.batch_size, args.nb_code, args.vq_name+'_'+eval_split if not args.train_eval else args.vq_name, unit_length=2**args.down_t, split=eval_split if not args.train_eval else "train", max_motion_length=args.max_motion_length, 160 | evaluation=True, gpt2_config=gpt2_config, no_text=args.no_text, max_tokens=args.max_tokens, no_before_text=args.no_before_text, max_time_before=args.max_time_before, 161 | fps=args.fps[0], 162 | fixed_text_token=args.fixed_text_token, fixed_text_token_not_space=args.fixed_text_token_not_space, fixed_text_token_not_punctuation=args.fixed_text_token_not_punctuation, unaligned_text=args.unaligned_text, remove_space_before_vq_tokens=args.remove_space_before_vq_tokens, random_text_token_order=args.random_text_token_order) 163 | 164 | ##### ---- Training ---- ##### 165 | best_acc, best_loss, best_v_loss, best_windowed_v_loss, best_a_loss, best_e_loss, best_l2, best_iter, writer, logger = eval_trans.evaluation_transformer2(args, args.out_dir, val_loader, net, trans_encoder, logger, writer, nb_iter, best_acc=0, best_loss=float("inf"), best_v_loss=float("inf"), best_windowed_v_loss=float("inf"), best_a_loss=float("inf"), best_e_loss=float("inf"), best_l2=float("inf"), best_iter=0, text_model=text_model, text_tokenizer=text_tokenizer, max_motion_length=args.max_motion_length, draw=True, save=(args.total_iter > 0), savenpy=True, save_name=args.save_name, valence_window_size=args.valence_window_size, num_samples=args.num_samples) 166 | optimizer.zero_grad() 167 | prev_loss_total = float("inf") 168 | curr_loss_total = 0.0 169 | while nb_iter < args.total_iter: 170 | batch = next(train_loader_iter) 171 | if len(batch) == 4: 172 | before_text, during_text, m_tokens, m_tokens_len = batch 173 | input_text = (before_text, during_text) 174 | else: 175 | input_text, m_tokens, m_tokens_len = batch 176 | m_tokens, m_tokens_len = m_tokens.cuda(), m_tokens_len.cuda() 177 | bs = m_tokens.shape[0] 178 | target = m_tokens # (bs, 26) 179 | target = target.cuda() 180 | 181 | with torch.cuda.amp.autocast(enabled=args.fp16): 182 | if not isinstance(input_text, tuple): 183 | input_text = (input_text,) 184 | text_feats = [] 185 | if args.no_text: 186 | feat_clip_text = torch.zeros((bs, args.clip_dim)).float().to(m_tokens.device) 187 | elif args.gpt2 is None: 188 | for txt in input_text: 189 | if args.text_token_level and isinstance(txt[0], list): 190 | char_indices = [[[] for _ in range(target.shape[1])] for _ in range(bs)] 191 | full_texts = ["" for _ in range(bs)] 192 | for j in range(bs): 193 | for t in range(len(txt[j])): 194 | if len(txt[j][t]) > 0: 195 | for word in txt[j][t]: 196 | if len(full_texts[j]) > 0: 197 | full_texts[j] += " " 198 | char_indices[j][t].append(len(full_texts[j])) 199 | full_texts[j] += word 200 | text_inputs = text_tokenizer(full_texts, return_tensors='pt', padding=True, truncation=True).to(m_tokens.device) 201 | with torch.no_grad(): 202 | if "openai/clip" in args.text_model_name: 203 | feats_clip_text = text_model.encode_text(text_inputs.input_ids) 204 | else: 205 | feats_clip_text = text_model(**text_inputs).last_hidden_state 206 | feats = torch.zeros_like((bs, target.shape[1], text_model.config.hidden_dim), dtype=torch.float32).to(m_tokens.device) 207 | for j in range(bs): 208 | for t in range(len(char_indices[j])): 209 | if len(char_indices[j][t]) > 0: 210 | feats[j,t,:] = feats_clip_text[j,[tok for c in char_indices[j][t] for tok in text_inputs[j].char_to_token(c)],:].mean(dim=0) 211 | text_feats.append(feats) 212 | else: 213 | text_inputs = text_tokenizer(txt, return_tensors='pt', padding=True, truncation=True).to(m_tokens.device) 214 | with torch.no_grad(): 215 | text_inputs = clip.tokenize(txt) 216 | feat_clip_text = text_model.encode_text(text_inputs) 217 | text_feats.append(feat_clip_text) 218 | feat_clip_text = torch.cat(text_feats, dim=1) 219 | if args.manual_bf16: 220 | feat_clip_text = feat_clip_text.bfloat16() 221 | 222 | 223 | if args.gpt2 is not None: 224 | input_index = target 225 | else: 226 | input_index = target[:,:-1] 227 | 228 | if args.pkeep == -1: 229 | proba = np.random.rand(1)[0] 230 | mask = torch.bernoulli(proba * torch.ones(input_index.shape, 231 | device=input_index.device)) 232 | else: 233 | mask = torch.bernoulli(args.pkeep * torch.ones(input_index.shape, 234 | device=input_index.device)) 235 | mask = mask.round().to(dtype=torch.int64) 236 | if args.gpt2 is not None: 237 | r_indices = torch.where( 238 | (input_index >= text_model.text_vocab_size) & (input_index < text_model.text_vocab_size+args.nb_code), 239 | torch.randint_like(input_index, low=text_model.text_vocab_size, high=text_model.text_vocab_size+args.nb_code), 240 | input_index 241 | ) 242 | else: 243 | r_indices = torch.randint_like(input_index, args.nb_code) 244 | a_indices = mask*input_index+(1-mask)*r_indices 245 | base_codebook_num = text_model.text_vocab_size+args.nb_code 246 | 247 | if args.gpt2 is not None: 248 | if (args.include_speaker and args.speaker_vq_path is None) or (args.include_audio and args.audio_vq_path is None): 249 | if args.fix_pkeep: 250 | input_idx = a_indices 251 | else: 252 | input_idx = input_index 253 | input_embeds = trans_encoder.gpt.transformer.wte(input_idx.clamp(min=0)) 254 | if args.include_audio: 255 | for i in range(input_idx.shape[0]): 256 | num_audio_inputs = (input_idx[i] == -2).long().sum().item() 257 | input_embeds[i,input_idx[i] == -2,:] = trans_encoder.extra_input_layers["aud"](audio_inputs[0][i,:num_audio_inputs,:]).to(input_embeds.dtype) 258 | if args.audio_pkeep is not None: 259 | audio_keep = (input_idx == -2) & (torch.rand_like(input_idx.float()) < args.audio_pkeep) 260 | input_embeds = torch.where( 261 | audio_keep.unsqueeze(-1).repeat(1, 1, input_embeds.shape[-1]), 262 | input_embeds, 263 | torch.zeros_like(input_embeds) 264 | ) 265 | if args.include_speaker: # search for -1 embeddings 266 | for i in range(input_idx.shape[0]): 267 | num_speaker_inputs = (input_idx[i] == -1).long().sum().item() 268 | input_embeds[i,input_idx[i] == -1,:] = trans_encoder.extra_input_layers["mot"](speaker_inputs[0][i,:num_speaker_inputs,:]).to(input_embeds.dtype) 269 | if args.speaker_pkeep is not None: 270 | speaker_keep = (input_idx == -1) & (torch.rand_like(input_idx.float()) < args.speaker_pkeep) 271 | input_embeds = torch.where( 272 | speaker_keep.unsqueeze(-1).repeat(1, 1, input_embeds.shape[-1]), 273 | input_embeds, 274 | torch.zeros_like(input_embeds) 275 | ) 276 | else: 277 | input_embeds = trans_encoder.gpt.transformer.wte(a_indices) 278 | cls_pred = trans_encoder(input_ids=a_indices, input_embeds=input_embeds, attention_mask=m_tokens_len, predict_input_vq=args.speaker_vq_loss) 279 | 280 | cls_pred = cls_pred.contiguous() 281 | 282 | loss_cls = 0.0 283 | for i in range(bs): 284 | if args.gpt2 is not None: 285 | length_i = m_tokens_len[i].sum().item() 286 | if args.speaker_vq_loss: 287 | mask_i = (target[i][1:length_i] >= text_model.text_vocab_size) 288 | else: 289 | mask_i = (target[i][1:length_i] >= text_model.text_vocab_size) & (target[i][1:length_i] < text_model.text_vocab_size + args.nb_code) 290 | loss_cls += loss_ce(cls_pred[i][:length_i-1][mask_i], target[i][1:length_i][mask_i]-text_model.text_vocab_size) / bs 291 | probs = torch.softmax(cls_pred[i][:length_i-1], dim=-1) 292 | 293 | if args.if_maxtest: 294 | _, cls_pred_index = torch.max(probs, dim=-1) 295 | 296 | else: 297 | with torch.no_grad(): 298 | dist = Categorical(probs.float()) 299 | cls_pred_index = dist.sample() 300 | if args.gpt2 is not None: 301 | right_num += (cls_pred_index[mask_i] == target[i][1:length_i][mask_i]-text_model.text_vocab_size).sum().item() 302 | nb_sample_train += mask_i.long().sum().item() 303 | 304 | loss_cls = loss_cls / args.gradient_accumulation_steps 305 | curr_loss_total = curr_loss_total + loss_cls.item() 306 | if args.grad_scaling: 307 | scaler.scale(loss_cls).backward() 308 | else: 309 | loss_cls.backward() 310 | if ((nb_iter + 1) % args.gradient_accumulation_steps == 0) or (nb_iter + 1 == args.total_iter): 311 | if args.training_end_check_interval is not None and (nb_iter + 1) % (args.training_end_check_interval * args.gradient_accumulation_steps) == 0: 312 | if (prev_loss_total - curr_loss_total) / args.training_end_check_interval < args.train_loss_threshold: 313 | break 314 | prev_loss_total = curr_loss_total 315 | curr_loss_total = 0.0 316 | if args.grad_clip is not None: 317 | torch.nn.utils.clip_grad_norm_(trans_encoder.parameters(), args.grad_clip) 318 | if args.grad_scaling: 319 | scaler.step(optimizer) 320 | scaler.update() 321 | else: 322 | optimizer.step() 323 | scheduler.step() 324 | optimizer.zero_grad() 325 | 326 | avg_loss_cls = avg_loss_cls + loss_cls.item() 327 | if args.gpt2 is None: 328 | nb_sample_train = nb_sample_train + (m_tokens_len + 1).sum().item() 329 | 330 | nb_iter += 1 331 | if nb_iter % args.print_iter == 0 : 332 | avg_loss_cls = avg_loss_cls / args.print_iter 333 | avg_acc = right_num * 100 / nb_sample_train 334 | writer.add_scalar('./Loss/train', avg_loss_cls, nb_iter) 335 | writer.add_scalar('./ACC/train', avg_acc, nb_iter) 336 | msg = f"Train. Iter {nb_iter} : Loss. {avg_loss_cls:.5f}, ACC. {avg_acc:.4f}" 337 | logger.info(msg) 338 | avg_loss_cls = 0. 339 | right_num = 0 340 | nb_sample_train = 0 341 | 342 | if nb_iter % args.eval_iter == 0: 343 | best_acc, best_loss, best_v_loss, best_windowed_v_loss, best_a_loss, best_e_loss, best_l2, best_iter, writer, logger = eval_trans.evaluation_transformer2(args, args.out_dir, val_loader, net, trans_encoder, logger, writer, nb_iter, best_acc=best_acc, best_loss=best_loss, best_v_loss=best_v_loss, best_windowed_v_loss=best_windowed_v_loss, best_a_loss=best_a_loss, best_e_loss=best_e_loss, best_l2=best_l2, best_iter=best_iter, text_model=text_model, text_tokenizer=text_tokenizer, max_motion_length=args.max_motion_length, draw=True, save=True, savenpy=True, save_name=args.save_name, valence_window_size=args.valence_window_size, num_samples=args.num_samples) 344 | 345 | if nb_iter == args.total_iter: 346 | msg_fiinal = f"Train. Iter {best_iter} : Acc. {best_acc:.5f}" 347 | logger.info(msg_final) 348 | break 349 | -------------------------------------------------------------------------------- /utils/rotation_conversions.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # Check PYTORCH3D_LICENCE before use 3 | 4 | import functools 5 | from typing import Optional 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | 11 | """ 12 | The transformation matrices returned from the functions in this file assume 13 | the points on which the transformation will be applied are column vectors. 14 | i.e. the R matrix is structured as 15 | R = [ 16 | [Rxx, Rxy, Rxz], 17 | [Ryx, Ryy, Ryz], 18 | [Rzx, Rzy, Rzz], 19 | ] # (3, 3) 20 | This matrix can be applied to column vectors by post multiplication 21 | by the points e.g. 22 | points = [[0], [1], [2]] # (3 x 1) xyz coordinates of a point 23 | transformed_points = R * points 24 | To apply the same matrix to points which are row vectors, the R matrix 25 | can be transposed and pre multiplied by the points: 26 | e.g. 27 | points = [[0, 1, 2]] # (1 x 3) xyz coordinates of a point 28 | transformed_points = points * R.transpose(1, 0) 29 | """ 30 | 31 | 32 | def quaternion_to_matrix(quaternions): 33 | """ 34 | Convert rotations given as quaternions to rotation matrices. 35 | Args: 36 | quaternions: quaternions with real part first, 37 | as tensor of shape (..., 4). 38 | Returns: 39 | Rotation matrices as tensor of shape (..., 3, 3). 40 | """ 41 | r, i, j, k = torch.unbind(quaternions, -1) 42 | two_s = 2.0 / (quaternions * quaternions).sum(-1) 43 | 44 | o = torch.stack( 45 | ( 46 | 1 - two_s * (j * j + k * k), 47 | two_s * (i * j - k * r), 48 | two_s * (i * k + j * r), 49 | two_s * (i * j + k * r), 50 | 1 - two_s * (i * i + k * k), 51 | two_s * (j * k - i * r), 52 | two_s * (i * k - j * r), 53 | two_s * (j * k + i * r), 54 | 1 - two_s * (i * i + j * j), 55 | ), 56 | -1, 57 | ) 58 | return o.reshape(quaternions.shape[:-1] + (3, 3)) 59 | 60 | 61 | def _copysign(a, b): 62 | """ 63 | Return a tensor where each element has the absolute value taken from the, 64 | corresponding element of a, with sign taken from the corresponding 65 | element of b. This is like the standard copysign floating-point operation, 66 | but is not careful about negative 0 and NaN. 67 | Args: 68 | a: source tensor. 69 | b: tensor whose signs will be used, of the same shape as a. 70 | Returns: 71 | Tensor of the same shape as a with the signs of b. 72 | """ 73 | signs_differ = (a < 0) != (b < 0) 74 | return torch.where(signs_differ, -a, a) 75 | 76 | 77 | def _sqrt_positive_part(x): 78 | """ 79 | Returns torch.sqrt(torch.max(0, x)) 80 | but with a zero subgradient where x is 0. 81 | """ 82 | ret = torch.zeros_like(x) 83 | positive_mask = x > 0 84 | ret[positive_mask] = torch.sqrt(x[positive_mask]) 85 | return ret 86 | 87 | 88 | def matrix_to_quaternion(matrix): 89 | """ 90 | Convert rotations given as rotation matrices to quaternions. 91 | Args: 92 | matrix: Rotation matrices as tensor of shape (..., 3, 3). 93 | Returns: 94 | quaternions with real part first, as tensor of shape (..., 4). 95 | """ 96 | if matrix.size(-1) != 3 or matrix.size(-2) != 3: 97 | raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.") 98 | m00 = matrix[..., 0, 0] 99 | m11 = matrix[..., 1, 1] 100 | m22 = matrix[..., 2, 2] 101 | o0 = 0.5 * _sqrt_positive_part(1 + m00 + m11 + m22) 102 | x = 0.5 * _sqrt_positive_part(1 + m00 - m11 - m22) 103 | y = 0.5 * _sqrt_positive_part(1 - m00 + m11 - m22) 104 | z = 0.5 * _sqrt_positive_part(1 - m00 - m11 + m22) 105 | o1 = _copysign(x, matrix[..., 2, 1] - matrix[..., 1, 2]) 106 | o2 = _copysign(y, matrix[..., 0, 2] - matrix[..., 2, 0]) 107 | o3 = _copysign(z, matrix[..., 1, 0] - matrix[..., 0, 1]) 108 | return torch.stack((o0, o1, o2, o3), -1) 109 | 110 | 111 | def _axis_angle_rotation(axis: str, angle): 112 | """ 113 | Return the rotation matrices for one of the rotations about an axis 114 | of which Euler angles describe, for each value of the angle given. 115 | Args: 116 | axis: Axis label "X" or "Y or "Z". 117 | angle: any shape tensor of Euler angles in radians 118 | Returns: 119 | Rotation matrices as tensor of shape (..., 3, 3). 120 | """ 121 | 122 | cos = torch.cos(angle) 123 | sin = torch.sin(angle) 124 | one = torch.ones_like(angle) 125 | zero = torch.zeros_like(angle) 126 | 127 | if axis == "X": 128 | R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos) 129 | if axis == "Y": 130 | R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos) 131 | if axis == "Z": 132 | R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one) 133 | 134 | return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3)) 135 | 136 | 137 | def euler_angles_to_matrix(euler_angles, convention: str): 138 | """ 139 | Convert rotations given as Euler angles in radians to rotation matrices. 140 | Args: 141 | euler_angles: Euler angles in radians as tensor of shape (..., 3). 142 | convention: Convention string of three uppercase letters from 143 | {"X", "Y", and "Z"}. 144 | Returns: 145 | Rotation matrices as tensor of shape (..., 3, 3). 146 | """ 147 | if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3: 148 | raise ValueError("Invalid input euler angles.") 149 | if len(convention) != 3: 150 | raise ValueError("Convention must have 3 letters.") 151 | if convention[1] in (convention[0], convention[2]): 152 | raise ValueError(f"Invalid convention {convention}.") 153 | for letter in convention: 154 | if letter not in ("X", "Y", "Z"): 155 | raise ValueError(f"Invalid letter {letter} in convention string.") 156 | matrices = map(_axis_angle_rotation, convention, torch.unbind(euler_angles, -1)) 157 | return functools.reduce(torch.matmul, matrices) 158 | 159 | 160 | def _angle_from_tan( 161 | axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool 162 | ): 163 | """ 164 | Extract the first or third Euler angle from the two members of 165 | the matrix which are positive constant times its sine and cosine. 166 | Args: 167 | axis: Axis label "X" or "Y or "Z" for the angle we are finding. 168 | other_axis: Axis label "X" or "Y or "Z" for the middle axis in the 169 | convention. 170 | data: Rotation matrices as tensor of shape (..., 3, 3). 171 | horizontal: Whether we are looking for the angle for the third axis, 172 | which means the relevant entries are in the same row of the 173 | rotation matrix. If not, they are in the same column. 174 | tait_bryan: Whether the first and third axes in the convention differ. 175 | Returns: 176 | Euler Angles in radians for each matrix in data as a tensor 177 | of shape (...). 178 | """ 179 | 180 | i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis] 181 | if horizontal: 182 | i2, i1 = i1, i2 183 | even = (axis + other_axis) in ["XY", "YZ", "ZX"] 184 | if horizontal == even: 185 | return torch.atan2(data[..., i1], data[..., i2]) 186 | if tait_bryan: 187 | return torch.atan2(-data[..., i2], data[..., i1]) 188 | return torch.atan2(data[..., i2], -data[..., i1]) 189 | 190 | 191 | def _index_from_letter(letter: str): 192 | if letter == "X": 193 | return 0 194 | if letter == "Y": 195 | return 1 196 | if letter == "Z": 197 | return 2 198 | 199 | 200 | def matrix_to_euler_angles(matrix, convention: str): 201 | """ 202 | Convert rotations given as rotation matrices to Euler angles in radians. 203 | Args: 204 | matrix: Rotation matrices as tensor of shape (..., 3, 3). 205 | convention: Convention string of three uppercase letters. 206 | Returns: 207 | Euler angles in radians as tensor of shape (..., 3). 208 | """ 209 | if len(convention) != 3: 210 | raise ValueError("Convention must have 3 letters.") 211 | if convention[1] in (convention[0], convention[2]): 212 | raise ValueError(f"Invalid convention {convention}.") 213 | for letter in convention: 214 | if letter not in ("X", "Y", "Z"): 215 | raise ValueError(f"Invalid letter {letter} in convention string.") 216 | if matrix.size(-1) != 3 or matrix.size(-2) != 3: 217 | raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.") 218 | i0 = _index_from_letter(convention[0]) 219 | i2 = _index_from_letter(convention[2]) 220 | tait_bryan = i0 != i2 221 | if tait_bryan: 222 | central_angle = torch.asin( 223 | matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0) 224 | ) 225 | else: 226 | central_angle = torch.acos(matrix[..., i0, i0]) 227 | 228 | o = ( 229 | _angle_from_tan( 230 | convention[0], convention[1], matrix[..., i2], False, tait_bryan 231 | ), 232 | central_angle, 233 | _angle_from_tan( 234 | convention[2], convention[1], matrix[..., i0, :], True, tait_bryan 235 | ), 236 | ) 237 | return torch.stack(o, -1) 238 | 239 | 240 | def random_quaternions( 241 | n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False 242 | ): 243 | """ 244 | Generate random quaternions representing rotations, 245 | i.e. versors with nonnegative real part. 246 | Args: 247 | n: Number of quaternions in a batch to return. 248 | dtype: Type to return. 249 | device: Desired device of returned tensor. Default: 250 | uses the current device for the default tensor type. 251 | requires_grad: Whether the resulting tensor should have the gradient 252 | flag set. 253 | Returns: 254 | Quaternions as tensor of shape (N, 4). 255 | """ 256 | o = torch.randn((n, 4), dtype=dtype, device=device, requires_grad=requires_grad) 257 | s = (o * o).sum(1) 258 | o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None] 259 | return o 260 | 261 | 262 | def random_rotations( 263 | n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False 264 | ): 265 | """ 266 | Generate random rotations as 3x3 rotation matrices. 267 | Args: 268 | n: Number of rotation matrices in a batch to return. 269 | dtype: Type to return. 270 | device: Device of returned tensor. Default: if None, 271 | uses the current device for the default tensor type. 272 | requires_grad: Whether the resulting tensor should have the gradient 273 | flag set. 274 | Returns: 275 | Rotation matrices as tensor of shape (n, 3, 3). 276 | """ 277 | quaternions = random_quaternions( 278 | n, dtype=dtype, device=device, requires_grad=requires_grad 279 | ) 280 | return quaternion_to_matrix(quaternions) 281 | 282 | 283 | def random_rotation( 284 | dtype: Optional[torch.dtype] = None, device=None, requires_grad=False 285 | ): 286 | """ 287 | Generate a single random 3x3 rotation matrix. 288 | Args: 289 | dtype: Type to return 290 | device: Device of returned tensor. Default: if None, 291 | uses the current device for the default tensor type 292 | requires_grad: Whether the resulting tensor should have the gradient 293 | flag set 294 | Returns: 295 | Rotation matrix as tensor of shape (3, 3). 296 | """ 297 | return random_rotations(1, dtype, device, requires_grad)[0] 298 | 299 | 300 | def standardize_quaternion(quaternions): 301 | """ 302 | Convert a unit quaternion to a standard form: one in which the real 303 | part is non negative. 304 | Args: 305 | quaternions: Quaternions with real part first, 306 | as tensor of shape (..., 4). 307 | Returns: 308 | Standardized quaternions as tensor of shape (..., 4). 309 | """ 310 | return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions) 311 | 312 | 313 | def quaternion_raw_multiply(a, b): 314 | """ 315 | Multiply two quaternions. 316 | Usual torch rules for broadcasting apply. 317 | Args: 318 | a: Quaternions as tensor of shape (..., 4), real part first. 319 | b: Quaternions as tensor of shape (..., 4), real part first. 320 | Returns: 321 | The product of a and b, a tensor of quaternions shape (..., 4). 322 | """ 323 | aw, ax, ay, az = torch.unbind(a, -1) 324 | bw, bx, by, bz = torch.unbind(b, -1) 325 | ow = aw * bw - ax * bx - ay * by - az * bz 326 | ox = aw * bx + ax * bw + ay * bz - az * by 327 | oy = aw * by - ax * bz + ay * bw + az * bx 328 | oz = aw * bz + ax * by - ay * bx + az * bw 329 | return torch.stack((ow, ox, oy, oz), -1) 330 | 331 | 332 | def quaternion_multiply(a, b): 333 | """ 334 | Multiply two quaternions representing rotations, returning the quaternion 335 | representing their composition, i.e. the versor with nonnegative real part. 336 | Usual torch rules for broadcasting apply. 337 | Args: 338 | a: Quaternions as tensor of shape (..., 4), real part first. 339 | b: Quaternions as tensor of shape (..., 4), real part first. 340 | Returns: 341 | The product of a and b, a tensor of quaternions of shape (..., 4). 342 | """ 343 | ab = quaternion_raw_multiply(a, b) 344 | return standardize_quaternion(ab) 345 | 346 | 347 | def quaternion_invert(quaternion): 348 | """ 349 | Given a quaternion representing rotation, get the quaternion representing 350 | its inverse. 351 | Args: 352 | quaternion: Quaternions as tensor of shape (..., 4), with real part 353 | first, which must be versors (unit quaternions). 354 | Returns: 355 | The inverse, a tensor of quaternions of shape (..., 4). 356 | """ 357 | 358 | return quaternion * quaternion.new_tensor([1, -1, -1, -1]) 359 | 360 | 361 | def quaternion_apply(quaternion, point): 362 | """ 363 | Apply the rotation given by a quaternion to a 3D point. 364 | Usual torch rules for broadcasting apply. 365 | Args: 366 | quaternion: Tensor of quaternions, real part first, of shape (..., 4). 367 | point: Tensor of 3D points of shape (..., 3). 368 | Returns: 369 | Tensor of rotated points of shape (..., 3). 370 | """ 371 | if point.size(-1) != 3: 372 | raise ValueError(f"Points are not in 3D, f{point.shape}.") 373 | real_parts = point.new_zeros(point.shape[:-1] + (1,)) 374 | point_as_quaternion = torch.cat((real_parts, point), -1) 375 | out = quaternion_raw_multiply( 376 | quaternion_raw_multiply(quaternion, point_as_quaternion), 377 | quaternion_invert(quaternion), 378 | ) 379 | return out[..., 1:] 380 | 381 | 382 | def axis_angle_to_matrix(axis_angle): 383 | """ 384 | Convert rotations given as axis/angle to rotation matrices. 385 | Args: 386 | axis_angle: Rotations given as a vector in axis angle form, 387 | as a tensor of shape (..., 3), where the magnitude is 388 | the angle turned anticlockwise in radians around the 389 | vector's direction. 390 | Returns: 391 | Rotation matrices as tensor of shape (..., 3, 3). 392 | """ 393 | return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle)) 394 | 395 | 396 | def matrix_to_axis_angle(matrix): 397 | """ 398 | Convert rotations given as rotation matrices to axis/angle. 399 | Args: 400 | matrix: Rotation matrices as tensor of shape (..., 3, 3). 401 | Returns: 402 | Rotations given as a vector in axis angle form, as a tensor 403 | of shape (..., 3), where the magnitude is the angle 404 | turned anticlockwise in radians around the vector's 405 | direction. 406 | """ 407 | return quaternion_to_axis_angle(matrix_to_quaternion(matrix)) 408 | 409 | 410 | def axis_angle_to_quaternion(axis_angle): 411 | """ 412 | Convert rotations given as axis/angle to quaternions. 413 | Args: 414 | axis_angle: Rotations given as a vector in axis angle form, 415 | as a tensor of shape (..., 3), where the magnitude is 416 | the angle turned anticlockwise in radians around the 417 | vector's direction. 418 | Returns: 419 | quaternions with real part first, as tensor of shape (..., 4). 420 | """ 421 | angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True) 422 | half_angles = 0.5 * angles 423 | eps = 1e-6 424 | small_angles = angles.abs() < eps 425 | sin_half_angles_over_angles = torch.empty_like(angles) 426 | sin_half_angles_over_angles[~small_angles] = ( 427 | torch.sin(half_angles[~small_angles]) / angles[~small_angles] 428 | ) 429 | # for x small, sin(x/2) is about x/2 - (x/2)^3/6 430 | # so sin(x/2)/x is about 1/2 - (x*x)/48 431 | sin_half_angles_over_angles[small_angles] = ( 432 | 0.5 - (angles[small_angles] * angles[small_angles]) / 48 433 | ) 434 | quaternions = torch.cat( 435 | [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1 436 | ) 437 | return quaternions 438 | 439 | 440 | def quaternion_to_axis_angle(quaternions): 441 | """ 442 | Convert rotations given as quaternions to axis/angle. 443 | Args: 444 | quaternions: quaternions with real part first, 445 | as tensor of shape (..., 4). 446 | Returns: 447 | Rotations given as a vector in axis angle form, as a tensor 448 | of shape (..., 3), where the magnitude is the angle 449 | turned anticlockwise in radians around the vector's 450 | direction. 451 | """ 452 | norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True) 453 | half_angles = torch.atan2(norms, quaternions[..., :1]) 454 | angles = 2 * half_angles 455 | eps = 1e-6 456 | small_angles = angles.abs() < eps 457 | sin_half_angles_over_angles = torch.empty_like(angles) 458 | sin_half_angles_over_angles[~small_angles] = ( 459 | torch.sin(half_angles[~small_angles]) / angles[~small_angles] 460 | ) 461 | # for x small, sin(x/2) is about x/2 - (x/2)^3/6 462 | # so sin(x/2)/x is about 1/2 - (x*x)/48 463 | sin_half_angles_over_angles[small_angles] = ( 464 | 0.5 - (angles[small_angles] * angles[small_angles]) / 48 465 | ) 466 | return quaternions[..., 1:] / sin_half_angles_over_angles 467 | 468 | 469 | def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor: 470 | """ 471 | Converts 6D rotation representation by Zhou et al. [1] to rotation matrix 472 | using Gram--Schmidt orthogonalisation per Section B of [1]. 473 | Args: 474 | d6: 6D rotation representation, of size (*, 6) 475 | Returns: 476 | batch of rotation matrices of size (*, 3, 3) 477 | [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. 478 | On the Continuity of Rotation Representations in Neural Networks. 479 | IEEE Conference on Computer Vision and Pattern Recognition, 2019. 480 | Retrieved from http://arxiv.org/abs/1812.07035 481 | """ 482 | 483 | a1, a2 = d6[..., :3], d6[..., 3:] 484 | b1 = F.normalize(a1, dim=-1) 485 | b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1 486 | b2 = F.normalize(b2, dim=-1) 487 | b3 = torch.cross(b1, b2, dim=-1) 488 | return torch.stack((b1, b2, b3), dim=-2) 489 | 490 | 491 | def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor: 492 | """ 493 | Converts rotation matrices to 6D rotation representation by Zhou et al. [1] 494 | by dropping the last row. Note that 6D representation is not unique. 495 | Args: 496 | matrix: batch of rotation matrices of size (*, 3, 3) 497 | Returns: 498 | 6D rotation representation, of size (*, 6) 499 | [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. 500 | On the Continuity of Rotation Representations in Neural Networks. 501 | IEEE Conference on Computer Vision and Pattern Recognition, 2019. 502 | Retrieved from http://arxiv.org/abs/1812.07035 503 | """ 504 | return matrix[..., :2, :].clone().reshape(*matrix.size()[:-2], 6) 505 | 506 | def canonicalize_smplh(poses, trans = None): 507 | bs, nframes, njoints = poses.shape[:3] 508 | 509 | global_orient = poses[:, :, 0] 510 | 511 | # first global rotations 512 | rot2d = matrix_to_axis_angle(global_orient[:, 0]) 513 | #rot2d[:, :2] = 0 # Remove the rotation along the vertical axis 514 | rot2d = axis_angle_to_matrix(rot2d) 515 | 516 | # Rotate the global rotation to eliminate Z rotations 517 | global_orient = torch.einsum("ikj,imkl->imjl", rot2d, global_orient) 518 | 519 | # Construct canonicalized version of x 520 | xc = torch.cat((global_orient[:, :, None], poses[:, :, 1:]), dim=2) 521 | 522 | if trans is not None: 523 | vel = trans[:, 1:] - trans[:, :-1] 524 | # Turn the translation as well 525 | vel = torch.einsum("ikj,ilk->ilj", rot2d, vel) 526 | trans = torch.cat((torch.zeros(bs, 1, 3, device=vel.device), 527 | torch.cumsum(vel, 1)), 1) 528 | return xc, trans 529 | else: 530 | return xc 531 | 532 | --------------------------------------------------------------------------------