├── .gitignore ├── README.md ├── configs ├── bvh.yaml ├── skin.yaml └── vox.yaml ├── data └── README.md ├── datasets ├── mixamo_bvh_dataset.py ├── mixamo_skin_dataset.py └── mixamo_vox_dataset.py ├── evaluate_rig.py ├── models ├── gcn_modules.py ├── mixamo_bvh_model.py ├── mixamo_skin_model.py ├── mixamo_vox_model.py └── networks │ ├── bvh_simple.py │ ├── vox_hourglass.py │ └── vox_simple.py ├── notebooks └── Combine Autorigging Output for Evaluation.ipynb ├── prepare_vol_geo.py ├── preprocess ├── blender_utils │ ├── DataGeneration.md │ ├── dae_to_bvh.py │ ├── dae_to_obj.py │ ├── normalize_and_decimate_obj.py │ ├── remove_joints.py │ └── remove_vn_from_obj.py ├── collada_utils │ ├── extract_weights.py │ └── transform_parser.py └── volume │ ├── binvox │ ├── flood-fill │ ├── binvox_rw.py │ ├── main.ipynb │ ├── main.py │ └── util.py │ ├── notebook.ipynb │ ├── obj_rot_fixed_to_binvox.py │ ├── obj_to_binvox2.py │ └── util │ ├── binvox_rw.py │ └── rigging_parser │ └── obj_parser.py ├── requirements.txt ├── sdf visualizer.ipynb ├── test_bvh.py ├── test_skin.py ├── test_vox.py ├── train_bvh.py ├── train_skin.py ├── train_vox.py └── utils ├── binvox_rw.py ├── bvh_utils.py ├── common_ops.py ├── compute_volumetric_geodesic.py ├── ik_utils.py ├── joint_util.py ├── loss_utils.py ├── misc ├── calc_IBM.py ├── dataset_split.py ├── extract_weights.py ├── joint_tree_util.py ├── transpose.py └── vis_util.py ├── obj_utils.py ├── rig_parser.py ├── rotation_util.py ├── skin_util.py ├── test_skin_utils.py ├── train_bvh_utils.py ├── train_skin_utils.py ├── train_vox_utils.py ├── tree_utils.py └── voxel_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints/ 2 | 3 | data/* 4 | !data/README.md 5 | 6 | utils/test/ 7 | result/ 8 | logs/ 9 | logs 10 | seg/ 11 | *.pyc 12 | *.ipynb_checkpoints 13 | .idea* 14 | *.egg-info/ 15 | test_results/ 16 | *.sdf 17 | /.vscode/ 18 | notebooks/result 19 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Auto-rigging 3D Bipedal Characters in Arbitrary Poses 2 | Pytorch code for the paper Auto-rigging 3D Bipedal Characters in Arbitrary Poses 3 | 4 | **Jeonghwan Kim, Hyeontae Son, Jinseok Bae, Young Min Kim** 5 | 6 | ## Prerequisites 7 | 8 | #### Download Mixamo Dataset 9 | - [Mixamo Dataset](https://drive.google.com/file/d/1d6o28Mu9yNaYCIWdZ-tiDDnTECnYDHzx/view?usp=sharing) 10 | 11 | 12 | #### Install Dependencies 13 | ``` 14 | # Create conda environment 15 | conda create -n autorigging python=3.7 16 | conda activate autorigging 17 | 18 | # Install torch-related 19 | conda install pytorch==1.6.0 torchvision==0.7.0 cudatoolkit=10.1 -c pytorch 20 | pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.6.0+cu101.html 21 | pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-1.6.0+cu101.html 22 | pip install torch-cluster -f https://pytorch-geometric.com/whl/torch-1.6.0+cu101.html 23 | pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-1.6.0+cu101.html 24 | pip install torch-geometric 25 | 26 | # Install other packages 27 | pip install -r requirements.txt 28 | 29 | ``` 30 | 31 | ## Usage 32 | - Tested on Ubuntu 18.04 33 | 34 | ### Training 35 | - You can train the model either by modifying the config files in `configs` directory or by using command-line arguments 36 | ``` 37 | python train_vox.py -c configs/vox.yaml --log_dir=logs/vox 38 | python train_bvh.py -c configs/bvh.yaml --log_dir=logs/bvh 39 | python train_skin.py -c configs/skin.yaml --log_dir=logs/skin 40 | ``` 41 | 42 | - Settings used for our paper: 43 | ``` 44 | python train_vox.py -c logs/vox.yaml --log_dir=logs/vox_ori_all_s4.5_HG_mean_stack2_down2_lr3e-4_b4_ce --loss_type=ce --batch_size=4 --downsample=2 --n_stack=2 --sigma=4.5 --mean_hourglass --data_dir=data/mixamo --no-reduce_motion --save_epoch=1 45 | 46 | python train_bvh.py -c configs/bvh.yaml --log_dir=logs/bvh/bvh_all_lr1e-3_zeroroot_bn --zero_root --bn --lr 1e-3 47 | 48 | python train_skin.py -c configs/skin.yaml --workers=32 --batch_size=4 --lr=1e-4 --use_bn --log_dir=logs/skin/b4_tpl_and_euc0.12_lr1e-4_bn --euc_radius=0.12 --vis_step=10 --save_step=10 --edge_type tpl_and_euc 49 | ``` 50 | 51 | ### Testing 52 | # Joint position prediction 53 | python test_vox.py -c logs/vox/vox_ori_all_s4.5_HG_mean_stack2_down2_lr3e-4_b4_ce/config.yaml --model=logs/vox/vox_ori_all_s4.5_HG_mean_stack2_down2_lr3e-4_b4_ce/model_epoch_030.pth 54 | 55 | # Joint rotation prediction 56 | python test_bvh.py --config logs/bvh/bvh_all_lr1e-3_zeroroot_bn/config.yaml --model logs/bvh/bvh_all_lr1e-3_zeroroot_bn/model_epoch_709.pth --joint_path logs/vox/vox_ori_all_s4.5_HG_mean_stack2_down2_lr3e-4_b4_ce/test 57 | 58 | # For calculating geodesic distance based on predicted joint position 59 | ## Note, geodesic distance for ground truth joint position is calculated in the dataset 60 | python prepare_vol_geo.py 61 | 62 | # Skin weight prediction 63 | python test_skin.py --config logs/skin/b4_tpl_and_euc0.12_lr1e-4_bn/config.yaml --model=logs/skin/b4_tpl_and_euc0.12_lr1e-4_bn/model_epoch_030.pth --vol_geo_dir logs/vox/vox_ori_all_s4.5_HG_mean_stack2_down2_lr3e-4_b4_ce/volumetric_geodesic_ours_final 64 | 65 | ### Evaluation 66 | - Evaluation is done via `evaluate_rig.py` 67 | - We have to put the result data in the right `dataset_dir` and add `--same_skeleton` when evaluating our model while omit the flag when evaluating rignet generated code. -------------------------------------------------------------------------------- /configs/bvh.yaml: -------------------------------------------------------------------------------- 1 | # Data Setting 2 | batch_size: 128 3 | data_dir: data/mixamo 4 | num_joints: 22 5 | workers: 8 6 | 7 | # Network Setting 8 | network: simple_bn 9 | 10 | # Training Setting 11 | vis_overfit: false 12 | overfit: false 13 | nepoch: 1000 14 | reproduce: true 15 | ## Learning Rate Setting 16 | lr: 1e-3 17 | lr_gamma: 1.0 18 | lr_step_size: 5 19 | ## Logging Setting 20 | time: true 21 | save_epoch: 1 22 | vis_epoch: 10 23 | 24 | 25 | 26 | -------------------------------------------------------------------------------- /configs/skin.yaml: -------------------------------------------------------------------------------- 1 | batch_size: 6 2 | bindpose_loss_type: glob 3 | bm_rot_hp: 0 4 | bm_shape_hp: 0 5 | bm_trans_hp: 0 6 | channels: [64, 128, 256] 7 | data_dir: data/mixamo 8 | feature_size: 512 9 | global_feature_size: 512 10 | joint_loss_type: glob 11 | lr: 1e-4 12 | lr_gamma: 1.0 13 | lr_step_size: 5 14 | nepoch: 300 15 | num_joints: 22 16 | overfit: false 17 | k: -1 18 | reproduce: true 19 | rot_hp: 1 20 | save_step: 1 21 | skin_hp: 1 22 | trans_hp: 1e2 23 | use_bindpose: true 24 | use_bn: false 25 | use_gt_ibm: false 26 | vis_overfit: false 27 | vis_step: 1 28 | workers: 32 29 | network_type: mesh 30 | # network_type: full 31 | 32 | -------------------------------------------------------------------------------- /configs/vox.yaml: -------------------------------------------------------------------------------- 1 | # Data Setting 2 | batch_size: 3 3 | data_dir: data/mixamo 4 | num_joints: 22 5 | dim_ori: 82 6 | dim_pad: 88 7 | padding: 3 8 | 9 | workers: 8 10 | 11 | # Network Setting 12 | network: HG2 13 | downsample: 1 14 | n_stack: 1 15 | activation: none 16 | normalize_heatmap: false 17 | loss_type: ce 18 | 19 | # Training Setting 20 | vis_overfit: false 21 | overfit: false 22 | nepoch: 10000 23 | reproduce: true 24 | ## Learning Rate Setting 25 | lr: 3e-4 26 | lr_gamma: 1.0 27 | lr_step_size: 5 28 | ## Logging Setting 29 | time: true 30 | save_epoch: 1 31 | vis_epoch: 1 32 | 33 | 34 | 35 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | #### Dataset Folder 2 | - Extract the ```mixamo.tar.gz``` folder as ```mixamo``` 3 | -------------------------------------------------------------------------------- /datasets/mixamo_bvh_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import numpy as np 5 | import torch 6 | import torch.utils.data as data 7 | 8 | from scipy.spatial.transform import Rotation as R 9 | from tqdm import tqdm 10 | from bvh import Bvh 11 | 12 | sys.path.insert(0, '..') 13 | sys.path.insert(0, '.') 14 | from utils.bvh_utils import get_bvh_offsets_and_animation_load_all, get_global_bvh_offsets, get_global_bvh_rotations, get_animated_bvh_joint_positions_single 15 | 16 | class MixamoBVHDataset(data.Dataset): 17 | def __init__(self, split, configs=None, use_front_faced=False, joint_path=None): 18 | self.data_dir = data_dir = configs['data_dir'] 19 | self.configs = configs 20 | self.split = split # [train_models.txt/valid_models.txt/test_models.txt] 21 | if 'train' in split: 22 | self.mode = 'train' 23 | elif 'valid' in split: 24 | self.mode = 'eval' 25 | elif 'test' in split: 26 | self.mode = 'vis' 27 | self.use_front_faced = use_front_faced 28 | self.joint_path = joint_path 29 | 30 | self.characters_list = open(os.path.join(data_dir, split)).read().splitlines() 31 | # count # of objs 32 | self.frames_list = sorted([motion.split('.binvox')[0] for motion in 33 | os.listdir(os.path.join(data_dir, 'objs/{}'.format(self.characters_list[0]))) 34 | if motion.endswith('.binvox') and motion != 'bindpose.binvox']) 35 | self.motions_list = [] 36 | for frame_name in self.frames_list: 37 | motion_name = '_'.join(frame_name.split('_')[:-1]) 38 | if motion_name not in self.motions_list: 39 | self.motions_list.append(motion_name) 40 | if configs['overfit']: 41 | self.characters_list = [self.characters_list[0]] 42 | self.frames_list = [self.frames_list[0]] 43 | 44 | # reduce number of motions 45 | # Samba Dancing: ceil(274/15), Warming Up: ceil(95/10)*2, Shoved Reaction With Spin: ceil(45/5)*2, 46 | def keep_motion(motion): 47 | if not configs['reduce_motion']: 48 | return True 49 | motion_name = motion.split('_')[0] 50 | frame_idx = int(motion.split('_')[-1]) 51 | if self.mode == 'vis': 52 | return frame_idx == 0 53 | if motion_name == 'Samba Dancing': 54 | return frame_idx % 15 == 0 55 | elif motion_name == 'Warming Up': 56 | return frame_idx % 10 == 0 57 | elif motion_name == 'Shoved Reaction With Spin': 58 | return frame_idx % 5 == 0 59 | else: # Back Squat: 18, Drunk Walk Backwards: 13 60 | return True 61 | self.frames_list = [motion for motion in self.frames_list 62 | if keep_motion(motion)] 63 | self.num_characters = len(self.characters_list) 64 | self.frames_per_character = len(self.frames_list) 65 | 66 | self.n_joint = 22 67 | 68 | self.mocap_files = [[0 for _ in range(len(self.motions_list))] for _ in range(self.num_characters)] 69 | self.mocap_info = [[0 for _ in range(len(self.motions_list))] for _ in range(self.num_characters)] 70 | 71 | try: 72 | self.mocap_files = np.load(os.path.join(data_dir, 'animated/mocap_files_%s.npy'%(self.mode)), allow_pickle=True) 73 | self.mocap_info = np.load(os.path.join(data_dir, 'animated/mocap_info_%s.npy'%(self.mode)), allow_pickle=True) 74 | except: 75 | if self.joint_path is None: 76 | print("Loading %s mocap files"%(self.mode)) 77 | for character_idx, character_name in tqdm(enumerate(self.characters_list)): 78 | for motion_idx, motion_name in enumerate(self.motions_list): 79 | bvh_file_name = os.path.join(data_dir, 'animated/{}/{}.bvh'.format(character_name, motion_name)) 80 | with open(bvh_file_name) as f: 81 | mocap = Bvh(f.read()) 82 | self.mocap_files[character_idx][motion_idx] = mocap 83 | offsets_list, rotations_list, root_position_list = get_bvh_offsets_and_animation_load_all(mocap) 84 | self.mocap_info[character_idx][motion_idx] = {'offsets': offsets_list, 'rotations': rotations_list, 85 | 'root_position': root_position_list} 86 | np.save(os.path.join(data_dir, 'animated/mocap_files_%s'%(self.mode)), self.mocap_files) 87 | np.save(os.path.join(data_dir, 'animated/mocap_info_%s'%(self.mode)), self.mocap_info) 88 | print("save complete") 89 | 90 | 91 | 92 | def __len__(self): 93 | return self.num_characters * self.frames_per_character 94 | 95 | def __getitem__(self, idx): 96 | if torch.is_tensor(idx): 97 | idx = idx.tolist() 98 | 99 | if self.joint_path is not None: 100 | character_idx = idx // self.frames_per_character 101 | character_name = self.characters_list[character_idx] 102 | 103 | frame_idx = idx % self.frames_per_character 104 | frame_name = self.frames_list[frame_idx] 105 | joint_pos_file = os.path.join(self.joint_path, character_name, '%s_joint.npy'%(frame_name)) 106 | joint_pos = np.load(joint_pos_file, allow_pickle=True).item() 107 | pred_heatmap_joint_pos_mask = joint_pos['pred_heatmap_joint_pos_mask'] 108 | joint_positions = pred_heatmap_joint_pos_mask 109 | return joint_positions, 0, {'character_name': character_name, 'motion_name': frame_name} 110 | 111 | 112 | return self.get_item(self.data_dir, idx, mode=self.mode) 113 | 114 | # @staticmethod 115 | def get_item(self, data_dir, idx, mode='vis'): 116 | character_idx = idx // self.frames_per_character 117 | character_name = self.characters_list[character_idx] 118 | 119 | frame_idx = idx % self.frames_per_character 120 | frame_name = self.frames_list[frame_idx] 121 | 122 | # Samba Dancing_000250 -> Samba Dancing, 250 123 | motion_name, frame_number = '_'.join(frame_name.split('_')[:-1]), int(frame_name.split('_')[-1]) 124 | motion_idx = self.motions_list.index(motion_name) 125 | mocap_data = self.mocap_info[character_idx][motion_idx] 126 | offsets, rotations, root_position = mocap_data['offsets'][frame_number], mocap_data['rotations'][frame_number], mocap_data['root_position'][frame_number] 127 | 128 | if self.configs['augment_rot_std'] > 0: 129 | rotations = self.augment_rotation(rotations, self.configs['augment_rot_std']) 130 | 131 | glob_rotations = get_global_bvh_rotations(rotations) 132 | glob_offsets = get_global_bvh_offsets(offsets) 133 | joint_positions = get_animated_bvh_joint_positions_single(offsets, glob_rotations, root_position) 134 | 135 | if self.configs['jitter_pos_std'] > 0: 136 | joint_positions = self.jitter_position(joint_positions, self.configs['jitter_pos_std']) 137 | 138 | meta = { 139 | 'character_name': character_name, 'motion_name': frame_name, 'mode': mode, 140 | 'offsets': offsets, 141 | 'rotations': rotations, 142 | 'root_position': root_position 143 | } 144 | # ToDo: do some jittering for robust bvh prediction 145 | return joint_positions, glob_rotations, meta 146 | 147 | def augment_rotation(self, rotations, angle_std): 148 | rotations_list = [] 149 | for rotation in rotations: # predefined joint_names 150 | rot_mat = R.from_matrix(rotation) 151 | rot_q = rot_mat.as_euler('XYZ', degrees=True) 152 | 153 | rot_noise = np.random.normal(0, angle_std, size=(3)) # degrees 154 | rot_q_aug = rot_q + rot_noise 155 | 156 | rot_q_aug = R.from_euler('XYZ', rot_q_aug, degrees=True) # should we use 'XYZ', degrees=True? 157 | rot_aug = rot_q_aug.as_matrix() 158 | # rot_aug = rot_aug.reshape(1, 3, 3) 159 | rotations_list.append(rot_aug) 160 | rotation_matrices = np.stack(rotations_list) 161 | return rotation_matrices 162 | 163 | def jitter_position(self, positions, pos_std): 164 | jitter_p = np.random.normal(0, pos_std, size=positions.shape) 165 | return positions+jitter_p -------------------------------------------------------------------------------- /datasets/mixamo_skin_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os import listdir 3 | from os.path import join 4 | 5 | import torch 6 | from torch_geometric.data import Data, Dataset 7 | import numpy as np 8 | from tqdm import tqdm 9 | try: 10 | import open3d as o3d 11 | use_o3d=True 12 | except: 13 | print("Unable to load open3d") 14 | use_o3d=False 15 | 16 | 17 | class MixamoSkinDataset(Dataset): 18 | def __init__(self, root="data/mixamo", transform=None, pre_transform=None, data_dir="data/mixamo", split='train', vol_geo_dir=None, 19 | num_joints=22, center_yaxis=True, preprocess=False, datatype=None, configs=None, test_all=False): 20 | self.split = split 21 | self.overfit = split.endswith('overfit') 22 | self.center_yaxis=center_yaxis 23 | self.preprocess = preprocess 24 | self.configs=configs 25 | self.test_all=test_all 26 | 27 | self.datatype = datatype # "point" 28 | 29 | self.data_dir = data_dir 30 | self.obj_dir = join(data_dir, 'objs') # inference based on objs inside obj_dir 31 | # self.obj_dir = join(data_dir, 'test_objs') 32 | self.skin_dir = join(data_dir, 'weights') 33 | if vol_geo_dir is not None: 34 | self.vol_geo_dir = join(data_dir, 'volumetric_geodesic') 35 | else: 36 | self.vol_geo_dir = vol_geo_dir 37 | self.num_joints = num_joints 38 | 39 | self.characters_skinweights = {} 40 | super(MixamoSkinDataset, self).__init__(root, transform, pre_transform) 41 | 42 | @property 43 | def processed_dir(self): 44 | if self.datatype is not None: 45 | return os.path.join(self.data_dir, 'processed') + '_' + self.datatype 46 | return os.path.join(self.data_dir, 'processed') 47 | 48 | @property 49 | def raw_dir(self): 50 | return self.data_dir 51 | 52 | @property 53 | def raw_file_names(self): 54 | if self.overfit: 55 | character, motion = 'aj', 'Samba Dancing_000000' 56 | raw_objs_list = [join(character, motion+'.obj')] 57 | else: 58 | self.characters_list = characters_list = open(os.path.join(self.data_dir, self.split+'.txt')).read().splitlines() 59 | if self.split == 'test_models' and not self.test_all: 60 | objs_per_character = [motion for motion 61 | in os.listdir(os.path.join(self.obj_dir, characters_list[0])) 62 | if motion.endswith('000000.obj')] # use only the first sequence 63 | else: 64 | objs_per_character = [motion for motion 65 | in os.listdir(os.path.join(self.obj_dir, characters_list[0])) 66 | if motion.endswith('.obj') and motion != 'bindpose.obj'] 67 | raw_objs_list = [] 68 | for character in characters_list: 69 | for obj in objs_per_character: 70 | raw_objs_list.append(join(character, obj)) 71 | return raw_objs_list 72 | @property 73 | def processed_file_names(self): 74 | return [join(self.split, 'data_{}.pt'.format(i)) for i in range(self.__len__())] 75 | 76 | def __len__(self): 77 | return len(self.raw_paths) 78 | 79 | def get(self, idx): 80 | if self.preprocess: 81 | try: 82 | data = torch.load(join(self.processed_dir, self.split, 'data_{}.pt'.format(idx))) 83 | return data 84 | except: 85 | print("\n\nFailed to load preprocessed_data\n") 86 | self.preprocess = False 87 | data = self.process_single(idx) 88 | return data 89 | 90 | def process(self): 91 | if not self.preprocess: 92 | return 93 | if not os.path.exists(join(self.processed_dir, self.split)): 94 | os.makedirs(join(self.processed_dir, self.split)) 95 | 96 | for i, motion_obj in tqdm(enumerate(self.raw_file_names)): 97 | data = self.process_single(i) 98 | torch.save(data, join(self.processed_dir, self.split, 'data_{}.pt'.format(i))) 99 | 100 | 101 | def process_single(self, idx): 102 | i, motion_obj = idx, self.raw_file_names[idx] 103 | character_name = motion_obj.split('/')[0] 104 | motion_name = motion_obj.split('/')[1].split('.obj')[0] 105 | vol_geo_npy = character_name + '_' + motion_name + '_volumetric_geo.npy' 106 | 107 | if use_o3d: 108 | mesh = o3d.io.read_triangle_mesh(join(self.obj_dir, motion_obj)) 109 | bp_mesh = o3d.io.read_triangle_mesh(join(self.obj_dir, character_name, 'bindpose.obj')) 110 | triangles = np.asarray(mesh.triangles) 111 | 112 | vertices = torch.Tensor(np.asarray(mesh.vertices)) 113 | vertex_normals = torch.Tensor(np.asarray(mesh.vertex_normals)) 114 | bp_vertices = torch.FloatTensor(np.asarray(bp_mesh.vertices)) 115 | 116 | else: 117 | raise NotImplementedError 118 | 119 | 120 | if len(self.characters_skinweights) == 0: 121 | for cn in tqdm(self.characters_list): 122 | skinweights = torch.FloatTensor( 123 | np.genfromtxt(join(self.skin_dir, cn + '.csv'), delimiter=',', dtype='float')) 124 | self.characters_skinweights[cn] = skinweights 125 | 126 | skinweights=self.characters_skinweights[character_name] 127 | # y-axis center 128 | if self.center_yaxis: 129 | x_max, x_min = vertices[:,0].max().item(), vertices[:,0].min().item() 130 | z_max, z_min = vertices[:,2].max().item(), vertices[:,2].min().item() 131 | x_mid, z_mid = (x_max+x_min)/2, (z_max+z_min)/2 132 | vertices[:, 0] -= x_mid 133 | vertices[:, 2] -= z_mid 134 | 135 | e=[] 136 | for f in triangles: 137 | e += [[f[0], f[1]], [f[1], f[2]], [f[2], f[0]], [f[1], f[0]], [f[0], f[2]], [f[2], f[1]]] 138 | E = torch.unique(torch.LongTensor(e).T, dim=0) 139 | 140 | vol_geo = torch.Tensor(np.load(join(self.vol_geo_dir, vol_geo_npy))) 141 | 142 | mesh = Data(pos=vertices, normal=vertex_normals, volumetric_geodesic=vol_geo, edge_index=E, skin=skinweights, bindpose=bp_vertices, 143 | character_name=character_name, motion_name=motion_name) 144 | return (mesh, 0, 0, character_name, motion_name) -------------------------------------------------------------------------------- /datasets/mixamo_vox_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import sys 4 | 5 | import numpy as np 6 | from tqdm import tqdm 7 | 8 | 9 | import torch 10 | import torch.utils.data as data 11 | 12 | sys.path.insert(0, '..') 13 | sys.path.insert(0, '.') 14 | import utils.binvox_rw as binvox_rw 15 | from utils.joint_util import maketree, bfs 16 | from utils.voxel_utils import Cartesian2Voxcoord, bin2sdf, draw_jointmap, center_vox 17 | 18 | try: 19 | import open3d as o3d 20 | use_o3d=True 21 | except: 22 | print("Unable to load open3d") 23 | use_o3d=False 24 | 25 | from utils.obj_utils import ObjLoader 26 | 27 | class MixamoVoxDataset(data.Dataset): 28 | def __init__(self, split, configs=None, reduce_motion=True, use_front_faced=False): 29 | self.data_dir = data_dir = configs['data_dir'] 30 | self.configs = configs 31 | self.split = split # [train_models.txt/valid_models.txt/test_models.txt] 32 | if 'train' in split: 33 | self.mode = 'train' 34 | elif 'valid' in split: 35 | self.mode = 'eval' 36 | elif 'test' in split: 37 | self.mode = 'vis' 38 | self.use_front_faced = use_front_faced 39 | 40 | self.characters_list = open(os.path.join(data_dir, split)).read().splitlines() 41 | self.motions_list = sorted([motion.split('.binvox')[0] for motion in 42 | os.listdir(os.path.join(data_dir, 'objs/{}'.format(self.characters_list[0]))) 43 | if motion.endswith('.binvox') and motion != 'bindpose.binvox']) 44 | 45 | self.sigma = 2.4 46 | if 'sigma' in configs.keys(): 47 | self.sigma = configs['sigma'] 48 | if configs['overfit']: 49 | self.characters_list = [self.characters_list[0]] 50 | self.motions_list = [self.motions_list[0]] 51 | 52 | # reduce number of motions 53 | # Samba Dancing: ceil(274/15), Warming Up: ceil(95/10)*2, Shoved Reaction With Spin: ceil(45/5)*2, 54 | def keep_motion(motion): 55 | if not reduce_motion: 56 | return True 57 | motion_name = motion.split('_')[0] 58 | frame_idx = int(motion.split('_')[-1]) 59 | if self.mode == 'vis': 60 | return frame_idx == 0 61 | if motion_name == 'Samba Dancing': 62 | return frame_idx % 15 == 0 63 | elif motion_name == 'Warming Up': 64 | return frame_idx % 10 == 0 65 | elif motion_name == 'Shoved Reaction With Spin': 66 | return frame_idx % 5 == 0 67 | else: # Back Squat: 18, Drunk Walk Backwards: 13 68 | return True 69 | self.motions_list = [motion for motion in self.motions_list 70 | if keep_motion(motion)] 71 | self.num_characters = self.characters_list.__len__() 72 | self.motions_per_character = self.motions_list.__len__() 73 | 74 | self.n_joint = 22 75 | self.r = configs['padding'] # 3 76 | self.dim_ori = configs['dim_ori'] # 82 77 | self.dim_pad = configs['dim_pad'] # 88 78 | 79 | def __len__(self): 80 | return self.num_characters * self.motions_per_character 81 | 82 | def __getitem__(self, idx): 83 | if torch.is_tensor(idx): 84 | idx = idx.tolist() 85 | 86 | character_idx = idx // self.motions_per_character 87 | motion_idx = idx % self.motions_per_character 88 | 89 | character_name = self.characters_list[character_idx] 90 | motion_name = self.motions_list[motion_idx] 91 | 92 | data_dir = self.data_dir 93 | if self.use_front_faced: 94 | vox_file = os.path.join(data_dir, 'objs_fixed/{}/{}.binvox'.format(character_name, motion_name)) 95 | joint_matrix_file = os.path.join(data_dir, 96 | 'transforms_fixed/{}/{}.csv'.format(character_name, motion_name)) 97 | else: 98 | vox_file = os.path.join(data_dir, 'objs/{}/{}.binvox'.format(character_name, motion_name)) 99 | joint_matrix_file = os.path.join(data_dir, 100 | 'transforms/{}/{}.csv'.format(character_name, motion_name)) 101 | # Read binvox 102 | r, dim_ori, dim_pad = self.r, self.dim_ori, self.dim_pad 103 | with open(vox_file, 'rb') as f: 104 | bin_vox = binvox_rw.read_as_3d_array(f) 105 | meta = {'translate': bin_vox.translate, 'scale': bin_vox.scale, 'dims': bin_vox.dims[0], 106 | 'character_name': character_name, 'motion_name': motion_name, 107 | 'mode': self.mode} 108 | bin_vox_padded = np.zeros((bin_vox.dims[0] + 2 * r, bin_vox.dims[1] + 2 * r, bin_vox.dims[2] + 2 * r), dtype= np.float16) 109 | bin_vox_padded[r:bin_vox.dims[0] + r, r:bin_vox.dims[1] + r, r:bin_vox.dims[2] + r] = bin_vox.data 110 | # put the occupied voxels at the center instead of left-top corner 111 | bin_vox_padded, center_trans = center_vox(bin_vox_padded) 112 | meta['center_trans'] = np.array(center_trans) 113 | # convert binary voxels to SDF representation 114 | sdf_vox_padded = bin2sdf(bin_vox_padded) 115 | 116 | # Relative Coordinate 117 | JM4x4 = np.concatenate([np.genfromtxt(joint_matrix_file, delimiter=',', dtype='float'), np.array([[0, 0, 0, 1]]*22)], 118 | axis=1).reshape(22, 4, 4) 119 | # IBM4x4 = np.concatenate([np.genfromtxt(inverse_bind_matrix_file, delimiter=',', dtype='float'), np.array([[0, 0, 0, 1]]*22)], 120 | # axis=1).reshape(22, 4, 4) 121 | # skinweights = np.genfromtxt(skin_file, delimiter=',', dtype='float') 122 | tree = maketree(22) 123 | JM4x4_glob = [None] * 22 124 | bfs(tree, JM4x4, JM4x4_glob) 125 | JM4x4_glob = np.array(JM4x4_glob) 126 | JM4x4_glob_p = JM4x4_glob[:, :3, 3] 127 | 128 | target_heatmap = np.zeros((self.n_joint, int(bin_vox.dims[0] + 2 * r), int(bin_vox.dims[1] + 2 * r), 129 | int(bin_vox.dims[2] + 2 * r)), dtype=np.float16) 130 | for joint_idx in range(self.n_joint): 131 | heatmap, pos = target_heatmap[joint_idx], JM4x4_glob_p[joint_idx] 132 | pos = Cartesian2Voxcoord(pos, bin_vox.translate, bin_vox.scale, bin_vox.dims[0]) 133 | pos = (pos[0] - center_trans[0] + r, pos[1] - center_trans[1] + r, pos[2] - center_trans[2] + r) 134 | pos = np.clip(pos, a_min=0, a_max=dim_pad - 1) 135 | draw_jointmap(heatmap, pos, sigma=self.sigma) 136 | 137 | return bin_vox_padded, sdf_vox_padded, target_heatmap, JM4x4, meta -------------------------------------------------------------------------------- /models/gcn_modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.nn import MessagePassing, knn_graph 3 | from torch_scatter import scatter_max, scatter_mean 4 | from torch_geometric.utils import add_self_loops, remove_self_loops, softmax 5 | from torch.nn import Sequential, Dropout, Linear, ReLU, BatchNorm1d, Parameter 6 | 7 | def MLP(channels, use_bn=True): 8 | if use_bn: 9 | return Sequential(*[Sequential(Linear(channels[i - 1], channels[i]), ReLU(), BatchNorm1d(channels[i], momentum=0.1)) 10 | for i in range(1, len(channels))]) 11 | else: 12 | return Sequential(*[Sequential(Linear(channels[i - 1], channels[i]), ReLU()) for i in range(1, len(channels))]) 13 | 14 | 15 | class EdgeConv(MessagePassing): 16 | def __init__(self, in_channels, out_channels, nn, aggr='max', **kwargs): 17 | super(EdgeConv, self).__init__(aggr=aggr, **kwargs) 18 | self.in_channels = in_channels 19 | self.out_channels = out_channels 20 | self.nn = nn 21 | 22 | def forward(self, x, edge_index): 23 | x = x.unsqueeze(-1) if x.dim() == 1 else x ## ToDo: verify 24 | edge_index, _ = remove_self_loops(edge_index) 25 | edge_index, _ = add_self_loops(edge_index, num_nodes = x.size(0)) 26 | return self.propagate(edge_index, x=x) # x: [V, out_channels] 27 | 28 | def message(self, x_i, x_j): 29 | return self.nn(torch.cat([x_i, (x_j - x_i)], dim=1)) 30 | 31 | def update(self, aggr_out): 32 | aggr_out = aggr_out.view(-1, self.out_channels) 33 | return aggr_out 34 | 35 | def __repr__(self): 36 | return '{}(nn={})'.format(self.__class__.__name__, self.nn) 37 | 38 | class DynamicEdgeConv(EdgeConv): 39 | def __init__(self, in_channels, out_channels, nn, k=20, aggr='max', **kwargs): 40 | super(DynamicEdgeConv, self).__init__(in_channels, out_channels, nn, aggr, **kwargs) 41 | self.k = k 42 | 43 | def forward(self, x, batch=None): 44 | edge_index = knn_graph(x, self.k, batch, loop=False, flow=self.flow) 45 | return super(DynamicEdgeConv, self).forward(x, edge_index) 46 | 47 | class GCU(torch.nn.Module): # graph convolution unit : Mesh -> feature vector 48 | def __init__(self, in_channels, out_channels, k=-1, aggr='max'): 49 | super(GCU, self).__init__() 50 | assert out_channels % 2 == 0 51 | # Edge conv based on mesh 52 | self.k = k 53 | self.edge_conv_0 = EdgeConv(in_channels=in_channels, out_channels=out_channels//2, 54 | nn=MLP([in_channels * 2, out_channels // 2, out_channels // 2]), aggr=aggr) 55 | # Dynamic Edge conv based on volumetric euclidean distance 56 | if k > 0: 57 | self.edge_conv_1 = DynamicEdgeConv(in_channels=in_channels, out_channels=out_channels//2, 58 | nn=MLP([in_channels * 2, out_channels // 2, out_channels // 2]), k=k, aggr=aggr) 59 | else: 60 | self.edge_conv_1 = EdgeConv(in_channels=in_channels, out_channels=out_channels//2, 61 | nn=MLP([in_channels * 2, out_channels // 2, out_channels // 2]), aggr=aggr) 62 | self.mlp = MLP([out_channels, out_channels]) 63 | 64 | def forward(self, x, batch=None, tpl_edge_index=None, euc_edge_index=None, radius=0.): 65 | # assuming that at least one of (tpl_edge_index or euc_edge_index) is not None. 66 | if tpl_edge_index is not None: 67 | x0 = self.edge_conv_0(x, tpl_edge_index) # [V, out_channels//2] 68 | else: 69 | x0 = self.edge_conv_0(x, euc_edge_index) 70 | if self.k > 0: 71 | x1 = self.edge_conv_1(x, batch) # [V, out_channels//2] 72 | else: 73 | if euc_edge_index is not None: 74 | x1 = self.edge_conv_1(x, euc_edge_index) 75 | else: 76 | x1 = self.edge_conv_1(x, tpl_edge_index) 77 | # x_euc = self.edge_conv_euc(x, batch) # [V, out_channels//2] 78 | x_out = torch.cat([x0, x1], dim=1) # [V, out_channels] 79 | x_out = self.mlp(x_out) # [V, out_channels] 80 | return x_out 81 | -------------------------------------------------------------------------------- /models/mixamo_bvh_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.parallel 4 | import torch.utils.data 5 | from collections import defaultdict 6 | import numpy as np 7 | import torch.nn.functional as F 8 | from torch.utils.tensorboard import SummaryWriter 9 | 10 | from utils.rotation_util import compute_rotation_matrix_from_ortho6d, compute_geodesic_distance_from_two_matrices, compute_L2_distance_from_two_matrices 11 | from utils.bvh_utils import compute_bindpose_from_bvh_animation, get_global_bvh_rotations_torch 12 | from utils.joint_util import transform_rel2glob, toSE3 13 | from utils.skin_util import vertex_transform, mesh_transform 14 | 15 | blue = lambda x: '\033[94m' + x + '\033[0m' 16 | 17 | from utils.loss_utils import JointsMSELoss, compute_mpjpe, compute_bone_symmetry_loss 18 | from models.networks.bvh_simple import SimpleBVHNet, SimpleBVHNet_BN 19 | 20 | 21 | class MixamoBVHModel(nn.Module): 22 | # Input: sdf Vox, Output: joint heatmap 23 | def __init__(self, configs): 24 | super(MixamoBVHModel, self).__init__() 25 | self.configs = configs 26 | self.writer = SummaryWriter 27 | self.losses = defaultdict(torch.FloatTensor) 28 | if configs['network'] == 'simple': 29 | self.network = SimpleBVHNet(in_dim=22*3, out_dim=22*6, hidden_dim=[512, 256, 256]) 30 | elif configs['network'] == 'simple_bn': 31 | self.network = SimpleBVHNet_BN(in_dim=22*3, out_dim=22*6, hidden_dim=[512, 256, 256]) 32 | else: 33 | raise NotImplementedError 34 | 35 | def preprocess_position(self, input_position): 36 | # zero root position 37 | if 'zero_root' in self.configs.keys() and self.configs['zero_root'] is True: 38 | input_position = input_position - input_position[:, :1, :] 39 | # should we normalize? 40 | return input_position 41 | 42 | def postprocess_rotation(self, pred_rotation): 43 | if self.configs['rel_rot']: # network is outputting relative joint rotation 44 | # pred_rotation1 = transform_rel2glob(pred_rotation) 45 | pred_rotation = get_global_bvh_rotations_torch(pred_rotation) 46 | 47 | return pred_rotation 48 | 49 | 50 | def forward(self, input_position): 51 | ######## Becareful to uncomment this when zeroing input ######## 52 | input_position = self.preprocess_position(input_position) 53 | ######## Becareful to uncomment this when zeroing input ######## 54 | pred_rotation_param = self.network(input_position) # [B, J, nParam] -> [B, J, rotParam] 55 | pred_rotation = self.to_rotation_matrix(pred_rotation_param) 56 | pred_rotation = self.postprocess_rotation(pred_rotation) 57 | return pred_rotation 58 | 59 | def to_rotation_matrix(self, rotation_param): 60 | # Todo: more representations 61 | # if True or self.configs['rotation_representation'] == '6D': 62 | return compute_rotation_matrix_from_ortho6d(rotation_param) 63 | # else: 64 | # raise NotImplementedError 65 | 66 | 67 | def compute_loss(self, pred_rotation, target_rotation, average=True): 68 | # SO3, SO3 -> radians 69 | theta = self.theta = compute_geodesic_distance_from_two_matrices(pred_rotation, target_rotation) 70 | if average: 71 | return torch.mean(theta) 72 | return theta 73 | 74 | 75 | def compute_accuracy(self, input_position, pred_rotation, target_rotation, average=True): 76 | target_bindpose_position = compute_bindpose_from_bvh_animation(input_position, target_rotation) 77 | pred_bindpose_position = compute_bindpose_from_bvh_animation(input_position, pred_rotation) 78 | bindpose_mpjpe = compute_mpjpe(pred_bindpose_position, target_bindpose_position) # mean per joint position error [B,] 79 | if average: 80 | return torch.mean(bindpose_mpjpe) 81 | return bindpose_mpjpe 82 | 83 | def print_loss(self, loss, acc, epoch, i, num_batch): 84 | mode = 'Train' if self.training else 'Eval' 85 | status, end = (blue(mode) + ' {}'.format(epoch), '') if not self.training else (mode + ' {}: {}/{}'.format(epoch, i, num_batch), '') 86 | print("\r[" + status + "] ", 87 | "loss : {:.8f}, acc : {:.4f}".format(loss, acc), end=end) 88 | 89 | def write_summary(self, losses_dict, step=None, epoch=None, writer=None): 90 | mode = 'Train' if self.training else 'Eval' 91 | if step is None and epoch is not None: 92 | mode += '/epoch' 93 | step=epoch 94 | l = defaultdict(float) 95 | for key, val in losses_dict.items(): 96 | if isinstance(val, list): 97 | l[key] = val[-1] 98 | else: 99 | l[key] = val 100 | writer.add_scalar(f'{mode}/{key}', l[key], step) 101 | -------------------------------------------------------------------------------- /models/mixamo_skin_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.parallel 4 | import torch.utils.data 5 | import torch.nn.functional as F 6 | from torch.utils.tensorboard import SummaryWriter 7 | 8 | from torch_geometric.nn import radius_graph 9 | 10 | blue = lambda x: '\033[94m' + x + '\033[0m' 11 | 12 | 13 | from models.gcn_modules import MLP, GCU 14 | from torch_scatter import scatter_max 15 | from torch.nn import Sequential, Dropout, Linear 16 | import torch_geometric 17 | 18 | class EdgeConvfeat(torch.nn.Module): 19 | def __init__(self, out_channels, channels=[64, 256, 512], global_channel=1024, input_normal=False, arch='all_feat', 20 | aggr='max', use_bn=False, graph_configs=None): 21 | super(EdgeConvfeat, self).__init__() 22 | self.input_normal = input_normal 23 | self.arch = arch 24 | if self.input_normal: 25 | self.input_channel = 6 + 26 # pos(3) + norm(3) + vol_geo(26) 26 | raise NotImplementedError 27 | else: 28 | self.input_channel = 3 + 26 # pos(3) + vol_geo(26) 29 | 30 | self.global_feat_size = global_channel 31 | k = graph_configs['k'] 32 | 33 | self.gcu_1 = GCU(in_channels=self.input_channel, out_channels=channels[0], k=k, aggr=aggr) 34 | self.gcu_2 = GCU(in_channels=channels[0], out_channels=channels[1], k=k, aggr=aggr) 35 | self.gcu_3 = GCU(in_channels=channels[1], out_channels=channels[2], k=k, aggr=aggr) 36 | # feature compression 37 | self.mlp_glb = MLP([(channels[0] + channels[1] + channels[2]), self.global_feat_size], use_bn=use_bn) 38 | if self.arch != 'global_feat': 39 | self.mlp_transform = Sequential(MLP([self.global_feat_size + self.input_channel + channels[0] + channels[1] + channels[2], 40 | self.global_feat_size, 256], use_bn=use_bn), 41 | Dropout(0.7), Linear(256, out_channels)) 42 | # edge index type 43 | self.edge_type = graph_configs['edge_type'] 44 | 45 | def forward(self, mesh: torch_geometric.data.Batch): 46 | if self.input_normal: 47 | x = torch.cat([mesh.pos, mesh.x, mesh.volumetric_geodesic], dim=1) 48 | raise NotImplementedError 49 | else: 50 | x = torch.cat([mesh.pos, mesh.volumetric_geodesic], dim=1) 51 | 52 | edge_index, euc_edge_index, batch = mesh.edge_index, mesh.euc_edge_index, mesh.batch 53 | 54 | if self.edge_type == 'tpl_and_euc': 55 | x_1 = self.gcu_1(x, batch, edge_index, euc_edge_index) # [V, channels[0]] 56 | x_2 = self.gcu_2(x_1, batch, edge_index, euc_edge_index) # [V, channels[1]] 57 | x_3 = self.gcu_3(x_2, batch, edge_index, euc_edge_index) # [V, channels[2]] 58 | elif self.edge_type == 'tpl_only': 59 | x_1 = self.gcu_1(x, batch, edge_index, None) # [V, channels[0]] 60 | x_2 = self.gcu_2(x_1, batch, edge_index, None) # [V, channels[1]] 61 | x_3 = self.gcu_3(x_2, batch, edge_index, None) # [V, channels[2]] 62 | elif self.edge_type == 'euc_only': 63 | x_1 = self.gcu_1(x, batch, None, euc_edge_index) # [V, channels[0]] 64 | x_2 = self.gcu_2(x_1, batch, None, euc_edge_index) # [V, channels[1]] 65 | x_3 = self.gcu_3(x_2, batch, None, euc_edge_index) # [V, channels[2]] 66 | else: 67 | raise NotImplementedError 68 | x_4 = self.mlp_glb(torch.cat([x_1, x_2, x_3], dim=1)) # [V, 1024] 69 | 70 | x_global_feat, _ = scatter_max(x_4, batch, dim=0) # [B, 1024] 71 | if self.arch == 'global_feat': 72 | return x_global_feat # [B, 1024] 73 | x_global = torch.repeat_interleave(x_global_feat, torch.bincount(batch), dim=0) # global feature to each vertex 74 | x_5 = torch.cat([x_global, x, x_1, x_2, x_3], dim=1) # [V, 1024+input_channel+sum(channels)] 75 | 76 | out = self.mlp_transform(x_5) # [V, out_channels] 77 | if self.arch == 'mesh_feat': 78 | return out # [V, out_channels] 79 | elif self.arch == 'all_feat': 80 | return out, x_global_feat # [V, out_channels], [B, 1024] 81 | return out 82 | 83 | 84 | class MixamoMeshSkinModel(nn.Module): 85 | def __init__(self, configs, num_joints=22, use_bn=False): 86 | super(MixamoMeshSkinModel, self).__init__() 87 | self.writer = SummaryWriter 88 | self.global_feature_size = configs['global_feature_size'] 89 | self.feature_size = configs['feature_size'] 90 | self.channels = configs['channels'] 91 | self.configs = configs 92 | self.input_normal = configs['use_normal'] 93 | self.edgeconv_feat = EdgeConvfeat(out_channels=self.feature_size, channels=self.channels, global_channel=self.global_feature_size, 94 | input_normal=self.input_normal, arch='all_feat', aggr='max', use_bn=use_bn, graph_configs=configs) 95 | self.skinnet = MeshSkinNet(input_channels=self.feature_size, num_joints=num_joints, use_bn=use_bn) 96 | 97 | def forward(self, data): 98 | # mesh = data[0] 99 | mesh = data 100 | if self.configs['euc_radius'] > 0: 101 | mesh.euc_edge_index = radius_graph(mesh.pos, self.configs['euc_radius']) 102 | else: 103 | mesh.euc_edge_index = None 104 | mesh_feat, global_feat = self.edgeconv_feat(mesh) 105 | skin_logits = self.skinnet(mesh_feat) 106 | return skin_logits 107 | 108 | def calculate_loss(self, pred_skin, mesh, writer=None, step=None, summary_step=1): 109 | """ 110 | In: 111 | preds = (pred_jm_rot6d, pred_jm_trans, pred_skin) + (pred_bm_rot6d, pred_bm_trans) 112 | targets = (mesh, gt_jm, gt_ibm) 113 | Out: 114 | loss 115 | """ 116 | self.skin_loss = 0 117 | l2_loss = torch.nn.MSELoss() 118 | batch_size = mesh.batch.max().item() + 1 119 | for i in range(batch_size): 120 | single_gt_skin, single_pred_skin = mesh.skin[mesh.batch == i], pred_skin[mesh.batch == i] 121 | single_skin_loss = - torch.mean(torch.sum(single_pred_skin * single_gt_skin, axis=-1)) 122 | self.skin_loss += single_skin_loss / batch_size 123 | if writer is not None and self.training: 124 | self.write_summary(writer, step, self.skin_loss) 125 | 126 | return self.skin_loss 127 | 128 | def write_summary(self, writer, step, skin_loss=None): 129 | # summary is written at every step during training. 130 | # you have to explicitly call this function when evaluating 131 | if self.training: 132 | mode = 'train' 133 | else: 134 | mode = 'eval' 135 | writer.add_scalar(f'{mode}/skin_loss', skin_loss, step) 136 | 137 | def print_running_loss(self, epoch, i, num_batch): 138 | self.print_loss(epoch, i, num_batch, self.skin_loss, is_mean=False) 139 | 140 | def print_loss(self, epoch, i, num_batch, skin_loss=None, is_mean=False): 141 | mode = 'Train' if self.training else 'Eval' 142 | status, end = (blue(mode) + ' {}'.format(epoch), '\n') if is_mean else (mode + ' {}: {}/{}'.format(epoch, i, num_batch), '') 143 | print("\r[" + status + "] " + 144 | "skin loss: {:.4f}".format( 145 | skin_loss 146 | ), end=end) 147 | 148 | 149 | class MeshSkinNet(nn.Module): 150 | def __init__(self, input_channels, num_joints, use_bn=False): 151 | super(MeshSkinNet, self).__init__() 152 | self.num_joints = num_joints 153 | self.use_bn = use_bn 154 | self.mlp_skin = MLP([input_channels, 256, 128, num_joints], use_bn) 155 | 156 | def forward(self, mesh_feat): 157 | out = self.mlp_skin(mesh_feat) # [V, num_joints] 158 | log_prob = F.log_softmax(out.view(-1, self.num_joints), dim=-1) 159 | return log_prob 160 | -------------------------------------------------------------------------------- /models/mixamo_vox_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.parallel 4 | import torch.utils.data 5 | from collections import defaultdict 6 | import numpy as np 7 | import torch.nn.functional as F 8 | from torch.utils.tensorboard import SummaryWriter 9 | 10 | blue = lambda x: '\033[94m' + x + '\033[0m' 11 | 12 | from utils.loss_utils import softmax3d, normalize3d, compute_mpjpe, compute_bone_symmetry_loss 13 | from utils.voxel_utils import get_final_preds, get_max_preds, get_final_preds_torch, get_max_preds_torch, downsample_heatmap 14 | from models.networks.vox_hourglass import V2V_HG 15 | 16 | class MixamoVoxModel(nn.Module): 17 | # Input: sdf Vox, Output: joint heatmap 18 | def __init__(self, configs): 19 | super(MixamoVoxModel, self).__init__() 20 | self.configs = configs 21 | self.writer = SummaryWriter 22 | self.losses = defaultdict(torch.FloatTensor) 23 | self.network = V2V_HG(input_channels=1, feature_channels=configs['feature_channels'], n_stack=configs['n_stack'], n_joint=22, 24 | downsample=self.configs['downsample'], configs=configs) 25 | 26 | if self.configs['activation'] == 'sigmoid': 27 | self.activation = F.sigmoid 28 | elif self.configs['activation'] == 'softmax': 29 | self.activation = softmax3d 30 | elif self.configs['activation'] == 'none': 31 | self.activation = None 32 | else: 33 | __import__('pdb').set_trace() 34 | print('Unrecognized activation') 35 | self.activation = None 36 | 37 | def forward(self, vox): 38 | heatmap = self.network(vox) 39 | if self.activation is not None: 40 | heatmap = self.activation(heatmap) 41 | return heatmap 42 | 43 | def compute_loss(self, pred_heatmap, target_heatmap, mask=None): 44 | loss_type = self.configs['loss_type'] 45 | if self.configs['normalize_heatmap']: 46 | target_heatmap = normalize3d(target_heatmap) 47 | if self.configs['downsample'] > 1: 48 | target_heatmap = downsample_heatmap(target_heatmap, self.configs['downsample']) 49 | if loss_type == 'ce': 50 | assert self.configs['activation'] not in ['sigmoid', 'softmax'] 51 | loss = torch.nn.functional.binary_cross_entropy_with_logits(pred_heatmap, target_heatmap) 52 | # loss = torch.nn.BCEWithLogitsLoss()(pred_heatmap, target_heatmap) 53 | elif loss_type == 'ce_mask': 54 | assert self.configs['activation'] not in ['sigmoid', 'softmax'] 55 | loss = torch.nn.functional.binary_cross_entropy_with_logits(pred_heatmap, target_heatmap, 56 | weight=mask, reduction='sum') / mask.sum() / 22 # divide by n_joint 57 | elif loss_type == 'mse': 58 | loss = torch.mean((pred_heatmap - target_heatmap) ** 2) 59 | else: 60 | print("Unrecognized loss type") 61 | raise NotImplementedError 62 | 63 | if self.configs['loss_symmetry'] != 0: 64 | loss += self.configs['loss_symmetry'] * compute_bone_symmetry_loss(pred_heatmap) 65 | return loss 66 | 67 | 68 | def compute_accuracy(self, pred_heatmap, target_heatmap, JM4x4_rel, meta, average=True): 69 | """ 70 | Input: torch 71 | output: np 72 | """ 73 | # JM4x4_global = transform_rel2glob(JM4x4_rel) 74 | # target_coords = JM4x4_global[:, :, :3, 3] 75 | device = pred_heatmap.device 76 | translate, scale = torch.stack(meta['translate'], axis=1).to(device)[:, None, :], \ 77 | meta['scale'].to(device)[:, None, None] 78 | target_coords, _ = get_final_preds_torch(target_heatmap, translate, scale) 79 | pred_coords, _ = get_final_preds_torch(pred_heatmap, translate, scale * self.configs['downsample']) 80 | # if not is_mean: 81 | # a = (pred_coords - target_coords) * (pred_coords - target_coords) 82 | # return torch.einsum('ijk->i', a)/float(torch.prod(torch.Tensor(a.shape[1:]))) 83 | mpjpe = compute_mpjpe(pred_coords, target_coords) # mean per joint position error [B,] 84 | if average: 85 | return torch.mean(mpjpe) 86 | else: 87 | return mpjpe 88 | 89 | 90 | def print_loss(self, loss, acc, epoch, i, num_batch): 91 | mode = 'Train' if self.training else 'Eval' 92 | status, end = (blue(mode) + ' {}'.format(epoch), '') if not self.training else (mode + ' {}: {}/{}'.format(epoch, i, num_batch), '') 93 | print("\r[" + status + "] ", 94 | "loss : {:.8f}, acc : {:.4f}".format(loss, acc), end=end) 95 | 96 | def write_summary(self, losses_dict, step=None, epoch=None, writer=None): 97 | mode = 'Train' if self.training else 'Eval' 98 | if step is None and epoch is not None: 99 | mode += '/epoch' 100 | step=epoch 101 | l = defaultdict(float) 102 | for key, val in losses_dict.items(): 103 | if isinstance(val, list): 104 | l[key] = val[-1] 105 | else: 106 | l[key] = val 107 | writer.add_scalar(f'{mode}/{key}', l[key], step) 108 | -------------------------------------------------------------------------------- /models/networks/bvh_simple.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class SimpleBVHNet(nn.Module): 6 | def __init__(self, in_dim, out_dim, hidden_dim): 7 | super(SimpleBVHNet, self).__init__() 8 | self.in_dim, self.out_dim = in_dim, out_dim 9 | self.fc_list = [] 10 | self.fc_list += [nn.Linear(in_dim, hidden_dim[0])] 11 | self.fc_list += [nn.Linear(hidden_dim[i], hidden_dim[i+1]) for i in range(len(hidden_dim)-1)] 12 | self.fc_list += [nn.Linear(hidden_dim[-1], out_dim)] 13 | self.fc_list = nn.ModuleList(self.fc_list) 14 | self.activation = nn.ReLU() 15 | 16 | 17 | def forward(self, x): 18 | """ 19 | Input: input_pos [B, J, in_dim//J] 20 | Output: rotation_param [B, J, out_dim//J] 21 | """ 22 | x = x.reshape([-1, self.in_dim]) # [B, J, in_dim//J] -> [B, J] 23 | for fc in self.fc_list[:-1]: 24 | x = self.activation(fc(x)) 25 | x = self.fc_list[-1](x) 26 | return x.view(-1, 22, self.out_dim//22) # [B, out_dim] -> [B, J, out_dim//J] 27 | 28 | class SimpleBVHNet_BN(nn.Module): 29 | def __init__(self, in_dim, out_dim, hidden_dim): 30 | super(SimpleBVHNet_BN, self).__init__() 31 | self.in_dim, self.out_dim = in_dim, out_dim 32 | self.fc_list = [] 33 | self.fc_list += [nn.Linear(in_dim, hidden_dim[0])] 34 | self.fc_list += [nn.Linear(hidden_dim[i], hidden_dim[i+1]) for i in range(len(hidden_dim)-1)] 35 | self.fc_list += [nn.Linear(hidden_dim[-1], out_dim)] 36 | self.fc_list = nn.ModuleList(self.fc_list) 37 | 38 | # self.bn_list = [nn.BatchNorm1d(dim) for dim in hidden_dim[:-1]] 39 | self.bn_list = [nn.BatchNorm1d(dim) for dim in hidden_dim] 40 | self.bn_list = nn.ModuleList(self.bn_list) 41 | self.activation = nn.ReLU() 42 | 43 | 44 | def forward(self, x): 45 | """ 46 | Input: input_pos [B, J, in_dim//J] 47 | Output: rotation_param [B, J, out_dim//J] 48 | """ 49 | x = x.reshape([-1, self.in_dim]) # [B, J, in_dim//J] -> [B, J] 50 | for layer_idx, fc in enumerate(self.fc_list[:-1]): 51 | x = fc(x) 52 | x = self.activation(x) 53 | x = self.bn_list[layer_idx](x) 54 | x = self.fc_list[-1](x) 55 | return x.view(-1, 22, self.out_dim//22) # [B, out_dim] -> [B, J, out_dim//J] 56 | -------------------------------------------------------------------------------- /models/networks/vox_simple.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class SimpleVoxNet(nn.Module): 6 | def __init__(self, input, output): 7 | super(SimpleVoxNet, self).__init__() 8 | features = [64, 64, 32, 32] 9 | self.block0 = nn.Sequential( 10 | nn.Conv3d(input, features[0], kernel_size=3,padding=1), 11 | nn.BatchNorm3d(features[0]), 12 | nn.ReLU(True)) 13 | self.block1 = nn.Sequential( 14 | nn.Conv3d(features[0], features[1], kernel_size=3,padding=1), 15 | nn.BatchNorm3d(features[1]), 16 | nn.ReLU(True)) 17 | self.block2 = nn.Sequential( 18 | nn.Conv3d(features[1], features[2], kernel_size=3,padding=1), 19 | nn.BatchNorm3d(features[2]), 20 | nn.ReLU(True)) 21 | self.block3 = nn.Sequential( 22 | nn.Conv3d(features[2], features[3], kernel_size=3,padding=1), 23 | nn.BatchNorm3d(features[3]), 24 | nn.ReLU(True)) 25 | self.block4 = nn.Sequential( 26 | nn.Conv3d(features[3], output, kernel_size=3,padding=1)) 27 | 28 | def forward(self, x): 29 | x = x.unsqueeze(dim=1) 30 | return self.block4(self.block3(self.block2(self.block1(self.block0(x))))) 31 | -------------------------------------------------------------------------------- /prepare_vol_geo.py: -------------------------------------------------------------------------------- 1 | ## for calculating volumetric geodesic distance with rignet format. 2 | ## Slowest Part of the pipeline 3 | 4 | import os 5 | import numpy as np 6 | import open3d as o3d 7 | import argparse 8 | 9 | from utils.tree_utils import TreeNode 10 | from utils.rig_parser import Skel, Info, get_joint_names, maketree 11 | from utils.common_ops import get_bones, calc_surface_geodesic 12 | from utils.compute_volumetric_geodesic import pts2line, calc_pts2bone_visible_mat, calc_geodesic_matrix 13 | 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--character_idx', type=int, default=-1) 17 | parser.add_argument('--start', type=int, default=0) 18 | parser.add_argument('--last', type=int, default=-1) 19 | parser.add_argument('--data_dir', type=str, default='data/mixamo') 20 | parser.add_argument('--joint_log', type=str, default='logs/vox/vox_ori_all_s4.5_HG_mean_stack2_down2_lr3e-4_b4_ce') 21 | args = parser.parse_args() 22 | # pre-calculate volumetric geodesic distance with input mesh & skeleton 23 | print("Calculating volumetric-geodesic distance...") 24 | downsample_skinning = True 25 | 26 | joint_name = get_joint_names() # joint name 27 | tree = maketree(22) # hard-coded tree 28 | 29 | data_dir = args.data_dir 30 | with open(os.path.join(data_dir, 'test_models.txt')) as test_file: 31 | characters_list = test_file.read().splitlines() 32 | joint_log = args.joint_log 33 | info_list = os.listdir(os.path.join(joint_log, 'test', characters_list[0])) 34 | motions_list = sorted([info.split('_info')[0] for info in info_list if info.endswith('_info.npy')]) 35 | prediction_method = 'pred_heatmap_joint_pos_mask' 36 | 37 | for i, character in enumerate(characters_list): 38 | # if i != args.character_idx: 39 | # continue 40 | 41 | for motion in motions_list[args.start:args.last]: 42 | joint_pos_file = os.path.join(joint_log, 'test', character, '%s_joint.npy' % (motion)) 43 | joint_pos = np.load(joint_pos_file, allow_pickle=True).item() 44 | joint_result = joint_pos[prediction_method] 45 | 46 | # save skeleton 47 | pred_skel = Info() 48 | nodes = [] 49 | for joint_index, joint_pos in enumerate(joint_result): 50 | nodes.append(TreeNode(name=joint_name[joint_index], pos=joint_pos)) 51 | 52 | pred_skel.root = nodes[0] 53 | for parent, children in enumerate(tree): 54 | for child in children: 55 | nodes[parent].children.append(nodes[child]) 56 | nodes[child].parent = nodes[parent] 57 | 58 | # calculate volumetric geodesic distance 59 | bones, _, _ = get_bones(pred_skel) 60 | mesh_filename = os.path.join(data_dir, 'objs', character + '/' + motion + '.obj') 61 | # mesh_filename = os.path.join(data_dir, 'test_objs', character + '_' + motion + '.obj') 62 | mesh = o3d.io.read_triangle_mesh(mesh_filename) 63 | mesh.compute_vertex_normals() 64 | mesh_v = np.asarray(mesh.vertices) 65 | surface_geodesic = calc_surface_geodesic(mesh) 66 | volumetric_geodesic = calc_geodesic_matrix(bones, mesh_v, surface_geodesic, mesh_filename, subsampling=downsample_skinning) 67 | # save the volumetric_geodesic distance of the prediction inside the joint log folder 68 | os.makedirs(os.path.join(joint_log, 'volumetric_geodesic_ours_final'), exist_ok=True) 69 | 70 | # save volumetric geodesic distance 71 | np.save( 72 | os.path.join(joint_log, 'volumetric_geodesic_ours_final', character + "_" + motion + "_volumetric_geo.npy"), 73 | volumetric_geodesic 74 | ) 75 | -------------------------------------------------------------------------------- /preprocess/blender_utils/DataGeneration.md: -------------------------------------------------------------------------------- 1 | ## Downloading Data 2 | 3 | - We downloaded 65 rigged 3D characters from the mixamo website(mixamo.com) 4 | - Since the Structure of the riggs differ from character to character, we first unrigged all the data and re-rigged using the 25-bone based autorigging platform provided by mixamo 5 | 6 | ### Rigging Procedure 7 | - All meshes in a character are merged into a single mesh. 8 | - All bones(armatures) are removed 9 | - Saved into FBX and OBJ formats 10 | - Upload OBJ format into autorigging platform provided by mixamo.com 11 | - Rig up to 25 bones (result in 25 vertex groups or joints. Exclude 2 Toe ends and 1 Head end) -------------------------------------------------------------------------------- /preprocess/blender_utils/dae_to_bvh.py: -------------------------------------------------------------------------------- 1 | import bpy 2 | import os 3 | import math 4 | from os import listdir, mkdir 5 | from os.path import join, exists 6 | 7 | def clear_scene(): 8 | bpy.ops.object.select_all(action='SELECT') 9 | bpy.ops.object.delete() 10 | 11 | # get keyframes of object list 12 | # def get_keyframes(obj_list): 13 | # keyframes = [] 14 | # for obj in obj_list: 15 | # anim = obj.animation_data 16 | # if anim is not None and anim.action is not None: 17 | # for fcu in anim.action.fcurves: 18 | # for keyframe in fcu.keyframe_points: 19 | # x, y = keyframe.co 20 | # print(x,y) 21 | # if x not in keyframes: 22 | # keyframes.append((math.ceil(x))) 23 | # return keyframes 24 | 25 | 26 | 27 | project_path = '/home/whitealex95/Projects/autorigging' 28 | # rigged_path = join(project_path, 'mixamo/rerigged_4k') 29 | animation_path = join(project_path, 'mixamo/animated') 30 | # animation_frame_path = join(project_path, 'mixamo/objs_4k') 31 | 32 | ## Load all animations 33 | animation_models_list = sorted(listdir(animation_path)) 34 | # animation_models_list = ['aj'] # select single animation_model 35 | # animation_models_list = ['zombie'] 36 | action = ["Back Squat", "Drunk Walk Backwards", "Samba Dancing", "Shoved Reaction With Spin_1", "Shoved Reaction With Spin_2", "Warming Up_1", "Warming Up_2"] 37 | fCount = [18, 13, 274, 45, 45, 95, 95] 38 | 39 | 40 | for animation_model in animation_models_list: 41 | # consider only collada animations 42 | # sorted with lower case first 43 | animations_list = sorted([ anim for anim in listdir(join(animation_path, animation_model)) if anim.endswith('.dae')]) 44 | print(f"Animations_list of {animation_model}: ", animations_list) 45 | 46 | # if not os.path.exists(join(animation_frame_path, animation_model)): 47 | # os.mkdir(join(animation_frame_path, animation_model)) 48 | frame_counts_list= [] 49 | for animation in animations_list: 50 | # import sys 51 | # save_stdout = sys.stdout 52 | # sys.stdout = open('trash', 'w') 53 | try: 54 | clear_scene() 55 | source_path = join(animation_path, animation_model, animation) 56 | dest_path = join(animation_path, animation_model, animation.split('.dae')[0]+'.bvh') 57 | # Some additional operations will be done to "bones" when importing collada(.dae) files 58 | # Documentation says collada only takes care of joints? g 59 | collada_import_args = { 60 | 'import_units': True, 61 | 'fix_orientation': True, 62 | 'find_chains': True, 63 | 'auto_connect': True, 64 | 'keep_bind_info': True 65 | } 66 | bpy.ops.wm.collada_import(filepath=source_path, **collada_import_args) 67 | n_frames = fCount[action.index(animation.split('.dae')[0])] 68 | bpy.context.scene.frame_start = 0 69 | bpy.context.scene.frame_end = n_frames - 1 # start from 0 to n_frames 70 | bpy.context.scene.render.fps = 30 71 | bpy.ops.export_anim.bvh(filepath=dest_path, check_existing=False, 72 | frame_start=0, frame_end=n_frames-1, root_transform_only=True) # fix bone length 73 | except: 74 | __import__('pdb').set_trace() 75 | # print('done something?') 76 | # sys.stdout = save_stdout 77 | -------------------------------------------------------------------------------- /preprocess/blender_utils/dae_to_obj.py: -------------------------------------------------------------------------------- 1 | import bpy 2 | import os 3 | import math 4 | from os import listdir, mkdir 5 | from os.path import join, exists 6 | 7 | def clear_scene(): 8 | bpy.ops.object.select_all(action='SELECT') 9 | bpy.ops.object.delete() 10 | 11 | # get keyframes of object list 12 | # def get_keyframes(obj_list): 13 | # keyframes = [] 14 | # for obj in obj_list: 15 | # anim = obj.animation_data 16 | # if anim is not None and anim.action is not None: 17 | # for fcu in anim.action.fcurves: 18 | # for keyframe in fcu.keyframe_points: 19 | # x, y = keyframe.co 20 | # print(x,y) 21 | # if x not in keyframes: 22 | # keyframes.append((math.ceil(x))) 23 | # return keyframes 24 | 25 | 26 | 27 | project_path = '/home/whitealex95/Projects/autorigging' 28 | rigged_path = join(project_path, 'mixamo/rerigged_4k') 29 | animation_path = join(project_path, 'mixamo/animated_4k') 30 | animation_frame_path = join(project_path, 'mixamo/objs_4k') 31 | 32 | ## Load all animations 33 | animation_models_list = sorted(listdir(animation_path)) 34 | #animation_models_list = ['aj'] # select single animation_model 35 | animation_models_list = ['zombie'] 36 | action = ["Back Squat", "Drunk Walk Backwards", "Samba Dancing", "Shoved Reaction With Spin_1", "Shoved Reaction With Spin_2", "Warming Up_1", "Warming Up_2"] 37 | fCount = [18, 13, 274, 45, 45, 95, 95] 38 | 39 | 40 | for animation_model in animation_models_list: 41 | # consider only collada animations 42 | # sorted with lower case first 43 | animations_list = sorted([ anim for anim in listdir(join(animation_path, animation_model)) if anim.endswith('.dae')]) 44 | print(f"Animations_list of {animation_model}: ", animations_list) 45 | 46 | if not os.path.exists(join(animation_frame_path, animation_model)): 47 | os.mkdir(join(animation_frame_path, animation_model)) 48 | frame_counts_list= [] 49 | for animation in animations_list: 50 | try: 51 | clear_scene() 52 | source_path = join(animation_path, animation_model, animation) 53 | dest_path = join(animation_frame_path, animation_model, animation.split('.dae')[0]+'.obj') 54 | # Some additional operations will be done to "bones" when importing collada(.dae) files 55 | # Documentation says collada only takes care of joints? g 56 | collada_import_args = { 57 | 'import_units': True, 58 | 'fix_orientation': True, 59 | 'find_chains': True, 60 | 'auto_connect': True, 61 | 'keep_bind_info': True 62 | } 63 | bpy.ops.wm.collada_import(filepath=source_path, **collada_import_args) 64 | n_frames = fCount[action.index(animation.split('.dae')[0])] 65 | bpy.context.scene.frame_start = 0 66 | bpy.context.scene.frame_end= n_frames - 1 # start from 0 to n_frames 67 | print(n_frames) 68 | obj_export_args = { 69 | 'use_animation': True, 70 | 'use_materials': False, # Do not create .mtl files 71 | 'keep_vertex_order': True, 72 | # Default Settings starting from below 73 | 'use_blen_objects': True, # Objects as OBJ Objects 74 | 'use_mesh_modifiers': True, # Apply Modifiers 75 | 'use_normals': True, # Apply Normals 76 | 'use_uvs': True, # Include UVs 77 | } 78 | bpy.ops.export_scene.obj(filepath=dest_path, **obj_export_args) 79 | except: 80 | __import__('pdb').set_trace() 81 | # Save mesh of obj at the end of animations_list 82 | bpy.context.scene.frame_end=0 83 | obj_export_args['use_animation']=False 84 | bind_pose_path = join(animation_frame_path, animation_model, 'bindpose.obj') 85 | for obj in bpy.context.selected_objects: 86 | obj.animation_data_clear() 87 | bpy.ops.export_scene.obj(filepath=bind_pose_path, **obj_export_args) 88 | -------------------------------------------------------------------------------- /preprocess/blender_utils/normalize_and_decimate_obj.py: -------------------------------------------------------------------------------- 1 | 2 | # blender2 --background --python ~/Projects/autorigging/autorigging/blender_utils/remove_joints.py 3 | import bpy 4 | import os 5 | import sys 6 | import math 7 | import numpy as np 8 | from tqdm import tqdm 9 | 10 | from os import listdir, makedirs, system 11 | from os.path import exists, join 12 | 13 | def delete_rig_and_join_mesh(): 14 | armature_name = 0 15 | mesh_names = [] 16 | for idx, obj in enumerate(bpy.data.objects): 17 | if obj.type == 'ARMATURE': 18 | armature_name = obj.name 19 | print(f"{idx} is ARMATURE") 20 | elif obj.type == 'MESH': 21 | mesh_names.append(obj.name) 22 | print(f"{idx} is MESH") 23 | elif obj.type == 'EMPTY': 24 | # Delete Empty Data 25 | bpy.ops.object.select_all(action='DESELECT') 26 | obj.select_set(True) 27 | bpy.ops.object.delete() 28 | else: 29 | print(obj.type) 30 | 31 | print(idx, armature_name, mesh_names) 32 | 33 | # Merge all mesh (if needed) 34 | bpy.ops.object.select_all(action='SELECT') 35 | bpy.context.view_layer.objects.active=bpy.data.objects[mesh_names[0]] 36 | bpy.ops.object.join() 37 | 38 | # Delete Rigged Info (Including Bind Pose Matrices) 39 | bpy.ops.object.select_all(action='DESELECT') 40 | bpy.data.objects[armature_name].select_set(True) 41 | bpy.ops.object.delete() 42 | 43 | def resize_obj(): 44 | # mesh coordinate normalization 45 | obj = bpy.data.objects[0] 46 | bpy.ops.object.origin_set(type='ORIGIN_GEOMETRY', center='BOUNDS') 47 | obj.location.x=0 48 | obj.location.y=0 49 | obj.location.z=0 50 | maxdim=max(obj.dimensions.x, max(obj.dimensions.y, obj.dimensions.z)) 51 | sf=2/maxdim 52 | bpy.ops.transform.resize(value=(sf, sf, sf)) 53 | 54 | def resize_and_rotate_obj(): 55 | # mesh coordinate normalization 56 | obj = bpy.data.objects[0] 57 | bpy.ops.object.origin_set(type='ORIGIN_GEOMETRY', center='BOUNDS') 58 | obj.location.x=0 59 | obj.location.y=0 60 | obj.location.z=0 61 | maxdim=max(obj.dimensions.x, max(obj.dimensions.y, obj.dimensions.z)) 62 | # maxdim = obj.dimensions.x 63 | # if maxdim != obj.dimensions.y: 64 | # __import__('pdb').set_trace() 65 | # print('fucki') 66 | sf= 1/ obj.dimensions.z 67 | # sf = 2/obj.dimensions.y # y is the upward direction when loaded by blender 68 | bpy.ops.transform.resize(value=(sf, sf, sf)) 69 | bpy.ops.transform.rotate(value=math.pi/2, orient_axis='X') 70 | bpy.ops.transform.translate(value=(0, 0, 0.5)) 71 | 72 | 73 | 74 | def decimate_obj(target_vnum = 8000): 75 | def cleanAllDecimateModifiers(obj): 76 | for m in obj.modifiers: 77 | if(m.type=="DECIMATE"): 78 | print("Removing modifier ") 79 | obj.modifiers.remove(modifier=m) 80 | 81 | def binarysearch(low, high, target_vnum): 82 | decimateRatio = (low+high)/2 83 | objectList=bpy.data.objects 84 | vertexCount = 0 85 | for obj in objectList: 86 | if(obj.type=="MESH"): 87 | # Decimate Start 88 | cleanAllDecimateModifiers(obj) 89 | modifier=obj.modifiers.new('DecimateMod','DECIMATE') 90 | modifier.ratio=1-decimateRatio 91 | modifier.use_collapse_triangulate=True 92 | # Decimate End, count the number of vertices 93 | dg = bpy.context.evaluated_depsgraph_get() #getting the dependency graph 94 | vertexCount += len(obj.evaluated_get(dg).to_mesh().vertices) 95 | print("decimateRatio: "+str(decimateRatio)) 96 | print("vertexCount: "+str(vertexCount)) 97 | if(vertexCount <= target_vnum): 98 | return True 99 | elif vertexCount < target_vnum: 100 | return binarysearch(low, decimateRatio, target_vnum) 101 | else: 102 | return binarysearch(decimateRatio, high, target_vnum) 103 | 104 | binarysearch(0, 1.0, target_vnum) 105 | 106 | 107 | def clear_scene(): 108 | bpy.ops.object.select_all(action='SELECT') 109 | bpy.ops.object.delete() 110 | 111 | project_path = "/home/whitealex95/Projects/autorigging" 112 | #directories = sorted([f for f in listdir(data_path) if not f.startswith(".")]) 113 | source_data_path = "mixamo/unrigged" 114 | dest_data_path = "mixamo/unrigged_normalized_decimated_4k_up" 115 | source_files = sorted([f for f in listdir(join(project_path, source_data_path)) if f.endswith(".obj")]) 116 | print(source_files) 117 | 118 | for f in source_files: 119 | source_path = join(project_path, source_data_path, f) 120 | dest_path = join(project_path, dest_data_path, f) 121 | # Delete all objects 122 | clear_scene() 123 | # Load OBJ 124 | bpy.ops.import_scene.obj(filepath=source_path) 125 | # Delete and Merge operation in blender 126 | # resize_obj() 127 | resize_and_rotate_obj() 128 | # Decimate meshes that are too big 129 | decimate_obj(target_vnum=4000) 130 | # Save FBX 131 | # bpy.ops.export_scene.fbx(filepath=dest_path) 132 | # Save OBJ 133 | bpy.ops.export_scene.obj(filepath=dest_path.split(".obj")[0]+".obj") 134 | 135 | 136 | 137 | -------------------------------------------------------------------------------- /preprocess/blender_utils/remove_joints.py: -------------------------------------------------------------------------------- 1 | 2 | # blender2 --background --python ~/Projects/autorigging/autorigging/blender_utils/remove_joints.py 3 | import bpy 4 | import os 5 | import sys 6 | import numpy as np 7 | from tqdm import tqdm 8 | 9 | from os import listdir, makedirs, system 10 | from os.path import exists, join 11 | 12 | def delete_rig_and_join_mesh(): 13 | armature_name = 0 14 | mesh_names = [] 15 | for idx, obj in enumerate(bpy.data.objects): 16 | if obj.type == 'ARMATURE': 17 | armature_name = obj.name 18 | print(f"{idx} is ARMATURE") 19 | elif obj.type == 'MESH': 20 | mesh_names.append(obj.name) 21 | print(f"{idx} is MESH") 22 | elif obj.type == 'EMPTY': 23 | # Delete Empty Data 24 | bpy.ops.object.select_all(action='DESELECT') 25 | obj.select_set(True) 26 | bpy.ops.object.delete() 27 | else: 28 | print(obj.type) 29 | 30 | print(idx, armature_name, mesh_names) 31 | 32 | # Merge all mesh (if needed) 33 | bpy.ops.object.select_all(action='SELECT') 34 | bpy.context.view_layer.objects.active=bpy.data.objects[mesh_names[0]] 35 | bpy.ops.object.join() 36 | 37 | # Delete Rigged Info (Including Bind Pose Matrices) 38 | bpy.ops.object.select_all(action='DESELECT') 39 | bpy.data.objects[armature_name].select_set(True) 40 | bpy.ops.object.delete() 41 | 42 | def clear_scene(): 43 | bpy.ops.object.select_all(action='SELECT') 44 | bpy.ops.object.delete() 45 | 46 | project_path = "./Projects/autorigging" 47 | #directories = sorted([f for f in listdir(data_path) if not f.startswith(".")]) 48 | source_data_path = "mixamo/original" 49 | dest_data_path = "mixamo/unrigged" 50 | source_files = sorted([f for f in listdir(join(project_path, source_data_path)) if f.endswith(".fbx") and not f.endswith("_unrigged.fbx")]) 51 | 52 | for f in tqdm(source_files): 53 | source_path = join(project_path, source_data_path, f) 54 | dest_path = join(project_path, dest_data_path, f) 55 | 56 | # Delete all objects 57 | clear_scene() 58 | # Load FBX 59 | bpy.ops.import_scene.fbx(filepath=source_path) 60 | # Delete and Merge operation in blender 61 | delete_rig_and_join_mesh() 62 | # Save FBX 63 | bpy.ops.export_scene.fbx(filepath=dest_path) 64 | # Save OBJ 65 | bpy.ops.export_scene.obj(filepath=dest_path.split(".fbx")[0]+".obj") 66 | 67 | 68 | 69 | -------------------------------------------------------------------------------- /preprocess/blender_utils/remove_vn_from_obj.py: -------------------------------------------------------------------------------- 1 | import bpy 2 | import os 3 | import math 4 | from os import listdir, mkdir 5 | from os.path import join, exists 6 | 7 | def clear_scene(): 8 | bpy.ops.object.select_all(action='SELECT') 9 | bpy.ops.object.delete() 10 | 11 | project_path = '/home/whitealex95/Projects/autorigging' 12 | original_obj_path = join(project_path, 'mixamo/unrigged_normalized_decimated_4k_up') 13 | processed_obj_path = join(project_path, 'mixamo/unrigged_normalized_decimated_4k_up_processed') 14 | 15 | ## Load all animations 16 | obj_list = sorted([ obj for obj in listdir(original_obj_path) if obj.endswith('.obj')]) 17 | 18 | for obj in obj_list: 19 | clear_scene() 20 | source_path = join(original_obj_path, obj) 21 | dest_path = join(processed_obj_path, obj) 22 | bpy.ops.import_scene.obj(filepath=source_path) 23 | obj_export_args = { 24 | 'use_animation': False, 25 | 'use_materials': False, # Do not create .mtl files 26 | 'keep_vertex_order': True, 27 | 'use_triangles': True, 28 | # Default Settings starting from below 29 | 'use_blen_objects': True, # Objects as OBJ Objects 30 | 'use_normals': False, # Apply Normals 31 | 'use_uvs': False, # Include UVs 32 | } 33 | bpy.ops.export_scene.obj(filepath=dest_path, **obj_export_args) -------------------------------------------------------------------------------- /preprocess/collada_utils/extract_weights.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | BASE_DIR=os.path.abspath('') 5 | sys.path.append(BASE_DIR) 6 | sys.path.append(os.path.join(BASE_DIR, '/pycollada')) 7 | from os import listdir, mkdir 8 | from os.path import join, exists 9 | from collada import * 10 | 11 | project_path = '/home/whitealex95/Projects/autorigging' 12 | animation_path = join(project_path, 'mixamo/animated_4k') 13 | vertex_weight_path = join(project_path, 'mixamo/weights_4k') 14 | joint_cnt = 22 15 | 16 | animation_models_list = sorted(listdir(animation_path)) 17 | animation_models_list = ['zombie'] 18 | for animation_model in animation_models_list: 19 | animations_list = sorted([ anim for anim in listdir(join(animation_path, animation_model)) if anim.endswith('.dae')]) 20 | animation = animations_list[0] 21 | 22 | source_path = join(animation_path, animation_model, animation) 23 | dest_path = join(vertex_weight_path, animation_model+'.csv') 24 | 25 | mesh = Collada(source_path) 26 | vertex_cnt = mesh.controllers[0].vcounts.__len__() 27 | output = np.zeros((vertex_cnt, joint_cnt)) 28 | 29 | joint_dictionary = {} 30 | for joint in range(joint_cnt - 2): 31 | jointname = mesh.animations[joint].id.split('_')[1][:-5] 32 | joint_dictionary[jointname]=joint # assign joint number for each joint name 33 | joint_dictionary['LeftHand']=20 34 | joint_dictionary['RightHand']=21 35 | # joint_dictionary['mixamorig_LeftHand']=20 36 | # joint_dictionary['mixamorig_RightHand']=21 37 | 38 | for idx in range(vertex_cnt): 39 | count = 0 40 | for jidx in mesh.controllers[0].joint_index[idx]: 41 | jointname = mesh.controllers[0].weight_joints[jidx].split('_')[1] 42 | correct_jidx = joint_dictionary[jointname] 43 | output[idx][correct_jidx] = mesh.controllers[0].weights[mesh.controllers[0].weight_index[idx][count]] 44 | count = count + 1 45 | print(dest_path) 46 | np.savetxt(dest_path, output, delimiter=',') 47 | -------------------------------------------------------------------------------- /preprocess/collada_utils/transform_parser.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | BASE_DIR=os.path.abspath(__file__) 5 | sys.path.append(BASE_DIR) 6 | sys.path.append(os.path.join(BASE_DIR, '/pycollada')) 7 | from os import listdir 8 | from os.path import join 9 | from collada import * 10 | 11 | project_path = '/home/whitealex95/Projects/autorigging' 12 | animation_path = join(project_path, 'mixamo/animated_4k') 13 | transforms_path = join(project_path, 'mixamo/transforms_4k') 14 | 15 | 16 | indices = [*range(1,10)] 17 | #wrongIndices = [1, 3, 7, 8, 14] 18 | #for x in wrongIndices: 19 | # indices.remove(x) 20 | 21 | joint_num = 22 22 | action = ["Back Squat", "Drunk Walk Backwards", "Samba Dancing", "Shoved Reaction With Spin_1", "Shoved Reaction With Spin_2", "Warming Up_1", "Warming Up_2"] 23 | # fCount = [15, 11, 220, 37, 37, 77, 77] 24 | fCount = [18, 13, 274, 45, 45, 95, 95] 25 | end_string = 'Matrix-animation-output-transform' 26 | 27 | animation_models_list = sorted(listdir(animation_path)) 28 | for animation_model in animation_models_list: 29 | animations_list = sorted([ anim for anim in listdir(join(animation_path, animation_model)) if anim.endswith('.dae')]) 30 | 31 | if not os.path.exists(join(transforms_path, animation_model)): 32 | os.makedirs(join(transforms_path, animation_model)) 33 | 34 | for animation in animations_list: 35 | source_path = join(animation_path, animation_model, animation) 36 | dest_path = lambda frame: join(transforms_path, animation_model, animation.split('.dae')[0] + '_' + str(frame).zfill(6)+'.csv') 37 | dest_ibm_path = join(transforms_path, animation_model, 'ibm.csv') 38 | 39 | mesh = Collada(source_path) 40 | if mesh.controllers[0].max_joint_index != joint_num - 1: 41 | # print("number of joint(shape index: "+ str(modIdx) + f") is not {joint_num}.") 42 | __import__('pdb').set_trace() 43 | continue 44 | 45 | for frame in range(fCount[action.index(animation.split('.dae')[0])]): 46 | output = np.zeros((joint_num, 12)) 47 | ibm = np.zeros((joint_num, 12)) 48 | for joint in range(joint_num): 49 | if 0<= joint and joint<=19: 50 | animIdx = mesh.animations[joint].id[:-4] + end_string 51 | transform_matrices = mesh.animations[joint].sourceById[animIdx].data 52 | transform_matrices = transform_matrices.reshape(-1, 16) 53 | output[joint] = transform_matrices[frame][:12] 54 | ibm[joint] = mesh.controllers[0].joint_matrices[animIdx.split('-')[0]].reshape(16,)[:12] 55 | elif joint==20: # mixamorig_LeftHand 56 | try: 57 | parentIBM = mesh.controllers[0].joint_matrices['mixamorig_LeftForeArm'] 58 | myIBM = mesh.controllers[0].joint_matrices['mixamorig_LeftHand'] 59 | except: 60 | parentIBM = mesh.controllers[0].joint_matrices['boss_LeftForeArm'] 61 | myIBM = mesh.controllers[0].joint_matrices['boss_LeftHand'] 62 | output[joint] = np.matmul(parentIBM, np.linalg.inv(myIBM)).reshape(16,)[:12] 63 | ibm[joint] = myIBM.reshape(16,)[:12] 64 | elif joint==21: # mixamorig_RightHand 65 | try: 66 | parentIBM = mesh.controllers[0].joint_matrices['mixamorig_RightForeArm'] 67 | myIBM = mesh.controllers[0].joint_matrices['mixamorig_RightHand'] 68 | except: 69 | parentIBM = mesh.controllers[0].joint_matrices['boss_RightForeArm'] 70 | myIBM = mesh.controllers[0].joint_matrices['boss_RightHand'] 71 | output[joint] = np.matmul(parentIBM, np.linalg.inv(myIBM)).reshape(16,)[:12] 72 | ibm[joint] = myIBM.reshape(16,)[:12] 73 | print("[Saving JM] ", dest_path(frame), end='\r') 74 | np.savetxt(dest_path(frame), output, delimiter=',') 75 | print('[Saving IBM] ', dest_ibm_path) 76 | np.savetxt(dest_ibm_path, ibm, delimiter=',') 77 | 78 | def maketree(num_joints=22): 79 | # we assume that , 80 | # INDEX 20 -> mixamorig_LeftHand 81 | # INDEX 21 -> mixamorig_RightHan 82 | child = [[] for x in range(num_joints)] 83 | child[0] = [1, 12, 16] 84 | child[1] = [2] 85 | child[2] = [3] 86 | child[3] = [4, 6, 9] 87 | child[4] = [5] 88 | #child[5] 89 | child[6] = [7] 90 | child[7] = [8] 91 | child[8] = [20] ##### ATTENTION 92 | child[9] = [10] 93 | child[10] = [11] 94 | child[11] = [21] ##### ATTENTION 95 | child[12]=[13] 96 | child[13]=[14] 97 | child[14]=[15] 98 | #child[15] 99 | child[16]=[17] 100 | child[17]=[18] 101 | child[18]=[19] 102 | #child[19] 103 | 104 | return child 105 | -------------------------------------------------------------------------------- /preprocess/volume/binvox: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whitealex95/autorigging-bipedal/2851d5894ecdd90c14825dcf7398ab0d3d41433e/preprocess/volume/binvox -------------------------------------------------------------------------------- /preprocess/volume/flood-fill/binvox_rw.py: -------------------------------------------------------------------------------- 1 | ../util/binvox_rw.py -------------------------------------------------------------------------------- /preprocess/volume/flood-fill/main.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 6, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "data": { 10 | "text/plain": [ 11 | "60000" 12 | ] 13 | }, 14 | "execution_count": 6, 15 | "metadata": {}, 16 | "output_type": "execute_result" 17 | } 18 | ], 19 | "source": [ 20 | "import sys\n", 21 | "sys.setrecursionlimit(60000) # 3000\n", 22 | "sys.getrecursionlimit()" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 7, 28 | "metadata": {}, 29 | "outputs": [ 30 | { 31 | "name": "stdout", 32 | "output_type": "stream", 33 | "text": [ 34 | "number of data 0\n" 35 | ] 36 | }, 37 | { 38 | "ename": "IndexError", 39 | "evalue": "list index out of range", 40 | "output_type": "error", 41 | "traceback": [ 42 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 43 | "\u001b[0;31mIndexError\u001b[0m Traceback (most recent call last)", 44 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0;31m# for i, path in enumerate(paths):\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 12\u001b[0;31m \u001b[0mpath\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpaths\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 13\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mutil\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mread_binvox\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 45 | "\u001b[0;31mIndexError\u001b[0m: list index out of range" 46 | ] 47 | } 48 | ], 49 | "source": [ 50 | "import numpy as np\n", 51 | "import glob\n", 52 | "import os\n", 53 | "import util\n", 54 | "import time\n", 55 | "\n", 56 | "# paths = glob.glob(\"path/to/*.binvox\")\n", 57 | "paths = glob.glob(\"../data/aj.binvox\")\n", 58 | "print(\"number of data\", len(paths))\n", 59 | "\n", 60 | "# for i, path in enumerate(paths):\n", 61 | "path = paths[0]\n", 62 | "\n", 63 | "data = util.read_binvox(path)\n", 64 | "invdata = np.ones(data.shape, dtype=np.bool) & ~data\n", 65 | "start = util.start_index(invdata)\n", 66 | "\n", 67 | "xdim, ydim, zdim = data.shape\n", 68 | "fill_value = False\n", 69 | "while start is not None:\n", 70 | " old_value = invdata[start[0], start[1], start[2]]\n", 71 | " stack = set([(start[0], start[1], start[2])])\n", 72 | "\n", 73 | " if fill_value == old_value:\n", 74 | " raise ValueError(\"Filling region with same value\")\n", 75 | "\n", 76 | " while stack:\n", 77 | " x, y, z = stack.pop()\n", 78 | " if invdata[x, y, z] == old_value:\n", 79 | " invdata[x, y, z] = fill_value\n", 80 | " if x > 0:\n", 81 | " stack.add((x - 1, y, z))\n", 82 | " if x < (xdim - 1):\n", 83 | " stack.add((x + 1, y, z))\n", 84 | " if y > 0:\n", 85 | " stack.add((x, y - 1, z))\n", 86 | " if y < (ydim - 1):\n", 87 | " stack.add((x, y + 1, z))\n", 88 | " if z > 0:\n", 89 | " stack.add((x, y, z - 1))\n", 90 | " if z < (zdim - 1):\n", 91 | " stack.add((x, y, z + 1))\n", 92 | "\n", 93 | " start = util.start_index(invdata)\n", 94 | "\n", 95 | "data += invdata\n", 96 | "util.save_binvox(path.split('.binvox')[0] + '_ff.binvox' , data)\n", 97 | "print(0, path)\n", 98 | "import binvox_rw\n", 99 | "\n", 100 | "with open('../data/aj_ff.binvox', 'rb') as f:\n", 101 | " data = binvox_rw.read_as_3d_array(f)" 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": 8, 107 | "metadata": {}, 108 | "outputs": [ 109 | { 110 | "ename": "NameError", 111 | "evalue": "name 'data' is not defined", 112 | "output_type": "error", 113 | "traceback": [ 114 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 115 | "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", 116 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", 117 | "\u001b[0;31mNameError\u001b[0m: name 'data' is not defined" 118 | ] 119 | } 120 | ], 121 | "source": [ 122 | "data.data" 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": 9, 128 | "metadata": {}, 129 | "outputs": [ 130 | { 131 | "ename": "FileNotFoundError", 132 | "evalue": "[Errno 2] No such file or directory: '../data/aj.binvox'", 133 | "output_type": "error", 134 | "traceback": [ 135 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 136 | "\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)", 137 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mwith\u001b[0m \u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'../data/aj.binvox'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'rb'\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbinvox_rw\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mread_as_3d_array\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'tmp.binvox'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'wb'\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwrite\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'tmp.binvox'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'rb'\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mf2\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 138 | "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: '../data/aj.binvox'" 139 | ] 140 | } 141 | ], 142 | "source": [ 143 | "with open('../data/aj.binvox', 'rb') as f:\n", 144 | " data = binvox_rw.read_as_3d_array(f)\n", 145 | "with open('tmp.binvox', 'wb') as f:\n", 146 | " data.write(f)\n", 147 | "with open('tmp.binvox', 'rb') as f2:\n", 148 | " data2 = binvox_rw.read_as_3d_array(f2)\n", 149 | "print(data.data, data2.data)\n" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": null, 155 | "metadata": {}, 156 | "outputs": [], 157 | "source": [] 158 | }, 159 | { 160 | "cell_type": "code", 161 | "execution_count": null, 162 | "metadata": {}, 163 | "outputs": [], 164 | "source": [] 165 | } 166 | ], 167 | "metadata": { 168 | "kernelspec": { 169 | "display_name": "autorigging", 170 | "language": "python", 171 | "name": "autorigging" 172 | }, 173 | "language_info": { 174 | "codemirror_mode": { 175 | "name": "ipython", 176 | "version": 3 177 | }, 178 | "file_extension": ".py", 179 | "mimetype": "text/x-python", 180 | "name": "python", 181 | "nbconvert_exporter": "python", 182 | "pygments_lexer": "ipython3", 183 | "version": "3.7.7" 184 | } 185 | }, 186 | "nbformat": 4, 187 | "nbformat_minor": 4 188 | } 189 | -------------------------------------------------------------------------------- /preprocess/volume/flood-fill/main.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import glob 3 | import os 4 | import util 5 | import time 6 | from os import listdir, mkdir 7 | from os.path import join, exists 8 | from tqdm import tqdm 9 | 10 | # paths = glob.glob("path/to/*.binvox") 11 | # paths = glob.glob("../data/aj.binvox") 12 | # print("number of data", len(paths)) 13 | 14 | project_path = '/home/whitealex95/Projects/autorigging' 15 | obj_path = join(project_path, 'mixamo_4k/objs') 16 | character_list = sorted(listdir(obj_path)) 17 | 18 | for character in tqdm(character_list): 19 | obj_list = sorted([obj for obj in listdir(os.path.join(obj_path, character)) if obj.endswith('.obj')]) 20 | for obj in obj_list: 21 | path = join(obj_path, character, obj.split('.obj')[0] + '.binvox') 22 | 23 | data = util.read_binvox(path) 24 | invdata = np.ones(data.shape, dtype=np.bool) & ~data 25 | start = util.start_index(invdata) 26 | 27 | xdim, ydim, zdim = data.shape 28 | fill_value = False 29 | while start is not None: 30 | old_value = invdata[start[0], start[1], start[2]] 31 | stack = set([(start[0], start[1], start[2])]) 32 | 33 | if fill_value == old_value: 34 | raise ValueError("Filling region with same value") 35 | 36 | while stack: 37 | x, y, z = stack.pop() 38 | if invdata[x, y, z] == old_value: 39 | invdata[x, y, z] = fill_value 40 | if x > 0: 41 | stack.add((x - 1, y, z)) 42 | if x < (xdim - 1): 43 | stack.add((x + 1, y, z)) 44 | if y > 0: 45 | stack.add((x, y - 1, z)) 46 | if y < (ydim - 1): 47 | stack.add((x, y + 1, z)) 48 | if z > 0: 49 | stack.add((x, y, z - 1)) 50 | if z < (zdim - 1): 51 | stack.add((x, y, z + 1)) 52 | 53 | start = util.start_index(invdata) 54 | 55 | data += invdata 56 | util.save_binvox(path, data) # overwrite previous binvox 57 | # util.save_binvox(path.split('.binvox')[0] + '_ff.binvox' , data) -------------------------------------------------------------------------------- /preprocess/volume/flood-fill/util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import binvox_rw as binvox 3 | import os 4 | 5 | def read_binvox(filename): 6 | with open(filename, 'rb') as f: 7 | model = binvox.read_as_3d_array(f) 8 | return model.data 9 | 10 | def save_binvox(filename, data): 11 | dims = data.shape 12 | translate = [0.0, 0.0, 0.0] 13 | model = binvox.Voxels(data, dims, translate, 1.0, 'xyz') 14 | with open(filename, 'wb') as f: 15 | model.write(f) 16 | 17 | def extract_name(filename): 18 | head, tail = os.path.split(filename) 19 | name, ext = os.path.splitext(tail) 20 | return name 21 | 22 | def start_index(invdata): 23 | indices = np.transpose(np.nonzero(invdata[0, :, :])) 24 | if indices.shape[0] > 0: 25 | y, z = indices[0] 26 | return [0, y, z] 27 | 28 | indices = np.transpose(np.nonzero(invdata[:, 0, :])) 29 | if indices.shape[0] > 0: 30 | x, z = indices[0] 31 | return [x, 0, z] 32 | 33 | indices = np.transpose(np.nonzero(invdata[:, :, 0])) 34 | if indices.shape[0] > 0: 35 | x, y = indices[0] 36 | return [x, y, 0] 37 | 38 | indices = np.transpose(np.nonzero(invdata[-1, :, :])) 39 | if indices.shape[0] > 0: 40 | y, z = indices[0] 41 | return [-1, y, z] 42 | 43 | indices = np.transpose(np.nonzero(invdata[:, -1, :])) 44 | if indices.shape[0] > 0: 45 | x, z = indices[0] 46 | return [x, -1, z] 47 | 48 | indices = np.transpose(np.nonzero(invdata[:, :, -1])) 49 | if indices.shape[0] > 0: 50 | x, y = indices[0] 51 | return [x, y, -1] 52 | 53 | return None 54 | -------------------------------------------------------------------------------- /preprocess/volume/obj_rot_fixed_to_binvox.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | from os import listdir, mkdir 4 | from os.path import join, exists 5 | from tqdm import tqdm 6 | import subprocess 7 | 8 | data_dir = '/home/whitealex95/Projects/autorigging/autorigging/data/mixamo' 9 | # characters_list = sorted(os.listdir(os.path.join(data_dir, 'animated'))) 10 | characters_list = ['Ch24_nonPBR', 'Ch29_nonPBR', 'kaya', 11 | 'maria_j_j_ong', 'paladin_j_nordstrom', 'pumpkinhulk_l_shaw'] 12 | characters_list = ['Ch14_nonPBR'] 13 | print(characters_list) 14 | # motions_list = sorted([f.split('.obj')[0] for f in os.listdir(os.path.join(data_dir, 'objs', characters_list[0])) if 15 | # f.endswith('.obj') and f != 'bindpose.obj']) 16 | motions_list = ['Back Squat', 'Drunk Walk Backwards', 'Samba Dancing', 'Shoved Reaction With Spin_1', 17 | 'Shoved Reaction With Spin_2', 'Warming Up_1', 'Warming Up_2'] 18 | # #TestKim 19 | # motions_list = ['Back Squat'] 20 | # motions_list = ['\ '.join(motion.split(' ')) for motion in motions_list] 21 | print(motions_list) 22 | 23 | frame_counts = [18, 13, 274, 45, 24 | 45, 95, 95] 25 | 26 | 27 | for character in characters_list: 28 | for motion_idx, motion in enumerate(motions_list): 29 | for frame_number in range(frame_counts[motion_idx]): 30 | frame_name = '%s_%06d'%(motion, frame_number) 31 | 32 | obj_filename = data_dir + '/objs_fixed/'+character+'/%s.obj'%(frame_name) 33 | 34 | try: 35 | # command = "xvfb-run -s '-screen 0 640x480x24' ./binvox -d 82 -pb '{}'".format(source_path) 36 | command = "./binvox -d 82 -pb '{}'".format(obj_filename) 37 | print(command) 38 | os.system(command) 39 | # subprocess.Popen(command) 40 | except: 41 | __import__('pdb').set_trace() 42 | -------------------------------------------------------------------------------- /preprocess/volume/obj_to_binvox2.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | from os import listdir, mkdir 4 | from os.path import join, exists 5 | from tqdm import tqdm 6 | 7 | project_path = '/home/whitealex95/Projects/autorigging' 8 | obj_path = join(project_path, 'mixamo/objs') 9 | character_list = sorted(listdir(obj_path)) 10 | 11 | 12 | print(character_list) 13 | __import__('pdb').set_trace() 14 | # character_list = [character for character in character_list if not character.startswith('Ch')] 15 | character_list = ['Ch15_nonPBR'] 16 | # character_list = ['Ch34_nonPBR', 'Ch35_nonPBR', 'Ch36_nonPBR', 'Ch39_nonPBR', 'Ch40_nonPBR', 'Ch42_nonPBR', 'Ch44_nonPBR', 'Ch45_nonPBR', 'Ch46_nonPBR'] 17 | for character in tqdm(character_list): 18 | obj_list = sorted([obj for obj in listdir(os.path.join(obj_path, character)) if obj.endswith('.obj')]) 19 | for obj in obj_list: 20 | source_path = join(obj_path, character, obj) 21 | if not os.path.exists(source_path.split('.obj')[0]+'.binvox'): 22 | try: 23 | source_path = join(obj_path, character, obj) 24 | # command = "xvfb-run -s '-screen 0 640x480x24' ./binvox -d 82 -pb '{}'".format(source_path) 25 | command = "./binvox -d 82 -pb '{}'".format(source_path) 26 | print(command) 27 | os.system(command) 28 | except: 29 | __import__('pdb').set_trace() 30 | -------------------------------------------------------------------------------- /preprocess/volume/util/binvox_rw.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2012 Daniel Maturana 2 | # This file is part of binvox-rw-py. 3 | # 4 | # binvox-rw-py is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU General Public License as published by 6 | # the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # binvox-rw-py is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU General Public License 15 | # along with binvox-rw-py. If not, see . 16 | # 17 | 18 | 19 | import numpy as np 20 | import struct 21 | 22 | 23 | class Voxels(object): 24 | """ Holds a binvox model. 25 | data is either a three-dimensional numpy boolean array (dense representation) 26 | or a two-dimensional numpy float array (coordinate representation). 27 | 28 | dims, translate and scale are the model metadata. 29 | 30 | dims are the voxel dimensions, e.g. [32, 32, 32] for a 32x32x32 model. 31 | 32 | scale and translate relate the voxels to the original model coordinates. 33 | 34 | To translate voxel coordinates i, j, k to original coordinates x, y, z: 35 | 36 | x_n = (i+.5)/dims[0] 37 | y_n = (j+.5)/dims[1] 38 | z_n = (k+.5)/dims[2] 39 | x = scale*x_n + translate[0] 40 | y = scale*y_n + translate[1] 41 | z = scale*z_n + translate[2] 42 | 43 | """ 44 | 45 | def __init__(self, data, dims, translate, scale, axis_order): 46 | self.data = data 47 | self.dims = dims 48 | self.translate = translate 49 | self.scale = scale 50 | assert (axis_order in ('xzy', 'xyz')) 51 | self.axis_order = axis_order 52 | 53 | def clone(self): 54 | data = self.data.copy() 55 | dims = self.dims[:] 56 | translate = self.translate[:] 57 | return Voxels(data, dims, translate, self.scale, self.axis_order) 58 | 59 | def write(self, fp): 60 | write(self, fp) 61 | 62 | def read_header(fp): 63 | """ Read binvox header. Mostly meant for internal use. 64 | """ 65 | line = fp.readline().strip() 66 | if not line.startswith(b'#binvox'): 67 | raise IOError('Not a binvox file') 68 | dims = list(map(int, fp.readline().strip().split(b' ')[1:])) 69 | translate = list(map(float, fp.readline().strip().split(b' ')[1:])) 70 | scale = list(map(float, fp.readline().strip().split(b' ')[1:]))[0] 71 | line = fp.readline() 72 | 73 | return dims, translate, scale 74 | 75 | def read_as_3d_array(fp, fix_coords=True): 76 | """ Read binary binvox format as array. 77 | 78 | Returns the model with accompanying metadata. 79 | 80 | Voxels are stored in a three-dimensional numpy array, which is simple and 81 | direct, but may use a lot of memory for large models. (Storage requirements 82 | are 8*(d^3) bytes, where d is the dimensions of the binvox model. Numpy 83 | boolean arrays use a byte per element). 84 | 85 | Doesn't do any checks on input except for the '#binvox' line. 86 | """ 87 | dims, translate, scale = read_header(fp) 88 | raw_data = np.frombuffer(fp.read(), dtype=np.uint8) 89 | # if just using reshape() on the raw data: 90 | # indexing the array as array[i,j,k], the indices map into the 91 | # coords as: 92 | # i -> x 93 | # j -> z 94 | # k -> y 95 | # if fix_coords is true, then data is rearranged so that 96 | # mapping is 97 | # i -> x 98 | # j -> y 99 | # k -> z 100 | values, counts = raw_data[::2], raw_data[1::2] 101 | data = np.repeat(values, counts).astype(np.bool) 102 | data = data.reshape(dims) 103 | if fix_coords: 104 | # xzy to xyz TODO the right thing 105 | data = np.transpose(data, (0, 2, 1)) 106 | axis_order = 'xyz' 107 | else: 108 | axis_order = 'xzy' 109 | return Voxels(data, dims, translate, scale, axis_order) 110 | 111 | def read_as_coord_array(fp, fix_coords=True): 112 | """ Read binary binvox format as coordinates. 113 | 114 | Returns binvox model with voxels in a "coordinate" representation, i.e. an 115 | 3 x N array where N is the number of nonzero voxels. Each column 116 | corresponds to a nonzero voxel and the 3 rows are the (x, z, y) coordinates 117 | of the voxel. (The odd ordering is due to the way binvox format lays out 118 | data). Note that coordinates refer to the binvox voxels, without any 119 | scaling or translation. 120 | 121 | Use this to save memory if your model is very sparse (mostly empty). 122 | 123 | Doesn't do any checks on input except for the '#binvox' line. 124 | """ 125 | dims, translate, scale = read_header(fp) 126 | raw_data = np.frombuffer(fp.read(), dtype=np.uint8) 127 | 128 | values, counts = raw_data[::2], raw_data[1::2] 129 | 130 | sz = np.prod(dims) 131 | index, end_index = 0, 0 132 | end_indices = np.cumsum(counts) 133 | indices = np.concatenate(([0], end_indices[:-1])).astype(end_indices.dtype) 134 | 135 | values = values.astype(np.bool) 136 | indices = indices[values] 137 | end_indices = end_indices[values] 138 | 139 | nz_voxels = [] 140 | for index, end_index in zip(indices, end_indices): 141 | nz_voxels.extend(range(index, end_index)) 142 | nz_voxels = np.array(nz_voxels) 143 | # TODO are these dims correct? 144 | # according to docs, 145 | # index = x * wxh + z * width + y; // wxh = width * height = d * d 146 | 147 | x = nz_voxels / (dims[0]*dims[1]) 148 | zwpy = nz_voxels % (dims[0]*dims[1]) # z*w + y 149 | z = zwpy / dims[0] 150 | y = zwpy % dims[0] 151 | if fix_coords: 152 | data = np.vstack((x, y, z)) 153 | axis_order = 'xyz' 154 | else: 155 | data = np.vstack((x, z, y)) 156 | axis_order = 'xzy' 157 | 158 | #return Voxels(data, dims, translate, scale, axis_order) 159 | return Voxels(np.ascontiguousarray(data), dims, translate, scale, axis_order) 160 | 161 | def dense_to_sparse(voxel_data, dtype=np.int): 162 | """ From dense representation to sparse (coordinate) representation. 163 | No coordinate reordering. 164 | """ 165 | if voxel_data.ndim!=3: 166 | raise ValueError('voxel_data is wrong shape; should be 3D array.') 167 | return np.asarray(np.nonzero(voxel_data), dtype) 168 | 169 | def sparse_to_dense(voxel_data, dims, dtype=np.bool): 170 | if voxel_data.ndim!=2 or voxel_data.shape[0]!=3: 171 | raise ValueError('voxel_data is wrong shape; should be 3xN array.') 172 | if np.isscalar(dims): 173 | dims = [dims]*3 174 | dims = np.atleast_2d(dims).T 175 | # truncate to integers 176 | xyz = voxel_data.astype(np.int) 177 | # discard voxels that fall outside dims 178 | valid_ix = ~np.any((xyz < 0) | (xyz >= dims), 0) 179 | xyz = xyz[:,valid_ix] 180 | out = np.zeros(dims.flatten(), dtype=dtype) 181 | out[tuple(xyz)] = True 182 | return out 183 | 184 | #def get_linear_index(x, y, z, dims): 185 | #""" Assuming xzy order. (y increasing fastest. 186 | #TODO ensure this is right when dims are not all same 187 | #""" 188 | #return x*(dims[1]*dims[2]) + z*dims[1] + y 189 | 190 | def bwrite(fp,s): 191 | fp.write(s.encode()) 192 | 193 | def write_pair(fp,state, ctr): 194 | fp.write(struct.pack('B',state)) 195 | fp.write(struct.pack('B',ctr)) 196 | 197 | def write(voxel_model, fp): 198 | """ Write binary binvox format. 199 | 200 | Note that when saving a model in sparse (coordinate) format, it is first 201 | converted to dense format. 202 | 203 | Doesn't check if the model is 'sane'. 204 | 205 | """ 206 | if voxel_model.data.ndim==2: 207 | # TODO avoid conversion to dense 208 | dense_voxel_data = sparse_to_dense(voxel_model.data, voxel_model.dims) 209 | else: 210 | dense_voxel_data = voxel_model.data 211 | 212 | bwrite(fp, '#binvox 1\n') 213 | bwrite(fp, 'dim ' + ' '.join(map(str, voxel_model.dims)) + '\n') 214 | bwrite(fp, 'translate ' + ' '.join(map(str, voxel_model.translate)) + '\n') 215 | bwrite(fp, 'scale ' + str(voxel_model.scale) + '\n') 216 | bwrite(fp, 'data\n') 217 | if not voxel_model.axis_order in ('xzy', 'xyz'): 218 | raise ValueError('Unsupported voxel model axis order') 219 | 220 | if voxel_model.axis_order=='xzy': 221 | voxels_flat = dense_voxel_data.flatten() 222 | elif voxel_model.axis_order=='xyz': 223 | voxels_flat = np.transpose(dense_voxel_data, (0, 2, 1)).flatten() 224 | 225 | # keep a sort of state machine for writing run length encoding 226 | state = voxels_flat[0] 227 | ctr = 0 228 | for c in voxels_flat: 229 | if c==state: 230 | ctr += 1 231 | # if ctr hits max, dump 232 | if ctr==255: 233 | write_pair(fp, state, ctr) 234 | ctr = 0 235 | else: 236 | # if switch state, dump 237 | write_pair(fp, state, ctr) 238 | state = c 239 | ctr = 1 240 | # flush out remainders 241 | if ctr > 0: 242 | write_pair(fp, state, ctr) 243 | 244 | if __name__ == '__main__': 245 | import doctest 246 | doctest.testmod() 247 | -------------------------------------------------------------------------------- /preprocess/volume/util/rigging_parser/obj_parser.py: -------------------------------------------------------------------------------- 1 | """ 2 | My simple obj file parser. 3 | """ 4 | 5 | import numpy as np 6 | 7 | 8 | class Mesh_obj: 9 | def __init__(self, filename=None): 10 | self.v = [] 11 | self.vt = [] 12 | self.vn = [] 13 | self.f = [] 14 | self.hasTex = False 15 | self.hasNorm = False 16 | self.mtlfile = None 17 | self.materialList = [] 18 | if filename is not None: 19 | self.load(filename) 20 | 21 | def load(self, obj_filename): 22 | obj_file = open(obj_filename, 'r') 23 | line = obj_file.readline() 24 | while line: 25 | if len(line.split()) > 1 and line.split()[0] == 'v': 26 | self.v.append([float(line.split()[1]), float(line.split()[2]), float(line.split()[3])]) 27 | elif len(line.split()) > 1 and line.split()[0] == 'vt': 28 | self.vt.append([float(line.split()[1]), float(line.split()[2])]) 29 | elif len(line.split()) > 1 and line.split()[0] == 'vn': 30 | self.vn.append([float(line.split()[1]), float(line.split()[2]), float(line.split()[3])]) 31 | elif len(line.split()) > 1 and line.split()[0] == 'f': 32 | if '//' in line.split()[1] and len(line.split()[1].split('//'))==2: 33 | self.hasNorm = True 34 | cur_face = [] 35 | for ver in line.split()[1:]: 36 | cur_face.append(int(ver.split('//')[0])) 37 | self.f.append(cur_face) 38 | elif len(line.split()[1].split('/')) ==2: 39 | self.hasTex = True 40 | cur_face = [] 41 | for ver in line.split()[1:]: 42 | cur_face.append(list(map(int, ver.split('/')))) 43 | self.f.append(cur_face) 44 | elif len(line.split()[1].split('/')) ==3: 45 | self.hasTex = True 46 | self.hasNorm = True 47 | cur_face = [] 48 | for ver in line.split()[1:]: 49 | cur_face.append(ver.split('/')) 50 | self.f.append(cur_face) 51 | else: 52 | cur_face = [] 53 | for ver in line.split()[1:]: 54 | cur_face.append(int(ver)) 55 | self.f.append(cur_face) 56 | elif 'mtllib ' in line: 57 | self.mtlfile = line.split()[1] 58 | line = obj_file.readline() 59 | obj_file.close() 60 | # print len(self.v) 61 | self.v = np.stack(self.v, axis=0) 62 | if self.f: 63 | self.f = np.stack(self.f, axis=0) 64 | if self.vt: 65 | self.vt = np.stack(self.vt, axis=0) 66 | if self.vn: 67 | self.vn = np.stack(self.vn, axis=0) 68 | 69 | def write(self, obj_filename): 70 | f_out = open(obj_filename, 'w') 71 | f_out.write('#Export Obj file with Mesh_obj\n') 72 | #if self.mtlfile != '': 73 | # f_out.write('mtllib '+ self.mtlfile + '\n') 74 | for i in range(self.v.shape[0]): 75 | f_out.write('v {0} {1} {2}\n'.format(self.v[i,0],self.v[i,1],self.v[i,2])) 76 | if self.hasTex: 77 | for i in range(self.vt.shape[0]): 78 | f_out.write('vt {0} {1}\n'.format(self.vt[i, 0], self.vt[i, 1])) 79 | if self.hasNorm: 80 | for i in range(self.vn.shape[0]): 81 | f_out.write('vn {0} {1} {2}\n'.format(self.vn[i, 0], self.vn[i, 1], self.vn[i, 2])) 82 | for f in self.f: 83 | if self.hasTex and self.hasNorm: 84 | f_out.write('f') 85 | for i in range(len(f)): 86 | f_out.write(' {0}/{1}/{2}'.format(f[i][0],f[i][1],f[i][2])) 87 | f_out.write('\n') 88 | elif self.hasTex and not self.hasNorm: 89 | f_out.write('f') 90 | for i in range(len(f)): 91 | f_out.write(' {0}/{1}'.format(f[i][0], f[i][1])) 92 | f_out.write('\n') 93 | elif self.hasNorm and not self.hasTex: 94 | f_out.write('f') 95 | for i in range(len(f)): 96 | f_out.write(' {0}//{1}'.format(f[i][0], f[i][1])) 97 | f_out.write('\n') 98 | elif not self.hasTex and not self.hasNorm: 99 | f_out.write('f') 100 | for i in range(len(f)): 101 | f_out.write(' {0}'.format(f[i])) 102 | f_out.write('\n') 103 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | configargparse 2 | pyyaml 3 | open3d==0.9.0 4 | bvh 5 | tensorflow-gpu # for tensorboard 6 | jupyter 7 | ipykernel 8 | ipyvolume 9 | tqdm 10 | trimesh -------------------------------------------------------------------------------- /sdf visualizer.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import ipyvolume as ipv\n", 11 | "import os" 12 | ] 13 | }, 14 | { 15 | "cell_type": "markdown", 16 | "metadata": {}, 17 | "source": [ 18 | "## Load and parse .sdf" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 5, 24 | "metadata": { 25 | "scrolled": true 26 | }, 27 | "outputs": [ 28 | { 29 | "ename": "FileNotFoundError", 30 | "evalue": "[Errno 2] No such file or directory: '/home/whitealex95/Projects/autorigging/autorigging/preprocess/volume/data/aj.sdf'", 31 | "output_type": "error", 32 | "traceback": [ 33 | "\u001b[0;31m--------------------------------------------------------------\u001b[0m", 34 | "\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)", 35 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0msdf_path\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgetcwd\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"preprocess/volume/data/aj.sdf\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mf\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msdf_path\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"r\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0mvalues\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mread\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msplitlines\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mn\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0masarray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mvalues\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msplit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mori\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0masarray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mfloat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mvalues\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msplit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 36 | "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: '/home/whitealex95/Projects/autorigging/autorigging/preprocess/volume/data/aj.sdf'" 37 | ] 38 | } 39 | ], 40 | "source": [ 41 | "sdf_path = os.path.join(os.getcwd(), \"preprocess/volume/data/aj.sdf\")\n", 42 | "f = open(sdf_path, \"r\")\n", 43 | "values = f.read().splitlines()\n", 44 | "n = np.asarray([int(i) for i in values[0].split()])\n", 45 | "ori = np.asarray([float(i) for i in values[1].split()])\n", 46 | "dx = float(values[2])\n", 47 | "values = np.asarray([float(i) for i in values[3:]])\n", 48 | "print('n : {}'.format(n))\n", 49 | "print('ori : {}'.format(ori))\n", 50 | "print('dx : {}'.format(dx))" 51 | ] 52 | }, 53 | { 54 | "cell_type": "markdown", 55 | "metadata": {}, 56 | "source": [ 57 | "## Visualization - Scattering" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": 8, 63 | "metadata": {}, 64 | "outputs": [], 65 | "source": [ 66 | "points,colors = [], []\n", 67 | "count = 0\n", 68 | "MAX = values.max()\n", 69 | "MAX=0\n", 70 | "MIN = values.min()\n", 71 | "for k in range(n[2]):\n", 72 | " for j in range(n[1]):\n", 73 | " for i in range(n[0]):\n", 74 | " if(values[count] < 0):\n", 75 | " points.append(np.asarray([i,j,k]).astype(np.uint8))\n", 76 | " colors.append(np.asarray([MAX + MIN - values[count], MIN, values[count]]))\n", 77 | " count += 1\n", 78 | "N_MAX = n.max()\n", 79 | "points = np.stack(points,axis=0)\n", 80 | "colors = np.stack(colors,axis=0)\n", 81 | "colors = (colors - MIN)/(MAX-MIN)\n", 82 | "key = 1" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": 9, 88 | "metadata": {}, 89 | "outputs": [ 90 | { 91 | "data": { 92 | "application/vnd.jupyter.widget-view+json": { 93 | "model_id": "e4371cc3bb224bc1a0ec30dc96afee6d", 94 | "version_major": 2, 95 | "version_minor": 0 96 | }, 97 | "text/plain": [ 98 | "VBox(children=(Figure(camera=PerspectiveCamera(aspect=0.8, fov=46.0, position=(0.22326181515957758, -0.1555554…" 99 | ] 100 | }, 101 | "metadata": {}, 102 | "output_type": "display_data" 103 | } 104 | ], 105 | "source": [ 106 | "key+=1\n", 107 | "ipv.figure(key=key)\n", 108 | "\n", 109 | "ipv.scatter(points[:,0], points[:,1], points[:,2], size=1, marker=\"sphere\", color=colors)\n", 110 | "ipv.xyzlim(-50,N_MAX+50)\n", 111 | "ipv.show()" 112 | ] 113 | }, 114 | { 115 | "cell_type": "markdown", 116 | "metadata": {}, 117 | "source": [ 118 | "## Visualization - Volume show" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": 10, 124 | "metadata": { 125 | "scrolled": false 126 | }, 127 | "outputs": [ 128 | { 129 | "data": { 130 | "application/vnd.jupyter.widget-view+json": { 131 | "model_id": "5e45d12b36074402a2a30239154b1964", 132 | "version_major": 2, 133 | "version_minor": 0 134 | }, 135 | "text/plain": [ 136 | "VBox(children=(VBox(children=(HBox(children=(Label(value='levels:'), FloatSlider(value=0.0, max=1.0, step=0.00…" 137 | ] 138 | }, 139 | "metadata": {}, 140 | "output_type": "display_data" 141 | } 142 | ], 143 | "source": [ 144 | "V = np.zeros(n)\n", 145 | "count = 0\n", 146 | "for k in range(n[2]):\n", 147 | " for j in range(n[1]):\n", 148 | " for i in range(n[0]):\n", 149 | " V[i,j,k] = values[count]\n", 150 | " count += 1\n", 151 | "N_MAX = n.max()\n", 152 | "# fig = ipv.figure(key=1, width=450, height=450)\n", 153 | "ipv.quickvolshow(V, level=[0, 3], opacity=0.03, level_width=0.1, data_min=-8, data_max=-0.00001)\n", 154 | "ipv.xyzlim(0,N_MAX)\n", 155 | "ipv.show()" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": null, 161 | "metadata": {}, 162 | "outputs": [], 163 | "source": [] 164 | } 165 | ], 166 | "metadata": { 167 | "kernelspec": { 168 | "display_name": "autorigging", 169 | "language": "python", 170 | "name": "autorigging" 171 | }, 172 | "language_info": { 173 | "codemirror_mode": { 174 | "name": "ipython", 175 | "version": 3 176 | }, 177 | "file_extension": ".py", 178 | "mimetype": "text/x-python", 179 | "name": "python", 180 | "nbconvert_exporter": "python", 181 | "pygments_lexer": "ipython3", 182 | "version": "3.7.7" 183 | } 184 | }, 185 | "nbformat": 4, 186 | "nbformat_minor": 4 187 | } 188 | -------------------------------------------------------------------------------- /test_bvh.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | import yaml 5 | import argparse 6 | 7 | import torch.utils.data 8 | from torch.utils.data import DataLoader 9 | 10 | from models.mixamo_bvh_model import MixamoBVHModel 11 | from datasets.mixamo_bvh_dataset import MixamoBVHDataset 12 | 13 | parser = argparse.ArgumentParser(description="need trained_model, its config, joint_dir") 14 | parser.add_argument('--model', required=True) 15 | parser.add_argument('--config', required=True) 16 | parser.add_argument('--joint_path', required=True) 17 | args = parser.parse_args() 18 | 19 | # configs = get_configs_from_arguments(testing=True) 20 | with open(args.config) as f: 21 | configs = yaml.load(f, Loader=yaml.FullLoader) 22 | configs['model'] = args.model 23 | 24 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 25 | 26 | joint_path = 'logs/vox/vox_ori_all_s4.5_HG_mean_stack2_down2_lr3e-4_b4_ce/test' 27 | 28 | test_dataset = MixamoBVHDataset('test_models.txt', configs, use_front_faced=True, joint_path=joint_path) # front faced only for test 29 | test_dataloader = DataLoader(test_dataset, batch_size=configs['batch_size'], shuffle=False, num_workers=int(configs['workers']), pin_memory=True) 30 | 31 | model = MixamoBVHModel(configs) 32 | model.to(device) 33 | 34 | checkpoint = torch.load(configs['model']) 35 | model.load_state_dict(checkpoint['model_state_dict']) 36 | print("[Info] Loaded model parameters from " + configs['model']) 37 | 38 | model.eval() 39 | with torch.no_grad(): 40 | for data in test_dataloader: 41 | data = [dat.to(device).float() for dat in data[:-1]] + [data[-1]] 42 | input_position, target_rotation, meta = data # forget target_rotation.... 43 | for key in meta: 44 | if isinstance(meta[key], torch.Tensor): 45 | meta[key].to(device).float() 46 | pred_rotation = model(input_position) 47 | 48 | character_name, motion_name = meta['character_name'], meta['motion_name'] 49 | batch_size = input_position.shape[0] 50 | for i in range(batch_size): 51 | character_name_i, motion_name_i = character_name[i], motion_name[i] 52 | if not os.path.exists(os.path.join(configs['log_dir'],'test'+'_'+configs['model'].split('/')[-1].split('_')[-1].split('.pth')[0], character_name_i)): 53 | os.makedirs(os.path.join(configs['log_dir'], 'test'+'_'+configs['model'].split('/')[-1].split('_')[-1].split('.pth')[0], character_name_i)) 54 | 55 | model_epoch = configs['model'].split('/')[-1].split('_')[-1].split('.pth')[0] 56 | np.save(os.path.join(configs['log_dir'], 'test'+'_'+ model_epoch, character_name_i, motion_name_i + '_pred_rot'),pred_rotation[i].cpu().detach()) 57 | np.save(os.path.join(configs['log_dir'], 'test'+'_'+ model_epoch, character_name_i, motion_name_i + '_target_rot'),target_rotation[i].cpu().detach()) 58 | 59 | print("[Info] Inference Done") 60 | -------------------------------------------------------------------------------- /test_skin.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import yaml 4 | 5 | import torch.utils.data 6 | import numpy as np 7 | from torch_geometric.data import DataLoader 8 | 9 | from datasets.mixamo_skin_dataset import MixamoSkinDataset 10 | from models.mixamo_skin_model import MixamoMeshSkinModel 11 | 12 | parser = argparse.ArgumentParser(description="need trained_model, its config, joint_dir") 13 | parser.add_argument('--model', required=True) 14 | parser.add_argument('--config', required=True) 15 | parser.add_argument('--vol_geo_dir', required=True) 16 | args = parser.parse_args() 17 | 18 | # configs = get_configs_from_arguments(testing=True) 19 | with open(args.config) as f: 20 | configs = yaml.load(f, Loader=yaml.FullLoader) 21 | configs['model'] = args.model 22 | configs['vol_geo_dir'] = args.vol_geo_dir 23 | 24 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 25 | 26 | test_dataset = MixamoSkinDataset(data_dir=configs['data_dir'], split='test_models', vol_geo_dir=configs['vol_geo_dir'], preprocess=configs['preprocess'], datatype=configs['datatype'], configs=configs, test_all=True) 27 | test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=int(configs['workers']), pin_memory=True) 28 | 29 | model = MixamoMeshSkinModel(configs, num_joints=configs['num_joints'], use_bn=configs['use_bn']) 30 | model.to(device) 31 | 32 | checkpoint = torch.load(configs['model']) 33 | model.load_state_dict(checkpoint['model_state_dict']) 34 | print("[Info] Loaded model parameters from " + configs['model']) 35 | 36 | with torch.no_grad(): 37 | print("Testing") 38 | for data in test_dataloader: 39 | data = [dat.to(device) for dat in data[:-2]] + list(data[-2:]) 40 | mesh, gt_jm, gt_ibm, character_name, motion_name = data 41 | pred_skin_logit = model(mesh) 42 | 43 | batch_size = mesh.batch.max().item() + 1 44 | for i in range(batch_size): 45 | pred_skin = torch.exp(pred_skin_logit)[mesh.batch==i].cpu().detach().numpy() 46 | character_name, motion_name = character_name[i], motion_name[i] 47 | print(character_name, motion_name) 48 | 49 | if not os.path.exists(os.path.join(configs['log_dir'], 'test', character_name)): 50 | os.makedirs(os.path.join(configs['log_dir'], 'test', character_name)) 51 | 52 | np.savetxt(os.path.join(configs['log_dir'], 'test', character_name, motion_name + '_skin.csv'), 53 | pred_skin, delimiter=',') 54 | -------------------------------------------------------------------------------- /test_vox.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch.optim as optim 3 | import torch.utils.data 4 | from torch.utils.tensorboard import SummaryWriter 5 | import numpy as np 6 | from collections import defaultdict 7 | 8 | from utils.joint_util import save_jm, save_jm2 9 | from utils.train_vox_utils import get_vox_dataloaders, train_vox, eval_vox, vis_vox, get_configs_from_arguments 10 | from models.mixamo_vox_model import MixamoVoxModel 11 | from time import time 12 | from datasets.mixamo_vox_dataset import MixamoVoxDataset 13 | from torch.utils.data import DataLoader 14 | 15 | from utils.voxel_utils import get_final_preds, extract_joint_pos_from_heatmap_softargmax, extract_joint_pos_from_heatmap, downsample_single_heatmap 16 | from utils.joint_util import transform_rel2glob 17 | from utils.loss_utils import normalize3d 18 | 19 | configs = get_configs_from_arguments(testing=True) 20 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 21 | 22 | test_dataset = MixamoVoxDataset('test_models.txt', configs, reduce_motion=False, use_front_faced=True) # front faced only for test 23 | test_dataloader = DataLoader(test_dataset, batch_size=configs['batch_size'], shuffle=False, num_workers=int(configs['workers']), pin_memory=True) 24 | 25 | model = MixamoVoxModel(configs) 26 | model.to(device) 27 | 28 | checkpoint = torch.load(configs['model']) 29 | model.load_state_dict(checkpoint['model_state_dict']) 30 | print("[Info] Loaded model parameters from " + configs['model']) 31 | 32 | 33 | model.eval() 34 | with torch.no_grad(): 35 | for data in test_dataloader: 36 | data = [dat.to(device).float() for dat in data[:-1]] + [data[-1]] 37 | bin_vox_padded, sdf_vox_padded, target_heatmap, JM4x4, meta = data 38 | if configs['normalize_heatmap']: 39 | target_heatmap = normalize3d(target_heatmap) 40 | pred_heatmap_logit = model(sdf_vox_padded) 41 | if configs['loss_type'].endswith('mask'): 42 | pred_heatmap = torch.sigmoid(pred_heatmap_logit) * bin_vox_padded 43 | else: 44 | pred_heatmap = torch.sigmoid(pred_heatmap_logit) 45 | acc = joint_acc = model.compute_accuracy(pred_heatmap, target_heatmap, JM4x4, meta, average=False) 46 | 47 | character_name, motion_name = meta['character_name'], meta['motion_name'] 48 | batch_size = bin_vox_padded.shape[0] 49 | pred_coords, _ = get_final_preds(pred_heatmap, meta['translate'], meta['scale']* configs['downsample']) 50 | target_heatmap_coords, _ = get_final_preds(target_heatmap, meta['translate'], meta['scale']) 51 | JM4x4_global = transform_rel2glob(JM4x4) 52 | target_coords = JM4x4_global[:, :, :3, 3] 53 | 54 | for i in range(batch_size): 55 | character_name_i, motion_name_i = character_name[i], motion_name[i] 56 | print(character_name_i, motion_name_i) 57 | pred_heatmap_i = pred_heatmap[i].cpu().numpy() 58 | target_heatmap_i = target_heatmap[i].cpu().numpy() 59 | pred_coords_i = pred_coords[i] 60 | target_coords_i = target_coords[i].cpu().numpy() 61 | target_heatmap_coords_i = target_heatmap_coords[i] 62 | acc_i = acc[i].cpu().numpy() 63 | if not os.path.exists(os.path.join(configs['log_dir'], 'test', character_name_i)): 64 | os.makedirs(os.path.join(configs['log_dir'], 'test', character_name_i)) 65 | 66 | # np.save(os.path.join(configs['log_dir'], 'test', character_name_i, motion_name_i + '_pred_hm'),pred_heatmap_i) 67 | # np.save(os.path.join(configs['log_dir'], 'test', character_name_i, motion_name_i + '_gt_hm'),target_heatmap_i) 68 | np.save(os.path.join(configs['log_dir'], 'test', character_name_i, motion_name_i + '_info'),{ 69 | 'pred_coords': pred_coords_i, 70 | 'target_coords': target_coords_i, 71 | 'target_heatmap_coords': target_heatmap_coords_i, 72 | 'acc': acc_i 73 | }) 74 | 75 | scale_i, translate_i = meta['scale'][i].cpu().numpy(), torch.stack(meta['translate']).T[i].cpu().numpy() 76 | center_trans_i = meta['center_trans'][i].cpu().numpy() 77 | use_downsample = False 78 | mask = None 79 | if configs['downsample'] > 1: 80 | use_downsample = True 81 | mask = downsample_single_heatmap(torch.Tensor(bin_vox_padded[i][None,...].cpu()), 2).numpy() 82 | pred_heatmap_joint_pos = extract_joint_pos_from_heatmap(pred_heatmap_i, scale_i, translate_i, use_downsample=use_downsample, center_trans=center_trans_i, mask=None) 83 | pred_heatmap_joint_pos_soft = extract_joint_pos_from_heatmap_softargmax(pred_heatmap_i, scale_i, translate_i, use_downsample=use_downsample, center_trans=center_trans_i, mask=None) 84 | pred_heatmap_joint_pos_mask = extract_joint_pos_from_heatmap(pred_heatmap_i, scale_i, translate_i, use_downsample=use_downsample, center_trans=center_trans_i, mask=mask) 85 | 86 | np.save(os.path.join(configs['log_dir'], 'test', character_name_i, motion_name_i + '_joint'),{ 87 | 'pred_coords': pred_coords_i, 88 | 'pred_heatmap_joint_pos': pred_heatmap_joint_pos, 89 | 'pred_heatmap_joint_pos_soft': pred_heatmap_joint_pos_soft, 90 | 'pred_heatmap_joint_pos_mask': pred_heatmap_joint_pos_mask 91 | }) 92 | np.save(os.path.join(configs['log_dir'], 'test', character_name_i, motion_name_i + '_joint_pos_mask'), pred_heatmap_joint_pos_mask) 93 | 94 | 95 | -------------------------------------------------------------------------------- /train_bvh.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch.optim as optim 3 | import torch.utils.data 4 | from torch.utils.tensorboard import SummaryWriter 5 | import numpy as np 6 | from collections import defaultdict 7 | 8 | from utils.joint_util import save_jm, save_jm2 9 | from utils.train_vox_utils import get_configs_from_arguments 10 | from utils.train_bvh_utils import get_bvh_dataloaders, train_bvh, eval_bvh, vis_bvh 11 | from models.mixamo_bvh_model import MixamoBVHModel 12 | from time import time 13 | 14 | configs = get_configs_from_arguments() 15 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 16 | 17 | train_dataloader, eval_dataloader, vis_dataloader = get_bvh_dataloaders(configs) 18 | train_num_batch, eval_num_batch = len(train_dataloader), len(eval_dataloader) 19 | model = MixamoBVHModel(configs) 20 | model.to(device) 21 | 22 | optimizer = optim.Adam(model.parameters(), lr=configs['lr'], betas=(0.9, 0.999)) 23 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=configs['lr_step_size'], gamma=configs['lr_gamma']) 24 | writer = SummaryWriter(log_dir=configs['log_dir']) 25 | 26 | step = 0 27 | start_epoch = 0 28 | if configs['model'] != '': # Load model 29 | checkpoint = torch.load(configs['model']) 30 | model.load_state_dict(checkpoint['model_state_dict']) 31 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 32 | scheduler.load_state_dict(checkpoint['scheduler_state_dict']) 33 | step = checkpoint['last_step'] + 1 34 | start_epoch = scheduler.last_epoch 35 | print("[Info] Loaded model parameters from " + configs['model']) 36 | 37 | time_log = defaultdict(float) 38 | 39 | for epoch in range(start_epoch, start_epoch + configs['nepoch']): 40 | # total_loss_list, rot_loss_list, trans_loss_list, skin_loss_list = [], [], [], [] 41 | model.train() 42 | losses_dict = defaultdict(list) 43 | losses_mean = defaultdict(float) 44 | configs['time'] = configs['time'] and epoch == start_epoch 45 | if configs['time']: 46 | time_log['before_load'] = time() 47 | for i, data in enumerate(train_dataloader, 0): 48 | train_bvh(i, epoch, step, data, model, optimizer, writer, losses_dict, train_num_batch, time_log, device, 49 | configs) 50 | step += 1 51 | for key, val in losses_dict.items(): 52 | losses_mean[key] = np.mean(val) 53 | model.print_loss(losses_mean['loss'], losses_mean['joint_acc'], epoch, 0, 0)# print() 54 | model.write_summary(losses_mean, epoch=epoch, writer=writer) 55 | scheduler.step() # schedule based on epochs 56 | 57 | # Evaluate 58 | if not configs['vis_overfit']: 59 | model.eval() 60 | with torch.no_grad(): 61 | losses_dict = defaultdict(list) 62 | losses_mean = defaultdict(float) 63 | for i, data in enumerate(eval_dataloader, 0): 64 | eval_bvh(i, epoch, step, data, model, writer, losses_dict, train_num_batch, device, configs) 65 | for key, val in losses_dict.items(): 66 | losses_mean[key] = np.mean(val) 67 | model.print_loss(losses_mean['loss'], losses_mean['joint_acc'], epoch, 0, 0)# print() 68 | model.write_summary(losses_mean, epoch=epoch, writer=writer) 69 | 70 | # Save training model 71 | if epoch % configs['save_epoch'] == 0: 72 | torch.save({'model_state_dict': model.state_dict(), 73 | 'optimizer_state_dict': optimizer.state_dict(), 74 | 'scheduler_state_dict': scheduler.state_dict(), 75 | 'last_step': step-1}, 76 | '%s/model_epoch_%.3d.pth' % (configs['log_dir'], epoch)) 77 | 78 | model.eval() 79 | with torch.no_grad(): 80 | if epoch % configs['vis_epoch'] == 0: 81 | print("Visualizing") 82 | for data in vis_dataloader: 83 | vis_bvh(epoch, data, model, device, configs) 84 | 85 | -------------------------------------------------------------------------------- /train_skin.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch.optim as optim 3 | import torch.utils.data 4 | from torch.utils.tensorboard import SummaryWriter 5 | import numpy as np 6 | from collections import defaultdict 7 | 8 | from utils.joint_util import save_jm, save_jm2 9 | from utils.train_skin_utils import get_skin_dataloaders, get_configs 10 | from models.mixamo_skin_model import MixamoMeshSkinModel 11 | 12 | 13 | args = get_configs() 14 | configs = vars(args) # class fields to dictionary (args.lr -> configs['lr']) 15 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 16 | train_dataloader, eval_dataloader, vis_dataloader = get_skin_dataloaders(args) 17 | train_num_batch, eval_num_batch = len(train_dataloader), len(eval_dataloader) 18 | model = MixamoMeshSkinModel(configs, num_joints=configs['num_joints'], use_bn=configs['use_bn']) 19 | model.to(device) 20 | 21 | optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999)) 22 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma) 23 | writer = SummaryWriter(log_dir=args.log_dir) 24 | 25 | step = 0 26 | if args.model != '': # Load model 27 | checkpoint = torch.load(args.model) 28 | model.load_state_dict(checkpoint['model_state_dict']) 29 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 30 | scheduler.load_state_dict(checkpoint['scheduler_state_dict']) 31 | step = checkpoint['last_step'] + 1 32 | print("[Info] Loaded model parameters from " + args.model) 33 | 34 | 35 | for epoch in range(scheduler.last_epoch, args.nepoch): 36 | # total_loss_list, rot_loss_list, trans_loss_list, skin_loss_list = [], [], [], [] 37 | losses_list = [] 38 | for i, data in enumerate(train_dataloader, 0): 39 | data = [dat.to(device) for dat in data[:-2]] + data[-2:] 40 | mesh, gt_jm, gt_ibm, character_name, motion_name = data # jm: relative, ibm: global 41 | optimizer.zero_grad() 42 | model.train() 43 | pred_skin_logit = model(mesh) 44 | skin_loss = model.calculate_loss(pred_skin_logit, mesh, writer=writer, step=step) 45 | model.print_running_loss(epoch, i, train_num_batch) 46 | losses_list.append([skin_loss.item()]) 47 | skin_loss.backward() 48 | optimizer.step() 49 | step += 1 50 | 51 | losses_mean = np.mean(losses_list) 52 | model.print_loss(epoch, i, train_num_batch, skin_loss=losses_mean, is_mean=True) 53 | model.write_summary(writer=writer, step=step, skin_loss=losses_mean) 54 | scheduler.step() # schedule based on epochs 55 | 56 | # Evaluate 57 | model.eval() 58 | with torch.no_grad(): 59 | losses_list = [] 60 | for i, data in enumerate(eval_dataloader, 0): 61 | data = [dat.to(device) for dat in data[:-2]] + data[-2:] 62 | mesh, gt_jm, gt_ibm, character_name, motion_name = data # jm: relative, ibm: global 63 | pred_skin_logit = model(mesh) 64 | skin_loss = model.calculate_loss(pred_skin_logit, mesh, writer=writer, step=step) 65 | model.print_running_loss(epoch, i, eval_num_batch) 66 | losses_list.append([skin_loss.item()]) 67 | 68 | losses_mean = np.mean(losses_list) 69 | model.print_loss(epoch, i, eval_num_batch, skin_loss=losses_mean, is_mean=True) 70 | model.write_summary(writer, step, skin_loss=losses_mean) 71 | 72 | # Save training model 73 | if epoch % args.save_step == 0: 74 | torch.save({'model_state_dict': model.state_dict(), 75 | 'optimizer_state_dict': optimizer.state_dict(), 76 | 'scheduler_state_dict': scheduler.state_dict(), 77 | 'last_step': step-1}, 78 | '%s/model_epoch_%.3d.pth' % (args.log_dir, epoch)) 79 | 80 | # Save inference result of training model 81 | model.eval() 82 | with torch.no_grad(): 83 | if epoch % args.vis_step == 0: 84 | print("Visualizing") 85 | for data in vis_dataloader: 86 | data = [dat.to(device) for dat in data[:-2]] + list(data[-2:]) 87 | mesh, gt_jm, gt_ibm, character_name, motion_name = data 88 | pred_skin_logit = model(mesh) 89 | 90 | batch_size = mesh.batch.max().item() + 1 91 | for i in range(batch_size): 92 | pred_skin = torch.exp(pred_skin_logit)[mesh.batch==i].cpu().detach().numpy() 93 | character_name, motion_name = character_name[i], motion_name[i] 94 | 95 | if not os.path.exists(os.path.join(args.log_dir, 'vis', character_name)): 96 | os.makedirs(os.path.join(args.log_dir, 'vis', character_name)) 97 | 98 | np.savetxt(os.path.join(args.log_dir, 'vis', character_name, motion_name + '_skin_%.3d.csv' % (epoch)), 99 | pred_skin, delimiter=',') -------------------------------------------------------------------------------- /train_vox.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch.optim as optim 3 | import torch.utils.data 4 | from torch.utils.tensorboard import SummaryWriter 5 | import numpy as np 6 | from collections import defaultdict 7 | 8 | from utils.joint_util import save_jm, save_jm2 9 | from utils.train_vox_utils import get_vox_dataloaders, train_vox, eval_vox, vis_vox, get_configs_from_arguments 10 | from models.mixamo_vox_model import MixamoVoxModel 11 | from time import time 12 | 13 | configs = get_configs_from_arguments() 14 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 15 | 16 | train_dataloader, eval_dataloader, vis_dataloader = get_vox_dataloaders(configs) 17 | train_num_batch, eval_num_batch = len(train_dataloader), len(eval_dataloader) 18 | model = MixamoVoxModel(configs) 19 | model.to(device) 20 | 21 | optimizer = optim.Adam(model.parameters(), lr=configs['lr'], betas=(0.9, 0.999)) 22 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=configs['lr_step_size'], gamma=configs['lr_gamma']) 23 | writer = SummaryWriter(log_dir=configs['log_dir']) 24 | 25 | step = 0 26 | start_epoch = 0 27 | if configs['model'] != '': # Load model 28 | checkpoint = torch.load(configs['model']) 29 | model.load_state_dict(checkpoint['model_state_dict']) 30 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 31 | scheduler.load_state_dict(checkpoint['scheduler_state_dict']) 32 | step = checkpoint['last_step'] + 1 33 | start_epoch = scheduler.last_epoch 34 | print("[Info] Loaded model parameters from " + configs['model']) 35 | 36 | time_log = defaultdict(float) 37 | 38 | for epoch in range(start_epoch, start_epoch + configs['nepoch']): 39 | # total_loss_list, rot_loss_list, trans_loss_list, skin_loss_list = [], [], [], [] 40 | model.train() 41 | losses_dict = defaultdict(list) 42 | losses_mean = defaultdict(float) 43 | configs['time'] = configs['time'] and epoch == start_epoch 44 | if configs['time']: 45 | time_log['before_load'] = time() 46 | for i, data in enumerate(train_dataloader, 0): 47 | train_vox(i, epoch, step, data, model, optimizer, writer, losses_dict, train_num_batch, time_log, device, 48 | configs) 49 | step += 1 50 | for key, val in losses_dict.items(): 51 | losses_mean[key] = np.mean(val) 52 | model.print_loss(losses_mean['loss'], losses_mean['joint_acc'], epoch, 0, 0); print() 53 | model.write_summary(losses_mean, epoch=epoch, writer=writer) 54 | scheduler.step() # schedule based on epochs 55 | 56 | # Evaluate 57 | if not configs['vis_overfit']: 58 | model.eval() 59 | with torch.no_grad(): 60 | losses_dict = defaultdict(list) 61 | losses_mean = defaultdict(float) 62 | for i, data in enumerate(eval_dataloader, 0): 63 | eval_vox(i, epoch, step, data, model, writer, losses_dict, train_num_batch, device, configs) 64 | for key, val in losses_dict.items(): 65 | losses_mean[key] = np.mean(val) 66 | model.print_loss(losses_mean['loss'], losses_mean['joint_acc'], epoch, 0, 0);print() 67 | model.write_summary(losses_mean, epoch=epoch, writer=writer) 68 | 69 | # Save training model 70 | if epoch % configs['save_epoch'] == 0: 71 | torch.save({'model_state_dict': model.state_dict(), 72 | 'optimizer_state_dict': optimizer.state_dict(), 73 | 'scheduler_state_dict': scheduler.state_dict(), 74 | 'last_step': step-1}, 75 | '%s/model_epoch_%.3d.pth' % (configs['log_dir'], epoch)) 76 | 77 | model.eval() 78 | with torch.no_grad(): 79 | if epoch % configs['vis_epoch'] == 0: 80 | print("Visualizing") 81 | for data in vis_dataloader: 82 | vis_vox(epoch, data, model, device, configs) 83 | 84 | -------------------------------------------------------------------------------- /utils/binvox_rw.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2012 Daniel Maturana 2 | # This file is part of binvox-rw-py. 3 | # 4 | # binvox-rw-py is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU General Public License as published by 6 | # the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # binvox-rw-py is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU General Public License 15 | # along with binvox-rw-py. If not, see . 16 | # 17 | 18 | 19 | import numpy as np 20 | import struct 21 | 22 | 23 | class Voxels(object): 24 | """ Holds a binvox model. 25 | data is either a three-dimensional numpy boolean array (dense representation) 26 | or a two-dimensional numpy float array (coordinate representation). 27 | 28 | dims, translate and scale are the model metadata. 29 | 30 | dims are the voxel dimensions, e.g. [32, 32, 32] for a 32x32x32 model. 31 | 32 | scale and translate relate the voxels to the original model coordinates. 33 | 34 | To translate voxel coordinates i, j, k to original coordinates x, y, z: 35 | 36 | x_n = (i+.5)/dims[0] 37 | y_n = (j+.5)/dims[1] 38 | z_n = (k+.5)/dims[2] 39 | x = scale*x_n + translate[0] 40 | y = scale*y_n + translate[1] 41 | z = scale*z_n + translate[2] 42 | 43 | """ 44 | 45 | def __init__(self, data, dims, translate, scale, axis_order): 46 | self.data = data 47 | self.dims = dims 48 | self.translate = translate 49 | self.scale = scale 50 | assert (axis_order in ('xzy', 'xyz')) 51 | self.axis_order = axis_order 52 | 53 | def clone(self): 54 | data = self.data.copy() 55 | dims = self.dims[:] 56 | translate = self.translate[:] 57 | return Voxels(data, dims, translate, self.scale, self.axis_order) 58 | 59 | def write(self, fp): 60 | write(self, fp) 61 | 62 | def read_header(fp): 63 | """ Read binvox header. Mostly meant for internal use. 64 | """ 65 | line = fp.readline().strip() 66 | if not line.startswith(b'#binvox'): 67 | raise IOError('Not a binvox file') 68 | dims = list(map(int, fp.readline().strip().split(b' ')[1:])) 69 | translate = list(map(float, fp.readline().strip().split(b' ')[1:])) 70 | scale = list(map(float, fp.readline().strip().split(b' ')[1:]))[0] 71 | line = fp.readline() 72 | 73 | return dims, translate, scale 74 | 75 | def read_as_3d_array(fp, fix_coords=True): 76 | """ Read binary binvox format as array. 77 | 78 | Returns the model with accompanying metadata. 79 | 80 | Voxels are stored in a three-dimensional numpy array, which is simple and 81 | direct, but may use a lot of memory for large models. (Storage requirements 82 | are 8*(d^3) bytes, where d is the dimensions of the binvox model. Numpy 83 | boolean arrays use a byte per element). 84 | 85 | Doesn't do any checks on input except for the '#binvox' line. 86 | """ 87 | dims, translate, scale = read_header(fp) 88 | raw_data = np.frombuffer(fp.read(), dtype=np.uint8) 89 | # if just using reshape() on the raw data: 90 | # indexing the array as array[i,j,k], the indices map into the 91 | # coords as: 92 | # i -> x 93 | # j -> z 94 | # k -> y 95 | # if fix_coords is true, then data is rearranged so that 96 | # mapping is 97 | # i -> x 98 | # j -> y 99 | # k -> z 100 | values, counts = raw_data[::2], raw_data[1::2] 101 | data = np.repeat(values, counts).astype(np.bool) 102 | data = data.reshape(dims) 103 | if fix_coords: 104 | # xzy to xyz TODO the right thing 105 | data = np.transpose(data, (0, 2, 1)) 106 | axis_order = 'xyz' 107 | else: 108 | axis_order = 'xzy' 109 | return Voxels(data, dims, translate, scale, axis_order) 110 | 111 | def read_as_coord_array(fp, fix_coords=True): 112 | """ Read binary binvox format as coordinates. 113 | 114 | Returns binvox model with voxels in a "coordinate" representation, i.e. an 115 | 3 x N array where N is the number of nonzero voxels. Each column 116 | corresponds to a nonzero voxel and the 3 rows are the (x, z, y) coordinates 117 | of the voxel. (The odd ordering is due to the way binvox format lays out 118 | data). Note that coordinates refer to the binvox voxels, without any 119 | scaling or translation. 120 | 121 | Use this to save memory if your model is very sparse (mostly empty). 122 | 123 | Doesn't do any checks on input except for the '#binvox' line. 124 | """ 125 | dims, translate, scale = read_header(fp) 126 | raw_data = np.frombuffer(fp.read(), dtype=np.uint8) 127 | 128 | values, counts = raw_data[::2], raw_data[1::2] 129 | 130 | sz = np.prod(dims) 131 | index, end_index = 0, 0 132 | end_indices = np.cumsum(counts) 133 | indices = np.concatenate(([0], end_indices[:-1])).astype(end_indices.dtype) 134 | 135 | values = values.astype(np.bool) 136 | indices = indices[values] 137 | end_indices = end_indices[values] 138 | 139 | nz_voxels = [] 140 | for index, end_index in zip(indices, end_indices): 141 | nz_voxels.extend(range(index, end_index)) 142 | nz_voxels = np.array(nz_voxels) 143 | # TODO are these dims correct? 144 | # according to docs, 145 | # index = x * wxh + z * width + y; // wxh = width * height = d * d 146 | 147 | x = nz_voxels / (dims[0]*dims[1]) 148 | zwpy = nz_voxels % (dims[0]*dims[1]) # z*w + y 149 | z = zwpy / dims[0] 150 | y = zwpy % dims[0] 151 | if fix_coords: 152 | data = np.vstack((x, y, z)) 153 | axis_order = 'xyz' 154 | else: 155 | data = np.vstack((x, z, y)) 156 | axis_order = 'xzy' 157 | 158 | #return Voxels(data, dims, translate, scale, axis_order) 159 | return Voxels(np.ascontiguousarray(data), dims, translate, scale, axis_order) 160 | 161 | def dense_to_sparse(voxel_data, dtype=np.int): 162 | """ From dense representation to sparse (coordinate) representation. 163 | No coordinate reordering. 164 | """ 165 | if voxel_data.ndim!=3: 166 | raise ValueError('voxel_data is wrong shape; should be 3D array.') 167 | return np.asarray(np.nonzero(voxel_data), dtype) 168 | 169 | def sparse_to_dense(voxel_data, dims, dtype=np.bool): 170 | if voxel_data.ndim!=2 or voxel_data.shape[0]!=3: 171 | raise ValueError('voxel_data is wrong shape; should be 3xN array.') 172 | if np.isscalar(dims): 173 | dims = [dims]*3 174 | dims = np.atleast_2d(dims).T 175 | # truncate to integers 176 | xyz = voxel_data.astype(np.int) 177 | # discard voxels that fall outside dims 178 | valid_ix = ~np.any((xyz < 0) | (xyz >= dims), 0) 179 | xyz = xyz[:,valid_ix] 180 | out = np.zeros(dims.flatten(), dtype=dtype) 181 | out[tuple(xyz)] = True 182 | return out 183 | 184 | #def get_linear_index(x, y, z, dims): 185 | #""" Assuming xzy order. (y increasing fastest. 186 | #TODO ensure this is right when dims are not all same 187 | #""" 188 | #return x*(dims[1]*dims[2]) + z*dims[1] + y 189 | 190 | def bwrite(fp,s): 191 | fp.write(s.encode()) 192 | 193 | def write_pair(fp,state, ctr): 194 | fp.write(struct.pack('B',state)) 195 | fp.write(struct.pack('B',ctr)) 196 | 197 | def write(voxel_model, fp): 198 | """ Write binary binvox format. 199 | 200 | Note that when saving a model in sparse (coordinate) format, it is first 201 | converted to dense format. 202 | 203 | Doesn't check if the model is 'sane'. 204 | 205 | """ 206 | if voxel_model.data.ndim==2: 207 | # TODO avoid conversion to dense 208 | dense_voxel_data = sparse_to_dense(voxel_model.data, voxel_model.dims) 209 | else: 210 | dense_voxel_data = voxel_model.data 211 | 212 | bwrite(fp, '#binvox 1\n') 213 | bwrite(fp, 'dim ' + ' '.join(map(str, voxel_model.dims)) + '\n') 214 | bwrite(fp, 'translate ' + ' '.join(map(str, voxel_model.translate)) + '\n') 215 | bwrite(fp, 'scale ' + str(voxel_model.scale) + '\n') 216 | bwrite(fp, 'data\n') 217 | if not voxel_model.axis_order in ('xzy', 'xyz'): 218 | raise ValueError('Unsupported voxel model axis order') 219 | 220 | if voxel_model.axis_order=='xzy': 221 | voxels_flat = dense_voxel_data.flatten() 222 | elif voxel_model.axis_order=='xyz': 223 | voxels_flat = np.transpose(dense_voxel_data, (0, 2, 1)).flatten() 224 | 225 | # keep a sort of state machine for writing run length encoding 226 | state = voxels_flat[0] 227 | ctr = 0 228 | for c in voxels_flat: 229 | if c==state: 230 | ctr += 1 231 | # if ctr hits max, dump 232 | if ctr==255: 233 | write_pair(fp, state, ctr) 234 | ctr = 0 235 | else: 236 | # if switch state, dump 237 | write_pair(fp, state, ctr) 238 | state = c 239 | ctr = 1 240 | # flush out remainders 241 | if ctr > 0: 242 | write_pair(fp, state, ctr) 243 | 244 | if __name__ == '__main__': 245 | import doctest 246 | doctest.testmod() 247 | -------------------------------------------------------------------------------- /utils/common_ops.py: -------------------------------------------------------------------------------- 1 | #------------------------------------------------------------------------------- 2 | # Name: common_ops.py 3 | # Purpose: common functions for geometry processing 4 | # RigNet Copyright 2020 University of Massachusetts 5 | # RigNet is made available under General Public License Version 3 (GPLv3), or under a Commercial License. 6 | # Please see the LICENSE README.txt file in the main directory for more information and instruction on using and licensing RigNet. 7 | #------------------------------------------------------------------------------- 8 | 9 | import numpy as np 10 | import time 11 | import open3d as o3d 12 | from scipy.sparse import lil_matrix 13 | from scipy.sparse.csgraph import dijkstra 14 | 15 | # This file is from RigNet(https://github.com/zhan-xu/RigNet) 16 | 17 | def get_bones(skel): 18 | """ 19 | extract bones from skeleton struction 20 | :param skel: input skeleton 21 | :return: bones are B*6 array where each row consists starting and ending points of a bone 22 | bone_name are a list of B elements, where each element consists starting and ending joint name 23 | leaf_bones indicate if this bone is a virtual "leaf" bone. 24 | We add virtual "leaf" bones to the leaf joints since they always have skinning weights as well 25 | """ 26 | bones = [] 27 | bone_name = [] 28 | leaf_bones = [] 29 | this_level = [skel.root] 30 | while this_level: 31 | next_level = [] 32 | for p_node in this_level: 33 | p_pos = np.array(p_node.pos) 34 | next_level += p_node.children 35 | for c_node in p_node.children: 36 | c_pos = np.array(c_node.pos) 37 | bones.append(np.concatenate((p_pos, c_pos))[np.newaxis, :]) 38 | bone_name.append([p_node.name, c_node.name]) 39 | leaf_bones.append(False) 40 | if len(c_node.children) == 0: 41 | bones.append(np.concatenate((c_pos, c_pos))[np.newaxis, :]) 42 | bone_name.append([c_node.name, c_node.name+'_leaf']) 43 | leaf_bones.append(True) 44 | this_level = next_level 45 | bones = np.concatenate(bones, axis=0) 46 | return bones, bone_name, leaf_bones 47 | 48 | 49 | def calc_surface_geodesic(mesh): 50 | # We denselu sample 4000 points to be more accuracy. 51 | samples = mesh.sample_points_poisson_disk(number_of_points=4000) 52 | pts = np.asarray(samples.points) 53 | pts_normal = np.asarray(samples.normals) 54 | 55 | time1 = time.time() 56 | N = len(pts) 57 | verts_dist = np.sqrt(np.sum((pts[np.newaxis, ...] - pts[:, np.newaxis, :]) ** 2, axis=2)) 58 | verts_nn = np.argsort(verts_dist, axis=1) 59 | conn_matrix = lil_matrix((N, N), dtype=np.float32) 60 | 61 | for p in range(N): 62 | nn_p = verts_nn[p, 1:6] 63 | norm_nn_p = np.linalg.norm(pts_normal[nn_p], axis=1) 64 | norm_p = np.linalg.norm(pts_normal[p]) 65 | cos_similar = np.dot(pts_normal[nn_p], pts_normal[p]) / (norm_nn_p * norm_p + 1e-10) 66 | nn_p = nn_p[cos_similar > -0.5] 67 | conn_matrix[p, nn_p] = verts_dist[p, nn_p] 68 | [dist, predecessors] = dijkstra(conn_matrix, directed=False, indices=range(N), 69 | return_predecessors=True, unweighted=False) 70 | 71 | # replace inf distance with euclidean distance + 8 72 | # 6.12 is the maximal geodesic distance without considering inf, I add 8 to be safer. 73 | inf_pos = np.argwhere(np.isinf(dist)) 74 | if len(inf_pos) > 0: 75 | euc_distance = np.sqrt(np.sum((pts[np.newaxis, ...] - pts[:, np.newaxis, :]) ** 2, axis=2)) 76 | dist[inf_pos[:, 0], inf_pos[:, 1]] = 8.0 + euc_distance[inf_pos[:, 0], inf_pos[:, 1]] 77 | 78 | verts = np.array(mesh.vertices) 79 | vert_pts_distance = np.sqrt(np.sum((verts[np.newaxis, ...] - pts[:, np.newaxis, :]) ** 2, axis=2)) 80 | vert_pts_nn = np.argmin(vert_pts_distance, axis=0) 81 | surface_geodesic = dist[vert_pts_nn, :][:, vert_pts_nn] 82 | time2 = time.time() 83 | print('surface geodesic calculation: {} seconds'.format((time2 - time1))) 84 | return surface_geodesic 85 | -------------------------------------------------------------------------------- /utils/compute_volumetric_geodesic.py: -------------------------------------------------------------------------------- 1 | #------------------------------------------------------------------------------- 2 | # Name: compute_volumetric_geodesic.py 3 | # Purpose: his script calculates volumetric geodesic distance between vertices and bones. 4 | # The shortest paths start from bones, then hit all "visible" vertices w.r.t each bone, i.e. the first hit on the surface is the vertex itself. 5 | # For "invisible" vertices, find the nearest "visible" vertices along surface by surface geodesic distance, and then go interior to the bone. 6 | # RigNet Copyright 2020 University of Massachusetts 7 | # RigNet is made available under General Public License Version 3 (GPLv3), or under a Commercial License. 8 | # Please see the LICENSE README.txt file in the main directory for more information and instruction on using and licensing RigNet. 9 | #------------------------------------------------------------------------------- 10 | 11 | # This file is from RigNet(https://github.com/zhan-xu/RigNet) 12 | 13 | import sys 14 | sys.path.append("./") 15 | import os 16 | import trimesh 17 | import numpy as np 18 | import open3d as o3d 19 | from utils.rig_parser import Info 20 | from utils.common_ops import get_bones, calc_surface_geodesic 21 | import time 22 | import pdb 23 | 24 | 25 | def pts2line(pts, lines): 26 | ''' 27 | Calculate points-to-bone distance. Point to line segment distance refer to 28 | https://stackoverflow.com/questions/849211/shortest-distance-between-a-point-and-a-line-segment 29 | :param pts: N*3 30 | :param lines: N*6, where [N,0:3] is the starting position and [N, 3:6] is the ending position 31 | :return: origins are the neatest projected position of the point on the line. 32 | ends are the points themselves. 33 | dist is the distance in between, which is the distance from points to lines. 34 | Origins and ends will be used for generate rays. 35 | ''' 36 | l2 = np.sum((lines[:, 3:6] - lines[:, 0:3]) ** 2, axis=1) 37 | origins = np.zeros((len(pts) * len(lines), 3)) 38 | ends = np.zeros((len(pts) * len(lines), 3)) 39 | dist = np.zeros((len(pts) * len(lines))) 40 | for l in range(len(lines)): 41 | if np.abs(l2[l]) < 1e-8: # for zero-length edges 42 | origins[l * len(pts):(l + 1) * len(pts)] = lines[l][0:3] 43 | else: # for other edges 44 | t = np.sum((pts - lines[l][0:3][np.newaxis, :]) * (lines[l][3:6] - lines[l][0:3])[np.newaxis, :], axis=1) / \ 45 | l2[l] 46 | t = np.clip(t, 0, 1) 47 | t_pos = lines[l][0:3][np.newaxis, :] + t[:, np.newaxis] * (lines[l][3:6] - lines[l][0:3])[np.newaxis, :] 48 | origins[l * len(pts):(l + 1) * len(pts)] = t_pos 49 | ends[l * len(pts):(l + 1) * len(pts)] = pts 50 | dist[l * len(pts):(l + 1) * len(pts)] = np.linalg.norm( 51 | origins[l * len(pts):(l + 1) * len(pts)] - ends[l * len(pts):(l + 1) * len(pts)], axis=1) 52 | return origins, ends, dist 53 | 54 | 55 | def calc_pts2bone_visible_mat(mesh, origins, ends): 56 | ''' 57 | Check whether the surface point is visible by the internal bone. 58 | Visible is defined as no occlusion on the path between. 59 | :param mesh: 60 | :param surface_pts: points on the surface (n*3) 61 | :param origins: origins of rays 62 | :param ends: ends of the rays, together with origins, we can decide the direction of the ray. 63 | :return: binary visibility matrix (n*m), where 1 indicate the n-th surface point is visible to the m-th ray 64 | ''' 65 | ray_dir = ends - origins 66 | RayMeshIntersector = trimesh.ray.ray_triangle.RayMeshIntersector(mesh) 67 | locations, index_ray, index_tri = RayMeshIntersector.intersects_location(origins, ray_dir + 1e-15) 68 | locations_per_ray = [locations[index_ray == i] for i in range(len(ray_dir))] 69 | min_hit_distance = [] 70 | for i in range(len(locations_per_ray)): 71 | if len(locations_per_ray[i]) == 0: 72 | min_hit_distance.append(np.linalg.norm(ray_dir[i])) 73 | else: 74 | min_hit_distance.append(np.min(np.linalg.norm(locations_per_ray[i] - origins[i], axis=1))) 75 | min_hit_distance = np.array(min_hit_distance) 76 | distance = np.linalg.norm(ray_dir, axis=1) 77 | vis_mat = (np.abs(min_hit_distance - distance) < 1e-4) 78 | return vis_mat 79 | 80 | 81 | def show_visible_mat(mesh_filename, joint_pos, vis_mat, joint_id): 82 | from utils.vis_utils import drawSphere 83 | 84 | mesh_o3d = o3d.io.read_triangle_mesh(mesh_filename) 85 | mesh_trimesh = trimesh.load(mesh_filename) 86 | visible = vis_mat[:, joint_id] 87 | 88 | mesh_ls = o3d.geometry.LineSet.create_from_triangle_mesh(mesh_o3d) 89 | mesh_ls.colors = o3d.utility.Vector3dVector([[0.8, 0.8, 0.8] for i in range(len(mesh_ls.lines))]) 90 | pcd = o3d.geometry.PointCloud() 91 | pcd.points = o3d.utility.Vector3dVector(np.array(mesh_trimesh.vertices)[visible]) 92 | pcd.colors = o3d.utility.Vector3dVector(np.repeat(np.array([[0.0, 0.0, 1.0]]), int(np.sum(visible)), axis=0)) 93 | vis = o3d.visualization.Visualizer() 94 | vis.create_window() 95 | vis.add_geometry(mesh_ls) 96 | vis.add_geometry(drawSphere(joint_pos[joint_id], 0.005, color=[1.0, 0.0, 0.0])) 97 | vis.add_geometry(pcd) 98 | vis.run() 99 | vis.destroy_window() 100 | 101 | def calc_geodesic_matrix(bones, mesh_v, surface_geodesic, mesh_filename, subsampling=False): 102 | """ 103 | calculate volumetric geodesic distance from vertices to each bones 104 | :param bones: B*6 numpy array where each row stores the starting and ending joint position of a bone 105 | :param mesh_v: V*3 mesh vertices 106 | :param surface_geodesic: geodesic distance matrix of all vertices 107 | :param mesh_filename: mesh filename 108 | :return: an approaximate volumetric geodesic distance matrix V*B, were (v,b) is the distance from vertex v to bone b 109 | """ 110 | 111 | if subsampling: 112 | mesh0 = o3d.io.read_triangle_mesh(mesh_filename) 113 | mesh0 = mesh0.simplify_quadric_decimation(3000) 114 | o3d.io.write_triangle_mesh(mesh_filename.replace(".obj", "_simplified.obj"), mesh0) 115 | mesh_trimesh = trimesh.load(mesh_filename.replace(".obj", "_simplified.obj")) 116 | subsamples_ids = np.random.choice(len(mesh_v), np.min((len(mesh_v), 1500)), replace=False) 117 | subsamples = mesh_v[subsamples_ids, :] 118 | surface_geodesic = surface_geodesic[subsamples_ids, :][:, subsamples_ids] 119 | else: 120 | mesh_trimesh = trimesh.load(mesh_filename) 121 | subsamples = mesh_v 122 | origins, ends, pts_bone_dist = pts2line(subsamples, bones) 123 | pts_bone_visibility = calc_pts2bone_visible_mat(mesh_trimesh, origins, ends) 124 | pts_bone_visibility = pts_bone_visibility.reshape(len(bones), len(subsamples)).transpose() 125 | pts_bone_dist = pts_bone_dist.reshape(len(bones), len(subsamples)).transpose() 126 | # remove visible points which are too far 127 | for b in range(pts_bone_visibility.shape[1]): 128 | visible_pts = np.argwhere(pts_bone_visibility[:, b] == 1).squeeze(1) 129 | if len(visible_pts) == 0: 130 | continue 131 | threshold_b = np.percentile(pts_bone_dist[visible_pts, b], 15) 132 | pts_bone_visibility[pts_bone_dist[:, b] > 1.3 * threshold_b, b] = False 133 | 134 | visible_matrix = np.zeros(pts_bone_visibility.shape) 135 | visible_matrix[np.where(pts_bone_visibility == 1)] = pts_bone_dist[np.where(pts_bone_visibility == 1)] 136 | for c in range(visible_matrix.shape[1]): 137 | unvisible_pts = np.argwhere(pts_bone_visibility[:, c] == 0).squeeze(1) 138 | visible_pts = np.argwhere(pts_bone_visibility[:, c] == 1).squeeze(1) 139 | if len(visible_pts) == 0: 140 | visible_matrix[:, c] = pts_bone_dist[:, c] 141 | continue 142 | for r in unvisible_pts: 143 | dist1 = np.min(surface_geodesic[r, visible_pts]) 144 | nn_visible = visible_pts[np.argmin(surface_geodesic[r, visible_pts])] 145 | if np.isinf(dist1): 146 | visible_matrix[r, c] = 8.0 + pts_bone_dist[r, c] 147 | else: 148 | visible_matrix[r, c] = dist1 + visible_matrix[nn_visible, c] 149 | if subsampling: 150 | nn_dist = np.sum((mesh_v[:, np.newaxis, :] - subsamples[np.newaxis, ...])**2, axis=2) 151 | nn_ind = np.argmin(nn_dist, axis=1) 152 | visible_matrix = visible_matrix[nn_ind, :] 153 | os.remove(mesh_filename.replace(".obj", "_simplified.obj")) 154 | os.remove(mesh_filename.replace(".obj", "_simplified.mtl")) 155 | return visible_matrix -------------------------------------------------------------------------------- /utils/ik_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils.joint_util import transform_rel2glob 3 | 4 | ''' 5 | rotation should have shape batch_size * Joint_num * (3/4) * Time 6 | position should have shape batch_size * 3 * Time 7 | offset should have shape batch_size * Joint_num * 3 8 | output have shape batch_size * Time * Joint_num * 3 9 | ''' 10 | 11 | ''' 12 | Input: 13 | joint_position: (B, J, 3) (Time x) 14 | offset: (B, J, 3) 15 | Output: 16 | rotation: (B, J, (3/4/6)) (Time x) 17 | ''' 18 | from utils.bvh_utils import get_animated_bvh_joint_positions 19 | 20 | 21 | class InverseKinematics: 22 | def __init__(self, rotations, positions, offset, parents): 23 | # self.quater = args.rotation == 'quaternion' 24 | if rotations is None: 25 | self.rotations = torch.rand(positions.shape[:-1] + (4,), device=positions.device) # random quaternion 26 | else: 27 | self.rotations = rotations # local bvh rotation 28 | self.rotations.requires_grad_(True) 29 | 30 | self.positions = positions # global joint position 31 | self.root_position = self.positions[..., 0, :] 32 | self.offset = offset 33 | self.parents = parents # topology info 34 | 35 | # Optimizers for IK 36 | self.optimizer = torch.optim.Adam([self.rotations], lr=1e-2, betas=(0.9, 0.999)) 37 | self.crit = torch.nn.MSELoss() 38 | 39 | def step(self): 40 | self.optimizer.zero_grad() 41 | positions_fk = self.forward(self.rotations, self.root_position, self.offset, order='', quater=True, world=True) 42 | self.loss = loss = self.crit(positions_fk, self.positions) 43 | loss.backward() 44 | self.optimizer.step() 45 | self.positions_fk = positions_fk 46 | return loss.item() 47 | 48 | ''' 49 | rotation should have shape batch_size * Joint_num * (3/4) * Time 50 | position should have shape batch_size * 3 * Time 51 | offset should have shape batch_size * Joint_num * 3 52 | output have shape batch_size * Time * Joint_num * 3 53 | ''' 54 | 55 | def forward_(self): 56 | return self.forward(self.rotations, self.root_position, self.offset, order='', quater=True, world=True) 57 | 58 | def forward(self, rotation: torch.Tensor, root_position: torch.Tensor, offset: torch.Tensor, order='xyz', 59 | quater=False, 60 | world=True): 61 | 62 | if quater: # Default rotation representation is quaternion. You should try something else though 63 | transform = self.transform_from_quaternion(rotation) # B, J, 3, 3 64 | transform_glob = transform_rel2glob(transform) 65 | else: 66 | raise NotImplementedError 67 | return get_animated_bvh_joint_positions(offset, transform_glob, root_position) 68 | 69 | @staticmethod 70 | def transform_from_axis(euler, axis): 71 | transform = torch.empty(euler.shape[0:3] + (3, 3), device=euler.device) 72 | cos = torch.cos(euler) 73 | sin = torch.sin(euler) 74 | cord = ord(axis) - ord('x') 75 | 76 | transform[..., cord, :] = transform[..., :, cord] = 0 77 | transform[..., cord, cord] = 1 78 | 79 | if axis == 'x': 80 | transform[..., 1, 1] = transform[..., 2, 2] = cos 81 | transform[..., 1, 2] = -sin 82 | transform[..., 2, 1] = sin 83 | if axis == 'y': 84 | transform[..., 0, 0] = transform[..., 2, 2] = cos 85 | transform[..., 0, 2] = sin 86 | transform[..., 2, 0] = -sin 87 | if axis == 'z': 88 | transform[..., 0, 0] = transform[..., 1, 1] = cos 89 | transform[..., 0, 1] = -sin 90 | transform[..., 1, 0] = sin 91 | 92 | return transform 93 | 94 | @staticmethod 95 | def transform_from_quaternion(quater: torch.Tensor): 96 | 97 | norm = torch.norm(quater, dim=-1, keepdim=True) 98 | quater = quater / norm 99 | 100 | # Fucking problem here... Fixed from (w, x, y, z) -> (x, y, z, w) 101 | qx = quater[..., 0] 102 | qy = quater[..., 1] 103 | qz = quater[..., 2] 104 | qw = quater[..., 3] 105 | 106 | x2 = qx + qx 107 | y2 = qy + qy 108 | z2 = qz + qz 109 | xx = qx * x2 110 | yy = qy * y2 111 | wx = qw * x2 112 | xy = qx * y2 113 | yz = qy * z2 114 | wy = qw * y2 115 | xz = qx * z2 116 | zz = qz * z2 117 | wz = qw * z2 118 | 119 | m = torch.empty(quater.shape[:-1] + (3, 3), device=quater.device) 120 | m[..., 0, 0] = 1.0 - (yy + zz) 121 | m[..., 0, 1] = xy - wz 122 | m[..., 0, 2] = xz + wy 123 | m[..., 1, 0] = xy + wz 124 | m[..., 1, 1] = 1.0 - (xx + zz) 125 | m[..., 1, 2] = yz - wx 126 | m[..., 2, 0] = xz - wy 127 | m[..., 2, 1] = yz + wx 128 | m[..., 2, 2] = 1.0 - (xx + yy) 129 | 130 | return m 131 | 132 | -------------------------------------------------------------------------------- /utils/joint_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | # Numpy Function 5 | def bfs(child, T, T_glob): 6 | if torch.is_tensor(T): 7 | def bfs_rec(child, T, T_glob, node): 8 | for x in child[node]: 9 | T_glob[x] = torch.matmul(T_glob[node],T[x]) 10 | for x in child[node]: 11 | bfs_rec(child, T, T_glob, x) 12 | return 13 | 14 | else: 15 | def bfs_rec(child, T, T_glob, node): 16 | for x in child[node]: 17 | T_glob[x] = np.matmul(T_glob[node],T[x]) 18 | for x in child[node]: 19 | bfs_rec(child, T, T_glob, x) 20 | return 21 | T_glob[0] = T[0] 22 | bfs_rec(child, T, T_glob, 0) 23 | return 24 | 25 | # 22 joints 26 | joint_names = ['Hips', 'Spine', 'Spine1', 'Spine2', 'Neck', 'Head', 'LeftShoulder', 'LeftArm', 'LeftForeArm', 27 | 'RightShoulder', 'RightArm', 'RightForeArm', 'LeftUpLeg', 'LeftLeg', 'LeftFoot', 'LeftToeBase', 28 | 'RightUpLeg', 'RightLeg', 'RightFoot', 'RightToeBase', 29 | 'LeftHand', 'RightHand'] 30 | # 21 bones 31 | bone_names = ['Hips-Spine', 'Hips-LeftUpLeg', 'Hips-RightUpLeg', 'Spine-Spine1', 'Spine1-Spine2', 'Spine2-Neck', 'Spine2-LeftShoulder', 'Spine2-RightShoulder', 'Neck-Head', 'LeftShoulder-LeftArm', 'LeftArm-LeftForeArm', 'LeftForeArm-LeftHand', 'RightShoulder-RightArm', 'RightArm-RightForeArm', 'RightForeArm-RightHand', 'LeftUpLeg-LeftLeg', 'LeftLeg-LeftFoot', 'LeftFoot-LeftToeBase', 'RightUpLeg-RightLeg', 'RightLeg-RightFoot', 'RightFoot-RightToeBase'] 32 | 33 | def joint_pos2bone_len(joint_pos): 34 | if len(joint_pos.shape) == 2: 35 | J, _ = joint_pos.shape 36 | B = 0 37 | joint_pos = joint_pos.unsqueeze(0) 38 | else: 39 | B, J, _ = joint_pos.shape 40 | tree = maketree(J) 41 | bone_lengths = [] 42 | for node_idx, child_list in enumerate(tree): 43 | for child_idx in child_list: 44 | bone_lengths.append(torch.sqrt(torch.sum((joint_pos[:, node_idx] - joint_pos[:, child_idx])**2, dim=-1, keepdim=True))) # BxJ 45 | bone_lengths = torch.cat(bone_lengths, dim=1) 46 | if B==0: 47 | bone_lengths = bone_lengths.squeeze(0) 48 | return bone_lengths 49 | 50 | def maketree(num_joints): 51 | # hardcoded tree 52 | if num_joints==52: 53 | child = [[] for x in range(num_joints)] 54 | child[0] = [1, 44, 48] 55 | child[1] = [2] 56 | child[2] = [3] 57 | child[3] = [4, 6, 25] 58 | child[4] = [5] 59 | #child[5] 60 | child[6] = [7] 61 | child[7] = [8] 62 | child[8] = [9] 63 | child[9] = [10, 13, 16, 19, 22] 64 | child[10] = [11] 65 | child[11] = [12] 66 | #child[12] 67 | child[13] = [14] 68 | child[14] = [15] 69 | #child[15] 70 | child[16] = [17] 71 | child[17] = [18] 72 | #child[18] 73 | child[19] = [20] 74 | child[20] = [21] 75 | #child[21] 76 | child[22] = [23] 77 | child[23] = [24] 78 | #child[24] 79 | child[25] = [26] 80 | child[26] = [27] 81 | child[27] = [28] 82 | child[28] = [29, 32, 35, 38, 41] 83 | child[29] = [30] 84 | child[30] = [31] 85 | #child[31] 86 | child[32] = [33] 87 | child[33] = [34] 88 | #child[34] 89 | child[35] = [36] 90 | child[36] = [37] 91 | #child[37] 92 | child[38] = [39] 93 | child[39] = [40] 94 | #child[40] 95 | child[41] = [42] 96 | child[42] = [43] 97 | #child[43] 98 | child[44] = [45] 99 | child[45] = [46] 100 | child[46] = [47] 101 | #child[47] 102 | child[48] = [49] 103 | child[49] = [50] 104 | child[50] = [51] 105 | #child[51] 106 | return child 107 | else: # joint 22 108 | # we assume that , 109 | # INDEX 20 -> mixamorig_LeftHand 110 | # INDEX 21 -> mixamorig_RightHan 111 | child = [[] for x in range(num_joints)] 112 | child[0] = [1, 12, 16] 113 | child[1] = [2] 114 | child[2] = [3] 115 | child[3] = [4, 6, 9] 116 | child[4] = [5] 117 | #child[5] 118 | child[6] = [7] 119 | child[7] = [8] 120 | child[8] = [20] ##### ATTENTION 121 | child[9] = [10] 122 | child[10] = [11] 123 | child[11] = [21] ##### ATTENTION 124 | child[12]=[13] 125 | child[13]=[14] 126 | child[14]=[15] 127 | #child[15] 128 | child[16]=[17] 129 | child[17]=[18] 130 | child[18]=[19] 131 | #child[19] 132 | return child 133 | 134 | # numpy function 135 | def modify_joint_matrix(transforms, num_joints, inverse=True, local2global=True, short=False): 136 | tree =maketree(num_joints) 137 | T0 = transforms 138 | T = [None] * num_joints 139 | JM = [None] * num_joints 140 | IJM = [None] * num_joints 141 | 142 | for i,A in enumerate(T0): 143 | if A.shape[-2:] == (4, 4): 144 | pass 145 | elif A.shape[-1] == 12: 146 | A = np.reshape(np.append(A,[0., 0., 0., 1.]),(4,4)) 147 | T[i] = A 148 | 149 | if local2global: 150 | bfs(tree, T, JM) 151 | JM = np.array(JM) 152 | else: 153 | JM = np.array(T) 154 | if inverse: 155 | for i,A in enumerate(JM): 156 | IJM[i] = np.linalg.inv(A) 157 | 158 | IJM = np.array(IJM) 159 | IJM = np.reshape(IJM, (num_joints, 16)) 160 | if short: 161 | return IJM[:,:12] 162 | return IJM 163 | 164 | JM = np.reshape(JM, (num_joints, 16)) 165 | if short: 166 | return JM[:, :12] 167 | return JM 168 | 169 | # torch functions 170 | def transform_rel2glob(transforms): 171 | batch_size = transforms.shape[0] 172 | num_joints = transforms.shape[1] 173 | 174 | tree = maketree(num_joints) 175 | 176 | # glob_transforms = torch.zeros(transforms.shape).to(transforms.device).type(transforms.type()) 177 | glob_transforms = [] 178 | for batch_idx, T in enumerate(transforms): 179 | # T_glob = glob_transforms[batch_idx] 180 | T_glob = [None] * num_joints 181 | bfs(tree, T, T_glob) 182 | glob_transforms.append(torch.stack(T_glob).to(transforms.device).type(transforms.type())) 183 | return torch.stack(glob_transforms).to(transforms.device).type(transforms.type()) 184 | 185 | 186 | def toSE3(R_or_T3x4: torch.Tensor, p: torch.Tensor=None)->torch.Tensor: 187 | """ 188 | R: torch.Tensor (batch_size, num_joints, 3, 3) or (num_joints, 3, 3) 189 | p: torch.Tensor (batch_size, num_joints, 3) or (num_joints, 3) 190 | """ 191 | is_batched = R_or_T3x4.shape.__len__() == 4 192 | if is_batched: 193 | batch_size = R_or_T3x4.shape[0] 194 | num_joints = R_or_T3x4.shape[1] 195 | else: 196 | batch_size = 1 197 | num_joints = R_or_T3x4.shape[0] 198 | 199 | if R_or_T3x4.shape[-1] == 3: 200 | R = R_or_T3x4 201 | if not is_batched: 202 | R = R.unsqueeze(0) 203 | p = p.unsqueeze(0) 204 | T3x4 = torch.cat([R, p.unsqueeze(-1)], axis=-1) 205 | 206 | elif R_or_T3x4.shape[-1] == 4: 207 | T3x4 = R_or_T3x4 208 | if not is_batched: 209 | T3x4 = T3x4.unsqueeze(0) 210 | else: 211 | raise NotImplementedError 212 | 213 | T =torch.cat([T3x4, torch.Tensor([0, 0, 0, 1] * batch_size * num_joints).type(T3x4.type()).view(batch_size, num_joints, 1, 4)], axis=-2) 214 | 215 | if not is_batched: 216 | T = T.squeeze(0) 217 | 218 | return T 219 | 220 | 221 | 222 | # numpy function 223 | def get_transform_matrix(rot6d, trans, num_joints=22): 224 | a1 = rot6d[:, :3] 225 | a2 = rot6d[:, 3:] 226 | b1 = a1 / np.sqrt(np.reshape(np.sum(a1**2, axis=1), (num_joints, 1))) 227 | b2 = a2 - np.reshape(np.sum(np.multiply(a2, b1), axis=1), (num_joints, 1)) * b1 228 | b2 = b2 / np.sqrt(np.reshape(np.sum(b2**2, axis=1), (num_joints,1))) 229 | b3 = np.cross(b1, b2) 230 | 231 | b1 = np.expand_dims(b1, axis=2) 232 | b2 = np.expand_dims(b2, axis=2) 233 | b3 = np.expand_dims(b3, axis=2) 234 | 235 | R = np.concatenate((b1,b2,b3), axis=2)# (num_joints, 3, 3) 236 | T = np.expand_dims(trans, axis=2) # (num_joints, 3, 1) 237 | 238 | transform = np.reshape(np.concatenate((R,T), axis=2), (num_joints, 12)) # (args.num_joints, 12) 239 | 240 | return transform 241 | 242 | def transform_joint(joint_path="./data/mixamo/transforms/aj/Samba Dancing_000101.csv"): 243 | transforms = np.genfromtxt(joint_path, delimiter=',', dtype=float) 244 | IJM = modify_joint_matrix(transforms, num_joints=22, inverse=False, local2global=True, short=True) 245 | print("Saving IJM to ",joint_path.split('.csv')[0]+ '_IJM.csv' ) 246 | np.savetxt(joint_path.split('.csv')[0]+ '_JM_GT.csv', IJM, delimiter=',') 247 | return 248 | 249 | # numpy function 250 | def save_ijm(rot6d, trans, output_filename, joint_loss_type): 251 | # rot = compute_rotation_matrix_from_ortho6d(rot6d) 252 | transform = get_transform_matrix(rot6d, trans) 253 | local2global = joint_loss_type in ['rel', 'rel2glob'] # Todo: verify 254 | inverse_joint_matrix = modify_joint_matrix(transform, num_joints=22, inverse=True, 255 | local2global=local2global, short=False) 256 | np.savetxt(output_filename, inverse_joint_matrix, delimiter=',') 257 | 258 | def save_jm(rot6d, trans, output_filename, joint_loss_type): 259 | # rot = compute_rotation_matrix_from_ortho6d(rot6d) 260 | transform = get_transform_matrix(rot6d, trans) 261 | local2global = joint_loss_type in ['rel', 'rel2glob'] # Todo: verify 262 | joint_matrix = modify_joint_matrix(transform, num_joints=22, inverse=False, 263 | local2global=local2global, short=False) 264 | np.savetxt(output_filename, joint_matrix, delimiter=',') 265 | 266 | def save_jm2(jm, output_filename, joint_type='glob', inverse=False): 267 | jm = modify_joint_matrix(jm, num_joints=22, inverse=inverse, 268 | local2global=joint_type != 'global', short=False) 269 | np.savetxt(output_filename, jm, delimiter=', ') 270 | 271 | 272 | if __name__ == "__main__": 273 | # __import__('pdb').set_trace() 274 | transform_joint("./data/mixamo/transforms/aj/Samba Dancing_000000.csv") -------------------------------------------------------------------------------- /utils/loss_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from utils.voxel_utils import get_softmax_preds 6 | from utils.joint_util import joint_pos2bone_len 7 | 8 | def normalize3d(batch): 9 | """ 10 | batch (B, J, N, N, N) 11 | """ 12 | B, J = batch.shape[0], batch.shape[1] 13 | 14 | norm = batch.reshape(B, J, -1).sum(dim=-1, keepdim=True) 15 | return batch.reshape(B, J, -1).div(norm).reshape(batch.shape) 16 | 17 | 18 | def softmax3d(batch): 19 | batch_size = batch.shape[0] 20 | joint_num = batch.shape[1] 21 | return F.softmax(batch.reshape(batch_size, joint_num, -1), dim=2).reshape(batch.shape) 22 | # F.softmax(batch.view(batch_size, joint_num, -1)).view(batch.shape) 23 | 24 | def compute_mpjpe(pred_coords, target_coords): 25 | """ 26 | batch (B, J, 3) 27 | """ 28 | if isinstance(pred_coords, np.ndarray): 29 | return np.mean(np.sqrt(np.mean((pred_coords - target_coords) ** 2, axis=-1))) 30 | else: 31 | return torch.mean(torch.sqrt(torch.sum((pred_coords - target_coords) ** 2, dim=-1)), dim=-1) 32 | 33 | def compute_joint_ce_loss(joint_heatmap, target_heatmap): 34 | """ 35 | (B, J, 88, 88, 88) 36 | """ 37 | B, J = joint_heatmap.shape[0], joint_heatmap.shape[1] 38 | loss_joint = F.binary_cross_entropy_with_logits(joint_heatmap, target_heatmap) 39 | return loss_joint 40 | 41 | class JointsMSELoss(nn.Module): 42 | # Not used 43 | def __init__(self): 44 | super(JointsMSELoss, self).__init__() 45 | self.criterion = nn.MSELoss(size_average=True) 46 | 47 | def forward(self, pred, target): 48 | batch_size = pred.size(0) 49 | num_joints = pred.size(1) 50 | heatmaps_pred = pred.reshape((batch_size, num_joints, -1)).split(1, 1) 51 | heatmaps_gt = target.reshape((batch_size, num_joints, -1)).split(1, 1) 52 | loss = 0 53 | 54 | for idx in range(num_joints): 55 | heatmap_pred = heatmaps_pred[idx].squeeze() 56 | heatmap_gt = heatmaps_gt[idx].squeeze() 57 | loss += 0.5 * self.criterion(heatmap_pred, heatmap_gt) 58 | 59 | return loss / num_joints 60 | 61 | def compute_bone_symmetry_loss(joint_heatmap): 62 | B, J, H, W, D = joint_heatmap.shape 63 | device = joint_heatmap.device 64 | joint_pos = get_softmax_preds(joint_heatmap) 65 | assert J == 22 66 | bone_len = joint_pos2bone_len(joint_pos) 67 | """ 68 | Symmetry: 69 | 0-1 70 | 6-7 71 | 9,10,11 - 12, 13, 14 72 | 15,16,17 - 18,19,20 73 | """ 74 | left_symmetry_idx = [0, 6, 9, 10, 11, 15, 16, 17] 75 | right_symmetry_idx = [1, 7, 12, 13, 14, 18, 19, 20] 76 | # MSE Loss 77 | loss = 0 78 | for i in range(len(left_symmetry_idx)): 79 | loss += (bone_len[...,left_symmetry_idx[i]] - bone_len[...,right_symmetry_idx[i]])**2 80 | return torch.mean(loss) 81 | 82 | 83 | -------------------------------------------------------------------------------- /utils/misc/calc_IBM.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | def main(): 5 | for f in os.listdir("../test/gt/transforms"): 6 | print(f) 7 | convert(f) 8 | 9 | 10 | def convert(filename): 11 | num_joints = 22 #TODO 12 | tree =maketree(num_joints) 13 | T0 = np.genfromtxt('./test/gt/transforms/'+filename, delimiter=',') 14 | T = [None] * num_joints 15 | BM = [None]*num_joints 16 | IBM = [None] * num_joints 17 | 18 | for i,A in enumerate(T0): 19 | A = np.reshape(np.append(A,[0., 0., 0., 1.]),(4,4)) 20 | T[i] = A 21 | 22 | bfs(tree, T, BM) 23 | BM = np.array(BM) 24 | 25 | for i,A in enumerate(BM): 26 | IBM[i] = np.linalg.inv(A) 27 | 28 | IBM = np.array(IBM) 29 | IBM = np.reshape(IBM, (num_joints, 16)) 30 | np.savetxt('./test/gt/transforms/IBM_'+filename, IBM, delimiter=',') 31 | return 32 | 33 | def bfs(child, T, BM): 34 | def bfs_rec(child, T, BM, node): 35 | for x in child[node]: 36 | BM[x] = np.matmul(BM[node],T[x]) 37 | for x in child[node]: 38 | bfs_rec(child, T, BM, x) 39 | return 40 | BM[0] = T[0] 41 | bfs_rec(child, T, BM, 0) 42 | return 43 | 44 | def maketree(num_joints): 45 | # hardcoded tree 46 | if num_joints==52: 47 | child = [[] for x in range(num_joints)] 48 | child[0] = [1, 44, 48] 49 | child[1] = [2] 50 | child[2] = [3] 51 | child[3] = [4, 6, 25] 52 | child[4] = [5] 53 | #child[5] 54 | child[6] = [7] 55 | child[7] = [8] 56 | child[8] = [9] 57 | child[9] = [10, 13, 16, 19, 22] 58 | child[10] = [11] 59 | child[11] = [12] 60 | #child[12] 61 | child[13] = [14] 62 | child[14] = [15] 63 | #child[15] 64 | child[16] = [17] 65 | child[17] = [18] 66 | #child[18] 67 | child[19] = [20] 68 | child[20] = [21] 69 | #child[21] 70 | child[22] = [23] 71 | child[23] = [24] 72 | #child[24] 73 | child[25] = [26] 74 | child[26] = [27] 75 | child[27] = [28] 76 | child[28] = [29, 32, 35, 38, 41] 77 | child[29] = [30] 78 | child[30] = [31] 79 | #child[31] 80 | child[32] = [33] 81 | child[33] = [34] 82 | #child[34] 83 | child[35] = [36] 84 | child[36] = [37] 85 | #child[37] 86 | child[38] = [39] 87 | child[39] = [40] 88 | #child[40] 89 | child[41] = [42] 90 | child[42] = [43] 91 | #child[43] 92 | child[44] = [45] 93 | child[45] = [46] 94 | child[46] = [47] 95 | #child[47] 96 | child[48] = [49] 97 | child[49] = [50] 98 | child[50] = [51] 99 | #child[51] 100 | return child 101 | else: 102 | # we assume that , 103 | # INDEX 20 -> mixamorig_LeftHand 104 | # INDEX 21 -> mixamorig_RightHan 105 | child = [[] for x in range(num_joints)] 106 | child[0] = [1, 12, 16] 107 | child[1] = [2] 108 | child[2] = [3] 109 | child[3] = [4, 6, 9] 110 | child[4] = [5] 111 | #child[5] 112 | child[6] = [7] 113 | child[7] = [8] 114 | child[8] = [20] ##### ATTENTION 115 | child[9] = [10] 116 | child[10] = [11] 117 | child[11] = [21] ##### ATTENTION 118 | child[12]=[13] 119 | child[13]=[14] 120 | child[14]=[15] 121 | #child[15] 122 | child[16]=[17] 123 | child[17]=[18] 124 | child[18]=[19] 125 | #child[19] 126 | return child 127 | 128 | if __name__=="__main__": 129 | main() 130 | -------------------------------------------------------------------------------- /utils/misc/dataset_split.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | 5 | #BASE_DIR=os.path.dirname(os.path.abspath(__file__)) 6 | #sys.path.append(BASE_DIR) 7 | #sys.path.append(os.path.join(BASE_DIR, 'data')) 8 | 9 | action = ["Boxing", "Shoved", "Samba", "Walk"] 10 | fCount = [66, 137, 547, 117] 11 | owner = ['bjs', 'sht', 'kjh'] 12 | 13 | valid_num = 2 14 | 15 | # validset generation 16 | for i in range(valid_num): 17 | while(True): 18 | ownerIdx = np.random.randint(len(owner)) 19 | modIdx = np.random.randint(1, 15+1) 20 | actionIdx = np.random.randint(len(action)) 21 | frame = np.random.randint(fCount[actionIdx]) 22 | 23 | filename = owner[ownerIdx]+str(modIdx).zfill(2)+"_"+action[actionIdx]+"_"+str(frame).zfill(3)+'.csv' 24 | 25 | if os.path.isfile('../data/joint22/train/transforms/'+owner[ownerIdx]+'/'+filename): 26 | print(filename + ' exists! go to validation set...') 27 | os.rename('../data/joint22/train/transforms/'+owner[ownerIdx]+'/'+filename, '../data/joint22/valid/transforms/'+filename) 28 | os.rename('../data/joint22/train/vertices/'+owner[ownerIdx]+'/'+filename, '../data/joint22/valid/vertices/'+filename) 29 | break 30 | 31 | -------------------------------------------------------------------------------- /utils/misc/extract_weights.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | BASE_DIR=os.path.abspath(__file__) 5 | sys.path.append(BASE_DIR) 6 | sys.path.append(os.path.join(BASE_DIR, '/pycollada')) 7 | from collada import * 8 | 9 | inputRepoPath = '~/autorigging/obj_dae' 10 | outputRepoPath = '~/autorigging/weights/sht/' 11 | joint_num = 22 12 | 13 | indices = [*range(1,10)] 14 | #wrongIndices = [1,3,7,8,11,14] 15 | 16 | 17 | #for x in wrongIndices: 18 | # indices.remove(x) 19 | 20 | for modIdx in indices: 21 | fileName = '%s/%d/Boxing.dae'%(inputRepoPath, modIdx) 22 | mesh = Collada(fileName) 23 | output = np.zeros((4096,joint_num)) 24 | 25 | joint_dictionary={} 26 | for joint in range(joint_num-2): 27 | jointname = mesh.animations[joint].id[:-5] 28 | joint_dictionary[jointname]=joint 29 | joint_dictionary['mixamorig_LeftHand']=20 30 | joint_dictionary['mixamorig_RightHand']=21 31 | 32 | for idx in range(0,4096): 33 | count = 0 34 | for jidx in mesh.controllers[0].joint_index[idx]: 35 | jointname = mesh.controllers[0].weight_joints[jidx] 36 | correct_jidx = joint_dictionary[jointname] 37 | output[idx][correct_jidx] = mesh.controllers[0].weights[mesh.controllers[0].weight_index[idx][count]] 38 | count=count+1 39 | print(outputRepoPath+'sht'+str(modIdx)+'.csv') 40 | np.savetxt(outputRepoPath+'sht'+ str(modIdx).zfill(2)+'.csv', output, delimiter=',') 41 | -------------------------------------------------------------------------------- /utils/misc/joint_tree_util.py: -------------------------------------------------------------------------------- 1 | 2 | def makedepth(child): 3 | def makedepth_rec(child, depth, node): 4 | for x in child[node]: 5 | depth[x] = depth[node] + 1 6 | for x in child[node]: 7 | makedepth_rec(child, depth, x) 8 | return 9 | depth = [0 for x in range(len(child))] 10 | root = 0 11 | makedepth_rec(child, depth, root) 12 | return depth 13 | -------------------------------------------------------------------------------- /utils/misc/transpose.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | num_joints = 52 #TODO 4 | filename = 'inference_tree_test'#TODO 5 | T=np.genfromtxt('./test/'+filename+'.csv', delimiter=',') 6 | newT = [None] * num_joints 7 | 8 | for i, A in enumerate(T): 9 | A = np.reshape(A, (3,4)) 10 | R = A[:, :3] 11 | newR = np.transpose(R) 12 | newA = np.reshape(np.append(newR, np.expand_dims(A[:, 3], axis=-1), axis=1), (12,)) 13 | newT[i] = newA 14 | 15 | newT = np.array(newT) 16 | np.savetxt('./new_'+filename+'.csv', newT, delimiter=',') 17 | 18 | 19 | 20 | -------------------------------------------------------------------------------- /utils/misc/vis_util.py: -------------------------------------------------------------------------------- 1 | # Author: Wentao Yuan (wyuan1@cs.cmu.edu) 05/31/2018 2 | 3 | from matplotlib import pyplot as plt 4 | from mpl_toolkits.mplot3d import Axes3D 5 | 6 | 7 | def plot_pcd_three_views(filename, pcds, titles, suptitle='', sizes=None, cmap='Reds', zdir='y', 8 | xlim=(-0.3, 0.3), ylim=(-0.3, 0.3), zlim=(-0.3, 0.3)): 9 | if sizes is None: 10 | sizes = [0.5 for i in range(len(pcds))] 11 | fig = plt.figure(figsize=(len(pcds) * 3, 9)) 12 | for i in range(3): 13 | elev = 30 14 | azim = -45 + 90 * i 15 | for j, (pcd, size) in enumerate(zip(pcds, sizes)): 16 | color = pcd[:, 0] 17 | """if j==1: 18 | ax = fig.add_subplot(3, len(pcds), i * len(pcds) + j + 1, projection='3d') 19 | ax.view_init(elev, azim) 20 | ax.scatter(pcd[:512, 0], pcd[:512, 1], pcd[:512, 2], zdir=zdir, c=pcd[:512, 0] 21 | , s=size, cmap=cmap, vmin=-1, vmax=0.5) 22 | ax.scatter(pcd[512:, 0], pcd[512:, 1], pcd[512:, 2], zdir=zdir, c=pcd[512:, 0], s=size, cmap='Blues', vmin=-1, vmax=0.5) 23 | 24 | elif j==2: 25 | ax = fig.add_subplot(3, len(pcds), i * len(pcds) + j + 1, projection='3d') 26 | ax.view_init(elev, azim) 27 | ax.scatter(pcd[:512*16, 0], pcd[:512*16, 1], pcd[:512*16, 2], zdir=zdir, c=pcd[:512*16, 0] 28 | , s=size, cmap=cmap, vmin=-1, vmax=0.5) 29 | ax.scatter(pcd[512*16:, 0], pcd[512*16:, 1], pcd[512*16:, 2], zdir=zdir, c=pcd[512*16:, 0], s=size, cmap='Blues', vmin=-1, vmax=0.5) 30 | 31 | else:""" 32 | ax = fig.add_subplot(3, len(pcds), i * len(pcds) + j + 1, projection='3d') 33 | ax.view_init(elev, azim) 34 | ax.scatter(pcd[:, 0], pcd[:, 1], pcd[:, 2], zdir=zdir, c=color, s=size, cmap=cmap, vmin=-1, vmax=0.5) 35 | ax.set_title(titles[j]) 36 | ax.set_axis_off() 37 | ax.set_xlim(xlim) 38 | ax.set_ylim(ylim) 39 | ax.set_zlim(zlim) 40 | 41 | plt.subplots_adjust(left=0.05, right=0.95, bottom=0.05, top=0.9, wspace=0.1, hspace=0.1) 42 | plt.suptitle(suptitle) 43 | fig.savefig(filename) 44 | plt.close(fig) 45 | -------------------------------------------------------------------------------- /utils/obj_utils.py: -------------------------------------------------------------------------------- 1 | class ObjLoader(object): 2 | def __init__(self, fileName, skip_face=False, process_face=True): 3 | self.vertices = [] 4 | self.faces = [] 5 | ## 6 | vertex_processing=False 7 | try: 8 | f = open(fileName) 9 | for line in f: 10 | if line[:2] == "v ": 11 | index1 = line.find(" ") + 1 12 | index2 = line.find(" ", index1 + 1) 13 | index3 = line.find(" ", index2 + 1) 14 | 15 | vertex = (float(line[index1:index2]), float(line[index2:index3]), float(line[index3:-1])) 16 | vertex = (round(vertex[0], 2), round(vertex[1], 2), round(vertex[2], 2)) 17 | self.vertices.append(vertex) 18 | vertex_processing = True 19 | 20 | elif line[0] == "f": 21 | string = line.replace("//", "/") 22 | ## 23 | i = string.find(" ") + 1 24 | face = [] 25 | for item in range(string.count(" ")): 26 | if string.find(" ", i) == -1: 27 | fragment = string[i:-1] 28 | if process_face: 29 | fragment = int(fragment.split("/")[0]) - 1 30 | face.append(fragment) 31 | break 32 | fragment = string[i:string.find(" ", i)] 33 | if process_face: 34 | fragment = int(fragment.split("/")[0]) - 1 35 | face.append(fragment) 36 | i = string.find(" ", i) + 1 37 | ## 38 | self.faces.append(tuple(face)) 39 | else: 40 | # single mesh assumption 41 | if skip_face and vertex_processing: 42 | break 43 | 44 | 45 | f.close() 46 | except IOError: 47 | print(".obj file not found.") 48 | 49 | 50 | def face2edge(faces:list): 51 | # faces = [(1,2,3), (3,4,5), (6,7,8)] 52 | e = [] 53 | for f in faces: 54 | e += [[f[0], f[1]], [f[1],f[2]], [f[2],f[0]], [f[1], f[0]], [f[0], f[2]], [f[2], f[1]]] 55 | return e -------------------------------------------------------------------------------- /utils/rotation_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | 4 | 5 | # [B, n] 6 | def normalize_vector( v, return_mag =False): 7 | batch=v.shape[0] 8 | v_mag = torch.sqrt(v.pow(2).sum(1))# batch 9 | v_mag = torch.max(v_mag, torch.autograd.Variable(torch.FloatTensor([1e-8]).cuda())) 10 | v_mag = v_mag.view(batch,1).expand(batch,v.shape[1]) 11 | v = v/v_mag 12 | if(return_mag==True): 13 | return v, v_mag[:,0] 14 | else: 15 | return v 16 | 17 | # u, v [B, n] 18 | def cross_product( u, v): 19 | batch = u.shape[0] 20 | #print (u.shape) 21 | #print (v.shape) 22 | i = u[:,1]*v[:,2] - u[:,2]*v[:,1] 23 | j = u[:,2]*v[:,0] - u[:,0]*v[:,2] 24 | k = u[:,0]*v[:,1] - u[:,1]*v[:,0] 25 | 26 | out = torch.cat((i.view(batch,1), j.view(batch,1), k.view(batch,1)),1)#batch*3 27 | 28 | return out 29 | 30 | def compute_rotation_matrix_from_ortho6d(ortho6d): # parameter range: -inf ~ inf 31 | if ortho6d.shape.__len__() != 2: 32 | reshape=True 33 | else: 34 | reshape=False 35 | if reshape: 36 | # [B', J, 6] -> [B, 3] 37 | batch_size = ortho6d.shape[0] 38 | num_joints = ortho6d.shape[1] 39 | ortho6d = ortho6d.view(-1, 6) 40 | 41 | x_raw = ortho6d[:,0:3] # [B,3] 42 | y_raw = ortho6d[:,3:6] # [B,3] 43 | 44 | x = normalize_vector(x_raw) 45 | z = cross_product(x, y_raw) 46 | z = normalize_vector(z) 47 | y = cross_product(z, x) 48 | 49 | x = x.view(-1, 3, 1) 50 | y = y.view(-1, 3, 1) 51 | z = z.view(-1, 3, 1) 52 | matrix = torch.cat((x,y,z), 2) # [B, 3, 3] 53 | 54 | if reshape: 55 | # [B, 3, 3] -> [B', J, 3, 3] 56 | matrix = matrix.view(batch_size, num_joints, 3, 3) 57 | return matrix 58 | 59 | #matrices batch*3*3 60 | #both matrix are orthogonal rotation matrices 61 | #out theta between 0 to 3.1416 radian (0 to 180 degree if degrees=True) batch 62 | #snippet from github.com/papagina/RotationContinuity 63 | def compute_geodesic_distance_from_two_matrices(m1, m2, degrees=False): 64 | m2 = m2.float() 65 | if m1.shape.__len__() != 3: 66 | reshape = True 67 | else: 68 | reshape = False 69 | if reshape: 70 | # [B', J, 3, 3] -> [B, 3, 3] 71 | batch_size = m1.shape[0] 72 | num_joints = m1.shape[1] 73 | m1 = m1.view(-1, 3, 3) 74 | m2 = m2.view(-1, 3, 3) 75 | 76 | batch=m1.shape[0] 77 | m = torch.bmm(m1, m2.transpose(1,2)) #batch*3*3 78 | 79 | cos = ( m[:,0,0] + m[:,1,1] + m[:,2,2] - 1 )/2 80 | cos = torch.min(cos, torch.autograd.Variable(torch.ones(batch).cuda()) ) 81 | cos = torch.max(cos, torch.autograd.Variable(torch.ones(batch).cuda())*-1 ) 82 | theta = torch.acos(cos) 83 | if reshape: 84 | # [B] -> [B', J] 85 | theta = theta.view(batch_size, num_joints) 86 | if degrees: 87 | theta = theta * (180 / math.pi) 88 | return theta 89 | 90 | def compute_L2_distance_from_two_matrices(m1, m2): 91 | # L2 loss : mean sum of squares of all matrix elements, is different from matrix L2-norm 92 | """ 93 | [B, J, 3, 3], [B, J, 3, 3] -> [B, ] 94 | """ 95 | B, J = m1.shape[:2] 96 | return ((m1 - m2)**2).view(B, -1).mean(1) 97 | # ref: https://github.com/papagina/RotationContinuity/blob/758b0ce551c06372cab7022d4c0bdf331c89c696/shapenet/code/Model_pointnet.py#L97 98 | -------------------------------------------------------------------------------- /utils/skin_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils.joint_util import transform_rel2glob 3 | from torch_geometric.data import Batch 4 | 5 | def mesh_transform(mesh, JM, BM, skin_weight, J=22, JM_type='rel', BM_type='glob'): 6 | # Todo: transform vertex using V_transformed = skin_weight * BM * JM^-1 * V 7 | """ 8 | Input: 9 | vertex: torch.Tensor [B, npts, 3] --> (jsbae) it must be raw input, i.e. without sampling applied 10 | JM: torch.Tensor [B, J, 4, 4] --> (capoo) Don't forget it's global! (kim)No, it's relative.... 11 | BM: torch.Tensor [B, J, 4, 4] --> (capoo) Don't forget it's global! 12 | skin_weight: torch.Tensor [B, npts, J] 13 | Output: 14 | transformed_vertex 15 | """ 16 | vertex_batch, batch = mesh.pos, mesh.batch 17 | B = JM.shape[0] # batch_size 18 | 19 | if JM_type != 'glob': 20 | JM_global = transform_rel2glob(JM) 21 | else: 22 | JM_global = JM 23 | IJM = [] 24 | for i in range(B): 25 | IJM.append(torch.stack([JM_j.inverse() for JM_j in JM_global[i]], dim=0).view(J, 4, 4)) 26 | IJM = torch.stack(IJM, dim=0) 27 | transformed_vertex_batch = torch.zeros_like(vertex_batch) 28 | for batch_idx in range(B): 29 | vertex = vertex_batch[batch==batch_idx] 30 | V = vertex.__len__() 31 | vertex = torch.cat((vertex, torch.ones(V, 1).to(vertex.device)), dim=-1) # vertx: [B, npts, 4] 32 | local_vertex = torch.einsum('kij,vj->kvi', IJM[batch_idx], vertex) 33 | BM_transformed = torch.einsum('kij,kvj->kvi', BM[batch_idx], local_vertex) 34 | transformed_vertex = torch.einsum('vk,kvi->vi', skin_weight[batch==batch_idx], BM_transformed) 35 | transformed_vertex_batch[batch==batch_idx] = transformed_vertex[:, :3] 36 | transformed_mesh = Batch(pos=transformed_vertex_batch, edge_index=mesh.edge_index, batch=batch) 37 | 38 | return transformed_mesh 39 | 40 | def vertex_transform(vertex, JM, BM, skin_weight, J=22, JM_type='rel', BM_type='glob'): 41 | # Todo: transform vertex using V_transformed = skin_weight * BM * JM^-1 * V 42 | """ 43 | Input: 44 | vertex: torch.Tensor [B, npts, 3] --> (jsbae) it must be raw input, i.e. without sampling applied 45 | JM: torch.Tensor [B, J, 4, 4] --> (capoo) Don't forget it's global! (kim)No, it's relative.... 46 | BM: torch.Tensor [B, J, 4, 4] --> (capoo) Don't forget it's global! 47 | skin_weight: torch.Tensor [B, npts, J] 48 | Output: 49 | transformed_vertex 50 | """ 51 | if len(vertex.shape) == 2: 52 | B = 1 53 | V = vertex.shape[0] 54 | vertex = vertex.view(B, V, 3) 55 | JM = JM.view(B, J, 4, 4) 56 | BM = BM.view(B, J, 4, 4) 57 | skin_weight = skin_weight.view(B, V, J) 58 | else: 59 | B, V, _ = vertex.shape 60 | 61 | vertex = torch.cat((vertex, torch.ones(B, V, 1).to(vertex.device)), dim=-1) # vertx: [B, npts, 4] 62 | if JM_type != 'glob': 63 | JM_global = transform_rel2glob(JM) 64 | else: 65 | JM_global = JM 66 | IJM = [] 67 | for i in range(B): 68 | IJM.append(torch.stack([JM_j.inverse() for JM_j in JM_global[i]], dim=0).view(J, 4, 4)) 69 | IJM = torch.stack(IJM, dim=0) 70 | local_vertex = torch.einsum('bkij,bvj->bkvi', IJM, vertex) 71 | BM_transformed = torch.einsum('bkij,bkvj->bkvi', BM, local_vertex) 72 | transformed_vertex = torch.einsum('bvk,bkvi->bvi', skin_weight, BM_transformed) 73 | 74 | return transformed_vertex[...,:3] 75 | -------------------------------------------------------------------------------- /utils/test_skin_utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import random 4 | import datetime 5 | import shutil 6 | import yaml 7 | from configargparse import ArgumentParser, YAMLConfigFileParser 8 | 9 | import torch 10 | from torch_geometric.data import DataLoader 11 | from datasets.mixamo_skin_dataset import MixamoSkinDataset 12 | from models.mixamo_skin_model import MixamoMeshSkinModel 13 | 14 | def get_configs(): 15 | current_time = datetime.datetime.now().strftime('%d%b%Y-%H:%M') 16 | parser = ArgumentParser(config_file_parser_class=YAMLConfigFileParser) 17 | parser.add_argument('-c', '--config', required=True, is_config_file=True, 18 | help='config file path') 19 | parser.add_argument('--batch_size', type=int, default=32, help='input batch size') 20 | parser.add_argument('--workers', type=int, help='number of data lodaing workers', default=8) 21 | parser.add_argument('--nepoch', type=int, default=4, help='number of epochs to train for') 22 | parser.add_argument('--reproduce', default=False, action='store_true') 23 | parser.add_argument('--time', default=False, action='store_true') 24 | parser.add_argument('--datatype', default=None, type=str) # point, point_uniform, point_poisson 25 | parser.add_argument('--preprocess', default=False, action='store_true') 26 | 27 | # Dataset Settings 28 | parser.add_argument('--model', type=str, default='', help='model path to load from') 29 | parser.add_argument('--data_dir', type=str, default='data/mixamo', help='dataset path') 30 | parser.add_argument('--dataset_type', type=str, default='MixamoMeshDataset', help='type of dataset') 31 | parser.add_argument('--version', type=str, default='', help='version of dataset') 32 | parser.add_argument('--split_version', type=str, default='', help='version of split') 33 | parser.add_argument('--no_skin', default=False, action='store_true') # hoped it will make faster dataloader 34 | 35 | 36 | parser.add_argument('--log_dir', type=str, default=f"./logs/{current_time}", help='log dir') 37 | parser.add_argument('--save_step', type=int, default=1000, help='model saving interval') 38 | parser.add_argument('--vis_step', type=int, default=1000, help='model visualization interval') 39 | parser.add_argument('--sample_eval', type=int, default=1, help='use only subset of eval dataset') 40 | 41 | # Model Settings 42 | parser.add_argument('--npts', type=int, default=4096, dest='npts') 43 | parser.add_argument('--num_joints', type=int, default=22) 44 | parser.add_argument('--joint_loss_type', type=str, default='rel') # 'rel'/'glob'/'rel2glob' 45 | parser.add_argument('--bindpose_loss_type', type=str, default='glob') # 'rel'/'glob'/'rel2glob' 46 | parser.add_argument('--use_bindpose', default=False, action='store_true') 47 | parser.add_argument('--use_gt_ibm', default=False, action='store_true') 48 | 49 | parser.add_argument('--use_normal', default=False, action='store_true') 50 | parser.add_argument('--quantize', type=int, default=0) 51 | 52 | # Network Settings 53 | parser.add_argument('--use_bn', default=False, action='store_true', help='Use Batch Norm in networks?') 54 | parser.add_argument('--global_feature_size', type=int, default=1024) 55 | parser.add_argument('--feature_size', type=int, default=1024) 56 | parser.add_argument('--channels', type=int, default=[64, 256, 512], nargs=3) 57 | parser.add_argument('--k', type=int, default=-1) # k for k-nearest neighbor in euclidean distance 58 | parser.add_argument('--euc_radius', type=float, default=0.0) # euclidean ball, 0.6 in RigNet 59 | parser.add_argument('--network_type', type=str, default='full') # k for k-nearest neighbor in euclidean distance 60 | parser.add_argument('--edge_type', type=str, default='tpl_and_euc', help='select one of tpl_and_euc, tpl_only, euc_only') 61 | 62 | # Hyperparameter Settings 63 | parser.add_argument('--rot_hp', type=float, default=1., help='weight of rotation loss') 64 | parser.add_argument('--trans_hp', type=float, default=1., help='weight of translation loss') 65 | parser.add_argument('--skin_hp', type=float, default=1e-3, help='weight of skin loss') 66 | parser.add_argument('--bm_rot_hp', type=float, default=1., help='weight of rotation loss') 67 | parser.add_argument('--bm_trans_hp', type=float, default=1, help='weight of translation loss') 68 | parser.add_argument('--bm_shape_hp', type=float, default=1e-3, help='weight of skin loss') 69 | 70 | # Optimization Settings 71 | parser.add_argument('--lr', type=float, default=0.001) 72 | parser.add_argument('--lr_step_size', type=int, default=100) 73 | parser.add_argument('--lr_gamma', type=float, default=0.8) 74 | parser.add_argument('--overfit', default=False, action='store_true') 75 | parser.add_argument('--vis_overfit', default=False, action='store_true') # overfit on vis dataset 76 | 77 | args = parser.parse_args() 78 | 79 | if not args.reproduce: 80 | manual_seed = random.randint(1, 10000) 81 | else: 82 | manual_seed = 0 83 | print("Random Seed: ", manual_seed) 84 | random.seed(manual_seed) 85 | torch.manual_seed(manual_seed) 86 | 87 | return args 88 | 89 | 90 | def get_dataloaders(args): 91 | configs = vars(args) 92 | version = args.version 93 | split_version = args.split_version 94 | Dataset = MixamoSkinDataset 95 | # if args.dataset_type == 'MixamoMeshDataset': 96 | # Dataset = MixamoMeshDataset 97 | # elif args.dataset_type == 'MixamoPointDataset': 98 | # Dataset = MixamoPointDataset 99 | # else: 100 | # raise NotImplementedError 101 | 102 | if args.overfit: 103 | # "Training done on single instance: 'aj', 'Samba Dancing_000000' 104 | raise NotImplementedError 105 | train_dataset = Dataset(data_dir=args.data_dir, split='train_overfit', version=version, split_version=split_version, preprocess=args.preprocess, datatype=args.datatype, configs=configs) 106 | eval_dataset = train_dataset 107 | vis_dataset = train_dataset 108 | args.batch_size = 2 109 | elif args.vis_overfit: 110 | vis_dataset = Dataset(data_dir=args.data_dir, split='test_models', version=version, split_version=split_version, preprocess=args.preprocess, datatype=args.datatype, configs=configs) 111 | train_dataset = vis_dataset 112 | eval_dataset = vis_dataset 113 | else: # normal training 114 | train_dataset = Dataset(data_dir=args.data_dir, split='train_models', version=version, split_version=split_version, preprocess=args.preprocess, datatype=args.datatype, configs=configs) 115 | eval_dataset = Dataset(data_dir=args.data_dir, split='valid_models', version=version, split_version=split_version, preprocess=args.preprocess, datatype=args.datatype, configs=configs) 116 | vis_dataset = Dataset(data_dir=args.data_dir, split='test_models', version=version, split_version=split_version, preprocess=args.preprocess, datatype=args.datatype, configs=configs) 117 | 118 | if args.sample_eval > 1: 119 | eval_dataset = torch.utils.data.Subset(eval_dataset, list(range(0, eval_dataset.__len__(), args.sample_eval))) 120 | 121 | train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=int(args.workers), pin_memory=True) 122 | eval_dataloader = DataLoader(eval_dataset, batch_size=args.batch_size, shuffle=True, num_workers=int(args.workers), pin_memory=True) 123 | vis_dataloader = DataLoader(vis_dataset, batch_size=1, shuffle=False, num_workers=int(args.workers), pin_memory=True) 124 | 125 | return train_dataloader, eval_dataloader, vis_dataloader 126 | 127 | def get_hyperparameters(args): 128 | hyper_parameters = { 129 | 'rot_hp': args.rot_hp, 130 | 'trans_hp': args.trans_hp, 131 | 'skin_hp': args.skin_hp, 132 | 'bm_rot_hp': args.bm_rot_hp, 133 | 'bm_trans_hp': args.bm_trans_hp, 134 | 'bm_shape_hp': args.bm_shape_hp 135 | } 136 | return hyper_parameters 137 | 138 | def get_networkconfigs(args): 139 | network_configs = { 140 | # pointnet 141 | 'npts': args.npts, 142 | 'quantize': args.quantize, 143 | # vertex features 144 | 'use_normal': args.use_normal, 145 | # euclidean edge-conv 146 | 'euc_radius': args.euc_radius, # euc_edge_index selected on runtime 147 | # edge-conv 148 | 'global_feature_size': args.global_feature_size, 149 | 'feature_size': args.feature_size, 150 | 'channels': args.channels, 151 | # dynmic-edgeconv graph configs 152 | 'k': args.k, 153 | 'edge_type': args.edge_type 154 | } 155 | return network_configs -------------------------------------------------------------------------------- /utils/train_bvh_utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import random 4 | import datetime 5 | import shutil 6 | import yaml 7 | from configargparse import ArgumentParser 8 | from time import time 9 | 10 | import numpy as np 11 | import torch 12 | # from torch_geometric.data import DataLoader 13 | from torch.utils.data import DataLoader 14 | from datasets.mixamo_bvh_dataset import MixamoBVHDataset 15 | blue = lambda x: '\033[94m' + x + '\033[0m' 16 | 17 | from utils.voxel_utils import get_final_preds 18 | from utils.joint_util import transform_rel2glob 19 | from utils.loss_utils import normalize3d 20 | 21 | def train_bvh(i, epoch, step, data, model, optimizer, writer, losses_dict, train_num_batch, time_log, device, configs): 22 | # Load data 23 | if configs['time'] and i < 5: 24 | torch.cuda.synchronize() 25 | time_log['after_load'] = time() 26 | # Send data to CUDA 27 | data = [dat.to(device).float() for dat in data[:-1]] + [data[-1]] # last data is meta data 28 | input_position, target_rotation, meta = data # 29 | for key in meta: 30 | if isinstance(meta[key], torch.Tensor): 31 | meta[key].to(device).float() 32 | 33 | optimizer.zero_grad() 34 | if configs['time'] and i < 5: 35 | torch.cuda.synchronize() 36 | time_log['before_pred'] = time() 37 | # Single Inference 38 | pred_rotation = model(input_position) 39 | if configs['time'] and i < 5: 40 | torch.cuda.synchronize() 41 | time_log['after_pred'] = time() 42 | # Compute Loss 43 | loss =joint_loss = model.compute_loss(pred_rotation, target_rotation) 44 | if configs['time'] and i < 5: 45 | time_log['after_loss'] = time() 46 | # # Compute accuracy 47 | acc = joint_acc = model.compute_accuracy(input_position, pred_rotation, target_rotation) 48 | losses_dict['loss'].append(loss.item()) 49 | losses_dict['joint_acc'].append(joint_acc.item()) 50 | if configs['time'] and i < 5: 51 | torch.cuda.synchronize() 52 | time_log['after_acc'] = time() 53 | # Print loss 54 | model.print_loss(loss, joint_acc, epoch, i, train_num_batch) 55 | model.write_summary(losses_dict, step=step, writer=writer) 56 | # Compute Gradients 57 | loss.backward() 58 | if configs['time'] and i < 5: 59 | torch.cuda.synchronize() 60 | time_log['after_backprop'] = time() 61 | optimizer.step() 62 | # Print Timing 63 | if configs['time'] and i < 5: 64 | torch.cuda.synchronize() 65 | time_log['after_update'] = time() 66 | print('\r' + blue('[Log b={:d} i={:d}]'.format(configs['batch_size'], i)) + 67 | ' Total: {:.3f}, Loading: {:.3f}, Inferring: {:.3f}, ComputeLoss: {:.3f}, ComputeAcc: {:.3f}, Backprop: {:.3f}, Update: {:.3f}'.format( 68 | time_log['after_update'] - time_log['before_load'], 69 | time_log['after_load'] - time_log['before_load'], 70 | time_log['after_pred'] - time_log['before_pred'], time_log['after_loss'] - time_log['after_pred'], 71 | time_log['after_acc'] - time_log['after_loss'], 72 | time_log['after_backprop'] - time_log['after_acc'], time_log['after_update'] - time_log['after_backprop'] 73 | )) 74 | time_log['before_load'] = time() 75 | # Gradient Descent 76 | 77 | def eval_bvh(i, epoch, step, data, model, writer, losses_dict, eval_num_batch, device, configs): 78 | data = [dat.to(device).float() for dat in data[:-1]] + [data[-1]] # last data is meta data 79 | input_position, target_rotation, meta = data # 80 | for key in meta: 81 | if isinstance(meta[key], torch.Tensor): 82 | meta[key].to(device).float() 83 | pred_rotation = model(input_position) 84 | loss =joint_loss = model.compute_loss(pred_rotation, target_rotation) 85 | acc = joint_acc = model.compute_accuracy(input_position, pred_rotation, target_rotation) 86 | losses_dict['loss'].append(loss.item()) 87 | losses_dict['joint_acc'].append(joint_acc.item()) 88 | model.print_loss(loss, joint_acc, epoch, i, eval_num_batch) 89 | 90 | def vis_bvh(epoch, data, model, device, configs): 91 | data = [dat.to(device).float() for dat in data[:-1]] + [data[-1]] 92 | input_position, target_rotation, meta = data # 93 | for key in meta: 94 | if isinstance(meta[key], torch.Tensor): 95 | meta[key].to(device).float() 96 | pred_rotation = model(input_position) 97 | loss =joint_loss = model.compute_loss(pred_rotation, target_rotation, average=False) 98 | acc = joint_acc = model.compute_accuracy(input_position, pred_rotation, target_rotation, average=False) 99 | 100 | character_name, motion_name = meta['character_name'], meta['motion_name'] 101 | batch_size = input_position.shape[0] 102 | for i in range(batch_size): 103 | character_name_i, motion_name_i = character_name[i], motion_name[i] 104 | if not os.path.exists(os.path.join(configs['log_dir'], 'vis', character_name_i)): 105 | os.makedirs(os.path.join(configs['log_dir'], 'vis', character_name_i)) 106 | 107 | np.save(os.path.join(configs['log_dir'], 'vis', character_name_i, motion_name_i + '_input_pos_%.3d.csv' % (epoch)),input_position[i].cpu().detach()) 108 | np.save(os.path.join(configs['log_dir'], 'vis', character_name_i, motion_name_i + '_pred_rot_%.3d.csv' % (epoch)),pred_rotation[i].cpu().detach()) 109 | np.save(os.path.join(configs['log_dir'], 'vis', character_name_i, motion_name_i + '_target_rot_%.3d.csv' % (epoch)),target_rotation[i].cpu().detach()) 110 | np.save(os.path.join(configs['log_dir'], 'vis', character_name_i, motion_name_i + '_info_%.3d.csv' % (epoch)),{ 111 | 'loss': loss[i].cpu().detach(), 112 | 'acc': acc[i].cpu().detach() 113 | }) 114 | # np.savetxt(os.path.join(configs['log_dir'], 'vis', character_name_i, motion_name_i + '_skin_%.3d.csv' % (epoch)), 115 | # pred_skin_i, delimiter=',') 116 | 117 | def get_bvh_dataloaders(configs): 118 | Dataset = MixamoBVHDataset 119 | data_dir = configs['data_dir'] 120 | train_split, eval_split, vis_split = 'train_models.txt', 'valid_models.txt', 'test_models.txt' 121 | if configs['vis_overfit']: 122 | train_split = eval_split = vis_split = 'test_models.txt' 123 | train_dataset = Dataset(train_split, configs) 124 | eval_dataset = Dataset(eval_split, configs) 125 | vis_dataset = Dataset(vis_split, configs) 126 | train_dataloader = DataLoader(train_dataset, batch_size=configs['batch_size'], shuffle=True, num_workers=int(configs['workers']), pin_memory=True) 127 | eval_dataloader = DataLoader(eval_dataset, batch_size=configs['batch_size'], shuffle=True, num_workers=int(configs['workers']), pin_memory=True) 128 | vis_dataloader = DataLoader(vis_dataset, batch_size=configs['batch_size'], shuffle=False, num_workers=int(configs['workers']), pin_memory=True) 129 | 130 | return train_dataloader, eval_dataloader, vis_dataloader 131 | 132 | -------------------------------------------------------------------------------- /utils/train_skin_utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import random 4 | import datetime 5 | import shutil 6 | import yaml 7 | from configargparse import ArgumentParser, YAMLConfigFileParser 8 | 9 | import torch 10 | from torch_geometric.data import DataLoader 11 | from datasets.mixamo_skin_dataset import MixamoSkinDataset 12 | 13 | def get_configs(): 14 | current_time = datetime.datetime.now().strftime('%d%b%Y-%H:%M') 15 | parser = ArgumentParser(config_file_parser_class=YAMLConfigFileParser) 16 | parser.add_argument('-c', '--config', required=True, is_config_file=True, 17 | help='config file path') 18 | parser.add_argument('--batch_size', type=int, default=32, help='input batch size') 19 | parser.add_argument('--workers', type=int, help='number of data lodaing workers', default=8) 20 | parser.add_argument('--nepoch', type=int, default=4, help='number of epochs to train for') 21 | parser.add_argument('--reproduce', default=False, action='store_true') 22 | parser.add_argument('--time', default=False, action='store_true') 23 | parser.add_argument('--datatype', default=None, type=str) # point, point_uniform, point_poisson 24 | parser.add_argument('--preprocess', default=False, action='store_true') 25 | 26 | # Dataset Settings 27 | parser.add_argument('--model', type=str, default='', help='model path to load from') 28 | parser.add_argument('--data_dir', type=str, default='data/mixamo', help='dataset path') 29 | parser.add_argument('--no_skin', default=False, action='store_true') # hoped it will make faster dataloader 30 | 31 | 32 | parser.add_argument('--log_dir', type=str, default=f"./logs/{current_time}", help='log dir') 33 | parser.add_argument('--save_step', type=int, default=1000, help='model saving interval') 34 | parser.add_argument('--vis_step', type=int, default=1000, help='model visualization interval') 35 | 36 | # Model Settings 37 | parser.add_argument('--num_joints', type=int, default=22) 38 | parser.add_argument('--joint_loss_type', type=str, default='rel') # 'rel'/'glob'/'rel2glob' 39 | parser.add_argument('--bindpose_loss_type', type=str, default='glob') # 'rel'/'glob'/'rel2glob' 40 | parser.add_argument('--use_bindpose', default=False, action='store_true') 41 | parser.add_argument('--use_gt_ibm', default=False, action='store_true') 42 | 43 | parser.add_argument('--use_normal', default=False, action='store_true') 44 | parser.add_argument('--quantize', type=int, default=0) 45 | 46 | # Network Settings 47 | parser.add_argument('--use_bn', default=False, action='store_true', help='Use Batch Norm in networks?') 48 | parser.add_argument('--global_feature_size', type=int, default=1024) 49 | parser.add_argument('--feature_size', type=int, default=1024) 50 | parser.add_argument('--channels', type=int, default=[64, 256, 512], nargs=3) 51 | parser.add_argument('--k', type=int, default=-1) # k for k-nearest neighbor in euclidean distance 52 | parser.add_argument('--euc_radius', type=float, default=0.0) # euclidean ball, 0.6 in RigNet 53 | parser.add_argument('--network_type', type=str, default='full') # k for k-nearest neighbor in euclidean distance 54 | parser.add_argument('--edge_type', type=str, default='tpl_and_euc', help='select one of tpl_and_euc, tpl_only, euc_only') 55 | 56 | # Hyperparameter Settings 57 | parser.add_argument('--rot_hp', type=float, default=1., help='weight of rotation loss') 58 | parser.add_argument('--trans_hp', type=float, default=1., help='weight of translation loss') 59 | parser.add_argument('--skin_hp', type=float, default=1e-3, help='weight of skin loss') 60 | parser.add_argument('--bm_rot_hp', type=float, default=1., help='weight of rotation loss') 61 | parser.add_argument('--bm_trans_hp', type=float, default=1, help='weight of translation loss') 62 | parser.add_argument('--bm_shape_hp', type=float, default=1e-3, help='weight of skin loss') 63 | 64 | # Optimization Settings 65 | parser.add_argument('--lr', type=float, default=0.001) 66 | parser.add_argument('--lr_step_size', type=int, default=100) 67 | parser.add_argument('--lr_gamma', type=float, default=0.8) 68 | parser.add_argument('--overfit', default=False, action='store_true') 69 | parser.add_argument('--vis_overfit', default=False, action='store_true') # overfit on vis dataset 70 | 71 | args = parser.parse_args() 72 | 73 | print(args) 74 | if os.path.exists(args.log_dir): 75 | print("\nAre you re-training? [y/n]", end='') 76 | choice = input().lower() 77 | if choice not in ['y', 'n']: 78 | print("please type in valid response") 79 | sys.exit() 80 | elif choice == 'n': 81 | print("The log directory is already occupied. Do you want to remove and rewrite? [y/n]", end='') 82 | choice = input().lower() 83 | if choice == 'y': 84 | shutil.rmtree(args.log_dir, ignore_errors=True) 85 | os.makedirs(args.log_dir) 86 | else: 87 | print("Please choose a different log_dir") 88 | sys.exit() 89 | else: 90 | if args.model != '': 91 | print("You cannot restart when the model is specified") 92 | __import__('pdb').set_trace() 93 | else: 94 | ckpt_list = [ckpt for ckpt in os.listdir(args.log_dir) if ckpt.endswith('.pth')] 95 | args.model = os.path.join(args.log_dir, 96 | sorted(ckpt_list, key=lambda ckpt_str: ckpt_str.split('_')[-1].split('.pth')[0])[-1]) 97 | 98 | print("Retraining from ckpt: {}".format(args.model)) 99 | 100 | else: 101 | os.makedirs(args.log_dir) 102 | with open(os.path.join(args.log_dir, 'config.yaml'), 'w') as f: 103 | yaml.dump(vars(args), f)#, default_flow_style=None) 104 | 105 | if not args.reproduce: 106 | manual_seed = random.randint(1, 10000) 107 | else: 108 | manual_seed = 0 109 | print("Random Seed: ", manual_seed) 110 | random.seed(manual_seed) 111 | torch.manual_seed(manual_seed) 112 | 113 | return args 114 | 115 | 116 | def get_skin_dataloaders(args): 117 | configs = vars(args) 118 | Dataset = MixamoSkinDataset 119 | 120 | if args.overfit: 121 | # "Training done on single instance: 'aj', 'Samba Dancing_000000' 122 | raise NotImplementedError 123 | train_dataset = Dataset(data_dir=args.data_dir, split='train_overfit',preprocess=args.preprocess, datatype=args.datatype, configs=configs) 124 | eval_dataset = train_dataset 125 | vis_dataset = train_dataset 126 | args.batch_size = 2 127 | elif args.vis_overfit: 128 | vis_dataset = Dataset(data_dir=args.data_dir, split='test_models', preprocess=args.preprocess, datatype=args.datatype, configs=configs) 129 | train_dataset = vis_dataset 130 | eval_dataset = vis_dataset 131 | else: # normal training 132 | train_dataset = Dataset(data_dir=args.data_dir, split='train_models', preprocess=args.preprocess, datatype=args.datatype, configs=configs) 133 | eval_dataset = Dataset(data_dir=args.data_dir, split='valid_models', preprocess=args.preprocess, datatype=args.datatype, configs=configs) 134 | vis_dataset = Dataset(data_dir=args.data_dir, split='test_models', preprocess=args.preprocess, datatype=args.datatype, configs=configs) 135 | 136 | train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=int(args.workers), pin_memory=True) 137 | eval_dataloader = DataLoader(eval_dataset, batch_size=args.batch_size, shuffle=True, num_workers=int(args.workers), pin_memory=True) 138 | vis_dataloader = DataLoader(vis_dataset, batch_size=1, shuffle=False, num_workers=int(args.workers), pin_memory=True) 139 | 140 | return train_dataloader, eval_dataloader, vis_dataloader 141 | -------------------------------------------------------------------------------- /utils/tree_utils.py: -------------------------------------------------------------------------------- 1 | #------------------------------------------------------------------------------- 2 | # Name: tree_utils.py 3 | # Purpose: classes of node (joint) and tree-node (joint with its parent and children) 4 | # RigNet Copyright 2020 University of Massachusetts 5 | # RigNet is made available under General Public License Version 3 (GPLv3), or under a Commercial License. 6 | # Please see the LICENSE README.txt file in the main directory for more information and instruction on using and licensing RigNet. 7 | #------------------------------------------------------------------------------- 8 | 9 | # This file is from RigNet(https://github.com/zhan-xu/RigNet) 10 | 11 | class Node(object): 12 | def __init__(self, name, pos): 13 | self.name = name 14 | self.pos = pos 15 | 16 | 17 | class TreeNode(Node): 18 | def __init__(self, name, pos): 19 | super(TreeNode, self).__init__(name, pos) 20 | self.children = [] 21 | self.parent = None 22 | --------------------------------------------------------------------------------