├── .DS_Store ├── README.md ├── doc └── flowchart1.png ├── inference.py ├── network.py ├── preprocessing_and_test_data ├── .DS_Store ├── prepare_test_data.m ├── preprocessing.m └── test_data │ ├── .DS_Store │ ├── raw_input_example5.h5 │ └── raw_input_example6.h5 ├── result ├── check_performance.m ├── plot_mesh.m ├── raw_input_example5.mat └── raw_input_example6.mat ├── trained_model ├── .DS_Store └── README.md └── utils_distance ├── .DS_Store └── nndistance ├── .gitignore ├── README.md ├── _ext ├── __init__.py ├── __init__.pyc ├── __pycache__ │ ├── __init__.cpython-35.pyc │ └── __init__.cpython-37.pyc └── my_lib │ ├── __init__.py │ ├── __init__.pyc │ └── __pycache__ │ ├── __init__.cpython-35.pyc │ └── __init__.cpython-37.pyc ├── build.py ├── functions ├── __init__.py ├── __init__.pyc ├── __pycache__ │ ├── __init__.cpython-35.pyc │ ├── __init__.cpython-37.pyc │ ├── nnd.cpython-35.pyc │ └── nnd.cpython-37.pyc ├── nnd.py └── nnd.pyc ├── modules ├── __init__.py ├── __init__.pyc ├── __pycache__ │ ├── __init__.cpython-35.pyc │ ├── __init__.cpython-37.pyc │ ├── nnd.cpython-35.pyc │ └── nnd.cpython-37.pyc ├── nnd.py └── nnd.pyc ├── src ├── my_lib.c ├── my_lib.h ├── my_lib_cuda.c ├── my_lib_cuda.h ├── nnd_cuda.cu └── nnd_cuda.h └── test.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liufeng2915/3DFC/8b5e4e02f6c1a78e8efa8ac2dce2cd8e14c7feb5/.DS_Store -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## 3D Face Modeling from Diverse Raw Scan Data 2 | 3 | ![](doc/flowchart1.png) 4 | 5 | ======= 6 | 7 | We propose a novel PointNet-based encoder-decoder framework that for the first time jointly learns an expressive face model from a diverse set of raw 3D scans and establishes dense correspondence among them. 8 | 9 | ======= 10 | 11 | 3D Face Modeling from Diverse Raw Scan Data. ICCV 2019. [Paper](https://arxiv.org/pdf/1902.04943v2.pdf)
12 | [Feng Liu](http://www.face3d.org/), [Luan Tran](http://www.cse.msu.edu/~tranluan/), [Xiaoming Liu](http://cvlab.cse.msu.edu/pages/people.html)
13 | Department of Computer Science and Engineering, Michigan State University.
14 | -------------------------------------------------------------------------------- /doc/flowchart1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liufeng2915/3DFC/8b5e4e02f6c1a78e8efa8ac2dce2cd8e14c7feb5/doc/flowchart1.png -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 |  2 | import torch 3 | from torch.autograd import Variable 4 | import network 5 | import h5py 6 | import scipy.io 7 | 8 | ## load trained model 9 | ShapeGen = network.FaceModel() 10 | ShapeGen = torch.load('trained_model/ShapeModel.pth') 11 | 12 | ## test data 13 | test_file_name = 'raw_input_example9' 14 | test_file = 'preprocessing_and_test_data/test_data/' + test_file_name + '.h5' 15 | hf = h5py.File(test_file, 'r') 16 | input_shape = hf['input_shape'][:].T 17 | hf.close() 18 | input_shape = Variable(torch.from_numpy(input_shape)).unsqueeze(0) 19 | 20 | ## GPU or CPU 21 | cuda = True # False 22 | if cuda: 23 | ShapeGen = ShapeGen.cuda() 24 | input_shape = input_shape.cuda() 25 | 26 | ## 27 | esti_shape = ShapeGen(input_shape) 28 | if cuda: 29 | esti_shape = esti_shape.data.cpu().numpy().reshape(esti_shape.shape[1],esti_shape.shape[2]) 30 | else: 31 | esti_shape = esti_shape.data.numpy().reshape(esti_shape.shape[1],esti_shape.shape[2]) 32 | 33 | ## save data 34 | scipy.io.savemat('result/'+test_file_name+'.mat', {'esti_shape': esti_shape}) 35 | 36 | 37 | 38 | 39 | 40 | -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | 6 | 7 | class FaceModel(nn.Module): 8 | def __init__(self, num_vertex=29495): 9 | super(FaceModel, self).__init__() 10 | self.num_vertex = num_vertex 11 | 12 | ## encoder Point-Net without STN 13 | self.conv1 = nn.Conv1d(3, 64, 1) 14 | self.bn1 = nn.BatchNorm1d(64) 15 | self.conv2 = nn.Conv1d(64, 64, 1) 16 | self.bn2 = nn.BatchNorm1d(64) 17 | self.conv3 = nn.Conv1d(64, 64, 1) 18 | self.bn3 = nn.BatchNorm1d(64) 19 | self.conv4 = nn.Conv1d(64, 128, 1) 20 | self.bn4 = nn.BatchNorm1d(128) 21 | self.conv5 = nn.Conv1d(128, 1024, 1) 22 | self.bn5 = nn.BatchNorm1d(1024) 23 | self.mp1 = nn.MaxPool1d(num_vertex) 24 | 25 | ## identity latent vector 26 | self.neu_fc = nn.Linear(1024, 512) 27 | ## identity decoder 28 | self.neuDe_fc1 = nn.Linear(512, 1024) 29 | self.neuDe_fc2 = nn.Linear(1024, num_vertex*3) 30 | 31 | ## expression latent vector 32 | self.exp_fc = nn.Linear(1024, 512) 33 | ## expression decoder 34 | self.expDe_fc1 = nn.Linear(512, 1024) 35 | self.expDe_fc2 = nn.Linear(1024, num_vertex*3) 36 | 37 | 38 | def forward(self, x): 39 | batch_size = x.size()[0] 40 | x = x.transpose(2,1).contiguous() 41 | x = F.relu(self.bn1(self.conv1(x))) 42 | x = F.relu(self.bn2(self.conv2(x))) 43 | x = F.relu(self.bn3(self.conv3(x))) 44 | x = F.relu(self.bn4(self.conv4(x))) 45 | x = F.relu(self.bn5(self.conv5(x))) 46 | x = self.mp1(x) 47 | x = x.view(-1, 1024) 48 | 49 | neu_x = self.neu_fc(x) 50 | exp_x = self.exp_fc(x) 51 | 52 | neu_x = F.relu(self.neuDe_fc1(neu_x)) 53 | neu_x = self.neuDe_fc2(neu_x) 54 | neu_x = neu_x.view(batch_size, 3, self.num_vertex).transpose(1,2).contiguous() 55 | 56 | exp_x = F.relu(self.expDe_fc1(exp_x)) 57 | exp_x = self.expDe_fc2(exp_x) 58 | exp_x = exp_x.view(batch_size, 3, self.num_vertex).transpose(1,2).contiguous() 59 | 60 | x = torch.add(neu_x, exp_x) 61 | return x 62 | 63 | 64 | if __name__ == '__main__': 65 | 66 | input_shape = Variable(torch.rand(5,29495,3).cuda()) 67 | ShapeGen = FaceModel().cuda() 68 | esti_shape = ShapeGen(input_shape) 69 | print(esti_shape) 70 | -------------------------------------------------------------------------------- /preprocessing_and_test_data/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liufeng2915/3DFC/8b5e4e02f6c1a78e8efa8ac2dce2cd8e14c7feb5/preprocessing_and_test_data/.DS_Store -------------------------------------------------------------------------------- /preprocessing_and_test_data/prepare_test_data.m: -------------------------------------------------------------------------------- 1 | 2 | clear; 3 | %% load data 4 | load template_data.mat; %% mean_shape 5 | load raw_scan_data.mat; 6 | 7 | for i = 1:length(raw_data) 8 | raw_shape = raw_data(i).shape; 9 | raw_face = raw_data(i).face; 10 | raw_land = raw_data(i).land; % five points [right_eye;left_eye;nose;right_mouth;left_mouth]; 11 | 12 | %% preprocessing 13 | [input_shape] = preprocessing(mean_shape, mean_face, raw_shape, raw_land); 14 | h5_name = ['test_data/raw_input_example' num2str(i) '.h5']; 15 | h5create(h5_name,'/input_shape',size(input_shape), 'Datatype', 'single'); 16 | h5write(h5_name,'/input_shape', input_shape); 17 | end -------------------------------------------------------------------------------- /preprocessing_and_test_data/preprocessing.m: -------------------------------------------------------------------------------- 1 | 2 | function [input_shape] = preprocessing(mean_shape, mean_face, raw_shape, raw_land) 3 | 4 | % mean_shape nv*3 5 | % mean_face nf*3 6 | % raw_shape n*3 7 | % raw_land 5*3 8 | end -------------------------------------------------------------------------------- /preprocessing_and_test_data/test_data/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liufeng2915/3DFC/8b5e4e02f6c1a78e8efa8ac2dce2cd8e14c7feb5/preprocessing_and_test_data/test_data/.DS_Store -------------------------------------------------------------------------------- /preprocessing_and_test_data/test_data/raw_input_example5.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liufeng2915/3DFC/8b5e4e02f6c1a78e8efa8ac2dce2cd8e14c7feb5/preprocessing_and_test_data/test_data/raw_input_example5.h5 -------------------------------------------------------------------------------- /preprocessing_and_test_data/test_data/raw_input_example6.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liufeng2915/3DFC/8b5e4e02f6c1a78e8efa8ac2dce2cd8e14c7feb5/preprocessing_and_test_data/test_data/raw_input_example6.h5 -------------------------------------------------------------------------------- /result/check_performance.m: -------------------------------------------------------------------------------- 1 | 2 | clear; close all; 3 | load ../preprocessing_and_test_data/template_data.mat; 4 | load ../preprocessing_and_test_data/raw_scan_data.mat; 5 | %% 6 | load raw_input_example6.mat; 7 | raw_shape = raw_data(6).shape; raw_face = raw_data(6).face; 8 | 9 | figure, subplot(1,2,1);plot_mesh(raw_shape, raw_face); 10 | subplot(1,2,2); plot_mesh(esti_shape, mean_face); -------------------------------------------------------------------------------- /result/plot_mesh.m: -------------------------------------------------------------------------------- 1 | function h = plot_mesh(vertex,face,options) 2 | 3 | % plot_mesh - plot a 3D mesh. 4 | % 5 | % plot_mesh(vertex,face, options); 6 | % 7 | % 'options' is a structure that may contains: 8 | % - 'normal' : a (nvertx x 3) array specifying the normals at each vertex. 9 | % - 'edge_color' : a float specifying the color of the edges. 10 | % - 'face_color' : a float specifying the color of the faces. 11 | % - 'face_vertex_color' : a color per vertex or face. 12 | % - 'vertex' 13 | % 14 | % See also: mesh_previewer. 15 | % 16 | % Copyright (c) 2004 Gabriel Peyr� 17 | 18 | 19 | if nargin<2 20 | error('Not enough arguments.'); 21 | end 22 | 23 | options.null = 0; 24 | 25 | name = getoptions(options, 'name', ''); 26 | normal = getoptions(options, 'normal', []); 27 | face_color = getoptions(options, 'face_color', .7); 28 | edge_color = getoptions(options, 'edge_color', 0); 29 | normal_scaling = getoptions(options, 'normal_scaling', .8); 30 | sanity_check = getoptions(options, 'sanity_check', 1); 31 | 32 | 33 | % can flip to accept data in correct ordering 34 | vertex = vertex'; face = face'; 35 | 36 | if size(face,1)==4 37 | %%%% tet mesh %%%% 38 | 39 | % normal to the plane <=a 40 | w = getoptions(options, 'cutting_plane', [0.2 0 1]'); 41 | w = w(:)/sqrt(sum(w.^2)); 42 | t = sum(vertex.*repmat(w,[1 size(vertex,2)])); 43 | a = getoptions(options, 'cutting_offs', median(t(:)) ); 44 | b = getoptions(options, 'cutting_interactive', 0); 45 | 46 | while true; 47 | 48 | % in/out 49 | I = ( t<=a ); 50 | % trim 51 | e = sum(I(face)); 52 | J = find(e==4); 53 | facetrim = face(:,J); 54 | K = find(e==0); 55 | K = face(:,K); K = unique(K(:)); 56 | 57 | % convert to triangular mesh 58 | hold on; 59 | if not(isempty(facetrim)) 60 | face1 = tet2tri(facetrim, vertex, 1); 61 | options.method = 'fast'; 62 | face1 = perform_faces_reorientation(vertex,face1, options); 63 | h{1} = plot_mesh(vertex,face1); 64 | end 65 | view(3); camlight; 66 | shading faceted; 67 | h{2} = plot3(vertex(1,K), vertex(2,K), vertex(3,K), 'k.'); 68 | hold off; 69 | 70 | if b==0 71 | break; 72 | end 73 | 74 | [x,y,b] = ginput(1); 75 | 76 | if b==1 77 | a = a+.03; 78 | elseif b==3 79 | a = a-.03; 80 | else 81 | break; 82 | end 83 | end 84 | return; 85 | end 86 | 87 | 88 | vertex = vertex'; 89 | face = face'; 90 | 91 | if strcmp(name, 'bunny') || strcmp(name, 'pieta') 92 | % vertex = -vertex; 93 | end 94 | if strcmp(name, 'armadillo') 95 | vertex(:,3) = -vertex(:,3); 96 | end 97 | 98 | if sanity_check && (size(face,2)~=3 || (size(vertex,2)~=3 && size(vertex,2)~=2)) 99 | error('face or vertex does not have correct format.'); 100 | end 101 | 102 | if ~isfield(options, 'face_vertex_color') || isempty(options.face_vertex_color) 103 | options.face_vertex_color = zeros(size(vertex,1),1); 104 | end 105 | face_vertex_color = options.face_vertex_color; 106 | 107 | 108 | if isempty(face_vertex_color) 109 | h = patch('vertices',vertex,'faces',face,'facecolor',[face_color face_color face_color],'edgecolor',[edge_color edge_color edge_color]); 110 | else 111 | nverts = size(vertex,1); 112 | % vertex_color = rand(nverts,1); 113 | if size(face_vertex_color,1)==size(vertex,1) 114 | shading_type = 'interp'; 115 | else 116 | shading_type = 'flat'; 117 | end 118 | h = patch('vertices',vertex,'faces',face,'FaceVertexCData',face_vertex_color, 'FaceColor',shading_type); 119 | end 120 | colormap gray(256); 121 | lighting phong; 122 | % camlight infinite; 123 | camproj('perspective'); 124 | axis square; 125 | axis off; 126 | 127 | if ~isempty(normal) 128 | % plot the normals 129 | n = size(vertex,1); 130 | subsample_normal = getoptions(options, 'subsample_normal', min(4000/n,1) ); 131 | sel = randperm(n); sel = sel(1:floor(end*subsample_normal)); 132 | hold on; 133 | quiver3(vertex(sel,1),vertex(sel,2),vertex(sel,3),normal(1,sel),normal(2,sel),normal(3,sel),normal_scaling); 134 | hold off; 135 | end 136 | 137 | cameramenu; 138 | switch lower(name) 139 | case 'hammerheadtriang' 140 | view(150,-45); 141 | case 'horse' 142 | view(134,-61); 143 | case 'skull' 144 | view(21.5,-12); 145 | case 'mushroom' 146 | view(160,-75); 147 | case 'bunny' 148 | % view(0,-55); 149 | view(0,90); 150 | case 'david_head' 151 | view(-100,10); 152 | case 'screwdriver' 153 | view(-10,25); 154 | case 'pieta' 155 | view(15,31); 156 | case 'mannequin' 157 | view(25,15); 158 | case 'david-low' 159 | view(40,3); 160 | case 'brain' 161 | view(30,40); 162 | case 'pelvis' 163 | view(5,-15); 164 | end 165 | view_param = getoptions(options, 'view_param', []); 166 | if not(isempty(view_param)) 167 | view(view_param(1),view_param(2)); 168 | end 169 | 170 | axis tight; 171 | axis equal; 172 | shading interp; 173 | camlight; 174 | 175 | if strcmp(name, 'david50kf') || strcmp(name, 'hand') 176 | zoom(.85); 177 | end -------------------------------------------------------------------------------- /result/raw_input_example5.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liufeng2915/3DFC/8b5e4e02f6c1a78e8efa8ac2dce2cd8e14c7feb5/result/raw_input_example5.mat -------------------------------------------------------------------------------- /result/raw_input_example6.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liufeng2915/3DFC/8b5e4e02f6c1a78e8efa8ac2dce2cd8e14c7feb5/result/raw_input_example6.mat -------------------------------------------------------------------------------- /trained_model/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liufeng2915/3DFC/8b5e4e02f6c1a78e8efa8ac2dce2cd8e14c7feb5/trained_model/.DS_Store -------------------------------------------------------------------------------- /trained_model/README.md: -------------------------------------------------------------------------------- 1 | Please download the pretrained models from: https://drive.google.com/file/d/16yV4sgXvoKnlEpaTpnkJu2eix4q61ONg/view?usp=sharing 2 | 3 | -------------------------------------------------------------------------------- /utils_distance/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liufeng2915/3DFC/8b5e4e02f6c1a78e8efa8ac2dce2cd8e14c7feb5/utils_distance/.DS_Store -------------------------------------------------------------------------------- /utils_distance/nndistance/.gitignore: -------------------------------------------------------------------------------- 1 | *.so 2 | *.o 3 | -------------------------------------------------------------------------------- /utils_distance/nndistance/README.md: -------------------------------------------------------------------------------- 1 | # An example C extension for PyTorch 2 | 3 | This example showcases adding a neural network layer that adds two input Tensors 4 | 5 | - src: C source code 6 | - functions: the autograd functions 7 | - modules: code of the nn module 8 | - build.py: a small file that compiles your module to be ready to use 9 | - test.py: an example file that loads and uses the extension 10 | 11 | ```bash 12 | cd src 13 | nvcc -c -o nnd_cuda.cu.o nnd_cuda.cu -x cu -Xcompiler -fPIC -arch=sm_52 14 | cd .. 15 | python build.py 16 | python test.py 17 | ``` 18 | -------------------------------------------------------------------------------- /utils_distance/nndistance/_ext/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liufeng2915/3DFC/8b5e4e02f6c1a78e8efa8ac2dce2cd8e14c7feb5/utils_distance/nndistance/_ext/__init__.py -------------------------------------------------------------------------------- /utils_distance/nndistance/_ext/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liufeng2915/3DFC/8b5e4e02f6c1a78e8efa8ac2dce2cd8e14c7feb5/utils_distance/nndistance/_ext/__init__.pyc -------------------------------------------------------------------------------- /utils_distance/nndistance/_ext/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liufeng2915/3DFC/8b5e4e02f6c1a78e8efa8ac2dce2cd8e14c7feb5/utils_distance/nndistance/_ext/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /utils_distance/nndistance/_ext/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liufeng2915/3DFC/8b5e4e02f6c1a78e8efa8ac2dce2cd8e14c7feb5/utils_distance/nndistance/_ext/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /utils_distance/nndistance/_ext/my_lib/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from torch.utils.ffi import _wrap_function 3 | from ._my_lib import lib as _lib, ffi as _ffi 4 | 5 | __all__ = [] 6 | def _import_symbols(locals): 7 | for symbol in dir(_lib): 8 | fn = getattr(_lib, symbol) 9 | if callable(fn): 10 | locals[symbol] = _wrap_function(fn, _ffi) 11 | else: 12 | locals[symbol] = fn 13 | __all__.append(symbol) 14 | 15 | _import_symbols(locals()) 16 | -------------------------------------------------------------------------------- /utils_distance/nndistance/_ext/my_lib/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liufeng2915/3DFC/8b5e4e02f6c1a78e8efa8ac2dce2cd8e14c7feb5/utils_distance/nndistance/_ext/my_lib/__init__.pyc -------------------------------------------------------------------------------- /utils_distance/nndistance/_ext/my_lib/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liufeng2915/3DFC/8b5e4e02f6c1a78e8efa8ac2dce2cd8e14c7feb5/utils_distance/nndistance/_ext/my_lib/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /utils_distance/nndistance/_ext/my_lib/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liufeng2915/3DFC/8b5e4e02f6c1a78e8efa8ac2dce2cd8e14c7feb5/utils_distance/nndistance/_ext/my_lib/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /utils_distance/nndistance/build.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.utils.ffi import create_extension 4 | 5 | this_file = os.path.dirname(__file__) 6 | 7 | sources = ['src/my_lib.c'] 8 | headers = ['src/my_lib.h'] 9 | defines = [] 10 | with_cuda = False 11 | 12 | if torch.cuda.is_available(): 13 | print('Including CUDA code.') 14 | sources += ['src/my_lib_cuda.c'] 15 | headers += ['src/my_lib_cuda.h'] 16 | defines += [('WITH_CUDA', None)] 17 | with_cuda = True 18 | 19 | this_file = os.path.dirname(os.path.realpath(__file__)) 20 | print(this_file) 21 | extra_objects = ['src/nnd_cuda.cu.o'] 22 | extra_objects = [os.path.join(this_file, fname) for fname in extra_objects] 23 | 24 | ffi = create_extension( 25 | '_ext.my_lib', 26 | headers=headers, 27 | sources=sources, 28 | define_macros=defines, 29 | relative_to=__file__, 30 | with_cuda=with_cuda, 31 | extra_objects=extra_objects 32 | ) 33 | 34 | if __name__ == '__main__': 35 | ffi.build() 36 | -------------------------------------------------------------------------------- /utils_distance/nndistance/functions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liufeng2915/3DFC/8b5e4e02f6c1a78e8efa8ac2dce2cd8e14c7feb5/utils_distance/nndistance/functions/__init__.py -------------------------------------------------------------------------------- /utils_distance/nndistance/functions/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liufeng2915/3DFC/8b5e4e02f6c1a78e8efa8ac2dce2cd8e14c7feb5/utils_distance/nndistance/functions/__init__.pyc -------------------------------------------------------------------------------- /utils_distance/nndistance/functions/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liufeng2915/3DFC/8b5e4e02f6c1a78e8efa8ac2dce2cd8e14c7feb5/utils_distance/nndistance/functions/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /utils_distance/nndistance/functions/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liufeng2915/3DFC/8b5e4e02f6c1a78e8efa8ac2dce2cd8e14c7feb5/utils_distance/nndistance/functions/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /utils_distance/nndistance/functions/__pycache__/nnd.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liufeng2915/3DFC/8b5e4e02f6c1a78e8efa8ac2dce2cd8e14c7feb5/utils_distance/nndistance/functions/__pycache__/nnd.cpython-35.pyc -------------------------------------------------------------------------------- /utils_distance/nndistance/functions/__pycache__/nnd.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liufeng2915/3DFC/8b5e4e02f6c1a78e8efa8ac2dce2cd8e14c7feb5/utils_distance/nndistance/functions/__pycache__/nnd.cpython-37.pyc -------------------------------------------------------------------------------- /utils_distance/nndistance/functions/nnd.py: -------------------------------------------------------------------------------- 1 | # functions/add.py 2 | import torch 3 | from torch.autograd import Function 4 | from _ext import my_lib 5 | 6 | 7 | class NNDFunction(Function): 8 | def forward(self, xyz1, xyz2): 9 | batchsize, n, _ = xyz1.size() 10 | _, m, _ = xyz2.size() 11 | self.xyz1 = xyz1 12 | self.xyz2 = xyz2 13 | dist1 = torch.zeros(batchsize, n) 14 | dist2 = torch.zeros(batchsize, m) 15 | 16 | self.idx1 = torch.zeros(batchsize, n).type(torch.IntTensor) 17 | self.idx2 = torch.zeros(batchsize, m).type(torch.IntTensor) 18 | 19 | if not xyz1.is_cuda: 20 | my_lib.nnd_forward(xyz1, xyz2, dist1, dist2, self.idx1, self.idx2) 21 | else: 22 | dist1 = dist1.cuda() 23 | dist2 = dist2.cuda() 24 | self.idx1 = self.idx1.cuda() 25 | self.idx2 = self.idx2.cuda() 26 | my_lib.nnd_forward_cuda(xyz1, xyz2, dist1, dist2, self.idx1, self.idx2) 27 | 28 | self.dist1 = dist1 29 | self.dist2 = dist2 30 | 31 | #print(batchsize, n, m) 32 | 33 | return dist1, dist2, self.idx1, self.idx2 34 | 35 | def backward(self, graddist1, graddist2, gradidx1, gradidx2): 36 | #print(self.idx1, self.idx2) 37 | 38 | 39 | graddist1 = graddist1.contiguous() 40 | graddist2 = graddist2.contiguous() 41 | 42 | gradxyz1 = torch.zeros(self.xyz1.size()) 43 | gradxyz2 = torch.zeros(self.xyz2.size()) 44 | 45 | if not graddist1.is_cuda: 46 | my_lib.nnd_backward(self.xyz1, self.xyz2, gradxyz1, gradxyz2, graddist1, graddist2, self.idx1, self.idx2) 47 | else: 48 | gradxyz1 = gradxyz1.cuda() 49 | gradxyz2 = gradxyz2.cuda() 50 | my_lib.nnd_backward_cuda(self.xyz1, self.xyz2, gradxyz1, gradxyz2, graddist1, graddist2, self.idx1, self.idx2) 51 | 52 | return gradxyz1, gradxyz2 -------------------------------------------------------------------------------- /utils_distance/nndistance/functions/nnd.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liufeng2915/3DFC/8b5e4e02f6c1a78e8efa8ac2dce2cd8e14c7feb5/utils_distance/nndistance/functions/nnd.pyc -------------------------------------------------------------------------------- /utils_distance/nndistance/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liufeng2915/3DFC/8b5e4e02f6c1a78e8efa8ac2dce2cd8e14c7feb5/utils_distance/nndistance/modules/__init__.py -------------------------------------------------------------------------------- /utils_distance/nndistance/modules/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liufeng2915/3DFC/8b5e4e02f6c1a78e8efa8ac2dce2cd8e14c7feb5/utils_distance/nndistance/modules/__init__.pyc -------------------------------------------------------------------------------- /utils_distance/nndistance/modules/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liufeng2915/3DFC/8b5e4e02f6c1a78e8efa8ac2dce2cd8e14c7feb5/utils_distance/nndistance/modules/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /utils_distance/nndistance/modules/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liufeng2915/3DFC/8b5e4e02f6c1a78e8efa8ac2dce2cd8e14c7feb5/utils_distance/nndistance/modules/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /utils_distance/nndistance/modules/__pycache__/nnd.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liufeng2915/3DFC/8b5e4e02f6c1a78e8efa8ac2dce2cd8e14c7feb5/utils_distance/nndistance/modules/__pycache__/nnd.cpython-35.pyc -------------------------------------------------------------------------------- /utils_distance/nndistance/modules/__pycache__/nnd.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liufeng2915/3DFC/8b5e4e02f6c1a78e8efa8ac2dce2cd8e14c7feb5/utils_distance/nndistance/modules/__pycache__/nnd.cpython-37.pyc -------------------------------------------------------------------------------- /utils_distance/nndistance/modules/nnd.py: -------------------------------------------------------------------------------- 1 | from torch.nn.modules.module import Module 2 | from functions.nnd import NNDFunction 3 | 4 | class NNDModule(Module): 5 | def forward(self, input1, input2): 6 | return NNDFunction()(input1, input2) 7 | -------------------------------------------------------------------------------- /utils_distance/nndistance/modules/nnd.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liufeng2915/3DFC/8b5e4e02f6c1a78e8efa8ac2dce2cd8e14c7feb5/utils_distance/nndistance/modules/nnd.pyc -------------------------------------------------------------------------------- /utils_distance/nndistance/src/my_lib.c: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | void nnsearch(int b,int n,int m,const float * xyz1,const float * xyz2,float * dist,int * idx){ 5 | for (int i=0;isize[0]; 30 | int n = xyz1->size[1]; 31 | int m = xyz2->size[1]; 32 | 33 | //printf("%d %d %d\n", batchsize, n, m); 34 | 35 | float *xyz1_data = THFloatTensor_data(xyz1); 36 | float *xyz2_data = THFloatTensor_data(xyz2); 37 | float *dist1_data = THFloatTensor_data(dist1); 38 | float *dist2_data = THFloatTensor_data(dist2); 39 | int *idx1_data = THIntTensor_data(idx1); 40 | int *idx2_data = THIntTensor_data(idx2); 41 | 42 | nnsearch(batchsize, n, m, xyz1_data, xyz2_data, dist1_data, idx1_data); 43 | nnsearch(batchsize, m, n, xyz2_data, xyz1_data, dist2_data, idx2_data); 44 | 45 | return 1; 46 | } 47 | 48 | 49 | int nnd_backward(THFloatTensor *xyz1, THFloatTensor *xyz2, THFloatTensor *gradxyz1, THFloatTensor *gradxyz2, THFloatTensor *graddist1, THFloatTensor *graddist2, THIntTensor *idx1, THIntTensor *idx2) { 50 | 51 | int b = xyz1->size[0]; 52 | int n = xyz1->size[1]; 53 | int m = xyz2->size[1]; 54 | 55 | //printf("%d %d %d\n", batchsize, n, m); 56 | 57 | float *xyz1_data = THFloatTensor_data(xyz1); 58 | float *xyz2_data = THFloatTensor_data(xyz2); 59 | float *gradxyz1_data = THFloatTensor_data(gradxyz1); 60 | float *gradxyz2_data = THFloatTensor_data(gradxyz2); 61 | float *graddist1_data = THFloatTensor_data(graddist1); 62 | float *graddist2_data = THFloatTensor_data(graddist2); 63 | int *idx1_data = THIntTensor_data(idx1); 64 | int *idx2_data = THIntTensor_data(idx2); 65 | 66 | 67 | for (int i=0;i 2 | #include "nnd_cuda.h" 3 | 4 | 5 | 6 | extern THCState *state; 7 | 8 | 9 | int nnd_forward_cuda(THCudaTensor *xyz1, THCudaTensor *xyz2, THCudaTensor *dist1, THCudaTensor *dist2, THCudaIntTensor *idx1, THCudaIntTensor *idx2) { 10 | int success = 0; 11 | success = NmDistanceKernelLauncher(xyz1->size[0], 12 | xyz1->size[1], 13 | THCudaTensor_data(state, xyz1), 14 | xyz2->size[1], 15 | THCudaTensor_data(state, xyz2), 16 | THCudaTensor_data(state, dist1), 17 | THCudaIntTensor_data(state, idx1), 18 | THCudaTensor_data(state, dist2), 19 | THCudaIntTensor_data(state, idx2), 20 | THCState_getCurrentStream(state) 21 | ); 22 | //int NmDistanceKernelLauncher(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream) 23 | 24 | 25 | if (!success) { 26 | THError("aborting"); 27 | } 28 | return 1; 29 | } 30 | 31 | 32 | int nnd_backward_cuda(THCudaTensor *xyz1, THCudaTensor *xyz2, THCudaTensor *gradxyz1, THCudaTensor *gradxyz2, THCudaTensor *graddist1, 33 | THCudaTensor *graddist2, THCudaIntTensor *idx1, THCudaIntTensor *idx2) { 34 | 35 | int success = 0; 36 | success = NmDistanceGradKernelLauncher(xyz1->size[0], 37 | xyz1->size[1], 38 | THCudaTensor_data(state, xyz1), 39 | xyz2->size[1], 40 | THCudaTensor_data(state, xyz2), 41 | THCudaTensor_data(state, graddist1), 42 | THCudaIntTensor_data(state, idx1), 43 | THCudaTensor_data(state, graddist2), 44 | THCudaIntTensor_data(state, idx2), 45 | THCudaTensor_data(state, gradxyz1), 46 | THCudaTensor_data(state, gradxyz2), 47 | THCState_getCurrentStream(state) 48 | ); 49 | //int NmDistanceGradKernelLauncher(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,const float * grad_dist2,const int * idx2,float * grad_xyz1,float * grad_xyz2, cudaStream_t stream) 50 | 51 | if (!success) { 52 | THError("aborting"); 53 | } 54 | 55 | return 1; 56 | } 57 | 58 | 59 | 60 | -------------------------------------------------------------------------------- /utils_distance/nndistance/src/my_lib_cuda.h: -------------------------------------------------------------------------------- 1 | int nnd_forward_cuda(THCudaTensor *xyz1, THCudaTensor *xyz2, THCudaTensor *dist1, THCudaTensor *dist2, THCudaIntTensor *idx1, THCudaIntTensor *idx2); 2 | 3 | 4 | int nnd_backward_cuda(THCudaTensor *xyz1, THCudaTensor *xyz2, THCudaTensor *gradxyz1, THCudaTensor *gradxyz2, THCudaTensor *graddist1, THCudaTensor *graddist2, THCudaIntTensor *idx1, THCudaIntTensor *idx2); 5 | 6 | -------------------------------------------------------------------------------- /utils_distance/nndistance/src/nnd_cuda.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include "nnd_cuda.h" 3 | 4 | 5 | 6 | __global__ void NmDistanceKernel(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i){ 7 | const int batch=512; 8 | __shared__ float buf[batch*3]; 9 | for (int i=blockIdx.x;ibest){ 121 | result[(i*n+j)]=best; 122 | result_i[(i*n+j)]=best_i; 123 | } 124 | } 125 | __syncthreads(); 126 | } 127 | } 128 | } 129 | int NmDistanceKernelLauncher(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream){ 130 | NmDistanceKernel<<>>(b,n,xyz,m,xyz2,result,result_i); 131 | NmDistanceKernel<<>>(b,m,xyz2,n,xyz,result2,result2_i); 132 | 133 | cudaError_t err = cudaGetLastError(); 134 | if (err != cudaSuccess) { 135 | printf("error in nnd updateOutput: %s\n", cudaGetErrorString(err)); 136 | //THError("aborting"); 137 | return 0; 138 | } 139 | return 1; 140 | 141 | 142 | } 143 | __global__ void NmDistanceGradKernel(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,float * grad_xyz1,float * grad_xyz2){ 144 | for (int i=blockIdx.x;i>>(b,n,xyz1,m,xyz2,grad_dist1,idx1,grad_xyz1,grad_xyz2); 167 | NmDistanceGradKernel<<>>(b,m,xyz2,n,xyz1,grad_dist2,idx2,grad_xyz2,grad_xyz1); 168 | 169 | cudaError_t err = cudaGetLastError(); 170 | if (err != cudaSuccess) { 171 | printf("error in nnd get grad: %s\n", cudaGetErrorString(err)); 172 | //THError("aborting"); 173 | return 0; 174 | } 175 | return 1; 176 | 177 | } 178 | 179 | -------------------------------------------------------------------------------- /utils_distance/nndistance/src/nnd_cuda.h: -------------------------------------------------------------------------------- 1 | #ifdef __cplusplus 2 | extern "C" { 3 | #endif 4 | 5 | int NmDistanceKernelLauncher(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream); 6 | 7 | int NmDistanceGradKernelLauncher(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,const float * grad_dist2,const int * idx2,float * grad_xyz1,float * grad_xyz2, cudaStream_t stream); 8 | 9 | #ifdef __cplusplus 10 | } 11 | #endif -------------------------------------------------------------------------------- /utils_distance/nndistance/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | 5 | from modules.nnd import NNDModule 6 | 7 | dist = NNDModule() 8 | 9 | p1 = torch.ones(2,100,3)*0.2 10 | p2 = torch.ones(2,150,3)*0.3 11 | p1[0,0,0] = 1 12 | p2[1,1,1] = -1 13 | #points1 = Variable(p1,requires_grad = True) 14 | #points2 = Variable(p2) 15 | #dist1, idx1, dist2, idx2 = dist(points1, points2) 16 | #print(dist1, dist2) 17 | #loss = torch.sum(dist1) 18 | #print(loss) 19 | #loss.backward() 20 | #print(points1.grad, points2.grad) 21 | 22 | 23 | points1 = Variable(p1.cuda(), requires_grad = True) 24 | points2 = Variable(p2.cuda()) 25 | dist1, dist2, idx1, idx2 = dist(points1, points2) 26 | #print(dist1, dist2) 27 | print(idx1, idx2) 28 | loss = torch.mean(dist1) + torch.mean(dist2) 29 | print(loss) 30 | loss.backward() 31 | print(points1.grad, points2.grad) --------------------------------------------------------------------------------