├── common ├── __init__.py └── skeleton.py ├── visualization ├── __init__.py ├── .DS_Store ├── data │ ├── .DS_Store │ ├── smpl │ │ └── smpl │ │ │ ├── .DS_Store │ │ │ └── smpl.txt │ └── gBR_sBM_cAll_d04_mBR0_ch01.pkl ├── quality-comp-walk_page-0001.jpg ├── joints2smpl │ ├── smpl_models │ │ ├── SMPL_downsample_index.pkl │ │ └── neutral_smpl_mean_params.h5 │ ├── environment.yaml │ ├── src │ │ ├── config.py │ │ └── prior.py │ ├── README.md │ └── fit_seq.py ├── blender_scripts │ ├── remove_frames.py │ ├── inpainting.py │ ├── suffix.py │ ├── prefix.py │ └── framing_coloring.py ├── render_mesh.py ├── vis_utils.py ├── motions2hik.py ├── simplify_loc2rot.py ├── joints2bvh.py ├── smpl2bvh.py └── utils │ └── bvh.py ├── my_clip ├── __init__.py ├── bpe_simple_vocab_16e6.txt.gz ├── simple_tokenizer.py └── model.py ├── exit ├── t2m-std.npy └── t2m-mean.npy ├── dataset ├── prepare │ ├── extract_kit.sh │ ├── download_glove.sh │ ├── download_smpl_files.sh │ ├── download_extractor.sh │ ├── download_models.sh │ └── download_kit.sh ├── dataset_VQ.py ├── dataset_tokenize.py ├── dataset_TM_train.py └── dataset_TM_eval.py ├── utils ├── fixseed.py ├── config.py ├── losses.py ├── my_sampler.py ├── misc.py ├── PYTORCH3D_LICENSE ├── paramUtil.py ├── dist_util.py ├── utils_model.py ├── humanml_utils.py ├── model_util.py ├── word_vectorizer.py ├── rotation2xyz.py └── smpl.py ├── .gitignore ├── models ├── clip_model.py ├── pos_encoding.py ├── encdec.py ├── resnet.py ├── evaluator_wrapper.py ├── modules.py └── vqvae.py ├── options ├── get_eval_option.py └── option_vq.py ├── environment2.yml ├── eval_edit.py ├── vq_eval.py ├── edit_eval └── main_edit_eval.py └── environment.yml /common/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /visualization/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /my_clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import * 2 | -------------------------------------------------------------------------------- /exit/t2m-std.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RohollahHS/BAD/HEAD/exit/t2m-std.npy -------------------------------------------------------------------------------- /exit/t2m-mean.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RohollahHS/BAD/HEAD/exit/t2m-mean.npy -------------------------------------------------------------------------------- /visualization/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RohollahHS/BAD/HEAD/visualization/.DS_Store -------------------------------------------------------------------------------- /visualization/data/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RohollahHS/BAD/HEAD/visualization/data/.DS_Store -------------------------------------------------------------------------------- /my_clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RohollahHS/BAD/HEAD/my_clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /visualization/data/smpl/smpl/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RohollahHS/BAD/HEAD/visualization/data/smpl/smpl/.DS_Store -------------------------------------------------------------------------------- /visualization/quality-comp-walk_page-0001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RohollahHS/BAD/HEAD/visualization/quality-comp-walk_page-0001.jpg -------------------------------------------------------------------------------- /visualization/data/gBR_sBM_cAll_d04_mBR0_ch01.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RohollahHS/BAD/HEAD/visualization/data/gBR_sBM_cAll_d04_mBR0_ch01.pkl -------------------------------------------------------------------------------- /dataset/prepare/extract_kit.sh: -------------------------------------------------------------------------------- 1 | cd dataset/KIT-ML 2 | 3 | unrar x new_joint_vecs.rar 4 | unrar x new_joints.rar 5 | unrar x texts.rar 6 | rm -rf *.rar 7 | 8 | cd ../.. -------------------------------------------------------------------------------- /visualization/joints2smpl/smpl_models/SMPL_downsample_index.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RohollahHS/BAD/HEAD/visualization/joints2smpl/smpl_models/SMPL_downsample_index.pkl -------------------------------------------------------------------------------- /visualization/joints2smpl/smpl_models/neutral_smpl_mean_params.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RohollahHS/BAD/HEAD/visualization/joints2smpl/smpl_models/neutral_smpl_mean_params.h5 -------------------------------------------------------------------------------- /visualization/data/smpl/smpl/smpl.txt: -------------------------------------------------------------------------------- 1 | Once you have downloaded the SMPL model, place it here like below. 2 | 3 | data 4 | |_smpl 5 | |_smpl 6 | |_SMPL_FEMALE.pkl 7 | |_SMPL_MALE.pkl 8 | |_SMPL_NEUTRAL.pkl 9 | -------------------------------------------------------------------------------- /dataset/prepare/download_glove.sh: -------------------------------------------------------------------------------- 1 | echo -e "Downloading glove (in use by the evaluators)" 2 | gdown --fuzzy https://drive.google.com/file/d/1bCeS6Sh_mLVTebxIgiUHgdPrroW06mb6/view?usp=sharing 3 | rm -rf glove 4 | 5 | unzip glove.zip 6 | echo -e "Cleaning\n" 7 | rm glove.zip 8 | 9 | echo -e "Downloading done!" -------------------------------------------------------------------------------- /visualization/blender_scripts/remove_frames.py: -------------------------------------------------------------------------------- 1 | import bpy 2 | 3 | # Iterate over all objects in the scene 4 | for obj in bpy.data.objects: 5 | # Check if "frame" is in the object's name 6 | if "frame" in obj.name: 7 | # Delete the object 8 | bpy.data.objects.remove(obj, do_unlink=True) -------------------------------------------------------------------------------- /dataset/prepare/download_smpl_files.sh: -------------------------------------------------------------------------------- 1 | mkdir -p checkpoints/smpl_models 2 | cd checkpoints/smpl_models/ 3 | 4 | echo -e "The smpl files will be stored in the 'checkpoints/smpl_models/smpl/' folder\n" 5 | gdown "https://drive.google.com/uc?id=1INYlGA76ak_cKGzvpOV2Pe6RkYTlXTW2" 6 | rm -rf smpl 7 | 8 | unzip smpl.zip 9 | echo -e "Cleaning\n" 10 | rm smpl.zip 11 | 12 | echo -e "Downloading done!" 13 | cd ../.. -------------------------------------------------------------------------------- /utils/fixseed.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import random 4 | 5 | 6 | def fixseed(seed): 7 | torch.backends.cudnn.benchmark = False 8 | random.seed(seed) 9 | np.random.seed(seed) 10 | torch.manual_seed(seed) 11 | 12 | 13 | # SEED = 10 14 | # EVALSEED = 0 15 | # # Provoc warning: not fully functionnal yet 16 | # # torch.set_deterministic(True) 17 | # torch.backends.cudnn.benchmark = False 18 | # fixseed(SEED) 19 | -------------------------------------------------------------------------------- /dataset/prepare/download_extractor.sh: -------------------------------------------------------------------------------- 1 | rm -rf checkpoints 2 | mkdir checkpoints 3 | cd checkpoints 4 | echo -e "Downloading extractors" 5 | gdown --fuzzy https://drive.google.com/file/d/1FIiqtkt4F-GVWmnBgtZnv9W3cPWS-oM-/view 6 | gdown --fuzzy https://drive.google.com/file/d/1KNU8CsMAnxFrwopKBBkC8jEULGLPBHQp/view 7 | 8 | unzip t2m.zip 9 | unzip kit.zip 10 | 11 | echo -e "Cleaning\n" 12 | rm t2m.zip 13 | rm kit.zip 14 | echo -e "Downloading done!" 15 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | ./output/* 5 | 6 | checkpoints 7 | glove 8 | dataset/HumanML3D 9 | dataset/KIT-ML 10 | dataset/KIT-ML 11 | ./dataset/prepare/download_humanml3d.sh 12 | ./dataset/prepare/extract_humanml3d.sh 13 | output 14 | slurm* 15 | *ipynb* 16 | *.zip 17 | dest_dir/ 18 | output_results/ 19 | core* 20 | *.pth 21 | *.ckpt 22 | 23 | 24 | fid_results/ 25 | jobs/ 26 | diffusion/ 27 | .vscode/ -------------------------------------------------------------------------------- /utils/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | SMPL_DATA_PATH = "./checkpoints/smpl_models/smpl" 4 | 5 | SMPL_KINTREE_PATH = os.path.join(SMPL_DATA_PATH, "kintree_table.pkl") 6 | SMPL_MODEL_PATH = os.path.join(SMPL_DATA_PATH, "SMPL_NEUTRAL.pkl") 7 | JOINT_REGRESSOR_TRAIN_EXTRA = os.path.join(SMPL_DATA_PATH, 'J_regressor_extra.npy') 8 | 9 | ROT_CONVENTION_TO_ROT_NUMBER = { 10 | 'legacy': 23, 11 | 'no_hands': 21, 12 | 'full_hands': 51, 13 | 'mitten_hands': 33, 14 | } 15 | 16 | GENDERS = ['neutral', 'male', 'female'] 17 | NUM_BETAS = 10 -------------------------------------------------------------------------------- /models/clip_model.py: -------------------------------------------------------------------------------- 1 | import my_clip 2 | import torch 3 | import os 4 | 5 | 6 | def load_clip_model(args): 7 | clip_model = my_clip.load("ViT-B/32", device=args.device, jit=False, 8 | download_root=os.path.join(os.environ.get("TORCH_HOME"), 'clip')) # Must set jit=False for training 9 | my_clip.model.convert_weights(clip_model) # Actually this line is unnecessary since clip by default already on float16 10 | clip_model.eval() 11 | for p in clip_model.parameters(): 12 | p.requires_grad = False 13 | return clip_model 14 | 15 | -------------------------------------------------------------------------------- /dataset/prepare/download_models.sh: -------------------------------------------------------------------------------- 1 | echo -e "Downloading pretrained models" 2 | 3 | mkdir -p ./output/vq/ 4 | mkdir -p ./output/t2m/ 5 | mkdir -p ./checkpoints/t2m/length_estimator/model/ 6 | 7 | cd ./output 8 | 9 | gdown --fuzzy https://drive.google.com/file/d/1fchcM7vWJpMKbDP7wTVrufIgfyYaXK49/view?usp=sharing 10 | mv vq_last.pth vq/ 11 | 12 | 13 | gdown --fuzzy https://drive.google.com/file/d/1ZldeaE9mYOAsG9B2UM-Oc_BpQw75xW4l/view?usp=sharing 14 | mv trans_best_fid.pth t2m/ 15 | 16 | cd .. 17 | 18 | cd ./checkpoints/t2m/length_estimator/model/ 19 | gdown --fuzzy https://drive.google.com/file/d/1eFphHaWX669pXVgXJTRvCOdKonQ0Pns_/view?usp=sharing 20 | 21 | 22 | echo -e "Finished" -------------------------------------------------------------------------------- /visualization/joints2smpl/environment.yaml: -------------------------------------------------------------------------------- 1 | name: fit3d 2 | channels: 3 | - conda-forge 4 | - pytorch 5 | - defaults 6 | - pytorch3d 7 | - open3d-admin 8 | - anaconda 9 | dependencies: 10 | - pip=21.1.3 11 | - numpy=1.20.3 12 | - numpy-base=1.20.3 13 | - matplotlib=3.4.2 14 | - matplotlib-base=3.4.2 15 | - pandas=1.3.1 16 | - python=3.7.6 17 | - pytorch=1.7.1 18 | - tensorboardx=2.2 19 | - cudatoolkit=10.2.89 20 | - torchvision=0.8.2 21 | - einops=0.3.0 22 | - pytorch3d=0.4.0 23 | - tqdm=4.61.2 24 | - trimesh=3.9.24 25 | - joblib=1.0.1 26 | - open3d=0.13.0 27 | - pip: 28 | - h5py==2.9.0 29 | - chumpy==0.70 30 | - smplx==0.1.28 31 | -------------------------------------------------------------------------------- /utils/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class ReConsLoss(nn.Module): 5 | def __init__(self, recons_loss, nb_joints): 6 | super(ReConsLoss, self).__init__() 7 | 8 | if recons_loss == 'l1': 9 | self.Loss = torch.nn.L1Loss() 10 | elif recons_loss == 'l2' : 11 | self.Loss = torch.nn.MSELoss() 12 | elif recons_loss == 'l1_smooth' : 13 | self.Loss = torch.nn.SmoothL1Loss() 14 | 15 | # 4 global motion associated to root 16 | # 12 local motion (3 local xyz, 3 vel xyz, 6 rot6d) 17 | # 3 global vel xyz 18 | # 4 foot contact 19 | self.nb_joints = nb_joints 20 | self.motion_dim = (nb_joints - 1) * 12 + 4 + 3 + 4 21 | 22 | def forward(self, motion_pred, motion_gt) : 23 | loss = self.Loss(motion_pred[..., : self.motion_dim], motion_gt[..., :self.motion_dim]) 24 | return loss 25 | 26 | def forward_joint(self, motion_pred, motion_gt) : 27 | loss = self.Loss(motion_pred[..., 4 : (self.nb_joints - 1) * 3 + 4], motion_gt[..., 4 : (self.nb_joints - 1) * 3 + 4]) 28 | return loss 29 | 30 | -------------------------------------------------------------------------------- /utils/my_sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Sampler 3 | 4 | 5 | class CustomSampler(Sampler): 6 | def __init__(self, data_source, num_replicas, rank, special_rank=None, special_ratio=2): 7 | self.data_source = data_source 8 | self.num_replicas = num_replicas 9 | self.rank = rank 10 | self.special_rank = special_rank 11 | self.special_ratio = special_ratio 12 | 13 | # Calculate the number of samples for each rank 14 | if rank == special_rank: 15 | self.num_samples = len(self.data_source) * special_ratio // (num_replicas + special_ratio - 1) 16 | else: 17 | self.num_samples = len(self.data_source) // (num_replicas + special_ratio - 1) 18 | 19 | self.total_size = self.num_samples * self.num_replicas 20 | 21 | def __iter__(self): 22 | indices = list(range(len(self.data_source))) 23 | if self.rank == self.special_rank: 24 | indices = indices * self.special_ratio 25 | 26 | indices += indices[:(self.total_size - len(indices))] 27 | indices = indices[self.rank:self.total_size:self.num_replicas] 28 | 29 | return iter(indices) 30 | 31 | def __len__(self): 32 | return self.num_samples 33 | -------------------------------------------------------------------------------- /dataset/prepare/download_kit.sh: -------------------------------------------------------------------------------- 1 | cd dataset 2 | mkdir KIT-ML 3 | 4 | cd KIT-ML 5 | echo -e "Downloading KIT-ML dataset" 6 | gdown --fuzzy https://drive.google.com/file/d/1ui0VuFl-4S3mXjHtlKOMhfXEqa1p2vlN/view?usp=drive_link&confirm=t 7 | gdown --fuzzy https://drive.google.com/file/d/1skF34AV8gKe2_4peasE4sdcHWSppN6MT/view?usp=drive_link&confirm=t 8 | gdown --fuzzy https://drive.google.com/file/d/1rkIyu5xVQaU669kyfkOKo_ObvPv8IOlM/view?usp=drive_link&confirm=t 9 | gdown --fuzzy https://drive.google.com/file/d/1m7e130_cOHAj3OjjeW5VOcwotL62IjwT/view?usp=drive_link&confirm=t 10 | gdown --fuzzy https://drive.google.com/file/d/1_cpQbCYKVgWwQV2r1eMLn32Ne1UFRHf_/view?usp=drive_link&confirm=t 11 | gdown --fuzzy https://drive.google.com/file/d/1_3n5MNkKCdDyCTPU6mCn-LYMoK0y9_yj/view?usp=drive_link&confirm=t 12 | gdown --fuzzy https://drive.google.com/file/d/1UcFEOIbMi_BfJRrNsgbFy6qEPf86M45s/view?usp=drive_link&confirm=t 13 | gdown --fuzzy https://drive.google.com/file/d/1OwgZEnsflyE90bMEfaGVFP_g2kNB5N5z/view?usp=drive_link&confirm=t 14 | gdown --fuzzy https://drive.google.com/file/d/1FF6cDP8h3q3OUa337WWb2d423ZtljBiR/view?usp=drive_link&confirm=t 15 | gdown --fuzzy https://drive.google.com/file/d/170u5YbFUq-BHqnMRbMXfJ4uhzhobSPkw/view?usp=drive_link&confirm=t 16 | cd ../.. 17 | 18 | echo -e "Downloading Done!\n" -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def to_numpy(tensor): 5 | if torch.is_tensor(tensor): 6 | return tensor.cpu().numpy() 7 | elif type(tensor).__module__ != 'numpy': 8 | raise ValueError("Cannot convert {} to numpy array".format( 9 | type(tensor))) 10 | return tensor 11 | 12 | 13 | def to_torch(ndarray): 14 | if type(ndarray).__module__ == 'numpy': 15 | return torch.from_numpy(ndarray) 16 | elif not torch.is_tensor(ndarray): 17 | raise ValueError("Cannot convert {} to torch tensor".format( 18 | type(ndarray))) 19 | return ndarray 20 | 21 | 22 | def cleanexit(): 23 | import sys 24 | import os 25 | try: 26 | sys.exit(0) 27 | except SystemExit: 28 | os._exit(0) 29 | 30 | def load_model_wo_clip(model, state_dict): 31 | missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) 32 | assert len(unexpected_keys) == 0 33 | assert all([k.startswith('clip_model.') for k in missing_keys]) 34 | 35 | def freeze_joints(x, joints_to_freeze): 36 | # Freezes selected joint *rotations* as they appear in the first frame 37 | # x [bs, [root+n_joints], joint_dim(6), seqlen] 38 | frozen = x.detach().clone() 39 | frozen[:, joints_to_freeze, :, :] = frozen[:, joints_to_freeze, :, :1] 40 | return frozen 41 | -------------------------------------------------------------------------------- /visualization/joints2smpl/src/config.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | # Map joints Name to SMPL joints idx 4 | JOINT_MAP = { 5 | 'MidHip': 0, 6 | 'LHip': 1, 'LKnee': 4, 'LAnkle': 7, 'LFoot': 10, 7 | 'RHip': 2, 'RKnee': 5, 'RAnkle': 8, 'RFoot': 11, 8 | 'LShoulder': 16, 'LElbow': 18, 'LWrist': 20, 'LHand': 22, 9 | 'RShoulder': 17, 'RElbow': 19, 'RWrist': 21, 'RHand': 23, 10 | 'spine1': 3, 'spine2': 6, 'spine3': 9, 'Neck': 12, 'Head': 15, 11 | 'LCollar':13, 'Rcollar' :14, 12 | 'Nose':24, 'REye':26, 'LEye':26, 'REar':27, 'LEar':28, 13 | 'LHeel': 31, 'RHeel': 34, 14 | 'OP RShoulder': 17, 'OP LShoulder': 16, 15 | 'OP RHip': 2, 'OP LHip': 1, 16 | 'OP Neck': 12, 17 | } 18 | 19 | full_smpl_idx = range(24) 20 | key_smpl_idx = [0, 1, 4, 7, 2, 5, 8, 17, 19, 21, 16, 18, 20] 21 | 22 | 23 | AMASS_JOINT_MAP = { 24 | 'MidHip': 0, 25 | 'LHip': 1, 'LKnee': 4, 'LAnkle': 7, 'LFoot': 10, 26 | 'RHip': 2, 'RKnee': 5, 'RAnkle': 8, 'RFoot': 11, 27 | 'LShoulder': 16, 'LElbow': 18, 'LWrist': 20, 28 | 'RShoulder': 17, 'RElbow': 19, 'RWrist': 21, 29 | 'spine1': 3, 'spine2': 6, 'spine3': 9, 'Neck': 12, 'Head': 15, 30 | 'LCollar':13, 'Rcollar' :14, 31 | } 32 | amass_idx = range(22) 33 | amass_smpl_idx = range(22) 34 | 35 | 36 | SMPL_MODEL_DIR = "./checkpoints/smpl_models/" 37 | GMM_MODEL_DIR = "./visualization/joints2smpl/smpl_models/" 38 | SMPL_MEAN_FILE = "./visualization/joints2smpl/smpl_models/neutral_smpl_mean_params.h5" 39 | # for collsion 40 | Part_Seg_DIR = "./visualization/joints2smpl/smpl_models/smplx_parts_segm.pkl" -------------------------------------------------------------------------------- /models/pos_encoding.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various positional encodings for the transformer. 3 | """ 4 | import math 5 | import torch 6 | from torch import nn 7 | 8 | def PE1d_sincos(seq_length, dim): 9 | """ 10 | :param d_model: dimension of the model 11 | :param length: length of positions 12 | :return: length*d_model position matrix 13 | """ 14 | if dim % 2 != 0: 15 | raise ValueError("Cannot use sin/cos positional encoding with " 16 | "odd dim (got dim={:d})".format(dim)) 17 | pe = torch.zeros(seq_length, dim) 18 | position = torch.arange(0, seq_length).unsqueeze(1) 19 | div_term = torch.exp((torch.arange(0, dim, 2, dtype=torch.float) * 20 | -(math.log(10000.0) / dim))) 21 | pe[:, 0::2] = torch.sin(position.float() * div_term) 22 | pe[:, 1::2] = torch.cos(position.float() * div_term) 23 | 24 | return pe.unsqueeze(1) 25 | 26 | 27 | class PositionEmbedding(nn.Module): 28 | """ 29 | Absolute pos embedding (standard), learned. 30 | """ 31 | def __init__(self, seq_length, dim, dropout, grad=False): 32 | super().__init__() 33 | self.embed = nn.Parameter(data=PE1d_sincos(seq_length, dim), requires_grad=grad) 34 | self.dropout = nn.Dropout(p=dropout) 35 | 36 | def forward(self, x): 37 | # x.shape: bs, seq_len, feat_dim 38 | l = x.shape[1] 39 | x = x.permute(1, 0, 2) + self.embed[:l].expand(x.permute(1, 0, 2).shape) 40 | x = self.dropout(x.permute(1, 0, 2)) 41 | return x 42 | 43 | -------------------------------------------------------------------------------- /utils/PYTORCH3D_LICENSE: -------------------------------------------------------------------------------- 1 | BSD License 2 | 3 | For PyTorch3D software 4 | 5 | Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 6 | 7 | Redistribution and use in source and binary forms, with or without modification, 8 | are permitted provided that the following conditions are met: 9 | 10 | * Redistributions of source code must retain the above copyright notice, this 11 | list of conditions and the following disclaimer. 12 | 13 | * Redistributions in binary form must reproduce the above copyright notice, 14 | this list of conditions and the following disclaimer in the documentation 15 | and/or other materials provided with the distribution. 16 | 17 | * Neither the name Facebook nor the names of its contributors may be used to 18 | endorse or promote products derived from this software without specific 19 | prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 22 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 23 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 25 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 26 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 27 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 28 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 30 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /visualization/blender_scripts/inpainting.py: -------------------------------------------------------------------------------- 1 | # inbetween 2 | import bpy 3 | 4 | yellow = (1, 0.905, 0.05, 1.0) 5 | red = (0.9, 0.1, 0.1, 1.0) 6 | 7 | # Create materials if they don't exist 8 | yellow_material = bpy.data.materials.get("YellowMaterial") 9 | if not yellow_material: 10 | yellow_material = bpy.data.materials.new(name="YellowMaterial") 11 | yellow_material.diffuse_color = yellow 12 | yellow_material.use_nodes = False 13 | 14 | red_material = bpy.data.materials.get("RedMaterial") 15 | if not red_material: 16 | red_material = bpy.data.materials.new(name="RedMaterial") 17 | red_material.diffuse_color = red 18 | red_material.use_nodes = False 19 | 20 | objects = [obj for obj in bpy.data.objects if obj.type == 'MESH' and "frame" in obj.name] 21 | 22 | for i, obj in enumerate(objects): 23 | # Determine material based on frame range 24 | if i < 49 or i >= len(objects) - 49: 25 | material = red_material 26 | else: 27 | material = yellow_material 28 | 29 | # Assign the material 30 | if len(obj.data.materials) > 0: 31 | obj.data.materials[0] = material 32 | else: 33 | obj.data.materials.append(material) 34 | 35 | # Set visibility keyframes 36 | obj.hide_viewport = True 37 | obj.hide_render = True 38 | obj.keyframe_insert(data_path="hide_viewport", frame=1) 39 | obj.keyframe_insert(data_path="hide_render", frame=1) 40 | 41 | frame = i + 1 42 | obj.hide_viewport = False 43 | obj.hide_render = False 44 | obj.keyframe_insert(data_path="hide_viewport", frame=frame) 45 | obj.keyframe_insert(data_path="hide_render", frame=frame) 46 | 47 | obj.hide_viewport = True 48 | obj.hide_render = True 49 | obj.keyframe_insert(data_path="hide_viewport", frame=frame + 1) 50 | obj.keyframe_insert(data_path="hide_render", frame=frame + 1) 51 | 52 | print(f"Processed {len(objects)} objects with 'frame' in their names.") -------------------------------------------------------------------------------- /visualization/blender_scripts/suffix.py: -------------------------------------------------------------------------------- 1 | import bpy 2 | 3 | yellow = (1, 0.905, 0.05, 1.0) 4 | red = (0.9, 0.1, 0.1, 1.0) 5 | 6 | 7 | # Create materials if they don't exist 8 | yellow_material = bpy.data.materials.get("YellowMaterial") 9 | if not yellow_material: 10 | yellow_material = bpy.data.materials.new(name="YellowMaterial") 11 | yellow_material.diffuse_color = yellow 12 | yellow_material.use_nodes = False 13 | 14 | red_material = bpy.data.materials.get("RedMaterial") 15 | if not red_material: 16 | red_material = bpy.data.materials.new(name="RedMaterial") 17 | red_material.diffuse_color = red 18 | red_material.use_nodes = False 19 | 20 | objects = [obj for obj in bpy.data.objects if obj.type == 'MESH' and "frame" in obj.name] 21 | total_ob = len(objects) 22 | 23 | for i, obj in enumerate(objects): 24 | # Determine material based on frame range 25 | if i > total_ob//2: 26 | material = red_material 27 | else: 28 | material = yellow_material 29 | 30 | # Assign the material 31 | if len(obj.data.materials) > 0: 32 | obj.data.materials[0] = material 33 | else: 34 | obj.data.materials.append(material) 35 | 36 | # Set visibility keyframes 37 | obj.hide_viewport = True 38 | obj.hide_render = True 39 | obj.keyframe_insert(data_path="hide_viewport", frame=1) 40 | obj.keyframe_insert(data_path="hide_render", frame=1) 41 | 42 | frame = i + 1 43 | obj.hide_viewport = False 44 | obj.hide_render = False 45 | obj.keyframe_insert(data_path="hide_viewport", frame=frame) 46 | obj.keyframe_insert(data_path="hide_render", frame=frame) 47 | 48 | obj.hide_viewport = True 49 | obj.hide_render = True 50 | obj.keyframe_insert(data_path="hide_viewport", frame=frame + 1) 51 | obj.keyframe_insert(data_path="hide_render", frame=frame + 1) 52 | 53 | print(f"Processed {len(objects)} objects with 'frame' in their names.") 54 | -------------------------------------------------------------------------------- /visualization/blender_scripts/prefix.py: -------------------------------------------------------------------------------- 1 | import bpy 2 | 3 | yellow = (1, 0.905, 0.05, 1.0) 4 | red = (0.9, 0.1, 0.1, 1.0) 5 | 6 | total_frames = 198 7 | 8 | # Create materials if they don't exist 9 | yellow_material = bpy.data.materials.get("YellowMaterial") 10 | if not yellow_material: 11 | yellow_material = bpy.data.materials.new(name="YellowMaterial") 12 | yellow_material.diffuse_color = yellow 13 | yellow_material.use_nodes = False 14 | 15 | red_material = bpy.data.materials.get("RedMaterial") 16 | if not red_material: 17 | red_material = bpy.data.materials.new(name="RedMaterial") 18 | red_material.diffuse_color = red 19 | red_material.use_nodes = False 20 | 21 | objects = [obj for obj in bpy.data.objects if obj.type == 'MESH' and "frame" in obj.name] 22 | 23 | for i, obj in enumerate(objects): 24 | # Determine material based on frame range 25 | if i < int(total_frames//2): 26 | material = red_material 27 | else: 28 | material = yellow_material 29 | 30 | # Assign the material 31 | if len(obj.data.materials) > 0: 32 | obj.data.materials[0] = material 33 | else: 34 | obj.data.materials.append(material) 35 | 36 | # Set visibility keyframes 37 | obj.hide_viewport = True 38 | obj.hide_render = True 39 | obj.keyframe_insert(data_path="hide_viewport", frame=1) 40 | obj.keyframe_insert(data_path="hide_render", frame=1) 41 | 42 | frame = i + 1 43 | obj.hide_viewport = False 44 | obj.hide_render = False 45 | obj.keyframe_insert(data_path="hide_viewport", frame=frame) 46 | obj.keyframe_insert(data_path="hide_render", frame=frame) 47 | 48 | obj.hide_viewport = True 49 | obj.hide_render = True 50 | obj.keyframe_insert(data_path="hide_viewport", frame=frame + 1) 51 | obj.keyframe_insert(data_path="hide_render", frame=frame + 1) 52 | 53 | print(f"Processed {len(objects)} objects with 'frame' in their names.") 54 | -------------------------------------------------------------------------------- /visualization/joints2smpl/README.md: -------------------------------------------------------------------------------- 1 | # joints2smpl 2 | fit SMPL model using 3D joints 3 | 4 | ## Prerequisites 5 | We have tested the code on Ubuntu 18.04/20.04 with CUDA 10.2/11.3 6 | 7 | ## Installation 8 | First you have to make sure that you have all dependencies in place. 9 | The simplest way to do is to use the [anaconda](https://www.anaconda.com/). 10 | 11 | You can create an anaconda environment called `fit3d` using 12 | ``` 13 | conda env create -f environment.yaml 14 | conda activate fit3d 15 | ``` 16 | 17 | ## Download SMPL models 18 | Download [SMPL Female and Male](https://smpl.is.tue.mpg.de/) and [SMPL Netural](https://smplify.is.tue.mpg.de/), and rename the files and extract them to `/smpl_models/smpl/`, eventually, the `/smpl_models` folder should have the following structure: 19 | ``` 20 | smpl_models 21 | └-- smpl 22 | └-- SMPL_FEMALE.pkl 23 | └-- SMPL_MALE.pkl 24 | └-- SMPL_NEUTRAL.pkl 25 | ``` 26 | 27 | ## Demo 28 | ### Demo for sequences 29 | python fit_seq.py --files test_motion2.npy 30 | 31 | The results will locate in ./demo/demo_results/ 32 | 33 | ## Citation 34 | If you find this project useful for your research, please consider citing: 35 | ``` 36 | @article{zuo2021sparsefusion, 37 | title={Sparsefusion: Dynamic human avatar modeling from sparse rgbd images}, 38 | author={Zuo, Xinxin and Wang, Sen and Zheng, Jiangbin and Yu, Weiwei and Gong, Minglun and Yang, Ruigang and Cheng, Li}, 39 | journal={IEEE Transactions on Multimedia}, 40 | volume={23}, 41 | pages={1617--1629}, 42 | year={2021} 43 | } 44 | ``` 45 | 46 | ## References 47 | We indicate if a function or script is borrowed externally inside each file. Here are some great resources we 48 | benefit: 49 | 50 | - Shape/Pose prior and some functions are borrowed from [VIBE](https://github.com/mkocabas/VIBE). 51 | - SMPL models and layer is from [SMPL-X model](https://github.com/vchoutas/smplx). 52 | - Some functions are borrowed from [HMR-pytorch](https://github.com/MandyMo/pytorch_HMR). 53 | -------------------------------------------------------------------------------- /visualization/render_mesh.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | sys.path.append(os.getcwd()) 5 | from visualization import vis_utils 6 | import shutil 7 | from tqdm import tqdm 8 | 9 | if __name__ == '__main__': 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument("--input_path", type=str, required=True, help='stick figure mp4 file to be rendered.') 12 | parser.add_argument("--npy_path", type=str) 13 | parser.add_argument("--n_iter_smpl", type=int, default=150) 14 | parser.add_argument("--cuda", type=bool, default=True, help='') 15 | parser.add_argument("--device", type=str, default='cuda', help='') 16 | params = parser.parse_args() 17 | 18 | assert params.input_path.endswith('.mp4') 19 | parsed_name = os.path.basename(params.input_path).replace('.mp4', '').replace('sample', '').replace('repeat', '') 20 | # sample_i, rep_i = [int(e) for e in parsed_name.split('_')[:2]] 21 | sample_i, rep_i = 0, 0 22 | if params.npy_path is None: 23 | npy_path = f"/output/visualization/joints/{params.input_path.split('/')[-1][:-4]}.npy" 24 | else: 25 | npy_path = params.npy_path 26 | # npy_path = os.path.join(os.path.dirname(params.input_path), 'results.npy') 27 | out_npy_path = params.input_path.replace('.mp4', '_smpl_params.npy') 28 | assert os.path.exists(npy_path) 29 | results_dir = params.input_path.replace('.mp4', '_obj') 30 | if os.path.exists(results_dir): 31 | shutil.rmtree(results_dir) 32 | os.makedirs(results_dir) 33 | 34 | npy2obj = vis_utils.npy2obj(npy_path, sample_i, rep_i, 35 | device=params.device, cuda=params.cuda, n_iter=params.n_iter_smpl) 36 | 37 | print('Saving obj files to [{}]'.format(os.path.abspath(results_dir))) 38 | for frame_i in tqdm(range(npy2obj.real_num_frames)): 39 | npy2obj.save_obj(os.path.join(results_dir, 'frame{:03d}.obj'.format(frame_i)), frame_i) 40 | 41 | print('Saving SMPL params to [{}]'.format(os.path.abspath(out_npy_path))) 42 | npy2obj.save_npy(out_npy_path) -------------------------------------------------------------------------------- /utils/paramUtil.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | # Define a kinematic tree for the skeletal struture 4 | kit_kinematic_chain = [[0, 11, 12, 13, 14, 15], [0, 16, 17, 18, 19, 20], [0, 1, 2, 3, 4], [3, 5, 6, 7], [3, 8, 9, 10]] 5 | 6 | kit_raw_offsets = np.array( 7 | [ 8 | [0, 0, 0], 9 | [0, 1, 0], 10 | [0, 1, 0], 11 | [0, 1, 0], 12 | [0, 1, 0], 13 | [1, 0, 0], 14 | [0, -1, 0], 15 | [0, -1, 0], 16 | [-1, 0, 0], 17 | [0, -1, 0], 18 | [0, -1, 0], 19 | [1, 0, 0], 20 | [0, -1, 0], 21 | [0, -1, 0], 22 | [0, 0, 1], 23 | [0, 0, 1], 24 | [-1, 0, 0], 25 | [0, -1, 0], 26 | [0, -1, 0], 27 | [0, 0, 1], 28 | [0, 0, 1] 29 | ] 30 | ) 31 | 32 | t2m_raw_offsets = np.array([[0,0,0], 33 | [1,0,0], 34 | [-1,0,0], 35 | [0,1,0], 36 | [0,-1,0], 37 | [0,-1,0], 38 | [0,1,0], 39 | [0,-1,0], 40 | [0,-1,0], 41 | [0,1,0], 42 | [0,0,1], 43 | [0,0,1], 44 | [0,1,0], 45 | [1,0,0], 46 | [-1,0,0], 47 | [0,0,1], 48 | [0,-1,0], 49 | [0,-1,0], 50 | [0,-1,0], 51 | [0,-1,0], 52 | [0,-1,0], 53 | [0,-1,0]]) 54 | 55 | t2m_kinematic_chain = [[0, 2, 5, 8, 11], [0, 1, 4, 7, 10], [0, 3, 6, 9, 12, 15], [9, 14, 17, 19, 21], [9, 13, 16, 18, 20]] 56 | t2m_left_hand_chain = [[20, 22, 23, 24], [20, 34, 35, 36], [20, 25, 26, 27], [20, 31, 32, 33], [20, 28, 29, 30]] 57 | t2m_right_hand_chain = [[21, 43, 44, 45], [21, 46, 47, 48], [21, 40, 41, 42], [21, 37, 38, 39], [21, 49, 50, 51]] 58 | 59 | 60 | kit_tgt_skel_id = '03950' 61 | 62 | t2m_tgt_skel_id = '000021' 63 | 64 | -------------------------------------------------------------------------------- /visualization/blender_scripts/framing_coloring.py: -------------------------------------------------------------------------------- 1 | import bpy 2 | 3 | # Define the new color (RGBA) for the material 4 | #new_color = (0.230, 0.834, 1.0, 1.0) # Example: cyan color 5 | new_color = (1, 0.905, 0.05, 1.0) # Alternative: yellow color 6 | #new_color = (1, 0.1, 0.1, 1.0) 7 | 8 | # Create a new material for colorization 9 | material_name = "NewMaterial" 10 | new_material = bpy.data.materials.get(material_name) 11 | if not new_material: 12 | new_material = bpy.data.materials.new(name=material_name) 13 | new_material.diffuse_color = new_color # Set the material's color 14 | new_material.use_nodes = False # Disable nodes to make diffuse_color work 15 | 16 | # Get all mesh objects with "frame" in their name 17 | objects = [obj for obj in bpy.data.objects if obj.type == 'MESH' and "frame" in obj.name] 18 | 19 | # Iterate through each object to set visibility keyframes and apply the material 20 | for i, obj in enumerate(objects): 21 | # Assign the new material to the object 22 | if len(obj.data.materials) > 0: 23 | obj.data.materials[0] = new_material # Replace the first material slot 24 | else: 25 | obj.data.materials.append(new_material) # Add material if no slots exist 26 | 27 | # Set initial keyframes to hide the object by default 28 | obj.hide_viewport = True 29 | obj.hide_render = True 30 | obj.keyframe_insert(data_path="hide_viewport", frame=1) 31 | obj.keyframe_insert(data_path="hide_render", frame=1) 32 | 33 | # Make the object visible on its assigned frame (i+1) 34 | frame = i + 1 # Assign each object a frame starting from 1 35 | obj.hide_viewport = False 36 | obj.hide_render = False 37 | obj.keyframe_insert(data_path="hide_viewport", frame=frame) 38 | obj.keyframe_insert(data_path="hide_render", frame=frame) 39 | 40 | # Immediately hide the object again on the next frame 41 | obj.hide_viewport = True 42 | obj.hide_render = True 43 | obj.keyframe_insert(data_path="hide_viewport", frame=frame + 1) 44 | obj.keyframe_insert(data_path="hide_render", frame=frame + 1) 45 | 46 | print(f"Processed {len(objects)} objects with 'frame' in their names.") -------------------------------------------------------------------------------- /utils/dist_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for distributed training. 3 | """ 4 | 5 | import socket 6 | 7 | import torch as th 8 | import torch.distributed as dist 9 | 10 | # Change this to reflect your cluster layout. 11 | # The GPU for a given rank is (rank % GPUS_PER_NODE). 12 | GPUS_PER_NODE = 8 13 | 14 | SETUP_RETRY_COUNT = 3 15 | 16 | used_device = 0 17 | 18 | def setup_dist(device=0): 19 | """ 20 | Setup a distributed process group. 21 | """ 22 | global used_device 23 | used_device = device 24 | if dist.is_initialized(): 25 | return 26 | # os.environ["CUDA_VISIBLE_DEVICES"] = str(device) # f"{MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE}" 27 | 28 | # comm = MPI.COMM_WORLD 29 | # backend = "gloo" if not th.cuda.is_available() else "nccl" 30 | 31 | # if backend == "gloo": 32 | # hostname = "localhost" 33 | # else: 34 | # hostname = socket.gethostbyname(socket.getfqdn()) 35 | # os.environ["MASTER_ADDR"] = comm.bcast(hostname, root=0) 36 | # os.environ["RANK"] = str(comm.rank) 37 | # os.environ["WORLD_SIZE"] = str(comm.size) 38 | 39 | # port = comm.bcast(_find_free_port(), root=used_device) 40 | # os.environ["MASTER_PORT"] = str(port) 41 | # dist.init_process_group(backend=backend, init_method="env://") 42 | 43 | 44 | def dev(): 45 | """ 46 | Get the device to use for torch.distributed. 47 | """ 48 | global used_device 49 | if th.cuda.is_available() and used_device>=0: 50 | return th.device(f"cuda:{used_device}") 51 | return th.device("cpu") 52 | 53 | 54 | def load_state_dict(path, **kwargs): 55 | """ 56 | Load a PyTorch file without redundant fetches across MPI ranks. 57 | """ 58 | return th.load(path, **kwargs) 59 | 60 | 61 | def sync_params(params): 62 | """ 63 | Synchronize a sequence of Tensors across ranks from rank 0. 64 | """ 65 | for p in params: 66 | with th.no_grad(): 67 | dist.broadcast(p, 0) 68 | 69 | 70 | def _find_free_port(): 71 | try: 72 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 73 | s.bind(("", 0)) 74 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 75 | return s.getsockname()[1] 76 | finally: 77 | s.close() 78 | -------------------------------------------------------------------------------- /utils/utils_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.optim as optim 4 | import logging 5 | import os 6 | import sys 7 | 8 | def getCi(accLog): 9 | 10 | mean = np.mean(accLog) 11 | std = np.std(accLog) 12 | ci95 = 1.96*std/np.sqrt(len(accLog)) 13 | 14 | return mean, ci95 15 | 16 | def get_logger(out_dir, resume_pth=None, args=None): 17 | 18 | # Determine the file path and mode 19 | try: 20 | file_path = os.path.join(out_dir, f"run_{args.rank}.log") 21 | except: 22 | file_path = os.path.join(out_dir, f"run.log") 23 | 24 | logger = logging.getLogger('Exp') 25 | logger.setLevel(logging.INFO) 26 | formatter = logging.Formatter("%(asctime)s %(levelname)s %(message)s") 27 | 28 | file_hdlr = logging.FileHandler(file_path) 29 | file_hdlr.setFormatter(formatter) 30 | 31 | strm_hdlr = logging.StreamHandler(sys.stdout) 32 | strm_hdlr.setFormatter(formatter) 33 | 34 | logger.addHandler(file_hdlr) 35 | logger.addHandler(strm_hdlr) 36 | 37 | return logger 38 | 39 | ## Optimizer 40 | def initial_optim(decay_option, lr, weight_decay, net, optimizer) : 41 | 42 | if optimizer == 'adamw' : 43 | optimizer_adam_family = optim.AdamW 44 | elif optimizer == 'adam' : 45 | optimizer_adam_family = optim.Adam 46 | if decay_option == 'all': 47 | #optimizer = optimizer_adam_family(net.parameters(), lr=lr, betas=(0.9, 0.999), weight_decay=weight_decay) 48 | optimizer = optimizer_adam_family(net.parameters(), lr=lr, betas=(0.5, 0.9), weight_decay=weight_decay) 49 | 50 | elif decay_option == 'noVQ': 51 | all_params = set(net.parameters()) 52 | no_decay = set([net.vq_layer]) 53 | 54 | decay = all_params - no_decay 55 | optimizer = optimizer_adam_family([ 56 | {'params': list(no_decay), 'weight_decay': 0}, 57 | {'params': list(decay), 'weight_decay' : weight_decay}], lr=lr) 58 | 59 | return optimizer 60 | 61 | 62 | def get_motion_with_trans(motion, velocity) : 63 | ''' 64 | motion : torch.tensor, shape (batch_size, T, 72), with the global translation = 0 65 | velocity : torch.tensor, shape (batch_size, T, 3), contain the information of velocity = 0 66 | 67 | ''' 68 | trans = torch.cumsum(velocity, dim=1) 69 | trans = trans - trans[:, :1] ## the first root is initialized at 0 (just for visualization) 70 | trans = trans.repeat((1, 1, 21)) 71 | motion_with_trans = motion + trans 72 | return motion_with_trans 73 | -------------------------------------------------------------------------------- /models/encdec.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from models.resnet import Resnet1D 3 | 4 | class PrintModule(nn.Module): 5 | def __init__(self, me=''): 6 | super().__init__() 7 | self.me = me 8 | 9 | def forward(self, x): 10 | print(self.me, x.shape) 11 | return x 12 | 13 | class Encoder(nn.Module): 14 | def __init__(self, 15 | input_emb_width = 3, 16 | output_emb_width = 512, 17 | down_t = 3, 18 | stride_t = 2, 19 | width = 512, 20 | depth = 3, 21 | dilation_growth_rate = 3, 22 | activation='relu', 23 | norm=None): 24 | super().__init__() 25 | 26 | blocks = [] 27 | filter_t, pad_t = stride_t * 2, stride_t // 2 28 | blocks.append(nn.Conv1d(input_emb_width, width, 3, 1, 1)) 29 | blocks.append(nn.ReLU()) 30 | 31 | for i in range(down_t): 32 | input_dim = width 33 | block = nn.Sequential( 34 | nn.Conv1d(input_dim, width, filter_t, stride_t, pad_t), 35 | Resnet1D(width, depth, dilation_growth_rate, activation=activation, norm=norm), 36 | ) 37 | blocks.append(block) 38 | blocks.append(nn.Conv1d(width, output_emb_width, 3, 1, 1)) 39 | self.model = nn.Sequential(*blocks) 40 | 41 | def forward(self, x): 42 | return self.model(x) 43 | 44 | class Decoder(nn.Module): 45 | def __init__(self, 46 | input_emb_width = 3, 47 | output_emb_width = 512, 48 | down_t = 3, 49 | stride_t = 2, 50 | width = 512, 51 | depth = 3, 52 | dilation_growth_rate = 3, 53 | activation='relu', 54 | norm=None): 55 | super().__init__() 56 | blocks = [] 57 | 58 | filter_t, pad_t = stride_t * 2, stride_t // 2 59 | blocks.append(nn.Conv1d(output_emb_width, width, 3, 1, 1)) 60 | blocks.append(nn.ReLU()) 61 | for i in range(down_t): 62 | out_dim = width 63 | block = nn.Sequential( 64 | Resnet1D(width, depth, dilation_growth_rate, reverse_dilation=True, activation=activation, norm=norm), 65 | nn.Upsample(scale_factor=2, mode='nearest'), 66 | nn.Conv1d(width, out_dim, 3, 1, 1) 67 | ) 68 | blocks.append(block) 69 | blocks.append(nn.Conv1d(width, width, 3, 1, 1)) 70 | blocks.append(nn.ReLU()) 71 | blocks.append(nn.Conv1d(width, input_emb_width, 3, 1, 1)) 72 | self.model = nn.Sequential(*blocks) 73 | 74 | def forward(self, x): 75 | return self.model(x) 76 | 77 | -------------------------------------------------------------------------------- /utils/humanml_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | HML_JOINT_NAMES = [ 4 | 'pelvis', 5 | 'left_hip', 6 | 'right_hip', 7 | 'spine1', 8 | 'left_knee', 9 | 'right_knee', 10 | 'spine2', 11 | 'left_ankle', 12 | 'right_ankle', 13 | 'spine3', 14 | 'left_foot', 15 | 'right_foot', 16 | 'neck', 17 | 'left_collar', 18 | 'right_collar', 19 | 'head', 20 | 'left_shoulder', 21 | 'right_shoulder', 22 | 'left_elbow', 23 | 'right_elbow', 24 | 'left_wrist', 25 | 'right_wrist', 26 | ] 27 | 28 | NUM_HML_JOINTS = len(HML_JOINT_NAMES) # 22 SMPLH body joints 29 | 30 | HML_LOWER_BODY_JOINTS = [HML_JOINT_NAMES.index(name) for name in ['pelvis', 'left_hip', 'right_hip', 'left_knee', 'right_knee', 'left_ankle', 'right_ankle', 'left_foot', 'right_foot',]] 31 | SMPL_UPPER_BODY_JOINTS = [i for i in range(len(HML_JOINT_NAMES)) if i not in HML_LOWER_BODY_JOINTS] 32 | 33 | 34 | # Recover global angle and positions for rotation data 35 | # root_rot_velocity (B, seq_len, 1) 36 | # root_linear_velocity (B, seq_len, 2) 37 | # root_y (B, seq_len, 1) 38 | # ric_data (B, seq_len, (joint_num - 1)*3) 39 | # rot_data (B, seq_len, (joint_num - 1)*6) 40 | # local_velocity (B, seq_len, joint_num*3) 41 | # foot contact (B, seq_len, 4) 42 | HML_ROOT_BINARY = np.array([True] + [False] * (NUM_HML_JOINTS-1)) 43 | HML_ROOT_MASK = np.concatenate(([True]*(1+2+1), 44 | HML_ROOT_BINARY[1:].repeat(3), 45 | HML_ROOT_BINARY[1:].repeat(6), 46 | HML_ROOT_BINARY.repeat(3), 47 | [False] * 4)) 48 | HML_LOWER_BODY_JOINTS_BINARY = np.array([i in HML_LOWER_BODY_JOINTS for i in range(NUM_HML_JOINTS)]) 49 | HML_LOWER_BODY_MASK = np.concatenate(([True]*(1+2+1), 50 | HML_LOWER_BODY_JOINTS_BINARY[1:].repeat(3), 51 | HML_LOWER_BODY_JOINTS_BINARY[1:].repeat(6), 52 | HML_LOWER_BODY_JOINTS_BINARY.repeat(3), 53 | [True]*4)) 54 | HML_UPPER_BODY_MASK = ~HML_LOWER_BODY_MASK 55 | 56 | 57 | ALL_JOINT_FALSE = np.full(*HML_ROOT_BINARY.shape, False) 58 | HML_UPPER_BODY_JOINTS_BINARY = np.array([i in SMPL_UPPER_BODY_JOINTS for i in range(NUM_HML_JOINTS)]) 59 | 60 | UPPER_JOINT_Y_TRUE = np.array([ALL_JOINT_FALSE[1:], HML_UPPER_BODY_JOINTS_BINARY[1:], ALL_JOINT_FALSE[1:]]) 61 | UPPER_JOINT_Y_TRUE = UPPER_JOINT_Y_TRUE.T 62 | UPPER_JOINT_Y_TRUE = UPPER_JOINT_Y_TRUE.reshape(ALL_JOINT_FALSE[1:].shape[0]*3) 63 | 64 | UPPER_JOINT_Y_MASK = np.concatenate(([False]*(1+2+1), 65 | UPPER_JOINT_Y_TRUE, 66 | ALL_JOINT_FALSE[1:].repeat(6), 67 | ALL_JOINT_FALSE.repeat(3), 68 | [False] * 4)) -------------------------------------------------------------------------------- /options/get_eval_option.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | import re 3 | from os.path import join as pjoin 4 | 5 | 6 | def is_float(numStr): 7 | flag = False 8 | numStr = str(numStr).strip().lstrip('-').lstrip('+') 9 | try: 10 | reg = re.compile(r'^[-+]?[0-9]+\.[0-9]+$') 11 | res = reg.match(str(numStr)) 12 | if res: 13 | flag = True 14 | except Exception as ex: 15 | print("is_float() - error: " + str(ex)) 16 | return flag 17 | 18 | 19 | def is_number(numStr): 20 | flag = False 21 | numStr = str(numStr).strip().lstrip('-').lstrip('+') 22 | if str(numStr).isdigit(): 23 | flag = True 24 | return flag 25 | 26 | 27 | def get_opt(opt_path, device): 28 | opt = Namespace() 29 | opt_dict = vars(opt) 30 | 31 | skip = ('-------------- End ----------------', 32 | '------------ Options -------------', 33 | '\n') 34 | print('Reading', opt_path) 35 | with open(opt_path) as f: 36 | for line in f: 37 | if line.strip() not in skip: 38 | # print(line.strip()) 39 | key, value = line.strip().split(': ') 40 | if value in ('True', 'False'): 41 | opt_dict[key] = (value == 'True') 42 | # print(key, value) 43 | elif is_float(value): 44 | opt_dict[key] = float(value) 45 | elif is_number(value): 46 | opt_dict[key] = int(value) 47 | else: 48 | opt_dict[key] = str(value) 49 | 50 | # print(opt) 51 | opt_dict['which_epoch'] = 'finest' 52 | opt.save_root = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.name) 53 | opt.model_dir = pjoin(opt.save_root, 'model') 54 | opt.meta_dir = pjoin(opt.save_root, 'meta') 55 | 56 | if opt.dataset_name == 't2m': 57 | opt.data_root = './dataset/HumanML3D/' 58 | opt.motion_dir = pjoin(opt.data_root, 'new_joint_vecs') 59 | opt.text_dir = pjoin(opt.data_root, 'texts') 60 | opt.joints_num = 22 61 | opt.dim_pose = 263 62 | opt.max_motion_length = 196 63 | opt.max_motion_frame = 196 64 | opt.max_motion_token = 55 65 | elif opt.dataset_name == 'kit': 66 | opt.data_root = './dataset/KIT-ML/' 67 | opt.motion_dir = pjoin(opt.data_root, 'new_joint_vecs') 68 | opt.text_dir = pjoin(opt.data_root, 'texts') 69 | opt.joints_num = 21 70 | opt.dim_pose = 251 71 | opt.max_motion_length = 196 72 | opt.max_motion_frame = 196 73 | opt.max_motion_token = 55 74 | else: 75 | raise KeyError('Dataset not recognized') 76 | 77 | opt.dim_word = 300 78 | opt.num_classes = 200 // opt.unit_length 79 | opt.is_train = False 80 | opt.is_continue = False 81 | opt.device = device 82 | 83 | return opt -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | class nonlinearity(nn.Module): 5 | def __init__(self): 6 | super().__init__() 7 | 8 | def forward(self, x): 9 | # swish 10 | return x * torch.sigmoid(x) 11 | 12 | class ResConv1DBlock(nn.Module): 13 | def __init__(self, n_in, n_state, dilation=1, activation='silu', norm=None, dropout=None): 14 | super().__init__() 15 | padding = dilation 16 | self.norm = norm 17 | if norm == "LN": 18 | self.norm1 = nn.LayerNorm(n_in) 19 | self.norm2 = nn.LayerNorm(n_in) 20 | elif norm == "GN": 21 | self.norm1 = nn.GroupNorm(num_groups=32, num_channels=n_in, eps=1e-6, affine=True) 22 | self.norm2 = nn.GroupNorm(num_groups=32, num_channels=n_in, eps=1e-6, affine=True) 23 | elif norm == "BN": 24 | self.norm1 = nn.BatchNorm1d(num_features=n_in, eps=1e-6, affine=True) 25 | self.norm2 = nn.BatchNorm1d(num_features=n_in, eps=1e-6, affine=True) 26 | 27 | else: 28 | self.norm1 = nn.Identity() 29 | self.norm2 = nn.Identity() 30 | 31 | if activation == "relu": 32 | self.activation1 = nn.ReLU() 33 | self.activation2 = nn.ReLU() 34 | 35 | elif activation == "silu": 36 | self.activation1 = nonlinearity() 37 | self.activation2 = nonlinearity() 38 | 39 | elif activation == "gelu": 40 | self.activation1 = nn.GELU() 41 | self.activation2 = nn.GELU() 42 | 43 | 44 | 45 | self.conv1 = nn.Conv1d(n_in, n_state, 3, 1, padding, dilation) 46 | self.conv2 = nn.Conv1d(n_state, n_in, 1, 1, 0,) 47 | 48 | 49 | def forward(self, x): 50 | x_orig = x 51 | if self.norm == "LN": 52 | x = self.norm1(x.transpose(-2, -1)) 53 | x = self.activation1(x.transpose(-2, -1)) 54 | else: 55 | x = self.norm1(x) 56 | x = self.activation1(x) 57 | 58 | x = self.conv1(x) 59 | 60 | if self.norm == "LN": 61 | x = self.norm2(x.transpose(-2, -1)) 62 | x = self.activation2(x.transpose(-2, -1)) 63 | else: 64 | x = self.norm2(x) 65 | x = self.activation2(x) 66 | 67 | x = self.conv2(x) 68 | x = x + x_orig 69 | return x 70 | 71 | class Resnet1D(nn.Module): 72 | def __init__(self, n_in, n_depth, dilation_growth_rate=1, reverse_dilation=True, activation='relu', norm=None): 73 | super().__init__() 74 | 75 | blocks = [ResConv1DBlock(n_in, n_in, dilation=dilation_growth_rate ** depth, activation=activation, norm=norm) for depth in range(n_depth)] 76 | if reverse_dilation: 77 | blocks = blocks[::-1] 78 | 79 | self.model = nn.Sequential(*blocks) 80 | 81 | def forward(self, x): 82 | return self.model(x) -------------------------------------------------------------------------------- /utils/model_util.py: -------------------------------------------------------------------------------- 1 | from model.mdm import MDM 2 | from diffusion import gaussian_diffusion as gd 3 | from diffusion.respace import SpacedDiffusion, space_timesteps 4 | from utils.parser_util import get_cond_mode 5 | 6 | 7 | def load_model_wo_clip(model, state_dict): 8 | missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) 9 | assert len(unexpected_keys) == 0 10 | assert all([k.startswith('clip_model.') for k in missing_keys]) 11 | 12 | 13 | def create_model_and_diffusion(args, data): 14 | model = MDM(**get_model_args(args, data)) 15 | diffusion = create_gaussian_diffusion(args) 16 | return model, diffusion 17 | 18 | 19 | def get_model_args(args, data): 20 | 21 | # default args 22 | clip_version = 'ViT-B/32' 23 | action_emb = 'tensor' 24 | cond_mode = get_cond_mode(args) 25 | if hasattr(data.dataset, 'num_actions'): 26 | num_actions = data.dataset.num_actions 27 | else: 28 | num_actions = 1 29 | 30 | # SMPL defaults 31 | data_rep = 'rot6d' 32 | njoints = 25 33 | nfeats = 6 34 | 35 | if args.dataset == 'humanml': 36 | data_rep = 'hml_vec' 37 | njoints = 263 38 | nfeats = 1 39 | elif args.dataset == 'kit': 40 | data_rep = 'hml_vec' 41 | njoints = 251 42 | nfeats = 1 43 | 44 | return {'modeltype': '', 'njoints': njoints, 'nfeats': nfeats, 'num_actions': num_actions, 45 | 'translation': True, 'pose_rep': 'rot6d', 'glob': True, 'glob_rot': True, 46 | 'latent_dim': args.latent_dim, 'ff_size': 1024, 'num_layers': args.layers, 'num_heads': 4, 47 | 'dropout': 0.1, 'activation': "gelu", 'data_rep': data_rep, 'cond_mode': cond_mode, 48 | 'cond_mask_prob': args.cond_mask_prob, 'action_emb': action_emb, 'arch': args.arch, 49 | 'emb_trans_dec': args.emb_trans_dec, 'clip_version': clip_version, 'dataset': args.dataset} 50 | 51 | 52 | def create_gaussian_diffusion(args): 53 | # default params 54 | predict_xstart = True # we always predict x_start (a.k.a. x0), that's our deal! 55 | steps = args.diffusion_steps 56 | scale_beta = 1. # no scaling 57 | timestep_respacing = '' # can be used for ddim sampling, we don't use it. 58 | learn_sigma = False 59 | rescale_timesteps = False 60 | 61 | betas = gd.get_named_beta_schedule(args.noise_schedule, steps, scale_beta) 62 | loss_type = gd.LossType.MSE 63 | 64 | if not timestep_respacing: 65 | timestep_respacing = [steps] 66 | 67 | return SpacedDiffusion( 68 | use_timesteps=space_timesteps(steps, timestep_respacing), 69 | betas=betas, 70 | model_mean_type=( 71 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X 72 | ), 73 | model_var_type=( 74 | ( 75 | gd.ModelVarType.FIXED_LARGE 76 | if not args.sigma_small 77 | else gd.ModelVarType.FIXED_SMALL 78 | ) 79 | if not learn_sigma 80 | else gd.ModelVarType.LEARNED_RANGE 81 | ), 82 | loss_type=loss_type, 83 | rescale_timesteps=rescale_timesteps, 84 | lambda_vel=args.lambda_vel, 85 | lambda_rcxyz=args.lambda_rcxyz, 86 | lambda_fc=args.lambda_fc, 87 | ) 88 | -------------------------------------------------------------------------------- /utils/word_vectorizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pickle 3 | from os.path import join as pjoin 4 | 5 | POS_enumerator = { 6 | 'VERB': 0, 7 | 'NOUN': 1, 8 | 'DET': 2, 9 | 'ADP': 3, 10 | 'NUM': 4, 11 | 'AUX': 5, 12 | 'PRON': 6, 13 | 'ADJ': 7, 14 | 'ADV': 8, 15 | 'Loc_VIP': 9, 16 | 'Body_VIP': 10, 17 | 'Obj_VIP': 11, 18 | 'Act_VIP': 12, 19 | 'Desc_VIP': 13, 20 | 'OTHER': 14, 21 | } 22 | 23 | Loc_list = ('left', 'right', 'clockwise', 'counterclockwise', 'anticlockwise', 'forward', 'back', 'backward', 24 | 'up', 'down', 'straight', 'curve') 25 | 26 | Body_list = ('arm', 'chin', 'foot', 'feet', 'face', 'hand', 'mouth', 'leg', 'waist', 'eye', 'knee', 'shoulder', 'thigh') 27 | 28 | Obj_List = ('stair', 'dumbbell', 'chair', 'window', 'floor', 'car', 'ball', 'handrail', 'baseball', 'basketball') 29 | 30 | Act_list = ('walk', 'run', 'swing', 'pick', 'bring', 'kick', 'put', 'squat', 'throw', 'hop', 'dance', 'jump', 'turn', 31 | 'stumble', 'dance', 'stop', 'sit', 'lift', 'lower', 'raise', 'wash', 'stand', 'kneel', 'stroll', 32 | 'rub', 'bend', 'balance', 'flap', 'jog', 'shuffle', 'lean', 'rotate', 'spin', 'spread', 'climb') 33 | 34 | Desc_list = ('slowly', 'carefully', 'fast', 'careful', 'slow', 'quickly', 'happy', 'angry', 'sad', 'happily', 35 | 'angrily', 'sadly') 36 | 37 | VIP_dict = { 38 | 'Loc_VIP': Loc_list, 39 | 'Body_VIP': Body_list, 40 | 'Obj_VIP': Obj_List, 41 | 'Act_VIP': Act_list, 42 | 'Desc_VIP': Desc_list, 43 | } 44 | 45 | 46 | class WordVectorizer(object): 47 | def __init__(self, meta_root, prefix): 48 | vectors = np.load(pjoin(meta_root, '%s_data.npy'%prefix)) 49 | words = pickle.load(open(pjoin(meta_root, '%s_words.pkl'%prefix), 'rb')) 50 | self.word2idx = pickle.load(open(pjoin(meta_root, '%s_idx.pkl'%prefix), 'rb')) 51 | self.word2vec = {w: vectors[self.word2idx[w]] for w in words} 52 | 53 | def _get_pos_ohot(self, pos): 54 | pos_vec = np.zeros(len(POS_enumerator)) 55 | if pos in POS_enumerator: 56 | pos_vec[POS_enumerator[pos]] = 1 57 | else: 58 | pos_vec[POS_enumerator['OTHER']] = 1 59 | return pos_vec 60 | 61 | def __len__(self): 62 | return len(self.word2vec) 63 | 64 | def __getitem__(self, item): 65 | word, pos = item.split('/') 66 | if word in self.word2vec: 67 | word_vec = self.word2vec[word] 68 | vip_pos = None 69 | for key, values in VIP_dict.items(): 70 | if word in values: 71 | vip_pos = key 72 | break 73 | if vip_pos is not None: 74 | pos_vec = self._get_pos_ohot(vip_pos) 75 | else: 76 | pos_vec = self._get_pos_ohot(pos) 77 | else: 78 | word_vec = self.word2vec['unk'] 79 | pos_vec = self._get_pos_ohot('OTHER') 80 | return word_vec, pos_vec 81 | 82 | 83 | class WordVectorizerV2(WordVectorizer): 84 | def __init__(self, meta_root, prefix): 85 | super(WordVectorizerV2, self).__init__(meta_root, prefix) 86 | self.idx2word = {self.word2idx[w]: w for w in self.word2idx} 87 | 88 | def __getitem__(self, item): 89 | word_vec, pose_vec = super(WordVectorizerV2, self).__getitem__(item) 90 | word, pos = item.split('/') 91 | if word in self.word2vec: 92 | return word_vec, pose_vec, self.word2idx[word] 93 | else: 94 | return word_vec, pose_vec, self.word2idx['unk'] 95 | 96 | def itos(self, idx): 97 | if idx == len(self.idx2word): 98 | return "pad" 99 | return self.idx2word[idx] -------------------------------------------------------------------------------- /utils/rotation2xyz.py: -------------------------------------------------------------------------------- 1 | # This code is based on https://github.com/Mathux/ACTOR.git 2 | import torch 3 | import utils.rotation_conversions as geometry 4 | 5 | 6 | from utils.smpl import SMPL, JOINTSTYPE_ROOT 7 | # from .get_model import JOINTSTYPES 8 | JOINTSTYPES = ["a2m", "a2mpl", "smpl", "vibe", "vertices"] 9 | 10 | 11 | class Rotation2xyz: 12 | def __init__(self, device, dataset='amass'): 13 | self.device = device 14 | self.dataset = dataset 15 | self.smpl_model = SMPL().eval().to(device) 16 | 17 | def __call__(self, x, mask, pose_rep, translation, glob, 18 | jointstype, vertstrans, betas=None, beta=0, 19 | glob_rot=None, get_rotations_back=False, **kwargs): 20 | if pose_rep == "xyz": 21 | return x 22 | 23 | if mask is None: 24 | mask = torch.ones((x.shape[0], x.shape[-1]), dtype=bool, device=x.device) 25 | 26 | if not glob and glob_rot is None: 27 | raise TypeError("You must specify global rotation if glob is False") 28 | 29 | if jointstype not in JOINTSTYPES: 30 | raise NotImplementedError("This jointstype is not implemented.") 31 | 32 | if translation: 33 | x_translations = x[:, -1, :3] 34 | x_rotations = x[:, :-1] 35 | else: 36 | x_rotations = x 37 | 38 | x_rotations = x_rotations.permute(0, 3, 1, 2) 39 | nsamples, time, njoints, feats = x_rotations.shape 40 | 41 | # Compute rotations (convert only masked sequences output) 42 | if pose_rep == "rotvec": 43 | rotations = geometry.axis_angle_to_matrix(x_rotations[mask]) 44 | elif pose_rep == "rotmat": 45 | rotations = x_rotations[mask].view(-1, njoints, 3, 3) 46 | elif pose_rep == "rotquat": 47 | rotations = geometry.quaternion_to_matrix(x_rotations[mask]) 48 | elif pose_rep == "rot6d": 49 | rotations = geometry.rotation_6d_to_matrix(x_rotations[mask]) 50 | else: 51 | raise NotImplementedError("No geometry for this one.") 52 | 53 | if not glob: 54 | global_orient = torch.tensor(glob_rot, device=x.device) 55 | global_orient = geometry.axis_angle_to_matrix(global_orient).view(1, 1, 3, 3) 56 | global_orient = global_orient.repeat(len(rotations), 1, 1, 1) 57 | else: 58 | global_orient = rotations[:, 0] 59 | rotations = rotations[:, 1:] 60 | 61 | if betas is None: 62 | betas = torch.zeros([rotations.shape[0], self.smpl_model.num_betas], 63 | dtype=rotations.dtype, device=rotations.device) 64 | betas[:, 1] = beta 65 | # import ipdb; ipdb.set_trace() 66 | out = self.smpl_model(body_pose=rotations, global_orient=global_orient, betas=betas) 67 | 68 | # get the desirable joints 69 | joints = out[jointstype] 70 | 71 | x_xyz = torch.empty(nsamples, time, joints.shape[1], 3, device=x.device, dtype=x.dtype) 72 | x_xyz[~mask] = 0 73 | x_xyz[mask] = joints 74 | 75 | x_xyz = x_xyz.permute(0, 2, 3, 1).contiguous() 76 | 77 | # the first translation root at the origin on the prediction 78 | if jointstype != "vertices": 79 | rootindex = JOINTSTYPE_ROOT[jointstype] 80 | x_xyz = x_xyz - x_xyz[:, [rootindex], :, :] 81 | 82 | if translation and vertstrans: 83 | # the first translation root at the origin 84 | x_translations = x_translations - x_translations[:, :, [0]] 85 | 86 | # add the translation to all the joints 87 | x_xyz = x_xyz + x_translations[:, None, :, :] 88 | 89 | if get_rotations_back: 90 | return x_xyz, rotations, global_orient 91 | else: 92 | return x_xyz -------------------------------------------------------------------------------- /visualization/vis_utils.py: -------------------------------------------------------------------------------- 1 | from utils.rotation2xyz import Rotation2xyz 2 | import numpy as np 3 | from trimesh import Trimesh 4 | import os 5 | import torch 6 | from visualization.simplify_loc2rot import joints2smpl 7 | 8 | class npy2obj: 9 | def __init__(self, npy_path, sample_idx, rep_idx, device=0, cuda=True, n_iter=2): 10 | self.npy_path = npy_path 11 | self.motions = np.load(self.npy_path, allow_pickle=True) 12 | if self.npy_path.endswith('.npz'): 13 | self.motions = self.motions['arr_0'] 14 | self.motions = self.motions[None][0] 15 | self.rot2xyz = Rotation2xyz(device='cpu') 16 | self.faces = self.rot2xyz.smpl_model.faces 17 | if self.motions.ndim == 3: 18 | m = torch.from_numpy(self.motions).permute(1, 2, 0).unsqueeze(0).numpy() 19 | self.motions = {} 20 | self.motions.update({'motion': m}) 21 | self.motions.update({'text': ['']}) 22 | self.motions.update({'num_samples': 1}) 23 | self.motions.update({'lengths': [m.shape[-1]]}) 24 | self.bs, self.njoints, self.nfeats, self.nframes = self.motions['motion'].shape 25 | self.opt_cache = {} 26 | self.sample_idx = sample_idx 27 | self.total_num_samples = self.motions['num_samples'] 28 | self.rep_idx = rep_idx 29 | self.absl_idx = self.rep_idx*self.total_num_samples + self.sample_idx 30 | self.num_frames = self.motions['motion'][self.absl_idx].shape[-1] 31 | self.j2s = joints2smpl(num_frames=self.num_frames, device=device, n_iter=n_iter) 32 | 33 | if self.nfeats == 3: 34 | print(f'Running SMPLify For sample [{sample_idx}], repetition [{rep_idx}], it may take a few minutes.') 35 | motion_tensor, opt_dict = self.j2s.joint2smpl(self.motions['motion'][self.absl_idx].transpose(2, 0, 1)) # [nframes, njoints, 3] 36 | self.motions['motion'] = motion_tensor.cpu().numpy() 37 | elif self.nfeats == 6: 38 | self.motions['motion'] = self.motions['motion'][[self.absl_idx]] 39 | self.bs, self.njoints, self.nfeats, self.nframes = self.motions['motion'].shape 40 | self.real_num_frames = self.motions['lengths'][self.absl_idx] 41 | 42 | self.vertices = self.rot2xyz(torch.tensor(self.motions['motion']), mask=None, 43 | pose_rep='rot6d', translation=True, glob=True, 44 | jointstype='vertices', 45 | # jointstype='smpl', # for joint locations 46 | vertstrans=True) 47 | self.root_loc = self.motions['motion'][:, -1, :3, :].reshape(1, 1, 3, -1) 48 | # self.vertices += self.root_loc 49 | 50 | def get_vertices(self, sample_i, frame_i): 51 | return self.vertices[sample_i, :, :, frame_i].squeeze().tolist() 52 | 53 | def get_trimesh(self, sample_i, frame_i): 54 | return Trimesh(vertices=self.get_vertices(sample_i, frame_i), 55 | faces=self.faces) 56 | 57 | def save_obj(self, save_path, frame_i): 58 | mesh = self.get_trimesh(0, frame_i) 59 | with open(save_path, 'w') as fw: 60 | mesh.export(fw, 'obj') 61 | return save_path 62 | 63 | def save_npy(self, save_path): 64 | data_dict = { 65 | 'motion': self.motions['motion'][0, :, :, :self.real_num_frames], 66 | 'thetas': self.motions['motion'][0, :-1, :, :self.real_num_frames], 67 | 'root_translation': self.motions['motion'][0, -1, :3, :self.real_num_frames], 68 | 'faces': self.faces, 69 | 'vertices': self.vertices[0, :, :, :self.real_num_frames], 70 | 'text': self.motions['text'][0], 71 | 'length': self.real_num_frames, 72 | } 73 | np.save(save_path, data_dict) 74 | -------------------------------------------------------------------------------- /visualization/motions2hik.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from utils.rotation_conversions import rotation_6d_to_matrix, matrix_to_euler_angles 5 | from visualization.simplify_loc2rot import joints2smpl 6 | 7 | """ 8 | Utility function to convert model output to a representation used by HumanIK skeletons in Maya and Motion Builder 9 | by converting joint positions to joint rotations in degrees. Based on visualization.vis_utils.npy2obj 10 | """ 11 | 12 | # Mapping of SMPL joint index to HIK joint Name 13 | JOINT_MAP = [ 14 | 'Hips', 15 | 'LeftUpLeg', 16 | 'RightUpLeg', 17 | 'Spine', 18 | 'LeftLeg', 19 | 'RightLeg', 20 | 'Spine1', 21 | 'LeftFoot', 22 | 'RightFoot', 23 | 'Spine2', 24 | 'LeftToeBase', 25 | 'RightToeBase', 26 | 'Neck', 27 | 'LeftShoulder', 28 | 'RightShoulder', 29 | 'Head', 30 | 'LeftArm', 31 | 'RightArm', 32 | 'LeftForeArm', 33 | 'RightForeArm', 34 | 'LeftHand', 35 | 'RightHand' 36 | ] 37 | 38 | 39 | def motions2hik(motions, device=0, cuda=True): 40 | """ 41 | Utility function to convert model output to a representation used by HumanIK skeletons in Maya and Motion Builder 42 | by converting joint positions to joint rotations in degrees. Based on visualization.vis_utils.npy2obj 43 | 44 | :param motions: numpy array containing MDM model output [num_reps, num_joints, num_params (xyz), num_frames 45 | :param device: 46 | :param cuda: 47 | 48 | :returns: JSON serializable dict to be used with the Replicate API implementation 49 | """ 50 | 51 | nreps, njoints, nfeats, nframes = motions.shape 52 | j2s = joints2smpl(num_frames=nframes, device_id=device, cuda=cuda) 53 | 54 | thetas = [] 55 | root_translation = [] 56 | for rep_idx in range(nreps): 57 | rep_motions = motions[rep_idx].transpose(2, 0, 1) # [nframes, njoints, 3] 58 | 59 | if nfeats == 3: 60 | print(f'Running SMPLify for repetition [{rep_idx + 1}] of {nreps}, it may take a few minutes.') 61 | motion_tensor, opt_dict = j2s.joint2smpl(rep_motions) # [nframes, njoints, 3] 62 | motion = motion_tensor.cpu().numpy() 63 | 64 | elif nfeats == 6: 65 | motion = rep_motions 66 | thetas.append(rep_motions) 67 | 68 | # Convert 6D rotation representation to Euler angles 69 | thetas_6d = motion[0, :-1, :, :nframes].transpose(2, 0, 1) # [nframes, njoints, 6] 70 | thetas_deg = [] 71 | for frame, d6 in enumerate(thetas_6d): 72 | thetas_deg.append([_rotation_6d_to_euler(d6)]) 73 | 74 | thetas.append([np.concatenate(thetas_deg, axis=0)]) 75 | root_translation.append([motion[0, -1, :3, :nframes].transpose(1, 0)]) # [nframes, 3] 76 | 77 | thetas = np.concatenate(thetas, axis=0)[:nframes] 78 | root_translation = np.concatenate(root_translation, axis=0)[:nframes] 79 | 80 | data_dict = { 81 | 'joint_map': JOINT_MAP, 82 | 'thetas': thetas.tolist(), # [nreps, nframes, njoints, 3 (deg)] 83 | 'root_translation': root_translation.tolist(), # [nreps, nframes, 3 (xyz)] 84 | } 85 | 86 | return data_dict 87 | 88 | 89 | def _rotation_6d_to_euler(d6): 90 | """ 91 | Converts 6D rotation representation by Zhou et al. [1] to euler angles 92 | using Gram--Schmidt orthogonalisation per Section B of [1]. 93 | 94 | :param d6: numpy Array 6D rotation representation, of size (*, 6) 95 | :returns: JSON serializable dict to be used with the Replicate API implementation 96 | :returns: euler angles in degrees as a numpy array with shape (*, 3) 97 | """ 98 | rot_mat = rotation_6d_to_matrix(torch.tensor(d6)) 99 | rot_eul_rad = matrix_to_euler_angles(rot_mat, 'XYZ') 100 | eul_deg = torch.rad2deg(rot_eul_rad).numpy() 101 | 102 | return eul_deg 103 | 104 | -------------------------------------------------------------------------------- /dataset/dataset_VQ.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils import data 3 | import numpy as np 4 | from os.path import join as pjoin 5 | import random 6 | import codecs as cs 7 | from tqdm import tqdm 8 | 9 | 10 | 11 | class VQMotionDataset(data.Dataset): 12 | def __init__(self, dataset_name, window_size = 64, unit_length = 4, args=None): 13 | self.window_size = window_size 14 | self.unit_length = unit_length 15 | self.dataset_name = dataset_name 16 | 17 | if dataset_name == 't2m': 18 | self.data_root = './dataset/HumanML3D' 19 | self.motion_dir = pjoin(self.data_root, 'new_joint_vecs') 20 | self.text_dir = pjoin(self.data_root, 'texts') 21 | self.joints_num = 22 22 | self.max_motion_length = 196 23 | self.meta_dir = 'checkpoints/t2m/VQVAEV3_CB1024_CMT_H1024_NRES3/meta' 24 | 25 | elif dataset_name == 'kit': 26 | self.data_root = './dataset/KIT-ML' 27 | self.motion_dir = pjoin(self.data_root, 'new_joint_vecs') 28 | self.text_dir = pjoin(self.data_root, 'texts') 29 | self.joints_num = 21 30 | 31 | self.max_motion_length = 196 32 | self.meta_dir = 'checkpoints/kit/VQVAEV3_CB1024_CMT_H1024_NRES3/meta' 33 | 34 | joints_num = self.joints_num 35 | 36 | mean = np.load(pjoin(self.meta_dir, 'mean.npy')) 37 | std = np.load(pjoin(self.meta_dir, 'std.npy')) 38 | 39 | split_file = pjoin(self.data_root, 'train.txt') 40 | 41 | self.data = [] 42 | self.lengths = [] 43 | id_list = [] 44 | with cs.open(split_file, 'r') as f: 45 | for line in f.readlines(): 46 | id_list.append(line.strip()) 47 | 48 | i_debug = 0 49 | for name in tqdm(id_list): 50 | try: 51 | motion = np.load(pjoin(self.motion_dir, name + '.npy')) 52 | if motion.shape[0] < self.window_size: 53 | continue 54 | self.lengths.append(motion.shape[0] - self.window_size) 55 | self.data.append(motion) 56 | 57 | if args.debug: 58 | if i_debug >= args.maxdata: 59 | break 60 | i_debug += 1 61 | 62 | except: 63 | # Some motion may not exist in KIT dataset 64 | pass 65 | 66 | 67 | self.mean = mean 68 | self.std = std 69 | print("Total number of motions {}".format(len(self.data))) 70 | 71 | def inv_transform(self, data): 72 | return data * self.std + self.mean 73 | 74 | def compute_sampling_prob(self) : 75 | 76 | prob = np.array(self.lengths, dtype=np.float32) 77 | prob /= np.sum(prob) 78 | return prob 79 | 80 | def __len__(self): 81 | return len(self.data) 82 | 83 | def __getitem__(self, item): 84 | motion = self.data[item] 85 | 86 | idx = random.randint(0, len(motion) - self.window_size) 87 | 88 | motion = motion[idx:idx+self.window_size] 89 | "Z Normalization" 90 | motion = (motion - self.mean) / self.std 91 | 92 | return motion 93 | 94 | def VQMDataset(dataset_name, 95 | batch_size, 96 | num_workers = 8, 97 | window_size = 64, 98 | unit_length = 4, args=None): 99 | 100 | trainSet = VQMotionDataset(dataset_name, window_size=window_size, unit_length=unit_length, args=args) 101 | # prob = trainSet.compute_sampling_prob() 102 | # sampler = torch.utils.data.WeightedRandomSampler(prob, num_samples = len(trainSet) * 1000, replacement=True) 103 | return trainSet 104 | 105 | 106 | return train_loader 107 | 108 | def cycle(iterable): 109 | while True: 110 | for x in iterable: 111 | yield x 112 | -------------------------------------------------------------------------------- /models/evaluator_wrapper.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from os.path import join as pjoin 4 | import numpy as np 5 | from models.modules import MovementConvEncoder, TextEncoderBiGRUCo, MotionEncoderBiGRUCo 6 | from utils.word_vectorizer import POS_enumerator 7 | 8 | def build_models(opt): 9 | movement_enc = MovementConvEncoder(opt.dim_pose-4, opt.dim_movement_enc_hidden, opt.dim_movement_latent) 10 | text_enc = TextEncoderBiGRUCo(word_size=opt.dim_word, 11 | pos_size=opt.dim_pos_ohot, 12 | hidden_size=opt.dim_text_hidden, 13 | output_size=opt.dim_coemb_hidden, 14 | device=opt.device) 15 | 16 | motion_enc = MotionEncoderBiGRUCo(input_size=opt.dim_movement_latent, 17 | hidden_size=opt.dim_motion_hidden, 18 | output_size=opt.dim_coemb_hidden, 19 | device=opt.device) 20 | 21 | checkpoint = torch.load(pjoin(opt.checkpoints_dir, opt.dataset_name, 'text_mot_match', 'model', 'finest.tar'), 22 | map_location=opt.device) 23 | movement_enc.load_state_dict(checkpoint['movement_encoder']) 24 | text_enc.load_state_dict(checkpoint['text_encoder']) 25 | motion_enc.load_state_dict(checkpoint['motion_encoder']) 26 | print('Loading Evaluation Model Wrapper (Epoch %d) Completed!!' % (checkpoint['epoch'])) 27 | return text_enc, motion_enc, movement_enc 28 | 29 | 30 | class EvaluatorModelWrapper(object): 31 | 32 | def __init__(self, opt): 33 | 34 | if opt.dataset_name == 't2m': 35 | opt.dim_pose = 263 36 | elif opt.dataset_name == 'kit': 37 | opt.dim_pose = 251 38 | else: 39 | raise KeyError('Dataset not Recognized!!!') 40 | 41 | opt.dim_word = 300 42 | opt.max_motion_length = 196 43 | opt.dim_pos_ohot = len(POS_enumerator) 44 | opt.dim_motion_hidden = 1024 45 | opt.max_text_len = 20 46 | opt.dim_text_hidden = 512 47 | opt.dim_coemb_hidden = 512 48 | 49 | # print(opt) 50 | 51 | self.text_encoder, self.motion_encoder, self.movement_encoder = build_models(opt) 52 | self.opt = opt 53 | self.device = opt.device 54 | 55 | self.text_encoder.to(opt.device) 56 | self.motion_encoder.to(opt.device) 57 | self.movement_encoder.to(opt.device) 58 | 59 | self.text_encoder.eval() 60 | self.motion_encoder.eval() 61 | self.movement_encoder.eval() 62 | 63 | # Please note that the results does not following the order of inputs 64 | def get_co_embeddings(self, word_embs, pos_ohot, cap_lens, motions, m_lens): 65 | with torch.no_grad(): 66 | word_embs = word_embs.detach().to(self.device).float() 67 | pos_ohot = pos_ohot.detach().to(self.device).float() 68 | motions = motions.detach().to(self.device).float() 69 | 70 | '''Movement Encoding''' 71 | movements = self.movement_encoder(motions[..., :-4]).detach() 72 | m_lens = m_lens // self.opt.unit_length 73 | motion_embedding = self.motion_encoder(movements, m_lens) 74 | 75 | '''Text Encoding''' 76 | text_embedding = self.text_encoder(word_embs, pos_ohot, cap_lens) 77 | return text_embedding, motion_embedding 78 | 79 | # Please note that the results does not following the order of inputs 80 | def get_motion_embeddings(self, motions, m_lens): 81 | with torch.no_grad(): 82 | motions = motions.detach().to(self.device).float() 83 | 84 | align_idx = np.argsort(m_lens.data.tolist())[::-1].copy() 85 | motions = motions[align_idx] 86 | m_lens = m_lens[align_idx] 87 | 88 | '''Movement Encoding''' 89 | movements = self.movement_encoder(motions[..., :-4]).detach() 90 | m_lens = m_lens // self.opt.unit_length 91 | motion_embedding = self.motion_encoder(movements, m_lens) 92 | return motion_embedding 93 | -------------------------------------------------------------------------------- /utils/smpl.py: -------------------------------------------------------------------------------- 1 | # This code is based on https://github.com/Mathux/ACTOR.git 2 | import numpy as np 3 | import torch 4 | 5 | import contextlib 6 | 7 | from smplx import SMPLLayer as _SMPLLayer 8 | from smplx.lbs import vertices2joints 9 | 10 | 11 | # action2motion_joints = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 21, 24, 38] 12 | # change 0 and 8 13 | action2motion_joints = [8, 1, 2, 3, 4, 5, 6, 7, 0, 9, 10, 11, 12, 13, 14, 21, 24, 38] 14 | 15 | from utils.config import SMPL_MODEL_PATH, JOINT_REGRESSOR_TRAIN_EXTRA 16 | 17 | JOINTSTYPE_ROOT = {"a2m": 0, # action2motion 18 | "smpl": 0, 19 | "a2mpl": 0, # set(smpl, a2m) 20 | "vibe": 8} # 0 is the 8 position: OP MidHip below 21 | 22 | JOINT_MAP = { 23 | 'OP Nose': 24, 'OP Neck': 12, 'OP RShoulder': 17, 24 | 'OP RElbow': 19, 'OP RWrist': 21, 'OP LShoulder': 16, 25 | 'OP LElbow': 18, 'OP LWrist': 20, 'OP MidHip': 0, 26 | 'OP RHip': 2, 'OP RKnee': 5, 'OP RAnkle': 8, 27 | 'OP LHip': 1, 'OP LKnee': 4, 'OP LAnkle': 7, 28 | 'OP REye': 25, 'OP LEye': 26, 'OP REar': 27, 29 | 'OP LEar': 28, 'OP LBigToe': 29, 'OP LSmallToe': 30, 30 | 'OP LHeel': 31, 'OP RBigToe': 32, 'OP RSmallToe': 33, 'OP RHeel': 34, 31 | 'Right Ankle': 8, 'Right Knee': 5, 'Right Hip': 45, 32 | 'Left Hip': 46, 'Left Knee': 4, 'Left Ankle': 7, 33 | 'Right Wrist': 21, 'Right Elbow': 19, 'Right Shoulder': 17, 34 | 'Left Shoulder': 16, 'Left Elbow': 18, 'Left Wrist': 20, 35 | 'Neck (LSP)': 47, 'Top of Head (LSP)': 48, 36 | 'Pelvis (MPII)': 49, 'Thorax (MPII)': 50, 37 | 'Spine (H36M)': 51, 'Jaw (H36M)': 52, 38 | 'Head (H36M)': 53, 'Nose': 24, 'Left Eye': 26, 39 | 'Right Eye': 25, 'Left Ear': 28, 'Right Ear': 27 40 | } 41 | 42 | JOINT_NAMES = [ 43 | 'OP Nose', 'OP Neck', 'OP RShoulder', 44 | 'OP RElbow', 'OP RWrist', 'OP LShoulder', 45 | 'OP LElbow', 'OP LWrist', 'OP MidHip', 46 | 'OP RHip', 'OP RKnee', 'OP RAnkle', 47 | 'OP LHip', 'OP LKnee', 'OP LAnkle', 48 | 'OP REye', 'OP LEye', 'OP REar', 49 | 'OP LEar', 'OP LBigToe', 'OP LSmallToe', 50 | 'OP LHeel', 'OP RBigToe', 'OP RSmallToe', 'OP RHeel', 51 | 'Right Ankle', 'Right Knee', 'Right Hip', 52 | 'Left Hip', 'Left Knee', 'Left Ankle', 53 | 'Right Wrist', 'Right Elbow', 'Right Shoulder', 54 | 'Left Shoulder', 'Left Elbow', 'Left Wrist', 55 | 'Neck (LSP)', 'Top of Head (LSP)', 56 | 'Pelvis (MPII)', 'Thorax (MPII)', 57 | 'Spine (H36M)', 'Jaw (H36M)', 58 | 'Head (H36M)', 'Nose', 'Left Eye', 59 | 'Right Eye', 'Left Ear', 'Right Ear' 60 | ] 61 | 62 | 63 | # adapted from VIBE/SPIN to output smpl_joints, vibe joints and action2motion joints 64 | class SMPL(_SMPLLayer): 65 | """ Extension of the official SMPL implementation to support more joints """ 66 | 67 | def __init__(self, model_path=SMPL_MODEL_PATH, **kwargs): 68 | kwargs["model_path"] = model_path 69 | 70 | # remove the verbosity for the 10-shapes beta parameters 71 | with contextlib.redirect_stdout(None): 72 | super(SMPL, self).__init__(**kwargs) 73 | 74 | J_regressor_extra = np.load(JOINT_REGRESSOR_TRAIN_EXTRA) 75 | self.register_buffer('J_regressor_extra', torch.tensor(J_regressor_extra, dtype=torch.float32)) 76 | vibe_indexes = np.array([JOINT_MAP[i] for i in JOINT_NAMES]) 77 | a2m_indexes = vibe_indexes[action2motion_joints] 78 | smpl_indexes = np.arange(24) 79 | a2mpl_indexes = np.unique(np.r_[smpl_indexes, a2m_indexes]) 80 | 81 | self.maps = {"vibe": vibe_indexes, 82 | "a2m": a2m_indexes, 83 | "smpl": smpl_indexes, 84 | "a2mpl": a2mpl_indexes} 85 | 86 | def forward(self, *args, **kwargs): 87 | smpl_output = super(SMPL, self).forward(*args, **kwargs) 88 | 89 | extra_joints = vertices2joints(self.J_regressor_extra, smpl_output.vertices) 90 | all_joints = torch.cat([smpl_output.joints, extra_joints], dim=1) 91 | 92 | output = {"vertices": smpl_output.vertices} 93 | 94 | for joinstype, indexes in self.maps.items(): 95 | output[joinstype] = all_joints[:, indexes] 96 | 97 | return output -------------------------------------------------------------------------------- /models/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.utils.rnn import pack_padded_sequence 4 | 5 | def init_weight(m): 6 | if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear) or isinstance(m, nn.ConvTranspose1d): 7 | nn.init.xavier_normal_(m.weight) 8 | # m.bias.data.fill_(0.01) 9 | if m.bias is not None: 10 | nn.init.constant_(m.bias, 0) 11 | 12 | 13 | class MovementConvEncoder(nn.Module): 14 | def __init__(self, input_size, hidden_size, output_size): 15 | super(MovementConvEncoder, self).__init__() 16 | self.main = nn.Sequential( 17 | nn.Conv1d(input_size, hidden_size, 4, 2, 1), 18 | nn.Dropout(0.2, inplace=True), 19 | nn.LeakyReLU(0.2, inplace=True), 20 | nn.Conv1d(hidden_size, output_size, 4, 2, 1), 21 | nn.Dropout(0.2, inplace=True), 22 | nn.LeakyReLU(0.2, inplace=True), 23 | ) 24 | self.out_net = nn.Linear(output_size, output_size) 25 | self.main.apply(init_weight) 26 | self.out_net.apply(init_weight) 27 | 28 | def forward(self, inputs): 29 | inputs = inputs.permute(0, 2, 1) 30 | outputs = self.main(inputs).permute(0, 2, 1) 31 | # print(outputs.shape) 32 | return self.out_net(outputs) 33 | 34 | 35 | 36 | class TextEncoderBiGRUCo(nn.Module): 37 | def __init__(self, word_size, pos_size, hidden_size, output_size, device): 38 | super(TextEncoderBiGRUCo, self).__init__() 39 | self.device = device 40 | 41 | self.pos_emb = nn.Linear(pos_size, word_size) 42 | self.input_emb = nn.Linear(word_size, hidden_size) 43 | self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True) 44 | self.output_net = nn.Sequential( 45 | nn.Linear(hidden_size * 2, hidden_size), 46 | nn.LayerNorm(hidden_size), 47 | nn.LeakyReLU(0.2, inplace=True), 48 | nn.Linear(hidden_size, output_size) 49 | ) 50 | 51 | self.input_emb.apply(init_weight) 52 | self.pos_emb.apply(init_weight) 53 | self.output_net.apply(init_weight) 54 | self.hidden_size = hidden_size 55 | self.hidden = nn.Parameter(torch.randn((2, 1, self.hidden_size), requires_grad=True)) 56 | 57 | # input(batch_size, seq_len, dim) 58 | def forward(self, word_embs, pos_onehot, cap_lens): 59 | num_samples = word_embs.shape[0] 60 | 61 | pos_embs = self.pos_emb(pos_onehot) 62 | inputs = word_embs + pos_embs 63 | input_embs = self.input_emb(inputs) 64 | hidden = self.hidden.repeat(1, num_samples, 1) 65 | 66 | cap_lens = cap_lens.data.tolist() 67 | emb = pack_padded_sequence(input_embs, cap_lens, batch_first=True) 68 | 69 | gru_seq, gru_last = self.gru(emb, hidden) 70 | 71 | gru_last = torch.cat([gru_last[0], gru_last[1]], dim=-1) 72 | 73 | return self.output_net(gru_last) 74 | 75 | 76 | class MotionEncoderBiGRUCo(nn.Module): 77 | def __init__(self, input_size, hidden_size, output_size, device): 78 | super(MotionEncoderBiGRUCo, self).__init__() 79 | self.device = device 80 | 81 | self.input_emb = nn.Linear(input_size, hidden_size) 82 | self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True) 83 | self.output_net = nn.Sequential( 84 | nn.Linear(hidden_size*2, hidden_size), 85 | nn.LayerNorm(hidden_size), 86 | nn.LeakyReLU(0.2, inplace=True), 87 | nn.Linear(hidden_size, output_size) 88 | ) 89 | 90 | self.input_emb.apply(init_weight) 91 | self.output_net.apply(init_weight) 92 | self.hidden_size = hidden_size 93 | self.hidden = nn.Parameter(torch.randn((2, 1, self.hidden_size), requires_grad=True)) 94 | 95 | # input(batch_size, seq_len, dim) 96 | def forward(self, inputs, m_lens): 97 | num_samples = inputs.shape[0] 98 | 99 | input_embs = self.input_emb(inputs) 100 | hidden = self.hidden.repeat(1, num_samples, 1) 101 | 102 | cap_lens = m_lens.data.tolist() 103 | emb = pack_padded_sequence(input_embs, cap_lens, batch_first=True, enforce_sorted=False) 104 | 105 | gru_seq, gru_last = self.gru(emb, hidden) 106 | 107 | gru_last = torch.cat([gru_last[0], gru_last[1]], dim=-1) 108 | 109 | return self.output_net(gru_last) 110 | -------------------------------------------------------------------------------- /models/vqvae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from models.encdec import Encoder, Decoder 4 | from models.quantize_cnn import QuantizeEMAReset, Quantizer, QuantizeEMA, QuantizeReset 5 | from exit.utils import generate_src_mask 6 | import torch.nn.functional as F 7 | import math 8 | 9 | 10 | class VQVAE_251(nn.Module): 11 | def __init__(self, 12 | args, 13 | nb_code=1024, 14 | code_dim=512, 15 | output_emb_width=512, 16 | down_t=3, 17 | stride_t=2, 18 | width=512, 19 | depth=3, 20 | dilation_growth_rate=3, 21 | activation='relu', 22 | norm=None): 23 | 24 | super().__init__() 25 | self.code_dim = code_dim 26 | self.num_code = nb_code 27 | self.quant = args.quantizer 28 | output_dim = 251 if args.dataname == 'kit' else 263 29 | self.encoder = Encoder(output_dim, output_emb_width, down_t, stride_t, width, depth, dilation_growth_rate, activation=activation, norm=norm) 30 | self.decoder = Decoder(output_dim, output_emb_width, down_t, stride_t, width, depth, dilation_growth_rate, activation=activation, norm=norm) 31 | if args.quantizer == "ema_reset": 32 | self.quantizer = QuantizeEMAReset(nb_code, code_dim, args) 33 | elif args.quantizer == "orig": 34 | self.quantizer = Quantizer(nb_code, code_dim, 1.0) 35 | elif args.quantizer == "ema": 36 | self.quantizer = QuantizeEMA(nb_code, code_dim, args) 37 | elif args.quantizer == "reset": 38 | self.quantizer = QuantizeReset(nb_code, code_dim, args) 39 | 40 | 41 | def preprocess(self, x): 42 | # (bs, T, Jx3) -> (bs, Jx3, T) 43 | x = x.permute(0,2,1).float() 44 | return x 45 | 46 | 47 | def postprocess(self, x): 48 | # (bs, Jx3, T) -> (bs, T, Jx3) 49 | x = x.permute(0,2,1) 50 | return x 51 | 52 | 53 | def encode(self, x, *args, **kwargs): 54 | N, T, _ = x.shape 55 | x_in = self.preprocess(x) 56 | x_encoder = self.encoder(x_in) 57 | x_encoder = self.postprocess(x_encoder) 58 | x_encoder = x_encoder.contiguous().view(-1, x_encoder.shape[-1]) # (NT, C) 59 | code_idx = self.quantizer.quantize(x_encoder) 60 | code_idx = code_idx.view(N, -1) 61 | return code_idx 62 | 63 | 64 | def forward(self, x, *args, **kwargs): 65 | 66 | x_in = self.preprocess(x) 67 | x_encoder = self.encoder(x_in) 68 | 69 | ## quantization 70 | x_quantized, loss, perplexity = self.quantizer(x_encoder) 71 | 72 | ## decoder 73 | x_decoder = self.decoder(x_quantized) 74 | x_out = self.postprocess(x_decoder) 75 | return x_out, loss, perplexity 76 | 77 | 78 | def forward_decoder(self, x, *args, **kwargs): 79 | x_d = self.quantizer.dequantize(x) 80 | x_d = x_d.permute(0, 2, 1).contiguous() 81 | 82 | # decoder 83 | x_decoder = self.decoder(x_d) 84 | x_out = self.postprocess(x_decoder) 85 | return x_out 86 | 87 | 88 | class HumanVQVAE(nn.Module): 89 | def __init__(self, 90 | args, 91 | nb_code=512, 92 | code_dim=512, 93 | output_emb_width=512, 94 | down_t=3, 95 | stride_t=2, 96 | width=512, 97 | depth=3, 98 | dilation_growth_rate=3, 99 | activation='relu', 100 | norm=None): 101 | 102 | super().__init__() 103 | 104 | self.nb_joints = 21 if args.dataname == 'kit' else 22 105 | self.vqvae = VQVAE_251(args, nb_code, code_dim, code_dim, down_t, stride_t, width, depth, dilation_growth_rate, activation=activation, norm=norm) 106 | 107 | self.mask_id = self.vqvae.num_code + 2 108 | self.pad_id = self.vqvae.num_code + 1 109 | self.end_id = self.vqvae.num_code 110 | 111 | def forward(self, x, dataname=None, type='full', *argv, **kwargs): 112 | '''type=[full, encode, decode]''' 113 | if type=='full': 114 | x_out, loss, perplexity = self.vqvae(x, *argv, **kwargs) 115 | return x_out, loss, perplexity 116 | elif type=='encode': 117 | b, t, c = x.size() 118 | quants = self.vqvae.encode(x, *argv, **kwargs) # (N, T) 119 | return quants 120 | elif type=='decode': 121 | x_out = self.vqvae.forward_decoder(x, dataname) 122 | return x_out 123 | else: 124 | raise ValueError(f'Unknown "{type}" type') 125 | 126 | -------------------------------------------------------------------------------- /options/option_vq.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def get_args_parser(): 4 | parser = argparse.ArgumentParser(description='Optimal Transport AutoEncoder training for AIST', 5 | add_help=True, 6 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 7 | 8 | ## dataloader 9 | parser.add_argument('--dataname', type=str, default='kit', help='dataset directory', choices=['kit', 't2m', 'both']) 10 | parser.add_argument('--total_batch_size', default=256, type=int, help='batch size') 11 | parser.add_argument('--window_size', type=int, default=64, help='training motion length') 12 | 13 | ## optimization 14 | parser.add_argument('--total_iters', default=300_000, type=int) 15 | parser.add_argument('--warm_up_iter', default=1000, type=int) 16 | parser.add_argument('--reset_codebook_every', default=1000, type=int) 17 | parser.add_argument('--lr', default=2e-4, type=float, help='max learning rate') 18 | parser.add_argument('--lr_scheduler', default=[200000], nargs="+", type=int, help="learning rate schedule (iterations)") 19 | parser.add_argument('--gamma', default=0.05, type=float, help="learning rate decay") 20 | 21 | parser.add_argument('--weight_decay', default=0.0, type=float, help='weight decay') 22 | parser.add_argument("--commit", type=float, default=0.02, help="hyper_parameter for the commitment loss") 23 | parser.add_argument("--orth", type=float, default=0.01) 24 | parser.add_argument('--loss_vel', type=float, default=0.5, help='hyper_parameter for the velocity loss') 25 | parser.add_argument('--recons_loss', type=str, default='l1_smooth', help='reconstruction loss') 26 | 27 | ## vqvae arch 28 | parser.add_argument("--mu", type=float, default=0.99, help="exponential moving average to update the codebook") 29 | parser.add_argument("--down_t", type=int, default=2, help="downsampling rate") 30 | parser.add_argument("--stride_t", type=int, default=2, help="stride size") 31 | parser.add_argument("--width", type=int, default=512, help="width of the network") 32 | parser.add_argument("--depth", type=int, default=3, help="depth of the network") 33 | parser.add_argument("--dilation_growth_rate", type=int, default=3, help="dilation growth rate") 34 | parser.add_argument("--output_emb_width", type=int, default=512, help="output embedding width") 35 | parser.add_argument('--vq_act', type=str, default='relu', choices = ['relu', 'silu', 'gelu'], help='dataset directory') 36 | parser.add_argument('--vq_norm', type=str, default=None, help='dataset directory') 37 | 38 | ## quantizer 39 | parser.add_argument("--quantizer", type=str, default='ema_reset', choices = ['ema', 'orig', 'ema_reset', 'reset'], help="eps for optimal transport") 40 | parser.add_argument('--beta', type=float, default=1.0, help='commitment loss in standard VQ') 41 | 42 | ## resume 43 | parser.add_argument("--resume_pth", type=str, default=None, help='resume pth for VQ') 44 | parser.add_argument("--resume_gpt", type=str, default=None, help='resume pth for GPT') 45 | 46 | 47 | ## output directory 48 | parser.add_argument('--out_dir', type=str, default='output', help='output directory') 49 | parser.add_argument('--results_dir', type=str, default='visual_results/', help='output directory') 50 | parser.add_argument('--visual_name', type=str, default='baseline', help='output directory') 51 | parser.add_argument('--exp_name', type=str, default='exp_debug', help='name of the experiment, will create a file inside out_dir') 52 | ## other 53 | parser.add_argument('--print_iter', default=200, type=int, help='print frequency') 54 | parser.add_argument('--eval_iter', default=5000, type=int, help='evaluation frequency') 55 | parser.add_argument('--seed', default=123, type=int, help='seed for initializing training.') 56 | 57 | parser.add_argument('--vis_gt', action='store_true', help='whether visualize GT motions') 58 | parser.add_argument('--nb_vis', default=20, type=int, help='nb of visualizations') 59 | 60 | parser.add_argument('--sep_uplow', action='store_true', help='whether visualize GT motions') 61 | 62 | ### New Options 63 | parser.add_argument('--pin_memory', action='store_false') 64 | parser.add_argument('--num_workers', default=8, type=int) 65 | parser.add_argument('--device', type=str, default='cuda') 66 | parser.add_argument('--debug', action='store_true') 67 | parser.add_argument('--maxdata', default=124, type=int) 68 | 69 | ### VQVAE 70 | parser.add_argument("--code_dim", type=int, default=32, help="embedding dimension") 71 | parser.add_argument("--nb_code", type=int, default=8192, help="nb of embedding") 72 | 73 | ### distributed 74 | parser.add_argument('--init_method', default='tcp://127.0.0.1:3456', type=str, help='') 75 | parser.add_argument('--dist_backend', default='nccl', type=str, help='') 76 | parser.add_argument('--world_size', default=1, type=int, help='') 77 | parser.add_argument('--ddp', action='store_true', help='') 78 | 79 | return parser.parse_args() -------------------------------------------------------------------------------- /my_clip/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | -------------------------------------------------------------------------------- /visualization/joints2smpl/fit_seq.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import argparse 3 | import torch 4 | import os,sys 5 | from os import walk, listdir 6 | from os.path import isfile, join 7 | import numpy as np 8 | import joblib 9 | import smplx 10 | import trimesh 11 | import h5py 12 | from tqdm import tqdm 13 | 14 | sys.path.append(os.path.join(os.path.dirname(__file__), "src")) 15 | from smplify import SMPLify3D 16 | import config 17 | 18 | # parsing argmument 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('--batchSize', type=int, default=1, 21 | help='input batch size') 22 | parser.add_argument('--num_smplify_iters', type=int, default=100, 23 | help='num of smplify iters') 24 | parser.add_argument('--cuda', type=bool, default=False, 25 | help='enables cuda') 26 | parser.add_argument('--gpu_ids', type=int, default=0, 27 | help='choose gpu ids') 28 | parser.add_argument('--num_joints', type=int, default=22, 29 | help='joint number') 30 | parser.add_argument('--joint_category', type=str, default="AMASS", 31 | help='use correspondence') 32 | parser.add_argument('--fix_foot', type=str, default="False", 33 | help='fix foot or not') 34 | parser.add_argument('--data_folder', type=str, default="./demo/demo_data/", 35 | help='data in the folder') 36 | parser.add_argument('--save_folder', type=str, default="./demo/demo_results/", 37 | help='results save folder') 38 | parser.add_argument('--files', type=str, default="test_motion.npy", 39 | help='files use') 40 | opt = parser.parse_args() 41 | print(opt) 42 | 43 | # ---load predefined something 44 | device = torch.device("cuda:" + str(opt.gpu_ids) if opt.cuda else "cpu") 45 | print(config.SMPL_MODEL_DIR) 46 | smplmodel = smplx.create(config.SMPL_MODEL_DIR, 47 | model_type="smpl", gender="neutral", ext="pkl", 48 | batch_size=opt.batchSize).to(device) 49 | 50 | # ## --- load the mean pose as original ---- 51 | smpl_mean_file = config.SMPL_MEAN_FILE 52 | 53 | file = h5py.File(smpl_mean_file, 'r') 54 | init_mean_pose = torch.from_numpy(file['pose'][:]).unsqueeze(0).float() 55 | init_mean_shape = torch.from_numpy(file['shape'][:]).unsqueeze(0).float() 56 | cam_trans_zero = torch.Tensor([0.0, 0.0, 0.0]).to(device) 57 | # 58 | pred_pose = torch.zeros(opt.batchSize, 72).to(device) 59 | pred_betas = torch.zeros(opt.batchSize, 10).to(device) 60 | pred_cam_t = torch.zeros(opt.batchSize, 3).to(device) 61 | keypoints_3d = torch.zeros(opt.batchSize, opt.num_joints, 3).to(device) 62 | 63 | # # #-------------initialize SMPLify 64 | smplify = SMPLify3D(smplxmodel=smplmodel, 65 | batch_size=opt.batchSize, 66 | joints_category=opt.joint_category, 67 | num_iters=opt.num_smplify_iters, 68 | device=device) 69 | #print("initialize SMPLify3D done!") 70 | 71 | 72 | purename = os.path.splitext(opt.files)[0] 73 | # --- load data --- 74 | data = np.load(opt.data_folder + "/" + purename + ".npy") # [nframes, njoints, 3] 75 | 76 | dir_save = os.path.join(opt.save_folder, purename) 77 | if not os.path.isdir(dir_save): 78 | os.makedirs(dir_save, exist_ok=True) 79 | 80 | # run the whole seqs 81 | num_seqs = data.shape[0] 82 | 83 | for idx in tqdm(range(num_seqs)): 84 | #print(idx) 85 | 86 | joints3d = data[idx] #*1.2 #scale problem [check first] 87 | keypoints_3d[0, :, :] = torch.Tensor(joints3d).to(device).float() 88 | 89 | if idx == 0: 90 | pred_betas[0, :] = init_mean_shape 91 | pred_pose[0, :] = init_mean_pose 92 | pred_cam_t[0, :] = cam_trans_zero 93 | else: 94 | data_param = joblib.load(dir_save + "/" + "%04d"%(idx-1) + ".pkl") 95 | pred_betas[0, :] = torch.from_numpy(data_param['beta']).unsqueeze(0).float() 96 | pred_pose[0, :] = torch.from_numpy(data_param['pose']).unsqueeze(0).float() 97 | pred_cam_t[0, :] = torch.from_numpy(data_param['cam']).unsqueeze(0).float() 98 | 99 | if opt.joint_category =="AMASS": 100 | confidence_input = torch.ones(opt.num_joints) 101 | # make sure the foot and ankle 102 | if opt.fix_foot == True: 103 | confidence_input[7] = 1.5 104 | confidence_input[8] = 1.5 105 | confidence_input[10] = 1.5 106 | confidence_input[11] = 1.5 107 | else: 108 | print("Such category not settle down!") 109 | 110 | # ----- from initial to fitting ------- 111 | new_opt_vertices, new_opt_joints, new_opt_pose, new_opt_betas, \ 112 | new_opt_cam_t, new_opt_joint_loss = smplify( 113 | pred_pose.detach(), 114 | pred_betas.detach(), 115 | pred_cam_t.detach(), 116 | keypoints_3d, 117 | conf_3d=confidence_input.to(device), 118 | seq_ind=idx 119 | ) 120 | 121 | # # -- save the results to ply--- 122 | outputp = smplmodel(betas=new_opt_betas, global_orient=new_opt_pose[:, :3], body_pose=new_opt_pose[:, 3:], 123 | transl=new_opt_cam_t, return_verts=True) 124 | mesh_p = trimesh.Trimesh(vertices=outputp.vertices.detach().cpu().numpy().squeeze(), faces=smplmodel.faces, process=False) 125 | mesh_p.export(dir_save + "/" + "%04d"%idx + ".ply") 126 | 127 | # save the pkl 128 | param = {} 129 | param['beta'] = new_opt_betas.detach().cpu().numpy() 130 | param['pose'] = new_opt_pose.detach().cpu().numpy() 131 | param['cam'] = new_opt_cam_t.detach().cpu().numpy() 132 | joblib.dump(param, dir_save + "/" + "%04d"%idx + ".pkl", compress=3) 133 | -------------------------------------------------------------------------------- /environment2.yml: -------------------------------------------------------------------------------- 1 | name: bad2 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - anaconda 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1 9 | - absl-py=1.4.0 10 | - aiohttp=3.8.3 11 | - aiosignal=1.2.0 12 | - argon2-cffi=21.3.0 13 | - argon2-cffi-bindings=21.2.0 14 | - async-timeout=4.0.2 15 | - asynctest=0.13.0 16 | - attrs=22.1.0 17 | - backcall=0.2.0 18 | - beautifulsoup4=4.11.1 19 | - blas=1.0 20 | - bleach=4.1.0 21 | - blinker=1.4 22 | - brotlipy=0.7.0 23 | - c-ares=1.19.0 24 | - ca-certificates=2023.05.30 25 | - catalogue=2.0.8 26 | - certifi=2022.12.7 27 | - cffi=1.15.1 28 | - charset-normalizer=2.1.1 29 | - click=8.0.4 30 | - colorama=0.4.5 31 | - cryptography=35.0.0 32 | - cudatoolkit=11.0.221 33 | - cycler=0.11.0 34 | - cymem=2.0.6 35 | - cython-blis=0.7.7 36 | - dataclasses=0.8 37 | - debugpy=1.5.1 38 | - decorator=5.1.1 39 | - defusedxml=0.7.1 40 | - entrypoints=0.4 41 | - expat=2.4.9 42 | - fftw=3.3.9 43 | - filelock=3.8.0 44 | - fontconfig=2.13.1 45 | - freetype=2.10.4 46 | - frozenlist=1.3.3 47 | - giflib=5.2.1 48 | - glib=2.69.1 49 | - h5py=3.7.0 50 | - hdf5=1.10.6 51 | # - icc_rt=2022.1.0 52 | - icu=58.2 53 | - idna=3.4 54 | - importlib-metadata=4.11.4 55 | - intel-openmp=2021.4.0 56 | - ipykernel=6.15.2 57 | - ipython=7.31.1 58 | - ipython_genutils=0.2.0 59 | - jedi=0.18.1 60 | - jinja2=3.1.2 61 | - joblib=1.1.0 62 | - jpeg=9b 63 | - jsonschema=3.0.2 64 | - jupyter_client=7.4.9 65 | - jupyter_core=4.11.2 66 | - jupyterlab_pygments=0.1.2 67 | - kiwisolver=1.4.2 68 | - langcodes=3.3.0 69 | - lcms2=2.12 70 | - libffi=3.3 71 | - libiconv=1.16 72 | - libpng=1.6.37 73 | - libprotobuf=3.15.8 74 | - libsodium=1.0.18 75 | - libtiff=4.1.0 76 | - libuv=1.40.0 77 | - libwebp=1.2.0 78 | - libxcb=1.15 79 | - libxml2=2.9.14 80 | - lz4-c=1.9.3 81 | # - m2w64-gcc-libgfortran=5.3.0 82 | # - m2w64-gcc-libs=5.3.0 83 | # - m2w64-gcc-libs-core=5.3.0 84 | # - m2w64-gmp=6.1.0 85 | # - m2w64-libwinpthread-git=5.0.0.4634.697f757 86 | - markdown=3.4.3 87 | - markupsafe=2.1.1 88 | - matplotlib=3.1.3 89 | - matplotlib-base=3.1.3 90 | - matplotlib-inline=0.1.6 91 | - mistune=0.8.4 92 | - mkl=2021.4.0 93 | - mkl-service=2.4.0 94 | - mkl_fft=1.3.1 95 | - mkl_random=1.2.2 96 | # - msys2-conda-epoch=20160418 97 | - multidict=6.0.2 98 | - murmurhash=1.0.7 99 | - nb_conda_kernels=2.3.1 100 | - nbclient=0.5.13 101 | - nbconvert=6.4.4 102 | - nbformat=5.5.0 103 | - nest-asyncio=1.5.6 104 | - ninja=1.10.2 105 | - ninja-base=1.10.2 106 | - notebook=6.4.12 107 | - numpy=1.21.5 108 | - numpy-base=1.21.5 109 | - openssl=1.1.1v 110 | - packaging=21.3 111 | - pandocfilters=1.5.0 112 | - parso=0.8.3 113 | - pathy=0.6.2 114 | - pcre=8.45 115 | - pexpect=4.8.0 116 | - pickleshare=0.7.5 117 | - pillow=9.2.0 118 | - pip=22.2.2 119 | - preshed=3.0.6 120 | - prometheus_client=0.14.1 121 | - prompt-toolkit=3.0.36 122 | - psutil=5.9.0 123 | - pthread-stubs=0.3 124 | - ptyprocess=0.7.0 125 | - pycparser=2.21 126 | - pydantic=1.8.2 127 | - pygments=2.11.2 128 | - pyjwt=2.4.0 129 | - pyopenssl=22.0.0 130 | - pyparsing=3.0.9 131 | - pyqt=5.9.2 132 | - pyrsistent=0.18.0 133 | - pysocks=1.7.1 134 | - python=3.7.13 135 | - python-dateutil=2.8.2 136 | - python-fastjsonschema=2.16.2 137 | - python_abi=3.7 138 | - pytorch=1.7.1 139 | # - pywin32=305 140 | # - pywinpty=2.0.10 141 | - pyzmq=23.2.0 142 | - qt=5.9.7 143 | - requests=2.28.1 144 | - scikit-learn=1.0.2 145 | - scipy=1.7.3 146 | - send2trash=1.8.0 147 | - setuptools=63.4.1 148 | - shellingham=1.5.0 149 | - sip=4.19.8 150 | - six=1.16.0 151 | - smart_open=5.2.1 152 | - soupsieve=2.3.2.post1 153 | - spacy=3.3.1 154 | - spacy-legacy=3.0.10 155 | - spacy-loggers=1.0.3 156 | - sqlite=3.39.3 157 | - srsly=2.4.3 158 | - tensorboard-plugin-wit=1.8.1 159 | - terminado=0.17.1 160 | - testpath=0.6.0 161 | - thinc=8.0.15 162 | - threadpoolctl=2.2.0 163 | - tk=8.6.12 164 | - torchaudio=0.7.2 165 | - torchvision=0.8.2 166 | - tornado=6.2 167 | - tqdm=4.64.1 168 | - traitlets=5.7.1 169 | - trimesh=3.15.3 170 | - typer=0.4.2 171 | - typing-extensions=3.10.0.2 172 | - typing_extensions=3.10.0.2 173 | - urllib3=1.26.15 174 | # - vc=14.2 175 | # - vs2015_runtime=14.29.30133 176 | - wasabi=0.10.1 177 | - webencodings=0.5.1 178 | - werkzeug=2.2.3 179 | - wheel=0.37.1 180 | # - win_inet_pton=1.1.0 181 | # - wincertstore=0.2 182 | # - winpty=0.4.3 183 | - xorg-libxau=1.0.11 184 | - xorg-libxdmcp=1.1.3 185 | - xz=5.2.6 186 | - yarl=1.8.1 187 | - zeromq=4.3.4 188 | - zipp=3.8.1 189 | - zlib=1.2.12 190 | - zstd=1.4.9 191 | - pip: 192 | - cachetools==5.3.1 193 | - chumpy==0.70 194 | # - clip==1.0 195 | - einops==0.6.1 196 | - fsspec==2023.1.0 197 | - ftfy==6.1.1 198 | - gdown==4.7.1 199 | - google-auth==2.22.0 200 | - google-auth-oauthlib==0.4.6 201 | - grpcio==1.57.0 202 | - huggingface-hub==0.16.4 203 | - oauthlib==3.2.2 204 | - plotly==5.18.0 205 | - protobuf==3.20.3 206 | - pyasn1==0.5.0 207 | - pyasn1-modules==0.3.0 208 | - pyyaml==6.0.1 209 | - regex==2023.8.8 210 | - requests-oauthlib==1.3.1 211 | - rsa==4.9 212 | - safetensors==0.4.5 213 | - smplx==0.1.28 214 | - tenacity==8.2.3 215 | - tensorboard==2.11.2 216 | - tensorboard-data-server==0.6.1 217 | - timm==0.9.12 218 | - wcwidth==0.2.6 219 | -------------------------------------------------------------------------------- /dataset/dataset_tokenize.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils import data 3 | import numpy as np 4 | from os.path import join as pjoin 5 | import random 6 | import codecs as cs 7 | from tqdm import tqdm 8 | 9 | 10 | 11 | class VQMotionDataset(data.Dataset): 12 | def __init__(self, dataset_name, feat_bias = 5, window_size = 64, unit_length = 8, fill_max_len=False, args=None): 13 | self.window_size = window_size 14 | self.unit_length = unit_length 15 | self.feat_bias = feat_bias 16 | self.fill_max_len = fill_max_len 17 | 18 | self.dataset_name = dataset_name 19 | min_motion_len = 40 if dataset_name =='t2m' else 24 20 | 21 | if dataset_name == 't2m': 22 | self.data_root = './dataset/HumanML3D' 23 | self.motion_dir = pjoin(self.data_root, 'new_joint_vecs') 24 | self.text_dir = pjoin(self.data_root, 'texts') 25 | self.joints_num = 22 26 | radius = 4 27 | fps = 20 28 | self.max_motion_length = 196 29 | self.dim_pose = 263 30 | self.meta_dir = 'checkpoints/t2m/VQVAEV3_CB1024_CMT_H1024_NRES3/meta' 31 | #kinematic_chain = paramUtil.t2m_kinematic_chain 32 | elif dataset_name == 'kit': 33 | self.data_root = './dataset/KIT-ML' 34 | self.motion_dir = pjoin(self.data_root, 'new_joint_vecs') 35 | self.text_dir = pjoin(self.data_root, 'texts') 36 | self.joints_num = 21 37 | radius = 240 * 8 38 | fps = 12.5 39 | self.dim_pose = 251 40 | self.max_motion_length = 196 41 | self.meta_dir = 'checkpoints/kit/VQVAEV3_CB1024_CMT_H1024_NRES3/meta' 42 | #kinematic_chain = paramUtil.kit_kinematic_chain 43 | 44 | joints_num = self.joints_num 45 | 46 | mean = np.load(pjoin(self.meta_dir, 'mean.npy')) 47 | std = np.load(pjoin(self.meta_dir, 'std.npy')) 48 | 49 | split_file = pjoin(self.data_root, 'train.txt') 50 | 51 | data_dict = {} 52 | id_list = [] 53 | with cs.open(split_file, 'r') as f: 54 | for line in f.readlines(): 55 | id_list.append(line.strip()) 56 | 57 | new_name_list = [] 58 | length_list = [] 59 | i_debug = 0 60 | for name in tqdm(id_list): 61 | try: 62 | motion = np.load(pjoin(self.motion_dir, name + '.npy')) 63 | if (len(motion)) < min_motion_len or (len(motion) >= 200): 64 | continue 65 | 66 | data_dict[name] = {'motion': motion, 67 | 'length': len(motion), 68 | 'name': name} 69 | new_name_list.append(name) 70 | length_list.append(len(motion)) 71 | 72 | if args.debug: 73 | i_debug += 1 74 | if i_debug >= args.maxdata: 75 | break 76 | except: 77 | # Some motion may not exist in KIT dataset 78 | pass 79 | 80 | 81 | self.mean = mean 82 | self.std = std 83 | self.length_arr = np.array(length_list) 84 | self.data_dict = data_dict 85 | self.name_list = new_name_list 86 | 87 | def inv_transform(self, data): 88 | return data * self.std + self.mean 89 | 90 | def __len__(self): 91 | return len(self.data_dict) 92 | 93 | def __getitem__(self, item): 94 | name = self.name_list[item] 95 | data = self.data_dict[name] 96 | motion, m_length = data['motion'], data['length'] 97 | 98 | m_length = (m_length // self.unit_length) * self.unit_length 99 | 100 | idx = random.randint(0, len(motion) - m_length) 101 | motion = motion[idx:idx+m_length] 102 | 103 | if self.fill_max_len: 104 | motion_zero = np.zeros((self.max_motion_length, self.dim_pose)) 105 | motion_zero[:m_length] = motion 106 | motion = motion_zero 107 | motion = (motion - self.mean) / self.std 108 | return motion, m_length 109 | 110 | "Z Normalization" 111 | motion = (motion - self.mean) / self.std 112 | 113 | return motion, name 114 | 115 | def DATALoader(dataset_name, 116 | batch_size = 1, 117 | num_workers = 8, unit_length = 4, shuffle=True, args=None) : 118 | 119 | train_loader = torch.utils.data.DataLoader(VQMotionDataset(dataset_name, unit_length=unit_length, fill_max_len=batch_size!=1, args=args), 120 | batch_size, 121 | shuffle=shuffle, 122 | num_workers=0, 123 | #collate_fn=collate_fn, 124 | drop_last = True) 125 | 126 | return train_loader 127 | 128 | def cycle(iterable): 129 | while True: 130 | for x in iterable: 131 | yield x 132 | 133 | 134 | def save_tokens(args, net): 135 | print("Extracting Code") 136 | train_loader_token = DATALoader(args.dataname, 1, unit_length=2**args.down_t, args=args) 137 | for batch in tqdm(train_loader_token, "Starting Extracting..."): 138 | pose, name = batch 139 | bs, seq = pose.shape[0], pose.shape[1] 140 | 141 | pose = pose.to(args.device).float() # bs, nb_joints, joints_dim, seq_len 142 | target = net(pose, type='encode') 143 | target = target.cpu().numpy() 144 | 145 | np.save(pjoin(args.codebook_dir, name[0] +'.npy'), target) 146 | -------------------------------------------------------------------------------- /visualization/simplify_loc2rot.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import torch 4 | from visualization.joints2smpl.src import config 5 | import smplx 6 | import h5py 7 | from visualization.joints2smpl.src.smplify import SMPLify3D 8 | from tqdm import tqdm 9 | import utils.rotation_conversions as geometry 10 | import argparse 11 | 12 | 13 | class joints2smpl: 14 | 15 | def __init__(self, num_frames, device, cuda=True, n_iter=150): 16 | self.device = device 17 | # self.device = torch.device("cpu") 18 | self.batch_size = num_frames 19 | self.num_joints = 22 # for HumanML3D 20 | self.joint_category = "AMASS" 21 | self.num_smplify_iters = n_iter 22 | self.fix_foot = False 23 | print(config.SMPL_MODEL_DIR) 24 | smplmodel = smplx.create(config.SMPL_MODEL_DIR, 25 | model_type="smpl", gender="neutral", ext="pkl", 26 | batch_size=self.batch_size).to(self.device) 27 | 28 | # ## --- load the mean pose as original ---- 29 | smpl_mean_file = config.SMPL_MEAN_FILE 30 | 31 | file = h5py.File(smpl_mean_file, 'r') 32 | self.init_mean_pose = torch.from_numpy(file['pose'][:]).unsqueeze(0).repeat(self.batch_size, 1).float().to(self.device) 33 | self.init_mean_shape = torch.from_numpy(file['shape'][:]).unsqueeze(0).repeat(self.batch_size, 1).float().to(self.device) 34 | self.cam_trans_zero = torch.Tensor([0.0, 0.0, 0.0]).unsqueeze(0).to(self.device) 35 | # 36 | 37 | # # #-------------initialize SMPLify 38 | self.smplify = SMPLify3D(smplxmodel=smplmodel, 39 | batch_size=self.batch_size, 40 | joints_category=self.joint_category, 41 | num_iters=self.num_smplify_iters, 42 | device=self.device) 43 | 44 | 45 | def npy2smpl(self, npy_path): 46 | out_path = npy_path.replace('.npy', '_rot.npy') 47 | motions = np.load(npy_path, allow_pickle=True)[None][0] 48 | # print_batch('', motions) 49 | n_samples = motions['motion'].shape[0] 50 | all_thetas = [] 51 | for sample_i in tqdm(range(n_samples)): 52 | thetas, _ = self.joint2smpl(motions['motion'][sample_i].transpose(2, 0, 1)) # [nframes, njoints, 3] 53 | all_thetas.append(thetas.cpu().numpy()) 54 | motions['motion'] = np.concatenate(all_thetas, axis=0) 55 | print('motions', motions['motion'].shape) 56 | 57 | print(f'Saving [{out_path}]') 58 | np.save(out_path, motions) 59 | exit() 60 | 61 | 62 | 63 | def joint2smpl(self, input_joints, init_params=None): 64 | _smplify = self.smplify # if init_params is None else self.smplify_fast 65 | pred_pose = torch.zeros(self.batch_size, 72).to(self.device) 66 | pred_betas = torch.zeros(self.batch_size, 10).to(self.device) 67 | pred_cam_t = torch.zeros(self.batch_size, 3).to(self.device) 68 | keypoints_3d = torch.zeros(self.batch_size, self.num_joints, 3).to(self.device) 69 | 70 | # run the whole seqs 71 | num_seqs = input_joints.shape[0] 72 | 73 | 74 | # joints3d = input_joints[idx] # *1.2 #scale problem [check first] 75 | keypoints_3d = torch.Tensor(input_joints).to(self.device).float() 76 | 77 | # if idx == 0: 78 | if init_params is None: 79 | pred_betas = self.init_mean_shape 80 | pred_pose = self.init_mean_pose 81 | pred_cam_t = self.cam_trans_zero 82 | else: 83 | pred_betas = init_params['betas'] 84 | pred_pose = init_params['pose'] 85 | pred_cam_t = init_params['cam'] 86 | 87 | if self.joint_category == "AMASS": 88 | confidence_input = torch.ones(self.num_joints) 89 | # make sure the foot and ankle 90 | if self.fix_foot == True: 91 | confidence_input[7] = 1.5 92 | confidence_input[8] = 1.5 93 | confidence_input[10] = 1.5 94 | confidence_input[11] = 1.5 95 | else: 96 | print("Such category not settle down!") 97 | 98 | new_opt_vertices, new_opt_joints, new_opt_pose, new_opt_betas, \ 99 | new_opt_cam_t, new_opt_joint_loss = _smplify( 100 | pred_pose.detach(), 101 | pred_betas.detach(), 102 | pred_cam_t.detach(), 103 | keypoints_3d, 104 | conf_3d=confidence_input.to(self.device), 105 | # seq_ind=idx 106 | ) 107 | 108 | thetas = new_opt_pose.reshape(self.batch_size, 24, 3) 109 | thetas = geometry.matrix_to_rotation_6d(geometry.axis_angle_to_matrix(thetas)) # [bs, 24, 6] 110 | root_loc = torch.tensor(keypoints_3d[:, 0]) # [bs, 3] 111 | root_loc = torch.cat([root_loc, torch.zeros_like(root_loc)], dim=-1).unsqueeze(1) # [bs, 1, 6] 112 | thetas = torch.cat([thetas, root_loc], dim=1).unsqueeze(0).permute(0, 2, 3, 1) # [1, 25, 6, 196] 113 | 114 | return thetas.clone().detach(), {'pose': new_opt_joints[0, :24].flatten().clone().detach(), 'betas': new_opt_betas.clone().detach(), 'cam': new_opt_cam_t.clone().detach()} 115 | 116 | 117 | if __name__ == '__main__': 118 | parser = argparse.ArgumentParser() 119 | parser.add_argument("--input_path", type=str, required=True, help='Blender file or dir with blender files') 120 | parser.add_argument("--cuda", type=bool, default=True, help='') 121 | parser.add_argument("--device", type=int, default=0, help='') 122 | params = parser.parse_args() 123 | 124 | simplify = joints2smpl(device_id=params.device, cuda=params.cuda) 125 | 126 | if os.path.isfile(params.input_path) and params.input_path.endswith('.npy'): 127 | simplify.npy2smpl(params.input_path) 128 | elif os.path.isdir(params.input_path): 129 | files = [os.path.join(params.input_path, f) for f in os.listdir(params.input_path) if f.endswith('.npy')] 130 | for f in files: 131 | simplify.npy2smpl(f) -------------------------------------------------------------------------------- /visualization/joints2bvh.py: -------------------------------------------------------------------------------- 1 | import visualization.Animation as Animation 2 | 3 | from visualization.InverseKinematics import BasicInverseKinematics, BasicJacobianIK, InverseKinematics 4 | from visualization.Quaternions import Quaternions 5 | import visualization.BVH_mod as BVH 6 | from visualization.remove_fs import * 7 | 8 | from utils.plot_script import plot_3d_motion 9 | from utils import paramUtil 10 | from common.skeleton import Skeleton 11 | import torch 12 | 13 | from torch import nn 14 | from visualization.utils.quat import ik_rot, between, fk, ik 15 | from tqdm import tqdm 16 | 17 | 18 | def get_grot(glb, parent, offset): 19 | root_quat = np.array([[1.0, 0.0, 0.0, 0.0]]).repeat(glb.shape[0], axis=0)[:, None] 20 | local_pos = glb[:, 1:] - glb[:, parent[1:]] 21 | norm_offset = offset[1:] / np.linalg.norm(offset[1:], axis=-1, keepdims=True) 22 | norm_lpos = local_pos / np.linalg.norm(local_pos, axis=-1, keepdims=True) 23 | grot = between(norm_offset, norm_lpos) 24 | grot = np.concatenate((root_quat, grot), axis=1) 25 | grot /= np.linalg.norm(grot, axis=-1, keepdims=True) 26 | return grot 27 | 28 | 29 | class Joint2BVHConvertor: 30 | def __init__(self): 31 | self.template = BVH.load('./visualization/data/template.bvh', need_quater=True) 32 | self.re_order = [0, 1, 4, 7, 10, 2, 5, 8, 11, 3, 6, 9, 12, 15, 13, 16, 18, 20, 14, 17, 19, 21] 33 | 34 | self.re_order_inv = [0, 1, 5, 9, 2, 6, 10, 3, 7, 11, 4, 8, 12, 14, 18, 13, 15, 19, 16, 20, 17, 21] 35 | self.end_points = [4, 8, 13, 17, 21] 36 | 37 | self.template_offset = self.template.offsets.copy() 38 | self.parents = [-1, 0, 1, 2, 3, 0, 5, 6, 7, 0, 9, 10, 11, 12, 11, 14, 15, 16, 11, 18, 19, 20] 39 | 40 | def convert(self, positions, filename, iterations=10, foot_ik=True): 41 | ''' 42 | Convert the SMPL joint positions to Mocap BVH 43 | :param positions: (N, 22, 3) 44 | :param filename: Save path for resulting BVH 45 | :param iterations: iterations for optimizing rotations, 10 is usually enough 46 | :param foot_ik: whether to enfore foot inverse kinematics, removing foot slide issue. 47 | :return: 48 | ''' 49 | positions = positions[:, self.re_order] 50 | new_anim = self.template.copy() 51 | new_anim.rotations = Quaternions.id(positions.shape[:-1]) 52 | new_anim.positions = new_anim.positions[0:1].repeat(positions.shape[0], axis=-0) 53 | new_anim.positions[:, 0] = positions[:, 0] 54 | 55 | if foot_ik: 56 | positions = remove_fs(positions, None, fid_l=(3, 4), fid_r=(7, 8), interp_length=5, 57 | force_on_floor=True) 58 | ik_solver = BasicInverseKinematics(new_anim, positions, iterations=iterations, silent=True) 59 | new_anim = ik_solver() 60 | 61 | # BVH.save(filename, new_anim, names=new_anim.names, frametime=1 / 20, order='zyx', quater=True) 62 | glb = Animation.positions_global(new_anim)[:, self.re_order_inv] 63 | if filename is not None: 64 | BVH.save(filename, new_anim, names=new_anim.names, frametime=1 / 20, order='zyx', quater=True) 65 | return new_anim, glb 66 | 67 | def convert_sgd(self, positions, filename, iterations=100, foot_ik=True): 68 | ''' 69 | Convert the SMPL joint positions to Mocap BVH 70 | 71 | :param positions: (N, 22, 3) 72 | :param filename: Save path for resulting BVH 73 | :param iterations: iterations for optimizing rotations, 10 is usually enough 74 | :param foot_ik: whether to enfore foot inverse kinematics, removing foot slide issue. 75 | :return: 76 | ''' 77 | 78 | ## Positional Foot locking ## 79 | glb = positions[:, self.re_order] 80 | 81 | if foot_ik: 82 | glb = remove_fs(glb, None, fid_l=(3, 4), fid_r=(7, 8), interp_length=2, 83 | force_on_floor=True) 84 | 85 | ## Fit BVH ## 86 | new_anim = self.template.copy() 87 | new_anim.rotations = Quaternions.id(glb.shape[:-1]) 88 | new_anim.positions = new_anim.positions[0:1].repeat(glb.shape[0], axis=-0) 89 | new_anim.positions[:, 0] = glb[:, 0] 90 | anim = new_anim.copy() 91 | 92 | rot = torch.tensor(anim.rotations.qs, dtype=torch.float) 93 | pos = torch.tensor(anim.positions[:, 0, :], dtype=torch.float) 94 | offset = torch.tensor(anim.offsets, dtype=torch.float) 95 | 96 | glb = torch.tensor(glb, dtype=torch.float) 97 | ik_solver = InverseKinematics(rot, pos, offset, anim.parents, glb) 98 | print('Fixing foot contact using IK...') 99 | for i in tqdm(range(iterations)): 100 | mse = ik_solver.step() 101 | # print(i, mse) 102 | 103 | rotations = ik_solver.rotations.detach().cpu() 104 | norm = torch.norm(rotations, dim=-1, keepdim=True) 105 | rotations /= norm 106 | 107 | anim.rotations = Quaternions(rotations.numpy()) 108 | anim.rotations[:, self.end_points] = Quaternions.id((anim.rotations.shape[0], len(self.end_points))) 109 | anim.positions[:, 0, :] = ik_solver.position.detach().cpu().numpy() 110 | if filename is not None: 111 | BVH.save(filename, anim, names=new_anim.names, frametime=1 / 20, order='zyx', quater=True) 112 | # BVH.save(filename[:-3] + 'bvh', anim, names=new_anim.names, frametime=1 / 20, order='zyx', quater=True) 113 | glb = Animation.positions_global(anim)[:, self.re_order_inv] 114 | return anim, glb 115 | 116 | 117 | 118 | if __name__ == "__main__": 119 | # file = 'batch0_sample13_repeat0_len196.npy' 120 | # file = 'batch2_sample10_repeat0_len156.npy' 121 | # file = 'batch2_sample13_repeat0_len196.npy' #line #57 new_anim.positions = lpos #new_anim.positions[0:1].repeat(positions.shape[0], axis=-0) #TODO, figure out why it's important 122 | # file = 'batch1_sample12_repeat0_len196.npy' #hard case karate 123 | # file = 'batch1_sample14_repeat0_len180.npy' 124 | # file = 'batch0_sample3_repeat0_len192.npy' 125 | # file = 'batch1_sample4_repeat0_len136.npy' 126 | 127 | # file = 'batch0_sample0_repeat0_len152.npy' 128 | # path = f'/Users/yuxuanmu/project/MaskMIT/demo/cond4_topkr0.9_ts18_tau1.0_s1009/joints/{file}' 129 | # joints = np.load(path) 130 | # converter = Joint2BVHConvertor() 131 | # new_anim = converter.convert(joints, './gen_L196.mp4', foot_ik=True) 132 | 133 | folder = '/Users/yuxuanmu/project/MaskMIT/demo/cond4_topkr0.9_ts18_tau1.0_s1009' 134 | files = os.listdir(os.path.join(folder, 'joints')) 135 | files = [f for f in files if 'repeat' in f] 136 | converter = Joint2BVHConvertor() 137 | for f in tqdm(files): 138 | joints = np.load(os.path.join(folder, 'joints', f)) 139 | converter.convert(joints, os.path.join(folder, 'ik_animations', f'ik_{f}'.replace('npy', 'mp4')), foot_ik=True) -------------------------------------------------------------------------------- /visualization/smpl2bvh.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import argparse 4 | import pickle 5 | import smplx 6 | 7 | from utils import bvh, quat 8 | 9 | 10 | def parse_args(): 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument("--model_path", type=str, default="./visualization/data/smpl/") 13 | parser.add_argument("--model_type", type=str, default="smpl", choices=["smpl", "smplx"]) 14 | parser.add_argument("--gender", type=str, default="MALE", choices=["MALE", "FEMALE", "NEUTRAL"]) 15 | parser.add_argument("--num_betas", type=int, default=10, choices=[10, 300]) 16 | parser.add_argument("--poses", type=str, default="data/gWA_sFM_cAll_d27_mWA5_ch20.pkl") 17 | parser.add_argument("--fps", type=int, default=60) 18 | parser.add_argument("--output", type=str, default="data/gWA_sFM_cAll_d27_mWA5_ch20.bvh") 19 | parser.add_argument("--mirror", action="store_true") 20 | return parser.parse_args() 21 | 22 | def mirror_rot_trans(lrot, trans, names, parents): 23 | joints_mirror = np.array([( 24 | names.index("Left"+n[5:]) if n.startswith("Right") else ( 25 | names.index("Right"+n[4:]) if n.startswith("Left") else 26 | names.index(n))) for n in names]) 27 | 28 | mirror_pos = np.array([-1, 1, 1]) 29 | mirror_rot = np.array([1, 1, -1, -1]) 30 | grot = quat.fk_rot(lrot, parents) 31 | trans_mirror = mirror_pos * trans 32 | grot_mirror = mirror_rot * grot[:,joints_mirror] 33 | 34 | return quat.ik_rot(grot_mirror, parents), trans_mirror 35 | 36 | def smpl2bvh(model_path:str, poses:str, output:str, mirror:bool, 37 | model_type="smpl", gender="MALE", 38 | num_betas=10, fps=60) -> None: 39 | """Save bvh file created by smpl parameters. 40 | 41 | Args: 42 | model_path (str): Path to smpl models. 43 | poses (str): Path to npz or pkl file. 44 | output (str): Where to save bvh. 45 | mirror (bool): Whether save mirror motion or not. 46 | model_type (str, optional): I prepared "smpl" only. Defaults to "smpl". 47 | gender (str, optional): Gender Information. Defaults to "MALE". 48 | num_betas (int, optional): How many pca parameters to use in SMPL. Defaults to 10. 49 | fps (int, optional): Frame per second. Defaults to 30. 50 | """ 51 | 52 | # names = [ 53 | # "Pelvis", 54 | # "Left_hip", 55 | # "Right_hip", 56 | # "Spine1", 57 | # "Left_knee", 58 | # "Right_knee", 59 | # "Spine2", 60 | # "Left_ankle", 61 | # "Right_ankle", 62 | # "Spine3", 63 | # "Left_foot", 64 | # "Right_foot", 65 | # "Neck", 66 | # "Left_collar", 67 | # "Right_collar", 68 | # "Head", 69 | # "Left_shoulder", 70 | # "Right_shoulder", 71 | # "Left_elbow", 72 | # "Right_elbow", 73 | # "Left_wrist", 74 | # "Right_wrist", 75 | # "Left_palm", 76 | # "Right_palm", 77 | # ] 78 | 79 | names = [ 80 | "Hips", 81 | "LeftUpLeg", 82 | "RightUpLeg", 83 | "Spine", 84 | "LeftLeg", 85 | "RightLeg", 86 | "Spine1", 87 | "LeftFoot", 88 | "RightFoot", 89 | "Spine2", 90 | "LeftToe", 91 | "RightToe", 92 | "Neck", 93 | "LeftShoulder", 94 | "RightShoulder", 95 | "Head", 96 | "LeftArm", 97 | "RightArm", 98 | "LeftForeArm", 99 | "RightForeArm", 100 | "LeftHand", 101 | "RightHand", 102 | "LeftThumb", 103 | "RightThumb", 104 | ] 105 | 106 | # I prepared smpl models only, 107 | # but I will release for smplx models recently. 108 | model = smplx.create(model_path=model_path, 109 | model_type=model_type, 110 | gender=gender, 111 | batch_size=1) 112 | 113 | parents = model.parents.detach().cpu().numpy() 114 | 115 | # You can define betas like this.(default betas are 0 at all.) 116 | rest = model( 117 | # betas = torch.randn([1, num_betas], dtype=torch.float32) 118 | ) 119 | rest_pose = rest.joints.detach().cpu().numpy().squeeze()[:24,:] 120 | 121 | root_offset = rest_pose[0] 122 | offsets = rest_pose - rest_pose[parents] 123 | offsets[0] = root_offset 124 | offsets *= 1 125 | 126 | scaling = None 127 | 128 | # Pose setting. 129 | if poses.endswith(".npz"): 130 | poses = np.load(poses) 131 | rots = np.squeeze(poses["poses"], axis=0) # (N, 24, 3) 132 | trans = np.squeeze(poses["trans"], axis=0) # (N, 3) 133 | 134 | elif poses.endswith(".pkl"): 135 | with open(poses, "rb") as f: 136 | poses = pickle.load(f) 137 | rots = poses["smpl_poses"] # (N, 72) 138 | rots = rots.reshape(rots.shape[0], -1, 3) # (N, 24, 3) 139 | scaling = poses["smpl_scaling"] # (1,) 140 | trans = poses["smpl_trans"] # (N, 3) 141 | 142 | else: 143 | raise Exception("This file type is not supported!") 144 | 145 | if scaling is not None: 146 | trans /= scaling 147 | 148 | # to quaternion 149 | rots = quat.from_axis_angle(rots) 150 | 151 | order = "zyx" 152 | pos = offsets[None].repeat(len(rots), axis=0) 153 | positions = pos.copy() 154 | # positions[:,0] += trans * 10 155 | positions[:, 0] += trans 156 | rotations = np.degrees(quat.to_euler(rots, order=order)) 157 | 158 | bvh_data ={ 159 | "rotations": rotations[:, :22], 160 | "positions": positions[:, :22], 161 | "offsets": offsets[:22], 162 | "parents": parents[:22], 163 | "names": names[:22], 164 | "order": order, 165 | "frametime": 1 / fps, 166 | } 167 | 168 | if not output.endswith(".bvh"): 169 | output = output + ".bvh" 170 | 171 | bvh.save(output, bvh_data) 172 | 173 | if mirror: 174 | rots_mirror, trans_mirror = mirror_rot_trans( 175 | rots, trans, names, parents) 176 | positions_mirror = pos.copy() 177 | positions_mirror[:,0] += trans_mirror 178 | rotations_mirror = np.degrees( 179 | quat.to_euler(rots_mirror, order=order)) 180 | 181 | bvh_data ={ 182 | "rotations": rotations_mirror, 183 | "positions": positions_mirror, 184 | "offsets": offsets, 185 | "parents": parents, 186 | "names": names, 187 | "order": order, 188 | "frametime": 1 / fps, 189 | } 190 | 191 | output_mirror = output.split(".")[0] + "_mirror.bvh" 192 | bvh.save(output_mirror, bvh_data) 193 | 194 | 195 | def joints2bvh() 196 | 197 | if __name__ == "__main__": 198 | args = parse_args() 199 | 200 | smpl2bvh(model_path=args.model_path, model_type=args.model_type, 201 | mirror = args.mirror, gender=args.gender, 202 | poses=args.poses, num_betas=args.num_betas, 203 | fps=args.fps, output=args.output) 204 | 205 | print("finished!") -------------------------------------------------------------------------------- /eval_edit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import clip 3 | 4 | import options.option_transformer as option_trans 5 | import models.vqvae as vqvae 6 | import models.t2m_trans as trans 7 | import warnings 8 | warnings.filterwarnings('ignore') 9 | from exit.utils import get_model, generate_src_mask, set_seed 10 | from edit_eval.main_edit_eval import run_all_eval 11 | import subprocess 12 | 13 | 14 | def get_latest_commit_info(): 15 | # Run the git command to get the latest commit hash and message 16 | result = subprocess.run( 17 | ["git", "log", "-1", "--pretty=format:%H %s"], 18 | stdout=subprocess.PIPE, 19 | text=True, 20 | check=True 21 | ) 22 | 23 | # Extract the output 24 | commit_info = result.stdout.strip().replace(' ', '_') 25 | return commit_info[:10] 26 | 27 | 28 | ##### ---- Exp dirs ---- ##### 29 | args = option_trans.get_args_parser() 30 | args.exp_name = f'{get_latest_commit_info()}__{args.exp_name}' 31 | set_seed(args.seed) 32 | 33 | #### ---- ##### 34 | clip_model, clip_preprocess = clip.load("ViT-B/32", device=torch.device('cuda'), jit=False) # Must set jit=False for training 35 | clip.model.convert_weights(clip_model) # Actually this line is unnecessary since clip by default already on float16 36 | clip_model.eval() 37 | for p in clip_model.parameters(): 38 | p.requires_grad = False 39 | 40 | # https://github.com/openai/CLIP/issues/111 41 | class TextCLIP(torch.nn.Module): 42 | def __init__(self, model) : 43 | super(TextCLIP, self).__init__() 44 | self.model = model 45 | 46 | def forward(self,text): 47 | with torch.no_grad(): 48 | word_emb = self.model.token_embedding(text).type(self.model.dtype) 49 | word_emb = word_emb + self.model.positional_embedding.type(self.model.dtype) 50 | word_emb = word_emb.permute(1, 0, 2) # NLD -> LND 51 | word_emb = self.model.transformer(word_emb) 52 | word_emb = self.model.ln_final(word_emb).permute(1, 0, 2).float() 53 | enctxt = self.model.encode_text(text).float() 54 | return enctxt, word_emb 55 | clip_model = TextCLIP(clip_model) 56 | 57 | net = vqvae.HumanVQVAE(args, ## use args to define different parameters in different quantizers 58 | args.nb_code, 59 | args.code_dim, 60 | args.output_emb_width, 61 | args.down_t, 62 | args.stride_t, 63 | args.width, 64 | args.depth, 65 | args.dilation_growth_rate) 66 | 67 | 68 | trans_encoder = trans.Text2Motion_Transformer(net, 69 | num_vq=args.nb_code, 70 | embed_dim=args.embed_dim_gpt, 71 | clip_dim=args.clip_dim, 72 | num_layers=args.num_layers, 73 | num_local_layer=args.num_local_layer, 74 | n_head=args.n_head_gpt, 75 | drop_out_rate=args.drop_out_rate, 76 | fc_rate=args.ff_rate, args=args) 77 | 78 | 79 | print ('loading checkpoint from {}'.format(args.resume_pth)) 80 | ckpt = torch.load(args.resume_pth, map_location='cpu') 81 | net.load_state_dict(ckpt['net'], strict=True) 82 | net.eval() 83 | net.cuda() 84 | 85 | if args.resume_trans is not None and not args.debug: 86 | print ('loading transformer checkpoint from {}'.format(args.resume_trans)) 87 | ckpt = torch.load(args.resume_trans, map_location='cpu') 88 | trans_encoder.load_state_dict(ckpt['trans'], strict=True) 89 | trans_encoder.eval() 90 | trans_encoder.cuda() 91 | 92 | 93 | 94 | def call_T2MBD(clip_text, pose, m_length): 95 | ### FOR NO TEST ###: 96 | # clip_text = [''] * len(clip_text) 97 | # edit_task = 'prefix' # inbetween, 'outpainting', prefix, suffix upperbody 98 | edit_task = trans_encoder.args.edit_task 99 | 100 | text = clip.tokenize(clip_text, truncate=True).cuda() 101 | feat_clip_text, word_emb = clip_model(text) 102 | 103 | bs, seq = pose.shape[:2] 104 | tokens = -1*torch.ones((bs, 50), dtype=torch.long).cuda() 105 | 106 | if edit_task in ['inbetween', 'outpainting']: 107 | m_token_length = torch.ceil((m_length)/4).int().cpu().numpy() 108 | m_token_length_init = (m_token_length * .25).astype(int) 109 | m_length_init = (m_length * .25).int() 110 | for k in range(bs): 111 | l = m_length_init[k] 112 | l_token = m_token_length_init[k] 113 | 114 | if edit_task == 'inbetween': 115 | # start tokens 116 | index_motion = net(pose[k:k+1, :l].cuda(), type='encode') 117 | tokens[k,:index_motion.shape[1]] = index_motion[0] 118 | 119 | # end tokens 120 | index_motion = net(pose[k:k+1, m_length[k]-l :m_length[k]].cuda(), type='encode') 121 | tokens[k, m_token_length[k]-l_token :m_token_length[k]] = index_motion[0] 122 | elif edit_task == 'outpainting': 123 | # inside tokens 124 | index_motion = net(pose[k:k+1, l:m_length[k]-l].cuda(), type='encode') 125 | tokens[k, l_token: l_token+index_motion.shape[1]] = index_motion[0] 126 | 127 | if edit_task in ['prefix', 'suffix']: 128 | m_token_length = torch.ceil((m_length)/4).int().cpu().numpy() 129 | m_token_length_half = (m_token_length * .5).astype(int) 130 | m_length_half = (m_length * .5).int() 131 | for k in range(bs): 132 | if edit_task == 'prefix': 133 | index_motion = net(pose[k:k+1, :m_length_half[k]].cuda(), type='encode') 134 | tokens[k, :m_token_length_half[k]] = index_motion[0] 135 | elif edit_task == 'suffix': 136 | index_motion = net(pose[k:k+1, m_length_half[k]:m_length[k]].cuda(), type='encode') 137 | tokens[k, m_token_length[k]-m_token_length_half[k] :m_token_length[k]] = index_motion[0] 138 | 139 | inpaint_index = trans_encoder(feat_clip_text, word_emb, type="sample", 140 | m_length=m_length.cuda(), token_cond=tokens) 141 | 142 | pred_pose_eval = torch.zeros((bs, seq, pose.shape[-1])).cuda() 143 | for k in range(bs): 144 | pred_pose = net(inpaint_index[k:k+1, :m_token_length[k]], type='decode') 145 | pred_pose_eval[k:k+1, :int(m_length[k].item())] = pred_pose 146 | 147 | return pred_pose_eval 148 | 149 | run_all_eval(call_T2MBD, args.out_dir, args.exp_name, args_orig=args) 150 | 151 | 152 | # from instantmotion import InstantMotion 153 | # from dataset import dataset_TM_eval 154 | # from utils.word_vectorizer import WordVectorizer 155 | # def call_InstantMotionUpper(clip_text, pose, m_length): 156 | # return instant_motion_upper.upper_edit(pose, m_length, clip_text) 157 | 158 | # w_vectorizer = WordVectorizer('./glove', 'our_vab') 159 | # val_loader = dataset_TM_eval.DATALoader('t2m', True, 32, w_vectorizer) 160 | # instant_motion_upper = InstantMotion(is_upper_edit=True, 161 | # extra_args = {'mean':val_loader.dataset.mean, 162 | # 'std':val_loader.dataset.std}).cuda() 163 | # run_all_eval(call_InstantMotionUpper, args.out_dir, args.exp_name) -------------------------------------------------------------------------------- /vq_eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | import torch 5 | import torch.optim as optim 6 | from torch.utils.tensorboard import SummaryWriter 7 | 8 | import models.vqvae as vqvae 9 | import utils.losses as losses 10 | import options.option_vq as option_vq 11 | import utils.utils_model as utils_model 12 | from dataset import dataset_VQ, dataset_TM_eval 13 | import utils.eval_trans as eval_trans 14 | from options.get_eval_option import get_opt 15 | from models.evaluator_wrapper import EvaluatorModelWrapper 16 | import warnings 17 | warnings.filterwarnings('ignore') 18 | from utils.word_vectorizer import WordVectorizer 19 | from tqdm import tqdm 20 | from exit.utils import load_vqvae_from_MMM, init_save_folder, load_last_vqvae, set_seed, seed_worker 21 | from models.vqvae_sep import VQVAE_SEP 22 | from torch.nn.parallel import DistributedDataParallel as DDP 23 | import torch.distributed as dist 24 | from torch.utils.data.distributed import DistributedSampler 25 | import numpy as np 26 | from torch.utils.data._utils.collate import default_collate 27 | 28 | 29 | def collate_fn(batch): 30 | batch.sort(key=lambda x: x[3], reverse=True) 31 | return default_collate(batch) 32 | 33 | 34 | def update_lr_warm_up(optimizer, nb_iter, warm_up_iter, lr): 35 | 36 | current_lr = lr * (nb_iter + 1) / (warm_up_iter + 1) 37 | for param_group in optimizer.param_groups: 38 | param_group["lr"] = current_lr 39 | 40 | return optimizer, current_lr 41 | 42 | ##### ---- Exp dirs ---- ##### 43 | args = option_vq.get_args_parser() 44 | torch.manual_seed(args.seed) 45 | 46 | 47 | ##### ---- DDP ---- ##### 48 | ddp = args.ddp 49 | if ddp: 50 | world_size = args.world_size 51 | ngpus_per_node = torch.cuda.device_count() 52 | local_rank = int(os.environ.get("SLURM_LOCALID")) 53 | rank = int(os.environ.get("SLURM_NODEID")) * ngpus_per_node + local_rank 54 | torch.cuda.set_device(local_rank) 55 | device = f'cuda:{local_rank}' 56 | print(20*'-----') 57 | print('ngpus_per_node: ', ngpus_per_node) 58 | print('From Rank: {}, ==> Initializing Process Group...'.format(rank)) 59 | dist.init_process_group(backend=args.dist_backend, init_method=args.init_method, world_size=args.world_size, rank=rank) 60 | print("process group ready") 61 | print(f"From rank {rank} making model...") 62 | print(20*'-----') 63 | master_process = rank == 0 # this process will do logging, checkpointing etc. 64 | else: 65 | # vanilla, non-DDP run 66 | rank = 0 67 | local_rank = 0 68 | world_size = 1 69 | master_process = True 70 | device = args.device 71 | 72 | args.rank = rank 73 | args.local_rank = local_rank 74 | args.world_size = world_size 75 | args.device = device 76 | args.master_process = master_process 77 | args.batch_size = args.total_batch_size // world_size 78 | 79 | ########## ------------- Seed -----------############## 80 | set_seed(args.seed) 81 | ########## ------------- DIRS -----------############## 82 | if master_process: 83 | args.out_dir = os.path.join(args.out_dir, f'vq', 'eval') # /{args.exp_name} 84 | # os.makedirs(args.out_dir, exist_ok = True) 85 | init_save_folder(args) 86 | 87 | ##### ---- Logger ---- ##### 88 | logger = utils_model.get_logger(args.out_dir, args=args) 89 | writer = SummaryWriter(args.out_dir) 90 | if master_process: 91 | logger.info(json.dumps(vars(args), indent=4, sort_keys=True)) 92 | args.logger = logger 93 | args.writer = writer 94 | 95 | w_vectorizer = WordVectorizer('./glove', 'our_vab') 96 | 97 | if args.dataname == 'kit' : 98 | dataset_opt_path = 'checkpoints/kit/Comp_v6_KLD005/opt.txt' 99 | args.nb_joints = 21 100 | else: 101 | dataset_opt_path = 'checkpoints/t2m/Comp_v6_KLD005/opt.txt' 102 | args.nb_joints = 22 103 | 104 | logger.info(f'Training on {args.dataname}, motions are with {args.nb_joints} joints') 105 | 106 | wrapper_opt = get_opt(dataset_opt_path, device) 107 | eval_wrapper = EvaluatorModelWrapper(wrapper_opt) 108 | 109 | val_dataset = dataset_TM_eval.T2MDataset(args.dataname, True, 110 | 32, 111 | w_vectorizer, 112 | unit_length=2**args.down_t, 113 | args=args) 114 | 115 | val_loader = torch.utils.data.DataLoader( 116 | val_dataset, 117 | batch_size=32, 118 | shuffle=True, 119 | num_workers=args.num_workers, 120 | collate_fn=collate_fn, 121 | drop_last=True, 122 | pin_memory=args.pin_memory) 123 | 124 | if master_process: 125 | logger.info(f"len valid dataset {len(val_dataset)}") 126 | logger.info(f"len valid loader {len(val_loader)}") 127 | 128 | ##### ---- Network ---- ##### 129 | net = vqvae.HumanVQVAE(args, ## use args to define different parameters in different quantizers 130 | args.nb_code, 131 | args.code_dim, 132 | args.output_emb_width, 133 | args.down_t, 134 | args.stride_t, 135 | args.width, 136 | args.depth, 137 | args.dilation_growth_rate, 138 | args.vq_act, 139 | args.vq_norm) 140 | net.eval() 141 | net.to(local_rank) 142 | 143 | if master_process: 144 | n = sum([p.numel() for k, p in net.named_parameters()]) 145 | logger.info(f"Number of transformer parameters: {n/1e6} M") 146 | 147 | if args.resume_pth : 148 | logger.info('loading checkpoint from {}'.format(args.resume_pth)) 149 | ckpt = torch.load(args.resume_pth, map_location='cpu') 150 | 151 | try: 152 | net.load_state_dict(ckpt['net'], strict=True) 153 | del ckpt 154 | except: 155 | sd = {} 156 | for k, v in ckpt['net'].items(): 157 | new_k = k.split('module.')[-1] 158 | sd[k] = v 159 | net.load_state_dict(sd, strict=True) 160 | del sd 161 | del ckpt 162 | 163 | ##### ------ warm-up ------- ##### 164 | fid = [] 165 | div = [] 166 | top1 = [] 167 | top2 = [] 168 | top3 = [] 169 | matching = [] 170 | repeat_time = 10 171 | 172 | for i in range(repeat_time): 173 | best_fid, best_iter, best_div, best_top1, best_top2, best_top3, best_matching \ 174 | = eval_trans.evaluation_vqvae( 175 | args.out_dir, val_loader, net, 0, best_fid=1000, best_iter=0,\ 176 | best_div=100, best_top1=0, best_top2=0, best_top3=0, best_matching=100, 177 | eval_wrapper=eval_wrapper, args=args) 178 | 179 | fid.append(best_fid) 180 | div.append(best_div) 181 | top1.append(best_top1) 182 | top2.append(best_top2) 183 | top3.append(best_top3) 184 | matching.append(best_matching) 185 | 186 | logger.info('final result:') 187 | logger.info(f'fid: {sum(fid)/repeat_time}') 188 | logger.info(f'div: {sum(div)/repeat_time}') 189 | logger.info(f'top1: {sum(top1)/repeat_time}') 190 | logger.info(f'top2: {sum(top2)/repeat_time}') 191 | logger.info(f'top3: {sum(top3)/repeat_time}') 192 | logger.info(f'matching: {sum(matching)/repeat_time}') 193 | 194 | fid = np.array(fid) 195 | div = np.array(div) 196 | top1 = np.array(top1) 197 | top2 = np.array(top2) 198 | top3 = np.array(top3) 199 | matching = np.array(matching) 200 | msg_final = f"FID. {np.mean(fid):.3f}, conf. {np.std(fid)*1.96/np.sqrt(repeat_time):.3f}, Diversity. {np.mean(div):.3f}, conf. {np.std(div)*1.96/np.sqrt(repeat_time):.3f}, TOP1. {np.mean(top1):.3f}, conf. {np.std(top1)*1.96/np.sqrt(repeat_time):.3f}, TOP2. {np.mean(top2):.3f}, conf. {np.std(top2)*1.96/np.sqrt(repeat_time):.3f}, TOP3. {np.mean(top3):.3f}, conf. {np.std(top3)*1.96/np.sqrt(repeat_time):.3f}, Matching. {np.mean(matching):.3f}, conf. {np.std(matching)*1.96/np.sqrt(repeat_time):.3f}" 201 | logger.info(msg_final) 202 | -------------------------------------------------------------------------------- /dataset/dataset_TM_train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils import data 3 | import numpy as np 4 | from os.path import join as pjoin 5 | import random 6 | import codecs as cs 7 | from tqdm import tqdm 8 | import utils.paramUtil as paramUtil 9 | from torch.utils.data._utils.collate import default_collate 10 | import random 11 | import math 12 | 13 | def collate_fn(batch): 14 | batch.sort(key=lambda x: x[3], reverse=True) 15 | return default_collate(batch) 16 | 17 | 18 | '''For use of training text-2-motion generative model''' 19 | class Text2MotionDataset(data.Dataset): 20 | def __init__(self, dataset_name, feat_bias = 5, unit_length = 4, codebook_size = 1024, tokenizer_name=None, up_low_sep=False, args=None): 21 | 22 | self.max_length = 64 23 | self.pointer = 0 24 | self.dataset_name = dataset_name 25 | self.up_low_sep = up_low_sep 26 | 27 | self.unit_length = unit_length 28 | # self.mot_start_idx = codebook_size 29 | self.mot_end_idx = codebook_size 30 | self.mot_pad_idx = codebook_size + 1 # [TODO] I think 513 (codebook_size+1) can be what ever, it will be croped out 31 | if dataset_name == 't2m': 32 | self.data_root = './dataset/HumanML3D' 33 | self.motion_dir = pjoin(self.data_root, 'new_joint_vecs') 34 | self.text_dir = pjoin(self.data_root, 'texts') 35 | self.joints_num = 22 36 | radius = 4 37 | fps = 20 38 | self.max_motion_length = 26 if unit_length == 8 else 50 39 | dim_pose = 263 40 | kinematic_chain = paramUtil.t2m_kinematic_chain 41 | elif dataset_name == 'kit': 42 | self.data_root = './dataset/KIT-ML' 43 | self.motion_dir = pjoin(self.data_root, 'new_joint_vecs') 44 | self.text_dir = pjoin(self.data_root, 'texts') 45 | self.joints_num = 21 46 | radius = 240 * 8 47 | fps = 12.5 48 | dim_pose = 251 49 | self.max_motion_length = 26 if unit_length == 8 else 50 50 | kinematic_chain = paramUtil.kit_kinematic_chain 51 | 52 | split_file = pjoin(self.data_root, 'train.txt') 53 | 54 | 55 | id_list = [] 56 | with cs.open(split_file, 'r') as f: 57 | for line in f.readlines(): 58 | id_list.append(line.strip()) 59 | 60 | new_name_list = [] 61 | data_dict = {} 62 | i_debug = 0 63 | for name in tqdm(id_list): 64 | try: 65 | m_token_list = np.load(pjoin(tokenizer_name, '%s.npy'%name)) 66 | 67 | # Read text 68 | with cs.open(pjoin(self.text_dir, name + '.txt')) as f: 69 | text_data = [] 70 | flag = False 71 | lines = f.readlines() 72 | 73 | for line in lines: 74 | try: 75 | text_dict = {} 76 | line_split = line.strip().split('#') 77 | caption = line_split[0] 78 | t_tokens = line_split[1].split(' ') 79 | f_tag = float(line_split[2]) 80 | to_tag = float(line_split[3]) 81 | f_tag = 0.0 if np.isnan(f_tag) else f_tag 82 | to_tag = 0.0 if np.isnan(to_tag) else to_tag 83 | 84 | text_dict['caption'] = caption 85 | text_dict['tokens'] = t_tokens 86 | if f_tag == 0.0 and to_tag == 0.0: 87 | flag = True 88 | text_data.append(text_dict) 89 | else: 90 | # [INFO] Check with KIT, doesn't come here that mean f_tag & to_tag are 0.0 (tag for caption from-to frames) 91 | m_token_list_new = [tokens[int(f_tag*fps/unit_length) : int(to_tag*fps/unit_length)] for tokens in m_token_list if int(f_tag*fps/unit_length) < int(to_tag*fps/unit_length)] 92 | 93 | if len(m_token_list_new) == 0: 94 | continue 95 | new_name = '%s_%f_%f'%(name, f_tag, to_tag) 96 | 97 | data_dict[new_name] = {'m_token_list': m_token_list_new, 98 | 'text':[text_dict]} 99 | new_name_list.append(new_name) 100 | 101 | if args.debug: 102 | i_debug += 1 103 | if i_debug >= args.maxdata: 104 | break 105 | 106 | except: 107 | pass 108 | 109 | if flag: 110 | data_dict[name] = {'m_token_list': m_token_list, 111 | 'text':text_data} 112 | new_name_list.append(name) 113 | 114 | if args.debug: 115 | i_debug += 1 116 | if i_debug >= args.maxdata: 117 | break 118 | except: 119 | pass 120 | self.data_dict = data_dict 121 | self.name_list = new_name_list 122 | 123 | def __len__(self): 124 | return len(self.data_dict) 125 | 126 | def __getitem__(self, item): 127 | data = self.data_dict[self.name_list[item]] 128 | m_token_list, text_list = data['m_token_list'], data['text'] 129 | m_tokens = random.choice(m_token_list) 130 | 131 | text_data = random.choice(text_list) 132 | caption= text_data['caption'] 133 | 134 | 135 | coin = np.random.choice([False, False, True]) 136 | # print(len(m_tokens)) 137 | if coin: 138 | # drop one token at the head or tail 139 | coin2 = np.random.choice([True, False]) 140 | if coin2: 141 | m_tokens = m_tokens[:-1] 142 | else: 143 | m_tokens = m_tokens[1:] 144 | m_tokens_len = m_tokens.shape[0] 145 | 146 | if self.up_low_sep: 147 | new_len = random.randint(20, self.max_motion_length-1) 148 | len_mult = math.ceil(new_len/m_tokens_len) 149 | m_tokens = np.tile(m_tokens, (len_mult, 1))[:new_len] 150 | m_tokens_len = new_len 151 | if m_tokens_len+1 < self.max_motion_length: 152 | m_tokens = np.concatenate([m_tokens, np.ones((1, 2), dtype=int) * self.mot_end_idx, np.ones((self.max_motion_length-1-m_tokens_len, 2), dtype=int) * self.mot_pad_idx], axis=0) 153 | else: 154 | m_tokens = np.concatenate([m_tokens, np.ones((1, 2), dtype=int) * self.mot_end_idx], axis=0) 155 | else: 156 | if m_tokens_len+1 < self.max_motion_length: 157 | m_tokens = np.concatenate([m_tokens, np.ones((1), dtype=int) * self.mot_end_idx, np.ones((self.max_motion_length-1-m_tokens_len), dtype=int) * self.mot_pad_idx], axis=0) 158 | else: 159 | m_tokens = np.concatenate([m_tokens, np.ones((1), dtype=int) * self.mot_end_idx], axis=0) 160 | return caption, m_tokens, m_tokens_len 161 | 162 | 163 | 164 | 165 | def T2MDataset(dataset_name, 166 | batch_size, codebook_size, tokenizer_name, unit_length=4, 167 | num_workers = 8, up_low_sep=False, args=None) : 168 | dataset = Text2MotionDataset(dataset_name, codebook_size = codebook_size, tokenizer_name = tokenizer_name, unit_length=unit_length, up_low_sep=up_low_sep, args=args) 169 | return dataset 170 | 171 | 172 | def cycle(iterable): 173 | while True: 174 | for x in iterable: 175 | yield x 176 | 177 | 178 | -------------------------------------------------------------------------------- /edit_eval/main_edit_eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import tqdm 3 | import utils.eval_trans as eval_trans 4 | 5 | import numpy as np 6 | import json 7 | 8 | import utils.utils_model as utils_model 9 | import utils.eval_trans as eval_trans 10 | from dataset import dataset_TM_eval 11 | from options.get_eval_option import get_opt 12 | from models.evaluator_wrapper import EvaluatorModelWrapper 13 | from exit.utils import base_dir, init_save_folder, set_seed, seed_worker 14 | from torch.utils.data._utils.collate import default_collate 15 | 16 | 17 | def collate_fn(batch): 18 | batch.sort(key=lambda x: x[3], reverse=True) 19 | return default_collate(batch) 20 | 21 | 22 | def eval_inbetween(eval_wrapper, logger, val_loader, call_model, nb_iter): 23 | num_repeat = 1 24 | motion_annotation_list = [] 25 | motion_pred_list = [] 26 | motion_multimodality = [] 27 | R_precision_real = 0 28 | R_precision = 0 29 | matching_score_real = 0 30 | matching_score_pred = 0 31 | nb_sample = 0 32 | 33 | count = 0 34 | for word_embeddings, pos_one_hots, clip_text, sent_len, pose, m_length, token, name in tqdm(val_loader): 35 | bs, seq = pose.shape[:2] 36 | motion_multimodality_batch = [] 37 | for i in range(num_repeat): 38 | pred_pose_eval = call_model(clip_text, pose, m_length) 39 | 40 | et_pred, em_pred = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len, pred_pose_eval, m_length) 41 | motion_multimodality_batch.append(em_pred.reshape(bs, 1, -1)) 42 | 43 | if i == 0: 44 | et, em = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len, pose, m_length) 45 | motion_annotation_list.append(em) 46 | motion_pred_list.append(em_pred) 47 | 48 | temp_R, temp_match = eval_trans.calculate_R_precision(et.cpu().numpy(), em.cpu().numpy(), top_k=3, sum_all=True) 49 | R_precision_real += temp_R 50 | matching_score_real += temp_match 51 | temp_R, temp_match = eval_trans.calculate_R_precision(et_pred.cpu().numpy(), em_pred.cpu().numpy(), top_k=3, sum_all=True) 52 | R_precision += temp_R 53 | matching_score_pred += temp_match 54 | 55 | nb_sample += bs 56 | ### end if 57 | ### end for 58 | motion_multimodality.append(torch.cat(motion_multimodality_batch, dim=1)) 59 | # if count > 2: 60 | # break 61 | count += 1 62 | motion_annotation_np = torch.cat(motion_annotation_list, dim=0).cpu().numpy() 63 | motion_pred_np = torch.cat(motion_pred_list, dim=0).cpu().numpy() 64 | gt_mu, gt_cov = eval_trans.calculate_activation_statistics(motion_annotation_np) 65 | mu, cov= eval_trans.calculate_activation_statistics(motion_pred_np) 66 | 67 | diversity_real = eval_trans.calculate_diversity(motion_annotation_np, 300 if nb_sample > 300 else 100) 68 | diversity = eval_trans.calculate_diversity(motion_pred_np, 300 if nb_sample > 300 else 100) 69 | 70 | R_precision_real = R_precision_real / nb_sample 71 | R_precision = R_precision / nb_sample 72 | 73 | matching_score_real = matching_score_real / nb_sample 74 | matching_score_pred = matching_score_pred / nb_sample 75 | 76 | multimodality = 0 77 | motion_multimodality = torch.cat(motion_multimodality, dim=0).cpu().numpy() 78 | if num_repeat > 1: 79 | multimodality = eval_trans.calculate_multimodality(motion_multimodality, 10) 80 | 81 | fid = eval_trans.calculate_frechet_distance(gt_mu, gt_cov, mu, cov) 82 | 83 | msg = f"--> \t Eva. Iter {nb_iter} :, \n\ 84 | FID. {fid:.4f} , \n\ 85 | Diversity Real. {diversity_real:.4f}, \n\ 86 | Diversity. {diversity:.4f}, \n\ 87 | R_precision_real. {R_precision_real}, \n\ 88 | R_precision. {R_precision}, \n\ 89 | matching_score_real. {matching_score_real}, \n\ 90 | matching_score_pred. {matching_score_pred}, \n\ 91 | multimodality. {multimodality:.4f}" 92 | logger.info(msg) 93 | return fid, diversity, R_precision, matching_score_pred, multimodality 94 | 95 | def run_all_eval(call_model, out_dir, exp_name, copysource=True, args_orig=None): 96 | from tqdm import tqdm 97 | import os 98 | out_dir = f'{out_dir}/eval_edit' 99 | os.makedirs(out_dir, exist_ok = True) 100 | 101 | class Temp: 102 | def __init__(self): 103 | print('mock:: opt') 104 | args = Temp() 105 | args.out_dir = out_dir 106 | args.exp_name = exp_name 107 | args.seed = 123 108 | args.debug = args_orig.debug 109 | args.maxdata = args_orig.maxdata 110 | init_save_folder(args, copysource) 111 | 112 | ##### ---- Logger ---- ##### 113 | logger = utils_model.get_logger(args.out_dir, args=args) 114 | logger.info(json.dumps(vars(args_orig), indent=4, sort_keys=True)) 115 | logger.info(50*'===') 116 | logger.info(json.dumps(vars(args), indent=4, sort_keys=True)) 117 | 118 | from utils.word_vectorizer import WordVectorizer 119 | w_vectorizer = WordVectorizer('./glove', 'our_vab') 120 | val_dataset = dataset_TM_eval.T2MDataset('t2m', True, 32, w_vectorizer, args=args) 121 | 122 | def get_val_loader(): 123 | g = torch.Generator() 124 | g.manual_seed(args.seed) 125 | return torch.utils.data.DataLoader( 126 | val_dataset, 127 | batch_size=32, 128 | shuffle=True, 129 | num_workers=8 if 'shahab' in os.getcwd() else 0, 130 | collate_fn=collate_fn, 131 | drop_last=True, 132 | pin_memory=True if 'shahab' in os.getcwd() else False, 133 | worker_init_fn=seed_worker, generator=g) 134 | 135 | val_loader = get_val_loader() 136 | 137 | dataset_opt_path = 'checkpoints/t2m/Comp_v6_KLD005/opt.txt' 138 | 139 | wrapper_opt = get_opt(dataset_opt_path, torch.device('cuda')) 140 | eval_wrapper = EvaluatorModelWrapper(wrapper_opt) 141 | 142 | fid = [] 143 | div = [] 144 | top1 = [] 145 | top2 = [] 146 | top3 = [] 147 | matching = [] 148 | multi = [] 149 | repeat_time = 20 150 | 151 | for i in tqdm(range(repeat_time)): 152 | _fid, diversity, R_precision, matching_score_pred, multimodality = eval_inbetween(eval_wrapper, logger, val_loader, call_model, nb_iter=i) 153 | 154 | fid.append(_fid) 155 | div.append(diversity) 156 | top1.append(R_precision[0]) 157 | top2.append(R_precision[1]) 158 | top3.append(R_precision[2]) 159 | matching.append(matching_score_pred) 160 | multi.append(multimodality) 161 | 162 | print('final result:') 163 | print('fid: ', sum(fid)/repeat_time) 164 | print('div: ', sum(div)/repeat_time) 165 | print('top1: ', sum(top1)/repeat_time) 166 | print('top2: ', sum(top2)/repeat_time) 167 | print('top3: ', sum(top3)/repeat_time) 168 | print('matching: ', sum(matching)/repeat_time) 169 | print('multi: ', sum(multi)/repeat_time) 170 | 171 | fid = np.array(fid) 172 | div = np.array(div) 173 | top1 = np.array(top1) 174 | top2 = np.array(top2) 175 | top3 = np.array(top3) 176 | matching = np.array(matching) 177 | multi = np.array(multi) 178 | msg_final = f"FID. {np.mean(fid):.3f}, conf. {np.std(fid)*1.96/np.sqrt(repeat_time):.3f}, Diversity. {np.mean(div):.3f}, conf. {np.std(div)*1.96/np.sqrt(repeat_time):.3f}, TOP1. {np.mean(top1):.3f}, conf. {np.std(top1)*1.96/np.sqrt(repeat_time):.3f}, TOP2. {np.mean(top2):.3f}, conf. {np.std(top2)*1.96/np.sqrt(repeat_time):.3f}, TOP3. {np.mean(top3):.3f}, conf. {np.std(top3)*1.96/np.sqrt(repeat_time):.3f}, Matching. {np.mean(matching):.3f}, conf. {np.std(matching)*1.96/np.sqrt(repeat_time):.3f}, Multi. {np.mean(multi):.3f}, conf. {np.std(multi)*1.96/np.sqrt(repeat_time):.3f}" 179 | logger.info(msg_final) -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: bad 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1=conda_forge 9 | - _openmp_mutex=4.5=2_gnu 10 | - abseil-cpp=20230802.0=h6a678d5_2 11 | - absl-py=2.1.0=py312h06a4308_0 12 | - aiohttp=3.9.5=py312h5eee18b_0 13 | - aiosignal=1.2.0=pyhd3eb1b0_0 14 | - asttokens=2.4.1=pyhd8ed1ab_0 15 | - attrs=23.1.0=py312h06a4308_0 16 | - blas=1.0=mkl 17 | - blinker=1.6.2=py312h06a4308_0 18 | - brotli=1.0.9=h5eee18b_8 19 | - brotli-bin=1.0.9=h5eee18b_8 20 | - brotli-python=1.0.9=py312h6a678d5_8 21 | - bzip2=1.0.8=h5eee18b_6 22 | - c-ares=1.19.1=h5eee18b_0 23 | - ca-certificates=2024.6.2=hbcca054_0 24 | - cachetools=5.3.3=py312h06a4308_0 25 | - certifi=2024.2.2=py312h06a4308_0 26 | - cffi=1.16.0=py312h5eee18b_1 27 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 28 | - click=8.1.7=py312h06a4308_0 29 | - comm=0.2.2=pyhd8ed1ab_0 30 | - contourpy=1.2.0=py312hdb19cb5_0 31 | - cryptography=42.0.5=py312hdda0065_1 32 | - cuda-cudart=11.8.89=0 33 | - cuda-cupti=11.8.87=0 34 | - cuda-libraries=11.8.0=0 35 | - cuda-nvrtc=11.8.89=0 36 | - cuda-nvtx=11.8.86=0 37 | - cuda-runtime=11.8.0=0 38 | - cuda-version=12.5=3 39 | - cycler=0.11.0=pyhd3eb1b0_0 40 | - cyrus-sasl=2.1.28=h52b45da_1 41 | - dbus=1.13.18=hb2f20db_0 42 | - debugpy=1.6.7=py312h6a678d5_0 43 | - decorator=5.1.1=pyhd8ed1ab_0 44 | - exceptiongroup=1.2.0=pyhd8ed1ab_2 45 | - executing=2.0.1=pyhd8ed1ab_0 46 | - expat=2.6.2=h6a678d5_0 47 | - ffmpeg=4.3=hf484d3e_0 48 | - filelock=3.13.1=py312h06a4308_0 49 | - fontconfig=2.14.1=h4c34cd2_2 50 | - fonttools=4.51.0=py312h5eee18b_0 51 | - freetype=2.12.1=h4a9f257_0 52 | - frozenlist=1.4.0=py312h5eee18b_0 53 | - glib=2.78.4=h6a678d5_0 54 | - glib-tools=2.78.4=h6a678d5_0 55 | - gmp=6.2.1=h295c915_3 56 | - gnutls=3.6.15=he1e5248_0 57 | - google-auth=2.29.0=py312h06a4308_0 58 | - google-auth-oauthlib=0.4.1=py_2 59 | - grpc-cpp=1.48.2=he1ff14a_4 60 | - grpcio=1.48.2=py312he1ff14a_4 61 | - gst-plugins-base=1.14.1=h6a678d5_1 62 | - gstreamer=1.14.1=h5eee18b_1 63 | - gtest=1.14.0=hdb19cb5_1 64 | - icu=73.1=h6a678d5_0 65 | - idna=3.7=py312h06a4308_0 66 | - importlib-metadata=7.1.0=pyha770c72_0 67 | - importlib_metadata=7.1.0=hd8ed1ab_0 68 | - intel-openmp=2023.1.0=hdb19cb5_46306 69 | - ipykernel=6.29.3=pyhd33586a_0 70 | - ipython=8.25.0=pyh707e725_0 71 | - jedi=0.19.1=pyhd8ed1ab_0 72 | - jinja2=3.1.4=py312h06a4308_0 73 | - jpeg=9e=h5eee18b_1 74 | - jupyter_client=8.6.2=pyhd8ed1ab_0 75 | - jupyter_core=5.5.0=py312h06a4308_0 76 | - kiwisolver=1.4.4=py312h6a678d5_0 77 | - krb5=1.20.1=h143b758_1 78 | - lame=3.100=h7b6447c_0 79 | - lcms2=2.12=h3be6417_0 80 | - ld_impl_linux-64=2.38=h1181459_1 81 | - lerc=3.0=h295c915_0 82 | - libbrotlicommon=1.0.9=h5eee18b_8 83 | - libbrotlidec=1.0.9=h5eee18b_8 84 | - libbrotlienc=1.0.9=h5eee18b_8 85 | - libclang=14.0.6=default_hc6dbbc7_1 86 | - libclang13=14.0.6=default_he11475f_1 87 | - libcublas=11.11.3.6=0 88 | - libcufft=10.9.0.58=0 89 | - libcufile=1.10.0.4=0 90 | - libcups=2.4.2=h2d74bed_1 91 | - libcurand=10.3.6.39=0 92 | - libcusolver=11.4.1.48=0 93 | - libcusparse=11.7.5.86=0 94 | - libdeflate=1.17=h5eee18b_1 95 | - libedit=3.1.20230828=h5eee18b_0 96 | - libffi=3.4.4=h6a678d5_1 97 | - libgcc-ng=13.2.0=h77fa898_7 98 | - libgfortran-ng=11.2.0=h00389a5_1 99 | - libgfortran5=11.2.0=h1234567_1 100 | - libglib=2.78.4=hdc74915_0 101 | - libgomp=13.2.0=h77fa898_7 102 | - libiconv=1.16=h5eee18b_3 103 | - libidn2=2.3.4=h5eee18b_0 104 | - libjpeg-turbo=2.0.0=h9bf148f_0 105 | - libllvm14=14.0.6=hdb19cb5_3 106 | - libnpp=11.8.0.86=0 107 | - libnvjpeg=11.9.0.86=0 108 | - libpng=1.6.39=h5eee18b_0 109 | - libpq=12.17=hdbd6064_0 110 | - libprotobuf=3.20.3=he621ea3_0 111 | - libsodium=1.0.18=h36c2ea0_1 112 | - libstdcxx-ng=11.2.0=h1234567_1 113 | - libtasn1=4.19.0=h5eee18b_0 114 | - libtiff=4.5.1=h6a678d5_0 115 | - libunistring=0.9.10=h27cfd23_0 116 | - libuuid=1.41.5=h5eee18b_0 117 | - libwebp-base=1.3.2=h5eee18b_0 118 | - libxcb=1.15=h7f8727e_0 119 | - libxkbcommon=1.0.1=h5eee18b_1 120 | - libxml2=2.10.4=hfdd30dd_2 121 | - llvm-openmp=14.0.6=h9e868ea_0 122 | - lz4-c=1.9.4=h6a678d5_1 123 | - markdown=3.4.1=py312h06a4308_0 124 | - markupsafe=2.1.3=py312h5eee18b_0 125 | - matplotlib=3.8.4=py312h06a4308_0 126 | - matplotlib-base=3.8.4=py312h526ad5a_0 127 | - matplotlib-inline=0.1.7=pyhd8ed1ab_0 128 | - mkl=2023.1.0=h213fc3f_46344 129 | - mkl-service=2.4.0=py312h5eee18b_1 130 | - mkl_fft=1.3.8=py312h5eee18b_0 131 | - mkl_random=1.2.4=py312hdb19cb5_0 132 | - mpmath=1.3.0=py312h06a4308_0 133 | - multidict=6.0.4=py312h5eee18b_0 134 | - mysql=5.7.24=h721c034_2 135 | - ncurses=6.4=h6a678d5_0 136 | - nest-asyncio=1.6.0=pyhd8ed1ab_0 137 | - nettle=3.7.3=hbbd107a_1 138 | - networkx=3.1=py312h06a4308_0 139 | - numpy=1.26.4=py312hc5e2394_0 140 | - numpy-base=1.26.4=py312h0da6c21_0 141 | - oauthlib=3.2.2=py312h06a4308_0 142 | - openh264=2.1.1=h4ff587b_0 143 | - openjpeg=2.4.0=h3ad879b_0 144 | - openssl=3.3.0=h4ab18f5_3 145 | - packaging=23.2=py312h06a4308_0 146 | - parso=0.8.4=pyhd8ed1ab_0 147 | - pcre2=10.42=hebb0a14_1 148 | - pexpect=4.9.0=pyhd8ed1ab_0 149 | - pickleshare=0.7.5=py_1003 150 | - pillow=10.3.0=py312h5eee18b_0 151 | - pip=24.0=py312h06a4308_0 152 | - platformdirs=4.2.2=pyhd8ed1ab_0 153 | - plotly=5.19.0=py312he106c6f_0 154 | - ply=3.11=py312h06a4308_1 155 | - prompt-toolkit=3.0.42=pyha770c72_0 156 | - protobuf=3.20.3=py312h6a678d5_0 157 | - psutil=5.9.0=py312h5eee18b_0 158 | - ptyprocess=0.7.0=pyhd3deb0d_0 159 | - pure_eval=0.2.2=pyhd8ed1ab_0 160 | - pyasn1=0.4.8=pyhd3eb1b0_0 161 | - pyasn1-modules=0.2.8=py_0 162 | - pybind11-abi=5=hd3eb1b0_0 163 | - pycparser=2.21=pyhd3eb1b0_0 164 | - pygments=2.18.0=pyhd8ed1ab_0 165 | - pyjwt=2.8.0=py312h06a4308_0 166 | - pyopenssl=24.0.0=py312h06a4308_0 167 | - pyparsing=3.0.9=py312h06a4308_0 168 | - pyqt=5.15.10=py312h6a678d5_0 169 | - pyqt5-sip=12.13.0=py312h5eee18b_0 170 | - pysocks=1.7.1=py312h06a4308_0 171 | - python=3.12.3=h996f2a0_1 172 | - python-dateutil=2.9.0=pyhd8ed1ab_0 173 | - pytorch=2.3.0=py3.12_cuda11.8_cudnn8.7.0_0 174 | - pytorch-cuda=11.8=h7e8668a_5 175 | - pytorch-mutex=1.0=cuda 176 | - pyyaml=6.0.1=py312h5eee18b_0 177 | - pyzmq=25.1.2=py312h6a678d5_0 178 | - qt-main=5.15.2=h53bd1ea_10 179 | - re2=2022.04.01=h295c915_0 180 | - readline=8.2=h5eee18b_0 181 | - requests=2.32.2=py312h06a4308_0 182 | - requests-oauthlib=1.3.0=py_0 183 | - rsa=4.7.2=pyhd3eb1b0_1 184 | - scipy=1.13.0=py312hc5e2394_0 185 | - setuptools=69.5.1=py312h06a4308_0 186 | - sip=6.7.12=py312h6a678d5_0 187 | - six=1.16.0=pyhd3eb1b0_1 188 | - sqlite=3.45.3=h5eee18b_0 189 | - stack_data=0.6.2=pyhd8ed1ab_0 190 | - sympy=1.12=py312h06a4308_0 191 | - tbb=2021.8.0=hdb19cb5_0 192 | - tenacity=8.2.2=py312h06a4308_1 193 | - tensorboard=2.6.0=py_0 194 | - tensorboard-plugin-wit=1.6.0=py_0 195 | - tk=8.6.14=h39e8969_0 196 | - torchvision=0.18.0=py312_cu118 197 | - tornado=6.3.3=py312h5eee18b_0 198 | - traitlets=5.14.3=pyhd8ed1ab_0 199 | - typing_extensions=4.11.0=py312h06a4308_0 200 | - tzdata=2024a=h04d1e81_0 201 | - unicodedata2=15.1.0=py312h5eee18b_0 202 | - urllib3=2.2.1=py312h06a4308_0 203 | - wcwidth=0.2.13=pyhd8ed1ab_0 204 | - werkzeug=3.0.3=py312h06a4308_0 205 | - wheel=0.43.0=py312h06a4308_0 206 | - xz=5.4.6=h5eee18b_1 207 | - yaml=0.2.5=h7b6447c_0 208 | - yarl=1.9.3=py312h5eee18b_0 209 | - zeromq=4.3.5=h6a678d5_0 210 | - zipp=3.17.0=pyhd8ed1ab_0 211 | - zlib=1.2.13=h5eee18b_1 212 | - zstd=1.5.5=hc292b87_2 213 | - pip: 214 | - beautifulsoup4==4.12.3 215 | # - clip==1.0 216 | - einops==0.8.0 217 | - fastjsonschema==2.19.1 218 | - fsspec==2024.5.0 219 | - ftfy==6.2.0 220 | - gdown==5.2.0 221 | - jsonschema==4.22.0 222 | - jsonschema-specifications==2023.12.1 223 | - nbformat==5.10.4 224 | - referencing==0.35.1 225 | - regex==2024.5.15 226 | - rpds-py==0.18.1 227 | - soupsieve==2.5 228 | - tqdm==4.66.4 -------------------------------------------------------------------------------- /my_clip/model.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Tuple, Union 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn 8 | 9 | from einops import rearrange, repeat 10 | 11 | 12 | def mid_feature_append(x): 13 | return repeat(x, 'n b c -> b n 1 c') 14 | 15 | 16 | class LayerNorm(nn.LayerNorm): 17 | """Subclass torch's LayerNorm to handle fp16.""" 18 | 19 | def forward(self, x: torch.Tensor): 20 | orig_type = x.dtype 21 | ret = super().forward(x.type(torch.float32)) 22 | return ret.type(orig_type) 23 | 24 | 25 | class QuickGELU(nn.Module): 26 | def forward(self, x: torch.Tensor): 27 | return x * torch.sigmoid(1.702 * x) 28 | 29 | 30 | class ResidualAttentionBlock(nn.Module): 31 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): 32 | super().__init__() 33 | 34 | self.attn = nn.MultiheadAttention(d_model, n_head) 35 | self.ln_1 = LayerNorm(d_model) 36 | self.mlp = nn.Sequential(OrderedDict([ 37 | ("c_fc", nn.Linear(d_model, d_model * 4)), 38 | ("gelu", QuickGELU()), 39 | ("c_proj", nn.Linear(d_model * 4, d_model)) 40 | ])) 41 | self.ln_2 = LayerNorm(d_model) 42 | self.attn_mask = attn_mask 43 | 44 | def attention(self, x: torch.Tensor, text_mask=None): 45 | if text_mask is not None: 46 | B, T = text_mask.shape 47 | attn_mask = text_mask.view(B, 1, T).repeat(1, self.attn.num_heads, T, 1).view(B*self.attn.num_heads, T, T) 48 | attn_mask = torch.where(attn_mask, 0, -torch.inf).to(x.dtype) 49 | else: 50 | attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device)[:x.shape[0], :x.shape[0]] if self.attn_mask is not None else None 51 | 52 | return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0] 53 | 54 | def forward(self, x: torch.Tensor, text_mask=None): 55 | x = x + self.attention(self.ln_1(x), text_mask=text_mask) 56 | x = x + self.mlp(self.ln_2(x)) 57 | return x 58 | 59 | 60 | class Transformer(nn.Module): 61 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): 62 | super().__init__() 63 | self.width = width 64 | self.layers = layers 65 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) 66 | 67 | def forward(self, x: torch.Tensor, text_mask=None): 68 | for block in self.resblocks: 69 | x = block(x, text_mask=text_mask) 70 | return x 71 | 72 | 73 | class CLIP(nn.Module): 74 | def __init__(self, 75 | embed_dim: int, 76 | # text 77 | context_length: int, 78 | vocab_size: int, 79 | transformer_width: int, 80 | transformer_heads: int, 81 | transformer_layers: int 82 | ): 83 | super().__init__() 84 | 85 | self.context_length = context_length 86 | 87 | self.transformer = Transformer( 88 | width=transformer_width, 89 | layers=transformer_layers, 90 | heads=transformer_heads, 91 | attn_mask=self.build_attention_mask() 92 | ) 93 | 94 | self.vocab_size = vocab_size 95 | self.token_embedding = nn.Embedding(vocab_size, transformer_width) 96 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) 97 | self.ln_final = LayerNorm(transformer_width) 98 | 99 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) 100 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 101 | 102 | self.initialize_parameters() 103 | 104 | def initialize_parameters(self): 105 | nn.init.normal_(self.token_embedding.weight, std=0.02) 106 | nn.init.normal_(self.positional_embedding, std=0.01) 107 | 108 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 109 | attn_std = self.transformer.width ** -0.5 110 | fc_std = (2 * self.transformer.width) ** -0.5 111 | for block in self.transformer.resblocks: 112 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 113 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 114 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 115 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 116 | 117 | if self.text_projection is not None: 118 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) 119 | 120 | def build_attention_mask(self): 121 | # lazily create causal attention mask, with full attention between the vision tokens 122 | # pytorch uses additive attention mask; fill with -inf 123 | mask = torch.empty(self.context_length, self.context_length) 124 | mask.fill_(float("-inf")) 125 | mask.triu_(1) # zero out the lower diagonal 126 | return mask 127 | 128 | @property 129 | def dtype(self): 130 | return next(self.transformer.parameters()).dtype 131 | 132 | def encode_image(self, image): 133 | return self.transformer(image.type(self.dtype)) 134 | 135 | def encode_text(self, text, text_mask=None): 136 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 137 | 138 | x = x + self.positional_embedding[:text.shape[-1]].type(self.dtype) 139 | x = x.permute(1, 0, 2) # NLD -> LND 140 | x = self.transformer(x, text_mask=text_mask) 141 | x = x.permute(1, 0, 2) # LND -> NLD 142 | w = self.ln_final(x).type(self.dtype) 143 | 144 | # x.shape = [batch_size, n_ctx, transformer.width] 145 | # take features from the eot embedding (eot_token is the highest number in each sequence) 146 | x = w[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 147 | 148 | return x, w 149 | 150 | def forward(self, text, text_mask=None): 151 | # image_features = self.encode_image(image) 152 | sentence_feature, mid_features = self.encode_text(text, text_mask=text_mask) 153 | # sentence_feature = sentence_feature / sentence_feature.norm(dim=1, keepdim=True) 154 | 155 | # shape = [global_batch_size, global_batch_size] 156 | return sentence_feature, mid_features 157 | 158 | 159 | def convert_weights(model: nn.Module): 160 | """Convert applicable model parameters to fp16""" 161 | 162 | def _convert_weights_to_fp16(l): 163 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): 164 | l.weight.data = l.weight.data.half() 165 | if l.bias is not None: 166 | l.bias.data = l.bias.data.half() 167 | 168 | if isinstance(l, nn.MultiheadAttention): 169 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: 170 | tensor = getattr(l, attr) 171 | if tensor is not None: 172 | tensor.data = tensor.data.half() 173 | 174 | for name in ["text_projection", "proj"]: 175 | if hasattr(l, name): 176 | attr = getattr(l, name) 177 | if attr is not None: 178 | attr.data = attr.data.half() 179 | 180 | model.apply(_convert_weights_to_fp16) 181 | 182 | 183 | def build_model(state_dict: dict): 184 | 185 | new_state_dict = {} 186 | for k, v in state_dict.items(): 187 | if 'visua' not in k: 188 | new_state_dict[k] = state_dict[k] 189 | state_dict = new_state_dict 190 | 191 | embed_dim = state_dict["text_projection"].shape[1] 192 | context_length = state_dict["positional_embedding"].shape[0] 193 | vocab_size = state_dict["token_embedding.weight"].shape[0] 194 | transformer_width = state_dict["ln_final.weight"].shape[0] 195 | transformer_heads = transformer_width // 64 196 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks"))) 197 | 198 | model = CLIP( 199 | embed_dim, 200 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers 201 | ) 202 | 203 | for key in ["input_resolution", "context_length", "vocab_size"]: 204 | if key in state_dict: 205 | del state_dict[key] 206 | 207 | convert_weights(model) 208 | model.load_state_dict(state_dict) 209 | return model.eval() 210 | -------------------------------------------------------------------------------- /visualization/utils/bvh.py: -------------------------------------------------------------------------------- 1 | import re 2 | import numpy as np 3 | 4 | channelmap = { 5 | 'Xrotation': 'x', 6 | 'Yrotation': 'y', 7 | 'Zrotation': 'z' 8 | } 9 | 10 | channelmap_inv = { 11 | 'x': 'Xrotation', 12 | 'y': 'Yrotation', 13 | 'z': 'Zrotation', 14 | } 15 | 16 | ordermap = { 17 | 'x': 0, 18 | 'y': 1, 19 | 'z': 2, 20 | } 21 | 22 | def load(filename:str, order:str=None) -> dict: 23 | """Loads a BVH file. 24 | 25 | Args: 26 | filename (str): Path to the BVH file. 27 | order (str): The order of the rotation channels. (i.e."xyz") 28 | 29 | Returns: 30 | dict: A dictionary containing the following keys: 31 | * names (list)(jnum): The names of the joints. 32 | * parents (list)(jnum): The parent indices. 33 | * offsets (np.ndarray)(jnum, 3): The offsets of the joints. 34 | * rotations (np.ndarray)(fnum, jnum, 3) : The local coordinates of rotations of the joints. 35 | * positions (np.ndarray)(fnum, jnum, 3) : The positions of the joints. 36 | * order (str): The order of the channels. 37 | * frametime (float): The time between two frames. 38 | """ 39 | 40 | f = open(filename, "r") 41 | 42 | i = 0 43 | active = -1 44 | end_site = False 45 | 46 | # Create empty lists for saving parameters 47 | names = [] 48 | offsets = np.array([]).reshape((0, 3)) 49 | parents = np.array([], dtype=int) 50 | 51 | # Parse the file, line by line 52 | for line in f: 53 | 54 | if "HIERARCHY" in line: continue 55 | if "MOTION" in line: continue 56 | 57 | rmatch = re.match(r"ROOT (\w+)", line) 58 | if rmatch: 59 | names.append(rmatch.group(1)) 60 | offsets = np.append(offsets, np.array([[0, 0, 0]]), axis=0) 61 | parents = np.append(parents, active) 62 | active = (len(parents) - 1) 63 | continue 64 | 65 | if "{" in line: continue 66 | 67 | if "}" in line: 68 | if end_site: 69 | end_site = False 70 | else: 71 | active = parents[active] 72 | continue 73 | 74 | offmatch = re.match(r"\s*OFFSET\s+([\-\d\.e]+)\s+([\-\d\.e]+)\s+([\-\d\.e]+)", line) 75 | if offmatch: 76 | if not end_site: 77 | offsets[active] = np.array([list(map(float, offmatch.groups()))]) 78 | continue 79 | 80 | chanmatch = re.match(r"\s*CHANNELS\s+(\d+)", line) 81 | if chanmatch: 82 | channels = int(chanmatch.group(1)) 83 | if order is None: 84 | channelis = 0 if channels == 3 else 3 85 | channelie = 3 if channels == 3 else 6 86 | parts = line.split()[2 + channelis:2 + channelie] 87 | if any([p not in channelmap for p in parts]): 88 | continue 89 | order = "".join([channelmap[p] for p in parts]) 90 | continue 91 | 92 | jmatch = re.match("\s*JOINT\s+(\w+)", line) 93 | if jmatch: 94 | names.append(jmatch.group(1)) 95 | offsets = np.append(offsets, np.array([[0, 0, 0]]), axis=0) 96 | parents = np.append(parents, active) 97 | active = (len(parents) - 1) 98 | continue 99 | 100 | if "End Site" in line: 101 | end_site = True 102 | continue 103 | 104 | fmatch = re.match("\s*Frames:\s+(\d+)", line) 105 | if fmatch: 106 | fnum = int(fmatch.group(1)) 107 | positions = offsets[None].repeat(fnum, axis=0) 108 | rotations = np.zeros((fnum, len(offsets), 3)) 109 | continue 110 | 111 | fmatch = re.match("\s*Frame Time:\s+([\d\.]+)", line) 112 | if fmatch: 113 | frametime = float(fmatch.group(1)) 114 | continue 115 | 116 | dmatch = line.strip().split(' ') 117 | if dmatch: 118 | data_block = np.array(list(map(float, dmatch))) 119 | N = len(parents) 120 | fi = i 121 | if channels == 3: 122 | positions[fi, 0:1] = data_block[0:3] 123 | rotations[fi, :] = data_block[3:].reshape(N, 3) 124 | elif channels == 6: 125 | data_block = data_block.reshape(N, 6) 126 | positions[fi, :] = data_block[:, 0:3] 127 | rotations[fi, :] = data_block[:, 3:6] 128 | elif channels == 9: 129 | positions[fi, 0] = data_block[0:3] 130 | data_block = data_block[3:].reshape(N - 1, 9) 131 | rotations[fi, 1:] = data_block[:, 3:6] 132 | positions[fi, 1:] += data_block[:, 0:3] * data_block[:, 6:9] 133 | else: 134 | raise Exception("Too many channels! %i" % channels) 135 | 136 | i += 1 137 | 138 | f.close() 139 | 140 | return { 141 | 'rotations': rotations, 142 | 'positions': positions, 143 | 'offsets': offsets, 144 | 'parents': parents, 145 | 'names': names, 146 | 'order': order, 147 | 'frametime': frametime 148 | } 149 | 150 | 151 | def save_joint(f, data, t, i, save_order, order='zyx', save_positions=False): 152 | 153 | save_order.append(i) 154 | 155 | f.write("%sJOINT %s\n" % (t, data['names'][i])) 156 | f.write("%s{\n" % t) 157 | t += '\t' 158 | 159 | f.write("%sOFFSET %f %f %f\n" % (t, data['offsets'][i,0], data['offsets'][i,1], data['offsets'][i,2])) 160 | 161 | if save_positions: 162 | f.write("%sCHANNELS 6 Xposition Yposition Zposition %s %s %s \n" % (t, 163 | channelmap_inv[order[0]], channelmap_inv[order[1]], channelmap_inv[order[2]])) 164 | else: 165 | f.write("%sCHANNELS 3 %s %s %s\n" % (t, 166 | channelmap_inv[order[0]], channelmap_inv[order[1]], channelmap_inv[order[2]])) 167 | 168 | end_site = True 169 | 170 | for j in range(len(data['parents'])): 171 | if data['parents'][j] == i: 172 | t = save_joint(f, data, t, j, save_order, order=order, save_positions=save_positions) 173 | end_site = False 174 | 175 | if end_site: 176 | f.write("%sEnd Site\n" % t) 177 | f.write("%s{\n" % t) 178 | t += '\t' 179 | f.write("%sOFFSET %f %f %f\n" % (t, 0.0, 0.0, 0.0)) 180 | t = t[:-1] 181 | f.write("%s}\n" % t) 182 | 183 | t = t[:-1] 184 | f.write("%s}\n" % t) 185 | 186 | return t 187 | 188 | 189 | def save(filename, data, save_positions=False): 190 | """ Save a joint hierarchy to a file. 191 | 192 | Args: 193 | filename (str): The output will save on the bvh file. 194 | data (dict): The data to save.(rotations, positions, offsets, parents, names, order, frametime) 195 | save_positions (bool): Whether to save all of joint positions on MOTION. (False is recommended.) 196 | """ 197 | 198 | order = data['order'] 199 | frametime = data['frametime'] 200 | 201 | with open(filename, 'w') as f: 202 | 203 | t = "" 204 | f.write("%sHIERARCHY\n" % t) 205 | f.write("%sROOT %s\n" % (t, data['names'][0])) 206 | f.write("%s{\n" % t) 207 | t += '\t' 208 | 209 | f.write("%sOFFSET %f %f %f\n" % (t, data['offsets'][0,0], data['offsets'][0,1], data['offsets'][0,2]) ) 210 | f.write("%sCHANNELS 6 Xposition Yposition Zposition %s %s %s \n" % 211 | (t, channelmap_inv[order[0]], channelmap_inv[order[1]], channelmap_inv[order[2]])) 212 | 213 | save_order = [0] 214 | 215 | for i in range(len(data['parents'])): 216 | if data['parents'][i] == 0: 217 | t = save_joint(f, data, t, i, save_order, order=order, save_positions=save_positions) 218 | 219 | t = t[:-1] 220 | f.write("%s}\n" % t) 221 | 222 | rots, poss = data['rotations'], data['positions'] 223 | 224 | f.write("MOTION\n") 225 | f.write("Frames: %i\n" % len(rots)); 226 | f.write("Frame Time: %f\n" % frametime); 227 | 228 | for i in range(rots.shape[0]): 229 | for j in save_order: 230 | 231 | if save_positions or j == 0: 232 | 233 | f.write("%f %f %f %f %f %f " % ( 234 | poss[i,j,0], poss[i,j,1], poss[i,j,2], 235 | rots[i,j,0], rots[i,j,1], rots[i,j,2])) 236 | 237 | else: 238 | 239 | f.write("%f %f %f " % ( 240 | rots[i,j,0], rots[i,j,1], rots[i,j,2])) 241 | 242 | f.write("\n") -------------------------------------------------------------------------------- /common/skeleton.py: -------------------------------------------------------------------------------- 1 | from common.quaternion import * 2 | import scipy.ndimage.filters as filters 3 | 4 | 5 | class Skeleton(object): 6 | def __init__(self, offset, kinematic_tree, device): 7 | self.device = device 8 | self._raw_offset_np = offset.numpy() 9 | self._raw_offset = offset.clone().detach().to(device).float() 10 | self._kinematic_tree = kinematic_tree 11 | self._offset = None 12 | self._parents = [0] * len(self._raw_offset) 13 | self._parents[0] = -1 14 | for chain in self._kinematic_tree: 15 | for j in range(1, len(chain)): 16 | self._parents[chain[j]] = chain[j-1] 17 | 18 | def njoints(self): 19 | return len(self._raw_offset) 20 | 21 | def offset(self): 22 | return self._offset 23 | 24 | def set_offset(self, offsets): 25 | self._offset = offsets.clone().detach().to(self.device).float() 26 | 27 | def kinematic_tree(self): 28 | return self._kinematic_tree 29 | 30 | def parents(self): 31 | return self._parents 32 | 33 | # joints (batch_size, joints_num, 3) 34 | def get_offsets_joints_batch(self, joints): 35 | assert len(joints.shape) == 3 36 | _offsets = self._raw_offset.expand(joints.shape[0], -1, -1).clone() 37 | for i in range(1, self._raw_offset.shape[0]): 38 | _offsets[:, i] = torch.norm(joints[:, i] - joints[:, self._parents[i]], p=2, dim=1)[:, None] * _offsets[:, i] 39 | 40 | self._offset = _offsets.detach() 41 | return _offsets 42 | 43 | # joints (joints_num, 3) 44 | def get_offsets_joints(self, joints): 45 | assert len(joints.shape) == 2 46 | _offsets = self._raw_offset.clone() 47 | for i in range(1, self._raw_offset.shape[0]): 48 | # print(joints.shape) 49 | _offsets[i] = torch.norm(joints[i] - joints[self._parents[i]], p=2, dim=0) * _offsets[i] 50 | 51 | self._offset = _offsets.detach() 52 | return _offsets 53 | 54 | # face_joint_idx should follow the order of right hip, left hip, right shoulder, left shoulder 55 | # joints (batch_size, joints_num, 3) 56 | def inverse_kinematics_np(self, joints, face_joint_idx, smooth_forward=False): 57 | assert len(face_joint_idx) == 4 58 | '''Get Forward Direction''' 59 | l_hip, r_hip, sdr_r, sdr_l = face_joint_idx 60 | across1 = joints[:, r_hip] - joints[:, l_hip] 61 | across2 = joints[:, sdr_r] - joints[:, sdr_l] 62 | across = across1 + across2 63 | across = across / np.sqrt((across**2).sum(axis=-1))[:, np.newaxis] 64 | # print(across1.shape, across2.shape) 65 | 66 | # forward (batch_size, 3) 67 | forward = np.cross(np.array([[0, 1, 0]]), across, axis=-1) 68 | if smooth_forward: 69 | forward = filters.gaussian_filter1d(forward, 20, axis=0, mode='nearest') 70 | # forward (batch_size, 3) 71 | forward = forward / np.sqrt((forward**2).sum(axis=-1))[..., np.newaxis] 72 | 73 | '''Get Root Rotation''' 74 | target = np.array([[0,0,1]]).repeat(len(forward), axis=0) 75 | root_quat = qbetween_np(forward, target) 76 | 77 | '''Inverse Kinematics''' 78 | # quat_params (batch_size, joints_num, 4) 79 | # print(joints.shape[:-1]) 80 | quat_params = np.zeros(joints.shape[:-1] + (4,)) 81 | # print(quat_params.shape) 82 | root_quat[0] = np.array([[1.0, 0.0, 0.0, 0.0]]) 83 | quat_params[:, 0] = root_quat 84 | # quat_params[0, 0] = np.array([[1.0, 0.0, 0.0, 0.0]]) 85 | for chain in self._kinematic_tree: 86 | R = root_quat 87 | for j in range(len(chain) - 1): 88 | # (batch, 3) 89 | u = self._raw_offset_np[chain[j+1]][np.newaxis,...].repeat(len(joints), axis=0) 90 | # print(u.shape) 91 | # (batch, 3) 92 | v = joints[:, chain[j+1]] - joints[:, chain[j]] 93 | v = v / np.sqrt((v**2).sum(axis=-1))[:, np.newaxis] 94 | # print(u.shape, v.shape) 95 | rot_u_v = qbetween_np(u, v) 96 | 97 | R_loc = qmul_np(qinv_np(R), rot_u_v) 98 | 99 | quat_params[:,chain[j + 1], :] = R_loc 100 | R = qmul_np(R, R_loc) 101 | 102 | return quat_params 103 | 104 | # Be sure root joint is at the beginning of kinematic chains 105 | def forward_kinematics(self, quat_params, root_pos, skel_joints=None, do_root_R=True): 106 | # quat_params (batch_size, joints_num, 4) 107 | # joints (batch_size, joints_num, 3) 108 | # root_pos (batch_size, 3) 109 | if skel_joints is not None: 110 | offsets = self.get_offsets_joints_batch(skel_joints) 111 | if len(self._offset.shape) == 2: 112 | offsets = self._offset.expand(quat_params.shape[0], -1, -1) 113 | joints = torch.zeros(quat_params.shape[:-1] + (3,)).to(self.device) 114 | joints[:, 0] = root_pos 115 | for chain in self._kinematic_tree: 116 | if do_root_R: 117 | R = quat_params[:, 0] 118 | else: 119 | R = torch.tensor([[1.0, 0.0, 0.0, 0.0]]).expand(len(quat_params), -1).detach().to(self.device) 120 | for i in range(1, len(chain)): 121 | R = qmul(R, quat_params[:, chain[i]]) 122 | offset_vec = offsets[:, chain[i]] 123 | joints[:, chain[i]] = qrot(R, offset_vec) + joints[:, chain[i-1]] 124 | return joints 125 | 126 | # Be sure root joint is at the beginning of kinematic chains 127 | def forward_kinematics_np(self, quat_params, root_pos, skel_joints=None, do_root_R=True): 128 | # quat_params (batch_size, joints_num, 4) 129 | # joints (batch_size, joints_num, 3) 130 | # root_pos (batch_size, 3) 131 | if skel_joints is not None: 132 | skel_joints = torch.from_numpy(skel_joints) 133 | offsets = self.get_offsets_joints_batch(skel_joints) 134 | if len(self._offset.shape) == 2: 135 | offsets = self._offset.expand(quat_params.shape[0], -1, -1) 136 | offsets = offsets.numpy() 137 | joints = np.zeros(quat_params.shape[:-1] + (3,)) 138 | joints[:, 0] = root_pos 139 | for chain in self._kinematic_tree: 140 | if do_root_R: 141 | R = quat_params[:, 0] 142 | else: 143 | R = np.array([[1.0, 0.0, 0.0, 0.0]]).repeat(len(quat_params), axis=0) 144 | for i in range(1, len(chain)): 145 | R = qmul_np(R, quat_params[:, chain[i]]) 146 | offset_vec = offsets[:, chain[i]] 147 | joints[:, chain[i]] = qrot_np(R, offset_vec) + joints[:, chain[i - 1]] 148 | return joints 149 | 150 | def forward_kinematics_cont6d_np(self, cont6d_params, root_pos, skel_joints=None, do_root_R=True): 151 | # cont6d_params (batch_size, joints_num, 6) 152 | # joints (batch_size, joints_num, 3) 153 | # root_pos (batch_size, 3) 154 | if skel_joints is not None: 155 | skel_joints = torch.from_numpy(skel_joints) 156 | offsets = self.get_offsets_joints_batch(skel_joints) 157 | if len(self._offset.shape) == 2: 158 | offsets = self._offset.expand(cont6d_params.shape[0], -1, -1) 159 | offsets = offsets.numpy() 160 | joints = np.zeros(cont6d_params.shape[:-1] + (3,)) 161 | joints[:, 0] = root_pos 162 | for chain in self._kinematic_tree: 163 | if do_root_R: 164 | matR = cont6d_to_matrix_np(cont6d_params[:, 0]) 165 | else: 166 | matR = np.eye(3)[np.newaxis, :].repeat(len(cont6d_params), axis=0) 167 | for i in range(1, len(chain)): 168 | matR = np.matmul(matR, cont6d_to_matrix_np(cont6d_params[:, chain[i]])) 169 | offset_vec = offsets[:, chain[i]][..., np.newaxis] 170 | # print(matR.shape, offset_vec.shape) 171 | joints[:, chain[i]] = np.matmul(matR, offset_vec).squeeze(-1) + joints[:, chain[i-1]] 172 | return joints 173 | 174 | def forward_kinematics_cont6d(self, cont6d_params, root_pos, skel_joints=None, do_root_R=True): 175 | # cont6d_params (batch_size, joints_num, 6) 176 | # joints (batch_size, joints_num, 3) 177 | # root_pos (batch_size, 3) 178 | if skel_joints is not None: 179 | # skel_joints = torch.from_numpy(skel_joints) 180 | offsets = self.get_offsets_joints_batch(skel_joints) 181 | if len(self._offset.shape) == 2: 182 | offsets = self._offset.expand(cont6d_params.shape[0], -1, -1) 183 | joints = torch.zeros(cont6d_params.shape[:-1] + (3,)).to(cont6d_params.device) 184 | joints[..., 0, :] = root_pos 185 | for chain in self._kinematic_tree: 186 | if do_root_R: 187 | matR = cont6d_to_matrix(cont6d_params[:, 0]) 188 | else: 189 | matR = torch.eye(3).expand((len(cont6d_params), -1, -1)).detach().to(cont6d_params.device) 190 | for i in range(1, len(chain)): 191 | matR = torch.matmul(matR, cont6d_to_matrix(cont6d_params[:, chain[i]])) 192 | offset_vec = offsets[:, chain[i]].unsqueeze(-1) 193 | # print(matR.shape, offset_vec.shape) 194 | joints[:, chain[i]] = torch.matmul(matR, offset_vec).squeeze(-1) + joints[:, chain[i-1]] 195 | return joints 196 | 197 | 198 | 199 | 200 | 201 | -------------------------------------------------------------------------------- /visualization/joints2smpl/src/prior.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is 4 | # holder of all proprietary rights on this computer program. 5 | # You can only use this computer program if you have closed 6 | # a license agreement with MPG or you get the right to use the computer 7 | # program from someone who is authorized to grant you that right. 8 | # Any use of the computer program without a valid license is prohibited and 9 | # liable to prosecution. 10 | # 11 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung 12 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute 13 | # for Intelligent Systems. All rights reserved. 14 | # 15 | # Contact: ps-license@tuebingen.mpg.de 16 | 17 | from __future__ import absolute_import 18 | from __future__ import print_function 19 | from __future__ import division 20 | 21 | import sys 22 | import os 23 | 24 | import time 25 | import pickle 26 | 27 | import numpy as np 28 | 29 | import torch 30 | import torch.nn as nn 31 | 32 | DEFAULT_DTYPE = torch.float32 33 | 34 | 35 | def create_prior(prior_type, **kwargs): 36 | if prior_type == 'gmm': 37 | prior = MaxMixturePrior(**kwargs) 38 | elif prior_type == 'l2': 39 | return L2Prior(**kwargs) 40 | elif prior_type == 'angle': 41 | return SMPLifyAnglePrior(**kwargs) 42 | elif prior_type == 'none' or prior_type is None: 43 | # Don't use any pose prior 44 | def no_prior(*args, **kwargs): 45 | return 0.0 46 | prior = no_prior 47 | else: 48 | raise ValueError('Prior {}'.format(prior_type) + ' is not implemented') 49 | return prior 50 | 51 | 52 | class SMPLifyAnglePrior(nn.Module): 53 | def __init__(self, dtype=torch.float32, **kwargs): 54 | super(SMPLifyAnglePrior, self).__init__() 55 | 56 | # Indices for the roration angle of 57 | # 55: left elbow, 90deg bend at -np.pi/2 58 | # 58: right elbow, 90deg bend at np.pi/2 59 | # 12: left knee, 90deg bend at np.pi/2 60 | # 15: right knee, 90deg bend at np.pi/2 61 | angle_prior_idxs = np.array([55, 58, 12, 15], dtype=np.int64) 62 | angle_prior_idxs = torch.tensor(angle_prior_idxs, dtype=torch.long) 63 | self.register_buffer('angle_prior_idxs', angle_prior_idxs) 64 | 65 | angle_prior_signs = np.array([1, -1, -1, -1], 66 | dtype=np.float32 if dtype == torch.float32 67 | else np.float64) 68 | angle_prior_signs = torch.tensor(angle_prior_signs, 69 | dtype=dtype) 70 | self.register_buffer('angle_prior_signs', angle_prior_signs) 71 | 72 | def forward(self, pose, with_global_pose=False): 73 | ''' Returns the angle prior loss for the given pose 74 | 75 | Args: 76 | pose: (Bx[23 + 1] * 3) torch tensor with the axis-angle 77 | representation of the rotations of the joints of the SMPL model. 78 | Kwargs: 79 | with_global_pose: Whether the pose vector also contains the global 80 | orientation of the SMPL model. If not then the indices must be 81 | corrected. 82 | Returns: 83 | A sze (B) tensor containing the angle prior loss for each element 84 | in the batch. 85 | ''' 86 | angle_prior_idxs = self.angle_prior_idxs - (not with_global_pose) * 3 87 | return torch.exp(pose[:, angle_prior_idxs] * 88 | self.angle_prior_signs).pow(2) 89 | 90 | 91 | class L2Prior(nn.Module): 92 | def __init__(self, dtype=DEFAULT_DTYPE, reduction='sum', **kwargs): 93 | super(L2Prior, self).__init__() 94 | 95 | def forward(self, module_input, *args): 96 | return torch.sum(module_input.pow(2)) 97 | 98 | 99 | class MaxMixturePrior(nn.Module): 100 | 101 | def __init__(self, prior_folder='prior', 102 | num_gaussians=6, dtype=DEFAULT_DTYPE, epsilon=1e-16, 103 | use_merged=True, 104 | **kwargs): 105 | super(MaxMixturePrior, self).__init__() 106 | 107 | if dtype == DEFAULT_DTYPE: 108 | np_dtype = np.float32 109 | elif dtype == torch.float64: 110 | np_dtype = np.float64 111 | else: 112 | print('Unknown float type {}, exiting!'.format(dtype)) 113 | sys.exit(-1) 114 | 115 | self.num_gaussians = num_gaussians 116 | self.epsilon = epsilon 117 | self.use_merged = use_merged 118 | gmm_fn = 'gmm_{:02d}.pkl'.format(num_gaussians) 119 | 120 | full_gmm_fn = os.path.join(prior_folder, gmm_fn) 121 | if not os.path.exists(full_gmm_fn): 122 | print('The path to the mixture prior "{}"'.format(full_gmm_fn) + 123 | ' does not exist, exiting!') 124 | sys.exit(-1) 125 | 126 | with open(full_gmm_fn, 'rb') as f: 127 | gmm = pickle.load(f, encoding='latin1') 128 | 129 | if type(gmm) == dict: 130 | means = gmm['means'].astype(np_dtype) 131 | covs = gmm['covars'].astype(np_dtype) 132 | weights = gmm['weights'].astype(np_dtype) 133 | elif 'sklearn.mixture.gmm.GMM' in str(type(gmm)): 134 | means = gmm.means_.astype(np_dtype) 135 | covs = gmm.covars_.astype(np_dtype) 136 | weights = gmm.weights_.astype(np_dtype) 137 | else: 138 | print('Unknown type for the prior: {}, exiting!'.format(type(gmm))) 139 | sys.exit(-1) 140 | 141 | self.register_buffer('means', torch.tensor(means, dtype=dtype)) 142 | 143 | self.register_buffer('covs', torch.tensor(covs, dtype=dtype)) 144 | 145 | precisions = [np.linalg.inv(cov) for cov in covs] 146 | precisions = np.stack(precisions).astype(np_dtype) 147 | 148 | self.register_buffer('precisions', 149 | torch.tensor(precisions, dtype=dtype)) 150 | 151 | # The constant term: 152 | sqrdets = np.array([(np.sqrt(np.linalg.det(c))) 153 | for c in gmm['covars']]) 154 | const = (2 * np.pi)**(69 / 2.) 155 | 156 | nll_weights = np.asarray(gmm['weights'] / (const * 157 | (sqrdets / sqrdets.min()))) 158 | nll_weights = torch.tensor(nll_weights, dtype=dtype).unsqueeze(dim=0) 159 | self.register_buffer('nll_weights', nll_weights) 160 | 161 | weights = torch.tensor(gmm['weights'], dtype=dtype).unsqueeze(dim=0) 162 | self.register_buffer('weights', weights) 163 | 164 | self.register_buffer('pi_term', 165 | torch.log(torch.tensor(2 * np.pi, dtype=dtype))) 166 | 167 | cov_dets = [np.log(np.linalg.det(cov.astype(np_dtype)) + epsilon) 168 | for cov in covs] 169 | self.register_buffer('cov_dets', 170 | torch.tensor(cov_dets, dtype=dtype)) 171 | 172 | # The dimensionality of the random variable 173 | self.random_var_dim = self.means.shape[1] 174 | 175 | def get_mean(self): 176 | ''' Returns the mean of the mixture ''' 177 | mean_pose = torch.matmul(self.weights, self.means) 178 | return mean_pose 179 | 180 | def merged_log_likelihood(self, pose, betas): 181 | diff_from_mean = pose.unsqueeze(dim=1) - self.means 182 | 183 | prec_diff_prod = torch.einsum('mij,bmj->bmi', 184 | [self.precisions, diff_from_mean]) 185 | diff_prec_quadratic = (prec_diff_prod * diff_from_mean).sum(dim=-1) 186 | 187 | curr_loglikelihood = 0.5 * diff_prec_quadratic - \ 188 | torch.log(self.nll_weights) 189 | # curr_loglikelihood = 0.5 * (self.cov_dets.unsqueeze(dim=0) + 190 | # self.random_var_dim * self.pi_term + 191 | # diff_prec_quadratic 192 | # ) - torch.log(self.weights) 193 | 194 | min_likelihood, _ = torch.min(curr_loglikelihood, dim=1) 195 | return min_likelihood 196 | 197 | def log_likelihood(self, pose, betas, *args, **kwargs): 198 | ''' Create graph operation for negative log-likelihood calculation 199 | ''' 200 | likelihoods = [] 201 | 202 | for idx in range(self.num_gaussians): 203 | mean = self.means[idx] 204 | prec = self.precisions[idx] 205 | cov = self.covs[idx] 206 | diff_from_mean = pose - mean 207 | 208 | curr_loglikelihood = torch.einsum('bj,ji->bi', 209 | [diff_from_mean, prec]) 210 | curr_loglikelihood = torch.einsum('bi,bi->b', 211 | [curr_loglikelihood, 212 | diff_from_mean]) 213 | cov_term = torch.log(torch.det(cov) + self.epsilon) 214 | curr_loglikelihood += 0.5 * (cov_term + 215 | self.random_var_dim * 216 | self.pi_term) 217 | likelihoods.append(curr_loglikelihood) 218 | 219 | log_likelihoods = torch.stack(likelihoods, dim=1) 220 | min_idx = torch.argmin(log_likelihoods, dim=1) 221 | weight_component = self.nll_weights[:, min_idx] 222 | weight_component = -torch.log(weight_component) 223 | 224 | return weight_component + log_likelihoods[:, min_idx] 225 | 226 | def forward(self, pose, betas): 227 | if self.use_merged: 228 | return self.merged_log_likelihood(pose, betas) 229 | else: 230 | return self.log_likelihood(pose, betas) -------------------------------------------------------------------------------- /dataset/dataset_TM_eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils import data 3 | import numpy as np 4 | from os.path import join as pjoin 5 | import random 6 | import codecs as cs 7 | from tqdm import tqdm 8 | 9 | import utils.paramUtil as paramUtil 10 | from torch.utils.data._utils.collate import default_collate 11 | 12 | 13 | def collate_fn(batch): 14 | batch.sort(key=lambda x: x[3], reverse=True) 15 | return default_collate(batch) 16 | 17 | 18 | '''For use of training text-2-motion generative model''' 19 | class Text2MotionDataset(data.Dataset): 20 | def __init__(self, dataset_name, is_test, w_vectorizer, feat_bias = 5, max_text_len = 20, unit_length = 4, shuffle=True, args=None): 21 | 22 | self.max_length = 20 23 | self.pointer = 0 24 | self.dataset_name = dataset_name 25 | self.is_test = is_test 26 | self.max_text_len = max_text_len 27 | self.unit_length = unit_length 28 | self.w_vectorizer = w_vectorizer 29 | if dataset_name == 't2m': 30 | self.data_root = './dataset/HumanML3D' 31 | self.motion_dir = pjoin(self.data_root, 'new_joint_vecs') 32 | self.text_dir = pjoin(self.data_root, 'texts') 33 | self.joints_num = 22 34 | radius = 4 35 | fps = 20 36 | self.max_motion_length = 196 37 | dim_pose = 263 38 | kinematic_chain = paramUtil.t2m_kinematic_chain 39 | self.meta_dir = 'checkpoints/t2m/VQVAEV3_CB1024_CMT_H1024_NRES3/meta' 40 | elif dataset_name == 'kit': 41 | self.data_root = './dataset/KIT-ML' 42 | self.motion_dir = pjoin(self.data_root, 'new_joint_vecs') 43 | self.text_dir = pjoin(self.data_root, 'texts') 44 | self.joints_num = 21 45 | radius = 240 * 8 46 | fps = 12.5 47 | dim_pose = 251 48 | self.max_motion_length = 196 49 | kinematic_chain = paramUtil.kit_kinematic_chain 50 | self.meta_dir = 'checkpoints/kit/VQVAEV3_CB1024_CMT_H1024_NRES3/meta' 51 | 52 | mean = np.load(pjoin(self.meta_dir, 'mean.npy')) 53 | std = np.load(pjoin(self.meta_dir, 'std.npy')) 54 | 55 | if is_test: 56 | split_file = pjoin(self.data_root, 'test.txt') 57 | else: 58 | split_file = pjoin(self.data_root, 'val.txt') 59 | 60 | min_motion_len = 40 if self.dataset_name =='t2m' else 24 61 | # min_motion_len = 64 62 | 63 | joints_num = self.joints_num 64 | 65 | data_dict = {} 66 | id_list = [] 67 | with cs.open(split_file, 'r') as f: 68 | for line in f.readlines(): 69 | id_list.append(line.strip()) 70 | if args.debug: id_list = id_list[:args.maxdata] 71 | 72 | new_name_list = [] 73 | length_list = [] 74 | for name in tqdm(id_list): 75 | try: 76 | motion = np.load(pjoin(self.motion_dir, name + '.npy')) 77 | if np.isnan(motion.mean()): 78 | print(f"NaN found in {self.motion_dir}/{name}") 79 | continue 80 | if (len(motion)) < min_motion_len or (len(motion) >= 200): 81 | continue 82 | text_data = [] 83 | flag = False 84 | with cs.open(pjoin(self.text_dir, name + '.txt')) as f: 85 | for line in f.readlines(): 86 | text_dict = {} 87 | line_split = line.strip().split('#') 88 | caption = line_split[0] 89 | tokens = line_split[1].split(' ') 90 | f_tag = float(line_split[2]) 91 | to_tag = float(line_split[3]) 92 | f_tag = 0.0 if np.isnan(f_tag) else f_tag 93 | to_tag = 0.0 if np.isnan(to_tag) else to_tag 94 | 95 | text_dict['caption'] = caption 96 | text_dict['tokens'] = tokens 97 | if f_tag == 0.0 and to_tag == 0.0: 98 | flag = True 99 | text_data.append(text_dict) 100 | else: 101 | try: 102 | n_motion = motion[int(f_tag*fps) : int(to_tag*fps)] 103 | if np.isnan(n_motion.mean()): 104 | print(f"NaN found in {self.motion_dir}/{name}") 105 | continue 106 | if (len(n_motion)) < min_motion_len or (len(n_motion) >= 200): 107 | continue 108 | new_name = random.choice('ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name 109 | while new_name in data_dict: 110 | new_name = random.choice('ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name 111 | data_dict[new_name] = {'motion': n_motion, 112 | 'length': len(n_motion), 113 | 'text':[text_dict]} 114 | new_name_list.append(new_name) 115 | length_list.append(len(n_motion)) 116 | 117 | except: 118 | print(line_split) 119 | print(line_split[2], line_split[3], f_tag, to_tag, name) 120 | # break 121 | 122 | if flag: 123 | data_dict[name] = {'motion': motion, 124 | 'length': len(motion), 125 | 'text': text_data} 126 | new_name_list.append(name) 127 | length_list.append(len(motion)) 128 | 129 | except Exception as e: 130 | # print(e) 131 | pass 132 | 133 | name_list, length_list = zip(*sorted(zip(new_name_list, length_list), key=lambda x: x[1])) 134 | self.mean = mean 135 | self.std = std 136 | self.length_arr = np.array(length_list) 137 | self.data_dict = data_dict 138 | self.name_list = name_list 139 | self.reset_max_len(self.max_length) 140 | self.shuffle = shuffle 141 | 142 | def reset_max_len(self, length): 143 | assert length <= self.max_motion_length 144 | self.pointer = np.searchsorted(self.length_arr, length) 145 | print("Pointer Pointing at %d"%self.pointer) 146 | self.max_length = length 147 | 148 | def inv_transform(self, data): 149 | return data * self.std + self.mean 150 | 151 | def forward_transform(self, data): 152 | return (data - self.mean) / self.std 153 | 154 | def __len__(self): 155 | return len(self.data_dict) - self.pointer 156 | 157 | def __getitem__(self, item): 158 | idx = self.pointer + item 159 | name = self.name_list[idx] 160 | data = self.data_dict[name] 161 | # data = self.data_dict[self.name_list[idx]] 162 | motion, m_length, text_list = data['motion'], data['length'], data['text'] 163 | # Randomly select a caption 164 | text_data = random.choice(text_list) 165 | caption, tokens = text_data['caption'], text_data['tokens'] 166 | 167 | if len(tokens) < self.max_text_len: 168 | # pad with "unk" 169 | tokens = ['sos/OTHER'] + tokens + ['eos/OTHER'] 170 | sent_len = len(tokens) 171 | tokens = tokens + ['unk/OTHER'] * (self.max_text_len + 2 - sent_len) 172 | else: 173 | # crop 174 | tokens = tokens[:self.max_text_len] 175 | tokens = ['sos/OTHER'] + tokens + ['eos/OTHER'] 176 | sent_len = len(tokens) 177 | pos_one_hots = [] 178 | word_embeddings = [] 179 | for token in tokens: 180 | word_emb, pos_oh = self.w_vectorizer[token] 181 | pos_one_hots.append(pos_oh[None, :]) 182 | word_embeddings.append(word_emb[None, :]) 183 | pos_one_hots = np.concatenate(pos_one_hots, axis=0) 184 | word_embeddings = np.concatenate(word_embeddings, axis=0) 185 | 186 | if self.unit_length < 10 and self.shuffle: 187 | coin2 = np.random.choice(['single', 'single', 'double']) 188 | else: 189 | coin2 = 'single' 190 | 191 | if coin2 == 'double': 192 | m_length = (m_length // self.unit_length - 1) * self.unit_length 193 | elif coin2 == 'single': 194 | m_length = (m_length // self.unit_length) * self.unit_length 195 | idx = random.randint(0, len(motion) - m_length) 196 | motion = motion[idx:idx+m_length] 197 | 198 | "Z Normalization" 199 | motion = (motion - self.mean) / self.std 200 | 201 | if m_length < self.max_motion_length and self.shuffle: 202 | motion = np.concatenate([motion, 203 | np.zeros((self.max_motion_length - m_length, motion.shape[1])) 204 | ], axis=0) 205 | 206 | return word_embeddings, pos_one_hots, caption, sent_len, motion, m_length, '_'.join(tokens), name 207 | 208 | 209 | 210 | from torch.utils.data.distributed import DistributedSampler 211 | 212 | def T2MDataset(dataset_name, is_test, batch_size, w_vectorizer, 213 | num_workers = 0, unit_length = 4, shuffle=True, args=None) : 214 | 215 | return Text2MotionDataset(dataset_name, is_test, w_vectorizer, 216 | unit_length=unit_length, shuffle=shuffle, args=args) 217 | 218 | 219 | def cycle(iterable): 220 | while True: 221 | for x in iterable: 222 | yield x 223 | --------------------------------------------------------------------------------