├── .gitignore ├── README.md ├── config ├── rio.yaml └── scannet.yaml ├── docs ├── rerun_interface.png └── system.001.png ├── requirements.txt ├── setup.py ├── sgreg ├── __init__.py ├── backbone │ ├── __init__.py │ ├── backbone.py │ ├── senet.py │ └── shape_encoder.py ├── bert │ ├── bertwarper.py │ └── get_tokenizer.py ├── dataset │ ├── __init__.py │ ├── dataset_factory.py │ ├── generate_gt_association.py │ ├── prepare_semantics.py │ ├── scene_graph.py │ ├── scene_pair_dataset.py │ └── stack_mode.py ├── extensions │ ├── README.md │ ├── common │ │ └── torch_helper.h │ ├── cpu │ │ ├── grid_subsampling │ │ │ ├── grid_subsampling.cpp │ │ │ ├── grid_subsampling.h │ │ │ ├── grid_subsampling_cpu.cpp │ │ │ └── grid_subsampling_cpu.h │ │ └── radius_neighbors │ │ │ ├── radius_neighbors.cpp │ │ │ ├── radius_neighbors.h │ │ │ ├── radius_neighbors_cpu.cpp │ │ │ └── radius_neighbors_cpu.h │ ├── extra │ │ ├── cloud │ │ │ ├── cloud.cpp │ │ │ └── cloud.h │ │ └── nanoflann │ │ │ └── nanoflann.hpp │ └── pybind.cpp ├── gnn │ ├── __init__.py │ ├── gnn.py │ ├── nodes_init_layer.py │ ├── spatial_attention.py │ └── triplet_gnn.py ├── kpconv │ ├── __init__.py │ ├── dispositions │ │ └── k_015_center_3D.ply │ ├── functional.py │ ├── kernel_points.py │ ├── kpconv.py │ └── modules.py ├── loss │ ├── eval.py │ └── loss.py ├── match │ ├── __init__.py │ ├── learnable_sinkhorn.py │ └── match.py ├── ops │ ├── __init__.py │ ├── grid_subsample.py │ ├── index_select.py │ ├── instance_partition.py │ ├── pairwise_distance.py │ ├── radius_search.py │ └── transformation.py ├── registration │ ├── __init__.py │ ├── hybrid_reg.py │ ├── local_global_registration.py │ ├── metrics.py │ ├── offline_registration.py │ └── procrustes.py ├── sg_reg.py ├── train.py ├── utils │ ├── __init__.py │ ├── config.py │ ├── io.py │ ├── tictoc.py │ ├── torch.py │ ├── utils.py │ └── viz_tools.py ├── val.py └── visualize.py └── tutorials ├── explicit_sg.png └── rag_data.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | build 2 | data 3 | output 4 | checkpoints 5 | sgreg/__pycache__ 6 | sgreg/*/__pycache__ 7 | .vscode 8 | 9 | *.egg-info 10 | *.so -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 |
3 |

SG-Reg: Generalizable and Efficient
Scene Graph Registration

4 | Accepted by IEEE T-RO 5 |
6 | Chuhao Liu1, 7 | Zhijian Qiao1, 8 | Jieqi Shi2,*, 9 | Ke Wang3, 10 | Peize Liu 1 11 | and Shaojie Shen1 12 |

13 | 14 | 1HKUST Aerial Robotics Group    15 | 2 NanJing University    16 | 3Chang'an University    17 |
18 | 19 | *Corresponding Author 20 |

21 |
22 | 23 | [![T-RO](https://img.shields.io/badge/IEEE-T--RO-004c99)](https://ieeexplore.ieee.org/xpl/RecentIssue.jsp?punumber=8860) 24 | [![Arxiv](https://img.shields.io/badge/arXiv-2504.14440-990000)](https://arxiv.org/abs/2504.14440) 25 | [![YouTube](https://badges.aleen42.com/src/youtube.svg)](https://youtu.be/s3P1FvbQGhs) 26 | [![Bilibili](https://img.shields.io/badge/Video-Bilibili-pink)](https://www.bilibili.com/video/BV1ymLWzaEMo/) 27 | [![HuggingFace Space](https://img.shields.io/badge/🤗-HuggingFace%20Space-cyan.svg)](https://huggingface.co/glennliu/sgnet) 28 | 29 |

30 | 31 |

32 | 33 | ### News 34 | * [21 Apr 2025] Publish the initial version of code. 35 | * [19 Apr 2025] Our paper is accepted by [IEEE T-RO](https://ieeexplore.ieee.org/xpl/RecentIssue.jsp?punumber=8860) as a regular paper. 36 | * [8 Oct 2024] Paper submitted to IEEE T-RO. 37 | 38 | In this work, we **learn to register two semantic scene graphs**, an essential capability when an autonomous agent needs to register its map against a remote agent, or against a prior map. To acehive a generalizable registration in the real-world, we design a scene graph network to encode multiple modalities of semantic nodes: open-set semantic feature, local topology with spatial awareness, and shape feature. SG-Reg represents a dense indoor scene in **coarse node features** and **dense point features**. In multi-agent SLAM systems, this representation supports both coarse-to-fine localization and bandwidth-efficient communication. 39 | We generate semantic scene graph using [vision foundation models](https://github.com/IDEA-Research/Grounded-Segment-Anything) and semantic mapping module [FM-Fusion](https://github.com/HKUST-Aerial-Robotics/FM-Fusion). It eliminates the need for ground-truth semantic annotations, enabling **fully self-supervised network training**. 40 | We evaluate our method using real-world RGB-D sequences: [ScanNet](https://github.com/ScanNet/ScanNet), [3RScan](https://github.com/WaldJohannaU/3RScan) and self-collected data using [Realsense i-435](https://www.intelrealsense.com/lidar-camera-l515/). 41 | ## 1. Install 42 | Create virtual environment, 43 | ```bash 44 | conda create sgreg python=3.9 45 | ``` 46 | Install PyTorch 2.1.2 and other dependencies. 47 | ```bash 48 | conda install pytorch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 pytorch-cuda=11.8 -c pytorch -c nvidia 49 | ``` 50 | ```bash 51 | pip install -r requirements.txt 52 | python setup.py build develop 53 | ``` 54 | 55 | ## 2. Download Dataset 56 | Download the *3RScan (RIO)* data [坚果云nutStore link](https://www.jianguoyun.com/p/DVNIaZYQmcSyDRjX8PQFIAA). It involves $50$ pairs of scene graphs. In ```RIO_DATAROOT```, the data are organized in the following structures. 57 | ``` 58 | |--val 59 | |--scenexxxx_00a % each individual scene graph 60 | |-- .... 61 | |--splits 62 | |-- val.txt 63 | |--gt 64 | |-- SRCSCENE-REFSCENE.txt % T_ref_src 65 | |--matches 66 | |-- SRCSCENE-REFSCENE.pth % ground-truth node matches 67 | |--output 68 | |--CHECKPOINT_NAME % default: sgnet_scannet_0080 69 | |--SRCSCENE-REFSCENE % results of scene pair 70 | ``` 71 | 72 | 73 | We also provide another 50 pairs of *ScanNet* scenes. Please download the ScanNet data using this [坚果云nutStore link](https://www.jianguoyun.com/p/DSJqTN8QmcSyDRjZ8PQFIAA). They are organized in the same data structure as the 3RScan data. 74 | 75 | *Note: We did not use any ground-truth semantic annotation from [3RScan](https://github.com/WaldJohannaU/3RScan) or [ScanNet](https://github.com/ScanNet/ScanNet). The downloaded scene graphs are reconstructed using [FM-Fusion](https://github.com/HKUST-Aerial-Robotics/FM-Fusion). You can also download the original RGB-D sequences and build your scene graphs using FM-Fusion. If you want to try, ScanNet sequences should be easier to start with. 76 | 77 | ## 3. Inference 3RScan Scenes 78 | Find the [config/rio.yaml](config/rio.yaml) and set the ```dataroot/dataroot``` to be the ```RIO_DATASET``` directory on your machine. Then, run the inference program, 79 | ```bash 80 | python sgreg/val.py --cfg_file config/rio.yaml 81 | ``` 82 | It will inference all of the downloaded scene pairs in 3RScan. The registration results, including matched nodes, point correspondences and predicted transformation are saved at ```RIO_DATAROOT/ouptut/CHECKPOINT_NAME/SRCSCENE-REFSCENE```. You can visualize the registration results, 83 | ```bash 84 | python sgreg/visualize.py --dataroot $RIO_DATAROOT$ --viz_mode 1 --find_gt --viz_translation [3.0,5.0,0.0] 85 | ``` 86 | It should visualize the results as below, 87 |

88 | 89 |

90 | On the left column, you can select the entities you want to visualize. 91 | 92 | If you run the program on a remote server, rerun supports remote visualization (see [rerun connect_tcp](https://ref.rerun.io/docs/python/0.22.1/common/initialization_functions/#rerun)). Check the arguments instruction in [visualize.py](sgreg/visualize.py) to customize your visualization. 93 | 94 | *[Optional]* If you want to evaluate SG-Reg on ScanNet sequences, adjust the running options as below, 95 | ```bash 96 | python sgreg/val.py --cfg_file config/scannet.yaml 97 | python sgreg/visualize.py --dataroot $SCANNET_DATAROOT$ --viz_mode 1 --augment_transform --viz_translation [3.0,5.0,0.0] 98 | ``` 99 | 100 | ## 4. Evaluate on your own data 101 | We think generalization capability remains to be a key challenge in 3D semantic perception. If you are interested in the task we are doing, we encourage you to collect your own RGB-D sequence to evaluate. 102 | It requires [VINS-Mono](https://github.com/HKUST-Aerial-Robotics/VINS-Mono) to compute camera poses, [Grounded-SAM](https://github.com/IDEA-Research/Grounded-Segment-Anything) to generate semantic labels, and [FM-Fusion](https://github.com/HKUST-Aerial-Robotics/FM-Fusion) to reconstruct a semantic scene graph. 103 | We will add a detailed instruction later to illustrate how to build your own data. 104 | 105 | ## 5. Develop Log 106 | - [x] Scene graph network code and verify its inference. 107 | - [x] Remove unncessary dependencies. 108 | - [x] Clean the data structure. 109 | - [x] Visualize the results. 110 | - [x] Provide RIO scene graph data for download. 111 | - [x] Provide network weight for download. 112 | - [x] Publish checkpoint on Huggingface Hub and reload. 113 | - [ ] Registration back-end in python interface. (The version used in the paper is a C++ version.) 114 | - [ ] Validation the entire system in a new computer. 115 | - [x] A tutorial for running the validation. 116 | 117 | We will continue to maintain this repo. If you encounter any problem in using it, feel free to publish an issue. We'll try to help. 118 | 119 | ## 6. Acknowledge 120 | We used some of the code from [GeoTransformer](https://github.com/qinzheng93/GeoTransformer), [SG-PGM](https://github.com/dfki-av/sg-pgm) and [LightGlue](https://github.com/cvg/LightGlue). [SkyLand](https://www.futureis3d.com) provides lidar-camera suite to allow us evaluating SG-Reg in large-scale scenes (as demonstrated at the end of the [video](https://youtu.be/k_kPFKcj-jo)). 121 | 122 | ## 7. License 123 | The source code is released under [GPLv3](https://www.gnu.org/licenses/) license. 124 | For technical issues, please contact Chuhao LIU (cliuci@connect.ust.hk). 125 | -------------------------------------------------------------------------------- /config/rio.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | fix_modules: ['backbone','optimal_transport','shape_backbone','sgnn'] 3 | dataset: 4 | dataroot: /data2/RioGraph 5 | load_transform: True 6 | min_iou: 0.2 # for gt match 7 | global_edges_dist: 2.0 8 | online_bert: True 9 | train: 10 | batch_size: 1 11 | num_workers: 8 12 | epochs: 80 13 | optimizer: Adam # Adam, SGD 14 | lr: 0.01 15 | weight_decay: 0.0001 16 | momentum: 0.9 17 | save_interval: 10 18 | val_interval: 1 19 | registration_in_train: True 20 | backbone: 21 | num_stages: 4 22 | init_voxel_size: 0.05 23 | base_radius: 2.5 24 | base_sigma: 2.0 25 | input_dim: 1 26 | output_dim: 256 27 | init_dim: 64 28 | kernel_size: 15 29 | group_norm: 32 30 | shape_encoder: 31 | input_from_stages: 1 32 | output_dim: 1024 33 | kernel_size: 15 34 | init_radius: 2.5 # 0.125 35 | init_sigma: 2.0 # 0.1 36 | group_norm: 32 37 | point_limit: 2048 38 | scenegraph: 39 | bert_dim: 768 40 | semantic_dim: 64 41 | pos_dim: -1 42 | box_dim: 8 43 | fuse_shape: True 44 | fuse_stage: late 45 | node_dim: 64 46 | encode_method: 'gnn' 47 | gnn: 48 | all_self_edges: False 49 | position_encoding: true 50 | layers: [ltriplet] # sage,self,gtop 51 | triplet_mlp: concat # concat, ffn or projector 52 | triplet_activation: gelu 53 | triplet_number: 20 54 | heads: 1 55 | hidden_dim : 16 56 | enable_dist_embedding: true 57 | enable_angle_embedding: True 58 | reduce: 'mean' 59 | se_layer: False 60 | instance_matching: 61 | match_layers: [0,1,2] 62 | topk: 3 63 | min_score: 0.1 64 | multiply_matchability: false 65 | fine_matching: 66 | min_nodes: 3 67 | num_points_per_instance: 256 68 | num_sinkhorn_iterations: 100 69 | topk: 3 70 | acceptance_radius: 0.1 71 | max_instance_selection: true 72 | mutual: True 73 | confidence_threshold: 0.05 74 | use_dustbin: True 75 | ignore_semantics: [floor, carpet, ceiling,] #[floor, carpet, ceiling] 76 | loss: 77 | instance_match: 78 | loss_func: nllv2 # overlap, nll 79 | nll_negative: 0.0 80 | gt_alpha: 2.0 81 | fine_loss_weight: -1.0 82 | shape_contrast_weight: -1.0 83 | gnode_match_weight: 1.0 84 | start_epoch: -1 85 | positive_point_radius: 0.1 86 | contrastive_temp: 0.1 87 | contrastive_postive_overlap: 0.3 88 | # nll_matchability: False 89 | eval: 90 | acceptance_overlap: 0.0 91 | acceptance_radius: 0.1 92 | rmse_threshold: 0.2 93 | gt_node_iou: 0.3 -------------------------------------------------------------------------------- /config/scannet.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | # pretrain_weight: chkt2.0/pretrain_finetuned5/sgnet-0064.pth 3 | fix_modules: ['backbone','optimal_transport','shape_backbone','sgnn'] 4 | dataset: 5 | dataroot: /data2/ScanNetGraph 6 | min_iou: 0.3 # for gt match 7 | global_edges_dist: 1.0 8 | online_bert: False 9 | load_transform: False 10 | train: 11 | batch_size: 1 12 | num_workers: 8 13 | epochs: 64 14 | optimizer: Adam # Adam, SGD 15 | lr: 0.01 16 | weight_decay: 0.0001 17 | momentum: 0.9 18 | save_interval: 32 19 | val_interval: 2 20 | registration_in_train: true 21 | backbone: 22 | num_stages: 4 23 | init_voxel_size: 0.05 24 | base_radius: 2.5 25 | base_sigma: 2.0 26 | input_dim: 1 27 | output_dim: 256 28 | init_dim: 64 29 | kernel_size: 15 30 | group_norm: 32 31 | shape_encoder: 32 | input_from_stages: 1 33 | output_dim: 1024 34 | kernel_size: 15 35 | init_radius: 2.5 # 0.125 36 | init_sigma: 2.0 # 0.1 37 | group_norm: 32 38 | point_limit: 2048 39 | scenegraph: 40 | bert_dim: 768 41 | semantic_dim: 64 42 | pos_dim: -1 43 | box_dim: 8 44 | fuse_shape: true 45 | fuse_stage: late 46 | node_dim: 64 47 | encode_method: 'gnn' # 'geotransformer' or 'gnn' 48 | geotransformer: 49 | num_heads: 1 50 | blocks: [self] 51 | sigma_d: 0.2 52 | sigma_a: 15 53 | angle_k: 3 54 | reduction_a: max 55 | gnn: 56 | all_self_edges: false 57 | position_encoding: true 58 | layers: [ltriplet] # sage,self,gtop,sattn 59 | triplet_mlp: concat # concat, ffn or projector 60 | triplet_activation: relu 61 | heads: 1 62 | hidden_dim : 16 63 | enable_dist_embedding: true 64 | enable_angle_embedding: true 65 | reduce: 'mean' 66 | se_layer: False 67 | instance_matching: 68 | match_layers: [0,1,2] 69 | topk: 3 70 | min_score: 0.1 71 | multiply_matchability: false 72 | fine_matching: 73 | min_nodes: 3 74 | num_points_per_instance: 256 75 | num_sinkhorn_iterations: 100 76 | topk: 3 77 | acceptance_radius: 0.1 78 | max_instance_selection: true 79 | mutual: True 80 | confidence_threshold: 0.05 81 | use_dustbin: false 82 | ignore_semantics: [floor, carpet] 83 | loss: 84 | instance_match: 85 | loss_func: nllv2 # overlap, nll 86 | nll_negative: 0.0 87 | gt_alpha: 2.0 88 | fine_loss_weight: -1.0 89 | shape_contrast_weight: -1.0 90 | gnode_match_weight: 1.0 91 | start_epoch: -1 92 | positive_point_radius: 0.1 93 | contrastive_temp: 0.1 94 | contrastive_postive_overlap: 0.3 95 | eval: 96 | acceptance_overlap: 0.0 97 | acceptance_radius: 0.1 98 | rmse_threshold: 0.2 99 | gt_node_iou: 0.3 -------------------------------------------------------------------------------- /docs/rerun_interface.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUST-Aerial-Robotics/SG-Reg/c164198cec84be11dc53101755b0d9f7a4bc5082/docs/rerun_interface.png -------------------------------------------------------------------------------- /docs/system.001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUST-Aerial-Robotics/SG-Reg/c164198cec84be11dc53101755b0d9f7a4bc5082/docs/system.001.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch_geometric 2 | numpy==1.26.2 3 | omegaconf 4 | transformers 5 | open3d 6 | matplotlib 7 | scipy 8 | tensorboard 9 | rerun-sdk -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | ''' This is a module adopted from SG-PGM: 2 | git@github.com:dfki-av/sg-pgm.git 3 | It maintains the instance label while downsampling the point cloud. 4 | ''' 5 | 6 | from setuptools import setup 7 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 8 | 9 | 10 | setup( 11 | name='kpconv_extension', 12 | version='1.0.0', 13 | ext_modules=[ 14 | CUDAExtension( 15 | name='.ext', 16 | sources=[ 17 | 'sgreg/extensions/extra/cloud/cloud.cpp', 18 | 'sgreg/extensions/cpu/grid_subsampling/grid_subsampling.cpp', 19 | 'sgreg/extensions/cpu/grid_subsampling/grid_subsampling_cpu.cpp', 20 | 'sgreg/extensions/cpu/radius_neighbors/radius_neighbors.cpp', 21 | 'sgreg/extensions/cpu/radius_neighbors/radius_neighbors_cpu.cpp', 22 | 'sgreg/extensions/pybind.cpp', 23 | ], 24 | ), 25 | ], 26 | cmdclass={'build_ext': BuildExtension}, 27 | ) 28 | -------------------------------------------------------------------------------- /sgreg/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUST-Aerial-Robotics/SG-Reg/c164198cec84be11dc53101755b0d9f7a4bc5082/sgreg/__init__.py -------------------------------------------------------------------------------- /sgreg/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | from sgreg.backbone.backbone import KPConvFPN, encode_batch_scenes_points 2 | from sgreg.backbone.shape_encoder import KPConvShape, encode_batch_scenes_instances 3 | from sgreg.backbone.senet import SEModule -------------------------------------------------------------------------------- /sgreg/backbone/backbone.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import torch.nn as nn 4 | from sgreg.utils.tictoc import TicToc 5 | from sgreg.kpconv import ConvBlock, ResidualBlock, UnaryBlock, LastUnaryBlock, nearest_upsample 6 | 7 | def encode_batch_scenes_points(backbone:nn.Module, 8 | batch_graph_dict:dict): 9 | ''' Read a batch of scene pairs. 10 | Encode all of their fine points and features. 11 | ''' 12 | stack_feats_f = [] 13 | stack_feats_f_batch = [torch.tensor(0).long()] 14 | 15 | for scene_id in torch.arange(batch_graph_dict['batch_size']): 16 | tictoc = TicToc() 17 | feats = batch_graph_dict['batch_features'][scene_id].detach() 18 | data_dict = batch_graph_dict['batch_points'][scene_id] 19 | 20 | # 2. KPFCNN Encoder 21 | feats_list = backbone(feats, 22 | data_dict['points'], 23 | data_dict['neighbors'], 24 | data_dict['subsampling'], 25 | data_dict['upsampling']) 26 | duration = tictoc.toc() 27 | # 28 | points_f = data_dict['points'][1] 29 | feats_f = feats_list[0] 30 | assert feats_f.shape[0] == points_f.shape[0], 'feats_f and points_f shape are not match' 31 | 32 | stack_feats_f.append(feats_f) # (P+Q,C) 33 | stack_feats_f_batch.append(stack_feats_f_batch[-1]+feats_f.shape[0]) 34 | 35 | # 36 | stack_feats_f = torch.cat(stack_feats_f,dim=0) # (P+Q,C) 37 | stack_feats_f_batch = torch.stack(stack_feats_f_batch,dim=0) # (B+1,) 38 | 39 | return {'feats_f':stack_feats_f, 40 | 'feats_batch':stack_feats_f_batch}, duration 41 | 42 | class KPConvFPN(nn.Module): 43 | def __init__(self, input_dim, output_dim, init_dim, kernel_size, init_radius, init_sigma, group_norm): 44 | super(KPConvFPN, self).__init__() 45 | 46 | self.encoder1_1 = ConvBlock(input_dim, init_dim, kernel_size, init_radius, init_sigma, group_norm) 47 | self.encoder1_2 = ResidualBlock(init_dim, init_dim * 2, kernel_size, init_radius, init_sigma, group_norm) 48 | 49 | self.encoder2_1 = ResidualBlock( 50 | init_dim * 2, init_dim * 2, kernel_size, init_radius, init_sigma, group_norm, strided=True 51 | ) 52 | self.encoder2_2 = ResidualBlock( 53 | init_dim * 2, init_dim * 4, kernel_size, init_radius * 2, init_sigma * 2, group_norm 54 | ) 55 | self.encoder2_3 = ResidualBlock( 56 | init_dim * 4, init_dim * 4, kernel_size, init_radius * 2, init_sigma * 2, group_norm 57 | ) 58 | 59 | self.encoder3_1 = ResidualBlock( 60 | init_dim * 4, init_dim * 4, kernel_size, init_radius * 2, init_sigma * 2, group_norm, strided=True 61 | ) 62 | self.encoder3_2 = ResidualBlock( 63 | init_dim * 4, init_dim * 8, kernel_size, init_radius * 4, init_sigma * 4, group_norm 64 | ) 65 | self.encoder3_3 = ResidualBlock( 66 | init_dim * 8, init_dim * 8, kernel_size, init_radius * 4, init_sigma * 4, group_norm 67 | ) 68 | 69 | self.encoder4_1 = ResidualBlock( 70 | init_dim * 8, init_dim * 8, kernel_size, init_radius * 4, init_sigma * 4, group_norm, strided=True 71 | ) 72 | self.encoder4_2 = ResidualBlock( 73 | init_dim * 8, init_dim * 16, kernel_size, init_radius * 8, init_sigma * 8, group_norm 74 | ) 75 | self.encoder4_3 = ResidualBlock( 76 | init_dim * 16, init_dim * 16, kernel_size, init_radius * 8, init_sigma * 8, group_norm 77 | ) 78 | 79 | self.decoder3 = UnaryBlock(init_dim * 24, init_dim * 8, group_norm) 80 | self.decoder2 = LastUnaryBlock(init_dim * 12, output_dim) 81 | 82 | def forward(self, feats, 83 | points_list, 84 | neighbors_list, 85 | subsampling_list, 86 | upsampling_list): 87 | ''' 88 | Read a scene pair. 89 | Encode all the fine points and features. 90 | ''' 91 | feats_list = [] 92 | 93 | feats_s1 = feats 94 | feats_s1 = self.encoder1_1(feats_s1, points_list[0], points_list[0], neighbors_list[0]) 95 | feats_s1 = self.encoder1_2(feats_s1, points_list[0], points_list[0], neighbors_list[0]) 96 | 97 | feats_s2 = self.encoder2_1(feats_s1, points_list[1], points_list[0], subsampling_list[0]) 98 | feats_s2 = self.encoder2_2(feats_s2, points_list[1], points_list[1], neighbors_list[1]) 99 | feats_s2 = self.encoder2_3(feats_s2, points_list[1], points_list[1], neighbors_list[1]) 100 | 101 | feats_s3 = self.encoder3_1(feats_s2, points_list[2], points_list[1], subsampling_list[1]) 102 | feats_s3 = self.encoder3_2(feats_s3, points_list[2], points_list[2], neighbors_list[2]) 103 | feats_s3 = self.encoder3_3(feats_s3, points_list[2], points_list[2], neighbors_list[2]) 104 | 105 | feats_s4 = self.encoder4_1(feats_s3, points_list[3], points_list[2], subsampling_list[2]) 106 | feats_s4 = self.encoder4_2(feats_s4, points_list[3], points_list[3], neighbors_list[3]) 107 | feats_s4 = self.encoder4_3(feats_s4, points_list[3], points_list[3], neighbors_list[3]) 108 | 109 | latent_s4 = feats_s4 110 | feats_list.append(feats_s4) 111 | 112 | latent_s3 = nearest_upsample(latent_s4, upsampling_list[2]) 113 | latent_s3 = torch.cat([latent_s3, feats_s3], dim=1) 114 | latent_s3 = self.decoder3(latent_s3) 115 | feats_list.append(latent_s3) 116 | 117 | latent_s2 = nearest_upsample(latent_s3, upsampling_list[1]) 118 | latent_s2 = torch.cat([latent_s2, feats_s2], dim=1) 119 | latent_s2 = self.decoder2(latent_s2) 120 | feats_list.append(latent_s2) 121 | 122 | feats_list.reverse() 123 | 124 | return feats_list 125 | 126 | 127 | # This a copy of the KPConvFPN. It is used to compute the FLOPS of the backbone. 128 | # It accepts input in one list. 129 | class TmpKPConvFPN(KPConvFPN): 130 | # the init function is the same as the KPConvFPN 131 | def forward(self, feats, 132 | points0, 133 | points1, 134 | points2, 135 | points3, 136 | neighbors0, 137 | neighbors1, 138 | neighbors2, 139 | neighbors3, 140 | subsampling0, 141 | subsampling1, 142 | subsampling2, 143 | upsampling0, 144 | upsampling1, 145 | upsampling2): 146 | ''' 147 | Read a scene pair. 148 | Encode all the fine points and features. 149 | ''' 150 | feats_list = [] 151 | 152 | feats_s1 = feats 153 | feats_s1 = self.encoder1_1(feats_s1, points0, points0, neighbors0) 154 | feats_s1 = self.encoder1_2(feats_s1, points0, points0, neighbors0) 155 | 156 | feats_s2 = self.encoder2_1(feats_s1, points1, points0, subsampling0) 157 | feats_s2 = self.encoder2_2(feats_s2, points1, points1, neighbors1) 158 | feats_s2 = self.encoder2_3(feats_s2, points1, points1, neighbors1) 159 | 160 | feats_s3 = self.encoder3_1(feats_s2, points2, points1, subsampling1) 161 | feats_s3 = self.encoder3_2(feats_s3, points2, points2, neighbors2) 162 | feats_s3 = self.encoder3_3(feats_s3, points2, points2, neighbors2) 163 | 164 | feats_s4 = self.encoder4_1(feats_s3, points3, points2, subsampling2) 165 | feats_s4 = self.encoder4_2(feats_s4, points3, points3, neighbors3) 166 | feats_s4 = self.encoder4_3(feats_s4, points3, points3, neighbors3) 167 | 168 | latent_s4 = feats_s4 169 | feats_list.append(feats_s4) 170 | 171 | latent_s3 = nearest_upsample(latent_s4, upsampling2) 172 | latent_s3 = torch.cat([latent_s3, feats_s3], dim=1) 173 | latent_s3 = self.decoder3(latent_s3) 174 | feats_list.append(latent_s3) 175 | 176 | latent_s2 = nearest_upsample(latent_s3, upsampling1) 177 | latent_s2 = torch.cat([latent_s2, feats_s2], dim=1) 178 | latent_s2 = self.decoder2(latent_s2) 179 | feats_list.append(latent_s2) 180 | 181 | feats_list.reverse() 182 | return feats_list 183 | 184 | if __name__=='__main__': 185 | print('Try init a KPConvFPN') 186 | model = KPConvFPN(input_dim=3, 187 | output_dim=64, 188 | init_dim=16, 189 | kernel_size=15, 190 | init_radius=0.2, 191 | init_sigma=0.1, 192 | group_norm=8) 193 | 194 | -------------------------------------------------------------------------------- /sgreg/backbone/senet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class SEModule(nn.Module): 7 | def __init__(self, channels, reduction=16): 8 | super(SEModule, self).__init__() 9 | self.squeeze = nn.AdaptiveAvgPool1d(1) 10 | self.execitation = nn.Sequential( 11 | nn.Linear(channels, channels // reduction, bias=False), 12 | nn.ReLU(inplace=True), 13 | nn.Linear(channels // reduction, channels, bias=False), 14 | nn.Sigmoid() 15 | ) 16 | 17 | def forward(self, x): 18 | ''' 19 | Input, 20 | - x: (N,C) 21 | ''' 22 | n, c = x.size() 23 | y = self.squeeze(x.t()).t() # (C) 24 | y = self.execitation(y) # (C,1) 25 | out = x * y.expand_as(x) # (N,C) 26 | return out 27 | 28 | 29 | if __name__=='__main__': 30 | se = SEModule(32) 31 | x = torch.randn(4,32) 32 | out = se(x) 33 | print(out.shape) -------------------------------------------------------------------------------- /sgreg/backbone/shape_encoder.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import torch.nn as nn 4 | from sgreg.ops import sample_instance_from_points 5 | from sgreg.kpconv import ResidualBlock, LastUnaryBlock 6 | from sgreg.utils.tictoc import TicToc 7 | 8 | class KPConvShape(nn.Module): 9 | def __init__(self, 10 | input_dim, 11 | output_dim, 12 | kernel_size, 13 | init_radius, 14 | init_sigma, 15 | group_norm, 16 | K_shape=1024, 17 | K_match=256, 18 | decoder=True): 19 | super(KPConvShape, self).__init__() 20 | 21 | self.encoder1_1 = ResidualBlock( 22 | input_dim, output_dim, kernel_size, init_radius, init_sigma, group_norm, strided=True 23 | ) 24 | 25 | if decoder: 26 | self.decoder = LastUnaryBlock(output_dim+input_dim, input_dim) 27 | 28 | self.input_dim = input_dim 29 | self.output_dim = output_dim 30 | self.K_shape = K_shape 31 | self.K_match = K_match 32 | 33 | def forward(self, f_points, 34 | f_feats, 35 | f_instances, 36 | instances_centroids, 37 | instances_points_indices, 38 | decode_points): 39 | ''' 40 | Input: 41 | f_points: (P, 3), 42 | f_feats: (P, C), 43 | f_instances: (P, 1), 44 | instances_centroids: (N, 3), instance centroid poistion 45 | instances_points_indices: (N, H), instance knn indices 46 | ''' 47 | 48 | #! encoder reduce feature dimension C->C/4 49 | instance_feats = self.encoder1_1(f_feats, instances_centroids, f_points, instances_points_indices) # (N, C) 50 | assert instance_feats.shape[0] == instances_centroids.shape[0] 51 | if decode_points: # This is only used in pre-train constrastive learning. In validation and deployment, skip it. 52 | f_instance_feats = instance_feats[f_instances.squeeze()] # (P, C) 53 | feats_decoded = self.decoder(torch.cat([f_instance_feats, f_feats], dim=1)) 54 | assert feats_decoded.shape[0] == f_points.shape[0] 55 | return instance_feats, feats_decoded 56 | else: 57 | return instance_feats, None 58 | 59 | def concat_instance_points(ref_instance_number:torch.Tensor, 60 | src_instance_number:torch.Tensor, 61 | ref_instance_points:torch.Tensor, 62 | src_instance_points:torch.Tensor, 63 | device:torch.device): 64 | ''' build instance-labeled points and instance list. 65 | Return: 66 | instance_list: (N+M,), instance_points: (P+Q, 3) 67 | ''' 68 | ref_instance_list = torch.arange(ref_instance_number).to(device) 69 | src_instance_list = torch.arange(src_instance_number).to(device)+ref_instance_number 70 | src_instance_points = src_instance_points + ref_instance_number 71 | instance_list = torch.cat([ref_instance_list,src_instance_list],dim=0) # (N+M,) 72 | instance_points = torch.cat([ref_instance_points,src_instance_points],dim=0) # (P+Q, 3) 73 | 74 | return instance_list, instance_points 75 | 76 | def encode_batch_scenes_instances(shape_backbone:nn.Module, 77 | batch_graph_pair:dict, 78 | instance_f_feats_dict:dict, 79 | decode_points=True, 80 | verify=True): 81 | tictoc = TicToc() 82 | stack_instances_shape = [] 83 | stack_instance_pts_match = [] 84 | stack_instances_batch = [torch.tensor(0).long()] 85 | stack_feats_f_decoded = [] 86 | ref_graph_batch = batch_graph_pair['ref_graph']['batch'] 87 | src_graph_batch = batch_graph_pair['src_graph']['batch'] 88 | invalid_instance_exist = False 89 | duration_list = [] 90 | 91 | for scene_id in torch.arange(batch_graph_pair['batch_size']): 92 | # Extract scene data 93 | data_dict = batch_graph_pair['batch_points'][scene_id] 94 | f_pts_length = data_dict['lengths'][1] # [P,Q] 95 | assert f_pts_length.device == data_dict['insts'][1].device, 'f_pts_length and insts device are not match' 96 | assert data_dict['insts'][1].is_contiguous(), 'insts[1] is not contiguous' 97 | tmp_instances_f = data_dict['insts'][1] 98 | # todo: this step is slow. approx. 50ms 99 | ref_instances_f = tmp_instances_f[:f_pts_length[0]] # (P,) 100 | src_instances_f = tmp_instances_f[f_pts_length[0]:] # (Q,) 101 | points_f = data_dict['points'][1] # (P+Q, 3) 102 | duration_list.append(tictoc.toc()) 103 | 104 | feats_b0 = instance_f_feats_dict['feats_batch'][scene_id] 105 | feats_b1 = instance_f_feats_dict['feats_batch'][scene_id+1] 106 | feats_f = instance_f_feats_dict['feats_f'][feats_b0:feats_b1] # (P+Q, C) 107 | 108 | instance_list, instances_f = concat_instance_points( 109 | ref_graph_batch[scene_id+1]-ref_graph_batch[scene_id], 110 | src_graph_batch[scene_id+1]-src_graph_batch[scene_id], 111 | ref_instances_f, 112 | src_instances_f, 113 | feats_f.device) 114 | assert points_f.shape[0] == instances_f.shape[0] \ 115 | and points_f.shape[0]==feats_f.shape[0], 'points, feats, and instances shape are not match' 116 | 117 | # 118 | instance_fpts_indxs_shape, invalid_instance_mask = \ 119 | sample_instance_from_points(instances_f, instance_list, 120 | shape_backbone.K_shape, points_f.shape[0]) # (N+M,Ks), (N+M,) 121 | instance_fpts_indxs_match, _ =\ 122 | sample_instance_from_points(instances_f, instance_list, 123 | shape_backbone.K_match, points_f.shape[0]) # (N+M,Km), (N+M,) 124 | if(torch.any(invalid_instance_mask)): # involves invalid instances 125 | print('instance without points in {}'.format(batch_graph_pair['src_scan'][scene_id])) 126 | # assert False, 'Some instances are not assigned to any fine points.' 127 | instances_shape = torch.zeros((instance_list.shape[0], 128 | shape_backbone.output_dim)).to(feats_f.device) 129 | feats_f_decoded = torch.zeros((feats_f.shape[0], 130 | feats_f.shape[1])).to(feats_f.device) 131 | invalid_instance_exist = True 132 | else: # shape instance-wise shapes and points 133 | # Extract instance centroids 134 | instance_f_points = points_f[instance_fpts_indxs_shape] # (N+M, Ks, 3) 135 | instance_f_centers = instance_f_points.mean(dim=1) # (N+M, 3) 136 | 137 | if verify: 138 | ref_instance_centroids = batch_graph_pair['ref_graph']['centroids'][batch_graph_pair['ref_graph']['scene_mask']==scene_id] 139 | src_instance_centroids = batch_graph_pair['src_graph']['centroids'][batch_graph_pair['src_graph']['scene_mask']==scene_id] 140 | instance_centroids = torch.cat([ref_instance_centroids,src_instance_centroids],dim=0) # (M+N, 3) 141 | dist = torch.abs((instance_centroids - instance_f_centers).mean(dim=1)) # (M+N,) 142 | assert dist.max()<1.0, 'Center distance is too large {:.4f}m'.format(dist.max()) 143 | 144 | instances_shape, feats_f_decoded = shape_backbone(points_f, 145 | feats_f, 146 | instances_f, 147 | instance_f_centers, 148 | instance_fpts_indxs_shape, 149 | decode_points) 150 | duration_list.append(tictoc.toc()) 151 | 152 | stack_instances_shape.append(instances_shape) 153 | stack_instance_pts_match.append(instance_fpts_indxs_match) 154 | stack_instances_batch.append(stack_instances_batch[-1]+instance_list.shape[0]) 155 | 156 | if decode_points: 157 | assert feats_f_decoded is not None, 'feats_f_decoded is None' 158 | stack_feats_f_decoded.append(feats_f_decoded) # (P+Q, C0) 159 | 160 | # Concatenate all scenes 161 | stack_instances_shape = torch.cat(stack_instances_shape,dim=0) # (M+N, C1) 162 | stack_instances_shape = nn.functional.normalize(stack_instances_shape, dim=1) # (M+N, C1) 163 | instance_f_feats_dict['instances_shape'] = stack_instances_shape 164 | instance_f_feats_dict['instances_f_indices_match'] = torch.cat(stack_instance_pts_match,dim=0) 165 | instance_f_feats_dict['instances_batch'] = torch.stack(stack_instances_batch,dim=0) # (B+1,) 166 | instance_f_feats_dict['invalid_instance_exist'] = invalid_instance_exist 167 | 168 | if decode_points: 169 | stack_feats_f_decoded = torch.cat(stack_feats_f_decoded,dim=0) # (P+Q, C0) 170 | instance_f_feats_dict['feats_f_decoded'] = stack_feats_f_decoded 171 | duration_list.append(tictoc.toc()) 172 | 173 | # timing 174 | # msg = ['{:.2f}'.format(1000*t) for t in duration_list] 175 | # print('Shape Encoder: ', ' '.join(msg)) 176 | 177 | return instance_f_feats_dict, duration_list[1] -------------------------------------------------------------------------------- /sgreg/bert/get_tokenizer.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer, BertModel, BertTokenizer, RobertaModel, RobertaTokenizerFast 2 | 3 | 4 | def get_tokenlizer(text_encoder_type): 5 | if not isinstance(text_encoder_type, str): 6 | # print("text_encoder_type is not a str") 7 | if hasattr(text_encoder_type, "text_encoder_type"): 8 | text_encoder_type = text_encoder_type.text_encoder_type 9 | elif text_encoder_type.get("text_encoder_type", False): 10 | text_encoder_type = text_encoder_type.get("text_encoder_type") 11 | else: 12 | raise ValueError( 13 | "Unknown type of text_encoder_type: {}".format(type(text_encoder_type)) 14 | ) 15 | print("final text_encoder_type: {}".format(text_encoder_type)) 16 | 17 | tokenizer = AutoTokenizer.from_pretrained(text_encoder_type) 18 | return tokenizer 19 | 20 | 21 | def get_pretrained_language_model(text_encoder_type): 22 | if text_encoder_type == "bert-base-uncased": 23 | return BertModel.from_pretrained(text_encoder_type) 24 | if text_encoder_type == "roberta-base": 25 | return RobertaModel.from_pretrained(text_encoder_type) 26 | raise ValueError("Unknown text_encoder_type {}".format(text_encoder_type)) 27 | -------------------------------------------------------------------------------- /sgreg/dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUST-Aerial-Robotics/SG-Reg/c164198cec84be11dc53101755b0d9f7a4bc5082/sgreg/dataset/__init__.py -------------------------------------------------------------------------------- /sgreg/dataset/dataset_factory.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | import torch 3 | import numpy as np 4 | import open3d as o3d 5 | from sgreg.dataset.scene_pair_dataset import ( 6 | ScenePairDataset, 7 | ) 8 | from sgreg.dataset.stack_mode import ( 9 | calibrate_neighbors_stack_mode, 10 | registration_collate_fn_stack_mode, 11 | sgreg_collate_fn_stack_mode, 12 | build_dataloader_stack_mode 13 | ) 14 | 15 | def train_data_loader(cfg,distributed=False): 16 | train_dataset = ScenePairDataset(cfg.dataset.dataroot, 'train', cfg) 17 | neighbor_limits = calibrate_neighbors_stack_mode( 18 | train_dataset, 19 | registration_collate_fn_stack_mode, 20 | cfg.backbone.num_stages, 21 | cfg.backbone.init_voxel_size, 22 | cfg.backbone.init_radius, 23 | ) 24 | train_dataset.neighbor_limits = neighbor_limits 25 | if 'verify_instance_points' in cfg.dataset: 26 | verify_instance_points = cfg.dataset.verify_instance_points 27 | else: 28 | verify_instance_points = False 29 | 30 | train_loader = build_dataloader_stack_mode( 31 | train_dataset, 32 | sgreg_collate_fn_stack_mode, 33 | cfg.backbone.num_stages, 34 | cfg.backbone.init_voxel_size, 35 | cfg.backbone.init_radius, 36 | neighbor_limits, 37 | verify_instance_points, 38 | batch_size=cfg.train.batch_size, 39 | num_workers=cfg.train.num_workers, 40 | shuffle=True, 41 | distributed=distributed, 42 | ) 43 | return train_loader,neighbor_limits 44 | 45 | def val_data_loader(cfg,distributed=False): 46 | val_dataset = ScenePairDataset(cfg.dataset.dataroot, 'val', cfg) 47 | neighbor_limits = calibrate_neighbors_stack_mode( 48 | val_dataset, 49 | registration_collate_fn_stack_mode, 50 | cfg.backbone.num_stages, 51 | cfg.backbone.init_voxel_size, 52 | cfg.backbone.init_radius, 53 | ) 54 | print('Calibrated neighbor limits:',neighbor_limits) 55 | # neighbor_limits = [38, 36, 36, 38] # default setting in 3DMatch 56 | val_dataset.neighbor_limits = neighbor_limits 57 | 58 | if 'verify_instance_points' in cfg.dataset: 59 | verify_instance_points = cfg.dataset.verify_instance_points 60 | else: 61 | verify_instance_points = False 62 | 63 | val_loader = build_dataloader_stack_mode( 64 | val_dataset, 65 | sgreg_collate_fn_stack_mode, 66 | cfg.backbone.num_stages, 67 | cfg.backbone.init_voxel_size, 68 | cfg.backbone.init_radius, 69 | neighbor_limits, 70 | verify_instance_points, 71 | batch_size=cfg.train.batch_size, 72 | num_workers=cfg.train.num_workers, 73 | shuffle=False, 74 | distributed=distributed, 75 | ) 76 | return val_loader,neighbor_limits 77 | 78 | if __name__=='__main__': 79 | import torch 80 | from omegaconf import OmegaConf 81 | conf = OmegaConf.load('config/pretrain.yaml') 82 | conf.backbone.init_radius = 0.5 * conf.backbone.base_radius * conf.backbone.init_voxel_size # 0.0625 83 | conf.backbone.init_sigma = 0.5 * conf.backbone.base_sigma * conf.backbone.init_voxel_size # 0.05 84 | 85 | train_loader, train_neighbor_limits = train_data_loader(conf) 86 | val_loader, val_neighbor_limits = val_data_loader(conf) 87 | 88 | # print('train neighbor limits',train_neighbor_limits) 89 | # print('val neighbor limits',val_neighbor_limits) 90 | # exit(0) 91 | # data_dict = next(iter(train_loader)) 92 | # src_graph = data_dict['src_graph'] 93 | # print(data_dict.keys()) 94 | # print(src_graph.keys()) 95 | 96 | warning_scans = [] 97 | valid_scans = [] 98 | 99 | for data_dict in val_loader: 100 | src_graph = data_dict['src_graph'] 101 | src_scan = data_dict['src_scan'][0] 102 | ref_scan = data_dict['ref_scan'][0] 103 | msg = '{}-{}: '.format(src_scan,ref_scan) 104 | points_dict= data_dict['batch_points'][0] 105 | # print(data_dict.keys()) 106 | lengths = points_dict['lengths'] # [(Pl,Ql)] 107 | 108 | # if 'debug_flag' in data_dict: 109 | # print('{} inconsistent instance {}-{} !!!'.format( 110 | # data_dict['debug_flag'], src_scan,ref_scan, )) 111 | if data_dict['instance_matches'][0] is None: 112 | print('{}-{} no gt instance matches!!!'.format(src_scan,ref_scan)) 113 | warning_scans.append('{} {}'.format(src_scan,ref_scan)) 114 | else: 115 | valid_scans.append('{} {}'.format(src_scan,ref_scan)) 116 | print(msg) 117 | # break 118 | 119 | print('************ Finished ************') 120 | print('{}/{} warning scans'.format(len(warning_scans), 121 | len(valid_scans)+len(warning_scans))) 122 | 123 | 124 | # outdir = '/data2/sgalign_data/splits/val_ours.txt' 125 | # with open(outdir,'w') as f: 126 | # f.write('\n'.join(valid_scans)) 127 | # f.close() 128 | -------------------------------------------------------------------------------- /sgreg/dataset/prepare_semantics.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Generate median features for the scene graph, 3 | - Semantic embeddings. Encoded from bert. 4 | - Hierarchical points features. Encoded in pair of sub-scenes. 5 | ''' 6 | 7 | 8 | import os 9 | import torch 10 | import torch.utils.data 11 | import pandas as pd 12 | import time 13 | # import numpy as np 14 | from omegaconf import OmegaConf 15 | from scipy.spatial.transform import Rotation as R 16 | from sgreg.bert.get_tokenizer import get_tokenlizer, get_pretrained_language_model 17 | from sgreg.bert.bertwarper import generate_bert_fetures 18 | 19 | # from model.ops.transformation import apply_transform 20 | from sgreg.dataset.scene_pair_dataset import ScenePairDataset 21 | # from model.dataset.DatasetFactory import train_data_loader, val_data_loader 22 | from sgreg.utils.config import create_cfg 23 | # from model.dataset.ScanNetPairDataset import read_scans,load_scene_graph 24 | 25 | def generate_semantic_embeddings(data_dict): 26 | src_labels = data_dict['src_graph']['labels'] 27 | ref_labels = data_dict['ref_graph']['labels'] 28 | src_idxs = data_dict['src_graph']['idx2name'] 29 | ref_idxs = data_dict['ref_graph']['idx2name'] 30 | assert len(src_labels) == src_idxs.shape[0] and len(ref_labels) == ref_idxs.shape[0] 31 | 32 | labels = src_labels + ref_labels 33 | 34 | t0 = time.time() 35 | semantic_embeddings = generate_bert_fetures(tokenizer,bert,labels) 36 | t1 = time.time() 37 | print('encode {} labels takes {:.3f} msecs'.format(len(labels),1000*(t1-t0))) 38 | assert len(semantic_embeddings) == len(labels) 39 | 40 | src_semantic_embeddings = semantic_embeddings[:len(src_labels)].detach() 41 | ref_semantic_embeddings = semantic_embeddings[len(src_labels):].detach() 42 | 43 | return {'instance_idxs':src_idxs,'semantic_embeddings':src_semantic_embeddings},\ 44 | {'instance_idxs':ref_idxs,'semantic_embeddings':ref_semantic_embeddings} 45 | 46 | if __name__=='__main__': 47 | print('Save the semantic embeddings to accelerate the training process.') 48 | ## 49 | dataroot = '/data2/RioGraph' # '/data2/ScanNetGraph' 50 | split = 'val' 51 | cfg_file = '/home/cliuci/code_ws/SceneGraphNet/config/rio.yaml' 52 | middle_feat_folder = os.path.join(dataroot,'matches') 53 | from sgreg.dataset.DatasetFactory import prepare_hierarchy_points_feats 54 | 55 | ## 56 | tokenizer = get_tokenlizer('bert-base-uncased') 57 | bert = get_pretrained_language_model('bert-base-uncased') 58 | bert.eval() 59 | bert.pooler.dense.weight.requires_grad = False 60 | bert.pooler.dense.bias.requires_grad = False 61 | 62 | # 63 | conf = OmegaConf.load(cfg_file) 64 | conf = create_cfg(conf) 65 | conf.dataset.online_bert = True 66 | dataset = ScenePairDataset(dataroot,split,conf) 67 | 68 | neighbor_limits = [38, 36, 36, 38] 69 | print('neighbor_limits:',neighbor_limits) 70 | cfg_dict= {'num_stages': conf.backbone.num_stages, 71 | 'voxel_size': conf.backbone.init_voxel_size, 72 | 'search_radius': conf.backbone.init_radius, 73 | 'neighbor_limits': neighbor_limits} 74 | 75 | # 76 | N = len(dataset) 77 | print('Dataset size:',N) 78 | 79 | SEMANTIC_ON = True 80 | POINTS_ON = False 81 | CHECK_EDGES = False 82 | max_fine_points = [] 83 | 84 | warn_scans = [] 85 | 86 | for i in range(N): 87 | data_dict = dataset[i] 88 | scene_name = data_dict['src_scan'][:-1] 89 | src_subname = data_dict['src_scan'][-1] 90 | ref_subname = data_dict['ref_scan'][-1] 91 | print('---processing {} and {} -----'.format(data_dict['src_scan'],data_dict['ref_scan'])) 92 | if data_dict['instance_ious'] is None: 93 | warn_scans.append((data_dict['src_scan'],data_dict['ref_scan'])) 94 | 95 | if CHECK_EDGES: 96 | src_global_edges = data_dict['src_graph']['global_edge_indices'] 97 | ref_global_edges = data_dict['ref_graph']['global_edge_indices'] 98 | print('{} global src edges'.format(src_global_edges.shape[0])) 99 | 100 | if src_global_edges.shape[0]<1 or ref_global_edges.shape[0]<1: 101 | print('global_edges is None!') 102 | warn_scans.append((data_dict['src_scan'],data_dict['ref_scan'])) 103 | 104 | if SEMANTIC_ON: 105 | src_semantic_embeddings, ref_semantic_embeddings = generate_semantic_embeddings(data_dict) 106 | torch.save(src_semantic_embeddings, 107 | os.path.join(dataroot,split,data_dict['src_scan'],'semantic_embeddings.pth')) 108 | torch.save(ref_semantic_embeddings, 109 | os.path.join(dataroot,split,data_dict['ref_scan'],'semantic_embeddings.pth')) 110 | if POINTS_ON: 111 | out_dict = prepare_hierarchy_points_feats(cfg_dict=cfg_dict, 112 | ref_points = data_dict['ref_points'], 113 | src_points=data_dict['src_points'], 114 | ref_instances=data_dict['ref_instances'], 115 | src_instances=data_dict['src_instances'], 116 | ref_feats=data_dict['ref_feats'], 117 | src_feats=data_dict['src_feats']) 118 | import numpy as np 119 | 120 | ref_instance_count = np.histogram(data_dict['ref_instances'],bins=np.unique(data_dict['ref_instances']))[0] 121 | print('ref instance list:',np.unique(data_dict['ref_instances'])) 122 | print('ref_instance points count:',ref_instance_count) 123 | print('points number: ', data_dict['ref_points'].shape[0]) 124 | 125 | ref_instance_list = np.unique(out_dict['ref_points_f_instances']) 126 | ref_instance_fine_count = np.histogram(out_dict['ref_points_f_instances'],bins=ref_instance_list)[0] 127 | print('ref instance list:',ref_instance_list) 128 | print('ref_instance fine points count:',ref_instance_fine_count) 129 | print('max fine points: ',ref_instance_fine_count.max()) 130 | max_fine_points.append(ref_instance_fine_count.max()) 131 | 132 | points_dict = out_dict['points_dict'] 133 | feats = out_dict['feats'] 134 | ref_points_f_instances= out_dict['ref_points_f_instances'] 135 | src_points_f_instances= out_dict['src_points_f_instances'] 136 | assert 'instance_matches' in data_dict 137 | assert isinstance(ref_points_f_instances,torch.Tensor) and isinstance(src_points_f_instances,torch.Tensor) 138 | assert points_dict['points'][0].shape[0] == points_dict['lengths'][0].sum() 139 | # Export 140 | data_dict.update(out_dict) 141 | torch.save(data_dict,os.path.join(middle_feat_folder,scene_name,'data_dict_{}.pth'.format(src_subname+ref_subname))) 142 | 143 | # break 144 | print(warn_scans) 145 | print('finished {} scan pairs'.format(N)) 146 | 147 | if len(max_fine_points)>0: 148 | max_fine_points = np.array(max_fine_points) 149 | print(max_fine_points) 150 | print(max_fine_points.max()) 151 | 152 | 153 | 154 | 155 | -------------------------------------------------------------------------------- /sgreg/dataset/scene_graph.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import pandas as pd 4 | import numpy as np 5 | import open3d as o3d 6 | from scipy.spatial.transform import Rotation as R 7 | 8 | from open3d.geometry import OrientedBoundingBox as OBB 9 | from os.path import join as osp 10 | 11 | class Instance: 12 | def __init__(self,idx:int, 13 | cloud:o3d.geometry.PointCloud|np.ndarray, 14 | label:str, 15 | score:float): 16 | self.idx = idx 17 | self.label = label 18 | self.score = score 19 | self.cloud = cloud 20 | self.cloud_dir = None 21 | def load_box(self,box:o3d.geometry.OrientedBoundingBox): 22 | self.box = box 23 | 24 | def load_raw_scene_graph(folder_dir:str, 25 | voxel_size:float=0.02, 26 | ignore_types:list=['ceiling']): 27 | ''' graph: {'nodes':{idx:Instance},'edges':{idx:idx}} 28 | ''' 29 | # load scene graph 30 | nodes = {} 31 | boxes = {} 32 | invalid_nodes = [] 33 | xyzi = [] 34 | global_cloud = o3d.geometry.PointCloud() 35 | # IGNORE_TYPES = ['floor','carpet','wall'] 36 | # IGNORE_TYPES = ['ceiling'] 37 | 38 | # load instance boxes 39 | with open(os.path.join(folder_dir,'instance_box.txt')) as f: 40 | count=0 41 | for line in f.readlines(): 42 | line = line.strip() 43 | if'#' in line:continue 44 | parts = line.split(';') 45 | idx = int(parts[0]) 46 | center = np.array([float(x) for x in parts[1].split(',')]) 47 | rotation = np.array([float(x) for x in parts[2].split(',')]) 48 | extent = np.array([float(x) for x in parts[3].split(',')]) 49 | o3d_box = o3d.geometry.OrientedBoundingBox(center,rotation.reshape(3,3),extent) 50 | o3d_box.color = (0,0,0) 51 | # if'nan' in line:invalid_nodes.append(idx) 52 | if 'nan' not in line: 53 | boxes[idx] = o3d_box 54 | # nodes[idx].load_box(o3d_box) 55 | count+=1 56 | f.close() 57 | print('load {} boxes'.format(count)) 58 | 59 | # load instance info 60 | with open(os.path.join(folder_dir,'instance_info.txt')) as f: 61 | for line in f.readlines(): 62 | line = line.strip() 63 | if'#' in line:continue 64 | parts = line.split(';') 65 | idx = int(parts[0]) 66 | if idx not in boxes: continue 67 | label_score_vec = parts[1].split('(') 68 | label = label_score_vec[0] 69 | score = float(label_score_vec[1].split(')')[0]) 70 | if label in ignore_types: continue 71 | # print('load {}:{}, {}'.format(idx,label,score)) 72 | 73 | cloud = o3d.io.read_point_cloud(os.path.join(folder_dir,'{}.ply'.format(parts[0]))) 74 | cloud = cloud.voxel_down_sample(voxel_size) 75 | xyz = np.asarray(cloud.points) 76 | # if xyz.shape[0]<50: continue 77 | xyzi.append(np.concatenate([xyz,idx*np.ones((len(xyz),1))], 78 | axis=1)) 79 | global_cloud = global_cloud + cloud 80 | nodes[idx] = Instance(idx,cloud,label,score) 81 | nodes[idx].cloud_dir = '{}.ply'.format(parts[0]) 82 | nodes[idx].load_box(boxes[idx]) 83 | 84 | f.close() 85 | print('Load {} instances '.format(len(nodes))) 86 | if len(xyzi)>0: 87 | xyzi = np.concatenate(xyzi,axis=0) 88 | 89 | return {'nodes':nodes, 90 | 'edges':[], 91 | 'global_cloud':global_cloud, 92 | 'xyzi':xyzi} 93 | 94 | def load_processed_scene_graph(scan_dir:str): 95 | 96 | instance_nodes = {} # {idx:node_info} 97 | 98 | # Nodes 99 | nodes_data = pd.read_csv(os.path.join(scan_dir,'nodes.csv')) 100 | max_node_id = 0 101 | 102 | for idx, label, score, center, quat, extent, _ in zip(nodes_data['node_id'], 103 | nodes_data['label'], 104 | nodes_data['score'], 105 | nodes_data['center'], 106 | nodes_data['quaternion'], 107 | nodes_data['extent'], 108 | nodes_data['cloud_dir']): 109 | centroid = np.fromstring(center, dtype=float, sep=',') 110 | quaternion = np.fromstring(quat, dtype=float, sep=',') # (x,y,z,w) 111 | rot = R.from_quat(quaternion) 112 | extent = np.fromstring(extent, dtype=float, sep=',') 113 | 114 | if np.isnan(extent).any() or np.isnan(quaternion).any() or np.isnan(centroid).any() or np.isnan(idx): 115 | continue 116 | if '_' in label: label = label.replace('_',' ') 117 | pcd_dir = osp(scan_dir, str(idx).zfill(4)+'.ply') 118 | assert os.path.exists(pcd_dir), 'File not found: {}'.format(pcd_dir) 119 | 120 | instance_nodes[idx] = Instance(idx, 121 | o3d.io.read_point_cloud(pcd_dir), 122 | label, 123 | score) 124 | instance_nodes[idx].load_box(OBB(centroid, 125 | rot.as_matrix(), 126 | extent)) 127 | 128 | if idx>max_node_id: 129 | max_node_id = idx 130 | print('Load {} instances'.format(len(instance_nodes))) 131 | 132 | # Instance Point Cloud 133 | xyzi = torch.load(os.path.join(scan_dir,'xyzi.pth')).numpy() 134 | instances = xyzi[:,-1].astype(np.int32) 135 | xyz = xyzi[:,:3].astype(np.float32) 136 | assert max_node_id == instances.max(), 'Instance ID mismatch' 137 | assert np.unique(instances).shape[0] == len(instance_nodes), 'Instance ID mismatch' 138 | 139 | # Global Point Cloud 140 | # colors = np.zeros_like(xyz) 141 | # for idx, instance in instance_nodes.items(): 142 | # inst_mask = instances== idx 143 | # assert inst_mask.sum()>0 144 | # inst_color = 255*np.random.rand(3) 145 | # colors[inst_mask] = np.floor(inst_color).astype(np.int32) 146 | # instance.cloud = o3d.geometry.PointCloud( 147 | # o3d.utility.Vector3dVector(xyz[inst_mask])) 148 | 149 | # global_pcd = o3d.geometry.PointCloud( 150 | # o3d.utility.Vector3dVector(xyz)) 151 | # global_pcd.colors = o3d.utility.Vector3dVector(colors) 152 | # global_pcd = o3d.io.read_point_cloud(os.path.join(scan_dir,'instance_map.ply')) 153 | 154 | global_pcd = o3d.geometry.PointCloud() 155 | for idx, instance in instance_nodes.items(): 156 | global_pcd += instance.cloud 157 | 158 | return {'nodes':instance_nodes, 159 | 'edges':[], 160 | 'global_cloud':global_pcd} 161 | 162 | 163 | def transform_scene_graph(scene_graph:dict, 164 | transformation:np.ndarray): 165 | 166 | scene_graph['global_cloud'].transform(transformation) 167 | for idx, instance in scene_graph['nodes'].items(): 168 | tmp_center = instance.box.center 169 | instance.cloud.transform(transformation) 170 | # todo: open3d rotate the bbox falsely 171 | # instance.box = o3d.geometry.OrientedBoundingBox.create_from_points( 172 | # o3d.utility.Vector3dVector(np.asarray(instance.cloud.points))) 173 | instance.box.rotate(R=transformation[:3,:3]) 174 | instance.box.translate(transformation[:3,3]) 175 | 176 | 177 | -------------------------------------------------------------------------------- /sgreg/extensions/README.md: -------------------------------------------------------------------------------- 1 | # Geotransformer with grid subsampling for point clouds with semantic labels 2 | ## Grid Subsampling (Modified with semantic labels) 3 | 4 | This code implements the grid subsampling algorithm for point clouds. It performs grid subsampling on each batch of points and returns the subsampled points, subsampled instance labels, and lengths of the subsampled point cloud batches. 5 | 6 | It provides a function `grid_subsampling_cpu` that takes in a set of 3D points, **their corresponding instance labels**, and lengths of point cloud batches. 7 | 8 | ## Radius-based Neighbor Search 9 | The main function, `radius_neighbors_cpu`, performs a batched search on two sets of 3D points, storing the indices of the neighbors found within a given radius for each point in the first set. 10 | 11 | 12 | -------------------------------------------------------------------------------- /sgreg/extensions/common/torch_helper.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | #define CHECK_CUDA(x) \ 7 | TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") 8 | 9 | #define CHECK_CPU(x) \ 10 | TORCH_CHECK(!x.device().is_cuda(), #x " must be a CPU tensor") 11 | 12 | #define CHECK_CONTIGUOUS(x) \ 13 | TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 14 | 15 | #define CHECK_INPUT(x) \ 16 | CHECK_CUDA(x); \ 17 | CHECK_CONTIGUOUS(x) 18 | 19 | #define CHECK_IS_INT(x) \ 20 | do { \ 21 | TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, \ 22 | #x " must be an int tensor"); \ 23 | } while (0) 24 | 25 | #define CHECK_IS_LONG(x) \ 26 | do { \ 27 | TORCH_CHECK(x.scalar_type() == at::ScalarType::Long, \ 28 | #x " must be an long tensor"); \ 29 | } while (0) 30 | 31 | #define CHECK_IS_FLOAT(x) \ 32 | do { \ 33 | TORCH_CHECK(x.scalar_type() == at::ScalarType::Float, \ 34 | #x " must be a float tensor"); \ 35 | } while (0) 36 | -------------------------------------------------------------------------------- /sgreg/extensions/cpu/grid_subsampling/grid_subsampling.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "grid_subsampling.h" 3 | #include "grid_subsampling_cpu.h" 4 | 5 | std::vector grid_subsampling( 6 | at::Tensor points, 7 | at::Tensor insts, 8 | at::Tensor lengths, 9 | float voxel_size 10 | ) { 11 | CHECK_CPU(points); 12 | CHECK_CPU(lengths); 13 | CHECK_IS_FLOAT(points); 14 | CHECK_IS_LONG(lengths); 15 | CHECK_CONTIGUOUS(points); 16 | CHECK_CONTIGUOUS(lengths); 17 | 18 | CHECK_CPU(insts); 19 | CHECK_IS_INT(insts); 20 | CHECK_CONTIGUOUS(insts); 21 | 22 | std::size_t batch_size = lengths.size(0); 23 | std::size_t total_points = points.size(0); 24 | 25 | std::vector vec_points = std::vector( 26 | reinterpret_cast(points.data_ptr()), 27 | reinterpret_cast(points.data_ptr()) + total_points 28 | ); 29 | std::vector vec_s_points; 30 | 31 | std::vector vec_lengths = std::vector( 32 | lengths.data_ptr(), 33 | lengths.data_ptr() + batch_size 34 | ); 35 | std::vector vec_s_lengths; 36 | 37 | std::vector vec_insts = std::vector( 38 | reinterpret_cast(insts.data_ptr()), 39 | reinterpret_cast(insts.data_ptr()) + total_points 40 | ); 41 | std::vector vec_s_insts; 42 | 43 | grid_subsampling_cpu( 44 | vec_points, 45 | vec_s_points, 46 | vec_insts, 47 | vec_s_insts, 48 | vec_lengths, 49 | vec_s_lengths, 50 | voxel_size 51 | ); 52 | 53 | std::size_t total_s_points = vec_s_points.size(); 54 | at::Tensor s_points = torch::zeros( 55 | {total_s_points, 3}, 56 | at::device(points.device()).dtype(at::ScalarType::Float) 57 | ); 58 | at::Tensor s_lengths = torch::zeros( 59 | {batch_size}, 60 | at::device(lengths.device()).dtype(at::ScalarType::Long) 61 | ); 62 | 63 | at::Tensor s_insts = torch::zeros( 64 | {total_s_points, 1}, 65 | at::device(insts.device()).dtype(at::ScalarType::Int) 66 | ); 67 | 68 | std::memcpy( 69 | s_points.data_ptr(), 70 | reinterpret_cast(vec_s_points.data()), 71 | sizeof(float) * total_s_points * 3 72 | ); 73 | std::memcpy( 74 | s_lengths.data_ptr(), 75 | vec_s_lengths.data(), 76 | sizeof(long) * batch_size 77 | ); 78 | 79 | std::memcpy( 80 | s_insts.data_ptr(), 81 | reinterpret_cast(vec_s_insts.data()), 82 | sizeof(int) * total_s_points * 1 83 | ); 84 | 85 | 86 | return {s_points, s_insts, s_lengths}; 87 | } 88 | -------------------------------------------------------------------------------- /sgreg/extensions/cpu/grid_subsampling/grid_subsampling.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include "../../common/torch_helper.h" 5 | 6 | std::vector grid_subsampling( 7 | at::Tensor points, 8 | at::Tensor insts, 9 | at::Tensor lengths, 10 | float voxel_size 11 | ); 12 | -------------------------------------------------------------------------------- /sgreg/extensions/cpu/grid_subsampling/grid_subsampling_cpu.cpp: -------------------------------------------------------------------------------- 1 | #include "grid_subsampling_cpu.h" 2 | 3 | void single_grid_subsampling_cpu( 4 | std::vector& points, 5 | std::vector& s_points, 6 | std::vector& inst, 7 | std::vector& s_inst, 8 | float voxel_size 9 | ) { 10 | // float sub_scale = 1. / voxel_size; 11 | PointXYZ minCorner = min_point(points); 12 | PointXYZ maxCorner = max_point(points); 13 | PointXYZ originCorner = floor(minCorner * (1. / voxel_size)) * voxel_size; 14 | 15 | std::size_t sampleNX = static_cast( 16 | // floor((maxCorner.x - originCorner.x) * sub_scale) + 1 17 | floor((maxCorner.x - originCorner.x) / voxel_size) + 1 18 | ); 19 | std::size_t sampleNY = static_cast( 20 | // floor((maxCorner.y - originCorner.y) * sub_scale) + 1 21 | floor((maxCorner.y - originCorner.y) / voxel_size) + 1 22 | ); 23 | 24 | std::size_t iX = 0; 25 | std::size_t iY = 0; 26 | std::size_t iZ = 0; 27 | std::size_t mapIdx = 0; 28 | std::unordered_map data; 29 | std::unordered_map data_inst; 30 | 31 | int i = 0; 32 | for (auto& p : points) { 33 | // iX = static_cast(floor((p.x - originCorner.x) * sub_scale)); 34 | // iY = static_cast(floor((p.y - originCorner.y) * sub_scale)); 35 | // iZ = static_cast(floor((p.z - originCorner.z) * sub_scale)); 36 | iX = static_cast(floor((p.x - originCorner.x) / voxel_size)); 37 | iY = static_cast(floor((p.y - originCorner.y) / voxel_size)); 38 | iZ = static_cast(floor((p.z - originCorner.z) / voxel_size)); 39 | mapIdx = iX + sampleNX * iY + sampleNX * sampleNY * iZ; 40 | 41 | if (!data.count(mapIdx)) { 42 | data.emplace(mapIdx, SampledData()); 43 | data_inst.emplace(mapIdx, SampledInst()); 44 | 45 | } 46 | 47 | data[mapIdx].update(p); 48 | data_inst[mapIdx].update(inst[i]); 49 | i = i + 1; 50 | } 51 | 52 | s_points.reserve(data.size()); 53 | for (auto& v : data) { 54 | s_points.push_back(v.second.point * (1.0 / v.second.count)); 55 | } 56 | 57 | s_inst.reserve(data_inst.size()); 58 | for (auto& v : data_inst) { 59 | 60 | std::map histo_map; 61 | for (int i : v.second.inst_ids) { 62 | histo_map[i]++; 63 | } 64 | auto maxElem = std::max_element(histo_map.begin(), histo_map.end(), 65 | [](const std::pair& a, const std::pair& b) { 66 | return a.second < b.second; 67 | }); 68 | s_inst.push_back(maxElem->first); 69 | } 70 | 71 | 72 | } 73 | 74 | void grid_subsampling_cpu( 75 | std::vector& points, 76 | std::vector& s_points, 77 | std::vector& insts, 78 | std::vector& s_insts, 79 | std::vector& lengths, 80 | std::vector& s_lengths, 81 | float voxel_size 82 | ) { 83 | std::size_t start_index = 0; 84 | std::size_t batch_size = lengths.size(); 85 | for (std::size_t b = 0; b < batch_size; b++) { 86 | std::vector cur_points = std::vector( 87 | points.begin() + start_index, 88 | points.begin() + start_index + lengths[b] 89 | ); 90 | std::vector cur_s_points; 91 | 92 | std::vector cur_insts = std::vector( 93 | insts.begin() + start_index, 94 | insts.begin() + start_index + lengths[b] 95 | ); 96 | std::vector cur_s_insts; 97 | 98 | single_grid_subsampling_cpu(cur_points, cur_s_points, cur_insts, cur_s_insts, voxel_size); 99 | 100 | s_points.insert(s_points.end(), cur_s_points.begin(), cur_s_points.end()); 101 | s_insts.insert(s_insts.end(), cur_s_insts.begin(), cur_s_insts.end()); 102 | s_lengths.push_back(cur_s_points.size()); 103 | 104 | start_index += lengths[b]; 105 | } 106 | 107 | return; 108 | } 109 | -------------------------------------------------------------------------------- /sgreg/extensions/cpu/grid_subsampling/grid_subsampling_cpu.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include "../../extra/cloud/cloud.h" 6 | 7 | class SampledData { 8 | public: 9 | int count; 10 | PointXYZ point; 11 | 12 | SampledData() { 13 | count = 0; 14 | point = PointXYZ(); 15 | } 16 | 17 | void update(const PointXYZ& p) { 18 | count += 1; 19 | point += p; 20 | } 21 | }; 22 | 23 | class SampledInst { 24 | public: 25 | int count; 26 | std::vector inst_ids; 27 | 28 | SampledInst() { 29 | count = 0; 30 | } 31 | 32 | void update(const int inst_id) { 33 | count += 1; 34 | inst_ids.push_back(inst_id); 35 | } 36 | }; 37 | 38 | void single_grid_subsampling_cpu( 39 | std::vector& o_points, 40 | std::vector& s_points, 41 | std::vector& inst, 42 | std::vector& s_inst, 43 | float voxel_size 44 | ); 45 | 46 | void grid_subsampling_cpu( 47 | std::vector& o_points, 48 | std::vector& s_points, 49 | std::vector& o_insts, 50 | std::vector& s_insts, 51 | std::vector& o_lengths, 52 | std::vector& s_lengths, 53 | float voxel_size 54 | ); 55 | 56 | -------------------------------------------------------------------------------- /sgreg/extensions/cpu/radius_neighbors/radius_neighbors.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "radius_neighbors.h" 3 | #include "radius_neighbors_cpu.h" 4 | 5 | at::Tensor radius_neighbors( 6 | at::Tensor q_points, 7 | at::Tensor s_points, 8 | at::Tensor q_lengths, 9 | at::Tensor s_lengths, 10 | float radius 11 | ) { 12 | CHECK_CPU(q_points); 13 | CHECK_CPU(s_points); 14 | CHECK_CPU(q_lengths); 15 | CHECK_CPU(s_lengths); 16 | CHECK_IS_FLOAT(q_points); 17 | CHECK_IS_FLOAT(s_points); 18 | CHECK_IS_LONG(q_lengths); 19 | CHECK_IS_LONG(s_lengths); 20 | CHECK_CONTIGUOUS(q_points); 21 | CHECK_CONTIGUOUS(s_points); 22 | CHECK_CONTIGUOUS(q_lengths); 23 | CHECK_CONTIGUOUS(s_lengths); 24 | 25 | std::size_t total_q_points = q_points.size(0); 26 | std::size_t total_s_points = s_points.size(0); 27 | std::size_t batch_size = q_lengths.size(0); 28 | 29 | std::vector vec_q_points = std::vector( 30 | reinterpret_cast(q_points.data_ptr()), 31 | reinterpret_cast(q_points.data_ptr()) + total_q_points 32 | ); 33 | std::vector vec_s_points = std::vector( 34 | reinterpret_cast(s_points.data_ptr()), 35 | reinterpret_cast(s_points.data_ptr()) + total_s_points 36 | ); 37 | std::vector vec_q_lengths = std::vector( 38 | q_lengths.data_ptr(), q_lengths.data_ptr() + batch_size 39 | ); 40 | std::vector vec_s_lengths = std::vector( 41 | s_lengths.data_ptr(), s_lengths.data_ptr() + batch_size 42 | ); 43 | std::vector vec_neighbor_indices; 44 | 45 | radius_neighbors_cpu( 46 | vec_q_points, 47 | vec_s_points, 48 | vec_q_lengths, 49 | vec_s_lengths, 50 | vec_neighbor_indices, 51 | radius 52 | ); 53 | 54 | std::size_t max_neighbors = vec_neighbor_indices.size() / total_q_points; 55 | 56 | at::Tensor neighbor_indices = torch::zeros( 57 | {total_q_points, max_neighbors}, 58 | at::device(q_points.device()).dtype(at::ScalarType::Long) 59 | ); 60 | 61 | std::memcpy( 62 | neighbor_indices.data_ptr(), 63 | vec_neighbor_indices.data(), 64 | sizeof(long) * total_q_points * max_neighbors 65 | ); 66 | 67 | return neighbor_indices; 68 | } 69 | -------------------------------------------------------------------------------- /sgreg/extensions/cpu/radius_neighbors/radius_neighbors.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "../../common/torch_helper.h" 4 | 5 | at::Tensor radius_neighbors( 6 | at::Tensor q_points, 7 | at::Tensor s_points, 8 | at::Tensor q_lengths, 9 | at::Tensor s_lengths, 10 | float radius 11 | ); 12 | -------------------------------------------------------------------------------- /sgreg/extensions/cpu/radius_neighbors/radius_neighbors_cpu.cpp: -------------------------------------------------------------------------------- 1 | #include "radius_neighbors_cpu.h" 2 | 3 | 4 | void radius_neighbors_cpu( 5 | std::vector& q_points, 6 | std::vector& s_points, 7 | std::vector& q_lengths, 8 | std::vector& s_lengths, 9 | std::vector& neighbor_indices, 10 | float radius 11 | ) { 12 | std::size_t i0 = 0; 13 | float r2 = radius * radius; 14 | 15 | std::size_t max_count = 0; 16 | std::vector>> all_inds_dists( 17 | q_points.size() 18 | ); 19 | 20 | std::size_t b = 0; 21 | std::size_t q_start_index = 0; 22 | std::size_t s_start_index = 0; 23 | 24 | PointCloud current_cloud; 25 | current_cloud.pts = std::vector( 26 | s_points.begin() + s_start_index, 27 | s_points.begin() + s_start_index + s_lengths[b] 28 | ); 29 | 30 | nanoflann::KDTreeSingleIndexAdaptorParams tree_params(10); 31 | my_kd_tree_t* index = new my_kd_tree_t(3, current_cloud, tree_params);; 32 | index->buildIndex(); 33 | 34 | nanoflann::SearchParams search_params; 35 | search_params.sorted = true; 36 | 37 | for (auto& p0 : q_points) { 38 | if (i0 == q_start_index + q_lengths[b]) { 39 | q_start_index += q_lengths[b]; 40 | s_start_index += s_lengths[b]; 41 | b++; 42 | 43 | current_cloud.pts.clear(); 44 | current_cloud.pts = std::vector( 45 | s_points.begin() + s_start_index, 46 | s_points.begin() + s_start_index + s_lengths[b] 47 | ); 48 | 49 | delete index; 50 | index = new my_kd_tree_t(3, current_cloud, tree_params); 51 | index->buildIndex(); 52 | } 53 | 54 | all_inds_dists[i0].reserve(max_count); 55 | float query_pt[3] = {p0.x, p0.y, p0.z}; 56 | std::size_t nMatches = index->radiusSearch( 57 | query_pt, r2, all_inds_dists[i0], search_params 58 | ); 59 | 60 | if (nMatches > max_count) { 61 | max_count = nMatches; 62 | } 63 | 64 | i0++; 65 | } 66 | 67 | delete index; 68 | 69 | neighbor_indices.resize(q_points.size() * max_count); 70 | i0 = 0; 71 | s_start_index = 0; 72 | q_start_index = 0; 73 | b = 0; 74 | for (auto& inds_dists : all_inds_dists) { 75 | if (i0 == q_start_index + q_lengths[b]) { 76 | q_start_index += q_lengths[b]; 77 | s_start_index += s_lengths[b]; 78 | b++; 79 | } 80 | 81 | for (std::size_t j = 0; j < max_count; j++) { 82 | std::size_t i = i0 * max_count + j; 83 | if (j < inds_dists.size()) { 84 | neighbor_indices[i] = inds_dists[j].first + s_start_index; 85 | } else { 86 | neighbor_indices[i] = s_points.size(); 87 | } 88 | } 89 | 90 | i0++; 91 | } 92 | } 93 | -------------------------------------------------------------------------------- /sgreg/extensions/cpu/radius_neighbors/radius_neighbors_cpu.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include "../../extra/cloud/cloud.h" 3 | #include "../../extra/nanoflann/nanoflann.hpp" 4 | 5 | typedef nanoflann::KDTreeSingleIndexAdaptor< 6 | nanoflann::L2_Simple_Adaptor, PointCloud, 3 7 | > my_kd_tree_t; 8 | 9 | void radius_neighbors_cpu( 10 | std::vector& q_points, 11 | std::vector& s_points, 12 | std::vector& q_lengths, 13 | std::vector& s_lengths, 14 | std::vector& neighbor_indices, 15 | float radius 16 | ); 17 | -------------------------------------------------------------------------------- /sgreg/extensions/extra/cloud/cloud.cpp: -------------------------------------------------------------------------------- 1 | // Modified from https://github.com/HuguesTHOMAS/KPConv-PyTorch 2 | #include "cloud.h" 3 | 4 | PointXYZ max_point(std::vector points) { 5 | PointXYZ maxP(points[0]); 6 | 7 | for (auto p : points) { 8 | if (p.x > maxP.x) { 9 | maxP.x = p.x; 10 | } 11 | if (p.y > maxP.y) { 12 | maxP.y = p.y; 13 | } 14 | if (p.z > maxP.z) { 15 | maxP.z = p.z; 16 | } 17 | } 18 | 19 | return maxP; 20 | } 21 | 22 | PointXYZ min_point(std::vector points) { 23 | PointXYZ minP(points[0]); 24 | 25 | for (auto p : points) { 26 | if (p.x < minP.x) { 27 | minP.x = p.x; 28 | } 29 | if (p.y < minP.y) { 30 | minP.y = p.y; 31 | } 32 | if (p.z < minP.z) { 33 | minP.z = p.z; 34 | } 35 | } 36 | 37 | return minP; 38 | } -------------------------------------------------------------------------------- /sgreg/extensions/extra/cloud/cloud.h: -------------------------------------------------------------------------------- 1 | // Modified from https://github.com/HuguesTHOMAS/KPConv-PyTorch 2 | #pragma once 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | class PointXYZ { 15 | public: 16 | float x, y, z; 17 | 18 | PointXYZ() { 19 | x = 0; 20 | y = 0; 21 | z = 0; 22 | } 23 | 24 | PointXYZ(float x0, float y0, float z0) { 25 | x = x0; 26 | y = y0; 27 | z = z0; 28 | } 29 | 30 | float operator [] (int i) const { 31 | if (i == 0) { 32 | return x; 33 | } 34 | else if (i == 1) { 35 | return y; 36 | } 37 | else { 38 | return z; 39 | } 40 | } 41 | 42 | float dot(const PointXYZ P) const { 43 | return x * P.x + y * P.y + z * P.z; 44 | } 45 | 46 | float sq_norm() { 47 | return x * x + y * y + z * z; 48 | } 49 | 50 | PointXYZ cross(const PointXYZ P) const { 51 | return PointXYZ(y * P.z - z * P.y, z * P.x - x * P.z, x * P.y - y * P.x); 52 | } 53 | 54 | PointXYZ& operator+=(const PointXYZ& P) { 55 | x += P.x; 56 | y += P.y; 57 | z += P.z; 58 | return *this; 59 | } 60 | 61 | PointXYZ& operator-=(const PointXYZ& P) { 62 | x -= P.x; 63 | y -= P.y; 64 | z -= P.z; 65 | return *this; 66 | } 67 | 68 | PointXYZ& operator*=(const float& a) { 69 | x *= a; 70 | y *= a; 71 | z *= a; 72 | return *this; 73 | } 74 | }; 75 | 76 | inline PointXYZ operator + (const PointXYZ A, const PointXYZ B) { 77 | return PointXYZ(A.x + B.x, A.y + B.y, A.z + B.z); 78 | } 79 | 80 | inline PointXYZ operator - (const PointXYZ A, const PointXYZ B) { 81 | return PointXYZ(A.x - B.x, A.y - B.y, A.z - B.z); 82 | } 83 | 84 | inline PointXYZ operator * (const PointXYZ P, const float a) { 85 | return PointXYZ(P.x * a, P.y * a, P.z * a); 86 | } 87 | 88 | inline PointXYZ operator * (const float a, const PointXYZ P) { 89 | return PointXYZ(P.x * a, P.y * a, P.z * a); 90 | } 91 | 92 | inline std::ostream& operator << (std::ostream& os, const PointXYZ P) { 93 | return os << "[" << P.x << ", " << P.y << ", " << P.z << "]"; 94 | } 95 | 96 | inline bool operator == (const PointXYZ A, const PointXYZ B) { 97 | return A.x == B.x && A.y == B.y && A.z == B.z; 98 | } 99 | 100 | inline PointXYZ floor(const PointXYZ P) { 101 | return PointXYZ(std::floor(P.x), std::floor(P.y), std::floor(P.z)); 102 | } 103 | 104 | PointXYZ max_point(std::vector points); 105 | 106 | PointXYZ min_point(std::vector points); 107 | 108 | struct PointCloud { 109 | std::vector pts; 110 | 111 | inline size_t kdtree_get_point_count() const { 112 | return pts.size(); 113 | } 114 | 115 | // Returns the dim'th component of the idx'th point in the class: 116 | // Since this is inlined and the "dim" argument is typically an immediate value, the 117 | // "if/else's" are actually solved at compile time. 118 | inline float kdtree_get_pt(const size_t idx, const size_t dim) const { 119 | if (dim == 0) { 120 | return pts[idx].x; 121 | } 122 | else if (dim == 1) { 123 | return pts[idx].y; 124 | } 125 | else { 126 | return pts[idx].z; 127 | } 128 | } 129 | 130 | // Optional bounding-box computation: return false to default to a standard bbox computation loop. 131 | // Return true if the BBOX was already computed by the class and returned in "bb" so it can be avoided to redo it again. 132 | // Look at bb.size() to find out the expected dimensionality (e.g. 2 or 3 for point clouds) 133 | template 134 | bool kdtree_get_bbox(BBOX& /* bb */) const { 135 | return false; 136 | } 137 | }; 138 | -------------------------------------------------------------------------------- /sgreg/extensions/pybind.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "cpu/radius_neighbors/radius_neighbors.h" 4 | #include "cpu/grid_subsampling/grid_subsampling.h" 5 | 6 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 7 | // CPU extensions 8 | m.def( 9 | "radius_neighbors", 10 | &radius_neighbors, 11 | "Radius neighbors (CPU)" 12 | ); 13 | m.def( 14 | "grid_subsampling", 15 | &grid_subsampling, 16 | "Grid subsampling (CPU)" 17 | ); 18 | } 19 | -------------------------------------------------------------------------------- /sgreg/gnn/__init__.py: -------------------------------------------------------------------------------- 1 | from sgreg.gnn.gnn import CrossGraphLayer ,SelfGAT, GraphNeuralNetwork 2 | from sgreg.gnn.triplet_gnn import TripletGNN 3 | from sgreg.gnn.nodes_init_layer import NodesInitLayer -------------------------------------------------------------------------------- /sgreg/gnn/nodes_init_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from sgreg.bert.get_tokenizer import get_tokenlizer, get_pretrained_language_model 4 | from sgreg.bert.bertwarper import generate_bert_fetures 5 | 6 | class NodesInitLayer(nn.Module): 7 | ''' Initialize semantic, boundingbox and shape embeddings of each node. 8 | Output fused node embeddings. 9 | ''' 10 | def __init__(self, sg_conf, 11 | shape_emb_dim, 12 | online_bert): 13 | super(NodesInitLayer, self).__init__() 14 | self.sg_conf = sg_conf 15 | self.shape_emb_dim = shape_emb_dim 16 | 17 | # BERT 18 | self.online_bert = online_bert 19 | if online_bert: 20 | self.tokenizer = get_tokenlizer('bert-base-uncased') 21 | self.bert = get_pretrained_language_model('bert-base-uncased') 22 | self.bert.pooler.dense.weight.requires_grad_(False) 23 | self.bert.pooler.dense.bias.requires_grad_(False) 24 | 25 | # Semantic 26 | self.mlp_semantic = torch.nn.Sequential( 27 | torch.nn.Linear(sg_conf.bert_dim,256), 28 | torch.nn.ReLU(), 29 | torch.nn.Linear(256,sg_conf.semantic_dim)) 30 | self.mlp_box = torch.nn.Sequential( 31 | torch.nn.Linear(3,16), 32 | torch.nn.ReLU(), 33 | torch.nn.Linear(16,sg_conf.box_dim)) 34 | self.instance_input_dim = sg_conf.semantic_dim + sg_conf.box_dim 35 | 36 | if sg_conf.fuse_shape and sg_conf.fuse_stage=='early': 37 | self.instance_input_dim += sg_conf.semantic_dim 38 | self.mlp_shape = nn.Linear(shape_emb_dim, 39 | sg_conf.semantic_dim) 40 | assert sg_conf.fuse_stage in ['early','late'], 'Invalid fuse stage' 41 | self.feat_projector = nn.Linear(self.instance_input_dim, 42 | sg_conf.node_dim) 43 | 44 | def forward(self, data_dict:dict, 45 | shape_embeddings:torch.Tensor): 46 | ''' Initialize nodes with instance features. 47 | ''' 48 | if 'semantic_embeddings' in data_dict and self.online_bert==False: 49 | semantic_embeddings = data_dict['semantic_embeddings'] 50 | else: 51 | semantic_embeddings = generate_bert_fetures(self.tokenizer, 52 | self.bert, 53 | data_dict['labels'], 54 | CUDA=True) 55 | semantic_feats = self.mlp_semantic(semantic_embeddings) 56 | box_feats = self.mlp_box(data_dict['boxes']) 57 | output_feat = [semantic_feats, box_feats] 58 | if self.sg_conf.fuse_shape and self.sg_conf.fuse_stage=='early': 59 | output_feat.append(self.mlp_shape(shape_embeddings)) 60 | output_feat = torch.cat(output_feat,dim=1) 61 | output_feat = self.feat_projector(output_feat) 62 | 63 | return output_feat 64 | 65 | -------------------------------------------------------------------------------- /sgreg/gnn/spatial_attention.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from pathlib import Path 3 | from typing import Callable, List, Optional 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn.functional as F 8 | from omegaconf import OmegaConf 9 | from torch import nn 10 | from torch.utils.checkpoint import checkpoint 11 | 12 | # FLASH_AVAILABLE = hasattr(F, "scaled_dot_product_attention") 13 | 14 | torch.backends.cudnn.deterministic = True 15 | 16 | def rotate_half(x: torch.Tensor) -> torch.Tensor: 17 | x = x.unflatten(-1, (-1, 2)) 18 | x1, x2 = x.unbind(dim=-1) 19 | return torch.stack((-x2, x1), dim=-1).flatten(start_dim=-2) 20 | 21 | def apply_cached_rotary_emb(encoding: torch.Tensor, x: torch.Tensor) -> torch.Tensor: 22 | return (x * encoding[0]) + (rotate_half(x) * encoding[1]) 23 | 24 | class LearnableFourierPositionalEncoding(nn.Module): 25 | def __init__(self, M: int, dim: int, F_dim: int = None, gamma: float = 1.0) -> None: 26 | super().__init__() 27 | F_dim = F_dim if F_dim is not None else dim 28 | self.gamma = gamma 29 | self.Wr = nn.Linear(M, F_dim // 2, bias=False) 30 | nn.init.normal_(self.Wr.weight.data, mean=0, std=self.gamma**-2) 31 | 32 | def forward(self, x: torch.Tensor) -> torch.Tensor: 33 | """encode position vector 34 | x: (b,n,3) 35 | return (2,b,n,d/2) 36 | """ 37 | projected = self.Wr(x) # (b,n,d/2) 38 | cosines, sines = torch.cos(projected), torch.sin(projected) 39 | emb = torch.stack([cosines, sines], 0).unsqueeze(-3) # (2,b,1,n,d/2) 40 | emb = emb.repeat_interleave(2, dim=-1) # (2,b,1,n,d) 41 | return emb.squeeze(2) 42 | return emb.repeat_interleave(2, dim=-1) 43 | 44 | 45 | class Attention(nn.Module): 46 | def __init__(self, allow_flash: bool) -> None: 47 | super().__init__() 48 | self.FLASH_AVAILABLE = hasattr(F, "scaled_dot_product_attention") 49 | if allow_flash and not self.FLASH_AVAILABLE: 50 | warnings.warn( 51 | "FlashAttention is not available. For optimal speed, " 52 | "consider installing torch >= 2.0 or flash-attn.", 53 | stacklevel=2, 54 | ) 55 | self.enable_flash = allow_flash and self.FLASH_AVAILABLE 56 | 57 | if self.FLASH_AVAILABLE: 58 | torch.backends.cuda.enable_flash_sdp(allow_flash) 59 | 60 | def forward(self, q, k, v, mask: Optional[torch.Tensor] = None) -> torch.Tensor: 61 | ''' 62 | q: (B,K,D) 63 | ''' 64 | if self.enable_flash and q.device.type == "cuda": 65 | # use torch 2.0 scaled_dot_product_attention with flash 66 | if self.FLASH_AVAILABLE: 67 | args = [x.half().contiguous() for x in [q, k, v]] 68 | # v_raw = v.clone() 69 | v = F.scaled_dot_product_attention(args[0],args[1],args[2], attn_mask=mask).to(q.dtype) 70 | if mask is not None: 71 | valid_mask = mask.sum(-1)>0 # (n,k) 72 | nan_mask = torch.isnan(v.sum(-1)) # (n,k) 73 | assert nan_mask[valid_mask].sum()<1, 'nan detected in flash attention' 74 | return v if mask is None else v.nan_to_num() 75 | elif self.FLASH_AVAILABLE: 76 | args = [x.contiguous() for x in [q, k, v]] 77 | v = F.scaled_dot_product_attention(args[0],args[1],args[2], attn_mask=mask) 78 | return v if mask is None else v.nan_to_num() 79 | else: 80 | s = q.shape[-1] ** -0.5 81 | sim = torch.einsum("...id,...jd->...ij", q, k) * s 82 | if mask is not None: 83 | sim.masked_fill(~mask, -float("inf")) 84 | attn = F.softmax(sim, -1) 85 | return torch.einsum("...ij,...jd->...id", attn, v) 86 | 87 | 88 | class SelfBlock(nn.Module): 89 | def __init__( 90 | self, embed_dim: int, num_heads: int, flash: bool = False, bias: bool = True 91 | ) -> None: 92 | super().__init__() 93 | self.embed_dim = embed_dim 94 | self.num_heads = num_heads 95 | assert self.embed_dim % num_heads == 0 96 | self.head_dim = self.embed_dim // num_heads 97 | self.Wqkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias) 98 | self.inner_attn = Attention(flash) 99 | self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 100 | self.ffn = nn.Sequential( 101 | nn.Linear(2 * embed_dim, 2 * embed_dim), 102 | nn.LayerNorm(2 * embed_dim, elementwise_affine=True), 103 | nn.GELU(), 104 | nn.Linear(2 * embed_dim, embed_dim), 105 | ) 106 | 107 | def forward( 108 | self, 109 | x: torch.Tensor, 110 | encoding: torch.Tensor=None, 111 | mask: Optional[torch.Tensor] = None, 112 | ) -> torch.Tensor: 113 | ''' 114 | x: (b,n,d) 115 | encoding: (2,b,n,d) 116 | mask: (b,h,n,n) 117 | ''' 118 | b,n,d = x.shape 119 | assert b==1, 'spaital gat currently only support batch size 1' 120 | qkv = self.Wqkv(x) # (b,n,h*d) 121 | qkv = qkv.unflatten(-1, (self.num_heads, -1, 3)).transpose(1, 2) # (b,h,n,3*d) 122 | q, k, v = qkv[..., 0], qkv[..., 1], qkv[..., 2] # (b,h,n,d) 123 | if encoding is not None: 124 | q = apply_cached_rotary_emb(encoding, q) 125 | k = apply_cached_rotary_emb(encoding, k) 126 | if mask is not None: 127 | # nodes_edge_number = mask.sum(-1) 128 | valid_nodes = mask.sum(-1)>0 # (b,h,n) 129 | # valid_nodes_number = valid_nodes.sum(-1).squeeze() # (b) 130 | # valid_nodes_indices = torch.nonzero(valid_nodes.squeeze(1))[:,1] # (numer_valid,2),[batch_id,node_id] 131 | q = q[valid_nodes].view(b,self.num_heads,-1,self.head_dim).contiguous() # (b,h,n',d) 132 | # k = k[valid_nodes].view(b,self.num_heads,-1,self.head_dim).contiguous() # (b,h,n',d) 133 | # v = v[valid_nodes].view(b,self.num_heads,-1,self.head_dim).contiguous() # (b,h,n',d) 134 | 135 | filtered_mask = mask[valid_nodes] # (b,h,n',n) 136 | # filtered_mask = mask[valid_nodes_indices[:,0],0,valid_nodes_indices[:,1],valid_nodes_indices[:,1].unsqueeze(1)] # (b,h,n',n') 137 | assert torch.all(filtered_mask.sum(-1)>0), 'filtered mask should have at least one valid node' 138 | while filtered_mask.ndim<4: 139 | filtered_mask = filtered_mask.unsqueeze(0) 140 | # if valid_number.sum() None: 155 | super().__init__() 156 | self.heads = num_heads 157 | dim_head = embed_dim // num_heads 158 | self.scale = dim_head**-0.5 159 | inner_dim = dim_head * num_heads 160 | self.to_qk = nn.Linear(embed_dim, inner_dim, bias=bias) 161 | self.to_v = nn.Linear(embed_dim, inner_dim, bias=bias) 162 | self.to_out = nn.Linear(inner_dim, embed_dim, bias=bias) 163 | self.ffn = nn.Sequential( 164 | nn.Linear(2 * embed_dim, 2 * embed_dim), 165 | nn.LayerNorm(2 * embed_dim, elementwise_affine=True), 166 | nn.GELU(), 167 | nn.Linear(2 * embed_dim, embed_dim), 168 | ) 169 | self.FLASH_AVAILABLE = hasattr(F, "scaled_dot_product_attention") 170 | if flash and self.FLASH_AVAILABLE: 171 | self.flash = Attention(True) 172 | else: 173 | self.flash = None 174 | 175 | def map_(self, func: Callable, x0: torch.Tensor, x1: torch.Tensor): 176 | return func(x0), func(x1) 177 | 178 | def forward( 179 | self, x0: torch.Tensor, x1: torch.Tensor, mask: Optional[torch.Tensor] = None 180 | ) -> List[torch.Tensor]: 181 | qk0, qk1 = self.map_(self.to_qk, x0, x1) 182 | v0, v1 = self.map_(self.to_v, x0, x1) 183 | qk0, qk1, v0, v1 = map( 184 | lambda t: t.unflatten(-1, (self.heads, -1)).transpose(1, 2), 185 | (qk0, qk1, v0, v1), 186 | ) 187 | if self.flash is not None and qk0.device.type == "cuda": 188 | m0 = self.flash(qk0, qk1, v1, mask) 189 | m1 = self.flash( 190 | qk1, qk0, v0, mask.transpose(-1, -2) if mask is not None else None 191 | ) 192 | else: 193 | qk0, qk1 = qk0 * self.scale**0.5, qk1 * self.scale**0.5 194 | sim = torch.einsum("bhid, bhjd -> bhij", qk0, qk1) 195 | if mask is not None: 196 | sim = sim.masked_fill(~mask, -float("inf")) 197 | attn01 = F.softmax(sim, dim=-1) 198 | attn10 = F.softmax(sim.transpose(-2, -1).contiguous(), dim=-1) 199 | m0 = torch.einsum("bhij, bhjd -> bhid", attn01, v1) 200 | m1 = torch.einsum("bhji, bhjd -> bhid", attn10.transpose(-2, -1), v0) 201 | if mask is not None: 202 | m0, m1 = m0.nan_to_num(), m1.nan_to_num() 203 | m0, m1 = self.map_(lambda t: t.transpose(1, 2).flatten(start_dim=-2), m0, m1) 204 | m0, m1 = self.map_(self.to_out, m0, m1) 205 | x0 = x0 + self.ffn(torch.cat([x0, m0], -1)) 206 | x1 = x1 + self.ffn(torch.cat([x1, m1], -1)) 207 | return x0, x1 208 | 209 | 210 | class SpatialTransformer(nn.Module): 211 | def __init__(self, embed_dim, heads, position_encoding, all_self_edges) -> None: 212 | super().__init__() 213 | self.embed_dim = embed_dim 214 | self.heads = heads 215 | self.head_dim = embed_dim // heads 216 | self.all_self_edges = all_self_edges 217 | self.position_encoding = position_encoding 218 | self.posenc = LearnableFourierPositionalEncoding(3, None, self.head_dim) 219 | self.self_attn = SelfBlock(self.embed_dim, self.heads, True) 220 | 221 | 222 | def forward(self, x: torch.Tensor, pos:torch.Tensor, edge_index:torch.Tensor, graph_batch:torch.Tensor) -> torch.Tensor: 223 | ''' 224 | x: (b,n,d) 225 | pos: (b,n,3) 226 | edge: (b,2,e) 227 | mask: (b,n,n) 228 | ''' 229 | if x.ndim==2: 230 | x = x.unsqueeze(0) 231 | pos = pos.unsqueeze(0) 232 | edge_index = edge_index.unsqueeze(0) 233 | if torch.isnan(x).any(): 234 | assert False, 'x nan detected' 235 | 236 | b,n,d = x.shape 237 | e = edge_index.shape[-1] 238 | if self.all_self_edges: 239 | mask = torch.zeros((b,self.heads,graph_batch[-1],graph_batch[-1])).bool().to(x.device) 240 | B_ = graph_batch.shape[0]-1 241 | for batch_id in torch.arange(B_): 242 | start_id = graph_batch[batch_id] 243 | stop_id = graph_batch[batch_id+1] 244 | mask[start_id:stop_id,start_id:stop_id] = True 245 | else: # set mask at edge index to 1 246 | mask = torch.zeros(b,self.heads,n,n).bool().to(x.device) 247 | mask[torch.arange(b).repeat((e)),torch.arange(self.heads).repeat((e)),edge_index[:,0,:],edge_index[:,1,:]] = True 248 | # print('mask shape',mask.shape) 249 | 250 | if self.position_encoding: 251 | encoding = self.posenc(pos) # (2,b,n,d) 252 | else: 253 | encoding = None 254 | out = self.self_attn(x, encoding, mask) # (b,n,d) 255 | out = out.squeeze(0) 256 | 257 | # check nan 258 | if torch.isnan(out).any(): 259 | assert False, 'nan detected in self-gat outupt' 260 | 261 | return out 262 | 263 | 264 | if __name__=='__main__': 265 | 266 | print('test the self attention block') 267 | # B = 1 268 | N = 8 269 | heads = 1 270 | embed_dim = 128 271 | 272 | x = torch.randn(N,embed_dim).float() 273 | pos = torch.randn(N,3).float() 274 | edge_index = torch.randint(0, N, (2, 10)) 275 | x = x.cuda() 276 | pos = pos.cuda() 277 | edge_index = edge_index.cuda() 278 | 279 | # print(edge_index) 280 | 281 | stransformer = SpatialTransformer(embed_dim=embed_dim,heads=heads,all_self_edges=False,position_encoding=False) 282 | stransformer = stransformer.cuda() 283 | 284 | out = stransformer(x,pos,edge_index,None) 285 | print(out.shape) 286 | 287 | exit(0) 288 | self_block = SelfBlock(embed_dim, heads, True) 289 | out = self_block(x) 290 | print(out.shape) 291 | 292 | 293 | 294 | 295 | -------------------------------------------------------------------------------- /sgreg/kpconv/__init__.py: -------------------------------------------------------------------------------- 1 | from sgreg.kpconv.kpconv import KPConv 2 | from sgreg.kpconv.modules import ( 3 | ConvBlock, 4 | ResidualBlock, 5 | UnaryBlock, 6 | LastUnaryBlock, 7 | GroupNorm, 8 | KNNInterpolate, 9 | GlobalAvgPool, 10 | MaxPool, 11 | ) 12 | from sgreg.kpconv.functional import nearest_upsample, global_avgpool, maxpool 13 | -------------------------------------------------------------------------------- /sgreg/kpconv/dispositions/k_015_center_3D.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUST-Aerial-Robotics/SG-Reg/c164198cec84be11dc53101755b0d9f7a4bc5082/sgreg/kpconv/dispositions/k_015_center_3D.ply -------------------------------------------------------------------------------- /sgreg/kpconv/functional.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from sgreg.ops import index_select 4 | 5 | @torch.jit.script 6 | def nearest_upsample(x, upsample_indices): 7 | """Pools features from the closest neighbors. 8 | 9 | WARNING: this function assumes the neighbors are ordered. 10 | 11 | Args: 12 | x: [n1, d] features matrix 13 | upsample_indices: [n2, max_num] Only the first column is used for pooling 14 | 15 | Returns: 16 | x: [n2, d] pooled features matrix 17 | """ 18 | # Add a last row with minimum features for shadow pools 19 | x = torch.cat((x, torch.zeros_like(x[:1, :])), 0) 20 | # Get features for each pooling location [n2, d] 21 | x = index_select(x, upsample_indices[:, 0], dim=0) 22 | return x 23 | 24 | 25 | def knn_interpolate(s_feats, q_points, s_points, neighbor_indices, k, eps=1e-8): 26 | r"""K-NN interpolate. 27 | 28 | WARNING: this function assumes the neighbors are ordered. 29 | 30 | Args: 31 | s_feats (Tensor): (M, C) 32 | q_points (Tensor): (N, 3) 33 | s_points (Tensor): (M, 3) 34 | neighbor_indices (LongTensor): (N, X) 35 | k (int) 36 | eps (float) 37 | 38 | Returns: 39 | q_feats (Tensor): (N, C) 40 | """ 41 | s_points = torch.cat((s_points, torch.zeros_like(s_points[:1, :])), 0) # (M + 1, 3) 42 | s_feats = torch.cat((s_feats, torch.zeros_like(s_feats[:1, :])), 0) # (M + 1, C) 43 | knn_indices = neighbor_indices[:, :k].contiguous() 44 | knn_points = index_select(s_points, knn_indices, dim=0) # (N, k, 3) 45 | knn_feats = index_select(s_feats, knn_indices, dim=0) # (N, k, C) 46 | knn_sq_distances = (q_points.unsqueeze(1) - knn_points).pow(2).sum(dim=-1) # (N, k) 47 | knn_masks = torch.ne(knn_indices, s_points.shape[0] - 1).float() # (N, k) 48 | knn_weights = knn_masks / (knn_sq_distances + eps) # (N, k) 49 | knn_weights = knn_weights / (knn_weights.sum(dim=1, keepdim=True) + eps) # (N, k) 50 | q_feats = (knn_feats * knn_weights.unsqueeze(-1)).sum(dim=1) # (N, C) 51 | return q_feats 52 | 53 | 54 | def maxpool(x, neighbor_indices): 55 | """Max pooling from neighbors. 56 | 57 | Args: 58 | x: [n1, d] features matrix 59 | neighbor_indices: [n2, max_num] pooling indices 60 | 61 | Returns: 62 | pooled_feats: [n2, d] pooled features matrix 63 | """ 64 | x = torch.cat((x, torch.zeros_like(x[:1, :])), 0) 65 | neighbor_feats = index_select(x, neighbor_indices, dim=0) 66 | pooled_feats = neighbor_feats.max(1)[0] 67 | return pooled_feats 68 | 69 | 70 | def global_avgpool(x, batch_lengths): 71 | """Global average pooling over batch. 72 | 73 | Args: 74 | x: [N, D] input features 75 | batch_lengths: [B] list of batch lengths 76 | 77 | Returns: 78 | x: [B, D] averaged features 79 | """ 80 | # Loop over the clouds of the batch 81 | averaged_features = [] 82 | i0 = 0 83 | for b_i, length in enumerate(batch_lengths): 84 | # Average features for each batch cloud 85 | averaged_features.append(torch.mean(x[i0 : i0 + length], dim=0)) 86 | # Increment for next cloud 87 | i0 += length 88 | # Average features in each batch 89 | x = torch.stack(averaged_features) 90 | return x 91 | -------------------------------------------------------------------------------- /sgreg/kpconv/kpconv.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from sgreg.ops import index_select 7 | from sgreg.kpconv.kernel_points import load_kernels 8 | 9 | 10 | class KPConv(nn.Module): 11 | def __init__( 12 | self, 13 | in_channels, 14 | out_channels, 15 | kernel_size, 16 | radius, 17 | sigma, 18 | bias=False, 19 | dimension=3, 20 | inf=1e6, 21 | eps=1e-9, 22 | ): 23 | """Initialize parameters for KPConv. 24 | 25 | Modified from [KPConv-PyTorch](https://github.com/HuguesTHOMAS/KPConv-PyTorch). 26 | 27 | Deformable KPConv is not supported. 28 | 29 | Args: 30 | in_channels: dimension of input features. 31 | out_channels: dimension of output features. 32 | kernel_size: Number of kernel points. 33 | radius: radius used for kernel point init. 34 | sigma: influence radius of each kernel point. 35 | bias: use bias or not (default: False) 36 | dimension: dimension of the point space. 37 | inf: value of infinity to generate the padding point 38 | eps: epsilon for gaussian influence 39 | """ 40 | super(KPConv, self).__init__() 41 | 42 | # Save parameters 43 | self.kernel_size = kernel_size 44 | self.in_channels = in_channels 45 | self.out_channels = out_channels 46 | self.radius = radius 47 | self.sigma = sigma 48 | self.dimension = dimension 49 | 50 | self.inf = inf 51 | self.eps = eps 52 | 53 | # Initialize weights 54 | self.weights = nn.Parameter(torch.zeros(self.kernel_size, in_channels, out_channels)) 55 | if bias: 56 | self.bias = nn.Parameter(torch.zeros(self.out_channels)) 57 | else: 58 | self.register_parameter('bias', None) 59 | 60 | # Reset parameters 61 | self.reset_parameters() 62 | 63 | # Initialize kernel points 64 | kernel_points = self.initialize_kernel_points() # (N, 3) 65 | self.register_buffer('kernel_points', kernel_points) 66 | 67 | def reset_parameters(self): 68 | nn.init.kaiming_uniform_(self.weights, a=math.sqrt(5)) 69 | if self.bias is not None: 70 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weights) 71 | bound = 1 / math.sqrt(fan_in) 72 | nn.init.uniform_(self.bias, -bound, bound) 73 | 74 | def initialize_kernel_points(self): 75 | """Initialize the kernel point positions in a sphere.""" 76 | kernel_points = load_kernels(self.radius, self.kernel_size, dimension=self.dimension, fixed='center') 77 | return torch.from_numpy(kernel_points).float() 78 | 79 | def forward(self, s_feats, q_points, s_points, neighbor_indices): 80 | r"""KPConv forward. 81 | 82 | Args: 83 | s_feats (Tensor): (N, C_in) 84 | q_points (Tensor): (M, 3) 85 | s_points (Tensor): (N, 3) 86 | neighbor_indices (LongTensor): (M, H) 87 | 88 | Returns: 89 | q_feats (Tensor): (M, C_out) 90 | """ 91 | s_points = torch.cat([s_points, torch.zeros_like(s_points[:1, :]) + self.inf], 0) # (N, 3) -> (N+1, 3) 92 | neighbors = index_select(s_points, neighbor_indices, dim=0) # (N+1, 3) -> (M, H, 3) 93 | neighbors = neighbors - q_points.unsqueeze(1) # (M, H, 3) 94 | 95 | # Get Kernel point influences 96 | neighbors = neighbors.unsqueeze(2) # (M, H, 3) -> (M, H, 1, 3) 97 | differences = neighbors - self.kernel_points # (M, H, 1, 3) x (K, 3) -> (M, H, K, 3) 98 | sq_distances = torch.sum(differences ** 2, dim=3) # (M, H, K) 99 | neighbor_weights = torch.clamp(1 - torch.sqrt(sq_distances) / self.sigma, min=0.0) # (M, H, K) 100 | neighbor_weights = torch.transpose(neighbor_weights, 1, 2) # (M, H, K) -> (M, K, H) 101 | 102 | # apply neighbor weights 103 | s_feats = torch.cat((s_feats, torch.zeros_like(s_feats[:1, :])), 0) # (N, C) -> (N+1, C) 104 | neighbor_feats = index_select(s_feats, neighbor_indices, dim=0) # (N+1, C) -> (M, H, C) 105 | weighted_feats = torch.matmul(neighbor_weights, neighbor_feats) # (M, K, H) x (M, H, C) -> (M, K, C) 106 | 107 | # apply convolutional weights 108 | weighted_feats = weighted_feats.permute(1, 0, 2) # (M, K, C) -> (K, M, C) 109 | kernel_outputs = torch.matmul(weighted_feats, self.weights) # (K, M, C) x (K, C, C_out) -> (K, M, C_out) 110 | output_feats = torch.sum(kernel_outputs, dim=0, keepdim=False) # (K, M, C_out) -> (M, C_out) 111 | 112 | # normalization 113 | neighbor_feats_sum = torch.sum(neighbor_feats, dim=-1) 114 | neighbor_num = torch.sum(torch.gt(neighbor_feats_sum, 0.0), dim=-1) 115 | neighbor_num = torch.max(neighbor_num, torch.ones_like(neighbor_num)) 116 | output_feats = output_feats / neighbor_num.unsqueeze(1) 117 | 118 | # add bias 119 | if self.bias is not None: 120 | output_feats = output_feats + self.bias 121 | 122 | return output_feats 123 | 124 | def __repr__(self): 125 | format_string = self.__class__.__name__ + '(' 126 | format_string += 'kernel_size: {}'.format(self.kernel_size) 127 | format_string += ', in_channels: {}'.format(self.in_channels) 128 | format_string += ', out_channels: {}'.format(self.out_channels) 129 | format_string += ', radius: {:g}'.format(self.radius) 130 | format_string += ', sigma: {:g}'.format(self.sigma) 131 | format_string += ', bias: {}'.format(self.bias is not None) 132 | format_string += ')' 133 | return format_string 134 | -------------------------------------------------------------------------------- /sgreg/kpconv/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from sgreg.kpconv.functional import maxpool, nearest_upsample, global_avgpool, knn_interpolate 5 | from sgreg.kpconv.kpconv import KPConv 6 | 7 | 8 | class KNNInterpolate(nn.Module): 9 | def __init__(self, k, eps=1e-8): 10 | super(KNNInterpolate, self).__init__() 11 | self.k = k 12 | self.eps = eps 13 | 14 | def forward(self, s_feats, q_points, s_points, neighbor_indices): 15 | if self.k == 1: 16 | return nearest_upsample(s_feats, neighbor_indices) 17 | else: 18 | return knn_interpolate(s_feats, q_points, s_points, neighbor_indices, self.k, eps=self.eps) 19 | 20 | 21 | class MaxPool(nn.Module): 22 | @staticmethod 23 | def forward(s_feats, neighbor_indices): 24 | return maxpool(s_feats, neighbor_indices) 25 | 26 | 27 | class GlobalAvgPool(nn.Module): 28 | @staticmethod 29 | def forward(feats, lengths): 30 | return global_avgpool(feats, lengths) 31 | 32 | 33 | class GroupNorm(nn.Module): 34 | def __init__(self, num_groups, num_channels): 35 | r"""Initialize a group normalization block. 36 | 37 | Args: 38 | num_groups: number of groups 39 | num_channels: feature dimension 40 | """ 41 | super(GroupNorm, self).__init__() 42 | self.num_groups = num_groups 43 | self.num_channels = num_channels 44 | self.norm = nn.GroupNorm(self.num_groups, self.num_channels) 45 | 46 | def forward(self, x): 47 | x = x.transpose(0, 1).unsqueeze(0) # (N, C) -> (B, C, N) 48 | x = self.norm(x) 49 | x = x.squeeze(0).transpose(0, 1) # (B, C, N) -> (N, C) 50 | return x.squeeze() 51 | 52 | 53 | class UnaryBlock(nn.Module): 54 | def __init__(self, in_channels, out_channels, group_norm, has_relu=True, bias=True, layer_norm=False): 55 | r"""Initialize a standard unary block with GroupNorm and LeakyReLU. 56 | 57 | Args: 58 | in_channels: dimension input features 59 | out_channels: dimension input features 60 | group_norm: number of groups in group normalization 61 | bias: If True, use bias 62 | layer_norm: If True, use LayerNorm instead of GroupNorm 63 | """ 64 | super(UnaryBlock, self).__init__() 65 | self.in_channels = in_channels 66 | self.out_channels = out_channels 67 | self.group_norm = group_norm 68 | self.mlp = nn.Linear(in_channels, out_channels, bias=bias) 69 | if layer_norm: 70 | self.norm = nn.LayerNorm(out_channels) 71 | else: 72 | self.norm = GroupNorm(group_norm, out_channels) 73 | if has_relu: 74 | self.leaky_relu = nn.LeakyReLU(0.1) 75 | else: 76 | self.leaky_relu = None 77 | 78 | def forward(self, x): 79 | x = self.mlp(x) 80 | x = self.norm(x) 81 | if self.leaky_relu is not None: 82 | x = self.leaky_relu(x) 83 | return x 84 | 85 | 86 | class LastUnaryBlock(nn.Module): 87 | def __init__(self, in_channels, out_channels, bias=True): 88 | r"""Initialize a standard last_unary block without GN, ReLU. 89 | 90 | Args: 91 | in_channels: dimension input features 92 | out_channels: dimension input features 93 | """ 94 | super(LastUnaryBlock, self).__init__() 95 | self.in_channels = in_channels 96 | self.out_channels = out_channels 97 | self.mlp = nn.Linear(in_channels, out_channels, bias=bias) 98 | 99 | def forward(self, x): 100 | x = self.mlp(x) 101 | return x 102 | 103 | 104 | class ConvBlock(nn.Module): 105 | def __init__( 106 | self, 107 | in_channels, 108 | out_channels, 109 | kernel_size, 110 | radius, 111 | sigma, 112 | group_norm, 113 | negative_slope=0.1, 114 | bias=True, 115 | layer_norm=False, 116 | ): 117 | r"""Initialize a KPConv block with ReLU and BatchNorm. 118 | 119 | Args: 120 | in_channels: dimension input features 121 | out_channels: dimension input features 122 | kernel_size: number of kernel points 123 | radius: convolution radius 124 | sigma: influence radius of each kernel point 125 | group_norm: group number for GroupNorm 126 | negative_slope: leaky relu negative slope 127 | bias: If True, use bias in KPConv 128 | layer_norm: If True, use LayerNorm instead of GroupNorm 129 | """ 130 | super(ConvBlock, self).__init__() 131 | 132 | self.in_channels = in_channels 133 | self.out_channels = out_channels 134 | 135 | self.KPConv = KPConv(in_channels, out_channels, kernel_size, radius, sigma, bias=bias) 136 | if layer_norm: 137 | self.norm = nn.LayerNorm(out_channels) 138 | else: 139 | self.norm = GroupNorm(group_norm, out_channels) 140 | self.leaky_relu = nn.LeakyReLU(negative_slope=negative_slope) 141 | 142 | def forward(self, s_feats, q_points, s_points, neighbor_indices): 143 | x = self.KPConv(s_feats, q_points, s_points, neighbor_indices) 144 | x = self.norm(x) 145 | x = self.leaky_relu(x) 146 | return x 147 | 148 | 149 | class ResidualBlock(nn.Module): 150 | def __init__( 151 | self, 152 | in_channels, 153 | out_channels, 154 | kernel_size, 155 | radius, 156 | sigma, 157 | group_norm, 158 | strided=False, 159 | bias=True, 160 | layer_norm=False, 161 | ): 162 | r"""Initialize a ResNet bottleneck block. 163 | 164 | Args: 165 | in_channels: dimension input features 166 | out_channels: dimension input features 167 | kernel_size: number of kernel points 168 | radius: convolution radius 169 | sigma: influence radius of each kernel point 170 | group_norm: group number for GroupNorm 171 | strided: strided or not 172 | bias: If True, use bias in KPConv 173 | layer_norm: If True, use LayerNorm instead of GroupNorm 174 | """ 175 | super(ResidualBlock, self).__init__() 176 | 177 | self.in_channels = in_channels 178 | self.out_channels = out_channels 179 | self.strided = strided 180 | 181 | mid_channels = out_channels // 4 182 | 183 | if in_channels != mid_channels: 184 | self.unary1 = UnaryBlock(in_channels, mid_channels, group_norm, bias=bias, layer_norm=layer_norm) 185 | else: 186 | self.unary1 = nn.Identity() 187 | 188 | self.KPConv = KPConv(mid_channels, mid_channels, kernel_size, radius, sigma, bias=bias) 189 | if layer_norm: 190 | self.norm_conv = nn.LayerNorm(mid_channels) 191 | else: 192 | self.norm_conv = GroupNorm(group_norm, mid_channels) 193 | 194 | self.unary2 = UnaryBlock( 195 | mid_channels, out_channels, group_norm, has_relu=False, bias=bias, layer_norm=layer_norm 196 | ) 197 | 198 | if in_channels != out_channels: 199 | self.unary_shortcut = UnaryBlock( 200 | in_channels, out_channels, group_norm, has_relu=False, bias=bias, layer_norm=layer_norm 201 | ) 202 | else: 203 | self.unary_shortcut = nn.Identity() 204 | 205 | self.leaky_relu = nn.LeakyReLU(0.1) 206 | 207 | def forward(self, s_feats, q_points, s_points, neighbor_indices): 208 | x = self.unary1(s_feats) 209 | 210 | x = self.KPConv(x, q_points, s_points, neighbor_indices) 211 | x = self.norm_conv(x) 212 | x = self.leaky_relu(x) 213 | 214 | x = self.unary2(x) 215 | 216 | if self.strided: 217 | shortcut = maxpool(s_feats, neighbor_indices) 218 | else: 219 | shortcut = s_feats 220 | shortcut = self.unary_shortcut(shortcut) 221 | 222 | x = x + shortcut 223 | x = self.leaky_relu(x) 224 | 225 | return x 226 | -------------------------------------------------------------------------------- /sgreg/loss/eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from sgreg.ops import apply_transform 5 | from sgreg.registration.metrics import isotropic_transform_error 6 | 7 | 8 | class Evaluator(nn.Module): 9 | def __init__(self, cfg): 10 | super(Evaluator, self).__init__() 11 | self.acceptance_overlap = cfg.eval.acceptance_overlap 12 | self.acceptance_radius = cfg.eval.acceptance_radius 13 | self.acceptance_rmse = cfg.eval.rmse_threshold 14 | 15 | @torch.no_grad() 16 | def evaluate_coarse(self, output_dict): 17 | ref_length_c = output_dict["ref_points_c"].shape[0] 18 | src_length_c = output_dict["src_points_c"].shape[0] 19 | gt_node_corr_overlaps = output_dict["gt_node_corr_overlaps"] 20 | gt_node_corr_indices = output_dict["gt_node_corr_indices"] 21 | masks = torch.gt(gt_node_corr_overlaps, self.acceptance_overlap) 22 | gt_node_corr_indices = gt_node_corr_indices[masks] 23 | gt_ref_node_corr_indices = gt_node_corr_indices[:, 0] 24 | gt_src_node_corr_indices = gt_node_corr_indices[:, 1] 25 | gt_node_corr_map = torch.zeros(ref_length_c, src_length_c).cuda() 26 | gt_node_corr_map[gt_ref_node_corr_indices, gt_src_node_corr_indices] = 1.0 27 | 28 | ref_node_corr_indices = output_dict["ref_node_corr_indices"] 29 | src_node_corr_indices = output_dict["src_node_corr_indices"] 30 | 31 | precision = gt_node_corr_map[ 32 | ref_node_corr_indices, src_node_corr_indices 33 | ].mean() 34 | 35 | return precision 36 | 37 | @torch.no_grad() 38 | def evaluate_fine(self, output_dict, data_dict): 39 | transform = data_dict["transform"] 40 | ref_corr_points = output_dict["ref_corr_points"] 41 | src_corr_points = output_dict["src_corr_points"] 42 | src_corr_points = apply_transform(src_corr_points, transform) 43 | corr_distances = torch.linalg.norm(ref_corr_points - src_corr_points, dim=1) 44 | precision = torch.lt(corr_distances, self.acceptance_radius).float().mean() 45 | return precision, corr_distances 46 | 47 | @torch.no_grad() 48 | def evaluate_registration(self, output_dict, data_dict): 49 | transform = data_dict["transform"] 50 | est_transform = output_dict["estimated_transform"] 51 | src_points = output_dict["src_points"] 52 | 53 | rre, rte = isotropic_transform_error(transform, est_transform) 54 | 55 | realignment_transform = torch.matmul(torch.inverse(transform), est_transform) 56 | realigned_src_points_f = apply_transform(src_points, realignment_transform) 57 | rmse = torch.linalg.norm(realigned_src_points_f - src_points, dim=1).mean() 58 | recall = torch.lt(rmse, self.acceptance_rmse).float() 59 | 60 | return rre, rte, rmse, recall 61 | 62 | def forward(self, output_dict, data_dict): 63 | # c_precision = self.evaluate_coarse(output_dict) 64 | # f_precision = self.evaluate_fine(output_dict, data_dict) 65 | rre, rte, rmse, recall = self.evaluate_registration(output_dict, data_dict) 66 | 67 | return { 68 | # 'PIR': c_precision, 69 | # 'IR': f_precision, 70 | "RRE": rre, 71 | "RTE": rte, 72 | "RMSE": rmse, 73 | "RR": recall, 74 | } 75 | 76 | 77 | def eval_instance_match(pred: torch.Tensor, gt: torch.Tensor): 78 | """background (floors, carpets) are considered. 79 | pred: (a,2), gt: (b,2) 80 | return: true_pos, false_pos 81 | """ 82 | true_pos_mask = torch.zeros_like(pred[:, 0]).bool() 83 | 84 | for row, pred_pair in enumerate(pred): 85 | check_equal = (gt - pred_pair) == 0 86 | if check_equal.sum(dim=1).max() == 2: 87 | true_pos_mask[row] = True 88 | 89 | tp = true_pos_mask.sum().cpu().numpy() 90 | fp = pred.shape[0] - tp 91 | 92 | return tp, fp, true_pos_mask.detach().cpu().numpy().astype(np.int32) 93 | 94 | def eval_instance_match_new(gt_matrix: torch.Tensor, 95 | pred: torch.Tensor, 96 | min_iou: float): 97 | """background (floors, carpets) are considered. 98 | - gt_matrix: (n,m) 99 | - pred: (a,2), [i,j] where i in [0,n), j in [0,m) 100 | - return: true_pos, false_pos 101 | """ 102 | 103 | true_pos_mask = torch.zeros_like(pred[:, 0]).bool() 104 | 105 | # gt matches 106 | row_max = torch.zeros_like(gt_matrix).to(torch.int32) 107 | col_max = torch.zeros_like(gt_matrix).to(torch.int32) 108 | row_max[torch.arange(gt_matrix.shape[0]), gt_matrix.argmax(dim=1)] = 1 109 | col_max[gt_matrix.argmax(dim=0), torch.arange(gt_matrix.shape[1])] = 1 110 | valid_mask = row_max * col_max 111 | gt_tp_matrix = torch.gt(gt_matrix, min_iou) * valid_mask 112 | 113 | # pred matches. 114 | # If there are multiple tp pairs, they are all considered as tp. 115 | for row, pred_pair in enumerate(pred): 116 | pred_iou = gt_matrix[pred_pair[0], pred_pair[1]] 117 | if pred_iou >= min_iou: 118 | true_pos_mask[row] = True 119 | 120 | tp = true_pos_mask.sum().cpu().numpy() 121 | fp = pred.shape[0] - tp 122 | true_pos_mask = true_pos_mask.detach().cpu().numpy().astype(np.int32) 123 | gt_pairs = gt_tp_matrix.sum().item() 124 | 125 | return tp, fp, gt_pairs, true_pos_mask 126 | 127 | 128 | 129 | 130 | def is_recall( 131 | source: np.ndarray, T_est: np.ndarray, T_gt: np.ndarray, threshold: float 132 | ): 133 | """check if the registration is successful 134 | source: (N,3), T_est: (4,4), T_gt: (4,4) 135 | """ 136 | 137 | source = np.hstack([source, np.ones((source.shape[0], 1))]) 138 | realignment_transform = np.linalg.inv(T_gt) @ T_est 139 | realigned_src_points_f = source @ realignment_transform.T 140 | rmse = np.linalg.norm(realigned_src_points_f[:, :3] - source[:, :3], axis=1).mean() 141 | recall = rmse < threshold 142 | return recall, rmse 143 | 144 | def compute_node_matching(metric_dict:dict): 145 | recall = metric_dict['nodes_tp'] / metric_dict['nodes_gt'] 146 | precision = metric_dict['nodes_tp'] / (metric_dict['nodes_tp'] + metric_dict['nodes_fp']) 147 | 148 | recall = 100 * recall 149 | precision = 100 * precision 150 | return recall, precision 151 | 152 | def compute_registration(metric_dict:dict): 153 | rmse = metric_dict['rmse'] / metric_dict['scenes'] 154 | recall = metric_dict['recall'] / metric_dict['scenes'] 155 | recall = 100 * recall 156 | return rmse, recall 157 | -------------------------------------------------------------------------------- /sgreg/match/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUST-Aerial-Robotics/SG-Reg/c164198cec84be11dc53101755b0d9f7a4bc5082/sgreg/match/__init__.py -------------------------------------------------------------------------------- /sgreg/match/learnable_sinkhorn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class LearnableLogOptimalTransport(nn.Module): 6 | ''' 7 | Adopted from https://github.com/qinzheng93/GeoTransformer 8 | 9 | ''' 10 | 11 | def __init__(self, num_iterations, inf=1e12): 12 | r"""Sinkhorn Optimal transport with dustbin parameter (SuperGlue style).""" 13 | super(LearnableLogOptimalTransport, self).__init__() 14 | self.num_iterations = num_iterations 15 | self.register_parameter('alpha', torch.nn.Parameter(torch.tensor(1.0))) 16 | self.inf = inf 17 | 18 | def log_sinkhorn_normalization(self, scores, log_mu, log_nu): 19 | u, v = torch.zeros_like(log_mu), torch.zeros_like(log_nu) 20 | for _ in range(self.num_iterations): 21 | u = log_mu - torch.logsumexp(scores + v.unsqueeze(1), dim=2) 22 | v = log_nu - torch.logsumexp(scores + u.unsqueeze(2), dim=1) 23 | return scores + u.unsqueeze(2) + v.unsqueeze(1) 24 | 25 | def forward(self, scores, row_masks=None, col_masks=None): 26 | r"""Sinkhorn Optimal Transport (SuperGlue style) forward. 27 | 28 | Args: 29 | scores: torch.Tensor (B, M, N) 30 | row_masks: torch.Tensor (B, M) 31 | col_masks: torch.Tensor (B, N) 32 | 33 | Returns: 34 | matching_scores: torch.Tensor (B, M+1, N+1) 35 | """ 36 | batch_size, num_row, num_col = scores.shape 37 | 38 | if row_masks is None: 39 | row_masks = torch.ones(size=(batch_size, num_row), dtype=torch.bool).cuda() 40 | if col_masks is None: 41 | col_masks = torch.ones(size=(batch_size, num_col), dtype=torch.bool).cuda() 42 | 43 | padded_row_masks = torch.zeros(size=(batch_size, num_row + 1), dtype=torch.bool).cuda() 44 | padded_row_masks[:, :num_row] = ~row_masks 45 | padded_col_masks = torch.zeros(size=(batch_size, num_col + 1), dtype=torch.bool).cuda() 46 | padded_col_masks[:, :num_col] = ~col_masks 47 | padded_score_masks = torch.logical_or(padded_row_masks.unsqueeze(2), padded_col_masks.unsqueeze(1)) 48 | 49 | padded_col = self.alpha.expand(batch_size, num_row, 1) 50 | padded_row = self.alpha.expand(batch_size, 1, num_col + 1) 51 | padded_scores = torch.cat([torch.cat([scores, padded_col], dim=-1), padded_row], dim=1) 52 | padded_scores.masked_fill_(padded_score_masks, -self.inf) 53 | 54 | num_valid_row = row_masks.float().sum(1) 55 | num_valid_col = col_masks.float().sum(1) 56 | norm = -torch.log(num_valid_row + num_valid_col) # (B,) 57 | 58 | log_mu = torch.empty(size=(batch_size, num_row + 1)).cuda() 59 | log_mu[:, :num_row] = norm.unsqueeze(1) 60 | log_mu[:, num_row] = torch.log(num_valid_col) + norm 61 | log_mu[padded_row_masks] = -self.inf 62 | 63 | log_nu = torch.empty(size=(batch_size, num_col + 1)).cuda() 64 | log_nu[:, :num_col] = norm.unsqueeze(1) 65 | log_nu[:, num_col] = torch.log(num_valid_row) + norm 66 | log_nu[padded_col_masks] = -self.inf 67 | 68 | outputs = self.log_sinkhorn_normalization(padded_scores, log_mu, log_nu) 69 | outputs = outputs - norm.unsqueeze(1).unsqueeze(2) 70 | 71 | return outputs 72 | 73 | def __repr__(self): 74 | format_string = self.__class__.__name__ + '(num_iterations={})'.format(self.num_iterations) 75 | return format_string 76 | -------------------------------------------------------------------------------- /sgreg/ops/__init__.py: -------------------------------------------------------------------------------- 1 | from sgreg.ops.transformation import( 2 | apply_transform, 3 | apply_rotation, 4 | inverse_transform, 5 | skew_symmetric_matrix, 6 | rodrigues_rotation_matrix, 7 | rodrigues_alignment_matrix, 8 | get_transform_from_rotation_translation, 9 | get_rotation_translation_from_transform, 10 | ) 11 | 12 | from sgreg.ops.instance_partition import( 13 | point_to_instance_partition, 14 | instance_f_points_batch, 15 | sample_instance_from_points, 16 | sample_all_instance_points, 17 | ) 18 | 19 | from sgreg.ops.index_select import index_select 20 | 21 | from sgreg.ops.pairwise_distance import pairwise_distance 22 | 23 | from sgreg.ops.grid_subsample import grid_subsample 24 | from sgreg.ops.radius_search import radius_search -------------------------------------------------------------------------------- /sgreg/ops/grid_subsample.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | 4 | ext_module = importlib.import_module('ext') 5 | 6 | 7 | def grid_subsample(points, insts, lengths, voxel_size): 8 | """Grid subsampling in stack mode. 9 | 10 | This function is implemented on CPU. 11 | 12 | Args: 13 | points (Tensor): stacked points. (N, 3) 14 | lengths (Tensor): number of points in the stacked batch. (B,) 15 | voxel_size (float): voxel size. 16 | 17 | Returns: 18 | s_points (Tensor): stacked subsampled points (M, 3) 19 | s_lengths (Tensor): numbers of subsampled points in the batch. (B,) 20 | """ 21 | s_points, s_insts, s_lengths = ext_module.grid_subsampling(points, insts, lengths, voxel_size) 22 | return s_points, s_insts, s_lengths 23 | -------------------------------------------------------------------------------- /sgreg/ops/index_select.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | @torch.jit.script 4 | def index_select(data: torch.Tensor, index: torch.LongTensor, dim: int) -> torch.Tensor: 5 | r"""Advanced index select. 6 | 7 | Returns a tensor `output` which indexes the `data` tensor along dimension `dim` 8 | using the entries in `index` which is a `LongTensor`. 9 | 10 | Different from `torch.index_select`, `index` does not has to be 1-D. The `dim`-th 11 | dimension of `data` will be expanded to the number of dimensions in `index`. 12 | 13 | For example, suppose the shape `data` is $(a_0, a_1, ..., a_{n-1})$, the shape of `index` is 14 | $(b_0, b_1, ..., b_{m-1})$, and `dim` is $i$, then `output` is $(n+m-1)$-d tensor, whose shape is 15 | $(a_0, ..., a_{i-1}, b_0, b_1, ..., b_{m-1}, a_{i+1}, ..., a_{n-1})$. 16 | 17 | Args: 18 | data (Tensor): (a_0, a_1, ..., a_{n-1}) 19 | index (LongTensor): (b_0, b_1, ..., b_{m-1}) 20 | dim: int 21 | 22 | Returns: 23 | output (Tensor): (a_0, ..., a_{dim-1}, b_0, ..., b_{m-1}, a_{dim+1}, ..., a_{n-1}) 24 | """ 25 | output = data.index_select(dim, index.view(-1)) 26 | 27 | if index.ndim > 1: 28 | output_shape = data.shape[:dim] + index.shape + data.shape[dim:][1:] 29 | output = output.view(output_shape[0],output_shape[1],output_shape[2]) 30 | # output = output.view(*output_shape) 31 | 32 | return output 33 | -------------------------------------------------------------------------------- /sgreg/ops/instance_partition.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | @torch.no_grad() 4 | def point_to_instance_partition( 5 | points: torch.Tensor, 6 | points_instance: torch.Tensor, 7 | instances: torch.Tensor, 8 | point_limit: int = 128, 9 | ): 10 | ''' 11 | Sample point_limit points for each instance 12 | Input, 13 | - points: (N, 3), fine-level points 14 | - points_instances: (N, 1), fine points instances 15 | - instances: (M,), instance idx 16 | ''' 17 | 18 | assert points.shape[0] == points_instance.shape[0] 19 | M = instances.shape[0] 20 | instance_points_masks = instances.unsqueeze(1).repeat(1, points.shape[0]) - points_instance.unsqueeze(0).repeat(M,1) # (M, N) 21 | instance_points_masks = instance_points_masks == 0 22 | instance_points_count = instance_points_masks.sum(dim=1) 23 | 24 | instance_masks = instance_points_count> point_limit # (M,) 25 | assert torch.all(instance_points_count>1), 'Some instances are not assigned to any fine points.' 26 | 27 | instance_knn_indices = torch.tensor(points.shape[0]).repeat((M, point_limit)).to(points.device) # (M, K) 28 | small_instances = torch.where(instance_masks==False)[0] 29 | for idx in small_instances: 30 | f_points_count = instance_points_masks[idx,:].sum() 31 | instance_knn_indices[idx,:f_points_count] = torch.where(instance_points_masks[idx,:])[0] 32 | instance_knn_indices[instance_masks] = torch.multinomial(instance_points_masks[instance_masks].float(), point_limit) # (M, K) 33 | instance_knn_masks = instance_knn_indices < points.shape[0] # (M, K) 34 | instance_knn_count = instance_knn_masks.sum(1) # (M,) 35 | 36 | assert instance_knn_indices.min() < points.shape[0], 'Some instances are not assigned to any points.' 37 | 38 | return instance_knn_indices.to(torch.int64), instance_knn_masks, instance_knn_count 39 | 40 | @torch.no_grad() 41 | def instance_f_points_batch( 42 | points_instance: torch.Tensor, 43 | instances_list: torch.Tensor, 44 | point_limit: int = 1024, 45 | ): 46 | ''' 47 | Input: 48 | - points_instances: (N, 1), fine points instances 49 | - instances: (M,), instance idx 50 | Output: 51 | - instance_f_points_indices: (M, K), fine points indices 52 | ''' 53 | M = instances_list.shape[0] 54 | # instance_count = torch.histogram(points_instance.clone().detach().cpu(), 55 | # bins=M,range=(instances_list.min().item(),instances_list.max().item()+1))[0] 56 | # K = instance_count.max().int().item() 57 | K = point_limit 58 | instance_f_points_indices = torch.zeros((M, K), dtype=torch.int64).cuda() 59 | 60 | for id in instances_list: 61 | instance_f_points_masks = points_instance == id # (N, 1) 62 | count = instance_f_points_masks.sum() 63 | assert count>1, 'An instance has none fine points.' 64 | instance_f_points_indices[id] = torch.multinomial(instance_f_points_masks.float(), K, replacement=True) 65 | 66 | # instance_f_points_indices[id][:count] = torch.where(instance_f_points_masks)[0] 67 | # if count1), 'Some instances are not assigned to any fine points.' 151 | # print(instance_points_count.clone().detach().cpu().numpy()) 152 | for id in torch.arange(instances_number): 153 | # mask_count = instance_points_count[id]1), 'Some instances are not assigned to any fine points.' 178 | 179 | for scene_id in torch.arange(instances_number): 180 | sampling_number = max(int(instance_points_count[scene_id]*sample_ratio),1) 181 | valid_indices = torch.multinomial( 182 | instance_points_mask[scene_id].float(), min(sampling_number, K)) 183 | out_instance_f_points_indices[scene_id] = instance_f_points_indices[scene_id][valid_indices] 184 | 185 | return out_instance_f_points_indices 186 | 187 | def extract_instance_f_feats( 188 | fused_feats_dict:dict, 189 | instances_knn_dict: dict, 190 | batch_size: int, 191 | ): 192 | ''' 193 | replace the f_feats in instances_knn_dict with the fused f_feats 194 | 195 | fused_feats_dict: 196 | - features: (P, C), fine-level features 197 | - features_batch: (B+1,), features batch ranges 198 | instances_knn_dict: 199 | - instances_batch: (N, 1), instance batch 200 | - instances_f_indices: (N, T), fine points indices 201 | ''' 202 | features = fused_feats_dict['feats_f'] 203 | features_batch = fused_feats_dict['feats_f_batch'] 204 | fused_instance_knn_feats = [] # (N, T, C) 205 | 206 | for scene_id in torch.arange(batch_size): 207 | scene_mask = instances_knn_dict['instances_batch']==scene_id 208 | instance_knn_indices = instances_knn_dict['instances_f_indices'][scene_mask,:] # (N_i, T) 209 | scene_features = features[features_batch[scene_id]:features_batch[scene_id+1],:] # (P_i, C) 210 | padded_scene_features = torch.cat( 211 | [scene_features, torch.zeros((1, scene_features.shape[1])).to(scene_features.device)], dim=0) # (P_i+1, C) 212 | 213 | # instance_valid = instance_knn_indices torch.Tensor: 7 | r"""Pairwise distance of two (batched) point clouds. 8 | 9 | Args: 10 | x (Tensor): (*, N, C) or (*, C, N) 11 | y (Tensor): (*, M, C) or (*, C, M) 12 | normalized (bool=False): if the points are normalized, we have "x2 + y2 = 1", so "d2 = 2 - 2xy". 13 | channel_first (bool=False): if True, the points shape is (*, C, N). 14 | 15 | Returns: 16 | dist: torch.Tensor (*, N, M) 17 | """ 18 | if channel_first: 19 | channel_dim = -2 20 | xy = torch.matmul(x.transpose(-1, -2), y) # [(*, C, N) -> (*, N, C)] x (*, C, M) 21 | else: 22 | channel_dim = -1 23 | xy = torch.matmul(x, y.transpose(-1, -2)) # (*, N, C) x [(*, M, C) -> (*, C, M)] 24 | if normalized: 25 | sq_distances = 2.0 - 2.0 * xy 26 | else: 27 | x2 = torch.sum(x ** 2, dim=channel_dim).unsqueeze(-1) # (*, N, C) or (*, C, N) -> (*, N) -> (*, N, 1) 28 | y2 = torch.sum(y ** 2, dim=channel_dim).unsqueeze(-2) # (*, M, C) or (*, C, M) -> (*, M) -> (*, 1, M) 29 | sq_distances = x2 - 2 * xy + y2 30 | sq_distances = sq_distances.clamp(min=0.0) 31 | return sq_distances 32 | -------------------------------------------------------------------------------- /sgreg/ops/radius_search.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | 4 | ext_module = importlib.import_module('ext') 5 | 6 | 7 | def radius_search(q_points, s_points, q_lengths, s_lengths, radius, neighbor_limit): 8 | r"""Computes neighbors for a batch of q_points and s_points, apply radius search (in stack mode). 9 | 10 | This function is implemented on CPU. 11 | 12 | Args: 13 | q_points (Tensor): the query points (N, 3) 14 | s_points (Tensor): the support points (M, 3) 15 | q_lengths (Tensor): the list of lengths of batch elements in q_points 16 | s_lengths (Tensor): the list of lengths of batch elements in s_points 17 | radius (float): maximum distance of neighbors 18 | neighbor_limit (int): maximum number of neighbors 19 | 20 | Returns: 21 | neighbors (Tensor): the k nearest neighbors of q_points in s_points (N, k). 22 | Filled with M if there are less than k neighbors. 23 | """ 24 | neighbor_indices = ext_module.radius_neighbors(q_points, s_points, q_lengths, s_lengths, radius) 25 | if neighbor_limit > 0: 26 | neighbor_indices = neighbor_indices[:, :neighbor_limit] 27 | return neighbor_indices 28 | -------------------------------------------------------------------------------- /sgreg/ops/transformation.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | 7 | def apply_transform(points: torch.Tensor, transform: torch.Tensor, normals: Optional[torch.Tensor] = None): 8 | r"""Rigid transform to points and normals (optional). 9 | 10 | Given a point cloud P(3, N), normals V(3, N) and a transform matrix T in the form of 11 | | R t | 12 | | 0 1 |, 13 | the output point cloud Q = RP + t, V' = RV. 14 | 15 | In the implementation, P and V are (N, 3), so R should be transposed: Q = PR^T + t, V' = VR^T. 16 | 17 | There are two cases supported: 18 | 1. points and normals are (*, 3), transform is (4, 4), the output points are (*, 3). 19 | In this case, the transform is applied to all points. 20 | 2. points and normals are (B, N, 3), transform is (B, 4, 4), the output points are (B, N, 3). 21 | In this case, the transform is applied batch-wise. The points can be broadcast if B=1. 22 | 23 | Args: 24 | points (Tensor): (*, 3) or (B, N, 3) 25 | normals (optional[Tensor]=None): same shape as points. 26 | transform (Tensor): (4, 4) or (B, 4, 4) 27 | 28 | Returns: 29 | points (Tensor): same shape as points. 30 | normals (Tensor): same shape as points. 31 | """ 32 | if normals is not None: 33 | assert points.shape == normals.shape 34 | if transform.ndim == 2: 35 | rotation = transform[:3, :3] 36 | translation = transform[:3, 3] 37 | points_shape = points.shape 38 | points = points.reshape(-1, 3) 39 | points = torch.matmul(points, rotation.transpose(-1, -2)) + translation 40 | points = points.reshape(*points_shape) 41 | if normals is not None: 42 | normals = normals.reshape(-1, 3) 43 | normals = torch.matmul(normals, rotation.transpose(-1, -2)) 44 | normals = normals.reshape(*points_shape) 45 | elif transform.ndim == 3 and points.ndim == 3: 46 | rotation = transform[:, :3, :3] # (B, 3, 3) 47 | translation = transform[:, None, :3, 3] # (B, 1, 3) 48 | points = torch.matmul(points, rotation.transpose(-1, -2)) + translation 49 | if normals is not None: 50 | normals = torch.matmul(normals, rotation.transpose(-1, -2)) 51 | else: 52 | raise ValueError( 53 | 'Incompatible shapes between points {} and transform {}.'.format( 54 | tuple(points.shape), tuple(transform.shape) 55 | ) 56 | ) 57 | if normals is not None: 58 | return points, normals 59 | else: 60 | return points 61 | 62 | 63 | def apply_rotation(points: torch.Tensor, rotation: torch.Tensor, normals: Optional[torch.Tensor] = None): 64 | r"""Rotate points and normals (optional) along the origin. 65 | 66 | Given a point cloud P(3, N), normals V(3, N) and a rotation matrix R, the output point cloud Q = RP, V' = RV. 67 | 68 | In the implementation, P and V are (N, 3), so R should be transposed: Q = PR^T, V' = VR^T. 69 | 70 | There are two cases supported: 71 | 1. points and normals are (*, 3), rotation is (3, 3), the output points are (*, 3). 72 | In this case, the rotation is applied to all points. 73 | 2. points and normals are (B, N, 3), transform is (B, 3, 3), the output points are (B, N, 3). 74 | In this case, the rotation is applied batch-wise. The points can be broadcast if B=1. 75 | 76 | Args: 77 | points (Tensor): (*, 3) or (B, N, 3) 78 | normals (optional[Tensor]=None): same shape as points. 79 | rotation (Tensor): (3, 3) or (B, 3, 3) 80 | 81 | Returns: 82 | points (Tensor): same shape as points. 83 | normals (Tensor): same shape as points. 84 | """ 85 | if normals is not None: 86 | assert points.shape == normals.shape 87 | if rotation.ndim == 2: 88 | points_shape = points.shape 89 | points = points.reshape(-1, 3) 90 | points = torch.matmul(points, rotation.transpose(-1, -2)) 91 | points = points.reshape(*points_shape) 92 | if normals is not None: 93 | normals = normals.reshape(-1, 3) 94 | normals = torch.matmul(normals, rotation.transpose(-1, -2)) 95 | normals = normals.reshape(*points_shape) 96 | elif rotation.ndim == 3 and points.ndim == 3: 97 | points = torch.matmul(points, rotation.transpose(-1, -2)) 98 | if normals is not None: 99 | normals = torch.matmul(normals, rotation.transpose(-1, -2)) 100 | else: 101 | raise ValueError( 102 | 'Incompatible shapes between points {} and rotation{}.'.format(tuple(points.shape), tuple(rotation.shape)) 103 | ) 104 | if normals is not None: 105 | return points, normals 106 | else: 107 | return points 108 | 109 | 110 | def get_rotation_translation_from_transform(transform): 111 | r"""Decompose transformation matrix into rotation matrix and translation vector. 112 | 113 | Args: 114 | transform (Tensor): (*, 4, 4) 115 | 116 | Returns: 117 | rotation (Tensor): (*, 3, 3) 118 | translation (Tensor): (*, 3) 119 | """ 120 | rotation = transform[..., :3, :3] 121 | translation = transform[..., :3, 3] 122 | return rotation, translation 123 | 124 | 125 | def get_transform_from_rotation_translation(rotation, translation): 126 | r"""Compose transformation matrix from rotation matrix and translation vector. 127 | 128 | Args: 129 | rotation (Tensor): (*, 3, 3) 130 | translation (Tensor): (*, 3) 131 | 132 | Returns: 133 | transform (Tensor): (*, 4, 4) 134 | """ 135 | input_shape = rotation.shape 136 | rotation = rotation.view(-1, 3, 3) 137 | translation = translation.view(-1, 3) 138 | transform = torch.eye(4).to(rotation).unsqueeze(0).repeat(rotation.shape[0], 1, 1) 139 | transform[:, :3, :3] = rotation 140 | transform[:, :3, 3] = translation 141 | output_shape = input_shape[:-2] + (4, 4) 142 | transform = transform.view(*output_shape) 143 | return transform 144 | 145 | 146 | def inverse_transform(transform): 147 | r"""Inverse rigid transform. 148 | 149 | Args: 150 | transform (Tensor): (*, 4, 4) 151 | 152 | Return: 153 | inv_transform (Tensor): (*, 4, 4) 154 | """ 155 | rotation, translation = get_rotation_translation_from_transform(transform) # (*, 3, 3), (*, 3) 156 | inv_rotation = rotation.transpose(-1, -2) # (*, 3, 3) 157 | inv_translation = -torch.matmul(inv_rotation, translation.unsqueeze(-1)).squeeze(-1) # (*, 3) 158 | inv_transform = get_transform_from_rotation_translation(inv_rotation, inv_translation) # (*, 4, 4) 159 | return inv_transform 160 | 161 | 162 | def skew_symmetric_matrix(inputs): 163 | r"""Compute Skew-symmetric Matrix. 164 | 165 | [v]_{\times} = 0 -z y 166 | z 0 -x 167 | -y x 0 168 | 169 | Args: 170 | inputs (Tensor): input vectors (*, c) 171 | 172 | Returns: 173 | skews (Tensor): output skew-symmetric matrix (*, 3, 3) 174 | """ 175 | input_shape = inputs.shape 176 | output_shape = input_shape[:-1] + (3, 3) 177 | skews = torch.zeros(size=output_shape).cuda() 178 | skews[..., 0, 1] = -inputs[..., 2] 179 | skews[..., 0, 2] = inputs[..., 1] 180 | skews[..., 1, 0] = inputs[..., 2] 181 | skews[..., 1, 2] = -inputs[..., 0] 182 | skews[..., 2, 0] = -inputs[..., 1] 183 | skews[..., 2, 1] = inputs[..., 0] 184 | return skews 185 | 186 | 187 | def rodrigues_rotation_matrix(axes, angles): 188 | r"""Compute Rodrigues Rotation Matrix. 189 | 190 | R = I + \sin{\theta} K + (1 - \cos{\theta}) K^2, 191 | where K is the skew-symmetric matrix of the axis vector. 192 | 193 | Args: 194 | axes (Tensor): axis vectors (*, 3) 195 | angles (Tensor): rotation angles in right-hand direction in rad. (*) 196 | 197 | Returns: 198 | rotations (Tensor): Rodrigues rotation matrix (*, 3, 3) 199 | """ 200 | input_shape = axes.shape 201 | axes = axes.view(-1, 3) 202 | angles = angles.view(-1) 203 | axes = F.normalize(axes, p=2, dim=1) 204 | skews = skew_symmetric_matrix(axes) # (B, 3, 3) 205 | sin_values = torch.sin(angles).view(-1, 1, 1) # (B,) 206 | cos_values = torch.cos(angles).view(-1, 1, 1) # (B,) 207 | eyes = torch.eye(3).cuda().unsqueeze(0).expand_as(skews) # (B, 3, 3) 208 | rotations = eyes + sin_values * skews + (1.0 - cos_values) * torch.matmul(skews, skews) 209 | output_shape = input_shape[:-1] + (3, 3) 210 | rotations = rotations.view(*output_shape) 211 | return rotations 212 | 213 | 214 | def rodrigues_alignment_matrix(src_vectors, tgt_vectors): 215 | r"""Compute the Rodrigues rotation matrix aligning source vectors to target vectors. 216 | 217 | Args: 218 | src_vectors (Tensor): source vectors (*, 3) 219 | tgt_vectors (Tensor): target vectors (*, 3) 220 | 221 | Returns: 222 | rotations (Tensor): rotation matrix (*, 3, 3) 223 | """ 224 | input_shape = src_vectors.shape 225 | src_vectors = src_vectors.view(-1, 3) # (B, 3) 226 | tgt_vectors = tgt_vectors.view(-1, 3) # (B, 3) 227 | 228 | # compute axes 229 | src_vectors = F.normalize(src_vectors, dim=-1, p=2) # (B, 3) 230 | tgt_vectors = F.normalize(tgt_vectors, dim=-1, p=2) # (B, 3) 231 | src_skews = skew_symmetric_matrix(src_vectors) # (B, 3, 3) 232 | axes = torch.matmul(src_skews, tgt_vectors.unsqueeze(-1)).squeeze(-1) # (B, 3) 233 | 234 | # compute rodrigues rotation matrix 235 | sin_values = torch.linalg.norm(axes, dim=-1) # (B,) 236 | cos_values = (src_vectors * tgt_vectors).sum(dim=-1) # (B,) 237 | axes = F.normalize(axes, dim=-1, p=2) # (B, 3) 238 | skews = skew_symmetric_matrix(axes) # (B, 3, 3) 239 | eyes = torch.eye(3).cuda().unsqueeze(0).expand_as(skews) # (B, 3, 3) 240 | sin_values = sin_values.view(-1, 1, 1) 241 | cos_values = cos_values.view(-1, 1, 1) 242 | rotations = eyes + sin_values * skews + (1.0 - cos_values) * torch.matmul(skews, skews) 243 | 244 | # handle opposite direction 245 | sin_values = sin_values.view(-1) 246 | cos_values = cos_values.view(-1) 247 | masks = torch.logical_and(torch.eq(sin_values, 0.0), torch.lt(cos_values, 0.0)) 248 | rotations[masks] *= -1 249 | 250 | output_shape = input_shape[:-1] + (3, 3) 251 | rotations = rotations.view(*output_shape) 252 | 253 | return rotations 254 | -------------------------------------------------------------------------------- /sgreg/registration/__init__.py: -------------------------------------------------------------------------------- 1 | from sgreg.registration.procrustes import( 2 | weighted_procrustes, 3 | WeightedProcrustes 4 | ) -------------------------------------------------------------------------------- /sgreg/registration/hybrid_reg.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import teaserpp_python 3 | import open3d as o3d 4 | from scipy.spatial.kdtree import KDTree 5 | from copy import deepcopy 6 | import pygicp 7 | 8 | 9 | class HybridReg: 10 | def __init__( 11 | self, 12 | src_pcd: o3d.geometry.PointCloud, 13 | tgt_pcd: o3d.geometry.PointCloud, 14 | refine=None, # 'icp' or 'vgicp' 15 | use_pagor=False, 16 | only_yaw=False, 17 | ins_wise=False, 18 | max_ins=256, 19 | max_pts=256, 20 | ): 21 | self.src_pcd = src_pcd 22 | self.tgt_pcd = tgt_pcd 23 | self.refine = refine 24 | self.use_pagor = use_pagor 25 | self.only_yaw = only_yaw 26 | self.ins_wise = ins_wise 27 | self.instance_corres = [] 28 | self.max_inst = max_ins 29 | self.max_points = max_pts 30 | 31 | voxel_size = 0.05 32 | self.src_pcd = self.src_pcd.voxel_down_sample(voxel_size) 33 | self.tgt_pcd = self.tgt_pcd.voxel_down_sample(voxel_size) 34 | if self.refine == "icp": 35 | if np.array(self.tgt_pcd.normals).shape[0] == 0: 36 | self.tgt_pcd.estimate_normals( 37 | o3d.geometry.KDTreeSearchParamHybrid(radius=0.3, max_nn=100) 38 | ) 39 | 40 | if self.use_pagor or self.ins_wise: 41 | tgt_points = np.asarray(self.tgt_pcd.points) 42 | self.tgt_kdtree = KDTree(tgt_points) 43 | if self.use_pagor: 44 | self.noise_bounds = [0.1, 0.2, 0.3] 45 | else: 46 | self.noise_bounds = [0.2] 47 | 48 | def set_model_pred(self, model_pred: dict): 49 | 50 | corres_scores = model_pred["corres_scores"] # (C,) 51 | score_mask = corres_scores > 1e-3 52 | corres_src = model_pred["corres_src_points"][score_mask] # (C,3) 53 | corres_ref = model_pred["corres_ref_points"][score_mask] # (C,3) 54 | corres_instances = model_pred["corres_instances"][ 55 | score_mask 56 | ] # (C,) values in [0,I-1] 57 | 58 | # pred_nodes = model_pred["pred_nodes"] # (I,) 59 | pred_scores = model_pred["pred_scores"] 60 | corres_src_centroids = model_pred["corres_src_centroids"] # (I,3) 61 | corres_ref_centroids = model_pred["corres_ref_centroids"] # (I,3) 62 | if pred_scores.shape[0] > self.max_inst: 63 | inst_pair_indices = np.argsort(pred_scores)[ 64 | ::-1 65 | ] # (I,), descending order according to pred_scores 66 | inst_pair_indices = inst_pair_indices[: self.max_inst] 67 | else: 68 | inst_pair_indices = np.arange(pred_scores.shape[0]) 69 | 70 | instance_corres = [] 71 | for i in inst_pair_indices: 72 | pred_score = pred_scores[i] 73 | corres_src_centroid = corres_src_centroids[i] 74 | corres_ref_centroid = corres_ref_centroids[i] 75 | 76 | matching = { 77 | "centorid_corr": np.hstack([corres_src_centroid, corres_ref_centroid]), 78 | "score": pred_score, 79 | } 80 | ins_mask = corres_instances == i 81 | if len(ins_mask) > self.max_points: 82 | descending_indices = np.argsort(corres_scores[ins_mask])[::-1] 83 | points_src = corres_src[ins_mask][descending_indices[: self.max_points]] 84 | points_ref = corres_ref[ins_mask][descending_indices[: self.max_points]] 85 | point_corr = np.hstack([points_src, points_ref]) 86 | elif len(ins_mask) > 0: 87 | point_corr = np.hstack([corres_src[ins_mask], corres_ref[ins_mask]]) 88 | else: 89 | point_corr = np.empty((0, 6)) 90 | matching["point_corr"] = point_corr 91 | instance_corres.append(matching) 92 | 93 | self.instance_corres = instance_corres 94 | return instance_corres 95 | 96 | def solve(self): 97 | if len(self.instance_corres) == 0: 98 | return np.eye(4) 99 | 100 | if self.ins_wise: 101 | # Handle each instance individually 102 | A_corr_list, B_corr_list = self.front_end() 103 | tf_candidates = [] 104 | for A_corr, B_corr in zip(A_corr_list, B_corr_list): 105 | tf = self.back_end(A_corr, B_corr) 106 | tf_candidates.append(tf) 107 | 108 | scores = [self.chamfer_score(tf) for tf in tf_candidates] 109 | tf = tf_candidates[np.argmax(scores)] 110 | else: 111 | A_corr, B_corr = self.front_end() 112 | tf = self.back_end(A_corr, B_corr) 113 | 114 | return tf 115 | 116 | def front_end(self): 117 | # 抽取每个instance的匹配 118 | A_corr_list, B_corr_list = [], [] 119 | for ins_corr in self.instance_corres: 120 | # Load point correspondences 121 | A_corr = ins_corr["point_corr"][:, :3].T 122 | A_centroid = ins_corr["centorid_corr"][:3].reshape(3, 1) 123 | A_corr = np.hstack([A_corr, A_centroid]) 124 | A_corr_list.append(A_corr) 125 | 126 | B_corr = ins_corr["point_corr"][:, 3:].T 127 | B_centroid = ins_corr["centorid_corr"][3:].reshape(3, 1) 128 | B_corr = np.hstack([B_corr, B_centroid]) 129 | B_corr_list.append(B_corr) 130 | if self.ins_wise: 131 | return A_corr_list, B_corr_list 132 | else: 133 | # Merge correspondences 134 | A_corr = np.hstack(A_corr_list) 135 | B_corr = np.hstack(B_corr_list) 136 | return A_corr, B_corr 137 | 138 | def back_end(self, A_corr, B_corr): 139 | 140 | tf_candidates = [] 141 | for noise_bound in self.noise_bounds: 142 | tf = self.solve_by_teaser(A_corr, B_corr, noise_bound) 143 | if tf_candidates == []: 144 | tf_candidates.append(tf) 145 | else: 146 | similar = [ 147 | np.linalg.norm(tf_candidate[:3, 3] - tf[:3, 3]) < 0.1 148 | for tf_candidate in tf_candidates 149 | ] 150 | if not any(similar): 151 | tf_candidates.append(tf) 152 | 153 | # Verification 154 | if len(tf_candidates) > 1: 155 | scores = [self.chamfer_score(tf) for tf in tf_candidates] 156 | tf = tf_candidates[np.argmax(scores)] 157 | else: 158 | tf = tf_candidates[0] 159 | 160 | return tf 161 | 162 | def chamfer_score(self, tf): 163 | src_pcd = deepcopy(self.src_pcd) 164 | src_pcd.transform(tf) 165 | src_points = np.asarray(src_pcd.points) 166 | nearest_dists = self.tgt_kdtree.query(src_points, k=1)[0] 167 | nearest_dists = np.clip(nearest_dists, 0, 0.5) 168 | return -np.mean(nearest_dists) 169 | 170 | def solve_by_teaser(self, A_corr, B_corr, noise_bound): 171 | 172 | teaser_solver = self.get_teaser_solver(noise_bound=noise_bound) 173 | teaser_solver.solve(A_corr, B_corr) 174 | solution = teaser_solver.getSolution() 175 | tf = np.identity(4) 176 | tf[:3, :3] = solution.rotation 177 | tf[:3, 3] = solution.translation 178 | 179 | if self.refine is not None: 180 | if self.refine == "vgicp": 181 | source = np.asarray(self.src_pcd.points) 182 | target = np.asarray(self.tgt_pcd.points) 183 | tf = pygicp.align_points( 184 | target, 185 | source, 186 | initial_guess=tf, 187 | max_correspondence_distance=1.0, 188 | voxel_resolution=0.5, 189 | method="VGICP", 190 | ) 191 | else: 192 | loss = o3d.pipelines.registration.GMLoss(k=0.1) 193 | p2l = o3d.pipelines.registration.TransformationEstimationPointToPlane( 194 | loss 195 | ) 196 | reg_p2l = o3d.pipelines.registration.registration_icp( 197 | self.src_pcd, 198 | self.tgt_pcd, 199 | 0.5, 200 | tf, 201 | p2l, 202 | ) 203 | tf = reg_p2l.transformation 204 | 205 | return tf 206 | 207 | def get_teaser_solver(self, noise_bound): 208 | solver_params = teaserpp_python.RobustRegistrationSolver.Params() 209 | solver_params.cbar2 = 1.0 210 | solver_params.noise_bound = noise_bound 211 | solver_params.estimate_scaling = False 212 | solver_params.inlier_selection_mode = ( 213 | teaserpp_python.RobustRegistrationSolver.INLIER_SELECTION_MODE.PMC_EXACT 214 | ) 215 | solver_params.rotation_tim_graph = ( 216 | teaserpp_python.RobustRegistrationSolver.INLIER_GRAPH_FORMULATION.CHAIN 217 | ) 218 | if self.only_yaw: 219 | solver_params.rotation_estimation_algorithm = ( 220 | teaserpp_python.RobustRegistrationSolver.ROTATION_ESTIMATION_ALGORITHM.QUATRO 221 | ) 222 | else: 223 | solver_params.rotation_estimation_algorithm = ( 224 | teaserpp_python.RobustRegistrationSolver.ROTATION_ESTIMATION_ALGORITHM.GNC_TLS 225 | ) 226 | solver_params.rotation_gnc_factor = 1.4 227 | solver_params.rotation_max_iterations = 10000 228 | solver_params.rotation_cost_threshold = 1e-16 229 | solver = teaserpp_python.RobustRegistrationSolver(solver_params) 230 | return solver 231 | -------------------------------------------------------------------------------- /sgreg/registration/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from sgreg.ops.transformation import apply_transform, get_rotation_translation_from_transform 5 | from sgreg.ops.pairwise_distance import pairwise_distance 6 | # from geotransformer.utils.registration import compute_transform_mse_and_mae 7 | 8 | 9 | def modified_chamfer_distance(raw_points, ref_points, src_points, gt_transform, transform, reduction='mean'): 10 | r"""Compute the modified chamfer distance. 11 | 12 | Args: 13 | raw_points (Tensor): (B, N_raw, 3) 14 | ref_points (Tensor): (B, N_ref, 3) 15 | src_points (Tensor): (B, N_src, 3) 16 | gt_transform (Tensor): (B, 4, 4) 17 | transform (Tensor): (B, 4, 4) 18 | reduction (str='mean'): reduction method, 'mean', 'sum' or 'none' 19 | 20 | Returns: 21 | chamfer_distance 22 | """ 23 | assert reduction in ['mean', 'sum', 'none'] 24 | 25 | # P_t -> Q_raw 26 | aligned_src_points = apply_transform(src_points, transform) # (B, N_src, 3) 27 | sq_dist_mat_p_q = pairwise_distance(aligned_src_points, raw_points) # (B, N_src, N_raw) 28 | nn_sq_distances_p_q = sq_dist_mat_p_q.min(dim=-1)[0] # (B, N_src) 29 | chamfer_distance_p_q = torch.sqrt(nn_sq_distances_p_q).mean(dim=-1) # (B) 30 | 31 | # Q -> P_raw 32 | composed_transform = torch.matmul(transform, torch.inverse(gt_transform)) # (B, 4, 4) 33 | aligned_raw_points = apply_transform(raw_points, composed_transform) # (B, N_raw, 3) 34 | sq_dist_mat_q_p = pairwise_distance(ref_points, aligned_raw_points) # (B, N_ref, N_raw) 35 | nn_sq_distances_q_p = sq_dist_mat_q_p.min(dim=-1)[0] # (B, N_ref) 36 | chamfer_distance_q_p = torch.sqrt(nn_sq_distances_q_p).mean(dim=-1) # (B) 37 | 38 | # sum up 39 | chamfer_distance = chamfer_distance_p_q + chamfer_distance_q_p # (B) 40 | 41 | if reduction == 'mean': 42 | chamfer_distance = chamfer_distance.mean() 43 | elif reduction == 'sum': 44 | chamfer_distance = chamfer_distance.sum() 45 | return chamfer_distance 46 | 47 | 48 | def relative_rotation_error(gt_rotations, rotations): 49 | r"""Isotropic Relative Rotation Error. 50 | 51 | RRE = acos((trace(R^T \cdot \bar{R}) - 1) / 2) 52 | 53 | Args: 54 | gt_rotations (Tensor): ground truth rotation matrix (*, 3, 3) 55 | rotations (Tensor): estimated rotation matrix (*, 3, 3) 56 | 57 | Returns: 58 | rre (Tensor): relative rotation errors (*) 59 | """ 60 | mat = torch.matmul(rotations.transpose(-1, -2), gt_rotations) 61 | trace = mat[..., 0, 0] + mat[..., 1, 1] + mat[..., 2, 2] 62 | x = 0.5 * (trace - 1.0) 63 | x = x.clamp(min=-1.0, max=1.0) 64 | x = torch.arccos(x) 65 | rre = 180.0 * x / np.pi 66 | return rre 67 | 68 | 69 | def relative_translation_error(gt_translations, translations): 70 | r"""Isotropic Relative Rotation Error. 71 | 72 | RTE = \lVert t - \bar{t} \rVert_2 73 | 74 | Args: 75 | gt_translations (Tensor): ground truth translation vector (*, 3) 76 | translations (Tensor): estimated translation vector (*, 3) 77 | 78 | Returns: 79 | rre (Tensor): relative rotation errors (*) 80 | """ 81 | rte = torch.linalg.norm(gt_translations - translations, dim=-1) 82 | return rte 83 | 84 | 85 | def isotropic_transform_error(gt_transforms, transforms, reduction='mean'): 86 | r"""Compute the isotropic Relative Rotation Error and Relative Translation Error. 87 | 88 | Args: 89 | gt_transforms (Tensor): ground truth transformation matrix (*, 4, 4) 90 | transforms (Tensor): estimated transformation matrix (*, 4, 4) 91 | reduction (str='mean'): reduction method, 'mean', 'sum' or 'none' 92 | 93 | Returns: 94 | rre (Tensor): relative rotation error. 95 | rte (Tensor): relative translation error. 96 | """ 97 | assert reduction in ['mean', 'sum', 'none'] 98 | 99 | gt_rotations, gt_translations = get_rotation_translation_from_transform(gt_transforms) 100 | rotations, translations = get_rotation_translation_from_transform(transforms) 101 | 102 | rre = relative_rotation_error(gt_rotations, rotations) # (*) 103 | rte = relative_translation_error(gt_translations, translations) # (*) 104 | 105 | if reduction == 'mean': 106 | rre = rre.mean() 107 | rte = rte.mean() 108 | elif reduction == 'sum': 109 | rre = rre.sum() 110 | rte = rte.sum() 111 | 112 | return rre, rte 113 | 114 | 115 | # def anisotropic_transform_error(gt_transforms, transforms, reduction='mean'): 116 | # r"""Compute the anisotropic Relative Rotation Error and Relative Translation Error. 117 | 118 | # This function calls numpy-based implementation to achieve batch-wise computation and thus is non-differentiable. 119 | 120 | # Args: 121 | # gt_transforms (Tensor): ground truth transformation matrix (B, 4, 4) 122 | # transforms (Tensor): estimated transformation matrix (B, 4, 4) 123 | # reduction (str='mean'): reduction method, 'mean', 'sum' or 'none' 124 | 125 | # Returns: 126 | # r_mse (Tensor): rotation mse. 127 | # r_mae (Tensor): rotation mae. 128 | # t_mse (Tensor): translation mse. 129 | # t_mae (Tensor): translation mae. 130 | # """ 131 | # assert reduction in ['mean', 'sum', 'none'] 132 | 133 | # batch_size = gt_transforms.shape[0] 134 | # gt_transforms_array = gt_transforms.detach().cpu().numpy() 135 | # transforms_array = transforms.detach().cpu().numpy() 136 | 137 | # all_r_mse = [] 138 | # all_r_mae = [] 139 | # all_t_mse = [] 140 | # all_t_mae = [] 141 | # for i in range(batch_size): 142 | # r_mse, r_mae, t_mse, t_mae = compute_transform_mse_and_mae(gt_transforms_array[i], transforms_array[i]) 143 | # all_r_mse.append(r_mse) 144 | # all_r_mae.append(r_mae) 145 | # all_t_mse.append(t_mse) 146 | # all_t_mae.append(t_mae) 147 | # r_mse = torch.as_tensor(all_r_mse).to(gt_transforms) 148 | # r_mae = torch.as_tensor(all_r_mae).to(gt_transforms) 149 | # t_mse = torch.as_tensor(all_t_mse).to(gt_transforms) 150 | # t_mae = torch.as_tensor(all_t_mae).to(gt_transforms) 151 | 152 | # if reduction == 'mean': 153 | # r_mse = r_mse.mean() 154 | # r_mae = r_mae.mean() 155 | # t_mse = t_mse.mean() 156 | # t_mae = t_mae.mean() 157 | # elif reduction == 'sum': 158 | # r_mse = r_mse.sum() 159 | # r_mae = r_mae.sum() 160 | # t_mse = t_mse.sum() 161 | # t_mae = t_mae.sum() 162 | 163 | # return r_mse, r_mae, t_mse, t_mae 164 | -------------------------------------------------------------------------------- /sgreg/registration/offline_registration.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import open3d as o3d 5 | from omegaconf import OmegaConf 6 | 7 | from sgreg.loss.eval import Evaluator 8 | from sgreg.utils.utils import read_scan_pairs 9 | from sgreg.utils.io import read_pred_nodes, read_corr_scores 10 | 11 | from os.path import join as osp 12 | 13 | def read_points(dir:str): 14 | pcd = o3d.io.read_point_cloud(dir) 15 | points = np.asarray(pcd.points) # (N,3) 16 | return points 17 | 18 | if __name__=='__main__': 19 | print('*'*60) 20 | print('This script reads the data association from neural network.') 21 | print('It estimates a relative transformation between scene grpahs.') 22 | print('Lastly, save and evaluate the results.') 23 | print('*'*60) 24 | 25 | ############ Args ############ 26 | DATAROOT = '/data2/RioGraph' 27 | SPLIT = 'val' 28 | RESULT_FOLDER = osp(DATAROOT, 'output', 'sgnet_init_layer2') 29 | ############################## 30 | 31 | cfg = OmegaConf.create({'eval': {'acceptance_overlap': 0.0, 32 | 'acceptance_radius': 0.1, 33 | 'rmse_threshold': 0.2}}) 34 | eval = Evaluator(cfg) 35 | summary_rmse = [] 36 | summary_recall = [] 37 | 38 | scene_pairs = read_scan_pairs(osp(DATAROOT, 'splits', SPLIT+'.txt')) 39 | print('Read {} scene pairs'.format(len(scene_pairs))) 40 | 41 | for scene_pair in scene_pairs: 42 | print('----------- {}-{} ------------'.format(scene_pair[0],scene_pair[1])) 43 | scene_result_folder = os.path.join(RESULT_FOLDER, '{}-{}'.format(scene_pair[0],scene_pair[1])) 44 | 45 | # 1. load src points, node_matches, and correspondence points 46 | src_points = read_points(osp(scene_result_folder,'src_instances.ply')) 47 | _, _, src_centroids, ref_centroids, _ = read_pred_nodes(osp(scene_result_folder,'node_matches.txt')) 48 | corr_src_points = read_points(osp(scene_result_folder,'corr_src.ply')) 49 | corr_ref_points = read_points(osp(scene_result_folder,'corr_ref.ply')) 50 | _, corr_scores, _ = read_corr_scores(osp(scene_result_folder,'point_matches.txt')) 51 | assert corr_src_points.shape[0] == corr_scores.shape[0] 52 | print('Read {} node matches and {} point corrs'.format(src_centroids.shape[0], 53 | corr_src_points.shape[0])) 54 | 55 | # 2. load gt 56 | T_ref_src = np.loadtxt(osp(scene_result_folder, 'gt_transform.txt')) 57 | 58 | # TODO 3. estimate transformation 59 | T_fake_pose = np.eye(4) 60 | 61 | # 4. Eval 62 | precision, corr_errors \ 63 | = eval.evaluate_fine({'src_corr_points':torch.from_numpy(corr_src_points).float(), 64 | 'ref_corr_points':torch.from_numpy(corr_ref_points).float()}, 65 | {'transform':torch.from_numpy(T_ref_src).float()}) 66 | rre, rte, rmse, recall =\ 67 | eval.evaluate_registration({'src_points':torch.from_numpy(src_points).float(), 68 | 'estimated_transform':torch.from_numpy(T_fake_pose).float()}, 69 | {'transform':torch.from_numpy(T_ref_src).float()}) 70 | 71 | # print('Inlier ratio: {:.2f}'.format(precision.item())) 72 | msg = 'Inlier ratio: {:.3f}%, '.format(precision.item()*100) 73 | msg += 'RRE: {:.2f} deg, RTE: {:.2f}m'.format(rre.item(), rte.item()) 74 | msg += ', RMSE: {:.2f}m'.format(rmse.item()) 75 | print(msg) 76 | 77 | summary_rmse.append(rmse.item()) 78 | summary_recall.append(recall.item()) 79 | 80 | # break 81 | 82 | print('************************** Summary ***************************') 83 | # In the C++ version of registration code (https://github.com/glennliu/OpensetFusion/blob/master/src/Test3RscanRegister.cpp). 84 | # The result is, Registration recall: 0.790(79/100), RMSE: 0.102m 85 | if len(summary_rmse)>0: 86 | summary_rmse = np.array(summary_rmse) 87 | summary_recall = np.array(summary_recall) 88 | print('Average RMSE: {:.2f}m'.format(np.mean(summary_rmse))) 89 | print('Registration Recall: {:.2f}% ({}/{})'.format(np.mean(summary_recall)*100, 90 | np.sum(summary_recall).astype(int), 91 | summary_recall.shape[0])) -------------------------------------------------------------------------------- /sgreg/registration/procrustes.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | # import ipdb 4 | 5 | 6 | def weighted_procrustes( 7 | src_points, 8 | ref_points, 9 | weights=None, 10 | weight_thresh=0.0, 11 | eps=1e-5, 12 | return_transform=False, 13 | ): 14 | r"""Compute rigid transformation from `src_points` to `ref_points` using weighted SVD. 15 | 16 | Modified from [PointDSC](https://github.com/XuyangBai/PointDSC/blob/master/models/common.py). 17 | 18 | Args: 19 | src_points: torch.Tensor (B, N, 3) or (N, 3) 20 | ref_points: torch.Tensor (B, N, 3) or (N, 3) 21 | weights: torch.Tensor (B, N) or (N,) (default: None) 22 | weight_thresh: float (default: 0.) 23 | eps: float (default: 1e-5) 24 | return_transform: bool (default: False) 25 | 26 | Returns: 27 | R: torch.Tensor (B, 3, 3) or (3, 3) 28 | t: torch.Tensor (B, 3) or (3,) 29 | transform: torch.Tensor (B, 4, 4) or (4, 4) 30 | """ 31 | if src_points.ndim == 2: 32 | src_points = src_points.unsqueeze(0) 33 | ref_points = ref_points.unsqueeze(0) 34 | if weights is not None: 35 | weights = weights.unsqueeze(0) 36 | squeeze_first = True 37 | else: 38 | squeeze_first = False 39 | 40 | batch_size = src_points.shape[0] 41 | if weights is None: 42 | weights = torch.ones_like(src_points[:, :, 0]) 43 | weights = torch.where(torch.lt(weights, weight_thresh), torch.zeros_like(weights), weights) 44 | weights = weights / (torch.sum(weights, dim=1, keepdim=True) + eps) 45 | weights = weights.unsqueeze(2) # (B, N, 1) 46 | 47 | src_centroid = torch.sum(src_points * weights, dim=1, keepdim=True) # (B, 1, 3) 48 | ref_centroid = torch.sum(ref_points * weights, dim=1, keepdim=True) # (B, 1, 3) 49 | src_points_centered = src_points - src_centroid # (B, N, 3) 50 | ref_points_centered = ref_points - ref_centroid # (B, N, 3) 51 | 52 | H = src_points_centered.permute(0, 2, 1) @ (weights * ref_points_centered) 53 | U, _, V = torch.svd(H.cpu()) # H = USV^T 54 | Ut, V = U.transpose(1, 2).cuda(), V.cuda() 55 | eye = torch.eye(3).unsqueeze(0).repeat(batch_size, 1, 1).cuda() 56 | eye[:, -1, -1] = torch.sign(torch.det(V @ Ut)) 57 | R = V @ eye @ Ut 58 | 59 | t = ref_centroid.permute(0, 2, 1) - R @ src_centroid.permute(0, 2, 1) 60 | t = t.squeeze(2) 61 | 62 | if return_transform: 63 | transform = torch.eye(4).unsqueeze(0).repeat(batch_size, 1, 1).cuda() 64 | transform[:, :3, :3] = R 65 | transform[:, :3, 3] = t 66 | if squeeze_first: 67 | transform = transform.squeeze(0) 68 | return transform 69 | else: 70 | if squeeze_first: 71 | R = R.squeeze(0) 72 | t = t.squeeze(0) 73 | return R, t 74 | 75 | 76 | class WeightedProcrustes(nn.Module): 77 | def __init__(self, weight_thresh=0.0, eps=1e-5, return_transform=False): 78 | super(WeightedProcrustes, self).__init__() 79 | self.weight_thresh = weight_thresh 80 | self.eps = eps 81 | self.return_transform = return_transform 82 | 83 | def forward(self, src_points, tgt_points, weights=None): 84 | return weighted_procrustes( 85 | src_points, 86 | tgt_points, 87 | weights=weights, 88 | weight_thresh=self.weight_thresh, 89 | eps=self.eps, 90 | return_transform=self.return_transform, 91 | ) 92 | -------------------------------------------------------------------------------- /sgreg/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUST-Aerial-Robotics/SG-Reg/c164198cec84be11dc53101755b0d9f7a4bc5082/sgreg/utils/__init__.py -------------------------------------------------------------------------------- /sgreg/utils/config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import yaml 3 | from yaml import Loader 4 | 5 | def get_parser(): 6 | parser = argparse.ArgumentParser(description='GNN for scene graph matching') 7 | parser.add_argument('--config', type=str, default='config/scannet.yaml', help='path to config file') 8 | 9 | args = parser.parse_args() 10 | assert args.config is not None 11 | with open(args.config, 'r') as f: 12 | config = yaml.load(f, Loader=Loader) 13 | for key in config: 14 | for k, v in config[key].items(): 15 | setattr(args, k, v) 16 | 17 | # setattr(args, 'exp_path', os.path.join('exp', args.dataset, args.model_name, args.config.split('/')[-1][:-5])) 18 | 19 | return args 20 | 21 | def create_cfg(cfg): 22 | 23 | cfg.backbone.init_radius = 0.5 * cfg.backbone.base_radius * cfg.backbone.init_voxel_size # 0.0625 24 | cfg.backbone.init_sigma = 0.5 * cfg.backbone.base_sigma * cfg.backbone.init_voxel_size # 0.05 25 | 26 | return cfg 27 | 28 | -------------------------------------------------------------------------------- /sgreg/utils/io.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import numpy as np 3 | import open3d as o3d 4 | 5 | def write_pred_nodes(dir:str, 6 | pred_src_instances:np.ndarray, 7 | pred_ref_instances:np.ndarray, 8 | pred_src_centroids:np.ndarray, 9 | pred_ref_centroids:np.ndarray, 10 | gt_mask:np.ndarray): 11 | 12 | with open(dir,'w') as f: 13 | f.write('# src_id, ref_id, gt_mask, src_centroid, ref_centroid\n') 14 | for i in range(pred_src_instances.shape[0]): 15 | f.write('{} {} {} '.format(pred_src_instances[i], 16 | pred_ref_instances[i], 17 | gt_mask[i])) 18 | f.write('{:.3f} {:.3f} {:.3f} '.format(pred_src_centroids[i,0], 19 | pred_src_centroids[i,1], 20 | pred_src_centroids[i,2])) 21 | f.write('{:.3f} {:.3f} {:.3f}\n'.format(pred_ref_centroids[i,0], 22 | pred_ref_centroids[i,1], 23 | pred_ref_centroids[i,2])) 24 | f.close() 25 | # print('write pred nodes to {}'.format(dir)) 26 | 27 | def write_registration_results(registration_dict:dict, 28 | dir:str): 29 | corres_points = registration_dict['points'].detach().cpu().numpy() # (C,6) 30 | corres_instances = registration_dict['instances'].detach().cpu().numpy() # (C,1), [m] 31 | corres_scores = registration_dict['scores'].detach().cpu().numpy() # (C,) 32 | # corres_rmse = registration_dict['errors'].detach().cpu().numpy() # (C,) 33 | corres_masks = registration_dict['corres_masks'].int().detach().cpu().numpy() # (C,) 34 | estimated_transforms = registration_dict['estimated_transform'].squeeze().detach().cpu().numpy() # (4,4) 35 | 36 | # correspondences 37 | corre_ref_pcd = o3d.geometry.PointCloud( 38 | points=o3d.utility.Vector3dVector(corres_points[:,:3])) 39 | corre_src_pcd = o3d.geometry.PointCloud( 40 | points=o3d.utility.Vector3dVector(corres_points[:,3:])) 41 | o3d.io.write_point_cloud(os.path.join(dir,'corr_ref.ply'),corre_ref_pcd) 42 | o3d.io.write_point_cloud(os.path.join(dir,'corr_src.ply'),corre_src_pcd) 43 | 44 | # correspondences info 45 | with open(os.path.join(dir,'point_matches.txt'),'w') as f: 46 | f.write('# match, score, tp_mask\n') 47 | for i in range(corres_points.shape[0]): 48 | f.write('{} {:.2f} {}\n'.format(corres_instances[i], 49 | corres_scores[i], 50 | corres_masks[i])) 51 | f.close() 52 | 53 | # estimated transform 54 | np.savetxt(os.path.join(dir,'svds_estimation.txt'), 55 | estimated_transforms, 56 | fmt='%.6f') 57 | 58 | def read_pred_nodes(dir:str): 59 | pred_src_instances = [] 60 | pred_ref_instances = [] 61 | pred_src_centroids = [] 62 | pred_ref_centroids = [] 63 | gt_mask = [] 64 | 65 | with open(dir,'r') as f: 66 | lines = f.readlines() 67 | for line in lines[1:]: 68 | line = line.strip().split(' ') 69 | pred_src_instances.append(int(line[0])) 70 | pred_ref_instances.append(int(line[1])) 71 | gt_mask.append(int(line[2])) 72 | pred_src_centroids.append([float(line[3]),float(line[4]),float(line[5])]) 73 | pred_ref_centroids.append([float(line[6]),float(line[7]),float(line[8])]) 74 | f.close() 75 | 76 | return {'src_instances':np.array(pred_src_instances), 77 | 'ref_instances':np.array(pred_ref_instances), 78 | 'src_centroids':np.array(pred_src_centroids), 79 | 'ref_centroids':np.array(pred_ref_centroids), 80 | 'gt_mask':np.array(gt_mask)} 81 | 82 | # return np.array(pred_src_instances), \ 83 | # np.array(pred_ref_instances), \ 84 | # np.array(pred_src_centroids), \ 85 | # np.array(pred_ref_centroids), \ 86 | # np.array(gt_mask) 87 | 88 | def read_corr_scores(dir:str): 89 | corr_instances = [] 90 | corr_scores = [] 91 | corr_masks = [] 92 | 93 | with open(dir,'r') as f: 94 | lines = f.readlines() 95 | for line in lines: 96 | if '#' in line: 97 | continue 98 | line = line.strip().split(' ') 99 | corr_instances.append(int(line[0])) 100 | corr_scores.append(float(line[1])) 101 | corr_masks.append(int(line[2])) 102 | f.close() 103 | return np.array(corr_instances), \ 104 | np.array(corr_scores), \ 105 | np.array(corr_masks) 106 | 107 | def read_dense_correspondences(src_corr_dir:str, 108 | ref_corr_dir:str, 109 | score_file_dir:str): 110 | src_corr_pcd = o3d.io.read_point_cloud(src_corr_dir) 111 | ref_corr_pcd = o3d.io.read_point_cloud(ref_corr_dir) 112 | src_corr_points = np.asarray(src_corr_pcd.points) # (N,3) 113 | ref_corr_points = np.asarray(ref_corr_pcd.points) # (N,3) 114 | 115 | corr_inst, corr_scores, corr_masks = read_corr_scores(score_file_dir) 116 | assert src_corr_points.shape[0] == corr_scores.shape[0] 117 | assert src_corr_points.shape[0] == ref_corr_points.shape[0] 118 | assert src_corr_points.shape[0] == corr_masks.shape[0] 119 | 120 | return {'src_corrs':src_corr_points, 121 | 'ref_corrs':ref_corr_points, 122 | 'corr_scores':corr_scores, 123 | 'corr_masks':corr_masks} -------------------------------------------------------------------------------- /sgreg/utils/tictoc.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | class TicToc: 4 | def __init__(self): 5 | self.tic() 6 | 7 | def tic(self): 8 | self.t0 = time.time() 9 | 10 | def toc(self): 11 | duration = time.time() - self.t0 12 | self.tic() 13 | return duration 14 | 15 | -------------------------------------------------------------------------------- /sgreg/utils/torch.py: -------------------------------------------------------------------------------- 1 | import os, glob 2 | import math 3 | import random 4 | from typing import Callable 5 | from collections import OrderedDict 6 | 7 | import numpy as np 8 | import torch 9 | import torch.distributed as dist 10 | import torch.utils.data 11 | import torch.backends.cudnn as cudnn 12 | from torch_geometric.graphgym.checkpoint import get_ckpt_path, get_ckpt_epoch 13 | 14 | 15 | # Distributed Data Parallel Utilities 16 | def all_reduce_tensor(tensor, world_size=1): 17 | r"""Average reduce a tensor across all workers.""" 18 | reduced_tensor = tensor.clone() 19 | dist.all_reduce(reduced_tensor) 20 | reduced_tensor /= world_size 21 | return reduced_tensor 22 | 23 | 24 | def all_reduce_tensors(x, world_size=1): 25 | r"""Average reduce all tensors across all workers.""" 26 | if isinstance(x, list): 27 | x = [all_reduce_tensors(item, world_size=world_size) for item in x] 28 | elif isinstance(x, tuple): 29 | x = (all_reduce_tensors(item, world_size=world_size) for item in x) 30 | elif isinstance(x, dict): 31 | x = {key: all_reduce_tensors(value, world_size=world_size) for key, value in x.items()} 32 | elif isinstance(x, torch.Tensor): 33 | x = all_reduce_tensor(x, world_size=world_size) 34 | return x 35 | 36 | 37 | # Dataloader Utilities 38 | 39 | 40 | def reset_seed_worker_init_fn(worker_id): 41 | r"""Reset seed for data loader worker.""" 42 | seed = torch.initial_seed() % (2 ** 32) 43 | # print(worker_id, seed) 44 | np.random.seed(seed) 45 | random.seed(seed) 46 | 47 | 48 | def build_dataloader( 49 | dataset, 50 | batch_size=1, 51 | num_workers=1, 52 | shuffle=None, 53 | collate_fn=None, 54 | pin_memory=False, 55 | drop_last=False, 56 | distributed=False, 57 | ): 58 | if distributed: 59 | sampler = torch.utils.data.DistributedSampler(dataset) 60 | shuffle = False 61 | else: 62 | sampler = None 63 | shuffle = shuffle 64 | 65 | data_loader = torch.utils.data.DataLoader( 66 | dataset, 67 | batch_size=batch_size, 68 | num_workers=num_workers, 69 | shuffle=shuffle, 70 | sampler=sampler, 71 | collate_fn=collate_fn, 72 | worker_init_fn=reset_seed_worker_init_fn, 73 | pin_memory=pin_memory, 74 | drop_last=drop_last, 75 | ) 76 | 77 | return data_loader 78 | 79 | 80 | # Common Utilities 81 | 82 | 83 | def initialize(seed=None, cudnn_deterministic=True, autograd_anomaly_detection=False): 84 | if seed is not None: 85 | random.seed(seed) 86 | torch.manual_seed(seed) 87 | np.random.seed(seed) 88 | if cudnn_deterministic: 89 | cudnn.benchmark = False 90 | cudnn.deterministic = True 91 | else: 92 | cudnn.benchmark = True 93 | cudnn.deterministic = False 94 | torch.autograd.set_detect_anomaly(autograd_anomaly_detection) 95 | 96 | 97 | def release_cuda(x): 98 | r"""Release all tensors to item or numpy array.""" 99 | if isinstance(x, list): 100 | x = [release_cuda(item) for item in x] 101 | elif isinstance(x, tuple): 102 | x = (release_cuda(item) for item in x) 103 | elif isinstance(x, dict): 104 | x = {key: release_cuda(value) for key, value in x.items()} 105 | elif isinstance(x, torch.Tensor): 106 | if x.numel() == 1: 107 | x = x.item() 108 | else: 109 | x = x.detach().cpu().numpy() 110 | return x 111 | 112 | 113 | def to_cuda(x): 114 | r"""Move all tensors to cuda.""" 115 | if isinstance(x, list): 116 | x = [to_cuda(item) for item in x] 117 | elif isinstance(x, tuple): 118 | x = (to_cuda(item) for item in x) 119 | elif isinstance(x, dict): 120 | x = {key: to_cuda(value) for key, value in x.items()} 121 | elif isinstance(x, torch.Tensor): 122 | x = x.cuda() 123 | return x 124 | 125 | 126 | def load_weights(model: torch.nn.Module, 127 | snapshot_dir: str): 128 | r"""Load weights and check keys.""" 129 | state_dict = torch.load(snapshot_dir) 130 | model_dict = state_dict['model_state'] 131 | missing_keys, unexpected_keys = model.load_state_dict(model_dict, strict=False) 132 | 133 | # snapshot_keys = set(model_dict.keys()) 134 | # model_keys = set(model.state_dict().keys()) 135 | # missing_keys = model_keys - snapshot_keys 136 | # unexpected_keys = snapshot_keys - model_keys 137 | # if len(missing_keys) > 0: 138 | # print('Missing keys:', missing_keys) 139 | # if len(unexpected_keys) > 0: 140 | # print('Unexpected keys:', unexpected_keys) 141 | 142 | return missing_keys, unexpected_keys 143 | 144 | def load_checkpoints_from_folder(model, folder, epoch=-1): 145 | if epoch<0: 146 | ckpt_folder = os.path.join(folder,'0','ckpt') 147 | weights_files = glob.glob(os.path.join(ckpt_folder,'*.ckpt')) 148 | weights_files = sorted(weights_files) 149 | weight_file = weights_files[-1] 150 | else: 151 | weight_file = os.path.join(folder,'0','ckpt',str(epoch)+'.ckpt') 152 | assert os.path.exists(weight_file), 'weight file {} not exists'.format(weight_file) 153 | file_name = os.path.basename(weight_file) 154 | epoch = int(file_name.split('.')[0]) 155 | load_weights(model, weight_file) 156 | return epoch 157 | 158 | def checkpoint_restore(model, 159 | exp_path, 160 | exp_name, 161 | epoch=0, 162 | dist=False, 163 | f='', 164 | gpu=0, 165 | optimizer: torch.optim.Optimizer = None): 166 | # Find file and epoch 167 | if not f: 168 | if epoch > 0: 169 | f = os.path.join(exp_path, exp_name + '-%04d'%epoch + '.pth') 170 | assert os.path.isfile(f) 171 | else: 172 | f = sorted(glob.glob(os.path.join(exp_path, exp_name + '-*.pth'))) 173 | if len(f) > 0: 174 | f = f[-1] 175 | epoch = int(f[len(exp_path) + len(exp_name) + 2 : -4]) 176 | 177 | # Load checkpoint and optimizer 178 | if len(f) > 0: 179 | map_location = {'cuda:0': 'cuda:{}'.format(gpu)} if gpu > 0 else None 180 | state = torch.load(f, map_location=map_location) 181 | checkpoint = state if not (isinstance(state, dict) and 'state_dict' in state) else state['state_dict'] 182 | 183 | for k, v in checkpoint.items(): 184 | if 'module.' in k: 185 | checkpoint = {k[len('module.'):]: v for k, v in checkpoint.items()} 186 | break 187 | if dist: 188 | model.module.load_state_dict(checkpoint) 189 | else: 190 | missing_keys, unexpected_keys = model.load_state_dict(checkpoint, 191 | strict=False) 192 | 193 | if optimizer is not None: 194 | if isinstance(state, dict) and 'optimizer' in state: 195 | optimizer.load_state_dict(state['optimizer']) 196 | 197 | print('Restore checkpoint from {} at epoch {}'.format(f, epoch)) 198 | 199 | if epoch>0: epoch = epoch + 1 200 | return epoch, f 201 | 202 | def is_power2(num): 203 | return num != 0 and ((num & (num - 1)) == 0) 204 | 205 | 206 | def is_multiple(num, multiple): 207 | return num > 0 and num % multiple == 0 208 | 209 | 210 | def is_last(num, total_num, ratio=0.95): 211 | return num > int(total_num * ratio) 212 | 213 | def copy_state_dict(state_dict: OrderedDict, 214 | ignore_keys: list = []): 215 | new_state_dict = OrderedDict() 216 | for k, v in state_dict.items(): 217 | skip = False 218 | for ignore_key in ignore_keys: 219 | if ignore_key in k: 220 | skip = True 221 | break 222 | 223 | if not skip: 224 | new_state_dict[k] = v 225 | return new_state_dict 226 | 227 | def checkpoint_save(model, 228 | optimizer, 229 | exp_path, 230 | exp_name, 231 | epoch, 232 | save_freq=16, 233 | ignore_keys=['bert']): 234 | f = os.path.join(exp_path, exp_name + '-%04d'%epoch + '.pth') 235 | state_to_save = copy_state_dict(model.state_dict(), 236 | ignore_keys=ignore_keys) 237 | 238 | state = { 239 | 'state_dict': state_to_save, #model.state_dict(), 240 | 'optimizer': optimizer.state_dict(), 241 | } 242 | torch.save(state, f) 243 | 244 | #remove previous checkpoints unless they are a power of 2 or a multiple of 16 to save disk space 245 | epoch = epoch - 1 246 | fd = os.path.join(exp_path, exp_name + '-%04d'%epoch + '.pth') 247 | if os.path.isfile(fd): 248 | if not is_multiple(epoch, save_freq) and not is_power2(epoch): 249 | os.remove(fd) 250 | 251 | return f 252 | 253 | 254 | def realase_cuda(input_dict): 255 | for key in input_dict: 256 | if isinstance(input_dict[key],torch.Tensor): 257 | input_dict[key] = input_dict[key].detach().cpu().numpy() 258 | return input_dict 259 | 260 | def fix_network_modules(model, fixed_modules=[]): 261 | # msg = 'fixed modules: ' 262 | for name, params in model.named_parameters(): 263 | if name.split('.')[0] in fixed_modules: 264 | params.requires_grad = False 265 | return model 266 | 267 | class CosineAnnealingFunction(Callable): 268 | def __init__(self, max_epoch, eta_min=0.0): 269 | self.max_epoch = max_epoch 270 | self.eta_min = eta_min 271 | 272 | def __call__(self, last_epoch): 273 | next_epoch = last_epoch + 1 274 | return self.eta_min + 0.5 * (1.0 - self.eta_min) * (1.0 + math.cos(math.pi * next_epoch / self.max_epoch)) 275 | 276 | 277 | class WarmUpCosineAnnealingFunction(Callable): 278 | def __init__(self, total_steps, warmup_steps, eta_init=0.1, eta_min=0.1): 279 | self.total_steps = total_steps 280 | self.warmup_steps = warmup_steps 281 | self.normal_steps = total_steps - warmup_steps 282 | self.eta_init = eta_init 283 | self.eta_min = eta_min 284 | 285 | def __call__(self, last_step): 286 | # last_step starts from -1, which means last_steps=0 indicates the first call of lr annealing. 287 | next_step = last_step + 1 288 | if next_step < self.warmup_steps: 289 | return self.eta_init + (1.0 - self.eta_init) / self.warmup_steps * next_step 290 | else: 291 | if next_step > self.total_steps: 292 | return self.eta_min 293 | next_step -= self.warmup_steps 294 | return self.eta_min + 0.5 * (1.0 - self.eta_min) * (1 + np.cos(np.pi * next_step / self.normal_steps)) 295 | 296 | 297 | def build_warmup_cosine_lr_scheduler(optimizer, total_steps, warmup_steps, eta_init=0.1, eta_min=0.1, grad_acc_steps=1): 298 | total_steps //= grad_acc_steps 299 | warmup_steps //= grad_acc_steps 300 | cosine_func = WarmUpCosineAnnealingFunction(total_steps, warmup_steps, eta_init=eta_init, eta_min=eta_min) 301 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, cosine_func) 302 | return scheduler 303 | 304 | def save_final_ckpt( 305 | model: torch.nn.Module, 306 | dir:str, 307 | optimizer= None, 308 | scheduler= None, 309 | ): 310 | r"""Saves the model checkpoint at a given epoch.""" 311 | MODEL_STATE = 'model_state' 312 | OPTIMIZER_STATE = 'optimizer_state' 313 | SCHEDULER_STATE = 'scheduler_state' 314 | ckpt = {} 315 | ckpt[MODEL_STATE] = model.state_dict() 316 | if optimizer is not None: 317 | ckpt[OPTIMIZER_STATE] = optimizer.state_dict() 318 | if scheduler is not None: 319 | ckpt[SCHEDULER_STATE] = scheduler.state_dict() 320 | 321 | torch.save(ckpt,dir) 322 | 323 | # os.makedirs(get_ckpt_dir(), exist_ok=True) 324 | # torch.save(ckpt, get_ckpt_path(get_ckpt_epoch(epoch))) -------------------------------------------------------------------------------- /sgreg/utils/utils.py: -------------------------------------------------------------------------------- 1 | import os, glob 2 | import torch 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from time import perf_counter 6 | 7 | 8 | class RecallMetrics: 9 | def __init__(self): 10 | self.num_pos = 0 11 | self.num_neg = 0 12 | self.num_gt = 0 13 | 14 | # def update(self,tp_,fp_,gt_): 15 | def update(self, data_dict): 16 | self.num_pos += data_dict["tp"] 17 | self.num_neg += data_dict["fp"] 18 | self.num_gt += data_dict["gt"] 19 | 20 | def get_metrics(self, precentage=True, recall_only=False): 21 | if self.num_pos < 1: 22 | recall = 0.0 23 | precision = 0.0 24 | else: 25 | recall = self.num_pos / (self.num_gt + 1e-6) 26 | precision = self.num_pos / (self.num_pos + self.num_neg + 1e-6) 27 | 28 | if precentage: 29 | recall = recall * 100 30 | precision = precision * 100 31 | if recall_only: 32 | return {"R": recall} 33 | else: 34 | return {"R": recall, "P": precision} 35 | 36 | 37 | class TicToc: 38 | def __init__(self): 39 | self.time_dict = {} 40 | 41 | def tic(self, key_name): 42 | if key_name not in self.time_dict: 43 | self.time_dict[key_name] = {"time": 0, "number": 0, "t0": perf_counter()} 44 | else: 45 | self.time_dict[key_name]["t0"] = perf_counter() 46 | 47 | def toc(self, key_name, verbose=False): 48 | if key_name not in self.time_dict: 49 | raise ValueError(f"No timer started for {key_name}") 50 | t1 = perf_counter() 51 | elapsed_time = (t1 - self.time_dict[key_name]["t0"]) * 1000 52 | self.time_dict[key_name]["time"] += elapsed_time 53 | self.time_dict[key_name]["number"] += 1 54 | if verbose: 55 | print(f"Time for {key_name}: {elapsed_time:.2f} ms") 56 | return elapsed_time 57 | 58 | def print_summary(self): 59 | for k, v in self.time_dict.items(): 60 | average_time = v["time"] / v["number"] if v["number"] > 0 else 0 61 | print(f"Average Time for {k}: {average_time:.2f} ms") 62 | 63 | 64 | timer = TicToc() 65 | 66 | def summary_dict(input_dict: dict, output_dict: dict): 67 | for k, v in input_dict.items(): 68 | if k not in output_dict: 69 | output_dict[k] = v 70 | else: 71 | output_dict[k] += v 72 | 73 | def update_dict(total_dict: dict, sub_dict: dict, name: str): 74 | for k, v in sub_dict.items(): 75 | total_dict[name + "_" + k] = v 76 | return total_dict 77 | 78 | 79 | def read_scans(dir): 80 | with open(dir, "r") as f: 81 | scans = [line.strip() for line in f.readlines()] 82 | print("Find {} scans to load".format(len(scans))) 83 | return scans 84 | 85 | 86 | def create_mask_from_edges(edge_index, min_nodes, max_nodes): 87 | nodes_number = max_nodes - min_nodes 88 | mask = torch.zeros(nodes_number, nodes_number).bool().to(edge_index.device) 89 | valid_edges = edge_index[ 90 | :, 91 | (edge_index[0, :] >= min_nodes) 92 | & (edge_index[0, :] < max_nodes) 93 | & (edge_index[1, :] >= min_nodes) 94 | & (edge_index[1, :] < max_nodes), 95 | ] # (2,e) 96 | 97 | valid_edges = valid_edges - min_nodes 98 | mask[valid_edges[0, :], valid_edges[1, :]] = True 99 | return mask 100 | 101 | def mask_valid_labels(src_labels:list, 102 | ignore_labels:str): 103 | mask = np.ones(len(src_labels),dtype=bool) 104 | for i, label in enumerate(src_labels): 105 | if label in ignore_labels: mask[i] = False 106 | 107 | return mask 108 | 109 | def read_scan_pairs(dir): 110 | with open(dir, "r") as f: 111 | pairs = [line.strip().split(" ") for line in f.readlines()] 112 | return pairs 113 | 114 | def scanpairs_2_scans(scan_pairs): 115 | scans = [] 116 | for pair in scan_pairs: 117 | scans.append(pair[0]) 118 | scans.append(pair[1]) 119 | return scans 120 | 121 | 122 | def write_scan_pairs(scan_pairs, dir): 123 | with open(dir, "w") as f: 124 | n = len(scan_pairs) 125 | for i, pair in enumerate(scan_pairs): 126 | f.write("{} {}".format(pair[0], pair[1])) 127 | if i < n - 1: 128 | f.write("\n") 129 | 130 | 131 | def load_checkpoint(path: str, model: torch.nn.Module, keyname: str = "pointnet2"): 132 | if os.path.exists(path): 133 | print("load checkpoint {} from {}".format(keyname, path)) 134 | ckpt = torch.load(path) 135 | model.load_state_dict(ckpt[keyname]) 136 | return True 137 | else: 138 | print("checkpoint {} not found".format(path)) 139 | return False 140 | 141 | MODEL_STATE = "model_state" 142 | OPTIMIZER_STATE = "optimizer_state" 143 | SCHEDULER_STATE = "scheduler_state" 144 | # ckpt_weights ={'pointnet2':ckpt[MODEL_STATE]} 145 | # torch.save(ckpt_weights, path.replace('64','model_epoch64')) 146 | 147 | # if optimizer is not None and OPTIMIZER_STATE in ckpt: 148 | # optimizer.load_state_dict(ckpt[OPTIMIZER_STATE]) 149 | # if scheduler is not None and SCHEDULER_STATE in ckpt: 150 | # scheduler.load_state_dict(ckpt[SCHEDULER_STATE]) 151 | 152 | # epoch = os.path.basename(path).split('.')[0] 153 | # print('load checkpoint from {}, epoch {}'.format(path,epoch)) 154 | 155 | # return int(epoch) + 1 156 | 157 | def load_pretrain_weight(dir): 158 | model_dict = torch.load(dir) 159 | if "model" in model_dict: 160 | model_dict = model_dict["model"] 161 | elif "model_state" in model_dict: 162 | model_dict = model_dict["model_state"] 163 | else: 164 | raise ValueError("No model state found in the checkpoint") 165 | 166 | return model_dict 167 | 168 | def load_submodule_state_dict(model_state_dir, submodule_names=["backbone"]): 169 | model_dict = torch.load(model_state_dir) 170 | if "model" in model_dict: 171 | model_dict = model_dict["model"] 172 | elif "model_state" in model_dict: 173 | model_dict = model_dict["model_state"] 174 | elif 'state_dict' in model_dict: 175 | model_dict = model_dict['state_dict'] 176 | 177 | submodule_dicts = {} 178 | for submodule_name in submodule_names: 179 | 180 | submodule_dicts[submodule_name] = { 181 | k[len(submodule_name) + 1 :]: v 182 | for k, v in model_dict.items() 183 | if submodule_name==k.split(".")[0] 184 | } 185 | 186 | return submodule_dicts 187 | -------------------------------------------------------------------------------- /sgreg/utils/viz_tools.py: -------------------------------------------------------------------------------- 1 | import open3d as o3d 2 | import numpy as np 3 | import torch 4 | 5 | def build_o3d_points(points:np.ndarray,colors:np.ndarray=None): 6 | pcd = o3d.geometry.PointCloud() 7 | pcd.points = o3d.utility.Vector3dVector(points) 8 | if colors is not None: 9 | pcd.colors = o3d.utility.Vector3dVector(colors) 10 | return pcd 11 | 12 | def build_correspondences_lines(corres_s, corres_t, corres_pos=None): 13 | line_points = np.concatenate([corres_s,corres_t],axis=0) # (2C,3) 14 | line_indices = np.stack([np.arange(len(corres_s)),np.arange(len(corres_s),2*len(corres_s))],axis=0) # (2,C) 15 | line_colors = np.zeros((corres_s.shape[0],3)) 16 | if corres_pos is None: 17 | line_colors += np.array([0,0,1]) 18 | else: 19 | line_colors[corres_pos] = np.array([0,1,0]) 20 | line_colors[~corres_pos] = np.array([1,0,0]) 21 | 22 | line_set = o3d.geometry.LineSet( 23 | points = o3d.utility.Vector3dVector(line_points), 24 | lines = o3d.utility.Vector2iVector(line_indices.T)) 25 | line_set.colors = o3d.utility.Vector3dVector(line_colors) 26 | return line_set 27 | 28 | def build_instance_centroids(graph:dict,pos_indices=np.array([]),neg_indices=np.array([]),radius=0.1): 29 | centroids = [] 30 | for idx, instance in graph['nodes'].items(): 31 | centroid = o3d.geometry.TriangleMesh.create_sphere(radius=radius) 32 | if instance.idx in pos_indices: 33 | centroid.paint_uniform_color(np.array([0,1,0])) 34 | elif instance.idx in neg_indices: 35 | centroid.paint_uniform_color(np.array([1,0,0])) 36 | # else: 37 | # continue 38 | centroid.translate(instance.cloud.get_center()) 39 | centroids.append(centroid) 40 | 41 | return centroids 42 | 43 | def build_centroids_from_points(points:np.ndarray,radius=0.1): 44 | centroids = [] 45 | for point in points: 46 | centroid = o3d.geometry.TriangleMesh.create_sphere(radius=radius) 47 | centroid.paint_uniform_color(np.array([0,1,0])) 48 | centroid.translate(point) 49 | centroids.append(centroid) 50 | return centroids 51 | 52 | def generate_instance_color(instances): 53 | # ref_instances = data_dict['ref_points_f_instances'] 54 | ref_instance_list = torch.unique(instances) 55 | instance_colors = np.zeros((instances.shape[0],3)) 56 | for idx in ref_instance_list: 57 | color = np.random.uniform(0,1,3) 58 | instance_mask = instances==idx 59 | instance_colors[instance_mask] = color 60 | return instance_colors -------------------------------------------------------------------------------- /sgreg/val.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml, argparse 3 | from omegaconf import OmegaConf 4 | import numpy as np 5 | import torch 6 | 7 | from sgreg.dataset.dataset_factory import val_data_loader 8 | from sgreg.utils.config import create_cfg 9 | from sgreg.utils.torch import to_cuda, checkpoint_restore 10 | from sgreg.sg_reg import SGNet, SGNetDecorator 11 | from train import val_epoch 12 | 13 | def test_epoch(data_loader, model, test_model_fn, save_dir=''): 14 | # todo: inference without gt evaluation 15 | print('test start') 16 | memory_array = [] 17 | time_record = [] 18 | time_analysis = [] 19 | count = 0 20 | 21 | for i, batch_data in enumerate(data_loader): 22 | batch_data = to_cuda(batch_data) 23 | output_dict, eval_dict = test_model_fn(batch_data,model) 24 | 25 | 26 | assert batch_data['batch_size'] == 1 27 | with torch.no_grad(): 28 | if save_dir != '': 29 | src_scan = batch_data['src_scan'][0] 30 | ref_scan = batch_data['ref_scan'][0] 31 | src_subix = src_scan[-1] 32 | ref_subix = ref_scan[-1] 33 | scene_name = src_scan[:-1] 34 | if os.path.exists(os.path.join(save_dir,scene_name)) == False: 35 | os.makedirs(os.path.join(save_dir,scene_name)) 36 | 37 | src_idx2name = batch_data['src_graph']['idx2name'].squeeze() 38 | ref_idx2name = batch_data['ref_graph']['idx2name'].squeeze() 39 | pred_nodes = output_dict['instance_matches']['pred_nodes'][:,1:] # (M,2) 40 | src_centroids = batch_data['src_graph']['centroids'].squeeze() 41 | ref_centroids = batch_data['ref_graph']['centroids'].squeeze() 42 | pred_src_centroids = src_centroids[pred_nodes[:,0].long()] # (M,3) 43 | pred_ref_centroids = ref_centroids[pred_nodes[:,1].long()] # (M,3) 44 | 45 | pred_nodes[:,0] = src_idx2name[pred_nodes[:,0].long()] 46 | pred_nodes[:,1] = ref_idx2name[pred_nodes[:,1].long()] 47 | 48 | scan_output_dict = { 49 | 'ref_instances':src_idx2name.detach().cpu().numpy(), 50 | 'src_instances':ref_idx2name.detach().cpu().numpy(), 51 | 'pred_nodes':pred_nodes.detach().cpu().numpy(), 52 | 'gt_transform': batch_data['transform'].squeeze().detach().cpu().numpy(), 53 | 'pred_scores': output_dict['instance_matches']['pred_scores'].squeeze().detach().cpu().numpy(), 54 | } 55 | print('{},{}: pred nodes {}'.format(src_scan,ref_scan,scan_output_dict['pred_nodes'].shape[0])) 56 | 57 | if 'registration' in output_dict: 58 | corres_points = output_dict['registration']['points'].detach().cpu().numpy() # (C,6) 59 | scan_output_dict['corres_ref_points'] = corres_points[:,:3] 60 | scan_output_dict['corres_src_points'] = corres_points[:,3:] 61 | scan_output_dict['corres_ref_centroids'] = pred_ref_centroids.detach().cpu().numpy() 62 | scan_output_dict['corres_src_centroids'] = pred_src_centroids.detach().cpu().numpy() 63 | scan_output_dict['corres_scores'] = output_dict['registration']['scores'].detach().cpu().numpy() # (C,) 64 | scan_output_dict['corres_rmse'] = output_dict['registration']['errors'].detach().cpu().numpy() # (C,) 65 | scan_output_dict['estimated_transforms'] = output_dict['registration']['estimated_transform'].squeeze().detach().cpu().numpy() # (4,4) 66 | # scan_output_dict['ref_instance_points'] = output_dict['ref_instance_points'].squeeze().detach().cpu().numpy() 67 | # scan_output_dict['src_instance_points'] = output_dict['src_instance_points'].squeeze().detach().cpu().numpy() 68 | print('PIR:{:.3f}% in {} correspondence points'.format(100*eval_dict['PIR'].squeeze(),corres_points.shape[0])) 69 | 70 | torch.save(scan_output_dict,os.path.join(save_dir,scene_name,'output_dict_{}{}.pth'.format(src_subix,ref_subix))) 71 | 72 | print('finished test') 73 | 74 | def run_rag_epoch(data_loader, model, model_fn, save_dir): 75 | assert save_dir != '' 76 | print('Inference the dataset for RAG. Save result to {}'.format(save_dir)) 77 | 78 | scene_pairs = [] 79 | scene_size = [] 80 | sum_metric_dict = {} 81 | count_scenes = 0 82 | 83 | for i, batch_data in enumerate(data_loader): 84 | batch_data = to_cuda(batch_data) 85 | _, _, output_dict, _ = model_fn(batch_data, 86 | model, 87 | epoch, 88 | False) 89 | 90 | with torch.no_grad(): 91 | assert 'x_src1' in output_dict and 'f_src' in output_dict 92 | 93 | # collect data 94 | src_scene_name = batch_data['src_scan'][0] 95 | ref_scene_name = batch_data['ref_scan'][0] 96 | ref_instances = batch_data['ref_graph']['idx2name'] 97 | src_instances = batch_data['src_graph']['idx2name'] 98 | ref_labels = batch_data['ref_graph']['labels'] 99 | src_labels = batch_data['src_graph']['labels'] 100 | x_src = output_dict['x_src1'] 101 | x_ref = output_dict['x_ref1'] 102 | f_src = output_dict['f_src'] 103 | f_ref = output_dict['f_ref'] 104 | 105 | assert isinstance(src_labels, list) and isinstance(src_labels[0],str) 106 | 107 | data2save_src = {'instances':src_instances, 108 | 'labels':src_labels, 109 | 'x':x_src, 110 | 'f':f_src} 111 | data2save_ref = {'instances':ref_instances, 112 | 'labels':ref_labels, 113 | 'x':x_ref, 114 | 'f':f_ref} 115 | 116 | msg = '{}-{}'.format(src_scene_name,ref_scene_name) 117 | 118 | # save data 119 | src_out_dir = os.path.join(save_dir,src_scene_name) 120 | ref_out_dir = os.path.join(save_dir,ref_scene_name) 121 | os.makedirs(src_out_dir,exist_ok=True) 122 | os.makedirs(ref_out_dir,exist_ok=True) 123 | 124 | torch.save(data2save_src,os.path.join(src_out_dir,'features.pth')) 125 | torch.save(data2save_ref,os.path.join(ref_out_dir,'features.pth')) 126 | print(msg) 127 | 128 | 129 | print('*********** finished RAG ***********') 130 | 131 | if __name__ == '__main__': 132 | parser = argparse.ArgumentParser(description='GNN for scene graph matching') 133 | parser.add_argument('--cfg_file', type=str, default='config/scannet.yaml', 134 | help='path to config file') 135 | parser.add_argument('--checkpoint', type=str, 136 | help='folder to load checkpoint if it is set') 137 | parser.add_argument('--test', action='store_true', help='test mode') 138 | parser.add_argument('--epoch',type=int,help='Load checkpoint at assigned epoch',default=-1) 139 | parser.add_argument('--output', type=str, default='sgnet_scannet_0080', help='save results to here') 140 | parser.add_argument('--push2hf', action='store_true', help='push model to huggingface hub') 141 | args = parser.parse_args() 142 | 143 | # Paramters 144 | assert os.path.exists(args.cfg_file), 'config file {} not found'.format(args.cfg_file) 145 | conf = OmegaConf.load(args.cfg_file) 146 | conf = create_cfg(conf) 147 | 148 | # Model 149 | if args.checkpoint is None: 150 | print('Loading checkpoint from huggingface hub') 151 | model = SGNet.from_pretrained('glennliu/sgnet', 152 | conf=conf) 153 | model = model.cuda() 154 | print('Loaded checkpoint') 155 | else: 156 | model = SGNet(conf=conf) 157 | model = model.cuda() 158 | epoch, f = checkpoint_restore(model, 159 | args.checkpoint, 160 | 'sgnet', 161 | args.epoch) 162 | 163 | for name, params in model.named_parameters(): 164 | if name.split('.')[0] in conf.model.fix_modules: 165 | params.requires_grad = False 166 | model.eval() 167 | model_fn = SGNetDecorator(conf=conf) 168 | if args.push2hf: 169 | model.push_to_hub('glennliu/sgnet') 170 | print('Push model to huggingface glennliu/sgnet') 171 | exit(0) 172 | 173 | # Dataset 174 | val_loader, _ = val_data_loader(conf) 175 | print('Validation dataset: {} batches'.format(len(val_loader))) 176 | 177 | # Saving directory 178 | if args.output=='': 179 | pred_output = os.path.join(conf.dataset.dataroot, 180 | 'output', 181 | os.path.basename(args.checkpoint)) 182 | else: 183 | pred_output = os.path.join(conf.dataset.dataroot, 184 | 'output', 185 | args.output) 186 | os.makedirs(pred_output,exist_ok=True) 187 | OmegaConf.save(conf, os.path.join(pred_output,'config.yaml')) 188 | 189 | # Val 190 | loss_metrics, eval_metrics =val_epoch( 191 | val_loader, model, model_fn,save_dir=pred_output) 192 | 193 | -------------------------------------------------------------------------------- /sgreg/visualize.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import open3d as o3d 4 | import rerun as rr 5 | import argparse, json 6 | from scipy.spatial.transform import Rotation as R 7 | 8 | from sgreg.dataset.scene_graph import load_processed_scene_graph, transform_scene_graph 9 | from sgreg.utils.io import read_pred_nodes, read_dense_correspondences 10 | 11 | from os.path import join as osp 12 | 13 | def render_point_cloud(entity_name:str, 14 | cloud:o3d.geometry.PointCloud, 15 | radius=0.1, 16 | color=None): 17 | """ 18 | Render a point cloud with a specific color and point size. 19 | """ 20 | if color is not None: 21 | viz_colors = color 22 | else: 23 | viz_colors = np.asarray(cloud.colors) 24 | 25 | rr.log(entity_name, 26 | rr.Points3D( 27 | np.asarray(cloud.points), 28 | colors=viz_colors, 29 | radii=radius, 30 | ) 31 | ) 32 | 33 | def render_node_centers(entity_name:str, 34 | nodes:dict, 35 | radius=0.1, 36 | color=[0,0,0]): 37 | """ 38 | Render the centers of nodes in the scene graph. 39 | """ 40 | 41 | centers = [] 42 | semantic_labels = [] 43 | for node in nodes.values(): 44 | if isinstance(node, o3d.geometry.OrientedBoundingBox): 45 | centers.append(node.center) 46 | else: 47 | centers.append(node.cloud.get_center()) 48 | semantic_labels.append(node.label) 49 | centers = np.array(centers) 50 | if color is None: 51 | raise NotImplementedError('todo: color by node pcd') 52 | else: 53 | viz_colors = color 54 | rr.log(entity_name, 55 | rr.Points3D( 56 | centers, 57 | colors=viz_colors, 58 | radii=radius, 59 | labels=semantic_labels, 60 | show_labels=False 61 | ) 62 | ) 63 | 64 | def render_node_bboxes(entity_name:str, 65 | nodes:dict, 66 | show_labels:bool=True, 67 | radius=0.01): 68 | 69 | for idx, node in nodes.items(): 70 | quad = R.from_matrix(node.box.R).as_quat() 71 | # print(node.box.R) 72 | rr.log('{}/{}'.format(entity_name,idx), 73 | rr.Boxes3D(half_sizes=0.5*node.box.extent, 74 | centers=node.box.center, 75 | quaternions=rr.Quaternion(xyzw=quad), 76 | radii=radius, 77 | labels=node.label, 78 | show_labels=show_labels) 79 | ) 80 | 81 | 82 | def render_semantic_scene_graph(scene_name:str, 83 | scene_graph:dict, 84 | voxel_size:float=0.05, 85 | origin:np.ndarray=np.eye(4), 86 | box:bool=False 87 | ): 88 | render_point_cloud(scene_name+'/global_cloud', 89 | scene_graph['global_cloud'], 90 | voxel_size) 91 | render_node_centers(scene_name+'/centroids', 92 | scene_graph['nodes']) 93 | 94 | if box: 95 | render_node_bboxes(scene_name+'/nodes', 96 | scene_graph['nodes'], 97 | show_labels=True) 98 | 99 | quad = R.from_matrix(origin[:3,:3]).as_quat() 100 | rr.log(scene_name+'/local_origin', 101 | rr.Transform3D(translation=origin[:3,3], 102 | quaternion=quad) 103 | ) 104 | 105 | def render_correspondences(entity_name:str, 106 | src_points:np.ndarray, 107 | ref_points:np.ndarray, 108 | transform:np.ndarray=None, 109 | gt_mask:np.ndarray=None, 110 | radius=0.01): 111 | 112 | N = src_points.shape[0] 113 | assert N==ref_points.shape[0], 'src and ref points should have the same number of points' 114 | line_points = [] 115 | line_colors = [] 116 | 117 | for i in range(N): 118 | src = src_points[i] 119 | ref = ref_points[i] 120 | if transform is not None: 121 | src = transform[:3,:3] @ src + transform[:3,3] 122 | 123 | if gt_mask[i]: 124 | line_colors.append([0,255,0]) 125 | else: 126 | line_colors.append([255,0,0]) 127 | 128 | line_points.append([src,ref]) 129 | 130 | 131 | line_points = np.concatenate(line_points,axis=0) 132 | line_points = line_points.reshape(-1,2,3) 133 | line_colors = np.array(line_colors) 134 | rr.log(entity_name, 135 | rr.LineStrips3D(line_points, 136 | radii=radius, 137 | colors=line_colors) 138 | ) 139 | 140 | def render_registration(entity_name:str, 141 | src_cloud:o3d.geometry.PointCloud, 142 | ref_cloud:o3d.geometry.PointCloud, 143 | transform:np.ndarray): 144 | 145 | src_cloud.transform(transform) 146 | src_points = np.asarray(src_cloud.points) 147 | ref_points = np.asarray(ref_cloud.points) 148 | src_color = [0,180,180] 149 | ref_color = [180,180,0] 150 | rr.log(entity_name+'/src', 151 | rr.Points3D(src_points, 152 | colors=src_color, 153 | radii=0.01) 154 | ) 155 | rr.log(entity_name+'/ref', 156 | rr.Points3D(ref_points, 157 | colors=ref_color, 158 | radii=0.01) 159 | ) 160 | 161 | 162 | def get_parser_args(): 163 | def float_list(string): 164 | return [float(x) for x in string.split(',')] 165 | 166 | parser = argparse.ArgumentParser(description='Visualize scene graph') 167 | parser.add_argument('--dataroot', type=str, required=True, 168 | help='path to dataset root') 169 | parser.add_argument('--src_scene', type=str, default='scene0108_00c', 170 | help='source scene name') 171 | parser.add_argument('--ref_scene', type=str, default='scene0108_00a', 172 | help='reference scene name') 173 | parser.add_argument('--result_folder', type=str, default='output/sgnet_scannet_0080', 174 | help='a relative path to the result folder') 175 | parser.add_argument('--viz_mode', type=int, required=True, 176 | help='0: no viz, 1: local viz, 2: remote viz, 3: save rrd') 177 | parser.add_argument('--remote_rerun_add', type=str, help='IP:PORT') 178 | parser.add_argument('--find_gt', action='store_true', 179 | help='align the scene graphs for bettter visualization') 180 | parser.add_argument('--augment_transform', action='store_true', 181 | help='only enable it for ScanNet scenes.') 182 | parser.add_argument('--viz_translation', type=json.loads, 183 | default='[0,0,0]', 184 | help='translation to viz the scene graphs') 185 | 186 | parser.add_argument('--voxel_size', type=float, default=0.05, 187 | help='voxel size for downsampling') 188 | 189 | return parser.parse_args() 190 | 191 | if __name__=='__main__': 192 | print('*'*60) 193 | print('This script reads the data association and registration results.') 194 | print('*'*60) 195 | 196 | ############ Args ############ 197 | args = get_parser_args() 198 | SPLIT = 'val' 199 | RESULT_FOLDER = osp(args.dataroot, 200 | args.result_folder, 201 | '{}-{}'.format(args.src_scene,args.ref_scene)) 202 | print('Visualize {}-{} scene graph'.format(args.src_scene,args.ref_scene)) 203 | ############################## 204 | 205 | # Load scene graphs 206 | src_sg = load_processed_scene_graph(osp(args.dataroot,SPLIT,args.src_scene)) 207 | ref_sg = load_processed_scene_graph(osp(args.dataroot,SPLIT,args.ref_scene)) 208 | src_cloud = o3d.geometry.PointCloud(src_sg['global_cloud'].points) 209 | 210 | # Load SG-Reg results 211 | if os.path.exists(RESULT_FOLDER): 212 | node_matches = read_pred_nodes(osp(RESULT_FOLDER,'node_matches.txt')) 213 | point_correspondences = read_dense_correspondences(osp(RESULT_FOLDER,'corr_src.ply'), 214 | osp(RESULT_FOLDER,'corr_ref.ply'), 215 | osp(RESULT_FOLDER,'point_matches.txt')) 216 | pred_transformation = np.loadtxt(osp(RESULT_FOLDER,'svds_estimation.txt')) 217 | 218 | # Transform for better visualization 219 | transform = np.eye(4) 220 | if args.find_gt: 221 | if 'ScanNet' in args.dataroot: 222 | assert False, 'ScanNet scenes are already aligned. Remove the option.' 223 | gt_dir = osp(args.dataroot, 'gt', '{}-{}.txt'.format(args.src_scene,args.ref_scene)) 224 | assert os.path.exists(gt_dir), 'gt file not found' 225 | gt = np.loadtxt(gt_dir) 226 | print('Load gt transformations from ',gt_dir) 227 | transform = gt 228 | 229 | transform[:3,3] += args.viz_translation 230 | transform_scene_graph(src_sg, transform) 231 | 232 | # Stream to rerun 233 | rr.init("SGReg") 234 | render_semantic_scene_graph('src',src_sg, 235 | args.voxel_size, 236 | transform) 237 | render_semantic_scene_graph('ref',ref_sg, 238 | args.voxel_size, 239 | np.eye(4), 240 | True) 241 | 242 | if args.augment_transform: 243 | if 'RioGraph' in args.dataroot: 244 | assert False, 'RIO dataset does not require augment transform. Remove the option.' 245 | drift_dir = osp(RESULT_FOLDER,'gt_transform.txt') 246 | transform = np.loadtxt(drift_dir) 247 | assert os.path.exists(drift_dir), 'drift transform file not found' 248 | src_cloud.transform(np.linalg.inv(transform)) 249 | transform[:3,3] += args.viz_translation 250 | 251 | if os.path.exists(RESULT_FOLDER): 252 | render_correspondences('node_matches', 253 | node_matches['src_centroids'], 254 | node_matches['ref_centroids'], 255 | transform=transform, 256 | gt_mask=node_matches['gt_mask']) 257 | render_correspondences('point_correspondences', 258 | point_correspondences['src_corrs'], 259 | point_correspondences['ref_corrs'], 260 | transform=transform, 261 | gt_mask=point_correspondences['corr_masks']) 262 | render_registration('registration', 263 | src_cloud, 264 | ref_sg['global_cloud'], 265 | pred_transformation) 266 | 267 | # Eval message 268 | msg = '{}/{} TP node matches, '.format(node_matches['gt_mask'].sum(), 269 | node_matches['gt_mask'].shape[0]) 270 | msg += 'Inlier ratio: {:.2f}%'.format(point_correspondences['corr_masks'].mean()*100) 271 | print(msg) 272 | 273 | # Render on rerun 274 | if args.viz_mode==1: 275 | rr.spawn() 276 | elif args.viz_mode==2: 277 | assert args.remote_rerun_add is not None, \ 278 | 'require a remote address for rendering, (eg. 143.89.38.169:9876)' 279 | print('--- Render rerun at a remote machine ',args.remote_rerun_add, '---') 280 | rr.connect_tcp(args.remote_rerun_add) 281 | elif args.viz_mode==3: 282 | rr.save(osp(RESULT_FOLDER,'result.rrd')) 283 | print('Save rerun data to ',osp(RESULT_FOLDER,'result.rrd')) 284 | else: 285 | print('No visualization') -------------------------------------------------------------------------------- /tutorials/explicit_sg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUST-Aerial-Robotics/SG-Reg/c164198cec84be11dc53101755b0d9f7a4bc5082/tutorials/explicit_sg.png -------------------------------------------------------------------------------- /tutorials/rag_data.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Using SG-Net encoded features for RAG\n", 8 | "In SG-Reg, we encode multiple modality of a semantic node:\n", 9 | "- Local topology feature $\\mathbf{x}$.\n", 10 | "- Shape feature $\\mathbf{f}$.\n", 11 | "- Dense point feature $\\mathbf{z}$.\n", 12 | "\n", 13 | "\n", 14 | "

