├── models ├── __init__.py ├── conv │ ├── __init__.py │ ├── message_passing.py │ └── cheb_conv.py ├── inits.py ├── arap.py ├── saldnet.py ├── meshnet.py ├── meshnet_base.py └── asap.py ├── utils ├── __init__.py ├── time_utils.py ├── read.py ├── scheduler_utils.py ├── diff_operators.py ├── utils.py ├── writer.py ├── ddp_utils.py ├── geom_utils.py ├── implicit_utils.py └── mesh_sampling.py ├── datasets ├── __init__.py ├── smal.py ├── dfaust.py └── dfaust_hybrid.py ├── registration_dfaust ├── multiCorres_sync_dfaust1k │ ├── Params.m │ ├── io │ │ ├── write_ply.m │ │ ├── get_directory.m │ │ ├── load_folder.m │ │ ├── write_obj.m │ │ ├── read_ply.m │ │ ├── read_obj.m │ │ └── plyread.m │ ├── geodesic │ │ ├── fastmarchmex.mexa64 │ │ ├── fastmarchmex.mexw32 │ │ ├── fastmarchmex.mexw64 │ │ ├── fastmarchmex.mexmaci64 │ │ ├── compute_dist_matrix.m │ │ └── read_obj_nm.m │ ├── batch_icp.sh │ ├── condor.sh │ ├── multi_condor │ │ └── gen_script.sh │ ├── icp_interpolation.m │ ├── compute_GDM.m │ ├── embedded_deformation.m │ ├── main.m │ └── non_rigid_icp2.m ├── preprocess_registration_data.py ├── gen_edges.py └── gen_graph.py ├── .gitignore ├── interp ├── condor.sh └── batch_interp.sh ├── scripts ├── slurm_titans.sh ├── dist_train.sh └── slurm_titans_ddp.sh ├── pyutils.py ├── config ├── dfaust │ ├── ivae_dfaustJSM1k.yaml │ └── admesh_dfaustJSM1k.yaml └── smal │ ├── ivae_smalJSM.yaml │ └── admesh_smalJSM.yaml └── README.md /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /registration_dfaust/multiCorres_sync_dfaust1k/Params.m: -------------------------------------------------------------------------------- 1 | classdef Params 2 | properties 3 | lambda 4 | beta 5 | end 6 | end -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pkl 2 | *.json 3 | __pycache__ 4 | */__pycache__ 5 | */*/__pycache__ 6 | registration/multiCorres_sync_dfaust1k/multi_condor/* 7 | -------------------------------------------------------------------------------- /registration_dfaust/multiCorres_sync_dfaust1k/io/write_ply.m: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yanghtr/GenCorres/HEAD/registration_dfaust/multiCorres_sync_dfaust1k/io/write_ply.m -------------------------------------------------------------------------------- /registration_dfaust/multiCorres_sync_dfaust1k/io/get_directory.m: -------------------------------------------------------------------------------- 1 | function [path] = get_directory(path) 2 | if not(isfolder(path)) 3 | mkdir(path) 4 | end 5 | end -------------------------------------------------------------------------------- /models/conv/__init__.py: -------------------------------------------------------------------------------- 1 | from .message_passing import MessagePassing 2 | from .cheb_conv import ChebConv 3 | 4 | 5 | __all__ = [ 6 | 'MessagePassing', 7 | 'ChebConv', 8 | ] 9 | -------------------------------------------------------------------------------- /registration_dfaust/multiCorres_sync_dfaust1k/geodesic/fastmarchmex.mexa64: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yanghtr/GenCorres/HEAD/registration_dfaust/multiCorres_sync_dfaust1k/geodesic/fastmarchmex.mexa64 -------------------------------------------------------------------------------- /registration_dfaust/multiCorres_sync_dfaust1k/geodesic/fastmarchmex.mexw32: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yanghtr/GenCorres/HEAD/registration_dfaust/multiCorres_sync_dfaust1k/geodesic/fastmarchmex.mexw32 -------------------------------------------------------------------------------- /registration_dfaust/multiCorres_sync_dfaust1k/geodesic/fastmarchmex.mexw64: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yanghtr/GenCorres/HEAD/registration_dfaust/multiCorres_sync_dfaust1k/geodesic/fastmarchmex.mexw64 -------------------------------------------------------------------------------- /registration_dfaust/multiCorres_sync_dfaust1k/geodesic/fastmarchmex.mexmaci64: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yanghtr/GenCorres/HEAD/registration_dfaust/multiCorres_sync_dfaust1k/geodesic/fastmarchmex.mexmaci64 -------------------------------------------------------------------------------- /registration_dfaust/multiCorres_sync_dfaust1k/batch_icp.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | offset=$1 3 | interval=10 # 6213 4 | start_idx=$((offset * interval + 1)) 5 | echo "start_idx=${start_idx}, interval=${interval}" 6 | /lusr/share/software/matlab-r2018b/bin/matlab -nodesktop -nosplash -r "main ${start_idx} ${interval}" 7 | -------------------------------------------------------------------------------- /registration_dfaust/multiCorres_sync_dfaust1k/io/load_folder.m: -------------------------------------------------------------------------------- 1 | function [shapes] = load_folder(foldername, numObjects) 2 | % 3 | for id = 1 : numObjects 4 | filename = sprintf("%s/interp_%d_30.ply", foldername, id-1); 5 | fprintf('%d : load %s \n', id-1, filename); 6 | shapes{id} = read_ply(filename); 7 | end 8 | 9 | 10 | 11 | -------------------------------------------------------------------------------- /registration_dfaust/multiCorres_sync_dfaust1k/condor.sh: -------------------------------------------------------------------------------- 1 | +Group = "GRAD" 2 | +Project = "GRAPHICS_VISUALIZATION" 3 | +ProjectDescription = "corres" 4 | 5 | Universe = vanilla 6 | requirements = InMastodon 7 | Executable = ./batch_icp.sh 8 | Output = ./log/$(Process).out 9 | Error = ./log/$(Process).err 10 | Log = ./log/$(Process).log 11 | arguments = $(Process) 12 | 13 | Queue 622 14 | 15 | -------------------------------------------------------------------------------- /interp/condor.sh: -------------------------------------------------------------------------------- 1 | +Group = "GRAD" 2 | +Project = "GRAPHICS_VISUALIZATION" 3 | +ProjectDescription = "compute correspondence" 4 | +GPUJob = true 5 | Universe = vanilla 6 | requirements = Eldar 7 | request_GPUs = 1 8 | Executable = ./interp/batch_interp.sh 9 | Output = ./interp/log/$(Process).out 10 | Error = ./interp/log/$(Process).err 11 | Log = ./interp/log/$(Process).log 12 | arguments = $(Process) 13 | Queue 346 14 | 15 | -------------------------------------------------------------------------------- /registration_dfaust/multiCorres_sync_dfaust1k/io/write_obj.m: -------------------------------------------------------------------------------- 1 | function [] = write_obj(shape, filename) 2 | % 3 | f_id = fopen(filename, 'w'); 4 | for vId = 1 : size(shape.vertexPoss,2) 5 | pos = shape.vertexPoss(:, vId); 6 | fprintf(f_id, 'v %f %f %f\n', pos(1), pos(2), pos(3)); 7 | end 8 | for fId = 1 : size(shape.faceVIds,2) 9 | vids = shape.faceVIds(:,fId); 10 | fprintf(f_id, 'f %d %d %d\n', vids(1), vids(2), vids(3)); 11 | end 12 | fclose(f_id); -------------------------------------------------------------------------------- /scripts/slurm_titans.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | 5 | PARTITION=$1 6 | JOB_NAME=$2 7 | GPUS=$3 8 | PY_ARGS=${@:4} 9 | 10 | CPUS_PER_TASK=${CPUS_PER_TASK:-5} 11 | SRUN_ARGS=${SRUN_ARGS:-""} 12 | 13 | srun -p ${PARTITION} \ 14 | --job-name=${JOB_NAME} \ 15 | --nodes=1 \ 16 | --ntasks-per-node=1 \ 17 | --gres=gpu:${GPUS} \ 18 | --mem-per-cpu=6G \ 19 | --cpus-per-task=$((CPUS_PER_TASK * GPUS)) \ 20 | --kill-on-bad-exit=1 \ 21 | ${SRUN_ARGS} \ 22 | ${PY_ARGS} 23 | -------------------------------------------------------------------------------- /registration_dfaust/multiCorres_sync_dfaust1k/io/read_ply.m: -------------------------------------------------------------------------------- 1 | function [mesh] = read_ply(filename) 2 | % 3 | temp = plyread(filename); 4 | numV = length(temp.vertex.x); 5 | numF = length(temp.face.vertex_indices); 6 | % 7 | mesh.vertexPoss = zeros(3, numV); 8 | mesh.faceVIds = zeros(3, numF); 9 | % 10 | for id = 1 : numV 11 | mesh.vertexPoss(1, id) = temp.vertex.x(id); 12 | mesh.vertexPoss(2, id) = temp.vertex.y(id); 13 | mesh.vertexPoss(3, id) = temp.vertex.z(id); 14 | end 15 | for id = 1 : numF 16 | mesh.faceVIds(:, id) = temp.face.vertex_indices{id}'+1; 17 | end 18 | % 19 | -------------------------------------------------------------------------------- /scripts/dist_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | NGPUS=$1 5 | PY_ARGS=${@:2} 6 | 7 | while true 8 | do 9 | PORT=$(( ((RANDOM<<15)|RANDOM) % 49152 + 10000 )) 10 | status="$(nc -z 127.0.0.1 $PORT < /dev/null &>/dev/null; echo $?)" 11 | if [ "${status}" != "0" ]; then 12 | break; 13 | fi 14 | done 15 | echo $PORT 16 | 17 | # python -m torch.distributed.launch --nproc_per_node=${NGPUS} --rdzv_endpoint=localhost:${PORT} train.py --launcher pytorch ${PY_ARGS} 18 | python -m torch.distributed.launch --nproc_per_node=${NGPUS} --master_port ${PORT} main.py --launcher pytorch ${PY_ARGS} 19 | 20 | -------------------------------------------------------------------------------- /interp/batch_interp.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export PATH=/scratch/cluster/yanght/Software/miniconda3/bin/:$PATH 3 | source /scratch/cluster/yanght/Software/miniconda3/etc/profile.d/conda.sh 4 | conda activate torch13; 5 | which python 6 | echo "setup finished!" 7 | 8 | offset=$1 9 | interval=100 # 10 10 | start_idx=$((offset * interval)) 11 | echo "start_idx=${start_idx}, interval=${interval}" 12 | 13 | python main.py --config ./config/dfaust/ivae_dfaustJSM1k.yaml --mode interp_edges --rep sdf --continue_from 6499 --split test --edge_ids_path ./work_dir/dfaust/ivae_dfaustJSM1k/results/test/analysis_sdf/edge_ids/test_6499_edge_ids_K25.npy --parallel_idx ${offset} --parallel_interval ${interval} 14 | -------------------------------------------------------------------------------- /utils/time_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | import time 4 | from contextlib import contextmanager 5 | 6 | @contextmanager 7 | def timer(task_name): 8 | t = time.perf_counter() 9 | try: 10 | yield 11 | finally: 12 | print(task_name, ": ", time.perf_counter() - t, " s.") 13 | 14 | def timing(f): 15 | def wrap(*args, **kwargs): 16 | t1 = time.perf_counter() 17 | ret = f(*args, **kwargs) 18 | t2 = time.perf_counter() 19 | print('{:s} function took {:.6f} s'.format(f.__name__, (t2 - t1))) 20 | return ret 21 | return wrap 22 | 23 | 24 | if __name__ == '__main__': 25 | with timer('test'): 26 | time.sleep(1) 27 | -------------------------------------------------------------------------------- /scripts/slurm_titans_ddp.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | 5 | PARTITION=$1 6 | JOB_NAME=$2 7 | GPUS=$3 8 | PY_ARGS=${@:4} 9 | 10 | CPUS_PER_TASK=${CPUS_PER_TASK:-5} 11 | SRUN_ARGS=${SRUN_ARGS:-""} 12 | 13 | while true 14 | do 15 | PORT=$(( ((RANDOM<<15)|RANDOM) % 49152 + 10000 )) 16 | status="$(nc -z 127.0.0.1 $PORT < /dev/null &>/dev/null; echo $?)" 17 | if [ "${status}" != "0" ]; then 18 | break; 19 | fi 20 | done 21 | echo $PORT 22 | 23 | srun -p ${PARTITION} \ 24 | --job-name=${JOB_NAME} \ 25 | --nodes=1 \ 26 | --ntasks-per-node=${GPUS} \ 27 | --gres=gpu:${GPUS} \ 28 | --mem-per-cpu=6G \ 29 | --cpus-per-task=${CPUS_PER_TASK} \ 30 | --kill-on-bad-exit=1 \ 31 | ${SRUN_ARGS} \ 32 | python -u main.py --launcher slurm --tcp_port $PORT ${PY_ARGS} 33 | -------------------------------------------------------------------------------- /utils/read.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.data import Data 3 | from torch_geometric.utils import to_undirected 4 | import openmesh as om 5 | 6 | def read_mesh(path, data_id, pose=None, return_face=False): 7 | mesh = om.read_trimesh(path) 8 | points = mesh.points() 9 | face = torch.from_numpy(mesh.face_vertex_indices()).T.type(torch.long) 10 | 11 | x = torch.tensor(points.astype('float32')) 12 | edge_index = torch.cat([face[:2], face[1:], face[::2]], dim=1) 13 | edge_index = to_undirected(edge_index) 14 | if return_face==True and pose is not None: 15 | return Data(x=x, edge_index=edge_index,face=face, data_id=data_id, pose=pose) 16 | if pose is not None: 17 | return Data(x=x, edge_index=edge_index, data_id=data_id, pose=pose) 18 | if return_face==True: 19 | return Data(x=x, edge_index=edge_index,face=face, data_id=data_id) 20 | return Data(x=x, edge_index=edge_index, data_id=data_id) 21 | 22 | -------------------------------------------------------------------------------- /models/inits.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | 4 | def uniform(size, tensor): 5 | bound = 1.0 / math.sqrt(size) 6 | if tensor is not None: 7 | tensor.data.uniform_(-bound, bound) 8 | 9 | 10 | def kaiming_uniform(tensor, fan, a): 11 | if tensor is not None: 12 | bound = math.sqrt(6 / ((1 + a**2) * fan)) 13 | tensor.data.uniform_(-bound, bound) 14 | 15 | 16 | def glorot(tensor): 17 | if tensor is not None: 18 | stdv = math.sqrt(6.0 / (tensor.size(-2) + tensor.size(-1))) 19 | tensor.data.uniform_(-stdv, stdv) 20 | 21 | 22 | def zeros(tensor): 23 | if tensor is not None: 24 | tensor.data.fill_(0) 25 | 26 | 27 | def ones(tensor): 28 | if tensor is not None: 29 | tensor.data.fill_(1) 30 | 31 | 32 | def reset(nn): 33 | def _reset(item): 34 | if hasattr(item, 'reset_parameters'): 35 | item.reset_parameters() 36 | 37 | if nn is not None: 38 | if hasattr(nn, 'children') and len(list(nn.children())) > 0: 39 | for item in nn.children(): 40 | _reset(item) 41 | else: 42 | _reset(nn) 43 | -------------------------------------------------------------------------------- /registration_dfaust/multiCorres_sync_dfaust1k/io/read_obj.m: -------------------------------------------------------------------------------- 1 | function [Shape] = read_obj(filename) 2 | % 3 | numV = 100000; 4 | numF = 200000; 5 | vPos = zeros(3, numV); 6 | vFace = zeros(3, numF); 7 | vId = 0; 8 | fId = 0; 9 | f_id = fopen(filename, 'r'); 10 | while 1 11 | tline = fgetl(f_id); 12 | if tline == -1 13 | break; 14 | end 15 | if length(tline) < 2 16 | continue; 17 | end 18 | if tline(1) == 'v' && tline(2) == ' ' 19 | vId = vId + 1; 20 | p = str2num(tline(3:length(tline))); 21 | vPos(:, vId) = p'; 22 | end 23 | if tline(1) == 'f' && tline(2) == ' ' 24 | fId = fId + 1; 25 | v = str2num(tline(3:length(tline))); 26 | vFace(:, fId) = v'; 27 | end 28 | end 29 | fclose(f_id); 30 | % 31 | vPos = vPos(:,1:vId); 32 | vFace = vFace(:,1:fId); 33 | v1 = vFace(1,:); 34 | v2 = vFace(2,:); 35 | v3 = vFace(3,:); 36 | edges = [v1,v2,v3;v2,v3,v1]; 37 | G = sparse(edges(1,:), edges(2,:), ones(1,size(edges,2)), vId, vId); 38 | G = max(G,G'); 39 | [rows, cols, vals] = find(G); 40 | edges = [rows';cols']; 41 | Shape.vertexPoss = vPos; 42 | Shape.faceVIds = vFace; 43 | Shape.edges = edges; -------------------------------------------------------------------------------- /registration_dfaust/multiCorres_sync_dfaust1k/multi_condor/gen_script.sh: -------------------------------------------------------------------------------- 1 | num_edges=31189 2 | num_machine=10 3 | interval=10 4 | num_queue=$((num_edges/num_machine/interval + 1)) 5 | num_jobs_per_mac=$((num_queue * interval)) 6 | for i in $(seq 1 ${num_machine}); do 7 | echo ${i}; 8 | mkdir log_${i}; 9 | echo -e "#!/bin/bash 10 | start=\$1 11 | offset=$((1 + (i-1) * num_jobs_per_mac)) 12 | start_idx=\$((start * ${interval} + offset)) 13 | echo \"start_idx=\${start_idx}, interval=${interval}\" 14 | /lusr/share/software/matlab-r2018b/bin/matlab -c /lusr/share/software/matlab-r2018a/licenses/network.lic -nodesktop -nosplash -r \"main \${start_idx} ${interval}\" " >> batch_icp_${i}.sh; 15 | 16 | echo -e "+Group = \"GRAD\" 17 | +Project = \"GRAPHICS_VISUALIZATION\" 18 | +ProjectDescription = \"corres\" 19 | Universe = vanilla 20 | requirements = InMastodon 21 | Executable = ./multi_condor/batch_icp_${i}.sh 22 | Output = ./multi_condor/log_${i}/\$(Process).out 23 | Error = ./multi_condor/log_${i}/\$(Process).err 24 | Log = ./multi_condor/log_${i}/\$(Process).log 25 | arguments = \$(Process) 26 | Queue ${num_queue}" >> condor_${i}.sh; 27 | chmod 777 batch_icp_${i}.sh; 28 | chmod 777 condor_${i}.sh; 29 | done 30 | -------------------------------------------------------------------------------- /registration_dfaust/multiCorres_sync_dfaust1k/icp_interpolation.m: -------------------------------------------------------------------------------- 1 | %% 2 | function [mesh_src_def] = icp_interpolation(mesh_src, mesh_tgt, meshes_interp, params) 3 | num_interp = length(meshes_interp); 4 | poss_vec = mesh_src.vertexPoss; 5 | % nonrigid registration to interpolation 6 | for i_interp = 1 : num_interp 7 | fprintf('i_interp = %d\n', i_interp); 8 | lambda_w = params.lambda; 9 | beta = params.beta; 10 | outerIterMax = 8; 11 | innerIterMax = 1; 12 | poss_vec = non_rigid_icp2(mesh_src, meshes_interp{i_interp}, poss_vec, lambda_w, beta, outerIterMax, innerIterMax); 13 | % DEBUG 14 | % mesh_tmp = mesh_src; mesh_tmp.vertexPoss = poss_vec; 15 | % write_obj(mesh_tmp, ['/mnt/yanghaitao/Projects/GenCorres/gencorres/vis/registration/dfaust1k/mesh_def/tmp', num2str(i_interp), '.obj']); 16 | end 17 | % nonrigid registration to target 18 | lambda_w = params.lambda * 0.1; 19 | beta = params.beta; 20 | outerIterMax = 10; 21 | innerIterMax = 1; 22 | poss_vec = non_rigid_icp2(mesh_src, mesh_tgt, poss_vec, lambda_w, beta, outerIterMax, innerIterMax); 23 | 24 | mesh_src_def = mesh_src; 25 | mesh_src_def.vertexPoss = poss_vec; 26 | end 27 | 28 | -------------------------------------------------------------------------------- /registration_dfaust/multiCorres_sync_dfaust1k/compute_GDM.m: -------------------------------------------------------------------------------- 1 | function compute_GDM(start_idx, interval) 2 | 3 | addpath('io/'); 4 | addpath("geodesic/"); 5 | meta_data_path = '../dfaust1k/meta_test_6499_K5.mat'; 6 | mesh_raw_dir = '../dfaust1k/mesh_raw/'; 7 | dist_mat_dir = get_directory('/media/yanghaitao/HaitaoYang/Graphicsai_Backup/mnt/yanghaitao/Dataset/DFAUST/dfaust1k/distance_matrix'); 8 | 9 | meta_data = load(meta_data_path); 10 | fids = meta_data.fids; 11 | num_meshes = size(fids, 1); 12 | template_idx = meta_data.template_idx; % starts from 0 13 | template_fid = meta_data.template_fid; 14 | 15 | if isstring(start_idx) || ischar(start_idx) 16 | start_idx = str2num(start_idx) 17 | end 18 | if isstring(interval) || ischar(interval) 19 | interval = str2num(interval) 20 | end 21 | end_idx = min(start_idx + interval - 1, num_meshes); 22 | for idx = start_idx : end_idx 23 | fprintf("start compute_dist_matrix: %d", idx); 24 | fid = strtrim(fids(idx, :)); % matlab starts from 1 25 | 26 | [X.vert, X.triv] = read_obj_nm([mesh_raw_dir, '/', fid, '.obj']); % NeuroMorph API to load obj 27 | X.vert = X.vert'; 28 | X.triv = X.triv'; 29 | X.n = size(X.vert, 1); 30 | X.m = size(X.triv, 1); 31 | 32 | D = compute_dist_matrix(X); 33 | D = single(D); 34 | 35 | save([dist_mat_dir, '/', num2str(idx - 1), '.mat'], 'D'); 36 | fprintf("finish compute_dist_matrix: %d (matlab), %d (python)\n", idx, idx-1); 37 | end 38 | 39 | end 40 | -------------------------------------------------------------------------------- /registration_dfaust/multiCorres_sync_dfaust1k/geodesic/compute_dist_matrix.m: -------------------------------------------------------------------------------- 1 | function D = compute_dist_matrix(M, samples) 2 | 3 | % Copyright (c) Facebook, Inc. and its affiliates. 4 | % 5 | % This source code is licensed under the MIT license found in the 6 | % LICENSE file in the root directory of this source tree. 7 | 8 | M.n = size(M.vert, 1); 9 | M.m = size(M.triv, 1); 10 | 11 | if nargin == 1 || isempty(samples) 12 | samples = 1:M.n; 13 | end 14 | 15 | if ~exist('fastmarchmex') 16 | % Use precomputed binaries from https://github.com/abbasloo/dnnAuto/ 17 | base_url = "https://github.com/abbasloo/dnnAuto/raw/37ce4320bc90a75b07a7ec1d862484d6576cec4c/preprocessing/isc"; 18 | for ext = {'mexa64', 'mexmaci64', 'mexw32', 'mexw64'} 19 | urlwrite(... 20 | base_url + "/fastmarchmex." + ext, ... 21 | "fastmarchmex." + ext); 22 | end 23 | rehash 24 | end 25 | 26 | % Calls legacy fast marching code 27 | march = fastmarchmex('init', int32(M.triv - 1), double(M.vert(:, 1)), double(M.vert(:, 2)), double(M.vert(:, 3))); 28 | 29 | D = zeros(length(samples)); 30 | 31 | for i = 1:length(samples) 32 | source = inf(M.n, 1); 33 | source(samples(i)) = 0; 34 | d = fastmarchmex('march', march, double(source)); 35 | D(:, i) = d(samples); 36 | end 37 | 38 | fastmarchmex('deinit', march); 39 | 40 | % Ensures that the distance matrix is exactly symmetric 41 | D = 0.5 * (D + D'); 42 | end 43 | -------------------------------------------------------------------------------- /pyutils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | import os 4 | import torch 5 | import random 6 | import importlib 7 | import numpy as np 8 | 9 | 10 | def set_random_seed(seed): 11 | random.seed(seed) 12 | np.random.seed(seed) 13 | torch.manual_seed(seed) 14 | torch.cuda.manual_seed(seed) 15 | torch.backends.cudnn.deterministic = True 16 | torch.backends.cudnn.benchmark = False 17 | 18 | 19 | def get_directory(path): 20 | if not os.path.exists(path): 21 | os.makedirs(path, exist_ok=True) 22 | return path 23 | 24 | 25 | def to_device(d, device): 26 | if isinstance(d, dict): 27 | return {k: to_device(v, device) for k, v in d.items()} 28 | elif isinstance(d, list): 29 | return [to_device(v, device) for v in d] 30 | else: 31 | assert(isinstance(d, torch.Tensor) or isinstance(d, np.ndarray)) 32 | return d.to(device) 33 | 34 | 35 | def to_numpy(d): 36 | if isinstance(d, torch.Tensor): 37 | return d.detach().cpu().numpy() 38 | if isinstance(d, dict): 39 | return {k: to_numpy(v) for k, v in d.items()} 40 | elif isinstance(d, list): 41 | return [to_numpy(v) for v in d] 42 | else: 43 | return d 44 | 45 | 46 | def load_module(module_name, class_name=None): 47 | module = importlib.import_module(module_name) 48 | if class_name is not None: 49 | return getattr(module, class_name) 50 | else: 51 | return module 52 | 53 | 54 | def update_config_from_args(config, args): 55 | for k, v in args.__dict__.items(): 56 | config[k] = v 57 | 58 | 59 | -------------------------------------------------------------------------------- /utils/scheduler_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | import os 4 | import torch 5 | 6 | 7 | class StepLRSchedule(): 8 | def __init__(self, lr_group_init, gamma, step_size): 9 | self.lr_group_init = lr_group_init 10 | self.gamma = gamma 11 | self.step_size = step_size 12 | 13 | def get_lr_group(self, epoch): 14 | e = epoch // self.step_size 15 | lr_group = [] 16 | for lr in self.lr_group_init: 17 | lr_group.append(lr * (self.gamma ** e)) 18 | return lr_group 19 | 20 | 21 | class MultiplicativeLRSchedule(): 22 | def __init__(self, lr_group_init, gammas, milestones): 23 | ''' 24 | Args: 25 | milestones: list, epoch in increasing order 26 | gammas: list 27 | ''' 28 | assert(len(gammas)==len(milestones)) 29 | self.lr_group_init = lr_group_init 30 | self.gammas = gammas 31 | self.milestones = milestones 32 | 33 | def get_lr_group(self, epoch): 34 | factor = 1. 35 | for g, m in zip(self.gammas, self.milestones): 36 | if epoch >= m: 37 | factor *= g 38 | else: 39 | break 40 | 41 | lr_group = [] 42 | for lr in self.lr_group_init: 43 | lr_group.append(lr * factor) 44 | return lr_group 45 | 46 | 47 | def adjust_learning_rate(lr_scheduler, optimizer, epoch): 48 | lr_group = lr_scheduler.get_lr_group(epoch) 49 | for i, param_group in enumerate(optimizer.param_groups): 50 | param_group["lr"] = lr_group[i] 51 | return lr_group 52 | 53 | 54 | 55 | 56 | -------------------------------------------------------------------------------- /registration_dfaust/multiCorres_sync_dfaust1k/geodesic/read_obj_nm.m: -------------------------------------------------------------------------------- 1 | function [vertex, faces, normal] = read_obj_nm(filename) 2 | 3 | % read_obj - load a .obj file. 4 | % 5 | % [vertex,face,normal] = read_obj(filename); 6 | % 7 | % faces : list of facesangle elements 8 | % vertex : node vertexinatates 9 | % normal : normal vector list 10 | % 11 | % Copyright (c) 2003 Gabriel Peyré 12 | 13 | fid = fopen(filename); 14 | 15 | if fid < 0 16 | error(['Cannot open ' filename '.']); 17 | end 18 | 19 | frewind(fid); 20 | a = fscanf(fid, '%c', 1); 21 | 22 | if strcmp(a, 'P') 23 | % This is the montreal neurological institute (MNI) specific ASCII facesangular mesh data structure. 24 | % For FreeSurfer software, a slightly different data input coding is 25 | % needed. It will be provided upon request. 26 | fscanf(fid, '%f', 5); 27 | n_points = fscanf(fid, '%i', 1); 28 | vertex = fscanf(fid, '%f', [3, n_points]); 29 | normal = fscanf(fid, '%f', [3, n_points]); 30 | n_faces = fscanf(fid, '%i', 1); 31 | fscanf(fid, '%i', 5 + n_faces); 32 | faces = fscanf(fid, '%i', [3, n_faces])' + 1; 33 | fclose(fid); 34 | return; 35 | end 36 | 37 | frewind(fid); 38 | vertex = []; 39 | faces = []; 40 | 41 | while 1 42 | s = fgetl(fid); 43 | 44 | if ~ischar(s), 45 | break; 46 | end 47 | 48 | if ~isempty(s) && strcmp(s(1), 'f') 49 | % face 50 | faces(:, end + 1) = sscanf(s(3:end), '%d %d %d'); 51 | end 52 | 53 | if ~isempty(s) && strcmp(s(1), 'v') 54 | % vertex 55 | vertex(:, end + 1) = sscanf(s(3:end), '%f %f %f'); 56 | end 57 | 58 | end 59 | 60 | fclose(fid); 61 | 62 | end 63 | -------------------------------------------------------------------------------- /registration_dfaust/multiCorres_sync_dfaust1k/embedded_deformation.m: -------------------------------------------------------------------------------- 1 | function [deformed_mesh] = embedded_deformation(mesh, mesh_sim, mesh_sim_deformed) 2 | % Compute the deformation using mesh_sim and mesh_sim_deformed 3 | % Use the resulting deformation to deform mesh to obtain deformed_mesh 4 | [~, nIds] = compute_neighbors(mesh_sim); 5 | cur_trans = vertex_trans_fitting(mesh_sim.vertexPoss, mesh_sim_deformed.vertexPoss, nIds); 6 | [IDX, DIS] = knnsearch(mesh_sim.vertexPoss', mesh.vertexPoss', 'k', 20); 7 | sigma = median(DIS(:,2)); 8 | Weights = exp(-(DIS.*DIS)/2/sigma/sigma); 9 | sumOfWeights = sum(Weights')'; 10 | Weights = Weights./(sumOfWeights*ones(1,20)); 11 | % 12 | deformed_mesh = mesh; 13 | for id = 1 : size(mesh.vertexPoss, 2) 14 | tPos = zeros(3,1); 15 | sPos = mesh.vertexPoss(:, id); 16 | for i = 1 : size(IDX, 2) 17 | transId = IDX(id, i); 18 | w = Weights(id, i); 19 | A = cur_trans{transId}(:,1:3); 20 | b = cur_trans{transId}(:,4); 21 | tPos = tPos + w*(A*sPos + b); 22 | end 23 | deformed_mesh.vertexPoss(:, id) = tPos; 24 | end 25 | 26 | %% 27 | function [A, nIds] = compute_neighbors(mesh) 28 | numV = size(mesh.vertexPoss, 2); 29 | edges = [mesh.faceVIds(1, :), mesh.faceVIds(2, :), mesh.faceVIds(3, :); 30 | mesh.faceVIds(2, :), mesh.faceVIds(3, :), mesh.faceVIds(1, :)]; 31 | A = sparse(edges(1, :), edges(2, :), ones(1, size(edges, 2))); 32 | nIds = cell(1, numV); 33 | for vId = 1 : numV 34 | nIds{vId} = find(A(vId,:)); 35 | end 36 | 37 | %% 38 | function [vertex_trans] = vertex_trans_fitting(fixed_poss, opt_poss, nIds) 39 | % 40 | numV = size(fixed_poss, 2); 41 | for vId = 1 : numV 42 | ids = nIds{vId}; 43 | valence = length(ids); 44 | P = double(fixed_poss(:,ids) - fixed_poss(:,vId)*ones(1,valence)); 45 | Q = double(opt_poss(:,ids) - opt_poss(:,vId)*ones(1,valence)); 46 | A = (Q*P')*pinv(P*P'); 47 | b = opt_poss(:,vId) - A*fixed_poss(:,vId); 48 | vertex_trans{vId} = [A,b]; 49 | end 50 | % 51 | -------------------------------------------------------------------------------- /registration_dfaust/preprocess_registration_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import pickle 4 | import trimesh 5 | import argparse 6 | import numpy as np 7 | import scipy.io as sio 8 | 9 | 10 | def get_directory(path): 11 | if not os.path.exists(path): 12 | os.makedirs(path, exist_ok=True) 13 | return path 14 | 15 | 16 | if __name__ == '__main__': 17 | 18 | ##### convert python file to matlab file 19 | template_fid = '50009-running_on_spot-running_on_spot.000366' 20 | analysis_root = '../work_dir/dfaust/ivae_dfaustJSM1k/results/test/analysis_sdf/' 21 | dump_root = './dfaust1k/' 22 | epoch = 6499 23 | num_neighbors = 25 # 25 24 | 25 | edge_ids = np.load(f'{analysis_root}/edge_ids/test_{epoch}_edge_ids_K{num_neighbors}.npy') 26 | 27 | pkl = pickle.load(open(f"{analysis_root}/latents_all_test_{epoch}.pkl", 'rb')) 28 | N = len(pkl) 29 | fids = [pkl[i]['fid'] for i in range(N)] 30 | template_idx = fids.index(template_fid) 31 | assert(template_idx == 374) 32 | 33 | metadata = { 34 | 'edge_ids': edge_ids, 35 | 'fids': fids, 36 | 'template_idx': template_idx, 37 | 'template_fid': template_fid, 38 | } 39 | sio.savemat(f'{dump_root}/meta_test_{epoch}_K{num_neighbors}.mat', mdict=metadata) 40 | 41 | ##### aggregate dataset 42 | data_root = '/scratch/cluster/yanght/Dataset/Human/DFAUST/registrations/' 43 | mesh_raw_dir = get_directory( f'{dump_root}/mesh_raw/' ) 44 | mesh_sim_dir = get_directory( f'{dump_root}/mesh_sim/' ) 45 | 46 | for i, fid in enumerate(fids): 47 | print(i, fid) 48 | fname = '/'.join(fid.split('-')) + '.obj' 49 | fpath = f'{data_root}/{fname}' 50 | os.system(f"cp {fpath} {mesh_raw_dir}/{fid}.obj") 51 | 52 | mesh = trimesh.load(fpath, process=False, maintain_order=True) 53 | mesh_sim = mesh.simplify_quadratic_decimation(2000) 54 | mesh_sim.export(f"{mesh_sim_dir}/{fid}_sim.obj") 55 | 56 | 57 | 58 | -------------------------------------------------------------------------------- /utils/diff_operators.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import grad 3 | 4 | 5 | def hessian(y, x): 6 | ''' hessian of y wrt x 7 | y: shape (meta_batch_size, num_observations, channels) 8 | x: shape (meta_batch_size, num_observations, 2) 9 | ''' 10 | meta_batch_size, num_observations = y.shape[:2] 11 | grad_y = torch.ones_like(y[..., 0]).to(y.device) 12 | h = torch.zeros(meta_batch_size, num_observations, y.shape[-1], x.shape[-1], x.shape[-1]).to(y.device) 13 | for i in range(y.shape[-1]): 14 | # calculate dydx over batches for each feature value of y 15 | dydx = grad(y[..., i], x, grad_y, create_graph=True)[0] 16 | 17 | # calculate hessian on y for each x value 18 | for j in range(x.shape[-1]): 19 | h[..., i, j, :] = grad(dydx[..., j], x, grad_y, create_graph=True)[0][..., :] 20 | 21 | status = 0 22 | if torch.any(torch.isnan(h)): 23 | status = -1 24 | return h, status 25 | 26 | 27 | def laplace(y, x): 28 | grad = gradient(y, x) 29 | return divergence(grad, x) 30 | 31 | 32 | def divergence(y, x): 33 | div = 0. 34 | for i in range(y.shape[-1]): 35 | div += grad(y[..., i], x, torch.ones_like(y[..., i]), create_graph=True)[0][..., i:i+1] 36 | return div 37 | 38 | 39 | def gradient(y, x, grad_outputs=None): 40 | if grad_outputs is None: 41 | grad_outputs = torch.ones_like(y) 42 | grad = torch.autograd.grad(y, [x], grad_outputs=grad_outputs, create_graph=True)[0] 43 | return grad 44 | 45 | 46 | def jacobian(y, x): 47 | ''' jacobian of y wrt x ''' 48 | meta_batch_size, num_observations = y.shape[:2] 49 | jac = torch.zeros(meta_batch_size, num_observations, y.shape[-1], x.shape[-1]).to(y.device) # (meta_batch_size*num_points, 2, 2) 50 | for i in range(y.shape[-1]): 51 | # calculate dydx over batches for each feature value of y 52 | y_flat = y[...,i].reshape(-1, 1) 53 | jac[:, :, i, :] = grad(y_flat, x, torch.ones_like(y_flat), create_graph=True)[0] 54 | 55 | # status = 0 56 | # if torch.any(torch.isnan(jac)): 57 | # status = -1 58 | # return jac, status 59 | return jac 60 | 61 | 62 | 63 | 64 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import numpy as np 4 | from glob import glob 5 | from scipy.spatial import cKDTree as KDTree 6 | from matplotlib import cm 7 | 8 | 9 | 10 | def makedirs(folder): 11 | if not os.path.exists(folder): 12 | os.makedirs(folder) 13 | 14 | 15 | def to_sparse(spmat): 16 | return torch.sparse.FloatTensor( 17 | torch.LongTensor([spmat.tocoo().row, 18 | spmat.tocoo().col]), 19 | torch.FloatTensor(spmat.tocoo().data), torch.Size(spmat.tocoo().shape)) 20 | 21 | 22 | def to_edge_index(mat): 23 | return torch.LongTensor(np.vstack(mat.nonzero())) 24 | 25 | def get_colors_from_diff_pc(diff_pc, min_error, max_error): 26 | colors = np.zeros((diff_pc.shape[0],3)) 27 | mix = (diff_pc-min_error)/(max_error-min_error) 28 | mix = np.clip(mix, 0,1) #point_num 29 | cmap=cm.get_cmap('coolwarm') 30 | colors = cmap(mix)[:,0:3] 31 | return colors 32 | 33 | 34 | def save_pc_with_color_into_ply(template_ply, pc, color, fn): 35 | plydata=template_ply 36 | #pc = pc.copy()*pc_std + pc_mean 37 | plydata['vertex']['x']=pc[:,0] 38 | plydata['vertex']['y']=pc[:,1] 39 | plydata['vertex']['z']=pc[:,2] 40 | 41 | plydata['vertex']['red']=color[:,0] 42 | plydata['vertex']['green']=color[:,1] 43 | plydata['vertex']['blue']=color[:,2] 44 | 45 | plydata.write(fn) 46 | plydata['vertex']['red']=plydata['vertex']['red']*0+0.7*255 47 | plydata['vertex']['green']=plydata['vertex']['red']*0+0.7*255 48 | plydata['vertex']['blue']=plydata['vertex']['red']*0+0.7*255 49 | 50 | def compute_trimesh_chamfer(gt_points, gen_points): 51 | """ 52 | This function computes a symmetric chamfer distance, i.e. the sum of both chamfers. 53 | gt_points: trimesh.points.PointCloud of just poins, sampled from the surface (see 54 | compute_metrics.ply for more documentation) 55 | gen_mesh: trimesh.base.Trimesh of output mesh from whichever autoencoding reconstruction 56 | method (see compute_metrics.py for more) 57 | """ 58 | # only need numpy array of points 59 | # gt_points_np = gt_points.vertices 60 | gt_points_np = gt_points.detach().cpu().numpy() 61 | gen_points_sampled = gen_points.detach().cpu().numpy() 62 | 63 | # one direction 64 | gen_points_kd_tree = KDTree(gen_points_sampled) 65 | one_distances, one_vertex_ids = gen_points_kd_tree.query(gt_points_np) 66 | gt_to_gen_chamfer = np.mean(np.square(one_distances)) 67 | 68 | # other direction 69 | gt_points_kd_tree = KDTree(gt_points_np) 70 | two_distances, two_vertex_ids = gt_points_kd_tree.query(gen_points_sampled) 71 | gen_to_gt_chamfer = np.mean(np.square(two_distances)) 72 | 73 | return gt_to_gen_chamfer + gen_to_gt_chamfer 74 | -------------------------------------------------------------------------------- /utils/writer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | 5 | from pyutils import * 6 | from loguru import logger 7 | from tensorboardX import SummaryWriter 8 | 9 | 10 | def get_state_info_str(state_info): 11 | state_info_str = f"epoch={state_info['epoch']:05d} it={state_info['b']:04d} " 12 | for k, v in state_info.items(): 13 | if 'loss' in k: 14 | state_info_str = f"{state_info_str}{k}={v:.6f} " 15 | for k, v in state_info.items(): 16 | if 'err' in k: 17 | state_info_str = f"{state_info_str}{k}={v:.6f} " 18 | state_info_str = f"{state_info_str}|lr= " 19 | for lr in state_info['lr']: 20 | state_info_str = f"{state_info_str}{lr:.6f} " 21 | return state_info_str 22 | 23 | 24 | class Writer(): 25 | def __init__(self, log_dir, config): 26 | self.log_dir = log_dir 27 | if config.mode == 'train': 28 | self.summary_writer = SummaryWriter(get_directory(f"{log_dir}/summary/{config.mode}/{config.rep}")) 29 | elif config.mode == 'test_opt': 30 | self.summary_writer = SummaryWriter(get_directory(f"{log_dir}/summary/{config.mode}/{config.rep}/{config.epoch_continue}")) 31 | else: 32 | self.summary_writer = None 33 | 34 | 35 | def log_state_info(self, state_info): 36 | state_info_str = get_state_info_str(state_info) 37 | logger.info(state_info_str) 38 | 39 | 40 | def log_summary(self, state_info, global_step, mode): 41 | for k, v in state_info.items(): 42 | if 'loss' in k: 43 | self.summary_writer.add_scalar(f'{mode}/{k}', v, global_step) 44 | for i, lr in enumerate(state_info['lr']): 45 | self.summary_writer.add_scalar(f'{mode}/lr{i}', lr, global_step) 46 | 47 | 48 | def save_checkpoint(self, ckpt_path, epoch, model, latent_vecs, optimizer): 49 | torch.save( 50 | { 51 | 'epoch': epoch, 52 | 'model_state_dict': model.state_dict(), 53 | 'train_latent_vecs': latent_vecs.state_dict() if latent_vecs is not None else None, 54 | 'optimizer_state_dict': optimizer.state_dict(), 55 | }, 56 | ckpt_path 57 | ) 58 | logger.info(ckpt_path) 59 | 60 | 61 | def load_checkpoint(self, ckpt_path, model=None, latent_vecs=None, optimizer=None): 62 | # in-place load 63 | ckpt = torch.load(ckpt_path, map_location=torch.device('cpu')) 64 | if model is not None: 65 | logger.info(f"load model from ${ckpt_path}") 66 | model.load_state_dict(ckpt["model_state_dict"]) 67 | if latent_vecs is not None: 68 | logger.info(f"load lat_vecs from ${ckpt_path}") 69 | latent_vecs.load_state_dict(ckpt["train_latent_vecs"]) 70 | if optimizer is not None: 71 | logger.info(f"load optimizer from ${ckpt_path}") 72 | optimizer.load_state_dict(ckpt["optimizer_state_dict"]) 73 | logger.info("loaded!") 74 | return ckpt["epoch"] 75 | 76 | -------------------------------------------------------------------------------- /utils/ddp_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import random 4 | import shutil 5 | import subprocess 6 | 7 | import numpy as np 8 | import torch 9 | import torch.distributed as dist 10 | import torch.multiprocessing as mp 11 | 12 | 13 | def init_dist_slurm(tcp_port, local_rank, backend='nccl'): 14 | """ 15 | modified from https://github.com/open-mmlab/mmdetection 16 | Args: 17 | tcp_port: 18 | backend: 19 | 20 | Returns: 21 | 22 | """ 23 | proc_id = int(os.environ['SLURM_PROCID']) 24 | ntasks = int(os.environ['SLURM_NTASKS']) 25 | node_list = os.environ['SLURM_NODELIST'] 26 | num_gpus = torch.cuda.device_count() 27 | torch.cuda.set_device(proc_id % num_gpus) 28 | addr = subprocess.getoutput('scontrol show hostname {} | head -n1'.format(node_list)) 29 | os.environ['MASTER_PORT'] = str(tcp_port) 30 | os.environ['MASTER_ADDR'] = addr 31 | os.environ['WORLD_SIZE'] = str(ntasks) 32 | os.environ['RANK'] = str(proc_id) 33 | dist.init_process_group(backend=backend) 34 | 35 | total_gpus = dist.get_world_size() 36 | rank = dist.get_rank() 37 | return total_gpus, rank 38 | 39 | 40 | def init_dist_pytorch(tcp_port, local_rank, backend='nccl'): 41 | if mp.get_start_method(allow_none=True) is None: 42 | mp.set_start_method('spawn') 43 | # os.environ['MASTER_PORT'] = str(tcp_port) 44 | # os.environ['MASTER_ADDR'] = 'localhost' 45 | num_gpus = torch.cuda.device_count() 46 | torch.cuda.set_device(local_rank % num_gpus) 47 | 48 | dist.init_process_group( 49 | backend=backend, 50 | # init_method='tcp://127.0.0.1:%d' % tcp_port, 51 | # rank=local_rank, 52 | # world_size=num_gpus 53 | ) 54 | rank = dist.get_rank() 55 | return num_gpus, rank 56 | 57 | 58 | def get_dist_info(return_gpu_per_machine=False): 59 | if torch.__version__ < '1.0': 60 | initialized = dist._initialized 61 | else: 62 | if dist.is_available(): 63 | initialized = dist.is_initialized() 64 | else: 65 | initialized = False 66 | if initialized: 67 | rank = dist.get_rank() 68 | world_size = dist.get_world_size() 69 | else: 70 | rank = 0 71 | world_size = 1 72 | 73 | if return_gpu_per_machine: 74 | gpu_per_machine = torch.cuda.device_count() 75 | return rank, world_size, gpu_per_machine 76 | 77 | return rank, world_size 78 | 79 | 80 | def merge_results_dist(result_part, size, tmpdir): 81 | rank, world_size = get_dist_info() 82 | os.makedirs(tmpdir, exist_ok=True) 83 | 84 | dist.barrier() 85 | pickle.dump(result_part, open(os.path.join(tmpdir, 'result_part_{}.pkl'.format(rank)), 'wb')) 86 | dist.barrier() 87 | 88 | if rank != 0: 89 | return None 90 | 91 | part_list = [] 92 | for i in range(world_size): 93 | part_file = os.path.join(tmpdir, 'result_part_{}.pkl'.format(i)) 94 | part_list.append(pickle.load(open(part_file, 'rb'))) 95 | 96 | ordered_results = [] 97 | for res in zip(*part_list): 98 | ordered_results.extend(list(res)) 99 | ordered_results = ordered_results[:size] 100 | shutil.rmtree(tmpdir) 101 | return ordered_results 102 | 103 | -------------------------------------------------------------------------------- /registration_dfaust/multiCorres_sync_dfaust1k/main.m: -------------------------------------------------------------------------------- 1 | function [] = main(start_idx, interval) 2 | %% start_idx starts from 1 3 | %% load data 4 | addpath('io/'); 5 | meta_data_path = '../dfaust1k/meta_test_6499_K25.mat'; 6 | mesh_interp_dir = '../../work_dir/dfaust/ivae_dfaustJSM1k/results/test/interp_edges_sdf/6499/'; 7 | mesh_raw_dir = '../dfaust1k/mesh_raw/'; 8 | mesh_sim_dir = '../dfaust1k/mesh_sim/'; 9 | mesh_def_dir = get_directory('../dfaust1k/mesh_def/'); 10 | % log_dir = get_directory('../dfaust1k/log/'); 11 | 12 | meta_data = load(meta_data_path); 13 | edge_ids = meta_data.edge_ids; % starts from 0 14 | fids = meta_data.fids; 15 | num_edges = size(edge_ids, 1); 16 | num_meshes = size(fids, 1); 17 | template_idx = meta_data.template_idx; % starts from 0 18 | template_fid = meta_data.template_fid; 19 | assert(max(max(edge_ids)) == num_meshes - 1); 20 | 21 | template_raw_mesh = read_obj([mesh_raw_dir, template_fid, '.obj']); 22 | template_sim_mesh = read_obj([mesh_sim_dir, template_fid, '_sim.obj']); 23 | 24 | %% Hyperparameters 25 | params = Params; 26 | params.lambda = 10; 27 | params.beta = 0.05; 28 | 29 | %% 30 | if isstring(start_idx) || ischar(start_idx) 31 | start_idx = str2num(start_idx) 32 | end 33 | if isstring(interval) || ischar(interval) 34 | interval = str2num(interval) 35 | end 36 | end_idx = min(start_idx + interval - 1, num_edges); 37 | for eid = start_idx : end_idx 38 | % get pair id 39 | sid = edge_ids(eid, 1); 40 | tid = edge_ids(eid, 2); 41 | sfid = strtrim(fids(sid + 1, :)); % matlab starts from 1 42 | tfid = strtrim(fids(tid + 1, :)); % matlab starts from 1 43 | precheck_mesh_path = sprintf('./%s/meshdef_%d_%d.obj', mesh_def_dir, sid, tid); 44 | if isfile(precheck_mesh_path) 45 | continue; 46 | end 47 | mesh_src_sim = read_obj([mesh_sim_dir, sfid, '_sim.obj']); 48 | mesh_tgt_sim = read_obj([mesh_sim_dir, tfid, '_sim.obj']); 49 | mesh_src = read_obj([mesh_raw_dir, sfid, '.obj']); 50 | mesh_tgt = read_obj([mesh_raw_dir, tfid, '.obj']); 51 | % log file 52 | % log_path = sprintf('%s/%d.log', log_dir, eid); 53 | % fileID = fopen(log_path, 'w'); 54 | % fprintf(fileID, '----- interp: eid=%d, sid=%d, tid=%d, sfid=%s, tfid=%s -----\n', eid, sid, tid, sfid, tfid); 55 | fprintf('----- interp: eid=%d, sid=%d, tid=%d, sfid=%s, tfid=%s -----\n', eid, sid, tid, sfid, tfid); 56 | % load interpolation meshes 57 | meshes_interp = cell(1, 9); 58 | for i_interp = 1 : 9 59 | fpath = sprintf('%s/%d_%d/%d_%d_%02d.obj', mesh_interp_dir, sid, tid, sid, tid, i_interp); 60 | mesh_interp = read_obj(fpath); 61 | meshes_interp{i_interp} = mesh_interp; 62 | end 63 | % interpolate meshes 64 | mesh_src_sim_def = icp_interpolation(mesh_src_sim, mesh_tgt_sim, meshes_interp, params); 65 | % DEBUG: write_obj(mesh_src_sim_def, '/mnt/yanghaitao/Projects/GenCorres/gencorres/vis/registration/dfaust1k/mesh_def/mesh_src_sim_def.obj'); 66 | % fprintf(fileID, 'ED + refine start: eid = %d, sid = %d, tid = %d\n', eid, sid, tid); 67 | fprintf('ED + refine start: eid = %d, sid = %d, tid = %d\n', eid, sid, tid); 68 | 69 | mesh_src_def = embedded_deformation(mesh_src, mesh_src_sim, mesh_src_sim_def); 70 | 71 | % DEBUG: write_obj(mesh_src_def, '/mnt/yanghaitao/Projects/GenCorres/gencorres/vis/registration/dfaust1k/mesh_def/ed.obj'); 72 | 73 | poss_vec = non_rigid_icp2(mesh_src, mesh_tgt, mesh_src_def.vertexPoss, 1, 0.05, 10, 1); 74 | mesh_src_def.vertexPoss = poss_vec; 75 | 76 | dump_mesh_path = sprintf('./%s/meshdef_%d_%d.obj', mesh_def_dir, sid, tid); % starts from 0 77 | write_obj(mesh_src_def, dump_mesh_path); 78 | 79 | % fprintf(fileID, 'Done: eid = %d, sid = %d, tid = %d\n', eid, sid, tid); 80 | fprintf('Done: eid = %d, sid = %d, tid = %d\n', eid, sid, tid); 81 | % fclose(fileID); 82 | end 83 | 84 | end 85 | 86 | 87 | -------------------------------------------------------------------------------- /models/conv/message_passing.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import inspect 3 | 4 | import torch 5 | from torch_scatter import scatter 6 | 7 | special_args = [ 8 | 'edge_index', 'edge_index_i', 'edge_index_j', 'size', 'size_i', 'size_j' 9 | ] 10 | __size_error_msg__ = ('All tensors which should get mapped to the same source ' 11 | 'or target nodes must be of same size in dimension 0.') 12 | 13 | is_python2 = sys.version_info[0] < 3 14 | getargspec = inspect.getargspec if is_python2 else inspect.getfullargspec 15 | 16 | 17 | class MessagePassing(torch.nn.Module): 18 | def __init__(self, aggr='add', flow='source_to_target'): 19 | super(MessagePassing, self).__init__() 20 | 21 | self.aggr = aggr 22 | assert self.aggr in ['add', 'mean', 'max'] 23 | 24 | self.flow = flow 25 | assert self.flow in ['source_to_target', 'target_to_source'] 26 | 27 | self.__message_args__ = getargspec(self.message)[0][1:] 28 | self.__special_args__ = [(i, arg) 29 | for i, arg in enumerate(self.__message_args__) 30 | if arg in special_args] 31 | self.__message_args__ = [ 32 | arg for arg in self.__message_args__ if arg not in special_args 33 | ] 34 | self.__update_args__ = getargspec(self.update)[0][2:] 35 | 36 | def propagate(self, edge_index, size=None, dim=0, **kwargs): 37 | dim = 1 # aggregate messages wrt nodes for batched_data: [batch_size, nodes, features] 38 | size = [None, None] if size is None else list(size) 39 | assert len(size) == 2 40 | 41 | i, j = (0, 1) if self.flow == 'target_to_source' else (1, 0) 42 | ij = {"_i": i, "_j": j} 43 | 44 | message_args = [] 45 | for arg in self.__message_args__: 46 | if arg[-2:] in ij.keys(): 47 | tmp = kwargs.get(arg[:-2], None) 48 | if tmp is None: # pragma: no cover 49 | message_args.append(tmp) 50 | else: 51 | idx = ij[arg[-2:]] 52 | if isinstance(tmp, tuple) or isinstance(tmp, list): 53 | assert len(tmp) == 2 54 | if tmp[1 - idx] is not None: 55 | if size[1 - idx] is None: 56 | size[1 - idx] = tmp[1 - idx].size(dim) 57 | if size[1 - idx] != tmp[1 - idx].size(dim): 58 | raise ValueError(__size_error_msg__) 59 | tmp = tmp[idx] 60 | 61 | if tmp is None: 62 | message_args.append(tmp) 63 | else: 64 | if size[idx] is None: 65 | size[idx] = tmp.size(dim) 66 | if size[idx] != tmp.size(dim): 67 | raise ValueError(__size_error_msg__) 68 | 69 | tmp = torch.index_select(tmp, dim, edge_index[idx]) 70 | message_args.append(tmp) 71 | else: 72 | message_args.append(kwargs.get(arg, None)) 73 | 74 | size[0] = size[1] if size[0] is None else size[0] 75 | size[1] = size[0] if size[1] is None else size[1] 76 | 77 | kwargs['edge_index'] = edge_index 78 | kwargs['size'] = size 79 | 80 | for (idx, arg) in self.__special_args__: 81 | if arg[-2:] in ij.keys(): 82 | message_args.insert(idx, kwargs[arg[:-2]][ij[arg[-2:]]]) 83 | else: 84 | message_args.insert(idx, kwargs[arg]) 85 | 86 | update_args = [kwargs[arg] for arg in self.__update_args__] 87 | 88 | out = self.message(*message_args) 89 | out = scatter(out, edge_index[i], dim=dim, dim_size=size[i], reduce=self.aggr) 90 | out = self.update(out, *update_args) 91 | 92 | return out 93 | 94 | def message(self, x_j): # pragma: no cover 95 | return x_j 96 | 97 | def update(self, aggr_out): # pragma: no cover 98 | return aggr_out 99 | -------------------------------------------------------------------------------- /config/dfaust/ivae_dfaustJSM1k.yaml: -------------------------------------------------------------------------------- 1 | seed: 1 2 | dataset_exp_name: dfaust 3 | data_dir: /scratch/cluster/yanght/Dataset/Human/DFAUST/ 4 | work_dir: ./work_dir/ 5 | latent_dim: 256 6 | latent_dim_mesh: 72 7 | num_workers: 4 # 16 8 | 9 | dataset: 10 | module_name: dfaust 11 | class_name: DFaustDataSet 12 | data_dir: ${data_dir} 13 | 14 | sdf_dir: ${data_dir}/registrations_processed_sal_sigma03/ 15 | 16 | with_raw_mesh: True 17 | raw_mesh_dir: ${data_dir}/registrations/ 18 | raw_mesh_file_type: obj 19 | 20 | with_registration: False 21 | registration_dir: ${data_dir}/registrations/ 22 | 23 | # with_sim_mesh: True 24 | # sim_mesh_dir: ${data_dir}/registrations_sim/ 25 | 26 | # init_mesh_dir: ./TODO/mesh_corres/ 27 | # use_vert_pca: True 28 | # pca_n_comp: ${latent_dim_mesh} 29 | 30 | template_path: ${data_dir}/registrations/50009/running_on_spot/running_on_spot.000366.obj # template 31 | num_samples: 8192 32 | split_cfg: 33 | train: test_fps1k.json # In JSM, only test_fps1k.json is available. In shape space, use train_fps1k.json 34 | test: test_fps1k.json # split 35 | 36 | sdf_asap_start_epoch: 3000 37 | 38 | model: 39 | mesh: 40 | module_name: meshnet 41 | class_name: MeshNet 42 | auto_decoder: True 43 | in_channels: 3 44 | out_channels: [32, 32, 32, 64] 45 | latent_channels: ${latent_dim_mesh} 46 | K: 6 47 | ds_factors: [2, 2, 2, 2] 48 | 49 | sdf: 50 | module_name: implicit_vae 51 | class_name: ImplicitGenerator 52 | auto_decoder: False 53 | encoder: 54 | with_normals: False 55 | decoder: 56 | latent_size: ${latent_dim} 57 | dims : [ 512, 512, 512, 512, 512, 512, 512, 512 ] 58 | norm_layers : [0, 1, 2, 3, 4, 5, 6, 7] 59 | latent_in : [4] 60 | weight_norm : True 61 | xyz_dim : 3 62 | 63 | loss: 64 | ###### mesh ###### 65 | # Mesh ARAP loss 66 | mesh_arap_weight: 0.1 # 5e-4 67 | use_mesh_arap_epoch: 500 68 | use_mesh_arap_with_asap: True 69 | mesh_weight_asap: 0.1 70 | nz_max: 60 # random sample nz_max latent channels to compute ARAP energy 71 | chamfer_loss_weight: 1 72 | point2point_loss_weight: 1 73 | point2plane_loss_weight: 0 74 | ###### sdf ###### 75 | # SDF loss 76 | sdf_weight: 1.0 77 | sdf_loss_type: L1 78 | # VAE latent reg loss 79 | vae_latent_reg_weight: 0.001 80 | # AD latent reg loss 81 | ad_latent_reg_weight: 0.0 82 | # sdf grad loss 83 | grad_loss_weight: 1.0 84 | # sdf surfafe ARAP loss 85 | use_sdf_asap_epoch: ${sdf_asap_start_epoch} 86 | simplify_mesh: True 87 | implicit_reg_type: 'dense_inverse' 88 | sample_latent_space: True 89 | sample_latent_space_type: 'line' # normal, line 90 | sample_latent_space_detach: False 91 | sdf_asap_weight: 0.001 92 | weight_asap: 0.1 93 | mu_asap: 0.0001 94 | add_mu_diag_to_hessian: True 95 | sdf_grid_size: 64 96 | # cyc regularization 97 | use_cyc_reg: True 98 | eps_cyc: 0.001 99 | sdf_cyc_weight: 0.0001 100 | 101 | 102 | optimization: 103 | mesh: 104 | batch_size: 32 # 64 105 | lr: 0.001 106 | lr_decay: 0.99 107 | decay_step: 5 108 | num_epochs: 4500001 109 | 110 | lat_vecs: 111 | lr: 0.001 112 | test_lr: 0.01 113 | test_lr_decay: 0.99 114 | test_decay_step: 5 115 | num_test_epochs: 2500001 116 | 117 | sdf: 118 | batch_size: 8 119 | lr: 0.0005 120 | gammas: [ 0.5, 0.5, 1, 0.5, 0.5, 0.5] 121 | milestones: [1000, 2000, 3000, 4000, 5000, 6000] 122 | num_epochs: 5000001 # 5001 123 | 124 | lat_vecs: 125 | lr: 0.001 126 | test_lr: 0.001 # 0.005 127 | test_lr_decay: 0.1 128 | test_decay_step: 400 129 | num_test_epochs: 801 130 | 131 | log: 132 | log_batch_interval: 100 133 | save_epoch_interval: 500 134 | save_latest_epoch_interval: 100 135 | 136 | 137 | -------------------------------------------------------------------------------- /config/dfaust/admesh_dfaustJSM1k.yaml: -------------------------------------------------------------------------------- 1 | seed: 1 2 | dataset_exp_name: dfaust 3 | data_dir: /scratch/cluster/yanght/Dataset/Human/DFAUST/ 4 | work_dir: ./work_dir/ 5 | latent_dim: 256 6 | latent_dim_mesh: 72 7 | num_workers: 4 # 16 8 | 9 | dataset: 10 | module_name: dfaust 11 | class_name: DFaustDataSet 12 | data_dir: ${data_dir} 13 | 14 | sdf_dir: ${data_dir}/registrations_processed_sal_sigma03/ 15 | 16 | with_raw_mesh: True 17 | raw_mesh_dir: ${data_dir}/registrations/ 18 | raw_mesh_file_type: obj 19 | 20 | with_registration: False 21 | registration_dir: ${data_dir}/registrations/ 22 | 23 | # with_sim_mesh: True 24 | # sim_mesh_dir: ${data_dir}/registrations_sim/ 25 | 26 | init_mesh_dir: ./registration_dfaust/dfaust1k/mesh_corres/ # We use mesh_corres_K25 27 | use_vert_pca: True 28 | pca_n_comp: ${latent_dim_mesh} 29 | 30 | template_path: ${data_dir}/registrations/50009/running_on_spot/running_on_spot.000366.obj # template 31 | num_samples: 8192 32 | split_cfg: 33 | train: test_fps1k.json # In JSM, only test_fps1k.json is available. In shape space, use train_fps1k.json 34 | test: test_fps1k.json # split 35 | 36 | sdf_asap_start_epoch: 3000 37 | 38 | model: 39 | mesh: 40 | module_name: meshnet 41 | class_name: MeshNet 42 | auto_decoder: True 43 | in_channels: 3 44 | out_channels: [32, 32, 32, 64] 45 | latent_channels: ${latent_dim_mesh} 46 | K: 6 47 | ds_factors: [2, 2, 2, 2] 48 | 49 | sdf: 50 | module_name: implicit_vae 51 | class_name: ImplicitGenerator 52 | auto_decoder: False 53 | encoder: 54 | with_normals: False 55 | decoder: 56 | latent_size: ${latent_dim} 57 | dims : [ 512, 512, 512, 512, 512, 512, 512, 512 ] 58 | norm_layers : [0, 1, 2, 3, 4, 5, 6, 7] 59 | latent_in : [4] 60 | weight_norm : True 61 | xyz_dim : 3 62 | 63 | loss: 64 | ###### mesh ###### 65 | # Mesh ARAP loss 66 | mesh_arap_weight: 0.1 # 5e-4 67 | use_mesh_arap_epoch: 500 68 | use_mesh_arap_with_asap: True 69 | mesh_weight_asap: 0.1 70 | nz_max: 60 # random sample nz_max latent channels to compute ARAP energy 71 | chamfer_loss_weight: 1 72 | point2point_loss_weight: 1 73 | point2plane_loss_weight: 0 74 | ###### sdf ###### 75 | # SDF loss 76 | sdf_weight: 1.0 77 | sdf_loss_type: L1 78 | # VAE latent reg loss 79 | vae_latent_reg_weight: 0.001 80 | # AD latent reg loss 81 | ad_latent_reg_weight: 0.0 82 | # sdf grad loss 83 | grad_loss_weight: 1.0 84 | # sdf surfafe ARAP loss 85 | use_sdf_asap_epoch: ${sdf_asap_start_epoch} 86 | simplify_mesh: True 87 | implicit_reg_type: 'dense_inverse' 88 | sample_latent_space: True 89 | sample_latent_space_type: 'line' # normal, line 90 | sample_latent_space_detach: False 91 | sdf_asap_weight: 0.001 92 | weight_asap: 0.1 93 | mu_asap: 0.0001 94 | add_mu_diag_to_hessian: True 95 | sdf_grid_size: 64 96 | # cyc regularization 97 | use_cyc_reg: True 98 | eps_cyc: 0.001 99 | sdf_cyc_weight: 0.0001 100 | 101 | 102 | optimization: 103 | mesh: 104 | batch_size: 32 # 64 105 | lr: 0.001 106 | lr_decay: 0.99 107 | decay_step: 5 108 | num_epochs: 4500001 109 | 110 | lat_vecs: 111 | lr: 0.001 112 | test_lr: 0.01 113 | test_lr_decay: 0.99 114 | test_decay_step: 5 115 | num_test_epochs: 2500001 116 | 117 | sdf: 118 | batch_size: 8 119 | lr: 0.0005 120 | gammas: [ 0.5, 0.5, 1, 0.5, 0.5, 0.5] 121 | milestones: [1000, 2000, 3000, 4000, 5000, 6000] 122 | num_epochs: 5000001 # 5001 123 | 124 | lat_vecs: 125 | lr: 0.001 126 | test_lr: 0.001 # 0.005 127 | test_lr_decay: 0.1 128 | test_decay_step: 400 129 | num_test_epochs: 801 130 | 131 | log: 132 | log_batch_interval: 10 133 | save_epoch_interval: 50 134 | save_latest_epoch_interval: 10 135 | 136 | 137 | -------------------------------------------------------------------------------- /config/smal/ivae_smalJSM.yaml: -------------------------------------------------------------------------------- 1 | seed: 1 2 | dataset_exp_name: smal 3 | data_dir: /scratch/cluster/yanght/Dataset/Human/SMAL/ 4 | work_dir: ./work_dir/ 5 | latent_dim: 128 6 | latent_dim_mesh: 64 7 | num_workers: 4 # 16 8 | 9 | dataset: 10 | module_name: smal 11 | class_name: SMALDataSet 12 | data_dir: ${data_dir} 13 | 14 | sdf_dir: ${data_dir}/registrations_processed_sal_sigma03/ 15 | 16 | with_raw_mesh: True 17 | raw_mesh_dir: ${data_dir}/registrations/ 18 | raw_mesh_file_type: obj 19 | 20 | with_registration: False 21 | registration_dir: ${data_dir}/registrations/ 22 | 23 | # with_sim_mesh: True 24 | # sim_mesh_dir: ${data_dir}/registrations_sim/ 25 | 26 | # init_mesh_dir: ./TODO/mesh_corres/ 27 | # use_vert_pca: True 28 | # pca_n_comp: ${latent_dim_mesh} 29 | 30 | template_path: ${data_dir}/registrations/smal401/pose/000.obj # template 31 | num_samples: 8192 32 | split_cfg: 33 | train: test_94.json # In JSM, only test_94.json is available. In shape space, use train_289.json 34 | test: test_94.json # split 35 | 36 | sdf_asap_start_epoch: 4000 37 | 38 | model: 39 | mesh: 40 | module_name: meshnet 41 | class_name: MeshNet 42 | auto_decoder: True 43 | in_channels: 3 44 | out_channels: [16, 16, 16, 32] 45 | latent_channels: ${latent_dim_mesh} 46 | K: 6 47 | ds_factors: [1, 1, 1, 1] 48 | 49 | sdf: 50 | module_name: implicit_vae 51 | class_name: ImplicitGenerator 52 | auto_decoder: False 53 | encoder: 54 | with_normals: False 55 | decoder: 56 | latent_size: ${latent_dim} 57 | dims : [ 512, 512, 512, 512, 512, 512, 512, 512 ] 58 | norm_layers : [0, 1, 2, 3, 4, 5, 6, 7] 59 | latent_in : [4] 60 | weight_norm : True 61 | xyz_dim : 3 62 | 63 | loss: 64 | ###### mesh ###### 65 | # Mesh ARAP loss 66 | mesh_arap_weight: 0.0005 # 5e-4 67 | use_mesh_arap_epoch: 1000 68 | use_mesh_arap_with_asap: True 69 | mesh_weight_asap: 0.1 70 | nz_max: 64 # random sample nz_max latent channels to compute ARAP energy 71 | chamfer_loss_weight: 1 72 | point2point_loss_weight: 1 73 | point2plane_loss_weight: 0 74 | ###### sdf ###### 75 | # SDF loss 76 | sdf_weight: 1.0 77 | sdf_loss_type: L1 78 | # VAE latent reg loss 79 | vae_latent_reg_weight: 0.001 80 | # AD latent reg loss 81 | ad_latent_reg_weight: 0.0 82 | # sdf grad loss 83 | grad_loss_weight: 0.1 84 | # sdf surfafe ARAP loss 85 | use_sdf_asap_epoch: ${sdf_asap_start_epoch} 86 | simplify_mesh: True 87 | implicit_reg_type: 'dense_inverse' 88 | sample_latent_space: True 89 | sample_latent_space_type: 'line' # normal, line 90 | sample_latent_space_detach: False 91 | sdf_asap_weight: 0.001 92 | weight_asap: 0.1 93 | mu_asap: 0.0001 94 | add_mu_diag_to_hessian: True 95 | sdf_grid_size: 50 96 | x_range: [-1.4, 0.95] 97 | y_range: [-0.75, 0.7] 98 | z_range: [-0.85, 1.2] 99 | # cyc regularization 100 | use_cyc_reg: True 101 | eps_cyc: 0.001 102 | sdf_cyc_weight: 0.0001 103 | 104 | 105 | optimization: 106 | mesh: 107 | batch_size: 8 # 64 108 | lr: 0.01 109 | lr_decay: 0.99 110 | decay_step: 3 111 | num_epochs: 2000 112 | 113 | lat_vecs: 114 | lr: 0.01 115 | test_lr: 0.01 116 | test_lr_decay: 0.99 117 | test_decay_step: 1 118 | num_test_epochs: 2500001 119 | 120 | sdf: 121 | batch_size: 8 122 | lr: 0.005 123 | gammas: [ 0.5, 0.5, 0.5, 2, 0.5, 0.5, 0.5] 124 | milestones: [1000, 2000, 3000, 4000, 5000, 6000, 7000] 125 | num_epochs: 5000001 # 5001 126 | 127 | lat_vecs: 128 | lr: 0.01 129 | test_lr: 0.001 # 0.005 130 | test_lr_decay: 0.1 131 | test_decay_step: 400 132 | num_test_epochs: 801 133 | 134 | log: 135 | log_batch_interval: 100 136 | save_epoch_interval: 500 137 | save_latest_epoch_interval: 100 138 | 139 | 140 | -------------------------------------------------------------------------------- /config/smal/admesh_smalJSM.yaml: -------------------------------------------------------------------------------- 1 | seed: 1 2 | dataset_exp_name: smal 3 | data_dir: /scratch/cluster/yanght/Dataset/Human/SMAL/ 4 | work_dir: ./work_dir/ 5 | latent_dim: 128 6 | latent_dim_mesh: 64 7 | num_workers: 4 # 16 8 | 9 | dataset: 10 | module_name: smal 11 | class_name: SMALDataSet 12 | data_dir: ${data_dir} 13 | 14 | sdf_dir: ${data_dir}/registrations_processed_sal_sigma03/ 15 | 16 | with_raw_mesh: True 17 | raw_mesh_dir: ${data_dir}/registrations/ 18 | raw_mesh_file_type: obj 19 | 20 | with_registration: False 21 | registration_dir: ${data_dir}/registrations/ 22 | 23 | # with_sim_mesh: True 24 | # sim_mesh_dir: ${data_dir}/registrations_sim/ 25 | 26 | init_mesh_dir: ./registration_smal/smal/mesh_corres 27 | use_vert_pca: True 28 | pca_n_comp: ${latent_dim_mesh} 29 | 30 | template_path: ${data_dir}/registrations/smal401/pose/000.obj # template 31 | num_samples: 8192 32 | split_cfg: 33 | train: test_94.json # In JSM, only test_94.json is available. In shape space, use train_289.json 34 | test: test_94.json # split 35 | 36 | sdf_asap_start_epoch: 4000 37 | 38 | model: 39 | mesh: 40 | module_name: meshnet 41 | class_name: MeshNet 42 | auto_decoder: True 43 | in_channels: 3 44 | out_channels: [16, 16, 16, 32] 45 | latent_channels: ${latent_dim_mesh} 46 | K: 6 47 | ds_factors: [1, 1, 1, 1] 48 | 49 | sdf: 50 | module_name: implicit_vae 51 | class_name: ImplicitGenerator 52 | auto_decoder: False 53 | encoder: 54 | with_normals: False 55 | decoder: 56 | latent_size: ${latent_dim} 57 | dims : [ 512, 512, 512, 512, 512, 512, 512, 512 ] 58 | norm_layers : [0, 1, 2, 3, 4, 5, 6, 7] 59 | latent_in : [4] 60 | weight_norm : True 61 | xyz_dim : 3 62 | 63 | loss: 64 | ###### mesh ###### 65 | # Mesh ARAP loss 66 | mesh_arap_weight: 0.0005 # 5e-4 67 | use_mesh_arap_epoch: 1000 68 | use_mesh_arap_with_asap: True 69 | mesh_weight_asap: 0.1 70 | nz_max: 64 # random sample nz_max latent channels to compute ARAP energy 71 | chamfer_loss_weight: 1 72 | point2point_loss_weight: 1 73 | point2plane_loss_weight: 0 74 | ###### sdf ###### 75 | # SDF loss 76 | sdf_weight: 1.0 77 | sdf_loss_type: L1 78 | # VAE latent reg loss 79 | vae_latent_reg_weight: 0.001 80 | # AD latent reg loss 81 | ad_latent_reg_weight: 0.0 82 | # sdf grad loss 83 | grad_loss_weight: 0.1 84 | # sdf surfafe ARAP loss 85 | use_sdf_asap_epoch: ${sdf_asap_start_epoch} 86 | simplify_mesh: True 87 | implicit_reg_type: 'dense_inverse' 88 | sample_latent_space: True 89 | sample_latent_space_type: 'line' # normal, line 90 | sample_latent_space_detach: False 91 | sdf_asap_weight: 0.001 92 | weight_asap: 0.1 93 | mu_asap: 0.0001 94 | add_mu_diag_to_hessian: True 95 | sdf_grid_size: 50 96 | x_range: [-1.4, 0.95] 97 | y_range: [-0.75, 0.7] 98 | z_range: [-0.85, 1.2] 99 | # cyc regularization 100 | use_cyc_reg: True 101 | eps_cyc: 0.001 102 | sdf_cyc_weight: 0.0001 103 | 104 | 105 | optimization: 106 | mesh: 107 | batch_size: 8 # 64 108 | lr: 0.01 109 | lr_decay: 0.99 110 | decay_step: 3 111 | num_epochs: 2000 112 | 113 | lat_vecs: 114 | lr: 0.01 115 | test_lr: 0.01 116 | test_lr_decay: 0.99 117 | test_decay_step: 1 118 | num_test_epochs: 2500001 119 | 120 | sdf: 121 | batch_size: 8 122 | lr: 0.005 123 | gammas: [ 0.5, 0.5, 0.5, 2, 0.5, 0.5, 0.5] 124 | milestones: [1000, 2000, 3000, 4000, 5000, 6000, 7000] 125 | num_epochs: 5000001 # 5001 126 | 127 | lat_vecs: 128 | lr: 0.01 129 | test_lr: 0.001 # 0.005 130 | test_lr_decay: 0.1 131 | test_decay_step: 400 132 | num_test_epochs: 801 133 | 134 | log: 135 | log_batch_interval: 100 136 | save_epoch_interval: 500 137 | save_latest_epoch_interval: 100 138 | 139 | 140 | -------------------------------------------------------------------------------- /models/arap.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.utils import degree, get_laplacian 3 | import torch_sparse as ts 4 | import numpy as np 5 | import sys 6 | from loguru import logger 7 | 8 | def get_laplacian_kron3x3(edge_index, edge_weights, N): 9 | edge_index, edge_weight = get_laplacian(edge_index, edge_weights, num_nodes=N) 10 | edge_weight *= 2 11 | e0, e1 = edge_index 12 | i0 = [e0*3, e0*3+1, e0*3+2] 13 | i1 = [e1*3, e1*3+1, e1*3+2] 14 | vals = [edge_weight, edge_weight, edge_weight] 15 | i0 = torch.cat(i0, 0) 16 | i1 = torch.cat(i1, 0) 17 | vals = torch.cat(vals, 0) 18 | indices, vals = ts.coalesce([i0, i1], vals, N*3, N*3) 19 | return indices, vals 20 | 21 | class ARAP(torch.nn.Module): 22 | def __init__(self, template_face, num_points): 23 | super(ARAP, self).__init__() 24 | N = num_points 25 | self.template_face = template_face 26 | adj = np.zeros((num_points, num_points)) 27 | adj[template_face[:, 0], template_face[:, 1]] = 1 28 | adj[template_face[:, 1], template_face[:, 2]] = 1 29 | adj[template_face[:, 0], template_face[:, 2]] = 1 30 | adj = adj + adj.T 31 | edge_index = torch.as_tensor(np.stack(np.where(adj > 0), 0), 32 | dtype=torch.long) 33 | e0, e1 = edge_index 34 | deg = degree(e0, N) 35 | edge_weight = torch.ones_like(e0) 36 | 37 | L_indices, L_vals = get_laplacian_kron3x3(edge_index, edge_weight, N) 38 | self.register_buffer('L_indices', L_indices) 39 | self.register_buffer('L_vals', L_vals) 40 | self.register_buffer('edge_weight', edge_weight) 41 | self.register_buffer('edge_index', edge_index) 42 | 43 | def forward(self, x, J, k=0, **kwargs): 44 | """ 45 | x: [B, N, 3] point locations. 46 | J: [B, N*3, D] Jacobian of generator. 47 | J_eigvals: [B, D] 48 | """ 49 | num_batches, N = x.shape[:2] 50 | e0, e1 = self.edge_index 51 | edge_vecs = x[:, e0, :] - x[:, e1, :] 52 | trace_ = [] 53 | 54 | for i in range(num_batches): 55 | LJ = ts.spmm(self.L_indices, self.L_vals, N*3, N*3, J[i]) 56 | JTLJ = J[i].T.matmul(LJ) 57 | 58 | B0, B1, B_vals = [], [], [] 59 | B0.append(e0*3 ); B1.append(e1*3+1); B_vals.append(-edge_vecs[i, :, 2]*self.edge_weight) 60 | B0.append(e0*3 ); B1.append(e1*3+2); B_vals.append( edge_vecs[i, :, 1]*self.edge_weight) 61 | B0.append(e0*3+1); B1.append(e1*3+0); B_vals.append( edge_vecs[i, :, 2]*self.edge_weight) 62 | B0.append(e0*3+1); B1.append(e1*3+2); B_vals.append(-edge_vecs[i, :, 0]*self.edge_weight) 63 | B0.append(e0*3+2); B1.append(e1*3+0); B_vals.append(-edge_vecs[i, :, 1]*self.edge_weight) 64 | B0.append(e0*3+2); B1.append(e1*3+1); B_vals.append( edge_vecs[i, :, 0]*self.edge_weight) 65 | 66 | B0.append(e0*3 ); B1.append(e0*3+1); B_vals.append(-edge_vecs[i, :, 2]*self.edge_weight) 67 | B0.append(e0*3 ); B1.append(e0*3+2); B_vals.append( edge_vecs[i, :, 1]*self.edge_weight) 68 | B0.append(e0*3+1); B1.append(e0*3+0); B_vals.append( edge_vecs[i, :, 2]*self.edge_weight) 69 | B0.append(e0*3+1); B1.append(e0*3+2); B_vals.append(-edge_vecs[i, :, 0]*self.edge_weight) 70 | B0.append(e0*3+2); B1.append(e0*3+0); B_vals.append(-edge_vecs[i, :, 1]*self.edge_weight) 71 | B0.append(e0*3+2); B1.append(e0*3+1); B_vals.append( edge_vecs[i, :, 0]*self.edge_weight) 72 | B0 = torch.cat(B0, 0) 73 | B1 = torch.cat(B1, 0) 74 | B_vals = torch.cat(B_vals, 0) 75 | B_indices, B_vals = ts.coalesce([B0, B1], B_vals, N*3, N*3) 76 | BT_indices, BT_vals = ts.transpose(B_indices, B_vals, N*3, N*3) 77 | 78 | C0, C1, C_vals = [], [], [] 79 | edge_vecs_sq = (edge_vecs[i] * edge_vecs[i]).sum(-1) 80 | evi = edge_vecs[i] 81 | for di in range(3): 82 | for dj in range(3): 83 | C0.append(e0*3+di); C1.append(e0*3+dj); C_vals.append(-evi[:, di]*evi[:, dj]*self.edge_weight) 84 | C0.append(e0*3+di); C1.append(e0*3+di); C_vals.append(edge_vecs_sq*self.edge_weight) 85 | C0 = torch.cat(C0, 0) 86 | C1 = torch.cat(C1, 0) 87 | C_vals = torch.cat(C_vals, 0) 88 | C_indices, C_vals = ts.coalesce([C0, C1], C_vals, N*3, N*3) 89 | try: 90 | C_vals = C_vals.view(N, 3, 3).inverse().reshape(-1) 91 | except: 92 | logger.debug('C_vals error: use pinv') 93 | C_vals = torch.linalg.pinv(C_vals.view(N, 3, 3)).reshape(-1) 94 | BTJ = ts.spmm(BT_indices, BT_vals, N*3, N*3, J[i]) 95 | CBTJ = ts.spmm(C_indices, C_vals, N*3, N*3, BTJ) 96 | JTBCBTJ = BTJ.T.mm(CBTJ) 97 | 98 | e = torch.linalg.eigvalsh(JTLJ-JTBCBTJ).clip(0) 99 | 100 | e = e ** 0.5 101 | 102 | trace = e.sum() 103 | 104 | trace_.append(trace) 105 | 106 | trace_ = torch.stack(trace_, ) 107 | return trace_.mean() 108 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GenCorres 2 | Code for ICLR 2024 paper: [GenCorres: Consistent Shape Matching via Coupled Implicit-Explicit Shape Generative Models](https://openreview.net/pdf?id=dGH4kHFKFj). 3 | 4 | A cleaner version of the code for training the implicit generator is in the [Supplementary Material](https://openreview.net/attachment?id=dGH4kHFKFj&name=supplementary_material). 5 | 6 | ## Dataset 7 | 8 | We use the data processing code of [SALD](https://github.com/matanatz/SALD/tree/main). The processed dataset is in [this link](https://drive.google.com/drive/folders/1JvPRxcuqeUV9evtUNKMKgQWnhs0uzlg-?usp=drive_link). 9 | 10 | Unzip the dataset: 11 | ``` 12 | . 13 | ├── DFAUST 14 | │   ├── registrations 15 | │   └── registrations_processed_sal_sigma03 16 | └── SMAL 17 | ├── registrations 18 | └── registrations_processed_sal_sigma03 19 | ``` 20 | Change the `data_dir` in the config file (e.g. `./config/dfaust/ivae_dfaustJSM1k.yaml`). 21 | 22 | Below we show the example of JSM for the DFAUST dataset (1k shapes). Pretrained model is in [work_dir.zip](https://drive.google.com/drive/folders/1JvPRxcuqeUV9evtUNKMKgQWnhs0uzlg-?usp=drive_link). 23 | 24 | ## Stage 1 25 | Train the implicit network with regularization to fit the input shapes: 26 | ``` 27 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node=8 main.py --launcher pytorch --config ./config/dfaust/ivae_dfaustJSM1k.yaml --mode train --rep sdf # (--continue_from 2999) 28 | ``` 29 | 30 | To visualize the interpolation of a pair of shapes in the shape space: 31 | ``` 32 | CUDA_VISIBLE_DEVICES=0 python main.py --config ./config/dfaust/ivae_dfaustJSM1k.yaml --mode interp --rep sdf --continue_from 6499 --split train --interp_src_fid 50009-running_on_spot-running_on_spot.000366 --interp_tgt_fid 50002-chicken_wings-chicken_wings.004011 33 | ``` 34 | 35 | ## Stage 2 36 | 37 | ### Latent space interpolation 38 | 39 | #### Generate latents for each shape 40 | ``` 41 | CUDA_VISIBLE_DEVICES=0 python main.py --config ./config/dfaust/ivae_dfaustJSM1k.yaml --mode analysis --rep sdf --continue_from 6499 --split test 42 | ``` 43 | The outputs are `latents_all_test_6499.npy` and `latents_all_test_6499.pkl` in `work_dir/dfaust/ivae_dfaustJSM1k/results/test/analysis_sdf`. 44 | 45 | #### Create KNN graph 46 | Create a KNN graph according to the latents, also add edges from the template to all the remaining shapes. 47 | ``` 48 | cd ./registration_dfaust 49 | python gen_edges.py --epoch 6499 --split test --data_root ../work_dir/dfaust/ivae_dfaustJSM1k/results/test/analysis_sdf/ 50 | ``` 51 | The outputs are: `work_dir/dfaust/ivae_dfaustJSM1k/results/test/analysis_sdf/edge_ids/test_6499_edge_ids_K25.npy`. 52 | 53 | #### Interpolate shapes according to edge_ids 54 | The command is in `interp/batch_interp.sh`. We utilize [HTCondor](https://htcondor.org/) to accelerate the execution. 55 | ``` 56 | condor_submit interp/condor.sh 57 | ``` 58 | The outputs are in: `work_dir/dfaust/ivae_dfaustJSM1k/results/test/interp_edges_sdf/6499/`. Each folder stores the interpolation results of a pair of shapes. 59 | 60 | 61 | ### Nonrigid registration between pairs 62 | 63 | #### Prepare data for MATLAB 64 | ``` 65 | cd ./registration_dfaust 66 | python preprocess_registration_data.py 67 | ``` 68 | The outputs are: 69 | ``` 70 | ├── mesh_raw # raw mesh 71 | ├── mesh_sim # simplified mesh, each about 1k vertices 72 | └── meta_test_6499_K5.mat # edge_ids and fids 73 | ``` 74 | 75 | #### Nonrigid registration 76 | We use MATLAB to solve the following optimization problem: 77 | ``` 78 | Energy = L_point2point * beta + L_point2plane * (1 - beta) + lambda * L_arap 79 | ``` 80 | 81 | To utilize [HTCondor](https://htcondor.org/ ) to accelerate the execution: 82 | ``` 83 | cd ./registration_dfaust/multiCorres_sync_dfaust1k 84 | condor_submit condor.sh 85 | ``` 86 | The registered meshes are in: `./dfaust1k/mesh_def` 87 | 88 | ### Propagate correspondences for initialization 89 | ``` 90 | cd ./registration_dfaust 91 | python gen_graph.py --epoch 6499 --split test --edge_ids_path ../work_dir/dfaust/ivae_dfaustJSM1k/results/test/analysis_sdf/edge_ids/test_6499_edge_ids_K25.npy 92 | ``` 93 | The outputs are in: `./dfaust1k/mesh_corres/` 94 | 95 | ## Stage 3 96 | 97 | ### Initialize the mesh generator 98 | ``` 99 | CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py --config ./config/dfaust/admesh_dfaustJSM1k.yaml --rep mesh --mode train --data_parallel 100 | ``` 101 | 102 | ### Refinement 103 | After training 999 epochs, change the hyperparameter `mesh_arap_weight` in the yaml file to `0.001` and resume training: 104 | ``` 105 | CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py --config ./config/dfaust/admesh_dfaustJSM1k.yaml --rep mesh --mode train --data_parallel --continue_from 999 --batch_size 28 106 | ``` 107 | 108 | Stop training at epoch 1500. To generate the final correspondences: 109 | ``` 110 | python main.py --config ./config/dfaust/admesh_dfaustJSM1k.yaml --rep mesh --mode eval --continue_from 1499 --split train --parallel_idx 0 --parallel_interval 1000 111 | ``` 112 | The outputs are in `work_dir/dfaust/ivae_dfaustJSM1k/results/train/eval_mesh`. 113 | 114 | 115 | ## Contact 116 | If you have any questions, you can contact Haitao Yang (yanghtr [AT] outlook [DOT] com). 117 | 118 | -------------------------------------------------------------------------------- /models/saldnet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from loguru import logger 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch import distributions as dist 8 | 9 | 10 | def maxpool(x, dim=-1, keepdim=False): 11 | out, _ = x.max(dim=dim, keepdim=keepdim) 12 | return out 13 | 14 | 15 | class SimplePointnet(nn.Module): 16 | ''' PointNet-based encoder network. 17 | Args: 18 | c_dim (int): dimension of latent code c 19 | dim (int): input points dimension 20 | hidden_dim (int): hidden dimension of the network 21 | ''' 22 | 23 | def __init__(self, c_dim=128, dim=3, hidden_dim=128): 24 | super().__init__() 25 | self.c_dim = c_dim 26 | 27 | self.fc_pos = nn.Linear(dim, 2 * hidden_dim) 28 | self.fc_0 = nn.Linear(2 * hidden_dim, hidden_dim) 29 | self.fc_1 = nn.Linear(2 * hidden_dim, hidden_dim) 30 | self.fc_2 = nn.Linear(2 * hidden_dim, hidden_dim) 31 | self.fc_3 = nn.Linear(2 * hidden_dim, hidden_dim) 32 | 33 | self.fc_mean = nn.Linear(hidden_dim, c_dim) 34 | self.fc_std = nn.Linear(hidden_dim, c_dim) 35 | 36 | torch.nn.init.constant_(self.fc_mean.weight, 0) 37 | torch.nn.init.constant_(self.fc_mean.bias, 0) 38 | 39 | torch.nn.init.constant_(self.fc_std.weight, 0) 40 | torch.nn.init.constant_(self.fc_std.bias, -10) 41 | 42 | self.actvn = nn.ReLU() 43 | self.pool = maxpool 44 | 45 | def forward(self, p): 46 | net = self.fc_pos(p) 47 | net = self.fc_0(self.actvn(net)) 48 | pooled = self.pool(net, dim=1, keepdim=True).expand(net.size()) 49 | net = torch.cat([net, pooled], dim=2) 50 | 51 | net = self.fc_1(self.actvn(net)) 52 | pooled = self.pool(net, dim=1, keepdim=True).expand(net.size()) 53 | net = torch.cat([net, pooled], dim=2) 54 | 55 | net = self.fc_2(self.actvn(net)) 56 | pooled = self.pool(net, dim=1, keepdim=True).expand(net.size()) 57 | net = torch.cat([net, pooled], dim=2) 58 | 59 | net = self.fc_3(self.actvn(net)) 60 | 61 | net = self.pool(net, dim=1) 62 | 63 | c_mean = self.fc_mean(self.actvn(net)) 64 | c_std = self.fc_std(self.actvn(net)) 65 | 66 | return c_mean, c_std 67 | 68 | 69 | class ImplicitMap(nn.Module): 70 | def __init__( 71 | self, 72 | latent_size, 73 | dims, 74 | norm_layers=(), 75 | latent_in=(), 76 | weight_norm=False, 77 | activation=None, 78 | xyz_dim=3, 79 | geometric_init=True, 80 | beta=100, 81 | **kwargs 82 | ): 83 | super().__init__() 84 | 85 | bias = 1.0 86 | self.latent_size = latent_size 87 | last_out_dim = 1 88 | dims = [latent_size + xyz_dim] + list(dims) + [last_out_dim] 89 | self.d_in = latent_size + xyz_dim 90 | self.latent_in = latent_in 91 | self.num_layers = len(dims) 92 | 93 | for l in range(0, self.num_layers - 1): 94 | if l + 1 in latent_in: 95 | out_dim = dims[l + 1] - dims[0] 96 | else: 97 | out_dim = dims[l + 1] 98 | 99 | lin = nn.Linear(dims[l], out_dim) 100 | 101 | if geometric_init: 102 | if l == self.num_layers - 2: 103 | torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001) 104 | torch.nn.init.constant_(lin.bias, -bias) 105 | else: 106 | torch.nn.init.constant_(lin.bias, 0.0) 107 | torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) 108 | 109 | if weight_norm: 110 | lin = nn.utils.weight_norm(lin) 111 | 112 | setattr(self, "lin" + str(l), lin) 113 | 114 | self.softplus = nn.Softplus(beta=beta) 115 | 116 | 117 | def forward(self, inputs, latent): 118 | ''' 119 | Args: 120 | inputs: (B, N, 3) or (N1+...+NB, 3) 121 | latent: (B, din) or (N1+...+NB, din) 122 | return: 123 | x: (B, N, 1) or (N1+...+NB, 1) 124 | ''' 125 | assert(self.latent_size > 0) 126 | assert(len(latent.shape) == 2) 127 | assert(latent.shape[0] == inputs.shape[0]) 128 | if len(inputs.shape) == 3: 129 | # inputs: (B, N, 3), latent: (B, din) 130 | B, N = inputs.shape[0], inputs.shape[1] 131 | inputs_con = latent.unsqueeze(1).repeat(1, N, 1) # (B, N, din) 132 | elif len(inputs.shape) == 2: 133 | # inputs: (N1+...+NB, 3), latent: (N1+...+NB, din) 134 | inputs_con = latent 135 | else: 136 | raise AssertionError 137 | 138 | x = torch.cat([inputs, inputs_con], dim=-1) # (B, N, din + 3) or (N1+...+NB, din+3) 139 | 140 | to_cat = x 141 | 142 | for l in range(0, self.num_layers - 1): 143 | lin = getattr(self, "lin" + str(l)) 144 | 145 | if l in self.latent_in: 146 | x = torch.cat([x, to_cat], -1) / np.sqrt(2) 147 | 148 | x = lin(x) 149 | 150 | if l < self.num_layers - 2: 151 | x = self.softplus(x) 152 | 153 | return x 154 | 155 | 156 | -------------------------------------------------------------------------------- /utils/geom_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | import trimesh 7 | import numpy as np 8 | import networkx as nx 9 | import point_cloud_utils as pcu 10 | 11 | def index(x, idxs, dim): 12 | ''' Index a tensor along a given dimension using an index tensor, replacing 13 | the shape along the given dimension with the shape of the index tensor. 14 | Example: 15 | x: [8, 6890, 3] 16 | idxs: [13776, 3] 17 | y = index(x, idxs, dim=1) -> y: [B, 13776, 3, 3] 18 | with each y[b, i, j, k] = x[b, idxs[i, j], k] 19 | ''' 20 | target_shape = [*x.shape] 21 | del target_shape[dim] 22 | target_shape[dim:dim] = [*idxs.shape] 23 | return x.index_select(dim, idxs.view(-1)).reshape(target_shape) 24 | 25 | 26 | def compute_face_normals(v, vi): 27 | ''' 28 | @Args: 29 | v: (B, V, 3) 30 | vi: (B, F, 3) 31 | @Returns: 32 | face_normals: (B, F, 3) 33 | ''' 34 | B = v.shape[0] 35 | vi = vi.expand(B, -1, -1) 36 | 37 | # p0 = torch.stack([index(v[i], vi[i, :, 0], 0) for i in range(b)]) 38 | # p1 = torch.stack([index(v[i], vi[i, :, 1], 0) for i in range(b)]) 39 | # p2 = torch.stack([index(v[i], vi[i, :, 2], 0) for i in range(b)]) 40 | p0 = torch.stack([v[i].index_select(0, vi[i, :, 0]) for i in range(B)]) 41 | p1 = torch.stack([v[i].index_select(0, vi[i, :, 1]) for i in range(B)]) 42 | p2 = torch.stack([v[i].index_select(0, vi[i, :, 2]) for i in range(B)]) 43 | v0 = p1 - p0 44 | v1 = p2 - p0 45 | n = torch.cross(v0, v1, dim=-1) 46 | return F.normalize(n, dim=-1) 47 | 48 | 49 | def compute_vertex_normals(v, vi, fn): 50 | ''' 51 | @Args: 52 | v: (B, V, 3), vertex coordinates 53 | vi: (B, F, 3), vertex indices 54 | fn: (B, F, 3), face normals 55 | @Returns: 56 | vn: (B, V, 3), vertex normals 57 | ''' 58 | fn_exp = fn[:, :, None, :].expand(-1, -1, 3, -1).reshape(fn.shape[0], -1, 3) # repeat 3 times for 3 vertices of a face 59 | vi_flat = vi.view(vi.shape[0], -1).expand(v.shape[0], -1) 60 | vn = torch.zeros_like(v) 61 | 62 | for j in range(3): 63 | vn[..., j].scatter_add_(1, vi_flat, fn_exp[..., j]) 64 | norm = torch.norm(vn, dim=-1, keepdim=True) 65 | vn = vn / norm.clamp(min=1e-8) 66 | return vn 67 | 68 | 69 | def subdivide_mesh(mesh): 70 | ''' 71 | Args: 72 | mesh: open3d.geometry.TriangleMesh 73 | ''' 74 | mesh_sub = mesh.subdivide_midpoint(number_of_iterations=2) 75 | return mesh_sub 76 | 77 | 78 | def sample_mesh(mesh): 79 | ''' 80 | Args: 81 | mesh: open3d.geometry.TriangleMesh 82 | pcd: open3d.geometry.PointCloud 83 | ''' 84 | pcd = mesh.sample_points_uniformly(2000) 85 | return pcd 86 | 87 | 88 | ###################### Geometry ###################### 89 | 90 | def get_neighs_1ring(mesh): 91 | ''' 92 | Args: 93 | mesh: trimesh.Trimesh 94 | Returns: 95 | neighs_1ring/(nIds): list of list. 96 | ''' 97 | g = nx.from_edgelist(mesh.edges_unique) 98 | neighs_1ring = [list(g[i].keys()) for i in range(len(mesh.vertices))] 99 | return neighs_1ring 100 | 101 | 102 | def vertex_trans_fitting(fixed_poss, opt_poss, nIds): 103 | ''' 104 | Args: 105 | fixed_poss: (n, 3) 106 | opt_poss: (n, 3) 107 | nIds: neighs_1ring, list of list 108 | ''' 109 | assert(fixed_poss.shape == opt_poss.shape) 110 | numV = fixed_poss.shape[0] 111 | 112 | vertex_trans = [] 113 | for vId in range(numV): 114 | ids = nIds[vId] 115 | valence = len(ids) 116 | 117 | P = fixed_poss[ids, :] - fixed_poss[vId] # (v, 3) 118 | Q = opt_poss[ids, :] - opt_poss[vId] # (v, 3) 119 | A = (Q.T @ P) @ np.linalg.pinv(P.T @ P) # (3, 3) 120 | b = opt_poss[vId].reshape(-1, 1) - A @ fixed_poss[vId].reshape(-1, 1) # (3, 1) 121 | vertex_trans.append( np.concatenate((A, b), axis=-1) ) # (3, 4) 122 | 123 | return vertex_trans 124 | 125 | 126 | def embedded_deformation(mesh, mesh_sim, mesh_sim_def, k=20): 127 | ''' 128 | Args: 129 | mesh: trimesh, (n, 3) 130 | mesh_sim: trimesh, (ns, 3) 131 | mesh_sim_def: trimesh, (ns, 3) 132 | Returns: 133 | mesh_def: trimesh, (n, 3) 134 | ''' 135 | 136 | nIds = get_neighs_1ring(mesh_sim) 137 | 138 | cur_trans = vertex_trans_fitting(mesh_sim.vertices, mesh_sim_def.vertices, nIds) 139 | 140 | DIS, IDX = pcu.k_nearest_neighbors(mesh.vertices, mesh_sim.vertices, k=k) # dense to sparse, (n, k), (n, k) 141 | 142 | sigma = np.median(DIS[:, 1]) # median of all NN distance. NOTE: 1 is better than 0 143 | weights = np.exp(-(DIS * DIS / 2 / sigma / sigma)) # (n, k) 144 | weights = weights / np.sum(weights, axis=-1, keepdims=True) # (n, k) 145 | 146 | mesh_def_verts = np.zeros_like(mesh.vertices) # (n, 3) 147 | for iv in range(mesh.vertices.shape[0]): 148 | tPos = np.zeros((3, 1)) 149 | sPos = mesh.vertices[iv].reshape(3, 1) # (3, 1) 150 | for ik in range(k): 151 | transId = IDX[iv, ik] 152 | w = weights[iv, ik] 153 | A = cur_trans[transId][:, 0:3] # (3, 3) 154 | b = cur_trans[transId][:, 3:4] # (3, 1) 155 | tPos = tPos + w * (A @ sPos + b) # (3, 1) 156 | mesh_def_verts[iv] = tPos.reshape(-1) 157 | 158 | mesh_def = trimesh.Trimesh(vertices=mesh_def_verts, faces=mesh.faces, process=False) 159 | return mesh_def 160 | 161 | 162 | 163 | if __name__ == '__main__': 164 | # import io_utils 165 | # v, vt, vi, vti = io_utils.read_obj_uv('../smpl_uv.obj') 166 | # fn = compute_face_normals(torch.FloatTensor(v[None, ...]), torch.LongTensor(vi[None, ...])) 167 | 168 | from IPython import embed; embed() 169 | 170 | -------------------------------------------------------------------------------- /registration_dfaust/gen_edges.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | import os 4 | import pickle 5 | import trimesh 6 | import argparse 7 | import numpy as np 8 | import scipy.io as sio 9 | import networkx as nx 10 | 11 | import open3d as o3d 12 | import matplotlib.pyplot as plt 13 | from sklearn.neighbors import NearestNeighbors 14 | 15 | 16 | def get_template_idx(pkl, template_fid): 17 | for k, v in pkl.items(): 18 | if v['fid'] == template_fid: 19 | template_idx = k 20 | return template_idx 21 | 22 | 23 | def vis_edges_from_adj(A, pkl, mesh_root): 24 | import vis_utils 25 | # vis who can reach current shape 26 | num = A.shape[0] 27 | nbatch = 20 28 | for i in range((num//nbatch)): 29 | A_chunk = A[:, i * nbatch : (i+1) * nbatch].T # (nbatch, num) 30 | mesh_list = [] 31 | for eid, edges in enumerate(A_chunk): 32 | src_idx = i * nbatch + eid 33 | interp_ids = np.where(edges)[1].tolist() 34 | interp_ids = [src_idx] + interp_ids 35 | print(i, eid, interp_ids) 36 | for ii, idx in enumerate(interp_ids): 37 | fid = pkl[idx]['fid'] 38 | fname = '/'.join(fid.split('-')) + '.obj' 39 | # print(fname) 40 | mesh = trimesh.load(f"{mesh_root}/{fname}", process=False) 41 | mesh.vertices = mesh.vertices + np.array([0, 0, 2]) * ii + np.array([2, 0, 0]) * eid 42 | mesh_list.append(vis_utils.create_triangle_mesh(mesh.vertices, mesh.faces)) 43 | coord = o3d.geometry.TriangleMesh.create_coordinate_frame() 44 | o3d.visualization.draw_geometries(mesh_list + [coord]) 45 | from IPython import embed; embed() 46 | 47 | 48 | def vis_edges_from_ids(vis_ids, edges_list, pkl, mesh_root): 49 | import vis_utils 50 | mesh_list = [] 51 | for ii, idx in enumerate(vis_ids): 52 | fid = pkl[idx]['fid'] 53 | fname = '/'.join(fid.split('-')) + '.obj' 54 | mesh = trimesh.load(f"{mesh_root}/{fname}", process=False) 55 | mesh.vertices = mesh.vertices + np.array([2, 0, 0]) * ii 56 | mesh_list.append(vis_utils.create_triangle_mesh(mesh.vertices, mesh.faces)) 57 | coord = o3d.geometry.TriangleMesh.create_coordinate_frame() 58 | o3d.visualization.draw_geometries(mesh_list + [coord]) 59 | 60 | 61 | if __name__ == '__main__': 62 | ''' 63 | Example commands: python gen_edges.py --epoch 5499 --split test 64 | ''' 65 | parser = argparse.ArgumentParser() 66 | parser.add_argument("--epoch", type=int, required=True, default=None, help='e.g. 2999 or 5499') 67 | parser.add_argument("--split", type=str, required=True, default='test', help='{train, test}, use train or test dataset') 68 | parser.add_argument("--data_root", type=str,\ 69 | default='./data/dfaust1k/arapTest1k/ivae_dfaustTest1k_8B8_lr1k_arap_8B8_SE3k_inv_SLS_w1e-3/results/test/analysis_sdf/',\ 70 | help='dir to e.g. latents_all_test_5499.npy and latents_all_test_5499.pkl') 71 | args = parser.parse_args() 72 | 73 | assert(args.split == 'test') 74 | 75 | template_fid = '50009-running_on_spot-running_on_spot.000366' 76 | 77 | pkl_path = f'{args.data_root}/latents_all_{args.split}_{args.epoch}.pkl' 78 | npy_path = f'{args.data_root}/latents_all_{args.split}_{args.epoch}.npy' 79 | pkl = pickle.load(open(pkl_path, 'rb')) 80 | 81 | template_idx = get_template_idx(pkl, template_fid) 82 | 83 | latents_all = np.array([v['latent'] for k, v in pkl.items()]) 84 | latents_all_tmp = np.load(npy_path) 85 | assert(np.all(latents_all_tmp == latents_all)) 86 | 87 | # fit neigh 88 | num_neighbors = 25 # 5: 6193 edges, 10: 12617 edges 89 | neigh = NearestNeighbors(n_neighbors=num_neighbors) 90 | neigh.fit(latents_all) 91 | 92 | _, neigh_ids = neigh.kneighbors(latents_all) 93 | assert(np.all(neigh_ids[:, 0] == np.arange(latents_all.shape[0]))) 94 | src_ids = np.tile(neigh_ids[:, 0:1], (1, num_neighbors - 1)) # (N, num_neighbors - 1) 95 | tgt_ids = neigh_ids[:, 1:] # (N, num_neighbors - 1) 96 | 97 | edge_ids = np.stack((src_ids, tgt_ids), axis=-1).reshape(-1, 2) 98 | edge_ids = np.concatenate((edge_ids, edge_ids[:, [1, 0]]), axis=0) # (E, 2) 99 | 100 | # template to all 101 | template_ids = np.array([template_idx] * latents_all.shape[0]) 102 | edge_template_ids = np.concatenate((template_ids[:, None], neigh_ids[:, 0:1]), axis=-1) # (N, 2) 103 | 104 | # template KNN 105 | # num_neighbors_temp = 500 106 | # _, neigh_ids_temp = neigh.kneighbors(latents_all[template_idx, :][None, :], n_neighbors=num_neighbors_temp) 107 | # src_temp_ids = np.array([template_idx] * (num_neighbors_temp - 1)) 108 | # tgt_temp_ids = neigh_ids_temp[0, 1:] 109 | # edge_template_ids = np.stack((src_temp_ids, tgt_temp_ids), axis=-1).reshape(-1, 2) 110 | 111 | edge_ids = np.concatenate((edge_ids, edge_template_ids), axis=0) # (E+N, 2) 112 | 113 | G = nx.DiGraph() 114 | G.add_nodes_from(neigh_ids[:, 0]) # use indices as labels 115 | G.add_edges_from(edge_ids) 116 | G.remove_edges_from(nx.selfloop_edges(G)) # the only self loop is template to template 117 | 118 | edge_ids_new = np.array(G.edges) 119 | print(f"\n num of edges: {edge_ids_new.shape[0]} \n") 120 | dump_root = f"{args.data_root}/edge_ids/" 121 | if not os.path.exists(dump_root): 122 | os.makedirs(dump_root) 123 | np.save(f"{dump_root}/{args.split}_{args.epoch}_edge_ids_K{num_neighbors}.npy", edge_ids_new) 124 | # np.save(f"{dump_root}/{args.split}_{args.epoch}_edge_ids_K{num_neighbors}_tempKNN.npy", edge_ids_new) 125 | 126 | ############################## visialuzation ############################## 127 | # from IPython import embed; embed() 128 | # mesh_root = '/media/yanghaitao/HaitaoYang/Graphicsai_Backup/mnt/yanghaitao/Dataset/DFAUST/registrations/' 129 | # for each mesh, visualize: from which mesh we can reach the current mesh with 1 step 130 | # A = nx.adjacency_matrix(G).todense() 131 | # vis_edges_from_adj(A, pkl, mesh_root) 132 | 133 | # draw G 134 | # nx.draw(G, with_labels=True, font_weight='bold') 135 | # plt.show() 136 | 137 | # draw specific shapes 138 | # vis_ids = [249, 805, 808] 139 | # vis_edges_from_ids(vis_ids, edge_ids, pkl, mesh_root) 140 | 141 | # from IPython import embed; embed() 142 | 143 | 144 | 145 | 146 | 147 | 148 | -------------------------------------------------------------------------------- /models/conv/cheb_conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Parameter 3 | from torch_geometric.utils import remove_self_loops, add_self_loops 4 | from torch_geometric.utils import get_laplacian 5 | from .message_passing import MessagePassing 6 | 7 | from ..inits import glorot, zeros 8 | 9 | 10 | class ChebConv(MessagePassing): 11 | """ 12 | Args: 13 | in_channels (int): Size of each input sample. 14 | out_channels (int): Size of each output sample. 15 | K (int): Chebyshev filter size, *i.e.* number of hops :math:`K`. 16 | normalization (str, optional): The normalization scheme for the graph 17 | Laplacian (default: :obj:`"sym"`): 18 | 19 | 1. :obj:`None`: No normalization 20 | :math:`\mathbf{L} = \mathbf{D} - \mathbf{A}` 21 | 22 | 2. :obj:`"sym"`: Symmetric normalization 23 | :math:`\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1/2} \mathbf{A} 24 | \mathbf{D}^{-1/2}` 25 | 26 | 3. :obj:`"rw"`: Random-walk normalization 27 | :math:`\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1} \mathbf{A}` 28 | 29 | You need to pass :obj:`lambda_max` to the :meth:`forward` method of 30 | this operator in case the normalization is non-symmetric. 31 | :obj:`\lambda_max` should be a :class:`torch.Tensor` of size 32 | :obj:`[num_graphs]` in a mini-batch scenario and a scalar when 33 | operating on single graphs. 34 | You can pre-compute :obj:`lambda_max` via the 35 | :class:`torch_geometric.transforms.LaplacianLambdaMax` transform. 36 | cached (bool, optional): If set to :obj:`True`, the layer will cache 37 | the computation of the scaled and normalized Laplacian 38 | :math:`\frac{2\mathbf{L}}{\lambda_{\max}} - \mathbf{I}` on first execution, 39 | and will use the cached version for further executions. 40 | This parameter should only be set to :obj:`True` in 41 | fixed graph scenarios. (default: :obj:`True`) 42 | bias (bool, optional): If set to :obj:`False`, the layer will not learn 43 | an additive bias. (default: :obj:`True`) 44 | **kwargs (optional): Additional arguments of 45 | :class:`torch_geometric.nn.conv.MessagePassing`. 46 | """ 47 | 48 | def __init__(self, 49 | in_channels, 50 | out_channels, 51 | K, 52 | normalization='sym', 53 | cached=True, 54 | bias=True, 55 | **kwargs): 56 | super(ChebConv, self).__init__(aggr='add', **kwargs) 57 | 58 | assert K > 0 59 | assert normalization in [None, 'sym', 'rw'], 'Invalid normalization' 60 | 61 | self.in_channels = in_channels 62 | self.out_channels = out_channels 63 | self.normalization = normalization 64 | self.cached = cached 65 | self.weight = Parameter(torch.Tensor(K, in_channels, out_channels)) 66 | 67 | if bias: 68 | self.bias = Parameter(torch.Tensor(out_channels)) 69 | else: 70 | self.register_parameter('bias', None) 71 | 72 | self.reset_parameters() 73 | 74 | def reset_parameters(self): 75 | glorot(self.weight) 76 | zeros(self.bias) 77 | self.cached_result = None 78 | self.cached_num_edges = None 79 | 80 | @staticmethod 81 | def norm(edge_index, 82 | num_nodes, 83 | edge_weight, 84 | normalization, 85 | lambda_max, 86 | dtype=None, 87 | batch=None): 88 | 89 | edge_index, edge_weight = remove_self_loops(edge_index, edge_weight) 90 | 91 | edge_index, edge_weight = get_laplacian(edge_index, edge_weight, 92 | normalization, dtype, 93 | num_nodes) 94 | 95 | if batch is not None and torch.is_tensor(lambda_max): 96 | lambda_max = lambda_max[batch[edge_index[0]]] 97 | 98 | edge_weight = (2.0 * edge_weight) / lambda_max 99 | edge_weight[edge_weight == float('inf')] = 0 100 | 101 | edge_index, edge_weight = add_self_loops(edge_index, 102 | edge_weight, 103 | fill_value=-1, 104 | num_nodes=num_nodes) 105 | 106 | return edge_index, edge_weight 107 | 108 | def forward(self, 109 | x, 110 | edge_index, 111 | edge_weight=None, 112 | batch=None, 113 | lambda_max=None): 114 | """""" 115 | if self.normalization != 'sym' and lambda_max is None: 116 | raise ValueError('You need to pass `lambda_max` to `forward() in`' 117 | 'case the normalization is non-symmetric.') 118 | lambda_max = 2.0 if lambda_max is None else lambda_max 119 | 120 | if not self.cached or self.cached_result is None: 121 | edge_index, norm = self.norm(edge_index, 122 | x.size(1), 123 | edge_weight, 124 | self.normalization, 125 | lambda_max, 126 | dtype=x.dtype, 127 | batch=batch) 128 | self.cached_result = edge_index, norm 129 | 130 | edge_index, norm = self.cached_result 131 | Tx_0 = x 132 | out = torch.matmul(Tx_0, self.weight[0]) 133 | 134 | if self.weight.size(0) > 1: 135 | Tx_1 = self.propagate(edge_index, x=x, norm=norm) 136 | out = out + torch.matmul(Tx_1, self.weight[1]) 137 | 138 | for k in range(2, self.weight.size(0)): 139 | Tx_2 = 2 * self.propagate(edge_index, x=Tx_1, norm=norm) - Tx_0 140 | out = out + torch.matmul(Tx_2, self.weight[k]) 141 | Tx_0, Tx_1 = Tx_1, Tx_2 142 | 143 | if self.bias is not None: 144 | out = out + self.bias 145 | 146 | return out 147 | 148 | def message(self, x_j, norm): 149 | return norm.view(-1, 1) * x_j 150 | 151 | def __repr__(self): 152 | return '{}({}, {}, K={}, normalization={})'.format( 153 | self.__class__.__name__, self.in_channels, self.out_channels, 154 | self.weight.size(0), self.normalization) 155 | -------------------------------------------------------------------------------- /models/meshnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import numpy as np 6 | from loguru import logger 7 | 8 | from models.arap import ARAP 9 | from models.asap import ASAP 10 | from models.meshnet_base import MeshDecoder 11 | 12 | from utils import geom_utils 13 | from torch_cluster import knn 14 | import pytorch3d.loss 15 | 16 | class MeshNet(nn.Module): 17 | def __init__(self, 18 | config, 19 | dataset, 20 | edge_index, 21 | down_transform, 22 | up_transform, 23 | ): 24 | super().__init__() 25 | self.config = config 26 | self.dataset = dataset 27 | self.decoder = MeshDecoder(edge_index=edge_index, 28 | down_transform=down_transform, 29 | up_transform=up_transform, 30 | **config.model.mesh) 31 | 32 | self.use_mesh_arap_with_asap = config.loss.get('use_mesh_arap_with_asap', False) 33 | if self.use_mesh_arap_with_asap: 34 | self.arap = ASAP(dataset.template_faces, dataset.template_faces.max() + 1) 35 | logger.info("\nuse_mesh_arap_with_asap, weight_asap=0.05\n") 36 | else: 37 | self.arap = ARAP(dataset.template_faces, dataset.template_faces.max() + 1) 38 | 39 | self.data_mean_gpu = torch.from_numpy(dataset.mean_init).float() 40 | self.data_std_gpu = torch.from_numpy(dataset.std_init).float() 41 | 42 | def forward(self, lat_vecs, batch_dict, config, state_info=None): 43 | ''' 44 | Args: 45 | lat_vecs: (B, latent_dim) 46 | ''' 47 | mesh_out_pred = self.decoder(lat_vecs) # (B, N, 3), normalized coordinates 48 | batch_dict["mesh_verts_nml_pred"] = mesh_out_pred 49 | 50 | if state_info is not None: 51 | self.get_loss(lat_vecs, batch_dict, config, state_info) 52 | 53 | return batch_dict 54 | 55 | @staticmethod 56 | def get_point2plane_loss(pred_shape, gt_shape, gt_faces): 57 | ''' 58 | Args: 59 | gt_shape: (B, Vx, 3) 60 | gt_faces: (B, F, 3) 61 | pred_shape: (B, Vy, 3) 62 | ''' 63 | assert(gt_shape.shape[0] == pred_shape.shape[0]) 64 | gt_fnormals = geom_utils.compute_face_normals(gt_shape, gt_faces) 65 | gt_vnormals = geom_utils.compute_vertex_normals(gt_shape, gt_faces, gt_fnormals) 66 | batch_size, num_x, num_y = gt_shape.shape[0], gt_shape.shape[1], pred_shape.shape[1] 67 | x = gt_shape.reshape(-1, 3) # (B*Vx, 3) 68 | y = pred_shape.reshape(-1, 3) # (B*Vy, 3) 69 | batch_x = torch.arange(batch_size)[:, None].expand(-1, num_x).reshape(-1).cuda() # (B*Vx) 70 | batch_y = torch.arange(batch_size)[:, None].expand(-1, num_y).reshape(-1).cuda() # (B*Vy) 71 | corres = knn(x, y, 1, batch_x, batch_y) 72 | diff = (x[corres[1]] - y[corres[0]]).reshape(batch_size, num_y, 3) # (B, Vy, 3) 73 | vnorm_corres = gt_vnormals.reshape(-1, 3)[corres[1]].reshape(batch_size, num_y, 3) # (B, Vy, 3) 74 | dist_plane = torch.abs(torch.sum(diff * vnorm_corres, dim=-1)) # (B, Vy) 75 | l1_loss = torch.mean(dist_plane) 76 | return l1_loss 77 | 78 | 79 | @staticmethod 80 | def get_jacobian_rand(cur_shape, z, data_mean_gpu, data_std_gpu, model, device, epsilon=[1e-3], nz_max=60): 81 | nb, nz = z.size() 82 | _, n_vert, nc = cur_shape.size() 83 | if nz >= nz_max: 84 | rand_idx = np.random.permutation(nz)[:nz_max] 85 | nz = nz_max 86 | else: 87 | rand_idx = np.arange(nz) 88 | 89 | jacobian = torch.zeros((nb, n_vert*nc, nz)).to(device) 90 | for i, idx in enumerate(rand_idx): 91 | dz = torch.zeros(z.size()).to(device) 92 | dz[:, idx] = epsilon 93 | z_new = z + dz 94 | out_new = model(z_new) 95 | shape_new = out_new * data_std_gpu + data_mean_gpu 96 | dout = (shape_new - cur_shape).view(nb, -1) 97 | jacobian[:, :, i] = dout/epsilon 98 | return jacobian 99 | 100 | 101 | def get_loss(self, lat_vecs, batch_dict, config, state_info): 102 | epoch = state_info['epoch'] 103 | device = batch_dict['verts_init_nml'].device 104 | loss = torch.zeros(1, device=device) 105 | self.data_std_gpu = self.data_std_gpu.to(device) 106 | self.data_mean_gpu = self.data_mean_gpu.to(device) 107 | assert(config.rep in ['mesh']) 108 | 109 | verts_pred = batch_dict['mesh_verts_nml_pred'] * self.data_std_gpu + self.data_mean_gpu # (B, V, 3) 110 | verts_init = batch_dict['verts_init_nml'] * self.data_std_gpu + self.data_mean_gpu 111 | verts_raw = batch_dict['verts_raw'] 112 | faces_raw = batch_dict['faces_raw'] 113 | assert(verts_raw.shape[0] == faces_raw.shape[0]) 114 | verts_raw_lengths = batch_dict['verts_raw_lengths'] if 'verts_raw_lengths' in batch_dict else None 115 | 116 | # mesh init loss 117 | if config.use_point2point_loss: 118 | point2point_loss = F.l1_loss(verts_pred, verts_init, reduction='mean') * config.loss.point2point_loss_weight 119 | loss += point2point_loss 120 | batch_dict['point2point_loss'] = point2point_loss 121 | state_info['point2point_loss'] = point2point_loss.item() 122 | 123 | if config.use_point2plane_loss: 124 | raise NotImplementedError("Unbatched operation is not implemented") 125 | point2plane_loss = self.get_point2plane_loss(verts_pred, verts_raw, faces_raw) * config.loss.point2plane_loss_weight 126 | loss += point2plane_loss 127 | batch_dict['point2plane_loss'] = point2plane_loss 128 | state_info['point2plane_loss'] = point2plane_loss.item() 129 | 130 | if config.use_chamfer_loss: 131 | chamfer_loss, _ = pytorch3d.loss.chamfer_distance(x=verts_pred, y=verts_raw, x_lengths=None, y_lengths=verts_raw_lengths) * config.loss.chamfer_loss_weight 132 | loss += chamfer_loss 133 | batch_dict['chamfer_loss'] = chamfer_loss 134 | state_info['chamfer_loss'] = chamfer_loss.item() 135 | 136 | if config.use_mesh_arap: 137 | jacob = self.get_jacobian_rand( 138 | verts_pred, lat_vecs, self.data_mean_gpu, self.data_std_gpu, 139 | self.decoder, device, epsilon=0.1, nz_max=config.loss.nz_max) 140 | try: 141 | arap_energy = self.arap(verts_pred, jacob, weight_asap=config.loss.mesh_weight_asap) / jacob.shape[-1] 142 | except: 143 | from IPython import embed; embed() 144 | mesh_arap_loss = arap_energy * config.loss.mesh_arap_weight 145 | 146 | loss += mesh_arap_loss 147 | batch_dict['mesh_arap_loss'] = mesh_arap_loss 148 | state_info['mesh_arap_loss'] = mesh_arap_loss.item() 149 | 150 | batch_dict["loss"] = loss 151 | 152 | 153 | 154 | 155 | -------------------------------------------------------------------------------- /registration_dfaust/gen_graph.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | import os 4 | import pickle 5 | import trimesh 6 | import argparse 7 | import numpy as np 8 | import scipy.io as sio 9 | import networkx as nx 10 | 11 | import open3d as o3d 12 | import matplotlib.pyplot as plt 13 | from sklearn.neighbors import NearestNeighbors 14 | 15 | import point_cloud_utils as pcu 16 | 17 | def get_template_idx(pkl, template_fid): 18 | for k, v in pkl.items(): 19 | if v['fid'] == template_fid: 20 | template_idx = k 21 | return template_idx 22 | 23 | 24 | def vis_edges_from_adj(A, pkl, mesh_root): 25 | import vis_utils 26 | # vis who can reach current shape 27 | num = A.shape[0] 28 | nbatch = 20 29 | for i in range((num//nbatch)): 30 | A_chunk = A[:, i * nbatch : (i+1) * nbatch].T # (nbatch, num) 31 | mesh_list = [] 32 | for eid, edges in enumerate(A_chunk): 33 | src_idx = i * nbatch + eid 34 | interp_ids = np.where(edges)[1].tolist() 35 | interp_ids = [src_idx] + interp_ids 36 | print(i, eid, interp_ids) 37 | for ii, idx in enumerate(interp_ids): 38 | fid = pkl[idx]['fid'] 39 | fname = '/'.join(fid.split('-')) + '.obj' 40 | # print(fname) 41 | mesh = trimesh.load(f"{mesh_root}/{fname}", process=False) 42 | mesh.vertices = mesh.vertices + np.array([0, 0, 2]) * ii + np.array([2, 0, 0]) * eid 43 | mesh_list.append(vis_utils.create_triangle_mesh(mesh.vertices, mesh.faces)) 44 | coord = o3d.geometry.TriangleMesh.create_coordinate_frame() 45 | o3d.visualization.draw_geometries(mesh_list + [coord]) 46 | from IPython import embed; embed() 47 | 48 | 49 | def vis_edges_from_ids(vis_ids, edges_list, pkl, mesh_root): 50 | import vis_utils 51 | mesh_list = [] 52 | for ii, idx in enumerate(vis_ids): 53 | fid = pkl[idx]['fid'] 54 | fname = '/'.join(fid.split('-')) + '.obj' 55 | mesh = trimesh.load(f"{mesh_root}/{fname}", process=False) 56 | mesh.vertices = mesh.vertices + np.array([2, 0, 0]) * ii 57 | mesh_list.append(vis_utils.create_triangle_mesh(mesh.vertices, mesh.faces)) 58 | coord = o3d.geometry.TriangleMesh.create_coordinate_frame() 59 | o3d.visualization.draw_geometries(mesh_list + [coord]) 60 | 61 | 62 | if __name__ == '__main__': 63 | ''' 64 | Example commands: python gen_graph.py --epoch 6499 --split test 65 | ''' 66 | parser = argparse.ArgumentParser() 67 | parser.add_argument("--epoch", type=int, required=True, default=None, help='e.g. 2999 or 6499') 68 | parser.add_argument("--split", type=str, required=True, default='test', help='{train, test}, use train or test dataset') 69 | parser.add_argument("--analysis_dir", type=str,\ 70 | default='../work_dir/dfaust/dfaust1kBak/ivae_dfaustTest1k_8B8_lr1k_arap_8B8_SE3k_inv_SLS_w1e-3/results/test/analysis_sdf/',\ 71 | help='dir to e.g. latents_all_test_5499.npy and latents_all_test_5499.pkl') 72 | parser.add_argument("--interp_dir", type=str, default='./dfaust1k/mesh_def/', help='dir to interpolated meshes') 73 | parser.add_argument("--dump_dir", type=str, default='./dfaust1k/mesh_corres/', help='dir to mesh init and correspondence') 74 | parser.add_argument("--data_dir", type=str, default='/scratch/cluster/yanght/Dataset/Human/DFAUST/registrations/', help='dir to interpolated meshes') 75 | parser.add_argument("--edge_ids_path", type=str, required=True, default='{args.analysis_dir}/edge_ids/{args.split}_{args.epoch}_edge_ids_K5.npy', help='path to edge_ids npy') 76 | args = parser.parse_args() 77 | 78 | assert(args.split == 'test') 79 | 80 | template_fid = '50009-running_on_spot-running_on_spot.000366' 81 | 82 | pkl_path = f'{args.analysis_dir}/latents_all_{args.split}_{args.epoch}.pkl' 83 | npy_path = f'{args.analysis_dir}/latents_all_{args.split}_{args.epoch}.npy' 84 | # edge_ids_path = f'{args.analysis_dir}/edge_ids/{args.split}_{args.epoch}_edge_ids_K5.npy' 85 | pkl = pickle.load(open(pkl_path, 'rb')) 86 | edge_ids = np.load(args.edge_ids_path) 87 | 88 | template_idx = get_template_idx(pkl, template_fid) 89 | assert(template_idx == 374) 90 | num_nodes = len(pkl) 91 | assert(num_nodes == 1000) 92 | template_fname = '/'.join(template_fid.split('-')) + '.obj' 93 | mesh_template = trimesh.load(f"{args.data_dir}/{template_fname}", process=False, maintain_order=True) 94 | 95 | corres_dict = {} 96 | dists_dict = {} 97 | for eid, (sid, tid) in enumerate(edge_ids): 98 | sfid = pkl[sid]['fid'] 99 | tfid = pkl[tid]['fid'] 100 | mesh_def = trimesh.load(f"{args.interp_dir}/meshdef_{sid}_{tid}.obj", process=False, maintain_order=True) 101 | 102 | sfname = '/'.join(sfid.split('-')) + '.obj' 103 | mesh_src = trimesh.load(f"{args.data_dir}/{sfname}", process=False, maintain_order=True) 104 | tfname = '/'.join(tfid.split('-')) + '.obj' 105 | mesh_tgt = trimesh.load(f"{args.data_dir}/{tfname}", process=False, maintain_order=True) 106 | 107 | dists_def_to_tgt, corres_def_to_tgt = pcu.k_nearest_neighbors(mesh_def.vertices, mesh_tgt.vertices, k=1) 108 | dists_tgt_to_def, corres_tgt_to_def = pcu.k_nearest_neighbors(mesh_tgt.vertices, mesh_def.vertices, k=1) 109 | # corres = corres_tgt_to_def[corres_def_to_tgt] 110 | # diff = np.linalg.norm(mesh_template.vertices[corres] - mesh_template.vertices, axis=-1) 111 | dists_dict[(sid, tid)] = dists_tgt_to_def 112 | corres_dict[(sid, tid)] = corres_def_to_tgt 113 | 114 | print(f"eid={eid}, sid={sid}, tid={tid}, sfid={sfid}, tfid={tfid}") 115 | 116 | dists_mean_dict = {} 117 | for eid, (sid, tid) in enumerate(edge_ids): 118 | dists_mean_dict[(sid, tid)] = dists_dict[(sid, tid)].mean() 119 | dists_min = np.min([v for k, v in dists_mean_dict.items()]) 120 | 121 | weights_dict = {} 122 | # for eid, (sid, tid) in enumerate(edge_ids): 123 | # w_diff = dists_dict[(sid, tid)].mean() - dists_min 124 | # weights_dict[(sid, tid)] = w_diff * w_diff 125 | # IMPORTANT NOTE: instead of using mean, use distribution to filter out bad edges 126 | for eid, (sid, tid) in enumerate(edge_ids): 127 | w_diff = dists_dict[(sid, tid)].copy() 128 | w_diff.sort() 129 | weights_dict[(sid, tid)] = w_diff.mean() if w_diff[6810] < 0.02 else 100 130 | 131 | G = nx.DiGraph() 132 | G.add_nodes_from(np.arange(num_nodes)) # use indices as labels 133 | G.add_edges_from(edge_ids) 134 | nx.set_edge_attributes(G, values = weights_dict, name = 'weight') 135 | 136 | lengths, paths = nx.single_source_dijkstra(G, template_idx) 137 | 138 | if not os.path.exists(args.dump_dir): 139 | os.makedirs(args.dump_dir) 140 | for mesh_idx in range(num_nodes): 141 | print(mesh_idx) 142 | path = paths[mesh_idx] 143 | assert(path[0] == template_idx) 144 | 145 | corres = np.arange(mesh_template.vertices.shape[0]) 146 | if mesh_idx != template_idx: 147 | for ii in range(len(path[:-1])): 148 | sid = path[ii] 149 | tid = path[ii + 1] 150 | corres = corres_dict[(sid, tid)][corres] 151 | 152 | fid = pkl[mesh_idx]['fid'] 153 | fname = '/'.join(fid.split('-')) + '.obj' 154 | mesh = trimesh.load(f"{args.data_dir}/{fname}", process=False, maintain_order=True) 155 | verts_corres = mesh.vertices[corres] 156 | faces_corres = mesh_template.faces 157 | mesh_corres = trimesh.Trimesh(vertices=verts_corres, faces=faces_corres, process=False) 158 | 159 | mesh_corres.export(f"{args.dump_dir}/{fid}_init.obj") 160 | np.save(f"{args.dump_dir}/{fid}_corres.npy", corres) 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | -------------------------------------------------------------------------------- /registration_dfaust/multiCorres_sync_dfaust1k/non_rigid_icp2.m: -------------------------------------------------------------------------------- 1 | function [optimized_poss] = non_rigid_icp2(template, target, init_poss_vec, lambda, beta, outerIterMax, innerIterMax) 2 | % Optimize the alignment between the template model and the target model 3 | % Using the initial pose_vec 4 | cur_poss = init_poss_vec; 5 | numV = size(init_poss_vec, 2); 6 | % 7 | ids1 = template.faceVIds(1,:); 8 | ids2 = template.faceVIds(2,:); 9 | ids3 = template.faceVIds(3,:); 10 | ROWs = [ids1;ids2;ids3]; 11 | COLs = [ids2;ids3;ids1]; 12 | VALs = ones(3,1)*ones(1,length(ids1)); 13 | A = sparse(ROWs, COLs, VALs, numV,numV); 14 | A = max(A, A'); 15 | [i,j,~] = find(A); 16 | edges = [i';j']; 17 | for vId = 1 : numV 18 | nIds{vId} = find(A(vId,:)); 19 | end 20 | cur_rotations = vertex_rotation_fitting(template.vertexPoss, cur_poss,... 21 | nIds); 22 | % 23 | target.vertexNors = compute_vertex_normal(target); 24 | for outerIter = 1 : outerIterMax % 16 25 | % Compute bidirectional correspondences 26 | corres = bidirectional_corres(cur_poss, target.vertexPoss); 27 | fprintf(' outerIter = %d\n', outerIter); 28 | for innerIter = 1 : innerIterMax 29 | % Perform Gauss-Newton optimization to solve the induced 30 | % optimization problem 31 | [cur_poss, cur_rotations, flag] = one_step_non_rigid_icp(... 32 | corres, target.vertexPoss, target.vertexNors, edges, template.vertexPoss, cur_poss,... 33 | cur_rotations, lambda, beta); 34 | if flag == 0 35 | break; 36 | end 37 | end 38 | end 39 | optimized_poss = cur_poss; 40 | 41 | % Perform one-step non-rigid ICP 42 | function [next_poss, next_rotations, flag] = one_step_non_rigid_icp(... 43 | corres, target_poss, target_nors, edges, ori_poss, cur_poss,... 44 | cur_rotations, lambda, beta) 45 | % 46 | % 47 | [H_data, g_data] = data_term(corres, target_poss, target_nors, cur_poss, beta); 48 | [H_def, g_def] = deformation_term(edges, ori_poss, cur_poss,... 49 | cur_rotations); 50 | H = H_data + H_def*lambda; 51 | g = g_data + g_def*lambda; 52 | % 53 | e_cur = energy(corres, target_poss, target_nors, edges, ori_poss, cur_poss,... 54 | cur_rotations, lambda, beta); 55 | dx = H\g; 56 | if norm(dx) < 1e-6 57 | flag = 0; 58 | next_poss = cur_poss; 59 | next_rotations = cur_rotations; 60 | return; 61 | end 62 | [next_poss, next_rotations] = update_variables(cur_poss,... 63 | cur_rotations, dx); 64 | e_next = energy(corres, target_poss, target_nors, edges, ori_poss, next_poss,... 65 | next_rotations, lambda, beta); 66 | if e_next < e_cur 67 | fprintf(' e_cur = %f, e_next = %f.\n', e_cur, e_next); 68 | flag = 1; 69 | else 70 | flag = 0; 71 | s = mean(diag(H))*1e-6; 72 | dim2 = size(H,2); 73 | for iter = 1 : 12 74 | dx = (H+s*sparse(1:dim2,1:dim2,ones(1,dim2)))\g; 75 | [next_poss, next_rotations] = update_variables(cur_poss,... 76 | cur_rotations, dx); 77 | e_next = energy(corres, target_poss, target_nors, edges, ori_poss, next_poss,... 78 | next_rotations, lambda, beta); 79 | if e_next < e_cur 80 | flag = 1; 81 | fprintf(' e_cur = %f, e_next = %f.\n', e_cur, e_next); 82 | break; 83 | end 84 | s = s*4; 85 | end 86 | end 87 | 88 | % Update the solution based on the current solution 89 | function [next_poss, next_rotations] = update_variables(cur_poss,... 90 | cur_rotations, dx) 91 | % 92 | numV = size(cur_poss, 2); 93 | next_poss = cur_poss + reshape(dx(1:(3*numV)), [3,numV]); 94 | for vId = 1 : numV 95 | rowIds = 3*numV + ((3*vId-2):(3*vId)); 96 | dc = dx(rowIds); 97 | dR = expm([0 -dc(3) dc(2); 98 | dc(3) 0 -dc(1); 99 | -dc(2) dc(1) 0]); 100 | next_rotations{vId} = dR*cur_rotations{vId}; 101 | end 102 | 103 | % The data-term 104 | function [H_data, g_data] = data_term(corres, target_poss, target_nors, cur_poss, beta) 105 | % 106 | numV = size(cur_poss, 2); 107 | dim = 3*numV; 108 | numCorres = size(corres, 2); 109 | t = double(reshape(target_poss(:, corres(2,:)) - cur_poss(:,corres(1,:)),... 110 | [3*numCorres,1])); 111 | tp = 3*kron(corres(1,:), ones(1,3))... 112 | + kron(ones(1,numCorres),[-2,-1,0]); 113 | J = sparse(1:(3*numCorres), tp, ones(1,3*numCorres), 3*numCorres, 2*dim); % 2*dim: (delta_p, c) \in R^6*numV 114 | 115 | % point-2-point term 116 | H_data = J'*J; 117 | g_data = J'*t; 118 | % point-2-plane term 119 | plane_dis = double(sum((target_poss(:, corres(2,:)) - cur_poss(:,corres(1,:))).*target_nors(:,corres(2,:)))); 120 | J2 = sparse((1:numCorres)'*ones(1,3),... 121 | 3*kron(corres(1,:)',ones(1,3)) + kron(ones(numCorres,1), [-2,-1,0]),... 122 | target_nors(:,corres(2,:))', numCorres, 2*dim); 123 | H_data = H_data*beta + (J2'*J2)*(1-beta); 124 | g_data = g_data*beta + (J2'*plane_dis')*(1-beta); 125 | 126 | % The deformation term 127 | function [H_def, g_def] = deformation_term(edges, ori_poss, cur_poss,... 128 | cur_rotations) 129 | % 130 | numV = size(ori_poss,2); 131 | numE = size(edges, 2); 132 | rowsJ = (1:(3*numE))'*ones(1,5); 133 | colsJ = zeros(3*numE, 5); 134 | valsJ = zeros(3*numE, 5); 135 | % dxi - dxj + [] x c - (R_i^c(pi0-pj0) - (pic-pjc)) 136 | colsJ(:,1) = 3*kron(edges(1,:)', ones(3,1))+kron(ones(numE,1), [-2,-1,0]'); 137 | valsJ(:,1) = ones(3*numE,1); 138 | colsJ(:,2) = 3*kron(edges(2,:)', ones(3,1))+kron(ones(numE,1), [-2,-1,0]'); 139 | valsJ(:,2) = -ones(3*numE,1); 140 | colsJ(:,3:5) = 3*numV + 3*kron(edges(1,:)', ones(3,3))... 141 | + kron(ones(numE,1), ones(3,1)*[-2,-1,0]); 142 | vec_d = zeros(3*numE,1); 143 | for eId = 1 : numE 144 | rowIds = (3*eId-2):(3*eId); 145 | sId = edges(1, eId); 146 | tId = edges(2, eId); 147 | vec_o_trans = cur_rotations{sId}*(ori_poss(:,sId) - ori_poss(:,tId)); 148 | vec_cur = (cur_poss(:,sId) - cur_poss(:,tId)); 149 | vec_d(rowIds) = vec_o_trans - vec_cur; 150 | valsJ(rowIds,3:5) = [0 -vec_o_trans(3) vec_o_trans(2); 151 | vec_o_trans(3) 0 -vec_o_trans(1); 152 | -vec_o_trans(2) vec_o_trans(1) 0]; 153 | end 154 | J = sparse(rowsJ, colsJ, valsJ, 3*numE, 6*numV); 155 | H_def = J'*J; 156 | g_def = J'*vec_d; 157 | 158 | % Compute the energy for non-rigid registration 159 | function [e] = energy(corres, target_poss, target_nors, edges, ori_poss, cur_poss, cur_rotations, lambda, beta) 160 | % Compute the cumulative squared distance 161 | dif = cur_poss(:, corres(1,:)) - target_poss(:, corres(2,:)); 162 | dis_plane = sum(dif.*target_nors(:,corres(2,:))); 163 | e = sum(sum(dif.*dif))*beta + sum(dis_plane.*dis_plane)*(1-beta); 164 | % Compute the rotation fitting residuals 165 | for eId = 1 : size(edges,2) 166 | sId = edges(1, eId); 167 | tId = edges(2, eId); 168 | P = ori_poss(:,sId) - ori_poss(:,tId); 169 | Q = cur_poss(:,sId) - cur_poss(:,tId); 170 | dif = cur_rotations{sId}*P - Q; 171 | e = e + lambda*(dif'*dif); 172 | end 173 | 174 | 175 | % Compute bi-directional correspondences 176 | function [corres] = bidirectional_corres(opt_poss, target_poss) 177 | % 178 | IDX1 = knnsearch(target_poss', opt_poss'); 179 | IDX2 = knnsearch(opt_poss', target_poss'); 180 | corres = [1:length(IDX1),IDX2';IDX1',1:length(IDX2)]; 181 | 182 | % 183 | function [vertex_rotations] = vertex_rotation_fitting(fixed_poss, opt_poss, nIds) 184 | % Use ARAP as initialization for R instead of using Identity 185 | numV = size(fixed_poss, 2); 186 | for vId = 1 : numV 187 | ids = nIds{vId}; 188 | valence = length(ids); 189 | P = double(fixed_poss(:,ids) - fixed_poss(:,vId)*ones(1,valence)); 190 | Q = double(opt_poss(:,ids) - opt_poss(:,vId)*ones(1,valence)); 191 | S = P*Q'; 192 | [u,v,w] = svd(S); 193 | R = w*u'; 194 | if det(S) < 0 195 | R = w*diag([1,1,-1])*u'; 196 | end 197 | vertex_rotations{vId} = R; 198 | end 199 | % 200 | function [vertex_normal] = compute_vertex_normal(mesh) 201 | % 202 | p1 = mesh.vertexPoss(:,mesh.faceVIds(1,:)); 203 | p2 = mesh.vertexPoss(:,mesh.faceVIds(2,:)); 204 | p3 = mesh.vertexPoss(:,mesh.faceVIds(3,:)); 205 | e12 = p1 - p2; 206 | e13 = p1 - p3; 207 | facenors = cross(e12, e13); 208 | lens = sqrt(sum(facenors.*facenors)) + 1e-10; 209 | facenors = facenors./(ones(3,1)*lens); 210 | numV = size(mesh.vertexPoss, 2); 211 | vertex_normal = zeros(3, numV); 212 | for fId = 1 : size(facenors,2) 213 | ids = mesh.faceVIds(:, fId); 214 | vertex_normal(:,ids(1)) = vertex_normal(:,ids(1)) + facenors(:,fId); 215 | vertex_normal(:,ids(2)) = vertex_normal(:,ids(2)) + facenors(:,fId); 216 | vertex_normal(:,ids(3)) = vertex_normal(:,ids(3)) + facenors(:,fId); 217 | end 218 | lens = sqrt(sum(vertex_normal.*vertex_normal)) + 1e-10; 219 | vertex_normal = vertex_normal./(ones(3,1)*lens); 220 | -------------------------------------------------------------------------------- /models/meshnet_base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from models.conv import ChebConv 5 | 6 | from torch_scatter import scatter_add 7 | 8 | 9 | def Pool(x, trans, dim=1): 10 | row, col = trans._indices() 11 | value = trans._values().unsqueeze(-1) 12 | out = torch.index_select(x, dim, col) * value 13 | out = scatter_add(out, row, dim, dim_size=trans.size(0)) 14 | return out 15 | 16 | 17 | class Enblock(nn.Module): 18 | def __init__(self, in_channels, out_channels, K, **kwargs): 19 | super(Enblock, self).__init__() 20 | self.conv = ChebConv(in_channels, out_channels, K, **kwargs) 21 | self.reset_parameters() 22 | 23 | def reset_parameters(self): 24 | for name, param in self.conv.named_parameters(): 25 | if 'bias' in name: 26 | nn.init.constant_(param, 0) 27 | else: 28 | nn.init.xavier_uniform_(param) 29 | 30 | def forward(self, x, edge_index, down_transform): 31 | out = F.elu(self.conv(x, edge_index)) 32 | out = Pool(out, down_transform) 33 | return out 34 | 35 | 36 | class Deblock(nn.Module): 37 | def __init__(self, in_channels, out_channels, K, **kwargs): 38 | super(Deblock, self).__init__() 39 | self.conv = ChebConv(in_channels, out_channels, K, **kwargs) 40 | self.reset_parameters() 41 | 42 | def reset_parameters(self): 43 | for name, param in self.conv.named_parameters(): 44 | if 'bias' in name: 45 | nn.init.constant_(param, 0) 46 | else: 47 | nn.init.xavier_uniform_(param) 48 | 49 | def forward(self, x, edge_index, up_transform): 50 | out = Pool(x, up_transform) 51 | out = F.elu(self.conv(out, edge_index)) 52 | return out 53 | 54 | 55 | class MeshDecoder(nn.Module): 56 | def __init__(self, in_channels, out_channels, latent_channels, 57 | edge_index, down_transform, up_transform, K, **kwargs): 58 | super().__init__() 59 | self.in_channels = in_channels 60 | self.out_channels = out_channels 61 | #self.edge_index = edge_index 62 | self.num_edge_index = len(edge_index) 63 | for i in range(self.num_edge_index): 64 | self.register_buffer(f'edge_index_{i}', edge_index[i]) 65 | setattr(self, f'edge_index_{i}', edge_index[i]) 66 | 67 | #self.down_transform = down_transform 68 | self.num_down_transform = len(down_transform) 69 | for i in range(self.num_down_transform): 70 | self.register_buffer(f'down_transform_{i}', down_transform[i]) 71 | setattr(self, f'down_transform_{i}', down_transform[i]) 72 | 73 | #self.up_transform = up_transform 74 | self.num_up_transform = len(up_transform) 75 | for i in range(self.num_up_transform): 76 | self.register_buffer(f'up_transform_{i}', up_transform[i]) 77 | setattr(self, f'up_transform_{i}', up_transform[i]) 78 | # self.num_vert used in the last and the first layer of encoder and decoder 79 | self.num_vert = down_transform[-1].size(0) 80 | 81 | # encoder 82 | #self.en_layers = nn.ModuleList() 83 | #for idx in range(len(out_channels)): 84 | # if idx == 0: 85 | # self.en_layers.append( 86 | # Enblock(in_channels, out_channels[idx], K, **kwargs)) 87 | # else: 88 | # self.en_layers.append( 89 | # Enblock(out_channels[idx - 1], out_channels[idx], K, 90 | # **kwargs)) 91 | #self.en_layers.append( 92 | # nn.Linear(self.num_vert * out_channels[-1], latent_channels)) 93 | 94 | # decoder 95 | self.de_layers = nn.ModuleList() 96 | self.de_layers.append( 97 | nn.Linear(latent_channels, self.num_vert * out_channels[-1])) 98 | for idx in range(len(out_channels)): 99 | if idx == 0: 100 | self.de_layers.append( 101 | Deblock(out_channels[-idx - 1], out_channels[-idx - 1], K,)) 102 | # **kwargs)) 103 | else: 104 | self.de_layers.append( 105 | Deblock(out_channels[-idx], out_channels[-idx - 1], K,)) 106 | # **kwargs)) 107 | # reconstruction 108 | self.de_layers.append( 109 | ChebConv(out_channels[0], in_channels, K)) 110 | # ChebConv(out_channels[0], in_channels, K, **kwargs)) 111 | 112 | self.reset_parameters() 113 | 114 | def reset_parameters(self): 115 | for name, param in self.named_parameters(): 116 | if 'bias' in name: 117 | nn.init.constant_(param, 0) 118 | else: 119 | nn.init.xavier_uniform_(param) 120 | 121 | def encoder(self, x): 122 | for i, layer in enumerate(self.en_layers): 123 | if i != len(self.en_layers) - 1: 124 | x = layer(x, getattr(self, f'edge_index_{i}'), 125 | getattr(self, f'down_transform_{i}')) 126 | else: 127 | x = x.view(-1, layer.weight.size(1)) 128 | x = layer(x) 129 | return x 130 | 131 | def decoder(self, x): 132 | num_layers = len(self.de_layers) 133 | num_deblocks = num_layers - 2 134 | for i, layer in enumerate(self.de_layers): 135 | if i == 0: 136 | x = layer(x) 137 | x = x.view(-1, self.num_vert, self.out_channels[-1]) 138 | elif i != num_layers - 1: 139 | x = layer(x, getattr(self, f'edge_index_{num_deblocks - i}'), 140 | getattr(self, f'up_transform_{num_deblocks - i}')) 141 | else: 142 | # last layer 143 | x = layer(x, getattr(self, 'edge_index_0')) 144 | return x 145 | 146 | def forward(self, x): 147 | # x - batched feature matrix 148 | #z = self.encoder(x) 149 | out = self.decoder(x) 150 | return out 151 | 152 | class MeshDecoder_single(nn.Module): 153 | def __init__(self, in_channels, out_channels, latent_channels, 154 | edge_index, down_transform, up_transform, K, **kwargs): 155 | super().__init__() 156 | self.in_channels = in_channels 157 | self.out_channels = out_channels 158 | self.edge_index = edge_index 159 | self.down_transform = down_transform 160 | self.up_transform = up_transform 161 | # self.num_vert used in the last and the first layer of encoder and decoder 162 | self.num_vert = self.down_transform[-1].size(0) 163 | 164 | # encoder 165 | #self.en_layers = nn.ModuleList() 166 | #for idx in range(len(out_channels)): 167 | # if idx == 0: 168 | # self.en_layers.append( 169 | # Enblock(in_channels, out_channels[idx], K, **kwargs)) 170 | # else: 171 | # self.en_layers.append( 172 | # Enblock(out_channels[idx - 1], out_channels[idx], K, 173 | # **kwargs)) 174 | #self.en_layers.append( 175 | # nn.Linear(self.num_vert * out_channels[-1], latent_channels)) 176 | 177 | # decoder 178 | self.de_layers = nn.ModuleList() 179 | self.de_layers.append( 180 | nn.Linear(latent_channels, self.num_vert * out_channels[-1])) 181 | for idx in range(len(out_channels)): 182 | if idx == 0: 183 | self.de_layers.append( 184 | Deblock(out_channels[-idx - 1], out_channels[-idx - 1], K,)) 185 | # **kwargs)) 186 | else: 187 | self.de_layers.append( 188 | Deblock(out_channels[-idx], out_channels[-idx - 1], K,)) 189 | # **kwargs)) 190 | # reconstruction 191 | self.de_layers.append( 192 | ChebConv(out_channels[0], in_channels, K)) 193 | # ChebConv(out_channels[0], in_channels, K, **kwargs)) 194 | 195 | self.reset_parameters() 196 | 197 | def reset_parameters(self): 198 | for name, param in self.named_parameters(): 199 | if 'bias' in name: 200 | nn.init.constant_(param, 0) 201 | else: 202 | nn.init.xavier_uniform_(param) 203 | 204 | def encoder(self, x): 205 | for i, layer in enumerate(self.en_layers): 206 | if i != len(self.en_layers) - 1: 207 | x = layer(x, self.edge_index[i], self.down_transform[i]) 208 | else: 209 | x = x.view(-1, layer.weight.size(1)) 210 | x = layer(x) 211 | return x 212 | 213 | def decoder(self, x): 214 | num_layers = len(self.de_layers) 215 | num_deblocks = num_layers - 2 216 | for i, layer in enumerate(self.de_layers): 217 | if i == 0: 218 | x = layer(x) 219 | x = x.view(-1, self.num_vert, self.out_channels[-1]) 220 | elif i != num_layers - 1: 221 | x = layer(x, self.edge_index[num_deblocks - i], 222 | self.up_transform[num_deblocks - i]) 223 | else: 224 | # last layer 225 | x = layer(x, self.edge_index[0]) 226 | return x 227 | 228 | def forward(self, x): 229 | # x - batched feature matrix 230 | #z = self.encoder(x) 231 | out = self.decoder(x) 232 | return out 233 | 234 | -------------------------------------------------------------------------------- /utils/implicit_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | import numpy as np 5 | from skimage import measure 6 | from tqdm import tqdm 7 | from .diff_operators import gradient 8 | 9 | 10 | def sdf_decode_mesh_from_single_lat(model, latent_vec, resolution=256, voxel_size=None, max_batch=int(2 ** 18), offset=None, scale=None, points_for_bound=None, verbose=False, x_range=[-1, 1], y_range=[-0.7, 1.7], z_range=[-1.1, 0.9]): 11 | ''' 12 | Args: 13 | model: only model.decoder is used 14 | resolution: the resolution of the shortest_axis 15 | latent_vec: (d, ) 16 | ''' 17 | if resolution is not None: 18 | assert(voxel_size is None) 19 | if points_for_bound is not None: 20 | grid = get_grid_YXZ(points_for_bound, resolution) 21 | else: 22 | grid = get_grid_uniform_YXZ(resolution, x_range=x_range, y_range=y_range, z_range=z_range) 23 | else: 24 | assert(voxel_size is not None) 25 | grid = get_grid_from_size_YXZ(points_for_bound, voxel_size) 26 | 27 | sdf_volume_yxz = [] 28 | ptn_samples_list = torch.split(grid['grid_points'], max_batch, dim=0) 29 | if verbose: 30 | ptn_samples_list = tqdm(ptn_samples_list) 31 | for ptn_samples in ptn_samples_list: 32 | ptn_samples.requires_grad = False 33 | sdf_samples = model.decoder(ptn_samples[None, :, :].to(latent_vec.device), latent_vec[None, :]) 34 | sdf_samples = sdf_samples[0, :, 0].detach().cpu().numpy() 35 | sdf_volume_yxz.append(sdf_samples) 36 | 37 | sdf_volume = np.concatenate(sdf_volume_yxz, axis=0).reshape(grid['ysize'], grid['xsize'], grid['zsize']).transpose([1, 0, 2]) # XYZ 38 | 39 | assert(np.min(sdf_volume) < 0 and np.max(sdf_volume) > 0) 40 | 41 | verts, faces = convert_sdf_volume_to_verts_faces(sdf_volume, grid['voxel_grid_origin'], grid['voxel_size']) 42 | 43 | return verts, faces 44 | 45 | 46 | def convert_sdf_volume_to_verts_faces(sdf_volume, voxel_grid_origin, voxel_size, offset=None, scale=None): 47 | """ 48 | Args: 49 | sdf_volume: (X, Y, Z) 50 | voxel_grid_origin: (3,) bottom, left, down origin of the voxel grid 51 | voxel_size: float, the size of the voxels 52 | """ 53 | 54 | verts, faces, normals, values = measure.marching_cubes( 55 | volume=sdf_volume, level=0.0, spacing=[voxel_size] * 3 56 | ) 57 | verts = verts + voxel_grid_origin 58 | 59 | # apply additional offset and scale. based on preprocess_dfaust.py 60 | if scale is not None: 61 | mesh_points = mesh_points * scale 62 | if offset is not None: 63 | mesh_points = mesh_points + offset 64 | 65 | return verts, faces 66 | 67 | 68 | def get_grid_YXZ(points, resolution): 69 | ''' For x, y, z, the voxel sizes are the same but grid sizes are different 70 | Args: 71 | points: (S, 3) 72 | resolution: the resolution of the shortest_axis 73 | Returns: 74 | grid_points: (Ny * Nx * Nz, 3), the order is Y, X, Z instead of X, Y, Z 75 | ''' 76 | eps = 0.1 77 | input_min = torch.min(points, dim=0)[0].squeeze().detach().cpu().numpy() 78 | input_max = torch.max(points, dim=0)[0].squeeze().detach().cpu().numpy() 79 | 80 | bounding_box = input_max - input_min 81 | shortest_axis = np.argmin(bounding_box) 82 | if (shortest_axis == 0): 83 | x = np.linspace(input_min[shortest_axis] - eps, 84 | input_max[shortest_axis] + eps, resolution) 85 | length = np.max(x) - np.min(x) 86 | voxel_size = length / (x.shape[0] - 1) 87 | y = np.arange(input_min[1] - eps, input_max[1] + length / (x.shape[0] - 1) + eps, length / (x.shape[0] - 1)) 88 | z = np.arange(input_min[2] - eps, input_max[2] + length / (x.shape[0] - 1) + eps, length / (x.shape[0] - 1)) 89 | elif (shortest_axis == 1): 90 | y = np.linspace(input_min[shortest_axis] - eps, 91 | input_max[shortest_axis] + eps, resolution) 92 | length = np.max(y) - np.min(y) 93 | voxel_size = length / (y.shape[0] - 1) 94 | x = np.arange(input_min[0] - eps, input_max[0] + length / (y.shape[0] - 1) + eps, length / (y.shape[0] - 1)) 95 | z = np.arange(input_min[2] - eps, input_max[2] + length / (y.shape[0] - 1) + eps, length / (y.shape[0] - 1)) 96 | elif (shortest_axis == 2): 97 | z = np.linspace(input_min[shortest_axis] - eps, 98 | input_max[shortest_axis] + eps, resolution) 99 | length = np.max(z) - np.min(z) 100 | voxel_size = length / (z.shape[0] - 1) 101 | x = np.arange(input_min[0] - eps, input_max[0] + length / (z.shape[0] - 1) + eps, length / (z.shape[0] - 1)) 102 | y = np.arange(input_min[1] - eps, input_max[1] + length / (z.shape[0] - 1) + eps, length / (z.shape[0] - 1)) 103 | 104 | xx, yy, zz = np.meshgrid(x, y, z) # default: indexing='xy', return shape is (N2, N1, N3) instead of (N1, N2, N3) 105 | grid_points = torch.tensor(np.vstack([xx.ravel(), yy.ravel(), zz.ravel()]).T, dtype=torch.float) 106 | 107 | xsize, ysize, zsize = x.shape[0], y.shape[0], z.shape[0] 108 | voxel_grid_origin = np.array([x[0], y[0], z[0]]) 109 | 110 | return {"grid_points":grid_points, 111 | "shortest_axis_length":length, 112 | "xyz":[x,y,z], 113 | "xsize": xsize, 114 | "ysize": ysize, 115 | "zsize": zsize, 116 | "voxel_grid_origin": voxel_grid_origin, 117 | "voxel_size": voxel_size, 118 | "shortest_axis_index":shortest_axis} 119 | 120 | 121 | def get_grid_uniform_YXZ(resolution, x_range=[-2, 2], y_range=[-2, 2], z_range=[-2, 2]): 122 | ''' For x, y, z, the voxel sizes are the same but grid sizes are different 123 | Args: 124 | resolution: the resolution of the shortest_axis, i.e. x axis 125 | Returns: 126 | grid_points: (Ny * Nx * Nz, 3), the order is Y, X, Z instead of X, Y, Z 127 | ''' 128 | range_len_list = [x_range[1] - x_range[0], y_range[1] - y_range[0], z_range[1] - z_range[0]] 129 | shortest_axis = np.argmin(range_len_list) 130 | 131 | if shortest_axis == 0: 132 | x = np.linspace(x_range[0], x_range[1], resolution) 133 | shortest_axis_length = x.max() - x.min() 134 | voxel_size = shortest_axis_length / (x.shape[0] - 1) 135 | y = np.arange(y_range[0], y_range[1] + voxel_size, voxel_size) 136 | z = np.arange(z_range[0], z_range[1] + voxel_size, voxel_size) 137 | elif shortest_axis == 1: 138 | y = np.linspace(y_range[0], y_range[1], resolution) 139 | shortest_axis_length = y.max() - y.min() 140 | voxel_size = shortest_axis_length / (y.shape[0] - 1) 141 | x = np.arange(x_range[0], x_range[1] + voxel_size, voxel_size) 142 | z = np.arange(z_range[0], z_range[1] + voxel_size, voxel_size) 143 | elif shortest_axis == 2: 144 | z = np.linspace(z_range[0], z_range[1], resolution) 145 | shortest_axis_length = z.max() - z.min() 146 | voxel_size = shortest_axis_length / (z.shape[0] - 1) 147 | x = np.arange(x_range[0], x_range[1] + voxel_size, voxel_size) 148 | y = np.arange(y_range[0], y_range[1] + voxel_size, voxel_size) 149 | else: 150 | raise NotImplementedError 151 | 152 | xx, yy, zz = np.meshgrid(x, y, z) # default: indexing='xy', return shape is (N2, N1, N3) instead of (N1, N2, N3) 153 | grid_points = torch.tensor(np.vstack([xx.ravel(), yy.ravel(), zz.ravel()]).T, dtype=torch.float) 154 | 155 | xsize, ysize, zsize = x.shape[0], y.shape[0], z.shape[0] 156 | voxel_grid_origin = np.array([x[0], y[0], z[0]]) 157 | 158 | return {"grid_points": grid_points, 159 | "shortest_axis_length": shortest_axis_length, 160 | "xyz": [x, y, z], 161 | "xsize": xsize, 162 | "ysize": ysize, 163 | "zsize": zsize, 164 | "voxel_grid_origin": voxel_grid_origin, 165 | "voxel_size": voxel_size, 166 | "shortest_axis_index": shortest_axis} 167 | 168 | 169 | def get_grid_from_size_YXZ(points, voxel_size): 170 | ''' For x, y, z, the voxel sizes are the same but grid sizes are different 171 | Args: 172 | points: (S, 3) 173 | voxel_size: float 174 | Returns: 175 | grid_points: (Ny * Nx * Nz, 3), the order is Y, X, Z instead of X, Y, Z 176 | ''' 177 | eps = 0.1 178 | input_min = torch.min(points, dim=0)[0].squeeze().detach().cpu().numpy() 179 | input_max = torch.max(points, dim=0)[0].squeeze().detach().cpu().numpy() 180 | 181 | bounding_box = input_max - input_min 182 | shortest_axis = np.argmin(bounding_box) 183 | 184 | x = np.arange(input_min[0] - eps, input_max[0] + voxel_size + eps, voxel_size) 185 | y = np.arange(input_min[1] - eps, input_max[1] + voxel_size + eps, voxel_size) 186 | z = np.arange(input_min[2] - eps, input_max[2] + voxel_size + eps, voxel_size) 187 | 188 | if (shortest_axis == 0): 189 | length = np.max(x) - np.min(x) 190 | elif (shortest_axis == 1): 191 | length = np.max(y) - np.min(y) 192 | elif (shortest_axis == 2): 193 | length = np.max(z) - np.min(z) 194 | 195 | xx, yy, zz = np.meshgrid(x, y, z) # default: indexing='xy', return shape is (N2, N1, N3) instead of (N1, N2, N3) 196 | grid_points = torch.tensor(np.vstack([xx.ravel(), yy.ravel(), zz.ravel()]).T, dtype=torch.float) 197 | 198 | xsize, ysize, zsize = x.shape[0], y.shape[0], z.shape[0] 199 | voxel_grid_origin = np.array([x[0], y[0], z[0]]) 200 | 201 | return {"grid_points":grid_points, 202 | "shortest_axis_length":length, 203 | "xyz":[x,y,z], 204 | "xsize": xsize, 205 | "ysize": ysize, 206 | "zsize": zsize, 207 | "voxel_grid_origin": voxel_grid_origin, 208 | "voxel_size": voxel_size, 209 | "shortest_axis_index":shortest_axis} 210 | 211 | 212 | 213 | def langevin_dynamics(model, latent_vec, xyz, num_iters=1): 214 | ''' 215 | Args: 216 | model: only model.decoder is used 217 | latent_vec: (d, ) 218 | xyz: (N, 3) 219 | ''' 220 | device = latent_vec.device 221 | assert(len(latent_vec.shape) == 1 and len(xyz.shape) == 2 and xyz.shape[1] == 3) 222 | 223 | xyz = torch.from_numpy(xyz).float().to(device) 224 | latents = latent_vec[None, :].repeat(xyz.shape[0], 1) 225 | 226 | for _ in range(num_iters): 227 | xyz = xyz.clone().detach().requires_grad_(True) 228 | sdf = model.decoder(xyz, latents) # (N, 1) 229 | fxyz = gradient(sdf, xyz) # (N, 3) 230 | xyz = xyz - sdf * fxyz 231 | 232 | return xyz 233 | 234 | 235 | 236 | 237 | 238 | 239 | 240 | 241 | 242 | 243 | 244 | 245 | -------------------------------------------------------------------------------- /datasets/smal.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import glob 4 | import json 5 | import pickle 6 | 7 | import torch 8 | import trimesh 9 | import numpy as np 10 | from tqdm import tqdm 11 | from omegaconf import OmegaConf 12 | from loguru import logger 13 | from sklearn.decomposition import PCA 14 | 15 | 16 | def remove_nans(tensor): 17 | tensor_nan = torch.isnan(tensor[:, 3]) 18 | return tensor[~tensor_nan, :] 19 | 20 | 21 | class SMALDataSet(torch.utils.data.Dataset): 22 | 23 | def __init__(self, 24 | mode, 25 | rep, 26 | config, 27 | **kwargs): 28 | ''' 29 | Args: 30 | sdf_dir: raw sdf dir 31 | raw_mesh_dir: raw mesh dir, might not have consistent topology 32 | registration_dir: registered mesh dir, must have consistent topology 33 | num_samples: num of samples used to train sdfnet 34 | ''' 35 | super().__init__() 36 | 37 | self.rep = rep 38 | self.config = config 39 | self.mode = mode 40 | if self.mode == 'train': 41 | split = 'train' 42 | elif self.mode == 'test': 43 | split = 'test' 44 | else: 45 | raise ValueError('invalid mode') 46 | 47 | self.data_dir = config.data_dir 48 | self.sdf_dir = config.sdf_dir 49 | self.raw_mesh_dir = config.raw_mesh_dir 50 | self.registration_dir = config.registration_dir 51 | self.num_samples = config.num_samples 52 | self.template_path = config.template_path 53 | 54 | # load data split 55 | split_cfg_fname = config.split_cfg[split] 56 | current_dir = os.path.dirname(os.path.realpath(__file__)) 57 | split_path = f"{current_dir}/splits/smal/{split_cfg_fname}" 58 | with open(split_path, "r") as f: 59 | split_names = json.load(f) 60 | 61 | self.fid_list = self.get_fid_list(split_names) 62 | self.num_data = len(self.fid_list) 63 | 64 | self.raw_mesh_file_type = config.get('raw_mesh_file_type', 'ply') 65 | logger.info(f"dataset mode = {mode}, split = {split}, len = {self.num_data}\n") 66 | 67 | # load temlate mesh for meshnet. Share topology. NOTE: used for meshnet, different from temp(late) in sdfnet 68 | template_mesh = trimesh.load(self.template_path, process=False, maintain_order=True) 69 | self.template_points = torch.from_numpy(template_mesh.vertices) 70 | self.template_faces = np.asarray(template_mesh.faces) 71 | self.num_nodes = self.template_points.shape[0] 72 | 73 | # load sim mesh data if exists 74 | self.sim_mesh_dir = config.get('sim_mesh_dir', None) 75 | if self.sim_mesh_dir is not None: 76 | self.verts_sim_list = [] 77 | self.faces_sim_list = [] 78 | for fid in self.fid_list: 79 | fname = '/'.join(fid.split('-')) 80 | sim_mesh_pkl = pickle.load(open(f"{self.sim_mesh_dir}/{fname}_sim.pkl", 'rb')) 81 | self.verts_sim_list.append(sim_mesh_pkl['verts_sim'].astype(np.float32)) 82 | self.faces_sim_list.append(sim_mesh_pkl['faces_sim']) 83 | 84 | # load init mesh data if exists 85 | self.init_mesh_dir = config.get('init_mesh_dir', None) 86 | if self.init_mesh_dir is not None: 87 | verts_init_list = [] 88 | for fid in tqdm(self.fid_list): 89 | mesh_init = trimesh.load(f"{self.init_mesh_dir}/{fid}_init.obj", process=False, maintain_order=True) 90 | verts_init_list.append(mesh_init.vertices.astype(np.float32)) 91 | self.verts_init = np.stack(verts_init_list) # (1000, 6890, 3) 92 | print(f'Finish loading verts_init, shape = {self.verts_init.shape}') 93 | assert(self.verts_init.shape[0] == self.num_data) 94 | self.mean_init = self.verts_init.mean(axis=0) # only verts_init always has consistent correspondence 95 | self.std_init = self.verts_init.std(axis=0) 96 | # IMPORTANT TODO: if SMAL, set self.std_init = 0.2 97 | 98 | # Normalize mesh data 99 | # NOTE: the target of the prediction: verts_init is normalized 100 | self.verts_init_nml = (self.verts_init - self.mean_init) / self.std_init 101 | 102 | self.use_vert_pca = config.get('use_vert_pca', True) 103 | self.pca = PCA(n_components=config.pca_n_comp) 104 | self.pca.fit(self.verts_init_nml.reshape(self.num_data, -1)) 105 | self.pca_axes = self.pca.components_ 106 | pca_sv = np.matmul(self.verts_init_nml.reshape(self.num_data, -1), self.pca_axes.transpose()) 107 | self.pca_sv_mean = np.mean(pca_sv, axis=0) 108 | self.pca_sv_std = np.std(pca_sv, axis=0) 109 | print(f'Finish computing PCA') 110 | 111 | # load raw mesh 112 | if self.rep == 'mesh': 113 | self.verts_raw_list = [] 114 | self.faces_raw_list = [] 115 | for fid in self.fid_list: 116 | fname = '/'.join(fid.split('-')) 117 | mesh_raw = trimesh.load(f"{self.raw_mesh_dir}/{fname}.{self.raw_mesh_file_type}", process=False, maintain_order=True) 118 | self.verts_raw_list.append(mesh_raw.vertices.astype(np.float32)) 119 | self.faces_raw_list.append(mesh_raw.faces) 120 | 121 | 122 | def get_fid_list(self, split_names): 123 | fid_list = [] 124 | assert(len(split_names) == 1) 125 | for dataset in split_names: 126 | for class_name in split_names[dataset]: 127 | for instance_name in split_names[dataset][class_name]: 128 | for shape in split_names[dataset][class_name][instance_name]: 129 | fid = f"{class_name}-{instance_name}-{shape}" 130 | fid_list.append(fid) 131 | return fid_list 132 | 133 | 134 | def update_pca_sv(self, train_pca_axes, train_pca_sv_mean, train_pca_sv_std): 135 | pca_sv = np.matmul(self.verts_init_nml.reshape(self.num_data, -1), train_pca_axes.transpose()) 136 | self.pca_sv = (pca_sv - train_pca_sv_mean) / train_pca_sv_std 137 | 138 | 139 | def __len__(self): 140 | return self.num_data 141 | 142 | 143 | def __getitem__(self, idx): 144 | data_dict = {} 145 | data_dict['idx'] = idx 146 | fid = self.fid_list[idx] 147 | fname = '/'.join(fid.split('-')) 148 | 149 | if self.rep in ['mesh']: 150 | # no sdf, only load mesh. TODO: verts num diff, need to use PyG dataloader 151 | data_dict['verts_init_nml'] = self.verts_init_nml[idx] 152 | data_dict['verts_raw'] = self.verts_raw_list[idx] 153 | data_dict['faces_raw'] = self.faces_raw_list[idx] 154 | 155 | elif self.rep in ['sdf']: 156 | # load sdf data 157 | 158 | point_set_mnfld = torch.from_numpy(np.load(f"{self.sdf_dir}/{fname}.npy")).float() 159 | samples_nonmnfld = torch.from_numpy(np.load(f"{self.sdf_dir}/{fname}_dist_triangle.npy")).float() 160 | 161 | random_idx = (torch.rand(self.num_samples) * point_set_mnfld.shape[0]).long() 162 | point_set_mnfld = torch.index_select(point_set_mnfld, 0, random_idx) 163 | normal_set_mnfld = point_set_mnfld[:, 3:] 164 | point_set_mnfld = point_set_mnfld[:, :3] # currently all center == [0, 0, 0], scale == 1 165 | 166 | random_idx = (torch.rand(self.num_samples) * samples_nonmnfld.shape[0]).long() 167 | samples_nonmnfld = torch.index_select(samples_nonmnfld, 0, random_idx) 168 | 169 | data_dict['points_mnfld'] = point_set_mnfld 170 | data_dict['normals_mnfld'] = normal_set_mnfld 171 | data_dict['samples_nonmnfld'] = samples_nonmnfld 172 | 173 | # load mesh data 174 | raw_mesh = trimesh.load(f"{self.raw_mesh_dir}/{fname}.{self.raw_mesh_file_type}", process=False, maintain_order=True) 175 | data_dict['raw_mesh_verts'] = np.asarray(raw_mesh.vertices).astype(np.float32) 176 | data_dict['raw_mesh_faces'] = np.asarray(raw_mesh.faces) 177 | 178 | return data_dict 179 | 180 | 181 | if __name__ == '__main__': 182 | import sys 183 | sys.path.append('../') 184 | from pyutils import * 185 | 186 | import argparse 187 | parser = argparse.ArgumentParser() 188 | parser.add_argument("--rep", type=str, help='sdf or mesh') 189 | parser.add_argument("--config", type=str, required=True, help='config yaml file path, e.g. ../config/dfaust.yaml') 190 | args = parser.parse_args() 191 | 192 | config = OmegaConf.load(args.config) 193 | OmegaConf.resolve(config) 194 | update_config_from_args(config, args) 195 | 196 | train_dataset = DFaustDataSet(mode='train', rep=config.rep, config=config.dataset) 197 | test_dataset = DFaustDataSet(mode='test', rep=config.rep, config=config.dataset) 198 | 199 | batch_size = 16 200 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True) 201 | test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False) 202 | raw_mesh_list = [] 203 | # for batch_idx, batch_dict in enumerate(test_loader): 204 | for batch_idx, batch_dict in enumerate(train_loader): 205 | for i in range(batch_size): 206 | if args.rep == 'sdf': 207 | print(i, batch_dict['points_mnfld'].shape) 208 | print(i, batch_dict['normals_mnfld'].shape) 209 | if args.rep == 'mesh': 210 | raise NotImplementedError 211 | 212 | import open3d as o3d 213 | import vis_utils 214 | 215 | starts_mnfld = batch_dict['points_mnfld'][i].numpy() 216 | ends_mnfld = batch_dict['points_mnfld'][i].numpy() + batch_dict['normals_mnfld'][i].numpy() * 0.1 217 | vf_mnfld = vis_utils.create_vector_field(starts_mnfld, ends_mnfld, [0, 1, 0]) 218 | pcd_mnfld = vis_utils.create_pointcloud_from_points(starts_mnfld, [1, 0, 0]) 219 | 220 | starts_nonmnfld = batch_dict['samples_nonmnfld'][i].numpy()[:, :3] 221 | ends_nonmnfld = batch_dict['samples_nonmnfld'][i].numpy()[:, :3] + batch_dict['samples_nonmnfld'][i].numpy()[:, 3:6] * 0.03 222 | vf_nonmnfld = vis_utils.create_vector_field(starts_nonmnfld, ends_nonmnfld, [0, 0, 1]) 223 | pcd_nonmnfld = vis_utils.create_pointcloud_from_points(starts_nonmnfld, [1, 0, 0]) 224 | 225 | raw_mesh = vis_utils.create_triangle_mesh(batch_dict['raw_mesh_verts'][i].numpy(), batch_dict['raw_mesh_faces'][i].numpy()) 226 | raw_mesh_list.append(raw_mesh) 227 | 228 | coord = o3d.geometry.TriangleMesh.create_coordinate_frame(0.1) 229 | o3d.visualization.draw_geometries([raw_mesh, coord, vf_mnfld, pcd_mnfld]) 230 | o3d.visualization.draw_geometries([coord, vf_nonmnfld, pcd_nonmnfld]) 231 | # from IPython import embed; embed() 232 | 233 | break 234 | from IPython import embed; embed() 235 | 236 | 237 | 238 | 239 | 240 | -------------------------------------------------------------------------------- /datasets/dfaust.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import glob 4 | import json 5 | import pickle 6 | 7 | import torch 8 | import trimesh 9 | import numpy as np 10 | from tqdm import tqdm 11 | from omegaconf import OmegaConf 12 | from loguru import logger 13 | from sklearn.decomposition import PCA 14 | 15 | 16 | def remove_nans(tensor): 17 | tensor_nan = torch.isnan(tensor[:, 3]) 18 | return tensor[~tensor_nan, :] 19 | 20 | 21 | class DFaustDataSet(torch.utils.data.Dataset): 22 | 23 | def __init__(self, 24 | mode, 25 | rep, 26 | config, 27 | **kwargs): 28 | ''' 29 | Args: 30 | sdf_dir: raw sdf dir 31 | raw_mesh_dir: raw mesh dir, might not have consistent topology 32 | registration_dir: registered mesh dir, must have consistent topology 33 | num_samples: num of samples used to train sdfnet 34 | ''' 35 | super().__init__() 36 | 37 | self.rep = rep 38 | self.config = config 39 | self.mode = mode 40 | if self.mode == 'train': 41 | split = 'train' 42 | elif self.mode == 'test': 43 | split = 'test' 44 | else: 45 | raise ValueError('invalid mode') 46 | 47 | self.data_dir = config.data_dir 48 | self.sdf_dir = config.sdf_dir 49 | self.raw_mesh_dir = config.raw_mesh_dir 50 | self.registration_dir = config.registration_dir 51 | self.num_samples = config.num_samples 52 | self.template_path = config.template_path 53 | 54 | # load data split 55 | split_cfg_fname = config.split_cfg[split] 56 | current_dir = os.path.dirname(os.path.realpath(__file__)) 57 | split_path = f"{current_dir}/splits/dfaust/{split_cfg_fname}" 58 | with open(split_path, "r") as f: 59 | split_names = json.load(f) 60 | 61 | self.fid_list = self.get_fid_list(split_names) 62 | self.num_data = len(self.fid_list) 63 | 64 | self.raw_mesh_file_type = config.get('raw_mesh_file_type', 'ply') 65 | logger.info(f"dataset mode = {mode}, split = {split}, len = {self.num_data}\n") 66 | 67 | # load temlate mesh for meshnet. Share topology. NOTE: used for meshnet, different from temp(late) in sdfnet 68 | template_mesh = trimesh.load(self.template_path, process=False, maintain_order=True) 69 | self.template_points = torch.from_numpy(template_mesh.vertices) 70 | self.template_faces = np.asarray(template_mesh.faces) 71 | self.num_nodes = self.template_points.shape[0] 72 | 73 | # load sim mesh data if exists 74 | self.sim_mesh_dir = config.get('sim_mesh_dir', None) 75 | if self.sim_mesh_dir is not None: 76 | self.verts_sim_list = [] 77 | self.faces_sim_list = [] 78 | for fid in self.fid_list: 79 | fname = '/'.join(fid.split('-')) 80 | sim_mesh_pkl = pickle.load(open(f"{self.sim_mesh_dir}/{fname}_sim.pkl", 'rb')) 81 | self.verts_sim_list.append(sim_mesh_pkl['verts_sim'].astype(np.float32)) 82 | self.faces_sim_list.append(sim_mesh_pkl['faces_sim']) 83 | 84 | # load init mesh data if exists 85 | self.init_mesh_dir = config.get('init_mesh_dir', None) 86 | if self.init_mesh_dir is not None: 87 | verts_init_list = [] 88 | for fid in tqdm(self.fid_list): 89 | mesh_init = trimesh.load(f"{self.init_mesh_dir}/{fid}_init.obj", process=False, maintain_order=True) 90 | verts_init_list.append(mesh_init.vertices.astype(np.float32)) 91 | self.verts_init = np.stack(verts_init_list) # (1000, 6890, 3) 92 | print(f'Finish loading verts_init, shape = {self.verts_init.shape}') 93 | assert(self.verts_init.shape[0] == self.num_data) 94 | self.mean_init = self.verts_init.mean(axis=0) # only verts_init always has consistent correspondence 95 | self.std_init = self.verts_init.std(axis=0) 96 | # IMPORTANT TODO: if SMAL, set self.std_init = 0.2 97 | 98 | # Normalize mesh data 99 | # NOTE: the target of the prediction: verts_init is normalized 100 | self.verts_init_nml = (self.verts_init - self.mean_init) / self.std_init 101 | 102 | self.use_vert_pca = config.get('use_vert_pca', True) 103 | self.pca = PCA(n_components=config.pca_n_comp) 104 | self.pca.fit(self.verts_init_nml.reshape(self.num_data, -1)) 105 | self.pca_axes = self.pca.components_ 106 | pca_sv = np.matmul(self.verts_init_nml.reshape(self.num_data, -1), self.pca_axes.transpose()) 107 | self.pca_sv_mean = np.mean(pca_sv, axis=0) 108 | self.pca_sv_std = np.std(pca_sv, axis=0) 109 | print(f'Finish computing PCA') 110 | 111 | # load raw mesh 112 | if self.rep == 'mesh': 113 | self.verts_raw_list = [] 114 | self.faces_raw_list = [] 115 | for fid in self.fid_list: 116 | fname = '/'.join(fid.split('-')) 117 | mesh_raw = trimesh.load(f"{self.raw_mesh_dir}/{fname}.{self.raw_mesh_file_type}", process=False, maintain_order=True) 118 | self.verts_raw_list.append(mesh_raw.vertices.astype(np.float32)) 119 | self.faces_raw_list.append(mesh_raw.faces) 120 | 121 | 122 | def get_fid_list(self, split_names): 123 | fid_list = [] 124 | assert(len(split_names) == 1) 125 | for dataset in split_names: 126 | for class_name in split_names[dataset]: 127 | for instance_name in split_names[dataset][class_name]: 128 | for shape in split_names[dataset][class_name][instance_name]: 129 | fid = f"{class_name}-{instance_name}-{shape}" 130 | fid_list.append(fid) 131 | return fid_list 132 | 133 | 134 | def update_pca_sv(self, train_pca_axes, train_pca_sv_mean, train_pca_sv_std): 135 | pca_sv = np.matmul(self.verts_init_nml.reshape(self.num_data, -1), train_pca_axes.transpose()) 136 | self.pca_sv = (pca_sv - train_pca_sv_mean) / train_pca_sv_std 137 | 138 | 139 | def __len__(self): 140 | return self.num_data 141 | 142 | 143 | def __getitem__(self, idx): 144 | data_dict = {} 145 | data_dict['idx'] = idx 146 | fid = self.fid_list[idx] 147 | fname = '/'.join(fid.split('-')) 148 | 149 | if self.rep in ['mesh']: 150 | # no sdf, only load mesh. TODO: verts num diff, need to use PyG dataloader 151 | data_dict['verts_init_nml'] = self.verts_init_nml[idx] 152 | data_dict['verts_raw'] = self.verts_raw_list[idx] 153 | data_dict['faces_raw'] = self.faces_raw_list[idx] 154 | 155 | elif self.rep in ['sdf']: 156 | # load sdf data 157 | 158 | point_set_mnfld = torch.from_numpy(np.load(f"{self.sdf_dir}/{fname}.npy")).float() 159 | samples_nonmnfld = torch.from_numpy(np.load(f"{self.sdf_dir}/{fname}_dist_triangle.npy")).float() 160 | 161 | random_idx = (torch.rand(self.num_samples) * point_set_mnfld.shape[0]).long() 162 | point_set_mnfld = torch.index_select(point_set_mnfld, 0, random_idx) 163 | normal_set_mnfld = point_set_mnfld[:, 3:] 164 | point_set_mnfld = point_set_mnfld[:, :3] # currently all center == [0, 0, 0], scale == 1 165 | 166 | random_idx = (torch.rand(self.num_samples) * samples_nonmnfld.shape[0]).long() 167 | samples_nonmnfld = torch.index_select(samples_nonmnfld, 0, random_idx) 168 | 169 | data_dict['points_mnfld'] = point_set_mnfld 170 | data_dict['normals_mnfld'] = normal_set_mnfld 171 | data_dict['samples_nonmnfld'] = samples_nonmnfld 172 | 173 | # load mesh data 174 | raw_mesh = trimesh.load(f"{self.raw_mesh_dir}/{fname}.{self.raw_mesh_file_type}", process=False, maintain_order=True) 175 | data_dict['raw_mesh_verts'] = np.asarray(raw_mesh.vertices).astype(np.float32) 176 | data_dict['raw_mesh_faces'] = np.asarray(raw_mesh.faces) 177 | 178 | return data_dict 179 | 180 | 181 | if __name__ == '__main__': 182 | import sys 183 | sys.path.append('../') 184 | from pyutils import * 185 | 186 | import argparse 187 | parser = argparse.ArgumentParser() 188 | parser.add_argument("--rep", type=str, help='sdf or mesh') 189 | parser.add_argument("--config", type=str, required=True, help='config yaml file path, e.g. ../config/dfaust.yaml') 190 | args = parser.parse_args() 191 | 192 | config = OmegaConf.load(args.config) 193 | OmegaConf.resolve(config) 194 | update_config_from_args(config, args) 195 | 196 | train_dataset = DFaustDataSet(mode='train', rep=config.rep, config=config.dataset) 197 | test_dataset = DFaustDataSet(mode='test', rep=config.rep, config=config.dataset) 198 | 199 | batch_size = 16 200 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True) 201 | test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False) 202 | raw_mesh_list = [] 203 | # for batch_idx, batch_dict in enumerate(test_loader): 204 | for batch_idx, batch_dict in enumerate(train_loader): 205 | for i in range(batch_size): 206 | if args.rep == 'sdf': 207 | print(i, batch_dict['points_mnfld'].shape) 208 | print(i, batch_dict['normals_mnfld'].shape) 209 | if args.rep == 'mesh': 210 | raise NotImplementedError 211 | 212 | import open3d as o3d 213 | import vis_utils 214 | 215 | starts_mnfld = batch_dict['points_mnfld'][i].numpy() 216 | ends_mnfld = batch_dict['points_mnfld'][i].numpy() + batch_dict['normals_mnfld'][i].numpy() * 0.1 217 | vf_mnfld = vis_utils.create_vector_field(starts_mnfld, ends_mnfld, [0, 1, 0]) 218 | pcd_mnfld = vis_utils.create_pointcloud_from_points(starts_mnfld, [1, 0, 0]) 219 | 220 | starts_nonmnfld = batch_dict['samples_nonmnfld'][i].numpy()[:, :3] 221 | ends_nonmnfld = batch_dict['samples_nonmnfld'][i].numpy()[:, :3] + batch_dict['samples_nonmnfld'][i].numpy()[:, 3:6] * 0.03 222 | vf_nonmnfld = vis_utils.create_vector_field(starts_nonmnfld, ends_nonmnfld, [0, 0, 1]) 223 | pcd_nonmnfld = vis_utils.create_pointcloud_from_points(starts_nonmnfld, [1, 0, 0]) 224 | 225 | raw_mesh = vis_utils.create_triangle_mesh(batch_dict['raw_mesh_verts'][i].numpy(), batch_dict['raw_mesh_faces'][i].numpy()) 226 | raw_mesh_list.append(raw_mesh) 227 | 228 | coord = o3d.geometry.TriangleMesh.create_coordinate_frame(0.1) 229 | o3d.visualization.draw_geometries([raw_mesh, coord, vf_mnfld, pcd_mnfld]) 230 | o3d.visualization.draw_geometries([coord, vf_nonmnfld, pcd_nonmnfld]) 231 | # from IPython import embed; embed() 232 | 233 | break 234 | from IPython import embed; embed() 235 | 236 | 237 | 238 | 239 | 240 | -------------------------------------------------------------------------------- /utils/mesh_sampling.py: -------------------------------------------------------------------------------- 1 | import math 2 | import heapq 3 | import numpy as np 4 | import os 5 | import scipy.sparse as sp 6 | from psbody.mesh import Mesh 7 | from scipy.spatial import KDTree 8 | 9 | def row(A): 10 | return A.reshape((1, -1)) 11 | 12 | def col(A): 13 | return A.reshape((-1, 1)) 14 | 15 | def get_vert_connectivity(mesh_v, mesh_f): 16 | """Returns a sparse matrix (of size #verts x #verts) where each nonzero 17 | element indicates a neighborhood relation. For example, if there is a 18 | nonzero element in position (15,12), that means vertex 15 is connected 19 | by an edge to vertex 12.""" 20 | 21 | vpv = sp.csc_matrix((len(mesh_v), len(mesh_v))) 22 | 23 | # for each column in the faces... 24 | for i in range(3): 25 | IS = mesh_f[:, i] 26 | JS = mesh_f[:, (i + 1) % 3] 27 | data = np.ones(len(IS)) 28 | ij = np.vstack((row(IS.ravel()), row(JS.ravel()))) 29 | mtx = sp.csc_matrix((data, ij), shape=vpv.shape) 30 | vpv = vpv + mtx + mtx.T 31 | 32 | return vpv 33 | 34 | 35 | def get_vertices_per_edge(mesh_v, mesh_f): 36 | """Returns an Ex2 array of adjacencies between vertices, where 37 | each element in the array is a vertex index. Each edge is included 38 | only once. If output of get_faces_per_edge is provided, this is used to 39 | avoid call to get_vert_connectivity()""" 40 | 41 | vc = sp.coo_matrix(get_vert_connectivity(mesh_v, mesh_f)) 42 | result = np.hstack((col(vc.row), col(vc.col))) 43 | result = result[result[:, 0] < result[:, 1]] # for uniqueness 44 | 45 | return result 46 | 47 | 48 | def vertex_quadrics(mesh): 49 | """Computes a quadric for each vertex in the Mesh. 50 | 51 | Returns: 52 | v_quadrics: an (N x 4 x 4) array, where N is # vertices. 53 | """ 54 | 55 | # Allocate quadrics 56 | v_quadrics = np.zeros(( 57 | len(mesh.v), 58 | 4, 59 | 4, 60 | )) 61 | 62 | # For each face... 63 | for f_idx in range(len(mesh.f)): 64 | 65 | # Compute normalized plane equation for that face 66 | vert_idxs = mesh.f[f_idx] 67 | verts = np.hstack((mesh.v[vert_idxs], np.array([1, 1, 68 | 1]).reshape(-1, 1))) 69 | u, s, v = np.linalg.svd(verts) 70 | eq = v[-1, :].reshape(-1, 1) 71 | eq = eq / (np.linalg.norm(eq[0:3])) 72 | 73 | # Add the outer product of the plane equation to the 74 | # quadrics of the vertices for this face 75 | for k in range(3): 76 | v_quadrics[mesh.f[f_idx, k], :, :] += np.outer(eq, eq) 77 | 78 | return v_quadrics 79 | 80 | 81 | def setup_deformation_transfer(source, target, use_normals=False): 82 | rows = np.zeros(3 * target.v.shape[0]) 83 | cols = np.zeros(3 * target.v.shape[0]) 84 | coeffs_v = np.zeros(3 * target.v.shape[0]) 85 | coeffs_n = np.zeros(3 * target.v.shape[0]) 86 | 87 | nearest_faces, nearest_parts, nearest_vertices = source.compute_aabb_tree( 88 | ).nearest(target.v, True) 89 | nearest_faces = nearest_faces.ravel().astype(np.int64) 90 | nearest_parts = nearest_parts.ravel().astype(np.int64) 91 | nearest_vertices = nearest_vertices.ravel() 92 | 93 | for i in range(target.v.shape[0]): 94 | # Closest triangle index 95 | f_id = nearest_faces[i] 96 | # Closest triangle vertex ids 97 | nearest_f = source.f[f_id] 98 | 99 | # Closest surface point 100 | nearest_v = nearest_vertices[3 * i:3 * i + 3] 101 | # Distance vector to the closest surface point 102 | dist_vec = target.v[i] - nearest_v 103 | 104 | rows[3 * i:3 * i + 3] = i * np.ones(3) 105 | cols[3 * i:3 * i + 3] = nearest_f 106 | 107 | n_id = nearest_parts[i] 108 | if n_id == 0: 109 | # Closest surface point in triangle 110 | A = np.vstack((source.v[nearest_f])).T 111 | coeffs_v[3 * i:3 * i + 3] = np.linalg.lstsq(A, nearest_v, 112 | rcond=-1)[0] 113 | elif n_id > 0 and n_id <= 3: 114 | # Closest surface point on edge 115 | A = np.vstack((source.v[nearest_f[n_id - 1]], 116 | source.v[nearest_f[n_id % 3]])).T 117 | tmp_coeffs = np.linalg.lstsq(A, target.v[i], rcond=-1)[0] 118 | coeffs_v[3 * i + n_id - 1] = tmp_coeffs[0] 119 | coeffs_v[3 * i + n_id % 3] = tmp_coeffs[1] 120 | else: 121 | # Closest surface point a vertex 122 | coeffs_v[3 * i + n_id - 4] = 1.0 123 | 124 | matrix = sp.csc_matrix((coeffs_v, (rows, cols)), 125 | shape=(target.v.shape[0], source.v.shape[0])) 126 | return matrix 127 | 128 | 129 | def qslim_decimator_transformer(mesh, factor=None, n_verts_desired=None): 130 | """Return a simplified version of this mesh. 131 | 132 | A Qslim-style approach is used here. 133 | 134 | :param factor: fraction of the original vertices to retain 135 | :param n_verts_desired: number of the original vertices to retain 136 | :returns: new_faces: An Fx3 array of faces, mtx: Transformation matrix 137 | """ 138 | 139 | if factor is None and n_verts_desired is None: 140 | raise Exception('Need either factor or n_verts_desired.') 141 | 142 | if n_verts_desired is None: 143 | n_verts_desired = math.ceil(len(mesh.v) * factor) 144 | 145 | Qv = vertex_quadrics(mesh) 146 | 147 | # fill out a sparse matrix indicating vertex-vertex adjacency 148 | # from psbody.mesh.topology.connectivity import get_vertices_per_edge 149 | vert_adj = get_vertices_per_edge(mesh.v, mesh.f) 150 | # vert_adj = sp.lil_matrix((len(mesh.v), len(mesh.v))) 151 | # for f_idx in range(len(mesh.f)): 152 | # vert_adj[mesh.f[f_idx], mesh.f[f_idx]] = 1 153 | 154 | vert_adj = sp.csc_matrix( 155 | (vert_adj[:, 0] * 0 + 1, (vert_adj[:, 0], vert_adj[:, 1])), 156 | shape=(len(mesh.v), len(mesh.v))) 157 | vert_adj = vert_adj + vert_adj.T 158 | vert_adj = vert_adj.tocoo() 159 | 160 | def collapse_cost(Qv, r, c, v): 161 | Qsum = Qv[r, :, :] + Qv[c, :, :] 162 | p1 = np.vstack((v[r].reshape(-1, 1), np.array([1]).reshape(-1, 1))) 163 | p2 = np.vstack((v[c].reshape(-1, 1), np.array([1]).reshape(-1, 1))) 164 | 165 | destroy_c_cost = p1.T.dot(Qsum).dot(p1) 166 | destroy_r_cost = p2.T.dot(Qsum).dot(p2) 167 | result = { 168 | 'destroy_c_cost': destroy_c_cost, 169 | 'destroy_r_cost': destroy_r_cost, 170 | 'collapse_cost': min([destroy_c_cost, destroy_r_cost]), 171 | 'Qsum': Qsum 172 | } 173 | return result 174 | 175 | # construct a queue of edges with costs 176 | queue = [] 177 | for k in range(vert_adj.nnz): 178 | r = vert_adj.row[k] 179 | c = vert_adj.col[k] 180 | 181 | if r > c: 182 | continue 183 | 184 | cost = collapse_cost(Qv, r, c, mesh.v)['collapse_cost'] 185 | heapq.heappush(queue, (cost, (r, c))) 186 | 187 | # decimate 188 | collapse_list = [] 189 | nverts_total = len(mesh.v) 190 | faces = mesh.f.copy() 191 | while nverts_total > n_verts_desired: 192 | e = heapq.heappop(queue) 193 | r = e[1][0] 194 | c = e[1][1] 195 | if r == c: 196 | continue 197 | 198 | cost = collapse_cost(Qv, r, c, mesh.v) 199 | if cost['collapse_cost'] > e[0]: 200 | heapq.heappush(queue, (cost['collapse_cost'], e[1])) 201 | # print 'found outdated cost, %.2f < %.2f' % (e[0], cost['collapse_cost']) 202 | continue 203 | else: 204 | 205 | # update old vert idxs to new one, 206 | # in queue and in face list 207 | if cost['destroy_c_cost'] < cost['destroy_r_cost']: 208 | to_destroy = c 209 | to_keep = r 210 | else: 211 | to_destroy = r 212 | to_keep = c 213 | 214 | collapse_list.append([to_keep, to_destroy]) 215 | 216 | # in our face array, replace "to_destroy" vertidx with "to_keep" vertidx 217 | np.place(faces, faces == to_destroy, to_keep) 218 | 219 | # same for queue 220 | which1 = [ 221 | idx for idx in range(len(queue)) 222 | if queue[idx][1][0] == to_destroy 223 | ] 224 | which2 = [ 225 | idx for idx in range(len(queue)) 226 | if queue[idx][1][1] == to_destroy 227 | ] 228 | for k in which1: 229 | queue[k] = (queue[k][0], (to_keep, queue[k][1][1])) 230 | for k in which2: 231 | queue[k] = (queue[k][0], (queue[k][1][0], to_keep)) 232 | 233 | Qv[r, :, :] = cost['Qsum'] 234 | Qv[c, :, :] = cost['Qsum'] 235 | 236 | a = faces[:, 0] == faces[:, 1] 237 | b = faces[:, 1] == faces[:, 2] 238 | c = faces[:, 2] == faces[:, 0] 239 | 240 | # remove degenerate faces 241 | def logical_or3(x, y, z): 242 | return np.logical_or(x, np.logical_or(y, z)) 243 | 244 | faces_to_keep = np.logical_not(logical_or3(a, b, c)) 245 | faces = faces[faces_to_keep, :].copy() 246 | 247 | nverts_total = (len(np.unique(faces.flatten()))) 248 | 249 | new_faces, mtx = _get_sparse_transform(faces, len(mesh.v)) 250 | return new_faces, mtx 251 | 252 | 253 | def _get_sparse_transform(faces, num_original_verts): 254 | verts_left = np.unique(faces.flatten()) 255 | IS = np.arange(len(verts_left)) 256 | JS = verts_left 257 | data = np.ones(len(JS)) 258 | 259 | mp = np.arange(0, np.max(faces.flatten()) + 1) 260 | mp[JS] = IS 261 | new_faces = mp[faces.copy().flatten()].reshape((-1, 3)) 262 | 263 | ij = np.vstack((IS.flatten(), JS.flatten())) 264 | mtx = sp.csc_matrix((data, ij), 265 | shape=(len(verts_left), num_original_verts)) 266 | 267 | return (new_faces, mtx) 268 | 269 | 270 | def generate_transform_matrices(mesh, factors): 271 | """Generates len(factors) meshes, each of them is scaled by factors[i] and 272 | computes the transformations between them. 273 | 274 | Returns: 275 | M: a set of meshes downsampled from mesh by a factor specified in factors. 276 | A: Adjacency matrix for each of the meshes 277 | D: csc_matrix Downsampling transforms between each of the meshes 278 | U: Upsampling transforms between each of the meshes 279 | F: a list of faces 280 | """ 281 | 282 | factors = map(lambda x: 1.0 / x, factors) 283 | M, A, D, U, F = [], [], [], [], [] 284 | F.append(mesh.f) # F[0] 285 | A.append(get_vert_connectivity(mesh.v, mesh.f).astype('float32')) # A[0] 286 | M.append(mesh) # M[0] 287 | 288 | for factor in factors: 289 | ds_f, ds_D = qslim_decimator_transformer(M[-1], factor=factor) 290 | D.append(ds_D.astype('float32')) 291 | new_mesh_v = ds_D.dot(M[-1].v) 292 | new_mesh = Mesh(v=new_mesh_v, f=ds_f) 293 | F.append(new_mesh.f) 294 | M.append(new_mesh) 295 | A.append( 296 | get_vert_connectivity(new_mesh.v, new_mesh.f).astype('float32')) 297 | U.append(setup_deformation_transfer(M[-1], M[-2]).astype('float32')) 298 | 299 | return M, A, D, U, F 300 | -------------------------------------------------------------------------------- /datasets/dfaust_hybrid.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import glob 4 | import json 5 | import pickle 6 | from collections import defaultdict 7 | 8 | import torch 9 | import trimesh 10 | import numpy as np 11 | from tqdm import tqdm 12 | from omegaconf import OmegaConf 13 | from loguru import logger 14 | from sklearn.decomposition import PCA 15 | 16 | 17 | def remove_nans(tensor): 18 | tensor_nan = torch.isnan(tensor[:, 3]) 19 | return tensor[~tensor_nan, :] 20 | 21 | 22 | class DFaustHybridDataSet(torch.utils.data.Dataset): 23 | 24 | def __init__(self, 25 | mode, 26 | rep, 27 | config, 28 | **kwargs): 29 | ''' 30 | Args: 31 | sdf_dir: raw sdf dir 32 | raw_mesh_dir: raw mesh dir, might not have consistent topology 33 | registration_dir: registered mesh dir, must have consistent topology 34 | num_samples: num of samples used to train sdfnet 35 | ''' 36 | super().__init__() 37 | 38 | self.rep = rep 39 | self.config = config 40 | self.mode = mode 41 | if self.mode == 'train': 42 | split = 'train' 43 | elif self.mode == 'test': 44 | split = 'test' 45 | else: 46 | raise ValueError('invalid mode') 47 | 48 | self.data_dir = config.data_dir 49 | self.sdf_dir = config.sdf_dir 50 | self.raw_mesh_dir = config.raw_mesh_dir 51 | self.registration_dir = config.registration_dir 52 | self.num_samples = config.num_samples 53 | self.template_path = config.template_path 54 | 55 | # load data split 56 | split_cfg_fname = config.split_cfg[split] 57 | current_dir = os.path.dirname(os.path.realpath(__file__)) 58 | split_path = f"{current_dir}/splits/dfaust_hybrid/{split_cfg_fname}" 59 | with open(split_path, "r") as f: 60 | split_names = json.load(f) 61 | 62 | self.fid_list = self.get_fid_list(split_names) 63 | self.num_data = len(self.fid_list) 64 | 65 | self.raw_mesh_file_type = config.get('raw_mesh_file_type', 'ply') 66 | logger.info(f"dataset mode = {mode}, split = {split}, len = {self.num_data}\n") 67 | 68 | # load temlate mesh for meshnet. Share topology. NOTE: used for meshnet, different from temp(late) in sdfnet 69 | template_mesh = trimesh.load(self.template_path, process=False, maintain_order=True) 70 | self.template_points = torch.from_numpy(template_mesh.vertices) 71 | self.template_faces = np.asarray(template_mesh.faces) 72 | self.num_nodes = self.template_points.shape[0] 73 | 74 | # load sim mesh data if exists 75 | self.sim_mesh_dir = config.get('sim_mesh_dir', None) 76 | if self.sim_mesh_dir is not None: 77 | self.verts_sim_list = [] 78 | self.faces_sim_list = [] 79 | for fid in self.fid_list: 80 | fname = '/'.join(fid.split('-')) 81 | sim_mesh_pkl = pickle.load(open(f"{self.sim_mesh_dir}/{fname}_sim.pkl", 'rb')) 82 | self.verts_sim_list.append(sim_mesh_pkl['verts_sim'].astype(np.float32)) 83 | self.faces_sim_list.append(sim_mesh_pkl['faces_sim']) 84 | 85 | # load init mesh data if exists 86 | self.init_mesh_dir = config.get('init_mesh_dir', None) 87 | if self.init_mesh_dir is not None: 88 | verts_init_list = [] 89 | for fid in tqdm(self.fid_list): 90 | mesh_init = trimesh.load(f"{self.init_mesh_dir}/{fid}_init.obj", process=False, maintain_order=True) 91 | verts_init_list.append(mesh_init.vertices.astype(np.float32)) 92 | self.verts_init = np.stack(verts_init_list) # (1000, 6890, 3) 93 | print(f'Finish loading verts_init, shape = {self.verts_init.shape}') 94 | assert(self.verts_init.shape[0] == self.num_data) 95 | self.mean_init = self.verts_init.mean(axis=0) # only verts_init always has consistent correspondence 96 | self.std_init = self.verts_init.std(axis=0) 97 | # IMPORTANT TODO: if SMAL, set self.std_init = 0.2 98 | 99 | # Normalize mesh data 100 | # NOTE: the target of the prediction: verts_init is normalized 101 | self.verts_init_nml = (self.verts_init - self.mean_init) / self.std_init 102 | 103 | self.use_vert_pca = config.get('use_vert_pca', True) 104 | self.pca = PCA(n_components=config.pca_n_comp) 105 | self.pca.fit(self.verts_init_nml.reshape(self.num_data, -1)) 106 | self.pca_axes = self.pca.components_ 107 | pca_sv = np.matmul(self.verts_init_nml.reshape(self.num_data, -1), self.pca_axes.transpose()) 108 | self.pca_sv_mean = np.mean(pca_sv, axis=0) 109 | self.pca_sv_std = np.std(pca_sv, axis=0) 110 | print(f'Finish computing PCA') 111 | 112 | # load raw mesh 113 | if self.rep == 'mesh': 114 | self.verts_raw_list = [] 115 | self.faces_raw_list = [] 116 | for fid in self.fid_list: 117 | fname = '/'.join(fid.split('-')) 118 | mesh_raw = trimesh.load(f"{self.raw_mesh_dir}/{fname}.{self.raw_mesh_file_type}", process=False, maintain_order=True) 119 | self.verts_raw_list.append(mesh_raw.vertices.astype(np.float32)) 120 | self.faces_raw_list.append(mesh_raw.faces) 121 | 122 | 123 | def get_fid_list(self, split_names): 124 | fid_list = [] 125 | assert(len(split_names) == 1) 126 | for dataset in split_names: 127 | for class_name in split_names[dataset]: 128 | for instance_name in split_names[dataset][class_name]: 129 | for shape in split_names[dataset][class_name][instance_name]: 130 | fid = f"{class_name}-{instance_name}-{shape}" 131 | fid_list.append(fid) 132 | return fid_list 133 | 134 | 135 | def update_pca_sv(self, train_pca_axes, train_pca_sv_mean, train_pca_sv_std): 136 | pca_sv = np.matmul(self.verts_init_nml.reshape(self.num_data, -1), train_pca_axes.transpose()) 137 | self.pca_sv = (pca_sv - train_pca_sv_mean) / train_pca_sv_std 138 | 139 | 140 | def __len__(self): 141 | return self.num_data 142 | 143 | 144 | def __getitem__(self, idx): 145 | data_dict = {} 146 | data_dict['idx'] = torch.tensor(idx, dtype=torch.long) 147 | fid = self.fid_list[idx] 148 | fname = '/'.join(fid.split('-')) 149 | 150 | if self.rep in ['mesh']: 151 | # no sdf, only load mesh. TODO: verts num diff, need to use PyG dataloader 152 | data_dict['verts_init_nml'] = torch.from_numpy(self.verts_init_nml[idx]).float() 153 | data_dict['verts_raw'] = torch.from_numpy(self.verts_raw_list[idx]).float() 154 | data_dict['faces_raw'] = torch.from_numpy(self.faces_raw_list[idx]).long() 155 | 156 | elif self.rep in ['sdf']: 157 | # load sdf data 158 | 159 | point_set_mnfld = torch.from_numpy(np.load(f"{self.sdf_dir}/{fname}.npy")).float() 160 | samples_nonmnfld = torch.from_numpy(np.load(f"{self.sdf_dir}/{fname}_dist_triangle.npy")).float() 161 | 162 | random_idx = (torch.rand(self.num_samples) * point_set_mnfld.shape[0]).long() 163 | point_set_mnfld = torch.index_select(point_set_mnfld, 0, random_idx) 164 | normal_set_mnfld = point_set_mnfld[:, 3:] 165 | point_set_mnfld = point_set_mnfld[:, :3] # currently all center == [0, 0, 0], scale == 1 166 | 167 | random_idx = (torch.rand(self.num_samples) * samples_nonmnfld.shape[0]).long() 168 | samples_nonmnfld = torch.index_select(samples_nonmnfld, 0, random_idx) 169 | 170 | data_dict['points_mnfld'] = point_set_mnfld 171 | data_dict['normals_mnfld'] = normal_set_mnfld 172 | data_dict['samples_nonmnfld'] = samples_nonmnfld 173 | 174 | # load mesh data 175 | # raw_mesh = trimesh.load(f"{self.raw_mesh_dir}/{fname}.{self.raw_mesh_file_type}", process=False, maintain_order=True) 176 | # data_dict['raw_mesh_verts'] = np.asarray(raw_mesh.vertices).astype(np.float32) 177 | # data_dict['raw_mesh_faces'] = np.asarray(raw_mesh.faces) 178 | 179 | return data_dict 180 | 181 | @staticmethod 182 | def collate_batch(batch_list): 183 | data_dict = defaultdict(list) 184 | for cur_sample in batch_list: 185 | for key, val in cur_sample.items(): 186 | data_dict[key].append(val) 187 | batch_size = len(batch_list) 188 | ret = {} 189 | 190 | for key, val in data_dict.items(): 191 | try: 192 | # TODO: should use torch instead of numpy here 193 | # if key in ['verts_raw', 'faces_raw']: # (\sum_{N_i}, d) 194 | # coors = [] 195 | # for i, coor in enumerate(val): 196 | # coor_pad = np.pad(coor, ((0, 0), (1, 0)), mode='constant', constant_values=i) 197 | # coors.append(coor_pad) 198 | # ret[key] = np.concatenate(coors, axis=0) 199 | if key in ['verts_raw', 'faces_raw']: # (B, N_max, d) 200 | max_raw = max([len(x) for x in val]) 201 | batch_raw = torch.zeros((batch_size, max_raw, val[0].shape[-1])).float() 202 | batch_raw_lengths = torch.zeros((batch_size)).long() 203 | for k in range(batch_size): 204 | batch_raw[k, :val[k].__len__(), :] = val[k] 205 | batch_raw_lengths[k] = val[k].__len__() 206 | ret[key] = batch_raw 207 | ret[key + '_lengths'] = batch_raw_lengths 208 | else: 209 | ret[key] = torch.stack(val, dim=0) 210 | except: 211 | print('Error in collate_batch: key=%s' % key) 212 | raise TypeError 213 | 214 | # ret['batch_size'] = batch_size 215 | return ret 216 | 217 | 218 | if __name__ == '__main__': 219 | import sys 220 | sys.path.append('../') 221 | from pyutils import * 222 | 223 | import argparse 224 | parser = argparse.ArgumentParser() 225 | parser.add_argument("--rep", type=str, help='sdf or mesh') 226 | parser.add_argument("--config", type=str, required=True, help='config yaml file path, e.g. ../config/dfaust.yaml') 227 | args = parser.parse_args() 228 | 229 | config = OmegaConf.load(args.config) 230 | OmegaConf.resolve(config) 231 | update_config_from_args(config, args) 232 | 233 | train_dataset = DFaustDataSet(mode='train', rep=config.rep, config=config.dataset) 234 | test_dataset = DFaustDataSet(mode='test', rep=config.rep, config=config.dataset) 235 | 236 | batch_size = 16 237 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True) 238 | test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False) 239 | raw_mesh_list = [] 240 | # for batch_idx, batch_dict in enumerate(test_loader): 241 | for batch_idx, batch_dict in enumerate(train_loader): 242 | for i in range(batch_size): 243 | if args.rep == 'sdf': 244 | print(i, batch_dict['points_mnfld'].shape) 245 | print(i, batch_dict['normals_mnfld'].shape) 246 | if args.rep == 'mesh': 247 | raise NotImplementedError 248 | 249 | import open3d as o3d 250 | import vis_utils 251 | 252 | starts_mnfld = batch_dict['points_mnfld'][i].numpy() 253 | ends_mnfld = batch_dict['points_mnfld'][i].numpy() + batch_dict['normals_mnfld'][i].numpy() * 0.1 254 | vf_mnfld = vis_utils.create_vector_field(starts_mnfld, ends_mnfld, [0, 1, 0]) 255 | pcd_mnfld = vis_utils.create_pointcloud_from_points(starts_mnfld, [1, 0, 0]) 256 | 257 | starts_nonmnfld = batch_dict['samples_nonmnfld'][i].numpy()[:, :3] 258 | ends_nonmnfld = batch_dict['samples_nonmnfld'][i].numpy()[:, :3] + batch_dict['samples_nonmnfld'][i].numpy()[:, 3:6] * 0.03 259 | vf_nonmnfld = vis_utils.create_vector_field(starts_nonmnfld, ends_nonmnfld, [0, 0, 1]) 260 | pcd_nonmnfld = vis_utils.create_pointcloud_from_points(starts_nonmnfld, [1, 0, 0]) 261 | 262 | raw_mesh = vis_utils.create_triangle_mesh(batch_dict['raw_mesh_verts'][i].numpy(), batch_dict['raw_mesh_faces'][i].numpy()) 263 | raw_mesh_list.append(raw_mesh) 264 | 265 | coord = o3d.geometry.TriangleMesh.create_coordinate_frame(0.1) 266 | o3d.visualization.draw_geometries([raw_mesh, coord, vf_mnfld, pcd_mnfld]) 267 | o3d.visualization.draw_geometries([coord, vf_nonmnfld, pcd_nonmnfld]) 268 | # from IPython import embed; embed() 269 | 270 | break 271 | from IPython import embed; embed() 272 | 273 | 274 | 275 | 276 | 277 | -------------------------------------------------------------------------------- /models/asap.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import trimesh 3 | import numpy as np 4 | import scipy.io as sio 5 | from loguru import logger 6 | 7 | import torch 8 | from torch_geometric.utils import degree, get_laplacian 9 | import torch_sparse as ts 10 | 11 | 12 | def get_laplacian_kron3x3(edge_index, edge_weights, N): 13 | edge_index, edge_weight = get_laplacian(edge_index, edge_weights, num_nodes=N) # (2, V+2E), (V+2E,) 14 | edge_weight *= 2 15 | e0, e1 = edge_index 16 | i0 = [e0*3, e0*3+1, e0*3+2] 17 | i1 = [e1*3, e1*3+1, e1*3+2] 18 | vals = [edge_weight, edge_weight, edge_weight] 19 | i0 = torch.cat(i0, 0) 20 | i1 = torch.cat(i1, 0) 21 | vals = torch.cat(vals, 0) 22 | indices, vals = ts.coalesce([i0, i1], vals, N*3, N*3) # (2, 3(V+2E)), (2(V+2E),) 23 | return indices, vals 24 | 25 | 26 | def compute_asap3d_sparse(verts, faces, weight_asap=0.1): 27 | """ compute normalized: (L_arap + weight_asap * L_asap) / (1 + weight_asap) 28 | Args: 29 | verts: (N, 3) 30 | faces: (E, 3) 31 | Returns: 32 | Hessian: (3N, 3N), sparse 33 | """ 34 | N = verts.shape[0] 35 | device = verts.device 36 | adj = torch.zeros((N, N), device=device) 37 | adj[faces[:, 0], faces[:, 1]] = 1 38 | adj[faces[:, 1], faces[:, 2]] = 1 39 | adj[faces[:, 0], faces[:, 2]] = 1 40 | adj = adj + adj.T 41 | edge_index = torch.as_tensor(torch.stack(torch.where(adj > 0), 0), dtype=torch.long) # (2, 2E) 42 | # edge_index = torch.cat((faces[:, [0, 1]], faces[:, [1, 2]], faces[:, [2, 0]]), dim=0).T # (2, 2E) 43 | 44 | e0, e1 = edge_index # (2E,), (2E,) 45 | deg = degree(e0, N) # (V,) 46 | edge_weight = torch.ones_like(e0) # (2E,) 47 | edge_vecs = verts[e0, :] - verts[e1, :] # (2E, 3) 48 | edge_vecs_sq = (edge_vecs * edge_vecs).sum(-1) # (2E,) 49 | 50 | # XXXXXX COMPUTE L XXXXXX 51 | L_indices, L_vals = get_laplacian_kron3x3(edge_index, edge_weight, N) 52 | L = torch.sparse_coo_tensor(L_indices, L_vals, (N*3, N*3)) 53 | 54 | # XXXXXX COMPUTE B XXXXXX 55 | B0, B1, B_vals = [], [], [] 56 | # off-diagonal: use e0 and e1 57 | B0.append(e0*3 ); B1.append(e1*3+1); B_vals.append(-edge_vecs[:, 2]*edge_weight) 58 | B0.append(e0*3 ); B1.append(e1*3+2); B_vals.append( edge_vecs[:, 1]*edge_weight) 59 | B0.append(e0*3+1); B1.append(e1*3+0); B_vals.append( edge_vecs[:, 2]*edge_weight) 60 | B0.append(e0*3+1); B1.append(e1*3+2); B_vals.append(-edge_vecs[:, 0]*edge_weight) 61 | B0.append(e0*3+2); B1.append(e1*3+0); B_vals.append(-edge_vecs[:, 1]*edge_weight) 62 | B0.append(e0*3+2); B1.append(e1*3+1); B_vals.append( edge_vecs[:, 0]*edge_weight) 63 | 64 | # in-diagonal: use e0 and e0 65 | B0.append(e0*3 ); B1.append(e0*3+1); B_vals.append(-edge_vecs[:, 2]*edge_weight) 66 | B0.append(e0*3 ); B1.append(e0*3+2); B_vals.append( edge_vecs[:, 1]*edge_weight) 67 | B0.append(e0*3+1); B1.append(e0*3+0); B_vals.append( edge_vecs[:, 2]*edge_weight) 68 | B0.append(e0*3+1); B1.append(e0*3+2); B_vals.append(-edge_vecs[:, 0]*edge_weight) 69 | B0.append(e0*3+2); B1.append(e0*3+0); B_vals.append(-edge_vecs[:, 1]*edge_weight) 70 | B0.append(e0*3+2); B1.append(e0*3+1); B_vals.append( edge_vecs[:, 0]*edge_weight) 71 | 72 | B0 = torch.cat(B0, 0); B1 = torch.cat(B1, 0); B_vals = torch.cat(B_vals, 0) 73 | B = torch.sparse_coo_tensor(torch.stack([B0, B1]), B_vals, (N*3, N*3)) 74 | 75 | # XXXXXX COMPUTE H XXXXXX 76 | H0, H1, H_vals = [], [], [] 77 | # i==j 78 | H0.append(e0*3 ); H1.append(e0); H_vals.append(-edge_vecs[:, 0]) 79 | H0.append(e0*3+1); H1.append(e0); H_vals.append(-edge_vecs[:, 1]) 80 | H0.append(e0*3+2); H1.append(e0); H_vals.append(-edge_vecs[:, 2]) 81 | # (i, j) \in E 82 | H0.append(e0*3 ); H1.append(e1); H_vals.append(-edge_vecs[:, 0]) 83 | H0.append(e0*3+1); H1.append(e1); H_vals.append(-edge_vecs[:, 1]) 84 | H0.append(e0*3+2); H1.append(e1); H_vals.append(-edge_vecs[:, 2]) 85 | 86 | H0 = torch.cat(H0, 0); H1 = torch.cat(H1, 0); H_vals = torch.cat(H_vals, 0) 87 | H = torch.sparse_coo_tensor(torch.stack([H0, H1]), H_vals, (N*3, N)) 88 | 89 | # XXXXXX COMPUTE C XXXXXX 90 | C0, C1, C_vals = [], [], [] 91 | for di in range(3): 92 | for dj in range(3): 93 | C0.append(e0*3+di); C1.append(e0*3+dj); C_vals.append(-edge_vecs[:, di]*edge_vecs[:, dj]*edge_weight) 94 | C0.append(e0*3+di); C1.append(e0*3+di); C_vals.append(edge_vecs_sq*edge_weight) 95 | C0 = torch.cat(C0, 0); C1 = torch.cat(C1, 0); C_vals = torch.cat(C_vals, 0) 96 | C_indices, C_vals = ts.coalesce([C0, C1], C_vals, N*3, N*3) 97 | Cinv_indices = C_indices 98 | try: 99 | Cinv_vals = C_vals.view(N, 3, 3).inverse().reshape(-1) 100 | except: 101 | logger.debug('Cinv_vals error: use pinv') 102 | Cinv_vals = torch.linalg.pinv(C_vals.view(N, 3, 3)).reshape(-1) 103 | Cinv = torch.sparse_coo_tensor(Cinv_indices, Cinv_vals, (N*3, N*3)) 104 | 105 | # XXXXXX COMPUTE G XXXXXX 106 | G0, G1, G_vals = [], [], [] 107 | # NOTE: DO NOT use: G_vals.append(1/(edge_vecs_sq*edge_weight)). Have to create the matrix first then inverse since 1/sum(v) != sum(1/v) 108 | G0.append(e0); G1.append(e0); G_vals.append(edge_vecs_sq*edge_weight) 109 | G0 = torch.cat(G0, 0); G1 = torch.cat(G1, 0); G_vals = torch.cat(G_vals, 0) 110 | G_indices, G_vals = ts.coalesce([G0, G1], G_vals, N, N) 111 | Ginv_indices = G_indices 112 | Ginv_vals = 1 / G_vals 113 | Ginv_vals[G_vals < 1e-6] = 0 # remove 1/0 = inf 114 | Ginv = torch.sparse_coo_tensor(Ginv_indices, Ginv_vals, (N, N)) 115 | 116 | # XXXXXX COMPUTE Hessian XXXXXX 117 | BCinv = torch.sparse.mm(B, Cinv) 118 | BCinvBT = torch.sparse.mm(BCinv, B.t()) 119 | 120 | HGinv = torch.sparse.mm(H, Ginv) 121 | HGinvHT = torch.sparse.mm(HGinv, H.t()) 122 | 123 | Hessian_sparse = L - BCinvBT - weight_asap / (1 + weight_asap) * HGinvHT 124 | 125 | return Hessian_sparse 126 | 127 | 128 | class ASAP(torch.nn.Module): 129 | def __init__(self, template_face, num_points): 130 | super(ASAP, self).__init__() 131 | N = num_points 132 | self.template_face = template_face # (F=13776, 3) 133 | adj = np.zeros((num_points, num_points)) 134 | adj[template_face[:, 0], template_face[:, 1]] = 1 135 | adj[template_face[:, 1], template_face[:, 2]] = 1 136 | adj[template_face[:, 0], template_face[:, 2]] = 1 137 | adj = adj + adj.T 138 | edge_index = torch.as_tensor(np.stack(np.where(adj > 0), 0), 139 | dtype=torch.long) # (2, 2E=41328) 140 | e0, e1 = edge_index # (2E,), (2E,) 141 | deg = degree(e0, N) # (V,) 142 | edge_weight = torch.ones_like(e0) # (2E,) 143 | 144 | L_indices, L_vals = get_laplacian_kron3x3(edge_index, edge_weight, N) 145 | self.register_buffer('L_indices', L_indices) 146 | self.register_buffer('L_vals', L_vals) 147 | self.register_buffer('edge_weight', edge_weight) 148 | self.register_buffer('edge_index', edge_index) 149 | 150 | def forward(self, x, J, k=0, weight_asap=0.05): 151 | """ compute normalized: (L_arap + weight_asap * L_asap) / (1 + weight_asap) 152 | x: [B, N, 3] point locations. 153 | J: [B, N*3, D] Jacobian of generator. 154 | J_eigvals: [B, D] 155 | """ 156 | num_batches, N = x.shape[:2] 157 | e0, e1 = self.edge_index 158 | edge_vecs = x[:, e0, :] - x[:, e1, :] # (B, 2E, 3) 159 | trace_ = [] 160 | 161 | for i in range(num_batches): 162 | LJ = ts.spmm(self.L_indices, self.L_vals, N*3, N*3, J[i]) # (3N, D) 163 | JTLJ = J[i].T.matmul(LJ) 164 | 165 | # XXXXXX COMPUTE B XXXXXX 166 | B0, B1, B_vals = [], [], [] 167 | # off-diagonal: use e0 and e1 168 | B0.append(e0*3 ); B1.append(e1*3+1); B_vals.append(-edge_vecs[i, :, 2]*self.edge_weight) 169 | B0.append(e0*3 ); B1.append(e1*3+2); B_vals.append( edge_vecs[i, :, 1]*self.edge_weight) 170 | B0.append(e0*3+1); B1.append(e1*3+0); B_vals.append( edge_vecs[i, :, 2]*self.edge_weight) 171 | B0.append(e0*3+1); B1.append(e1*3+2); B_vals.append(-edge_vecs[i, :, 0]*self.edge_weight) 172 | B0.append(e0*3+2); B1.append(e1*3+0); B_vals.append(-edge_vecs[i, :, 1]*self.edge_weight) 173 | B0.append(e0*3+2); B1.append(e1*3+1); B_vals.append( edge_vecs[i, :, 0]*self.edge_weight) 174 | 175 | # in-diagonal: use e0 and e0 176 | B0.append(e0*3 ); B1.append(e0*3+1); B_vals.append(-edge_vecs[i, :, 2]*self.edge_weight) 177 | B0.append(e0*3 ); B1.append(e0*3+2); B_vals.append( edge_vecs[i, :, 1]*self.edge_weight) 178 | B0.append(e0*3+1); B1.append(e0*3+0); B_vals.append( edge_vecs[i, :, 2]*self.edge_weight) 179 | B0.append(e0*3+1); B1.append(e0*3+2); B_vals.append(-edge_vecs[i, :, 0]*self.edge_weight) 180 | B0.append(e0*3+2); B1.append(e0*3+0); B_vals.append(-edge_vecs[i, :, 1]*self.edge_weight) 181 | B0.append(e0*3+2); B1.append(e0*3+1); B_vals.append( edge_vecs[i, :, 0]*self.edge_weight) 182 | 183 | B0 = torch.cat(B0, 0) 184 | B1 = torch.cat(B1, 0) 185 | B_vals = torch.cat(B_vals, 0) 186 | B_indices, B_vals = ts.coalesce([B0, B1], B_vals, N*3, N*3) 187 | BT_indices, BT_vals = ts.transpose(B_indices, B_vals, N*3, N*3) 188 | 189 | # XXXXXX COMPUTE H XXXXXX 190 | H0, H1, H_vals = [], [], [] 191 | # i==j 192 | H0.append(e0*3 ); H1.append(e0); H_vals.append(-edge_vecs[i, :, 0]*self.edge_weight) 193 | H0.append(e0*3+1); H1.append(e0); H_vals.append(-edge_vecs[i, :, 1]*self.edge_weight) 194 | H0.append(e0*3+2); H1.append(e0); H_vals.append(-edge_vecs[i, :, 2]*self.edge_weight) 195 | # (i, j) \in E 196 | H0.append(e0*3 ); H1.append(e1); H_vals.append(-edge_vecs[i, :, 0]*self.edge_weight) 197 | H0.append(e0*3+1); H1.append(e1); H_vals.append(-edge_vecs[i, :, 1]*self.edge_weight) 198 | H0.append(e0*3+2); H1.append(e1); H_vals.append(-edge_vecs[i, :, 2]*self.edge_weight) 199 | 200 | H0 = torch.cat(H0, 0); H1 = torch.cat(H1, 0); H_vals = torch.cat(H_vals, 0); 201 | H_indices, H_vals = ts.coalesce([H0, H1], H_vals, N*3, N) 202 | HT_indices, HT_vals = ts.transpose(H_indices, H_vals, N*3, N) 203 | 204 | # XXXXXX COMPUTE C XXXXXX 205 | C0, C1, C_vals = [], [], [] 206 | edge_vecs_sq = (edge_vecs[i] * edge_vecs[i]).sum(-1) 207 | evi = edge_vecs[i] # (2E, 3) 208 | for di in range(3): 209 | for dj in range(3): 210 | C0.append(e0*3+di); C1.append(e0*3+dj); C_vals.append(-evi[:, di]*evi[:, dj]*self.edge_weight) 211 | C0.append(e0*3+di); C1.append(e0*3+di); C_vals.append(edge_vecs_sq*self.edge_weight) 212 | C0 = torch.cat(C0, 0); C1 = torch.cat(C1, 0); C_vals = torch.cat(C_vals, 0) 213 | C_indices, C_vals = ts.coalesce([C0, C1], C_vals, N*3, N*3) 214 | Cinv_indices = C_indices 215 | try: 216 | Cinv_vals = C_vals.view(N, 3, 3).inverse().reshape(-1) 217 | except: 218 | logger.debug('C_vals error: use pinv') 219 | Cinv_vals = torch.linalg.pinv(C_vals.view(N, 3, 3)).reshape(-1) 220 | 221 | # XXXXXX COMPUTE G XXXXXX 222 | G0, G1, G_vals = [], [], [] 223 | # NOTE: DO NOT use: G_vals.append(1/(edge_vecs_sq*edge_weight)). Have to create the matrix first then inverse since 1/sum(v) != sum(1/v) 224 | G0.append(e0); G1.append(e0); G_vals.append(edge_vecs_sq*self.edge_weight) 225 | G0 = torch.cat(G0, 0); G1 = torch.cat(G1, 0); G_vals = torch.cat(G_vals, 0) 226 | G_indices, G_vals = ts.coalesce([G0, G1], G_vals, N, N) 227 | Ginv_indices = G_indices 228 | Ginv_vals = 1 / G_vals 229 | Ginv_vals[G_vals < 1e-6] = 0 # remove 1/0 = inf 230 | 231 | # XXXXXX COMPUTE Hessian XXXXXX 232 | BTJ = ts.spmm(BT_indices, BT_vals, N*3, N*3, J[i]) 233 | CinvBTJ = ts.spmm(Cinv_indices, Cinv_vals, N*3, N*3, BTJ) 234 | JTBCinvBTJ = BTJ.T.mm(CinvBTJ) 235 | 236 | HTJ = ts.spmm(HT_indices, HT_vals, N, N*3, J[i]) 237 | GinvHTJ = ts.spmm(Ginv_indices, Ginv_vals, N, N, HTJ) 238 | JTHGinvHTJ = HTJ.T.mm(GinvHTJ) 239 | 240 | Rm = JTLJ - JTBCinvBTJ - weight_asap / (1 + weight_asap) * JTHGinvHTJ 241 | 242 | e = torch.linalg.eigvalsh(Rm).clip(0) 243 | e = e ** 0.5 244 | 245 | trace = e.sum() 246 | trace_.append(trace) 247 | 248 | trace_ = torch.stack(trace_, ) 249 | return trace_.mean() 250 | 251 | 252 | if __name__ == '__main__': 253 | 254 | mesh = trimesh.load('./mesh.obj', process=False) 255 | # mesh = trimesh.load('./meshsim.obj', process=False) 256 | template_faces = np.asarray(mesh.faces) 257 | x = np.asarray(mesh.vertices) 258 | N = x.shape[0] 259 | 260 | x = torch.from_numpy(x)[None, ...] 261 | J = torch.randn(1, N*3, 16) 262 | arap = ASAP(template_faces, N) 263 | arap(x, J) 264 | 265 | 266 | 267 | 268 | 269 | -------------------------------------------------------------------------------- /registration_dfaust/multiCorres_sync_dfaust1k/io/plyread.m: -------------------------------------------------------------------------------- 1 | function [Elements,varargout] = plyread(Path,Str) 2 | %PLYREAD Read a PLY 3D data file. 3 | % [DATA,COMMENTS] = PLYREAD(FILENAME) reads a version 1.0 PLY file 4 | % FILENAME and returns a structure DATA. The fields in this structure 5 | % are defined by the PLY header; each element type is a field and each 6 | % element property is a subfield. If the file contains any comments, 7 | % they are returned in a cell string array COMMENTS. 8 | % 9 | % [TRI,PTS] = PLYREAD(FILENAME,'tri') or 10 | % [TRI,PTS,DATA,COMMENTS] = PLYREAD(FILENAME,'tri') converts vertex 11 | % and face data into triangular connectivity and vertex arrays. The 12 | % mesh can then be displayed using the TRISURF command. 13 | % 14 | % Note: This function is slow for large mesh files (+50K faces), 15 | % especially when reading data with list type properties. 16 | % 17 | % Example: 18 | % [Tri,Pts] = PLYREAD('cow.ply','tri'); 19 | % trisurf(Tri,Pts(:,1),Pts(:,2),Pts(:,3)); 20 | % colormap(gray); axis equal; 21 | % 22 | % See also: PLYWRITE 23 | 24 | % Pascal Getreuer 2004 25 | 26 | [fid,Msg] = fopen(Path,'rt'); % open file in read text mode 27 | 28 | if fid == -1, error(Msg); end 29 | 30 | Buf = fscanf(fid,'%s',1); 31 | if ~strcmp(Buf,'ply') 32 | fclose(fid); 33 | error('Not a PLY file.'); 34 | end 35 | 36 | 37 | %%% read header %%% 38 | 39 | Position = ftell(fid); 40 | Format = ''; 41 | NumComments = 0; 42 | Comments = {}; % for storing any file comments 43 | NumElements = 0; 44 | NumProperties = 0; 45 | Elements = []; % structure for holding the element data 46 | ElementCount = []; % number of each type of element in file 47 | PropertyTypes = []; % corresponding structure recording property types 48 | ElementNames = {}; % list of element names in the order they are stored in the file 49 | PropertyNames = []; % structure of lists of property names 50 | 51 | while 1 52 | Buf = fgetl(fid); % read one line from file 53 | BufRem = Buf; 54 | Token = {}; 55 | Count = 0; 56 | 57 | while ~isempty(BufRem) % split line into tokens 58 | [tmp,BufRem] = strtok(BufRem); 59 | 60 | if ~isempty(tmp) 61 | Count = Count + 1; % count tokens 62 | Token{Count} = tmp; 63 | end 64 | end 65 | 66 | if Count % parse line 67 | switch lower(Token{1}) 68 | case 'format' % read data format 69 | if Count >= 2 70 | Format = lower(Token{2}); 71 | 72 | if Count == 3 & ~strcmp(Token{3},'1.0') 73 | fclose(fid); 74 | error('Only PLY format version 1.0 supported.'); 75 | end 76 | end 77 | case 'comment' % read file comment 78 | NumComments = NumComments + 1; 79 | Comments{NumComments} = ''; 80 | for i = 2:Count 81 | Comments{NumComments} = [Comments{NumComments},Token{i},' ']; 82 | end 83 | case 'element' % element name 84 | if Count >= 3 85 | if isfield(Elements,Token{2}) 86 | fclose(fid); 87 | error(['Duplicate element name, ''',Token{2},'''.']); 88 | end 89 | 90 | NumElements = NumElements + 1; 91 | NumProperties = 0; 92 | Elements = setfield(Elements,Token{2},[]); 93 | PropertyTypes = setfield(PropertyTypes,Token{2},[]); 94 | ElementNames{NumElements} = Token{2}; 95 | PropertyNames = setfield(PropertyNames,Token{2},{}); 96 | CurElement = Token{2}; 97 | ElementCount(NumElements) = str2double(Token{3}); 98 | 99 | if isnan(ElementCount(NumElements)) 100 | fclose(fid); 101 | error(['Bad element definition: ',Buf]); 102 | end 103 | else 104 | error(['Bad element definition: ',Buf]); 105 | end 106 | case 'property' % element property 107 | if ~isempty(CurElement) & Count >= 3 108 | NumProperties = NumProperties + 1; 109 | eval(['tmp=isfield(Elements.',CurElement,',Token{Count});'],... 110 | 'fclose(fid);error([''Error reading property: '',Buf])'); 111 | 112 | if tmp 113 | error(['Duplicate property name, ''',CurElement,'.',Token{2},'''.']); 114 | end 115 | 116 | % add property subfield to Elements 117 | eval(['Elements.',CurElement,'.',Token{Count},'=[];'], ... 118 | 'fclose(fid);error([''Error reading property: '',Buf])'); 119 | % add property subfield to PropertyTypes and save type 120 | eval(['PropertyTypes.',CurElement,'.',Token{Count},'={Token{2:Count-1}};'], ... 121 | 'fclose(fid);error([''Error reading property: '',Buf])'); 122 | % record property name order 123 | eval(['PropertyNames.',CurElement,'{NumProperties}=Token{Count};'], ... 124 | 'fclose(fid);error([''Error reading property: '',Buf])'); 125 | else 126 | fclose(fid); 127 | 128 | if isempty(CurElement) 129 | error(['Property definition without element definition: ',Buf]); 130 | else 131 | error(['Bad property definition: ',Buf]); 132 | end 133 | end 134 | case 'end_header' % end of header, break from while loop 135 | break; 136 | end 137 | end 138 | end 139 | 140 | %%% set reading for specified data format %%% 141 | 142 | if isempty(Format) 143 | warning('Data format unspecified, assuming ASCII.'); 144 | Format = 'ascii'; 145 | end 146 | 147 | switch Format 148 | case 'ascii' 149 | Format = 0; 150 | case 'binary_little_endian' 151 | Format = 1; 152 | case 'binary_big_endian' 153 | Format = 2; 154 | otherwise 155 | fclose(fid); 156 | error(['Data format ''',Format,''' not supported.']); 157 | end 158 | 159 | if ~Format 160 | Buf = fscanf(fid,'%f'); % read the rest of the file as ASCII data 161 | BufOff = 1; 162 | else 163 | % reopen the file in read binary mode 164 | fclose(fid); 165 | 166 | if Format == 1 167 | fid = fopen(Path,'r','ieee-le.l64'); % little endian 168 | else 169 | fid = fopen(Path,'r','ieee-be.l64'); % big endian 170 | end 171 | 172 | % find the end of the header again (using ftell on the old handle doesn't give the correct position) 173 | BufSize = 8192; 174 | Buf = [blanks(10),char(fread(fid,BufSize,'uchar')')]; 175 | i = []; 176 | tmp = -11; 177 | 178 | while isempty(i) 179 | i = findstr(Buf,['end_header',13,10]); % look for end_header + CR/LF 180 | i = [i,findstr(Buf,['end_header',10])]; % look for end_header + LF 181 | 182 | if isempty(i) 183 | tmp = tmp + BufSize; 184 | Buf = [Buf(BufSize+1:BufSize+10),char(fread(fid,BufSize,'uchar')')]; 185 | end 186 | end 187 | 188 | % seek to just after the line feed 189 | fseek(fid,i + tmp + 11 + (Buf(i + 10) == 13),-1); 190 | end 191 | 192 | 193 | %%% read element data %%% 194 | 195 | % PLY and MATLAB data types (for fread) 196 | PlyTypeNames = {'char','uchar','short','ushort','int','uint','float','double', ... 197 | 'char8','uchar8','short16','ushort16','int32','uint32','float32','double64'}; 198 | MatlabTypeNames = {'schar','uchar','int16','uint16','int32','uint32','single','double'}; 199 | SizeOf = [1,1,2,2,4,4,4,8]; % size in bytes of each type 200 | 201 | for i = 1:NumElements 202 | % get current element property information 203 | eval(['CurPropertyNames=PropertyNames.',ElementNames{i},';']); 204 | eval(['CurPropertyTypes=PropertyTypes.',ElementNames{i},';']); 205 | NumProperties = size(CurPropertyNames,2); 206 | 207 | %fprintf('Reading %s...\n',ElementNames{i}); 208 | 209 | if ~Format %%% read ASCII data %%% 210 | for j = 1:NumProperties 211 | Token = getfield(CurPropertyTypes,CurPropertyNames{j}); 212 | 213 | if strcmpi(Token{1},'list') 214 | Type(j) = 1; 215 | else 216 | Type(j) = 0; 217 | end 218 | end 219 | 220 | % parse buffer 221 | if ~any(Type) 222 | % no list types 223 | Data = reshape(Buf(BufOff:BufOff+ElementCount(i)*NumProperties-1),NumProperties,ElementCount(i))'; 224 | BufOff = BufOff + ElementCount(i)*NumProperties; 225 | else 226 | ListData = cell(NumProperties,1); 227 | 228 | for k = 1:NumProperties 229 | ListData{k} = cell(ElementCount(i),1); 230 | end 231 | 232 | % list type 233 | for j = 1:ElementCount(i) 234 | for k = 1:NumProperties 235 | if ~Type(k) 236 | Data(j,k) = Buf(BufOff); 237 | BufOff = BufOff + 1; 238 | else 239 | tmp = Buf(BufOff); 240 | ListData{k}{j} = Buf(BufOff+(1:tmp))'; 241 | BufOff = BufOff + tmp + 1; 242 | end 243 | end 244 | end 245 | end 246 | else %%% read binary data %%% 247 | % translate PLY data type names to MATLAB data type names 248 | ListFlag = 0; % = 1 if there is a list type 249 | SameFlag = 1; % = 1 if all types are the same 250 | 251 | for j = 1:NumProperties 252 | Token = getfield(CurPropertyTypes,CurPropertyNames{j}); 253 | 254 | if ~strcmp(Token{1},'list') % non-list type 255 | tmp = rem(strmatch(Token{1},PlyTypeNames,'exact')-1,8)+1; 256 | 257 | if ~isempty(tmp) 258 | TypeSize(j) = SizeOf(tmp); 259 | Type{j} = MatlabTypeNames{tmp}; 260 | TypeSize2(j) = 0; 261 | Type2{j} = ''; 262 | 263 | SameFlag = SameFlag & strcmp(Type{1},Type{j}); 264 | else 265 | fclose(fid); 266 | error(['Unknown property data type, ''',Token{1},''', in ', ... 267 | ElementNames{i},'.',CurPropertyNames{j},'.']); 268 | end 269 | else % list type 270 | if length(Token) == 3 271 | ListFlag = 1; 272 | SameFlag = 0; 273 | tmp = rem(strmatch(Token{2},PlyTypeNames,'exact')-1,8)+1; 274 | tmp2 = rem(strmatch(Token{3},PlyTypeNames,'exact')-1,8)+1; 275 | 276 | if ~isempty(tmp) & ~isempty(tmp2) 277 | TypeSize(j) = SizeOf(tmp); 278 | Type{j} = MatlabTypeNames{tmp}; 279 | TypeSize2(j) = SizeOf(tmp2); 280 | Type2{j} = MatlabTypeNames{tmp2}; 281 | else 282 | fclose(fid); 283 | error(['Unknown property data type, ''list ',Token{2},' ',Token{3},''', in ', ... 284 | ElementNames{i},'.',CurPropertyNames{j},'.']); 285 | end 286 | else 287 | fclose(fid); 288 | error(['Invalid list syntax in ',ElementNames{i},'.',CurPropertyNames{j},'.']); 289 | end 290 | end 291 | end 292 | 293 | % read file 294 | if ~ListFlag 295 | if SameFlag 296 | % no list types, all the same type (fast) 297 | Data = fread(fid,[NumProperties,ElementCount(i)],Type{1})'; 298 | else 299 | % no list types, mixed type 300 | Data = zeros(ElementCount(i),NumProperties); 301 | 302 | for j = 1:ElementCount(i) 303 | for k = 1:NumProperties 304 | Data(j,k) = fread(fid,1,Type{k}); 305 | end 306 | end 307 | end 308 | else 309 | ListData = cell(NumProperties,1); 310 | 311 | for k = 1:NumProperties 312 | ListData{k} = cell(ElementCount(i),1); 313 | end 314 | 315 | if NumProperties == 1 316 | BufSize = 512; 317 | SkipNum = 4; 318 | j = 0; 319 | 320 | % list type, one property (fast if lists are usually the same length) 321 | while j < ElementCount(i) 322 | BufSize = min(ElementCount(i)-j,BufSize); 323 | Position = ftell(fid); 324 | % read in BufSize count values, assuming all counts = SkipNum 325 | [Buf,BufSize] = fread(fid,BufSize,Type{1},SkipNum*TypeSize2(1)); 326 | Miss = find(Buf ~= SkipNum); % find first count that is not SkipNum 327 | fseek(fid,Position + TypeSize(1),-1); % seek back to after first count 328 | 329 | if isempty(Miss) % all counts are SkipNum 330 | Buf = fread(fid,[SkipNum,BufSize],[int2str(SkipNum),'*',Type2{1}],TypeSize(1))'; 331 | fseek(fid,-TypeSize(1),0); % undo last skip 332 | 333 | for k = 1:BufSize 334 | ListData{1}{j+k} = Buf(k,:); 335 | end 336 | 337 | j = j + BufSize; 338 | BufSize = floor(1.5*BufSize); 339 | else 340 | if Miss(1) > 1 % some counts are SkipNum 341 | Buf2 = fread(fid,[SkipNum,Miss(1)-1],[int2str(SkipNum),'*',Type2{1}],TypeSize(1))'; 342 | 343 | for k = 1:Miss(1)-1 344 | ListData{1}{j+k} = Buf2(k,:); 345 | end 346 | 347 | j = j + k; 348 | end 349 | 350 | % read in the list with the missed count 351 | SkipNum = Buf(Miss(1)); 352 | j = j + 1; 353 | ListData{1}{j} = fread(fid,[1,SkipNum],Type2{1}); 354 | BufSize = ceil(0.6*BufSize); 355 | end 356 | end 357 | else 358 | % list type(s), multiple properties (slow) 359 | Data = zeros(ElementCount(i),NumProperties); 360 | 361 | for j = 1:ElementCount(i) 362 | for k = 1:NumProperties 363 | if isempty(Type2{k}) 364 | Data(j,k) = fread(fid,1,Type{k}); 365 | else 366 | tmp = fread(fid,1,Type{k}); 367 | ListData{k}{j} = fread(fid,[1,tmp],Type2{k}); 368 | end 369 | end 370 | end 371 | end 372 | end 373 | end 374 | 375 | % put data into Elements structure 376 | for k = 1:NumProperties 377 | if (~Format & ~Type(k)) | (Format & isempty(Type2{k})) 378 | eval(['Elements.',ElementNames{i},'.',CurPropertyNames{k},'=Data(:,k);']); 379 | else 380 | eval(['Elements.',ElementNames{i},'.',CurPropertyNames{k},'=ListData{k};']); 381 | end 382 | end 383 | end 384 | 385 | clear Data ListData; 386 | fclose(fid); 387 | 388 | if (nargin > 1 & strcmpi(Str,'Tri')) | nargout > 2 389 | % find vertex element field 390 | Name = {'vertex','Vertex','point','Point','pts','Pts'}; 391 | Names = []; 392 | 393 | for i = 1:length(Name) 394 | if any(strcmp(ElementNames,Name{i})) 395 | Names = getfield(PropertyNames,Name{i}); 396 | Name = Name{i}; 397 | break; 398 | end 399 | end 400 | 401 | if any(strcmp(Names,'x')) & any(strcmp(Names,'y')) & any(strcmp(Names,'z')) 402 | eval(['varargout{1}=[Elements.',Name,'.x,Elements.',Name,'.y,Elements.',Name,'.z];']); 403 | else 404 | varargout{1} = zeros(1,3); 405 | end 406 | 407 | varargout{2} = Elements; 408 | varargout{3} = Comments; 409 | Elements = []; 410 | 411 | % find face element field 412 | Name = {'face','Face','poly','Poly','tri','Tri'}; 413 | Names = []; 414 | 415 | for i = 1:length(Name) 416 | if any(strcmp(ElementNames,Name{i})) 417 | Names = getfield(PropertyNames,Name{i}); 418 | Name = Name{i}; 419 | break; 420 | end 421 | end 422 | 423 | if ~isempty(Names) 424 | % find vertex indices property subfield 425 | PropertyName = {'vertex_indices','vertex_indexes','vertex_index','indices','indexes'}; 426 | 427 | for i = 1:length(PropertyName) 428 | if any(strcmp(Names,PropertyName{i})) 429 | PropertyName = PropertyName{i}; 430 | break; 431 | end 432 | end 433 | 434 | if ~iscell(PropertyName) 435 | % convert face index lists to triangular connectivity 436 | eval(['FaceIndices=varargout{2}.',Name,'.',PropertyName,';']); 437 | N = length(FaceIndices); 438 | Elements = zeros(N*2,3); 439 | Extra = 0; 440 | 441 | for k = 1:N 442 | Elements(k,:) = FaceIndices{k}(1:3); 443 | 444 | for j = 4:length(FaceIndices{k}) 445 | Extra = Extra + 1; 446 | Elements(N + Extra,:) = [Elements(k,[1,j-1]),FaceIndices{k}(j)]; 447 | end 448 | end 449 | Elements = Elements(1:N+Extra,:) + 1; 450 | end 451 | end 452 | else 453 | varargout{1} = Comments; 454 | end --------------------------------------------------------------------------------