├── .gitignore
├── GeGnn_standalong.py
├── GnnDist.py
├── configs
└── gnndist_test.yaml
├── dataset_ps.py
├── hgraph
├── __init__.py
├── hgraph.py
├── models
│ ├── __init__.py
│ └── graph_unet.py
└── modules
│ ├── decide_edge_type.py
│ ├── modules.py
│ └── resblocks.py
├── img
├── bunny.jpg
├── teaser.png
└── teaser.v1.png
├── pretrained
└── ours00500.solver.tar
├── readme.md
├── requirements.txt
└── utils
├── __init__.py
├── dataset_prepare
├── py_data_generater.py
└── simplifiy.py
├── ocnn
├── __init__.py
├── dataset.py
├── models
│ ├── __init__.py
│ ├── autoencoder.py
│ ├── hrnet.py
│ ├── lenet.py
│ ├── octree_unet.py
│ ├── resnet.py
│ └── segnet.py
├── modules
│ ├── __init__.py
│ ├── modules.py
│ └── resblocks.py
├── nn
│ ├── __init__.py
│ ├── octree2col.py
│ ├── octree2vox.py
│ ├── octree_conv.py
│ ├── octree_drop.py
│ ├── octree_dwconv.py
│ ├── octree_interp.py
│ ├── octree_norm.py
│ ├── octree_pad.py
│ └── octree_pool.py
├── octree
│ ├── __init__.py
│ ├── octree.py
│ ├── points.py
│ └── shuffled_key.py
└── utils.py
└── thsolver
├── __init__.py
├── config.py
├── dataset.py
├── default_settings.py
├── lr_scheduler.py
├── sampler.py
├── solver.py
└── tracker.py
/.gitignore:
--------------------------------------------------------------------------------
1 |
2 | .DS_Store
3 | .idea
4 | .vscode
5 | __pycache__
6 | sphere.obj
7 | runs
8 | out
9 | configs/superhyper
10 | logs
11 |
12 |
13 | /utils/dataset_prepare/inputs/
14 | /utils/dataset_prepare/inputs_/
15 | /utils/dataset_prepare/outputs/
16 | /utils/dataset_prepare/simplified_meshes/
17 | /utils/dataset_prepare/shapenet.13.zip
18 | /utils/dataset_prepare/data_complete/
19 | /utils/dataset_prepare/original_obj/
20 | /utils/dataset_prepare/simplified_obj/
21 | /utils/.ipynb_checkpoints
22 | /data/
23 | /utils/dataset_prepare/geodesic/
24 | /logs
25 |
--------------------------------------------------------------------------------
/GeGnn_standalong.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import hgraph
4 | import trimesh
5 |
6 | import time
7 | import heat_method
8 | import argparse
9 |
10 |
11 | # a function that reads a triangular mesh, and generates its corresponding graph
12 | def read_mesh(path, to_tensor=True):
13 | # :param path: the path to the mesh
14 | # :param to_tensor: whether to convert the numpy array to torch tensor
15 | # :return: a dict containing the vertices, edges, normals, faces, face_normals, face_areas
16 |
17 | # read the mesh
18 | mesh = trimesh.load(path)
19 | # get the vertices
20 | vertices = mesh.vertices
21 | # get the edges
22 | edges = mesh.edges_unique
23 | edges_reversed = np.concatenate([edges[:, 1:], edges[:, :1]], 1)
24 | edges = np.concatenate([edges, edges_reversed], 0)
25 | edges = np.transpose(edges)
26 | # get the normals
27 | normals = mesh.vertex_normals
28 | # normalize the normal
29 | norm_normals = np.linalg.norm(normals, axis=1)
30 | normals = normals / norm_normals[:, np.newaxis]
31 |
32 | # get the faces
33 | faces = mesh.faces
34 | # get the face normals
35 | face_normals = mesh.face_normals
36 | # get the face areas
37 | face_areas = mesh.area_faces
38 | # convert to tensor, if needed
39 | if to_tensor:
40 | vertices = torch.from_numpy(vertices).float().cuda()
41 | edges = torch.from_numpy(edges).long().cuda()
42 | normals = torch.from_numpy(np.array(normals)).float().cuda()
43 | faces = torch.from_numpy(faces).long().cuda()
44 | face_normals = torch.from_numpy(np.array(face_normals)).float().cuda()
45 | face_areas = torch.from_numpy(np.array(face_areas)).float().cuda()
46 |
47 | # generate a dict
48 | dic = {
49 | "vertices": vertices,
50 | "edges": edges,
51 | "normals": normals,
52 | "faces": faces,
53 | "face_normals": face_normals,
54 | "face_areas": face_areas,
55 | }
56 |
57 | return dic
58 |
59 |
60 | # a wrapper of pretrained model
61 | class PretrainedModel(torch.nn.Module):
62 |
63 | def __init__(self, ckpt_path):
64 | super(PretrainedModel, self).__init__()
65 | self.model = hgraph.models.graph_unet.GraphUNet(
66 | 6, 256, None, None).cuda()
67 | self.embds = None
68 | # load the pretrained model
69 | ckpt = torch.load(ckpt_path, map_location=torch.device('cuda'))
70 | model_dict=ckpt['model_dict']
71 | self.model.load_state_dict(model_dict)
72 |
73 | def embd_decoder_func(self, i, j, embedding):
74 | i = i.long()
75 | j = j.long()
76 | embd_i = embedding[i].squeeze(-1)
77 | embd_j = embedding[j].squeeze(-1)
78 | embd = (embd_i - embd_j) ** 2
79 | pred = self.model.embedding_decoder_mlp(embd)
80 | pred = pred.squeeze(-1)
81 | return pred
82 |
83 | def precompute(self, mesh):
84 | with torch.no_grad():
85 | # calculate vertex wise embd
86 | # 1. construct the graph tree
87 | vertices = mesh['vertices'] # [N, 3]
88 | normals = mesh['normals'] # [N, 3]
89 | edges = mesh['edges'] # [2, M]
90 |
91 | tree = hgraph.hgraph.HGraph()
92 | tree.build_single_hgraph(
93 | hgraph.hgraph.Data(x=torch.cat([vertices, normals], dim=1), edge_index=edges)
94 | )
95 |
96 | # 2. feed the graph tree into the model & get the vertex-wise embedding
97 | self.embds = self.model(
98 | torch.cat([vertices, normals], dim=1),
99 | tree,
100 | tree.depth,
101 | dist=None,
102 | only_embd=True)
103 | self.embds = self.embds.detach()
104 |
105 |
106 | def forward(self, p_vertices=None, q_vertices=None):
107 | # given a mesh, and two sets of vertices, calculate the geodesic distances the pairs
108 | # :param p_vertices: [N], index of the vertices in the first set
109 | # :param q_vertices: [N], index of the vertices in the second set
110 | assert self.embds is not None, "Please call precompute() first!"
111 | with torch.no_grad():
112 | ans = self.embd_decoder_func(p_vertices, q_vertices, self.embds)
113 | return ans
114 |
115 | def SSAD(self, source: list):
116 | # given a mesh, calculate the geodesic distances from the source to all other vertices
117 | assert self.embds is not None, "Please call precompute() first!"
118 |
119 |
120 | with torch.no_grad():
121 | ret = []
122 | ss, tt = [], []
123 | for i in range(len(source)):
124 | s = torch.tensor([source[i]]).repeat(self.embds.shape[0]).cuda()
125 | t = torch.arange(self.embds.shape[0]).cuda()
126 | ss.append(s)
127 | tt.append(t)
128 | ans = self.embd_decoder_func(s, t, self.embds)
129 | ret.append(ans)
130 |
131 | return ret
132 |
133 |
134 |
135 |
136 |
137 |
138 | # a wrapper of pretrained model, so that it can be called directly from the command line
139 | def main():
140 | parser = argparse.ArgumentParser()
141 | parser.add_argument('--mode', type=str, default='SSAD',
142 | help='only SSAD available for now')
143 | parser.add_argument('--test_file', type=str, default=None, help='path to the obj file')
144 | parser.add_argument('--ckpt_path', type=str, default=None, help='path to the checkpoint')
145 | parser.add_argument('--start_pts', type=str, default=None, help='an int is expected.')
146 | parser.add_argument('--output', type=str, default=None, help='path to the output file')
147 | args = parser.parse_args()
148 |
149 | if args.mode == "SSAD":
150 | obj_dic = read_mesh(args.test_file)
151 | # print the vertex and face number
152 | print("Vertex number: ", obj_dic['vertices'].shape[0], "Face number: ", obj_dic['faces'].shape[0])
153 | start_pts = torch.tensor(int(args.start_pts)).cuda()
154 |
155 | model = PretrainedModel(args.ckpt_path).cuda()
156 | model.precompute(obj_dic)
157 | dist_pred = model.SSAD([start_pts])[0]
158 |
159 | np.save(args.output, dist_pred.detach().cpu().numpy())
160 |
161 | # save the colored mesh for visualization
162 | # given the vertices, faces of a mesh, save it as obj file
163 | def save_mesh_as_obj(vertices, faces, scalar=None, path="out/our_mesh.obj"):
164 | with open(path, 'w') as f:
165 | f.write('# mesh\n') # header of LittleRender
166 | for v in vertices:
167 | f.write('v ' + str(v[0]) + ' ' + str(v[1]) + ' ' + str(v[2]) + '\n')
168 | for face in faces:
169 | f.write('f ' + str(face[0]+1) + ' ' + str(face[1]+1) + ' ' + str(face[2]+1) + '\n')
170 | if scalar is not None:
171 | # normalize the scalar to [0, 1]
172 | scalar = (scalar - np.min(scalar)) / (np.max(scalar) - np.min(scalar))
173 | for c in scalar:
174 | f.write('c ' + str(c) + ' ' + str(c) + ' ' + str(c) + '\n')
175 |
176 | print("Saved mesh as obj file:", path, end="")
177 | if scalar is not None:
178 | print(" (with color) ")
179 | else:
180 | print(" (without color)")
181 |
182 | save_mesh_as_obj(obj_dic['vertices'].detach().cpu().numpy(),
183 | obj_dic['faces'].detach().cpu().numpy(),
184 | dist_pred.detach().cpu().numpy())
185 |
186 |
187 | else:
188 | print("Invalid mode! (" + args.mode + ")")
189 |
190 |
191 |
192 | if __name__ == "__main__":
193 | main()
194 |
195 | ###################################
196 | # visualization via polyscope starts
197 | # comment out the following lines if you are using ssh
198 | ###################################
199 | import polyscope as ps
200 | import numpy as np
201 | import trimesh
202 |
203 | # load mesh
204 | mesh = trimesh.load_mesh("out/our_mesh.obj", process=False)
205 | vertices = mesh.vertices
206 | faces = mesh.faces
207 |
208 | # load numpy array
209 | colors = np.load("out/ssad_ours.npy")
210 | print(colors.shape)
211 |
212 | # Initialize polyscope
213 | ps.init()
214 | ps_cloud = ps.register_point_cloud("my mesh", vertices)
215 | ps_cloud.add_scalar_quantity("geo_distance", colors, enabled=True)
216 | ps.show()
217 | ###################################
218 | # visualization via polyscope ends
219 | ###################################
220 |
--------------------------------------------------------------------------------
/GnnDist.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from utils.thsolver import default_settings
3 | # initialize global settings
4 | default_settings._init()
5 | from utils.thsolver.config import parse_args
6 | FLAGS = parse_args()
7 | default_settings.set_global_values(FLAGS)
8 |
9 |
10 | from utils import thsolver
11 | import hgraph
12 |
13 | from dataset_ps import get_dataset
14 |
15 |
16 |
17 | def get_parameter_number(model):
18 | """print the number of parameters in a model on terminal """
19 | total_num = sum(p.numel() for p in model.parameters())
20 | trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
21 | print(f"\nTotal Parameters: {total_num}, trainable: {trainable_num}")
22 | return {'Total': total_num, 'Trainable': trainable_num}
23 |
24 |
25 | class GnnDistSolver(thsolver.Solver):
26 |
27 | def get_model(self, flags):
28 |
29 | if flags.name.lower() == 'unet':
30 | model = hgraph.models.graph_unet.GraphUNet(
31 | flags.channel, flags.nout, flags.interp, flags.nempty)
32 |
33 |
34 | # model = ocnn.models.octree_unet.OctreeUnet(flags.channel, flags.nout, flags.interp, flags.nempty)
35 | else:
36 | raise ValueError
37 |
38 |
39 | # overall num parameters
40 | get_parameter_number(model)
41 | return model
42 |
43 | def get_dataset(self, flags):
44 | return get_dataset(flags)
45 |
46 |
47 | def model_forward(self, batch):
48 | """equivalent to `self.get_embd` + `self.embd_decoder_func` """
49 |
50 | data = batch["feature"].cuda()
51 | hgraph = batch['hgraph'].cuda()
52 | dist = batch['dist'].cuda()
53 |
54 | pred = self.model(data, hgraph, hgraph.depth, dist)
55 | return pred
56 |
57 | def get_embd(self, batch):
58 | """only used in visualization!"""
59 | data = batch["feature"].cuda()
60 | hgraph = batch['hgraph'].cuda()
61 | dist = batch['dist'].cuda()
62 |
63 | embedding = self.model(data, hgraph, hgraph.depth, dist, only_embd=True)
64 | return embedding
65 |
66 | def embd_decoder_func(self, i, j, embedding):
67 | """only used in visualization!"""
68 | i = i.long()
69 | j = j.long()
70 | embd_i = embedding[i].squeeze(-1)
71 | embd_j = embedding[j].squeeze(-1)
72 | embd = (embd_i - embd_j) ** 2
73 | pred = self.model.embedding_decoder_mlp(embd)
74 | pred = pred.squeeze(-1)
75 | return pred
76 |
77 | def train_step(self, batch):
78 | pred = self.model_forward(batch)
79 | loss = self.loss_function(batch, pred)
80 | return {'train/loss': loss}
81 |
82 | def test_step(self, batch):
83 | pred = self.model_forward(batch)
84 | loss = self.loss_function(batch, pred)
85 | return {'test/loss': loss}
86 |
87 | def loss_function(self, batch, pred):
88 | dist = batch['dist'].cuda()
89 | gt = dist[:, 2]
90 |
91 | # there are many kind of losses that may apply:
92 |
93 | # option 1: Mean Absolute Error, MAE
94 | #loss = torch.abs(pred - gt).mean()
95 |
96 | # option 2: relative MAE
97 | loss = (torch.abs(pred - gt) / (gt + 1e-3)).mean()
98 |
99 | # option 3: Mean Squared Error, MSE
100 | #loss = torch.square(pred - gt).mean()
101 |
102 | # option 4: relative MSE
103 | #loss = torch.square((pred - gt) / (gt + 1e-3)).mean()
104 |
105 | # option 5: root mean squared error, RMSE
106 | #loss = torch.sqrt(torch.square(pred - gt).mean())
107 |
108 | # clamp
109 | loss = torch.clamp(loss, -10, 10)
110 |
111 | return loss
112 |
113 |
114 | #def visualization(ret):
115 | # open the render to visualize the result
116 | # print("Establishing WebSocket and Visualization!")
117 | # asyncio.run(interactive.main(ret))
118 |
119 |
120 |
121 | if __name__ == "__main__":
122 | ret = GnnDistSolver.main()
123 | #visualization(ret)
124 |
125 |
--------------------------------------------------------------------------------
/configs/gnndist_test.yaml:
--------------------------------------------------------------------------------
1 | SOLVER:
2 | gpu: 0,
3 | run: train
4 |
5 | logdir: logs/my_test
6 | max_epoch: 500
7 | test_every_epoch: 20
8 | log_per_iter: 10
9 | ckpt_num: 5
10 | dist_url: tcp://localhost:10266
11 |
12 | # optimizer
13 | type: adamw
14 | weight_decay: 0.01 # default value of adamw
15 | lr: 0.00025
16 |
17 | # learning rate
18 | lr_type: poly
19 | lr_power: 0.9
20 |
21 | DATA:
22 | train:
23 | # octree building
24 | depth: 6
25 | full_depth: 2 #The octree layers with a depth smaller than `full_depth` are forced to be full.
26 |
27 | # data augmentations
28 | distort: False
29 |
30 | # data loading
31 | location: ./
32 | filelist: ./data/tiny/filelist/filelist_train.txt
33 | batch_size: 1
34 | shuffle: True
35 | num_workers: 0
36 |
37 | test:
38 | # octree building
39 | depth: 6
40 | full_depth: 2
41 |
42 | # data augmentations
43 | distort: False
44 |
45 | # data loading
46 | location: ./
47 | filelist: ./data/tiny/filelist/filelist_test.txt
48 | batch_size: 4
49 | shuffle: True
50 | num_workers: 10
51 |
52 | MODEL:
53 | name: unet
54 | feature: PN # N -> Normal(3 channels);
55 | # P -> Points(3 channel)
56 | channel: 6
57 |
58 | nout: 256 # the final embedding dimension of each vertices
59 | nempty: True
60 |
61 | num_edge_types: 7 # deprecated
62 |
63 | # SAGE, GAT, Edge, DirConv, DistConv, my
64 | conv_type: my
65 | include_distance: True # only appliable when use dist conv. if true, the distance between points will be concated to the feature.
66 |
67 | normal_aware_pooling: True # when grid pooling, consider normal or not
68 |
69 |
70 | # visualization, will not affect the training/testing process, only visualization
71 |
72 | get_test_stat: False # if true, evaluate the test set before visualization
73 | # test_mesh: 2323 # NOT finished: specify a mesh to evaluate in visualization system
74 |
75 |
--------------------------------------------------------------------------------
/dataset_ps.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import utils
3 | from utils import ocnn
4 |
5 | import numpy as np
6 | from utils.thsolver import Dataset
7 |
8 |
9 | from hgraph.hgraph import Data
10 | from hgraph.hgraph import HGraph
11 |
12 |
13 | class Transform(utils.ocnn.dataset.Transform):
14 |
15 | def __call__(self, sample: dict, idx: int):
16 | vertices = torch.from_numpy(sample['vertices'].astype(np.float32))
17 | normals = torch.from_numpy(sample['normals'].astype(np.float32))
18 | edges = torch.from_numpy(sample['edges'].astype(np.float32)).t().contiguous().long()
19 | dist_idx = sample['dist_idx'].astype(np.float32)
20 | dist_val = sample['dist_val'].astype(np.float32)
21 | # breakpoint()
22 | dist = np.concatenate([dist_idx, dist_val], -1)
23 | dist = torch.from_numpy(dist)
24 |
25 |
26 | rnd_idx = torch.randint(low=0, high=dist.shape[0], size=(100000,))
27 | dist = dist[rnd_idx]
28 |
29 |
30 | # normalize
31 | norm2 = torch.sqrt(torch.sum(normals ** 2, dim=1, keepdim=True))
32 | normals = normals / torch.clamp(norm2, min=1.0e-12)
33 |
34 | # construct hierarchical graph
35 | h_graph = HGraph()
36 |
37 | h_graph.build_single_hgraph(Data(x=torch.cat([vertices, normals], dim=1), edge_index=edges))
38 |
39 | return {'hgraph': h_graph,
40 | 'vertices': vertices, 'normals': normals,
41 | 'dist': dist, 'edges': edges}
42 |
43 |
44 | def collate_batch(batch: list):
45 | # batch: list of single samples.
46 | # each sample is a dict with keys:
47 | # edges, vertices, normals, dist
48 |
49 | # output: a big sample
50 | assert type(batch) == list
51 |
52 | # merge many hgraphs into one super hgraph
53 |
54 |
55 | outputs = {}
56 | for key in batch[0].keys():
57 | outputs[key] = [b[key] for b in batch]
58 |
59 | pts_num = torch.tensor([pts.shape[0] for pts in outputs['vertices']])
60 | cum_sum = utils.ocnn.utils.cumsum(pts_num, dim=0, exclusive=True)
61 | for i, dist in enumerate(outputs['dist']):
62 | dist[:, :2] += cum_sum[i]
63 | #for i, edge in enumerate(outputs['edges']):
64 | # edge += cum_sum[i]
65 | outputs['dist'] = torch.cat(outputs['dist'], dim=0)
66 |
67 |
68 | # input feature
69 | vertices = torch.cat(outputs['vertices'], dim=0)
70 | normals = torch.cat(outputs['normals'], dim=0)
71 | feature = torch.cat([vertices, normals], dim=1)
72 | outputs['feature'] = feature
73 |
74 | # merge a batch of hgraphs into one super hgraph
75 | hgraph_super = HGraph(batch_size=len(batch))
76 | hgraph_super.merge_hgraph(outputs['hgraph'])
77 | outputs['hgraph'] = hgraph_super
78 |
79 | #if (outputs['dist'].max() >= len(vertices)):
80 | # print("!!!!!!!")
81 |
82 |
83 | # Merge a batch of octrees into one super octree
84 | #octree = ocnn.octree.merge_octrees(outputs['octree'])
85 | #octree.construct_all_neigh()
86 | #outputs['octree'] = octree
87 |
88 | # Merge a batch of points
89 | #outputs['points'] = ocnn.octree.merge_points(outputs['points'])
90 | return outputs
91 |
92 |
93 | def get_dataset(flags):
94 | transform = Transform(**flags)
95 | dataset = Dataset(flags.location, flags.filelist, transform,
96 | read_file=np.load, take=flags.take)
97 | return dataset, collate_batch
98 |
--------------------------------------------------------------------------------
/hgraph/__init__.py:
--------------------------------------------------------------------------------
1 | from . import models
2 | from . import modules
3 | from . import hgraph
--------------------------------------------------------------------------------
/hgraph/hgraph.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import Callable, Optional, Tuple
3 |
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 | import torch_geometric
10 | from torch_geometric.nn import SAGEConv
11 | #from torch_geometric.nn import avg_pool, voxel_grid
12 |
13 | #from utils.thsolver import default_settings
14 |
15 | from .modules.decide_edge_type import *
16 |
17 | """
18 | graph neural network coarse to fine data structure
19 | keep records of a "graph tree"
20 | """
21 |
22 |
23 | class Data:
24 | def __init__(self, x=None, edge_index=None, edge_attr=None):
25 | """
26 | a rewrite of torch_geometric.data.Data
27 | get rid of its self-aleck re-indexing
28 | :param edge_attr: the type of edges.
29 | """
30 | self.x = x
31 | self.edge_index = edge_index
32 | self.edge_attr = edge_attr
33 |
34 | def to(self, target):
35 | # trans data from gpu to cpu or vice versa
36 | self.x = self.x.to(target)
37 | self.edge_index = self.edge_index.to(target)
38 | if self.edge_attr != None:
39 | self.edge_attr = self.edge_attr.to(target)
40 | return self
41 |
42 | def cuda(self):
43 | return self.to("cuda")
44 |
45 | def cpu(self):
46 | return self.to("cpu")
47 |
48 |
49 |
50 | def avg_pool(
51 | cluster: torch.Tensor,
52 | data: Data,
53 | transform: Optional[Callable] = None,
54 | ) -> Data:
55 | """a wrapper of torch_geometric.nn.avg_pool"""
56 | data_torch_geometric = torch_geometric.data.Data(x=data.x, edge_index=data.edge_index)
57 | new_data = torch_geometric.nn.avg_pool(cluster, data_torch_geometric, transform=transform)
58 | ret = Data(x=new_data.x, edge_index=new_data.edge_index)
59 | return ret
60 |
61 |
62 | def avg_pool_maintain_old(
63 | cluster: torch.Tensor,
64 | data: Data,
65 | transform: Optional[Callable] = None,
66 | ):
67 | """a wrapper of torch_geometric.nn.avg_pool, but maintain the old graph"""
68 | data_torch_geometric = torch_geometric.data.Data(x=data.x, edge_index=data.edge_index)
69 | new_layer = torch_geometric.nn.avg_pool(cluster, data_torch_geometric, transform=transform)
70 | # connect the corresponding node in the two layers
71 |
72 |
73 |
74 | def pooling(data: Data, size: float, normal_aware_pooling):
75 | """
76 | do pooling according to x. a new graph (x, edges) will be generated after pooling.
77 | This function is a wrapper of some funcs in pytorch_geometric. It assumes the
78 | object's coordinates range from -1 to 1.
79 |
80 | normal_aware_pooling: if True, only data.x[..., :3] will be used for grid pooling.
81 | """
82 | assert type(size) == float
83 | if normal_aware_pooling == False:
84 | x = data.x[..., :3]
85 | else:
86 | x = data.x[..., :6]
87 | # we assume x has 6 feature channels, first 3 xyz, then 3 nxnynz
88 | # grid size here waits for fine-tuning
89 | n_size = size * 3. # a hyper parameter, controling "how important the normal vecs are, compared to xyz coords"
90 | size = [size, size, size, n_size, n_size, n_size]
91 |
92 | edges = data.edge_index
93 | cluster = torch_geometric.nn.voxel_grid(pos=x, size=size, batch=None,
94 | start=[-1, -1, -1.], end=[1, 1, 1.])
95 |
96 | # keep max index smaller than # of unique (e.g., [4,5,4,3] --> [0,1,0,2])
97 | mapping = cluster.unique()
98 | mapping += mapping.shape[0]
99 | cluster += mapping.shape[0]
100 |
101 | for i in range(int(mapping.shape[0])): # maybe some optimization here to remove the for loop
102 | cluster[cluster == mapping[i]] = i
103 |
104 | return cluster #.contiguous()
105 |
106 |
107 |
108 |
109 | def add_self_loop(data: Data):
110 | # avg_pool will clear the self loop in the graph. here we add it back
111 | device = data.x.device
112 | n_vert = data.x.shape[0]
113 | self_loops = torch.tensor([[i for i in range(int(n_vert))]])
114 | self_loops = self_loops.repeat(2, 1)
115 | new_edges = torch.zeros([2, data.edge_index.shape[1] + n_vert], dtype=torch.int64)
116 | new_edges[:, :data.edge_index.shape[1]] = data.edge_index
117 | new_edges[:, data.edge_index.shape[1]:] = self_loops
118 |
119 | return Data(x=data.x, edge_index=new_edges).to(device)
120 |
121 |
122 |
123 |
124 | class HGraph:
125 | """
126 | HGraph stands for "Hierarchical Graph"
127 |
128 | notes:
129 | - coordinates of input vertices should be in [-1, 1]
130 | - this class handles xyz coordinates/normals, not input feature
131 |
132 | """
133 | def __init__(self, depth: int=5,
134 | smallest_grid=2/2**5,
135 | batch_size: int=1,
136 | adj_layer_connected=False):
137 | """
138 |
139 | suppose depth=3:
140 | self.treedict[0] = original graph
141 | self.treedict[1] = merge verts in voxel with edge length = smallest_grid
142 | self.treedict[2] = merge verts in voxel with edge length = smallest_grid * (2**1)
143 | self.treedict[3] = merge verts in voxel with edge length = smallest_grid * (2**2)
144 |
145 | __init__ method only specify hyper parameters of a hgraph.
146 | The real construction of hgraph happens in build_single_hgraph() or merge_hgraph()
147 | """
148 |
149 | assert smallest_grid * (2**(depth-1)) <= 2
150 | assert depth >= 0
151 |
152 | self.device = "cuda"
153 | self.depth = depth
154 | self.batch_size = batch_size
155 | self.smallest_grid = smallest_grid
156 | self.normal_aware_pooling = True
157 |
158 | self.vertices_sizes = {}
159 | self.edges_sizes = {}
160 | #
161 | self.treedict = {}
162 | self.cluster = {}
163 |
164 | def build_single_hgraph(self, original_graph: Data):
165 | """
166 | build a graph-tree of **one** graph
167 | """
168 | assert type(original_graph) == Data
169 |
170 |
171 | graphtree = {}
172 | cluster = {}
173 | vertices_size = {}
174 | edges_size = {}
175 |
176 | for i in range(self.depth+1):
177 | if i == 0:
178 | original_graph = add_self_loop(original_graph)
179 | # if original graph do not have edge types, assign it
180 | if original_graph.edge_attr == None:
181 | edges = original_graph.x[original_graph.edge_index[0]] \
182 | - original_graph.x[original_graph.edge_index[1]]
183 | edges_attr = decide_edge_type_distance(edges, return_edge_length=False)
184 | original_graph.edge_attr = edges_attr
185 | graphtree[0] = original_graph
186 | cluster[0] = None
187 | edges_size[0] = original_graph.edge_index.shape[1]
188 | vertices_size[0] = original_graph.x.shape[0]
189 | continue
190 |
191 | clst = pooling(graphtree[i-1], self.smallest_grid * (2**(i-1)), normal_aware_pooling=self.normal_aware_pooling)
192 | new_graph = avg_pool(cluster=clst, data=graphtree[i-1], transform=None)
193 | new_graph = add_self_loop(new_graph)
194 | # assign edge type
195 | edges = new_graph.x[new_graph.edge_index[0]] \
196 | - new_graph.x[new_graph.edge_index[1]]
197 | edges_attr = decide_edge_type_distance(edges, return_edge_length=False)
198 | new_graph.edge_attr = edges_attr
199 |
200 | graphtree[i] = new_graph
201 | cluster[i] = clst
202 | edges_size[i] = new_graph.edge_index.shape[1]
203 | vertices_size[i] = new_graph.x.shape[0]
204 |
205 | self.treedict = graphtree
206 | self.cluster = cluster
207 | self.vertices_sizes = vertices_size
208 | self.edges_sizes = edges_size
209 |
210 | #self.export_obj()
211 |
212 |
213 | # @staticmethod
214 | def merge_hgraph(self, original_graphs: list[HGraph], debug_report=False):
215 | """
216 | merge multi hgraph into a large hgraph
217 |
218 | """
219 | assert len(self.cluster) == 0 and len(self.treedict) == 0, "please call this function on a new instance"
220 | assert original_graphs.__len__() == self.batch_size, "please make sure the batch size is correct"
221 |
222 | # re-indexing
223 | for d in range(self.depth+1):
224 | # merge vertices for every layer
225 | num_vertices = [0]
226 | for i, each in enumerate(original_graphs):
227 | num_vertices.append(each.vertices_sizes[d])
228 | cum_sum = torch.cumsum(torch.tensor(num_vertices), dim=0)
229 | for i in range(original_graphs.__len__()):
230 | original_graphs[i].treedict[d].edge_index += cum_sum[i]
231 | # cluster is None at d=0
232 | if d != 0:
233 | original_graphs[i].cluster[d] += cum_sum[i]
234 |
235 | # merge
236 | for d in range(self.depth+1):
237 | graphtrees_x, graphtrees_e, graphtrees_e_type, clusters = [], [], [], []
238 | for i in range(original_graphs.__len__()):
239 | graphtrees_x.append(original_graphs[i].treedict[d].x)
240 | graphtrees_e.append(original_graphs[i].treedict[d].edge_index)
241 | graphtrees_e_type.append(original_graphs[i].treedict[d].edge_attr)
242 | clusters.append(original_graphs[i].cluster[d])
243 | # construct new graph
244 | temp_data = Data(x=torch.cat(graphtrees_x, dim=0),
245 | edge_index=torch.cat(graphtrees_e, dim=1),
246 | edge_attr=torch.cat(graphtrees_e_type, dim=0) # edge_attr shape: [E]
247 | )
248 | # construct new cluster
249 | if d != 0:
250 | temp_clst = torch.cat(clusters, dim=0)
251 | else:
252 | temp_clst = None
253 | self.treedict[d] = temp_data
254 | self.cluster[d] = temp_clst
255 | self.edges_sizes = temp_data.edge_index.shape[1]
256 | self.vertices_sizes = len(temp_data.x)
257 |
258 | # sanity check
259 | if debug_report == True:
260 | # a simple unit test
261 | for d in range(self.depth+1):
262 | num_edges_before = 0
263 | for i in range(original_graphs.__len__()):
264 | num_edges_before += original_graphs[i].treedict[d].edge_index.shape[1]
265 | num_edges_after = self.treedict[d].edge_index.shape[1]
266 | print(f"Before merge, at d={d} there's {num_edges_before} edges; {num_edges_after} afterwards")
267 |
268 |
269 |
270 |
271 | #####################################################
272 | # Util
273 | #####################################################
274 |
275 | def cuda(self):
276 | # move all tensors to cuda
277 | for each in self.treedict.keys():
278 | self.treedict[each] = self.treedict[each].cuda()
279 | for each in self.cluster.keys():
280 | if self.cluster[each] is None:
281 | continue
282 | self.cluster[each] = self.cluster[each].cuda()
283 | return self
284 |
285 |
286 |
--------------------------------------------------------------------------------
/hgraph/models/__init__.py:
--------------------------------------------------------------------------------
1 | from . import graph_unet
2 | from . import simple_resnet
--------------------------------------------------------------------------------
/hgraph/models/graph_unet.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.nn
4 | from typing import Dict, Optional
5 |
6 |
7 | from hgraph.hgraph import Data
8 | from hgraph.hgraph import HGraph
9 |
10 | from hgraph.modules.resblocks import GraphResBlocks, GraphResBlock2, GraphResBlock
11 | from hgraph.modules import modules
12 |
13 |
14 | class GraphUNet(torch.nn.Module):
15 | r'''
16 | A U-Net like network with graph neural network, utilizing HGraph (hierarchical graph) as
17 | the data structure.
18 | '''
19 |
20 | def __init__(self, in_channels: int, out_channels: int, interp: str = 'linear',
21 | nempty: bool = False, **kwargs):
22 | super(GraphUNet, self).__init__()
23 | self.in_channels = in_channels
24 | self.out_channels = out_channels
25 | self.config_network()
26 | self.encoder_stages = len(self.encoder_blocks)
27 | self.decoder_stages = len(self.decoder_blocks)
28 |
29 | # encoder
30 | self.conv1 = modules.GraphConvBnRelu(in_channels, self.encoder_channel[0])
31 |
32 | self.downsample = torch.nn.ModuleList(
33 | [modules.PoolingGraph() for i in range(self.encoder_stages)]
34 | )
35 | self.encoder = torch.nn.ModuleList(
36 | [GraphResBlocks(self.encoder_channel[i], self.encoder_channel[i+1],
37 | resblk_num=self.encoder_blocks[i], resblk=self.resblk)
38 | for i in range(self.encoder_stages)]
39 | )
40 |
41 | # decoder
42 | channel = [self.decoder_channel[i] + self.encoder_channel[-i-2]
43 | for i in range(self.decoder_stages)]
44 | self.upsample = torch.nn.ModuleList(
45 | [modules.UnpoolingGraph() for i in range(self.decoder_stages)]
46 | )
47 | self.decoder = torch.nn.ModuleList(
48 | [GraphResBlocks(channel[i], self.decoder_channel[i+1],
49 | resblk_num=self.decoder_blocks[i], resblk=self.resblk, bottleneck=self.bottleneck)
50 | for i in range(self.decoder_stages)]
51 | )
52 |
53 | # header
54 | # channel = self.decoder_channel[self.decoder_stages]
55 | #self.octree_interp = ocnn.nn.OctreeInterp(interp, nempty)
56 | self.header = torch.nn.Sequential(
57 | modules.Conv1x1BnRelu(self.decoder_channel[-1], self.decoder_channel[-1]),
58 | modules.Conv1x1(self.decoder_channel[-1], self.out_channels, use_bias=True))
59 |
60 | # a embedding decoder function
61 | self.embedding_decoder_mlp = torch.nn.Sequential(
62 | torch.nn.Linear(self.out_channels, self.out_channels, bias=True),
63 | torch.nn.ReLU(),
64 | torch.nn.Linear(self.out_channels, self.out_channels, bias=True),
65 | torch.nn.ReLU(),
66 | torch.nn.Linear(self.out_channels, 1, bias=True)
67 | )
68 |
69 | def config_network(self):
70 | r''' Configure the network channels and Resblock numbers.
71 | '''
72 | self.encoder_blocks = [2, 3, 3, 3, 2]
73 | self.decoder_blocks = [2, 3, 3, 3, 2]
74 | self.encoder_channel = [256, 256, 256, 256, 256, 256]
75 | self.decoder_channel = [256, 256, 256, 256, 256, 256]
76 |
77 | # self.encoder_blocks = [4, 9, 9, 3]
78 | # self.decoder_blocks = [4, 9, 9, 3]
79 | #self.encoder_channel = [512, 512, 512, 512, 512,]
80 | #self.decoder_channel = [512, 512, 512, 512, 512]
81 |
82 | self.bottleneck = 1
83 | self.resblk = GraphResBlock2
84 |
85 | def unet_encoder(self, data: torch.Tensor, hgraph: HGraph, depth: int):
86 | r''' The encoder of the U-Net.
87 | '''
88 |
89 | convd = dict()
90 | convd[depth] = self.conv1(data, hgraph, depth)
91 | for i in range(self.encoder_stages):
92 | d = depth - i
93 | conv = self.downsample[i](convd[d], hgraph, i+1)
94 | convd[d-1] = self.encoder[i](conv, hgraph, d-1)
95 | return convd
96 |
97 | def unet_decoder(self, convd: Dict[int, torch.Tensor], hgraph: HGraph, depth: int):
98 | r''' The decoder of the U-Net.
99 | '''
100 |
101 | deconv = convd[depth]
102 | for i in range(self.decoder_stages):
103 | d = depth + i
104 | deconv = self.upsample[i](deconv, hgraph, self.decoder_stages-i)
105 | deconv = torch.cat([convd[d+1], deconv], dim=1) # skip connections
106 | deconv = self.decoder[i](deconv, hgraph, d+1)
107 | return deconv
108 |
109 | def forward(self, data: torch.Tensor, hgraph: HGraph, depth: int, dist: torch.Tensor, only_embd=False):
110 | """_summary_
111 |
112 | Args:
113 | data (torch.Tensor): _description_
114 | hgraph (HGraph): _description_
115 | depth (int): _description_
116 | dist (torch.Tensor): _description_
117 | only_embd (bool, optional): If True, the return value will be the embeddings of vertices; if False,
118 | return the estimated distance of point pairs. Defaults to False.
119 | """
120 | convd = self.unet_encoder(data, hgraph, depth)
121 | deconv = self.unet_decoder(convd, hgraph, depth - self.encoder_stages)
122 |
123 | embedding = self.header(deconv)
124 |
125 | if dist == None and only_embd:
126 | return embedding
127 |
128 | # calculate the distance
129 | i, j= dist[:, 0].long(), dist[:, 1].long()
130 |
131 | embd_i = embedding[i].squeeze(-1)
132 | embd_j = embedding[j].squeeze(-1)
133 |
134 | embd = (embd_i - embd_j) ** 2
135 | # alternative way... bad
136 | #embd_1 = (embd_i[..., :64] - embd_j[..., :64]) ** 2
137 | #embd_2 = (embd_i[..., 64:] - embd_j[..., 64:]) ** 4
138 | #embd = torch.cat([embd_1, embd_2], dim=-1)
139 |
140 | pred = self.embedding_decoder_mlp(embd)
141 | pred = pred.squeeze(-1)
142 |
143 | if only_embd:
144 | return embedding
145 | else:
146 | return pred
147 |
148 |
149 |
--------------------------------------------------------------------------------
/hgraph/modules/decide_edge_type.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 |
4 | # helper functions, decide the type of a edge
5 |
6 |
7 | ###############################################################
8 | # For distance graph conv
9 | ###############################################################
10 |
11 | def decide_edge_type_distance(vec: torch.tensor,
12 | method="predefined",
13 | return_edge_length=True):
14 | """
15 | classify each vec into many categories.
16 | vec: N x 3
17 | ret: N
18 | """
19 | if method == "predefined":
20 | return decide_edge_type_predefined_distance(vec, return_edge_length=return_edge_length)
21 | else:
22 | raise NotImplementedError
23 |
24 |
25 | def decide_edge_type_predefined_distance(vec: torch.tensor, epsilon=0.00001, return_edge_length=True):
26 | """
27 | classify each vec into N categories, according to the length of the vcector.
28 | the last category is self-loop
29 | vec: N x 3
30 | ret: N
31 | """
32 | positive_x = torch.maximum(vec[..., 0], torch.zeros_like(vec[..., 0]))
33 | positive_y = torch.maximum(vec[..., 1], torch.zeros_like(vec[..., 0]))
34 | positive_z = torch.maximum(vec[..., 2], torch.zeros_like(vec[..., 0]))
35 | negative_x = - torch.minimum(vec[..., 0], torch.zeros_like(vec[..., 0]))
36 | negative_y = - torch.minimum(vec[..., 1], torch.zeros_like(vec[..., 0]))
37 | negative_z = - torch.minimum(vec[..., 2], torch.zeros_like(vec[..., 0]))
38 |
39 | ary = torch.stack([positive_x, positive_y, positive_z, negative_x, negative_y, negative_z])
40 | ary = ary.transpose(0, 1)
41 |
42 | device = vec.device
43 |
44 | edge_type = torch.ones([len(vec), 1]).to(device) * 999
45 | vec_length = torch.norm(ary, dim=1)
46 |
47 | # edge_length > eps --> type 0 edge
48 | # edge_length <= eps --> type 1 edge (self-loop)
49 |
50 | # print(thres_1)
51 | dist_threshold = [epsilon]
52 |
53 |
54 | for i in range(len(dist_threshold)-1, -1, -1):
55 | dist_mask = vec_length > dist_threshold[i]
56 | edge_type[dist_mask] = i
57 |
58 | # self-loop
59 | self_loops_mask = vec_length <= epsilon
60 | edge_type[self_loops_mask] = len(dist_threshold)
61 |
62 | # squeeze to 1d tensor
63 | edge_type = edge_type.squeeze(-1)
64 | edge_type = edge_type.long()
65 |
66 |
67 | # assertion, suppose there are N thresholds (epsilon included), make sure that
68 | # the type of different edges are at most N+1
69 | # breakpoint()
70 | assert edge_type.max() == len(dist_threshold) # there must be edges indexed N, since self-loop exists in all meshes
71 | assert edge_type.min() >= 0 # note edge_type indexed N-1 may not exists, since edges of that length may not exists in this mesh
72 |
73 | if return_edge_length == True:
74 | return edge_type, vec_length
75 | else:
76 | return edge_type
77 |
78 |
--------------------------------------------------------------------------------
/hgraph/modules/modules.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import math
4 | import torch.utils.checkpoint
5 | from typing import List, Optional
6 |
7 |
8 | from hgraph.hgraph import Data
9 | #from torch_geometric.nn import avg_pool
10 | from hgraph.hgraph import avg_pool
11 | from hgraph.hgraph import HGraph
12 |
13 |
14 | bn_momentum, bn_eps = 0.01, 0.001 # the default value of Tensorflow 1.x
15 | # bn_momentum, bn_eps = 0.1, 1e-05 # the default value of pytorch
16 |
17 |
18 | ###############################################################
19 | # Util funcs
20 | ###############################################################
21 |
22 | from .decide_edge_type import *
23 |
24 |
25 |
26 | def broadcast(src: torch.Tensor, other: torch.Tensor, dim: int):
27 | r''' Broadcast :attr:`src` according to :attr:`other`, originally from the
28 | library `pytorch_scatter`.
29 | '''
30 |
31 | if dim < 0:
32 | dim = other.dim() + dim
33 |
34 | if src.dim() == 1:
35 | for _ in range(0, dim):
36 | src = src.unsqueeze(0)
37 | for _ in range(src.dim(), other.dim()):
38 | src = src.unsqueeze(-1)
39 |
40 | src = src.expand_as(other)
41 | return src
42 |
43 |
44 | def scatter_add(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
45 | out: Optional[torch.Tensor] = None,
46 | dim_size: Optional[int] = None,) -> torch.Tensor:
47 | r''' Reduces all values from the :attr:`src` tensor into :attr:`out` at the
48 | indices specified in the :attr:`index` tensor along a given axis :attr:`dim`.
49 | This is just a wrapper of :func:`torch.scatter` in a boardcasting fashion.
50 |
51 | Args:
52 | src (torch.Tensor): The source tensor.
53 | index (torch.Tensor): The indices of elements to scatter.
54 | dim (torch.Tensor): The axis along which to index, (default: :obj:`-1`).
55 | out (torch.Tensor or None): The destination tensor.
56 | dim_size (int or None): If :attr:`out` is not given, automatically create
57 | output with size :attr:`dim_size` at dimension :attr:`dim`. If
58 | :attr:`dim_size` is not given, a minimal sized output tensor according
59 | to :obj:`index.max() + 1` is returned.
60 | '''
61 |
62 | index = broadcast(index, src, dim)
63 |
64 | if out is None:
65 | size = list(src.size())
66 | if dim_size is not None:
67 | size[dim] = dim_size
68 | elif index.numel() == 0:
69 | size[dim] = 0
70 | else:
71 | size[dim] = int(index.max()) + 1
72 | out = torch.zeros(size, dtype=src.dtype, device=src.device)
73 |
74 | return out.scatter_add_(dim, index, src)
75 |
76 |
77 | ###############################################################
78 | # Basic Operators on graphs
79 | ###############################################################
80 |
81 | from torch_geometric.nn import GraphSAGE
82 |
83 | class GraphSAGEConv(torch.nn.Module):
84 | def __init__(self, in_channels: int, out_channels: int,
85 | ):
86 | super().__init__()
87 | self.conv = GraphSAGE(in_channels, out_channels, num_layers=1)
88 |
89 | def forward(self, input_feature: torch.Tensor, hgraph: HGraph, depth: int):
90 | graph = hgraph.treedict[hgraph.depth - depth]
91 | assert input_feature.shape[0] == graph.x.shape[0]
92 | out = self.conv(input_feature, graph.edge_index)
93 | return out
94 |
95 |
96 | from torch_geometric.nn import GATv2Conv
97 |
98 | class GraphAttentionConv(torch.nn.Module):
99 | # Graph Attention Network ()
100 | def __init__(self, in_channels: int, out_channels: int,
101 | ):
102 | super().__init__()
103 | self.conv = GATv2Conv(in_channels, out_channels, num_layers=1)
104 |
105 | def forward(self, input_feature: torch.Tensor, hgraph: HGraph, depth: int):
106 | graph = hgraph.treedict[hgraph.depth - depth]
107 | assert input_feature.shape[0] == graph.x.shape[0]
108 | out = self.conv(input_feature, graph.edge_index)
109 | return out
110 |
111 |
112 | ############################################################
113 | from torch.nn import Sequential as Seq, Linear, ReLU
114 | from torch_geometric.nn import MessagePassing, HeteroLinear
115 |
116 |
117 | ############################################################
118 |
119 | class MyConvOp(MessagePassing):
120 | # this implementation is from tutorial "message passing" of torch_geometric
121 | def __init__(self, in_channels, out_channels, include_distance):
122 | super().__init__(aggr='max') # "Max" aggregation.
123 | #self.mlp = Seq(Linear(2 * in_channels, out_channels),
124 | # ReLU(),
125 | # Linear(out_channels, out_channels))
126 | # self.self_loop = Linear(in_channels, out_channels)
127 | # self.neighbor_matrix_1 = Linear(in_channels, out_channels)
128 | # self.neighbor_matrix_2 = Linear(in_channels, out_channels)
129 |
130 | self.num_types = 2
131 | self.lin = HeteroLinear(in_channels, out_channels, num_types=self.num_types,)
132 | self.include_distance = include_distance
133 |
134 |
135 | def forward(self, x, edge_index, edge_attr=None):
136 | # x has shape [N, in_channels]
137 | # edge_index has shape [2, E]
138 |
139 | # edge_attr has shape [1, E]
140 | # referred to https://pytorch-geometric.readthedocs.io/en/latest/tutorial/create_gnn.html#implementing-the-gcn-layer
141 |
142 | if edge_attr == None:
143 | # dynamically calculate edge_type
144 | raise NotImplementedError
145 | else:
146 | # use predefined edge_type
147 | if edge_attr.min() != 0:
148 | breakpoint()
149 | assert edge_attr.min() == 0
150 | assert edge_attr.max() == self.num_types - 1
151 |
152 |
153 | return self.propagate(edge_index, x=x, edge_attr=edge_attr)
154 |
155 | def message(self, x_i, x_j, edge_attr):
156 | # x_i has shape [E, in_channels]
157 | # x_j has shape [E, in_channels]
158 |
159 |
160 | # last 3 dim is the xyz position of vertices
161 | # compute the abs value of them
162 | if self.include_distance == True:
163 | abs_dist = torch.norm(x_i[..., -3:] - x_j[..., -3:], dim=1, keepdim=True) # shape: [E, 1]
164 | x_j = torch.cat([x_j, x_i[..., -3:] - x_j[..., -3:], abs_dist], dim=-1)
165 |
166 | out = self.lin(x_j, edge_attr)
167 | return out
168 |
169 |
170 | class MyConv(torch.nn.Module):
171 | # My Conv, consider neighbor features (relative pos and its abs)
172 | # and many seperated catagories
173 | def __init__(self, in_channels: int, out_channels: int,
174 | ):
175 | super().__init__()
176 | self.include_distance = True#default_settings.get_global_value("include_distance")
177 | if self.include_distance == False:
178 | self.conv = MyConvOp(in_channels, out_channels, include_distance=False)
179 | else:
180 | self.conv = MyConvOp(in_channels + 3 + 3 + 1, out_channels, include_distance=True)
181 |
182 | def forward(self, input_feature: torch.Tensor, hgraph: HGraph, depth: int):
183 | graph = hgraph.treedict[hgraph.depth - depth]
184 | assert input_feature.shape[0] == graph.x.shape[0]
185 | # concat input feature
186 | if self.include_distance == True:
187 | input_feature = torch.cat([input_feature, graph.x[..., :3]], dim=-1)
188 | # do the conv
189 | out = self.conv(input_feature, graph.edge_index, graph.edge_attr)
190 | return out
191 |
192 |
193 | ########################################################
194 |
195 |
196 |
197 |
198 | ###############################################################
199 | # Complex Components
200 | ###############################################################
201 |
202 | conv_type = "my"#default_settings.get_global_value("conv_type").lower()
203 | if conv_type == "sage":
204 | GraphConv = GraphSAGEConv
205 | elif conv_type == "my":
206 | GraphConv = MyConv
207 | else:
208 | raise NotImplementedError
209 |
210 | # group norm
211 | normalization = lambda x: torch.nn.GroupNorm(num_groups=4, num_channels=x, eps=bn_eps)
212 |
213 |
214 |
215 | class GraphConvBn(torch.nn.Module):
216 | def __init__(self, in_channels: int, out_channels: int,
217 | ):
218 | super().__init__()
219 | self.conv = GraphConv(in_channels, out_channels)
220 | self.bn = normalization(out_channels)
221 |
222 | def forward(self, data: torch.Tensor, hgraph: HGraph, depth: int):
223 | out = self.conv(data, hgraph, depth)
224 | out = self.bn(out)
225 | return out
226 |
227 |
228 | class GraphConvBnRelu(torch.nn.Module):
229 | def __init__(self, in_channels: int, out_channels: int,
230 | ):
231 | super().__init__()
232 | self.conv = GraphConv(in_channels, out_channels)
233 | self.bn = normalization(out_channels)
234 | self.relu = torch.nn.ReLU() # inplace=True
235 |
236 | def forward(self, data: torch.Tensor, hgraph: HGraph, depth: int):
237 | out = self.conv(data, hgraph, depth)
238 | out = self.bn(out)
239 | out = self.relu(out)
240 | return out
241 |
242 |
243 | class PoolingGraph(torch.nn.Module):
244 | def __init__(self):
245 | super().__init__()
246 | pass
247 |
248 | def forward(self, x: torch.Tensor, hgraph: HGraph, depth: int):
249 | cluster = hgraph.cluster[depth]
250 | out = avg_pool(cluster=cluster, data=Data(x=x, edge_index=torch.zeros([2,1]).long())) # fake edges XD
251 | return out.x
252 |
253 | class UnpoolingGraph(torch.nn.Module):
254 | def __init__(self):
255 | super().__init__()
256 | pass
257 |
258 | def forward(self, x: torch.Tensor, hgraph: HGraph, depth: int):
259 | assert depth != 0
260 | #
261 | feature_dim = x.shape[1]
262 | cluster = hgraph.cluster[depth][..., None].repeat(1, feature_dim).long()
263 | out = torch.gather(input=x, dim=0, index=cluster)
264 | return out
265 |
266 |
267 |
268 |
269 | class Conv1x1(torch.nn.Module):
270 | r''' Performs a convolution with kernel :obj:`(1,1,1)`.
271 |
272 | The shape of octree features is :obj:`(N, C)`, where :obj:`N` is the node
273 | number and :obj:`C` is the feature channel. Therefore, :class:`Conv1x1` can be
274 | implemented with :class:`torch.nn.Linear`.
275 | '''
276 |
277 | def __init__(self, in_channels: int, out_channels: int, use_bias: bool = False):
278 | super().__init__()
279 | self.linear = torch.nn.Linear(in_channels, out_channels, use_bias)
280 |
281 | def forward(self, data: torch.Tensor):
282 | r''''''
283 |
284 | return self.linear(data)
285 |
286 |
287 | class Conv1x1Bn(torch.nn.Module):
288 | r''' A sequence of :class:`Conv1x1` and :class:`BatchNorm`.
289 | '''
290 |
291 | def __init__(self, in_channels: int, out_channels: int):
292 | super().__init__()
293 | self.conv = Conv1x1(in_channels, out_channels, use_bias=False)
294 | self.bn = normalization(out_channels)
295 |
296 | def forward(self, data: torch.Tensor):
297 | r''''''
298 |
299 | out = self.conv(data)
300 | out = self.bn(out)
301 | return out
302 |
303 |
304 | class Conv1x1BnRelu(torch.nn.Module):
305 | r''' A sequence of :class:`Conv1x1`, :class:`BatchNorm` and :class:`Relu`.
306 | '''
307 |
308 | def __init__(self, in_channels: int, out_channels: int):
309 | super().__init__()
310 | self.conv = Conv1x1(in_channels, out_channels, use_bias=False)
311 | self.bn = normalization(out_channels)
312 | self.relu = torch.nn.ReLU(inplace=True)
313 |
314 | def forward(self, data: torch.Tensor):
315 | r''''''
316 |
317 | out = self.conv(data)
318 | out = self.bn(out)
319 | out = self.relu(out)
320 | return out
321 |
322 |
323 | class FcBnRelu(torch.nn.Module):
324 | r''' A sequence of :class:`FC`, :class:`BatchNorm` and :class:`Relu`.
325 | '''
326 |
327 | def __init__(self, in_channels: int, out_channels: int):
328 | super().__init__()
329 | self.flatten = torch.nn.Flatten(start_dim=1)
330 | self.fc = torch.nn.Linear(in_channels, out_channels, bias=False)
331 | self.bn = normalization(out_channels)
332 | self.relu = torch.nn.ReLU(inplace=True)
333 |
334 | def forward(self, data):
335 | r''''''
336 |
337 | out = self.flatten(data)
338 | out = self.fc(out)
339 | out = self.bn(out)
340 | out = self.relu(out)
341 | return out
--------------------------------------------------------------------------------
/hgraph/modules/resblocks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.utils.checkpoint
3 |
4 | from hgraph.hgraph import HGraph
5 | from hgraph.modules.modules import Conv1x1BnRelu, Conv1x1, Conv1x1Bn, \
6 | GraphConv, GraphConvBnRelu, GraphConvBn, FcBnRelu, \
7 | UnpoolingGraph, PoolingGraph
8 |
9 |
10 | class GraphResBlock(torch.nn.Module):
11 | def __init__(self, in_channels: int, out_channels: int,
12 | bottleneck: int=4):
13 | super().__init__()
14 | self.in_channels = in_channels
15 | self.out_channels = out_channels
16 | self.bottleneck = bottleneck
17 | channelb = int(out_channels / bottleneck)
18 |
19 | self.conv1x1a = Conv1x1BnRelu(in_channels, channelb)
20 | self.conv3x3 = GraphConvBnRelu(channelb, channelb)
21 | self.conv1x1b = Conv1x1Bn(channelb, out_channels)
22 |
23 | if self.in_channels != self.out_channels:
24 | self.conv1x1c = Conv1x1Bn(in_channels, out_channels)
25 | self.relu = torch.nn.ReLU(inplace=True)
26 |
27 | def forward(self, data: torch.Tensor, hgraph: HGraph, depth: int):
28 | conv1 = self.conv1x1a(data)
29 | conv2 = self.conv3x3(conv1, hgraph, depth)
30 | conv3 = self.conv1x1b(conv2)
31 | if self.in_channels != self.out_channels:
32 | data = self.conv1x1c(data)
33 | out = self.relu(conv3 + data)
34 | return out
35 |
36 |
37 |
38 |
39 | class GraphResBlock2(torch.nn.Module):
40 | def __init__(self, in_channels: int, out_channels: int,
41 | bottleneck: int=4):
42 | super().__init__()
43 | self.in_channels = in_channels
44 | self.out_channels = out_channels
45 | self.bottleneck = bottleneck
46 | channelb = int(out_channels / bottleneck)
47 |
48 | self.conv3x3a = GraphConvBnRelu(in_channels, channelb)
49 | self.conv3x3b = GraphConvBn(channelb, out_channels)
50 |
51 | if self.in_channels != self.out_channels:
52 | self.conv1x1 = Conv1x1Bn(in_channels, out_channels)
53 | self.relu = torch.nn.ReLU(inplace=True)
54 |
55 | def forward(self, data: torch.Tensor, hgraph: HGraph, depth: int):
56 | conv1 = self.conv3x3a(data, hgraph, depth)
57 | conv2 = self.conv3x3b(conv1, hgraph, depth)
58 | if self.in_channels != self.out_channels:
59 | data = self.conv1x1(data)
60 | out = self.relu(conv2 + data)
61 | return out
62 |
63 |
64 |
65 | class GraphResBlocks(torch.nn.Module):
66 | def __init__(self, in_channels, out_channels,
67 | resblk_num, bottleneck=4,
68 | resblk=GraphResBlock):
69 | super().__init__()
70 | self.resblk_num = resblk_num
71 | channels = [in_channels] + [out_channels] * resblk_num
72 |
73 | self.resblks = torch.nn.ModuleList(
74 | [resblk(channels[i], channels[i+1], bottleneck=bottleneck) for i in range(self.resblk_num)]
75 | )
76 |
77 | def forward(self, data: torch.Tensor, hgraph: HGraph, depth: int):
78 | for i in range(self.resblk_num):
79 | data = self.resblks[i](data, hgraph, depth)
80 |
81 | return data
82 |
83 |
--------------------------------------------------------------------------------
/img/bunny.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelligentGeometry/GeGnn/e9c1e79e521cb97690791af3340cae64899f9d54/img/bunny.jpg
--------------------------------------------------------------------------------
/img/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelligentGeometry/GeGnn/e9c1e79e521cb97690791af3340cae64899f9d54/img/teaser.png
--------------------------------------------------------------------------------
/img/teaser.v1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelligentGeometry/GeGnn/e9c1e79e521cb97690791af3340cae64899f9d54/img/teaser.v1.png
--------------------------------------------------------------------------------
/pretrained/ours00500.solver.tar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelligentGeometry/GeGnn/e9c1e79e521cb97690791af3340cae64899f9d54/pretrained/ours00500.solver.tar
--------------------------------------------------------------------------------
/readme.md:
--------------------------------------------------------------------------------
1 | # Learning the Geodesic Embedding with Graph Neural Networks
2 |
3 |
4 |
5 | [**Learning the Geodesic Embedding with Graph Neural Networks**](https://arxiv.org/abs/2309.05613)
6 | [Bo Pang](https://github.com/skinboC), [Zhongtian Zheng](https://github.com/zzttzz), Guoping Wang, and [Peng-Shuai Wang](https://wang-ps.github.io/)
7 | ACM Transactions on Graphics (SIGGRAPH Asia), 42(6), 2023
8 |
9 | 
10 |
11 | - [Learning the Geodesic Embedding with Graph Neural Networks](#learning-the-geodesic-embedding-with-graph-neural-networks)
12 | - [1. Environment](#1-environment)
13 | - [2. Prepare Data](#2-prepare-data)
14 | - [3. Train](#3-train)
15 | - [4. Test and Visualization](#4-test-and-visualization)
16 | - [5. Citation](#5-citation)
17 |
18 |
19 | ## 1. Environment
20 |
21 | First, please install pytorch that fits your cuda version.
22 |
23 | Then, install torch geometric:
24 |
25 | ```
26 | conda install pyg -c pyg
27 | ```
28 |
29 | Then install the packages required by this project:
30 |
31 | ```
32 | pip3 install -r requirements.txt
33 | ```
34 |
35 | ## 2. Prepare Data
36 |
37 | Before training, you have to generate training data, as described in 4.1 of the paper. Please note the model may not generalize well on shapes very different from the training data, as suggested in the paper.
38 |
39 | Suppose you have your meshes in `path/to/meshes`, we provide a script to generate training data from these meshes. Please open `utils/dataset_prepare/py_data_generator.py`, and change the following lines:
40 |
41 | ```python
42 | PATH_TO_MESH = "path/to/your/mesh/folder/"
43 | PATH_TO_OUTPUT_NPZ = "path/to/your/output/folder/"
44 | PATH_TO_OUTPUT_FILELIST = "path/to/your/another/output/folder"
45 | ```
46 |
47 | Then, please run the script:
48 |
49 | ```
50 | python utils/dataset_prepare/py_data_generator.py
51 | ```
52 |
53 | This will load meshes (.obj files) and generate the processed `.npz` files. You can change `a, b, c, d` in that file to adjust the property of training data.
54 |
55 | We will upload the processed data we used in our paper to Google Drive soon.
56 |
57 | There is also a mesh processing script in `utils/dataset_prepare/simplifiy.py`, which is a utility tool we used to process the meshes.
58 |
59 | ## 3. Train
60 |
61 | To train the network, you can type:
62 |
63 | ```shell
64 | python3 GnnDist.py --config configs/gnndist_test.yaml
65 | ```
66 |
67 | We train our model on Ubuntu 20.04 (4 Nvidia 3090 GPUs, 24GB VRAM) with batch size 10. If your GPU memory is not enough, please reduce the batch size in the config file.
68 |
69 | A checkpoint of our model is provided in `pretrained/ours00500.solver.tar`.
70 |
71 | ## 4. Test and Visualization
72 |
73 | 
74 |
75 | We provide a script to test the network and visualize the results. The following command tests our method on the specified mesh and will open a polyscope window for visualization. (If you are using ssh and cannot open the window, please delete polyscope-related code in `GeGnn_standalong.py`.)
76 |
77 | Feel free to change `--test_file` to test on other meshes and change `--start_pts` to test on different source points.
78 |
79 | ```shell
80 | python3 GeGnn_standalong.py --mode SSAD --test_file data/test_mesh/bunny.obj --ckpt_path pretrained/ours00500.solver.tar --start_pts 0 --output out/ssad_ours.npy
81 | ```
82 |
83 | This will open a polyscope window and show the results.
84 |
85 | ## 5. Citation
86 |
87 | If you find this project useful for your research, please kindly cite our paper:
88 |
89 | ```bibtex
90 | @article{pang2023gegnn,
91 | title={Learning the Geodesic Embedding with Graph Neural Networks},
92 | author={Pang, Bo and Zheng, Zhongtian and Wang, Guoping and Wang, Peng-Shuai},
93 | journal={ACM Transactions on Graphics (SIGGRAPH Asia)},
94 | year={2023}
95 |
96 | }
97 | ```
98 |
99 | If you have any questions, please feel free to contact us at bo98@stu.pku.edu.cn
100 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | matplotlib==3.6.1
2 | numpy==1.23.4
3 | PyMCubes==0.1.2
4 | scikit_learn==1.1.3
5 | tqdm==4.64.1
6 | websockets==10.3
7 | tensorboard
8 | pygeodesic
9 | trimesh
10 | yacs
11 | polyscope
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelligentGeometry/GeGnn/e9c1e79e521cb97690791af3340cae64899f9d54/utils/__init__.py
--------------------------------------------------------------------------------
/utils/dataset_prepare/py_data_generater.py:
--------------------------------------------------------------------------------
1 | #
2 | import trimesh
3 | import pygeodesic.geodesic as geodesic
4 | import numpy as np
5 | import os
6 | import multiprocessing as mp
7 | from threading import Thread
8 | from tqdm import tqdm
9 |
10 |
11 |
12 | PATH_TO_MESH = "./data/tiny/meshes/"
13 | PATH_TO_OUTPUT_NPZ = "./data/tiny/npz/"
14 | PATH_TO_OUTPUT_FILELIST = "./data/tiny/filelist/"
15 |
16 | TRAINING_SPLIT_RATIO = 0.8
17 |
18 |
19 | def visualize_ssad(vertices: np.ndarray, triangles: np.ndarray, source_index: int):
20 | # Initialise the PyGeodesicAlgorithmExact class instance
21 | geoalg = geodesic.PyGeodesicAlgorithmExact(vertices, triangles)
22 |
23 | # Define the source and target point ids with respect to the points array
24 | source_indices = np.array([source_index])
25 | target_indices = None
26 | distances, best_source = geoalg.geodesicDistances(source_indices, target_indices)
27 | return distances
28 |
29 |
30 | def visualize_two_pts(vertices: np.ndarray, triangles: np.ndarray, source_index: int, dest_index: int):
31 | # Initialise the PyGeodesicAlgorithmExact class instance
32 | geoalg = geodesic.PyGeodesicAlgorithmExact(vertices, triangles)
33 |
34 | # Define the source and target point ids with respect to the points array
35 | source_indices = np.array([source_index])
36 | target_indices = np.array([dest_index])
37 | distances, best_source = geoalg.geodesicDistance(source_indices, target_indices)
38 | return distances
39 |
40 |
41 | def data_prepare_ssad(object_file: str, output_path: str, source_index: int):
42 | vertices = []
43 | triangles = []
44 |
45 | with open(object_file, "r") as f:
46 | lines = f.readlines()
47 | for each in lines:
48 | if len(each) < 2:
49 | continue
50 | if each[0:2] == "v ":
51 | temp = each.split()
52 | vertices.append([float(temp[1]), float(temp[2]), float(temp[3])])
53 | if each[0:2] == "f ":
54 | temp = each.split()
55 | #
56 | temp[3] = temp[3].split("/")[0]
57 | temp[1] = temp[1].split("/")[0]
58 | temp[2] = temp[2].split("/")[0]
59 | triangles.append([int(temp[1]) - 1, int(temp[2]) - 1, int(temp[3]) - 1])
60 | vertices = np.array(vertices)
61 | triangles = np.array(triangles)
62 |
63 |
64 | # Initialise the PyGeodesicAlgorithmExact class instance
65 | geoalg = geodesic.PyGeodesicAlgorithmExact(vertices, triangles)
66 |
67 | # Define the source and target point ids with respect to the points array
68 | source_indices = np.array([source_index])
69 | target_indices = None
70 | distances, best_source = geoalg.geodesicDistances(source_indices, target_indices)
71 |
72 |
73 | def data_prepare_gen_dataset(object_file: str, output_path: str, num_sources, num_each_dest, tqdm_on=True):
74 | vertices = []
75 | triangles = []
76 |
77 | with open(object_file, "r") as f:
78 | lines = f.readlines()
79 | for each in lines:
80 | if len(each) < 2:
81 | continue
82 | if each[0:2] == "v ":
83 | temp = each.split()
84 | vertices.append([float(temp[1]), float(temp[2]), float(temp[3])])
85 | if each[0:2] == "f ":
86 | temp = each.split()
87 | #
88 | temp[3] = temp[3].split("/")[0]
89 | temp[1] = temp[1].split("/")[0]
90 | temp[2] = temp[2].split("/")[0]
91 | triangles.append([int(temp[1]) - 1, int(temp[2]) - 1, int(temp[3]) - 1])
92 | vertices = np.array(vertices)
93 | triangles = np.array(triangles)
94 |
95 |
96 | # Initialise the PyGeodesicAlgorithmExact class instance
97 | geoalg = geodesic.PyGeodesicAlgorithmExact(vertices, triangles)
98 |
99 | result = np.array([[0,0,0]])
100 |
101 | sources = np.random.randint(low=0, high=len(vertices), size=[num_sources])
102 |
103 | # iterate
104 | # only the process on process #0 will be displayed
105 | # this should not be problematic or confusing on most homogeneous CPUs
106 | it = tqdm(range(num_sources)) if tqdm_on else range(num_sources)
107 | for i in it:
108 | source_indices = np.array([sources[i]])
109 | target_indices = np.random.randint(low=0, high=len(vertices), size=[num_each_dest])
110 | if source_indices.max() >= len(vertices):
111 | print("!!!!!!!", source_indices.max(), len(vertices))
112 | if target_indices.max() >= len(vertices):
113 | print("!!!!!!!", target_indices.max(), len(vertices))
114 |
115 | distances, best_source = geoalg.geodesicDistances(source_indices, target_indices)
116 |
117 | a = source_indices.repeat([num_each_dest]).reshape([-1,1])
118 | b = target_indices.reshape([-1,1])
119 | c = distances.reshape([-1,1])
120 | new = np.concatenate([a, b, c], -1)
121 | result = np.concatenate([result, new])
122 |
123 | np.savetxt(output_path, result)
124 |
125 |
126 | #############################################
127 | def computation_thread(filename, object_name, a, b, c, d, idx=None):
128 | assert idx != None, "an idx has to be given"
129 | tqdm_on = False
130 | if idx == 0:
131 | tqdm_on = True
132 | print(filename, object_name)
133 | data_prepare_gen_dataset(filename, object_name + "_train_" + str(idx), a, b, tqdm_on=tqdm_on)
134 | # data_prepare_gen_dataset(filename, "utils/dataset_prepare/outputs/" + object_name + "_test_" + str(idx), c, d, tqdm_on=tqdm_on)
135 |
136 |
137 | ##############################################
138 |
139 |
140 |
141 |
142 | if __name__ == "__main__":
143 | '''
144 |
145 |
146 | a: on training set, how many sources to randomly sample.
147 | b: on training set, for each source, how many dest to randomly sample.
148 | c: on testing set, how many sources to randomly sample.
149 | d: on testing set, for each source, how many dest to randomly sample.
150 | threads: how many threads to use. 0 means all cores.
151 |
152 | '''
153 |
154 | #############################################################
155 | object_name = None
156 | a = 300
157 | b = 800
158 | c = 400
159 | d = 60
160 | file_size_threshold = 12_048_576 # a threshold to filter out large meshes
161 | threads = 1
162 | #############################################################
163 |
164 | assert threads >= 0 and type(threads) == int
165 | if threads == 0:
166 | threads = mp.cpu_count()
167 | print(f"Automatically utilize all CPU cores ({threads})")
168 | else:
169 | print(f"{threads} CPU cores are utilized!")
170 |
171 | # make dirs, if not exist
172 | if os.path.exists(PATH_TO_OUTPUT_NPZ) == False:
173 | os.mkdir(PATH_TO_OUTPUT_NPZ)
174 | if os.path.exists(PATH_TO_OUTPUT_FILELIST) == False:
175 | os.mkdir(PATH_TO_OUTPUT_FILELIST)
176 |
177 | all_files = []
178 | for mesh in os.listdir(PATH_TO_MESH):
179 | # check if the file is too large
180 | if os.path.getsize(PATH_TO_MESH + mesh) < file_size_threshold:
181 | all_files.append(PATH_TO_MESH + mesh)
182 |
183 |
184 |
185 | #all_files = os.listdir("./inputs")
186 |
187 | object_names = all_files
188 | for i in range(len(object_names)):
189 | if object_names[i][-4:] == ".obj":
190 | object_names[i] = object_names[i][:-4]
191 |
192 |
193 | print(f"Current dir: {os.getcwd()}, object to be processed: {len(object_names)}")
194 |
195 |
196 | # handle the case when the output file already exists
197 | for i in tqdm(range(len(object_names))):
198 | object_name = object_names[i]
199 | if object_name.split("/")[-1][0] == ".":
200 | continue # not an obj file
201 |
202 | filename_out = PATH_TO_OUTPUT_NPZ + object_name + ".npz"
203 | if os.path.exists(filename_out):
204 | continue
205 |
206 | filename = object_name + ".obj"
207 |
208 |
209 | train_data_filename_list = []
210 | test_data_filename_list = []
211 |
212 |
213 | pool = []
214 |
215 | for t in range(threads):
216 | task = mp.Process(target=computation_thread, args=(filename, object_name, a//threads,b,c//threads,d, t,))
217 | task.start()
218 | pool.append(task)
219 | for t, task in enumerate(pool):
220 | task.join()
221 | train_data_filename_list.append(object_name + "_train_" + str(t))
222 | #test_data_filename_list.append("./utils/dataset_prepare/outputs/" + object_name + "_test_" + str(t))
223 | #breakpoint()
224 | # 整合多线程的结果到一起
225 | #print(object_name)
226 | try:
227 | for i in range(len(train_data_filename_list)):
228 | # train data
229 | with open(object_name + "_train_" + str(i), "r") as f:
230 | data = f.read()
231 | with open(object_name + "_train", "a") as f:
232 | f.write(data)
233 | except:
234 | #print("Error on " + object_name + ", this is mostly due to non-manifold (failed to initialise the PyGeodesicAlgorithmExact class instance)")
235 | continue
236 |
237 |
238 | # 清理掉中间过程文件
239 | #print("qqqqqq")
240 | #breakpoint()
241 | for each in (train_data_filename_list + test_data_filename_list):
242 | os.remove(each)
243 |
244 | filename_in = object_name + ".obj"
245 | dist_in = object_name + "_train"
246 | filename_out = PATH_TO_OUTPUT_NPZ + object_name.split("/")[-1] + ".npz"
247 | try:
248 | # 由于下述问题,有时trimesh loader会返回错误的顶点数量(并在训练时导致数组越界)
249 | # https://github.com/mikedh/trimesh/issues/489
250 | # 参考sanity check部分,对于出现这种错误的mesh,我们采取最直接的解决方案:放弃该mesh
251 | # 当使用QEM方法简化网格的时候,似乎更可能出现这个问题,出现率估计约百分之一到千分之一。
252 | # 下面的mesh就是一个例子:(1.3 remeshing + 0.85 QEM)
253 | # /mnt/sdb/pangbo/gnn-dist/utils/dataset_prepare/simplified_obj/02828884/26f583c91e815e8fcb2a965e75be701c.obj
254 | mesh = trimesh.load_mesh(filename_in)
255 | dist = np.loadtxt(dist_in)
256 | except Exception:
257 | print(f"load {filename_in} or {dist_in} failed...")
258 | continue
259 |
260 | # delete the dist_in
261 | os.remove(dist_in)
262 |
263 | # 额外保存图的拓扑结构(edges)
264 | # trimesh 的 .edge_unique 可以得到该网格的所有边(不重复)
265 | aa = mesh.edges_unique
266 | # 我们需要双向边,所以还需一点点额外处理
267 | bb = np.concatenate([aa[:, 1:], aa[:, :1]], 1)
268 | cc = np.concatenate([aa, bb])
269 |
270 | # breakpoint()
271 | # sanity check
272 | vertices = mesh.vertices
273 | if dist.max() > 100000000:
274 | print("inf encountered!!")
275 | elif ((dist.astype(np.float32).max()) >= vertices.shape[0]):
276 | print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
277 | print(object_name, "encountered trimesh loading error!!")
278 | continue
279 |
280 | np.savez(filename_out,
281 | edges=cc,
282 | vertices=mesh.vertices.astype(np.float32),
283 | normals=mesh.vertex_normals.astype(np.float32),
284 | faces=mesh.faces.astype(np.float32),
285 | dist_val=dist[:, 2:].astype(np.float32),
286 | dist_idx=dist[:, :2].astype(np.uint16),
287 | )
288 |
289 |
290 |
291 |
292 |
293 |
294 | print("\nnpz data generation finished. Now generating filelist...\n")
295 | # 最后生成 filelist
296 | lines = []
297 | breakpoint()
298 | for each in tqdm(object_names):
299 | filename_out = PATH_TO_OUTPUT_NPZ + each.split("/")[-1] + ".npz"
300 | try:
301 | dist = np.load(filename_out)
302 | # sanity check
303 | if dist['dist_val'].max() != np.inf and dist['dist_val'].max() < 1000000000:
304 | lines.append(filename_out + "\n")
305 | else:
306 | print(f"{filename_out} not good, contains inf!")
307 | continue
308 | except Exception:
309 | print(f"load {filename_out} failed for unknown reason.")
310 | continue
311 |
312 | import random
313 | random.shuffle(lines)
314 | # split
315 | train_num = int(len(lines) * TRAINING_SPLIT_RATIO)
316 | test_num = len(lines) - train_num
317 | train_lines = lines[:train_num]
318 | test_lines = lines[train_num:]
319 |
320 | with open(PATH_TO_OUTPUT_FILELIST + 'filelist_train.txt', 'w') as f:
321 | f.writelines(train_lines)
322 |
323 | with open(PATH_TO_OUTPUT_FILELIST + 'filelist_test.txt', 'w') as f:
324 | f.writelines(test_lines)
--------------------------------------------------------------------------------
/utils/dataset_prepare/simplifiy.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import sys
4 | import pymeshlab
5 | from tqdm import tqdm
6 | import multiprocessing as mp
7 | from threading import Thread
8 |
9 | def walk_dir(dir):
10 | full_path = []
11 | filename = []
12 | for root, dirs, files in os.walk(dir):
13 | for file in files:
14 | if file.endswith(".obj"):
15 | full_path.append(os.path.join(root, file))
16 | filename.append(file)
17 | return [full_path, filename]
18 |
19 |
20 | def simplify_mesh(input_mesh, output_mesh):
21 | if os.path.exists(input_mesh) == False:
22 | return
23 | if os.path.exists(output_mesh):
24 | return
25 |
26 | ms = pymeshlab.MeshSet()
27 | ms.load_new_mesh(input_mesh)
28 | m = ms.current_mesh()
29 | num_v = m.vertex_matrix().shape[0]
30 | num_f = m.face_matrix().shape[0]
31 |
32 | #
33 | simplify_method = "combined"
34 |
35 | if simplify_method == "QEM":
36 | # not recommended. QEM may produce very strange topology/degenerate triangles, which is not good for training
37 | ms.meshing_decimation_quadric_edge_collapse(targetfacenum=2000+num_f//4)
38 | elif simplify_method == "remesh":
39 | # the distribution of triangles is uniform. Not so good since we want to make the network "exposed" to more complex triangulations
40 | ms.meshing_isotropic_explicit_remeshing(iterations=6, targetlen=pymeshlab.Percentage(1.5))
41 | elif simplify_method == "combined":
42 | # used in our paper
43 | ms.meshing_isotropic_explicit_remeshing(iterations=6, targetlen=pymeshlab.Percentage(1.35))
44 | ms.meshing_decimation_quadric_edge_collapse(targetfacenum=int(num_f*0.85))
45 | else:
46 | raise NotImplementedError
47 |
48 | ms.save_current_mesh(output_mesh)
49 | # try:
50 | # pass
51 | #
52 | # except Exception:
53 | # print(f"Mesh {input_mesh} failed")
54 |
55 |
56 |
57 | if __name__ == "__main__":
58 |
59 | threads = 4
60 | dir = "path/to/your/dataset"
61 | output_dir = "path/to/your/output/directory"
62 |
63 |
64 | [full_path, filename] = walk_dir(dir)
65 |
66 | try:
67 | os.mkdir(output_dir)
68 | except:
69 | pass
70 |
71 | for i in tqdm(range(len(full_path))):
72 | r = i % threads
73 | pool = []
74 |
75 | input_mesh = full_path[i]
76 | output_mesh = output_dir+filename[i]
77 | task = mp.Process(target=simplify_mesh, args=(input_mesh, output_mesh))
78 | task.start()
79 | pool.append(task)
80 |
81 | if r == 0:
82 | for t, task in enumerate(pool):
83 | task.join()
84 |
85 |
86 |
--------------------------------------------------------------------------------
/utils/ocnn/__init__.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Octree-based Sparse Convolutional Neural Networks
3 | # Copyright (c) 2022 Peng-Shuai Wang
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Peng-Shuai Wang
6 | # --------------------------------------------------------
7 |
8 | from . import octree
9 | from . import nn
10 | from . import modules
11 | from . import models
12 | from . import dataset
13 | from . import utils
14 |
15 |
16 | __version__ = '2.1.8'
17 |
18 | __all__ = [
19 | 'octree',
20 | 'nn',
21 | 'modules',
22 | 'models',
23 | 'dataset',
24 | 'utils'
25 | ]
26 |
--------------------------------------------------------------------------------
/utils/ocnn/dataset.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Octree-based Sparse Convolutional Neural Networks
3 | # Copyright (c) 2022 Peng-Shuai Wang
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Peng-Shuai Wang
6 | # --------------------------------------------------------
7 |
8 | import torch
9 |
10 | from utils import ocnn
11 | from utils.ocnn.octree import Octree, Points
12 |
13 |
14 | __all__ = ['Transform', 'CollateBatch']
15 | classes = __all__
16 |
17 |
18 | class Transform:
19 | r''' A boilerplate class which transforms an input data for :obj:`ocnn`.
20 | The input data is first converted to :class:`Points`, then randomly transformed
21 | (if enabled), and converted to an :class:`Octree`.
22 |
23 | Args:
24 | depth (int): The octree depth.
25 | full_depth (int): The octree layers with a depth small than
26 | :attr:`full_depth` are forced to be full.
27 | distort (bool): If true, performs the data augmentation.
28 | angle (list): A list of 3 float values to generate random rotation angles.
29 | interval (list): A list of 3 float values to represent the interval of
30 | rotation angles.
31 | scale (float): The maximum relative scale factor.
32 | uniform (bool): If true, performs uniform scaling.
33 | jittor (float): The maximum jitter values.
34 | orient_normal (str): Orient point normals along the specified axis, which is
35 | useful when normals are not oriented.
36 | '''
37 |
38 | def __init__(self, depth: int, full_depth: int, distort: bool, angle: list,
39 | interval: list, scale: float, uniform: bool, jitter: float,
40 | orient_normal: str = '', **kwargs):
41 | super().__init__()
42 |
43 | # for octree building
44 | self.depth = depth
45 | self.full_depth = full_depth
46 |
47 | # for data augmentation
48 | self.distort = distort
49 | self.angle = angle
50 | self.interval = interval
51 | self.scale = scale
52 | self.uniform = uniform
53 | self.jitter = jitter
54 |
55 | # for other transformations
56 | self.orient_normal = orient_normal
57 |
58 | def __call__(self, sample: dict, idx: int):
59 | r''''''
60 |
61 | points = self.preprocess(sample, idx)
62 | output = self.transform(points, idx)
63 | output['octree'] = self.points2octree(output['points'])
64 | return output
65 |
66 | def preprocess(self, sample: dict, idx: int):
67 | r''' Transforms :attr:`sample` to :class:`Points` and performs some specific
68 | transformations, like normalization.
69 | '''
70 |
71 | xyz = torch.from_numpy(sample['points'])
72 | normals = torch.from_numpy(sample['normals'])
73 | points = Points(xyz, normals)
74 | return points
75 |
76 | def transform(self, points: Points, idx: int):
77 | r''' Applies the general transformations provided by :obj:`ocnn`.
78 | '''
79 |
80 | # The augmentations including rotation, scaling, and jittering.
81 | if self.distort:
82 | rng_angle, rng_scale, rng_jitter = self.rnd_parameters()
83 | points.rotate(rng_angle)
84 | points.translate(rng_jitter)
85 | points.scale(rng_scale)
86 |
87 | if self.orient_normal:
88 | points.orient_normal(self.orient_normal)
89 |
90 | # !!! NOTE: Clip the point cloud to [-1, 1] before building the octree
91 | inbox_mask = points.clip(min=-1, max=1)
92 | return {'points': points, 'inbox_mask': inbox_mask}
93 |
94 | def points2octree(self, points: Points):
95 | r''' Converts the input :attr:`points` to an octree.
96 | '''
97 |
98 | octree = Octree(self.depth, self.full_depth)
99 | octree.build_octree(points)
100 | return octree
101 |
102 | def rnd_parameters(self):
103 | r''' Generates random parameters for data augmentation.
104 | '''
105 |
106 | rnd_angle = [None] * 3
107 | for i in range(3):
108 | rot_num = self.angle[i] // self.interval[i]
109 | rnd = torch.randint(low=-rot_num, high=rot_num+1, size=(1,))
110 | rnd_angle[i] = rnd * self.interval[i] * (3.14159265 / 180.0)
111 | rnd_angle = torch.cat(rnd_angle)
112 |
113 | rnd_scale = torch.rand(3) * (2 * self.scale) - self.scale + 1.0
114 | if self.uniform:
115 | rnd_scale[1] = rnd_scale[0]
116 | rnd_scale[2] = rnd_scale[0]
117 |
118 | rnd_jitter = torch.rand(3) * (2 * self.jitter) - self.jitter
119 | return rnd_angle, rnd_scale, rnd_jitter
120 |
121 |
122 | class CollateBatch:
123 | r''' Merge a list of octrees and points into a batch.
124 | '''
125 |
126 | def __init__(self, merge_points: bool = False):
127 | self.merge_points = merge_points
128 |
129 | def __call__(self, batch: list):
130 | assert type(batch) == list
131 |
132 | outputs = {}
133 | for key in batch[0].keys():
134 | outputs[key] = [b[key] for b in batch]
135 |
136 | # Merge a batch of octrees into one super octree
137 | if 'octree' in key:
138 | octree = ocnn.octree.merge_octrees(outputs[key])
139 | # NOTE: remember to construct the neighbor indices
140 | octree.construct_all_neigh()
141 | outputs[key] = octree
142 |
143 | # Merge a batch of points
144 | if 'points' in key and self.merge_points:
145 | outputs[key] = ocnn.octree.merge_points(outputs[key])
146 |
147 | # Convert the labels to a Tensor
148 | if 'label' in key:
149 | outputs['label'] = torch.tensor(outputs[key])
150 |
151 | return outputs
152 |
--------------------------------------------------------------------------------
/utils/ocnn/models/__init__.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Octree-based Sparse Convolutional Neural Networks
3 | # Copyright (c) 2022 Peng-Shuai Wang
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Peng-Shuai Wang
6 | # --------------------------------------------------------
7 |
8 | from .lenet import LeNet
9 | from .resnet import ResNet
10 | from .segnet import SegNet
11 | from .octree_unet import OctreeUnet
12 | from .hrnet import HRNet
13 | from .autoencoder import AutoEncoder
14 |
15 | __all__ = [
16 | 'LeNet',
17 | 'ResNet',
18 | 'SegNet',
19 | 'UNet',
20 | 'HRNet',
21 | 'AutoEncoder',
22 | ]
23 |
24 | classes = __all__
25 |
--------------------------------------------------------------------------------
/utils/ocnn/models/autoencoder.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Octree-based Sparse Convolutional Neural Networks
3 | # Copyright (c) 2022 Peng-Shuai Wang
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Peng-Shuai Wang
6 | # --------------------------------------------------------
7 |
8 | import torch
9 | import torch.nn
10 |
11 | from utils import ocnn
12 | from utils.ocnn.octree import Octree
13 |
14 |
15 | class AutoEncoder(torch.nn.Module):
16 | r''' Octree-based AutoEncoder for shape encoding and decoding.
17 |
18 | Args:
19 | channel_in (int): The channel of the input signal.
20 | channel_out (int): The channel of the output signal.
21 | depth (int): The depth of the octree.
22 | full_depth (int): The full depth of the octree.
23 | feature (str): The feature type of the input signal. For details of this
24 | argument, please refer to :class:`ocnn.modules.InputFeature`.
25 | '''
26 |
27 | def __init__(self, channel_in: int, channel_out: int, depth: int,
28 | full_depth: int = 2, feature: str = 'ND'):
29 | super().__init__()
30 | self.channel_in = channel_in
31 | self.channel_out = channel_out
32 | self.depth = depth
33 | self.full_depth = full_depth
34 | self.feature = feature
35 | self.resblk_num = 2
36 | self.shape_code_channel = 128
37 | self.channels = [512, 512, 256, 256, 128, 128, 32, 32, 16, 16]
38 |
39 | # encoder
40 | self.conv1 = ocnn.modules.OctreeConvBnRelu(
41 | channel_in, self.channels[depth], nempty=False)
42 | self.encoder_blks = torch.nn.ModuleList([ocnn.modules.OctreeResBlocks(
43 | self.channels[d], self.channels[d], self.resblk_num, nempty=False)
44 | for d in range(depth, full_depth-1, -1)])
45 | self.downsample = torch.nn.ModuleList([ocnn.modules.OctreeConvBnRelu(
46 | self.channels[d], self.channels[d-1], kernel_size=[2], stride=2,
47 | nempty=False) for d in range(depth, full_depth, -1)])
48 | self.proj = torch.nn.Linear(
49 | self.channels[full_depth], self.shape_code_channel, bias=True)
50 |
51 | # decoder
52 | self.channels[full_depth] = self.shape_code_channel # update `channels`
53 | self.upsample = torch.nn.ModuleList([ocnn.modules.OctreeDeconvBnRelu(
54 | self.channels[d-1], self.channels[d], kernel_size=[2], stride=2,
55 | nempty=False) for d in range(full_depth+1, depth+1)])
56 | self.decoder_blks = torch.nn.ModuleList([ocnn.modules.OctreeResBlocks(
57 | self.channels[d], self.channels[d], self.resblk_num, nempty=False)
58 | for d in range(full_depth, depth+1)])
59 |
60 | # header
61 | self.predict = torch.nn.ModuleList([self._make_predict_module(
62 | self.channels[d], 2) for d in range(full_depth, depth + 1)])
63 | self.header = self._make_predict_module(self.channels[depth], channel_out)
64 |
65 | def _make_predict_module(self, channel_in, channel_out=2, num_hidden=64):
66 | return torch.nn.Sequential(
67 | ocnn.modules.Conv1x1BnRelu(channel_in, num_hidden),
68 | ocnn.modules.Conv1x1(num_hidden, channel_out, use_bias=True))
69 |
70 | def get_input_feature(self, octree: Octree):
71 | r''' Get the input feature from the input `octree`.
72 | '''
73 |
74 | octree_feature = ocnn.modules.InputFeature(self.feature, nempty=False)
75 | out = octree_feature(octree)
76 | assert out.size(1) == self.channel_in
77 | return out
78 |
79 | def ae_encoder(self, octree: Octree):
80 | r''' The encoder network of the AutoEncoder.
81 | '''
82 |
83 | convs = dict()
84 | depth, full_depth = self.depth, self.full_depth
85 | data = self.get_input_feature(octree)
86 | convs[depth] = self.conv1(data, octree, depth)
87 | for i, d in enumerate(range(depth, full_depth-1, -1)):
88 | convs[d] = self.encoder_blks[i](convs[d], octree, d)
89 | if d > full_depth:
90 | convs[d-1] = self.downsample[i](convs[d], octree, d)
91 |
92 | # NOTE: here tanh is used to constrain the shape code in [-1, 1]
93 | shape_code = self.proj(convs[full_depth]).tanh()
94 | return shape_code
95 |
96 | def ae_decoder(self, shape_code: torch.Tensor, octree: Octree,
97 | update_octree: bool = False):
98 | r''' The decoder network of the AutoEncoder.
99 | '''
100 |
101 | logits = dict()
102 | deconv = shape_code
103 | depth, full_depth = self.depth, self.full_depth
104 | for i, d in enumerate(range(full_depth, depth+1)):
105 | if d > full_depth:
106 | deconv = self.upsample[i-1](deconv, octree, d-1)
107 | deconv = self.decoder_blks[i](deconv, octree, d)
108 |
109 | # predict the splitting label
110 | logit = self.predict[i](deconv)
111 | logits[d] = logit
112 |
113 | # update the octree according to predicted labels
114 | if update_octree:
115 | split = logit.argmax(1).int()
116 | octree.octree_split(split, d)
117 | if d < depth:
118 | octree.octree_grow(d + 1)
119 |
120 | # predict the signal
121 | if d == depth:
122 | signal = self.header(deconv)
123 | signal = torch.tanh(signal)
124 | signal = ocnn.nn.octree_depad(signal, octree, depth)
125 | if update_octree:
126 | octree.features[depth] = signal
127 |
128 | return {'logits': logits, 'signal': signal, 'octree_out': octree}
129 |
130 | def decode_code(self, shape_code: torch.Tensor):
131 | r''' Decodes the shape code to an output octree.
132 |
133 | Args:
134 | shape_code (torch.Tensor): The shape code for decoding.
135 | '''
136 |
137 | octree_out = self.init_octree(shape_code)
138 | out = self.ae_decoder(shape_code, octree_out, update_octree=True)
139 | return out
140 |
141 | def init_octree(self, shape_code: torch.Tensor):
142 | r''' Initialize a full octree for decoding.
143 |
144 | Args:
145 | shape_code (torch.Tensor): The shape code for decoding, used to getting
146 | the `batch_size` and `device` to initialize the output octree.
147 | '''
148 |
149 | device = shape_code.device
150 | node_num = 2 ** (3 * self.full_depth)
151 | batch_size = shape_code.size(0) // node_num
152 | octree = Octree(self.depth, self.full_depth, batch_size, device)
153 | for d in range(self.full_depth+1):
154 | octree.octree_grow_full(depth=d)
155 | return octree
156 |
157 | def forward(self, octree: Octree, update_octree: bool):
158 | r''''''
159 |
160 | shape_code = self.ae_encoder(octree)
161 | if update_octree:
162 | octree = self.init_octree(shape_code)
163 | out = self.ae_decoder(shape_code, octree, update_octree)
164 | return out
165 |
--------------------------------------------------------------------------------
/utils/ocnn/models/hrnet.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Octree-based Sparse Convolutional Neural Networks
3 | # Copyright (c) 2022 Peng-Shuai Wang
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Peng-Shuai Wang
6 | # --------------------------------------------------------
7 |
8 | import torch
9 | from typing import List
10 |
11 | from utils import ocnn
12 | from utils.ocnn.octree import Octree
13 |
14 |
15 | class Branches(torch.nn.Module):
16 |
17 | def __init__(self, channels: List[int], resblk_num: int, nempty: bool = False):
18 | super().__init__()
19 | self.channels = channels
20 | self.resblk_num = resblk_num
21 | bottlenecks = [4 if c < 256 else 8 for c in channels] # to save parameters
22 | self.resblocks = torch.nn.ModuleList([
23 | ocnn.modules.OctreeResBlocks(ch, ch, resblk_num, bnk, nempty=nempty)
24 | for ch, bnk in zip(channels, bottlenecks)])
25 |
26 | def forward(self, datas: List[torch.Tensor], octree: Octree, depth: int):
27 | num = len(self.channels)
28 | torch._assert(len(datas) == num, 'Error')
29 |
30 | out = [None] * num
31 | for i in range(num):
32 | depth_i = depth - i
33 | out[i] = self.resblocks[i](datas[i], octree, depth_i)
34 | return out
35 |
36 |
37 | class TransFunc(torch.nn.Module):
38 |
39 | def __init__(self, in_channels: int, out_channels: int, nempty: bool = False):
40 | super().__init__()
41 | self.in_channels = in_channels
42 | self.out_channels = out_channels
43 | self.nempty = nempty
44 | if in_channels != out_channels:
45 | self.conv1x1 = ocnn.modules.Conv1x1BnRelu(in_channels, out_channels)
46 |
47 | def forward(self, data: torch.Tensor, octree: Octree,
48 | in_depth: int, out_depth: int):
49 | out = data
50 | if in_depth > out_depth:
51 | for d in range(in_depth, out_depth, -1):
52 | out = ocnn.nn.octree_max_pool(out, octree, d, self.nempty)
53 | if self.in_channels != self.out_channels:
54 | out = self.conv1x1(out)
55 |
56 | if in_depth < out_depth:
57 | if self.in_channels != self.out_channels:
58 | out = self.conv1x1(out)
59 | for d in range(in_depth, out_depth, 1):
60 | out = ocnn.nn.octree_upsample(out, octree, d, self.nempty)
61 | return out
62 |
63 |
64 | class Transitions(torch.nn.Module):
65 |
66 | def __init__(self, channels: List[int], nempty: bool = False):
67 | super().__init__()
68 | self.channels = channels
69 | self.nempty = nempty
70 |
71 | num = len(self.channels)
72 | self.trans_func = torch.nn.ModuleList()
73 | for i in range(num - 1):
74 | for j in range(num):
75 | self.trans_func.append(TransFunc(channels[i], channels[j], nempty))
76 |
77 | def forward(self, data: List[torch.Tensor], octree: Octree, depth: int):
78 | num = len(self.channels)
79 | features = [[None] * (num - 1) for _ in range(num)]
80 | for i in range(num - 1):
81 | for j in range(num):
82 | k = i * num + j
83 | in_depth = depth - i
84 | out_depth = depth - j
85 | features[j][i] = self.trans_func[k](
86 | data[i], octree, in_depth, out_depth)
87 |
88 | out = [None] * num
89 | for j in range(num):
90 | # In the original tensorflow implmentation, the relu is added here,
91 | # instead of Line 77
92 | out[j] = torch.stack(features[j], dim=0).sum(dim=0)
93 | return out
94 |
95 |
96 | class FrontLayer(torch.nn.Module):
97 |
98 | def __init__(self, channels: List[int], nempty: bool = False):
99 | super().__init__()
100 | self.channels = channels
101 | self.num = len(channels) - 1
102 | self.nempty = nempty
103 |
104 | self.conv = torch.nn.ModuleList([
105 | ocnn.modules.OctreeConvBnRelu(channels[i], channels[i + 1], nempty=nempty)
106 | for i in range(self.num)])
107 | self.maxpool = torch.nn.ModuleList([
108 | ocnn.nn.OctreeMaxPool(nempty) for i in range(self.num - 1)])
109 |
110 | def forward(self, data: torch.Tensor, octree: Octree, depth: int):
111 | out = data
112 | for i in range(self.num - 1):
113 | depth_i = depth - i
114 | out = self.conv[i](out, octree, depth_i)
115 | out = self.maxpool[i](out, octree, depth_i)
116 | out = self.conv[-1](out, octree, depth - self.num + 1)
117 | return out
118 |
119 |
120 | class ClsHeader(torch.nn.Module):
121 |
122 | def __init__(self, channels: List[int], out_channels: int, nempty: bool = False):
123 | super().__init__()
124 | self.channels = channels
125 | self.out_channels = out_channels
126 | self.nempty = nempty
127 |
128 | in_channels = int(torch.Tensor(channels).sum())
129 | self.conv1x1 = ocnn.modules.Conv1x1BnRelu(in_channels, 1024)
130 | self.global_pool = ocnn.nn.OctreeGlobalPool(nempty)
131 | self.header = torch.nn.Sequential(
132 | torch.nn.Flatten(start_dim=1),
133 | torch.nn.Linear(1024, out_channels, bias=True))
134 | # self.header = torch.nn.Sequential(
135 | # ocnn.modules.FcBnRelu(512, 256),
136 | # torch.nn.Dropout(p=0.5),
137 | # torch.nn.Linear(256, out_channels))
138 |
139 | def forward(self, data: List[torch.Tensor], octree: Octree, depth: int):
140 | full_depth = 2
141 | num = len(data)
142 | for i in range(num):
143 | depth_i = depth - i
144 | for d in range(depth_i, full_depth, -1):
145 | data[i] = ocnn.nn.octree_max_pool(data[i], octree, d, self.nempty)
146 |
147 | out = torch.cat(data, dim=1)
148 | out = self.conv1x1(out)
149 | out = self.global_pool(out, octree, full_depth)
150 | logit = self.header(out)
151 | return logit
152 |
153 |
154 | class HRNet(torch.nn.Module):
155 | r''' Octree-based HRNet for classification and segmentation. '''
156 |
157 | def __init__(self, in_channels: int, out_channels: int, stages: int = 3,
158 | interp: str = 'linear', nempty: bool = False):
159 | super().__init__()
160 | self.in_channels = in_channels
161 | self.out_channels = out_channels
162 | self.interp = interp
163 | self.nempty = nempty
164 | self.stages = stages
165 |
166 | self.resblk_num = 3
167 | self.channels = [128, 256, 512, 512]
168 |
169 | self.front = FrontLayer([in_channels, 32, self.channels[0]], nempty)
170 | self.branches = torch.nn.ModuleList([
171 | Branches(self.channels[:i+1], self.resblk_num, nempty)
172 | for i in range(stages)])
173 | self.transitions = torch.nn.ModuleList([
174 | Transitions(self.channels[:i+2], nempty)
175 | for i in range(stages-1)])
176 |
177 | self.cls_header = ClsHeader(self.channels[:stages], out_channels, nempty)
178 |
179 | def forward(self, data: torch.Tensor, octree: Octree, depth: int):
180 | r''''''
181 | convs = [self.front(data, octree, depth)]
182 | depth = depth - 1 # the data is downsampled in `front`
183 | for i in range(self.stages):
184 | convs = self.branches[i](convs, octree, depth)
185 | if i < self.stages - 1:
186 | convs = self.transitions[i](convs, octree, depth)
187 |
188 | logits = self.cls_header(convs, octree, depth)
189 |
190 | return logits
191 |
--------------------------------------------------------------------------------
/utils/ocnn/models/lenet.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Octree-based Sparse Convolutional Neural Networks
3 | # Copyright (c) 2022 Peng-Shuai Wang
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Peng-Shuai Wang
6 | # --------------------------------------------------------
7 |
8 | import torch
9 | from utils import ocnn
10 | from utils.ocnn.octree import Octree
11 |
12 |
13 | class LeNet(torch.nn.Module):
14 | r''' Octree-based LeNet for classification.
15 | '''
16 |
17 | def __init__(self, in_channels: int, out_channels: int, stages: int,
18 | nempty: bool = False):
19 | super().__init__()
20 | self.in_channels = in_channels
21 | self.out_channels = out_channels
22 | self.stages = stages
23 | self.nempty = nempty
24 | channels = [in_channels] + [2 ** max(i+7-stages, 2) for i in range(stages)]
25 |
26 | self.convs = torch.nn.ModuleList([ocnn.modules.OctreeConvBnRelu(
27 | channels[i], channels[i+1], nempty=nempty) for i in range(stages)])
28 | self.pools = torch.nn.ModuleList([ocnn.nn.OctreeMaxPool(
29 | nempty) for i in range(stages)])
30 | self.octree2voxel = ocnn.nn.Octree2Voxel(self.nempty)
31 | self.header = torch.nn.Sequential(
32 | torch.nn.Dropout(p=0.5), # drop1
33 | ocnn.modules.FcBnRelu(64 * 64, 128), # fc1
34 | torch.nn.Dropout(p=0.5), # drop2
35 | torch.nn.Linear(128, out_channels)) # fc2
36 |
37 | def forward(self, data: torch.Tensor, octree: Octree, depth: int):
38 | r''''''
39 |
40 | for i in range(self.stages):
41 | d = depth - i
42 | data = self.convs[i](data, octree, d)
43 | data = self.pools[i](data, octree, d)
44 | data = self.octree2voxel(data, octree, depth-self.stages)
45 | data = self.header(data)
46 | return data
47 |
--------------------------------------------------------------------------------
/utils/ocnn/models/octree_unet.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Octree-based Sparse Convolutional Neural Networks
3 | # Copyright (c) 2022 Peng-Shuai Wang
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Peng-Shuai Wang
6 | # --------------------------------------------------------
7 |
8 | import torch
9 | import torch.nn
10 | from typing import Dict
11 |
12 | from utils import ocnn
13 | from utils.ocnn.octree import Octree
14 |
15 |
16 | class OctreeUnet(torch.nn.Module):
17 | r''' Octree-based UNet for segmentation.
18 | '''
19 |
20 | def __init__(self, in_channels: int, out_channels: int, interp: str = 'linear',
21 | nempty: bool = False, **kwargs):
22 | super(OctreeUnet, self).__init__()
23 | self.in_channels = in_channels
24 | self.out_channels = out_channels
25 | self.nempty = nempty
26 | self.config_network()
27 | self.encoder_stages = len(self.encoder_blocks)
28 | self.decoder_stages = len(self.decoder_blocks)
29 |
30 | # encoder
31 | self.conv1 = ocnn.modules.OctreeConvBnRelu(
32 | in_channels, self.encoder_channel[0], nempty=nempty)
33 | self.downsample = torch.nn.ModuleList([ocnn.modules.OctreeConvBnRelu(
34 | self.encoder_channel[i], self.encoder_channel[i+1], kernel_size=[2],
35 | stride=2, nempty=nempty) for i in range(self.encoder_stages)])
36 | self.encoder = torch.nn.ModuleList([ocnn.modules.OctreeResBlocks(
37 | self.encoder_channel[i+1], self.encoder_channel[i + 1],
38 | self.encoder_blocks[i], self.bottleneck, nempty, self.resblk)
39 | for i in range(self.encoder_stages)])
40 |
41 | # decoder
42 | channel = [self.decoder_channel[i+1] + self.encoder_channel[-i-2]
43 | for i in range(self.decoder_stages)]
44 | self.upsample = torch.nn.ModuleList([ocnn.modules.OctreeDeconvBnRelu(
45 | self.decoder_channel[i], self.decoder_channel[i+1], kernel_size=[2],
46 | stride=2, nempty=nempty) for i in range(self.decoder_stages)])
47 | self.decoder = torch.nn.ModuleList([ocnn.modules.OctreeResBlocks(
48 | channel[i], self.decoder_channel[i+1],
49 | self.decoder_blocks[i], self.bottleneck, nempty, self.resblk)
50 | for i in range(self.decoder_stages)])
51 |
52 | # header
53 | # channel = self.decoder_channel[self.decoder_stages]
54 | self.octree_interp = ocnn.nn.OctreeInterp(interp, nempty)
55 | self.header = torch.nn.Sequential(
56 | ocnn.modules.Conv1x1BnRelu(self.decoder_channel[-1], 64),
57 | ocnn.modules.Conv1x1(64, self.out_channels, use_bias=True))
58 |
59 | def config_network(self):
60 | r''' Configure the network channels and Resblock numbers.
61 | '''
62 |
63 | self.encoder_channel = [32, 32, 64, 128, 256]
64 | self.decoder_channel = [256, 256, 128, 96, 96]
65 | self.encoder_blocks = [2, 3, 4, 6]
66 | self.decoder_blocks = [2, 2, 2, 2]
67 | self.bottleneck = 1
68 | self.resblk = ocnn.modules.OctreeResBlock2
69 |
70 | def unet_encoder(self, data: torch.Tensor, octree: Octree, depth: int):
71 | r''' The encoder of the U-Net.
72 | '''
73 |
74 | convd = dict()
75 | convd[depth] = self.conv1(data, octree, depth)
76 | for i in range(self.encoder_stages):
77 | d = depth - i
78 | conv = self.downsample[i](convd[d], octree, d)
79 | convd[d-1] = self.encoder[i](conv, octree, d-1)
80 | return convd
81 |
82 | def unet_decoder(self, convd: Dict[int, torch.Tensor], octree: Octree, depth: int):
83 | r''' The decoder of the U-Net.
84 | '''
85 |
86 | deconv = convd[depth]
87 | for i in range(self.decoder_stages):
88 | d = depth + i
89 | deconv = self.upsample[i](deconv, octree, d)
90 | deconv = torch.cat([convd[d+1], deconv], dim=1) # skip connections
91 | deconv = self.decoder[i](deconv, octree, d+1)
92 | return deconv
93 |
94 | def forward(self, data: torch.Tensor, octree: Octree, depth: int,
95 | query_pts: torch.Tensor):
96 | r''''''
97 |
98 | convd = self.unet_encoder(data, octree, depth)
99 | deconv = self.unet_decoder(convd, octree, depth - self.encoder_stages)
100 |
101 | interp_depth = depth - self.encoder_stages + self.decoder_stages
102 | feature = self.octree_interp(deconv, octree, interp_depth, query_pts)
103 | logits = self.header(feature)
104 | return logits
105 |
--------------------------------------------------------------------------------
/utils/ocnn/models/resnet.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Octree-based Sparse Convolutional Neural Networks
3 | # Copyright (c) 2022 Peng-Shuai Wang
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Peng-Shuai Wang
6 | # --------------------------------------------------------
7 |
8 | import torch
9 | from utils import ocnn
10 | from utils.ocnn.octree import Octree
11 |
12 |
13 | class ResNet(torch.nn.Module):
14 | r''' Octree-based ResNet for classification.
15 | '''
16 |
17 | def __init__(self, in_channels: int, out_channels: int, resblock_num: int,
18 | stages: int, nempty: bool = False):
19 | super().__init__()
20 | self.in_channels = in_channels
21 | self.out_channels = out_channels
22 | self.resblk_num = resblock_num
23 | self.stages = stages
24 | self.nempty = nempty
25 | channels = [2 ** max(i+9-stages, 2) for i in range(stages)]
26 |
27 | self.conv1 = ocnn.modules.OctreeConvBnRelu(
28 | in_channels, channels[0], nempty=nempty)
29 | self.pool1 = ocnn.nn.OctreeMaxPool(nempty)
30 | self.resblocks = torch.nn.ModuleList([ocnn.modules.OctreeResBlocks(
31 | channels[i], channels[i+1], resblock_num, nempty=nempty)
32 | for i in range(stages-1)])
33 | self.pools = torch.nn.ModuleList([ocnn.nn.OctreeMaxPool(
34 | nempty) for i in range(stages-1)])
35 | self.global_pool = ocnn.nn.OctreeGlobalPool(nempty)
36 | # self.header = torch.nn.Linear(channels[-1], out_channels, bias=True)
37 | self.header = torch.nn.Sequential(
38 | ocnn.modules.FcBnRelu(channels[-1], 512),
39 | torch.nn.Dropout(p=0.5),
40 | torch.nn.Linear(512, out_channels))
41 |
42 | def forward(self, data: torch.Tensor, octree: Octree, depth: int):
43 | r''''''
44 |
45 | data = self.conv1(data, octree, depth)
46 | data = self.pool1(data, octree, depth)
47 | for i in range(self.stages-1):
48 | d = depth - i - 1
49 | data = self.resblocks[i](data, octree, d)
50 | data = self.pools[i](data, octree, d)
51 | data = self.global_pool(data, octree, depth-self.stages)
52 | data = self.header(data)
53 | return data
54 |
--------------------------------------------------------------------------------
/utils/ocnn/models/segnet.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Octree-based Sparse Convolutional Neural Networks
3 | # Copyright (c) 2022 Peng-Shuai Wang
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Peng-Shuai Wang
6 | # --------------------------------------------------------
7 |
8 | import torch
9 | from utils import ocnn
10 | from utils.ocnn.octree import Octree
11 |
12 |
13 | class SegNet(torch.nn.Module):
14 | r''' Octree-based SegNet for segmentation.
15 | '''
16 |
17 | def __init__(self, in_channels: int, out_channels: int, stages: int,
18 | interp: str = 'linear', nempty: bool = False, **kwargs):
19 | super().__init__()
20 | self.in_channels = in_channels
21 | self.out_channels = out_channels
22 | self.stages = stages
23 | self.nempty = nempty
24 | return_indices = True
25 |
26 | channels_stages = [2 ** max(i+8-stages, 2) for i in range(stages)]
27 | channels = [in_channels] + channels_stages
28 | self.convs = torch.nn.ModuleList(
29 | [ocnn.modules.OctreeConvBnRelu(channels[i], channels[i + 1], nempty=nempty)
30 | for i in range(stages)])
31 | self.pools = torch.nn.ModuleList(
32 | [ocnn.nn.OctreeMaxPool(nempty, return_indices) for i in range(stages)])
33 |
34 | self.bottleneck = ocnn.modules.OctreeConvBnRelu(channels[-1], channels[-1])
35 |
36 | channels = channels_stages[::-1] + [channels_stages[0]]
37 | self.deconvs = torch.nn.ModuleList(
38 | [ocnn.modules.OctreeConvBnRelu(channels[i], channels[i + 1], nempty=nempty)
39 | for i in range(0, stages)])
40 | self.unpools = torch.nn.ModuleList(
41 | [ocnn.nn.OctreeMaxUnpool(nempty) for i in range(stages)])
42 |
43 | self.octree_interp = ocnn.nn.OctreeInterp(interp, nempty)
44 | self.header = torch.nn.Sequential(
45 | ocnn.modules.Conv1x1BnRelu(channels[-1], 64),
46 | ocnn.modules.Conv1x1(64, out_channels, use_bias=True))
47 |
48 | def forward(self, data: torch.Tensor, octree: Octree, depth: int,
49 | query_pts: torch.Tensor):
50 | r''''''
51 |
52 | # encoder
53 | indices = dict()
54 | for i in range(self.stages):
55 | d = depth - i
56 | data = self.convs[i](data, octree, d)
57 | data, indices[d] = self.pools[i](data, octree, d)
58 |
59 | # bottleneck
60 | data = self.bottleneck(data, octree, depth-self.stages)
61 |
62 | # decoder
63 | for i in range(self.stages):
64 | d = depth - self.stages + i
65 | data = self.unpools[i](data, indices[d + 1], octree, d)
66 | data = self.deconvs[i](data, octree, d + 1)
67 |
68 | # header
69 | feature = self.octree_interp(data, octree, depth, query_pts)
70 | logits = self.header(feature)
71 |
72 | return logits
73 |
--------------------------------------------------------------------------------
/utils/ocnn/modules/__init__.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Octree-based Sparse Convolutional Neural Networks
3 | # Copyright (c) 2022 Peng-Shuai Wang
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Peng-Shuai Wang
6 | # --------------------------------------------------------
7 |
8 | from .modules import (InputFeature,
9 | OctreeConvBn, OctreeConvBnRelu, OctreeDeconvBnRelu,
10 | Conv1x1, Conv1x1Bn, Conv1x1BnRelu, FcBnRelu,)
11 | from .resblocks import OctreeResBlock, OctreeResBlock2, OctreeResBlocks
12 |
13 | __all__ = [
14 | 'InputFeature',
15 | 'OctreeConvBn', 'OctreeConvBnRelu', 'OctreeDeconvBnRelu',
16 | 'Conv1x1', 'Conv1x1Bn', 'Conv1x1BnRelu', 'FcBnRelu',
17 | 'OctreeResBlock', 'OctreeResBlock2', 'OctreeResBlocks',
18 | ]
19 |
20 | classes = __all__
21 |
--------------------------------------------------------------------------------
/utils/ocnn/modules/modules.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Octree-based Sparse Convolutional Neural Networks
3 | # Copyright (c) 2022 Peng-Shuai Wang
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Peng-Shuai Wang
6 | # --------------------------------------------------------
7 |
8 | import torch
9 | import torch.utils.checkpoint
10 | from typing import List
11 |
12 | from utils import ocnn
13 | from utils.ocnn.nn import OctreeConv, OctreeDeconv
14 | from utils.ocnn.octree import Octree
15 |
16 |
17 | bn_momentum, bn_eps = 0.01, 0.001 # the default value of Tensorflow 1.x
18 | # bn_momentum, bn_eps = 0.1, 1e-05 # the default value of pytorch
19 |
20 |
21 | def ckpt_conv_wrapper(conv_op, data, octree):
22 | # The dummy tensor is a workaround when the checkpoint is used for the first conv layer:
23 | # https://discuss.pytorch.org/t/checkpoint-with-no-grad-requiring-inputs-problem/19117/11
24 | dummy = torch.ones(1, dtype=torch.float32, requires_grad=True)
25 |
26 | def conv_wrapper(data, octree, dummy_tensor):
27 | return conv_op(data, octree)
28 |
29 | return torch.utils.checkpoint.checkpoint(conv_wrapper, data, octree, dummy)
30 |
31 |
32 | class OctreeConvBn(torch.nn.Module):
33 | r''' A sequence of :class:`OctreeConv` and :obj:`BatchNorm`.
34 |
35 | Please refer to :class:`ocnn.nn.OctreeConv` for details on the parameters.
36 | '''
37 |
38 | def __init__(self, in_channels: int, out_channels: int,
39 | kernel_size: List[int] = [3], stride: int = 1,
40 | nempty: bool = False):
41 | super().__init__()
42 | self.conv = OctreeConv(
43 | in_channels, out_channels, kernel_size, stride, nempty)
44 | self.bn = torch.nn.BatchNorm1d(out_channels, bn_eps, bn_momentum)
45 |
46 | def forward(self, data: torch.Tensor, octree: Octree, depth: int):
47 | r''''''
48 |
49 | out = self.conv(data, octree, depth)
50 | out = self.bn(out)
51 | return out
52 |
53 |
54 | class OctreeConvBnRelu(torch.nn.Module):
55 | r''' A sequence of :class:`OctreeConv`, :obj:`BatchNorm`, and :obj:`Relu`.
56 |
57 | Please refer to :class:`ocnn.nn.OctreeConv` for details on the parameters.
58 | '''
59 |
60 | def __init__(self, in_channels: int, out_channels: int,
61 | kernel_size: List[int] = [3], stride: int = 1,
62 | nempty: bool = False):
63 | super().__init__()
64 | self.conv = OctreeConv(
65 | in_channels, out_channels, kernel_size, stride, nempty)
66 | self.bn = torch.nn.BatchNorm1d(out_channels, bn_eps, bn_momentum)
67 | self.relu = torch.nn.ReLU(inplace=True)
68 |
69 | def forward(self, data: torch.Tensor, octree: Octree, depth: int):
70 | r''''''
71 |
72 | out = self.conv(data, octree, depth)
73 | out = self.bn(out)
74 | out = self.relu(out)
75 | return out
76 |
77 |
78 | class OctreeDeconvBnRelu(torch.nn.Module):
79 | r''' A sequence of :class:`OctreeDeconv`, :obj:`BatchNorm`, and :obj:`Relu`.
80 |
81 | Please refer to :class:`ocnn.nn.OctreeDeconv` for details on the parameters.
82 | '''
83 |
84 | def __init__(self, in_channels: int, out_channels: int,
85 | kernel_size: List[int] = [3], stride: int = 1,
86 | nempty: bool = False):
87 | super().__init__()
88 | self.deconv = OctreeDeconv(
89 | in_channels, out_channels, kernel_size, stride, nempty)
90 | self.bn = torch.nn.BatchNorm1d(out_channels, bn_eps, bn_momentum)
91 | self.relu = torch.nn.ReLU(inplace=True)
92 |
93 | def forward(self, data: torch.Tensor, octree: Octree, depth: int):
94 | r''''''
95 |
96 | out = self.deconv(data, octree, depth)
97 | out = self.bn(out)
98 | out = self.relu(out)
99 | return out
100 |
101 |
102 | class Conv1x1(torch.nn.Module):
103 | r''' Performs a convolution with kernel :obj:`(1,1,1)`.
104 |
105 | The shape of octree features is :obj:`(N, C)`, where :obj:`N` is the node
106 | number and :obj:`C` is the feature channel. Therefore, :class:`Conv1x1` can be
107 | implemented with :class:`torch.nn.Linear`.
108 | '''
109 |
110 | def __init__(self, in_channels: int, out_channels: int, use_bias: bool = False):
111 | super().__init__()
112 | self.linear = torch.nn.Linear(in_channels, out_channels, use_bias)
113 |
114 | def forward(self, data: torch.Tensor):
115 | r''''''
116 |
117 | return self.linear(data)
118 |
119 |
120 | class Conv1x1Bn(torch.nn.Module):
121 | r''' A sequence of :class:`Conv1x1` and :class:`BatchNorm`.
122 | '''
123 |
124 | def __init__(self, in_channels: int, out_channels: int):
125 | super().__init__()
126 | self.conv = Conv1x1(in_channels, out_channels, use_bias=False)
127 | self.bn = torch.nn.BatchNorm1d(out_channels, bn_eps, bn_momentum)
128 |
129 | def forward(self, data: torch.Tensor):
130 | r''''''
131 |
132 | out = self.conv(data)
133 | out = self.bn(out)
134 | return out
135 |
136 |
137 | class Conv1x1BnRelu(torch.nn.Module):
138 | r''' A sequence of :class:`Conv1x1`, :class:`BatchNorm` and :class:`Relu`.
139 | '''
140 |
141 | def __init__(self, in_channels: int, out_channels: int):
142 | super().__init__()
143 | self.conv = Conv1x1(in_channels, out_channels, use_bias=False)
144 | self.bn = torch.nn.BatchNorm1d(out_channels, bn_eps, bn_momentum)
145 | self.relu = torch.nn.ReLU(inplace=True)
146 |
147 | def forward(self, data: torch.Tensor):
148 | r''''''
149 |
150 | out = self.conv(data)
151 | out = self.bn(out)
152 | out = self.relu(out)
153 | return out
154 |
155 |
156 | class FcBnRelu(torch.nn.Module):
157 | r''' A sequence of :class:`FC`, :class:`BatchNorm` and :class:`Relu`.
158 | '''
159 |
160 | def __init__(self, in_channels: int, out_channels: int):
161 | super().__init__()
162 | self.flatten = torch.nn.Flatten(start_dim=1)
163 | self.fc = torch.nn.Linear(in_channels, out_channels, bias=False)
164 | self.bn = torch.nn.BatchNorm1d(out_channels, bn_eps, bn_momentum)
165 | self.relu = torch.nn.ReLU(inplace=True)
166 |
167 | def forward(self, data):
168 | r''''''
169 |
170 | out = self.flatten(data)
171 | out = self.fc(out)
172 | out = self.bn(out)
173 | out = self.relu(out)
174 | return out
175 |
176 |
177 | class InputFeature(torch.nn.Module):
178 | r''' Returns the initial input feature stored in octree.
179 |
180 | Args:
181 | feature (str): A string used to indicate which features to extract from the
182 | input octree. If the character :obj:`N` is in :attr:`feature`, the
183 | normal signal is extracted (3 channels). Similarly, if :obj:`D` is in
184 | :attr:`feature`, the local displacement is extracted (1 channels). If
185 | :obj:`L` is in :attr:`feature`, the local coordinates of the averaged
186 | points in each octree node is extracted (3 channels). If :attr:`P` is in
187 | :attr:`feature`, the global coordinates are extracted (3 channels). If
188 | :attr:`F` is in :attr:`feature`, other features (like colors) are
189 | extracted (k channels).
190 | nempty (bool): If false, gets the features of all octree nodes.
191 | '''
192 |
193 | def __init__(self, feature: str = 'NDF', nempty: bool = False):
194 | super().__init__()
195 | self.nempty = nempty
196 | self.feature = feature.upper()
197 |
198 | def forward(self, octree: Octree):
199 | r''''''
200 |
201 | features = list()
202 | depth = octree.depth
203 | if 'N' in self.feature:
204 | features.append(octree.normals[depth])
205 |
206 | if 'L' in self.feature or 'D' in self.feature:
207 | local_points = octree.points[depth].frac() - 0.5
208 |
209 | if 'D' in self.feature:
210 | dis = torch.sum(local_points * octree.normals[depth], dim=1, keepdim=True)
211 | features.append(dis)
212 |
213 | if 'L' in self.feature:
214 | features.append(local_points)
215 |
216 | if 'P' in self.feature:
217 | scale = 2 ** (1 - depth) # normalize [0, 2^depth] -> [-1, 1]
218 | global_points = octree.points[depth] * scale - 1.0
219 | features.append(global_points)
220 |
221 | if 'F' in self.feature:
222 | features.append(octree.features[depth])
223 |
224 | out = torch.cat(features, dim=1)
225 | if not self.nempty:
226 | out = ocnn.nn.octree_pad(out, octree, depth)
227 | return out
228 |
229 | def extra_repr(self) -> str:
230 | r''''''
231 | return 'feature={}, nempty={}'.format(self.feature, self.nempty)
232 |
--------------------------------------------------------------------------------
/utils/ocnn/modules/resblocks.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Octree-based Sparse Convolutional Neural Networks
3 | # Copyright (c) 2022 Peng-Shuai Wang
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Peng-Shuai Wang
6 | # --------------------------------------------------------
7 |
8 | import torch
9 | import torch.utils.checkpoint
10 |
11 | from utils.ocnn.octree import Octree
12 | from utils.ocnn.nn import OctreeMaxPool
13 | from utils.ocnn.modules import Conv1x1BnRelu, OctreeConvBnRelu, Conv1x1Bn, OctreeConvBn
14 |
15 |
16 | class OctreeResBlock(torch.nn.Module):
17 | r''' Octree-based ResNet block in a bottleneck style. The block is composed of
18 | a series of :obj:`Conv1x1`, :obj:`Conv3x3`, and :obj:`Conv1x1`.
19 |
20 | Args:
21 | in_channels (int): Number of input channels.
22 | out_channels (int): Number of output channels.
23 | stride (int): The stride of the block (:obj:`1` or :obj:`2`).
24 | bottleneck (int): The input and output channels of the :obj:`Conv3x3` is
25 | equal to the input channel divided by :attr:`bottleneck`.
26 | nempty (bool): If True, only performs the convolution on non-empty
27 | octree nodes.
28 | '''
29 |
30 | def __init__(self, in_channels: int, out_channels: int, stride: int = 1,
31 | bottleneck: int = 4, nempty: bool = False):
32 | super().__init__()
33 | self.in_channels = in_channels
34 | self.out_channels = out_channels
35 | self.bottleneck = bottleneck
36 | self.stride = stride
37 | channelb = int(out_channels / bottleneck)
38 |
39 | if self.stride == 2:
40 | self.max_pool = OctreeMaxPool(nempty)
41 | self.conv1x1a = Conv1x1BnRelu(in_channels, channelb)
42 | self.conv3x3 = OctreeConvBnRelu(channelb, channelb, nempty=nempty)
43 | self.conv1x1b = Conv1x1Bn(channelb, out_channels)
44 | if self.in_channels != self.out_channels:
45 | self.conv1x1c = Conv1x1Bn(in_channels, out_channels)
46 | self.relu = torch.nn.ReLU(inplace=True)
47 |
48 | def forward(self, data: torch.Tensor, octree: Octree, depth: int):
49 | r''''''
50 |
51 | if self.stride == 2:
52 | data = self.max_pool(data, octree, depth)
53 | depth = depth - 1
54 | conv1 = self.conv1x1a(data)
55 | conv2 = self.conv3x3(conv1, octree, depth)
56 | conv3 = self.conv1x1b(conv2)
57 | if self.in_channels != self.out_channels:
58 | data = self.conv1x1c(data)
59 | out = self.relu(conv3 + data)
60 | return out
61 |
62 |
63 | class OctreeResBlock2(torch.nn.Module):
64 | r''' Basic Octree-based ResNet block. The block is composed of
65 | a series of :obj:`Conv3x3` and :obj:`Conv3x3`.
66 |
67 | Refer to :class:`OctreeResBlock` for the details of arguments.
68 | '''
69 |
70 | def __init__(self, in_channels, out_channels, stride=1, bottleneck=1,
71 | nempty=False):
72 | super().__init__()
73 | self.in_channels = in_channels
74 | self.out_channels = out_channels
75 | self.stride = stride
76 | channelb = int(out_channels / bottleneck)
77 |
78 | if self.stride == 2:
79 | self.maxpool = OctreeMaxPool(self.depth)
80 | self.conv3x3a = OctreeConvBnRelu(in_channels, channelb, nempty=nempty)
81 | self.conv3x3b = OctreeConvBn(channelb, out_channels, nempty=nempty)
82 | if self.in_channels != self.out_channels:
83 | self.conv1x1 = Conv1x1Bn(in_channels, out_channels)
84 | self.relu = torch.nn.ReLU(inplace=True)
85 |
86 | def forward(self, data: torch.Tensor, octree: Octree, depth: int):
87 | r''''''
88 |
89 | if self.stride == 2:
90 | data = self.maxpool(data, octree, depth)
91 | depth = depth - 1
92 | conv1 = self.conv3x3a(data, octree, depth)
93 | conv2 = self.conv3x3b(conv1, octree, depth)
94 | if self.in_channels != self.out_channels:
95 | data = self.conv1x1(data)
96 | out = self.relu(conv2 + data)
97 | return out
98 |
99 |
100 | class OctreeResBlocks(torch.nn.Module):
101 | r''' A sequence of :attr:`resblk_num` ResNet blocks.
102 | '''
103 |
104 | def __init__(self, in_channels, out_channels, resblk_num, bottleneck=4,
105 | nempty=False, resblk=OctreeResBlock, use_checkpoint=False):
106 | super().__init__()
107 | self.resblk_num = resblk_num
108 | self.use_checkpoint = use_checkpoint
109 | channels = [in_channels] + [out_channels] * resblk_num
110 |
111 | self.resblks = torch.nn.ModuleList(
112 | [resblk(channels[i], channels[i+1], 1, bottleneck, nempty)
113 | for i in range(self.resblk_num)])
114 |
115 | def forward(self, data: torch.Tensor, octree: Octree, depth: int):
116 | r''''''
117 |
118 | for i in range(self.resblk_num):
119 | if self.use_checkpoint:
120 | data = torch.utils.checkpoint.checkpoint(
121 | self.resblks[i], data, octree, depth)
122 | else:
123 | data = self.resblks[i](data, octree, depth)
124 | return data
125 |
--------------------------------------------------------------------------------
/utils/ocnn/nn/__init__.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Octree-based Sparse Convolutional Neural Networks
3 | # Copyright (c) 2022 Peng-Shuai Wang
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Peng-Shuai Wang
6 | # --------------------------------------------------------
7 |
8 | from .octree2vox import octree2voxel, Octree2Voxel
9 | from .octree2col import octree2col, col2octree
10 | from .octree_pad import octree_pad, octree_depad
11 | from .octree_interp import (octree_nearest_pts, octree_linear_pts,
12 | OctreeInterp, OctreeUpsample)
13 | from .octree_pool import (octree_max_pool, OctreeMaxPool,
14 | octree_max_unpool, OctreeMaxUnpool,
15 | octree_global_pool, OctreeGlobalPool,
16 | octree_avg_pool, OctreeAvgPool,)
17 | from .octree_conv import OctreeConv, OctreeDeconv
18 | from .octree_dwconv import OctreeDWConv
19 | from .octree_norm import OctreeInstanceNorm, OctreeBatchNorm
20 | from .octree_drop import OctreeDropPath
21 |
22 |
23 | __all__ = [
24 | 'octree2voxel',
25 | 'octree2col', 'col2octree',
26 | 'octree_pad', 'octree_depad',
27 | 'octree_nearest_pts', 'octree_linear_pts',
28 | 'octree_max_pool', 'octree_max_unpool',
29 | 'octree_global_pool', 'octree_avg_pool',
30 | 'Octree2Voxel',
31 | 'OctreeMaxPool', 'OctreeMaxUnpool',
32 | 'OctreeGlobalPool', 'OctreeAvgPool',
33 | 'OctreeConv', 'OctreeDeconv',
34 | 'OctreeDWConv',
35 | 'OctreeInterp', 'OctreeUpsample',
36 | 'OctreeInstanceNorm', 'OctreeBatchNorm',
37 | 'OctreeDropPath',
38 | ]
39 |
40 | classes = __all__
41 |
--------------------------------------------------------------------------------
/utils/ocnn/nn/octree2col.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Octree-based Sparse Convolutional Neural Networks
3 | # Copyright (c) 2022 Peng-Shuai Wang
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Peng-Shuai Wang
6 | # --------------------------------------------------------
7 |
8 | import torch
9 | import torch.nn
10 |
11 | from utils.ocnn.octree import Octree
12 | from utils.ocnn.utils import scatter_add
13 |
14 |
15 | def octree2col(data: torch.Tensor, octree: Octree, depth: int,
16 | kernel_size: str = '333', stride: int = 1, nempty: bool = False):
17 | r''' Gathers the neighboring features for convolutions.
18 |
19 | Args:
20 | data (torch.Tensor): The input data.
21 | octree (Octree): The corresponding octree.
22 | depth (int): The depth of current octree.
23 | kernel_size (str): The kernel shape, choose from :obj:`333`, :obj:`311`,
24 | :obj:`131`, :obj:`113`, :obj:`222`, :obj:`331`, :obj:`133`, and
25 | :obj:`313`.
26 | stride (int): The stride of neighborhoods (:obj:`1` or :obj:`2`). If the
27 | stride is :obj:`2`, it always returns the neighborhood of the first
28 | siblings, and the number of elements of output tensor is
29 | :obj:`octree.nnum[depth] / 8`.
30 | nempty (bool): If True, only returns the neighborhoods of the non-empty
31 | octree nodes.
32 | '''
33 |
34 | neigh = octree.get_neigh(depth, kernel_size, stride, nempty)
35 | size = (neigh.shape[0], neigh.shape[1], data.shape[1])
36 | out = torch.zeros(size, dtype=data.dtype, device=data.device)
37 | valid = neigh >= 0
38 | out[valid] = data[neigh[valid]] # (N, K, C)
39 | return out
40 |
41 |
42 | def col2octree(data: torch.Tensor, octree: Octree, depth: int,
43 | kernel_size: str = '333', stride: int = 1, nempty: bool = False):
44 | r''' Scatters the convolution features to an octree.
45 |
46 | Please refer to :func:`octree2col` for the usage of function parameters.
47 | '''
48 |
49 | neigh = octree.get_neigh(depth, kernel_size, stride, nempty)
50 | valid = neigh >= 0
51 | dim_size = octree.nnum_nempty[depth] if nempty else octree.nnum[depth]
52 | out = scatter_add(data[valid], neigh[valid], dim=0, dim_size=dim_size)
53 | return out
54 |
--------------------------------------------------------------------------------
/utils/ocnn/nn/octree2vox.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Octree-based Sparse Convolutional Neural Networks
3 | # Copyright (c) 2022 Peng-Shuai Wang
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Peng-Shuai Wang
6 | # --------------------------------------------------------
7 |
8 | import torch
9 |
10 | from ..octree import Octree, key2xyz
11 | from .octree_pad import octree_depad
12 |
13 |
14 | def octree2voxel(data: torch.Tensor, octree: Octree, depth: int,
15 | nempty: bool = False):
16 | r''' Converts the input feature to full-voxel-based representation.
17 |
18 | Args:
19 | data (torch.Tensor): The input feature.
20 | octree (Octree): The corresponding octree.
21 | depth (int): The depth of current octree.
22 | nempty (bool): If True, :attr:`data` only contains the features of non-empty
23 | octree nodes.
24 | '''
25 |
26 | key = octree.keys[depth]
27 | if nempty:
28 | key = octree_depad(key, octree, depth)
29 | x, y, z, b = key2xyz(key, depth)
30 |
31 | num = 1 << depth
32 | channel = data.shape[1]
33 | batch_size = octree.batch_size
34 | size = (batch_size, num, num, num, channel)
35 | vox = torch.zeros(size, dtype=data.dtype, device=data.device)
36 | vox[b, x, y, z] = data
37 | return vox
38 |
39 |
40 | class Octree2Voxel(torch.nn.Module):
41 | r''' Converts the input feature to full-voxel-based representation
42 |
43 | Please refer to :func:`octree2voxel` for details.
44 | '''
45 |
46 | def __init__(self, nempty: bool = False):
47 | super().__init__()
48 | self.nempty = nempty
49 |
50 | def forward(self, data: torch.Tensor, octree: Octree, depth: int):
51 | r''''''
52 |
53 | return octree2voxel(data, octree, depth, self.nempty)
54 |
55 | def extra_repr(self) -> str:
56 | return 'nempty={}'.format(self.nempty)
57 |
--------------------------------------------------------------------------------
/utils/ocnn/nn/octree_conv.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Octree-based Sparse Convolutional Neural Networks
3 | # Copyright (c) 2022 Peng-Shuai Wang
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Peng-Shuai Wang
6 | # --------------------------------------------------------
7 |
8 | import torch
9 | import torch.nn
10 | from torch.autograd import Function
11 | from typing import List
12 |
13 | from utils.ocnn.octree import Octree
14 | from utils.ocnn.utils import scatter_add, xavier_uniform_, resize_with_last_val, list2str
15 | from .octree2col import octree2col, col2octree
16 | from .octree_pad import octree_pad, octree_depad
17 |
18 |
19 | class OctreeConvBase:
20 |
21 | def __init__(self, in_channels: int, out_channels: int,
22 | kernel_size: List[int] = [3], stride: int = 1,
23 | nempty: bool = False, max_buffer: int = int(2e8)):
24 | super().__init__()
25 | self.in_channels = in_channels
26 | self.out_channels = out_channels
27 | self.kernel_size = resize_with_last_val(kernel_size)
28 | self.kernel = list2str(self.kernel_size)
29 | self.stride = stride
30 | self.nempty = nempty
31 | self.max_buffer = max_buffer # about 200M
32 |
33 | self.kdim = self.kernel_size[0] * self.kernel_size[1] * self.kernel_size[2]
34 | self.in_conv = in_channels if self.is_conv_layer() else out_channels
35 | self.out_conv = out_channels if self.is_conv_layer() else in_channels
36 | self.weights_shape = (self.kdim, self.in_conv, self.out_conv)
37 |
38 | def is_conv_layer(self):
39 | r''' Returns :obj:`True` to indicate this is a convolution layer.
40 | '''
41 |
42 | raise NotImplementedError
43 |
44 | def setup(self, octree: Octree, depth: int):
45 | r''' Setup the shapes of each tensor.
46 | This function MUST be called before :obj:`forward_gemm`, :obj:`backward_gemm`
47 | and :obj:`weight_gemm`.
48 | '''
49 |
50 | # The depth of tensors:
51 | # The in_depth and out_depth are the octree depth of the input and output
52 | # data; neigh_depth is the octree depth of the neighborhood information, as
53 | # well as `col` data, neigh_depth is always the same as the depth of larger
54 | # data when doing octree2col or col2octree.
55 | self.in_depth = depth
56 | self.out_depth = depth
57 | self.neigh_depth = depth
58 | if self.stride == 2:
59 | if self.is_conv_layer():
60 | self.out_depth = depth - 1
61 | else:
62 | self.out_depth = depth + 1
63 | self.neigh_depth = depth + 1
64 |
65 | # The height of tensors
66 | if self.nempty:
67 | self.in_h = octree.nnum_nempty[self.in_depth]
68 | self.out_h = octree.nnum_nempty[self.out_depth]
69 | else:
70 | self.in_h = octree.nnum[self.in_depth]
71 | self.out_h = octree.nnum[self.out_depth]
72 | if self.stride == 2:
73 | if self.is_conv_layer():
74 | self.out_h = octree.nnum_nempty[self.out_depth]
75 | else:
76 | self.in_h = octree.nnum_nempty[self.in_depth]
77 | self.in_shape = (self.in_h, self.in_channels)
78 | self.out_shape = (self.out_h, self.out_channels)
79 |
80 | # The neighborhood indices
81 | self.neigh = octree.get_neigh(
82 | self.neigh_depth, self.kernel, self.stride, self.nempty)
83 |
84 | # The heigh and number of the temporary buffer
85 | self.buffer_n = 1
86 | self.buffer_h = self.neigh.shape[0]
87 | ideal_size = self.buffer_h * self.kdim * self.in_conv
88 | if ideal_size > self.max_buffer:
89 | kc = self.kdim * self.in_conv # make `max_buffer` be divided
90 | max_buffer = self.max_buffer // kc * kc # by `kc` with no remainder
91 | self.buffer_n = (ideal_size + max_buffer - 1) // max_buffer
92 | self.buffer_h = (self.buffer_h + self.buffer_n - 1) // self.buffer_n
93 | self.buffer_shape = (self.buffer_h, self.kdim, self.in_conv)
94 |
95 | def check_and_init(self, data: torch.Tensor):
96 | r''' Checks the input data and initializes the shape of output data.
97 | '''
98 |
99 | # Check the shape of input data
100 | check = tuple(data.shape) == self.in_shape
101 | assert check, 'The shape of input data is wrong.'
102 |
103 | # Init the output data
104 | out = data.new_zeros(self.out_shape)
105 | return out
106 |
107 | def forward_gemm(self, out: torch.Tensor, data: torch.Tensor,
108 | weights: torch.Tensor):
109 | r''' Peforms the forward pass of octree-based convolution.
110 | '''
111 |
112 | # Initialize the buffer
113 | buffer = data.new_empty(self.buffer_shape)
114 |
115 | # Loop over each sub-matrix
116 | for i in range(self.buffer_n):
117 | start = i * self.buffer_h
118 | end = (i + 1) * self.buffer_h
119 |
120 | # The boundary case in the last iteration
121 | if end > self.neigh.shape[0]:
122 | dis = end - self.neigh.shape[0]
123 | end = self.neigh.shape[0]
124 | buffer, _ = buffer.split([self.buffer_h-dis, dis])
125 |
126 | # Perform octree2col
127 | neigh_i = self.neigh[start:end]
128 | valid = neigh_i >= 0
129 | buffer.fill_(0)
130 | buffer[valid] = data[neigh_i[valid]]
131 |
132 | # The sub-matrix gemm
133 | out[start:end] = torch.mm(buffer.flatten(1, 2), weights.flatten(0, 1))
134 |
135 | return out
136 |
137 | def backward_gemm(self, out: torch.Tensor, grad: torch.Tensor,
138 | weights: torch.Tensor):
139 | r''' Performs the backward pass of octree-based convolution.
140 | '''
141 |
142 | # Loop over each sub-matrix
143 | for i in range(self.buffer_n):
144 | start = i * self.buffer_h
145 | end = (i + 1) * self.buffer_h
146 |
147 | # The boundary case in the last iteration
148 | if end > self.neigh.shape[0]:
149 | end = self.neigh.shape[0]
150 |
151 | # The sub-matrix gemm
152 | buffer = torch.mm(grad[start:end], weights.flatten(0, 1).t())
153 | buffer = buffer.view(-1, self.buffer_shape[1], self.buffer_shape[2])
154 | buffer = buffer.to(out.dtype) # for pytorch.amp
155 |
156 | # Performs col2octree
157 | neigh_i = self.neigh[start:end]
158 | valid = neigh_i >= 0
159 | out = scatter_add(buffer[valid], neigh_i[valid], dim=0, out=out)
160 |
161 | return out
162 |
163 | def weight_gemm(
164 | self, out: torch.Tensor, data: torch.Tensor, grad: torch.Tensor):
165 | r''' Computes the gradient of the weight matrix.
166 | '''
167 |
168 | # Record the shape of out
169 | out_shape = out.shape
170 | out = out.flatten(0, 1)
171 |
172 | # Initialize the buffer
173 | buffer = data.new_empty(self.buffer_shape)
174 |
175 | # Loop over each sub-matrix
176 | for i in range(self.buffer_n):
177 | start = i * self.buffer_h
178 | end = (i + 1) * self.buffer_h
179 |
180 | # The boundary case in the last iteration
181 | if end > self.neigh.shape[0]:
182 | d = end - self.neigh.shape[0]
183 | end = self.neigh.shape[0]
184 | buffer, _ = buffer.split([self.buffer_h-d, d])
185 |
186 | # Perform octree2col
187 | neigh_i = self.neigh[start:end]
188 | valid = neigh_i >= 0
189 | buffer.fill_(0)
190 | buffer[valid] = data[neigh_i[valid]]
191 |
192 | # Accumulate the gradient via gemm
193 | out.addmm_(buffer.flatten(1, 2).t(), grad[start:end])
194 |
195 | return out.view(out_shape)
196 |
197 |
198 | class _OctreeConv(OctreeConvBase):
199 | r''' Instantiates _OctreeConvBase by overriding `is_conv_layer`
200 | '''
201 |
202 | def is_conv_layer(self): return True
203 |
204 |
205 | class _OctreeDeconv(OctreeConvBase):
206 | r''' Instantiates _OctreeConvBase by overriding `is_conv_layer`
207 | '''
208 |
209 | def is_conv_layer(self): return False
210 |
211 |
212 | class OctreeConvFunction(Function):
213 | r''' Wrap the octree convolution for auto-diff.
214 | '''
215 |
216 | @staticmethod
217 | def forward(ctx, data: torch.Tensor, weights: torch.Tensor, octree: Octree,
218 | depth: int, in_channels: int, out_channels: int,
219 | kernel_size: List[int] = [3, 3, 3], stride: int = 1,
220 | nempty: bool = False, max_buffer: int = int(2e8)):
221 | octree_conv = _OctreeConv(
222 | in_channels, out_channels, kernel_size, stride, nempty, max_buffer)
223 | octree_conv.setup(octree, depth)
224 | out = octree_conv.check_and_init(data)
225 | out = octree_conv.forward_gemm(out, data, weights)
226 |
227 | ctx.save_for_backward(data, weights)
228 | ctx.octree_conv = octree_conv
229 | return out
230 |
231 | @staticmethod
232 | def backward(ctx, grad):
233 | data, weights = ctx.saved_tensors
234 | octree_conv = ctx.octree_conv
235 |
236 | grad_out = None
237 | if ctx.needs_input_grad[0]:
238 | grad_out = torch.zeros_like(data)
239 | grad_out = octree_conv.backward_gemm(grad_out, grad, weights)
240 |
241 | grad_w = None
242 | if ctx.needs_input_grad[1]:
243 | grad_w = torch.zeros_like(weights)
244 | grad_w = octree_conv.weight_gemm(grad_w, data, grad)
245 |
246 | return (grad_out, grad_w) + (None,) * 8
247 |
248 |
249 | class OctreeDeconvFunction(Function):
250 | r''' Wrap the octree deconvolution for auto-diff.
251 | '''
252 |
253 | @staticmethod
254 | def forward(ctx, data: torch.Tensor, weights: torch.Tensor, octree: Octree,
255 | depth: int, in_channels: int, out_channels: int,
256 | kernel_size: List[int] = [3, 3, 3], stride: int = 1,
257 | nempty: bool = False, max_buffer: int = int(2e8)):
258 | octree_deconv = _OctreeDeconv(
259 | in_channels, out_channels, kernel_size, stride, nempty, max_buffer)
260 | octree_deconv.setup(octree, depth)
261 | out = octree_deconv.check_and_init(data)
262 | out = octree_deconv.backward_gemm(out, data, weights)
263 |
264 | ctx.save_for_backward(data, weights)
265 | ctx.octree_deconv = octree_deconv
266 | return out
267 |
268 | @staticmethod
269 | def backward(ctx, grad):
270 | data, weights = ctx.saved_tensors
271 | octree_deconv = ctx.octree_deconv
272 |
273 | grad_out = None
274 | if ctx.needs_input_grad[0]:
275 | grad_out = torch.zeros_like(data)
276 | grad_out = octree_deconv.forward_gemm(grad_out, grad, weights)
277 |
278 | grad_w = None
279 | if ctx.needs_input_grad[1]:
280 | grad_w = torch.zeros_like(weights)
281 | grad_w = octree_deconv.weight_gemm(grad_w, grad, data)
282 |
283 | return (grad_out, grad_w) + (None,) * 8
284 |
285 |
286 | # alias
287 | octree_conv = OctreeConvFunction.apply
288 | octree_deconv = OctreeDeconvFunction.apply
289 |
290 |
291 | class OctreeConv(OctreeConvBase, torch.nn.Module):
292 | r''' Performs octree convolution.
293 |
294 | Args:
295 | in_channels (int): Number of input channels.
296 | out_channels (int): Number of output channels.
297 | kernel_size (List(int)): The kernel shape, choose from :obj:`[3]`, :obj:`[2]`,
298 | :obj:`[3,3,3]`, :obj:`[3,1,1]`, :obj:`[1,3,1]`, :obj:`[1,1,3]`,
299 | :obj:`[2,2,2]`, :obj:`[3,3,1]`, :obj:`[1,3,3]`, and :obj:`[3,1,3]`.
300 | stride (int): The stride of the convolution (:obj:`1` or :obj:`2`).
301 | nempty (bool): If True, only performs the convolution on non-empty
302 | octree nodes.
303 | direct_method (bool): If True, directly performs the convolution via using
304 | gemm and octree2col/col2octree. The octree2col/col2octree needs to
305 | construct a large matrix, which may consume a lot of memory. If False,
306 | performs the convolution in a sub-matrix manner, which can save the
307 | requied runtime memory.
308 | use_bias (bool): If True, add a bias term to the convolution.
309 | max_buffer (int): The maximum number of elements in the buffer, used when
310 | :attr:`direct_method` is False.
311 | '''
312 |
313 | def __init__(self, in_channels: int, out_channels: int,
314 | kernel_size: List[int] = [3], stride: int = 1,
315 | nempty: bool = False, direct_method: bool = False,
316 | use_bias: bool = False, max_buffer: int = int(2e8)):
317 | super().__init__(
318 | in_channels, out_channels, kernel_size, stride, nempty, max_buffer)
319 |
320 | self.direct_method = direct_method
321 | self.use_bias = use_bias
322 | self.weights = torch.nn.Parameter(torch.Tensor(*self.weights_shape))
323 | if self.use_bias:
324 | self.bias = torch.nn.Parameter(torch.Tensor(out_channels))
325 | self.reset_parameters()
326 |
327 | def reset_parameters(self):
328 | xavier_uniform_(self.weights)
329 | if self.use_bias:
330 | torch.nn.init.zeros_(self.bias)
331 |
332 | def is_conv_layer(self): return True
333 |
334 | def forward(self, data: torch.Tensor, octree: Octree, depth: int):
335 | r''' Defines the octree convolution.
336 |
337 | Args:
338 | data (torch.Tensor): The input data.
339 | octree (Octree): The corresponding octree.
340 | depth (int): The depth of current octree.
341 | '''
342 |
343 | if self.direct_method:
344 | col = octree2col(
345 | data, octree, depth, self.kernel, self.stride, self.nempty)
346 | out = torch.mm(col.flatten(1), self.weights.flatten(0, 1))
347 | else:
348 | out = octree_conv(
349 | data, self.weights, octree, depth, self.in_channels,
350 | self.out_channels, self.kernel_size, self.stride, self.nempty,
351 | self.max_buffer)
352 |
353 | if self.use_bias:
354 | out += self.bias
355 |
356 | if self.stride == 2 and not self.nempty:
357 | out = octree_pad(out, octree, depth-1)
358 | return out
359 |
360 | def extra_repr(self) -> str:
361 | r''' Sets the extra representation of the module.
362 | '''
363 |
364 | return ('in_channels={}, out_channels={}, kernel_size={}, stride={}, '
365 | 'nempty={}, bias={}').format(self.in_channels, self.out_channels,
366 | self.kernel_size, self.stride, self.nempty, self.use_bias) # noqa
367 |
368 |
369 | class OctreeDeconv(OctreeConv):
370 | r''' Performs octree deconvolution.
371 |
372 | Please refer to :class:`OctreeConv` for the meaning of the arguments.
373 | '''
374 |
375 | def is_conv_layer(self): return False
376 |
377 | def forward(self, data: torch.Tensor, octree: Octree, depth: int):
378 | r''' Defines the octree deconvolution.
379 |
380 | Please refer to :meth:`OctreeConv.forward` for the meaning of the arguments.
381 | '''
382 |
383 | depth_col = depth
384 | if self.stride == 2:
385 | depth_col = depth + 1
386 | if not self.nempty:
387 | data = octree_depad(data, octree, depth)
388 |
389 | if self.direct_method:
390 | col = torch.mm(data, self.weights.flatten(0, 1).t())
391 | col = col.view(col.shape[0], self.kdim, -1)
392 | out = col2octree(
393 | col, octree, depth_col, self.kernel, self.stride, self.nempty)
394 | else:
395 | out = octree_deconv(
396 | data, self.weights, octree, depth, self.in_channels,
397 | self.out_channels, self.kernel_size, self.stride, self.nempty,
398 | self.max_buffer)
399 |
400 | if self.use_bias:
401 | out += self.bias
402 | return out
403 |
--------------------------------------------------------------------------------
/utils/ocnn/nn/octree_drop.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Octree-based Sparse Convolutional Neural Networks
3 | # Copyright (c) 2022 Peng-Shuai Wang
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Peng-Shuai Wang
6 | # --------------------------------------------------------
7 |
8 | import torch
9 | from typing import Optional
10 |
11 | from utils.ocnn.octree import Octree
12 |
13 |
14 | class OctreeDropPath(torch.nn.Module):
15 | r'''Drop paths (Stochastic Depth) per sample when applied in main path of
16 | residual blocks, following the logic of :func:`timm.models.layers.DropPath`.
17 |
18 | Args:
19 | drop_prob (int): The probability of drop paths.
20 | nempty (bool): Indicate whether the input data only contains features of the
21 | non-empty octree nodes or not.
22 | scale_by_keep (bool): Whether to scale the kept features proportionally.
23 | '''
24 |
25 | def __init__(self, drop_prob: float = 0.0, nempty: bool = False,
26 | scale_by_keep: bool = True):
27 | super().__init__()
28 |
29 | self.drop_prob = drop_prob
30 | self.nempty = nempty
31 | self.scale_by_keep = scale_by_keep
32 |
33 | def forward(self, data: torch.Tensor, octree: Octree, depth: int,
34 | batch_id: Optional[torch.Tensor] = None):
35 | r''''''
36 |
37 | if self.drop_prob <= 0.0 or not self.training:
38 | return data
39 |
40 | batch_size = octree.batch_size
41 | keep_prob = 1 - self.drop_prob
42 | rnd_tensor = torch.rand(batch_size, 1, dtype=data.dtype, device=data.device)
43 | rnd_tensor = torch.floor(rnd_tensor + keep_prob)
44 | if keep_prob > 0.0 and self.scale_by_keep:
45 | rnd_tensor.div_(keep_prob)
46 |
47 | if batch_id is None:
48 | batch_id = octree.batch_id(depth, self.nempty)
49 | drop_mask = rnd_tensor[batch_id]
50 | output = data * drop_mask
51 | return output
52 |
53 | def extra_repr(self) -> str:
54 | return ('drop_prob={:.4f}, nempty={}, scale_by_keep={}').format(
55 | self.drop_prob, self.nempty, self.scale_by_keep) # noqa
56 |
--------------------------------------------------------------------------------
/utils/ocnn/nn/octree_dwconv.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Octree-based Sparse Convolutional Neural Networks
3 | # Copyright (c) 2022 Peng-Shuai Wang
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Peng-Shuai Wang
6 | # --------------------------------------------------------
7 |
8 | import torch
9 | import torch.nn
10 | from torch.autograd import Function
11 | from typing import List
12 |
13 | from utils.ocnn.octree import Octree
14 | from utils.ocnn.utils import scatter_add, xavier_uniform_
15 | from .octree_pad import octree_pad
16 | from .octree_conv import OctreeConvBase
17 |
18 |
19 | class OctreeDWConvBase(OctreeConvBase):
20 |
21 | def __init__(self, in_channels: int, kernel_size: List[int] = [3],
22 | stride: int = 1, nempty: bool = False,
23 | max_buffer: int = int(2e8)):
24 | super().__init__(
25 | in_channels, in_channels, kernel_size, stride, nempty, max_buffer)
26 | self.weights_shape = (self.kdim, 1, self.out_channels)
27 |
28 | def is_conv_layer(self): return True
29 |
30 | def forward_gemm(self, out: torch.Tensor, data: torch.Tensor,
31 | weights: torch.Tensor):
32 | r''' Peforms the forward pass of octree-based convolution.
33 | '''
34 |
35 | # Initialize the buffer
36 | buffer = data.new_empty(self.buffer_shape)
37 |
38 | # Loop over each sub-matrix
39 | for i in range(self.buffer_n):
40 | start = i * self.buffer_h
41 | end = (i + 1) * self.buffer_h
42 |
43 | # The boundary case in the last iteration
44 | if end > self.neigh.shape[0]:
45 | dis = end - self.neigh.shape[0]
46 | end = self.neigh.shape[0]
47 | buffer, _ = buffer.split([self.buffer_h-dis, dis])
48 |
49 | # Perform octree2col
50 | neigh_i = self.neigh[start:end]
51 | valid = neigh_i >= 0
52 | buffer.fill_(0)
53 | buffer[valid] = data[neigh_i[valid]]
54 |
55 | # The sub-matrix gemm
56 | # out[start:end] = torch.mm(buffer.flatten(1, 2), weights.flatten(0, 1))
57 | out[start:end] = torch.einsum('ikc,kc->ic', buffer, weights.flatten(0, 1))
58 | return out
59 |
60 | def backward_gemm(self, out: torch.Tensor, grad: torch.Tensor,
61 | weights: torch.Tensor):
62 | r''' Performs the backward pass of octree-based convolution.
63 | '''
64 |
65 | # Loop over each sub-matrix
66 | for i in range(self.buffer_n):
67 | start = i * self.buffer_h
68 | end = (i + 1) * self.buffer_h
69 |
70 | # The boundary case in the last iteration
71 | if end > self.neigh.shape[0]:
72 | end = self.neigh.shape[0]
73 |
74 | # The sub-matrix gemm
75 | # buffer = torch.mm(grad[start:end], weights.flatten(0, 1).t())
76 | # buffer = buffer.view(-1, self.buffer_shape[1], self.buffer_shape[2])
77 | buffer = torch.einsum(
78 | 'ic,kc->ikc', grad[start:end], weights.flatten(0, 1))
79 |
80 | # Performs col2octree
81 | neigh_i = self.neigh[start:end]
82 | valid = neigh_i >= 0
83 | out = scatter_add(buffer[valid], neigh_i[valid], dim=0, out=out)
84 |
85 | return out
86 |
87 | def weight_gemm(self, out: torch.Tensor, data: torch.Tensor, grad: torch.Tensor):
88 | r''' Computes the gradient of the weight matrix.
89 | '''
90 |
91 | # Record the shape of out
92 | out_shape = out.shape
93 | out = out.flatten(0, 1)
94 |
95 | # Initialize the buffer
96 | buffer = data.new_empty(self.buffer_shape)
97 |
98 | # Loop over each sub-matrix
99 | for i in range(self.buffer_n):
100 | start = i * self.buffer_h
101 | end = (i + 1) * self.buffer_h
102 |
103 | # The boundary case in the last iteration
104 | if end > self.neigh.shape[0]:
105 | d = end - self.neigh.shape[0]
106 | end = self.neigh.shape[0]
107 | buffer, _ = buffer.split([self.buffer_h-d, d])
108 |
109 | # Perform octree2col
110 | neigh_i = self.neigh[start:end]
111 | valid = neigh_i >= 0
112 | buffer.fill_(0)
113 | buffer[valid] = data[neigh_i[valid]]
114 |
115 | # Accumulate the gradient via gemm
116 | # out.addmm_(buffer.flatten(1, 2).t(), grad[start:end])
117 | out += torch.einsum('ikc,ic->kc', buffer, grad[start:end])
118 | return out.view(out_shape)
119 |
120 |
121 | class OctreeDWConvFunction(Function):
122 | r''' Wrap the octree convolution for auto-diff.
123 | '''
124 |
125 | @staticmethod
126 | def forward(ctx, data: torch.Tensor, weights: torch.Tensor, octree: Octree,
127 | depth: int, in_channels: int, kernel_size: List[int] = [3, 3, 3],
128 | stride: int = 1, nempty: bool = False, max_buffer: int = int(2e8)):
129 | octree_conv = OctreeDWConvBase(
130 | in_channels, kernel_size, stride, nempty, max_buffer)
131 | octree_conv.setup(octree, depth)
132 | out = octree_conv.check_and_init(data)
133 | out = octree_conv.forward_gemm(out, data, weights)
134 |
135 | ctx.save_for_backward(data, weights)
136 | ctx.octree_conv = octree_conv
137 | return out
138 |
139 | @staticmethod
140 | def backward(ctx, grad):
141 | data, weights = ctx.saved_tensors
142 | octree_conv = ctx.octree_conv
143 |
144 | grad_out = None
145 | if ctx.needs_input_grad[0]:
146 | grad_out = torch.zeros_like(data)
147 | grad_out = octree_conv.backward_gemm(grad_out, grad, weights)
148 |
149 | grad_w = None
150 | if ctx.needs_input_grad[1]:
151 | grad_w = torch.zeros_like(weights)
152 | grad_w = octree_conv.weight_gemm(grad_w, data, grad)
153 |
154 | return (grad_out, grad_w) + (None,) * 7
155 |
156 |
157 | # alias
158 | octree_dwconv = OctreeDWConvFunction.apply
159 |
160 |
161 | class OctreeDWConv(OctreeDWConvBase, torch.nn.Module):
162 | r''' Performs octree-based depth-wise convolution.
163 |
164 | Please refer to :class:`ocnn.nn.OctreeConv` for the meaning of the arguments.
165 |
166 | .. note::
167 | This implementation uses the :func:`torch.einsum` and I find that the speed
168 | is relatively slow. Further optimization is needed to speed it up.
169 | '''
170 |
171 | def __init__(self, in_channels: int, kernel_size: List[int] = [3],
172 | stride: int = 1, nempty: bool = False, use_bias: bool = False,
173 | max_buffer: int = int(2e8)):
174 | super().__init__(in_channels, kernel_size, stride, nempty, max_buffer)
175 |
176 | self.use_bias = use_bias
177 | self.weights = torch.nn.Parameter(torch.Tensor(*self.weights_shape))
178 | if self.use_bias:
179 | self.bias = torch.nn.Parameter(torch.Tensor(in_channels))
180 | self.reset_parameters()
181 |
182 | def reset_parameters(self):
183 | xavier_uniform_(self.weights)
184 | if self.use_bias:
185 | torch.nn.init.zeros_(self.bias)
186 |
187 | def forward(self, data: torch.Tensor, octree: Octree, depth: int):
188 | r''''''
189 |
190 | out = octree_dwconv(
191 | data, self.weights, octree, depth, self.in_channels,
192 | self.kernel_size, self.stride, self.nempty, self.max_buffer)
193 |
194 | if self.use_bias:
195 | out += self.bias
196 |
197 | if self.stride == 2 and not self.nempty:
198 | out = octree_pad(out, octree, depth-1)
199 | return out
200 |
201 | def extra_repr(self) -> str:
202 | return ('in_channels={}, out_channels={}, kernel_size={}, stride={}, '
203 | 'nempty={}, bias={}').format(self.in_channels, self.out_channels,
204 | self.kernel_size, self.stride, self.nempty, self.use_bias) # noqa
205 |
--------------------------------------------------------------------------------
/utils/ocnn/nn/octree_interp.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Octree-based Sparse Convolutional Neural Networks
3 | # Copyright (c) 2022 Peng-Shuai Wang
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Peng-Shuai Wang
6 | # --------------------------------------------------------
7 |
8 | import torch
9 | import torch.sparse
10 | from typing import Optional
11 |
12 | from utils import ocnn
13 | from utils.ocnn.octree import Octree
14 |
15 |
16 | def octree_nearest_pts(data: torch.Tensor, octree: Octree, depth: int,
17 | pts: torch.Tensor, nempty: bool = False,
18 | bound_check: bool = False):
19 | ''' The nearest-neighbor interpolatation with input points.
20 |
21 | Args:
22 | data (torch.Tensor): The input data.
23 | octree (Octree): The octree to interpolate.
24 | depth (int): The depth of the data.
25 | pts (torch.Tensor): The coordinates of the points with shape :obj:`(N, 4)`,
26 | i.e. :obj:`N x (x, y, z, batch)`.
27 | nempty (bool): If true, the :attr:`data` only contains features of non-empty
28 | octree nodes
29 | bound_check (bool): If true, check whether the point is in :obj:`[0, 2^depth)`.
30 |
31 | .. note::
32 | The :attr:`pts` MUST be scaled into :obj:`[0, 2^depth)`.
33 | '''
34 |
35 | nnum = octree.nnum_nempty[depth] if nempty else octree.nnum[depth]
36 | assert data.shape[0] == nnum, 'The shape of input data is wrong.'
37 |
38 | idx = octree.search_xyzb(pts, depth, nempty)
39 | valid = idx > -1 # valid indices
40 | if bound_check:
41 | bound = torch.logical_and(pts[:, :3] >= 0, pts[:, :3] < 2**depth).all(1)
42 | valid = torch.logical_and(valid, bound)
43 |
44 | size = (pts.shape[0], data.shape[1])
45 | out = torch.zeros(size, device=data.device, dtype=data.dtype)
46 | out[valid] = data.index_select(0, idx[valid])
47 | return out
48 |
49 |
50 | def octree_linear_pts(data: torch.Tensor, octree: Octree, depth: int,
51 | pts: torch.Tensor, nempty: bool = False,
52 | bound_check: bool = False):
53 | ''' Linear interpolatation with input points.
54 |
55 | Refer to :func:`octree_nearest_pts` for the meaning of the arguments.
56 | '''
57 |
58 | nnum = octree.nnum_nempty[depth] if nempty else octree.nnum[depth]
59 | assert data.shape[0] == nnum, 'The shape of input data is wrong.'
60 |
61 | device = data.device
62 | grid = torch.tensor(
63 | [[0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1],
64 | [1, 0, 0], [1, 0, 1], [1, 1, 0], [1, 1, 1]], device=device)
65 |
66 | # 1. Neighborhood searching
67 | xyzf = pts[:, :3] - 0.5 # the value is defined on the center of each voxel
68 | xyzi = xyzf.floor() # the integer part (N, 3)
69 | frac = xyzf - xyzi # the fraction part (N, 3)
70 |
71 | xyzn = (xyzi.unsqueeze(1) + grid).view(-1, 3)
72 | batch = pts[:, 3].unsqueeze(1).repeat(1, 8).view(-1, 1)
73 | idx = octree.search_xyzb(torch.cat([xyzn, batch], dim=1), depth, nempty)
74 | valid = idx > -1 # valid indices
75 | if bound_check:
76 | bound = torch.logical_and(xyzn >= 0, xyzn < 2**depth).all(1)
77 | valid = torch.logical_and(valid, bound)
78 | idx = idx[valid]
79 |
80 | # 2. Build the sparse matrix
81 | npt = pts.shape[0]
82 | ids = torch.arange(npt, device=idx.device)
83 | ids = ids.unsqueeze(1).repeat(1, 8).view(-1)
84 | ids = ids[valid]
85 | indices = torch.stack([ids, idx], dim=0).long()
86 |
87 | frac = (1.0 - grid) - frac.unsqueeze(dim=1) # (8, 3) - (N, 1, 3) -> (N, 8, 3)
88 | weight = frac.prod(dim=2).abs().view(-1) # (8*N,)
89 | weight = weight[valid]
90 |
91 | h = data.shape[0]
92 | mat = torch.sparse_coo_tensor(indices, weight, [npt, h], device=device)
93 |
94 | # 3. Interpolatation
95 | output = torch.sparse.mm(mat, data)
96 | ones = torch.ones(h, 1, dtype=data.dtype, device=device)
97 | norm = torch.sparse.mm(mat, ones)
98 | output = torch.div(output, norm + 1e-12)
99 | return output
100 |
101 |
102 | class OctreeInterp(torch.nn.Module):
103 | r''' Interpolates the points with an octree feature.
104 |
105 | Refer to :func:`octree_nearest_pts` for a description of arguments.
106 | '''
107 |
108 | def __init__(self, method: str = 'linear', nempty: bool = False,
109 | bound_check: bool = False, rescale_pts: bool = True):
110 | super().__init__()
111 | self.method = method
112 | self.nempty = nempty
113 | self.bound_check = bound_check
114 | self.rescale_pts = rescale_pts
115 | self.func = octree_linear_pts if method == 'linear' else octree_nearest_pts
116 |
117 | def forward(self, data: torch.Tensor, octree: Octree, depth: int,
118 | pts: torch.Tensor):
119 | r''''''
120 |
121 | # rescale points from [-1, 1] to [0, 2^depth]
122 | if self.rescale_pts:
123 | scale = 2 ** (depth - 1)
124 | pts[:, :3] = (pts[:, :3] + 1.0) * scale
125 |
126 | return self.func(data, octree, depth, pts, self.nempty, self.bound_check)
127 |
128 | def extra_repr(self) -> str:
129 | r''' Sets the extra representation of the module.
130 | '''
131 |
132 | return ('method={}, nempty={}, bound_check={}, rescale_pts={}').format(
133 | self.method, self.nempty, self.bound_check, self.rescale_pts) # noqa
134 |
135 |
136 | def octree_nearest_upsample(data: torch.Tensor, octree: Octree, depth: int,
137 | nempty: bool = False):
138 | r''' Upsamples the octree node features from :attr:`depth` to :attr:`(depth+1)`
139 | with the nearest-neighbor interpolation.
140 |
141 | Args:
142 | data (torch.Tensor): The input data.
143 | octree (Octree): The octree to interpolate.
144 | depth (int): The depth of the data.
145 | nempty (bool): If true, the :attr:`data` only contains features of non-empty
146 | octree nodes.
147 | '''
148 |
149 | nnum = octree.nnum_nempty[depth] if nempty else octree.nnum[depth]
150 | assert data.shape[0] == nnum, 'The shape of input data is wrong.'
151 |
152 | out = data
153 | if not nempty:
154 | out = ocnn.nn.octree_depad(out, octree, depth)
155 | out = out.unsqueeze(1).repeat(1, 8, 1).flatten(end_dim=1)
156 | if nempty:
157 | out = ocnn.nn.octree_depad(out, octree, depth + 1) # !!! depth+1
158 | return out
159 |
160 |
161 | class OctreeUpsample(torch.nn.Module):
162 | r''' Upsamples the octree node features from :attr:`depth` to
163 | :attr:`(target_depth)`.
164 |
165 | Refer to :class:`octree_nearest_pts` for details.
166 | '''
167 |
168 | def __init__(self, method: str = 'linear', nempty: bool = False):
169 | super().__init__()
170 | self.method = method
171 | self.nempty = nempty
172 | self.func = octree_linear_pts if method == 'linear' else octree_nearest_pts
173 |
174 | def forward(self, data: torch.Tensor, octree: Octree, depth: int,
175 | target_depth: Optional[int] = None):
176 | r''''''
177 |
178 | if target_depth is None:
179 | target_depth = depth + 1
180 | if target_depth == depth:
181 | return data # return, do nothing
182 | assert target_depth >= depth, 'target_depth must be larger than depth'
183 |
184 | if target_depth == depth + 1 and self.method == 'nearest':
185 | return octree_nearest_upsample(data, octree, depth, self.nempty)
186 |
187 | xyzb = octree.xyzb(target_depth, self.nempty)
188 | pts = torch.stack(xyzb, dim=1).float()
189 | pts[:, :3] = (pts[:, :3] + 0.5) * (2**(depth - target_depth)) # !!! rescale
190 | return self.func(data, octree, depth, pts, self.nempty)
191 |
192 | def extra_repr(self) -> str:
193 | r''' Sets the extra representation of the module.
194 | '''
195 |
196 | return ('method={}, nempty={}').format(self.method, self.nempty)
197 |
--------------------------------------------------------------------------------
/utils/ocnn/nn/octree_norm.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Octree-based Sparse Convolutional Neural Networks
3 | # Copyright (c) 2022 Peng-Shuai Wang
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Peng-Shuai Wang
6 | # --------------------------------------------------------
7 |
8 | import torch
9 | import torch.nn
10 |
11 | from utils.ocnn.octree import Octree
12 | from utils.ocnn.utils import scatter_add
13 |
14 |
15 | OctreeBatchNorm = torch.nn.BatchNorm1d
16 |
17 |
18 | class OctreeInstanceNorm(torch.nn.Module):
19 | r''' An instance normalization layer for the octree.
20 | '''
21 |
22 | def __init__(self, in_channels: int, nempty: bool = False):
23 | super().__init__()
24 |
25 | self.eps = 1e-5
26 | self.nempty = nempty
27 | self.in_channels = in_channels
28 |
29 | self.weights = torch.nn.Parameter(torch.Tensor(1, in_channels))
30 | self.bias = torch.nn.Parameter(torch.Tensor(1, in_channels))
31 | self.reset_parameters()
32 |
33 | def reset_parameters(self):
34 | torch.nn.init.ones_(self.weights)
35 | torch.nn.init.zeros_(self.bias)
36 |
37 | def forward(self, data: torch.Tensor, octree: Octree, depth: int):
38 | r''''''
39 |
40 | batch_size = octree.batch_size
41 | batch_id = octree.batch_id(depth, self.nempty)
42 | ones = data.new_ones([data.shape[0], 1])
43 | count = scatter_add(ones, batch_id, dim=0, dim_size=batch_size)
44 | norm = 1.0 / (count + self.eps) # there might be 0 element in some shapes
45 |
46 | mean = scatter_add(data, batch_id, dim=0, dim_size=batch_size) * norm
47 | out = data - mean.index_select(0, batch_id)
48 | var = scatter_add(out * out, batch_id, dim=0, dim_size=batch_size) * norm
49 | inv_std = 1.0 / (var + self.eps).sqrt()
50 | out = out * inv_std.index_select(0, batch_id)
51 |
52 | out = out * self.weights + self.bias
53 | return out
54 |
55 | def extra_repr(self) -> str:
56 | return ('in_channels={}, nempty={}').format(self.in_channels, self.nempty)
57 |
--------------------------------------------------------------------------------
/utils/ocnn/nn/octree_pad.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Octree-based Sparse Convolutional Neural Networks
3 | # Copyright (c) 2022 Peng-Shuai Wang
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Peng-Shuai Wang
6 | # --------------------------------------------------------
7 |
8 | import torch
9 |
10 | from ..octree import Octree
11 |
12 |
13 | def octree_pad(data: torch.Tensor, octree: Octree, depth: int, val: float = 0.0):
14 | r''' Pads :attr:`val` to make the number of elements of :attr:`data` equal to
15 | the octree node number.
16 |
17 | Args:
18 | data (torch.Tensor): The input tensor with its number of elements equal to the
19 | non-empty octree node number.
20 | octree (Octree): The corresponding octree.
21 | depth (int): The depth of current octree.
22 | val (float): The padding value. (Default: :obj:`0.0`)
23 | '''
24 |
25 | mask = octree.nempty_mask(depth)
26 | size = (octree.nnum[depth], data.shape[1]) # (N, C)
27 | out = torch.full(size, val, dtype=data.dtype, device=data.device)
28 | out[mask] = data
29 | return out
30 |
31 |
32 | def octree_depad(data: torch.Tensor, octree: Octree, depth: int):
33 | r''' Reverse operation of :func:`octree_depad`.
34 |
35 | Please refer to :func:`octree_depad` for the meaning of the arguments.
36 | '''
37 |
38 | mask = octree.nempty_mask(depth)
39 | return data[mask]
40 |
--------------------------------------------------------------------------------
/utils/ocnn/nn/octree_pool.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Octree-based Sparse Convolutional Neural Networks
3 | # Copyright (c) 2022 Peng-Shuai Wang
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Peng-Shuai Wang
6 | # --------------------------------------------------------
7 |
8 | import torch
9 | import torch.nn
10 | from typing import List
11 |
12 | from utils.ocnn.octree import Octree
13 | from utils.ocnn.utils import meshgrid, scatter_add, resize_with_last_val, list2str
14 | from . import octree_pad, octree_depad
15 |
16 |
17 | def octree_max_pool(data: torch.Tensor, octree: Octree, depth: int,
18 | nempty: bool = False, return_indices: bool = False):
19 | r''' Performs octree max pooling with kernel size 2 and stride 2.
20 |
21 | Args:
22 | data (torch.Tensor): The input tensor.
23 | octree (Octree): The corresponding octree.
24 | depth (int): The depth of current octree. After pooling, the corresponding
25 | depth decreased by 1.
26 | nempty (bool): If True, :attr:`data` contains only features of non-empty
27 | octree nodes.
28 | return_indices (bool): If True, returns the indices, which can be used in
29 | :func:`octree_max_unpool`.
30 | '''
31 |
32 | if nempty:
33 | data = octree_pad(data, octree, depth, float('-inf'))
34 | data = data.view(-1, 8, data.shape[1])
35 | out, indices = data.max(dim=1)
36 | if not nempty:
37 | out = octree_pad(out, octree, depth-1)
38 | return (out, indices) if return_indices else out
39 |
40 |
41 | def octree_max_unpool(data: torch.Tensor, indices: torch.Tensor, octree: Octree,
42 | depth: int, nempty: bool = False):
43 | r''' Performs octree max unpooling.
44 |
45 | Args:
46 | data (torch.Tensor): The input tensor.
47 | indices (torch.Tensor): The indices returned by :func:`octree_max_pool`. The
48 | depth of :attr:`indices` is larger by 1 than :attr:`data`.
49 | octree (Octree): The corresponding octree.
50 | depth (int): The depth of current data. After unpooling, the corresponding
51 | depth increases by 1.
52 | '''
53 |
54 | if not nempty:
55 | data = octree_depad(data, octree, depth)
56 | num, channel = data.shape
57 | out = torch.zeros(num, 8, channel, dtype=data.dtype, device=data.device)
58 | i = torch.arange(num, dtype=indices.dtype, device=indices.device)
59 | k = torch.arange(channel, dtype=indices.dtype, device=indices.device)
60 | i, k = meshgrid(i, k, indexing='ij')
61 | out[i, indices, k] = data
62 | out = out.view(-1, channel)
63 | if nempty:
64 | out = octree_depad(out, octree, depth+1)
65 | return out
66 |
67 |
68 | def octree_avg_pool(data: torch.Tensor, octree: Octree, depth: int,
69 | kernel: str, stride: int = 2, nempty: bool = False):
70 | r''' Performs octree average pooling.
71 |
72 | Args:
73 | data (torch.Tensor): The input tensor.
74 | octree (Octree): The corresponding octree.
75 | depth (int): The depth of current octree.
76 | kernel (str): The kernel size, like '333', '222'.
77 | stride (int): The stride of the pooling.
78 | nempty (bool): If True, :attr:`data` contains only features of non-empty
79 | octree nodes.
80 | '''
81 |
82 | neigh = octree.get_neigh(depth, kernel, stride, nempty)
83 |
84 | N1 = data.shape[0]
85 | N2 = neigh.shape[0]
86 | K = neigh.shape[1]
87 |
88 | mask = neigh >= 0
89 | val = 1.0 / (torch.sum(mask, dim=1) + 1e-8)
90 | mask = mask.view(-1)
91 | val = val.unsqueeze(1).repeat(1, K).reshape(-1)
92 | val = val[mask]
93 |
94 | row = torch.arange(N2, device=neigh.device)
95 | row = row.unsqueeze(1).repeat(1, K).view(-1)
96 | col = neigh.view(-1)
97 | indices = torch.stack([row[mask], col[mask]], dim=0).long()
98 |
99 | mat = torch.sparse_coo_tensor(indices, val, [N2, N1], device=data.device)
100 | out = torch.sparse.mm(mat, data)
101 | return out
102 |
103 |
104 | def octree_global_pool(data: torch.Tensor, octree: Octree, depth: int,
105 | nempty: bool = False):
106 | r''' Performs octree global average pooling.
107 |
108 | Args:
109 | data (torch.Tensor): The input tensor.
110 | octree (Octree): The corresponding octree.
111 | depth (int): The depth of current octree.
112 | nempty (bool): If True, :attr:`data` contains only features of non-empty
113 | octree nodes.
114 | '''
115 |
116 | batch_size = octree.batch_size
117 | batch_id = octree.batch_id(depth, nempty)
118 | ones = data.new_ones(data.shape[0], 1)
119 | count = scatter_add(ones, batch_id, dim=0, dim_size=batch_size)
120 | count[count < 1] = 1 # there might be 0 element in some shapes
121 |
122 | out = scatter_add(data, batch_id, dim=0, dim_size=batch_size)
123 | out = out / count
124 | return out
125 |
126 |
127 | class OctreePoolBase(torch.nn.Module):
128 | r''' The base class for octree-based pooling.
129 | '''
130 |
131 | def __init__(self, kernel_size: List[int], stride: int, nempty: bool = False):
132 | super().__init__()
133 | self.kernel_size = resize_with_last_val(kernel_size)
134 | self.kernel = list2str(self.kernel_size)
135 | self.stride = stride
136 | self.nempty = nempty
137 |
138 | def extra_repr(self) -> str:
139 | return ('kernel={}, stride={}, nempty={}').format(
140 | self.kernel, self.stride, self.nempty) # noqa
141 |
142 |
143 | class OctreeMaxPool(OctreePoolBase):
144 | r''' Performs octree max pooling.
145 |
146 | Please refer to :func:`octree_max_pool` for details.
147 | '''
148 |
149 | def __init__(self, nempty: bool = False, return_indices: bool = False):
150 | super().__init__(kernel_size=[2], stride=2, nempty=nempty)
151 | self.return_indices = return_indices
152 |
153 | def forward(self, data: torch.Tensor, octree: Octree, depth: int):
154 | r''''''
155 |
156 | return octree_max_pool(data, octree, depth, self.nempty, self.return_indices)
157 |
158 |
159 | class OctreeMaxUnpool(OctreePoolBase):
160 | r''' Performs octree max unpooling.
161 |
162 | Please refer to :func:`octree_max_unpool` for details.
163 | '''
164 |
165 | def forward(self, data: torch.Tensor, indices: torch.Tensor, octree: Octree,
166 | depth: int):
167 | r''''''
168 |
169 | return octree_max_unpool(data, indices, octree, depth, self.nempty)
170 |
171 |
172 | class OctreeGlobalPool(OctreePoolBase):
173 | r''' Performs octree global pooling.
174 |
175 | Please refer to :func:`octree_global_pool` for details.
176 | '''
177 |
178 | def __init__(self, nempty: bool = False):
179 | super().__init__(kernel_size=[-1], stride=-1, nempty=nempty)
180 |
181 | def forward(self, data: torch.Tensor, octree: Octree, depth: int):
182 | r''''''
183 |
184 | return octree_global_pool(data, octree, depth, self.nempty)
185 |
186 |
187 | class OctreeAvgPool(OctreePoolBase):
188 | r''' Performs octree average pooling.
189 |
190 | Please refer to :func:`octree_avg_pool` for details.
191 | '''
192 |
193 | def forward(self, data: torch.Tensor, octree: Octree, depth: int):
194 | r''''''
195 |
196 | return octree_avg_pool(
197 | data, octree, depth, self.kernel, self.stride, self.nempty)
198 |
--------------------------------------------------------------------------------
/utils/ocnn/octree/__init__.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Octree-based Sparse Convolutional Neural Networks
3 | # Copyright (c) 2022 Peng-Shuai Wang
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Peng-Shuai Wang
6 | # --------------------------------------------------------
7 |
8 | from .shuffled_key import key2xyz, xyz2key
9 | from .points import Points, merge_points
10 | from .octree import Octree, merge_octrees
11 |
12 | __all__ = [
13 | 'key2xyz',
14 | 'xyz2key',
15 | 'Points',
16 | 'Octree',
17 | 'merge_points',
18 | 'merge_octrees',
19 | ]
20 |
21 | classes = __all__
22 |
--------------------------------------------------------------------------------
/utils/ocnn/octree/points.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Octree-based Sparse Convolutional Neural Networks
3 | # Copyright (c) 2022 Peng-Shuai Wang
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Peng-Shuai Wang
6 | # --------------------------------------------------------
7 |
8 | import torch
9 | import numpy as np
10 | from typing import Optional, Union, List
11 |
12 |
13 | class Points:
14 | r''' Represents a point cloud and contains some elementary transformations.
15 |
16 | Args:
17 | points (torch.Tensor): The coordinates of the points with a shape of
18 | :obj:`(N, 3)`, where :obj:`N` is the number of points.
19 | normals (torch.Tensor or None): The point normals with a shape of
20 | :obj:`(N, 3)`.
21 | features (torch.Tensor or None): The point features with a shape of
22 | :obj:`(N, C)`, where :obj:`C` is the channel of features.
23 | labels (torch.Tensor or None): The point labels with a shape of
24 | :obj:`(N, K)`, where :obj:`K` is the channel of labels.
25 | batch_id (torch.Tensor or None): The batch indices for each point with a
26 | shape of :obj:`(N, 1)`.
27 | batch_size (int): The batch size.
28 | '''
29 |
30 | def __init__(self, points: torch.Tensor,
31 | normals: Optional[torch.Tensor] = None,
32 | features: Optional[torch.Tensor] = None,
33 | labels: Optional[torch.Tensor] = None,
34 | batch_id: Optional[torch.Tensor] = None,
35 | batch_size: int = 1):
36 | super().__init__()
37 | self.points = points
38 | self.normals = normals
39 | self.features = features
40 | self.labels = labels
41 | self.batch_id = batch_id
42 | self.batch_size = batch_size
43 | self.device = points.device
44 | self.batch_npt = None # valid after `merge_points`
45 |
46 | def orient_normal(self, axis: str = 'x'):
47 | r''' Orients the point normals along a given axis.
48 |
49 | Args:
50 | axis (int): The coordinate axes, choose from :obj:`x`, :obj:`y` and
51 | :obj:`z`. (default: :obj:`x`)
52 | '''
53 |
54 | if self.normals is None:
55 | return
56 |
57 | axis_map = {'x': 0, 'y': 1, 'z': 2, 'xyz': 3}
58 | idx = axis_map[axis]
59 | if idx < 3:
60 | flags = self.normals[:, idx] > 0
61 | flags = flags.float() * 2.0 - 1.0 # [0, 1] -> [-1, 1]
62 | self.normals = self.normals * flags.unsqueeze(1)
63 | else:
64 | self.normals.abs_()
65 |
66 | def scale(self, factor: torch.Tensor):
67 | r''' Rescales the point cloud.
68 |
69 | Args:
70 | factor (torch.Tensor): The scale factor with shape :obj:`(3,)`.
71 | '''
72 |
73 | non_zero = (factor != 0).all()
74 | all_ones = (factor == 1.0).all()
75 | non_uniform = (factor != factor[0]).any()
76 | assert non_zero, 'The scale factor must not constain 0.'
77 | if all_ones: return
78 |
79 | factor = factor.to(self.device)
80 | self.points = self.points * factor
81 | if self.normals is not None and non_uniform:
82 | ifactor = 1.0 / factor
83 | self.normals = self.normals * ifactor
84 | norm2 = torch.sqrt(torch.sum(self.normals ** 2, dim=1, keepdim=True))
85 | self.normals = self.normals / torch.clamp(norm2, min=1.0e-12)
86 |
87 | def rotate(self, angle: torch.Tensor):
88 | r''' Rotates the point cloud.
89 |
90 | Args:
91 | angle (torch.Tensor): The rotation angles in radian with shape :obj:`(3,)`.
92 | '''
93 |
94 | cos, sin = angle.cos(), angle.sin()
95 | # rotx, roty, rotz are actually the transpose of the rotation matrices
96 | rotx = torch.Tensor([[1, 0, 0], [0, cos[0], sin[0]], [0, -sin[0], cos[0]]])
97 | roty = torch.Tensor([[cos[1], 0, -sin[1]], [0, 1, 0], [sin[1], 0, cos[1]]])
98 | rotz = torch.Tensor([[cos[2], sin[2], 0], [-sin[2], cos[2], 0], [0, 0, 1]])
99 | rot = rotx @ roty @ rotz
100 |
101 | rot = rot.to(self.device)
102 | self.points = self.points @ rot
103 | if self.normals is not None:
104 | self.normals = self.normals @ rot
105 |
106 | def translate(self, dis: torch.Tensor):
107 | r''' Translates the point cloud.
108 |
109 | Args:
110 | dis (torch.Tensor): The displacement with shape :obj:`(3,)`.
111 | '''
112 |
113 | dis = dis.to(self.device)
114 | self.points = self.points + dis
115 |
116 | def flip(self, axis: str):
117 | r''' Flips the point cloud along the given :attr:`axis`.
118 |
119 | Args:
120 | axis (str): The flipping axis, choosen from :obj:`x`, :obj:`y`, and :obj`z`.
121 | '''
122 |
123 | axis_map = {'x': 0, 'y': 1, 'z': 2}
124 | idx = axis_map[axis]
125 | self.points[:, idx] *= -1.0
126 | if self.normals is not None:
127 | self.normals[:, idx] *= -1.0
128 |
129 | def clip(self, min: float = -1.0, max: float = 1.0, esp: float = 0.01):
130 | r''' Clips the point cloud to :obj:`[min+esp, max-esp]` and returns the mask.
131 |
132 | Args:
133 | min (float): The minimum value to clip.
134 | max (float): The maximum value to clip.
135 | esp (float): The margin.
136 | '''
137 |
138 | mask = self.inbox_mask(min + esp, max - esp)
139 | tmp = self.__getitem__(mask)
140 | self.__dict__.update(tmp.__dict__)
141 | return mask
142 |
143 | def __getitem__(self, mask: torch.Tensor):
144 | r''' Slices the point cloud according a given :attr:`mask`.
145 | '''
146 |
147 | dummy_pts = torch.zeros(1, 3, device=self.device)
148 | out = Points(dummy_pts, batch_size=self.batch_size)
149 |
150 | out.points = self.points[mask]
151 | if self.normals is not None:
152 | out.normals = self.normals[mask]
153 | if self.features is not None:
154 | out.features = self.features[mask]
155 | if self.labels is not None:
156 | out.labels = self.labels[mask]
157 | if self.batch_id is not None:
158 | out.batch_id = self.batch_id[mask]
159 | return out
160 |
161 | def inbox_mask(self, bbmin: Union[float, torch.Tensor] = -1.0,
162 | bbmax: Union[float, torch.Tensor] = 1.0):
163 | r''' Returns a mask indicating whether the points are within the specified
164 | bounding box or not.
165 | '''
166 |
167 | mask_min = torch.all(self.points > bbmin, dim=1)
168 | mask_max = torch.all(self.points < bbmax, dim=1)
169 | mask = torch.logical_and(mask_min, mask_max)
170 | return mask
171 |
172 | def bbox(self):
173 | r''' Returns the bounding box.
174 | '''
175 |
176 | # torch.min and torch.max return (value, indices)
177 | bbmin = self.points.min(dim=0)
178 | bbmax = self.points.max(dim=0)
179 | return bbmin[0], bbmax[0]
180 |
181 | def normalize(self, bbmin: torch.Tensor, bbmax: torch.Tensor,
182 | scale: float = 1.0):
183 | r''' Normalizes the point cloud to :obj:`[-scale, scale]`.
184 |
185 | Args:
186 | bbmin (torch.Tensor): The minimum coordinates of the bounding box.
187 | bbmax (torch.Tensor): The maximum coordinates of the bounding box.
188 | scale (float): The scale factor
189 | '''
190 |
191 | center = (bbmin + bbmax) * 0.5
192 | box_size = (bbmax - bbmin).max() + 1.0e-6
193 | self.points = (self.points - center) * (2.0 * scale / box_size)
194 |
195 | def to(self, device: Union[torch.device, str], non_blocking: bool = False):
196 | r''' Moves the Points to a specified device.
197 |
198 | Args:
199 | device (torch.device or str): The destination device.
200 | non_blocking (bool): If True and the source is in pinned memory, the copy
201 | will be asynchronous with respect to the host. Otherwise, the argument
202 | has no effect. Default: False.
203 | '''
204 |
205 | if isinstance(device, str):
206 | device = torch.device(device)
207 |
208 | # If on the save device, directly retrun self
209 | if self.device == device:
210 | return self
211 |
212 | # Construct a new Points on the specified device
213 | points = Points(torch.zeros(1, 3, device=device))
214 | points.batch_npt = self.batch_npt
215 | points.points = self.points.to(device, non_blocking=non_blocking)
216 | if self.normals is not None:
217 | points.normals = self.normals.to(device, non_blocking=non_blocking)
218 | if self.features is not None:
219 | points.features = self.features.to(device, non_blocking=non_blocking)
220 | if self.labels is not None:
221 | points.labels = self.labels.to(device, non_blocking=non_blocking)
222 | if self.batch_id is not None:
223 | points.batch_id = self.batch_id.to(device, non_blocking=non_blocking)
224 | return points
225 |
226 | def cuda(self, non_blocking: bool = False):
227 | r''' Moves the Points to the GPU. '''
228 |
229 | return self.to('cuda', non_blocking)
230 |
231 | def cpu(self):
232 | r''' Moves the Points to the CPU. '''
233 |
234 | return self.to('cpu')
235 |
236 | def save(self, filename: str, info: str = 'PNFL'):
237 | r''' Save the Points into npz or xyz files.
238 |
239 | Args:
240 | filename (str): The output filename.
241 | info (str): The infomation for saving: 'P' -> 'points', 'N' -> 'normals',
242 | 'F' -> 'features', 'L' -> 'labels', 'B' -> 'batch_id'.
243 | '''
244 |
245 | mapping = {
246 | 'P': ('points', self.points), 'N': ('normals', self.normals),
247 | 'F': ('features', self.features), 'L': ('labels', self.labels),
248 | 'B': ('batch_id', self.batch_id), }
249 |
250 | names, outs = [], []
251 | for key in info.upper():
252 | name, out = mapping[key]
253 | if out is not None:
254 | names.append(name)
255 | if out.dim() == 1:
256 | out = out.unsqueeze(1)
257 | outs.append(out.cpu().numpy())
258 |
259 | if filename.endswith('npz'):
260 | out_dict = dict(zip(names, outs))
261 | np.savez(filename, **out_dict)
262 | elif filename.endswith('xyz'):
263 | out_array = np.concatenate(outs, axis=1)
264 | np.savetxt(filename, out_array, fmt='%.6f')
265 | else:
266 | raise ValueError
267 |
268 |
269 | def merge_points(points: List['Points'], update_batch_info: bool = True):
270 | r''' Merges a list of points into one batch.
271 |
272 | Args:
273 | points (List[Octree]): A list of points to merge. The batch size of each
274 | points in the list is assumed to be 1, and the :obj:`batch_size`,
275 | :obj:`batch_id`, and :obj:`batch_npt` in the points are ignored.
276 | '''
277 |
278 | out = Points(torch.zeros(1, 3))
279 | out.points = torch.cat([p.points for p in points], dim=0)
280 | if points[0].normals is not None:
281 | out.normals = torch.cat([p.normals for p in points], dim=0)
282 | if points[0].features is not None:
283 | out.features = torch.cat([p.features for p in points], dim=0)
284 | if points[0].labels is not None:
285 | out.labels = torch.cat([p.labels for p in points], dim=0)
286 | out.device = points[0].device
287 |
288 | if update_batch_info:
289 | out.batch_size = len(points)
290 | out.batch_npt = torch.Tensor([p.points.shape[0] for p in points])
291 | out.batch_id = torch.cat([p.points.new_full((p.points.shape[0], 1), i)
292 | for i, p in enumerate(points)], dim=0)
293 | return out
294 |
--------------------------------------------------------------------------------
/utils/ocnn/octree/shuffled_key.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Octree-based Sparse Convolutional Neural Networks
3 | # Copyright (c) 2022 Peng-Shuai Wang
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Peng-Shuai Wang
6 | # --------------------------------------------------------
7 |
8 | import torch
9 | from typing import Optional, Union
10 |
11 |
12 | class KeyLUT:
13 |
14 | def __init__(self):
15 | r256 = torch.arange(256, dtype=torch.int64)
16 | r512 = torch.arange(512, dtype=torch.int64)
17 | zero = torch.zeros(256, dtype=torch.int64)
18 | device = torch.device('cpu')
19 |
20 | self._encode = {device: (self.xyz2key(r256, zero, zero, 8),
21 | self.xyz2key(zero, r256, zero, 8),
22 | self.xyz2key(zero, zero, r256, 8))}
23 | self._decode = {device: self.key2xyz(r512, 9)}
24 |
25 | def encode_lut(self, device=torch.device('cpu')):
26 | if device not in self._encode:
27 | cpu = torch.device('cpu')
28 | self._encode[device] = tuple(e.to(device) for e in self._encode[cpu])
29 | return self._encode[device]
30 |
31 | def decode_lut(self, device=torch.device('cpu')):
32 | if device not in self._decode:
33 | cpu = torch.device('cpu')
34 | self._decode[device] = tuple(e.to(device) for e in self._decode[cpu])
35 | return self._decode[device]
36 |
37 | def xyz2key(self, x, y, z, depth):
38 | key = torch.zeros_like(x)
39 | for i in range(depth):
40 | mask = 1 << i
41 | key = (key | ((x & mask) << (2 * i + 2)) |
42 | ((y & mask) << (2 * i + 1)) |
43 | ((z & mask) << (2 * i + 0)))
44 | return key
45 |
46 | def key2xyz(self, key, depth):
47 | x = torch.zeros_like(key)
48 | y = torch.zeros_like(key)
49 | z = torch.zeros_like(key)
50 | for i in range(depth):
51 | x = x | ((key & (1 << (3 * i + 2))) >> (2 * i + 2))
52 | y = y | ((key & (1 << (3 * i + 1))) >> (2 * i + 1))
53 | z = z | ((key & (1 << (3 * i + 0))) >> (2 * i + 0))
54 | return x, y, z
55 |
56 |
57 | _key_lut = KeyLUT()
58 |
59 |
60 | def xyz2key(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor,
61 | b: Optional[Union[torch.Tensor, int]] = None, depth: int = 16):
62 | r'''Encodes :attr:`x`, :attr:`y`, :attr:`z` coordinates to the shuffled keys
63 | based on pre-computed look up tables. The speed of this function is much
64 | faster than the method based on for-loop.
65 |
66 | Args:
67 | x (torch.Tensor): The x coordinate.
68 | y (torch.Tensor): The y coordinate.
69 | z (torch.Tensor): The z coordinate.
70 | b (torch.Tensor or int): The batch index of the coordinates, and should be
71 | smaller than 1024. If :attr:`b` is :obj:`torch.Tensor`, the size of
72 | :attr:`b` must be the same as :attr:`x`, :attr:`y`, and :attr:`z`.
73 | depth (int): The depth of the shuffled key, and must be smaller than 17 (< 17).
74 | '''
75 |
76 | EX, EY, EZ = _key_lut.encode_lut(x.device)
77 | x, y, z = x.long(), y.long(), z.long()
78 |
79 | mask = 255 if depth > 8 else (1 << depth) - 1
80 | key = EX[x & mask] | EY[y & mask] | EZ[z & mask]
81 | if depth > 8:
82 | mask = (1 << (depth-8)) - 1
83 | key16 = EX[(x >> 8) & mask] | EY[(y >> 8) & mask] | EZ[(z >> 8) & mask]
84 | key = key16 << 24 | key
85 |
86 | if b is not None:
87 | b = b.long()
88 | key = b << 48 | key
89 |
90 | return key
91 |
92 |
93 | def key2xyz(key: torch.Tensor, depth: int = 16):
94 | r'''Decodes the shuffled key to :attr:`x`, :attr:`y`, :attr:`z` coordinates
95 | and the batch index based on pre-computed look up tables.
96 |
97 | Args:
98 | key (torch.Tensor): The shuffled key.
99 | depth (int): The depth of the shuffled key, and must be smaller than 17 (< 17).
100 | '''
101 |
102 | DX, DY, DZ = _key_lut.decode_lut(key.device)
103 | x, y, z = torch.zeros_like(key), torch.zeros_like(key), torch.zeros_like(key)
104 |
105 | b = key >> 48
106 | key = key & ((1 << 48) - 1)
107 |
108 | n = (depth + 2) // 3
109 | for i in range(n):
110 | k = key >> (i * 9) & 511
111 | x = x | (DX[k] << (i * 3))
112 | y = y | (DY[k] << (i * 3))
113 | z = z | (DZ[k] << (i * 3))
114 |
115 | return x, y, z, b
116 |
--------------------------------------------------------------------------------
/utils/ocnn/utils.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Octree-based Sparse Convolutional Neural Networks
3 | # Copyright (c) 2022 Peng-Shuai Wang
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Peng-Shuai Wang
6 | # --------------------------------------------------------
7 |
8 | import math
9 | import torch
10 | from typing import Optional
11 |
12 | __all__ = ['trunc_div', 'meshgrid', 'cumsum', 'scatter_add', 'xavier_uniform_',
13 | 'resize_with_last_val', 'list2str']
14 | classes = __all__
15 |
16 |
17 | def trunc_div(input, other):
18 | r''' Wraps :func:`torch.div` for compatibility. It rounds the results of the
19 | division towards zero and is equivalent to C-style integer division.
20 | '''
21 |
22 | version = torch.__version__.split('.')
23 | larger_than_170 = int(version[0]) > 0 and int(version[1]) > 7
24 |
25 | if larger_than_170:
26 | return torch.div(input, other, rounding_mode='trunc')
27 | else:
28 | return torch.floor_divide(input, other)
29 |
30 |
31 | def meshgrid(*tensors, indexing: Optional[str] = None):
32 | r''' Wraps :func:`torch.meshgrid` for compatibility.
33 | '''
34 |
35 | version = torch.__version__.split('.')
36 | larger_than_190 = int(version[0]) > 0 and int(version[1]) > 9
37 |
38 | if larger_than_190:
39 | return torch.meshgrid(*tensors, indexing=indexing)
40 | else:
41 | return torch.meshgrid(*tensors)
42 |
43 |
44 | def cumsum(data: torch.Tensor, dim: int, exclusive: bool = False):
45 | r''' Extends :func:`torch.cumsum` with the input argument :attr:`exclusive`.
46 |
47 | Args:
48 | data (torch.Tensor): The input data.
49 | dim (int): The dimension to do the operation over.
50 | exclusive (bool): If false, the behavior is the same as :func:`torch.cumsum`;
51 | if true, returns the cumulative sum exclusively. Note that if ture,
52 | the shape of output tensor is larger by 1 than :attr:`data` in the
53 | dimension where the computation occurs.
54 | '''
55 |
56 | out = torch.cumsum(data, dim)
57 |
58 | if exclusive:
59 | size = list(data.size())
60 | size[dim] = 1
61 | zeros = out.new_zeros(size)
62 | out = torch.cat([zeros, out], dim)
63 | return out
64 |
65 |
66 | def broadcast(src: torch.Tensor, other: torch.Tensor, dim: int):
67 | r''' Broadcast :attr:`src` according to :attr:`other`, originally from the
68 | library `pytorch_scatter`.
69 | '''
70 |
71 | if dim < 0:
72 | dim = other.dim() + dim
73 |
74 | if src.dim() == 1:
75 | for _ in range(0, dim):
76 | src = src.unsqueeze(0)
77 | for _ in range(src.dim(), other.dim()):
78 | src = src.unsqueeze(-1)
79 |
80 | src = src.expand_as(other)
81 | return src
82 |
83 |
84 | def scatter_add(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
85 | out: Optional[torch.Tensor] = None,
86 | dim_size: Optional[int] = None,) -> torch.Tensor:
87 | r''' Reduces all values from the :attr:`src` tensor into :attr:`out` at the
88 | indices specified in the :attr:`index` tensor along a given axis :attr:`dim`.
89 | This is just a wrapper of :func:`torch.scatter` in a boardcasting fashion.
90 |
91 | Args:
92 | src (torch.Tensor): The source tensor.
93 | index (torch.Tensor): The indices of elements to scatter.
94 | dim (torch.Tensor): The axis along which to index, (default: :obj:`-1`).
95 | out (torch.Tensor or None): The destination tensor.
96 | dim_size (int or None): If :attr:`out` is not given, automatically create
97 | output with size :attr:`dim_size` at dimension :attr:`dim`. If
98 | :attr:`dim_size` is not given, a minimal sized output tensor according
99 | to :obj:`index.max() + 1` is returned.
100 | '''
101 |
102 | index = broadcast(index, src, dim)
103 |
104 | if out is None:
105 | size = list(src.size())
106 | if dim_size is not None:
107 | size[dim] = dim_size
108 | elif index.numel() == 0:
109 | size[dim] = 0
110 | else:
111 | size[dim] = int(index.max()) + 1
112 | out = torch.zeros(size, dtype=src.dtype, device=src.device)
113 |
114 | return out.scatter_add_(dim, index, src)
115 |
116 |
117 | def xavier_uniform_(weights: torch.Tensor):
118 | r''' Initialize convolution weights with the same method as
119 | :obj:`torch.nn.init.xavier_uniform_`.
120 |
121 | :obj:`torch.nn.init.xavier_uniform_` initialize a tensor with shape
122 | :obj:`(out_c, in_c, kdim)`. It can not be used in :class:`ocnn.nn.OctreeConv`
123 | since the the shape of :attr:`OctreeConv.weights` is :obj:`(kdim, in_c,
124 | out_c)`.
125 | '''
126 |
127 | shape = weights.shape # (kernel_dim, in_conv, out_conv)
128 | fan_in = shape[0] * shape[1]
129 | fan_out = shape[0] * shape[2]
130 | std = math.sqrt(2.0 / float(fan_in + fan_out))
131 | a = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
132 |
133 | torch.nn.init.uniform_(weights, -a, a)
134 |
135 |
136 | def resize_with_last_val(list_in: list, num: int = 3):
137 | r''' Resizes the number of elements of :attr:`list_in` to :attr:`num` with
138 | the last element of :attr:`list_in` if its number of elements is smaller
139 | than :attr:`num`.
140 | '''
141 |
142 | assert (type(list_in) is list and len(list_in) < num + 1)
143 | for i in range(len(list_in), num):
144 | list_in.append(list_in[-1])
145 | return list_in
146 |
147 |
148 | def list2str(list_in: list):
149 | r''' Returns a string representation of :attr:`list_in`
150 | '''
151 |
152 | out = [str(x) for x in list_in]
153 | return ''.join(out)
154 |
--------------------------------------------------------------------------------
/utils/thsolver/__init__.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Octree-based Sparse Convolutional Neural Networks
3 | # Copyright (c) 2022 Peng-Shuai Wang
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Peng-Shuai Wang
6 | # --------------------------------------------------------
7 |
8 | from . import config
9 | from .config import get_config, parse_args
10 |
11 | from . import solver
12 | from .solver import Solver
13 |
14 | from . import dataset
15 | from .dataset import Dataset
16 |
17 |
18 | __all__ = [
19 | 'config', 'get_config', 'parse_args',
20 | 'solver', 'Solver',
21 | 'dataset', 'Dataset'
22 | ]
23 |
--------------------------------------------------------------------------------
/utils/thsolver/config.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Octree-based Sparse Convolutional Neural Networks
3 | # Copyright (c) 2022 Peng-Shuai Wang
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Peng-Shuai Wang
6 | # --------------------------------------------------------
7 |
8 | # autopep8: off
9 | import os
10 | import sys
11 | import shutil
12 | import argparse
13 | from datetime import datetime
14 | from yacs.config import CfgNode as CN
15 |
16 | _C = CN()
17 |
18 | # SOLVER related parameters
19 | _C.SOLVER = CN()
20 | _C.SOLVER.alias = '' # The experiment alias
21 | _C.SOLVER.gpu = (0,) # The gpu ids
22 | _C.SOLVER.run = 'train' # Choose from train or test
23 |
24 | _C.SOLVER.logdir = 'logs' # Directory where to write event logs
25 | _C.SOLVER.ckpt = '' # Restore weights from checkpoint file
26 | _C.SOLVER.ckpt_num = 10 # The number of checkpoint kept
27 |
28 | _C.SOLVER.type = 'sgd' # Choose from sgd or adam
29 | _C.SOLVER.weight_decay = 0.0005 # The weight decay on model weights
30 | _C.SOLVER.clip_grad = -1.0 # Clip gradient norm (-1: disable)
31 | _C.SOLVER.max_epoch = 300 # Maximum training epoch
32 | _C.SOLVER.warmup_epoch = 20 # The warmup epoch number
33 | _C.SOLVER.warmup_init = 0.001 # The initial ratio of the warmup
34 | _C.SOLVER.eval_epoch = 1 # Maximum evaluating epoch
35 | _C.SOLVER.eval_step = -1 # Maximum evaluating steps
36 | _C.SOLVER.test_every_epoch = 10 # Test model every n training epochs
37 | _C.SOLVER.log_per_iter = -1 # Output log every k training iteration
38 |
39 | _C.SOLVER.lr_type = 'step' # Learning rate type: step or cos
40 | _C.SOLVER.lr = 0.1 # Initial learning rate
41 | _C.SOLVER.lr_min = 0.0001 # The minimum learning rate
42 | _C.SOLVER.gamma = 0.1 # Learning rate step-wise decay
43 | _C.SOLVER.milestones = (120,180,) # Learning rate milestones
44 | _C.SOLVER.lr_power = 0.9 # Used in poly learning rate
45 |
46 | _C.SOLVER.dist_url = 'tcp://localhost:10001'
47 | _C.SOLVER.progress_bar = True # Enable the progress_bar or not
48 | _C.SOLVER.rand_seed = -1 # Fix the random seed if larger than 0
49 | _C.SOLVER.empty_cache = True # Empty cuda cache periodically
50 |
51 | # DATA related parameters
52 | _C.DATA = CN()
53 | _C.DATA.train = CN()
54 | _C.DATA.train.name = '' # The name of the dataset
55 | _C.DATA.train.disable = False # Disable this dataset or not
56 |
57 | # For octree building
58 | _C.DATA.train.depth = 5 # The octree depth
59 | _C.DATA.train.full_depth = 2 # The full depth
60 | _C.DATA.train.adaptive = False # Build the adaptive octree
61 |
62 | # For transformation
63 | _C.DATA.train.orient_normal = '' # Used to re-orient normal directions
64 |
65 | # For data augmentation
66 | _C.DATA.train.distort = False # Whether to apply data augmentation
67 | _C.DATA.train.scale = 0.0 # Scale the points
68 | _C.DATA.train.uniform = False # Generate uniform scales
69 | _C.DATA.train.jitter = 0.0 # Jitter the points
70 | _C.DATA.train.interval = (1, 1, 1) # Use interval&angle to generate random angle
71 | _C.DATA.train.angle = (180, 180, 180)
72 |
73 | # For data loading
74 | _C.DATA.train.location = '' # The data location
75 | _C.DATA.train.filelist = '' # The data filelist
76 | _C.DATA.train.batch_size = 32 # Training data batch size
77 | _C.DATA.train.take = -1 # Number of samples used for training
78 | _C.DATA.train.num_workers = 4 # Number of workers to load the data
79 | _C.DATA.train.shuffle = False # Shuffle the input data
80 | _C.DATA.train.in_memory = False # Load the training data into memory
81 |
82 |
83 | _C.DATA.test = _C.DATA.train.clone()
84 | _C.DATA.test.num_workers = 2
85 |
86 | # MODEL related parameters
87 | _C.MODEL = CN()
88 | _C.MODEL.name = '' # The name of the model
89 | _C.MODEL.feature = 'ND' # The input features
90 | _C.MODEL.channel = 3 # The input feature channel
91 | _C.MODEL.nout = 40 # The output feature channel
92 | _C.MODEL.nempty = False # Perform Octree Conv on non-empty octree nodes
93 |
94 | _C.MODEL.stages = 3
95 | _C.MODEL.resblock_num = 3 # The resblock number
96 | _C.MODEL.resblock_type = 'bottleneck'# Choose from 'bottleneck' and 'basic
97 | _C.MODEL.bottleneck = 4 # The bottleneck factor of one resblock
98 |
99 | _C.MODEL.upsample = 'nearest' # The method used for upsampling
100 | _C.MODEL.interp = 'linear' # The interplation method: linear or nearest
101 |
102 | _C.MODEL.sync_bn = False # Use sync_bn when training the network
103 | _C.MODEL.use_checkpoint = False # Use checkpoint to save memory
104 | _C.MODEL.find_unused_parameters = False # Used in DistributedDataParallel
105 |
106 | _C.MODEL.num_edge_types = 7
107 | _C.MODEL.conv_type = "SAGE"
108 | _C.MODEL.include_distance = False
109 | _C.MODEL.normal_aware_pooling = False
110 | _C.MODEL.test_mesh = "TBD"
111 | _C.MODEL.get_test_stat = False
112 |
113 |
114 | # loss related parameters
115 | _C.LOSS = CN()
116 | _C.LOSS.name = '' # The name of the loss
117 | _C.LOSS.num_class = 40 # The class number for the cross-entropy loss
118 | _C.LOSS.weights = (1.0, 1.0) # The weight factors for different losses
119 | _C.LOSS.label_smoothing = 0.0 # The factor of label smoothing
120 |
121 |
122 | # backup the commands
123 | _C.SYS = CN()
124 | _C.SYS.cmds = '' # Used to backup the commands
125 |
126 | FLAGS = _C
127 |
128 |
129 | def _update_config(FLAGS, args):
130 | FLAGS.defrost()
131 | if args.config:
132 | FLAGS.merge_from_file(args.config)
133 | if args.opts:
134 | FLAGS.merge_from_list(args.opts)
135 | FLAGS.SYS.cmds = ' '.join(sys.argv)
136 |
137 | # update logdir
138 | alias = FLAGS.SOLVER.alias.lower()
139 | if 'time' in alias: # 'time' is a special keyword
140 | alias = alias.replace('time', datetime.now().strftime('%m%d%H%M')) #%S
141 | if alias != '':
142 | FLAGS.SOLVER.logdir += '_' + alias
143 | FLAGS.freeze()
144 |
145 |
146 | def _backup_config(FLAGS, args):
147 | logdir = FLAGS.SOLVER.logdir
148 | if not os.path.exists(logdir):
149 | os.makedirs(logdir)
150 | # copy the file to logdir
151 | if args.config:
152 | shutil.copy2(args.config, logdir)
153 | # dump all configs
154 | filename = os.path.join(logdir, 'all_configs.yaml')
155 | with open(filename, 'w') as fid:
156 | fid.write(FLAGS.dump())
157 |
158 |
159 | def _set_env_var(FLAGS):
160 | gpus = ','.join([str(a) for a in FLAGS.SOLVER.gpu])
161 | os.environ['CUDA_VISIBLE_DEVICES'] = gpus
162 |
163 |
164 | def get_config():
165 | return FLAGS
166 |
167 | def parse_args(backup=True):
168 | parser = argparse.ArgumentParser(description='The configs')
169 | parser.add_argument('--config', type=str,
170 | help='experiment configure file name')
171 | parser.add_argument('opts', nargs=argparse.REMAINDER,
172 | help="Modify config options using the command-line")
173 |
174 | args = parser.parse_args()
175 | _update_config(FLAGS, args)
176 | if backup:
177 | _backup_config(FLAGS, args)
178 | # _set_env_var(FLAGS)
179 | return FLAGS
180 |
181 |
182 | if __name__ == '__main__':
183 | flags = parse_args(backup=False)
184 | print(flags)
185 |
--------------------------------------------------------------------------------
/utils/thsolver/dataset.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Octree-based Sparse Convolutional Neural Networks
3 | # Copyright (c) 2022 Peng-Shuai Wang
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Peng-Shuai Wang
6 | # --------------------------------------------------------
7 |
8 | import os
9 | import torch
10 | import torch.utils.data
11 | import numpy as np
12 | from tqdm import tqdm
13 |
14 |
15 | def read_file(filename):
16 | points = np.fromfile(filename, dtype=np.uint8)
17 | return torch.from_numpy(points) # convert it to torch.tensor
18 |
19 |
20 | class Dataset(torch.utils.data.Dataset):
21 |
22 | def __init__(self, root, filelist, transform, read_file=read_file,
23 | in_memory=False, take: int = -1):
24 | super(Dataset, self).__init__()
25 | self.root = root
26 | self.filelist = filelist
27 | self.transform = transform
28 | self.in_memory = in_memory
29 | self.read_file = read_file
30 | self.take = take
31 |
32 | self.filenames, self.labels = self.load_filenames()
33 | if self.in_memory:
34 | print('Load files into memory from ' + self.filelist)
35 | self.samples = [self.read_file(os.path.join(self.root, f))
36 | for f in tqdm(self.filenames, ncols=80, leave=False)]
37 |
38 | def __len__(self):
39 | return len(self.filenames)
40 |
41 | def __getitem__(self, idx):
42 | sample = self.samples[idx] if self.in_memory else \
43 | self.read_file(os.path.join(self.root, self.filenames[idx])) # noqa
44 | output = self.transform(sample, idx) # data augmentation + build octree
45 | output['label'] = self.labels[idx]
46 | output['filename'] = self.filenames[idx]
47 | return output
48 |
49 | def load_filenames(self):
50 | filenames, labels = [], []
51 | with open(self.filelist) as fid:
52 | lines = fid.readlines()
53 | for line in lines:
54 | tokens = line.split()
55 | filename = tokens[0]
56 | label = tokens[1] if len(tokens) == 2 else 0
57 | filenames.append(filename)
58 | labels.append(int(label))
59 |
60 | num = len(filenames)
61 | if self.take > num or self.take < 1:
62 | self.take = num
63 |
64 | return filenames[:self.take], labels[:self.take]
65 |
66 |
67 | class DatasetGraph(torch.utils.data.Dataset):
68 | # 对 Dataset 做了一些小改动:原先dataset只能读取单一类型的文件,无法把图和GT一起读入
69 | # 现在能了
70 | def __init__(self, root, filelist, transform, read_file=read_file,
71 | in_memory=False, take: int = -1):
72 | super(Dataset, self).__init__()
73 | self.root = root
74 | self.filelist = filelist
75 | self.transform = transform
76 | self.in_memory = in_memory
77 | self.read_file = read_file
78 | self.take = take
79 |
80 | self.filenames, self.labels = self.load_filenames()
81 | if self.in_memory:
82 | print('Load files into memory from ' + self.filelist)
83 | self.samples = [self.read_file(os.path.join(self.root, f))
84 | for f in tqdm(self.filenames, ncols=80, leave=False)]
85 |
86 | def __len__(self):
87 | return len(self.filenames)
88 |
89 | def __getitem__(self, idx):
90 | # GT
91 | gts = self.samples[idx] if self.in_memory else \
92 | self.read_file(os.path.join(self.root, self.filenames[idx]))
93 | vertices = torch.from_numpy(gts['vertices'])
94 | normals = torch.from_numpy(gts['normals'])
95 | dist = torch.from_numpy(gts['dist'])
96 |
97 | rnd_idx = torch.randint(low=0, high=dist.shape[0], size=(100000,))
98 | dist = dist[rnd_idx]
99 |
100 | points = Points(points=vertices, normals=normals, )
101 | octree = self.points2octree(points)
102 | return {'points': points, 'octree': octree, 'dist': dist}
103 |
104 |
105 | # graph
106 | graph = self.samples[idx] if self.in_memory else \
107 | self.read_file(os.path.join(self.root, self.filenames[idx]))
108 |
109 | output['label'] = self.labels[idx]
110 | output['filename'] = self.filenames[idx]
111 | return output
112 |
113 | def load_filenames(self):
114 | filenames, labels = [], []
115 | with open(self.filelist) as fid:
116 | lines = fid.readlines()
117 | for line in lines:
118 | tokens = line.split()
119 | filename = tokens[0]
120 | label = tokens[1] if len(tokens) == 2 else 0
121 | filenames.append(filename)
122 | labels.append(int(label))
123 |
124 | num = len(filenames)
125 | if self.take > num or self.take < 1:
126 | self.take = num
127 |
128 | return filenames[:self.take], labels[:self.take]
129 |
--------------------------------------------------------------------------------
/utils/thsolver/default_settings.py:
--------------------------------------------------------------------------------
1 | """
2 | build a dict containing all settings which can be accessed globally.
3 |
4 | Usage:
5 | all settings defined in gnndist.yaml can be accessed anywhere like this:
6 |
7 | from utils.thsolver import default_settings
8 | default_settings.get_global_value("max_epoch")
9 |
10 | """
11 |
12 | _global_dict = None
13 |
14 | def _init(FLAGS=None):
15 | global _global_dict
16 | _global_dict = {
17 | # "normal_aware_pooling": True,
18 | # "num_edge_types": 7,
19 | }
20 |
21 | def set_global_value(key, value):
22 | _global_dict[key] = value
23 |
24 | def set_global_values(FLAGS):
25 | """set global values from the FLAGS in thsolver.solver
26 | """
27 | for each in FLAGS:
28 | for it in FLAGS[each]:
29 | _global_dict[it] = FLAGS[each][it]
30 |
31 | def get_global_value(key):
32 | return _global_dict[key]
--------------------------------------------------------------------------------
/utils/thsolver/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Octree-based Sparse Convolutional Neural Networks
3 | # Copyright (c) 2022 Peng-Shuai Wang
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Peng-Shuai Wang
6 | # --------------------------------------------------------
7 |
8 | import math
9 | from bisect import bisect_right
10 | import torch.optim.lr_scheduler as LR
11 |
12 |
13 | def multi_step(optimizer, flags):
14 | return LR.MultiStepLR(optimizer, flags.milestones, flags.gamma)
15 |
16 |
17 | def cos(optimizer, flags):
18 | return LR.CosineAnnealingLR(optimizer, flags.max_epoch, eta_min=flags.lr_min)
19 |
20 |
21 | def poly(optimizer, flags):
22 | lr_lambda = lambda epoch: (1 - epoch / flags.max_epoch) ** flags.lr_power
23 | return LR.LambdaLR(optimizer, lr_lambda)
24 |
25 |
26 | def constant(optimizer, flags):
27 | lr_lambda = lambda epoch: 1
28 | return LR.LambdaLR(optimizer, lr_lambda)
29 |
30 |
31 | def cos_warmup(optimizer, flags):
32 | def lr_lambda(epoch):
33 | warmup = flags.warmup_epoch
34 | warmup_init = flags.warmup_init
35 | if epoch <= warmup:
36 | return (1 - warmup_init) * epoch / warmup + warmup_init
37 | else:
38 | lr_min = flags.lr_min
39 | ratio = (epoch - warmup) / (flags.max_epoch - warmup)
40 | return lr_min + 0.5 * (1.0 - lr_min) * (1 + math.cos(math.pi * ratio))
41 | return LR.LambdaLR(optimizer, lr_lambda)
42 |
43 |
44 | def poly_warmup(optimizer, flags):
45 | def lr_lambda(epoch):
46 | warmup = flags.warmup_epoch
47 | warmup_init = flags.warmup_init
48 | if epoch <= warmup:
49 | return (1 - warmup_init) * epoch / warmup + warmup_init
50 | else:
51 | ratio = (epoch - warmup) / (flags.max_epoch - warmup)
52 | return (1 - ratio) ** flags.lr_power
53 | return LR.LambdaLR(optimizer, lr_lambda)
54 |
55 |
56 | def step_warmup(optimizer, flags):
57 | def lr_lambda(epoch):
58 | warmup = flags.warmup_epoch
59 | warmup_init = flags.warmup_init
60 | if epoch <= warmup:
61 | return (1 - warmup_init) * epoch / warmup + warmup_init
62 | else:
63 | milestones = sorted(flags.milestones)
64 | return flags.gamma ** bisect_right(milestones, epoch)
65 | return LR.LambdaLR(optimizer, lr_lambda)
66 |
67 |
68 | def get_lr_scheduler(optimizer, flags):
69 | lr_dict = {'cos': cos, 'step': multi_step, 'cos': cos, 'poly': poly,
70 | 'constant': constant, 'cos_warmup': cos_warmup,
71 | 'poly_warmup': poly_warmup, 'step_warmup': step_warmup}
72 | lr_func = lr_dict[flags.lr_type]
73 | return lr_func(optimizer, flags)
74 |
--------------------------------------------------------------------------------
/utils/thsolver/sampler.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Octree-based Sparse Convolutional Neural Networks
3 | # Copyright (c) 2022 Peng-Shuai Wang
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Peng-Shuai Wang
6 | # --------------------------------------------------------
7 |
8 | import torch
9 | from torch.utils.data import Sampler, DistributedSampler, Dataset
10 |
11 |
12 | class InfSampler(Sampler):
13 | def __init__(self, dataset: Dataset, shuffle: bool = True) -> None:
14 | self.dataset = dataset
15 | self.shuffle = shuffle
16 | self.reset_sampler()
17 |
18 | def reset_sampler(self):
19 | num = len(self.dataset)
20 | indices = torch.randperm(num) if self.shuffle else torch.arange(num)
21 | self.indices = indices.tolist()
22 | self.iter_num = 0
23 |
24 | def __iter__(self):
25 | return self
26 |
27 | def __next__(self):
28 | value = self.indices[self.iter_num]
29 | self.iter_num = self.iter_num + 1
30 |
31 | if self.iter_num >= len(self.indices):
32 | self.reset_sampler()
33 | return value
34 |
35 | def __len__(self):
36 | return len(self.dataset)
37 |
38 |
39 | class DistributedInfSampler(DistributedSampler):
40 | def __init__(self, dataset: Dataset, shuffle: bool = True) -> None:
41 | super().__init__(dataset, shuffle=shuffle)
42 | self.reset_sampler()
43 |
44 | def reset_sampler(self):
45 | self.indices = list(super().__iter__())
46 | self.iter_num = 0
47 |
48 | def __iter__(self):
49 | return self
50 |
51 | def __next__(self):
52 | value = self.indices[self.iter_num]
53 | self.iter_num = self.iter_num + 1
54 |
55 | if self.iter_num >= len(self.indices):
56 | self.reset_sampler()
57 | return value
58 |
--------------------------------------------------------------------------------
/utils/thsolver/tracker.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Octree-based Sparse Convolutional Neural Networks
3 | # Copyright (c) 2022 Peng-Shuai Wang
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Peng-Shuai Wang
6 | # --------------------------------------------------------
7 |
8 | import time
9 | import torch
10 | import torch.distributed
11 | from datetime import datetime
12 | from tqdm import tqdm
13 | from typing import Dict
14 |
15 |
16 | class AverageTracker:
17 |
18 | def __init__(self):
19 | self.value = None
20 | self.num = 0.0
21 | self.max_len = 76
22 | self.start_time = time.time()
23 |
24 | def update(self, value: Dict[str, torch.Tensor]):
25 | if not value:
26 | return # empty input, return
27 |
28 | value = {key: val.detach() for key, val in value.items()}
29 | if self.value is None:
30 | self.value = value
31 | else:
32 | for key, val in value.items():
33 | self.value[key] += val
34 | self.num += 1
35 |
36 | def average(self):
37 | return {key: val.item() / self.num for key, val in self.value.items()}
38 |
39 | @torch.no_grad()
40 | def average_all_gather(self):
41 | for key, tensor in self.value.items():
42 | if not tensor.is_cuda: continue
43 | tensors_gather = [torch.ones_like(tensor)
44 | for _ in range(torch.distributed.get_world_size())]
45 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
46 | tensors = torch.stack(tensors_gather, dim=0)
47 | self.value[key] = torch.mean(tensors)
48 |
49 | def log(self, epoch, summary_writer=None, log_file=None, msg_tag='->',
50 | notes='', print_time=True, print_memory=False):
51 | if not self.value: return # empty, return
52 |
53 | avg = self.average()
54 | msg = 'Epoch: %d' % epoch
55 | for key, val in avg.items():
56 | msg += ', %s: %.3f' % (key, val)
57 | if summary_writer is not None:
58 | summary_writer.add_scalar(key, val, epoch)
59 |
60 | # if the log_file is provided, save the log
61 | if log_file is not None:
62 | with open(log_file, 'a') as fid:
63 | fid.write(msg + '\n')
64 |
65 | # memory
66 | memory = ''
67 | if print_memory and torch.cuda.is_available():
68 | size = torch.cuda.memory_reserved()
69 | # size = torch.cuda.memory_allocated()
70 | memory = ', memory: {:.3f}GB'.format(size / 2**30)
71 |
72 | # time
73 | time_str = ''
74 | if print_time:
75 | curr_time = ', time: ' + datetime.now().strftime("%Y/%m/%d %H:%M:%S")
76 | duration = ', duration: {:.2f}s'.format(time.time() - self.start_time)
77 | time_str = curr_time + duration
78 |
79 | # other notes
80 | if notes:
81 | notes = ', ' + notes
82 |
83 | # concatenate all messages
84 | msg += memory + time_str + notes
85 |
86 | # split the msg for better display
87 | chunks = [msg[i:i+self.max_len] for i in range(0, len(msg), self.max_len)]
88 | msg = (msg_tag + ' ') + ('\n' + len(msg_tag) * ' ' + ' ').join(chunks)
89 | tqdm.write(msg)
90 |
--------------------------------------------------------------------------------