\n", 15 | " \n", 16 | "

\n", 17 | "\n", 18 | "We encode these features on ScanNet dataset. Please use the following code to load the data.\n" 19 | ] 20 | }, 21 | { 22 | "cell_type": "markdown", 23 | "metadata": {}, 24 | "source": [ 25 | "### Read implicit features\n" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 5, 31 | "metadata": {}, 32 | "outputs": [ 33 | { 34 | "name": "stdout", 35 | "output_type": "stream", 36 | "text": [ 37 | "Load 20 features from /data2/ScanNetRag/val/scene0025_00c\n" 38 | ] 39 | } 40 | ], 41 | "source": [ 42 | "import os\n", 43 | "import torch\n", 44 | "\n", 45 | "scene_dir = '/data2/ScanNetRag/val/scene0025_00c'\n", 46 | "feature_dict = torch.load(scene_dir + '/features.pth')\n", 47 | "x = feature_dict['x'] # (N,d)\n", 48 | "f = feature_dict['f'] # (N,ds)\n", 49 | "labels = feature_dict['labels'] # (N,)\n", 50 | "N = x.shape[0]\n", 51 | "\n", 52 | "print('Load {} features from {}'.format(N, scene_dir))\n" 53 | ] 54 | }, 55 | { 56 | "cell_type": "markdown", 57 | "metadata": {}, 58 | "source": [ 59 | "### Read tags from RAM" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": 6, 65 | "metadata": {}, 66 | "outputs": [ 67 | { 68 | "name": "stdout", 69 | "output_type": "stream", 70 | "text": [ 71 | "The scene has tags: armchair, black, bureau, cart, chair, couch, table, drawer, dresser, file cabinet, floor, nightstand, room, stool, waiting room, attach, board, bottle, wall, bulletin board, classroom, note, office supply, whiteboard, write, writing, magnet, white, carpet, closet, corner, pad, gray, hassock, living room, pillow, throw pillow, footrest, hospital room, air conditioner, blue, office, radiator, ceiling fan, door, fan, computer, doorway, electronic, equipment, floor fan, shelf, speaker, cabinet, computer desk, file, office chair, office desk, shelve, swivel chair, hide, bin, computer monitor, lamp, hang, mark, office building, computer screen, desktop, desktop computer, monitor, office cubicle, pen, \n", 72 | "\n" 73 | ] 74 | } 75 | ], 76 | "source": [ 77 | "\n", 78 | "with open(scene_dir + '/tags.txt', 'r') as f:\n", 79 | " tags = f.readlines()[0]\n", 80 | " f.close()\n", 81 | "\n", 82 | " print('The scene has tags: {}'.format(tags))" 83 | ] 84 | }, 85 | { 86 | "cell_type": "markdown", 87 | "metadata": {}, 88 | "source": [ 89 | "### Visualize a scene (have not tested)" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": null, 95 | "metadata": {}, 96 | "outputs": [], 97 | "source": [ 98 | "import open3d as o3d \n", 99 | "from open3d.web_visualizer import draw\n", 100 | "\n", 101 | "explicit_scene_folder = '/data2/ScanNetGraph/val/scene0064_00c'\n", 102 | "\n", 103 | "instance_map_dir = os.path.join(explicit_scene_folder, '0002.ply')\n", 104 | "pcd = o3d.io.read_point_cloud(instance_map_dir)\n", 105 | "print('Load {} points from {}'.format(len(pcd.points), instance_map_dir))\n", 106 | "# draw([pcd, pcd])\n" 107 | ] 108 | } 109 | ], 110 | "metadata": { 111 | "kernelspec": { 112 | "display_name": "sgnet", 113 | "language": "python", 114 | "name": "python3" 115 | }, 116 | "language_info": { 117 | "codemirror_mode": { 118 | "name": "ipython", 119 | "version": 3 120 | }, 121 | "file_extension": ".py", 122 | "mimetype": "text/x-python", 123 | "name": "python", 124 | "nbconvert_exporter": "python", 125 | "pygments_lexer": "ipython3", 126 | "version": "3.9.18" 127 | } 128 | }, 129 | "nbformat": 4, 130 | "nbformat_minor": 2 131 | } 132 | --------------------------------------------------------------------------------