├── parser ├── __init__.py ├── evaluation.py ├── training.py ├── parser_mixamo.py └── base.py ├── data_preprocess ├── Lafan1_and_dog │ ├── lafan1_test.txt │ ├── dog_test.txt │ ├── lafan1_train.txt │ ├── dog_train.txt │ ├── std_bvh │ │ ├── dog_std.bvh │ │ └── hum_std.bvh │ └── datasetserial.py └── Mixamo │ ├── download_test.sh │ ├── fbx2bvh.py │ ├── __init__.py │ ├── preprocess.py │ ├── split_joint.py │ ├── bvh_writer.py │ ├── motion_dataset.py │ └── combined_motion.py ├── models ├── __init__.py ├── functions.py ├── base_model.py ├── Intergrated.py ├── IK.py ├── Kinematics.py └── multi_attention_forward.py ├── setup.py ├── config.py ├── eval.py ├── LICENSE ├── get_error.py ├── loss_record.py ├── train_mixamo.py ├── demo_mixamo.py ├── train_lafan1dog.py ├── test_mixamo.py ├── eval_single_pair.py ├── utils ├── metrices.py ├── utils.py └── data_utils.py ├── README.md ├── demo_hum2dog.py ├── demo_dog2hum.py ├── loss_function.py └── outer_utils ├── Animation.py └── BVH_mod.py /parser/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data_preprocess/Lafan1_and_dog/lafan1_test.txt: -------------------------------------------------------------------------------- 1 | aiming1_subject1 2 | walk1_subject1 3 | walk2_subject1 4 | walk3_subject1 5 | walk4_subject1 6 | run2_subject1 -------------------------------------------------------------------------------- /data_preprocess/Mixamo/download_test.sh: -------------------------------------------------------------------------------- 1 | export fileid=1_849LvuT3WBEHktBT97P2oMBzeJz7-UP 2 | export filename=test_set.tar.bz2 3 | 4 | wget -O $filename 'https://docs.google.com/uc?export=download&id='$fileid 5 | 6 | tar -jxvf $filename 7 | rm $filename 8 | -------------------------------------------------------------------------------- /data_preprocess/Lafan1_and_dog/dog_test.txt: -------------------------------------------------------------------------------- 1 | D1_001_KAN01_001 2 | D1_004_KAN01_001 3 | D1_005_KAN01_001 4 | D1_008_KAN01_002 5 | D1_010_KAN01_001 6 | D1_013_KAN01_001 7 | D1_047_KAN01_001 8 | D1_047z_KAN01_002 9 | D1_047z_KAN01_003 10 | D1_057_KAN01_001 11 | D1_061z_KAN01_003 12 | D1_ex03_KAN02_006 13 | D1_ex04_KAN02_001 -------------------------------------------------------------------------------- /data_preprocess/Lafan1_and_dog/lafan1_train.txt: -------------------------------------------------------------------------------- 1 | aiming1_subject4 2 | aiming2_subject2 3 | aiming2_subject3 4 | aiming2_subject5 5 | run1_subject2 6 | run1_subject5 7 | run2_subject4 8 | sprint1_subject2 9 | sprint1_subject4 10 | walk1_subject2 11 | walk1_subject5 12 | walk2_subject3 13 | walk2_subject4 14 | walk3_subject2 15 | walk3_subject3 16 | walk3_subject4 17 | walk3_subject5 -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | def creat_model(args, body_parts, datasets, topology_name): 2 | if args.architecture_name == 'pan': 3 | import models.architecture_humdog 4 | return models.architecture_humdog.PAN_model(args, body_parts, datasets, topology_name) 5 | else: 6 | raise Exception('Unimplemented model') 7 | 8 | 9 | def create_model_mixamo(args, character_names, dataset): 10 | if args.model == 'pan': 11 | import models.architecture_mixamo 12 | return models.architecture_mixamo.PAN_model(args, character_names, dataset) 13 | else: 14 | raise Exception('Unimplemented model') -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | setup( 4 | name='pan-motion-retargeting', 5 | description="The official PyTorch implementation of the paper `Pose-aware Attention Network for Flexible Motion Retargeting by Body Part.`", 6 | author='Shihong Xia', 7 | author_email='xsh@ict.ac.cn', 8 | python_requires='>=3.8.12', 9 | install_requires=[ 10 | 'setuptools==59.5.0', 11 | 'numpy== 1.21.4', 12 | 'scipy==1.7.3', 13 | 'scikit-learn==1.1.3', 14 | 'tensorboard==2.9.1', 15 | 'tqdm==4.62.3', 16 | 'torchsummary==1.5.1', 17 | ], 18 | 19 | packages=find_packages(), 20 | 21 | ) -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | class Configuration(object): 2 | hum_njoints = 22 # joint number of Lafan1 dataset (exclude End sites) 3 | dog_njoints = 21 # joint number of Dog dataset (exclude End sites) 4 | 5 | # The value in correspondence are indices in the skeleton tree excluding the End sites. 6 | correspondence = [{"hum_joints": [3, 4, 7, 8, 1, 2, 5, 6], "dog_joints": [8, 15, 12, 18, 5, 6, 7, 9, 10, 11, 13, 14, 16, 17]}, # two/four legs 7 | {"hum_joints": [9, 10, 11], "dog_joints": [1, 2]}, # Spine 8 | {"hum_joints": [12, 13], "dog_joints": [3, 4]}] # Head 9 | 10 | dog_end = [10, 15, 19, 23] # These value are indices of the used End Sites in Dog skeleton 11 | hum_end = [5, 10, 16] # These value are indices of the used End Sites in human skeleton -------------------------------------------------------------------------------- /data_preprocess/Lafan1_and_dog/dog_train.txt: -------------------------------------------------------------------------------- 1 | D1_010_KAN01_002 2 | D1_053_KAN01_002 3 | D1_058_KAN01_001 4 | D1_010_KAN01_003 5 | D1_ex04_KAN02_004 6 | D1_009_KAN01_002 7 | D1_ex03_KAN02_003 8 | D1_003_KAN01_001 9 | D1_ex05_KAN02_001 10 | D1_ex02_KAN02_001 11 | D1_047z_KAN01_004 12 | D1_ex06_KAN02_001 13 | D1_ex01_KAN01_001 14 | D1_071_KAN02_002 15 | D1_086_KAN02_001 16 | D1_010_KAN01_004 17 | D1_073_KAN02_002 18 | D1_ex03_KAN02_013 19 | D1_ex03_KAN02_014 20 | D1_059_KAN01_001 21 | D1_ex03_KAN02_012 22 | D1_047z_KAN01_005 23 | D1_006_KAN01_001 24 | D1_031_KAN01_001 25 | D1_025_KAN01_001 26 | D1_055_KAN01_001 27 | D1_007_KAN01_001 28 | D1_009_KAN01_001 29 | D1_008_KAN01_001 30 | D1_053_KAN01_001 31 | D1_ex03_KAN02_001 32 | D1_006_KAN01_002 33 | D1_045_KAN01_001 34 | D1_061z_KAN01_002 35 | D1_ex04_KAN02_003 36 | D1_049_KAN01_001 37 | D1_ex04_KAN02_002 38 | D1_061_KAN01_001 -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | from models import create_model_mixamo 3 | from data_preprocess.Mixamo import create_dataset, get_character_names 4 | import parser.parser_mixamo as option_parser 5 | import torch 6 | from tqdm import tqdm 7 | 8 | 9 | def eval(eval_seq, save_dir, test_device='cpu', epoch=20000): 10 | para_path = os.path.join(save_dir, 'para.txt') 11 | with open(para_path, 'r') as para_file: 12 | argv_ = para_file.readline().split()[1:] 13 | args = option_parser.get_parser().parse_args(argv_) 14 | 15 | args.cuda_device = test_device if torch.cuda.is_available() else 'cpu' 16 | args.is_train = False 17 | args.rotation = 'quaternion' 18 | args.eval_seq = eval_seq 19 | args.save_dir = save_dir 20 | character_names = get_character_names(args) 21 | 22 | dataset = create_dataset(args, character_names) 23 | model = create_model_mixamo(args, character_names, dataset) 24 | model.load(epoch=epoch) 25 | 26 | for i, motions in tqdm(enumerate(dataset), total=len(dataset)): 27 | model.set_input(motions) 28 | model.test() 29 | 30 | if __name__ == '__main__': 31 | parser = option_parser.get_parser() 32 | args = parser.parse_args() 33 | eval(args.eval_seq, args.save_dir, args.cuda_device) 34 | -------------------------------------------------------------------------------- /data_preprocess/Mixamo/fbx2bvh.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code comes from https://github.com/rubenvillegas/cvpr2018nkn/blob/master/datasets/fbx2bvh.py 3 | """ 4 | import bpy 5 | import numpy as np 6 | 7 | from os import listdir, path 8 | 9 | data_path = './Mixamo/' 10 | 11 | directories = sorted([f for f in listdir(data_path) if not f.startswith(".") and path.isdir(path.join(data_path, f))]) 12 | for d in directories: 13 | files = sorted([f for f in listdir(data_path + d) if f.endswith(".fbx")]) 14 | 15 | for f in files: 16 | sourcepath = data_path + d + "/" + f 17 | dumppath = data_path+d + "/" + f.split(".fbx")[0] + ".bvh" 18 | 19 | bpy.ops.import_scene.fbx(filepath=sourcepath) 20 | 21 | frame_start = 9999 22 | frame_end = -9999 23 | action = bpy.data.actions[-1] 24 | if action.frame_range[1] > frame_end: 25 | frame_end = action.frame_range[1] 26 | if action.frame_range[0] < frame_start: 27 | frame_start = action.frame_range[0] 28 | 29 | frame_end = np.max([60, frame_end]) 30 | bpy.ops.export_anim.bvh(filepath=dumppath, 31 | frame_start=int(frame_start), 32 | frame_end=int(frame_end), root_transform_only=True) 33 | bpy.data.actions.remove(bpy.data.actions[-1]) 34 | 35 | print(data_path + d + "/" + f + " processed.") 36 | -------------------------------------------------------------------------------- /parser/evaluation.py: -------------------------------------------------------------------------------- 1 | # evaluation parser for biped-quadruped retargeting 2 | from parser.base import boolean_string, add_misc_options, \ 3 | add_cuda_options, ArgumentParser, add_dataset_options, add_model_options, add_losses_options 4 | 5 | def add_evaluation_options(parser): 6 | group = parser.add_argument_group('Evaluation options') 7 | group.add_argument('--is_train', type=boolean_string, default=False) 8 | group.add_argument('--batch_size', type=int, help="size of the batches", default=24) 9 | group.add_argument('--verbose', type=boolean_string, default=True) 10 | group.add_argument('--epoch_begin', type=int, default=0) 11 | group.add_argument("--save_iter", type=int, default=200, help="frequency of saving model/viz per xx epoch") 12 | group.add_argument("--epoch_num", type=int, default=4001) 13 | 14 | 15 | 16 | def get_parser(argv_=None): 17 | parser = ArgumentParser() 18 | # misc options 19 | add_misc_options(parser) 20 | 21 | # cuda options 22 | add_cuda_options(parser) 23 | 24 | # training options 25 | add_evaluation_options(parser) 26 | 27 | # dataset options 28 | add_dataset_options(parser) 29 | 30 | # model options 31 | add_model_options(parser) 32 | 33 | # loss options 34 | add_losses_options(parser) 35 | 36 | args = parser.parse_args(argv_) 37 | return args 38 | 39 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 2-Clause License 2 | 3 | Copyright (c) 2023, Lei Hu, Zihao Zhang, Chongyang Zhong, Boyuan Jiang, Shihong Xia 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 16 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 19 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 20 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 21 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 22 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 23 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 24 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | -------------------------------------------------------------------------------- /parser/training.py: -------------------------------------------------------------------------------- 1 | # training parser for biped-quadruped retargeting 2 | from parser.base import boolean_string, add_misc_options, \ 3 | add_cuda_options, add_dataset_options, \ 4 | add_model_options, add_losses_options, ArgumentParser 5 | 6 | 7 | def add_training_options(parser): 8 | group = parser.add_argument_group('Training options') 9 | group.add_argument('--is_train', type=boolean_string, default=True) 10 | group.add_argument("--batch_size", type=int, help="size of the batches", default=128) 11 | group.add_argument("--epoch_begin", type=int, default=0, help="load training epoch !") 12 | group.add_argument("--epoch_num", type=int, help="number of epochs of training", default=5001) 13 | group.add_argument("--save_iter", type=int, default=200, help="frequency of saving model/viz per xx epoch") 14 | 15 | group.add_argument("--scheduler", type=str, default='none') 16 | group.add_argument("--optimizer", type=str, default='Adam') 17 | group.add_argument('--lr_d', type=float, default=1e-4, help="discriminator learning rate") 18 | group.add_argument('--lr_g', type=float, default=1e-4, help="generator learning rate") 19 | 20 | 21 | def get_parser(): 22 | parser = ArgumentParser() 23 | # misc options 24 | add_misc_options(parser) 25 | 26 | # cuda options 27 | add_cuda_options(parser) 28 | 29 | # training options 30 | add_training_options(parser) 31 | 32 | # dataset options 33 | add_dataset_options(parser) 34 | 35 | # model options 36 | add_model_options(parser) 37 | 38 | # loss options 39 | add_losses_options(parser) 40 | 41 | args = parser.parse_args() 42 | return args 43 | 44 | 45 | -------------------------------------------------------------------------------- /get_error.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | from parser.parser_mixamo import get_std_bvh 4 | import outer_utils.BVH as BVH 5 | import numpy as np 6 | from data_preprocess.Mixamo.bvh_parser import BVH_file 7 | import outer_utils.Animation as Animation 8 | 9 | 10 | def full_batch(suffix, prefix): 11 | res = [] 12 | chars = ['Mousey_m', 'Goblin_m', 'Mremireh_m', 'Vampire_m'] 13 | for char in chars: 14 | res.append(batch(char, suffix, prefix)) 15 | return res 16 | 17 | 18 | def batch(char, suffix, prefix): 19 | input_path = os.path.join(prefix, 'results/bvh') 20 | 21 | all_err = [] 22 | ref_file = get_std_bvh(dataset=char) 23 | ref_file = BVH_file(ref_file) 24 | height = ref_file.get_height() 25 | 26 | test_num = 0 27 | 28 | new_p = os.path.join(input_path, char) 29 | files = [f for f in os.listdir(new_p) if 30 | f.endswith('_{}.bvh'.format(suffix)) and not f.endswith('_gt.bvh') and 'fix' not in f and not f.endswith('_input.bvh')] 31 | 32 | for file in files: 33 | file_full = os.path.join(new_p, file) 34 | anim, names, _ = BVH.load(file_full) 35 | test_num += 1 36 | index = [] 37 | for i, name in enumerate(names): 38 | if 'virtual' in name: 39 | continue 40 | index.append(i) 41 | 42 | file_ref = file_full[:-6] + '_gt.bvh' 43 | anim_ref, _, _ = BVH.load(file_ref) 44 | 45 | pos = Animation.positions_global(anim) # [T, J, 3] 46 | pos_ref = Animation.positions_global(anim_ref) 47 | 48 | pos = pos[:, index, :] 49 | pos_ref = pos_ref[:, index, :] 50 | 51 | err = (pos - pos_ref) * (pos - pos_ref) 52 | err /= height ** 2 53 | err = np.mean(err) 54 | all_err.append(err) 55 | 56 | all_err = np.array(all_err) 57 | return all_err.mean() 58 | -------------------------------------------------------------------------------- /data_preprocess/Mixamo/__init__.py: -------------------------------------------------------------------------------- 1 | def get_character_names(args): 2 | if args.is_train: 3 | """ 4 | Put the name of subdirectory in ./data_preprocess/Mixamo/Mixamo as [[names of group A], [names of group B]] 5 | """ 6 | characters = [['Aj', 'BigVegas', 'Kaya', 'SportyGranny'], 7 | ['Malcolm_m', 'Remy_m', 'Maria_m', 'Jasper_m', 'Knight_m', 8 | 'Liam_m', 'ParasiteLStarkie_m', 'Pearl_m', 'Michelle_m', 'LolaB_m', 9 | 'Pumpkinhulk_m', 'Ortiz_m', 'Paladin_m', 'James_m', 'Joe_m', 10 | 'Olivia_m', 'Yaku_m', 'Timmy_m', 'Racer_m', 'Abe_m']] 11 | 12 | else: 13 | """ 14 | To run evaluation successfully, number of characters in both groups must be the same. Repeat is okay. 15 | """ 16 | characters = [['BigVegas', 'BigVegas', 'BigVegas', 'BigVegas'], 17 | ['Mousey_m', 'Goblin_m', 'Mremireh_m', 'Vampire_m']] 18 | 19 | tmp = characters[1][args.eval_seq] 20 | characters[1][args.eval_seq] = characters[1][0] 21 | characters[1][0] = tmp 22 | 23 | return characters 24 | 25 | 26 | def create_dataset(args, character_names=None): 27 | from data_preprocess.Mixamo.combined_motion import TestData, MixedData 28 | 29 | if args.is_train: 30 | return MixedData(args, character_names) 31 | else: 32 | return TestData(args, character_names) 33 | 34 | 35 | def get_test_set(): 36 | with open('./data_preprocess/Mixamo/Mixamo/test_list.txt', 'r') as file: 37 | list = file.readlines() 38 | list = [f[:-1] for f in list] 39 | return list 40 | 41 | 42 | def get_train_list(): 43 | with open('./data_preprocess/Mixamo/Mixamo/train_list.txt', 'r') as file: 44 | list = file.readlines() 45 | list = [f[:-1] for f in list] 46 | return list 47 | -------------------------------------------------------------------------------- /loss_record.py: -------------------------------------------------------------------------------- 1 | from torch.utils.tensorboard import SummaryWriter 2 | import numpy as np 3 | import torch 4 | 5 | 6 | class SingleLoss: 7 | def __init__(self, name: str, writer: SummaryWriter): 8 | self.name = name 9 | self.loss_step = [] 10 | self.loss_epoch = [] 11 | self.loss_epoch_tmp = [] 12 | self.writer = writer 13 | 14 | def add_scalar(self, val, step=None): 15 | if step is None: step = len(self.loss_step) 16 | self.loss_step.append(val) 17 | self.loss_epoch_tmp.append(val) 18 | self.writer.add_scalar('Train/step_' + self.name, val, step) 19 | 20 | def epoch(self, step=None): 21 | if step is None: step = len(self.loss_epoch) 22 | loss_avg = sum(self.loss_epoch_tmp) / len(self.loss_epoch_tmp) 23 | self.loss_epoch_tmp = [] 24 | self.loss_epoch.append(loss_avg) 25 | self.writer.add_scalar('Train/epoch_' + self.name, loss_avg, step) 26 | 27 | def save(self, path): 28 | loss_step = np.array(self.loss_step) 29 | loss_epoch = np.array(self.loss_epoch) 30 | np.save(path + self.name + '_step.npy', loss_step) 31 | np.save(path + self.name + '_epoch.npy', loss_epoch) 32 | 33 | 34 | class LossRecorder: 35 | def __init__(self, writer: SummaryWriter): 36 | self.losses = {} 37 | self.writer = writer 38 | 39 | def add_scalar(self, name, val, step=None): 40 | if isinstance(val, torch.Tensor): val = val.item() 41 | if name not in self.losses: 42 | self.losses[name] = SingleLoss(name, self.writer) 43 | self.losses[name].add_scalar(val, step) 44 | 45 | def epoch(self, step=None): 46 | for loss in self.losses.values(): 47 | loss.epoch(step) 48 | 49 | def save(self, path): 50 | for loss in self.losses.values(): 51 | loss.save(path) 52 | -------------------------------------------------------------------------------- /train_mixamo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from torch.utils.data.dataloader import DataLoader 3 | from models import create_model_mixamo 4 | from data_preprocess.Mixamo import create_dataset, get_character_names 5 | import parser.parser_mixamo as option_parser 6 | import os 7 | from parser.base import try_mkdir 8 | import time 9 | import torch 10 | 11 | 12 | torch.autograd.set_detect_anomaly(True) 13 | os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' 14 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 15 | 16 | def main(): 17 | args = option_parser.get_args() 18 | characters = get_character_names(args) 19 | 20 | log_path = os.path.join(args.save_dir, 'logs/') 21 | try_mkdir(args.save_dir) 22 | try_mkdir(log_path) 23 | 24 | with open(os.path.join(args.save_dir, 'para.txt'), 'w') as para_file: 25 | para_file.write(' '.join(sys.argv)) 26 | 27 | dataset = create_dataset(args, characters) 28 | data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=2) 29 | 30 | model = create_model_mixamo(args, characters, dataset) 31 | if args.use_parallel: 32 | model.parallel() 33 | if args.epoch_begin: 34 | model.load(epoch=args.epoch_begin) 35 | 36 | model.setup() 37 | 38 | start_time = time.time() 39 | 40 | for epoch in range(args.epoch_begin, args.epoch_num): 41 | for step, motions in enumerate(data_loader): 42 | model.set_input(motions) # motions: 0(256, 91, 64)(256) 1(256, 111, 64)(256) 43 | model.optimize_parameters() 44 | 45 | if args.verbose: 46 | res = model.verbose() 47 | print('[{}/{}]\t[{}/{}]\t'.format(epoch, args.epoch_num, step, len(data_loader)), res) 48 | 49 | if epoch % args.save_iter == 0 or epoch == args.epoch_num - 1: 50 | model.save() 51 | 52 | model.epoch() 53 | 54 | end_tiem = time.time() 55 | print('training time', end_tiem - start_time) 56 | 57 | 58 | if __name__ == '__main__': 59 | main() 60 | -------------------------------------------------------------------------------- /demo_mixamo.py: -------------------------------------------------------------------------------- 1 | import os 2 | from data_preprocess.Mixamo.bvh_parser import BVH_file 3 | from data_preprocess.Mixamo.bvh_writer import BVH_writer 4 | from models.IK import remove_foot_sliding 5 | from os.path import join as pjoin 6 | from parser.base import try_mkdir 7 | 8 | 9 | # downsampling and remove redundant joints 10 | def copy_ref_file(src, dst): 11 | file = BVH_file(src) 12 | writer = BVH_writer(file.edges, file.names) 13 | writer.write_raw(file.to_tensor(quater=True)[..., ::2], 'quaternion', dst) 14 | 15 | 16 | def get_height(file): 17 | file = BVH_file(file) 18 | return file.get_height() 19 | 20 | 21 | def example(src_name, dest_name, bvh_name, test_type, output_path): 22 | try_mkdir(output_path) 23 | input_file = './data_preprocess/Mixamo/Mixamo/{}/{}'.format(src_name, bvh_name) 24 | ref_file = './data_preprocess/Mixamo/Mixamo/{}/{}'.format(dest_name, bvh_name) 25 | copy_ref_file(input_file, pjoin(output_path, 'input.bvh')) 26 | copy_ref_file(ref_file, pjoin(output_path, 'gt.bvh')) 27 | height = get_height(input_file) 28 | 29 | bvh_name = bvh_name.replace(' ', '_') 30 | input_file = './data_preprocess/Mixamo/Mixamo/{}/{}'.format(src_name, bvh_name) 31 | ref_file = './data_preprocess/Mixamo/Mixamo/{}/{}'.format(dest_name, bvh_name) 32 | 33 | cmd = 'python eval_single_pair.py --input_bvh={} --target_bvh={} ' \ 34 | '--output_filename={} --test_type={} --model_dir={} --epoch={}'.format( 35 | input_file, ref_file, pjoin(output_path, 'result.bvh'), 36 | test_type, './pretrained_mixamo', 1000 37 | ) 38 | os.system(cmd) 39 | 40 | remove_foot_sliding(pjoin(output_path, 'result.bvh'), 41 | pjoin(output_path, 'input.bvh'), 42 | pjoin(output_path, 'result.bvh'), 43 | height) 44 | 45 | 46 | if __name__ == '__main__': 47 | 48 | example('Aj', 'BigVegas', 'Dancing Running Man.bvh', 'intra', './pretrained_mixamo/demo/intra_structure') 49 | example('BigVegas', 'Mousey_m', 'Dual Weapon Combo.bvh', 'cross', './pretrained_mixamo/demo/cross_structure') 50 | print('Finished!') -------------------------------------------------------------------------------- /data_preprocess/Mixamo/preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import copy 4 | from data_preprocess.Mixamo.bvh_parser import BVH_file 5 | from data_preprocess.Mixamo.motion_dataset import MotionData 6 | from parser.parser_mixamo import get_args, try_mkdir 7 | 8 | 9 | def collect_bvh(data_path, character, files): 10 | print('begin {}'.format(character)) 11 | motions = [] 12 | args = get_args() 13 | for i, motion in enumerate(files): 14 | if not os.path.exists(data_path + character + '/' + motion): 15 | continue 16 | file = BVH_file(data_path + character + '/' + motion, args=args) 17 | new_motion = file.to_tensor().permute((1, 0)).numpy() 18 | motions.append(new_motion) 19 | 20 | save_file = data_path + character + '.npy' 21 | 22 | np.save(save_file, motions) 23 | print('Npy file saved at {}'.format(save_file)) 24 | 25 | 26 | def write_statistics(character, path): 27 | args = get_args() 28 | new_args = copy.copy(args) 29 | new_args.data_augment = 0 30 | new_args.dataset = character 31 | 32 | dataset = MotionData(new_args) 33 | 34 | mean = dataset.mean 35 | var = dataset.var 36 | mean = mean.cpu().numpy()[0, ...] 37 | var = var.cpu().numpy()[0, ...] 38 | 39 | np.save(path + '{}_mean.npy'.format(character), mean) 40 | np.save(path + '{}_var.npy'.format(character), var) 41 | 42 | 43 | def copy_std_bvh(data_path, character, files): 44 | """ 45 | copy an arbitrary bvh file as a static information (skeleton's offset) reference 46 | """ 47 | cmd = 'cp \"{}\" ./data_preprocess/Mixamo/Mixamo/std_bvhs/{}.bvh'.format(data_path + character + '/' + files[0], character) 48 | os.system(cmd) 49 | 50 | 51 | if __name__ == '__main__': 52 | prefix = './data_preprocess/Mixamo/Mixamo/' 53 | characters = [f for f in os.listdir(prefix) if os.path.isdir(os.path.join(prefix, f))] 54 | if 'std_bvhs' in characters: characters.remove('std_bvhs') 55 | if 'mean_var' in characters: characters.remove('mean_var') 56 | 57 | try_mkdir(os.path.join(prefix, 'std_bvhs')) 58 | try_mkdir(os.path.join(prefix, 'mean_var')) 59 | 60 | for character in characters: 61 | data_path = os.path.join(prefix, character) 62 | files = sorted([f for f in os.listdir(data_path) if f.endswith(".bvh")]) 63 | 64 | collect_bvh(prefix, character, files) 65 | copy_std_bvh(prefix, character, files) 66 | write_statistics(character, './data_preprocess/Mixamo/Mixamo/mean_var/') 67 | -------------------------------------------------------------------------------- /train_lafan1dog.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data.dataloader import DataLoader 2 | from models import creat_model 3 | from data_preprocess.Lafan1_and_dog.datasetserial import HumDataset, DogDataset 4 | from parser.training import get_parser 5 | from parser.base import dict_to_object, try_mkdir 6 | import os, sys 7 | from config import Configuration 8 | import torch 9 | from utils.utils import get_body_part 10 | 11 | 12 | def main(): 13 | args = get_parser() 14 | parameters_config = {key: val for key, val in vars(Configuration).items() if val is not None} 15 | parameters_args = {key: val for key, val in vars(args).items() if val is not None} 16 | parameters_args.update(parameters_config) 17 | args = dict_to_object(parameters_args) 18 | 19 | os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' 20 | os.environ['CUDA_VISIBLE_DEVICES'] = args.device[-1] 21 | args.device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") 22 | 23 | log_path = os.path.join(args.save_dir, 'logs/') 24 | try_mkdir(args.save_dir) 25 | try_mkdir(log_path) 26 | 27 | with open(os.path.join(args.save_dir, 'para.txt'), 'w') as para_file: 28 | para_file.write(' '.join(sys.argv)) 29 | 30 | humdataset = HumDataset(args) 31 | dogdataset = DogDataset(args) 32 | 33 | humloader = DataLoader(humdataset, batch_size=args.batch_size, shuffle=True, drop_last=True) 34 | dogloader = DataLoader(dogdataset, batch_size=args.batch_size, shuffle=True, drop_last=True) 35 | dogfeeder = iter(dogloader) 36 | humfeeder = iter(humloader) 37 | 38 | hum_parts = get_body_part(args.correspondence, 'hum_joints') 39 | dog_parts = get_body_part(args.correspondence, 'dog_joints') 40 | 41 | body_parts = [hum_parts, dog_parts] 42 | datasets = [humdataset, dogdataset] 43 | 44 | model = creat_model(args, body_parts, datasets, ['human', 'dog']) 45 | 46 | if args.epoch_begin: 47 | model.load(epoch=args.epoch_begin) 48 | 49 | model.setup() 50 | 51 | epoch = args.epoch_begin 52 | while epoch < args.epoch_num: 53 | if epoch % args.save_iter == 0 or epoch == args.epoch_num - 1: 54 | model.save() 55 | 56 | flag = True 57 | while flag: 58 | try: 59 | input_d, d_yrot, d_offsets, d_offsets_withend = next(dogfeeder) 60 | except StopIteration: 61 | dogfeeder = iter(dogloader) 62 | input_d, d_yrot, d_offsets, d_offsets_withend = next(dogfeeder) 63 | 64 | try: 65 | input_h, h_yrot, h_offsets, h_offsets_withend = next(humfeeder) 66 | except StopIteration: 67 | epoch += 1 68 | flag = False 69 | humfeeder = iter(humloader) 70 | input_h, h_yrot, h_offsets, h_offsets_withend = next(humfeeder) 71 | 72 | vel_dim = 4 73 | input_h_encoder = (input_h[..., :args.hum_njoints * 4 + vel_dim]).transpose(1, 2) 74 | input_d_encoder = (input_d[..., :args.dog_njoints * 4 + vel_dim]).transpose(1, 2) 75 | 76 | input_h_encoder = (input_h_encoder, h_offsets, h_offsets_withend) 77 | input_d_encoder = (input_d_encoder, d_offsets, d_offsets_withend) 78 | 79 | model.set_input([input_h_encoder, input_d_encoder]) 80 | 81 | model.optimize_parameters() 82 | 83 | model.epoch() 84 | 85 | if __name__ == '__main__': 86 | main() -------------------------------------------------------------------------------- /models/functions.py: -------------------------------------------------------------------------------- 1 | import loss_function as lf 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | 5 | def get_gan_loss(loss_type): 6 | if loss_type == 'bce_gan': 7 | return lf.dis_loss 8 | elif loss_type == 'l2_gan': 9 | return lf.dis_loss_l2 10 | 11 | def get_rec_loss(loss_type): 12 | if loss_type == 'mse_rec': 13 | return nn.MSELoss() 14 | elif loss_type == 'quat_rec' or loss_type == 'norm_rec': 15 | return lf.caloutputloss 16 | 17 | def get_root_loss(loss_type): 18 | if loss_type == 'mse_root': 19 | return nn.MSELoss() 20 | 21 | def get_kine_loss(loss_type): 22 | if loss_type == 'mse_kine': 23 | return nn.MSELoss() 24 | elif loss_type == 'part_kine': 25 | return lf.calposloss 26 | 27 | def get_cycle_loss(loss_type): 28 | if loss_type == 'mse_cycle_motion': 29 | return lf.cycle_motions 30 | 31 | def get_cycle_latent_loss(loss_type): 32 | if loss_type == 'mse_latent': 33 | return lf.cycle_latents 34 | 35 | def get_retar_root_v_loss(loss_type): 36 | if loss_type == 'linear': 37 | return nn.MSELoss() 38 | 39 | 40 | def get_discriminator_input(gan_model, dis_mode, index, real): 41 | if dis_mode == 'norm_rotation': 42 | if real: 43 | return gan_model.motion[index] 44 | else: 45 | return gan_model.fake_retar[index].transpose(1, 2) 46 | elif dis_mode == 'denorm_rotation': 47 | if real: 48 | return gan_model.motion_denorm[index].transpose(1, 2) 49 | else: 50 | return gan_model.fake_retar_denorm[index].transpose(1, 2) 51 | elif dis_mode == 'denorm_pos': 52 | if real: 53 | return gan_model.gt_pos[index].reshape(gan_model.gt_pos[index].shape[:-2] + (-1, )).transpose(1, 2) 54 | else: 55 | return gan_model.fake_pos[index].reshape(gan_model.fake_pos[index].shape[:-2] + (-1, )).transpose(1, 2) 56 | elif dis_mode == 'latent': 57 | if real: 58 | return gan_model.latents[index] 59 | else: 60 | return gan_model.retar_latents[index] 61 | else: 62 | raise Exception("Discriminator input not defined !") 63 | 64 | 65 | def get_recloss_input(gan_model, rec_loss_mode, index): 66 | if rec_loss_mode == 'norm_rec': 67 | input_0 = gan_model.motion[index].transpose(1, 2) 68 | input_1 = gan_model.rec[index] 69 | elif rec_loss_mode == 'quat_rec': 70 | input_0 = gan_model.motion_denorm[index] 71 | input_1 = gan_model.rec_denorm[index] 72 | return input_0, input_1 73 | 74 | 75 | def get_retar_latents(gan_model, src): 76 | input = gan_model.latents[src] 77 | out_latent = input 78 | return out_latent 79 | 80 | 81 | def get_cyc_latents(gan_model, fake_latent): 82 | 83 | # Construct shared latent space without any transfer 84 | out_latent = fake_latent 85 | return out_latent 86 | 87 | 88 | def get_optimizer(optimizer_name, para, lr, **kwargs): 89 | if optimizer_name == 'RMSprop': 90 | optimizer = optim.RMSprop(para, lr=lr) 91 | 92 | elif optimizer_name == 'Adam': 93 | optimizer = optim.Adam(para, lr=lr, **kwargs) 94 | 95 | elif optimizer_name == 'Adadelta': 96 | optimizer = optim.Adadelta(para, lr=lr, **kwargs) 97 | 98 | else: 99 | raise Exception("Optimizer not find!") 100 | 101 | return optimizer 102 | 103 | 104 | -------------------------------------------------------------------------------- /data_preprocess/Mixamo/split_joint.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script splits three joints as we describe in the paper. 3 | It automatically detects all the dirs in "./data_preprocess/Mixamo/Mixamo" and 4 | finds these can be split then create a new dir with an extra _m 5 | as suffix to store the split files in the new dir. 6 | """ 7 | 8 | import os 9 | from parser.base import try_mkdir 10 | import numpy as np 11 | from tqdm import tqdm 12 | from data_preprocess.Mixamo.bvh_parser import BVH_file 13 | import outer_utils as BVH 14 | 15 | 16 | def split_joint(file_name, save_file=None): 17 | if save_file is None: 18 | save_file = file_name 19 | target_joints = ['Spine1', 'LeftShoulder', 'RightShoulder'] 20 | target_idx = [-1] * len(target_joints) 21 | anim, names, ftime = BVH.load(file_name) 22 | 23 | n_joint = len(anim.parents) 24 | 25 | for i, name in enumerate(names): 26 | if ':' in name: 27 | name = name[name.find(':') + 1:] 28 | names[i] = name 29 | 30 | for j, joint in enumerate(target_joints): 31 | if joint == names[i]: 32 | target_idx[j] = i 33 | 34 | new_anim = anim.copy() 35 | new_anim.offsets = [] 36 | new_anim.parents = [] 37 | new_anim.rotations = [] 38 | new_names = [] 39 | 40 | target_idx.sort() 41 | 42 | bias = 0 43 | new_id = {-1: -1} 44 | target_idx.append(-1) 45 | for i in range(n_joint): 46 | new_id[i] = i + bias 47 | if i == target_idx[bias]: bias += 1 48 | 49 | identity = np.zeros_like(anim.rotations) 50 | identity = identity[:, :1, :] 51 | 52 | bias = 0 53 | for i in range(n_joint): 54 | new_anim.parents.append(new_id[anim.parents[i]]) 55 | new_names.append(names[i]) 56 | new_anim.rotations.append(anim.rotations[:, [i], :]) 57 | 58 | if i == target_idx[bias]: 59 | new_anim.offsets.append(anim.offsets[i] / 2) 60 | 61 | new_anim.parents.append(i + bias) 62 | new_names.append(names[i] + '_split') 63 | new_anim.offsets.append(anim.offsets[i] / 2) 64 | 65 | new_anim.rotations.append(identity) 66 | 67 | new_id[i] += 1 68 | bias += 1 69 | else: 70 | new_anim.offsets.append(anim.offsets[i]) 71 | 72 | new_anim.offsets = np.array(new_anim.offsets) 73 | 74 | offset_spine = anim.offsets[target_idx[0]] + anim.offsets[target_idx[0] + 1] 75 | new_anim.offsets[target_idx[0]:target_idx[0]+3, :] = offset_spine / 3 76 | 77 | new_anim.rotations = np.concatenate(new_anim.rotations, axis=1) 78 | try_mkdir(os.path.split(save_file)[0]) 79 | BVH.save(save_file, new_anim, names=new_names, frametime=ftime, order='xyz') 80 | 81 | 82 | def batch_split(source, dest): 83 | print(source) 84 | files = [f for f in os.listdir(source) if f.endswith('.bvh')] 85 | try: 86 | bvh_file = BVH_file(os.path.join(source, files[0])) 87 | # if bvh_file.skeleton_type != 1: return 88 | print(1) 89 | except: 90 | return 91 | 92 | print("Working on {}".format(os.path.split(source)[-1])) 93 | try_mkdir(dest) 94 | files = [f for f in os.listdir(source) if f.endswith('.bvh')] 95 | for i, file in tqdm(enumerate(files), total=len(files)): 96 | in_file = os.path.join(source, file) 97 | out_file = os.path.join(dest, file) 98 | split_joint(in_file, out_file) 99 | 100 | 101 | if __name__ == '__main__': 102 | prefix = './data_preprocess/Mixamo/Mixamo/' 103 | names = [f for f in os.listdir(prefix) if os.path.isdir(os.path.join(prefix, f))] 104 | 105 | for name in names: 106 | batch_split(os.path.join(prefix, name), os.path.join(prefix, name + '_m')) 107 | -------------------------------------------------------------------------------- /test_mixamo.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join as pjoin 3 | from get_error import full_batch 4 | import numpy as np 5 | from parser.parser_mixamo import try_mkdir 6 | from eval import eval 7 | import argparse 8 | 9 | 10 | os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' 11 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 12 | 13 | 14 | def batch_copy(source_path, suffix, dest_path, dest_suffix=None): 15 | try_mkdir(dest_path) 16 | files = [f for f in os.listdir(source_path) if f.endswith('_{}.bvh'.format(suffix))] 17 | 18 | length = len('_{}.bvh'.format(suffix)) 19 | for f in files: 20 | if dest_suffix is not None: 21 | cmd = 'cp \"{}\" \"{}\"'.format(os.path.join(source_path, f), os.path.join(dest_path, f[:-length] + '_{}.bvh'.format(dest_suffix))) 22 | else: 23 | cmd = 'cp \"{}\" \"{}\"'.format(os.path.join(source_path, f), os.path.join(dest_path, f[:-length] + '.bvh')) 24 | os.system(cmd) 25 | 26 | 27 | def batch_mat_copy(source_path, suffix, dest_path, dest_suffix=None): 28 | try_mkdir(dest_path) 29 | files = [f for f in os.listdir(source_path) if f.endswith('_{}_gt.mat'.format(suffix))] 30 | 31 | length = len('_{}_gt.mat'.format(suffix)) 32 | for f in files: 33 | if dest_suffix is not None: 34 | cmd = 'cp \"{}\" \"{}\"'.format(os.path.join(source_path, f), os.path.join(dest_path, f[:-length] + '_{}.mat'.format(dest_suffix))) 35 | else: 36 | cmd = 'cp \"{}\" \"{}\"'.format(os.path.join(source_path, f), os.path.join(dest_path, f[:-length] + '.mat')) 37 | os.system(cmd) 38 | 39 | 40 | if __name__ == '__main__': 41 | test_characters = ['Mousey_m', 'Goblin_m', 'Mremireh_m', 'Vampire_m'] 42 | 43 | parser = argparse.ArgumentParser() 44 | parser.add_argument('--save_dir', type=str, default='./pretrained_mixamo/') 45 | parser.add_argument('--model', type=str, default='pan') 46 | parser.add_argument('--epoch', type=int, default=20000) 47 | parser.add_argument('--device', type=str, default='cuda:0') 48 | 49 | args = parser.parse_args() 50 | prefix = args.save_dir 51 | 52 | cross_dest_path = pjoin(prefix, 'results/cross_structure/') 53 | intra_dest_path = pjoin(prefix, 'results/intra_structure/') 54 | source_path = pjoin(prefix, 'results/bvh/') 55 | 56 | cross_error = [] 57 | intra_error = [] 58 | for i in range(4): 59 | print('Batch [{}/4]'.format(i + 1)) 60 | if args.device == 'cpu': 61 | eval(i, prefix, "cpu", epoch=args.epoch) 62 | else: 63 | eval(i, prefix, "cuda:0", epoch=args.epoch) 64 | 65 | print('Collecting test error...') 66 | if i == 0: 67 | cross_error += full_batch(0, prefix) 68 | for char in test_characters: 69 | batch_copy(os.path.join(source_path, char), 0, os.path.join(cross_dest_path, char)) 70 | batch_copy(os.path.join(source_path, char), 'gt', os.path.join(cross_dest_path, char), 'gt') 71 | batch_mat_copy(os.path.join(source_path, char), 0, os.path.join(cross_dest_path, char), 'gt') 72 | 73 | intra_dest = os.path.join(intra_dest_path, 'from_{}'.format(test_characters[i])) 74 | for char in test_characters: 75 | for char in test_characters: 76 | batch_copy(os.path.join(source_path, char), 1, os.path.join(intra_dest, char)) 77 | batch_copy(os.path.join(source_path, char), 'gt', os.path.join(intra_dest, char), 'gt') 78 | batch_mat_copy(os.path.join(source_path, char), 1, os.path.join(intra_dest, char), 'gt') 79 | 80 | 81 | intra_error += full_batch(1, prefix) 82 | 83 | cross_error = np.array(cross_error) 84 | intra_error = np.array(intra_error) 85 | 86 | cross_error_mean = cross_error.mean() 87 | intra_error_mean = intra_error.mean() 88 | 89 | os.system('rm -r %s' % pjoin(prefix, 'results/bvh')) 90 | 91 | print('Intra-retargeting error:', intra_error_mean) 92 | print('Cross-retargeting error:', cross_error_mean) 93 | print('Evaluation finished!') 94 | 95 | 96 | -------------------------------------------------------------------------------- /data_preprocess/Mixamo/bvh_writer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from outer_utils.Quaternions import Quaternions 3 | from utils.utils import build_joint_topology 4 | 5 | 6 | # rotation with shape frame * J * 3 7 | def write_bvh(parent, offset, rotation, position, names, frametime, order, path, endsite=None): 8 | file = open(path, 'w') 9 | frame = rotation.shape[0] 10 | joint_num = rotation.shape[1] 11 | order = order.upper() 12 | 13 | file_string = 'HIERARCHY\n' 14 | 15 | def write_static(idx, prefix): 16 | nonlocal parent, offset, rotation, names, order, endsite, file_string 17 | if idx == 0: 18 | name_label = 'ROOT ' + names[idx] 19 | channel_label = 'CHANNELS 6 Xposition Yposition Zposition {}rotation {}rotation {}rotation'.format(*order) 20 | else: 21 | name_label = 'JOINT ' + names[idx] 22 | channel_label = 'CHANNELS 3 {}rotation {}rotation {}rotation'.format(*order) 23 | offset_label = 'OFFSET %.6f %.6f %.6f' % (offset[idx][0], offset[idx][1], offset[idx][2]) 24 | 25 | file_string += prefix + name_label + '\n' 26 | file_string += prefix + '{\n' 27 | file_string += prefix + '\t' + offset_label + '\n' 28 | file_string += prefix + '\t' + channel_label + '\n' 29 | 30 | has_child = False 31 | for y in range(idx+1, rotation.shape[1]): 32 | if parent[y] == idx: 33 | has_child = True 34 | write_static(y, prefix + '\t') 35 | if not has_child: 36 | file_string += prefix + '\t' + 'End Site\n' 37 | file_string += prefix + '\t' + '{\n' 38 | file_string += prefix + '\t\t' + 'OFFSET 0 0 0\n' 39 | file_string += prefix + '\t' + '}\n' 40 | 41 | file_string += prefix + '}\n' 42 | 43 | write_static(0, '') 44 | 45 | file_string += 'MOTION\n' + 'Frames: {}\n'.format(frame) + 'Frame Time: %.8f\n' % frametime 46 | for i in range(frame): 47 | file_string += '%.6f %.6f %.6f ' % (position[i][0], position[i][1], position[i][2]) 48 | for j in range(joint_num): 49 | file_string += '%.6f %.6f %.6f ' % (rotation[i][j][0], rotation[i][j][1], rotation[i][j][2]) 50 | file_string += '\n' 51 | 52 | file.write(file_string) 53 | return file_string 54 | 55 | 56 | class BVH_writer(): 57 | def __init__(self, edges, names): 58 | self.parent, self.offset, self.names, self.edge2joint = build_joint_topology(edges, names) 59 | self.joint_num = len(self.parent) 60 | 61 | # position, rotation with shape T * J * (3/4) 62 | def write(self, rotations, positions, order, path, frametime=1.0/30, offset=None, root_y=None): 63 | if order == 'quaternion': 64 | norm = rotations[:, :, 0] ** 2 + rotations[:, :, 1] ** 2 + rotations[:, :, 2] ** 2 + rotations[:, :, 3] ** 2 65 | norm = np.repeat(norm[:, :, np.newaxis], 4, axis=2) 66 | rotations /= norm 67 | rotations = Quaternions(rotations) 68 | rotations = np.degrees(rotations.euler()) 69 | order = 'xyz' 70 | 71 | rotations_full = np.zeros((rotations.shape[0], self.joint_num, 3)) 72 | for idx, edge in enumerate(self.edge2joint): 73 | if edge != -1: 74 | rotations_full[:, idx, :] = rotations[:, edge, :] 75 | if root_y is not None: rotations_full[0, 0, 1] = root_y 76 | 77 | if offset is None: offset = self.offset 78 | return write_bvh(self.parent, offset, rotations_full, positions, self.names, frametime, order, path) 79 | 80 | def write_raw(self, motion, order, path, frametime=1.0/30, root_y=None): 81 | motion = motion.permute(1, 0).detach().cpu().numpy() 82 | positions = motion[:, -3:] 83 | rotations = motion[:, :-3] 84 | if order == 'quaternion': 85 | rotations = rotations.reshape((motion.shape[0], -1, 4)) 86 | else: 87 | rotations = rotations.reshape((motion.shape[0], -1, 3)) 88 | 89 | return self.write(rotations, positions, order, path, frametime, root_y=root_y) 90 | -------------------------------------------------------------------------------- /eval_single_pair.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from models import create_model_mixamo 4 | from data_preprocess.Mixamo import create_dataset 5 | import parser.parser_mixamo as option_parser 6 | 7 | 8 | def eval_prepare(args): 9 | character = [] 10 | file_id = [] 11 | character_names = [] 12 | character_names.append(args.input_bvh.split('/')[-2]) 13 | character_names.append(args.target_bvh.split('/')[-2]) 14 | if args.test_type == 'intra': 15 | if character_names[0].endswith('_m'): 16 | character = [['BigVegas', 'BigVegas'], character_names] 17 | file_id = [[0, 0], [args.input_bvh, args.input_bvh]] 18 | src_id = 1 19 | else: 20 | character = [character_names, ['Goblin_m', 'Goblin_m']] 21 | file_id = [[args.input_bvh, args.input_bvh], [0, 0]] 22 | src_id = 0 23 | elif args.test_type == 'cross': 24 | if character_names[0].endswith('_m'): 25 | character = [[character_names[1]], [character_names[0]]] 26 | file_id = [[0], [args.input_bvh]] 27 | src_id = 1 28 | else: 29 | character = [[character_names[0]], [character_names[1]]] 30 | file_id = [[args.input_bvh], [0]] 31 | src_id = 0 32 | else: 33 | raise Exception('Unknown test type') 34 | return character, file_id, src_id 35 | 36 | 37 | def recover_space(file): 38 | l = file.split('/') 39 | l[-1] = l[-1].replace('_', ' ') 40 | return '/'.join(l) 41 | 42 | 43 | def main(): 44 | parser = option_parser.get_parser() 45 | parser.add_argument('--input_bvh', type=str, required=True) 46 | parser.add_argument('--target_bvh', type=str, required=True) 47 | parser.add_argument('--test_type', type=str, required=True) 48 | parser.add_argument('--output_filename', type=str, required=True) 49 | parser.add_argument('--model_dir', type=str, required=True) 50 | parser.add_argument('--epoch', type=int, required=True) 51 | 52 | args = parser.parse_args() 53 | 54 | # argsparse can't take space character as part of the argument 55 | args.input_bvh = recover_space(args.input_bvh) 56 | args.target_bvh = recover_space(args.target_bvh) 57 | args.output_filename = recover_space(args.output_filename) 58 | 59 | character_names, file_id, src_id = eval_prepare(args) 60 | input_character_name = args.input_bvh.split('/')[-2] 61 | output_character_name = args.target_bvh.split('/')[-2] 62 | output_filename = args.output_filename 63 | 64 | test_device = args.cuda_device 65 | eval_seq = args.eval_seq 66 | epoch = args.epoch 67 | 68 | para_path = os.path.join(args.model_dir, 'para.txt') 69 | with open(para_path, 'r') as para_file: 70 | argv_ = para_file.readline().split()[1:] 71 | args = option_parser.get_parser().parse_args(argv_) 72 | 73 | args.model = 'pan' 74 | args.cuda_device = test_device if torch.cuda.is_available() else 'cpu' 75 | args.is_train = False 76 | args.rotation = 'quaternion' 77 | args.eval_seq = eval_seq 78 | 79 | dataset = create_dataset(args, character_names) 80 | model = create_model_mixamo(args, character_names, dataset) 81 | model.load(epoch=epoch) 82 | 83 | input_motion = [] 84 | for i, character_group in enumerate(character_names): 85 | input_group = [] 86 | for j in range(len(character_group)): 87 | new_motion = dataset.get_item(i, j, file_id[i][j]) 88 | new_motion.unsqueeze_(0) 89 | new_motion = (new_motion - dataset.mean[i][j]) / dataset.var[i][j] 90 | input_group.append(new_motion) 91 | input_group = torch.cat(input_group, dim=0) 92 | input_motion.append([input_group, list(range(len(character_group)))]) 93 | 94 | model.set_input(input_motion) 95 | model.test() 96 | 97 | os.system('cp "{}/{}/0_{}.bvh" "./{}"'.format(model.bvh_path, output_character_name, src_id, output_filename)) 98 | 99 | 100 | if __name__ == '__main__': 101 | main() 102 | -------------------------------------------------------------------------------- /models/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.optim 4 | from abc import ABC, abstractmethod 5 | 6 | 7 | class BaseModel(ABC): 8 | """This class is an abstract base class (ABC) for models. 9 | To create a subclass, you need to implement the following five functions: 10 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). 11 | -- : unpack data from dataset and apply preprocessing. 12 | -- : produce intermediate results. 13 | -- : calculate losses, gradients, and update network weights. 14 | """ 15 | 16 | def __init__(self, args): 17 | self.args = args 18 | self.is_train = args.is_train 19 | self.device = torch.device("cuda:0" if (torch.cuda.is_available()) else 'cpu') 20 | self.model_save_dir = os.path.join(args.save_dir, 'models') # save all the checkpoints to save_dir 21 | 22 | if self.is_train: 23 | from loss_record import LossRecorder 24 | from torch.utils.tensorboard import SummaryWriter 25 | self.log_path = os.path.join(args.save_dir, 'logs') 26 | self.writer = SummaryWriter(self.log_path) 27 | self.loss_recoder = LossRecorder(self.writer) 28 | 29 | self.epoch_cnt = 0 30 | self.schedulers = [] 31 | self.optimizers = [] 32 | 33 | @abstractmethod 34 | def set_input(self, input): 35 | """Unpack input data from the dataloader and perform necessary pre-processing steps. 36 | Parameters: 37 | input (dict): includes the data itself and its metadata information. 38 | """ 39 | pass 40 | 41 | @abstractmethod 42 | def compute_test_result(self): 43 | """ 44 | After forward, do something like output bvh, get error value 45 | """ 46 | pass 47 | 48 | @abstractmethod 49 | def forward(self): 50 | """Run forward pass; called by both functions and .""" 51 | pass 52 | 53 | 54 | @abstractmethod 55 | def optimize_parameters(self): 56 | """Calculate losses, gradients, and update network weights; called in every training iteration""" 57 | pass 58 | 59 | def get_scheduler(self, optimizer): 60 | if self.args.scheduler == 'linear': 61 | def lambda_rule(epoch): 62 | lr_l = 1.0 - max(0, epoch - self.args.n_epochs_origin) / float(self.args.n_epochs_decay + 1) 63 | return lr_l 64 | return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 65 | if self.args.scheduler == 'Step_LR': 66 | print('Step_LR scheduler set') 67 | return torch.optim.lr_scheduler.StepLR(optimizer, 50, 0.5) 68 | if self.args.scheduler == 'Plateau': 69 | print('Plateau_LR shceduler set') 70 | return torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5, verbose=True) 71 | if self.args.scheduler == 'MultiStep': 72 | return torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[]) 73 | 74 | def setup(self): 75 | """Load and print networks; create schedulers 76 | Parameters: 77 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 78 | """ 79 | if self.is_train: 80 | self.schedulers = [self.get_scheduler(optimizer) for optimizer in self.optimizers] 81 | 82 | def epoch(self): 83 | self.loss_recoder.epoch() 84 | for scheduler in self.schedulers: 85 | if scheduler is not None: 86 | scheduler.step() 87 | self.epoch_cnt += 1 88 | 89 | def test(self): 90 | """Forward function used in test time. 91 | This function wraps function in no_grad() so we don't save intermediate steps for backprop 92 | It also calls to produce additional visualization results 93 | """ 94 | with torch.no_grad(): 95 | self.forward() 96 | self.compute_test_result() -------------------------------------------------------------------------------- /data_preprocess/Lafan1_and_dog/std_bvh/dog_std.bvh: -------------------------------------------------------------------------------- 1 | HIERARCHY 2 | ROOT Hips 3 | { 4 | OFFSET -10.0563 7.73376 -472.55 5 | CHANNELS 6 Xposition Yposition Zposition Zrotation Xrotation Yrotation 6 | JOINT Spine 7 | { 8 | OFFSET 0 0 0 9 | CHANNELS 3 Zrotation Xrotation Yrotation 10 | JOINT Spine1 11 | { 12 | OFFSET 19 0 0 13 | CHANNELS 3 Zrotation Xrotation Yrotation 14 | JOINT Neck 15 | { 16 | OFFSET 22.5 0.6 0 17 | CHANNELS 3 Zrotation Xrotation Yrotation 18 | JOINT Head 19 | { 20 | OFFSET 14 0.0308777 0 21 | CHANNELS 3 Zrotation Xrotation Yrotation 22 | End Site 23 | { 24 | OFFSET 17 0 0 25 | } 26 | } 27 | } 28 | JOINT LeftShoulder 29 | { 30 | OFFSET 19.8 3.7 4.3 31 | CHANNELS 3 Zrotation Xrotation Yrotation 32 | JOINT LeftArm 33 | { 34 | OFFSET 8 0 0 35 | CHANNELS 3 Zrotation Xrotation Yrotation 36 | JOINT LeftForeArm 37 | { 38 | OFFSET 15.2 0 0 39 | CHANNELS 3 Zrotation Xrotation Yrotation 40 | JOINT LeftHand 41 | { 42 | OFFSET 17.8 0 0 43 | CHANNELS 3 Zrotation Xrotation Yrotation 44 | End Site 45 | { 46 | OFFSET 7.2 0 0 47 | } 48 | } 49 | } 50 | } 51 | } 52 | JOINT RightShoulder 53 | { 54 | OFFSET 19.8 3.7 -4.3 55 | CHANNELS 3 Zrotation Xrotation Yrotation 56 | JOINT RightArm 57 | { 58 | OFFSET 8 0 0.151654 59 | CHANNELS 3 Zrotation Xrotation Yrotation 60 | JOINT RightForeArm 61 | { 62 | OFFSET 15.2 0 0 63 | CHANNELS 3 Zrotation Xrotation Yrotation 64 | JOINT RightHand 65 | { 66 | OFFSET 17.8 0 0 67 | CHANNELS 3 Zrotation Xrotation Yrotation 68 | End Site 69 | { 70 | OFFSET 7.2 0 0 71 | } 72 | } 73 | } 74 | } 75 | } 76 | } 77 | } 78 | JOINT LeftUpLeg 79 | { 80 | OFFSET 5.98425 -7.666 4.78879 81 | CHANNELS 3 Zrotation Xrotation Yrotation 82 | JOINT LeftLeg 83 | { 84 | OFFSET 16 0 0 85 | CHANNELS 3 Zrotation Xrotation Yrotation 86 | JOINT LeftFoot 87 | { 88 | OFFSET 18 0 0 89 | CHANNELS 3 Zrotation Xrotation Yrotation 90 | End Site 91 | { 92 | OFFSET 0 -10.8 0 93 | } 94 | } 95 | } 96 | } 97 | JOINT RightUpLeg 98 | { 99 | OFFSET 5.98425 -7.66598 -4.78879 100 | CHANNELS 3 Zrotation Xrotation Yrotation 101 | JOINT RightLeg 102 | { 103 | OFFSET 16 0 0 104 | CHANNELS 3 Zrotation Xrotation Yrotation 105 | JOINT RightFoot 106 | { 107 | OFFSET 18 0 0 108 | CHANNELS 3 Zrotation Xrotation Yrotation 109 | End Site 110 | { 111 | OFFSET 0 -10.8 0 112 | } 113 | } 114 | } 115 | } 116 | JOINT Tail 117 | { 118 | OFFSET 6.83696 -0.722574 0 119 | CHANNELS 3 Zrotation Xrotation Yrotation 120 | JOINT Tail1 121 | { 122 | OFFSET 12 0 0 123 | CHANNELS 3 Zrotation Xrotation Yrotation 124 | End Site 125 | { 126 | OFFSET 12 0 0 127 | } 128 | } 129 | } 130 | } 131 | MOTION 132 | Frames: 1 133 | Frame Time: 0.0166667 134 | -17.969 52.075 -519.751 -4.80122 8.28961 96.7006 -172.99 -0.76463 3.72698 2.64451 0.748447 -12.6336 -23.6352 -3.7871 8.94177 -121.389 71.5156 116.825 -104.468 -25.292 -146.921 132.594 -36.0807 -138.548 -31.6843 12.3227 -3.29186 -21.3944 2.73052 -13.1956 73.5913 -32.3901 44.3507 36.0035 -27.0275 -51.2977 -68.2346 -25.1381 0.183097 -27.9054 -6.90339 22.1542 -127.35 1.19208 -12.148 111.467 12.0059 -1.71467 31.1289 -2.32536 -4.23218 -138.319 4.57958 6.68372 112.572 -2.48073 -3.99924 17.5036 -1.74821 -2.67415 -11.501 12.7199 -64.7021 2.87678 -0.242718 -18.0754 -------------------------------------------------------------------------------- /data_preprocess/Lafan1_and_dog/std_bvh/hum_std.bvh: -------------------------------------------------------------------------------- 1 | HIERARCHY 2 | ROOT Hips 3 | { 4 | OFFSET 173.408295 91.952423 -518.280273 5 | CHANNELS 6 Xposition Yposition Zposition Zrotation Yrotation Xrotation 6 | JOINT LeftUpLeg 7 | { 8 | OFFSET 0.103459 1.857827 10.548504 9 | CHANNELS 3 Zrotation Yrotation Xrotation 10 | JOINT LeftLeg 11 | { 12 | OFFSET 43.500008 0.000000 0.000004 13 | CHANNELS 3 Zrotation Yrotation Xrotation 14 | JOINT LeftFoot 15 | { 16 | OFFSET 42.372192 0.000011 0.000000 17 | CHANNELS 3 Zrotation Yrotation Xrotation 18 | JOINT LeftToe 19 | { 20 | OFFSET 17.299973 -0.000013 -0.000010 21 | CHANNELS 3 Zrotation Yrotation Xrotation 22 | End Site 23 | { 24 | OFFSET 0.000000 0.000000 0.000000 25 | } 26 | } 27 | } 28 | } 29 | } 30 | JOINT RightUpLeg 31 | { 32 | OFFSET 0.103454 1.857830 -10.548500 33 | CHANNELS 3 Zrotation Yrotation Xrotation 34 | JOINT RightLeg 35 | { 36 | OFFSET 43.500038 -0.000038 0.000004 37 | CHANNELS 3 Zrotation Yrotation Xrotation 38 | JOINT RightFoot 39 | { 40 | OFFSET 42.372253 0.000019 0.000024 41 | CHANNELS 3 Zrotation Yrotation Xrotation 42 | JOINT RightToe 43 | { 44 | OFFSET 17.299988 -0.000007 0.000004 45 | CHANNELS 3 Zrotation Yrotation Xrotation 46 | End Site 47 | { 48 | OFFSET 0.000000 0.000000 0.000000 49 | } 50 | } 51 | } 52 | } 53 | } 54 | JOINT Spine 55 | { 56 | OFFSET 6.901963 -2.603744 0.000004 57 | CHANNELS 3 Zrotation Yrotation Xrotation 58 | JOINT Spine1 59 | { 60 | OFFSET 12.588104 0.000008 -0.000010 61 | CHANNELS 3 Zrotation Yrotation Xrotation 62 | JOINT Spine2 63 | { 64 | OFFSET 12.343202 -0.000005 0.000010 65 | CHANNELS 3 Zrotation Yrotation Xrotation 66 | JOINT Neck 67 | { 68 | OFFSET 25.832897 0.000000 0.000001 69 | CHANNELS 3 Zrotation Yrotation Xrotation 70 | JOINT Head 71 | { 72 | OFFSET 11.766611 -0.000006 -0.000000 73 | CHANNELS 3 Zrotation Yrotation Xrotation 74 | End Site 75 | { 76 | OFFSET 0.000000 0.000000 0.000000 77 | } 78 | } 79 | } 80 | JOINT LeftShoulder 81 | { 82 | OFFSET 19.745899 -1.480366 6.000108 83 | CHANNELS 3 Zrotation Yrotation Xrotation 84 | JOINT LeftArm 85 | { 86 | OFFSET 11.284111 -0.000018 -0.000015 87 | CHANNELS 3 Zrotation Yrotation Xrotation 88 | JOINT LeftForeArm 89 | { 90 | OFFSET 33.000050 -0.000005 0.000028 91 | CHANNELS 3 Zrotation Yrotation Xrotation 92 | JOINT LeftHand 93 | { 94 | OFFSET 25.200012 0.000000 0.000002 95 | CHANNELS 3 Zrotation Yrotation Xrotation 96 | End Site 97 | { 98 | OFFSET 0.000000 0.000000 0.000000 99 | } 100 | } 101 | } 102 | } 103 | } 104 | JOINT RightShoulder 105 | { 106 | OFFSET 19.746111 -1.480335 -6.000074 107 | CHANNELS 3 Zrotation Yrotation Xrotation 108 | JOINT RightArm 109 | { 110 | OFFSET 11.284151 0.000036 0.000001 111 | CHANNELS 3 Zrotation Yrotation Xrotation 112 | JOINT RightForeArm 113 | { 114 | OFFSET 33.000092 -0.000035 0.000022 115 | CHANNELS 3 Zrotation Yrotation Xrotation 116 | JOINT RightHand 117 | { 118 | OFFSET 25.199768 0.000178 0.000417 119 | CHANNELS 3 Zrotation Yrotation Xrotation 120 | End Site 121 | { 122 | OFFSET 0.000000 0.000000 0.000000 123 | } 124 | } 125 | } 126 | } 127 | } 128 | } 129 | } 130 | } 131 | } 132 | MOTION 133 | Frames: 1 134 | Frame Time: 0.033333 135 | 173.408295 91.952423 -518.280273 89.620408 0.652150 92.222382 175.529305 -3.825606 173.027898 -16.027724 0.907516 2.303384 76.117670 -2.318585 -2.386197 21.454555 0.003095 0.000000 177.801203 2.282903 -176.783696 -14.571799 4.189206 6.242791 77.874716 -3.774645 2.921458 21.454560 -0.003101 -0.000015 5.477288 0.241115 0.396648 0.943347 0.474440 0.795290 0.625275 0.479302 0.792255 0.966223 6.327928 -2.427240 14.907216 -14.144141 -0.363863 -64.744039 -84.244694 -117.560273 -11.211029 3.037350 -15.035068 -8.166726 -10.412443 0.705441 12.859502 1.553254 18.358731 -82.670322 80.131137 99.953589 -3.698083 -8.076120 16.001649 -15.291380 15.341529 -1.080976 12.298188 10.353672 -7.473871 136 | -------------------------------------------------------------------------------- /parser/parser_mixamo.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from parser.base import boolean_string 3 | 4 | 5 | def get_parser(): 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument('--save_dir', type=str, default='./pretrained', help='directory for all savings') 8 | parser.add_argument('--cuda_device', type=str, default='cuda:0', help='cuda device number, eg:[cuda:0]') 9 | parser.add_argument('--use_parallel', type=boolean_string, default=False) 10 | 11 | parser.add_argument('--learning_rate', type=float, default=2e-4, help='learning rate') 12 | parser.add_argument('--alpha', type=float, default=0, help='penalty of sparsity') 13 | parser.add_argument('--batch_size', type=int, default=32, help='batch_size') 14 | parser.add_argument('--upsampling', type=str, default='linear', help="'stride2' or 'nearest', 'linear'") 15 | parser.add_argument('--downsampling', type=str, default='stride2', help='stride2 or max_pooling') 16 | parser.add_argument('--batch_normalization', type=int, default=0, help='batch_norm: 1 or 0') 17 | parser.add_argument('--activation', type=str, default='LeakyReLU', help='activation: ReLU, LeakyReLU, tanh') 18 | parser.add_argument('--rotation', type=str, default='quaternion', help='representation of rotation:euler_angle, quaternion') 19 | parser.add_argument('--data_augment', type=int, default=1, help='data_augment: 1 or 0') 20 | parser.add_argument('--epoch_num', type=int, default=1001, help='epoch_num') 21 | parser.add_argument('--window_size', type=int, default=64, help='length of time axis per window') 22 | parser.add_argument('--kernel_size', type=int, default=15, help='must be odd') 23 | parser.add_argument('--base_channel_num', type=int, default=-1) 24 | parser.add_argument('--normalization', type=int, default=1) 25 | parser.add_argument('--verbose', type=int, default=0) 26 | parser.add_argument('--padding_mode', type=str, default='reflection') 27 | parser.add_argument('--dataset', type=str, default='Mixamo') 28 | parser.add_argument('--fk_world', type=int, default=0) 29 | parser.add_argument('--debug', type=int, default=0) 30 | parser.add_argument('--skeleton_info', type=str, default='additive') 31 | parser.add_argument('--ee_loss_fact', type=str, default='height') 32 | parser.add_argument('--pos_repr', type=str, default='3d') 33 | parser.add_argument('--gan_mode', type=str, default='lsgan') 34 | parser.add_argument('--pool_size', type=int, default=50) 35 | parser.add_argument('--is_train', type=int, default=1) 36 | 37 | parser.add_argument('--model', type=str, default='pan') 38 | parser.add_argument('--epoch_begin', type=int, default=0) 39 | parser.add_argument('--lambda_rec', type=float, default=1) 40 | parser.add_argument('--lambda_cycle', type=float, default=2.5) 41 | 42 | parser.add_argument('--scheduler', type=str, default='none') 43 | parser.add_argument('--rec_loss_mode', type=str, default='extra_global_pos') 44 | parser.add_argument('--adaptive_ee', type=int, default=0) 45 | parser.add_argument('--use_sep_ee', type=int, default=0) 46 | parser.add_argument('--eval_seq', type=int, default=0) 47 | parser.add_argument('--ee_velo', type=int, default=1) 48 | parser.add_argument('--ee_from_root', type=int, default=1) 49 | 50 | 51 | parser.add_argument('--save_iter', type=int, default=200) 52 | # transformer parsers 53 | parser.add_argument('--transformer_srcdim', type=int, default=4) 54 | parser.add_argument('--transformer_latents', type=int, default=32) 55 | parser.add_argument('--transformer_heads', type=int, default=2) 56 | parser.add_argument('--transformer_layers', type=int, default=2) 57 | parser.add_argument('--transformer_ffsize', type=int, default=256) 58 | parser.add_argument('--transformer_dropout', type=float, default=0.2) 59 | parser.add_argument('--conv_layers', type=int, default=2, help='number of conv layers') 60 | parser.add_argument('--fc_size', type=int, default=512) 61 | 62 | return parser 63 | 64 | 65 | def get_args(): 66 | parser = get_parser() 67 | return parser.parse_args() 68 | 69 | 70 | def get_std_bvh(args=None, dataset=None): 71 | if args is None and dataset is None: raise Exception('Unexpected parameter') 72 | if dataset is None: dataset = args.dataset 73 | std_bvh = './data_preprocess/Mixamo/Mixamo/std_bvhs/{}.bvh'.format(dataset) 74 | return std_bvh 75 | 76 | 77 | def try_mkdir(path): 78 | import os 79 | if not os.path.exists(path): 80 | os.system('mkdir -p {}'.format(path)) 81 | -------------------------------------------------------------------------------- /data_preprocess/Mixamo/motion_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import numpy as np 3 | import torch 4 | from outer_utils.Quaternions import Quaternions 5 | from parser.parser_mixamo import get_std_bvh 6 | 7 | 8 | class MotionData(Dataset): 9 | """ 10 | Clip long dataset into fixed length window for batched training 11 | each data is a 2d tensor with shape (Joint_num*3) * Time 12 | """ 13 | def __init__(self, args): 14 | super(MotionData, self).__init__() 15 | name = args.dataset 16 | file_path = './data_preprocess/Mixamo/Mixamo/{}.npy'.format(name) 17 | 18 | if args.debug: 19 | file_path = file_path[:-4] + '_debug' + file_path[-4:] 20 | 21 | print('load from file {}'.format(file_path)) 22 | self.total_frame = 0 23 | self.std_bvh = get_std_bvh(args) 24 | self.args = args 25 | self.data = [] 26 | self.motion_length = [] 27 | motions = np.load(file_path, allow_pickle=True) 28 | motions = list(motions) 29 | new_windows = self.get_windows(motions) 30 | self.data.append(new_windows) 31 | self.data = torch.cat(self.data) 32 | self.data = self.data.permute(0, 2, 1) 33 | 34 | if args.normalization == 1: 35 | self.mean = torch.mean(self.data, (0, 2), keepdim=True) 36 | self.var = torch.var(self.data, (0, 2), keepdim=True) 37 | self.var = self.var ** (1/2) 38 | idx = self.var < 1e-5 39 | self.var[idx] = 1 40 | self.data = (self.data - self.mean) / self.var 41 | else: 42 | self.mean = torch.mean(self.data, (0, 2), keepdim=True) 43 | self.mean.zero_() 44 | self.var = torch.ones_like(self.mean) 45 | 46 | train_len = self.data.shape[0] * 95 // 100 47 | self.test_set = self.data[train_len:, ...] 48 | self.data = self.data[:train_len, ...] 49 | self.data_reverse = torch.tensor(self.data.numpy()[..., ::-1].copy()) 50 | 51 | self.reset_length_flag = 0 52 | self.virtual_length = 0 53 | print('Window count: {}, total frame (without downsampling): {}'.format(len(self), self.total_frame)) 54 | 55 | def reset_length(self, length): 56 | self.reset_length_flag = 1 57 | self.virtual_length = length 58 | 59 | def __len__(self): 60 | if self.reset_length_flag: 61 | return self.virtual_length 62 | else: 63 | return self.data.shape[0] 64 | 65 | def __getitem__(self, item): 66 | if isinstance(item, int): item %= self.data.shape[0] 67 | if self.args.data_augment == 0 or np.random.randint(0, 2) == 0: 68 | return self.data[item] 69 | else: 70 | return self.data_reverse[item] 71 | 72 | def get_windows(self, motions): 73 | new_windows = [] 74 | 75 | for motion in motions: 76 | self.total_frame += motion.shape[0] 77 | motion = self.subsample(motion) 78 | self.motion_length.append(motion.shape[0]) 79 | step_size = self.args.window_size // 2 80 | window_size = step_size * 2 81 | n_window = motion.shape[0] // step_size - 1 82 | for i in range(n_window): 83 | begin = i * step_size 84 | end = begin + window_size 85 | 86 | new = motion[begin:end, :] 87 | if self.args.rotation == 'quaternion': 88 | new = new.reshape(new.shape[0], -1, 3) 89 | rotations = new[:, :-1, :] 90 | rotations = Quaternions.from_euler(np.radians(rotations)).qs 91 | rotations = rotations.reshape(rotations.shape[0], -1) 92 | positions = new[:, -1, :] 93 | positions = np.concatenate((new, np.zeros((new.shape[0], new.shape[1], 1))), axis=2) 94 | new = np.concatenate((rotations, new[:, -1, :].reshape(new.shape[0], -1)), axis=1) 95 | 96 | new = new[np.newaxis, ...] 97 | 98 | new_window = torch.tensor(new, dtype=torch.float32) 99 | new_windows.append(new_window) 100 | 101 | return torch.cat(new_windows) 102 | 103 | def subsample(self, motion): 104 | return motion[::2, :] 105 | 106 | def denormalize(self, motion): 107 | if self.args.normalization: 108 | if self.var.device != motion.device: 109 | self.var = self.var.to(motion.device) 110 | self.mean = self.mean.to(motion.device) 111 | ans = motion * self.var + self.mean 112 | else: ans = motion 113 | return ans 114 | -------------------------------------------------------------------------------- /parser/base.py: -------------------------------------------------------------------------------- 1 | # base parser for biped-quadruped retargeting 2 | from argparse import ArgumentParser 3 | 4 | def boolean_string(s): 5 | if s not in {'False', 'True'}: 6 | raise ValueError('Not a valid boolean string') 7 | return s == 'True' 8 | 9 | 10 | def add_misc_options(parser): 11 | group = parser.add_argument_group('Miscellaneous options') 12 | group.add_argument("--save_dir", help="directory name to save models", default='./run') 13 | group.add_argument('--with_end', type=boolean_string, default=True, help='whether considering the endsites of the dog') 14 | 15 | def add_cuda_options(parser): 16 | group = parser.add_argument_group('Cuda options') 17 | group.add_argument('--device', type=str, default='cuda:3') 18 | 19 | 20 | def adding_cuda(parameters): 21 | import torch 22 | if parameters["cuda"] and torch.cuda.is_available(): 23 | parameters["device"] = torch.device("cuda") 24 | else: 25 | parameters["device"] = torch.device("cpu") 26 | 27 | 28 | def add_dataset_options(parser): 29 | group = parser.add_argument_group('Dataset options') 30 | group.add_argument("--humstats_path", type=str, default='./data_preprocess/Lafan1_and_dog/humstats.npz') 31 | group.add_argument("--dogstats_path", type=str, default='./data_preprocess/Lafan1_and_dog/dogstats.npz') 32 | group.add_argument("--dog_train_path", type=str, default='./data_preprocess/Lafan1_and_dog/dogtrain.npz') 33 | group.add_argument("--hum_train_path", type=str, default='./data_preprocess/Lafan1_and_dog/humtrain.npz') 34 | group.add_argument("--hum_test_path", type=str, default='./data_preprocess/Lafan1_and_dog/humtest.npz') 35 | group.add_argument("--dog_test_path", type=str, default='./data_preprocess/Lafan1_and_dog/dogtest.npz') 36 | group.add_argument("--time_size", type=int, default=64) 37 | 38 | 39 | def add_losses_options(parser): 40 | group = parser.add_argument_group('Losses options') 41 | 42 | group.add_argument("--rec_loss_type", type=str, 43 | choices=["mse_rec", "quat_rec", "norm_rec"], 44 | default='quat_rec') 45 | group.add_argument("--root_loss_type", type=str, choices=["mse_root"], default='mse_root') 46 | group.add_argument("--global_kine_loss_type", type=str, 47 | choices=["mse_kine", "l1_kine", "part_kine"], default="part_kine") 48 | group.add_argument("--cyc_loss_type", type=str, default="mse_cycle_motion") 49 | group.add_argument("--cyc_latent_loss_type", type=str, default="mse_latent") 50 | group.add_argument("--retar_vel_loss_type", type=str, default='linear') 51 | group.add_argument("--dis_loss_type", type=str, choices=["bce_gan", "l2_gan"], default='l2_gan') 52 | group.add_argument("--retar_vel_matching", type=str, default='mapping', choices=["mapping", 'direct', 'direction']) 53 | 54 | group.add_argument('--lambda_rec', type=float, default=1) 55 | group.add_argument('--lambda_cycle', type=float, default=1e-3) 56 | group.add_argument('--lambda_retar_vel', type=float, default=1e3) 57 | 58 | 59 | def add_model_options(parser): 60 | group = parser.add_argument_group('Model options') 61 | group.add_argument("--architecture_name", type=str, default='pan') 62 | group.add_argument("--fid_net_name", type=str, default='FIDAutoEncoder') 63 | 64 | group.add_argument("--transformer", type=boolean_string, default=True) 65 | group.add_argument("--transformer_layers", type=int, default=1) 66 | group.add_argument("--transformer_latents", type=int, default=32) 67 | group.add_argument("--transformer_ffsize", type=int, default=256) 68 | group.add_argument("--transformer_heads", type=int, default=1) 69 | group.add_argument("--transformer_dropout", type=int, default=0) 70 | group.add_argument("--transformer_srcdim", type=int, default=4) 71 | 72 | group.add_argument("--conv_input", type=int, default=4) 73 | group.add_argument("--conv_layers", type=int, default=2) 74 | group.add_argument("--kernel_size", type=int, default=15) 75 | group.add_argument("--dim_per_part", type=int, default=32) 76 | group.add_argument("--padding_mode", type=str, default='reflect') 77 | 78 | group.add_argument('--upsampling', type=str, default='linear', help="'stride2' or 'nearest', 'linear'") 79 | group.add_argument("--skeleton_info", type=str, default="additive") 80 | 81 | group.add_argument("--dis", type=boolean_string, help="use_discriminator", default=True) 82 | group.add_argument("--diter", type=int, default=3) 83 | group.add_argument("--dis_mode", type=str, 84 | choices=['norm_rotation', 'denorm_rotation', 'denorm_pos', 'latent'], default='denorm_pos') 85 | group.add_argument("--dis_hidden", type=int, default=256) 86 | group.add_argument("--dis_layers", type=int, default=3) 87 | group.add_argument("--dis_kernel_size", type=int, default=15) 88 | 89 | 90 | def try_mkdir(path): 91 | import os 92 | if not os.path.exists(path): 93 | os.system('mkdir -p {}'.format(path)) 94 | 95 | 96 | class Dict(dict): 97 | __setattr__ = dict.__setattr__ 98 | __getattr__ = dict.__getitem__ 99 | 100 | 101 | def dict_to_object(dictObj): 102 | if not isinstance(dictObj, dict): 103 | return dictObj 104 | inst = Dict() 105 | for k, v in dictObj.items(): 106 | inst[k] = dict_to_object(v) 107 | return inst -------------------------------------------------------------------------------- /utils/metrices.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import linalg 3 | import torch 4 | from torch.utils.data.dataloader import DataLoader 5 | import torch.nn as nn 6 | 7 | 8 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 9 | """Numpy implementation of the Frechet Distance. 10 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 11 | and X_2 ~ N(mu_2, C_2) is 12 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 13 | Stable version by Dougal J. Sutherland. 14 | Params: 15 | -- mu1 : Numpy array containing the activations of a layer of the 16 | inception net (like returned by the function 'get_predictions') 17 | for generated samples. 18 | -- mu2 : The sample mean over activations, precalculated on an 19 | representative data set. 20 | -- sigma1: The covariance matrix over activations for generated samples. 21 | -- sigma2: The covariance matrix over activations, precalculated on an 22 | representative data set. 23 | Returns: 24 | -- : The Frechet Distance. 25 | """ 26 | 27 | mu1 = np.atleast_1d(mu1) 28 | mu2 = np.atleast_1d(mu2) 29 | 30 | sigma1 = np.atleast_2d(sigma1) 31 | sigma2 = np.atleast_2d(sigma2) 32 | 33 | assert mu1.shape == mu2.shape, \ 34 | 'Training and test mean vectors have different lengths' 35 | assert sigma1.shape == sigma2.shape, \ 36 | 'Training and test covariances have different dimensions' 37 | 38 | diff = mu1 - mu2 39 | 40 | # Product might be almost singular 41 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 42 | if not np.isfinite(covmean).all(): 43 | msg = ('fid calculation produces singular product; ' 44 | 'adding %s to diagonal of cov estimates') % eps 45 | print(msg) 46 | offset = np.eye(sigma1.shape[0]) * eps 47 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 48 | 49 | # Numerical error might give slight imaginary component 50 | if np.iscomplexobj(covmean): 51 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 52 | m = np.max(np.abs(covmean.imag)) 53 | raise ValueError('Imaginary component {}'.format(m)) 54 | covmean = covmean.real 55 | 56 | tr_covmean = np.trace(covmean) 57 | 58 | return (diff.dot(diff) + np.trace(sigma1) + 59 | np.trace(sigma2) - 2 * tr_covmean) 60 | 61 | 62 | def calculate_activations(topology_name, args, motion_loader, gan_model, fid_model, datasets, is_san=False): 63 | print('Calculating Activations...') 64 | mse = nn.MSELoss() 65 | errs = [] 66 | activations = [] 67 | topology_names = ['human', 'dog'] 68 | index = topology_names.index(topology_name) 69 | model = gan_model.models[index] 70 | fid_net = fid_model.models[index].ae 71 | 72 | if index == 0: 73 | njoints = args.hum_njoints 74 | else: 75 | njoints = args.dog_njoints 76 | with torch.no_grad(): 77 | if isinstance(motion_loader, DataLoader): 78 | for idx, batch in enumerate(motion_loader): 79 | try: 80 | input_, yrot, phases = batch 81 | except: 82 | input_, yrot, phases, _ = batch 83 | batch_input = (input_[..., :njoints * 4 + 3]).transpose(1, 2).float().to(args.device) 84 | batch_input = datasets[index].denorm(batch_input, transpose=True) 85 | 86 | _, batch_joints = model.fk.forward(batch_input) 87 | 88 | batch_joints_norm = fid_model.normalize(batch_joints.reshape(batch_joints.shape[:2] + (-1, )), 89 | fid_model.means[index], fid_model.std[index]).transpose(1, 2) 90 | batch_joints_norm = torch.clone(batch_joints_norm).float().detach_().to(args.device) 91 | 92 | latent, rec = fid_net(batch_joints_norm) # B C 1 93 | 94 | rec_denorm = fid_model.denormalize(rec.transpose(1, 2), fid_model.means[index], fid_model.std[index]) 95 | err = mse(rec_denorm, batch_joints.reshape(batch_joints.shape[:2] + (-1, ))) 96 | errs.append(err) 97 | activations.append(latent.reshape(latent.shape[0], -1)) # B C T 98 | activations = torch.cat(activations, dim=0) 99 | print(torch.mean(torch.stack(errs, 0))) 100 | else: 101 | for idx in range(motion_loader.shape[0]): 102 | if not is_san: 103 | input_ = motion_loader[idx] 104 | batch_input = torch.clone(input_).float().detach_().to(args.device) 105 | batch_input = datasets[index].denorm(batch_input, transpose=True) 106 | _, batch_joints = model.fk.forward(batch_input) 107 | else: 108 | batch_joints = motion_loader[idx] 109 | batch_joints = torch.from_numpy(batch_joints).to(args.device).float() 110 | batch_joints = fid_model.normalize(batch_joints.reshape(batch_joints.shape[:2] + (-1,)), 111 | fid_model.means[index], fid_model.std[index]).transpose(1, 2) 112 | latent, _ = fid_net(batch_joints) # B C B,512 113 | activations.append(latent.reshape(latent.shape[0], -1)) # B C 114 | 115 | activations = torch.cat(activations, dim=0) 116 | 117 | return activations 118 | 119 | 120 | def calculate_activation_statistics(activations): 121 | activations = activations.cpu().numpy() 122 | mu = np.mean(activations, axis=0) 123 | sigma = np.cov(activations, rowvar=False) 124 | 125 | return mu, sigma 126 | 127 | def calculate_fid(statistics_1, statistics_2): 128 | return calculate_frechet_distance(statistics_1[0], statistics_1[1], 129 | statistics_2[0], statistics_2[1]) 130 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pan-motion-retargeting 2 | Official implementation for the paper ["Pose-aware Attention Network for Flexible Motion Retargeting by Body Part"](https://ieeexplore.ieee.org/document/10129844) 3 | 4 | Please visit our [webpage](https://hlcdyy.github.io/pan-motion-retargeting/) for more details. 5 | 6 | ![hum2dog](https://raw.githubusercontent.com/hlcdyy/pan-motion-retargeting/pan-page/static/hum2dog.gif) 7 | 8 | ## Getting started 9 | This code was tested on `Ubuntu 18.04.4 LTS` and requires: 10 | * Python 3.8 11 | * conda3 12 | * CUDA capable GPU 13 | 14 | ### 1. Creat conda environment 15 |
Click to expand 16 | 17 | We strongly recommend activating a Python virtual environment prior to installing PAN. Follow these steps to download and install it. Then run the following commands: 18 | ```` 19 | # create and activate the virtual environment 20 | conda create --name pan_retargeting python=3.8.12 21 | conda activate pan_retargeting 22 | ```` 23 | Install [PyTorch 1.10.0](https://pytorch.org/) inside the conda environment. 24 | 25 | ```` 26 | # clone pan-motion-retargeting and use pip to install 27 | git clone https://github.com/hlcdyy/pan-motion-retargeting.git 28 | cd pan-motion-retargeting 29 | pip install -e . 30 | ```` 31 |
32 | 33 | ### 2. Download the datasets 34 |
Click to expand 35 | 36 | **Mixamo dataset** 37 | 38 | **Be sure to read and follow their license agreements, and cite accordingly.** 39 | 40 | We use [Mixamo](https://www.mixamo.com/#/) dataset to train our model for retargeting between humanoid characters and follow the train-test setting of [SAN](https://github.com/DeepMotionEditing/deep-motion-editing). You can find the download link in the [SAN github page](https://github.com/DeepMotionEditing/deep-motion-editing) for the preprocessed data or generating the data from scratch. 41 | 42 | The `Mixamo` directory should be placed within `data_preprocess/Mixamo` 43 | 44 | **Lafan1 and Dog datasets** 45 | 46 | **Be sure to read and follow their license agreements, and cite accordingly.** 47 | 48 | Creat this folder: 49 | ```` 50 | mkdir data_preprocess/Lafan1_and_dog/Lafan1 51 | ```` 52 | Go to the [Lafan1 website](https://github.com/ubisoft/ubisoft-laforge-animation-dataset) and download the lafan1.zip. Then unzip it and put all the .bvh files into `data_preprocess/Lafan1_and_dog/Lafan1` 53 | 54 | Creat this folder: 55 | ```` 56 | mkdir data_preprocess/Lafan1_and_dog/DogSet 57 | ```` 58 | Go to the [AI4Animation Website](https://github.com/sebastianstarke/AI4Animation) and get the Mocap Data from "Mode-Adaptive Neural Networks for Quadruped Motion Control". Then put all the .bvh files into the `data_preprocess/Lafan1_and_dog/DogSet` 59 | 60 | **Process the Lafan1 and dog data using the following commands:** 61 | 62 | ```` 63 | python data_preprocess/Lafan1_and_dog/extract.py 64 | ```` 65 | It will use train/test split files in the folder to generate the processed .npz files and the statistic files for training and testing. 66 | 67 | You can also download our preprocessed data from [Google Drive](https://drive.google.com/file/d/1q6xjlssq3G-O-SBr-IHGVJnCCM_KrSCA/view?usp=sharing) and put all the npz files into `data_preprocess/Lafan1_and_dog/` after unzipping. 68 | 69 |
70 | 71 | ### 3. Download pretrained model 72 | 73 |
Click to expand 74 | 75 | **Model for retargeting between Mixamo characters** 76 | 77 | Download the models from [hear](https://drive.google.com/file/d/1jYtOLCDye68nShXNlse-I5hCe7ZAHaG-/view?usp=sharing) and unzip the file in the workspace of this project by following command: 78 | ```` 79 | unzip pretrained_mixamo.zip 80 | ```` 81 | Eventually the `./pretrained_mixamo` folder should have the following structure: 82 | ``` 83 | pretrained_mixamo 84 | └-- models 85 | └-- optimizers 86 | └-- topology0 87 | └-- topology1 88 | └-- para.txt 89 | ``` 90 | 91 | **Model for retargeting between biped and quadruped** 92 | 93 | Download the models from [hear](https://drive.google.com/file/d/1p-fDC9nIuqktVaqxcAr4wSa09mGq1_63/view?usp=sharing) and unzip the file by following command: 94 | 95 | ```` 96 | unzip pretrained_lafan1dog.zip 97 | ```` 98 | The `./pretrained_lafan1dog` folder should look like this: 99 | ``` 100 | pretrained_lafan1dog 101 | └-- models 102 | └-- dog 103 | └-- human 104 | └-- optimizers 105 | └-- para.txt 106 | ``` 107 | 108 |
109 | 110 | ## Quick Start 111 |
Click to expand 112 | 113 | We provide scripts together with demo examples using files specified in bvh format. 114 | 115 | To generate the example of retargeting from biped to quadruped skeleton, run the following command: 116 | ```` 117 | python demo_hum2dog.py 118 | ```` 119 | The retargeting source file and the results will be saved in floder `./pretrained_lafan1dog/demo/hum2dog` 120 | 121 | As for retargeting from quadruped to biped, run: 122 | ```` 123 | python demo_dog2hum.py 124 | ```` 125 | 126 | To generate the retargeting results between Mixamo skeletons, please run: 127 | ```` 128 | python demo_mixamo.py 129 | ```` 130 | The results are stored in floder `./pretrained_mixamo/demo` including intra- and cross-strutural retargeting. 131 | 132 |
133 | 134 | ## Training models from scratch 135 |
Click to expand 136 | 137 | **Train models using Mixamo dataset** 138 | ```` 139 | python train_mixamo.py --save_dir ./pretrained_mixamo --batch_size 128 --model pan --learning_rate 1e-3 --cuda_device cuda --use_parallel True 140 | ```` 141 | 142 | **Train models using Lafan1 and dog datasets** 143 | ```` 144 | python train_lafan1dog.py --save_dir ./pretrained_lafan1dog --rec_loss_type norm_rec --lambda_cycle 1e-3 --lambda_retar_vel 1e3 --device cuda:0 --batch_size 128 --with_end True 145 | ```` 146 | 147 |
148 | 149 | ## Quantitative Evaluations 150 | 151 | ```` 152 | python test_mixamo.py --save_dir ./pretrained_mixamo --model pan --epoch 1000 153 | ```` 154 | This will evaluate the model performance on Mixamo dataset by intra- and cross-structural retargeting. The generated retargeting results will be saved in `./pretrained_mixamo/results`. 155 | 156 | ## Comments 157 | * Our code for the training architecture and strategy builds on [SAN](https://github.com/DeepMotionEditing/deep-motion-editing). 158 | * The data processing code for Lafan1 and dog is based on the project of [Ubisoft La Forge Animation Dataset](https://github.com/ubisoft/ubisoft-laforge-animation-dataset) 159 | 160 | ## Citation 161 | If you find this project useful for your research, please consider citing: 162 | ```` 163 | @article{hu2023pose, 164 | title={Pose-Aware Attention Network for Flexible Motion Retargeting by Body Part}, 165 | author={Hu, Lei and Zhang, Zihao and Zhong, Chongyang and Jiang, Boyuan and Xia, Shihong}, 166 | journal={IEEE Transactions on Visualization and Computer Graphics}, 167 | year={2023}, 168 | publisher={IEEE} 169 | } 170 | ```` 171 | -------------------------------------------------------------------------------- /demo_hum2dog.py: -------------------------------------------------------------------------------- 1 | import os 2 | from parser.base import try_mkdir 3 | from utils.rotation import * 4 | from models import creat_model 5 | from data_preprocess.Lafan1_and_dog.datasetserial import HumDataset, DogDataset 6 | from parser.base import dict_to_object 7 | from utils.bvh_utils import save_bvh, Anim, read_bvh, read_bvh_with_end 8 | from data_preprocess.Lafan1_and_dog.extract import get_lafan1_example 9 | from parser.evaluation import get_parser 10 | from config import Configuration 11 | from utils.utils import get_body_part 12 | from models.IK import remove_foot_sliding_humdog 13 | 14 | os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' 15 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 16 | 17 | # load standard 18 | human_std = './data_preprocess/Lafan1_and_dog/std_bvh/hum_std.bvh' 19 | dog_std = './data_preprocess/Lafan1_and_dog/std_bvh/dog_std.bvh' 20 | std_dog_anim = read_bvh(dog_std) 21 | std_hum_anim = read_bvh(human_std) 22 | standard_pos = std_dog_anim.pos[0:1, ...] 23 | dog_tmp = read_bvh_with_end(dog_std) 24 | hum_tmp = read_bvh_with_end(human_std) 25 | hum_end_sites = [] 26 | dog_end_sites = [] 27 | 28 | for i in range(len(hum_tmp.bones)): 29 | if hum_tmp.bones[i] == 'End Site': 30 | hum_end_sites.append(i) 31 | 32 | for i in range(len(dog_tmp.bones)): 33 | if dog_tmp.bones[i] == 'End Site': 34 | dog_end_sites.append(i) 35 | dog_end_offsets = dog_tmp.offsets[dog_end_sites, :] 36 | hum_end_offsets = hum_tmp.offsets[hum_end_sites, :] 37 | 38 | 39 | def main(): 40 | bvh_dir = './demo_dir/Lafan1' # source motion directory for retargeting from human to dog skeletons 41 | save_dir = './pretrained_lafan1dog' # save dictory and also the used model dictory. 42 | 43 | device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") 44 | para_path = os.path.join(save_dir, 'para.txt') 45 | with open(para_path, 'r') as para_file: 46 | argv_ = para_file.readline().split()[1:] 47 | args = get_parser(argv_) 48 | parameters_config = {key: val for key, val in vars(Configuration).items() if val is not None} 49 | parameters_args = {key: val for key, val in vars(args).items() if val is not None} 50 | parameters_args.update(parameters_config) 51 | args = dict_to_object(parameters_args) 52 | args.device = device 53 | args.batch_size = 1 54 | try_mkdir(os.path.join(args.save_dir, 'demo')) 55 | 56 | humdataset = HumDataset(args) 57 | dogdataset = DogDataset(args) 58 | 59 | hum_parts = get_body_part(args.correspondence, 'hum_joints') 60 | dog_parts = get_body_part(args.correspondence, 'dog_joints') 61 | 62 | body_parts = [hum_parts, dog_parts] 63 | datasets = [humdataset, dogdataset] 64 | 65 | model = creat_model(args, body_parts, datasets, ['human', 'dog']) 66 | model.load(epoch=None) # specify the epoch number for testing, None is for the latest. 67 | model.setup() 68 | 69 | ori_name = [name[:-4] for name in os.listdir(bvh_dir)] 70 | files = [os.path.join(bvh_dir, name) for name in os.listdir(bvh_dir)] 71 | num = 0 72 | 73 | for file in files: 74 | print("retargeting the human motion %s to dog skeleton" % str(file)) 75 | X, Q, Pos, V, parents, yrot, offsets, offsets_withend = get_lafan1_example(file) 76 | offsets = torch.Tensor(offsets).to(device) 77 | offsets_withend = torch.Tensor(offsets_withend).to(device) 78 | 79 | rvel = wrap(quat2pivots, wrap(qmultipy, wrap(qinv, yrot[:, :-1, ...]), yrot[:, 1:, ...])) 80 | rvel = np.concatenate((rvel, rvel[:, -1:, ...]), axis=1) 81 | rvel = np.reshape(rvel, rvel.shape[:2] + (-1,)) 82 | 83 | args.time_size = X.shape[1] - X.shape[1] % 4 84 | 85 | yrot = yrot[:, :args.time_size, ...] 86 | Q_src = Q.copy()[:, :args.time_size, ...] 87 | V_src = V.copy() 88 | Q_src[:, :args.time_size, :1, :] = wrap(qmultipy, yrot, Q[:, :args.time_size, :1, :]) 89 | V_src = wrap(qrot, yrot, V_src[:, :args.time_size]) 90 | for i in range(1, V_src.shape[1]): 91 | V_src[:, i, ...] = V_src[:, i - 1, ...] + V_src[:, i, ...] 92 | 93 | Pos_src = Pos[:, :args.time_size, ...] 94 | Pos_src[..., 0, :] = V_src[..., 0, :] 95 | 96 | src_anim = Anim(Q_src.squeeze(), Pos_src.squeeze(), 97 | std_hum_anim.offsets, std_hum_anim.parents, std_hum_anim.bones) 98 | 99 | indices = np.where(Q[..., 0] < 0) 100 | Q[indices] = -Q[indices] 101 | Q = np.reshape(Q, [Q.shape[0], Q.shape[1], -1]) 102 | V = np.reshape(V, [V.shape[0], V.shape[1], -1]) 103 | RootV = V[..., :3] 104 | data = np.concatenate([Q, RootV, rvel], axis=-1) 105 | data = (data - humdataset.mean[np.newaxis, np.newaxis, ...]) / humdataset.std[np.newaxis, np.newaxis, ...] 106 | 107 | vel_dim = 4 108 | input_h_encoder = torch.Tensor(data[..., :args.hum_njoints * 4 + vel_dim]).transpose(1, 2).to(device) 109 | input_d_encoder = torch.zeros(data.shape[:-1] + (args.dog_njoints * 4 + vel_dim,) 110 | ).transpose(1, 2).to(device) # Placeholder, meaningless 111 | 112 | input_h_encoder = input_h_encoder[..., :args.time_size] 113 | input_d_encoder = input_d_encoder[..., :args.time_size] # Placeholder, meaningless 114 | 115 | input_h_encoder = (input_h_encoder, offsets, offsets_withend) 116 | input_d_encoder = (input_d_encoder, offsets[..., :(args.dog_njoints-1)*3], offsets_withend) # Placeholder, meaningless 117 | 118 | model.set_input([input_h_encoder, input_d_encoder]) 119 | model.forward() 120 | src, retar = model.motion_denorm[0], model.fake_retar_denorm[0] 121 | 122 | retar_q = qnorm(retar[..., :-vel_dim].reshape(-1, args.dog_njoints, 4)) 123 | 124 | retar_vel = retar[..., -4:-1].squeeze() 125 | 126 | retar_q[..., :1, :] = qmultipy(torch.Tensor(yrot).to(device), retar_q[:, :1, :].unsqueeze(0)).squeeze(0) 127 | retar_vel = qrot(torch.Tensor(yrot).to(device).squeeze(), retar_vel) 128 | 129 | for i in range(1, retar_vel.shape[0]): 130 | retar_vel[i, ...] = retar_vel[i-1, ...] + retar_vel[i, ...] 131 | 132 | retar_q_np = retar_q.detach().cpu().numpy() 133 | retar_vel_np = retar_vel.detach().cpu().numpy()[:, np.newaxis] 134 | pos = standard_pos.repeat(retar_q.shape[0], axis=0) 135 | pos[:, 0:1, :] = retar_vel_np 136 | retar_anim = Anim(retar_q_np, pos, std_dog_anim.offsets, std_dog_anim.parents, std_dog_anim.bones) 137 | 138 | if not os.path.exists(os.path.join(args.save_dir, 'demo/hum2dog')): 139 | os.mkdir(os.path.join(args.save_dir, 'demo/hum2dog')) 140 | bvh_name = os.path.join(os.path.join(args.save_dir, 'demo/hum2dog'), ori_name[num]+'_retar.bvh') 141 | save_bvh(bvh_name, retar_anim, frametime=1 / 30, order='zyx', with_end=False, 142 | names=retar_anim.bones, end_offset=dog_end_offsets) 143 | 144 | remove_foot_sliding_humdog(bvh_name, bvh_name, end_site=True) 145 | 146 | bvh_src = os.path.join(os.path.join(args.save_dir, 'demo/hum2dog'), ori_name[num] + '_source.bvh') 147 | save_bvh(bvh_src, src_anim, frametime=1 / 30, order='zyx', 148 | with_end=False, names=src_anim.bones, end_offset=hum_end_offsets) 149 | 150 | num += 1 151 | 152 | if __name__ == '__main__': 153 | main() 154 | -------------------------------------------------------------------------------- /demo_dog2hum.py: -------------------------------------------------------------------------------- 1 | import os 2 | from parser.base import try_mkdir 3 | from utils.rotation import * 4 | from models import creat_model 5 | from data_preprocess.Lafan1_and_dog.datasetserial import HumDataset, DogDataset 6 | from parser.base import dict_to_object 7 | from utils.bvh_utils import save_bvh, Anim, read_bvh, read_bvh_with_end 8 | from data_preprocess.Lafan1_and_dog.extract import get_dog_example 9 | from parser.evaluation import get_parser 10 | from config import Configuration 11 | from utils.utils import get_body_part 12 | from models.IK import remove_foot_sliding_humdog 13 | 14 | os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' 15 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 16 | 17 | # load standard 18 | human_std = './data_preprocess/Lafan1_and_dog/std_bvh/hum_std.bvh' 19 | dog_std = './data_preprocess/Lafan1_and_dog/std_bvh/dog_std.bvh' 20 | std_dog_anim = read_bvh(dog_std) 21 | std_hum_anim = read_bvh(human_std) 22 | standard_pos = std_dog_anim.pos[0:1, ...] 23 | dog_tmp = read_bvh_with_end(dog_std) 24 | hum_tmp = read_bvh_with_end(human_std) 25 | hum_end_sites = [] 26 | dog_end_sites = [] 27 | 28 | for i in range(len(hum_tmp.bones)): 29 | if hum_tmp.bones[i] == 'End Site': 30 | hum_end_sites.append(i) 31 | 32 | for i in range(len(dog_tmp.bones)): 33 | if dog_tmp.bones[i] == 'End Site': 34 | dog_end_sites.append(i) 35 | dog_end_offsets = dog_tmp.offsets[dog_end_sites, :] 36 | hum_end_offsets = hum_tmp.offsets[hum_end_sites, :] 37 | 38 | def main(): 39 | bvh_dir = './demo_dir/Dog' # source motion directory 40 | save_dir = './pretrained_lafan1dog' # save dictory and also the used model dictory. 41 | 42 | device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") 43 | para_path = os.path.join(save_dir, 'para.txt') 44 | with open(para_path, 'r') as para_file: 45 | argv_ = para_file.readline().split()[1:] 46 | args = get_parser(argv_) 47 | parameters_config = {key: val for key, val in vars(Configuration).items() if val is not None} 48 | parameters_args = {key: val for key, val in vars(args).items() if val is not None} 49 | parameters_args.update(parameters_config) 50 | args = dict_to_object(parameters_args) 51 | args.device = device 52 | args.batch_size = 1 53 | try_mkdir(os.path.join(args.save_dir, 'demo')) 54 | 55 | humdataset = HumDataset(args) 56 | dogdataset = DogDataset(args) 57 | 58 | hum_parts = get_body_part(args.correspondence, 'hum_joints') 59 | dog_parts = get_body_part(args.correspondence, 'dog_joints') 60 | 61 | body_parts = [hum_parts, dog_parts] 62 | datasets = [humdataset, dogdataset] 63 | model = creat_model(args, body_parts, datasets, ['human', 'dog']) 64 | model.load(epoch=None) # specify the epoch number for testing, None is for the latest. 65 | model.setup() 66 | 67 | ori_name = [name[:-4] for name in os.listdir(bvh_dir) if (name[-4:] == '.bvh' )] 68 | files = [os.path.join(bvh_dir, name) for name in os.listdir(bvh_dir) 69 | if (name[-4:] == '.bvh')] 70 | num = 0 71 | 72 | for file in files: 73 | print("retargeting the dog motion %s to human skeleton" % str(file)) 74 | X, Q, Pos, V, parents, yrot, offsets, offsets_withend = get_dog_example(file) 75 | offsets = torch.Tensor(offsets).to(device) 76 | offsets_withend = torch.Tensor(offsets_withend).to(device) 77 | rvel = wrap(quat2pivots, wrap(qmultipy, wrap(qinv, yrot[:, :-1, ...]), yrot[:, 1:, ...])) 78 | rvel = np.concatenate((rvel, rvel[:, -1:, ...]), axis=1) 79 | rvel = np.reshape(rvel, rvel.shape[:2] + (-1,)) 80 | 81 | args.time_size = X.shape[1] - X.shape[1] % 4 82 | 83 | yrot = yrot[:, :args.time_size, ...] 84 | Q_src = Q.copy()[:, :args.time_size, ...] 85 | V_src = V.copy() 86 | Q_src[:, :args.time_size, :1, :] = wrap(qmultipy, yrot, Q[:, :args.time_size, :1, :]) 87 | V_src = wrap(qrot, yrot, V_src[:, :args.time_size]) 88 | for i in range(1, V_src.shape[1]): 89 | V_src[:, i, ...] = V_src[:, i - 1, ...] + V_src[:, i, ...] 90 | Pos_src = Pos[:, :args.time_size, ...] 91 | Pos_src[..., 0, :] = V_src[..., 0, :] 92 | 93 | src_anim = Anim(Q_src.squeeze(), Pos_src.squeeze(), 94 | std_dog_anim.offsets, std_dog_anim.parents, std_dog_anim.bones) 95 | 96 | indices = np.where(Q[..., 0] < 0) 97 | Q[indices] = -Q[indices] 98 | Q = np.reshape(Q, [Q.shape[0], Q.shape[1], -1]) 99 | V = np.reshape(V, [V.shape[0], V.shape[1], -1]) 100 | RootV = V[..., :3] 101 | data = np.concatenate([Q, RootV, rvel], axis=-1) 102 | data = (data - dogdataset.mean[np.newaxis, np.newaxis, ...]) / dogdataset.std[np.newaxis, np.newaxis, ...] 103 | 104 | 105 | vel_dim = 4 106 | input_d_encoder = torch.Tensor(data[..., :args.dog_njoints * 4 + vel_dim] 107 | ).transpose(1, 2).to(device) 108 | 109 | input_h_encoder = torch.zeros(data.shape[:-1] + (args.hum_njoints * 4 + vel_dim, ) 110 | ).transpose(1, 2).to(device) # Placeholder, meaningless 111 | 112 | input_d_encoder = input_d_encoder[..., :args.time_size] 113 | input_h_encoder = input_h_encoder[..., :args.time_size] 114 | 115 | input_d_encoder = ( 116 | input_d_encoder, offsets, offsets_withend) 117 | input_h_encoder = (input_h_encoder, torch.zeros(offsets.shape[:-1] + ((args.hum_njoints-1)*3, ) 118 | ).to(device), offsets_withend) # Placeholder, meaningless 119 | 120 | 121 | model.set_input([input_h_encoder, input_d_encoder]) 122 | model.forward() 123 | 124 | src, retar = model.motion_denorm[1], model.fake_retar_denorm[1] 125 | 126 | retar_q = qnorm(retar[..., :-vel_dim].reshape(-1, args.hum_njoints, 4)) 127 | 128 | retar_vel = retar[..., -4:-1].squeeze() 129 | 130 | retar_q[..., :1, :] = qmultipy(torch.Tensor(yrot).to(device), retar_q[:, :1, :].unsqueeze(0)).squeeze(0) 131 | retar_vel = qrot(torch.Tensor(yrot).to(device).squeeze(), retar_vel) 132 | 133 | for i in range(1, retar_vel.shape[0]): 134 | retar_vel[i, ...] = retar_vel[i-1, ...] + retar_vel[i, ...] 135 | retar_q_np = retar_q.detach().cpu().numpy() 136 | retar_vel_np = retar_vel.detach().cpu().numpy()[:, np.newaxis] 137 | pos = standard_pos.repeat(retar_q.shape[0], axis=0) 138 | pos[:, 0:1, :] = retar_vel_np 139 | retar_anim = Anim(retar_q_np, pos, std_hum_anim.offsets, std_hum_anim.parents, std_hum_anim.bones) 140 | 141 | if not os.path.exists(os.path.join(args.save_dir, 'demo/dog2hum')): 142 | os.mkdir(os.path.join(args.save_dir, 'demo/dog2hum')) 143 | bvh_name = os.path.join(os.path.join(args.save_dir, 'demo/dog2hum'), ori_name[num]+'_retar.bvh') 144 | save_bvh(bvh_name, retar_anim, frametime=1 / 30, order='zyx', with_end=False, 145 | names=retar_anim.bones, end_offset=hum_end_offsets) 146 | 147 | bvh_src = os.path.join(os.path.join(args.save_dir, 'demo/dog2hum'), ori_name[num] + '_source.bvh') 148 | save_bvh(bvh_src, src_anim, frametime=1 / 30, order='zyx', 149 | with_end=False, names=src_anim.bones, end_offset=dog_end_offsets) 150 | 151 | remove_foot_sliding_humdog(bvh_name, bvh_name, 152 | end_names=['LeftToe', 'RightToe', 'LeftFoot', 'RightFoot'], 153 | end_site=False) 154 | 155 | num += 1 156 | 157 | if __name__ == '__main__': 158 | main() 159 | -------------------------------------------------------------------------------- /loss_function.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import random 4 | from torch.optim import lr_scheduler 5 | 6 | mse = torch.nn.MSELoss() 7 | l1 = torch.nn.L1Loss() 8 | 9 | def dis_loss(pred, bool): 10 | device = pred.device 11 | loss = torch.nn.BCEWithLogitsLoss() 12 | if bool: 13 | gt = torch.ones_like(pred).to(device) 14 | return loss(pred, gt) 15 | else: 16 | gt = torch.zeros_like(pred).to(device) 17 | return loss(pred, gt) 18 | 19 | def dis_loss_l2(pred, bool): 20 | device = pred.device 21 | if bool: 22 | gt = torch.ones_like(pred).to(device) 23 | return mse(pred, gt) 24 | else: 25 | gt = torch.zeros_like(pred).to(device) 26 | return mse(pred, gt) 27 | 28 | 29 | def caloutputloss(pred, gt, njoints, indices=None): 30 | pred_quat = pred[..., :njoints * 4].clone().reshape(pred.shape[0], pred.shape[1], njoints, 4) 31 | gt_quat = gt[..., :njoints * 4].clone().reshape(pred.shape[0], pred.shape[1], njoints, 4) 32 | if indices is None: 33 | quat_mse = torch.mean(mse(pred_quat, gt_quat)) 34 | else: 35 | quat_mse = torch.mean(mse(pred_quat[..., indices, :], gt_quat[..., indices, :])) 36 | 37 | loss_total = quat_mse 38 | return loss_total 39 | 40 | 41 | def calposloss(pred, gt, indices=None): 42 | if indices is None: 43 | return mse(pred, gt) 44 | else: 45 | return mse(pred[..., indices, :], gt[..., indices, :]) 46 | 47 | 48 | def cycle_latents(gan_model, src, tgt): 49 | return mse(gan_model.retar_latents[src], gan_model.cyc_latents[tgt]) 50 | 51 | 52 | def cycle_motions(gan_model, src, tgt, indice=None): 53 | if indice == None: 54 | return mse(gan_model.gt_pos[src], gan_model.cyc_pos[tgt]) 55 | else: 56 | return mse(gan_model.gt_pos[src][..., indice, :], gan_model.cyc_pos[tgt][..., indice, :]) 57 | 58 | 59 | 60 | class GAN_loss(nn.Module): 61 | def __init__(self, gan_mode, real_lable=1.0, fake_lable=0.0): 62 | super(GAN_loss, self).__init__() 63 | self.register_buffer('real_label', torch.tensor(real_lable)) 64 | self.register_buffer('fake_label', torch.tensor(fake_lable)) 65 | self.gan_mode = gan_mode 66 | if gan_mode == 'lsgan': 67 | self.loss = nn.MSELoss() 68 | elif gan_mode == 'vanilla': 69 | self.loss = nn.BCEWithLogitsLoss() 70 | elif gan_mode == 'none': 71 | self.loss = None 72 | else: 73 | raise Exception('Unknown GAN mode') 74 | 75 | def get_target_tensor(self, prediction, target_is_real): 76 | if target_is_real: 77 | target_tensor = self.real_label 78 | else: 79 | target_tensor = self.fake_label 80 | return target_tensor.expand_as(prediction) 81 | 82 | def __call__(self, prediction, target_is_real): 83 | target_tensor = self.get_target_tensor(prediction, target_is_real) 84 | loss = self.loss(prediction, target_tensor) 85 | return loss 86 | 87 | 88 | class Criterion_EE: 89 | def __init__(self, args, base_criterion, norm_eps=0.008): 90 | self.args = args 91 | self.base_criterion = base_criterion 92 | self.norm_eps = norm_eps 93 | 94 | def __call__(self, pred, gt): 95 | reg_ee_loss = self.base_criterion(pred, gt) 96 | if self.args.ee_velo: 97 | gt_norm = torch.norm(gt, dim=-1) 98 | contact_idx = gt_norm < self.norm_eps 99 | extra_ee_loss = self.base_criterion(pred[contact_idx], gt[contact_idx]) 100 | else: 101 | extra_ee_loss = 0 102 | return reg_ee_loss + extra_ee_loss * 100 103 | 104 | def parameters(self): 105 | return [] 106 | 107 | class Criterion_EE_2: 108 | def __init__(self, args, base_criterion, norm_eps=0.008): 109 | print('Using adaptive EE') 110 | self.args = args 111 | self.base_criterion = base_criterion 112 | self.norm_eps = norm_eps 113 | self.ada_para = nn.Linear(15, 15).to(torch.device(args.cuda_device)) 114 | 115 | def __call__(self, pred, gt): 116 | pred = pred.reshape(pred.shape[:-2] + (-1,)) 117 | gt = gt.reshape(gt.shape[:-2] + (-1,)) 118 | pred = self.ada_para(pred) 119 | reg_ee_loss = self.base_criterion(pred, gt) 120 | extra_ee_loss = 0 121 | return reg_ee_loss + extra_ee_loss * 100 122 | 123 | def parameters(self): 124 | return list(self.ada_para.parameters()) 125 | 126 | class Eval_Criterion: 127 | def __init__(self, parent): 128 | self.pa = parent 129 | self.base_criterion = nn.MSELoss() 130 | pass 131 | 132 | def __call__(self, pred, gt): 133 | for i in range(1, len(self.pa)): 134 | pred[..., i, :] += pred[..., self.pa[i], :] 135 | gt[..., i, :] += pred[..., self.pa[i], :] 136 | return self.base_criterion(pred, gt) 137 | 138 | 139 | class ImagePool(): 140 | """This class implements an image buffer that stores previously generated images. 141 | This buffer enables us to update discriminators using a history of generated images 142 | rather than the ones produced by the latest generators. 143 | """ 144 | 145 | def __init__(self, pool_size): 146 | """Initialize the ImagePool class 147 | Parameters: 148 | pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created 149 | """ 150 | self.pool_size = pool_size 151 | if self.pool_size > 0: # create an empty pool 152 | self.num_imgs = 0 153 | self.images = [] 154 | 155 | def query(self, images): 156 | """Return an image from the pool. 157 | Parameters: 158 | images: the latest generated images from the generator 159 | Returns images from the buffer. 160 | By 50/100, the buffer will return input images. 161 | By 50/100, the buffer will return images previously stored in the buffer, 162 | and insert the current images to the buffer. 163 | """ 164 | if self.pool_size == 0: # if the buffer size is 0, do nothing 165 | return images 166 | return_images = [] 167 | for image in images: 168 | image = torch.unsqueeze(image.data, 0) 169 | if self.num_imgs < self.pool_size: # if the buffer is not full; keep inserting current images to the buffer 170 | self.num_imgs = self.num_imgs + 1 171 | self.images.append(image) 172 | return_images.append(image) 173 | else: 174 | p = random.uniform(0, 1) 175 | if p > 0.5: # by 50% chance, the buffer will return a previously stored image, and insert the current image into the buffer 176 | random_id = random.randint(0, self.pool_size - 1) # randint is inclusive 177 | tmp = self.images[random_id].clone() 178 | self.images[random_id] = image 179 | return_images.append(tmp) 180 | else: # by another 50% chance, the buffer will return the current image 181 | return_images.append(image) 182 | return_images = torch.cat(return_images, 0) # collect all the images and return 183 | return return_images 184 | 185 | 186 | def get_scheduler(optimizer, opt): 187 | """Return a learning rate scheduler 188 | Parameters: 189 | optimizer -- the optimizer of the network 190 | opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.  191 | opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine 192 | For 'linear', we keep the same learning rate for the first epochs 193 | and linearly decay the rate to zero over the next epochs. 194 | For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers. 195 | See https://pytorch.org/docs/stable/optim.html for more details. 196 | """ 197 | if opt.lr_policy == 'linear': 198 | def lambda_rule(epoch): 199 | lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs_decay + 1) 200 | return lr_l 201 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 202 | elif opt.lr_policy == 'step': 203 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1) 204 | elif opt.lr_policy == 'plateau': 205 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) 206 | elif opt.lr_policy == 'cosine': 207 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0) 208 | else: 209 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) 210 | return scheduler 211 | 212 | 213 | def get_ee(pos, pa, ees, velo=False, from_root=False): 214 | pos = pos.clone() 215 | for i, fa in enumerate(pa): 216 | if i == 0: continue 217 | if not from_root and fa == 0: continue 218 | pos[:, :, i, :] += pos[:, :, fa, :] 219 | 220 | pos = pos[:, :, ees, :] 221 | if velo: 222 | pos = pos[:, 1:, ...] - pos[:, :-1, ...] 223 | pos = pos * 10 224 | return pos -------------------------------------------------------------------------------- /outer_utils/Animation.py: -------------------------------------------------------------------------------- 1 | import operator 2 | 3 | import numpy as np 4 | import numpy.core.umath_tests as ut 5 | 6 | from outer_utils.Quaternions_old import Quaternions 7 | 8 | class Animation: 9 | """ 10 | Animation is a numpy-like wrapper for animation data 11 | 12 | Animation data consists of several arrays consisting 13 | of F frames and J joints. 14 | 15 | The animation is specified by 16 | 17 | rotations : (F, J) Quaternions | Joint Rotations 18 | positions : (F, J, 3) ndarray | Joint Positions 19 | 20 | The base pose is specified by 21 | 22 | orients : (J) Quaternions | Joint Orientations 23 | offsets : (J, 3) ndarray | Joint Offsets 24 | 25 | And the skeletal structure is specified by 26 | 27 | parents : (J) ndarray | Joint Parents 28 | """ 29 | 30 | def __init__(self, rotations, positions, orients, offsets, parents): 31 | 32 | self.rotations = rotations 33 | self.positions = positions 34 | self.orients = orients 35 | self.offsets = offsets 36 | self.parents = parents 37 | 38 | def __op__(self, op, other): 39 | return Animation( 40 | op(self.rotations, other.rotations), 41 | op(self.positions, other.positions), 42 | op(self.orients, other.orients), 43 | op(self.offsets, other.offsets), 44 | op(self.parents, other.parents)) 45 | 46 | def __iop__(self, op, other): 47 | self.rotations = op(self.roations, other.rotations) 48 | self.positions = op(self.roations, other.positions) 49 | self.orients = op(self.orients, other.orients) 50 | self.offsets = op(self.offsets, other.offsets) 51 | self.parents = op(self.parents, other.parents) 52 | return self 53 | 54 | def __sop__(self, op): 55 | return Animation( 56 | op(self.rotations), 57 | op(self.positions), 58 | op(self.orients), 59 | op(self.offsets), 60 | op(self.parents)) 61 | 62 | def __add__(self, other): return self.__op__(operator.add, other) 63 | def __sub__(self, other): return self.__op__(operator.sub, other) 64 | def __mul__(self, other): return self.__op__(operator.mul, other) 65 | def __div__(self, other): return self.__op__(operator.div, other) 66 | 67 | def __abs__(self): return self.__sop__(operator.abs) 68 | def __neg__(self): return self.__sop__(operator.neg) 69 | 70 | def __iadd__(self, other): return self.__iop__(operator.iadd, other) 71 | def __isub__(self, other): return self.__iop__(operator.isub, other) 72 | def __imul__(self, other): return self.__iop__(operator.imul, other) 73 | def __idiv__(self, other): return self.__iop__(operator.idiv, other) 74 | 75 | def __len__(self): return len(self.rotations) 76 | 77 | def __getitem__(self, k): 78 | if isinstance(k, tuple): 79 | return Animation( 80 | self.rotations[k], 81 | self.positions[k], 82 | self.orients[k[1:]], 83 | self.offsets[k[1:]], 84 | self.parents[k[1:]]) 85 | else: 86 | return Animation( 87 | self.rotations[k], 88 | self.positions[k], 89 | self.orients, 90 | self.offsets, 91 | self.parents) 92 | 93 | def __setitem__(self, k, v): 94 | if isinstance(k, tuple): 95 | self.rotations.__setitem__(k, v.rotations) 96 | self.positions.__setitem__(k, v.positions) 97 | self.orients.__setitem__(k[1:], v.orients) 98 | self.offsets.__setitem__(k[1:], v.offsets) 99 | self.parents.__setitem__(k[1:], v.parents) 100 | else: 101 | self.rotations.__setitem__(k, v.rotations) 102 | self.positions.__setitem__(k, v.positions) 103 | self.orients.__setitem__(k, v.orients) 104 | self.offsets.__setitem__(k, v.offsets) 105 | self.parents.__setitem__(k, v.parents) 106 | 107 | @property 108 | def shape(self): return (self.rotations.shape[0], self.rotations.shape[1]) 109 | 110 | def copy(self): return Animation( 111 | self.rotations.copy(), self.positions.copy(), 112 | self.orients.copy(), self.offsets.copy(), 113 | self.parents.copy()) 114 | 115 | def repeat(self, *args, **kw): 116 | return Animation( 117 | self.rotations.repeat(*args, **kw), 118 | self.positions.repeat(*args, **kw), 119 | self.orients, self.offsets, self.parents) 120 | 121 | def ravel(self): 122 | return np.hstack([ 123 | self.rotations.log().ravel(), 124 | self.positions.ravel(), 125 | self.orients.log().ravel(), 126 | self.offsets.ravel()]) 127 | 128 | @classmethod 129 | def unravel(clas, anim, shape, parents): 130 | nf, nj = shape 131 | rotations = anim[nf*nj*0:nf*nj*3] 132 | positions = anim[nf*nj*3:nf*nj*6] 133 | orients = anim[nf*nj*6+nj*0:nf*nj*6+nj*3] 134 | offsets = anim[nf*nj*6+nj*3:nf*nj*6+nj*6] 135 | return cls( 136 | Quaternions.exp(rotations), positions, 137 | Quaternions.exp(orients), offsets, 138 | parents.copy()) 139 | 140 | 141 | # local transformation matrices 142 | def transforms_local(anim): 143 | """ 144 | Computes Animation Local Transforms 145 | 146 | As well as a number of other uses this can 147 | be used to compute global joint transforms, 148 | which in turn can be used to compete global 149 | joint positions 150 | 151 | Parameters 152 | ---------- 153 | 154 | anim : Animation 155 | Input animation 156 | 157 | Returns 158 | ------- 159 | 160 | transforms : (F, J, 4, 4) ndarray 161 | 162 | For each frame F, joint local 163 | transforms for each joint J 164 | """ 165 | 166 | transforms = anim.rotations.transforms() 167 | transforms = np.concatenate([transforms, np.zeros(transforms.shape[:2] + (3, 1))], axis=-1) 168 | transforms = np.concatenate([transforms, np.zeros(transforms.shape[:2] + (1, 4))], axis=-2) 169 | # the last column is filled with the joint positions! 170 | transforms[:,:,0:3,3] = anim.positions 171 | transforms[:,:,3:4,3] = 1.0 172 | return transforms 173 | 174 | 175 | def transforms_multiply(t0s, t1s): 176 | """ 177 | Transforms Multiply 178 | 179 | Multiplies two arrays of animation transforms 180 | 181 | Parameters 182 | ---------- 183 | 184 | t0s, t1s : (F, J, 4, 4) ndarray 185 | Two arrays of transforms 186 | for each frame F and each 187 | joint J 188 | 189 | Returns 190 | ------- 191 | 192 | transforms : (F, J, 4, 4) ndarray 193 | Array of transforms for each 194 | frame F and joint J multiplied 195 | together 196 | """ 197 | 198 | return ut.matrix_multiply(t0s, t1s) 199 | 200 | 201 | def transforms_blank(anim): 202 | """ 203 | Blank Transforms 204 | 205 | Parameters 206 | ---------- 207 | 208 | anim : Animation 209 | Input animation 210 | 211 | Returns 212 | ------- 213 | 214 | transforms : (F, J, 4, 4) ndarray 215 | Array of identity transforms for 216 | each frame F and joint J 217 | """ 218 | 219 | ts = np.zeros(anim.shape + (4, 4)) 220 | ts[:,:,0,0] = 1.0; ts[:,:,1,1] = 1.0; 221 | ts[:,:,2,2] = 1.0; ts[:,:,3,3] = 1.0; 222 | return ts 223 | 224 | 225 | # global transformation matrices 226 | def transforms_global(anim): 227 | """ 228 | Global Animation Transforms 229 | 230 | This relies on joint ordering 231 | being incremental. That means a joint 232 | J1 must not be a ancestor of J0 if 233 | J0 appears before J1 in the joint 234 | ordering. 235 | 236 | Parameters 237 | ---------- 238 | 239 | anim : Animation 240 | Input animation 241 | 242 | Returns 243 | ------ 244 | 245 | transforms : (F, J, 4, 4) ndarray 246 | Array of global transforms for 247 | each frame F and joint J 248 | """ 249 | 250 | joints = np.arange(anim.shape[1]) 251 | parents = np.arange(anim.shape[1]) 252 | locals = transforms_local(anim) 253 | globals = transforms_blank(anim) 254 | 255 | globals[:,0] = locals[:,0] 256 | 257 | for i in range(1, anim.shape[1]): 258 | globals[:,i] = transforms_multiply(globals[:,anim.parents[i]], locals[:,i]) 259 | 260 | return globals 261 | 262 | # !!! useful! 263 | def positions_global(anim): 264 | """ 265 | Global Joint Positions 266 | 267 | Given an animation compute the global joint 268 | positions at at every frame 269 | 270 | Parameters 271 | ---------- 272 | 273 | anim : Animation 274 | Input animation 275 | 276 | Returns 277 | ------- 278 | 279 | positions : (F, J, 3) ndarray 280 | Positions for every frame F 281 | and joint position J 282 | """ 283 | 284 | # get the last column -- corresponding to the coordinates 285 | positions = transforms_global(anim)[:,:,:,3] 286 | return positions[:,:,:3] / positions[:,:,3,np.newaxis] 287 | 288 | -------------------------------------------------------------------------------- /data_preprocess/Mixamo/combined_motion.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import copy 3 | from data_preprocess.Mixamo.motion_dataset import MotionData 4 | import os 5 | import numpy as np 6 | import torch 7 | from data_preprocess.Mixamo.bvh_parser import BVH_file 8 | from parser.parser_mixamo import get_std_bvh 9 | from data_preprocess.Mixamo import get_test_set 10 | 11 | 12 | class MixedData0(Dataset): 13 | """ 14 | Mixed data for many skeletons but one topologies 15 | """ 16 | def __init__(self, args, motions, skeleton_idx): 17 | super(MixedData0, self).__init__() 18 | 19 | self.motions = motions 20 | self.motions_reverse = torch.tensor(self.motions.numpy()[..., ::-1].copy()) 21 | self.skeleton_idx = skeleton_idx 22 | self.length = motions.shape[0] 23 | self.args = args 24 | 25 | def __len__(self): 26 | return self.length 27 | 28 | def __getitem__(self, item): 29 | if self.args.data_augment == 0 or torch.rand(1) < 0.5: 30 | return [self.motions[item], self.skeleton_idx[item]] 31 | else: 32 | return [self.motions_reverse[item], self.skeleton_idx[item]] 33 | 34 | 35 | class MixedData(Dataset): 36 | """ 37 | data_gruop_num * 2 * samples 38 | """ 39 | def __init__(self, args, datasets_groups): 40 | device = torch.device(args.cuda_device if (torch.cuda.is_available()) else 'cpu') 41 | self.final_data = [] 42 | self.length = 0 43 | self.offsets = [] 44 | self.joint_topologies = [] 45 | self.ee_ids = [] 46 | self.means = [] 47 | self.vars = [] 48 | dataset_num = 0 49 | seed = 19260817 50 | total_length = 10000000 51 | all_datas = [] 52 | for datasets in datasets_groups: 53 | offsets_group = [] 54 | means_group = [] 55 | vars_group = [] 56 | dataset_num += len(datasets) 57 | tmp = [] 58 | for i, dataset in enumerate(datasets): 59 | new_args = copy.copy(args) 60 | new_args.data_augment = 0 61 | new_args.dataset = dataset 62 | 63 | tmp.append(MotionData(new_args)) # append one character motion 64 | 65 | mean = np.load('./data_preprocess/Mixamo/Mixamo/mean_var/{}_mean.npy'.format(dataset)) 66 | var = np.load('./data_preprocess/Mixamo/Mixamo/mean_var/{}_var.npy'.format(dataset)) 67 | mean = torch.tensor(mean) 68 | mean = mean.reshape((1,) + mean.shape) 69 | var = torch.tensor(var) 70 | var = var.reshape((1,) + var.shape) 71 | 72 | means_group.append(mean) 73 | vars_group.append(var) 74 | 75 | file = BVH_file(get_std_bvh(dataset=dataset)) 76 | if i == 0: 77 | self.joint_topologies.append(file.topology) 78 | self.ee_ids.append(file.get_ee_id()) 79 | new_offset = file.offset 80 | new_offset = torch.tensor(new_offset, dtype=torch.float) 81 | new_offset = new_offset.reshape((1,) + new_offset.shape) 82 | offsets_group.append(new_offset) 83 | 84 | total_length = min(total_length, len(tmp[-1])) 85 | all_datas.append(tmp) # list contains groups of skeleton motions 86 | offsets_group = torch.cat(offsets_group, dim=0) 87 | offsets_group = offsets_group.to(device) 88 | means_group = torch.cat(means_group, dim=0).to(device) 89 | vars_group = torch.cat(vars_group, dim=0).to(device) 90 | self.offsets.append(offsets_group) 91 | self.means.append(means_group) 92 | self.vars.append(vars_group) 93 | 94 | for datasets in all_datas: 95 | pt = 0 96 | motions = [] 97 | skeleton_idx = [] 98 | for dataset in datasets: 99 | motions.append(dataset[:]) 100 | skeleton_idx += [pt] * len(dataset) 101 | pt += 1 102 | motions = torch.cat(motions, dim=0) 103 | if self.length != 0 and self.length != len(skeleton_idx): 104 | self.length = min(self.length, len(skeleton_idx)) 105 | else: 106 | self.length = len(skeleton_idx) 107 | self.final_data.append(MixedData0(args, motions, skeleton_idx)) 108 | 109 | def denorm(self, gid, pid, data): 110 | means = self.means[gid][pid, ...] 111 | var = self.vars[gid][pid, ...] 112 | return data * var + means 113 | 114 | def __len__(self): 115 | return self.length 116 | 117 | def __getitem__(self, item): 118 | res = [] 119 | for data in self.final_data: 120 | res.append(data[item]) #The biger datasets_group cant get whole items 121 | return res 122 | 123 | 124 | class TestData(Dataset): 125 | def __init__(self, args, characters): 126 | self.characters = characters 127 | self.file_list = get_test_set() 128 | self.mean = [] 129 | self.joint_topologies = [] 130 | self.inverse_simplify_maps = [] 131 | self.simplified_names = [] 132 | self.var = [] 133 | self.offsets = [] 134 | self.ee_ids = [] 135 | self.args = args 136 | self.device = torch.device(args.cuda_device) 137 | 138 | for i, character_group in enumerate(characters): 139 | mean_group = [] 140 | var_group = [] 141 | offsets_group = [] 142 | for j, character in enumerate(character_group): 143 | file = BVH_file(get_std_bvh(dataset=character)) 144 | # print(file.skeleton_type, character) 145 | if j == 0: 146 | self.joint_topologies.append(file.topology) 147 | self.ee_ids.append(file.get_ee_id()) 148 | self.inverse_simplify_maps.append(file.inverse_simplify_map) 149 | self.simplified_names.append(file.simplified_name) 150 | new_offset = file.offset 151 | new_offset = torch.tensor(new_offset, dtype=torch.float) 152 | new_offset = new_offset.reshape((1,) + new_offset.shape) 153 | 154 | offsets_group.append(new_offset) 155 | mean = np.load('./data_preprocess/Mixamo/Mixamo/mean_var/{}_mean.npy'.format(character)) 156 | var = np.load('./data_preprocess/Mixamo/Mixamo/mean_var/{}_var.npy'.format(character)) 157 | mean = torch.tensor(mean) 158 | mean = mean.reshape((1, ) + mean.shape) 159 | var = torch.tensor(var) 160 | var = var.reshape((1, ) + var.shape) 161 | if len(mean.shape) > 3: 162 | mean = mean.squeeze(1) 163 | var = var.squeeze(1) 164 | mean_group.append(mean) 165 | var_group.append(var) 166 | 167 | mean_group = torch.cat(mean_group, dim=0).to(self.device) 168 | var_group = torch.cat(var_group, dim=0).to(self.device) 169 | offsets_group = torch.cat(offsets_group, dim=0).to(self.device) 170 | self.mean.append(mean_group) 171 | self.var.append(var_group) 172 | self.offsets.append(offsets_group) 173 | 174 | def __getitem__(self, item): 175 | res = [] 176 | bad_flag = 0 177 | for i, character_group in enumerate(self.characters): 178 | res_group = [] 179 | ref_shape = None 180 | for j in range(len(character_group)): 181 | new_motion = self.get_item(i, j, item) 182 | if new_motion is not None: 183 | new_motion = new_motion.reshape((1, ) + new_motion.shape) 184 | new_motion = (new_motion - self.mean[i][j]) / self.var[i][j] 185 | ref_shape = new_motion 186 | res_group.append(new_motion) 187 | 188 | if ref_shape is None: 189 | print('Bad at {}'.format(item)) 190 | return None 191 | for j in range(len(character_group)): 192 | if res_group[j] is None: 193 | bad_flag = 1 194 | res_group[j] = torch.zeros_like(ref_shape) 195 | if bad_flag: 196 | print('Bad at {}'.format(item)) 197 | 198 | res_group = torch.cat(res_group, dim=0) 199 | res.append([res_group, list(range(len(character_group)))]) 200 | return res 201 | 202 | def __len__(self): 203 | return len(self.file_list) 204 | 205 | def get_item(self, gid, pid, id): 206 | character = self.characters[gid][pid] 207 | path = './data_preprocess/Mixamo/Mixamo/{}/'.format(character) 208 | if isinstance(id, int): 209 | file = path + self.file_list[id] 210 | elif isinstance(id, str): 211 | file = id 212 | else: 213 | raise Exception('Wrong input file type') 214 | if not os.path.exists(file): 215 | raise Exception('Cannot find file') 216 | file = BVH_file(file, args=self.args) 217 | motion = file.to_tensor(quater=self.args.rotation == 'quaternion') 218 | motion = motion[:, ::2] 219 | length = motion.shape[-1] 220 | length = length // 4 * 4 221 | return motion[..., :length].to(self.device) 222 | 223 | def denorm(self, gid, pid, data): 224 | means = self.mean[gid][pid, ...] 225 | var = self.var[gid][pid, ...] 226 | return data * var + means 227 | 228 | def normalize(self, gid, pid, data): 229 | means = self.mean[gid][pid, ...] 230 | var = self.var[gid][pid, ...] 231 | return (data - means) / var 232 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from utils import rotation as rt 4 | 5 | 6 | def get_lpos(skel, t_size, njoints, device): 7 | b_size = skel.shape[0] 8 | lpos = torch.cat([torch.zeros(b_size, t_size, 3).to(device), 9 | skel.unsqueeze(1).repeat(1, t_size, 1)], 10 | dim=-1).reshape(b_size, t_size, njoints, -1) # B T J 3 11 | 12 | return lpos 13 | 14 | def get_body_part(correspondence, topology_name): 15 | part_list = [] 16 | for dic in correspondence: 17 | part_list.append(dic[topology_name]) 18 | return part_list 19 | 20 | 21 | def get_part_matrix(part_list, njoints): 22 | matrix = torch.zeros(len(part_list), njoints) 23 | for i, part in enumerate(part_list): 24 | matrix[i, part] = 1 25 | matrix[:, -1] = 1 26 | matrix[:, 0] = 1 27 | return matrix 28 | 29 | 30 | def get_offset_part_matrix(part_list, num_offsets): 31 | matrix = torch.zeros(len(part_list), num_offsets+1) 32 | for i, part in enumerate(part_list): 33 | matrix[i, part] = 1 34 | return matrix[:, 1:] 35 | 36 | 37 | def get_transformer_matrix(part_list, njoints): 38 | """ 39 | :param part_list: body part list [[0 ,1 , 2], [1]] n * 4 40 | :param njoints: body joints' number plus root velocity 41 | :return: 42 | """ 43 | nparts = len(part_list) 44 | matrix = torch.zeros([nparts + njoints, njoints]) 45 | 46 | for i in range(nparts): 47 | matrix[i, part_list[i]] = 1 48 | for j in part_list[i]: 49 | for k in part_list[i]: 50 | matrix[j + nparts, k] = 1 51 | matrix[:, 0] = 1 52 | matrix[:, -1] = 1 53 | 54 | matrix = torch.cat((torch.zeros([njoints + nparts, nparts]), matrix), dim=1) 55 | for p in range(nparts + njoints): 56 | matrix[p, p] = 1 57 | 58 | matrix = matrix.float().masked_fill(matrix == 0., float(-1e20)).masked_fill(matrix == 1., float(0.0)) 59 | return matrix 60 | 61 | 62 | def quat2motion(input, lpos, parents, jnum): 63 | b_size, t_size = input.shape[0], input.shape[1] 64 | input_quat = input[..., :jnum * 4].reshape(b_size, t_size, jnum, 4) 65 | input_vel = input[..., jnum * 4: jnum * 4 + 3].unsqueeze(2) 66 | _, local_joints = rt.quat_fk(input_quat, lpos, parents) 67 | return torch.cat([local_joints, input_vel], dim=-2) 68 | 69 | 70 | def static2motion(local_joints): 71 | njoints = local_joints.shape[2] - 1 72 | global_motion = local_joints[..., :njoints, :].clone() 73 | for i in range(local_joints.shape[1]): 74 | if i == 0: 75 | translation = local_joints[:, 0, njoints:, :] 76 | else: 77 | translation = local_joints[:, i, njoints:, :] + translation 78 | global_motion[:, i, ...] = global_motion[:, i, ...] + translation 79 | return global_motion 80 | 81 | 82 | def forwardkinematics(input, lpos, parents, jnum): 83 | local_pos = quat2motion(input, lpos, parents, jnum) 84 | global_pos = static2motion(local_pos) 85 | return global_pos, local_pos[..., :-1, :] 86 | 87 | 88 | class ForwardKinematics: 89 | def __init__(self, parents, jnum, site_index=None): 90 | self.parents = parents 91 | self.jnum = jnum 92 | self.site_index = site_index 93 | 94 | def forward(self, input, lpos): 95 | if self.site_index is not None: 96 | input_new0 = torch.zeros([input.shape[0], input.shape[1], self.jnum, 4]).to(input.device) 97 | input_new0[..., 0] = 1 98 | input_new0[..., self.site_index, :] = input[..., :len(self.site_index) * 4].\ 99 | reshape(input.shape[0], input.shape[1], len(self.site_index), 4) 100 | input_new1 = input[..., len(self.site_index) * 4: len(self.site_index) * 4 + 3] 101 | input_new = torch.cat((input_new0.reshape(input.shape[0], input.shape[1], -1), input_new1), dim=-1) 102 | global_pose, local_pose = forwardkinematics(input_new, lpos, self.parents, self.jnum) 103 | else: 104 | global_pose, local_pose = forwardkinematics(input, lpos, self.parents, self.jnum) 105 | return global_pose, local_pose 106 | 107 | 108 | def findedgechains(edges): 109 | degree = [0] * 100 110 | seq_list = [] 111 | 112 | for edge in edges: 113 | degree[edge[0]] += 1 114 | degree[edge[1]] += 1 115 | 116 | def find_seq(j, seq): 117 | nonlocal degree, edges, seq_list 118 | 119 | if degree[j] > 2 and j != 0: 120 | seq_list.append(seq) 121 | seq = [] 122 | 123 | if degree[j] == 1: 124 | seq_list.append(seq) 125 | return 126 | 127 | for idx, edge in enumerate(edges): 128 | if edge[0] == j: 129 | find_seq(edge[1], seq + [idx]) 130 | 131 | find_seq(0, []) 132 | return seq_list 133 | 134 | 135 | def findbodychain(edge_seq, edges): 136 | joint_seq = [] 137 | for seq in edge_seq: 138 | joint_chain = [] 139 | for i, edge in enumerate(seq): 140 | joint_chain.append(edges[edge][0]) 141 | if i == len(seq)-1: 142 | joint_chain.append(edges[edge][1]) 143 | joint_seq.append(joint_chain) 144 | return joint_seq 145 | 146 | 147 | def getbodyparts(edges): 148 | edge_seq = findedgechains(edges) 149 | joint_seq = findbodychain(edge_seq, edges) 150 | return joint_seq 151 | 152 | 153 | def calselfmask(part_list, njoints, edges=None, is_conv=False, 154 | ): 155 | part_list = part_list.copy() 156 | nparts = len(part_list) 157 | 158 | matrix = torch.zeros([njoints + nparts, njoints]) 159 | n = 0 160 | 161 | if edges is not None: 162 | rotation_map = [] 163 | for i, edge in enumerate(edges): 164 | rotation_map.append(edge[1]) 165 | rotation_map_reverse = [] 166 | for i in range(1, njoints): 167 | rotation_map_reverse.append(rotation_map.index(i)) 168 | 169 | for part in part_list: 170 | if part[0] == 0: 171 | part.pop(0) 172 | for i in range(len(part)): 173 | if edges is not None: 174 | part[i] = rotation_map_reverse[part[i]-1] 175 | else: 176 | part[i] -= 1 177 | 178 | for part in part_list: 179 | matrix[n, part] = 1 180 | for k in part: 181 | matrix[k + nparts, part] = 1 182 | n += 1 183 | 184 | matrix = torch.cat((torch.zeros([njoints+nparts, nparts]), matrix), dim=1) 185 | for p in range(nparts + njoints): 186 | matrix[p, p] = 1 187 | 188 | matrix[:, -1] = 1 189 | if not is_conv: 190 | matrix = matrix.float().masked_fill(matrix == 0., float(-1e20)).masked_fill(matrix == 1., float(0.0)) 191 | else: 192 | matrix = matrix[:nparts, nparts:] 193 | return matrix 194 | 195 | 196 | def q_mul_q(a, b): 197 | # sqs, oqs = q_broadcast(a, b) 198 | sqs, oqs = torch.broadcast_tensors(a, b) 199 | if sqs.shape[-1] != 4: 200 | sqs = sqs.reshape(sqs.shape[:-1] + (-1, 4)) 201 | oqs = oqs.reshape(oqs.shape[:-1] + (-1, 4)) 202 | q0 = sqs[..., 0:1] 203 | q1 = sqs[..., 1:2] 204 | q2 = sqs[..., 2:3] 205 | q3 = sqs[..., 3:4] 206 | r0 = oqs[..., 0:1] 207 | r1 = oqs[..., 1:2] 208 | r2 = oqs[..., 2:3] 209 | r3 = oqs[..., 3:4] 210 | 211 | qs0 = r0 * q0 - r1 * q1 - r2 * q2 - r3 * q3 212 | qs1 = r0 * q1 + r1 * q0 - r2 * q3 + r3 * q2 213 | qs2 = r0 * q2 + r1 * q3 + r2 * q0 - r3 * q1 214 | qs3 = r0 * q3 - r1 * q2 + r2 * q1 + r3 * q0 215 | 216 | # return tf.concat([qs0, qs1, qs2, qs3], axis=-1) 217 | return torch.cat((qs0, qs1, qs2, qs3), -1).reshape(qs0.shape[:-2] + (-1,)) 218 | 219 | def qnormalize(quat): 220 | quat = quat.transpose(1, 2) 221 | b_s, t_s = quat.shape[0], quat.shape[1] 222 | quat = quat.reshape(b_s, t_s, -1, 4) 223 | quat = quat/torch.norm(quat).unsqueeze(-1) 224 | quat = quat.reshape(b_s, t_s, -1).transpose(1, 2) 225 | return quat 226 | 227 | def build_edge_topology(topology, offset): 228 | # get all edges (pa, child, offset) 229 | edges = [] 230 | joint_num = len(topology) 231 | for i in range(1, joint_num): 232 | edges.append((topology[i], i, offset[i])) 233 | return edges 234 | 235 | def build_joint_topology(edges, origin_names): 236 | parent = [] 237 | offset = [] 238 | names = [] 239 | edge2joint = [] 240 | joint_from_edge = [] # -1 means virtual joint 241 | joint_cnt = 0 242 | out_degree = [0] * (len(edges) + 10) 243 | for edge in edges: 244 | out_degree[edge[0]] += 1 245 | 246 | # add root joint 247 | joint_from_edge.append(-1) 248 | parent.append(0) 249 | offset.append(np.array([0, 0, 0])) 250 | names.append(origin_names[0]) 251 | joint_cnt += 1 252 | 253 | def make_topology(edge_idx, pa): 254 | nonlocal edges, parent, offset, names, edge2joint, joint_from_edge, joint_cnt 255 | edge = edges[edge_idx] 256 | if out_degree[edge[0]] > 1: 257 | parent.append(pa) 258 | offset.append(np.array([0, 0, 0])) 259 | names.append(origin_names[edge[1]] + '_virtual') 260 | edge2joint.append(-1) 261 | pa = joint_cnt 262 | joint_cnt += 1 263 | 264 | parent.append(pa) 265 | offset.append(edge[2]) 266 | names.append(origin_names[edge[1]]) 267 | edge2joint.append(edge_idx) 268 | pa = joint_cnt 269 | joint_cnt += 1 270 | 271 | for idx, e in enumerate(edges): 272 | if e[0] == edge[1]: 273 | make_topology(idx, pa) 274 | 275 | for idx, e in enumerate(edges): 276 | if e[0] == 0: 277 | make_topology(idx, 0) 278 | 279 | return parent, offset, names, edge2joint 280 | -------------------------------------------------------------------------------- /models/Intergrated.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from models.networks import MotionAE, LatentDiscriminator, SkeletonEncoder 3 | import os 4 | from utils.utils import ForwardKinematics, build_edge_topology 5 | 6 | # package for retargeting between characters in Mixamo 7 | from models.Kinematics import ForwardKinematics as ForwardKinematics_mixamo 8 | from data_preprocess.Mixamo.bvh_parser import BVH_file 9 | from parser.parser_mixamo import get_std_bvh 10 | import torch.nn as nn 11 | from collections import OrderedDict 12 | 13 | 14 | class IntegratedModel: 15 | def __init__(self, args, body_parts, njoints, parents, n_topology, device, **kwargs): 16 | self.args = args 17 | self.body_parts = body_parts 18 | self.part_num = len(self.body_parts) 19 | self.njoints = njoints 20 | self.indices = [0] 21 | 22 | for part in body_parts: 23 | self.indices += part 24 | if args.with_end: 25 | self.indices_withend = [] 26 | 27 | for idx in self.indices: 28 | self.indices_withend.append(kwargs["not_end"][idx]) 29 | self.indices_withend.extend(kwargs["part_end"]) 30 | 31 | else: 32 | self.indices_withend = self.indices 33 | 34 | if args.with_end: 35 | self.fk = ForwardKinematics(kwargs["parents_withend"], 36 | kwargs["njoints_withend"], 37 | site_index=kwargs["not_end"]) 38 | else: 39 | self.fk = ForwardKinematics(parents, njoints) 40 | 41 | self.ae = MotionAE(args, body_parts, njoints).to(device) 42 | self.skel_enc = SkeletonEncoder(args, body_parts, njoints).to(device) 43 | 44 | if self.args.dis: 45 | if self.args.dis_mode == 'norm_rotation' or self.args.dis_mode == 'denorm_rotation': 46 | input_dim = self.args.conv_input * self.njoints + 3 47 | hidden_dim = self.args.dis_hidden 48 | self.discriminator = \ 49 | LatentDiscriminator(args.dis_layers, args.dis_kernel_size, 50 | input_dim, hidden_dim).to(device) 51 | elif self.args.dis_mode == 'denorm_pos': 52 | if args.with_end: 53 | input_dim = 3 * kwargs["njoints_withend"] 54 | else: 55 | input_dim = 3 * self.njoints 56 | hidden_dim = self.args.dis_hidden 57 | self.discriminator = \ 58 | LatentDiscriminator(args.dis_layers, args.dis_kernel_size, 59 | input_dim, hidden_dim).to(device) 60 | 61 | 62 | def parameters(self): 63 | return self.G_parameters() + self.D_parameters() 64 | 65 | def G_parameters(self): 66 | parameters = list(self.ae.parameters()) + list(self.skel_enc.parameters()) 67 | return parameters 68 | 69 | def D_parameters(self): 70 | return list(self.discriminator.parameters()) 71 | 72 | def save(self, path, epoch): 73 | from parser.base import try_mkdir 74 | 75 | path = os.path.join(path, str(epoch)) 76 | try_mkdir(path) 77 | 78 | torch.save(self.ae.state_dict(), os.path.join(path, 'ae.pth')) 79 | torch.save(self.skel_enc.state_dict(), os.path.join(path, 'skel_enc.pth')) 80 | 81 | if self.args.dis: 82 | torch.save(self.discriminator.state_dict(), os.path.join(path, 'discriminator.pth')) 83 | 84 | print('Save at {} succeed!'.format(path)) 85 | 86 | def load(self, path, epoch=None): 87 | print('loading from', path) 88 | if not os.path.exists(path): 89 | raise Exception('Unknown loading path') 90 | 91 | if epoch is None: 92 | all = [int(q) for q in os.listdir(path) if os.path.isdir(os.path.join(path, q))] 93 | if len(all) == 0: 94 | raise Exception('Empty loading path') 95 | epoch = sorted(all)[-1] 96 | 97 | path = os.path.join(path, str(epoch)) 98 | print('loading from epoch {}......'.format(epoch)) 99 | 100 | self.ae.load_state_dict(torch.load(os.path.join(path, 'ae.pth') 101 | )) 102 | self.skel_enc.load_state_dict(torch.load(os.path.join(path, 'skel_enc.pth') 103 | )) 104 | 105 | if os.path.exists(os.path.join(path, 'discriminator.pth')): 106 | self.discriminator.load_state_dict(torch.load(os.path.join(path, 'discriminator.pth'))) 107 | print('load succeed!') 108 | 109 | def train(self): 110 | self.ae = self.ae.train() 111 | self.skel_enc = self.skel_enc.train() 112 | if self.args.dis: 113 | self.discriminator = self.discriminator.train() 114 | 115 | def eval(self): 116 | self.ae = self.ae.eval() 117 | self.skel_enc = self.skel_enc.eval() 118 | if self.args.dis: 119 | self.discriminator = self.discriminator.eval() 120 | 121 | 122 | class IntegratedModel_Mixamo: 123 | def __init__(self, args, joint_topology, device, characters): 124 | self.args = args 125 | self.joint_topology = joint_topology 126 | self.edges = build_edge_topology(joint_topology, torch.zeros((len(joint_topology), 3))) 127 | self.fk = ForwardKinematics_mixamo(args, self.edges) 128 | 129 | self.height = [] 130 | self.real_height = [] 131 | for char in characters: 132 | if args.use_sep_ee: 133 | h = BVH_file(get_std_bvh(dataset=char)).get_ee_length() 134 | else: 135 | h = BVH_file(get_std_bvh(dataset=char)).get_height() 136 | if args.ee_loss_fact == 'learn': 137 | h = torch.tensor(h, dtype=torch.float) 138 | else: 139 | h = torch.tensor(h, dtype=torch.float, requires_grad=False) 140 | self.real_height.append(BVH_file(get_std_bvh(dataset=char)).get_height()) 141 | self.height.append(h.unsqueeze(0)) 142 | self.real_height = torch.tensor(self.real_height, device=device) 143 | self.height = torch.cat(self.height, dim=0) 144 | self.height = self.height.to(device) 145 | if not args.use_sep_ee: self.height.unsqueeze_(-1) 146 | if args.ee_loss_fact == 'learn': self.height_para = [self.height] 147 | else: self.height_para = [] 148 | 149 | if args.model == "pan": 150 | self.auto_encoder = MotionAE(args, self.edges, None).to(device) 151 | self.static_encoder = SkeletonEncoder(args, self.edges, None).to(device) 152 | self.discriminator = LatentDiscriminator(3, 15, (len(self.edges) + 1)*3, 256, is_lafan1=False) 153 | 154 | 155 | def parameters(self): 156 | return self.G_parameters() + self.D_parameters() 157 | 158 | def G_parameters(self): 159 | return list(self.auto_encoder.parameters()) + list(self.static_encoder.parameters()) + self.height_para 160 | 161 | def D_parameters(self): 162 | return list(self.discriminator.parameters()) 163 | 164 | def train(self): 165 | self.auto_encoder.train() 166 | self.discriminator.train() 167 | self.static_encoder.train() 168 | 169 | def eval(self): 170 | self.auto_encoder.eval() 171 | self.discriminator.eval() 172 | self.static_encoder.eval() 173 | 174 | def save(self, path, epoch): 175 | from parser.parser_mixamo import try_mkdir 176 | 177 | path = os.path.join(path, str(epoch)) 178 | try_mkdir(path) 179 | 180 | torch.save(self.height, os.path.join(path, 'height.pt')) 181 | torch.save(self.auto_encoder.state_dict(), os.path.join(path, 'auto_encoder.pt')) 182 | torch.save(self.discriminator.state_dict(), os.path.join(path, 'discriminator.pt')) 183 | torch.save(self.static_encoder.state_dict(), os.path.join(path, 'static_encoder.pt')) 184 | 185 | print('Save at {} succeed!'.format(path)) 186 | 187 | def load(self, path, epoch=None): 188 | print('loading from', path) 189 | if not os.path.exists(path): 190 | raise Exception('Unknown loading path') 191 | 192 | if epoch is None: 193 | all = [int(q) for q in os.listdir(path) if os.path.isdir(os.path.join(path, q))] 194 | if len(all) == 0: 195 | raise Exception('Empty loading path') 196 | epoch = sorted(all)[-1] 197 | 198 | path = os.path.join(path, str(epoch)) 199 | print('loading from epoch {}......'.format(epoch)) 200 | if self.args.use_parallel: 201 | self.load_network(self.auto_encoder, os.path.join(path, 'auto_encoder.pt')) 202 | self.load_network(self.static_encoder, os.path.join(path, 'static_encoder.pt')) 203 | else: 204 | self.auto_encoder.load_state_dict(torch.load(os.path.join(path, 'auto_encoder.pt'), 205 | map_location=self.args.cuda_device)) 206 | self.static_encoder.load_state_dict(torch.load(os.path.join(path, 'static_encoder.pt'), 207 | map_location=self.args.cuda_device)) 208 | 209 | print('load succeed!') 210 | 211 | def load_network(self, network, save_path): 212 | state_dict = torch.load(save_path) 213 | # create new OrderedDict that does not contain `module.` 214 | new_state_dict = OrderedDict() 215 | for k, v in state_dict.items(): 216 | namekey = k[7:] # remove `module.` 217 | new_state_dict[namekey] = v 218 | # load params 219 | network.load_state_dict(new_state_dict) 220 | return network 221 | 222 | def DataParallel(self): 223 | self.static_encoder = nn.DataParallel(self.static_encoder) 224 | self.auto_encoder = nn.DataParallel(self.auto_encoder) 225 | self.discriminator = nn.DataParallel(self.discriminator) 226 | -------------------------------------------------------------------------------- /utils/data_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.ndimage.filters as filters 3 | 4 | 5 | def length(x, axis=-1, keepdims=True): 6 | """ 7 | Computes vector norm along a tensor axis(axes) 8 | 9 | :param x: tensor 10 | :param axis: axis(axes) along which to compute the norm 11 | :param keepdims: indicates if the dimension(s) on axis should be kept 12 | :return: The length or vector of lengths. 13 | """ 14 | lgth = np.sqrt(np.sum(x * x, axis=axis, keepdims=keepdims)) 15 | return lgth 16 | 17 | 18 | def normalize(x, axis=-1, eps=1e-8): 19 | """ 20 | Normalizes a tensor over some axis (axes) 21 | 22 | :param x: data tensor 23 | :param axis: axis(axes) along which to compute the norm 24 | :param eps: epsilon to prevent numerical instabilities 25 | :return: The normalized tensor 26 | """ 27 | res = x / (length(x, axis=axis) + eps) 28 | return res 29 | 30 | 31 | def quat_normalize(x, eps=1e-8): 32 | """ 33 | Normalizes a quaternion tensor 34 | 35 | :param x: data tensor 36 | :param eps: epsilon to prevent numerical instabilities 37 | :return: The normalized quaternions tensor 38 | """ 39 | res = normalize(x, eps=eps) 40 | return res 41 | 42 | 43 | def angle_axis_to_quat(angle, axis): 44 | """ 45 | Converts from and angle-axis representation to a quaternion representation 46 | 47 | :param angle: angles tensor 48 | :param axis: axis tensor 49 | :return: quaternion tensor 50 | """ 51 | c = np.cos(angle / 2.0)[..., np.newaxis] 52 | s = np.sin(angle / 2.0)[..., np.newaxis] 53 | q = np.concatenate([c, s * axis], axis=-1) 54 | return q 55 | 56 | 57 | def euler_to_quat(e, order='zyx'): 58 | """ 59 | 60 | Converts from an euler representation to a quaternion representation 61 | 62 | :param e: euler tensor 63 | :param order: order of euler rotations 64 | :return: quaternion tensor 65 | """ 66 | axis = { 67 | 'x': np.asarray([1, 0, 0], dtype=np.float32), 68 | 'y': np.asarray([0, 1, 0], dtype=np.float32), 69 | 'z': np.asarray([0, 0, 1], dtype=np.float32)} 70 | 71 | q0 = angle_axis_to_quat(e[..., 0], axis[order[0]]) 72 | q1 = angle_axis_to_quat(e[..., 1], axis[order[1]]) 73 | q2 = angle_axis_to_quat(e[..., 2], axis[order[2]]) 74 | 75 | return quat_mul(q0, quat_mul(q1, q2)) 76 | 77 | 78 | def quat_inv(q): 79 | """ 80 | Inverts a tensor of quaternions 81 | 82 | :param q: quaternion tensor 83 | :return: tensor of inverted quaternions 84 | """ 85 | res = np.asarray([1, -1, -1, -1], dtype=np.float32) * q 86 | return res 87 | 88 | 89 | def quat_fk(lrot, lpos, parents): 90 | """ 91 | Performs Forward Kinematics (FK) on local quaternions and local positions to retrieve global representations 92 | 93 | :param lrot: tensor of local quaternions with shape (..., Nb of joints, 4) 94 | :param lpos: tensor of local positions with shape (..., Nb of joints, 3) 95 | :param parents: list of parents indices 96 | :return: tuple of tensors of global quaternion, global positions 97 | """ 98 | gp, gr = [lpos[..., :1, :]], [lrot[..., :1, :]] 99 | for i in range(1, len(parents)): 100 | gp.append(quat_mul_vec(gr[parents[i]], lpos[..., i:i+1, :]) + gp[parents[i]]) 101 | gr.append(quat_mul (gr[parents[i]], lrot[..., i:i+1, :])) 102 | 103 | res = np.concatenate(gr, axis=-2), np.concatenate(gp, axis=-2) 104 | return res 105 | 106 | 107 | def quat_ik(grot, gpos, parents): 108 | """ 109 | Performs Inverse Kinematics (IK) on global quaternions and global positions to retrieve local representations 110 | 111 | :param grot: tensor of global quaternions with shape (..., Nb of joints, 4) 112 | :param gpos: tensor of global positions with shape (..., Nb of joints, 3) 113 | :param parents: list of parents indices 114 | :return: tuple of tensors of local quaternion, local positions 115 | """ 116 | res = [ 117 | np.concatenate([ 118 | grot[..., :1, :], 119 | quat_mul(quat_inv(grot[..., parents[1:], :]), grot[..., 1:, :]), 120 | ], axis=-2), 121 | np.concatenate([ 122 | gpos[..., :1, :], 123 | quat_mul_vec( 124 | quat_inv(grot[..., parents[1:], :]), 125 | gpos[..., 1:, :] - gpos[..., parents[1:], :]), 126 | ], axis=-2) 127 | ] 128 | 129 | return res 130 | 131 | 132 | def quat_mul(x, y): 133 | """ 134 | Performs quaternion multiplication on arrays of quaternions 135 | 136 | :param x: tensor of quaternions of shape (..., Nb of joints, 4) 137 | :param y: tensor of quaternions of shape (..., Nb of joints, 4) 138 | :return: The resulting quaternions 139 | """ 140 | x0, x1, x2, x3 = x[..., 0:1], x[..., 1:2], x[..., 2:3], x[..., 3:4] 141 | y0, y1, y2, y3 = y[..., 0:1], y[..., 1:2], y[..., 2:3], y[..., 3:4] 142 | 143 | res = np.concatenate([ 144 | y0 * x0 - y1 * x1 - y2 * x2 - y3 * x3, 145 | y0 * x1 + y1 * x0 - y2 * x3 + y3 * x2, 146 | y0 * x2 + y1 * x3 + y2 * x0 - y3 * x1, 147 | y0 * x3 - y1 * x2 + y2 * x1 + y3 * x0], axis=-1) 148 | 149 | return res 150 | 151 | 152 | def quat_mul_vec(q, x): 153 | """ 154 | Performs multiplication of an array of 3D vectors by an array of quaternions (rotation). 155 | 156 | :param q: tensor of quaternions of shape (..., Nb of joints, 4) 157 | :param x: tensor of vectors of shape (..., Nb of joints, 3) 158 | :return: the resulting array of rotated vectors 159 | """ 160 | t = 2.0 * np.cross(q[..., 1:], x) 161 | res = x + q[..., 0][..., np.newaxis] * t + np.cross(q[..., 1:], t) 162 | 163 | return res 164 | 165 | 166 | def quat_between(x, y): 167 | """ 168 | Quaternion rotations between two 3D-vector arrays 169 | 170 | :param x: tensor of 3D vectors 171 | :param y: tensor of 3D vetcors 172 | :return: tensor of quaternions 173 | """ 174 | res = np.concatenate([ 175 | np.sqrt(np.sum(x * x, axis=-1) * np.sum(y * y, axis=-1))[..., np.newaxis] + 176 | np.sum(x * y, axis=-1)[..., np.newaxis], 177 | np.cross(x, y)], axis=-1) 178 | return res 179 | 180 | def remove_quat_discontinuities(rotations): 181 | """ 182 | 183 | Removing quat discontinuities on the time dimension (removing flips) 184 | 185 | :param rotations: Array of quaternions of shape (T, J, 4) 186 | :return: The processed array without quaternion inversion. 187 | """ 188 | rots_inv = -rotations 189 | 190 | for i in range(1, rotations.shape[0]): 191 | # Compare dot products 192 | replace_mask = np.sum(rotations[i - 1: i] * rotations[i: i + 1], axis=-1) < np.sum( 193 | rotations[i - 1: i] * rots_inv[i: i + 1], axis=-1) 194 | replace_mask = replace_mask[..., np.newaxis] 195 | rotations[i] = replace_mask * rots_inv[i] + (1.0 - replace_mask) * rotations[i] 196 | 197 | return rotations 198 | 199 | 200 | def rotate_at_each_frame(X, Q, parents): 201 | """ 202 | Re-orients the animation data according to the last frame of past context. 203 | 204 | :param X: tensor of local positions of shape (Batchsize, Timesteps, Joints, 3) 205 | :param Q: tensor of local quaternions (Batchsize, Timesteps, Joints, 4) 206 | :param parents: list of parents' indices 207 | :param n_past: number of frames in the past context 208 | :return: The rotated positions X and quaternions Q 209 | """ 210 | # Get global quats and global poses (FK) 211 | 212 | global_q, global_x = quat_fk(Q, X, parents) 213 | 214 | key_glob_Q = global_q[..., 0:1, :] # (B, T, 1, 4) 215 | 216 | forward = np.array([1, 0, 1])[np.newaxis, np.newaxis, np.newaxis, :] \ 217 | * quat_mul_vec(key_glob_Q, np.array([0, 1, 0])[np.newaxis, np.newaxis, np.newaxis, :]) # (B, T, 1, 3) 218 | forward = normalize(forward) 219 | 220 | # yrot = quat_normalize(quat_between(np.array([1, 0, 0]), forward)) 221 | yrot = quat_normalize(quat_between(np.array([0, 0, 1]), forward)) # (B, T, 1, 4) 222 | new_glob_Q = quat_mul(quat_inv(yrot), global_q) 223 | new_glob_X = quat_mul_vec(quat_inv(yrot), global_x - global_x[..., 0:1, :]) + global_x[..., 0:1, :] 224 | # new_glob_X = quat_mul_vec(quat_inv(yrot), global_x) 225 | 226 | # back to local quat-pos 227 | Q, X = quat_ik(new_glob_Q, new_glob_X, parents) 228 | 229 | return X, Q, yrot, forward 230 | 231 | 232 | def rotate_at_each_dog_frame(X, Q, parents): 233 | """ 234 | Re-orients the animation data according to the last frame of past context. 235 | 236 | :param X: tensor of local positions of shape (Batchsize, Timesteps, Joints, 3) 237 | :param Q: tensor of local quaternions (Batchsize, Timesteps, Joints, 4) 238 | :param parents: list of parents' indices 239 | :param n_past: number of frames in the past context 240 | :return: The rotated positions X and quaternions Q 241 | """ 242 | # Get global quats and global poses (FK) 243 | global_q, global_x = quat_fk(Q, X, parents) # (B, T, J, 4) (B, T, J, 3) 244 | 245 | # key_glob_Q = global_q[..., 0:1, :] # (B, T, 1, 4) 246 | 247 | """ Extract Forward Direction """ 248 | forward = ( 249 | (global_x[..., 2, :] - global_x[..., 0, :])) 250 | forward[..., 1] = 0 251 | forward = forward / np.sqrt((forward ** 2).sum(axis=-1))[..., np.newaxis] 252 | """ Smooth Forward Direction """ 253 | direction_filterwidth = 8 254 | forward = filters.gaussian_filter1d( 255 | forward, direction_filterwidth, axis=1, mode='nearest') 256 | forward = forward / np.sqrt((forward**2).sum(axis=-1))[...,np.newaxis] # B T 3 257 | forward = forward[...,np.newaxis, :] # B T 1 3 258 | 259 | # yrot = quat_normalize(quat_between(np.array([1, 0, 0]), forward)) 260 | yrot = quat_normalize(quat_between(np.array([0, 0, 1]), forward)) # (B, T, 1, 4) 261 | new_glob_Q = quat_mul(quat_inv(yrot), global_q) 262 | new_glob_X = quat_mul_vec(quat_inv(yrot), global_x - global_x[..., 0:1, :]) + global_x[..., 0:1, :] 263 | # new_glob_X = quat_mul_vec(quat_inv(yrot), global_x) 264 | 265 | # back to local quat-pos 266 | Q, X = quat_ik(new_glob_Q, new_glob_X, parents) 267 | 268 | return X, Q, yrot, forward 269 | 270 | 271 | def extract_local_velocities(vel, yrot): 272 | """ 273 | calculate velocities of whole joints after facing on (0, 0, 1) 274 | :param vel: velocities of joints B * T * J * 3 275 | :param yrot: return by function rotate_at_frame 276 | :return: 277 | """ 278 | local_velocities = quat_mul_vec(quat_inv(yrot), vel) 279 | return local_velocities 280 | -------------------------------------------------------------------------------- /outer_utils/BVH_mod.py: -------------------------------------------------------------------------------- 1 | import re 2 | import numpy as np 3 | 4 | from outer_utils.Animation import Animation 5 | from outer_utils.Quaternions import Quaternions 6 | 7 | channelmap = { 8 | 'Xrotation' : 'x', 9 | 'Yrotation' : 'y', 10 | 'Zrotation' : 'z' 11 | } 12 | 13 | channelmap_inv = { 14 | 'x': 'Xrotation', 15 | 'y': 'Yrotation', 16 | 'z': 'Zrotation', 17 | } 18 | 19 | ordermap = { 20 | 'x' : 0, 21 | 'y' : 1, 22 | 'z' : 2, 23 | } 24 | 25 | def load(filename, start=None, end=None, order=None, world=False, need_quater=False): 26 | """ 27 | Reads a BVH file and constructs an animation 28 | 29 | Parameters 30 | ---------- 31 | filename: str 32 | File to be opened 33 | 34 | start : int 35 | Optional Starting Frame 36 | 37 | end : int 38 | Optional Ending Frame 39 | 40 | order : str 41 | Optional Specifier for joint order. 42 | Given as string E.G 'xyz', 'zxy' 43 | 44 | world : bool 45 | If set to true euler angles are applied 46 | together in world space rather than local 47 | space 48 | 49 | Returns 50 | ------- 51 | 52 | (animation, joint_names, frametime) 53 | Tuple of loaded animation and joint names 54 | """ 55 | 56 | f = open(filename, "r") 57 | 58 | i = 0 59 | active = -1 60 | end_site = False 61 | 62 | names = [] 63 | orients = Quaternions.id(0) 64 | offsets = np.array([]).reshape((0,3)) 65 | parents = np.array([], dtype=int) 66 | 67 | for line in f: 68 | 69 | if "HIERARCHY" in line: continue 70 | if "MOTION" in line: continue 71 | 72 | """ Modified line read to handle mixamo data """ 73 | # rmatch = re.match(r"ROOT (\w+)", line) 74 | rmatch = re.match(r"ROOT (\w+:?\w+)", line) 75 | if rmatch: 76 | names.append(rmatch.group(1)) 77 | offsets = np.append(offsets, np.array([[0,0,0]]), axis=0) 78 | orients.qs = np.append(orients.qs, np.array([[1,0,0,0]]), axis=0) 79 | parents = np.append(parents, active) 80 | active = (len(parents)-1) 81 | continue 82 | 83 | if "{" in line: continue 84 | 85 | if "}" in line: 86 | if end_site: end_site = False 87 | else: active = parents[active] 88 | continue 89 | 90 | offmatch = re.match(r"\s*OFFSET\s+([\-\d\.e]+)\s+([\-\d\.e]+)\s+([\-\d\.e]+)", line) 91 | if offmatch: 92 | if not end_site: 93 | offsets[active] = np.array([list(map(float, offmatch.groups()))]) 94 | continue 95 | 96 | chanmatch = re.match(r"\s*CHANNELS\s+(\d+)", line) 97 | if chanmatch: 98 | channels = int(chanmatch.group(1)) 99 | if order is None: 100 | channelis = 0 if channels == 3 else 3 101 | channelie = 3 if channels == 3 else 6 102 | parts = line.split()[2+channelis:2+channelie] 103 | if any([p not in channelmap for p in parts]): 104 | continue 105 | order = "".join([channelmap[p] for p in parts]) 106 | continue 107 | 108 | """ Modified line read to handle mixamo data """ 109 | # jmatch = re.match("\s*JOINT\s+(\w+)", line) 110 | jmatch = re.match("\s*JOINT\s+(\w+:?\w+)", line) 111 | if jmatch: 112 | names.append(jmatch.group(1)) 113 | offsets = np.append(offsets, np.array([[0,0,0]]), axis=0) 114 | orients.qs = np.append(orients.qs, np.array([[1,0,0,0]]), axis=0) 115 | parents = np.append(parents, active) 116 | active = (len(parents)-1) 117 | continue 118 | 119 | if "End Site" in line: 120 | end_site = True 121 | continue 122 | 123 | fmatch = re.match("\s*Frames:\s+(\d+)", line) 124 | if fmatch: 125 | if start and end: 126 | fnum = (end - start)-1 127 | else: 128 | fnum = int(fmatch.group(1)) 129 | jnum = len(parents) 130 | positions = offsets[np.newaxis].repeat(fnum, axis=0) 131 | rotations = np.zeros((fnum, len(orients), 3)) 132 | continue 133 | 134 | fmatch = re.match("\s*Frame Time:\s+([\d\.]+)", line) 135 | if fmatch: 136 | frametime = float(fmatch.group(1)) 137 | continue 138 | 139 | if (start and end) and (i < start or i >= end-1): 140 | i += 1 141 | continue 142 | 143 | # dmatch = line.strip().split(' ') 144 | dmatch = line.strip().split() 145 | if dmatch: 146 | data_block = np.array(list(map(float, dmatch))) 147 | N = len(parents) 148 | fi = i - start if start else i 149 | if channels == 3: 150 | positions[fi,0:1] = data_block[0:3] 151 | rotations[fi, : ] = data_block[3: ].reshape(N,3) 152 | elif channels == 6: 153 | data_block = data_block.reshape(N,6) 154 | positions[fi,:] = data_block[:,0:3] 155 | rotations[fi,:] = data_block[:,3:6] 156 | elif channels == 9: 157 | positions[fi,0] = data_block[0:3] 158 | data_block = data_block[3:].reshape(N-1,9) 159 | rotations[fi,1:] = data_block[:,3:6] 160 | positions[fi,1:] += data_block[:,0:3] * data_block[:,6:9] 161 | else: 162 | raise Exception("Too many channels! %i" % channels) 163 | 164 | i += 1 165 | 166 | f.close() 167 | 168 | if need_quater: 169 | rotations = Quaternions.from_euler(np.radians(rotations), order=order, world=world) 170 | elif order != 'xyz': 171 | rotations = Quaternions.from_euler(np.radians(rotations), order=order, world=world) 172 | rotations = np.degrees(rotations.euler()) 173 | 174 | return (Animation(rotations, positions, orients, offsets, parents), names, frametime) 175 | 176 | 177 | 178 | def save(filename, anim, names=None, frametime=1.0/24.0, order='zyx', positions=False, orients=True, mask=None, quater=False): 179 | """ 180 | Saves an Animation to file as BVH 181 | 182 | Parameters 183 | ---------- 184 | filename: str 185 | File to be saved to 186 | 187 | anim : Animation 188 | Animation to save 189 | 190 | names : [str] 191 | List of joint names 192 | 193 | order : str 194 | Optional Specifier for joint order. 195 | Given as string E.G 'xyz', 'zxy' 196 | 197 | frametime : float 198 | Optional Animation Frame time 199 | 200 | positions : bool 201 | Optional specfier to save bone 202 | positions for each frame 203 | 204 | orients : bool 205 | Multiply joint orients to the rotations 206 | before saving. 207 | 208 | """ 209 | 210 | if names is None: 211 | names = ["joint_" + str(i) for i in range(len(anim.parents))] 212 | 213 | with open(filename, 'w') as f: 214 | 215 | t = "" 216 | f.write("%sHIERARCHY\n" % t) 217 | f.write("%sROOT %s\n" % (t, names[0])) 218 | f.write("%s{\n" % t) 219 | t += '\t' 220 | 221 | f.write("%sOFFSET %f %f %f\n" % (t, anim.offsets[0,0], anim.offsets[0,1], anim.offsets[0,2]) ) 222 | f.write("%sCHANNELS 6 Xposition Yposition Zposition %s %s %s \n" % 223 | (t, channelmap_inv[order[0]], channelmap_inv[order[1]], channelmap_inv[order[2]])) 224 | 225 | for i in range(anim.shape[1]): 226 | if anim.parents[i] == 0: 227 | t = save_joint(f, anim, names, t, i, order=order, positions=positions) 228 | 229 | t = t[:-1] 230 | f.write("%s}\n" % t) 231 | 232 | f.write("MOTION\n") 233 | f.write("Frames: %i\n" % anim.shape[0]); 234 | f.write("Frame Time: %f\n" % frametime); 235 | 236 | #if orients: 237 | # rots = np.degrees((-anim.orients[np.newaxis] * anim.rotations).euler(order=order[::-1])) 238 | #else: 239 | # rots = np.degrees(anim.rotations.euler(order=order[::-1])) 240 | # rots = np.degrees(anim.rotations.euler(order=order[::-1])) 241 | if quater: 242 | rots = np.degrees(anim.rotations.euler(order=order[::-1])) 243 | else: 244 | rots = anim.rotations 245 | poss = anim.positions 246 | 247 | for i in range(anim.shape[0]): 248 | for j in range(anim.shape[1]): 249 | 250 | if positions or j == 0: 251 | 252 | f.write("%f %f %f %f %f %f " % ( 253 | poss[i,j,0], poss[i,j,1], poss[i,j,2], 254 | rots[i,j,ordermap[order[0]]], rots[i,j,ordermap[order[1]]], rots[i,j,ordermap[order[2]]])) 255 | 256 | else: 257 | if mask == None or mask[j] == 1: 258 | f.write("%f %f %f " % ( 259 | rots[i,j,ordermap[order[0]]], rots[i,j,ordermap[order[1]]], rots[i,j,ordermap[order[2]]])) 260 | else: 261 | f.write("%f %f %f " % (0, 0, 0)) 262 | 263 | f.write("\n") 264 | 265 | 266 | def save_joint(f, anim, names, t, i, order='zyx', positions=False): 267 | 268 | f.write("%sJOINT %s\n" % (t, names[i])) 269 | f.write("%s{\n" % t) 270 | t += '\t' 271 | 272 | f.write("%sOFFSET %f %f %f\n" % (t, anim.offsets[i,0], anim.offsets[i,1], anim.offsets[i,2])) 273 | 274 | if positions: 275 | f.write("%sCHANNELS 6 Xposition Yposition Zposition %s %s %s \n" % (t, 276 | channelmap_inv[order[0]], channelmap_inv[order[1]], channelmap_inv[order[2]])) 277 | else: 278 | f.write("%sCHANNELS 3 %s %s %s\n" % (t, 279 | channelmap_inv[order[0]], channelmap_inv[order[1]], channelmap_inv[order[2]])) 280 | 281 | end_site = True 282 | 283 | for j in range(anim.shape[1]): 284 | if anim.parents[j] == i: 285 | t = save_joint(f, anim, names, t, j, order=order, positions=positions) 286 | end_site = False 287 | 288 | if end_site: 289 | f.write("%sEnd Site\n" % t) 290 | f.write("%s{\n" % t) 291 | t += '\t' 292 | f.write("%sOFFSET %f %f %f\n" % (t, 0.0, 0.0, 0.0)) 293 | t = t[:-1] 294 | f.write("%s}\n" % t) 295 | 296 | t = t[:-1] 297 | f.write("%s}\n" % t) 298 | 299 | return t 300 | -------------------------------------------------------------------------------- /data_preprocess/Lafan1_and_dog/datasetserial.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from utils.rotation import * 3 | 4 | 5 | class DogDataset(Dataset): 6 | def __init__(self, config): 7 | super(DogDataset, self).__init__() 8 | self.config = config 9 | self.njoints = config.dog_njoints 10 | _data = np.load(config.dog_train_path, allow_pickle=True) 11 | _stats = np.load(config.dogstats_path, allow_pickle=True) 12 | 13 | self.parents = _stats["parents"] 14 | self.parents_withend = _stats["parents_withend"] 15 | self.not_end = _stats["not_end"] 16 | self.njoints_withend = len(self.parents_withend) 17 | 18 | self.data, self.y_rot, self.skel_offsets, \ 19 | self.skel_offsets_withend, self.skel_names = self._get_in_out(_data) 20 | self.skel_offsets = self.skel_offsets.item() 21 | self.skel_offsets_withend = self.skel_offsets_withend.item() 22 | self.mean, self.std, self.min_vel, self.max_vel = \ 23 | _stats['mean'], _stats['std'], _stats['min_vel'], _stats['max_vel'] 24 | self.data = (self.data - self.mean[np.newaxis, np.newaxis, ...])/self.std[np.newaxis, np.newaxis, ...] 25 | 26 | def _get_in_out(self, _data): 27 | Q = _data["Q"] 28 | V = _data["V"] 29 | yrot = _data["yrot"] 30 | rvel = wrap(quat2pivots, wrap(qmultipy, wrap(qinv, yrot[:, :-1, ...]), yrot[:, 1:, ...])) 31 | rvel = np.concatenate((rvel, rvel[:, -1:, ...]), axis=1) 32 | skel_offsets, skel_offsets_withend, skel_names = \ 33 | _data["skel_offsets"], _data["skel_offsets_withend"], _data["skel_names"] 34 | indices = np.where(Q[..., 0] < 0) 35 | Q[indices] = -Q[indices] 36 | Q = np.reshape(Q, [Q.shape[0], Q.shape[1], -1]) 37 | V = np.reshape(V, [V.shape[0], V.shape[1], -1]) 38 | RootV = V[..., :3] 39 | Input = np.concatenate([Q, RootV], axis=-1) 40 | rvel = np.reshape(rvel, rvel.shape[:2] + (-1,)) 41 | Input = np.concatenate([Input, rvel], -1) 42 | return Input, yrot, skel_offsets, skel_offsets_withend, skel_names 43 | 44 | def __getitem__(self, item): 45 | skel_name = self.skel_names[item] 46 | return self.data[item], self.y_rot[item], \ 47 | self.skel_offsets[skel_name], \ 48 | self.skel_offsets_withend[skel_name] 49 | 50 | def __len__(self): 51 | return len(self.data) 52 | 53 | def denorm(self, x, transpose=False): 54 | if transpose: 55 | x = x.transpose(1, 2) 56 | mean = torch.Tensor(self.mean).to(x.device) 57 | std = torch.Tensor(self.std).to(x.device) 58 | b_size, t_size, c_size = x.shape[0], x.shape[1], x.shape[2] 59 | x = x.reshape(-1, c_size) 60 | denorm_x = x * std[:c_size] + mean[:c_size] 61 | return denorm_x.reshape(b_size, t_size, c_size) 62 | 63 | 64 | class HumDataset(Dataset): 65 | def __init__(self, config): 66 | super(HumDataset, self).__init__() 67 | 68 | hum_path = config.hum_train_path 69 | self.njoints = config.hum_njoints 70 | _data = np.load(hum_path, allow_pickle=True) 71 | _stats = np.load(config.humstats_path, allow_pickle=True) 72 | self.config = config 73 | 74 | self.parents = _stats["parents"] 75 | self.parents_withend = _stats["parents_withend"] 76 | self.not_end = _stats["not_end"] 77 | self.njoints_withend = len(self.parents_withend) 78 | 79 | self.data, self.y_rot, self.skel_offsets, \ 80 | self.skel_offsets_withend, self.skel_names = self._get_in_out(_data) 81 | self.skel_offsets = self.skel_offsets.item() 82 | self.skel_offsets_withend = self.skel_offsets_withend.item() 83 | 84 | self.mean, self.std, self.min_vel, self.max_vel = \ 85 | _stats['mean'], _stats['std'], _stats['min_vel'], _stats['max_vel'] 86 | 87 | self.data = (self.data - self.mean[np.newaxis, np.newaxis, ...]) / self.std[np.newaxis, np.newaxis, ...] 88 | 89 | def _get_in_out(self, _data): 90 | Q = _data["Q"] 91 | V = _data["V"] 92 | yrot = _data["yrot"] 93 | rvel = wrap(quat2pivots, wrap(qmultipy, wrap(qinv, yrot[:, :-1, ...]), yrot[:, 1:, ...])) 94 | rvel = np.concatenate((rvel, rvel[:, -1:, ...]), axis=1) 95 | skel_offsets, skel_offsets_withend, skel_names = \ 96 | _data["skel_offsets"], _data["skel_offsets_withend"], _data["skel_names"] 97 | indices = np.where(Q[..., 0] < 0) 98 | Q[indices] = -Q[indices] 99 | Q = np.reshape(Q, [Q.shape[0], Q.shape[1], -1]) 100 | V = np.reshape(V, [V.shape[0], V.shape[1], -1]) 101 | RootV = V[..., :3] 102 | Input = np.concatenate([Q, RootV], axis=-1) 103 | rvel = np.reshape(rvel, rvel.shape[:2] + (-1,)) 104 | Input = np.concatenate([Input, rvel], -1) 105 | return Input, yrot, skel_offsets, skel_offsets_withend, skel_names 106 | 107 | def __getitem__(self, item): 108 | return self.data[item], self.y_rot[item], \ 109 | self.skel_offsets[self.skel_names[item]], \ 110 | self.skel_offsets_withend[self.skel_names[item]] 111 | 112 | def __len__(self): 113 | return len(self.data) 114 | 115 | def denorm(self, x, transpose=False): 116 | if transpose: 117 | x = x.transpose(1, 2) 118 | mean = torch.Tensor(self.mean).to(x.device) 119 | std = torch.Tensor(self.std).to(x.device) 120 | b_size, t_size, c_size = x.shape[0], x.shape[1], x.shape[2] 121 | x = x.reshape(-1, c_size) 122 | denorm_x = x * std[:c_size] + mean[:c_size] 123 | return denorm_x.reshape(b_size, t_size, c_size) 124 | 125 | 126 | class DogDatasetTest(Dataset): 127 | def __init__(self, config): 128 | super(DogDatasetTest, self).__init__() 129 | 130 | self.config = config 131 | self.njoints = config.dog_njoints 132 | 133 | _data = np.load(config.dog_test_path, allow_pickle=True) 134 | _stats = np.load(config.dogstats_path, allow_pickle=True) 135 | 136 | self.parents = _stats["parents"] 137 | self.parents_withend = _stats["parents_withend"] 138 | self.not_end = _stats["not_end"] 139 | self.njoints_withend = len(self.parents_withend) 140 | 141 | self.data, self.y_rot, self.skel_offsets, \ 142 | self.skel_offsets_withend, self.skel_names = self._get_in_out(_data) 143 | self.skel_offsets = self.skel_offsets.item() 144 | self.skel_offsets_withend = self.skel_offsets_withend.item() 145 | self.mean, self.std, self.min_vel, self.max_vel = \ 146 | _stats['mean'], _stats['std'], _stats['min_vel'], _stats['max_vel'] 147 | self.data = (self.data - self.mean[np.newaxis, np.newaxis, ...]) / self.std[np.newaxis, np.newaxis, ...] 148 | 149 | def _get_in_out(self, _data): 150 | Q = _data["Q"] 151 | V = _data["V"] 152 | yrot = _data["yrot"] 153 | rvel = wrap(quat2pivots, wrap(qmultipy, wrap(qinv, yrot[:, :-1, ...]), yrot[:, 1:, ...])) 154 | rvel = np.concatenate((rvel, rvel[:, -1:, ...]), axis=1) 155 | skel_offsets, skel_offsets_withend, skel_names = \ 156 | _data["skel_offsets"], _data["skel_offsets_withend"], _data["skel_names"] 157 | indices = np.where(Q[..., 0] < 0) 158 | Q[indices] = -Q[indices] 159 | Q = np.reshape(Q, [Q.shape[0], Q.shape[1], -1]) 160 | V = np.reshape(V, [V.shape[0], V.shape[1], -1]) 161 | RootV = V[..., :3] 162 | Input = np.concatenate([Q, RootV], axis=-1) 163 | rvel = np.reshape(rvel, rvel.shape[:2] + (-1,)) 164 | Input = np.concatenate([Input, rvel], -1) 165 | return Input, yrot, skel_offsets, skel_offsets_withend, skel_names 166 | 167 | def __getitem__(self, item): 168 | return self.data[item], self.y_rot[item], \ 169 | self.skel_offsets[self.skel_names[item]], \ 170 | self.skel_offsets_withend[self.skel_names[item]] 171 | 172 | def __len__(self): 173 | return len(self.data) 174 | 175 | def denorm(self, x, transpose=False): 176 | if transpose: 177 | x = x.transpose(1, 2) 178 | mean = torch.Tensor(self.mean).to(x.device) 179 | std = torch.Tensor(self.std).to(x.device) 180 | b_size, t_size, c_size = x.shape[0], x.shape[1], x.shape[2] 181 | x = x.reshape(-1, c_size) 182 | denorm_x = x * std[:c_size] + mean[:c_size] 183 | return denorm_x.reshape(b_size, t_size, c_size) 184 | 185 | 186 | class HumDatasetTest(Dataset): 187 | def __init__(self, config): 188 | super(HumDatasetTest, self).__init__() 189 | _data = np.load(config.hum_test_path, allow_pickle=True) 190 | _stats = np.load(config.humstats_path, allow_pickle=True) 191 | 192 | self.njoints = config.hum_njoints 193 | self.config = config 194 | 195 | self.parents = _stats["parents"] 196 | self.parents_withend = _stats["parents_withend"] 197 | self.not_end = _stats["not_end"] 198 | self.njoints_withend = len(self.parents_withend) 199 | 200 | self.data, self.y_rot, self.skel_offsets, \ 201 | self.skel_offsets_withend, self.skel_names = self._get_in_out(_data) 202 | self.skel_offsets = self.skel_offsets.item() 203 | self.skel_offsets_withend = self.skel_offsets_withend.item() 204 | 205 | self.mean, self.std, self.min_vel, self.max_vel = \ 206 | _stats['mean'], _stats['std'], _stats['min_vel'], _stats['max_vel'] 207 | 208 | self.data = (self.data - self.mean[np.newaxis, np.newaxis, ...]) / self.std[np.newaxis, np.newaxis, ...] 209 | 210 | def _get_in_out(self, _data): 211 | Q = _data["Q"] 212 | V = _data["V"] 213 | yrot = _data["yrot"] 214 | rvel = wrap(quat2pivots, wrap(qmultipy, wrap(qinv, yrot[:, :-1, ...]), yrot[:, 1:, ...])) 215 | rvel = np.concatenate((rvel, rvel[:, -1:, ...]), axis=1) 216 | skel_offsets, skel_offsets_withend, skel_names = \ 217 | _data["skel_offsets"], _data["skel_offsets_withend"], _data["skel_names"] 218 | indices = np.where(Q[..., 0] < 0) 219 | Q[indices] = -Q[indices] 220 | Q = np.reshape(Q, [Q.shape[0], Q.shape[1], -1]) 221 | V = np.reshape(V, [V.shape[0], V.shape[1], -1]) 222 | RootV = V[..., :3] 223 | Input = np.concatenate([Q, RootV], axis=-1) 224 | rvel = np.reshape(rvel, rvel.shape[:2] + (-1,)) 225 | Input = np.concatenate([Input, rvel], -1) 226 | return Input, yrot, skel_offsets, skel_offsets_withend, skel_names 227 | 228 | def __getitem__(self, item): 229 | return self.data[item], self.y_rot[item], \ 230 | self.skel_offsets[self.skel_names[item]], \ 231 | self.skel_offsets_withend[self.skel_names[item]] 232 | 233 | def __len__(self): 234 | return len(self.data) 235 | 236 | def denorm(self, x, transpose=False): 237 | if transpose: 238 | x = x.transpose(1, 2) 239 | mean = torch.Tensor(self.mean).to(x.device) 240 | std = torch.Tensor(self.std).to(x.device) 241 | b_size, t_size, c_size = x.shape[0], x.shape[1], x.shape[2] 242 | x = x.reshape(-1, c_size) 243 | denorm_x = x * std[:c_size] + mean[:c_size] 244 | return denorm_x.reshape(b_size, t_size, c_size) -------------------------------------------------------------------------------- /models/IK.py: -------------------------------------------------------------------------------- 1 | from utils.bvh_utils import read_bvh, save_bvh, read_bvh_with_end 2 | from tqdm import tqdm 3 | import torch 4 | import copy 5 | import utils.rotation as rt 6 | import numpy as np 7 | import outer_utils.BVH as BVH 8 | import outer_utils.Animation as Animation 9 | from data_preprocess.Mixamo.bvh_parser import BVH_file 10 | from outer_utils.Quaternions_old import Quaternions 11 | from models.Kinematics import InverseKinematics, InverseKinematics_humdog 12 | 13 | L = 6 14 | from scipy import io 15 | 16 | 17 | def alpha(t): 18 | return 2.0 * t * t * t - 3.0 * t * t + 1 19 | 20 | 21 | def lerp(a, l, r): 22 | return (1 - a) * l + a * r 23 | 24 | 25 | def get_character_height(file_name): 26 | file = BVH_file(file_name) 27 | return file.get_height() 28 | 29 | 30 | def get_ee_id_by_names(joint_names, raw_bvh=False): 31 | if raw_bvh: 32 | ees = ['RightToe_End', 'LeftToe_End', 'LeftToeBase', 'RightToeBase'] 33 | else: 34 | ees = ['RightToeBase', 'LeftToeBase', 'LeftFoot', 'RightFoot'] 35 | 36 | ee_id = [] 37 | for i, name in enumerate(joint_names): 38 | if ':' in name: 39 | joint_names[i] = joint_names[i].split(':')[1] 40 | for i, ee in enumerate(ees): 41 | ee_id.append(joint_names.index(ee)) 42 | return ee_id 43 | 44 | 45 | def get_foot_contact(file_name, ref_height=None, thr=0.003, raw_bvh=False): 46 | anim, names, _ = BVH.load(file_name) 47 | 48 | ee_ids = get_ee_id_by_names(names, raw_bvh=raw_bvh) 49 | 50 | glb = Animation.positions_global(anim) # [T, J, 3] 51 | 52 | ee_pos = glb[:, ee_ids, :] 53 | ee_velo = ee_pos[1:, ...] - ee_pos[:-1, ...] 54 | if ref_height is not None: 55 | ee_velo = torch.tensor(ee_velo) / ref_height 56 | else: 57 | ee_velo = torch.tensor(ee_velo) 58 | ee_velo_norm = torch.norm(ee_velo, dim=-1) 59 | contact = ee_velo_norm < thr 60 | contact = contact.int() 61 | padding = torch.zeros_like(contact) 62 | contact = torch.cat([padding[:1, :], contact], dim=0) 63 | return contact.numpy() 64 | 65 | 66 | def remove_foot_sliding(input_file, foot_file, output_file, 67 | ref_height, input_raw_bvh=False, foot_raw_bvh=False): 68 | 69 | anim, name, ftime = BVH.load(input_file) 70 | anim_with_end = read_bvh_with_end(input_file) 71 | anim_no_end = read_bvh(input_file) 72 | 73 | fid = get_ee_id_by_names(name, input_raw_bvh) 74 | contact = get_foot_contact(foot_file, ref_height, raw_bvh=foot_raw_bvh) 75 | 76 | glb = Animation.positions_global(anim) # [T, J, 3] 77 | 78 | T = glb.shape[0] 79 | 80 | for i, fidx in enumerate(fid): # fidx: index of the foot joint 81 | fixed = contact[:, i] # [T] 82 | s = 0 83 | while s < T: 84 | while s < T and fixed[s] == 0: 85 | s += 1 86 | if s >= T: 87 | break 88 | t = s 89 | avg = glb[t, fidx].copy() 90 | while t + 1 < T and fixed[t + 1] == 1: 91 | t += 1 92 | avg += glb[t, fidx].copy() 93 | avg /= (t - s + 1) 94 | 95 | for j in range(s, t + 1): 96 | glb[j, fidx] = avg.copy() 97 | s = t + 1 98 | 99 | for s in range(T): 100 | if fixed[s] == 1: 101 | continue 102 | l, r = None, None 103 | consl, consr = False, False 104 | for k in range(L): 105 | if s - k - 1 < 0: 106 | break 107 | if fixed[s - k - 1]: 108 | l = s - k - 1 109 | consl = True 110 | break 111 | for k in range(L): 112 | if s + k + 1 >= T: 113 | break 114 | if fixed[s + k + 1]: 115 | r = s + k + 1 116 | consr = True 117 | break 118 | if not consl and not consr: 119 | continue 120 | if consl and consr: 121 | litp = lerp(alpha(1.0 * (s - l + 1) / (L + 1)), 122 | glb[s, fidx], glb[l, fidx]) 123 | ritp = lerp(alpha(1.0 * (r - s + 1) / (L + 1)), 124 | glb[s, fidx], glb[r, fidx]) 125 | itp = lerp(alpha(1.0 * (s - l + 1) / (r - l + 1)), 126 | ritp, litp) 127 | glb[s, fidx] = itp.copy() 128 | continue 129 | if consl: 130 | litp = lerp(alpha(1.0 * (s - l + 1) / (L + 1)), 131 | glb[s, fidx], glb[l, fidx]) 132 | glb[s, fidx] = litp.copy() 133 | continue 134 | if consr: 135 | ritp = lerp(alpha(1.0 * (r - s + 1) / (L + 1)), 136 | glb[s, fidx], glb[r, fidx]) 137 | glb[s, fidx] = ritp.copy() 138 | 139 | # glb is ready 140 | 141 | anim = anim.copy() 142 | 143 | rot = torch.tensor(anim.rotations.qs, dtype=torch.float) 144 | pos = torch.tensor(anim.positions[:, 0, :], dtype=torch.float) 145 | offset = torch.tensor(anim.offsets, dtype=torch.float) 146 | 147 | glb = torch.tensor(glb, dtype=torch.float) 148 | 149 | ik_solver = InverseKinematics(rot, pos, offset, anim.parents, glb) 150 | 151 | print('Removing Foot sliding') 152 | for i in tqdm(range(50)): 153 | ik_solver.step() 154 | 155 | rotations = ik_solver.rotations.detach() 156 | norm = torch.norm(rotations, dim=-1, keepdim=True) 157 | rotations /= norm 158 | 159 | anim.rotations = Quaternions(rotations.numpy()) 160 | anim.positions[:, 0, :] = ik_solver.position.detach().numpy() 161 | 162 | # BVH.save(output_file, anim, name, ftime) 163 | anim_no_end.quats = anim.rotations.qs 164 | anim_no_end.pos = anim.positions 165 | end_offset = anim_with_end.offsets[anim_with_end.endsite, :] 166 | 167 | save_bvh(output_file, anim_no_end, anim_no_end.bones, ftime, 168 | order='zyx', with_end=False, 169 | end_offset=end_offset) 170 | 171 | 172 | 173 | def get_foot_contact_by_height2(file_name, end_names, end_site=False): 174 | if not end_site: 175 | anim = read_bvh(file_name) 176 | else: 177 | anim = read_bvh_with_end(file_name) 178 | 179 | ee_ids = get_ee_id_by_names_humdog(anim.bones, end_names, True) 180 | 181 | _, glb = rt.quat_fk(torch.Tensor(anim.quats), torch.Tensor(anim.pos), anim.parents) 182 | ee_pos = glb[:, ee_ids, :].numpy() 183 | 184 | contacts = [] 185 | for i in range(0, glb.shape[0], 40): 186 | end = i+40 if i+40 < glb.shape[0] else glb.shape[0] 187 | contact = [] 188 | for j in range(len(end_names)): 189 | min_height = np.min(ee_pos[i: end, j, 1], axis=0) # len(ee_ids) 190 | ground_height = min_height + 1.5 191 | contact.append(ee_pos[i:end, j, 1] < ground_height) 192 | contact = np.stack(contact, 1) 193 | contacts.append(contact) 194 | 195 | contacts = np.concatenate(contacts, 0) 196 | contacts = contacts.astype(np.int) 197 | if 'Toe' in end_names: 198 | contacts[:, 2] = contacts[:, 0] 199 | contacts[:, 3] = contacts[:, 1] 200 | 201 | return contacts 202 | 203 | 204 | def get_ee_id_by_names_humdog(joint_names, end_names, end_site=False): 205 | # ees = ['RightHand', 'LeftHand', 'LeftFoot', 'RightFoot'] 206 | ees = end_names 207 | ee_id = [] 208 | for i, ee in enumerate(ees): 209 | if end_site: 210 | ee_id.append(joint_names.index(ee)+1) 211 | else: 212 | ee_id.append(joint_names.index(ee)) 213 | return ee_id 214 | 215 | 216 | def remove_foot_sliding_humdog(input_file, output_file, 217 | end_names=['RightHand', 'LeftHand', 'LeftFoot', 'RightFoot'], 218 | end_site=False): 219 | if end_site: 220 | anim = read_bvh_with_end(input_file) 221 | end_index = [] 222 | not_end_index = [] 223 | for i in range(len(anim.bones)): 224 | if anim.bones[i] == 'End Site': 225 | end_index.append(i) 226 | else: 227 | not_end_index.append(i) 228 | end_offsets = anim.offsets[end_index, :] 229 | else: 230 | anim = read_bvh(input_file) 231 | 232 | fid = get_ee_id_by_names_humdog(anim.bones, end_names, end_site) 233 | contact = get_foot_contact_by_height2(input_file, end_names, end_site) 234 | 235 | 236 | _, glb = rt.quat_fk(torch.Tensor(anim.quats), torch.Tensor(anim.pos), anim.parents) 237 | glb = glb.cpu().numpy() 238 | T = glb.shape[0] 239 | 240 | for i, fidx in enumerate(fid): # fidx: index of the foot joint 241 | fixed = contact[:, i] # [T] 242 | s = 0 243 | while s < T: 244 | while s < T and fixed[s] == 0: 245 | s += 1 246 | if s >= T: 247 | break 248 | t = s 249 | avg = glb[t, fidx].copy() 250 | while t + 1 < T and fixed[t + 1] == 1: 251 | t += 1 252 | avg += glb[t, fidx].copy() 253 | avg /= (t - s + 1) 254 | 255 | for j in range(s, t + 1): 256 | glb[j, fidx] = avg.copy() 257 | s = t + 1 258 | 259 | for s in range(T): 260 | if fixed[s] == 1: 261 | continue 262 | l, r = None, None 263 | consl, consr = False, False 264 | for k in range(L): 265 | if s - k - 1 < 0: 266 | break 267 | if fixed[s - k - 1]: 268 | l = s - k - 1 269 | consl = True 270 | break 271 | for k in range(L): 272 | if s + k + 1 >= T: 273 | break 274 | if fixed[s + k + 1]: 275 | r = s + k + 1 276 | consr = True 277 | break 278 | if not consl and not consr: 279 | continue 280 | if consl and consr: 281 | litp = lerp(alpha(1.0 * (s - l + 1) / (L + 1)), 282 | glb[s, fidx], glb[l, fidx]) 283 | ritp = lerp(alpha(1.0 * (r - s + 1) / (L + 1)), 284 | glb[s, fidx], glb[r, fidx]) 285 | itp = lerp(alpha(1.0 * (s - l + 1) / (r - l + 1)), 286 | ritp, litp) 287 | glb[s, fidx] = itp.copy() 288 | continue 289 | if consl: 290 | litp = lerp(alpha(1.0 * (s - l + 1) / (L + 1)), 291 | glb[s, fidx], glb[l, fidx]) 292 | glb[s, fidx] = litp.copy() 293 | continue 294 | if consr: 295 | ritp = lerp(alpha(1.0 * (r - s + 1) / (L + 1)), 296 | glb[s, fidx], glb[r, fidx]) 297 | glb[s, fidx] = ritp.copy() 298 | 299 | # glb is ready 300 | anim = copy.copy(anim) 301 | 302 | rot = torch.Tensor(anim.quats) 303 | pos = torch.Tensor(anim.pos[:, 0, :]) 304 | offset = torch.Tensor(anim.offsets) 305 | 306 | glb = torch.Tensor(glb) 307 | 308 | ik_solver = InverseKinematics_humdog(rot, pos, offset, anim.parents, glb) 309 | 310 | print('remove foot sliding using IK...') 311 | for i in tqdm(range(50)): 312 | ik_solver.step() 313 | 314 | rotations = ik_solver.rot.detach() 315 | norm = torch.norm(rotations, dim=-1, keepdim=True) 316 | rotations /= norm 317 | 318 | anim.quats = rotations.detach().numpy() 319 | anim.pos[:, 1, :] = ik_solver.pos.detach().numpy() 320 | if not end_site: 321 | save_bvh(output_file, anim, frametime=1 / 30, order='zyx', with_end=False, 322 | names=anim.bones) 323 | else: 324 | save_bvh(output_file, anim, frametime=1 / 30, order='zyx', with_end=True, 325 | names=anim.bones, end_offset=end_offsets, not_end_index=not_end_index) 326 | 327 | 328 | 329 | def normalize(x): 330 | return x/torch.norm(x, dim=-1, p=2, keepdim=True) 331 | 332 | 333 | -------------------------------------------------------------------------------- /models/Kinematics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import math 5 | from utils.rotation import quat_fk 6 | 7 | 8 | class ForwardKinematics: 9 | def __init__(self, args, edges): 10 | self.topology = [-1] * (len(edges) + 1) 11 | self.rotation_map = [] 12 | for i, edge in enumerate(edges): 13 | self.topology[edge[1]] = edge[0] 14 | self.rotation_map.append(edge[1]) 15 | 16 | self.world = args.fk_world 17 | self.pos_repr = args.pos_repr 18 | self.quater = args.rotation == 'quaternion' 19 | 20 | def forward_from_raw(self, raw, offset, world=None, quater=None): 21 | if world is None: world = self.world 22 | if quater is None: quater = self.quater 23 | if self.pos_repr == '3d': 24 | position = raw[:, -3:, :] 25 | rotation = raw[:, :-3, :] 26 | elif self.pos_repr == '4d': 27 | raise Exception('Not support') 28 | if quater: 29 | rotation = rotation.reshape((rotation.shape[0], -1, 4, rotation.shape[-1])) 30 | identity = torch.tensor((1, 0, 0, 0), dtype=torch.float, device=raw.device) 31 | else: 32 | rotation = rotation.reshape((rotation.shape[0], -1, 3, rotation.shape[-1])) 33 | identity = torch.zeros((3, ), dtype=torch.float, device=raw.device) 34 | identity = identity.reshape((1, 1, -1, 1)) 35 | new_shape = list(rotation.shape) 36 | new_shape[1] += 1 37 | new_shape[2] = 1 38 | rotation_final = identity.repeat(new_shape) 39 | for i, j in enumerate(self.rotation_map): 40 | rotation_final[:, j, :, :] = rotation[:, i, :, :] 41 | return self.forward(rotation_final, position, offset, world=world, quater=quater) 42 | 43 | ''' 44 | rotation should have shape batch_size * Joint_num * (3/4) * Time 45 | position should have shape batch_size * 3 * Time 46 | offset should have shape batch_size * Joint_num * 3 47 | output have shape batch_size * Time * Joint_num * 3 48 | ''' 49 | def forward(self, rotation: torch.Tensor, position: torch.Tensor, offset: torch.Tensor, order='xyz', quater=False, world=True): 50 | if not quater and rotation.shape[-2] != 3: raise Exception('Unexpected shape of rotation') 51 | if quater and rotation.shape[-2] != 4: raise Exception('Unexpected shape of rotation') 52 | rotation = rotation.permute(0, 3, 1, 2) 53 | position = position.permute(0, 2, 1) 54 | result = torch.empty(rotation.shape[:-1] + (3, ), device=position.device) 55 | 56 | 57 | norm = torch.norm(rotation, dim=-1, keepdim=True) 58 | #norm[norm < 1e-10] = 1 59 | rotation = rotation / norm 60 | 61 | 62 | if quater: 63 | transform = self.transform_from_quaternion(rotation) 64 | else: 65 | transform = self.transform_from_euler(rotation, order) 66 | 67 | offset = offset.reshape((-1, 1, offset.shape[-2], offset.shape[-1], 1)) 68 | 69 | result[..., 0, :] = position 70 | for i, pi in enumerate(self.topology): 71 | if pi == -1: 72 | assert i == 0 73 | continue 74 | 75 | transform[..., i, :, :] = torch.matmul(transform[..., pi, :, :].clone(), transform[..., i, :, :].clone()) 76 | result[..., i, :] = torch.matmul(transform[..., i, :, :], offset[..., i, :, :]).squeeze() 77 | if world: result[..., i, :] += result[..., pi, :] 78 | return result 79 | 80 | def from_local_to_world(self, res: torch.Tensor): 81 | res = res.clone() 82 | for i, pi in enumerate(self.topology): 83 | if pi == 0 or pi == -1: 84 | continue 85 | res[..., i, :] += res[..., pi, :] 86 | return res 87 | 88 | @staticmethod 89 | def transform_from_euler(rotation, order): 90 | rotation = rotation / 180 * math.pi 91 | transform = torch.matmul(ForwardKinematics.transform_from_axis(rotation[..., 1], order[1]), 92 | ForwardKinematics.transform_from_axis(rotation[..., 2], order[2])) 93 | transform = torch.matmul(ForwardKinematics.transform_from_axis(rotation[..., 0], order[0]), transform) 94 | return transform 95 | 96 | @staticmethod 97 | def transform_from_axis(euler, axis): 98 | transform = torch.empty(euler.shape[0:3] + (3, 3), device=euler.device) 99 | cos = torch.cos(euler) 100 | sin = torch.sin(euler) 101 | cord = ord(axis) - ord('x') 102 | 103 | transform[..., cord, :] = transform[..., :, cord] = 0 104 | transform[..., cord, cord] = 1 105 | 106 | if axis == 'x': 107 | transform[..., 1, 1] = transform[..., 2, 2] = cos 108 | transform[..., 1, 2] = -sin 109 | transform[..., 2, 1] = sin 110 | if axis == 'y': 111 | transform[..., 0, 0] = transform[..., 2, 2] = cos 112 | transform[..., 0, 2] = sin 113 | transform[..., 2, 0] = -sin 114 | if axis == 'z': 115 | transform[..., 0, 0] = transform[..., 1, 1] = cos 116 | transform[..., 0, 1] = -sin 117 | transform[..., 1, 0] = sin 118 | 119 | return transform 120 | 121 | @staticmethod 122 | def transform_from_quaternion(quater: torch.Tensor): 123 | qw = quater[..., 0] 124 | qx = quater[..., 1] 125 | qy = quater[..., 2] 126 | qz = quater[..., 3] 127 | 128 | x2 = qx + qx 129 | y2 = qy + qy 130 | z2 = qz + qz 131 | xx = qx * x2 132 | yy = qy * y2 133 | wx = qw * x2 134 | xy = qx * y2 135 | yz = qy * z2 136 | wy = qw * y2 137 | xz = qx * z2 138 | zz = qz * z2 139 | wz = qw * z2 140 | 141 | m = torch.empty(quater.shape[:-1] + (3, 3), device=quater.device) 142 | m[..., 0, 0] = 1.0 - (yy + zz) 143 | m[..., 0, 1] = xy - wz 144 | m[..., 0, 2] = xz + wy 145 | m[..., 1, 0] = xy + wz 146 | m[..., 1, 1] = 1.0 - (xx + zz) 147 | m[..., 1, 2] = yz - wx 148 | m[..., 2, 0] = xz - wy 149 | m[..., 2, 1] = yz + wx 150 | m[..., 2, 2] = 1.0 - (xx + yy) 151 | 152 | return m 153 | 154 | 155 | class InverseKinematics: 156 | def __init__(self, rotations: torch.Tensor, positions: torch.Tensor, offset, parents, constrains): 157 | self.rotations = rotations 158 | self.rotations.requires_grad_(True) 159 | self.position = positions 160 | self.position.requires_grad_(True) 161 | 162 | self.parents = parents 163 | self.offset = offset 164 | self.constrains = constrains 165 | 166 | self.optimizer = torch.optim.Adam([self.position, self.rotations], lr=1e-3, betas=(0.9, 0.999)) 167 | self.crit = nn.MSELoss() 168 | 169 | def step(self): 170 | self.optimizer.zero_grad() 171 | glb = self.forward(self.rotations, self.position, self.offset, order='', quater=True, world=True) 172 | loss = self.crit(glb, self.constrains) 173 | loss.backward() 174 | self.optimizer.step() 175 | self.glb = glb 176 | return loss.item() 177 | 178 | def tloss(self, time): 179 | return self.crit(self.glb[time, :], self.constrains[time, :]) 180 | 181 | def all_loss(self): 182 | res = [self.tloss(t).detach().numpy() for t in range(self.constrains.shape[0])] 183 | return np.array(res) 184 | 185 | ''' 186 | rotation should have shape batch_size * Joint_num * (3/4) * Time 187 | position should have shape batch_size * 3 * Time 188 | offset should have shape batch_size * Joint_num * 3 189 | output have shape batch_size * Time * Joint_num * 3 190 | ''' 191 | 192 | def forward(self, rotation: torch.Tensor, position: torch.Tensor, offset: torch.Tensor, order='xyz', quater=False, 193 | world=True): 194 | ''' 195 | if not quater and rotation.shape[-2] != 3: raise Exception('Unexpected shape of rotation') 196 | if quater and rotation.shape[-2] != 4: raise Exception('Unexpected shape of rotation') 197 | rotation = rotation.permute(0, 3, 1, 2) 198 | position = position.permute(0, 2, 1) 199 | ''' 200 | result = torch.empty(rotation.shape[:-1] + (3,), device=position.device) 201 | 202 | norm = torch.norm(rotation, dim=-1, keepdim=True) 203 | rotation = rotation / norm 204 | 205 | if quater: 206 | transform = self.transform_from_quaternion(rotation) 207 | else: 208 | transform = self.transform_from_euler(rotation, order) 209 | 210 | offset = offset.reshape((-1, 1, offset.shape[-2], offset.shape[-1], 1)) 211 | 212 | result[..., 0, :] = position 213 | for i, pi in enumerate(self.parents): 214 | if pi == -1: 215 | assert i == 0 216 | continue 217 | 218 | result[..., i, :] = torch.matmul(transform[..., pi, :, :], offset[..., i, :, :]).squeeze() 219 | transform[..., i, :, :] = torch.matmul(transform[..., pi, :, :].clone(), transform[..., i, :, :].clone()) 220 | if world: result[..., i, :] += result[..., pi, :] 221 | return result 222 | 223 | @staticmethod 224 | def transform_from_euler(rotation, order): 225 | rotation = rotation / 180 * math.pi 226 | transform = torch.matmul(ForwardKinematics.transform_from_axis(rotation[..., 1], order[1]), 227 | ForwardKinematics.transform_from_axis(rotation[..., 2], order[2])) 228 | transform = torch.matmul(ForwardKinematics.transform_from_axis(rotation[..., 0], order[0]), transform) 229 | return transform 230 | 231 | @staticmethod 232 | def transform_from_axis(euler, axis): 233 | transform = torch.empty(euler.shape[0:3] + (3, 3), device=euler.device) 234 | cos = torch.cos(euler) 235 | sin = torch.sin(euler) 236 | cord = ord(axis) - ord('x') 237 | 238 | transform[..., cord, :] = transform[..., :, cord] = 0 239 | transform[..., cord, cord] = 1 240 | 241 | if axis == 'x': 242 | transform[..., 1, 1] = transform[..., 2, 2] = cos 243 | transform[..., 1, 2] = -sin 244 | transform[..., 2, 1] = sin 245 | if axis == 'y': 246 | transform[..., 0, 0] = transform[..., 2, 2] = cos 247 | transform[..., 0, 2] = sin 248 | transform[..., 2, 0] = -sin 249 | if axis == 'z': 250 | transform[..., 0, 0] = transform[..., 1, 1] = cos 251 | transform[..., 0, 1] = -sin 252 | transform[..., 1, 0] = sin 253 | 254 | return transform 255 | 256 | @staticmethod 257 | def transform_from_quaternion(quater: torch.Tensor): 258 | qw = quater[..., 0] 259 | qx = quater[..., 1] 260 | qy = quater[..., 2] 261 | qz = quater[..., 3] 262 | 263 | x2 = qx + qx 264 | y2 = qy + qy 265 | z2 = qz + qz 266 | xx = qx * x2 267 | yy = qy * y2 268 | wx = qw * x2 269 | xy = qx * y2 270 | yz = qy * z2 271 | wy = qw * y2 272 | xz = qx * z2 273 | zz = qz * z2 274 | wz = qw * z2 275 | 276 | m = torch.empty(quater.shape[:-1] + (3, 3), device=quater.device) 277 | m[..., 0, 0] = 1.0 - (yy + zz) 278 | m[..., 0, 1] = xy - wz 279 | m[..., 0, 2] = xz + wy 280 | m[..., 1, 0] = xy + wz 281 | m[..., 1, 1] = 1.0 - (xx + zz) 282 | m[..., 1, 2] = yz - wx 283 | m[..., 2, 0] = xz - wy 284 | m[..., 2, 1] = yz + wx 285 | m[..., 2, 2] = 1.0 - (xx + yy) 286 | 287 | return m 288 | 289 | class InverseKinematics_humdog: 290 | def __init__(self, rot, pos, offset, parents, constraints): 291 | self.rot = rot 292 | self.pos = pos 293 | self.rot.requires_grad_(True) 294 | self.pos.requires_grad_(True) 295 | 296 | self.parents = parents 297 | self.offset = offset 298 | self.constraints = constraints 299 | 300 | self.optimizer = torch.optim.Adam([self.pos, self.rot], lr=1e-3, betas=(0.9, 0.999)) 301 | self.crit = nn.MSELoss() 302 | 303 | def step(self): 304 | self.optimizer.zero_grad() 305 | 306 | glb = self.forward(self.rot, self.pos, self.offset, self.parents) 307 | loss = self.crit(glb, self.constraints) 308 | loss.backward() 309 | self.optimizer.step() 310 | self.glb = glb 311 | return loss.item() 312 | 313 | def forward(self, rot, pos, offset, parents): 314 | offset = offset.reshape(1, offset.shape[-2], offset.shape[-1]).repeat(rot.shape[0], 1, 1) 315 | offset[:, 0, :] = pos 316 | norm = torch.norm(rot, dim=-1, keepdim=True) 317 | rot = rot / norm 318 | _, x = quat_fk(rot, offset, parents) 319 | return x -------------------------------------------------------------------------------- /models/multi_attention_forward.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.functional import softmax, dropout 3 | 4 | def linear(input, weight, bias=None): 5 | # type: (Tensor, Tensor, Optional[Tensor]) -> Tensor 6 | r""" 7 | Applies a linear transformation to the incoming data: :math:`y = xA^T + b`. 8 | Shape: 9 | - Input: :math:`(N, *, in\_features)` where `*` means any number of 10 | additional dimensions 11 | - Weight: :math:`(out\_features, in\_features)` 12 | - Bias: :math:`(out\_features)` 13 | - Output: :math:`(N, *, out\_features)` 14 | """ 15 | if input.dim() == 2 and bias is not None: 16 | # fused op is marginally faster 17 | ret = torch.addmm(bias, input, weight.t()) 18 | else: 19 | output = input.matmul(weight.t()) 20 | if bias is not None: 21 | output += bias 22 | ret = output 23 | return ret 24 | 25 | 26 | def multi_head_attention_forward(query, # type: Tensor 27 | key, # type: Tensor 28 | value, # type: Tensor 29 | embed_dim_to_check, # type: int 30 | num_heads, # type: int 31 | in_proj_weight, # type: Tensor 32 | in_proj_bias, # type: Tensor 33 | bias_k, # type: Optional[Tensor] 34 | bias_v, # type: Optional[Tensor] 35 | add_zero_attn, # type: bool 36 | dropout_p, # type: float 37 | out_proj_weight, # type: Tensor 38 | out_proj_bias, # type: Tensor 39 | training=True, # type: bool 40 | key_padding_mask=None, # type: Optional[Tensor] 41 | need_weights=True, # type: bool 42 | attn_mask=None, # type: Optional[Tensor] 43 | use_separate_proj_weight=False, # type: bool 44 | q_proj_weight=None, # type: Optional[Tensor] 45 | k_proj_weight=None, # type: Optional[Tensor] 46 | v_proj_weight=None, # type: Optional[Tensor] 47 | static_k=None, # type: Optional[Tensor] 48 | static_v=None, # type: Optional[Tensor] 49 | layer_cache=None # type: Optional[dict] 50 | ): 51 | # type: (...) -> Tuple[Tensor, Optional[Tensor]] 52 | r""" 53 | Args: 54 | query, key, value: map a query and a set of key-value pairs to an output. 55 | See "Attention Is All You Need" for more details. 56 | embed_dim_to_check: total dimension of the model. 57 | num_heads: parallel attention heads. 58 | in_proj_weight, in_proj_bias: input projection weight and bias. 59 | bias_k, bias_v: bias of the key and value sequences to be added at dim=0. 60 | add_zero_attn: add a new batch of zeros to the key and 61 | value sequences at dim=1. 62 | dropout_p: probability of an element to be zeroed. 63 | out_proj_weight, out_proj_bias: the output projection weight and bias. 64 | training: apply dropout if is ``True``. 65 | key_padding_mask: if provided, specified padding elements in the key will 66 | be ignored by the attention. This is an binary mask. When the value is True, 67 | the corresponding value on the attention layer will be filled with -inf. 68 | need_weights: output attn_output_weights. 69 | attn_mask: mask that prevents attention to certain positions. This is an additive mask 70 | (i.e. the values will be added to the attention layer). 71 | use_separate_proj_weight: the function accept the proj. weights for query, key, 72 | and value in different forms. If false, in_proj_weight will be used, which is 73 | a combination of q_proj_weight, k_proj_weight, v_proj_weight. 74 | q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias. 75 | static_k, static_v: static key and value used for attention operators. 76 | Shape: 77 | Inputs: 78 | - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is 79 | the embedding dimension. 80 | - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is 81 | the embedding dimension. 82 | - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is 83 | the embedding dimension. 84 | - key_padding_mask: :math:`(N, S)`, ByteTensor, where N is the batch size, S is the source sequence length. 85 | - attn_mask: :math:`(L, S)` where L is the target sequence length, S is the source sequence length. 86 | - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length, 87 | N is the batch size, E is the embedding dimension. E/num_heads is the head dimension. 88 | - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length, 89 | N is the batch size, E is the embedding dimension. E/num_heads is the head dimension. 90 | Outputs: 91 | - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, 92 | E is the embedding dimension. 93 | - attn_output_weights: :math:`(N, L, S)` where N is the batch size, 94 | L is the target sequence length, S is the source sequence length. 95 | """ 96 | 97 | tgt_len, bsz, embed_dim = query.size() 98 | assert embed_dim == embed_dim_to_check 99 | assert key.size() == value.size() 100 | 101 | head_dim = embed_dim // num_heads 102 | assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads" 103 | scaling = float(head_dim) ** -0.5 104 | 105 | if not use_separate_proj_weight: 106 | if torch.equal(query, key) and torch.equal(key, value): # attn type: self 107 | # self-attention 108 | q, k, v = linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1) 109 | if layer_cache is not None: 110 | device = k.device 111 | if layer_cache["self_keys"] is not None: 112 | k = torch.cat( 113 | (layer_cache["self_keys"].to(device), k), 114 | dim=0) 115 | if layer_cache["self_values"] is not None: 116 | v = torch.cat( 117 | (layer_cache["self_values"].to(device), v), 118 | dim=0) 119 | layer_cache["self_keys"] = k 120 | layer_cache["self_values"] = v 121 | 122 | elif torch.equal(key, value): # attn type: context 123 | # encoder-decoder attention 124 | # This is inline in_proj function with in_proj_weight and in_proj_bias 125 | _b = in_proj_bias 126 | _start = 0 127 | _end = embed_dim 128 | _w = in_proj_weight[_start:_end, :] 129 | if _b is not None: 130 | _b = _b[_start:_end] 131 | q = linear(query, _w, _b) 132 | 133 | if key is None: 134 | assert value is None 135 | k = None 136 | v = None 137 | else: 138 | 139 | # This is inline in_proj function with in_proj_weight and in_proj_bias 140 | if layer_cache is not None: 141 | if layer_cache["memory_keys"] is None: 142 | _b = in_proj_bias 143 | _start = embed_dim 144 | _end = None 145 | _w = in_proj_weight[_start:, :] 146 | if _b is not None: 147 | _b = _b[_start:] 148 | k, v = linear(key, _w, _b).chunk(2, dim=-1) 149 | else: 150 | k, v = layer_cache["memory_keys"], \ 151 | layer_cache["memory_values"] 152 | layer_cache["memory_keys"] = k 153 | layer_cache["memory_values"] = v 154 | else: 155 | _b = in_proj_bias 156 | _start = embed_dim 157 | _end = None 158 | _w = in_proj_weight[_start:, :] 159 | if _b is not None: 160 | _b = _b[_start:] 161 | k, v = linear(key, _w, _b).chunk(2, dim=-1) 162 | else: 163 | # This is inline in_proj function with in_proj_weight and in_proj_bias 164 | _b = in_proj_bias 165 | _start = 0 166 | _end = embed_dim 167 | _w = in_proj_weight[_start:_end, :] 168 | if _b is not None: 169 | _b = _b[_start:_end] 170 | q = linear(query, _w, _b) 171 | 172 | # This is inline in_proj function with in_proj_weight and in_proj_bias 173 | _b = in_proj_bias 174 | _start = embed_dim 175 | _end = embed_dim * 2 176 | _w = in_proj_weight[_start:_end, :] 177 | if _b is not None: 178 | _b = _b[_start:_end] 179 | k = linear(key, _w, _b) 180 | 181 | # This is inline in_proj function with in_proj_weight and in_proj_bias 182 | _b = in_proj_bias 183 | _start = embed_dim * 2 184 | _end = None 185 | _w = in_proj_weight[_start:, :] 186 | if _b is not None: 187 | _b = _b[_start:] 188 | v = linear(value, _w, _b) 189 | else: 190 | q_proj_weight_non_opt = torch.jit._unwrap_optional(q_proj_weight) 191 | len1, len2 = q_proj_weight_non_opt.size() 192 | assert len1 == embed_dim and len2 == query.size(-1) 193 | 194 | k_proj_weight_non_opt = torch.jit._unwrap_optional(k_proj_weight) 195 | len1, len2 = k_proj_weight_non_opt.size() 196 | assert len1 == embed_dim and len2 == key.size(-1) 197 | 198 | v_proj_weight_non_opt = torch.jit._unwrap_optional(v_proj_weight) 199 | len1, len2 = v_proj_weight_non_opt.size() 200 | assert len1 == embed_dim and len2 == value.size(-1) 201 | 202 | if in_proj_bias is not None: 203 | q = linear(query, q_proj_weight_non_opt, in_proj_bias[0:embed_dim]) 204 | k = linear(key, k_proj_weight_non_opt, in_proj_bias[embed_dim:(embed_dim * 2)]) 205 | v = linear(value, v_proj_weight_non_opt, in_proj_bias[(embed_dim * 2):]) 206 | else: 207 | q = linear(query, q_proj_weight_non_opt, in_proj_bias) 208 | k = linear(key, k_proj_weight_non_opt, in_proj_bias) 209 | v = linear(value, v_proj_weight_non_opt, in_proj_bias) 210 | q = q * scaling 211 | 212 | if bias_k is not None and bias_v is not None: 213 | if static_k is None and static_v is None: 214 | k = torch.cat([k, bias_k.repeat(1, bsz, 1)]) 215 | v = torch.cat([v, bias_v.repeat(1, bsz, 1)]) 216 | if attn_mask is not None: 217 | attn_mask = torch.cat([attn_mask, 218 | torch.zeros((attn_mask.size(0), 1), 219 | dtype=attn_mask.dtype, 220 | device=attn_mask.device)], dim=1) 221 | if key_padding_mask is not None: 222 | key_padding_mask = torch.cat( 223 | [key_padding_mask, torch.zeros((key_padding_mask.size(0), 1), 224 | dtype=key_padding_mask.dtype, 225 | device=key_padding_mask.device)], dim=1) 226 | else: 227 | assert static_k is None, "bias cannot be added to static key." 228 | assert static_v is None, "bias cannot be added to static value." 229 | else: 230 | assert bias_k is None 231 | assert bias_v is None 232 | 233 | q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1) 234 | if k is not None: 235 | k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) 236 | if v is not None: 237 | v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) 238 | 239 | if static_k is not None: 240 | assert static_k.size(0) == bsz * num_heads 241 | assert static_k.size(2) == head_dim 242 | k = static_k 243 | 244 | if static_v is not None: 245 | assert static_v.size(0) == bsz * num_heads 246 | assert static_v.size(2) == head_dim 247 | v = static_v 248 | 249 | src_len = k.size(1) 250 | 251 | if key_padding_mask is not None: 252 | assert key_padding_mask.size(0) == bsz 253 | assert key_padding_mask.size(1) == src_len 254 | 255 | if add_zero_attn: 256 | src_len += 1 257 | k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1) 258 | v = torch.cat([v, torch.zeros((v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)], dim=1) 259 | if attn_mask is not None: 260 | attn_mask = torch.cat([attn_mask, torch.zeros((attn_mask.size(0), 1), 261 | dtype=attn_mask.dtype, 262 | device=attn_mask.device)], dim=1) 263 | if key_padding_mask is not None: 264 | key_padding_mask = torch.cat( 265 | [key_padding_mask, torch.zeros((key_padding_mask.size(0), 1), 266 | dtype=key_padding_mask.dtype, 267 | device=key_padding_mask.device)], dim=1) 268 | 269 | attn_output_weights = torch.bmm(q, k.transpose(1, 2)) 270 | assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len] 271 | 272 | if attn_mask is not None: 273 | attn_mask = attn_mask.unsqueeze(0) 274 | attn_output_weights += attn_mask 275 | 276 | if key_padding_mask is not None: 277 | attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) 278 | attn_output_weights = attn_output_weights.masked_fill( 279 | key_padding_mask.unsqueeze(1).unsqueeze(2), 280 | float('-inf'), 281 | ) 282 | attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len) 283 | 284 | attn_output_weights = softmax( 285 | attn_output_weights, dim=-1) 286 | 287 | attn_output_weights = dropout(attn_output_weights, p=dropout_p, training=training) 288 | 289 | attn_output = torch.bmm(attn_output_weights, v) 290 | # import pdb; pdb.set_trace() 291 | assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] 292 | attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) 293 | attn_output = linear(attn_output, out_proj_weight, out_proj_bias) 294 | if need_weights: 295 | # average attention weights over heads 296 | attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) 297 | return attn_output, attn_output_weights.sum(dim=1) / num_heads 298 | else: 299 | return attn_output, None 300 | --------------------------------------------------------------------------------