├── .gitignore
├── README.md
├── assets
├── architecture.png
├── attention_diagram2.png
├── box_plot_easy.png
├── box_plot_hard.png
├── lc.png
├── local_correspondences4.png
└── no_lc.png
├── requirements.txt
└── src
├── __init__.py
├── config
├── model.yaml
└── train.yaml
├── data
├── __init__.py
├── augmentation.py
├── dataset_utils.py
├── datasets
│ ├── __init__.py
│ ├── alita
│ │ ├── alita_raw.py
│ │ ├── generate_evaluation_sets.py
│ │ └── test_val_5_0.01_5.pickle
│ ├── augmentation.py
│ ├── base_datasets.py
│ ├── dataset_utils.py
│ ├── kitti
│ │ ├── generate_evaluation_sets.py
│ │ ├── kitti_00_eval.pickle
│ │ ├── kitti_raw.py
│ │ └── utils.py
│ ├── kitti360
│ │ ├── generate_evaluation_sets.py
│ │ ├── kitti360_09_3.0_eval.pickle
│ │ ├── kitti360_raw.py
│ │ └── utils.py
│ ├── mulran
│ │ ├── generate_evaluation_sets.py
│ │ ├── generate_training_tuples.py
│ │ ├── mulran_raw.py
│ │ ├── mulran_train.py
│ │ ├── test_DCC1_DCC2_10.0_5.pickle
│ │ ├── test_Sejong1_Sejong2_0.2_20.pickle
│ │ ├── test_Sejong1_Sejong2_0.2_5.pickle
│ │ └── utils.py
│ ├── point_clouds_utils.py
│ ├── poses_utils.py
│ ├── quantization.py
│ ├── samplers.py
│ └── southbay
│ │ ├── generate_evaluation_sets.py
│ │ ├── generate_training_tuples.py
│ │ ├── pypcd.py
│ │ ├── southbay_raw.py
│ │ ├── southbay_raw_old.py
│ │ └── test_SunnyvaleBigloop_1.0_5_20m.pickle
└── sejong_southbay.py
├── evaluate
├── SALSA
│ ├── eval_salsa_sgv.py
│ └── sgv_utils.py
└── pca.py
├── loss
├── __init__.py
├── global_loss.py
├── local_consistency_loss.py
└── loss.py
├── misc
├── point_clouds.py
├── poses.py
├── robot_trans.py
└── utils.py
├── models
├── Mixer
│ ├── __init__.py
│ └── mixer.py
├── SphereFormer
│ ├── SparseTransformer
│ │ ├── .gitignore
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── license
│ │ ├── setup.py
│ │ ├── sptr
│ │ │ ├── __init__.py
│ │ │ ├── functional.py
│ │ │ ├── modules.py
│ │ │ ├── position_embedding.py
│ │ │ └── utils.py
│ │ ├── src
│ │ │ └── sptr
│ │ │ │ ├── __init__.py
│ │ │ │ ├── attention
│ │ │ │ ├── attention_cuda.cpp
│ │ │ │ ├── attention_cuda_kernel.cu
│ │ │ │ └── attention_cuda_kernel.h
│ │ │ │ ├── cuda_utils.h
│ │ │ │ ├── pointops_api.cpp
│ │ │ │ ├── precompute
│ │ │ │ ├── precompute.cpp
│ │ │ │ ├── precompute_cuda_kernel.cu
│ │ │ │ └── precompute_cuda_kernel.h
│ │ │ │ └── rpe
│ │ │ │ ├── relative_pos_encoding_cuda.cpp
│ │ │ │ ├── relative_pos_encoding_cuda_kernel.cu
│ │ │ │ └── relative_pos_encoding_cuda_kernel.h
│ │ └── test
│ │ │ ├── pointops.py
│ │ │ ├── test_attention_op_step1.py
│ │ │ ├── test_attention_op_step2.py
│ │ │ ├── test_precompute_all.py
│ │ │ ├── test_relative_pos_encoding_op_step1.py
│ │ │ ├── test_relative_pos_encoding_op_step1_all.py
│ │ │ └── test_relative_pos_encoding_op_step2.py
│ └── model
│ │ ├── spherical_transformer.py
│ │ └── unet_spherical_transformer.py
├── __init__.py
├── adappool.py
├── pca_model.py
└── salsa.py
├── train.py
└── utils
├── __init__.py
├── misc_utils.py
└── o3d_utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Ignore large files
2 | src/checkpoints/
3 | *.pth
4 | *__pycache__/
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SALSA: Swift Adaptive Lightweight Self-Attention for Enhanced LiDAR Place Recognition
2 |
3 | 📖 Paper: [`RA-L`](https://ieeexplore.ieee.org/document/10629049)
4 |
5 | 📖 Pre-print: [``arXiv``](https://arxiv.org/abs/2407.08260)
6 |
7 | 📹 Video: [`Youtube`](https://www.youtube.com/watch?v=JLunemW91bQ)
8 |
9 | #### Authors: Raktim Gautam Goswami, Naman Patel, Prashanth Krishnamurthy, Farshad Khorrami
10 |
11 | #### Control/Robotics Research Laboratory (CRRL), Department of Electrical and Computer Engineering, NYU Tandon School of Engineering
12 |
13 | ### 💡 Contributions
14 | - **SALSA**: A novel, lightweight, and efficient framework for LiDAR place recognition that delivers state-of-the-art performance while maintaining real-time operational capabilities.
15 | - **SphereFormer**: Utilized for local descriptor extraction with radial and cubic window attention to boost localization performance for sparse distant points.
16 | - **Adaptive Pooling**: A self-attention adaptive pooling module to fuse local descriptors into global tokens. It can aggregate arbitrary numbers of points in a point cloud without pre-processing.
17 |
18 | - **MLP Mixer Token Aggregator**: An MLP mixer-based aggregator to iteratively incorporate global context information to generate a robust global scene descriptor.
19 |
20 |
21 |
22 |
23 | Fig. 1: Overview of our SALSA framework to generate scene descriptors from point clouds for place recognition.
24 |
25 |
26 |
27 | ### 🔨 Environment creation
28 |
29 | ```bash
30 | conda create --name salsa python=3.10.11
31 | conda activate salsa
32 | pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu117
33 | pip install torch-scatter -f https://data.pyg.org/whl/torch-2.0.1+cu117.html
34 | pip install torch-sparse -f https://data.pyg.org/whl/torch-2.0.1+cu117.html
35 | pip install torch-cluster -f https://data.pyg.org/whl/torch-2.0.1+cu117.html
36 | pip install -r requirements.txt
37 | pip install --no-deps timm==0.9.7
38 | ```
39 | Install [SpTr](https://github.com/dvlab-research/SparseTransformer) from source.
40 |
41 |
42 | ### 📊💾 Dataset Download
43 | The model is trained on Mulran Sejong 01/02 sequences and Apollo Southbay (excluding Sunnyvale). Evaluation is performed on 'easy set': Apollo-Southbay (Sunnyvale), SemanticKITTI, Mulran Sejong, and 'hard set': Mulran DCC1/DCC2, KITTI360, ALITA. The datasets can be downloaded from the following links.
44 | - [MulRan](https://sites.google.com/view/mulran-pr/download) dataset: ground truth data (*.csv) and LiDAR point clouds (Ouster.zip).
45 | - [Apollo-Southbay](https://developer.apollo.auto/southbay.html) dataset.
46 | - [SemanticKITTI](http://semantic-kitti.org/dataset.html#download) dataset (velodyne point clouds and calibration data for poses).
47 | - [ALITA](https://github.com/MetaSLAM/ALITA) dataset.
48 | - [KITTI-360](https://www.cvlibs.net/datasets/kitti-360/user_login.php) dataset (raw velodyne scans, calibrations and vehicle poses).
49 |
50 |
51 | ### 📊💾 Dataset Pickle Creation
52 |
53 | Create Training Pickle
54 |
55 | ```bash
56 | cd src/data/datasets/
57 | python southbay/generate_training_tuples.py --dataset_root
58 | python mulran/generate_training_tuples.py --dataset_root
59 | ```
60 |
61 | Create Evaluation Pickle
62 | ```bash
63 | python mulran/generate_evaluation_sets.py --dataset_root , --sequence sejong
64 | python mulran/generate_evaluation_sets.py --dataset_root , --sequence mulran
65 | python southbay/generate_evaluation_sets.py --dataset_root
66 | python kitti/generate_evaluation_sets.py --dataset_root
67 | python kitti360/generate_evaluation_sets.py --dataset_root
68 | python alita/generate_evaluation_sets.py --dataset_root
69 | ```
70 |
71 | ### ✈️ Training
72 | Navigate to the base, create a folder inside src named checkpoints to save the trained models directory and start training.
73 | To change the model and training parameters, change them in config/model.yaml and config/train.yaml, respectively.
74 | ```bash
75 | mkdir -p src/checkpoints/SALSA/Model src/checkpoints/SALSA/PCA
76 | python src/train.py
77 | ```
78 | This will train the model on the generated training dataset and store the saved models for each epoch in src/checkpoints.
79 |
80 | ### ✈️ PCA
81 | Learn PCA using trained model
82 | ```bash
83 | python src/evaluate/pca.py
84 | ```
85 | This will learn a PCA to compress and decorrelate the scene descriptor and store the learned PCA as a pytorch model in src/checkpoints.
86 |
87 | ### ✈️ Evaluation
88 | ```bash
89 | python src/evaluate/SALSA/eval_salsa_sgv.py --dataset_root --dataset_type --only_global True
90 | ```
91 | Our pre-trained models can also be downloaded from this [link](https://drive.google.com/drive/folders/1lehq0Hki75i7U_Twhd5uxxz37WvcRzGa?usp=sharing). After downloading, copy the contents into the 'src/checkpoints' directory.
92 |
93 |
94 | ### 📝 Results
95 | The spreads of the Recall@1 before and after re-ranking for best-performing models are plotted in the following figure.
96 |
97 |
98 |

99 |
Fig. 2a: 'Easy' Dataset
100 |
101 |
102 |

103 |
'Fig. 2b: Hard' Dataset
104 |
105 |
Fig. 2: Box plot displaying Recall@1 across six datasets, with first to third quartile spans, whiskers for data variability, and internal lines as medians.
106 |
107 |
108 |
109 | ### 🌈 Visualizations
110 |
111 |
112 | Fig. 3: Visualization of areas attended to by different tokens from the adaptive pooling layer. Each token focuses on different geometries: trees and traffic signs (green), road intersections (red), and distant points (blue).
113 |
114 |
115 |
116 |
117 |
118 |
119 | Fig. 4: Point matches between query and target clouds using LoGG3D-Net and SALSA local descriptors. Matching colors indicate correspondences; circles highlight SALSA’s superior performance on sparse distant points.
120 |
121 |
122 |
123 |
124 |
125 |
126 |

127 |
Fig. 5a: Without Loop Detection.
128 |
129 |
130 |

131 |
Fig. 5b: With Loop Detection.
132 |
133 |
Fig. 5: Comparison of LiDAR-only odometry and maps: (a) without loop detection, and (b) after online pose graph optimization from SALSA loop detections. The highlighted rectangles emphasize the map and odometry disparities due to loop closures.
134 |
135 |
136 | ## 📧 Citation
137 |
138 | If you find our work useful in your research please consider citing our publication:
139 | ```bibtex
140 | @article{goswami2024salsa,
141 | title={SALSA: Swift Adaptive Lightweight Self-Attention for Enhanced LiDAR Place Recognition},
142 | author={Goswami, Raktim Gautam and Patel, Naman and Krishnamurthy, Prashanth and Khorrami, Farshad},
143 | journal={IEEE Robotics and Automation Letters},
144 | year={2024},
145 | }
146 | ```
--------------------------------------------------------------------------------
/assets/architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raktimgg/SALSA/9152e8dff6fb3fc07dd546c9f19bfab462d08a6f/assets/architecture.png
--------------------------------------------------------------------------------
/assets/attention_diagram2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raktimgg/SALSA/9152e8dff6fb3fc07dd546c9f19bfab462d08a6f/assets/attention_diagram2.png
--------------------------------------------------------------------------------
/assets/box_plot_easy.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raktimgg/SALSA/9152e8dff6fb3fc07dd546c9f19bfab462d08a6f/assets/box_plot_easy.png
--------------------------------------------------------------------------------
/assets/box_plot_hard.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raktimgg/SALSA/9152e8dff6fb3fc07dd546c9f19bfab462d08a6f/assets/box_plot_hard.png
--------------------------------------------------------------------------------
/assets/lc.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raktimgg/SALSA/9152e8dff6fb3fc07dd546c9f19bfab462d08a6f/assets/lc.png
--------------------------------------------------------------------------------
/assets/local_correspondences4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raktimgg/SALSA/9152e8dff6fb3fc07dd546c9f19bfab462d08a6f/assets/local_correspondences4.png
--------------------------------------------------------------------------------
/assets/no_lc.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raktimgg/SALSA/9152e8dff6fb3fc07dd546c9f19bfab462d08a6f/assets/no_lc.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch_geometric
2 | spconv-cu117
3 | cumm-cu117
4 | tensorboard
5 | open3d
6 | python-lzf
7 | faiss-gpu
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raktimgg/SALSA/9152e8dff6fb3fc07dd546c9f19bfab462d08a6f/src/__init__.py
--------------------------------------------------------------------------------
/src/config/model.yaml:
--------------------------------------------------------------------------------
1 | feat_extractor:
2 | feature_dim: 16
3 | patch_size: 1
4 | input_c: 3
5 | m: 32
6 | block_reps: 2
7 | layers: [32, 64, 128]
8 | window_size_sphere: [2, 2, 80]
9 | drop_path_rate: 0.3
10 | window_size_scale: [2.0, 1.5]
11 | sphere_layers: [1,2,3,4,5]
12 | a: 0.0125
13 |
14 | aggregator:
15 | tokens: 512
16 | out_channels: 128
17 | mix_depth: 4
18 | mlp_ratio: 1
19 | out_d: 4
20 |
--------------------------------------------------------------------------------
/src/config/train.yaml:
--------------------------------------------------------------------------------
1 | batch_size: 12
2 | writer_loc: 'runs/SALSA'
3 | cached_queries: 1000
4 | device: 'cuda'
5 | max_epoch: 50
6 | lr: 0.001
7 | max_lr: 0.005
8 | outdim: 512
--------------------------------------------------------------------------------
/src/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raktimgg/SALSA/9152e8dff6fb3fc07dd546c9f19bfab462d08a6f/src/data/__init__.py
--------------------------------------------------------------------------------
/src/data/dataset_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def fnv_hash_vec(arr):
5 | """
6 | FNV64-1A
7 | """
8 | assert arr.ndim == 2
9 | # Floor first for negative coordinates
10 | arr = arr.copy()
11 | arr = arr.astype(np.uint64, copy=False)
12 | hashed_arr = np.uint64(14695981039346656037) * np.ones(arr.shape[0], dtype=np.uint64)
13 | for j in range(arr.shape[1]):
14 | hashed_arr *= np.uint64(1099511628211)
15 | hashed_arr = np.bitwise_xor(hashed_arr, arr[:, j])
16 | return hashed_arr
17 |
18 |
19 | def ravel_hash_vec(arr):
20 | """
21 | Ravel the coordinates after subtracting the min coordinates.
22 | """
23 | assert arr.ndim == 2
24 | arr = arr.copy()
25 | arr -= arr.min(0)
26 | arr = arr.astype(np.uint64, copy=False)
27 | arr_max = arr.max(0).astype(np.uint64) + 1
28 |
29 | keys = np.zeros(arr.shape[0], dtype=np.uint64)
30 | # Fortran style indexing
31 | for j in range(arr.shape[1] - 1):
32 | keys += arr[:, j]
33 | keys *= arr_max[j + 1]
34 | keys += arr[:, -1]
35 | return keys
36 |
37 |
38 | def voxelize(coord, voxel_size=0.05, hash_type='fnv', mode=0):
39 | discrete_coord = np.floor(coord / np.array(voxel_size))
40 | if hash_type == 'ravel':
41 | key = ravel_hash_vec(discrete_coord)
42 | else:
43 | key = fnv_hash_vec(discrete_coord)
44 |
45 | idx_sort = np.argsort(key)
46 | key_sort = key[idx_sort]
47 | _, count = np.unique(key_sort, return_counts=True)
48 | idx_select = np.cumsum(np.insert(count, 0, 0)[0:-1]) + np.random.randint(0, count.max(), count.size) % count
49 | idx_unique = idx_sort[idx_select]
50 | return idx_unique
51 |
--------------------------------------------------------------------------------
/src/data/datasets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raktimgg/SALSA/9152e8dff6fb3fc07dd546c9f19bfab462d08a6f/src/data/datasets/__init__.py
--------------------------------------------------------------------------------
/src/data/datasets/alita/alita_raw.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | import numpy as np
5 | import open3d as o3d
6 |
7 | sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))
8 | from datasets.point_clouds_utils import PointCloudLoader
9 |
10 |
11 | class ALITAPointCloudLoader(PointCloudLoader):
12 | def set_properties(self):
13 | # Set point cloud propertiers, such as ground_plane_level. Must be defined in inherited classes.
14 | self.ground_plane_level = -1.6
15 |
16 | def read_pc(self, file_pathname):
17 | pcd = o3d.io.read_point_cloud(file_pathname)
18 | xyz = np.asarray(pcd.points)
19 | return xyz
20 |
--------------------------------------------------------------------------------
/src/data/datasets/alita/generate_evaluation_sets.py:
--------------------------------------------------------------------------------
1 | # Generate evaluation sets
2 | # This script is adapted from: https://github.com/jac99/Egonn/blob/main/datasets/southbay/generate_evaluation_sets.py
3 |
4 | import argparse
5 | import glob
6 | import os
7 | import sys
8 | from typing import List
9 |
10 | import numpy as np
11 | from scipy.spatial.transform import Rotation as R
12 |
13 | sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))
14 |
15 | from datasets.base_datasets import EvaluationSet, EvaluationTuple, filter_query_elements
16 |
17 |
18 | def get_pose_transform(pose6d):
19 | rot_matrix = R.from_euler('xyz', pose6d[3:]).as_matrix()
20 | trans_vector = pose6d[:3].reshape((3, 1))
21 |
22 | trans_matrix = np.identity(4)
23 | trans_matrix[:3, :3] = rot_matrix
24 | trans_matrix[:3, 3:] = trans_vector
25 |
26 | return trans_matrix
27 |
28 | def get_length(poses):
29 | current_pos = poses[0][:2,3]
30 | total = 0
31 | for i in range(len(poses)-1):
32 | delta = np.linalg.norm(poses[i][:2,3] - poses[i+1][:2,3])
33 | print(delta)
34 | total += delta
35 | print('')
36 |
37 | def get_poses_ugv(pose_files):
38 | poses = []
39 | timestamps = []
40 | for f in pose_files:
41 | pose = np.load(f)
42 | timestamps.append(pose[6])
43 | T = get_pose_transform(pose[:6])
44 | poses.append(T)
45 | return np.asarray(timestamps), np.asarray(poses)
46 |
47 | def get_scans(base_dir, area, split, min_displacement: float = 0.0) -> List[EvaluationTuple]:
48 |
49 | operating_dir = os.path.join(base_dir, split, area)
50 | pcd_files = sorted(glob.glob(os.path.join(operating_dir, '*.pcd')))
51 | pose_files = sorted(glob.glob(os.path.join(operating_dir, '*.npy')))
52 | timestamps, poses = get_poses_ugv(pose_files)
53 | get_length(poses)
54 |
55 |
56 | elems = []
57 | for ndx in range(len(pcd_files)):
58 | pose = poses[ndx]
59 | position = pose[0:2, 3] # (x, y) position in global coordinate frame
60 | rel_scan_filepath = pcd_files[ndx][len(base_dir)+1:]
61 | print(rel_scan_filepath)
62 | timestamp = timestamps[ndx]
63 |
64 | item = EvaluationTuple(timestamp, rel_scan_filepath, position=position, pose=pose)
65 | elems.append(item)
66 |
67 | print(f"{len(elems)} total elements in {split} split")
68 |
69 | # Filter-out elements leaving only 1 per grid cell with min_displacement size
70 | pos = np.zeros((len(elems), 2), dtype=np.float32)
71 | for ndx, e in enumerate(elems):
72 | pos[ndx] = e.position
73 |
74 | # Quantize x-y coordinates. Quantized coords start from 0
75 | pos = np.floor(pos / min_displacement)
76 | pos = pos.astype(int)
77 | _, unique_ndx = np.unique(pos, axis=0, return_index=True)
78 |
79 | # Leave only unique elements
80 | elems = [elems[i] for i in unique_ndx]
81 | print(f"{len(elems)} filtered elements in {split} split with grid cell size = {min_displacement}")
82 |
83 | return elems
84 |
85 |
86 | def generate_evaluation_set(dataset_root: str, area: str, min_displacement: float = 0.0, dist_threshold=5) -> \
87 | EvaluationSet:
88 | map_set = get_scans(dataset_root, area, 'DATABASE', min_displacement)
89 | query_set = get_scans(dataset_root, area, 'QUERY', min_displacement)
90 | query_set = filter_query_elements(query_set, map_set, dist_threshold)
91 | print(f'Area: {area} - {len(map_set)} database elements, {len(query_set)} query elements\n')
92 | return EvaluationSet(query_set, map_set)
93 |
94 |
95 | if __name__ == '__main__':
96 | parser = argparse.ArgumentParser(description='Generate evaluation sets for UGV dataset')
97 | parser.add_argument('--dataset_root', type=str, required=False, default='/data/raktim/Datasets/ALITA/VAL')
98 | parser.add_argument('--min_displacement', type=float, default=0.01)
99 | # Ignore query elements that do not have a corresponding map element within the given threshold (in meters)
100 | parser.add_argument('--dist_threshold', type=float, default=5)
101 |
102 | args = parser.parse_args()
103 | print(f'Dataset root: {args.dataset_root}')
104 | print(f'Minimum displacement between scans in each set (map/query): {args.min_displacement}')
105 | print(f'Ignore query elements without a corresponding map element within a threshold [m]: {args.dist_threshold}')
106 |
107 | area = 'val_5' # Evaluation area
108 | eval_set = generate_evaluation_set(dataset_root=args.dataset_root,area=area, min_displacement=args.min_displacement,
109 | dist_threshold=args.dist_threshold)
110 | pickle_name = f'test_{area}_{args.min_displacement}_{args.dist_threshold}.pickle'
111 | # file_path_name = os.path.join(args.dataset_root, pickle_name)
112 | file_path_name = os.path.join(os.path.dirname(__file__), pickle_name)
113 | print(f"Saving evaluation pickle: {file_path_name}")
114 | eval_set.save(file_path_name)
115 |
--------------------------------------------------------------------------------
/src/data/datasets/alita/test_val_5_0.01_5.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raktimgg/SALSA/9152e8dff6fb3fc07dd546c9f19bfab462d08a6f/src/data/datasets/alita/test_val_5_0.01_5.pickle
--------------------------------------------------------------------------------
/src/data/datasets/base_datasets.py:
--------------------------------------------------------------------------------
1 | # Base dataset classes, inherited by dataset-specific classes
2 | # This file is adapted from: https://github.com/jac99/Egonn/blob/main/datasets/base_datasets.py
3 |
4 | import os
5 | import sys
6 | import pickle
7 | from typing import List, Dict
8 | import torch
9 | import numpy as np
10 | from sklearn.neighbors import KDTree
11 | from torch.utils.data import Dataset
12 |
13 | sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
14 | from datasets.alita.alita_raw import ALITAPointCloudLoader
15 | from datasets.kitti.kitti_raw import KittiPointCloudLoader
16 | from datasets.mulran.mulran_raw import MulranPointCloudLoader
17 | from datasets.southbay.southbay_raw import SouthbayPointCloudLoader
18 | from datasets.kitti360.kitti360_raw import Kitti360PointCloudLoader
19 | from datasets.point_clouds_utils import PointCloudLoader
20 |
21 |
22 | class TrainingTuple:
23 | # Tuple describing an element for training/validation
24 | def __init__(self, id: int, timestamp: int, rel_scan_filepath: str, positives: np.ndarray,
25 | non_negatives: np.ndarray, pose: np, positives_poses: Dict[int, np.ndarray] = None):
26 | # id: element id (ids start from 0 and are consecutive numbers)
27 | # ts: timestamp
28 | # rel_scan_filepath: relative path to the scan
29 | # positives: sorted ndarray of positive elements id
30 | # negatives: sorted ndarray of elements id
31 | # pose: pose as 4x4 matrix
32 | # positives_poses: relative poses of positive examples refined using ICP
33 | self.id = id
34 | self.timestamp = timestamp
35 | self.rel_scan_filepath = rel_scan_filepath
36 | self.positives = positives
37 | self.non_negatives = non_negatives
38 | self.pose = pose
39 | self.positives_poses = positives_poses
40 |
41 |
42 | class EvaluationTuple:
43 | # Tuple describing an evaluation set element
44 | def __init__(self, timestamp: int, rel_scan_filepath: str, position: np.array, pose: np.array = None):
45 | # position: x, y position in meters
46 | # pose: 6 DoF pose (as 4x4 pose matrix)
47 | assert position.shape == (2,)
48 | assert pose is None or pose.shape == (4, 4)
49 | self.timestamp = timestamp
50 | self.rel_scan_filepath = rel_scan_filepath
51 | self.position = position
52 | self.pose = pose
53 |
54 | def to_tuple(self):
55 | return self.timestamp, self.rel_scan_filepath, self.position, self.pose
56 |
57 |
58 | class TrainingDataset(Dataset):
59 | def __init__(self, dataset_path: str, dataset_type: str, query_filename: str, transform=None, set_transform=None):
60 | # remove_zero_points: remove points with all zero coords
61 | assert os.path.exists(dataset_path), 'Cannot access dataset path: {}'.format(dataset_path)
62 | self.dataset_path = dataset_path
63 | self.dataset_type = dataset_type
64 | self.query_filepath = os.path.join(dataset_path, query_filename)
65 | assert os.path.exists(self.query_filepath), 'Cannot access query file: {}'.format(self.query_filepath)
66 | self.transform = transform
67 | self.set_transform = set_transform
68 | self.queries: Dict[int, TrainingTuple] = pickle.load(open(self.query_filepath, 'rb'))
69 | print('{} queries in the dataset'.format(len(self)))
70 |
71 | # pc_loader must be set in the inheriting class
72 | self.pc_loader = get_pointcloud_loader(self.dataset_type)
73 |
74 | def __len__(self):
75 | return len(self.queries)
76 |
77 | def __getitem__(self, ndx):
78 | # Load point cloud and apply transform
79 | file_pathname = os.path.join(self.dataset_path, self.queries[ndx].rel_scan_filepath)
80 | query_pc = self.pc_loader(file_pathname)
81 | query_pc = torch.tensor(query_pc, dtype=torch.float)
82 | if self.transform is not None:
83 | query_pc = self.transform(query_pc)
84 | return query_pc, ndx
85 |
86 | def get_positives(self, ndx):
87 | return self.queries[ndx].positives
88 |
89 | def get_non_negatives(self, ndx):
90 | return self.queries[ndx].non_negatives
91 |
92 |
93 | class EvaluationSet:
94 | # Evaluation set consisting of map and query elements
95 | def __init__(self, query_set: List[EvaluationTuple] = None, map_set: List[EvaluationTuple] = None):
96 | self.query_set = query_set
97 | self.map_set = map_set
98 |
99 | def save(self, pickle_filepath: str):
100 | # Pickle the evaluation set
101 |
102 | # Convert data to tuples and save as tuples
103 | query_l = []
104 | for e in self.query_set:
105 | query_l.append(e.to_tuple())
106 |
107 | map_l = []
108 | for e in self.map_set:
109 | map_l.append(e.to_tuple())
110 | pickle.dump([query_l, map_l], open(pickle_filepath, 'wb'))
111 |
112 | def load(self, pickle_filepath: str):
113 | # Load evaluation set from the pickle
114 | query_l, map_l = pickle.load(open(pickle_filepath, 'rb'))
115 |
116 | self.query_set = []
117 | for e in query_l:
118 | self.query_set.append(EvaluationTuple(e[0], e[1], e[2], e[3]))
119 |
120 | self.map_set = []
121 | for e in map_l:
122 | self.map_set.append(EvaluationTuple(e[0], e[1], e[2], e[3]))
123 |
124 | def get_map_positions(self):
125 | # Get map positions as (N, 2) array
126 | positions = np.zeros((len(self.map_set), 2), dtype=self.map_set[0].position.dtype)
127 | for ndx, pos in enumerate(self.map_set):
128 | positions[ndx] = pos.position
129 | return positions
130 |
131 | def get_query_positions(self):
132 | # Get query positions as (N, 2) array
133 | positions = np.zeros((len(self.query_set), 2), dtype=self.query_set[0].position.dtype)
134 | for ndx, pos in enumerate(self.query_set):
135 | positions[ndx] = pos.position
136 | return positions
137 |
138 | def filter_query_elements(query_set: List[EvaluationTuple], map_set: List[EvaluationTuple],
139 | dist_threshold: float) -> List[EvaluationTuple]:
140 | # Function used in evaluation dataset generation
141 | # Filters out query elements without a corresponding map element within dist_threshold threshold
142 | map_pos = np.zeros((len(map_set), 2), dtype=np.float32)
143 | for ndx, e in enumerate(map_set):
144 | map_pos[ndx] = e.position
145 |
146 | # Build a kdtree
147 | kdtree = KDTree(map_pos)
148 |
149 | filtered_query_set = []
150 | count_ignored = 0
151 | for ndx, e in enumerate(query_set):
152 | position = e.position.reshape(1, -1)
153 | nn = kdtree.query_radius(position, dist_threshold, count_only=True)[0]
154 | if nn > 0:
155 | filtered_query_set.append(e)
156 | else:
157 | count_ignored += 1
158 |
159 | print(f"{count_ignored} query elements ignored - not having corresponding map element within {dist_threshold} [m] radius")
160 | return filtered_query_set
161 |
162 | def get_pointcloud_loader(dataset_type) -> PointCloudLoader:
163 | if dataset_type == 'mulran':
164 | return MulranPointCloudLoader()
165 | elif dataset_type == 'southbay':
166 | return SouthbayPointCloudLoader()
167 | elif dataset_type == 'kitti':
168 | return KittiPointCloudLoader()
169 | elif dataset_type == 'alita':
170 | return ALITAPointCloudLoader()
171 | elif dataset_type == 'kitti360':
172 | return Kitti360PointCloudLoader()
173 | else:
174 | raise NotImplementedError(f"Unsupported dataset type: {dataset_type}")
--------------------------------------------------------------------------------
/src/data/datasets/kitti/generate_evaluation_sets.py:
--------------------------------------------------------------------------------
1 | # Test set for Kitti Sequence 00 dataset.
2 | # Following procedures in [cite papers Kitti for place reco] we use 170 seconds of drive from sequence for map generation
3 | # and the rest is left for queries
4 |
5 | import numpy as np
6 | import argparse
7 | from typing import List
8 | import os
9 | import sys
10 | sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))
11 |
12 | from datasets.kitti.kitti_raw import KittiSequence
13 | from datasets.base_datasets import EvaluationTuple, EvaluationSet
14 | from datasets.dataset_utils import filter_query_elements
15 |
16 |
17 | MAP_TIMERANGE = (0, 170)
18 |
19 |
20 | def get_scans(sequence: KittiSequence, min_displacement: float = 0.1, ts_range: tuple = None) -> List[EvaluationTuple]:
21 | # Get a list of all point clouds from the sequence (the full sequence or test split only)
22 |
23 | elems = []
24 | old_pos = None
25 | count_skipped = 0
26 | displacements = []
27 |
28 | for ndx in range(len(sequence)):
29 | if ts_range is not None:
30 | if (ts_range[0] > sequence.rel_lidar_timestamps[ndx]) or (ts_range[1] < sequence.rel_lidar_timestamps[ndx]):
31 | continue
32 | pose = sequence.lidar_poses[ndx]
33 | # Kitti poses are in camera coordinates system where where y is upper axis dim
34 | position = pose[[0,2], 3]
35 |
36 | if old_pos is not None:
37 | displacements.append(np.linalg.norm(old_pos - position))
38 |
39 | if np.linalg.norm(old_pos - position) < min_displacement:
40 | # Ignore the point cloud if the vehicle didn't move
41 | count_skipped += 1
42 | continue
43 |
44 | item = EvaluationTuple(sequence.rel_lidar_timestamps[ndx], sequence.rel_scan_filepath[ndx], position, pose)
45 | elems.append(item)
46 | old_pos = position
47 |
48 | print(f'{count_skipped} clouds skipped due to displacement smaller than {min_displacement}')
49 | print(f'mean displacement {np.mean(np.array(displacements))}')
50 | return elems
51 |
52 |
53 | def generate_evaluation_set(dataset_root: str, map_sequence: str, min_displacement: float = 0.1,
54 | dist_threshold: float = 5.) -> EvaluationSet:
55 |
56 | sequence = KittiSequence(dataset_root, map_sequence)
57 |
58 | map_set = get_scans(sequence, min_displacement, MAP_TIMERANGE)
59 | query_set = get_scans(sequence, min_displacement, (MAP_TIMERANGE[-1], sequence.rel_lidar_timestamps[-1]))
60 | query_set = filter_query_elements(query_set, map_set, dist_threshold)
61 | print(f'{len(map_set)} database elements, {len(query_set)} query elements')
62 | return EvaluationSet(query_set, map_set)
63 |
64 |
65 | if __name__ == '__main__':
66 | parser = argparse.ArgumentParser(description='Generate evaluation sets for KItti dataset')
67 | parser.add_argument('--dataset_root', type=str, required=True)
68 | parser.add_argument('--min_displacement', type=float, default=0.1)
69 | # Ignore query elements that do not have a corresponding map element within the given threshold (in meters)
70 | parser.add_argument('--dist_threshold', type=float, default=5.)
71 |
72 | args = parser.parse_args()
73 |
74 | # Sequences are fixed
75 | sequence = '00'
76 | print(f'Dataset root: {args.dataset_root}')
77 | print(f'Kitti sequence: {sequence}')
78 | print(f'Minimum displacement between consecutive anchors: {args.min_displacement}')
79 | print(f'Ignore query elements without a corresponding map element within a threshold [m]: {args.dist_threshold}')
80 |
81 | kitti_eval_set = generate_evaluation_set(args.dataset_root, sequence, min_displacement=args.min_displacement,
82 | dist_threshold=args.dist_threshold)
83 | file_path_name = os.path.join(args.dataset_root, f'kitti_{sequence}_eval.pickle')
84 | print(f"Saving evaluation pickle: {file_path_name}")
85 | kitti_eval_set.save(file_path_name)
86 |
--------------------------------------------------------------------------------
/src/data/datasets/kitti/kitti_00_eval.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raktimgg/SALSA/9152e8dff6fb3fc07dd546c9f19bfab462d08a6f/src/data/datasets/kitti/kitti_00_eval.pickle
--------------------------------------------------------------------------------
/src/data/datasets/kitti/kitti_raw.py:
--------------------------------------------------------------------------------
1 | # Functions and classes operating on a raw Kitti dataset
2 |
3 | import os
4 | import sys
5 |
6 | import numpy as np
7 | import torch
8 | from torch.utils.data import Dataset
9 |
10 | sys.path.append('../../..')
11 |
12 | from misc.point_clouds import PointCloudLoader
13 |
14 |
15 | class KittiPointCloudLoader(PointCloudLoader):
16 | def set_properties(self):
17 | # Set point cloud properties, such as ground_plane_level.
18 | self.ground_plane_level = -1.5
19 |
20 | def read_pc(self, file_pathname: str) -> torch.Tensor:
21 | # Reads the point cloud without pre-processing
22 | # Returns Nx3 tensor
23 | pc = np.fromfile(file_pathname, dtype=np.float32)
24 | # PC in Mulran is of size [num_points, 4] -> x,y,z,reflectance
25 | pc = np.reshape(pc, (-1, 4))[:, :3]
26 | return pc
27 |
28 |
29 | class KittiSequence(Dataset):
30 | """
31 | Point cloud from a sequence from a raw Mulran dataset
32 | """
33 | def __init__(self, dataset_root: str, sequence_name: str, pose_time_tolerance: float = 1.,
34 | remove_zero_points: bool = True):
35 | # pose_time_tolerance: (in seconds) skip point clouds without corresponding pose information (based on
36 | # timestamps difference)
37 | # remove_zero_points: remove (0,0,0) points
38 |
39 | assert os.path.exists(dataset_root), f'Cannot access dataset root: {dataset_root}'
40 | self.dataset_root = dataset_root
41 | self.sequence_name = sequence_name
42 | # self.sequence_path = os.path.join(self.dataset_root, 'sequences')
43 | # assert os.path.exists(self.sequence_path), f'Cannot access sequence: {self.sequence_path}'
44 | self.rel_lidar_path = os.path.join('sequences', self.sequence_name, 'velodyne')
45 | # lidar_path = os.path.join(self.sequence_path, self.rel_lidar_path)
46 | # assert os.path.exists(lidar_path), f'Cannot access lidar scans: {lidar_path}'
47 | self.pose_file = os.path.join(self.dataset_root, 'poses', self.sequence_name + '.txt')
48 | assert os.path.exists(self.pose_file), f'Cannot access sequence pose file: {self.pose_file}'
49 | self.times_file = os.path.join(self.dataset_root, 'sequences', self.sequence_name, 'times.txt')
50 | assert os.path.exists(self.pose_file), f'Cannot access sequence times file: {self.times_file}'
51 | # Maximum discrepancy between timestamps of LiDAR scan and global pose in seconds
52 | self.pose_time_tolerance = pose_time_tolerance
53 | self.remove_zero_points = remove_zero_points
54 |
55 | self.rel_lidar_timestamps, self.lidar_poses, filenames = self._read_lidar_poses()
56 | self.rel_scan_filepath = [os.path.join(self.rel_lidar_path, '%06d%s' % (e, '.bin')) for e in filenames]
57 |
58 | def __len__(self):
59 | return len(self.rel_lidar_timestamps)
60 |
61 | def __getitem__(self, ndx):
62 | scan_filepath = os.path.join(self.dataset_root, self.rel_scan_filepath[ndx])
63 | pc = load_pc(scan_filepath)
64 | if self.remove_zero_points:
65 | mask = np.all(np.isclose(pc, 0), axis=1)
66 | pc = pc[~mask]
67 | return {'pc': pc, 'pose': self.lidar_poses[ndx], 'ts': self.rel_lidar_timestamps[ndx]}
68 |
69 | def _read_lidar_poses(self):
70 | fnames = os.listdir(os.path.join(self.dataset_root, self.rel_lidar_path))
71 | temp = os.path.join(self.dataset_root, self.rel_lidar_path)
72 | fnames = [e for e in fnames if os.path.isfile(os.path.join(temp, e))]
73 | assert len(fnames) > 0, f"Make sure that the path {self.rel_lidar_path}"
74 | filenames = sorted([int(os.path.split(fname)[-1][:-4]) for fname in fnames])
75 | with open(self.pose_file, "r") as h:
76 | txt_poses = h.readlines()
77 |
78 | n = len(txt_poses)
79 | poses = np.zeros((n, 4, 4), dtype=np.float64) # 4x4 pose matrix
80 |
81 | for ndx, pose in enumerate(txt_poses):
82 | # Split by comma and remove whitespaces
83 | temp = [e.strip() for e in pose.split(' ')]
84 | assert len(temp) == 12, f'Invalid line in global poses file: {temp}'
85 | # poses in kitti ar ein cam0 reference
86 | poses[ndx] = np.array([[float(temp[0]), float(temp[1]), float(temp[2]), float(temp[3])],
87 | [float(temp[4]), float(temp[5]), float(temp[6]), float(temp[7])],
88 | [float(temp[8]), float(temp[9]), float(temp[10]), float(temp[11])],
89 | [0., 0., 0., 1.]])
90 | rel_ts = np.genfromtxt(self.times_file)
91 |
92 | return rel_ts, poses, filenames
93 |
94 |
95 | def load_pc(filepath):
96 | # Load point cloud, does not apply any transform
97 | # Returns Nx3 matrix
98 | pc = np.fromfile(filepath, dtype=np.float32)
99 | # PC in Kitti is of size [num_points, 4] -> x,y,z,reflectance
100 | pc = np.reshape(pc, (-1, 4))[:, :3]
101 | return pc
102 |
--------------------------------------------------------------------------------
/src/data/datasets/kitti/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def velo2cam():
5 | R = np.array([
6 | 7.533745e-03, -9.999714e-01, -6.166020e-04, 1.480249e-02, 7.280733e-04,
7 | -9.998902e-01, 9.998621e-01, 7.523790e-03, 1.480755e-02
8 | ]).reshape(3, 3)
9 | T = np.array([-4.069766e-03, -7.631618e-02, -2.717806e-01]).reshape(3, 1)
10 | velo2cam = np.hstack([R, T])
11 | velo2cam = np.vstack((velo2cam, [0, 0, 0, 1])).T
12 | return velo2cam
13 |
14 |
15 | def get_relative_pose(pose_1, pose_2):
16 | # as seen in https://github.com/chrischoy/FCGF
17 | M = (velo2cam() @ pose_1.T @ np.linalg.inv(pose_2.T) @ np.linalg.inv(velo2cam())).T
18 | return M
19 |
--------------------------------------------------------------------------------
/src/data/datasets/kitti360/generate_evaluation_sets.py:
--------------------------------------------------------------------------------
1 | # Test set for Kitti360 Sequence 09.
2 | # This script is adapted from: https://github.com/jac99/Egonn/blob/main/datasets/kitti/generate_evaluation_sets.py
3 |
4 | import argparse
5 | import os
6 | import sys
7 | from typing import List
8 |
9 | import numpy as np
10 |
11 | sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))
12 |
13 | from datasets.base_datasets import EvaluationSet, EvaluationTuple, filter_query_elements
14 | from datasets.kitti360.kitti360_raw import Kitti360Sequence
15 |
16 | # MAP_TIMERANGE = (0, 170)
17 | MAP_TIMERANGE = (0, 300)
18 |
19 | def get_scans(sequence: Kitti360Sequence, min_displacement: float = 0.1, ts_range: tuple = None) -> List[EvaluationTuple]:
20 | # Get a list of all point clouds from the sequence (the full sequence or test split only)
21 |
22 | elems = []
23 | old_pos = None
24 | count_skipped = 0
25 | displacements = []
26 |
27 | for ndx in range(len(sequence)):
28 | if ts_range is not None:
29 | if (ts_range[0] > sequence.rel_lidar_timestamps[ndx]) or (ts_range[1] < sequence.rel_lidar_timestamps[ndx]):
30 | continue
31 | pose = sequence.lidar_poses[ndx]
32 | # Kitti poses are in camera coordinates system where where y is upper axis dim
33 | position = pose[[0,1], 3]
34 |
35 | if old_pos is not None:
36 | displacements.append(np.linalg.norm(old_pos - position))
37 |
38 | if np.linalg.norm(old_pos - position) < min_displacement:
39 | # Ignore the point cloud if the vehicle didn't move
40 | count_skipped += 1
41 | continue
42 | # print(sequence.rel_scan_filepath)
43 | item = EvaluationTuple(sequence.rel_lidar_timestamps[ndx], sequence.rel_scan_filepath[ndx], position, pose)
44 | elems.append(item)
45 | old_pos = position
46 |
47 | print(f'{count_skipped} clouds skipped due to displacement smaller than {min_displacement}')
48 | print(f'mean displacement {np.mean(np.array(displacements))}')
49 | return elems
50 |
51 |
52 | def generate_evaluation_set(dataset_root: str, map_sequence: str, min_displacement: float = 0.1,
53 | dist_threshold: float = 5.) -> EvaluationSet:
54 |
55 | sequence = Kitti360Sequence(dataset_root, map_sequence)
56 |
57 | map_set = get_scans(sequence, min_displacement, MAP_TIMERANGE)
58 | query_set = get_scans(sequence, min_displacement, (MAP_TIMERANGE[-1], sequence.rel_lidar_timestamps[-1]))
59 | query_set = filter_query_elements(query_set, map_set, dist_threshold)
60 | print(f'{len(map_set)} database elements, {len(query_set)} query elements')
61 | return EvaluationSet(query_set, map_set)
62 |
63 |
64 | if __name__ == '__main__':
65 | parser = argparse.ArgumentParser(description='Generate evaluation sets for KItti dataset')
66 | # kitti: /mnt/088A6CBB8A6CA742/Datasets/Kitti/dataset/
67 | # mulran: /mnt/088A6CBB8A6CA742/Datasets/MulRan/
68 | # apollo:
69 | parser.add_argument('--dataset_root', type=str, required=False, default='/data/raktim/Datasets/KITTI360/KITTI-360/data_3d_raw')
70 | parser.add_argument('--min_displacement', type=float, default=3.0)
71 | # Ignore query elements that do not have a corresponding map element within the given threshold (in meters)
72 | parser.add_argument('--dist_threshold', type=float, default=5.)
73 |
74 | args = parser.parse_args()
75 |
76 | # Sequences are fixed
77 | sequence = '09'
78 | sequence_name = '2013_05_28_drive_00'+ sequence + '_sync'
79 | print(f'Dataset root: {args.dataset_root}')
80 | print(f'Kitti sequence: {sequence}')
81 | print(f'Minimum displacement between consecutive anchors: {args.min_displacement}')
82 | print(f'Ignore query elements without a corresponding map element within a threshold [m]: {args.dist_threshold}')
83 |
84 | kitti_eval_set = generate_evaluation_set(args.dataset_root, sequence, min_displacement=args.min_displacement,
85 | dist_threshold=args.dist_threshold)
86 | file_path_name = os.path.join(os.path.dirname(__file__), f'kitti360_{sequence}_{args.min_displacement}_eval.pickle')
87 | print(f"Saving evaluation pickle: {file_path_name}")
88 | kitti_eval_set.save(file_path_name)
89 |
--------------------------------------------------------------------------------
/src/data/datasets/kitti360/kitti360_09_3.0_eval.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raktimgg/SALSA/9152e8dff6fb3fc07dd546c9f19bfab462d08a6f/src/data/datasets/kitti360/kitti360_09_3.0_eval.pickle
--------------------------------------------------------------------------------
/src/data/datasets/kitti360/kitti360_raw.py:
--------------------------------------------------------------------------------
1 | # Functions and classes operating on a raw Kitti dataset
2 | # This script is adapted from: https://github.com/jac99/Egonn/blob/main/datasets/kitti/kitti_raw.py
3 |
4 | import os
5 | import sys
6 | from datetime import datetime
7 |
8 | import numpy as np
9 | import torch
10 | from torch.utils.data import Dataset
11 |
12 | sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))
13 | from datasets.point_clouds_utils import PointCloudLoader
14 |
15 |
16 | class Kitti360PointCloudLoader(PointCloudLoader):
17 | def set_properties(self):
18 | # Set point cloud properties, such as ground_plane_level.
19 | self.ground_plane_level = -1.5
20 |
21 | def read_pc(self, file_pathname: str) -> torch.Tensor:
22 | # Reads the point cloud without pre-processing
23 | # Returns Nx3 tensor
24 | pc = np.fromfile(file_pathname, dtype=np.float32)
25 | # PC in Mulran is of size [num_points, 4] -> x,y,z,reflectance
26 | pc = np.reshape(pc, (-1, 4))[:, :3]
27 | # pc = np.reshape(pc, (-1, 4))
28 | return pc
29 |
30 |
31 | class Kitti360Sequence(Dataset):
32 | """
33 | Point cloud from a sequence from a raw Mulran dataset
34 | """
35 | def __init__(self, dataset_root: str, sequence_name: str, pose_time_tolerance: float = 1.,
36 | remove_zero_points: bool = True):
37 | # pose_time_tolerance: (in seconds) skip point clouds without corresponding pose information (based on
38 | # timestamps difference)
39 | # remove_zero_points: remove (0,0,0) points
40 |
41 | assert os.path.exists(dataset_root), f'Cannot access dataset root: {dataset_root}'
42 | self.dataset_root = dataset_root
43 | self.sequence_name = '2013_05_28_drive_00'+ sequence_name + '_sync'
44 | # self.sequence_path = os.path.join(self.dataset_root, 'sequences')
45 | # assert os.path.exists(self.sequence_path), f'Cannot access sequence: {self.sequence_path}'
46 | self.rel_lidar_path = os.path.join(self.sequence_name, 'velodyne_points/data')
47 | # lidar_path = os.path.join(self.sequence_path, self.rel_lidar_path)
48 | # assert os.path.exists(lidar_path), f'Cannot access lidar scans: {lidar_path}'
49 | self.pose_file = os.path.join(self.dataset_root, self.sequence_name , 'poses.txt')
50 | self.calib_file = os.path.join(self.dataset_root, self.sequence_name , 'cam0_to_world.txt')
51 | assert os.path.exists(self.pose_file), f'Cannot access sequence pose file: {self.pose_file}'
52 | self.times_file = os.path.join(self.dataset_root, self.sequence_name, 'velodyne_points/timestamps.txt')
53 | assert os.path.exists(self.pose_file), f'Cannot access sequence times file: {self.times_file}'
54 | # Maximum discrepancy between timestamps of LiDAR scan and global pose in seconds
55 | self.pose_time_tolerance = pose_time_tolerance
56 | self.remove_zero_points = remove_zero_points
57 |
58 | self.rel_lidar_timestamps, self.lidar_poses, filenames = self._read_lidar_poses()
59 | self.rel_scan_filepath = [os.path.join(self.rel_lidar_path, '%010d%s' % (e, '.bin')) for e in filenames]
60 | print('')
61 |
62 | def __len__(self):
63 | return len(self.rel_lidar_timestamps)
64 |
65 | def __getitem__(self, ndx):
66 | scan_filepath = os.path.join(self.dataset_root, self.rel_scan_filepath[ndx])
67 | pc = load_pc(scan_filepath)
68 | if self.remove_zero_points:
69 | mask = np.all(np.isclose(pc, 0), axis=1)
70 | pc = pc[~mask]
71 | return {'pc': pc, 'pose': self.lidar_poses[ndx], 'ts': self.rel_lidar_timestamps[ndx]}
72 |
73 | def _read_lidar_poses(self):
74 | fnames = os.listdir(os.path.join(self.dataset_root, self.rel_lidar_path))
75 | temp = os.path.join(self.dataset_root, self.rel_lidar_path)
76 | fnames = [e for e in fnames if os.path.isfile(os.path.join(temp, e))]
77 | assert len(fnames) > 0, f"Make sure that the path {self.rel_lidar_path}"
78 | # filenames = sorted([int(os.path.split(fname)[-1][:-4]) for fname in fnames])
79 | # with open(self.calib_file, 'r') as f:
80 | # for line in f.readlines():
81 | # data = np.array([float(x) for x in line.split()])
82 |
83 | # cam0_to_velo = np.reshape(data, (3, 4))
84 | # cam0_to_velo = np.vstack([cam0_to_velo, [0, 0, 0, 1]])
85 |
86 |
87 | poses,_,_ = load_poses_from_txt(self.pose_file)#, cam0_to_velo)
88 | sorted_keys = sorted(poses.keys())
89 | poses_list = [poses[k] for k in sorted_keys]
90 | filenames = sorted([int(key) for key in poses])
91 | ts = load_timestamps(self.times_file)
92 | ts = np.asarray(ts)[filenames]
93 | rel_ts = ts - ts[0]
94 |
95 | return rel_ts, poses_list, filenames
96 |
97 |
98 | def load_pc(filepath):
99 | # Load point cloud, does not apply any transform
100 | # Returns Nx3 matrix
101 | pc = np.fromfile(filepath, dtype=np.float32)
102 | # PC in Kitti is of size [num_points, 4] -> x,y,z,reflectance
103 | pc = np.reshape(pc, (-1, 4))[:, :3]
104 | return pc
105 |
106 | def load_poses_from_txt(file_name):#, cam0_to_velo):
107 | f = open(file_name, 'r')
108 | s = f.readlines()
109 | f.close()
110 | transforms = {}
111 | x = []
112 | y = []
113 | for cnt, line in enumerate(s):
114 | P = np.eye(4)
115 | line_split = [float(i) for i in line.split(" ") if i!="" and i!="\n"]
116 | withIdx = len(line_split) >= 13
117 | for row in range(3):
118 | for col in range(4):
119 | P[row, col] = line_split[row*4 + col + withIdx]
120 | if withIdx:
121 | frame_idx = line_split[0]
122 | else:
123 | frame_idx = cnt
124 | transforms[frame_idx] = P #@ cam0_to_velo.inverse()
125 | x.append(P[0, 3])
126 | y.append(P[1, 3])
127 | return transforms, x, y
128 |
129 | def load_timestamps(file_name):
130 | f = open(file_name, 'r')
131 | s = f.readlines()
132 | f.close()
133 | times = []
134 |
135 | for cnt, line in enumerate(s):#2013-05-28 11:36:55.89086054
136 | dt_obj = datetime.strptime(line[:-4], '%Y-%m-%d %H:%M:%S.%f')
137 |
138 | # nanosec = dt_obj.timestamp() * 10e9
139 | sec = dt_obj.timestamp()
140 |
141 | times.append(sec)
142 | return times
143 |
--------------------------------------------------------------------------------
/src/data/datasets/kitti360/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def kitti360_calib_transform(init_T):
5 | cam_to_velo_data = [0.04307104361, -0.08829286498, 0.995162929, 0.8043914418,
6 | -0.999004371, 0.007784614041, 0.04392796942, 0.2993489574,
7 | -0.01162548558, -0.9960641394, -0.08786966659, -0.1770225824]
8 | cam0_to_velo = np.reshape(cam_to_velo_data, (3, 4))
9 | cam0_to_velo = np.vstack([cam0_to_velo, [0, 0, 0, 1]])
10 |
11 | cam_to_pose_data = [0.0371783278, -0.0986182135, 0.9944306009, 1.5752681039,
12 | 0.9992675562, -0.0053553387, -0.0378902567, 0.0043914093,
13 | 0.0090621821, 0.9951109327, 0.0983468786, -0.6500000000]
14 | cam0_to_pose = np.reshape(cam_to_pose_data, (3, 4))
15 | cam0_to_pose = np.vstack([cam0_to_pose, [0, 0, 0, 1]])
16 |
17 | return init_T @ cam0_to_pose @ np.linalg.inv(cam0_to_velo)
18 |
19 |
20 | def kitti360_relative_pose(pose_1, pose_2):
21 | pose_1 = kitti360_calib_transform(pose_1)
22 | pose_2 = kitti360_calib_transform(pose_2)
23 | return np.linalg.inv(pose_2) @ pose_1
24 |
25 | # def relative_pose(m1, m2):
26 |
27 | # return np.linalg.inv(m2) @ m1
--------------------------------------------------------------------------------
/src/data/datasets/mulran/generate_evaluation_sets.py:
--------------------------------------------------------------------------------
1 | # Test sets for Mulran dataset.
2 | # This file is adapted from: https://github.com/jac99/Egonn/blob/main/datasets/mulran/generate_evaluation_sets.py
3 |
4 | import argparse
5 | import os
6 | import sys
7 | from typing import List
8 |
9 | sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))
10 |
11 |
12 | from datasets.base_datasets import EvaluationSet, EvaluationTuple, filter_query_elements
13 | from datasets.mulran.mulran_raw import MulranSequence
14 |
15 | # # KAIST 02
16 | # MAP_TIMERANGE = (1566535940856033867, 1566536300000000000)
17 | # QUERY_TIMERANGE = (1566536300000000000, 1566536825534173166)
18 | # # Riverside 01:
19 | # # MAP_TIMERANGE = (1564718063503232284, 1564718300000000000)
20 | # # QUERY_TIMERANGE = (1564718300000000000, 1564718603800415528)
21 |
22 |
23 | def get_scans(sequence: MulranSequence, ts_range: tuple = None) -> List[EvaluationTuple]:
24 | # Get a list of all readings from the test area in the sequence
25 | elems = []
26 | for ndx in range(len(sequence)):
27 | if ts_range is not None:
28 | if (ts_range[0] > sequence.timestamps[ndx]) or (ts_range[1] < sequence.timestamps[ndx]):
29 | continue
30 | pose = sequence.poses[ndx]
31 | position = pose[:2, 3]
32 | item = EvaluationTuple(sequence.timestamps[ndx], sequence.rel_scan_filepath[ndx], position=position, pose=pose)
33 | elems.append(item)
34 | return elems
35 |
36 |
37 | def generate_evaluation_set(dataset_root: str, map_sequence_name: str, query_sequence_name: str, min_displacement: float = 0.2,
38 | dist_threshold=20) -> EvaluationSet:
39 | split = 'test'
40 | map_sequence = MulranSequence(dataset_root, map_sequence_name, split=split, min_displacement=min_displacement)
41 | query_sequence = MulranSequence(dataset_root, query_sequence_name, split=split, min_displacement=min_displacement)
42 | print(min_displacement)
43 |
44 | if map_sequence_name == query_sequence_name:
45 | print('Wrong Wrong Wrong')
46 | map_set = get_scans(map_sequence, MAP_TIMERANGE)
47 | query_set = get_scans(query_sequence, QUERY_TIMERANGE)
48 | else:
49 | map_set = get_scans(map_sequence)
50 | query_set = get_scans(query_sequence)
51 |
52 | # Function used in evaluation dataset generation
53 | # Filters out query elements without a corresponding map element within dist_threshold threshold
54 | query_set = filter_query_elements(query_set, map_set, dist_threshold)
55 | print(f'{len(map_set)} database elements, {len(query_set)} query elements')
56 | return EvaluationSet(query_set, map_set)
57 |
58 |
59 | if __name__ == '__main__':
60 | parser = argparse.ArgumentParser(description='Generate evaluation sets for Mulran dataset')
61 | parser.add_argument('--dataset_root', type=str, required=False, default='/data/raktim/Datasets/Mulran/Sejong')
62 | parser.add_argument('--sequence', type=str, required=False, default='Sejong')
63 | parser.add_argument('--min_displacement', type=float, default=0.2)
64 | # Ignore query elements that do not have a corresponding map element within the given threshold (in meters)
65 | parser.add_argument('--dist_threshold', type=float, default=20)
66 | args = parser.parse_args()
67 |
68 | # Sequences is a list of (map sequence, query sequence)
69 | sequences = [('Sejong1', 'Sejong2')]
70 | if args.sequence == 'DCC':
71 | sequences = [('DCC1', 'DCC2')]
72 | args.min_displacement = 10.0
73 | args.dist_threshold = 5
74 |
75 |
76 | print(f'Dataset root: {args.dataset_root}')
77 | print(f'Minimum displacement between consecutive anchors: {args.min_displacement}')
78 | print(f'Ignore query elements without a corresponding map element within a threshold [m]: {args.dist_threshold}')
79 |
80 |
81 | for map_sequence, query_sequence in sequences:
82 | print(f'Map sequence: {map_sequence}')
83 | print(f'Query sequence: {query_sequence}')
84 |
85 | test_set = generate_evaluation_set(args.dataset_root, map_sequence, query_sequence,
86 | min_displacement=args.min_displacement, dist_threshold=args.dist_threshold)
87 |
88 | pickle_name = f'test_{map_sequence}_{query_sequence}_{args.min_displacement}_{args.dist_threshold}.pickle'
89 | # file_path_name = os.path.join(args.dataset_root, pickle_name)
90 | file_path_name = os.path.join(os.path.dirname(__file__), pickle_name)
91 | print(f"Saving evaluation pickle: {file_path_name}")
92 | test_set.save(file_path_name)
93 |
--------------------------------------------------------------------------------
/src/data/datasets/mulran/generate_training_tuples.py:
--------------------------------------------------------------------------------
1 | # Training tuples generation for Mulran dataset.
2 |
3 | import argparse
4 | import os
5 | import pickle
6 | import sys
7 |
8 | import numpy as np
9 | import tqdm
10 |
11 | sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))
12 |
13 | from datasets.base_datasets import TrainingTuple
14 | from datasets.mulran.mulran_raw import MulranSequences
15 | from datasets.mulran.utils import relative_pose
16 |
17 | from misc.point_clouds import icp
18 |
19 | DEBUG = False
20 |
21 |
22 | def load_pc(file_pathname):
23 | # Load point cloud, clip x, y and z coords (points far away and the ground plane)
24 | # Returns Nx3 matrix
25 | pc = np.fromfile(file_pathname, dtype=np.float32)
26 | # PC in Mulran is of size [num_points, 4] -> x,y,z,reflectance
27 | pc = np.reshape(pc, (-1, 4))[:, :3]
28 |
29 | mask = np.all(np.isclose(pc, 0.), axis=1)
30 | pc = pc[~mask]
31 | mask = pc[:, 0] > -80
32 | pc = pc[mask]
33 | mask = pc[:, 0] <= 80
34 |
35 | pc = pc[mask]
36 | mask = pc[:, 1] > -80
37 | pc = pc[mask]
38 | mask = pc[:, 1] <= 80
39 | pc = pc[mask]
40 |
41 | mask = pc[:, 2] > -0.9
42 | pc = pc[mask]
43 | return pc
44 |
45 |
46 | def generate_training_tuples(ds: MulranSequences, pos_threshold: float = 10, neg_threshold: float = 50):
47 | # displacement: displacement between consecutive anchors (if None all scans are takes as anchors).
48 | # Use some small displacement to ensure there's only one scan if the vehicle does not move
49 |
50 | tuples = {} # Dictionary of training tuples: tuples[ndx] = (sef ot positives, set of non negatives)
51 | for anchor_ndx in tqdm.tqdm(range(len(ds))):
52 | anchor_pos = ds.get_xy()[anchor_ndx]
53 |
54 | # Find timestamps of positive and negative elements
55 | positives = ds.find_neighbours_ndx(anchor_pos, pos_threshold)
56 | non_negatives = ds.find_neighbours_ndx(anchor_pos, neg_threshold)
57 | # Remove anchor element from positives, but leave it in non_negatives
58 | positives = positives[positives != anchor_ndx]
59 |
60 | # Sort ascending order
61 | positives = np.sort(positives)
62 | non_negatives = np.sort(non_negatives)
63 |
64 | # ICP pose refinement
65 | fitness_l = []
66 | inlier_rmse_l = []
67 | positive_poses = {}
68 |
69 | if True:
70 | # Use ground truth transform without pose refinement
71 | anchor_pose = ds.poses[anchor_ndx]
72 | for positive_ndx in positives:
73 | positive_pose = ds.poses[positive_ndx]
74 | # Compute initial relative pose
75 | m, fitness, inlier_rmse = relative_pose(anchor_pose, positive_pose), 1., 1.
76 | fitness_l.append(fitness)
77 | inlier_rmse_l.append(inlier_rmse)
78 | positive_poses[positive_ndx] = m
79 | else:
80 | anchor_pc = load_pc(os.path.join(ds.dataset_root, ds.rel_scan_filepath[anchor_ndx]))
81 | anchor_pose = ds.poses[anchor_ndx]
82 | for positive_ndx in positives:
83 | positive_pc = load_pc(os.path.join(ds.dataset_root, ds.rel_scan_filepath[positive_ndx]))
84 | positive_pose = ds.poses[positive_ndx]
85 | # Compute initial relative pose
86 | transform = relative_pose(anchor_pose, positive_pose)
87 | # Refine the pose using ICP
88 | m, fitness, inlier_rmse = icp(anchor_pc, positive_pc, transform)
89 |
90 | fitness_l.append(fitness)
91 | inlier_rmse_l.append(inlier_rmse)
92 | positive_poses[positive_ndx] = m
93 |
94 | # Tuple(id: int, timestamp: int, rel_scan_filepath: str, positives: List[int], non_negatives: List[int])
95 | tuples[anchor_ndx] = TrainingTuple(id=anchor_ndx, timestamp=ds.timestamps[anchor_ndx],
96 | rel_scan_filepath=ds.dataset_root + '/' + ds.rel_scan_filepath[anchor_ndx],
97 | positives=positives, non_negatives=non_negatives, pose=anchor_pose,
98 | positives_poses=positive_poses)
99 |
100 | print(f'{len(tuples)} training tuples generated')
101 | print('ICP pose refimenement stats:')
102 | print(f'Fitness - min: {np.min(fitness_l):0.3f} mean: {np.mean(fitness_l):0.3f} max: {np.max(fitness_l):0.3f}')
103 | print(f'Inlier RMSE - min: {np.min(inlier_rmse_l):0.3f} mean: {np.mean(inlier_rmse_l):0.3f} max: {np.max(inlier_rmse_l):0.3f}')
104 |
105 | return tuples
106 |
107 |
108 | if __name__ == '__main__':
109 | parser = argparse.ArgumentParser(description='Generate training tuples')
110 | parser.add_argument('--dataset_root', type=str, default='/data/raktim/Datasets/Mulran/Sejong')
111 | parser.add_argument('--pos_threshold', default=2)
112 | parser.add_argument('--neg_threshold', default=10)
113 | parser.add_argument('--min_displacement', type=float, default=0.2)
114 | args = parser.parse_args()
115 |
116 | sequences = ['Sejong1', 'Sejong2']
117 | if DEBUG:
118 | sequences = ['ParkingLot', 'ParkingLot']
119 |
120 | print(f'Dataset root: {args.dataset_root}')
121 | print(f'Sequences: {sequences}')
122 | print(f'Threshold for positive examples: {args.pos_threshold}')
123 | print(f'Threshold for negative examples: {args.neg_threshold}')
124 | print(f'Minimum displacement between consecutive anchors: {args.min_displacement}')
125 |
126 | ds = MulranSequences(args.dataset_root, sequences, split='train', min_displacement=args.min_displacement)
127 | train_tuples = generate_training_tuples(ds, args.pos_threshold, args.neg_threshold)
128 | pickle_name = f'train_{sequences[0]}_{sequences[1]}_{args.pos_threshold}_{args.neg_threshold}.pickle'
129 | train_tuples_filepath = os.path.join(args.dataset_root, pickle_name)
130 | pickle.dump(train_tuples, open(train_tuples_filepath, 'wb'))
131 | train_tuples = None
132 |
133 | ds = MulranSequences(args.dataset_root, sequences, split='test', min_displacement=args.min_displacement)
134 | test_tuples = generate_training_tuples(ds, args.pos_threshold, args.neg_threshold)
135 | pickle_name = f'val_{sequences[0]}_{sequences[1]}_{args.pos_threshold}_{args.neg_threshold}.pickle'
136 | test_tuples_filepath = os.path.join(args.dataset_root, pickle_name)
137 | pickle.dump(test_tuples, open(test_tuples_filepath, 'wb'))
138 |
--------------------------------------------------------------------------------
/src/data/datasets/mulran/mulran_train.py:
--------------------------------------------------------------------------------
1 | # Warsaw University of Technology
2 | # Dataset wrapper for Mulran lidar scans dataset
3 |
4 | import os
5 | import random
6 | import numpy as np
7 | import torch
8 |
9 | from datasets.base_datasets import TrainingDataset
10 | from datasets.quantization import Quantizer
11 | from misc.poses import apply_transform
12 | from datasets.base_datasets import TrainingDataset
13 |
14 | DEBUG = False
15 |
16 |
17 | class MulranTraining6DOFDataset(TrainingDataset):
18 | """
19 | Dataset wrapper for Mulran dataset for 6dof estimation.
20 | """
21 | def __init__(self, dataset_path: str, query_filename: str, quantizer: Quantizer,
22 | rot_max: float = 0., trans_max: float = 0., **vargs):
23 | dataset_type = 'mulran'
24 | super().__init__(dataset_path, dataset_type, query_filename, **vargs)
25 | self.quantizer = quantizer
26 | self.rot_max = rot_max
27 | self.trans_max = trans_max
28 |
29 | def __getitem__(self, ndx):
30 | # pose is a global coordinate system pose 3x4 R|T matrix
31 | query_pc, _ = super().__getitem__(ndx)
32 |
33 | # get random positive
34 | positives = self.get_positives(ndx)
35 | positive_idx = np.random.choice(positives, 1)[0]
36 | positive_pc, _ = super().__getitem__(positive_idx)
37 |
38 | # get relative pose taking two global poses
39 | transform = self.queries[ndx].positives_poses[positive_idx]
40 |
41 | # Apply random transform to the positive point cloud
42 | rotation_angle = np.random.uniform(low=-self.rot_max, high=self.rot_max)
43 | cosval = np.cos(rotation_angle)
44 | sinval = np.sin(rotation_angle)
45 | m = torch.eye(4, dtype=torch.float)
46 | #m[:3, :3] = np.array([[cosval, sinval, 0.], [-sinval, cosval, 0.], [0., 0., 1.]])
47 | m[:3, :3] = torch.tensor([[cosval, sinval, 0.], [-sinval, cosval, 0.], [0., 0., 1.]], dtype=m.dtype)
48 | m[:2, 3] = torch.rand((1, 2)) * 2. * self.trans_max - self.trans_max
49 | positive_pc = apply_transform(positive_pc, m)
50 | transform = m @ transform
51 |
52 | # Find indices of unique quantized coordinates and filter out points to leave max 1 point per voxel
53 | coords1, idx1 = self.quantizer(query_pc)
54 | coords2, idx2 = self.quantizer(positive_pc)
55 | pc1_cop = query_pc[idx1, :]
56 | pc2_trans_cop = positive_pc[idx2, :]
57 |
58 | return pc1_cop, pc2_trans_cop, transform
59 |
--------------------------------------------------------------------------------
/src/data/datasets/mulran/test_DCC1_DCC2_10.0_5.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raktimgg/SALSA/9152e8dff6fb3fc07dd546c9f19bfab462d08a6f/src/data/datasets/mulran/test_DCC1_DCC2_10.0_5.pickle
--------------------------------------------------------------------------------
/src/data/datasets/mulran/test_Sejong1_Sejong2_0.2_20.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raktimgg/SALSA/9152e8dff6fb3fc07dd546c9f19bfab462d08a6f/src/data/datasets/mulran/test_Sejong1_Sejong2_0.2_20.pickle
--------------------------------------------------------------------------------
/src/data/datasets/mulran/test_Sejong1_Sejong2_0.2_5.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raktimgg/SALSA/9152e8dff6fb3fc07dd546c9f19bfab462d08a6f/src/data/datasets/mulran/test_Sejong1_Sejong2_0.2_5.pickle
--------------------------------------------------------------------------------
/src/data/datasets/mulran/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | from scipy.spatial import distance_matrix
4 |
5 | # Faulty point clouds (with 0 points)
6 | FAULTY_POINTCLOUDS = [1566279795718079314]
7 |
8 | # Coordinates of test region centres (in Sejong sequence)
9 | TEST_REGION_CENTRES = np.array([[345090.0743, 4037591.323], [345090.483, 4044700.04],
10 | [350552.0308, 4041000.71], [349252.0308, 4044800.71]])
11 |
12 | # Radius of the test region
13 | TEST_REGION_RADIUS = 500
14 |
15 | # Boundary between training and test region - to ensure there's no overlap between training and test clouds
16 | TEST_TRAIN_BOUNDARY = 50
17 |
18 |
19 | def in_train_split(pos):
20 | # returns true if pos is in train split
21 | assert pos.ndim == 2
22 | assert pos.shape[1] == 2
23 | dist = distance_matrix(pos, TEST_REGION_CENTRES)
24 | mask = (dist > TEST_REGION_RADIUS + TEST_TRAIN_BOUNDARY).all(axis=1)
25 | return mask
26 |
27 |
28 | def in_test_split(pos):
29 | # returns true if position is in evaluation split
30 | assert pos.ndim == 2
31 | assert pos.shape[1] == 2
32 | dist = distance_matrix(pos, TEST_REGION_CENTRES)
33 | mask = (dist < TEST_REGION_RADIUS).any(axis=1)
34 | return mask
35 |
36 |
37 | def find_nearest_ndx(ts, timestamps):
38 | ndx = np.searchsorted(timestamps, ts)
39 | if ndx == 0:
40 | return ndx
41 | elif ndx == len(timestamps):
42 | return ndx - 1
43 | else:
44 | assert timestamps[ndx-1] <= ts <= timestamps[ndx]
45 | if ts - timestamps[ndx-1] < timestamps[ndx] - ts:
46 | return ndx - 1
47 | else:
48 | return ndx
49 |
50 |
51 | def read_lidar_poses(poses_filepath: str, lidar_filepath: str, pose_time_tolerance: float = 1.):
52 | # Read global poses from .csv file and link each lidar_scan with the nearest pose
53 | # threshold: threshold in seconds
54 | # Returns a dictionary with (4, 4) pose matrix indexed by a timestamp (as integer)
55 |
56 | with open(poses_filepath, "r") as h:
57 | txt_poses = h.readlines()
58 |
59 | n = len(txt_poses)
60 | system_timestamps = np.zeros((n,), dtype=np.int64)
61 | poses = np.zeros((n, 4, 4), dtype=np.float64) # 4x4 pose matrix
62 |
63 | for ndx, pose in enumerate(txt_poses):
64 | # Split by comma and remove whitespaces
65 | temp = [e.strip() for e in pose.split(',')]
66 | assert len(temp) == 13, f'Invalid line in global poses file: {temp}'
67 | system_timestamps[ndx] = int(temp[0])
68 | poses[ndx] = np.array([[float(temp[1]), float(temp[2]), float(temp[3]), float(temp[4])],
69 | [float(temp[5]), float(temp[6]), float(temp[7]), float(temp[8])],
70 | [float(temp[9]), float(temp[10]), float(temp[11]), float(temp[12])],
71 | [0., 0., 0., 1.]])
72 |
73 | # Ensure timestamps and poses are sorted in ascending order
74 | sorted_ndx = np.argsort(system_timestamps, axis=0)
75 | system_timestamps = system_timestamps[sorted_ndx]
76 | poses = poses[sorted_ndx]
77 |
78 | # List LiDAR scan timestamps
79 | all_lidar_timestamps = [int(os.path.splitext(f)[0]) for f in os.listdir(lidar_filepath) if
80 | os.path.splitext(f)[1] == '.bin']
81 | all_lidar_timestamps.sort()
82 |
83 | lidar_timestamps = []
84 | lidar_poses = []
85 | count_rejected = 0
86 |
87 | for ndx, lidar_ts in enumerate(all_lidar_timestamps):
88 | # Skip faulty point clouds
89 | if lidar_ts in FAULTY_POINTCLOUDS:
90 | continue
91 |
92 | # Find index of the closest timestamp
93 | closest_ts_ndx = find_nearest_ndx(lidar_ts, system_timestamps)
94 | delta = abs(system_timestamps[closest_ts_ndx] - lidar_ts)
95 | # Timestamp is in nanoseconds = 1e-9 second
96 | if delta > pose_time_tolerance * 1000000000:
97 | # Reject point cloud without corresponding pose
98 | count_rejected += 1
99 | continue
100 |
101 | lidar_timestamps.append(lidar_ts)
102 | lidar_poses.append(poses[closest_ts_ndx])
103 |
104 | lidar_timestamps = np.array(lidar_timestamps, dtype=np.int64)
105 | lidar_poses = np.array(lidar_poses, dtype=np.float64) # (northing, easting) position
106 |
107 | print(f'{len(lidar_timestamps)} scans with valid pose, {count_rejected} rejected due to unknown pose')
108 | return lidar_timestamps, lidar_poses
109 |
110 |
111 | def relative_pose(m1, m2):
112 | # SE(3) pose is 4x 4 matrix, such that
113 | # Pw = [R | T] @ [P]
114 | # [0 | 1] [1]
115 | # where Pw are coordinates in the world reference frame and P are coordinates in the camera frame
116 | # m1: coords in camera/lidar1 reference frame -> world coordinate frame
117 | # m2: coords in camera/lidar2 coords -> world coordinate frame
118 | # returns: relative pose of the first camera with respect to the second camera
119 | # transformation matrix to convert coords in camera/lidar1 reference frame to coords in
120 | # camera/lidar2 reference frame
121 | #
122 | m = np.linalg.inv(m2) @ m1
123 | # !!!!!!!!!! Fix for relative pose !!!!!!!!!!!!!
124 | m[:3, 3] = -m[:3, 3]
125 | return m
126 |
--------------------------------------------------------------------------------
/src/data/datasets/point_clouds_utils.py:
--------------------------------------------------------------------------------
1 | # This file is adapted from: https://github.com/jac99/Egonn/blob/main/misc/point_clouds.py
2 |
3 | import copy
4 | import os
5 |
6 | import numpy as np
7 | import open3d as o3d
8 |
9 |
10 | def draw_registration_result(source, target, transformation):
11 | source_temp = copy.deepcopy(source)
12 | target_temp = copy.deepcopy(target)
13 | source_temp.paint_uniform_color([1, 0.706, 0])
14 | target_temp.paint_uniform_color([0, 0.651, 0.929])
15 | source_temp.transform(transformation)
16 | o3d.visualization.draw_geometries([source_temp, target_temp],
17 | zoom=0.4459,
18 | front=[0.9288, -0.2951, -0.2242],
19 | lookat=[1.6784, 2.0612, 1.4451],
20 | up=[-0.3402, -0.9189, -0.1996])
21 |
22 |
23 | def draw_pc(pc):
24 | pc = copy.deepcopy(pc)
25 | pc.paint_uniform_color([1, 0.706, 0])
26 | o3d.visualization.draw_geometries([pc],
27 | zoom=0.4459,
28 | front=[0.9288, -0.2951, -0.2242],
29 | lookat=[1.6784, 2.0612, 1.4451],
30 | up=[-0.3402, -0.9189, -0.1996])
31 |
32 |
33 | def icp(anchor_pc, positive_pc, transform: np.ndarray = None, point2plane: bool = False,
34 | inlier_dist_threshold: float = 1.2, max_iteration: int = 200):
35 | # transform: initial alignment transform
36 | if transform is not None:
37 | transform = transform.astype(float)
38 |
39 | voxel_size = 0.1
40 | pcd1 = o3d.geometry.PointCloud()
41 | pcd1.points = o3d.utility.Vector3dVector(anchor_pc)
42 | pcd1 = pcd1.voxel_down_sample(voxel_size=voxel_size)
43 |
44 | pcd2 = o3d.geometry.PointCloud()
45 | pcd2.points = o3d.utility.Vector3dVector(positive_pc)
46 | pcd2 = pcd2.voxel_down_sample(voxel_size=voxel_size)
47 |
48 | if point2plane:
49 | pcd1.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamKNN(knn=20))
50 | pcd2.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamKNN(knn=20))
51 | transform_estimation = o3d.pipelines.registration.TransformationEstimationPointToPlane()
52 | else:
53 | transform_estimation = o3d.pipelines.registration.TransformationEstimationPointToPoint()
54 |
55 | if transform is not None:
56 | reg_p2p = o3d.pipelines.registration.registration_icp(pcd1, pcd2, inlier_dist_threshold, transform,
57 | estimation_method=transform_estimation,
58 | criteria=o3d.pipelines.registration.ICPConvergenceCriteria(max_iteration=max_iteration))
59 | else:
60 | reg_p2p = o3d.pipelines.registration.registration_icp(pcd1, pcd2, inlier_dist_threshold,
61 | estimation_method=transform_estimation,
62 | criteria=o3d.pipelines.registration.ICPConvergenceCriteria(max_iteration=max_iteration))
63 |
64 | return reg_p2p.transformation, reg_p2p.fitness, reg_p2p.inlier_rmse
65 |
66 |
67 | def make_open3d_feature(data, dim, npts):
68 | feature = o3d.pipelines.registration.Feature()
69 | feature.resize(dim, npts)
70 | if not isinstance(data, np.ndarray):
71 | feature.data = data.cpu().numpy().astype('d').transpose()
72 | else:
73 | feature.data = data.astype('d').transpose()
74 | return feature
75 |
76 |
77 | def make_open3d_point_cloud(xyz, color=None):
78 | pcd = o3d.geometry.PointCloud()
79 | pcd.points = o3d.utility.Vector3dVector(xyz)
80 | if color is not None:
81 | pcd.colors = o3d.utility.Vector3dVector(color)
82 | return pcd
83 |
84 | def preprocess_pointcloud(pc, remove_zero_points: bool = False,
85 | min_x: float = None, max_x: float = None,
86 | min_y: float = None, max_y: float = None,
87 | min_z: float = None, max_z: float = None):
88 | if remove_zero_points:
89 | mask = np.all(np.isclose(pc, 0.), axis=1)
90 | pc = pc[~mask]
91 |
92 | if min_x is not None:
93 | mask = pc[:, 0] > min_x
94 | pc = pc[mask]
95 |
96 | if max_x is not None:
97 | mask = pc[:, 0] <= max_x
98 | pc = pc[mask]
99 |
100 | if min_y is not None:
101 | mask = pc[:, 1] > min_y
102 | pc = pc[mask]
103 |
104 | if max_y is not None:
105 | mask = pc[:, 1] <= max_y
106 | pc = pc[mask]
107 |
108 | if min_z is not None:
109 | mask = pc[:, 2] > min_z
110 | pc = pc[mask]
111 |
112 | if max_z is not None:
113 | mask = pc[:, 2] <= max_z
114 | pc = pc[mask]
115 |
116 | return pc
117 |
118 |
119 | class PointCloudLoader:
120 | # Generic point cloud loader class
121 | def __init__(self):
122 | # remove_zero_points: remove points with all zero coordinates
123 | # remove_ground_plane: remove points on ground plane level and below
124 | # ground_plane_level: ground plane level
125 | self.remove_zero_points = True
126 | self.remove_ground_plane = True
127 | self.ground_plane_level = None
128 | self.set_properties()
129 |
130 | def set_properties(self):
131 | # Set point cloud properties, such as ground_plane_level. Must be defined in inherited classes.
132 | raise NotImplementedError('set_properties must be defined in inherited classes')
133 |
134 | def __call__(self, file_pathname):
135 | # Reads the point cloud from a disk and preprocess (optional removal of zero points and points on the ground
136 | # plane and below
137 | # file_pathname: relative file path
138 | assert os.path.exists(file_pathname), f"Cannot open point cloud: {file_pathname}"
139 | pc = self.read_pc(file_pathname)
140 | # assert pc.shape[1] == 3
141 |
142 | if self.remove_zero_points:
143 | mask = np.all(np.isclose(pc, 0), axis=1)
144 | pc = pc[~mask]
145 |
146 | if self.remove_ground_plane:
147 | mask = pc[:, 2] > self.ground_plane_level
148 | pc = pc[mask]
149 |
150 | return pc
151 |
152 | def read_pc(self, file_pathname):
153 | # Reads the point cloud without pre-processing
154 | raise NotImplementedError("read_pc must be overloaded in an inheriting class")
155 |
--------------------------------------------------------------------------------
/src/data/datasets/poses_utils.py:
--------------------------------------------------------------------------------
1 | # This file is directly copied from: https://github.com/jac99/Egonn/blob/main/misc/poses.py
2 |
3 | import numpy as np
4 | import torch
5 |
6 |
7 | def q2r(q):
8 | # Rotation matrix from Hamiltonian quaternion
9 | # Source: https://en.wikipedia.org/wiki/Conversion_between_quaternions_and_Euler_angles
10 | w, x, y, z = tuple(q)
11 |
12 | n = 1.0/np.sqrt(x*x+y*y+z*z+w*w)
13 | x *= n
14 | y *= n
15 | z *= n
16 | w *= n
17 | r = np.array([[1.0 - 2.0*y*y - 2.0*z*z, 2.0*x*y - 2.0*z*w, 2.0*x*z + 2.0*y*w],
18 | [2.0*x*y + 2.0*z*w, 1.0 - 2.0*x*x - 2.0*z*z, 2.0*y*z - 2.0*x*w],
19 | [2.0*x*z - 2.0*y*w, 2.0*y*z + 2.0*x*w, 1.0 - 2.0*x*x - 2.0*y*y]])
20 | return r
21 |
22 |
23 | def m2ypr(m):
24 | # Get yaw, pitch, roll angles from 4x4 transformation matrix
25 | # Based on formulas in Section 2.5.1 in:
26 | # A tutorial on SE(3) transformation parameterizations and on-manifold optimization
27 | # https://ingmec.ual.es/~jlblanco/papers/jlblanco2010geometry3D_techrep.pdf
28 | assert m.shape == (4, 4)
29 | pitch = np.arctan2(-m[2][0], np.sqrt(m[0][0]**2 + m[1][0]**2))
30 | # We do not handle degenerate case, when pitch is 90 degrees a.k.a. gimball lock
31 | assert not np.isclose(np.abs(pitch), np.pi/2)
32 | yaw = np.arctan2(m[1][0], m[0][0])
33 | roll = np.arctan2(m[2][1], m[2][2])
34 | return yaw, pitch, roll
35 |
36 |
37 | def m2xyz_ypr(m):
38 | # Get yaw, pitch, roll angles from 4x4 transformation matrix
39 | # Based on formulas in Section 2.5.1 in:
40 | # A tutorial on SE(3) transformation parameterizations and on-manifold optimization
41 | # https://ingmec.ual.es/~jlblanco/papers/jlblanco2010geometry3D_techrep.pdf
42 | assert m.shape == (4, 4)
43 | yaw, pitch, roll = m2ypr(m)
44 | return m[0, 3], m[1, 3], m[2, 3], yaw, pitch, roll
45 |
46 |
47 | def ypr2m(yaw, pitch, roll):
48 | # Construct 4x4 transformation matrix with rotation part set to given yaw, pitch, roll. Translation is set to 0.
49 | # Based on formulas in Section 2.2.1
50 | m = np.array([[np.cos(yaw) * np.cos(pitch), np.cos(yaw) * np.sin(pitch) * np.sin(roll) - np.sin(yaw) * np.cos(roll),
51 | np.cos(yaw) * np.sin(pitch) * np.cos(roll) + np.sin(yaw) * np.sin(roll), 0.],
52 | [np.sin(yaw) * np.cos(pitch), np.sin(roll) * np.sin(pitch) * np.sin(roll) + np.cos(yaw) * np.cos(roll),
53 | np.sin(yaw) * np.sin(pitch) * np.cos(roll) - np.cos(yaw) * np.sin(roll), 0.],
54 | [-np.sin(pitch), np.cos(pitch) * np.sin(roll), np.cos(pitch) * np.cos(roll), 0.],
55 | [0., 0., 0., 1.]], dtype=np.float32)
56 |
57 | return m
58 |
59 |
60 | def xyz_ypr2m(x, y, z, yaw, pitch, roll):
61 | # Construct 4x4 transformation matrix with given translation and rotation part set to given yaw, pitch, roll.
62 | # Based on formulas in Section 2.2.1
63 | m = ypr2m(yaw, pitch, roll)
64 | m[0, 3] = x
65 | m[1, 3] = y
66 | m[2, 3] = z
67 | return m
68 |
69 |
70 | def apply_transform(pc: torch.Tensor, m: torch.Tensor):
71 | # Apply 4x4 SE(3) transformation matrix on (N, 3) point cloud or 3x3 transformation on (N, 2) point cloud
72 | assert pc.ndim == 2
73 | n_dim = pc.shape[1]
74 | assert n_dim == 2 or n_dim == 3
75 | assert m.shape == (n_dim + 1, n_dim + 1)
76 | # (m @ pc.t).t = pc @ m.t
77 | pc = pc @ m[:n_dim, :n_dim].transpose(1, 0) + m[:n_dim, -1]
78 | return pc
79 |
80 |
81 | def relative_pose(m1, m2):
82 | # !!! DO NOT USE THIS FUNCTION FOR MULRAN POSES !!!
83 | # SE(3) pose is 4x 4 matrix, such that
84 | # Pw = [R | T] @ [P]
85 | # [0 | 1] [1]
86 | # where Pw are coordinates in the world reference frame and P are coordinates in the camera frame
87 | # m1: coords in camera/lidar1 reference frame -> world coordinate frame
88 | # m2: coords in camera/lidar2 coords -> world coordinate frame
89 | # returns: coords in camera/lidar1 reference frame -> coords in camera/lidar2 reference frame
90 | # relative pose of the first camera with respect to the second camera
91 | return np.linalg.inv(m2) @ m1
92 |
--------------------------------------------------------------------------------
/src/data/datasets/quantization.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from typing import List
3 | from abc import ABC, abstractmethod
4 | import torch
5 | # import MinkowskiEngine as ME
6 |
7 |
8 | class Quantizer(ABC):
9 | @abstractmethod
10 | def __call__(self, pc):
11 | pass
12 |
13 | @abstractmethod
14 | def dequantize(self, coords):
15 | pass
16 |
17 | @abstractmethod
18 | def keypoint_position(self, supervoxel_centers, stride, kp_offset):
19 | pass
20 |
21 |
22 | class PolarQuantizer(Quantizer):
23 | def __init__(self, quant_step: List[float]):
24 | assert len(quant_step) == 3, '3 quantization steps expected: for sector (in degrees), ring and z-coordinate (in meters)'
25 | self.quant_step = torch.tensor(quant_step, dtype=torch.float)
26 | self.theta_range = int(360. // self.quant_step[0])
27 | self.quant_step = torch.tensor(quant_step, dtype=torch.float)
28 |
29 | def __call__(self, pc):
30 | # Convert to polar coordinates and quantize with different step size for each coordinate
31 | # pc: (N, 3) point cloud with Cartesian coordinates (X, Y, Z)
32 | assert pc.shape[1] == 3
33 |
34 | # theta is an angle in degrees in 0..360 range
35 | theta = 180. + torch.atan2(pc[:, 1], pc[:, 0]) * 180./np.pi
36 | # dist is a distance from a coordinate origin
37 | dist = torch.sqrt(pc[:, 0]**2 + pc[:, 1]**2)
38 | z = pc[:, 2]
39 | polar_pc = torch.stack([theta, dist, z], dim=1)
40 | # Scale each coordinate so after quantization with step 1. we got the required quantization step in each dim
41 | polar_pc = polar_pc / self.quant_step
42 | quantized_polar_pc, ndx = ME.utils.sparse_quantize(polar_pc, quantization_size=1., return_index=True)
43 | # Return quantized coordinates and index of selected elements
44 | return quantized_polar_pc, ndx
45 |
46 | def to_cartesian(self, pc):
47 | # Convert to radian in -180..180 range
48 | theta = np.pi * (pc[:, 0] - 180.) / 180.
49 | x = torch.cos(theta) * pc[:, 1]
50 | y = torch.sin(theta) * pc[:, 1]
51 | z = pc[:, 2]
52 | cartesian_pc = torch.stack([x, y, z], dim=1)
53 | return cartesian_pc
54 |
55 | def dequantize(self, coords):
56 | # Dequantize coords and convert to cartesian as (N, 3) tensor of floats
57 | pc = (0.5 + coords) * self.quant_step.to(coords.device)
58 | return self.to_cartesian(pc)
59 |
60 | def keypoint_position(self, supervoxel_centres, stride, kp_offset):
61 | # Add voxel center position: 0.5 * self.voxel_size
62 | # to offset from the supervoxel centre value (in -1..1 range converted to absolute values):
63 | # self.voxel_size + features * super_voxel_size / 2
64 | device = supervoxel_centres.device
65 | supervoxel_centres = (supervoxel_centres + 0.5) * self.quant_step.to(device)
66 | supervoxel_size = torch.tensor(stride, dtype=torch.float, device=supervoxel_centres.device) * \
67 | self.quant_step.to(device)
68 | #kp_pos = supervoxel_centres
69 | kp_pos = supervoxel_centres + kp_offset * supervoxel_size / 2.
70 |
71 | kp_pos = self.to_cartesian(kp_pos)
72 | return kp_pos
73 |
74 |
75 | class CartesianQuantizer(Quantizer):
76 | def __init__(self, quant_step: float):
77 | self.quant_step = quant_step
78 |
79 | def __call__(self, pc):
80 | # Converts to polar coordinates and quantizes with different step size for each coordinate
81 | # pc: (N, 3) point cloud with Cartesian coordinates (X, Y, Z)
82 | assert pc.shape[1] == 3
83 | quantized_pc, ndx = ME.utils.sparse_quantize(pc, quantization_size=self.quant_step, return_index=True)
84 | # Return quantized coordinates and index of selected elements
85 | return quantized_pc, ndx
86 |
87 | def dequantize(self, coords):
88 | # Dequantize coords and return as (N, 3) tensor of floats
89 | # Use coords of the voxel center
90 | pc = (0.5 + coords) * self.quant_step
91 | return pc
92 |
93 | def keypoint_position(self, supervoxel_centers, stride, kp_offset):
94 | # Add voxel center position: 0.5 * self.voxel_size
95 | # to offset from the supervoxel centre value (in -1..1 range converted to absolute values):
96 | # self.voxel_size + features * super_voxel_size / 2
97 | supervoxel_centres = (supervoxel_centers + 0.5) * self.quant_step
98 | supervoxel_size = torch.tensor(stride, dtype=torch.float, device=supervoxel_centres.device) * self.quant_step
99 | if kp_offset is not None:
100 | kp_pos = supervoxel_centres + kp_offset * supervoxel_size / 2.
101 | else:
102 | kp_pos = supervoxel_centres
103 | return kp_pos
104 |
105 |
106 | if __name__ == "__main__":
107 | n = 1000
108 | cart = torch.rand((n, 3), dtype=torch.float)
109 | cart[:, 0] = cart[:, 0] * 200. - 100.
110 | cart[:, 1] = cart[:, 1] * 200. - 100.
111 | cart[:, 2] = cart[:, 2] * 30. - 10.
112 |
113 | quantizer = PolarQuantizer([0.5, 0.3, 0.2])
114 | polar_quant, ndx = quantizer(cart)
115 | back2cart = quantizer.dequantize(polar_quant)
116 | cart_filtered = cart[ndx]
117 | dist = torch.norm(back2cart - cart_filtered, dim=1)
118 | print(f'Residual error - min: {torch.min(dist):0.5f} max: {torch.max(dist):0.5f} mean: {torch.mean(dist):0.5f}')
119 |
120 |
--------------------------------------------------------------------------------
/src/data/datasets/samplers.py:
--------------------------------------------------------------------------------
1 | # Warsaw University of Technology
2 |
3 | import random
4 | import copy
5 | from torch.utils.data import Sampler
6 |
7 | from datasets.base_datasets import TrainingDataset
8 |
9 | VERBOSE = False
10 |
11 |
12 | class ListDict(object):
13 | def __init__(self, items=None):
14 | if items is not None:
15 | self.items = copy.deepcopy(items)
16 | self.item_to_position = {item: ndx for ndx, item in enumerate(items)}
17 | else:
18 | self.items = []
19 | self.item_to_position = {}
20 |
21 | def add(self, item):
22 | if item in self.item_to_position:
23 | return
24 | self.items.append(item)
25 | self.item_to_position[item] = len(self.items)-1
26 |
27 | def remove(self, item):
28 | position = self.item_to_position.pop(item)
29 | last_item = self.items.pop()
30 | if position != len(self.items):
31 | self.items[position] = last_item
32 | self.item_to_position[last_item] = position
33 |
34 | def choose_random(self):
35 | return random.choice(self.items)
36 |
37 | def __contains__(self, item):
38 | return item in self.item_to_position
39 |
40 | def __iter__(self):
41 | return iter(self.items)
42 |
43 | def __len__(self):
44 | return len(self.items)
45 |
46 |
47 | class BatchSampler(Sampler):
48 | # Sampler returning list of indices to form a mini-batch
49 | # Samples elements in groups consisting of k=2 similar elements (positives)
50 | # Batch has the following structure: item1_1, ..., item1_k, item2_1, ... item2_k, itemn_1, ..., itemn_k
51 | def __init__(self, dataset: TrainingDataset, batch_size: int, batch_size_limit: int = None,
52 | batch_expansion_rate: float = None, max_batches: int = None):
53 | if batch_expansion_rate is not None:
54 | assert batch_expansion_rate > 1., 'batch_expansion_rate must be greater than 1'
55 | assert batch_size <= batch_size_limit, 'batch_size_limit must be greater or equal to batch_size'
56 |
57 | self.batch_size = batch_size
58 | self.batch_size_limit = batch_size_limit
59 | self.batch_expansion_rate = batch_expansion_rate
60 | self.max_batches = max_batches
61 | self.dataset = dataset
62 | self.k = 2 # Number of positive examples per group must be 2
63 | if self.batch_size < 2 * self.k:
64 | self.batch_size = 2 * self.k
65 | print('WARNING: Batch too small. Batch size increased to {}.'.format(self.batch_size))
66 |
67 | self.batch_idx = [] # Index of elements in each batch (re-generated every epoch)
68 | self.elems_ndx = list(self.dataset.queries) # List of point cloud indexes
69 |
70 | def __iter__(self):
71 | # Re-generate batches every epoch
72 | self.generate_batches()
73 | for batch in self.batch_idx:
74 | yield batch
75 |
76 | def __len(self):
77 | return len(self.batch_idx)
78 |
79 | def expand_batch(self):
80 | if self.batch_expansion_rate is None:
81 | print('WARNING: batch_expansion_rate is None')
82 | return
83 |
84 | if self.batch_size >= self.batch_size_limit:
85 | return
86 |
87 | old_batch_size = self.batch_size
88 | self.batch_size = int(self.batch_size * self.batch_expansion_rate)
89 | self.batch_size = min(self.batch_size, self.batch_size_limit)
90 | print('=> Batch size increased from: {} to {}'.format(old_batch_size, self.batch_size))
91 |
92 | def generate_batches(self):
93 | # Generate training/evaluation batches.
94 | # batch_idx holds indexes of elements in each batch as a list of lists
95 | self.batch_idx = []
96 |
97 | unused_elements_ndx = ListDict(self.elems_ndx)
98 | current_batch = []
99 |
100 | assert self.k == 2, 'sampler can sample only k=2 elements from the same class'
101 |
102 | while True:
103 | if len(current_batch) >= self.batch_size or len(unused_elements_ndx) == 0:
104 | # Flush out batch, when it has a desired size, or a smaller batch, when there's no more
105 | # elements to process
106 | if len(current_batch) >= 2*self.k:
107 | # Ensure there're at least two groups of similar elements, otherwise, it would not be possible
108 | # to find negative examples in the batch
109 | assert len(current_batch) % self.k == 0, 'Incorrect bach size: {}'.format(len(current_batch))
110 | self.batch_idx.append(current_batch)
111 | current_batch = []
112 | if (self.max_batches is not None) and (len(self.batch_idx) >= self.max_batches):
113 | break
114 | if len(unused_elements_ndx) == 0:
115 | break
116 |
117 | # Add k=2 similar elements to the batch
118 | selected_element = unused_elements_ndx.choose_random()
119 | unused_elements_ndx.remove(selected_element)
120 | positives = self.dataset.get_positives(selected_element)
121 | if len(positives) == 0:
122 | # Broken dataset element without any positives
123 | continue
124 |
125 | unused_positives = [e for e in positives if e in unused_elements_ndx]
126 | # If there're unused elements similar to selected_element, sample from them
127 | # otherwise sample from all similar elements
128 | if len(unused_positives) > 0:
129 | second_positive = random.choice(unused_positives)
130 | unused_elements_ndx.remove(second_positive)
131 | else:
132 | second_positive = random.choice(list(positives))
133 |
134 | current_batch += [selected_element, second_positive]
135 |
136 | for batch in self.batch_idx:
137 | assert len(batch) % self.k == 0, 'Incorrect bach size: {}'.format(len(batch))
138 |
139 |
140 | if __name__ == '__main__':
141 | pass
142 |
143 |
--------------------------------------------------------------------------------
/src/data/datasets/southbay/generate_evaluation_sets.py:
--------------------------------------------------------------------------------
1 | # Generate evaluation sets
2 | # - Map point clouds are taken from MapData folder
3 | # - Query point clouds are taken from TestData
4 | # For each area (BaylandsToSeafood, ColumbiaPark, HighWay237, MathildaAVE, SanJoseDowntown, SunnyvaleBigloop) a
5 | # separate evaluation set is crated. We do not match clouds from different areas.
6 |
7 | # This file is directly copied from: https://github.com/jac99/Egonn/blob/main/datasets/southbay/generate_evaluation_sets.py
8 |
9 | import argparse
10 | import numpy as np
11 | from typing import List
12 | import os
13 | import sys
14 | sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))
15 |
16 | from datasets.southbay.southbay_raw import SouthBayDataset
17 | from datasets.base_datasets import EvaluationTuple, EvaluationSet, filter_query_elements
18 |
19 | def get_scans(ds: SouthBayDataset, split: str, area: str, min_displacement: float = 0.1) -> List[EvaluationTuple]:
20 | elems = []
21 | for ndx in ds.location_ndx[split][area]:
22 | pose = ds.global_ndx[ndx].pose
23 | position = pose[0:2, 3] # (x, y) position in global coordinate frame
24 | rel_scan_filepath = ds.global_ndx[ndx].rel_scan_filepath
25 | timestamp = ds.global_ndx[ndx].timestamp
26 |
27 | item = EvaluationTuple(timestamp, rel_scan_filepath, position=position, pose=pose)
28 | elems.append(item)
29 |
30 | print(f"{len(elems)} total elements in {split} split")
31 |
32 | # Filter-out elements leaving only 1 per grid cell with min_displacement size
33 | pos = np.zeros((len(elems), 2), dtype=np.float32)
34 | for ndx, e in enumerate(elems):
35 | pos[ndx] = e.position
36 |
37 | # Quantize x-y coordinates. Quantized coords start from 0
38 | pos = np.floor(pos / min_displacement)
39 | pos = pos.astype(int)
40 | _, unique_ndx = np.unique(pos, axis=0, return_index=True)
41 |
42 | # Leave only unique elements
43 | elems = [elems[i] for i in unique_ndx]
44 | print(f"{len(elems)} filtered elements in {split} split with grid cell size = {min_displacement}")
45 |
46 | return elems
47 |
48 |
49 | def generate_evaluation_set(ds: SouthBayDataset, area: str, min_displacement: float = 0.1, dist_threshold=5) -> \
50 | EvaluationSet:
51 | map_set = get_scans(ds, 'MapData', area, min_displacement)
52 | query_set = get_scans(ds, 'TestData', area, min_displacement)
53 | query_set = filter_query_elements(query_set, map_set, dist_threshold)
54 | print(f'Area: {area} - {len(map_set)} database elements, {len(query_set)} query elements\n')
55 | return EvaluationSet(query_set, map_set)
56 |
57 |
58 | if __name__ == '__main__':
59 | parser = argparse.ArgumentParser(description='Generate evaluation sets for Apollo SouthBay dataset')
60 | parser.add_argument('--dataset_root', type=str, required=False, default='/data/raktim/Datasets/Apollo-Southbay')
61 | parser.add_argument('--min_displacement', type=float, default=1.0)
62 | # Ignore query elements that do not have a corresponding map element within the given threshold (in meters)
63 | parser.add_argument('--dist_threshold', type=float, default=5)
64 |
65 | args = parser.parse_args()
66 | print(f'Dataset root: {args.dataset_root}')
67 | print(f'Minimum displacement between scans in each set (map/query): {args.min_displacement}')
68 | print(f'Ignore query elements without a corresponding map element within a threshold [m]: {args.dist_threshold}')
69 |
70 | ds = SouthBayDataset(args.dataset_root)
71 | ds.print_info()
72 |
73 | min_displacement = args.min_displacement
74 |
75 | area = 'SunnyvaleBigloop' # Evaluation area
76 | assert area in ds.location_ndx['TestData']
77 | eval_set = generate_evaluation_set(ds, area, min_displacement=min_displacement,
78 | dist_threshold=args.dist_threshold)
79 | pickle_name = f'test_{area}_{args.min_displacement}_{args.dist_threshold}_20m.pickle'
80 | file_path_name = os.path.join(args.dataset_root, pickle_name)
81 | eval_set.save(file_path_name)
82 |
--------------------------------------------------------------------------------
/src/data/datasets/southbay/generate_training_tuples.py:
--------------------------------------------------------------------------------
1 | # Generate training triplets
2 |
3 | import argparse
4 | import os
5 | import pickle
6 | import sys
7 |
8 | import numpy as np
9 | import tqdm
10 |
11 | sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))
12 |
13 | from datasets.base_datasets import TrainingTuple
14 | from datasets.southbay.southbay_raw import SouthBayDataset
15 |
16 |
17 | class Triplet:
18 | def __init__(self, anchor: int, positives: np.ndarray, non_negatives: np.ndarray):
19 | self.anchor = anchor
20 | self.positives = positives
21 | self.non_negatives = non_negatives
22 |
23 |
24 | def generate_triplets(ds: SouthBayDataset, map_split: str, query_split: str,
25 | positives_th: int = 2, negatives_th: int = 10, min_displacement: float = 0.1):
26 | # All elements (anchors, positives and negatives) are taken from both map_split and query_split
27 | assert positives_th < negatives_th
28 |
29 | # Create a master table with positions of all point clouds from the query split and in map_split
30 | pc_ids, pc_poses = ds.get_poses2([query_split, map_split])
31 | pc_coords = pc_poses[:, :3, 3]
32 |
33 | # Quantize x-y coordinates
34 | pos = np.floor(pc_coords / min_displacement)
35 | pos = pos.astype(int)
36 | _, unique_ndx = np.unique(pos, axis=0, return_index=True)
37 |
38 | # Leave only unique elements
39 | pc_ids = pc_ids[unique_ndx]
40 | pc_coords = pc_coords[unique_ndx]
41 | print(f'{len(pc_ids)} point clouds left from {len(pc_poses)} after filtering with min_displacement={min_displacement}')
42 |
43 | triplets = []
44 | count_zero_positives = 0
45 | for anchor_id in tqdm.tqdm(pc_ids):
46 | anchor_coords = ds.global_ndx[anchor_id].pose[:3, 3]
47 | dist = np.linalg.norm(pc_coords - anchor_coords, axis=1)
48 | positives_mask = dist <= positives_th
49 | non_negatives_mask = dist <= negatives_th
50 |
51 | positives_ndx = pc_ids[positives_mask]
52 | # remove anchor_id from positives
53 | positives_ndx = np.array([e for e in positives_ndx if e != anchor_id])
54 | non_negatives_ndx = pc_ids[non_negatives_mask]
55 |
56 | if len(positives_ndx) == 0:
57 | # Skip examples without positives
58 | count_zero_positives += 1
59 | continue
60 |
61 | t = Triplet(anchor_id, positives_ndx, non_negatives_ndx)
62 | triplets.append(t)
63 |
64 | print(f'{count_zero_positives} filtered out due to no positives')
65 | print(f'{len(triplets)} training tuples generated')
66 |
67 | # Remove ids from positives and negatives that are not anchors
68 | anchors_set = set([e.anchor for e in triplets])
69 | triplets = [Triplet(e.anchor, [p for p in e.positives if p in anchors_set],
70 | [nn for nn in e.non_negatives if nn in anchors_set]) for e in triplets]
71 | print(len(triplets))
72 | # All used global ids
73 | used_ids = set()
74 | for triplet in triplets:
75 | used_ids.add(triplet.anchor)
76 | used_ids.update(list(triplet.positives))
77 | used_ids.update(list(triplet.non_negatives))
78 |
79 | # New ids, consecutive and starting from 0
80 | new_ids = {old_ndx: ndx for ndx, old_ndx in enumerate(used_ids)}
81 |
82 | tuples = {}
83 | for triplet in triplets:
84 | new_anchor_ndx = new_ids[triplet.anchor]
85 | pc = ds.global_ndx[triplet.anchor]
86 | positives = np.array([new_ids[e] for e in triplet.positives], dtype=np.int32)
87 | non_negatives = np.array([new_ids[e] for e in triplet.non_negatives], dtype=np.int32)
88 |
89 | # Sort ascending order
90 | positives = np.sort(positives)
91 | non_negatives = np.sort(non_negatives)
92 |
93 | tuple = TrainingTuple(id=new_anchor_ndx, timestamp=pc.timestamp,
94 | rel_scan_filepath= ds.dataset_root + '/' +pc.rel_scan_filepath,
95 | positives=positives, non_negatives=non_negatives,
96 | pose=pc.pose, positives_poses=None)
97 | tuples[new_anchor_ndx] = tuple
98 |
99 | return tuples
100 |
101 |
102 | if __name__ == '__main__':
103 | parser = argparse.ArgumentParser(description='Generate training triplets for Apollo SouthBay dataset')
104 | parser.add_argument('--dataset_root', type=str, default='/data/raktim/Datasets/Apollo-Southbay', help='Path to Apollo SouthBay root folder')
105 | parser.add_argument('--pos_th', type=float, default=2, help='Positives threshold')
106 | parser.add_argument('--neg_th', type=float, default=10, help='Negatives threshold')
107 | parser.add_argument('--min_displacement', type=float, default=1.0)
108 |
109 | query_split = 'TrainData'
110 |
111 | args = parser.parse_args()
112 | print(f'Dataset root folder: {args.dataset_root}')
113 | print(f'Split for positives/negatives: {query_split}')
114 | print(f'Positives threshold: {args.pos_th}')
115 | print(f'Negatives threshold: {args.neg_th}')
116 | print(f'Minimum displacement between consecutive scans: {args.min_displacement}')
117 |
118 | ds = SouthBayDataset(args.dataset_root)
119 | ds.print_info()
120 |
121 | triplets = generate_triplets(ds, 'MapData', query_split, positives_th=args.pos_th, negatives_th=args.neg_th,
122 | min_displacement=args.min_displacement)
123 | print(f'{len(triplets)} anchors generated')
124 |
125 | pickle_name = f'train_southbay_{args.pos_th}_{args.neg_th}.pickle'
126 | pickle_filepath = os.path.join(args.dataset_root, pickle_name)
127 | pickle.dump(triplets, open(pickle_filepath, 'wb'))
128 |
--------------------------------------------------------------------------------
/src/data/datasets/southbay/test_SunnyvaleBigloop_1.0_5_20m.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raktimgg/SALSA/9152e8dff6fb3fc07dd546c9f19bfab462d08a6f/src/data/datasets/southbay/test_SunnyvaleBigloop_1.0_5_20m.pickle
--------------------------------------------------------------------------------
/src/evaluate/SALSA/sgv_utils.py:
--------------------------------------------------------------------------------
1 | # Functions in this file are adapted from: https://github.com/ZhiChen902/SC2-PCR/blob/main/SC2_PCR.py
2 |
3 | import numpy as np
4 | import torch
5 |
6 |
7 | def match_pair_parallel(src_keypts, tgt_keypts, src_features, tgt_features):
8 | # normalize:
9 | # print(src_features.shape, tgt_features.shape)
10 | src_features = torch.nn.functional.normalize(src_features, p=2.0, dim=1)
11 | tgt_features = torch.nn.functional.normalize(tgt_features, p=2.0, dim=1)
12 | distance = torch.cdist(src_features, tgt_features)
13 | min_vals, min_ids = torch.min(distance, dim=2)
14 |
15 | min_ids = min_ids.unsqueeze(-1).expand(-1, -1, 3)
16 | tgt_keypts_corr = torch.gather(tgt_keypts, 1, min_ids)
17 | src_keypts_corr = src_keypts
18 |
19 | return src_keypts_corr, tgt_keypts_corr
20 |
21 | # def match_pair_parallel(src_keypts, tgt_keypts, src_features, tgt_features):
22 | # # normalize:
23 | # src_features = torch.nn.functional.normalize(src_features, p=2.0, dim=1)
24 | # tgt_features = torch.nn.functional.normalize(tgt_features, p=2.0, dim=1)
25 | # # print(src_features.shape, tgt_features.shape)
26 | # distance = torch.cdist(src_features, tgt_features)
27 |
28 | # min_vals, min_ids = torch.min(distance, dim=2)
29 |
30 | # # print(distance.shape)
31 | # sorted_vals, sorted_ids = torch.sort(min_vals, dim=1)
32 | # sorted_ids = sorted_ids[:,:2000]
33 | # # print(sorted_ids.shape)
34 | # # print(sorted_ids)
35 | # # print(min_ids.shape)
36 | # # print(min_vals[:,sorted_ids[0]])
37 | # # print(min_ids[:,sorted_ids[0]])
38 |
39 | # tgt_ids = min_ids[:,sorted_ids[0]].unsqueeze(-1).expand(-1, -1, 3)
40 | # src_ids = sorted_ids.unsqueeze(-1).expand(-1, -1, 3)
41 |
42 | # tgt_keypts_corr = torch.gather(tgt_keypts, 1, tgt_ids)
43 | # # print(tgt_keypts_corr.shape)
44 | # src_keypts_corr = torch.gather(src_keypts, 1, src_ids)
45 | # # print(src_keypts_corr.shape)
46 |
47 | # return src_keypts_corr, tgt_keypts_corr
48 |
49 | def power_iteration(M, num_iterations=5):
50 | """
51 | Calculate the leading eigenvector using power iteration algorithm
52 | Input:
53 | - M: [bs, num_pts, num_pts] the adjacency matrix
54 | Output:
55 | - leading_eig: [bs, num_pts] leading eigenvector
56 | """
57 | leading_eig = torch.ones_like(M[:, :, 0:1])
58 | leading_eig_last = leading_eig
59 | for i in range(num_iterations):
60 | # print(i)
61 | leading_eig = torch.bmm(M, leading_eig)
62 | leading_eig = leading_eig / (torch.norm(leading_eig, dim=1, keepdim=True) + 1e-6)
63 | if torch.allclose(leading_eig, leading_eig_last):
64 | break
65 | leading_eig_last = leading_eig
66 | leading_eig = leading_eig.squeeze(-1)
67 | return leading_eig
68 |
69 |
70 | def cal_spatial_consistency( M, leading_eig):
71 | """
72 | Calculate the spatial consistency based on spectral analysis.
73 | Input:
74 | - M: [bs, num_pts, num_pts] the adjacency matrix
75 | - leading_eig [bs, num_pts] the leading eigenvector of matrix M
76 | Output:
77 | - sc_score_list [bs, 1]
78 | """
79 | spatial_consistency = leading_eig[:, None, :] @ M @ leading_eig[:, :, None]
80 | spatial_consistency = spatial_consistency.squeeze(-1) / M.shape[1]
81 | return spatial_consistency
82 |
83 |
84 | def sgv(src_keypts, tgt_keypts, src_features, tgt_features, d_thresh=5.0):
85 | """
86 | Input:
87 | - src_keypts: [1, num_pts, 3]
88 | - tgt_keypts: [bs, num_pts, 3]
89 | - src_features: [1, num_pts, D]
90 | - tgt_features: [bs, num_pts, D]
91 | Output:
92 | - sc_score_list: [bs, 1], spatial consistency score for each candidate
93 | """
94 | # print('src_kp', src_keypts.shape)
95 | # print('tgt_kp', tgt_keypts.shape)
96 | # Correspondence Estimation: Nearest Neighbour Matching
97 | src_keypts_corr, tgt_keypts_corr = match_pair_parallel(src_keypts, tgt_keypts, src_features, tgt_features)
98 | # print('src',src_keypts_corr.shape)
99 | # print('tgt',tgt_keypts_corr.shape)
100 |
101 | # Spatial Consistency Adjacency Matrix
102 | src_dist = torch.norm((src_keypts_corr[:, :, None, :] - src_keypts_corr[:, None, :, :]), dim=-1)
103 | target_dist = torch.norm((tgt_keypts_corr[:, :, None, :] - tgt_keypts_corr[:, None, :, :]), dim=-1)
104 | cross_dist = torch.abs(src_dist - target_dist)
105 | adj_mat = torch.clamp(1.0 - cross_dist ** 2 / d_thresh ** 2, min=0)
106 | # print(adj_mat)
107 | # Spatial Consistency Score
108 | # print(adj_mat.shape)
109 | lead_eigvec = power_iteration(adj_mat)
110 | # print(lead_eigvec.shape)
111 | sc_score_list = cal_spatial_consistency(adj_mat, lead_eigvec)
112 |
113 | sc_score_list = np.squeeze(sc_score_list.cpu().detach().numpy())
114 | return sc_score_list
115 |
116 | def sgv_fn(query_keypoints, candidate_keypoints, d_thresh=5.0, max_points=15000):
117 |
118 | # print(len(query_keypoints),len(candidate_keypoints))
119 | kp1 = query_keypoints['keypoints']
120 | kp2 = candidate_keypoints['keypoints']
121 | f1 = query_keypoints['features']
122 | f2 = candidate_keypoints['features']
123 |
124 | # draw_registration_result(kp1, kp2, np.eye(4))
125 |
126 | min_num_feat = min(len(kp1),len(kp2))
127 | min_num_feat = min(min_num_feat, max_points)
128 | kp1 = kp1[:min_num_feat]
129 | kp2 = kp2[:min_num_feat]
130 | f1 = f1[:min_num_feat]
131 | f2 = f2[:min_num_feat]
132 |
133 | src_keypts = kp1.unsqueeze(0).cuda()
134 | tgt_keypts = kp2.unsqueeze(0).cuda()
135 | src_features = f1.unsqueeze(0).cuda()
136 | tgt_features = f2.unsqueeze(0).cuda()
137 |
138 | conf = sgv(src_keypts, tgt_keypts, src_features, tgt_features, d_thresh=d_thresh)
139 |
140 |
141 |
142 | return conf
143 |
--------------------------------------------------------------------------------
/src/loss/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raktimgg/SALSA/9152e8dff6fb3fc07dd546c9f19bfab462d08a6f/src/loss/__init__.py
--------------------------------------------------------------------------------
/src/loss/global_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | def triplet_margin_loss(query, positive, negative, margin=0.1):
5 | distance_positive = torch.norm(query - positive, dim=1, p=2)**2
6 | distance_negative = torch.norm(query - negative, dim=1, p=2)**2
7 | losses = torch.relu(distance_positive - distance_negative + margin)
8 | loss = torch.mean(losses)
9 | return loss
10 |
11 | # def triplet_margin_loss(query, positive, negative, margin=0.1):
12 | # distance_positive = F.cosine_similarity(query, positive, dim=1)
13 | # distance_negative = F.cosine_similarity(query, negative, dim=1)
14 | # losses = torch.relu(distance_negative - distance_positive + margin)
15 | # loss = torch.mean(losses)
16 | # return loss
17 |
--------------------------------------------------------------------------------
/src/loss/local_consistency_loss.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import numpy as np
4 | import torch
5 | import torch.nn.functional as F
6 | # sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
7 | from utils.misc_utils import pdist, hashM
8 |
9 |
10 | def point_contrastive_loss(F0, F1, positive_pairs, point_pos_margin=0.1, point_neg_margin=2.0,point_neg_weight=1.0,
11 | num_pos=5192,
12 | num_hn_samples=2048):
13 | """
14 | Randomly select 'num-pos' positive pairs.
15 | Find the hardest-negative (from a random subset of num_hn_samples) for each point in a positive pair.
16 | Calculate contrastive loss on the tuple (p1,p2,hn1,hn2)
17 | Based on: https://github.com/chrischoy/FCGF/blob/master/lib/trainer.py
18 | """
19 | N0, N1 = len(F0), len(F1)
20 | N_pos_pairs = len(positive_pairs)
21 | hash_seed = max(N0, N1)
22 | sel0 = np.random.choice(N0, min(N0, num_hn_samples), replace=False)
23 | sel1 = np.random.choice(N1, min(N1, num_hn_samples), replace=False)
24 |
25 | if N_pos_pairs > num_pos:
26 | pos_sel = np.random.choice(N_pos_pairs, num_pos, replace=False)
27 | sample_pos_pairs = positive_pairs[pos_sel]
28 | else:
29 | sample_pos_pairs = positive_pairs
30 |
31 | # Find negatives for all F1[positive_pairs[:, 1]]
32 | subF0, subF1 = F0[sel0], F1[sel1]
33 |
34 | pos_ind0 = sample_pos_pairs[:, 0] # .long()
35 | pos_ind1 = sample_pos_pairs[:, 1] # .long()
36 | posF0, posF1 = F0[pos_ind0], F1[pos_ind1]
37 |
38 | D01 = pdist(posF0, subF1, dist_type='L2')
39 | D10 = pdist(posF1, subF0, dist_type='L2')
40 |
41 | D01min, D01ind = D01.min(1)
42 | D10min, D10ind = D10.min(1)
43 |
44 | if not isinstance(positive_pairs, np.ndarray):
45 | positive_pairs = np.array(positive_pairs, dtype=np.int64)
46 |
47 | pos_keys = hashM(positive_pairs, hash_seed)
48 |
49 | D01ind = sel1[D01ind.cpu().numpy()]
50 | D10ind = sel0[D10ind.cpu().numpy()]
51 | neg_keys0 = hashM([pos_ind0, D01ind], hash_seed)
52 | neg_keys1 = hashM([D10ind, pos_ind1], hash_seed)
53 |
54 | mask0 = torch.from_numpy(
55 | np.logical_not(np.isin(neg_keys0, pos_keys, assume_unique=False)))
56 | mask1 = torch.from_numpy(
57 | np.logical_not(np.isin(neg_keys1, pos_keys, assume_unique=False)))
58 | pos_loss = F.relu((posF0 - posF1).pow(2).sum(1) - point_pos_margin)
59 | neg_loss0 = F.relu(point_neg_margin - D01min[mask0]).pow(2)
60 | neg_loss1 = F.relu(point_neg_margin - D10min[mask1]).pow(2)
61 |
62 | pos_loss = pos_loss.mean()
63 | neg_loss = (neg_loss0.mean() + neg_loss1.mean()) / 2
64 | loss = pos_loss + point_neg_weight * neg_loss
65 | return loss
66 |
67 |
68 | def point_infonce_loss(query_feats, pos_feats, pos_pairs, neg_pairs, config): # TODO
69 | return 0
70 |
--------------------------------------------------------------------------------
/src/loss/loss.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
5 |
6 | from loss.global_loss import triplet_margin_loss
7 | from loss.local_consistency_loss import point_contrastive_loss
8 |
9 | def find_loss(local_features, global_descriptors,point_pos_pairs):
10 | global_loss = 0
11 | local_loss = 0
12 | batch_size = int(global_descriptors.shape[0]/3)
13 | for i in range(batch_size):
14 |
15 | ## Global descriptor triplet loss ###################################################
16 | q_gd = global_descriptors[i][None,...]
17 | p_gd = global_descriptors[(1*batch_size) + i][None,...]
18 | n_gd = global_descriptors[(2*batch_size) + i][None,...]
19 |
20 |
21 | global_loss += triplet_margin_loss(q_gd, p_gd, n_gd, margin=0.1)
22 |
23 | ## Local features loss ###############################################################
24 | ## Point triplet loss ###################
25 | if point_pos_pairs[i].shape[0]>0:
26 | point_loss = point_contrastive_loss(local_features[i], local_features[(1*batch_size) + i], point_pos_pairs[i])
27 | else:
28 | point_loss = 0
29 | local_loss += point_loss
30 | #########################################
31 | # loss = global_loss + local_loss
32 | loss = global_loss + local_loss
33 | return loss
--------------------------------------------------------------------------------
/src/misc/point_clouds.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import os
3 |
4 | import numpy as np
5 | import open3d as o3d
6 |
7 |
8 | def draw_registration_result(source, target, transformation):
9 | source_temp = copy.deepcopy(source)
10 | target_temp = copy.deepcopy(target)
11 | source_temp.paint_uniform_color([1, 0.706, 0])
12 | target_temp.paint_uniform_color([0, 0.651, 0.929])
13 | source_temp.transform(transformation)
14 | o3d.visualization.draw_geometries([source_temp, target_temp],
15 | zoom=0.4459,
16 | front=[0.9288, -0.2951, -0.2242],
17 | lookat=[1.6784, 2.0612, 1.4451],
18 | up=[-0.3402, -0.9189, -0.1996])
19 |
20 |
21 | def draw_pc(pc):
22 | pc = copy.deepcopy(pc)
23 | pc.paint_uniform_color([1, 0.706, 0])
24 | o3d.visualization.draw_geometries([pc],
25 | zoom=0.4459,
26 | front=[0.9288, -0.2951, -0.2242],
27 | lookat=[1.6784, 2.0612, 1.4451],
28 | up=[-0.3402, -0.9189, -0.1996])
29 |
30 |
31 | def icp(anchor_pc, positive_pc, transform: np.ndarray = None, point2plane: bool = False,
32 | inlier_dist_threshold: float = 1.2, max_iteration: int = 200):
33 | # transform: initial alignment transform
34 | if transform is not None:
35 | transform = transform.astype(float)
36 |
37 | voxel_size = 0.1
38 | pcd1 = o3d.geometry.PointCloud()
39 | pcd1.points = o3d.utility.Vector3dVector(anchor_pc)
40 | pcd1 = pcd1.voxel_down_sample(voxel_size=voxel_size)
41 |
42 | pcd2 = o3d.geometry.PointCloud()
43 | pcd2.points = o3d.utility.Vector3dVector(positive_pc)
44 | pcd2 = pcd2.voxel_down_sample(voxel_size=voxel_size)
45 |
46 | if point2plane:
47 | pcd1.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamKNN(knn=20))
48 | pcd2.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamKNN(knn=20))
49 | transform_estimation = o3d.pipelines.registration.TransformationEstimationPointToPlane()
50 | else:
51 | transform_estimation = o3d.pipelines.registration.TransformationEstimationPointToPoint()
52 |
53 | if transform is not None:
54 | reg_p2p = o3d.pipelines.registration.registration_icp(pcd1, pcd2, inlier_dist_threshold, transform,
55 | estimation_method=transform_estimation,
56 | criteria=o3d.pipelines.registration.ICPConvergenceCriteria(max_iteration=max_iteration))
57 | else:
58 | reg_p2p = o3d.pipelines.registration.registration_icp(pcd1, pcd2, inlier_dist_threshold,
59 | estimation_method=transform_estimation,
60 | criteria=o3d.pipelines.registration.ICPConvergenceCriteria(max_iteration=max_iteration))
61 |
62 | return reg_p2p.transformation, reg_p2p.fitness, reg_p2p.inlier_rmse
63 |
64 |
65 | def make_open3d_feature(data, dim, npts):
66 | feature = o3d.pipelines.registration.Feature()
67 | feature.resize(dim, npts)
68 | feature.data = data.cpu().numpy().astype('d').transpose()
69 | return feature
70 |
71 |
72 | def make_open3d_point_cloud(xyz, color=None):
73 | pcd = o3d.geometry.PointCloud()
74 | pcd.points = o3d.utility.Vector3dVector(xyz)
75 | if color is not None:
76 | pcd.colors = o3d.utility.Vector3dVector(color)
77 | return pcd
78 |
79 |
80 | class PointCloudLoader:
81 | # Generic point cloud loader class
82 | def __init__(self):
83 | # remove_zero_points: remove points with all zero coordinates
84 | # remove_ground_plane: remove points on ground plane level and below
85 | # ground_plane_level: ground plane level
86 | self.remove_zero_points = True
87 | self.remove_ground_plane = True
88 | self.ground_plane_level = None
89 | self.set_properties()
90 |
91 | def set_properties(self):
92 | # Set point cloud properties, such as ground_plane_level. Must be defined in inherited classes.
93 | raise NotImplementedError('set_properties must be defined in inherited classes')
94 |
95 | def __call__(self, file_pathname):
96 | # Reads the point cloud from a disk and preprocess (optional removal of zero points and points on the ground
97 | # plane and below
98 | # file_pathname: relative file path
99 | assert os.path.exists(file_pathname), f"Cannot open point cloud: {file_pathname}"
100 | pc = self.read_pc(file_pathname)
101 | # assert pc.shape[1] == 3
102 | if self.remove_zero_points:
103 | mask = np.all(np.isclose(pc, 0), axis=1)
104 | pc = pc[~mask]
105 | if self.remove_ground_plane:
106 | mask = pc[:, 2] > self.ground_plane_level
107 | pc = pc[mask]
108 |
109 | return pc
110 |
111 | def read_pc(self, file_pathname):
112 | # Reads the point cloud without pre-processing
113 | raise NotImplementedError("read_pc must be overloaded in an inheriting class")
114 |
--------------------------------------------------------------------------------
/src/misc/poses.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | def q2r(q):
6 | # Rotation matrix from Hamiltonian quaternion
7 | # Source: https://en.wikipedia.org/wiki/Conversion_between_quaternions_and_Euler_angles
8 | w, x, y, z = tuple(q)
9 |
10 | n = 1.0/np.sqrt(x*x+y*y+z*z+w*w)
11 | x *= n
12 | y *= n
13 | z *= n
14 | w *= n
15 | r = np.array([[1.0 - 2.0*y*y - 2.0*z*z, 2.0*x*y - 2.0*z*w, 2.0*x*z + 2.0*y*w],
16 | [2.0*x*y + 2.0*z*w, 1.0 - 2.0*x*x - 2.0*z*z, 2.0*y*z - 2.0*x*w],
17 | [2.0*x*z - 2.0*y*w, 2.0*y*z + 2.0*x*w, 1.0 - 2.0*x*x - 2.0*y*y]])
18 | return r
19 |
20 |
21 | def m2ypr(m):
22 | # Get yaw, pitch, roll angles from 4x4 transformation matrix
23 | # Based on formulas in Section 2.5.1 in:
24 | # A tutorial on SE(3) transformation parameterizations and on-manifold optimization
25 | # https://ingmec.ual.es/~jlblanco/papers/jlblanco2010geometry3D_techrep.pdf
26 | assert m.shape == (4, 4)
27 | pitch = np.arctan2(-m[2][0], np.sqrt(m[0][0]**2 + m[1][0]**2))
28 | # We do not handle degenerate case, when pitch is 90 degrees a.k.a. gimball lock
29 | assert not np.isclose(np.abs(pitch), np.pi/2)
30 | yaw = np.arctan2(m[1][0], m[0][0])
31 | roll = np.arctan2(m[2][1], m[2][2])
32 | return yaw, pitch, roll
33 |
34 |
35 | def m2xyz_ypr(m):
36 | # Get yaw, pitch, roll angles from 4x4 transformation matrix
37 | # Based on formulas in Section 2.5.1 in:
38 | # A tutorial on SE(3) transformation parameterizations and on-manifold optimization
39 | # https://ingmec.ual.es/~jlblanco/papers/jlblanco2010geometry3D_techrep.pdf
40 | assert m.shape == (4, 4)
41 | yaw, pitch, roll = m2ypr(m)
42 | return m[0, 3], m[1, 3], m[2, 3], yaw, pitch, roll
43 |
44 |
45 | def ypr2m(yaw, pitch, roll):
46 | # Construct 4x4 transformation matrix with rotation part set to given yaw, pitch, roll. Translation is set to 0.
47 | # Based on formulas in Section 2.2.1
48 | m = np.array([[np.cos(yaw) * np.cos(pitch), np.cos(yaw) * np.sin(pitch) * np.sin(roll) - np.sin(yaw) * np.cos(roll),
49 | np.cos(yaw) * np.sin(pitch) * np.cos(roll) + np.sin(yaw) * np.sin(roll), 0.],
50 | [np.sin(yaw) * np.cos(pitch), np.sin(roll) * np.sin(pitch) * np.sin(roll) + np.cos(yaw) * np.cos(roll),
51 | np.sin(yaw) * np.sin(pitch) * np.cos(roll) - np.cos(yaw) * np.sin(roll), 0.],
52 | [-np.sin(pitch), np.cos(pitch) * np.sin(roll), np.cos(pitch) * np.cos(roll), 0.],
53 | [0., 0., 0., 1.]], dtype=np.float32)
54 |
55 | return m
56 |
57 |
58 | def xyz_ypr2m(x, y, z, yaw, pitch, roll):
59 | # Construct 4x4 transformation matrix with given translation and rotation part set to given yaw, pitch, roll.
60 | # Based on formulas in Section 2.2.1
61 | m = ypr2m(yaw, pitch, roll)
62 | m[0, 3] = x
63 | m[1, 3] = y
64 | m[2, 3] = z
65 | return m
66 |
67 |
68 | def apply_transform(pc: torch.Tensor, m: torch.Tensor):
69 | # Apply 4x4 SE(3) transformation matrix on (N, 3) point cloud or 3x3 transformation on (N, 2) point cloud
70 | assert pc.ndim == 2
71 | n_dim = pc.shape[1]
72 | assert n_dim == 2 or n_dim == 3
73 | assert m.shape == (n_dim + 1, n_dim + 1)
74 | # (m @ pc.t).t = pc @ m.t
75 | pc = pc @ m[:n_dim, :n_dim].transpose(1, 0) + m[:n_dim, -1]
76 | return pc
77 |
78 |
79 | def relative_pose(m1, m2):
80 | # !!! DO NOT USE THIS FUNCTION FOR MULRAN POSES !!!
81 | # SE(3) pose is 4x 4 matrix, such that
82 | # Pw = [R | T] @ [P]
83 | # [0 | 1] [1]
84 | # where Pw are coordinates in the world reference frame and P are coordinates in the camera frame
85 | # m1: coords in camera/lidar1 reference frame -> world coordinate frame
86 | # m2: coords in camera/lidar2 coords -> world coordinate frame
87 | # returns: coords in camera/lidar1 reference frame -> coords in camera/lidar2 reference frame
88 | # relative pose of the first camera with respect to the second camera
89 | return np.linalg.inv(m2) @ m1
90 |
--------------------------------------------------------------------------------
/src/models/Mixer/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raktimgg/SALSA/9152e8dff6fb3fc07dd546c9f19bfab462d08a6f/src/models/Mixer/__init__.py
--------------------------------------------------------------------------------
/src/models/Mixer/mixer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | import numpy as np
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 | sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))
10 |
11 |
12 | class FeatureMixerLayer(nn.Module):
13 | def __init__(self, in_dim, mlp_ratio=1):
14 | super().__init__()
15 | self.mix = nn.Sequential(
16 | nn.LayerNorm(in_dim),
17 | nn.Linear(in_dim, int(in_dim * mlp_ratio)),
18 | nn.ReLU(),
19 | nn.Linear(int(in_dim * mlp_ratio), in_dim),
20 | )
21 |
22 | for m in self.modules():
23 | if isinstance(m, (nn.Linear)):
24 | nn.init.trunc_normal_(m.weight, std=0.02)
25 | if m.bias is not None:
26 | nn.init.zeros_(m.bias)
27 |
28 | def forward(self, x):
29 | return x + self.mix(x)
30 |
31 |
32 | class Mixer(nn.Module):
33 | def __init__(self,
34 | in_channels=35000,
35 | out_channels=1000,
36 | in_d=30,
37 | mix_depth=1,
38 | mlp_ratio=1,
39 | out_d=4,
40 | ) -> None:
41 | super().__init__()
42 |
43 | self.in_d = in_d
44 |
45 | self.out_d = out_d # row wise projection dimesion
46 |
47 | self.mix_depth = mix_depth # L the number of stacked FeatureMixers
48 | self.mlp_ratio = mlp_ratio # ratio of the mid projection layer in the mixer block
49 |
50 | self.mix = nn.Sequential(*[
51 | FeatureMixerLayer(in_dim=in_d, mlp_ratio=mlp_ratio)
52 | for _ in range(self.mix_depth)
53 | ])
54 | self.row_proj = nn.Linear(in_d, out_d)
55 | self.channel_proj = nn.Linear(in_channels, out_channels)
56 |
57 | def forward(self, x):
58 | # x = x.unsqueeze(0)
59 | x = self.mix(x)
60 | x = x.permute(0, 2, 1)
61 | x = self.channel_proj(x)
62 | x = x.permute(0, 2, 1)
63 | x = self.row_proj(x)
64 | x = F.normalize(x.flatten(1), p=2, dim=-1)
65 | return x
66 |
67 |
68 | # -------------------------------------------------------------------------------
69 |
70 | def print_nb_params(m):
71 | model_parameters = filter(lambda p: p.requires_grad, m.parameters())
72 | params = sum([np.prod(p.size()) for p in model_parameters])
73 | print(f'Trainable parameters: {params/1e6:.3}M')
74 |
75 |
76 | def main():
77 |
78 | model = Mixer().to('cuda')
79 | print_nb_params(model)
80 |
81 |
82 | if __name__ == '__main__':
83 | main()
84 |
--------------------------------------------------------------------------------
/src/models/SphereFormer/SparseTransformer/.gitignore:
--------------------------------------------------------------------------------
1 | *.egg-info
2 | build/
3 | dist/
4 | */__pycache__/
5 | *.pyc
6 |
--------------------------------------------------------------------------------
/src/models/SphereFormer/SparseTransformer/README.md:
--------------------------------------------------------------------------------
1 | # SpTr: PyTorch Spatially Sparse Transformer Library
2 |
3 | This library provides a **fast** and **memory-efficient** implementation for sparse transformer with **varying token numbers** (e.g., window transformer for 3D point cloud).
4 |
5 | This library has been used by the following works:
6 |
7 | * Spherical Transformer for LiDAR-based 3D Recognition (CVPR 2023): \[Paper\] [\[Code\]](https://github.com/dvlab-research/SphereFormer)
8 |
9 | * Stratified Transformer for 3D Point Cloud Segmentation (CVPR 2022): [\[Paper\]](https://openaccess.thecvf.com/content/CVPR2022/papers/Lai_Stratified_Transformer_for_3D_Point_Cloud_Segmentation_CVPR_2022_paper.pdf) [\[Code\]](https://github.com/dvlab-research/Stratified-Transformer)
10 |
11 | ## Installation
12 | ### Install Dependency
13 | ```
14 | pip install torch==1.8.0+cu111 torchvision==0.9.0+cu111 torchaudio==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html
15 | pip install torch_scatter
16 | pip install torch_geometric
17 | ```
18 |
19 | ### Compile sptr
20 | ```
21 | python3 setup.py install
22 | ```
23 |
24 |
25 | ## Usage
26 | SpTr can be easily used in most current transformer-based 3D point cloud networks, with only several minor modifications. First, define the attention module `sptr.VarLengthMultiheadSA`. Then, wrap the input features and indices into `sptr.SparseTrTensor`, and foward it into the module. That's all. A simple example is as follows. For more complex usage, you can refer to the code of above works (e.g., SphereFormer, StratifiedFormer).
27 | ### Example
28 | ```
29 | import sptr
30 |
31 | # Define module
32 | dim = 48
33 | num_heads = 3
34 | indice_key = 'sptr_0'
35 | window_size = np.array([0.4, 0.4, 0.4]) # can also be integers for voxel-based methods
36 | shift_win = False # whether to adopt shifted window
37 | self.attn = sptr.VarLengthMultiheadSA(
38 | dim,
39 | num_heads,
40 | indice_key,
41 | window_size,
42 | shift_win
43 | )
44 |
45 | # Wrap the input features and indices into SparseTrTensor. Note: indices can be either intergers for voxel-based methods or floats (i.e., xyz) for point-based methods
46 | # feats: [N, C], indices: [N, 4] with batch indices in the 0-th column
47 | input_tensor = sptr.SparseTrTensor(feats, indices, spatial_shape=None, batch_size=None)
48 | output_tensor = self.attn(input_tensor)
49 |
50 | # Extract features from output tensor
51 | output_feats = output_tensor.query_feats
52 | ```
53 |
54 | ## Authors
55 |
56 | Xin Lai (a Ph.D student at CSE CUHK) - Initial CUDA implementation, maintainance.
57 |
58 | Fanbin Lu (a Ph.D student at CSE CUHK) - Improve CUDA implementation, maintainance.
59 |
60 | Yukang Chen (a Ph.D student at CSE CUHK) - Maintainance.
61 |
62 |
63 | ## Cite
64 |
65 | If you find this project useful, please consider citing
66 | ```
67 | @inproceedings{lai2023spherical,
68 | title={Spherical Transformer for LiDAR-based 3D Recognition},
69 | author={Lai, Xin and Chen, Yukang and Lu, Fanbin and Liu, Jianhui and Jia, Jiaya},
70 | booktitle={CVPR},
71 | year={2023}
72 | }
73 | ```
74 | ```
75 | @inproceedings{lai2022stratified,
76 | title={Stratified transformer for 3d point cloud segmentation},
77 | author={Lai, Xin and Liu, Jianhui and Jiang, Li and Wang, Liwei and Zhao, Hengshuang and Liu, Shu and Qi, Xiaojuan and Jia, Jiaya},
78 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
79 | pages={8500--8509},
80 | year={2022}
81 | }
82 | ```
83 |
84 | ## License
85 |
86 | This project is licensed under the Apache license 2.0 License.
87 |
--------------------------------------------------------------------------------
/src/models/SphereFormer/SparseTransformer/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raktimgg/SALSA/9152e8dff6fb3fc07dd546c9f19bfab462d08a6f/src/models/SphereFormer/SparseTransformer/__init__.py
--------------------------------------------------------------------------------
/src/models/SphereFormer/SparseTransformer/setup.py:
--------------------------------------------------------------------------------
1 | #python3 setup.py install
2 | from setuptools import setup
3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
4 | import os
5 | from distutils.sysconfig import get_config_vars
6 |
7 | (opt,) = get_config_vars('OPT')
8 | os.environ['OPT'] = " ".join(
9 | flag for flag in opt.split() if flag != '-Wstrict-prototypes'
10 | )
11 |
12 | setup(
13 | name='sptr',
14 | ext_modules=[
15 | CUDAExtension('sptr_cuda', [
16 | 'src/sptr/pointops_api.cpp',
17 | 'src/sptr/attention/attention_cuda.cpp',
18 | 'src/sptr/attention/attention_cuda_kernel.cu',
19 | 'src/sptr/precompute/precompute.cpp',
20 | 'src/sptr/precompute/precompute_cuda_kernel.cu',
21 | 'src/sptr/rpe/relative_pos_encoding_cuda.cpp',
22 | 'src/sptr/rpe/relative_pos_encoding_cuda_kernel.cu',
23 | ],
24 | extra_compile_args={'cxx': ['-g'], 'nvcc': ['-O2']}
25 | )
26 | ],
27 | cmdclass={'build_ext': BuildExtension}
28 | )
29 |
--------------------------------------------------------------------------------
/src/models/SphereFormer/SparseTransformer/sptr/__init__.py:
--------------------------------------------------------------------------------
1 | from .functional import *
2 | from .utils import *
3 |
4 | class SparseTrTensor(object):
5 | def __init__(self, query_feats, query_indices, spatial_shape, batch_size, key_feats=None, value_feats=None, key_indices=None):
6 | """
7 | Args:
8 | query_feats: [num_points, num_features] feature tensor
9 | indices: [num_points, ndim + 1] indice tensor. batch index saved in indices[:, 0]
10 | spatial_shape: spatial shape of your sparse data
11 | batch_size: batch size of your sparse data
12 | """
13 | self.query_feats = query_feats
14 | self.key_feats = key_feats
15 | self.value_feats = value_feats
16 | self.query_indices = query_indices
17 | self.key_indices = key_indices
18 | self.spatial_shape = spatial_shape
19 | self.batch_size = batch_size
20 | self.indice_dict = {}
21 |
22 | @property
23 | def spatial_size(self):
24 | return np.prod(self.spatial_shape)
25 |
26 | def find_indice_params(self, key):
27 | if key is None:
28 | return None
29 | if key in self.indice_dict:
30 | return self.indice_dict[key]
31 | return None
32 |
33 | from .modules import VarLengthMultiheadSA, sparse_self_attention
34 |
--------------------------------------------------------------------------------
/src/models/SphereFormer/SparseTransformer/sptr/position_embedding.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | """
3 | Various positional encodings for the transformer.
4 | """
5 | import math
6 | import torch
7 | from torch import nn
8 | import numpy as np
9 |
10 |
11 | def shift_scale_points(pred_xyz, src_range, dst_range=None):
12 | """
13 | pred_xyz: B x N x 3
14 | src_range: [[B x 3], [B x 3]] - min and max XYZ coords
15 | dst_range: [[B x 3], [B x 3]] - min and max XYZ coords
16 | """
17 | if dst_range is None:
18 | dst_range = [
19 | torch.zeros((src_range[0].shape[0], 3), device=src_range[0].device),
20 | torch.ones((src_range[0].shape[0], 3), device=src_range[0].device),
21 | ]
22 |
23 | if pred_xyz.ndim == 4:
24 | src_range = [x[:, None] for x in src_range]
25 | dst_range = [x[:, None] for x in dst_range]
26 |
27 | assert src_range[0].shape[0] == pred_xyz.shape[0]
28 | assert dst_range[0].shape[0] == pred_xyz.shape[0]
29 | assert src_range[0].shape[-1] == pred_xyz.shape[-1]
30 | assert src_range[0].shape == src_range[1].shape
31 | assert dst_range[0].shape == dst_range[1].shape
32 | assert src_range[0].shape == dst_range[1].shape
33 |
34 | src_diff = src_range[1][:, None, :] - src_range[0][:, None, :]
35 | dst_diff = dst_range[1][:, None, :] - dst_range[0][:, None, :]
36 | prop_xyz = (
37 | ((pred_xyz - src_range[0][:, None, :]) * dst_diff) / src_diff
38 | ) + dst_range[0][:, None, :]
39 | return prop_xyz
40 |
41 |
42 | class PositionEmbeddingCoordsSine(nn.Module):
43 | def __init__(
44 | self,
45 | temperature=10000,
46 | normalize=False,
47 | scale=None,
48 | pos_type="fourier",
49 | d_pos=None,
50 | d_in=3,
51 | gauss_scale=1.0
52 | ):
53 | super().__init__()
54 | self.d_pos = d_pos
55 | self.temperature = temperature
56 | self.normalize = normalize
57 | if scale is not None and normalize is False:
58 | raise ValueError("normalize should be True if scale is passed")
59 | if scale is None:
60 | scale = 2 * math.pi
61 | assert pos_type in ["sine", "fourier"]
62 | self.pos_type = pos_type
63 | self.scale = scale
64 | if pos_type == "fourier":
65 | assert d_pos is not None
66 | assert d_pos % 2 == 0
67 | # define a gaussian matrix input_ch -> output_ch
68 | B = torch.empty((d_in, d_pos // 2)).normal_()
69 | B *= gauss_scale
70 | self.register_buffer("gauss_B", B)
71 | self.d_pos = d_pos
72 |
73 | def get_sine_embeddings(self, xyz, num_channels, input_range):
74 | num_channels = self.d_pos
75 | # clone coords so that shift/scale operations do not affect original tensor
76 | orig_xyz = xyz
77 | xyz = orig_xyz.clone()
78 |
79 | ncoords = xyz.shape[1]
80 | if self.normalize:
81 | xyz = shift_scale_points(xyz, src_range=input_range)
82 |
83 | ndim = num_channels // xyz.shape[2]
84 | if ndim % 2 != 0:
85 | ndim -= 1
86 | # automatically handle remainder by assiging it to the first dim
87 | rems = num_channels - (ndim * xyz.shape[2])
88 |
89 | assert (
90 | ndim % 2 == 0
91 | ), f"Cannot handle odd sized ndim={ndim} where num_channels={num_channels} and xyz={xyz.shape}"
92 |
93 | final_embeds = []
94 | prev_dim = 0
95 |
96 | for d in range(xyz.shape[2]):
97 | cdim = ndim
98 | if rems > 0:
99 | # add remainder in increments of two to maintain even size
100 | cdim += 2
101 | rems -= 2
102 |
103 | if cdim != prev_dim:
104 | dim_t = torch.arange(cdim, dtype=torch.float32, device=xyz.device)
105 | dim_t = self.temperature ** (2 * (dim_t // 2) / cdim)
106 |
107 | # create batch x cdim x nccords embedding
108 | raw_pos = xyz[:, :, d]
109 | if self.scale:
110 | raw_pos *= self.scale
111 | pos = raw_pos[:, :, None] / dim_t
112 | pos = torch.stack(
113 | (pos[:, :, 0::2].sin(), pos[:, :, 1::2].cos()), dim=3
114 | ).flatten(2)
115 | final_embeds.append(pos)
116 | prev_dim = cdim
117 |
118 | final_embeds = torch.cat(final_embeds, dim=2).permute(0, 2, 1)
119 | return final_embeds
120 |
121 | def get_fourier_embeddings(self, xyz, num_channels=None, input_range=None):
122 | # Follows - https://people.eecs.berkeley.edu/~bmild/fourfeat/index.html
123 |
124 | if num_channels is None:
125 | num_channels = self.gauss_B.shape[1] * 2
126 |
127 | bsize, npoints = xyz.shape[0], xyz.shape[1]
128 | assert num_channels > 0 and num_channels % 2 == 0
129 | d_in, max_d_out = self.gauss_B.shape[0], self.gauss_B.shape[1]
130 | d_out = num_channels // 2
131 | assert d_out <= max_d_out
132 | assert d_in == xyz.shape[-1]
133 |
134 | # clone coords so that shift/scale operations do not affect original tensor
135 | orig_xyz = xyz
136 | xyz = orig_xyz.clone()
137 |
138 | ncoords = xyz.shape[1]
139 | if self.normalize:
140 | xyz = shift_scale_points(xyz, src_range=input_range)
141 |
142 | xyz *= 2 * np.pi
143 | xyz_proj = torch.mm(xyz.view(-1, d_in), self.gauss_B[:, :d_out]).view(
144 | bsize, npoints, d_out
145 | )
146 | final_embeds = [xyz_proj.sin(), xyz_proj.cos()]
147 |
148 | # return batch x d_pos x npoints embedding
149 | final_embeds = torch.cat(final_embeds, dim=2).permute(0, 2, 1)
150 | return final_embeds
151 |
152 | def forward(self, xyz, num_channels=None, input_range=None):
153 | assert isinstance(xyz, torch.Tensor)
154 | assert xyz.ndim == 3
155 | # xyz is batch x npoints x 3
156 | if self.pos_type == "sine":
157 | with torch.no_grad():
158 | out = self.get_sine_embeddings(xyz, num_channels, input_range)
159 | elif self.pos_type == "fourier":
160 | with torch.no_grad():
161 | out = self.get_fourier_embeddings(xyz, num_channels, input_range)
162 | else:
163 | raise ValueError(f"Unknown {self.pos_type}")
164 |
165 | return out
166 |
167 | def extra_repr(self):
168 | st = f"type={self.pos_type}, scale={self.scale}, normalize={self.normalize}"
169 | if hasattr(self, "gauss_B"):
170 | st += (
171 | f", gaussB={self.gauss_B.shape}, gaussBsum={self.gauss_B.sum().item()}"
172 | )
173 | return st
--------------------------------------------------------------------------------
/src/models/SphereFormer/SparseTransformer/sptr/utils.py:
--------------------------------------------------------------------------------
1 | import numbers
2 | import torch
3 | import numpy as np
4 | from torch_scatter import segment_csr, gather_csr
5 | from torch_geometric.nn import voxel_grid
6 | from . import precompute_all
7 |
8 |
9 | def to_3d_numpy(size):
10 | if isinstance(size, numbers.Number):
11 | size = np.array([size, size, size]).astype(np.float32)
12 | elif isinstance(size, list):
13 | size = np.array(size)
14 | elif isinstance(size, np.ndarray):
15 | size = size
16 | else:
17 | raise ValueError("size is either a number, or a list, or a np.ndarray")
18 | return size
19 |
20 | def grid_sample(pos, batch, size, start, return_p2v=True, return_counts=True, return_unique=False):
21 | # pos: float [N, 3]
22 | # batch: long [N]
23 | # size: float [3, ]
24 | # start: float [3, ] / None
25 | cluster = voxel_grid(pos=pos, batch=batch, size=size, start=start) #[N, ]
26 |
27 | if return_p2v == False and return_counts == False:
28 | unique, cluster = torch.unique(cluster, sorted=True, return_inverse=True)
29 | return cluster
30 |
31 | unique, cluster, counts = torch.unique(cluster, sorted=True, return_inverse=True, return_counts=True)
32 |
33 | if return_p2v == False and return_counts == True:
34 | return cluster, counts.max().item(), counts
35 |
36 | # obtain p2v_map
37 | n = unique.shape[0]
38 | k = counts.max().item()
39 | p2v_map = cluster.new_zeros(n, k) #[n, k]
40 | mask = torch.arange(k).cuda().unsqueeze(0) < counts.unsqueeze(-1) #[n, k]
41 | p2v_map[mask] = torch.argsort(cluster)
42 |
43 | if return_unique:
44 | return cluster, p2v_map, counts, unique
45 |
46 | return cluster, p2v_map, counts
47 |
48 | def get_indices_params(xyz, batch, window_size, shift_win: bool):
49 |
50 | if isinstance(window_size, list) or isinstance(window_size, np.ndarray):
51 | window_size = torch.from_numpy(window_size).type_as(xyz).to(xyz.device)
52 | else:
53 | window_size = torch.tensor([window_size]*3).type_as(xyz).to(xyz.device)
54 |
55 | if shift_win:
56 | v2p_map, k, counts = grid_sample(xyz+1/2*window_size, batch, window_size, start=xyz.min(0)[0], return_p2v=False, return_counts=True)
57 | else:
58 | v2p_map, k, counts = grid_sample(xyz, batch, window_size, start=None, return_p2v=False, return_counts=True)
59 |
60 | v2p_map, sort_idx = v2p_map.sort()
61 |
62 | n = counts.shape[0]
63 | N = v2p_map.shape[0]
64 |
65 | n_max = k
66 | index_0_offsets, index_1_offsets, index_0, index_1 = precompute_all(N, n, n_max, counts)
67 | index_0 = index_0.long()
68 | index_1 = index_1.long()
69 |
70 | return index_0, index_0_offsets, n_max, index_1, index_1_offsets, sort_idx
71 |
72 | def scatter_softmax_csr(src: torch.Tensor, indptr: torch.Tensor, dim: int = -1):
73 | ''' src: (N, C),
74 | index: (Ni+1, ), [0, n0^2, n0^2+n1^2, ...]
75 | '''
76 | max_value_per_index = segment_csr(src, indptr, reduce='max')
77 | max_per_src_element = gather_csr(max_value_per_index, indptr)
78 |
79 | recentered_scores = src - max_per_src_element
80 | recentered_scores_exp = recentered_scores.exp_()
81 |
82 | sum_per_index = segment_csr(
83 | recentered_scores_exp, indptr, reduce='sum')
84 |
85 | normalizing_constants = gather_csr(sum_per_index, indptr)
86 |
87 | return recentered_scores_exp.div(normalizing_constants)
88 |
--------------------------------------------------------------------------------
/src/models/SphereFormer/SparseTransformer/src/sptr/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raktimgg/SALSA/9152e8dff6fb3fc07dd546c9f19bfab462d08a6f/src/models/SphereFormer/SparseTransformer/src/sptr/__init__.py
--------------------------------------------------------------------------------
/src/models/SphereFormer/SparseTransformer/src/sptr/attention/attention_cuda.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 | #include "attention_cuda_kernel.h"
6 |
7 |
8 | void attention_step1_forward_cuda(int N_q, int N_k, int M, int h, int hdim, const unsigned int n_max, at::Tensor q_tensor, at::Tensor k_tensor,
9 | at::Tensor index0_tensor, at::Tensor index1_tensor, at::Tensor attn_tensor)
10 | {
11 | const float *q = q_tensor.data_ptr();
12 | const float *k = k_tensor.data_ptr();
13 | const int *index0 = index0_tensor.data_ptr();
14 | const int *index1 = index1_tensor.data_ptr();
15 | float *attn = attn_tensor.data_ptr();
16 | attention_step1_forward_cuda_launcher(N_q, N_k, M, h, hdim, n_max, q, k, index0, index1, attn);
17 | }
18 |
19 | void attention_step1_backward_cuda(int N, int M, int h, int hdim, const unsigned int n_max, at::Tensor grad_out_tensor,
20 | at::Tensor index0_tensor, at::Tensor index0_tensor_offsets, at::Tensor index1_tensor, at::Tensor index1_tensor_offsets, at::Tensor q_tensor, at::Tensor k_tensor,
21 | at::Tensor grad_q_tensor, at::Tensor grad_k_tensor)
22 | {
23 | const float *grad_out = grad_out_tensor.data_ptr();
24 | const int *index0 = index0_tensor.data_ptr();
25 | const int *index0_offsets = index0_tensor_offsets.data_ptr();
26 | const int *index1 = index1_tensor.data_ptr();
27 | const int *index1_offsets = index1_tensor_offsets.data_ptr();
28 | const float *q = q_tensor.data_ptr();
29 | const float *k = k_tensor.data_ptr();
30 | float *grad_q = grad_q_tensor.data_ptr();
31 | float *grad_k = grad_k_tensor.data_ptr();
32 | attention_step1_backward_cuda_launcher(N, M, h, hdim, n_max, grad_out, index0, index0_offsets, index1, index1_offsets, q, k, grad_q, grad_k);
33 | }
34 |
35 | void attention_step2_forward_cuda(int N, int M, int h, int hdim, int n_max, at::Tensor attn_tensor, at::Tensor v_tensor,
36 | at::Tensor index0_offsets_tensor, at::Tensor index1_tensor, at::Tensor output_tensor)
37 | {
38 | const float *attn = attn_tensor.data_ptr();
39 | const float *v = v_tensor.data_ptr();
40 | const int *index0_offsets = index0_offsets_tensor.data_ptr();
41 | const int *index1 = index1_tensor.data_ptr();
42 | float *output = output_tensor.data_ptr();
43 | attention_step2_forward_cuda_launcher(N, M, h, hdim, n_max, attn, v, index0_offsets, index1, output);
44 | }
45 |
46 | void attention_step2_backward_cuda(int N, int M, int h, int hdim, int n_max, at::Tensor grad_out_tensor, at::Tensor index0_tensor,
47 | at::Tensor index0_offsets_tensor, at::Tensor index1_tensor, at::Tensor index1_offsets_tensor, at::Tensor attn_tensor, at::Tensor v_tensor,
48 | at::Tensor grad_attn_tensor, at::Tensor grad_v_tensor)
49 | {
50 | const float *grad_out = grad_out_tensor.data_ptr();
51 | const int *index0 = index0_tensor.data_ptr();
52 | const int *index0_offsets = index0_offsets_tensor.data_ptr();
53 | const int *index1 = index1_tensor.data_ptr();
54 | const int *index1_offsets = index1_offsets_tensor.data_ptr();
55 | const float *attn = attn_tensor.data_ptr();
56 | const float *v = v_tensor.data_ptr();
57 | float *grad_attn = grad_attn_tensor.data_ptr();
58 | float *grad_v = grad_v_tensor.data_ptr();
59 | attention_step2_backward_cuda_launcher(N, M, h, hdim, n_max, grad_out, index0, index0_offsets, index1, index1_offsets, attn, v, grad_attn, grad_v);
60 | }
61 |
--------------------------------------------------------------------------------
/src/models/SphereFormer/SparseTransformer/src/sptr/attention/attention_cuda_kernel.cu:
--------------------------------------------------------------------------------
1 | #include "../cuda_utils.h"
2 | #include "attention_cuda_kernel.h"
3 |
4 | __global__ void attention_step1_forward_cuda_kernel( // M, h, C//h
5 | int N_q, int N_k, int M, int h, int d, const float *q, const float *k,
6 | const int *index0, const int *index1, float *attn) {
7 | // q: [N, h, d], k: [h, d, N], index0: [M], index1: [M], attn: [h, M]
8 |
9 | int h_idx = blockIdx.y;
10 | int m_idx = blockIdx.x * blockDim.x + threadIdx.x;
11 |
12 | if(m_idx >= M) return;
13 | float s = 0.0;
14 | int index_q = index0[m_idx], index_k = index1[m_idx];
15 | for(int i = 0; i < d; i++){
16 | s += q[h_idx * d * N_q + i * N_q + index_q] * k[h_idx * d * N_k + i * N_k + index_k];
17 | }
18 | attn[h_idx * M + m_idx] = s;
19 | }
20 |
21 | void attention_step1_forward_cuda_launcher(int N_q, int N_k, int M, int h, int hdim, const unsigned int n_max,
22 | const float *q, const float *k, const int *index0, const int *index1, float *attn) {
23 | // input: attn: (h, M), index0: (M, ), index1: (M, )
24 | unsigned int n_threads = 512;
25 | dim3 blocks((M + n_threads - 1) / n_threads, h);
26 | attention_step1_forward_cuda_kernel<<>>(N_q, N_k, M, h, hdim, q, k, index0, index1, attn);
27 | }
28 |
29 | __global__ void attention_step1_backward_cuda_kernel( // M, h, C//h
30 | int N, int M, int h, int d, const float *grad_out, const int *index0, const int *index0_offsets, const int *index1, const int *index1_offsets,
31 | const float *q, const float *k, float *grad_q, float *grad_k) {
32 | // q: [N, h, d], k: [N, h, d], index0: [M], index1: [M], attn: [M, h], grad_out: [M, h]
33 | // grad_q: [N, h, hdim], grad_k: [N, h, hdim]
34 |
35 | int n_h = blockDim.x;
36 | int h_idx = blockIdx.y * n_h + threadIdx.y;
37 | int q_idx = blockIdx.x;
38 | int d_idx = threadIdx.x;
39 | int C = d * h;
40 |
41 | int start = index0_offsets[q_idx], end = index0_offsets[q_idx+1];
42 | int n = end - start;
43 |
44 | float grad_q_val = 0;
45 | for(int i = 0; i < n; i++){
46 | int start_i = start + i;
47 | float grad_out_val = grad_out[start_i*h + h_idx];
48 | int k_idx = index1[start_i];
49 | grad_q_val += grad_out_val * k[k_idx*C + h_idx*d + d_idx];
50 | }
51 | grad_q[q_idx*C + h_idx*d + d_idx] = grad_q_val;
52 |
53 | float grad_k_val = 0;
54 | int start_k = index1_offsets[q_idx];
55 | for(int i = 0; i < n; i++){
56 | int start_i = start_k + i*n;
57 | float grad_out_val = grad_out[start_i*h + h_idx];
58 | int query_idx = index0[start_i];
59 | grad_k_val += grad_out_val * q[query_idx*C + h_idx*d + d_idx];
60 | }
61 | grad_k[q_idx*C + h_idx*d + d_idx] = grad_k_val;
62 | }
63 |
64 | void attention_step1_backward_cuda_launcher(int N, int M, int h, int hdim, const unsigned int n_max,
65 | const float *grad_out, const int *index0, const int *index0_offsets, const int *index1, const int *index1_offsets, const float *q, const float *k, float *grad_q, float *grad_k) {
66 | // input: grad_output: (n, nsample, c), output: grad_input1: (n, c), grad_input2: (n, c)
67 |
68 | unsigned int n_h = h*hdim > 512 ? 512 / hdim : h;
69 |
70 | dim3 blocks(N, h/n_h);
71 | dim3 threads(hdim, n_h);
72 |
73 | attention_step1_backward_cuda_kernel<<>>(N, M, h, hdim, grad_out, index0, index0_offsets, index1, index1_offsets, q, k, grad_q, grad_k);
74 |
75 | }
76 |
77 | __global__ void attention_step2_forward_cuda_kernel( // M, h, hdim
78 | int N, int M, const int h, int d, const float *attn, const float *v,
79 | const int *index0_offsets, const int *index1, float *output) {
80 | // input: attn: (M, h), v: (N, h, hdim), index0: (M, ), index1: (M, ), table: (L, 3, h, hdim), rel_idx: (M, 3)
81 |
82 | int q_idx = blockIdx.x;
83 | int n_h = blockDim.x;
84 | int h_idx = blockIdx.y * n_h + threadIdx.y;
85 | int d_idx = threadIdx.x;
86 |
87 | int C = h*d;
88 |
89 | int start = index0_offsets[q_idx], end = index0_offsets[q_idx+1];
90 | int n = end - start;
91 | float sum = 0;
92 | for(int i = 0; i < n; i++){
93 | int start_i = start + i;
94 | int k_idx = index1[start_i];
95 | float v_val = v[k_idx*C + h_idx*d + d_idx];
96 | sum += attn[start_i*h + h_idx] * v_val;
97 | }
98 | output[q_idx*C + h_idx*d + d_idx] = sum;
99 | }
100 |
101 | void attention_step2_forward_cuda_launcher(int N, int M, const int h, int hdim, int n_max, const float *attn, const float *v, const int *index0_offsets,
102 | const int *index1, float *output) {
103 | // input: attn: (M, h), v: (N, h, hdim), index0: (M, ), index1: (M, ), table: (L, h, hdim, 3), rel_idx: (M, 3)
104 | unsigned int n_h = h*hdim > 512 ? 512 / hdim : h;
105 | dim3 blocks(N, h/n_h);
106 | dim3 threads(hdim, n_h);
107 | attention_step2_forward_cuda_kernel<<>>(N, M, h, hdim, attn, v, index0_offsets, index1, output);
108 | }
109 |
110 | __global__ void attention_step2_grad_v_backward_cuda_kernel( // M, h, hdim
111 | int N, int M, int h, int hdim, const float *grad_out, const int *index0, const int *index0_offsets, const int *index1, const int *index1_offsets, const float *attn, const float *v,
112 | float *grad_v) {
113 | // input: attn: (M, h), v: (h, hdim, N), index0: (M, ), index1: (M, ), rel_idx: (3, M)
114 |
115 | int q_idx = blockIdx.x;
116 | int n_h = blockDim.x;
117 | int h_idx = blockIdx.y * n_h + threadIdx.y;
118 | int d_idx = threadIdx.x;
119 |
120 | int C = h*hdim;
121 |
122 | int start = index0_offsets[q_idx], end = index0_offsets[q_idx+1];
123 | int n = end - start;
124 | int start_k = index1_offsets[q_idx];
125 | float grad_v_val = 0;
126 |
127 | for(int i = 0; i < n; i ++){
128 | int start_i = start_k + i*n;
129 | int query_idx = index0[start_i];
130 | float grad_out_val = grad_out[query_idx*C + h_idx*hdim + d_idx];
131 |
132 | float grad_val = attn[start_i*h + h_idx] * grad_out_val;
133 | grad_v_val += grad_val;
134 | }
135 | grad_v[q_idx*C + h_idx*hdim + d_idx] = grad_v_val;
136 |
137 | }
138 |
139 | void attention_step2_backward_cuda_launcher(int N, int M, int h, int hdim, int n_max, const float *grad_out, const int *index0, const int *index0_offsets,
140 | const int *index1, const int *index1_offsets, const float *attn, const float *v, float *grad_attn, float *grad_v) {
141 | // input: grad_out: (N, h, hdim)
142 |
143 | unsigned int n_h = h*hdim > 512 ? 512 / hdim : h;
144 | dim3 blocks(N, h/n_h);
145 | dim3 threads(hdim, n_h);
146 | attention_step2_grad_v_backward_cuda_kernel<<>>(N, M, h, hdim, grad_out, index0, index0_offsets, index1, index1_offsets, attn, v, grad_v);
147 |
148 | unsigned int n_threads = 512;
149 | dim3 blocks_2((M + n_threads - 1) / n_threads, h);
150 |
151 | attention_step1_forward_cuda_kernel<<>>(N, N, M, h, hdim, grad_out, v, index0, index1, grad_attn);
152 | }
153 |
--------------------------------------------------------------------------------
/src/models/SphereFormer/SparseTransformer/src/sptr/attention/attention_cuda_kernel.h:
--------------------------------------------------------------------------------
1 | #ifndef _ATTENTION_CUDA_KERNEL
2 | #define _ATTENTION_CUDA_KERNEL
3 | #include
4 | #include
5 | #include
6 |
7 | void attention_step1_forward_cuda(int N_q, int N_k, int M, int h, int hdim, const unsigned int n_max, at::Tensor q_tensor, at::Tensor k_tensor, at::Tensor index0_tensor, at::Tensor index1_tensor, at::Tensor attn_tensor);
8 | void attention_step1_backward_cuda(int N, int M, int h, int hdim, const unsigned int n_max, at::Tensor grad_out_tensor, at::Tensor index0_tensor, at::Tensor index0_tensor_offsets, at::Tensor index1_tensor, at::Tensor index1_tensor_offsets, at::Tensor q_tensor, at::Tensor k_tensor, at::Tensor grad_q_tensor, at::Tensor grad_k_tensor);
9 |
10 | void attention_step2_forward_cuda(int N, int M, int h, int hdim, int n_max, at::Tensor attn_tensor, at::Tensor v_tensor, at::Tensor index0_offsets_tensor, at::Tensor index1_tensor, at::Tensor output_tensor);
11 | void attention_step2_backward_cuda(int N, int M, int h, int hdim, int n_max, at::Tensor grad_out_tensor, at::Tensor index0_tensor, at::Tensor index0_offsets_tensor, at::Tensor index1_tensor, at::Tensor index1_offsets_tensor, at::Tensor attn_tensor, at::Tensor v_tensor, at::Tensor grad_attn_tensor, at::Tensor grad_v_tensor);
12 |
13 | #ifdef __cplusplus
14 | extern "C" {
15 | #endif
16 |
17 | void attention_step1_forward_cuda_launcher(int N_q, int N_k, int M, int h, int hdim, const unsigned int n_max, const float *q, const float *k, const int *index0, const int *index1, float *attn);
18 | void attention_step1_backward_cuda_launcher(int N, int M, int h, int hdim, const unsigned int n_max, const float *grad_out, const int *index0, const int *index0_offsets, const int *index1, const int *index1_offsets, const float *q, const float *k, float *grad_q, float *grad_k);
19 |
20 | void attention_step2_forward_cuda_launcher(int N, int M, const int h, int hdim, int n_max, const float *attn, const float *v, const int *index0_offsets, const int *index1, float *output);
21 | void attention_step2_backward_cuda_launcher(int N, int M, int h, int hdim, int n_max, const float *grad_out, const int *index0, const int *index0_offsets, const int *index1, const int *index1_offsets, const float *attn, const float *v, float *grad_attn, float *grad_v);
22 |
23 | #ifdef __cplusplus
24 | }
25 | #endif
26 | #endif
27 |
--------------------------------------------------------------------------------
/src/models/SphereFormer/SparseTransformer/src/sptr/cuda_utils.h:
--------------------------------------------------------------------------------
1 | #ifndef _CUDA_UTILS_H
2 | #define _CUDA_UTILS_H
3 |
4 | #include
5 | #include
6 |
7 | #define TOTAL_THREADS 1024
8 | #define THREADS_PER_BLOCK 256
9 | #define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
10 |
11 | inline int opt_n_threads(int work_size) {
12 | const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0);
13 | return std::max(std::min(1 << pow_2, TOTAL_THREADS), 1);
14 | }
15 |
16 | inline dim3 opt_block_config(int x, int y) {
17 | const int x_threads = opt_n_threads(x);
18 | const int y_threads = std::max(std::min(opt_n_threads(y), TOTAL_THREADS / x_threads), 1);
19 | dim3 block_config(x_threads, y_threads, 1);
20 | return block_config;
21 | }
22 |
23 | #endif
24 |
--------------------------------------------------------------------------------
/src/models/SphereFormer/SparseTransformer/src/sptr/pointops_api.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | #include "attention/attention_cuda_kernel.h"
5 | #include "rpe/relative_pos_encoding_cuda_kernel.h"
6 | #include "precompute/precompute_cuda_kernel.h"
7 |
8 |
9 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
10 | m.def("attention_step1_forward_cuda", &attention_step1_forward_cuda, "attention_step1_forward_cuda");
11 | m.def("attention_step1_backward_cuda", &attention_step1_backward_cuda, "attention_step1_backward_cuda");
12 | m.def("attention_step2_forward_cuda", &attention_step2_forward_cuda, "attention_step2_forward_cuda");
13 | m.def("attention_step2_backward_cuda", &attention_step2_backward_cuda, "attention_step2_backward_cuda");
14 | m.def("precompute_all_cuda", &precompute_all_cuda, "precompute_all_cuda");
15 | m.def("dot_prod_with_idx_forward_cuda", &dot_prod_with_idx_forward_cuda, "dot_prod_with_idx_forward_cuda");
16 | m.def("dot_prod_with_idx_backward_cuda", &dot_prod_with_idx_backward_cuda, "dot_prod_with_idx_backward_cuda");
17 | m.def("attention_step2_with_rel_pos_value_forward_cuda", &attention_step2_with_rel_pos_value_forward_cuda, "attention_step2_with_rel_pos_value_forward_cuda");
18 | m.def("attention_step2_with_rel_pos_value_backward_cuda", &attention_step2_with_rel_pos_value_backward_cuda, "attention_step2_with_rel_pos_value_backward_cuda");
19 | m.def("dot_prod_with_idx_all_forward_cuda", &dot_prod_with_idx_all_forward_cuda, "dot_prod_with_idx_all_forward_cuda");
20 | }
21 |
--------------------------------------------------------------------------------
/src/models/SphereFormer/SparseTransformer/src/sptr/precompute/precompute.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 | #include "precompute_cuda_kernel.h"
6 |
7 | void precompute_all_cuda(int N, int n, const unsigned int n_max, at::Tensor counts_tensor, at::Tensor offsets_tensor, at::Tensor sq_offsets_tensor, at::Tensor index_0_offsets_tensor, at::Tensor index_1_offsets_tensor, at::Tensor index_0_tensor, at::Tensor index_1_tensor)
8 | {
9 | const int *counts = counts_tensor.data_ptr();
10 | const int *offsets = offsets_tensor.data_ptr();
11 | const int *sq_offsets = sq_offsets_tensor.data_ptr();
12 | int *index_0_offsets = index_0_offsets_tensor.data_ptr();
13 | int *index_1_offsets = index_1_offsets_tensor.data_ptr();
14 | int *index_0 = index_0_tensor.data_ptr();
15 | int *index_1 = index_1_tensor.data_ptr();
16 | precompute_all_cuda_launcher(N, n, n_max, counts, offsets, sq_offsets, index_0_offsets, index_1_offsets, index_0, index_1);
17 | }
18 |
--------------------------------------------------------------------------------
/src/models/SphereFormer/SparseTransformer/src/sptr/precompute/precompute_cuda_kernel.cu:
--------------------------------------------------------------------------------
1 | #include "../cuda_utils.h"
2 | #include "precompute_cuda_kernel.h"
3 |
4 | __global__ void precompute_all_cuda_kernel( // M, h, C//h
5 | int N, int n, int k, const int *counts, const int *offsets, const int *sq_offsets, int *index0_offsets, int *index1_offsets, int *index0, int *index1) {
6 | // counts: (n), sq_offsets: (n), index0_offsets: (n), index1_offsets: (n)
7 |
8 | int n_idx = blockIdx.x;
9 | int thread_idx = threadIdx.x;
10 |
11 | int start = offsets[n_idx];
12 | int start_val = sq_offsets[n_idx];
13 | int length = counts[n_idx];
14 | for(int t_idx = thread_idx; t_idx < length; t_idx += blockDim.x){
15 | index0_offsets[start+t_idx] = start_val + length * t_idx;
16 | index1_offsets[start+t_idx] = start_val + t_idx;
17 | for(int i = 0; i < length; i++){
18 | index0[start_val + i*length + t_idx] = start+i;
19 | index1[start_val + i*length + t_idx] = start+t_idx;
20 | }
21 | }
22 | }
23 |
24 | void precompute_all_cuda_launcher(int N, int n, const unsigned int n_max, const int *counts,
25 | const int *offsets, const int *sq_offsets, int *index_0_offsets, int *index_1_offsets, int *index_0, int *index_1) {
26 | // input: attn: (M, h), index0: (M, ), index1: (M, )
27 |
28 | unsigned int blocks = n;
29 | unsigned int n_threads = opt_n_threads(n_max);
30 | n_threads = n_threads == n_max ? n_threads : n_threads * 2;
31 | n_threads = n_threads > 1024 ? 1024 : n_threads;
32 |
33 | precompute_all_cuda_kernel<<>>(N, n, n_max, counts, offsets, sq_offsets, index_0_offsets, index_1_offsets, index_0, index_1);
34 |
35 | }
--------------------------------------------------------------------------------
/src/models/SphereFormer/SparseTransformer/src/sptr/precompute/precompute_cuda_kernel.h:
--------------------------------------------------------------------------------
1 | #ifndef PRECOMPUTE_CUDA_KERNEL
2 | #define PRECOMPUTE_CUDA_KERNEL
3 | #include
4 | #include
5 | #include
6 |
7 | void precompute_all_cuda(int N, int n, const unsigned int n_max, at::Tensor counts_tensor, at::Tensor offsets_tensor, at::Tensor sq_offsets_tensor, at::Tensor index_0_offsets_tensor, at::Tensor index_1_offsets_tensor, at::Tensor index_0, at::Tensor index_1);
8 |
9 | #ifdef __cplusplus
10 | extern "C" {
11 | #endif
12 |
13 | void precompute_all_cuda_launcher(int N, int n, const unsigned int n_max, const int *counts, const int *offsets, const int *sq_offsets, int *index_0_offsets, int *index_1_offsets, int *index_0, int *index_1);
14 |
15 | #ifdef __cplusplus
16 | }
17 | #endif
18 | #endif
19 |
--------------------------------------------------------------------------------
/src/models/SphereFormer/SparseTransformer/src/sptr/rpe/relative_pos_encoding_cuda.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 | #include "relative_pos_encoding_cuda_kernel.h"
6 |
7 | void dot_prod_with_idx_forward_cuda(int N, int M, int h, int hdim, int n_max, const int L, at::Tensor q_tensor,
8 | at::Tensor index_q_tensor, at::Tensor index_q_offsets_tensor, at::Tensor k_tensor, at::Tensor index_k_tensor, at::Tensor table_q_tensor,
9 | at::Tensor table_k_tensor, at::Tensor rel_idx_tensor, at::Tensor output_tensor)
10 | {
11 | const float *q = q_tensor.data_ptr();
12 | const int *index_q = index_q_tensor.data_ptr();
13 | const int *index_q_offsets = index_q_offsets_tensor.data_ptr();
14 | const float *k = k_tensor.data_ptr();
15 | const int *index_k = index_k_tensor.data_ptr();
16 | const float *table_q = table_q_tensor.data_ptr();
17 | const float *table_k = table_k_tensor.data_ptr();
18 | const int *rel_idx = rel_idx_tensor.data_ptr();
19 | float *output = output_tensor.data_ptr();
20 | dot_prod_with_idx_forward_cuda_launcher(N, M, h, hdim, n_max, L, q, index_q, index_q_offsets, k, index_k, table_q, table_k, rel_idx, output);
21 | }
22 |
23 | void dot_prod_with_idx_backward_cuda(int N, int M, int h, int hdim, int n_max, const int L, at::Tensor grad_out_tensor,
24 | at::Tensor q_tensor, at::Tensor index_q_offsets_tensor, at::Tensor k_tensor, at::Tensor index_k_offsets_tensor, at::Tensor index_k_tensor,
25 | at::Tensor table_q_tensor, at::Tensor table_k_tensor, at::Tensor rel_idx_tensor, at::Tensor grad_q_tensor,
26 | at::Tensor grad_k_tensor, at::Tensor grad_table_q_tensor, at::Tensor grad_table_k_tensor)
27 | {
28 | const float *grad_out = grad_out_tensor.data_ptr();
29 | const float *q = q_tensor.data_ptr();
30 | const int *index_q_offsets = index_q_offsets_tensor.data_ptr();
31 | const float *k = k_tensor.data_ptr();
32 | const int *index_k_offsets = index_k_offsets_tensor.data_ptr();
33 | const int *index_k = index_k_tensor.data_ptr();
34 | const float *table_q = table_q_tensor.data_ptr();
35 | const float *table_k = table_k_tensor.data_ptr();
36 | const int *rel_idx = rel_idx_tensor.data_ptr();
37 | float *grad_q = grad_q_tensor.data_ptr();
38 | float *grad_k = grad_k_tensor.data_ptr();
39 | float *grad_table_q = grad_table_q_tensor.data_ptr();
40 | float *grad_table_k = grad_table_k_tensor.data_ptr();
41 | dot_prod_with_idx_backward_cuda_launcher(N, M, h, hdim, n_max, L, grad_out, q, index_q_offsets, k, index_k_offsets, index_k, table_q, table_k, rel_idx, grad_q, grad_k, grad_table_q, grad_table_k);
42 | }
43 |
44 | void dot_prod_with_idx_all_forward_cuda(int N, int M, int h, int hdim, int n_max, const int L, at::Tensor q_tensor,
45 | at::Tensor index_q_tensor, at::Tensor index_q_offsets_tensor, at::Tensor k_tensor, at::Tensor index_k_tensor, at::Tensor table_q_tensor,
46 | at::Tensor table_k_tensor, at::Tensor rel_idx_tensor, at::Tensor output_tensor)
47 | {
48 | const float *q = q_tensor.data_ptr();
49 | const int *index_q = index_q_tensor.data_ptr();
50 | const int *index_q_offsets = index_q_offsets_tensor.data_ptr();
51 | const float *k = k_tensor.data_ptr();
52 | const int *index_k = index_k_tensor.data_ptr();
53 | const float *table_q = table_q_tensor.data_ptr();
54 | const float *table_k = table_k_tensor.data_ptr();
55 | const int *rel_idx = rel_idx_tensor.data_ptr();
56 | float *output = output_tensor.data_ptr();
57 | dot_prod_with_idx_all_forward_cuda_launcher(N, M, h, hdim, n_max, L, q, index_q, index_q_offsets, k, index_k, table_q, table_k, rel_idx, output);
58 | }
59 |
60 | void attention_step2_with_rel_pos_value_forward_cuda(int N, int M, int h, int hdim, int n_max, at::Tensor attn_tensor, at::Tensor v_tensor,
61 | at::Tensor index0_offsets_tensor, at::Tensor index1_tensor, at::Tensor table_tensor, at::Tensor rel_idx_tensor, at::Tensor output_tensor)
62 | {
63 | const float *attn = attn_tensor.data_ptr();
64 | const float *v = v_tensor.data_ptr();
65 | const int *index0_offsets = index0_offsets_tensor.data_ptr();
66 | const int *index1 = index1_tensor.data_ptr();
67 | const float *table = table_tensor.data_ptr();
68 | const int *rel_idx = rel_idx_tensor.data_ptr();
69 | float *output = output_tensor.data_ptr();
70 | attention_step2_with_rel_pos_value_forward_cuda_launcher(N, M, h, hdim, n_max, attn, v, index0_offsets, index1, table, rel_idx, output);
71 | }
72 |
73 | void attention_step2_with_rel_pos_value_backward_cuda(int N, int M, int h, int hdim, int L, int n_max, at::Tensor grad_out_tensor, at::Tensor index0_tensor,
74 | at::Tensor index0_offsets_tensor, at::Tensor index1_tensor, at::Tensor index1_offsets_tensor, at::Tensor attn_tensor, at::Tensor v_tensor, at::Tensor table_tensor,
75 | at::Tensor rel_idx_tensor, at::Tensor grad_attn_tensor, at::Tensor grad_v_tensor, at::Tensor grad_table_tensor)
76 | {
77 | const float *grad_out = grad_out_tensor.data_ptr();
78 | const int *index0 = index0_tensor.data_ptr();
79 | const int *index0_offsets = index0_offsets_tensor.data_ptr();
80 | const int *index1 = index1_tensor.data_ptr();
81 | const int *index1_offsets = index1_offsets_tensor.data_ptr();
82 | const float *attn = attn_tensor.data_ptr();
83 | const float *v = v_tensor.data_ptr();
84 | const float *table = table_tensor.data_ptr();
85 | const int *rel_idx = rel_idx_tensor.data_ptr();
86 | float *grad_attn = grad_attn_tensor.data_ptr();
87 | float *grad_v = grad_v_tensor.data_ptr();
88 | float *grad_table = grad_table_tensor.data_ptr();
89 | attention_step2_with_rel_pos_value_backward_cuda_launcher(N, M, h, hdim, L, n_max, grad_out, index0, index0_offsets, index1, index1_offsets, attn, v, table, rel_idx, grad_attn, grad_v, grad_table);
90 | }
91 |
--------------------------------------------------------------------------------
/src/models/SphereFormer/SparseTransformer/src/sptr/rpe/relative_pos_encoding_cuda_kernel.h:
--------------------------------------------------------------------------------
1 | #ifndef _RPE_CUDA_KERNEL
2 | #define _RPE_CUDA_KERNEL
3 | #include
4 | #include
5 | #include
6 |
7 | void dot_prod_with_idx_forward_cuda(int N, int M, int h, int hdim, int n_max, const int L, at::Tensor q_tensor, at::Tensor index_q_tensor, at::Tensor index_q_offsets_tensor, at::Tensor k_tensor, at::Tensor index_k_tensor, at::Tensor table_q_tensor, at::Tensor table_k_tensor, at::Tensor rel_idx_tensor, at::Tensor output_tensor);
8 | void dot_prod_with_idx_backward_cuda(int N, int M, int h, int hdim, int n_max, const int L, at::Tensor grad_out_tensor, at::Tensor q_tensor, at::Tensor index_q_offsets_tensor, at::Tensor k_tensor, at::Tensor index_k_offsets_tensor, at::Tensor index_k_tensor, at::Tensor table_q_tensor, at::Tensor table_k_tensor, at::Tensor rel_idx_tensor, at::Tensor grad_q_tensor, at::Tensor grad_k_tensor, at::Tensor grad_table_q_tensor, at::Tensor grad_table_k_tensor);
9 |
10 | void dot_prod_with_idx_all_forward_cuda(int N, int M, int h, int hdim, int n_max, const int L, at::Tensor q_tensor, at::Tensor index_q_tensor, at::Tensor index_q_offsets_tensor, at::Tensor k_tensor, at::Tensor index_k_tensor, at::Tensor table_q_tensor, at::Tensor table_k_tensor, at::Tensor rel_idx_tensor, at::Tensor output_tensor);
11 |
12 | void attention_step2_with_rel_pos_value_forward_cuda(int N, int M, int h, int hdim, int n_max, at::Tensor attn_tensor, at::Tensor v_tensor, at::Tensor index0_offsets_tensor, at::Tensor index1_tensor, at::Tensor table_tensor, at::Tensor rel_idx_tensor, at::Tensor output_tensor);
13 | void attention_step2_with_rel_pos_value_backward_cuda(int N, int M, int h, int hdim, int L, int n_max, at::Tensor grad_out_tensor, at::Tensor index0_tensor, at::Tensor index0_offsets_tensor, at::Tensor index1_tensor, at::Tensor index1_offsets_tensor, at::Tensor attn_tensor, at::Tensor v_tensor, at::Tensor table_tensor, at::Tensor rel_idx_tensor, at::Tensor grad_attn_tensor, at::Tensor grad_v_tensor, at::Tensor grad_table_tensor);
14 |
15 | #ifdef __cplusplus
16 | extern "C" {
17 | #endif
18 |
19 | void dot_prod_with_idx_forward_cuda_launcher(int N, int M, int h, int hdim, int n_max, const int L, const float *q, const int *index_q, const int *index_q_offsets, const float *k, const int *index_k, const float *table_q, const float *table_k, const int *rel_idx, float *output);
20 | void dot_prod_with_idx_backward_cuda_launcher(int N, int M, int h, int hdim, int n_max, const int L, const float *grad_out, const float *q, const int *index_q_offsets, const float *k, const int *index_k_offsets, const int *index_k, const float *table_q, const float *table_k, const int *rel_idx, float *grad_q, float *grad_k, float *grad_table_q, float *grad_table_k);
21 |
22 | void dot_prod_with_idx_all_forward_cuda_launcher(int N, int M, int h, int hdim, int n_max, const int L, const float *q, const int *index_q, const int *index_q_offsets, const float *k, const int *index_k, const float *table_q, const float *table_k, const int *rel_idx, float *output);
23 |
24 | void attention_step2_with_rel_pos_value_forward_cuda_launcher(int N, int M, int h, int hdim, int n_max, const float *attn, const float *v, const int *index0_offsets, const int *index1, const float *table, const int *rel_idx, float *output);
25 | void attention_step2_with_rel_pos_value_backward_cuda_launcher(int N, int M, int h, int hdim, int L, int n_max, const float *grad_out, const int *index0, const int *index0_offsets, const int *index1, const int *index1_offsets, const float *attn, const float *v, const float *table, const int *rel_idx, float *grad_attn, float *grad_v, float *grad_table);
26 |
27 |
28 | #ifdef __cplusplus
29 | }
30 | #endif
31 | #endif
32 |
--------------------------------------------------------------------------------
/src/models/SphereFormer/SparseTransformer/test/test_attention_op_step1.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import pointops
3 | from torch_scatter import scatter_max, scatter_mean, scatter_add, scatter_min, scatter_sum
4 | import sys
5 | sys.path.append("..")
6 | import sptr
7 |
8 | torch.manual_seed(1)
9 |
10 | # M = 800000
11 | N = 35000
12 | n = 1500
13 | k = 100
14 | C = 96
15 | h = 6
16 | query = torch.rand(N, h, C//h).cuda()
17 | key = torch.rand(N, h, C//h).cuda()
18 |
19 | v2p_map = torch.randint(low=0, high=n, size=(N,)).cuda()
20 | v2p_map, _ = v2p_map.sort()
21 | counts = v2p_map.bincount()
22 | M = (counts**2).sum().item()
23 |
24 | N = v2p_map.shape[0]
25 | mask = torch.arange(k)[None].cuda().expand(n, -1) < counts[:, None] #[n, k]
26 | to_add = torch.arange(k)[None].cuda().expand(n, -1)[mask]
27 | v2p_map = v2p_map.long()
28 | p2v_map = torch.zeros(n, k).long().cuda() #torch.zeros_like(p2v_map)
29 | p2v_map[mask] = torch.arange(N).cuda()
30 | ctg_index_1_offsets = torch.cat([torch.zeros(1, dtype=torch.long).cuda(), (counts ** 2).cumsum(-1)], 0)[v2p_map] + to_add
31 |
32 | index_0_counts = counts[v2p_map.long()] #[N, ]
33 | index_0_offsets = index_0_counts.cumsum(-1) #[N, ]
34 | index_0_offsets = torch.cat([torch.zeros(1, dtype=torch.long).cuda(), index_0_offsets], 0) #[N+1]
35 | n_max = p2v_map.shape[1]
36 | index_0, index_1 = pointops.precompute_index_pairs(p2v_map, counts, index_0_offsets)
37 | index_0 = index_0.long()
38 | index_1 = index_1.long()
39 |
40 | # index_0 = torch.rand(M)
41 | # index_0[index_0 < 0] = 0
42 | # index_0 = (index_0*N).long().cuda()
43 |
44 | # index_1 = torch.rand(M)
45 | # index_1[index_1 < 0] = 0
46 | # index_1 = (index_1*N).long().cuda()
47 |
48 | query.requires_grad = True
49 | key.requires_grad = True
50 |
51 | # # rearrange index for acceleration
52 | # index_0, indices = torch.sort(index_0) #[M,]
53 | # index_1 = index_1[indices] #[M,]
54 | # index_0_counts = index_0.bincount()
55 | # n_max = index_0_counts.max()
56 | # index_0_offsets = index_0_counts.cumsum(dim=-1) #[N]
57 | # index_0_offsets = torch.cat([torch.zeros(1, dtype=torch.long).cuda(), index_0_offsets], 0) #[N+1]
58 |
59 | attn_flat = pointops.attention_step1(query.float(), key.float(), index_0.int(), index_1.int())
60 | loss = attn_flat.sum()
61 | loss.backward()
62 | print("attn_flat.shape: {}, attn_flat[:20,:10]: {}".format(attn_flat.shape, attn_flat[:20,:10]))
63 | print("query.grad[:5, :3, :5]: ", query.grad[:5, :3, :5])
64 | print("key.grad[:5, :3, :5]: ", key.grad[:5, :3, :5])
65 | # input()
66 |
67 | query_grad = query.grad.clone()
68 | key_grad = key.grad.clone()
69 |
70 | query.grad.zero_()
71 | key.grad.zero_()
72 |
73 | # # print("index_0[:100]: ", index_0[:100])
74 | # print("n_max: ", n_max)
75 | # print("index_0_offsets.shape: ", index_0_offsets.shape)
76 | # # input()
77 |
78 | # print("index_0_offsets[:100]: ", index_0_offsets[:100])
79 | # print("index_1[:20]: ", index_1[:20])
80 |
81 |
82 | # attn_flat = pointops.attention_step1(query.float(), key.float(), index_0.int(), index_1.int())
83 | # loss = attn_flat.sum()
84 | # loss.backward()
85 | # # attn_flat = pointops.attention_step1(query.float(), key.float(), index_0.int(), index_1.int())
86 | # # loss = attn_flat.sum()
87 | # # loss.backward()
88 | # print("attn_flat.shape: {}, attn_flat[:20,:10]: {}".format(attn_flat.shape, attn_flat[:20,:10]))
89 | # print("query.grad[:5, :3, :5]: ", query.grad[:5, :3, :5])
90 | # print("key.grad[:5, :3, :5]: ", key.grad[:5, :3, :5])
91 | # input()
92 |
93 | print("query.is_contiguous(): ", query.is_contiguous())
94 | print("key.is_contiguous(): ", key.is_contiguous())
95 | print("index_0.is_contiguous(): ", index_0.is_contiguous())
96 | print("index_1.is_contiguous(): ", index_1.is_contiguous())
97 |
98 | attn_flat_v2 = sptr.attention_step1(query.float(), key.float(), index_0.int(), index_0_offsets.int(), index_1.int(), ctg_index_1_offsets.int(), n_max)
99 | loss = attn_flat_v2.sum()
100 | loss.backward()
101 |
102 | # attn_flat_v2 = pointops.attention_step1_v2(query.float(), key.float(), index_1.int(), index_0_offsets.int(), n_max)
103 | # loss = attn_flat_v2.sum()
104 | # loss.backward()
105 |
106 | print("attn_flat_v2.shape: {}, attn_flat_v2[:20,:10]: {}".format(attn_flat_v2.shape, attn_flat_v2[:20,:10]))
107 | print("query.grad[:5, :3, :5]: ", query.grad[:5, :3, :5])
108 | print("key.grad[:5, :3, :5]: ", key.grad[:5, :3, :5])
109 | # input()
110 |
111 | # mask = attn_flat_v2.sum(-1) != 0
112 | # print("mask.sum(): ", mask.sum())
113 | # print("attn_flat_v2[mask] - attn_flat[mask]: ", ((attn_flat_v2[mask] - attn_flat[mask])**2).max())
114 |
115 |
116 | print("((attn_flat-attn_flat_v2).abs()).max(): ", ((attn_flat-attn_flat_v2).abs()).max())
117 |
118 | print("(query.grad-query_grad).abs().max(): ", (query.grad-query_grad).abs().max())
119 | print("(key.grad-key_grad).abs().max(): ", (key.grad-key_grad).abs().max())
120 |
121 | # selected = 10000
122 | # print("torch.max((attn_flat[:selected]-attn_flat_v2[:selected])**2, 0): ", torch.max((attn_flat[:selected]-attn_flat_v2[:selected])**2, 0))
123 |
124 |
--------------------------------------------------------------------------------
/src/models/SphereFormer/SparseTransformer/test/test_attention_op_step2.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import pointops
3 | from torch_scatter import scatter_max, scatter_mean, scatter_add, scatter_min, scatter_sum
4 | import sys
5 | sys.path.append("..")
6 | import sptr
7 |
8 | torch.manual_seed(2)
9 |
10 | # M = 80000
11 | N = 3500 #10 #5 #10 #3500
12 | n = 150 #2 #150
13 | # k = 5 #7 #65
14 | hdim = 16
15 | h = 6 #1 #6
16 | L = 31 #2 #31
17 | v = torch.rand(N, h, hdim).cuda()
18 | # table = torch.rand(L, h, hdim, 3).cuda()
19 |
20 | # index_0 = torch.rand(M)
21 | # index_0[index_0 < 0] = 0
22 | # index_0 = (index_0*N).long().cuda()
23 |
24 | # index_1 = torch.rand(M)
25 | # index_1[index_1 < 0] = 0
26 | # index_1 = (index_1*N).long().cuda()
27 |
28 |
29 | v2p_map = torch.randint(low=0, high=n, size=(N,)).cuda()
30 | v2p_map, _ = v2p_map.sort()
31 | counts = v2p_map.bincount()
32 | M = (counts**2).sum().item()
33 | k = counts.max().item()
34 |
35 | print("counts: ", counts)
36 |
37 | attn = torch.rand(M, h).cuda()
38 |
39 | # v2p_map, ctg_sort_idx = v2p_map.sort()
40 | # n, k = p2v_map.shape
41 | # N = v2p_map.shape[0]
42 | mask = torch.arange(k)[None].cuda().expand(n, -1) < counts[:, None] #[n, k]
43 | to_add = torch.arange(k)[None].cuda().expand(n, -1)[mask]
44 | v2p_map = v2p_map.long()
45 | p2v_map = torch.zeros(n, k).long().cuda() #torch.zeros_like(p2v_map)
46 | p2v_map[mask] = torch.arange(N).cuda()
47 | ctg_index_1_offsets = torch.cat([torch.zeros(1, dtype=torch.long).cuda(), (counts ** 2).cumsum(-1)], 0)[v2p_map] + to_add
48 |
49 | print("M: ", M)
50 | # print("counts[:5]: {}, v2p_map[:5]: {}, p2v_map[:5]: {}".format(counts[:5], v2p_map[:5], p2v_map[:5]))
51 | # print("ctg_index_1_offsets[:50]: ", ctg_index_1_offsets[:50])
52 |
53 | # print("ctg_index_1_offsets.max(): {}, ctg_index_1_offsets.min(): {}".format(ctg_index_1_offsets.max(), ctg_index_1_offsets.min()))
54 | print("ctg_index_1_offsets: ", ctg_index_1_offsets)
55 |
56 | # rel_index = torch.rand(M, 3)
57 | # rel_index[rel_index < 0] = 0
58 | # rel_index = (rel_index*L).long().cuda()
59 |
60 | # # print("rel_index.min(): {}, rel_index.max(): {}".format(
61 | # # rel_index.min(), rel_index.max()
62 | # # ))
63 |
64 | # print("rel_index: ", rel_index)
65 |
66 | index_0_counts = counts[v2p_map.long()] #[N, ]
67 | index_0_offsets = index_0_counts.cumsum(-1) #[N, ]
68 | index_0_offsets = torch.cat([torch.zeros(1, dtype=torch.long).cuda(), index_0_offsets], 0) #[N+1]
69 | n_max = p2v_map.shape[1]
70 | index_0, index_1 = pointops.precompute_index_pairs(p2v_map, counts, index_0_offsets)
71 | index_0 = index_0.long()
72 | index_1 = index_1.long()
73 |
74 | print("index_0: {}".format(index_0))
75 | print("index_1: {}".format(index_1))
76 |
77 | # print("index_0.max(): {}, index_0.min(): {}".format(index_0.max(), index_0.min()))
78 |
79 | # print("index_1.max(): {}, index_1.min(): {}".format(index_1.max(), index_1.min()))
80 |
81 | # input()
82 |
83 | # # rearrange index for acceleration
84 | # index_0, indices = torch.sort(index_0) #[M,]
85 | # index_1 = index_1[indices] #[M,]
86 | # rel_index = rel_index[indices]
87 | # index_0_counts = index_0.bincount()
88 |
89 | print("index_0_counts.shape: ", index_0_counts.shape)
90 |
91 | # n_max = index_0_counts.max()
92 | # index_0_offsets = index_0_counts.cumsum(dim=-1) #[N]
93 |
94 | # print("v1 index_0_offsets.shape: ", index_0_offsets.shape)
95 |
96 | # index_0_offsets = torch.cat([torch.zeros(1, dtype=torch.long).cuda(), index_0_offsets], 0) #[N+1]
97 |
98 |
99 | attn.requires_grad = True
100 | v.requires_grad = True
101 | # table.requires_grad = True
102 |
103 |
104 | # output = pointops.attention_step2_with_rel_pos_value(attn, v, index_0.int(), index_1.int(), table, rel_index.int())
105 | output = pointops.attention_step2(attn, v, index_0.int(), index_1.int())
106 | loss = output.mean()
107 | loss.backward()
108 |
109 | # print("output.shape: {}, output[:5,:10,:5]: {}".format(output.shape, output[:5,:10, :5]))
110 | # print("attn.grad[:5, :3]: ", attn.grad[:5, :3])
111 | # print("v.grad[:5, :3, :5]: ", v.grad[:5, :3, :5])
112 | # print("table.grad[:5, :3, :5, :2]: ", table.grad[:5, :3, :5, :2])
113 | # # input()
114 |
115 | attn_grad = attn.grad.clone()
116 | v_grad = v.grad.clone()
117 | # table_grad = table.grad.clone()
118 |
119 | attn.grad.zero_()
120 | v.grad.zero_()
121 | # table.grad.zero_()
122 |
123 | # print("query.is_contiguous(): ", query.is_contiguous())
124 | # print("key.is_contiguous(): ", key.is_contiguous())
125 | # print("index_0.is_contiguous(): ", index_0.is_contiguous())
126 | # print("index_1.is_contiguous(): ", index_1.is_contiguous())
127 |
128 | # output_v2 = pointops.attention_step2_with_rel_pos_value_v7(attn, v, index_0.int(), index_0_offsets.int(), n_max, index_1.int(), ctg_index_1_offsets.int(), table, rel_index.int())
129 | # output_v2 = pointops.attention_step2_with_rel_pos_value_v4(attn, v, index_0_offsets.int(), n_max, index_1.int(), table, rel_index.int())
130 | output_v2 = sptr.attention_step2(attn, v, index_0.int(), index_0_offsets.int(), n_max, index_1.int(), ctg_index_1_offsets.int())
131 | loss = output_v2.mean()
132 | loss.backward()
133 |
134 | # print("output_v2.shape: {}, output_v2[:5,:10,:5]: {}".format(output_v2.shape, output_v2[:5,:10,:5]))
135 | # print("v2 attn.grad[:5, :3]: ", attn.grad[:5, :3])
136 | # print("v2 v.grad[:5, :3, :5]: ", v.grad[:5, :3, :5])
137 | # print("v2 table.grad[:5, :3, :5, :2]: ", table.grad[:5, :3, :5, :2])
138 | # # input()
139 |
140 | print("((output-output_v2)**2).max(): ", ((output-output_v2)**2).max())
141 |
142 | print("((attn_grad-attn.grad)**2).max(): ", ((attn_grad-attn.grad)**2).max())
143 |
144 | print("((v_grad-v.grad)**2).max(): ", ((v_grad-v.grad)**2).max())
145 |
146 | # print("((table_grad-table.grad)**2).max(): ", ((table_grad-table.grad)**2).max())
147 |
148 | # print("torch.max((attn_flat-attn_flat_v2)**2): ", torch.max((attn_flat-attn_flat_v2)**2))
149 |
150 |
--------------------------------------------------------------------------------
/src/models/SphereFormer/SparseTransformer/test/test_precompute_all.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import pointops
3 | import sys
4 | sys.path.append("..")
5 | import sptr
6 |
7 | torch.manual_seed(1)
8 |
9 | v2p_map = torch.IntTensor([
10 | 1, 0, 0, 2, 0, 2, 2, 1, 2, 2, 2
11 | ]).cuda()
12 |
13 | p2v_map = torch.IntTensor([
14 | [1, 2, 4, 0, 0, 0],
15 | [0, 7, 0, 0, 0, 0],
16 | [5, 6, 3, 9, 8, 10]
17 | ]).cuda()
18 |
19 | counts = torch.IntTensor([3, 2, 6]).cuda()
20 |
21 | # index_0_counts = counts[v2p_map.long()] #[N, ]
22 | # index_0_offsets = index_0_counts.cumsum(-1) #[N, ]
23 | # index_0_offsets = torch.cat([torch.zeros(1, dtype=torch.long).cuda(), index_0_offsets], 0) #[N+1]
24 |
25 | # N = v2p_map.shape[0]
26 |
27 | print("v2p_map.shape: {}, p2v_map.shape: {}".format(v2p_map.shape, p2v_map.shape))
28 |
29 | v2p_map, ctg_sort_idx = v2p_map.sort()
30 | n, k = p2v_map.shape
31 | N = v2p_map.shape[0]
32 | mask = torch.arange(k)[None].cuda().expand(n, -1) < counts[:, None] #[n, k]
33 | to_add = torch.arange(k)[None].cuda().expand(n, -1)[mask]
34 | v2p_map = v2p_map.long()
35 | p2v_map = torch.zeros_like(p2v_map)
36 | p2v_map[mask] = torch.arange(N).int().cuda()
37 | ctg_index_1_offsets = torch.cat([torch.zeros(1, dtype=torch.long).cuda(), (counts ** 2).cumsum(-1)], 0)[v2p_map] + to_add
38 |
39 | index_params = (ctg_index_1_offsets, ctg_sort_idx)
40 | index_0_counts = counts[v2p_map.long()] #[N, ]
41 | index_0_offsets = index_0_counts.cumsum(-1) #[N, ]
42 | index_0_offsets = torch.cat([torch.zeros(1, dtype=torch.long).cuda(), index_0_offsets], 0) #[N+1]
43 | n_max = p2v_map.shape[1]
44 | index_0, index_1 = pointops.precompute_index_pairs(p2v_map, counts, index_0_offsets)
45 | index_0 = index_0.long()
46 | index_1 = index_1.long()
47 |
48 |
49 | # index_0_offsets_, index_1_offsets_ = pointops.precompute_offsets(N, p2v_map.shape[0], p2v_map.shape[1], counts)
50 |
51 | # # assert (index_0_offsets_ == index_0_offsets).all()
52 | # print("index_0_offsets: ", index_0_offsets)
53 | # print("index_0_offsets_: ", index_0_offsets_)
54 | # print("ctg_index_1_offsets: ", ctg_index_1_offsets)
55 | # print("index_1_offsets_: ", index_1_offsets_)
56 |
57 | # index_0, index_1 = pointops.precompute_index_pairs(p2v_map, counts, index_0_offsets)
58 |
59 | # print("index_0: ", index_0)
60 | # print("index_1: ", index_1)
61 |
62 | # print("index_0.shape: ", index_0.shape)
63 |
64 |
65 | index_0_offsets_, index_1_offsets_, index_0_, index_1_ = sptr.precompute_all(N, p2v_map.shape[0], p2v_map.shape[1], counts)
66 |
67 | assert (index_0_offsets_ == index_0_offsets).all()
68 | assert (index_1_offsets_ == ctg_index_1_offsets).all()
69 | assert (index_0_ == index_0).all()
70 | assert (index_1_ == index_1).all()
71 |
72 | print("index_0_offsets: ", index_0_offsets)
73 | print("index_0_offsets_: ", index_0_offsets_)
74 | print("ctg_index_1_offsets: ", ctg_index_1_offsets)
75 | print("index_1_offsets_: ", index_1_offsets_)
76 |
77 | print("index_0_offsets.shape: {}, index_0_offsets_.shape: {}".format(index_0_offsets.shape, index_0_offsets_.shape))
78 |
79 | print("ctg_index_1_offsets.shape: {}, index_1_offsets_.shape: {}".format(ctg_index_1_offsets.shape, index_1_offsets_.shape))
80 |
81 | # print("index_0: ", index_0)
82 | # print("index_0_: ", index_0_)
83 | # print("index_1: ", index_1)
84 | # print("index_1_: ", index_1_)
85 |
86 | # print("index_0.shape: ", index_0.shape)
87 |
--------------------------------------------------------------------------------
/src/models/SphereFormer/SparseTransformer/test/test_relative_pos_encoding_op_step1.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import pointops
3 | from torch_scatter import scatter_max, scatter_mean, scatter_add, scatter_min, scatter_sum
4 | import sys
5 | sys.path.append("..")
6 | import sptr
7 |
8 | torch.manual_seed(1)
9 |
10 | # M = 80000
11 | N = 3500
12 | n = 150
13 | k = 65
14 | # M = 80
15 | # N = 5
16 | hdim = 16
17 | h = 6
18 | L = 31
19 | query = torch.rand(N, h, hdim).cuda()
20 | table_q = torch.rand(L, h, hdim, 3).cuda()
21 | key = torch.rand(N, h, hdim).cuda()
22 | table_k = torch.rand(L, h, hdim, 3).cuda()
23 |
24 | # index_q = torch.rand(M)
25 | # index_q[index_q < 0] = 0
26 | # index_q = (index_q*N).long().cuda()
27 |
28 | # index_k = torch.rand(M)
29 | # index_k[index_k < 0] = 0
30 | # index_k = (index_k*N).long().cuda()
31 |
32 | v2p_map = torch.randint(low=0, high=n, size=(N,)).cuda()
33 | v2p_map, _ = v2p_map.sort()
34 | counts = v2p_map.bincount()
35 | M = (counts**2).sum().item()
36 |
37 | # v2p_map, ctg_sort_idx = v2p_map.sort()
38 | # n, k = p2v_map.shape
39 | N = v2p_map.shape[0]
40 | mask = torch.arange(k)[None].cuda().expand(n, -1) < counts[:, None] #[n, k]
41 | to_add = torch.arange(k)[None].cuda().expand(n, -1)[mask]
42 | v2p_map = v2p_map.long()
43 | p2v_map = torch.zeros(n, k).long().cuda() #torch.zeros_like(p2v_map)
44 | p2v_map[mask] = torch.arange(N).cuda()
45 | ctg_index_1_offsets = torch.cat([torch.zeros(1, dtype=torch.long).cuda(), (counts ** 2).cumsum(-1)], 0)[v2p_map] + to_add
46 |
47 | rel_index = torch.rand(M, 3)
48 | rel_index[rel_index < 0] = 0
49 | rel_index = (rel_index*L).long().cuda()
50 |
51 | index_q_counts = counts[v2p_map.long()] #[N, ]
52 | index_q_offsets = index_q_counts.cumsum(-1) #[N, ]
53 | index_q_offsets = torch.cat([torch.zeros(1, dtype=torch.long).cuda(), index_q_offsets], 0) #[N+1]
54 | n_max = p2v_map.shape[1]
55 | index_q, index_k = pointops.precompute_index_pairs(p2v_map, counts, index_q_offsets)
56 | index_q = index_q.long()
57 | index_k = index_k.long()
58 |
59 | # # rearrange index for acceleration
60 | # index_q, indices = torch.sort(index_q) #[M,]
61 | # index_k = index_k[indices] #[M,]
62 | # rel_index = rel_index[indices]
63 | # index_q_counts = index_q.bincount()
64 |
65 | # print("index_q_counts.shape: ", index_q_counts.shape)
66 |
67 | # n_max = index_q_counts.max()
68 | # index_q_offsets = index_q_counts.cumsum(dim=-1) #[N]
69 |
70 | # print("v1 index_q_offsets.shape: ", index_q_offsets.shape)
71 |
72 | # index_q_offsets = torch.cat([torch.zeros(1, dtype=torch.long).cuda(), index_q_offsets], 0) #[N+1]
73 |
74 | # print("index_q[:100]: ", index_q[:100])
75 | print("n_max: ", n_max)
76 | print("index_q_offsets.shape: ", index_q_offsets.shape)
77 | # input()
78 |
79 | print("index_q_offsets[:100]: ", index_q_offsets[:100])
80 | print("index_k[:20]: ", index_k[:20])
81 |
82 | query.requires_grad = True
83 | table_q.requires_grad = True
84 | key.requires_grad = True
85 | table_k.requires_grad = True
86 |
87 | output1 = pointops.dot_prod_with_idx(query, index_q.int(), table_q, rel_index.int())
88 | output2 = pointops.dot_prod_with_idx(key, index_k.int(), table_k, rel_index.int())
89 | output = output1 + output2
90 | loss = output.mean()
91 | loss.backward()
92 |
93 | print("output.shape: {}, output[:5,:10]: {}".format(output.shape, output[:5,:10]))
94 | print("query.grad[:5, :3, :5]: ", query.grad[:5, :3, :5])
95 | print("table_q.grad[:5, :3, :5, :2]: ", table_q.grad[:5, :3, :5, :2])
96 | print("key.grad[:5, :3, :5]: ", key.grad[:5, :3, :5])
97 | print("table_k.grad[:5, :3, :5, :2]: ", table_k.grad[:5, :3, :5, :2])
98 | # input()
99 |
100 | query_grad = query.grad.clone()
101 | key_grad = key.grad.clone()
102 | table_q_grad = table_q.grad.clone()
103 | table_k_grad = table_k.grad.clone()
104 |
105 | query.grad.zero_()
106 | key.grad.zero_()
107 | table_q.grad.zero_()
108 | table_k.grad.zero_()
109 |
110 |
111 | # print("query.is_contiguous(): ", query.is_contiguous())
112 | # print("key.is_contiguous(): ", key.is_contiguous())
113 | # print("index_q.is_contiguous(): ", index_q.is_contiguous())
114 | # print("index_k.is_contiguous(): ", index_k.is_contiguous())
115 | table_q = table_q.detach().permute(0,3,1,2).contiguous()
116 | table_k = table_k.detach().permute(0,3,1,2).contiguous()
117 | table_q.requires_grad = True
118 | table_k.requires_grad = True
119 | output_v2 = sptr.dot_prod_with_idx(query, index_q.int(), index_q_offsets.int(), n_max, key, ctg_index_1_offsets.int(), index_k.int(), table_q, table_k, rel_index.int())
120 | # output_v2 = pointops.dot_prod_with_idx_v5(query, index_q_offsets.int(), n_max, key, index_k.int(), table_q, table_k, rel_index.int())
121 | loss = output_v2.mean()
122 | loss.backward()
123 |
124 | table_q_grad2 = table_q.grad.clone().permute(0,2,3,1).contiguous()
125 | table_k_grad2 = table_k.grad.clone().permute(0,2,3,1).contiguous()
126 |
127 | print("output_v2.shape: {}, output_v2[:5,:10]: {}".format(output_v2.shape, output_v2[:5,:10]))
128 | print("v2 query.grad[:5, :3, :5]: ", query.grad[:5, :3, :5])
129 | print("v2 table_q_grad2[:5, :3, :5, :2]: ", table_q_grad2[:5, :3, :5, :2])
130 | print("v2 key.grad[:5, :3, :5]: ", key.grad[:5, :3, :5])
131 | print("v2 table_k_grad2[:5, :3, :5, :2]: ", table_k_grad2[:5, :3, :5, :2])
132 | # input()
133 |
134 | print("((output-output_v2)**2).max(): ", ((output-output_v2)**2).max())
135 |
136 | print("((query.grad-query_grad)**2).max(): ", ((query.grad-query_grad)**2).max())
137 |
138 | print("((key.grad-key_grad)**2).max(): ", ((key.grad-key_grad)**2).max())
139 |
140 | print("((table_q_grad2-table_q_grad)**2).max(): ", ((table_q_grad2-table_q_grad)**2).max())
141 |
142 | print("((table_k_grad2-table_k_grad)**2).max(): ", ((table_k_grad2-table_k_grad)**2).max())
143 |
144 |
145 |
146 |
--------------------------------------------------------------------------------
/src/models/SphereFormer/SparseTransformer/test/test_relative_pos_encoding_op_step1_all.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import pointops
3 | from torch_scatter import scatter_max, scatter_mean, scatter_add, scatter_min, scatter_sum
4 | import sys
5 | sys.path.append("..")
6 | import sptr
7 |
8 | torch.manual_seed(1)
9 |
10 | # M = 80000
11 | N = 3500
12 | n = 150
13 | k = 65
14 | # M = 80
15 | # N = 5
16 | hdim = 16
17 | h = 6
18 | L = 31
19 | query = torch.rand(N, h, hdim).cuda()
20 | table_q = torch.rand(L, h, hdim, 3).cuda()
21 | key = torch.rand(N, h, hdim).cuda()
22 | table_k = torch.rand(L, h, hdim, 3).cuda()
23 |
24 | # index_q = torch.rand(M)
25 | # index_q[index_q < 0] = 0
26 | # index_q = (index_q*N).long().cuda()
27 |
28 | # index_k = torch.rand(M)
29 | # index_k[index_k < 0] = 0
30 | # index_k = (index_k*N).long().cuda()
31 |
32 | v2p_map = torch.randint(low=0, high=n, size=(N,)).cuda()
33 | v2p_map, _ = v2p_map.sort()
34 | counts = v2p_map.bincount()
35 | M = (counts**2).sum().item()
36 |
37 | # v2p_map, ctg_sort_idx = v2p_map.sort()
38 | # n, k = p2v_map.shape
39 | N = v2p_map.shape[0]
40 | mask = torch.arange(k)[None].cuda().expand(n, -1) < counts[:, None] #[n, k]
41 | to_add = torch.arange(k)[None].cuda().expand(n, -1)[mask]
42 | v2p_map = v2p_map.long()
43 | p2v_map = torch.zeros(n, k).long().cuda() #torch.zeros_like(p2v_map)
44 | p2v_map[mask] = torch.arange(N).cuda()
45 | ctg_index_1_offsets = torch.cat([torch.zeros(1, dtype=torch.long).cuda(), (counts ** 2).cumsum(-1)], 0)[v2p_map] + to_add
46 |
47 | rel_index = torch.rand(M, 3)
48 | rel_index[rel_index < 0] = 0
49 | rel_index = (rel_index*L).long().cuda()
50 |
51 | index_q_counts = counts[v2p_map.long()] #[N, ]
52 | index_q_offsets = index_q_counts.cumsum(-1) #[N, ]
53 | index_q_offsets = torch.cat([torch.zeros(1, dtype=torch.long).cuda(), index_q_offsets], 0) #[N+1]
54 | n_max = p2v_map.shape[1]
55 | index_q, index_k = pointops.precompute_index_pairs(p2v_map, counts, index_q_offsets)
56 | index_q = index_q.long()
57 | index_k = index_k.long()
58 |
59 | # # rearrange index for acceleration
60 | # index_q, indices = torch.sort(index_q) #[M,]
61 | # index_k = index_k[indices] #[M,]
62 | # rel_index = rel_index[indices]
63 | # index_q_counts = index_q.bincount()
64 |
65 | # print("index_q_counts.shape: ", index_q_counts.shape)
66 |
67 | # n_max = index_q_counts.max()
68 | # index_q_offsets = index_q_counts.cumsum(dim=-1) #[N]
69 |
70 | # print("v1 index_q_offsets.shape: ", index_q_offsets.shape)
71 |
72 | # index_q_offsets = torch.cat([torch.zeros(1, dtype=torch.long).cuda(), index_q_offsets], 0) #[N+1]
73 |
74 | # print("index_q[:100]: ", index_q[:100])
75 | print("n_max: ", n_max)
76 | print("index_q_offsets.shape: ", index_q_offsets.shape)
77 | # input()
78 |
79 | print("index_q_offsets[:100]: ", index_q_offsets[:100])
80 | print("index_k[:20]: ", index_k[:20])
81 |
82 | query.requires_grad = True
83 | table_q.requires_grad = True
84 | key.requires_grad = True
85 | table_k.requires_grad = True
86 |
87 | output1 = pointops.dot_prod_with_idx(query, index_q.int(), table_q, rel_index.int())
88 | output2 = pointops.dot_prod_with_idx(key, index_k.int(), table_k, rel_index.int())
89 | attn_flat = pointops.attention_step1(query.float(), key.float(), index_q.int(), index_k.int())
90 |
91 | output = output1 + output2 + attn_flat
92 | loss = output.mean()
93 | loss.backward()
94 |
95 | print("output.shape: {}, output[:5,:10]: {}".format(output.shape, output[:5,:10]))
96 | print("query.grad[:5, :3, :5]: ", query.grad[:5, :3, :5])
97 | print("table_q.grad[:5, :3, :5, :2]: ", table_q.grad[:5, :3, :5, :2])
98 | print("key.grad[:5, :3, :5]: ", key.grad[:5, :3, :5])
99 | print("table_k.grad[:5, :3, :5, :2]: ", table_k.grad[:5, :3, :5, :2])
100 | # input()
101 |
102 | query_grad = query.grad.clone()
103 | key_grad = key.grad.clone()
104 | table_q_grad = table_q.grad.clone()
105 | table_k_grad = table_k.grad.clone()
106 |
107 | query.grad.zero_()
108 | key.grad.zero_()
109 | table_q.grad.zero_()
110 | table_k.grad.zero_()
111 |
112 |
113 | # print("query.is_contiguous(): ", query.is_contiguous())
114 | # print("key.is_contiguous(): ", key.is_contiguous())
115 | # print("index_q.is_contiguous(): ", index_q.is_contiguous())
116 | # print("index_k.is_contiguous(): ", index_k.is_contiguous())
117 | table_q = table_q.detach().permute(0,3,1,2).contiguous()
118 | table_k = table_k.detach().permute(0,3,1,2).contiguous()
119 | table_q.requires_grad = True
120 | table_k.requires_grad = True
121 | output_v2 = sptr.dot_prod_with_idx_all(query, index_q.int(), index_q_offsets.int(), n_max, key, ctg_index_1_offsets.int(), index_k.int(), table_q, table_k, rel_index.int())
122 | loss = output_v2.mean()
123 | loss.backward()
124 |
125 | table_q_grad2 = table_q.grad.clone().permute(0,2,3,1).contiguous()
126 | table_k_grad2 = table_k.grad.clone().permute(0,2,3,1).contiguous()
127 |
128 | print("output[:5, :5]: ", output[:5, :5])
129 |
130 | print("output_v2[:5, :5]: ", output_v2[:5, :5])
131 |
132 | # print("output_v2.shape: {}, output_v2[:5,:10]: {}".format(output_v2.shape, output_v2[:5,:10]))
133 | print("v2 query.grad[:5, :3, :5]: ", query.grad[:5, :3, :5])
134 | print("v2 table_q.grad[:5, :3, :5, :2]: ", table_q.grad[:5, :3, :5, :2])
135 | print("v2 key.grad[:5, :3, :5]: ", key.grad[:5, :3, :5])
136 | print("v2 table_k.grad[:5, :3, :5, :2]: ", table_k.grad[:5, :3, :5, :2])
137 | # input()
138 |
139 | print("((output-output_v2)**2).max(): ", ((output-output_v2)**2).max())
140 |
141 | print("((query.grad-query_grad)**2).max(): ", ((query.grad-query_grad)**2).max())
142 |
143 | print("((key.grad-key_grad)**2).max(): ", ((key.grad-key_grad)**2).max())
144 |
145 | print("((table_q_grad2-table_q_grad)**2).max(): ", ((table_q_grad2-table_q_grad)**2).max())
146 |
147 | print("((table_k_grad2-table_k_grad)**2).max(): ", ((table_k_grad2-table_k_grad)**2).max())
148 |
149 |
150 |
--------------------------------------------------------------------------------
/src/models/SphereFormer/SparseTransformer/test/test_relative_pos_encoding_op_step2.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import pointops
3 | from torch_scatter import scatter_max, scatter_mean, scatter_add, scatter_min, scatter_sum
4 | import sys
5 | sys.path.append("..")
6 | import sptr
7 |
8 | torch.manual_seed(2)
9 |
10 | # M = 80000
11 | N = 3500 #10 #5 #10 #3500
12 | n = 150 #2 #150
13 | # k = 5 #7 #65
14 | hdim = 16
15 | h = 6 #1 #6
16 | L = 31 #2 #31
17 | v = torch.rand(N, h, hdim).cuda()
18 | table = torch.rand(L, h, hdim, 3).cuda()
19 |
20 | # index_0 = torch.rand(M)
21 | # index_0[index_0 < 0] = 0
22 | # index_0 = (index_0*N).long().cuda()
23 |
24 | # index_1 = torch.rand(M)
25 | # index_1[index_1 < 0] = 0
26 | # index_1 = (index_1*N).long().cuda()
27 |
28 |
29 | v2p_map = torch.randint(low=0, high=n, size=(N,)).cuda()
30 | v2p_map, _ = v2p_map.sort()
31 | counts = v2p_map.bincount()
32 | M = (counts**2).sum().item()
33 | k = counts.max().item()
34 |
35 | print("counts: ", counts)
36 |
37 | attn = torch.rand(M, h).cuda()
38 |
39 | # v2p_map, ctg_sort_idx = v2p_map.sort()
40 | # n, k = p2v_map.shape
41 | # N = v2p_map.shape[0]
42 | mask = torch.arange(k)[None].cuda().expand(n, -1) < counts[:, None] #[n, k]
43 | to_add = torch.arange(k)[None].cuda().expand(n, -1)[mask]
44 | v2p_map = v2p_map.long()
45 | p2v_map = torch.zeros(n, k).long().cuda() #torch.zeros_like(p2v_map)
46 | p2v_map[mask] = torch.arange(N).cuda()
47 | ctg_index_1_offsets = torch.cat([torch.zeros(1, dtype=torch.long).cuda(), (counts ** 2).cumsum(-1)], 0)[v2p_map] + to_add
48 |
49 | print("M: ", M)
50 | # print("counts[:5]: {}, v2p_map[:5]: {}, p2v_map[:5]: {}".format(counts[:5], v2p_map[:5], p2v_map[:5]))
51 | # print("ctg_index_1_offsets[:50]: ", ctg_index_1_offsets[:50])
52 |
53 | # print("ctg_index_1_offsets.max(): {}, ctg_index_1_offsets.min(): {}".format(ctg_index_1_offsets.max(), ctg_index_1_offsets.min()))
54 | print("ctg_index_1_offsets: ", ctg_index_1_offsets)
55 |
56 | rel_index = torch.rand(M, 3)
57 | rel_index[rel_index < 0] = 0
58 | rel_index = (rel_index*L).long().cuda()
59 |
60 | # print("rel_index.min(): {}, rel_index.max(): {}".format(
61 | # rel_index.min(), rel_index.max()
62 | # ))
63 |
64 | print("rel_index: ", rel_index)
65 |
66 | index_0_counts = counts[v2p_map.long()] #[N, ]
67 | index_0_offsets = index_0_counts.cumsum(-1) #[N, ]
68 | index_0_offsets = torch.cat([torch.zeros(1, dtype=torch.long).cuda(), index_0_offsets], 0) #[N+1]
69 | n_max = p2v_map.shape[1]
70 | index_0, index_1 = pointops.precompute_index_pairs(p2v_map, counts, index_0_offsets)
71 | index_0 = index_0.long()
72 | index_1 = index_1.long()
73 |
74 | print("index_0: {}".format(index_0))
75 | print("index_1: {}".format(index_1))
76 |
77 | # print("index_0.max(): {}, index_0.min(): {}".format(index_0.max(), index_0.min()))
78 |
79 | # print("index_1.max(): {}, index_1.min(): {}".format(index_1.max(), index_1.min()))
80 |
81 | # input()
82 |
83 | # # rearrange index for acceleration
84 | # index_0, indices = torch.sort(index_0) #[M,]
85 | # index_1 = index_1[indices] #[M,]
86 | # rel_index = rel_index[indices]
87 | # index_0_counts = index_0.bincount()
88 |
89 | print("index_0_counts.shape: ", index_0_counts.shape)
90 |
91 | # n_max = index_0_counts.max()
92 | # index_0_offsets = index_0_counts.cumsum(dim=-1) #[N]
93 |
94 | # print("v1 index_0_offsets.shape: ", index_0_offsets.shape)
95 |
96 | # index_0_offsets = torch.cat([torch.zeros(1, dtype=torch.long).cuda(), index_0_offsets], 0) #[N+1]
97 |
98 |
99 | attn.requires_grad = True
100 | v.requires_grad = True
101 | table.requires_grad = True
102 |
103 |
104 | output = pointops.attention_step2_with_rel_pos_value(attn, v, index_0.int(), index_1.int(), table, rel_index.int())
105 | loss = output.mean()
106 | loss.backward()
107 |
108 | # print("output.shape: {}, output[:5,:10,:5]: {}".format(output.shape, output[:5,:10, :5]))
109 | # print("attn.grad[:5, :3]: ", attn.grad[:5, :3])
110 | # print("v.grad[:5, :3, :5]: ", v.grad[:5, :3, :5])
111 | # print("table.grad[:5, :3, :5, :2]: ", table.grad[:5, :3, :5, :2])
112 | # # input()
113 |
114 | attn_grad = attn.grad.clone()
115 | v_grad = v.grad.clone()
116 | table_grad = table.grad.clone()
117 |
118 | attn.grad.zero_()
119 | v.grad.zero_()
120 | table.grad.zero_()
121 |
122 | # print("query.is_contiguous(): ", query.is_contiguous())
123 | # print("key.is_contiguous(): ", key.is_contiguous())
124 | # print("index_0.is_contiguous(): ", index_0.is_contiguous())
125 | # print("index_1.is_contiguous(): ", index_1.is_contiguous())
126 | table = table.detach().permute(0,3,1,2).contiguous()
127 | table.requires_grad = True
128 | output_v2 = sptr.attention_step2_with_rel_pos_value(attn, v, index_0.int(), index_0_offsets.int(), n_max, index_1.int(), ctg_index_1_offsets.int(), table, rel_index.int())
129 | # output_v2 = pointops.attention_step2_with_rel_pos_value_v4(attn, v, index_0_offsets.int(), n_max, index_1.int(), table, rel_index.int())
130 | loss = output_v2.mean()
131 | loss.backward()
132 |
133 | table_grad2 = table.grad.clone().permute(0,2,3,1).contiguous()
134 |
135 | # print("output_v2.shape: {}, output_v2[:5,:10,:5]: {}".format(output_v2.shape, output_v2[:5,:10,:5]))
136 | # print("v2 attn.grad[:5, :3]: ", attn.grad[:5, :3])
137 | # print("v2 v.grad[:5, :3, :5]: ", v.grad[:5, :3, :5])
138 | # print("v2 table.grad[:5, :3, :5, :2]: ", table.grad[:5, :3, :5, :2])
139 | # # input()
140 |
141 | print("((output-output_v2)**2).max(): ", ((output-output_v2)**2).max())
142 |
143 | print("((attn_grad-attn.grad)**2).max(): ", ((attn_grad-attn.grad)**2).max())
144 |
145 | print("((v_grad-v.grad)**2).max(): ", ((v_grad-v.grad)**2).max())
146 |
147 | print("((table_grad-table_grad2)**2).max(): ", ((table_grad-table_grad2)**2).max())
148 |
149 | # print("torch.max((attn_flat-attn_flat_v2)**2): ", torch.max((attn_flat-attn_flat_v2)**2))
150 |
151 |
--------------------------------------------------------------------------------
/src/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raktimgg/SALSA/9152e8dff6fb3fc07dd546c9f19bfab462d08a6f/src/models/__init__.py
--------------------------------------------------------------------------------
/src/models/adappool.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 |
5 | class AdaptivePooling(nn.Module):
6 | def __init__(self, feature_dim, output_channels):
7 | super().__init__()
8 | self.output_channels = output_channels
9 | self.query = nn.Parameter(torch.randn(output_channels, feature_dim))
10 |
11 | def forward(self, x, return_weights=False):
12 | """
13 | Args:
14 | x: Input tensor of shape (batch_size, input_channels, feature_dim)
15 |
16 | Returns:
17 | Output tensor of shape (batch_size, output_channels, feature_dim)
18 | """
19 | query = self.query.unsqueeze(0).repeat(x.shape[0],1,1)
20 |
21 | out = F.scaled_dot_product_attention(query=query,key=x,value=x)
22 | if return_weights:
23 | attn_scores = torch.einsum('ij,bkj->bki', self.query, x)
24 | attn_weights = F.softmax(attn_scores, dim=1)
25 | return out, attn_weights
26 |
27 | return out
28 |
--------------------------------------------------------------------------------
/src/models/pca_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
5 | import numpy as np
6 | import torch
7 | import torch.nn as nn
8 |
9 | from models.salsa import SALSA
10 |
11 |
12 | class L2Norm(nn.Module):
13 | def forward(self, x):
14 | return torch.nn.functional.normalize(x, p=2.0, dim=1, eps=1e-12)
15 |
16 | class PCAModel(nn.Module):
17 | def __init__(self,num_in_features,num_out_features):
18 | super(PCAModel, self).__init__()
19 | self.pca_conv = nn.Conv2d(num_in_features, num_out_features, kernel_size=(1, 1), stride=1, padding=0)
20 | self.layer = nn.Sequential(*[self.pca_conv, nn.Flatten(), L2Norm()])
21 |
22 | def forward(self,x):
23 | return self.layer(x)
24 |
25 |
26 | class CombinedModel(nn.Module):
27 | def __init__(self, voxel_sz, num_in_features,num_out_features):
28 | super(CombinedModel, self).__init__()
29 | self.spherelpr = SALSA(voxel_sz=voxel_sz)
30 | self.pca_model = PCAModel(num_in_features, num_out_features)
31 |
32 | def forward(self,data):
33 | coord, xyz, feat, batch = data
34 | output_feats, output_desc = self.spherelpr(coord, xyz, feat, batch)
35 | output_desc = self.pca_model(output_desc[...,None][...,None])
36 | return output_feats, output_desc
37 | # return output_desc
38 |
39 |
--------------------------------------------------------------------------------
/src/models/salsa.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | from datetime import datetime
4 | from time import time
5 |
6 | import tqdm
7 |
8 | sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
9 |
10 | import numpy as np
11 | import spconv.pytorch as spconv
12 | import torch
13 | import torch.nn as nn
14 |
15 | from models.Mixer.mixer import Mixer
16 | from models.SphereFormer.model.unet_spherical_transformer import Semantic
17 | from models.adappool import AdaptivePooling
18 | from utils.misc_utils import read_yaml_config
19 |
20 |
21 | def print_model_size(model):
22 | param_size = 0
23 | for param in model.parameters():
24 | param_size += param.nelement() * param.element_size()
25 | buffer_size = 0
26 | for buffer in model.buffers():
27 | buffer_size += buffer.nelement() * buffer.element_size()
28 |
29 | size_all_mb = (param_size + buffer_size) / 1024**2
30 | print('model size: {:.3f}MB'.format(size_all_mb))
31 |
32 |
33 | class SALSA(nn.Module):
34 | def __init__(self,voxel_sz):
35 | super(SALSA, self).__init__()
36 | config = read_yaml_config(os.path.join(os.path.dirname(__file__),'../config/model.yaml'))
37 | self.k = config['aggregator']['tokens']
38 | feature_dim = config['feat_extractor']['feature_dim']
39 | patch_size = config['feat_extractor']['patch_size']
40 | voxel_size = [voxel_sz, voxel_sz, voxel_sz]
41 | patch_size = np.array([voxel_size[i] * patch_size for i in range(3)]).astype(np.float32)
42 | window_size = patch_size * 6
43 | self.do_pe = True
44 | self.feature_extractor = Semantic(input_c=config['feat_extractor']['input_c'],
45 | m=config['feat_extractor']['m'],
46 | classes=feature_dim,
47 | block_reps=config['feat_extractor']['block_reps'],
48 | block_residual=True,
49 | layers=config['feat_extractor']['layers'],
50 | window_size=window_size,
51 | window_size_sphere=np.array(config['feat_extractor']['window_size_sphere']),
52 | quant_size=window_size/24,
53 | quant_size_sphere= np.array(config['feat_extractor']['window_size_sphere'])/24,
54 | rel_query=True,
55 | rel_key=True,
56 | rel_value=True,
57 | drop_path_rate=config['feat_extractor']['drop_path_rate'],
58 | window_size_scale=config['feat_extractor']['window_size_scale'],
59 | grad_checkpoint_layers=[],
60 | sphere_layers=config['feat_extractor']['sphere_layers'],
61 | a=config['feat_extractor']['a'],
62 | )
63 |
64 | self.attpool = AdaptivePooling(feature_dim=feature_dim,output_channels=self.k)
65 |
66 | self.descriptor_extractor = Mixer(in_channels=self.k,
67 | out_channels=config['aggregator']['out_channels'],
68 | in_d=feature_dim,
69 | mix_depth=config['aggregator']['mix_depth'],
70 | mlp_ratio=config['aggregator']['mlp_ratio'],
71 | out_d=config['aggregator']['out_d'])
72 |
73 |
74 |
75 | self.do_pe = True
76 |
77 |
78 |
79 | def forward(self, coord, xyz, feat, batch, save_attn_weights=False):
80 | ########################## Feature extractor ########################################
81 | batch_shape = batch[-1]+1
82 | coord = torch.cat([batch.unsqueeze(-1), coord], -1)
83 | spatial_shape = np.clip((coord.max(0)[0][1:] + 1).cpu().numpy(), 128, None)
84 |
85 | sinput = spconv.SparseConvTensor(feat, coord.int(), spatial_shape, batch_shape)
86 |
87 | local_features = self.feature_extractor(sinput, xyz, batch)
88 | #####################################################################################
89 |
90 | #################### Adaptive pooling + Mixer based aggregator #####################
91 | padded_split_local_features = []
92 | _, counts = torch.unique(batch, return_counts=True)
93 | split_local_fetures = torch.split(local_features, list(counts)) # [(N1,16),(N2,16),(N3,16),(N4,16),...]
94 | for features in split_local_fetures:
95 | # print(features.shape)
96 | if save_attn_weights:
97 | attval, attn_weights = self.attpool(features.unsqueeze(0), return_weights=True)
98 | self.attn_weights = attn_weights
99 | else:
100 | attval = self.attpool(features.unsqueeze(0))
101 | padded_split_local_features.append(attval.squeeze(0)) ### Attention based pooling
102 | padded_split_local_features = torch.stack(padded_split_local_features, dim=0)
103 | global_descriptor = self.descriptor_extractor(padded_split_local_features)
104 | #####################################################################################
105 |
106 |
107 | return split_local_fetures, global_descriptor
108 |
109 |
110 | if __name__=='__main__':
111 | import random
112 | import time
113 | os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE"
114 | os.environ["CUBLAS_WORKSPACE_CONFIG"]=":4096:8"
115 | # seed = 3407
116 | seed = 1100
117 | # seed = np.random.randint(10000)
118 | print(seed)
119 | random.seed(seed)
120 | np.random.seed(seed)
121 | torch.manual_seed(seed)
122 | torch.backends.cudnn.benchmark = False
123 | torch.backends.cudnn.enabled = True
124 | torch.backends.cudnn.deterministic = True
125 | torch.cuda.manual_seed_all(seed)
126 | torch.use_deterministic_algorithms(True)
127 |
128 | model = SALSA(2,voxel_sz=0.5,device='cuda')
129 | # save_path = '/data/raktim/Projects/LPR/Main/src/checkpoints/Ablation/NewSphereMixerVoxel2/model_6.pth'
130 | # checkpoint = torch.load(save_path) # ,map_location='cuda:0')
131 | # model.load_state_dict(checkpoint)
132 | model.to('cuda')
133 |
134 | coords = torch.IntTensor(np.random.randint(0,100,size=[11000,3])).to('cuda')
135 | xyz = torch.FloatTensor(np.random.rand(11000,3)).to('cuda')
136 | feats = torch.FloatTensor(np.random.rand(11000,3)).to('cuda')
137 | batch_number = torch.IntTensor(np.ones([11000])).to('cuda')
138 | # print(coords.shape, xyz.shape, feats.shape, batch_number.shape)
139 | model.eval()
140 |
141 | N = 1000
142 |
143 | with torch.inference_mode():
144 | # torch.cuda.synchronize()
145 | start = time.time()
146 | for i in tqdm.tqdm(range(N)):
147 | local_features, output_desc = model(coords, xyz, feats, batch_number)
148 | # torch.cuda.synchronize()
149 | end = time.time()
150 | print('Forward pass took {} seconds for {} trials'.format(end - start,N))
151 |
152 |
153 |
154 |
--------------------------------------------------------------------------------
/src/train.py:
--------------------------------------------------------------------------------
1 | import gc
2 | import os
3 | import random
4 | import sys
5 | import time
6 |
7 | import numpy as np
8 | import torch
9 | from torch.utils.tensorboard import SummaryWriter
10 |
11 | sys.path.append(os.path.join(os.path.dirname(__file__), '../'))
12 |
13 | from data.sejong_southbay import SejongSouthbayTupleLoader
14 |
15 | from models.salsa import SALSA
16 |
17 | from loss.loss import find_loss
18 | from utils.misc_utils import tuple_collate_fn, read_yaml_config
19 |
20 | os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE"
21 | os.environ["CUBLAS_WORKSPACE_CONFIG"]=":4096:8"
22 |
23 | def print_nb_params(m):
24 | model_parameters = filter(lambda p: p.requires_grad, m.parameters())
25 | params = sum([np.prod(p.size()) for p in model_parameters])
26 | print(f'Trainable parameters: {params/1e6:.3}M')
27 | del model_parameters, params
28 |
29 | def print_model_size(model):
30 | param_size = 0
31 | for param in model.parameters():
32 | param_size += param.nelement() * param.element_size()
33 | buffer_size = 0
34 | for buffer in model.buffers():
35 | buffer_size += buffer.nelement() * buffer.element_size()
36 |
37 | size_all_mb = (param_size + buffer_size) / 1024**2
38 | print('model size: {:.3f}MB'.format(size_all_mb))
39 |
40 |
41 | def main():
42 | config = read_yaml_config(os.path.join(os.path.dirname(__file__),'config/train.yaml'))
43 | writer = SummaryWriter(config['writer_loc'])
44 | # Get data loader
45 | batch_size = config['batch_size']
46 | train_transform = None
47 | dataset = SejongSouthbayTupleLoader(cached_queries=config['cached_queries'], pcl_transform=train_transform)
48 | device = config['device']
49 |
50 | model = SALSA(voxel_sz=0.5).to(device)
51 |
52 | model.train()
53 | print_nb_params(model)
54 | print_model_size(model)
55 | MAX_EPOCH = config['max_epoch']
56 |
57 | optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'])
58 |
59 |
60 | kk_batch = 0
61 | kk_subcache = 0
62 | for e in range(MAX_EPOCH):
63 | EPOCH_LOSS = []
64 | time1 = time.time()
65 | dataset.new_epoch()
66 | steps_per_epoch = int(np.ceil(1000/batch_size))
67 | scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=config['max_lr'],epochs=dataset.nCacheSubset, steps_per_epoch=steps_per_epoch,anneal_strategy='cos', cycle_momentum=False)
68 | lr_list = [scheduler.get_last_lr()]
69 | for ii in range(dataset.nCacheSubset):
70 | scheduler.step((ii+1)*steps_per_epoch)
71 | lr_list.append(scheduler.get_last_lr())
72 |
73 | for current_subset in range(0,dataset.nCacheSubset):
74 | CACHE_LOSS = []
75 |
76 | dataset.current_subset=current_subset
77 | dataset.update_subcache(model,outputdim=config['outdim'])
78 | if len(dataset.triplets)==0:
79 | continue
80 | model.train()
81 | data_loader = torch.utils.data.DataLoader(dataset=dataset,shuffle=True, batch_size=batch_size,collate_fn=tuple_collate_fn, num_workers=16)
82 | scheduler_lr = np.linspace(lr_list[current_subset],lr_list[current_subset+1],len(data_loader))
83 |
84 | for i, batch_data in enumerate(data_loader):
85 | model.zero_grad()
86 | optimizer.zero_grad()
87 | coord, xyz, feat, batch_number, labels, point_pos_pairs = batch_data
88 | coord, xyz, feat, batch_number, labels = coord.to(device), xyz.to(device), feat.to(device), batch_number.to(device),labels.to(device)
89 | local_features, global_descriptor = model(coord, xyz, feat, batch_number)
90 |
91 | loss = find_loss(local_features, global_descriptor, point_pos_pairs)
92 | loss.backward()
93 | optimizer.step()
94 | for param_group in optimizer.param_groups:
95 | last_lr = param_group['lr']
96 | param_group['lr'] = scheduler_lr[i][0]
97 | writer.add_scalar("Batch Loss", loss.item(), kk_batch)
98 | writer.add_scalar("Batch LR", last_lr, kk_batch)
99 | kk_batch += 1
100 | CACHE_LOSS.append(loss.item())
101 | sys.stdout.write('\r' + 'Epoch ' + str(e + 1) + ' / ' + str(MAX_EPOCH) + ' Subset ' + str(current_subset + 1) + ' / ' + str(dataset.nCacheSubset) + ' Progress ' + str(i+1) + ' / ' + str(len(data_loader))+ ' Loss ' + str(format(loss.item(),'.2f')) + ' time '+ str(format(time.time()-time1,'.2f'))+' seconds.')
102 |
103 | torch.save(model.state_dict(),os.path.join(os.path.dirname(__file__),'checkpoints/SALSA/Model/model_'+str(e)+'.pth'))
104 | del coord, xyz, feat, batch_number, labels, local_features, global_descriptor, point_pos_pairs
105 | gc.collect()
106 | torch.cuda.empty_cache()
107 | cache_loss_avg = sum(CACHE_LOSS)/len(CACHE_LOSS)*steps_per_epoch
108 |
109 | writer.add_scalar("Subcache Loss", cache_loss_avg, kk_subcache)
110 | writer.add_scalar("Subcache LR", last_lr, kk_subcache)
111 | kk_subcache += 1
112 |
113 | EPOCH_LOSS.append(cache_loss_avg)
114 | print(' ')
115 | print('Avg. Subcache Loss', cache_loss_avg)
116 | torch.save(model.state_dict(),os.path.join(os.path.dirname(__file__),'checkpoints/SALSA/Model/model_'+str(e)+'.pth'))
117 | epoch_loss_avg = sum(EPOCH_LOSS)/len(EPOCH_LOSS)
118 | print(' ')
119 | print('Avg. EPOCH Loss', epoch_loss_avg)
120 | writer.add_scalar("Epoch Loss", epoch_loss_avg, e)
121 |
122 |
123 | if __name__ == "__main__":
124 | seed = 1100
125 | random.seed(seed)
126 | np.random.seed(seed)
127 | torch.manual_seed(seed)
128 | torch.backends.cudnn.benchmark = False
129 | torch.backends.cudnn.enabled = True
130 | torch.backends.cudnn.deterministic = True
131 | torch.cuda.manual_seed_all(seed)
132 | torch.use_deterministic_algorithms(True)
133 |
134 | gc.collect()
135 | main()
136 |
137 |
138 |
--------------------------------------------------------------------------------
/src/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raktimgg/SALSA/9152e8dff6fb3fc07dd546c9f19bfab462d08a6f/src/utils/__init__.py
--------------------------------------------------------------------------------
/src/utils/misc_utils.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import random
3 |
4 | import numpy as np
5 | import torch
6 | import yaml
7 |
8 |
9 | def collate_fn(batch):
10 | coord, xyz, feat = list(zip(*batch))
11 | offset, count = [], 0
12 |
13 | new_coord, new_xyz, new_feat = [], [], []
14 | k = 0
15 | for i, item in enumerate(xyz):
16 |
17 | count += item.shape[0]
18 | k += 1
19 | offset.append(count)
20 | new_coord.append(coord[i])
21 | new_xyz.append(xyz[i])
22 | new_feat.append(feat[i])
23 | offset_ = torch.IntTensor(offset[:k]).clone()
24 | offset_[1:] = offset_[1:] - offset_[:-1]
25 | batch_number = torch.cat([torch.tensor([ii]*o) for ii,o in enumerate(offset_)], 0).long()
26 | coords,xyz,feat = torch.cat(new_coord[:k]), torch.cat(new_xyz[:k]), torch.cat(new_feat[:k])
27 | return coords,xyz,feat,batch_number
28 |
29 |
30 | def tuple_collate_fn(batch):
31 | anchor_coords, anchor_xyz, anchor_feats, pos_coords, pos_xyz, pos_feats, neg_coords, neg_xyz, neg_feats, labels, point_pos_pairs = list(zip(*batch))
32 | offset, count = [], 0
33 |
34 | new_coord, new_xyz, new_feat, new_label, new_point_pos_pairs = [], [], [], [], []
35 |
36 | coord, xyz, feat = anchor_coords, anchor_xyz, anchor_feats
37 | for i, item in enumerate(xyz):
38 |
39 | count += item.shape[0]
40 | offset.append(count)
41 | new_coord.append(coord[i])
42 | new_xyz.append(xyz[i])
43 | new_feat.append(feat[i])
44 | new_label.append(labels[i][0])
45 |
46 | coord, xyz, feat = pos_coords, pos_xyz, pos_feats
47 | for i, item in enumerate(xyz):
48 |
49 | count += item.shape[0]
50 | offset.append(count)
51 | new_coord.append(coord[i])
52 | new_xyz.append(xyz[i])
53 | new_feat.append(feat[i])
54 | new_label.append(labels[i][1])
55 |
56 |
57 | coord, xyz, feat = neg_coords, neg_xyz, neg_feats
58 | for i, item in enumerate(xyz):
59 |
60 | count += item.shape[0]
61 | offset.append(count)
62 | new_coord.append(coord[i])
63 | new_xyz.append(xyz[i])
64 | new_feat.append(feat[i])
65 | new_label.append(labels[i][2])
66 |
67 | if point_pos_pairs!=None:
68 | for i, item in enumerate(point_pos_pairs):
69 | # item = np.array(item) + len(new_point_pos_pairs)
70 | # if i>0:
71 | # item1 = np.array(item)[:,0] + new_coord[i-1].shape[0]
72 | # item2 = np.array(item)[:,1] + new_coord[i-1].shape[0]
73 | new_point_pos_pairs.append(item)
74 |
75 |
76 | offset_ = torch.IntTensor(offset).clone()
77 | offset_[1:] = offset_[1:] - offset_[:-1]
78 | batch_number = torch.cat([torch.tensor([ii]*o) for ii,o in enumerate(offset_)], 0).long()
79 | coords,xyz,feat,labels = torch.cat(new_coord), torch.cat(new_xyz), torch.cat(new_feat), torch.Tensor(new_label)
80 | return coords,xyz,feat,batch_number,labels, new_point_pos_pairs
81 |
82 |
83 | def hashM(arr, M):
84 | if isinstance(arr, np.ndarray):
85 | N, D = arr.shape
86 | else:
87 | N, D = len(arr[0]), len(arr)
88 |
89 | hash_vec = np.zeros(N, dtype=np.int64)
90 | for d in range(D):
91 | if isinstance(arr, np.ndarray):
92 | hash_vec += arr[:, d] * M**d
93 | else:
94 | hash_vec += arr[d] * M**d
95 | return hash_vec
96 |
97 |
98 | def pdist(A, B, dist_type='L2'):
99 | if dist_type == 'L2':
100 | D2 = torch.sum((A.unsqueeze(1) - B.unsqueeze(0)).pow(2), 2)
101 | return torch.sqrt(D2 + 1e-7)
102 | elif dist_type == 'SquareL2':
103 | return torch.sum((A.unsqueeze(1) - B.unsqueeze(0)).pow(2), 2)
104 | else:
105 | raise NotImplementedError('Not implemented')
106 |
107 |
108 | #####################################################################################
109 | # Load poses
110 | #####################################################################################
111 |
112 |
113 | def load_poses_from_csv(file_name):
114 | with open(file_name, newline='') as f:
115 | reader = csv.reader(f)
116 | data_poses = list(reader)
117 |
118 | transforms = []
119 | positions = []
120 | for cnt, line in enumerate(data_poses):
121 | line_f = [float(i) for i in line]
122 | P = np.vstack((np.reshape(line_f[1:], (3, 4)), [0, 0, 0, 1]))
123 | transforms.append(P)
124 | positions.append([P[0, 3], P[1, 3], P[2, 3]])
125 | return np.asarray(transforms), np.asarray(positions)
126 |
127 |
128 | #####################################################################################
129 | # Load timestamps
130 | #####################################################################################
131 |
132 |
133 | def load_timestamps_csv(file_name):
134 | with open(file_name, newline='') as f:
135 | reader = csv.reader(f)
136 | data_poses = list(reader)
137 | data_poses_ts = np.asarray(
138 | [float(t)/1e9 for t in np.asarray(data_poses)[:, 0]])
139 | return data_poses_ts
140 |
141 | def read_yaml_config(filename):
142 | with open(filename, 'r') as stream:
143 | try:
144 | # Load the YAML file
145 | config = yaml.safe_load(stream)
146 | return config
147 | except yaml.YAMLError as exc:
148 | print(exc)
149 | return None
--------------------------------------------------------------------------------
/src/utils/o3d_utils.py:
--------------------------------------------------------------------------------
1 | import copy
2 |
3 | import numpy as np
4 | import open3d as o3d
5 |
6 |
7 | def make_open3d_point_cloud(xyz, color=None, tile=False):
8 | pcd = o3d.geometry.PointCloud()
9 | pcd.points = o3d.utility.Vector3dVector(xyz)
10 | if color is not None:
11 | if tile:
12 | if len(color) != len(xyz):
13 | color = np.tile(color, (len(xyz), 1))
14 | pcd.colors = o3d.utility.Vector3dVector(color)
15 | return pcd
16 |
17 |
18 | def get_matching_indices(source, target, search_voxel_size, K=None):
19 | source_copy = copy.deepcopy(source)
20 | target_copy = copy.deepcopy(target)
21 | pcd_tree = o3d.geometry.KDTreeFlann(target_copy)
22 |
23 | match_inds = []
24 | for i, point in enumerate(source_copy.points):
25 | [_, idx, _] = pcd_tree.search_radius_vector_3d(
26 | point, search_voxel_size)
27 | if K is not None:
28 | idx = idx[:K]
29 | for j in idx:
30 | match_inds.append([i, j])
31 | return np.asarray(match_inds)
32 |
--------------------------------------------------------------------------------