├── .gitignore ├── GeGnn_standalong.py ├── GnnDist.py ├── configs └── gnndist_test.yaml ├── dataset_ps.py ├── hgraph ├── __init__.py ├── hgraph.py ├── models │ ├── __init__.py │ └── graph_unet.py └── modules │ ├── decide_edge_type.py │ ├── modules.py │ └── resblocks.py ├── img ├── bunny.jpg ├── teaser.png └── teaser.v1.png ├── pretrained └── ours00500.solver.tar ├── readme.md ├── requirements.txt └── utils ├── __init__.py ├── dataset_prepare ├── py_data_generater.py └── simplifiy.py ├── ocnn ├── __init__.py ├── dataset.py ├── models │ ├── __init__.py │ ├── autoencoder.py │ ├── hrnet.py │ ├── lenet.py │ ├── octree_unet.py │ ├── resnet.py │ └── segnet.py ├── modules │ ├── __init__.py │ ├── modules.py │ └── resblocks.py ├── nn │ ├── __init__.py │ ├── octree2col.py │ ├── octree2vox.py │ ├── octree_conv.py │ ├── octree_drop.py │ ├── octree_dwconv.py │ ├── octree_interp.py │ ├── octree_norm.py │ ├── octree_pad.py │ └── octree_pool.py ├── octree │ ├── __init__.py │ ├── octree.py │ ├── points.py │ └── shuffled_key.py └── utils.py └── thsolver ├── __init__.py ├── config.py ├── dataset.py ├── default_settings.py ├── lr_scheduler.py ├── sampler.py ├── solver.py └── tracker.py /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | .DS_Store 3 | .idea 4 | .vscode 5 | __pycache__ 6 | sphere.obj 7 | runs 8 | out 9 | configs/superhyper 10 | logs 11 | 12 | 13 | /utils/dataset_prepare/inputs/ 14 | /utils/dataset_prepare/inputs_/ 15 | /utils/dataset_prepare/outputs/ 16 | /utils/dataset_prepare/simplified_meshes/ 17 | /utils/dataset_prepare/shapenet.13.zip 18 | /utils/dataset_prepare/data_complete/ 19 | /utils/dataset_prepare/original_obj/ 20 | /utils/dataset_prepare/simplified_obj/ 21 | /utils/.ipynb_checkpoints 22 | /data/ 23 | /utils/dataset_prepare/geodesic/ 24 | /logs 25 | -------------------------------------------------------------------------------- /GeGnn_standalong.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import hgraph 4 | import trimesh 5 | 6 | import time 7 | import heat_method 8 | import argparse 9 | 10 | 11 | # a function that reads a triangular mesh, and generates its corresponding graph 12 | def read_mesh(path, to_tensor=True): 13 | # :param path: the path to the mesh 14 | # :param to_tensor: whether to convert the numpy array to torch tensor 15 | # :return: a dict containing the vertices, edges, normals, faces, face_normals, face_areas 16 | 17 | # read the mesh 18 | mesh = trimesh.load(path) 19 | # get the vertices 20 | vertices = mesh.vertices 21 | # get the edges 22 | edges = mesh.edges_unique 23 | edges_reversed = np.concatenate([edges[:, 1:], edges[:, :1]], 1) 24 | edges = np.concatenate([edges, edges_reversed], 0) 25 | edges = np.transpose(edges) 26 | # get the normals 27 | normals = mesh.vertex_normals 28 | # normalize the normal 29 | norm_normals = np.linalg.norm(normals, axis=1) 30 | normals = normals / norm_normals[:, np.newaxis] 31 | 32 | # get the faces 33 | faces = mesh.faces 34 | # get the face normals 35 | face_normals = mesh.face_normals 36 | # get the face areas 37 | face_areas = mesh.area_faces 38 | # convert to tensor, if needed 39 | if to_tensor: 40 | vertices = torch.from_numpy(vertices).float().cuda() 41 | edges = torch.from_numpy(edges).long().cuda() 42 | normals = torch.from_numpy(np.array(normals)).float().cuda() 43 | faces = torch.from_numpy(faces).long().cuda() 44 | face_normals = torch.from_numpy(np.array(face_normals)).float().cuda() 45 | face_areas = torch.from_numpy(np.array(face_areas)).float().cuda() 46 | 47 | # generate a dict 48 | dic = { 49 | "vertices": vertices, 50 | "edges": edges, 51 | "normals": normals, 52 | "faces": faces, 53 | "face_normals": face_normals, 54 | "face_areas": face_areas, 55 | } 56 | 57 | return dic 58 | 59 | 60 | # a wrapper of pretrained model 61 | class PretrainedModel(torch.nn.Module): 62 | 63 | def __init__(self, ckpt_path): 64 | super(PretrainedModel, self).__init__() 65 | self.model = hgraph.models.graph_unet.GraphUNet( 66 | 6, 256, None, None).cuda() 67 | self.embds = None 68 | # load the pretrained model 69 | ckpt = torch.load(ckpt_path, map_location=torch.device('cuda')) 70 | model_dict=ckpt['model_dict'] 71 | self.model.load_state_dict(model_dict) 72 | 73 | def embd_decoder_func(self, i, j, embedding): 74 | i = i.long() 75 | j = j.long() 76 | embd_i = embedding[i].squeeze(-1) 77 | embd_j = embedding[j].squeeze(-1) 78 | embd = (embd_i - embd_j) ** 2 79 | pred = self.model.embedding_decoder_mlp(embd) 80 | pred = pred.squeeze(-1) 81 | return pred 82 | 83 | def precompute(self, mesh): 84 | with torch.no_grad(): 85 | # calculate vertex wise embd 86 | # 1. construct the graph tree 87 | vertices = mesh['vertices'] # [N, 3] 88 | normals = mesh['normals'] # [N, 3] 89 | edges = mesh['edges'] # [2, M] 90 | 91 | tree = hgraph.hgraph.HGraph() 92 | tree.build_single_hgraph( 93 | hgraph.hgraph.Data(x=torch.cat([vertices, normals], dim=1), edge_index=edges) 94 | ) 95 | 96 | # 2. feed the graph tree into the model & get the vertex-wise embedding 97 | self.embds = self.model( 98 | torch.cat([vertices, normals], dim=1), 99 | tree, 100 | tree.depth, 101 | dist=None, 102 | only_embd=True) 103 | self.embds = self.embds.detach() 104 | 105 | 106 | def forward(self, p_vertices=None, q_vertices=None): 107 | # given a mesh, and two sets of vertices, calculate the geodesic distances the pairs 108 | # :param p_vertices: [N], index of the vertices in the first set 109 | # :param q_vertices: [N], index of the vertices in the second set 110 | assert self.embds is not None, "Please call precompute() first!" 111 | with torch.no_grad(): 112 | ans = self.embd_decoder_func(p_vertices, q_vertices, self.embds) 113 | return ans 114 | 115 | def SSAD(self, source: list): 116 | # given a mesh, calculate the geodesic distances from the source to all other vertices 117 | assert self.embds is not None, "Please call precompute() first!" 118 | 119 | 120 | with torch.no_grad(): 121 | ret = [] 122 | ss, tt = [], [] 123 | for i in range(len(source)): 124 | s = torch.tensor([source[i]]).repeat(self.embds.shape[0]).cuda() 125 | t = torch.arange(self.embds.shape[0]).cuda() 126 | ss.append(s) 127 | tt.append(t) 128 | ans = self.embd_decoder_func(s, t, self.embds) 129 | ret.append(ans) 130 | 131 | return ret 132 | 133 | 134 | 135 | 136 | 137 | 138 | # a wrapper of pretrained model, so that it can be called directly from the command line 139 | def main(): 140 | parser = argparse.ArgumentParser() 141 | parser.add_argument('--mode', type=str, default='SSAD', 142 | help='only SSAD available for now') 143 | parser.add_argument('--test_file', type=str, default=None, help='path to the obj file') 144 | parser.add_argument('--ckpt_path', type=str, default=None, help='path to the checkpoint') 145 | parser.add_argument('--start_pts', type=str, default=None, help='an int is expected.') 146 | parser.add_argument('--output', type=str, default=None, help='path to the output file') 147 | args = parser.parse_args() 148 | 149 | if args.mode == "SSAD": 150 | obj_dic = read_mesh(args.test_file) 151 | # print the vertex and face number 152 | print("Vertex number: ", obj_dic['vertices'].shape[0], "Face number: ", obj_dic['faces'].shape[0]) 153 | start_pts = torch.tensor(int(args.start_pts)).cuda() 154 | 155 | model = PretrainedModel(args.ckpt_path).cuda() 156 | model.precompute(obj_dic) 157 | dist_pred = model.SSAD([start_pts])[0] 158 | 159 | np.save(args.output, dist_pred.detach().cpu().numpy()) 160 | 161 | # save the colored mesh for visualization 162 | # given the vertices, faces of a mesh, save it as obj file 163 | def save_mesh_as_obj(vertices, faces, scalar=None, path="out/our_mesh.obj"): 164 | with open(path, 'w') as f: 165 | f.write('# mesh\n') # header of LittleRender 166 | for v in vertices: 167 | f.write('v ' + str(v[0]) + ' ' + str(v[1]) + ' ' + str(v[2]) + '\n') 168 | for face in faces: 169 | f.write('f ' + str(face[0]+1) + ' ' + str(face[1]+1) + ' ' + str(face[2]+1) + '\n') 170 | if scalar is not None: 171 | # normalize the scalar to [0, 1] 172 | scalar = (scalar - np.min(scalar)) / (np.max(scalar) - np.min(scalar)) 173 | for c in scalar: 174 | f.write('c ' + str(c) + ' ' + str(c) + ' ' + str(c) + '\n') 175 | 176 | print("Saved mesh as obj file:", path, end="") 177 | if scalar is not None: 178 | print(" (with color) ") 179 | else: 180 | print(" (without color)") 181 | 182 | save_mesh_as_obj(obj_dic['vertices'].detach().cpu().numpy(), 183 | obj_dic['faces'].detach().cpu().numpy(), 184 | dist_pred.detach().cpu().numpy()) 185 | 186 | 187 | else: 188 | print("Invalid mode! (" + args.mode + ")") 189 | 190 | 191 | 192 | if __name__ == "__main__": 193 | main() 194 | 195 | ################################### 196 | # visualization via polyscope starts 197 | # comment out the following lines if you are using ssh 198 | ################################### 199 | import polyscope as ps 200 | import numpy as np 201 | import trimesh 202 | 203 | # load mesh 204 | mesh = trimesh.load_mesh("out/our_mesh.obj", process=False) 205 | vertices = mesh.vertices 206 | faces = mesh.faces 207 | 208 | # load numpy array 209 | colors = np.load("out/ssad_ours.npy") 210 | print(colors.shape) 211 | 212 | # Initialize polyscope 213 | ps.init() 214 | ps_cloud = ps.register_point_cloud("my mesh", vertices) 215 | ps_cloud.add_scalar_quantity("geo_distance", colors, enabled=True) 216 | ps.show() 217 | ################################### 218 | # visualization via polyscope ends 219 | ################################### 220 | -------------------------------------------------------------------------------- /GnnDist.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils.thsolver import default_settings 3 | # initialize global settings 4 | default_settings._init() 5 | from utils.thsolver.config import parse_args 6 | FLAGS = parse_args() 7 | default_settings.set_global_values(FLAGS) 8 | 9 | 10 | from utils import thsolver 11 | import hgraph 12 | 13 | from dataset_ps import get_dataset 14 | 15 | 16 | 17 | def get_parameter_number(model): 18 | """print the number of parameters in a model on terminal """ 19 | total_num = sum(p.numel() for p in model.parameters()) 20 | trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad) 21 | print(f"\nTotal Parameters: {total_num}, trainable: {trainable_num}") 22 | return {'Total': total_num, 'Trainable': trainable_num} 23 | 24 | 25 | class GnnDistSolver(thsolver.Solver): 26 | 27 | def get_model(self, flags): 28 | 29 | if flags.name.lower() == 'unet': 30 | model = hgraph.models.graph_unet.GraphUNet( 31 | flags.channel, flags.nout, flags.interp, flags.nempty) 32 | 33 | 34 | # model = ocnn.models.octree_unet.OctreeUnet(flags.channel, flags.nout, flags.interp, flags.nempty) 35 | else: 36 | raise ValueError 37 | 38 | 39 | # overall num parameters 40 | get_parameter_number(model) 41 | return model 42 | 43 | def get_dataset(self, flags): 44 | return get_dataset(flags) 45 | 46 | 47 | def model_forward(self, batch): 48 | """equivalent to `self.get_embd` + `self.embd_decoder_func` """ 49 | 50 | data = batch["feature"].cuda() 51 | hgraph = batch['hgraph'].cuda() 52 | dist = batch['dist'].cuda() 53 | 54 | pred = self.model(data, hgraph, hgraph.depth, dist) 55 | return pred 56 | 57 | def get_embd(self, batch): 58 | """only used in visualization!""" 59 | data = batch["feature"].cuda() 60 | hgraph = batch['hgraph'].cuda() 61 | dist = batch['dist'].cuda() 62 | 63 | embedding = self.model(data, hgraph, hgraph.depth, dist, only_embd=True) 64 | return embedding 65 | 66 | def embd_decoder_func(self, i, j, embedding): 67 | """only used in visualization!""" 68 | i = i.long() 69 | j = j.long() 70 | embd_i = embedding[i].squeeze(-1) 71 | embd_j = embedding[j].squeeze(-1) 72 | embd = (embd_i - embd_j) ** 2 73 | pred = self.model.embedding_decoder_mlp(embd) 74 | pred = pred.squeeze(-1) 75 | return pred 76 | 77 | def train_step(self, batch): 78 | pred = self.model_forward(batch) 79 | loss = self.loss_function(batch, pred) 80 | return {'train/loss': loss} 81 | 82 | def test_step(self, batch): 83 | pred = self.model_forward(batch) 84 | loss = self.loss_function(batch, pred) 85 | return {'test/loss': loss} 86 | 87 | def loss_function(self, batch, pred): 88 | dist = batch['dist'].cuda() 89 | gt = dist[:, 2] 90 | 91 | # there are many kind of losses that may apply: 92 | 93 | # option 1: Mean Absolute Error, MAE 94 | #loss = torch.abs(pred - gt).mean() 95 | 96 | # option 2: relative MAE 97 | loss = (torch.abs(pred - gt) / (gt + 1e-3)).mean() 98 | 99 | # option 3: Mean Squared Error, MSE 100 | #loss = torch.square(pred - gt).mean() 101 | 102 | # option 4: relative MSE 103 | #loss = torch.square((pred - gt) / (gt + 1e-3)).mean() 104 | 105 | # option 5: root mean squared error, RMSE 106 | #loss = torch.sqrt(torch.square(pred - gt).mean()) 107 | 108 | # clamp 109 | loss = torch.clamp(loss, -10, 10) 110 | 111 | return loss 112 | 113 | 114 | #def visualization(ret): 115 | # open the render to visualize the result 116 | # print("Establishing WebSocket and Visualization!") 117 | # asyncio.run(interactive.main(ret)) 118 | 119 | 120 | 121 | if __name__ == "__main__": 122 | ret = GnnDistSolver.main() 123 | #visualization(ret) 124 | 125 | -------------------------------------------------------------------------------- /configs/gnndist_test.yaml: -------------------------------------------------------------------------------- 1 | SOLVER: 2 | gpu: 0, 3 | run: train 4 | 5 | logdir: logs/my_test 6 | max_epoch: 500 7 | test_every_epoch: 20 8 | log_per_iter: 10 9 | ckpt_num: 5 10 | dist_url: tcp://localhost:10266 11 | 12 | # optimizer 13 | type: adamw 14 | weight_decay: 0.01 # default value of adamw 15 | lr: 0.00025 16 | 17 | # learning rate 18 | lr_type: poly 19 | lr_power: 0.9 20 | 21 | DATA: 22 | train: 23 | # octree building 24 | depth: 6 25 | full_depth: 2 #The octree layers with a depth smaller than `full_depth` are forced to be full. 26 | 27 | # data augmentations 28 | distort: False 29 | 30 | # data loading 31 | location: ./ 32 | filelist: ./data/tiny/filelist/filelist_train.txt 33 | batch_size: 1 34 | shuffle: True 35 | num_workers: 0 36 | 37 | test: 38 | # octree building 39 | depth: 6 40 | full_depth: 2 41 | 42 | # data augmentations 43 | distort: False 44 | 45 | # data loading 46 | location: ./ 47 | filelist: ./data/tiny/filelist/filelist_test.txt 48 | batch_size: 4 49 | shuffle: True 50 | num_workers: 10 51 | 52 | MODEL: 53 | name: unet 54 | feature: PN # N -> Normal(3 channels); 55 | # P -> Points(3 channel) 56 | channel: 6 57 | 58 | nout: 256 # the final embedding dimension of each vertices 59 | nempty: True 60 | 61 | num_edge_types: 7 # deprecated 62 | 63 | # SAGE, GAT, Edge, DirConv, DistConv, my 64 | conv_type: my 65 | include_distance: True # only appliable when use dist conv. if true, the distance between points will be concated to the feature. 66 | 67 | normal_aware_pooling: True # when grid pooling, consider normal or not 68 | 69 | 70 | # visualization, will not affect the training/testing process, only visualization 71 | 72 | get_test_stat: False # if true, evaluate the test set before visualization 73 | # test_mesh: 2323 # NOT finished: specify a mesh to evaluate in visualization system 74 | 75 | -------------------------------------------------------------------------------- /dataset_ps.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import utils 3 | from utils import ocnn 4 | 5 | import numpy as np 6 | from utils.thsolver import Dataset 7 | 8 | 9 | from hgraph.hgraph import Data 10 | from hgraph.hgraph import HGraph 11 | 12 | 13 | class Transform(utils.ocnn.dataset.Transform): 14 | 15 | def __call__(self, sample: dict, idx: int): 16 | vertices = torch.from_numpy(sample['vertices'].astype(np.float32)) 17 | normals = torch.from_numpy(sample['normals'].astype(np.float32)) 18 | edges = torch.from_numpy(sample['edges'].astype(np.float32)).t().contiguous().long() 19 | dist_idx = sample['dist_idx'].astype(np.float32) 20 | dist_val = sample['dist_val'].astype(np.float32) 21 | # breakpoint() 22 | dist = np.concatenate([dist_idx, dist_val], -1) 23 | dist = torch.from_numpy(dist) 24 | 25 | 26 | rnd_idx = torch.randint(low=0, high=dist.shape[0], size=(100000,)) 27 | dist = dist[rnd_idx] 28 | 29 | 30 | # normalize 31 | norm2 = torch.sqrt(torch.sum(normals ** 2, dim=1, keepdim=True)) 32 | normals = normals / torch.clamp(norm2, min=1.0e-12) 33 | 34 | # construct hierarchical graph 35 | h_graph = HGraph() 36 | 37 | h_graph.build_single_hgraph(Data(x=torch.cat([vertices, normals], dim=1), edge_index=edges)) 38 | 39 | return {'hgraph': h_graph, 40 | 'vertices': vertices, 'normals': normals, 41 | 'dist': dist, 'edges': edges} 42 | 43 | 44 | def collate_batch(batch: list): 45 | # batch: list of single samples. 46 | # each sample is a dict with keys: 47 | # edges, vertices, normals, dist 48 | 49 | # output: a big sample 50 | assert type(batch) == list 51 | 52 | # merge many hgraphs into one super hgraph 53 | 54 | 55 | outputs = {} 56 | for key in batch[0].keys(): 57 | outputs[key] = [b[key] for b in batch] 58 | 59 | pts_num = torch.tensor([pts.shape[0] for pts in outputs['vertices']]) 60 | cum_sum = utils.ocnn.utils.cumsum(pts_num, dim=0, exclusive=True) 61 | for i, dist in enumerate(outputs['dist']): 62 | dist[:, :2] += cum_sum[i] 63 | #for i, edge in enumerate(outputs['edges']): 64 | # edge += cum_sum[i] 65 | outputs['dist'] = torch.cat(outputs['dist'], dim=0) 66 | 67 | 68 | # input feature 69 | vertices = torch.cat(outputs['vertices'], dim=0) 70 | normals = torch.cat(outputs['normals'], dim=0) 71 | feature = torch.cat([vertices, normals], dim=1) 72 | outputs['feature'] = feature 73 | 74 | # merge a batch of hgraphs into one super hgraph 75 | hgraph_super = HGraph(batch_size=len(batch)) 76 | hgraph_super.merge_hgraph(outputs['hgraph']) 77 | outputs['hgraph'] = hgraph_super 78 | 79 | #if (outputs['dist'].max() >= len(vertices)): 80 | # print("!!!!!!!") 81 | 82 | 83 | # Merge a batch of octrees into one super octree 84 | #octree = ocnn.octree.merge_octrees(outputs['octree']) 85 | #octree.construct_all_neigh() 86 | #outputs['octree'] = octree 87 | 88 | # Merge a batch of points 89 | #outputs['points'] = ocnn.octree.merge_points(outputs['points']) 90 | return outputs 91 | 92 | 93 | def get_dataset(flags): 94 | transform = Transform(**flags) 95 | dataset = Dataset(flags.location, flags.filelist, transform, 96 | read_file=np.load, take=flags.take) 97 | return dataset, collate_batch 98 | -------------------------------------------------------------------------------- /hgraph/__init__.py: -------------------------------------------------------------------------------- 1 | from . import models 2 | from . import modules 3 | from . import hgraph -------------------------------------------------------------------------------- /hgraph/hgraph.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Callable, Optional, Tuple 3 | 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | import torch_geometric 10 | from torch_geometric.nn import SAGEConv 11 | #from torch_geometric.nn import avg_pool, voxel_grid 12 | 13 | #from utils.thsolver import default_settings 14 | 15 | from .modules.decide_edge_type import * 16 | 17 | """ 18 | graph neural network coarse to fine data structure 19 | keep records of a "graph tree" 20 | """ 21 | 22 | 23 | class Data: 24 | def __init__(self, x=None, edge_index=None, edge_attr=None): 25 | """ 26 | a rewrite of torch_geometric.data.Data 27 | get rid of its self-aleck re-indexing 28 | :param edge_attr: the type of edges. 29 | """ 30 | self.x = x 31 | self.edge_index = edge_index 32 | self.edge_attr = edge_attr 33 | 34 | def to(self, target): 35 | # trans data from gpu to cpu or vice versa 36 | self.x = self.x.to(target) 37 | self.edge_index = self.edge_index.to(target) 38 | if self.edge_attr != None: 39 | self.edge_attr = self.edge_attr.to(target) 40 | return self 41 | 42 | def cuda(self): 43 | return self.to("cuda") 44 | 45 | def cpu(self): 46 | return self.to("cpu") 47 | 48 | 49 | 50 | def avg_pool( 51 | cluster: torch.Tensor, 52 | data: Data, 53 | transform: Optional[Callable] = None, 54 | ) -> Data: 55 | """a wrapper of torch_geometric.nn.avg_pool""" 56 | data_torch_geometric = torch_geometric.data.Data(x=data.x, edge_index=data.edge_index) 57 | new_data = torch_geometric.nn.avg_pool(cluster, data_torch_geometric, transform=transform) 58 | ret = Data(x=new_data.x, edge_index=new_data.edge_index) 59 | return ret 60 | 61 | 62 | def avg_pool_maintain_old( 63 | cluster: torch.Tensor, 64 | data: Data, 65 | transform: Optional[Callable] = None, 66 | ): 67 | """a wrapper of torch_geometric.nn.avg_pool, but maintain the old graph""" 68 | data_torch_geometric = torch_geometric.data.Data(x=data.x, edge_index=data.edge_index) 69 | new_layer = torch_geometric.nn.avg_pool(cluster, data_torch_geometric, transform=transform) 70 | # connect the corresponding node in the two layers 71 | 72 | 73 | 74 | def pooling(data: Data, size: float, normal_aware_pooling): 75 | """ 76 | do pooling according to x. a new graph (x, edges) will be generated after pooling. 77 | This function is a wrapper of some funcs in pytorch_geometric. It assumes the 78 | object's coordinates range from -1 to 1. 79 | 80 | normal_aware_pooling: if True, only data.x[..., :3] will be used for grid pooling. 81 | """ 82 | assert type(size) == float 83 | if normal_aware_pooling == False: 84 | x = data.x[..., :3] 85 | else: 86 | x = data.x[..., :6] 87 | # we assume x has 6 feature channels, first 3 xyz, then 3 nxnynz 88 | # grid size here waits for fine-tuning 89 | n_size = size * 3. # a hyper parameter, controling "how important the normal vecs are, compared to xyz coords" 90 | size = [size, size, size, n_size, n_size, n_size] 91 | 92 | edges = data.edge_index 93 | cluster = torch_geometric.nn.voxel_grid(pos=x, size=size, batch=None, 94 | start=[-1, -1, -1.], end=[1, 1, 1.]) 95 | 96 | # keep max index smaller than # of unique (e.g., [4,5,4,3] --> [0,1,0,2]) 97 | mapping = cluster.unique() 98 | mapping += mapping.shape[0] 99 | cluster += mapping.shape[0] 100 | 101 | for i in range(int(mapping.shape[0])): # maybe some optimization here to remove the for loop 102 | cluster[cluster == mapping[i]] = i 103 | 104 | return cluster #.contiguous() 105 | 106 | 107 | 108 | 109 | def add_self_loop(data: Data): 110 | # avg_pool will clear the self loop in the graph. here we add it back 111 | device = data.x.device 112 | n_vert = data.x.shape[0] 113 | self_loops = torch.tensor([[i for i in range(int(n_vert))]]) 114 | self_loops = self_loops.repeat(2, 1) 115 | new_edges = torch.zeros([2, data.edge_index.shape[1] + n_vert], dtype=torch.int64) 116 | new_edges[:, :data.edge_index.shape[1]] = data.edge_index 117 | new_edges[:, data.edge_index.shape[1]:] = self_loops 118 | 119 | return Data(x=data.x, edge_index=new_edges).to(device) 120 | 121 | 122 | 123 | 124 | class HGraph: 125 | """ 126 | HGraph stands for "Hierarchical Graph" 127 | 128 | notes: 129 | - coordinates of input vertices should be in [-1, 1] 130 | - this class handles xyz coordinates/normals, not input feature 131 | 132 | """ 133 | def __init__(self, depth: int=5, 134 | smallest_grid=2/2**5, 135 | batch_size: int=1, 136 | adj_layer_connected=False): 137 | """ 138 | 139 | suppose depth=3: 140 | self.treedict[0] = original graph 141 | self.treedict[1] = merge verts in voxel with edge length = smallest_grid 142 | self.treedict[2] = merge verts in voxel with edge length = smallest_grid * (2**1) 143 | self.treedict[3] = merge verts in voxel with edge length = smallest_grid * (2**2) 144 | 145 | __init__ method only specify hyper parameters of a hgraph. 146 | The real construction of hgraph happens in build_single_hgraph() or merge_hgraph() 147 | """ 148 | 149 | assert smallest_grid * (2**(depth-1)) <= 2 150 | assert depth >= 0 151 | 152 | self.device = "cuda" 153 | self.depth = depth 154 | self.batch_size = batch_size 155 | self.smallest_grid = smallest_grid 156 | self.normal_aware_pooling = True 157 | 158 | self.vertices_sizes = {} 159 | self.edges_sizes = {} 160 | # 161 | self.treedict = {} 162 | self.cluster = {} 163 | 164 | def build_single_hgraph(self, original_graph: Data): 165 | """ 166 | build a graph-tree of **one** graph 167 | """ 168 | assert type(original_graph) == Data 169 | 170 | 171 | graphtree = {} 172 | cluster = {} 173 | vertices_size = {} 174 | edges_size = {} 175 | 176 | for i in range(self.depth+1): 177 | if i == 0: 178 | original_graph = add_self_loop(original_graph) 179 | # if original graph do not have edge types, assign it 180 | if original_graph.edge_attr == None: 181 | edges = original_graph.x[original_graph.edge_index[0]] \ 182 | - original_graph.x[original_graph.edge_index[1]] 183 | edges_attr = decide_edge_type_distance(edges, return_edge_length=False) 184 | original_graph.edge_attr = edges_attr 185 | graphtree[0] = original_graph 186 | cluster[0] = None 187 | edges_size[0] = original_graph.edge_index.shape[1] 188 | vertices_size[0] = original_graph.x.shape[0] 189 | continue 190 | 191 | clst = pooling(graphtree[i-1], self.smallest_grid * (2**(i-1)), normal_aware_pooling=self.normal_aware_pooling) 192 | new_graph = avg_pool(cluster=clst, data=graphtree[i-1], transform=None) 193 | new_graph = add_self_loop(new_graph) 194 | # assign edge type 195 | edges = new_graph.x[new_graph.edge_index[0]] \ 196 | - new_graph.x[new_graph.edge_index[1]] 197 | edges_attr = decide_edge_type_distance(edges, return_edge_length=False) 198 | new_graph.edge_attr = edges_attr 199 | 200 | graphtree[i] = new_graph 201 | cluster[i] = clst 202 | edges_size[i] = new_graph.edge_index.shape[1] 203 | vertices_size[i] = new_graph.x.shape[0] 204 | 205 | self.treedict = graphtree 206 | self.cluster = cluster 207 | self.vertices_sizes = vertices_size 208 | self.edges_sizes = edges_size 209 | 210 | #self.export_obj() 211 | 212 | 213 | # @staticmethod 214 | def merge_hgraph(self, original_graphs: list[HGraph], debug_report=False): 215 | """ 216 | merge multi hgraph into a large hgraph 217 | 218 | """ 219 | assert len(self.cluster) == 0 and len(self.treedict) == 0, "please call this function on a new instance" 220 | assert original_graphs.__len__() == self.batch_size, "please make sure the batch size is correct" 221 | 222 | # re-indexing 223 | for d in range(self.depth+1): 224 | # merge vertices for every layer 225 | num_vertices = [0] 226 | for i, each in enumerate(original_graphs): 227 | num_vertices.append(each.vertices_sizes[d]) 228 | cum_sum = torch.cumsum(torch.tensor(num_vertices), dim=0) 229 | for i in range(original_graphs.__len__()): 230 | original_graphs[i].treedict[d].edge_index += cum_sum[i] 231 | # cluster is None at d=0 232 | if d != 0: 233 | original_graphs[i].cluster[d] += cum_sum[i] 234 | 235 | # merge 236 | for d in range(self.depth+1): 237 | graphtrees_x, graphtrees_e, graphtrees_e_type, clusters = [], [], [], [] 238 | for i in range(original_graphs.__len__()): 239 | graphtrees_x.append(original_graphs[i].treedict[d].x) 240 | graphtrees_e.append(original_graphs[i].treedict[d].edge_index) 241 | graphtrees_e_type.append(original_graphs[i].treedict[d].edge_attr) 242 | clusters.append(original_graphs[i].cluster[d]) 243 | # construct new graph 244 | temp_data = Data(x=torch.cat(graphtrees_x, dim=0), 245 | edge_index=torch.cat(graphtrees_e, dim=1), 246 | edge_attr=torch.cat(graphtrees_e_type, dim=0) # edge_attr shape: [E] 247 | ) 248 | # construct new cluster 249 | if d != 0: 250 | temp_clst = torch.cat(clusters, dim=0) 251 | else: 252 | temp_clst = None 253 | self.treedict[d] = temp_data 254 | self.cluster[d] = temp_clst 255 | self.edges_sizes = temp_data.edge_index.shape[1] 256 | self.vertices_sizes = len(temp_data.x) 257 | 258 | # sanity check 259 | if debug_report == True: 260 | # a simple unit test 261 | for d in range(self.depth+1): 262 | num_edges_before = 0 263 | for i in range(original_graphs.__len__()): 264 | num_edges_before += original_graphs[i].treedict[d].edge_index.shape[1] 265 | num_edges_after = self.treedict[d].edge_index.shape[1] 266 | print(f"Before merge, at d={d} there's {num_edges_before} edges; {num_edges_after} afterwards") 267 | 268 | 269 | 270 | 271 | ##################################################### 272 | # Util 273 | ##################################################### 274 | 275 | def cuda(self): 276 | # move all tensors to cuda 277 | for each in self.treedict.keys(): 278 | self.treedict[each] = self.treedict[each].cuda() 279 | for each in self.cluster.keys(): 280 | if self.cluster[each] is None: 281 | continue 282 | self.cluster[each] = self.cluster[each].cuda() 283 | return self 284 | 285 | 286 | -------------------------------------------------------------------------------- /hgraph/models/__init__.py: -------------------------------------------------------------------------------- 1 | from . import graph_unet 2 | from . import simple_resnet -------------------------------------------------------------------------------- /hgraph/models/graph_unet.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn 4 | from typing import Dict, Optional 5 | 6 | 7 | from hgraph.hgraph import Data 8 | from hgraph.hgraph import HGraph 9 | 10 | from hgraph.modules.resblocks import GraphResBlocks, GraphResBlock2, GraphResBlock 11 | from hgraph.modules import modules 12 | 13 | 14 | class GraphUNet(torch.nn.Module): 15 | r''' 16 | A U-Net like network with graph neural network, utilizing HGraph (hierarchical graph) as 17 | the data structure. 18 | ''' 19 | 20 | def __init__(self, in_channels: int, out_channels: int, interp: str = 'linear', 21 | nempty: bool = False, **kwargs): 22 | super(GraphUNet, self).__init__() 23 | self.in_channels = in_channels 24 | self.out_channels = out_channels 25 | self.config_network() 26 | self.encoder_stages = len(self.encoder_blocks) 27 | self.decoder_stages = len(self.decoder_blocks) 28 | 29 | # encoder 30 | self.conv1 = modules.GraphConvBnRelu(in_channels, self.encoder_channel[0]) 31 | 32 | self.downsample = torch.nn.ModuleList( 33 | [modules.PoolingGraph() for i in range(self.encoder_stages)] 34 | ) 35 | self.encoder = torch.nn.ModuleList( 36 | [GraphResBlocks(self.encoder_channel[i], self.encoder_channel[i+1], 37 | resblk_num=self.encoder_blocks[i], resblk=self.resblk) 38 | for i in range(self.encoder_stages)] 39 | ) 40 | 41 | # decoder 42 | channel = [self.decoder_channel[i] + self.encoder_channel[-i-2] 43 | for i in range(self.decoder_stages)] 44 | self.upsample = torch.nn.ModuleList( 45 | [modules.UnpoolingGraph() for i in range(self.decoder_stages)] 46 | ) 47 | self.decoder = torch.nn.ModuleList( 48 | [GraphResBlocks(channel[i], self.decoder_channel[i+1], 49 | resblk_num=self.decoder_blocks[i], resblk=self.resblk, bottleneck=self.bottleneck) 50 | for i in range(self.decoder_stages)] 51 | ) 52 | 53 | # header 54 | # channel = self.decoder_channel[self.decoder_stages] 55 | #self.octree_interp = ocnn.nn.OctreeInterp(interp, nempty) 56 | self.header = torch.nn.Sequential( 57 | modules.Conv1x1BnRelu(self.decoder_channel[-1], self.decoder_channel[-1]), 58 | modules.Conv1x1(self.decoder_channel[-1], self.out_channels, use_bias=True)) 59 | 60 | # a embedding decoder function 61 | self.embedding_decoder_mlp = torch.nn.Sequential( 62 | torch.nn.Linear(self.out_channels, self.out_channels, bias=True), 63 | torch.nn.ReLU(), 64 | torch.nn.Linear(self.out_channels, self.out_channels, bias=True), 65 | torch.nn.ReLU(), 66 | torch.nn.Linear(self.out_channels, 1, bias=True) 67 | ) 68 | 69 | def config_network(self): 70 | r''' Configure the network channels and Resblock numbers. 71 | ''' 72 | self.encoder_blocks = [2, 3, 3, 3, 2] 73 | self.decoder_blocks = [2, 3, 3, 3, 2] 74 | self.encoder_channel = [256, 256, 256, 256, 256, 256] 75 | self.decoder_channel = [256, 256, 256, 256, 256, 256] 76 | 77 | # self.encoder_blocks = [4, 9, 9, 3] 78 | # self.decoder_blocks = [4, 9, 9, 3] 79 | #self.encoder_channel = [512, 512, 512, 512, 512,] 80 | #self.decoder_channel = [512, 512, 512, 512, 512] 81 | 82 | self.bottleneck = 1 83 | self.resblk = GraphResBlock2 84 | 85 | def unet_encoder(self, data: torch.Tensor, hgraph: HGraph, depth: int): 86 | r''' The encoder of the U-Net. 87 | ''' 88 | 89 | convd = dict() 90 | convd[depth] = self.conv1(data, hgraph, depth) 91 | for i in range(self.encoder_stages): 92 | d = depth - i 93 | conv = self.downsample[i](convd[d], hgraph, i+1) 94 | convd[d-1] = self.encoder[i](conv, hgraph, d-1) 95 | return convd 96 | 97 | def unet_decoder(self, convd: Dict[int, torch.Tensor], hgraph: HGraph, depth: int): 98 | r''' The decoder of the U-Net. 99 | ''' 100 | 101 | deconv = convd[depth] 102 | for i in range(self.decoder_stages): 103 | d = depth + i 104 | deconv = self.upsample[i](deconv, hgraph, self.decoder_stages-i) 105 | deconv = torch.cat([convd[d+1], deconv], dim=1) # skip connections 106 | deconv = self.decoder[i](deconv, hgraph, d+1) 107 | return deconv 108 | 109 | def forward(self, data: torch.Tensor, hgraph: HGraph, depth: int, dist: torch.Tensor, only_embd=False): 110 | """_summary_ 111 | 112 | Args: 113 | data (torch.Tensor): _description_ 114 | hgraph (HGraph): _description_ 115 | depth (int): _description_ 116 | dist (torch.Tensor): _description_ 117 | only_embd (bool, optional): If True, the return value will be the embeddings of vertices; if False, 118 | return the estimated distance of point pairs. Defaults to False. 119 | """ 120 | convd = self.unet_encoder(data, hgraph, depth) 121 | deconv = self.unet_decoder(convd, hgraph, depth - self.encoder_stages) 122 | 123 | embedding = self.header(deconv) 124 | 125 | if dist == None and only_embd: 126 | return embedding 127 | 128 | # calculate the distance 129 | i, j= dist[:, 0].long(), dist[:, 1].long() 130 | 131 | embd_i = embedding[i].squeeze(-1) 132 | embd_j = embedding[j].squeeze(-1) 133 | 134 | embd = (embd_i - embd_j) ** 2 135 | # alternative way... bad 136 | #embd_1 = (embd_i[..., :64] - embd_j[..., :64]) ** 2 137 | #embd_2 = (embd_i[..., 64:] - embd_j[..., 64:]) ** 4 138 | #embd = torch.cat([embd_1, embd_2], dim=-1) 139 | 140 | pred = self.embedding_decoder_mlp(embd) 141 | pred = pred.squeeze(-1) 142 | 143 | if only_embd: 144 | return embedding 145 | else: 146 | return pred 147 | 148 | 149 | -------------------------------------------------------------------------------- /hgraph/modules/decide_edge_type.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | 4 | # helper functions, decide the type of a edge 5 | 6 | 7 | ############################################################### 8 | # For distance graph conv 9 | ############################################################### 10 | 11 | def decide_edge_type_distance(vec: torch.tensor, 12 | method="predefined", 13 | return_edge_length=True): 14 | """ 15 | classify each vec into many categories. 16 | vec: N x 3 17 | ret: N 18 | """ 19 | if method == "predefined": 20 | return decide_edge_type_predefined_distance(vec, return_edge_length=return_edge_length) 21 | else: 22 | raise NotImplementedError 23 | 24 | 25 | def decide_edge_type_predefined_distance(vec: torch.tensor, epsilon=0.00001, return_edge_length=True): 26 | """ 27 | classify each vec into N categories, according to the length of the vcector. 28 | the last category is self-loop 29 | vec: N x 3 30 | ret: N 31 | """ 32 | positive_x = torch.maximum(vec[..., 0], torch.zeros_like(vec[..., 0])) 33 | positive_y = torch.maximum(vec[..., 1], torch.zeros_like(vec[..., 0])) 34 | positive_z = torch.maximum(vec[..., 2], torch.zeros_like(vec[..., 0])) 35 | negative_x = - torch.minimum(vec[..., 0], torch.zeros_like(vec[..., 0])) 36 | negative_y = - torch.minimum(vec[..., 1], torch.zeros_like(vec[..., 0])) 37 | negative_z = - torch.minimum(vec[..., 2], torch.zeros_like(vec[..., 0])) 38 | 39 | ary = torch.stack([positive_x, positive_y, positive_z, negative_x, negative_y, negative_z]) 40 | ary = ary.transpose(0, 1) 41 | 42 | device = vec.device 43 | 44 | edge_type = torch.ones([len(vec), 1]).to(device) * 999 45 | vec_length = torch.norm(ary, dim=1) 46 | 47 | # edge_length > eps --> type 0 edge 48 | # edge_length <= eps --> type 1 edge (self-loop) 49 | 50 | # print(thres_1) 51 | dist_threshold = [epsilon] 52 | 53 | 54 | for i in range(len(dist_threshold)-1, -1, -1): 55 | dist_mask = vec_length > dist_threshold[i] 56 | edge_type[dist_mask] = i 57 | 58 | # self-loop 59 | self_loops_mask = vec_length <= epsilon 60 | edge_type[self_loops_mask] = len(dist_threshold) 61 | 62 | # squeeze to 1d tensor 63 | edge_type = edge_type.squeeze(-1) 64 | edge_type = edge_type.long() 65 | 66 | 67 | # assertion, suppose there are N thresholds (epsilon included), make sure that 68 | # the type of different edges are at most N+1 69 | # breakpoint() 70 | assert edge_type.max() == len(dist_threshold) # there must be edges indexed N, since self-loop exists in all meshes 71 | assert edge_type.min() >= 0 # note edge_type indexed N-1 may not exists, since edges of that length may not exists in this mesh 72 | 73 | if return_edge_length == True: 74 | return edge_type, vec_length 75 | else: 76 | return edge_type 77 | 78 | -------------------------------------------------------------------------------- /hgraph/modules/modules.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import math 4 | import torch.utils.checkpoint 5 | from typing import List, Optional 6 | 7 | 8 | from hgraph.hgraph import Data 9 | #from torch_geometric.nn import avg_pool 10 | from hgraph.hgraph import avg_pool 11 | from hgraph.hgraph import HGraph 12 | 13 | 14 | bn_momentum, bn_eps = 0.01, 0.001 # the default value of Tensorflow 1.x 15 | # bn_momentum, bn_eps = 0.1, 1e-05 # the default value of pytorch 16 | 17 | 18 | ############################################################### 19 | # Util funcs 20 | ############################################################### 21 | 22 | from .decide_edge_type import * 23 | 24 | 25 | 26 | def broadcast(src: torch.Tensor, other: torch.Tensor, dim: int): 27 | r''' Broadcast :attr:`src` according to :attr:`other`, originally from the 28 | library `pytorch_scatter`. 29 | ''' 30 | 31 | if dim < 0: 32 | dim = other.dim() + dim 33 | 34 | if src.dim() == 1: 35 | for _ in range(0, dim): 36 | src = src.unsqueeze(0) 37 | for _ in range(src.dim(), other.dim()): 38 | src = src.unsqueeze(-1) 39 | 40 | src = src.expand_as(other) 41 | return src 42 | 43 | 44 | def scatter_add(src: torch.Tensor, index: torch.Tensor, dim: int = -1, 45 | out: Optional[torch.Tensor] = None, 46 | dim_size: Optional[int] = None,) -> torch.Tensor: 47 | r''' Reduces all values from the :attr:`src` tensor into :attr:`out` at the 48 | indices specified in the :attr:`index` tensor along a given axis :attr:`dim`. 49 | This is just a wrapper of :func:`torch.scatter` in a boardcasting fashion. 50 | 51 | Args: 52 | src (torch.Tensor): The source tensor. 53 | index (torch.Tensor): The indices of elements to scatter. 54 | dim (torch.Tensor): The axis along which to index, (default: :obj:`-1`). 55 | out (torch.Tensor or None): The destination tensor. 56 | dim_size (int or None): If :attr:`out` is not given, automatically create 57 | output with size :attr:`dim_size` at dimension :attr:`dim`. If 58 | :attr:`dim_size` is not given, a minimal sized output tensor according 59 | to :obj:`index.max() + 1` is returned. 60 | ''' 61 | 62 | index = broadcast(index, src, dim) 63 | 64 | if out is None: 65 | size = list(src.size()) 66 | if dim_size is not None: 67 | size[dim] = dim_size 68 | elif index.numel() == 0: 69 | size[dim] = 0 70 | else: 71 | size[dim] = int(index.max()) + 1 72 | out = torch.zeros(size, dtype=src.dtype, device=src.device) 73 | 74 | return out.scatter_add_(dim, index, src) 75 | 76 | 77 | ############################################################### 78 | # Basic Operators on graphs 79 | ############################################################### 80 | 81 | from torch_geometric.nn import GraphSAGE 82 | 83 | class GraphSAGEConv(torch.nn.Module): 84 | def __init__(self, in_channels: int, out_channels: int, 85 | ): 86 | super().__init__() 87 | self.conv = GraphSAGE(in_channels, out_channels, num_layers=1) 88 | 89 | def forward(self, input_feature: torch.Tensor, hgraph: HGraph, depth: int): 90 | graph = hgraph.treedict[hgraph.depth - depth] 91 | assert input_feature.shape[0] == graph.x.shape[0] 92 | out = self.conv(input_feature, graph.edge_index) 93 | return out 94 | 95 | 96 | from torch_geometric.nn import GATv2Conv 97 | 98 | class GraphAttentionConv(torch.nn.Module): 99 | # Graph Attention Network () 100 | def __init__(self, in_channels: int, out_channels: int, 101 | ): 102 | super().__init__() 103 | self.conv = GATv2Conv(in_channels, out_channels, num_layers=1) 104 | 105 | def forward(self, input_feature: torch.Tensor, hgraph: HGraph, depth: int): 106 | graph = hgraph.treedict[hgraph.depth - depth] 107 | assert input_feature.shape[0] == graph.x.shape[0] 108 | out = self.conv(input_feature, graph.edge_index) 109 | return out 110 | 111 | 112 | ############################################################ 113 | from torch.nn import Sequential as Seq, Linear, ReLU 114 | from torch_geometric.nn import MessagePassing, HeteroLinear 115 | 116 | 117 | ############################################################ 118 | 119 | class MyConvOp(MessagePassing): 120 | # this implementation is from tutorial "message passing" of torch_geometric 121 | def __init__(self, in_channels, out_channels, include_distance): 122 | super().__init__(aggr='max') # "Max" aggregation. 123 | #self.mlp = Seq(Linear(2 * in_channels, out_channels), 124 | # ReLU(), 125 | # Linear(out_channels, out_channels)) 126 | # self.self_loop = Linear(in_channels, out_channels) 127 | # self.neighbor_matrix_1 = Linear(in_channels, out_channels) 128 | # self.neighbor_matrix_2 = Linear(in_channels, out_channels) 129 | 130 | self.num_types = 2 131 | self.lin = HeteroLinear(in_channels, out_channels, num_types=self.num_types,) 132 | self.include_distance = include_distance 133 | 134 | 135 | def forward(self, x, edge_index, edge_attr=None): 136 | # x has shape [N, in_channels] 137 | # edge_index has shape [2, E] 138 | 139 | # edge_attr has shape [1, E] 140 | # referred to https://pytorch-geometric.readthedocs.io/en/latest/tutorial/create_gnn.html#implementing-the-gcn-layer 141 | 142 | if edge_attr == None: 143 | # dynamically calculate edge_type 144 | raise NotImplementedError 145 | else: 146 | # use predefined edge_type 147 | if edge_attr.min() != 0: 148 | breakpoint() 149 | assert edge_attr.min() == 0 150 | assert edge_attr.max() == self.num_types - 1 151 | 152 | 153 | return self.propagate(edge_index, x=x, edge_attr=edge_attr) 154 | 155 | def message(self, x_i, x_j, edge_attr): 156 | # x_i has shape [E, in_channels] 157 | # x_j has shape [E, in_channels] 158 | 159 | 160 | # last 3 dim is the xyz position of vertices 161 | # compute the abs value of them 162 | if self.include_distance == True: 163 | abs_dist = torch.norm(x_i[..., -3:] - x_j[..., -3:], dim=1, keepdim=True) # shape: [E, 1] 164 | x_j = torch.cat([x_j, x_i[..., -3:] - x_j[..., -3:], abs_dist], dim=-1) 165 | 166 | out = self.lin(x_j, edge_attr) 167 | return out 168 | 169 | 170 | class MyConv(torch.nn.Module): 171 | # My Conv, consider neighbor features (relative pos and its abs) 172 | # and many seperated catagories 173 | def __init__(self, in_channels: int, out_channels: int, 174 | ): 175 | super().__init__() 176 | self.include_distance = True#default_settings.get_global_value("include_distance") 177 | if self.include_distance == False: 178 | self.conv = MyConvOp(in_channels, out_channels, include_distance=False) 179 | else: 180 | self.conv = MyConvOp(in_channels + 3 + 3 + 1, out_channels, include_distance=True) 181 | 182 | def forward(self, input_feature: torch.Tensor, hgraph: HGraph, depth: int): 183 | graph = hgraph.treedict[hgraph.depth - depth] 184 | assert input_feature.shape[0] == graph.x.shape[0] 185 | # concat input feature 186 | if self.include_distance == True: 187 | input_feature = torch.cat([input_feature, graph.x[..., :3]], dim=-1) 188 | # do the conv 189 | out = self.conv(input_feature, graph.edge_index, graph.edge_attr) 190 | return out 191 | 192 | 193 | ######################################################## 194 | 195 | 196 | 197 | 198 | ############################################################### 199 | # Complex Components 200 | ############################################################### 201 | 202 | conv_type = "my"#default_settings.get_global_value("conv_type").lower() 203 | if conv_type == "sage": 204 | GraphConv = GraphSAGEConv 205 | elif conv_type == "my": 206 | GraphConv = MyConv 207 | else: 208 | raise NotImplementedError 209 | 210 | # group norm 211 | normalization = lambda x: torch.nn.GroupNorm(num_groups=4, num_channels=x, eps=bn_eps) 212 | 213 | 214 | 215 | class GraphConvBn(torch.nn.Module): 216 | def __init__(self, in_channels: int, out_channels: int, 217 | ): 218 | super().__init__() 219 | self.conv = GraphConv(in_channels, out_channels) 220 | self.bn = normalization(out_channels) 221 | 222 | def forward(self, data: torch.Tensor, hgraph: HGraph, depth: int): 223 | out = self.conv(data, hgraph, depth) 224 | out = self.bn(out) 225 | return out 226 | 227 | 228 | class GraphConvBnRelu(torch.nn.Module): 229 | def __init__(self, in_channels: int, out_channels: int, 230 | ): 231 | super().__init__() 232 | self.conv = GraphConv(in_channels, out_channels) 233 | self.bn = normalization(out_channels) 234 | self.relu = torch.nn.ReLU() # inplace=True 235 | 236 | def forward(self, data: torch.Tensor, hgraph: HGraph, depth: int): 237 | out = self.conv(data, hgraph, depth) 238 | out = self.bn(out) 239 | out = self.relu(out) 240 | return out 241 | 242 | 243 | class PoolingGraph(torch.nn.Module): 244 | def __init__(self): 245 | super().__init__() 246 | pass 247 | 248 | def forward(self, x: torch.Tensor, hgraph: HGraph, depth: int): 249 | cluster = hgraph.cluster[depth] 250 | out = avg_pool(cluster=cluster, data=Data(x=x, edge_index=torch.zeros([2,1]).long())) # fake edges XD 251 | return out.x 252 | 253 | class UnpoolingGraph(torch.nn.Module): 254 | def __init__(self): 255 | super().__init__() 256 | pass 257 | 258 | def forward(self, x: torch.Tensor, hgraph: HGraph, depth: int): 259 | assert depth != 0 260 | # 261 | feature_dim = x.shape[1] 262 | cluster = hgraph.cluster[depth][..., None].repeat(1, feature_dim).long() 263 | out = torch.gather(input=x, dim=0, index=cluster) 264 | return out 265 | 266 | 267 | 268 | 269 | class Conv1x1(torch.nn.Module): 270 | r''' Performs a convolution with kernel :obj:`(1,1,1)`. 271 | 272 | The shape of octree features is :obj:`(N, C)`, where :obj:`N` is the node 273 | number and :obj:`C` is the feature channel. Therefore, :class:`Conv1x1` can be 274 | implemented with :class:`torch.nn.Linear`. 275 | ''' 276 | 277 | def __init__(self, in_channels: int, out_channels: int, use_bias: bool = False): 278 | super().__init__() 279 | self.linear = torch.nn.Linear(in_channels, out_channels, use_bias) 280 | 281 | def forward(self, data: torch.Tensor): 282 | r'''''' 283 | 284 | return self.linear(data) 285 | 286 | 287 | class Conv1x1Bn(torch.nn.Module): 288 | r''' A sequence of :class:`Conv1x1` and :class:`BatchNorm`. 289 | ''' 290 | 291 | def __init__(self, in_channels: int, out_channels: int): 292 | super().__init__() 293 | self.conv = Conv1x1(in_channels, out_channels, use_bias=False) 294 | self.bn = normalization(out_channels) 295 | 296 | def forward(self, data: torch.Tensor): 297 | r'''''' 298 | 299 | out = self.conv(data) 300 | out = self.bn(out) 301 | return out 302 | 303 | 304 | class Conv1x1BnRelu(torch.nn.Module): 305 | r''' A sequence of :class:`Conv1x1`, :class:`BatchNorm` and :class:`Relu`. 306 | ''' 307 | 308 | def __init__(self, in_channels: int, out_channels: int): 309 | super().__init__() 310 | self.conv = Conv1x1(in_channels, out_channels, use_bias=False) 311 | self.bn = normalization(out_channels) 312 | self.relu = torch.nn.ReLU(inplace=True) 313 | 314 | def forward(self, data: torch.Tensor): 315 | r'''''' 316 | 317 | out = self.conv(data) 318 | out = self.bn(out) 319 | out = self.relu(out) 320 | return out 321 | 322 | 323 | class FcBnRelu(torch.nn.Module): 324 | r''' A sequence of :class:`FC`, :class:`BatchNorm` and :class:`Relu`. 325 | ''' 326 | 327 | def __init__(self, in_channels: int, out_channels: int): 328 | super().__init__() 329 | self.flatten = torch.nn.Flatten(start_dim=1) 330 | self.fc = torch.nn.Linear(in_channels, out_channels, bias=False) 331 | self.bn = normalization(out_channels) 332 | self.relu = torch.nn.ReLU(inplace=True) 333 | 334 | def forward(self, data): 335 | r'''''' 336 | 337 | out = self.flatten(data) 338 | out = self.fc(out) 339 | out = self.bn(out) 340 | out = self.relu(out) 341 | return out -------------------------------------------------------------------------------- /hgraph/modules/resblocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.checkpoint 3 | 4 | from hgraph.hgraph import HGraph 5 | from hgraph.modules.modules import Conv1x1BnRelu, Conv1x1, Conv1x1Bn, \ 6 | GraphConv, GraphConvBnRelu, GraphConvBn, FcBnRelu, \ 7 | UnpoolingGraph, PoolingGraph 8 | 9 | 10 | class GraphResBlock(torch.nn.Module): 11 | def __init__(self, in_channels: int, out_channels: int, 12 | bottleneck: int=4): 13 | super().__init__() 14 | self.in_channels = in_channels 15 | self.out_channels = out_channels 16 | self.bottleneck = bottleneck 17 | channelb = int(out_channels / bottleneck) 18 | 19 | self.conv1x1a = Conv1x1BnRelu(in_channels, channelb) 20 | self.conv3x3 = GraphConvBnRelu(channelb, channelb) 21 | self.conv1x1b = Conv1x1Bn(channelb, out_channels) 22 | 23 | if self.in_channels != self.out_channels: 24 | self.conv1x1c = Conv1x1Bn(in_channels, out_channels) 25 | self.relu = torch.nn.ReLU(inplace=True) 26 | 27 | def forward(self, data: torch.Tensor, hgraph: HGraph, depth: int): 28 | conv1 = self.conv1x1a(data) 29 | conv2 = self.conv3x3(conv1, hgraph, depth) 30 | conv3 = self.conv1x1b(conv2) 31 | if self.in_channels != self.out_channels: 32 | data = self.conv1x1c(data) 33 | out = self.relu(conv3 + data) 34 | return out 35 | 36 | 37 | 38 | 39 | class GraphResBlock2(torch.nn.Module): 40 | def __init__(self, in_channels: int, out_channels: int, 41 | bottleneck: int=4): 42 | super().__init__() 43 | self.in_channels = in_channels 44 | self.out_channels = out_channels 45 | self.bottleneck = bottleneck 46 | channelb = int(out_channels / bottleneck) 47 | 48 | self.conv3x3a = GraphConvBnRelu(in_channels, channelb) 49 | self.conv3x3b = GraphConvBn(channelb, out_channels) 50 | 51 | if self.in_channels != self.out_channels: 52 | self.conv1x1 = Conv1x1Bn(in_channels, out_channels) 53 | self.relu = torch.nn.ReLU(inplace=True) 54 | 55 | def forward(self, data: torch.Tensor, hgraph: HGraph, depth: int): 56 | conv1 = self.conv3x3a(data, hgraph, depth) 57 | conv2 = self.conv3x3b(conv1, hgraph, depth) 58 | if self.in_channels != self.out_channels: 59 | data = self.conv1x1(data) 60 | out = self.relu(conv2 + data) 61 | return out 62 | 63 | 64 | 65 | class GraphResBlocks(torch.nn.Module): 66 | def __init__(self, in_channels, out_channels, 67 | resblk_num, bottleneck=4, 68 | resblk=GraphResBlock): 69 | super().__init__() 70 | self.resblk_num = resblk_num 71 | channels = [in_channels] + [out_channels] * resblk_num 72 | 73 | self.resblks = torch.nn.ModuleList( 74 | [resblk(channels[i], channels[i+1], bottleneck=bottleneck) for i in range(self.resblk_num)] 75 | ) 76 | 77 | def forward(self, data: torch.Tensor, hgraph: HGraph, depth: int): 78 | for i in range(self.resblk_num): 79 | data = self.resblks[i](data, hgraph, depth) 80 | 81 | return data 82 | 83 | -------------------------------------------------------------------------------- /img/bunny.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelligentGeometry/GeGnn/e9c1e79e521cb97690791af3340cae64899f9d54/img/bunny.jpg -------------------------------------------------------------------------------- /img/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelligentGeometry/GeGnn/e9c1e79e521cb97690791af3340cae64899f9d54/img/teaser.png -------------------------------------------------------------------------------- /img/teaser.v1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelligentGeometry/GeGnn/e9c1e79e521cb97690791af3340cae64899f9d54/img/teaser.v1.png -------------------------------------------------------------------------------- /pretrained/ours00500.solver.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelligentGeometry/GeGnn/e9c1e79e521cb97690791af3340cae64899f9d54/pretrained/ours00500.solver.tar -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Learning the Geodesic Embedding with Graph Neural Networks 2 | 3 | 4 | 5 | [**Learning the Geodesic Embedding with Graph Neural Networks**](https://arxiv.org/abs/2309.05613)
6 | [Bo Pang](https://github.com/skinboC), [Zhongtian Zheng](https://github.com/zzttzz), Guoping Wang, and [Peng-Shuai Wang](https://wang-ps.github.io/)
7 | ACM Transactions on Graphics (SIGGRAPH Asia), 42(6), 2023 8 | 9 | ![](img/teaser.v1.png) 10 | 11 | - [Learning the Geodesic Embedding with Graph Neural Networks](#learning-the-geodesic-embedding-with-graph-neural-networks) 12 | - [1. Environment](#1-environment) 13 | - [2. Prepare Data](#2-prepare-data) 14 | - [3. Train](#3-train) 15 | - [4. Test and Visualization](#4-test-and-visualization) 16 | - [5. Citation](#5-citation) 17 | 18 | 19 | ## 1. Environment 20 | 21 | First, please install pytorch that fits your cuda version. 22 | 23 | Then, install torch geometric: 24 | 25 | ``` 26 | conda install pyg -c pyg 27 | ``` 28 | 29 | Then install the packages required by this project: 30 | 31 | ``` 32 | pip3 install -r requirements.txt 33 | ``` 34 | 35 | ## 2. Prepare Data 36 | 37 | Before training, you have to generate training data, as described in 4.1 of the paper. Please note the model may not generalize well on shapes very different from the training data, as suggested in the paper. 38 | 39 | Suppose you have your meshes in `path/to/meshes`, we provide a script to generate training data from these meshes. Please open `utils/dataset_prepare/py_data_generator.py`, and change the following lines: 40 | 41 | ```python 42 | PATH_TO_MESH = "path/to/your/mesh/folder/" 43 | PATH_TO_OUTPUT_NPZ = "path/to/your/output/folder/" 44 | PATH_TO_OUTPUT_FILELIST = "path/to/your/another/output/folder" 45 | ``` 46 | 47 | Then, please run the script: 48 | 49 | ``` 50 | python utils/dataset_prepare/py_data_generator.py 51 | ``` 52 | 53 | This will load meshes (.obj files) and generate the processed `.npz` files. You can change `a, b, c, d` in that file to adjust the property of training data. 54 | 55 | We will upload the processed data we used in our paper to Google Drive soon. 56 | 57 | There is also a mesh processing script in `utils/dataset_prepare/simplifiy.py`, which is a utility tool we used to process the meshes. 58 | 59 | ## 3. Train 60 | 61 | To train the network, you can type: 62 | 63 | ```shell 64 | python3 GnnDist.py --config configs/gnndist_test.yaml 65 | ``` 66 | 67 | We train our model on Ubuntu 20.04 (4 Nvidia 3090 GPUs, 24GB VRAM) with batch size 10. If your GPU memory is not enough, please reduce the batch size in the config file. 68 | 69 | A checkpoint of our model is provided in `pretrained/ours00500.solver.tar`. 70 | 71 | ## 4. Test and Visualization 72 | 73 | ![](img/bunny.jpg) 74 | 75 | We provide a script to test the network and visualize the results. The following command tests our method on the specified mesh and will open a polyscope window for visualization. (If you are using ssh and cannot open the window, please delete polyscope-related code in `GeGnn_standalong.py`.) 76 | 77 | Feel free to change `--test_file` to test on other meshes and change `--start_pts` to test on different source points. 78 | 79 | ```shell 80 | python3 GeGnn_standalong.py --mode SSAD --test_file data/test_mesh/bunny.obj --ckpt_path pretrained/ours00500.solver.tar --start_pts 0 --output out/ssad_ours.npy 81 | ``` 82 | 83 | This will open a polyscope window and show the results. 84 | 85 | ## 5. Citation 86 | 87 | If you find this project useful for your research, please kindly cite our paper: 88 | 89 | ```bibtex 90 | @article{pang2023gegnn, 91 | title={Learning the Geodesic Embedding with Graph Neural Networks}, 92 | author={Pang, Bo and Zheng, Zhongtian and Wang, Guoping and Wang, Peng-Shuai}, 93 | journal={ACM Transactions on Graphics (SIGGRAPH Asia)}, 94 | year={2023} 95 | 96 | } 97 | ``` 98 | 99 | If you have any questions, please feel free to contact us at bo98@stu.pku.edu.cn 100 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib==3.6.1 2 | numpy==1.23.4 3 | PyMCubes==0.1.2 4 | scikit_learn==1.1.3 5 | tqdm==4.64.1 6 | websockets==10.3 7 | tensorboard 8 | pygeodesic 9 | trimesh 10 | yacs 11 | polyscope -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelligentGeometry/GeGnn/e9c1e79e521cb97690791af3340cae64899f9d54/utils/__init__.py -------------------------------------------------------------------------------- /utils/dataset_prepare/py_data_generater.py: -------------------------------------------------------------------------------- 1 | # 2 | import trimesh 3 | import pygeodesic.geodesic as geodesic 4 | import numpy as np 5 | import os 6 | import multiprocessing as mp 7 | from threading import Thread 8 | from tqdm import tqdm 9 | 10 | 11 | 12 | PATH_TO_MESH = "./data/tiny/meshes/" 13 | PATH_TO_OUTPUT_NPZ = "./data/tiny/npz/" 14 | PATH_TO_OUTPUT_FILELIST = "./data/tiny/filelist/" 15 | 16 | TRAINING_SPLIT_RATIO = 0.8 17 | 18 | 19 | def visualize_ssad(vertices: np.ndarray, triangles: np.ndarray, source_index: int): 20 | # Initialise the PyGeodesicAlgorithmExact class instance 21 | geoalg = geodesic.PyGeodesicAlgorithmExact(vertices, triangles) 22 | 23 | # Define the source and target point ids with respect to the points array 24 | source_indices = np.array([source_index]) 25 | target_indices = None 26 | distances, best_source = geoalg.geodesicDistances(source_indices, target_indices) 27 | return distances 28 | 29 | 30 | def visualize_two_pts(vertices: np.ndarray, triangles: np.ndarray, source_index: int, dest_index: int): 31 | # Initialise the PyGeodesicAlgorithmExact class instance 32 | geoalg = geodesic.PyGeodesicAlgorithmExact(vertices, triangles) 33 | 34 | # Define the source and target point ids with respect to the points array 35 | source_indices = np.array([source_index]) 36 | target_indices = np.array([dest_index]) 37 | distances, best_source = geoalg.geodesicDistance(source_indices, target_indices) 38 | return distances 39 | 40 | 41 | def data_prepare_ssad(object_file: str, output_path: str, source_index: int): 42 | vertices = [] 43 | triangles = [] 44 | 45 | with open(object_file, "r") as f: 46 | lines = f.readlines() 47 | for each in lines: 48 | if len(each) < 2: 49 | continue 50 | if each[0:2] == "v ": 51 | temp = each.split() 52 | vertices.append([float(temp[1]), float(temp[2]), float(temp[3])]) 53 | if each[0:2] == "f ": 54 | temp = each.split() 55 | # 56 | temp[3] = temp[3].split("/")[0] 57 | temp[1] = temp[1].split("/")[0] 58 | temp[2] = temp[2].split("/")[0] 59 | triangles.append([int(temp[1]) - 1, int(temp[2]) - 1, int(temp[3]) - 1]) 60 | vertices = np.array(vertices) 61 | triangles = np.array(triangles) 62 | 63 | 64 | # Initialise the PyGeodesicAlgorithmExact class instance 65 | geoalg = geodesic.PyGeodesicAlgorithmExact(vertices, triangles) 66 | 67 | # Define the source and target point ids with respect to the points array 68 | source_indices = np.array([source_index]) 69 | target_indices = None 70 | distances, best_source = geoalg.geodesicDistances(source_indices, target_indices) 71 | 72 | 73 | def data_prepare_gen_dataset(object_file: str, output_path: str, num_sources, num_each_dest, tqdm_on=True): 74 | vertices = [] 75 | triangles = [] 76 | 77 | with open(object_file, "r") as f: 78 | lines = f.readlines() 79 | for each in lines: 80 | if len(each) < 2: 81 | continue 82 | if each[0:2] == "v ": 83 | temp = each.split() 84 | vertices.append([float(temp[1]), float(temp[2]), float(temp[3])]) 85 | if each[0:2] == "f ": 86 | temp = each.split() 87 | # 88 | temp[3] = temp[3].split("/")[0] 89 | temp[1] = temp[1].split("/")[0] 90 | temp[2] = temp[2].split("/")[0] 91 | triangles.append([int(temp[1]) - 1, int(temp[2]) - 1, int(temp[3]) - 1]) 92 | vertices = np.array(vertices) 93 | triangles = np.array(triangles) 94 | 95 | 96 | # Initialise the PyGeodesicAlgorithmExact class instance 97 | geoalg = geodesic.PyGeodesicAlgorithmExact(vertices, triangles) 98 | 99 | result = np.array([[0,0,0]]) 100 | 101 | sources = np.random.randint(low=0, high=len(vertices), size=[num_sources]) 102 | 103 | # iterate 104 | # only the process on process #0 will be displayed 105 | # this should not be problematic or confusing on most homogeneous CPUs 106 | it = tqdm(range(num_sources)) if tqdm_on else range(num_sources) 107 | for i in it: 108 | source_indices = np.array([sources[i]]) 109 | target_indices = np.random.randint(low=0, high=len(vertices), size=[num_each_dest]) 110 | if source_indices.max() >= len(vertices): 111 | print("!!!!!!!", source_indices.max(), len(vertices)) 112 | if target_indices.max() >= len(vertices): 113 | print("!!!!!!!", target_indices.max(), len(vertices)) 114 | 115 | distances, best_source = geoalg.geodesicDistances(source_indices, target_indices) 116 | 117 | a = source_indices.repeat([num_each_dest]).reshape([-1,1]) 118 | b = target_indices.reshape([-1,1]) 119 | c = distances.reshape([-1,1]) 120 | new = np.concatenate([a, b, c], -1) 121 | result = np.concatenate([result, new]) 122 | 123 | np.savetxt(output_path, result) 124 | 125 | 126 | ############################################# 127 | def computation_thread(filename, object_name, a, b, c, d, idx=None): 128 | assert idx != None, "an idx has to be given" 129 | tqdm_on = False 130 | if idx == 0: 131 | tqdm_on = True 132 | print(filename, object_name) 133 | data_prepare_gen_dataset(filename, object_name + "_train_" + str(idx), a, b, tqdm_on=tqdm_on) 134 | # data_prepare_gen_dataset(filename, "utils/dataset_prepare/outputs/" + object_name + "_test_" + str(idx), c, d, tqdm_on=tqdm_on) 135 | 136 | 137 | ############################################## 138 | 139 | 140 | 141 | 142 | if __name__ == "__main__": 143 | ''' 144 | 145 | 146 | a: on training set, how many sources to randomly sample. 147 | b: on training set, for each source, how many dest to randomly sample. 148 | c: on testing set, how many sources to randomly sample. 149 | d: on testing set, for each source, how many dest to randomly sample. 150 | threads: how many threads to use. 0 means all cores. 151 | 152 | ''' 153 | 154 | ############################################################# 155 | object_name = None 156 | a = 300 157 | b = 800 158 | c = 400 159 | d = 60 160 | file_size_threshold = 12_048_576 # a threshold to filter out large meshes 161 | threads = 1 162 | ############################################################# 163 | 164 | assert threads >= 0 and type(threads) == int 165 | if threads == 0: 166 | threads = mp.cpu_count() 167 | print(f"Automatically utilize all CPU cores ({threads})") 168 | else: 169 | print(f"{threads} CPU cores are utilized!") 170 | 171 | # make dirs, if not exist 172 | if os.path.exists(PATH_TO_OUTPUT_NPZ) == False: 173 | os.mkdir(PATH_TO_OUTPUT_NPZ) 174 | if os.path.exists(PATH_TO_OUTPUT_FILELIST) == False: 175 | os.mkdir(PATH_TO_OUTPUT_FILELIST) 176 | 177 | all_files = [] 178 | for mesh in os.listdir(PATH_TO_MESH): 179 | # check if the file is too large 180 | if os.path.getsize(PATH_TO_MESH + mesh) < file_size_threshold: 181 | all_files.append(PATH_TO_MESH + mesh) 182 | 183 | 184 | 185 | #all_files = os.listdir("./inputs") 186 | 187 | object_names = all_files 188 | for i in range(len(object_names)): 189 | if object_names[i][-4:] == ".obj": 190 | object_names[i] = object_names[i][:-4] 191 | 192 | 193 | print(f"Current dir: {os.getcwd()}, object to be processed: {len(object_names)}") 194 | 195 | 196 | # handle the case when the output file already exists 197 | for i in tqdm(range(len(object_names))): 198 | object_name = object_names[i] 199 | if object_name.split("/")[-1][0] == ".": 200 | continue # not an obj file 201 | 202 | filename_out = PATH_TO_OUTPUT_NPZ + object_name + ".npz" 203 | if os.path.exists(filename_out): 204 | continue 205 | 206 | filename = object_name + ".obj" 207 | 208 | 209 | train_data_filename_list = [] 210 | test_data_filename_list = [] 211 | 212 | 213 | pool = [] 214 | 215 | for t in range(threads): 216 | task = mp.Process(target=computation_thread, args=(filename, object_name, a//threads,b,c//threads,d, t,)) 217 | task.start() 218 | pool.append(task) 219 | for t, task in enumerate(pool): 220 | task.join() 221 | train_data_filename_list.append(object_name + "_train_" + str(t)) 222 | #test_data_filename_list.append("./utils/dataset_prepare/outputs/" + object_name + "_test_" + str(t)) 223 | #breakpoint() 224 | # 整合多线程的结果到一起 225 | #print(object_name) 226 | try: 227 | for i in range(len(train_data_filename_list)): 228 | # train data 229 | with open(object_name + "_train_" + str(i), "r") as f: 230 | data = f.read() 231 | with open(object_name + "_train", "a") as f: 232 | f.write(data) 233 | except: 234 | #print("Error on " + object_name + ", this is mostly due to non-manifold (failed to initialise the PyGeodesicAlgorithmExact class instance)") 235 | continue 236 | 237 | 238 | # 清理掉中间过程文件 239 | #print("qqqqqq") 240 | #breakpoint() 241 | for each in (train_data_filename_list + test_data_filename_list): 242 | os.remove(each) 243 | 244 | filename_in = object_name + ".obj" 245 | dist_in = object_name + "_train" 246 | filename_out = PATH_TO_OUTPUT_NPZ + object_name.split("/")[-1] + ".npz" 247 | try: 248 | # 由于下述问题,有时trimesh loader会返回错误的顶点数量(并在训练时导致数组越界) 249 | # https://github.com/mikedh/trimesh/issues/489 250 | # 参考sanity check部分,对于出现这种错误的mesh,我们采取最直接的解决方案:放弃该mesh 251 | # 当使用QEM方法简化网格的时候,似乎更可能出现这个问题,出现率估计约百分之一到千分之一。 252 | # 下面的mesh就是一个例子:(1.3 remeshing + 0.85 QEM) 253 | # /mnt/sdb/pangbo/gnn-dist/utils/dataset_prepare/simplified_obj/02828884/26f583c91e815e8fcb2a965e75be701c.obj 254 | mesh = trimesh.load_mesh(filename_in) 255 | dist = np.loadtxt(dist_in) 256 | except Exception: 257 | print(f"load {filename_in} or {dist_in} failed...") 258 | continue 259 | 260 | # delete the dist_in 261 | os.remove(dist_in) 262 | 263 | # 额外保存图的拓扑结构(edges) 264 | # trimesh 的 .edge_unique 可以得到该网格的所有边(不重复) 265 | aa = mesh.edges_unique 266 | # 我们需要双向边,所以还需一点点额外处理 267 | bb = np.concatenate([aa[:, 1:], aa[:, :1]], 1) 268 | cc = np.concatenate([aa, bb]) 269 | 270 | # breakpoint() 271 | # sanity check 272 | vertices = mesh.vertices 273 | if dist.max() > 100000000: 274 | print("inf encountered!!") 275 | elif ((dist.astype(np.float32).max()) >= vertices.shape[0]): 276 | print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") 277 | print(object_name, "encountered trimesh loading error!!") 278 | continue 279 | 280 | np.savez(filename_out, 281 | edges=cc, 282 | vertices=mesh.vertices.astype(np.float32), 283 | normals=mesh.vertex_normals.astype(np.float32), 284 | faces=mesh.faces.astype(np.float32), 285 | dist_val=dist[:, 2:].astype(np.float32), 286 | dist_idx=dist[:, :2].astype(np.uint16), 287 | ) 288 | 289 | 290 | 291 | 292 | 293 | 294 | print("\nnpz data generation finished. Now generating filelist...\n") 295 | # 最后生成 filelist 296 | lines = [] 297 | breakpoint() 298 | for each in tqdm(object_names): 299 | filename_out = PATH_TO_OUTPUT_NPZ + each.split("/")[-1] + ".npz" 300 | try: 301 | dist = np.load(filename_out) 302 | # sanity check 303 | if dist['dist_val'].max() != np.inf and dist['dist_val'].max() < 1000000000: 304 | lines.append(filename_out + "\n") 305 | else: 306 | print(f"{filename_out} not good, contains inf!") 307 | continue 308 | except Exception: 309 | print(f"load {filename_out} failed for unknown reason.") 310 | continue 311 | 312 | import random 313 | random.shuffle(lines) 314 | # split 315 | train_num = int(len(lines) * TRAINING_SPLIT_RATIO) 316 | test_num = len(lines) - train_num 317 | train_lines = lines[:train_num] 318 | test_lines = lines[train_num:] 319 | 320 | with open(PATH_TO_OUTPUT_FILELIST + 'filelist_train.txt', 'w') as f: 321 | f.writelines(train_lines) 322 | 323 | with open(PATH_TO_OUTPUT_FILELIST + 'filelist_test.txt', 'w') as f: 324 | f.writelines(test_lines) -------------------------------------------------------------------------------- /utils/dataset_prepare/simplifiy.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import sys 4 | import pymeshlab 5 | from tqdm import tqdm 6 | import multiprocessing as mp 7 | from threading import Thread 8 | 9 | def walk_dir(dir): 10 | full_path = [] 11 | filename = [] 12 | for root, dirs, files in os.walk(dir): 13 | for file in files: 14 | if file.endswith(".obj"): 15 | full_path.append(os.path.join(root, file)) 16 | filename.append(file) 17 | return [full_path, filename] 18 | 19 | 20 | def simplify_mesh(input_mesh, output_mesh): 21 | if os.path.exists(input_mesh) == False: 22 | return 23 | if os.path.exists(output_mesh): 24 | return 25 | 26 | ms = pymeshlab.MeshSet() 27 | ms.load_new_mesh(input_mesh) 28 | m = ms.current_mesh() 29 | num_v = m.vertex_matrix().shape[0] 30 | num_f = m.face_matrix().shape[0] 31 | 32 | # 33 | simplify_method = "combined" 34 | 35 | if simplify_method == "QEM": 36 | # not recommended. QEM may produce very strange topology/degenerate triangles, which is not good for training 37 | ms.meshing_decimation_quadric_edge_collapse(targetfacenum=2000+num_f//4) 38 | elif simplify_method == "remesh": 39 | # the distribution of triangles is uniform. Not so good since we want to make the network "exposed" to more complex triangulations 40 | ms.meshing_isotropic_explicit_remeshing(iterations=6, targetlen=pymeshlab.Percentage(1.5)) 41 | elif simplify_method == "combined": 42 | # used in our paper 43 | ms.meshing_isotropic_explicit_remeshing(iterations=6, targetlen=pymeshlab.Percentage(1.35)) 44 | ms.meshing_decimation_quadric_edge_collapse(targetfacenum=int(num_f*0.85)) 45 | else: 46 | raise NotImplementedError 47 | 48 | ms.save_current_mesh(output_mesh) 49 | # try: 50 | # pass 51 | # 52 | # except Exception: 53 | # print(f"Mesh {input_mesh} failed") 54 | 55 | 56 | 57 | if __name__ == "__main__": 58 | 59 | threads = 4 60 | dir = "path/to/your/dataset" 61 | output_dir = "path/to/your/output/directory" 62 | 63 | 64 | [full_path, filename] = walk_dir(dir) 65 | 66 | try: 67 | os.mkdir(output_dir) 68 | except: 69 | pass 70 | 71 | for i in tqdm(range(len(full_path))): 72 | r = i % threads 73 | pool = [] 74 | 75 | input_mesh = full_path[i] 76 | output_mesh = output_dir+filename[i] 77 | task = mp.Process(target=simplify_mesh, args=(input_mesh, output_mesh)) 78 | task.start() 79 | pool.append(task) 80 | 81 | if r == 0: 82 | for t, task in enumerate(pool): 83 | task.join() 84 | 85 | 86 | -------------------------------------------------------------------------------- /utils/ocnn/__init__.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Octree-based Sparse Convolutional Neural Networks 3 | # Copyright (c) 2022 Peng-Shuai Wang 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | from . import octree 9 | from . import nn 10 | from . import modules 11 | from . import models 12 | from . import dataset 13 | from . import utils 14 | 15 | 16 | __version__ = '2.1.8' 17 | 18 | __all__ = [ 19 | 'octree', 20 | 'nn', 21 | 'modules', 22 | 'models', 23 | 'dataset', 24 | 'utils' 25 | ] 26 | -------------------------------------------------------------------------------- /utils/ocnn/dataset.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Octree-based Sparse Convolutional Neural Networks 3 | # Copyright (c) 2022 Peng-Shuai Wang 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | import torch 9 | 10 | from utils import ocnn 11 | from utils.ocnn.octree import Octree, Points 12 | 13 | 14 | __all__ = ['Transform', 'CollateBatch'] 15 | classes = __all__ 16 | 17 | 18 | class Transform: 19 | r''' A boilerplate class which transforms an input data for :obj:`ocnn`. 20 | The input data is first converted to :class:`Points`, then randomly transformed 21 | (if enabled), and converted to an :class:`Octree`. 22 | 23 | Args: 24 | depth (int): The octree depth. 25 | full_depth (int): The octree layers with a depth small than 26 | :attr:`full_depth` are forced to be full. 27 | distort (bool): If true, performs the data augmentation. 28 | angle (list): A list of 3 float values to generate random rotation angles. 29 | interval (list): A list of 3 float values to represent the interval of 30 | rotation angles. 31 | scale (float): The maximum relative scale factor. 32 | uniform (bool): If true, performs uniform scaling. 33 | jittor (float): The maximum jitter values. 34 | orient_normal (str): Orient point normals along the specified axis, which is 35 | useful when normals are not oriented. 36 | ''' 37 | 38 | def __init__(self, depth: int, full_depth: int, distort: bool, angle: list, 39 | interval: list, scale: float, uniform: bool, jitter: float, 40 | orient_normal: str = '', **kwargs): 41 | super().__init__() 42 | 43 | # for octree building 44 | self.depth = depth 45 | self.full_depth = full_depth 46 | 47 | # for data augmentation 48 | self.distort = distort 49 | self.angle = angle 50 | self.interval = interval 51 | self.scale = scale 52 | self.uniform = uniform 53 | self.jitter = jitter 54 | 55 | # for other transformations 56 | self.orient_normal = orient_normal 57 | 58 | def __call__(self, sample: dict, idx: int): 59 | r'''''' 60 | 61 | points = self.preprocess(sample, idx) 62 | output = self.transform(points, idx) 63 | output['octree'] = self.points2octree(output['points']) 64 | return output 65 | 66 | def preprocess(self, sample: dict, idx: int): 67 | r''' Transforms :attr:`sample` to :class:`Points` and performs some specific 68 | transformations, like normalization. 69 | ''' 70 | 71 | xyz = torch.from_numpy(sample['points']) 72 | normals = torch.from_numpy(sample['normals']) 73 | points = Points(xyz, normals) 74 | return points 75 | 76 | def transform(self, points: Points, idx: int): 77 | r''' Applies the general transformations provided by :obj:`ocnn`. 78 | ''' 79 | 80 | # The augmentations including rotation, scaling, and jittering. 81 | if self.distort: 82 | rng_angle, rng_scale, rng_jitter = self.rnd_parameters() 83 | points.rotate(rng_angle) 84 | points.translate(rng_jitter) 85 | points.scale(rng_scale) 86 | 87 | if self.orient_normal: 88 | points.orient_normal(self.orient_normal) 89 | 90 | # !!! NOTE: Clip the point cloud to [-1, 1] before building the octree 91 | inbox_mask = points.clip(min=-1, max=1) 92 | return {'points': points, 'inbox_mask': inbox_mask} 93 | 94 | def points2octree(self, points: Points): 95 | r''' Converts the input :attr:`points` to an octree. 96 | ''' 97 | 98 | octree = Octree(self.depth, self.full_depth) 99 | octree.build_octree(points) 100 | return octree 101 | 102 | def rnd_parameters(self): 103 | r''' Generates random parameters for data augmentation. 104 | ''' 105 | 106 | rnd_angle = [None] * 3 107 | for i in range(3): 108 | rot_num = self.angle[i] // self.interval[i] 109 | rnd = torch.randint(low=-rot_num, high=rot_num+1, size=(1,)) 110 | rnd_angle[i] = rnd * self.interval[i] * (3.14159265 / 180.0) 111 | rnd_angle = torch.cat(rnd_angle) 112 | 113 | rnd_scale = torch.rand(3) * (2 * self.scale) - self.scale + 1.0 114 | if self.uniform: 115 | rnd_scale[1] = rnd_scale[0] 116 | rnd_scale[2] = rnd_scale[0] 117 | 118 | rnd_jitter = torch.rand(3) * (2 * self.jitter) - self.jitter 119 | return rnd_angle, rnd_scale, rnd_jitter 120 | 121 | 122 | class CollateBatch: 123 | r''' Merge a list of octrees and points into a batch. 124 | ''' 125 | 126 | def __init__(self, merge_points: bool = False): 127 | self.merge_points = merge_points 128 | 129 | def __call__(self, batch: list): 130 | assert type(batch) == list 131 | 132 | outputs = {} 133 | for key in batch[0].keys(): 134 | outputs[key] = [b[key] for b in batch] 135 | 136 | # Merge a batch of octrees into one super octree 137 | if 'octree' in key: 138 | octree = ocnn.octree.merge_octrees(outputs[key]) 139 | # NOTE: remember to construct the neighbor indices 140 | octree.construct_all_neigh() 141 | outputs[key] = octree 142 | 143 | # Merge a batch of points 144 | if 'points' in key and self.merge_points: 145 | outputs[key] = ocnn.octree.merge_points(outputs[key]) 146 | 147 | # Convert the labels to a Tensor 148 | if 'label' in key: 149 | outputs['label'] = torch.tensor(outputs[key]) 150 | 151 | return outputs 152 | -------------------------------------------------------------------------------- /utils/ocnn/models/__init__.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Octree-based Sparse Convolutional Neural Networks 3 | # Copyright (c) 2022 Peng-Shuai Wang 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | from .lenet import LeNet 9 | from .resnet import ResNet 10 | from .segnet import SegNet 11 | from .octree_unet import OctreeUnet 12 | from .hrnet import HRNet 13 | from .autoencoder import AutoEncoder 14 | 15 | __all__ = [ 16 | 'LeNet', 17 | 'ResNet', 18 | 'SegNet', 19 | 'UNet', 20 | 'HRNet', 21 | 'AutoEncoder', 22 | ] 23 | 24 | classes = __all__ 25 | -------------------------------------------------------------------------------- /utils/ocnn/models/autoencoder.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Octree-based Sparse Convolutional Neural Networks 3 | # Copyright (c) 2022 Peng-Shuai Wang 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | import torch 9 | import torch.nn 10 | 11 | from utils import ocnn 12 | from utils.ocnn.octree import Octree 13 | 14 | 15 | class AutoEncoder(torch.nn.Module): 16 | r''' Octree-based AutoEncoder for shape encoding and decoding. 17 | 18 | Args: 19 | channel_in (int): The channel of the input signal. 20 | channel_out (int): The channel of the output signal. 21 | depth (int): The depth of the octree. 22 | full_depth (int): The full depth of the octree. 23 | feature (str): The feature type of the input signal. For details of this 24 | argument, please refer to :class:`ocnn.modules.InputFeature`. 25 | ''' 26 | 27 | def __init__(self, channel_in: int, channel_out: int, depth: int, 28 | full_depth: int = 2, feature: str = 'ND'): 29 | super().__init__() 30 | self.channel_in = channel_in 31 | self.channel_out = channel_out 32 | self.depth = depth 33 | self.full_depth = full_depth 34 | self.feature = feature 35 | self.resblk_num = 2 36 | self.shape_code_channel = 128 37 | self.channels = [512, 512, 256, 256, 128, 128, 32, 32, 16, 16] 38 | 39 | # encoder 40 | self.conv1 = ocnn.modules.OctreeConvBnRelu( 41 | channel_in, self.channels[depth], nempty=False) 42 | self.encoder_blks = torch.nn.ModuleList([ocnn.modules.OctreeResBlocks( 43 | self.channels[d], self.channels[d], self.resblk_num, nempty=False) 44 | for d in range(depth, full_depth-1, -1)]) 45 | self.downsample = torch.nn.ModuleList([ocnn.modules.OctreeConvBnRelu( 46 | self.channels[d], self.channels[d-1], kernel_size=[2], stride=2, 47 | nempty=False) for d in range(depth, full_depth, -1)]) 48 | self.proj = torch.nn.Linear( 49 | self.channels[full_depth], self.shape_code_channel, bias=True) 50 | 51 | # decoder 52 | self.channels[full_depth] = self.shape_code_channel # update `channels` 53 | self.upsample = torch.nn.ModuleList([ocnn.modules.OctreeDeconvBnRelu( 54 | self.channels[d-1], self.channels[d], kernel_size=[2], stride=2, 55 | nempty=False) for d in range(full_depth+1, depth+1)]) 56 | self.decoder_blks = torch.nn.ModuleList([ocnn.modules.OctreeResBlocks( 57 | self.channels[d], self.channels[d], self.resblk_num, nempty=False) 58 | for d in range(full_depth, depth+1)]) 59 | 60 | # header 61 | self.predict = torch.nn.ModuleList([self._make_predict_module( 62 | self.channels[d], 2) for d in range(full_depth, depth + 1)]) 63 | self.header = self._make_predict_module(self.channels[depth], channel_out) 64 | 65 | def _make_predict_module(self, channel_in, channel_out=2, num_hidden=64): 66 | return torch.nn.Sequential( 67 | ocnn.modules.Conv1x1BnRelu(channel_in, num_hidden), 68 | ocnn.modules.Conv1x1(num_hidden, channel_out, use_bias=True)) 69 | 70 | def get_input_feature(self, octree: Octree): 71 | r''' Get the input feature from the input `octree`. 72 | ''' 73 | 74 | octree_feature = ocnn.modules.InputFeature(self.feature, nempty=False) 75 | out = octree_feature(octree) 76 | assert out.size(1) == self.channel_in 77 | return out 78 | 79 | def ae_encoder(self, octree: Octree): 80 | r''' The encoder network of the AutoEncoder. 81 | ''' 82 | 83 | convs = dict() 84 | depth, full_depth = self.depth, self.full_depth 85 | data = self.get_input_feature(octree) 86 | convs[depth] = self.conv1(data, octree, depth) 87 | for i, d in enumerate(range(depth, full_depth-1, -1)): 88 | convs[d] = self.encoder_blks[i](convs[d], octree, d) 89 | if d > full_depth: 90 | convs[d-1] = self.downsample[i](convs[d], octree, d) 91 | 92 | # NOTE: here tanh is used to constrain the shape code in [-1, 1] 93 | shape_code = self.proj(convs[full_depth]).tanh() 94 | return shape_code 95 | 96 | def ae_decoder(self, shape_code: torch.Tensor, octree: Octree, 97 | update_octree: bool = False): 98 | r''' The decoder network of the AutoEncoder. 99 | ''' 100 | 101 | logits = dict() 102 | deconv = shape_code 103 | depth, full_depth = self.depth, self.full_depth 104 | for i, d in enumerate(range(full_depth, depth+1)): 105 | if d > full_depth: 106 | deconv = self.upsample[i-1](deconv, octree, d-1) 107 | deconv = self.decoder_blks[i](deconv, octree, d) 108 | 109 | # predict the splitting label 110 | logit = self.predict[i](deconv) 111 | logits[d] = logit 112 | 113 | # update the octree according to predicted labels 114 | if update_octree: 115 | split = logit.argmax(1).int() 116 | octree.octree_split(split, d) 117 | if d < depth: 118 | octree.octree_grow(d + 1) 119 | 120 | # predict the signal 121 | if d == depth: 122 | signal = self.header(deconv) 123 | signal = torch.tanh(signal) 124 | signal = ocnn.nn.octree_depad(signal, octree, depth) 125 | if update_octree: 126 | octree.features[depth] = signal 127 | 128 | return {'logits': logits, 'signal': signal, 'octree_out': octree} 129 | 130 | def decode_code(self, shape_code: torch.Tensor): 131 | r''' Decodes the shape code to an output octree. 132 | 133 | Args: 134 | shape_code (torch.Tensor): The shape code for decoding. 135 | ''' 136 | 137 | octree_out = self.init_octree(shape_code) 138 | out = self.ae_decoder(shape_code, octree_out, update_octree=True) 139 | return out 140 | 141 | def init_octree(self, shape_code: torch.Tensor): 142 | r''' Initialize a full octree for decoding. 143 | 144 | Args: 145 | shape_code (torch.Tensor): The shape code for decoding, used to getting 146 | the `batch_size` and `device` to initialize the output octree. 147 | ''' 148 | 149 | device = shape_code.device 150 | node_num = 2 ** (3 * self.full_depth) 151 | batch_size = shape_code.size(0) // node_num 152 | octree = Octree(self.depth, self.full_depth, batch_size, device) 153 | for d in range(self.full_depth+1): 154 | octree.octree_grow_full(depth=d) 155 | return octree 156 | 157 | def forward(self, octree: Octree, update_octree: bool): 158 | r'''''' 159 | 160 | shape_code = self.ae_encoder(octree) 161 | if update_octree: 162 | octree = self.init_octree(shape_code) 163 | out = self.ae_decoder(shape_code, octree, update_octree) 164 | return out 165 | -------------------------------------------------------------------------------- /utils/ocnn/models/hrnet.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Octree-based Sparse Convolutional Neural Networks 3 | # Copyright (c) 2022 Peng-Shuai Wang 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | import torch 9 | from typing import List 10 | 11 | from utils import ocnn 12 | from utils.ocnn.octree import Octree 13 | 14 | 15 | class Branches(torch.nn.Module): 16 | 17 | def __init__(self, channels: List[int], resblk_num: int, nempty: bool = False): 18 | super().__init__() 19 | self.channels = channels 20 | self.resblk_num = resblk_num 21 | bottlenecks = [4 if c < 256 else 8 for c in channels] # to save parameters 22 | self.resblocks = torch.nn.ModuleList([ 23 | ocnn.modules.OctreeResBlocks(ch, ch, resblk_num, bnk, nempty=nempty) 24 | for ch, bnk in zip(channels, bottlenecks)]) 25 | 26 | def forward(self, datas: List[torch.Tensor], octree: Octree, depth: int): 27 | num = len(self.channels) 28 | torch._assert(len(datas) == num, 'Error') 29 | 30 | out = [None] * num 31 | for i in range(num): 32 | depth_i = depth - i 33 | out[i] = self.resblocks[i](datas[i], octree, depth_i) 34 | return out 35 | 36 | 37 | class TransFunc(torch.nn.Module): 38 | 39 | def __init__(self, in_channels: int, out_channels: int, nempty: bool = False): 40 | super().__init__() 41 | self.in_channels = in_channels 42 | self.out_channels = out_channels 43 | self.nempty = nempty 44 | if in_channels != out_channels: 45 | self.conv1x1 = ocnn.modules.Conv1x1BnRelu(in_channels, out_channels) 46 | 47 | def forward(self, data: torch.Tensor, octree: Octree, 48 | in_depth: int, out_depth: int): 49 | out = data 50 | if in_depth > out_depth: 51 | for d in range(in_depth, out_depth, -1): 52 | out = ocnn.nn.octree_max_pool(out, octree, d, self.nempty) 53 | if self.in_channels != self.out_channels: 54 | out = self.conv1x1(out) 55 | 56 | if in_depth < out_depth: 57 | if self.in_channels != self.out_channels: 58 | out = self.conv1x1(out) 59 | for d in range(in_depth, out_depth, 1): 60 | out = ocnn.nn.octree_upsample(out, octree, d, self.nempty) 61 | return out 62 | 63 | 64 | class Transitions(torch.nn.Module): 65 | 66 | def __init__(self, channels: List[int], nempty: bool = False): 67 | super().__init__() 68 | self.channels = channels 69 | self.nempty = nempty 70 | 71 | num = len(self.channels) 72 | self.trans_func = torch.nn.ModuleList() 73 | for i in range(num - 1): 74 | for j in range(num): 75 | self.trans_func.append(TransFunc(channels[i], channels[j], nempty)) 76 | 77 | def forward(self, data: List[torch.Tensor], octree: Octree, depth: int): 78 | num = len(self.channels) 79 | features = [[None] * (num - 1) for _ in range(num)] 80 | for i in range(num - 1): 81 | for j in range(num): 82 | k = i * num + j 83 | in_depth = depth - i 84 | out_depth = depth - j 85 | features[j][i] = self.trans_func[k]( 86 | data[i], octree, in_depth, out_depth) 87 | 88 | out = [None] * num 89 | for j in range(num): 90 | # In the original tensorflow implmentation, the relu is added here, 91 | # instead of Line 77 92 | out[j] = torch.stack(features[j], dim=0).sum(dim=0) 93 | return out 94 | 95 | 96 | class FrontLayer(torch.nn.Module): 97 | 98 | def __init__(self, channels: List[int], nempty: bool = False): 99 | super().__init__() 100 | self.channels = channels 101 | self.num = len(channels) - 1 102 | self.nempty = nempty 103 | 104 | self.conv = torch.nn.ModuleList([ 105 | ocnn.modules.OctreeConvBnRelu(channels[i], channels[i + 1], nempty=nempty) 106 | for i in range(self.num)]) 107 | self.maxpool = torch.nn.ModuleList([ 108 | ocnn.nn.OctreeMaxPool(nempty) for i in range(self.num - 1)]) 109 | 110 | def forward(self, data: torch.Tensor, octree: Octree, depth: int): 111 | out = data 112 | for i in range(self.num - 1): 113 | depth_i = depth - i 114 | out = self.conv[i](out, octree, depth_i) 115 | out = self.maxpool[i](out, octree, depth_i) 116 | out = self.conv[-1](out, octree, depth - self.num + 1) 117 | return out 118 | 119 | 120 | class ClsHeader(torch.nn.Module): 121 | 122 | def __init__(self, channels: List[int], out_channels: int, nempty: bool = False): 123 | super().__init__() 124 | self.channels = channels 125 | self.out_channels = out_channels 126 | self.nempty = nempty 127 | 128 | in_channels = int(torch.Tensor(channels).sum()) 129 | self.conv1x1 = ocnn.modules.Conv1x1BnRelu(in_channels, 1024) 130 | self.global_pool = ocnn.nn.OctreeGlobalPool(nempty) 131 | self.header = torch.nn.Sequential( 132 | torch.nn.Flatten(start_dim=1), 133 | torch.nn.Linear(1024, out_channels, bias=True)) 134 | # self.header = torch.nn.Sequential( 135 | # ocnn.modules.FcBnRelu(512, 256), 136 | # torch.nn.Dropout(p=0.5), 137 | # torch.nn.Linear(256, out_channels)) 138 | 139 | def forward(self, data: List[torch.Tensor], octree: Octree, depth: int): 140 | full_depth = 2 141 | num = len(data) 142 | for i in range(num): 143 | depth_i = depth - i 144 | for d in range(depth_i, full_depth, -1): 145 | data[i] = ocnn.nn.octree_max_pool(data[i], octree, d, self.nempty) 146 | 147 | out = torch.cat(data, dim=1) 148 | out = self.conv1x1(out) 149 | out = self.global_pool(out, octree, full_depth) 150 | logit = self.header(out) 151 | return logit 152 | 153 | 154 | class HRNet(torch.nn.Module): 155 | r''' Octree-based HRNet for classification and segmentation. ''' 156 | 157 | def __init__(self, in_channels: int, out_channels: int, stages: int = 3, 158 | interp: str = 'linear', nempty: bool = False): 159 | super().__init__() 160 | self.in_channels = in_channels 161 | self.out_channels = out_channels 162 | self.interp = interp 163 | self.nempty = nempty 164 | self.stages = stages 165 | 166 | self.resblk_num = 3 167 | self.channels = [128, 256, 512, 512] 168 | 169 | self.front = FrontLayer([in_channels, 32, self.channels[0]], nempty) 170 | self.branches = torch.nn.ModuleList([ 171 | Branches(self.channels[:i+1], self.resblk_num, nempty) 172 | for i in range(stages)]) 173 | self.transitions = torch.nn.ModuleList([ 174 | Transitions(self.channels[:i+2], nempty) 175 | for i in range(stages-1)]) 176 | 177 | self.cls_header = ClsHeader(self.channels[:stages], out_channels, nempty) 178 | 179 | def forward(self, data: torch.Tensor, octree: Octree, depth: int): 180 | r'''''' 181 | convs = [self.front(data, octree, depth)] 182 | depth = depth - 1 # the data is downsampled in `front` 183 | for i in range(self.stages): 184 | convs = self.branches[i](convs, octree, depth) 185 | if i < self.stages - 1: 186 | convs = self.transitions[i](convs, octree, depth) 187 | 188 | logits = self.cls_header(convs, octree, depth) 189 | 190 | return logits 191 | -------------------------------------------------------------------------------- /utils/ocnn/models/lenet.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Octree-based Sparse Convolutional Neural Networks 3 | # Copyright (c) 2022 Peng-Shuai Wang 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | import torch 9 | from utils import ocnn 10 | from utils.ocnn.octree import Octree 11 | 12 | 13 | class LeNet(torch.nn.Module): 14 | r''' Octree-based LeNet for classification. 15 | ''' 16 | 17 | def __init__(self, in_channels: int, out_channels: int, stages: int, 18 | nempty: bool = False): 19 | super().__init__() 20 | self.in_channels = in_channels 21 | self.out_channels = out_channels 22 | self.stages = stages 23 | self.nempty = nempty 24 | channels = [in_channels] + [2 ** max(i+7-stages, 2) for i in range(stages)] 25 | 26 | self.convs = torch.nn.ModuleList([ocnn.modules.OctreeConvBnRelu( 27 | channels[i], channels[i+1], nempty=nempty) for i in range(stages)]) 28 | self.pools = torch.nn.ModuleList([ocnn.nn.OctreeMaxPool( 29 | nempty) for i in range(stages)]) 30 | self.octree2voxel = ocnn.nn.Octree2Voxel(self.nempty) 31 | self.header = torch.nn.Sequential( 32 | torch.nn.Dropout(p=0.5), # drop1 33 | ocnn.modules.FcBnRelu(64 * 64, 128), # fc1 34 | torch.nn.Dropout(p=0.5), # drop2 35 | torch.nn.Linear(128, out_channels)) # fc2 36 | 37 | def forward(self, data: torch.Tensor, octree: Octree, depth: int): 38 | r'''''' 39 | 40 | for i in range(self.stages): 41 | d = depth - i 42 | data = self.convs[i](data, octree, d) 43 | data = self.pools[i](data, octree, d) 44 | data = self.octree2voxel(data, octree, depth-self.stages) 45 | data = self.header(data) 46 | return data 47 | -------------------------------------------------------------------------------- /utils/ocnn/models/octree_unet.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Octree-based Sparse Convolutional Neural Networks 3 | # Copyright (c) 2022 Peng-Shuai Wang 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | import torch 9 | import torch.nn 10 | from typing import Dict 11 | 12 | from utils import ocnn 13 | from utils.ocnn.octree import Octree 14 | 15 | 16 | class OctreeUnet(torch.nn.Module): 17 | r''' Octree-based UNet for segmentation. 18 | ''' 19 | 20 | def __init__(self, in_channels: int, out_channels: int, interp: str = 'linear', 21 | nempty: bool = False, **kwargs): 22 | super(OctreeUnet, self).__init__() 23 | self.in_channels = in_channels 24 | self.out_channels = out_channels 25 | self.nempty = nempty 26 | self.config_network() 27 | self.encoder_stages = len(self.encoder_blocks) 28 | self.decoder_stages = len(self.decoder_blocks) 29 | 30 | # encoder 31 | self.conv1 = ocnn.modules.OctreeConvBnRelu( 32 | in_channels, self.encoder_channel[0], nempty=nempty) 33 | self.downsample = torch.nn.ModuleList([ocnn.modules.OctreeConvBnRelu( 34 | self.encoder_channel[i], self.encoder_channel[i+1], kernel_size=[2], 35 | stride=2, nempty=nempty) for i in range(self.encoder_stages)]) 36 | self.encoder = torch.nn.ModuleList([ocnn.modules.OctreeResBlocks( 37 | self.encoder_channel[i+1], self.encoder_channel[i + 1], 38 | self.encoder_blocks[i], self.bottleneck, nempty, self.resblk) 39 | for i in range(self.encoder_stages)]) 40 | 41 | # decoder 42 | channel = [self.decoder_channel[i+1] + self.encoder_channel[-i-2] 43 | for i in range(self.decoder_stages)] 44 | self.upsample = torch.nn.ModuleList([ocnn.modules.OctreeDeconvBnRelu( 45 | self.decoder_channel[i], self.decoder_channel[i+1], kernel_size=[2], 46 | stride=2, nempty=nempty) for i in range(self.decoder_stages)]) 47 | self.decoder = torch.nn.ModuleList([ocnn.modules.OctreeResBlocks( 48 | channel[i], self.decoder_channel[i+1], 49 | self.decoder_blocks[i], self.bottleneck, nempty, self.resblk) 50 | for i in range(self.decoder_stages)]) 51 | 52 | # header 53 | # channel = self.decoder_channel[self.decoder_stages] 54 | self.octree_interp = ocnn.nn.OctreeInterp(interp, nempty) 55 | self.header = torch.nn.Sequential( 56 | ocnn.modules.Conv1x1BnRelu(self.decoder_channel[-1], 64), 57 | ocnn.modules.Conv1x1(64, self.out_channels, use_bias=True)) 58 | 59 | def config_network(self): 60 | r''' Configure the network channels and Resblock numbers. 61 | ''' 62 | 63 | self.encoder_channel = [32, 32, 64, 128, 256] 64 | self.decoder_channel = [256, 256, 128, 96, 96] 65 | self.encoder_blocks = [2, 3, 4, 6] 66 | self.decoder_blocks = [2, 2, 2, 2] 67 | self.bottleneck = 1 68 | self.resblk = ocnn.modules.OctreeResBlock2 69 | 70 | def unet_encoder(self, data: torch.Tensor, octree: Octree, depth: int): 71 | r''' The encoder of the U-Net. 72 | ''' 73 | 74 | convd = dict() 75 | convd[depth] = self.conv1(data, octree, depth) 76 | for i in range(self.encoder_stages): 77 | d = depth - i 78 | conv = self.downsample[i](convd[d], octree, d) 79 | convd[d-1] = self.encoder[i](conv, octree, d-1) 80 | return convd 81 | 82 | def unet_decoder(self, convd: Dict[int, torch.Tensor], octree: Octree, depth: int): 83 | r''' The decoder of the U-Net. 84 | ''' 85 | 86 | deconv = convd[depth] 87 | for i in range(self.decoder_stages): 88 | d = depth + i 89 | deconv = self.upsample[i](deconv, octree, d) 90 | deconv = torch.cat([convd[d+1], deconv], dim=1) # skip connections 91 | deconv = self.decoder[i](deconv, octree, d+1) 92 | return deconv 93 | 94 | def forward(self, data: torch.Tensor, octree: Octree, depth: int, 95 | query_pts: torch.Tensor): 96 | r'''''' 97 | 98 | convd = self.unet_encoder(data, octree, depth) 99 | deconv = self.unet_decoder(convd, octree, depth - self.encoder_stages) 100 | 101 | interp_depth = depth - self.encoder_stages + self.decoder_stages 102 | feature = self.octree_interp(deconv, octree, interp_depth, query_pts) 103 | logits = self.header(feature) 104 | return logits 105 | -------------------------------------------------------------------------------- /utils/ocnn/models/resnet.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Octree-based Sparse Convolutional Neural Networks 3 | # Copyright (c) 2022 Peng-Shuai Wang 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | import torch 9 | from utils import ocnn 10 | from utils.ocnn.octree import Octree 11 | 12 | 13 | class ResNet(torch.nn.Module): 14 | r''' Octree-based ResNet for classification. 15 | ''' 16 | 17 | def __init__(self, in_channels: int, out_channels: int, resblock_num: int, 18 | stages: int, nempty: bool = False): 19 | super().__init__() 20 | self.in_channels = in_channels 21 | self.out_channels = out_channels 22 | self.resblk_num = resblock_num 23 | self.stages = stages 24 | self.nempty = nempty 25 | channels = [2 ** max(i+9-stages, 2) for i in range(stages)] 26 | 27 | self.conv1 = ocnn.modules.OctreeConvBnRelu( 28 | in_channels, channels[0], nempty=nempty) 29 | self.pool1 = ocnn.nn.OctreeMaxPool(nempty) 30 | self.resblocks = torch.nn.ModuleList([ocnn.modules.OctreeResBlocks( 31 | channels[i], channels[i+1], resblock_num, nempty=nempty) 32 | for i in range(stages-1)]) 33 | self.pools = torch.nn.ModuleList([ocnn.nn.OctreeMaxPool( 34 | nempty) for i in range(stages-1)]) 35 | self.global_pool = ocnn.nn.OctreeGlobalPool(nempty) 36 | # self.header = torch.nn.Linear(channels[-1], out_channels, bias=True) 37 | self.header = torch.nn.Sequential( 38 | ocnn.modules.FcBnRelu(channels[-1], 512), 39 | torch.nn.Dropout(p=0.5), 40 | torch.nn.Linear(512, out_channels)) 41 | 42 | def forward(self, data: torch.Tensor, octree: Octree, depth: int): 43 | r'''''' 44 | 45 | data = self.conv1(data, octree, depth) 46 | data = self.pool1(data, octree, depth) 47 | for i in range(self.stages-1): 48 | d = depth - i - 1 49 | data = self.resblocks[i](data, octree, d) 50 | data = self.pools[i](data, octree, d) 51 | data = self.global_pool(data, octree, depth-self.stages) 52 | data = self.header(data) 53 | return data 54 | -------------------------------------------------------------------------------- /utils/ocnn/models/segnet.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Octree-based Sparse Convolutional Neural Networks 3 | # Copyright (c) 2022 Peng-Shuai Wang 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | import torch 9 | from utils import ocnn 10 | from utils.ocnn.octree import Octree 11 | 12 | 13 | class SegNet(torch.nn.Module): 14 | r''' Octree-based SegNet for segmentation. 15 | ''' 16 | 17 | def __init__(self, in_channels: int, out_channels: int, stages: int, 18 | interp: str = 'linear', nempty: bool = False, **kwargs): 19 | super().__init__() 20 | self.in_channels = in_channels 21 | self.out_channels = out_channels 22 | self.stages = stages 23 | self.nempty = nempty 24 | return_indices = True 25 | 26 | channels_stages = [2 ** max(i+8-stages, 2) for i in range(stages)] 27 | channels = [in_channels] + channels_stages 28 | self.convs = torch.nn.ModuleList( 29 | [ocnn.modules.OctreeConvBnRelu(channels[i], channels[i + 1], nempty=nempty) 30 | for i in range(stages)]) 31 | self.pools = torch.nn.ModuleList( 32 | [ocnn.nn.OctreeMaxPool(nempty, return_indices) for i in range(stages)]) 33 | 34 | self.bottleneck = ocnn.modules.OctreeConvBnRelu(channels[-1], channels[-1]) 35 | 36 | channels = channels_stages[::-1] + [channels_stages[0]] 37 | self.deconvs = torch.nn.ModuleList( 38 | [ocnn.modules.OctreeConvBnRelu(channels[i], channels[i + 1], nempty=nempty) 39 | for i in range(0, stages)]) 40 | self.unpools = torch.nn.ModuleList( 41 | [ocnn.nn.OctreeMaxUnpool(nempty) for i in range(stages)]) 42 | 43 | self.octree_interp = ocnn.nn.OctreeInterp(interp, nempty) 44 | self.header = torch.nn.Sequential( 45 | ocnn.modules.Conv1x1BnRelu(channels[-1], 64), 46 | ocnn.modules.Conv1x1(64, out_channels, use_bias=True)) 47 | 48 | def forward(self, data: torch.Tensor, octree: Octree, depth: int, 49 | query_pts: torch.Tensor): 50 | r'''''' 51 | 52 | # encoder 53 | indices = dict() 54 | for i in range(self.stages): 55 | d = depth - i 56 | data = self.convs[i](data, octree, d) 57 | data, indices[d] = self.pools[i](data, octree, d) 58 | 59 | # bottleneck 60 | data = self.bottleneck(data, octree, depth-self.stages) 61 | 62 | # decoder 63 | for i in range(self.stages): 64 | d = depth - self.stages + i 65 | data = self.unpools[i](data, indices[d + 1], octree, d) 66 | data = self.deconvs[i](data, octree, d + 1) 67 | 68 | # header 69 | feature = self.octree_interp(data, octree, depth, query_pts) 70 | logits = self.header(feature) 71 | 72 | return logits 73 | -------------------------------------------------------------------------------- /utils/ocnn/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Octree-based Sparse Convolutional Neural Networks 3 | # Copyright (c) 2022 Peng-Shuai Wang 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | from .modules import (InputFeature, 9 | OctreeConvBn, OctreeConvBnRelu, OctreeDeconvBnRelu, 10 | Conv1x1, Conv1x1Bn, Conv1x1BnRelu, FcBnRelu,) 11 | from .resblocks import OctreeResBlock, OctreeResBlock2, OctreeResBlocks 12 | 13 | __all__ = [ 14 | 'InputFeature', 15 | 'OctreeConvBn', 'OctreeConvBnRelu', 'OctreeDeconvBnRelu', 16 | 'Conv1x1', 'Conv1x1Bn', 'Conv1x1BnRelu', 'FcBnRelu', 17 | 'OctreeResBlock', 'OctreeResBlock2', 'OctreeResBlocks', 18 | ] 19 | 20 | classes = __all__ 21 | -------------------------------------------------------------------------------- /utils/ocnn/modules/modules.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Octree-based Sparse Convolutional Neural Networks 3 | # Copyright (c) 2022 Peng-Shuai Wang 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | import torch 9 | import torch.utils.checkpoint 10 | from typing import List 11 | 12 | from utils import ocnn 13 | from utils.ocnn.nn import OctreeConv, OctreeDeconv 14 | from utils.ocnn.octree import Octree 15 | 16 | 17 | bn_momentum, bn_eps = 0.01, 0.001 # the default value of Tensorflow 1.x 18 | # bn_momentum, bn_eps = 0.1, 1e-05 # the default value of pytorch 19 | 20 | 21 | def ckpt_conv_wrapper(conv_op, data, octree): 22 | # The dummy tensor is a workaround when the checkpoint is used for the first conv layer: 23 | # https://discuss.pytorch.org/t/checkpoint-with-no-grad-requiring-inputs-problem/19117/11 24 | dummy = torch.ones(1, dtype=torch.float32, requires_grad=True) 25 | 26 | def conv_wrapper(data, octree, dummy_tensor): 27 | return conv_op(data, octree) 28 | 29 | return torch.utils.checkpoint.checkpoint(conv_wrapper, data, octree, dummy) 30 | 31 | 32 | class OctreeConvBn(torch.nn.Module): 33 | r''' A sequence of :class:`OctreeConv` and :obj:`BatchNorm`. 34 | 35 | Please refer to :class:`ocnn.nn.OctreeConv` for details on the parameters. 36 | ''' 37 | 38 | def __init__(self, in_channels: int, out_channels: int, 39 | kernel_size: List[int] = [3], stride: int = 1, 40 | nempty: bool = False): 41 | super().__init__() 42 | self.conv = OctreeConv( 43 | in_channels, out_channels, kernel_size, stride, nempty) 44 | self.bn = torch.nn.BatchNorm1d(out_channels, bn_eps, bn_momentum) 45 | 46 | def forward(self, data: torch.Tensor, octree: Octree, depth: int): 47 | r'''''' 48 | 49 | out = self.conv(data, octree, depth) 50 | out = self.bn(out) 51 | return out 52 | 53 | 54 | class OctreeConvBnRelu(torch.nn.Module): 55 | r''' A sequence of :class:`OctreeConv`, :obj:`BatchNorm`, and :obj:`Relu`. 56 | 57 | Please refer to :class:`ocnn.nn.OctreeConv` for details on the parameters. 58 | ''' 59 | 60 | def __init__(self, in_channels: int, out_channels: int, 61 | kernel_size: List[int] = [3], stride: int = 1, 62 | nempty: bool = False): 63 | super().__init__() 64 | self.conv = OctreeConv( 65 | in_channels, out_channels, kernel_size, stride, nempty) 66 | self.bn = torch.nn.BatchNorm1d(out_channels, bn_eps, bn_momentum) 67 | self.relu = torch.nn.ReLU(inplace=True) 68 | 69 | def forward(self, data: torch.Tensor, octree: Octree, depth: int): 70 | r'''''' 71 | 72 | out = self.conv(data, octree, depth) 73 | out = self.bn(out) 74 | out = self.relu(out) 75 | return out 76 | 77 | 78 | class OctreeDeconvBnRelu(torch.nn.Module): 79 | r''' A sequence of :class:`OctreeDeconv`, :obj:`BatchNorm`, and :obj:`Relu`. 80 | 81 | Please refer to :class:`ocnn.nn.OctreeDeconv` for details on the parameters. 82 | ''' 83 | 84 | def __init__(self, in_channels: int, out_channels: int, 85 | kernel_size: List[int] = [3], stride: int = 1, 86 | nempty: bool = False): 87 | super().__init__() 88 | self.deconv = OctreeDeconv( 89 | in_channels, out_channels, kernel_size, stride, nempty) 90 | self.bn = torch.nn.BatchNorm1d(out_channels, bn_eps, bn_momentum) 91 | self.relu = torch.nn.ReLU(inplace=True) 92 | 93 | def forward(self, data: torch.Tensor, octree: Octree, depth: int): 94 | r'''''' 95 | 96 | out = self.deconv(data, octree, depth) 97 | out = self.bn(out) 98 | out = self.relu(out) 99 | return out 100 | 101 | 102 | class Conv1x1(torch.nn.Module): 103 | r''' Performs a convolution with kernel :obj:`(1,1,1)`. 104 | 105 | The shape of octree features is :obj:`(N, C)`, where :obj:`N` is the node 106 | number and :obj:`C` is the feature channel. Therefore, :class:`Conv1x1` can be 107 | implemented with :class:`torch.nn.Linear`. 108 | ''' 109 | 110 | def __init__(self, in_channels: int, out_channels: int, use_bias: bool = False): 111 | super().__init__() 112 | self.linear = torch.nn.Linear(in_channels, out_channels, use_bias) 113 | 114 | def forward(self, data: torch.Tensor): 115 | r'''''' 116 | 117 | return self.linear(data) 118 | 119 | 120 | class Conv1x1Bn(torch.nn.Module): 121 | r''' A sequence of :class:`Conv1x1` and :class:`BatchNorm`. 122 | ''' 123 | 124 | def __init__(self, in_channels: int, out_channels: int): 125 | super().__init__() 126 | self.conv = Conv1x1(in_channels, out_channels, use_bias=False) 127 | self.bn = torch.nn.BatchNorm1d(out_channels, bn_eps, bn_momentum) 128 | 129 | def forward(self, data: torch.Tensor): 130 | r'''''' 131 | 132 | out = self.conv(data) 133 | out = self.bn(out) 134 | return out 135 | 136 | 137 | class Conv1x1BnRelu(torch.nn.Module): 138 | r''' A sequence of :class:`Conv1x1`, :class:`BatchNorm` and :class:`Relu`. 139 | ''' 140 | 141 | def __init__(self, in_channels: int, out_channels: int): 142 | super().__init__() 143 | self.conv = Conv1x1(in_channels, out_channels, use_bias=False) 144 | self.bn = torch.nn.BatchNorm1d(out_channels, bn_eps, bn_momentum) 145 | self.relu = torch.nn.ReLU(inplace=True) 146 | 147 | def forward(self, data: torch.Tensor): 148 | r'''''' 149 | 150 | out = self.conv(data) 151 | out = self.bn(out) 152 | out = self.relu(out) 153 | return out 154 | 155 | 156 | class FcBnRelu(torch.nn.Module): 157 | r''' A sequence of :class:`FC`, :class:`BatchNorm` and :class:`Relu`. 158 | ''' 159 | 160 | def __init__(self, in_channels: int, out_channels: int): 161 | super().__init__() 162 | self.flatten = torch.nn.Flatten(start_dim=1) 163 | self.fc = torch.nn.Linear(in_channels, out_channels, bias=False) 164 | self.bn = torch.nn.BatchNorm1d(out_channels, bn_eps, bn_momentum) 165 | self.relu = torch.nn.ReLU(inplace=True) 166 | 167 | def forward(self, data): 168 | r'''''' 169 | 170 | out = self.flatten(data) 171 | out = self.fc(out) 172 | out = self.bn(out) 173 | out = self.relu(out) 174 | return out 175 | 176 | 177 | class InputFeature(torch.nn.Module): 178 | r''' Returns the initial input feature stored in octree. 179 | 180 | Args: 181 | feature (str): A string used to indicate which features to extract from the 182 | input octree. If the character :obj:`N` is in :attr:`feature`, the 183 | normal signal is extracted (3 channels). Similarly, if :obj:`D` is in 184 | :attr:`feature`, the local displacement is extracted (1 channels). If 185 | :obj:`L` is in :attr:`feature`, the local coordinates of the averaged 186 | points in each octree node is extracted (3 channels). If :attr:`P` is in 187 | :attr:`feature`, the global coordinates are extracted (3 channels). If 188 | :attr:`F` is in :attr:`feature`, other features (like colors) are 189 | extracted (k channels). 190 | nempty (bool): If false, gets the features of all octree nodes. 191 | ''' 192 | 193 | def __init__(self, feature: str = 'NDF', nempty: bool = False): 194 | super().__init__() 195 | self.nempty = nempty 196 | self.feature = feature.upper() 197 | 198 | def forward(self, octree: Octree): 199 | r'''''' 200 | 201 | features = list() 202 | depth = octree.depth 203 | if 'N' in self.feature: 204 | features.append(octree.normals[depth]) 205 | 206 | if 'L' in self.feature or 'D' in self.feature: 207 | local_points = octree.points[depth].frac() - 0.5 208 | 209 | if 'D' in self.feature: 210 | dis = torch.sum(local_points * octree.normals[depth], dim=1, keepdim=True) 211 | features.append(dis) 212 | 213 | if 'L' in self.feature: 214 | features.append(local_points) 215 | 216 | if 'P' in self.feature: 217 | scale = 2 ** (1 - depth) # normalize [0, 2^depth] -> [-1, 1] 218 | global_points = octree.points[depth] * scale - 1.0 219 | features.append(global_points) 220 | 221 | if 'F' in self.feature: 222 | features.append(octree.features[depth]) 223 | 224 | out = torch.cat(features, dim=1) 225 | if not self.nempty: 226 | out = ocnn.nn.octree_pad(out, octree, depth) 227 | return out 228 | 229 | def extra_repr(self) -> str: 230 | r'''''' 231 | return 'feature={}, nempty={}'.format(self.feature, self.nempty) 232 | -------------------------------------------------------------------------------- /utils/ocnn/modules/resblocks.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Octree-based Sparse Convolutional Neural Networks 3 | # Copyright (c) 2022 Peng-Shuai Wang 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | import torch 9 | import torch.utils.checkpoint 10 | 11 | from utils.ocnn.octree import Octree 12 | from utils.ocnn.nn import OctreeMaxPool 13 | from utils.ocnn.modules import Conv1x1BnRelu, OctreeConvBnRelu, Conv1x1Bn, OctreeConvBn 14 | 15 | 16 | class OctreeResBlock(torch.nn.Module): 17 | r''' Octree-based ResNet block in a bottleneck style. The block is composed of 18 | a series of :obj:`Conv1x1`, :obj:`Conv3x3`, and :obj:`Conv1x1`. 19 | 20 | Args: 21 | in_channels (int): Number of input channels. 22 | out_channels (int): Number of output channels. 23 | stride (int): The stride of the block (:obj:`1` or :obj:`2`). 24 | bottleneck (int): The input and output channels of the :obj:`Conv3x3` is 25 | equal to the input channel divided by :attr:`bottleneck`. 26 | nempty (bool): If True, only performs the convolution on non-empty 27 | octree nodes. 28 | ''' 29 | 30 | def __init__(self, in_channels: int, out_channels: int, stride: int = 1, 31 | bottleneck: int = 4, nempty: bool = False): 32 | super().__init__() 33 | self.in_channels = in_channels 34 | self.out_channels = out_channels 35 | self.bottleneck = bottleneck 36 | self.stride = stride 37 | channelb = int(out_channels / bottleneck) 38 | 39 | if self.stride == 2: 40 | self.max_pool = OctreeMaxPool(nempty) 41 | self.conv1x1a = Conv1x1BnRelu(in_channels, channelb) 42 | self.conv3x3 = OctreeConvBnRelu(channelb, channelb, nempty=nempty) 43 | self.conv1x1b = Conv1x1Bn(channelb, out_channels) 44 | if self.in_channels != self.out_channels: 45 | self.conv1x1c = Conv1x1Bn(in_channels, out_channels) 46 | self.relu = torch.nn.ReLU(inplace=True) 47 | 48 | def forward(self, data: torch.Tensor, octree: Octree, depth: int): 49 | r'''''' 50 | 51 | if self.stride == 2: 52 | data = self.max_pool(data, octree, depth) 53 | depth = depth - 1 54 | conv1 = self.conv1x1a(data) 55 | conv2 = self.conv3x3(conv1, octree, depth) 56 | conv3 = self.conv1x1b(conv2) 57 | if self.in_channels != self.out_channels: 58 | data = self.conv1x1c(data) 59 | out = self.relu(conv3 + data) 60 | return out 61 | 62 | 63 | class OctreeResBlock2(torch.nn.Module): 64 | r''' Basic Octree-based ResNet block. The block is composed of 65 | a series of :obj:`Conv3x3` and :obj:`Conv3x3`. 66 | 67 | Refer to :class:`OctreeResBlock` for the details of arguments. 68 | ''' 69 | 70 | def __init__(self, in_channels, out_channels, stride=1, bottleneck=1, 71 | nempty=False): 72 | super().__init__() 73 | self.in_channels = in_channels 74 | self.out_channels = out_channels 75 | self.stride = stride 76 | channelb = int(out_channels / bottleneck) 77 | 78 | if self.stride == 2: 79 | self.maxpool = OctreeMaxPool(self.depth) 80 | self.conv3x3a = OctreeConvBnRelu(in_channels, channelb, nempty=nempty) 81 | self.conv3x3b = OctreeConvBn(channelb, out_channels, nempty=nempty) 82 | if self.in_channels != self.out_channels: 83 | self.conv1x1 = Conv1x1Bn(in_channels, out_channels) 84 | self.relu = torch.nn.ReLU(inplace=True) 85 | 86 | def forward(self, data: torch.Tensor, octree: Octree, depth: int): 87 | r'''''' 88 | 89 | if self.stride == 2: 90 | data = self.maxpool(data, octree, depth) 91 | depth = depth - 1 92 | conv1 = self.conv3x3a(data, octree, depth) 93 | conv2 = self.conv3x3b(conv1, octree, depth) 94 | if self.in_channels != self.out_channels: 95 | data = self.conv1x1(data) 96 | out = self.relu(conv2 + data) 97 | return out 98 | 99 | 100 | class OctreeResBlocks(torch.nn.Module): 101 | r''' A sequence of :attr:`resblk_num` ResNet blocks. 102 | ''' 103 | 104 | def __init__(self, in_channels, out_channels, resblk_num, bottleneck=4, 105 | nempty=False, resblk=OctreeResBlock, use_checkpoint=False): 106 | super().__init__() 107 | self.resblk_num = resblk_num 108 | self.use_checkpoint = use_checkpoint 109 | channels = [in_channels] + [out_channels] * resblk_num 110 | 111 | self.resblks = torch.nn.ModuleList( 112 | [resblk(channels[i], channels[i+1], 1, bottleneck, nempty) 113 | for i in range(self.resblk_num)]) 114 | 115 | def forward(self, data: torch.Tensor, octree: Octree, depth: int): 116 | r'''''' 117 | 118 | for i in range(self.resblk_num): 119 | if self.use_checkpoint: 120 | data = torch.utils.checkpoint.checkpoint( 121 | self.resblks[i], data, octree, depth) 122 | else: 123 | data = self.resblks[i](data, octree, depth) 124 | return data 125 | -------------------------------------------------------------------------------- /utils/ocnn/nn/__init__.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Octree-based Sparse Convolutional Neural Networks 3 | # Copyright (c) 2022 Peng-Shuai Wang 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | from .octree2vox import octree2voxel, Octree2Voxel 9 | from .octree2col import octree2col, col2octree 10 | from .octree_pad import octree_pad, octree_depad 11 | from .octree_interp import (octree_nearest_pts, octree_linear_pts, 12 | OctreeInterp, OctreeUpsample) 13 | from .octree_pool import (octree_max_pool, OctreeMaxPool, 14 | octree_max_unpool, OctreeMaxUnpool, 15 | octree_global_pool, OctreeGlobalPool, 16 | octree_avg_pool, OctreeAvgPool,) 17 | from .octree_conv import OctreeConv, OctreeDeconv 18 | from .octree_dwconv import OctreeDWConv 19 | from .octree_norm import OctreeInstanceNorm, OctreeBatchNorm 20 | from .octree_drop import OctreeDropPath 21 | 22 | 23 | __all__ = [ 24 | 'octree2voxel', 25 | 'octree2col', 'col2octree', 26 | 'octree_pad', 'octree_depad', 27 | 'octree_nearest_pts', 'octree_linear_pts', 28 | 'octree_max_pool', 'octree_max_unpool', 29 | 'octree_global_pool', 'octree_avg_pool', 30 | 'Octree2Voxel', 31 | 'OctreeMaxPool', 'OctreeMaxUnpool', 32 | 'OctreeGlobalPool', 'OctreeAvgPool', 33 | 'OctreeConv', 'OctreeDeconv', 34 | 'OctreeDWConv', 35 | 'OctreeInterp', 'OctreeUpsample', 36 | 'OctreeInstanceNorm', 'OctreeBatchNorm', 37 | 'OctreeDropPath', 38 | ] 39 | 40 | classes = __all__ 41 | -------------------------------------------------------------------------------- /utils/ocnn/nn/octree2col.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Octree-based Sparse Convolutional Neural Networks 3 | # Copyright (c) 2022 Peng-Shuai Wang 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | import torch 9 | import torch.nn 10 | 11 | from utils.ocnn.octree import Octree 12 | from utils.ocnn.utils import scatter_add 13 | 14 | 15 | def octree2col(data: torch.Tensor, octree: Octree, depth: int, 16 | kernel_size: str = '333', stride: int = 1, nempty: bool = False): 17 | r''' Gathers the neighboring features for convolutions. 18 | 19 | Args: 20 | data (torch.Tensor): The input data. 21 | octree (Octree): The corresponding octree. 22 | depth (int): The depth of current octree. 23 | kernel_size (str): The kernel shape, choose from :obj:`333`, :obj:`311`, 24 | :obj:`131`, :obj:`113`, :obj:`222`, :obj:`331`, :obj:`133`, and 25 | :obj:`313`. 26 | stride (int): The stride of neighborhoods (:obj:`1` or :obj:`2`). If the 27 | stride is :obj:`2`, it always returns the neighborhood of the first 28 | siblings, and the number of elements of output tensor is 29 | :obj:`octree.nnum[depth] / 8`. 30 | nempty (bool): If True, only returns the neighborhoods of the non-empty 31 | octree nodes. 32 | ''' 33 | 34 | neigh = octree.get_neigh(depth, kernel_size, stride, nempty) 35 | size = (neigh.shape[0], neigh.shape[1], data.shape[1]) 36 | out = torch.zeros(size, dtype=data.dtype, device=data.device) 37 | valid = neigh >= 0 38 | out[valid] = data[neigh[valid]] # (N, K, C) 39 | return out 40 | 41 | 42 | def col2octree(data: torch.Tensor, octree: Octree, depth: int, 43 | kernel_size: str = '333', stride: int = 1, nempty: bool = False): 44 | r''' Scatters the convolution features to an octree. 45 | 46 | Please refer to :func:`octree2col` for the usage of function parameters. 47 | ''' 48 | 49 | neigh = octree.get_neigh(depth, kernel_size, stride, nempty) 50 | valid = neigh >= 0 51 | dim_size = octree.nnum_nempty[depth] if nempty else octree.nnum[depth] 52 | out = scatter_add(data[valid], neigh[valid], dim=0, dim_size=dim_size) 53 | return out 54 | -------------------------------------------------------------------------------- /utils/ocnn/nn/octree2vox.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Octree-based Sparse Convolutional Neural Networks 3 | # Copyright (c) 2022 Peng-Shuai Wang 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | import torch 9 | 10 | from ..octree import Octree, key2xyz 11 | from .octree_pad import octree_depad 12 | 13 | 14 | def octree2voxel(data: torch.Tensor, octree: Octree, depth: int, 15 | nempty: bool = False): 16 | r''' Converts the input feature to full-voxel-based representation. 17 | 18 | Args: 19 | data (torch.Tensor): The input feature. 20 | octree (Octree): The corresponding octree. 21 | depth (int): The depth of current octree. 22 | nempty (bool): If True, :attr:`data` only contains the features of non-empty 23 | octree nodes. 24 | ''' 25 | 26 | key = octree.keys[depth] 27 | if nempty: 28 | key = octree_depad(key, octree, depth) 29 | x, y, z, b = key2xyz(key, depth) 30 | 31 | num = 1 << depth 32 | channel = data.shape[1] 33 | batch_size = octree.batch_size 34 | size = (batch_size, num, num, num, channel) 35 | vox = torch.zeros(size, dtype=data.dtype, device=data.device) 36 | vox[b, x, y, z] = data 37 | return vox 38 | 39 | 40 | class Octree2Voxel(torch.nn.Module): 41 | r''' Converts the input feature to full-voxel-based representation 42 | 43 | Please refer to :func:`octree2voxel` for details. 44 | ''' 45 | 46 | def __init__(self, nempty: bool = False): 47 | super().__init__() 48 | self.nempty = nempty 49 | 50 | def forward(self, data: torch.Tensor, octree: Octree, depth: int): 51 | r'''''' 52 | 53 | return octree2voxel(data, octree, depth, self.nempty) 54 | 55 | def extra_repr(self) -> str: 56 | return 'nempty={}'.format(self.nempty) 57 | -------------------------------------------------------------------------------- /utils/ocnn/nn/octree_conv.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Octree-based Sparse Convolutional Neural Networks 3 | # Copyright (c) 2022 Peng-Shuai Wang 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | import torch 9 | import torch.nn 10 | from torch.autograd import Function 11 | from typing import List 12 | 13 | from utils.ocnn.octree import Octree 14 | from utils.ocnn.utils import scatter_add, xavier_uniform_, resize_with_last_val, list2str 15 | from .octree2col import octree2col, col2octree 16 | from .octree_pad import octree_pad, octree_depad 17 | 18 | 19 | class OctreeConvBase: 20 | 21 | def __init__(self, in_channels: int, out_channels: int, 22 | kernel_size: List[int] = [3], stride: int = 1, 23 | nempty: bool = False, max_buffer: int = int(2e8)): 24 | super().__init__() 25 | self.in_channels = in_channels 26 | self.out_channels = out_channels 27 | self.kernel_size = resize_with_last_val(kernel_size) 28 | self.kernel = list2str(self.kernel_size) 29 | self.stride = stride 30 | self.nempty = nempty 31 | self.max_buffer = max_buffer # about 200M 32 | 33 | self.kdim = self.kernel_size[0] * self.kernel_size[1] * self.kernel_size[2] 34 | self.in_conv = in_channels if self.is_conv_layer() else out_channels 35 | self.out_conv = out_channels if self.is_conv_layer() else in_channels 36 | self.weights_shape = (self.kdim, self.in_conv, self.out_conv) 37 | 38 | def is_conv_layer(self): 39 | r''' Returns :obj:`True` to indicate this is a convolution layer. 40 | ''' 41 | 42 | raise NotImplementedError 43 | 44 | def setup(self, octree: Octree, depth: int): 45 | r''' Setup the shapes of each tensor. 46 | This function MUST be called before :obj:`forward_gemm`, :obj:`backward_gemm` 47 | and :obj:`weight_gemm`. 48 | ''' 49 | 50 | # The depth of tensors: 51 | # The in_depth and out_depth are the octree depth of the input and output 52 | # data; neigh_depth is the octree depth of the neighborhood information, as 53 | # well as `col` data, neigh_depth is always the same as the depth of larger 54 | # data when doing octree2col or col2octree. 55 | self.in_depth = depth 56 | self.out_depth = depth 57 | self.neigh_depth = depth 58 | if self.stride == 2: 59 | if self.is_conv_layer(): 60 | self.out_depth = depth - 1 61 | else: 62 | self.out_depth = depth + 1 63 | self.neigh_depth = depth + 1 64 | 65 | # The height of tensors 66 | if self.nempty: 67 | self.in_h = octree.nnum_nempty[self.in_depth] 68 | self.out_h = octree.nnum_nempty[self.out_depth] 69 | else: 70 | self.in_h = octree.nnum[self.in_depth] 71 | self.out_h = octree.nnum[self.out_depth] 72 | if self.stride == 2: 73 | if self.is_conv_layer(): 74 | self.out_h = octree.nnum_nempty[self.out_depth] 75 | else: 76 | self.in_h = octree.nnum_nempty[self.in_depth] 77 | self.in_shape = (self.in_h, self.in_channels) 78 | self.out_shape = (self.out_h, self.out_channels) 79 | 80 | # The neighborhood indices 81 | self.neigh = octree.get_neigh( 82 | self.neigh_depth, self.kernel, self.stride, self.nempty) 83 | 84 | # The heigh and number of the temporary buffer 85 | self.buffer_n = 1 86 | self.buffer_h = self.neigh.shape[0] 87 | ideal_size = self.buffer_h * self.kdim * self.in_conv 88 | if ideal_size > self.max_buffer: 89 | kc = self.kdim * self.in_conv # make `max_buffer` be divided 90 | max_buffer = self.max_buffer // kc * kc # by `kc` with no remainder 91 | self.buffer_n = (ideal_size + max_buffer - 1) // max_buffer 92 | self.buffer_h = (self.buffer_h + self.buffer_n - 1) // self.buffer_n 93 | self.buffer_shape = (self.buffer_h, self.kdim, self.in_conv) 94 | 95 | def check_and_init(self, data: torch.Tensor): 96 | r''' Checks the input data and initializes the shape of output data. 97 | ''' 98 | 99 | # Check the shape of input data 100 | check = tuple(data.shape) == self.in_shape 101 | assert check, 'The shape of input data is wrong.' 102 | 103 | # Init the output data 104 | out = data.new_zeros(self.out_shape) 105 | return out 106 | 107 | def forward_gemm(self, out: torch.Tensor, data: torch.Tensor, 108 | weights: torch.Tensor): 109 | r''' Peforms the forward pass of octree-based convolution. 110 | ''' 111 | 112 | # Initialize the buffer 113 | buffer = data.new_empty(self.buffer_shape) 114 | 115 | # Loop over each sub-matrix 116 | for i in range(self.buffer_n): 117 | start = i * self.buffer_h 118 | end = (i + 1) * self.buffer_h 119 | 120 | # The boundary case in the last iteration 121 | if end > self.neigh.shape[0]: 122 | dis = end - self.neigh.shape[0] 123 | end = self.neigh.shape[0] 124 | buffer, _ = buffer.split([self.buffer_h-dis, dis]) 125 | 126 | # Perform octree2col 127 | neigh_i = self.neigh[start:end] 128 | valid = neigh_i >= 0 129 | buffer.fill_(0) 130 | buffer[valid] = data[neigh_i[valid]] 131 | 132 | # The sub-matrix gemm 133 | out[start:end] = torch.mm(buffer.flatten(1, 2), weights.flatten(0, 1)) 134 | 135 | return out 136 | 137 | def backward_gemm(self, out: torch.Tensor, grad: torch.Tensor, 138 | weights: torch.Tensor): 139 | r''' Performs the backward pass of octree-based convolution. 140 | ''' 141 | 142 | # Loop over each sub-matrix 143 | for i in range(self.buffer_n): 144 | start = i * self.buffer_h 145 | end = (i + 1) * self.buffer_h 146 | 147 | # The boundary case in the last iteration 148 | if end > self.neigh.shape[0]: 149 | end = self.neigh.shape[0] 150 | 151 | # The sub-matrix gemm 152 | buffer = torch.mm(grad[start:end], weights.flatten(0, 1).t()) 153 | buffer = buffer.view(-1, self.buffer_shape[1], self.buffer_shape[2]) 154 | buffer = buffer.to(out.dtype) # for pytorch.amp 155 | 156 | # Performs col2octree 157 | neigh_i = self.neigh[start:end] 158 | valid = neigh_i >= 0 159 | out = scatter_add(buffer[valid], neigh_i[valid], dim=0, out=out) 160 | 161 | return out 162 | 163 | def weight_gemm( 164 | self, out: torch.Tensor, data: torch.Tensor, grad: torch.Tensor): 165 | r''' Computes the gradient of the weight matrix. 166 | ''' 167 | 168 | # Record the shape of out 169 | out_shape = out.shape 170 | out = out.flatten(0, 1) 171 | 172 | # Initialize the buffer 173 | buffer = data.new_empty(self.buffer_shape) 174 | 175 | # Loop over each sub-matrix 176 | for i in range(self.buffer_n): 177 | start = i * self.buffer_h 178 | end = (i + 1) * self.buffer_h 179 | 180 | # The boundary case in the last iteration 181 | if end > self.neigh.shape[0]: 182 | d = end - self.neigh.shape[0] 183 | end = self.neigh.shape[0] 184 | buffer, _ = buffer.split([self.buffer_h-d, d]) 185 | 186 | # Perform octree2col 187 | neigh_i = self.neigh[start:end] 188 | valid = neigh_i >= 0 189 | buffer.fill_(0) 190 | buffer[valid] = data[neigh_i[valid]] 191 | 192 | # Accumulate the gradient via gemm 193 | out.addmm_(buffer.flatten(1, 2).t(), grad[start:end]) 194 | 195 | return out.view(out_shape) 196 | 197 | 198 | class _OctreeConv(OctreeConvBase): 199 | r''' Instantiates _OctreeConvBase by overriding `is_conv_layer` 200 | ''' 201 | 202 | def is_conv_layer(self): return True 203 | 204 | 205 | class _OctreeDeconv(OctreeConvBase): 206 | r''' Instantiates _OctreeConvBase by overriding `is_conv_layer` 207 | ''' 208 | 209 | def is_conv_layer(self): return False 210 | 211 | 212 | class OctreeConvFunction(Function): 213 | r''' Wrap the octree convolution for auto-diff. 214 | ''' 215 | 216 | @staticmethod 217 | def forward(ctx, data: torch.Tensor, weights: torch.Tensor, octree: Octree, 218 | depth: int, in_channels: int, out_channels: int, 219 | kernel_size: List[int] = [3, 3, 3], stride: int = 1, 220 | nempty: bool = False, max_buffer: int = int(2e8)): 221 | octree_conv = _OctreeConv( 222 | in_channels, out_channels, kernel_size, stride, nempty, max_buffer) 223 | octree_conv.setup(octree, depth) 224 | out = octree_conv.check_and_init(data) 225 | out = octree_conv.forward_gemm(out, data, weights) 226 | 227 | ctx.save_for_backward(data, weights) 228 | ctx.octree_conv = octree_conv 229 | return out 230 | 231 | @staticmethod 232 | def backward(ctx, grad): 233 | data, weights = ctx.saved_tensors 234 | octree_conv = ctx.octree_conv 235 | 236 | grad_out = None 237 | if ctx.needs_input_grad[0]: 238 | grad_out = torch.zeros_like(data) 239 | grad_out = octree_conv.backward_gemm(grad_out, grad, weights) 240 | 241 | grad_w = None 242 | if ctx.needs_input_grad[1]: 243 | grad_w = torch.zeros_like(weights) 244 | grad_w = octree_conv.weight_gemm(grad_w, data, grad) 245 | 246 | return (grad_out, grad_w) + (None,) * 8 247 | 248 | 249 | class OctreeDeconvFunction(Function): 250 | r''' Wrap the octree deconvolution for auto-diff. 251 | ''' 252 | 253 | @staticmethod 254 | def forward(ctx, data: torch.Tensor, weights: torch.Tensor, octree: Octree, 255 | depth: int, in_channels: int, out_channels: int, 256 | kernel_size: List[int] = [3, 3, 3], stride: int = 1, 257 | nempty: bool = False, max_buffer: int = int(2e8)): 258 | octree_deconv = _OctreeDeconv( 259 | in_channels, out_channels, kernel_size, stride, nempty, max_buffer) 260 | octree_deconv.setup(octree, depth) 261 | out = octree_deconv.check_and_init(data) 262 | out = octree_deconv.backward_gemm(out, data, weights) 263 | 264 | ctx.save_for_backward(data, weights) 265 | ctx.octree_deconv = octree_deconv 266 | return out 267 | 268 | @staticmethod 269 | def backward(ctx, grad): 270 | data, weights = ctx.saved_tensors 271 | octree_deconv = ctx.octree_deconv 272 | 273 | grad_out = None 274 | if ctx.needs_input_grad[0]: 275 | grad_out = torch.zeros_like(data) 276 | grad_out = octree_deconv.forward_gemm(grad_out, grad, weights) 277 | 278 | grad_w = None 279 | if ctx.needs_input_grad[1]: 280 | grad_w = torch.zeros_like(weights) 281 | grad_w = octree_deconv.weight_gemm(grad_w, grad, data) 282 | 283 | return (grad_out, grad_w) + (None,) * 8 284 | 285 | 286 | # alias 287 | octree_conv = OctreeConvFunction.apply 288 | octree_deconv = OctreeDeconvFunction.apply 289 | 290 | 291 | class OctreeConv(OctreeConvBase, torch.nn.Module): 292 | r''' Performs octree convolution. 293 | 294 | Args: 295 | in_channels (int): Number of input channels. 296 | out_channels (int): Number of output channels. 297 | kernel_size (List(int)): The kernel shape, choose from :obj:`[3]`, :obj:`[2]`, 298 | :obj:`[3,3,3]`, :obj:`[3,1,1]`, :obj:`[1,3,1]`, :obj:`[1,1,3]`, 299 | :obj:`[2,2,2]`, :obj:`[3,3,1]`, :obj:`[1,3,3]`, and :obj:`[3,1,3]`. 300 | stride (int): The stride of the convolution (:obj:`1` or :obj:`2`). 301 | nempty (bool): If True, only performs the convolution on non-empty 302 | octree nodes. 303 | direct_method (bool): If True, directly performs the convolution via using 304 | gemm and octree2col/col2octree. The octree2col/col2octree needs to 305 | construct a large matrix, which may consume a lot of memory. If False, 306 | performs the convolution in a sub-matrix manner, which can save the 307 | requied runtime memory. 308 | use_bias (bool): If True, add a bias term to the convolution. 309 | max_buffer (int): The maximum number of elements in the buffer, used when 310 | :attr:`direct_method` is False. 311 | ''' 312 | 313 | def __init__(self, in_channels: int, out_channels: int, 314 | kernel_size: List[int] = [3], stride: int = 1, 315 | nempty: bool = False, direct_method: bool = False, 316 | use_bias: bool = False, max_buffer: int = int(2e8)): 317 | super().__init__( 318 | in_channels, out_channels, kernel_size, stride, nempty, max_buffer) 319 | 320 | self.direct_method = direct_method 321 | self.use_bias = use_bias 322 | self.weights = torch.nn.Parameter(torch.Tensor(*self.weights_shape)) 323 | if self.use_bias: 324 | self.bias = torch.nn.Parameter(torch.Tensor(out_channels)) 325 | self.reset_parameters() 326 | 327 | def reset_parameters(self): 328 | xavier_uniform_(self.weights) 329 | if self.use_bias: 330 | torch.nn.init.zeros_(self.bias) 331 | 332 | def is_conv_layer(self): return True 333 | 334 | def forward(self, data: torch.Tensor, octree: Octree, depth: int): 335 | r''' Defines the octree convolution. 336 | 337 | Args: 338 | data (torch.Tensor): The input data. 339 | octree (Octree): The corresponding octree. 340 | depth (int): The depth of current octree. 341 | ''' 342 | 343 | if self.direct_method: 344 | col = octree2col( 345 | data, octree, depth, self.kernel, self.stride, self.nempty) 346 | out = torch.mm(col.flatten(1), self.weights.flatten(0, 1)) 347 | else: 348 | out = octree_conv( 349 | data, self.weights, octree, depth, self.in_channels, 350 | self.out_channels, self.kernel_size, self.stride, self.nempty, 351 | self.max_buffer) 352 | 353 | if self.use_bias: 354 | out += self.bias 355 | 356 | if self.stride == 2 and not self.nempty: 357 | out = octree_pad(out, octree, depth-1) 358 | return out 359 | 360 | def extra_repr(self) -> str: 361 | r''' Sets the extra representation of the module. 362 | ''' 363 | 364 | return ('in_channels={}, out_channels={}, kernel_size={}, stride={}, ' 365 | 'nempty={}, bias={}').format(self.in_channels, self.out_channels, 366 | self.kernel_size, self.stride, self.nempty, self.use_bias) # noqa 367 | 368 | 369 | class OctreeDeconv(OctreeConv): 370 | r''' Performs octree deconvolution. 371 | 372 | Please refer to :class:`OctreeConv` for the meaning of the arguments. 373 | ''' 374 | 375 | def is_conv_layer(self): return False 376 | 377 | def forward(self, data: torch.Tensor, octree: Octree, depth: int): 378 | r''' Defines the octree deconvolution. 379 | 380 | Please refer to :meth:`OctreeConv.forward` for the meaning of the arguments. 381 | ''' 382 | 383 | depth_col = depth 384 | if self.stride == 2: 385 | depth_col = depth + 1 386 | if not self.nempty: 387 | data = octree_depad(data, octree, depth) 388 | 389 | if self.direct_method: 390 | col = torch.mm(data, self.weights.flatten(0, 1).t()) 391 | col = col.view(col.shape[0], self.kdim, -1) 392 | out = col2octree( 393 | col, octree, depth_col, self.kernel, self.stride, self.nempty) 394 | else: 395 | out = octree_deconv( 396 | data, self.weights, octree, depth, self.in_channels, 397 | self.out_channels, self.kernel_size, self.stride, self.nempty, 398 | self.max_buffer) 399 | 400 | if self.use_bias: 401 | out += self.bias 402 | return out 403 | -------------------------------------------------------------------------------- /utils/ocnn/nn/octree_drop.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Octree-based Sparse Convolutional Neural Networks 3 | # Copyright (c) 2022 Peng-Shuai Wang 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | import torch 9 | from typing import Optional 10 | 11 | from utils.ocnn.octree import Octree 12 | 13 | 14 | class OctreeDropPath(torch.nn.Module): 15 | r'''Drop paths (Stochastic Depth) per sample when applied in main path of 16 | residual blocks, following the logic of :func:`timm.models.layers.DropPath`. 17 | 18 | Args: 19 | drop_prob (int): The probability of drop paths. 20 | nempty (bool): Indicate whether the input data only contains features of the 21 | non-empty octree nodes or not. 22 | scale_by_keep (bool): Whether to scale the kept features proportionally. 23 | ''' 24 | 25 | def __init__(self, drop_prob: float = 0.0, nempty: bool = False, 26 | scale_by_keep: bool = True): 27 | super().__init__() 28 | 29 | self.drop_prob = drop_prob 30 | self.nempty = nempty 31 | self.scale_by_keep = scale_by_keep 32 | 33 | def forward(self, data: torch.Tensor, octree: Octree, depth: int, 34 | batch_id: Optional[torch.Tensor] = None): 35 | r'''''' 36 | 37 | if self.drop_prob <= 0.0 or not self.training: 38 | return data 39 | 40 | batch_size = octree.batch_size 41 | keep_prob = 1 - self.drop_prob 42 | rnd_tensor = torch.rand(batch_size, 1, dtype=data.dtype, device=data.device) 43 | rnd_tensor = torch.floor(rnd_tensor + keep_prob) 44 | if keep_prob > 0.0 and self.scale_by_keep: 45 | rnd_tensor.div_(keep_prob) 46 | 47 | if batch_id is None: 48 | batch_id = octree.batch_id(depth, self.nempty) 49 | drop_mask = rnd_tensor[batch_id] 50 | output = data * drop_mask 51 | return output 52 | 53 | def extra_repr(self) -> str: 54 | return ('drop_prob={:.4f}, nempty={}, scale_by_keep={}').format( 55 | self.drop_prob, self.nempty, self.scale_by_keep) # noqa 56 | -------------------------------------------------------------------------------- /utils/ocnn/nn/octree_dwconv.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Octree-based Sparse Convolutional Neural Networks 3 | # Copyright (c) 2022 Peng-Shuai Wang 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | import torch 9 | import torch.nn 10 | from torch.autograd import Function 11 | from typing import List 12 | 13 | from utils.ocnn.octree import Octree 14 | from utils.ocnn.utils import scatter_add, xavier_uniform_ 15 | from .octree_pad import octree_pad 16 | from .octree_conv import OctreeConvBase 17 | 18 | 19 | class OctreeDWConvBase(OctreeConvBase): 20 | 21 | def __init__(self, in_channels: int, kernel_size: List[int] = [3], 22 | stride: int = 1, nempty: bool = False, 23 | max_buffer: int = int(2e8)): 24 | super().__init__( 25 | in_channels, in_channels, kernel_size, stride, nempty, max_buffer) 26 | self.weights_shape = (self.kdim, 1, self.out_channels) 27 | 28 | def is_conv_layer(self): return True 29 | 30 | def forward_gemm(self, out: torch.Tensor, data: torch.Tensor, 31 | weights: torch.Tensor): 32 | r''' Peforms the forward pass of octree-based convolution. 33 | ''' 34 | 35 | # Initialize the buffer 36 | buffer = data.new_empty(self.buffer_shape) 37 | 38 | # Loop over each sub-matrix 39 | for i in range(self.buffer_n): 40 | start = i * self.buffer_h 41 | end = (i + 1) * self.buffer_h 42 | 43 | # The boundary case in the last iteration 44 | if end > self.neigh.shape[0]: 45 | dis = end - self.neigh.shape[0] 46 | end = self.neigh.shape[0] 47 | buffer, _ = buffer.split([self.buffer_h-dis, dis]) 48 | 49 | # Perform octree2col 50 | neigh_i = self.neigh[start:end] 51 | valid = neigh_i >= 0 52 | buffer.fill_(0) 53 | buffer[valid] = data[neigh_i[valid]] 54 | 55 | # The sub-matrix gemm 56 | # out[start:end] = torch.mm(buffer.flatten(1, 2), weights.flatten(0, 1)) 57 | out[start:end] = torch.einsum('ikc,kc->ic', buffer, weights.flatten(0, 1)) 58 | return out 59 | 60 | def backward_gemm(self, out: torch.Tensor, grad: torch.Tensor, 61 | weights: torch.Tensor): 62 | r''' Performs the backward pass of octree-based convolution. 63 | ''' 64 | 65 | # Loop over each sub-matrix 66 | for i in range(self.buffer_n): 67 | start = i * self.buffer_h 68 | end = (i + 1) * self.buffer_h 69 | 70 | # The boundary case in the last iteration 71 | if end > self.neigh.shape[0]: 72 | end = self.neigh.shape[0] 73 | 74 | # The sub-matrix gemm 75 | # buffer = torch.mm(grad[start:end], weights.flatten(0, 1).t()) 76 | # buffer = buffer.view(-1, self.buffer_shape[1], self.buffer_shape[2]) 77 | buffer = torch.einsum( 78 | 'ic,kc->ikc', grad[start:end], weights.flatten(0, 1)) 79 | 80 | # Performs col2octree 81 | neigh_i = self.neigh[start:end] 82 | valid = neigh_i >= 0 83 | out = scatter_add(buffer[valid], neigh_i[valid], dim=0, out=out) 84 | 85 | return out 86 | 87 | def weight_gemm(self, out: torch.Tensor, data: torch.Tensor, grad: torch.Tensor): 88 | r''' Computes the gradient of the weight matrix. 89 | ''' 90 | 91 | # Record the shape of out 92 | out_shape = out.shape 93 | out = out.flatten(0, 1) 94 | 95 | # Initialize the buffer 96 | buffer = data.new_empty(self.buffer_shape) 97 | 98 | # Loop over each sub-matrix 99 | for i in range(self.buffer_n): 100 | start = i * self.buffer_h 101 | end = (i + 1) * self.buffer_h 102 | 103 | # The boundary case in the last iteration 104 | if end > self.neigh.shape[0]: 105 | d = end - self.neigh.shape[0] 106 | end = self.neigh.shape[0] 107 | buffer, _ = buffer.split([self.buffer_h-d, d]) 108 | 109 | # Perform octree2col 110 | neigh_i = self.neigh[start:end] 111 | valid = neigh_i >= 0 112 | buffer.fill_(0) 113 | buffer[valid] = data[neigh_i[valid]] 114 | 115 | # Accumulate the gradient via gemm 116 | # out.addmm_(buffer.flatten(1, 2).t(), grad[start:end]) 117 | out += torch.einsum('ikc,ic->kc', buffer, grad[start:end]) 118 | return out.view(out_shape) 119 | 120 | 121 | class OctreeDWConvFunction(Function): 122 | r''' Wrap the octree convolution for auto-diff. 123 | ''' 124 | 125 | @staticmethod 126 | def forward(ctx, data: torch.Tensor, weights: torch.Tensor, octree: Octree, 127 | depth: int, in_channels: int, kernel_size: List[int] = [3, 3, 3], 128 | stride: int = 1, nempty: bool = False, max_buffer: int = int(2e8)): 129 | octree_conv = OctreeDWConvBase( 130 | in_channels, kernel_size, stride, nempty, max_buffer) 131 | octree_conv.setup(octree, depth) 132 | out = octree_conv.check_and_init(data) 133 | out = octree_conv.forward_gemm(out, data, weights) 134 | 135 | ctx.save_for_backward(data, weights) 136 | ctx.octree_conv = octree_conv 137 | return out 138 | 139 | @staticmethod 140 | def backward(ctx, grad): 141 | data, weights = ctx.saved_tensors 142 | octree_conv = ctx.octree_conv 143 | 144 | grad_out = None 145 | if ctx.needs_input_grad[0]: 146 | grad_out = torch.zeros_like(data) 147 | grad_out = octree_conv.backward_gemm(grad_out, grad, weights) 148 | 149 | grad_w = None 150 | if ctx.needs_input_grad[1]: 151 | grad_w = torch.zeros_like(weights) 152 | grad_w = octree_conv.weight_gemm(grad_w, data, grad) 153 | 154 | return (grad_out, grad_w) + (None,) * 7 155 | 156 | 157 | # alias 158 | octree_dwconv = OctreeDWConvFunction.apply 159 | 160 | 161 | class OctreeDWConv(OctreeDWConvBase, torch.nn.Module): 162 | r''' Performs octree-based depth-wise convolution. 163 | 164 | Please refer to :class:`ocnn.nn.OctreeConv` for the meaning of the arguments. 165 | 166 | .. note:: 167 | This implementation uses the :func:`torch.einsum` and I find that the speed 168 | is relatively slow. Further optimization is needed to speed it up. 169 | ''' 170 | 171 | def __init__(self, in_channels: int, kernel_size: List[int] = [3], 172 | stride: int = 1, nempty: bool = False, use_bias: bool = False, 173 | max_buffer: int = int(2e8)): 174 | super().__init__(in_channels, kernel_size, stride, nempty, max_buffer) 175 | 176 | self.use_bias = use_bias 177 | self.weights = torch.nn.Parameter(torch.Tensor(*self.weights_shape)) 178 | if self.use_bias: 179 | self.bias = torch.nn.Parameter(torch.Tensor(in_channels)) 180 | self.reset_parameters() 181 | 182 | def reset_parameters(self): 183 | xavier_uniform_(self.weights) 184 | if self.use_bias: 185 | torch.nn.init.zeros_(self.bias) 186 | 187 | def forward(self, data: torch.Tensor, octree: Octree, depth: int): 188 | r'''''' 189 | 190 | out = octree_dwconv( 191 | data, self.weights, octree, depth, self.in_channels, 192 | self.kernel_size, self.stride, self.nempty, self.max_buffer) 193 | 194 | if self.use_bias: 195 | out += self.bias 196 | 197 | if self.stride == 2 and not self.nempty: 198 | out = octree_pad(out, octree, depth-1) 199 | return out 200 | 201 | def extra_repr(self) -> str: 202 | return ('in_channels={}, out_channels={}, kernel_size={}, stride={}, ' 203 | 'nempty={}, bias={}').format(self.in_channels, self.out_channels, 204 | self.kernel_size, self.stride, self.nempty, self.use_bias) # noqa 205 | -------------------------------------------------------------------------------- /utils/ocnn/nn/octree_interp.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Octree-based Sparse Convolutional Neural Networks 3 | # Copyright (c) 2022 Peng-Shuai Wang 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | import torch 9 | import torch.sparse 10 | from typing import Optional 11 | 12 | from utils import ocnn 13 | from utils.ocnn.octree import Octree 14 | 15 | 16 | def octree_nearest_pts(data: torch.Tensor, octree: Octree, depth: int, 17 | pts: torch.Tensor, nempty: bool = False, 18 | bound_check: bool = False): 19 | ''' The nearest-neighbor interpolatation with input points. 20 | 21 | Args: 22 | data (torch.Tensor): The input data. 23 | octree (Octree): The octree to interpolate. 24 | depth (int): The depth of the data. 25 | pts (torch.Tensor): The coordinates of the points with shape :obj:`(N, 4)`, 26 | i.e. :obj:`N x (x, y, z, batch)`. 27 | nempty (bool): If true, the :attr:`data` only contains features of non-empty 28 | octree nodes 29 | bound_check (bool): If true, check whether the point is in :obj:`[0, 2^depth)`. 30 | 31 | .. note:: 32 | The :attr:`pts` MUST be scaled into :obj:`[0, 2^depth)`. 33 | ''' 34 | 35 | nnum = octree.nnum_nempty[depth] if nempty else octree.nnum[depth] 36 | assert data.shape[0] == nnum, 'The shape of input data is wrong.' 37 | 38 | idx = octree.search_xyzb(pts, depth, nempty) 39 | valid = idx > -1 # valid indices 40 | if bound_check: 41 | bound = torch.logical_and(pts[:, :3] >= 0, pts[:, :3] < 2**depth).all(1) 42 | valid = torch.logical_and(valid, bound) 43 | 44 | size = (pts.shape[0], data.shape[1]) 45 | out = torch.zeros(size, device=data.device, dtype=data.dtype) 46 | out[valid] = data.index_select(0, idx[valid]) 47 | return out 48 | 49 | 50 | def octree_linear_pts(data: torch.Tensor, octree: Octree, depth: int, 51 | pts: torch.Tensor, nempty: bool = False, 52 | bound_check: bool = False): 53 | ''' Linear interpolatation with input points. 54 | 55 | Refer to :func:`octree_nearest_pts` for the meaning of the arguments. 56 | ''' 57 | 58 | nnum = octree.nnum_nempty[depth] if nempty else octree.nnum[depth] 59 | assert data.shape[0] == nnum, 'The shape of input data is wrong.' 60 | 61 | device = data.device 62 | grid = torch.tensor( 63 | [[0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1], 64 | [1, 0, 0], [1, 0, 1], [1, 1, 0], [1, 1, 1]], device=device) 65 | 66 | # 1. Neighborhood searching 67 | xyzf = pts[:, :3] - 0.5 # the value is defined on the center of each voxel 68 | xyzi = xyzf.floor() # the integer part (N, 3) 69 | frac = xyzf - xyzi # the fraction part (N, 3) 70 | 71 | xyzn = (xyzi.unsqueeze(1) + grid).view(-1, 3) 72 | batch = pts[:, 3].unsqueeze(1).repeat(1, 8).view(-1, 1) 73 | idx = octree.search_xyzb(torch.cat([xyzn, batch], dim=1), depth, nempty) 74 | valid = idx > -1 # valid indices 75 | if bound_check: 76 | bound = torch.logical_and(xyzn >= 0, xyzn < 2**depth).all(1) 77 | valid = torch.logical_and(valid, bound) 78 | idx = idx[valid] 79 | 80 | # 2. Build the sparse matrix 81 | npt = pts.shape[0] 82 | ids = torch.arange(npt, device=idx.device) 83 | ids = ids.unsqueeze(1).repeat(1, 8).view(-1) 84 | ids = ids[valid] 85 | indices = torch.stack([ids, idx], dim=0).long() 86 | 87 | frac = (1.0 - grid) - frac.unsqueeze(dim=1) # (8, 3) - (N, 1, 3) -> (N, 8, 3) 88 | weight = frac.prod(dim=2).abs().view(-1) # (8*N,) 89 | weight = weight[valid] 90 | 91 | h = data.shape[0] 92 | mat = torch.sparse_coo_tensor(indices, weight, [npt, h], device=device) 93 | 94 | # 3. Interpolatation 95 | output = torch.sparse.mm(mat, data) 96 | ones = torch.ones(h, 1, dtype=data.dtype, device=device) 97 | norm = torch.sparse.mm(mat, ones) 98 | output = torch.div(output, norm + 1e-12) 99 | return output 100 | 101 | 102 | class OctreeInterp(torch.nn.Module): 103 | r''' Interpolates the points with an octree feature. 104 | 105 | Refer to :func:`octree_nearest_pts` for a description of arguments. 106 | ''' 107 | 108 | def __init__(self, method: str = 'linear', nempty: bool = False, 109 | bound_check: bool = False, rescale_pts: bool = True): 110 | super().__init__() 111 | self.method = method 112 | self.nempty = nempty 113 | self.bound_check = bound_check 114 | self.rescale_pts = rescale_pts 115 | self.func = octree_linear_pts if method == 'linear' else octree_nearest_pts 116 | 117 | def forward(self, data: torch.Tensor, octree: Octree, depth: int, 118 | pts: torch.Tensor): 119 | r'''''' 120 | 121 | # rescale points from [-1, 1] to [0, 2^depth] 122 | if self.rescale_pts: 123 | scale = 2 ** (depth - 1) 124 | pts[:, :3] = (pts[:, :3] + 1.0) * scale 125 | 126 | return self.func(data, octree, depth, pts, self.nempty, self.bound_check) 127 | 128 | def extra_repr(self) -> str: 129 | r''' Sets the extra representation of the module. 130 | ''' 131 | 132 | return ('method={}, nempty={}, bound_check={}, rescale_pts={}').format( 133 | self.method, self.nempty, self.bound_check, self.rescale_pts) # noqa 134 | 135 | 136 | def octree_nearest_upsample(data: torch.Tensor, octree: Octree, depth: int, 137 | nempty: bool = False): 138 | r''' Upsamples the octree node features from :attr:`depth` to :attr:`(depth+1)` 139 | with the nearest-neighbor interpolation. 140 | 141 | Args: 142 | data (torch.Tensor): The input data. 143 | octree (Octree): The octree to interpolate. 144 | depth (int): The depth of the data. 145 | nempty (bool): If true, the :attr:`data` only contains features of non-empty 146 | octree nodes. 147 | ''' 148 | 149 | nnum = octree.nnum_nempty[depth] if nempty else octree.nnum[depth] 150 | assert data.shape[0] == nnum, 'The shape of input data is wrong.' 151 | 152 | out = data 153 | if not nempty: 154 | out = ocnn.nn.octree_depad(out, octree, depth) 155 | out = out.unsqueeze(1).repeat(1, 8, 1).flatten(end_dim=1) 156 | if nempty: 157 | out = ocnn.nn.octree_depad(out, octree, depth + 1) # !!! depth+1 158 | return out 159 | 160 | 161 | class OctreeUpsample(torch.nn.Module): 162 | r''' Upsamples the octree node features from :attr:`depth` to 163 | :attr:`(target_depth)`. 164 | 165 | Refer to :class:`octree_nearest_pts` for details. 166 | ''' 167 | 168 | def __init__(self, method: str = 'linear', nempty: bool = False): 169 | super().__init__() 170 | self.method = method 171 | self.nempty = nempty 172 | self.func = octree_linear_pts if method == 'linear' else octree_nearest_pts 173 | 174 | def forward(self, data: torch.Tensor, octree: Octree, depth: int, 175 | target_depth: Optional[int] = None): 176 | r'''''' 177 | 178 | if target_depth is None: 179 | target_depth = depth + 1 180 | if target_depth == depth: 181 | return data # return, do nothing 182 | assert target_depth >= depth, 'target_depth must be larger than depth' 183 | 184 | if target_depth == depth + 1 and self.method == 'nearest': 185 | return octree_nearest_upsample(data, octree, depth, self.nempty) 186 | 187 | xyzb = octree.xyzb(target_depth, self.nempty) 188 | pts = torch.stack(xyzb, dim=1).float() 189 | pts[:, :3] = (pts[:, :3] + 0.5) * (2**(depth - target_depth)) # !!! rescale 190 | return self.func(data, octree, depth, pts, self.nempty) 191 | 192 | def extra_repr(self) -> str: 193 | r''' Sets the extra representation of the module. 194 | ''' 195 | 196 | return ('method={}, nempty={}').format(self.method, self.nempty) 197 | -------------------------------------------------------------------------------- /utils/ocnn/nn/octree_norm.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Octree-based Sparse Convolutional Neural Networks 3 | # Copyright (c) 2022 Peng-Shuai Wang 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | import torch 9 | import torch.nn 10 | 11 | from utils.ocnn.octree import Octree 12 | from utils.ocnn.utils import scatter_add 13 | 14 | 15 | OctreeBatchNorm = torch.nn.BatchNorm1d 16 | 17 | 18 | class OctreeInstanceNorm(torch.nn.Module): 19 | r''' An instance normalization layer for the octree. 20 | ''' 21 | 22 | def __init__(self, in_channels: int, nempty: bool = False): 23 | super().__init__() 24 | 25 | self.eps = 1e-5 26 | self.nempty = nempty 27 | self.in_channels = in_channels 28 | 29 | self.weights = torch.nn.Parameter(torch.Tensor(1, in_channels)) 30 | self.bias = torch.nn.Parameter(torch.Tensor(1, in_channels)) 31 | self.reset_parameters() 32 | 33 | def reset_parameters(self): 34 | torch.nn.init.ones_(self.weights) 35 | torch.nn.init.zeros_(self.bias) 36 | 37 | def forward(self, data: torch.Tensor, octree: Octree, depth: int): 38 | r'''''' 39 | 40 | batch_size = octree.batch_size 41 | batch_id = octree.batch_id(depth, self.nempty) 42 | ones = data.new_ones([data.shape[0], 1]) 43 | count = scatter_add(ones, batch_id, dim=0, dim_size=batch_size) 44 | norm = 1.0 / (count + self.eps) # there might be 0 element in some shapes 45 | 46 | mean = scatter_add(data, batch_id, dim=0, dim_size=batch_size) * norm 47 | out = data - mean.index_select(0, batch_id) 48 | var = scatter_add(out * out, batch_id, dim=0, dim_size=batch_size) * norm 49 | inv_std = 1.0 / (var + self.eps).sqrt() 50 | out = out * inv_std.index_select(0, batch_id) 51 | 52 | out = out * self.weights + self.bias 53 | return out 54 | 55 | def extra_repr(self) -> str: 56 | return ('in_channels={}, nempty={}').format(self.in_channels, self.nempty) 57 | -------------------------------------------------------------------------------- /utils/ocnn/nn/octree_pad.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Octree-based Sparse Convolutional Neural Networks 3 | # Copyright (c) 2022 Peng-Shuai Wang 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | import torch 9 | 10 | from ..octree import Octree 11 | 12 | 13 | def octree_pad(data: torch.Tensor, octree: Octree, depth: int, val: float = 0.0): 14 | r''' Pads :attr:`val` to make the number of elements of :attr:`data` equal to 15 | the octree node number. 16 | 17 | Args: 18 | data (torch.Tensor): The input tensor with its number of elements equal to the 19 | non-empty octree node number. 20 | octree (Octree): The corresponding octree. 21 | depth (int): The depth of current octree. 22 | val (float): The padding value. (Default: :obj:`0.0`) 23 | ''' 24 | 25 | mask = octree.nempty_mask(depth) 26 | size = (octree.nnum[depth], data.shape[1]) # (N, C) 27 | out = torch.full(size, val, dtype=data.dtype, device=data.device) 28 | out[mask] = data 29 | return out 30 | 31 | 32 | def octree_depad(data: torch.Tensor, octree: Octree, depth: int): 33 | r''' Reverse operation of :func:`octree_depad`. 34 | 35 | Please refer to :func:`octree_depad` for the meaning of the arguments. 36 | ''' 37 | 38 | mask = octree.nempty_mask(depth) 39 | return data[mask] 40 | -------------------------------------------------------------------------------- /utils/ocnn/nn/octree_pool.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Octree-based Sparse Convolutional Neural Networks 3 | # Copyright (c) 2022 Peng-Shuai Wang 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | import torch 9 | import torch.nn 10 | from typing import List 11 | 12 | from utils.ocnn.octree import Octree 13 | from utils.ocnn.utils import meshgrid, scatter_add, resize_with_last_val, list2str 14 | from . import octree_pad, octree_depad 15 | 16 | 17 | def octree_max_pool(data: torch.Tensor, octree: Octree, depth: int, 18 | nempty: bool = False, return_indices: bool = False): 19 | r''' Performs octree max pooling with kernel size 2 and stride 2. 20 | 21 | Args: 22 | data (torch.Tensor): The input tensor. 23 | octree (Octree): The corresponding octree. 24 | depth (int): The depth of current octree. After pooling, the corresponding 25 | depth decreased by 1. 26 | nempty (bool): If True, :attr:`data` contains only features of non-empty 27 | octree nodes. 28 | return_indices (bool): If True, returns the indices, which can be used in 29 | :func:`octree_max_unpool`. 30 | ''' 31 | 32 | if nempty: 33 | data = octree_pad(data, octree, depth, float('-inf')) 34 | data = data.view(-1, 8, data.shape[1]) 35 | out, indices = data.max(dim=1) 36 | if not nempty: 37 | out = octree_pad(out, octree, depth-1) 38 | return (out, indices) if return_indices else out 39 | 40 | 41 | def octree_max_unpool(data: torch.Tensor, indices: torch.Tensor, octree: Octree, 42 | depth: int, nempty: bool = False): 43 | r''' Performs octree max unpooling. 44 | 45 | Args: 46 | data (torch.Tensor): The input tensor. 47 | indices (torch.Tensor): The indices returned by :func:`octree_max_pool`. The 48 | depth of :attr:`indices` is larger by 1 than :attr:`data`. 49 | octree (Octree): The corresponding octree. 50 | depth (int): The depth of current data. After unpooling, the corresponding 51 | depth increases by 1. 52 | ''' 53 | 54 | if not nempty: 55 | data = octree_depad(data, octree, depth) 56 | num, channel = data.shape 57 | out = torch.zeros(num, 8, channel, dtype=data.dtype, device=data.device) 58 | i = torch.arange(num, dtype=indices.dtype, device=indices.device) 59 | k = torch.arange(channel, dtype=indices.dtype, device=indices.device) 60 | i, k = meshgrid(i, k, indexing='ij') 61 | out[i, indices, k] = data 62 | out = out.view(-1, channel) 63 | if nempty: 64 | out = octree_depad(out, octree, depth+1) 65 | return out 66 | 67 | 68 | def octree_avg_pool(data: torch.Tensor, octree: Octree, depth: int, 69 | kernel: str, stride: int = 2, nempty: bool = False): 70 | r''' Performs octree average pooling. 71 | 72 | Args: 73 | data (torch.Tensor): The input tensor. 74 | octree (Octree): The corresponding octree. 75 | depth (int): The depth of current octree. 76 | kernel (str): The kernel size, like '333', '222'. 77 | stride (int): The stride of the pooling. 78 | nempty (bool): If True, :attr:`data` contains only features of non-empty 79 | octree nodes. 80 | ''' 81 | 82 | neigh = octree.get_neigh(depth, kernel, stride, nempty) 83 | 84 | N1 = data.shape[0] 85 | N2 = neigh.shape[0] 86 | K = neigh.shape[1] 87 | 88 | mask = neigh >= 0 89 | val = 1.0 / (torch.sum(mask, dim=1) + 1e-8) 90 | mask = mask.view(-1) 91 | val = val.unsqueeze(1).repeat(1, K).reshape(-1) 92 | val = val[mask] 93 | 94 | row = torch.arange(N2, device=neigh.device) 95 | row = row.unsqueeze(1).repeat(1, K).view(-1) 96 | col = neigh.view(-1) 97 | indices = torch.stack([row[mask], col[mask]], dim=0).long() 98 | 99 | mat = torch.sparse_coo_tensor(indices, val, [N2, N1], device=data.device) 100 | out = torch.sparse.mm(mat, data) 101 | return out 102 | 103 | 104 | def octree_global_pool(data: torch.Tensor, octree: Octree, depth: int, 105 | nempty: bool = False): 106 | r''' Performs octree global average pooling. 107 | 108 | Args: 109 | data (torch.Tensor): The input tensor. 110 | octree (Octree): The corresponding octree. 111 | depth (int): The depth of current octree. 112 | nempty (bool): If True, :attr:`data` contains only features of non-empty 113 | octree nodes. 114 | ''' 115 | 116 | batch_size = octree.batch_size 117 | batch_id = octree.batch_id(depth, nempty) 118 | ones = data.new_ones(data.shape[0], 1) 119 | count = scatter_add(ones, batch_id, dim=0, dim_size=batch_size) 120 | count[count < 1] = 1 # there might be 0 element in some shapes 121 | 122 | out = scatter_add(data, batch_id, dim=0, dim_size=batch_size) 123 | out = out / count 124 | return out 125 | 126 | 127 | class OctreePoolBase(torch.nn.Module): 128 | r''' The base class for octree-based pooling. 129 | ''' 130 | 131 | def __init__(self, kernel_size: List[int], stride: int, nempty: bool = False): 132 | super().__init__() 133 | self.kernel_size = resize_with_last_val(kernel_size) 134 | self.kernel = list2str(self.kernel_size) 135 | self.stride = stride 136 | self.nempty = nempty 137 | 138 | def extra_repr(self) -> str: 139 | return ('kernel={}, stride={}, nempty={}').format( 140 | self.kernel, self.stride, self.nempty) # noqa 141 | 142 | 143 | class OctreeMaxPool(OctreePoolBase): 144 | r''' Performs octree max pooling. 145 | 146 | Please refer to :func:`octree_max_pool` for details. 147 | ''' 148 | 149 | def __init__(self, nempty: bool = False, return_indices: bool = False): 150 | super().__init__(kernel_size=[2], stride=2, nempty=nempty) 151 | self.return_indices = return_indices 152 | 153 | def forward(self, data: torch.Tensor, octree: Octree, depth: int): 154 | r'''''' 155 | 156 | return octree_max_pool(data, octree, depth, self.nempty, self.return_indices) 157 | 158 | 159 | class OctreeMaxUnpool(OctreePoolBase): 160 | r''' Performs octree max unpooling. 161 | 162 | Please refer to :func:`octree_max_unpool` for details. 163 | ''' 164 | 165 | def forward(self, data: torch.Tensor, indices: torch.Tensor, octree: Octree, 166 | depth: int): 167 | r'''''' 168 | 169 | return octree_max_unpool(data, indices, octree, depth, self.nempty) 170 | 171 | 172 | class OctreeGlobalPool(OctreePoolBase): 173 | r''' Performs octree global pooling. 174 | 175 | Please refer to :func:`octree_global_pool` for details. 176 | ''' 177 | 178 | def __init__(self, nempty: bool = False): 179 | super().__init__(kernel_size=[-1], stride=-1, nempty=nempty) 180 | 181 | def forward(self, data: torch.Tensor, octree: Octree, depth: int): 182 | r'''''' 183 | 184 | return octree_global_pool(data, octree, depth, self.nempty) 185 | 186 | 187 | class OctreeAvgPool(OctreePoolBase): 188 | r''' Performs octree average pooling. 189 | 190 | Please refer to :func:`octree_avg_pool` for details. 191 | ''' 192 | 193 | def forward(self, data: torch.Tensor, octree: Octree, depth: int): 194 | r'''''' 195 | 196 | return octree_avg_pool( 197 | data, octree, depth, self.kernel, self.stride, self.nempty) 198 | -------------------------------------------------------------------------------- /utils/ocnn/octree/__init__.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Octree-based Sparse Convolutional Neural Networks 3 | # Copyright (c) 2022 Peng-Shuai Wang 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | from .shuffled_key import key2xyz, xyz2key 9 | from .points import Points, merge_points 10 | from .octree import Octree, merge_octrees 11 | 12 | __all__ = [ 13 | 'key2xyz', 14 | 'xyz2key', 15 | 'Points', 16 | 'Octree', 17 | 'merge_points', 18 | 'merge_octrees', 19 | ] 20 | 21 | classes = __all__ 22 | -------------------------------------------------------------------------------- /utils/ocnn/octree/points.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Octree-based Sparse Convolutional Neural Networks 3 | # Copyright (c) 2022 Peng-Shuai Wang 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | import torch 9 | import numpy as np 10 | from typing import Optional, Union, List 11 | 12 | 13 | class Points: 14 | r''' Represents a point cloud and contains some elementary transformations. 15 | 16 | Args: 17 | points (torch.Tensor): The coordinates of the points with a shape of 18 | :obj:`(N, 3)`, where :obj:`N` is the number of points. 19 | normals (torch.Tensor or None): The point normals with a shape of 20 | :obj:`(N, 3)`. 21 | features (torch.Tensor or None): The point features with a shape of 22 | :obj:`(N, C)`, where :obj:`C` is the channel of features. 23 | labels (torch.Tensor or None): The point labels with a shape of 24 | :obj:`(N, K)`, where :obj:`K` is the channel of labels. 25 | batch_id (torch.Tensor or None): The batch indices for each point with a 26 | shape of :obj:`(N, 1)`. 27 | batch_size (int): The batch size. 28 | ''' 29 | 30 | def __init__(self, points: torch.Tensor, 31 | normals: Optional[torch.Tensor] = None, 32 | features: Optional[torch.Tensor] = None, 33 | labels: Optional[torch.Tensor] = None, 34 | batch_id: Optional[torch.Tensor] = None, 35 | batch_size: int = 1): 36 | super().__init__() 37 | self.points = points 38 | self.normals = normals 39 | self.features = features 40 | self.labels = labels 41 | self.batch_id = batch_id 42 | self.batch_size = batch_size 43 | self.device = points.device 44 | self.batch_npt = None # valid after `merge_points` 45 | 46 | def orient_normal(self, axis: str = 'x'): 47 | r''' Orients the point normals along a given axis. 48 | 49 | Args: 50 | axis (int): The coordinate axes, choose from :obj:`x`, :obj:`y` and 51 | :obj:`z`. (default: :obj:`x`) 52 | ''' 53 | 54 | if self.normals is None: 55 | return 56 | 57 | axis_map = {'x': 0, 'y': 1, 'z': 2, 'xyz': 3} 58 | idx = axis_map[axis] 59 | if idx < 3: 60 | flags = self.normals[:, idx] > 0 61 | flags = flags.float() * 2.0 - 1.0 # [0, 1] -> [-1, 1] 62 | self.normals = self.normals * flags.unsqueeze(1) 63 | else: 64 | self.normals.abs_() 65 | 66 | def scale(self, factor: torch.Tensor): 67 | r''' Rescales the point cloud. 68 | 69 | Args: 70 | factor (torch.Tensor): The scale factor with shape :obj:`(3,)`. 71 | ''' 72 | 73 | non_zero = (factor != 0).all() 74 | all_ones = (factor == 1.0).all() 75 | non_uniform = (factor != factor[0]).any() 76 | assert non_zero, 'The scale factor must not constain 0.' 77 | if all_ones: return 78 | 79 | factor = factor.to(self.device) 80 | self.points = self.points * factor 81 | if self.normals is not None and non_uniform: 82 | ifactor = 1.0 / factor 83 | self.normals = self.normals * ifactor 84 | norm2 = torch.sqrt(torch.sum(self.normals ** 2, dim=1, keepdim=True)) 85 | self.normals = self.normals / torch.clamp(norm2, min=1.0e-12) 86 | 87 | def rotate(self, angle: torch.Tensor): 88 | r''' Rotates the point cloud. 89 | 90 | Args: 91 | angle (torch.Tensor): The rotation angles in radian with shape :obj:`(3,)`. 92 | ''' 93 | 94 | cos, sin = angle.cos(), angle.sin() 95 | # rotx, roty, rotz are actually the transpose of the rotation matrices 96 | rotx = torch.Tensor([[1, 0, 0], [0, cos[0], sin[0]], [0, -sin[0], cos[0]]]) 97 | roty = torch.Tensor([[cos[1], 0, -sin[1]], [0, 1, 0], [sin[1], 0, cos[1]]]) 98 | rotz = torch.Tensor([[cos[2], sin[2], 0], [-sin[2], cos[2], 0], [0, 0, 1]]) 99 | rot = rotx @ roty @ rotz 100 | 101 | rot = rot.to(self.device) 102 | self.points = self.points @ rot 103 | if self.normals is not None: 104 | self.normals = self.normals @ rot 105 | 106 | def translate(self, dis: torch.Tensor): 107 | r''' Translates the point cloud. 108 | 109 | Args: 110 | dis (torch.Tensor): The displacement with shape :obj:`(3,)`. 111 | ''' 112 | 113 | dis = dis.to(self.device) 114 | self.points = self.points + dis 115 | 116 | def flip(self, axis: str): 117 | r''' Flips the point cloud along the given :attr:`axis`. 118 | 119 | Args: 120 | axis (str): The flipping axis, choosen from :obj:`x`, :obj:`y`, and :obj`z`. 121 | ''' 122 | 123 | axis_map = {'x': 0, 'y': 1, 'z': 2} 124 | idx = axis_map[axis] 125 | self.points[:, idx] *= -1.0 126 | if self.normals is not None: 127 | self.normals[:, idx] *= -1.0 128 | 129 | def clip(self, min: float = -1.0, max: float = 1.0, esp: float = 0.01): 130 | r''' Clips the point cloud to :obj:`[min+esp, max-esp]` and returns the mask. 131 | 132 | Args: 133 | min (float): The minimum value to clip. 134 | max (float): The maximum value to clip. 135 | esp (float): The margin. 136 | ''' 137 | 138 | mask = self.inbox_mask(min + esp, max - esp) 139 | tmp = self.__getitem__(mask) 140 | self.__dict__.update(tmp.__dict__) 141 | return mask 142 | 143 | def __getitem__(self, mask: torch.Tensor): 144 | r''' Slices the point cloud according a given :attr:`mask`. 145 | ''' 146 | 147 | dummy_pts = torch.zeros(1, 3, device=self.device) 148 | out = Points(dummy_pts, batch_size=self.batch_size) 149 | 150 | out.points = self.points[mask] 151 | if self.normals is not None: 152 | out.normals = self.normals[mask] 153 | if self.features is not None: 154 | out.features = self.features[mask] 155 | if self.labels is not None: 156 | out.labels = self.labels[mask] 157 | if self.batch_id is not None: 158 | out.batch_id = self.batch_id[mask] 159 | return out 160 | 161 | def inbox_mask(self, bbmin: Union[float, torch.Tensor] = -1.0, 162 | bbmax: Union[float, torch.Tensor] = 1.0): 163 | r''' Returns a mask indicating whether the points are within the specified 164 | bounding box or not. 165 | ''' 166 | 167 | mask_min = torch.all(self.points > bbmin, dim=1) 168 | mask_max = torch.all(self.points < bbmax, dim=1) 169 | mask = torch.logical_and(mask_min, mask_max) 170 | return mask 171 | 172 | def bbox(self): 173 | r''' Returns the bounding box. 174 | ''' 175 | 176 | # torch.min and torch.max return (value, indices) 177 | bbmin = self.points.min(dim=0) 178 | bbmax = self.points.max(dim=0) 179 | return bbmin[0], bbmax[0] 180 | 181 | def normalize(self, bbmin: torch.Tensor, bbmax: torch.Tensor, 182 | scale: float = 1.0): 183 | r''' Normalizes the point cloud to :obj:`[-scale, scale]`. 184 | 185 | Args: 186 | bbmin (torch.Tensor): The minimum coordinates of the bounding box. 187 | bbmax (torch.Tensor): The maximum coordinates of the bounding box. 188 | scale (float): The scale factor 189 | ''' 190 | 191 | center = (bbmin + bbmax) * 0.5 192 | box_size = (bbmax - bbmin).max() + 1.0e-6 193 | self.points = (self.points - center) * (2.0 * scale / box_size) 194 | 195 | def to(self, device: Union[torch.device, str], non_blocking: bool = False): 196 | r''' Moves the Points to a specified device. 197 | 198 | Args: 199 | device (torch.device or str): The destination device. 200 | non_blocking (bool): If True and the source is in pinned memory, the copy 201 | will be asynchronous with respect to the host. Otherwise, the argument 202 | has no effect. Default: False. 203 | ''' 204 | 205 | if isinstance(device, str): 206 | device = torch.device(device) 207 | 208 | # If on the save device, directly retrun self 209 | if self.device == device: 210 | return self 211 | 212 | # Construct a new Points on the specified device 213 | points = Points(torch.zeros(1, 3, device=device)) 214 | points.batch_npt = self.batch_npt 215 | points.points = self.points.to(device, non_blocking=non_blocking) 216 | if self.normals is not None: 217 | points.normals = self.normals.to(device, non_blocking=non_blocking) 218 | if self.features is not None: 219 | points.features = self.features.to(device, non_blocking=non_blocking) 220 | if self.labels is not None: 221 | points.labels = self.labels.to(device, non_blocking=non_blocking) 222 | if self.batch_id is not None: 223 | points.batch_id = self.batch_id.to(device, non_blocking=non_blocking) 224 | return points 225 | 226 | def cuda(self, non_blocking: bool = False): 227 | r''' Moves the Points to the GPU. ''' 228 | 229 | return self.to('cuda', non_blocking) 230 | 231 | def cpu(self): 232 | r''' Moves the Points to the CPU. ''' 233 | 234 | return self.to('cpu') 235 | 236 | def save(self, filename: str, info: str = 'PNFL'): 237 | r''' Save the Points into npz or xyz files. 238 | 239 | Args: 240 | filename (str): The output filename. 241 | info (str): The infomation for saving: 'P' -> 'points', 'N' -> 'normals', 242 | 'F' -> 'features', 'L' -> 'labels', 'B' -> 'batch_id'. 243 | ''' 244 | 245 | mapping = { 246 | 'P': ('points', self.points), 'N': ('normals', self.normals), 247 | 'F': ('features', self.features), 'L': ('labels', self.labels), 248 | 'B': ('batch_id', self.batch_id), } 249 | 250 | names, outs = [], [] 251 | for key in info.upper(): 252 | name, out = mapping[key] 253 | if out is not None: 254 | names.append(name) 255 | if out.dim() == 1: 256 | out = out.unsqueeze(1) 257 | outs.append(out.cpu().numpy()) 258 | 259 | if filename.endswith('npz'): 260 | out_dict = dict(zip(names, outs)) 261 | np.savez(filename, **out_dict) 262 | elif filename.endswith('xyz'): 263 | out_array = np.concatenate(outs, axis=1) 264 | np.savetxt(filename, out_array, fmt='%.6f') 265 | else: 266 | raise ValueError 267 | 268 | 269 | def merge_points(points: List['Points'], update_batch_info: bool = True): 270 | r''' Merges a list of points into one batch. 271 | 272 | Args: 273 | points (List[Octree]): A list of points to merge. The batch size of each 274 | points in the list is assumed to be 1, and the :obj:`batch_size`, 275 | :obj:`batch_id`, and :obj:`batch_npt` in the points are ignored. 276 | ''' 277 | 278 | out = Points(torch.zeros(1, 3)) 279 | out.points = torch.cat([p.points for p in points], dim=0) 280 | if points[0].normals is not None: 281 | out.normals = torch.cat([p.normals for p in points], dim=0) 282 | if points[0].features is not None: 283 | out.features = torch.cat([p.features for p in points], dim=0) 284 | if points[0].labels is not None: 285 | out.labels = torch.cat([p.labels for p in points], dim=0) 286 | out.device = points[0].device 287 | 288 | if update_batch_info: 289 | out.batch_size = len(points) 290 | out.batch_npt = torch.Tensor([p.points.shape[0] for p in points]) 291 | out.batch_id = torch.cat([p.points.new_full((p.points.shape[0], 1), i) 292 | for i, p in enumerate(points)], dim=0) 293 | return out 294 | -------------------------------------------------------------------------------- /utils/ocnn/octree/shuffled_key.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Octree-based Sparse Convolutional Neural Networks 3 | # Copyright (c) 2022 Peng-Shuai Wang 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | import torch 9 | from typing import Optional, Union 10 | 11 | 12 | class KeyLUT: 13 | 14 | def __init__(self): 15 | r256 = torch.arange(256, dtype=torch.int64) 16 | r512 = torch.arange(512, dtype=torch.int64) 17 | zero = torch.zeros(256, dtype=torch.int64) 18 | device = torch.device('cpu') 19 | 20 | self._encode = {device: (self.xyz2key(r256, zero, zero, 8), 21 | self.xyz2key(zero, r256, zero, 8), 22 | self.xyz2key(zero, zero, r256, 8))} 23 | self._decode = {device: self.key2xyz(r512, 9)} 24 | 25 | def encode_lut(self, device=torch.device('cpu')): 26 | if device not in self._encode: 27 | cpu = torch.device('cpu') 28 | self._encode[device] = tuple(e.to(device) for e in self._encode[cpu]) 29 | return self._encode[device] 30 | 31 | def decode_lut(self, device=torch.device('cpu')): 32 | if device not in self._decode: 33 | cpu = torch.device('cpu') 34 | self._decode[device] = tuple(e.to(device) for e in self._decode[cpu]) 35 | return self._decode[device] 36 | 37 | def xyz2key(self, x, y, z, depth): 38 | key = torch.zeros_like(x) 39 | for i in range(depth): 40 | mask = 1 << i 41 | key = (key | ((x & mask) << (2 * i + 2)) | 42 | ((y & mask) << (2 * i + 1)) | 43 | ((z & mask) << (2 * i + 0))) 44 | return key 45 | 46 | def key2xyz(self, key, depth): 47 | x = torch.zeros_like(key) 48 | y = torch.zeros_like(key) 49 | z = torch.zeros_like(key) 50 | for i in range(depth): 51 | x = x | ((key & (1 << (3 * i + 2))) >> (2 * i + 2)) 52 | y = y | ((key & (1 << (3 * i + 1))) >> (2 * i + 1)) 53 | z = z | ((key & (1 << (3 * i + 0))) >> (2 * i + 0)) 54 | return x, y, z 55 | 56 | 57 | _key_lut = KeyLUT() 58 | 59 | 60 | def xyz2key(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, 61 | b: Optional[Union[torch.Tensor, int]] = None, depth: int = 16): 62 | r'''Encodes :attr:`x`, :attr:`y`, :attr:`z` coordinates to the shuffled keys 63 | based on pre-computed look up tables. The speed of this function is much 64 | faster than the method based on for-loop. 65 | 66 | Args: 67 | x (torch.Tensor): The x coordinate. 68 | y (torch.Tensor): The y coordinate. 69 | z (torch.Tensor): The z coordinate. 70 | b (torch.Tensor or int): The batch index of the coordinates, and should be 71 | smaller than 1024. If :attr:`b` is :obj:`torch.Tensor`, the size of 72 | :attr:`b` must be the same as :attr:`x`, :attr:`y`, and :attr:`z`. 73 | depth (int): The depth of the shuffled key, and must be smaller than 17 (< 17). 74 | ''' 75 | 76 | EX, EY, EZ = _key_lut.encode_lut(x.device) 77 | x, y, z = x.long(), y.long(), z.long() 78 | 79 | mask = 255 if depth > 8 else (1 << depth) - 1 80 | key = EX[x & mask] | EY[y & mask] | EZ[z & mask] 81 | if depth > 8: 82 | mask = (1 << (depth-8)) - 1 83 | key16 = EX[(x >> 8) & mask] | EY[(y >> 8) & mask] | EZ[(z >> 8) & mask] 84 | key = key16 << 24 | key 85 | 86 | if b is not None: 87 | b = b.long() 88 | key = b << 48 | key 89 | 90 | return key 91 | 92 | 93 | def key2xyz(key: torch.Tensor, depth: int = 16): 94 | r'''Decodes the shuffled key to :attr:`x`, :attr:`y`, :attr:`z` coordinates 95 | and the batch index based on pre-computed look up tables. 96 | 97 | Args: 98 | key (torch.Tensor): The shuffled key. 99 | depth (int): The depth of the shuffled key, and must be smaller than 17 (< 17). 100 | ''' 101 | 102 | DX, DY, DZ = _key_lut.decode_lut(key.device) 103 | x, y, z = torch.zeros_like(key), torch.zeros_like(key), torch.zeros_like(key) 104 | 105 | b = key >> 48 106 | key = key & ((1 << 48) - 1) 107 | 108 | n = (depth + 2) // 3 109 | for i in range(n): 110 | k = key >> (i * 9) & 511 111 | x = x | (DX[k] << (i * 3)) 112 | y = y | (DY[k] << (i * 3)) 113 | z = z | (DZ[k] << (i * 3)) 114 | 115 | return x, y, z, b 116 | -------------------------------------------------------------------------------- /utils/ocnn/utils.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Octree-based Sparse Convolutional Neural Networks 3 | # Copyright (c) 2022 Peng-Shuai Wang 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | import math 9 | import torch 10 | from typing import Optional 11 | 12 | __all__ = ['trunc_div', 'meshgrid', 'cumsum', 'scatter_add', 'xavier_uniform_', 13 | 'resize_with_last_val', 'list2str'] 14 | classes = __all__ 15 | 16 | 17 | def trunc_div(input, other): 18 | r''' Wraps :func:`torch.div` for compatibility. It rounds the results of the 19 | division towards zero and is equivalent to C-style integer division. 20 | ''' 21 | 22 | version = torch.__version__.split('.') 23 | larger_than_170 = int(version[0]) > 0 and int(version[1]) > 7 24 | 25 | if larger_than_170: 26 | return torch.div(input, other, rounding_mode='trunc') 27 | else: 28 | return torch.floor_divide(input, other) 29 | 30 | 31 | def meshgrid(*tensors, indexing: Optional[str] = None): 32 | r''' Wraps :func:`torch.meshgrid` for compatibility. 33 | ''' 34 | 35 | version = torch.__version__.split('.') 36 | larger_than_190 = int(version[0]) > 0 and int(version[1]) > 9 37 | 38 | if larger_than_190: 39 | return torch.meshgrid(*tensors, indexing=indexing) 40 | else: 41 | return torch.meshgrid(*tensors) 42 | 43 | 44 | def cumsum(data: torch.Tensor, dim: int, exclusive: bool = False): 45 | r''' Extends :func:`torch.cumsum` with the input argument :attr:`exclusive`. 46 | 47 | Args: 48 | data (torch.Tensor): The input data. 49 | dim (int): The dimension to do the operation over. 50 | exclusive (bool): If false, the behavior is the same as :func:`torch.cumsum`; 51 | if true, returns the cumulative sum exclusively. Note that if ture, 52 | the shape of output tensor is larger by 1 than :attr:`data` in the 53 | dimension where the computation occurs. 54 | ''' 55 | 56 | out = torch.cumsum(data, dim) 57 | 58 | if exclusive: 59 | size = list(data.size()) 60 | size[dim] = 1 61 | zeros = out.new_zeros(size) 62 | out = torch.cat([zeros, out], dim) 63 | return out 64 | 65 | 66 | def broadcast(src: torch.Tensor, other: torch.Tensor, dim: int): 67 | r''' Broadcast :attr:`src` according to :attr:`other`, originally from the 68 | library `pytorch_scatter`. 69 | ''' 70 | 71 | if dim < 0: 72 | dim = other.dim() + dim 73 | 74 | if src.dim() == 1: 75 | for _ in range(0, dim): 76 | src = src.unsqueeze(0) 77 | for _ in range(src.dim(), other.dim()): 78 | src = src.unsqueeze(-1) 79 | 80 | src = src.expand_as(other) 81 | return src 82 | 83 | 84 | def scatter_add(src: torch.Tensor, index: torch.Tensor, dim: int = -1, 85 | out: Optional[torch.Tensor] = None, 86 | dim_size: Optional[int] = None,) -> torch.Tensor: 87 | r''' Reduces all values from the :attr:`src` tensor into :attr:`out` at the 88 | indices specified in the :attr:`index` tensor along a given axis :attr:`dim`. 89 | This is just a wrapper of :func:`torch.scatter` in a boardcasting fashion. 90 | 91 | Args: 92 | src (torch.Tensor): The source tensor. 93 | index (torch.Tensor): The indices of elements to scatter. 94 | dim (torch.Tensor): The axis along which to index, (default: :obj:`-1`). 95 | out (torch.Tensor or None): The destination tensor. 96 | dim_size (int or None): If :attr:`out` is not given, automatically create 97 | output with size :attr:`dim_size` at dimension :attr:`dim`. If 98 | :attr:`dim_size` is not given, a minimal sized output tensor according 99 | to :obj:`index.max() + 1` is returned. 100 | ''' 101 | 102 | index = broadcast(index, src, dim) 103 | 104 | if out is None: 105 | size = list(src.size()) 106 | if dim_size is not None: 107 | size[dim] = dim_size 108 | elif index.numel() == 0: 109 | size[dim] = 0 110 | else: 111 | size[dim] = int(index.max()) + 1 112 | out = torch.zeros(size, dtype=src.dtype, device=src.device) 113 | 114 | return out.scatter_add_(dim, index, src) 115 | 116 | 117 | def xavier_uniform_(weights: torch.Tensor): 118 | r''' Initialize convolution weights with the same method as 119 | :obj:`torch.nn.init.xavier_uniform_`. 120 | 121 | :obj:`torch.nn.init.xavier_uniform_` initialize a tensor with shape 122 | :obj:`(out_c, in_c, kdim)`. It can not be used in :class:`ocnn.nn.OctreeConv` 123 | since the the shape of :attr:`OctreeConv.weights` is :obj:`(kdim, in_c, 124 | out_c)`. 125 | ''' 126 | 127 | shape = weights.shape # (kernel_dim, in_conv, out_conv) 128 | fan_in = shape[0] * shape[1] 129 | fan_out = shape[0] * shape[2] 130 | std = math.sqrt(2.0 / float(fan_in + fan_out)) 131 | a = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation 132 | 133 | torch.nn.init.uniform_(weights, -a, a) 134 | 135 | 136 | def resize_with_last_val(list_in: list, num: int = 3): 137 | r''' Resizes the number of elements of :attr:`list_in` to :attr:`num` with 138 | the last element of :attr:`list_in` if its number of elements is smaller 139 | than :attr:`num`. 140 | ''' 141 | 142 | assert (type(list_in) is list and len(list_in) < num + 1) 143 | for i in range(len(list_in), num): 144 | list_in.append(list_in[-1]) 145 | return list_in 146 | 147 | 148 | def list2str(list_in: list): 149 | r''' Returns a string representation of :attr:`list_in` 150 | ''' 151 | 152 | out = [str(x) for x in list_in] 153 | return ''.join(out) 154 | -------------------------------------------------------------------------------- /utils/thsolver/__init__.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Octree-based Sparse Convolutional Neural Networks 3 | # Copyright (c) 2022 Peng-Shuai Wang 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | from . import config 9 | from .config import get_config, parse_args 10 | 11 | from . import solver 12 | from .solver import Solver 13 | 14 | from . import dataset 15 | from .dataset import Dataset 16 | 17 | 18 | __all__ = [ 19 | 'config', 'get_config', 'parse_args', 20 | 'solver', 'Solver', 21 | 'dataset', 'Dataset' 22 | ] 23 | -------------------------------------------------------------------------------- /utils/thsolver/config.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Octree-based Sparse Convolutional Neural Networks 3 | # Copyright (c) 2022 Peng-Shuai Wang 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | # autopep8: off 9 | import os 10 | import sys 11 | import shutil 12 | import argparse 13 | from datetime import datetime 14 | from yacs.config import CfgNode as CN 15 | 16 | _C = CN() 17 | 18 | # SOLVER related parameters 19 | _C.SOLVER = CN() 20 | _C.SOLVER.alias = '' # The experiment alias 21 | _C.SOLVER.gpu = (0,) # The gpu ids 22 | _C.SOLVER.run = 'train' # Choose from train or test 23 | 24 | _C.SOLVER.logdir = 'logs' # Directory where to write event logs 25 | _C.SOLVER.ckpt = '' # Restore weights from checkpoint file 26 | _C.SOLVER.ckpt_num = 10 # The number of checkpoint kept 27 | 28 | _C.SOLVER.type = 'sgd' # Choose from sgd or adam 29 | _C.SOLVER.weight_decay = 0.0005 # The weight decay on model weights 30 | _C.SOLVER.clip_grad = -1.0 # Clip gradient norm (-1: disable) 31 | _C.SOLVER.max_epoch = 300 # Maximum training epoch 32 | _C.SOLVER.warmup_epoch = 20 # The warmup epoch number 33 | _C.SOLVER.warmup_init = 0.001 # The initial ratio of the warmup 34 | _C.SOLVER.eval_epoch = 1 # Maximum evaluating epoch 35 | _C.SOLVER.eval_step = -1 # Maximum evaluating steps 36 | _C.SOLVER.test_every_epoch = 10 # Test model every n training epochs 37 | _C.SOLVER.log_per_iter = -1 # Output log every k training iteration 38 | 39 | _C.SOLVER.lr_type = 'step' # Learning rate type: step or cos 40 | _C.SOLVER.lr = 0.1 # Initial learning rate 41 | _C.SOLVER.lr_min = 0.0001 # The minimum learning rate 42 | _C.SOLVER.gamma = 0.1 # Learning rate step-wise decay 43 | _C.SOLVER.milestones = (120,180,) # Learning rate milestones 44 | _C.SOLVER.lr_power = 0.9 # Used in poly learning rate 45 | 46 | _C.SOLVER.dist_url = 'tcp://localhost:10001' 47 | _C.SOLVER.progress_bar = True # Enable the progress_bar or not 48 | _C.SOLVER.rand_seed = -1 # Fix the random seed if larger than 0 49 | _C.SOLVER.empty_cache = True # Empty cuda cache periodically 50 | 51 | # DATA related parameters 52 | _C.DATA = CN() 53 | _C.DATA.train = CN() 54 | _C.DATA.train.name = '' # The name of the dataset 55 | _C.DATA.train.disable = False # Disable this dataset or not 56 | 57 | # For octree building 58 | _C.DATA.train.depth = 5 # The octree depth 59 | _C.DATA.train.full_depth = 2 # The full depth 60 | _C.DATA.train.adaptive = False # Build the adaptive octree 61 | 62 | # For transformation 63 | _C.DATA.train.orient_normal = '' # Used to re-orient normal directions 64 | 65 | # For data augmentation 66 | _C.DATA.train.distort = False # Whether to apply data augmentation 67 | _C.DATA.train.scale = 0.0 # Scale the points 68 | _C.DATA.train.uniform = False # Generate uniform scales 69 | _C.DATA.train.jitter = 0.0 # Jitter the points 70 | _C.DATA.train.interval = (1, 1, 1) # Use interval&angle to generate random angle 71 | _C.DATA.train.angle = (180, 180, 180) 72 | 73 | # For data loading 74 | _C.DATA.train.location = '' # The data location 75 | _C.DATA.train.filelist = '' # The data filelist 76 | _C.DATA.train.batch_size = 32 # Training data batch size 77 | _C.DATA.train.take = -1 # Number of samples used for training 78 | _C.DATA.train.num_workers = 4 # Number of workers to load the data 79 | _C.DATA.train.shuffle = False # Shuffle the input data 80 | _C.DATA.train.in_memory = False # Load the training data into memory 81 | 82 | 83 | _C.DATA.test = _C.DATA.train.clone() 84 | _C.DATA.test.num_workers = 2 85 | 86 | # MODEL related parameters 87 | _C.MODEL = CN() 88 | _C.MODEL.name = '' # The name of the model 89 | _C.MODEL.feature = 'ND' # The input features 90 | _C.MODEL.channel = 3 # The input feature channel 91 | _C.MODEL.nout = 40 # The output feature channel 92 | _C.MODEL.nempty = False # Perform Octree Conv on non-empty octree nodes 93 | 94 | _C.MODEL.stages = 3 95 | _C.MODEL.resblock_num = 3 # The resblock number 96 | _C.MODEL.resblock_type = 'bottleneck'# Choose from 'bottleneck' and 'basic 97 | _C.MODEL.bottleneck = 4 # The bottleneck factor of one resblock 98 | 99 | _C.MODEL.upsample = 'nearest' # The method used for upsampling 100 | _C.MODEL.interp = 'linear' # The interplation method: linear or nearest 101 | 102 | _C.MODEL.sync_bn = False # Use sync_bn when training the network 103 | _C.MODEL.use_checkpoint = False # Use checkpoint to save memory 104 | _C.MODEL.find_unused_parameters = False # Used in DistributedDataParallel 105 | 106 | _C.MODEL.num_edge_types = 7 107 | _C.MODEL.conv_type = "SAGE" 108 | _C.MODEL.include_distance = False 109 | _C.MODEL.normal_aware_pooling = False 110 | _C.MODEL.test_mesh = "TBD" 111 | _C.MODEL.get_test_stat = False 112 | 113 | 114 | # loss related parameters 115 | _C.LOSS = CN() 116 | _C.LOSS.name = '' # The name of the loss 117 | _C.LOSS.num_class = 40 # The class number for the cross-entropy loss 118 | _C.LOSS.weights = (1.0, 1.0) # The weight factors for different losses 119 | _C.LOSS.label_smoothing = 0.0 # The factor of label smoothing 120 | 121 | 122 | # backup the commands 123 | _C.SYS = CN() 124 | _C.SYS.cmds = '' # Used to backup the commands 125 | 126 | FLAGS = _C 127 | 128 | 129 | def _update_config(FLAGS, args): 130 | FLAGS.defrost() 131 | if args.config: 132 | FLAGS.merge_from_file(args.config) 133 | if args.opts: 134 | FLAGS.merge_from_list(args.opts) 135 | FLAGS.SYS.cmds = ' '.join(sys.argv) 136 | 137 | # update logdir 138 | alias = FLAGS.SOLVER.alias.lower() 139 | if 'time' in alias: # 'time' is a special keyword 140 | alias = alias.replace('time', datetime.now().strftime('%m%d%H%M')) #%S 141 | if alias != '': 142 | FLAGS.SOLVER.logdir += '_' + alias 143 | FLAGS.freeze() 144 | 145 | 146 | def _backup_config(FLAGS, args): 147 | logdir = FLAGS.SOLVER.logdir 148 | if not os.path.exists(logdir): 149 | os.makedirs(logdir) 150 | # copy the file to logdir 151 | if args.config: 152 | shutil.copy2(args.config, logdir) 153 | # dump all configs 154 | filename = os.path.join(logdir, 'all_configs.yaml') 155 | with open(filename, 'w') as fid: 156 | fid.write(FLAGS.dump()) 157 | 158 | 159 | def _set_env_var(FLAGS): 160 | gpus = ','.join([str(a) for a in FLAGS.SOLVER.gpu]) 161 | os.environ['CUDA_VISIBLE_DEVICES'] = gpus 162 | 163 | 164 | def get_config(): 165 | return FLAGS 166 | 167 | def parse_args(backup=True): 168 | parser = argparse.ArgumentParser(description='The configs') 169 | parser.add_argument('--config', type=str, 170 | help='experiment configure file name') 171 | parser.add_argument('opts', nargs=argparse.REMAINDER, 172 | help="Modify config options using the command-line") 173 | 174 | args = parser.parse_args() 175 | _update_config(FLAGS, args) 176 | if backup: 177 | _backup_config(FLAGS, args) 178 | # _set_env_var(FLAGS) 179 | return FLAGS 180 | 181 | 182 | if __name__ == '__main__': 183 | flags = parse_args(backup=False) 184 | print(flags) 185 | -------------------------------------------------------------------------------- /utils/thsolver/dataset.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Octree-based Sparse Convolutional Neural Networks 3 | # Copyright (c) 2022 Peng-Shuai Wang 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | import os 9 | import torch 10 | import torch.utils.data 11 | import numpy as np 12 | from tqdm import tqdm 13 | 14 | 15 | def read_file(filename): 16 | points = np.fromfile(filename, dtype=np.uint8) 17 | return torch.from_numpy(points) # convert it to torch.tensor 18 | 19 | 20 | class Dataset(torch.utils.data.Dataset): 21 | 22 | def __init__(self, root, filelist, transform, read_file=read_file, 23 | in_memory=False, take: int = -1): 24 | super(Dataset, self).__init__() 25 | self.root = root 26 | self.filelist = filelist 27 | self.transform = transform 28 | self.in_memory = in_memory 29 | self.read_file = read_file 30 | self.take = take 31 | 32 | self.filenames, self.labels = self.load_filenames() 33 | if self.in_memory: 34 | print('Load files into memory from ' + self.filelist) 35 | self.samples = [self.read_file(os.path.join(self.root, f)) 36 | for f in tqdm(self.filenames, ncols=80, leave=False)] 37 | 38 | def __len__(self): 39 | return len(self.filenames) 40 | 41 | def __getitem__(self, idx): 42 | sample = self.samples[idx] if self.in_memory else \ 43 | self.read_file(os.path.join(self.root, self.filenames[idx])) # noqa 44 | output = self.transform(sample, idx) # data augmentation + build octree 45 | output['label'] = self.labels[idx] 46 | output['filename'] = self.filenames[idx] 47 | return output 48 | 49 | def load_filenames(self): 50 | filenames, labels = [], [] 51 | with open(self.filelist) as fid: 52 | lines = fid.readlines() 53 | for line in lines: 54 | tokens = line.split() 55 | filename = tokens[0] 56 | label = tokens[1] if len(tokens) == 2 else 0 57 | filenames.append(filename) 58 | labels.append(int(label)) 59 | 60 | num = len(filenames) 61 | if self.take > num or self.take < 1: 62 | self.take = num 63 | 64 | return filenames[:self.take], labels[:self.take] 65 | 66 | 67 | class DatasetGraph(torch.utils.data.Dataset): 68 | # 对 Dataset 做了一些小改动:原先dataset只能读取单一类型的文件,无法把图和GT一起读入 69 | # 现在能了 70 | def __init__(self, root, filelist, transform, read_file=read_file, 71 | in_memory=False, take: int = -1): 72 | super(Dataset, self).__init__() 73 | self.root = root 74 | self.filelist = filelist 75 | self.transform = transform 76 | self.in_memory = in_memory 77 | self.read_file = read_file 78 | self.take = take 79 | 80 | self.filenames, self.labels = self.load_filenames() 81 | if self.in_memory: 82 | print('Load files into memory from ' + self.filelist) 83 | self.samples = [self.read_file(os.path.join(self.root, f)) 84 | for f in tqdm(self.filenames, ncols=80, leave=False)] 85 | 86 | def __len__(self): 87 | return len(self.filenames) 88 | 89 | def __getitem__(self, idx): 90 | # GT 91 | gts = self.samples[idx] if self.in_memory else \ 92 | self.read_file(os.path.join(self.root, self.filenames[idx])) 93 | vertices = torch.from_numpy(gts['vertices']) 94 | normals = torch.from_numpy(gts['normals']) 95 | dist = torch.from_numpy(gts['dist']) 96 | 97 | rnd_idx = torch.randint(low=0, high=dist.shape[0], size=(100000,)) 98 | dist = dist[rnd_idx] 99 | 100 | points = Points(points=vertices, normals=normals, ) 101 | octree = self.points2octree(points) 102 | return {'points': points, 'octree': octree, 'dist': dist} 103 | 104 | 105 | # graph 106 | graph = self.samples[idx] if self.in_memory else \ 107 | self.read_file(os.path.join(self.root, self.filenames[idx])) 108 | 109 | output['label'] = self.labels[idx] 110 | output['filename'] = self.filenames[idx] 111 | return output 112 | 113 | def load_filenames(self): 114 | filenames, labels = [], [] 115 | with open(self.filelist) as fid: 116 | lines = fid.readlines() 117 | for line in lines: 118 | tokens = line.split() 119 | filename = tokens[0] 120 | label = tokens[1] if len(tokens) == 2 else 0 121 | filenames.append(filename) 122 | labels.append(int(label)) 123 | 124 | num = len(filenames) 125 | if self.take > num or self.take < 1: 126 | self.take = num 127 | 128 | return filenames[:self.take], labels[:self.take] 129 | -------------------------------------------------------------------------------- /utils/thsolver/default_settings.py: -------------------------------------------------------------------------------- 1 | """ 2 | build a dict containing all settings which can be accessed globally. 3 | 4 | Usage: 5 | all settings defined in gnndist.yaml can be accessed anywhere like this: 6 | 7 | from utils.thsolver import default_settings 8 | default_settings.get_global_value("max_epoch") 9 | 10 | """ 11 | 12 | _global_dict = None 13 | 14 | def _init(FLAGS=None): 15 | global _global_dict 16 | _global_dict = { 17 | # "normal_aware_pooling": True, 18 | # "num_edge_types": 7, 19 | } 20 | 21 | def set_global_value(key, value): 22 | _global_dict[key] = value 23 | 24 | def set_global_values(FLAGS): 25 | """set global values from the FLAGS in thsolver.solver 26 | """ 27 | for each in FLAGS: 28 | for it in FLAGS[each]: 29 | _global_dict[it] = FLAGS[each][it] 30 | 31 | def get_global_value(key): 32 | return _global_dict[key] -------------------------------------------------------------------------------- /utils/thsolver/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Octree-based Sparse Convolutional Neural Networks 3 | # Copyright (c) 2022 Peng-Shuai Wang 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | import math 9 | from bisect import bisect_right 10 | import torch.optim.lr_scheduler as LR 11 | 12 | 13 | def multi_step(optimizer, flags): 14 | return LR.MultiStepLR(optimizer, flags.milestones, flags.gamma) 15 | 16 | 17 | def cos(optimizer, flags): 18 | return LR.CosineAnnealingLR(optimizer, flags.max_epoch, eta_min=flags.lr_min) 19 | 20 | 21 | def poly(optimizer, flags): 22 | lr_lambda = lambda epoch: (1 - epoch / flags.max_epoch) ** flags.lr_power 23 | return LR.LambdaLR(optimizer, lr_lambda) 24 | 25 | 26 | def constant(optimizer, flags): 27 | lr_lambda = lambda epoch: 1 28 | return LR.LambdaLR(optimizer, lr_lambda) 29 | 30 | 31 | def cos_warmup(optimizer, flags): 32 | def lr_lambda(epoch): 33 | warmup = flags.warmup_epoch 34 | warmup_init = flags.warmup_init 35 | if epoch <= warmup: 36 | return (1 - warmup_init) * epoch / warmup + warmup_init 37 | else: 38 | lr_min = flags.lr_min 39 | ratio = (epoch - warmup) / (flags.max_epoch - warmup) 40 | return lr_min + 0.5 * (1.0 - lr_min) * (1 + math.cos(math.pi * ratio)) 41 | return LR.LambdaLR(optimizer, lr_lambda) 42 | 43 | 44 | def poly_warmup(optimizer, flags): 45 | def lr_lambda(epoch): 46 | warmup = flags.warmup_epoch 47 | warmup_init = flags.warmup_init 48 | if epoch <= warmup: 49 | return (1 - warmup_init) * epoch / warmup + warmup_init 50 | else: 51 | ratio = (epoch - warmup) / (flags.max_epoch - warmup) 52 | return (1 - ratio) ** flags.lr_power 53 | return LR.LambdaLR(optimizer, lr_lambda) 54 | 55 | 56 | def step_warmup(optimizer, flags): 57 | def lr_lambda(epoch): 58 | warmup = flags.warmup_epoch 59 | warmup_init = flags.warmup_init 60 | if epoch <= warmup: 61 | return (1 - warmup_init) * epoch / warmup + warmup_init 62 | else: 63 | milestones = sorted(flags.milestones) 64 | return flags.gamma ** bisect_right(milestones, epoch) 65 | return LR.LambdaLR(optimizer, lr_lambda) 66 | 67 | 68 | def get_lr_scheduler(optimizer, flags): 69 | lr_dict = {'cos': cos, 'step': multi_step, 'cos': cos, 'poly': poly, 70 | 'constant': constant, 'cos_warmup': cos_warmup, 71 | 'poly_warmup': poly_warmup, 'step_warmup': step_warmup} 72 | lr_func = lr_dict[flags.lr_type] 73 | return lr_func(optimizer, flags) 74 | -------------------------------------------------------------------------------- /utils/thsolver/sampler.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Octree-based Sparse Convolutional Neural Networks 3 | # Copyright (c) 2022 Peng-Shuai Wang 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | import torch 9 | from torch.utils.data import Sampler, DistributedSampler, Dataset 10 | 11 | 12 | class InfSampler(Sampler): 13 | def __init__(self, dataset: Dataset, shuffle: bool = True) -> None: 14 | self.dataset = dataset 15 | self.shuffle = shuffle 16 | self.reset_sampler() 17 | 18 | def reset_sampler(self): 19 | num = len(self.dataset) 20 | indices = torch.randperm(num) if self.shuffle else torch.arange(num) 21 | self.indices = indices.tolist() 22 | self.iter_num = 0 23 | 24 | def __iter__(self): 25 | return self 26 | 27 | def __next__(self): 28 | value = self.indices[self.iter_num] 29 | self.iter_num = self.iter_num + 1 30 | 31 | if self.iter_num >= len(self.indices): 32 | self.reset_sampler() 33 | return value 34 | 35 | def __len__(self): 36 | return len(self.dataset) 37 | 38 | 39 | class DistributedInfSampler(DistributedSampler): 40 | def __init__(self, dataset: Dataset, shuffle: bool = True) -> None: 41 | super().__init__(dataset, shuffle=shuffle) 42 | self.reset_sampler() 43 | 44 | def reset_sampler(self): 45 | self.indices = list(super().__iter__()) 46 | self.iter_num = 0 47 | 48 | def __iter__(self): 49 | return self 50 | 51 | def __next__(self): 52 | value = self.indices[self.iter_num] 53 | self.iter_num = self.iter_num + 1 54 | 55 | if self.iter_num >= len(self.indices): 56 | self.reset_sampler() 57 | return value 58 | -------------------------------------------------------------------------------- /utils/thsolver/tracker.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Octree-based Sparse Convolutional Neural Networks 3 | # Copyright (c) 2022 Peng-Shuai Wang 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | import time 9 | import torch 10 | import torch.distributed 11 | from datetime import datetime 12 | from tqdm import tqdm 13 | from typing import Dict 14 | 15 | 16 | class AverageTracker: 17 | 18 | def __init__(self): 19 | self.value = None 20 | self.num = 0.0 21 | self.max_len = 76 22 | self.start_time = time.time() 23 | 24 | def update(self, value: Dict[str, torch.Tensor]): 25 | if not value: 26 | return # empty input, return 27 | 28 | value = {key: val.detach() for key, val in value.items()} 29 | if self.value is None: 30 | self.value = value 31 | else: 32 | for key, val in value.items(): 33 | self.value[key] += val 34 | self.num += 1 35 | 36 | def average(self): 37 | return {key: val.item() / self.num for key, val in self.value.items()} 38 | 39 | @torch.no_grad() 40 | def average_all_gather(self): 41 | for key, tensor in self.value.items(): 42 | if not tensor.is_cuda: continue 43 | tensors_gather = [torch.ones_like(tensor) 44 | for _ in range(torch.distributed.get_world_size())] 45 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False) 46 | tensors = torch.stack(tensors_gather, dim=0) 47 | self.value[key] = torch.mean(tensors) 48 | 49 | def log(self, epoch, summary_writer=None, log_file=None, msg_tag='->', 50 | notes='', print_time=True, print_memory=False): 51 | if not self.value: return # empty, return 52 | 53 | avg = self.average() 54 | msg = 'Epoch: %d' % epoch 55 | for key, val in avg.items(): 56 | msg += ', %s: %.3f' % (key, val) 57 | if summary_writer is not None: 58 | summary_writer.add_scalar(key, val, epoch) 59 | 60 | # if the log_file is provided, save the log 61 | if log_file is not None: 62 | with open(log_file, 'a') as fid: 63 | fid.write(msg + '\n') 64 | 65 | # memory 66 | memory = '' 67 | if print_memory and torch.cuda.is_available(): 68 | size = torch.cuda.memory_reserved() 69 | # size = torch.cuda.memory_allocated() 70 | memory = ', memory: {:.3f}GB'.format(size / 2**30) 71 | 72 | # time 73 | time_str = '' 74 | if print_time: 75 | curr_time = ', time: ' + datetime.now().strftime("%Y/%m/%d %H:%M:%S") 76 | duration = ', duration: {:.2f}s'.format(time.time() - self.start_time) 77 | time_str = curr_time + duration 78 | 79 | # other notes 80 | if notes: 81 | notes = ', ' + notes 82 | 83 | # concatenate all messages 84 | msg += memory + time_str + notes 85 | 86 | # split the msg for better display 87 | chunks = [msg[i:i+self.max_len] for i in range(0, len(msg), self.max_len)] 88 | msg = (msg_tag + ' ') + ('\n' + len(msg_tag) * ' ' + ' ').join(chunks) 89 | tqdm.write(msg) 90 | --------------------------------------------------------------------------------