├── data
└── __init__.py
├── checkpoints
└── __init__.py
├── results
└── __init__.py
├── .gitignore
├── .idea
├── .gitignore
├── vcs.xml
├── inspectionProfiles
│ └── profiles_settings.xml
├── modules.xml
├── misc.xml
├── deployment.xml
└── FMNet.pytorch.iml
├── imgs
├── kpnet_1.png
├── kpnet_2.png
├── ScreenCapture_2020-02-17-13-23-52.png
├── ScreenCapture_2020-02-17-13-24-07.png
└── ScreenCapture_2020-02-17-13-25-17.png
├── utils
├── shot
│ ├── CMakeLists.txt
│ └── shot.cpp
└── loss.py
├── visualize.py
├── README.md
├── train.py
├── test.py
├── preprocess.py
├── dataloader.py
└── model.py
/data/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/checkpoints/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/results/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | debug*
2 | results/faust/*
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /workspace.xml
--------------------------------------------------------------------------------
/imgs/kpnet_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlankCheng/FMNet.pytorch-DeepFunctionalMap/HEAD/imgs/kpnet_1.png
--------------------------------------------------------------------------------
/imgs/kpnet_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlankCheng/FMNet.pytorch-DeepFunctionalMap/HEAD/imgs/kpnet_2.png
--------------------------------------------------------------------------------
/imgs/ScreenCapture_2020-02-17-13-23-52.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlankCheng/FMNet.pytorch-DeepFunctionalMap/HEAD/imgs/ScreenCapture_2020-02-17-13-23-52.png
--------------------------------------------------------------------------------
/imgs/ScreenCapture_2020-02-17-13-24-07.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlankCheng/FMNet.pytorch-DeepFunctionalMap/HEAD/imgs/ScreenCapture_2020-02-17-13-24-07.png
--------------------------------------------------------------------------------
/imgs/ScreenCapture_2020-02-17-13-25-17.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlankCheng/FMNet.pytorch-DeepFunctionalMap/HEAD/imgs/ScreenCapture_2020-02-17-13-25-17.png
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/deployment.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/utils/shot/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | cmake_minimum_required(VERSION 2.8.12)
2 | project(shot)
3 | set (CMAKE_CXX_STANDARD 11)
4 |
5 | find_package( PythonInterp 3.6 REQUIRED )
6 | find_package( PythonLibs 3.6 REQUIRED )
7 | find_package(pybind11 REQUIRED)
8 | find_package( PCL 1.8 REQUIRED )
9 |
10 | include_directories( ${PCL_INCLUDE_DIRS} )
11 | link_directories( ${PCL_LIBRARY_DIRS} "D:\\PCL 1.8.1\\3rdParty\\Boost\\lib" "D:\\PCL 1.8.1\\3rdParty\\FLANN\\lib" )
12 | add_definitions(${PCL_DEFINITIONS})
13 |
14 | pybind11_add_module(shot shot.cpp)
15 | target_link_libraries(shot PRIVATE ${PCL_LIBRARIES})
--------------------------------------------------------------------------------
/.idea/FMNet.pytorch.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
--------------------------------------------------------------------------------
/utils/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class GeodesicLoss(nn.Module):
6 | """
7 | Geodesic Loss. Calculate total geodesic distance difference between estimated corresponding pairs.
8 | """
9 | def __init__(self):
10 | super(GeodesicLoss, self).__init__()
11 |
12 | def forward(self, Q, dist_x, dist_y):
13 | """
14 |
15 | :param Q: B * 2048 * 2048, soft correspondence matrix
16 | :param dist_x: B * 2048 * 2048, geodesic distances of x
17 | :param dist_y: B * 2048 * 2048, geodesic distances of y
18 | :return: a scalar as loss
19 | """
20 | criterion = nn.MSELoss(reduce=True)
21 | loss = criterion(dist_x, torch.bmm(Q.transpose(2, 1), torch.bmm(dist_y, Q)))
22 | return loss
23 |
--------------------------------------------------------------------------------
/visualize.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pptk
3 | import open3d as o3d
4 | import scipy.io as sio
5 | import os
6 |
7 | result_dir = './results/faust/'
8 | if __name__ == '__main__':
9 | for root, _, files in os.walk(result_dir):
10 | for file in files:
11 | print(file)
12 | file_path = os.path.join(root, file)
13 | m = sio.loadmat(file_path)
14 | x, y, Q = m['x'], m['y'], m['Q']
15 | Q = np.argmax(Q, axis=0)
16 | colors = (y + 1) / 2
17 | x_o3d, y_o3d = o3d.PointCloud(), o3d.PointCloud()
18 | x_o3d.points, y_o3d.points = o3d.Vector3dVector(x+2), o3d.Vector3dVector(y)
19 | x_o3d.colors, y_o3d.colors = o3d.Vector3dVector(colors[Q]), o3d.Vector3dVector(colors)
20 | o3d.draw_geometries([x_o3d, y_o3d])
21 |
22 | # v = pptk.viewer(np.concatenate([y, x + 2.], axis=0), np.concatenate([colors, colors[Q]]), axis=0)
23 | # v.set(point_size=0.02)
24 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # FMNet.pytorch
2 | A pytorch implementation of Deep Functional Maps (FMNet).
3 |
4 | ## Introduction
5 | This is a pytorch implementation of [Deep Functional Maps](https://arxiv.org/abs/1704.08686). Groundtruth labels of FAUST correspondence are not used. For efficiency, 2048 points are randomly sampled from 6890 points on original meshes. The results may not be bijective.
6 |
7 | **Update:** Visualization pairs of KeyPointNet are post-processed by PMF to be bijective, while faust pairs are not.
8 |
9 | ## Usage
10 | Build shot calculator:
11 | ~~~
12 | cd utils/shot
13 | cmake .
14 | make
15 | ~~~
16 | Calculate eigenvectors, geodesic maps, shot descriptors of trained models, save in .mat format:
17 | ~~~
18 | python preprocess.py
19 | ~~~
20 | Train:
21 | ~~~
22 | python train.py --dataset=faust
23 | ~~~
24 | Test(temporarily use trained data to test, for visualization):
25 | ~~~
26 | python test.py --dataset=faust --model_name=epoch300.pth
27 | ~~~
28 | Visualize correspondence:
29 | ~~~
30 | python visualize.py
31 | ~~~
32 |
33 | ## Visualization
34 | 
35 | 
36 | 
37 | 
38 |
--------------------------------------------------------------------------------
/utils/shot/shot.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 | #include
6 | #include
7 |
8 | namespace py = pybind11;
9 | using namespace pybind11::literals;
10 |
11 | py::array_t compute(py::array_t pc, double normal_r, double shot_r)
12 | {
13 | // Object for storing the point cloud.
14 | pcl::PointCloud::Ptr cloud(new pcl::PointCloud);
15 | cloud->points.resize(pc.shape(0));
16 | float *pc_ptr = (float*)pc.request().ptr;
17 | for (int i = 0; i < pc.shape(0); ++i)
18 | {
19 | std::copy(pc_ptr, pc_ptr + 3, &cloud->points[i].data[0]);
20 | // std::cout << cloud->points[i] << std::endl;
21 | pc_ptr += 3;
22 | }
23 |
24 | // Object for storing the normals.
25 | pcl::PointCloud::Ptr normals(new pcl::PointCloud);
26 | // Object for storing the SHOT descriptors for each point.
27 | pcl::PointCloud::Ptr descriptors(new pcl::PointCloud());
28 |
29 | // Note: you would usually perform downsampling now. It has been omitted here
30 | // for simplicity, but be aware that computation can take a long time.
31 |
32 | // Estimate the normals.
33 | pcl::NormalEstimation normalEstimation;
34 | normalEstimation.setInputCloud(cloud);
35 | // normalEstimation.setRadiusSearch(normal_r);
36 | normalEstimation.setKSearch(40);
37 | pcl::search::KdTree::Ptr kdtree(new pcl::search::KdTree);
38 | normalEstimation.setSearchMethod(kdtree);
39 | normalEstimation.compute(*normals);
40 |
41 | // SHOT estimation object.
42 | pcl::SHOTEstimation shot;
43 | shot.setInputCloud(cloud);
44 | shot.setInputNormals(normals);
45 | // The radius that defines which of the keypoint's neighbors are described.
46 | // If too large, there may be clutter, and if too small, not enough points may be found.
47 | shot.setRadiusSearch(shot_r);
48 | // shot.setKSearch(40);
49 | shot.compute(*descriptors);
50 |
51 | auto result = py::array_t(descriptors->points.size() * 352);
52 | auto buf = result.request();
53 | float *ptr = (float*)buf.ptr;
54 |
55 | for (int i = 0; i < descriptors->points.size(); ++i)
56 | {
57 | std::copy(&descriptors->points[i].descriptor[0], &descriptors->points[i].descriptor[352], &ptr[i * 352]);
58 | }
59 | return result;
60 | }
61 |
62 |
63 | PYBIND11_MODULE(shot, m) {
64 | m.def("compute", &compute, py::arg("pc"), py::arg("normal_r")=0.1, py::arg("shot_r")=0.17);
65 | }
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from tqdm import tqdm
3 | import os
4 | import torch.nn.parallel
5 | import torch.optim as optim
6 | import torch.utils.data as DT
7 |
8 | from dataloader import FAUSTDataset, KeyPointDataset
9 | from model import FMNet
10 | from utils.loss import GeodesicLoss
11 |
12 |
13 | def arg_parse():
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument('--dataset', type=str, choices=['faust', 'keypointnet'])
16 | parser.add_argument('--category', type=str, default='02691156') # only for keypointnet
17 | parser.add_argument('--root_path', type=str, default='./')
18 | parser.add_argument('--phase', type=str, default='train')
19 | parser.add_argument('--batch_size', type=int, default=8)
20 | parser.add_argument('--lr', type=str, default='1e-3')
21 | parser.add_argument('--max_epochs', type=int, default=300)
22 | parser.add_argument('--seed', type=int, default=-1)
23 | parser.add_argument('--gpu', type=str, default='-1')
24 | parser.add_argument('--num_workers', type=int, default=4)
25 | args = parser.parse_args()
26 | return args
27 |
28 | if __name__ == '__main__':
29 | # init
30 | args = arg_parse()
31 | if args.seed != -1:
32 | torch.manual_seed(args.seed)
33 | save_dir = os.path.join(args.root_path, 'checkpoints', args.dataset)
34 |
35 | # load train data
36 | if args.dataset == 'faust':
37 | train_dataset = FAUSTDataset(args=args)
38 | elif args.dataset == 'keypointnet':
39 | train_dataset = KeyPointDataset(args=args)
40 | else:
41 | raise NotImplementedError("Dataset not implemented now.")
42 | train_dataloader = DT.DataLoader(
43 | dataset=train_dataset,
44 | batch_size=args.batch_size,
45 | shuffle=True,
46 | num_workers=args.num_workers,
47 | drop_last=False
48 | )
49 |
50 | # network config
51 | net = FMNet()
52 | optimizer = optim.Adam(net.parameters(), lr=eval(args.lr), betas=(0.9, 0.999))
53 | schedular = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)
54 | criterion = GeodesicLoss()
55 | net = net.cuda().float()
56 |
57 | # train
58 | for epoch in range(1, args.max_epochs+1):
59 | schedular.step()
60 | i = 0
61 | for data in tqdm(train_dataloader):
62 | x, y, evecs_x, evecs_y, feat_x, feat_y, dist_x, dist_y = data # FAUST has no keypoints
63 | x, y, evecs_x, evecs_y, feat_x, feat_y, dist_x, dist_y = \
64 | x.cuda(), y.cuda(), evecs_x.cuda(), evecs_y.cuda(), feat_x.cuda(), feat_y.cuda(), dist_x.cuda(), dist_y.cuda()
65 | optimizer.zero_grad()
66 | net = net.train()
67 | Q, C = net(feat_x, feat_y, evecs_x, evecs_y)
68 | loss = criterion(Q, dist_x, dist_y)
69 | loss.backward()
70 | optimizer.step()
71 | i += 1
72 | print("#epoch:{}, #batch:{}, loss:{}".format(epoch, i+1, loss))
73 |
74 | if epoch % 20 == 0:
75 | torch.save(net.state_dict(), os.path.join(save_dir, 'epoch{}.pth'.format(epoch)))
76 |
77 |
78 |
79 |
80 |
81 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from tqdm import tqdm
3 | import os
4 | import scipy.io as sio
5 | import torch
6 | import torch.nn.parallel
7 | import torch.utils.data as DT
8 |
9 | from dataloader import FAUSTDataset, KeyPointDataset
10 | from model import FMNet
11 | from utils.loss import GeodesicLoss
12 |
13 |
14 | def arg_parse():
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument('--dataset', type=str, choices=['faust', 'keypointnet'])
17 | parser.add_argument('--category', type=str, default='02691156') # only for keypointnet
18 | parser.add_argument('--root_path', type=str, default='./')
19 | parser.add_argument('--model_name', type=str, default=None)
20 | parser.add_argument('--phase', type=str, default='train')
21 | parser.add_argument('--batch_size', type=int, default=1)
22 | parser.add_argument('--seed', type=int, default=-1)
23 | parser.add_argument('--gpu', type=str, default='-1')
24 | parser.add_argument('--num_workers', type=int, default=4)
25 | args = parser.parse_args()
26 | return args
27 |
28 | if __name__ == '__main__':
29 | # init
30 | args = arg_parse()
31 | if args.seed != -1:
32 | torch.manual_seed(args.seed)
33 |
34 | # load test data
35 | if args.dataset == 'faust':
36 | test_dataset = FAUSTDataset(args=args)
37 | elif args.dataset == 'keypointnet':
38 | test_dataset = KeyPointDataset(args=args)
39 | else:
40 | raise NotImplementedError("Dataset not implemented now.")
41 | test_dataloader = DT.DataLoader(
42 | dataset=test_dataset,
43 | batch_size=args.batch_size,
44 | shuffle=False,
45 | num_workers=args.num_workers,
46 | drop_last=False
47 | )
48 |
49 | # network config
50 | model_path = os.path.join(args.root_path, 'checkpoints', args.dataset, args.model_name)
51 | if not os.path.exists(model_path):
52 | print("No trained model in {}".format(model_path))
53 | exit(-1)
54 | print(model_path)
55 | net = FMNet()
56 | net.load_state_dict(torch.load(model_path))
57 | net = net.cuda().float()
58 | criterion = GeodesicLoss()
59 |
60 | # train
61 | i = 0
62 | with torch.no_grad():
63 | for data in tqdm(test_dataloader):
64 | i += 1
65 | x, y, evecs_x, evecs_y, feat_x, feat_y, dist_x, dist_y = data # FAUST has no keypoints
66 | x, y, evecs_x, evecs_y, feat_x, feat_y, dist_x, dist_y = \
67 | x.cuda(), y.cuda(), evecs_x.cuda(), evecs_y.cuda(), feat_x.cuda(), feat_y.cuda(), dist_x.cuda(), dist_y.cuda()
68 | net = net.eval()
69 | Q, C = net(feat_x, feat_y, evecs_x, evecs_y)
70 | loss = criterion(Q, dist_x, dist_y)
71 | print("loss:{}".format(loss))
72 | sio.savemat(os.path.join(args.root_path, 'results', args.dataset, 'pair{}.mat'.format(i)), # note #pair is not related to #ply
73 | {'x': x.cpu().detach().numpy().squeeze(),
74 | 'y': y.cpu().detach().numpy().squeeze(),
75 | 'Q': Q.cpu().detach().numpy().squeeze(),
76 | })
77 |
78 |
79 |
80 |
81 |
82 |
83 |
--------------------------------------------------------------------------------
/preprocess.py:
--------------------------------------------------------------------------------
1 | """
2 | Preprocessing codes are provided by You Yang: qq456cvb@github
3 | """
4 | import scipy.io as sio
5 | from sklearn import neighbors
6 | from sklearn.utils.graph import graph_shortest_path
7 | import numpy as np
8 | from scipy.linalg import eigh
9 | from timeit import default_timer as timer
10 | import os
11 | import utils.shot.shot as shot
12 | import plyfile
13 |
14 | NUM_EIGENS = 150
15 | NUM_POINTS = 6890
16 | NUM_POINTS_SAMPLE = 2048
17 | NORMAL_R = 0.1
18 | SHOT_R = 0.1
19 | KNN = 20
20 | DESCRIPTOR = 'shot'
21 | MODELS_DIR = '~/MPI-FAUST/training/registrations'
22 | SAVE_DIR = './data/faust/train'
23 |
24 | def normalize_adj(adj):
25 | rowsum = np.sum(adj, axis=1)
26 | d_inv_sqrt = np.power(rowsum, -0.5).flatten()
27 | d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.
28 | d_mat_inv_sqrt = np.diag(d_inv_sqrt)
29 | return d_mat_inv_sqrt @ adj @ d_mat_inv_sqrt
30 |
31 |
32 | def normalized_laplacian(adj):
33 | adj_normalized = normalize_adj(adj)
34 | norm_laplacian = np.eye(adj.shape[0]) - adj_normalized
35 | return norm_laplacian
36 |
37 |
38 | if __name__ == '__main__':
39 | cnt = 0
40 | for root, dirs, files in os.walk(MODELS_DIR):
41 | print(root, dirs, files)
42 | for file in files:
43 | if file[-4:] == '.png':
44 | continue
45 | cnt += 1
46 | file_path = os.path.join(root, file)
47 | ply = plyfile.PlyData.read(file_path)
48 | x = ply.elements[0].data
49 | x = np.stack((x['x'], x['y'], x['z']), axis=-1)
50 | np.random.seed(cnt) # assure different randoms
51 | random_indices = np.random.choice(NUM_POINTS, NUM_POINTS_SAMPLE)
52 | x = x[random_indices]
53 |
54 | x = x - np.mean(x, axis=0, keepdims=True)
55 | x = x / np.max(np.linalg.norm(x, axis=1))
56 |
57 | t = timer()
58 | graph_x_csr = neighbors.kneighbors_graph(x, KNN, mode='distance', include_self=False)
59 | print('knn time:', timer() - t)
60 |
61 | t = timer()
62 | graph_x = graph_x_csr.toarray() # derive a n*n matrix, each element is distance(if knn) or 0(if not knn)
63 | graph_x = np.exp(- graph_x ** 2 / np.mean(graph_x[:, -1]) ** 2)
64 | graph_x = graph_x + (graph_x.T - graph_x) * np.greater(graph_x.T, graph_x).astype(
65 | np.float)
66 |
67 | laplacian_x = normalized_laplacian(graph_x)
68 | print('laplacian time:', timer() - t)
69 |
70 | t = timer()
71 | _, eigen_x = eigh(laplacian_x,
72 | eigvals=(laplacian_x.shape[0] - NUM_EIGENS, laplacian_x.shape[0] - 1))
73 | print('eigen function time:', timer() - t)
74 |
75 | t = timer()
76 | geodesic_x = graph_shortest_path(graph_x_csr,
77 | directed=False) # directed=False means geodesic, not l2 distance
78 | print('geodesic matrix time:', timer() - t)
79 |
80 | t = timer()
81 |
82 | if DESCRIPTOR == 'shot':
83 | feature_x = shot.compute(x, NORMAL_R, SHOT_R).reshape(-1, 352)
84 |
85 | print('shot time:', timer() - t)
86 | feature_x[np.where(np.isnan(feature_x))] = 0
87 |
88 | print(os.path.join(SAVE_DIR, '{}.mat'.format(file[:-4])))
89 | sio.savemat(os.path.join(SAVE_DIR, '{}.mat'.format(file[:-4])),
90 | {'pcd': x,
91 | 'evecs': eigen_x,
92 | 'feat': feature_x,
93 | 'dist': geodesic_x,
94 | })
95 |
--------------------------------------------------------------------------------
/dataloader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.utils.data as DT
4 | import scipy.io as sio
5 |
6 | class FAUSTDataset(DT.Dataset):
7 | def __init__(self, args):
8 | super(FAUSTDataset, self).__init__()
9 | self.data_path = os.path.join(args.root_path, 'data', args.dataset, args.phase)
10 | self.pcds = []
11 | self.evecs = []
12 | self.feats = []
13 | self.dists = [] # not input into network, only for loss
14 | for root, dirs, files in os.walk(self.data_path):
15 | for file in files:
16 | f = sio.loadmat(os.path.join(root, file))
17 | self.pcds.append(f['pcd'])
18 | self.evecs.append(f['evecs'])
19 | self.feats.append(f['feat'])
20 | self.dists.append(f['dist'])
21 |
22 | def __getitem__(self, item):
23 | """
24 |
25 | :param item: index, index*2 are chosen as pairs
26 | :return: (pcd_x, pcd_y, evecs_x, evecs_y, feat_x, feat_y, keypoints_x, keypoints_y, dist_x, dist_y)
27 | """
28 | x, y = self.pcds[item], self.pcds[item*2]
29 | evecs_x, evecs_y = self.evecs[item], self.evecs[item*2] # not rotational invariant
30 | feat_x, feat_y = self.feats[item], self.feats[item*2] # not rotational invatriant
31 | dist_x, dist_y = self.dists[item], self.dists[item*2]
32 | return torch.tensor(x).float(), torch.tensor(y).float(),\
33 | torch.tensor(evecs_x).float(), torch.tensor(evecs_y).float(), \
34 | torch.tensor(feat_x).float(), torch.tensor(feat_y).float(), \
35 | torch.tensor(dist_x).float(), torch.tensor(dist_y).float()
36 |
37 | def __len__(self):
38 | return len(self.pcds) // 2
39 |
40 |
41 | class KeyPointDataset(DT.Dataset):
42 | def __init__(self, args):
43 | super(KeyPointDataset, self).__init__()
44 | self.data_path = os.path.join(args.root_path, 'data', args.dataset, args.phase, args.category)
45 | self.pcds = []
46 | self.evecs = []
47 | self.feats = []
48 | self.keypoints = [] # not input into network, only for loss
49 | self.dists = [] # not input into network, only for loss
50 | for root, dirs, files in os.walk(self.data_path):
51 | for file in files:
52 | f = sio.loadmat(os.path.join(root, file))
53 | self.pcds.append(f['pcd'])
54 | self.evecs.append(f['evecs'])
55 | self.feats.append(f['feat'])
56 | self.keypoints.append(f['keypoints'])
57 | self.dists.append(f['dist'])
58 |
59 | def __getitem__(self, item):
60 | """
61 |
62 | :param item: index, index*2 are chosen as pairs
63 | :return: (pcd_x, pcd_y, evecs_x, evecs_y, feat_x, feat_y, keypoints_x, keypoints_y, dist_x, dist_y)
64 | """
65 | x, y = self.pcds[item], self.pcds[item*2]
66 | evecs_x, evecs_y = self.evecs[item], self.evecs[item*2] # related with rotation and translation
67 | feat_x, feat_y = self.feats[item], self.feats[item*2] # related with rotation and translation
68 | keypoints_x, keypoints_y = self.keypoints[item], self.keypoints[item*2]
69 | dist_x, dist_y = self.dists[item], self.dists[item*2]
70 | return torch.tensor(x).float(), torch.tensor(y).float(),\
71 | torch.tensor(evecs_x).float(), torch.tensor(evecs_y).float(), \
72 | torch.tensor(feat_x).float(), torch.tensor(feat_y).float(), \
73 | torch.tensor(keypoints_x).int(), torch.tensor(keypoints_y).int(), \
74 | torch.tensor(dist_x).float(), torch.tensor(dist_y).float()
75 |
76 | def __len__(self):
77 | return len(self.pcds) // 2
78 |
79 | # DATA DIR STRUCTURE
80 | # --FMNet.pytorch
81 | # --data
82 | # --KeypointNet
83 | # --train
84 | # --03001627
85 | # --1.mat
86 | # --2.mat
87 | # --3.mat
88 | # --02691159
89 | # --1.mat
90 | # --2.mat
91 | # --3.mat
92 | # --FAUST
93 | # --train
94 | # --1.mat
95 | # --test
96 | # --1.mat
97 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.utils.data
4 | import torch.nn.functional as F
5 |
6 | NUM_POINTS = 2048
7 | NUM_EVECS = 150
8 |
9 | class FMNet(nn.Module):
10 | def __init__(self):
11 | super(FMNet, self).__init__()
12 | self.feat_refiner = FeatRefineLayer(in_channels=352)
13 | self.corres = CorresLayer()
14 |
15 | def forward(self, feat_x, feat_y, evecs_x, evecs_y):
16 | """
17 |
18 | :param feat_x: B * 2048 * 352, handcrafted point-wise feature of x
19 | :param feat_y: B * 2048 * 352, handcrafted point-wise feature of y
20 | :param evecs_x: B * 2048 * 150, Laplace basis of x, each column is a basis vector
21 | :param evecs_y: B * 2048 * 150, Laplace basis of y, each column is a basis vector
22 | :return: Q(2048*2048) is soft point-wise correspondence, C(150*150) functional map
23 | """
24 | feat_x, feat_y = self.feat_refiner(feat_x), self.feat_refiner(feat_y)
25 | Q, C = self.corres(feat_x, feat_y, evecs_x, evecs_y)
26 | return Q, C
27 |
28 |
29 | class FeatRefineLayer(nn.Module):
30 | def __init__(self, in_channels=352, out_channels_list=None):
31 | super(FeatRefineLayer, self).__init__()
32 | self.in_channels = in_channels
33 | if out_channels_list is None:
34 | out_channels_list = [352, 352, 352, 352, 352, 352, 352]
35 | self.out_channels_list = out_channels_list
36 | self.res_layers = nn.ModuleList()
37 | for out_channels in self.out_channels_list:
38 | self.res_layers.append(ResLayer(in_channels, out_channels, out_channels))
39 | in_channels = out_channels
40 |
41 | def forward(self, x):
42 | """
43 |
44 | :param x: B * 2048 * 352, handcrafted point-wise feature of x
45 | :return: B * 2048 * 352, refined point-wise feature of x
46 | """
47 | for res_layer in self.res_layers:
48 | x = res_layer(x)
49 | return x
50 |
51 |
52 | class ResLayer(nn.ModuleList):
53 | def __init__(self, in_channels, out_channels, mid_channels):
54 | super(ResLayer, self).__init__()
55 | self.in_channels = in_channels
56 | self.out_channels = out_channels
57 | self.fc1 = nn.Linear(in_channels, mid_channels)
58 | self.bn1 = nn.BatchNorm1d(num_features=mid_channels, eps=1e-3, momentum=1e-3)
59 | self.fc2 = nn.Linear(mid_channels, out_channels)
60 | self.bn2 = nn.BatchNorm1d(num_features=out_channels, eps=1e-3, momentum=1e-3)
61 | self.fc3 = None
62 | if in_channels != out_channels:
63 | self.fc3 = nn.Linear(in_channels, out_channels)
64 |
65 | def forward(self, x):
66 | """
67 |
68 | :param x: B * 2048 * 352, refining point-wise feature of x
69 | :return: B * 2048 * 352, refining point-wise feature of x
70 | """
71 | x_res = F.relu(self.bn1(self.fc1(x).transpose(1, 2)).transpose(1, 2))
72 | x_res = self.bn2(self.fc2(x_res).transpose(1, 2)).transpose(1, 2)
73 | if self.in_channels != self.out_channels:
74 | x = self.fc3(x)
75 | x_res += x
76 | return x_res
77 |
78 |
79 |
80 | class CorresLayer(nn.Module):
81 | def __init__(self):
82 | super(CorresLayer, self).__init__()
83 |
84 | def forward(self, feat_x, feat_y, evecs_x, evecs_y):
85 | """
86 |
87 | :param feat_x: B * 2048 * 352, refined point-wise feature of x
88 | :param feat_y: B * 2048 * 352, refined point-wise feature of y
89 | :param evecs_x: B * 2048 * 150, Laplace basis of x, each column is a basis vector
90 | :param evecs_y: B * 2048 * 150, Laplace basis of y, each column is a basis vector
91 | :return: Q and C
92 | """
93 | # solve ls C*A=B, i.e. A.T*C.T=B.T
94 | batch_size = feat_x.size(0)
95 | A = torch.bmm(evecs_x.transpose(2, 1), feat_x)
96 | B = torch.bmm(evecs_y.transpose(2, 1), feat_y)
97 | A, B = A.transpose(2, 1), B.transpose(2, 1)
98 | for i in range(batch_size):
99 | # C_i, _ = torch.gels(B[i], A[i])
100 | C_i = torch.inverse(A[i].transpose(1, 0)@ A[i]) @ A[i].transpose(1, 0) @ B[i] # C=(A.T*A)^{-1}*A.T*B
101 | if i == 0:
102 | C = C_i.unsqueeze(0)[:, :NUM_EVECS, :]
103 | else:
104 | C = torch.cat((C, C_i.unsqueeze(0)[:, :NUM_EVECS, :]), dim=0)
105 | C = C.transpose(2, 1)
106 | # function map-> point2point map
107 | P = abs(torch.bmm(torch.bmm(evecs_y, C), evecs_x.transpose(2, 1)))
108 | Q = F.normalize(P, 2, 1) ** 2
109 | return Q, C
110 |
111 | if __name__ == '__main__':
112 | feat_x = torch.rand(8, 2048, 352)
113 | feat_y = torch.rand(8, 2048, 352)
114 | evecs_x = torch.rand(8, 2048, 150)
115 | evecs_y = torch.rand(8, 2048, 150)
116 | net = FMNet()
117 | Q, C = net(feat_x, feat_y, evecs_x, evecs_y)
118 | print(torch.sum(Q, 1))
119 | print(C.shape)
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
--------------------------------------------------------------------------------