├── data └── __init__.py ├── checkpoints └── __init__.py ├── results └── __init__.py ├── .gitignore ├── .idea ├── .gitignore ├── vcs.xml ├── inspectionProfiles │ └── profiles_settings.xml ├── modules.xml ├── misc.xml ├── deployment.xml └── FMNet.pytorch.iml ├── imgs ├── kpnet_1.png ├── kpnet_2.png ├── ScreenCapture_2020-02-17-13-23-52.png ├── ScreenCapture_2020-02-17-13-24-07.png └── ScreenCapture_2020-02-17-13-25-17.png ├── utils ├── shot │ ├── CMakeLists.txt │ └── shot.cpp └── loss.py ├── visualize.py ├── README.md ├── train.py ├── test.py ├── preprocess.py ├── dataloader.py └── model.py /data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /checkpoints/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /results/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | debug* 2 | results/faust/* -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /workspace.xml -------------------------------------------------------------------------------- /imgs/kpnet_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlankCheng/FMNet.pytorch-DeepFunctionalMap/HEAD/imgs/kpnet_1.png -------------------------------------------------------------------------------- /imgs/kpnet_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlankCheng/FMNet.pytorch-DeepFunctionalMap/HEAD/imgs/kpnet_2.png -------------------------------------------------------------------------------- /imgs/ScreenCapture_2020-02-17-13-23-52.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlankCheng/FMNet.pytorch-DeepFunctionalMap/HEAD/imgs/ScreenCapture_2020-02-17-13-23-52.png -------------------------------------------------------------------------------- /imgs/ScreenCapture_2020-02-17-13-24-07.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlankCheng/FMNet.pytorch-DeepFunctionalMap/HEAD/imgs/ScreenCapture_2020-02-17-13-24-07.png -------------------------------------------------------------------------------- /imgs/ScreenCapture_2020-02-17-13-25-17.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlankCheng/FMNet.pytorch-DeepFunctionalMap/HEAD/imgs/ScreenCapture_2020-02-17-13-25-17.png -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | -------------------------------------------------------------------------------- /.idea/deployment.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | -------------------------------------------------------------------------------- /utils/shot/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 2.8.12) 2 | project(shot) 3 | set (CMAKE_CXX_STANDARD 11) 4 | 5 | find_package( PythonInterp 3.6 REQUIRED ) 6 | find_package( PythonLibs 3.6 REQUIRED ) 7 | find_package(pybind11 REQUIRED) 8 | find_package( PCL 1.8 REQUIRED ) 9 | 10 | include_directories( ${PCL_INCLUDE_DIRS} ) 11 | link_directories( ${PCL_LIBRARY_DIRS} "D:\\PCL 1.8.1\\3rdParty\\Boost\\lib" "D:\\PCL 1.8.1\\3rdParty\\FLANN\\lib" ) 12 | add_definitions(${PCL_DEFINITIONS}) 13 | 14 | pybind11_add_module(shot shot.cpp) 15 | target_link_libraries(shot PRIVATE ${PCL_LIBRARIES}) -------------------------------------------------------------------------------- /.idea/FMNet.pytorch.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 13 | -------------------------------------------------------------------------------- /utils/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class GeodesicLoss(nn.Module): 6 | """ 7 | Geodesic Loss. Calculate total geodesic distance difference between estimated corresponding pairs. 8 | """ 9 | def __init__(self): 10 | super(GeodesicLoss, self).__init__() 11 | 12 | def forward(self, Q, dist_x, dist_y): 13 | """ 14 | 15 | :param Q: B * 2048 * 2048, soft correspondence matrix 16 | :param dist_x: B * 2048 * 2048, geodesic distances of x 17 | :param dist_y: B * 2048 * 2048, geodesic distances of y 18 | :return: a scalar as loss 19 | """ 20 | criterion = nn.MSELoss(reduce=True) 21 | loss = criterion(dist_x, torch.bmm(Q.transpose(2, 1), torch.bmm(dist_y, Q))) 22 | return loss 23 | -------------------------------------------------------------------------------- /visualize.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pptk 3 | import open3d as o3d 4 | import scipy.io as sio 5 | import os 6 | 7 | result_dir = './results/faust/' 8 | if __name__ == '__main__': 9 | for root, _, files in os.walk(result_dir): 10 | for file in files: 11 | print(file) 12 | file_path = os.path.join(root, file) 13 | m = sio.loadmat(file_path) 14 | x, y, Q = m['x'], m['y'], m['Q'] 15 | Q = np.argmax(Q, axis=0) 16 | colors = (y + 1) / 2 17 | x_o3d, y_o3d = o3d.PointCloud(), o3d.PointCloud() 18 | x_o3d.points, y_o3d.points = o3d.Vector3dVector(x+2), o3d.Vector3dVector(y) 19 | x_o3d.colors, y_o3d.colors = o3d.Vector3dVector(colors[Q]), o3d.Vector3dVector(colors) 20 | o3d.draw_geometries([x_o3d, y_o3d]) 21 | 22 | # v = pptk.viewer(np.concatenate([y, x + 2.], axis=0), np.concatenate([colors, colors[Q]]), axis=0) 23 | # v.set(point_size=0.02) 24 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FMNet.pytorch 2 | A pytorch implementation of Deep Functional Maps (FMNet). 3 | 4 | ## Introduction 5 | This is a pytorch implementation of [Deep Functional Maps](https://arxiv.org/abs/1704.08686). Groundtruth labels of FAUST correspondence are not used. For efficiency, 2048 points are randomly sampled from 6890 points on original meshes. The results may not be bijective. 6 | 7 | **Update:** Visualization pairs of KeyPointNet are post-processed by PMF to be bijective, while faust pairs are not. 8 | 9 | ## Usage 10 | Build shot calculator: 11 | ~~~ 12 | cd utils/shot 13 | cmake . 14 | make 15 | ~~~ 16 | Calculate eigenvectors, geodesic maps, shot descriptors of trained models, save in .mat format: 17 | ~~~ 18 | python preprocess.py 19 | ~~~ 20 | Train: 21 | ~~~ 22 | python train.py --dataset=faust 23 | ~~~ 24 | Test(temporarily use trained data to test, for visualization): 25 | ~~~ 26 | python test.py --dataset=faust --model_name=epoch300.pth 27 | ~~~ 28 | Visualize correspondence: 29 | ~~~ 30 | python visualize.py 31 | ~~~ 32 | 33 | ## Visualization 34 | ![pair1](https://github.com/BlankCheng/FMNet.pytorch/raw/master/imgs/ScreenCapture_2020-02-17-13-23-52.png) 35 | ![pair2](https://github.com/BlankCheng/FMNet.pytorch/raw/master/imgs/ScreenCapture_2020-02-17-13-25-17.png) 36 | ![pair3](https://github.com/BlankCheng/FMNet.pytorch/raw/master/imgs/kpnet_1.png) 37 | ![pair4](https://github.com/BlankCheng/FMNet.pytorch/raw/master/imgs/kpnet_2.png) 38 | -------------------------------------------------------------------------------- /utils/shot/shot.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | namespace py = pybind11; 9 | using namespace pybind11::literals; 10 | 11 | py::array_t compute(py::array_t pc, double normal_r, double shot_r) 12 | { 13 | // Object for storing the point cloud. 14 | pcl::PointCloud::Ptr cloud(new pcl::PointCloud); 15 | cloud->points.resize(pc.shape(0)); 16 | float *pc_ptr = (float*)pc.request().ptr; 17 | for (int i = 0; i < pc.shape(0); ++i) 18 | { 19 | std::copy(pc_ptr, pc_ptr + 3, &cloud->points[i].data[0]); 20 | // std::cout << cloud->points[i] << std::endl; 21 | pc_ptr += 3; 22 | } 23 | 24 | // Object for storing the normals. 25 | pcl::PointCloud::Ptr normals(new pcl::PointCloud); 26 | // Object for storing the SHOT descriptors for each point. 27 | pcl::PointCloud::Ptr descriptors(new pcl::PointCloud()); 28 | 29 | // Note: you would usually perform downsampling now. It has been omitted here 30 | // for simplicity, but be aware that computation can take a long time. 31 | 32 | // Estimate the normals. 33 | pcl::NormalEstimation normalEstimation; 34 | normalEstimation.setInputCloud(cloud); 35 | // normalEstimation.setRadiusSearch(normal_r); 36 | normalEstimation.setKSearch(40); 37 | pcl::search::KdTree::Ptr kdtree(new pcl::search::KdTree); 38 | normalEstimation.setSearchMethod(kdtree); 39 | normalEstimation.compute(*normals); 40 | 41 | // SHOT estimation object. 42 | pcl::SHOTEstimation shot; 43 | shot.setInputCloud(cloud); 44 | shot.setInputNormals(normals); 45 | // The radius that defines which of the keypoint's neighbors are described. 46 | // If too large, there may be clutter, and if too small, not enough points may be found. 47 | shot.setRadiusSearch(shot_r); 48 | // shot.setKSearch(40); 49 | shot.compute(*descriptors); 50 | 51 | auto result = py::array_t(descriptors->points.size() * 352); 52 | auto buf = result.request(); 53 | float *ptr = (float*)buf.ptr; 54 | 55 | for (int i = 0; i < descriptors->points.size(); ++i) 56 | { 57 | std::copy(&descriptors->points[i].descriptor[0], &descriptors->points[i].descriptor[352], &ptr[i * 352]); 58 | } 59 | return result; 60 | } 61 | 62 | 63 | PYBIND11_MODULE(shot, m) { 64 | m.def("compute", &compute, py::arg("pc"), py::arg("normal_r")=0.1, py::arg("shot_r")=0.17); 65 | } -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from tqdm import tqdm 3 | import os 4 | import torch.nn.parallel 5 | import torch.optim as optim 6 | import torch.utils.data as DT 7 | 8 | from dataloader import FAUSTDataset, KeyPointDataset 9 | from model import FMNet 10 | from utils.loss import GeodesicLoss 11 | 12 | 13 | def arg_parse(): 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--dataset', type=str, choices=['faust', 'keypointnet']) 16 | parser.add_argument('--category', type=str, default='02691156') # only for keypointnet 17 | parser.add_argument('--root_path', type=str, default='./') 18 | parser.add_argument('--phase', type=str, default='train') 19 | parser.add_argument('--batch_size', type=int, default=8) 20 | parser.add_argument('--lr', type=str, default='1e-3') 21 | parser.add_argument('--max_epochs', type=int, default=300) 22 | parser.add_argument('--seed', type=int, default=-1) 23 | parser.add_argument('--gpu', type=str, default='-1') 24 | parser.add_argument('--num_workers', type=int, default=4) 25 | args = parser.parse_args() 26 | return args 27 | 28 | if __name__ == '__main__': 29 | # init 30 | args = arg_parse() 31 | if args.seed != -1: 32 | torch.manual_seed(args.seed) 33 | save_dir = os.path.join(args.root_path, 'checkpoints', args.dataset) 34 | 35 | # load train data 36 | if args.dataset == 'faust': 37 | train_dataset = FAUSTDataset(args=args) 38 | elif args.dataset == 'keypointnet': 39 | train_dataset = KeyPointDataset(args=args) 40 | else: 41 | raise NotImplementedError("Dataset not implemented now.") 42 | train_dataloader = DT.DataLoader( 43 | dataset=train_dataset, 44 | batch_size=args.batch_size, 45 | shuffle=True, 46 | num_workers=args.num_workers, 47 | drop_last=False 48 | ) 49 | 50 | # network config 51 | net = FMNet() 52 | optimizer = optim.Adam(net.parameters(), lr=eval(args.lr), betas=(0.9, 0.999)) 53 | schedular = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5) 54 | criterion = GeodesicLoss() 55 | net = net.cuda().float() 56 | 57 | # train 58 | for epoch in range(1, args.max_epochs+1): 59 | schedular.step() 60 | i = 0 61 | for data in tqdm(train_dataloader): 62 | x, y, evecs_x, evecs_y, feat_x, feat_y, dist_x, dist_y = data # FAUST has no keypoints 63 | x, y, evecs_x, evecs_y, feat_x, feat_y, dist_x, dist_y = \ 64 | x.cuda(), y.cuda(), evecs_x.cuda(), evecs_y.cuda(), feat_x.cuda(), feat_y.cuda(), dist_x.cuda(), dist_y.cuda() 65 | optimizer.zero_grad() 66 | net = net.train() 67 | Q, C = net(feat_x, feat_y, evecs_x, evecs_y) 68 | loss = criterion(Q, dist_x, dist_y) 69 | loss.backward() 70 | optimizer.step() 71 | i += 1 72 | print("#epoch:{}, #batch:{}, loss:{}".format(epoch, i+1, loss)) 73 | 74 | if epoch % 20 == 0: 75 | torch.save(net.state_dict(), os.path.join(save_dir, 'epoch{}.pth'.format(epoch))) 76 | 77 | 78 | 79 | 80 | 81 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from tqdm import tqdm 3 | import os 4 | import scipy.io as sio 5 | import torch 6 | import torch.nn.parallel 7 | import torch.utils.data as DT 8 | 9 | from dataloader import FAUSTDataset, KeyPointDataset 10 | from model import FMNet 11 | from utils.loss import GeodesicLoss 12 | 13 | 14 | def arg_parse(): 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--dataset', type=str, choices=['faust', 'keypointnet']) 17 | parser.add_argument('--category', type=str, default='02691156') # only for keypointnet 18 | parser.add_argument('--root_path', type=str, default='./') 19 | parser.add_argument('--model_name', type=str, default=None) 20 | parser.add_argument('--phase', type=str, default='train') 21 | parser.add_argument('--batch_size', type=int, default=1) 22 | parser.add_argument('--seed', type=int, default=-1) 23 | parser.add_argument('--gpu', type=str, default='-1') 24 | parser.add_argument('--num_workers', type=int, default=4) 25 | args = parser.parse_args() 26 | return args 27 | 28 | if __name__ == '__main__': 29 | # init 30 | args = arg_parse() 31 | if args.seed != -1: 32 | torch.manual_seed(args.seed) 33 | 34 | # load test data 35 | if args.dataset == 'faust': 36 | test_dataset = FAUSTDataset(args=args) 37 | elif args.dataset == 'keypointnet': 38 | test_dataset = KeyPointDataset(args=args) 39 | else: 40 | raise NotImplementedError("Dataset not implemented now.") 41 | test_dataloader = DT.DataLoader( 42 | dataset=test_dataset, 43 | batch_size=args.batch_size, 44 | shuffle=False, 45 | num_workers=args.num_workers, 46 | drop_last=False 47 | ) 48 | 49 | # network config 50 | model_path = os.path.join(args.root_path, 'checkpoints', args.dataset, args.model_name) 51 | if not os.path.exists(model_path): 52 | print("No trained model in {}".format(model_path)) 53 | exit(-1) 54 | print(model_path) 55 | net = FMNet() 56 | net.load_state_dict(torch.load(model_path)) 57 | net = net.cuda().float() 58 | criterion = GeodesicLoss() 59 | 60 | # train 61 | i = 0 62 | with torch.no_grad(): 63 | for data in tqdm(test_dataloader): 64 | i += 1 65 | x, y, evecs_x, evecs_y, feat_x, feat_y, dist_x, dist_y = data # FAUST has no keypoints 66 | x, y, evecs_x, evecs_y, feat_x, feat_y, dist_x, dist_y = \ 67 | x.cuda(), y.cuda(), evecs_x.cuda(), evecs_y.cuda(), feat_x.cuda(), feat_y.cuda(), dist_x.cuda(), dist_y.cuda() 68 | net = net.eval() 69 | Q, C = net(feat_x, feat_y, evecs_x, evecs_y) 70 | loss = criterion(Q, dist_x, dist_y) 71 | print("loss:{}".format(loss)) 72 | sio.savemat(os.path.join(args.root_path, 'results', args.dataset, 'pair{}.mat'.format(i)), # note #pair is not related to #ply 73 | {'x': x.cpu().detach().numpy().squeeze(), 74 | 'y': y.cpu().detach().numpy().squeeze(), 75 | 'Q': Q.cpu().detach().numpy().squeeze(), 76 | }) 77 | 78 | 79 | 80 | 81 | 82 | 83 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | """ 2 | Preprocessing codes are provided by You Yang: qq456cvb@github 3 | """ 4 | import scipy.io as sio 5 | from sklearn import neighbors 6 | from sklearn.utils.graph import graph_shortest_path 7 | import numpy as np 8 | from scipy.linalg import eigh 9 | from timeit import default_timer as timer 10 | import os 11 | import utils.shot.shot as shot 12 | import plyfile 13 | 14 | NUM_EIGENS = 150 15 | NUM_POINTS = 6890 16 | NUM_POINTS_SAMPLE = 2048 17 | NORMAL_R = 0.1 18 | SHOT_R = 0.1 19 | KNN = 20 20 | DESCRIPTOR = 'shot' 21 | MODELS_DIR = '~/MPI-FAUST/training/registrations' 22 | SAVE_DIR = './data/faust/train' 23 | 24 | def normalize_adj(adj): 25 | rowsum = np.sum(adj, axis=1) 26 | d_inv_sqrt = np.power(rowsum, -0.5).flatten() 27 | d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0. 28 | d_mat_inv_sqrt = np.diag(d_inv_sqrt) 29 | return d_mat_inv_sqrt @ adj @ d_mat_inv_sqrt 30 | 31 | 32 | def normalized_laplacian(adj): 33 | adj_normalized = normalize_adj(adj) 34 | norm_laplacian = np.eye(adj.shape[0]) - adj_normalized 35 | return norm_laplacian 36 | 37 | 38 | if __name__ == '__main__': 39 | cnt = 0 40 | for root, dirs, files in os.walk(MODELS_DIR): 41 | print(root, dirs, files) 42 | for file in files: 43 | if file[-4:] == '.png': 44 | continue 45 | cnt += 1 46 | file_path = os.path.join(root, file) 47 | ply = plyfile.PlyData.read(file_path) 48 | x = ply.elements[0].data 49 | x = np.stack((x['x'], x['y'], x['z']), axis=-1) 50 | np.random.seed(cnt) # assure different randoms 51 | random_indices = np.random.choice(NUM_POINTS, NUM_POINTS_SAMPLE) 52 | x = x[random_indices] 53 | 54 | x = x - np.mean(x, axis=0, keepdims=True) 55 | x = x / np.max(np.linalg.norm(x, axis=1)) 56 | 57 | t = timer() 58 | graph_x_csr = neighbors.kneighbors_graph(x, KNN, mode='distance', include_self=False) 59 | print('knn time:', timer() - t) 60 | 61 | t = timer() 62 | graph_x = graph_x_csr.toarray() # derive a n*n matrix, each element is distance(if knn) or 0(if not knn) 63 | graph_x = np.exp(- graph_x ** 2 / np.mean(graph_x[:, -1]) ** 2) 64 | graph_x = graph_x + (graph_x.T - graph_x) * np.greater(graph_x.T, graph_x).astype( 65 | np.float) 66 | 67 | laplacian_x = normalized_laplacian(graph_x) 68 | print('laplacian time:', timer() - t) 69 | 70 | t = timer() 71 | _, eigen_x = eigh(laplacian_x, 72 | eigvals=(laplacian_x.shape[0] - NUM_EIGENS, laplacian_x.shape[0] - 1)) 73 | print('eigen function time:', timer() - t) 74 | 75 | t = timer() 76 | geodesic_x = graph_shortest_path(graph_x_csr, 77 | directed=False) # directed=False means geodesic, not l2 distance 78 | print('geodesic matrix time:', timer() - t) 79 | 80 | t = timer() 81 | 82 | if DESCRIPTOR == 'shot': 83 | feature_x = shot.compute(x, NORMAL_R, SHOT_R).reshape(-1, 352) 84 | 85 | print('shot time:', timer() - t) 86 | feature_x[np.where(np.isnan(feature_x))] = 0 87 | 88 | print(os.path.join(SAVE_DIR, '{}.mat'.format(file[:-4]))) 89 | sio.savemat(os.path.join(SAVE_DIR, '{}.mat'.format(file[:-4])), 90 | {'pcd': x, 91 | 'evecs': eigen_x, 92 | 'feat': feature_x, 93 | 'dist': geodesic_x, 94 | }) 95 | -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.utils.data as DT 4 | import scipy.io as sio 5 | 6 | class FAUSTDataset(DT.Dataset): 7 | def __init__(self, args): 8 | super(FAUSTDataset, self).__init__() 9 | self.data_path = os.path.join(args.root_path, 'data', args.dataset, args.phase) 10 | self.pcds = [] 11 | self.evecs = [] 12 | self.feats = [] 13 | self.dists = [] # not input into network, only for loss 14 | for root, dirs, files in os.walk(self.data_path): 15 | for file in files: 16 | f = sio.loadmat(os.path.join(root, file)) 17 | self.pcds.append(f['pcd']) 18 | self.evecs.append(f['evecs']) 19 | self.feats.append(f['feat']) 20 | self.dists.append(f['dist']) 21 | 22 | def __getitem__(self, item): 23 | """ 24 | 25 | :param item: index, index*2 are chosen as pairs 26 | :return: (pcd_x, pcd_y, evecs_x, evecs_y, feat_x, feat_y, keypoints_x, keypoints_y, dist_x, dist_y) 27 | """ 28 | x, y = self.pcds[item], self.pcds[item*2] 29 | evecs_x, evecs_y = self.evecs[item], self.evecs[item*2] # not rotational invariant 30 | feat_x, feat_y = self.feats[item], self.feats[item*2] # not rotational invatriant 31 | dist_x, dist_y = self.dists[item], self.dists[item*2] 32 | return torch.tensor(x).float(), torch.tensor(y).float(),\ 33 | torch.tensor(evecs_x).float(), torch.tensor(evecs_y).float(), \ 34 | torch.tensor(feat_x).float(), torch.tensor(feat_y).float(), \ 35 | torch.tensor(dist_x).float(), torch.tensor(dist_y).float() 36 | 37 | def __len__(self): 38 | return len(self.pcds) // 2 39 | 40 | 41 | class KeyPointDataset(DT.Dataset): 42 | def __init__(self, args): 43 | super(KeyPointDataset, self).__init__() 44 | self.data_path = os.path.join(args.root_path, 'data', args.dataset, args.phase, args.category) 45 | self.pcds = [] 46 | self.evecs = [] 47 | self.feats = [] 48 | self.keypoints = [] # not input into network, only for loss 49 | self.dists = [] # not input into network, only for loss 50 | for root, dirs, files in os.walk(self.data_path): 51 | for file in files: 52 | f = sio.loadmat(os.path.join(root, file)) 53 | self.pcds.append(f['pcd']) 54 | self.evecs.append(f['evecs']) 55 | self.feats.append(f['feat']) 56 | self.keypoints.append(f['keypoints']) 57 | self.dists.append(f['dist']) 58 | 59 | def __getitem__(self, item): 60 | """ 61 | 62 | :param item: index, index*2 are chosen as pairs 63 | :return: (pcd_x, pcd_y, evecs_x, evecs_y, feat_x, feat_y, keypoints_x, keypoints_y, dist_x, dist_y) 64 | """ 65 | x, y = self.pcds[item], self.pcds[item*2] 66 | evecs_x, evecs_y = self.evecs[item], self.evecs[item*2] # related with rotation and translation 67 | feat_x, feat_y = self.feats[item], self.feats[item*2] # related with rotation and translation 68 | keypoints_x, keypoints_y = self.keypoints[item], self.keypoints[item*2] 69 | dist_x, dist_y = self.dists[item], self.dists[item*2] 70 | return torch.tensor(x).float(), torch.tensor(y).float(),\ 71 | torch.tensor(evecs_x).float(), torch.tensor(evecs_y).float(), \ 72 | torch.tensor(feat_x).float(), torch.tensor(feat_y).float(), \ 73 | torch.tensor(keypoints_x).int(), torch.tensor(keypoints_y).int(), \ 74 | torch.tensor(dist_x).float(), torch.tensor(dist_y).float() 75 | 76 | def __len__(self): 77 | return len(self.pcds) // 2 78 | 79 | # DATA DIR STRUCTURE 80 | # --FMNet.pytorch 81 | # --data 82 | # --KeypointNet 83 | # --train 84 | # --03001627 85 | # --1.mat 86 | # --2.mat 87 | # --3.mat 88 | # --02691159 89 | # --1.mat 90 | # --2.mat 91 | # --3.mat 92 | # --FAUST 93 | # --train 94 | # --1.mat 95 | # --test 96 | # --1.mat 97 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.utils.data 4 | import torch.nn.functional as F 5 | 6 | NUM_POINTS = 2048 7 | NUM_EVECS = 150 8 | 9 | class FMNet(nn.Module): 10 | def __init__(self): 11 | super(FMNet, self).__init__() 12 | self.feat_refiner = FeatRefineLayer(in_channels=352) 13 | self.corres = CorresLayer() 14 | 15 | def forward(self, feat_x, feat_y, evecs_x, evecs_y): 16 | """ 17 | 18 | :param feat_x: B * 2048 * 352, handcrafted point-wise feature of x 19 | :param feat_y: B * 2048 * 352, handcrafted point-wise feature of y 20 | :param evecs_x: B * 2048 * 150, Laplace basis of x, each column is a basis vector 21 | :param evecs_y: B * 2048 * 150, Laplace basis of y, each column is a basis vector 22 | :return: Q(2048*2048) is soft point-wise correspondence, C(150*150) functional map 23 | """ 24 | feat_x, feat_y = self.feat_refiner(feat_x), self.feat_refiner(feat_y) 25 | Q, C = self.corres(feat_x, feat_y, evecs_x, evecs_y) 26 | return Q, C 27 | 28 | 29 | class FeatRefineLayer(nn.Module): 30 | def __init__(self, in_channels=352, out_channels_list=None): 31 | super(FeatRefineLayer, self).__init__() 32 | self.in_channels = in_channels 33 | if out_channels_list is None: 34 | out_channels_list = [352, 352, 352, 352, 352, 352, 352] 35 | self.out_channels_list = out_channels_list 36 | self.res_layers = nn.ModuleList() 37 | for out_channels in self.out_channels_list: 38 | self.res_layers.append(ResLayer(in_channels, out_channels, out_channels)) 39 | in_channels = out_channels 40 | 41 | def forward(self, x): 42 | """ 43 | 44 | :param x: B * 2048 * 352, handcrafted point-wise feature of x 45 | :return: B * 2048 * 352, refined point-wise feature of x 46 | """ 47 | for res_layer in self.res_layers: 48 | x = res_layer(x) 49 | return x 50 | 51 | 52 | class ResLayer(nn.ModuleList): 53 | def __init__(self, in_channels, out_channels, mid_channels): 54 | super(ResLayer, self).__init__() 55 | self.in_channels = in_channels 56 | self.out_channels = out_channels 57 | self.fc1 = nn.Linear(in_channels, mid_channels) 58 | self.bn1 = nn.BatchNorm1d(num_features=mid_channels, eps=1e-3, momentum=1e-3) 59 | self.fc2 = nn.Linear(mid_channels, out_channels) 60 | self.bn2 = nn.BatchNorm1d(num_features=out_channels, eps=1e-3, momentum=1e-3) 61 | self.fc3 = None 62 | if in_channels != out_channels: 63 | self.fc3 = nn.Linear(in_channels, out_channels) 64 | 65 | def forward(self, x): 66 | """ 67 | 68 | :param x: B * 2048 * 352, refining point-wise feature of x 69 | :return: B * 2048 * 352, refining point-wise feature of x 70 | """ 71 | x_res = F.relu(self.bn1(self.fc1(x).transpose(1, 2)).transpose(1, 2)) 72 | x_res = self.bn2(self.fc2(x_res).transpose(1, 2)).transpose(1, 2) 73 | if self.in_channels != self.out_channels: 74 | x = self.fc3(x) 75 | x_res += x 76 | return x_res 77 | 78 | 79 | 80 | class CorresLayer(nn.Module): 81 | def __init__(self): 82 | super(CorresLayer, self).__init__() 83 | 84 | def forward(self, feat_x, feat_y, evecs_x, evecs_y): 85 | """ 86 | 87 | :param feat_x: B * 2048 * 352, refined point-wise feature of x 88 | :param feat_y: B * 2048 * 352, refined point-wise feature of y 89 | :param evecs_x: B * 2048 * 150, Laplace basis of x, each column is a basis vector 90 | :param evecs_y: B * 2048 * 150, Laplace basis of y, each column is a basis vector 91 | :return: Q and C 92 | """ 93 | # solve ls C*A=B, i.e. A.T*C.T=B.T 94 | batch_size = feat_x.size(0) 95 | A = torch.bmm(evecs_x.transpose(2, 1), feat_x) 96 | B = torch.bmm(evecs_y.transpose(2, 1), feat_y) 97 | A, B = A.transpose(2, 1), B.transpose(2, 1) 98 | for i in range(batch_size): 99 | # C_i, _ = torch.gels(B[i], A[i]) 100 | C_i = torch.inverse(A[i].transpose(1, 0)@ A[i]) @ A[i].transpose(1, 0) @ B[i] # C=(A.T*A)^{-1}*A.T*B 101 | if i == 0: 102 | C = C_i.unsqueeze(0)[:, :NUM_EVECS, :] 103 | else: 104 | C = torch.cat((C, C_i.unsqueeze(0)[:, :NUM_EVECS, :]), dim=0) 105 | C = C.transpose(2, 1) 106 | # function map-> point2point map 107 | P = abs(torch.bmm(torch.bmm(evecs_y, C), evecs_x.transpose(2, 1))) 108 | Q = F.normalize(P, 2, 1) ** 2 109 | return Q, C 110 | 111 | if __name__ == '__main__': 112 | feat_x = torch.rand(8, 2048, 352) 113 | feat_y = torch.rand(8, 2048, 352) 114 | evecs_x = torch.rand(8, 2048, 150) 115 | evecs_y = torch.rand(8, 2048, 150) 116 | net = FMNet() 117 | Q, C = net(feat_x, feat_y, evecs_x, evecs_y) 118 | print(torch.sum(Q, 1)) 119 | print(C.shape) 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | --------------------------------------------------------------------------------