├── .gitignore
├── README.md
├── config
├── rio.yaml
└── scannet.yaml
├── docs
├── rerun_interface.png
└── system.001.png
├── requirements.txt
├── setup.py
├── sgreg
├── __init__.py
├── backbone
│ ├── __init__.py
│ ├── backbone.py
│ ├── senet.py
│ └── shape_encoder.py
├── bert
│ ├── bertwarper.py
│ └── get_tokenizer.py
├── dataset
│ ├── __init__.py
│ ├── dataset_factory.py
│ ├── generate_gt_association.py
│ ├── prepare_semantics.py
│ ├── scene_graph.py
│ ├── scene_pair_dataset.py
│ └── stack_mode.py
├── extensions
│ ├── README.md
│ ├── common
│ │ └── torch_helper.h
│ ├── cpu
│ │ ├── grid_subsampling
│ │ │ ├── grid_subsampling.cpp
│ │ │ ├── grid_subsampling.h
│ │ │ ├── grid_subsampling_cpu.cpp
│ │ │ └── grid_subsampling_cpu.h
│ │ └── radius_neighbors
│ │ │ ├── radius_neighbors.cpp
│ │ │ ├── radius_neighbors.h
│ │ │ ├── radius_neighbors_cpu.cpp
│ │ │ └── radius_neighbors_cpu.h
│ ├── extra
│ │ ├── cloud
│ │ │ ├── cloud.cpp
│ │ │ └── cloud.h
│ │ └── nanoflann
│ │ │ └── nanoflann.hpp
│ └── pybind.cpp
├── gnn
│ ├── __init__.py
│ ├── gnn.py
│ ├── nodes_init_layer.py
│ ├── spatial_attention.py
│ └── triplet_gnn.py
├── kpconv
│ ├── __init__.py
│ ├── dispositions
│ │ └── k_015_center_3D.ply
│ ├── functional.py
│ ├── kernel_points.py
│ ├── kpconv.py
│ └── modules.py
├── loss
│ ├── eval.py
│ └── loss.py
├── match
│ ├── __init__.py
│ ├── learnable_sinkhorn.py
│ └── match.py
├── ops
│ ├── __init__.py
│ ├── grid_subsample.py
│ ├── index_select.py
│ ├── instance_partition.py
│ ├── pairwise_distance.py
│ ├── radius_search.py
│ └── transformation.py
├── registration
│ ├── __init__.py
│ ├── hybrid_reg.py
│ ├── local_global_registration.py
│ ├── metrics.py
│ ├── offline_registration.py
│ └── procrustes.py
├── sg_reg.py
├── train.py
├── utils
│ ├── __init__.py
│ ├── config.py
│ ├── io.py
│ ├── tictoc.py
│ ├── torch.py
│ ├── utils.py
│ └── viz_tools.py
├── val.py
└── visualize.py
└── tutorials
├── explicit_sg.png
└── rag_data.ipynb
/.gitignore:
--------------------------------------------------------------------------------
1 | build
2 | data
3 | output
4 | checkpoints
5 | sgreg/__pycache__
6 | sgreg/*/__pycache__
7 | .vscode
8 |
9 | *.egg-info
10 | *.so
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
SG-Reg: Generalizable and Efficient Scene Graph Registration
4 | Accepted by IEEE T-RO
5 |
6 | Chuhao Liu1,
7 | Zhijian Qiao1,
8 | Jieqi Shi2,*,
9 | Ke Wang3,
10 | Peize Liu 1
11 | and Shaojie Shen1
12 |
13 |
14 | 1HKUST Aerial Robotics Group
15 | 2 NanJing University
16 | 3Chang'an University
17 |
18 |
19 | *Corresponding Author
20 |
21 |
22 |
23 | [](https://ieeexplore.ieee.org/xpl/RecentIssue.jsp?punumber=8860)
24 | [](https://arxiv.org/abs/2504.14440)
25 | [](https://youtu.be/s3P1FvbQGhs)
26 | [](https://www.bilibili.com/video/BV1ymLWzaEMo/)
27 | [](https://huggingface.co/glennliu/sgnet)
28 |
29 |
30 |
31 |
32 |
33 | ### News
34 | * [21 Apr 2025] Publish the initial version of code.
35 | * [19 Apr 2025] Our paper is accepted by [IEEE T-RO](https://ieeexplore.ieee.org/xpl/RecentIssue.jsp?punumber=8860) as a regular paper.
36 | * [8 Oct 2024] Paper submitted to IEEE T-RO.
37 |
38 | In this work, we **learn to register two semantic scene graphs**, an essential capability when an autonomous agent needs to register its map against a remote agent, or against a prior map. To acehive a generalizable registration in the real-world, we design a scene graph network to encode multiple modalities of semantic nodes: open-set semantic feature, local topology with spatial awareness, and shape feature. SG-Reg represents a dense indoor scene in **coarse node features** and **dense point features**. In multi-agent SLAM systems, this representation supports both coarse-to-fine localization and bandwidth-efficient communication.
39 | We generate semantic scene graph using [vision foundation models](https://github.com/IDEA-Research/Grounded-Segment-Anything) and semantic mapping module [FM-Fusion](https://github.com/HKUST-Aerial-Robotics/FM-Fusion). It eliminates the need for ground-truth semantic annotations, enabling **fully self-supervised network training**.
40 | We evaluate our method using real-world RGB-D sequences: [ScanNet](https://github.com/ScanNet/ScanNet), [3RScan](https://github.com/WaldJohannaU/3RScan) and self-collected data using [Realsense i-435](https://www.intelrealsense.com/lidar-camera-l515/).
41 | ## 1. Install
42 | Create virtual environment,
43 | ```bash
44 | conda create sgreg python=3.9
45 | ```
46 | Install PyTorch 2.1.2 and other dependencies.
47 | ```bash
48 | conda install pytorch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 pytorch-cuda=11.8 -c pytorch -c nvidia
49 | ```
50 | ```bash
51 | pip install -r requirements.txt
52 | python setup.py build develop
53 | ```
54 |
55 | ## 2. Download Dataset
56 | Download the *3RScan (RIO)* data [坚果云nutStore link](https://www.jianguoyun.com/p/DVNIaZYQmcSyDRjX8PQFIAA). It involves $50$ pairs of scene graphs. In ```RIO_DATAROOT```, the data are organized in the following structures.
57 | ```
58 | |--val
59 | |--scenexxxx_00a % each individual scene graph
60 | |-- ....
61 | |--splits
62 | |-- val.txt
63 | |--gt
64 | |-- SRCSCENE-REFSCENE.txt % T_ref_src
65 | |--matches
66 | |-- SRCSCENE-REFSCENE.pth % ground-truth node matches
67 | |--output
68 | |--CHECKPOINT_NAME % default: sgnet_scannet_0080
69 | |--SRCSCENE-REFSCENE % results of scene pair
70 | ```
71 |
72 |
73 | We also provide another 50 pairs of *ScanNet* scenes. Please download the ScanNet data using this [坚果云nutStore link](https://www.jianguoyun.com/p/DSJqTN8QmcSyDRjZ8PQFIAA). They are organized in the same data structure as the 3RScan data.
74 |
75 | *Note: We did not use any ground-truth semantic annotation from [3RScan](https://github.com/WaldJohannaU/3RScan) or [ScanNet](https://github.com/ScanNet/ScanNet). The downloaded scene graphs are reconstructed using [FM-Fusion](https://github.com/HKUST-Aerial-Robotics/FM-Fusion). You can also download the original RGB-D sequences and build your scene graphs using FM-Fusion. If you want to try, ScanNet sequences should be easier to start with.
76 |
77 | ## 3. Inference 3RScan Scenes
78 | Find the [config/rio.yaml](config/rio.yaml) and set the ```dataroot/dataroot``` to be the ```RIO_DATASET``` directory on your machine. Then, run the inference program,
79 | ```bash
80 | python sgreg/val.py --cfg_file config/rio.yaml
81 | ```
82 | It will inference all of the downloaded scene pairs in 3RScan. The registration results, including matched nodes, point correspondences and predicted transformation are saved at ```RIO_DATAROOT/ouptut/CHECKPOINT_NAME/SRCSCENE-REFSCENE```. You can visualize the registration results,
83 | ```bash
84 | python sgreg/visualize.py --dataroot $RIO_DATAROOT$ --viz_mode 1 --find_gt --viz_translation [3.0,5.0,0.0]
85 | ```
86 | It should visualize the results as below,
87 |
88 |
89 |
90 | On the left column, you can select the entities you want to visualize.
91 |
92 | If you run the program on a remote server, rerun supports remote visualization (see [rerun connect_tcp](https://ref.rerun.io/docs/python/0.22.1/common/initialization_functions/#rerun)). Check the arguments instruction in [visualize.py](sgreg/visualize.py) to customize your visualization.
93 |
94 | *[Optional]* If you want to evaluate SG-Reg on ScanNet sequences, adjust the running options as below,
95 | ```bash
96 | python sgreg/val.py --cfg_file config/scannet.yaml
97 | python sgreg/visualize.py --dataroot $SCANNET_DATAROOT$ --viz_mode 1 --augment_transform --viz_translation [3.0,5.0,0.0]
98 | ```
99 |
100 | ## 4. Evaluate on your own data
101 | We think generalization capability remains to be a key challenge in 3D semantic perception. If you are interested in the task we are doing, we encourage you to collect your own RGB-D sequence to evaluate.
102 | It requires [VINS-Mono](https://github.com/HKUST-Aerial-Robotics/VINS-Mono) to compute camera poses, [Grounded-SAM](https://github.com/IDEA-Research/Grounded-Segment-Anything) to generate semantic labels, and [FM-Fusion](https://github.com/HKUST-Aerial-Robotics/FM-Fusion) to reconstruct a semantic scene graph.
103 | We will add a detailed instruction later to illustrate how to build your own data.
104 |
105 | ## 5. Develop Log
106 | - [x] Scene graph network code and verify its inference.
107 | - [x] Remove unncessary dependencies.
108 | - [x] Clean the data structure.
109 | - [x] Visualize the results.
110 | - [x] Provide RIO scene graph data for download.
111 | - [x] Provide network weight for download.
112 | - [x] Publish checkpoint on Huggingface Hub and reload.
113 | - [ ] Registration back-end in python interface. (The version used in the paper is a C++ version.)
114 | - [ ] Validation the entire system in a new computer.
115 | - [x] A tutorial for running the validation.
116 |
117 | We will continue to maintain this repo. If you encounter any problem in using it, feel free to publish an issue. We'll try to help.
118 |
119 | ## 6. Acknowledge
120 | We used some of the code from [GeoTransformer](https://github.com/qinzheng93/GeoTransformer), [SG-PGM](https://github.com/dfki-av/sg-pgm) and [LightGlue](https://github.com/cvg/LightGlue). [SkyLand](https://www.futureis3d.com) provides lidar-camera suite to allow us evaluating SG-Reg in large-scale scenes (as demonstrated at the end of the [video](https://youtu.be/k_kPFKcj-jo)).
121 |
122 | ## 7. License
123 | The source code is released under [GPLv3](https://www.gnu.org/licenses/) license.
124 | For technical issues, please contact Chuhao LIU (cliuci@connect.ust.hk).
125 |
--------------------------------------------------------------------------------
/config/rio.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | fix_modules: ['backbone','optimal_transport','shape_backbone','sgnn']
3 | dataset:
4 | dataroot: /data2/RioGraph
5 | load_transform: True
6 | min_iou: 0.2 # for gt match
7 | global_edges_dist: 2.0
8 | online_bert: True
9 | train:
10 | batch_size: 1
11 | num_workers: 8
12 | epochs: 80
13 | optimizer: Adam # Adam, SGD
14 | lr: 0.01
15 | weight_decay: 0.0001
16 | momentum: 0.9
17 | save_interval: 10
18 | val_interval: 1
19 | registration_in_train: True
20 | backbone:
21 | num_stages: 4
22 | init_voxel_size: 0.05
23 | base_radius: 2.5
24 | base_sigma: 2.0
25 | input_dim: 1
26 | output_dim: 256
27 | init_dim: 64
28 | kernel_size: 15
29 | group_norm: 32
30 | shape_encoder:
31 | input_from_stages: 1
32 | output_dim: 1024
33 | kernel_size: 15
34 | init_radius: 2.5 # 0.125
35 | init_sigma: 2.0 # 0.1
36 | group_norm: 32
37 | point_limit: 2048
38 | scenegraph:
39 | bert_dim: 768
40 | semantic_dim: 64
41 | pos_dim: -1
42 | box_dim: 8
43 | fuse_shape: True
44 | fuse_stage: late
45 | node_dim: 64
46 | encode_method: 'gnn'
47 | gnn:
48 | all_self_edges: False
49 | position_encoding: true
50 | layers: [ltriplet] # sage,self,gtop
51 | triplet_mlp: concat # concat, ffn or projector
52 | triplet_activation: gelu
53 | triplet_number: 20
54 | heads: 1
55 | hidden_dim : 16
56 | enable_dist_embedding: true
57 | enable_angle_embedding: True
58 | reduce: 'mean'
59 | se_layer: False
60 | instance_matching:
61 | match_layers: [0,1,2]
62 | topk: 3
63 | min_score: 0.1
64 | multiply_matchability: false
65 | fine_matching:
66 | min_nodes: 3
67 | num_points_per_instance: 256
68 | num_sinkhorn_iterations: 100
69 | topk: 3
70 | acceptance_radius: 0.1
71 | max_instance_selection: true
72 | mutual: True
73 | confidence_threshold: 0.05
74 | use_dustbin: True
75 | ignore_semantics: [floor, carpet, ceiling,] #[floor, carpet, ceiling]
76 | loss:
77 | instance_match:
78 | loss_func: nllv2 # overlap, nll
79 | nll_negative: 0.0
80 | gt_alpha: 2.0
81 | fine_loss_weight: -1.0
82 | shape_contrast_weight: -1.0
83 | gnode_match_weight: 1.0
84 | start_epoch: -1
85 | positive_point_radius: 0.1
86 | contrastive_temp: 0.1
87 | contrastive_postive_overlap: 0.3
88 | # nll_matchability: False
89 | eval:
90 | acceptance_overlap: 0.0
91 | acceptance_radius: 0.1
92 | rmse_threshold: 0.2
93 | gt_node_iou: 0.3
--------------------------------------------------------------------------------
/config/scannet.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | # pretrain_weight: chkt2.0/pretrain_finetuned5/sgnet-0064.pth
3 | fix_modules: ['backbone','optimal_transport','shape_backbone','sgnn']
4 | dataset:
5 | dataroot: /data2/ScanNetGraph
6 | min_iou: 0.3 # for gt match
7 | global_edges_dist: 1.0
8 | online_bert: False
9 | load_transform: False
10 | train:
11 | batch_size: 1
12 | num_workers: 8
13 | epochs: 64
14 | optimizer: Adam # Adam, SGD
15 | lr: 0.01
16 | weight_decay: 0.0001
17 | momentum: 0.9
18 | save_interval: 32
19 | val_interval: 2
20 | registration_in_train: true
21 | backbone:
22 | num_stages: 4
23 | init_voxel_size: 0.05
24 | base_radius: 2.5
25 | base_sigma: 2.0
26 | input_dim: 1
27 | output_dim: 256
28 | init_dim: 64
29 | kernel_size: 15
30 | group_norm: 32
31 | shape_encoder:
32 | input_from_stages: 1
33 | output_dim: 1024
34 | kernel_size: 15
35 | init_radius: 2.5 # 0.125
36 | init_sigma: 2.0 # 0.1
37 | group_norm: 32
38 | point_limit: 2048
39 | scenegraph:
40 | bert_dim: 768
41 | semantic_dim: 64
42 | pos_dim: -1
43 | box_dim: 8
44 | fuse_shape: true
45 | fuse_stage: late
46 | node_dim: 64
47 | encode_method: 'gnn' # 'geotransformer' or 'gnn'
48 | geotransformer:
49 | num_heads: 1
50 | blocks: [self]
51 | sigma_d: 0.2
52 | sigma_a: 15
53 | angle_k: 3
54 | reduction_a: max
55 | gnn:
56 | all_self_edges: false
57 | position_encoding: true
58 | layers: [ltriplet] # sage,self,gtop,sattn
59 | triplet_mlp: concat # concat, ffn or projector
60 | triplet_activation: relu
61 | heads: 1
62 | hidden_dim : 16
63 | enable_dist_embedding: true
64 | enable_angle_embedding: true
65 | reduce: 'mean'
66 | se_layer: False
67 | instance_matching:
68 | match_layers: [0,1,2]
69 | topk: 3
70 | min_score: 0.1
71 | multiply_matchability: false
72 | fine_matching:
73 | min_nodes: 3
74 | num_points_per_instance: 256
75 | num_sinkhorn_iterations: 100
76 | topk: 3
77 | acceptance_radius: 0.1
78 | max_instance_selection: true
79 | mutual: True
80 | confidence_threshold: 0.05
81 | use_dustbin: false
82 | ignore_semantics: [floor, carpet]
83 | loss:
84 | instance_match:
85 | loss_func: nllv2 # overlap, nll
86 | nll_negative: 0.0
87 | gt_alpha: 2.0
88 | fine_loss_weight: -1.0
89 | shape_contrast_weight: -1.0
90 | gnode_match_weight: 1.0
91 | start_epoch: -1
92 | positive_point_radius: 0.1
93 | contrastive_temp: 0.1
94 | contrastive_postive_overlap: 0.3
95 | eval:
96 | acceptance_overlap: 0.0
97 | acceptance_radius: 0.1
98 | rmse_threshold: 0.2
99 | gt_node_iou: 0.3
--------------------------------------------------------------------------------
/docs/rerun_interface.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUST-Aerial-Robotics/SG-Reg/c164198cec84be11dc53101755b0d9f7a4bc5082/docs/rerun_interface.png
--------------------------------------------------------------------------------
/docs/system.001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUST-Aerial-Robotics/SG-Reg/c164198cec84be11dc53101755b0d9f7a4bc5082/docs/system.001.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch_geometric
2 | numpy==1.26.2
3 | omegaconf
4 | transformers
5 | open3d
6 | matplotlib
7 | scipy
8 | tensorboard
9 | rerun-sdk
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | ''' This is a module adopted from SG-PGM:
2 | git@github.com:dfki-av/sg-pgm.git
3 | It maintains the instance label while downsampling the point cloud.
4 | '''
5 |
6 | from setuptools import setup
7 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
8 |
9 |
10 | setup(
11 | name='kpconv_extension',
12 | version='1.0.0',
13 | ext_modules=[
14 | CUDAExtension(
15 | name='.ext',
16 | sources=[
17 | 'sgreg/extensions/extra/cloud/cloud.cpp',
18 | 'sgreg/extensions/cpu/grid_subsampling/grid_subsampling.cpp',
19 | 'sgreg/extensions/cpu/grid_subsampling/grid_subsampling_cpu.cpp',
20 | 'sgreg/extensions/cpu/radius_neighbors/radius_neighbors.cpp',
21 | 'sgreg/extensions/cpu/radius_neighbors/radius_neighbors_cpu.cpp',
22 | 'sgreg/extensions/pybind.cpp',
23 | ],
24 | ),
25 | ],
26 | cmdclass={'build_ext': BuildExtension},
27 | )
28 |
--------------------------------------------------------------------------------
/sgreg/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUST-Aerial-Robotics/SG-Reg/c164198cec84be11dc53101755b0d9f7a4bc5082/sgreg/__init__.py
--------------------------------------------------------------------------------
/sgreg/backbone/__init__.py:
--------------------------------------------------------------------------------
1 | from sgreg.backbone.backbone import KPConvFPN, encode_batch_scenes_points
2 | from sgreg.backbone.shape_encoder import KPConvShape, encode_batch_scenes_instances
3 | from sgreg.backbone.senet import SEModule
--------------------------------------------------------------------------------
/sgreg/backbone/backbone.py:
--------------------------------------------------------------------------------
1 | import time
2 | import torch
3 | import torch.nn as nn
4 | from sgreg.utils.tictoc import TicToc
5 | from sgreg.kpconv import ConvBlock, ResidualBlock, UnaryBlock, LastUnaryBlock, nearest_upsample
6 |
7 | def encode_batch_scenes_points(backbone:nn.Module,
8 | batch_graph_dict:dict):
9 | ''' Read a batch of scene pairs.
10 | Encode all of their fine points and features.
11 | '''
12 | stack_feats_f = []
13 | stack_feats_f_batch = [torch.tensor(0).long()]
14 |
15 | for scene_id in torch.arange(batch_graph_dict['batch_size']):
16 | tictoc = TicToc()
17 | feats = batch_graph_dict['batch_features'][scene_id].detach()
18 | data_dict = batch_graph_dict['batch_points'][scene_id]
19 |
20 | # 2. KPFCNN Encoder
21 | feats_list = backbone(feats,
22 | data_dict['points'],
23 | data_dict['neighbors'],
24 | data_dict['subsampling'],
25 | data_dict['upsampling'])
26 | duration = tictoc.toc()
27 | #
28 | points_f = data_dict['points'][1]
29 | feats_f = feats_list[0]
30 | assert feats_f.shape[0] == points_f.shape[0], 'feats_f and points_f shape are not match'
31 |
32 | stack_feats_f.append(feats_f) # (P+Q,C)
33 | stack_feats_f_batch.append(stack_feats_f_batch[-1]+feats_f.shape[0])
34 |
35 | #
36 | stack_feats_f = torch.cat(stack_feats_f,dim=0) # (P+Q,C)
37 | stack_feats_f_batch = torch.stack(stack_feats_f_batch,dim=0) # (B+1,)
38 |
39 | return {'feats_f':stack_feats_f,
40 | 'feats_batch':stack_feats_f_batch}, duration
41 |
42 | class KPConvFPN(nn.Module):
43 | def __init__(self, input_dim, output_dim, init_dim, kernel_size, init_radius, init_sigma, group_norm):
44 | super(KPConvFPN, self).__init__()
45 |
46 | self.encoder1_1 = ConvBlock(input_dim, init_dim, kernel_size, init_radius, init_sigma, group_norm)
47 | self.encoder1_2 = ResidualBlock(init_dim, init_dim * 2, kernel_size, init_radius, init_sigma, group_norm)
48 |
49 | self.encoder2_1 = ResidualBlock(
50 | init_dim * 2, init_dim * 2, kernel_size, init_radius, init_sigma, group_norm, strided=True
51 | )
52 | self.encoder2_2 = ResidualBlock(
53 | init_dim * 2, init_dim * 4, kernel_size, init_radius * 2, init_sigma * 2, group_norm
54 | )
55 | self.encoder2_3 = ResidualBlock(
56 | init_dim * 4, init_dim * 4, kernel_size, init_radius * 2, init_sigma * 2, group_norm
57 | )
58 |
59 | self.encoder3_1 = ResidualBlock(
60 | init_dim * 4, init_dim * 4, kernel_size, init_radius * 2, init_sigma * 2, group_norm, strided=True
61 | )
62 | self.encoder3_2 = ResidualBlock(
63 | init_dim * 4, init_dim * 8, kernel_size, init_radius * 4, init_sigma * 4, group_norm
64 | )
65 | self.encoder3_3 = ResidualBlock(
66 | init_dim * 8, init_dim * 8, kernel_size, init_radius * 4, init_sigma * 4, group_norm
67 | )
68 |
69 | self.encoder4_1 = ResidualBlock(
70 | init_dim * 8, init_dim * 8, kernel_size, init_radius * 4, init_sigma * 4, group_norm, strided=True
71 | )
72 | self.encoder4_2 = ResidualBlock(
73 | init_dim * 8, init_dim * 16, kernel_size, init_radius * 8, init_sigma * 8, group_norm
74 | )
75 | self.encoder4_3 = ResidualBlock(
76 | init_dim * 16, init_dim * 16, kernel_size, init_radius * 8, init_sigma * 8, group_norm
77 | )
78 |
79 | self.decoder3 = UnaryBlock(init_dim * 24, init_dim * 8, group_norm)
80 | self.decoder2 = LastUnaryBlock(init_dim * 12, output_dim)
81 |
82 | def forward(self, feats,
83 | points_list,
84 | neighbors_list,
85 | subsampling_list,
86 | upsampling_list):
87 | '''
88 | Read a scene pair.
89 | Encode all the fine points and features.
90 | '''
91 | feats_list = []
92 |
93 | feats_s1 = feats
94 | feats_s1 = self.encoder1_1(feats_s1, points_list[0], points_list[0], neighbors_list[0])
95 | feats_s1 = self.encoder1_2(feats_s1, points_list[0], points_list[0], neighbors_list[0])
96 |
97 | feats_s2 = self.encoder2_1(feats_s1, points_list[1], points_list[0], subsampling_list[0])
98 | feats_s2 = self.encoder2_2(feats_s2, points_list[1], points_list[1], neighbors_list[1])
99 | feats_s2 = self.encoder2_3(feats_s2, points_list[1], points_list[1], neighbors_list[1])
100 |
101 | feats_s3 = self.encoder3_1(feats_s2, points_list[2], points_list[1], subsampling_list[1])
102 | feats_s3 = self.encoder3_2(feats_s3, points_list[2], points_list[2], neighbors_list[2])
103 | feats_s3 = self.encoder3_3(feats_s3, points_list[2], points_list[2], neighbors_list[2])
104 |
105 | feats_s4 = self.encoder4_1(feats_s3, points_list[3], points_list[2], subsampling_list[2])
106 | feats_s4 = self.encoder4_2(feats_s4, points_list[3], points_list[3], neighbors_list[3])
107 | feats_s4 = self.encoder4_3(feats_s4, points_list[3], points_list[3], neighbors_list[3])
108 |
109 | latent_s4 = feats_s4
110 | feats_list.append(feats_s4)
111 |
112 | latent_s3 = nearest_upsample(latent_s4, upsampling_list[2])
113 | latent_s3 = torch.cat([latent_s3, feats_s3], dim=1)
114 | latent_s3 = self.decoder3(latent_s3)
115 | feats_list.append(latent_s3)
116 |
117 | latent_s2 = nearest_upsample(latent_s3, upsampling_list[1])
118 | latent_s2 = torch.cat([latent_s2, feats_s2], dim=1)
119 | latent_s2 = self.decoder2(latent_s2)
120 | feats_list.append(latent_s2)
121 |
122 | feats_list.reverse()
123 |
124 | return feats_list
125 |
126 |
127 | # This a copy of the KPConvFPN. It is used to compute the FLOPS of the backbone.
128 | # It accepts input in one list.
129 | class TmpKPConvFPN(KPConvFPN):
130 | # the init function is the same as the KPConvFPN
131 | def forward(self, feats,
132 | points0,
133 | points1,
134 | points2,
135 | points3,
136 | neighbors0,
137 | neighbors1,
138 | neighbors2,
139 | neighbors3,
140 | subsampling0,
141 | subsampling1,
142 | subsampling2,
143 | upsampling0,
144 | upsampling1,
145 | upsampling2):
146 | '''
147 | Read a scene pair.
148 | Encode all the fine points and features.
149 | '''
150 | feats_list = []
151 |
152 | feats_s1 = feats
153 | feats_s1 = self.encoder1_1(feats_s1, points0, points0, neighbors0)
154 | feats_s1 = self.encoder1_2(feats_s1, points0, points0, neighbors0)
155 |
156 | feats_s2 = self.encoder2_1(feats_s1, points1, points0, subsampling0)
157 | feats_s2 = self.encoder2_2(feats_s2, points1, points1, neighbors1)
158 | feats_s2 = self.encoder2_3(feats_s2, points1, points1, neighbors1)
159 |
160 | feats_s3 = self.encoder3_1(feats_s2, points2, points1, subsampling1)
161 | feats_s3 = self.encoder3_2(feats_s3, points2, points2, neighbors2)
162 | feats_s3 = self.encoder3_3(feats_s3, points2, points2, neighbors2)
163 |
164 | feats_s4 = self.encoder4_1(feats_s3, points3, points2, subsampling2)
165 | feats_s4 = self.encoder4_2(feats_s4, points3, points3, neighbors3)
166 | feats_s4 = self.encoder4_3(feats_s4, points3, points3, neighbors3)
167 |
168 | latent_s4 = feats_s4
169 | feats_list.append(feats_s4)
170 |
171 | latent_s3 = nearest_upsample(latent_s4, upsampling2)
172 | latent_s3 = torch.cat([latent_s3, feats_s3], dim=1)
173 | latent_s3 = self.decoder3(latent_s3)
174 | feats_list.append(latent_s3)
175 |
176 | latent_s2 = nearest_upsample(latent_s3, upsampling1)
177 | latent_s2 = torch.cat([latent_s2, feats_s2], dim=1)
178 | latent_s2 = self.decoder2(latent_s2)
179 | feats_list.append(latent_s2)
180 |
181 | feats_list.reverse()
182 | return feats_list
183 |
184 | if __name__=='__main__':
185 | print('Try init a KPConvFPN')
186 | model = KPConvFPN(input_dim=3,
187 | output_dim=64,
188 | init_dim=16,
189 | kernel_size=15,
190 | init_radius=0.2,
191 | init_sigma=0.1,
192 | group_norm=8)
193 |
194 |
--------------------------------------------------------------------------------
/sgreg/backbone/senet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class SEModule(nn.Module):
7 | def __init__(self, channels, reduction=16):
8 | super(SEModule, self).__init__()
9 | self.squeeze = nn.AdaptiveAvgPool1d(1)
10 | self.execitation = nn.Sequential(
11 | nn.Linear(channels, channels // reduction, bias=False),
12 | nn.ReLU(inplace=True),
13 | nn.Linear(channels // reduction, channels, bias=False),
14 | nn.Sigmoid()
15 | )
16 |
17 | def forward(self, x):
18 | '''
19 | Input,
20 | - x: (N,C)
21 | '''
22 | n, c = x.size()
23 | y = self.squeeze(x.t()).t() # (C)
24 | y = self.execitation(y) # (C,1)
25 | out = x * y.expand_as(x) # (N,C)
26 | return out
27 |
28 |
29 | if __name__=='__main__':
30 | se = SEModule(32)
31 | x = torch.randn(4,32)
32 | out = se(x)
33 | print(out.shape)
--------------------------------------------------------------------------------
/sgreg/backbone/shape_encoder.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import torch
3 | import torch.nn as nn
4 | from sgreg.ops import sample_instance_from_points
5 | from sgreg.kpconv import ResidualBlock, LastUnaryBlock
6 | from sgreg.utils.tictoc import TicToc
7 |
8 | class KPConvShape(nn.Module):
9 | def __init__(self,
10 | input_dim,
11 | output_dim,
12 | kernel_size,
13 | init_radius,
14 | init_sigma,
15 | group_norm,
16 | K_shape=1024,
17 | K_match=256,
18 | decoder=True):
19 | super(KPConvShape, self).__init__()
20 |
21 | self.encoder1_1 = ResidualBlock(
22 | input_dim, output_dim, kernel_size, init_radius, init_sigma, group_norm, strided=True
23 | )
24 |
25 | if decoder:
26 | self.decoder = LastUnaryBlock(output_dim+input_dim, input_dim)
27 |
28 | self.input_dim = input_dim
29 | self.output_dim = output_dim
30 | self.K_shape = K_shape
31 | self.K_match = K_match
32 |
33 | def forward(self, f_points,
34 | f_feats,
35 | f_instances,
36 | instances_centroids,
37 | instances_points_indices,
38 | decode_points):
39 | '''
40 | Input:
41 | f_points: (P, 3),
42 | f_feats: (P, C),
43 | f_instances: (P, 1),
44 | instances_centroids: (N, 3), instance centroid poistion
45 | instances_points_indices: (N, H), instance knn indices
46 | '''
47 |
48 | #! encoder reduce feature dimension C->C/4
49 | instance_feats = self.encoder1_1(f_feats, instances_centroids, f_points, instances_points_indices) # (N, C)
50 | assert instance_feats.shape[0] == instances_centroids.shape[0]
51 | if decode_points: # This is only used in pre-train constrastive learning. In validation and deployment, skip it.
52 | f_instance_feats = instance_feats[f_instances.squeeze()] # (P, C)
53 | feats_decoded = self.decoder(torch.cat([f_instance_feats, f_feats], dim=1))
54 | assert feats_decoded.shape[0] == f_points.shape[0]
55 | return instance_feats, feats_decoded
56 | else:
57 | return instance_feats, None
58 |
59 | def concat_instance_points(ref_instance_number:torch.Tensor,
60 | src_instance_number:torch.Tensor,
61 | ref_instance_points:torch.Tensor,
62 | src_instance_points:torch.Tensor,
63 | device:torch.device):
64 | ''' build instance-labeled points and instance list.
65 | Return:
66 | instance_list: (N+M,), instance_points: (P+Q, 3)
67 | '''
68 | ref_instance_list = torch.arange(ref_instance_number).to(device)
69 | src_instance_list = torch.arange(src_instance_number).to(device)+ref_instance_number
70 | src_instance_points = src_instance_points + ref_instance_number
71 | instance_list = torch.cat([ref_instance_list,src_instance_list],dim=0) # (N+M,)
72 | instance_points = torch.cat([ref_instance_points,src_instance_points],dim=0) # (P+Q, 3)
73 |
74 | return instance_list, instance_points
75 |
76 | def encode_batch_scenes_instances(shape_backbone:nn.Module,
77 | batch_graph_pair:dict,
78 | instance_f_feats_dict:dict,
79 | decode_points=True,
80 | verify=True):
81 | tictoc = TicToc()
82 | stack_instances_shape = []
83 | stack_instance_pts_match = []
84 | stack_instances_batch = [torch.tensor(0).long()]
85 | stack_feats_f_decoded = []
86 | ref_graph_batch = batch_graph_pair['ref_graph']['batch']
87 | src_graph_batch = batch_graph_pair['src_graph']['batch']
88 | invalid_instance_exist = False
89 | duration_list = []
90 |
91 | for scene_id in torch.arange(batch_graph_pair['batch_size']):
92 | # Extract scene data
93 | data_dict = batch_graph_pair['batch_points'][scene_id]
94 | f_pts_length = data_dict['lengths'][1] # [P,Q]
95 | assert f_pts_length.device == data_dict['insts'][1].device, 'f_pts_length and insts device are not match'
96 | assert data_dict['insts'][1].is_contiguous(), 'insts[1] is not contiguous'
97 | tmp_instances_f = data_dict['insts'][1]
98 | # todo: this step is slow. approx. 50ms
99 | ref_instances_f = tmp_instances_f[:f_pts_length[0]] # (P,)
100 | src_instances_f = tmp_instances_f[f_pts_length[0]:] # (Q,)
101 | points_f = data_dict['points'][1] # (P+Q, 3)
102 | duration_list.append(tictoc.toc())
103 |
104 | feats_b0 = instance_f_feats_dict['feats_batch'][scene_id]
105 | feats_b1 = instance_f_feats_dict['feats_batch'][scene_id+1]
106 | feats_f = instance_f_feats_dict['feats_f'][feats_b0:feats_b1] # (P+Q, C)
107 |
108 | instance_list, instances_f = concat_instance_points(
109 | ref_graph_batch[scene_id+1]-ref_graph_batch[scene_id],
110 | src_graph_batch[scene_id+1]-src_graph_batch[scene_id],
111 | ref_instances_f,
112 | src_instances_f,
113 | feats_f.device)
114 | assert points_f.shape[0] == instances_f.shape[0] \
115 | and points_f.shape[0]==feats_f.shape[0], 'points, feats, and instances shape are not match'
116 |
117 | #
118 | instance_fpts_indxs_shape, invalid_instance_mask = \
119 | sample_instance_from_points(instances_f, instance_list,
120 | shape_backbone.K_shape, points_f.shape[0]) # (N+M,Ks), (N+M,)
121 | instance_fpts_indxs_match, _ =\
122 | sample_instance_from_points(instances_f, instance_list,
123 | shape_backbone.K_match, points_f.shape[0]) # (N+M,Km), (N+M,)
124 | if(torch.any(invalid_instance_mask)): # involves invalid instances
125 | print('instance without points in {}'.format(batch_graph_pair['src_scan'][scene_id]))
126 | # assert False, 'Some instances are not assigned to any fine points.'
127 | instances_shape = torch.zeros((instance_list.shape[0],
128 | shape_backbone.output_dim)).to(feats_f.device)
129 | feats_f_decoded = torch.zeros((feats_f.shape[0],
130 | feats_f.shape[1])).to(feats_f.device)
131 | invalid_instance_exist = True
132 | else: # shape instance-wise shapes and points
133 | # Extract instance centroids
134 | instance_f_points = points_f[instance_fpts_indxs_shape] # (N+M, Ks, 3)
135 | instance_f_centers = instance_f_points.mean(dim=1) # (N+M, 3)
136 |
137 | if verify:
138 | ref_instance_centroids = batch_graph_pair['ref_graph']['centroids'][batch_graph_pair['ref_graph']['scene_mask']==scene_id]
139 | src_instance_centroids = batch_graph_pair['src_graph']['centroids'][batch_graph_pair['src_graph']['scene_mask']==scene_id]
140 | instance_centroids = torch.cat([ref_instance_centroids,src_instance_centroids],dim=0) # (M+N, 3)
141 | dist = torch.abs((instance_centroids - instance_f_centers).mean(dim=1)) # (M+N,)
142 | assert dist.max()<1.0, 'Center distance is too large {:.4f}m'.format(dist.max())
143 |
144 | instances_shape, feats_f_decoded = shape_backbone(points_f,
145 | feats_f,
146 | instances_f,
147 | instance_f_centers,
148 | instance_fpts_indxs_shape,
149 | decode_points)
150 | duration_list.append(tictoc.toc())
151 |
152 | stack_instances_shape.append(instances_shape)
153 | stack_instance_pts_match.append(instance_fpts_indxs_match)
154 | stack_instances_batch.append(stack_instances_batch[-1]+instance_list.shape[0])
155 |
156 | if decode_points:
157 | assert feats_f_decoded is not None, 'feats_f_decoded is None'
158 | stack_feats_f_decoded.append(feats_f_decoded) # (P+Q, C0)
159 |
160 | # Concatenate all scenes
161 | stack_instances_shape = torch.cat(stack_instances_shape,dim=0) # (M+N, C1)
162 | stack_instances_shape = nn.functional.normalize(stack_instances_shape, dim=1) # (M+N, C1)
163 | instance_f_feats_dict['instances_shape'] = stack_instances_shape
164 | instance_f_feats_dict['instances_f_indices_match'] = torch.cat(stack_instance_pts_match,dim=0)
165 | instance_f_feats_dict['instances_batch'] = torch.stack(stack_instances_batch,dim=0) # (B+1,)
166 | instance_f_feats_dict['invalid_instance_exist'] = invalid_instance_exist
167 |
168 | if decode_points:
169 | stack_feats_f_decoded = torch.cat(stack_feats_f_decoded,dim=0) # (P+Q, C0)
170 | instance_f_feats_dict['feats_f_decoded'] = stack_feats_f_decoded
171 | duration_list.append(tictoc.toc())
172 |
173 | # timing
174 | # msg = ['{:.2f}'.format(1000*t) for t in duration_list]
175 | # print('Shape Encoder: ', ' '.join(msg))
176 |
177 | return instance_f_feats_dict, duration_list[1]
--------------------------------------------------------------------------------
/sgreg/bert/get_tokenizer.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoTokenizer, BertModel, BertTokenizer, RobertaModel, RobertaTokenizerFast
2 |
3 |
4 | def get_tokenlizer(text_encoder_type):
5 | if not isinstance(text_encoder_type, str):
6 | # print("text_encoder_type is not a str")
7 | if hasattr(text_encoder_type, "text_encoder_type"):
8 | text_encoder_type = text_encoder_type.text_encoder_type
9 | elif text_encoder_type.get("text_encoder_type", False):
10 | text_encoder_type = text_encoder_type.get("text_encoder_type")
11 | else:
12 | raise ValueError(
13 | "Unknown type of text_encoder_type: {}".format(type(text_encoder_type))
14 | )
15 | print("final text_encoder_type: {}".format(text_encoder_type))
16 |
17 | tokenizer = AutoTokenizer.from_pretrained(text_encoder_type)
18 | return tokenizer
19 |
20 |
21 | def get_pretrained_language_model(text_encoder_type):
22 | if text_encoder_type == "bert-base-uncased":
23 | return BertModel.from_pretrained(text_encoder_type)
24 | if text_encoder_type == "roberta-base":
25 | return RobertaModel.from_pretrained(text_encoder_type)
26 | raise ValueError("Unknown text_encoder_type {}".format(text_encoder_type))
27 |
--------------------------------------------------------------------------------
/sgreg/dataset/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUST-Aerial-Robotics/SG-Reg/c164198cec84be11dc53101755b0d9f7a4bc5082/sgreg/dataset/__init__.py
--------------------------------------------------------------------------------
/sgreg/dataset/dataset_factory.py:
--------------------------------------------------------------------------------
1 | import sys, os
2 | import torch
3 | import numpy as np
4 | import open3d as o3d
5 | from sgreg.dataset.scene_pair_dataset import (
6 | ScenePairDataset,
7 | )
8 | from sgreg.dataset.stack_mode import (
9 | calibrate_neighbors_stack_mode,
10 | registration_collate_fn_stack_mode,
11 | sgreg_collate_fn_stack_mode,
12 | build_dataloader_stack_mode
13 | )
14 |
15 | def train_data_loader(cfg,distributed=False):
16 | train_dataset = ScenePairDataset(cfg.dataset.dataroot, 'train', cfg)
17 | neighbor_limits = calibrate_neighbors_stack_mode(
18 | train_dataset,
19 | registration_collate_fn_stack_mode,
20 | cfg.backbone.num_stages,
21 | cfg.backbone.init_voxel_size,
22 | cfg.backbone.init_radius,
23 | )
24 | train_dataset.neighbor_limits = neighbor_limits
25 | if 'verify_instance_points' in cfg.dataset:
26 | verify_instance_points = cfg.dataset.verify_instance_points
27 | else:
28 | verify_instance_points = False
29 |
30 | train_loader = build_dataloader_stack_mode(
31 | train_dataset,
32 | sgreg_collate_fn_stack_mode,
33 | cfg.backbone.num_stages,
34 | cfg.backbone.init_voxel_size,
35 | cfg.backbone.init_radius,
36 | neighbor_limits,
37 | verify_instance_points,
38 | batch_size=cfg.train.batch_size,
39 | num_workers=cfg.train.num_workers,
40 | shuffle=True,
41 | distributed=distributed,
42 | )
43 | return train_loader,neighbor_limits
44 |
45 | def val_data_loader(cfg,distributed=False):
46 | val_dataset = ScenePairDataset(cfg.dataset.dataroot, 'val', cfg)
47 | neighbor_limits = calibrate_neighbors_stack_mode(
48 | val_dataset,
49 | registration_collate_fn_stack_mode,
50 | cfg.backbone.num_stages,
51 | cfg.backbone.init_voxel_size,
52 | cfg.backbone.init_radius,
53 | )
54 | print('Calibrated neighbor limits:',neighbor_limits)
55 | # neighbor_limits = [38, 36, 36, 38] # default setting in 3DMatch
56 | val_dataset.neighbor_limits = neighbor_limits
57 |
58 | if 'verify_instance_points' in cfg.dataset:
59 | verify_instance_points = cfg.dataset.verify_instance_points
60 | else:
61 | verify_instance_points = False
62 |
63 | val_loader = build_dataloader_stack_mode(
64 | val_dataset,
65 | sgreg_collate_fn_stack_mode,
66 | cfg.backbone.num_stages,
67 | cfg.backbone.init_voxel_size,
68 | cfg.backbone.init_radius,
69 | neighbor_limits,
70 | verify_instance_points,
71 | batch_size=cfg.train.batch_size,
72 | num_workers=cfg.train.num_workers,
73 | shuffle=False,
74 | distributed=distributed,
75 | )
76 | return val_loader,neighbor_limits
77 |
78 | if __name__=='__main__':
79 | import torch
80 | from omegaconf import OmegaConf
81 | conf = OmegaConf.load('config/pretrain.yaml')
82 | conf.backbone.init_radius = 0.5 * conf.backbone.base_radius * conf.backbone.init_voxel_size # 0.0625
83 | conf.backbone.init_sigma = 0.5 * conf.backbone.base_sigma * conf.backbone.init_voxel_size # 0.05
84 |
85 | train_loader, train_neighbor_limits = train_data_loader(conf)
86 | val_loader, val_neighbor_limits = val_data_loader(conf)
87 |
88 | # print('train neighbor limits',train_neighbor_limits)
89 | # print('val neighbor limits',val_neighbor_limits)
90 | # exit(0)
91 | # data_dict = next(iter(train_loader))
92 | # src_graph = data_dict['src_graph']
93 | # print(data_dict.keys())
94 | # print(src_graph.keys())
95 |
96 | warning_scans = []
97 | valid_scans = []
98 |
99 | for data_dict in val_loader:
100 | src_graph = data_dict['src_graph']
101 | src_scan = data_dict['src_scan'][0]
102 | ref_scan = data_dict['ref_scan'][0]
103 | msg = '{}-{}: '.format(src_scan,ref_scan)
104 | points_dict= data_dict['batch_points'][0]
105 | # print(data_dict.keys())
106 | lengths = points_dict['lengths'] # [(Pl,Ql)]
107 |
108 | # if 'debug_flag' in data_dict:
109 | # print('{} inconsistent instance {}-{} !!!'.format(
110 | # data_dict['debug_flag'], src_scan,ref_scan, ))
111 | if data_dict['instance_matches'][0] is None:
112 | print('{}-{} no gt instance matches!!!'.format(src_scan,ref_scan))
113 | warning_scans.append('{} {}'.format(src_scan,ref_scan))
114 | else:
115 | valid_scans.append('{} {}'.format(src_scan,ref_scan))
116 | print(msg)
117 | # break
118 |
119 | print('************ Finished ************')
120 | print('{}/{} warning scans'.format(len(warning_scans),
121 | len(valid_scans)+len(warning_scans)))
122 |
123 |
124 | # outdir = '/data2/sgalign_data/splits/val_ours.txt'
125 | # with open(outdir,'w') as f:
126 | # f.write('\n'.join(valid_scans))
127 | # f.close()
128 |
--------------------------------------------------------------------------------
/sgreg/dataset/prepare_semantics.py:
--------------------------------------------------------------------------------
1 | '''
2 | Generate median features for the scene graph,
3 | - Semantic embeddings. Encoded from bert.
4 | - Hierarchical points features. Encoded in pair of sub-scenes.
5 | '''
6 |
7 |
8 | import os
9 | import torch
10 | import torch.utils.data
11 | import pandas as pd
12 | import time
13 | # import numpy as np
14 | from omegaconf import OmegaConf
15 | from scipy.spatial.transform import Rotation as R
16 | from sgreg.bert.get_tokenizer import get_tokenlizer, get_pretrained_language_model
17 | from sgreg.bert.bertwarper import generate_bert_fetures
18 |
19 | # from model.ops.transformation import apply_transform
20 | from sgreg.dataset.scene_pair_dataset import ScenePairDataset
21 | # from model.dataset.DatasetFactory import train_data_loader, val_data_loader
22 | from sgreg.utils.config import create_cfg
23 | # from model.dataset.ScanNetPairDataset import read_scans,load_scene_graph
24 |
25 | def generate_semantic_embeddings(data_dict):
26 | src_labels = data_dict['src_graph']['labels']
27 | ref_labels = data_dict['ref_graph']['labels']
28 | src_idxs = data_dict['src_graph']['idx2name']
29 | ref_idxs = data_dict['ref_graph']['idx2name']
30 | assert len(src_labels) == src_idxs.shape[0] and len(ref_labels) == ref_idxs.shape[0]
31 |
32 | labels = src_labels + ref_labels
33 |
34 | t0 = time.time()
35 | semantic_embeddings = generate_bert_fetures(tokenizer,bert,labels)
36 | t1 = time.time()
37 | print('encode {} labels takes {:.3f} msecs'.format(len(labels),1000*(t1-t0)))
38 | assert len(semantic_embeddings) == len(labels)
39 |
40 | src_semantic_embeddings = semantic_embeddings[:len(src_labels)].detach()
41 | ref_semantic_embeddings = semantic_embeddings[len(src_labels):].detach()
42 |
43 | return {'instance_idxs':src_idxs,'semantic_embeddings':src_semantic_embeddings},\
44 | {'instance_idxs':ref_idxs,'semantic_embeddings':ref_semantic_embeddings}
45 |
46 | if __name__=='__main__':
47 | print('Save the semantic embeddings to accelerate the training process.')
48 | ##
49 | dataroot = '/data2/RioGraph' # '/data2/ScanNetGraph'
50 | split = 'val'
51 | cfg_file = '/home/cliuci/code_ws/SceneGraphNet/config/rio.yaml'
52 | middle_feat_folder = os.path.join(dataroot,'matches')
53 | from sgreg.dataset.DatasetFactory import prepare_hierarchy_points_feats
54 |
55 | ##
56 | tokenizer = get_tokenlizer('bert-base-uncased')
57 | bert = get_pretrained_language_model('bert-base-uncased')
58 | bert.eval()
59 | bert.pooler.dense.weight.requires_grad = False
60 | bert.pooler.dense.bias.requires_grad = False
61 |
62 | #
63 | conf = OmegaConf.load(cfg_file)
64 | conf = create_cfg(conf)
65 | conf.dataset.online_bert = True
66 | dataset = ScenePairDataset(dataroot,split,conf)
67 |
68 | neighbor_limits = [38, 36, 36, 38]
69 | print('neighbor_limits:',neighbor_limits)
70 | cfg_dict= {'num_stages': conf.backbone.num_stages,
71 | 'voxel_size': conf.backbone.init_voxel_size,
72 | 'search_radius': conf.backbone.init_radius,
73 | 'neighbor_limits': neighbor_limits}
74 |
75 | #
76 | N = len(dataset)
77 | print('Dataset size:',N)
78 |
79 | SEMANTIC_ON = True
80 | POINTS_ON = False
81 | CHECK_EDGES = False
82 | max_fine_points = []
83 |
84 | warn_scans = []
85 |
86 | for i in range(N):
87 | data_dict = dataset[i]
88 | scene_name = data_dict['src_scan'][:-1]
89 | src_subname = data_dict['src_scan'][-1]
90 | ref_subname = data_dict['ref_scan'][-1]
91 | print('---processing {} and {} -----'.format(data_dict['src_scan'],data_dict['ref_scan']))
92 | if data_dict['instance_ious'] is None:
93 | warn_scans.append((data_dict['src_scan'],data_dict['ref_scan']))
94 |
95 | if CHECK_EDGES:
96 | src_global_edges = data_dict['src_graph']['global_edge_indices']
97 | ref_global_edges = data_dict['ref_graph']['global_edge_indices']
98 | print('{} global src edges'.format(src_global_edges.shape[0]))
99 |
100 | if src_global_edges.shape[0]<1 or ref_global_edges.shape[0]<1:
101 | print('global_edges is None!')
102 | warn_scans.append((data_dict['src_scan'],data_dict['ref_scan']))
103 |
104 | if SEMANTIC_ON:
105 | src_semantic_embeddings, ref_semantic_embeddings = generate_semantic_embeddings(data_dict)
106 | torch.save(src_semantic_embeddings,
107 | os.path.join(dataroot,split,data_dict['src_scan'],'semantic_embeddings.pth'))
108 | torch.save(ref_semantic_embeddings,
109 | os.path.join(dataroot,split,data_dict['ref_scan'],'semantic_embeddings.pth'))
110 | if POINTS_ON:
111 | out_dict = prepare_hierarchy_points_feats(cfg_dict=cfg_dict,
112 | ref_points = data_dict['ref_points'],
113 | src_points=data_dict['src_points'],
114 | ref_instances=data_dict['ref_instances'],
115 | src_instances=data_dict['src_instances'],
116 | ref_feats=data_dict['ref_feats'],
117 | src_feats=data_dict['src_feats'])
118 | import numpy as np
119 |
120 | ref_instance_count = np.histogram(data_dict['ref_instances'],bins=np.unique(data_dict['ref_instances']))[0]
121 | print('ref instance list:',np.unique(data_dict['ref_instances']))
122 | print('ref_instance points count:',ref_instance_count)
123 | print('points number: ', data_dict['ref_points'].shape[0])
124 |
125 | ref_instance_list = np.unique(out_dict['ref_points_f_instances'])
126 | ref_instance_fine_count = np.histogram(out_dict['ref_points_f_instances'],bins=ref_instance_list)[0]
127 | print('ref instance list:',ref_instance_list)
128 | print('ref_instance fine points count:',ref_instance_fine_count)
129 | print('max fine points: ',ref_instance_fine_count.max())
130 | max_fine_points.append(ref_instance_fine_count.max())
131 |
132 | points_dict = out_dict['points_dict']
133 | feats = out_dict['feats']
134 | ref_points_f_instances= out_dict['ref_points_f_instances']
135 | src_points_f_instances= out_dict['src_points_f_instances']
136 | assert 'instance_matches' in data_dict
137 | assert isinstance(ref_points_f_instances,torch.Tensor) and isinstance(src_points_f_instances,torch.Tensor)
138 | assert points_dict['points'][0].shape[0] == points_dict['lengths'][0].sum()
139 | # Export
140 | data_dict.update(out_dict)
141 | torch.save(data_dict,os.path.join(middle_feat_folder,scene_name,'data_dict_{}.pth'.format(src_subname+ref_subname)))
142 |
143 | # break
144 | print(warn_scans)
145 | print('finished {} scan pairs'.format(N))
146 |
147 | if len(max_fine_points)>0:
148 | max_fine_points = np.array(max_fine_points)
149 | print(max_fine_points)
150 | print(max_fine_points.max())
151 |
152 |
153 |
154 |
155 |
--------------------------------------------------------------------------------
/sgreg/dataset/scene_graph.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import pandas as pd
4 | import numpy as np
5 | import open3d as o3d
6 | from scipy.spatial.transform import Rotation as R
7 |
8 | from open3d.geometry import OrientedBoundingBox as OBB
9 | from os.path import join as osp
10 |
11 | class Instance:
12 | def __init__(self,idx:int,
13 | cloud:o3d.geometry.PointCloud|np.ndarray,
14 | label:str,
15 | score:float):
16 | self.idx = idx
17 | self.label = label
18 | self.score = score
19 | self.cloud = cloud
20 | self.cloud_dir = None
21 | def load_box(self,box:o3d.geometry.OrientedBoundingBox):
22 | self.box = box
23 |
24 | def load_raw_scene_graph(folder_dir:str,
25 | voxel_size:float=0.02,
26 | ignore_types:list=['ceiling']):
27 | ''' graph: {'nodes':{idx:Instance},'edges':{idx:idx}}
28 | '''
29 | # load scene graph
30 | nodes = {}
31 | boxes = {}
32 | invalid_nodes = []
33 | xyzi = []
34 | global_cloud = o3d.geometry.PointCloud()
35 | # IGNORE_TYPES = ['floor','carpet','wall']
36 | # IGNORE_TYPES = ['ceiling']
37 |
38 | # load instance boxes
39 | with open(os.path.join(folder_dir,'instance_box.txt')) as f:
40 | count=0
41 | for line in f.readlines():
42 | line = line.strip()
43 | if'#' in line:continue
44 | parts = line.split(';')
45 | idx = int(parts[0])
46 | center = np.array([float(x) for x in parts[1].split(',')])
47 | rotation = np.array([float(x) for x in parts[2].split(',')])
48 | extent = np.array([float(x) for x in parts[3].split(',')])
49 | o3d_box = o3d.geometry.OrientedBoundingBox(center,rotation.reshape(3,3),extent)
50 | o3d_box.color = (0,0,0)
51 | # if'nan' in line:invalid_nodes.append(idx)
52 | if 'nan' not in line:
53 | boxes[idx] = o3d_box
54 | # nodes[idx].load_box(o3d_box)
55 | count+=1
56 | f.close()
57 | print('load {} boxes'.format(count))
58 |
59 | # load instance info
60 | with open(os.path.join(folder_dir,'instance_info.txt')) as f:
61 | for line in f.readlines():
62 | line = line.strip()
63 | if'#' in line:continue
64 | parts = line.split(';')
65 | idx = int(parts[0])
66 | if idx not in boxes: continue
67 | label_score_vec = parts[1].split('(')
68 | label = label_score_vec[0]
69 | score = float(label_score_vec[1].split(')')[0])
70 | if label in ignore_types: continue
71 | # print('load {}:{}, {}'.format(idx,label,score))
72 |
73 | cloud = o3d.io.read_point_cloud(os.path.join(folder_dir,'{}.ply'.format(parts[0])))
74 | cloud = cloud.voxel_down_sample(voxel_size)
75 | xyz = np.asarray(cloud.points)
76 | # if xyz.shape[0]<50: continue
77 | xyzi.append(np.concatenate([xyz,idx*np.ones((len(xyz),1))],
78 | axis=1))
79 | global_cloud = global_cloud + cloud
80 | nodes[idx] = Instance(idx,cloud,label,score)
81 | nodes[idx].cloud_dir = '{}.ply'.format(parts[0])
82 | nodes[idx].load_box(boxes[idx])
83 |
84 | f.close()
85 | print('Load {} instances '.format(len(nodes)))
86 | if len(xyzi)>0:
87 | xyzi = np.concatenate(xyzi,axis=0)
88 |
89 | return {'nodes':nodes,
90 | 'edges':[],
91 | 'global_cloud':global_cloud,
92 | 'xyzi':xyzi}
93 |
94 | def load_processed_scene_graph(scan_dir:str):
95 |
96 | instance_nodes = {} # {idx:node_info}
97 |
98 | # Nodes
99 | nodes_data = pd.read_csv(os.path.join(scan_dir,'nodes.csv'))
100 | max_node_id = 0
101 |
102 | for idx, label, score, center, quat, extent, _ in zip(nodes_data['node_id'],
103 | nodes_data['label'],
104 | nodes_data['score'],
105 | nodes_data['center'],
106 | nodes_data['quaternion'],
107 | nodes_data['extent'],
108 | nodes_data['cloud_dir']):
109 | centroid = np.fromstring(center, dtype=float, sep=',')
110 | quaternion = np.fromstring(quat, dtype=float, sep=',') # (x,y,z,w)
111 | rot = R.from_quat(quaternion)
112 | extent = np.fromstring(extent, dtype=float, sep=',')
113 |
114 | if np.isnan(extent).any() or np.isnan(quaternion).any() or np.isnan(centroid).any() or np.isnan(idx):
115 | continue
116 | if '_' in label: label = label.replace('_',' ')
117 | pcd_dir = osp(scan_dir, str(idx).zfill(4)+'.ply')
118 | assert os.path.exists(pcd_dir), 'File not found: {}'.format(pcd_dir)
119 |
120 | instance_nodes[idx] = Instance(idx,
121 | o3d.io.read_point_cloud(pcd_dir),
122 | label,
123 | score)
124 | instance_nodes[idx].load_box(OBB(centroid,
125 | rot.as_matrix(),
126 | extent))
127 |
128 | if idx>max_node_id:
129 | max_node_id = idx
130 | print('Load {} instances'.format(len(instance_nodes)))
131 |
132 | # Instance Point Cloud
133 | xyzi = torch.load(os.path.join(scan_dir,'xyzi.pth')).numpy()
134 | instances = xyzi[:,-1].astype(np.int32)
135 | xyz = xyzi[:,:3].astype(np.float32)
136 | assert max_node_id == instances.max(), 'Instance ID mismatch'
137 | assert np.unique(instances).shape[0] == len(instance_nodes), 'Instance ID mismatch'
138 |
139 | # Global Point Cloud
140 | # colors = np.zeros_like(xyz)
141 | # for idx, instance in instance_nodes.items():
142 | # inst_mask = instances== idx
143 | # assert inst_mask.sum()>0
144 | # inst_color = 255*np.random.rand(3)
145 | # colors[inst_mask] = np.floor(inst_color).astype(np.int32)
146 | # instance.cloud = o3d.geometry.PointCloud(
147 | # o3d.utility.Vector3dVector(xyz[inst_mask]))
148 |
149 | # global_pcd = o3d.geometry.PointCloud(
150 | # o3d.utility.Vector3dVector(xyz))
151 | # global_pcd.colors = o3d.utility.Vector3dVector(colors)
152 | # global_pcd = o3d.io.read_point_cloud(os.path.join(scan_dir,'instance_map.ply'))
153 |
154 | global_pcd = o3d.geometry.PointCloud()
155 | for idx, instance in instance_nodes.items():
156 | global_pcd += instance.cloud
157 |
158 | return {'nodes':instance_nodes,
159 | 'edges':[],
160 | 'global_cloud':global_pcd}
161 |
162 |
163 | def transform_scene_graph(scene_graph:dict,
164 | transformation:np.ndarray):
165 |
166 | scene_graph['global_cloud'].transform(transformation)
167 | for idx, instance in scene_graph['nodes'].items():
168 | tmp_center = instance.box.center
169 | instance.cloud.transform(transformation)
170 | # todo: open3d rotate the bbox falsely
171 | # instance.box = o3d.geometry.OrientedBoundingBox.create_from_points(
172 | # o3d.utility.Vector3dVector(np.asarray(instance.cloud.points)))
173 | instance.box.rotate(R=transformation[:3,:3])
174 | instance.box.translate(transformation[:3,3])
175 |
176 |
177 |
--------------------------------------------------------------------------------
/sgreg/extensions/README.md:
--------------------------------------------------------------------------------
1 | # Geotransformer with grid subsampling for point clouds with semantic labels
2 | ## Grid Subsampling (Modified with semantic labels)
3 |
4 | This code implements the grid subsampling algorithm for point clouds. It performs grid subsampling on each batch of points and returns the subsampled points, subsampled instance labels, and lengths of the subsampled point cloud batches.
5 |
6 | It provides a function `grid_subsampling_cpu` that takes in a set of 3D points, **their corresponding instance labels**, and lengths of point cloud batches.
7 |
8 | ## Radius-based Neighbor Search
9 | The main function, `radius_neighbors_cpu`, performs a batched search on two sets of 3D points, storing the indices of the neighbors found within a given radius for each point in the first set.
10 |
11 |
12 |
--------------------------------------------------------------------------------
/sgreg/extensions/common/torch_helper.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include
4 | #include
5 |
6 | #define CHECK_CUDA(x) \
7 | TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
8 |
9 | #define CHECK_CPU(x) \
10 | TORCH_CHECK(!x.device().is_cuda(), #x " must be a CPU tensor")
11 |
12 | #define CHECK_CONTIGUOUS(x) \
13 | TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
14 |
15 | #define CHECK_INPUT(x) \
16 | CHECK_CUDA(x); \
17 | CHECK_CONTIGUOUS(x)
18 |
19 | #define CHECK_IS_INT(x) \
20 | do { \
21 | TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, \
22 | #x " must be an int tensor"); \
23 | } while (0)
24 |
25 | #define CHECK_IS_LONG(x) \
26 | do { \
27 | TORCH_CHECK(x.scalar_type() == at::ScalarType::Long, \
28 | #x " must be an long tensor"); \
29 | } while (0)
30 |
31 | #define CHECK_IS_FLOAT(x) \
32 | do { \
33 | TORCH_CHECK(x.scalar_type() == at::ScalarType::Float, \
34 | #x " must be a float tensor"); \
35 | } while (0)
36 |
--------------------------------------------------------------------------------
/sgreg/extensions/cpu/grid_subsampling/grid_subsampling.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include "grid_subsampling.h"
3 | #include "grid_subsampling_cpu.h"
4 |
5 | std::vector grid_subsampling(
6 | at::Tensor points,
7 | at::Tensor insts,
8 | at::Tensor lengths,
9 | float voxel_size
10 | ) {
11 | CHECK_CPU(points);
12 | CHECK_CPU(lengths);
13 | CHECK_IS_FLOAT(points);
14 | CHECK_IS_LONG(lengths);
15 | CHECK_CONTIGUOUS(points);
16 | CHECK_CONTIGUOUS(lengths);
17 |
18 | CHECK_CPU(insts);
19 | CHECK_IS_INT(insts);
20 | CHECK_CONTIGUOUS(insts);
21 |
22 | std::size_t batch_size = lengths.size(0);
23 | std::size_t total_points = points.size(0);
24 |
25 | std::vector vec_points = std::vector(
26 | reinterpret_cast(points.data_ptr()),
27 | reinterpret_cast(points.data_ptr()) + total_points
28 | );
29 | std::vector vec_s_points;
30 |
31 | std::vector vec_lengths = std::vector(
32 | lengths.data_ptr(),
33 | lengths.data_ptr() + batch_size
34 | );
35 | std::vector vec_s_lengths;
36 |
37 | std::vector vec_insts = std::vector(
38 | reinterpret_cast(insts.data_ptr()),
39 | reinterpret_cast(insts.data_ptr()) + total_points
40 | );
41 | std::vector vec_s_insts;
42 |
43 | grid_subsampling_cpu(
44 | vec_points,
45 | vec_s_points,
46 | vec_insts,
47 | vec_s_insts,
48 | vec_lengths,
49 | vec_s_lengths,
50 | voxel_size
51 | );
52 |
53 | std::size_t total_s_points = vec_s_points.size();
54 | at::Tensor s_points = torch::zeros(
55 | {total_s_points, 3},
56 | at::device(points.device()).dtype(at::ScalarType::Float)
57 | );
58 | at::Tensor s_lengths = torch::zeros(
59 | {batch_size},
60 | at::device(lengths.device()).dtype(at::ScalarType::Long)
61 | );
62 |
63 | at::Tensor s_insts = torch::zeros(
64 | {total_s_points, 1},
65 | at::device(insts.device()).dtype(at::ScalarType::Int)
66 | );
67 |
68 | std::memcpy(
69 | s_points.data_ptr(),
70 | reinterpret_cast(vec_s_points.data()),
71 | sizeof(float) * total_s_points * 3
72 | );
73 | std::memcpy(
74 | s_lengths.data_ptr(),
75 | vec_s_lengths.data(),
76 | sizeof(long) * batch_size
77 | );
78 |
79 | std::memcpy(
80 | s_insts.data_ptr(),
81 | reinterpret_cast(vec_s_insts.data()),
82 | sizeof(int) * total_s_points * 1
83 | );
84 |
85 |
86 | return {s_points, s_insts, s_lengths};
87 | }
88 |
--------------------------------------------------------------------------------
/sgreg/extensions/cpu/grid_subsampling/grid_subsampling.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include
4 | #include "../../common/torch_helper.h"
5 |
6 | std::vector grid_subsampling(
7 | at::Tensor points,
8 | at::Tensor insts,
9 | at::Tensor lengths,
10 | float voxel_size
11 | );
12 |
--------------------------------------------------------------------------------
/sgreg/extensions/cpu/grid_subsampling/grid_subsampling_cpu.cpp:
--------------------------------------------------------------------------------
1 | #include "grid_subsampling_cpu.h"
2 |
3 | void single_grid_subsampling_cpu(
4 | std::vector& points,
5 | std::vector& s_points,
6 | std::vector& inst,
7 | std::vector& s_inst,
8 | float voxel_size
9 | ) {
10 | // float sub_scale = 1. / voxel_size;
11 | PointXYZ minCorner = min_point(points);
12 | PointXYZ maxCorner = max_point(points);
13 | PointXYZ originCorner = floor(minCorner * (1. / voxel_size)) * voxel_size;
14 |
15 | std::size_t sampleNX = static_cast(
16 | // floor((maxCorner.x - originCorner.x) * sub_scale) + 1
17 | floor((maxCorner.x - originCorner.x) / voxel_size) + 1
18 | );
19 | std::size_t sampleNY = static_cast(
20 | // floor((maxCorner.y - originCorner.y) * sub_scale) + 1
21 | floor((maxCorner.y - originCorner.y) / voxel_size) + 1
22 | );
23 |
24 | std::size_t iX = 0;
25 | std::size_t iY = 0;
26 | std::size_t iZ = 0;
27 | std::size_t mapIdx = 0;
28 | std::unordered_map data;
29 | std::unordered_map data_inst;
30 |
31 | int i = 0;
32 | for (auto& p : points) {
33 | // iX = static_cast(floor((p.x - originCorner.x) * sub_scale));
34 | // iY = static_cast(floor((p.y - originCorner.y) * sub_scale));
35 | // iZ = static_cast(floor((p.z - originCorner.z) * sub_scale));
36 | iX = static_cast(floor((p.x - originCorner.x) / voxel_size));
37 | iY = static_cast(floor((p.y - originCorner.y) / voxel_size));
38 | iZ = static_cast(floor((p.z - originCorner.z) / voxel_size));
39 | mapIdx = iX + sampleNX * iY + sampleNX * sampleNY * iZ;
40 |
41 | if (!data.count(mapIdx)) {
42 | data.emplace(mapIdx, SampledData());
43 | data_inst.emplace(mapIdx, SampledInst());
44 |
45 | }
46 |
47 | data[mapIdx].update(p);
48 | data_inst[mapIdx].update(inst[i]);
49 | i = i + 1;
50 | }
51 |
52 | s_points.reserve(data.size());
53 | for (auto& v : data) {
54 | s_points.push_back(v.second.point * (1.0 / v.second.count));
55 | }
56 |
57 | s_inst.reserve(data_inst.size());
58 | for (auto& v : data_inst) {
59 |
60 | std::map histo_map;
61 | for (int i : v.second.inst_ids) {
62 | histo_map[i]++;
63 | }
64 | auto maxElem = std::max_element(histo_map.begin(), histo_map.end(),
65 | [](const std::pair& a, const std::pair& b) {
66 | return a.second < b.second;
67 | });
68 | s_inst.push_back(maxElem->first);
69 | }
70 |
71 |
72 | }
73 |
74 | void grid_subsampling_cpu(
75 | std::vector& points,
76 | std::vector& s_points,
77 | std::vector& insts,
78 | std::vector& s_insts,
79 | std::vector& lengths,
80 | std::vector& s_lengths,
81 | float voxel_size
82 | ) {
83 | std::size_t start_index = 0;
84 | std::size_t batch_size = lengths.size();
85 | for (std::size_t b = 0; b < batch_size; b++) {
86 | std::vector cur_points = std::vector(
87 | points.begin() + start_index,
88 | points.begin() + start_index + lengths[b]
89 | );
90 | std::vector cur_s_points;
91 |
92 | std::vector cur_insts = std::vector(
93 | insts.begin() + start_index,
94 | insts.begin() + start_index + lengths[b]
95 | );
96 | std::vector cur_s_insts;
97 |
98 | single_grid_subsampling_cpu(cur_points, cur_s_points, cur_insts, cur_s_insts, voxel_size);
99 |
100 | s_points.insert(s_points.end(), cur_s_points.begin(), cur_s_points.end());
101 | s_insts.insert(s_insts.end(), cur_s_insts.begin(), cur_s_insts.end());
102 | s_lengths.push_back(cur_s_points.size());
103 |
104 | start_index += lengths[b];
105 | }
106 |
107 | return;
108 | }
109 |
--------------------------------------------------------------------------------
/sgreg/extensions/cpu/grid_subsampling/grid_subsampling_cpu.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include
4 | #include
5 | #include "../../extra/cloud/cloud.h"
6 |
7 | class SampledData {
8 | public:
9 | int count;
10 | PointXYZ point;
11 |
12 | SampledData() {
13 | count = 0;
14 | point = PointXYZ();
15 | }
16 |
17 | void update(const PointXYZ& p) {
18 | count += 1;
19 | point += p;
20 | }
21 | };
22 |
23 | class SampledInst {
24 | public:
25 | int count;
26 | std::vector inst_ids;
27 |
28 | SampledInst() {
29 | count = 0;
30 | }
31 |
32 | void update(const int inst_id) {
33 | count += 1;
34 | inst_ids.push_back(inst_id);
35 | }
36 | };
37 |
38 | void single_grid_subsampling_cpu(
39 | std::vector& o_points,
40 | std::vector& s_points,
41 | std::vector& inst,
42 | std::vector& s_inst,
43 | float voxel_size
44 | );
45 |
46 | void grid_subsampling_cpu(
47 | std::vector& o_points,
48 | std::vector& s_points,
49 | std::vector& o_insts,
50 | std::vector& s_insts,
51 | std::vector& o_lengths,
52 | std::vector& s_lengths,
53 | float voxel_size
54 | );
55 |
56 |
--------------------------------------------------------------------------------
/sgreg/extensions/cpu/radius_neighbors/radius_neighbors.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include "radius_neighbors.h"
3 | #include "radius_neighbors_cpu.h"
4 |
5 | at::Tensor radius_neighbors(
6 | at::Tensor q_points,
7 | at::Tensor s_points,
8 | at::Tensor q_lengths,
9 | at::Tensor s_lengths,
10 | float radius
11 | ) {
12 | CHECK_CPU(q_points);
13 | CHECK_CPU(s_points);
14 | CHECK_CPU(q_lengths);
15 | CHECK_CPU(s_lengths);
16 | CHECK_IS_FLOAT(q_points);
17 | CHECK_IS_FLOAT(s_points);
18 | CHECK_IS_LONG(q_lengths);
19 | CHECK_IS_LONG(s_lengths);
20 | CHECK_CONTIGUOUS(q_points);
21 | CHECK_CONTIGUOUS(s_points);
22 | CHECK_CONTIGUOUS(q_lengths);
23 | CHECK_CONTIGUOUS(s_lengths);
24 |
25 | std::size_t total_q_points = q_points.size(0);
26 | std::size_t total_s_points = s_points.size(0);
27 | std::size_t batch_size = q_lengths.size(0);
28 |
29 | std::vector vec_q_points = std::vector(
30 | reinterpret_cast(q_points.data_ptr()),
31 | reinterpret_cast(q_points.data_ptr()) + total_q_points
32 | );
33 | std::vector vec_s_points = std::vector(
34 | reinterpret_cast(s_points.data_ptr()),
35 | reinterpret_cast(s_points.data_ptr()) + total_s_points
36 | );
37 | std::vector vec_q_lengths = std::vector(
38 | q_lengths.data_ptr(), q_lengths.data_ptr() + batch_size
39 | );
40 | std::vector vec_s_lengths = std::vector(
41 | s_lengths.data_ptr(), s_lengths.data_ptr() + batch_size
42 | );
43 | std::vector vec_neighbor_indices;
44 |
45 | radius_neighbors_cpu(
46 | vec_q_points,
47 | vec_s_points,
48 | vec_q_lengths,
49 | vec_s_lengths,
50 | vec_neighbor_indices,
51 | radius
52 | );
53 |
54 | std::size_t max_neighbors = vec_neighbor_indices.size() / total_q_points;
55 |
56 | at::Tensor neighbor_indices = torch::zeros(
57 | {total_q_points, max_neighbors},
58 | at::device(q_points.device()).dtype(at::ScalarType::Long)
59 | );
60 |
61 | std::memcpy(
62 | neighbor_indices.data_ptr(),
63 | vec_neighbor_indices.data(),
64 | sizeof(long) * total_q_points * max_neighbors
65 | );
66 |
67 | return neighbor_indices;
68 | }
69 |
--------------------------------------------------------------------------------
/sgreg/extensions/cpu/radius_neighbors/radius_neighbors.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include "../../common/torch_helper.h"
4 |
5 | at::Tensor radius_neighbors(
6 | at::Tensor q_points,
7 | at::Tensor s_points,
8 | at::Tensor q_lengths,
9 | at::Tensor s_lengths,
10 | float radius
11 | );
12 |
--------------------------------------------------------------------------------
/sgreg/extensions/cpu/radius_neighbors/radius_neighbors_cpu.cpp:
--------------------------------------------------------------------------------
1 | #include "radius_neighbors_cpu.h"
2 |
3 |
4 | void radius_neighbors_cpu(
5 | std::vector& q_points,
6 | std::vector& s_points,
7 | std::vector& q_lengths,
8 | std::vector& s_lengths,
9 | std::vector& neighbor_indices,
10 | float radius
11 | ) {
12 | std::size_t i0 = 0;
13 | float r2 = radius * radius;
14 |
15 | std::size_t max_count = 0;
16 | std::vector>> all_inds_dists(
17 | q_points.size()
18 | );
19 |
20 | std::size_t b = 0;
21 | std::size_t q_start_index = 0;
22 | std::size_t s_start_index = 0;
23 |
24 | PointCloud current_cloud;
25 | current_cloud.pts = std::vector(
26 | s_points.begin() + s_start_index,
27 | s_points.begin() + s_start_index + s_lengths[b]
28 | );
29 |
30 | nanoflann::KDTreeSingleIndexAdaptorParams tree_params(10);
31 | my_kd_tree_t* index = new my_kd_tree_t(3, current_cloud, tree_params);;
32 | index->buildIndex();
33 |
34 | nanoflann::SearchParams search_params;
35 | search_params.sorted = true;
36 |
37 | for (auto& p0 : q_points) {
38 | if (i0 == q_start_index + q_lengths[b]) {
39 | q_start_index += q_lengths[b];
40 | s_start_index += s_lengths[b];
41 | b++;
42 |
43 | current_cloud.pts.clear();
44 | current_cloud.pts = std::vector(
45 | s_points.begin() + s_start_index,
46 | s_points.begin() + s_start_index + s_lengths[b]
47 | );
48 |
49 | delete index;
50 | index = new my_kd_tree_t(3, current_cloud, tree_params);
51 | index->buildIndex();
52 | }
53 |
54 | all_inds_dists[i0].reserve(max_count);
55 | float query_pt[3] = {p0.x, p0.y, p0.z};
56 | std::size_t nMatches = index->radiusSearch(
57 | query_pt, r2, all_inds_dists[i0], search_params
58 | );
59 |
60 | if (nMatches > max_count) {
61 | max_count = nMatches;
62 | }
63 |
64 | i0++;
65 | }
66 |
67 | delete index;
68 |
69 | neighbor_indices.resize(q_points.size() * max_count);
70 | i0 = 0;
71 | s_start_index = 0;
72 | q_start_index = 0;
73 | b = 0;
74 | for (auto& inds_dists : all_inds_dists) {
75 | if (i0 == q_start_index + q_lengths[b]) {
76 | q_start_index += q_lengths[b];
77 | s_start_index += s_lengths[b];
78 | b++;
79 | }
80 |
81 | for (std::size_t j = 0; j < max_count; j++) {
82 | std::size_t i = i0 * max_count + j;
83 | if (j < inds_dists.size()) {
84 | neighbor_indices[i] = inds_dists[j].first + s_start_index;
85 | } else {
86 | neighbor_indices[i] = s_points.size();
87 | }
88 | }
89 |
90 | i0++;
91 | }
92 | }
93 |
--------------------------------------------------------------------------------
/sgreg/extensions/cpu/radius_neighbors/radius_neighbors_cpu.h:
--------------------------------------------------------------------------------
1 | #include
2 | #include "../../extra/cloud/cloud.h"
3 | #include "../../extra/nanoflann/nanoflann.hpp"
4 |
5 | typedef nanoflann::KDTreeSingleIndexAdaptor<
6 | nanoflann::L2_Simple_Adaptor, PointCloud, 3
7 | > my_kd_tree_t;
8 |
9 | void radius_neighbors_cpu(
10 | std::vector& q_points,
11 | std::vector& s_points,
12 | std::vector& q_lengths,
13 | std::vector& s_lengths,
14 | std::vector& neighbor_indices,
15 | float radius
16 | );
17 |
--------------------------------------------------------------------------------
/sgreg/extensions/extra/cloud/cloud.cpp:
--------------------------------------------------------------------------------
1 | // Modified from https://github.com/HuguesTHOMAS/KPConv-PyTorch
2 | #include "cloud.h"
3 |
4 | PointXYZ max_point(std::vector points) {
5 | PointXYZ maxP(points[0]);
6 |
7 | for (auto p : points) {
8 | if (p.x > maxP.x) {
9 | maxP.x = p.x;
10 | }
11 | if (p.y > maxP.y) {
12 | maxP.y = p.y;
13 | }
14 | if (p.z > maxP.z) {
15 | maxP.z = p.z;
16 | }
17 | }
18 |
19 | return maxP;
20 | }
21 |
22 | PointXYZ min_point(std::vector points) {
23 | PointXYZ minP(points[0]);
24 |
25 | for (auto p : points) {
26 | if (p.x < minP.x) {
27 | minP.x = p.x;
28 | }
29 | if (p.y < minP.y) {
30 | minP.y = p.y;
31 | }
32 | if (p.z < minP.z) {
33 | minP.z = p.z;
34 | }
35 | }
36 |
37 | return minP;
38 | }
--------------------------------------------------------------------------------
/sgreg/extensions/extra/cloud/cloud.h:
--------------------------------------------------------------------------------
1 | // Modified from https://github.com/HuguesTHOMAS/KPConv-PyTorch
2 | #pragma once
3 |
4 | #include
5 | #include
6 | #include