├── .gitignore ├── README.md ├── datasets ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ └── heatmap3D_sdf.cpython-36.pyc └── heatmap3D_sdf.py ├── gen_dataset.py ├── maya_bind.py ├── models3D ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ └── model3d_hg.cpython-36.pyc └── model3d_hg.py ├── mst_generate.py ├── run_trainval.py └── util ├── __pycache__ ├── binvox_rw.cpython-36.pyc ├── open3d_utils.cpython-36.pyc ├── os_utils.cpython-36.pyc ├── train_utils.cpython-36.pyc ├── tree_utils.cpython-36.pyc └── vox_utils.cpython-36.pyc ├── binvox_rw.py ├── open3d_utils.py ├── os_utils.py ├── rigging_parser ├── __pycache__ │ ├── obj_parser.cpython-36.pyc │ └── skel_parser.cpython-36.pyc ├── obj_parser.py └── skel_parser.py ├── train_utils.py ├── tree_utils.py └── vox_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | checkpoints/ 3 | logs/ 4 | model_resource_data/ 5 | results/ 6 | *.sbatch 7 | *.txt 8 | *.err 9 | *.cpython-36.pyc -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This is the code repository implementing the paper "Predicting Animation Skeletons for 3D Articulated Models via Volumetric Nets". 2 | 3 | ### Dependecy and Setup 4 | 5 | The project is developed on Ubuntu 16.04 with cuda9.2 + cudnn7.5. 6 | We suggest to use conda virtual environment, which can be set up as following: 7 | 8 | ``` 9 | conda create -n AnimSkelVolNet python=3.6 10 | . activate AnimSkelVolNet 11 | pip install numpy, scipy, future, tensorboard, h5py, open3d, tqdm, opencv-python 12 | pip install torch==1.2.0 torchvision==0.4.0 -f https://download.pytorch.org/whl/torch_stable.html 13 | ``` 14 | 15 | ### Data 16 | 17 | Our dataset ModelResource has 3,193 models. 18 | We split it into 80% for training (2,554 models), 10% 19 | for validation (319 models), and 10% for testing. 20 | All models in fbx format can be downloaded [here](https://drive.google.com/file/d/1Zak7Ydfm6KuD9rwxdrfxV75rgtbxI2Hm/view?usp=sharing). 21 | 22 | To use this dataset in this project, we need some pre-processing, 23 | including calculating curvature and shape diameter, 24 | converting models into SDF voxels, calculating feature size as control parameter. 25 | Most of these works are done in C++. If you are interested 26 | in that part, you can implement with the help of [trimesh](https://gfx.cs.princeton.edu/proj/trimesh2/) 27 | and [Thea](https://github.com/sidch/Thea). 28 | We put the data after pre-processing [here](https://drive.google.com/file/d/1fcxTmRJAVEc0ZuiXM-NgqDdSmNLdOO1V/view?usp=sharing). 29 | The folder includes several sub-folders: 30 | 31 | * obj: all meshes in obj format. We triangulated them by [MeshLab](http://www.meshlab.net/), and fixed them by [meshfix](https://github.com/MarcoAttene/MeshFix-V2.1). 32 | * skel: we save skeleton information in txt files. Each row contains information for a joint, 33 | which are its level, joint name, joint position (x, y, z) and parent joint name. 34 | * curvature: curvature in voxel grids. Only surface voxels have value, otherwise 0. 35 | We use curvature along two directions, so the size is (2x82x82x82). 36 | * sd: shape diameter in voxel grids. Only surface voxels have value, otherwise 0. The size is (82x82x82). 37 | * vox_82: voxelized model with a resolution of (82x82x82). 38 | It is padded in the code to get a size of (88x88x88). 39 | * feature size (fs): Our feature size file contains one bone sample per row. 40 | Each row are (x, y, z) coordinates of a bone sample, followed by the feature size. 41 | Feature size is calculated by shooting rays on a plane perpendicular to the bone. 42 | For each bone sample, its "feature size" is the median distance to all nearest hits 43 | of the rays from it. The file use "new bone" to seperate samples from different bones. 44 | 45 | To create the data used directly by the code, see and run our script: 46 | 47 | `python gen_dataset.py` 48 | 49 | Remember to change the root_folder to the directory you uncompress the pre-processed data. 50 | 51 | ### Inference 52 | To run forward inference only, you can download a trained model from [here](https://drive.google.com/file/d/1MJMiMVIpilKbMFI_vRqFBgd2ExzniR0U/view?usp=sharing). 53 | Then you put it into REPO_PATH/checkpoints/volNet/, and run the following command: 54 | 55 | `python run_trainval.py -e --resume 'checkpoints/volNet/trained_model_volNet.pth.tar' --arch 'v2v_hg' --train-batch 4 --test-batch 4 --output_dir volNet --data_path 'DATA_PATH/model-resource-volumetric.h5' --json_file 'DATA_PATH/model-resource-volumetric.json' --input_feature curvature sd vertex_kde --num_stack 4` 56 | 57 | This will output the predicted joint&bone heatmaps, as well as the binary input voxels, 58 | into a folder called 'results/OUTPUT_DIR'. 59 | 60 | To generate the skeleton, you need to run our script: 61 | 62 | `python mst_generate.py` 63 | 64 | Remember to modify the result folder name and output folder name. 65 | 66 | You can run maya_bind.py in Maya to bind the predicted skeleton with the mesh. The skinning weights is generated by geodesic voxel binding in Maya. 67 | 68 | ### Training 69 | To train a model by yourself, run the following command 70 | 71 | `python run_trainval.py --arch 'v2v_hg' --data_path 'DATA_PATH/model-resource-volumetric.h5' --json_file 'DATA_PATH/model-resource-volumetric.json' --checkpoint 'checkpoints/volNet' --logdir 'logs/volNet' --lr 1e-4 --train-batch 4 --test-batch 4 --input_feature curvature sd vertex_kde --num_stack 4 --epochs 50` 72 | 73 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .heatmap3D_sdf import Heatmap3D_sdf 2 | __all__ = ('Heatmap3D_sdf') 3 | -------------------------------------------------------------------------------- /datasets/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhan-xu/AnimSkelVolNet/fbc103d3194d6c14b6276fd5be004462e7c4722f/datasets/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/heatmap3D_sdf.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhan-xu/AnimSkelVolNet/fbc103d3194d6c14b6276fd5be004462e7c4722f/datasets/__pycache__/heatmap3D_sdf.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/heatmap3D_sdf.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | 3 | import numpy as np 4 | import json 5 | import torch 6 | import h5py 7 | import pdb 8 | import scipy.ndimage as ndimage 9 | import torch.utils.data as data 10 | from scipy.ndimage import convolve 11 | 12 | 13 | class Heatmap3D_sdf(data.Dataset): 14 | def __init__(self, h5path, json_file, subset, kde, input_feature): 15 | self.h5path = h5path # data files 16 | self.kde = kde 17 | self.input_feature = input_feature 18 | with open(json_file, 'r') as train_anno_file: 19 | annos = json.load(train_anno_file) 20 | h5data = h5py.File(h5path, "r") 21 | if subset == 'train': 22 | self.annos = [anno for anno in annos if anno['subset'] =='train'] 23 | self.data = h5data['train_data'] 24 | self.vert = h5data['train_vert'] 25 | self.joint_label = h5data['train_label_joint'] 26 | self.bone_label = h5data['train_label_bone'] 27 | self.curvature = h5data['train_curvature'] 28 | self.sd = h5data['train_sd'] 29 | elif subset == 'val': 30 | self.annos = [anno for anno in annos if anno['subset']=='val'] 31 | self.data = h5data['val_data'] 32 | self.vert = h5data['val_vert'] 33 | self.joint_label = h5data['val_label_joint'] 34 | self.bone_label = h5data['val_label_bone'] 35 | self.curvature = h5data['val_curvature'] 36 | self.sd = h5data['val_sd'] 37 | elif subset == 'test': 38 | self.annos = [anno for anno in annos if anno['subset'] == 'test'] 39 | self.data = h5data['test_data'] 40 | self.vert = h5data['test_vert'] 41 | self.joint_label = h5data['test_label_joint'] 42 | self.bone_label = h5data['test_label_bone'] 43 | self.curvature = h5data['test_curvature'] 44 | self.sd = h5data['test_sd'] 45 | self.struct = np.ones((3, 3, 3)).astype(bool) 46 | 47 | def __getitem__(self, index): 48 | model = self.data[index].astype(np.float32) 49 | model = torch.from_numpy(model) 50 | model = model.unsqueeze(0) 51 | 52 | mask = self.data[index] 53 | mask = (mask < 0) 54 | mask = ndimage.binary_dilation(mask, structure=self.struct, iterations=2).astype(np.float32) 55 | mask = torch.from_numpy(mask) 56 | 57 | target_joint = self.joint_label[index].astype(np.float32) 58 | target_joint = torch.from_numpy(target_joint) 59 | 60 | target_bone = self.bone_label[index].astype(np.float32) 61 | target_bone = torch.from_numpy(target_bone) 62 | 63 | if 'vertex_kde' in self.input_feature: 64 | vert = self.vert[index].astype(np.float32) 65 | g_vert = self.make_gaussian(self.kde * self.annos[index]['avg_edge']) 66 | vert = convolve(vert.astype(np.float32), g_vert, mode='constant', cval=0) 67 | vert = torch.from_numpy(vert) 68 | model = torch.cat((model, vert.unsqueeze(0)), dim=0) 69 | 70 | if 'curvature' in self.input_feature: 71 | curvature = self.curvature[index].astype(np.float32) 72 | curvature = torch.from_numpy(curvature) 73 | model = torch.cat((model, curvature), dim=0) 74 | 75 | if 'sd' in self.input_feature: 76 | sd = self.sd[index].astype(np.float32) 77 | sd = torch.from_numpy(sd) 78 | model = torch.cat((model, sd), dim=0) 79 | 80 | # Meta info 81 | meta = {'index': index, 'min_5_fs': self.annos[index]['min_5_fs'], 'name': self.annos[index]['name'], 82 | 'translate': self.annos[index]['translate'], 'scale': self.annos[index]['scale'], 83 | 'center_trans': self.annos[index]['center_trans']} 84 | 85 | return model, mask, target_joint, target_bone, meta 86 | 87 | def __len__(self): 88 | return len(self.data) 89 | 90 | def make_gaussian(self, sigma): 91 | size = 6 * sigma + 1 92 | x0 = y0 = z0 = size // 2 93 | x = np.arange(0, size, 1, float) 94 | y = x[:, np.newaxis] 95 | z = y[..., np.newaxis] 96 | g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2 + (z - z0) ** 2) / (2 * sigma ** 2)) 97 | return g 98 | -------------------------------------------------------------------------------- /gen_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -* 3 | 4 | """ 5 | This script is for generating training/testing data, which is saved as a h5 file. 6 | """ 7 | 8 | import sys 9 | sys.path.append('./') 10 | 11 | import numpy as np 12 | import os 13 | import h5py 14 | import json 15 | import scipy.ndimage as ndimage 16 | from tqdm import tqdm 17 | 18 | import util.binvox_rw as binvox_rw 19 | from util.vox_utils import Cartesian2Voxcoord 20 | from util.rigging_parser.skel_parser import Skel 21 | from util.rigging_parser.obj_parser import Mesh_obj 22 | 23 | 24 | def unique_rows(a): 25 | # remove repeat rows from numpy array 26 | a = np.ascontiguousarray(a) 27 | unique_a = np.unique(a.view([('', a.dtype)]*a.shape[1])) 28 | return unique_a.view(a.dtype).reshape((unique_a.shape[0], a.shape[1])) 29 | 30 | 31 | def draw_jointmap(img, pt, sigma): 32 | # Draw a 3D gaussian 33 | # Check that any part of the gaussian is in-bounds 34 | ul = [int(pt[0] - 3 * sigma), int(pt[1] - 3 * sigma), int(pt[2] - 3*sigma)] 35 | br = [int(pt[0] + 3 * sigma + 1), int(pt[1] + 3 * sigma + 1), int(pt[2] + 3*sigma +1)] 36 | if (ul[0] >= img.shape[0] or ul[1] >= img.shape[1] or ul[2] >= img.shape[2] or 37 | br[0] < 0 or br[1] < 0 or br[2] < 0): 38 | # If not, just return the image as is 39 | return img 40 | 41 | # Generate gaussian 42 | size = 6 * sigma + 1 43 | x = np.arange(0, size, 1, float) 44 | y = x[:, np.newaxis] 45 | z = y[..., np.newaxis] 46 | x0 = y0 = z0 = size // 2 47 | # The gaussian is not normalized, we want the center value to equal 1 48 | g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2 + (z - z0) ** 2) / (2 * sigma ** 2)) 49 | 50 | # Usable gaussian range 51 | g_x = max(0, -ul[0]), min(br[0], img.shape[0]) - ul[0] 52 | g_y = max(0, -ul[1]), min(br[1], img.shape[1]) - ul[1] 53 | g_z = max(0, -ul[2]), min(br[2], img.shape[2]) - ul[2] 54 | # Image range 55 | img_x = max(0, ul[0]), min(br[0], img.shape[0]) 56 | img_y = max(0, ul[1]), min(br[1], img.shape[1]) 57 | img_z = max(0, ul[2]), min(br[2], img.shape[2]) 58 | 59 | img[img_x[0]:img_x[1], img_y[0]:img_y[1], img_z[0]:img_z[1]] = \ 60 | np.maximum(img[img_x[0]:img_x[1], img_y[0]:img_y[1], img_z[0]:img_z[1]], 61 | g[g_x[0]:g_x[1], g_y[0]:g_y[1], g_z[0]:g_z[1]]) 62 | return img 63 | 64 | 65 | def draw_bonemap(heatmap, p_pos, c_pos, output_resulotion): 66 | # create 3D bone heatmap. Voxels along the bone have value 1, otherwise 0 67 | c_pos = np.asarray(c_pos) 68 | ray = c_pos - p_pos 69 | i_step = np.arange(1, 100) 70 | unit_step = (ray / 100)[np.newaxis,:] 71 | unit_step = np.repeat(unit_step, 99, axis=0) 72 | pos_step = p_pos + unit_step * i_step[:,np.newaxis] 73 | pos_step = np.round(pos_step).astype(np.uint8) 74 | pos_step = np.array([p for p in pos_step if np.all(p >= 0) and np.all(p < output_resulotion)]) 75 | if len(pos_step) != 0: 76 | heatmap[pos_step[:, 0], pos_step[:, 1], pos_step[:, 2]] += 1 77 | np.clip(heatmap, 0.0, 1.0, out=heatmap) 78 | return heatmap 79 | 80 | 81 | def getConditions(fs_filename): 82 | ''' 83 | Read in our feature size file and return the 5th percentile of all feature size 84 | Our feature size file contains one bone sample per row. 85 | The first three numbers are coordinates of the sample. The last number is the feature size. 86 | Feature size is calculated in cpp, where we shoot rays on a plane perpendicular to the bone. 87 | For each bone sample, its "feature size" is the median distance to all nearest hits of the rays from it. 88 | Our feature size file use "new bone" to seperate samples from different bones. 89 | :param fs_filename: filename of our feature size file. 90 | :return: 5th percentile of all the feature size, used as contional variable during training and testing. 91 | ''' 92 | with open(fs_filename, 'r') as f: 93 | lines = f.readlines() 94 | if len(lines) == 0: 95 | return 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 96 | fs_all = [] 97 | fs_list_bone = [] 98 | for li in lines: 99 | if li.strip() == 'new bone': 100 | fs_all.append(fs_list_bone) 101 | fs_list_bone = [] 102 | else: 103 | words = li.split() 104 | fs_list_bone.append(float(words[3])) 105 | min_fs = [] 106 | for i in fs_all: 107 | min_fs += i 108 | min_fs = np.array(min_fs) 109 | min_5_fs = np.percentile(min_fs, 5) 110 | return min_5_fs 111 | 112 | 113 | def center_vox(volumn_input): 114 | #put the occupied voxels at the center instead of corner 115 | pos = np.where(volumn_input > 0) 116 | x_max, x_min = np.max(pos[0]), np.min(pos[0]) 117 | y_max, y_min = np.max(pos[1]), np.min(pos[1]) 118 | z_max, z_min = np.max(pos[2]), np.min(pos[2]) 119 | side_length = volumn_input.shape[0] 120 | mid_len = int(side_length / 2) 121 | xr_low = int((x_max - x_min + 1) / 2) 122 | xr_high = x_max - x_min + 1 - xr_low 123 | yr_low = int((y_max - y_min + 1) / 2) 124 | yr_high = y_max - y_min + 1 - yr_low 125 | zr_low = int((z_max - z_min + 1) / 2) 126 | zr_high = z_max - z_min + 1 - zr_low 127 | content = volumn_input[x_min: x_max + 1, y_min: y_max + 1, z_min: z_max + 1] 128 | volumn_output = np.zeros((volumn_input.shape), dtype=np.bool) 129 | center_trans = [x_min - mid_len + xr_low, y_min - mid_len + yr_low, z_min - mid_len + zr_low] 130 | center_trans = list(map(int, center_trans)) 131 | volumn_output[mid_len - xr_low:mid_len + xr_high, mid_len - yr_low:mid_len + yr_high, 132 | mid_len - zr_low:mid_len + zr_high] = content 133 | return volumn_output, center_trans 134 | 135 | 136 | def bin2sdf(input): 137 | ''' 138 | convert binary voxels into sign distance function field. Negetive for interior. Positive for exterior. Normalized. 139 | :param input: binary voxels 140 | :return: SDF representation of voxel. 141 | ''' 142 | fill_map = np.zeros(input.shape, dtype=np.bool) 143 | output = np.zeros(input.shape, dtype=np.float16) 144 | # fill inside 145 | changing_map = input.copy() 146 | sdf_in = -1 147 | while np.sum(fill_map) != np.sum(input): 148 | changing_map_new = ndimage.binary_erosion(changing_map) 149 | fill_map[np.where(changing_map_new!=changing_map)] = True 150 | output[np.where(changing_map_new!=changing_map)] = sdf_in 151 | changing_map = changing_map_new.copy() 152 | sdf_in -= 1 153 | # fill outside. No need to fill all of them, since during training, outside part will be masked. 154 | changing_map = input.copy() 155 | sdf_out = 1 156 | while np.sum(fill_map) != np.size(input): 157 | changing_map_new = ndimage.binary_dilation(changing_map) 158 | fill_map[np.where(changing_map_new!=changing_map)] = True 159 | output[np.where(changing_map_new!=changing_map)] = sdf_out 160 | changing_map = changing_map_new.copy() 161 | sdf_out += 1 162 | if sdf_out == -sdf_in: 163 | break 164 | # Normalization 165 | output[np.where(output < 0)] /= (-sdf_in-1) 166 | output[np.where(output > 0)] /= (sdf_out-1) 167 | return output 168 | 169 | 170 | def cal_avg_edge_length(mesh): 171 | # calculate the average length of all edges of the mesh, which is used as a multiplier when generating vertice density maps. 172 | points = mesh.v 173 | edge_index = np.concatenate((mesh.f[:, 0:2], mesh.f[:, 1:], mesh.f[:, [0, 2]]), axis=0) 174 | edge_index -= 1 175 | edge_index = unique_rows(edge_index) 176 | edge_length = np.linalg.norm(points[edge_index[:, 0]] - points[edge_index[:, 1]], axis=1) 177 | return points, np.mean(edge_length) 178 | 179 | 180 | def get_surface_vertice(mesh, trans, scale, center_trans, dim_ori=82, r=3, dim_pad=88): 181 | points, avg_edge = cal_avg_edge_length(mesh) 182 | res = np.zeros((dim_pad, dim_pad, dim_pad), dtype=np.uint8) 183 | vc = (points - np.array([trans])) / scale * dim_ori 184 | vc = np.round(vc).astype(int) 185 | vc = vc - np.array([center_trans]) + r 186 | for v in vc: 187 | if np.all((dim_pad - v) > 0) and np.all(v > 0): 188 | res[v[0], v[1], v[2]] += 1 189 | return res, avg_edge 190 | 191 | 192 | def genDataset_inner(root_folder, model_id, subset, dim_ori=82, r=3, dim_pad=88): 193 | ''' 194 | generate necessary data for one sample 195 | :param root_folder: directory with all raw data 196 | :param model_id: model ID 197 | :param subset: 'train', 'val' or 'test' 198 | :param dim_ori: original voxel grid resolution 199 | :param r: padding added to the original voxel grid 200 | :param dim_pad: padded voxel grid resolution 201 | :return: input representation for one sample, including SDF-voxelization, k1, k2 curvature maps, shape diameter, vertice density 202 | ''' 203 | mesh_file = os.path.join(root_folder, 'obj/{:d}.obj'.format(model_id)) 204 | vox_file = os.path.join(root_folder, 'vox_82/{:d}.binvox'.format(model_id)) 205 | skel_file = os.path.join(root_folder, 'skel/{:d}.txt'.format(model_id)) 206 | fs_file = os.path.join(root_folder, 'fs/{:d}_featuresize.txt'.format(model_id)) 207 | 208 | # read original voxels and pad it. 209 | with open(vox_file, 'rb') as f: 210 | mesh_vox = binvox_rw.read_as_3d_array(f) 211 | mesh_vox_padded = np.zeros((mesh_vox.dims[0] + 2 * r, mesh_vox.dims[1] + 2 * r, mesh_vox.dims[2] + 2 * r), dtype=np.float16) 212 | mesh_vox_padded[r:mesh_vox.dims[0] + r, r:mesh_vox.dims[1] + r, r:mesh_vox.dims[2] + r] = mesh_vox.data 213 | # put the occupied voxels at the center instead of left-top corner 214 | mesh_vox_padded, center_trans = center_vox(mesh_vox_padded) 215 | # convert binary voxels to SDF representation 216 | mesh_vox_padded = bin2sdf(mesh_vox_padded) 217 | 218 | mesh = Mesh_obj(mesh_file) 219 | min_5_fs = getConditions(fs_file) # get 5th-percentile control parameter 220 | skel = Skel(skel_file) # read in ground-truth skeleton 221 | heatmap_joint = np.zeros((int(mesh_vox.dims[0] + 2 * r), int(mesh_vox.dims[1] + 2 * r), 222 | int(mesh_vox.dims[2] + 2 * r)), dtype=np.float16) 223 | heatmap_bones = np.zeros((int(mesh_vox.dims[0] + 2 * r), int(mesh_vox.dims[1] + 2 * r), 224 | int(mesh_vox.dims[2] + 2 * r)), dtype=np.float16) 225 | # create vertice density heatmap. 226 | heatmap_verts, avg_edge = get_surface_vertice(mesh, mesh_vox.translate, mesh_vox.scale, center_trans, dim_ori, r, dim_pad) 227 | # start to create target joint&bone heatmaps. BFS iteration. 228 | this_level = [skel.root] 229 | while this_level: 230 | next_level = [] 231 | for p_node in this_level: 232 | pos = Cartesian2Voxcoord(np.array(p_node.pos), mesh_vox.translate, mesh_vox.scale, mesh_vox.dims[0]) 233 | pos = (pos[0] - center_trans[0] + r, pos[1] - center_trans[1] + r, pos[2] - center_trans[2] + r) 234 | pos = np.clip(pos, a_min=0, a_max=dim_pad-1) 235 | draw_jointmap(heatmap_joint, pos, sigma=0.6) 236 | next_level += p_node.children 237 | for c_node in p_node.children: 238 | ch_pos = Cartesian2Voxcoord(np.array(c_node.pos), mesh_vox.translate, mesh_vox.scale, mesh_vox.dims[0]) 239 | ch_pos = (ch_pos[0] - center_trans[0] + r, ch_pos[1] - center_trans[1] + r, ch_pos[2] - center_trans[2] + r) 240 | draw_bonemap(heatmap_bones, pos, ch_pos, output_resulotion=dim_pad) 241 | this_level = next_level 242 | 243 | # read original curvature 244 | curvature_raw = np.load(os.path.join(root_folder, 'curvature/{:d}_curvature.npy'.format(model_id))) 245 | curvature_surface = np.zeros((2, dim_pad, dim_pad, dim_pad), dtype=np.float16) 246 | # read original shape diameter 247 | sd_raw = np.load(os.path.join(root_folder, 'shape_diameter/{:d}_sd.npy'.format(model_id))) 248 | sd_surface = np.zeros((1, dim_pad, dim_pad, dim_pad), dtype=np.float16) 249 | # only preserve values at surface voxels. 250 | data_bin = (mesh_vox_padded < 0) 251 | changing_map_new = data_bin.copy() 252 | changing_map_new = ndimage.binary_erosion(changing_map_new) 253 | fill_map = (changing_map_new != data_bin) 254 | coord_v = np.argwhere(fill_map) 255 | coord_v_trans = coord_v + np.array([center_trans]) - r 256 | curvature_surface[:, coord_v[:, 0], coord_v[:, 1], coord_v[:, 2]] = \ 257 | curvature_raw[:, coord_v_trans[:, 0], coord_v_trans[:, 1], coord_v_trans[:, 2]] 258 | sd_surface[:, coord_v[:, 0], coord_v[:, 1], coord_v[:, 2]] = \ 259 | sd_raw[coord_v_trans[:, 0], coord_v_trans[:, 1], coord_v_trans[:, 2]] 260 | 261 | anno = {'name': str(model_id), 'min_5_fs': min_5_fs, 'translate': mesh_vox.translate, 'scale': mesh_vox.scale, 262 | 'subset': subset, 'center_trans': center_trans, 'avg_edge': avg_edge} 263 | return mesh_vox_padded, heatmap_joint, heatmap_bones, heatmap_verts, curvature_surface, sd_surface, anno 264 | 265 | 266 | def genDataset(root_folder, dim_ori=82, padding=3, dim_pad=88): 267 | # read in train/val/test split 268 | train_id_list = np.loadtxt(os.path.join(root_folder, 'train_final.txt'), dtype=int) 269 | val_id_list = np.loadtxt(os.path.join(root_folder, 'val_final.txt'), dtype=int) 270 | test_id_list = np.loadtxt(os.path.join(root_folder, 'test_final.txt'), dtype=int) 271 | 272 | num_train = len(train_id_list) 273 | num_val = len(val_id_list) 274 | num_test = len(test_id_list) 275 | 276 | # create sub-datasets 277 | hf = h5py.File(os.path.join(root_folder, 'model-resource-volumetric.h5'), 'w') 278 | hf.create_dataset('train_data', (num_train, dim_pad, dim_pad, dim_pad), np.float16) 279 | hf.create_dataset('train_vert', (num_train, dim_pad, dim_pad, dim_pad), np.uint8) 280 | hf.create_dataset('train_curvature', (num_train, 2, dim_pad, dim_pad, dim_pad), np.float16) 281 | hf.create_dataset('train_sd', (num_train, 1, dim_pad, dim_pad, dim_pad), np.float16) 282 | hf.create_dataset('train_label_joint', (num_train, dim_pad, dim_pad, dim_pad), np.float16) 283 | hf.create_dataset('train_label_bone', (num_train, dim_pad, dim_pad, dim_pad), np.float16) 284 | 285 | hf.create_dataset('val_data', (num_val, dim_pad, dim_pad, dim_pad), np.float16) 286 | hf.create_dataset('val_vert', (num_val, dim_pad, dim_pad, dim_pad), np.uint8) 287 | hf.create_dataset('val_curvature', (num_val, 2, dim_pad, dim_pad, dim_pad), np.float16) 288 | hf.create_dataset('val_sd', (num_val, 1, dim_pad, dim_pad, dim_pad), np.float16) 289 | hf.create_dataset('val_label_joint', (num_val, dim_pad, dim_pad, dim_pad), np.float16) 290 | hf.create_dataset('val_label_bone', (num_val, dim_pad, dim_pad, dim_pad), np.float16) 291 | 292 | hf.create_dataset('test_data', (num_test, dim_pad, dim_pad, dim_pad), np.float16) 293 | hf.create_dataset('test_vert', (num_test, dim_pad, dim_pad, dim_pad), np.uint8) 294 | hf.create_dataset('test_curvature', (num_test, 2, dim_pad, dim_pad, dim_pad), np.float16) 295 | hf.create_dataset('test_sd', (num_test, 1, dim_pad, dim_pad, dim_pad), np.float16) 296 | hf.create_dataset('test_label_joint', (num_test, dim_pad, dim_pad, dim_pad), np.float16) 297 | hf.create_dataset('test_label_bone', (num_test, dim_pad, dim_pad, dim_pad), np.float16) 298 | 299 | # start to fill all sub-datasets 300 | anno_all = [] 301 | for train_id in tqdm(range(len(train_id_list))): 302 | model_id = train_id_list[train_id] 303 | mesh_vox, heatmap_joint, heatmap_bones, heatmap_verts, curvature, sd, anno = \ 304 | genDataset_inner(root_folder, model_id, subset='train', dim_ori=dim_ori, r=padding, dim_pad=dim_pad) 305 | hf['train_data'][train_id, :, :, :] = mesh_vox 306 | hf['train_vert'][train_id, :, :, :] = heatmap_verts 307 | hf['train_label_joint'][train_id, :, :, :] = heatmap_joint 308 | hf['train_label_bone'][train_id, :, :, :] = heatmap_bones 309 | hf['train_curvature'][train_id, ...] = curvature[np.newaxis, ...] 310 | hf['train_sd'][train_id, ...] = sd[np.newaxis, ...] 311 | anno_all.append(anno) 312 | 313 | for val_id in tqdm(range(len(val_id_list))): 314 | model_id = val_id_list[val_id] 315 | mesh_vox, heatmap_joint, heatmap_bones, heatmap_verts, curvature, sd, anno = \ 316 | genDataset_inner(root_folder, model_id, subset='val', dim_ori=dim_ori, r=padding, dim_pad=dim_pad) 317 | hf['val_data'][val_id, :, :, :] = mesh_vox 318 | hf['val_vert'][val_id, :, :, :] = heatmap_verts 319 | hf['val_label_joint'][val_id, :, :, :] = heatmap_joint 320 | hf['val_label_bone'][val_id, :, :, :] = heatmap_bones 321 | hf['val_curvature'][val_id, ...] = curvature[np.newaxis, ...] 322 | hf['val_sd'][val_id, ...] = sd[np.newaxis, ...] 323 | anno_all.append(anno) 324 | 325 | for test_id in tqdm(range(len(test_id_list))): 326 | model_id = test_id_list[test_id] 327 | mesh_vox, heatmap_joint, heatmap_bones, heatmap_verts, curvature, sd, anno = \ 328 | genDataset_inner(root_folder, model_id, subset='test', dim_ori=dim_ori, r=padding, dim_pad=dim_pad) 329 | hf['test_data'][test_id, :, :, :] = mesh_vox 330 | hf['test_vert'][test_id, :, :, :] = heatmap_verts 331 | hf['test_label_joint'][test_id, :, :, :] = heatmap_joint 332 | hf['test_label_bone'][test_id, :, :, :] = heatmap_bones 333 | hf['test_curvature'][test_id, ...] = curvature[np.newaxis, ...] 334 | hf['test_sd'][test_id, ...] = sd[np.newaxis, ...] 335 | anno_all.append(anno) 336 | 337 | hf.close() 338 | 339 | # save accompanying information as json 340 | with open(os.path.join(root_folder, 'model-resource-volumetric.json'), 'w') as outfile: 341 | json.dump(anno_all, outfile) 342 | 343 | 344 | if __name__ == '__main__': 345 | root_folder = 'model_resource_data/' # the directory to put raw data and generated dataset 346 | genDataset(root_folder, dim_ori=82, padding=3, dim_pad=88) 347 | -------------------------------------------------------------------------------- /maya_bind.py: -------------------------------------------------------------------------------- 1 | """ 2 | The script can be used to bind predicted skeleton and original model together, with geodesic voxel skinning method of maya 3 | """ 4 | 5 | import pymel.core as pm 6 | 7 | def getGeometryGroups(): 8 | geo_list = [] 9 | geometries = cmds.ls(type='surfaceShape') 10 | for geo in geometries: 11 | if 'ShapeOrig' in geo: 12 | ''' 13 | we can also use cmds.ls(geo, l=True)[0].split("|")[0] 14 | to get the upper level node name, but stick on this way for now 15 | ''' 16 | geo_name = geo.replace('ShapeOrig', '') 17 | geo_list.append(geo_name) 18 | if not geo_list: 19 | geo_list = cmds.ls(type='surfaceShape') 20 | return geo_list 21 | 22 | 23 | def load_skel(filename): 24 | with open(filename, 'r') as fin: 25 | lines = fin.readlines() 26 | for li in lines: 27 | words = li.split() 28 | if words[5] == 'None': 29 | root = words[1] 30 | print 'root: '+root 31 | pos = (float(words[2]), float(words[3]), float(words[4])) 32 | cmds.joint(p=(pos[0], pos[1], pos[2]), name=root) 33 | break 34 | this_level = [root] 35 | while this_level: 36 | next_level = [] 37 | for pname in this_level: 38 | for li in lines: 39 | words = li.split() 40 | name_li = words[1] 41 | name_pa = words[5] 42 | if name_pa == pname: 43 | #print name_li 44 | cmds.select(pname, r=True) 45 | pos = (float(words[2]), float(words[3]), float(words[4])) 46 | cmds.joint(p=(pos[0], pos[1], pos[2]), name=name_li) 47 | next_level.append(name_li) 48 | this_level = next_level 49 | return root 50 | 51 | 52 | if __name__ == '__main__': 53 | obj_name = 'DATA_PATH\\obj\\1195.obj' 54 | skel_name = 'DATA_PATH\\skel\\1195.txt' 55 | 56 | # import obj 57 | cmds.file(new=True,force=True) 58 | cmds.file(obj_name, o=True) 59 | 60 | #import skel 61 | root = load_skel(skel_name) 62 | 63 | # geodesic volumetric skinning 64 | geo_list = getGeometryGroups() 65 | cmds.skinCluster(root, geo_list[0]) # The line only works for mesh with a single component. If the mesh has multiple groups, this line is incorrect! 66 | cmds.select('skinCluster1', r=True) 67 | cmds.select(geo_list[0], add=True) 68 | cmds.geomBind(bm=3, fo=0.5, mi=3) # adjust the parameters 69 | cmds.skinPercent('skinCluster1', geo_list[0], pruneWeights=0.2) # adjust the parameters 70 | 71 | # export fbx 72 | # pm.mel.FBXExport(f=out_name) -------------------------------------------------------------------------------- /models3D/__init__.py: -------------------------------------------------------------------------------- 1 | from .model3d_hg import * 2 | -------------------------------------------------------------------------------- /models3D/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhan-xu/AnimSkelVolNet/fbc103d3194d6c14b6276fd5be004462e7c4722f/models3D/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /models3D/__pycache__/model3d_hg.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhan-xu/AnimSkelVolNet/fbc103d3194d6c14b6276fd5be004462e7c4722f/models3D/__pycache__/model3d_hg.cpython-36.pyc -------------------------------------------------------------------------------- /models3D/model3d_hg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | __all__ = ['V2V_HG', 'v2v_hg'] 6 | 7 | 8 | class Basic3DBlock(nn.Module): 9 | def __init__(self, in_planes, out_planes, kernel_size): 10 | super(Basic3DBlock, self).__init__() 11 | self.block = nn.Sequential( 12 | nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, stride=1, padding=((kernel_size - 1) // 2)), 13 | nn.BatchNorm3d(out_planes), 14 | nn.ReLU(True) 15 | ) 16 | 17 | def forward(self, x): 18 | return self.block(x) 19 | 20 | 21 | class Res3DBlock(nn.Module): 22 | def __init__(self, in_planes, out_planes): 23 | super(Res3DBlock, self).__init__() 24 | self.res_branch = nn.Sequential( 25 | nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=1, padding=1), 26 | nn.BatchNorm3d(out_planes), 27 | nn.ReLU(True), 28 | nn.Conv3d(out_planes, out_planes, kernel_size=3, stride=1, padding=1), 29 | nn.BatchNorm3d(out_planes) 30 | ) 31 | 32 | if in_planes == out_planes: 33 | self.skip_con = nn.Sequential() 34 | else: 35 | self.skip_con = nn.Sequential( 36 | nn.Conv3d(in_planes, out_planes, kernel_size=1, stride=1, padding=0), 37 | nn.BatchNorm3d(out_planes) 38 | ) 39 | 40 | def forward(self, x): 41 | res = self.res_branch(x) 42 | skip = self.skip_con(x) 43 | return F.relu(res + skip, True) 44 | 45 | 46 | class Pool3DBlock(nn.Module): 47 | def __init__(self, pool_size, input_plane): 48 | super(Pool3DBlock, self).__init__() 49 | self.stride_conv = nn.Sequential( 50 | nn.Conv3d(input_plane, input_plane, kernel_size=pool_size, stride=pool_size, padding=0), 51 | nn.BatchNorm3d(input_plane), 52 | nn.ReLU(True) 53 | ) 54 | 55 | def forward(self, x): 56 | y = self.stride_conv(x) 57 | return y 58 | 59 | 60 | class Upsample3DBlock(nn.Module): 61 | def __init__(self, in_planes, out_planes, kernel_size, stride): 62 | super(Upsample3DBlock, self).__init__() 63 | # assert (kernel_size == 2) 64 | assert (stride == 2) 65 | self.block = nn.Sequential( 66 | nn.ConvTranspose3d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=0, output_padding=0), 67 | nn.BatchNorm3d(out_planes), 68 | nn.ReLU(True) 69 | ) 70 | 71 | def forward(self, x): 72 | return self.block(x) 73 | 74 | 75 | class HG(nn.Module): 76 | def __init__(self, front=False): 77 | super(HG, self).__init__() 78 | if front: 79 | input_channel = 8 80 | else: 81 | input_channel = 10 82 | self.encoder_pool1 = Pool3DBlock(2, input_channel) 83 | self.encoder_res1 = Res3DBlock(input_channel, 16) 84 | self.encoder_pool2 = Pool3DBlock(2, 16) 85 | self.encoder_res2 = Res3DBlock(16, 24) 86 | self.encoder_pool3 = Pool3DBlock(2, 24) 87 | self.encoder_res3 = Res3DBlock(24, 36) 88 | 89 | self.mid_res = Res3DBlock(40, 36) 90 | 91 | self.decoder_res3 = Res3DBlock(36, 36) 92 | self.decoder_upsample3 = Upsample3DBlock(36, 24, 2, 2) 93 | self.decoder_res2 = Res3DBlock(24, 24) 94 | self.decoder_upsample2 = Upsample3DBlock(24, 16, 2, 2) 95 | self.decoder_res1 = Res3DBlock(16, 16) 96 | self.decoder_upsample1 = Upsample3DBlock(16, 8, 2, 2) 97 | 98 | self.skip_res1 = Res3DBlock(input_channel, 8) 99 | self.skip_res2 = Res3DBlock(16, 16) 100 | self.skip_res3 = Res3DBlock(24, 24) 101 | 102 | def forward(self, x, c): 103 | skip_x1 = self.skip_res1(x) 104 | x = self.encoder_pool1(x) 105 | x = self.encoder_res1(x) 106 | skip_x2 = self.skip_res2(x) 107 | x = self.encoder_pool2(x) 108 | x = self.encoder_res2(x) 109 | skip_x3 = self.skip_res3(x) 110 | x = self.encoder_pool3(x) 111 | x = self.encoder_res3(x) 112 | 113 | c = c.repeat(1, 4, 11, 11, 11) 114 | x = torch.cat((x, c), dim=1) 115 | x = self.mid_res(x) 116 | 117 | x = self.decoder_res3(x) 118 | x = self.decoder_upsample3(x) 119 | x = x + skip_x3 120 | x = self.decoder_res2(x) 121 | x = self.decoder_upsample2(x) 122 | x = x + skip_x2 123 | x = self.decoder_res1(x) 124 | x = self.decoder_upsample1(x) 125 | x = x + skip_x1 126 | return x 127 | 128 | 129 | class V2V_HG(nn.Module): 130 | def __init__(self, input_channels, n_stack): 131 | super(V2V_HG, self).__init__() 132 | self.input_channels = input_channels 133 | self.n_stack = n_stack 134 | self.front_layers = nn.Sequential( 135 | Basic3DBlock(input_channels, 8, 5), 136 | Res3DBlock(8, 8) 137 | ) 138 | 139 | self.hg_1 = HG(front=True) 140 | 141 | self.joint_output_1 = nn.Sequential( 142 | Res3DBlock(8, 4), 143 | Basic3DBlock(4, 4, 1), 144 | nn.Dropout3d(p=0.2), 145 | nn.Conv3d(4, 1, kernel_size=1, stride=1, padding=0) 146 | ) 147 | self.bone_output_1 = nn.Sequential( 148 | Res3DBlock(8, 4), 149 | Basic3DBlock(4, 4, 1), 150 | nn.Dropout3d(p=0.2), 151 | nn.Conv3d(4, 1, kernel_size=1, stride=1, padding=0) 152 | ) 153 | 154 | if n_stack > 1: 155 | self.hg_list = nn.ModuleList([HG(front=False) for i in range(1, n_stack)]) 156 | self.joint_output_list = nn.ModuleList([nn.Sequential( 157 | Res3DBlock(8, 4), Basic3DBlock(4, 4, 1), nn.Dropout3d(p=0.2), 158 | nn.Conv3d(4, 1, kernel_size=1, stride=1, padding=0)) for i in range(1, n_stack)]) 159 | self.bone_output_list = nn.ModuleList([nn.Sequential( 160 | Res3DBlock(8, 4), Basic3DBlock(4, 4, 1), nn.Dropout3d(p=0.2), 161 | nn.Conv3d(4, 1, kernel_size=1, stride=1, padding=0)) for i in range(1, n_stack)]) 162 | self._initialize_weights() 163 | 164 | def forward(self, x_in, c): 165 | x = self.front_layers(x_in) 166 | x_hg_1 = self.hg_1(x, c) 167 | x_joint_out1 = self.joint_output_1(x_hg_1) 168 | x_bone_out1 = self.bone_output_1(x_hg_1) 169 | 170 | x_joint_out = [x_joint_out1] 171 | x_bone_out = [x_bone_out1] 172 | 173 | for i in range(1, self.n_stack): 174 | x_in = torch.cat((x, x_joint_out1, x_bone_out1), dim=1) 175 | x_hg = self.hg_list[i-1](x_in, c) 176 | x_joint = self.joint_output_list[i-1](x_hg) 177 | x_bone = self.bone_output_list[i-1](x_hg) 178 | x_joint_out.append(x_joint) 179 | x_bone_out.append(x_bone) 180 | 181 | return x_joint_out, x_bone_out 182 | 183 | def _initialize_weights(self): 184 | for m in self.modules(): 185 | if isinstance(m, nn.Conv3d): 186 | nn.init.normal_(m.weight, 0, 0.001) 187 | nn.init.constant_(m.bias, 0) 188 | elif isinstance(m, nn.ConvTranspose3d): 189 | nn.init.normal_(m.weight, 0, 0.001) 190 | nn.init.constant_(m.bias, 0) 191 | 192 | 193 | def v2v_hg(**kwargs): 194 | model = V2V_HG(input_channels=kwargs['input_channels'], n_stack=kwargs['n_stack']) 195 | return model 196 | -------------------------------------------------------------------------------- /mst_generate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -* 3 | 4 | """ 5 | This script is for generating skeleton based on the predicted joint&bone probability maps. 6 | It is formulated as a minimal spinning tree problem. 7 | A soft non-maximum suppression approach is used for better results. 8 | Refer to our paper for more details. 9 | """ 10 | 11 | import os 12 | import glob 13 | import cv2 14 | import sys 15 | import numpy as np 16 | from scipy.ndimage import correlate as correlate 17 | from scipy.ndimage import binary_erosion as binary_erosion 18 | 19 | from util.tree_utils import TreeNode 20 | from util.rigging_parser.obj_parser import Mesh_obj 21 | from util.rigging_parser.skel_parser import Skel 22 | from util.open3d_utils import show_obj_skel 23 | 24 | 25 | def nms_soft(heatmap, th_conf=0.1, size=15, sigma=6): 26 | ''' 27 | soft nms algorithm proposed in "Soft-nms - improving object detection with one line of code" 28 | :param heatmap: predicted joint heatmaps 29 | :param th_conf: lowest threshold to stop 30 | :param size: gaussian kernel size 31 | :param sigma: gaussian kernel sigma 32 | :return: heatmap after soft-nms 33 | ''' 34 | heatmap[heatmap < th_conf] = 0 35 | 36 | # we find all local maximums by looking into the gradient of each voxel. 37 | # Local maximums should have higher value to all its 26 neighbors. 38 | # so we apply 26 filters (each of them is 3*3*3) to compute the discrete gradient w.r.t. all directions/neighbors. 39 | kernels = np.zeros((26, 3, 3, 3)) 40 | kernels[:, 1, 1, 1] = 1.0 41 | gradient_map = np.zeros((26, heatmap.shape[0], heatmap.shape[1], heatmap.shape[2]), np.float32) 42 | # create 26 kernels to compute gradient 43 | for c in range(26): 44 | if c < 13: 45 | i = c // 9 46 | j = (c - i * 9) // 3 47 | k = (c - i * 9) % 3 48 | kernels[c, i, j, k] = -1.0 49 | else: 50 | i = (c + 1) // 9 51 | j = (c + 1 - i * 9) // 3 52 | k = (c + 1 - i * 9) % 3 53 | kernels[c, i, j, k] = -1.0 54 | gradient_map[c, ...] = correlate(heatmap, kernels[c, ...], mode='constant', cval=0.0) 55 | gradient_map[gradient_map >= 0] = 0 56 | gradient_map[gradient_map < 0] = -1 57 | # Sum up all 26 filters. Local maximums will have 26 zeros. Otherwise summation will be negetive. 58 | gradient_map = np.sum(gradient_map, axis=0) 59 | 60 | maximum_map = np.logical_and((gradient_map == 0), (heatmap > 0)) # local maximum map 61 | heatmap = np.multiply(heatmap, maximum_map.astype(np.float32)) # pick all local maximum by masking 62 | 63 | # create gaussian kernel for soft-nms 64 | radius = size//2 65 | x = np.arange(0, size, 1, float) 66 | y = x[:, np.newaxis] 67 | z = y[..., np.newaxis] 68 | x0 = y0 = z0 = size // 2 69 | # The gaussian is not normalized, we want the center value to equal 1 70 | # This is actually 1 - gaussian 71 | g = 1 - np.exp(- ((x - x0) ** 2 + (y - y0) ** 2 + (z - z0) ** 2) / (2 * sigma ** 2)) 72 | result = np.zeros((heatmap.shape), dtype=np.float32) 73 | # for symmetric reason, we only do nms to the left half. 74 | # Note we add 1 more column, because we don't want to ruin the middle part of heat map 75 | heatmap_half = heatmap[0:int(heatmap.shape[0] / 2) + 1, ...] 76 | 77 | while np.any(heatmap_half >= th_conf): 78 | idx_list = np.argwhere(heatmap_half == np.amax(heatmap_half)) 79 | if len(idx_list) > 1: 80 | # first merge adjacent voxels 81 | idx_list = idx_list.tolist() 82 | # For adjacent voxels, we only preserve the one closest to the middle plane. 83 | # So here we sort according to their distance to the middle plane. 84 | idx_list.sort(key=lambda x: abs(x[0] - 43.5)) 85 | idx_list_merge = [] # list to store joints after adjacent merging. 86 | for idx in idx_list: 87 | preserve_flag = True 88 | # check if any adjacent voxel has already preserved. 89 | for prefix in idx_list_merge: 90 | if abs(prefix[0] - idx[0]) <= 1 and abs(prefix[1] - idx[1]) <= 1 and abs(prefix[2] - idx[2]) <= 1: 91 | preserve_flag = False 92 | break 93 | if preserve_flag: # if no neighbor has been preserved, we add this voxel to the preserved list. 94 | idx_list_merge.append(idx) 95 | idx_list = np.array(idx_list_merge) 96 | # decay map is used to store how much we should reduce the heatmap probability 97 | decay_map = np.ones((heatmap_half.shape), dtype=np.float32) 98 | for idx in idx_list: 99 | # set the local maximum position to the original probability in the result heatmap 100 | result[idx[0], idx[1], idx[2]] = heatmap_half[idx[0], idx[1], idx[2]] 101 | heatmap_half[idx[0], idx[1], idx[2]] = 0 102 | ul = [int(idx[0] - radius), int(idx[1] - radius), int(idx[2] - radius)] 103 | br = [int(idx[0] + radius + 1), int(idx[1] + radius + 1), int(idx[2] + radius + 1)] 104 | g_x = max(0, -ul[0]), min(br[0], heatmap_half.shape[0]) - ul[0] 105 | g_y = max(0, -ul[1]), min(br[1], heatmap_half.shape[1]) - ul[1] 106 | g_z = max(0, -ul[2]), min(br[2], heatmap_half.shape[2]) - ul[2] 107 | # Image range 108 | img_x = max(0, ul[0]), min(br[0], heatmap_half.shape[0]) 109 | img_y = max(0, ul[1]), min(br[1], heatmap_half.shape[1]) 110 | img_z = max(0, ul[2]), min(br[2], heatmap_half.shape[2]) 111 | decay_map[img_x[0]:img_x[1], img_y[0]:img_y[1], img_z[0]:img_z[1]] = \ 112 | np.multiply(decay_map[img_x[0]:img_x[1], img_y[0]:img_y[1], img_z[0]:img_z[1]], 113 | g[g_x[0]:g_x[1], g_y[0]:g_y[1], g_z[0]:g_z[1]]) 114 | heatmap_half = np.multiply(heatmap_half, decay_map) 115 | heatmap_half[heatmap_half < th_conf] = 0 116 | 117 | return result 118 | 119 | 120 | def load_ts(ts_filename): 121 | trans = np.zeros(3) 122 | center_trans = np.zeros(3, dtype=np.int8) 123 | with open(ts_filename, 'r') as fts: 124 | line1 = fts.readline().strip().split() 125 | center_trans[0], center_trans[1], center_trans[2] = int(line1[0]), int(line1[1]), int(line1[2]) 126 | line2 = fts.readline().strip().split() 127 | trans[0], trans[1], trans[2] = float(line2[0]), float(line2[1]), float(line2[2]) 128 | line3 = fts.readline().strip() 129 | scl = float(line3) 130 | return trans, scl, center_trans 131 | 132 | 133 | def loadSkel_recur(p_node, parent_id, joint_pos, parent): 134 | for i in range(len(parent)): 135 | if parent[i] == parent_id: 136 | ch_node = TreeNode('joint_{}'.format(i), tuple(joint_pos[i])) 137 | p_node.children.append(ch_node) 138 | ch_node.parent = p_node 139 | loadSkel_recur(ch_node, i, joint_pos, parent) 140 | 141 | 142 | def getInitId(joint_pos): 143 | ''' 144 | Root joint is chosen as the lowest joint near the middle symmetric plane 145 | :param joint_pos: all joint positions 146 | :return: root joint ID 147 | ''' 148 | sorted_id = np.argsort(joint_pos[:,1]) 149 | for i in range(len(sorted_id)): 150 | id = sorted_id[i] 151 | if joint_pos[id, 0] < 0.2: 152 | continue 153 | if abs(joint_pos[id, 0]) < 2e-2: 154 | return id 155 | return np.argsort(abs(joint_pos[:,0]))[0] 156 | 157 | 158 | def unique_rows(a): 159 | a = np.ascontiguousarray(a) 160 | unique_a = np.unique(a.view([('', a.dtype)]*a.shape[1])) 161 | return unique_a.view(a.dtype).reshape((unique_a.shape[0], a.shape[1])) 162 | 163 | 164 | def minKey(key, mstSet, nV): 165 | # Initilaize min value 166 | min = sys.maxsize 167 | 168 | for v in range(nV): 169 | if key[v] < min and mstSet[v] == False: 170 | min = key[v] 171 | min_index = v 172 | 173 | return min_index 174 | 175 | 176 | def flip(pred_joint, pred_bone, trans, center_trans, scl, input_dim=88, r=3): 177 | ''' 178 | Enforcing predicted heatmap symmetric by reflecting left half to the right 179 | The symmetric voxel positions are computed by converting to euclidean space coordinates, finding symmetric positions and converting back. 180 | This is because voxel space has low resolution and accuracy. Euclidean space is more accurate. 181 | :param pred_joint: predicted joint heatmap 182 | :param pred_bone: predicted bone heatmap 183 | :param trans: translation vector from voxel space to euclidean space 184 | :param center_trans: translation vector from centered volume to original volume 185 | :param scl: scale from voxel space to euclidean space 186 | :param input_dim: voxel grid dimension 187 | :param r: voxel grid padding 188 | :return: symmetric heatmaps 189 | ''' 190 | ori_dim = input_dim - 2 * r # original voxel grid dimension without padding 191 | 192 | grid_x, grid_y, grid_z = np.meshgrid(np.arange(input_dim), np.arange(input_dim), np.arange(input_dim)) 193 | grid_coord = np.concatenate((grid_y.flatten()[:, None], grid_x.flatten()[:, None], grid_z.flatten()[:, None]), axis=1) 194 | grid_coord = grid_coord - r + center_trans 195 | # grid_coord (88^3 * 3) stores correpsonding euclidean coordinates for all voxels in voxel grid 196 | grid_coord = grid_coord / np.array([ori_dim, ori_dim, ori_dim]) * scl + trans 197 | 198 | # for joint heatmap 199 | if pred_joint is not None: 200 | # reflect the left-half heatmap values to the right-half 201 | # flatten predicted joint map to (88^3, ), with the same order as grid_coord 202 | val_pred_joint = pred_joint.flatten() 203 | val_left = np.copy(val_pred_joint[np.logical_and(grid_coord[:, 0] < -2e-2, val_pred_joint > 0)]) 204 | grid_coord_left = np.copy(grid_coord[np.logical_and(grid_coord[:, 0] < -2e-2, val_pred_joint > 0), :]) 205 | grid_coord_right_ = grid_coord_left.copy() 206 | grid_coord_right_[:, 0] = -grid_coord_right_[:, 0] 207 | 208 | # transform euclidean coodinates back into voxel space. 209 | vc = np.round((grid_coord_right_ - trans) / scl * ori_dim - center_trans + r) # 210 | vc = vc.astype(int) 211 | vc = np.clip(vc, 0, input_dim-1) 212 | pred_joint[vc[:,0], vc[:,1], vc[:,2]] = (pred_joint[vc[:,0], vc[:,1], vc[:,2]] + val_left) # add left to right 213 | 214 | # do the same for the right half, reflecting them into the left half. 215 | val_pred_joint = pred_joint.flatten() 216 | val_right = np.copy(val_pred_joint[np.logical_and(grid_coord[:, 0] > 2e-2, val_pred_joint > 0)]) 217 | grid_coord_right = np.copy(grid_coord[np.logical_and(grid_coord[:, 0] > 2e-2, val_pred_joint > 0), :]) 218 | grid_coord_left_ = grid_coord_right.copy() 219 | grid_coord_left_[:, 0] = -grid_coord_left_[:, 0] 220 | 221 | vc = np.round((grid_coord_left_ - trans) / scl * ori_dim - center_trans + r) 222 | vc = vc.astype(int) 223 | vc = np.clip(vc, 0, input_dim - 1) 224 | pred_joint[vc[:, 0], vc[:, 1], vc[:, 2]] = val_right # don't add values here. Just copy. Otherwise it will unsymmetric again! 225 | 226 | # for bone heatmap 227 | if pred_bone is not None: 228 | val_pred_bone = pred_bone.flatten() 229 | # reflect left-half voxels to right half 230 | val_left = np.copy(val_pred_bone[np.logical_and(grid_coord[:, 0] < -2e-2, val_pred_bone > 0)]) 231 | grid_coord_left = np.copy(grid_coord[np.logical_and(grid_coord[:, 0] < -2e-2, val_pred_bone > 0), :]) 232 | grid_coord_right_ = grid_coord_left.copy() 233 | grid_coord_right_[:, 0] = -grid_coord_right_[:, 0] 234 | 235 | vc = np.round((grid_coord_right_ - trans) / scl * ori_dim - center_trans + r) 236 | vc = vc.astype(int) 237 | vc = np.clip(vc, 0, input_dim - 1) 238 | pred_bone[vc[:, 0], vc[:, 1], vc[:, 2]] = (pred_bone[vc[:, 0], vc[:, 1], vc[:, 2]] + val_left) / 2 239 | 240 | val_pred_bone = pred_bone.flatten() 241 | # reflect right-half voxels to left half 242 | val_right = np.copy(val_pred_bone[np.logical_and(grid_coord[:, 0] > 2e-2, val_pred_bone > 0)]) 243 | grid_coord_right = np.copy(grid_coord[np.logical_and(grid_coord[:, 0] > 2e-2, val_pred_bone > 0), :]) 244 | grid_coord_left_ = grid_coord_right.copy() 245 | grid_coord_left_[:, 0] = -grid_coord_left_[:, 0] 246 | 247 | vc = np.round((grid_coord_left_ - trans) / scl * ori_dim - center_trans + r) 248 | vc = vc.astype(int) 249 | vc = np.clip(vc, 0, input_dim - 1) 250 | pred_bone[vc[:, 0], vc[:, 1], vc[:, 2]] = val_right 251 | 252 | return pred_joint, pred_bone 253 | 254 | 255 | def primMST(graph, init_id): 256 | nV = graph.shape[0] 257 | # Key values used to pick minimum weight edge in cut 258 | key = [sys.maxsize] * nV 259 | parent = [None] * nV # Array to store constructed MST 260 | mstSet = [False] * nV 261 | # Make key init_id so that this vertex is picked as first vertex 262 | key[init_id] = 0 263 | parent[init_id] = -1 # First node is always the root of 264 | 265 | for cout in range(nV): 266 | # Pick the minimum distance vertex from 267 | # the set of vertices not yet processed. 268 | # u is always equal to src in first iteration 269 | u = minKey(key, mstSet, nV) 270 | 271 | # Put the minimum distance vertex in 272 | # the shortest path tree 273 | mstSet[u] = True 274 | 275 | # Update dist value of the adjacent vertices 276 | # of the picked vertex only if the current 277 | # distance is greater than new distance and 278 | # the vertex in not in the shotest path tree 279 | for v in range(nV): 280 | # graph[u][v] is non zero only for adjacent vertices of m 281 | # mstSet[v] is false for vertices not yet included in MST 282 | # Update the key only if graph[u][v] is smaller than key[v] 283 | if graph[u,v] > 0 and mstSet[v] == False and key[v] > graph[u,v]: 284 | key[v] = graph[u,v] 285 | parent[v] = u 286 | 287 | return parent, key 288 | 289 | 290 | def primMST_symmetry(graph, init_id, joints): 291 | ''' 292 | My revised prim algorithm to find a MST as symmetric as possible. 293 | The function is sort of messy but... 294 | The basic idea is if a bone on the left is picked, its counterpart on the right should also be picked 295 | :param graph: cost matrix (N*N) 296 | :param init_id: joint ID to be first picked 297 | :param joints: joint position (N*3) 298 | :return: 299 | ''' 300 | joint_mapping = {} 301 | # this is trick. Since we already reflect the joints, "joints" have the order as left_joints->middle_joints->right_joints 302 | # so we can find correspondence by simply following the original order after splitting three parts. 303 | left_joint_ids = np.argwhere(joints[:, 0] < -2e-2).squeeze(1).tolist() 304 | middle_joint_ids = np.argwhere(np.abs(joints[:, 0]) <= 2e-2).squeeze(1).tolist() 305 | right_joint_ids = np.argwhere(joints[:, 0] > 2e-2).squeeze(1).tolist() 306 | for i in range(len(left_joint_ids)): 307 | joint_mapping[left_joint_ids[i]] = right_joint_ids[i] 308 | for i in range(len(right_joint_ids)): 309 | joint_mapping[right_joint_ids[i]] = left_joint_ids[i] 310 | 311 | if init_id not in middle_joint_ids: 312 | #find nearest joint in the middle to be root 313 | if len(middle_joint_ids) > 0: 314 | nearest_id = np.argmin(np.linalg.norm(joints[middle_joint_ids, :] - joints[init_id, :][np.newaxis, :], axis=1)) 315 | init_id = middle_joint_ids[nearest_id] 316 | 317 | nV = graph.shape[0] 318 | # Key values used to pick minimum weight edge in cut 319 | key = [sys.maxsize] * nV 320 | parent = [None] * nV # Array to store constructed MST 321 | mstSet = [False] * nV 322 | # Make key init_id so that this vertex is picked as first vertex 323 | key[init_id] = 0 324 | parent[init_id] = -1 # First node is always the root of 325 | 326 | while not all(mstSet): 327 | # Pick the minimum distance vertex from 328 | # the set of vertices not yet processed. 329 | # u is always equal to src in first iteration 330 | u = minKey(key, mstSet, nV) 331 | # left cases 332 | if u in left_joint_ids and parent[u] in middle_joint_ids: 333 | u2 = joint_mapping[u] 334 | if mstSet[u2] is False: 335 | mstSet[u2] = True 336 | parent[u2] = parent[u] 337 | key[u2] = graph[u2, parent[u2]] 338 | elif u in left_joint_ids and parent[u] in left_joint_ids: 339 | u2 = joint_mapping[u] 340 | if mstSet[u2] is False: 341 | mstSet[u2] = True 342 | parent[u2] = joint_mapping[parent[u]] 343 | key[u2] = graph[u2, parent[u2]] 344 | if mstSet[parent[u2]] is False: 345 | mstSet[parent[u2]] = True 346 | key[parent[u2]] = graph[parent[u2], parent[parent[u2]]] 347 | 348 | elif u in middle_joint_ids and parent[u] in left_joint_ids: 349 | # form loop in the tree, but we can do nothing 350 | u2 = None 351 | # right cases 352 | elif u in right_joint_ids and parent[u] in middle_joint_ids: 353 | u2 = joint_mapping[u] 354 | if mstSet[u2] is False: 355 | mstSet[u2] = True 356 | parent[u2] = parent[u] 357 | key[u2] = graph[u2, parent[u2]] 358 | elif u in right_joint_ids and parent[u] in right_joint_ids: 359 | u2 = joint_mapping[u] 360 | if mstSet[u2] is False: 361 | mstSet[u2] = True 362 | parent[u2] = joint_mapping[parent[u]] 363 | key[u2] = graph[u2, parent[u2]] 364 | if mstSet[parent[u2]] is False: 365 | mstSet[parent[u2]] = True 366 | key[parent[u2]] = graph[parent[u2], parent[parent[u2]]] 367 | elif u in middle_joint_ids and parent[u] in right_joint_ids: 368 | # form loop in the tree, but we can do nothing 369 | u2 = None 370 | # middle case 371 | else: 372 | u2 = None 373 | 374 | mstSet[u] = True 375 | 376 | # Update dist value of the adjacent vertices 377 | # of the picked vertex only if the current 378 | # distance is greater than new distance and 379 | # the vertex in not in the shotest path tree 380 | for v in range(nV): 381 | # graph[u][v] is non zero only for adjacent vertices of m 382 | # mstSet[v] is false for vertices not yet included in MST 383 | # Update the key only if graph[u][v] is smaller than key[v] 384 | if graph[u,v] > 0 and mstSet[v] == False and key[v] > graph[u,v]: 385 | key[v] = graph[u,v] 386 | parent[v] = u 387 | if u2 is not None and graph[u2,v] > 0 and mstSet[v] == False and key[v] > graph[u2,v]: 388 | key[v] = graph[u2, v] 389 | parent[v] = u2 390 | 391 | return parent, key 392 | 393 | 394 | def getMSTcost(joint_pos, joint_pos_cartesian, joint_pred, bone_pred, volume): 395 | n_joint = len(joint_pos) 396 | cost_matrix = np.zeros((n_joint, n_joint), dtype=np.float32) 397 | # fill upper triangular matrix 398 | for r in range(n_joint): 399 | for c in range(r + 1, n_joint): 400 | pos_start = joint_pos[r] 401 | pos_end = joint_pos[c] 402 | num_step = np.round((np.linalg.norm(pos_end - pos_start)) / 1.0).astype(int) 403 | pos_sample = pos_start[np.newaxis, :] + (pos_end - pos_start)[np.newaxis, :] * np.linspace(0, 1, num_step)[:, np.newaxis] 404 | pos_sample = np.round(pos_sample).astype(int) 405 | pos_sample = [tuple(row) for row in pos_sample] 406 | pos_sample = unique_rows(pos_sample) 407 | 408 | cost_matrix[r, c] = -np.log(bone_pred[pos_sample[:, 0], pos_sample[:, 1], pos_sample[:, 2]] + 1e-8).sum() 409 | cost_matrix[r, c] += 500 * np.sum(volume[pos_sample[:, 0], pos_sample[:, 1], pos_sample[:, 2]] == 0) 410 | 411 | cost_matrix = cost_matrix + cost_matrix.transpose() 412 | init_id = getInitId(joint_pos_cartesian) 413 | parent, key = primMST_symmetry(cost_matrix, init_id, joint_pos_cartesian) 414 | return joint_pos_cartesian, cost_matrix, parent, key 415 | 416 | 417 | def mst_generate(res_folder, best_thred, sigma, size, visualize=True, out_folder='results/mst_volNet/', 418 | mesh_folder='model_resource_data/obj_fixed/'): 419 | ''' 420 | Generate skeleton as a MST problem. 421 | :param res_folder: folde that contains predicted joints, bones, voxelized input and 422 | transformation information between vox-space coordinates and original (cartesian) coordinates. 423 | :param best_thred: minimal threshold to NMS 424 | :param sigma: sigma for soft NMS 425 | :param size: gaussian kernel size for soft NMS 426 | :param visualize: visualize result or not with Open3D library 427 | :param out_folder: folder to output final results 428 | ''' 429 | joint_pred_list = glob.glob(res_folder + 'joint_pred_*.npy') 430 | for i in range(len(joint_pred_list)): 431 | joint_pred_file = joint_pred_list[i] 432 | model_id = joint_pred_file.split('_')[-1][:-4] 433 | print(model_id) 434 | joint_pred = np.load(joint_pred_file) 435 | bone_pred = np.load(joint_pred_file.replace('joint_', 'bone_')) 436 | input = np.load(joint_pred_file.replace('joint_pred_', 'input_')).astype(np.float32) 437 | erode_input = binary_erosion(input, structure=None, iterations=1).astype(np.float32) 438 | mesh = Mesh_obj(os.path.join(mesh_folder, '{}.obj'.format(model_id))) 439 | 440 | joint_pred = np.clip(joint_pred, 0.0, 1.0) 441 | bone_pred = np.clip(bone_pred, 0.0, 1.0) 442 | joint_pred = joint_pred * erode_input 443 | bone_pred = bone_pred * erode_input 444 | ts_filename = joint_pred_file.replace('joint_pred_', 'ts_').replace('.npy', '.txt') 445 | trans, scl, center_trans = load_ts(ts_filename) 446 | 447 | joint_pred, bone_pred = flip(joint_pred, bone_pred, trans, center_trans, scl) 448 | joint_pred = nms_soft(joint_pred, best_thred, size=size, sigma=sigma) 449 | joint_pred, _ = flip(joint_pred, None, trans, center_trans, scl) 450 | 451 | joint_pos = np.argwhere(joint_pred > 1e-10) 452 | joint_pos_cartesian = joint_pos - 3 + center_trans 453 | joint_pos_cartesian = joint_pos_cartesian / 82 * scl + trans 454 | #contain_flag = filter_joints_between_legs(joint_pos_cartesian, mesh.v) 455 | #joint_pos_cartesian = joint_pos_cartesian[contain_flag] 456 | #joint_pos = joint_pos[contain_flag] 457 | 458 | # make extracted joints symmetric 459 | joint_pos_cartesian_left = joint_pos_cartesian[joint_pos_cartesian[:, 0] < -2e-2] 460 | joint_pos_cartesian_middle = joint_pos_cartesian[np.abs(joint_pos_cartesian[:, 0]) < 2e-2] 461 | joint_pos_cartesian_right = joint_pos_cartesian_left * np.array([[-1, 1, 1]]) 462 | joint_pos_cartesian = np.concatenate((joint_pos_cartesian_left, joint_pos_cartesian_middle, joint_pos_cartesian_right), axis=0) 463 | joint_pos = np.round((joint_pos_cartesian - trans) / scl * 82 - center_trans + 3).astype(int) 464 | 465 | if len(joint_pos) in [0, 1, 2]: 466 | print('too few joints extracted. Try to reduce the threshold.') 467 | continue 468 | if len(joint_pos) > 100: 469 | print('too many joints extracted. Try to increase the threshold.') 470 | continue 471 | joint_pos_cartesian, cost_matrix, parent, key = getMSTcost(joint_pos, joint_pos_cartesian, joint_pred, bone_pred, input) 472 | 473 | skel = Skel() 474 | for i in range(len(parent)): 475 | if parent[i] == -1: 476 | skel.root = TreeNode('joint_{}'.format(i), tuple(joint_pos_cartesian[i])) 477 | break 478 | loadSkel_recur(skel.root, i, joint_pos_cartesian, parent) 479 | 480 | if joint_pos_cartesian is not None: 481 | if visualize: 482 | img = show_obj_skel(mesh, skel.root) 483 | cv2.imwrite(os.path.join(out_folder, 'mst_{0}.jpg').format(model_id), img[:, 300:-300, ::-1]) 484 | if parent: 485 | skel.save(os.path.join(out_folder, 'mst_{0}.txt').format(model_id)) 486 | 487 | 488 | if __name__ == '__main__': 489 | mesh_folder = 'model_resource_data/obj/' 490 | folder_name = 'volNet/' 491 | out_folder = 'results/mst_{0}'.format(folder_name) 492 | if not os.path.isdir(out_folder): 493 | os.mkdir(out_folder) 494 | best_thred = 0.02 495 | print(folder_name, best_thred) 496 | mst_generate('results/{0}'.format(folder_name), 497 | best_thred=best_thred, sigma=15, size=11, visualize=True, out_folder=out_folder, 498 | mesh_folder=mesh_folder) 499 | -------------------------------------------------------------------------------- /run_trainval.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -* 3 | 4 | """ 5 | This script is used for training and inferencing our deep neural network. 6 | """ 7 | 8 | import os 9 | import shutil 10 | import argparse 11 | import numpy as np 12 | 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | import torch.backends.cudnn as cudnn 17 | from torch.utils.tensorboard import SummaryWriter 18 | 19 | import datasets 20 | import models3D 21 | 22 | from util.os_utils import mkdir_p, isfile, isdir 23 | from util.train_utils import save_checkpoint, adjust_learning_rate, AverageMeter 24 | 25 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 26 | 27 | 28 | def main(args): 29 | global device 30 | lowest_loss = 1e40 31 | 32 | # create checkpoint dir and log dir 33 | if not isdir(args.checkpoint): 34 | print("Create new checkpoint folder " + args.checkpoint) 35 | mkdir_p(args.checkpoint) 36 | if not args.resume: 37 | if isdir(args.logdir): 38 | shutil.rmtree(args.logdir) 39 | mkdir_p(args.logdir) 40 | if not args.evaluate: 41 | logger = SummaryWriter(args.logdir) 42 | 43 | # create model 44 | print("==> creating model") 45 | input_channel = 1 46 | if 'curvature' in args.input_feature: 47 | input_channel += 2 48 | if 'vertex_kde' in args.input_feature: 49 | input_channel += 1 50 | if 'sd' in args.input_feature: 51 | input_channel += 1 52 | n_stack = args.num_stack 53 | 54 | model = models3D.__dict__[args.arch](input_channels=input_channel, n_stack=n_stack) 55 | 56 | model.to(device) 57 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 58 | 59 | # optionally resume from a checkpoint 60 | if args.resume: 61 | if isfile(args.resume): 62 | print("=> loading checkpoint '{}'".format(args.resume)) 63 | checkpoint = torch.load(args.resume) 64 | args.start_epoch = checkpoint['epoch'] 65 | lowest_loss = checkpoint['lowest_loss'] 66 | model.load_state_dict(checkpoint['state_dict']) 67 | optimizer.load_state_dict(checkpoint['optimizer']) 68 | print("=> loaded checkpoint '{}' (epoch {})" 69 | .format(args.resume, checkpoint['epoch'])) 70 | else: 71 | print("=> no checkpoint found at '{}'".format(args.resume)) 72 | 73 | if torch.cuda.device_count() > 1: 74 | print('Using', torch.cuda.device_count(), 'GPUs') 75 | model = nn.DataParallel(model) 76 | model.to(device) 77 | 78 | cudnn.benchmark = True 79 | print(' Total params: %.2fM' % (sum(p.numel() for p in model.parameters()) / 1000000.0)) 80 | train_loader = torch.utils.data.DataLoader(datasets.Heatmap3D_sdf(args.data_path, args.json_file, 'train', args.kde, args.input_feature), 81 | batch_size=args.train_batch, shuffle=True, num_workers=args.workers, pin_memory=True) 82 | val_loader = torch.utils.data.DataLoader(datasets.Heatmap3D_sdf(args.data_path, args.json_file, 'val', args.kde, args.input_feature), 83 | batch_size=args.test_batch, shuffle=True, num_workers=args.workers, pin_memory=True) 84 | test_loader = torch.utils.data.DataLoader(datasets.Heatmap3D_sdf(args.data_path, args.json_file, 'test', args.kde, args.input_feature), 85 | batch_size=args.test_batch, shuffle=True, num_workers=args.workers, pin_memory=True) 86 | if args.evaluate: 87 | print('\nEvaluation only') 88 | test_loss, test_loss_joint, test_loss_bone = test(test_loader, model, args) 89 | print('test loss: ', test_loss, 'test_loss_joint: ', test_loss_joint, 'test_loss_bone: ', test_loss_bone) 90 | #args.output_dir = args.output_dir + '_val' 91 | #val_loss, val_loss_joint, val_loss_bone = test(val_loader, model, args) 92 | #print('val loss: ', val_loss, 'val_loss_joint: ', val_loss_joint, 'val_loss_bone: ', val_loss_bone) 93 | return 94 | 95 | lr = args.lr 96 | for epoch in range(args.start_epoch, args.epochs): 97 | lr = adjust_learning_rate(optimizer, epoch, lr, args.schedule, args.gamma) 98 | print('\nEpoch: %d | LR: %.8f' % (epoch + 1, lr)) 99 | train_loss, train_loss_joint, train_loss_bone = train(train_loader, model, optimizer) 100 | valid_loss, val_loss_joint, val_loss_bone = validate(val_loader, model) 101 | print(valid_loss, val_loss_joint, val_loss_bone) 102 | test_loss, test_loss_joint, test_loss_bone = test(test_loader, model, args) 103 | print(test_loss, test_loss_joint, test_loss_bone) 104 | 105 | # remember best acc and save checkpoint 106 | is_best = valid_loss < lowest_loss 107 | lowest_loss = min(valid_loss, lowest_loss) 108 | if torch.cuda.device_count() > 1: 109 | states = { 110 | 'epoch': epoch + 1, 111 | 'state_dict': model.module.state_dict(), 112 | 'lowest_loss': lowest_loss, 113 | 'optimizer': optimizer.state_dict(), 114 | } 115 | else: 116 | states = { 117 | 'epoch': epoch + 1, 118 | 'state_dict': model.state_dict(), 119 | 'lowest_loss': lowest_loss, 120 | 'optimizer': optimizer.state_dict(), 121 | } 122 | 123 | save_checkpoint(states, is_best, checkpoint=args.checkpoint) 124 | 125 | info = {'train_loss': train_loss, 'train_loss_joint': train_loss_joint, 'train_loss_bone': train_loss_bone, 126 | 'val_loss': valid_loss, 'val_loss_joint': val_loss_joint, 'val_loss_bone': val_loss_bone, 127 | 'test_loss': test_loss, 'test_loss_joint': test_loss_joint, 'test_loss_bone': test_loss_bone} 128 | for tag, value in info.items(): 129 | logger.add_scalar(tag, value, epoch+1) 130 | 131 | 132 | def train(train_loader, model, optimizer): 133 | global device 134 | model.train() # switch to train mode 135 | losses = AverageMeter() 136 | losses_joint = AverageMeter() 137 | losses_bone = AverageMeter() 138 | for i, (inputs, mask, target_joint, target_bone, meta) in enumerate(train_loader): 139 | inputs_var = inputs.to(device) 140 | target_joint_var = target_joint[:, None, :, :, :].to(device) 141 | target_bone_var = target_bone[:, None, :, :, :].to(device) 142 | mask_var = mask[:, None, :, :, :].to(device) 143 | 144 | (score_map_joint, score_map_bone) = model(inputs_var, meta['min_5_fs'].view(-1, 1, 1, 1, 1).float().to(device)) 145 | loss_joint, loss_bone = 0.0, 0.0 146 | for n_stack in range(len(score_map_joint)): 147 | loss_joint += F.binary_cross_entropy_with_logits(score_map_joint[n_stack], target_joint_var, 148 | weight=mask_var, reduction='sum') / mask_var.sum() 149 | loss_bone += F.binary_cross_entropy_with_logits(score_map_bone[n_stack], target_bone_var, 150 | weight=mask_var, reduction='sum') / mask_var.sum() 151 | loss = loss_joint + loss_bone 152 | 153 | #record loss 154 | losses_joint.update(loss_joint.item()) 155 | losses_bone.update(loss_bone.item()) 156 | losses.update(loss.item()) 157 | optimizer.zero_grad() 158 | 159 | loss.backward() 160 | optimizer.step() 161 | 162 | return losses.avg, losses_joint.avg, losses_bone.avg 163 | 164 | 165 | def validate(val_loader, model): 166 | global device 167 | losses = AverageMeter() 168 | losses_joint = AverageMeter() 169 | losses_bone = AverageMeter() 170 | model.eval() # switch to test mode 171 | with torch.no_grad(): 172 | for i, (inputs, mask, target_joint, target_bone, meta) in enumerate(val_loader): 173 | inputs_var = inputs.to(device) 174 | target_joint_var = target_joint[:, None, :, :, :].to(device) 175 | target_bone_var = target_bone[:, None, :, :, :].to(device) 176 | mask_var = mask[:, None, :, :, :].to(device) 177 | 178 | (score_map_joint, score_map_bone) = model(inputs_var, meta['min_5_fs'].view(-1, 1, 1, 1, 1).float().to(device)) 179 | 180 | loss_joint, loss_bone = 0.0, 0.0 181 | for n_stack in range(len(score_map_joint)): 182 | loss_joint += F.binary_cross_entropy_with_logits(score_map_joint[n_stack], target_joint_var, 183 | weight=mask_var, reduction='sum') / mask_var.sum() 184 | loss_bone += F.binary_cross_entropy_with_logits(score_map_bone[n_stack], target_bone_var, 185 | weight=mask_var, reduction='sum') / mask_var.sum() 186 | 187 | loss = loss_joint + loss_bone 188 | # record loss 189 | losses_joint.update(loss_joint.item()) 190 | losses_bone.update(loss_bone.item()) 191 | losses.update(loss.item()) 192 | 193 | return losses.avg, losses_joint.avg, losses_bone.avg 194 | 195 | 196 | def test(test_loader, model, args): 197 | global device 198 | losses = AverageMeter() 199 | losses_joint = AverageMeter() 200 | losses_bone = AverageMeter() 201 | model.eval() # switch to test mode 202 | with torch.no_grad(): 203 | for i, (inputs, mask, target_joint, target_bone, meta) in enumerate(test_loader): 204 | inputs_var = inputs.to(device) 205 | target_joint_var = target_joint[:, None, :, :, :].to(device) 206 | target_bone_var = target_bone[:, None, :, :, :].to(device) 207 | mask_var = mask[:, None, :, :, :].to(device) 208 | 209 | (score_map_joint_raw, score_map_bone_raw) = model(inputs_var, meta['min_5_fs'].view(-1, 1, 1, 1, 1).float().to(device)) 210 | 211 | if args.evaluate: 212 | output_folder = os.path.join('results/', args.output_dir) 213 | if not os.path.isdir(output_folder): 214 | mkdir_p(output_folder) 215 | score_map_joint = torch.sigmoid(score_map_joint_raw[-1]) 216 | score_map_bone = torch.sigmoid(score_map_bone_raw[-1]) 217 | score_map_joint *= mask_var 218 | score_map_bone *= mask_var 219 | score_map_joint = score_map_joint.cpu().data.numpy().squeeze() 220 | score_map_bone = score_map_bone.cpu().data.numpy().squeeze() 221 | 222 | inputs_bin = (inputs[:, 0, ...] < 0) 223 | inputs_bin = inputs_bin.squeeze() 224 | for id in range(len(inputs)): 225 | name_id = meta['name'][id] 226 | np.save(os.path.join(output_folder, 'input_' + name_id + '.npy'), inputs_bin[id]) 227 | np.save(os.path.join(output_folder, 'joint_pred_' + name_id + '.npy'), score_map_joint[id]) 228 | np.save(os.path.join(output_folder, 'bone_pred_'+ name_id +'.npy'), score_map_bone[id]) 229 | with open(os.path.join(output_folder,'ts_' + name_id +'.txt'), 'w') as f_ts: 230 | f_ts.write('{0} {1} {2}\n'.format(meta['center_trans'][0][id].item(), 231 | meta['center_trans'][1][id].item(), 232 | meta['center_trans'][2][id].item())) 233 | f_ts.write('{0} {1} {2}\n'.format(meta['translate'][0][id].item(), 234 | meta['translate'][1][id].item(), 235 | meta['translate'][2][id].item())) 236 | f_ts.write('{0}\n'.format(meta['scale'][id].item())) 237 | 238 | loss_joint = F.binary_cross_entropy_with_logits(score_map_joint_raw[-1], target_joint_var, 239 | weight=mask_var, reduction='sum') / mask_var.sum() 240 | loss_bone = F.binary_cross_entropy_with_logits(score_map_bone_raw[-1], target_bone_var, 241 | weight=mask_var, reduction='sum') / mask_var.sum() 242 | loss = loss_joint + loss_bone 243 | 244 | # # record loss 245 | losses_joint.update(loss_joint.item()) 246 | losses_bone.update(loss_bone.item()) 247 | losses.update(loss.item()) 248 | 249 | return losses.avg, losses_joint.avg, losses_bone.avg 250 | 251 | 252 | if __name__ == '__main__': 253 | parser = argparse.ArgumentParser(description='PyTorch 3D Heatmap Training') 254 | # Training 255 | parser.add_argument('--arch', '-a', metavar='ARCH', default='v2v_hg', help='model architecture.') 256 | parser.add_argument('--num_stack', metavar='NT', type=int, default=4, help='number of hourglass module.') 257 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', help='epoch number') 258 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, metavar='W', help='weight decay') 259 | parser.add_argument('--gamma', type=float, default=0.1, help='LR is multiplied by gamma on schedule.') 260 | parser.add_argument('-j', '--workers', default=3, type=int, metavar='N', help='number of data loading workers') 261 | parser.add_argument('--epochs', default=50, type=int, metavar='N', help='number of total epochs to run') 262 | parser.add_argument('--lr', '--learning-rate', default=1e-4, type=float, metavar='LR', help='initial learning rate') 263 | parser.add_argument('--schedule', type=int, nargs='+', default=[], help='Decrease learning rate at these epochs.') 264 | parser.add_argument('--kde', type=float, default=10, 265 | help='sigma of gaussian around each surface vertex is kde times average edge length') 266 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', help='evaluate model') 267 | parser.add_argument('--train-batch', default=2, type=int, metavar='N', help='train batchsize') 268 | parser.add_argument('--test-batch', default=2, type=int, metavar='N', help='test batchsize') 269 | parser.add_argument('-c', '--checkpoint', default='checkpoints/volNet', type=str, metavar='PATH', 270 | help='path to save checkpoint') 271 | parser.add_argument('--logdir', default='logs/volNet', type=str, metavar='LOG', help='directory to save logs') 272 | parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint') 273 | parser.add_argument('--data_path', type=str, help='h5 data file with all data', 274 | default='model_resource_data/model-resource-volumetric.h5') 275 | parser.add_argument('--json_file', type=str, help='annotation json file', 276 | default='model_resource_data/model-resource-volumetric.json') 277 | parser.add_argument('--output_dir', type=str, default='volNet', help='prediction output folder') 278 | parser.add_argument('--input_feature', type=str, nargs='+', default=['curvature', 'sd', 'vertex_kde'], 279 | help='input feature name list (curvature, sd, vertex_kde)') 280 | 281 | print(parser.parse_args()) 282 | main(parser.parse_args()) 283 | 284 | -------------------------------------------------------------------------------- /util/__pycache__/binvox_rw.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhan-xu/AnimSkelVolNet/fbc103d3194d6c14b6276fd5be004462e7c4722f/util/__pycache__/binvox_rw.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/open3d_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhan-xu/AnimSkelVolNet/fbc103d3194d6c14b6276fd5be004462e7c4722f/util/__pycache__/open3d_utils.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/os_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhan-xu/AnimSkelVolNet/fbc103d3194d6c14b6276fd5be004462e7c4722f/util/__pycache__/os_utils.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/train_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhan-xu/AnimSkelVolNet/fbc103d3194d6c14b6276fd5be004462e7c4722f/util/__pycache__/train_utils.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/tree_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhan-xu/AnimSkelVolNet/fbc103d3194d6c14b6276fd5be004462e7c4722f/util/__pycache__/tree_utils.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/vox_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhan-xu/AnimSkelVolNet/fbc103d3194d6c14b6276fd5be004462e7c4722f/util/__pycache__/vox_utils.cpython-36.pyc -------------------------------------------------------------------------------- /util/binvox_rw.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2012 Daniel Maturana 2 | # This file is part of binvox-rw-py. 3 | # 4 | # binvox-rw-py is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU General Public License as published by 6 | # the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # binvox-rw-py is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU General Public License 15 | # along with binvox-rw-py. If not, see . 16 | # 17 | 18 | 19 | import numpy as np 20 | import struct 21 | 22 | 23 | class Voxels(object): 24 | """ Holds a binvox model. 25 | data is either a three-dimensional numpy boolean array (dense representation) 26 | or a two-dimensional numpy float array (coordinate representation). 27 | 28 | dims, translate and scale are the model metadata. 29 | 30 | dims are the voxel dimensions, e.g. [32, 32, 32] for a 32x32x32 model. 31 | 32 | scale and translate relate the voxels to the original model coordinates. 33 | 34 | To translate voxel coordinates i, j, k to original coordinates x, y, z: 35 | 36 | x_n = (i+.5)/dims[0] 37 | y_n = (j+.5)/dims[1] 38 | z_n = (k+.5)/dims[2] 39 | x = scale*x_n + translate[0] 40 | y = scale*y_n + translate[1] 41 | z = scale*z_n + translate[2] 42 | 43 | """ 44 | 45 | def __init__(self, data, dims, translate, scale, axis_order): 46 | self.data = data 47 | self.dims = dims 48 | self.translate = translate 49 | self.scale = scale 50 | assert (axis_order in ('xzy', 'xyz')) 51 | self.axis_order = axis_order 52 | 53 | def clone(self): 54 | data = self.data.copy() 55 | dims = self.dims[:] 56 | translate = self.translate[:] 57 | return Voxels(data, dims, translate, self.scale, self.axis_order) 58 | 59 | def write(self, fp): 60 | write(self, fp) 61 | 62 | def read_header(fp): 63 | """ Read binvox header. Mostly meant for internal use. 64 | """ 65 | line = fp.readline().strip() 66 | if not line.startswith(b'#binvox'): 67 | raise IOError('Not a binvox file') 68 | dims = list(map(int, fp.readline().strip().split(b' ')[1:])) 69 | translate = list(map(float, fp.readline().strip().split(b' ')[1:])) 70 | scale = list(map(float, fp.readline().strip().split(b' ')[1:]))[0] 71 | line = fp.readline() 72 | 73 | return dims, translate, scale 74 | 75 | def read_as_3d_array(fp, fix_coords=True): 76 | """ Read binary binvox format as array. 77 | 78 | Returns the model with accompanying metadata. 79 | 80 | Voxels are stored in a three-dimensional numpy array, which is simple and 81 | direct, but may use a lot of memory for large models. (Storage requirements 82 | are 8*(d^3) bytes, where d is the dimensions of the binvox model. Numpy 83 | boolean arrays use a byte per element). 84 | 85 | Doesn't do any checks on input except for the '#binvox' line. 86 | """ 87 | dims, translate, scale = read_header(fp) 88 | raw_data = np.frombuffer(fp.read(), dtype=np.uint8) 89 | # if just using reshape() on the raw data: 90 | # indexing the array as array[i,j,k], the indices map into the 91 | # coords as: 92 | # i -> x 93 | # j -> z 94 | # k -> y 95 | # if fix_coords is true, then data is rearranged so that 96 | # mapping is 97 | # i -> x 98 | # j -> y 99 | # k -> z 100 | values, counts = raw_data[::2], raw_data[1::2] 101 | data = np.repeat(values, counts).astype(np.bool) 102 | data = data.reshape(dims) 103 | if fix_coords: 104 | # xzy to xyz TODO the right thing 105 | data = np.transpose(data, (0, 2, 1)) 106 | axis_order = 'xyz' 107 | else: 108 | axis_order = 'xzy' 109 | return Voxels(data, dims, translate, scale, axis_order) 110 | 111 | def read_as_coord_array(fp, fix_coords=True): 112 | """ Read binary binvox format as coordinates. 113 | 114 | Returns binvox model with voxels in a "coordinate" representation, i.e. an 115 | 3 x N array where N is the number of nonzero voxels. Each column 116 | corresponds to a nonzero voxel and the 3 rows are the (x, z, y) coordinates 117 | of the voxel. (The odd ordering is due to the way binvox format lays out 118 | data). Note that coordinates refer to the binvox voxels, without any 119 | scaling or translation. 120 | 121 | Use this to save memory if your model is very sparse (mostly empty). 122 | 123 | Doesn't do any checks on input except for the '#binvox' line. 124 | """ 125 | dims, translate, scale = read_header(fp) 126 | raw_data = np.frombuffer(fp.read(), dtype=np.uint8) 127 | 128 | values, counts = raw_data[::2], raw_data[1::2] 129 | 130 | sz = np.prod(dims) 131 | index, end_index = 0, 0 132 | end_indices = np.cumsum(counts) 133 | indices = np.concatenate(([0], end_indices[:-1])).astype(end_indices.dtype) 134 | 135 | values = values.astype(np.bool) 136 | indices = indices[values] 137 | end_indices = end_indices[values] 138 | 139 | nz_voxels = [] 140 | for index, end_index in zip(indices, end_indices): 141 | nz_voxels.extend(range(index, end_index)) 142 | nz_voxels = np.array(nz_voxels) 143 | # TODO are these dims correct? 144 | # according to docs, 145 | # index = x * wxh + z * width + y; // wxh = width * height = d * d 146 | 147 | x = nz_voxels / (dims[0]*dims[1]) 148 | zwpy = nz_voxels % (dims[0]*dims[1]) # z*w + y 149 | z = zwpy / dims[0] 150 | y = zwpy % dims[0] 151 | if fix_coords: 152 | data = np.vstack((x, y, z)) 153 | axis_order = 'xyz' 154 | else: 155 | data = np.vstack((x, z, y)) 156 | axis_order = 'xzy' 157 | 158 | #return Voxels(data, dims, translate, scale, axis_order) 159 | return Voxels(np.ascontiguousarray(data), dims, translate, scale, axis_order) 160 | 161 | def dense_to_sparse(voxel_data, dtype=np.int): 162 | """ From dense representation to sparse (coordinate) representation. 163 | No coordinate reordering. 164 | """ 165 | if voxel_data.ndim!=3: 166 | raise ValueError('voxel_data is wrong shape; should be 3D array.') 167 | return np.asarray(np.nonzero(voxel_data), dtype) 168 | 169 | def sparse_to_dense(voxel_data, dims, dtype=np.bool): 170 | if voxel_data.ndim!=2 or voxel_data.shape[0]!=3: 171 | raise ValueError('voxel_data is wrong shape; should be 3xN array.') 172 | if np.isscalar(dims): 173 | dims = [dims]*3 174 | dims = np.atleast_2d(dims).T 175 | # truncate to integers 176 | xyz = voxel_data.astype(np.int) 177 | # discard voxels that fall outside dims 178 | valid_ix = ~np.any((xyz < 0) | (xyz >= dims), 0) 179 | xyz = xyz[:,valid_ix] 180 | out = np.zeros(dims.flatten(), dtype=dtype) 181 | out[tuple(xyz)] = True 182 | return out 183 | 184 | #def get_linear_index(x, y, z, dims): 185 | #""" Assuming xzy order. (y increasing fastest. 186 | #TODO ensure this is right when dims are not all same 187 | #""" 188 | #return x*(dims[1]*dims[2]) + z*dims[1] + y 189 | 190 | def bwrite(fp,s): 191 | fp.write(s.encode()) 192 | 193 | def write_pair(fp,state, ctr): 194 | fp.write(struct.pack('B',state)) 195 | fp.write(struct.pack('B',ctr)) 196 | 197 | def write(voxel_model, fp): 198 | """ Write binary binvox format. 199 | 200 | Note that when saving a model in sparse (coordinate) format, it is first 201 | converted to dense format. 202 | 203 | Doesn't check if the model is 'sane'. 204 | 205 | """ 206 | if voxel_model.data.ndim==2: 207 | # TODO avoid conversion to dense 208 | dense_voxel_data = sparse_to_dense(voxel_model.data, voxel_model.dims) 209 | else: 210 | dense_voxel_data = voxel_model.data 211 | 212 | bwrite(fp, '#binvox 1\n') 213 | bwrite(fp, 'dim ' + ' '.join(map(str, voxel_model.dims)) + '\n') 214 | bwrite(fp, 'translate ' + ' '.join(map(str, voxel_model.translate)) + '\n') 215 | bwrite(fp, 'scale ' + str(voxel_model.scale) + '\n') 216 | bwrite(fp, 'data\n') 217 | if not voxel_model.axis_order in ('xzy', 'xyz'): 218 | raise ValueError('Unsupported voxel model axis order') 219 | 220 | if voxel_model.axis_order=='xzy': 221 | voxels_flat = dense_voxel_data.flatten() 222 | elif voxel_model.axis_order=='xyz': 223 | voxels_flat = np.transpose(dense_voxel_data, (0, 2, 1)).flatten() 224 | 225 | # keep a sort of state machine for writing run length encoding 226 | state = voxels_flat[0] 227 | ctr = 0 228 | for c in voxels_flat: 229 | if c==state: 230 | ctr += 1 231 | # if ctr hits max, dump 232 | if ctr==255: 233 | write_pair(fp, state, ctr) 234 | ctr = 0 235 | else: 236 | # if switch state, dump 237 | write_pair(fp, state, ctr) 238 | state = c 239 | ctr = 1 240 | # flush out remainders 241 | if ctr > 0: 242 | write_pair(fp, state, ctr) 243 | 244 | if __name__ == '__main__': 245 | import doctest 246 | doctest.testmod() 247 | -------------------------------------------------------------------------------- /util/open3d_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import open3d as o3d 3 | 4 | 5 | def find_lines_from_tree(root, line_list, pos_list): 6 | if not root.children: 7 | return 8 | else: 9 | for ch in root.children: 10 | pos_list.append(list(root.pos)) 11 | pos_list.append(list(ch.pos)) 12 | line_list.append([len(pos_list)-2, len(pos_list)-1]) 13 | find_lines_from_tree(ch, line_list, pos_list) 14 | 15 | 16 | def drawCube(center, radius, color=[0.0,0.0,0.0]): 17 | mesh_sphere = o3d.geometry.TriangleMesh.create_sphere(radius=radius) 18 | transform_mat = np.eye(4) 19 | transform_mat[0:3, -1] = center 20 | mesh_sphere.transform(transform_mat) 21 | mesh_sphere.paint_uniform_color(color) 22 | return mesh_sphere 23 | 24 | 25 | def show_obj_skel(mesh, root): 26 | vis = o3d.visualization.Visualizer() 27 | vis.create_window() 28 | 29 | # show obj mesh 30 | line_set_mesh = o3d.geometry.LineSet() 31 | line_set_mesh.points = o3d.utility.Vector3dVector(mesh.v) 32 | lines_mesh = np.concatenate((mesh.f[:, [0, 1]] - 1, mesh.f[:, [0, 2]] - 1, mesh.f[:, [1, 2]] - 1), axis=0) 33 | line_set_mesh.lines = o3d.utility.Vector2iVector(lines_mesh) 34 | colors = [[0.8, 0.8, 0.8] for i in range(len(lines_mesh))] 35 | line_set_mesh.colors = o3d.utility.Vector3dVector(colors) 36 | vis.add_geometry(line_set_mesh) 37 | 38 | # show skel 39 | line_list_skel = [] 40 | joint_pos_list = [] 41 | find_lines_from_tree(root, line_list_skel, joint_pos_list) 42 | line_set_skel = o3d.geometry.LineSet() 43 | for joint_pos in joint_pos_list: 44 | vis.add_geometry(drawCube(joint_pos, 0.007, color=[1.0,0.0,0.0])) 45 | 46 | line_set_skel.points = o3d.utility.Vector3dVector(joint_pos_list) 47 | line_set_skel.lines = o3d.utility.Vector2iVector(line_list_skel) 48 | colors = [[0.0, 0.0, 1.0] for i in range(len(line_list_skel))] 49 | line_set_skel.colors = o3d.utility.Vector3dVector(colors) 50 | vis.add_geometry(line_set_skel) 51 | 52 | vis.run() 53 | image = vis.capture_screen_float_buffer() 54 | vis.destroy_window() 55 | image = np.asarray(image) * 255 56 | image = image.astype(np.uint8) 57 | return image 58 | 59 | 60 | 61 | 62 | -------------------------------------------------------------------------------- /util/os_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import os 4 | import errno 5 | 6 | def mkdir_p(dir_path): 7 | try: 8 | os.makedirs(dir_path) 9 | except OSError as e: 10 | if e.errno != errno.EEXIST: 11 | raise 12 | 13 | def isfile(fname): 14 | return os.path.isfile(fname) 15 | 16 | def isdir(dirname): 17 | return os.path.isdir(dirname) 18 | 19 | def join(path, *paths): 20 | return os.path.join(path, *paths) -------------------------------------------------------------------------------- /util/rigging_parser/__pycache__/obj_parser.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhan-xu/AnimSkelVolNet/fbc103d3194d6c14b6276fd5be004462e7c4722f/util/rigging_parser/__pycache__/obj_parser.cpython-36.pyc -------------------------------------------------------------------------------- /util/rigging_parser/__pycache__/skel_parser.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhan-xu/AnimSkelVolNet/fbc103d3194d6c14b6276fd5be004462e7c4722f/util/rigging_parser/__pycache__/skel_parser.cpython-36.pyc -------------------------------------------------------------------------------- /util/rigging_parser/obj_parser.py: -------------------------------------------------------------------------------- 1 | """ 2 | My simple obj file parser. 3 | """ 4 | 5 | import numpy as np 6 | 7 | 8 | class Mesh_obj: 9 | def __init__(self, filename=None): 10 | self.v = [] 11 | self.vt = [] 12 | self.vn = [] 13 | self.f = [] 14 | self.hasTex = False 15 | self.hasNorm = False 16 | self.mtlfile = None 17 | self.materialList = [] 18 | if filename is not None: 19 | self.load(filename) 20 | 21 | def load(self, obj_filename): 22 | obj_file = open(obj_filename, 'r') 23 | line = obj_file.readline() 24 | while line: 25 | if len(line.split()) > 1 and line.split()[0] == 'v': 26 | self.v.append([float(line.split()[1]), float(line.split()[2]), float(line.split()[3])]) 27 | elif len(line.split()) > 1 and line.split()[0] == 'vt': 28 | self.vt.append([float(line.split()[1]), float(line.split()[2])]) 29 | elif len(line.split()) > 1 and line.split()[0] == 'vn': 30 | self.vn.append([float(line.split()[1]), float(line.split()[2]), float(line.split()[3])]) 31 | elif len(line.split()) > 1 and line.split()[0] == 'f': 32 | if '//' in line.split()[1] and len(line.split()[1].split('//'))==2: 33 | self.hasNorm = True 34 | cur_face = [] 35 | for ver in line.split()[1:]: 36 | cur_face.append(int(ver.split('//')[0])) 37 | self.f.append(cur_face) 38 | elif len(line.split()[1].split('/')) ==2: 39 | self.hasTex = True 40 | cur_face = [] 41 | for ver in line.split()[1:]: 42 | cur_face.append(list(map(int, ver.split('/')))) 43 | self.f.append(cur_face) 44 | elif len(line.split()[1].split('/')) ==3: 45 | self.hasTex = True 46 | self.hasNorm = True 47 | cur_face = [] 48 | for ver in line.split()[1:]: 49 | cur_face.append(ver.split('/')) 50 | self.f.append(cur_face) 51 | else: 52 | cur_face = [] 53 | for ver in line.split()[1:]: 54 | cur_face.append(int(ver)) 55 | self.f.append(cur_face) 56 | elif 'mtllib ' in line: 57 | self.mtlfile = line.split()[1] 58 | line = obj_file.readline() 59 | obj_file.close() 60 | # print len(self.v) 61 | self.v = np.stack(self.v, axis=0) 62 | if self.f: 63 | self.f = np.stack(self.f, axis=0) 64 | if self.vt: 65 | self.vt = np.stack(self.vt, axis=0) 66 | if self.vn: 67 | self.vn = np.stack(self.vn, axis=0) 68 | 69 | def write(self, obj_filename): 70 | f_out = open(obj_filename, 'w') 71 | f_out.write('#Export Obj file with Mesh_obj\n') 72 | #if self.mtlfile != '': 73 | # f_out.write('mtllib '+ self.mtlfile + '\n') 74 | for i in range(self.v.shape[0]): 75 | f_out.write('v {0} {1} {2}\n'.format(self.v[i,0],self.v[i,1],self.v[i,2])) 76 | if self.hasTex: 77 | for i in range(self.vt.shape[0]): 78 | f_out.write('vt {0} {1}\n'.format(self.vt[i, 0], self.vt[i, 1])) 79 | if self.hasNorm: 80 | for i in range(self.vn.shape[0]): 81 | f_out.write('vn {0} {1} {2}\n'.format(self.vn[i, 0], self.vn[i, 1], self.vn[i, 2])) 82 | for f in self.f: 83 | if self.hasTex and self.hasNorm: 84 | f_out.write('f') 85 | for i in range(len(f)): 86 | f_out.write(' {0}/{1}/{2}'.format(f[i][0],f[i][1],f[i][2])) 87 | f_out.write('\n') 88 | elif self.hasTex and not self.hasNorm: 89 | f_out.write('f') 90 | for i in range(len(f)): 91 | f_out.write(' {0}/{1}'.format(f[i][0], f[i][1])) 92 | f_out.write('\n') 93 | elif self.hasNorm and not self.hasTex: 94 | f_out.write('f') 95 | for i in range(len(f)): 96 | f_out.write(' {0}//{1}'.format(f[i][0], f[i][1])) 97 | f_out.write('\n') 98 | elif not self.hasTex and not self.hasNorm: 99 | f_out.write('f') 100 | for i in range(len(f)): 101 | f_out.write(' {0}'.format(f[i])) 102 | f_out.write('\n') 103 | -------------------------------------------------------------------------------- /util/rigging_parser/skel_parser.py: -------------------------------------------------------------------------------- 1 | ''' 2 | A parser of my skeleton file 3 | Each row of skeleton file contains: hierarchical level, joint name, joint position (x, y, z), parent joint name 4 | ''' 5 | 6 | from util.tree_utils import TreeNode 7 | 8 | try: 9 | import Queue as Q # ver. < 3.0 10 | except ImportError: 11 | import queue as Q 12 | 13 | 14 | class Skel: 15 | def __init__(self, filename=None): 16 | self.root = None 17 | if filename is not None: 18 | self.load(filename) 19 | 20 | def load(self, filename): 21 | with open(filename, 'r') as fin: 22 | lines = fin.readlines() 23 | for li in lines: 24 | words = li.split() 25 | if words[5] == "None": 26 | self.root = TreeNode(words[1], (float(words[2]), float(words[3]), float(words[4]))) 27 | if len(words) == 7: 28 | has_order = True 29 | self.root.order = int(words[6]) 30 | else: 31 | has_order = False 32 | break 33 | self.loadSkel_recur(self.root, lines, has_order) 34 | 35 | def loadSkel_recur(self, node, lines, has_order): 36 | if has_order: 37 | ch_queue = Q.PriorityQueue() 38 | for li in lines: 39 | words = li.split() 40 | if words[5] == node.name: 41 | ch_queue.put((int(li.split()[6]), li)) 42 | while not ch_queue.empty(): 43 | item = ch_queue.get() 44 | # print(item[0]) 45 | li = item[1] 46 | ch_node = TreeNode(li.split()[1], (float(li.split()[2]), float(li.split()[3]), float(li.split()[4]))) 47 | ch_node.order = int(li.split()[6]) 48 | node.children.append(ch_node) 49 | ch_node.parent = node 50 | self.loadSkel_recur(ch_node, lines, has_order) 51 | else: 52 | for li in lines: 53 | words = li.split() 54 | if words[5] == node.name: 55 | ch_node = TreeNode(words[1], (float(words[2]), float(words[3]), float(words[4]))) 56 | node.children.append(ch_node) 57 | ch_node.parent = node 58 | self.loadSkel_recur(ch_node, lines, has_order) 59 | 60 | def save(self, filename): 61 | fout = open(filename, 'w') 62 | this_level = [self.root] 63 | hier_level = 1 64 | while this_level: 65 | next_level = [] 66 | for p_node in this_level: 67 | pos = p_node.pos 68 | parent = p_node.parent.name if p_node.parent is not None else 'None' 69 | if not p_node.order: 70 | line = '{0} {1} {2:8f} {3:8f} {4:8f} {5}\n'.format(hier_level, p_node.name, pos[0], pos[1], pos[2], parent) 71 | else: 72 | line = '{0} {1} {2:8f} {3:8f} {4:8f} {5} {6}\n'.format(hier_level, p_node.name, pos[0], pos[1], pos[2], 73 | parent, p_node.order) 74 | fout.write(line) 75 | for c_node in p_node.children: 76 | next_level.append(c_node) 77 | this_level = next_level 78 | hier_level += 1 79 | fout.close() 80 | 81 | def get_joint_pos(self): 82 | joint_pos = {} 83 | this_level = [self.root] 84 | while this_level: 85 | next_level = [] 86 | for node in this_level: 87 | joint_pos[node.name] = node.pos 88 | next_level += node.children 89 | this_level = next_level 90 | return joint_pos 91 | 92 | def normalize(self, scale, trans): 93 | this_level = [self.root] 94 | while this_level: 95 | next_level = [] 96 | for node in this_level: 97 | node.pos /= scale 98 | node.pos = (node.pos[0] - trans[0, 0], node.pos[1] - trans[0, 1], node.pos[2] - trans[0, 2]) 99 | for ch in node.children: 100 | next_level.append(ch) 101 | this_level = next_level 102 | -------------------------------------------------------------------------------- /util/train_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import os 4 | import shutil 5 | import torch 6 | import numpy as np 7 | 8 | 9 | def to_numpy(tensor): 10 | if torch.is_tensor(tensor): 11 | return tensor.cpu().numpy() 12 | elif type(tensor).__module__ != 'numpy': 13 | raise ValueError("Cannot convert {} to numpy array" 14 | .format(type(tensor))) 15 | return tensor 16 | 17 | 18 | def im_to_numpy(img): 19 | img = to_numpy(img) 20 | img = np.transpose(img, (1, 2, 0)) # H*W*C 21 | return img 22 | 23 | 24 | def to_torch(ndarray): 25 | if type(ndarray).__module__ == 'numpy': 26 | return torch.from_numpy(ndarray) 27 | elif not torch.is_tensor(ndarray): 28 | raise ValueError("Cannot convert {} to torch tensor" 29 | .format(type(ndarray))) 30 | return ndarray 31 | 32 | 33 | def save_checkpoint(state, is_best, checkpoint='checkpoint', filename='checkpoint.pth.tar', snapshot=None): 34 | filepath = os.path.join(checkpoint, filename) 35 | torch.save(state, filepath) 36 | 37 | if snapshot and state['epoch'] % snapshot == 0: 38 | shutil.copyfile(filepath, os.path.join(checkpoint, 'checkpoint_{}.pth.tar'.format(state['epoch']))) 39 | 40 | if is_best: 41 | shutil.copyfile(filepath, os.path.join(checkpoint, 'model_best.pth.tar')) 42 | 43 | 44 | def adjust_learning_rate(optimizer, epoch, lr, schedule, gamma): 45 | """Sets the learning rate to the initial LR decayed by schedule""" 46 | if epoch in schedule: 47 | lr *= gamma 48 | for param_group in optimizer.param_groups: 49 | param_group['lr'] = param_group['lr']*gamma 50 | return lr 51 | 52 | 53 | class AverageMeter(object): 54 | """Computes and stores the average and current value""" 55 | def __init__(self): 56 | self.reset() 57 | 58 | def reset(self): 59 | self.val = 0.0 60 | self.avg = 0.0 61 | self.sum = 0.0 62 | self.count = 0.0 63 | 64 | def update(self, val, n=1): 65 | self.val = val 66 | self.sum += val * n 67 | self.count += n 68 | self.avg = self.sum / self.count 69 | 70 | def accumulate(self, val, n=1): 71 | self.val = val 72 | self.sum += val 73 | self.count += n 74 | self.avg = self.sum / self.count -------------------------------------------------------------------------------- /util/tree_utils.py: -------------------------------------------------------------------------------- 1 | class Node(object): 2 | def __init__(self, name, pos): 3 | self.name = name 4 | self.pos = pos 5 | 6 | 7 | class TreeNode(Node): 8 | def __init__(self, name, pos): 9 | super(TreeNode, self).__init__(name, pos) 10 | self.children = [] 11 | self.parent = None 12 | self.order = None 13 | self.is_middle = False 14 | -------------------------------------------------------------------------------- /util/vox_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import numpy as np 3 | 4 | 5 | def Cartesian2Voxcoord(v, translate, scale, resolution=82): 6 | vc = (v - translate) / scale * resolution 7 | vc = np.round(vc).astype(int) 8 | return vc[0], vc[1], vc[2] 9 | 10 | 11 | def Voxcoord2Cartesian(vc, translate, scale, resolution=82): 12 | v = vc / resolution * scale + translate 13 | return v[0], v[1], v[2] 14 | --------------------------------------------------------------------------------