├── .gitignore ├── CMakeLists.txt ├── README.md ├── data ├── __init__.py ├── augmentation.py ├── modelnet.py └── shapenet.py ├── models ├── __init__.py └── splatnet.py ├── train_seg.py └── utils ├── CMakeLists.txt ├── __init__.py ├── include ├── cuda_headers.h ├── cuda_utils.h └── permutohedral_lattice.h ├── permutohedral_lattice_gpu.cu ├── permutohedral_lattice_layer.py └── permutohedral_lattice_wrapper.cc /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by .ignore support plugin (hsz.mobi) 2 | ### CLION IDE 3 | .idea 4 | ide_helper.cc 5 | ### C++ template 6 | # Prerequisites 7 | *.d 8 | 9 | # Compiled Object files 10 | *.slo 11 | *.lo 12 | *.o 13 | *.obj 14 | 15 | # Precompiled Headers 16 | *.gch 17 | *.pch 18 | 19 | # Compiled Dynamic libraries 20 | *.so 21 | *.dylib 22 | *.dll 23 | 24 | # Fortran module files 25 | *.mod 26 | *.smod 27 | 28 | # Compiled Static libraries 29 | *.lai 30 | *.la 31 | *.a 32 | *.lib 33 | 34 | # Executables 35 | *.exe 36 | *.out 37 | *.app 38 | ### CMake template 39 | CMakeCache.txt 40 | CMakeFiles 41 | CMakeScripts 42 | Testing 43 | Makefile 44 | cmake_install.cmake 45 | install_manifest.txt 46 | compile_commands.json 47 | CTestTestfile.cmake 48 | ### C template 49 | # Prerequisites 50 | *.d 51 | 52 | # Object files 53 | *.o 54 | *.ko 55 | *.obj 56 | *.elf 57 | 58 | # Linker output 59 | *.ilk 60 | *.map 61 | *.exp 62 | 63 | # Precompiled Headers 64 | *.gch 65 | *.pch 66 | 67 | # Libraries 68 | *.lib 69 | *.a 70 | *.la 71 | *.lo 72 | 73 | # Shared objects (inc. Windows DLLs) 74 | *.dll 75 | *.so 76 | *.so.* 77 | *.dylib 78 | 79 | # Executables 80 | *.exe 81 | *.out 82 | *.app 83 | *.i*86 84 | *.x86_64 85 | *.hex 86 | 87 | # Debug files 88 | *.dSYM/ 89 | *.su 90 | *.idb 91 | *.pdb 92 | 93 | # Kernel Module Compile Results 94 | *.mod* 95 | *.cmd 96 | .tmp_versions/ 97 | modules.order 98 | Module.symvers 99 | Mkfile.old 100 | dkms.conf 101 | ### Python template 102 | # Byte-compiled / optimized / DLL files 103 | __pycache__/ 104 | *.py[cod] 105 | *$py.class 106 | 107 | # C extensions 108 | *.so 109 | 110 | # Distribution / packaging 111 | .Python 112 | build/ 113 | develop-eggs/ 114 | dist/ 115 | downloads/ 116 | eggs/ 117 | .eggs/ 118 | lib/ 119 | lib64/ 120 | parts/ 121 | sdist/ 122 | var/ 123 | wheels/ 124 | *.egg-info/ 125 | .installed.cfg 126 | *.egg 127 | MANIFEST 128 | 129 | # PyInstaller 130 | # Usually these files are written by a python script from a template 131 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 132 | *.manifest 133 | *.spec 134 | 135 | # Installer logs 136 | pip-log.txt 137 | pip-delete-this-directory.txt 138 | 139 | # Unit test / coverage reports 140 | htmlcov/ 141 | .tox/ 142 | .coverage 143 | .coverage.* 144 | .cache 145 | nosetests.xml 146 | coverage.xml 147 | *.cover 148 | .hypothesis/ 149 | .pytest_cache/ 150 | 151 | # Translations 152 | *.mo 153 | *.pot 154 | 155 | # Django stuff: 156 | *.log 157 | local_settings.py 158 | db.sqlite3 159 | 160 | # Flask stuff: 161 | instance/ 162 | .webassets-cache 163 | 164 | # Scrapy stuff: 165 | .scrapy 166 | 167 | # Sphinx documentation 168 | docs/_build/ 169 | 170 | # PyBuilder 171 | target/ 172 | 173 | # Jupyter Notebook 174 | .ipynb_checkpoints 175 | 176 | # pyenv 177 | .python-version 178 | 179 | # celery beat schedule file 180 | celerybeat-schedule 181 | 182 | # SageMath parsed files 183 | *.sage.py 184 | 185 | # Environments 186 | .env 187 | .venv 188 | env/ 189 | venv/ 190 | ENV/ 191 | env.bak/ 192 | venv.bak/ 193 | 194 | # Spyder project settings 195 | .spyderproject 196 | .spyproject 197 | 198 | # Rope project settings 199 | .ropeproject 200 | 201 | # mkdocs documentation 202 | /site 203 | 204 | # mypy 205 | .mypy_cache/ 206 | ### CUDA template 207 | *.i 208 | *.ii 209 | *.gpu 210 | *.ptx 211 | *.cubin 212 | *.fatbin 213 | 214 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 2.8 FATAL_ERROR) 2 | project(splatnet LANGUAGES CXX CUDA) 3 | 4 | set(LIBRARY_OUTPUT_PATH ${PROJECT_SOURCE_DIR}/utils/lib) 5 | set(CMAKE_CXX_COMPILER "g++") 6 | set(CMAKE_CXX_FLAGS "-std=c++11") 7 | set(CMAKE_CUDA_FLAGS "-std=c++11 --gpu-architecture=sm_61") 8 | 9 | set(INCLUDE_DIRS 10 | "/usr/local/cuda-8.0/include" 11 | "/home/lyc/virtualenv/pytorch/lib/python3.5/site-packages/torch/lib/include" 12 | "/home/lyc/virtualenv/pytorch/lib/python3.5/site-packages/torch/lib/include/TH" 13 | "/home/lyc/virtualenv/pytorch/lib/python3.5/site-packages/torch/lib/include/THC" 14 | "/home/lyc/virtualenv/pytorch/include/python3.5m" 15 | "${PROJECT_SOURCE_DIR}/utils/include") 16 | set(LINK_LIBRARIES "") 17 | include_directories(${INCLUDE_DIRS}) 18 | add_subdirectory(${PROJECT_SOURCE_DIR}/utils) 19 | 20 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Unofficial partial implemention of [SplatNet](https://github.com/NVlabs/splatnet) written in PyTorch. 2 | Hang Su's paper: [SPLATNet: Sparse Lattice Networks for Point Cloud Processing](https://arxiv.org/abs/1712.06760) 3 | References paper: 4 | Andrew B. Adams. [High-dimentional Gaussian Filtering for Computional Photography.](people.csail.mit.edu/abadams/thesis.pdf), 2011 5 | V. Jampani, M. Kiefel and P. V. Gehler. [Learning Sparse High-Dimensional Filters: Image Filtering, Dense CRFs and Bilateral Neural Networks.](https://arxiv.org/abs/1503.04949) CVPR, 2016 [Github](https://github.com/MPI-IS/bilateralNN) 6 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ftdlyc/splatnet_pytorch/0fbc60004d21a2c1dcc6beae065831f5f44add3a/data/__init__.py -------------------------------------------------------------------------------- /data/augmentation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def rotate_point_cloud(data): 5 | """ 6 | 7 | :param data: Nx3 array 8 | :return: rotated_data: Nx3 array 9 | """ 10 | angles = np.random.uniform() * 2 * np.pi 11 | cosval = np.cos(angles) 12 | sinval = np.sin(angles) 13 | R = np.array([[cosval, 0, sinval], 14 | [0, 1, 0], 15 | [-sinval, 0, cosval]]) 16 | rotated_data = np.dot(data, R) 17 | return rotated_data 18 | 19 | 20 | def random_rotate_point_cloud(data, angle_sigma=0.06, angle_clip=0.18): 21 | """ 22 | 23 | :param data: Nx3 array 24 | :return: rotated_data: Nx3 array 25 | """ 26 | angles = np.clip(angle_sigma * np.random.randn(3), -angle_clip, angle_clip) 27 | Rx = np.array([[1, 0, 0], 28 | [0, np.cos(angles[0]), -np.sin(angles[0])], 29 | [0, np.sin(angles[0]), np.cos(angles[0])]]) 30 | Ry = np.array([[np.cos(angles[1]), 0, np.sin(angles[1])], 31 | [0, 1, 0], 32 | [-np.sin(angles[1]), 0, np.cos(angles[1])]]) 33 | Rz = np.array([[np.cos(angles[2]), -np.sin(angles[2]), 0], 34 | [np.sin(angles[2]), np.cos(angles[2]), 0], 35 | [0, 0, 1]]) 36 | R = np.dot(Rz, np.dot(Ry, Rx)) 37 | rotated_data = np.dot(data, R) 38 | 39 | return rotated_data 40 | 41 | 42 | def jitter_point_cloud(data, sigma=0.01, clip=0.05): 43 | """ 44 | 45 | :param data: Nx3 array 46 | :return: jittered_data: Nx3 array 47 | """ 48 | N, C = data.shape 49 | jittered_data = np.clip(sigma * np.random.randn(N, C), -1 * clip, clip) 50 | jittered_data += data 51 | 52 | return jittered_data 53 | 54 | 55 | def random_scale_point_cloud(data, scale_low=0.8, scale_high=1.25): 56 | """ 57 | 58 | :param data: Nx3 array 59 | :return: scaled_data: Nx3 array 60 | """ 61 | scale = np.random.uniform(scale_low, scale_high) 62 | scaled_data = data * scale 63 | 64 | return scaled_data 65 | 66 | 67 | def random_dropout_point_cloud(data, p=0.9): 68 | """ 69 | 70 | :param data: Nx3 array 71 | :return: dropout_data: Nx3 array 72 | """ 73 | N, C = data.shape 74 | dropout_ratio = np.random.random() * p 75 | drop_idx = np.where(np.random.random(N) <= dropout_ratio)[0] 76 | dropout_data = np.zeros_like(data) 77 | if len(drop_idx) > 0: 78 | dropout_data[drop_idx, :] = data[0, :] 79 | 80 | return dropout_data 81 | 82 | 83 | def shift_point_cloud(data, shift_range=0.1): 84 | """ 85 | 86 | :param data: Nx3 array 87 | :return: shift_data: Nx3 array 88 | """ 89 | N, C = data.shape 90 | shifts = np.random.uniform(-shift_range, shift_range, 3) 91 | shift_data = data + shifts 92 | return shift_data 93 | -------------------------------------------------------------------------------- /data/modelnet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.utils.data as data 4 | import numpy as np 5 | from .augmentation import * 6 | 7 | 8 | class ModelNetDataset(data.Dataset): 9 | 10 | def __init__(self, root, point_nums=2048, train=True, argumentation=True): 11 | self.root = root 12 | self.point_nums = point_nums 13 | self.train = train 14 | self.argumentation = argumentation 15 | self.dataset = [] 16 | 17 | file = open(os.path.join(root, 'modelnet10_shape_names.txt'), 'r') 18 | self.shape_list = [str.rstrip() for str in file.readlines()] 19 | file.close() 20 | self.class_nums = len(self.shape_list) 21 | 22 | if train: 23 | file = open(os.path.join(root, 'modelnet10_train.txt'), 'r') 24 | else: 25 | file = open(os.path.join(root, 'modelnet10_test.txt'), 'r') 26 | for line in file.readlines(): 27 | line = line.rstrip() 28 | name = line[0:-5] 29 | label = self.shape_list.index(name) 30 | self.dataset.append((os.path.join(os.path.join(root, name), line + '.txt'), label)) 31 | file.close() 32 | 33 | def __len__(self): 34 | return len(self.dataset) 35 | 36 | def __getitem__(self, index): 37 | file_path, label = self.dataset[index] 38 | data = np.loadtxt(file_path, dtype=np.float32, delimiter=',', usecols=(0, 1, 2)) 39 | data = data[np.random.choice(data.shape[0], self.point_nums, replace=False), :] 40 | 41 | if self.train and self.argumentation: 42 | # data = rotate_point_cloud(data) 43 | data = random_rotate_point_cloud(data) 44 | data = random_scale_point_cloud(data) 45 | data = shift_point_cloud(data) 46 | data = jitter_point_cloud(data) 47 | 48 | pc = torch.from_numpy(data.transpose().astype(np.float32)) 49 | 50 | return pc, label 51 | -------------------------------------------------------------------------------- /data/shapenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | import torch.utils.data as data 5 | import numpy as np 6 | from .augmentation import * 7 | 8 | class ShapeNetDataset(data.Dataset): 9 | 10 | def __init__(self, root, point_nums=2048, split='train', argumentation=True): 11 | self.root = root 12 | self.point_nums = point_nums 13 | self.split = split 14 | self.argumentation = argumentation 15 | self.dataset = [] 16 | 17 | categories = {} 18 | with open(os.path.join(self.root, 'synsetoffset2category.txt'), 'r') as file: 19 | i = 0 20 | for line in file: 21 | ls = line.strip().split()[1] 22 | categories[ls] = i 23 | i = i + 1 24 | self.category_nums = len(categories) 25 | self.class_nums = 6 26 | 27 | with open(os.path.join(self.root, 'train_test_split', 'shuffled_train_file_list.json'), 'r') as file: 28 | train_idxs = [(d.split('/')[1], d.split('/')[2]) for d in json.load(file)] 29 | with open(os.path.join(self.root, 'train_test_split', 'shuffled_val_file_list.json'), 'r') as file: 30 | val_idxs = [(d.split('/')[1], d.split('/')[2]) for d in json.load(file)] 31 | with open(os.path.join(self.root, 'train_test_split', 'shuffled_test_file_list.json'), 'r') as file: 32 | test_idxs = [(d.split('/')[1], d.split('/')[2]) for d in json.load(file)] 33 | 34 | if split == 'train': 35 | for category, hash in train_idxs: 36 | self.dataset.append((os.path.join(self.root, category, 'points', hash + '.pts'), 37 | os.path.join(self.root, category, 'points_label', hash + '.seg'), 38 | categories[category])) 39 | elif split == 'val': 40 | for category, hash in val_idxs: 41 | self.dataset.append((os.path.join(self.root, category, 'points', hash + '.pts'), 42 | os.path.join(self.root, category, 'points_label', hash + '.seg'), 43 | categories[category])) 44 | elif split == 'test': 45 | for category, hash in test_idxs: 46 | self.dataset.append((os.path.join(self.root, category, 'points', hash + '.pts'), 47 | os.path.join(self.root, category, 'points_label', hash + '.seg'), 48 | categories[category])) 49 | elif split == 'train&val': 50 | for category, hash in train_idxs: 51 | self.dataset.append((os.path.join(self.root, category, 'points', hash + '.pts'), 52 | os.path.join(self.root, category, 'points_label', hash + '.seg'), 53 | categories[category])) 54 | for category, hash in val_idxs: 55 | self.dataset.append((os.path.join(self.root, category, 'points', hash + '.pts'), 56 | os.path.join(self.root, category, 'points_label', hash + '.seg'), 57 | categories[category])) 58 | else: 59 | raise NameError 60 | 61 | def __len__(self): 62 | return len(self.dataset) 63 | 64 | def __getitem__(self, index): 65 | pc_path, seg_path, label = self.dataset[index] 66 | pc_data = np.loadtxt(pc_path, dtype=np.float32) 67 | seg_data = np.loadtxt(seg_path, dtype=np.long) 68 | if pc_data.shape[0] >= self.point_nums: 69 | sampling_idxs = np.random.choice(pc_data.shape[0], self.point_nums, replace=False) 70 | else: 71 | sampling_idxs = np.random.choice(pc_data.shape[0], self.point_nums - pc_data.shape[0], replace=True) 72 | sampling_idxs = np.append(sampling_idxs, [i for i in range(0, pc_data.shape[0])]) 73 | np.random.shuffle(sampling_idxs) 74 | pc_data = pc_data[sampling_idxs, :] 75 | seg_data = seg_data[sampling_idxs] - 1 76 | 77 | if self.argumentation: 78 | # pc_data = random_rotate_point_cloud(pc_data) 79 | pc_data = random_scale_point_cloud(pc_data) 80 | pc_data = shift_point_cloud(pc_data) 81 | pc_data = jitter_point_cloud(pc_data) 82 | pc_data = jitter_point_cloud(pc_data) 83 | 84 | pc = torch.from_numpy(pc_data.transpose().astype(np.float32)) 85 | seg = torch.from_numpy(seg_data.astype(np.long)) 86 | return pc, seg, label 87 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ftdlyc/splatnet_pytorch/0fbc60004d21a2c1dcc6beae065831f5f44add3a/models/__init__.py -------------------------------------------------------------------------------- /models/splatnet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 5 | sys.path.append(BASE_DIR) 6 | sys.path.append(os.path.join(BASE_DIR, "../utils")) 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torch.optim as optim 12 | import numpy as np 13 | from utils.permutohedral_lattice_layer import PermutohedralLattice 14 | 15 | 16 | class SplatNetSegment(nn.Module): 17 | 18 | def __init__(self, class_nums, category_nums, pos_lambda=64, device_id=0, initial_weights=True): 19 | super(SplatNetSegment, self).__init__() 20 | 21 | self.class_nums = class_nums 22 | self.category_nums = category_nums 23 | self.device_id = device_id 24 | 25 | self.mlp1 = nn.Sequential( 26 | nn.Conv1d(3, 32, 1, bias=False), 27 | nn.BatchNorm1d(32), 28 | nn.ReLU(True) 29 | ) 30 | self.pl1 = PermutohedralLattice(32, 64, 3, pos_lambda, bias=False) 31 | self.pl1_bn = nn.Sequential( 32 | nn.BatchNorm1d(64), 33 | nn.ReLU(True) 34 | ) 35 | self.pl2 = PermutohedralLattice(64, 128, 3, pos_lambda / 2, bias=False) 36 | self.pl2_bn = nn.Sequential( 37 | nn.BatchNorm1d(128), 38 | nn.ReLU(True) 39 | ) 40 | self.pl3 = PermutohedralLattice(128, 256, 3, pos_lambda / 4, bias=False) 41 | self.pl3_bn = nn.Sequential( 42 | nn.BatchNorm1d(256), 43 | nn.ReLU(True) 44 | ) 45 | self.pl4 = PermutohedralLattice(256, 256, 3, pos_lambda / 8, bias=False) 46 | self.pl4_bn = nn.Sequential( 47 | nn.BatchNorm1d(256), 48 | nn.ReLU(True) 49 | ) 50 | self.pl5 = PermutohedralLattice(256, 256, 3, pos_lambda / 16, bias=False) 51 | self.pl5_bn = nn.Sequential( 52 | nn.BatchNorm1d(256), 53 | nn.ReLU(True) 54 | ) 55 | self.mlp2 = nn.Sequential( 56 | nn.Conv1d(64 + 128 + 3 * 256 + category_nums, 256, 1, bias=False), 57 | nn.BatchNorm1d(256), 58 | nn.ReLU(True), 59 | nn.Conv1d(256, 128, 1, bias=False), 60 | nn.BatchNorm1d(128), 61 | nn.ReLU(True), 62 | nn.Dropout(0.3), 63 | nn.Conv1d(128, class_nums, 1) 64 | ) 65 | 66 | if initial_weights: 67 | self.initialize_weights() 68 | 69 | self.criterion = nn.CrossEntropyLoss() 70 | self.optimizer = optim.Adam(self.parameters(), lr=0.01, weight_decay=0.001) 71 | self.schedule = optim.lr_scheduler.StepLR(self.optimizer, 20, 0.5) 72 | 73 | self.cuda(device_id) 74 | 75 | def forward(self, x, labels, position): 76 | x = self.mlp1(x) 77 | x1 = self.pl1(x, position) 78 | x1 = self.pl1_bn(x1) 79 | x2 = self.pl2(x1, position) 80 | x2 = self.pl2_bn(x2) 81 | x3 = self.pl3(x2, position) 82 | x3 = self.pl3_bn(x3) 83 | x4 = self.pl4(x3, position) 84 | x4 = self.pl4_bn(x4) 85 | x5 = self.pl5(x4, position) 86 | x5 = self.pl5_bn(x5) 87 | index = labels.unsqueeze(1).repeat([1, x.size(2)]).unsqueeze(1) 88 | one_hot = torch.zeros([x.size(0), self.category_nums, x.size(2)]) 89 | one_hot = one_hot.cuda(self.device_id) 90 | one_hot = one_hot.scatter_(1, index, 1) 91 | x = torch.cat([x1, x2, x3, x4, x5, one_hot], dim=1) 92 | x = self.mlp2(x) 93 | return x 94 | 95 | def initialize_weights(self): 96 | for m in self.modules(): 97 | if isinstance(m, nn.Conv1d): 98 | m.weight.data.normal_(0, 0.01) 99 | if m.bias is not None: 100 | m.bias.data.zero_() 101 | elif isinstance(m, nn.Linear): 102 | m.weight.data.normal_(0, 0.01) 103 | if m.bias is not None: 104 | m.bias.data.zero_() 105 | elif isinstance(m, nn.BatchNorm1d): 106 | m.weight.data.fill_(1) 107 | m.bias.data.zero_() 108 | elif isinstance(m, PermutohedralLattice): 109 | m.weights.data.normal_(0, 0.01) 110 | 111 | def loss(self, outputs, targets): 112 | return self.criterion(outputs, targets) 113 | 114 | def fit(self, dataloader, epoch): 115 | self.train() 116 | batch_loss = 0. 117 | epoch_loss = 0. 118 | batch_nums = 0 119 | if self.schedule is not None: 120 | self.schedule.step() 121 | 122 | print('----------epoch %d start train----------' % epoch) 123 | 124 | for batch_idx, (inputs, targets, labels) in enumerate(dataloader): 125 | inputs = inputs.cuda(self.device_id) 126 | targets = targets.cuda(self.device_id) 127 | labels = labels.cuda(self.device_id) 128 | position = inputs.transpose(1, 2).contiguous() 129 | self.optimizer.zero_grad() 130 | 131 | outputs = self(inputs, labels, position.detach()) 132 | losses = self.loss(outputs, targets) 133 | losses.backward() 134 | self.optimizer.step() 135 | 136 | batch_loss += losses.item() 137 | epoch_loss += losses.item() 138 | batch_nums += 1 139 | if (batch_idx + 1) % 4 == 0: 140 | print('[%d, %5d] loss %.3f' % (epoch, batch_idx, batch_loss / 4)) 141 | batch_loss = 0. 142 | 143 | print('-----------epoch %d end train-----------' % epoch) 144 | print('epoch %d loss %.3f' % (epoch, epoch_loss / batch_nums)) 145 | 146 | return epoch_loss / batch_nums 147 | 148 | def score(self, dataloader): 149 | self.eval() 150 | correct = 0. 151 | total = 0 152 | 153 | with torch.no_grad(): 154 | for batch_idx, (inputs, targets, labels) in enumerate(dataloader): 155 | inputs = inputs.cuda(self.device_id) 156 | targets = targets.cuda(self.device_id) 157 | labels = labels.cuda(self.device_id) 158 | position = inputs.transpose(1, 2).contiguous() 159 | 160 | outputs = self(inputs, labels, position.detach()) 161 | _, predicted = torch.max(outputs.data, 1) 162 | total += targets.size(0) * targets.size(1) 163 | correct += (predicted == targets).sum().item() 164 | 165 | print('Accuracy of the network on the test images: %d %%' % (100 * correct / total)) 166 | 167 | return correct / total 168 | -------------------------------------------------------------------------------- /train_seg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from data.shapenet import ShapeNetDataset 3 | from models.splatnet import SplatNetSegment 4 | 5 | trainset = ShapeNetDataset('/opt/shapenetcore_partanno_segmentation_benchmark_v0', split='train', argumentation=True) 6 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True, num_workers=2) 7 | testset = ShapeNetDataset('/opt/shapenetcore_partanno_segmentation_benchmark_v0', split='test', argumentation=False) 8 | testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True, num_workers=2) 9 | valset = ShapeNetDataset('/opt/shapenetcore_partanno_segmentation_benchmark_v0', split='val', argumentation=False) 10 | valloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True, num_workers=2) 11 | 12 | net = SplatNetSegment(trainset.class_nums, trainset.category_nums) 13 | for epcho in range(1, 400): 14 | net.fit(trainloader, epcho) 15 | if epcho % 20 == 0: 16 | net.score(valloader) 17 | net.score(testloader) 18 | torch.save(net.state_dict(), 'model.pkl') 19 | -------------------------------------------------------------------------------- /utils/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_library(permutohedral_lattice SHARED permutohedral_lattice_gpu.cu permutohedral_lattice_wrapper.cc) 2 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ftdlyc/splatnet_pytorch/0fbc60004d21a2c1dcc6beae065831f5f44add3a/utils/__init__.py -------------------------------------------------------------------------------- /utils/include/cuda_headers.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by Palash on 19-02-2018. 3 | // Modify by ftdlyc 4 | // 5 | 6 | #ifndef CUDA_BASE_CUDAHEADERS_H 7 | #define CUDA_BASE_CUDAHEADERS_H 8 | 9 | #define TOTAL_THREADS 1024 10 | 11 | inline int opt_n_threads(int work_size) { 12 | const int pow_2 = static_cast(floorf(log2f(work_size))); 13 | return max(min(1 << pow_2, TOTAL_THREADS), 1); 14 | } 15 | 16 | inline dim3 opt_block_config(int x, int y) { 17 | const int x_threads = opt_n_threads(x); 18 | const int y_threads = max(min(opt_n_threads(y), TOTAL_THREADS / x_threads), 1); 19 | dim3 block_config(x_threads, y_threads, 1); 20 | 21 | return block_config; 22 | } 23 | 24 | #ifdef __JETBRAINS_IDE__ 25 | 26 | #include 27 | #include 28 | #include 29 | #include 30 | #include 31 | #include 32 | #include 33 | 34 | #include "math.h" 35 | #define __CUDACC__ 1 36 | #define __host__ 37 | #define __device__ 38 | #define __global__ 39 | #define __noinline__ 40 | #define __forceinline__ 41 | #define __shared__ 42 | #define __constant__ 43 | #define __managed__ 44 | #define __restrict__ 45 | // CUDA Synchronization 46 | inline void __syncthreads() {}; 47 | inline void __threadfence_block() {}; 48 | inline void __threadfence() {}; 49 | inline void __threadfence_system(); 50 | inline int __syncthreads_count(int predicate) { return predicate; }; 51 | inline int __syncthreads_and(int predicate) { return predicate; }; 52 | inline int __syncthreads_or(int predicate) { return predicate; }; 53 | template 54 | inline T __clz(const T val) { return val; } 55 | template 56 | inline T __ldg(const T *address) { return *address; }; 57 | // CUDA TYPES 58 | typedef unsigned short uchar; 59 | typedef unsigned short ushort; 60 | typedef unsigned int uint; 61 | typedef unsigned long ulong; 62 | typedef unsigned long long ulonglong; 63 | typedef long long longlong; 64 | 65 | typedef struct uchar1 { 66 | uchar x; 67 | } uchar1; 68 | 69 | typedef struct uchar2 { 70 | uchar x; 71 | uchar y; 72 | } uchar2; 73 | 74 | typedef struct uchar3 { 75 | uchar x; 76 | uchar y; 77 | uchar z; 78 | } uchar3; 79 | 80 | typedef struct uchar4 { 81 | uchar x; 82 | uchar y; 83 | uchar z; 84 | uchar w; 85 | } uchar4; 86 | 87 | typedef struct char1 { 88 | char x; 89 | } char1; 90 | 91 | typedef struct char2 { 92 | char x; 93 | char y; 94 | } char2; 95 | 96 | typedef struct char3 { 97 | char x; 98 | char y; 99 | char z; 100 | } char3; 101 | 102 | typedef struct char4 { 103 | char x; 104 | char y; 105 | char z; 106 | char w; 107 | } char4; 108 | 109 | typedef struct ushort1 { 110 | ushort x; 111 | } ushort1; 112 | 113 | typedef struct ushort2 { 114 | ushort x; 115 | ushort y; 116 | } ushort2; 117 | 118 | typedef struct ushort3 { 119 | ushort x; 120 | ushort y; 121 | ushort z; 122 | } ushort3; 123 | 124 | typedef struct ushort4 { 125 | ushort x; 126 | ushort y; 127 | ushort z; 128 | ushort w; 129 | } ushort4; 130 | 131 | typedef struct short1 { 132 | short x; 133 | } short1; 134 | 135 | typedef struct short2 { 136 | short x; 137 | short y; 138 | } short2; 139 | 140 | typedef struct short3 { 141 | short x; 142 | short y; 143 | short z; 144 | } short3; 145 | 146 | typedef struct short4 { 147 | short x; 148 | short y; 149 | short z; 150 | short w; 151 | } short4; 152 | 153 | typedef struct uint1 { 154 | uint x; 155 | } uint1; 156 | 157 | typedef struct uint2 { 158 | uint x; 159 | uint y; 160 | } uint2; 161 | 162 | typedef struct uint3 { 163 | uint x; 164 | uint y; 165 | uint z; 166 | } uint3; 167 | 168 | typedef struct uint4 { 169 | uint x; 170 | uint y; 171 | uint z; 172 | uint w; 173 | } uint4; 174 | 175 | typedef struct int1 { 176 | int x; 177 | } int1; 178 | 179 | typedef struct int2 { 180 | int x; 181 | int y; 182 | } int2; 183 | 184 | typedef struct int3 { 185 | int x; 186 | int y; 187 | int z; 188 | } int3; 189 | 190 | typedef struct int4 { 191 | int x; 192 | int y; 193 | int z; 194 | int w; 195 | } int4; 196 | 197 | typedef struct ulong1 { 198 | ulong x; 199 | } ulong1; 200 | 201 | typedef struct ulong2 { 202 | ulong x; 203 | ulong y; 204 | } ulong2; 205 | 206 | typedef struct ulong3 { 207 | ulong x; 208 | ulong y; 209 | ulong z; 210 | } ulong3; 211 | 212 | typedef struct ulong4 { 213 | ulong x; 214 | ulong y; 215 | ulong z; 216 | ulong w; 217 | } ulong4; 218 | 219 | typedef struct long1 { 220 | long x; 221 | } long1; 222 | 223 | typedef struct long2 { 224 | long x; 225 | long y; 226 | } long2; 227 | 228 | typedef struct long3 { 229 | long x; 230 | long y; 231 | long z; 232 | } long3; 233 | 234 | typedef struct long4 { 235 | long x; 236 | long y; 237 | long z; 238 | long w; 239 | } long4; 240 | 241 | typedef struct ulonglong1 { 242 | ulonglong x; 243 | } ulonglong1; 244 | 245 | typedef struct ulonglong2 { 246 | ulonglong x; 247 | ulonglong y; 248 | } ulonglong2; 249 | 250 | typedef struct ulonglong3 { 251 | ulonglong x; 252 | ulonglong y; 253 | ulonglong z; 254 | } ulonglong3; 255 | 256 | typedef struct ulonglong4 { 257 | ulonglong x; 258 | ulonglong y; 259 | ulonglong z; 260 | ulonglong w; 261 | } ulonglong4; 262 | 263 | typedef struct longlong1 { 264 | longlong x; 265 | } longlong1; 266 | 267 | typedef struct longlong2 { 268 | longlong x; 269 | longlong y; 270 | } longlong2; 271 | 272 | typedef struct float1 { 273 | float x; 274 | } float1; 275 | 276 | typedef struct float2 { 277 | float x; 278 | float y; 279 | } float2; 280 | 281 | typedef struct float3 { 282 | float x; 283 | float y; 284 | float z; 285 | } float3; 286 | 287 | typedef struct float4 { 288 | float x; 289 | float y; 290 | float z; 291 | float w; 292 | } float4; 293 | 294 | typedef struct double1 { 295 | double x; 296 | } double1; 297 | 298 | typedef struct double2 { 299 | double x; 300 | double y; 301 | } double2; 302 | 303 | typedef uint3 dim3; 304 | 305 | extern dim3 gridDim; 306 | extern uint3 blockIdx; 307 | extern dim3 blockDim; 308 | extern uint3 threadIdx; 309 | extern int warpsize; 310 | #endif 311 | 312 | #endif //CUDA_BASE_CUDAHEADERS_H 313 | -------------------------------------------------------------------------------- /utils/include/cuda_utils.h: -------------------------------------------------------------------------------- 1 | #ifndef CUDA_UTILS_H 2 | #define CUDA_UTILS_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #define CUDA_CHECK_ERROR(err) gpuAssert(err, __FILE__, __LINE__) 11 | inline void gpuAssert(cudaError_t code, const char *file, int line, bool abort = true) { 12 | if (code != cudaSuccess) { 13 | fprintf(stderr, "GPUassert: %s %s %d\n", cudaGetErrorString(code), file, line); 14 | if (abort) exit(code); 15 | } 16 | } 17 | #define CUDA_CHECK_KERNEL_ERROR() cudaDeviceSynchronize(); gpuAssert(cudaGetLastError(), __FILE__, __LINE__) 18 | 19 | inline const char *cublasGetErrorString(cublasStatus_t stat) { 20 | switch (stat) { 21 | case CUBLAS_STATUS_NOT_INITIALIZED:return "The cuBLAS library was not initialized"; 22 | case CUBLAS_STATUS_ALLOC_FAILED: return "Resource allocation failed inside the cuBLAS library"; 23 | case CUBLAS_STATUS_INVALID_VALUE: return "An unsupported value or parameter was passed to the function "; 24 | case CUBLAS_STATUS_ARCH_MISMATCH: return "The function requires a feature absent from the device architecture"; 25 | case CUBLAS_STATUS_MAPPING_ERROR: return "An access to GPU memory space failed"; 26 | case CUBLAS_STATUS_EXECUTION_FAILED: return "The GPU program failed to execute"; 27 | case CUBLAS_STATUS_INTERNAL_ERROR: return "An internal cuBLAS operation failed"; 28 | case CUBLAS_STATUS_NOT_SUPPORTED: return "The functionnality requested is not supported"; 29 | case CUBLAS_STATUS_LICENSE_ERROR: return "The functionnality requested requires some license"; 30 | } 31 | return "Unknown error"; 32 | }; 33 | 34 | #define CUBLAS_CHECK_ERROR(err) cublasAssert(err, __FILE__, __LINE__) 35 | inline void cublasAssert(cublasStatus_t stat, const char *file, int line, bool abort = true) { 36 | if (stat != CUBLAS_STATUS_SUCCESS) { 37 | fprintf(stderr, "CUBLASassert: %s %s %d\n", cublasGetErrorString(stat), file, line); 38 | if (abort) exit(stat); 39 | } 40 | } 41 | 42 | #define CHECK_CUDA(x) AT_ASSERT(x.type().is_cuda(), #x " must be a CUDA tensor") 43 | #define CHECK_CPU(x) AT_ASSERT(!(x.type().is_cuda()), #x " must be a CPU tensor") 44 | #define CHECK_CONTIGUOUS(x) AT_ASSERT(x.is_contiguous(), #x " must be contiguous") 45 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 46 | #define CHECK_INPUT_CPU(x) CHECK_CPU(x); CHECK_CONTIGUOUS(x) 47 | #define CHECK_INPUT_TYPE(x, y) AT_ASSERT(x.type().scalarType() == y, #x " must be " #y) 48 | 49 | #endif //CUDA_UTILS_H 50 | -------------------------------------------------------------------------------- /utils/include/permutohedral_lattice.h: -------------------------------------------------------------------------------- 1 | #ifndef PERMUTOHEDRAL_LATTICE_H 2 | #define PERMUTOHEDRAL_LATTICE_H 3 | 4 | int compute( 5 | int b, int n_points_in, int n_points_out, int df_in, int df_out, int d, int n, bool skip_conv, bool keep_position, 6 | const float *position_in, 7 | const float *position_out, 8 | const float *features_in, 9 | const float *norm_features, 10 | const float *weights, 11 | float *features_out, 12 | float *&lattices_in, 13 | int *offset_in, 14 | int *offset_out, 15 | float *barycentric_in, 16 | float *barycentric_out, 17 | int *&conv_neightbors, 18 | float *norm); 19 | 20 | void compute_grad( 21 | int b, int n_points_in, int n_points_out, int n_filled, int df_in, int df_out, int d, int n, bool skip_conv, 22 | const float *grad_out, 23 | const float *weights_transpose, 24 | const float *lattices_in, 25 | const int *offset_in, 26 | const int *offset_out, 27 | const float *barycentric_in, 28 | const float *barycentric_out, 29 | const int *conv_neightbors, 30 | const float *norm, 31 | float *grad_in, 32 | float *grad_weights_transpose); 33 | 34 | #endif -------------------------------------------------------------------------------- /utils/permutohedral_lattice_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "cuda_utils.h" 4 | #include "cuda_headers.h" 5 | #include "permutohedral_lattice.h" 6 | 7 | //----------------------------------------------------------------------------------------// 8 | //-------------------------------------- Hash Table --------------------------------------// 9 | //----------------------------------------------------------------------------------------// 10 | 11 | typedef struct HashTable { 12 | size_t *key_size; 13 | size_t *filled; 14 | size_t *capacity; 15 | int16_t *keys; 16 | int *table; 17 | } HashTable; 18 | 19 | __host__ void init_hash_table(int b, size_t key_size, size_t capacity, 20 | HashTable **hash_tables_host, HashTable **hash_tables_gpu) { 21 | size_t filled = 0; 22 | *hash_tables_host = new HashTable[b]; 23 | for (int i = 0; i < b; ++i) { 24 | CUDA_CHECK_ERROR(cudaMalloc((void **) (&(*hash_tables_host)[i].key_size), sizeof(size_t))); 25 | CUDA_CHECK_ERROR(cudaMalloc((void **) (&(*hash_tables_host)[i].filled), sizeof(size_t))); 26 | CUDA_CHECK_ERROR(cudaMalloc((void **) (&(*hash_tables_host)[i].capacity), sizeof(size_t))); 27 | CUDA_CHECK_ERROR(cudaMalloc((void **) (&(*hash_tables_host)[i].keys), capacity * key_size * sizeof(int16_t))); 28 | CUDA_CHECK_ERROR(cudaMalloc((void **) (&(*hash_tables_host)[i].table), capacity * sizeof(int))); 29 | 30 | CUDA_CHECK_ERROR(cudaMemcpy((*hash_tables_host)[i].key_size, &key_size, sizeof(size_t), cudaMemcpyHostToDevice)); 31 | CUDA_CHECK_ERROR(cudaMemcpy((*hash_tables_host)[i].filled, &filled, sizeof(size_t), cudaMemcpyHostToDevice)); 32 | CUDA_CHECK_ERROR(cudaMemcpy((*hash_tables_host)[i].capacity, &capacity, sizeof(size_t), cudaMemcpyHostToDevice)); 33 | CUDA_CHECK_ERROR(cudaMemset((*hash_tables_host)[i].keys, 0, capacity * key_size * sizeof(int16_t))); 34 | CUDA_CHECK_ERROR(cudaMemset((*hash_tables_host)[i].table, -1, capacity * sizeof(int))); 35 | } 36 | CUDA_CHECK_ERROR(cudaMalloc((void **) hash_tables_gpu, b * sizeof(HashTable))); 37 | CUDA_CHECK_ERROR(cudaMemcpy(*hash_tables_gpu, *hash_tables_host, b * sizeof(HashTable), cudaMemcpyHostToDevice)); 38 | }; 39 | 40 | __host__ void deinit_hash_table(int b, HashTable *hash_tables_host, HashTable *hash_tables_gpu) { 41 | for (int i = 0; i < b; ++i) { 42 | CUDA_CHECK_ERROR(cudaFree(hash_tables_host[i].key_size)); 43 | CUDA_CHECK_ERROR(cudaFree(hash_tables_host[i].filled)); 44 | CUDA_CHECK_ERROR(cudaFree(hash_tables_host[i].capacity)); 45 | CUDA_CHECK_ERROR(cudaFree(hash_tables_host[i].keys)); 46 | CUDA_CHECK_ERROR(cudaFree(hash_tables_host[i].table)); 47 | } 48 | CUDA_CHECK_ERROR(cudaFree(hash_tables_gpu)); 49 | delete[] hash_tables_host; 50 | }; 51 | 52 | __device__ __host__ size_t hash(HashTable &hash_table, const int16_t *key) { 53 | size_t s = 0; 54 | for (int i = 0; i < *hash_table.key_size; ++i) { 55 | s += key[i]; 56 | s *= 2531011; 57 | } 58 | return s; 59 | } 60 | 61 | __device__ __host__ int find_hash(HashTable &hash_table, const int16_t *key, bool create = false) { 62 | size_t &key_size = *hash_table.key_size; 63 | size_t &filled = *hash_table.filled; 64 | size_t &capacity = *hash_table.capacity; 65 | 66 | size_t h = hash(hash_table, key) % capacity; 67 | while (true) { 68 | int e = hash_table.table[h]; 69 | if (e == -1) { 70 | if (create && filled < capacity) { 71 | for (int i = 0; i < key_size; i++) { 72 | hash_table.keys[filled * key_size + i] = key[i]; 73 | } 74 | hash_table.table[h] = static_cast(filled++); 75 | return hash_table.table[h]; 76 | } 77 | return -1; 78 | } 79 | 80 | bool good = true; 81 | for (int i = 0; i < key_size; ++i) { 82 | good &= (hash_table.keys[e * key_size + i] == key[i]); 83 | } 84 | if (good) { 85 | return e; 86 | } 87 | 88 | h = h < capacity - 1 ? h + 1 : 0; 89 | } 90 | } 91 | 92 | 93 | //----------------------------------------------------------------------------------------// 94 | //----------------------------- Compute neighbors And Gauss -----------------------------// 95 | //----------------------------------------------------------------------------------------// 96 | 97 | class PermutationGridCorCallBack { 98 | public: 99 | PermutationGridCorCallBack(int d1, int n_grid_cor, int16_t *grid_) 100 | : n_(0), d1_(d1), n_grid_cor_(n_grid_cor), grid_(grid_) {} 101 | 102 | void operator()(int16_t *grid_cor) { 103 | memcpy(grid_, grid_cor, d1_ * sizeof(int16_t)); 104 | ++n_; 105 | if (n_ < n_grid_cor_) { 106 | grid_ += d1_; 107 | } 108 | } 109 | 110 | int n_; 111 | int d1_; 112 | int n_grid_cor_; 113 | int16_t *grid_; 114 | }; 115 | 116 | __host__ void walk_in_dimension(int d, int dimension, int step, int16_t *key) { 117 | for (int i = 0; i < d + 1; ++i) { 118 | key[i] -= step; 119 | } 120 | key[dimension] += step * d; 121 | } 122 | 123 | __host__ void permutation_grid_cor(int start, int end, int n, int16_t *grid_cor, PermutationGridCorCallBack &yield) { 124 | if (start == end) { 125 | yield(grid_cor); 126 | return; 127 | } 128 | for (int i = 0; i < n + 1; ++i) { 129 | grid_cor[start] = i; 130 | permutation_grid_cor(start + 1, end, n, grid_cor, yield); 131 | } 132 | } 133 | 134 | __host__ void compute_neighbors_and_gauss(int d, int n, int16_t *neighbors, float *gauss_weights) { 135 | // compute convolution filter relative position 136 | int d1 = d + 1; 137 | int n_grid_cor = static_cast(pow(n + 1, d1)); 138 | int16_t *grid = new int16_t[n_grid_cor * d1]; 139 | int16_t *grid_cor = new int16_t[d1]; 140 | PermutationGridCorCallBack yield(d1, n_grid_cor, grid); 141 | permutation_grid_cor(0, d1, n, grid_cor, yield); 142 | 143 | HashTable hash_table{new size_t[1], new size_t[1], new size_t[1], new int16_t[n_grid_cor * d1], new int[n_grid_cor]}; 144 | *hash_table.key_size = d1; 145 | *hash_table.filled = 0; 146 | *hash_table.capacity = n_grid_cor; 147 | memset(hash_table.keys, 0, n_grid_cor * d1 * sizeof(int16_t)); 148 | memset(hash_table.table, -1, n_grid_cor * sizeof(int)); 149 | 150 | int16_t *lattice_cor = new int16_t[d1]; 151 | int16_t *b = new int16_t[d1 * d1]; 152 | 153 | memset(b, -1, d1 * d1 * sizeof(int16_t)); 154 | for (int i = 0; i < d1; ++i) { 155 | b[i * d1 + i] = d1 - 1; 156 | } 157 | for (int i = 0; i < n_grid_cor; ++i) { 158 | memset(lattice_cor, 0, d1 * sizeof(int16_t)); 159 | for (int j = 0; j < d1; ++j) { 160 | for (int k = 0; k < d1; ++k) { 161 | lattice_cor[j] += b[(j * d1 + k)] * grid[i * d1 + k]; 162 | } 163 | } 164 | if (find_hash(hash_table, lattice_cor) == -1) { 165 | memcpy(neighbors + (*hash_table.filled) * d1, lattice_cor, d1 * sizeof(int16_t)); 166 | find_hash(hash_table, lattice_cor, true); 167 | } 168 | } 169 | 170 | int n_neighbors = static_cast(pow(n + 1, d1)) - static_cast(pow(n, d1)); 171 | assert((*hash_table.filled) == n_neighbors); 172 | 173 | // compute gauss filter weights 174 | std::vector filter{1.0, 0.5}; 175 | float *gauss_weights_tmp = new float[n_neighbors]; 176 | int16_t *walking_key_up = new int16_t[d1]; 177 | int16_t *walking_key_down = new int16_t[d1]; 178 | memset(gauss_weights, 0, n_neighbors * sizeof(float)); 179 | gauss_weights[0] = 1; 180 | for (int i = 0; i < d1; ++i) { 181 | memset(gauss_weights_tmp, 0, n_neighbors * sizeof(float)); 182 | 183 | for (int j = 0; j < n_neighbors; ++j) { 184 | const int16_t *key = hash_table.keys + j * d1; 185 | memcpy(walking_key_up, key, d1 * sizeof(int16_t)); 186 | memcpy(walking_key_down, key, d1 * sizeof(int16_t)); 187 | 188 | float &v = gauss_weights_tmp[j]; 189 | v += gauss_weights[j] * filter[0]; 190 | for (int k = 1; k <= n && k <= 2; ++k) { 191 | walk_in_dimension(d1, i, 1, walking_key_up); 192 | walk_in_dimension(d1, i, -1, walking_key_down); 193 | 194 | int h1 = find_hash(hash_table, walking_key_up); 195 | int h2 = find_hash(hash_table, walking_key_down); 196 | 197 | v += ((h1 >= 0 ? gauss_weights[h1] : 0) + 198 | (h2 >= 0 ? gauss_weights[h2] : 0)) * (k < filter.size() ? filter[k] : 0); 199 | } 200 | } 201 | memmove(gauss_weights, gauss_weights_tmp, n_neighbors * sizeof(float)); 202 | } 203 | float norm_coef = gauss_weights[0]; 204 | for (int i = 0; i < n_neighbors; ++i) { 205 | gauss_weights[i] /= norm_coef; 206 | } 207 | 208 | delete[] hash_table.key_size, hash_table.filled, hash_table.capacity, hash_table.keys, hash_table.table; 209 | delete[] grid, grid_cor, lattice_cor, b; 210 | delete[] gauss_weights_tmp, walking_key_up, walking_key_down; 211 | } 212 | 213 | 214 | //----------------------------------------------------------------------------------------// 215 | //-------------------------------------- Initialize --------------------------------------// 216 | //----------------------------------------------------------------------------------------// 217 | 218 | // inputs: position_in (b, n_points, d) 219 | // keys (b, n_points, d + 1, d) 220 | // barycentric (b, n_points, d + 2) 221 | __global__ void init1_kernel(int d, int n_points, 222 | const float *__restrict__ position, 223 | int16_t *__restrict__ keys, 224 | float *__restrict__ barycentric) { 225 | int batch_idx = blockIdx.x; 226 | int thread_idx = threadIdx.x; 227 | int d1 = d + 1; 228 | 229 | for (int point_idx = thread_idx; point_idx < n_points; point_idx += blockDim.x) { 230 | const float *position_ = position + (batch_idx * n_points + point_idx) * d; 231 | int16_t *key_ = keys + (batch_idx * n_points + point_idx) * (d + 1) * d; 232 | float *barycentric_ = barycentric + (batch_idx * n_points + point_idx) * (d + 2); 233 | 234 | int16_t *canonical = new int16_t[d1 * d1]; 235 | float *scale_factor = new float[d]; 236 | float *elevated = new float[d1]; 237 | int16_t *rem0 = new int16_t[d1]; 238 | int16_t *rank = new int16_t[d1]{0}; 239 | 240 | for (int i = 0; i < d1; ++i) { 241 | for (int j = 0; j < d1 - i; ++j) { 242 | canonical[i * d1 + j] = i; 243 | } 244 | for (int j = d1 - i; j < d1; ++j) { 245 | canonical[i * d1 + j] = -d1 + i; 246 | } 247 | } 248 | 249 | float inv_std_dev = sqrtf(2.0 / 3.0) * d1; 250 | for (int i = 0; i < d; ++i) { 251 | scale_factor[i] = 1.0 / sqrtf(((i + 1) * (i + 2))) * inv_std_dev; 252 | } 253 | 254 | // Elevate point into Hd using the rotation matrix E 255 | // see p.30 in [Adams etal 2011] 256 | float sm = 0; 257 | for (int i = d; i > 0; --i) { 258 | float cf = position_[i - 1] * scale_factor[i - 1]; 259 | elevated[i] = sm - i * cf; 260 | sm += cf; 261 | } 262 | elevated[0] = sm; 263 | 264 | // Find the closest 0-colored simplex through rounding 265 | int sum = 0; 266 | for (int i = 0; i < d1; ++i) { 267 | int rd = static_cast(round(elevated[i] / d1)); 268 | rem0[i] = static_cast(rd * d1); 269 | sum += rd; 270 | } 271 | 272 | // Find the simplex we are in and store it in rank 273 | // (where rank describes what position coorinate i has in the sorted order of the features values) 274 | for (int i = 0; i < d; ++i) { 275 | float di = elevated[i] - rem0[i]; 276 | for (int j = i + 1; j < d1; ++j) { 277 | if (di < elevated[j] - rem0[j]) { 278 | ++rank[i]; 279 | } else { 280 | ++rank[j]; 281 | } 282 | } 283 | } 284 | 285 | // If the point doesn't lie on the plane (sum != 0) bring it back 286 | for (int i = 0; i < d1; ++i) { 287 | rank[i] += sum; 288 | if (rank[i] < 0) { 289 | rank[i] += d + 1; 290 | rem0[i] += d + 1; 291 | } else if (rank[i] > d) { 292 | rank[i] -= d + 1; 293 | rem0[i] -= d + 1; 294 | } 295 | } 296 | 297 | // Compute all vertices 298 | for (int remainder = 0; remainder < d1; ++remainder) { 299 | // all but the last coordinate - it's redundant because they sum to zero 300 | for (int i = 0; i < d; ++i) { 301 | key_[remainder * d + i] = rem0[i] + canonical[remainder * (d + 1) + rank[i]]; 302 | } 303 | } 304 | 305 | // comptue the barycentric coordinates 306 | // see p.31 in [Adams etal 2011] 307 | for (int i = 0; i < d + 2; ++i) { 308 | barycentric_[i] = 0; 309 | } 310 | for (int i = 0; i < d1; ++i) { 311 | float v = (elevated[i] - rem0[i]) / d1; 312 | barycentric_[d - rank[i]] += v; 313 | barycentric_[d - rank[i] + 1] -= v; 314 | } 315 | barycentric_[0] += 1.0 + barycentric_[d + 1]; 316 | 317 | delete[] canonical; 318 | delete[] scale_factor; 319 | delete[] elevated; 320 | delete[] rem0; 321 | delete[] rank; 322 | } 323 | } 324 | 325 | // inputs: hash_table (b, d) 326 | // keys (b, n_points, d + 1, d) 327 | // offset (b, n_points, d + 1) 328 | __global__ void init2_kernel(int d, int n_points, 329 | HashTable *hash_tables, 330 | const int16_t *__restrict__ keys, 331 | int *__restrict__ offset) { 332 | int batch_idx = blockIdx.x; 333 | HashTable &hash_table_ = hash_tables[batch_idx]; 334 | int d1 = d + 1; 335 | 336 | // Compute all offset 337 | for (int i = 0; i < n_points; ++i) { 338 | for (int j = 0; j < d1; ++j) { 339 | const int16_t *key_ = keys + ((batch_idx * n_points + i) * (d + 1) + j) * d; 340 | int h = find_hash(hash_table_, key_, true); 341 | offset[(batch_idx * n_points + i) * (d + 1) + j] = h; 342 | } 343 | } 344 | } 345 | 346 | // inputs: hash_table (b, d) 347 | // neighbors (n_neighbors) 348 | // conv_neighbors (b, n_neighbors, n_filled) 349 | __global__ void init3_kernel(int d, int n_neighbors, int n_filled, 350 | HashTable *hash_tables, 351 | const int16_t *__restrict__ neighbors, 352 | int *__restrict__ conv_neighbors) { 353 | int batch_idx = blockIdx.x; 354 | HashTable &hash_table_ = hash_tables[batch_idx]; 355 | int d1 = d + 1; 356 | 357 | // Compute all convlution neighbors 358 | size_t key_size = (*hash_table_.key_size); 359 | int n_filled_ = (*hash_table_.filled); 360 | for (int lattice_idx = 0; lattice_idx < n_filled_; ++lattice_idx) { 361 | int16_t *center = new int16_t[d1]; 362 | int16_t sum = 0; 363 | for (int i = 0; i < d; ++i) { 364 | center[i] = hash_table_.keys[lattice_idx * key_size + i]; 365 | sum += center[i]; 366 | } 367 | center[d] = -sum; 368 | 369 | int16_t *neighbor_key = new int16_t[d1]; 370 | for (int i = 0; i < n_neighbors; ++i) { 371 | for (int j = 0; j < d1; ++j) { 372 | neighbor_key[j] = center[j] + neighbors[i * d1 + j]; 373 | } 374 | conv_neighbors[(batch_idx * n_neighbors + i) * n_filled + lattice_idx] = find_hash(hash_table_, neighbor_key); 375 | } 376 | 377 | delete[] center; 378 | delete[] neighbor_key; 379 | } 380 | for (int lattice_idx = n_filled_; lattice_idx < n_filled; ++lattice_idx) { 381 | for (int i = 0; i < n_neighbors; ++i) { 382 | conv_neighbors[(batch_idx * n_neighbors + i) * n_filled + lattice_idx] = -1; 383 | } 384 | } 385 | } 386 | 387 | // inputs: position_in (b, n_points, d) 388 | // hash_table (b, d) 389 | // offset (b, n_points, d + 1) 390 | // barycentric (b, n_points, d + 2) 391 | __global__ void init4_kernel(int d, int n_points, 392 | const float *__restrict__ position, 393 | HashTable *hash_tables, 394 | int *__restrict__ offset, 395 | float *__restrict__ barycentric) { 396 | int batch_idx = blockIdx.x; 397 | int thread_idx = threadIdx.x; 398 | HashTable &hash_table_ = hash_tables[batch_idx]; 399 | int d1 = d + 1; 400 | 401 | for (int point_idx = thread_idx; point_idx < n_points; point_idx += blockDim.x) { 402 | const float *position_ = position + (batch_idx * n_points + point_idx) * d; 403 | float *barycentric_ = barycentric + (batch_idx * n_points + point_idx) * (d + 2); 404 | 405 | int16_t *canonical = new int16_t[d1 * d1]; 406 | float *scale_factor = new float[d]; 407 | float *elevated = new float[d1]; 408 | int16_t *rem0 = new int16_t[d1]; 409 | int16_t *rank = new int16_t[d1]{0}; 410 | int16_t *key = new int16_t[d]; 411 | 412 | for (int i = 0; i < d1; ++i) { 413 | for (int j = 0; j < d1 - i; ++j) { 414 | canonical[i * d1 + j] = i; 415 | } 416 | for (int j = d1 - i; j < d1; ++j) { 417 | canonical[i * d1 + j] = -d1 + i; 418 | } 419 | } 420 | 421 | float inv_std_dev = sqrtf(2.0 / 3.0) * d1; 422 | for (int i = 0; i < d; ++i) { 423 | scale_factor[i] = 1.0 / sqrtf(((i + 1) * (i + 2))) * inv_std_dev; 424 | } 425 | 426 | // Elevate point into Hd using the rotation matrix E 427 | // see p.30 in [Adams etal 2011] 428 | float sm = 0; 429 | for (int i = d; i > 0; --i) { 430 | float cf = position_[i - 1] * scale_factor[i - 1]; 431 | elevated[i] = sm - i * cf; 432 | sm += cf; 433 | } 434 | elevated[0] = sm; 435 | 436 | // Find the closest 0-colored simplex through rounding 437 | int sum = 0; 438 | for (int i = 0; i < d1; ++i) { 439 | int rd = static_cast(round(elevated[i] / d1)); 440 | rem0[i] = static_cast(rd * d1); 441 | sum += rd; 442 | } 443 | 444 | // Find the simplex we are in and store it in rank 445 | // (where rank describes what position coorinate i has in the sorted order of the features values) 446 | for (int i = 0; i < d; ++i) { 447 | float di = elevated[i] - rem0[i]; 448 | for (int j = i + 1; j < d1; ++j) { 449 | if (di < elevated[j] - rem0[j]) { 450 | ++rank[i]; 451 | } else { 452 | ++rank[j]; 453 | } 454 | } 455 | } 456 | 457 | // If the point doesn't lie on the plane (sum != 0) bring it back 458 | for (int i = 0; i < d1; ++i) { 459 | rank[i] += sum; 460 | if (rank[i] < 0) { 461 | rank[i] += d + 1; 462 | rem0[i] += d + 1; 463 | } else if (rank[i] > d) { 464 | rank[i] -= d + 1; 465 | rem0[i] -= d + 1; 466 | } 467 | } 468 | 469 | // Compute all offset 470 | for (int remainder = 0; remainder < d1; ++remainder) { 471 | // all but the last coordinate - it's redundant because they sum to zero 472 | for (int i = 0; i < d; ++i) { 473 | key[i] = rem0[i] + canonical[remainder * (d + 1) + rank[i]]; 474 | } 475 | int h = find_hash(hash_table_, key); 476 | offset[(batch_idx * n_points + point_idx) * (d + 1) + remainder] = h; 477 | } 478 | 479 | // comptue the barycentric coordinates 480 | // see p.31 in [Adams etal 2011] 481 | for (int i = 0; i < d + 2; ++i) { 482 | barycentric_[i] = 0; 483 | } 484 | for (int i = 0; i < d1; ++i) { 485 | float v = (elevated[i] - rem0[i]) / d1; 486 | barycentric_[d - rank[i]] += v; 487 | barycentric_[d - rank[i] + 1] -= v; 488 | } 489 | barycentric_[0] += 1.0 + barycentric_[d + 1]; 490 | 491 | delete[] canonical; 492 | delete[] scale_factor; 493 | delete[] elevated; 494 | delete[] rem0; 495 | delete[] rank; 496 | delete[] key; 497 | } 498 | } 499 | 500 | //----------------------------------------------------------------------------------------// 501 | //-------------------------------------- Operation ---------------------------------------// 502 | //----------------------------------------------------------------------------------------// 503 | 504 | // inputs: features (b, df, n_points) 505 | // offset (b, n_points, d + 1) 506 | // barycentric (b, n_points, d + 2) 507 | // lattices (b, df, n_filled) 508 | __global__ void splat_kernel(int d, int df, int n_points, int n_filled, 509 | const float *__restrict__ features, 510 | const int *__restrict__ offset, 511 | const float *__restrict__ barycentric, 512 | float *__restrict__ lattices) { 513 | int batch_idx = blockIdx.x; 514 | 515 | for (int i = threadIdx.y; i < df; i += blockDim.y) { 516 | for (int j = threadIdx.x; j < n_points; j += blockDim.x) { 517 | float feature_ = features[(batch_idx * df + i) * n_points + j]; 518 | 519 | for (int k = 0; k < (d + 1); ++k) { 520 | int offset_ = offset[(batch_idx * n_points + j) * (d + 1) + k]; 521 | float barycentric_ = barycentric[(batch_idx * n_points + j) * (d + 2) + k]; 522 | atomicAdd(lattices + (batch_idx * df + i) * n_filled + offset_, feature_ * barycentric_); 523 | } 524 | } 525 | } 526 | } 527 | 528 | //inputs: lattices_in (b, df_in, n_filled) 529 | // conv_weights (df_out, df_in, n_neighbors) 530 | // conv_neighbors (b, n_neighbors, n_filled) 531 | // lattices_out (b, df_out, n_filled) 532 | __global__ void conv_kernel(int df_in, int df_out, int n_filled, int n_neighbors, 533 | const float *__restrict__ lattices_in, 534 | const float *__restrict__ conv_weights, 535 | const int *__restrict__ conv_neighbors, 536 | float *__restrict__ lattices_out) { 537 | int batch_idx = blockIdx.x; 538 | 539 | for (int i = threadIdx.y; i < df_out; i += blockDim.y) { 540 | for (int j = threadIdx.x; j < n_filled; j += blockDim.x) { 541 | 542 | float sum = 0; 543 | for (int k = 0; k < df_in; ++k) { 544 | const float *conv_weights_ = conv_weights + (i * df_in + k) * n_neighbors; 545 | const float *lattices_in_ = lattices_in + (batch_idx * df_in + k) * n_filled; 546 | for (int l = 0; l < n_neighbors; ++l) { 547 | int h = conv_neighbors[(batch_idx * n_neighbors + l) * n_filled + j]; 548 | if (h == -1) { continue; } 549 | sum += conv_weights_[l] * lattices_in_[h]; 550 | } 551 | } 552 | lattices_out[(batch_idx * df_out + i) * n_filled + j] = sum; 553 | } 554 | } 555 | } 556 | 557 | // inputs: lattices (b, df_in, n_filled) 558 | // conv_neighbors (b, n_neighbors, n_filled) 559 | // col (df * n_neighbors, n_filled) 560 | __global__ void img2col_kernel(int b, int df_in, int n_filled, int n_neighbors, 561 | const float *__restrict__ lattices, 562 | const int *__restrict__ conv_neighbors, 563 | float *__restrict__ col) { 564 | int df_in_index = blockIdx.x; 565 | const float *lattices_ = lattices + (b * df_in + df_in_index) * n_filled; 566 | 567 | for (int i = threadIdx.x; i < n_filled; i += blockDim.x) { 568 | for (int j = 0; j < n_neighbors; ++j) { 569 | int h = conv_neighbors[(b * n_neighbors + j) * n_filled + i]; 570 | if (h == -1) { continue; } 571 | col[(df_in_index * n_neighbors + j) * n_filled + i] = lattices_[h]; 572 | } 573 | } 574 | } 575 | 576 | // inputs: lattices (b, df, n_filled) 577 | // offset (b, n_points, d + 1) 578 | // barycentric (b, n_points, d + 2) 579 | // features (b, df, n_points) 580 | __global__ void slice_kernel(int d, int df, int n_points, int n_filled, 581 | const float *__restrict__ lattices, 582 | const int *__restrict__ offset, 583 | const float *__restrict__ barycentric, 584 | float *__restrict__ features) { 585 | int batch_idx = blockIdx.x; 586 | 587 | for (int i = threadIdx.y; i < df; i += blockDim.y) { 588 | for (int j = threadIdx.x; j < n_points; j += blockDim.x) { 589 | 590 | for (int k = 0; k < (d + 1); ++k) { 591 | int offset_ = offset[(batch_idx * n_points + j) * (d + 1) + k]; 592 | if (offset_ == -1) { continue; } 593 | float barycentric_ = barycentric[(batch_idx * n_points + j) * (d + 2) + k]; 594 | float lattices_ = lattices[(batch_idx * df + i) * n_filled + offset_]; 595 | atomicAdd(features + (batch_idx * df + i) * n_points + j, lattices_ * barycentric_); 596 | } 597 | } 598 | } 599 | } 600 | 601 | // inputs: grad_out (b, df, n_filled) 602 | // offset (b, n_points, d + 1) 603 | // barycentric (b, n_points, d + 2) 604 | // grad_in (b, df, n_points) 605 | __global__ void splat_grad_kernel(int d, int df, int n_points, int n_filled, 606 | const float *__restrict__ grad_out, 607 | const int *__restrict__ offset, 608 | const float *__restrict__ barycentric, 609 | float *__restrict__ grad_in) { 610 | int batch_idx = blockIdx.x; 611 | 612 | for (int i = threadIdx.y; i < df; i += blockDim.y) { 613 | for (int j = threadIdx.x; j < n_points; j += blockDim.x) { 614 | 615 | float sum = 0; 616 | for (int k = 0; k < (d + 1); ++k) { 617 | int offset_ = offset[(batch_idx * n_points + j) * (d + 1) + k]; 618 | float barycentric_ = barycentric[(batch_idx * n_points + j) * (d + 2) + k]; 619 | float grad_out_ = grad_out[(batch_idx * df + i) * n_filled + offset_]; 620 | sum += barycentric_ * grad_out_; 621 | } 622 | grad_in[(batch_idx * df + i) * n_points + j] = sum; 623 | } 624 | } 625 | } 626 | 627 | //inputs: grad_out (b, df_out, n_filled) 628 | // conv_weights (df_out, df_in, n_neighbors) 629 | // conv_neighbors (b, n_neighbors, n_filled) 630 | // grad_in (b, df_in, n_filled) 631 | __global__ void conv_grad_kernel(int df_in, int df_out, int n_filled, int n_neighbors, 632 | const float *__restrict__ grad_out, 633 | const float *__restrict__ conv_weights, 634 | const int *__restrict__ conv_neighbors, 635 | float *__restrict__ grad_in) { 636 | int batch_idx = blockIdx.x; 637 | 638 | for (int i = threadIdx.y; i < df_out; i += blockDim.y) { 639 | for (int j = threadIdx.x; j < n_filled; j += blockDim.x) { 640 | float grad_out_ = grad_out[(batch_idx * df_out + i) * n_filled + j]; 641 | 642 | for (int k = 0; k < df_in; ++k) { 643 | const float *conv_weights_ = conv_weights + (i * df_in + k) * n_neighbors; 644 | float *grad_in_ = grad_in + (batch_idx * df_in + k) * n_filled; 645 | for (int l = 0; l < n_neighbors; ++l) { 646 | int h = conv_neighbors[(batch_idx * n_neighbors + l) * n_filled + j]; 647 | if (h == -1) { continue; } 648 | atomicAdd(grad_in_ + h, conv_weights_[l] * grad_out_); 649 | } 650 | } 651 | } 652 | } 653 | } 654 | 655 | //inputs: grad_out (b, df_out, n_filled) 656 | // conv_neighbors (b, n_neighbors, n_filled) 657 | // col_grad (df_out * n_neighbors, n_filled) 658 | __global__ void img2col_grad_kernel(int b, int df_out, int n_filled, int n_neighbors, 659 | const float *__restrict__ grad_out, 660 | const int *__restrict__ conv_neighbors, 661 | float *__restrict__ col_grad) { 662 | int df_out_index = blockIdx.x; 663 | const float *grad_out_ = grad_out + (b * df_out + df_out_index) * n_filled; 664 | 665 | for (int i = threadIdx.x; i < n_filled; i += blockDim.x) { 666 | for (int j = 0; j < n_neighbors; ++j) { 667 | int h = conv_neighbors[(b * n_neighbors + j) * n_filled + i]; 668 | if (h == -1) { continue; } 669 | col_grad[(df_out_index * n_neighbors + j) * n_filled + i] = grad_out_[h]; 670 | } 671 | } 672 | } 673 | 674 | // inputs: grad_out (b, df, n_points) 675 | // offset (b, n_points, d + 1) 676 | // barycentric (b, n_points, d + 2) 677 | // grad_in (b, df, n_filled) 678 | __global__ void slice_grad_kernel(int d, int df, int n_points, int n_filled, 679 | const float *__restrict__ grad_out, 680 | const int *__restrict__ offset, 681 | const float *__restrict__ barycentric, 682 | float *__restrict__ grad_in) { 683 | int batch_idx = blockIdx.x; 684 | 685 | for (int i = threadIdx.y; i < df; i += blockDim.y) { 686 | for (int j = threadIdx.x; j < n_points; j += blockDim.x) { 687 | 688 | for (int k = 0; k < (d + 1); ++k) { 689 | int offset_ = offset[(batch_idx * n_points + j) * (d + 1) + k]; 690 | if (offset_ == -1) { continue; } 691 | float barycentric_ = barycentric[(batch_idx * n_points + j) * (d + 2) + k]; 692 | float grad_out_ = grad_out[(batch_idx * df + i) * n_points + j]; 693 | atomicAdd(grad_in + (batch_idx * df + i) * n_filled + offset_, grad_out_ * barycentric_); 694 | } 695 | } 696 | } 697 | } 698 | 699 | //inputs: grad_out (b, df_out, n_filled) 700 | // lattices_in (b, df_in, n_filled) 701 | // conv_neighbors (b, n_neighbors, n_filled) 702 | // grad_weights (df_out, df_in, n_neighbors) 703 | __global__ void weights_grad_kernel(int df_in, int df_out, int n_filled, int n_neighbors, 704 | const float *__restrict__ grad_out, 705 | const float *__restrict__ lattices_in, 706 | const int *__restrict__ conv_neighbors, 707 | float *__restrict__ grad_weights) { 708 | int batch_idx = blockIdx.x; 709 | 710 | for (int i = threadIdx.y; i < df_out; i += blockDim.y) { 711 | for (int j = threadIdx.x; j < n_filled; j += blockDim.x) { 712 | float grad_out_ = grad_out[(batch_idx * df_out + i) * n_filled + j]; 713 | 714 | for (int k = 0; k < df_in; ++k) { 715 | const float *lattices_in_ = lattices_in + (batch_idx * df_in + k) * n_filled; 716 | for (int l = 0; l < n_neighbors; ++l) { 717 | int h = conv_neighbors[(batch_idx * n_neighbors + l) * n_filled + j]; 718 | if (h == -1) { continue; } 719 | atomicAdd(grad_weights + (i * df_in + k) * n_neighbors + l, lattices_in_[h] * grad_out_); 720 | } 721 | } 722 | } 723 | } 724 | } 725 | 726 | 727 | //----------------------------------------------------------------------------------------// 728 | //-------------------------------------- Operation ---------------------------------------// 729 | //----------------------------------------------------------------------------------------// 730 | 731 | int compute( 732 | int b, int n_points_in, int n_points_out, int df_in, int df_out, int d, int n, bool skip_conv, bool keep_position, 733 | const float *position_in, 734 | const float *position_out, 735 | const float *features_in, 736 | const float *norm_features, 737 | const float *weights, 738 | float *features_out, 739 | float *&lattices_in, 740 | int *offset_in, 741 | int *offset_out, 742 | float *barycentric_in, 743 | float *barycentric_out, 744 | int *&conv_neighbors, 745 | float *norm) { 746 | int d1 = d + 1; 747 | int n_neighbors = static_cast(pow(n + 1, d1)) - static_cast(pow(n, d1)); 748 | int n_filled = 0; 749 | size_t n_filled_tmp = 0; 750 | HashTable *hash_tables, *hash_tables_host; 751 | int16_t *keys_in, *neighbors, *neighbors_host; 752 | float *gauss_weights, *gauss_weights_host, *lattices_out; 753 | float *norm_lattices_in, *norm_lattices_out; 754 | 755 | // initialize 756 | neighbors_host = new int16_t[n_neighbors * d1]; 757 | gauss_weights_host = new float[n_neighbors]; 758 | compute_neighbors_and_gauss(d, n, neighbors_host, gauss_weights_host); 759 | 760 | init_hash_table(b, d, n_points_in * d1 * 10, &hash_tables_host, &hash_tables); 761 | CUDA_CHECK_ERROR(cudaMalloc((void **) (&keys_in), b * n_points_in * d1 * d * sizeof(int16_t))); 762 | CUDA_CHECK_ERROR(cudaMalloc((void **) (&neighbors), n_neighbors * d1 * sizeof(int16_t))); 763 | CUDA_CHECK_ERROR(cudaMalloc((void **) (&gauss_weights), n_neighbors * sizeof(float))); 764 | CUDA_CHECK_ERROR(cudaMemcpy(neighbors, neighbors_host, n_neighbors * d1 * sizeof(int16_t), cudaMemcpyHostToDevice)); 765 | CUDA_CHECK_ERROR(cudaMemcpy(gauss_weights, gauss_weights_host, n_neighbors * sizeof(float), cudaMemcpyHostToDevice)); 766 | 767 | init1_kernel << < b, opt_n_threads(n_points_in) >> > (d, n_points_in, position_in, keys_in, barycentric_in); 768 | CUDA_CHECK_KERNEL_ERROR(); 769 | init2_kernel << < b, 1 >> > (d, n_points_in, hash_tables, keys_in, offset_in); 770 | CUDA_CHECK_KERNEL_ERROR(); 771 | for (int i = 0; i < b; ++i) { 772 | CUDA_CHECK_ERROR(cudaMemcpy(&n_filled_tmp, hash_tables_host[i].filled, sizeof(size_t), cudaMemcpyDeviceToHost)); 773 | if (n_filled_tmp > n_filled) { 774 | n_filled = n_filled_tmp; 775 | } 776 | } 777 | 778 | CUDA_CHECK_ERROR(cudaMallocManaged((void **) (&conv_neighbors), b * n_filled * n_neighbors * sizeof(int))); 779 | CUDA_CHECK_ERROR(cudaMallocManaged((void **) (&lattices_in), b * df_in * n_filled * sizeof(float))); 780 | CUDA_CHECK_ERROR(cudaMalloc((void **) (&lattices_out), b * df_out * n_filled * sizeof(float))); 781 | CUDA_CHECK_ERROR(cudaMalloc((void **) (&norm_lattices_in), b * 1 * n_filled * sizeof(float))); 782 | CUDA_CHECK_ERROR(cudaMalloc((void **) (&norm_lattices_out), b * 1 * n_filled * sizeof(float))); 783 | 784 | init3_kernel << < b, 1 >> > (d, n_neighbors, n_filled, hash_tables, neighbors, conv_neighbors); 785 | CUDA_CHECK_KERNEL_ERROR(); 786 | if (keep_position) { 787 | offset_out = offset_in; 788 | barycentric_out = barycentric_in; 789 | } else { 790 | init4_kernel << < b, opt_n_threads(n_points_out) >> > 791 | (d, n_points_out, position_out, hash_tables, offset_out, barycentric_out); 792 | CUDA_CHECK_KERNEL_ERROR(); 793 | } 794 | 795 | // splat-conv-slice 796 | splat_kernel << < b, opt_block_config(n_points_in, df_in) >> > 797 | (d, df_in, n_points_in, n_filled, features_in, offset_in, barycentric_in, lattices_in); 798 | CUDA_CHECK_KERNEL_ERROR(); 799 | if (skip_conv) { 800 | CUDA_CHECK_ERROR( 801 | cudaMemcpy(lattices_out, lattices_in, b * df_out * n_filled * sizeof(float), cudaMemcpyDeviceToDevice)); 802 | } else { 803 | // // original convolution operation is too slow 804 | // conv_kernel << < b, opt_block_config(n_filled, df_out) >> > 805 | // (df_in, df_out, n_filled, n_neighbors, lattices_in, weights, conv_neighbors, lattices_out); 806 | // CUDA_CHECK_KERNEL_ERROR(); 807 | 808 | // conv -> matmul 809 | float alpha = 1.0; 810 | float beta = 0; 811 | float *col; 812 | cublasHandle_t cublas_handle; 813 | CUBLAS_CHECK_ERROR(cublasCreate(&cublas_handle)); 814 | CUDA_CHECK_ERROR(cudaMalloc((void **) (&col), df_in * n_neighbors * n_filled * sizeof(float))); 815 | for (int i = 0; i < b; ++i) { 816 | CUDA_CHECK_ERROR(cudaMemset(col, 0, df_in * n_neighbors * n_filled * sizeof(float))); 817 | img2col_kernel << < df_in, opt_n_threads(n_filled) >> > 818 | (i, df_in, n_filled, n_neighbors, lattices_in, conv_neighbors, col); 819 | CUDA_CHECK_KERNEL_ERROR(); 820 | CUBLAS_CHECK_ERROR(cublasSgemm(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, n_filled, df_out, df_in * n_neighbors, 821 | &alpha, col, n_filled, weights, df_in * n_neighbors, 822 | &beta, lattices_out + i * df_out * n_filled, n_filled)); 823 | } 824 | cudaDeviceSynchronize(); 825 | CUDA_CHECK_ERROR(cudaFree(col)); 826 | CUBLAS_CHECK_ERROR(cublasDestroy(cublas_handle)); 827 | } 828 | slice_kernel << < b, opt_block_config(n_points_out, df_out) >> > 829 | (d, df_out, n_points_out, n_filled, lattices_out, offset_out, barycentric_out, features_out); 830 | CUDA_CHECK_KERNEL_ERROR(); 831 | 832 | // norm 833 | splat_kernel << < b, opt_block_config(n_points_in, 1) >> > 834 | (d, 1, n_points_in, n_filled, norm_features, offset_in, barycentric_in, norm_lattices_in); 835 | CUDA_CHECK_KERNEL_ERROR(); 836 | conv_kernel << < b, opt_block_config(n_filled, 1) >> > 837 | (1, 1, n_filled, n_neighbors, norm_lattices_in, gauss_weights, conv_neighbors, norm_lattices_out); 838 | CUDA_CHECK_KERNEL_ERROR(); 839 | slice_kernel << < b, opt_block_config(n_points_out, 1) >> > 840 | (d, 1, n_points_out, n_filled, norm_lattices_out, offset_out, barycentric_out, norm); 841 | CUDA_CHECK_KERNEL_ERROR(); 842 | 843 | // deinitialize 844 | delete[] neighbors_host, gauss_weights_host; 845 | deinit_hash_table(b, hash_tables_host, hash_tables); 846 | CUDA_CHECK_ERROR(cudaFree(keys_in)); 847 | CUDA_CHECK_ERROR(cudaFree(neighbors)); 848 | CUDA_CHECK_ERROR(cudaFree(gauss_weights)); 849 | CUDA_CHECK_ERROR(cudaFree(lattices_out)); 850 | CUDA_CHECK_ERROR(cudaFree(norm_lattices_in)); 851 | CUDA_CHECK_ERROR(cudaFree(norm_lattices_out)); 852 | 853 | return n_filled; 854 | } 855 | 856 | void compute_grad( 857 | int b, int n_points_in, int n_points_out, int n_filled, int df_in, int df_out, int d, int n, bool skip_conv, 858 | const float *grad_out, 859 | const float *weights_transpose, 860 | const float *lattices_in, 861 | const int *offset_in, 862 | const int *offset_out, 863 | const float *barycentric_in, 864 | const float *barycentric_out, 865 | const int *conv_neighbors, 866 | const float *norm, 867 | float *grad_in, 868 | float *grad_weights_transpose) { 869 | int d1 = d + 1; 870 | int n_neighbors = static_cast(pow(n + 1, d1)) - static_cast(pow(n, d1)); 871 | float *lattices_grad_out, *lattices_grad_in; 872 | 873 | CUDA_CHECK_ERROR(cudaMalloc((void **) (&lattices_grad_in), b * df_in * n_filled * sizeof(float))); 874 | CUDA_CHECK_ERROR(cudaMalloc((void **) (&lattices_grad_out), b * df_out * n_filled * sizeof(float))); 875 | 876 | slice_grad_kernel << < b, opt_block_config(n_points_out, df_out) >> > 877 | (d, df_out, n_points_out, n_filled, grad_out, offset_out, barycentric_out, lattices_grad_out); 878 | CUDA_CHECK_KERNEL_ERROR(); 879 | if (skip_conv) { 880 | CUDA_CHECK_ERROR( 881 | cudaMemcpy(lattices_grad_in, 882 | lattices_grad_out, 883 | b * df_in * n_filled * sizeof(float), 884 | cudaMemcpyDeviceToDevice)); 885 | } else { 886 | // // original convolution operation is too slow 887 | // conv_grad_kernel << < b, opt_block_config(n_filled, df_out) >> > 888 | // (df_in, df_out, n_filled, n_neighbors, lattices_grad_out, weights, conv_neighbors, lattices_grad_in); 889 | // CUDA_CHECK_KERNEL_ERROR(); 890 | // weights_grad_kernel << < b, opt_block_config(n_filled, df_out) >> > 891 | // (df_in, df_out, n_filled, n_neighbors, lattices_grad_out, lattices_in, conv_neighbors, grad_weights); 892 | // CUDA_CHECK_KERNEL_ERROR(); 893 | 894 | // conv -> matmul 895 | float alpha = 1.0; 896 | float beta = 0; 897 | float *col_grad; 898 | cublasHandle_t cublas_handle; 899 | CUBLAS_CHECK_ERROR(cublasCreate(&cublas_handle)); 900 | CUDA_CHECK_ERROR(cudaMalloc((void **) (&col_grad), df_out * n_neighbors * n_filled * sizeof(float))); 901 | CUDA_CHECK_KERNEL_ERROR(); 902 | for (int i = 0; i < b; ++i) { 903 | CUDA_CHECK_ERROR(cudaMemset(col_grad, 0, df_out * n_neighbors * n_filled * sizeof(float))); 904 | img2col_grad_kernel << < df_out, opt_n_threads(n_filled) >> > 905 | (i, df_out, n_filled, n_neighbors, lattices_grad_out, conv_neighbors, col_grad); 906 | CUDA_CHECK_KERNEL_ERROR(); 907 | CUBLAS_CHECK_ERROR(cublasSgemm(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, n_filled, df_in, df_out * n_neighbors, 908 | &alpha, col_grad, n_filled, weights_transpose, df_out * n_neighbors, 909 | &beta, lattices_grad_in + i * df_in * n_filled, n_filled)); 910 | 911 | CUBLAS_CHECK_ERROR(cublasSgemm(cublas_handle, CUBLAS_OP_T, CUBLAS_OP_N, df_in, df_out * n_neighbors, n_filled, 912 | &alpha, lattices_in + i * df_in * n_filled, n_filled, col_grad, n_filled, 913 | &alpha, grad_weights_transpose, df_in)); 914 | } 915 | cudaDeviceSynchronize(); 916 | CUDA_CHECK_ERROR(cudaFree(col_grad)); 917 | CUBLAS_CHECK_ERROR(cublasDestroy(cublas_handle)); 918 | } 919 | splat_grad_kernel << < b, opt_block_config(n_points_in, df_in) >> > 920 | (d, df_in, n_points_in, n_filled, lattices_grad_in, offset_in, barycentric_in, grad_in); 921 | CUDA_CHECK_KERNEL_ERROR(); 922 | 923 | CUDA_CHECK_ERROR(cudaFree(lattices_grad_in)); 924 | CUDA_CHECK_ERROR(cudaFree(lattices_grad_out)); 925 | } 926 | -------------------------------------------------------------------------------- /utils/permutohedral_lattice_layer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import imp 3 | import torch 4 | import torch.nn as nn 5 | from torch.autograd import Function 6 | import time 7 | 8 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 9 | file, path, description = imp.find_module('libpermutohedral_lattice', [os.path.join(BASE_DIR, 'lib')]) 10 | C_ = imp.load_module('C_', file, path, description) 11 | 12 | 13 | class PermutohedralLatticeFunction(Function): 14 | """ 15 | 16 | 17 | :param 18 | position_in: (b, n_points_in, d) 19 | position_out: (b, n_points_out, d) or None 20 | features_in: (b, df_in, n_points_in) 21 | weights: (df_out, df_in, n_neighbors) 22 | :return: 23 | features_out: (b, df_out, n_points_out) 24 | norm: (b, 1, n_points_out) 25 | """ 26 | 27 | @staticmethod 28 | def forward(ctx, features_in, weights, position_in, position_out, 29 | neighbor_size=1, skip_conv=False): 30 | if position_out is None: 31 | features_out, lattices_in, offset_in, barycentric_in, conv_neighbors, norm = \ 32 | C_.permutohedral_lattice_keep_position(position_in, features_in, weights, neighbor_size, skip_conv) 33 | ctx.save_for_backward(weights, lattices_in, offset_in, barycentric_in, conv_neighbors, norm) 34 | ctx.keep_position = True 35 | 36 | else: 37 | features_out, lattices_in, offset_in, offset_out, barycentric_in, barycentric_out, conv_neighbors, norm = \ 38 | C_.permutohedral_lattice(position_in, position_out, features_in, weights, neighbor_size, skip_conv) 39 | ctx.save_for_backward(weights, lattices_in, offset_in, offset_out, barycentric_in, barycentric_out, 40 | conv_neighbors, norm) 41 | ctx.keep_position = False 42 | ctx.neighbor_size = neighbor_size 43 | ctx.skip_conv = skip_conv 44 | return features_out, norm 45 | 46 | @staticmethod 47 | def backward(ctx, grad_out, *args): 48 | if ctx.keep_position: 49 | weights, lattices_in, offset_in, barycentric_in, conv_neighbors, norm = ctx.saved_tensors 50 | grad_in, grad_weights = \ 51 | C_.permutohedral_lattice_keep_position_grad(grad_out.contiguous(), weights, lattices_in, offset_in, 52 | barycentric_in, conv_neighbors, norm, ctx.neighbor_size, 53 | ctx.skip_conv) 54 | else: 55 | weights, lattices_in, offset_in, offset_out, barycentric_in, barycentric_out, conv_neighbors, norm = ctx.saved_tensors 56 | grad_in, grad_weights = \ 57 | C_.permutohedral_lattice_grad(grad_out.contiguous(), weights, lattices_in, offset_in, offset_out, 58 | barycentric_in, barycentric_out, conv_neighbors, norm, ctx.neighbor_size, 59 | ctx.skip_conv) 60 | return grad_in, grad_weights, None, None, None, None 61 | 62 | 63 | permutohedral_lattice_ = PermutohedralLatticeFunction.apply 64 | 65 | 66 | class PermutohedralLattice(nn.Module): 67 | 68 | def __init__(self, df_in, df_out, d, pos_lambda, n=1, bias=True, skip_conv=False): 69 | super(PermutohedralLattice, self).__init__() 70 | n_neighbors = int(pow((n + 1), (d + 1)) - pow(n, (d + 1))) 71 | 72 | self.df_in = df_in 73 | self.df_out = df_out 74 | self.d = d 75 | self.pos_lambda = pos_lambda 76 | self.n = n 77 | self.skip_conv = skip_conv 78 | self.n_neighbors = n_neighbors 79 | self.weights = nn.Parameter(torch.Tensor(df_out, df_in, n_neighbors)) 80 | if bias: 81 | self.bias = nn.Parameter(torch.Tensor(1, df_out, 1)) 82 | else: 83 | self.register_parameter('bias', None) 84 | self.reset_parameters() 85 | 86 | def reset_parameters(self): 87 | self.weights.data.normal_(0, 0.01) 88 | if self.bias is not None: 89 | self.bias.data.fill_(0) 90 | return 91 | 92 | def forward(self, features_in, position_in, position_out=None): 93 | position_in = position_in * self.pos_lambda 94 | assert features_in.size(1) == self.df_in 95 | if position_out is None: 96 | assert position_in.size(2) == self.d 97 | else: 98 | assert position_in.size(2) == position_out.size(2) == self.d 99 | position_out = position_out * self.pos_lambda 100 | 101 | features_out, norm = \ 102 | permutohedral_lattice_(features_in, self.weights, position_in, position_out, self.n, self.skip_conv) 103 | features_out = features_out / norm.detach() 104 | if self.bias is not None: 105 | features_out = features_out + self.bias 106 | return features_out 107 | 108 | 109 | if __name__ == '__main__': 110 | # speed test 111 | m = PermutohedralLattice(512, 256, 3, 1, True) 112 | start = time.time() 113 | for i in range(10): 114 | init_time = time.time() 115 | position_in = torch.randn([32, 2048, 3]).cuda() 116 | position_out = torch.randn([32, 2048, 3]).cuda() 117 | features_in = torch.randn([32, 512, 2048]).cuda() 118 | features_in.requires_grad = True 119 | start = start + time.time() - init_time 120 | features_out = m(features_in, position_in, position_out) 121 | loss = features_out.sum() 122 | loss.backward() 123 | print(i) 124 | end = time.time() 125 | print((end - start) / 10) 126 | -------------------------------------------------------------------------------- /utils/permutohedral_lattice_wrapper.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include "cuda_utils.h" 6 | #include "permutohedral_lattice.h" 7 | 8 | void free_cuda_data_callback(void *p, void *data) { 9 | CUDA_CHECK_ERROR(cudaFree(data)); 10 | }; 11 | 12 | // inputs: 13 | // position_in_tensor (b, n_points_in, d) 14 | // position_out_tensor (b, n_points_out, d) 15 | // features_in_tensor (b, df_in, n_points_in) 16 | // weights_tensor (df_out, df_in, n_neighbors) 17 | // neighbor_size neighbor_size 18 | // skip_conv true: splat-conv-slice 19 | // false: splat-slice 20 | // 21 | // outputs: 22 | // features_out_tensor (b, df_out, n_points_out) 23 | // lattices_in_tensor (b, df_in, n_filled) 24 | // offset_in_tensor (b, n_points_in, d + 1) 25 | // offset_out_tensor (b, n_points_out, d + 1) 26 | // barycentric_in_tensor (b, n_points_in, d + 2) 27 | // barycentric_out_tensor (b, n_points_out, d + 2) 28 | // conv_neighbors_tensor (b, n_neighbors, n_filled) 29 | // norm_tensor (b, 1, n_points_out) 30 | std::vector permutohedral_lattice(at::Tensor position_in_tensor, 31 | at::Tensor position_out_tensor, 32 | at::Tensor features_in_tensor, 33 | at::Tensor weights_tensor, 34 | int neighbor_size, 35 | bool skip_conv) { 36 | CHECK_INPUT(position_in_tensor); 37 | CHECK_INPUT_TYPE(position_in_tensor, at::ScalarType::Float); 38 | CHECK_INPUT(position_out_tensor); 39 | CHECK_INPUT_TYPE(position_out_tensor, at::ScalarType::Float); 40 | CHECK_INPUT(features_in_tensor); 41 | CHECK_INPUT_TYPE(features_in_tensor, at::ScalarType::Float); 42 | CHECK_INPUT(weights_tensor); 43 | CHECK_INPUT_TYPE(weights_tensor, at::ScalarType::Float); 44 | 45 | int b = position_in_tensor.size(0); 46 | int n_points_in = position_in_tensor.size(1); 47 | int n_points_out = position_out_tensor.size(1); 48 | int d = position_in_tensor.size(2); 49 | int df_in = weights_tensor.size(1); 50 | int df_out = weights_tensor.size(0); 51 | int n_neighbors = static_cast(pow(neighbor_size + 1, d + 1)) - static_cast(pow(neighbor_size, d + 1)); 52 | 53 | at::Tensor norm_features_tensor = torch::CUDA(at::kFloat).ones({b, 1, n_points_in}); 54 | at::Tensor features_out_tensor = torch::CUDA(at::kFloat).zeros({b, df_out, n_points_out}); 55 | at::Tensor offset_in_tensor = torch::CUDA(at::kInt).zeros({b, n_points_in, d + 1}); 56 | at::Tensor offset_out_tensor = torch::CUDA(at::kInt).zeros({b, n_points_out, d + 1}); 57 | at::Tensor barycentric_in_tensor = torch::CUDA(at::kFloat).zeros({b, n_points_in, d + 2}); 58 | at::Tensor barycentric_out_tensor = torch::CUDA(at::kFloat).zeros({b, n_points_out, d + 2}); 59 | at::Tensor norm_tensor = torch::CUDA(at::kFloat).zeros({b, 1, n_points_out}); 60 | 61 | const float *position_in = position_in_tensor.data(); 62 | const float *position_out = position_out_tensor.data(); 63 | const float *features_in = features_in_tensor.data(); 64 | const float *norm_features = norm_features_tensor.data(); 65 | const float *weights = weights_tensor.data(); 66 | float *features_out = features_out_tensor.data(); 67 | int *offset_in = offset_in_tensor.data(); 68 | int *offset_out = offset_out_tensor.data(); 69 | float *barycentric_in = barycentric_in_tensor.data(); 70 | float *barycentric_out = barycentric_out_tensor.data(); 71 | float *norm = norm_tensor.data(); 72 | 73 | float *lattices_in; 74 | int *conv_neighbors; 75 | 76 | int n_filled = compute(b, n_points_in, n_points_out, df_in, df_out, d, neighbor_size, skip_conv, false, 77 | position_in, position_out, features_in, norm_features, weights, features_out, 78 | lattices_in, offset_in, offset_out, barycentric_in, barycentric_out, conv_neighbors, norm); 79 | 80 | at::Tensor lattices_in_tensor = 81 | torch::CUDA(at::kFloat).tensorFromBlob(lattices_in, {b, df_in, n_filled}, 82 | std::bind(free_cuda_data_callback, std::placeholders::_1, lattices_in)); 83 | at::Tensor conv_neighbors_tensor = 84 | torch::CUDA(at::kInt).tensorFromBlob(conv_neighbors, {b, n_neighbors, n_filled}, 85 | std::bind(free_cuda_data_callback, std::placeholders::_1, conv_neighbors)); 86 | 87 | return {features_out_tensor, 88 | lattices_in_tensor, 89 | offset_in_tensor, 90 | offset_out_tensor, 91 | barycentric_in_tensor, 92 | barycentric_out_tensor, 93 | conv_neighbors_tensor, 94 | norm_tensor}; 95 | } 96 | 97 | // inputs: 98 | // grad_out_tensor (b, df_out, n_points_out) 99 | // weights_tensor (df_out, df_in, n_neighbors) 100 | // lattices_in_tensor (b, df_in, n_filled) 101 | // offset_in_tensor (b, n_points_in, d + 1) 102 | // offset_out_tensor (b, n_points_out, d + 1) 103 | // barycentric_in_tensor (b, n_points_in, d + 2) 104 | // barycentric_out_tensor (b, n_points_out, d + 2) 105 | // conv_neighbors_tensor (b, n_neighbors, n_filled) 106 | // norm_tensor (b, 1, n_points_out) 107 | // neighbors_size neighbors_size 108 | // outputs: 109 | // grad_in_tensor (b, df_in, n_points_in) 110 | // grad_weights_tensor (df_out, df_in, n_neighbors) 111 | std::vector permutohedral_lattice_grad(at::Tensor grad_out_tensor, 112 | at::Tensor weights_tensor, 113 | at::Tensor lattices_in_tensor, 114 | at::Tensor offset_in_tensor, 115 | at::Tensor offset_out_tensor, 116 | at::Tensor barycentric_in_tensor, 117 | at::Tensor barycentric_out_tensor, 118 | at::Tensor conv_neighbors_tensor, 119 | at::Tensor norm_tensor, 120 | int neighbor_size, 121 | bool skip_conv) { 122 | CHECK_INPUT(grad_out_tensor); 123 | CHECK_INPUT_TYPE(grad_out_tensor, at::ScalarType::Float); 124 | CHECK_INPUT(weights_tensor); 125 | CHECK_INPUT_TYPE(weights_tensor, at::ScalarType::Float); 126 | CHECK_INPUT(lattices_in_tensor); 127 | CHECK_INPUT_TYPE(lattices_in_tensor, at::ScalarType::Float); 128 | CHECK_INPUT(offset_in_tensor); 129 | CHECK_INPUT_TYPE(offset_in_tensor, at::ScalarType::Int); 130 | CHECK_INPUT(offset_out_tensor); 131 | CHECK_INPUT_TYPE(offset_out_tensor, at::ScalarType::Int); 132 | CHECK_INPUT(barycentric_in_tensor); 133 | CHECK_INPUT_TYPE(barycentric_in_tensor, at::ScalarType::Float); 134 | CHECK_INPUT(barycentric_out_tensor); 135 | CHECK_INPUT_TYPE(barycentric_out_tensor, at::ScalarType::Float); 136 | CHECK_INPUT(conv_neighbors_tensor); 137 | CHECK_INPUT_TYPE(conv_neighbors_tensor, at::ScalarType::Int); 138 | CHECK_INPUT(norm_tensor); 139 | CHECK_INPUT_TYPE(norm_tensor, at::ScalarType::Float); 140 | 141 | int b = grad_out_tensor.size(0); 142 | int n_points_in = offset_in_tensor.size(1); 143 | int n_points_out = grad_out_tensor.size(2); 144 | int d = offset_in_tensor.size(2) - 1; 145 | int df_in = lattices_in_tensor.size(1); 146 | int df_out = grad_out_tensor.size(1); 147 | int n_filled = lattices_in_tensor.size(2); 148 | int n_neighbors = static_cast(pow(neighbor_size + 1, d + 1)) - static_cast(pow(neighbor_size, d + 1)); 149 | 150 | at::Tensor weights_transpose_tensor = weights_tensor.transpose(0, 1).contiguous(); 151 | at::Tensor grad_in_tensor = torch::CUDA(at::kFloat).zeros({b, df_in, n_points_in}); 152 | at::Tensor grad_weights_transpose_tensor = torch::CUDA(at::kFloat).zeros({df_out, n_neighbors, df_in}); 153 | 154 | const float *grad_out = grad_out_tensor.data(); 155 | const float *weights_transpose = weights_transpose_tensor.data(); 156 | const float *lattices_in = lattices_in_tensor.data(); 157 | const int *offset_in = offset_in_tensor.data(); 158 | const int *offset_out = offset_out_tensor.data(); 159 | const float *barycentric_in = barycentric_in_tensor.data(); 160 | const float *barycentric_out = barycentric_out_tensor.data(); 161 | const int *conv_neighbors = conv_neighbors_tensor.data(); 162 | const float *norm = norm_tensor.data(); 163 | float *grad_in = grad_in_tensor.data(); 164 | float *grad_weights_transpose = grad_weights_transpose_tensor.data(); 165 | 166 | compute_grad(b, n_points_in, n_points_out, n_filled, df_in, df_out, d, neighbor_size, skip_conv, 167 | grad_out, weights_transpose, 168 | lattices_in, offset_in, offset_out, barycentric_in, barycentric_out, conv_neighbors, norm, 169 | grad_in, grad_weights_transpose); 170 | 171 | return {grad_in_tensor, grad_weights_transpose_tensor.transpose(1, 2).contiguous()}; 172 | } 173 | 174 | // inputs: 175 | // position_tensor (b, n_points_in, d) 176 | // features_in_tensor (b, df_in, n_points_in) 177 | // weights_tensor (df_out, df_in, n_neighbors) 178 | // neighbor_size neighbor_size 179 | // skip_conv true: splat-conv-slice 180 | // false: splat-slice 181 | // 182 | // outputs: 183 | // features_out_tensor (b, df_out, n_points_out) 184 | // lattices_in_tensor (b, df_in, n_filled) 185 | // offset_in_tensor (b, n_points_in, d + 1) 186 | // barycentric_in_tensor (b, n_points_in, d + 2) 187 | // conv_neighbors_tensor (b, n_filled, n_neighbors) 188 | // norm_tensor (b, 1, n_points_out) 189 | std::vector permutohedral_lattice_keep_position(at::Tensor position_in_tensor, 190 | at::Tensor features_in_tensor, 191 | at::Tensor weights_tensor, 192 | int neighbor_size, 193 | bool skip_conv) { 194 | CHECK_INPUT(position_in_tensor); 195 | CHECK_INPUT_TYPE(position_in_tensor, at::ScalarType::Float); 196 | CHECK_INPUT(features_in_tensor); 197 | CHECK_INPUT_TYPE(features_in_tensor, at::ScalarType::Float); 198 | CHECK_INPUT(weights_tensor); 199 | CHECK_INPUT_TYPE(weights_tensor, at::ScalarType::Float); 200 | 201 | int b = position_in_tensor.size(0); 202 | int n_points_in = position_in_tensor.size(1); 203 | int n_points_out = n_points_in; 204 | int d = position_in_tensor.size(2); 205 | int df_in = weights_tensor.size(1); 206 | int df_out = weights_tensor.size(0); 207 | int n_neighbors = static_cast(pow(neighbor_size + 1, d + 1)) - static_cast(pow(neighbor_size, d + 1)); 208 | 209 | at::Tensor norm_features_tensor = torch::CUDA(at::kFloat).ones({b, 1, n_points_in}); 210 | at::Tensor features_out_tensor = torch::CUDA(at::kFloat).zeros({b, df_out, n_points_out}); 211 | at::Tensor offset_in_tensor = torch::CUDA(at::kInt).zeros({b, n_points_in, d + 1}); 212 | at::Tensor barycentric_in_tensor = torch::CUDA(at::kFloat).zeros({b, n_points_in, d + 2}); 213 | at::Tensor norm_tensor = torch::CUDA(at::kFloat).zeros({b, 1, n_points_out}); 214 | 215 | const float *position_in = position_in_tensor.data(); 216 | const float *position_out = 0; 217 | const float *features_in = features_in_tensor.data(); 218 | const float *norm_features = norm_features_tensor.data(); 219 | const float *weights = weights_tensor.data(); 220 | float *features_out = features_out_tensor.data(); 221 | int *offset_in = offset_in_tensor.data(); 222 | int *offset_out = 0; 223 | float *barycentric_in = barycentric_in_tensor.data(); 224 | float *barycentric_out = 0; 225 | float *norm = norm_tensor.data(); 226 | 227 | float *lattices_in; 228 | int *conv_neighbors; 229 | 230 | int n_filled = compute(b, n_points_in, n_points_out, df_in, df_out, d, neighbor_size, skip_conv, true, 231 | position_in, position_out, features_in, norm_features, weights, features_out, 232 | lattices_in, offset_in, offset_out, barycentric_in, barycentric_out, conv_neighbors, norm); 233 | 234 | at::Tensor lattices_in_tensor = 235 | torch::CUDA(at::kFloat).tensorFromBlob(lattices_in, {b, df_in, n_filled}, 236 | std::bind(free_cuda_data_callback, std::placeholders::_1, lattices_in)); 237 | at::Tensor conv_neighbors_tensor = 238 | torch::CUDA(at::kInt).tensorFromBlob(conv_neighbors, {b, n_neighbors, n_filled}, 239 | std::bind(free_cuda_data_callback, std::placeholders::_1, conv_neighbors)); 240 | 241 | return {features_out_tensor, 242 | lattices_in_tensor, 243 | offset_in_tensor, 244 | barycentric_in_tensor, 245 | conv_neighbors_tensor, 246 | norm_tensor}; 247 | } 248 | 249 | // inputs: 250 | // grad_out_tensor (b, df_out, n_points_out) 251 | // weights_tensor (df_out, df_in, n_neighbors) 252 | // lattices_in_tensor (b, df_in, n_filled) 253 | // offset_in_tensor (b, n_points_in, d + 1) 254 | // barycentric_in_tensor (b, n_points_in, d + 2) 255 | // conv_neighbors_tensor (b, n_filled, n_neighbors) 256 | // norm_tensor (b, 1, n_points_out) 257 | // neighbors_size neighbors_size 258 | // outputs: 259 | // grad_in_tensor (b, df_in, n_points_in) 260 | // grad_weights_tensor (df_out, df_in, n_neighbors) 261 | std::vector permutohedral_lattice_keep_position_grad(at::Tensor grad_out_tensor, 262 | at::Tensor weights_tensor, 263 | at::Tensor lattices_in_tensor, 264 | at::Tensor offset_in_tensor, 265 | at::Tensor barycentric_in_tensor, 266 | at::Tensor conv_neighbors_tensor, 267 | at::Tensor norm_tensor, 268 | int neighbor_size, 269 | bool skip_conv) { 270 | CHECK_INPUT(grad_out_tensor); 271 | CHECK_INPUT_TYPE(grad_out_tensor, at::ScalarType::Float); 272 | CHECK_INPUT(weights_tensor); 273 | CHECK_INPUT_TYPE(weights_tensor, at::ScalarType::Float); 274 | CHECK_INPUT(lattices_in_tensor); 275 | CHECK_INPUT_TYPE(lattices_in_tensor, at::ScalarType::Float); 276 | CHECK_INPUT(offset_in_tensor); 277 | CHECK_INPUT_TYPE(offset_in_tensor, at::ScalarType::Int); 278 | CHECK_INPUT(barycentric_in_tensor); 279 | CHECK_INPUT_TYPE(barycentric_in_tensor, at::ScalarType::Float); 280 | CHECK_INPUT(conv_neighbors_tensor); 281 | CHECK_INPUT_TYPE(conv_neighbors_tensor, at::ScalarType::Int); 282 | CHECK_INPUT(norm_tensor); 283 | CHECK_INPUT_TYPE(norm_tensor, at::ScalarType::Float); 284 | 285 | int b = grad_out_tensor.size(0); 286 | int n_points_in = offset_in_tensor.size(1); 287 | int n_points_out = n_points_in; 288 | int d = offset_in_tensor.size(2) - 1; 289 | int df_in = lattices_in_tensor.size(1); 290 | int df_out = grad_out_tensor.size(1); 291 | int n_filled = lattices_in_tensor.size(2); 292 | int n_neighbors = static_cast(pow(neighbor_size + 1, d + 1)) - static_cast(pow(neighbor_size, d + 1)); 293 | 294 | at::Tensor weights_transpose_tensor = weights_tensor.transpose(0, 1).contiguous(); 295 | at::Tensor grad_in_tensor = torch::CUDA(at::kFloat).zeros({b, df_in, n_points_in}); 296 | at::Tensor grad_weights_transpose_tensor = torch::CUDA(at::kFloat).zeros({df_out, n_neighbors, df_in}); 297 | 298 | const float *grad_out = grad_out_tensor.data(); 299 | const float *weights_transpose = weights_transpose_tensor.data(); 300 | const float *lattices_in = lattices_in_tensor.data(); 301 | const int *offset_in = offset_in_tensor.data(); 302 | const int *offset_out = offset_in; 303 | const float *barycentric_in = barycentric_in_tensor.data(); 304 | const float *barycentric_out = barycentric_in; 305 | const int *conv_neighbors = conv_neighbors_tensor.data(); 306 | const float *norm = norm_tensor.data(); 307 | float *grad_in = grad_in_tensor.data(); 308 | float *grad_weights_transpose = grad_weights_transpose_tensor.data(); 309 | 310 | compute_grad(b, n_points_in, n_points_out, n_filled, df_in, df_out, d, neighbor_size, skip_conv, 311 | grad_out, weights_transpose, 312 | lattices_in, offset_in, offset_out, barycentric_in, barycentric_out, conv_neighbors, norm, 313 | grad_in, grad_weights_transpose); 314 | 315 | return {grad_in_tensor, grad_weights_transpose_tensor.transpose(1, 2).contiguous()}; 316 | } 317 | 318 | PYBIND11_MODULE(C_, m) { 319 | m.def("permutohedral_lattice", &permutohedral_lattice, "permutohedral lattice"); 320 | m.def("permutohedral_lattice_keep_position", &permutohedral_lattice_keep_position, "permutohedral lattice"); 321 | m.def("permutohedral_lattice_grad", &permutohedral_lattice_grad, "permutohedral lattice grad"); 322 | m.def("permutohedral_lattice_keep_position_grad", &permutohedral_lattice_keep_position_grad, "permutohedral lattice grad"); 323 | } 324 | --------------------------------------------------------------------------------