├── .gitignore ├── README.md ├── assets ├── architecture.png ├── attention_diagram2.png ├── box_plot_easy.png ├── box_plot_hard.png ├── lc.png ├── local_correspondences4.png └── no_lc.png ├── requirements.txt └── src ├── __init__.py ├── config ├── model.yaml └── train.yaml ├── data ├── __init__.py ├── augmentation.py ├── dataset_utils.py ├── datasets │ ├── __init__.py │ ├── alita │ │ ├── alita_raw.py │ │ ├── generate_evaluation_sets.py │ │ └── test_val_5_0.01_5.pickle │ ├── augmentation.py │ ├── base_datasets.py │ ├── dataset_utils.py │ ├── kitti │ │ ├── generate_evaluation_sets.py │ │ ├── kitti_00_eval.pickle │ │ ├── kitti_raw.py │ │ └── utils.py │ ├── kitti360 │ │ ├── generate_evaluation_sets.py │ │ ├── kitti360_09_3.0_eval.pickle │ │ ├── kitti360_raw.py │ │ └── utils.py │ ├── mulran │ │ ├── generate_evaluation_sets.py │ │ ├── generate_training_tuples.py │ │ ├── mulran_raw.py │ │ ├── mulran_train.py │ │ ├── test_DCC1_DCC2_10.0_5.pickle │ │ ├── test_Sejong1_Sejong2_0.2_20.pickle │ │ ├── test_Sejong1_Sejong2_0.2_5.pickle │ │ └── utils.py │ ├── point_clouds_utils.py │ ├── poses_utils.py │ ├── quantization.py │ ├── samplers.py │ └── southbay │ │ ├── generate_evaluation_sets.py │ │ ├── generate_training_tuples.py │ │ ├── pypcd.py │ │ ├── southbay_raw.py │ │ ├── southbay_raw_old.py │ │ └── test_SunnyvaleBigloop_1.0_5_20m.pickle └── sejong_southbay.py ├── evaluate ├── SALSA │ ├── eval_salsa_sgv.py │ └── sgv_utils.py └── pca.py ├── loss ├── __init__.py ├── global_loss.py ├── local_consistency_loss.py └── loss.py ├── misc ├── point_clouds.py ├── poses.py ├── robot_trans.py └── utils.py ├── models ├── Mixer │ ├── __init__.py │ └── mixer.py ├── SphereFormer │ ├── SparseTransformer │ │ ├── .gitignore │ │ ├── README.md │ │ ├── __init__.py │ │ ├── license │ │ ├── setup.py │ │ ├── sptr │ │ │ ├── __init__.py │ │ │ ├── functional.py │ │ │ ├── modules.py │ │ │ ├── position_embedding.py │ │ │ └── utils.py │ │ ├── src │ │ │ └── sptr │ │ │ │ ├── __init__.py │ │ │ │ ├── attention │ │ │ │ ├── attention_cuda.cpp │ │ │ │ ├── attention_cuda_kernel.cu │ │ │ │ └── attention_cuda_kernel.h │ │ │ │ ├── cuda_utils.h │ │ │ │ ├── pointops_api.cpp │ │ │ │ ├── precompute │ │ │ │ ├── precompute.cpp │ │ │ │ ├── precompute_cuda_kernel.cu │ │ │ │ └── precompute_cuda_kernel.h │ │ │ │ └── rpe │ │ │ │ ├── relative_pos_encoding_cuda.cpp │ │ │ │ ├── relative_pos_encoding_cuda_kernel.cu │ │ │ │ └── relative_pos_encoding_cuda_kernel.h │ │ └── test │ │ │ ├── pointops.py │ │ │ ├── test_attention_op_step1.py │ │ │ ├── test_attention_op_step2.py │ │ │ ├── test_precompute_all.py │ │ │ ├── test_relative_pos_encoding_op_step1.py │ │ │ ├── test_relative_pos_encoding_op_step1_all.py │ │ │ └── test_relative_pos_encoding_op_step2.py │ └── model │ │ ├── spherical_transformer.py │ │ └── unet_spherical_transformer.py ├── __init__.py ├── adappool.py ├── pca_model.py └── salsa.py ├── train.py └── utils ├── __init__.py ├── misc_utils.py └── o3d_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore large files 2 | src/checkpoints/ 3 | *.pth 4 | *__pycache__/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SALSA: Swift Adaptive Lightweight Self-Attention for Enhanced LiDAR Place Recognition 2 | 3 | 📖 Paper: [`RA-L`](https://ieeexplore.ieee.org/document/10629049) 4 | 5 | 📖 Pre-print: [``arXiv``](https://arxiv.org/abs/2407.08260) 6 | 7 | 📹 Video: [`Youtube`](https://www.youtube.com/watch?v=JLunemW91bQ) 8 | 9 | #### Authors: Raktim Gautam Goswami, Naman Patel, Prashanth Krishnamurthy, Farshad Khorrami 10 | 11 | #### Control/Robotics Research Laboratory (CRRL), Department of Electrical and Computer Engineering, NYU Tandon School of Engineering 12 | 13 | ### 💡 Contributions 14 | - **SALSA**: A novel, lightweight, and efficient framework for LiDAR place recognition that delivers state-of-the-art performance while maintaining real-time operational capabilities. 15 | - **SphereFormer**: Utilized for local descriptor extraction with radial and cubic window attention to boost localization performance for sparse distant points. 16 | - **Adaptive Pooling**: A self-attention adaptive pooling module to fuse local descriptors into global tokens. It can aggregate arbitrary numbers of points in a point cloud without pre-processing. 17 | 18 | - **MLP Mixer Token Aggregator**: An MLP mixer-based aggregator to iteratively incorporate global context information to generate a robust global scene descriptor. 19 | 20 | 21 |
22 | Alt text 23 |

Fig. 1: Overview of our SALSA framework to generate scene descriptors from point clouds for place recognition.

24 |
25 | 26 | 27 | ### 🔨 Environment creation 28 | 29 | ```bash 30 | conda create --name salsa python=3.10.11 31 | conda activate salsa 32 | pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu117 33 | pip install torch-scatter -f https://data.pyg.org/whl/torch-2.0.1+cu117.html 34 | pip install torch-sparse -f https://data.pyg.org/whl/torch-2.0.1+cu117.html 35 | pip install torch-cluster -f https://data.pyg.org/whl/torch-2.0.1+cu117.html 36 | pip install -r requirements.txt 37 | pip install --no-deps timm==0.9.7 38 | ``` 39 | Install [SpTr](https://github.com/dvlab-research/SparseTransformer) from source. 40 | 41 | 42 | ### 📊💾 Dataset Download 43 | The model is trained on Mulran Sejong 01/02 sequences and Apollo Southbay (excluding Sunnyvale). Evaluation is performed on 'easy set': Apollo-Southbay (Sunnyvale), SemanticKITTI, Mulran Sejong, and 'hard set': Mulran DCC1/DCC2, KITTI360, ALITA. The datasets can be downloaded from the following links. 44 | - [MulRan](https://sites.google.com/view/mulran-pr/download) dataset: ground truth data (*.csv) and LiDAR point clouds (Ouster.zip). 45 | - [Apollo-Southbay](https://developer.apollo.auto/southbay.html) dataset. 46 | - [SemanticKITTI](http://semantic-kitti.org/dataset.html#download) dataset (velodyne point clouds and calibration data for poses). 47 | - [ALITA](https://github.com/MetaSLAM/ALITA) dataset. 48 | - [KITTI-360](https://www.cvlibs.net/datasets/kitti-360/user_login.php) dataset (raw velodyne scans, calibrations and vehicle poses). 49 | 50 | 51 | ### 📊💾 Dataset Pickle Creation 52 | 53 | Create Training Pickle 54 | 55 | ```bash 56 | cd src/data/datasets/ 57 | python southbay/generate_training_tuples.py --dataset_root 58 | python mulran/generate_training_tuples.py --dataset_root 59 | ``` 60 | 61 | Create Evaluation Pickle 62 | ```bash 63 | python mulran/generate_evaluation_sets.py --dataset_root , --sequence sejong 64 | python mulran/generate_evaluation_sets.py --dataset_root , --sequence mulran 65 | python southbay/generate_evaluation_sets.py --dataset_root 66 | python kitti/generate_evaluation_sets.py --dataset_root 67 | python kitti360/generate_evaluation_sets.py --dataset_root 68 | python alita/generate_evaluation_sets.py --dataset_root 69 | ``` 70 | 71 | ### ✈️ Training 72 | Navigate to the base, create a folder inside src named checkpoints to save the trained models directory and start training. 73 | To change the model and training parameters, change them in config/model.yaml and config/train.yaml, respectively. 74 | ```bash 75 | mkdir -p src/checkpoints/SALSA/Model src/checkpoints/SALSA/PCA 76 | python src/train.py 77 | ``` 78 | This will train the model on the generated training dataset and store the saved models for each epoch in src/checkpoints. 79 | 80 | ### ✈️ PCA 81 | Learn PCA using trained model 82 | ```bash 83 | python src/evaluate/pca.py 84 | ``` 85 | This will learn a PCA to compress and decorrelate the scene descriptor and store the learned PCA as a pytorch model in src/checkpoints. 86 | 87 | ### ✈️ Evaluation 88 | ```bash 89 | python src/evaluate/SALSA/eval_salsa_sgv.py --dataset_root --dataset_type --only_global True 90 | ``` 91 | Our pre-trained models can also be downloaded from this [link](https://drive.google.com/drive/folders/1lehq0Hki75i7U_Twhd5uxxz37WvcRzGa?usp=sharing). After downloading, copy the contents into the 'src/checkpoints' directory. 92 | 93 | 94 | ### 📝 Results 95 | The spreads of the Recall@1 before and after re-ranking for best-performing models are plotted in the following figure. 96 |
97 |
98 | Easy Box Plot 99 |

Fig. 2a: 'Easy' Dataset

100 |
101 |
102 | Hard Box Plot 103 |

'Fig. 2b: Hard' Dataset

104 |
105 |

Fig. 2: Box plot displaying Recall@1 across six datasets, with first to third quartile spans, whiskers for data variability, and internal lines as medians.

106 |
107 | 108 | 109 | ### 🌈 Visualizations 110 |
111 | Alt text 112 |

Fig. 3: Visualization of areas attended to by different tokens from the adaptive pooling layer. Each token focuses on different geometries: trees and traffic signs (green), road intersections (red), and distant points (blue).

113 |
114 | 115 |


116 | 117 |
118 | Alt text 119 |

Fig. 4: Point matches between query and target clouds using LoGG3D-Net and SALSA local descriptors. Matching colors indicate correspondences; circles highlight SALSA’s superior performance on sparse distant points.

120 |
121 | 122 |


123 | 124 |
125 |
126 | Easy Box Plot 127 |

Fig. 5a: Without Loop Detection.

128 |
129 |
130 | Hard Box Plot 131 |

Fig. 5b: With Loop Detection.

132 |
133 |

Fig. 5: Comparison of LiDAR-only odometry and maps: (a) without loop detection, and (b) after online pose graph optimization from SALSA loop detections. The highlighted rectangles emphasize the map and odometry disparities due to loop closures.

134 |
135 | 136 | ## 📧 Citation 137 | 138 | If you find our work useful in your research please consider citing our publication: 139 | ```bibtex 140 | @article{goswami2024salsa, 141 | title={SALSA: Swift Adaptive Lightweight Self-Attention for Enhanced LiDAR Place Recognition}, 142 | author={Goswami, Raktim Gautam and Patel, Naman and Krishnamurthy, Prashanth and Khorrami, Farshad}, 143 | journal={IEEE Robotics and Automation Letters}, 144 | year={2024}, 145 | } 146 | ``` -------------------------------------------------------------------------------- /assets/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raktimgg/SALSA/9152e8dff6fb3fc07dd546c9f19bfab462d08a6f/assets/architecture.png -------------------------------------------------------------------------------- /assets/attention_diagram2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raktimgg/SALSA/9152e8dff6fb3fc07dd546c9f19bfab462d08a6f/assets/attention_diagram2.png -------------------------------------------------------------------------------- /assets/box_plot_easy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raktimgg/SALSA/9152e8dff6fb3fc07dd546c9f19bfab462d08a6f/assets/box_plot_easy.png -------------------------------------------------------------------------------- /assets/box_plot_hard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raktimgg/SALSA/9152e8dff6fb3fc07dd546c9f19bfab462d08a6f/assets/box_plot_hard.png -------------------------------------------------------------------------------- /assets/lc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raktimgg/SALSA/9152e8dff6fb3fc07dd546c9f19bfab462d08a6f/assets/lc.png -------------------------------------------------------------------------------- /assets/local_correspondences4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raktimgg/SALSA/9152e8dff6fb3fc07dd546c9f19bfab462d08a6f/assets/local_correspondences4.png -------------------------------------------------------------------------------- /assets/no_lc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raktimgg/SALSA/9152e8dff6fb3fc07dd546c9f19bfab462d08a6f/assets/no_lc.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch_geometric 2 | spconv-cu117 3 | cumm-cu117 4 | tensorboard 5 | open3d 6 | python-lzf 7 | faiss-gpu -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raktimgg/SALSA/9152e8dff6fb3fc07dd546c9f19bfab462d08a6f/src/__init__.py -------------------------------------------------------------------------------- /src/config/model.yaml: -------------------------------------------------------------------------------- 1 | feat_extractor: 2 | feature_dim: 16 3 | patch_size: 1 4 | input_c: 3 5 | m: 32 6 | block_reps: 2 7 | layers: [32, 64, 128] 8 | window_size_sphere: [2, 2, 80] 9 | drop_path_rate: 0.3 10 | window_size_scale: [2.0, 1.5] 11 | sphere_layers: [1,2,3,4,5] 12 | a: 0.0125 13 | 14 | aggregator: 15 | tokens: 512 16 | out_channels: 128 17 | mix_depth: 4 18 | mlp_ratio: 1 19 | out_d: 4 20 | -------------------------------------------------------------------------------- /src/config/train.yaml: -------------------------------------------------------------------------------- 1 | batch_size: 12 2 | writer_loc: 'runs/SALSA' 3 | cached_queries: 1000 4 | device: 'cuda' 5 | max_epoch: 50 6 | lr: 0.001 7 | max_lr: 0.005 8 | outdim: 512 -------------------------------------------------------------------------------- /src/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raktimgg/SALSA/9152e8dff6fb3fc07dd546c9f19bfab462d08a6f/src/data/__init__.py -------------------------------------------------------------------------------- /src/data/dataset_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def fnv_hash_vec(arr): 5 | """ 6 | FNV64-1A 7 | """ 8 | assert arr.ndim == 2 9 | # Floor first for negative coordinates 10 | arr = arr.copy() 11 | arr = arr.astype(np.uint64, copy=False) 12 | hashed_arr = np.uint64(14695981039346656037) * np.ones(arr.shape[0], dtype=np.uint64) 13 | for j in range(arr.shape[1]): 14 | hashed_arr *= np.uint64(1099511628211) 15 | hashed_arr = np.bitwise_xor(hashed_arr, arr[:, j]) 16 | return hashed_arr 17 | 18 | 19 | def ravel_hash_vec(arr): 20 | """ 21 | Ravel the coordinates after subtracting the min coordinates. 22 | """ 23 | assert arr.ndim == 2 24 | arr = arr.copy() 25 | arr -= arr.min(0) 26 | arr = arr.astype(np.uint64, copy=False) 27 | arr_max = arr.max(0).astype(np.uint64) + 1 28 | 29 | keys = np.zeros(arr.shape[0], dtype=np.uint64) 30 | # Fortran style indexing 31 | for j in range(arr.shape[1] - 1): 32 | keys += arr[:, j] 33 | keys *= arr_max[j + 1] 34 | keys += arr[:, -1] 35 | return keys 36 | 37 | 38 | def voxelize(coord, voxel_size=0.05, hash_type='fnv', mode=0): 39 | discrete_coord = np.floor(coord / np.array(voxel_size)) 40 | if hash_type == 'ravel': 41 | key = ravel_hash_vec(discrete_coord) 42 | else: 43 | key = fnv_hash_vec(discrete_coord) 44 | 45 | idx_sort = np.argsort(key) 46 | key_sort = key[idx_sort] 47 | _, count = np.unique(key_sort, return_counts=True) 48 | idx_select = np.cumsum(np.insert(count, 0, 0)[0:-1]) + np.random.randint(0, count.max(), count.size) % count 49 | idx_unique = idx_sort[idx_select] 50 | return idx_unique 51 | -------------------------------------------------------------------------------- /src/data/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raktimgg/SALSA/9152e8dff6fb3fc07dd546c9f19bfab462d08a6f/src/data/datasets/__init__.py -------------------------------------------------------------------------------- /src/data/datasets/alita/alita_raw.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import numpy as np 5 | import open3d as o3d 6 | 7 | sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) 8 | from datasets.point_clouds_utils import PointCloudLoader 9 | 10 | 11 | class ALITAPointCloudLoader(PointCloudLoader): 12 | def set_properties(self): 13 | # Set point cloud propertiers, such as ground_plane_level. Must be defined in inherited classes. 14 | self.ground_plane_level = -1.6 15 | 16 | def read_pc(self, file_pathname): 17 | pcd = o3d.io.read_point_cloud(file_pathname) 18 | xyz = np.asarray(pcd.points) 19 | return xyz 20 | -------------------------------------------------------------------------------- /src/data/datasets/alita/generate_evaluation_sets.py: -------------------------------------------------------------------------------- 1 | # Generate evaluation sets 2 | # This script is adapted from: https://github.com/jac99/Egonn/blob/main/datasets/southbay/generate_evaluation_sets.py 3 | 4 | import argparse 5 | import glob 6 | import os 7 | import sys 8 | from typing import List 9 | 10 | import numpy as np 11 | from scipy.spatial.transform import Rotation as R 12 | 13 | sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) 14 | 15 | from datasets.base_datasets import EvaluationSet, EvaluationTuple, filter_query_elements 16 | 17 | 18 | def get_pose_transform(pose6d): 19 | rot_matrix = R.from_euler('xyz', pose6d[3:]).as_matrix() 20 | trans_vector = pose6d[:3].reshape((3, 1)) 21 | 22 | trans_matrix = np.identity(4) 23 | trans_matrix[:3, :3] = rot_matrix 24 | trans_matrix[:3, 3:] = trans_vector 25 | 26 | return trans_matrix 27 | 28 | def get_length(poses): 29 | current_pos = poses[0][:2,3] 30 | total = 0 31 | for i in range(len(poses)-1): 32 | delta = np.linalg.norm(poses[i][:2,3] - poses[i+1][:2,3]) 33 | print(delta) 34 | total += delta 35 | print('') 36 | 37 | def get_poses_ugv(pose_files): 38 | poses = [] 39 | timestamps = [] 40 | for f in pose_files: 41 | pose = np.load(f) 42 | timestamps.append(pose[6]) 43 | T = get_pose_transform(pose[:6]) 44 | poses.append(T) 45 | return np.asarray(timestamps), np.asarray(poses) 46 | 47 | def get_scans(base_dir, area, split, min_displacement: float = 0.0) -> List[EvaluationTuple]: 48 | 49 | operating_dir = os.path.join(base_dir, split, area) 50 | pcd_files = sorted(glob.glob(os.path.join(operating_dir, '*.pcd'))) 51 | pose_files = sorted(glob.glob(os.path.join(operating_dir, '*.npy'))) 52 | timestamps, poses = get_poses_ugv(pose_files) 53 | get_length(poses) 54 | 55 | 56 | elems = [] 57 | for ndx in range(len(pcd_files)): 58 | pose = poses[ndx] 59 | position = pose[0:2, 3] # (x, y) position in global coordinate frame 60 | rel_scan_filepath = pcd_files[ndx][len(base_dir)+1:] 61 | print(rel_scan_filepath) 62 | timestamp = timestamps[ndx] 63 | 64 | item = EvaluationTuple(timestamp, rel_scan_filepath, position=position, pose=pose) 65 | elems.append(item) 66 | 67 | print(f"{len(elems)} total elements in {split} split") 68 | 69 | # Filter-out elements leaving only 1 per grid cell with min_displacement size 70 | pos = np.zeros((len(elems), 2), dtype=np.float32) 71 | for ndx, e in enumerate(elems): 72 | pos[ndx] = e.position 73 | 74 | # Quantize x-y coordinates. Quantized coords start from 0 75 | pos = np.floor(pos / min_displacement) 76 | pos = pos.astype(int) 77 | _, unique_ndx = np.unique(pos, axis=0, return_index=True) 78 | 79 | # Leave only unique elements 80 | elems = [elems[i] for i in unique_ndx] 81 | print(f"{len(elems)} filtered elements in {split} split with grid cell size = {min_displacement}") 82 | 83 | return elems 84 | 85 | 86 | def generate_evaluation_set(dataset_root: str, area: str, min_displacement: float = 0.0, dist_threshold=5) -> \ 87 | EvaluationSet: 88 | map_set = get_scans(dataset_root, area, 'DATABASE', min_displacement) 89 | query_set = get_scans(dataset_root, area, 'QUERY', min_displacement) 90 | query_set = filter_query_elements(query_set, map_set, dist_threshold) 91 | print(f'Area: {area} - {len(map_set)} database elements, {len(query_set)} query elements\n') 92 | return EvaluationSet(query_set, map_set) 93 | 94 | 95 | if __name__ == '__main__': 96 | parser = argparse.ArgumentParser(description='Generate evaluation sets for UGV dataset') 97 | parser.add_argument('--dataset_root', type=str, required=False, default='/data/raktim/Datasets/ALITA/VAL') 98 | parser.add_argument('--min_displacement', type=float, default=0.01) 99 | # Ignore query elements that do not have a corresponding map element within the given threshold (in meters) 100 | parser.add_argument('--dist_threshold', type=float, default=5) 101 | 102 | args = parser.parse_args() 103 | print(f'Dataset root: {args.dataset_root}') 104 | print(f'Minimum displacement between scans in each set (map/query): {args.min_displacement}') 105 | print(f'Ignore query elements without a corresponding map element within a threshold [m]: {args.dist_threshold}') 106 | 107 | area = 'val_5' # Evaluation area 108 | eval_set = generate_evaluation_set(dataset_root=args.dataset_root,area=area, min_displacement=args.min_displacement, 109 | dist_threshold=args.dist_threshold) 110 | pickle_name = f'test_{area}_{args.min_displacement}_{args.dist_threshold}.pickle' 111 | # file_path_name = os.path.join(args.dataset_root, pickle_name) 112 | file_path_name = os.path.join(os.path.dirname(__file__), pickle_name) 113 | print(f"Saving evaluation pickle: {file_path_name}") 114 | eval_set.save(file_path_name) 115 | -------------------------------------------------------------------------------- /src/data/datasets/alita/test_val_5_0.01_5.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raktimgg/SALSA/9152e8dff6fb3fc07dd546c9f19bfab462d08a6f/src/data/datasets/alita/test_val_5_0.01_5.pickle -------------------------------------------------------------------------------- /src/data/datasets/base_datasets.py: -------------------------------------------------------------------------------- 1 | # Base dataset classes, inherited by dataset-specific classes 2 | # This file is adapted from: https://github.com/jac99/Egonn/blob/main/datasets/base_datasets.py 3 | 4 | import os 5 | import sys 6 | import pickle 7 | from typing import List, Dict 8 | import torch 9 | import numpy as np 10 | from sklearn.neighbors import KDTree 11 | from torch.utils.data import Dataset 12 | 13 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 14 | from datasets.alita.alita_raw import ALITAPointCloudLoader 15 | from datasets.kitti.kitti_raw import KittiPointCloudLoader 16 | from datasets.mulran.mulran_raw import MulranPointCloudLoader 17 | from datasets.southbay.southbay_raw import SouthbayPointCloudLoader 18 | from datasets.kitti360.kitti360_raw import Kitti360PointCloudLoader 19 | from datasets.point_clouds_utils import PointCloudLoader 20 | 21 | 22 | class TrainingTuple: 23 | # Tuple describing an element for training/validation 24 | def __init__(self, id: int, timestamp: int, rel_scan_filepath: str, positives: np.ndarray, 25 | non_negatives: np.ndarray, pose: np, positives_poses: Dict[int, np.ndarray] = None): 26 | # id: element id (ids start from 0 and are consecutive numbers) 27 | # ts: timestamp 28 | # rel_scan_filepath: relative path to the scan 29 | # positives: sorted ndarray of positive elements id 30 | # negatives: sorted ndarray of elements id 31 | # pose: pose as 4x4 matrix 32 | # positives_poses: relative poses of positive examples refined using ICP 33 | self.id = id 34 | self.timestamp = timestamp 35 | self.rel_scan_filepath = rel_scan_filepath 36 | self.positives = positives 37 | self.non_negatives = non_negatives 38 | self.pose = pose 39 | self.positives_poses = positives_poses 40 | 41 | 42 | class EvaluationTuple: 43 | # Tuple describing an evaluation set element 44 | def __init__(self, timestamp: int, rel_scan_filepath: str, position: np.array, pose: np.array = None): 45 | # position: x, y position in meters 46 | # pose: 6 DoF pose (as 4x4 pose matrix) 47 | assert position.shape == (2,) 48 | assert pose is None or pose.shape == (4, 4) 49 | self.timestamp = timestamp 50 | self.rel_scan_filepath = rel_scan_filepath 51 | self.position = position 52 | self.pose = pose 53 | 54 | def to_tuple(self): 55 | return self.timestamp, self.rel_scan_filepath, self.position, self.pose 56 | 57 | 58 | class TrainingDataset(Dataset): 59 | def __init__(self, dataset_path: str, dataset_type: str, query_filename: str, transform=None, set_transform=None): 60 | # remove_zero_points: remove points with all zero coords 61 | assert os.path.exists(dataset_path), 'Cannot access dataset path: {}'.format(dataset_path) 62 | self.dataset_path = dataset_path 63 | self.dataset_type = dataset_type 64 | self.query_filepath = os.path.join(dataset_path, query_filename) 65 | assert os.path.exists(self.query_filepath), 'Cannot access query file: {}'.format(self.query_filepath) 66 | self.transform = transform 67 | self.set_transform = set_transform 68 | self.queries: Dict[int, TrainingTuple] = pickle.load(open(self.query_filepath, 'rb')) 69 | print('{} queries in the dataset'.format(len(self))) 70 | 71 | # pc_loader must be set in the inheriting class 72 | self.pc_loader = get_pointcloud_loader(self.dataset_type) 73 | 74 | def __len__(self): 75 | return len(self.queries) 76 | 77 | def __getitem__(self, ndx): 78 | # Load point cloud and apply transform 79 | file_pathname = os.path.join(self.dataset_path, self.queries[ndx].rel_scan_filepath) 80 | query_pc = self.pc_loader(file_pathname) 81 | query_pc = torch.tensor(query_pc, dtype=torch.float) 82 | if self.transform is not None: 83 | query_pc = self.transform(query_pc) 84 | return query_pc, ndx 85 | 86 | def get_positives(self, ndx): 87 | return self.queries[ndx].positives 88 | 89 | def get_non_negatives(self, ndx): 90 | return self.queries[ndx].non_negatives 91 | 92 | 93 | class EvaluationSet: 94 | # Evaluation set consisting of map and query elements 95 | def __init__(self, query_set: List[EvaluationTuple] = None, map_set: List[EvaluationTuple] = None): 96 | self.query_set = query_set 97 | self.map_set = map_set 98 | 99 | def save(self, pickle_filepath: str): 100 | # Pickle the evaluation set 101 | 102 | # Convert data to tuples and save as tuples 103 | query_l = [] 104 | for e in self.query_set: 105 | query_l.append(e.to_tuple()) 106 | 107 | map_l = [] 108 | for e in self.map_set: 109 | map_l.append(e.to_tuple()) 110 | pickle.dump([query_l, map_l], open(pickle_filepath, 'wb')) 111 | 112 | def load(self, pickle_filepath: str): 113 | # Load evaluation set from the pickle 114 | query_l, map_l = pickle.load(open(pickle_filepath, 'rb')) 115 | 116 | self.query_set = [] 117 | for e in query_l: 118 | self.query_set.append(EvaluationTuple(e[0], e[1], e[2], e[3])) 119 | 120 | self.map_set = [] 121 | for e in map_l: 122 | self.map_set.append(EvaluationTuple(e[0], e[1], e[2], e[3])) 123 | 124 | def get_map_positions(self): 125 | # Get map positions as (N, 2) array 126 | positions = np.zeros((len(self.map_set), 2), dtype=self.map_set[0].position.dtype) 127 | for ndx, pos in enumerate(self.map_set): 128 | positions[ndx] = pos.position 129 | return positions 130 | 131 | def get_query_positions(self): 132 | # Get query positions as (N, 2) array 133 | positions = np.zeros((len(self.query_set), 2), dtype=self.query_set[0].position.dtype) 134 | for ndx, pos in enumerate(self.query_set): 135 | positions[ndx] = pos.position 136 | return positions 137 | 138 | def filter_query_elements(query_set: List[EvaluationTuple], map_set: List[EvaluationTuple], 139 | dist_threshold: float) -> List[EvaluationTuple]: 140 | # Function used in evaluation dataset generation 141 | # Filters out query elements without a corresponding map element within dist_threshold threshold 142 | map_pos = np.zeros((len(map_set), 2), dtype=np.float32) 143 | for ndx, e in enumerate(map_set): 144 | map_pos[ndx] = e.position 145 | 146 | # Build a kdtree 147 | kdtree = KDTree(map_pos) 148 | 149 | filtered_query_set = [] 150 | count_ignored = 0 151 | for ndx, e in enumerate(query_set): 152 | position = e.position.reshape(1, -1) 153 | nn = kdtree.query_radius(position, dist_threshold, count_only=True)[0] 154 | if nn > 0: 155 | filtered_query_set.append(e) 156 | else: 157 | count_ignored += 1 158 | 159 | print(f"{count_ignored} query elements ignored - not having corresponding map element within {dist_threshold} [m] radius") 160 | return filtered_query_set 161 | 162 | def get_pointcloud_loader(dataset_type) -> PointCloudLoader: 163 | if dataset_type == 'mulran': 164 | return MulranPointCloudLoader() 165 | elif dataset_type == 'southbay': 166 | return SouthbayPointCloudLoader() 167 | elif dataset_type == 'kitti': 168 | return KittiPointCloudLoader() 169 | elif dataset_type == 'alita': 170 | return ALITAPointCloudLoader() 171 | elif dataset_type == 'kitti360': 172 | return Kitti360PointCloudLoader() 173 | else: 174 | raise NotImplementedError(f"Unsupported dataset type: {dataset_type}") -------------------------------------------------------------------------------- /src/data/datasets/kitti/generate_evaluation_sets.py: -------------------------------------------------------------------------------- 1 | # Test set for Kitti Sequence 00 dataset. 2 | # Following procedures in [cite papers Kitti for place reco] we use 170 seconds of drive from sequence for map generation 3 | # and the rest is left for queries 4 | 5 | import numpy as np 6 | import argparse 7 | from typing import List 8 | import os 9 | import sys 10 | sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) 11 | 12 | from datasets.kitti.kitti_raw import KittiSequence 13 | from datasets.base_datasets import EvaluationTuple, EvaluationSet 14 | from datasets.dataset_utils import filter_query_elements 15 | 16 | 17 | MAP_TIMERANGE = (0, 170) 18 | 19 | 20 | def get_scans(sequence: KittiSequence, min_displacement: float = 0.1, ts_range: tuple = None) -> List[EvaluationTuple]: 21 | # Get a list of all point clouds from the sequence (the full sequence or test split only) 22 | 23 | elems = [] 24 | old_pos = None 25 | count_skipped = 0 26 | displacements = [] 27 | 28 | for ndx in range(len(sequence)): 29 | if ts_range is not None: 30 | if (ts_range[0] > sequence.rel_lidar_timestamps[ndx]) or (ts_range[1] < sequence.rel_lidar_timestamps[ndx]): 31 | continue 32 | pose = sequence.lidar_poses[ndx] 33 | # Kitti poses are in camera coordinates system where where y is upper axis dim 34 | position = pose[[0,2], 3] 35 | 36 | if old_pos is not None: 37 | displacements.append(np.linalg.norm(old_pos - position)) 38 | 39 | if np.linalg.norm(old_pos - position) < min_displacement: 40 | # Ignore the point cloud if the vehicle didn't move 41 | count_skipped += 1 42 | continue 43 | 44 | item = EvaluationTuple(sequence.rel_lidar_timestamps[ndx], sequence.rel_scan_filepath[ndx], position, pose) 45 | elems.append(item) 46 | old_pos = position 47 | 48 | print(f'{count_skipped} clouds skipped due to displacement smaller than {min_displacement}') 49 | print(f'mean displacement {np.mean(np.array(displacements))}') 50 | return elems 51 | 52 | 53 | def generate_evaluation_set(dataset_root: str, map_sequence: str, min_displacement: float = 0.1, 54 | dist_threshold: float = 5.) -> EvaluationSet: 55 | 56 | sequence = KittiSequence(dataset_root, map_sequence) 57 | 58 | map_set = get_scans(sequence, min_displacement, MAP_TIMERANGE) 59 | query_set = get_scans(sequence, min_displacement, (MAP_TIMERANGE[-1], sequence.rel_lidar_timestamps[-1])) 60 | query_set = filter_query_elements(query_set, map_set, dist_threshold) 61 | print(f'{len(map_set)} database elements, {len(query_set)} query elements') 62 | return EvaluationSet(query_set, map_set) 63 | 64 | 65 | if __name__ == '__main__': 66 | parser = argparse.ArgumentParser(description='Generate evaluation sets for KItti dataset') 67 | parser.add_argument('--dataset_root', type=str, required=True) 68 | parser.add_argument('--min_displacement', type=float, default=0.1) 69 | # Ignore query elements that do not have a corresponding map element within the given threshold (in meters) 70 | parser.add_argument('--dist_threshold', type=float, default=5.) 71 | 72 | args = parser.parse_args() 73 | 74 | # Sequences are fixed 75 | sequence = '00' 76 | print(f'Dataset root: {args.dataset_root}') 77 | print(f'Kitti sequence: {sequence}') 78 | print(f'Minimum displacement between consecutive anchors: {args.min_displacement}') 79 | print(f'Ignore query elements without a corresponding map element within a threshold [m]: {args.dist_threshold}') 80 | 81 | kitti_eval_set = generate_evaluation_set(args.dataset_root, sequence, min_displacement=args.min_displacement, 82 | dist_threshold=args.dist_threshold) 83 | file_path_name = os.path.join(args.dataset_root, f'kitti_{sequence}_eval.pickle') 84 | print(f"Saving evaluation pickle: {file_path_name}") 85 | kitti_eval_set.save(file_path_name) 86 | -------------------------------------------------------------------------------- /src/data/datasets/kitti/kitti_00_eval.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raktimgg/SALSA/9152e8dff6fb3fc07dd546c9f19bfab462d08a6f/src/data/datasets/kitti/kitti_00_eval.pickle -------------------------------------------------------------------------------- /src/data/datasets/kitti/kitti_raw.py: -------------------------------------------------------------------------------- 1 | # Functions and classes operating on a raw Kitti dataset 2 | 3 | import os 4 | import sys 5 | 6 | import numpy as np 7 | import torch 8 | from torch.utils.data import Dataset 9 | 10 | sys.path.append('../../..') 11 | 12 | from misc.point_clouds import PointCloudLoader 13 | 14 | 15 | class KittiPointCloudLoader(PointCloudLoader): 16 | def set_properties(self): 17 | # Set point cloud properties, such as ground_plane_level. 18 | self.ground_plane_level = -1.5 19 | 20 | def read_pc(self, file_pathname: str) -> torch.Tensor: 21 | # Reads the point cloud without pre-processing 22 | # Returns Nx3 tensor 23 | pc = np.fromfile(file_pathname, dtype=np.float32) 24 | # PC in Mulran is of size [num_points, 4] -> x,y,z,reflectance 25 | pc = np.reshape(pc, (-1, 4))[:, :3] 26 | return pc 27 | 28 | 29 | class KittiSequence(Dataset): 30 | """ 31 | Point cloud from a sequence from a raw Mulran dataset 32 | """ 33 | def __init__(self, dataset_root: str, sequence_name: str, pose_time_tolerance: float = 1., 34 | remove_zero_points: bool = True): 35 | # pose_time_tolerance: (in seconds) skip point clouds without corresponding pose information (based on 36 | # timestamps difference) 37 | # remove_zero_points: remove (0,0,0) points 38 | 39 | assert os.path.exists(dataset_root), f'Cannot access dataset root: {dataset_root}' 40 | self.dataset_root = dataset_root 41 | self.sequence_name = sequence_name 42 | # self.sequence_path = os.path.join(self.dataset_root, 'sequences') 43 | # assert os.path.exists(self.sequence_path), f'Cannot access sequence: {self.sequence_path}' 44 | self.rel_lidar_path = os.path.join('sequences', self.sequence_name, 'velodyne') 45 | # lidar_path = os.path.join(self.sequence_path, self.rel_lidar_path) 46 | # assert os.path.exists(lidar_path), f'Cannot access lidar scans: {lidar_path}' 47 | self.pose_file = os.path.join(self.dataset_root, 'poses', self.sequence_name + '.txt') 48 | assert os.path.exists(self.pose_file), f'Cannot access sequence pose file: {self.pose_file}' 49 | self.times_file = os.path.join(self.dataset_root, 'sequences', self.sequence_name, 'times.txt') 50 | assert os.path.exists(self.pose_file), f'Cannot access sequence times file: {self.times_file}' 51 | # Maximum discrepancy between timestamps of LiDAR scan and global pose in seconds 52 | self.pose_time_tolerance = pose_time_tolerance 53 | self.remove_zero_points = remove_zero_points 54 | 55 | self.rel_lidar_timestamps, self.lidar_poses, filenames = self._read_lidar_poses() 56 | self.rel_scan_filepath = [os.path.join(self.rel_lidar_path, '%06d%s' % (e, '.bin')) for e in filenames] 57 | 58 | def __len__(self): 59 | return len(self.rel_lidar_timestamps) 60 | 61 | def __getitem__(self, ndx): 62 | scan_filepath = os.path.join(self.dataset_root, self.rel_scan_filepath[ndx]) 63 | pc = load_pc(scan_filepath) 64 | if self.remove_zero_points: 65 | mask = np.all(np.isclose(pc, 0), axis=1) 66 | pc = pc[~mask] 67 | return {'pc': pc, 'pose': self.lidar_poses[ndx], 'ts': self.rel_lidar_timestamps[ndx]} 68 | 69 | def _read_lidar_poses(self): 70 | fnames = os.listdir(os.path.join(self.dataset_root, self.rel_lidar_path)) 71 | temp = os.path.join(self.dataset_root, self.rel_lidar_path) 72 | fnames = [e for e in fnames if os.path.isfile(os.path.join(temp, e))] 73 | assert len(fnames) > 0, f"Make sure that the path {self.rel_lidar_path}" 74 | filenames = sorted([int(os.path.split(fname)[-1][:-4]) for fname in fnames]) 75 | with open(self.pose_file, "r") as h: 76 | txt_poses = h.readlines() 77 | 78 | n = len(txt_poses) 79 | poses = np.zeros((n, 4, 4), dtype=np.float64) # 4x4 pose matrix 80 | 81 | for ndx, pose in enumerate(txt_poses): 82 | # Split by comma and remove whitespaces 83 | temp = [e.strip() for e in pose.split(' ')] 84 | assert len(temp) == 12, f'Invalid line in global poses file: {temp}' 85 | # poses in kitti ar ein cam0 reference 86 | poses[ndx] = np.array([[float(temp[0]), float(temp[1]), float(temp[2]), float(temp[3])], 87 | [float(temp[4]), float(temp[5]), float(temp[6]), float(temp[7])], 88 | [float(temp[8]), float(temp[9]), float(temp[10]), float(temp[11])], 89 | [0., 0., 0., 1.]]) 90 | rel_ts = np.genfromtxt(self.times_file) 91 | 92 | return rel_ts, poses, filenames 93 | 94 | 95 | def load_pc(filepath): 96 | # Load point cloud, does not apply any transform 97 | # Returns Nx3 matrix 98 | pc = np.fromfile(filepath, dtype=np.float32) 99 | # PC in Kitti is of size [num_points, 4] -> x,y,z,reflectance 100 | pc = np.reshape(pc, (-1, 4))[:, :3] 101 | return pc 102 | -------------------------------------------------------------------------------- /src/data/datasets/kitti/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def velo2cam(): 5 | R = np.array([ 6 | 7.533745e-03, -9.999714e-01, -6.166020e-04, 1.480249e-02, 7.280733e-04, 7 | -9.998902e-01, 9.998621e-01, 7.523790e-03, 1.480755e-02 8 | ]).reshape(3, 3) 9 | T = np.array([-4.069766e-03, -7.631618e-02, -2.717806e-01]).reshape(3, 1) 10 | velo2cam = np.hstack([R, T]) 11 | velo2cam = np.vstack((velo2cam, [0, 0, 0, 1])).T 12 | return velo2cam 13 | 14 | 15 | def get_relative_pose(pose_1, pose_2): 16 | # as seen in https://github.com/chrischoy/FCGF 17 | M = (velo2cam() @ pose_1.T @ np.linalg.inv(pose_2.T) @ np.linalg.inv(velo2cam())).T 18 | return M 19 | -------------------------------------------------------------------------------- /src/data/datasets/kitti360/generate_evaluation_sets.py: -------------------------------------------------------------------------------- 1 | # Test set for Kitti360 Sequence 09. 2 | # This script is adapted from: https://github.com/jac99/Egonn/blob/main/datasets/kitti/generate_evaluation_sets.py 3 | 4 | import argparse 5 | import os 6 | import sys 7 | from typing import List 8 | 9 | import numpy as np 10 | 11 | sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) 12 | 13 | from datasets.base_datasets import EvaluationSet, EvaluationTuple, filter_query_elements 14 | from datasets.kitti360.kitti360_raw import Kitti360Sequence 15 | 16 | # MAP_TIMERANGE = (0, 170) 17 | MAP_TIMERANGE = (0, 300) 18 | 19 | def get_scans(sequence: Kitti360Sequence, min_displacement: float = 0.1, ts_range: tuple = None) -> List[EvaluationTuple]: 20 | # Get a list of all point clouds from the sequence (the full sequence or test split only) 21 | 22 | elems = [] 23 | old_pos = None 24 | count_skipped = 0 25 | displacements = [] 26 | 27 | for ndx in range(len(sequence)): 28 | if ts_range is not None: 29 | if (ts_range[0] > sequence.rel_lidar_timestamps[ndx]) or (ts_range[1] < sequence.rel_lidar_timestamps[ndx]): 30 | continue 31 | pose = sequence.lidar_poses[ndx] 32 | # Kitti poses are in camera coordinates system where where y is upper axis dim 33 | position = pose[[0,1], 3] 34 | 35 | if old_pos is not None: 36 | displacements.append(np.linalg.norm(old_pos - position)) 37 | 38 | if np.linalg.norm(old_pos - position) < min_displacement: 39 | # Ignore the point cloud if the vehicle didn't move 40 | count_skipped += 1 41 | continue 42 | # print(sequence.rel_scan_filepath) 43 | item = EvaluationTuple(sequence.rel_lidar_timestamps[ndx], sequence.rel_scan_filepath[ndx], position, pose) 44 | elems.append(item) 45 | old_pos = position 46 | 47 | print(f'{count_skipped} clouds skipped due to displacement smaller than {min_displacement}') 48 | print(f'mean displacement {np.mean(np.array(displacements))}') 49 | return elems 50 | 51 | 52 | def generate_evaluation_set(dataset_root: str, map_sequence: str, min_displacement: float = 0.1, 53 | dist_threshold: float = 5.) -> EvaluationSet: 54 | 55 | sequence = Kitti360Sequence(dataset_root, map_sequence) 56 | 57 | map_set = get_scans(sequence, min_displacement, MAP_TIMERANGE) 58 | query_set = get_scans(sequence, min_displacement, (MAP_TIMERANGE[-1], sequence.rel_lidar_timestamps[-1])) 59 | query_set = filter_query_elements(query_set, map_set, dist_threshold) 60 | print(f'{len(map_set)} database elements, {len(query_set)} query elements') 61 | return EvaluationSet(query_set, map_set) 62 | 63 | 64 | if __name__ == '__main__': 65 | parser = argparse.ArgumentParser(description='Generate evaluation sets for KItti dataset') 66 | # kitti: /mnt/088A6CBB8A6CA742/Datasets/Kitti/dataset/ 67 | # mulran: /mnt/088A6CBB8A6CA742/Datasets/MulRan/ 68 | # apollo: 69 | parser.add_argument('--dataset_root', type=str, required=False, default='/data/raktim/Datasets/KITTI360/KITTI-360/data_3d_raw') 70 | parser.add_argument('--min_displacement', type=float, default=3.0) 71 | # Ignore query elements that do not have a corresponding map element within the given threshold (in meters) 72 | parser.add_argument('--dist_threshold', type=float, default=5.) 73 | 74 | args = parser.parse_args() 75 | 76 | # Sequences are fixed 77 | sequence = '09' 78 | sequence_name = '2013_05_28_drive_00'+ sequence + '_sync' 79 | print(f'Dataset root: {args.dataset_root}') 80 | print(f'Kitti sequence: {sequence}') 81 | print(f'Minimum displacement between consecutive anchors: {args.min_displacement}') 82 | print(f'Ignore query elements without a corresponding map element within a threshold [m]: {args.dist_threshold}') 83 | 84 | kitti_eval_set = generate_evaluation_set(args.dataset_root, sequence, min_displacement=args.min_displacement, 85 | dist_threshold=args.dist_threshold) 86 | file_path_name = os.path.join(os.path.dirname(__file__), f'kitti360_{sequence}_{args.min_displacement}_eval.pickle') 87 | print(f"Saving evaluation pickle: {file_path_name}") 88 | kitti_eval_set.save(file_path_name) 89 | -------------------------------------------------------------------------------- /src/data/datasets/kitti360/kitti360_09_3.0_eval.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raktimgg/SALSA/9152e8dff6fb3fc07dd546c9f19bfab462d08a6f/src/data/datasets/kitti360/kitti360_09_3.0_eval.pickle -------------------------------------------------------------------------------- /src/data/datasets/kitti360/kitti360_raw.py: -------------------------------------------------------------------------------- 1 | # Functions and classes operating on a raw Kitti dataset 2 | # This script is adapted from: https://github.com/jac99/Egonn/blob/main/datasets/kitti/kitti_raw.py 3 | 4 | import os 5 | import sys 6 | from datetime import datetime 7 | 8 | import numpy as np 9 | import torch 10 | from torch.utils.data import Dataset 11 | 12 | sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) 13 | from datasets.point_clouds_utils import PointCloudLoader 14 | 15 | 16 | class Kitti360PointCloudLoader(PointCloudLoader): 17 | def set_properties(self): 18 | # Set point cloud properties, such as ground_plane_level. 19 | self.ground_plane_level = -1.5 20 | 21 | def read_pc(self, file_pathname: str) -> torch.Tensor: 22 | # Reads the point cloud without pre-processing 23 | # Returns Nx3 tensor 24 | pc = np.fromfile(file_pathname, dtype=np.float32) 25 | # PC in Mulran is of size [num_points, 4] -> x,y,z,reflectance 26 | pc = np.reshape(pc, (-1, 4))[:, :3] 27 | # pc = np.reshape(pc, (-1, 4)) 28 | return pc 29 | 30 | 31 | class Kitti360Sequence(Dataset): 32 | """ 33 | Point cloud from a sequence from a raw Mulran dataset 34 | """ 35 | def __init__(self, dataset_root: str, sequence_name: str, pose_time_tolerance: float = 1., 36 | remove_zero_points: bool = True): 37 | # pose_time_tolerance: (in seconds) skip point clouds without corresponding pose information (based on 38 | # timestamps difference) 39 | # remove_zero_points: remove (0,0,0) points 40 | 41 | assert os.path.exists(dataset_root), f'Cannot access dataset root: {dataset_root}' 42 | self.dataset_root = dataset_root 43 | self.sequence_name = '2013_05_28_drive_00'+ sequence_name + '_sync' 44 | # self.sequence_path = os.path.join(self.dataset_root, 'sequences') 45 | # assert os.path.exists(self.sequence_path), f'Cannot access sequence: {self.sequence_path}' 46 | self.rel_lidar_path = os.path.join(self.sequence_name, 'velodyne_points/data') 47 | # lidar_path = os.path.join(self.sequence_path, self.rel_lidar_path) 48 | # assert os.path.exists(lidar_path), f'Cannot access lidar scans: {lidar_path}' 49 | self.pose_file = os.path.join(self.dataset_root, self.sequence_name , 'poses.txt') 50 | self.calib_file = os.path.join(self.dataset_root, self.sequence_name , 'cam0_to_world.txt') 51 | assert os.path.exists(self.pose_file), f'Cannot access sequence pose file: {self.pose_file}' 52 | self.times_file = os.path.join(self.dataset_root, self.sequence_name, 'velodyne_points/timestamps.txt') 53 | assert os.path.exists(self.pose_file), f'Cannot access sequence times file: {self.times_file}' 54 | # Maximum discrepancy between timestamps of LiDAR scan and global pose in seconds 55 | self.pose_time_tolerance = pose_time_tolerance 56 | self.remove_zero_points = remove_zero_points 57 | 58 | self.rel_lidar_timestamps, self.lidar_poses, filenames = self._read_lidar_poses() 59 | self.rel_scan_filepath = [os.path.join(self.rel_lidar_path, '%010d%s' % (e, '.bin')) for e in filenames] 60 | print('') 61 | 62 | def __len__(self): 63 | return len(self.rel_lidar_timestamps) 64 | 65 | def __getitem__(self, ndx): 66 | scan_filepath = os.path.join(self.dataset_root, self.rel_scan_filepath[ndx]) 67 | pc = load_pc(scan_filepath) 68 | if self.remove_zero_points: 69 | mask = np.all(np.isclose(pc, 0), axis=1) 70 | pc = pc[~mask] 71 | return {'pc': pc, 'pose': self.lidar_poses[ndx], 'ts': self.rel_lidar_timestamps[ndx]} 72 | 73 | def _read_lidar_poses(self): 74 | fnames = os.listdir(os.path.join(self.dataset_root, self.rel_lidar_path)) 75 | temp = os.path.join(self.dataset_root, self.rel_lidar_path) 76 | fnames = [e for e in fnames if os.path.isfile(os.path.join(temp, e))] 77 | assert len(fnames) > 0, f"Make sure that the path {self.rel_lidar_path}" 78 | # filenames = sorted([int(os.path.split(fname)[-1][:-4]) for fname in fnames]) 79 | # with open(self.calib_file, 'r') as f: 80 | # for line in f.readlines(): 81 | # data = np.array([float(x) for x in line.split()]) 82 | 83 | # cam0_to_velo = np.reshape(data, (3, 4)) 84 | # cam0_to_velo = np.vstack([cam0_to_velo, [0, 0, 0, 1]]) 85 | 86 | 87 | poses,_,_ = load_poses_from_txt(self.pose_file)#, cam0_to_velo) 88 | sorted_keys = sorted(poses.keys()) 89 | poses_list = [poses[k] for k in sorted_keys] 90 | filenames = sorted([int(key) for key in poses]) 91 | ts = load_timestamps(self.times_file) 92 | ts = np.asarray(ts)[filenames] 93 | rel_ts = ts - ts[0] 94 | 95 | return rel_ts, poses_list, filenames 96 | 97 | 98 | def load_pc(filepath): 99 | # Load point cloud, does not apply any transform 100 | # Returns Nx3 matrix 101 | pc = np.fromfile(filepath, dtype=np.float32) 102 | # PC in Kitti is of size [num_points, 4] -> x,y,z,reflectance 103 | pc = np.reshape(pc, (-1, 4))[:, :3] 104 | return pc 105 | 106 | def load_poses_from_txt(file_name):#, cam0_to_velo): 107 | f = open(file_name, 'r') 108 | s = f.readlines() 109 | f.close() 110 | transforms = {} 111 | x = [] 112 | y = [] 113 | for cnt, line in enumerate(s): 114 | P = np.eye(4) 115 | line_split = [float(i) for i in line.split(" ") if i!="" and i!="\n"] 116 | withIdx = len(line_split) >= 13 117 | for row in range(3): 118 | for col in range(4): 119 | P[row, col] = line_split[row*4 + col + withIdx] 120 | if withIdx: 121 | frame_idx = line_split[0] 122 | else: 123 | frame_idx = cnt 124 | transforms[frame_idx] = P #@ cam0_to_velo.inverse() 125 | x.append(P[0, 3]) 126 | y.append(P[1, 3]) 127 | return transforms, x, y 128 | 129 | def load_timestamps(file_name): 130 | f = open(file_name, 'r') 131 | s = f.readlines() 132 | f.close() 133 | times = [] 134 | 135 | for cnt, line in enumerate(s):#2013-05-28 11:36:55.89086054 136 | dt_obj = datetime.strptime(line[:-4], '%Y-%m-%d %H:%M:%S.%f') 137 | 138 | # nanosec = dt_obj.timestamp() * 10e9 139 | sec = dt_obj.timestamp() 140 | 141 | times.append(sec) 142 | return times 143 | -------------------------------------------------------------------------------- /src/data/datasets/kitti360/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def kitti360_calib_transform(init_T): 5 | cam_to_velo_data = [0.04307104361, -0.08829286498, 0.995162929, 0.8043914418, 6 | -0.999004371, 0.007784614041, 0.04392796942, 0.2993489574, 7 | -0.01162548558, -0.9960641394, -0.08786966659, -0.1770225824] 8 | cam0_to_velo = np.reshape(cam_to_velo_data, (3, 4)) 9 | cam0_to_velo = np.vstack([cam0_to_velo, [0, 0, 0, 1]]) 10 | 11 | cam_to_pose_data = [0.0371783278, -0.0986182135, 0.9944306009, 1.5752681039, 12 | 0.9992675562, -0.0053553387, -0.0378902567, 0.0043914093, 13 | 0.0090621821, 0.9951109327, 0.0983468786, -0.6500000000] 14 | cam0_to_pose = np.reshape(cam_to_pose_data, (3, 4)) 15 | cam0_to_pose = np.vstack([cam0_to_pose, [0, 0, 0, 1]]) 16 | 17 | return init_T @ cam0_to_pose @ np.linalg.inv(cam0_to_velo) 18 | 19 | 20 | def kitti360_relative_pose(pose_1, pose_2): 21 | pose_1 = kitti360_calib_transform(pose_1) 22 | pose_2 = kitti360_calib_transform(pose_2) 23 | return np.linalg.inv(pose_2) @ pose_1 24 | 25 | # def relative_pose(m1, m2): 26 | 27 | # return np.linalg.inv(m2) @ m1 -------------------------------------------------------------------------------- /src/data/datasets/mulran/generate_evaluation_sets.py: -------------------------------------------------------------------------------- 1 | # Test sets for Mulran dataset. 2 | # This file is adapted from: https://github.com/jac99/Egonn/blob/main/datasets/mulran/generate_evaluation_sets.py 3 | 4 | import argparse 5 | import os 6 | import sys 7 | from typing import List 8 | 9 | sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) 10 | 11 | 12 | from datasets.base_datasets import EvaluationSet, EvaluationTuple, filter_query_elements 13 | from datasets.mulran.mulran_raw import MulranSequence 14 | 15 | # # KAIST 02 16 | # MAP_TIMERANGE = (1566535940856033867, 1566536300000000000) 17 | # QUERY_TIMERANGE = (1566536300000000000, 1566536825534173166) 18 | # # Riverside 01: 19 | # # MAP_TIMERANGE = (1564718063503232284, 1564718300000000000) 20 | # # QUERY_TIMERANGE = (1564718300000000000, 1564718603800415528) 21 | 22 | 23 | def get_scans(sequence: MulranSequence, ts_range: tuple = None) -> List[EvaluationTuple]: 24 | # Get a list of all readings from the test area in the sequence 25 | elems = [] 26 | for ndx in range(len(sequence)): 27 | if ts_range is not None: 28 | if (ts_range[0] > sequence.timestamps[ndx]) or (ts_range[1] < sequence.timestamps[ndx]): 29 | continue 30 | pose = sequence.poses[ndx] 31 | position = pose[:2, 3] 32 | item = EvaluationTuple(sequence.timestamps[ndx], sequence.rel_scan_filepath[ndx], position=position, pose=pose) 33 | elems.append(item) 34 | return elems 35 | 36 | 37 | def generate_evaluation_set(dataset_root: str, map_sequence_name: str, query_sequence_name: str, min_displacement: float = 0.2, 38 | dist_threshold=20) -> EvaluationSet: 39 | split = 'test' 40 | map_sequence = MulranSequence(dataset_root, map_sequence_name, split=split, min_displacement=min_displacement) 41 | query_sequence = MulranSequence(dataset_root, query_sequence_name, split=split, min_displacement=min_displacement) 42 | print(min_displacement) 43 | 44 | if map_sequence_name == query_sequence_name: 45 | print('Wrong Wrong Wrong') 46 | map_set = get_scans(map_sequence, MAP_TIMERANGE) 47 | query_set = get_scans(query_sequence, QUERY_TIMERANGE) 48 | else: 49 | map_set = get_scans(map_sequence) 50 | query_set = get_scans(query_sequence) 51 | 52 | # Function used in evaluation dataset generation 53 | # Filters out query elements without a corresponding map element within dist_threshold threshold 54 | query_set = filter_query_elements(query_set, map_set, dist_threshold) 55 | print(f'{len(map_set)} database elements, {len(query_set)} query elements') 56 | return EvaluationSet(query_set, map_set) 57 | 58 | 59 | if __name__ == '__main__': 60 | parser = argparse.ArgumentParser(description='Generate evaluation sets for Mulran dataset') 61 | parser.add_argument('--dataset_root', type=str, required=False, default='/data/raktim/Datasets/Mulran/Sejong') 62 | parser.add_argument('--sequence', type=str, required=False, default='Sejong') 63 | parser.add_argument('--min_displacement', type=float, default=0.2) 64 | # Ignore query elements that do not have a corresponding map element within the given threshold (in meters) 65 | parser.add_argument('--dist_threshold', type=float, default=20) 66 | args = parser.parse_args() 67 | 68 | # Sequences is a list of (map sequence, query sequence) 69 | sequences = [('Sejong1', 'Sejong2')] 70 | if args.sequence == 'DCC': 71 | sequences = [('DCC1', 'DCC2')] 72 | args.min_displacement = 10.0 73 | args.dist_threshold = 5 74 | 75 | 76 | print(f'Dataset root: {args.dataset_root}') 77 | print(f'Minimum displacement between consecutive anchors: {args.min_displacement}') 78 | print(f'Ignore query elements without a corresponding map element within a threshold [m]: {args.dist_threshold}') 79 | 80 | 81 | for map_sequence, query_sequence in sequences: 82 | print(f'Map sequence: {map_sequence}') 83 | print(f'Query sequence: {query_sequence}') 84 | 85 | test_set = generate_evaluation_set(args.dataset_root, map_sequence, query_sequence, 86 | min_displacement=args.min_displacement, dist_threshold=args.dist_threshold) 87 | 88 | pickle_name = f'test_{map_sequence}_{query_sequence}_{args.min_displacement}_{args.dist_threshold}.pickle' 89 | # file_path_name = os.path.join(args.dataset_root, pickle_name) 90 | file_path_name = os.path.join(os.path.dirname(__file__), pickle_name) 91 | print(f"Saving evaluation pickle: {file_path_name}") 92 | test_set.save(file_path_name) 93 | -------------------------------------------------------------------------------- /src/data/datasets/mulran/generate_training_tuples.py: -------------------------------------------------------------------------------- 1 | # Training tuples generation for Mulran dataset. 2 | 3 | import argparse 4 | import os 5 | import pickle 6 | import sys 7 | 8 | import numpy as np 9 | import tqdm 10 | 11 | sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) 12 | 13 | from datasets.base_datasets import TrainingTuple 14 | from datasets.mulran.mulran_raw import MulranSequences 15 | from datasets.mulran.utils import relative_pose 16 | 17 | from misc.point_clouds import icp 18 | 19 | DEBUG = False 20 | 21 | 22 | def load_pc(file_pathname): 23 | # Load point cloud, clip x, y and z coords (points far away and the ground plane) 24 | # Returns Nx3 matrix 25 | pc = np.fromfile(file_pathname, dtype=np.float32) 26 | # PC in Mulran is of size [num_points, 4] -> x,y,z,reflectance 27 | pc = np.reshape(pc, (-1, 4))[:, :3] 28 | 29 | mask = np.all(np.isclose(pc, 0.), axis=1) 30 | pc = pc[~mask] 31 | mask = pc[:, 0] > -80 32 | pc = pc[mask] 33 | mask = pc[:, 0] <= 80 34 | 35 | pc = pc[mask] 36 | mask = pc[:, 1] > -80 37 | pc = pc[mask] 38 | mask = pc[:, 1] <= 80 39 | pc = pc[mask] 40 | 41 | mask = pc[:, 2] > -0.9 42 | pc = pc[mask] 43 | return pc 44 | 45 | 46 | def generate_training_tuples(ds: MulranSequences, pos_threshold: float = 10, neg_threshold: float = 50): 47 | # displacement: displacement between consecutive anchors (if None all scans are takes as anchors). 48 | # Use some small displacement to ensure there's only one scan if the vehicle does not move 49 | 50 | tuples = {} # Dictionary of training tuples: tuples[ndx] = (sef ot positives, set of non negatives) 51 | for anchor_ndx in tqdm.tqdm(range(len(ds))): 52 | anchor_pos = ds.get_xy()[anchor_ndx] 53 | 54 | # Find timestamps of positive and negative elements 55 | positives = ds.find_neighbours_ndx(anchor_pos, pos_threshold) 56 | non_negatives = ds.find_neighbours_ndx(anchor_pos, neg_threshold) 57 | # Remove anchor element from positives, but leave it in non_negatives 58 | positives = positives[positives != anchor_ndx] 59 | 60 | # Sort ascending order 61 | positives = np.sort(positives) 62 | non_negatives = np.sort(non_negatives) 63 | 64 | # ICP pose refinement 65 | fitness_l = [] 66 | inlier_rmse_l = [] 67 | positive_poses = {} 68 | 69 | if True: 70 | # Use ground truth transform without pose refinement 71 | anchor_pose = ds.poses[anchor_ndx] 72 | for positive_ndx in positives: 73 | positive_pose = ds.poses[positive_ndx] 74 | # Compute initial relative pose 75 | m, fitness, inlier_rmse = relative_pose(anchor_pose, positive_pose), 1., 1. 76 | fitness_l.append(fitness) 77 | inlier_rmse_l.append(inlier_rmse) 78 | positive_poses[positive_ndx] = m 79 | else: 80 | anchor_pc = load_pc(os.path.join(ds.dataset_root, ds.rel_scan_filepath[anchor_ndx])) 81 | anchor_pose = ds.poses[anchor_ndx] 82 | for positive_ndx in positives: 83 | positive_pc = load_pc(os.path.join(ds.dataset_root, ds.rel_scan_filepath[positive_ndx])) 84 | positive_pose = ds.poses[positive_ndx] 85 | # Compute initial relative pose 86 | transform = relative_pose(anchor_pose, positive_pose) 87 | # Refine the pose using ICP 88 | m, fitness, inlier_rmse = icp(anchor_pc, positive_pc, transform) 89 | 90 | fitness_l.append(fitness) 91 | inlier_rmse_l.append(inlier_rmse) 92 | positive_poses[positive_ndx] = m 93 | 94 | # Tuple(id: int, timestamp: int, rel_scan_filepath: str, positives: List[int], non_negatives: List[int]) 95 | tuples[anchor_ndx] = TrainingTuple(id=anchor_ndx, timestamp=ds.timestamps[anchor_ndx], 96 | rel_scan_filepath=ds.dataset_root + '/' + ds.rel_scan_filepath[anchor_ndx], 97 | positives=positives, non_negatives=non_negatives, pose=anchor_pose, 98 | positives_poses=positive_poses) 99 | 100 | print(f'{len(tuples)} training tuples generated') 101 | print('ICP pose refimenement stats:') 102 | print(f'Fitness - min: {np.min(fitness_l):0.3f} mean: {np.mean(fitness_l):0.3f} max: {np.max(fitness_l):0.3f}') 103 | print(f'Inlier RMSE - min: {np.min(inlier_rmse_l):0.3f} mean: {np.mean(inlier_rmse_l):0.3f} max: {np.max(inlier_rmse_l):0.3f}') 104 | 105 | return tuples 106 | 107 | 108 | if __name__ == '__main__': 109 | parser = argparse.ArgumentParser(description='Generate training tuples') 110 | parser.add_argument('--dataset_root', type=str, default='/data/raktim/Datasets/Mulran/Sejong') 111 | parser.add_argument('--pos_threshold', default=2) 112 | parser.add_argument('--neg_threshold', default=10) 113 | parser.add_argument('--min_displacement', type=float, default=0.2) 114 | args = parser.parse_args() 115 | 116 | sequences = ['Sejong1', 'Sejong2'] 117 | if DEBUG: 118 | sequences = ['ParkingLot', 'ParkingLot'] 119 | 120 | print(f'Dataset root: {args.dataset_root}') 121 | print(f'Sequences: {sequences}') 122 | print(f'Threshold for positive examples: {args.pos_threshold}') 123 | print(f'Threshold for negative examples: {args.neg_threshold}') 124 | print(f'Minimum displacement between consecutive anchors: {args.min_displacement}') 125 | 126 | ds = MulranSequences(args.dataset_root, sequences, split='train', min_displacement=args.min_displacement) 127 | train_tuples = generate_training_tuples(ds, args.pos_threshold, args.neg_threshold) 128 | pickle_name = f'train_{sequences[0]}_{sequences[1]}_{args.pos_threshold}_{args.neg_threshold}.pickle' 129 | train_tuples_filepath = os.path.join(args.dataset_root, pickle_name) 130 | pickle.dump(train_tuples, open(train_tuples_filepath, 'wb')) 131 | train_tuples = None 132 | 133 | ds = MulranSequences(args.dataset_root, sequences, split='test', min_displacement=args.min_displacement) 134 | test_tuples = generate_training_tuples(ds, args.pos_threshold, args.neg_threshold) 135 | pickle_name = f'val_{sequences[0]}_{sequences[1]}_{args.pos_threshold}_{args.neg_threshold}.pickle' 136 | test_tuples_filepath = os.path.join(args.dataset_root, pickle_name) 137 | pickle.dump(test_tuples, open(test_tuples_filepath, 'wb')) 138 | -------------------------------------------------------------------------------- /src/data/datasets/mulran/mulran_train.py: -------------------------------------------------------------------------------- 1 | # Warsaw University of Technology 2 | # Dataset wrapper for Mulran lidar scans dataset 3 | 4 | import os 5 | import random 6 | import numpy as np 7 | import torch 8 | 9 | from datasets.base_datasets import TrainingDataset 10 | from datasets.quantization import Quantizer 11 | from misc.poses import apply_transform 12 | from datasets.base_datasets import TrainingDataset 13 | 14 | DEBUG = False 15 | 16 | 17 | class MulranTraining6DOFDataset(TrainingDataset): 18 | """ 19 | Dataset wrapper for Mulran dataset for 6dof estimation. 20 | """ 21 | def __init__(self, dataset_path: str, query_filename: str, quantizer: Quantizer, 22 | rot_max: float = 0., trans_max: float = 0., **vargs): 23 | dataset_type = 'mulran' 24 | super().__init__(dataset_path, dataset_type, query_filename, **vargs) 25 | self.quantizer = quantizer 26 | self.rot_max = rot_max 27 | self.trans_max = trans_max 28 | 29 | def __getitem__(self, ndx): 30 | # pose is a global coordinate system pose 3x4 R|T matrix 31 | query_pc, _ = super().__getitem__(ndx) 32 | 33 | # get random positive 34 | positives = self.get_positives(ndx) 35 | positive_idx = np.random.choice(positives, 1)[0] 36 | positive_pc, _ = super().__getitem__(positive_idx) 37 | 38 | # get relative pose taking two global poses 39 | transform = self.queries[ndx].positives_poses[positive_idx] 40 | 41 | # Apply random transform to the positive point cloud 42 | rotation_angle = np.random.uniform(low=-self.rot_max, high=self.rot_max) 43 | cosval = np.cos(rotation_angle) 44 | sinval = np.sin(rotation_angle) 45 | m = torch.eye(4, dtype=torch.float) 46 | #m[:3, :3] = np.array([[cosval, sinval, 0.], [-sinval, cosval, 0.], [0., 0., 1.]]) 47 | m[:3, :3] = torch.tensor([[cosval, sinval, 0.], [-sinval, cosval, 0.], [0., 0., 1.]], dtype=m.dtype) 48 | m[:2, 3] = torch.rand((1, 2)) * 2. * self.trans_max - self.trans_max 49 | positive_pc = apply_transform(positive_pc, m) 50 | transform = m @ transform 51 | 52 | # Find indices of unique quantized coordinates and filter out points to leave max 1 point per voxel 53 | coords1, idx1 = self.quantizer(query_pc) 54 | coords2, idx2 = self.quantizer(positive_pc) 55 | pc1_cop = query_pc[idx1, :] 56 | pc2_trans_cop = positive_pc[idx2, :] 57 | 58 | return pc1_cop, pc2_trans_cop, transform 59 | -------------------------------------------------------------------------------- /src/data/datasets/mulran/test_DCC1_DCC2_10.0_5.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raktimgg/SALSA/9152e8dff6fb3fc07dd546c9f19bfab462d08a6f/src/data/datasets/mulran/test_DCC1_DCC2_10.0_5.pickle -------------------------------------------------------------------------------- /src/data/datasets/mulran/test_Sejong1_Sejong2_0.2_20.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raktimgg/SALSA/9152e8dff6fb3fc07dd546c9f19bfab462d08a6f/src/data/datasets/mulran/test_Sejong1_Sejong2_0.2_20.pickle -------------------------------------------------------------------------------- /src/data/datasets/mulran/test_Sejong1_Sejong2_0.2_5.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raktimgg/SALSA/9152e8dff6fb3fc07dd546c9f19bfab462d08a6f/src/data/datasets/mulran/test_Sejong1_Sejong2_0.2_5.pickle -------------------------------------------------------------------------------- /src/data/datasets/mulran/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from scipy.spatial import distance_matrix 4 | 5 | # Faulty point clouds (with 0 points) 6 | FAULTY_POINTCLOUDS = [1566279795718079314] 7 | 8 | # Coordinates of test region centres (in Sejong sequence) 9 | TEST_REGION_CENTRES = np.array([[345090.0743, 4037591.323], [345090.483, 4044700.04], 10 | [350552.0308, 4041000.71], [349252.0308, 4044800.71]]) 11 | 12 | # Radius of the test region 13 | TEST_REGION_RADIUS = 500 14 | 15 | # Boundary between training and test region - to ensure there's no overlap between training and test clouds 16 | TEST_TRAIN_BOUNDARY = 50 17 | 18 | 19 | def in_train_split(pos): 20 | # returns true if pos is in train split 21 | assert pos.ndim == 2 22 | assert pos.shape[1] == 2 23 | dist = distance_matrix(pos, TEST_REGION_CENTRES) 24 | mask = (dist > TEST_REGION_RADIUS + TEST_TRAIN_BOUNDARY).all(axis=1) 25 | return mask 26 | 27 | 28 | def in_test_split(pos): 29 | # returns true if position is in evaluation split 30 | assert pos.ndim == 2 31 | assert pos.shape[1] == 2 32 | dist = distance_matrix(pos, TEST_REGION_CENTRES) 33 | mask = (dist < TEST_REGION_RADIUS).any(axis=1) 34 | return mask 35 | 36 | 37 | def find_nearest_ndx(ts, timestamps): 38 | ndx = np.searchsorted(timestamps, ts) 39 | if ndx == 0: 40 | return ndx 41 | elif ndx == len(timestamps): 42 | return ndx - 1 43 | else: 44 | assert timestamps[ndx-1] <= ts <= timestamps[ndx] 45 | if ts - timestamps[ndx-1] < timestamps[ndx] - ts: 46 | return ndx - 1 47 | else: 48 | return ndx 49 | 50 | 51 | def read_lidar_poses(poses_filepath: str, lidar_filepath: str, pose_time_tolerance: float = 1.): 52 | # Read global poses from .csv file and link each lidar_scan with the nearest pose 53 | # threshold: threshold in seconds 54 | # Returns a dictionary with (4, 4) pose matrix indexed by a timestamp (as integer) 55 | 56 | with open(poses_filepath, "r") as h: 57 | txt_poses = h.readlines() 58 | 59 | n = len(txt_poses) 60 | system_timestamps = np.zeros((n,), dtype=np.int64) 61 | poses = np.zeros((n, 4, 4), dtype=np.float64) # 4x4 pose matrix 62 | 63 | for ndx, pose in enumerate(txt_poses): 64 | # Split by comma and remove whitespaces 65 | temp = [e.strip() for e in pose.split(',')] 66 | assert len(temp) == 13, f'Invalid line in global poses file: {temp}' 67 | system_timestamps[ndx] = int(temp[0]) 68 | poses[ndx] = np.array([[float(temp[1]), float(temp[2]), float(temp[3]), float(temp[4])], 69 | [float(temp[5]), float(temp[6]), float(temp[7]), float(temp[8])], 70 | [float(temp[9]), float(temp[10]), float(temp[11]), float(temp[12])], 71 | [0., 0., 0., 1.]]) 72 | 73 | # Ensure timestamps and poses are sorted in ascending order 74 | sorted_ndx = np.argsort(system_timestamps, axis=0) 75 | system_timestamps = system_timestamps[sorted_ndx] 76 | poses = poses[sorted_ndx] 77 | 78 | # List LiDAR scan timestamps 79 | all_lidar_timestamps = [int(os.path.splitext(f)[0]) for f in os.listdir(lidar_filepath) if 80 | os.path.splitext(f)[1] == '.bin'] 81 | all_lidar_timestamps.sort() 82 | 83 | lidar_timestamps = [] 84 | lidar_poses = [] 85 | count_rejected = 0 86 | 87 | for ndx, lidar_ts in enumerate(all_lidar_timestamps): 88 | # Skip faulty point clouds 89 | if lidar_ts in FAULTY_POINTCLOUDS: 90 | continue 91 | 92 | # Find index of the closest timestamp 93 | closest_ts_ndx = find_nearest_ndx(lidar_ts, system_timestamps) 94 | delta = abs(system_timestamps[closest_ts_ndx] - lidar_ts) 95 | # Timestamp is in nanoseconds = 1e-9 second 96 | if delta > pose_time_tolerance * 1000000000: 97 | # Reject point cloud without corresponding pose 98 | count_rejected += 1 99 | continue 100 | 101 | lidar_timestamps.append(lidar_ts) 102 | lidar_poses.append(poses[closest_ts_ndx]) 103 | 104 | lidar_timestamps = np.array(lidar_timestamps, dtype=np.int64) 105 | lidar_poses = np.array(lidar_poses, dtype=np.float64) # (northing, easting) position 106 | 107 | print(f'{len(lidar_timestamps)} scans with valid pose, {count_rejected} rejected due to unknown pose') 108 | return lidar_timestamps, lidar_poses 109 | 110 | 111 | def relative_pose(m1, m2): 112 | # SE(3) pose is 4x 4 matrix, such that 113 | # Pw = [R | T] @ [P] 114 | # [0 | 1] [1] 115 | # where Pw are coordinates in the world reference frame and P are coordinates in the camera frame 116 | # m1: coords in camera/lidar1 reference frame -> world coordinate frame 117 | # m2: coords in camera/lidar2 coords -> world coordinate frame 118 | # returns: relative pose of the first camera with respect to the second camera 119 | # transformation matrix to convert coords in camera/lidar1 reference frame to coords in 120 | # camera/lidar2 reference frame 121 | # 122 | m = np.linalg.inv(m2) @ m1 123 | # !!!!!!!!!! Fix for relative pose !!!!!!!!!!!!! 124 | m[:3, 3] = -m[:3, 3] 125 | return m 126 | -------------------------------------------------------------------------------- /src/data/datasets/point_clouds_utils.py: -------------------------------------------------------------------------------- 1 | # This file is adapted from: https://github.com/jac99/Egonn/blob/main/misc/point_clouds.py 2 | 3 | import copy 4 | import os 5 | 6 | import numpy as np 7 | import open3d as o3d 8 | 9 | 10 | def draw_registration_result(source, target, transformation): 11 | source_temp = copy.deepcopy(source) 12 | target_temp = copy.deepcopy(target) 13 | source_temp.paint_uniform_color([1, 0.706, 0]) 14 | target_temp.paint_uniform_color([0, 0.651, 0.929]) 15 | source_temp.transform(transformation) 16 | o3d.visualization.draw_geometries([source_temp, target_temp], 17 | zoom=0.4459, 18 | front=[0.9288, -0.2951, -0.2242], 19 | lookat=[1.6784, 2.0612, 1.4451], 20 | up=[-0.3402, -0.9189, -0.1996]) 21 | 22 | 23 | def draw_pc(pc): 24 | pc = copy.deepcopy(pc) 25 | pc.paint_uniform_color([1, 0.706, 0]) 26 | o3d.visualization.draw_geometries([pc], 27 | zoom=0.4459, 28 | front=[0.9288, -0.2951, -0.2242], 29 | lookat=[1.6784, 2.0612, 1.4451], 30 | up=[-0.3402, -0.9189, -0.1996]) 31 | 32 | 33 | def icp(anchor_pc, positive_pc, transform: np.ndarray = None, point2plane: bool = False, 34 | inlier_dist_threshold: float = 1.2, max_iteration: int = 200): 35 | # transform: initial alignment transform 36 | if transform is not None: 37 | transform = transform.astype(float) 38 | 39 | voxel_size = 0.1 40 | pcd1 = o3d.geometry.PointCloud() 41 | pcd1.points = o3d.utility.Vector3dVector(anchor_pc) 42 | pcd1 = pcd1.voxel_down_sample(voxel_size=voxel_size) 43 | 44 | pcd2 = o3d.geometry.PointCloud() 45 | pcd2.points = o3d.utility.Vector3dVector(positive_pc) 46 | pcd2 = pcd2.voxel_down_sample(voxel_size=voxel_size) 47 | 48 | if point2plane: 49 | pcd1.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamKNN(knn=20)) 50 | pcd2.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamKNN(knn=20)) 51 | transform_estimation = o3d.pipelines.registration.TransformationEstimationPointToPlane() 52 | else: 53 | transform_estimation = o3d.pipelines.registration.TransformationEstimationPointToPoint() 54 | 55 | if transform is not None: 56 | reg_p2p = o3d.pipelines.registration.registration_icp(pcd1, pcd2, inlier_dist_threshold, transform, 57 | estimation_method=transform_estimation, 58 | criteria=o3d.pipelines.registration.ICPConvergenceCriteria(max_iteration=max_iteration)) 59 | else: 60 | reg_p2p = o3d.pipelines.registration.registration_icp(pcd1, pcd2, inlier_dist_threshold, 61 | estimation_method=transform_estimation, 62 | criteria=o3d.pipelines.registration.ICPConvergenceCriteria(max_iteration=max_iteration)) 63 | 64 | return reg_p2p.transformation, reg_p2p.fitness, reg_p2p.inlier_rmse 65 | 66 | 67 | def make_open3d_feature(data, dim, npts): 68 | feature = o3d.pipelines.registration.Feature() 69 | feature.resize(dim, npts) 70 | if not isinstance(data, np.ndarray): 71 | feature.data = data.cpu().numpy().astype('d').transpose() 72 | else: 73 | feature.data = data.astype('d').transpose() 74 | return feature 75 | 76 | 77 | def make_open3d_point_cloud(xyz, color=None): 78 | pcd = o3d.geometry.PointCloud() 79 | pcd.points = o3d.utility.Vector3dVector(xyz) 80 | if color is not None: 81 | pcd.colors = o3d.utility.Vector3dVector(color) 82 | return pcd 83 | 84 | def preprocess_pointcloud(pc, remove_zero_points: bool = False, 85 | min_x: float = None, max_x: float = None, 86 | min_y: float = None, max_y: float = None, 87 | min_z: float = None, max_z: float = None): 88 | if remove_zero_points: 89 | mask = np.all(np.isclose(pc, 0.), axis=1) 90 | pc = pc[~mask] 91 | 92 | if min_x is not None: 93 | mask = pc[:, 0] > min_x 94 | pc = pc[mask] 95 | 96 | if max_x is not None: 97 | mask = pc[:, 0] <= max_x 98 | pc = pc[mask] 99 | 100 | if min_y is not None: 101 | mask = pc[:, 1] > min_y 102 | pc = pc[mask] 103 | 104 | if max_y is not None: 105 | mask = pc[:, 1] <= max_y 106 | pc = pc[mask] 107 | 108 | if min_z is not None: 109 | mask = pc[:, 2] > min_z 110 | pc = pc[mask] 111 | 112 | if max_z is not None: 113 | mask = pc[:, 2] <= max_z 114 | pc = pc[mask] 115 | 116 | return pc 117 | 118 | 119 | class PointCloudLoader: 120 | # Generic point cloud loader class 121 | def __init__(self): 122 | # remove_zero_points: remove points with all zero coordinates 123 | # remove_ground_plane: remove points on ground plane level and below 124 | # ground_plane_level: ground plane level 125 | self.remove_zero_points = True 126 | self.remove_ground_plane = True 127 | self.ground_plane_level = None 128 | self.set_properties() 129 | 130 | def set_properties(self): 131 | # Set point cloud properties, such as ground_plane_level. Must be defined in inherited classes. 132 | raise NotImplementedError('set_properties must be defined in inherited classes') 133 | 134 | def __call__(self, file_pathname): 135 | # Reads the point cloud from a disk and preprocess (optional removal of zero points and points on the ground 136 | # plane and below 137 | # file_pathname: relative file path 138 | assert os.path.exists(file_pathname), f"Cannot open point cloud: {file_pathname}" 139 | pc = self.read_pc(file_pathname) 140 | # assert pc.shape[1] == 3 141 | 142 | if self.remove_zero_points: 143 | mask = np.all(np.isclose(pc, 0), axis=1) 144 | pc = pc[~mask] 145 | 146 | if self.remove_ground_plane: 147 | mask = pc[:, 2] > self.ground_plane_level 148 | pc = pc[mask] 149 | 150 | return pc 151 | 152 | def read_pc(self, file_pathname): 153 | # Reads the point cloud without pre-processing 154 | raise NotImplementedError("read_pc must be overloaded in an inheriting class") 155 | -------------------------------------------------------------------------------- /src/data/datasets/poses_utils.py: -------------------------------------------------------------------------------- 1 | # This file is directly copied from: https://github.com/jac99/Egonn/blob/main/misc/poses.py 2 | 3 | import numpy as np 4 | import torch 5 | 6 | 7 | def q2r(q): 8 | # Rotation matrix from Hamiltonian quaternion 9 | # Source: https://en.wikipedia.org/wiki/Conversion_between_quaternions_and_Euler_angles 10 | w, x, y, z = tuple(q) 11 | 12 | n = 1.0/np.sqrt(x*x+y*y+z*z+w*w) 13 | x *= n 14 | y *= n 15 | z *= n 16 | w *= n 17 | r = np.array([[1.0 - 2.0*y*y - 2.0*z*z, 2.0*x*y - 2.0*z*w, 2.0*x*z + 2.0*y*w], 18 | [2.0*x*y + 2.0*z*w, 1.0 - 2.0*x*x - 2.0*z*z, 2.0*y*z - 2.0*x*w], 19 | [2.0*x*z - 2.0*y*w, 2.0*y*z + 2.0*x*w, 1.0 - 2.0*x*x - 2.0*y*y]]) 20 | return r 21 | 22 | 23 | def m2ypr(m): 24 | # Get yaw, pitch, roll angles from 4x4 transformation matrix 25 | # Based on formulas in Section 2.5.1 in: 26 | # A tutorial on SE(3) transformation parameterizations and on-manifold optimization 27 | # https://ingmec.ual.es/~jlblanco/papers/jlblanco2010geometry3D_techrep.pdf 28 | assert m.shape == (4, 4) 29 | pitch = np.arctan2(-m[2][0], np.sqrt(m[0][0]**2 + m[1][0]**2)) 30 | # We do not handle degenerate case, when pitch is 90 degrees a.k.a. gimball lock 31 | assert not np.isclose(np.abs(pitch), np.pi/2) 32 | yaw = np.arctan2(m[1][0], m[0][0]) 33 | roll = np.arctan2(m[2][1], m[2][2]) 34 | return yaw, pitch, roll 35 | 36 | 37 | def m2xyz_ypr(m): 38 | # Get yaw, pitch, roll angles from 4x4 transformation matrix 39 | # Based on formulas in Section 2.5.1 in: 40 | # A tutorial on SE(3) transformation parameterizations and on-manifold optimization 41 | # https://ingmec.ual.es/~jlblanco/papers/jlblanco2010geometry3D_techrep.pdf 42 | assert m.shape == (4, 4) 43 | yaw, pitch, roll = m2ypr(m) 44 | return m[0, 3], m[1, 3], m[2, 3], yaw, pitch, roll 45 | 46 | 47 | def ypr2m(yaw, pitch, roll): 48 | # Construct 4x4 transformation matrix with rotation part set to given yaw, pitch, roll. Translation is set to 0. 49 | # Based on formulas in Section 2.2.1 50 | m = np.array([[np.cos(yaw) * np.cos(pitch), np.cos(yaw) * np.sin(pitch) * np.sin(roll) - np.sin(yaw) * np.cos(roll), 51 | np.cos(yaw) * np.sin(pitch) * np.cos(roll) + np.sin(yaw) * np.sin(roll), 0.], 52 | [np.sin(yaw) * np.cos(pitch), np.sin(roll) * np.sin(pitch) * np.sin(roll) + np.cos(yaw) * np.cos(roll), 53 | np.sin(yaw) * np.sin(pitch) * np.cos(roll) - np.cos(yaw) * np.sin(roll), 0.], 54 | [-np.sin(pitch), np.cos(pitch) * np.sin(roll), np.cos(pitch) * np.cos(roll), 0.], 55 | [0., 0., 0., 1.]], dtype=np.float32) 56 | 57 | return m 58 | 59 | 60 | def xyz_ypr2m(x, y, z, yaw, pitch, roll): 61 | # Construct 4x4 transformation matrix with given translation and rotation part set to given yaw, pitch, roll. 62 | # Based on formulas in Section 2.2.1 63 | m = ypr2m(yaw, pitch, roll) 64 | m[0, 3] = x 65 | m[1, 3] = y 66 | m[2, 3] = z 67 | return m 68 | 69 | 70 | def apply_transform(pc: torch.Tensor, m: torch.Tensor): 71 | # Apply 4x4 SE(3) transformation matrix on (N, 3) point cloud or 3x3 transformation on (N, 2) point cloud 72 | assert pc.ndim == 2 73 | n_dim = pc.shape[1] 74 | assert n_dim == 2 or n_dim == 3 75 | assert m.shape == (n_dim + 1, n_dim + 1) 76 | # (m @ pc.t).t = pc @ m.t 77 | pc = pc @ m[:n_dim, :n_dim].transpose(1, 0) + m[:n_dim, -1] 78 | return pc 79 | 80 | 81 | def relative_pose(m1, m2): 82 | # !!! DO NOT USE THIS FUNCTION FOR MULRAN POSES !!! 83 | # SE(3) pose is 4x 4 matrix, such that 84 | # Pw = [R | T] @ [P] 85 | # [0 | 1] [1] 86 | # where Pw are coordinates in the world reference frame and P are coordinates in the camera frame 87 | # m1: coords in camera/lidar1 reference frame -> world coordinate frame 88 | # m2: coords in camera/lidar2 coords -> world coordinate frame 89 | # returns: coords in camera/lidar1 reference frame -> coords in camera/lidar2 reference frame 90 | # relative pose of the first camera with respect to the second camera 91 | return np.linalg.inv(m2) @ m1 92 | -------------------------------------------------------------------------------- /src/data/datasets/quantization.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from typing import List 3 | from abc import ABC, abstractmethod 4 | import torch 5 | # import MinkowskiEngine as ME 6 | 7 | 8 | class Quantizer(ABC): 9 | @abstractmethod 10 | def __call__(self, pc): 11 | pass 12 | 13 | @abstractmethod 14 | def dequantize(self, coords): 15 | pass 16 | 17 | @abstractmethod 18 | def keypoint_position(self, supervoxel_centers, stride, kp_offset): 19 | pass 20 | 21 | 22 | class PolarQuantizer(Quantizer): 23 | def __init__(self, quant_step: List[float]): 24 | assert len(quant_step) == 3, '3 quantization steps expected: for sector (in degrees), ring and z-coordinate (in meters)' 25 | self.quant_step = torch.tensor(quant_step, dtype=torch.float) 26 | self.theta_range = int(360. // self.quant_step[0]) 27 | self.quant_step = torch.tensor(quant_step, dtype=torch.float) 28 | 29 | def __call__(self, pc): 30 | # Convert to polar coordinates and quantize with different step size for each coordinate 31 | # pc: (N, 3) point cloud with Cartesian coordinates (X, Y, Z) 32 | assert pc.shape[1] == 3 33 | 34 | # theta is an angle in degrees in 0..360 range 35 | theta = 180. + torch.atan2(pc[:, 1], pc[:, 0]) * 180./np.pi 36 | # dist is a distance from a coordinate origin 37 | dist = torch.sqrt(pc[:, 0]**2 + pc[:, 1]**2) 38 | z = pc[:, 2] 39 | polar_pc = torch.stack([theta, dist, z], dim=1) 40 | # Scale each coordinate so after quantization with step 1. we got the required quantization step in each dim 41 | polar_pc = polar_pc / self.quant_step 42 | quantized_polar_pc, ndx = ME.utils.sparse_quantize(polar_pc, quantization_size=1., return_index=True) 43 | # Return quantized coordinates and index of selected elements 44 | return quantized_polar_pc, ndx 45 | 46 | def to_cartesian(self, pc): 47 | # Convert to radian in -180..180 range 48 | theta = np.pi * (pc[:, 0] - 180.) / 180. 49 | x = torch.cos(theta) * pc[:, 1] 50 | y = torch.sin(theta) * pc[:, 1] 51 | z = pc[:, 2] 52 | cartesian_pc = torch.stack([x, y, z], dim=1) 53 | return cartesian_pc 54 | 55 | def dequantize(self, coords): 56 | # Dequantize coords and convert to cartesian as (N, 3) tensor of floats 57 | pc = (0.5 + coords) * self.quant_step.to(coords.device) 58 | return self.to_cartesian(pc) 59 | 60 | def keypoint_position(self, supervoxel_centres, stride, kp_offset): 61 | # Add voxel center position: 0.5 * self.voxel_size 62 | # to offset from the supervoxel centre value (in -1..1 range converted to absolute values): 63 | # self.voxel_size + features * super_voxel_size / 2 64 | device = supervoxel_centres.device 65 | supervoxel_centres = (supervoxel_centres + 0.5) * self.quant_step.to(device) 66 | supervoxel_size = torch.tensor(stride, dtype=torch.float, device=supervoxel_centres.device) * \ 67 | self.quant_step.to(device) 68 | #kp_pos = supervoxel_centres 69 | kp_pos = supervoxel_centres + kp_offset * supervoxel_size / 2. 70 | 71 | kp_pos = self.to_cartesian(kp_pos) 72 | return kp_pos 73 | 74 | 75 | class CartesianQuantizer(Quantizer): 76 | def __init__(self, quant_step: float): 77 | self.quant_step = quant_step 78 | 79 | def __call__(self, pc): 80 | # Converts to polar coordinates and quantizes with different step size for each coordinate 81 | # pc: (N, 3) point cloud with Cartesian coordinates (X, Y, Z) 82 | assert pc.shape[1] == 3 83 | quantized_pc, ndx = ME.utils.sparse_quantize(pc, quantization_size=self.quant_step, return_index=True) 84 | # Return quantized coordinates and index of selected elements 85 | return quantized_pc, ndx 86 | 87 | def dequantize(self, coords): 88 | # Dequantize coords and return as (N, 3) tensor of floats 89 | # Use coords of the voxel center 90 | pc = (0.5 + coords) * self.quant_step 91 | return pc 92 | 93 | def keypoint_position(self, supervoxel_centers, stride, kp_offset): 94 | # Add voxel center position: 0.5 * self.voxel_size 95 | # to offset from the supervoxel centre value (in -1..1 range converted to absolute values): 96 | # self.voxel_size + features * super_voxel_size / 2 97 | supervoxel_centres = (supervoxel_centers + 0.5) * self.quant_step 98 | supervoxel_size = torch.tensor(stride, dtype=torch.float, device=supervoxel_centres.device) * self.quant_step 99 | if kp_offset is not None: 100 | kp_pos = supervoxel_centres + kp_offset * supervoxel_size / 2. 101 | else: 102 | kp_pos = supervoxel_centres 103 | return kp_pos 104 | 105 | 106 | if __name__ == "__main__": 107 | n = 1000 108 | cart = torch.rand((n, 3), dtype=torch.float) 109 | cart[:, 0] = cart[:, 0] * 200. - 100. 110 | cart[:, 1] = cart[:, 1] * 200. - 100. 111 | cart[:, 2] = cart[:, 2] * 30. - 10. 112 | 113 | quantizer = PolarQuantizer([0.5, 0.3, 0.2]) 114 | polar_quant, ndx = quantizer(cart) 115 | back2cart = quantizer.dequantize(polar_quant) 116 | cart_filtered = cart[ndx] 117 | dist = torch.norm(back2cart - cart_filtered, dim=1) 118 | print(f'Residual error - min: {torch.min(dist):0.5f} max: {torch.max(dist):0.5f} mean: {torch.mean(dist):0.5f}') 119 | 120 | -------------------------------------------------------------------------------- /src/data/datasets/samplers.py: -------------------------------------------------------------------------------- 1 | # Warsaw University of Technology 2 | 3 | import random 4 | import copy 5 | from torch.utils.data import Sampler 6 | 7 | from datasets.base_datasets import TrainingDataset 8 | 9 | VERBOSE = False 10 | 11 | 12 | class ListDict(object): 13 | def __init__(self, items=None): 14 | if items is not None: 15 | self.items = copy.deepcopy(items) 16 | self.item_to_position = {item: ndx for ndx, item in enumerate(items)} 17 | else: 18 | self.items = [] 19 | self.item_to_position = {} 20 | 21 | def add(self, item): 22 | if item in self.item_to_position: 23 | return 24 | self.items.append(item) 25 | self.item_to_position[item] = len(self.items)-1 26 | 27 | def remove(self, item): 28 | position = self.item_to_position.pop(item) 29 | last_item = self.items.pop() 30 | if position != len(self.items): 31 | self.items[position] = last_item 32 | self.item_to_position[last_item] = position 33 | 34 | def choose_random(self): 35 | return random.choice(self.items) 36 | 37 | def __contains__(self, item): 38 | return item in self.item_to_position 39 | 40 | def __iter__(self): 41 | return iter(self.items) 42 | 43 | def __len__(self): 44 | return len(self.items) 45 | 46 | 47 | class BatchSampler(Sampler): 48 | # Sampler returning list of indices to form a mini-batch 49 | # Samples elements in groups consisting of k=2 similar elements (positives) 50 | # Batch has the following structure: item1_1, ..., item1_k, item2_1, ... item2_k, itemn_1, ..., itemn_k 51 | def __init__(self, dataset: TrainingDataset, batch_size: int, batch_size_limit: int = None, 52 | batch_expansion_rate: float = None, max_batches: int = None): 53 | if batch_expansion_rate is not None: 54 | assert batch_expansion_rate > 1., 'batch_expansion_rate must be greater than 1' 55 | assert batch_size <= batch_size_limit, 'batch_size_limit must be greater or equal to batch_size' 56 | 57 | self.batch_size = batch_size 58 | self.batch_size_limit = batch_size_limit 59 | self.batch_expansion_rate = batch_expansion_rate 60 | self.max_batches = max_batches 61 | self.dataset = dataset 62 | self.k = 2 # Number of positive examples per group must be 2 63 | if self.batch_size < 2 * self.k: 64 | self.batch_size = 2 * self.k 65 | print('WARNING: Batch too small. Batch size increased to {}.'.format(self.batch_size)) 66 | 67 | self.batch_idx = [] # Index of elements in each batch (re-generated every epoch) 68 | self.elems_ndx = list(self.dataset.queries) # List of point cloud indexes 69 | 70 | def __iter__(self): 71 | # Re-generate batches every epoch 72 | self.generate_batches() 73 | for batch in self.batch_idx: 74 | yield batch 75 | 76 | def __len(self): 77 | return len(self.batch_idx) 78 | 79 | def expand_batch(self): 80 | if self.batch_expansion_rate is None: 81 | print('WARNING: batch_expansion_rate is None') 82 | return 83 | 84 | if self.batch_size >= self.batch_size_limit: 85 | return 86 | 87 | old_batch_size = self.batch_size 88 | self.batch_size = int(self.batch_size * self.batch_expansion_rate) 89 | self.batch_size = min(self.batch_size, self.batch_size_limit) 90 | print('=> Batch size increased from: {} to {}'.format(old_batch_size, self.batch_size)) 91 | 92 | def generate_batches(self): 93 | # Generate training/evaluation batches. 94 | # batch_idx holds indexes of elements in each batch as a list of lists 95 | self.batch_idx = [] 96 | 97 | unused_elements_ndx = ListDict(self.elems_ndx) 98 | current_batch = [] 99 | 100 | assert self.k == 2, 'sampler can sample only k=2 elements from the same class' 101 | 102 | while True: 103 | if len(current_batch) >= self.batch_size or len(unused_elements_ndx) == 0: 104 | # Flush out batch, when it has a desired size, or a smaller batch, when there's no more 105 | # elements to process 106 | if len(current_batch) >= 2*self.k: 107 | # Ensure there're at least two groups of similar elements, otherwise, it would not be possible 108 | # to find negative examples in the batch 109 | assert len(current_batch) % self.k == 0, 'Incorrect bach size: {}'.format(len(current_batch)) 110 | self.batch_idx.append(current_batch) 111 | current_batch = [] 112 | if (self.max_batches is not None) and (len(self.batch_idx) >= self.max_batches): 113 | break 114 | if len(unused_elements_ndx) == 0: 115 | break 116 | 117 | # Add k=2 similar elements to the batch 118 | selected_element = unused_elements_ndx.choose_random() 119 | unused_elements_ndx.remove(selected_element) 120 | positives = self.dataset.get_positives(selected_element) 121 | if len(positives) == 0: 122 | # Broken dataset element without any positives 123 | continue 124 | 125 | unused_positives = [e for e in positives if e in unused_elements_ndx] 126 | # If there're unused elements similar to selected_element, sample from them 127 | # otherwise sample from all similar elements 128 | if len(unused_positives) > 0: 129 | second_positive = random.choice(unused_positives) 130 | unused_elements_ndx.remove(second_positive) 131 | else: 132 | second_positive = random.choice(list(positives)) 133 | 134 | current_batch += [selected_element, second_positive] 135 | 136 | for batch in self.batch_idx: 137 | assert len(batch) % self.k == 0, 'Incorrect bach size: {}'.format(len(batch)) 138 | 139 | 140 | if __name__ == '__main__': 141 | pass 142 | 143 | -------------------------------------------------------------------------------- /src/data/datasets/southbay/generate_evaluation_sets.py: -------------------------------------------------------------------------------- 1 | # Generate evaluation sets 2 | # - Map point clouds are taken from MapData folder 3 | # - Query point clouds are taken from TestData 4 | # For each area (BaylandsToSeafood, ColumbiaPark, HighWay237, MathildaAVE, SanJoseDowntown, SunnyvaleBigloop) a 5 | # separate evaluation set is crated. We do not match clouds from different areas. 6 | 7 | # This file is directly copied from: https://github.com/jac99/Egonn/blob/main/datasets/southbay/generate_evaluation_sets.py 8 | 9 | import argparse 10 | import numpy as np 11 | from typing import List 12 | import os 13 | import sys 14 | sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) 15 | 16 | from datasets.southbay.southbay_raw import SouthBayDataset 17 | from datasets.base_datasets import EvaluationTuple, EvaluationSet, filter_query_elements 18 | 19 | def get_scans(ds: SouthBayDataset, split: str, area: str, min_displacement: float = 0.1) -> List[EvaluationTuple]: 20 | elems = [] 21 | for ndx in ds.location_ndx[split][area]: 22 | pose = ds.global_ndx[ndx].pose 23 | position = pose[0:2, 3] # (x, y) position in global coordinate frame 24 | rel_scan_filepath = ds.global_ndx[ndx].rel_scan_filepath 25 | timestamp = ds.global_ndx[ndx].timestamp 26 | 27 | item = EvaluationTuple(timestamp, rel_scan_filepath, position=position, pose=pose) 28 | elems.append(item) 29 | 30 | print(f"{len(elems)} total elements in {split} split") 31 | 32 | # Filter-out elements leaving only 1 per grid cell with min_displacement size 33 | pos = np.zeros((len(elems), 2), dtype=np.float32) 34 | for ndx, e in enumerate(elems): 35 | pos[ndx] = e.position 36 | 37 | # Quantize x-y coordinates. Quantized coords start from 0 38 | pos = np.floor(pos / min_displacement) 39 | pos = pos.astype(int) 40 | _, unique_ndx = np.unique(pos, axis=0, return_index=True) 41 | 42 | # Leave only unique elements 43 | elems = [elems[i] for i in unique_ndx] 44 | print(f"{len(elems)} filtered elements in {split} split with grid cell size = {min_displacement}") 45 | 46 | return elems 47 | 48 | 49 | def generate_evaluation_set(ds: SouthBayDataset, area: str, min_displacement: float = 0.1, dist_threshold=5) -> \ 50 | EvaluationSet: 51 | map_set = get_scans(ds, 'MapData', area, min_displacement) 52 | query_set = get_scans(ds, 'TestData', area, min_displacement) 53 | query_set = filter_query_elements(query_set, map_set, dist_threshold) 54 | print(f'Area: {area} - {len(map_set)} database elements, {len(query_set)} query elements\n') 55 | return EvaluationSet(query_set, map_set) 56 | 57 | 58 | if __name__ == '__main__': 59 | parser = argparse.ArgumentParser(description='Generate evaluation sets for Apollo SouthBay dataset') 60 | parser.add_argument('--dataset_root', type=str, required=False, default='/data/raktim/Datasets/Apollo-Southbay') 61 | parser.add_argument('--min_displacement', type=float, default=1.0) 62 | # Ignore query elements that do not have a corresponding map element within the given threshold (in meters) 63 | parser.add_argument('--dist_threshold', type=float, default=5) 64 | 65 | args = parser.parse_args() 66 | print(f'Dataset root: {args.dataset_root}') 67 | print(f'Minimum displacement between scans in each set (map/query): {args.min_displacement}') 68 | print(f'Ignore query elements without a corresponding map element within a threshold [m]: {args.dist_threshold}') 69 | 70 | ds = SouthBayDataset(args.dataset_root) 71 | ds.print_info() 72 | 73 | min_displacement = args.min_displacement 74 | 75 | area = 'SunnyvaleBigloop' # Evaluation area 76 | assert area in ds.location_ndx['TestData'] 77 | eval_set = generate_evaluation_set(ds, area, min_displacement=min_displacement, 78 | dist_threshold=args.dist_threshold) 79 | pickle_name = f'test_{area}_{args.min_displacement}_{args.dist_threshold}_20m.pickle' 80 | file_path_name = os.path.join(args.dataset_root, pickle_name) 81 | eval_set.save(file_path_name) 82 | -------------------------------------------------------------------------------- /src/data/datasets/southbay/generate_training_tuples.py: -------------------------------------------------------------------------------- 1 | # Generate training triplets 2 | 3 | import argparse 4 | import os 5 | import pickle 6 | import sys 7 | 8 | import numpy as np 9 | import tqdm 10 | 11 | sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) 12 | 13 | from datasets.base_datasets import TrainingTuple 14 | from datasets.southbay.southbay_raw import SouthBayDataset 15 | 16 | 17 | class Triplet: 18 | def __init__(self, anchor: int, positives: np.ndarray, non_negatives: np.ndarray): 19 | self.anchor = anchor 20 | self.positives = positives 21 | self.non_negatives = non_negatives 22 | 23 | 24 | def generate_triplets(ds: SouthBayDataset, map_split: str, query_split: str, 25 | positives_th: int = 2, negatives_th: int = 10, min_displacement: float = 0.1): 26 | # All elements (anchors, positives and negatives) are taken from both map_split and query_split 27 | assert positives_th < negatives_th 28 | 29 | # Create a master table with positions of all point clouds from the query split and in map_split 30 | pc_ids, pc_poses = ds.get_poses2([query_split, map_split]) 31 | pc_coords = pc_poses[:, :3, 3] 32 | 33 | # Quantize x-y coordinates 34 | pos = np.floor(pc_coords / min_displacement) 35 | pos = pos.astype(int) 36 | _, unique_ndx = np.unique(pos, axis=0, return_index=True) 37 | 38 | # Leave only unique elements 39 | pc_ids = pc_ids[unique_ndx] 40 | pc_coords = pc_coords[unique_ndx] 41 | print(f'{len(pc_ids)} point clouds left from {len(pc_poses)} after filtering with min_displacement={min_displacement}') 42 | 43 | triplets = [] 44 | count_zero_positives = 0 45 | for anchor_id in tqdm.tqdm(pc_ids): 46 | anchor_coords = ds.global_ndx[anchor_id].pose[:3, 3] 47 | dist = np.linalg.norm(pc_coords - anchor_coords, axis=1) 48 | positives_mask = dist <= positives_th 49 | non_negatives_mask = dist <= negatives_th 50 | 51 | positives_ndx = pc_ids[positives_mask] 52 | # remove anchor_id from positives 53 | positives_ndx = np.array([e for e in positives_ndx if e != anchor_id]) 54 | non_negatives_ndx = pc_ids[non_negatives_mask] 55 | 56 | if len(positives_ndx) == 0: 57 | # Skip examples without positives 58 | count_zero_positives += 1 59 | continue 60 | 61 | t = Triplet(anchor_id, positives_ndx, non_negatives_ndx) 62 | triplets.append(t) 63 | 64 | print(f'{count_zero_positives} filtered out due to no positives') 65 | print(f'{len(triplets)} training tuples generated') 66 | 67 | # Remove ids from positives and negatives that are not anchors 68 | anchors_set = set([e.anchor for e in triplets]) 69 | triplets = [Triplet(e.anchor, [p for p in e.positives if p in anchors_set], 70 | [nn for nn in e.non_negatives if nn in anchors_set]) for e in triplets] 71 | print(len(triplets)) 72 | # All used global ids 73 | used_ids = set() 74 | for triplet in triplets: 75 | used_ids.add(triplet.anchor) 76 | used_ids.update(list(triplet.positives)) 77 | used_ids.update(list(triplet.non_negatives)) 78 | 79 | # New ids, consecutive and starting from 0 80 | new_ids = {old_ndx: ndx for ndx, old_ndx in enumerate(used_ids)} 81 | 82 | tuples = {} 83 | for triplet in triplets: 84 | new_anchor_ndx = new_ids[triplet.anchor] 85 | pc = ds.global_ndx[triplet.anchor] 86 | positives = np.array([new_ids[e] for e in triplet.positives], dtype=np.int32) 87 | non_negatives = np.array([new_ids[e] for e in triplet.non_negatives], dtype=np.int32) 88 | 89 | # Sort ascending order 90 | positives = np.sort(positives) 91 | non_negatives = np.sort(non_negatives) 92 | 93 | tuple = TrainingTuple(id=new_anchor_ndx, timestamp=pc.timestamp, 94 | rel_scan_filepath= ds.dataset_root + '/' +pc.rel_scan_filepath, 95 | positives=positives, non_negatives=non_negatives, 96 | pose=pc.pose, positives_poses=None) 97 | tuples[new_anchor_ndx] = tuple 98 | 99 | return tuples 100 | 101 | 102 | if __name__ == '__main__': 103 | parser = argparse.ArgumentParser(description='Generate training triplets for Apollo SouthBay dataset') 104 | parser.add_argument('--dataset_root', type=str, default='/data/raktim/Datasets/Apollo-Southbay', help='Path to Apollo SouthBay root folder') 105 | parser.add_argument('--pos_th', type=float, default=2, help='Positives threshold') 106 | parser.add_argument('--neg_th', type=float, default=10, help='Negatives threshold') 107 | parser.add_argument('--min_displacement', type=float, default=1.0) 108 | 109 | query_split = 'TrainData' 110 | 111 | args = parser.parse_args() 112 | print(f'Dataset root folder: {args.dataset_root}') 113 | print(f'Split for positives/negatives: {query_split}') 114 | print(f'Positives threshold: {args.pos_th}') 115 | print(f'Negatives threshold: {args.neg_th}') 116 | print(f'Minimum displacement between consecutive scans: {args.min_displacement}') 117 | 118 | ds = SouthBayDataset(args.dataset_root) 119 | ds.print_info() 120 | 121 | triplets = generate_triplets(ds, 'MapData', query_split, positives_th=args.pos_th, negatives_th=args.neg_th, 122 | min_displacement=args.min_displacement) 123 | print(f'{len(triplets)} anchors generated') 124 | 125 | pickle_name = f'train_southbay_{args.pos_th}_{args.neg_th}.pickle' 126 | pickle_filepath = os.path.join(args.dataset_root, pickle_name) 127 | pickle.dump(triplets, open(pickle_filepath, 'wb')) 128 | -------------------------------------------------------------------------------- /src/data/datasets/southbay/test_SunnyvaleBigloop_1.0_5_20m.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raktimgg/SALSA/9152e8dff6fb3fc07dd546c9f19bfab462d08a6f/src/data/datasets/southbay/test_SunnyvaleBigloop_1.0_5_20m.pickle -------------------------------------------------------------------------------- /src/evaluate/SALSA/sgv_utils.py: -------------------------------------------------------------------------------- 1 | # Functions in this file are adapted from: https://github.com/ZhiChen902/SC2-PCR/blob/main/SC2_PCR.py 2 | 3 | import numpy as np 4 | import torch 5 | 6 | 7 | def match_pair_parallel(src_keypts, tgt_keypts, src_features, tgt_features): 8 | # normalize: 9 | # print(src_features.shape, tgt_features.shape) 10 | src_features = torch.nn.functional.normalize(src_features, p=2.0, dim=1) 11 | tgt_features = torch.nn.functional.normalize(tgt_features, p=2.0, dim=1) 12 | distance = torch.cdist(src_features, tgt_features) 13 | min_vals, min_ids = torch.min(distance, dim=2) 14 | 15 | min_ids = min_ids.unsqueeze(-1).expand(-1, -1, 3) 16 | tgt_keypts_corr = torch.gather(tgt_keypts, 1, min_ids) 17 | src_keypts_corr = src_keypts 18 | 19 | return src_keypts_corr, tgt_keypts_corr 20 | 21 | # def match_pair_parallel(src_keypts, tgt_keypts, src_features, tgt_features): 22 | # # normalize: 23 | # src_features = torch.nn.functional.normalize(src_features, p=2.0, dim=1) 24 | # tgt_features = torch.nn.functional.normalize(tgt_features, p=2.0, dim=1) 25 | # # print(src_features.shape, tgt_features.shape) 26 | # distance = torch.cdist(src_features, tgt_features) 27 | 28 | # min_vals, min_ids = torch.min(distance, dim=2) 29 | 30 | # # print(distance.shape) 31 | # sorted_vals, sorted_ids = torch.sort(min_vals, dim=1) 32 | # sorted_ids = sorted_ids[:,:2000] 33 | # # print(sorted_ids.shape) 34 | # # print(sorted_ids) 35 | # # print(min_ids.shape) 36 | # # print(min_vals[:,sorted_ids[0]]) 37 | # # print(min_ids[:,sorted_ids[0]]) 38 | 39 | # tgt_ids = min_ids[:,sorted_ids[0]].unsqueeze(-1).expand(-1, -1, 3) 40 | # src_ids = sorted_ids.unsqueeze(-1).expand(-1, -1, 3) 41 | 42 | # tgt_keypts_corr = torch.gather(tgt_keypts, 1, tgt_ids) 43 | # # print(tgt_keypts_corr.shape) 44 | # src_keypts_corr = torch.gather(src_keypts, 1, src_ids) 45 | # # print(src_keypts_corr.shape) 46 | 47 | # return src_keypts_corr, tgt_keypts_corr 48 | 49 | def power_iteration(M, num_iterations=5): 50 | """ 51 | Calculate the leading eigenvector using power iteration algorithm 52 | Input: 53 | - M: [bs, num_pts, num_pts] the adjacency matrix 54 | Output: 55 | - leading_eig: [bs, num_pts] leading eigenvector 56 | """ 57 | leading_eig = torch.ones_like(M[:, :, 0:1]) 58 | leading_eig_last = leading_eig 59 | for i in range(num_iterations): 60 | # print(i) 61 | leading_eig = torch.bmm(M, leading_eig) 62 | leading_eig = leading_eig / (torch.norm(leading_eig, dim=1, keepdim=True) + 1e-6) 63 | if torch.allclose(leading_eig, leading_eig_last): 64 | break 65 | leading_eig_last = leading_eig 66 | leading_eig = leading_eig.squeeze(-1) 67 | return leading_eig 68 | 69 | 70 | def cal_spatial_consistency( M, leading_eig): 71 | """ 72 | Calculate the spatial consistency based on spectral analysis. 73 | Input: 74 | - M: [bs, num_pts, num_pts] the adjacency matrix 75 | - leading_eig [bs, num_pts] the leading eigenvector of matrix M 76 | Output: 77 | - sc_score_list [bs, 1] 78 | """ 79 | spatial_consistency = leading_eig[:, None, :] @ M @ leading_eig[:, :, None] 80 | spatial_consistency = spatial_consistency.squeeze(-1) / M.shape[1] 81 | return spatial_consistency 82 | 83 | 84 | def sgv(src_keypts, tgt_keypts, src_features, tgt_features, d_thresh=5.0): 85 | """ 86 | Input: 87 | - src_keypts: [1, num_pts, 3] 88 | - tgt_keypts: [bs, num_pts, 3] 89 | - src_features: [1, num_pts, D] 90 | - tgt_features: [bs, num_pts, D] 91 | Output: 92 | - sc_score_list: [bs, 1], spatial consistency score for each candidate 93 | """ 94 | # print('src_kp', src_keypts.shape) 95 | # print('tgt_kp', tgt_keypts.shape) 96 | # Correspondence Estimation: Nearest Neighbour Matching 97 | src_keypts_corr, tgt_keypts_corr = match_pair_parallel(src_keypts, tgt_keypts, src_features, tgt_features) 98 | # print('src',src_keypts_corr.shape) 99 | # print('tgt',tgt_keypts_corr.shape) 100 | 101 | # Spatial Consistency Adjacency Matrix 102 | src_dist = torch.norm((src_keypts_corr[:, :, None, :] - src_keypts_corr[:, None, :, :]), dim=-1) 103 | target_dist = torch.norm((tgt_keypts_corr[:, :, None, :] - tgt_keypts_corr[:, None, :, :]), dim=-1) 104 | cross_dist = torch.abs(src_dist - target_dist) 105 | adj_mat = torch.clamp(1.0 - cross_dist ** 2 / d_thresh ** 2, min=0) 106 | # print(adj_mat) 107 | # Spatial Consistency Score 108 | # print(adj_mat.shape) 109 | lead_eigvec = power_iteration(adj_mat) 110 | # print(lead_eigvec.shape) 111 | sc_score_list = cal_spatial_consistency(adj_mat, lead_eigvec) 112 | 113 | sc_score_list = np.squeeze(sc_score_list.cpu().detach().numpy()) 114 | return sc_score_list 115 | 116 | def sgv_fn(query_keypoints, candidate_keypoints, d_thresh=5.0, max_points=15000): 117 | 118 | # print(len(query_keypoints),len(candidate_keypoints)) 119 | kp1 = query_keypoints['keypoints'] 120 | kp2 = candidate_keypoints['keypoints'] 121 | f1 = query_keypoints['features'] 122 | f2 = candidate_keypoints['features'] 123 | 124 | # draw_registration_result(kp1, kp2, np.eye(4)) 125 | 126 | min_num_feat = min(len(kp1),len(kp2)) 127 | min_num_feat = min(min_num_feat, max_points) 128 | kp1 = kp1[:min_num_feat] 129 | kp2 = kp2[:min_num_feat] 130 | f1 = f1[:min_num_feat] 131 | f2 = f2[:min_num_feat] 132 | 133 | src_keypts = kp1.unsqueeze(0).cuda() 134 | tgt_keypts = kp2.unsqueeze(0).cuda() 135 | src_features = f1.unsqueeze(0).cuda() 136 | tgt_features = f2.unsqueeze(0).cuda() 137 | 138 | conf = sgv(src_keypts, tgt_keypts, src_features, tgt_features, d_thresh=d_thresh) 139 | 140 | 141 | 142 | return conf 143 | -------------------------------------------------------------------------------- /src/loss/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raktimgg/SALSA/9152e8dff6fb3fc07dd546c9f19bfab462d08a6f/src/loss/__init__.py -------------------------------------------------------------------------------- /src/loss/global_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | def triplet_margin_loss(query, positive, negative, margin=0.1): 5 | distance_positive = torch.norm(query - positive, dim=1, p=2)**2 6 | distance_negative = torch.norm(query - negative, dim=1, p=2)**2 7 | losses = torch.relu(distance_positive - distance_negative + margin) 8 | loss = torch.mean(losses) 9 | return loss 10 | 11 | # def triplet_margin_loss(query, positive, negative, margin=0.1): 12 | # distance_positive = F.cosine_similarity(query, positive, dim=1) 13 | # distance_negative = F.cosine_similarity(query, negative, dim=1) 14 | # losses = torch.relu(distance_negative - distance_positive + margin) 15 | # loss = torch.mean(losses) 16 | # return loss 17 | -------------------------------------------------------------------------------- /src/loss/local_consistency_loss.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | # sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 7 | from utils.misc_utils import pdist, hashM 8 | 9 | 10 | def point_contrastive_loss(F0, F1, positive_pairs, point_pos_margin=0.1, point_neg_margin=2.0,point_neg_weight=1.0, 11 | num_pos=5192, 12 | num_hn_samples=2048): 13 | """ 14 | Randomly select 'num-pos' positive pairs. 15 | Find the hardest-negative (from a random subset of num_hn_samples) for each point in a positive pair. 16 | Calculate contrastive loss on the tuple (p1,p2,hn1,hn2) 17 | Based on: https://github.com/chrischoy/FCGF/blob/master/lib/trainer.py 18 | """ 19 | N0, N1 = len(F0), len(F1) 20 | N_pos_pairs = len(positive_pairs) 21 | hash_seed = max(N0, N1) 22 | sel0 = np.random.choice(N0, min(N0, num_hn_samples), replace=False) 23 | sel1 = np.random.choice(N1, min(N1, num_hn_samples), replace=False) 24 | 25 | if N_pos_pairs > num_pos: 26 | pos_sel = np.random.choice(N_pos_pairs, num_pos, replace=False) 27 | sample_pos_pairs = positive_pairs[pos_sel] 28 | else: 29 | sample_pos_pairs = positive_pairs 30 | 31 | # Find negatives for all F1[positive_pairs[:, 1]] 32 | subF0, subF1 = F0[sel0], F1[sel1] 33 | 34 | pos_ind0 = sample_pos_pairs[:, 0] # .long() 35 | pos_ind1 = sample_pos_pairs[:, 1] # .long() 36 | posF0, posF1 = F0[pos_ind0], F1[pos_ind1] 37 | 38 | D01 = pdist(posF0, subF1, dist_type='L2') 39 | D10 = pdist(posF1, subF0, dist_type='L2') 40 | 41 | D01min, D01ind = D01.min(1) 42 | D10min, D10ind = D10.min(1) 43 | 44 | if not isinstance(positive_pairs, np.ndarray): 45 | positive_pairs = np.array(positive_pairs, dtype=np.int64) 46 | 47 | pos_keys = hashM(positive_pairs, hash_seed) 48 | 49 | D01ind = sel1[D01ind.cpu().numpy()] 50 | D10ind = sel0[D10ind.cpu().numpy()] 51 | neg_keys0 = hashM([pos_ind0, D01ind], hash_seed) 52 | neg_keys1 = hashM([D10ind, pos_ind1], hash_seed) 53 | 54 | mask0 = torch.from_numpy( 55 | np.logical_not(np.isin(neg_keys0, pos_keys, assume_unique=False))) 56 | mask1 = torch.from_numpy( 57 | np.logical_not(np.isin(neg_keys1, pos_keys, assume_unique=False))) 58 | pos_loss = F.relu((posF0 - posF1).pow(2).sum(1) - point_pos_margin) 59 | neg_loss0 = F.relu(point_neg_margin - D01min[mask0]).pow(2) 60 | neg_loss1 = F.relu(point_neg_margin - D10min[mask1]).pow(2) 61 | 62 | pos_loss = pos_loss.mean() 63 | neg_loss = (neg_loss0.mean() + neg_loss1.mean()) / 2 64 | loss = pos_loss + point_neg_weight * neg_loss 65 | return loss 66 | 67 | 68 | def point_infonce_loss(query_feats, pos_feats, pos_pairs, neg_pairs, config): # TODO 69 | return 0 70 | -------------------------------------------------------------------------------- /src/loss/loss.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 5 | 6 | from loss.global_loss import triplet_margin_loss 7 | from loss.local_consistency_loss import point_contrastive_loss 8 | 9 | def find_loss(local_features, global_descriptors,point_pos_pairs): 10 | global_loss = 0 11 | local_loss = 0 12 | batch_size = int(global_descriptors.shape[0]/3) 13 | for i in range(batch_size): 14 | 15 | ## Global descriptor triplet loss ################################################### 16 | q_gd = global_descriptors[i][None,...] 17 | p_gd = global_descriptors[(1*batch_size) + i][None,...] 18 | n_gd = global_descriptors[(2*batch_size) + i][None,...] 19 | 20 | 21 | global_loss += triplet_margin_loss(q_gd, p_gd, n_gd, margin=0.1) 22 | 23 | ## Local features loss ############################################################### 24 | ## Point triplet loss ################### 25 | if point_pos_pairs[i].shape[0]>0: 26 | point_loss = point_contrastive_loss(local_features[i], local_features[(1*batch_size) + i], point_pos_pairs[i]) 27 | else: 28 | point_loss = 0 29 | local_loss += point_loss 30 | ######################################### 31 | # loss = global_loss + local_loss 32 | loss = global_loss + local_loss 33 | return loss -------------------------------------------------------------------------------- /src/misc/point_clouds.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | 4 | import numpy as np 5 | import open3d as o3d 6 | 7 | 8 | def draw_registration_result(source, target, transformation): 9 | source_temp = copy.deepcopy(source) 10 | target_temp = copy.deepcopy(target) 11 | source_temp.paint_uniform_color([1, 0.706, 0]) 12 | target_temp.paint_uniform_color([0, 0.651, 0.929]) 13 | source_temp.transform(transformation) 14 | o3d.visualization.draw_geometries([source_temp, target_temp], 15 | zoom=0.4459, 16 | front=[0.9288, -0.2951, -0.2242], 17 | lookat=[1.6784, 2.0612, 1.4451], 18 | up=[-0.3402, -0.9189, -0.1996]) 19 | 20 | 21 | def draw_pc(pc): 22 | pc = copy.deepcopy(pc) 23 | pc.paint_uniform_color([1, 0.706, 0]) 24 | o3d.visualization.draw_geometries([pc], 25 | zoom=0.4459, 26 | front=[0.9288, -0.2951, -0.2242], 27 | lookat=[1.6784, 2.0612, 1.4451], 28 | up=[-0.3402, -0.9189, -0.1996]) 29 | 30 | 31 | def icp(anchor_pc, positive_pc, transform: np.ndarray = None, point2plane: bool = False, 32 | inlier_dist_threshold: float = 1.2, max_iteration: int = 200): 33 | # transform: initial alignment transform 34 | if transform is not None: 35 | transform = transform.astype(float) 36 | 37 | voxel_size = 0.1 38 | pcd1 = o3d.geometry.PointCloud() 39 | pcd1.points = o3d.utility.Vector3dVector(anchor_pc) 40 | pcd1 = pcd1.voxel_down_sample(voxel_size=voxel_size) 41 | 42 | pcd2 = o3d.geometry.PointCloud() 43 | pcd2.points = o3d.utility.Vector3dVector(positive_pc) 44 | pcd2 = pcd2.voxel_down_sample(voxel_size=voxel_size) 45 | 46 | if point2plane: 47 | pcd1.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamKNN(knn=20)) 48 | pcd2.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamKNN(knn=20)) 49 | transform_estimation = o3d.pipelines.registration.TransformationEstimationPointToPlane() 50 | else: 51 | transform_estimation = o3d.pipelines.registration.TransformationEstimationPointToPoint() 52 | 53 | if transform is not None: 54 | reg_p2p = o3d.pipelines.registration.registration_icp(pcd1, pcd2, inlier_dist_threshold, transform, 55 | estimation_method=transform_estimation, 56 | criteria=o3d.pipelines.registration.ICPConvergenceCriteria(max_iteration=max_iteration)) 57 | else: 58 | reg_p2p = o3d.pipelines.registration.registration_icp(pcd1, pcd2, inlier_dist_threshold, 59 | estimation_method=transform_estimation, 60 | criteria=o3d.pipelines.registration.ICPConvergenceCriteria(max_iteration=max_iteration)) 61 | 62 | return reg_p2p.transformation, reg_p2p.fitness, reg_p2p.inlier_rmse 63 | 64 | 65 | def make_open3d_feature(data, dim, npts): 66 | feature = o3d.pipelines.registration.Feature() 67 | feature.resize(dim, npts) 68 | feature.data = data.cpu().numpy().astype('d').transpose() 69 | return feature 70 | 71 | 72 | def make_open3d_point_cloud(xyz, color=None): 73 | pcd = o3d.geometry.PointCloud() 74 | pcd.points = o3d.utility.Vector3dVector(xyz) 75 | if color is not None: 76 | pcd.colors = o3d.utility.Vector3dVector(color) 77 | return pcd 78 | 79 | 80 | class PointCloudLoader: 81 | # Generic point cloud loader class 82 | def __init__(self): 83 | # remove_zero_points: remove points with all zero coordinates 84 | # remove_ground_plane: remove points on ground plane level and below 85 | # ground_plane_level: ground plane level 86 | self.remove_zero_points = True 87 | self.remove_ground_plane = True 88 | self.ground_plane_level = None 89 | self.set_properties() 90 | 91 | def set_properties(self): 92 | # Set point cloud properties, such as ground_plane_level. Must be defined in inherited classes. 93 | raise NotImplementedError('set_properties must be defined in inherited classes') 94 | 95 | def __call__(self, file_pathname): 96 | # Reads the point cloud from a disk and preprocess (optional removal of zero points and points on the ground 97 | # plane and below 98 | # file_pathname: relative file path 99 | assert os.path.exists(file_pathname), f"Cannot open point cloud: {file_pathname}" 100 | pc = self.read_pc(file_pathname) 101 | # assert pc.shape[1] == 3 102 | if self.remove_zero_points: 103 | mask = np.all(np.isclose(pc, 0), axis=1) 104 | pc = pc[~mask] 105 | if self.remove_ground_plane: 106 | mask = pc[:, 2] > self.ground_plane_level 107 | pc = pc[mask] 108 | 109 | return pc 110 | 111 | def read_pc(self, file_pathname): 112 | # Reads the point cloud without pre-processing 113 | raise NotImplementedError("read_pc must be overloaded in an inheriting class") 114 | -------------------------------------------------------------------------------- /src/misc/poses.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def q2r(q): 6 | # Rotation matrix from Hamiltonian quaternion 7 | # Source: https://en.wikipedia.org/wiki/Conversion_between_quaternions_and_Euler_angles 8 | w, x, y, z = tuple(q) 9 | 10 | n = 1.0/np.sqrt(x*x+y*y+z*z+w*w) 11 | x *= n 12 | y *= n 13 | z *= n 14 | w *= n 15 | r = np.array([[1.0 - 2.0*y*y - 2.0*z*z, 2.0*x*y - 2.0*z*w, 2.0*x*z + 2.0*y*w], 16 | [2.0*x*y + 2.0*z*w, 1.0 - 2.0*x*x - 2.0*z*z, 2.0*y*z - 2.0*x*w], 17 | [2.0*x*z - 2.0*y*w, 2.0*y*z + 2.0*x*w, 1.0 - 2.0*x*x - 2.0*y*y]]) 18 | return r 19 | 20 | 21 | def m2ypr(m): 22 | # Get yaw, pitch, roll angles from 4x4 transformation matrix 23 | # Based on formulas in Section 2.5.1 in: 24 | # A tutorial on SE(3) transformation parameterizations and on-manifold optimization 25 | # https://ingmec.ual.es/~jlblanco/papers/jlblanco2010geometry3D_techrep.pdf 26 | assert m.shape == (4, 4) 27 | pitch = np.arctan2(-m[2][0], np.sqrt(m[0][0]**2 + m[1][0]**2)) 28 | # We do not handle degenerate case, when pitch is 90 degrees a.k.a. gimball lock 29 | assert not np.isclose(np.abs(pitch), np.pi/2) 30 | yaw = np.arctan2(m[1][0], m[0][0]) 31 | roll = np.arctan2(m[2][1], m[2][2]) 32 | return yaw, pitch, roll 33 | 34 | 35 | def m2xyz_ypr(m): 36 | # Get yaw, pitch, roll angles from 4x4 transformation matrix 37 | # Based on formulas in Section 2.5.1 in: 38 | # A tutorial on SE(3) transformation parameterizations and on-manifold optimization 39 | # https://ingmec.ual.es/~jlblanco/papers/jlblanco2010geometry3D_techrep.pdf 40 | assert m.shape == (4, 4) 41 | yaw, pitch, roll = m2ypr(m) 42 | return m[0, 3], m[1, 3], m[2, 3], yaw, pitch, roll 43 | 44 | 45 | def ypr2m(yaw, pitch, roll): 46 | # Construct 4x4 transformation matrix with rotation part set to given yaw, pitch, roll. Translation is set to 0. 47 | # Based on formulas in Section 2.2.1 48 | m = np.array([[np.cos(yaw) * np.cos(pitch), np.cos(yaw) * np.sin(pitch) * np.sin(roll) - np.sin(yaw) * np.cos(roll), 49 | np.cos(yaw) * np.sin(pitch) * np.cos(roll) + np.sin(yaw) * np.sin(roll), 0.], 50 | [np.sin(yaw) * np.cos(pitch), np.sin(roll) * np.sin(pitch) * np.sin(roll) + np.cos(yaw) * np.cos(roll), 51 | np.sin(yaw) * np.sin(pitch) * np.cos(roll) - np.cos(yaw) * np.sin(roll), 0.], 52 | [-np.sin(pitch), np.cos(pitch) * np.sin(roll), np.cos(pitch) * np.cos(roll), 0.], 53 | [0., 0., 0., 1.]], dtype=np.float32) 54 | 55 | return m 56 | 57 | 58 | def xyz_ypr2m(x, y, z, yaw, pitch, roll): 59 | # Construct 4x4 transformation matrix with given translation and rotation part set to given yaw, pitch, roll. 60 | # Based on formulas in Section 2.2.1 61 | m = ypr2m(yaw, pitch, roll) 62 | m[0, 3] = x 63 | m[1, 3] = y 64 | m[2, 3] = z 65 | return m 66 | 67 | 68 | def apply_transform(pc: torch.Tensor, m: torch.Tensor): 69 | # Apply 4x4 SE(3) transformation matrix on (N, 3) point cloud or 3x3 transformation on (N, 2) point cloud 70 | assert pc.ndim == 2 71 | n_dim = pc.shape[1] 72 | assert n_dim == 2 or n_dim == 3 73 | assert m.shape == (n_dim + 1, n_dim + 1) 74 | # (m @ pc.t).t = pc @ m.t 75 | pc = pc @ m[:n_dim, :n_dim].transpose(1, 0) + m[:n_dim, -1] 76 | return pc 77 | 78 | 79 | def relative_pose(m1, m2): 80 | # !!! DO NOT USE THIS FUNCTION FOR MULRAN POSES !!! 81 | # SE(3) pose is 4x 4 matrix, such that 82 | # Pw = [R | T] @ [P] 83 | # [0 | 1] [1] 84 | # where Pw are coordinates in the world reference frame and P are coordinates in the camera frame 85 | # m1: coords in camera/lidar1 reference frame -> world coordinate frame 86 | # m2: coords in camera/lidar2 coords -> world coordinate frame 87 | # returns: coords in camera/lidar1 reference frame -> coords in camera/lidar2 reference frame 88 | # relative pose of the first camera with respect to the second camera 89 | return np.linalg.inv(m2) @ m1 90 | -------------------------------------------------------------------------------- /src/models/Mixer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raktimgg/SALSA/9152e8dff6fb3fc07dd546c9f19bfab462d08a6f/src/models/Mixer/__init__.py -------------------------------------------------------------------------------- /src/models/Mixer/mixer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) 10 | 11 | 12 | class FeatureMixerLayer(nn.Module): 13 | def __init__(self, in_dim, mlp_ratio=1): 14 | super().__init__() 15 | self.mix = nn.Sequential( 16 | nn.LayerNorm(in_dim), 17 | nn.Linear(in_dim, int(in_dim * mlp_ratio)), 18 | nn.ReLU(), 19 | nn.Linear(int(in_dim * mlp_ratio), in_dim), 20 | ) 21 | 22 | for m in self.modules(): 23 | if isinstance(m, (nn.Linear)): 24 | nn.init.trunc_normal_(m.weight, std=0.02) 25 | if m.bias is not None: 26 | nn.init.zeros_(m.bias) 27 | 28 | def forward(self, x): 29 | return x + self.mix(x) 30 | 31 | 32 | class Mixer(nn.Module): 33 | def __init__(self, 34 | in_channels=35000, 35 | out_channels=1000, 36 | in_d=30, 37 | mix_depth=1, 38 | mlp_ratio=1, 39 | out_d=4, 40 | ) -> None: 41 | super().__init__() 42 | 43 | self.in_d = in_d 44 | 45 | self.out_d = out_d # row wise projection dimesion 46 | 47 | self.mix_depth = mix_depth # L the number of stacked FeatureMixers 48 | self.mlp_ratio = mlp_ratio # ratio of the mid projection layer in the mixer block 49 | 50 | self.mix = nn.Sequential(*[ 51 | FeatureMixerLayer(in_dim=in_d, mlp_ratio=mlp_ratio) 52 | for _ in range(self.mix_depth) 53 | ]) 54 | self.row_proj = nn.Linear(in_d, out_d) 55 | self.channel_proj = nn.Linear(in_channels, out_channels) 56 | 57 | def forward(self, x): 58 | # x = x.unsqueeze(0) 59 | x = self.mix(x) 60 | x = x.permute(0, 2, 1) 61 | x = self.channel_proj(x) 62 | x = x.permute(0, 2, 1) 63 | x = self.row_proj(x) 64 | x = F.normalize(x.flatten(1), p=2, dim=-1) 65 | return x 66 | 67 | 68 | # ------------------------------------------------------------------------------- 69 | 70 | def print_nb_params(m): 71 | model_parameters = filter(lambda p: p.requires_grad, m.parameters()) 72 | params = sum([np.prod(p.size()) for p in model_parameters]) 73 | print(f'Trainable parameters: {params/1e6:.3}M') 74 | 75 | 76 | def main(): 77 | 78 | model = Mixer().to('cuda') 79 | print_nb_params(model) 80 | 81 | 82 | if __name__ == '__main__': 83 | main() 84 | -------------------------------------------------------------------------------- /src/models/SphereFormer/SparseTransformer/.gitignore: -------------------------------------------------------------------------------- 1 | *.egg-info 2 | build/ 3 | dist/ 4 | */__pycache__/ 5 | *.pyc 6 | -------------------------------------------------------------------------------- /src/models/SphereFormer/SparseTransformer/README.md: -------------------------------------------------------------------------------- 1 | # SpTr: PyTorch Spatially Sparse Transformer Library 2 | 3 | This library provides a **fast** and **memory-efficient** implementation for sparse transformer with **varying token numbers** (e.g., window transformer for 3D point cloud). 4 | 5 | This library has been used by the following works: 6 | 7 | * Spherical Transformer for LiDAR-based 3D Recognition (CVPR 2023): \[Paper\] [\[Code\]](https://github.com/dvlab-research/SphereFormer) 8 | 9 | * Stratified Transformer for 3D Point Cloud Segmentation (CVPR 2022): [\[Paper\]](https://openaccess.thecvf.com/content/CVPR2022/papers/Lai_Stratified_Transformer_for_3D_Point_Cloud_Segmentation_CVPR_2022_paper.pdf) [\[Code\]](https://github.com/dvlab-research/Stratified-Transformer) 10 | 11 | ## Installation 12 | ### Install Dependency 13 | ``` 14 | pip install torch==1.8.0+cu111 torchvision==0.9.0+cu111 torchaudio==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html 15 | pip install torch_scatter 16 | pip install torch_geometric 17 | ``` 18 | 19 | ### Compile sptr 20 | ``` 21 | python3 setup.py install 22 | ``` 23 | 24 | 25 | ## Usage 26 | SpTr can be easily used in most current transformer-based 3D point cloud networks, with only several minor modifications. First, define the attention module `sptr.VarLengthMultiheadSA`. Then, wrap the input features and indices into `sptr.SparseTrTensor`, and foward it into the module. That's all. A simple example is as follows. For more complex usage, you can refer to the code of above works (e.g., SphereFormer, StratifiedFormer). 27 | ### Example 28 | ``` 29 | import sptr 30 | 31 | # Define module 32 | dim = 48 33 | num_heads = 3 34 | indice_key = 'sptr_0' 35 | window_size = np.array([0.4, 0.4, 0.4]) # can also be integers for voxel-based methods 36 | shift_win = False # whether to adopt shifted window 37 | self.attn = sptr.VarLengthMultiheadSA( 38 | dim, 39 | num_heads, 40 | indice_key, 41 | window_size, 42 | shift_win 43 | ) 44 | 45 | # Wrap the input features and indices into SparseTrTensor. Note: indices can be either intergers for voxel-based methods or floats (i.e., xyz) for point-based methods 46 | # feats: [N, C], indices: [N, 4] with batch indices in the 0-th column 47 | input_tensor = sptr.SparseTrTensor(feats, indices, spatial_shape=None, batch_size=None) 48 | output_tensor = self.attn(input_tensor) 49 | 50 | # Extract features from output tensor 51 | output_feats = output_tensor.query_feats 52 | ``` 53 | 54 | ## Authors 55 | 56 | Xin Lai (a Ph.D student at CSE CUHK) - Initial CUDA implementation, maintainance. 57 | 58 | Fanbin Lu (a Ph.D student at CSE CUHK) - Improve CUDA implementation, maintainance. 59 | 60 | Yukang Chen (a Ph.D student at CSE CUHK) - Maintainance. 61 | 62 | 63 | ## Cite 64 | 65 | If you find this project useful, please consider citing 66 | ``` 67 | @inproceedings{lai2023spherical, 68 | title={Spherical Transformer for LiDAR-based 3D Recognition}, 69 | author={Lai, Xin and Chen, Yukang and Lu, Fanbin and Liu, Jianhui and Jia, Jiaya}, 70 | booktitle={CVPR}, 71 | year={2023} 72 | } 73 | ``` 74 | ``` 75 | @inproceedings{lai2022stratified, 76 | title={Stratified transformer for 3d point cloud segmentation}, 77 | author={Lai, Xin and Liu, Jianhui and Jiang, Li and Wang, Liwei and Zhao, Hengshuang and Liu, Shu and Qi, Xiaojuan and Jia, Jiaya}, 78 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 79 | pages={8500--8509}, 80 | year={2022} 81 | } 82 | ``` 83 | 84 | ## License 85 | 86 | This project is licensed under the Apache license 2.0 License. 87 | -------------------------------------------------------------------------------- /src/models/SphereFormer/SparseTransformer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raktimgg/SALSA/9152e8dff6fb3fc07dd546c9f19bfab462d08a6f/src/models/SphereFormer/SparseTransformer/__init__.py -------------------------------------------------------------------------------- /src/models/SphereFormer/SparseTransformer/setup.py: -------------------------------------------------------------------------------- 1 | #python3 setup.py install 2 | from setuptools import setup 3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 4 | import os 5 | from distutils.sysconfig import get_config_vars 6 | 7 | (opt,) = get_config_vars('OPT') 8 | os.environ['OPT'] = " ".join( 9 | flag for flag in opt.split() if flag != '-Wstrict-prototypes' 10 | ) 11 | 12 | setup( 13 | name='sptr', 14 | ext_modules=[ 15 | CUDAExtension('sptr_cuda', [ 16 | 'src/sptr/pointops_api.cpp', 17 | 'src/sptr/attention/attention_cuda.cpp', 18 | 'src/sptr/attention/attention_cuda_kernel.cu', 19 | 'src/sptr/precompute/precompute.cpp', 20 | 'src/sptr/precompute/precompute_cuda_kernel.cu', 21 | 'src/sptr/rpe/relative_pos_encoding_cuda.cpp', 22 | 'src/sptr/rpe/relative_pos_encoding_cuda_kernel.cu', 23 | ], 24 | extra_compile_args={'cxx': ['-g'], 'nvcc': ['-O2']} 25 | ) 26 | ], 27 | cmdclass={'build_ext': BuildExtension} 28 | ) 29 | -------------------------------------------------------------------------------- /src/models/SphereFormer/SparseTransformer/sptr/__init__.py: -------------------------------------------------------------------------------- 1 | from .functional import * 2 | from .utils import * 3 | 4 | class SparseTrTensor(object): 5 | def __init__(self, query_feats, query_indices, spatial_shape, batch_size, key_feats=None, value_feats=None, key_indices=None): 6 | """ 7 | Args: 8 | query_feats: [num_points, num_features] feature tensor 9 | indices: [num_points, ndim + 1] indice tensor. batch index saved in indices[:, 0] 10 | spatial_shape: spatial shape of your sparse data 11 | batch_size: batch size of your sparse data 12 | """ 13 | self.query_feats = query_feats 14 | self.key_feats = key_feats 15 | self.value_feats = value_feats 16 | self.query_indices = query_indices 17 | self.key_indices = key_indices 18 | self.spatial_shape = spatial_shape 19 | self.batch_size = batch_size 20 | self.indice_dict = {} 21 | 22 | @property 23 | def spatial_size(self): 24 | return np.prod(self.spatial_shape) 25 | 26 | def find_indice_params(self, key): 27 | if key is None: 28 | return None 29 | if key in self.indice_dict: 30 | return self.indice_dict[key] 31 | return None 32 | 33 | from .modules import VarLengthMultiheadSA, sparse_self_attention 34 | -------------------------------------------------------------------------------- /src/models/SphereFormer/SparseTransformer/sptr/position_embedding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Various positional encodings for the transformer. 4 | """ 5 | import math 6 | import torch 7 | from torch import nn 8 | import numpy as np 9 | 10 | 11 | def shift_scale_points(pred_xyz, src_range, dst_range=None): 12 | """ 13 | pred_xyz: B x N x 3 14 | src_range: [[B x 3], [B x 3]] - min and max XYZ coords 15 | dst_range: [[B x 3], [B x 3]] - min and max XYZ coords 16 | """ 17 | if dst_range is None: 18 | dst_range = [ 19 | torch.zeros((src_range[0].shape[0], 3), device=src_range[0].device), 20 | torch.ones((src_range[0].shape[0], 3), device=src_range[0].device), 21 | ] 22 | 23 | if pred_xyz.ndim == 4: 24 | src_range = [x[:, None] for x in src_range] 25 | dst_range = [x[:, None] for x in dst_range] 26 | 27 | assert src_range[0].shape[0] == pred_xyz.shape[0] 28 | assert dst_range[0].shape[0] == pred_xyz.shape[0] 29 | assert src_range[0].shape[-1] == pred_xyz.shape[-1] 30 | assert src_range[0].shape == src_range[1].shape 31 | assert dst_range[0].shape == dst_range[1].shape 32 | assert src_range[0].shape == dst_range[1].shape 33 | 34 | src_diff = src_range[1][:, None, :] - src_range[0][:, None, :] 35 | dst_diff = dst_range[1][:, None, :] - dst_range[0][:, None, :] 36 | prop_xyz = ( 37 | ((pred_xyz - src_range[0][:, None, :]) * dst_diff) / src_diff 38 | ) + dst_range[0][:, None, :] 39 | return prop_xyz 40 | 41 | 42 | class PositionEmbeddingCoordsSine(nn.Module): 43 | def __init__( 44 | self, 45 | temperature=10000, 46 | normalize=False, 47 | scale=None, 48 | pos_type="fourier", 49 | d_pos=None, 50 | d_in=3, 51 | gauss_scale=1.0 52 | ): 53 | super().__init__() 54 | self.d_pos = d_pos 55 | self.temperature = temperature 56 | self.normalize = normalize 57 | if scale is not None and normalize is False: 58 | raise ValueError("normalize should be True if scale is passed") 59 | if scale is None: 60 | scale = 2 * math.pi 61 | assert pos_type in ["sine", "fourier"] 62 | self.pos_type = pos_type 63 | self.scale = scale 64 | if pos_type == "fourier": 65 | assert d_pos is not None 66 | assert d_pos % 2 == 0 67 | # define a gaussian matrix input_ch -> output_ch 68 | B = torch.empty((d_in, d_pos // 2)).normal_() 69 | B *= gauss_scale 70 | self.register_buffer("gauss_B", B) 71 | self.d_pos = d_pos 72 | 73 | def get_sine_embeddings(self, xyz, num_channels, input_range): 74 | num_channels = self.d_pos 75 | # clone coords so that shift/scale operations do not affect original tensor 76 | orig_xyz = xyz 77 | xyz = orig_xyz.clone() 78 | 79 | ncoords = xyz.shape[1] 80 | if self.normalize: 81 | xyz = shift_scale_points(xyz, src_range=input_range) 82 | 83 | ndim = num_channels // xyz.shape[2] 84 | if ndim % 2 != 0: 85 | ndim -= 1 86 | # automatically handle remainder by assiging it to the first dim 87 | rems = num_channels - (ndim * xyz.shape[2]) 88 | 89 | assert ( 90 | ndim % 2 == 0 91 | ), f"Cannot handle odd sized ndim={ndim} where num_channels={num_channels} and xyz={xyz.shape}" 92 | 93 | final_embeds = [] 94 | prev_dim = 0 95 | 96 | for d in range(xyz.shape[2]): 97 | cdim = ndim 98 | if rems > 0: 99 | # add remainder in increments of two to maintain even size 100 | cdim += 2 101 | rems -= 2 102 | 103 | if cdim != prev_dim: 104 | dim_t = torch.arange(cdim, dtype=torch.float32, device=xyz.device) 105 | dim_t = self.temperature ** (2 * (dim_t // 2) / cdim) 106 | 107 | # create batch x cdim x nccords embedding 108 | raw_pos = xyz[:, :, d] 109 | if self.scale: 110 | raw_pos *= self.scale 111 | pos = raw_pos[:, :, None] / dim_t 112 | pos = torch.stack( 113 | (pos[:, :, 0::2].sin(), pos[:, :, 1::2].cos()), dim=3 114 | ).flatten(2) 115 | final_embeds.append(pos) 116 | prev_dim = cdim 117 | 118 | final_embeds = torch.cat(final_embeds, dim=2).permute(0, 2, 1) 119 | return final_embeds 120 | 121 | def get_fourier_embeddings(self, xyz, num_channels=None, input_range=None): 122 | # Follows - https://people.eecs.berkeley.edu/~bmild/fourfeat/index.html 123 | 124 | if num_channels is None: 125 | num_channels = self.gauss_B.shape[1] * 2 126 | 127 | bsize, npoints = xyz.shape[0], xyz.shape[1] 128 | assert num_channels > 0 and num_channels % 2 == 0 129 | d_in, max_d_out = self.gauss_B.shape[0], self.gauss_B.shape[1] 130 | d_out = num_channels // 2 131 | assert d_out <= max_d_out 132 | assert d_in == xyz.shape[-1] 133 | 134 | # clone coords so that shift/scale operations do not affect original tensor 135 | orig_xyz = xyz 136 | xyz = orig_xyz.clone() 137 | 138 | ncoords = xyz.shape[1] 139 | if self.normalize: 140 | xyz = shift_scale_points(xyz, src_range=input_range) 141 | 142 | xyz *= 2 * np.pi 143 | xyz_proj = torch.mm(xyz.view(-1, d_in), self.gauss_B[:, :d_out]).view( 144 | bsize, npoints, d_out 145 | ) 146 | final_embeds = [xyz_proj.sin(), xyz_proj.cos()] 147 | 148 | # return batch x d_pos x npoints embedding 149 | final_embeds = torch.cat(final_embeds, dim=2).permute(0, 2, 1) 150 | return final_embeds 151 | 152 | def forward(self, xyz, num_channels=None, input_range=None): 153 | assert isinstance(xyz, torch.Tensor) 154 | assert xyz.ndim == 3 155 | # xyz is batch x npoints x 3 156 | if self.pos_type == "sine": 157 | with torch.no_grad(): 158 | out = self.get_sine_embeddings(xyz, num_channels, input_range) 159 | elif self.pos_type == "fourier": 160 | with torch.no_grad(): 161 | out = self.get_fourier_embeddings(xyz, num_channels, input_range) 162 | else: 163 | raise ValueError(f"Unknown {self.pos_type}") 164 | 165 | return out 166 | 167 | def extra_repr(self): 168 | st = f"type={self.pos_type}, scale={self.scale}, normalize={self.normalize}" 169 | if hasattr(self, "gauss_B"): 170 | st += ( 171 | f", gaussB={self.gauss_B.shape}, gaussBsum={self.gauss_B.sum().item()}" 172 | ) 173 | return st -------------------------------------------------------------------------------- /src/models/SphereFormer/SparseTransformer/sptr/utils.py: -------------------------------------------------------------------------------- 1 | import numbers 2 | import torch 3 | import numpy as np 4 | from torch_scatter import segment_csr, gather_csr 5 | from torch_geometric.nn import voxel_grid 6 | from . import precompute_all 7 | 8 | 9 | def to_3d_numpy(size): 10 | if isinstance(size, numbers.Number): 11 | size = np.array([size, size, size]).astype(np.float32) 12 | elif isinstance(size, list): 13 | size = np.array(size) 14 | elif isinstance(size, np.ndarray): 15 | size = size 16 | else: 17 | raise ValueError("size is either a number, or a list, or a np.ndarray") 18 | return size 19 | 20 | def grid_sample(pos, batch, size, start, return_p2v=True, return_counts=True, return_unique=False): 21 | # pos: float [N, 3] 22 | # batch: long [N] 23 | # size: float [3, ] 24 | # start: float [3, ] / None 25 | cluster = voxel_grid(pos=pos, batch=batch, size=size, start=start) #[N, ] 26 | 27 | if return_p2v == False and return_counts == False: 28 | unique, cluster = torch.unique(cluster, sorted=True, return_inverse=True) 29 | return cluster 30 | 31 | unique, cluster, counts = torch.unique(cluster, sorted=True, return_inverse=True, return_counts=True) 32 | 33 | if return_p2v == False and return_counts == True: 34 | return cluster, counts.max().item(), counts 35 | 36 | # obtain p2v_map 37 | n = unique.shape[0] 38 | k = counts.max().item() 39 | p2v_map = cluster.new_zeros(n, k) #[n, k] 40 | mask = torch.arange(k).cuda().unsqueeze(0) < counts.unsqueeze(-1) #[n, k] 41 | p2v_map[mask] = torch.argsort(cluster) 42 | 43 | if return_unique: 44 | return cluster, p2v_map, counts, unique 45 | 46 | return cluster, p2v_map, counts 47 | 48 | def get_indices_params(xyz, batch, window_size, shift_win: bool): 49 | 50 | if isinstance(window_size, list) or isinstance(window_size, np.ndarray): 51 | window_size = torch.from_numpy(window_size).type_as(xyz).to(xyz.device) 52 | else: 53 | window_size = torch.tensor([window_size]*3).type_as(xyz).to(xyz.device) 54 | 55 | if shift_win: 56 | v2p_map, k, counts = grid_sample(xyz+1/2*window_size, batch, window_size, start=xyz.min(0)[0], return_p2v=False, return_counts=True) 57 | else: 58 | v2p_map, k, counts = grid_sample(xyz, batch, window_size, start=None, return_p2v=False, return_counts=True) 59 | 60 | v2p_map, sort_idx = v2p_map.sort() 61 | 62 | n = counts.shape[0] 63 | N = v2p_map.shape[0] 64 | 65 | n_max = k 66 | index_0_offsets, index_1_offsets, index_0, index_1 = precompute_all(N, n, n_max, counts) 67 | index_0 = index_0.long() 68 | index_1 = index_1.long() 69 | 70 | return index_0, index_0_offsets, n_max, index_1, index_1_offsets, sort_idx 71 | 72 | def scatter_softmax_csr(src: torch.Tensor, indptr: torch.Tensor, dim: int = -1): 73 | ''' src: (N, C), 74 | index: (Ni+1, ), [0, n0^2, n0^2+n1^2, ...] 75 | ''' 76 | max_value_per_index = segment_csr(src, indptr, reduce='max') 77 | max_per_src_element = gather_csr(max_value_per_index, indptr) 78 | 79 | recentered_scores = src - max_per_src_element 80 | recentered_scores_exp = recentered_scores.exp_() 81 | 82 | sum_per_index = segment_csr( 83 | recentered_scores_exp, indptr, reduce='sum') 84 | 85 | normalizing_constants = gather_csr(sum_per_index, indptr) 86 | 87 | return recentered_scores_exp.div(normalizing_constants) 88 | -------------------------------------------------------------------------------- /src/models/SphereFormer/SparseTransformer/src/sptr/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raktimgg/SALSA/9152e8dff6fb3fc07dd546c9f19bfab462d08a6f/src/models/SphereFormer/SparseTransformer/src/sptr/__init__.py -------------------------------------------------------------------------------- /src/models/SphereFormer/SparseTransformer/src/sptr/attention/attention_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include "attention_cuda_kernel.h" 6 | 7 | 8 | void attention_step1_forward_cuda(int N_q, int N_k, int M, int h, int hdim, const unsigned int n_max, at::Tensor q_tensor, at::Tensor k_tensor, 9 | at::Tensor index0_tensor, at::Tensor index1_tensor, at::Tensor attn_tensor) 10 | { 11 | const float *q = q_tensor.data_ptr(); 12 | const float *k = k_tensor.data_ptr(); 13 | const int *index0 = index0_tensor.data_ptr(); 14 | const int *index1 = index1_tensor.data_ptr(); 15 | float *attn = attn_tensor.data_ptr(); 16 | attention_step1_forward_cuda_launcher(N_q, N_k, M, h, hdim, n_max, q, k, index0, index1, attn); 17 | } 18 | 19 | void attention_step1_backward_cuda(int N, int M, int h, int hdim, const unsigned int n_max, at::Tensor grad_out_tensor, 20 | at::Tensor index0_tensor, at::Tensor index0_tensor_offsets, at::Tensor index1_tensor, at::Tensor index1_tensor_offsets, at::Tensor q_tensor, at::Tensor k_tensor, 21 | at::Tensor grad_q_tensor, at::Tensor grad_k_tensor) 22 | { 23 | const float *grad_out = grad_out_tensor.data_ptr(); 24 | const int *index0 = index0_tensor.data_ptr(); 25 | const int *index0_offsets = index0_tensor_offsets.data_ptr(); 26 | const int *index1 = index1_tensor.data_ptr(); 27 | const int *index1_offsets = index1_tensor_offsets.data_ptr(); 28 | const float *q = q_tensor.data_ptr(); 29 | const float *k = k_tensor.data_ptr(); 30 | float *grad_q = grad_q_tensor.data_ptr(); 31 | float *grad_k = grad_k_tensor.data_ptr(); 32 | attention_step1_backward_cuda_launcher(N, M, h, hdim, n_max, grad_out, index0, index0_offsets, index1, index1_offsets, q, k, grad_q, grad_k); 33 | } 34 | 35 | void attention_step2_forward_cuda(int N, int M, int h, int hdim, int n_max, at::Tensor attn_tensor, at::Tensor v_tensor, 36 | at::Tensor index0_offsets_tensor, at::Tensor index1_tensor, at::Tensor output_tensor) 37 | { 38 | const float *attn = attn_tensor.data_ptr(); 39 | const float *v = v_tensor.data_ptr(); 40 | const int *index0_offsets = index0_offsets_tensor.data_ptr(); 41 | const int *index1 = index1_tensor.data_ptr(); 42 | float *output = output_tensor.data_ptr(); 43 | attention_step2_forward_cuda_launcher(N, M, h, hdim, n_max, attn, v, index0_offsets, index1, output); 44 | } 45 | 46 | void attention_step2_backward_cuda(int N, int M, int h, int hdim, int n_max, at::Tensor grad_out_tensor, at::Tensor index0_tensor, 47 | at::Tensor index0_offsets_tensor, at::Tensor index1_tensor, at::Tensor index1_offsets_tensor, at::Tensor attn_tensor, at::Tensor v_tensor, 48 | at::Tensor grad_attn_tensor, at::Tensor grad_v_tensor) 49 | { 50 | const float *grad_out = grad_out_tensor.data_ptr(); 51 | const int *index0 = index0_tensor.data_ptr(); 52 | const int *index0_offsets = index0_offsets_tensor.data_ptr(); 53 | const int *index1 = index1_tensor.data_ptr(); 54 | const int *index1_offsets = index1_offsets_tensor.data_ptr(); 55 | const float *attn = attn_tensor.data_ptr(); 56 | const float *v = v_tensor.data_ptr(); 57 | float *grad_attn = grad_attn_tensor.data_ptr(); 58 | float *grad_v = grad_v_tensor.data_ptr(); 59 | attention_step2_backward_cuda_launcher(N, M, h, hdim, n_max, grad_out, index0, index0_offsets, index1, index1_offsets, attn, v, grad_attn, grad_v); 60 | } 61 | -------------------------------------------------------------------------------- /src/models/SphereFormer/SparseTransformer/src/sptr/attention/attention_cuda_kernel.cu: -------------------------------------------------------------------------------- 1 | #include "../cuda_utils.h" 2 | #include "attention_cuda_kernel.h" 3 | 4 | __global__ void attention_step1_forward_cuda_kernel( // M, h, C//h 5 | int N_q, int N_k, int M, int h, int d, const float *q, const float *k, 6 | const int *index0, const int *index1, float *attn) { 7 | // q: [N, h, d], k: [h, d, N], index0: [M], index1: [M], attn: [h, M] 8 | 9 | int h_idx = blockIdx.y; 10 | int m_idx = blockIdx.x * blockDim.x + threadIdx.x; 11 | 12 | if(m_idx >= M) return; 13 | float s = 0.0; 14 | int index_q = index0[m_idx], index_k = index1[m_idx]; 15 | for(int i = 0; i < d; i++){ 16 | s += q[h_idx * d * N_q + i * N_q + index_q] * k[h_idx * d * N_k + i * N_k + index_k]; 17 | } 18 | attn[h_idx * M + m_idx] = s; 19 | } 20 | 21 | void attention_step1_forward_cuda_launcher(int N_q, int N_k, int M, int h, int hdim, const unsigned int n_max, 22 | const float *q, const float *k, const int *index0, const int *index1, float *attn) { 23 | // input: attn: (h, M), index0: (M, ), index1: (M, ) 24 | unsigned int n_threads = 512; 25 | dim3 blocks((M + n_threads - 1) / n_threads, h); 26 | attention_step1_forward_cuda_kernel<<>>(N_q, N_k, M, h, hdim, q, k, index0, index1, attn); 27 | } 28 | 29 | __global__ void attention_step1_backward_cuda_kernel( // M, h, C//h 30 | int N, int M, int h, int d, const float *grad_out, const int *index0, const int *index0_offsets, const int *index1, const int *index1_offsets, 31 | const float *q, const float *k, float *grad_q, float *grad_k) { 32 | // q: [N, h, d], k: [N, h, d], index0: [M], index1: [M], attn: [M, h], grad_out: [M, h] 33 | // grad_q: [N, h, hdim], grad_k: [N, h, hdim] 34 | 35 | int n_h = blockDim.x; 36 | int h_idx = blockIdx.y * n_h + threadIdx.y; 37 | int q_idx = blockIdx.x; 38 | int d_idx = threadIdx.x; 39 | int C = d * h; 40 | 41 | int start = index0_offsets[q_idx], end = index0_offsets[q_idx+1]; 42 | int n = end - start; 43 | 44 | float grad_q_val = 0; 45 | for(int i = 0; i < n; i++){ 46 | int start_i = start + i; 47 | float grad_out_val = grad_out[start_i*h + h_idx]; 48 | int k_idx = index1[start_i]; 49 | grad_q_val += grad_out_val * k[k_idx*C + h_idx*d + d_idx]; 50 | } 51 | grad_q[q_idx*C + h_idx*d + d_idx] = grad_q_val; 52 | 53 | float grad_k_val = 0; 54 | int start_k = index1_offsets[q_idx]; 55 | for(int i = 0; i < n; i++){ 56 | int start_i = start_k + i*n; 57 | float grad_out_val = grad_out[start_i*h + h_idx]; 58 | int query_idx = index0[start_i]; 59 | grad_k_val += grad_out_val * q[query_idx*C + h_idx*d + d_idx]; 60 | } 61 | grad_k[q_idx*C + h_idx*d + d_idx] = grad_k_val; 62 | } 63 | 64 | void attention_step1_backward_cuda_launcher(int N, int M, int h, int hdim, const unsigned int n_max, 65 | const float *grad_out, const int *index0, const int *index0_offsets, const int *index1, const int *index1_offsets, const float *q, const float *k, float *grad_q, float *grad_k) { 66 | // input: grad_output: (n, nsample, c), output: grad_input1: (n, c), grad_input2: (n, c) 67 | 68 | unsigned int n_h = h*hdim > 512 ? 512 / hdim : h; 69 | 70 | dim3 blocks(N, h/n_h); 71 | dim3 threads(hdim, n_h); 72 | 73 | attention_step1_backward_cuda_kernel<<>>(N, M, h, hdim, grad_out, index0, index0_offsets, index1, index1_offsets, q, k, grad_q, grad_k); 74 | 75 | } 76 | 77 | __global__ void attention_step2_forward_cuda_kernel( // M, h, hdim 78 | int N, int M, const int h, int d, const float *attn, const float *v, 79 | const int *index0_offsets, const int *index1, float *output) { 80 | // input: attn: (M, h), v: (N, h, hdim), index0: (M, ), index1: (M, ), table: (L, 3, h, hdim), rel_idx: (M, 3) 81 | 82 | int q_idx = blockIdx.x; 83 | int n_h = blockDim.x; 84 | int h_idx = blockIdx.y * n_h + threadIdx.y; 85 | int d_idx = threadIdx.x; 86 | 87 | int C = h*d; 88 | 89 | int start = index0_offsets[q_idx], end = index0_offsets[q_idx+1]; 90 | int n = end - start; 91 | float sum = 0; 92 | for(int i = 0; i < n; i++){ 93 | int start_i = start + i; 94 | int k_idx = index1[start_i]; 95 | float v_val = v[k_idx*C + h_idx*d + d_idx]; 96 | sum += attn[start_i*h + h_idx] * v_val; 97 | } 98 | output[q_idx*C + h_idx*d + d_idx] = sum; 99 | } 100 | 101 | void attention_step2_forward_cuda_launcher(int N, int M, const int h, int hdim, int n_max, const float *attn, const float *v, const int *index0_offsets, 102 | const int *index1, float *output) { 103 | // input: attn: (M, h), v: (N, h, hdim), index0: (M, ), index1: (M, ), table: (L, h, hdim, 3), rel_idx: (M, 3) 104 | unsigned int n_h = h*hdim > 512 ? 512 / hdim : h; 105 | dim3 blocks(N, h/n_h); 106 | dim3 threads(hdim, n_h); 107 | attention_step2_forward_cuda_kernel<<>>(N, M, h, hdim, attn, v, index0_offsets, index1, output); 108 | } 109 | 110 | __global__ void attention_step2_grad_v_backward_cuda_kernel( // M, h, hdim 111 | int N, int M, int h, int hdim, const float *grad_out, const int *index0, const int *index0_offsets, const int *index1, const int *index1_offsets, const float *attn, const float *v, 112 | float *grad_v) { 113 | // input: attn: (M, h), v: (h, hdim, N), index0: (M, ), index1: (M, ), rel_idx: (3, M) 114 | 115 | int q_idx = blockIdx.x; 116 | int n_h = blockDim.x; 117 | int h_idx = blockIdx.y * n_h + threadIdx.y; 118 | int d_idx = threadIdx.x; 119 | 120 | int C = h*hdim; 121 | 122 | int start = index0_offsets[q_idx], end = index0_offsets[q_idx+1]; 123 | int n = end - start; 124 | int start_k = index1_offsets[q_idx]; 125 | float grad_v_val = 0; 126 | 127 | for(int i = 0; i < n; i ++){ 128 | int start_i = start_k + i*n; 129 | int query_idx = index0[start_i]; 130 | float grad_out_val = grad_out[query_idx*C + h_idx*hdim + d_idx]; 131 | 132 | float grad_val = attn[start_i*h + h_idx] * grad_out_val; 133 | grad_v_val += grad_val; 134 | } 135 | grad_v[q_idx*C + h_idx*hdim + d_idx] = grad_v_val; 136 | 137 | } 138 | 139 | void attention_step2_backward_cuda_launcher(int N, int M, int h, int hdim, int n_max, const float *grad_out, const int *index0, const int *index0_offsets, 140 | const int *index1, const int *index1_offsets, const float *attn, const float *v, float *grad_attn, float *grad_v) { 141 | // input: grad_out: (N, h, hdim) 142 | 143 | unsigned int n_h = h*hdim > 512 ? 512 / hdim : h; 144 | dim3 blocks(N, h/n_h); 145 | dim3 threads(hdim, n_h); 146 | attention_step2_grad_v_backward_cuda_kernel<<>>(N, M, h, hdim, grad_out, index0, index0_offsets, index1, index1_offsets, attn, v, grad_v); 147 | 148 | unsigned int n_threads = 512; 149 | dim3 blocks_2((M + n_threads - 1) / n_threads, h); 150 | 151 | attention_step1_forward_cuda_kernel<<>>(N, N, M, h, hdim, grad_out, v, index0, index1, grad_attn); 152 | } 153 | -------------------------------------------------------------------------------- /src/models/SphereFormer/SparseTransformer/src/sptr/attention/attention_cuda_kernel.h: -------------------------------------------------------------------------------- 1 | #ifndef _ATTENTION_CUDA_KERNEL 2 | #define _ATTENTION_CUDA_KERNEL 3 | #include 4 | #include 5 | #include 6 | 7 | void attention_step1_forward_cuda(int N_q, int N_k, int M, int h, int hdim, const unsigned int n_max, at::Tensor q_tensor, at::Tensor k_tensor, at::Tensor index0_tensor, at::Tensor index1_tensor, at::Tensor attn_tensor); 8 | void attention_step1_backward_cuda(int N, int M, int h, int hdim, const unsigned int n_max, at::Tensor grad_out_tensor, at::Tensor index0_tensor, at::Tensor index0_tensor_offsets, at::Tensor index1_tensor, at::Tensor index1_tensor_offsets, at::Tensor q_tensor, at::Tensor k_tensor, at::Tensor grad_q_tensor, at::Tensor grad_k_tensor); 9 | 10 | void attention_step2_forward_cuda(int N, int M, int h, int hdim, int n_max, at::Tensor attn_tensor, at::Tensor v_tensor, at::Tensor index0_offsets_tensor, at::Tensor index1_tensor, at::Tensor output_tensor); 11 | void attention_step2_backward_cuda(int N, int M, int h, int hdim, int n_max, at::Tensor grad_out_tensor, at::Tensor index0_tensor, at::Tensor index0_offsets_tensor, at::Tensor index1_tensor, at::Tensor index1_offsets_tensor, at::Tensor attn_tensor, at::Tensor v_tensor, at::Tensor grad_attn_tensor, at::Tensor grad_v_tensor); 12 | 13 | #ifdef __cplusplus 14 | extern "C" { 15 | #endif 16 | 17 | void attention_step1_forward_cuda_launcher(int N_q, int N_k, int M, int h, int hdim, const unsigned int n_max, const float *q, const float *k, const int *index0, const int *index1, float *attn); 18 | void attention_step1_backward_cuda_launcher(int N, int M, int h, int hdim, const unsigned int n_max, const float *grad_out, const int *index0, const int *index0_offsets, const int *index1, const int *index1_offsets, const float *q, const float *k, float *grad_q, float *grad_k); 19 | 20 | void attention_step2_forward_cuda_launcher(int N, int M, const int h, int hdim, int n_max, const float *attn, const float *v, const int *index0_offsets, const int *index1, float *output); 21 | void attention_step2_backward_cuda_launcher(int N, int M, int h, int hdim, int n_max, const float *grad_out, const int *index0, const int *index0_offsets, const int *index1, const int *index1_offsets, const float *attn, const float *v, float *grad_attn, float *grad_v); 22 | 23 | #ifdef __cplusplus 24 | } 25 | #endif 26 | #endif 27 | -------------------------------------------------------------------------------- /src/models/SphereFormer/SparseTransformer/src/sptr/cuda_utils.h: -------------------------------------------------------------------------------- 1 | #ifndef _CUDA_UTILS_H 2 | #define _CUDA_UTILS_H 3 | 4 | #include 5 | #include 6 | 7 | #define TOTAL_THREADS 1024 8 | #define THREADS_PER_BLOCK 256 9 | #define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0)) 10 | 11 | inline int opt_n_threads(int work_size) { 12 | const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0); 13 | return std::max(std::min(1 << pow_2, TOTAL_THREADS), 1); 14 | } 15 | 16 | inline dim3 opt_block_config(int x, int y) { 17 | const int x_threads = opt_n_threads(x); 18 | const int y_threads = std::max(std::min(opt_n_threads(y), TOTAL_THREADS / x_threads), 1); 19 | dim3 block_config(x_threads, y_threads, 1); 20 | return block_config; 21 | } 22 | 23 | #endif 24 | -------------------------------------------------------------------------------- /src/models/SphereFormer/SparseTransformer/src/sptr/pointops_api.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "attention/attention_cuda_kernel.h" 5 | #include "rpe/relative_pos_encoding_cuda_kernel.h" 6 | #include "precompute/precompute_cuda_kernel.h" 7 | 8 | 9 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 10 | m.def("attention_step1_forward_cuda", &attention_step1_forward_cuda, "attention_step1_forward_cuda"); 11 | m.def("attention_step1_backward_cuda", &attention_step1_backward_cuda, "attention_step1_backward_cuda"); 12 | m.def("attention_step2_forward_cuda", &attention_step2_forward_cuda, "attention_step2_forward_cuda"); 13 | m.def("attention_step2_backward_cuda", &attention_step2_backward_cuda, "attention_step2_backward_cuda"); 14 | m.def("precompute_all_cuda", &precompute_all_cuda, "precompute_all_cuda"); 15 | m.def("dot_prod_with_idx_forward_cuda", &dot_prod_with_idx_forward_cuda, "dot_prod_with_idx_forward_cuda"); 16 | m.def("dot_prod_with_idx_backward_cuda", &dot_prod_with_idx_backward_cuda, "dot_prod_with_idx_backward_cuda"); 17 | m.def("attention_step2_with_rel_pos_value_forward_cuda", &attention_step2_with_rel_pos_value_forward_cuda, "attention_step2_with_rel_pos_value_forward_cuda"); 18 | m.def("attention_step2_with_rel_pos_value_backward_cuda", &attention_step2_with_rel_pos_value_backward_cuda, "attention_step2_with_rel_pos_value_backward_cuda"); 19 | m.def("dot_prod_with_idx_all_forward_cuda", &dot_prod_with_idx_all_forward_cuda, "dot_prod_with_idx_all_forward_cuda"); 20 | } 21 | -------------------------------------------------------------------------------- /src/models/SphereFormer/SparseTransformer/src/sptr/precompute/precompute.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include "precompute_cuda_kernel.h" 6 | 7 | void precompute_all_cuda(int N, int n, const unsigned int n_max, at::Tensor counts_tensor, at::Tensor offsets_tensor, at::Tensor sq_offsets_tensor, at::Tensor index_0_offsets_tensor, at::Tensor index_1_offsets_tensor, at::Tensor index_0_tensor, at::Tensor index_1_tensor) 8 | { 9 | const int *counts = counts_tensor.data_ptr(); 10 | const int *offsets = offsets_tensor.data_ptr(); 11 | const int *sq_offsets = sq_offsets_tensor.data_ptr(); 12 | int *index_0_offsets = index_0_offsets_tensor.data_ptr(); 13 | int *index_1_offsets = index_1_offsets_tensor.data_ptr(); 14 | int *index_0 = index_0_tensor.data_ptr(); 15 | int *index_1 = index_1_tensor.data_ptr(); 16 | precompute_all_cuda_launcher(N, n, n_max, counts, offsets, sq_offsets, index_0_offsets, index_1_offsets, index_0, index_1); 17 | } 18 | -------------------------------------------------------------------------------- /src/models/SphereFormer/SparseTransformer/src/sptr/precompute/precompute_cuda_kernel.cu: -------------------------------------------------------------------------------- 1 | #include "../cuda_utils.h" 2 | #include "precompute_cuda_kernel.h" 3 | 4 | __global__ void precompute_all_cuda_kernel( // M, h, C//h 5 | int N, int n, int k, const int *counts, const int *offsets, const int *sq_offsets, int *index0_offsets, int *index1_offsets, int *index0, int *index1) { 6 | // counts: (n), sq_offsets: (n), index0_offsets: (n), index1_offsets: (n) 7 | 8 | int n_idx = blockIdx.x; 9 | int thread_idx = threadIdx.x; 10 | 11 | int start = offsets[n_idx]; 12 | int start_val = sq_offsets[n_idx]; 13 | int length = counts[n_idx]; 14 | for(int t_idx = thread_idx; t_idx < length; t_idx += blockDim.x){ 15 | index0_offsets[start+t_idx] = start_val + length * t_idx; 16 | index1_offsets[start+t_idx] = start_val + t_idx; 17 | for(int i = 0; i < length; i++){ 18 | index0[start_val + i*length + t_idx] = start+i; 19 | index1[start_val + i*length + t_idx] = start+t_idx; 20 | } 21 | } 22 | } 23 | 24 | void precompute_all_cuda_launcher(int N, int n, const unsigned int n_max, const int *counts, 25 | const int *offsets, const int *sq_offsets, int *index_0_offsets, int *index_1_offsets, int *index_0, int *index_1) { 26 | // input: attn: (M, h), index0: (M, ), index1: (M, ) 27 | 28 | unsigned int blocks = n; 29 | unsigned int n_threads = opt_n_threads(n_max); 30 | n_threads = n_threads == n_max ? n_threads : n_threads * 2; 31 | n_threads = n_threads > 1024 ? 1024 : n_threads; 32 | 33 | precompute_all_cuda_kernel<<>>(N, n, n_max, counts, offsets, sq_offsets, index_0_offsets, index_1_offsets, index_0, index_1); 34 | 35 | } -------------------------------------------------------------------------------- /src/models/SphereFormer/SparseTransformer/src/sptr/precompute/precompute_cuda_kernel.h: -------------------------------------------------------------------------------- 1 | #ifndef PRECOMPUTE_CUDA_KERNEL 2 | #define PRECOMPUTE_CUDA_KERNEL 3 | #include 4 | #include 5 | #include 6 | 7 | void precompute_all_cuda(int N, int n, const unsigned int n_max, at::Tensor counts_tensor, at::Tensor offsets_tensor, at::Tensor sq_offsets_tensor, at::Tensor index_0_offsets_tensor, at::Tensor index_1_offsets_tensor, at::Tensor index_0, at::Tensor index_1); 8 | 9 | #ifdef __cplusplus 10 | extern "C" { 11 | #endif 12 | 13 | void precompute_all_cuda_launcher(int N, int n, const unsigned int n_max, const int *counts, const int *offsets, const int *sq_offsets, int *index_0_offsets, int *index_1_offsets, int *index_0, int *index_1); 14 | 15 | #ifdef __cplusplus 16 | } 17 | #endif 18 | #endif 19 | -------------------------------------------------------------------------------- /src/models/SphereFormer/SparseTransformer/src/sptr/rpe/relative_pos_encoding_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include "relative_pos_encoding_cuda_kernel.h" 6 | 7 | void dot_prod_with_idx_forward_cuda(int N, int M, int h, int hdim, int n_max, const int L, at::Tensor q_tensor, 8 | at::Tensor index_q_tensor, at::Tensor index_q_offsets_tensor, at::Tensor k_tensor, at::Tensor index_k_tensor, at::Tensor table_q_tensor, 9 | at::Tensor table_k_tensor, at::Tensor rel_idx_tensor, at::Tensor output_tensor) 10 | { 11 | const float *q = q_tensor.data_ptr(); 12 | const int *index_q = index_q_tensor.data_ptr(); 13 | const int *index_q_offsets = index_q_offsets_tensor.data_ptr(); 14 | const float *k = k_tensor.data_ptr(); 15 | const int *index_k = index_k_tensor.data_ptr(); 16 | const float *table_q = table_q_tensor.data_ptr(); 17 | const float *table_k = table_k_tensor.data_ptr(); 18 | const int *rel_idx = rel_idx_tensor.data_ptr(); 19 | float *output = output_tensor.data_ptr(); 20 | dot_prod_with_idx_forward_cuda_launcher(N, M, h, hdim, n_max, L, q, index_q, index_q_offsets, k, index_k, table_q, table_k, rel_idx, output); 21 | } 22 | 23 | void dot_prod_with_idx_backward_cuda(int N, int M, int h, int hdim, int n_max, const int L, at::Tensor grad_out_tensor, 24 | at::Tensor q_tensor, at::Tensor index_q_offsets_tensor, at::Tensor k_tensor, at::Tensor index_k_offsets_tensor, at::Tensor index_k_tensor, 25 | at::Tensor table_q_tensor, at::Tensor table_k_tensor, at::Tensor rel_idx_tensor, at::Tensor grad_q_tensor, 26 | at::Tensor grad_k_tensor, at::Tensor grad_table_q_tensor, at::Tensor grad_table_k_tensor) 27 | { 28 | const float *grad_out = grad_out_tensor.data_ptr(); 29 | const float *q = q_tensor.data_ptr(); 30 | const int *index_q_offsets = index_q_offsets_tensor.data_ptr(); 31 | const float *k = k_tensor.data_ptr(); 32 | const int *index_k_offsets = index_k_offsets_tensor.data_ptr(); 33 | const int *index_k = index_k_tensor.data_ptr(); 34 | const float *table_q = table_q_tensor.data_ptr(); 35 | const float *table_k = table_k_tensor.data_ptr(); 36 | const int *rel_idx = rel_idx_tensor.data_ptr(); 37 | float *grad_q = grad_q_tensor.data_ptr(); 38 | float *grad_k = grad_k_tensor.data_ptr(); 39 | float *grad_table_q = grad_table_q_tensor.data_ptr(); 40 | float *grad_table_k = grad_table_k_tensor.data_ptr(); 41 | dot_prod_with_idx_backward_cuda_launcher(N, M, h, hdim, n_max, L, grad_out, q, index_q_offsets, k, index_k_offsets, index_k, table_q, table_k, rel_idx, grad_q, grad_k, grad_table_q, grad_table_k); 42 | } 43 | 44 | void dot_prod_with_idx_all_forward_cuda(int N, int M, int h, int hdim, int n_max, const int L, at::Tensor q_tensor, 45 | at::Tensor index_q_tensor, at::Tensor index_q_offsets_tensor, at::Tensor k_tensor, at::Tensor index_k_tensor, at::Tensor table_q_tensor, 46 | at::Tensor table_k_tensor, at::Tensor rel_idx_tensor, at::Tensor output_tensor) 47 | { 48 | const float *q = q_tensor.data_ptr(); 49 | const int *index_q = index_q_tensor.data_ptr(); 50 | const int *index_q_offsets = index_q_offsets_tensor.data_ptr(); 51 | const float *k = k_tensor.data_ptr(); 52 | const int *index_k = index_k_tensor.data_ptr(); 53 | const float *table_q = table_q_tensor.data_ptr(); 54 | const float *table_k = table_k_tensor.data_ptr(); 55 | const int *rel_idx = rel_idx_tensor.data_ptr(); 56 | float *output = output_tensor.data_ptr(); 57 | dot_prod_with_idx_all_forward_cuda_launcher(N, M, h, hdim, n_max, L, q, index_q, index_q_offsets, k, index_k, table_q, table_k, rel_idx, output); 58 | } 59 | 60 | void attention_step2_with_rel_pos_value_forward_cuda(int N, int M, int h, int hdim, int n_max, at::Tensor attn_tensor, at::Tensor v_tensor, 61 | at::Tensor index0_offsets_tensor, at::Tensor index1_tensor, at::Tensor table_tensor, at::Tensor rel_idx_tensor, at::Tensor output_tensor) 62 | { 63 | const float *attn = attn_tensor.data_ptr(); 64 | const float *v = v_tensor.data_ptr(); 65 | const int *index0_offsets = index0_offsets_tensor.data_ptr(); 66 | const int *index1 = index1_tensor.data_ptr(); 67 | const float *table = table_tensor.data_ptr(); 68 | const int *rel_idx = rel_idx_tensor.data_ptr(); 69 | float *output = output_tensor.data_ptr(); 70 | attention_step2_with_rel_pos_value_forward_cuda_launcher(N, M, h, hdim, n_max, attn, v, index0_offsets, index1, table, rel_idx, output); 71 | } 72 | 73 | void attention_step2_with_rel_pos_value_backward_cuda(int N, int M, int h, int hdim, int L, int n_max, at::Tensor grad_out_tensor, at::Tensor index0_tensor, 74 | at::Tensor index0_offsets_tensor, at::Tensor index1_tensor, at::Tensor index1_offsets_tensor, at::Tensor attn_tensor, at::Tensor v_tensor, at::Tensor table_tensor, 75 | at::Tensor rel_idx_tensor, at::Tensor grad_attn_tensor, at::Tensor grad_v_tensor, at::Tensor grad_table_tensor) 76 | { 77 | const float *grad_out = grad_out_tensor.data_ptr(); 78 | const int *index0 = index0_tensor.data_ptr(); 79 | const int *index0_offsets = index0_offsets_tensor.data_ptr(); 80 | const int *index1 = index1_tensor.data_ptr(); 81 | const int *index1_offsets = index1_offsets_tensor.data_ptr(); 82 | const float *attn = attn_tensor.data_ptr(); 83 | const float *v = v_tensor.data_ptr(); 84 | const float *table = table_tensor.data_ptr(); 85 | const int *rel_idx = rel_idx_tensor.data_ptr(); 86 | float *grad_attn = grad_attn_tensor.data_ptr(); 87 | float *grad_v = grad_v_tensor.data_ptr(); 88 | float *grad_table = grad_table_tensor.data_ptr(); 89 | attention_step2_with_rel_pos_value_backward_cuda_launcher(N, M, h, hdim, L, n_max, grad_out, index0, index0_offsets, index1, index1_offsets, attn, v, table, rel_idx, grad_attn, grad_v, grad_table); 90 | } 91 | -------------------------------------------------------------------------------- /src/models/SphereFormer/SparseTransformer/src/sptr/rpe/relative_pos_encoding_cuda_kernel.h: -------------------------------------------------------------------------------- 1 | #ifndef _RPE_CUDA_KERNEL 2 | #define _RPE_CUDA_KERNEL 3 | #include 4 | #include 5 | #include 6 | 7 | void dot_prod_with_idx_forward_cuda(int N, int M, int h, int hdim, int n_max, const int L, at::Tensor q_tensor, at::Tensor index_q_tensor, at::Tensor index_q_offsets_tensor, at::Tensor k_tensor, at::Tensor index_k_tensor, at::Tensor table_q_tensor, at::Tensor table_k_tensor, at::Tensor rel_idx_tensor, at::Tensor output_tensor); 8 | void dot_prod_with_idx_backward_cuda(int N, int M, int h, int hdim, int n_max, const int L, at::Tensor grad_out_tensor, at::Tensor q_tensor, at::Tensor index_q_offsets_tensor, at::Tensor k_tensor, at::Tensor index_k_offsets_tensor, at::Tensor index_k_tensor, at::Tensor table_q_tensor, at::Tensor table_k_tensor, at::Tensor rel_idx_tensor, at::Tensor grad_q_tensor, at::Tensor grad_k_tensor, at::Tensor grad_table_q_tensor, at::Tensor grad_table_k_tensor); 9 | 10 | void dot_prod_with_idx_all_forward_cuda(int N, int M, int h, int hdim, int n_max, const int L, at::Tensor q_tensor, at::Tensor index_q_tensor, at::Tensor index_q_offsets_tensor, at::Tensor k_tensor, at::Tensor index_k_tensor, at::Tensor table_q_tensor, at::Tensor table_k_tensor, at::Tensor rel_idx_tensor, at::Tensor output_tensor); 11 | 12 | void attention_step2_with_rel_pos_value_forward_cuda(int N, int M, int h, int hdim, int n_max, at::Tensor attn_tensor, at::Tensor v_tensor, at::Tensor index0_offsets_tensor, at::Tensor index1_tensor, at::Tensor table_tensor, at::Tensor rel_idx_tensor, at::Tensor output_tensor); 13 | void attention_step2_with_rel_pos_value_backward_cuda(int N, int M, int h, int hdim, int L, int n_max, at::Tensor grad_out_tensor, at::Tensor index0_tensor, at::Tensor index0_offsets_tensor, at::Tensor index1_tensor, at::Tensor index1_offsets_tensor, at::Tensor attn_tensor, at::Tensor v_tensor, at::Tensor table_tensor, at::Tensor rel_idx_tensor, at::Tensor grad_attn_tensor, at::Tensor grad_v_tensor, at::Tensor grad_table_tensor); 14 | 15 | #ifdef __cplusplus 16 | extern "C" { 17 | #endif 18 | 19 | void dot_prod_with_idx_forward_cuda_launcher(int N, int M, int h, int hdim, int n_max, const int L, const float *q, const int *index_q, const int *index_q_offsets, const float *k, const int *index_k, const float *table_q, const float *table_k, const int *rel_idx, float *output); 20 | void dot_prod_with_idx_backward_cuda_launcher(int N, int M, int h, int hdim, int n_max, const int L, const float *grad_out, const float *q, const int *index_q_offsets, const float *k, const int *index_k_offsets, const int *index_k, const float *table_q, const float *table_k, const int *rel_idx, float *grad_q, float *grad_k, float *grad_table_q, float *grad_table_k); 21 | 22 | void dot_prod_with_idx_all_forward_cuda_launcher(int N, int M, int h, int hdim, int n_max, const int L, const float *q, const int *index_q, const int *index_q_offsets, const float *k, const int *index_k, const float *table_q, const float *table_k, const int *rel_idx, float *output); 23 | 24 | void attention_step2_with_rel_pos_value_forward_cuda_launcher(int N, int M, int h, int hdim, int n_max, const float *attn, const float *v, const int *index0_offsets, const int *index1, const float *table, const int *rel_idx, float *output); 25 | void attention_step2_with_rel_pos_value_backward_cuda_launcher(int N, int M, int h, int hdim, int L, int n_max, const float *grad_out, const int *index0, const int *index0_offsets, const int *index1, const int *index1_offsets, const float *attn, const float *v, const float *table, const int *rel_idx, float *grad_attn, float *grad_v, float *grad_table); 26 | 27 | 28 | #ifdef __cplusplus 29 | } 30 | #endif 31 | #endif 32 | -------------------------------------------------------------------------------- /src/models/SphereFormer/SparseTransformer/test/test_attention_op_step1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pointops 3 | from torch_scatter import scatter_max, scatter_mean, scatter_add, scatter_min, scatter_sum 4 | import sys 5 | sys.path.append("..") 6 | import sptr 7 | 8 | torch.manual_seed(1) 9 | 10 | # M = 800000 11 | N = 35000 12 | n = 1500 13 | k = 100 14 | C = 96 15 | h = 6 16 | query = torch.rand(N, h, C//h).cuda() 17 | key = torch.rand(N, h, C//h).cuda() 18 | 19 | v2p_map = torch.randint(low=0, high=n, size=(N,)).cuda() 20 | v2p_map, _ = v2p_map.sort() 21 | counts = v2p_map.bincount() 22 | M = (counts**2).sum().item() 23 | 24 | N = v2p_map.shape[0] 25 | mask = torch.arange(k)[None].cuda().expand(n, -1) < counts[:, None] #[n, k] 26 | to_add = torch.arange(k)[None].cuda().expand(n, -1)[mask] 27 | v2p_map = v2p_map.long() 28 | p2v_map = torch.zeros(n, k).long().cuda() #torch.zeros_like(p2v_map) 29 | p2v_map[mask] = torch.arange(N).cuda() 30 | ctg_index_1_offsets = torch.cat([torch.zeros(1, dtype=torch.long).cuda(), (counts ** 2).cumsum(-1)], 0)[v2p_map] + to_add 31 | 32 | index_0_counts = counts[v2p_map.long()] #[N, ] 33 | index_0_offsets = index_0_counts.cumsum(-1) #[N, ] 34 | index_0_offsets = torch.cat([torch.zeros(1, dtype=torch.long).cuda(), index_0_offsets], 0) #[N+1] 35 | n_max = p2v_map.shape[1] 36 | index_0, index_1 = pointops.precompute_index_pairs(p2v_map, counts, index_0_offsets) 37 | index_0 = index_0.long() 38 | index_1 = index_1.long() 39 | 40 | # index_0 = torch.rand(M) 41 | # index_0[index_0 < 0] = 0 42 | # index_0 = (index_0*N).long().cuda() 43 | 44 | # index_1 = torch.rand(M) 45 | # index_1[index_1 < 0] = 0 46 | # index_1 = (index_1*N).long().cuda() 47 | 48 | query.requires_grad = True 49 | key.requires_grad = True 50 | 51 | # # rearrange index for acceleration 52 | # index_0, indices = torch.sort(index_0) #[M,] 53 | # index_1 = index_1[indices] #[M,] 54 | # index_0_counts = index_0.bincount() 55 | # n_max = index_0_counts.max() 56 | # index_0_offsets = index_0_counts.cumsum(dim=-1) #[N] 57 | # index_0_offsets = torch.cat([torch.zeros(1, dtype=torch.long).cuda(), index_0_offsets], 0) #[N+1] 58 | 59 | attn_flat = pointops.attention_step1(query.float(), key.float(), index_0.int(), index_1.int()) 60 | loss = attn_flat.sum() 61 | loss.backward() 62 | print("attn_flat.shape: {}, attn_flat[:20,:10]: {}".format(attn_flat.shape, attn_flat[:20,:10])) 63 | print("query.grad[:5, :3, :5]: ", query.grad[:5, :3, :5]) 64 | print("key.grad[:5, :3, :5]: ", key.grad[:5, :3, :5]) 65 | # input() 66 | 67 | query_grad = query.grad.clone() 68 | key_grad = key.grad.clone() 69 | 70 | query.grad.zero_() 71 | key.grad.zero_() 72 | 73 | # # print("index_0[:100]: ", index_0[:100]) 74 | # print("n_max: ", n_max) 75 | # print("index_0_offsets.shape: ", index_0_offsets.shape) 76 | # # input() 77 | 78 | # print("index_0_offsets[:100]: ", index_0_offsets[:100]) 79 | # print("index_1[:20]: ", index_1[:20]) 80 | 81 | 82 | # attn_flat = pointops.attention_step1(query.float(), key.float(), index_0.int(), index_1.int()) 83 | # loss = attn_flat.sum() 84 | # loss.backward() 85 | # # attn_flat = pointops.attention_step1(query.float(), key.float(), index_0.int(), index_1.int()) 86 | # # loss = attn_flat.sum() 87 | # # loss.backward() 88 | # print("attn_flat.shape: {}, attn_flat[:20,:10]: {}".format(attn_flat.shape, attn_flat[:20,:10])) 89 | # print("query.grad[:5, :3, :5]: ", query.grad[:5, :3, :5]) 90 | # print("key.grad[:5, :3, :5]: ", key.grad[:5, :3, :5]) 91 | # input() 92 | 93 | print("query.is_contiguous(): ", query.is_contiguous()) 94 | print("key.is_contiguous(): ", key.is_contiguous()) 95 | print("index_0.is_contiguous(): ", index_0.is_contiguous()) 96 | print("index_1.is_contiguous(): ", index_1.is_contiguous()) 97 | 98 | attn_flat_v2 = sptr.attention_step1(query.float(), key.float(), index_0.int(), index_0_offsets.int(), index_1.int(), ctg_index_1_offsets.int(), n_max) 99 | loss = attn_flat_v2.sum() 100 | loss.backward() 101 | 102 | # attn_flat_v2 = pointops.attention_step1_v2(query.float(), key.float(), index_1.int(), index_0_offsets.int(), n_max) 103 | # loss = attn_flat_v2.sum() 104 | # loss.backward() 105 | 106 | print("attn_flat_v2.shape: {}, attn_flat_v2[:20,:10]: {}".format(attn_flat_v2.shape, attn_flat_v2[:20,:10])) 107 | print("query.grad[:5, :3, :5]: ", query.grad[:5, :3, :5]) 108 | print("key.grad[:5, :3, :5]: ", key.grad[:5, :3, :5]) 109 | # input() 110 | 111 | # mask = attn_flat_v2.sum(-1) != 0 112 | # print("mask.sum(): ", mask.sum()) 113 | # print("attn_flat_v2[mask] - attn_flat[mask]: ", ((attn_flat_v2[mask] - attn_flat[mask])**2).max()) 114 | 115 | 116 | print("((attn_flat-attn_flat_v2).abs()).max(): ", ((attn_flat-attn_flat_v2).abs()).max()) 117 | 118 | print("(query.grad-query_grad).abs().max(): ", (query.grad-query_grad).abs().max()) 119 | print("(key.grad-key_grad).abs().max(): ", (key.grad-key_grad).abs().max()) 120 | 121 | # selected = 10000 122 | # print("torch.max((attn_flat[:selected]-attn_flat_v2[:selected])**2, 0): ", torch.max((attn_flat[:selected]-attn_flat_v2[:selected])**2, 0)) 123 | 124 | -------------------------------------------------------------------------------- /src/models/SphereFormer/SparseTransformer/test/test_attention_op_step2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pointops 3 | from torch_scatter import scatter_max, scatter_mean, scatter_add, scatter_min, scatter_sum 4 | import sys 5 | sys.path.append("..") 6 | import sptr 7 | 8 | torch.manual_seed(2) 9 | 10 | # M = 80000 11 | N = 3500 #10 #5 #10 #3500 12 | n = 150 #2 #150 13 | # k = 5 #7 #65 14 | hdim = 16 15 | h = 6 #1 #6 16 | L = 31 #2 #31 17 | v = torch.rand(N, h, hdim).cuda() 18 | # table = torch.rand(L, h, hdim, 3).cuda() 19 | 20 | # index_0 = torch.rand(M) 21 | # index_0[index_0 < 0] = 0 22 | # index_0 = (index_0*N).long().cuda() 23 | 24 | # index_1 = torch.rand(M) 25 | # index_1[index_1 < 0] = 0 26 | # index_1 = (index_1*N).long().cuda() 27 | 28 | 29 | v2p_map = torch.randint(low=0, high=n, size=(N,)).cuda() 30 | v2p_map, _ = v2p_map.sort() 31 | counts = v2p_map.bincount() 32 | M = (counts**2).sum().item() 33 | k = counts.max().item() 34 | 35 | print("counts: ", counts) 36 | 37 | attn = torch.rand(M, h).cuda() 38 | 39 | # v2p_map, ctg_sort_idx = v2p_map.sort() 40 | # n, k = p2v_map.shape 41 | # N = v2p_map.shape[0] 42 | mask = torch.arange(k)[None].cuda().expand(n, -1) < counts[:, None] #[n, k] 43 | to_add = torch.arange(k)[None].cuda().expand(n, -1)[mask] 44 | v2p_map = v2p_map.long() 45 | p2v_map = torch.zeros(n, k).long().cuda() #torch.zeros_like(p2v_map) 46 | p2v_map[mask] = torch.arange(N).cuda() 47 | ctg_index_1_offsets = torch.cat([torch.zeros(1, dtype=torch.long).cuda(), (counts ** 2).cumsum(-1)], 0)[v2p_map] + to_add 48 | 49 | print("M: ", M) 50 | # print("counts[:5]: {}, v2p_map[:5]: {}, p2v_map[:5]: {}".format(counts[:5], v2p_map[:5], p2v_map[:5])) 51 | # print("ctg_index_1_offsets[:50]: ", ctg_index_1_offsets[:50]) 52 | 53 | # print("ctg_index_1_offsets.max(): {}, ctg_index_1_offsets.min(): {}".format(ctg_index_1_offsets.max(), ctg_index_1_offsets.min())) 54 | print("ctg_index_1_offsets: ", ctg_index_1_offsets) 55 | 56 | # rel_index = torch.rand(M, 3) 57 | # rel_index[rel_index < 0] = 0 58 | # rel_index = (rel_index*L).long().cuda() 59 | 60 | # # print("rel_index.min(): {}, rel_index.max(): {}".format( 61 | # # rel_index.min(), rel_index.max() 62 | # # )) 63 | 64 | # print("rel_index: ", rel_index) 65 | 66 | index_0_counts = counts[v2p_map.long()] #[N, ] 67 | index_0_offsets = index_0_counts.cumsum(-1) #[N, ] 68 | index_0_offsets = torch.cat([torch.zeros(1, dtype=torch.long).cuda(), index_0_offsets], 0) #[N+1] 69 | n_max = p2v_map.shape[1] 70 | index_0, index_1 = pointops.precompute_index_pairs(p2v_map, counts, index_0_offsets) 71 | index_0 = index_0.long() 72 | index_1 = index_1.long() 73 | 74 | print("index_0: {}".format(index_0)) 75 | print("index_1: {}".format(index_1)) 76 | 77 | # print("index_0.max(): {}, index_0.min(): {}".format(index_0.max(), index_0.min())) 78 | 79 | # print("index_1.max(): {}, index_1.min(): {}".format(index_1.max(), index_1.min())) 80 | 81 | # input() 82 | 83 | # # rearrange index for acceleration 84 | # index_0, indices = torch.sort(index_0) #[M,] 85 | # index_1 = index_1[indices] #[M,] 86 | # rel_index = rel_index[indices] 87 | # index_0_counts = index_0.bincount() 88 | 89 | print("index_0_counts.shape: ", index_0_counts.shape) 90 | 91 | # n_max = index_0_counts.max() 92 | # index_0_offsets = index_0_counts.cumsum(dim=-1) #[N] 93 | 94 | # print("v1 index_0_offsets.shape: ", index_0_offsets.shape) 95 | 96 | # index_0_offsets = torch.cat([torch.zeros(1, dtype=torch.long).cuda(), index_0_offsets], 0) #[N+1] 97 | 98 | 99 | attn.requires_grad = True 100 | v.requires_grad = True 101 | # table.requires_grad = True 102 | 103 | 104 | # output = pointops.attention_step2_with_rel_pos_value(attn, v, index_0.int(), index_1.int(), table, rel_index.int()) 105 | output = pointops.attention_step2(attn, v, index_0.int(), index_1.int()) 106 | loss = output.mean() 107 | loss.backward() 108 | 109 | # print("output.shape: {}, output[:5,:10,:5]: {}".format(output.shape, output[:5,:10, :5])) 110 | # print("attn.grad[:5, :3]: ", attn.grad[:5, :3]) 111 | # print("v.grad[:5, :3, :5]: ", v.grad[:5, :3, :5]) 112 | # print("table.grad[:5, :3, :5, :2]: ", table.grad[:5, :3, :5, :2]) 113 | # # input() 114 | 115 | attn_grad = attn.grad.clone() 116 | v_grad = v.grad.clone() 117 | # table_grad = table.grad.clone() 118 | 119 | attn.grad.zero_() 120 | v.grad.zero_() 121 | # table.grad.zero_() 122 | 123 | # print("query.is_contiguous(): ", query.is_contiguous()) 124 | # print("key.is_contiguous(): ", key.is_contiguous()) 125 | # print("index_0.is_contiguous(): ", index_0.is_contiguous()) 126 | # print("index_1.is_contiguous(): ", index_1.is_contiguous()) 127 | 128 | # output_v2 = pointops.attention_step2_with_rel_pos_value_v7(attn, v, index_0.int(), index_0_offsets.int(), n_max, index_1.int(), ctg_index_1_offsets.int(), table, rel_index.int()) 129 | # output_v2 = pointops.attention_step2_with_rel_pos_value_v4(attn, v, index_0_offsets.int(), n_max, index_1.int(), table, rel_index.int()) 130 | output_v2 = sptr.attention_step2(attn, v, index_0.int(), index_0_offsets.int(), n_max, index_1.int(), ctg_index_1_offsets.int()) 131 | loss = output_v2.mean() 132 | loss.backward() 133 | 134 | # print("output_v2.shape: {}, output_v2[:5,:10,:5]: {}".format(output_v2.shape, output_v2[:5,:10,:5])) 135 | # print("v2 attn.grad[:5, :3]: ", attn.grad[:5, :3]) 136 | # print("v2 v.grad[:5, :3, :5]: ", v.grad[:5, :3, :5]) 137 | # print("v2 table.grad[:5, :3, :5, :2]: ", table.grad[:5, :3, :5, :2]) 138 | # # input() 139 | 140 | print("((output-output_v2)**2).max(): ", ((output-output_v2)**2).max()) 141 | 142 | print("((attn_grad-attn.grad)**2).max(): ", ((attn_grad-attn.grad)**2).max()) 143 | 144 | print("((v_grad-v.grad)**2).max(): ", ((v_grad-v.grad)**2).max()) 145 | 146 | # print("((table_grad-table.grad)**2).max(): ", ((table_grad-table.grad)**2).max()) 147 | 148 | # print("torch.max((attn_flat-attn_flat_v2)**2): ", torch.max((attn_flat-attn_flat_v2)**2)) 149 | 150 | -------------------------------------------------------------------------------- /src/models/SphereFormer/SparseTransformer/test/test_precompute_all.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pointops 3 | import sys 4 | sys.path.append("..") 5 | import sptr 6 | 7 | torch.manual_seed(1) 8 | 9 | v2p_map = torch.IntTensor([ 10 | 1, 0, 0, 2, 0, 2, 2, 1, 2, 2, 2 11 | ]).cuda() 12 | 13 | p2v_map = torch.IntTensor([ 14 | [1, 2, 4, 0, 0, 0], 15 | [0, 7, 0, 0, 0, 0], 16 | [5, 6, 3, 9, 8, 10] 17 | ]).cuda() 18 | 19 | counts = torch.IntTensor([3, 2, 6]).cuda() 20 | 21 | # index_0_counts = counts[v2p_map.long()] #[N, ] 22 | # index_0_offsets = index_0_counts.cumsum(-1) #[N, ] 23 | # index_0_offsets = torch.cat([torch.zeros(1, dtype=torch.long).cuda(), index_0_offsets], 0) #[N+1] 24 | 25 | # N = v2p_map.shape[0] 26 | 27 | print("v2p_map.shape: {}, p2v_map.shape: {}".format(v2p_map.shape, p2v_map.shape)) 28 | 29 | v2p_map, ctg_sort_idx = v2p_map.sort() 30 | n, k = p2v_map.shape 31 | N = v2p_map.shape[0] 32 | mask = torch.arange(k)[None].cuda().expand(n, -1) < counts[:, None] #[n, k] 33 | to_add = torch.arange(k)[None].cuda().expand(n, -1)[mask] 34 | v2p_map = v2p_map.long() 35 | p2v_map = torch.zeros_like(p2v_map) 36 | p2v_map[mask] = torch.arange(N).int().cuda() 37 | ctg_index_1_offsets = torch.cat([torch.zeros(1, dtype=torch.long).cuda(), (counts ** 2).cumsum(-1)], 0)[v2p_map] + to_add 38 | 39 | index_params = (ctg_index_1_offsets, ctg_sort_idx) 40 | index_0_counts = counts[v2p_map.long()] #[N, ] 41 | index_0_offsets = index_0_counts.cumsum(-1) #[N, ] 42 | index_0_offsets = torch.cat([torch.zeros(1, dtype=torch.long).cuda(), index_0_offsets], 0) #[N+1] 43 | n_max = p2v_map.shape[1] 44 | index_0, index_1 = pointops.precompute_index_pairs(p2v_map, counts, index_0_offsets) 45 | index_0 = index_0.long() 46 | index_1 = index_1.long() 47 | 48 | 49 | # index_0_offsets_, index_1_offsets_ = pointops.precompute_offsets(N, p2v_map.shape[0], p2v_map.shape[1], counts) 50 | 51 | # # assert (index_0_offsets_ == index_0_offsets).all() 52 | # print("index_0_offsets: ", index_0_offsets) 53 | # print("index_0_offsets_: ", index_0_offsets_) 54 | # print("ctg_index_1_offsets: ", ctg_index_1_offsets) 55 | # print("index_1_offsets_: ", index_1_offsets_) 56 | 57 | # index_0, index_1 = pointops.precompute_index_pairs(p2v_map, counts, index_0_offsets) 58 | 59 | # print("index_0: ", index_0) 60 | # print("index_1: ", index_1) 61 | 62 | # print("index_0.shape: ", index_0.shape) 63 | 64 | 65 | index_0_offsets_, index_1_offsets_, index_0_, index_1_ = sptr.precompute_all(N, p2v_map.shape[0], p2v_map.shape[1], counts) 66 | 67 | assert (index_0_offsets_ == index_0_offsets).all() 68 | assert (index_1_offsets_ == ctg_index_1_offsets).all() 69 | assert (index_0_ == index_0).all() 70 | assert (index_1_ == index_1).all() 71 | 72 | print("index_0_offsets: ", index_0_offsets) 73 | print("index_0_offsets_: ", index_0_offsets_) 74 | print("ctg_index_1_offsets: ", ctg_index_1_offsets) 75 | print("index_1_offsets_: ", index_1_offsets_) 76 | 77 | print("index_0_offsets.shape: {}, index_0_offsets_.shape: {}".format(index_0_offsets.shape, index_0_offsets_.shape)) 78 | 79 | print("ctg_index_1_offsets.shape: {}, index_1_offsets_.shape: {}".format(ctg_index_1_offsets.shape, index_1_offsets_.shape)) 80 | 81 | # print("index_0: ", index_0) 82 | # print("index_0_: ", index_0_) 83 | # print("index_1: ", index_1) 84 | # print("index_1_: ", index_1_) 85 | 86 | # print("index_0.shape: ", index_0.shape) 87 | -------------------------------------------------------------------------------- /src/models/SphereFormer/SparseTransformer/test/test_relative_pos_encoding_op_step1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pointops 3 | from torch_scatter import scatter_max, scatter_mean, scatter_add, scatter_min, scatter_sum 4 | import sys 5 | sys.path.append("..") 6 | import sptr 7 | 8 | torch.manual_seed(1) 9 | 10 | # M = 80000 11 | N = 3500 12 | n = 150 13 | k = 65 14 | # M = 80 15 | # N = 5 16 | hdim = 16 17 | h = 6 18 | L = 31 19 | query = torch.rand(N, h, hdim).cuda() 20 | table_q = torch.rand(L, h, hdim, 3).cuda() 21 | key = torch.rand(N, h, hdim).cuda() 22 | table_k = torch.rand(L, h, hdim, 3).cuda() 23 | 24 | # index_q = torch.rand(M) 25 | # index_q[index_q < 0] = 0 26 | # index_q = (index_q*N).long().cuda() 27 | 28 | # index_k = torch.rand(M) 29 | # index_k[index_k < 0] = 0 30 | # index_k = (index_k*N).long().cuda() 31 | 32 | v2p_map = torch.randint(low=0, high=n, size=(N,)).cuda() 33 | v2p_map, _ = v2p_map.sort() 34 | counts = v2p_map.bincount() 35 | M = (counts**2).sum().item() 36 | 37 | # v2p_map, ctg_sort_idx = v2p_map.sort() 38 | # n, k = p2v_map.shape 39 | N = v2p_map.shape[0] 40 | mask = torch.arange(k)[None].cuda().expand(n, -1) < counts[:, None] #[n, k] 41 | to_add = torch.arange(k)[None].cuda().expand(n, -1)[mask] 42 | v2p_map = v2p_map.long() 43 | p2v_map = torch.zeros(n, k).long().cuda() #torch.zeros_like(p2v_map) 44 | p2v_map[mask] = torch.arange(N).cuda() 45 | ctg_index_1_offsets = torch.cat([torch.zeros(1, dtype=torch.long).cuda(), (counts ** 2).cumsum(-1)], 0)[v2p_map] + to_add 46 | 47 | rel_index = torch.rand(M, 3) 48 | rel_index[rel_index < 0] = 0 49 | rel_index = (rel_index*L).long().cuda() 50 | 51 | index_q_counts = counts[v2p_map.long()] #[N, ] 52 | index_q_offsets = index_q_counts.cumsum(-1) #[N, ] 53 | index_q_offsets = torch.cat([torch.zeros(1, dtype=torch.long).cuda(), index_q_offsets], 0) #[N+1] 54 | n_max = p2v_map.shape[1] 55 | index_q, index_k = pointops.precompute_index_pairs(p2v_map, counts, index_q_offsets) 56 | index_q = index_q.long() 57 | index_k = index_k.long() 58 | 59 | # # rearrange index for acceleration 60 | # index_q, indices = torch.sort(index_q) #[M,] 61 | # index_k = index_k[indices] #[M,] 62 | # rel_index = rel_index[indices] 63 | # index_q_counts = index_q.bincount() 64 | 65 | # print("index_q_counts.shape: ", index_q_counts.shape) 66 | 67 | # n_max = index_q_counts.max() 68 | # index_q_offsets = index_q_counts.cumsum(dim=-1) #[N] 69 | 70 | # print("v1 index_q_offsets.shape: ", index_q_offsets.shape) 71 | 72 | # index_q_offsets = torch.cat([torch.zeros(1, dtype=torch.long).cuda(), index_q_offsets], 0) #[N+1] 73 | 74 | # print("index_q[:100]: ", index_q[:100]) 75 | print("n_max: ", n_max) 76 | print("index_q_offsets.shape: ", index_q_offsets.shape) 77 | # input() 78 | 79 | print("index_q_offsets[:100]: ", index_q_offsets[:100]) 80 | print("index_k[:20]: ", index_k[:20]) 81 | 82 | query.requires_grad = True 83 | table_q.requires_grad = True 84 | key.requires_grad = True 85 | table_k.requires_grad = True 86 | 87 | output1 = pointops.dot_prod_with_idx(query, index_q.int(), table_q, rel_index.int()) 88 | output2 = pointops.dot_prod_with_idx(key, index_k.int(), table_k, rel_index.int()) 89 | output = output1 + output2 90 | loss = output.mean() 91 | loss.backward() 92 | 93 | print("output.shape: {}, output[:5,:10]: {}".format(output.shape, output[:5,:10])) 94 | print("query.grad[:5, :3, :5]: ", query.grad[:5, :3, :5]) 95 | print("table_q.grad[:5, :3, :5, :2]: ", table_q.grad[:5, :3, :5, :2]) 96 | print("key.grad[:5, :3, :5]: ", key.grad[:5, :3, :5]) 97 | print("table_k.grad[:5, :3, :5, :2]: ", table_k.grad[:5, :3, :5, :2]) 98 | # input() 99 | 100 | query_grad = query.grad.clone() 101 | key_grad = key.grad.clone() 102 | table_q_grad = table_q.grad.clone() 103 | table_k_grad = table_k.grad.clone() 104 | 105 | query.grad.zero_() 106 | key.grad.zero_() 107 | table_q.grad.zero_() 108 | table_k.grad.zero_() 109 | 110 | 111 | # print("query.is_contiguous(): ", query.is_contiguous()) 112 | # print("key.is_contiguous(): ", key.is_contiguous()) 113 | # print("index_q.is_contiguous(): ", index_q.is_contiguous()) 114 | # print("index_k.is_contiguous(): ", index_k.is_contiguous()) 115 | table_q = table_q.detach().permute(0,3,1,2).contiguous() 116 | table_k = table_k.detach().permute(0,3,1,2).contiguous() 117 | table_q.requires_grad = True 118 | table_k.requires_grad = True 119 | output_v2 = sptr.dot_prod_with_idx(query, index_q.int(), index_q_offsets.int(), n_max, key, ctg_index_1_offsets.int(), index_k.int(), table_q, table_k, rel_index.int()) 120 | # output_v2 = pointops.dot_prod_with_idx_v5(query, index_q_offsets.int(), n_max, key, index_k.int(), table_q, table_k, rel_index.int()) 121 | loss = output_v2.mean() 122 | loss.backward() 123 | 124 | table_q_grad2 = table_q.grad.clone().permute(0,2,3,1).contiguous() 125 | table_k_grad2 = table_k.grad.clone().permute(0,2,3,1).contiguous() 126 | 127 | print("output_v2.shape: {}, output_v2[:5,:10]: {}".format(output_v2.shape, output_v2[:5,:10])) 128 | print("v2 query.grad[:5, :3, :5]: ", query.grad[:5, :3, :5]) 129 | print("v2 table_q_grad2[:5, :3, :5, :2]: ", table_q_grad2[:5, :3, :5, :2]) 130 | print("v2 key.grad[:5, :3, :5]: ", key.grad[:5, :3, :5]) 131 | print("v2 table_k_grad2[:5, :3, :5, :2]: ", table_k_grad2[:5, :3, :5, :2]) 132 | # input() 133 | 134 | print("((output-output_v2)**2).max(): ", ((output-output_v2)**2).max()) 135 | 136 | print("((query.grad-query_grad)**2).max(): ", ((query.grad-query_grad)**2).max()) 137 | 138 | print("((key.grad-key_grad)**2).max(): ", ((key.grad-key_grad)**2).max()) 139 | 140 | print("((table_q_grad2-table_q_grad)**2).max(): ", ((table_q_grad2-table_q_grad)**2).max()) 141 | 142 | print("((table_k_grad2-table_k_grad)**2).max(): ", ((table_k_grad2-table_k_grad)**2).max()) 143 | 144 | 145 | 146 | -------------------------------------------------------------------------------- /src/models/SphereFormer/SparseTransformer/test/test_relative_pos_encoding_op_step1_all.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pointops 3 | from torch_scatter import scatter_max, scatter_mean, scatter_add, scatter_min, scatter_sum 4 | import sys 5 | sys.path.append("..") 6 | import sptr 7 | 8 | torch.manual_seed(1) 9 | 10 | # M = 80000 11 | N = 3500 12 | n = 150 13 | k = 65 14 | # M = 80 15 | # N = 5 16 | hdim = 16 17 | h = 6 18 | L = 31 19 | query = torch.rand(N, h, hdim).cuda() 20 | table_q = torch.rand(L, h, hdim, 3).cuda() 21 | key = torch.rand(N, h, hdim).cuda() 22 | table_k = torch.rand(L, h, hdim, 3).cuda() 23 | 24 | # index_q = torch.rand(M) 25 | # index_q[index_q < 0] = 0 26 | # index_q = (index_q*N).long().cuda() 27 | 28 | # index_k = torch.rand(M) 29 | # index_k[index_k < 0] = 0 30 | # index_k = (index_k*N).long().cuda() 31 | 32 | v2p_map = torch.randint(low=0, high=n, size=(N,)).cuda() 33 | v2p_map, _ = v2p_map.sort() 34 | counts = v2p_map.bincount() 35 | M = (counts**2).sum().item() 36 | 37 | # v2p_map, ctg_sort_idx = v2p_map.sort() 38 | # n, k = p2v_map.shape 39 | N = v2p_map.shape[0] 40 | mask = torch.arange(k)[None].cuda().expand(n, -1) < counts[:, None] #[n, k] 41 | to_add = torch.arange(k)[None].cuda().expand(n, -1)[mask] 42 | v2p_map = v2p_map.long() 43 | p2v_map = torch.zeros(n, k).long().cuda() #torch.zeros_like(p2v_map) 44 | p2v_map[mask] = torch.arange(N).cuda() 45 | ctg_index_1_offsets = torch.cat([torch.zeros(1, dtype=torch.long).cuda(), (counts ** 2).cumsum(-1)], 0)[v2p_map] + to_add 46 | 47 | rel_index = torch.rand(M, 3) 48 | rel_index[rel_index < 0] = 0 49 | rel_index = (rel_index*L).long().cuda() 50 | 51 | index_q_counts = counts[v2p_map.long()] #[N, ] 52 | index_q_offsets = index_q_counts.cumsum(-1) #[N, ] 53 | index_q_offsets = torch.cat([torch.zeros(1, dtype=torch.long).cuda(), index_q_offsets], 0) #[N+1] 54 | n_max = p2v_map.shape[1] 55 | index_q, index_k = pointops.precompute_index_pairs(p2v_map, counts, index_q_offsets) 56 | index_q = index_q.long() 57 | index_k = index_k.long() 58 | 59 | # # rearrange index for acceleration 60 | # index_q, indices = torch.sort(index_q) #[M,] 61 | # index_k = index_k[indices] #[M,] 62 | # rel_index = rel_index[indices] 63 | # index_q_counts = index_q.bincount() 64 | 65 | # print("index_q_counts.shape: ", index_q_counts.shape) 66 | 67 | # n_max = index_q_counts.max() 68 | # index_q_offsets = index_q_counts.cumsum(dim=-1) #[N] 69 | 70 | # print("v1 index_q_offsets.shape: ", index_q_offsets.shape) 71 | 72 | # index_q_offsets = torch.cat([torch.zeros(1, dtype=torch.long).cuda(), index_q_offsets], 0) #[N+1] 73 | 74 | # print("index_q[:100]: ", index_q[:100]) 75 | print("n_max: ", n_max) 76 | print("index_q_offsets.shape: ", index_q_offsets.shape) 77 | # input() 78 | 79 | print("index_q_offsets[:100]: ", index_q_offsets[:100]) 80 | print("index_k[:20]: ", index_k[:20]) 81 | 82 | query.requires_grad = True 83 | table_q.requires_grad = True 84 | key.requires_grad = True 85 | table_k.requires_grad = True 86 | 87 | output1 = pointops.dot_prod_with_idx(query, index_q.int(), table_q, rel_index.int()) 88 | output2 = pointops.dot_prod_with_idx(key, index_k.int(), table_k, rel_index.int()) 89 | attn_flat = pointops.attention_step1(query.float(), key.float(), index_q.int(), index_k.int()) 90 | 91 | output = output1 + output2 + attn_flat 92 | loss = output.mean() 93 | loss.backward() 94 | 95 | print("output.shape: {}, output[:5,:10]: {}".format(output.shape, output[:5,:10])) 96 | print("query.grad[:5, :3, :5]: ", query.grad[:5, :3, :5]) 97 | print("table_q.grad[:5, :3, :5, :2]: ", table_q.grad[:5, :3, :5, :2]) 98 | print("key.grad[:5, :3, :5]: ", key.grad[:5, :3, :5]) 99 | print("table_k.grad[:5, :3, :5, :2]: ", table_k.grad[:5, :3, :5, :2]) 100 | # input() 101 | 102 | query_grad = query.grad.clone() 103 | key_grad = key.grad.clone() 104 | table_q_grad = table_q.grad.clone() 105 | table_k_grad = table_k.grad.clone() 106 | 107 | query.grad.zero_() 108 | key.grad.zero_() 109 | table_q.grad.zero_() 110 | table_k.grad.zero_() 111 | 112 | 113 | # print("query.is_contiguous(): ", query.is_contiguous()) 114 | # print("key.is_contiguous(): ", key.is_contiguous()) 115 | # print("index_q.is_contiguous(): ", index_q.is_contiguous()) 116 | # print("index_k.is_contiguous(): ", index_k.is_contiguous()) 117 | table_q = table_q.detach().permute(0,3,1,2).contiguous() 118 | table_k = table_k.detach().permute(0,3,1,2).contiguous() 119 | table_q.requires_grad = True 120 | table_k.requires_grad = True 121 | output_v2 = sptr.dot_prod_with_idx_all(query, index_q.int(), index_q_offsets.int(), n_max, key, ctg_index_1_offsets.int(), index_k.int(), table_q, table_k, rel_index.int()) 122 | loss = output_v2.mean() 123 | loss.backward() 124 | 125 | table_q_grad2 = table_q.grad.clone().permute(0,2,3,1).contiguous() 126 | table_k_grad2 = table_k.grad.clone().permute(0,2,3,1).contiguous() 127 | 128 | print("output[:5, :5]: ", output[:5, :5]) 129 | 130 | print("output_v2[:5, :5]: ", output_v2[:5, :5]) 131 | 132 | # print("output_v2.shape: {}, output_v2[:5,:10]: {}".format(output_v2.shape, output_v2[:5,:10])) 133 | print("v2 query.grad[:5, :3, :5]: ", query.grad[:5, :3, :5]) 134 | print("v2 table_q.grad[:5, :3, :5, :2]: ", table_q.grad[:5, :3, :5, :2]) 135 | print("v2 key.grad[:5, :3, :5]: ", key.grad[:5, :3, :5]) 136 | print("v2 table_k.grad[:5, :3, :5, :2]: ", table_k.grad[:5, :3, :5, :2]) 137 | # input() 138 | 139 | print("((output-output_v2)**2).max(): ", ((output-output_v2)**2).max()) 140 | 141 | print("((query.grad-query_grad)**2).max(): ", ((query.grad-query_grad)**2).max()) 142 | 143 | print("((key.grad-key_grad)**2).max(): ", ((key.grad-key_grad)**2).max()) 144 | 145 | print("((table_q_grad2-table_q_grad)**2).max(): ", ((table_q_grad2-table_q_grad)**2).max()) 146 | 147 | print("((table_k_grad2-table_k_grad)**2).max(): ", ((table_k_grad2-table_k_grad)**2).max()) 148 | 149 | 150 | -------------------------------------------------------------------------------- /src/models/SphereFormer/SparseTransformer/test/test_relative_pos_encoding_op_step2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pointops 3 | from torch_scatter import scatter_max, scatter_mean, scatter_add, scatter_min, scatter_sum 4 | import sys 5 | sys.path.append("..") 6 | import sptr 7 | 8 | torch.manual_seed(2) 9 | 10 | # M = 80000 11 | N = 3500 #10 #5 #10 #3500 12 | n = 150 #2 #150 13 | # k = 5 #7 #65 14 | hdim = 16 15 | h = 6 #1 #6 16 | L = 31 #2 #31 17 | v = torch.rand(N, h, hdim).cuda() 18 | table = torch.rand(L, h, hdim, 3).cuda() 19 | 20 | # index_0 = torch.rand(M) 21 | # index_0[index_0 < 0] = 0 22 | # index_0 = (index_0*N).long().cuda() 23 | 24 | # index_1 = torch.rand(M) 25 | # index_1[index_1 < 0] = 0 26 | # index_1 = (index_1*N).long().cuda() 27 | 28 | 29 | v2p_map = torch.randint(low=0, high=n, size=(N,)).cuda() 30 | v2p_map, _ = v2p_map.sort() 31 | counts = v2p_map.bincount() 32 | M = (counts**2).sum().item() 33 | k = counts.max().item() 34 | 35 | print("counts: ", counts) 36 | 37 | attn = torch.rand(M, h).cuda() 38 | 39 | # v2p_map, ctg_sort_idx = v2p_map.sort() 40 | # n, k = p2v_map.shape 41 | # N = v2p_map.shape[0] 42 | mask = torch.arange(k)[None].cuda().expand(n, -1) < counts[:, None] #[n, k] 43 | to_add = torch.arange(k)[None].cuda().expand(n, -1)[mask] 44 | v2p_map = v2p_map.long() 45 | p2v_map = torch.zeros(n, k).long().cuda() #torch.zeros_like(p2v_map) 46 | p2v_map[mask] = torch.arange(N).cuda() 47 | ctg_index_1_offsets = torch.cat([torch.zeros(1, dtype=torch.long).cuda(), (counts ** 2).cumsum(-1)], 0)[v2p_map] + to_add 48 | 49 | print("M: ", M) 50 | # print("counts[:5]: {}, v2p_map[:5]: {}, p2v_map[:5]: {}".format(counts[:5], v2p_map[:5], p2v_map[:5])) 51 | # print("ctg_index_1_offsets[:50]: ", ctg_index_1_offsets[:50]) 52 | 53 | # print("ctg_index_1_offsets.max(): {}, ctg_index_1_offsets.min(): {}".format(ctg_index_1_offsets.max(), ctg_index_1_offsets.min())) 54 | print("ctg_index_1_offsets: ", ctg_index_1_offsets) 55 | 56 | rel_index = torch.rand(M, 3) 57 | rel_index[rel_index < 0] = 0 58 | rel_index = (rel_index*L).long().cuda() 59 | 60 | # print("rel_index.min(): {}, rel_index.max(): {}".format( 61 | # rel_index.min(), rel_index.max() 62 | # )) 63 | 64 | print("rel_index: ", rel_index) 65 | 66 | index_0_counts = counts[v2p_map.long()] #[N, ] 67 | index_0_offsets = index_0_counts.cumsum(-1) #[N, ] 68 | index_0_offsets = torch.cat([torch.zeros(1, dtype=torch.long).cuda(), index_0_offsets], 0) #[N+1] 69 | n_max = p2v_map.shape[1] 70 | index_0, index_1 = pointops.precompute_index_pairs(p2v_map, counts, index_0_offsets) 71 | index_0 = index_0.long() 72 | index_1 = index_1.long() 73 | 74 | print("index_0: {}".format(index_0)) 75 | print("index_1: {}".format(index_1)) 76 | 77 | # print("index_0.max(): {}, index_0.min(): {}".format(index_0.max(), index_0.min())) 78 | 79 | # print("index_1.max(): {}, index_1.min(): {}".format(index_1.max(), index_1.min())) 80 | 81 | # input() 82 | 83 | # # rearrange index for acceleration 84 | # index_0, indices = torch.sort(index_0) #[M,] 85 | # index_1 = index_1[indices] #[M,] 86 | # rel_index = rel_index[indices] 87 | # index_0_counts = index_0.bincount() 88 | 89 | print("index_0_counts.shape: ", index_0_counts.shape) 90 | 91 | # n_max = index_0_counts.max() 92 | # index_0_offsets = index_0_counts.cumsum(dim=-1) #[N] 93 | 94 | # print("v1 index_0_offsets.shape: ", index_0_offsets.shape) 95 | 96 | # index_0_offsets = torch.cat([torch.zeros(1, dtype=torch.long).cuda(), index_0_offsets], 0) #[N+1] 97 | 98 | 99 | attn.requires_grad = True 100 | v.requires_grad = True 101 | table.requires_grad = True 102 | 103 | 104 | output = pointops.attention_step2_with_rel_pos_value(attn, v, index_0.int(), index_1.int(), table, rel_index.int()) 105 | loss = output.mean() 106 | loss.backward() 107 | 108 | # print("output.shape: {}, output[:5,:10,:5]: {}".format(output.shape, output[:5,:10, :5])) 109 | # print("attn.grad[:5, :3]: ", attn.grad[:5, :3]) 110 | # print("v.grad[:5, :3, :5]: ", v.grad[:5, :3, :5]) 111 | # print("table.grad[:5, :3, :5, :2]: ", table.grad[:5, :3, :5, :2]) 112 | # # input() 113 | 114 | attn_grad = attn.grad.clone() 115 | v_grad = v.grad.clone() 116 | table_grad = table.grad.clone() 117 | 118 | attn.grad.zero_() 119 | v.grad.zero_() 120 | table.grad.zero_() 121 | 122 | # print("query.is_contiguous(): ", query.is_contiguous()) 123 | # print("key.is_contiguous(): ", key.is_contiguous()) 124 | # print("index_0.is_contiguous(): ", index_0.is_contiguous()) 125 | # print("index_1.is_contiguous(): ", index_1.is_contiguous()) 126 | table = table.detach().permute(0,3,1,2).contiguous() 127 | table.requires_grad = True 128 | output_v2 = sptr.attention_step2_with_rel_pos_value(attn, v, index_0.int(), index_0_offsets.int(), n_max, index_1.int(), ctg_index_1_offsets.int(), table, rel_index.int()) 129 | # output_v2 = pointops.attention_step2_with_rel_pos_value_v4(attn, v, index_0_offsets.int(), n_max, index_1.int(), table, rel_index.int()) 130 | loss = output_v2.mean() 131 | loss.backward() 132 | 133 | table_grad2 = table.grad.clone().permute(0,2,3,1).contiguous() 134 | 135 | # print("output_v2.shape: {}, output_v2[:5,:10,:5]: {}".format(output_v2.shape, output_v2[:5,:10,:5])) 136 | # print("v2 attn.grad[:5, :3]: ", attn.grad[:5, :3]) 137 | # print("v2 v.grad[:5, :3, :5]: ", v.grad[:5, :3, :5]) 138 | # print("v2 table.grad[:5, :3, :5, :2]: ", table.grad[:5, :3, :5, :2]) 139 | # # input() 140 | 141 | print("((output-output_v2)**2).max(): ", ((output-output_v2)**2).max()) 142 | 143 | print("((attn_grad-attn.grad)**2).max(): ", ((attn_grad-attn.grad)**2).max()) 144 | 145 | print("((v_grad-v.grad)**2).max(): ", ((v_grad-v.grad)**2).max()) 146 | 147 | print("((table_grad-table_grad2)**2).max(): ", ((table_grad-table_grad2)**2).max()) 148 | 149 | # print("torch.max((attn_flat-attn_flat_v2)**2): ", torch.max((attn_flat-attn_flat_v2)**2)) 150 | 151 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raktimgg/SALSA/9152e8dff6fb3fc07dd546c9f19bfab462d08a6f/src/models/__init__.py -------------------------------------------------------------------------------- /src/models/adappool.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | class AdaptivePooling(nn.Module): 6 | def __init__(self, feature_dim, output_channels): 7 | super().__init__() 8 | self.output_channels = output_channels 9 | self.query = nn.Parameter(torch.randn(output_channels, feature_dim)) 10 | 11 | def forward(self, x, return_weights=False): 12 | """ 13 | Args: 14 | x: Input tensor of shape (batch_size, input_channels, feature_dim) 15 | 16 | Returns: 17 | Output tensor of shape (batch_size, output_channels, feature_dim) 18 | """ 19 | query = self.query.unsqueeze(0).repeat(x.shape[0],1,1) 20 | 21 | out = F.scaled_dot_product_attention(query=query,key=x,value=x) 22 | if return_weights: 23 | attn_scores = torch.einsum('ij,bkj->bki', self.query, x) 24 | attn_weights = F.softmax(attn_scores, dim=1) 25 | return out, attn_weights 26 | 27 | return out 28 | -------------------------------------------------------------------------------- /src/models/pca_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | 9 | from models.salsa import SALSA 10 | 11 | 12 | class L2Norm(nn.Module): 13 | def forward(self, x): 14 | return torch.nn.functional.normalize(x, p=2.0, dim=1, eps=1e-12) 15 | 16 | class PCAModel(nn.Module): 17 | def __init__(self,num_in_features,num_out_features): 18 | super(PCAModel, self).__init__() 19 | self.pca_conv = nn.Conv2d(num_in_features, num_out_features, kernel_size=(1, 1), stride=1, padding=0) 20 | self.layer = nn.Sequential(*[self.pca_conv, nn.Flatten(), L2Norm()]) 21 | 22 | def forward(self,x): 23 | return self.layer(x) 24 | 25 | 26 | class CombinedModel(nn.Module): 27 | def __init__(self, voxel_sz, num_in_features,num_out_features): 28 | super(CombinedModel, self).__init__() 29 | self.spherelpr = SALSA(voxel_sz=voxel_sz) 30 | self.pca_model = PCAModel(num_in_features, num_out_features) 31 | 32 | def forward(self,data): 33 | coord, xyz, feat, batch = data 34 | output_feats, output_desc = self.spherelpr(coord, xyz, feat, batch) 35 | output_desc = self.pca_model(output_desc[...,None][...,None]) 36 | return output_feats, output_desc 37 | # return output_desc 38 | 39 | -------------------------------------------------------------------------------- /src/models/salsa.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from datetime import datetime 4 | from time import time 5 | 6 | import tqdm 7 | 8 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 9 | 10 | import numpy as np 11 | import spconv.pytorch as spconv 12 | import torch 13 | import torch.nn as nn 14 | 15 | from models.Mixer.mixer import Mixer 16 | from models.SphereFormer.model.unet_spherical_transformer import Semantic 17 | from models.adappool import AdaptivePooling 18 | from utils.misc_utils import read_yaml_config 19 | 20 | 21 | def print_model_size(model): 22 | param_size = 0 23 | for param in model.parameters(): 24 | param_size += param.nelement() * param.element_size() 25 | buffer_size = 0 26 | for buffer in model.buffers(): 27 | buffer_size += buffer.nelement() * buffer.element_size() 28 | 29 | size_all_mb = (param_size + buffer_size) / 1024**2 30 | print('model size: {:.3f}MB'.format(size_all_mb)) 31 | 32 | 33 | class SALSA(nn.Module): 34 | def __init__(self,voxel_sz): 35 | super(SALSA, self).__init__() 36 | config = read_yaml_config(os.path.join(os.path.dirname(__file__),'../config/model.yaml')) 37 | self.k = config['aggregator']['tokens'] 38 | feature_dim = config['feat_extractor']['feature_dim'] 39 | patch_size = config['feat_extractor']['patch_size'] 40 | voxel_size = [voxel_sz, voxel_sz, voxel_sz] 41 | patch_size = np.array([voxel_size[i] * patch_size for i in range(3)]).astype(np.float32) 42 | window_size = patch_size * 6 43 | self.do_pe = True 44 | self.feature_extractor = Semantic(input_c=config['feat_extractor']['input_c'], 45 | m=config['feat_extractor']['m'], 46 | classes=feature_dim, 47 | block_reps=config['feat_extractor']['block_reps'], 48 | block_residual=True, 49 | layers=config['feat_extractor']['layers'], 50 | window_size=window_size, 51 | window_size_sphere=np.array(config['feat_extractor']['window_size_sphere']), 52 | quant_size=window_size/24, 53 | quant_size_sphere= np.array(config['feat_extractor']['window_size_sphere'])/24, 54 | rel_query=True, 55 | rel_key=True, 56 | rel_value=True, 57 | drop_path_rate=config['feat_extractor']['drop_path_rate'], 58 | window_size_scale=config['feat_extractor']['window_size_scale'], 59 | grad_checkpoint_layers=[], 60 | sphere_layers=config['feat_extractor']['sphere_layers'], 61 | a=config['feat_extractor']['a'], 62 | ) 63 | 64 | self.attpool = AdaptivePooling(feature_dim=feature_dim,output_channels=self.k) 65 | 66 | self.descriptor_extractor = Mixer(in_channels=self.k, 67 | out_channels=config['aggregator']['out_channels'], 68 | in_d=feature_dim, 69 | mix_depth=config['aggregator']['mix_depth'], 70 | mlp_ratio=config['aggregator']['mlp_ratio'], 71 | out_d=config['aggregator']['out_d']) 72 | 73 | 74 | 75 | self.do_pe = True 76 | 77 | 78 | 79 | def forward(self, coord, xyz, feat, batch, save_attn_weights=False): 80 | ########################## Feature extractor ######################################## 81 | batch_shape = batch[-1]+1 82 | coord = torch.cat([batch.unsqueeze(-1), coord], -1) 83 | spatial_shape = np.clip((coord.max(0)[0][1:] + 1).cpu().numpy(), 128, None) 84 | 85 | sinput = spconv.SparseConvTensor(feat, coord.int(), spatial_shape, batch_shape) 86 | 87 | local_features = self.feature_extractor(sinput, xyz, batch) 88 | ##################################################################################### 89 | 90 | #################### Adaptive pooling + Mixer based aggregator ##################### 91 | padded_split_local_features = [] 92 | _, counts = torch.unique(batch, return_counts=True) 93 | split_local_fetures = torch.split(local_features, list(counts)) # [(N1,16),(N2,16),(N3,16),(N4,16),...] 94 | for features in split_local_fetures: 95 | # print(features.shape) 96 | if save_attn_weights: 97 | attval, attn_weights = self.attpool(features.unsqueeze(0), return_weights=True) 98 | self.attn_weights = attn_weights 99 | else: 100 | attval = self.attpool(features.unsqueeze(0)) 101 | padded_split_local_features.append(attval.squeeze(0)) ### Attention based pooling 102 | padded_split_local_features = torch.stack(padded_split_local_features, dim=0) 103 | global_descriptor = self.descriptor_extractor(padded_split_local_features) 104 | ##################################################################################### 105 | 106 | 107 | return split_local_fetures, global_descriptor 108 | 109 | 110 | if __name__=='__main__': 111 | import random 112 | import time 113 | os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE" 114 | os.environ["CUBLAS_WORKSPACE_CONFIG"]=":4096:8" 115 | # seed = 3407 116 | seed = 1100 117 | # seed = np.random.randint(10000) 118 | print(seed) 119 | random.seed(seed) 120 | np.random.seed(seed) 121 | torch.manual_seed(seed) 122 | torch.backends.cudnn.benchmark = False 123 | torch.backends.cudnn.enabled = True 124 | torch.backends.cudnn.deterministic = True 125 | torch.cuda.manual_seed_all(seed) 126 | torch.use_deterministic_algorithms(True) 127 | 128 | model = SALSA(2,voxel_sz=0.5,device='cuda') 129 | # save_path = '/data/raktim/Projects/LPR/Main/src/checkpoints/Ablation/NewSphereMixerVoxel2/model_6.pth' 130 | # checkpoint = torch.load(save_path) # ,map_location='cuda:0') 131 | # model.load_state_dict(checkpoint) 132 | model.to('cuda') 133 | 134 | coords = torch.IntTensor(np.random.randint(0,100,size=[11000,3])).to('cuda') 135 | xyz = torch.FloatTensor(np.random.rand(11000,3)).to('cuda') 136 | feats = torch.FloatTensor(np.random.rand(11000,3)).to('cuda') 137 | batch_number = torch.IntTensor(np.ones([11000])).to('cuda') 138 | # print(coords.shape, xyz.shape, feats.shape, batch_number.shape) 139 | model.eval() 140 | 141 | N = 1000 142 | 143 | with torch.inference_mode(): 144 | # torch.cuda.synchronize() 145 | start = time.time() 146 | for i in tqdm.tqdm(range(N)): 147 | local_features, output_desc = model(coords, xyz, feats, batch_number) 148 | # torch.cuda.synchronize() 149 | end = time.time() 150 | print('Forward pass took {} seconds for {} trials'.format(end - start,N)) 151 | 152 | 153 | 154 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import os 3 | import random 4 | import sys 5 | import time 6 | 7 | import numpy as np 8 | import torch 9 | from torch.utils.tensorboard import SummaryWriter 10 | 11 | sys.path.append(os.path.join(os.path.dirname(__file__), '../')) 12 | 13 | from data.sejong_southbay import SejongSouthbayTupleLoader 14 | 15 | from models.salsa import SALSA 16 | 17 | from loss.loss import find_loss 18 | from utils.misc_utils import tuple_collate_fn, read_yaml_config 19 | 20 | os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE" 21 | os.environ["CUBLAS_WORKSPACE_CONFIG"]=":4096:8" 22 | 23 | def print_nb_params(m): 24 | model_parameters = filter(lambda p: p.requires_grad, m.parameters()) 25 | params = sum([np.prod(p.size()) for p in model_parameters]) 26 | print(f'Trainable parameters: {params/1e6:.3}M') 27 | del model_parameters, params 28 | 29 | def print_model_size(model): 30 | param_size = 0 31 | for param in model.parameters(): 32 | param_size += param.nelement() * param.element_size() 33 | buffer_size = 0 34 | for buffer in model.buffers(): 35 | buffer_size += buffer.nelement() * buffer.element_size() 36 | 37 | size_all_mb = (param_size + buffer_size) / 1024**2 38 | print('model size: {:.3f}MB'.format(size_all_mb)) 39 | 40 | 41 | def main(): 42 | config = read_yaml_config(os.path.join(os.path.dirname(__file__),'config/train.yaml')) 43 | writer = SummaryWriter(config['writer_loc']) 44 | # Get data loader 45 | batch_size = config['batch_size'] 46 | train_transform = None 47 | dataset = SejongSouthbayTupleLoader(cached_queries=config['cached_queries'], pcl_transform=train_transform) 48 | device = config['device'] 49 | 50 | model = SALSA(voxel_sz=0.5).to(device) 51 | 52 | model.train() 53 | print_nb_params(model) 54 | print_model_size(model) 55 | MAX_EPOCH = config['max_epoch'] 56 | 57 | optimizer = torch.optim.Adam(model.parameters(), lr=config['lr']) 58 | 59 | 60 | kk_batch = 0 61 | kk_subcache = 0 62 | for e in range(MAX_EPOCH): 63 | EPOCH_LOSS = [] 64 | time1 = time.time() 65 | dataset.new_epoch() 66 | steps_per_epoch = int(np.ceil(1000/batch_size)) 67 | scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=config['max_lr'],epochs=dataset.nCacheSubset, steps_per_epoch=steps_per_epoch,anneal_strategy='cos', cycle_momentum=False) 68 | lr_list = [scheduler.get_last_lr()] 69 | for ii in range(dataset.nCacheSubset): 70 | scheduler.step((ii+1)*steps_per_epoch) 71 | lr_list.append(scheduler.get_last_lr()) 72 | 73 | for current_subset in range(0,dataset.nCacheSubset): 74 | CACHE_LOSS = [] 75 | 76 | dataset.current_subset=current_subset 77 | dataset.update_subcache(model,outputdim=config['outdim']) 78 | if len(dataset.triplets)==0: 79 | continue 80 | model.train() 81 | data_loader = torch.utils.data.DataLoader(dataset=dataset,shuffle=True, batch_size=batch_size,collate_fn=tuple_collate_fn, num_workers=16) 82 | scheduler_lr = np.linspace(lr_list[current_subset],lr_list[current_subset+1],len(data_loader)) 83 | 84 | for i, batch_data in enumerate(data_loader): 85 | model.zero_grad() 86 | optimizer.zero_grad() 87 | coord, xyz, feat, batch_number, labels, point_pos_pairs = batch_data 88 | coord, xyz, feat, batch_number, labels = coord.to(device), xyz.to(device), feat.to(device), batch_number.to(device),labels.to(device) 89 | local_features, global_descriptor = model(coord, xyz, feat, batch_number) 90 | 91 | loss = find_loss(local_features, global_descriptor, point_pos_pairs) 92 | loss.backward() 93 | optimizer.step() 94 | for param_group in optimizer.param_groups: 95 | last_lr = param_group['lr'] 96 | param_group['lr'] = scheduler_lr[i][0] 97 | writer.add_scalar("Batch Loss", loss.item(), kk_batch) 98 | writer.add_scalar("Batch LR", last_lr, kk_batch) 99 | kk_batch += 1 100 | CACHE_LOSS.append(loss.item()) 101 | sys.stdout.write('\r' + 'Epoch ' + str(e + 1) + ' / ' + str(MAX_EPOCH) + ' Subset ' + str(current_subset + 1) + ' / ' + str(dataset.nCacheSubset) + ' Progress ' + str(i+1) + ' / ' + str(len(data_loader))+ ' Loss ' + str(format(loss.item(),'.2f')) + ' time '+ str(format(time.time()-time1,'.2f'))+' seconds.') 102 | 103 | torch.save(model.state_dict(),os.path.join(os.path.dirname(__file__),'checkpoints/SALSA/Model/model_'+str(e)+'.pth')) 104 | del coord, xyz, feat, batch_number, labels, local_features, global_descriptor, point_pos_pairs 105 | gc.collect() 106 | torch.cuda.empty_cache() 107 | cache_loss_avg = sum(CACHE_LOSS)/len(CACHE_LOSS)*steps_per_epoch 108 | 109 | writer.add_scalar("Subcache Loss", cache_loss_avg, kk_subcache) 110 | writer.add_scalar("Subcache LR", last_lr, kk_subcache) 111 | kk_subcache += 1 112 | 113 | EPOCH_LOSS.append(cache_loss_avg) 114 | print(' ') 115 | print('Avg. Subcache Loss', cache_loss_avg) 116 | torch.save(model.state_dict(),os.path.join(os.path.dirname(__file__),'checkpoints/SALSA/Model/model_'+str(e)+'.pth')) 117 | epoch_loss_avg = sum(EPOCH_LOSS)/len(EPOCH_LOSS) 118 | print(' ') 119 | print('Avg. EPOCH Loss', epoch_loss_avg) 120 | writer.add_scalar("Epoch Loss", epoch_loss_avg, e) 121 | 122 | 123 | if __name__ == "__main__": 124 | seed = 1100 125 | random.seed(seed) 126 | np.random.seed(seed) 127 | torch.manual_seed(seed) 128 | torch.backends.cudnn.benchmark = False 129 | torch.backends.cudnn.enabled = True 130 | torch.backends.cudnn.deterministic = True 131 | torch.cuda.manual_seed_all(seed) 132 | torch.use_deterministic_algorithms(True) 133 | 134 | gc.collect() 135 | main() 136 | 137 | 138 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raktimgg/SALSA/9152e8dff6fb3fc07dd546c9f19bfab462d08a6f/src/utils/__init__.py -------------------------------------------------------------------------------- /src/utils/misc_utils.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import random 3 | 4 | import numpy as np 5 | import torch 6 | import yaml 7 | 8 | 9 | def collate_fn(batch): 10 | coord, xyz, feat = list(zip(*batch)) 11 | offset, count = [], 0 12 | 13 | new_coord, new_xyz, new_feat = [], [], [] 14 | k = 0 15 | for i, item in enumerate(xyz): 16 | 17 | count += item.shape[0] 18 | k += 1 19 | offset.append(count) 20 | new_coord.append(coord[i]) 21 | new_xyz.append(xyz[i]) 22 | new_feat.append(feat[i]) 23 | offset_ = torch.IntTensor(offset[:k]).clone() 24 | offset_[1:] = offset_[1:] - offset_[:-1] 25 | batch_number = torch.cat([torch.tensor([ii]*o) for ii,o in enumerate(offset_)], 0).long() 26 | coords,xyz,feat = torch.cat(new_coord[:k]), torch.cat(new_xyz[:k]), torch.cat(new_feat[:k]) 27 | return coords,xyz,feat,batch_number 28 | 29 | 30 | def tuple_collate_fn(batch): 31 | anchor_coords, anchor_xyz, anchor_feats, pos_coords, pos_xyz, pos_feats, neg_coords, neg_xyz, neg_feats, labels, point_pos_pairs = list(zip(*batch)) 32 | offset, count = [], 0 33 | 34 | new_coord, new_xyz, new_feat, new_label, new_point_pos_pairs = [], [], [], [], [] 35 | 36 | coord, xyz, feat = anchor_coords, anchor_xyz, anchor_feats 37 | for i, item in enumerate(xyz): 38 | 39 | count += item.shape[0] 40 | offset.append(count) 41 | new_coord.append(coord[i]) 42 | new_xyz.append(xyz[i]) 43 | new_feat.append(feat[i]) 44 | new_label.append(labels[i][0]) 45 | 46 | coord, xyz, feat = pos_coords, pos_xyz, pos_feats 47 | for i, item in enumerate(xyz): 48 | 49 | count += item.shape[0] 50 | offset.append(count) 51 | new_coord.append(coord[i]) 52 | new_xyz.append(xyz[i]) 53 | new_feat.append(feat[i]) 54 | new_label.append(labels[i][1]) 55 | 56 | 57 | coord, xyz, feat = neg_coords, neg_xyz, neg_feats 58 | for i, item in enumerate(xyz): 59 | 60 | count += item.shape[0] 61 | offset.append(count) 62 | new_coord.append(coord[i]) 63 | new_xyz.append(xyz[i]) 64 | new_feat.append(feat[i]) 65 | new_label.append(labels[i][2]) 66 | 67 | if point_pos_pairs!=None: 68 | for i, item in enumerate(point_pos_pairs): 69 | # item = np.array(item) + len(new_point_pos_pairs) 70 | # if i>0: 71 | # item1 = np.array(item)[:,0] + new_coord[i-1].shape[0] 72 | # item2 = np.array(item)[:,1] + new_coord[i-1].shape[0] 73 | new_point_pos_pairs.append(item) 74 | 75 | 76 | offset_ = torch.IntTensor(offset).clone() 77 | offset_[1:] = offset_[1:] - offset_[:-1] 78 | batch_number = torch.cat([torch.tensor([ii]*o) for ii,o in enumerate(offset_)], 0).long() 79 | coords,xyz,feat,labels = torch.cat(new_coord), torch.cat(new_xyz), torch.cat(new_feat), torch.Tensor(new_label) 80 | return coords,xyz,feat,batch_number,labels, new_point_pos_pairs 81 | 82 | 83 | def hashM(arr, M): 84 | if isinstance(arr, np.ndarray): 85 | N, D = arr.shape 86 | else: 87 | N, D = len(arr[0]), len(arr) 88 | 89 | hash_vec = np.zeros(N, dtype=np.int64) 90 | for d in range(D): 91 | if isinstance(arr, np.ndarray): 92 | hash_vec += arr[:, d] * M**d 93 | else: 94 | hash_vec += arr[d] * M**d 95 | return hash_vec 96 | 97 | 98 | def pdist(A, B, dist_type='L2'): 99 | if dist_type == 'L2': 100 | D2 = torch.sum((A.unsqueeze(1) - B.unsqueeze(0)).pow(2), 2) 101 | return torch.sqrt(D2 + 1e-7) 102 | elif dist_type == 'SquareL2': 103 | return torch.sum((A.unsqueeze(1) - B.unsqueeze(0)).pow(2), 2) 104 | else: 105 | raise NotImplementedError('Not implemented') 106 | 107 | 108 | ##################################################################################### 109 | # Load poses 110 | ##################################################################################### 111 | 112 | 113 | def load_poses_from_csv(file_name): 114 | with open(file_name, newline='') as f: 115 | reader = csv.reader(f) 116 | data_poses = list(reader) 117 | 118 | transforms = [] 119 | positions = [] 120 | for cnt, line in enumerate(data_poses): 121 | line_f = [float(i) for i in line] 122 | P = np.vstack((np.reshape(line_f[1:], (3, 4)), [0, 0, 0, 1])) 123 | transforms.append(P) 124 | positions.append([P[0, 3], P[1, 3], P[2, 3]]) 125 | return np.asarray(transforms), np.asarray(positions) 126 | 127 | 128 | ##################################################################################### 129 | # Load timestamps 130 | ##################################################################################### 131 | 132 | 133 | def load_timestamps_csv(file_name): 134 | with open(file_name, newline='') as f: 135 | reader = csv.reader(f) 136 | data_poses = list(reader) 137 | data_poses_ts = np.asarray( 138 | [float(t)/1e9 for t in np.asarray(data_poses)[:, 0]]) 139 | return data_poses_ts 140 | 141 | def read_yaml_config(filename): 142 | with open(filename, 'r') as stream: 143 | try: 144 | # Load the YAML file 145 | config = yaml.safe_load(stream) 146 | return config 147 | except yaml.YAMLError as exc: 148 | print(exc) 149 | return None -------------------------------------------------------------------------------- /src/utils/o3d_utils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import numpy as np 4 | import open3d as o3d 5 | 6 | 7 | def make_open3d_point_cloud(xyz, color=None, tile=False): 8 | pcd = o3d.geometry.PointCloud() 9 | pcd.points = o3d.utility.Vector3dVector(xyz) 10 | if color is not None: 11 | if tile: 12 | if len(color) != len(xyz): 13 | color = np.tile(color, (len(xyz), 1)) 14 | pcd.colors = o3d.utility.Vector3dVector(color) 15 | return pcd 16 | 17 | 18 | def get_matching_indices(source, target, search_voxel_size, K=None): 19 | source_copy = copy.deepcopy(source) 20 | target_copy = copy.deepcopy(target) 21 | pcd_tree = o3d.geometry.KDTreeFlann(target_copy) 22 | 23 | match_inds = [] 24 | for i, point in enumerate(source_copy.points): 25 | [_, idx, _] = pcd_tree.search_radius_vector_3d( 26 | point, search_voxel_size) 27 | if K is not None: 28 | idx = idx[:K] 29 | for j in idx: 30 | match_inds.append([i, j]) 31 | return np.asarray(match_inds) 32 | --------------------------------------------------------------------------------