├── common ├── __init__.py └── skeleton.py ├── data └── __init__.py ├── models ├── __init__.py ├── vq │ ├── __init__.py │ ├── encdec.py │ ├── resnet.py │ ├── model.py │ ├── quantizer.py │ └── residual_vq.py ├── mask_transformer │ ├── __init__.py │ └── tools.py ├── .DS_Store ├── t2m_eval_modules.py └── t2m_eval_wrapper.py ├── dataset └── __init__.py ├── options ├── __init__.py ├── eval_option.py ├── base_option.py ├── train_option.py └── vq_option.py ├── motion_loaders ├── __init__.py └── dataset_motion_loader.py ├── visualization ├── __init__.py ├── .DS_Store ├── data │ ├── .DS_Store │ ├── smpl │ │ └── smpl │ │ │ ├── .DS_Store │ │ │ └── smpl.txt │ └── gBR_sBM_cAll_d04_mBR0_ch01.pkl ├── joints2bvh.py ├── smpl2bvh.py ├── utils │ └── bvh.py ├── BVH.py ├── BVH_mod.py └── AnimationStructure.py ├── .gitignore ├── prepare ├── .DS_Store ├── download_glove.sh ├── download_evaluator.sh └── download_models.sh ├── example_data ├── 000612.mp4 └── 000612.npy ├── utils ├── fixseed.py ├── paramUtil.py ├── get_opt.py ├── word_vectorizer.py ├── utils.py └── metrics.py ├── requirements.txt ├── assets └── text_prompt.txt ├── LICENSE ├── train_vq.py ├── eval_t2m_vq.py ├── train_t2m_transformer.py ├── environment.yml ├── train_res_transformer.py ├── edit_t2m.py ├── eval_t2m_trans_res.py └── gen_t2m.py /common/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/vq/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /options/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /motion_loaders/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /visualization/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/mask_transformer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | checkpoints 3 | editing 4 | generation 5 | -------------------------------------------------------------------------------- /models/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EricGuo5513/momask-codes/HEAD/models/.DS_Store -------------------------------------------------------------------------------- /prepare/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EricGuo5513/momask-codes/HEAD/prepare/.DS_Store -------------------------------------------------------------------------------- /example_data/000612.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EricGuo5513/momask-codes/HEAD/example_data/000612.mp4 -------------------------------------------------------------------------------- /example_data/000612.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EricGuo5513/momask-codes/HEAD/example_data/000612.npy -------------------------------------------------------------------------------- /visualization/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EricGuo5513/momask-codes/HEAD/visualization/.DS_Store -------------------------------------------------------------------------------- /visualization/data/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EricGuo5513/momask-codes/HEAD/visualization/data/.DS_Store -------------------------------------------------------------------------------- /visualization/data/smpl/smpl/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EricGuo5513/momask-codes/HEAD/visualization/data/smpl/smpl/.DS_Store -------------------------------------------------------------------------------- /visualization/data/gBR_sBM_cAll_d04_mBR0_ch01.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EricGuo5513/momask-codes/HEAD/visualization/data/gBR_sBM_cAll_d04_mBR0_ch01.pkl -------------------------------------------------------------------------------- /visualization/data/smpl/smpl/smpl.txt: -------------------------------------------------------------------------------- 1 | Once you have downloaded the SMPL model, place it here like below. 2 | 3 | data 4 | |_smpl 5 | |_smpl 6 | |_SMPL_FEMALE.pkl 7 | |_SMPL_MALE.pkl 8 | |_SMPL_NEUTRAL.pkl 9 | -------------------------------------------------------------------------------- /prepare/download_glove.sh: -------------------------------------------------------------------------------- 1 | echo -e "Downloading glove (in use by the evaluators, not by MoMask itself)" 2 | gdown --fuzzy https://drive.google.com/file/d/1cmXKUT31pqd7_XpJAiWEo1K81TMYHA5n/view?usp=sharing 3 | rm -rf glove 4 | 5 | unzip glove.zip 6 | echo -e "Cleaning\n" 7 | rm glove.zip 8 | 9 | echo -e "Downloading done!" -------------------------------------------------------------------------------- /utils/fixseed.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import random 4 | 5 | 6 | def fixseed(seed): 7 | torch.backends.cudnn.benchmark = False 8 | random.seed(seed) 9 | np.random.seed(seed) 10 | torch.manual_seed(seed) 11 | 12 | 13 | # SEED = 10 14 | # EVALSEED = 0 15 | # # Provoc warning: not fully functionnal yet 16 | # # torch.set_deterministic(True) 17 | # torch.backends.cudnn.benchmark = False 18 | # fixseed(SEED) 19 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | --extra-index-url https://download.pytorch.org/whl/cu113 2 | clip @ git+https://github.com/openai/CLIP.git 3 | chumpy 4 | einops==0.6.1 5 | ffmpy==0.3.1 6 | ftfy==6.1.1 7 | gdown==4.7.1 8 | grpcio==1.54.2 9 | h11==0.14.0 10 | importlib-metadata==5.0.0 11 | importlib-resources==5.12.0 12 | joblib 13 | matplotlib==3.1.3 14 | numpy==1.21.5 15 | Pillow==9.2.0 16 | PyYAML==6.0 17 | scikit-learn 18 | scipy 19 | smplx==0.1.28 20 | sniffio==1.3.0 21 | torch==1.12.0 22 | torch-tb-profiler 23 | torchaudio 24 | torchvision 25 | tornado 26 | tqdm 27 | trimesh 28 | vector-quantize-pytorch==1.6.30 29 | -------------------------------------------------------------------------------- /prepare/download_evaluator.sh: -------------------------------------------------------------------------------- 1 | cd checkpoints 2 | 3 | cd t2m 4 | echo -e "Downloading evaluation models for HumanML3D dataset" 5 | gdown --fuzzy https://drive.google.com/file/d/19C_eiEr0kMGlYVJy_yFL6_Dhk3RvmwhM/view?usp=sharing 6 | echo -e "Unzipping humanml3d_evaluator.zip" 7 | unzip humanml3d_evaluator.zip 8 | 9 | echo -e "Clearning humanml3d_evaluator.zip" 10 | rm humanml3d_evaluator.zip 11 | 12 | cd ../kit/ 13 | echo -e "Downloading pretrained models for KIT-ML dataset" 14 | gdown --fuzzy https://drive.google.com/file/d/1TKIZ3TSSZawpilC-7Kw7Ws4sNNuzb49p/view?usp=drive_link 15 | 16 | echo -e "Unzipping kit_evaluator.zip" 17 | unzip kit_evaluator.zip 18 | 19 | echo -e "Clearning kit_evaluator.zip" 20 | rm kit_evaluator.zip 21 | 22 | cd ../../ 23 | 24 | echo -e "Downloading done!" 25 | -------------------------------------------------------------------------------- /assets/text_prompt.txt: -------------------------------------------------------------------------------- 1 | the person holds his left foot with his left hand, puts his right foot up and left hand up too.#132 2 | a man bends down and picks something up with his left hand.#84 3 | A man stands for few seconds and picks up his arms and shakes them.#176 4 | A person walks with a limp, their left leg get injured.#192 5 | a person jumps up and then lands.#52 6 | a person performs a standing back kick.#52 7 | A person pokes their right hand along the ground, like they might be planting seeds.#60 8 | the person steps forward and uses the left leg to kick something forward.#92 9 | the man walked forward, spun right on one foot and walked back to his original position.#92 10 | the person was pushed but did not fall.#124 11 | this person stumbles left and right while moving forward.#132 12 | a person reaching down and picking something up.#148 -------------------------------------------------------------------------------- /prepare/download_models.sh: -------------------------------------------------------------------------------- 1 | rm -rf checkpoints 2 | mkdir checkpoints 3 | cd checkpoints 4 | mkdir t2m 5 | 6 | cd t2m 7 | echo -e "Downloading pretrained models for HumanML3D dataset" 8 | gdown --fuzzy https://drive.google.com/file/d/1vXS7SHJBgWPt59wupQ5UUzhFObrnGkQ0/view?usp=sharing 9 | 10 | echo -e "Unzipping humanml3d_models.zip" 11 | unzip humanml3d_models.zip 12 | 13 | echo -e "Cleaning humanml3d_models.zip" 14 | rm humanml3d_models.zip 15 | 16 | cd ../ 17 | mkdir kit 18 | cd kit 19 | 20 | echo -e "Downloading pretrained models for KIT-ML dataset" 21 | gdown --fuzzy https://drive.google.com/file/d/1FapdHNkxPouasVM8MWgg1f6sd_4Lua2q/view?usp=sharing 22 | 23 | echo -e "Unzipping kit_models.zip" 24 | unzip kit_models.zip 25 | 26 | echo -e "Cleaning kit_models.zip" 27 | rm kit_models.zip 28 | 29 | cd ../../ 30 | 31 | echo -e "Downloading done!" 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Chuan Guo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /motion_loaders/dataset_motion_loader.py: -------------------------------------------------------------------------------- 1 | from data.t2m_dataset import Text2MotionDatasetEval, collate_fn # TODO 2 | from utils.word_vectorizer import WordVectorizer 3 | import numpy as np 4 | from os.path import join as pjoin 5 | from torch.utils.data import DataLoader 6 | from utils.get_opt import get_opt 7 | 8 | def get_dataset_motion_loader(opt_path, batch_size, fname, device): 9 | opt = get_opt(opt_path, device) 10 | 11 | # Configurations of T2M dataset and KIT dataset is almost the same 12 | if opt.dataset_name == 't2m' or opt.dataset_name == 'kit': 13 | print('Loading dataset %s ...' % opt.dataset_name) 14 | 15 | mean = np.load(pjoin(opt.meta_dir, 'mean.npy')) 16 | std = np.load(pjoin(opt.meta_dir, 'std.npy')) 17 | 18 | w_vectorizer = WordVectorizer('./glove', 'our_vab') 19 | split_file = pjoin(opt.data_root, '%s.txt'%fname) 20 | dataset = Text2MotionDatasetEval(opt, mean, std, split_file, w_vectorizer) 21 | dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=4, drop_last=True, 22 | collate_fn=collate_fn, shuffle=True) 23 | else: 24 | raise KeyError('Dataset not Recognized !!') 25 | 26 | print('Ground Truth Dataset Loading Completed!!!') 27 | return dataloader, dataset -------------------------------------------------------------------------------- /utils/paramUtil.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | # Define a kinematic tree for the skeletal struture 4 | kit_kinematic_chain = [[0, 11, 12, 13, 14, 15], [0, 16, 17, 18, 19, 20], [0, 1, 2, 3, 4], [3, 5, 6, 7], [3, 8, 9, 10]] 5 | 6 | kit_raw_offsets = np.array( 7 | [ 8 | [0, 0, 0], 9 | [0, 1, 0], 10 | [0, 1, 0], 11 | [0, 1, 0], 12 | [0, 1, 0], 13 | [1, 0, 0], 14 | [0, -1, 0], 15 | [0, -1, 0], 16 | [-1, 0, 0], 17 | [0, -1, 0], 18 | [0, -1, 0], 19 | [1, 0, 0], 20 | [0, -1, 0], 21 | [0, -1, 0], 22 | [0, 0, 1], 23 | [0, 0, 1], 24 | [-1, 0, 0], 25 | [0, -1, 0], 26 | [0, -1, 0], 27 | [0, 0, 1], 28 | [0, 0, 1] 29 | ] 30 | ) 31 | 32 | t2m_raw_offsets = np.array([[0,0,0], 33 | [1,0,0], 34 | [-1,0,0], 35 | [0,1,0], 36 | [0,-1,0], 37 | [0,-1,0], 38 | [0,1,0], 39 | [0,-1,0], 40 | [0,-1,0], 41 | [0,1,0], 42 | [0,0,1], 43 | [0,0,1], 44 | [0,1,0], 45 | [1,0,0], 46 | [-1,0,0], 47 | [0,0,1], 48 | [0,-1,0], 49 | [0,-1,0], 50 | [0,-1,0], 51 | [0,-1,0], 52 | [0,-1,0], 53 | [0,-1,0]]) 54 | 55 | t2m_kinematic_chain = [[0, 2, 5, 8, 11], [0, 1, 4, 7, 10], [0, 3, 6, 9, 12, 15], [9, 14, 17, 19, 21], [9, 13, 16, 18, 20]] 56 | t2m_left_hand_chain = [[20, 22, 23, 24], [20, 34, 35, 36], [20, 25, 26, 27], [20, 31, 32, 33], [20, 28, 29, 30]] 57 | t2m_right_hand_chain = [[21, 43, 44, 45], [21, 46, 47, 48], [21, 40, 41, 42], [21, 37, 38, 39], [21, 49, 50, 51]] 58 | 59 | 60 | kit_tgt_skel_id = '03950' 61 | 62 | t2m_tgt_skel_id = '000021' 63 | 64 | -------------------------------------------------------------------------------- /models/vq/encdec.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from models.vq.resnet import Resnet1D 3 | 4 | 5 | class Encoder(nn.Module): 6 | def __init__(self, 7 | input_emb_width=3, 8 | output_emb_width=512, 9 | down_t=2, 10 | stride_t=2, 11 | width=512, 12 | depth=3, 13 | dilation_growth_rate=3, 14 | activation='relu', 15 | norm=None): 16 | super().__init__() 17 | 18 | blocks = [] 19 | filter_t, pad_t = stride_t * 2, stride_t // 2 20 | blocks.append(nn.Conv1d(input_emb_width, width, 3, 1, 1)) 21 | blocks.append(nn.ReLU()) 22 | 23 | for i in range(down_t): 24 | input_dim = width 25 | block = nn.Sequential( 26 | nn.Conv1d(input_dim, width, filter_t, stride_t, pad_t), 27 | Resnet1D(width, depth, dilation_growth_rate, activation=activation, norm=norm), 28 | ) 29 | blocks.append(block) 30 | blocks.append(nn.Conv1d(width, output_emb_width, 3, 1, 1)) 31 | self.model = nn.Sequential(*blocks) 32 | 33 | def forward(self, x): 34 | return self.model(x) 35 | 36 | 37 | class Decoder(nn.Module): 38 | def __init__(self, 39 | input_emb_width=3, 40 | output_emb_width=512, 41 | down_t=2, 42 | stride_t=2, 43 | width=512, 44 | depth=3, 45 | dilation_growth_rate=3, 46 | activation='relu', 47 | norm=None): 48 | super().__init__() 49 | blocks = [] 50 | 51 | blocks.append(nn.Conv1d(output_emb_width, width, 3, 1, 1)) 52 | blocks.append(nn.ReLU()) 53 | for i in range(down_t): 54 | out_dim = width 55 | block = nn.Sequential( 56 | Resnet1D(width, depth, dilation_growth_rate, reverse_dilation=True, activation=activation, norm=norm), 57 | nn.Upsample(scale_factor=2, mode='nearest'), 58 | nn.Conv1d(width, out_dim, 3, 1, 1) 59 | ) 60 | blocks.append(block) 61 | blocks.append(nn.Conv1d(width, width, 3, 1, 1)) 62 | blocks.append(nn.ReLU()) 63 | blocks.append(nn.Conv1d(width, input_emb_width, 3, 1, 1)) 64 | self.model = nn.Sequential(*blocks) 65 | 66 | def forward(self, x): 67 | x = self.model(x) 68 | return x.permute(0, 2, 1) -------------------------------------------------------------------------------- /models/vq/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | class nonlinearity(nn.Module): 5 | def __init(self): 6 | super().__init__() 7 | 8 | def forward(self, x): 9 | return x * torch.sigmoid(x) 10 | 11 | 12 | class ResConv1DBlock(nn.Module): 13 | def __init__(self, n_in, n_state, dilation=1, activation='silu', norm=None, dropout=0.2): 14 | super(ResConv1DBlock, self).__init__() 15 | 16 | padding = dilation 17 | self.norm = norm 18 | 19 | if norm == "LN": 20 | self.norm1 = nn.LayerNorm(n_in) 21 | self.norm2 = nn.LayerNorm(n_in) 22 | elif norm == "GN": 23 | self.norm1 = nn.GroupNorm(num_groups=32, num_channels=n_in, eps=1e-6, affine=True) 24 | self.norm2 = nn.GroupNorm(num_groups=32, num_channels=n_in, eps=1e-6, affine=True) 25 | elif norm == "BN": 26 | self.norm1 = nn.BatchNorm1d(num_features=n_in, eps=1e-6, affine=True) 27 | self.norm2 = nn.BatchNorm1d(num_features=n_in, eps=1e-6, affine=True) 28 | else: 29 | self.norm1 = nn.Identity() 30 | self.norm2 = nn.Identity() 31 | 32 | if activation == "relu": 33 | self.activation1 = nn.ReLU() 34 | self.activation2 = nn.ReLU() 35 | 36 | elif activation == "silu": 37 | self.activation1 = nonlinearity() 38 | self.activation2 = nonlinearity() 39 | 40 | elif activation == "gelu": 41 | self.activation1 = nn.GELU() 42 | self.activation2 = nn.GELU() 43 | 44 | self.conv1 = nn.Conv1d(n_in, n_state, 3, 1, padding, dilation) 45 | self.conv2 = nn.Conv1d(n_state, n_in, 1, 1, 0, ) 46 | self.dropout = nn.Dropout(dropout) 47 | 48 | def forward(self, x): 49 | x_orig = x 50 | if self.norm == "LN": 51 | x = self.norm1(x.transpose(-2, -1)) 52 | x = self.activation1(x.transpose(-2, -1)) 53 | else: 54 | x = self.norm1(x) 55 | x = self.activation1(x) 56 | 57 | x = self.conv1(x) 58 | 59 | if self.norm == "LN": 60 | x = self.norm2(x.transpose(-2, -1)) 61 | x = self.activation2(x.transpose(-2, -1)) 62 | else: 63 | x = self.norm2(x) 64 | x = self.activation2(x) 65 | 66 | x = self.conv2(x) 67 | x = self.dropout(x) 68 | x = x + x_orig 69 | return x 70 | 71 | 72 | class Resnet1D(nn.Module): 73 | def __init__(self, n_in, n_depth, dilation_growth_rate=1, reverse_dilation=True, activation='relu', norm=None): 74 | super().__init__() 75 | 76 | blocks = [ResConv1DBlock(n_in, n_in, dilation=dilation_growth_rate ** depth, activation=activation, norm=norm) 77 | for depth in range(n_depth)] 78 | if reverse_dilation: 79 | blocks = blocks[::-1] 80 | 81 | self.model = nn.Sequential(*blocks) 82 | 83 | def forward(self, x): 84 | return self.model(x) -------------------------------------------------------------------------------- /options/eval_option.py: -------------------------------------------------------------------------------- 1 | from options.base_option import BaseOptions 2 | 3 | class EvalT2MOptions(BaseOptions): 4 | def initialize(self): 5 | BaseOptions.initialize(self) 6 | self.parser.add_argument('--which_epoch', type=str, default="latest", help='Checkpoint you want to use, {latest, net_best_fid, etc}') 7 | self.parser.add_argument('--batch_size', type=int, default=32, help='Batch size') 8 | 9 | self.parser.add_argument('--ext', type=str, default='text2motion', help='Extension of the result file or folder') 10 | self.parser.add_argument("--num_batch", default=2, type=int, 11 | help="Number of batch for generation") 12 | self.parser.add_argument("--repeat_times", default=1, type=int, 13 | help="Number of repetitions, per sample text prompt") 14 | self.parser.add_argument("--cond_scale", default=4, type=float, 15 | help="For classifier-free sampling - specifies the s parameter, as defined in the paper.") 16 | self.parser.add_argument("--temperature", default=1., type=float, 17 | help="Sampling Temperature.") 18 | self.parser.add_argument("--topkr", default=0.9, type=float, 19 | help="Filter out percentil low prop entries.") 20 | self.parser.add_argument("--time_steps", default=18, type=int, 21 | help="Mask Generate steps.") 22 | self.parser.add_argument("--seed", default=10107, type=int) 23 | 24 | self.parser.add_argument('--gumbel_sample', action="store_true", help='True: gumbel sampling, False: categorical sampling.') 25 | self.parser.add_argument('--use_res_model', action="store_true", help='Whether to use residual transformer.') 26 | # self.parser.add_argument('--est_length', action="store_true", help='Training iterations') 27 | 28 | self.parser.add_argument('--res_name', type=str, default='tres_nlayer8_ld384_ff1024_rvq6ns_cdp0.2_sw', help='Model name of residual transformer') 29 | self.parser.add_argument('--text_path', type=str, default="", help='Text prompt file') 30 | 31 | 32 | self.parser.add_argument('-msec', '--mask_edit_section', nargs='*', type=str, help='Indicate sections for editing, use comma to separate the start and end of a section' 33 | 'type int will specify the token frame, type float will specify the ratio of seq_len') 34 | self.parser.add_argument('--text_prompt', default='', type=str, help="A text prompt to be generated. If empty, will take text prompts from dataset.") 35 | self.parser.add_argument('--source_motion', default='example_data/000612.npy', type=str, help="Source motion path for editing. (new_joint_vecs format .npy file)") 36 | self.parser.add_argument("--motion_length", default=0, type=int, 37 | help="Motion length for generation, only applicable with single text prompt.") 38 | self.is_train = False 39 | -------------------------------------------------------------------------------- /options/base_option.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | 5 | class BaseOptions(): 6 | def __init__(self): 7 | self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 8 | self.initialized = False 9 | 10 | def initialize(self): 11 | self.parser.add_argument('--name', type=str, default="t2m_nlayer8_nhead6_ld384_ff1024_cdp0.1_rvq6ns", help='Name of this trial') 12 | 13 | self.parser.add_argument('--vq_name', type=str, default="rvq_nq1_dc512_nc512", help='Name of the rvq model.') 14 | 15 | self.parser.add_argument("--gpu_id", type=int, default=-1, help='GPU id') 16 | self.parser.add_argument('--dataset_name', type=str, default='t2m', help='Dataset Name, {t2m} for humanml3d, {kit} for kit-ml') 17 | self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here.') 18 | 19 | self.parser.add_argument('--latent_dim', type=int, default=384, help='Dimension of transformer latent.') 20 | self.parser.add_argument('--n_heads', type=int, default=6, help='Number of heads.') 21 | self.parser.add_argument('--n_layers', type=int, default=8, help='Number of attention layers.') 22 | self.parser.add_argument('--ff_size', type=int, default=1024, help='FF_Size') 23 | self.parser.add_argument('--dropout', type=float, default=0.2, help='Dropout ratio in transformer') 24 | 25 | self.parser.add_argument("--max_motion_length", type=int, default=196, help="Max length of motion") 26 | self.parser.add_argument("--unit_length", type=int, default=4, help="Downscale ratio of VQ") 27 | 28 | self.parser.add_argument('--force_mask', action="store_true", help='True: mask out conditions') 29 | 30 | self.initialized = True 31 | 32 | def parse(self): 33 | if not self.initialized: 34 | self.initialize() 35 | 36 | self.opt = self.parser.parse_args() 37 | 38 | self.opt.is_train = self.is_train 39 | 40 | if self.opt.gpu_id != -1: 41 | # self.opt.gpu_id = int(self.opt.gpu_id) 42 | torch.cuda.set_device(self.opt.gpu_id) 43 | 44 | args = vars(self.opt) 45 | 46 | print('------------ Options -------------') 47 | for k, v in sorted(args.items()): 48 | print('%s: %s' % (str(k), str(v))) 49 | print('-------------- End ----------------') 50 | if self.is_train: 51 | # save to the disk 52 | expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.dataset_name, self.opt.name) 53 | if not os.path.exists(expr_dir): 54 | os.makedirs(expr_dir) 55 | file_name = os.path.join(expr_dir, 'opt.txt') 56 | with open(file_name, 'wt') as opt_file: 57 | opt_file.write('------------ Options -------------\n') 58 | for k, v in sorted(args.items()): 59 | opt_file.write('%s: %s\n' % (str(k), str(v))) 60 | opt_file.write('-------------- End ----------------\n') 61 | return self.opt -------------------------------------------------------------------------------- /utils/get_opt.py: -------------------------------------------------------------------------------- 1 | import os 2 | from argparse import Namespace 3 | import re 4 | from os.path import join as pjoin 5 | from utils.word_vectorizer import POS_enumerator 6 | 7 | 8 | def is_float(numStr): 9 | flag = False 10 | numStr = str(numStr).strip().lstrip('-').lstrip('+') # 去除正数(+)、负数(-)符号 11 | try: 12 | reg = re.compile(r'^[-+]?[0-9]+\.[0-9]+$') 13 | res = reg.match(str(numStr)) 14 | if res: 15 | flag = True 16 | except Exception as ex: 17 | print("is_float() - error: " + str(ex)) 18 | return flag 19 | 20 | 21 | def is_number(numStr): 22 | flag = False 23 | numStr = str(numStr).strip().lstrip('-').lstrip('+') # 去除正数(+)、负数(-)符号 24 | if str(numStr).isdigit(): 25 | flag = True 26 | return flag 27 | 28 | 29 | def get_opt(opt_path, device, **kwargs): 30 | opt = Namespace() 31 | opt_dict = vars(opt) 32 | 33 | skip = ('-------------- End ----------------', 34 | '------------ Options -------------', 35 | '\n') 36 | print('Reading', opt_path) 37 | with open(opt_path, 'r') as f: 38 | for line in f: 39 | if line.strip() not in skip: 40 | # print(line.strip()) 41 | key, value = line.strip('\n').split(': ') 42 | if value in ('True', 'False'): 43 | opt_dict[key] = (value == 'True') 44 | # print(key, value) 45 | elif is_float(value): 46 | opt_dict[key] = float(value) 47 | elif is_number(value): 48 | opt_dict[key] = int(value) 49 | else: 50 | opt_dict[key] = str(value) 51 | 52 | # print(opt) 53 | opt_dict['which_epoch'] = 'finest' 54 | opt.save_root = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.name) 55 | opt.model_dir = pjoin(opt.save_root, 'model') 56 | opt.meta_dir = pjoin(opt.save_root, 'meta') 57 | 58 | if opt.dataset_name == 't2m': 59 | opt.data_root = './dataset/HumanML3D/' 60 | opt.motion_dir = pjoin(opt.data_root, 'new_joint_vecs') 61 | opt.text_dir = pjoin(opt.data_root, 'texts') 62 | opt.joints_num = 22 63 | opt.dim_pose = 263 64 | opt.max_motion_length = 196 65 | opt.max_motion_frame = 196 66 | opt.max_motion_token = 55 67 | elif opt.dataset_name == 'kit': 68 | opt.data_root = './dataset/KIT-ML/' 69 | opt.motion_dir = pjoin(opt.data_root, 'new_joint_vecs') 70 | opt.text_dir = pjoin(opt.data_root, 'texts') 71 | opt.joints_num = 21 72 | opt.dim_pose = 251 73 | opt.max_motion_length = 196 74 | opt.max_motion_frame = 196 75 | opt.max_motion_token = 55 76 | else: 77 | raise KeyError('Dataset not recognized') 78 | if not hasattr(opt, 'unit_length'): 79 | opt.unit_length = 4 80 | opt.dim_word = 300 81 | opt.num_classes = 200 // opt.unit_length 82 | opt.dim_pos_ohot = len(POS_enumerator) 83 | opt.is_train = False 84 | opt.is_continue = False 85 | opt.device = device 86 | 87 | opt_dict.update(kwargs) # Overwrite with kwargs params 88 | 89 | return opt -------------------------------------------------------------------------------- /utils/word_vectorizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pickle 3 | from os.path import join as pjoin 4 | 5 | POS_enumerator = { 6 | 'VERB': 0, 7 | 'NOUN': 1, 8 | 'DET': 2, 9 | 'ADP': 3, 10 | 'NUM': 4, 11 | 'AUX': 5, 12 | 'PRON': 6, 13 | 'ADJ': 7, 14 | 'ADV': 8, 15 | 'Loc_VIP': 9, 16 | 'Body_VIP': 10, 17 | 'Obj_VIP': 11, 18 | 'Act_VIP': 12, 19 | 'Desc_VIP': 13, 20 | 'OTHER': 14, 21 | } 22 | 23 | Loc_list = ('left', 'right', 'clockwise', 'counterclockwise', 'anticlockwise', 'forward', 'back', 'backward', 24 | 'up', 'down', 'straight', 'curve') 25 | 26 | Body_list = ('arm', 'chin', 'foot', 'feet', 'face', 'hand', 'mouth', 'leg', 'waist', 'eye', 'knee', 'shoulder', 'thigh') 27 | 28 | Obj_List = ('stair', 'dumbbell', 'chair', 'window', 'floor', 'car', 'ball', 'handrail', 'baseball', 'basketball') 29 | 30 | Act_list = ('walk', 'run', 'swing', 'pick', 'bring', 'kick', 'put', 'squat', 'throw', 'hop', 'dance', 'jump', 'turn', 31 | 'stumble', 'dance', 'stop', 'sit', 'lift', 'lower', 'raise', 'wash', 'stand', 'kneel', 'stroll', 32 | 'rub', 'bend', 'balance', 'flap', 'jog', 'shuffle', 'lean', 'rotate', 'spin', 'spread', 'climb') 33 | 34 | Desc_list = ('slowly', 'carefully', 'fast', 'careful', 'slow', 'quickly', 'happy', 'angry', 'sad', 'happily', 35 | 'angrily', 'sadly') 36 | 37 | VIP_dict = { 38 | 'Loc_VIP': Loc_list, 39 | 'Body_VIP': Body_list, 40 | 'Obj_VIP': Obj_List, 41 | 'Act_VIP': Act_list, 42 | 'Desc_VIP': Desc_list, 43 | } 44 | 45 | 46 | class WordVectorizer(object): 47 | def __init__(self, meta_root, prefix): 48 | vectors = np.load(pjoin(meta_root, '%s_data.npy'%prefix)) 49 | words = pickle.load(open(pjoin(meta_root, '%s_words.pkl'%prefix), 'rb')) 50 | self.word2idx = pickle.load(open(pjoin(meta_root, '%s_idx.pkl'%prefix), 'rb')) 51 | self.word2vec = {w: vectors[self.word2idx[w]] for w in words} 52 | 53 | def _get_pos_ohot(self, pos): 54 | pos_vec = np.zeros(len(POS_enumerator)) 55 | if pos in POS_enumerator: 56 | pos_vec[POS_enumerator[pos]] = 1 57 | else: 58 | pos_vec[POS_enumerator['OTHER']] = 1 59 | return pos_vec 60 | 61 | def __len__(self): 62 | return len(self.word2vec) 63 | 64 | def __getitem__(self, item): 65 | word, pos = item.split('/') 66 | if word in self.word2vec: 67 | word_vec = self.word2vec[word] 68 | vip_pos = None 69 | for key, values in VIP_dict.items(): 70 | if word in values: 71 | vip_pos = key 72 | break 73 | if vip_pos is not None: 74 | pos_vec = self._get_pos_ohot(vip_pos) 75 | else: 76 | pos_vec = self._get_pos_ohot(pos) 77 | else: 78 | word_vec = self.word2vec['unk'] 79 | pos_vec = self._get_pos_ohot('OTHER') 80 | return word_vec, pos_vec 81 | 82 | 83 | class WordVectorizerV2(WordVectorizer): 84 | def __init__(self, meta_root, prefix): 85 | super(WordVectorizerV2, self).__init__(meta_root, prefix) 86 | self.idx2word = {self.word2idx[w]: w for w in self.word2idx} 87 | 88 | def __getitem__(self, item): 89 | word_vec, pose_vec = super(WordVectorizerV2, self).__getitem__(item) 90 | word, pos = item.split('/') 91 | if word in self.word2vec: 92 | return word_vec, pose_vec, self.word2idx[word] 93 | else: 94 | return word_vec, pose_vec, self.word2idx['unk'] 95 | 96 | def itos(self, idx): 97 | if idx == len(self.idx2word): 98 | return "pad" 99 | return self.idx2word[idx] -------------------------------------------------------------------------------- /options/train_option.py: -------------------------------------------------------------------------------- 1 | from options.base_option import BaseOptions 2 | import argparse 3 | 4 | class TrainT2MOptions(BaseOptions): 5 | def initialize(self): 6 | BaseOptions.initialize(self) 7 | self.parser.add_argument('--batch_size', type=int, default=64, help='Batch size') 8 | self.parser.add_argument('--max_epoch', type=int, default=500, help='Maximum number of epoch for training') 9 | # self.parser.add_argument('--max_iters', type=int, default=150_000, help='Training iterations') 10 | 11 | '''LR scheduler''' 12 | self.parser.add_argument('--lr', type=float, default=2e-4, help='Learning rate') 13 | self.parser.add_argument('--gamma', type=float, default=0.1, help='Learning rate schedule factor') 14 | self.parser.add_argument('--milestones', default=[50_000], nargs="+", type=int, 15 | help="learning rate schedule (iterations)") 16 | self.parser.add_argument('--warm_up_iter', default=2000, type=int, help='number of total iterations for warmup') 17 | 18 | '''Condition''' 19 | self.parser.add_argument('--cond_drop_prob', type=float, default=0.1, help='Drop ratio of condition, for classifier-free guidance') 20 | self.parser.add_argument("--seed", default=3407, type=int, help="Seed") 21 | 22 | self.parser.add_argument('--is_continue', action="store_true", help='Is this trial continuing previous state?') 23 | self.parser.add_argument('--gumbel_sample', action="store_true", help='Strategy for token sampling, True: Gumbel sampling, False: Categorical sampling') 24 | self.parser.add_argument('--share_weight', action="store_true", help='Whether to share weight for projection/embedding, for residual transformer.') 25 | 26 | self.parser.add_argument('--log_every', type=int, default=50, help='Frequency of printing training progress, (iteration)') 27 | # self.parser.add_argument('--save_every_e', type=int, default=100, help='Frequency of printing training progress') 28 | self.parser.add_argument('--eval_every_e', type=int, default=10, help='Frequency of animating eval results, (epoch)') 29 | self.parser.add_argument('--save_latest', type=int, default=500, help='Frequency of saving checkpoint, (iteration)') 30 | 31 | 32 | self.is_train = True 33 | 34 | 35 | class TrainLenEstOptions(): 36 | def __init__(self): 37 | self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 38 | self.parser.add_argument('--name', type=str, default="test", help='Name of this trial') 39 | self.parser.add_argument("--gpu_id", type=int, default=-1, help='GPU id') 40 | 41 | self.parser.add_argument('--dataset_name', type=str, default='t2m', help='Dataset Name') 42 | self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') 43 | 44 | self.parser.add_argument('--batch_size', type=int, default=64, help='Batch size') 45 | 46 | self.parser.add_argument("--unit_length", type=int, default=4, help="Length of motion") 47 | self.parser.add_argument("--max_text_len", type=int, default=20, help="Length of motion") 48 | 49 | self.parser.add_argument('--max_epoch', type=int, default=300, help='Training iterations') 50 | 51 | self.parser.add_argument('--lr', type=float, default=1e-4, help='Layers of GRU') 52 | 53 | self.parser.add_argument('--is_continue', action="store_true", help='Training iterations') 54 | 55 | self.parser.add_argument('--log_every', type=int, default=50, help='Frequency of printing training progress') 56 | self.parser.add_argument('--save_every_e', type=int, default=5, help='Frequency of printing training progress') 57 | self.parser.add_argument('--eval_every_e', type=int, default=3, help='Frequency of printing training progress') 58 | self.parser.add_argument('--save_latest', type=int, default=500, help='Frequency of printing training progress') 59 | 60 | def parse(self): 61 | self.opt = self.parser.parse_args() 62 | self.opt.is_train = True 63 | # args = vars(self.opt) 64 | return self.opt 65 | -------------------------------------------------------------------------------- /models/vq/model.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import torch.nn as nn 4 | from models.vq.encdec import Encoder, Decoder 5 | from models.vq.residual_vq import ResidualVQ 6 | 7 | class RVQVAE(nn.Module): 8 | def __init__(self, 9 | args, 10 | input_width=263, 11 | nb_code=1024, 12 | code_dim=512, 13 | output_emb_width=512, 14 | down_t=3, 15 | stride_t=2, 16 | width=512, 17 | depth=3, 18 | dilation_growth_rate=3, 19 | activation='relu', 20 | norm=None): 21 | 22 | super().__init__() 23 | assert output_emb_width == code_dim 24 | self.code_dim = code_dim 25 | self.num_code = nb_code 26 | # self.quant = args.quantizer 27 | self.encoder = Encoder(input_width, output_emb_width, down_t, stride_t, width, depth, 28 | dilation_growth_rate, activation=activation, norm=norm) 29 | self.decoder = Decoder(input_width, output_emb_width, down_t, stride_t, width, depth, 30 | dilation_growth_rate, activation=activation, norm=norm) 31 | rvqvae_config = { 32 | 'num_quantizers': args.num_quantizers, 33 | 'shared_codebook': args.shared_codebook, 34 | 'quantize_dropout_prob': args.quantize_dropout_prob, 35 | 'quantize_dropout_cutoff_index': 0, 36 | 'nb_code': nb_code, 37 | 'code_dim':code_dim, 38 | 'args': args, 39 | } 40 | self.quantizer = ResidualVQ(**rvqvae_config) 41 | 42 | def preprocess(self, x): 43 | # (bs, T, Jx3) -> (bs, Jx3, T) 44 | x = x.permute(0, 2, 1).float() 45 | return x 46 | 47 | def postprocess(self, x): 48 | # (bs, Jx3, T) -> (bs, T, Jx3) 49 | x = x.permute(0, 2, 1) 50 | return x 51 | 52 | def encode(self, x): 53 | N, T, _ = x.shape 54 | x_in = self.preprocess(x) 55 | x_encoder = self.encoder(x_in) 56 | # print(x_encoder.shape) 57 | code_idx, all_codes = self.quantizer.quantize(x_encoder, return_latent=True) 58 | # print(code_idx.shape) 59 | # code_idx = code_idx.view(N, -1) 60 | # (N, T, Q) 61 | # print() 62 | return code_idx, all_codes 63 | 64 | def forward(self, x): 65 | x_in = self.preprocess(x) 66 | # Encode 67 | x_encoder = self.encoder(x_in) 68 | 69 | ## quantization 70 | # x_quantized, code_idx, commit_loss, perplexity = self.quantizer(x_encoder, sample_codebook_temp=0.5, 71 | # force_dropout_index=0) #TODO hardcode 72 | x_quantized, code_idx, commit_loss, perplexity = self.quantizer(x_encoder, sample_codebook_temp=0.5) 73 | 74 | # print(code_idx[0, :, 1]) 75 | ## decoder 76 | x_out = self.decoder(x_quantized) 77 | # x_out = self.postprocess(x_decoder) 78 | return x_out, commit_loss, perplexity 79 | 80 | def forward_decoder(self, x): 81 | x_d = self.quantizer.get_codes_from_indices(x) 82 | # x_d = x_d.view(1, -1, self.code_dim).permute(0, 2, 1).contiguous() 83 | x = x_d.sum(dim=0).permute(0, 2, 1) 84 | 85 | # decoder 86 | x_out = self.decoder(x) 87 | # x_out = self.postprocess(x_decoder) 88 | return x_out 89 | 90 | class LengthEstimator(nn.Module): 91 | def __init__(self, input_size, output_size): 92 | super(LengthEstimator, self).__init__() 93 | nd = 512 94 | self.output = nn.Sequential( 95 | nn.Linear(input_size, nd), 96 | nn.LayerNorm(nd), 97 | nn.LeakyReLU(0.2, inplace=True), 98 | 99 | nn.Dropout(0.2), 100 | nn.Linear(nd, nd // 2), 101 | nn.LayerNorm(nd // 2), 102 | nn.LeakyReLU(0.2, inplace=True), 103 | 104 | nn.Dropout(0.2), 105 | nn.Linear(nd // 2, nd // 4), 106 | nn.LayerNorm(nd // 4), 107 | nn.LeakyReLU(0.2, inplace=True), 108 | 109 | nn.Linear(nd // 4, output_size) 110 | ) 111 | 112 | self.output.apply(self.__init_weights) 113 | 114 | def __init_weights(self, module): 115 | if isinstance(module, (nn.Linear, nn.Embedding)): 116 | module.weight.data.normal_(mean=0.0, std=0.02) 117 | if isinstance(module, nn.Linear) and module.bias is not None: 118 | module.bias.data.zero_() 119 | elif isinstance(module, nn.LayerNorm): 120 | module.bias.data.zero_() 121 | module.weight.data.fill_(1.0) 122 | 123 | def forward(self, text_emb): 124 | return self.output(text_emb) -------------------------------------------------------------------------------- /train_vq.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join as pjoin 3 | 4 | import torch 5 | from torch.utils.data import DataLoader 6 | 7 | from models.vq.model import RVQVAE 8 | from models.vq.vq_trainer import RVQTokenizerTrainer 9 | from options.vq_option import arg_parse 10 | from data.t2m_dataset import MotionDataset 11 | from utils import paramUtil 12 | import numpy as np 13 | 14 | from models.t2m_eval_wrapper import EvaluatorModelWrapper 15 | from utils.get_opt import get_opt 16 | from motion_loaders.dataset_motion_loader import get_dataset_motion_loader 17 | 18 | from utils.motion_process import recover_from_ric 19 | from utils.plot_script import plot_3d_motion 20 | from utils.fixseed import fixseed 21 | 22 | os.environ["OMP_NUM_THREADS"] = "1" 23 | 24 | def plot_t2m(data, save_dir): 25 | data = train_dataset.inv_transform(data) 26 | for i in range(len(data)): 27 | joint_data = data[i] 28 | joint = recover_from_ric(torch.from_numpy(joint_data).float(), opt.joints_num).numpy() 29 | save_path = pjoin(save_dir, '%02d.mp4' % (i)) 30 | plot_3d_motion(save_path, kinematic_chain, joint, title="None", fps=fps, radius=radius) 31 | 32 | 33 | if __name__ == "__main__": 34 | # torch.autograd.set_detect_anomaly(True) 35 | opt = arg_parse(True) 36 | fixseed(opt.seed) 37 | 38 | opt.device = torch.device("cpu" if opt.gpu_id == -1 else "cuda:" + str(opt.gpu_id)) 39 | print(f"Using Device: {opt.device}") 40 | 41 | opt.save_root = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.name) 42 | opt.model_dir = pjoin(opt.save_root, 'model') 43 | opt.meta_dir = pjoin(opt.save_root, 'meta') 44 | opt.eval_dir = pjoin(opt.save_root, 'animation') 45 | opt.log_dir = pjoin('./log/vq/', opt.dataset_name, opt.name) 46 | 47 | os.makedirs(opt.model_dir, exist_ok=True) 48 | os.makedirs(opt.meta_dir, exist_ok=True) 49 | os.makedirs(opt.eval_dir, exist_ok=True) 50 | os.makedirs(opt.log_dir, exist_ok=True) 51 | 52 | if opt.dataset_name == "t2m": 53 | opt.data_root = './dataset/HumanML3D/' 54 | opt.motion_dir = pjoin(opt.data_root, 'new_joint_vecs') 55 | opt.text_dir = pjoin(opt.data_root, 'texts') 56 | opt.joints_num = 22 57 | dim_pose = 263 58 | fps = 20 59 | radius = 4 60 | kinematic_chain = paramUtil.t2m_kinematic_chain 61 | dataset_opt_path = './checkpoints/t2m/Comp_v6_KLD005/opt.txt' 62 | 63 | elif opt.dataset_name == "kit": 64 | opt.data_root = './dataset/KIT-ML/' 65 | opt.motion_dir = pjoin(opt.data_root, 'new_joint_vecs') 66 | opt.text_dir = pjoin(opt.data_root, 'texts') 67 | opt.joints_num = 21 68 | radius = 240 * 8 69 | fps = 12.5 70 | dim_pose = 251 71 | opt.max_motion_length = 196 72 | kinematic_chain = paramUtil.kit_kinematic_chain 73 | dataset_opt_path = './checkpoints/kit/Comp_v6_KLD005/opt.txt' 74 | else: 75 | raise KeyError('Dataset Does not Exists') 76 | 77 | wrapper_opt = get_opt(dataset_opt_path, torch.device('cuda')) 78 | eval_wrapper = EvaluatorModelWrapper(wrapper_opt) 79 | 80 | mean = np.load(pjoin(opt.data_root, 'Mean.npy')) 81 | std = np.load(pjoin(opt.data_root, 'Std.npy')) 82 | 83 | train_split_file = pjoin(opt.data_root, 'train.txt') 84 | val_split_file = pjoin(opt.data_root, 'val.txt') 85 | 86 | 87 | net = RVQVAE(opt, 88 | dim_pose, 89 | opt.nb_code, 90 | opt.code_dim, 91 | opt.code_dim, 92 | opt.down_t, 93 | opt.stride_t, 94 | opt.width, 95 | opt.depth, 96 | opt.dilation_growth_rate, 97 | opt.vq_act, 98 | opt.vq_norm) 99 | 100 | pc_vq = sum(param.numel() for param in net.parameters()) 101 | print(net) 102 | # print("Total parameters of discriminator net: {}".format(pc_vq)) 103 | # all_params += pc_vq_dis 104 | 105 | print('Total parameters of all models: {}M'.format(pc_vq/1000_000)) 106 | 107 | trainer = RVQTokenizerTrainer(opt, vq_model=net) 108 | 109 | train_dataset = MotionDataset(opt, mean, std, train_split_file) 110 | val_dataset = MotionDataset(opt, mean, std, val_split_file) 111 | 112 | train_loader = DataLoader(train_dataset, batch_size=opt.batch_size, drop_last=True, num_workers=4, 113 | shuffle=True, pin_memory=True) 114 | val_loader = DataLoader(val_dataset, batch_size=opt.batch_size, drop_last=True, num_workers=4, 115 | shuffle=True, pin_memory=True) 116 | eval_val_loader, _ = get_dataset_motion_loader(dataset_opt_path, 32, 'val', device=opt.device) 117 | trainer.train(train_loader, val_loader, eval_val_loader, eval_wrapper, plot_t2m) 118 | 119 | ## train_vq.py --dataset_name kit --batch_size 512 --name VQVAE_dp2 --gpu_id 3 120 | ## train_vq.py --dataset_name kit --batch_size 256 --name VQVAE_dp2_b256 --gpu_id 2 121 | ## train_vq.py --dataset_name kit --batch_size 1024 --name VQVAE_dp2_b1024 --gpu_id 1 122 | ## python train_vq.py --dataset_name kit --batch_size 256 --name VQVAE_dp1_b256 --gpu_id 2 -------------------------------------------------------------------------------- /eval_t2m_vq.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | from os.path import join as pjoin 4 | 5 | import torch 6 | from models.vq.model import RVQVAE 7 | from options.vq_option import arg_parse 8 | from motion_loaders.dataset_motion_loader import get_dataset_motion_loader 9 | import utils.eval_t2m as eval_t2m 10 | from utils.get_opt import get_opt 11 | from models.t2m_eval_wrapper import EvaluatorModelWrapper 12 | import warnings 13 | warnings.filterwarnings('ignore') 14 | import numpy as np 15 | from utils.word_vectorizer import WordVectorizer 16 | 17 | def load_vq_model(vq_opt, which_epoch): 18 | # opt_path = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.vq_name, 'opt.txt') 19 | 20 | vq_model = RVQVAE(vq_opt, 21 | dim_pose, 22 | vq_opt.nb_code, 23 | vq_opt.code_dim, 24 | vq_opt.code_dim, 25 | vq_opt.down_t, 26 | vq_opt.stride_t, 27 | vq_opt.width, 28 | vq_opt.depth, 29 | vq_opt.dilation_growth_rate, 30 | vq_opt.vq_act, 31 | vq_opt.vq_norm) 32 | ckpt = torch.load(pjoin(vq_opt.checkpoints_dir, vq_opt.dataset_name, vq_opt.name, 'model', which_epoch), 33 | map_location='cpu') 34 | model_key = 'vq_model' if 'vq_model' in ckpt else 'net' 35 | vq_model.load_state_dict(ckpt[model_key]) 36 | vq_epoch = ckpt['ep'] if 'ep' in ckpt else -1 37 | print(f'Loading VQ Model {vq_opt.name} Completed!, Epoch {vq_epoch}') 38 | return vq_model, vq_epoch 39 | 40 | if __name__ == "__main__": 41 | ##### ---- Exp dirs ---- ##### 42 | args = arg_parse(False) 43 | args.device = torch.device("cpu" if args.gpu_id == -1 else "cuda:" + str(args.gpu_id)) 44 | 45 | args.out_dir = pjoin(args.checkpoints_dir, args.dataset_name, args.name, 'eval') 46 | os.makedirs(args.out_dir, exist_ok=True) 47 | 48 | f = open(pjoin(args.out_dir, '%s.log'%args.ext), 'w') 49 | 50 | dataset_opt_path = 'checkpoints/kit/Comp_v6_KLD005/opt.txt' if args.dataset_name == 'kit' \ 51 | else 'checkpoints/t2m/Comp_v6_KLD005/opt.txt' 52 | 53 | wrapper_opt = get_opt(dataset_opt_path, torch.device('cuda')) 54 | eval_wrapper = EvaluatorModelWrapper(wrapper_opt) 55 | 56 | ##### ---- Dataloader ---- ##### 57 | args.nb_joints = 21 if args.dataset_name == 'kit' else 22 58 | dim_pose = 251 if args.dataset_name == 'kit' else 263 59 | 60 | eval_val_loader, _ = get_dataset_motion_loader(dataset_opt_path, 32, 'test', device=args.device) 61 | 62 | print(len(eval_val_loader)) 63 | 64 | ##### ---- Network ---- ##### 65 | vq_opt_path = pjoin(args.checkpoints_dir, args.dataset_name, args.name, 'opt.txt') 66 | vq_opt = get_opt(vq_opt_path, device=args.device) 67 | # net = load_vq_model() 68 | 69 | model_dir = pjoin(args.checkpoints_dir, args.dataset_name, args.name, 'model') 70 | for file in os.listdir(model_dir): 71 | # if not file.endswith('tar'): 72 | # continue 73 | # if not file.startswith('net_best_fid'): 74 | # continue 75 | if args.which_epoch != "all" and args.which_epoch not in file: 76 | continue 77 | print(file) 78 | net, ep = load_vq_model(vq_opt, file) 79 | 80 | net.eval() 81 | net.cuda() 82 | 83 | fid = [] 84 | div = [] 85 | top1 = [] 86 | top2 = [] 87 | top3 = [] 88 | matching = [] 89 | mae = [] 90 | repeat_time = 20 91 | for i in range(repeat_time): 92 | best_fid, best_div, Rprecision, best_matching, l1_dist = \ 93 | eval_t2m.evaluation_vqvae_plus_mpjpe(eval_val_loader, net, i, eval_wrapper=eval_wrapper, num_joint=args.nb_joints) 94 | fid.append(best_fid) 95 | div.append(best_div) 96 | top1.append(Rprecision[0]) 97 | top2.append(Rprecision[1]) 98 | top3.append(Rprecision[2]) 99 | matching.append(best_matching) 100 | mae.append(l1_dist) 101 | 102 | fid = np.array(fid) 103 | div = np.array(div) 104 | top1 = np.array(top1) 105 | top2 = np.array(top2) 106 | top3 = np.array(top3) 107 | matching = np.array(matching) 108 | mae = np.array(mae) 109 | 110 | print(f'{file} final result, epoch {ep}') 111 | print(f'{file} final result, epoch {ep}', file=f, flush=True) 112 | 113 | msg_final = f"\tFID: {np.mean(fid):.3f}, conf. {np.std(fid)*1.96/np.sqrt(repeat_time):.3f}\n" \ 114 | f"\tDiversity: {np.mean(div):.3f}, conf. {np.std(div)*1.96/np.sqrt(repeat_time):.3f}\n" \ 115 | f"\tTOP1: {np.mean(top1):.3f}, conf. {np.std(top1)*1.96/np.sqrt(repeat_time):.3f}, TOP2. {np.mean(top2):.3f}, conf. {np.std(top2)*1.96/np.sqrt(repeat_time):.3f}, TOP3. {np.mean(top3):.3f}, conf. {np.std(top3)*1.96/np.sqrt(repeat_time):.3f}\n" \ 116 | f"\tMatching: {np.mean(matching):.3f}, conf. {np.std(matching)*1.96/np.sqrt(repeat_time):.3f}\n" \ 117 | f"\tMAE:{np.mean(mae):.3f}, conf.{np.std(mae)*1.96/np.sqrt(repeat_time):.3f}\n\n" 118 | # logger.info(msg_final) 119 | print(msg_final) 120 | print(msg_final, file=f, flush=True) 121 | 122 | f.close() 123 | 124 | -------------------------------------------------------------------------------- /options/vq_option.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | 5 | def arg_parse(is_train=False): 6 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 7 | 8 | ## dataloader 9 | parser.add_argument('--dataset_name', type=str, default='humanml3d', help='dataset directory') 10 | parser.add_argument('--batch_size', default=256, type=int, help='batch size') 11 | parser.add_argument('--window_size', type=int, default=64, help='training motion length') 12 | parser.add_argument("--gpu_id", type=int, default=0, help='GPU id') 13 | 14 | ## optimization 15 | parser.add_argument('--max_epoch', default=50, type=int, help='number of total epochs to run') 16 | # parser.add_argument('--total_iter', default=None, type=int, help='number of total iterations to run') 17 | parser.add_argument('--warm_up_iter', default=2000, type=int, help='number of total iterations for warmup') 18 | parser.add_argument('--lr', default=2e-4, type=float, help='max learning rate') 19 | parser.add_argument('--milestones', default=[150000, 250000], nargs="+", type=int, help="learning rate schedule (iterations)") 20 | parser.add_argument('--gamma', default=0.1, type=float, help="learning rate decay") 21 | 22 | parser.add_argument('--weight_decay', default=0.0, type=float, help='weight decay') 23 | parser.add_argument("--commit", type=float, default=0.02, help="hyper-parameter for the commitment loss") 24 | parser.add_argument('--loss_vel', type=float, default=0.5, help='hyper-parameter for the velocity loss') 25 | parser.add_argument('--recons_loss', type=str, default='l1_smooth', help='reconstruction loss') 26 | 27 | ## vqvae arch 28 | parser.add_argument("--code_dim", type=int, default=512, help="embedding dimension") 29 | parser.add_argument("--nb_code", type=int, default=512, help="nb of embedding") 30 | parser.add_argument("--mu", type=float, default=0.99, help="exponential moving average to update the codebook") 31 | parser.add_argument("--down_t", type=int, default=2, help="downsampling rate") 32 | parser.add_argument("--stride_t", type=int, default=2, help="stride size") 33 | parser.add_argument("--width", type=int, default=512, help="width of the network") 34 | parser.add_argument("--depth", type=int, default=3, help="num of resblocks for each res") 35 | parser.add_argument("--dilation_growth_rate", type=int, default=3, help="dilation growth rate") 36 | parser.add_argument("--output_emb_width", type=int, default=512, help="output embedding width") 37 | parser.add_argument('--vq_act', type=str, default='relu', choices=['relu', 'silu', 'gelu'], 38 | help='dataset directory') 39 | parser.add_argument('--vq_norm', type=str, default=None, help='dataset directory') 40 | 41 | parser.add_argument('--num_quantizers', type=int, default=3, help='num_quantizers') 42 | parser.add_argument('--shared_codebook', action="store_true") 43 | parser.add_argument('--quantize_dropout_prob', type=float, default=0.2, help='quantize_dropout_prob') 44 | # parser.add_argument('--use_vq_prob', type=float, default=0.8, help='quantize_dropout_prob') 45 | 46 | parser.add_argument('--ext', type=str, default='default', help='reconstruction loss') 47 | 48 | 49 | ## other 50 | parser.add_argument('--name', type=str, default="test", help='Name of this trial') 51 | parser.add_argument('--is_continue', action="store_true", help='Name of this trial') 52 | parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') 53 | parser.add_argument('--log_every', default=10, type=int, help='iter log frequency') 54 | parser.add_argument('--save_latest', default=500, type=int, help='iter save latest model frequency') 55 | parser.add_argument('--save_every_e', default=2, type=int, help='save model every n epoch') 56 | parser.add_argument('--eval_every_e', default=1, type=int, help='save eval results every n epoch') 57 | # parser.add_argument('--early_stop_e', default=5, type=int, help='early stopping epoch') 58 | parser.add_argument('--feat_bias', type=float, default=5, help='Layers of GRU') 59 | 60 | parser.add_argument('--which_epoch', type=str, default="all", help='Name of this trial') 61 | 62 | ## For Res Predictor only 63 | parser.add_argument('--vq_name', type=str, default="rvq_nq6_dc512_nc512_noshare_qdp0.2", help='Name of this trial') 64 | # parser.add_argument('--n_res', type=int, default=2, help='Name of this trial') 65 | # parser.add_argument('--do_vq_res', action="store_true") 66 | parser.add_argument("--seed", default=3407, type=int) 67 | 68 | opt = parser.parse_args() 69 | torch.cuda.set_device(opt.gpu_id) 70 | 71 | args = vars(opt) 72 | 73 | print('------------ Options -------------') 74 | for k, v in sorted(args.items()): 75 | print('%s: %s' % (str(k), str(v))) 76 | print('-------------- End ----------------') 77 | opt.is_train = is_train 78 | if is_train: 79 | # save to the disk 80 | expr_dir = os.path.join(opt.checkpoints_dir, opt.dataset_name, opt.name) 81 | if not os.path.exists(expr_dir): 82 | os.makedirs(expr_dir) 83 | file_name = os.path.join(expr_dir, 'opt.txt') 84 | with open(file_name, 'wt') as opt_file: 85 | opt_file.write('------------ Options -------------\n') 86 | for k, v in sorted(args.items()): 87 | opt_file.write('%s: %s\n' % (str(k), str(v))) 88 | opt_file.write('-------------- End ----------------\n') 89 | return opt -------------------------------------------------------------------------------- /models/mask_transformer/tools.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import math 4 | from einops import rearrange 5 | 6 | # return mask where padding is FALSE 7 | def lengths_to_mask(lengths, max_len): 8 | # max_len = max(lengths) 9 | mask = torch.arange(max_len, device=lengths.device).expand(len(lengths), max_len) < lengths.unsqueeze(1) 10 | return mask #(b, len) 11 | 12 | # return mask where padding is ALL FALSE 13 | def get_pad_mask_idx(seq, pad_idx): 14 | return (seq != pad_idx).unsqueeze(1) 15 | 16 | # Given seq: (b, s) 17 | # Return mat: (1, s, s) 18 | # Example Output: 19 | # [[[ True, False, False], 20 | # [ True, True, False], 21 | # [ True, True, True]]] 22 | # For causal attention 23 | def get_subsequent_mask(seq): 24 | sz_b, seq_len = seq.shape 25 | subsequent_mask = (1 - torch.triu( 26 | torch.ones((1, seq_len, seq_len)), diagonal=1)).bool() 27 | return subsequent_mask.to(seq.device) 28 | 29 | 30 | def exists(val): 31 | return val is not None 32 | 33 | def default(val, d): 34 | return val if exists(val) else d 35 | 36 | def eval_decorator(fn): 37 | def inner(model, *args, **kwargs): 38 | was_training = model.training 39 | model.eval() 40 | out = fn(model, *args, **kwargs) 41 | model.train(was_training) 42 | return out 43 | return inner 44 | 45 | def l2norm(t): 46 | return F.normalize(t, dim = -1) 47 | 48 | # tensor helpers 49 | 50 | # Get a random subset of TRUE mask, with prob 51 | def get_mask_subset_prob(mask, prob): 52 | subset_mask = torch.bernoulli(mask, p=prob) & mask 53 | return subset_mask 54 | 55 | 56 | # Get mask of special_tokens in ids 57 | def get_mask_special_tokens(ids, special_ids): 58 | mask = torch.zeros_like(ids).bool() 59 | for special_id in special_ids: 60 | mask |= (ids==special_id) 61 | return mask 62 | 63 | # network builder helpers 64 | def _get_activation_fn(activation): 65 | if activation == "relu": 66 | return F.relu 67 | elif activation == "gelu": 68 | return F.gelu 69 | 70 | raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) 71 | 72 | # classifier free guidance functions 73 | 74 | def uniform(shape, device=None): 75 | return torch.zeros(shape, device=device).float().uniform_(0, 1) 76 | 77 | def prob_mask_like(shape, prob, device=None): 78 | if prob == 1: 79 | return torch.ones(shape, device=device, dtype=torch.bool) 80 | elif prob == 0: 81 | return torch.zeros(shape, device=device, dtype=torch.bool) 82 | else: 83 | return uniform(shape, device=device) < prob 84 | 85 | # sampling helpers 86 | 87 | def log(t, eps = 1e-20): 88 | return torch.log(t.clamp(min = eps)) 89 | 90 | def gumbel_noise(t): 91 | noise = torch.zeros_like(t).uniform_(0, 1) 92 | return -log(-log(noise)) 93 | 94 | def gumbel_sample(t, temperature = 1., dim = 1): 95 | return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim=dim) 96 | 97 | 98 | # Example input: 99 | # [[ 0.3596, 0.0862, 0.9771, -1.0000, -1.0000, -1.0000], 100 | # [ 0.4141, 0.1781, 0.6628, 0.5721, -1.0000, -1.0000], 101 | # [ 0.9428, 0.3586, 0.1659, 0.8172, 0.9273, -1.0000]] 102 | # Example output: 103 | # [[ -inf, -inf, 0.9771, -inf, -inf, -inf], 104 | # [ -inf, -inf, 0.6628, -inf, -inf, -inf], 105 | # [0.9428, -inf, -inf, -inf, -inf, -inf]] 106 | def top_k(logits, thres = 0.9, dim = 1): 107 | k = math.ceil((1 - thres) * logits.shape[dim]) 108 | val, ind = logits.topk(k, dim = dim) 109 | probs = torch.full_like(logits, float('-inf')) 110 | probs.scatter_(dim, ind, val) 111 | # func verified 112 | # print(probs) 113 | # print(logits) 114 | # raise 115 | return probs 116 | 117 | # noise schedules 118 | 119 | # More on large value, less on small 120 | def cosine_schedule(t): 121 | return torch.cos(t * math.pi * 0.5) 122 | 123 | def scale_cosine_schedule(t, scale): 124 | return torch.clip(scale*torch.cos(t * math.pi * 0.5) + 1 - scale, min=0., max=1.) 125 | 126 | # More on small value, less on large 127 | def q_schedule(bs, low, high, device): 128 | noise = uniform((bs,), device=device) 129 | schedule = 1 - cosine_schedule(noise) 130 | return torch.round(schedule * (high - low - 1)).long() + low 131 | 132 | def cal_performance(pred, labels, ignore_index=None, smoothing=0., tk=1): 133 | loss = cal_loss(pred, labels, ignore_index, smoothing=smoothing) 134 | # pred_id = torch.argmax(pred, dim=1) 135 | # mask = labels.ne(ignore_index) 136 | # n_correct = pred_id.eq(labels).masked_select(mask) 137 | # acc = torch.mean(n_correct.float()).item() 138 | pred_id_k = torch.topk(pred, k=tk, dim=1).indices 139 | pred_id = pred_id_k[:, 0] 140 | mask = labels.ne(ignore_index) 141 | n_correct = (pred_id_k == labels.unsqueeze(1)).any(dim=1).masked_select(mask) 142 | acc = torch.mean(n_correct.float()).item() 143 | 144 | return loss, pred_id, acc 145 | 146 | 147 | def cal_loss(pred, labels, ignore_index=None, smoothing=0.): 148 | '''Calculate cross entropy loss, apply label smoothing if needed.''' 149 | # print(pred.shape, labels.shape) #torch.Size([64, 1028, 55]) torch.Size([64, 55]) 150 | # print(pred.shape, labels.shape) #torch.Size([64, 1027, 55]) torch.Size([64, 55]) 151 | if smoothing: 152 | space = 2 153 | n_class = pred.size(1) 154 | mask = labels.ne(ignore_index) 155 | one_hot = rearrange(F.one_hot(labels, n_class + space), 'a ... b -> a b ...')[:, :n_class] 156 | # one_hot = torch.zeros_like(pred).scatter(1, labels.unsqueeze(1), 1) 157 | sm_one_hot = one_hot * (1 - smoothing) + (1 - one_hot) * smoothing / (n_class - 1) 158 | neg_log_prb = -F.log_softmax(pred, dim=1) 159 | loss = (sm_one_hot * neg_log_prb).sum(dim=1) 160 | # loss = F.cross_entropy(pred, sm_one_hot, reduction='none') 161 | loss = torch.mean(loss.masked_select(mask)) 162 | else: 163 | loss = F.cross_entropy(pred, labels, ignore_index=ignore_index) 164 | 165 | return loss -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | # import cv2 4 | from PIL import Image 5 | from utils import paramUtil 6 | import math 7 | import time 8 | import matplotlib.pyplot as plt 9 | # from scipy.ndimage import gaussian_filter 10 | 11 | 12 | def mkdir(path): 13 | if not os.path.exists(path): 14 | os.makedirs(path) 15 | 16 | COLORS = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], 17 | [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], 18 | [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]] 19 | 20 | MISSING_VALUE = -1 21 | 22 | def save_image(image_numpy, image_path): 23 | img_pil = Image.fromarray(image_numpy) 24 | img_pil.save(image_path) 25 | 26 | 27 | def save_logfile(log_loss, save_path): 28 | with open(save_path, 'wt') as f: 29 | for k, v in log_loss.items(): 30 | w_line = k 31 | for digit in v: 32 | w_line += ' %.3f' % digit 33 | f.write(w_line + '\n') 34 | 35 | 36 | def print_current_loss(start_time, niter_state, total_niters, losses, epoch=None, sub_epoch=None, 37 | inner_iter=None, tf_ratio=None, sl_steps=None): 38 | 39 | def as_minutes(s): 40 | m = math.floor(s / 60) 41 | s -= m * 60 42 | return '%dm %ds' % (m, s) 43 | 44 | def time_since(since, percent): 45 | now = time.time() 46 | s = now - since 47 | es = s / percent 48 | rs = es - s 49 | return '%s (- %s)' % (as_minutes(s), as_minutes(rs)) 50 | 51 | if epoch is not None: 52 | print('ep/it:%2d-%4d niter:%6d' % (epoch, inner_iter, niter_state), end=" ") 53 | 54 | message = ' %s completed:%3d%%)' % (time_since(start_time, niter_state / total_niters), niter_state / total_niters * 100) 55 | # now = time.time() 56 | # message += '%s'%(as_minutes(now - start_time)) 57 | 58 | 59 | for k, v in losses.items(): 60 | message += ' %s: %.4f ' % (k, v) 61 | # message += ' sl_length:%2d tf_ratio:%.2f'%(sl_steps, tf_ratio) 62 | print(message) 63 | 64 | def print_current_loss_decomp(start_time, niter_state, total_niters, losses, epoch=None, inner_iter=None): 65 | 66 | def as_minutes(s): 67 | m = math.floor(s / 60) 68 | s -= m * 60 69 | return '%dm %ds' % (m, s) 70 | 71 | def time_since(since, percent): 72 | now = time.time() 73 | s = now - since 74 | es = s / percent 75 | rs = es - s 76 | return '%s (- %s)' % (as_minutes(s), as_minutes(rs)) 77 | 78 | print('epoch: %03d inner_iter: %5d' % (epoch, inner_iter), end=" ") 79 | # now = time.time() 80 | message = '%s niter: %07d completed: %3d%%)'%(time_since(start_time, niter_state / total_niters), niter_state, niter_state / total_niters * 100) 81 | for k, v in losses.items(): 82 | message += ' %s: %.4f ' % (k, v) 83 | print(message) 84 | 85 | 86 | def compose_gif_img_list(img_list, fp_out, duration): 87 | img, *imgs = [Image.fromarray(np.array(image)) for image in img_list] 88 | img.save(fp=fp_out, format='GIF', append_images=imgs, optimize=False, 89 | save_all=True, loop=0, duration=duration) 90 | 91 | 92 | def save_images(visuals, image_path): 93 | if not os.path.exists(image_path): 94 | os.makedirs(image_path) 95 | 96 | for i, (label, img_numpy) in enumerate(visuals.items()): 97 | img_name = '%d_%s.jpg' % (i, label) 98 | save_path = os.path.join(image_path, img_name) 99 | save_image(img_numpy, save_path) 100 | 101 | 102 | def save_images_test(visuals, image_path, from_name, to_name): 103 | if not os.path.exists(image_path): 104 | os.makedirs(image_path) 105 | 106 | for i, (label, img_numpy) in enumerate(visuals.items()): 107 | img_name = "%s_%s_%s" % (from_name, to_name, label) 108 | save_path = os.path.join(image_path, img_name) 109 | save_image(img_numpy, save_path) 110 | 111 | 112 | def compose_and_save_img(img_list, save_dir, img_name, col=4, row=1, img_size=(256, 200)): 113 | # print(col, row) 114 | compose_img = compose_image(img_list, col, row, img_size) 115 | if not os.path.exists(save_dir): 116 | os.makedirs(save_dir) 117 | img_path = os.path.join(save_dir, img_name) 118 | # print(img_path) 119 | compose_img.save(img_path) 120 | 121 | 122 | def compose_image(img_list, col, row, img_size): 123 | to_image = Image.new('RGB', (col * img_size[0], row * img_size[1])) 124 | for y in range(0, row): 125 | for x in range(0, col): 126 | from_img = Image.fromarray(img_list[y * col + x]) 127 | # print((x * img_size[0], y*img_size[1], 128 | # (x + 1) * img_size[0], (y + 1) * img_size[1])) 129 | paste_area = (x * img_size[0], y*img_size[1], 130 | (x + 1) * img_size[0], (y + 1) * img_size[1]) 131 | to_image.paste(from_img, paste_area) 132 | # to_image[y*img_size[1]:(y + 1) * img_size[1], x * img_size[0] :(x + 1) * img_size[0]] = from_img 133 | return to_image 134 | 135 | 136 | def plot_loss_curve(losses, save_path, intervals=500): 137 | plt.figure(figsize=(10, 5)) 138 | plt.title("Loss During Training") 139 | for key in losses.keys(): 140 | plt.plot(list_cut_average(losses[key], intervals), label=key) 141 | plt.xlabel("Iterations/" + str(intervals)) 142 | plt.ylabel("Loss") 143 | plt.legend() 144 | plt.savefig(save_path) 145 | plt.show() 146 | 147 | 148 | def list_cut_average(ll, intervals): 149 | if intervals == 1: 150 | return ll 151 | 152 | bins = math.ceil(len(ll) * 1.0 / intervals) 153 | ll_new = [] 154 | for i in range(bins): 155 | l_low = intervals * i 156 | l_high = l_low + intervals 157 | l_high = l_high if l_high < len(ll) else len(ll) 158 | ll_new.append(np.mean(ll[l_low:l_high])) 159 | return ll_new 160 | 161 | 162 | # def motion_temporal_filter(motion, sigma=1): 163 | # motion = motion.reshape(motion.shape[0], -1) 164 | # # print(motion.shape) 165 | # for i in range(motion.shape[1]): 166 | # motion[:, i] = gaussian_filter(motion[:, i], sigma=sigma, mode="nearest") 167 | # return motion.reshape(motion.shape[0], -1, 3) 168 | 169 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import linalg 3 | import torch 4 | 5 | 6 | def calculate_mpjpe(gt_joints, pred_joints): 7 | """ 8 | gt_joints: num_poses x num_joints(22) x 3 9 | pred_joints: num_poses x num_joints(22) x 3 10 | (obtained from recover_from_ric()) 11 | """ 12 | assert gt_joints.shape == pred_joints.shape, f"GT shape: {gt_joints.shape}, pred shape: {pred_joints.shape}" 13 | 14 | # Align by root (pelvis) 15 | pelvis = gt_joints[:, [0]].mean(1) 16 | gt_joints = gt_joints - torch.unsqueeze(pelvis, dim=1) 17 | pelvis = pred_joints[:, [0]].mean(1) 18 | pred_joints = pred_joints - torch.unsqueeze(pelvis, dim=1) 19 | 20 | # Compute MPJPE 21 | mpjpe = torch.linalg.norm(pred_joints - gt_joints, dim=-1) # num_poses x num_joints=22 22 | mpjpe_seq = mpjpe.mean(-1) # num_poses 23 | 24 | return mpjpe_seq 25 | 26 | # (X - X_train)*(X - X_train) = -2X*X_train + X*X + X_train*X_train 27 | def euclidean_distance_matrix(matrix1, matrix2): 28 | """ 29 | Params: 30 | -- matrix1: N1 x D 31 | -- matrix2: N2 x D 32 | Returns: 33 | -- dist: N1 x N2 34 | dist[i, j] == distance(matrix1[i], matrix2[j]) 35 | """ 36 | assert matrix1.shape[1] == matrix2.shape[1] 37 | d1 = -2 * np.dot(matrix1, matrix2.T) # shape (num_test, num_train) 38 | d2 = np.sum(np.square(matrix1), axis=1, keepdims=True) # shape (num_test, 1) 39 | d3 = np.sum(np.square(matrix2), axis=1) # shape (num_train, ) 40 | dists = np.sqrt(d1 + d2 + d3) # broadcasting 41 | return dists 42 | 43 | def calculate_top_k(mat, top_k): 44 | size = mat.shape[0] 45 | gt_mat = np.expand_dims(np.arange(size), 1).repeat(size, 1) 46 | bool_mat = (mat == gt_mat) 47 | correct_vec = False 48 | top_k_list = [] 49 | for i in range(top_k): 50 | # print(correct_vec, bool_mat[:, i]) 51 | correct_vec = (correct_vec | bool_mat[:, i]) 52 | # print(correct_vec) 53 | top_k_list.append(correct_vec[:, None]) 54 | top_k_mat = np.concatenate(top_k_list, axis=1) 55 | return top_k_mat 56 | 57 | 58 | def calculate_R_precision(embedding1, embedding2, top_k, sum_all=False): 59 | dist_mat = euclidean_distance_matrix(embedding1, embedding2) 60 | argmax = np.argsort(dist_mat, axis=1) 61 | top_k_mat = calculate_top_k(argmax, top_k) 62 | if sum_all: 63 | return top_k_mat.sum(axis=0) 64 | else: 65 | return top_k_mat 66 | 67 | 68 | def calculate_matching_score(embedding1, embedding2, sum_all=False): 69 | assert len(embedding1.shape) == 2 70 | assert embedding1.shape[0] == embedding2.shape[0] 71 | assert embedding1.shape[1] == embedding2.shape[1] 72 | 73 | dist = linalg.norm(embedding1 - embedding2, axis=1) 74 | if sum_all: 75 | return dist.sum(axis=0) 76 | else: 77 | return dist 78 | 79 | 80 | 81 | def calculate_activation_statistics(activations): 82 | """ 83 | Params: 84 | -- activation: num_samples x dim_feat 85 | Returns: 86 | -- mu: dim_feat 87 | -- sigma: dim_feat x dim_feat 88 | """ 89 | mu = np.mean(activations, axis=0) 90 | cov = np.cov(activations, rowvar=False) 91 | return mu, cov 92 | 93 | 94 | def calculate_diversity(activation, diversity_times): 95 | assert len(activation.shape) == 2 96 | assert activation.shape[0] > diversity_times 97 | num_samples = activation.shape[0] 98 | 99 | first_indices = np.random.choice(num_samples, diversity_times, replace=False) 100 | second_indices = np.random.choice(num_samples, diversity_times, replace=False) 101 | dist = linalg.norm(activation[first_indices] - activation[second_indices], axis=1) 102 | return dist.mean() 103 | 104 | 105 | def calculate_multimodality(activation, multimodality_times): 106 | assert len(activation.shape) == 3 107 | assert activation.shape[1] > multimodality_times 108 | num_per_sent = activation.shape[1] 109 | 110 | first_dices = np.random.choice(num_per_sent, multimodality_times, replace=False) 111 | second_dices = np.random.choice(num_per_sent, multimodality_times, replace=False) 112 | dist = linalg.norm(activation[:, first_dices] - activation[:, second_dices], axis=2) 113 | return dist.mean() 114 | 115 | 116 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 117 | """Numpy implementation of the Frechet Distance. 118 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 119 | and X_2 ~ N(mu_2, C_2) is 120 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 121 | Stable version by Dougal J. Sutherland. 122 | Params: 123 | -- mu1 : Numpy array containing the activations of a layer of the 124 | inception net (like returned by the function 'get_predictions') 125 | for generated samples. 126 | -- mu2 : The sample mean over activations, precalculated on an 127 | representative data set. 128 | -- sigma1: The covariance matrix over activations for generated samples. 129 | -- sigma2: The covariance matrix over activations, precalculated on an 130 | representative data set. 131 | Returns: 132 | -- : The Frechet Distance. 133 | """ 134 | 135 | mu1 = np.atleast_1d(mu1) 136 | mu2 = np.atleast_1d(mu2) 137 | 138 | sigma1 = np.atleast_2d(sigma1) 139 | sigma2 = np.atleast_2d(sigma2) 140 | 141 | assert mu1.shape == mu2.shape, \ 142 | 'Training and test mean vectors have different lengths' 143 | assert sigma1.shape == sigma2.shape, \ 144 | 'Training and test covariances have different dimensions' 145 | 146 | diff = mu1 - mu2 147 | 148 | # Product might be almost singular 149 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 150 | if not np.isfinite(covmean).all(): 151 | msg = ('fid calculation produces singular product; ' 152 | 'adding %s to diagonal of cov estimates') % eps 153 | print(msg) 154 | offset = np.eye(sigma1.shape[0]) * eps 155 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 156 | 157 | # Numerical error might give slight imaginary component 158 | if np.iscomplexobj(covmean): 159 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 160 | m = np.max(np.abs(covmean.imag)) 161 | raise ValueError('Imaginary component {}'.format(m)) 162 | covmean = covmean.real 163 | 164 | tr_covmean = np.trace(covmean) 165 | 166 | return (diff.dot(diff) + np.trace(sigma1) + 167 | np.trace(sigma2) - 2 * tr_covmean) 168 | 169 | 170 | 171 | -------------------------------------------------------------------------------- /train_t2m_transformer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | 5 | from torch.utils.data import DataLoader 6 | from os.path import join as pjoin 7 | 8 | from models.mask_transformer.transformer import MaskTransformer 9 | from models.mask_transformer.transformer_trainer import MaskTransformerTrainer 10 | from models.vq.model import RVQVAE 11 | 12 | from options.train_option import TrainT2MOptions 13 | 14 | from utils.plot_script import plot_3d_motion 15 | from utils.motion_process import recover_from_ric 16 | from utils.get_opt import get_opt 17 | from utils.fixseed import fixseed 18 | from utils.paramUtil import t2m_kinematic_chain, kit_kinematic_chain 19 | 20 | from data.t2m_dataset import Text2MotionDataset 21 | from motion_loaders.dataset_motion_loader import get_dataset_motion_loader 22 | from models.t2m_eval_wrapper import EvaluatorModelWrapper 23 | 24 | 25 | def plot_t2m(data, save_dir, captions, m_lengths): 26 | data = train_dataset.inv_transform(data) 27 | 28 | # print(ep_curves.shape) 29 | for i, (caption, joint_data) in enumerate(zip(captions, data)): 30 | joint_data = joint_data[:m_lengths[i]] 31 | joint = recover_from_ric(torch.from_numpy(joint_data).float(), opt.joints_num).numpy() 32 | save_path = pjoin(save_dir, '%02d.mp4'%i) 33 | # print(joint.shape) 34 | plot_3d_motion(save_path, kinematic_chain, joint, title=caption, fps=fps, radius=radius) 35 | 36 | def load_vq_model(): 37 | opt_path = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.vq_name, 'opt.txt') 38 | vq_opt = get_opt(opt_path, opt.device) 39 | vq_model = RVQVAE(vq_opt, 40 | dim_pose, 41 | vq_opt.nb_code, 42 | vq_opt.code_dim, 43 | vq_opt.output_emb_width, 44 | vq_opt.down_t, 45 | vq_opt.stride_t, 46 | vq_opt.width, 47 | vq_opt.depth, 48 | vq_opt.dilation_growth_rate, 49 | vq_opt.vq_act, 50 | vq_opt.vq_norm) 51 | ckpt = torch.load(pjoin(vq_opt.checkpoints_dir, vq_opt.dataset_name, vq_opt.name, 'model', 'net_best_fid.tar'), 52 | map_location='cpu') 53 | model_key = 'vq_model' if 'vq_model' in ckpt else 'net' 54 | vq_model.load_state_dict(ckpt[model_key]) 55 | print(f'Loading VQ Model {opt.vq_name}') 56 | return vq_model, vq_opt 57 | 58 | if __name__ == '__main__': 59 | parser = TrainT2MOptions() 60 | opt = parser.parse() 61 | fixseed(opt.seed) 62 | 63 | opt.device = torch.device("cpu" if opt.gpu_id == -1 else "cuda:" + str(opt.gpu_id)) 64 | torch.autograd.set_detect_anomaly(True) 65 | 66 | opt.save_root = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.name) 67 | opt.model_dir = pjoin(opt.save_root, 'model') 68 | # opt.meta_dir = pjoin(opt.save_root, 'meta') 69 | opt.eval_dir = pjoin(opt.save_root, 'animation') 70 | opt.log_dir = pjoin('./log/t2m/', opt.dataset_name, opt.name) 71 | 72 | os.makedirs(opt.model_dir, exist_ok=True) 73 | # os.makedirs(opt.meta_dir, exist_ok=True) 74 | os.makedirs(opt.eval_dir, exist_ok=True) 75 | os.makedirs(opt.log_dir, exist_ok=True) 76 | 77 | if opt.dataset_name == 't2m': 78 | opt.data_root = './dataset/HumanML3D' 79 | opt.motion_dir = pjoin(opt.data_root, 'new_joint_vecs') 80 | opt.joints_num = 22 81 | opt.max_motion_len = 55 82 | dim_pose = 263 83 | radius = 4 84 | fps = 20 85 | kinematic_chain = t2m_kinematic_chain 86 | dataset_opt_path = './checkpoints/t2m/Comp_v6_KLD005/opt.txt' 87 | 88 | elif opt.dataset_name == 'kit': #TODO 89 | opt.data_root = './dataset/KIT-ML' 90 | opt.motion_dir = pjoin(opt.data_root, 'new_joint_vecs') 91 | opt.joints_num = 21 92 | radius = 240 * 8 93 | fps = 12.5 94 | dim_pose = 251 95 | opt.max_motion_len = 55 96 | kinematic_chain = kit_kinematic_chain 97 | dataset_opt_path = './checkpoints/kit/Comp_v6_KLD005/opt.txt' 98 | 99 | else: 100 | raise KeyError('Dataset Does Not Exist') 101 | 102 | opt.text_dir = pjoin(opt.data_root, 'texts') 103 | 104 | vq_model, vq_opt = load_vq_model() 105 | 106 | clip_version = 'ViT-B/32' 107 | 108 | opt.num_tokens = vq_opt.nb_code 109 | 110 | t2m_transformer = MaskTransformer(code_dim=vq_opt.code_dim, 111 | cond_mode='text', 112 | latent_dim=opt.latent_dim, 113 | ff_size=opt.ff_size, 114 | num_layers=opt.n_layers, 115 | num_heads=opt.n_heads, 116 | dropout=opt.dropout, 117 | clip_dim=512, 118 | cond_drop_prob=opt.cond_drop_prob, 119 | clip_version=clip_version, 120 | opt=opt) 121 | 122 | # if opt.fix_token_emb: 123 | # t2m_transformer.load_and_freeze_token_emb(vq_model.quantizer.codebooks[0]) 124 | 125 | all_params = 0 126 | pc_transformer = sum(param.numel() for param in t2m_transformer.parameters_wo_clip()) 127 | 128 | # print(t2m_transformer) 129 | # print("Total parameters of t2m_transformer net: {:.2f}M".format(pc_transformer / 1000_000)) 130 | all_params += pc_transformer 131 | 132 | print('Total parameters of all models: {:.2f}M'.format(all_params / 1000_000)) 133 | 134 | mean = np.load(pjoin(opt.checkpoints_dir, opt.dataset_name, opt.vq_name, 'meta', 'mean.npy')) 135 | std = np.load(pjoin(opt.checkpoints_dir, opt.dataset_name, opt.vq_name, 'meta', 'std.npy')) 136 | 137 | train_split_file = pjoin(opt.data_root, 'train.txt') 138 | val_split_file = pjoin(opt.data_root, 'val.txt') 139 | 140 | train_dataset = Text2MotionDataset(opt, mean, std, train_split_file) 141 | val_dataset = Text2MotionDataset(opt, mean, std, val_split_file) 142 | 143 | train_loader = DataLoader(train_dataset, batch_size=opt.batch_size, num_workers=4, shuffle=True, drop_last=True) 144 | val_loader = DataLoader(val_dataset, batch_size=opt.batch_size, num_workers=4, shuffle=True, drop_last=True) 145 | 146 | eval_val_loader, _ = get_dataset_motion_loader(dataset_opt_path, 32, 'val', device=opt.device) 147 | 148 | wrapper_opt = get_opt(dataset_opt_path, torch.device('cuda')) 149 | eval_wrapper = EvaluatorModelWrapper(wrapper_opt) 150 | 151 | trainer = MaskTransformerTrainer(opt, t2m_transformer, vq_model) 152 | 153 | trainer.train(train_loader, val_loader, eval_val_loader, eval_wrapper=eval_wrapper, plot_eval=plot_t2m) -------------------------------------------------------------------------------- /visualization/joints2bvh.py: -------------------------------------------------------------------------------- 1 | import visualization.Animation as Animation 2 | 3 | from visualization.InverseKinematics import BasicInverseKinematics, BasicJacobianIK, InverseKinematics 4 | from visualization.Quaternions import Quaternions 5 | import visualization.BVH_mod as BVH 6 | from visualization.remove_fs import * 7 | 8 | from utils.plot_script import plot_3d_motion 9 | from utils import paramUtil 10 | from common.skeleton import Skeleton 11 | import torch 12 | 13 | from torch import nn 14 | from visualization.utils.quat import ik_rot, between, fk, ik 15 | from tqdm import tqdm 16 | 17 | 18 | def get_grot(glb, parent, offset): 19 | root_quat = np.array([[1.0, 0.0, 0.0, 0.0]]).repeat(glb.shape[0], axis=0)[:, None] 20 | local_pos = glb[:, 1:] - glb[:, parent[1:]] 21 | norm_offset = offset[1:] / np.linalg.norm(offset[1:], axis=-1, keepdims=True) 22 | norm_lpos = local_pos / np.linalg.norm(local_pos, axis=-1, keepdims=True) 23 | grot = between(norm_offset, norm_lpos) 24 | grot = np.concatenate((root_quat, grot), axis=1) 25 | grot /= np.linalg.norm(grot, axis=-1, keepdims=True) 26 | return grot 27 | 28 | 29 | class Joint2BVHConvertor: 30 | def __init__(self): 31 | self.template = BVH.load('./visualization/data/template.bvh', need_quater=True) 32 | self.re_order = [0, 1, 4, 7, 10, 2, 5, 8, 11, 3, 6, 9, 12, 15, 13, 16, 18, 20, 14, 17, 19, 21] 33 | 34 | self.re_order_inv = [0, 1, 5, 9, 2, 6, 10, 3, 7, 11, 4, 8, 12, 14, 18, 13, 15, 19, 16, 20, 17, 21] 35 | self.end_points = [4, 8, 13, 17, 21] 36 | 37 | self.template_offset = self.template.offsets.copy() 38 | self.parents = [-1, 0, 1, 2, 3, 0, 5, 6, 7, 0, 9, 10, 11, 12, 11, 14, 15, 16, 11, 18, 19, 20] 39 | 40 | def convert(self, positions, filename, iterations=10, foot_ik=True): 41 | ''' 42 | Convert the SMPL joint positions to Mocap BVH 43 | :param positions: (N, 22, 3) 44 | :param filename: Save path for resulting BVH 45 | :param iterations: iterations for optimizing rotations, 10 is usually enough 46 | :param foot_ik: whether to enfore foot inverse kinematics, removing foot slide issue. 47 | :return: 48 | ''' 49 | positions = positions[:, self.re_order] 50 | new_anim = self.template.copy() 51 | new_anim.rotations = Quaternions.id(positions.shape[:-1]) 52 | new_anim.positions = new_anim.positions[0:1].repeat(positions.shape[0], axis=-0) 53 | new_anim.positions[:, 0] = positions[:, 0] 54 | 55 | if foot_ik: 56 | positions = remove_fs(positions, None, fid_l=(3, 4), fid_r=(7, 8), interp_length=5, 57 | force_on_floor=True) 58 | ik_solver = BasicInverseKinematics(new_anim, positions, iterations=iterations, silent=True) 59 | new_anim = ik_solver() 60 | 61 | # BVH.save(filename, new_anim, names=new_anim.names, frametime=1 / 20, order='zyx', quater=True) 62 | glb = Animation.positions_global(new_anim)[:, self.re_order_inv] 63 | if filename is not None: 64 | BVH.save(filename, new_anim, names=new_anim.names, frametime=1 / 20, order='zyx', quater=True) 65 | return new_anim, glb 66 | 67 | def convert_sgd(self, positions, filename, iterations=100, foot_ik=True): 68 | ''' 69 | Convert the SMPL joint positions to Mocap BVH 70 | 71 | :param positions: (N, 22, 3) 72 | :param filename: Save path for resulting BVH 73 | :param iterations: iterations for optimizing rotations, 10 is usually enough 74 | :param foot_ik: whether to enfore foot inverse kinematics, removing foot slide issue. 75 | :return: 76 | ''' 77 | 78 | ## Positional Foot locking ## 79 | glb = positions[:, self.re_order] 80 | 81 | if foot_ik: 82 | glb = remove_fs(glb, None, fid_l=(3, 4), fid_r=(7, 8), interp_length=2, 83 | force_on_floor=True) 84 | 85 | ## Fit BVH ## 86 | new_anim = self.template.copy() 87 | new_anim.rotations = Quaternions.id(glb.shape[:-1]) 88 | new_anim.positions = new_anim.positions[0:1].repeat(glb.shape[0], axis=-0) 89 | new_anim.positions[:, 0] = glb[:, 0] 90 | anim = new_anim.copy() 91 | 92 | rot = torch.tensor(anim.rotations.qs, dtype=torch.float) 93 | pos = torch.tensor(anim.positions[:, 0, :], dtype=torch.float) 94 | offset = torch.tensor(anim.offsets, dtype=torch.float) 95 | 96 | glb = torch.tensor(glb, dtype=torch.float) 97 | ik_solver = InverseKinematics(rot, pos, offset, anim.parents, glb) 98 | print('Fixing foot contact using IK...') 99 | for i in tqdm(range(iterations)): 100 | mse = ik_solver.step() 101 | # print(i, mse) 102 | 103 | rotations = ik_solver.rotations.detach().cpu() 104 | norm = torch.norm(rotations, dim=-1, keepdim=True) 105 | rotations /= norm 106 | 107 | anim.rotations = Quaternions(rotations.numpy()) 108 | anim.rotations[:, self.end_points] = Quaternions.id((anim.rotations.shape[0], len(self.end_points))) 109 | anim.positions[:, 0, :] = ik_solver.position.detach().cpu().numpy() 110 | if filename is not None: 111 | BVH.save(filename, anim, names=new_anim.names, frametime=1 / 20, order='zyx', quater=True) 112 | # BVH.save(filename[:-3] + 'bvh', anim, names=new_anim.names, frametime=1 / 20, order='zyx', quater=True) 113 | glb = Animation.positions_global(anim)[:, self.re_order_inv] 114 | return anim, glb 115 | 116 | 117 | 118 | if __name__ == "__main__": 119 | # file = 'batch0_sample13_repeat0_len196.npy' 120 | # file = 'batch2_sample10_repeat0_len156.npy' 121 | # file = 'batch2_sample13_repeat0_len196.npy' #line #57 new_anim.positions = lpos #new_anim.positions[0:1].repeat(positions.shape[0], axis=-0) #TODO, figure out why it's important 122 | # file = 'batch1_sample12_repeat0_len196.npy' #hard case karate 123 | # file = 'batch1_sample14_repeat0_len180.npy' 124 | # file = 'batch0_sample3_repeat0_len192.npy' 125 | # file = 'batch1_sample4_repeat0_len136.npy' 126 | 127 | # file = 'batch0_sample0_repeat0_len152.npy' 128 | # path = f'/Users/yuxuanmu/project/MaskMIT/demo/cond4_topkr0.9_ts18_tau1.0_s1009/joints/{file}' 129 | # joints = np.load(path) 130 | # converter = Joint2BVHConvertor() 131 | # new_anim = converter.convert(joints, './gen_L196.mp4', foot_ik=True) 132 | 133 | folder = '/Users/yuxuanmu/project/MaskMIT/demo/cond4_topkr0.9_ts18_tau1.0_s1009' 134 | files = os.listdir(os.path.join(folder, 'joints')) 135 | files = [f for f in files if 'repeat' in f] 136 | converter = Joint2BVHConvertor() 137 | for f in tqdm(files): 138 | joints = np.load(os.path.join(folder, 'joints', f)) 139 | converter.convert(joints, os.path.join(folder, 'ik_animations', f'ik_{f}'.replace('npy', 'mp4')), foot_ik=True) -------------------------------------------------------------------------------- /models/vq/quantizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from einops import rearrange, repeat, reduce, pack, unpack 6 | 7 | # from vector_quantize_pytorch import ResidualVQ 8 | 9 | #Borrow from vector_quantize_pytorch 10 | 11 | def log(t, eps = 1e-20): 12 | return torch.log(t.clamp(min = eps)) 13 | 14 | def gumbel_noise(t): 15 | noise = torch.zeros_like(t).uniform_(0, 1) 16 | return -log(-log(noise)) 17 | 18 | def gumbel_sample( 19 | logits, 20 | temperature = 1., 21 | stochastic = False, 22 | dim = -1, 23 | training = True 24 | ): 25 | 26 | if training and stochastic and temperature > 0: 27 | sampling_logits = (logits / temperature) + gumbel_noise(logits) 28 | else: 29 | sampling_logits = logits 30 | 31 | ind = sampling_logits.argmax(dim = dim) 32 | 33 | return ind 34 | 35 | class QuantizeEMAReset(nn.Module): 36 | def __init__(self, nb_code, code_dim, args): 37 | super(QuantizeEMAReset, self).__init__() 38 | self.nb_code = nb_code 39 | self.code_dim = code_dim 40 | self.mu = args.mu ##TO_DO 41 | self.reset_codebook() 42 | 43 | def reset_codebook(self): 44 | self.init = False 45 | self.code_sum = None 46 | self.code_count = None 47 | self.register_buffer('codebook', torch.zeros(self.nb_code, self.code_dim, requires_grad=False)) 48 | 49 | def _tile(self, x): 50 | nb_code_x, code_dim = x.shape 51 | if nb_code_x < self.nb_code: 52 | n_repeats = (self.nb_code + nb_code_x - 1) // nb_code_x 53 | std = 0.01 / np.sqrt(code_dim) 54 | out = x.repeat(n_repeats, 1) 55 | out = out + torch.randn_like(out) * std 56 | else: 57 | out = x 58 | return out 59 | 60 | def init_codebook(self, x): 61 | out = self._tile(x) 62 | self.codebook = out[:self.nb_code] 63 | self.code_sum = self.codebook.clone() 64 | self.code_count = torch.ones(self.nb_code, device=self.codebook.device) 65 | self.init = True 66 | 67 | def quantize(self, x, sample_codebook_temp=0.): 68 | # N X C -> C X N 69 | k_w = self.codebook.t() 70 | # x: NT X C 71 | # NT X N 72 | distance = torch.sum(x ** 2, dim=-1, keepdim=True) - \ 73 | 2 * torch.matmul(x, k_w) + \ 74 | torch.sum(k_w ** 2, dim=0, keepdim=True) # (N * L, b) 75 | 76 | # code_idx = torch.argmin(distance, dim=-1) 77 | 78 | code_idx = gumbel_sample(-distance, dim = -1, temperature = sample_codebook_temp, stochastic=True, training = self.training) 79 | 80 | return code_idx 81 | 82 | def dequantize(self, code_idx): 83 | x = F.embedding(code_idx, self.codebook) 84 | return x 85 | 86 | def get_codebook_entry(self, indices): 87 | return self.dequantize(indices).permute(0, 2, 1) 88 | 89 | @torch.no_grad() 90 | def compute_perplexity(self, code_idx): 91 | # Calculate new centres 92 | code_onehot = torch.zeros(self.nb_code, code_idx.shape[0], device=code_idx.device) # nb_code, N * L 93 | code_onehot.scatter_(0, code_idx.view(1, code_idx.shape[0]), 1) 94 | 95 | code_count = code_onehot.sum(dim=-1) # nb_code 96 | prob = code_count / torch.sum(code_count) 97 | perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7))) 98 | return perplexity 99 | 100 | @torch.no_grad() 101 | def update_codebook(self, x, code_idx): 102 | code_onehot = torch.zeros(self.nb_code, x.shape[0], device=x.device) # nb_code, N * L 103 | code_onehot.scatter_(0, code_idx.view(1, x.shape[0]), 1) 104 | 105 | code_sum = torch.matmul(code_onehot, x) # nb_code, c 106 | code_count = code_onehot.sum(dim=-1) # nb_code 107 | 108 | out = self._tile(x) 109 | code_rand = out[:self.nb_code] 110 | 111 | # Update centres 112 | self.code_sum = self.mu * self.code_sum + (1. - self.mu) * code_sum 113 | self.code_count = self.mu * self.code_count + (1. - self.mu) * code_count 114 | 115 | usage = (self.code_count.view(self.nb_code, 1) >= 1.0).float() 116 | code_update = self.code_sum.view(self.nb_code, self.code_dim) / self.code_count.view(self.nb_code, 1) 117 | self.codebook = usage * code_update + (1-usage) * code_rand 118 | 119 | 120 | prob = code_count / torch.sum(code_count) 121 | perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7))) 122 | 123 | return perplexity 124 | 125 | def preprocess(self, x): 126 | # NCT -> NTC -> [NT, C] 127 | # x = x.permute(0, 2, 1).contiguous() 128 | # x = x.view(-1, x.shape[-1]) 129 | x = rearrange(x, 'n c t -> (n t) c') 130 | return x 131 | 132 | def forward(self, x, return_idx=False, temperature=0.): 133 | N, width, T = x.shape 134 | 135 | x = self.preprocess(x) 136 | if self.training and not self.init: 137 | self.init_codebook(x) 138 | 139 | code_idx = self.quantize(x, temperature) 140 | x_d = self.dequantize(code_idx) 141 | 142 | if self.training: 143 | perplexity = self.update_codebook(x, code_idx) 144 | else: 145 | perplexity = self.compute_perplexity(code_idx) 146 | 147 | commit_loss = F.mse_loss(x, x_d.detach()) # It's right. the t2m-gpt paper is wrong on embed loss and commitment loss. 148 | 149 | # Passthrough 150 | x_d = x + (x_d - x).detach() 151 | 152 | # Postprocess 153 | x_d = x_d.view(N, T, -1).permute(0, 2, 1).contiguous() 154 | code_idx = code_idx.view(N, T).contiguous() 155 | # print(code_idx[0]) 156 | if return_idx: 157 | return x_d, code_idx, commit_loss, perplexity 158 | return x_d, commit_loss, perplexity 159 | 160 | class QuantizeEMA(QuantizeEMAReset): 161 | @torch.no_grad() 162 | def update_codebook(self, x, code_idx): 163 | code_onehot = torch.zeros(self.nb_code, x.shape[0], device=x.device) # nb_code, N * L 164 | code_onehot.scatter_(0, code_idx.view(1, x.shape[0]), 1) 165 | 166 | code_sum = torch.matmul(code_onehot, x) # nb_code, c 167 | code_count = code_onehot.sum(dim=-1) # nb_code 168 | 169 | # Update centres 170 | self.code_sum = self.mu * self.code_sum + (1. - self.mu) * code_sum 171 | self.code_count = self.mu * self.code_count + (1. - self.mu) * code_count 172 | 173 | usage = (self.code_count.view(self.nb_code, 1) >= 1.0).float() 174 | code_update = self.code_sum.view(self.nb_code, self.code_dim) / self.code_count.view(self.nb_code, 1) 175 | self.codebook = usage * code_update + (1-usage) * self.codebook 176 | 177 | prob = code_count / torch.sum(code_count) 178 | perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7))) 179 | 180 | return perplexity 181 | -------------------------------------------------------------------------------- /visualization/smpl2bvh.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import argparse 4 | import pickle 5 | import smplx 6 | 7 | from utils import bvh, quat 8 | 9 | 10 | def parse_args(): 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument("--model_path", type=str, default="./visualization/data/smpl/") 13 | parser.add_argument("--model_type", type=str, default="smpl", choices=["smpl", "smplx"]) 14 | parser.add_argument("--gender", type=str, default="MALE", choices=["MALE", "FEMALE", "NEUTRAL"]) 15 | parser.add_argument("--num_betas", type=int, default=10, choices=[10, 300]) 16 | parser.add_argument("--poses", type=str, default="data/gWA_sFM_cAll_d27_mWA5_ch20.pkl") 17 | parser.add_argument("--fps", type=int, default=60) 18 | parser.add_argument("--output", type=str, default="data/gWA_sFM_cAll_d27_mWA5_ch20.bvh") 19 | parser.add_argument("--mirror", action="store_true") 20 | return parser.parse_args() 21 | 22 | def mirror_rot_trans(lrot, trans, names, parents): 23 | joints_mirror = np.array([( 24 | names.index("Left"+n[5:]) if n.startswith("Right") else ( 25 | names.index("Right"+n[4:]) if n.startswith("Left") else 26 | names.index(n))) for n in names]) 27 | 28 | mirror_pos = np.array([-1, 1, 1]) 29 | mirror_rot = np.array([1, 1, -1, -1]) 30 | grot = quat.fk_rot(lrot, parents) 31 | trans_mirror = mirror_pos * trans 32 | grot_mirror = mirror_rot * grot[:,joints_mirror] 33 | 34 | return quat.ik_rot(grot_mirror, parents), trans_mirror 35 | 36 | def smpl2bvh(model_path:str, poses:str, output:str, mirror:bool, 37 | model_type="smpl", gender="MALE", 38 | num_betas=10, fps=60) -> None: 39 | """Save bvh file created by smpl parameters. 40 | 41 | Args: 42 | model_path (str): Path to smpl models. 43 | poses (str): Path to npz or pkl file. 44 | output (str): Where to save bvh. 45 | mirror (bool): Whether save mirror motion or not. 46 | model_type (str, optional): I prepared "smpl" only. Defaults to "smpl". 47 | gender (str, optional): Gender Information. Defaults to "MALE". 48 | num_betas (int, optional): How many pca parameters to use in SMPL. Defaults to 10. 49 | fps (int, optional): Frame per second. Defaults to 30. 50 | """ 51 | 52 | # names = [ 53 | # "Pelvis", 54 | # "Left_hip", 55 | # "Right_hip", 56 | # "Spine1", 57 | # "Left_knee", 58 | # "Right_knee", 59 | # "Spine2", 60 | # "Left_ankle", 61 | # "Right_ankle", 62 | # "Spine3", 63 | # "Left_foot", 64 | # "Right_foot", 65 | # "Neck", 66 | # "Left_collar", 67 | # "Right_collar", 68 | # "Head", 69 | # "Left_shoulder", 70 | # "Right_shoulder", 71 | # "Left_elbow", 72 | # "Right_elbow", 73 | # "Left_wrist", 74 | # "Right_wrist", 75 | # "Left_palm", 76 | # "Right_palm", 77 | # ] 78 | 79 | names = [ 80 | "Hips", 81 | "LeftUpLeg", 82 | "RightUpLeg", 83 | "Spine", 84 | "LeftLeg", 85 | "RightLeg", 86 | "Spine1", 87 | "LeftFoot", 88 | "RightFoot", 89 | "Spine2", 90 | "LeftToe", 91 | "RightToe", 92 | "Neck", 93 | "LeftShoulder", 94 | "RightShoulder", 95 | "Head", 96 | "LeftArm", 97 | "RightArm", 98 | "LeftForeArm", 99 | "RightForeArm", 100 | "LeftHand", 101 | "RightHand", 102 | "LeftThumb", 103 | "RightThumb", 104 | ] 105 | 106 | # I prepared smpl models only, 107 | # but I will release for smplx models recently. 108 | model = smplx.create(model_path=model_path, 109 | model_type=model_type, 110 | gender=gender, 111 | batch_size=1) 112 | 113 | parents = model.parents.detach().cpu().numpy() 114 | 115 | # You can define betas like this.(default betas are 0 at all.) 116 | rest = model( 117 | # betas = torch.randn([1, num_betas], dtype=torch.float32) 118 | ) 119 | rest_pose = rest.joints.detach().cpu().numpy().squeeze()[:24,:] 120 | 121 | root_offset = rest_pose[0] 122 | offsets = rest_pose - rest_pose[parents] 123 | offsets[0] = root_offset 124 | offsets *= 1 125 | 126 | scaling = None 127 | 128 | # Pose setting. 129 | if poses.endswith(".npz"): 130 | poses = np.load(poses) 131 | rots = np.squeeze(poses["poses"], axis=0) # (N, 24, 3) 132 | trans = np.squeeze(poses["trans"], axis=0) # (N, 3) 133 | 134 | elif poses.endswith(".pkl"): 135 | with open(poses, "rb") as f: 136 | poses = pickle.load(f) 137 | rots = poses["smpl_poses"] # (N, 72) 138 | rots = rots.reshape(rots.shape[0], -1, 3) # (N, 24, 3) 139 | scaling = poses["smpl_scaling"] # (1,) 140 | trans = poses["smpl_trans"] # (N, 3) 141 | 142 | else: 143 | raise Exception("This file type is not supported!") 144 | 145 | if scaling is not None: 146 | trans /= scaling 147 | 148 | # to quaternion 149 | rots = quat.from_axis_angle(rots) 150 | 151 | order = "zyx" 152 | pos = offsets[None].repeat(len(rots), axis=0) 153 | positions = pos.copy() 154 | # positions[:,0] += trans * 10 155 | positions[:, 0] += trans 156 | rotations = np.degrees(quat.to_euler(rots, order=order)) 157 | 158 | bvh_data ={ 159 | "rotations": rotations[:, :22], 160 | "positions": positions[:, :22], 161 | "offsets": offsets[:22], 162 | "parents": parents[:22], 163 | "names": names[:22], 164 | "order": order, 165 | "frametime": 1 / fps, 166 | } 167 | 168 | if not output.endswith(".bvh"): 169 | output = output + ".bvh" 170 | 171 | bvh.save(output, bvh_data) 172 | 173 | if mirror: 174 | rots_mirror, trans_mirror = mirror_rot_trans( 175 | rots, trans, names, parents) 176 | positions_mirror = pos.copy() 177 | positions_mirror[:,0] += trans_mirror 178 | rotations_mirror = np.degrees( 179 | quat.to_euler(rots_mirror, order=order)) 180 | 181 | bvh_data ={ 182 | "rotations": rotations_mirror, 183 | "positions": positions_mirror, 184 | "offsets": offsets, 185 | "parents": parents, 186 | "names": names, 187 | "order": order, 188 | "frametime": 1 / fps, 189 | } 190 | 191 | output_mirror = output.split(".")[0] + "_mirror.bvh" 192 | bvh.save(output_mirror, bvh_data) 193 | 194 | 195 | def joints2bvh() 196 | 197 | if __name__ == "__main__": 198 | args = parse_args() 199 | 200 | smpl2bvh(model_path=args.model_path, model_type=args.model_type, 201 | mirror = args.mirror, gender=args.gender, 202 | poses=args.poses, num_betas=args.num_betas, 203 | fps=args.fps, output=args.output) 204 | 205 | print("finished!") -------------------------------------------------------------------------------- /models/t2m_eval_modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import time 5 | import math 6 | import random 7 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 8 | # from networks.layers import * 9 | 10 | 11 | def init_weight(m): 12 | if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear) or isinstance(m, nn.ConvTranspose1d): 13 | nn.init.xavier_normal_(m.weight) 14 | # m.bias.data.fill_(0.01) 15 | if m.bias is not None: 16 | nn.init.constant_(m.bias, 0) 17 | 18 | 19 | # batch_size, dimension and position 20 | # output: (batch_size, dim) 21 | def positional_encoding(batch_size, dim, pos): 22 | assert batch_size == pos.shape[0] 23 | positions_enc = np.array([ 24 | [pos[j] / np.power(10000, (i-i%2)/dim) for i in range(dim)] 25 | for j in range(batch_size) 26 | ], dtype=np.float32) 27 | positions_enc[:, 0::2] = np.sin(positions_enc[:, 0::2]) 28 | positions_enc[:, 1::2] = np.cos(positions_enc[:, 1::2]) 29 | return torch.from_numpy(positions_enc).float() 30 | 31 | 32 | def get_padding_mask(batch_size, seq_len, cap_lens): 33 | cap_lens = cap_lens.data.tolist() 34 | mask_2d = torch.ones((batch_size, seq_len, seq_len), dtype=torch.float32) 35 | for i, cap_len in enumerate(cap_lens): 36 | mask_2d[i, :, :cap_len] = 0 37 | return mask_2d.bool(), 1 - mask_2d[:, :, 0].clone() 38 | 39 | 40 | def top_k_logits(logits, k): 41 | v, ix = torch.topk(logits, k) 42 | out = logits.clone() 43 | out[out < v[:, [-1]]] = -float('Inf') 44 | return out 45 | 46 | 47 | class PositionalEncoding(nn.Module): 48 | 49 | def __init__(self, d_model, max_len=300): 50 | super(PositionalEncoding, self).__init__() 51 | 52 | pe = torch.zeros(max_len, d_model) 53 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 54 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) 55 | pe[:, 0::2] = torch.sin(position * div_term) 56 | pe[:, 1::2] = torch.cos(position * div_term) 57 | # pe = pe.unsqueeze(0).transpose(0, 1) 58 | self.register_buffer('pe', pe) 59 | 60 | def forward(self, pos): 61 | return self.pe[pos] 62 | 63 | 64 | class MovementConvEncoder(nn.Module): 65 | def __init__(self, input_size, hidden_size, output_size): 66 | super(MovementConvEncoder, self).__init__() 67 | self.main = nn.Sequential( 68 | nn.Conv1d(input_size, hidden_size, 4, 2, 1), 69 | nn.Dropout(0.2, inplace=True), 70 | nn.LeakyReLU(0.2, inplace=True), 71 | nn.Conv1d(hidden_size, output_size, 4, 2, 1), 72 | nn.Dropout(0.2, inplace=True), 73 | nn.LeakyReLU(0.2, inplace=True), 74 | ) 75 | self.out_net = nn.Linear(output_size, output_size) 76 | self.main.apply(init_weight) 77 | self.out_net.apply(init_weight) 78 | 79 | def forward(self, inputs): 80 | inputs = inputs.permute(0, 2, 1) 81 | outputs = self.main(inputs).permute(0, 2, 1) 82 | # print(outputs.shape) 83 | return self.out_net(outputs) 84 | 85 | 86 | class MovementConvDecoder(nn.Module): 87 | def __init__(self, input_size, hidden_size, output_size): 88 | super(MovementConvDecoder, self).__init__() 89 | self.main = nn.Sequential( 90 | nn.ConvTranspose1d(input_size, hidden_size, 4, 2, 1), 91 | # nn.Dropout(0.2, inplace=True), 92 | nn.LeakyReLU(0.2, inplace=True), 93 | nn.ConvTranspose1d(hidden_size, output_size, 4, 2, 1), 94 | # nn.Dropout(0.2, inplace=True), 95 | nn.LeakyReLU(0.2, inplace=True), 96 | ) 97 | self.out_net = nn.Linear(output_size, output_size) 98 | 99 | self.main.apply(init_weight) 100 | self.out_net.apply(init_weight) 101 | 102 | def forward(self, inputs): 103 | inputs = inputs.permute(0, 2, 1) 104 | outputs = self.main(inputs).permute(0, 2, 1) 105 | return self.out_net(outputs) 106 | 107 | class TextEncoderBiGRUCo(nn.Module): 108 | def __init__(self, word_size, pos_size, hidden_size, output_size, device): 109 | super(TextEncoderBiGRUCo, self).__init__() 110 | self.device = device 111 | 112 | self.pos_emb = nn.Linear(pos_size, word_size) 113 | self.input_emb = nn.Linear(word_size, hidden_size) 114 | self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True) 115 | self.output_net = nn.Sequential( 116 | nn.Linear(hidden_size * 2, hidden_size), 117 | nn.LayerNorm(hidden_size), 118 | nn.LeakyReLU(0.2, inplace=True), 119 | nn.Linear(hidden_size, output_size) 120 | ) 121 | 122 | self.input_emb.apply(init_weight) 123 | self.pos_emb.apply(init_weight) 124 | self.output_net.apply(init_weight) 125 | # self.linear2.apply(init_weight) 126 | # self.batch_size = batch_size 127 | self.hidden_size = hidden_size 128 | self.hidden = nn.Parameter(torch.randn((2, 1, self.hidden_size), requires_grad=True)) 129 | 130 | # input(batch_size, seq_len, dim) 131 | def forward(self, word_embs, pos_onehot, cap_lens): 132 | num_samples = word_embs.shape[0] 133 | 134 | pos_embs = self.pos_emb(pos_onehot) 135 | inputs = word_embs + pos_embs 136 | input_embs = self.input_emb(inputs) 137 | hidden = self.hidden.repeat(1, num_samples, 1) 138 | 139 | cap_lens = cap_lens.data.tolist() 140 | emb = pack_padded_sequence(input_embs, cap_lens, batch_first=True) 141 | 142 | gru_seq, gru_last = self.gru(emb, hidden) 143 | 144 | gru_last = torch.cat([gru_last[0], gru_last[1]], dim=-1) 145 | 146 | return self.output_net(gru_last) 147 | 148 | 149 | class MotionEncoderBiGRUCo(nn.Module): 150 | def __init__(self, input_size, hidden_size, output_size, device): 151 | super(MotionEncoderBiGRUCo, self).__init__() 152 | self.device = device 153 | 154 | self.input_emb = nn.Linear(input_size, hidden_size) 155 | self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True) 156 | self.output_net = nn.Sequential( 157 | nn.Linear(hidden_size*2, hidden_size), 158 | nn.LayerNorm(hidden_size), 159 | nn.LeakyReLU(0.2, inplace=True), 160 | nn.Linear(hidden_size, output_size) 161 | ) 162 | 163 | self.input_emb.apply(init_weight) 164 | self.output_net.apply(init_weight) 165 | self.hidden_size = hidden_size 166 | self.hidden = nn.Parameter(torch.randn((2, 1, self.hidden_size), requires_grad=True)) 167 | 168 | # input(batch_size, seq_len, dim) 169 | def forward(self, inputs, m_lens): 170 | num_samples = inputs.shape[0] 171 | 172 | input_embs = self.input_emb(inputs) 173 | hidden = self.hidden.repeat(1, num_samples, 1) 174 | 175 | cap_lens = m_lens.data.tolist() 176 | emb = pack_padded_sequence(input_embs, cap_lens, batch_first=True) 177 | 178 | gru_seq, gru_last = self.gru(emb, hidden) 179 | 180 | gru_last = torch.cat([gru_last[0], gru_last[1]], dim=-1) 181 | 182 | return self.output_net(gru_last) -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: momask 2 | channels: 3 | - pytorch 4 | - anaconda 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1=main 9 | - _openmp_mutex=5.1=1_gnu 10 | - absl-py=1.4.0=pyhd8ed1ab_0 11 | - aiohttp=3.8.3=py37h5eee18b_0 12 | - aiosignal=1.2.0=pyhd3eb1b0_0 13 | - argon2-cffi=21.3.0=pyhd3eb1b0_0 14 | - argon2-cffi-bindings=21.2.0=py37h7f8727e_0 15 | - async-timeout=4.0.2=py37h06a4308_0 16 | - asynctest=0.13.0=py_0 17 | - attrs=22.1.0=py37h06a4308_0 18 | - backcall=0.2.0=pyhd3eb1b0_0 19 | - beautifulsoup4=4.11.1=pyha770c72_0 20 | - blas=1.0=mkl 21 | - bleach=4.1.0=pyhd3eb1b0_0 22 | - blinker=1.4=py37h06a4308_0 23 | - brotlipy=0.7.0=py37h540881e_1004 24 | - c-ares=1.19.0=h5eee18b_0 25 | - ca-certificates=2023.05.30=h06a4308_0 26 | - catalogue=2.0.8=py37h89c1867_0 27 | - certifi=2022.12.7=py37h06a4308_0 28 | - cffi=1.15.1=py37h74dc2b5_0 29 | - charset-normalizer=2.1.1=pyhd8ed1ab_0 30 | - click=8.0.4=py37h89c1867_0 31 | - colorama=0.4.5=pyhd8ed1ab_0 32 | - cryptography=35.0.0=py37hf1a17b8_2 33 | - cudatoolkit=11.0.221=h6bb024c_0 34 | - cycler=0.11.0=pyhd3eb1b0_0 35 | - cymem=2.0.6=py37hd23a5d3_3 36 | - cython-blis=0.7.7=py37hda87dfa_1 37 | - dataclasses=0.8=pyhc8e2a94_3 38 | - dbus=1.13.18=hb2f20db_0 39 | - debugpy=1.5.1=py37h295c915_0 40 | - decorator=5.1.1=pyhd3eb1b0_0 41 | - defusedxml=0.7.1=pyhd3eb1b0_0 42 | - entrypoints=0.4=py37h06a4308_0 43 | - expat=2.4.9=h6a678d5_0 44 | - fftw=3.3.9=h27cfd23_1 45 | - filelock=3.8.0=pyhd8ed1ab_0 46 | - fontconfig=2.13.1=h6c09931_0 47 | - freetype=2.11.0=h70c0345_0 48 | - frozenlist=1.3.3=py37h5eee18b_0 49 | - giflib=5.2.1=h7b6447c_0 50 | - glib=2.69.1=h4ff587b_1 51 | - gst-plugins-base=1.14.0=h8213a91_2 52 | - gstreamer=1.14.0=h28cd5cc_2 53 | - h5py=3.7.0=py37h737f45e_0 54 | - hdf5=1.10.6=h3ffc7dd_1 55 | - icu=58.2=he6710b0_3 56 | - idna=3.4=pyhd8ed1ab_0 57 | - importlib-metadata=4.11.4=py37h89c1867_0 58 | - intel-openmp=2021.4.0=h06a4308_3561 59 | - ipykernel=6.15.2=py37h06a4308_0 60 | - ipython=7.31.1=py37h06a4308_1 61 | - ipython_genutils=0.2.0=pyhd3eb1b0_1 62 | - jedi=0.18.1=py37h06a4308_1 63 | - jinja2=3.1.2=pyhd8ed1ab_1 64 | - joblib=1.1.0=pyhd3eb1b0_0 65 | - jpeg=9b=h024ee3a_2 66 | - jsonschema=3.0.2=py37_0 67 | - jupyter_client=7.4.9=py37h06a4308_0 68 | - jupyter_core=4.11.2=py37h06a4308_0 69 | - jupyterlab_pygments=0.1.2=py_0 70 | - kiwisolver=1.4.2=py37h295c915_0 71 | - langcodes=3.3.0=pyhd8ed1ab_0 72 | - lcms2=2.12=h3be6417_0 73 | - ld_impl_linux-64=2.38=h1181459_1 74 | - libffi=3.3=he6710b0_2 75 | - libgcc-ng=11.2.0=h1234567_1 76 | - libgfortran-ng=11.2.0=h00389a5_1 77 | - libgfortran5=11.2.0=h1234567_1 78 | - libgomp=11.2.0=h1234567_1 79 | - libpng=1.6.37=hbc83047_0 80 | - libprotobuf=3.15.8=h780b84a_1 81 | - libsodium=1.0.18=h7b6447c_0 82 | - libstdcxx-ng=11.2.0=h1234567_1 83 | - libtiff=4.1.0=h2733197_1 84 | - libuuid=1.0.3=h7f8727e_2 85 | - libuv=1.40.0=h7b6447c_0 86 | - libwebp=1.2.0=h89dd481_0 87 | - libxcb=1.15=h7f8727e_0 88 | - libxml2=2.9.14=h74e7548_0 89 | - lz4-c=1.9.3=h295c915_1 90 | - markdown=3.4.3=pyhd8ed1ab_0 91 | - markupsafe=2.1.1=py37h540881e_1 92 | - matplotlib=3.1.3=py37_0 93 | - matplotlib-base=3.1.3=py37hef1b27d_0 94 | - matplotlib-inline=0.1.6=py37h06a4308_0 95 | - mistune=0.8.4=py37h14c3975_1001 96 | - mkl=2021.4.0=h06a4308_640 97 | - mkl-service=2.4.0=py37h7f8727e_0 98 | - mkl_fft=1.3.1=py37hd3c417c_0 99 | - mkl_random=1.2.2=py37h51133e4_0 100 | - multidict=6.0.2=py37h5eee18b_0 101 | - murmurhash=1.0.7=py37hd23a5d3_0 102 | - nb_conda_kernels=2.3.1=py37h06a4308_0 103 | - nbclient=0.5.13=py37h06a4308_0 104 | - nbconvert=6.4.4=py37h06a4308_0 105 | - nbformat=5.5.0=py37h06a4308_0 106 | - ncurses=6.3=h5eee18b_3 107 | - nest-asyncio=1.5.6=py37h06a4308_0 108 | - ninja=1.10.2=h06a4308_5 109 | - ninja-base=1.10.2=hd09550d_5 110 | - notebook=6.4.12=py37h06a4308_0 111 | - numpy=1.21.5=py37h6c91a56_3 112 | - numpy-base=1.21.5=py37ha15fc14_3 113 | - openssl=1.1.1v=h7f8727e_0 114 | - packaging=21.3=pyhd8ed1ab_0 115 | - pandocfilters=1.5.0=pyhd3eb1b0_0 116 | - parso=0.8.3=pyhd3eb1b0_0 117 | - pathy=0.6.2=pyhd8ed1ab_0 118 | - pcre=8.45=h295c915_0 119 | - pexpect=4.8.0=pyhd3eb1b0_3 120 | - pickleshare=0.7.5=pyhd3eb1b0_1003 121 | - pillow=9.2.0=py37hace64e9_1 122 | - pip=22.2.2=py37h06a4308_0 123 | - preshed=3.0.6=py37hd23a5d3_2 124 | - prometheus_client=0.14.1=py37h06a4308_0 125 | - prompt-toolkit=3.0.36=py37h06a4308_0 126 | - psutil=5.9.0=py37h5eee18b_0 127 | - ptyprocess=0.7.0=pyhd3eb1b0_2 128 | - pycparser=2.21=pyhd8ed1ab_0 129 | - pydantic=1.8.2=py37h5e8e339_2 130 | - pygments=2.11.2=pyhd3eb1b0_0 131 | - pyjwt=2.4.0=py37h06a4308_0 132 | - pyopenssl=22.0.0=pyhd8ed1ab_1 133 | - pyparsing=3.0.9=py37h06a4308_0 134 | - pyqt=5.9.2=py37h05f1152_2 135 | - pyrsistent=0.18.0=py37heee7806_0 136 | - pysocks=1.7.1=py37h89c1867_5 137 | - python=3.7.13=h12debd9_0 138 | - python-dateutil=2.8.2=pyhd3eb1b0_0 139 | - python-fastjsonschema=2.16.2=py37h06a4308_0 140 | - python_abi=3.7=2_cp37m 141 | - pytorch=1.7.1=py3.7_cuda11.0.221_cudnn8.0.5_0 142 | - pyzmq=23.2.0=py37h6a678d5_0 143 | - qt=5.9.7=h5867ecd_1 144 | - readline=8.1.2=h7f8727e_1 145 | - requests=2.28.1=pyhd8ed1ab_1 146 | - scikit-learn=1.0.2=py37h51133e4_1 147 | - scipy=1.7.3=py37h6c91a56_2 148 | - send2trash=1.8.0=pyhd3eb1b0_1 149 | - setuptools=63.4.1=py37h06a4308_0 150 | - shellingham=1.5.0=pyhd8ed1ab_0 151 | - sip=4.19.8=py37hf484d3e_0 152 | - six=1.16.0=pyhd3eb1b0_1 153 | - smart_open=5.2.1=pyhd8ed1ab_0 154 | - soupsieve=2.3.2.post1=pyhd8ed1ab_0 155 | - spacy=3.3.1=py37h79cecc1_0 156 | - spacy-legacy=3.0.10=pyhd8ed1ab_0 157 | - spacy-loggers=1.0.3=pyhd8ed1ab_0 158 | - sqlite=3.39.3=h5082296_0 159 | - srsly=2.4.3=py37hd23a5d3_1 160 | - tensorboard-plugin-wit=1.8.1=py37h06a4308_0 161 | - terminado=0.17.1=py37h06a4308_0 162 | - testpath=0.6.0=py37h06a4308_0 163 | - thinc=8.0.15=py37h48bf904_0 164 | - threadpoolctl=2.2.0=pyh0d69192_0 165 | - tk=8.6.12=h1ccaba5_0 166 | - torchaudio=0.7.2=py37 167 | - torchvision=0.8.2=py37_cu110 168 | - tornado=6.2=py37h5eee18b_0 169 | - tqdm=4.64.1=py37h06a4308_0 170 | - traitlets=5.7.1=py37h06a4308_0 171 | - trimesh=3.15.3=pyh1a96a4e_0 172 | - typer=0.4.2=pyhd8ed1ab_0 173 | - typing-extensions=3.10.0.2=hd8ed1ab_0 174 | - typing_extensions=3.10.0.2=pyha770c72_0 175 | - urllib3=1.26.15=pyhd8ed1ab_0 176 | - wasabi=0.10.1=pyhd8ed1ab_1 177 | - webencodings=0.5.1=py37_1 178 | - werkzeug=2.2.3=pyhd8ed1ab_0 179 | - wheel=0.37.1=pyhd3eb1b0_0 180 | - xz=5.2.6=h5eee18b_0 181 | - yarl=1.8.1=py37h5eee18b_0 182 | - zeromq=4.3.4=h2531618_0 183 | - zipp=3.8.1=pyhd8ed1ab_0 184 | - zlib=1.2.12=h5eee18b_3 185 | - zstd=1.4.9=haebb681_0 186 | - pip: 187 | - cachetools==5.3.1 188 | - einops==0.6.1 189 | - ftfy==6.1.1 190 | - gdown==4.7.1 191 | - google-auth==2.22.0 192 | - google-auth-oauthlib==0.4.6 193 | - grpcio==1.57.0 194 | - oauthlib==3.2.2 195 | - protobuf==3.20.3 196 | - pyasn1==0.5.0 197 | - pyasn1-modules==0.3.0 198 | - regex==2023.8.8 199 | - requests-oauthlib==1.3.1 200 | - rsa==4.9 201 | - tensorboard==2.11.2 202 | - tensorboard-data-server==0.6.1 203 | - wcwidth==0.2.6 204 | prefix: /home/chuan/anaconda3/envs/momask -------------------------------------------------------------------------------- /models/vq/residual_vq.py: -------------------------------------------------------------------------------- 1 | import random 2 | from math import ceil 3 | from functools import partial 4 | from itertools import zip_longest 5 | from random import randrange 6 | 7 | import torch 8 | from torch import nn 9 | import torch.nn.functional as F 10 | # from vector_quantize_pytorch.vector_quantize_pytorch import VectorQuantize 11 | from models.vq.quantizer import QuantizeEMAReset, QuantizeEMA 12 | 13 | from einops import rearrange, repeat, pack, unpack 14 | 15 | # helper functions 16 | 17 | def exists(val): 18 | return val is not None 19 | 20 | def default(val, d): 21 | return val if exists(val) else d 22 | 23 | def round_up_multiple(num, mult): 24 | return ceil(num / mult) * mult 25 | 26 | # main class 27 | 28 | class ResidualVQ(nn.Module): 29 | """ Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf """ 30 | def __init__( 31 | self, 32 | num_quantizers, 33 | shared_codebook=False, 34 | quantize_dropout_prob=0.5, 35 | quantize_dropout_cutoff_index=0, 36 | **kwargs 37 | ): 38 | super().__init__() 39 | 40 | self.num_quantizers = num_quantizers 41 | 42 | # self.layers = nn.ModuleList([VectorQuantize(accept_image_fmap = accept_image_fmap, **kwargs) for _ in range(num_quantizers)]) 43 | if shared_codebook: 44 | layer = QuantizeEMAReset(**kwargs) 45 | self.layers = nn.ModuleList([layer for _ in range(num_quantizers)]) 46 | else: 47 | self.layers = nn.ModuleList([QuantizeEMAReset(**kwargs) for _ in range(num_quantizers)]) 48 | # self.layers = nn.ModuleList([QuantizeEMA(**kwargs) for _ in range(num_quantizers)]) 49 | 50 | # self.quantize_dropout = quantize_dropout and num_quantizers > 1 51 | 52 | assert quantize_dropout_cutoff_index >= 0 and quantize_dropout_prob >= 0 53 | 54 | self.quantize_dropout_cutoff_index = quantize_dropout_cutoff_index 55 | self.quantize_dropout_prob = quantize_dropout_prob 56 | 57 | 58 | @property 59 | def codebooks(self): 60 | codebooks = [layer.codebook for layer in self.layers] 61 | codebooks = torch.stack(codebooks, dim = 0) 62 | return codebooks # 'q c d' 63 | 64 | def get_codes_from_indices(self, indices): #indices shape 'b n q' # dequantize 65 | 66 | batch, quantize_dim = indices.shape[0], indices.shape[-1] 67 | 68 | # because of quantize dropout, one can pass in indices that are coarse 69 | # and the network should be able to reconstruct 70 | 71 | if quantize_dim < self.num_quantizers: 72 | indices = F.pad(indices, (0, self.num_quantizers - quantize_dim), value = -1) 73 | 74 | # get ready for gathering 75 | 76 | codebooks = repeat(self.codebooks, 'q c d -> q b c d', b = batch) 77 | gather_indices = repeat(indices, 'b n q -> q b n d', d = codebooks.shape[-1]) 78 | 79 | # take care of quantizer dropout 80 | 81 | mask = gather_indices == -1. 82 | gather_indices = gather_indices.masked_fill(mask, 0) # have it fetch a dummy code to be masked out later 83 | 84 | # print(gather_indices.max(), gather_indices.min()) 85 | all_codes = codebooks.gather(2, gather_indices) # gather all codes 86 | 87 | # mask out any codes that were dropout-ed 88 | 89 | all_codes = all_codes.masked_fill(mask, 0.) 90 | 91 | return all_codes # 'q b n d' 92 | 93 | def get_codebook_entry(self, indices): #indices shape 'b n q' 94 | all_codes = self.get_codes_from_indices(indices) #'q b n d' 95 | latent = torch.sum(all_codes, dim=0) #'b n d' 96 | latent = latent.permute(0, 2, 1) 97 | return latent 98 | 99 | def forward(self, x, return_all_codes = False, sample_codebook_temp = None, force_dropout_index=-1): 100 | # debug check 101 | # print(self.codebooks[:,0,0].detach().cpu().numpy()) 102 | num_quant, quant_dropout_prob, device = self.num_quantizers, self.quantize_dropout_prob, x.device 103 | 104 | quantized_out = 0. 105 | residual = x 106 | 107 | all_losses = [] 108 | all_indices = [] 109 | all_perplexity = [] 110 | 111 | 112 | should_quantize_dropout = self.training and random.random() < self.quantize_dropout_prob 113 | 114 | start_drop_quantize_index = num_quant 115 | # To ensure the first-k layers learn things as much as possible, we randomly dropout the last q - k layers 116 | if should_quantize_dropout: 117 | start_drop_quantize_index = randrange(self.quantize_dropout_cutoff_index, num_quant) # keep quant layers <= quantize_dropout_cutoff_index, TODO vary in batch 118 | null_indices_shape = [x.shape[0], x.shape[-1]] # 'b*n' 119 | null_indices = torch.full(null_indices_shape, -1., device = device, dtype = torch.long) 120 | # null_loss = 0. 121 | 122 | if force_dropout_index >= 0: 123 | should_quantize_dropout = True 124 | start_drop_quantize_index = force_dropout_index 125 | null_indices_shape = [x.shape[0], x.shape[-1]] # 'b*n' 126 | null_indices = torch.full(null_indices_shape, -1., device=device, dtype=torch.long) 127 | 128 | # print(force_dropout_index) 129 | # go through the layers 130 | 131 | for quantizer_index, layer in enumerate(self.layers): 132 | 133 | if should_quantize_dropout and quantizer_index > start_drop_quantize_index: 134 | all_indices.append(null_indices) 135 | # all_losses.append(null_loss) 136 | continue 137 | 138 | # layer_indices = None 139 | # if return_loss: 140 | # layer_indices = indices[..., quantizer_index] #gt indices 141 | 142 | # quantized, *rest = layer(residual, indices = layer_indices, sample_codebook_temp = sample_codebook_temp) #single quantizer TODO 143 | quantized, *rest = layer(residual, return_idx=True, temperature=sample_codebook_temp) #single quantizer 144 | 145 | # print(quantized.shape, residual.shape) 146 | residual -= quantized.detach() 147 | quantized_out += quantized 148 | 149 | embed_indices, loss, perplexity = rest 150 | all_indices.append(embed_indices) 151 | all_losses.append(loss) 152 | all_perplexity.append(perplexity) 153 | 154 | 155 | # stack all losses and indices 156 | all_indices = torch.stack(all_indices, dim=-1) 157 | all_losses = sum(all_losses)/len(all_losses) 158 | all_perplexity = sum(all_perplexity)/len(all_perplexity) 159 | 160 | ret = (quantized_out, all_indices, all_losses, all_perplexity) 161 | 162 | if return_all_codes: 163 | # whether to return all codes from all codebooks across layers 164 | all_codes = self.get_codes_from_indices(all_indices) 165 | 166 | # will return all codes in shape (quantizer, batch, sequence length, codebook dimension) 167 | ret = (*ret, all_codes) 168 | 169 | return ret 170 | 171 | def quantize(self, x, return_latent=False): 172 | all_indices = [] 173 | quantized_out = 0. 174 | residual = x 175 | all_codes = [] 176 | for quantizer_index, layer in enumerate(self.layers): 177 | 178 | quantized, *rest = layer(residual, return_idx=True) #single quantizer 179 | 180 | residual = residual - quantized.detach() 181 | quantized_out = quantized_out + quantized 182 | 183 | embed_indices, loss, perplexity = rest 184 | all_indices.append(embed_indices) 185 | # print(quantizer_index, embed_indices[0]) 186 | # print(quantizer_index, quantized[0]) 187 | # break 188 | all_codes.append(quantized) 189 | 190 | code_idx = torch.stack(all_indices, dim=-1) 191 | all_codes = torch.stack(all_codes, dim=0) 192 | if return_latent: 193 | return code_idx, all_codes 194 | return code_idx -------------------------------------------------------------------------------- /train_res_transformer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | 5 | from torch.utils.data import DataLoader 6 | from os.path import join as pjoin 7 | 8 | from models.mask_transformer.transformer import ResidualTransformer 9 | from models.mask_transformer.transformer_trainer import ResidualTransformerTrainer 10 | from models.vq.model import RVQVAE 11 | 12 | from options.train_option import TrainT2MOptions 13 | 14 | from utils.plot_script import plot_3d_motion 15 | from utils.motion_process import recover_from_ric 16 | from utils.get_opt import get_opt 17 | from utils.fixseed import fixseed 18 | from utils.paramUtil import t2m_kinematic_chain, kit_kinematic_chain 19 | 20 | from data.t2m_dataset import Text2MotionDataset 21 | from motion_loaders.dataset_motion_loader import get_dataset_motion_loader 22 | from models.t2m_eval_wrapper import EvaluatorModelWrapper 23 | 24 | 25 | def plot_t2m(data, save_dir, captions, m_lengths): 26 | data = train_dataset.inv_transform(data) 27 | 28 | # print(ep_curves.shape) 29 | for i, (caption, joint_data) in enumerate(zip(captions, data)): 30 | joint_data = joint_data[:m_lengths[i]] 31 | joint = recover_from_ric(torch.from_numpy(joint_data).float(), opt.joints_num).numpy() 32 | save_path = pjoin(save_dir, '%02d.mp4'%i) 33 | # print(joint.shape) 34 | plot_3d_motion(save_path, kinematic_chain, joint, title=caption, fps=fps, radius=radius) 35 | 36 | def load_vq_model(): 37 | opt_path = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.vq_name, 'opt.txt') 38 | vq_opt = get_opt(opt_path, opt.device) 39 | vq_model = RVQVAE(vq_opt, 40 | dim_pose, 41 | vq_opt.nb_code, 42 | vq_opt.code_dim, 43 | vq_opt.output_emb_width, 44 | vq_opt.down_t, 45 | vq_opt.stride_t, 46 | vq_opt.width, 47 | vq_opt.depth, 48 | vq_opt.dilation_growth_rate, 49 | vq_opt.vq_act, 50 | vq_opt.vq_norm) 51 | ckpt = torch.load(pjoin(vq_opt.checkpoints_dir, vq_opt.dataset_name, vq_opt.name, 'model', 'net_best_fid.tar'), 52 | map_location=opt.device) 53 | model_key = 'vq_model' if 'vq_model' in ckpt else 'net' 54 | vq_model.load_state_dict(ckpt[model_key]) 55 | print(f'Loading VQ Model {opt.vq_name}') 56 | vq_model.to(opt.device) 57 | return vq_model, vq_opt 58 | 59 | if __name__ == '__main__': 60 | parser = TrainT2MOptions() 61 | opt = parser.parse() 62 | fixseed(opt.seed) 63 | 64 | opt.device = torch.device("cpu" if opt.gpu_id == -1 else "cuda:" + str(opt.gpu_id)) 65 | torch.autograd.set_detect_anomaly(True) 66 | 67 | opt.save_root = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.name) 68 | opt.model_dir = pjoin(opt.save_root, 'model') 69 | # opt.meta_dir = pjoin(opt.save_root, 'meta') 70 | opt.eval_dir = pjoin(opt.save_root, 'animation') 71 | opt.log_dir = pjoin('./log/res/', opt.dataset_name, opt.name) 72 | 73 | os.makedirs(opt.model_dir, exist_ok=True) 74 | # os.makedirs(opt.meta_dir, exist_ok=True) 75 | os.makedirs(opt.eval_dir, exist_ok=True) 76 | os.makedirs(opt.log_dir, exist_ok=True) 77 | 78 | if opt.dataset_name == 't2m': 79 | opt.data_root = './dataset/HumanML3D' 80 | opt.motion_dir = pjoin(opt.data_root, 'new_joint_vecs') 81 | opt.joints_num = 22 82 | opt.max_motion_len = 55 83 | dim_pose = 263 84 | radius = 4 85 | fps = 20 86 | kinematic_chain = t2m_kinematic_chain 87 | dataset_opt_path = './checkpoints/t2m/Comp_v6_KLD005/opt.txt' 88 | 89 | elif opt.dataset_name == 'kit': #TODO 90 | opt.data_root = './dataset/KIT-ML' 91 | opt.motion_dir = pjoin(opt.data_root, 'new_joint_vecs') 92 | opt.joints_num = 21 93 | radius = 240 * 8 94 | fps = 12.5 95 | dim_pose = 251 96 | opt.max_motion_len = 55 97 | kinematic_chain = kit_kinematic_chain 98 | dataset_opt_path = './checkpoints/kit/Comp_v6_KLD005/opt.txt' 99 | 100 | else: 101 | raise KeyError('Dataset Does Not Exist') 102 | 103 | opt.text_dir = pjoin(opt.data_root, 'texts') 104 | 105 | vq_model, vq_opt = load_vq_model() 106 | 107 | clip_version = 'ViT-B/32' 108 | 109 | opt.num_tokens = vq_opt.nb_code 110 | opt.num_quantizers = vq_opt.num_quantizers 111 | 112 | # if opt.is_v2: 113 | res_transformer = ResidualTransformer(code_dim=vq_opt.code_dim, 114 | cond_mode='text', 115 | latent_dim=opt.latent_dim, 116 | ff_size=opt.ff_size, 117 | num_layers=opt.n_layers, 118 | num_heads=opt.n_heads, 119 | dropout=opt.dropout, 120 | clip_dim=512, 121 | shared_codebook=vq_opt.shared_codebook, 122 | cond_drop_prob=opt.cond_drop_prob, 123 | # codebook=vq_model.quantizer.codebooks[0] if opt.fix_token_emb else None, 124 | share_weight=opt.share_weight, 125 | clip_version=clip_version, 126 | opt=opt) 127 | # else: 128 | # res_transformer = ResidualTransformer(code_dim=vq_opt.code_dim, 129 | # cond_mode='text', 130 | # latent_dim=opt.latent_dim, 131 | # ff_size=opt.ff_size, 132 | # num_layers=opt.n_layers, 133 | # num_heads=opt.n_heads, 134 | # dropout=opt.dropout, 135 | # clip_dim=512, 136 | # shared_codebook=vq_opt.shared_codebook, 137 | # cond_drop_prob=opt.cond_drop_prob, 138 | # # codebook=vq_model.quantizer.codebooks[0] if opt.fix_token_emb else None, 139 | # clip_version=clip_version, 140 | # opt=opt) 141 | 142 | 143 | all_params = 0 144 | pc_transformer = sum(param.numel() for param in res_transformer.parameters_wo_clip()) 145 | 146 | print(res_transformer) 147 | # print("Total parameters of t2m_transformer net: {:.2f}M".format(pc_transformer / 1000_000)) 148 | all_params += pc_transformer 149 | 150 | print('Total parameters of all models: {:.2f}M'.format(all_params / 1000_000)) 151 | 152 | mean = np.load(pjoin(opt.checkpoints_dir, opt.dataset_name, opt.vq_name, 'meta', 'mean.npy')) 153 | std = np.load(pjoin(opt.checkpoints_dir, opt.dataset_name, opt.vq_name, 'meta', 'std.npy')) 154 | 155 | train_split_file = pjoin(opt.data_root, 'train.txt') 156 | val_split_file = pjoin(opt.data_root, 'val.txt') 157 | 158 | train_dataset = Text2MotionDataset(opt, mean, std, train_split_file) 159 | val_dataset = Text2MotionDataset(opt, mean, std, val_split_file) 160 | 161 | train_loader = DataLoader(train_dataset, batch_size=opt.batch_size, num_workers=4, shuffle=True, drop_last=True) 162 | val_loader = DataLoader(val_dataset, batch_size=opt.batch_size, num_workers=4, shuffle=True, drop_last=True) 163 | 164 | eval_val_loader, _ = get_dataset_motion_loader(dataset_opt_path, 32, 'val', device=opt.device) 165 | 166 | wrapper_opt = get_opt(dataset_opt_path, torch.device('cuda')) 167 | eval_wrapper = EvaluatorModelWrapper(wrapper_opt) 168 | 169 | trainer = ResidualTransformerTrainer(opt, res_transformer, vq_model) 170 | 171 | trainer.train(train_loader, val_loader, eval_val_loader, eval_wrapper=eval_wrapper, plot_eval=plot_t2m) -------------------------------------------------------------------------------- /edit_t2m.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join as pjoin 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | from models.mask_transformer.transformer import MaskTransformer, ResidualTransformer 8 | from models.vq.model import RVQVAE, LengthEstimator 9 | 10 | from options.eval_option import EvalT2MOptions 11 | from utils.get_opt import get_opt 12 | 13 | from utils.fixseed import fixseed 14 | from visualization.joints2bvh import Joint2BVHConvertor 15 | 16 | from utils.motion_process import recover_from_ric 17 | from utils.plot_script import plot_3d_motion 18 | 19 | from utils.paramUtil import t2m_kinematic_chain 20 | 21 | import numpy as np 22 | 23 | from gen_t2m import load_vq_model, load_res_model, load_trans_model 24 | 25 | if __name__ == '__main__': 26 | parser = EvalT2MOptions() 27 | opt = parser.parse() 28 | fixseed(opt.seed) 29 | 30 | opt.device = torch.device("cpu" if opt.gpu_id == -1 else "cuda:" + str(opt.gpu_id)) 31 | torch.autograd.set_detect_anomaly(True) 32 | 33 | dim_pose = 251 if opt.dataset_name == 'kit' else 263 34 | 35 | root_dir = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.name) 36 | model_dir = pjoin(root_dir, 'model') 37 | result_dir = pjoin('./editing', opt.ext) 38 | joints_dir = pjoin(result_dir, 'joints') 39 | animation_dir = pjoin(result_dir, 'animations') 40 | os.makedirs(joints_dir, exist_ok=True) 41 | os.makedirs(animation_dir,exist_ok=True) 42 | 43 | model_opt_path = pjoin(root_dir, 'opt.txt') 44 | model_opt = get_opt(model_opt_path, device=opt.device) 45 | 46 | ####################### 47 | ######Loading RVQ###### 48 | ####################### 49 | vq_opt_path = pjoin(opt.checkpoints_dir, opt.dataset_name, model_opt.vq_name, 'opt.txt') 50 | vq_opt = get_opt(vq_opt_path, device=opt.device) 51 | vq_opt.dim_pose = dim_pose 52 | vq_model, vq_opt = load_vq_model(vq_opt) 53 | 54 | model_opt.num_tokens = vq_opt.nb_code 55 | model_opt.num_quantizers = vq_opt.num_quantizers 56 | model_opt.code_dim = vq_opt.code_dim 57 | 58 | ################################# 59 | ######Loading R-Transformer###### 60 | ################################# 61 | res_opt_path = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.res_name, 'opt.txt') 62 | res_opt = get_opt(res_opt_path, device=opt.device) 63 | res_model = load_res_model(res_opt, vq_opt, opt) 64 | 65 | assert res_opt.vq_name == model_opt.vq_name 66 | 67 | ################################# 68 | ######Loading M-Transformer###### 69 | ################################# 70 | t2m_transformer = load_trans_model(model_opt, opt, 'latest.tar') 71 | 72 | t2m_transformer.eval() 73 | vq_model.eval() 74 | res_model.eval() 75 | 76 | res_model.to(opt.device) 77 | t2m_transformer.to(opt.device) 78 | vq_model.to(opt.device) 79 | 80 | ##### ---- Data ---- ##### 81 | max_motion_length = 196 82 | mean = np.load(pjoin(opt.checkpoints_dir, opt.dataset_name, model_opt.vq_name, 'meta', 'mean.npy')) 83 | std = np.load(pjoin(opt.checkpoints_dir, opt.dataset_name, model_opt.vq_name, 'meta', 'std.npy')) 84 | def inv_transform(data): 85 | return data * std + mean 86 | ### We provided an example source motion (from 'new_joint_vecs') for editing. See './example_data/000612.mp4'### 87 | motion = np.load(opt.source_motion) 88 | m_length = len(motion) 89 | motion = (motion - mean) / std 90 | if max_motion_length > m_length: 91 | motion = np.concatenate([motion, np.zeros((max_motion_length - m_length, motion.shape[1])) ], axis=0) 92 | motion = torch.from_numpy(motion)[None].to(opt.device) 93 | 94 | prompt_list = [] 95 | length_list = [] 96 | if opt.motion_length == 0: 97 | opt.motion_length = m_length 98 | print("Using default motion length.") 99 | 100 | prompt_list.append(opt.text_prompt) 101 | length_list.append(opt.motion_length) 102 | if opt.text_prompt == "": 103 | raise "Using an empty text prompt." 104 | 105 | token_lens = torch.LongTensor(length_list) // 4 106 | token_lens = token_lens.to(opt.device).long() 107 | 108 | m_length = token_lens * 4 109 | captions = prompt_list 110 | print_captions = captions[0] 111 | 112 | _edit_slice = opt.mask_edit_section 113 | edit_slice = [] 114 | for eds in _edit_slice: 115 | _start, _end = eds.split(',') 116 | _start = eval(_start) 117 | _end = eval(_end) 118 | edit_slice.append([_start, _end]) 119 | 120 | sample = 0 121 | kinematic_chain = t2m_kinematic_chain 122 | converter = Joint2BVHConvertor() 123 | 124 | with torch.no_grad(): 125 | tokens, features = vq_model.encode(motion) 126 | ### build editing mask, TOEDIT marked as 1 ### 127 | edit_mask = torch.zeros_like(tokens[..., 0]) 128 | seq_len = tokens.shape[1] 129 | for _start, _end in edit_slice: 130 | if isinstance(_start, float): 131 | _start = int(_start*seq_len) 132 | _end = int(_end*seq_len) 133 | else: 134 | _start //= 4 135 | _end //= 4 136 | edit_mask[:, _start: _end] = 1 137 | print_captions = f'{print_captions} [{_start*4/20.}s - {_end*4/20.}s]' 138 | edit_mask = edit_mask.bool() 139 | for r in range(opt.repeat_times): 140 | print("-->Repeat %d"%r) 141 | with torch.no_grad(): 142 | mids = t2m_transformer.edit( 143 | captions, tokens[..., 0].clone(), m_length//4, 144 | timesteps=opt.time_steps, 145 | cond_scale=opt.cond_scale, 146 | temperature=opt.temperature, 147 | topk_filter_thres=opt.topkr, 148 | gsample=opt.gumbel_sample, 149 | force_mask=opt.force_mask, 150 | edit_mask=edit_mask.clone(), 151 | ) 152 | if opt.use_res_model: 153 | mids = res_model.generate(mids, captions, m_length//4, temperature=1, cond_scale=2) 154 | else: 155 | mids.unsqueeze_(-1) 156 | 157 | pred_motions = vq_model.forward_decoder(mids) 158 | 159 | pred_motions = pred_motions.detach().cpu().numpy() 160 | 161 | source_motions = motion.detach().cpu().numpy() 162 | 163 | data = inv_transform(pred_motions) 164 | source_data = inv_transform(source_motions) 165 | 166 | for k, (caption, joint_data, source_data) in enumerate(zip(captions, data, source_data)): 167 | print("---->Sample %d: %s %d"%(k, caption, m_length[k])) 168 | animation_path = pjoin(animation_dir, str(k)) 169 | joint_path = pjoin(joints_dir, str(k)) 170 | 171 | os.makedirs(animation_path, exist_ok=True) 172 | os.makedirs(joint_path, exist_ok=True) 173 | 174 | joint_data = joint_data[:m_length[k]] 175 | joint = recover_from_ric(torch.from_numpy(joint_data).float(), 22).numpy() 176 | 177 | source_data = source_data[:m_length[k]] 178 | soucre_joint = recover_from_ric(torch.from_numpy(source_data).float(), 22).numpy() 179 | 180 | bvh_path = pjoin(animation_path, "sample%d_repeat%d_len%d_ik.bvh"%(k, r, m_length[k])) 181 | _, ik_joint = converter.convert(joint, filename=bvh_path, iterations=100) 182 | 183 | bvh_path = pjoin(animation_path, "sample%d_repeat%d_len%d.bvh" % (k, r, m_length[k])) 184 | _, joint = converter.convert(joint, filename=bvh_path, iterations=100, foot_ik=False) 185 | 186 | 187 | save_path = pjoin(animation_path, "sample%d_repeat%d_len%d.mp4"%(k, r, m_length[k])) 188 | ik_save_path = pjoin(animation_path, "sample%d_repeat%d_len%d_ik.mp4"%(k, r, m_length[k])) 189 | source_save_path = pjoin(animation_path, "sample%d_source_len%d.mp4"%(k, m_length[k])) 190 | 191 | plot_3d_motion(ik_save_path, kinematic_chain, ik_joint, title=print_captions, fps=20) 192 | plot_3d_motion(save_path, kinematic_chain, joint, title=print_captions, fps=20) 193 | plot_3d_motion(source_save_path, kinematic_chain, soucre_joint, title='None', fps=20) 194 | np.save(pjoin(joint_path, "sample%d_repeat%d_len%d.npy"%(k, r, m_length[k])), joint) 195 | np.save(pjoin(joint_path, "sample%d_repeat%d_len%d_ik.npy"%(k, r, m_length[k])), ik_joint) -------------------------------------------------------------------------------- /visualization/utils/bvh.py: -------------------------------------------------------------------------------- 1 | import re 2 | import numpy as np 3 | 4 | channelmap = { 5 | 'Xrotation': 'x', 6 | 'Yrotation': 'y', 7 | 'Zrotation': 'z' 8 | } 9 | 10 | channelmap_inv = { 11 | 'x': 'Xrotation', 12 | 'y': 'Yrotation', 13 | 'z': 'Zrotation', 14 | } 15 | 16 | ordermap = { 17 | 'x': 0, 18 | 'y': 1, 19 | 'z': 2, 20 | } 21 | 22 | def load(filename:str, order:str=None) -> dict: 23 | """Loads a BVH file. 24 | 25 | Args: 26 | filename (str): Path to the BVH file. 27 | order (str): The order of the rotation channels. (i.e."xyz") 28 | 29 | Returns: 30 | dict: A dictionary containing the following keys: 31 | * names (list)(jnum): The names of the joints. 32 | * parents (list)(jnum): The parent indices. 33 | * offsets (np.ndarray)(jnum, 3): The offsets of the joints. 34 | * rotations (np.ndarray)(fnum, jnum, 3) : The local coordinates of rotations of the joints. 35 | * positions (np.ndarray)(fnum, jnum, 3) : The positions of the joints. 36 | * order (str): The order of the channels. 37 | * frametime (float): The time between two frames. 38 | """ 39 | 40 | f = open(filename, "r") 41 | 42 | i = 0 43 | active = -1 44 | end_site = False 45 | 46 | # Create empty lists for saving parameters 47 | names = [] 48 | offsets = np.array([]).reshape((0, 3)) 49 | parents = np.array([], dtype=int) 50 | 51 | # Parse the file, line by line 52 | for line in f: 53 | 54 | if "HIERARCHY" in line: continue 55 | if "MOTION" in line: continue 56 | 57 | rmatch = re.match(r"ROOT (\w+)", line) 58 | if rmatch: 59 | names.append(rmatch.group(1)) 60 | offsets = np.append(offsets, np.array([[0, 0, 0]]), axis=0) 61 | parents = np.append(parents, active) 62 | active = (len(parents) - 1) 63 | continue 64 | 65 | if "{" in line: continue 66 | 67 | if "}" in line: 68 | if end_site: 69 | end_site = False 70 | else: 71 | active = parents[active] 72 | continue 73 | 74 | offmatch = re.match(r"\s*OFFSET\s+([\-\d\.e]+)\s+([\-\d\.e]+)\s+([\-\d\.e]+)", line) 75 | if offmatch: 76 | if not end_site: 77 | offsets[active] = np.array([list(map(float, offmatch.groups()))]) 78 | continue 79 | 80 | chanmatch = re.match(r"\s*CHANNELS\s+(\d+)", line) 81 | if chanmatch: 82 | channels = int(chanmatch.group(1)) 83 | if order is None: 84 | channelis = 0 if channels == 3 else 3 85 | channelie = 3 if channels == 3 else 6 86 | parts = line.split()[2 + channelis:2 + channelie] 87 | if any([p not in channelmap for p in parts]): 88 | continue 89 | order = "".join([channelmap[p] for p in parts]) 90 | continue 91 | 92 | jmatch = re.match("\s*JOINT\s+(\w+)", line) 93 | if jmatch: 94 | names.append(jmatch.group(1)) 95 | offsets = np.append(offsets, np.array([[0, 0, 0]]), axis=0) 96 | parents = np.append(parents, active) 97 | active = (len(parents) - 1) 98 | continue 99 | 100 | if "End Site" in line: 101 | end_site = True 102 | continue 103 | 104 | fmatch = re.match("\s*Frames:\s+(\d+)", line) 105 | if fmatch: 106 | fnum = int(fmatch.group(1)) 107 | positions = offsets[None].repeat(fnum, axis=0) 108 | rotations = np.zeros((fnum, len(offsets), 3)) 109 | continue 110 | 111 | fmatch = re.match("\s*Frame Time:\s+([\d\.]+)", line) 112 | if fmatch: 113 | frametime = float(fmatch.group(1)) 114 | continue 115 | 116 | dmatch = line.strip().split(' ') 117 | if dmatch: 118 | data_block = np.array(list(map(float, dmatch))) 119 | N = len(parents) 120 | fi = i 121 | if channels == 3: 122 | positions[fi, 0:1] = data_block[0:3] 123 | rotations[fi, :] = data_block[3:].reshape(N, 3) 124 | elif channels == 6: 125 | data_block = data_block.reshape(N, 6) 126 | positions[fi, :] = data_block[:, 0:3] 127 | rotations[fi, :] = data_block[:, 3:6] 128 | elif channels == 9: 129 | positions[fi, 0] = data_block[0:3] 130 | data_block = data_block[3:].reshape(N - 1, 9) 131 | rotations[fi, 1:] = data_block[:, 3:6] 132 | positions[fi, 1:] += data_block[:, 0:3] * data_block[:, 6:9] 133 | else: 134 | raise Exception("Too many channels! %i" % channels) 135 | 136 | i += 1 137 | 138 | f.close() 139 | 140 | return { 141 | 'rotations': rotations, 142 | 'positions': positions, 143 | 'offsets': offsets, 144 | 'parents': parents, 145 | 'names': names, 146 | 'order': order, 147 | 'frametime': frametime 148 | } 149 | 150 | 151 | def save_joint(f, data, t, i, save_order, order='zyx', save_positions=False): 152 | 153 | save_order.append(i) 154 | 155 | f.write("%sJOINT %s\n" % (t, data['names'][i])) 156 | f.write("%s{\n" % t) 157 | t += '\t' 158 | 159 | f.write("%sOFFSET %f %f %f\n" % (t, data['offsets'][i,0], data['offsets'][i,1], data['offsets'][i,2])) 160 | 161 | if save_positions: 162 | f.write("%sCHANNELS 6 Xposition Yposition Zposition %s %s %s \n" % (t, 163 | channelmap_inv[order[0]], channelmap_inv[order[1]], channelmap_inv[order[2]])) 164 | else: 165 | f.write("%sCHANNELS 3 %s %s %s\n" % (t, 166 | channelmap_inv[order[0]], channelmap_inv[order[1]], channelmap_inv[order[2]])) 167 | 168 | end_site = True 169 | 170 | for j in range(len(data['parents'])): 171 | if data['parents'][j] == i: 172 | t = save_joint(f, data, t, j, save_order, order=order, save_positions=save_positions) 173 | end_site = False 174 | 175 | if end_site: 176 | f.write("%sEnd Site\n" % t) 177 | f.write("%s{\n" % t) 178 | t += '\t' 179 | f.write("%sOFFSET %f %f %f\n" % (t, 0.0, 0.0, 0.0)) 180 | t = t[:-1] 181 | f.write("%s}\n" % t) 182 | 183 | t = t[:-1] 184 | f.write("%s}\n" % t) 185 | 186 | return t 187 | 188 | 189 | def save(filename, data, save_positions=False): 190 | """ Save a joint hierarchy to a file. 191 | 192 | Args: 193 | filename (str): The output will save on the bvh file. 194 | data (dict): The data to save.(rotations, positions, offsets, parents, names, order, frametime) 195 | save_positions (bool): Whether to save all of joint positions on MOTION. (False is recommended.) 196 | """ 197 | 198 | order = data['order'] 199 | frametime = data['frametime'] 200 | 201 | with open(filename, 'w') as f: 202 | 203 | t = "" 204 | f.write("%sHIERARCHY\n" % t) 205 | f.write("%sROOT %s\n" % (t, data['names'][0])) 206 | f.write("%s{\n" % t) 207 | t += '\t' 208 | 209 | f.write("%sOFFSET %f %f %f\n" % (t, data['offsets'][0,0], data['offsets'][0,1], data['offsets'][0,2]) ) 210 | f.write("%sCHANNELS 6 Xposition Yposition Zposition %s %s %s \n" % 211 | (t, channelmap_inv[order[0]], channelmap_inv[order[1]], channelmap_inv[order[2]])) 212 | 213 | save_order = [0] 214 | 215 | for i in range(len(data['parents'])): 216 | if data['parents'][i] == 0: 217 | t = save_joint(f, data, t, i, save_order, order=order, save_positions=save_positions) 218 | 219 | t = t[:-1] 220 | f.write("%s}\n" % t) 221 | 222 | rots, poss = data['rotations'], data['positions'] 223 | 224 | f.write("MOTION\n") 225 | f.write("Frames: %i\n" % len(rots)); 226 | f.write("Frame Time: %f\n" % frametime); 227 | 228 | for i in range(rots.shape[0]): 229 | for j in save_order: 230 | 231 | if save_positions or j == 0: 232 | 233 | f.write("%f %f %f %f %f %f " % ( 234 | poss[i,j,0], poss[i,j,1], poss[i,j,2], 235 | rots[i,j,0], rots[i,j,1], rots[i,j,2])) 236 | 237 | else: 238 | 239 | f.write("%f %f %f " % ( 240 | rots[i,j,0], rots[i,j,1], rots[i,j,2])) 241 | 242 | f.write("\n") -------------------------------------------------------------------------------- /models/t2m_eval_wrapper.py: -------------------------------------------------------------------------------- 1 | from models.t2m_eval_modules import * 2 | from utils.word_vectorizer import POS_enumerator 3 | from os.path import join as pjoin 4 | 5 | def build_models(opt): 6 | movement_enc = MovementConvEncoder(opt.dim_pose-4, opt.dim_movement_enc_hidden, opt.dim_movement_latent) 7 | text_enc = TextEncoderBiGRUCo(word_size=opt.dim_word, 8 | pos_size=opt.dim_pos_ohot, 9 | hidden_size=opt.dim_text_hidden, 10 | output_size=opt.dim_coemb_hidden, 11 | device=opt.device) 12 | 13 | motion_enc = MotionEncoderBiGRUCo(input_size=opt.dim_movement_latent, 14 | hidden_size=opt.dim_motion_hidden, 15 | output_size=opt.dim_coemb_hidden, 16 | device=opt.device) 17 | 18 | checkpoint = torch.load(pjoin(opt.checkpoints_dir, opt.dataset_name, 'text_mot_match', 'model', 'finest.tar'), 19 | map_location=opt.device) 20 | movement_enc.load_state_dict(checkpoint['movement_encoder']) 21 | text_enc.load_state_dict(checkpoint['text_encoder']) 22 | motion_enc.load_state_dict(checkpoint['motion_encoder']) 23 | print('Loading Evaluation Model Wrapper (Epoch %d) Completed!!' % (checkpoint['epoch'])) 24 | return text_enc, motion_enc, movement_enc 25 | 26 | 27 | class EvaluatorModelWrapper(object): 28 | 29 | def __init__(self, opt): 30 | 31 | if opt.dataset_name == 't2m': 32 | opt.dim_pose = 263 33 | elif opt.dataset_name == 'kit': 34 | opt.dim_pose = 251 35 | else: 36 | raise KeyError('Dataset not Recognized!!!') 37 | 38 | opt.dim_word = 300 39 | opt.max_motion_length = 196 40 | opt.dim_pos_ohot = len(POS_enumerator) 41 | opt.dim_motion_hidden = 1024 42 | opt.max_text_len = 20 43 | opt.dim_text_hidden = 512 44 | opt.dim_coemb_hidden = 512 45 | 46 | # print(opt) 47 | 48 | self.text_encoder, self.motion_encoder, self.movement_encoder = build_models(opt) 49 | self.opt = opt 50 | self.device = opt.device 51 | 52 | self.text_encoder.to(opt.device) 53 | self.motion_encoder.to(opt.device) 54 | self.movement_encoder.to(opt.device) 55 | 56 | self.text_encoder.eval() 57 | self.motion_encoder.eval() 58 | self.movement_encoder.eval() 59 | 60 | # Please note that the results does not follow the order of inputs 61 | def get_co_embeddings(self, word_embs, pos_ohot, cap_lens, motions, m_lens): 62 | with torch.no_grad(): 63 | word_embs = word_embs.detach().to(self.device).float() 64 | pos_ohot = pos_ohot.detach().to(self.device).float() 65 | motions = motions.detach().to(self.device).float() 66 | 67 | align_idx = np.argsort(m_lens.data.tolist())[::-1].copy() 68 | motions = motions[align_idx] 69 | m_lens = m_lens[align_idx] 70 | 71 | '''Movement Encoding''' 72 | movements = self.movement_encoder(motions[..., :-4]).detach() 73 | m_lens = m_lens // self.opt.unit_length 74 | motion_embedding = self.motion_encoder(movements, m_lens) 75 | 76 | '''Text Encoding''' 77 | text_embedding = self.text_encoder(word_embs, pos_ohot, cap_lens) 78 | text_embedding = text_embedding[align_idx] 79 | return text_embedding, motion_embedding 80 | 81 | # Please note that the results does not follow the order of inputs 82 | def get_motion_embeddings(self, motions, m_lens): 83 | with torch.no_grad(): 84 | motions = motions.detach().to(self.device).float() 85 | 86 | align_idx = np.argsort(m_lens.data.tolist())[::-1].copy() 87 | motions = motions[align_idx] 88 | m_lens = m_lens[align_idx] 89 | 90 | '''Movement Encoding''' 91 | movements = self.movement_encoder(motions[..., :-4]).detach() 92 | m_lens = m_lens // self.opt.unit_length 93 | motion_embedding = self.motion_encoder(movements, m_lens) 94 | return motion_embedding 95 | 96 | ## Borrowed form MDM 97 | # our version 98 | def build_evaluators(opt): 99 | movement_enc = MovementConvEncoder(opt['dim_pose']-4, opt['dim_movement_enc_hidden'], opt['dim_movement_latent']) 100 | text_enc = TextEncoderBiGRUCo(word_size=opt['dim_word'], 101 | pos_size=opt['dim_pos_ohot'], 102 | hidden_size=opt['dim_text_hidden'], 103 | output_size=opt['dim_coemb_hidden'], 104 | device=opt['device']) 105 | 106 | motion_enc = MotionEncoderBiGRUCo(input_size=opt['dim_movement_latent'], 107 | hidden_size=opt['dim_motion_hidden'], 108 | output_size=opt['dim_coemb_hidden'], 109 | device=opt['device']) 110 | 111 | ckpt_dir = opt['dataset_name'] 112 | if opt['dataset_name'] == 'humanml': 113 | ckpt_dir = 't2m' 114 | 115 | checkpoint = torch.load(pjoin(opt['checkpoints_dir'], ckpt_dir, 'text_mot_match', 'model', 'finest.tar'), 116 | map_location=opt['device']) 117 | movement_enc.load_state_dict(checkpoint['movement_encoder']) 118 | text_enc.load_state_dict(checkpoint['text_encoder']) 119 | motion_enc.load_state_dict(checkpoint['motion_encoder']) 120 | print('Loading Evaluation Model Wrapper (Epoch %d) Completed!!' % (checkpoint['epoch'])) 121 | return text_enc, motion_enc, movement_enc 122 | 123 | # our wrapper 124 | class EvaluatorWrapper(object): 125 | 126 | def __init__(self, dataset_name, device): 127 | opt = { 128 | 'dataset_name': dataset_name, 129 | 'device': device, 130 | 'dim_word': 300, 131 | 'max_motion_length': 196, 132 | 'dim_pos_ohot': len(POS_enumerator), 133 | 'dim_motion_hidden': 1024, 134 | 'max_text_len': 20, 135 | 'dim_text_hidden': 512, 136 | 'dim_coemb_hidden': 512, 137 | 'dim_pose': 263 if dataset_name == 'humanml' else 251, 138 | 'dim_movement_enc_hidden': 512, 139 | 'dim_movement_latent': 512, 140 | 'checkpoints_dir': './checkpoints', 141 | 'unit_length': 4, 142 | } 143 | 144 | self.text_encoder, self.motion_encoder, self.movement_encoder = build_evaluators(opt) 145 | self.opt = opt 146 | self.device = opt['device'] 147 | 148 | self.text_encoder.to(opt['device']) 149 | self.motion_encoder.to(opt['device']) 150 | self.movement_encoder.to(opt['device']) 151 | 152 | self.text_encoder.eval() 153 | self.motion_encoder.eval() 154 | self.movement_encoder.eval() 155 | 156 | # Please note that the results does not following the order of inputs 157 | def get_co_embeddings(self, word_embs, pos_ohot, cap_lens, motions, m_lens): 158 | with torch.no_grad(): 159 | word_embs = word_embs.detach().to(self.device).float() 160 | pos_ohot = pos_ohot.detach().to(self.device).float() 161 | motions = motions.detach().to(self.device).float() 162 | 163 | align_idx = np.argsort(m_lens.data.tolist())[::-1].copy() 164 | motions = motions[align_idx] 165 | m_lens = m_lens[align_idx] 166 | 167 | '''Movement Encoding''' 168 | movements = self.movement_encoder(motions[..., :-4]).detach() 169 | m_lens = m_lens // self.opt['unit_length'] 170 | motion_embedding = self.motion_encoder(movements, m_lens) 171 | # print(motions.shape, movements.shape, motion_embedding.shape, m_lens) 172 | 173 | '''Text Encoding''' 174 | text_embedding = self.text_encoder(word_embs, pos_ohot, cap_lens) 175 | text_embedding = text_embedding[align_idx] 176 | return text_embedding, motion_embedding 177 | 178 | # Please note that the results does not following the order of inputs 179 | def get_motion_embeddings(self, motions, m_lens): 180 | with torch.no_grad(): 181 | motions = motions.detach().to(self.device).float() 182 | 183 | align_idx = np.argsort(m_lens.data.tolist())[::-1].copy() 184 | motions = motions[align_idx] 185 | m_lens = m_lens[align_idx] 186 | 187 | '''Movement Encoding''' 188 | movements = self.movement_encoder(motions[..., :-4]).detach() 189 | m_lens = m_lens // self.opt['unit_length'] 190 | motion_embedding = self.motion_encoder(movements, m_lens) 191 | return motion_embedding -------------------------------------------------------------------------------- /common/skeleton.py: -------------------------------------------------------------------------------- 1 | from common.quaternion import * 2 | import scipy.ndimage.filters as filters 3 | 4 | class Skeleton(object): 5 | def __init__(self, offset, kinematic_tree, device): 6 | self.device = device 7 | self._raw_offset_np = offset.numpy() 8 | self._raw_offset = offset.clone().detach().to(device).float() 9 | self._kinematic_tree = kinematic_tree 10 | self._offset = None 11 | self._parents = [0] * len(self._raw_offset) 12 | self._parents[0] = -1 13 | for chain in self._kinematic_tree: 14 | for j in range(1, len(chain)): 15 | self._parents[chain[j]] = chain[j-1] 16 | 17 | def njoints(self): 18 | return len(self._raw_offset) 19 | 20 | def offset(self): 21 | return self._offset 22 | 23 | def set_offset(self, offsets): 24 | self._offset = offsets.clone().detach().to(self.device).float() 25 | 26 | def kinematic_tree(self): 27 | return self._kinematic_tree 28 | 29 | def parents(self): 30 | return self._parents 31 | 32 | # joints (batch_size, joints_num, 3) 33 | def get_offsets_joints_batch(self, joints): 34 | assert len(joints.shape) == 3 35 | _offsets = self._raw_offset.expand(joints.shape[0], -1, -1).clone() 36 | for i in range(1, self._raw_offset.shape[0]): 37 | _offsets[:, i] = torch.norm(joints[:, i] - joints[:, self._parents[i]], p=2, dim=1)[:, None] * _offsets[:, i] 38 | 39 | self._offset = _offsets.detach() 40 | return _offsets 41 | 42 | # joints (joints_num, 3) 43 | def get_offsets_joints(self, joints): 44 | assert len(joints.shape) == 2 45 | _offsets = self._raw_offset.clone() 46 | for i in range(1, self._raw_offset.shape[0]): 47 | # print(joints.shape) 48 | _offsets[i] = torch.norm(joints[i] - joints[self._parents[i]], p=2, dim=0) * _offsets[i] 49 | 50 | self._offset = _offsets.detach() 51 | return _offsets 52 | 53 | # face_joint_idx should follow the order of right hip, left hip, right shoulder, left shoulder 54 | # joints (batch_size, joints_num, 3) 55 | def inverse_kinematics_np(self, joints, face_joint_idx, smooth_forward=False): 56 | assert len(face_joint_idx) == 4 57 | '''Get Forward Direction''' 58 | l_hip, r_hip, sdr_r, sdr_l = face_joint_idx 59 | across1 = joints[:, r_hip] - joints[:, l_hip] 60 | across2 = joints[:, sdr_r] - joints[:, sdr_l] 61 | across = across1 + across2 62 | across = across / np.sqrt((across**2).sum(axis=-1))[:, np.newaxis] 63 | # print(across1.shape, across2.shape) 64 | 65 | # forward (batch_size, 3) 66 | forward = np.cross(np.array([[0, 1, 0]]), across, axis=-1) 67 | if smooth_forward: 68 | forward = filters.gaussian_filter1d(forward, 20, axis=0, mode='nearest') 69 | # forward (batch_size, 3) 70 | forward = forward / np.sqrt((forward**2).sum(axis=-1))[..., np.newaxis] 71 | 72 | '''Get Root Rotation''' 73 | target = np.array([[0,0,1]]).repeat(len(forward), axis=0) 74 | root_quat = qbetween_np(forward, target) 75 | 76 | '''Inverse Kinematics''' 77 | # quat_params (batch_size, joints_num, 4) 78 | # print(joints.shape[:-1]) 79 | quat_params = np.zeros(joints.shape[:-1] + (4,)) 80 | # print(quat_params.shape) 81 | root_quat[0] = np.array([[1.0, 0.0, 0.0, 0.0]]) 82 | quat_params[:, 0] = root_quat 83 | # quat_params[0, 0] = np.array([[1.0, 0.0, 0.0, 0.0]]) 84 | for chain in self._kinematic_tree: 85 | R = root_quat 86 | for j in range(len(chain) - 1): 87 | # (batch, 3) 88 | u = self._raw_offset_np[chain[j+1]][np.newaxis,...].repeat(len(joints), axis=0) 89 | # print(u.shape) 90 | # (batch, 3) 91 | v = joints[:, chain[j+1]] - joints[:, chain[j]] 92 | v = v / np.sqrt((v**2).sum(axis=-1))[:, np.newaxis] 93 | # print(u.shape, v.shape) 94 | rot_u_v = qbetween_np(u, v) 95 | 96 | R_loc = qmul_np(qinv_np(R), rot_u_v) 97 | 98 | quat_params[:,chain[j + 1], :] = R_loc 99 | R = qmul_np(R, R_loc) 100 | 101 | return quat_params 102 | 103 | # Be sure root joint is at the beginning of kinematic chains 104 | def forward_kinematics(self, quat_params, root_pos, skel_joints=None, do_root_R=True): 105 | # quat_params (batch_size, joints_num, 4) 106 | # joints (batch_size, joints_num, 3) 107 | # root_pos (batch_size, 3) 108 | if skel_joints is not None: 109 | offsets = self.get_offsets_joints_batch(skel_joints) 110 | if len(self._offset.shape) == 2: 111 | offsets = self._offset.expand(quat_params.shape[0], -1, -1) 112 | joints = torch.zeros(quat_params.shape[:-1] + (3,)).to(self.device) 113 | joints[:, 0] = root_pos 114 | for chain in self._kinematic_tree: 115 | if do_root_R: 116 | R = quat_params[:, 0] 117 | else: 118 | R = torch.tensor([[1.0, 0.0, 0.0, 0.0]]).expand(len(quat_params), -1).detach().to(self.device) 119 | for i in range(1, len(chain)): 120 | R = qmul(R, quat_params[:, chain[i]]) 121 | offset_vec = offsets[:, chain[i]] 122 | joints[:, chain[i]] = qrot(R, offset_vec) + joints[:, chain[i-1]] 123 | return joints 124 | 125 | # Be sure root joint is at the beginning of kinematic chains 126 | def forward_kinematics_np(self, quat_params, root_pos, skel_joints=None, do_root_R=True): 127 | # quat_params (batch_size, joints_num, 4) 128 | # joints (batch_size, joints_num, 3) 129 | # root_pos (batch_size, 3) 130 | if skel_joints is not None: 131 | skel_joints = torch.from_numpy(skel_joints) 132 | offsets = self.get_offsets_joints_batch(skel_joints) 133 | if len(self._offset.shape) == 2: 134 | offsets = self._offset.expand(quat_params.shape[0], -1, -1) 135 | offsets = offsets.numpy() 136 | joints = np.zeros(quat_params.shape[:-1] + (3,)) 137 | joints[:, 0] = root_pos 138 | for chain in self._kinematic_tree: 139 | if do_root_R: 140 | R = quat_params[:, 0] 141 | else: 142 | R = np.array([[1.0, 0.0, 0.0, 0.0]]).repeat(len(quat_params), axis=0) 143 | for i in range(1, len(chain)): 144 | R = qmul_np(R, quat_params[:, chain[i]]) 145 | offset_vec = offsets[:, chain[i]] 146 | joints[:, chain[i]] = qrot_np(R, offset_vec) + joints[:, chain[i - 1]] 147 | return joints 148 | 149 | def forward_kinematics_cont6d_np(self, cont6d_params, root_pos, skel_joints=None, do_root_R=True): 150 | # cont6d_params (batch_size, joints_num, 6) 151 | # joints (batch_size, joints_num, 3) 152 | # root_pos (batch_size, 3) 153 | if skel_joints is not None: 154 | skel_joints = torch.from_numpy(skel_joints) 155 | offsets = self.get_offsets_joints_batch(skel_joints) 156 | if len(self._offset.shape) == 2: 157 | offsets = self._offset.expand(cont6d_params.shape[0], -1, -1) 158 | offsets = offsets.numpy() 159 | joints = np.zeros(cont6d_params.shape[:-1] + (3,)) 160 | joints[:, 0] = root_pos 161 | for chain in self._kinematic_tree: 162 | if do_root_R: 163 | matR = cont6d_to_matrix_np(cont6d_params[:, 0]) 164 | else: 165 | matR = np.eye(3)[np.newaxis, :].repeat(len(cont6d_params), axis=0) 166 | for i in range(1, len(chain)): 167 | matR = np.matmul(matR, cont6d_to_matrix_np(cont6d_params[:, chain[i]])) 168 | offset_vec = offsets[:, chain[i]][..., np.newaxis] 169 | # print(matR.shape, offset_vec.shape) 170 | joints[:, chain[i]] = np.matmul(matR, offset_vec).squeeze(-1) + joints[:, chain[i-1]] 171 | return joints 172 | 173 | def forward_kinematics_cont6d(self, cont6d_params, root_pos, skel_joints=None, do_root_R=True): 174 | # cont6d_params (batch_size, joints_num, 6) 175 | # joints (batch_size, joints_num, 3) 176 | # root_pos (batch_size, 3) 177 | if skel_joints is not None: 178 | # skel_joints = torch.from_numpy(skel_joints) 179 | offsets = self.get_offsets_joints_batch(skel_joints) 180 | if len(self._offset.shape) == 2: 181 | offsets = self._offset.expand(cont6d_params.shape[0], -1, -1) 182 | joints = torch.zeros(cont6d_params.shape[:-1] + (3,)).to(cont6d_params.device) 183 | joints[..., 0, :] = root_pos 184 | for chain in self._kinematic_tree: 185 | if do_root_R: 186 | matR = cont6d_to_matrix(cont6d_params[:, 0]) 187 | else: 188 | matR = torch.eye(3).expand((len(cont6d_params), -1, -1)).detach().to(cont6d_params.device) 189 | for i in range(1, len(chain)): 190 | matR = torch.matmul(matR, cont6d_to_matrix(cont6d_params[:, chain[i]])) 191 | offset_vec = offsets[:, chain[i]].unsqueeze(-1) 192 | # print(matR.shape, offset_vec.shape) 193 | joints[:, chain[i]] = torch.matmul(matR, offset_vec).squeeze(-1) + joints[:, chain[i-1]] 194 | return joints 195 | 196 | 197 | 198 | 199 | 200 | -------------------------------------------------------------------------------- /eval_t2m_trans_res.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join as pjoin 3 | 4 | import torch 5 | 6 | from models.mask_transformer.transformer import MaskTransformer, ResidualTransformer 7 | from models.vq.model import RVQVAE 8 | 9 | from options.eval_option import EvalT2MOptions 10 | from utils.get_opt import get_opt 11 | from motion_loaders.dataset_motion_loader import get_dataset_motion_loader 12 | from models.t2m_eval_wrapper import EvaluatorModelWrapper 13 | 14 | import utils.eval_t2m as eval_t2m 15 | from utils.fixseed import fixseed 16 | 17 | import numpy as np 18 | 19 | def load_vq_model(vq_opt): 20 | # opt_path = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.vq_name, 'opt.txt') 21 | vq_model = RVQVAE(vq_opt, 22 | dim_pose, 23 | vq_opt.nb_code, 24 | vq_opt.code_dim, 25 | vq_opt.output_emb_width, 26 | vq_opt.down_t, 27 | vq_opt.stride_t, 28 | vq_opt.width, 29 | vq_opt.depth, 30 | vq_opt.dilation_growth_rate, 31 | vq_opt.vq_act, 32 | vq_opt.vq_norm) 33 | ckpt = torch.load(pjoin(vq_opt.checkpoints_dir, vq_opt.dataset_name, vq_opt.name, 'model', 'net_best_fid.tar'), 34 | map_location=opt.device) 35 | model_key = 'vq_model' if 'vq_model' in ckpt else 'net' 36 | vq_model.load_state_dict(ckpt[model_key]) 37 | print(f'Loading VQ Model {vq_opt.name} Completed!') 38 | return vq_model, vq_opt 39 | 40 | def load_trans_model(model_opt, which_model): 41 | t2m_transformer = MaskTransformer(code_dim=model_opt.code_dim, 42 | cond_mode='text', 43 | latent_dim=model_opt.latent_dim, 44 | ff_size=model_opt.ff_size, 45 | num_layers=model_opt.n_layers, 46 | num_heads=model_opt.n_heads, 47 | dropout=model_opt.dropout, 48 | clip_dim=512, 49 | cond_drop_prob=model_opt.cond_drop_prob, 50 | clip_version=clip_version, 51 | opt=model_opt) 52 | ckpt = torch.load(pjoin(model_opt.checkpoints_dir, model_opt.dataset_name, model_opt.name, 'model', which_model), 53 | map_location=opt.device) 54 | model_key = 't2m_transformer' if 't2m_transformer' in ckpt else 'trans' 55 | # print(ckpt.keys()) 56 | missing_keys, unexpected_keys = t2m_transformer.load_state_dict(ckpt[model_key], strict=False) 57 | assert len(unexpected_keys) == 0 58 | assert all([k.startswith('clip_model.') for k in missing_keys]) 59 | print(f'Loading Mask Transformer {opt.name} from epoch {ckpt["ep"]}!') 60 | return t2m_transformer 61 | 62 | def load_res_model(res_opt): 63 | res_opt.num_quantizers = vq_opt.num_quantizers 64 | res_opt.num_tokens = vq_opt.nb_code 65 | res_transformer = ResidualTransformer(code_dim=vq_opt.code_dim, 66 | cond_mode='text', 67 | latent_dim=res_opt.latent_dim, 68 | ff_size=res_opt.ff_size, 69 | num_layers=res_opt.n_layers, 70 | num_heads=res_opt.n_heads, 71 | dropout=res_opt.dropout, 72 | clip_dim=512, 73 | shared_codebook=vq_opt.shared_codebook, 74 | cond_drop_prob=res_opt.cond_drop_prob, 75 | # codebook=vq_model.quantizer.codebooks[0] if opt.fix_token_emb else None, 76 | share_weight=res_opt.share_weight, 77 | clip_version=clip_version, 78 | opt=res_opt) 79 | 80 | ckpt = torch.load(pjoin(res_opt.checkpoints_dir, res_opt.dataset_name, res_opt.name, 'model', 'net_best_fid.tar'), 81 | map_location=opt.device) 82 | missing_keys, unexpected_keys = res_transformer.load_state_dict(ckpt['res_transformer'], strict=False) 83 | assert len(unexpected_keys) == 0 84 | assert all([k.startswith('clip_model.') for k in missing_keys]) 85 | print(f'Loading Residual Transformer {res_opt.name} from epoch {ckpt["ep"]}!') 86 | return res_transformer 87 | 88 | if __name__ == '__main__': 89 | parser = EvalT2MOptions() 90 | opt = parser.parse() 91 | fixseed(opt.seed) 92 | 93 | opt.device = torch.device("cpu" if opt.gpu_id == -1 else "cuda:" + str(opt.gpu_id)) 94 | torch.autograd.set_detect_anomaly(True) 95 | 96 | dim_pose = 251 if opt.dataset_name == 'kit' else 263 97 | 98 | # out_dir = pjoin(opt.check) 99 | root_dir = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.name) 100 | model_dir = pjoin(root_dir, 'model') 101 | out_dir = pjoin(root_dir, 'eval') 102 | os.makedirs(out_dir, exist_ok=True) 103 | 104 | out_path = pjoin(out_dir, "%s.log"%opt.ext) 105 | 106 | f = open(pjoin(out_path), 'w') 107 | 108 | model_opt_path = pjoin(root_dir, 'opt.txt') 109 | model_opt = get_opt(model_opt_path, device=opt.device) 110 | clip_version = 'ViT-B/32' 111 | 112 | vq_opt_path = pjoin(opt.checkpoints_dir, opt.dataset_name, model_opt.vq_name, 'opt.txt') 113 | vq_opt = get_opt(vq_opt_path, device=opt.device) 114 | vq_model, vq_opt = load_vq_model(vq_opt) 115 | 116 | model_opt.num_tokens = vq_opt.nb_code 117 | model_opt.num_quantizers = vq_opt.num_quantizers 118 | model_opt.code_dim = vq_opt.code_dim 119 | 120 | res_opt_path = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.res_name, 'opt.txt') 121 | res_opt = get_opt(res_opt_path, device=opt.device) 122 | res_model = load_res_model(res_opt) 123 | 124 | assert res_opt.vq_name == model_opt.vq_name 125 | 126 | dataset_opt_path = 'checkpoints/kit/Comp_v6_KLD005/opt.txt' if opt.dataset_name == 'kit' \ 127 | else 'checkpoints/t2m/Comp_v6_KLD005/opt.txt' 128 | 129 | wrapper_opt = get_opt(dataset_opt_path, torch.device('cuda')) 130 | eval_wrapper = EvaluatorModelWrapper(wrapper_opt) 131 | 132 | ##### ---- Dataloader ---- ##### 133 | opt.nb_joints = 21 if opt.dataset_name == 'kit' else 22 134 | 135 | eval_val_loader, _ = get_dataset_motion_loader(dataset_opt_path, 32, 'test', device=opt.device) 136 | 137 | # model_dir = pjoin(opt.) 138 | for file in os.listdir(model_dir): 139 | if opt.which_epoch != "all" and opt.which_epoch not in file: 140 | continue 141 | print('loading checkpoint {}'.format(file)) 142 | t2m_transformer = load_trans_model(model_opt, file) 143 | t2m_transformer.eval() 144 | vq_model.eval() 145 | res_model.eval() 146 | 147 | t2m_transformer.to(opt.device) 148 | vq_model.to(opt.device) 149 | res_model.to(opt.device) 150 | 151 | fid = [] 152 | div = [] 153 | top1 = [] 154 | top2 = [] 155 | top3 = [] 156 | matching = [] 157 | mm = [] 158 | 159 | repeat_time = 20 160 | for i in range(repeat_time): 161 | with torch.no_grad(): 162 | best_fid, best_div, Rprecision, best_matching, best_mm = \ 163 | eval_t2m.evaluation_mask_transformer_test_plus_res(eval_val_loader, vq_model, res_model, t2m_transformer, 164 | i, eval_wrapper=eval_wrapper, 165 | time_steps=opt.time_steps, cond_scale=opt.cond_scale, 166 | temperature=opt.temperature, topkr=opt.topkr, 167 | force_mask=opt.force_mask, cal_mm=True) 168 | fid.append(best_fid) 169 | div.append(best_div) 170 | top1.append(Rprecision[0]) 171 | top2.append(Rprecision[1]) 172 | top3.append(Rprecision[2]) 173 | matching.append(best_matching) 174 | mm.append(best_mm) 175 | 176 | fid = np.array(fid) 177 | div = np.array(div) 178 | top1 = np.array(top1) 179 | top2 = np.array(top2) 180 | top3 = np.array(top3) 181 | matching = np.array(matching) 182 | mm = np.array(mm) 183 | 184 | print(f'{file} final result:') 185 | print(f'{file} final result:', file=f, flush=True) 186 | 187 | msg_final = f"\tFID: {np.mean(fid):.3f}, conf. {np.std(fid) * 1.96 / np.sqrt(repeat_time):.3f}\n" \ 188 | f"\tDiversity: {np.mean(div):.3f}, conf. {np.std(div) * 1.96 / np.sqrt(repeat_time):.3f}\n" \ 189 | f"\tTOP1: {np.mean(top1):.3f}, conf. {np.std(top1) * 1.96 / np.sqrt(repeat_time):.3f}, TOP2. {np.mean(top2):.3f}, conf. {np.std(top2) * 1.96 / np.sqrt(repeat_time):.3f}, TOP3. {np.mean(top3):.3f}, conf. {np.std(top3) * 1.96 / np.sqrt(repeat_time):.3f}\n" \ 190 | f"\tMatching: {np.mean(matching):.3f}, conf. {np.std(matching) * 1.96 / np.sqrt(repeat_time):.3f}\n" \ 191 | f"\tMultimodality:{np.mean(mm):.3f}, conf.{np.std(mm) * 1.96 / np.sqrt(repeat_time):.3f}\n\n" 192 | # logger.info(msg_final) 193 | print(msg_final) 194 | print(msg_final, file=f, flush=True) 195 | 196 | f.close() 197 | 198 | 199 | # python eval_t2m_trans.py --name t2m_nlayer8_nhead6_ld384_ff1024_cdp0.1_vq --dataset_name t2m --gpu_id 3 --cond_scale 4 --time_steps 18 --temperature 1 --topkr 0.9 --gumbel_sample --ext cs4_ts18_tau1_topkr0.9_gs -------------------------------------------------------------------------------- /visualization/BVH.py: -------------------------------------------------------------------------------- 1 | import re 2 | import numpy as np 3 | from common.quaternion import * 4 | from visualization.Animation import Animation 5 | 6 | channelmap = { 7 | 'Xrotation': 'x', 8 | 'Yrotation': 'y', 9 | 'Zrotation': 'z' 10 | } 11 | 12 | channelmap_inv = { 13 | 'x': 'Xrotation', 14 | 'y': 'Yrotation', 15 | 'z': 'Zrotation', 16 | } 17 | 18 | ordermap = { 19 | 'x': 0, 20 | 'y': 1, 21 | 'z': 2, 22 | } 23 | 24 | def load(filename, start=None, end=None, world=False, need_quater=True): 25 | """ 26 | Reads a BVH file and constructs an animation 27 | Parameters 28 | ---------- 29 | filename: str 30 | File to be opened 31 | start : int 32 | Optional Starting Frame 33 | end : int 34 | Optional Ending Frame 35 | order : str 36 | Optional Specifier for joint order. 37 | Given as string E.G 'xyz', 'zxy' 38 | world : bool 39 | If set to true euler angles are applied 40 | together in world space rather than local 41 | space 42 | Returns 43 | ------- 44 | (animation, joint_names, frametime) 45 | Tuple of loaded animation and joint names 46 | """ 47 | 48 | f = open(filename, "r") 49 | 50 | i = 0 51 | active = -1 52 | end_site = False 53 | 54 | names = [] 55 | orients = Quaterions.id(0) 56 | offsets = np.array([]).reshape((0, 3)) 57 | parents = np.array([], dtype=int) 58 | orders = [] 59 | 60 | for line in f: 61 | 62 | if "HIERARCHY" in line: continue 63 | if "MOTION" in line: continue 64 | 65 | # """ Modified line read to handle mixamo data """ 66 | rmatch = re.match(r"ROOT (\w+)", line) 67 | # rmatch = re.match(r"ROOT (\w+:?\w+)", line) 68 | if rmatch: 69 | names.append(rmatch.group(1)) 70 | offsets = np.append(offsets, np.array([[0, 0, 0]]), axis=0) 71 | orients = np.append(orients, np.array([[1, 0, 0, 0]]), axis=0) 72 | parents = np.append(parents, active) 73 | active = (len(parents) - 1) 74 | continue 75 | 76 | if "{" in line: continue 77 | 78 | if "}" in line: 79 | if end_site: 80 | end_site = False 81 | else: 82 | active = parents[active] 83 | continue 84 | 85 | offmatch = re.match(r"\s*OFFSET\s+([\-\d\.e]+)\s+([\-\d\.e]+)\s+([\-\d\.e]+)", line) 86 | if offmatch: 87 | if not end_site: 88 | offsets[active] = np.array([list(map(float, offmatch.groups()))]) 89 | continue 90 | 91 | chanmatch = re.match(r"\s*CHANNELS\s+(\d+)", line) 92 | if chanmatch: 93 | channels = int(chanmatch.group(1)) 94 | 95 | channelis = 0 if channels == 3 else 3 96 | channelie = 3 if channels == 3 else 6 97 | parts = line.split()[2 + channelis:2 + channelie] 98 | if any([p not in channelmap for p in parts]): 99 | continue 100 | order = "".join([channelmap[p] for p in parts]) 101 | orders.append(order) 102 | continue 103 | 104 | # """ Modified line read to handle mixamo data """ 105 | jmatch = re.match("\s*JOINT\s+(\w+)", line) 106 | # jmatch = re.match("\s*JOINT\s+(\w+:?\w+)", line) 107 | if jmatch: 108 | names.append(jmatch.group(1)) 109 | offsets = np.append(offsets, np.array([[0, 0, 0]]), axis=0) 110 | orients = np.append(orients, np.array([[1, 0, 0, 0]]), axis=0) 111 | parents = np.append(parents, active) 112 | active = (len(parents) - 1) 113 | continue 114 | 115 | if "End Site" in line: 116 | end_site = True 117 | continue 118 | 119 | fmatch = re.match("\s*Frames:\s+(\d+)", line) 120 | if fmatch: 121 | if start and end: 122 | fnum = (end - start) - 1 123 | else: 124 | fnum = int(fmatch.group(1)) 125 | jnum = len(parents) 126 | positions = offsets[np.newaxis].repeat(fnum, axis=0) 127 | rotations = np.zeros((fnum, len(orients), 3)) 128 | continue 129 | 130 | fmatch = re.match("\s*Frame Time:\s+([\d\.]+)", line) 131 | if fmatch: 132 | frametime = float(fmatch.group(1)) 133 | continue 134 | 135 | if (start and end) and (i < start or i >= end - 1): 136 | i += 1 137 | continue 138 | 139 | # dmatch = line.strip().split(' ') 140 | dmatch = line.strip().split() 141 | if dmatch: 142 | data_block = np.array(list(map(float, dmatch))) 143 | N = len(parents) 144 | fi = i - start if start else i 145 | if channels == 3: 146 | positions[fi, 0:1] = data_block[0:3] 147 | rotations[fi, :] = data_block[3:].reshape(N, 3) 148 | elif channels == 6: 149 | data_block = data_block.reshape(N, 6) 150 | positions[fi, :] = data_block[:, 0:3] 151 | rotations[fi, :] = data_block[:, 3:6] 152 | elif channels == 9: 153 | positions[fi, 0] = data_block[0:3] 154 | data_block = data_block[3:].reshape(N - 1, 9) 155 | rotations[fi, 1:] = data_block[:, 3:6] 156 | positions[fi, 1:] += data_block[:, 0:3] * data_block[:, 6:9] 157 | else: 158 | raise Exception("Too many channels! %i" % channels) 159 | 160 | i += 1 161 | 162 | f.close() 163 | 164 | all_rotations = [] 165 | canonical_order = 'xyz' 166 | for i, order in enumerate(orders): 167 | rot = rotations[:, i:i + 1] 168 | if need_quater: 169 | quat = euler_to_quat_np(np.radians(rot), order=order, world=world) 170 | all_rotations.append(quat) 171 | continue 172 | elif order != canonical_order: 173 | quat = euler_to_quat_np(np.radians(rot), order=order, world=world) 174 | rot = np.degrees(qeuler_np(quat, order=canonical_order)) 175 | all_rotations.append(rot) 176 | rotations = np.concatenate(all_rotations, axis=1) 177 | 178 | return Animation(rotations, positions, orients, offsets, parents, names, frametime) 179 | 180 | def write_bvh(parent, offset, rotation, rot_position, names, frametime, order, path, endsite=None): 181 | file = open(path, 'w') 182 | frame = rotation.shape[0] 183 | assert rotation.shape[-1] == 3 184 | joint_num = rotation.shape[1] 185 | order = order.upper() 186 | 187 | file_string = 'HIERARCHY\n' 188 | 189 | seq = [] 190 | 191 | def write_static(idx, prefix): 192 | nonlocal parent, offset, rotation, names, order, endsite, file_string, seq 193 | seq.append(idx) 194 | if idx == 0: 195 | name_label = 'ROOT ' + names[idx] 196 | channel_label = 'CHANNELS 6 Xposition Yposition Zposition {}rotation {}rotation {}rotation'.format( 197 | *order) 198 | else: 199 | name_label = 'JOINT ' + names[idx] 200 | channel_label = 'CHANNELS 3 {}rotation {}rotation {}rotation'.format(*order) 201 | offset_label = 'OFFSET %.6f %.6f %.6f' % (offset[idx][0], offset[idx][1], offset[idx][2]) 202 | 203 | file_string += prefix + name_label + '\n' 204 | file_string += prefix + '{\n' 205 | file_string += prefix + '\t' + offset_label + '\n' 206 | file_string += prefix + '\t' + channel_label + '\n' 207 | 208 | has_child = False 209 | for y in range(idx + 1, rotation.shape[1]): 210 | if parent[y] == idx: 211 | has_child = True 212 | write_static(y, prefix + '\t') 213 | if not has_child: 214 | file_string += prefix + '\t' + 'End Site\n' 215 | file_string += prefix + '\t' + '{\n' 216 | file_string += prefix + '\t\t' + 'OFFSET 0 0 0\n' 217 | file_string += prefix + '\t' + '}\n' 218 | 219 | file_string += prefix + '}\n' 220 | 221 | write_static(0, '') 222 | 223 | file_string += 'MOTION\n' + 'Frames: {}\n'.format(frame) + 'Frame Time: %.8f\n' % frametime 224 | for i in range(frame): 225 | file_string += '%.6f %.6f %.6f ' % (rot_position[i][0], rot_position[i][1], 226 | rot_position[i][2]) 227 | for j in range(joint_num): 228 | idx = seq[j] 229 | file_string += '%.6f %.6f %.6f ' % (rotation[i][idx][0], rotation[i][idx][1], rotation[i][idx][2]) 230 | file_string += '\n' 231 | 232 | file.write(file_string) 233 | return file_string 234 | 235 | class WriterWrapper: 236 | def __init__(self, parents, frametime, offset=None, names=None): 237 | self.parents = parents 238 | self.offset = offset 239 | self.frametime = frametime 240 | self.names = names 241 | 242 | def write(self, filename, rot, r_pos, order, offset=None, names=None, repr='quat'): 243 | """ 244 | Write animation to bvh file 245 | :param filename: 246 | :param rot: Quaternion as (w, x, y, z) 247 | :param pos: 248 | :param offset: 249 | :return: 250 | """ 251 | if repr not in ['euler', 'quat', 'quaternion', 'cont6d']: 252 | raise Exception('Unknown rotation representation') 253 | if offset is None: 254 | offset = self.offset 255 | if not isinstance(offset, torch.Tensor): 256 | offset = torch.tensor(offset) 257 | n_bone = offset.shape[0] 258 | 259 | if repr == 'cont6d': 260 | rot = rot.reshape(rot.shape[0], -1, 6) 261 | rot = cont6d_to_quat_np(rot) 262 | if repr == 'cont6d' or repr == 'quat' or repr == 'quaternion': 263 | # rot = rot.reshape(rot.shape[0], -1, 4) 264 | # rot /= rot.norm(dim=-1, keepdim=True) ** 0.5 265 | euler = qeuler_np(rot, order=order) 266 | rot = euler 267 | 268 | if names is None: 269 | if self.names is None: 270 | names = ['%02d' % i for i in range(n_bone)] 271 | else: 272 | names = self.names 273 | write_bvh(self.parents, offset, rot, r_pos, names, self.frametime, order, filename) -------------------------------------------------------------------------------- /visualization/BVH_mod.py: -------------------------------------------------------------------------------- 1 | import re 2 | import numpy as np 3 | 4 | from visualization.Animation import Animation 5 | from visualization.Quaternions import Quaternions 6 | 7 | channelmap = { 8 | 'Xrotation': 'x', 9 | 'Yrotation': 'y', 10 | 'Zrotation': 'z' 11 | } 12 | 13 | channelmap_inv = { 14 | 'x': 'Xrotation', 15 | 'y': 'Yrotation', 16 | 'z': 'Zrotation', 17 | } 18 | 19 | ordermap = { 20 | 'x': 0, 21 | 'y': 1, 22 | 'z': 2, 23 | } 24 | 25 | 26 | def load(filename, start=None, end=None, order=None, world=False, need_quater=True): 27 | """ 28 | Reads a BVH file and constructs an animation 29 | 30 | Parameters 31 | ---------- 32 | filename: str 33 | File to be opened 34 | 35 | start : int 36 | Optional Starting Frame 37 | 38 | end : int 39 | Optional Ending Frame 40 | 41 | order : str 42 | Optional Specifier for joint order. 43 | Given as string E.G 'xyz', 'zxy' 44 | 45 | world : bool 46 | If set to true euler angles are applied 47 | together in world space rather than local 48 | space 49 | 50 | Returns 51 | ------- 52 | 53 | (animation, joint_names, frametime) 54 | Tuple of loaded animation and joint names 55 | """ 56 | 57 | f = open(filename, "r") 58 | 59 | i = 0 60 | active = -1 61 | end_site = False 62 | 63 | names = [] 64 | orients = Quaternions.id(0) 65 | offsets = np.array([]).reshape((0, 3)) 66 | parents = np.array([], dtype=int) 67 | 68 | for line in f: 69 | 70 | if "HIERARCHY" in line: continue 71 | if "MOTION" in line: continue 72 | 73 | """ Modified line read to handle mixamo data """ 74 | # rmatch = re.match(r"ROOT (\w+)", line) 75 | rmatch = re.match(r"ROOT (\w+:?\w+)", line) 76 | if rmatch: 77 | names.append(rmatch.group(1)) 78 | offsets = np.append(offsets, np.array([[0, 0, 0]]), axis=0) 79 | orients.qs = np.append(orients.qs, np.array([[1, 0, 0, 0]]), axis=0) 80 | parents = np.append(parents, active) 81 | active = (len(parents) - 1) 82 | continue 83 | 84 | if "{" in line: continue 85 | 86 | if "}" in line: 87 | if end_site: 88 | end_site = False 89 | else: 90 | active = parents[active] 91 | continue 92 | 93 | offmatch = re.match(r"\s*OFFSET\s+([\-\d\.e]+)\s+([\-\d\.e]+)\s+([\-\d\.e]+)", line) 94 | if offmatch: 95 | if not end_site: 96 | offsets[active] = np.array([list(map(float, offmatch.groups()))]) 97 | continue 98 | 99 | chanmatch = re.match(r"\s*CHANNELS\s+(\d+)", line) 100 | if chanmatch: 101 | channels = int(chanmatch.group(1)) 102 | if order is None: 103 | channelis = 0 if channels == 3 else 3 104 | channelie = 3 if channels == 3 else 6 105 | parts = line.split()[2 + channelis:2 + channelie] 106 | if any([p not in channelmap for p in parts]): 107 | continue 108 | order = "".join([channelmap[p] for p in parts]) 109 | continue 110 | 111 | """ Modified line read to handle mixamo data """ 112 | # jmatch = re.match("\s*JOINT\s+(\w+)", line) 113 | jmatch = re.match("\s*JOINT\s+(\w+:?\w+)", line) 114 | if jmatch: 115 | names.append(jmatch.group(1)) 116 | offsets = np.append(offsets, np.array([[0, 0, 0]]), axis=0) 117 | orients.qs = np.append(orients.qs, np.array([[1, 0, 0, 0]]), axis=0) 118 | parents = np.append(parents, active) 119 | active = (len(parents) - 1) 120 | continue 121 | 122 | if "End Site" in line: 123 | end_site = True 124 | continue 125 | 126 | fmatch = re.match("\s*Frames:\s+(\d+)", line) 127 | if fmatch: 128 | if start and end: 129 | fnum = (end - start) - 1 130 | else: 131 | fnum = int(fmatch.group(1)) 132 | jnum = len(parents) 133 | positions = offsets[np.newaxis].repeat(fnum, axis=0) 134 | rotations = np.zeros((fnum, len(orients), 3)) 135 | continue 136 | 137 | fmatch = re.match("\s*Frame Time:\s+([\d\.]+)", line) 138 | if fmatch: 139 | frametime = float(fmatch.group(1)) 140 | continue 141 | 142 | if (start and end) and (i < start or i >= end - 1): 143 | i += 1 144 | continue 145 | 146 | # dmatch = line.strip().split(' ') 147 | dmatch = line.strip().split() 148 | if dmatch: 149 | data_block = np.array(list(map(float, dmatch))) 150 | N = len(parents) 151 | fi = i - start if start else i 152 | if channels == 3: 153 | positions[fi, 0:1] = data_block[0:3] 154 | rotations[fi, :] = data_block[3:].reshape(N, 3) 155 | elif channels == 6: 156 | data_block = data_block.reshape(N, 6) 157 | positions[fi, :] = data_block[:, 0:3] 158 | rotations[fi, :] = data_block[:, 3:6] 159 | elif channels == 9: 160 | positions[fi, 0] = data_block[0:3] 161 | data_block = data_block[3:].reshape(N - 1, 9) 162 | rotations[fi, 1:] = data_block[:, 3:6] 163 | positions[fi, 1:] += data_block[:, 0:3] * data_block[:, 6:9] 164 | else: 165 | raise Exception("Too many channels! %i" % channels) 166 | 167 | i += 1 168 | 169 | f.close() 170 | 171 | if need_quater: 172 | rotations = Quaternions.from_euler(np.radians(rotations), order=order, world=world) 173 | elif order != 'xyz': 174 | rotations = Quaternions.from_euler(np.radians(rotations), order=order, world=world) 175 | rotations = np.degrees(rotations.euler()) 176 | 177 | return Animation(rotations, positions, orients, offsets, parents, names, frametime) 178 | 179 | 180 | def save(filename, anim, names=None, frametime=1.0 / 24.0, order='zyx', positions=False, mask=None, quater=False): 181 | """ 182 | Saves an Animation to file as BVH 183 | 184 | Parameters 185 | ---------- 186 | filename: str 187 | File to be saved to 188 | 189 | anim : Animation 190 | Animation to save 191 | 192 | names : [str] 193 | List of joint names 194 | 195 | order : str 196 | Optional Specifier for joint order. 197 | Given as string E.G 'xyz', 'zxy' 198 | 199 | frametime : float 200 | Optional Animation Frame time 201 | 202 | positions : bool 203 | Optional specfier to save bone 204 | positions for each frame 205 | 206 | orients : bool 207 | Multiply joint orients to the rotations 208 | before saving. 209 | 210 | """ 211 | 212 | if names is None: 213 | names = ["joint_" + str(i) for i in range(len(anim.parents))] 214 | 215 | with open(filename, 'w') as f: 216 | 217 | t = "" 218 | f.write("%sHIERARCHY\n" % t) 219 | f.write("%sROOT %s\n" % (t, names[0])) 220 | f.write("%s{\n" % t) 221 | t += '\t' 222 | 223 | f.write("%sOFFSET %f %f %f\n" % (t, anim.offsets[0, 0], anim.offsets[0, 1], anim.offsets[0, 2])) 224 | f.write("%sCHANNELS 6 Xposition Yposition Zposition %s %s %s \n" % 225 | (t, channelmap_inv[order[0]], channelmap_inv[order[1]], channelmap_inv[order[2]])) 226 | 227 | for i in range(anim.shape[1]): 228 | if anim.parents[i] == 0: 229 | t = save_joint(f, anim, names, t, i, order=order, positions=positions) 230 | 231 | t = t[:-1] 232 | f.write("%s}\n" % t) 233 | 234 | f.write("MOTION\n") 235 | f.write("Frames: %i\n" % anim.shape[0]); 236 | f.write("Frame Time: %f\n" % frametime); 237 | 238 | # if orients: 239 | # rots = np.degrees((-anim.orients[np.newaxis] * anim.rotations).euler(order=order[::-1])) 240 | # else: 241 | # rots = np.degrees(anim.rotations.euler(order=order[::-1])) 242 | # rots = np.degrees(anim.rotations.euler(order=order[::-1])) 243 | if quater: 244 | rots = np.degrees(anim.rotations.euler(order=order[::-1])) 245 | else: 246 | rots = anim.rotations 247 | poss = anim.positions 248 | 249 | for i in range(anim.shape[0]): 250 | for j in range(anim.shape[1]): 251 | 252 | if positions or j == 0: 253 | 254 | f.write("%f %f %f %f %f %f " % ( 255 | poss[i, j, 0], poss[i, j, 1], poss[i, j, 2], 256 | rots[i, j, ordermap[order[0]]], rots[i, j, ordermap[order[1]]], rots[i, j, ordermap[order[2]]])) 257 | 258 | else: 259 | if mask == None or mask[j] == 1: 260 | f.write("%f %f %f " % ( 261 | rots[i, j, ordermap[order[0]]], rots[i, j, ordermap[order[1]]], 262 | rots[i, j, ordermap[order[2]]])) 263 | else: 264 | f.write("%f %f %f " % (0, 0, 0)) 265 | 266 | f.write("\n") 267 | 268 | 269 | def save_joint(f, anim, names, t, i, order='zyx', positions=False): 270 | f.write("%sJOINT %s\n" % (t, names[i])) 271 | f.write("%s{\n" % t) 272 | t += '\t' 273 | 274 | f.write("%sOFFSET %f %f %f\n" % (t, anim.offsets[i, 0], anim.offsets[i, 1], anim.offsets[i, 2])) 275 | 276 | if positions: 277 | f.write("%sCHANNELS 6 Xposition Yposition Zposition %s %s %s \n" % (t, 278 | channelmap_inv[order[0]], 279 | channelmap_inv[order[1]], 280 | channelmap_inv[order[2]])) 281 | else: 282 | f.write("%sCHANNELS 3 %s %s %s\n" % (t, 283 | channelmap_inv[order[0]], channelmap_inv[order[1]], 284 | channelmap_inv[order[2]])) 285 | 286 | end_site = True 287 | 288 | for j in range(anim.shape[1]): 289 | if anim.parents[j] == i: 290 | t = save_joint(f, anim, names, t, j, order=order, positions=positions) 291 | end_site = False 292 | 293 | if end_site: 294 | f.write("%sEnd Site\n" % t) 295 | f.write("%s{\n" % t) 296 | t += '\t' 297 | f.write("%sOFFSET %f %f %f\n" % (t, 0.0, 0.0, 0.0)) 298 | t = t[:-1] 299 | f.write("%s}\n" % t) 300 | 301 | t = t[:-1] 302 | f.write("%s}\n" % t) 303 | 304 | return t -------------------------------------------------------------------------------- /visualization/AnimationStructure.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | # import scipy.sparse as sparse 3 | import visualization.Animation as Animation 4 | 5 | 6 | """ Family Functions """ 7 | 8 | 9 | def joints(parents): 10 | """ 11 | Parameters 12 | ---------- 13 | 14 | parents : (J) ndarray 15 | parents array 16 | 17 | Returns 18 | ------- 19 | 20 | joints : (J) ndarray 21 | Array of joint indices 22 | """ 23 | return np.arange(len(parents), dtype=int) 24 | 25 | 26 | def joints_list(parents): 27 | """ 28 | Parameters 29 | ---------- 30 | 31 | parents : (J) ndarray 32 | parents array 33 | 34 | Returns 35 | ------- 36 | 37 | joints : [ndarray] 38 | List of arrays of joint idices for 39 | each joint 40 | """ 41 | return list(joints(parents)[:, np.newaxis]) 42 | 43 | 44 | def parents_list(parents): 45 | """ 46 | Parameters 47 | ---------- 48 | 49 | parents : (J) ndarray 50 | parents array 51 | 52 | Returns 53 | ------- 54 | 55 | parents : [ndarray] 56 | List of arrays of joint idices for 57 | the parents of each joint 58 | """ 59 | return list(parents[:, np.newaxis]) 60 | 61 | 62 | def children_list(parents): 63 | """ 64 | Parameters 65 | ---------- 66 | 67 | parents : (J) ndarray 68 | parents array 69 | 70 | Returns 71 | ------- 72 | 73 | children : [ndarray] 74 | List of arrays of joint indices for 75 | the children of each joint 76 | """ 77 | 78 | def joint_children(i): 79 | return [j for j, p in enumerate(parents) if p == i] 80 | 81 | return list(map(lambda j: np.array(joint_children(j)), joints(parents))) 82 | 83 | 84 | def descendants_list(parents): 85 | """ 86 | Parameters 87 | ---------- 88 | 89 | parents : (J) ndarray 90 | parents array 91 | 92 | Returns 93 | ------- 94 | 95 | descendants : [ndarray] 96 | List of arrays of joint idices for 97 | the descendants of each joint 98 | """ 99 | 100 | children = children_list(parents) 101 | 102 | def joint_descendants(i): 103 | return sum([joint_descendants(j) for j in children[i]], list(children[i])) 104 | 105 | return list(map(lambda j: np.array(joint_descendants(j)), joints(parents))) 106 | 107 | 108 | def ancestors_list(parents): 109 | """ 110 | Parameters 111 | ---------- 112 | 113 | parents : (J) ndarray 114 | parents array 115 | 116 | Returns 117 | ------- 118 | 119 | ancestors : [ndarray] 120 | List of arrays of joint idices for 121 | the ancestors of each joint 122 | """ 123 | 124 | decendants = descendants_list(parents) 125 | 126 | def joint_ancestors(i): 127 | return [j for j in joints(parents) if i in decendants[j]] 128 | 129 | return list(map(lambda j: np.array(joint_ancestors(j)), joints(parents))) 130 | 131 | 132 | """ Mask Functions """ 133 | 134 | 135 | def mask(parents, filter): 136 | """ 137 | Constructs a Mask for a give filter 138 | 139 | A mask is a (J, J) ndarray truth table for a given 140 | condition over J joints. For example there 141 | may be a mask specifying if a joint N is a 142 | child of another joint M. 143 | 144 | This could be constructed into a mask using 145 | `m = mask(parents, children_list)` and the condition 146 | of childhood tested using `m[N, M]`. 147 | 148 | Parameters 149 | ---------- 150 | 151 | parents : (J) ndarray 152 | parents array 153 | 154 | filter : (J) ndarray -> [ndarray] 155 | function that outputs a list of arrays 156 | of joint indices for some condition 157 | 158 | Returns 159 | ------- 160 | 161 | mask : (N, N) ndarray 162 | boolean truth table of given condition 163 | """ 164 | m = np.zeros((len(parents), len(parents))).astype(bool) 165 | jnts = joints(parents) 166 | fltr = filter(parents) 167 | for i, f in enumerate(fltr): m[i, :] = np.any(jnts[:, np.newaxis] == f[np.newaxis, :], axis=1) 168 | return m 169 | 170 | 171 | def joints_mask(parents): return np.eye(len(parents)).astype(bool) 172 | 173 | 174 | def children_mask(parents): return mask(parents, children_list) 175 | 176 | 177 | def parents_mask(parents): return mask(parents, parents_list) 178 | 179 | 180 | def descendants_mask(parents): return mask(parents, descendants_list) 181 | 182 | 183 | def ancestors_mask(parents): return mask(parents, ancestors_list) 184 | 185 | 186 | """ Search Functions """ 187 | 188 | 189 | def joint_chain_ascend(parents, start, end): 190 | chain = [] 191 | while start != end: 192 | chain.append(start) 193 | start = parents[start] 194 | chain.append(end) 195 | return np.array(chain, dtype=int) 196 | 197 | 198 | """ Constraints """ 199 | 200 | 201 | def constraints(anim, **kwargs): 202 | """ 203 | Constraint list for Animation 204 | 205 | This constraint list can be used in the 206 | VerletParticle solver to constrain 207 | a animation global joint positions. 208 | 209 | Parameters 210 | ---------- 211 | 212 | anim : Animation 213 | Input animation 214 | 215 | masses : (F, J) ndarray 216 | Optional list of masses 217 | for joints J across frames F 218 | defaults to weighting by 219 | vertical height 220 | 221 | Returns 222 | ------- 223 | 224 | constraints : [(int, int, (F, J) ndarray, (F, J) ndarray, (F, J) ndarray)] 225 | A list of constraints in the format: 226 | (Joint1, Joint2, Masses1, Masses2, Lengths) 227 | 228 | """ 229 | 230 | masses = kwargs.pop('masses', None) 231 | 232 | children = children_list(anim.parents) 233 | constraints = [] 234 | 235 | points_offsets = Animation.offsets_global(anim) 236 | points = Animation.positions_global(anim) 237 | 238 | if masses is None: 239 | masses = 1.0 / (0.1 + np.absolute(points_offsets[:, 1])) 240 | masses = masses[np.newaxis].repeat(len(anim), axis=0) 241 | 242 | for j in range(anim.shape[1]): 243 | 244 | """ Add constraints between all joints and their children """ 245 | for c0 in children[j]: 246 | 247 | dists = np.sum((points[:, c0] - points[:, j]) ** 2.0, axis=1) ** 0.5 248 | constraints.append((c0, j, masses[:, c0], masses[:, j], dists)) 249 | 250 | """ Add constraints between all children of joint """ 251 | for c1 in children[j]: 252 | if c0 == c1: continue 253 | 254 | dists = np.sum((points[:, c0] - points[:, c1]) ** 2.0, axis=1) ** 0.5 255 | constraints.append((c0, c1, masses[:, c0], masses[:, c1], dists)) 256 | 257 | return constraints 258 | 259 | 260 | """ Graph Functions """ 261 | 262 | 263 | def graph(anim): 264 | """ 265 | Generates a weighted adjacency matrix 266 | using local joint distances along 267 | the skeletal structure. 268 | 269 | Joints which are not connected 270 | are assigned the weight `0`. 271 | 272 | Joints which actually have zero distance 273 | between them, but are still connected, are 274 | perturbed by some minimal amount. 275 | 276 | The output of this routine can be used 277 | with the `scipy.sparse.csgraph` 278 | routines for graph analysis. 279 | 280 | Parameters 281 | ---------- 282 | 283 | anim : Animation 284 | input animation 285 | 286 | Returns 287 | ------- 288 | 289 | graph : (N, N) ndarray 290 | weight adjacency matrix using 291 | local distances along the 292 | skeletal structure from joint 293 | N to joint M. If joints are not 294 | directly connected are assigned 295 | the weight `0`. 296 | """ 297 | 298 | graph = np.zeros(anim.shape[1], anim.shape[1]) 299 | lengths = np.sum(anim.offsets ** 2.0, axis=1) ** 0.5 + 0.001 300 | 301 | for i, p in enumerate(anim.parents): 302 | if p == -1: continue 303 | graph[i, p] = lengths[p] 304 | graph[p, i] = lengths[p] 305 | 306 | return graph 307 | 308 | 309 | def distances(anim): 310 | """ 311 | Generates a distance matrix for 312 | pairwise joint distances along 313 | the skeletal structure 314 | 315 | Parameters 316 | ---------- 317 | 318 | anim : Animation 319 | input animation 320 | 321 | Returns 322 | ------- 323 | 324 | distances : (N, N) ndarray 325 | array of pairwise distances 326 | along skeletal structure 327 | from some joint N to some 328 | joint M 329 | """ 330 | 331 | distances = np.zeros((anim.shape[1], anim.shape[1])) 332 | generated = distances.copy().astype(bool) 333 | 334 | joint_lengths = np.sum(anim.offsets ** 2.0, axis=1) ** 0.5 335 | joint_children = children_list(anim) 336 | joint_parents = parents_list(anim) 337 | 338 | def find_distance(distances, generated, prev, i, j): 339 | 340 | """ If root, identity, or already generated, return """ 341 | if j == -1: return (0.0, True) 342 | if j == i: return (0.0, True) 343 | if generated[i, j]: return (distances[i, j], True) 344 | 345 | """ Find best distances along parents and children """ 346 | par_dists = [(joint_lengths[j], find_distance(distances, generated, j, i, p)) for p in joint_parents[j] if 347 | p != prev] 348 | out_dists = [(joint_lengths[c], find_distance(distances, generated, j, i, c)) for c in joint_children[j] if 349 | c != prev] 350 | 351 | """ Check valid distance and not dead end """ 352 | par_dists = [a + d for (a, (d, f)) in par_dists if f] 353 | out_dists = [a + d for (a, (d, f)) in out_dists if f] 354 | 355 | """ All dead ends """ 356 | if (out_dists + par_dists) == []: return (0.0, False) 357 | 358 | """ Get minimum path """ 359 | dist = min(out_dists + par_dists) 360 | distances[i, j] = dist; 361 | distances[j, i] = dist 362 | generated[i, j] = True; 363 | generated[j, i] = True 364 | 365 | for i in range(anim.shape[1]): 366 | for j in range(anim.shape[1]): 367 | find_distance(distances, generated, -1, i, j) 368 | 369 | return distances 370 | 371 | 372 | def edges(parents): 373 | """ 374 | Animation structure edges 375 | 376 | Parameters 377 | ---------- 378 | 379 | parents : (J) ndarray 380 | parents array 381 | 382 | Returns 383 | ------- 384 | 385 | edges : (M, 2) ndarray 386 | array of pairs where each 387 | pair contains two indices of a joints 388 | which corrisponds to an edge in the 389 | joint structure going from parent to child. 390 | """ 391 | 392 | return np.array(list(zip(parents, joints(parents)))[1:]) 393 | 394 | 395 | def incidence(parents): 396 | """ 397 | Incidence Matrix 398 | 399 | Parameters 400 | ---------- 401 | 402 | parents : (J) ndarray 403 | parents array 404 | 405 | Returns 406 | ------- 407 | 408 | incidence : (N, M) ndarray 409 | 410 | Matrix of N joint positions by 411 | M edges which each entry is either 412 | 1 or -1 and multiplication by the 413 | joint positions returns the an 414 | array of vectors along each edge 415 | of the structure 416 | """ 417 | 418 | es = edges(parents) 419 | 420 | inc = np.zeros((len(parents) - 1, len(parents))).astype(np.int) 421 | for i, e in enumerate(es): 422 | inc[i, e[0]] = 1 423 | inc[i, e[1]] = -1 424 | 425 | return inc.T 426 | -------------------------------------------------------------------------------- /gen_t2m.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join as pjoin 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | from models.mask_transformer.transformer import MaskTransformer, ResidualTransformer 8 | from models.vq.model import RVQVAE, LengthEstimator 9 | 10 | from options.eval_option import EvalT2MOptions 11 | from utils.get_opt import get_opt 12 | 13 | from utils.fixseed import fixseed 14 | from visualization.joints2bvh import Joint2BVHConvertor 15 | from torch.distributions.categorical import Categorical 16 | 17 | 18 | from utils.motion_process import recover_from_ric 19 | from utils.plot_script import plot_3d_motion 20 | 21 | from utils.paramUtil import t2m_kinematic_chain 22 | 23 | import numpy as np 24 | clip_version = 'ViT-B/32' 25 | 26 | def load_vq_model(vq_opt): 27 | # opt_path = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.vq_name, 'opt.txt') 28 | vq_model = RVQVAE(vq_opt, 29 | vq_opt.dim_pose, 30 | vq_opt.nb_code, 31 | vq_opt.code_dim, 32 | vq_opt.output_emb_width, 33 | vq_opt.down_t, 34 | vq_opt.stride_t, 35 | vq_opt.width, 36 | vq_opt.depth, 37 | vq_opt.dilation_growth_rate, 38 | vq_opt.vq_act, 39 | vq_opt.vq_norm) 40 | ckpt = torch.load(pjoin(vq_opt.checkpoints_dir, vq_opt.dataset_name, vq_opt.name, 'model', 'net_best_fid.tar'), 41 | map_location='cpu') 42 | model_key = 'vq_model' if 'vq_model' in ckpt else 'net' 43 | vq_model.load_state_dict(ckpt[model_key]) 44 | print(f'Loading VQ Model {vq_opt.name} Completed!') 45 | return vq_model, vq_opt 46 | 47 | def load_trans_model(model_opt, opt, which_model): 48 | t2m_transformer = MaskTransformer(code_dim=model_opt.code_dim, 49 | cond_mode='text', 50 | latent_dim=model_opt.latent_dim, 51 | ff_size=model_opt.ff_size, 52 | num_layers=model_opt.n_layers, 53 | num_heads=model_opt.n_heads, 54 | dropout=model_opt.dropout, 55 | clip_dim=512, 56 | cond_drop_prob=model_opt.cond_drop_prob, 57 | clip_version=clip_version, 58 | opt=model_opt) 59 | ckpt = torch.load(pjoin(model_opt.checkpoints_dir, model_opt.dataset_name, model_opt.name, 'model', which_model), 60 | map_location='cpu') 61 | model_key = 't2m_transformer' if 't2m_transformer' in ckpt else 'trans' 62 | # print(ckpt.keys()) 63 | missing_keys, unexpected_keys = t2m_transformer.load_state_dict(ckpt[model_key], strict=False) 64 | assert len(unexpected_keys) == 0 65 | assert all([k.startswith('clip_model.') for k in missing_keys]) 66 | print(f'Loading Transformer {opt.name} from epoch {ckpt["ep"]}!') 67 | return t2m_transformer 68 | 69 | def load_res_model(res_opt, vq_opt, opt): 70 | res_opt.num_quantizers = vq_opt.num_quantizers 71 | res_opt.num_tokens = vq_opt.nb_code 72 | res_transformer = ResidualTransformer(code_dim=vq_opt.code_dim, 73 | cond_mode='text', 74 | latent_dim=res_opt.latent_dim, 75 | ff_size=res_opt.ff_size, 76 | num_layers=res_opt.n_layers, 77 | num_heads=res_opt.n_heads, 78 | dropout=res_opt.dropout, 79 | clip_dim=512, 80 | shared_codebook=vq_opt.shared_codebook, 81 | cond_drop_prob=res_opt.cond_drop_prob, 82 | # codebook=vq_model.quantizer.codebooks[0] if opt.fix_token_emb else None, 83 | share_weight=res_opt.share_weight, 84 | clip_version=clip_version, 85 | opt=res_opt) 86 | 87 | ckpt = torch.load(pjoin(res_opt.checkpoints_dir, res_opt.dataset_name, res_opt.name, 'model', 'net_best_fid.tar'), 88 | map_location=opt.device) 89 | missing_keys, unexpected_keys = res_transformer.load_state_dict(ckpt['res_transformer'], strict=False) 90 | assert len(unexpected_keys) == 0 91 | assert all([k.startswith('clip_model.') for k in missing_keys]) 92 | print(f'Loading Residual Transformer {res_opt.name} from epoch {ckpt["ep"]}!') 93 | return res_transformer 94 | 95 | def load_len_estimator(opt): 96 | model = LengthEstimator(512, 50) 97 | ckpt = torch.load(pjoin(opt.checkpoints_dir, opt.dataset_name, 'length_estimator', 'model', 'finest.tar'), 98 | map_location=opt.device) 99 | model.load_state_dict(ckpt['estimator']) 100 | print(f'Loading Length Estimator from epoch {ckpt["epoch"]}!') 101 | return model 102 | 103 | 104 | if __name__ == '__main__': 105 | parser = EvalT2MOptions() 106 | opt = parser.parse() 107 | fixseed(opt.seed) 108 | 109 | opt.device = torch.device("cpu" if opt.gpu_id == -1 else "cuda:" + str(opt.gpu_id)) 110 | torch.autograd.set_detect_anomaly(True) 111 | 112 | dim_pose = 251 if opt.dataset_name == 'kit' else 263 113 | 114 | # out_dir = pjoin(opt.check) 115 | root_dir = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.name) 116 | model_dir = pjoin(root_dir, 'model') 117 | result_dir = pjoin('./generation', opt.ext) 118 | joints_dir = pjoin(result_dir, 'joints') 119 | animation_dir = pjoin(result_dir, 'animations') 120 | os.makedirs(joints_dir, exist_ok=True) 121 | os.makedirs(animation_dir,exist_ok=True) 122 | 123 | model_opt_path = pjoin(root_dir, 'opt.txt') 124 | model_opt = get_opt(model_opt_path, device=opt.device) 125 | 126 | 127 | ####################### 128 | ######Loading RVQ###### 129 | ####################### 130 | vq_opt_path = pjoin(opt.checkpoints_dir, opt.dataset_name, model_opt.vq_name, 'opt.txt') 131 | vq_opt = get_opt(vq_opt_path, device=opt.device) 132 | vq_opt.dim_pose = dim_pose 133 | vq_model, vq_opt = load_vq_model(vq_opt) 134 | 135 | model_opt.num_tokens = vq_opt.nb_code 136 | model_opt.num_quantizers = vq_opt.num_quantizers 137 | model_opt.code_dim = vq_opt.code_dim 138 | 139 | ################################# 140 | ######Loading R-Transformer###### 141 | ################################# 142 | res_opt_path = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.res_name, 'opt.txt') 143 | res_opt = get_opt(res_opt_path, device=opt.device) 144 | res_model = load_res_model(res_opt, vq_opt, opt) 145 | 146 | assert res_opt.vq_name == model_opt.vq_name 147 | 148 | ################################# 149 | ######Loading M-Transformer###### 150 | ################################# 151 | t2m_transformer = load_trans_model(model_opt, opt, 'latest.tar') 152 | 153 | ################################## 154 | #####Loading Length Predictor##### 155 | ################################## 156 | length_estimator = load_len_estimator(model_opt) 157 | 158 | t2m_transformer.eval() 159 | vq_model.eval() 160 | res_model.eval() 161 | length_estimator.eval() 162 | 163 | res_model.to(opt.device) 164 | t2m_transformer.to(opt.device) 165 | vq_model.to(opt.device) 166 | length_estimator.to(opt.device) 167 | 168 | ##### ---- Dataloader ---- ##### 169 | opt.nb_joints = 21 if opt.dataset_name == 'kit' else 22 170 | 171 | mean = np.load(pjoin(opt.checkpoints_dir, opt.dataset_name, model_opt.vq_name, 'meta', 'mean.npy')) 172 | std = np.load(pjoin(opt.checkpoints_dir, opt.dataset_name, model_opt.vq_name, 'meta', 'std.npy')) 173 | def inv_transform(data): 174 | return data * std + mean 175 | 176 | prompt_list = [] 177 | length_list = [] 178 | 179 | est_length = False 180 | if opt.text_prompt != "": 181 | prompt_list.append(opt.text_prompt) 182 | if opt.motion_length == 0: 183 | est_length = True 184 | else: 185 | length_list.append(opt.motion_length) 186 | elif opt.text_path != "": 187 | with open(opt.text_path, 'r') as f: 188 | lines = f.readlines() 189 | for line in lines: 190 | infos = line.split('#') 191 | prompt_list.append(infos[0]) 192 | if len(infos) == 1 or (not infos[1].isdigit()): 193 | est_length = True 194 | length_list = [] 195 | else: 196 | length_list.append(int(infos[-1])) 197 | else: 198 | raise "A text prompt, or a file a text prompts are required!!!" 199 | # print('loading checkpoint {}'.format(file)) 200 | 201 | if est_length: 202 | print("Since no motion length are specified, we will use estimated motion lengthes!!") 203 | text_embedding = t2m_transformer.encode_text(prompt_list) 204 | pred_dis = length_estimator(text_embedding) 205 | probs = F.softmax(pred_dis, dim=-1) # (b, ntoken) 206 | token_lens = Categorical(probs).sample() # (b, seqlen) 207 | # lengths = torch.multinomial() 208 | else: 209 | token_lens = torch.LongTensor(length_list) // 4 210 | token_lens = token_lens.to(opt.device).long() 211 | 212 | m_length = token_lens * 4 213 | captions = prompt_list 214 | 215 | sample = 0 216 | kinematic_chain = t2m_kinematic_chain 217 | converter = Joint2BVHConvertor() 218 | 219 | for r in range(opt.repeat_times): 220 | print("-->Repeat %d"%r) 221 | with torch.no_grad(): 222 | mids = t2m_transformer.generate(captions, token_lens, 223 | timesteps=opt.time_steps, 224 | cond_scale=opt.cond_scale, 225 | temperature=opt.temperature, 226 | topk_filter_thres=opt.topkr, 227 | gsample=opt.gumbel_sample) 228 | # print(mids) 229 | # print(mids.shape) 230 | mids = res_model.generate(mids, captions, token_lens, temperature=1, cond_scale=5) 231 | pred_motions = vq_model.forward_decoder(mids) 232 | 233 | pred_motions = pred_motions.detach().cpu().numpy() 234 | 235 | data = inv_transform(pred_motions) 236 | 237 | for k, (caption, joint_data) in enumerate(zip(captions, data)): 238 | print("---->Sample %d: %s %d"%(k, caption, m_length[k])) 239 | animation_path = pjoin(animation_dir, str(k)) 240 | joint_path = pjoin(joints_dir, str(k)) 241 | 242 | os.makedirs(animation_path, exist_ok=True) 243 | os.makedirs(joint_path, exist_ok=True) 244 | 245 | joint_data = joint_data[:m_length[k]] 246 | joint = recover_from_ric(torch.from_numpy(joint_data).float(), 22).numpy() 247 | 248 | bvh_path = pjoin(animation_path, "sample%d_repeat%d_len%d_ik.bvh"%(k, r, m_length[k])) 249 | _, ik_joint = converter.convert(joint, filename=bvh_path, iterations=100) 250 | 251 | bvh_path = pjoin(animation_path, "sample%d_repeat%d_len%d.bvh" % (k, r, m_length[k])) 252 | _, joint = converter.convert(joint, filename=bvh_path, iterations=100, foot_ik=False) 253 | 254 | 255 | save_path = pjoin(animation_path, "sample%d_repeat%d_len%d.mp4"%(k, r, m_length[k])) 256 | ik_save_path = pjoin(animation_path, "sample%d_repeat%d_len%d_ik.mp4"%(k, r, m_length[k])) 257 | 258 | plot_3d_motion(ik_save_path, kinematic_chain, ik_joint, title=caption, fps=20) 259 | plot_3d_motion(save_path, kinematic_chain, joint, title=caption, fps=20) 260 | np.save(pjoin(joint_path, "sample%d_repeat%d_len%d.npy"%(k, r, m_length[k])), joint) 261 | np.save(pjoin(joint_path, "sample%d_repeat%d_len%d_ik.npy"%(k, r, m_length[k])), ik_joint) --------------------------------------------------------------------------------