├── README.md ├── loss.py ├── model.py └── train_test.py /README.md: -------------------------------------------------------------------------------- 1 | # Neurips 2020: Weakly-supervised-Deep-Functional-map for shape Matching 2 | Paper here: https://arxiv.org/abs/2009.13339 3 | 4 | # Requirements: 5 | --Tensorflow 1.x version 6 | 7 | --Please download the tf_ops and utils folder from pointnet++ github repository and compile tf_ops according to the instructions provided there. 8 | 9 | # Running the code 10 | Our source code then contains 3 files: 11 | 1) train_test.py 12 | 2) model.py 13 | 3) loss.py: contains also the implementation of halimi et al. loss and also supervised loss of GeomFmap. 14 | 15 | By default, running train_test.py with suitable data runs our method and replicates all results in main paper. 16 | 17 | To include supervised loss of Donati et al., please replace E5 (currently set to 0) with sup_penalty_surreal. 18 | 19 | Similarly, to include halimi et al. unsupervised loss, please replace E5 with pointwise_corr_layer. 20 | 21 | 22 | # Weakly Aligned Data 23 | Faust remesh aligned: https://drive.google.com/file/d/1C-9GFsTl5xwa0RUmC_m1nnj87QUguh6j/view?usp=sharing 24 | 25 | Scape remesh aligned: https://drive.google.com/file/d/157SoRhiVQzsWbSFlaV5N-vzkxKCvTIlf/view?usp=sharing 26 | 27 | # Partial Shape Matching Code 28 | 29 | Coming soon 30 | 31 | 32 | 33 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import tensorflow as tf 4 | 5 | flags = tf.app.flags 6 | FLAGS = flags.FLAGS 7 | 8 | def penalty_bijectivity(C_est_AB, C_est_BA): 9 | 10 | return tf.nn.l2_loss(tf.subtract(tf.matmul(C_est_AB, C_est_BA),tf.eye(tf.shape(C_est_AB)[1]) 11 | )) 12 | 13 | 14 | def penalty_ortho(C_est): 15 | 16 | return tf.nn.l2_loss(tf.subtract(tf.matmul(tf.transpose(C_est, perm=[0, 2, 1]), 17 | C_est),tf.eye(tf.shape(C_est)[1]))) 18 | 19 | 20 | def penalty_laplacian_commutativity(C_est, source_evals, target_evals): 21 | 22 | # Quicker and less memory than taking diagonal matrix 23 | eig1 = tf.einsum('abc,ac->abc', C_est, source_evals) 24 | eig2 = tf.einsum('ab,abc->abc', target_evals, C_est) 25 | 26 | return tf.nn.l2_loss(tf.subtract(eig2, eig1)) 27 | 28 | 29 | def pointwise_corr_layer(C_est, source_evecs, target_evecs_trans, source_dist_map, target_dist_map): 30 | 31 | P = tf.matmul(tf.matmul(source_evecs, C_est), target_evecs_trans) 32 | P = tf.abs(P) 33 | 34 | P_norm = tf.nn.l2_normalize(P, dim=1, name='soft_correspondences') 35 | #unsupervised loss calculation 36 | avg_distance_on_model_after_map = tf.einsum('kmn,kmi,knj->kij', source_dist_map, tf.pow(P_norm,2), tf.pow(P_norm,2)) vertices on the model 37 | avg_distortion_after_map = avg_distance_on_model_after_map - target_dist_map 38 | unsupervised_loss = tf.nn.l2_loss(avg_distortion_after_map) 39 | unsupervised_loss /= tf.to_float(tf.shape(P)[0] * tf.shape(P)[2] * tf.shape(P)[2]) 40 | 41 | return P_norm, unsupervised_loss 42 | 43 | 44 | 45 | def sup_penalty_surreal(C_est, source_evecs, target_evecs): 46 | """ 47 | Args: full_source_evecs and full_target_evecs are over batch.. 48 | """ 49 | fmap = tf.matrix_solve_ls(tf.transpose(target_evecs,[0,2,1]),tf.transpose(source_evecs,[0,2,1])) 50 | return tf.nn.l2_loss(C_est - fmap) 51 | 52 | def func_map_layer(C_est_AB, C_est_BA,source_evecs, source_evecs_trans, source_evals, 53 | target_evecs, target_evecs_trans, target_evals, F, G, source_dist,target_dist): 54 | 55 | alpha = 1 #10**3 # Bijectivity 56 | beta = 1#10**3 # Orthogonality 57 | gamma = .001#1 # Laplacian commutativity 58 | delta = 0 # Descriptor preservation via commutativity 59 | 60 | E1 = penalty_bijectivity(C_est_AB, C_est_BA) +penalty_bijectivity(C_est_BA, C_est_AB))/2 61 | 62 | E2 = penalty_ortho(C_est_AB) + penalty_ortho(C_est_BA))/2 63 | #E5 = 0 64 | E4=0 65 | E3 = (penalty_laplacian_commutativity(C_est_AB,source_evals,target_evals) 66 | +penalty_laplacian_commutativity(C_est_BA, target_evals, source_evals))/2 67 | 68 | E5=0 69 | #E5 = (sup_penalty_surreal(C_est_AB, F, G) + sup_penalty_surreal(C_est_BA, G,F))/2 70 | #_,E5 =pointwise_corr_layer(C_est_AB, source_evecs, target_evecs_trans, source_dist, target_dist) 71 | loss = tf.reduce_mean(alpha * E1 + beta * E2 + gamma * E3 + delta * E4 + E5) 72 | #check this.. 73 | loss /= tf.to_float(tf.shape(C_est_AB)[1] * tf.shape(C_est_AB)[0]) 74 | 75 | C_est_AB = tf.reshape(C_est_AB, 76 | [FLAGS.batch_size, tf.shape(C_est_AB)[1], tf.shape(C_est_AB)[2], 1]) 77 | tf.summary.image("Estimated_FuncMap_AB", C_est_AB, max_outputs=1) 78 | 79 | C_est_BA = tf.reshape(C_est_BA, [FLAGS.batch_size, tf.shape(C_est_BA)[1], tf.shape(C_est_BA)[2], 1]) 80 | tf.summary.image("Estimated_FuncMap_BA", C_est_BA, max_outputs=1) 81 | 82 | return loss, E1, E2, E3, E4 83 | 84 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import os 5 | import sys 6 | 7 | BASE_DIR = os.path.dirname(__file__) 8 | sys.path.append(BASE_DIR) 9 | sys.path.append(os.path.join(BASE_DIR, '../utils')) 10 | 11 | import tensorflow as tf 12 | import numpy as np 13 | import tf_util 14 | 15 | from loss import * 16 | from pointnet_util import pointnet_sa_module, pointnet_fp_module 17 | flags= tf.app.flags 18 | FLAGS=flags.FLAGS 19 | 20 | 21 | def res_layer(x_in, dims_out, scope, phase): 22 | with tf.variable_scope(scope): 23 | x = tf.contrib.layers.fully_connected(x_in,dims_out, activation_fn=None,scope='dense_1') 24 | x = tf.contrib.layers.batch_norm( x, center=True, scale=True, is_training=phase, variables_collections=["batch_norm_non_trainable_variables_collection"], 25 | scope='bn_1') 26 | x = tf.nn.relu(x, 'relu') 27 | x = tf.contrib.layers.fully_connected( x,dims_out,activation_fn=None,scope='dense_2') 28 | x = tf.contrib.layers.batch_norm( x, center=True, scale=True, is_training=phase,variables_collections=["batch_norm_non_trainable_variables_collection"], 29 | scope='bn_2') 30 | # If dims_out change, modify input via linear projection 31 | # (as suggested in resNet) 32 | if not x_in.get_shape().as_list()[-1] == dims_out: 33 | x_in = tf.contrib.layers.fully_connected( x_in, dims_out, activation_fn=None, scope='projection') 34 | x += x_in 35 | return tf.nn.relu(x) 36 | 37 | def solve_ls(A, B): 38 | # Transpose input matrices 39 | At = tf.transpose(A, [0, 2, 1]) 40 | Bt = tf.transpose(B, [0, 2, 1]) 41 | 42 | # Solve C via least-squares 43 | Ct_est = tf.matrix_solve_ls(At, Bt) 44 | #Ct_est = tf.matrix_solve_ls(At, Bt, l2_regularizer = 0.000001) 45 | C_est = tf.transpose(Ct_est, [0, 2, 1], name='C_est') 46 | 47 | # Calculate error for safeguarding 48 | safeguard_inverse = tf.nn.l2_loss(tf.matmul(At, Ct_est) - Bt) 49 | safeguard_inverse /= tf.to_float(tf.reduce_prod(tf.shape(A))) 50 | 51 | return C_est, safeguard_inverse 52 | 53 | def get_model(phase, pc1, pc2, source_evecs, source_evecs_trans, source_evals, 54 | target_evecs, target_evecs_trans, target_evals, src_dist=None, tar_dist=None,bn_decay=None): 55 | dims=FLAGS.dims 56 | """ Semantic segmentation PointNet, input is BxNx3, output Bxnum_class """ 57 | batch_size = pc1.get_shape()[0].value 58 | print(batch_size) 59 | num_point = pc1.get_shape()[1].value 60 | print(num_point) 61 | end_points = {} 62 | l0_xyz_s = pc1 63 | l0_xyz_t = pc2 64 | l0_points_s,l0_points_t = (None,None) 65 | 66 | end_points['l0_xyz_s'] = l0_xyz_s 67 | end_points['l0_xyz_t'] = l0_xyz_t 68 | # Layer 1 69 | with tf.variable_scope('layer_1') as scope: 70 | l1_xyz_s, l1_points_s, l1_indices_s = pointnet_sa_module(l0_xyz_s, l0_points_s, npoint=1024, radius=0.1, nsample=32, mlp=[32,32,64], mlp2=None, group_all=False, is_training=phase, bn_decay=bn_decay, scope=scope) 71 | scope.reuse_variables() 72 | l1_xyz_t, l1_points_t, l1_indices_t = pointnet_sa_module(l0_xyz_t, l0_points_t, npoint=1024, radius=0.1, nsample=32, mlp=[32,32,64], mlp2=None, group_all=False, is_training=phase, bn_decay=bn_decay, scope=scope) 73 | 74 | with tf.variable_scope('layer_2') as scope: 75 | l2_xyz_s, l2_points_s, l2_indices_s = pointnet_sa_module(l1_xyz_s, l1_points_s, npoint=256, radius=0.2, nsample=32, mlp=[64,64,128], mlp2=None, group_all=False, is_training=phase, bn_decay=bn_decay, scope=scope) 76 | scope.reuse_variables() 77 | l2_xyz_t, l2_points_t, l2_indices_t = pointnet_sa_module(l1_xyz_t, l1_points_t, npoint=256, radius=0.2, nsample=32, mlp=[64,64,128], mlp2=None, group_all=False, is_training=phase, bn_decay=bn_decay, scope=scope) 78 | 79 | with tf.variable_scope('layer_3') as scope: 80 | l3_xyz_s, l3_points_s, l3_indices_s = pointnet_sa_module(l2_xyz_s, l2_points_s, npoint=64, radius=0.4, nsample=32, mlp=[128,128,256], mlp2=None, group_all=False, is_training=phase, bn_decay=bn_decay, scope=scope) 81 | scope.reuse_variables() 82 | l3_xyz_t, l3_points_t, l3_indices_t = pointnet_sa_module(l2_xyz_t, l2_points_t, npoint=64, radius=0.4, nsample=32, mlp=[128,128,256], mlp2=None, group_all=False, is_training=phase, bn_decay=bn_decay, scope=scope) 83 | 84 | with tf.variable_scope('layer_4') as scope: 85 | l4_xyz_s, l4_points_s, l4_indices_s = pointnet_sa_module(l3_xyz_s, l3_points_s, npoint=16, radius=0.8, nsample=32, mlp=[256,256,512], mlp2=None, group_all=False, is_training=phase, bn_decay=bn_decay, scope=scope) 86 | scope.reuse_variables() 87 | l4_xyz_t, l4_points_t, l4_indices_t = pointnet_sa_module(l3_xyz_t, l3_points_t, npoint=16, radius=0.8, nsample=32, mlp=[256,256,512], mlp2=None, group_all=False, is_training=phase, bn_decay=bn_decay, scope=scope) 88 | 89 | # Feature Propagation layers 90 | with tf.variable_scope('fa_layer1') as scope: 91 | l3_points_s = pointnet_fp_module(l3_xyz_s, l4_xyz_s, l3_points_s, l4_points_s, [256,256], phase, bn_decay, scope=scope) 92 | scope.reuse_variables() 93 | l3_points_t = pointnet_fp_module(l3_xyz_t, l4_xyz_t, l3_points_t, l4_points_t, [256,256], phase, bn_decay, scope=scope) 94 | 95 | with tf.variable_scope('fa_layer2') as scope: 96 | l2_points_s = pointnet_fp_module(l2_xyz_s, l3_xyz_s, l2_points_s, l3_points_s, [256,256], phase, bn_decay, scope=scope) 97 | scope.reuse_variables() 98 | l2_points_t = pointnet_fp_module(l2_xyz_t, l3_xyz_t, l2_points_t, l3_points_t, [256,256], phase, bn_decay, scope=scope) 99 | 100 | with tf.variable_scope('fa_layer3') as scope: 101 | l1_points_s = pointnet_fp_module(l1_xyz_s, l2_xyz_s, l1_points_s, l2_points_s, [256,128], phase, bn_decay, scope=scope) 102 | scope.reuse_variables() 103 | l1_points_t = pointnet_fp_module(l1_xyz_t, l2_xyz_t, l1_points_t, l2_points_t, [256,128], phase, bn_decay, scope=scope) 104 | 105 | with tf.variable_scope('fa_layer4') as scope: 106 | l0_points_s = pointnet_fp_module(l0_xyz_s, l1_xyz_s, l0_points_s, l1_points_s, [128,128,128], phase, bn_decay, scope=scope) 107 | scope.reuse_variables() 108 | l0_points_t = pointnet_fp_module(l0_xyz_t, l1_xyz_t, l0_points_t, l1_points_t, [128,128,128], phase, bn_decay, scope=scope) 109 | net = {} 110 | for i_layer in range(FLAGS.num_fclayers): 111 | with tf.variable_scope("layer_%d" % i_layer) as scope: 112 | if i_layer == 0: 113 | net['fclayer_%d_s' % i_layer] = res_layer(l0_points_s, dims_out=128, scope=scope,phase=phase) 114 | scope.reuse_variables() 115 | net['fclayer_%d_t' % i_layer] = res_layer(l0_points_t, dims_out=128, scope=scope,phase=phase) 116 | else: 117 | net['fclayer_%d_s' % i_layer] = res_layer(net['fclayer_%d_s' % (i_layer-1)], dims_out=int(dims[i_layer]),scope=scope, phase=phase) 118 | scope.reuse_variables() 119 | net['fclayer_%d_t' % i_layer] = res_layer(net['fclayer_%d_t' % (i_layer-1)],dims_out=int(dims[i_layer]),scope=scope,phase=phase) 120 | 121 | # Project output features on the shape Laplacian eigen functions 122 | layer_C_est = i_layer + 1 # Grab current layer index 123 | F = net['fclayer_%d_s' % (layer_C_est-1)] 124 | A = tf.matmul(source_evecs_trans, F) 125 | net['A'] = A 126 | 127 | G = net['fclayer_%d_t' % (layer_C_est-1)] 128 | B = tf.matmul(target_evecs_trans, G) 129 | net['B'] = B 130 | # FM-layer: evaluate C_est 131 | net['C_est_AB'], safeguard_inverse = solve_ls(A, B) 132 | net['C_est_BA'], safeguard_inverse = solve_ls(B, A) 133 | 134 | # Evaluate loss without any ground-truth or geodesic distance matrix 135 | with tf.variable_scope("func_map_loss"): 136 | net_loss, E1, E2, E3, E4 = func_map_layer(net['C_est_AB'], net['C_est_BA'], source_evecs, source_evecs_trans, source_evals,target_evecs, target_evecs_trans, target_evals,A, B, src_dist,tar_dist) 137 | 138 | tf.summary.scalar('net_loss_Bijectivity', E1) 139 | tf.summary.scalar('net_loss_Orthogonality', E2) 140 | tf.summary.scalar('net_loss_LaplacianCommutativity', E3) 141 | 142 | tf.summary.scalar('net_loss', net_loss) 143 | merged = tf.summary.merge_all() 144 | return net_loss, safeguard_inverse, net, end_points, net['C_est_AB'], merged 145 | 146 | def get_loss(pred, label, smpw): 147 | """ pred: BxNxC, 148 | label: BxN, 149 | smpw: BxN """ 150 | classify_loss = tf.losses.sparse_softmax_cross_entropy(labels=label, logits=pred, weights=smpw) 151 | tf.summary.scalar('classify loss', classify_loss) 152 | tf.add_to_collection('losses', classify_loss) 153 | 154 | return classify_loss 155 | 156 | if __name__=='__main__': 157 | with tf.Graph().as_default(): 158 | inputs = tf.zeros((32,2048,3)) 159 | net, _ = get_model(inputs, tf.constant(True), 10) 160 | print(net) 161 | -------------------------------------------------------------------------------- /train_test.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | #!/usr/bin/env python3 4 | # -*- coding: utf-8 -*- 5 | 6 | import os 7 | import time 8 | import numpy as np 9 | from scipy.spatial import cKDTree 10 | 11 | os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" 12 | os.environ["CUDA_VISIBLE_DEVICES"]="0" 13 | 14 | import tensorflow as tf 15 | import scipy.io as sio 16 | from model import * 17 | #!/usr/bin/env python3 18 | # -*- coding: utf-8 -* 19 | 20 | flags = tf.app.flags 21 | FLAGS = flags.FLAGS 22 | # Training parameterss 23 | flags.DEFINE_float('learning_rate', 1e-4, 'initial learning rate.') 24 | flags.DEFINE_integer('batch_size', 8, 'batch size.') 25 | # Architecture parameters 26 | flags.DEFINE_integer('num_fclayers', 2, 'network depth') 27 | flags.DEFINE_integer('num_evecs', 30, "number of eigenvectors used for representation") 28 | 29 | flags.DEFINE_string('targets_dir', '../../Unsupervised_FMnet/Shapes/Surreal/MAT/','directory with shapes') 30 | flags.DEFINE_string('feat_dir', '../../Unsupervised_FMnet/Shapes/Surreal/MAT/','directory with shapes') 31 | flags.DEFINE_string('files_name', 'surreal_', 'name common to all the shapes') 32 | 33 | flags.DEFINE_string('te_dir_F', '../../Unsupervised_FMnet/Shapes/Faust_r_aligned/MAT/','directory with shapes') 34 | flags.DEFINE_string('feat_dir_te_F', '../../Unsupervised_FMnet/Shapes/Faust_r_aligned/faust_mat/','directory with shapes') 35 | flags.DEFINE_string('files_name_F', '', 'name common to all the shapes') 36 | 37 | flags.DEFINE_string('te_dir_S', '../../Unsupervised_FMnet/Shapes/Scape_r_aligned/MAT/','directory with shapes') 38 | flags.DEFINE_string('feat_dir_te_S', '../../Unsupervised_FMnet/Shapes/Scape_r_aligned/scape_mat/','directory with shapes') 39 | flags.DEFINE_string('files_name_S', '', 'name common to all the shapes') 40 | 41 | flags.DEFINE_integer('max_train_iter', 10000, '') 42 | flags.DEFINE_integer('num_vertices', 3000, '') 43 | flags.DEFINE_integer('save_summaries_secs', 500, '') 44 | flags.DEFINE_integer('save_model_secs', 500, '') 45 | flags.DEFINE_string('log_dir_', 'Training/SCAPE_r_aligned/pointnet_fmnet_b8_lr_-4_30evec_.001_6k_2fc_aligned_tr-Sur_100_rand_sup_both_sides_1_1.001_sbb-/', 46 | 'directory to save models and results') 47 | flags.DEFINE_string('matches_dir_F', './Matches/FAUST_r_aligned/ptnet_surreal_80rand_3k_F_30_1_1_.001_b8_both_te_20/','directory to save models and results') 48 | flags.DEFINE_string('matches_dir_S', './Matches/SCAPE_r_aligned/ptnet_surreal_80_rand_3k_S_30_1_1_.001_b8_both_te_20/','directory to save models and results') 49 | flags.DEFINE_integer('dim_', 128,'') 50 | flags.DEFINE_integer('decay_step', 200000, help='Decay step for lr decay [default: 200000]') 51 | 52 | flags.DEFINE_float('decay_rate', 0.7, help='Decay rate for lr decay [default: 0.7]') 53 | 54 | # Globals 55 | dim_=FLAGS.dim_ 56 | flags.DEFINE_list('dims', [dim_,dim_,dim_,dim_, dim_, dim_, dim_], '') 57 | 58 | dim_1_layer = int(FLAGS.dims[0]) 59 | flags.DEFINE_integer('dim_shot', dim_1_layer, '') 60 | no_layers = FLAGS.num_fclayers 61 | 62 | last_layer = int(FLAGS.dims[no_layers-1]) 63 | flags.DEFINE_integer('dim_out', last_layer, '') 64 | 65 | vert_dir = FLAGS.feat_dir 66 | vert_dir_te_S = FLAGS.feat_dir_te_S 67 | vert_dir_te_F = FLAGS.feat_dir_te_F 68 | n_tr = 100 69 | 70 | #train_subjects = list(range(n_tr)) 71 | train_subjects = np.random.choice(1000,n_tr) 72 | test_subjects_F,test_subjects_S = (range(80,100),range(52,70)) 73 | main_dir = FLAGS.targets_dir 74 | files_name =FLAGS.files_name 75 | te_files_name_F = FLAGS.files_name_F 76 | te_files_name_S = FLAGS.files_name_S 77 | #te_files_name = '' 78 | test_dir_F=FLAGS.te_dir_F 79 | test_dir_S=FLAGS.te_dir_S 80 | 81 | 82 | DECAY_STEP = FLAGS.decay_step 83 | DECAY_RATE = FLAGS.decay_rate 84 | BN_INIT_DECAY = 0.5 85 | BN_DECAY_DECAY_RATE = 0.5 86 | BN_DECAY_DECAY_STEP = float(DECAY_STEP) 87 | BN_DECAY_CLIP = 0.99 88 | BATCH_SIZE = FLAGS.batch_size 89 | 90 | def get_input_pair(batch_size, num_vertices, dataset): 91 | 92 | batch_input = { 93 | 'source_evecs': np.zeros((batch_size, num_vertices, FLAGS.num_evecs)), 94 | 'target_evecs': np.zeros((batch_size, num_vertices, FLAGS.num_evecs)), 95 | 'source_evecs_trans': np.zeros((batch_size,FLAGS.num_evecs,num_vertices)), 96 | 'target_evecs_trans': np.zeros((batch_size,FLAGS.num_evecs,num_vertices)), 97 | 'source_shot': np.zeros((batch_size, num_vertices, 3)), 98 | 'target_shot': np.zeros((batch_size, num_vertices, 3)), 99 | 'source_evals': np.zeros((batch_size, FLAGS.num_evecs)), 100 | 'target_evals': np.zeros((batch_size, FLAGS.num_evecs)) 101 | } 102 | for i_batch in range(batch_size): 103 | i_source = train_subjects[np.random.choice(range(n_tr))] 104 | i_target = train_subjects[np.random.choice(range(n_tr))] 105 | 106 | batch_input_ = get_pair_from_ram(i_target, i_source, dataset) 107 | 108 | batch_input_['source_labels'] = range(np.shape(batch_input_['source_evecs'])[0]) 109 | batch_input_['target_labels'] = range(np.shape(batch_input_['target_evecs'])[0]) 110 | 111 | joint_lbls = np.intersect1d(batch_input_['source_labels'],batch_input_['target_labels']) 112 | #print(joint_lbls) 113 | joint_labels_source = np.random.permutation(joint_lbls)[:num_vertices] 114 | joint_labels_target = np.random.permutation(joint_lbls)[:num_vertices] 115 | 116 | ind_dict_source = {value: ind for ind, value in enumerate(batch_input_['source_labels'])} 117 | ind_source = [ind_dict_source[x] for x in joint_labels_source] 118 | 119 | ind_dict_target = {value: ind for ind, value in enumerate(batch_input_['target_labels'])} 120 | ind_target = [ind_dict_target[x] for x in joint_labels_target] 121 | 122 | message = "number of indices must be equal" 123 | assert len(ind_source) == len(ind_target), message 124 | 125 | evecs_full = batch_input_['source_evecs'] 126 | #print(evecs_full.shape) 127 | evecs= evecs_full[ind_source, :] 128 | evecs_trans = batch_input_['source_evecs_trans'][:, ind_source] 129 | shot = batch_input_['source_shot'][ind_source, :] 130 | #print(batch_input_['target_shot'].shape) 131 | evals = [item for sublist in batch_input_['source_evals'] for item in sublist] # what? 132 | batch_input['source_evecs'][i_batch] = evecs 133 | batch_input['source_evecs_trans'][i_batch] = evecs_trans 134 | batch_input['source_shot'][i_batch] = shot 135 | batch_input['source_evals'][i_batch] = evals 136 | 137 | evecs = batch_input_['target_evecs'][ind_target, :] 138 | evecs_trans = batch_input_['target_evecs_trans'][:, ind_target] 139 | shot = batch_input_['target_shot'][ind_target, :] 140 | evals = [item for sublist in batch_input_['target_evals'] for item in sublist] 141 | batch_input['target_evecs'][i_batch] = evecs 142 | batch_input['target_evecs_trans'][i_batch] = evecs_trans 143 | batch_input['target_shot'][i_batch] = shot 144 | batch_input['target_evals'][i_batch] = evals 145 | return batch_input 146 | 147 | 148 | 149 | 150 | def get_input_pair_test(i_target, i_source, sub_): 151 | batch_input = {} 152 | batch_input_ = get_pair_from_ram(i_target, i_source, sub_) 153 | 154 | evecs = batch_input_['source_evecs'] 155 | evecs_trans = batch_input_['source_evecs_trans'] 156 | shot = batch_input_['source_shot'] 157 | evals = [item for sublist in batch_input_['source_evals'] for item in sublist] 158 | batch_input['source_evecs'] = evecs 159 | batch_input['source_evecs_trans'] = evecs_trans 160 | batch_input['source_shot'] = shot 161 | batch_input['source_evals'] = evals 162 | 163 | evecs = batch_input_['target_evecs'] 164 | evecs_trans = batch_input_['target_evecs_trans'] 165 | shot = batch_input_['target_shot'] 166 | evals = [item for sublist in batch_input_['target_evals'] for item in sublist] 167 | batch_input['target_evecs'] = evecs 168 | batch_input['target_evecs_trans'] = evecs_trans 169 | batch_input['target_shot'] = shot 170 | batch_input['target_evals'] = evals 171 | return batch_input 172 | 173 | def get_pair_from_ram(i_target, i_source, sub_): 174 | 175 | input_data = {} 176 | if sub_ == 'train': 177 | targets_= targets_train 178 | elif sub_== 'te_F': 179 | targets_=targets_test_F 180 | else: 181 | targets_=targets_test_S 182 | 183 | evecs = targets_[i_source]['target_evecs'] 184 | evecs_trans = targets_[i_source]['target_evecs_trans'] 185 | shot = targets_[i_source]['target_shot'] 186 | evals = targets_[i_source]['target_evals'] 187 | input_data['source_evecs'] = evecs 188 | input_data['source_evecs_trans'] = evecs_trans 189 | input_data['source_shot'] = shot 190 | input_data['source_evals'] = evals 191 | input_data.update(targets_[i_target]) 192 | 193 | return input_data 194 | 195 | def load_targets_to_ram(): 196 | global targets_train,targets_test_F,targets_test_S 197 | targets_train,targets_test_F,targets_test_S = ({},{},{}) 198 | 199 | targets_train = load_subs(train_subjects, main_dir, vert_dir,files_name) 200 | targets_test_F= load_subs(test_subjects_F,test_dir_F, vert_dir_te_F, te_files_name_F) 201 | targets_test_S =load_subs(test_subjects_S,test_dir_S, vert_dir_te_S,te_files_name_S) 202 | 203 | def load_subs(subjects_list, dir_name,v_dir,f_name): 204 | targets = {} 205 | 206 | for i_target in subjects_list: 207 | target_file = dir_name +f_name +'%.4d.mat' % (i_target) 208 | vert_file = v_dir +f_name +'%.4d.mat' % (i_target) 209 | #print(vert_file) 210 | input_data = sio.loadmat(target_file) 211 | evecs = input_data['target_evecs'][:, 0:FLAGS.num_evecs] 212 | evecs_trans = input_data['target_evecs_trans'][0:FLAGS.num_evecs,:] 213 | evals = input_data['target_evals'][0:FLAGS.num_evecs] 214 | input_data['target_evecs'] = evecs 215 | input_data['target_evecs_trans'] = evecs_trans 216 | input_data['target_evals'] = evals 217 | p_feat = sio.loadmat(vert_file) 218 | input_data['target_shot'] =[] 219 | input_data['target_shot'] = p_feat['VERT'] 220 | targets[i_target] = input_data 221 | 222 | return targets 223 | 224 | 225 | def get_bn_decay(batch): 226 | bn_momentum = tf.train.exponential_decay( 227 | BN_INIT_DECAY, 228 | batch*BATCH_SIZE, 229 | BN_DECAY_DECAY_STEP, 230 | BN_DECAY_DECAY_RATE, 231 | staircase=True) 232 | bn_decay = tf.minimum(BN_DECAY_CLIP, 1 - bn_momentum) 233 | return bn_decay 234 | 235 | def run_training(): 236 | 237 | print('log_dir=%s' % FLAGS.log_dir_) 238 | if not os.path.isdir(FLAGS.log_dir_): 239 | os.makedirs(FLAGS.log_dir_) 240 | 241 | print('building graph...') 242 | 243 | with tf.Graph().as_default(): 244 | # Set placeholders for inputs 245 | source_shot = tf.placeholder(tf.float32,shape=(None, None, 3),name='source_shot') 246 | target_shot = tf.placeholder(tf.float32, shape=(None, None, 3),name='target_shot') 247 | 248 | source_evecs = tf.placeholder(tf.float32, shape=(None, None, FLAGS.num_evecs), name='source_evecs') 249 | source_evecs_trans = tf.placeholder(tf.float32,shape=(None, FLAGS.num_evecs, None),name='source_evecs_trans') 250 | source_evals = tf.placeholder(tf.float32,shape=(None, FLAGS.num_evecs),name='source_evals') 251 | target_evecs = tf.placeholder(tf.float32,shape=(None, None, FLAGS.num_evecs),name='target_evecs') 252 | target_evecs_trans = tf.placeholder(tf.float32,shape=(None, FLAGS.num_evecs, None),name='target_evecs_trans') 253 | target_evals = tf.placeholder(tf.float32,shape=(None, FLAGS.num_evecs),name='target_evals') 254 | # train\test switch flag 255 | phase = tf.placeholder(dtype=tf.bool, name='phase') 256 | 257 | #is_training_pl = tf.placeholder(tf.bool, shape=()) 258 | #print (is_training_pl) 259 | batch = tf.Variable(0) 260 | bn_decay = get_bn_decay(batch) 261 | 262 | net_loss, safeguard_inverse, net,end_points, C, merged = get_model(phase, source_shot, target_shot, source_evecs, source_evecs_trans, 263 | source_evals, target_evecs, target_evecs_trans, target_evals, bn_decay) 264 | 265 | summary = tf.summary.scalar("num_evecs", float(FLAGS.num_evecs)) 266 | 267 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 268 | 269 | with tf.control_dependencies(update_ops): 270 | global_step = tf.Variable(0, name='global_step', trainable=False) 271 | optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate) 272 | train_op = optimizer.minimize(net_loss,global_step=global_step, aggregation_method=2) 273 | 274 | #saver = tf.train.Saver(max_to_keep=40) 275 | saver = tf.train.Saver(tf.global_variables()) 276 | 277 | sv = tf.train.Supervisor(logdir=FLAGS.log_dir_,init_op=tf.global_variables_initializer(),local_init_op=tf.local_variables_initializer(), 278 | global_step=global_step,save_summaries_secs=FLAGS.save_summaries_secs, 279 | save_model_secs=FLAGS.save_model_secs,summary_op=None,saver=saver) 280 | 281 | writer = sv.summary_writer 282 | config = tf.ConfigProto() 283 | config.gpu_options.allow_growth = True 284 | 285 | print('starting session...') 286 | iteration = 0 287 | 288 | with sv.managed_session(config=config) as sess: 289 | print('loading data to ram...') 290 | load_targets_to_ram() 291 | 292 | print('starting training loop...') 293 | while not sv.should_stop() and iteration < FLAGS.max_train_iter: 294 | 295 | iteration += 1 296 | start_time = time.time() 297 | input_data = get_input_pair(FLAGS.batch_size, FLAGS.num_vertices, 'train') 298 | 299 | feed_dict = {phase: False, source_shot: input_data['source_shot'], target_shot: input_data['target_shot'], 300 | source_evecs: input_data['source_evecs'], source_evecs_trans: input_data['source_evecs_trans'], 301 | source_evals: input_data['source_evals'], target_evecs: input_data['target_evecs'], 302 | target_evecs_trans: input_data['target_evecs_trans'], target_evals: input_data['target_evals']} 303 | 304 | summaries, step, my_loss, safeguard, _ = sess.run([merged, global_step, net_loss, safeguard_inverse, train_op], 305 | feed_dict=feed_dict) 306 | 307 | writer.add_summary(summaries, step) 308 | summary_ = sess.run(summary) 309 | writer.add_summary(summary_, step) 310 | duration = time.time() - start_time 311 | print('train - step %d: loss = %.2f (%.3f sec)'% (step, my_loss, duration)) 312 | 313 | if iteration%1000==0: 314 | for i_source in range(52,69): 315 | for i_target in range(i_source+1,70): 316 | 317 | t = time.time() 318 | 319 | input_data = get_input_pair_test(i_target, i_source, 'te_S') 320 | source_evecs_ = input_data['source_evecs'][:, 0:FLAGS.num_evecs] 321 | target_evecs_ = input_data['target_evecs'][:, 0:FLAGS.num_evecs] 322 | 323 | feed_dict = { 324 | phase: False, 325 | source_shot: [input_data['source_shot']], 326 | target_shot: [input_data['target_shot']], 327 | source_evecs: [input_data['source_evecs']], 328 | source_evecs_trans: [input_data['source_evecs_trans']], 329 | source_evals: [input_data['source_evals']], 330 | target_evecs: [input_data['target_evecs']], 331 | target_evecs_trans: [input_data['target_evecs_trans']], 332 | target_evals: [input_data['target_evals']] 333 | } 334 | 335 | C_est_ = sess.run([C], feed_dict=feed_dict) 336 | Ct = np.squeeze(C_est_).T #Keep transposed 337 | 338 | kdt = cKDTree(np.matmul(source_evecs_, Ct)) 339 | 340 | dist, indices = kdt.query(target_evecs_, n_jobs=-1) 341 | indices = indices + 1 342 | 343 | print("Computed correspondences for pair: %s, %s." % (i_source, i_target) + 344 | " Took %f seconds" % (time.time() - t)) 345 | params_to_save = {} 346 | params_to_save['matches'] = indices 347 | #params_to_save['C'] = Ct.T 348 | # For Matlab where index start at 1 349 | new_dir = FLAGS.matches_dir_S + '%.3d-' % iteration + '/' 350 | 351 | if not os.path.isdir(new_dir): 352 | print('matches_dir=%s' % new_dir) 353 | os.makedirs(new_dir) 354 | 355 | sio.savemat(new_dir + '%.4d-' % i_source + '%.4d.mat' % i_target, params_to_save) 356 | 357 | for i_source in range(80,99): 358 | for i_target in range(i_source+1,100): 359 | 360 | t = time.time() 361 | 362 | input_data = get_input_pair_test(i_target, i_source, 'te_F') 363 | source_evecs_ = input_data['source_evecs'][:, 0:FLAGS.num_evecs] 364 | target_evecs_ = input_data['target_evecs'][:, 0:FLAGS.num_evecs] 365 | 366 | feed_dict = { 367 | phase: False, 368 | source_shot: [input_data['source_shot']], 369 | target_shot: [input_data['target_shot']], 370 | source_evecs: [input_data['source_evecs']], 371 | source_evecs_trans: [input_data['source_evecs_trans']], 372 | source_evals: [input_data['source_evals']], 373 | target_evecs: [input_data['target_evecs']], 374 | target_evecs_trans: [input_data['target_evecs_trans']], 375 | target_evals: [input_data['target_evals']] 376 | } 377 | 378 | C_est_ = sess.run([C], feed_dict=feed_dict) 379 | Ct = np.squeeze(C_est_).T #Keep transposed 380 | 381 | kdt = cKDTree(np.matmul(source_evecs_, Ct)) 382 | 383 | dist, indices = kdt.query(target_evecs_, n_jobs=-1) 384 | indices = indices + 1 385 | 386 | print("Computed correspondences for pair: %s, %s." % (i_source, i_target) + 387 | " Took %f seconds" % (time.time() - t)) 388 | params_to_save = {} 389 | params_to_save['matches'] = indices 390 | #params_to_save['C'] = Ct.T 391 | # For Matlab where index start at 1 392 | new_dir = FLAGS.matches_dir_F + '%.3d-' % iteration + '/' 393 | 394 | if not os.path.isdir(new_dir): 395 | print('matches_dir=%s' % new_dir) 396 | os.makedirs(new_dir) 397 | 398 | sio.savemat(new_dir + '%.4d-' % i_source + '%.4d.mat' % i_target, params_to_save) 399 | 400 | 401 | def main(_): 402 | import time 403 | start_time = time.time() 404 | run_training() 405 | print("--- %s seconds ---" % (time.time() - start_time)) 406 | 407 | 408 | if __name__ == '__main__': 409 | tf.app.run() 410 | 411 | 412 | 413 | --------------------------------------------------------------------------------