├── .gitattributes ├── .idea ├── .gitignore ├── OverlapTransformer-master.iml ├── deployment.xml ├── inspectionProfiles │ ├── Project_Default.xml │ └── profiles_settings.xml ├── misc.xml └── modules.xml ├── LICENSE ├── README.md ├── config ├── config.yml └── config_nclt.yml ├── demo.png ├── demo ├── __pycache__ │ └── com_overlap.cpython-38.pyc ├── com_overlap.py ├── demo_compute_overlap_sim.py └── scans │ ├── 000000.bin │ ├── 000005.bin │ └── 000015.bin ├── fig.png ├── mambapy ├── .gitignore ├── LICENSE ├── README.md ├── assets │ ├── logo.png │ ├── speed_comparison.png │ ├── training_vs_d_model.png │ ├── training_vs_d_state.png │ └── training_vs_seqlen_d_state_var.png ├── docs │ ├── assets │ │ ├── cumsum_rnns.jpg │ │ ├── down_sweep_rule.jpg │ │ ├── downsweep.jpg │ │ ├── downsweep_ex.jpg │ │ ├── downsweep_updated.jpg │ │ ├── reduction_tree.jpg │ │ ├── tensor_mem.jpeg │ │ ├── tensor_mem_tree.jpeg │ │ ├── tree_reduction.jpeg │ │ ├── tree_reduction_xs.jpeg │ │ └── up_down_trees.jpg │ └── pscan.ipynb ├── examples │ ├── buffer.py │ ├── example_e2e_training.ipynb │ ├── example_llm.ipynb │ └── tinyhome.py ├── mamba.py ├── mamba_lm.py ├── mlx │ ├── README.md │ ├── assets │ │ └── mamba_mlx.png │ ├── mamba_lm_mlx.py │ ├── mamba_mlx.py │ ├── misc.py │ ├── pscan_mlx.py │ ├── scripts │ │ └── generate.py │ └── utils.py ├── pscan.py └── tests │ ├── compare_mambapy_cuda.py │ ├── mem_mamba_3.py │ ├── profiling_mamba.py │ └── profiling_pscan.py ├── modules ├── __pycache__ │ ├── loss.cpython-38.pyc │ ├── loss.cpython-39.pyc │ ├── netvlad.cpython-38.pyc │ └── netvlad.cpython-39.pyc ├── loss.py ├── netvlad.py └── overlap_mamba.py ├── requirements.txt ├── test ├── test_kitti00_PR.py ├── test_kitti00_prepare.py ├── test_kitti00_topN.py ├── test_nclt_topn.py └── test_nclt_topn_prepare.py ├── tools ├── __pycache__ │ ├── read_all_sets.cpython-38.pyc │ ├── read_all_sets.cpython-39.pyc │ ├── read_samples.cpython-38.pyc │ ├── read_samples.cpython-39.pyc │ └── read_samples_haomo.cpython-39.pyc ├── name_folder.py ├── read_all_sets.py ├── read_samples.py ├── read_samples_haomo.py └── utils │ ├── __pycache__ │ ├── utils.cpython-38.pyc │ └── utils.cpython-39.pyc │ ├── com_overlap_yaw.py │ ├── gen_depth_data.py │ ├── gen_gt_data.py │ ├── split_train_val.py │ └── utils.py └── train ├── training_distributed.py ├── training_overlap_mamba_kitti.py └── training_overlap_mambar_nclt.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.tar filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Editor-based HTTP Client requests 5 | /httpRequests/ 6 | # Datasource local storage ignored files 7 | /dataSources/ 8 | /dataSources.local.xml 9 | -------------------------------------------------------------------------------- /.idea/OverlapTransformer-master.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /.idea/deployment.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 15 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 40 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # OverlapMamba: Novel Shift State Space Model for LiDAR-based Place Recognition 2 | 3 | 4 |

5 | 6 |

7 |

8 | Fig. 1 Core idea of the proposed OverlapMamba 

9 | 10 | ### 💾Environment 11 | 12 | 13 | We use pytorch-gpu for neural networks. 14 | 15 | To use a GPU, first you need to install the nvidia driver and CUDA. 16 | 17 | - CUDA Installation guide: [link](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html) 18 | We use CUDA 11.3 in our work. Other versions of CUDA are also supported but you should choose the corresponding torch version in the following Torch dependences. 19 | 20 | - System dependencies: 21 | 22 | ```bash 23 | sudo apt-get update 24 | sudo apt-get install -y python3-pip python3-tk 25 | sudo -H pip3 install --upgrade pip 26 | ``` 27 | - Torch dependences: 28 | Following this [link](https://pytorch.org/get-started/locally/), you can download Torch dependences by pip: 29 | ```bash 30 | pip3 install torch==1.10.2+cu113 torchvision==0.11.3+cu113 torchaudio==0.10.2+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html 31 | ``` 32 | or by conda: 33 | ```bash 34 | conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch 35 | ``` 36 | 37 | - Other Python dependencies (may also work with different versions than mentioned in the requirements file): 38 | 39 | ```bash 40 | sudo -H pip3 install -r requirements.txt 41 | ``` 42 | 43 | ## 📖How to use 44 | 45 | We provide a training and test tutorial for KITTI sequences in this repository. Before any operation, please modify the [config file](https://github.com/SCNU-RISLAB/OverlapMamba/blob/main/config/config.yml) according to your setups. 46 | 47 | ### 📚Dataset 48 | 49 | Download KITTI dataset from [KITTI](https://www.cvlibs.net/datasets/kitti/user_login.php). 50 | 51 | We recommend you follow our code and data structures as follows. 52 | 53 | ### Code Structure 54 | 55 | ```bash 56 | ├── config 57 | │   ├── config_nclt.yml 58 | │   └── config.yml 59 | ├── modules 60 | │   ├── loss.py 61 | │   ├── netvlad.py 62 | │   ├── overlap_mamba_nclt.py 63 | │   └── overlap_mamba.py 64 | ├── test 65 | │   ├── test_mamba_topn_prepare.py 66 | │   ├── test_mamba_topn.py 67 | │   ├── test_kitti00_prepare.py 68 | │   ├── test_kitti00_PR.py 69 | │   ├── test_kitti00_topN.py 70 | │   ├── test_results_nclt 71 | │   │   └── predicted_des_L2_dis_bet_traj_forward.npz (to be generated) 72 | │   └── test_results_kitti 73 | │   └── predicted_des_L2_dis.npz (to be generated) 74 | ├── tools 75 | │   ├── read_all_sets.py 76 | │   ├── read_samples_haomo.py 77 | │   ├── read_samples.py 78 | │   └── utils 79 | │   ├── gen_depth_data.py 80 | │   ├── split_train_val.py 81 | │   └── utils.py 82 | ├── train 83 | │   ├── training_overlap_mamba_nclt.py 84 | │   └── training_overlap_mamba_kitti.py 85 | ├── valid 86 | │   └── valid_seq.py 87 | ├── visualize 88 | │   ├── des_list.npy 89 | │   └── viz_nclt.py 90 | └── weights 91 | ├── pretrained_overlap_mamba_nclt.pth.tar 92 | └── pretrained_overlap_mamba.pth.tar 93 | ``` 94 | 95 | ### Dataset Structure 96 | ``` 97 | data_root_folder (KITTI sequences root) follows: 98 | ├── 00 99 | │   ├── depth_map 100 | │   ├── 000000.png 101 | │   ├── 000001.png 102 | │   ├── 000002.png 103 | │   ├── ... 104 | │   └── overlaps 105 | │   ├── train_set.npz 106 | ├── 01 107 | ├── 02 108 | ├── ... 109 | ├── 10 110 | └── loop_gt_seq00_0.3overlap_inactive.npz 111 | 112 | valid_scan_folder (KITTI sequence 02 velodyne) contains: 113 | ├── 000000.bin 114 | ├── 000001.bin 115 | ... 116 | 117 | gt_valid_folder (KITTI sequence 02 computed overlaps) contains: 118 | ├── 02 119 | │   ├── overlap_0.npy 120 | │   ├── overlap_10.npy 121 | ... 122 | ``` 123 | You need to download or generate the following files and put them in the right positions of the structure above: 124 | - You can find the groud truth for KITTI 00 here: [loop_gt_seq00_0.3overlap_inactive.npz](https://drive.google.com/file/d/1upAwJBF-_UIB7R8evW0PuJBM3RnrTbzl/view?usp=sharing) 125 | - You can find `gt_valid_folder` for sequence 02 [here](https://drive.google.com/file/d/13_1j20Uq3ppjVEkYaYcKjiJ2Zm7tudyH/view?usp=sharing). 126 | - Since the whole KITTI sequences need a large memory, we recommend you generate range images such as `00/depth_map/000000.png` by the preprocessing from [Overlap_Localization](https://github.com/PRBonn/overlap_localization/blob/master/src/prepare_training/gen_depth_and_normal_map.py), and we will not provide these images. 127 | - More directly, you can generate `.png` range images by [this script](https://github.com/SCNU-RISLAB/OverlapMamba/blob/main/OverlapMamba/tools/utils/gen_depth_data.py) . 128 | - `overlaps` folder of each sequence below `data_root_folder` is provided by the authors of OverlapNet [here](https://drive.google.com/file/d/1i333NUC1DnJglXasqkGYCmo9p45Fx28-/view?usp=sharing). You should rename them to `train_set.npz`. 129 | 130 | 131 | ### Quick Use 132 | 133 | For a quick use, you could download our [model pretrained on KITTI](https://github.com/SCNU-RISLAB/OverlapMamba/blob/main/data_root_folder/pretrained_overlap_mamba.pth.tar), and the following two files also should be downloaded : 134 | - [calib_file](https://drive.google.com/file/d/1LAcFrRSZQPxdD4EKSwIC0d3-uGvLB3yk/view?usp=sharing): calibration file from KITTI 00. 135 | - [poses_file](https://drive.google.com/file/d/1n02m1OqxK122ce8Cjz_N68PkazGqzj9l/view?usp=sharing): pose file from KITTI 00. 136 | 137 | Then you should modify `demo1_config` in the file [config.yaml](https://github.com/SCNU-RISLAB/OverlapMamba/blob/main/config/config.yml). 138 | 139 | Run the demo by: 140 | 141 | ``` 142 | cd demo 143 | python ./demo_compute_overlap_sim.py 144 | ``` 145 | You can see a query scan (000000.bin of KITTI 00) with a reprojected positive sample (000005.bin of KITTI 00) and a reprojected negative sample (000015.bin of KITTI 00), and the corresponding similarity. 146 | 147 | 148 | ### Training 149 | 150 | In the file [config.yaml](https://github.com/SCNU-RISLAB/OverlapMamba/blob/main/config/config.yml), `training_seqs` are set for the KITTI sequences used for training. 151 | 152 | You can start the training with 153 | 154 | ``` 155 | cd train 156 | python ./training_overlap_mamba_kitti.py 157 | ``` 158 | You can resume from our pretrained model [here](https://github.com/SCNU-RISLAB/OverlapMamba/blob/main/data_root_folder/pretrained_overlap_mamba.pth.tar) for training. 159 | 160 | 161 | ### Testing 162 | 163 | Once a model has been trained , the performance of the network can be evaluated. Before testing, the parameters shoud be set in [config.yaml](https://github.com/SCNU-RISLAB/OverlapMamba/blob/main/config/config.yml) 164 | 165 | - `test_seqs`: sequence number for evaluation which is "00" in our work. 166 | - `test_weights`: path of the pretrained model. 167 | - `gt_file`: path of the ground truth file provided by the author of OverlapNet, which can be downloaded [here](https://drive.google.com/file/d/1upAwJBF-_UIB7R8evW0PuJBM3RnrTbzl/view?usp=sharing). 168 | 169 | 170 | Therefore you can start the testing scripts as follows: 171 | 172 | ``` 173 | cd test 174 | mkdir test_results_kitti 175 | python test_kitti00_prepare.py 176 | python test_kitti00_PR.py 177 | python test_kitti00_topN.py 178 | ``` 179 | After you run `test_kitti00_prepare.py`, a file named `predicted_des_L2_dis.npz` is generated in `test_results_kitti`, which is used by `python test_kitti00_PR.py` to calculate PR curve and F1max, and used by `python test_kitti00_topN.py` to calculate topN recall. 180 | 181 | 182 | ## Performance 183 | The code of this repo uses a minimal PyTorch implementation of Mamba, which implements the scan operation as a sequential loop. It closely follows the file of the official Mamba implementation, but its performance are a bit worse than the official Mamba implementation. 184 | 185 | After our paper is accepted, we will provide a C++ implementation of OverlapMamba with libtorch for faster retrival. 186 | 187 | 188 | ## 👏Acknowledgment 189 | This repo is based on [OverlapTransformer](https://github.com/haomo-ai/OverlapTransformer) and [mamba.py](https://github.com/alxndrTL/mamba.py), we are very grateful for their excellent work 190 | and appreciate their contributions to LiDAR-based place recognition(LPR) and highly recommend people to use their excellent public available code. 191 | -------------------------------------------------------------------------------- /config/config.yml: -------------------------------------------------------------------------------- 1 | data_root: 2 | 3 | data_root_folder: "/home/robot/Project/OverlapTransformer/data_root_folder/" # KITTI sequences root 4 | valid_scan_folder: "/home/robot/下载/kitti/sequences/02/velodyne" # KITTI sequence 02 velodyne 5 | gt_valid_folder: "/home/robot/Project/OverlapTransformer/gt_valid_folder/" # KITTI sequence 02 computed overlaps 6 | 7 | # data_root_folder: "/home/robot/Project/OverlapTransformer/dataset-1/" # Ford campus sequences root 8 | # valid_scan_folder: "/home/robot/下载/kitti/sequences/02/velodyne" # KITTI sequence 02 velodyne 9 | # gt_valid_folder: "/home/robot/Project/OverlapTransformer/gt_valid_folder/" # KITTI sequence 02 computed overlaps 10 | 11 | # data_root_folder: "/home/robot/Project/CVTNet/NCLT/" # NCLT sequences root 12 | 13 | 14 | demo1_config: 15 | 16 | calib_file: "/home/robot/Project/OverlapTransformer/data_root_folder/00/calib.txt" # calibration file from KITTI 00 17 | poses_file: "/home/robot/Project/OverlapTransformer/weights/00.txt" # pose file from KITTI 00 18 | test_weights: "/home/robot/Project/OverlapTransformer/weights/pretrained_overlap_transformer.pth.tar" # pretrained model 19 | 20 | 21 | training_config: 22 | 23 | training_seqs: ["03", "04", "05","06", "07", "08", "09", "10"] # KITTI sequences for training 24 | # training_seqs: ["2012-01-08_vel/2012-01-08"] # KITTI sequences for training 25 | 26 | test_config: 27 | 28 | test_seqs: ["00"] # KITTI sequence 00 for evaluation 29 | test_weights: "/home/robot/下载/OverlapTransformer-master-copy/weights/final_weight/trained_overlap_transformer16.pth.tar" # pretrained model 30 | gt_file: "/home/robot/Project/OverlapTransformer/data_root_folder/loop_gt_seq00_0.3overlap_inactive.npz" # ground truth 31 | 32 | # test_seqs: [ "00" ] # Ford campus sequence 00 for evaluation 33 | # test_weights: "/home/robot/下载/OverlapTransformer-master-copy/weights/shift_bimamba_sppf_80/trained_overlap_transformer4.pth.tar" # pretrained model 34 | # gt_file: "/home/robot/下载/Ford campus/loop_gt_seq00_0.3overlap_inactive.npz" # ground truth 35 | 36 | # test_seqs: ["2012-02-05_vel/2012-02-05"] # KITTI sequence 00 for evaluation 37 | # test_weights: "/home/robot/下载/OverlapTransformer-master-copy/weights/nclt_npy/pretrained_overlap_transformer_haomo22.pth.tar" # pretrained model 38 | # gt_file: "/home/robot/Project/CVTNet/NCLT/2012-02-05_vel/2012-02-05/overlaps/gt.npz" # ground truth 39 | 40 | 41 | viz_config: 42 | 43 | calib_file: "/home/robot/Project/OverlapTransformer/data _root_folder/00/calib.txt" # calibration file from KITTI 00 44 | poses_file: "/home/robot/Project/OverlapTransformer/weights/00.txt" # pose file from KITTI 00 45 | cov_file: "/home/robot/Project/OverlapTransformer/weights/covariance_2nd.txt" # covariance file from SUMA++ on KITTI 00 46 | 47 | -------------------------------------------------------------------------------- /config/config_nclt.yml: -------------------------------------------------------------------------------- 1 | file_root: 2 | data_root_folder: "/home/robot/Project/CVTNet/NCLT/2012-01-08_vel/2012-01-08/depth_map/" 3 | triplets_for_training: "/home/robot/Project/CVTNet/more_chosen_normalized_data_120108.npy" 4 | pose_file_database: "/home/mjy/datasets/haomo_data/stamps_calibrated_poses_1208_1_01.npy" 5 | pose_file_query: "/home/mjy/datasets/haomo_data/stamps_calibrated_poses_1208_1_02.npy" 6 | 7 | 8 | data_root_folder_test: "/home/robot/Project/CVTNet/NCLT/2012-02-05_vel/2012-02-05/depth_map/" 9 | 10 | 11 | test_weights: "/home/robot/下载/OverlapTransformer-master-copy/weights/nclt_o1-shift2/trained_overlap_transformer5.pth.tar" 12 | #gt_file: "/home/robot/Project/CVTNet/gt_120108_120205.npy" 13 | gt_file: "/home/robot/Project/SeqOT-main/data_prepararion/gt/gt_120108_120205.npy" 14 | 15 | training_config: 16 | 17 | training_seqs: ["2012-01-08_vel/2012-01-08"] # KITTI sequences for training -------------------------------------------------------------------------------- /demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SCNU-RISLAB/OverlapMamba/b20add9b83ea87727a993eb316869d271436be78/demo.png -------------------------------------------------------------------------------- /demo/__pycache__/com_overlap.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SCNU-RISLAB/OverlapMamba/b20add9b83ea87727a993eb316869d271436be78/demo/__pycache__/com_overlap.cpython-38.pyc -------------------------------------------------------------------------------- /demo/com_overlap.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Developed by Xieyuanli Chen and Thomas Läbe 3 | # This file is covered by the LICENSE file in the root of this project. 4 | # Brief: This script generate the overlap and orientation combined mapping file. 5 | import sys 6 | import os 7 | p = os.path.dirname(os.path.dirname((os.path.abspath(__file__)))) 8 | if p not in sys.path: 9 | sys.path.append(p) 10 | from tools.utils.utils import * 11 | 12 | def com_overlap(scan_paths, poses, frame_idx): 13 | # init ground truth overlap and yaw 14 | print('Start to compute ground truth overlap ...') 15 | overlaps = [] 16 | 17 | # we calculate the ground truth for one given frame only 18 | # generate range projection for the given frame 19 | current_points = load_vertex(scan_paths[frame_idx]) 20 | current_range, project_points, _, _ = range_projection(current_points, fov_up=3, fov_down=-25.0, proj_H=64, proj_W=900, max_range=50) 21 | visible_points = project_points[current_range > 0] 22 | valid_num = len(visible_points) 23 | current_pose = poses[frame_idx] 24 | 25 | tau_ = 1.2 26 | print("threshold for overlap: ", tau_) 27 | 28 | reference_range_list = [] 29 | for i in range(len(scan_paths)): 30 | # generate range projection for the reference frame 31 | reference_idx = int(scan_paths[i][-10:-4]) 32 | reference_pose = poses[reference_idx] 33 | reference_points = load_vertex(scan_paths[i]) 34 | 35 | reference_points_world = reference_pose.dot(reference_points.T).T 36 | reference_points_in_current = np.linalg.inv(current_pose).dot(reference_points_world.T).T 37 | reference_range, _, _, _ = range_projection(reference_points_in_current, fov_up=3, fov_down=-25.0, proj_H=64, proj_W=900, max_range=50) 38 | # calculate overlap 39 | overlap = np.count_nonzero( 40 | abs(reference_range[reference_range > 0] - current_range[reference_range > 0]) < tau_) / valid_num 41 | overlaps.append(overlap) 42 | reference_range_list.append(reference_range) 43 | 44 | 45 | # ground truth format: each row contains [current_frame_idx, reference_frame_idx, overlap,] 46 | ground_truth_mapping = np.zeros((len(scan_paths), 3)) 47 | ground_truth_mapping[:, 0] = np.ones(len(scan_paths)) * frame_idx 48 | ground_truth_mapping[:, 1] = np.arange(len(scan_paths)) 49 | ground_truth_mapping[:, 2] = overlaps 50 | 51 | print('Finish generating ground_truth_mapping!') 52 | 53 | return ground_truth_mapping, current_range, reference_range_list 54 | -------------------------------------------------------------------------------- /demo/demo_compute_overlap_sim.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | p = os.path.dirname(os.path.dirname((os.path.abspath(__file__)))) 4 | if p not in sys.path: 5 | sys.path.append(p) 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | from com_overlap import com_overlap 9 | import yaml 10 | from tools.utils.utils import * 11 | from modules.overlap_mamba import featureExtracter 12 | import torch 13 | 14 | # load config ================================================================ 15 | config_filename = '../config/config.yml' 16 | config = yaml.safe_load(open(config_filename)) 17 | calib_file = config["demo1_config"]["calib_file"] 18 | poses_file = config["demo1_config"]["poses_file"] 19 | test_weights = config["demo1_config"]["test_weights"] 20 | # ============================================================================ 21 | 22 | # load scan paths 23 | scan_folder = "./scans" 24 | scan_paths = load_files(scan_folder) 25 | 26 | # load calibrations 27 | T_cam_velo = load_calib(calib_file) 28 | T_cam_velo = np.asarray(T_cam_velo).reshape((4, 4)) 29 | T_velo_cam = np.linalg.inv(T_cam_velo) 30 | 31 | # load poses 32 | poses = load_poses(poses_file) 33 | pose0_inv = np.linalg.inv(poses[0]) 34 | 35 | # for KITTI dataset, we need to convert the provided poses 36 | # from the camera coordinate system into the LiDAR coordinate system 37 | poses_new = [] 38 | for pose in poses: 39 | poses_new.append(T_velo_cam.dot(pose0_inv).dot(pose).dot(T_cam_velo)) 40 | poses = np.array(poses_new) 41 | 42 | # calculate overlap 43 | ground_truth_mapping, current_range, reference_range_list = com_overlap(scan_paths, poses, frame_idx=0) 44 | 45 | # build model and load pretrained weights 46 | amodel = featureExtracter(channels=1, use_transformer=True) 47 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 48 | amodel.to(device) 49 | print("Loading weights from ", test_weights) 50 | checkpoint = torch.load(test_weights) 51 | amodel.load_state_dict(checkpoint['state_dict']) 52 | amodel.eval() 53 | 54 | overlap_pos = round(ground_truth_mapping[1,-1]*100,2) 55 | overlap_neg = round(ground_truth_mapping[2,-1]*100,2) 56 | 57 | reference_range_pos = reference_range_list[1] 58 | reference_range_neg = reference_range_list[2] 59 | currentrange_neg_tensor = torch.from_numpy(current_range).unsqueeze(0) 60 | currentrange_neg_tensor = currentrange_neg_tensor.unsqueeze(0).cuda() 61 | reference_range_pos_tensor = torch.from_numpy(reference_range_pos).unsqueeze(0) 62 | reference_range_pos_tensor = reference_range_pos_tensor.unsqueeze(0).cuda() 63 | reference_range_neg_tensor = torch.from_numpy(reference_range_neg).unsqueeze(0) 64 | reference_range_neg_tensor = reference_range_neg_tensor.unsqueeze(0).cuda() 65 | 66 | # generate descriptors 67 | des_cur = amodel(currentrange_neg_tensor).cpu().detach().numpy() 68 | des_pos = amodel(reference_range_pos_tensor).cpu().detach().numpy() 69 | des_neg = amodel(reference_range_neg_tensor).cpu().detach().numpy() 70 | 71 | # calculate similarity 72 | dis_pos = np.linalg.norm(des_cur - des_pos) 73 | dis_neg = np.linalg.norm(des_cur - des_neg) 74 | sim_pos = round(1/(1+dis_pos),2) 75 | sim_neg = round(1/(1+dis_neg),2) 76 | 77 | plt.figure(figsize=(8,4)) 78 | plt.subplot(311) 79 | plt.title("query: " + scan_paths[0]) 80 | plt.imshow(current_range) 81 | plt.subplot(312) 82 | plt.title("positive reference: " + scan_paths[1] + " - similarity: " + str(sim_pos)) 83 | plt.imshow(reference_range_list[1]) 84 | plt.subplot(313) 85 | plt.title("negative reference: " + scan_paths[2] + " - similarity: " + str(sim_neg)) 86 | plt.imshow(reference_range_list[2]) 87 | plt.show() 88 | -------------------------------------------------------------------------------- /demo/scans/000000.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SCNU-RISLAB/OverlapMamba/b20add9b83ea87727a993eb316869d271436be78/demo/scans/000000.bin -------------------------------------------------------------------------------- /demo/scans/000005.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SCNU-RISLAB/OverlapMamba/b20add9b83ea87727a993eb316869d271436be78/demo/scans/000005.bin -------------------------------------------------------------------------------- /demo/scans/000015.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SCNU-RISLAB/OverlapMamba/b20add9b83ea87727a993eb316869d271436be78/demo/scans/000015.bin -------------------------------------------------------------------------------- /fig.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SCNU-RISLAB/OverlapMamba/b20add9b83ea87727a993eb316869d271436be78/fig.png -------------------------------------------------------------------------------- /mambapy/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | **/*.npz -------------------------------------------------------------------------------- /mambapy/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Alexandre TL 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /mambapy/README.md: -------------------------------------------------------------------------------- 1 | # mamba.py 🐍 : a simple and efficient Mamba implementation 2 | A straightfoward implementation of [Mamba](https://arxiv.org/abs/2312.00752) in PyTorch with a simple parallel scan implementation, offering an major speedup over a sequential implementation, as the parallel scan allows the parallelization over the time dimension. 3 | It combines the ease of read with good performances. 4 | 5 | ## Updates 6 | - 09/02/2024 : First part of the performance update. For small sequences (<128), it can speed up training by more than 20% compared to the first version. For setups close to what can found in practice (like in NLP), it can speed up training by 10%. 7 | 8 | - 22/01/2024 : Added a MLX version of `mamba.py`, which supports inference as well as training. This version is similar to PyTorch, and allows Mac users to play around with Mamba models. It was [tested]() on the largest Mamba trained to date (2.8b), as well as 9 | 10 | ___ 11 | ## Overview 12 | 13 | ![speed comparison](assets/speed_comparison.png) 14 | 15 | This graph shows the training time (forward and backward pass) of a single Mamba layer (`d_model=16, d_state=16`) using 3 different methods : `CUDA`, which is the official [Mamba implementation](https://github.com/state-spaces/mamba), `mamba.py`, which is this repo, and `sequential`, which is a sequential (RNN-like) implementation of the selective scan. 16 | 17 | This repo contains a simple and readable code implementing the [Mamba](https://arxiv.org/abs/2312.00752) architecture in pure PyTorch as well as MLX. Its primary goal is educational. 18 | 19 |

20 | a python and a mamba 21 |

22 | 23 | The repo is organized as follows : 24 | - `pscan.py` : a PyTorch implementation of Blelloch's parallel scan 25 | - `mamba.py` : the Mamba model, as described in the [paper](https://arxiv.org/abs/2312.00752). It is numerically equivalent (initialization, forward and backward pass). 26 | - `mamba_lm.py` : encapsulates a Mamba model in order to use it as a language model 27 | - `📁 mlx` : basically the same code as above, but in MLX. 28 | - `📁 docs` : a folder containing annotated explanations about the code, focusing on the parallel scan 29 | - `📁 examples` : two examples of how to use the Mamba model in PyTorch. 30 | 31 | ## Usage 32 | 33 | The most basic usage is to use the `Mamba` object ([mamba.py](mamba.py)), which implements a simple Mamba model given a configuration. 34 | No embedding, no head : input is `(B, L, D)` and output is `(B, L, D)` as well. 35 | 36 | ``` 37 | import torch 38 | from mamba import Mamba, MambaConfig 39 | 40 | config = MambaConfig(d_model=16, n_layers=2) 41 | model = Mamba(config) 42 | 43 | B, L, D = 2, 64, 16 44 | x = torch.randn(B, L, D) 45 | y = model(x) 46 | 47 | assert y.shape == x.shape 48 | ``` 49 | 50 | The class `MambaLM` ([mamba_lm.py](mamba_lm.py)) builds on the `Mamba` object and offers a classic API for language models. It can be used as follows : 51 | 52 | ``` 53 | from mamba_lm import MambaLM, MambaLMConfig 54 | 55 | config = MambaLMConfig(d_model=16, n_layers=4, vocab_size=32000) 56 | model = MambaLM(config) 57 | 58 | x = torch.randint(high=32000, size=(16, 64)) 59 | logits = model(x) # (B, L, vocab_size) 60 | ``` 61 | 62 | It simply encapsulates a `Mamba` object with an embedding layer, a final normalization and a language modeling head. 63 | 64 | ## Examples 65 | There are two basics examples available : 66 | - `example_llm.ipynb` : load a Mamba model with pretrained weights (from 130M to 2.8B from HuggingFace) 67 | - `example_e2e_training.ipynb` : an end-to-end training example where a Mamba model is employed as a world model for a simple 3-3 grid game (training is not completed, the model should be larger). 68 | 69 | If you want a full training example (like in llama2.c), you can check the [othello_mamba repo](https://github.com/alxndrTL/othello_mamba) I've done. With this repo, you can train a Mamba from scratch, easily swipe it with a Transformer, come up with your own data, etc ... 70 | 71 | ___ 72 | ## Performances 73 | This section provides a more comprehensive performance comparison between `mamba.py` and the official Mamba implementation. 74 | Overall, as the first graph of this file shows, both have approximately the same asymptotic performance with respect to the sequence length. You can think as `mamba.py` as a regular Transformer implementation, while the official Mamba implementation is more like FlashAttention v1. Both have their owns advantages. 75 | 76 | That being said, does the two implementations have the same asymptotic performances with respect to the other parameters ? 77 | 78 | ##### `d_model` asymptotic performances 79 |

80 | a python and a mamba 82 |

83 | 84 | We can see that both implementations behave the same as we increase `d_model`. The gap between the two stays roughly the same. (`mamba.py` is overall ~2x slower) 85 | 86 | ##### `d_state` asymptotic performances 87 |

88 | a python and a mamba 90 |

91 | 92 | This graph is important. We see that here, the asymptotic performance is not the same as we increase `d_state`. For a reminder, `d_state`, or $N$ in the paper, is the state expansion factor : each channel of the input is expanded into $N$ channels of the hidden state. 93 | 94 | Note : the CUDA version doesn't seem to be impacted by the increase of `d_state`. This is because the benchmark was done with a batch size of 1 : the GPU was not at its full capacity and thus the impact of an increased `d_state` isn't visible. The same happens if you have a small model, or a small input length. See [this issue](https://github.com/alxndrTL/mamba.py/issues/8). 95 | 96 | Does it matter in practice ? As of now, all the pretrained Mamba models (up to 2.8B parameters) used `d_state=16`, so this change of performance over `d_state` isn't important in this case. As `d_state` is not something that is supposed to grow (contrary to the seq length or `d_model`), this isn't a catastrophic result, but something to consider. 97 | 98 | However, it is interesting to relate this observation with the claim made by Albert Gu and Tri Dao [Mamba paper](https://arxiv.org/abs/2312.00752) : The main idea is to leverage properties of modern accelerators (GPUs) to materialize the state ℎ only in more efficient levels of the memory hierarchy. 99 | They also describe (Annex D) the main data movements of their selective scan : working mainly in SRAM, they can reduce the memory reads/writes by a factor of $O(N)$. This explains the different asymptotic behaviors that we see here. 100 | 101 | With `d_state=16` (as in `state-spaces/mamba-2.8b-slimpj`), the gap between the two is relatively small, but with `d_state=64` (currently not used in any models), the gap widens. (note the OOM on the second graph) 102 | 103 |

104 | a python and a mamba 106 |

107 | 108 | All the previous graph were computed with a batch size of 1, on a A100 80GB. 109 | It is a measure of both the forward and backward pass of a single Mamba block. 110 | 111 | The previous analysis showed the importance of kernel fusion, which reduces the memory accesses by $O(N)$, which makes the whole process faster. 112 | 113 | But memory requierement should also be considered : the official Mamba implementation uses recomputation in the backward pass : rather than keeping in memory the activations computed during the forward pass, it simply recomputes them in the backward pass, when needed. This greatly reduces the memory requierement of the Mamba model when doing training. This is not implemented in this repo. 114 | 115 | Hence, this repo implements one of the three techniques mentionned in the Mamba paper that form the so called "hardware-aware selective scan" : the parallel scan. 116 | We say how kernel fusion impacts the speed while recomputation the memory requierements. 117 | 118 | ___ 119 | ## Sources and where to learn more 120 | - the [Mamba paper](https://arxiv.org/abs/2312.00752) : describes the Mamba architecture as implemented in this repo, which allows to model sequences in linear time. 121 | - the [Mamba implementation](https://github.com/state-spaces/mamba), which is written in PyTorch but uses a parallel scan written in CUDA. This is the version that is the fastest. 122 | - [a minimal PyTorch implementation of Mamba](https://github.com/johnma2006/mamba-minimal), which implements the scan operation as a sequential loop (its performance are a bit worse than the 'sequential' line in the first graph). This code closely follows [this file](https://github.com/state-spaces/mamba/blob/da2626b5a5f347a8e844ac5e96a2cbcde3c34abb/mamba_ssm/modules/mamba_simple.py) from the officile Mamba implementation, but replaces the CUDA convolution with `torch.nn.Conv1d`, and the selective scan written in CUDA with a sequential loop. The code of this repo follows the structure of these 2 files. 123 | - [Prefix Sums and Their Applications](https://www.cs.cmu.edu/~guyb/papers/Ble93.pdf), by Guy E. Blelloch (1993). 124 | - [Parallelizing Linear Recurrent Neural Nets Over Sequence Length](https://arxiv.org/abs/1709.04057) : applies a parallel scan over the sequence in order to get rid of the sequential for-loop. 125 | - x.com/fchollet : original pscan implementation. 126 | 127 | ## TODOs 128 | - docs 129 | - ~~more tests with an increased `d_model` (add a Performances section)~~ 130 | - ~~a step function, used for (auto-regressive) inference.~~ 131 | - ~~a training function, similar to [llama2.c](https://github.com/karpathy/llama2.c)~~ 132 | 133 | perfs : 134 | - ~~unfold the for-loops in `pscan.py` to achieve better performance (see [François Fleuret's pscan](https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=mygptrnn.git;a=blob;f=pscan.py;h=0bb0d145bf9c6c82115956c8ce1e6a063e56e747;hb=HEAD)) (although this will sacrifice readability of bit)~~ 135 | ~~- write a reverse parallel scan specifically for the backward pass. (For now, we have to flip the array before and after the scan).~~ 136 | - enable gradient checkpointing to reduce the memory usage 137 | - use torch.compile(). As far as I tested, it doesn’t work for now. It seems it isn’t happy with the custom PScan autograd function. Need to investigate. (see [PR#1](https://github.com/alxndrTL/mamba.py/pull/1)) 138 | -------------------------------------------------------------------------------- /mambapy/assets/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SCNU-RISLAB/OverlapMamba/b20add9b83ea87727a993eb316869d271436be78/mambapy/assets/logo.png -------------------------------------------------------------------------------- /mambapy/assets/speed_comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SCNU-RISLAB/OverlapMamba/b20add9b83ea87727a993eb316869d271436be78/mambapy/assets/speed_comparison.png -------------------------------------------------------------------------------- /mambapy/assets/training_vs_d_model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SCNU-RISLAB/OverlapMamba/b20add9b83ea87727a993eb316869d271436be78/mambapy/assets/training_vs_d_model.png -------------------------------------------------------------------------------- /mambapy/assets/training_vs_d_state.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SCNU-RISLAB/OverlapMamba/b20add9b83ea87727a993eb316869d271436be78/mambapy/assets/training_vs_d_state.png -------------------------------------------------------------------------------- /mambapy/assets/training_vs_seqlen_d_state_var.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SCNU-RISLAB/OverlapMamba/b20add9b83ea87727a993eb316869d271436be78/mambapy/assets/training_vs_seqlen_d_state_var.png -------------------------------------------------------------------------------- /mambapy/docs/assets/cumsum_rnns.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SCNU-RISLAB/OverlapMamba/b20add9b83ea87727a993eb316869d271436be78/mambapy/docs/assets/cumsum_rnns.jpg -------------------------------------------------------------------------------- /mambapy/docs/assets/down_sweep_rule.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SCNU-RISLAB/OverlapMamba/b20add9b83ea87727a993eb316869d271436be78/mambapy/docs/assets/down_sweep_rule.jpg -------------------------------------------------------------------------------- /mambapy/docs/assets/downsweep.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SCNU-RISLAB/OverlapMamba/b20add9b83ea87727a993eb316869d271436be78/mambapy/docs/assets/downsweep.jpg -------------------------------------------------------------------------------- /mambapy/docs/assets/downsweep_ex.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SCNU-RISLAB/OverlapMamba/b20add9b83ea87727a993eb316869d271436be78/mambapy/docs/assets/downsweep_ex.jpg -------------------------------------------------------------------------------- /mambapy/docs/assets/downsweep_updated.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SCNU-RISLAB/OverlapMamba/b20add9b83ea87727a993eb316869d271436be78/mambapy/docs/assets/downsweep_updated.jpg -------------------------------------------------------------------------------- /mambapy/docs/assets/reduction_tree.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SCNU-RISLAB/OverlapMamba/b20add9b83ea87727a993eb316869d271436be78/mambapy/docs/assets/reduction_tree.jpg -------------------------------------------------------------------------------- /mambapy/docs/assets/tensor_mem.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SCNU-RISLAB/OverlapMamba/b20add9b83ea87727a993eb316869d271436be78/mambapy/docs/assets/tensor_mem.jpeg -------------------------------------------------------------------------------- /mambapy/docs/assets/tensor_mem_tree.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SCNU-RISLAB/OverlapMamba/b20add9b83ea87727a993eb316869d271436be78/mambapy/docs/assets/tensor_mem_tree.jpeg -------------------------------------------------------------------------------- /mambapy/docs/assets/tree_reduction.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SCNU-RISLAB/OverlapMamba/b20add9b83ea87727a993eb316869d271436be78/mambapy/docs/assets/tree_reduction.jpeg -------------------------------------------------------------------------------- /mambapy/docs/assets/tree_reduction_xs.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SCNU-RISLAB/OverlapMamba/b20add9b83ea87727a993eb316869d271436be78/mambapy/docs/assets/tree_reduction_xs.jpeg -------------------------------------------------------------------------------- /mambapy/docs/assets/up_down_trees.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SCNU-RISLAB/OverlapMamba/b20add9b83ea87727a993eb316869d271436be78/mambapy/docs/assets/up_down_trees.jpg -------------------------------------------------------------------------------- /mambapy/examples/buffer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | # todo : numpy est une bottleneck ? on a tout qui est en pytorch, on passe par numpy, pour ensuite reconvertir en pytorch pour le training... 4 | # pour la V1 (où la récolte et le training se déroulent pas en mm temps) ce n'est pas très genant 5 | 6 | # todo : ce qui est récolté en même temps est fourni en même temps dans les batch (ie across env)... encore une fois, pg pour la V1 les envs sont continuing (cas très particulier...) 7 | 8 | class ReplayBuffer(): 9 | def __init__(self, num_envs, capacity, obs_dim, act_dim): 10 | self.obs_buffer = np.empty((capacity, num_envs, obs_dim), dtype=np.uint8) 11 | self.act_buffer = np.empty((capacity, num_envs), dtype=np.uint8) # no one-hot 12 | self.rew_buffer = np.empty((capacity, num_envs), dtype=np.float32) 13 | 14 | self.num_envs = num_envs 15 | self.capacity = capacity 16 | self.idx = 0 17 | self.size = 0 18 | self.rng = np.random.default_rng() 19 | 20 | def store(self, obs, act, rew): 21 | # obs : (num_envs, L*L) 22 | # act : (num_envs,) 23 | # rew : (num_envs,) 24 | 25 | self.obs_buffer[self.idx] = obs 26 | self.act_buffer[self.idx] = act 27 | self.rew_buffer[self.idx] = rew 28 | 29 | self.idx = (self.idx + 1) % self.capacity 30 | self.size = min(self.size + 1, self.capacity) 31 | 32 | def sample(self, batch_size, batch_len): 33 | assert self.size >= batch_len, "not enough experience stored" 34 | 35 | start_idxs = self.rng.integers(0, self.size - batch_len, size=batch_size) # (B,) 36 | env_idxs = self.rng.integers(0, self.num_envs, size=batch_size)[:, None] # (B, 1) 37 | 38 | # all indices for sampling : from start_idxs to start_idxs+batch_len 39 | idxs = start_idxs[:, None] + np.arange(batch_len) # (B, batch_len) 40 | 41 | batch_obs = self.obs_buffer[idxs, env_idxs] 42 | batch_acts = self.act_buffer[idxs, env_idxs] 43 | batch_rews = self.rew_buffer[idxs, env_idxs] 44 | 45 | batch = {'obs': batch_obs, 'act': batch_acts, 'rew': batch_rews} 46 | return batch 47 | -------------------------------------------------------------------------------- /mambapy/examples/example_e2e_training.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 10, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "import torch.nn.functional as F\n", 11 | "\n", 12 | "import time\n", 13 | "from IPython.display import clear_output\n", 14 | "\n", 15 | "from example_src.tinyhome import TinyHomeEngineV1, print_grid, print_act\n", 16 | "from example_src.buffer import ReplayBuffer\n", 17 | "\n", 18 | "from mamba_lm import MambaLM, MambaLMConfig" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 3, 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": 4, 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "L = 5\n", 37 | "num_actions = 5\n", 38 | "num_obs_type = 4\n", 39 | "\n", 40 | "nb_instances = 512\n", 41 | "steps = 10000\n", 42 | "\n", 43 | "envs = TinyHomeEngineV1(B=nb_instances, h=L, w=L)\n", 44 | "buffer = ReplayBuffer(num_envs=nb_instances, capacity=int(1e6), obs_dim=L*L, act_dim=num_actions)\n", 45 | "\n", 46 | "obs = envs.reset()\n", 47 | "\n", 48 | "for _ in range(steps):\n", 49 | " a = torch.randint(low=0, high=num_actions, size=(nb_instances,))\n", 50 | " next_obs, rew = envs.step(a)\n", 51 | "\n", 52 | " buffer.store(obs.view(-1, L*L), a, rew.squeeze(1))\n", 53 | " obs = next_obs" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 5, 59 | "metadata": {}, 60 | "outputs": [ 61 | { 62 | "name": "stderr", 63 | "output_type": "stream", 64 | "text": [ 65 | "/home/alex/miniconda3/envs/torch23/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 66 | " from .autonotebook import tqdm as notebook_tqdm\n" 67 | ] 68 | } 69 | ], 70 | "source": [ 71 | "config = MambaLMConfig(d_model=16, n_layers=4, vocab_size=num_actions+num_obs_type, pad_vocab_size_multiple=num_actions+num_obs_type)\n", 72 | "model = MambaLM(config).to(device)\n", 73 | "optim = torch.optim.AdamW(model.parameters(), lr=3e-3)" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": 6, 79 | "metadata": {}, 80 | "outputs": [ 81 | { 82 | "data": { 83 | "text/plain": [ 84 | "13664" 85 | ] 86 | }, 87 | "execution_count": 6, 88 | "metadata": {}, 89 | "output_type": "execute_result" 90 | } 91 | ], 92 | "source": [ 93 | "sum([p.numel() for p in model.parameters()])" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": 7, 99 | "metadata": {}, 100 | "outputs": [ 101 | { 102 | "name": "stdout", 103 | "output_type": "stream", 104 | "text": [ 105 | "6.719587326049805\n", 106 | "0.38883084058761597\n", 107 | "0.26524049043655396\n", 108 | "0.22898846864700317\n", 109 | "0.21574825048446655\n", 110 | "0.1874855011701584\n", 111 | "0.1682835817337036\n", 112 | "0.14247804880142212\n", 113 | "0.1338169276714325\n", 114 | "0.10260577499866486\n" 115 | ] 116 | } 117 | ], 118 | "source": [ 119 | "for i in range(1000): \n", 120 | " B, T = 64, 10\n", 121 | " batch = buffer.sample(B, T)\n", 122 | "\n", 123 | " obs = torch.tensor(batch['obs']).long().to(device)\n", 124 | " act = torch.tensor(batch['act']).long().to(device)\n", 125 | "\n", 126 | " tokens = torch.cat([obs, torch.zeros(B, T, 1, dtype=torch.int, device='cuda')], dim=2).view(B, 26*T) # (B, 26T)\n", 127 | " tokens[:, 25::26] = act+4\n", 128 | "\n", 129 | " input = tokens\n", 130 | " output = tokens[:, 1:].reshape(-1)\n", 131 | "\n", 132 | " logits = model(tokens[:, :-1]) # (B, 26T-1, vocab_size)\n", 133 | " loss = F.cross_entropy(logits.view(-1, logits.size(-1)), output)\n", 134 | "\n", 135 | " optim.zero_grad()\n", 136 | " loss.backward()\n", 137 | " optim.step()\n", 138 | "\n", 139 | " if i%100==0:\n", 140 | " print(loss.item())" 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": 15, 146 | "metadata": {}, 147 | "outputs": [], 148 | "source": [ 149 | "tokens = torch.ones(1, 2, dtype=torch.long).cuda() # (B=1, 2)\n", 150 | "T = 20\n", 151 | "for _ in range(26*T-2):\n", 152 | " logits = model(tokens)[0, -1]\n", 153 | " probs = F.softmax(logits, dim=0)\n", 154 | " sampled = torch.multinomial(probs, num_samples=1, replacement=True)\n", 155 | " tokens = torch.cat([tokens, sampled.view(1, 1)], dim=1)\n", 156 | "tokens = tokens.view(T, 26) # (T, 26)" 157 | ] 158 | }, 159 | { 160 | "cell_type": "code", 161 | "execution_count": 16, 162 | "metadata": {}, 163 | "outputs": [ 164 | { 165 | "name": "stdout", 166 | "output_type": "stream", 167 | "text": [ 168 | "#####\n", 169 | "# #\n", 170 | "# #\n", 171 | "#G@ #\n", 172 | "#####\n", 173 | "\n", 174 | "\n", 175 | "E\n" 176 | ] 177 | } 178 | ], 179 | "source": [ 180 | "for timestep in tokens:\n", 181 | " grid = timestep[:-1].view(1, 5, 5)\n", 182 | " a = timestep[-1]-4\n", 183 | "\n", 184 | " clear_output(wait=True)\n", 185 | " print_grid(grid)\n", 186 | " print_act(a.item())\n", 187 | " time.sleep(0.1)" 188 | ] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "execution_count": null, 193 | "metadata": {}, 194 | "outputs": [], 195 | "source": [] 196 | } 197 | ], 198 | "metadata": { 199 | "kernelspec": { 200 | "display_name": "torch23", 201 | "language": "python", 202 | "name": "python3" 203 | }, 204 | "language_info": { 205 | "codemirror_mode": { 206 | "name": "ipython", 207 | "version": 3 208 | }, 209 | "file_extension": ".py", 210 | "mimetype": "text/x-python", 211 | "name": "python", 212 | "nbconvert_exporter": "python", 213 | "pygments_lexer": "ipython3", 214 | "version": "3.11.5" 215 | } 216 | }, 217 | "nbformat": 4, 218 | "nbformat_minor": 2 219 | } 220 | -------------------------------------------------------------------------------- /mambapy/examples/example_llm.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 3, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "\n", 11 | "from transformers import AutoTokenizer\n", 12 | "\n", 13 | "from mamba_lm import from_pretrained" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 4, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 7, 28 | "metadata": {}, 29 | "outputs": [ 30 | { 31 | "name": "stderr", 32 | "output_type": "stream", 33 | "text": [ 34 | "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" 35 | ] 36 | } 37 | ], 38 | "source": [ 39 | "model = from_pretrained('state-spaces/mamba-130m').to(device)\n", 40 | "tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b')" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": 8, 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "output = model.generate(tokenizer, \"Mamba is a type of\")" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": 9, 55 | "metadata": {}, 56 | "outputs": [ 57 | { 58 | "name": "stdout", 59 | "output_type": "stream", 60 | "text": [ 61 | "Mamba is a type of black sheep. Many types of black sheep, from black hares, black mares and black oxen, to Mamba, which is a black sheep. Mamba is the black sheep in Mocha. As they move in and out of one\n" 62 | ] 63 | } 64 | ], 65 | "source": [ 66 | "print(output)" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": null, 72 | "metadata": {}, 73 | "outputs": [], 74 | "source": [] 75 | } 76 | ], 77 | "metadata": { 78 | "kernelspec": { 79 | "display_name": "torch23", 80 | "language": "python", 81 | "name": "python3" 82 | }, 83 | "language_info": { 84 | "codemirror_mode": { 85 | "name": "ipython", 86 | "version": 3 87 | }, 88 | "file_extension": ".py", 89 | "mimetype": "text/x-python", 90 | "name": "python", 91 | "nbconvert_exporter": "python", 92 | "pygments_lexer": "ipython3", 93 | "version": "3.11.5" 94 | } 95 | }, 96 | "nbformat": 4, 97 | "nbformat_minor": 2 98 | } 99 | -------------------------------------------------------------------------------- /mambapy/examples/tinyhome.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | import os 5 | import time 6 | 7 | """ 8 | V1 : fully observable, tiny room with a goal generated randomly. 9 | going over the goal gets the agent a reward and spawn a new goal. 10 | the episodes ends after T steps. 11 | 12 | it is convenient because: 13 | - all envs end at the same time (both convenient for the env engine, AND for the training of the transformer : no padding needed) 14 | - no obstacles to handle (convenient in the step() func) 15 | - continuing, so its easier for the replay buffer (see buffer.py) 16 | """ 17 | class TinyHomeEngineV1: 18 | def __init__(self, B, h=10, w=10, max_envs_disp=4): 19 | self.B = B 20 | self.h = h 21 | self.w = w 22 | self.max_envs_disp = max_envs_disp 23 | 24 | def reset(self): 25 | self.grid = torch.zeros(self.B, self.h, self.w, dtype=torch.int) 26 | self.grid[:, 0, :] = 1 27 | self.grid[:, -1, :] = 1 28 | self.grid[:, :, 0] = 1 29 | self.grid[:, :, -1] = 1 30 | 31 | self.pos_player = torch.randint(low=1, high=self.h-1, size=(self.B, 2)) 32 | self.pos_goal = torch.randint(low=1, high=self.h-1, size=(self.B, 2)) 33 | 34 | while True: 35 | overlap = torch.all(self.pos_player == self.pos_goal, dim=1) 36 | if not overlap.any(): 37 | break 38 | self.pos_goal[overlap] = torch.randint(low=1, high=self.h-1, size=(overlap.sum(), 2)) 39 | 40 | disp_grid = self.grid.clone() 41 | disp_grid[torch.arange(self.B), self.pos_player[:, 0], self.pos_player[:, 1]] = 2 42 | disp_grid[torch.arange(self.B), self.pos_goal[:, 0], self.pos_goal[:, 1]] = 3 43 | 44 | """ 45 | x = F.one_hot(self.pos_player[:, 0]-1, num_classes=3) 46 | y = F.one_hot(self.pos_player[:, 1]-1, num_classes=3) 47 | u = F.one_hot(self.pos_goal[:, 0]-1, num_classes=3) 48 | v = F.one_hot(self.pos_goal[:, 1]-1, num_classes=3) 49 | 50 | concatenated = torch.cat([x, y, u, v], dim=1) # (B, 12) 51 | """ 52 | 53 | return disp_grid 54 | 55 | def optimal_policy_vectorized(self, moves): 56 | B, _ = self.pos_player.shape 57 | 58 | # Expand pos_player to (B, 5, 2) to match the moves 59 | expanded_pos_player = self.pos_player.unsqueeze(1).expand(-1, moves.size(0), -1) 60 | 61 | # Compute new positions for each move 62 | new_positions = expanded_pos_player + moves 63 | new_positions = new_positions.clamp(min=1, max=self.h-2) 64 | 65 | # Calculate Manhattan distances for each new position 66 | distances = torch.sum(torch.abs(new_positions - self.pos_goal.unsqueeze(1)), dim=2) 67 | 68 | # Find the move with the minimum distance for each environment 69 | actions = torch.argmin(distances, dim=1) 70 | 71 | return actions 72 | 73 | def step(self, a): 74 | # a : (B,) 75 | 76 | moves = torch.tensor([[0, 0], [-1, 0], [0, 1], [1, 0], [0, -1]]) # X, N, E, S, W 77 | 78 | #a = self.optimal_policy_vectorized(moves) 79 | 80 | self.pos_player += moves[a] 81 | self.pos_player = self.pos_player.clamp(min=1, max=self.h-2) # pas du tout généralisable à des murs placés au milieu etc etc 82 | 83 | reached_goal = torch.all(self.pos_player == self.pos_goal, dim=1) 84 | reward = torch.where(reached_goal, 1., 0.).unsqueeze(1) 85 | 86 | # regen goal (only for "completed" env) 87 | num_reached = reached_goal.sum() 88 | if num_reached > 0: 89 | self.pos_goal[reached_goal] = torch.randint(low=1, high=self.h-1, size=(num_reached, 2)) 90 | 91 | # make sure that the regenerated goals are at a different place 92 | while True: 93 | overlap = torch.all(self.pos_player == self.pos_goal, dim=1) 94 | if not overlap.any(): 95 | break 96 | self.pos_goal[overlap] = torch.randint(low=1, high=self.h-1, size=(overlap.sum(), 2)) 97 | 98 | disp_grid = self.grid.clone() 99 | disp_grid[torch.arange(self.B), self.pos_player[:, 0], self.pos_player[:, 1]] = 2 100 | disp_grid[torch.arange(self.B), self.pos_goal[:, 0], self.pos_goal[:, 1]] = 3 101 | 102 | """ 103 | x = F.one_hot(self.pos_player[:, 0]-1, num_classes=3) 104 | y = F.one_hot(self.pos_player[:, 1]-1, num_classes=3) 105 | u = F.one_hot(self.pos_goal[:, 0]-1, num_classes=3) 106 | v = F.one_hot(self.pos_goal[:, 1]-1, num_classes=3) 107 | 108 | concatenated = torch.cat([x, y, u, v], dim=1) # (B, 12) 109 | """ 110 | 111 | return disp_grid, reward 112 | 113 | def display(self): 114 | os.system('cls' if os.name == 'nt' else 'clear') 115 | 116 | disp_grid = self.grid.clone() 117 | disp_grid[torch.arange(self.B), self.pos_player[:, 0], self.pos_player[:, 1]] = 2 118 | disp_grid[torch.arange(self.B), self.pos_goal[:, 0], self.pos_goal[:, 1]] = 3 119 | 120 | for b in range(min(self.B, self.max_envs_disp)): 121 | for row in disp_grid[b]: 122 | print(''.join(display_mapping.get(value.item(), '?') for value in row)) 123 | 124 | print("\n") 125 | 126 | display_mapping = { 127 | 0: ' ', 128 | 1: '#', 129 | 2: '@', 130 | 3: 'G' 131 | } 132 | 133 | def print_grid(grid): 134 | for b in range(grid.shape[0]): 135 | for row in grid[b]: 136 | print(''.join(display_mapping.get(value.item(), '?') for value in row)) 137 | 138 | print("\n") 139 | 140 | actions_to_char = ['X', 'N', 'E', 'S', 'W'] 141 | def print_act(act): 142 | print(actions_to_char[act]) 143 | 144 | if __name__ == "__main__": 145 | mode = "display" # "display" or "grind" or "collect" 146 | 147 | if mode == "display": 148 | nb_instances = 2 149 | steps = 10 150 | 151 | engine = TinyHomeEngineV1(nb_instances, 5, 5) 152 | engine.reset() 153 | 154 | engine.display() 155 | time.sleep(0.5) 156 | 157 | for _ in range(steps): 158 | obs, rew = engine.step(torch.randint(low=0, high=5, size=(nb_instances,))) 159 | print(obs) 160 | #engine.display() 161 | print(rew.shape) 162 | time.sleep(0.05) 163 | 164 | elif mode == "grind": 165 | nb_instances = 1000 166 | steps = 1000 167 | 168 | engine = TinyHomeEngineV1(nb_instances) 169 | engine.reset() 170 | 171 | start_time = time.perf_counter() 172 | 173 | for _ in range(steps): 174 | obs, rew = engine.step(torch.randint(low=0, high=5, size=(nb_instances,))) 175 | 176 | end_time = time.perf_counter() 177 | 178 | print(f"The collection of {nb_instances*steps} steps took {end_time-start_time} seconds") 179 | 180 | elif mode == "collect": 181 | nb_instances = 2 182 | steps = 1 183 | 184 | embed = torch.nn.Embedding(num_embeddings=6, embedding_dim=2) 185 | 186 | engine = TinyHomeEngineV1(nb_instances, 5, 5) 187 | engine.reset() 188 | 189 | for _ in range(steps): 190 | obs, rew = engine.step(torch.randint(low=0, high=5, size=(nb_instances,))) 191 | # obs: (B, h, w), rew: (B, 1) 192 | 193 | obs = obs.view(nb_instances, 25) # (B, h*w) 194 | 195 | e = embed(obs) # (B, h*w, embed_dim) 196 | print(e.shape) -------------------------------------------------------------------------------- /mambapy/mamba_lm.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, fields, asdict 2 | import json 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from mamba import Mamba, MambaConfig, RMSNorm 9 | 10 | """ 11 | 12 | Encapsulates a Mamba model as language model. It has an embedding layer, and a LM head which maps the model output to logits. 13 | 14 | """ 15 | 16 | # TODO generate function : batch size != 1 ? (for now B=1) 17 | # TODO generate function : top-p sampling 18 | 19 | @dataclass 20 | class MambaLMConfig(MambaConfig): 21 | vocab_size: int = 32000 22 | pad_vocab_size_multiple: int = 8 23 | 24 | def __post_init__(self): 25 | super().__post_init__() 26 | 27 | if self.vocab_size % self.pad_vocab_size_multiple != 0: 28 | self.vocab_size += (self.pad_vocab_size_multiple - self.vocab_size % self.pad_vocab_size_multiple) 29 | 30 | def to_mamba_config(self) -> MambaConfig: 31 | mamba_config_fields = {field.name for field in fields(MambaConfig)} 32 | filtered_dict = {k: v for k, v in asdict(self).items() if k in mamba_config_fields} 33 | return MambaConfig(**filtered_dict) 34 | 35 | # adapted from https://github.com/johnma2006/mamba-minimal 36 | def from_pretrained(name: str): 37 | """ 38 | Returns a model loaded with pretrained weights pulled from HuggingFace. 39 | 40 | Args: 41 | name: As of now, supports 42 | * 'state-spaces/mamba-2.8b-slimpj' 43 | * 'state-spaces/mamba-2.8b' 44 | * 'state-spaces/mamba-1.4b' 45 | * 'state-spaces/mamba-790m' 46 | * 'state-spaces/mamba-370m' 47 | * 'state-spaces/mamba-130m' 48 | 49 | Returns: 50 | model: a Mamba model configured with the proper parameters and initialized with the proper weights 51 | """ 52 | 53 | from transformers.utils import WEIGHTS_NAME, CONFIG_NAME 54 | from transformers.utils.hub import cached_file 55 | 56 | def load_config_hf(model_name): 57 | resolved_archive_file = cached_file(model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=False) 58 | return json.load(open(resolved_archive_file)) 59 | 60 | def load_state_dict_hf(model_name): 61 | resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False) 62 | return torch.load(resolved_archive_file, weights_only=True, map_location='cpu', mmap=True) 63 | 64 | # copy config data 65 | config_data = load_config_hf(name) 66 | config = MambaLMConfig(d_model=config_data['d_model'], n_layers=config_data['n_layer'], vocab_size=config_data['vocab_size']) 67 | 68 | model = MambaLM(config) 69 | 70 | # copy weights 71 | state_dict = load_state_dict_hf(name) 72 | 73 | new_state_dict = {} 74 | for key in state_dict: 75 | if key == 'backbone.embedding.weight' or key == 'backbone.norm_f.weight': 76 | new_key = key.replace('backbone.', '') 77 | else: 78 | new_key = key.replace('backbone', 'mamba') 79 | 80 | new_state_dict[new_key] = state_dict[key] 81 | 82 | model.load_state_dict(new_state_dict) 83 | 84 | return model 85 | 86 | class MambaLM(nn.Module): 87 | def __init__(self, lm_config: MambaLMConfig): 88 | super().__init__() 89 | self.lm_config = lm_config 90 | self.config = lm_config.to_mamba_config() 91 | 92 | self.embedding = nn.Embedding(self.lm_config.vocab_size, self.config.d_model) 93 | self.mamba = Mamba(self.config) 94 | self.norm_f = RMSNorm(self.config.d_model) 95 | 96 | self.lm_head = nn.Linear(self.config.d_model, self.lm_config.vocab_size, bias=False) 97 | self.lm_head.weight = self.embedding.weight 98 | 99 | def forward(self, tokens): 100 | # tokens : (B, L) 101 | 102 | # logits : (B, L, vocab_size) 103 | 104 | x = self.embedding(tokens) 105 | 106 | x = self.mamba(x) 107 | x = self.norm_f(x) 108 | 109 | logits = self.lm_head(x) 110 | 111 | return logits 112 | 113 | def step(self, token, caches): 114 | # token : (B) 115 | # caches : [cache(layer) for all layers], cache : (h, inputs) 116 | 117 | # logits : (B, vocab_size) 118 | # caches : [cache(layer) for all layers], cache : (h, inputs) 119 | 120 | x = self.embedding(token) 121 | 122 | x, caches = self.mamba.step(x, caches) 123 | x = self.norm_f(x) 124 | 125 | logits = self.lm_head(x) 126 | 127 | return logits, caches 128 | 129 | # TODO temperature 130 | # TODO process prompt in parallel, and pass in sequential mode when prompt is finished ? 131 | def generate(self, tokenizer, prompt: str, num_tokens: int = 50, sample: bool = True, top_k: int = 40): 132 | self.eval() 133 | 134 | input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(next(self.parameters()).device) # (1, num_tokens) 135 | 136 | # caches is a list of cache, one per layer 137 | # cache is composed of : the hidden state, and the last d_conv-1 inputs 138 | # the hidden state because the update is like an RNN 139 | # the last d_conv-1 inputs because they are used in a 1d convolution (usually d_conv=4 so this is not large) 140 | caches = [(None, torch.zeros(1, self.config.d_inner, self.config.d_conv-1, device=input_ids.device)) for _ in range(self.config.n_layers)] 141 | 142 | for i in range(input_ids.size(1) + num_tokens - 1): 143 | with torch.no_grad(): 144 | # forward the new output, get new cache 145 | next_token_logits, caches = self.step(input_ids[:, i], caches) # (1, vocab_size), caches 146 | 147 | # sample (no sampling when the prompt is being processed) 148 | if i+1 >= input_ids.size(1): 149 | probs = F.softmax(next_token_logits, dim=-1) # (1, vocab_size) 150 | 151 | if top_k is not None: 152 | values, _ = torch.topk(probs, k=top_k) # (1, k) ordered from lowest to biggest 153 | probs[probs < values[:, -1, None]] = 0 154 | probs = probs / probs.sum(axis=1, keepdims=True) 155 | 156 | if sample: 157 | next_token = torch.multinomial(probs, num_samples=1).squeeze(1) # (1) 158 | else: 159 | next_token = torch.argmax(probs, dim=-1) # (1) 160 | 161 | input_ids = torch.cat([input_ids, next_token.unsqueeze(1)], dim=1) 162 | 163 | output = [tokenizer.decode(output.tolist()) for output in input_ids][0] 164 | 165 | self.train() 166 | 167 | return output 168 | -------------------------------------------------------------------------------- /mambapy/mlx/README.md: -------------------------------------------------------------------------------- 1 | # MLX implementation of Mamba 🐍 2 | 3 | This folder contains a complete MLX implementation of [Mamba](https://arxiv.org/abs/2312.00752), which allows to train and do inference with Mamba models using an Apple silicon equiped Mac. 4 | Both the forward and backward pass are numerically equivalent to the PyTorch code from `mamba.py`, as well as to the official [Mamba implementation](https://github.com/state-spaces/mamba). 5 | 6 |

7 | a python and a mamba 8 |

9 | 10 | The folder is organized as follows : 11 | - `pscan_mlx.py` : a MLX implementation of Blelloch's parallel scan. 12 | - `mamba_mlx.py` : the Mamba model, as described in the [paper](https://arxiv.org/abs/2312.00752). It is numerically equivalent (initialization, forward and backward pass). 13 | - `mamba_lm_mlx.py` : encapsulates a Mamba model in order to use it as a language model. 14 | - `utils.py` : utilitary functions 15 | - `misc.py` : a temporary file containing functions not yet implemented in MLX 16 | - `📁 scripts` : example scripts to play around with Mamba. 17 | 18 | # Quickstart 19 | Make sur you have `mlx`, `torch` and `transformers` installed. 20 | 21 | First, you can clone the repo : 22 | 23 | ``` 24 | git clone https://github.com/alxndrTL/mamba.py.git 25 | cd mamba.py/mlx 26 | ``` 27 | 28 | If you want to do inference with pretrained models (from 130M to 2.8B parameters), you can simply do : 29 | 30 | ``` 31 | cd scripts 32 | python3 generate.py --prompt="Mamba is a type of" --hf_model_name="state-spaces/mamba-130m" --n_tokens=100 33 | ``` 34 | 35 | It will download the specified model from [HuggingFace](https://huggingface.co/state-spaces), convert it (and save it to disk) to run with MLX, and stream generated words. 36 | As of now, you can choose from : 37 | 38 | ``` 39 | state-spaces/mamba-130m 40 | state-spaces/mamba-370m 41 | state-spaces/mamba-790m 42 | state-spaces/mamba-1.4b 43 | state-spaces/mamba-2.8b 44 | state-spaces/mamba-2.8b-slimpj 45 | ``` 46 | 47 | As of today, only single precision inference is supported. On an M2 Pro (16GB), the 790M model runs at ~30tok/s. 48 | 49 | Unlike the Transformers, inference doesn't depend on the sequence length, so we just have to carry along a hidden state 😎 (and the last `d_conv-1` inputs, where `d_conv` is usually 4). 50 | 51 | As of now, `generate.py` is the only available script. But you can train the model using your own script, just like you would with a Transformer. 52 | 53 | # About 54 | Mamba is a new state-space model that is able to do sequence modeling - just like Transformers do. 55 | While Transformers use attention to flow information through time, Mamba uses a simple hidden state, just like RNNs. It has the benefit of a constant inference time wrt. sequence length. 56 | What is important to know is that while it uses a hidden state that is updated sequentially through time : 57 | 58 | $$ 59 | h_t = A h_{t-1} + Bx_t 60 | $$ 61 | 62 | all the $h_t$ can actually be computed in parallel, thanks to an algorithm named the parallel scan, implemented in `pscan_mlx.py` in MLX. 63 | You can learn more about this algorithm and its implementation in `docs/pscan.ipynb` at the root of this repo. 64 | As you can see on the graph shown on the landing page of this repo, the naive sequential implementation is way slower than implementations than use this parallel scan. 65 | 66 | However, it's important to note that while the parallel scan gives correct computations with MLX, it's slow, so slow that it is sometimes actually harmful to use it. 67 | Why ? It is not yet clear. When translating the algorithm from PyTorch to MLX, a little modification is needed : at each iteration, we need to write back to our original arrays the numbers we computed. This is because MLX doesn't have views implemented (yet?). (see [this issue](https://github.com/ml-explore/mlx/issues/466)). I thus switched to a version which only uses slicing (see `pscan_mlx.py` for more details), but the performances are still lacking behind the sequential version (should be orders of magnitude faster). 68 | 69 | But, MLX is not even 2 months old :) 70 | I will closely follow MLX development to watch for potential upgrades of this MLX implementation. 71 | 72 | # Why [mamba.py](../) in MLX ? 73 | While the primary goal of the PyTorch version is educational, this implementation (with a performing parallel scan) could power future fine-tuning scripts. We are early, as there is still not much resources about fine-tuned Mamba models (see [this](https://github.com/havenhq/mamba-chat)). MLX doesn't yet have an associative scan operation implemented. 74 | 75 | Also, the more people play around and train Mamba models, the more we will be able to know better its strengths and limits, allowing us to compare it against its "competitors" (Based, RWKV, StripedHyena, or the Transformer). 76 | 77 | And finally, it was a great exercise for me, after having implemented Mamba in PyTorch and not knowing MLX. 78 | 79 | # TODOs 80 | - fix large memory footprint at inference ([issue](https://github.com/alxndrTL/mamba.py/issues/5)) 81 | - add more ready-to-go scripts (training and fine-tuning) 82 | - support for mixed precision training ? (see [this](https://github.com/state-spaces/mamba/tree/main?tab=readme-ov-file#precision) from the official Mamba implementation) 83 | - set device (cpu and gpu) (see [A Simple Example](https://ml-explore.github.io/mlx/build/html/usage/unified_memory.html#a-simple-example) from the MLX docs) 84 | - see TODOs of the PyTorch versions 85 | - watch out for new MLX updates ;) 86 | 87 | Feel free to contribute ! -------------------------------------------------------------------------------- /mambapy/mlx/assets/mamba_mlx.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SCNU-RISLAB/OverlapMamba/b20add9b83ea87727a993eb316869d271436be78/mambapy/mlx/assets/mamba_mlx.png -------------------------------------------------------------------------------- /mambapy/mlx/mamba_lm_mlx.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, fields, asdict 2 | import json 3 | 4 | import mlx.core as mx 5 | import mlx.nn as nn 6 | 7 | from mamba_mlx import Mamba, MambaConfig 8 | from misc import topk 9 | from utils import load_config_hf, load_state_dict_hf 10 | 11 | """ 12 | 13 | Encapsulates a Mamba model as language model. It has an embedding layer, and a LM head which maps the model output to logits. 14 | 15 | """ 16 | 17 | # TODO generate function : batch size != 1 ? (for now B=1) 18 | # TODO generate function : top-p sampling 19 | 20 | @dataclass 21 | class MambaLMConfig(MambaConfig): 22 | vocab_size: int = 32000 23 | pad_vocab_size_multiple: int = 8 24 | 25 | def __post_init__(self): 26 | super().__post_init__() 27 | 28 | if self.vocab_size % self.pad_vocab_size_multiple != 0: 29 | self.vocab_size += (self.pad_vocab_size_multiple - self.vocab_size % self.pad_vocab_size_multiple) 30 | 31 | def to_mamba_config(self) -> MambaConfig: 32 | mamba_config_fields = {field.name for field in fields(MambaConfig)} 33 | filtered_dict = {k: v for k, v in asdict(self).items() if k in mamba_config_fields} 34 | return MambaConfig(**filtered_dict) 35 | 36 | class MambaLM(nn.Module): 37 | def __init__(self, lm_config: MambaLMConfig): 38 | super().__init__() 39 | self.lm_config = lm_config 40 | self.config = lm_config.to_mamba_config() 41 | 42 | self.embedding = nn.Embedding(self.lm_config.vocab_size, self.config.d_model) 43 | self.mamba = Mamba(self.config) 44 | self.norm_f = nn.RMSNorm(self.config.d_model) 45 | 46 | self.lm_head = nn.Linear(self.config.d_model, self.lm_config.vocab_size, bias=False) 47 | self.lm_head.weight = self.embedding.weight #TODO this does not really tie the weights, investigate 48 | 49 | def __call__(self, tokens): 50 | # tokens : (B, L) 51 | 52 | # logits : (B, L, vocab_size) 53 | 54 | x = self.embedding(tokens) 55 | 56 | x = self.mamba(x) 57 | x = self.norm_f(x) 58 | 59 | logits = self.lm_head(x) 60 | 61 | return logits 62 | 63 | def step(self, token, caches): 64 | # token : (B) 65 | # caches : [cache(layer) for all layers], cache : (h, inputs) 66 | 67 | # logits : (B, vocab_size) 68 | # caches : [cache(layer) for all layers], cache : (h, inputs) 69 | 70 | x = self.embedding(token) 71 | 72 | x, caches = self.mamba.step(x, caches) 73 | x = self.norm_f(x) 74 | 75 | logits = self.lm_head(x) 76 | 77 | return logits, caches 78 | 79 | def generate(self, tokenizer, prompt: str, n_tokens_to_gen: int = 50, sample: bool = True, temperature: float = 1.0, top_k: int = None): 80 | self.eval() 81 | 82 | input_ids = mx.array(tokenizer(prompt, return_tensors='np').input_ids) # (1, tokens_prompt) # (1, num_tokens) 83 | 84 | # caches is a list of cache, one per layer 85 | # cache is composed of : the hidden state, and the last d_conv-1 inputs 86 | # the hidden state because the update is like an RNN 87 | # the last d_conv-1 inputs because they are used in a 1d convolution (usually d_conv=4 so this is not large) 88 | caches = [(None, mx.zeros([1, self.config.d_conv-1, self.config.d_inner])) for _ in range(self.config.n_layers)] 89 | 90 | for i in range(input_ids.shape[1] + n_tokens_to_gen - 1): 91 | next_token_logits, caches = self.step(input_ids[:, i], caches) # (1, vocab_size), caches 92 | 93 | # sample (no sampling when the prompt is being processed) 94 | if i+1 >= input_ids.shape[1]: 95 | 96 | if top_k is not None: 97 | values = topk(next_token_logits, k=top_k) # (1, k) ordered from lowest to biggest 98 | mask = next_token_logits < (values[:, 0, None]) 99 | next_token_logits = mx.where(mask, -5000, next_token_logits) # TODO -mx.inf is problematic for now 100 | 101 | if sample and temperature > 0: 102 | next_token = mx.random.categorical(next_token_logits * (1/temperature), num_samples=1) 103 | else: 104 | next_token = mx.argmax(next_token_logits, axis=-1)[:, None] 105 | 106 | input_ids = mx.concatenate([input_ids, next_token], axis=1) 107 | 108 | output = [tokenizer.decode(output.tolist()) for output in input_ids][0] 109 | 110 | self.train() 111 | 112 | return output 113 | 114 | @staticmethod 115 | def from_pretrained(name: str): 116 | """ 117 | Returns a model loaded with pretrained weights pulled from HuggingFace. 118 | 119 | Args: 120 | name: As of now, supports 121 | * 'state-spaces/mamba-2.8b-slimpj' 122 | * 'state-spaces/mamba-2.8b' 123 | * 'state-spaces/mamba-1.4b' 124 | * 'state-spaces/mamba-790m' 125 | * 'state-spaces/mamba-370m' 126 | * 'state-spaces/mamba-130m' 127 | 128 | Returns: 129 | model: a Mamba model configured with the proper parameters and initialized with the proper weights 130 | """ 131 | 132 | import os 133 | import numpy as np 134 | from mlx.utils import tree_unflatten 135 | 136 | from utils import map_mambassm_torch_to_mlx 137 | 138 | # copy config data 139 | config_data = load_config_hf(name) 140 | config = MambaLMConfig(d_model=config_data['d_model'], n_layers=config_data['n_layer'], vocab_size=config_data['vocab_size']) 141 | 142 | model = MambaLM(config) 143 | 144 | # copy weights 145 | filename = name.split('/')[-1] + '.mlx.npz' 146 | 147 | if not os.path.exists(filename): 148 | state_dict = load_state_dict_hf(name) 149 | mlx_state_dict = map_mambassm_torch_to_mlx(state_dict) 150 | 151 | np.savez(filename, **mlx_state_dict) 152 | 153 | model.update(tree_unflatten(list(mx.load(filename).items()))) 154 | 155 | return model 156 | -------------------------------------------------------------------------------- /mambapy/mlx/misc.py: -------------------------------------------------------------------------------- 1 | import mlx.core as mx 2 | import mlx.nn as nn 3 | 4 | import torch 5 | 6 | """ 7 | 8 | This is a temporary file, as it contains additional functions which are needed but not yet implemented in MLX (as of release v0.0.10). 9 | The first functions are straightforward, while the depthwise 1d convolution is a bit more elaborared. 10 | 11 | """ 12 | 13 | def softplus(x, beta=1, threshold=20): 14 | scaled_x = beta * x 15 | mask = scaled_x > threshold 16 | return mx.where(mask, x, 1/beta * mx.logaddexp(0, x)) 17 | 18 | def unsqueeze(x, axis): 19 | """ 20 | Same API as PyTorch. 21 | """ 22 | 23 | assert axis <= len(x.shape) 24 | if axis >= 0: 25 | new_shape = x.shape[:axis] + [1] + x.shape[axis:] 26 | else: 27 | new_shape = x.shape + [1] 28 | return x.reshape(new_shape) 29 | 30 | def clamp(x, min=None, max=None): 31 | if min is not None: 32 | mask_lower = x < min 33 | if max is not None: 34 | mask_upper = x > max 35 | 36 | if min is not None: 37 | if max is not None: 38 | return mx.where(mask_upper, max, mx.where(mask_lower, min, x)) 39 | return mx.where(mask_lower, min, x) 40 | 41 | return mx.where(mask_upper, max, x) 42 | 43 | def topk(x, k): 44 | """ 45 | Returns the top k biggest values of x along the 2nd dim. 46 | 47 | Args: 48 | x : (B, vocab_size). can be probs or logits 49 | 50 | Returns: 51 | values : (B, k). ordered from lowest to biggest val 52 | """ 53 | 54 | return mx.sort(x)[:, -k:] 55 | 56 | class DepthWiseConv1d(nn.Module): 57 | def __init__(self, channels, kernel_size, bias, padding): 58 | super().__init__() 59 | 60 | self.channels = channels 61 | self.kernel_size = kernel_size 62 | self.bias = bias 63 | self.padding = padding 64 | 65 | self.conv1d = nn.Conv1d(in_channels=channels, out_channels=channels, 66 | kernel_size=kernel_size, bias=True, padding=padding) 67 | 68 | # see comment below 69 | indices = mx.arange(channels) 70 | mask = mx.zeros_like(self.conv1d.weight) 71 | mask[indices, :, indices] = 1 72 | self.conv1d.weight *= mask 73 | 74 | def __call__(self, x): 75 | return self.conv1d(x) 76 | 77 | def torch_to_mlx_depthwise_weights(torch_weights): 78 | """ 79 | 80 | A convolution is said to be "depthwise" when channel i of the output is only computed by passing the filter overing channel i of the input. 81 | In torch, this is done by setting groups=number of channels. 82 | Because it is not yet implemented in MLX, a workaround is to zero out the weights of a conv object initialized with groups=1 (groups=1 is when output channel i is computing by passing the filter over all input channels) 83 | To do that, we need to zero out all elements except those on the "diagonal": 84 | for channels=8 and kernel_size=4, the weights are (8, 4, 8). 85 | these are composed of 8 x (8, 4, 1) filter, each of those is used to compute one output channel. 86 | this (8, 4, 1) filter is composed of 8 x (1, 4, 1) filter, each of those is passed over each input channel. 87 | so we need to set to 0 all those 8 filters, except the one which corresponds to the output channel of these 8 filters (so that the channels don't mix) 88 | 89 | """ 90 | 91 | # torch_weights : (channels, 1, kernel_size) = (ED, 1, d_conv) 92 | 93 | # mlx_weights : (channels, kernel_size, channels) = (ED, d_conv, ED) 94 | 95 | torch_weights = torch_weights.transpose(2, 1) # (channels, kernel_size, 1) = (ED, d_conv, 1) 96 | channels, kernel_size, _ = torch_weights.shape 97 | 98 | mlx_weights = torch.zeros(channels, kernel_size, channels) 99 | 100 | indices = torch.arange(channels) 101 | if torch_weights[:, :, 0].type() == 'torch.BFloat16Tensor': 102 | mlx_weights[indices, :, indices] = torch_weights[:, :, 0].float() 103 | else: 104 | mlx_weights[indices, :, indices] = torch_weights[:, :, 0] 105 | 106 | return mlx_weights 107 | -------------------------------------------------------------------------------- /mambapy/mlx/pscan_mlx.py: -------------------------------------------------------------------------------- 1 | import math 2 | import mlx.core as mx 3 | 4 | """ 5 | 6 | An implementation of the parallel scan algorithm in MLX (Blelloch version). 7 | The PyTorch implementation is easier to read. 8 | 9 | In a few words, this algorithm computes the sequence H[t] = A[t] * H[t-1] + X[t] in parallel for all H[t]. 10 | This repaces the naive sequential way of computing this sequence : first H[0], then H[1] and so on. 11 | 12 | If you want more explanation about what's happening here, please see docs/pscan.ipynb. 13 | 14 | There are a few points which are different from PyTorch : 15 | - when taking a reshape, we have a new tensor rather than a simple view of it. 16 | this is quite problematic because this algorithm works with in-place updates 17 | thus, at the end of each iteration (down sweep and up sweep) we actually need to write back to A and X the computations done in the iteration. 18 | - there is no need for hand-written backward computation ! 19 | 20 | Unfortunately, this parallel scan implementation is not worth it (compared to sequential implementation). 21 | From the different tests I've done, I suspect that it is partly caused by all the re-write we have to do at the end of each iteration. Still, turning them off does not make the pscan 22 | as fast as in PyTorch (relative to the sequential mode). There is a second version of this pscan (see the commented function) which works with indices only, but still is not competitive. 23 | 24 | This is *very different* from what is observed in PyTorch, where the pscan is on the same order of magnitude as the official Mamba impementation (for d_state=16), and orders of 25 | magnitude faster than the sequential mode. 26 | 27 | (Tests were done with a M2 Pro, 16GB) 28 | 29 | """ 30 | 31 | def pscan_f(A, X): 32 | # A : (B, D, L, N) 33 | # X : (B, D, L, N) 34 | 35 | # modifies X in place by doing a parallel scan. 36 | # more formally, X will be populated by these values : 37 | # H[t] = A[t] * H[t-1] + X[t] with H[0] = 0 38 | # which are computed in parallel (2*log2(T) sequential steps (ideally), instead of T sequential steps) 39 | 40 | Aa = A 41 | Xa = X 42 | 43 | B, D, L, _ = A.shape 44 | 45 | num_steps = int(math.log2(L)) 46 | 47 | # up sweep 48 | for k in range(num_steps): 49 | T = 2 * (Xa.shape[2] // 2) 50 | 51 | Aa = Aa[:, :, :T].reshape(B, D, T//2, 2, -1) 52 | Xa = Xa[:, :, :T].reshape(B, D, T//2, 2, -1) 53 | 54 | Xa[:, :, :, 1] += Aa[:, :, :, 1] * Xa[:, :, :, 0] 55 | Aa[:, :, :, 1] *= Aa[:, :, :, 0] 56 | 57 | A[:, :, 2**(k+1)-1::2**(k+1)] = Aa[:, :, :, 1] 58 | X[:, :, 2**(k+1)-1::2**(k+1)] = Xa[:, :, :, 1] 59 | 60 | Aa = Aa[:, :, :, 1] 61 | Xa = Xa[:, :, :, 1] 62 | 63 | # down sweep 64 | for k in range(num_steps-1, -1, -1): 65 | Aa = A[:, :, 2**k-1::2**k] 66 | Xa = X[:, :, 2**k-1::2**k] 67 | 68 | step_len = Xa.shape[2] 69 | T = 2 * (step_len // 2) 70 | 71 | if T < step_len: 72 | last_val_aa = Aa[:, :, -1] * Aa[:, :, -2] 73 | last_val_xa = Xa[:, :, -1] + Aa[:, :, -1] * Xa[:, :, -2] 74 | 75 | Aa = Aa[:, :, :T].reshape(B, D, T//2, 2, -1) 76 | Xa = Xa[:, :, :T].reshape(B, D, T//2, 2, -1) 77 | 78 | Xa[:, :, 1:, 0] += Aa[:, :, 1:, 0] * Xa[:, :, :-1, 1] 79 | Aa[:, :, 1:, 0] *= Aa[:, :, :-1, 1] 80 | 81 | if T == step_len: 82 | A[:, :, 2**k-1::2**(k+1)] = Aa[:, :, :, 0] 83 | X[:, :, 2**k-1::2**(k+1)] = Xa[:, :, :, 0] 84 | else: 85 | A[:, :, 2**k-1::2**(k+1)] = mx.concatenate([Aa[:, :, :, 0], mx.array([last_val_aa]).reshape(B, D, 1, -1)], axis=2) 86 | X[:, :, 2**k-1::2**(k+1)] = mx.concatenate([Xa[:, :, :, 0], mx.array([last_val_xa]).reshape(B, D, 1, -1)], axis=2) 87 | 88 | # main function, used in the Mamba model (mamba_mlx.py) 89 | def pscan(A_in, X_in): 90 | """ 91 | Applies the parallel scan operation, as defined above. Returns a new tensor. 92 | 93 | Args: 94 | A_in : (B, L, ED, N) 95 | X_in : (B, L, ED, N) 96 | 97 | Returns: 98 | H : (B, L, ED, N) 99 | """ 100 | 101 | A = A_in[:].transpose(0, 2, 1, 3) 102 | X = X_in[:].transpose(0, 2, 1, 3) 103 | 104 | pscan_f(A, X) 105 | 106 | return X.transpose(0, 2, 1, 3) 107 | 108 | """ 109 | def pscan_f(A, X): 110 | # A : (B, D, L, N) 111 | # X : (B, D, L, N) 112 | 113 | # This functions is numerically equivalent to the preivous one, but instead of creating new arrays (Aa and Xa) at each iterations, it simply 114 | # updates in-place A and X at the correct indices (hence the quite not-understandable code) 115 | # (it only works with L being a power of 2) 116 | 117 | # While being faster than the previous one (~4x), it stills is not competitive with the naive sequential implementation 118 | 119 | _, _, L, _ = A.shape 120 | 121 | num_steps = int(math.log2(L)) 122 | 123 | for k in range(0, num_steps): 124 | temp = 2**(k+1) 125 | X[:, :, temp-1::temp] += A[:, :, temp-1::temp] * X[:, :, 2**k-1::temp] 126 | A[:, :, temp-1::temp] *= A[:, :, 2**k-1::temp] 127 | 128 | for k in range(num_steps, -1, -1): 129 | temp = 2**(k+1) 130 | X[:, :, 3*2**k-1::temp] += A[:, :, 3*2**k-1::temp] * X[:, :, temp-1:L-2**k:temp] 131 | A[:, :, 3*2**k-1::temp] *= A[:, :, temp-1:L-2**k:temp] 132 | """ -------------------------------------------------------------------------------- /mambapy/mlx/scripts/generate.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 3 | 4 | import warnings 5 | warnings.filterwarnings('ignore') 6 | 7 | import argparse 8 | 9 | import mlx.core as mx 10 | 11 | import transformers 12 | transformers.logging.set_verbosity_error() 13 | from transformers import AutoTokenizer 14 | 15 | from mamba_lm_mlx import MambaLM 16 | 17 | if __name__ == "__main__": 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--hf_model_name', type=str, default='state-spaces/mamba-130m') 20 | parser.add_argument('--model_dir', type=str, default=None, help='local model to load. overwrites hf_model_name') 21 | parser.add_argument('--prompt', type=str, default='Mamba is a type of') 22 | parser.add_argument('--n_tokens', type=int, default=50, help='number of tokens to generate') 23 | parser.add_argument('--temperature', type=float, default=1.0) 24 | parser.add_argument('--top_k', type=int, default=None, help='top_k sampling : only sample from the top k most probable tokens') 25 | 26 | args = parser.parse_args() 27 | 28 | if args.model_dir is not None: 29 | raise NotImplementedError 30 | else: 31 | model = MambaLM.from_pretrained(args.hf_model_name) 32 | tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b') 33 | 34 | #print(model.config) 35 | 36 | mx.set_default_device(mx.gpu) 37 | 38 | output = model.generate(tokenizer, args.prompt, n_tokens_to_gen=args.n_tokens, temperature=args.temperature, top_k=args.top_k) 39 | 40 | print(output) 41 | -------------------------------------------------------------------------------- /mambapy/mlx/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Union 3 | 4 | import numpy as np 5 | import mlx.core as mx 6 | import torch 7 | 8 | from misc import torch_to_mlx_depthwise_weights 9 | 10 | # TODO : map_mlx_to_mambapy_torch 11 | 12 | def load_config_hf(model_name): 13 | from transformers.utils import WEIGHTS_NAME, CONFIG_NAME 14 | from transformers.utils.hub import cached_file 15 | 16 | resolved_archive_file = cached_file(model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=False) 17 | return json.load(open(resolved_archive_file)) 18 | 19 | def load_state_dict_hf(model_name): 20 | from transformers.utils import WEIGHTS_NAME, CONFIG_NAME 21 | from transformers.utils.hub import cached_file 22 | 23 | resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False) 24 | return torch.load(resolved_archive_file, weights_only=True, map_location='cpu', mmap=True) 25 | 26 | def map_mambapy_torch_to_mlx(torch_state_dict): 27 | new_state_dict = {} 28 | for key, value in torch_state_dict.items(): 29 | 30 | # from torch to mlx, we need to convert the conv weights (see misc.py for explanations) 31 | if 'conv1d.weight' in key: 32 | value = torch_to_mlx_depthwise_weights(value) 33 | 34 | if 'conv1d' in key: 35 | key = key.replace('conv1d', 'conv1d.conv1d') 36 | 37 | if value.type() == 'torch.BFloat16Tensor': 38 | new_state_dict[key] = value.half().numpy() 39 | else: 40 | new_state_dict[key] = value.numpy() 41 | 42 | return new_state_dict 43 | 44 | def map_mambassm_torch_to_mlx(torch_state_dict): 45 | # convert mambassm to mambapy 46 | new_state_dict = {} 47 | for key in torch_state_dict: 48 | if key == 'backbone.embedding.weight' or key == 'backbone.norm_f.weight': 49 | new_key = key.replace('backbone.', '') 50 | else: 51 | new_key = key.replace('backbone', 'mamba') 52 | 53 | new_state_dict[new_key] = torch_state_dict[key] 54 | 55 | # convert mambapy to mlx 56 | return map_mambapy_torch_to_mlx(new_state_dict) 57 | 58 | """ 59 | # todo : doesnt work, because MambaConfig and MambaLMConfig are not the ones defined in mamba.py and mamba_lm.py 60 | def mambapy_torch_to_mlx(torch_state_dict, config: Union[MambaConfig, MambaLMConfig]): 61 | mlx_state_dict = map_mambapy_torch_to_mlx(torch_state_dict) 62 | 63 | if isinstance(config, MambaConfig): 64 | model = Mamba(config) 65 | else: 66 | model = MambaLM(config) 67 | 68 | np.savez("weights.mlx.npz", **mlx_state_dict) # TODO name with config? 69 | model.update(tree_unflatten(list(mx.load("weights.mlx.npz").items()))) 70 | 71 | # todo : name the file according to config 72 | # todo : check if file already exists 73 | 74 | return model 75 | """ 76 | -------------------------------------------------------------------------------- /mambapy/pscan.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | """ 7 | 8 | An implementation of the parallel scan operation in PyTorch (Blelloch version). 9 | Please see docs/pscan.ipynb for a detailed explanation of what happens here. 10 | 11 | """ 12 | 13 | def npo2(len): 14 | """ 15 | Returns the next power of 2 above len 16 | """ 17 | 18 | return 2 ** math.ceil(math.log2(len)) 19 | 20 | def pad_npo2(X): 21 | """ 22 | Pads input length dim to the next power of 2 23 | 24 | Args: 25 | X : (B, L, D, N) 26 | 27 | Returns: 28 | Y : (B, npo2(L), D, N) 29 | """ 30 | 31 | len_npo2 = npo2(X.size(1)) 32 | pad_tuple = (0, 0, 0, 0, 0, len_npo2 - X.size(1)) 33 | return F.pad(X, pad_tuple, "constant", 0) 34 | 35 | class PScan(torch.autograd.Function): 36 | @staticmethod 37 | def pscan(A, X): 38 | # A : (B, D, L, N) 39 | # X : (B, D, L, N) 40 | 41 | # modifies X in place by doing a parallel scan. 42 | # more formally, X will be populated by these values : 43 | # H[t] = A[t] * H[t-1] + X[t] with H[0] = 0 44 | # which are computed in parallel (2*log2(T) sequential steps (ideally), instead of T sequential steps) 45 | 46 | # only supports L that is a power of two (mainly for a clearer code) 47 | 48 | B, D, L, _ = A.size() 49 | num_steps = int(math.log2(L)) 50 | 51 | # up sweep (last 2 steps unfolded) 52 | Aa = A 53 | Xa = X 54 | for _ in range(num_steps-2): 55 | T = Xa.size(2) 56 | Aa = Aa.view(B, D, T//2, 2, -1) 57 | Xa = Xa.view(B, D, T//2, 2, -1) 58 | 59 | Xa[:, :, :, 1].add_(Aa[:, :, :, 1].mul(Xa[:, :, :, 0])) 60 | Aa[:, :, :, 1].mul_(Aa[:, :, :, 0]) 61 | 62 | Aa = Aa[:, :, :, 1] 63 | Xa = Xa[:, :, :, 1] 64 | 65 | # we have only 4, 2 or 1 nodes left 66 | if Xa.size(2) == 4: 67 | Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 0])) 68 | Aa[:, :, 1].mul_(Aa[:, :, 0]) 69 | 70 | Xa[:, :, 3].add_(Aa[:, :, 3].mul(Xa[:, :, 2] + Aa[:, :, 2].mul(Xa[:, :, 1]))) 71 | elif Xa.size(2) == 2: 72 | Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 0])) 73 | return 74 | else: 75 | return 76 | 77 | # down sweep (first 2 steps unfolded) 78 | Aa = A[:, :, 2**(num_steps-2)-1:L:2**(num_steps-2)] 79 | Xa = X[:, :, 2**(num_steps-2)-1:L:2**(num_steps-2)] 80 | Xa[:, :, 2].add_(Aa[:, :, 2].mul(Xa[:, :, 1])) 81 | Aa[:, :, 2].mul_(Aa[:, :, 1]) 82 | 83 | for k in range(num_steps-3, -1, -1): 84 | Aa = A[:, :, 2**k-1:L:2**k] 85 | Xa = X[:, :, 2**k-1:L:2**k] 86 | 87 | T = Xa.size(2) 88 | Aa = Aa.view(B, D, T//2, 2, -1) 89 | Xa = Xa.view(B, D, T//2, 2, -1) 90 | 91 | Xa[:, :, 1:, 0].add_(Aa[:, :, 1:, 0].mul(Xa[:, :, :-1, 1])) 92 | Aa[:, :, 1:, 0].mul_(Aa[:, :, :-1, 1]) 93 | 94 | @staticmethod 95 | def pscan_rev(A, X): 96 | # A : (B, D, L, N) 97 | # X : (B, D, L, N) 98 | 99 | # the same function as above, but in reverse 100 | # (if you flip the input, call pscan, then flip the output, you get what this function outputs) 101 | # it is used in the backward pass 102 | 103 | # only supports L that is a power of two (mainly for a clearer code) 104 | 105 | B, D, L, _ = A.size() 106 | num_steps = int(math.log2(L)) 107 | 108 | # up sweep (last 2 steps unfolded) 109 | Aa = A 110 | Xa = X 111 | for _ in range(num_steps-2): 112 | T = Xa.size(2) 113 | Aa = Aa.view(B, D, T//2, 2, -1) 114 | Xa = Xa.view(B, D, T//2, 2, -1) 115 | 116 | Xa[:, :, :, 0].add_(Aa[:, :, :, 0].mul(Xa[:, :, :, 1])) 117 | Aa[:, :, :, 0].mul_(Aa[:, :, :, 1]) 118 | 119 | Aa = Aa[:, :, :, 0] 120 | Xa = Xa[:, :, :, 0] 121 | 122 | # we have only 4, 2 or 1 nodes left 123 | if Xa.size(2) == 4: 124 | Xa[:, :, 2].add_(Aa[:, :, 2].mul(Xa[:, :, 3])) 125 | Aa[:, :, 2].mul_(Aa[:, :, 3]) 126 | 127 | Xa[:, :, 0].add_(Aa[:, :, 0].mul(Xa[:, :, 1].add(Aa[:, :, 1].mul(Xa[:, :, 2])))) 128 | elif Xa.size(2) == 2: 129 | Xa[:, :, 0].add_(Aa[:, :, 0].mul(Xa[:, :, 1])) 130 | return 131 | else: 132 | return 133 | 134 | # down sweep (first 2 steps unfolded) 135 | Aa = A[:, :, 0:L:2**(num_steps-2)] 136 | Xa = X[:, :, 0:L:2**(num_steps-2)] 137 | Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 2])) 138 | Aa[:, :, 1].mul_(Aa[:, :, 2]) 139 | 140 | for k in range(num_steps-3, -1, -1): 141 | Aa = A[:, :, 0:L:2**k] 142 | Xa = X[:, :, 0:L:2**k] 143 | 144 | T = Xa.size(2) 145 | Aa = Aa.view(B, D, T//2, 2, -1) 146 | Xa = Xa.view(B, D, T//2, 2, -1) 147 | 148 | Xa[:, :, :-1, 1].add_(Aa[:, :, :-1, 1].mul(Xa[:, :, 1:, 0])) 149 | Aa[:, :, :-1, 1].mul_(Aa[:, :, 1:, 0]) 150 | 151 | @staticmethod 152 | def forward(ctx, A_in, X_in): 153 | """ 154 | Applies the parallel scan operation, as defined above. Returns a new tensor. 155 | If you can, privilege sequence lengths that are powers of two. 156 | 157 | Args: 158 | A_in : (B, L, D, N) 159 | X_in : (B, L, D, N) 160 | 161 | Returns: 162 | H : (B, L, D, N) 163 | """ 164 | 165 | L = X_in.size(1) 166 | 167 | # cloning is requiered because of the in-place ops 168 | if L == npo2(L): 169 | A = A_in.clone() 170 | X = X_in.clone() 171 | else: 172 | # pad tensors (and clone btw) 173 | A = pad_npo2(A_in) # (B, npo2(L), D, N) 174 | X = pad_npo2(X_in) # (B, npo2(L), D, N) 175 | 176 | # prepare tensors 177 | A = A.transpose(2, 1) # (B, D, npo2(L), N) 178 | X = X.transpose(2, 1) # (B, D, npo2(L), N) 179 | 180 | # parallel scan (modifies X in-place) 181 | PScan.pscan(A, X) 182 | 183 | ctx.save_for_backward(A_in, X) 184 | 185 | # slice [:, :L] (cut if there was padding) 186 | return X.transpose(2, 1)[:, :L] 187 | 188 | @staticmethod 189 | def backward(ctx, grad_output_in): 190 | """ 191 | Flows the gradient from the output to the input. Returns two new tensors. 192 | 193 | Args: 194 | ctx : A_in : (B, L, D, N), X : (B, D, L, N) 195 | grad_output_in : (B, L, D, N) 196 | 197 | Returns: 198 | gradA : (B, L, D, N), gradX : (B, L, D, N) 199 | """ 200 | 201 | A_in, X = ctx.saved_tensors 202 | 203 | L = grad_output_in.size(1) 204 | 205 | # cloning is requiered because of the in-place ops 206 | if L == npo2(L): 207 | grad_output = grad_output_in.clone() 208 | # the next padding will clone A_in 209 | else: 210 | grad_output = pad_npo2(grad_output_in) # (B, npo2(L), D, N) 211 | A_in = pad_npo2(A_in) # (B, npo2(L), D, N) 212 | 213 | # prepare tensors 214 | grad_output = grad_output.transpose(2, 1) 215 | A_in = A_in.transpose(2, 1) # (B, D, npo2(L), N) 216 | A = torch.nn.functional.pad(A_in[:, :, 1:], (0, 0, 0, 1)) # (B, D, npo2(L), N) shift 1 to the left (see hand derivation) 217 | 218 | # reverse parallel scan (modifies grad_output in-place) 219 | PScan.pscan_rev(A, grad_output) 220 | 221 | Q = torch.zeros_like(X) 222 | Q[:, :, 1:].add_(X[:, :, :-1] * grad_output[:, :, 1:]) 223 | 224 | return Q.transpose(2, 1)[:, :L], grad_output.transpose(2, 1)[:, :L] 225 | 226 | pscan = PScan.apply 227 | -------------------------------------------------------------------------------- /mambapy/tests/compare_mambapy_cuda.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from mamba_ssm import Mamba 4 | 5 | from mamba import MambaBlock, MambaConfig 6 | 7 | batch, length, dim = 2, 512, 16 8 | 9 | x = torch.randn(batch, length, dim).to("cuda") 10 | x.requieres_grad = True 11 | 12 | # CUDA Model 13 | 14 | torch.manual_seed(1) 15 | 16 | model_cuda = Mamba( 17 | # This module uses roughly 3 * expand * d_model^2 parameters 18 | d_model=dim, # Model dimension d_model 19 | d_state=16, # SSM state expansion factor 20 | d_conv=4, # Local convolution width 21 | expand=2, # Block expansion factor 22 | ).to("cuda") 23 | 24 | y_cuda = model_cuda(x) 25 | 26 | print(sum([p.numel() for p in model_cuda.parameters()])) 27 | print(y_cuda.shape) 28 | 29 | # mamba.py model 30 | 31 | torch.manual_seed(1) 32 | 33 | config = MambaConfig(d_model=dim, n_layers=1) 34 | model = MambaBlock(config).to("cuda") 35 | 36 | y_pscan = model(x) 37 | 38 | print(sum([p.numel() for p in model.parameters()])) 39 | print(y_pscan.shape) 40 | 41 | # forward # 42 | print(torch.allclose(y_cuda, y_pscan, rtol=0.1)) 43 | 44 | # backward # 45 | J_cuda = y_cuda.sum() 46 | J_cuda.backward() 47 | 48 | J_pscan = y_pscan.sum() 49 | J_pscan.backward() 50 | 51 | print(torch.allclose(model_cuda.in_proj.weight.grad, model.in_proj.weight.grad, rtol=0.01)) 52 | 53 | 54 | -------------------------------------------------------------------------------- /mambapy/tests/mem_mamba_3.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from mamba import Mamba, MambaConfig 4 | 5 | device = "cuda" 6 | 7 | B, L, D, N = 16, 64, 128, 16 8 | 9 | config = MambaConfig(d_model=D, n_layers=8, d_state=N) 10 | model = Mamba(config).to(device) 11 | 12 | torch.cuda.empty_cache() 13 | torch.cuda.reset_peak_memory_stats(device) 14 | 15 | torch.cuda.reset_peak_memory_stats(device) 16 | initial_memory = torch.cuda.max_memory_allocated(device) 17 | 18 | for _ in range(100): 19 | X = torch.randn(B, L, D).to(device, non_blocking=True) 20 | 21 | output = model(X) 22 | loss = output.sum() 23 | 24 | loss.backward() 25 | 26 | peak_memory = torch.cuda.max_memory_allocated(device=device) # Peak memory during backward 27 | 28 | print(initial_memory/(1024**2)) 29 | print(peak_memory/(1024**2)) 30 | print("-----------------------------") 31 | 32 | # relate bien ce qu'on voit dans nvidia-smi (a qqchose pres : sans le torch-cuda overhead) -------------------------------------------------------------------------------- /mambapy/tests/profiling_mamba.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.autograd.profiler as profiler 3 | 4 | from mamba import Mamba, MambaConfig 5 | 6 | device = "cuda" 7 | 8 | B, L, D, N = 16, 1024, 1024, 16 9 | 10 | config = MambaConfig(d_model=D, n_layers=8, d_state=N) 11 | model = Mamba(config).to(device) 12 | 13 | X = torch.randn(B, L, D).to(device, non_blocking=True) 14 | 15 | model(X) 16 | 17 | with profiler.profile(record_shapes=True, use_cuda=True, profile_memory=True) as prof: 18 | with profiler.record_function("model_forward"): 19 | output = model(X) 20 | loss = output.sum() 21 | with profiler.record_function("model_backward"): 22 | loss.backward() 23 | 24 | print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) 25 | print(f"Peak CUDA Memory Usage: {prof.total_average().cuda_memory_usage / (1024 ** 2)} MB") 26 | -------------------------------------------------------------------------------- /mambapy/tests/profiling_pscan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.autograd.profiler as profiler 3 | 4 | from pscan import pscan 5 | 6 | B, L, D, N = 16, 1024, 32, 16 7 | 8 | A = torch.randn(B, L, D, N).to("cuda") 9 | X = torch.randn(B, L, D, N).to("cuda") 10 | 11 | H = pscan(A, X) 12 | with profiler.profile(record_shapes=True, use_cuda=True, profile_memory=True) as prof: 13 | with profiler.record_function("pscan_custom_function"): 14 | H = pscan(A, X) 15 | 16 | print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=100)) 17 | prof.export_chrome_trace("pscan_profiling_trace.json") 18 | 19 | print(f"Peak CUDA Memory Usage: {prof.total_average().cuda_memory_usage / (1024 ** 2)} MB") 20 | 21 | -------------------------------------------------------------------------------- /modules/__pycache__/loss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SCNU-RISLAB/OverlapMamba/b20add9b83ea87727a993eb316869d271436be78/modules/__pycache__/loss.cpython-38.pyc -------------------------------------------------------------------------------- /modules/__pycache__/loss.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SCNU-RISLAB/OverlapMamba/b20add9b83ea87727a993eb316869d271436be78/modules/__pycache__/loss.cpython-39.pyc -------------------------------------------------------------------------------- /modules/__pycache__/netvlad.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SCNU-RISLAB/OverlapMamba/b20add9b83ea87727a993eb316869d271436be78/modules/__pycache__/netvlad.cpython-38.pyc -------------------------------------------------------------------------------- /modules/__pycache__/netvlad.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SCNU-RISLAB/OverlapMamba/b20add9b83ea87727a993eb316869d271436be78/modules/__pycache__/netvlad.cpython-39.pyc -------------------------------------------------------------------------------- /modules/loss.py: -------------------------------------------------------------------------------- 1 | # import os 2 | # import sys 3 | # p = os.path.dirname(os.path.dirname((os.path.abspath(__file__)))) 4 | # if p not in sys.path: 5 | # sys.path.append(p) 6 | # 7 | # import torch 8 | # import torch.nn as nn 9 | # import os 10 | # import numpy as np 11 | # 12 | # 13 | # def best_pos_distance(query, pos_vecs): 14 | # num_pos = pos_vecs.shape[0] 15 | # query_copies = query.repeat(int(num_pos), 1) 16 | # diff = ((pos_vecs - query_copies) ** 2).sum(1) 17 | # diff, _ = torch.sort(diff, 0) 18 | # if num_pos == 0: 19 | # print(diff.size()) 20 | # 21 | # min_pos, _ = diff.min(0) 22 | # max_pos, _ = diff.max(0) 23 | # return min_pos, max_pos 24 | # 25 | # 26 | # def triplet_loss(q_vec, pos_vecs, neg_vecs, margin, use_min=False, lazy=False, ignore_zero_loss=False): 27 | # 28 | # if pos_vecs.shape[0] == 0: 29 | # 30 | # num_neg = neg_vecs.shape[0] 31 | # query_copies = q_vec.repeat(int(num_neg), 1) 32 | # 33 | # negative = ((neg_vecs - query_copies) ** 2).sum(1).unsqueeze(1) 34 | # negative, _ = negative.min(0) 35 | # negative = negative.repeat(int(num_neg), 1) 36 | # 37 | # loss = margin - negative 38 | # 39 | # loss = loss.clamp(min=0.0) 40 | # 41 | # # loss = torch.log1p(loss) 42 | # 43 | # if lazy: 44 | # triplet_loss = loss.max(1)[0] 45 | # else: 46 | # triplet_loss = loss.sum(0) 47 | # if ignore_zero_loss: 48 | # hard_triplets = torch.gt(triplet_loss, 1e-16).float() 49 | # num_hard_triplets = torch.sum(hard_triplets) 50 | # triplet_loss = triplet_loss.sum() / (num_hard_triplets + 1e-16) 51 | # else: 52 | # triplet_loss = triplet_loss.mean() 53 | # 54 | # else: 55 | # min_pos, max_pos = best_pos_distance(q_vec, pos_vecs) 56 | # 57 | # if use_min: 58 | # positive = min_pos 59 | # else: 60 | # positive = max_pos 61 | # num_neg = neg_vecs.shape[0] 62 | # query_copies = q_vec.repeat(int(num_neg), 1) 63 | # positive = positive.view(-1, 1) 64 | # positive = positive.repeat(int(num_neg), 1) 65 | # 66 | # negative = ((neg_vecs - query_copies) ** 2).sum(1).unsqueeze(1) 67 | # negative, _ = negative.min(0) 68 | # negative = negative.repeat(int(num_neg), 1) 69 | # 70 | # loss = margin + positive - negative 71 | # 72 | # loss = loss.clamp(min=0.0) 73 | # 74 | # # loss = torch.log1p(loss) 75 | # 76 | # if lazy: 77 | # triplet_loss = loss.max(1)[0] 78 | # else: 79 | # triplet_loss = loss.sum(0) 80 | # if ignore_zero_loss: 81 | # hard_triplets = torch.gt(triplet_loss, 1e-16).float() 82 | # num_hard_triplets = torch.sum(hard_triplets) 83 | # triplet_loss = triplet_loss.sum() / (num_hard_triplets + 1e-16) 84 | # else: 85 | # triplet_loss = triplet_loss.mean() 86 | # return triplet_loss 87 | # 88 | # def triplet_loss_inv(q_vec, pos_vecs, neg_vecs, margin, use_min=True, lazy=False, ignore_zero_loss=False): 89 | # 90 | # min_neg, max_neg = best_pos_distance(q_vec, neg_vecs) 91 | # 92 | # if use_min: 93 | # negative = min_neg 94 | # else: 95 | # negative = max_neg 96 | # num_neg = neg_vecs.shape[0] 97 | # num_pos= pos_vecs.shape[0] 98 | # query_copies = q_vec.repeat(int(num_pos), 1) 99 | # negative = negative.view(-1, 1) 100 | # negative = negative.repeat(int(num_pos), 1) 101 | # 102 | # loss = margin - negative + ((pos_vecs - query_copies) ** 2).sum(1).unsqueeze(1) 103 | # 104 | # loss = loss.clamp(min=0.0) 105 | # 106 | # 107 | # 108 | # if lazy: 109 | # triplet_loss = loss.max(1)[0] 110 | # else: 111 | # triplet_loss = loss.sum(0) 112 | # if ignore_zero_loss: 113 | # hard_triplets = torch.gt(triplet_loss, 1e-16).float() 114 | # num_hard_triplets = torch.sum(hard_triplets) 115 | # triplet_loss = triplet_loss.sum() / (num_hard_triplets + 1e-16) 116 | # else: 117 | # triplet_loss = triplet_loss.mean() 118 | # return triplet_loss 119 | # 120 | # 121 | # def triplet_loss_wrapper(q_vec, pos_vecs, neg_vecs, m1, m2, use_min=False, lazy=False, ignore_zero_loss=False): 122 | # return triplet_loss(q_vec, pos_vecs, neg_vecs, m1, use_min, lazy, ignore_zero_loss) 123 | import os 124 | import sys 125 | 126 | p = os.path.dirname(os.path.dirname((os.path.abspath(__file__)))) 127 | if p not in sys.path: 128 | sys.path.append(p) 129 | 130 | import torch 131 | import torch.nn as nn 132 | import os 133 | import numpy as np 134 | 135 | 136 | def best_pos_distance(query, pos_vecs): 137 | num_pos = pos_vecs.shape[0] 138 | query_copies = query.repeat(int(num_pos), 1) 139 | diff = ((pos_vecs - query_copies) ** 2).sum(1) 140 | 141 | min_pos, _ = diff.min(0) 142 | max_pos, _ = diff.max(0) 143 | return min_pos, max_pos 144 | 145 | 146 | def triplet_loss(q_vec, pos_vecs, neg_vecs, margin, use_min=False, lazy=False, ignore_zero_loss=False): 147 | if pos_vecs.shape[0] == 0: 148 | 149 | num_neg = neg_vecs.shape[0] 150 | num_pos = pos_vecs.shape[0] 151 | query_copies = q_vec.repeat(int(num_neg), 1) 152 | 153 | negative = ((neg_vecs - query_copies) ** 2).sum(1).unsqueeze(1) 154 | 155 | loss = margin - ((neg_vecs - query_copies) ** 2).sum(1).unsqueeze(1) 156 | 157 | loss = loss.clamp(min=0.0) 158 | 159 | if lazy: 160 | triplet_loss = loss.max(1)[0] 161 | else: 162 | triplet_loss = loss.sum(0) 163 | if ignore_zero_loss: 164 | hard_triplets = torch.gt(triplet_loss, 1e-16).float() 165 | num_hard_triplets = torch.sum(hard_triplets) 166 | triplet_loss = triplet_loss.sum() / (num_hard_triplets + 1e-16) 167 | else: 168 | triplet_loss = triplet_loss.mean() 169 | return triplet_loss 170 | 171 | else: 172 | min_pos, max_pos = best_pos_distance(q_vec, pos_vecs) 173 | 174 | if use_min: 175 | positive = min_pos 176 | else: 177 | positive = max_pos 178 | num_neg = neg_vecs.shape[0] 179 | num_pos = pos_vecs.shape[0] 180 | query_copies = q_vec.repeat(int(num_neg), 1) 181 | positive = positive.view(-1, 1) 182 | positive = positive.repeat(int(num_neg), 1) 183 | 184 | negative = ((neg_vecs - query_copies) ** 2).sum(1).unsqueeze(1) 185 | 186 | loss = margin + positive - ((neg_vecs - query_copies) ** 2).sum(1).unsqueeze(1) 187 | 188 | loss = loss.clamp(min=0.0) 189 | 190 | if lazy: 191 | triplet_loss = loss.max(1)[0] 192 | else: 193 | triplet_loss = loss.sum(0) 194 | if ignore_zero_loss: 195 | hard_triplets = torch.gt(triplet_loss, 1e-16).float() 196 | num_hard_triplets = torch.sum(hard_triplets) 197 | triplet_loss = triplet_loss.sum() / (num_hard_triplets + 1e-16) 198 | else: 199 | triplet_loss = triplet_loss.mean() 200 | return triplet_loss 201 | 202 | 203 | def triplet_loss_inv(q_vec, pos_vecs, neg_vecs, margin, use_min=True, lazy=False, ignore_zero_loss=False): 204 | min_neg, max_neg = best_pos_distance(q_vec, neg_vecs) 205 | 206 | if use_min: 207 | negative = min_neg 208 | else: 209 | negative = max_neg 210 | num_neg = neg_vecs.shape[0] 211 | num_pos = pos_vecs.shape[0] 212 | query_copies = q_vec.repeat(int(num_pos), 1) 213 | negative = negative.view(-1, 1) 214 | negative = negative.repeat(int(num_pos), 1) 215 | 216 | loss = margin - negative + ((pos_vecs - query_copies) ** 2).sum(1).unsqueeze(1) 217 | 218 | loss = loss.clamp(min=0.0) 219 | 220 | if lazy: 221 | triplet_loss = loss.max(1)[0] 222 | else: 223 | triplet_loss = loss.sum(0) 224 | if ignore_zero_loss: 225 | hard_triplets = torch.gt(triplet_loss, 1e-16).float() 226 | num_hard_triplets = torch.sum(hard_triplets) 227 | triplet_loss = triplet_loss.sum() / (num_hard_triplets + 1e-16) 228 | else: 229 | triplet_loss = triplet_loss.mean() 230 | return triplet_loss 231 | 232 | 233 | def triplet_loss_wrapper(q_vec, pos_vecs, neg_vecs, m1, m2, use_min=False, lazy=False, ignore_zero_loss=False): 234 | return triplet_loss(q_vec, pos_vecs, neg_vecs, m1, use_min, lazy, ignore_zero_loss) -------------------------------------------------------------------------------- /modules/netvlad.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | p = os.path.dirname(os.path.dirname((os.path.abspath(__file__)))) 4 | if p not in sys.path: 5 | sys.path.append(p) 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import math 11 | 12 | 13 | class NetVLADLoupe(nn.Module): 14 | def __init__(self, feature_size, max_samples, cluster_size, output_dim, 15 | gating=True, add_batch_norm=True, is_training=True): 16 | super(NetVLADLoupe, self).__init__() 17 | self.feature_size = feature_size 18 | self.max_samples = max_samples 19 | self.output_dim = output_dim 20 | self.is_training = is_training 21 | self.gating = gating 22 | self.add_batch_norm = add_batch_norm 23 | self.cluster_size = cluster_size 24 | self.softmax = nn.Softmax(dim=-1) 25 | 26 | self.cluster_weights = nn.Parameter(torch.randn( 27 | feature_size, cluster_size) * 1 / math.sqrt(feature_size)) 28 | self.cluster_weights2 = nn.Parameter(torch.randn( 29 | 1, feature_size, cluster_size) * 1 / math.sqrt(feature_size)) 30 | self.hidden1_weights = nn.Parameter(torch.randn( 31 | cluster_size * feature_size, output_dim) * 1 / math.sqrt(feature_size)) 32 | 33 | if add_batch_norm: 34 | self.cluster_biases = None 35 | self.bn1 = nn.BatchNorm1d(cluster_size) 36 | else: 37 | self.cluster_biases = nn.Parameter(torch.randn( 38 | cluster_size) * 1 / math.sqrt(feature_size)) 39 | self.bn1 = None 40 | 41 | self.bn2 = nn.BatchNorm1d(output_dim) 42 | 43 | if gating: 44 | self.context_gating = GatingContext( 45 | output_dim, add_batch_norm=add_batch_norm) 46 | 47 | def forward(self, x): 48 | x = x.transpose(1, 3).contiguous() 49 | x = x.view((-1, self.max_samples, self.feature_size)) 50 | activation = torch.matmul(x, self.cluster_weights) 51 | if self.add_batch_norm: 52 | # activation = activation.transpose(1,2).contiguous() 53 | activation = activation.view(-1, self.cluster_size) 54 | activation = self.bn1(activation) 55 | activation = activation.view(-1, self.max_samples, self.cluster_size) 56 | # activation = activation.transpose(1,2).contiguous() 57 | else: 58 | activation = activation + self.cluster_biases 59 | activation = self.softmax(activation) 60 | activation = activation.view((-1, self.max_samples, self.cluster_size)) 61 | 62 | a_sum = activation.sum(-2, keepdim=True) 63 | a = a_sum * self.cluster_weights2 64 | 65 | activation = torch.transpose(activation, 2, 1) 66 | x = x.view((-1, self.max_samples, self.feature_size)) 67 | vlad = torch.matmul(activation, x) 68 | vlad = torch.transpose(vlad, 2, 1) 69 | vlad = vlad - a 70 | 71 | vlad = F.normalize(vlad, dim=1, p=2) 72 | # vlad = vlad.view((-1, self.cluster_size * self.feature_size)) 73 | vlad = vlad.reshape((-1, self.cluster_size * self.feature_size)) 74 | vlad = F.normalize(vlad, dim=1, p=2) 75 | 76 | vlad = torch.matmul(vlad, self.hidden1_weights) 77 | 78 | # vlad = self.bn2(vlad) 79 | 80 | if self.gating: 81 | vlad = self.context_gating(vlad) 82 | 83 | return vlad 84 | 85 | 86 | class GatingContext(nn.Module): 87 | def __init__(self, dim, add_batch_norm=True): 88 | super(GatingContext, self).__init__() 89 | self.dim = dim 90 | self.add_batch_norm = add_batch_norm 91 | self.gating_weights = nn.Parameter( 92 | torch.randn(dim, dim) * 1 / math.sqrt(dim)) 93 | self.sigmoid = nn.Sigmoid() 94 | 95 | if add_batch_norm: 96 | self.gating_biases = None 97 | self.bn1 = nn.BatchNorm1d(dim) 98 | else: 99 | self.gating_biases = nn.Parameter( 100 | torch.randn(dim) * 1 / math.sqrt(dim)) 101 | self.bn1 = None 102 | 103 | def forward(self, x): 104 | gates = torch.matmul(x, self.gating_weights) 105 | 106 | if self.add_batch_norm: 107 | gates = self.bn1(gates) 108 | else: 109 | gates = gates + self.gating_biases 110 | 111 | gates = self.sigmoid(gates) 112 | 113 | activation = x * gates 114 | 115 | return activation 116 | 117 | 118 | if __name__ == '__main__': 119 | net_vlad = NetVLADLoupe(feature_size=1024, max_samples=360, cluster_size=16, 120 | output_dim=20, gating=True, add_batch_norm=True, 121 | is_training=True) 122 | # input (bs, 1024, 360, 1) 123 | torch.manual_seed(1234) 124 | input_tensor = F.normalize(torch.randn((1,1024,360,1)), dim=1) 125 | input_tensor2 = torch.zeros_like(input_tensor) 126 | input_tensor2[:, :, 2:, :] = input_tensor[:, :, 0:-2, :].clone() 127 | input_tensor2[:, :, :2, :] = input_tensor[:, :, -2:, :].clone() 128 | input_tensor2= F.normalize(input_tensor2, dim=1) 129 | input_tensor_com = torch.cat((input_tensor, input_tensor2), dim=0) 130 | 131 | # print(input_tensor[0,0,:,0]) 132 | # print(input_tensor2[0,0,:,0]) 133 | print("==================================") 134 | 135 | with torch.no_grad(): 136 | net_vlad.eval() 137 | # output_tensor = net_vlad(input_tensor_com) 138 | # print(output_tensor) 139 | out1 = net_vlad(input_tensor) 140 | print(out1) 141 | net_vlad.eval() 142 | # input_tensor2[:, :, 20:, :] = 0.1 143 | input_tensor2 = F.normalize(input_tensor2, dim=1) 144 | out2 = net_vlad(input_tensor2) 145 | print(out2) 146 | net_vlad.eval() 147 | input_tensor3 = torch.randn((1,1024,360,1)) 148 | out3 = net_vlad(input_tensor3) 149 | print(out3) 150 | 151 | 152 | print(((out1-out2)**2).sum(1)) 153 | print(((out1-out3)**2).sum(1)) 154 | 155 | 156 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib==3.4.3 2 | numpy==1.20.3 3 | tensorboardx==2.4 4 | pyyaml==6.0 5 | opencv-python==4.5.4.58 6 | faiss-cpu==1.7.1 7 | scikit-learn==0.24.2 8 | -------------------------------------------------------------------------------- /test/test_kitti00_PR.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | import os 5 | import sys 6 | from datetime import datetime 7 | 8 | p = os.path.dirname(os.path.dirname((os.path.abspath(__file__)))) 9 | if p not in sys.path: 10 | sys.path.append(p) 11 | 12 | import copy 13 | import matplotlib.pyplot as plt 14 | import numpy as np 15 | import yaml 16 | from sklearn import metrics 17 | 18 | 19 | def cal_pr_curve(prediction_file_name, ground_truth_file_name): 20 | precisions = [] 21 | recalls = [] 22 | 23 | print(1) 24 | des_dists = np.load(prediction_file_name)['arr_0'] 25 | des_dists = np.asarray(des_dists, dtype='float32') 26 | des_dists = des_dists.reshape((len(des_dists), 3)) 27 | print('Load descrptor distances predictions with pairs of ', len(des_dists)) 28 | print(des_dists.shape) 29 | 30 | ground_truth = np.load(ground_truth_file_name, allow_pickle='True')['arr_0'] 31 | 32 | """Changing the threshold will lead to different test results""" 33 | for thres in np.arange(0, 1.0, 0.01): 34 | current_time = datetime.now() 35 | formatted_time = current_time.strftime("[%Y-%m-%d %H:%M:%S]") 36 | print(formatted_time, " thresh: ", thres) 37 | tps = 0 38 | fps = 0 39 | tns = 0 40 | fns = 0 41 | """Update the start frame index""" 42 | for idx in range(150, len(ground_truth) - 1): 43 | gt_idxes = ground_truth[int(idx)] 44 | reject_flag = False 45 | 46 | if des_dists[des_dists[:, 0] == int(idx), 2][0] > thres: 47 | reject_flag = True 48 | if reject_flag: 49 | if not any(gt_idxes): 50 | tns += 1 51 | else: 52 | fns += 1 53 | else: 54 | predicted_idx = des_dists[des_dists[:, 0] == int(idx), 1][0] 55 | if any(predicted_idx == gt for gt in gt_idxes): 56 | # if des_dists[des_dists[:, 0] == int(idx), 1][0] in gt_idxes: 57 | tps += 1 58 | else: 59 | fps += 1 60 | 61 | # for idx in range(150, len(ground_truth) - 1): 62 | # gt_idxes = ground_truth[int(idx)] 63 | # reject_flag = False 64 | # 65 | # if des_dists[des_dists[:, 0] == int(idx), 2][0] > thres: 66 | # reject_flag = True 67 | # if reject_flag: 68 | # if not np.any(gt_idxes): 69 | # tns += 1 70 | # else: 71 | # fns += 1 72 | # else: 73 | # if any(des_dists[des_dists[:, 0] == int(idx), 1][0] == gt for gt in gt_idxes): 74 | # tps += 1 75 | # else: 76 | # fps += 1 77 | 78 | if fps == 0: 79 | precision = 1 80 | else: 81 | precision = float(tps) / (float(tps) + float(fps)) 82 | if fns == 0: 83 | recall = 1 84 | else: 85 | recall = float(tps) / (float(tps) + float(fns)) 86 | 87 | f1_score = 2 * (precision * recall) / (precision + recall + 1e-10) 88 | print("f1 score: ", f1_score) 89 | print("recall: ", recall) 90 | print("precision:", precision) 91 | precisions.append(precision) 92 | recalls.append(recall) 93 | 94 | # print("precision ", precision) 95 | # print("recall ", recall) 96 | 97 | print("Highest precision: %s" % max(precisions)) 98 | print("Highest recall: %s" % max(recalls)) 99 | 100 | return precisions, recalls 101 | 102 | 103 | """Ploting and saving AUC.""" 104 | 105 | 106 | def plotPRC(precisions, recalls, print_2file=True): 107 | # initial plot 108 | plt.clf() 109 | 110 | if print_2file: 111 | save_name = "./" + dir_name + "/PR.png" 112 | 113 | recalls, precisions = (list(t) for t in zip(*sorted(zip(recalls, precisions), reverse=True))) 114 | auc = metrics.auc(recalls, precisions) * 100 115 | 116 | plt.plot(recalls, precisions, linewidth=1.0) 117 | 118 | plt.xlabel('Recall') 119 | plt.ylabel('Precision') 120 | plt.ylim([0.0, 1]) 121 | plt.xlim([0.0, 1]) 122 | plt.title('auc = ' + str(auc)) 123 | 124 | if print_2file: 125 | plt.savefig(save_name) 126 | 127 | plt.show() 128 | 129 | 130 | """Calculating Max F1 score.""" 131 | 132 | 133 | def cal_F1_score(): 134 | pr_values = np.load("./" + dir_name + "/PR.npz") 135 | f1_scores_max = -1 136 | for i in range(pr_values['precisions'].shape[0]): 137 | precision = pr_values['precisions'][i] 138 | recall = pr_values['recalls'][i] 139 | f1_score = 2 * (precision * recall) / (precision + recall + 1e-10) 140 | if f1_score > f1_scores_max: 141 | f1_scores_max = copy.deepcopy(f1_score) 142 | print("f1_score on KITTI test seq: ", f1_scores_max) 143 | 144 | 145 | def test_with_PR(ground_truth_file_name): 146 | plot_curve = True 147 | save_pr_results = True 148 | print_2file = True 149 | 150 | prediction_file_name = "./" + dir_name + "/predicted_des_L2_dis.npz" 151 | 152 | if not os.path.exists("./" + dir_name + "/PR.npz"): 153 | precisions, recalls = cal_pr_curve(prediction_file_name, ground_truth_file_name) 154 | 155 | pr_values = np.asarray([precisions, recalls]) 156 | pr_values = pr_values[:, np.argsort(pr_values[0, :])] 157 | 158 | if (plot_curve): 159 | plotPRC(pr_values[0], pr_values[1], print_2file) 160 | 161 | if (save_pr_results): 162 | np.savez_compressed("./" + dir_name + "/PR", precisions=pr_values[0], recalls=pr_values[1]) 163 | 164 | cal_F1_score() 165 | 166 | 167 | if __name__ == "__main__": 168 | # load config ================================================================ 169 | config_filename = '../config/config.yml' 170 | 171 | dir_name = "nclt/1" 172 | 173 | config = yaml.safe_load(open(config_filename)) 174 | ground_truth_file_name = config["test_config"]["gt_file"] 175 | # ============================================================================ 176 | 177 | """ground truth file follows OverlapNet""" 178 | test_with_PR(ground_truth_file_name) 179 | -------------------------------------------------------------------------------- /test/test_kitti00_prepare.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | import os 5 | import sys 6 | p = os.path.dirname(os.path.dirname((os.path.abspath(__file__)))) 7 | if p not in sys.path: 8 | sys.path.append(p) 9 | 10 | from matplotlib import pyplot as plt 11 | import torch 12 | import numpy as np 13 | from modules.overlap_transformer import featureExtracter 14 | from tools.read_samples import read_one_need_from_seq 15 | np.set_printoptions(threshold=sys.maxsize) 16 | from tools.utils.utils import * 17 | import faiss 18 | import yaml 19 | import time 20 | 21 | """ 22 | Evaluation is conducted on KITTI 00 in our work. 23 | Args: 24 | amodel: pretrained model. 25 | data_root_folder: dataset root of KITTI. 26 | test_seq: "00" in our work. 27 | """ 28 | def test_chosen_seq(amodel, data_root_folder, test_seq): 29 | range_images = os.listdir(os.path.join(data_root_folder, test_seq, "depth_map")) 30 | 31 | des_list = np.zeros((len(range_images), 256)) 32 | des_list_inv = np.zeros((len(range_images), 256)) 33 | 34 | """Calculate the descriptors of scans""" 35 | print("Calculating the descriptors of scans ...") 36 | for i in range(0, len(range_images)): 37 | current_batch = read_one_need_from_seq(data_root_folder, str(i).zfill(6), test_seq) 38 | current_batch_inv_double = torch.cat((current_batch, current_batch), dim=-1) 39 | current_batch_inv = current_batch_inv_double[:,:,:,450:1350] 40 | current_batch = torch.cat((current_batch, current_batch_inv), dim=0) 41 | amodel.eval() 42 | t = time.time() 43 | current_batch_des = amodel(current_batch) 44 | #print(f'cost:{time.time() - t:.8f}s') 45 | a = time.time() - t 46 | a *= 1000 47 | #print(a/2) 48 | des_list[i, :] = current_batch_des[0, :].cpu().detach().numpy() 49 | des_list_inv[i, :] = current_batch_des[1, :].cpu().detach().numpy() 50 | 51 | des_list = des_list.astype('float32') 52 | """TODO: You can test the rotation-invariance with des_list_inv.""" 53 | des_list_inv = des_list_inv.astype('float32') 54 | 55 | row_list = [] 56 | # for i in range(101, 3817 - 1): 57 | # for i in range(101, 4541-1): 58 | for i in range(101, 28239 - 1): 59 | nlist = 1 60 | k = 50 61 | d = 256 62 | quantizer = faiss.IndexFlatL2(d) 63 | index = faiss.IndexIVFFlat(quantizer, d, nlist, faiss.METRIC_L2) 64 | assert not index.is_trained 65 | index.train(des_list[:i-100,:]) 66 | assert index.is_trained 67 | index.add(des_list[:i-100,:]) 68 | plt.clf() 69 | """Faiss searching""" 70 | D, I = index.search(des_list[i, :].reshape(1, -1), k) 71 | for j in range(D.shape[1]): 72 | """The nearest 100 frames are not considered.""" 73 | if (i-I[:,j])<100: 74 | continue 75 | else: 76 | one_row = np.zeros((1,3)) 77 | one_row[:, 0] = i 78 | one_row[:, 1] = I[:,j] 79 | one_row[:, 2] = D[:,j] 80 | row_list.append(one_row) 81 | print(str(i) + "---->" + str(I[:, j]) + " " + str(D[:, j])) 82 | 83 | row_list_arr = np.array(row_list) 84 | """Saving for the next test""" 85 | folder_name = "./nclt/17/" 86 | if not os.path.exists(folder_name): 87 | os.mkdir(folder_name) 88 | np.savez_compressed(folder_name + "predicted_des_L2_dis", row_list_arr) 89 | 90 | 91 | class testHandler(): 92 | def __init__(self, height=64, width=900, channels=5, norm_layer=None, 93 | data_root_folder=None, 94 | test_seq=None, test_weights=None): 95 | super(testHandler, self).__init__() 96 | 97 | self.height = height 98 | self.width = width 99 | self.channels = channels 100 | self.norm_layer = norm_layer 101 | self.data_root_folder = data_root_folder 102 | self.test_seq = test_seq 103 | 104 | 105 | self.amodel = featureExtracter(channels=self.channels) 106 | # self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 107 | self.device = torch.device("cuda") 108 | self.amodel.to(self.device) 109 | 110 | self.parameters = self.amodel.parameters() 111 | self.test_weights = test_weights 112 | self.overlap_thresh = 0.3 113 | 114 | def eval(self): 115 | with torch.no_grad(): 116 | print("Loading weights from ", self.test_weights) 117 | checkpoint = torch.load(self.test_weights) 118 | self.amodel.load_state_dict(checkpoint['state_dict']) 119 | test_chosen_seq(self.amodel, self.data_root_folder, self.test_seq) 120 | 121 | 122 | 123 | 124 | if __name__ == '__main__': 125 | 126 | # load config ================================================================ 127 | config_filename = '../config/config.yml' 128 | config = yaml.safe_load(open(config_filename)) 129 | data_root_folder = config["data_root"]["data_root_folder"] 130 | test_seq = config["test_config"]["test_seqs"][0] 131 | test_weights = config["test_config"]["test_weights"] 132 | # ============================================================================ 133 | 134 | """ 135 | testHandler to handle with testing process. 136 | Args: 137 | height: the height of the range image (the beam number for convenience). 138 | width: the width of the range image (900, alone the lines of OverlapNet). 139 | channels: 1 for depth only in our work. 140 | norm_layer: None in our work for better model. 141 | use_transformer: Whether to use MHSA. 142 | data_root_folder: root of KITTI sequences. It's better to follow our file structure. 143 | test_seq: "00" in the evaluation. 144 | test_weights: pretrained weights. 145 | """ 146 | test_handler = testHandler(height=32, width=900, channels=1, norm_layer=None, 147 | data_root_folder=data_root_folder, test_seq=test_seq, test_weights=test_weights) 148 | test_handler.eval() 149 | 150 | -------------------------------------------------------------------------------- /test/test_kitti00_topN.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Developed by Junyi Ma, Xieyuanli Chen, and Jun Zhang 3 | # This file is covered by the LICENSE file in the root of the project OverlapTransformer: 4 | # https://github.com/haomo-ai/OverlapTransformer/ 5 | # Brief: calculate Recall@N using the prediction files generated by test_kitti00_prepare.py 6 | 7 | 8 | import os 9 | import sys 10 | p = os.path.dirname(os.path.dirname((os.path.abspath(__file__)))) 11 | if p not in sys.path: 12 | sys.path.append(p) 13 | 14 | import copy 15 | import matplotlib.pyplot as plt 16 | import numpy as np 17 | import yaml 18 | from sklearn import metrics 19 | 20 | 21 | 22 | def cal_topN(prediction_file_name, ground_truth_file_name, topn): 23 | precisions = [] 24 | recalls = [] 25 | 26 | 27 | # loading overlap predictions 28 | des_dists = np.load(prediction_file_name)['arr_0'] 29 | des_dists = np.asarray(des_dists, dtype='float32') 30 | des_dists = des_dists.reshape((len(des_dists), 3)) 31 | 32 | 33 | # loading ground truth in terms of distance 34 | ground_truth = np.load(ground_truth_file_name, allow_pickle='True')['arr_0'] 35 | 36 | 37 | all_have_gt = 0 38 | tps = 0 39 | 40 | 41 | for idx in range(0,len(ground_truth)-1): 42 | gt_idxes = ground_truth[int(idx)] 43 | 44 | if not gt_idxes.any(): 45 | continue 46 | 47 | all_have_gt += 1 48 | for t in range(topn): 49 | if des_dists[des_dists[:,0]==int(idx),:][t, 1] in gt_idxes: 50 | tps += 1 51 | break 52 | 53 | recall_topN = tps/all_have_gt 54 | print(recall_topN) 55 | 56 | 57 | return recall_topN 58 | 59 | 60 | 61 | 62 | def test_with_topN(topn, ground_truth_file_name): 63 | 64 | prediction_file_name = dir_name + "predicted_des_L2_dis.npz" 65 | recall_topN = cal_topN(prediction_file_name, ground_truth_file_name, topn) 66 | return recall_topN 67 | 68 | 69 | 70 | if __name__ == "__main__": 71 | # load config ================================================================ 72 | config_filename = '../config/config.yml' 73 | config = yaml.safe_load(open(config_filename)) 74 | ground_truth_file_name = config["test_config"]["gt_file"] 75 | # ============================================================================ 76 | dir_name = "./test_/" 77 | # topn = 46 # for KITTI 00 top1% 78 | topn = 38 # for KITTI 00 top1% 79 | recall_list = [] 80 | for i in range(1, topn): 81 | print("top"+str(i)+": ") 82 | rec = test_with_topN(i, ground_truth_file_name) 83 | recall_list.append(rec) 84 | print(recall_list) 85 | np.save(dir_name + "recall_list", np.array(recall_list)) 86 | 87 | 88 | -------------------------------------------------------------------------------- /test/test_nclt_topn.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | import os 5 | import sys 6 | 7 | p = os.path.dirname(os.path.dirname((os.path.abspath(__file__)))) 8 | if p not in sys.path: 9 | sys.path.append(p) 10 | 11 | import numpy as np 12 | import yaml 13 | import time 14 | 15 | rubish = [] 16 | 17 | 18 | def load_files(folder): 19 | file_paths = [os.path.join(dp, f) for dp, dn, fn in os.walk( 20 | os.path.expanduser(folder)) for f in fn] 21 | file_paths.sort() 22 | return file_paths 23 | 24 | 25 | def cal_pr_curve(prediction_file_name, ground_truth_file_name, topn): 26 | des_dists = np.load(prediction_file_name)['arr_0'] 27 | des_dists = np.asarray(des_dists, dtype='float32') 28 | des_dists = des_dists.reshape((len(des_dists), 3)) 29 | 30 | ground_truth = np.load(ground_truth_file_name, allow_pickle='True') 31 | 32 | gt_num = 0 33 | all_num = 0 34 | check_out = 0 35 | 36 | img_paths_query = load_files(query_scan) 37 | print(len(img_paths_query)) 38 | sum1 = 0 39 | for idx in range(0, len(img_paths_query), 5): 40 | 41 | gt_idxes = np.array(ground_truth[int(gt_num)]) 42 | if gt_idxes.any(): 43 | all_num += 1 44 | else: 45 | gt_num += 1 46 | continue 47 | 48 | gt_num += 1 49 | 50 | dist_lists_cur = des_dists[des_dists[:, 0] == idx, :] 51 | idx_sorted = np.argsort(dist_lists_cur[:, -1], axis=-1) 52 | 53 | for i in range(topn): 54 | if int(dist_lists_cur[idx_sorted[i], 1]) in gt_idxes: 55 | check_out += 1 56 | break 57 | 58 | print("top" + str(topn) + " recall: ", check_out / all_num) 59 | return check_out / all_num 60 | 61 | 62 | def main(topn, ground_truth_file_name, dir_name): 63 | prediction_file_name = dir_name + "/predicted_des_L2_dis_bet_traj_forward.npz" 64 | 65 | topn_recall = cal_pr_curve(prediction_file_name, ground_truth_file_name, topn) 66 | 67 | return topn_recall 68 | 69 | 70 | if __name__ == "__main__": 71 | # load config ================================================================ 72 | config_filename = '../config/config_nclt.yml' 73 | config = yaml.safe_load(open(config_filename)) 74 | ground_truth_file_name = config["file_root"]["gt_file"] 75 | query_scan = config["file_root"]["data_root_folder_test"] 76 | # ============================================================================ 77 | topn = 20 78 | recall_sum = 0 79 | recall_list = [] 80 | 81 | dir_name = ("./nclt_o1shift2/12.6.15/5_15dis/") 82 | 83 | for i in range(1, topn + 1): 84 | rec = main(i, ground_truth_file_name, dir_name) 85 | recall_sum += rec 86 | if i == 1: 87 | print("AR@1 = ", recall_sum / i) 88 | if i == 5: 89 | print("AR@5 = ", recall_sum / i) 90 | if i == 20: 91 | print("AR@20 = ", recall_sum / i) 92 | recall_list.append(rec) 93 | 94 | print(recall_list) 95 | 96 | np.save(dir_name + "/recall_list", np.array(recall_list)) 97 | -------------------------------------------------------------------------------- /test/test_nclt_topn_prepare.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | import os 5 | import sys 6 | 7 | p = os.path.dirname(os.path.dirname((os.path.abspath(__file__)))) 8 | if p not in sys.path: 9 | sys.path.append(p) 10 | sys.path.append('../tools/') 11 | sys.path.append('../modules/') 12 | 13 | import matplotlib.pyplot as plt 14 | import torch 15 | import yaml 16 | import numpy as np 17 | 18 | from modules.overlap_transformer import featureExtracter 19 | from tools.read_samples_haomo import read_one_need_from_seq 20 | from tools.read_samples_haomo import read_one_need_from_seq_test 21 | 22 | np.set_printoptions(threshold=sys.maxsize) 23 | from tqdm import tqdm 24 | import faiss 25 | from tools.utils.utils import * 26 | 27 | 28 | def shift(tensor, dim, index): 29 | length = tensor.size(dim) 30 | shifted_tensor = torch.cat((tensor.narrow(dim, index, length - index), 31 | tensor.narrow(dim, 0, index)), dim=dim) 32 | return shifted_tensor 33 | 34 | def unshift(tensor, dim, index): 35 | length = tensor.size(dim) 36 | unshifted_tensor = torch.cat((tensor.narrow(dim, length - index, index), 37 | tensor.narrow(dim, 0, length - index)), dim=dim) 38 | return unshifted_tensor 39 | 40 | class testHandler(): 41 | def __init__(self, height=32, width=900, channels=1, norm_layer=None, use_transformer=False, 42 | data_root_folder=None, data_root_folder_test=None, test_weights=None): 43 | super(testHandler, self).__init__() 44 | 45 | self.height = height 46 | self.width = width 47 | self.channels = channels 48 | self.norm_layer = norm_layer 49 | self.use_transformer = use_transformer 50 | self.data_root_folder = data_root_folder 51 | self.data_root_folder_test = data_root_folder_test 52 | 53 | self.amodel = featureExtracter(channels=self.channels, use_transformer=self.use_transformer) 54 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 55 | self.amodel.to(self.device) 56 | self.parameters = self.amodel.parameters() 57 | 58 | self.test_weights = test_weights 59 | 60 | def eval(self): 61 | 62 | print("Resuming From ", self.test_weights) 63 | checkpoint = torch.load(self.test_weights) 64 | self.amodel.load_state_dict(checkpoint['state_dict']) 65 | 66 | range_image_paths_database = load_files(self.data_root_folder) 67 | print("scan number of database: ", len(range_image_paths_database)) 68 | 69 | des_list = np.zeros((len(range_image_paths_database), 512)) # for forward driving 70 | for j in tqdm(range(0, len(range_image_paths_database))): 71 | f1_index = str(j).zfill(6) 72 | current_batch = read_one_need_from_seq(self.data_root_folder, f1_index) 73 | # current_batch_double = torch.cat((current_batch, current_batch), dim=-1) 74 | # current_batch_inv = current_batch_double[:, :, :, 450:1350] 75 | # print(current_batch.shape) 76 | # current_batch = torch.cat((current_batch, current_batch_inv), dim=0) 77 | 78 | 79 | 80 | 81 | # print(current_batch.shape) 82 | self.amodel.eval() 83 | #current_batch_des = self.amodel(current_batch) 84 | index222 = int(torch.rand(1).item() * 900) 85 | input_batch_shift = shift(current_batch, 3, index222) 86 | 87 | global_des = self.amodel(current_batch) 88 | global_des_shift = self.amodel(input_batch_shift) 89 | global_des_add = torch.cat((global_des, global_des_shift), dim=1) 90 | des_list[(j), :] = global_des_add[0, :].cpu().detach().numpy() 91 | 92 | des_list = des_list.astype('float32') 93 | 94 | nlist = 1 95 | k = 50 96 | d = 512 97 | quantizer = faiss.IndexFlatL2(d) 98 | 99 | index = faiss.IndexIVFFlat(quantizer, d, nlist, faiss.METRIC_L2) 100 | assert not index.is_trained 101 | 102 | index.train(des_list) 103 | assert index.is_trained 104 | index.add(des_list) 105 | row_list = [] 106 | 107 | range_image_paths_query = load_files(self.data_root_folder_test) 108 | print("scan number of query: ", len(range_image_paths_query)) 109 | 110 | for i in range(0, len(range_image_paths_query), 5): 111 | 112 | i_index = str(i).zfill(6) 113 | current_batch = read_one_need_from_seq_test(self.data_root_folder_test, i_index) # compute 1 descriptors 114 | # current_batch_double = torch.cat((current_batch, current_batch), dim=-1) 115 | # current_batch_inv = current_batch_double[:, :, :, 450:1350] 116 | # current_batch = torch.cat((current_batch, current_batch_inv), dim=0) 117 | self.amodel.eval() 118 | index111 = int(torch.rand(1).item() * 900) 119 | input_batch_shift = shift(current_batch, 3, index111) 120 | 121 | global_des = self.amodel(current_batch) 122 | global_des_shift = self.amodel(input_batch_shift) 123 | global_des_add = torch.cat((global_des, global_des_shift), dim=1) 124 | #current_batch_des = self.amodel(current_batch) # torch.Size([(1+pos_num+neg_num)), 256]) 125 | print() 126 | des_list_current = global_des_add[0, :].cpu().detach().numpy() 127 | 128 | D, I = index.search(des_list_current.reshape(1, -1), k) # actual search 129 | 130 | for j in range(D.shape[1]): 131 | one_row = np.zeros((1, 3)) 132 | one_row[:, 0] = i 133 | one_row[:, 1] = I[:, j] 134 | one_row[:, 2] = D[:, j] 135 | row_list.append(one_row) 136 | print("query:" + str(i) + "---->" + "database:" + str(I[:, j]) + " " + str(D[:, j])) 137 | 138 | row_list_arr = np.array(row_list) 139 | dir_name = "./nclt_cvtnet/12.2.5/5_15dis/" 140 | if not os.path.exists(dir_name): 141 | os.mkdir(dir_name) 142 | np.savez_compressed(dir_name + "predicted_des_L2_dis_bet_traj_forward", row_list_arr) 143 | 144 | 145 | if __name__ == '__main__': 146 | # data 147 | # load config ================================================================ 148 | config_filename = '../config/config_haomo.yml' 149 | config = yaml.safe_load(open(config_filename)) 150 | data_root_folder = config["file_root"]["data_root_folder"] 151 | data_root_folder_test = config["file_root"]["data_root_folder_test1"] 152 | test_weights = config["file_root"]["test_weights"] 153 | # ============================================================================ 154 | 155 | test_handler = testHandler(height=32, width=900, channels=1, norm_layer=None, use_transformer=False, 156 | data_root_folder=data_root_folder, data_root_folder_test=data_root_folder_test, 157 | test_weights=test_weights) 158 | 159 | test_handler.eval() 160 | -------------------------------------------------------------------------------- /tools/__pycache__/read_all_sets.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SCNU-RISLAB/OverlapMamba/b20add9b83ea87727a993eb316869d271436be78/tools/__pycache__/read_all_sets.cpython-38.pyc -------------------------------------------------------------------------------- /tools/__pycache__/read_all_sets.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SCNU-RISLAB/OverlapMamba/b20add9b83ea87727a993eb316869d271436be78/tools/__pycache__/read_all_sets.cpython-39.pyc -------------------------------------------------------------------------------- /tools/__pycache__/read_samples.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SCNU-RISLAB/OverlapMamba/b20add9b83ea87727a993eb316869d271436be78/tools/__pycache__/read_samples.cpython-38.pyc -------------------------------------------------------------------------------- /tools/__pycache__/read_samples.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SCNU-RISLAB/OverlapMamba/b20add9b83ea87727a993eb316869d271436be78/tools/__pycache__/read_samples.cpython-39.pyc -------------------------------------------------------------------------------- /tools/__pycache__/read_samples_haomo.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SCNU-RISLAB/OverlapMamba/b20add9b83ea87727a993eb316869d271436be78/tools/__pycache__/read_samples_haomo.cpython-39.pyc -------------------------------------------------------------------------------- /tools/name_folder.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def create_folders_in_subfolders(folder_path, new_folder_name): 5 | # 获取文件夹路径下的所有子文件夹 6 | subfolders = [f.path for f in os.scandir(folder_path) if f.is_dir()] 7 | 8 | # 在每个子文件夹中创建新文件夹(如果不存在同名文件夹) 9 | for subfolder in subfolders: 10 | new_folder_path = os.path.join(subfolder, new_folder_name) 11 | if not os.path.exists(new_folder_path): 12 | os.makedirs(new_folder_path) 13 | print(f"Created new folder: {new_folder_path}") 14 | else: 15 | print(f"Folder already exists, skipping: {new_folder_path}") 16 | 17 | 18 | # 指定文件夹路径 19 | folder_path = "/home/lenovo/xqc/OverlapTransformer-master/data_root_folder" 20 | 21 | # 指定新文件夹名称 22 | new_created_folder_name = "depth_map" 23 | 24 | 25 | def rename_folders_in_subfolders(folder_path, old_folder_name, new_folder_name): 26 | # 获取文件夹路径下的所有子文件夹 27 | subfolders = [f.path for f in os.scandir(folder_path) if f.is_dir()] 28 | 29 | # 遍历每个子文件夹,如果子文件夹包含名为old_folder_name的文件夹,则将其重命名为new_folder_name 30 | for subfolder in subfolders: 31 | folder_list = [f.name for f in os.scandir(subfolder) if f.is_dir()] 32 | if old_folder_name in folder_list: 33 | old_folder_path = os.path.join(subfolder, old_folder_name) 34 | new_folder_path = os.path.join(subfolder, new_folder_name) 35 | os.rename(old_folder_path, new_folder_path) 36 | print(f"Renamed folder from {old_folder_path} to {new_folder_path}") 37 | 38 | 39 | # 指定要替换的旧文件夹名称和新文件夹名称 40 | old_folder_name = "depth_map_50" 41 | new_folder_name = "depth_map" 42 | 43 | # 调用函数创建新文件夹 44 | # create_folders_in_subfolders(folder_path, new_created_folder_name) 45 | 46 | # 调用函数重命名文件夹 47 | rename_folders_in_subfolders(folder_path, old_folder_name, new_folder_name) 48 | -------------------------------------------------------------------------------- /tools/read_all_sets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | p = os.path.dirname(os.path.dirname((os.path.abspath(__file__)))) 4 | if p not in sys.path: 5 | sys.path.append(p) 6 | 7 | import numpy as np 8 | 9 | """Use the tools from OverlapNet""" 10 | 11 | def overlap_orientation_npz_file2string_string_nparray(npzfilenames, shuffle=True): 12 | 13 | imgf1_all = [] 14 | imgf2_all = [] 15 | dir1_all = [] 16 | dir2_all = [] 17 | overlap_all = [] 18 | 19 | for npzfilename in npzfilenames: 20 | h = np.load(npzfilename, allow_pickle=True) 21 | 22 | if len(h.files) == 3: 23 | # old format 24 | imgf1 = np.char.mod('%06d', h[h.files[0]][:, 0]).tolist() 25 | imgf2 = np.char.mod('%06d', h[h.files[0]][:, 1]).tolist() 26 | overlap = h[h.files[0]][:, 2] 27 | orientation = h[h.files[0]][:, 3] 28 | n = len(imgf1) 29 | dir1 = np.array(['' for _ in range(n)]).tolist() 30 | dir2 = np.array(['' for _ in range(n)]).tolist() 31 | else: 32 | imgf1 = np.char.mod('%06d', h['overlaps'][:, 0]).tolist() 33 | imgf2 = np.char.mod('%06d', h['overlaps'][:, 1]).tolist() 34 | overlap = h['overlaps'][:, 2] 35 | dir1 = (h['seq'][:, 0]).tolist() 36 | dir2 = (h['seq'][:, 1]).tolist() 37 | 38 | if shuffle: 39 | shuffled_idx = np.random.permutation(overlap.shape[0]) 40 | imgf1 = (np.array(imgf1)[shuffled_idx]).tolist() 41 | imgf2 = (np.array(imgf2)[shuffled_idx]).tolist() 42 | dir1 = (np.array(dir1)[shuffled_idx]).tolist() 43 | dir2 = (np.array(dir2)[shuffled_idx]).tolist() 44 | overlap = overlap[shuffled_idx] 45 | 46 | imgf1_all.extend(imgf1) 47 | imgf2_all.extend(imgf2) 48 | dir1_all.extend(dir1) 49 | dir2_all.extend(dir2) 50 | overlap_all.extend(overlap) 51 | 52 | return (imgf1_all, imgf2_all, dir1_all, dir2_all, np.asarray(overlap_all)) 53 | -------------------------------------------------------------------------------- /tools/read_samples.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | import os 5 | import sys 6 | 7 | p = os.path.dirname(os.path.dirname((os.path.abspath(__file__)))) 8 | if p not in sys.path: 9 | sys.path.append(p) 10 | 11 | import matplotlib.pyplot as plt 12 | import torch 13 | import cv2 14 | import numpy as np 15 | 16 | np.set_printoptions(threshold=sys.maxsize) 17 | from utils.utils import * 18 | import yaml 19 | from tools.read_all_sets import overlap_orientation_npz_file2string_string_nparray 20 | 21 | """ 22 | read one needed $file_num range image from sequence $file_num. 23 | Args: 24 | data_root_folder: dataset root of KITTI. 25 | file_num: the index of the needed scan (zfill 6). 26 | seq_num: the sequence in which the needed scan is (zfill 2). 27 | """ 28 | 29 | 30 | def read_one_need_from_seq(data_root_folder, file_num, seq_num): 31 | depth_data = \ 32 | np.array(cv2.imread(data_root_folder + seq_num + "/depth_map/" + file_num + ".png", 33 | cv2.IMREAD_GRAYSCALE)) 34 | 35 | # depth_data = \ 36 | # np.array(np.load(data_root_folder + seq_num + "/depth_map/" + file_num + ".npy", 37 | # )) 38 | 39 | depth_data_tensor = torch.from_numpy(depth_data).type(torch.FloatTensor).cuda() 40 | depth_data_tensor = torch.unsqueeze(depth_data_tensor, dim=0) 41 | depth_data_tensor = torch.unsqueeze(depth_data_tensor, dim=0) 42 | 43 | return depth_data_tensor 44 | 45 | 46 | """ 47 | read one batch of positive samples and negative samples with respect to $f1_index in sequence $f1_seq. 48 | Args: 49 | data_root_folder: dataset root of KITTI. 50 | f1_index: the index of the needed scan (zfill 6). 51 | f1_seq: the sequence in which the needed scan is (zfill 2). 52 | train_imgf1, train_imgf2, train_dir1, train_dir2: the index dictionary and sequence dictionary following OverlapNet. 53 | train_overlap: overlaps dictionary following OverlapNet. 54 | overlap_thresh: 0.3 following OverlapNet. 55 | """ 56 | 57 | 58 | def read_one_batch_pos_neg(data_root_folder, f1_index, f1_seq, train_imgf1, train_imgf2, train_dir1, train_dir2, 59 | train_overlap, overlap_thresh): # without end 60 | 61 | batch_size = 0 62 | for tt in range(len(train_imgf1)): 63 | if f1_index == train_imgf1[tt] and f1_seq == train_dir1[tt]: 64 | batch_size = batch_size + 1 65 | 66 | sample_batch = torch.from_numpy(np.zeros((batch_size, 1, 32, 900))).type(torch.FloatTensor).cuda() 67 | sample_truth = torch.from_numpy(np.zeros((batch_size, 1))).type(torch.FloatTensor).cuda() 68 | 69 | pos_idx = 0 70 | neg_idx = 0 71 | pos_num = 0 72 | neg_num = 0 73 | 74 | for j in range(len(train_imgf1)): 75 | pos_flag = False 76 | if f1_index == train_imgf1[j] and f1_seq == train_dir1[j]: 77 | if train_overlap[j] > overlap_thresh: 78 | pos_num = pos_num + 1 79 | pos_flag = True 80 | else: 81 | neg_num = neg_num + 1 82 | 83 | depth_data_r = \ 84 | np.array(cv2.imread(data_root_folder + train_dir2[j] + "/depth_map/" + train_imgf2[j] + ".png", 85 | cv2.IMREAD_GRAYSCALE)) 86 | 87 | # depth_data_r = \ 88 | # np.array(np.load(data_root_folder + train_dir2[j] + "/depth_map/" + train_imgf2[j] + ".npy", 89 | # )) 90 | 91 | depth_data_tensor_r = torch.from_numpy(depth_data_r).type(torch.FloatTensor).cuda() 92 | depth_data_tensor_r = torch.unsqueeze(depth_data_tensor_r, dim=0) 93 | 94 | if pos_flag: 95 | sample_batch[pos_idx, :, :, :] = depth_data_tensor_r 96 | sample_truth[pos_idx, :] = torch.from_numpy(np.array(train_overlap[j])).type(torch.FloatTensor).cuda() 97 | pos_idx = pos_idx + 1 98 | else: 99 | sample_batch[batch_size - neg_idx - 1, :, :, :] = depth_data_tensor_r 100 | sample_truth[batch_size - neg_idx - 1, :] = torch.from_numpy(np.array(train_overlap[j])).type( 101 | torch.FloatTensor).cuda() 102 | neg_idx = neg_idx + 1 103 | 104 | return sample_batch, sample_truth, pos_num, neg_num 105 | 106 | 107 | if __name__ == '__main__': 108 | # load config ================================================================ 109 | config_filename = '../config/config.yml' 110 | config = yaml.safe_load(open(config_filename)) 111 | seqs_root = config["data_root"]["data_root_folder"] 112 | # ============================================================================ 113 | 114 | seq = "08" 115 | cur_frame_idx = "000887" 116 | current_frame = read_one_need_from_seq(seqs_root, cur_frame_idx, seq) 117 | 118 | traindata_npzfiles = [os.path.join(seqs_root, seq, 'overlaps/train_set.npz')] 119 | (train_imgf1, train_imgf2, train_dir1, train_dir2, train_overlap) = \ 120 | overlap_orientation_npz_file2string_string_nparray(traindata_npzfiles) 121 | reference_frames, reference_gts, pos_num, neg_num = read_one_batch_pos_neg(seqs_root, cur_frame_idx, seq, 122 | train_imgf1, train_imgf2, train_dir1, 123 | train_dir2, train_overlap, 0.3) 124 | 125 | # visualization 126 | print("the size of current_frame: ", current_frame.size()) 127 | plt.figure(figsize=(15, 3)) 128 | plt.title("One sampled range image from KITTI sequence " + seq + ": " + cur_frame_idx + ".bin") 129 | plt.imshow(current_frame.cpu().detach().numpy()[0, 0, :, :]) 130 | plt.show() 131 | 132 | print("the size of reference_frames: ", reference_frames.size()) 133 | vis_idx = 5 # show the 5th sampled range image in the reference batch 134 | plt.figure(figsize=(15, 3)) 135 | plt.suptitle( 136 | "One sampled query-reference from KITTI sequence " + seq + ", Overlap: " + str(reference_gts[vis_idx].item())) 137 | plt.subplot(211) 138 | plt.title("query") 139 | plt.imshow(current_frame.cpu().detach().numpy()[0, 0, :, :]) 140 | plt.subplot(212) 141 | plt.title("reference") 142 | plt.imshow(reference_frames.cpu().detach().numpy()[vis_idx, 0, :, :]) 143 | plt.show() 144 | -------------------------------------------------------------------------------- /tools/read_samples_haomo.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | import os 5 | import sys 6 | p = os.path.dirname(os.path.dirname((os.path.abspath(__file__)))) 7 | if p not in sys.path: 8 | sys.path.append(p) 9 | sys.path.append('../tools/') 10 | sys.path.append('../modules/') 11 | 12 | import torch 13 | import numpy as np 14 | np.set_printoptions(threshold=sys.maxsize) 15 | from utils.utils import * 16 | import yaml 17 | import matplotlib.pyplot as plt 18 | 19 | 20 | def read_one_need_from_seq(data_root_folder, file_num): 21 | 22 | depth_data = np.load(data_root_folder + file_num + ".npy") 23 | depth_data_tensor = torch.from_numpy(depth_data).type(torch.FloatTensor).cuda() 24 | depth_data_tensor = torch.unsqueeze(depth_data_tensor, dim=0) 25 | depth_data_tensor = torch.unsqueeze(depth_data_tensor, dim=0) 26 | 27 | return depth_data_tensor 28 | 29 | def read_one_need_from_seq_test(data_root_folder_test, file_num): 30 | 31 | depth_data = np.load(data_root_folder_test + file_num + ".npy") 32 | 33 | depth_data_tensor = torch.from_numpy(depth_data).type(torch.FloatTensor).cuda() 34 | depth_data_tensor = torch.unsqueeze(depth_data_tensor, dim=0) 35 | depth_data_tensor = torch.unsqueeze(depth_data_tensor, dim=0) 36 | 37 | return depth_data_tensor 38 | 39 | 40 | 41 | 42 | def read_one_batch_pos_neg(data_root_folder, f1_index, f1_seq, train_imgf1, train_imgf2, train_dir1, train_dir2, train_overlap, overlap_thresh): # without end 43 | 44 | batch_size = 0 45 | for tt in range(len(train_imgf1)): 46 | if f1_index == train_imgf1[tt] and f1_seq == train_dir1[tt] and (train_overlap[tt]> overlap_thresh or train_overlap[tt]<(overlap_thresh-0.0)): # TODO: You can update the range 47 | batch_size = batch_size + 1 48 | 49 | sample_batch = torch.from_numpy(np.zeros((batch_size, 1, 32, 900))).type(torch.FloatTensor).cuda() 50 | sample_truth = torch.from_numpy(np.zeros((batch_size, 1))).type(torch.FloatTensor).cuda() 51 | 52 | pos_idx = 0 53 | neg_idx = 0 54 | pos_num = 0 55 | neg_num = 0 56 | 57 | 58 | for j in range(len(train_imgf1)): 59 | pos_flag = False 60 | if f1_index == train_imgf1[j] and f1_seq==train_dir1[j]: 61 | if train_overlap[j]> overlap_thresh: 62 | pos_num = pos_num + 1 63 | pos_flag = True 64 | elif train_overlap[j]< overlap_thresh: 65 | neg_num = neg_num + 1 66 | else: 67 | continue 68 | 69 | depth_data_r = np.load(data_root_folder + train_imgf2[j] + ".npy") 70 | depth_data_tensor_r = torch.from_numpy(depth_data_r).type(torch.FloatTensor).cuda() 71 | depth_data_tensor_r = torch.unsqueeze(depth_data_tensor_r, dim=0) 72 | 73 | if pos_flag: 74 | sample_batch[pos_idx,:,:,:] = depth_data_tensor_r 75 | sample_truth[pos_idx, :] = torch.from_numpy(np.array(train_overlap[j])).type(torch.FloatTensor).cuda() 76 | pos_idx = pos_idx + 1 77 | else: 78 | sample_batch[batch_size-neg_idx-1, :, :, :] = depth_data_tensor_r 79 | sample_truth[batch_size-neg_idx-1, :] = torch.from_numpy(np.array(train_overlap[j])).type(torch.FloatTensor).cuda() 80 | neg_idx = neg_idx + 1 81 | 82 | 83 | return sample_batch, sample_truth, pos_num, neg_num 84 | 85 | 86 | 87 | if __name__ == '__main__': 88 | # load config ================================================================ 89 | config_filename = '../config/config_haomo.yml' 90 | config = yaml.safe_load(open(config_filename)) 91 | data_root_folder = config["file_root"]["data_root_folder"] 92 | triplets_for_training = config["file_root"]["triplets_for_training"] 93 | training_seqs = config["training_config"]["training_seqs"] 94 | # ============================================================================ 95 | 96 | train_set_imgf1_imgf2_overlap = np.load(triplets_for_training) 97 | # print(train_set_imgf1_imgf2_overlap) 98 | 99 | cur_frame_idx = "003430" 100 | current_frame = read_one_need_from_seq(data_root_folder, cur_frame_idx) 101 | 102 | train_imgf1 = train_set_imgf1_imgf2_overlap[:, 0] 103 | train_imgf2 = train_set_imgf1_imgf2_overlap[:, 1] 104 | train_dir1 = np.zeros((len(train_imgf1),)) # to use the same form as KITTI 105 | train_dir2 = np.zeros((len(train_imgf2),)) 106 | train_overlap = train_set_imgf1_imgf2_overlap[:, 2].astype(float) 107 | reference_frames, reference_gts, pos_num, neg_num = read_one_batch_pos_neg \ 108 | (data_root_folder, cur_frame_idx, 0, train_imgf1, train_imgf2, train_dir1, 109 | train_dir2, train_overlap, 0.3) 110 | 111 | 112 | 113 | # visualization 114 | print("the size of current_frame: ", current_frame.size()) 115 | plt.figure(figsize=(15,3)) 116 | plt.title("One sampled range image from Haomo dataset: " + cur_frame_idx + ".bin") 117 | plt.imshow(current_frame.cpu().detach().numpy()[0, 0, :, :]) 118 | plt.show() 119 | 120 | print("the size of reference_frames: ", reference_frames.size()) 121 | vis_idx = 5 # show the 2rd sampled range image in the reference batch 122 | plt.figure(figsize=(15,3)) 123 | plt.suptitle("One sampled query-reference from Haomo dataset, Overlap: " + str(reference_gts[vis_idx].item())) 124 | plt.subplot(211) 125 | plt.title("query") 126 | plt.imshow(current_frame.cpu().detach().numpy()[0, 0, :, :]) 127 | plt.subplot(212) 128 | plt.title("reference") 129 | plt.imshow(reference_frames.cpu().detach().numpy()[vis_idx, 0, :, :]) 130 | plt.show() 131 | 132 | 133 | -------------------------------------------------------------------------------- /tools/utils/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SCNU-RISLAB/OverlapMamba/b20add9b83ea87727a993eb316869d271436be78/tools/utils/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /tools/utils/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SCNU-RISLAB/OverlapMamba/b20add9b83ea87727a993eb316869d271436be78/tools/utils/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /tools/utils/com_overlap_yaw.py: -------------------------------------------------------------------------------- 1 | # #!/usr/bin/env python3 2 | # # Developed by Xieyuanli Chen and Thomas Läbe 3 | # # This file is covered by the LICENSE file in the root of this project. 4 | # # Brief: This script generate the overlap and orientation combined mapping file. 5 | 6 | 7 | 8 | try: 9 | from utils import * 10 | except: 11 | from utils import * 12 | 13 | 14 | def com_overlap_yaw(scan_paths, poses, frame_idx, leg_output_width=360): 15 | """compute the overlap and yaw ground truth from the ground truth poses, 16 | which is used for OverlapNet training and testing. 17 | Args: 18 | scan_paths: paths of all raw LiDAR scans 19 | poses: ground-truth poses either given by the dataset or generated by SLAM or odometry 20 | frame_idx: the current frame index 21 | Returns: 22 | ground_truth_mapping: the ground truth overlap and yaw used for training OverlapNet, 23 | where each row contains [current_frame_idx, reference_frame_idx, overlap, yaw] 24 | """ 25 | # init ground truth overlap and yaw 26 | print('Start to compute ground truth overlap and yaw ...') 27 | overlaps = [] 28 | yaw_idxs = [] 29 | yaw_resolution = leg_output_width 30 | 31 | # we calculate the ground truth for one given frame only 32 | # generate range projection for the given frame 33 | current_points = load_vertex(scan_paths[frame_idx]) 34 | current_range, project_points, _, _ = range_projection(current_points) 35 | visible_points = project_points[current_range > 0] 36 | valid_num = len(visible_points) 37 | current_pose = poses[frame_idx] 38 | 39 | for reference_idx in range(len(scan_paths)): 40 | # generate range projection for the reference frame 41 | reference_pose = poses[reference_idx] 42 | reference_points = load_vertex(scan_paths[reference_idx]) 43 | reference_points_world = reference_pose.dot(reference_points.T).T 44 | reference_points_in_current = np.linalg.inv(current_pose).dot(reference_points_world.T).T 45 | reference_range, _, _, _ = range_projection(reference_points_in_current) 46 | 47 | # calculate overlap 48 | overlap = np.count_nonzero( 49 | abs(reference_range[reference_range > 0] - current_range[reference_range > 0]) < 1) / valid_num 50 | overlaps.append(overlap) 51 | 52 | # calculate yaw angle 53 | relative_transform = np.linalg.inv(current_pose).dot(reference_pose) 54 | relative_rotation = relative_transform[:3, :3] 55 | _, _, yaw = euler_angles_from_rotation_matrix(relative_rotation) 56 | 57 | # discretize yaw angle and shift the 0 degree to the center to make the network easy to lean 58 | yaw_element_idx = int(- (yaw / np.pi) * yaw_resolution // 2 + yaw_resolution // 2) 59 | yaw_idxs.append(yaw_element_idx) 60 | 61 | # print('finished pair id: ', reference_idx) 62 | 63 | # ground truth format: each row contains [current_frame_idx, reference_frame_idx, overlap, yaw] 64 | ground_truth_mapping = np.zeros((len(scan_paths), 4)) 65 | ground_truth_mapping[:, 0] = np.ones(len(scan_paths)) * frame_idx 66 | ground_truth_mapping[:, 1] = np.arange(len(scan_paths)) 67 | ground_truth_mapping[:, 2] = overlaps 68 | ground_truth_mapping[:, 3] = yaw_idxs 69 | # print(ground_truth_mapping) 70 | 71 | print('Finish generating ground_truth_mapping!') 72 | 73 | return ground_truth_mapping 74 | -------------------------------------------------------------------------------- /tools/utils/gen_depth_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Developed by Xieyuanli Chen and Thomas Läbe 3 | # This file is covered by the LICENSE file in the root of the project OverlapNet: 4 | #https://github.com/PRBonn/OverlapNet 5 | # Brief: a script to generate depth data 6 | import os 7 | from utils import load_files 8 | import numpy as np 9 | from utils import range_projection 10 | 11 | import cv2 12 | try: 13 | from utils import * 14 | except: 15 | from utils import * 16 | 17 | import scipy.linalg as linalg 18 | def rotate_mat( axis, radian): 19 | rot_matrix = linalg.expm(np.cross(np.eye(3), axis / linalg.norm(axis) * radian)) 20 | return rot_matrix 21 | # print(type(rot_matrix)) 22 | 23 | 24 | 25 | def gen_depth_data(scan_folder, dst_folder, normalize=False): 26 | """ Generate projected range data in the shape of (64, 900, 1). 27 | The input raw data are in the shape of (Num_points, 3). 28 | """ 29 | # specify the goal folder 30 | dst_folder = os.path.join(dst_folder, 'depth') 31 | try: 32 | os.stat(dst_folder) 33 | print('generating depth data in: ', dst_folder) 34 | except: 35 | print('creating new depth folder: ', dst_folder) 36 | os.mkdir(dst_folder) 37 | 38 | # load LiDAR scan files 39 | scan_paths = load_files(scan_folder) 40 | 41 | depths = [] 42 | axis_x, axis_y, axis_z = [1,0,0], [0,1,0], [0, 0, 1] 43 | 44 | # iterate over all scan files 45 | for idx in range(len(scan_paths)): 46 | # load a point cloud 47 | current_vertex = np.fromfile(scan_paths[idx], dtype=np.float32) 48 | current_vertex = current_vertex.reshape((-1, 4)) 49 | 50 | proj_range, _, _, _ = range_projection(current_vertex) # proj_ranges from larger to smaller 51 | 52 | # normalize the image 53 | if normalize: 54 | proj_range = proj_range / np.max(proj_range) 55 | 56 | # generate the destination path 57 | dst_path = os.path.join(dst_folder, str(idx).zfill(6)) 58 | 59 | # np.save(dst_path, proj_range) 60 | filename = dst_path + ".png" 61 | cv2.imwrite(filename, proj_range) 62 | print('finished generating depth data at: ', dst_path) 63 | 64 | return depths 65 | 66 | 67 | if __name__ == '__main__': 68 | scan_folder = '/home/lenovo/xqc/datasets/kitti/sequences/11/velodyne' 69 | dst_folder = '/data_root_folder/11' 70 | 71 | depth_data = gen_depth_data(scan_folder, dst_folder) 72 | -------------------------------------------------------------------------------- /tools/utils/gen_gt_data.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | import os 5 | import sys 6 | p = os.path.dirname(os.path.dirname((os.path.abspath(__file__)))) 7 | if p not in sys.path: 8 | sys.path.append(p) 9 | 10 | import time 11 | from com_overlap_yaw import com_overlap_yaw 12 | from utils import * 13 | 14 | # paths of kitti dataset 15 | scan_folder = "/media/mjy/My Passport/kitti_all/dataset/sequences/00/velodyne" 16 | calib_file = "/home/mjy/datasets/kitti/data_odometry_calib/dataset/sequences/00/calib.txt" 17 | # prepare poses of semantic kitti dataset (refined poses) 18 | poses_file = "/media/mjy/Samsung_T5/SemanticKITTI/data_odometry_labels/dataset/sequences/00/poses.txt" 19 | 20 | scan_paths = load_files(scan_folder) 21 | T_cam_velo = load_calib(calib_file) 22 | T_cam_velo = np.asarray(T_cam_velo).reshape((4, 4)) 23 | T_velo_cam = np.linalg.inv(T_cam_velo) 24 | poses = load_poses(poses_file) 25 | pose0_inv = np.linalg.inv(poses[0]) 26 | poses_new = [] 27 | for pose in poses: 28 | poses_new.append(T_velo_cam.dot(pose0_inv).dot(pose).dot(T_cam_velo)) 29 | poses = np.array(poses_new) 30 | 31 | 32 | all_rows = [] 33 | thresh = 0.3 34 | for i in range(len(scan_paths)): 35 | print(str(i) + " -------------------------------->") 36 | time1 = time.time() 37 | scan_paths_this_frame = [] 38 | poses_this_frame = [] 39 | scan_paths_this_frame.append(scan_paths[i]) 40 | poses_this_frame.append(poses[i]) 41 | idx_in_range = [] 42 | for idx in range(len(scan_paths)): 43 | if np.linalg.norm(poses[idx, :3, -1] - poses[i, :3, -1]) < 30 and (i-idx) > 100: 44 | scan_paths_this_frame.append(scan_paths[idx]) 45 | poses_this_frame.append(poses[idx]) 46 | idx_in_range.append(idx) 47 | print("prepared indexes for current laser: ", idx_in_range) 48 | 49 | poses_new_this_frame = np.array(poses_this_frame) 50 | ground_truth_mapping = com_overlap_yaw(scan_paths_this_frame, poses_new_this_frame, frame_idx=0, leg_output_width=360) 51 | 52 | one_row = [] 53 | for m in range(1, ground_truth_mapping.shape[0]): 54 | if ground_truth_mapping[m,2] > thresh: 55 | one_row.append(idx_in_range[m-1]) 56 | all_rows.append(one_row) 57 | print("gt list for current laser: ", one_row) 58 | time2 = time.time() 59 | print("time: ", time2-time1) 60 | 61 | print(len(all_rows)) 62 | all_rows_array = np.array(all_rows) 63 | np.savez_compressed("loop_gt_seq00_0.3overlap_inactive", all_rows_array) 64 | 65 | 66 | -------------------------------------------------------------------------------- /tools/utils/split_train_val.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Developed by Xieyuanli Chen and Thomas Läbe 3 | # This file is covered by the LICENSE file in the root of the project OverlapNet: 4 | #https://github.com/PRBonn/OverlapNet 5 | # Brief: a simple example to split the ground truth data into training and validation parts 6 | 7 | import numpy as np 8 | from sklearn.model_selection import train_test_split 9 | 10 | 11 | def split_train_val(ground_truth_mapping): 12 | """Split the ground truth data into training and validation two parts. 13 | Args: 14 | ground_truth_mapping: the raw ground truth mapping array 15 | Returns: 16 | train_set: data used for training 17 | test_set: data used for validation 18 | """ 19 | # set the ratio of validation data 20 | test_size = int(len(ground_truth_mapping) / 10) 21 | 22 | # use sklearn library to split the data 23 | train_set, test_set = train_test_split(ground_truth_mapping, test_size=test_size) 24 | 25 | print('finished generating training data and validation data') 26 | 27 | return train_set, test_set 28 | 29 | 30 | if __name__ == '__main__': 31 | # read from npz file 32 | ground_truth_file = 'path/to/the/groun-truth/file' 33 | ground_truth_mapping = np.load(ground_truth_file)['arr_0'] 34 | 35 | split_train_val(ground_truth_mapping) 36 | 37 | -------------------------------------------------------------------------------- /tools/utils/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Developed by Xieyuanli Chen and Thomas Läbe 3 | # This file is covered by the LICENSE file in the root of the project OverlapNet: 4 | #https://github.com/PRBonn/OverlapNet 5 | # Brief: some utilities 6 | import os 7 | import math 8 | import numpy as np 9 | import time 10 | 11 | 12 | def load_poses(pose_path): 13 | """ Load ground truth poses (T_w_cam0) from file. 14 | Args: 15 | pose_path: (Complete) filename for the pose file 16 | Returns: 17 | A numpy array of size nx4x4 with n poses as 4x4 transformation 18 | matrices 19 | """ 20 | # Read and parse the poses 21 | poses = [] 22 | try: 23 | if '.txt' in pose_path: 24 | with open(pose_path, 'r') as f: 25 | lines = f.readlines() 26 | for line in lines: 27 | T_w_cam0 = np.fromstring(line, dtype=float, sep=' ') 28 | T_w_cam0 = T_w_cam0.reshape(3, 4) 29 | T_w_cam0 = np.vstack((T_w_cam0, [0, 0, 0, 1])) 30 | poses.append(T_w_cam0) 31 | else: 32 | poses = np.load(pose_path)['arr_0'] 33 | 34 | except FileNotFoundError: 35 | print('Ground truth poses are not avaialble.') 36 | 37 | return np.array(poses) 38 | 39 | 40 | def load_calib(calib_path): 41 | """ Load calibrations (T_cam_velo) from file. 42 | """ 43 | # Read and parse the calibrations 44 | T_cam_velo = [] 45 | try: 46 | with open(calib_path, 'r') as f: 47 | lines = f.readlines() 48 | for line in lines: 49 | if 'Tr:' in line: 50 | line = line.replace('Tr:', '') 51 | T_cam_velo = np.fromstring(line, dtype=float, sep=' ') 52 | T_cam_velo = T_cam_velo.reshape(3, 4) 53 | T_cam_velo = np.vstack((T_cam_velo, [0, 0, 0, 1])) 54 | 55 | except FileNotFoundError: 56 | print('Calibrations are not avaialble.') 57 | 58 | return np.array(T_cam_velo) 59 | 60 | 61 | def range_projection(current_vertex, fov_up=3.0, fov_down=-25.0, proj_H=64, proj_W=900, max_range=80): 62 | """ Project a pointcloud into a spherical projection, range image. 63 | Args: 64 | current_vertex: raw point clouds 65 | Returns: 66 | proj_range: projected range image with depth, each pixel contains the corresponding depth 67 | proj_vertex: each pixel contains the corresponding point (x, y, z, 1) 68 | proj_intensity: each pixel contains the corresponding intensity 69 | proj_idx: each pixel contains the corresponding index of the point in the raw point cloud 70 | """ 71 | # laser parameters 72 | fov_up = fov_up / 180.0 * np.pi # field of view up in radians 73 | fov_down = fov_down / 180.0 * np.pi # field of view down in radians 74 | fov = abs(fov_down) + abs(fov_up) # get field of view total in radians 75 | 76 | # get depth of all points 77 | depth = np.linalg.norm(current_vertex[:, :3], 2, axis=1) 78 | current_vertex = current_vertex[(depth > 0) & (depth < max_range)] # get rid of [0, 0, 0] points 79 | depth = depth[(depth > 0) & (depth < max_range)] 80 | 81 | # get scan components 82 | scan_x = current_vertex[:, 0] 83 | scan_y = current_vertex[:, 1] 84 | scan_z = current_vertex[:, 2] 85 | intensity = current_vertex[:, 3] 86 | 87 | # get angles of all points 88 | yaw = -np.arctan2(scan_y, scan_x) 89 | pitch = np.arcsin(scan_z / depth) 90 | 91 | # get projections in image coords 92 | proj_x = 0.5 * (yaw / np.pi + 1.0) # in [0.0, 1.0] 93 | proj_y = 1.0 - (pitch + abs(fov_down)) / fov # in [0.0, 1.0] 94 | 95 | # scale to image size using angular resolution 96 | proj_x *= proj_W # in [0.0, W] 97 | proj_y *= proj_H # in [0.0, H] 98 | 99 | # round and clamp for use as index 100 | proj_x = np.floor(proj_x) 101 | proj_x = np.minimum(proj_W - 1, proj_x) 102 | proj_x = np.maximum(0, proj_x).astype(np.int32) # in [0,W-1] 103 | 104 | proj_y = np.floor(proj_y) 105 | proj_y = np.minimum(proj_H - 1, proj_y) 106 | proj_y = np.maximum(0, proj_y).astype(np.int32) # in [0,H-1] 107 | 108 | # order in decreasing depth 109 | order = np.argsort(depth)[::-1] 110 | depth = depth[order] 111 | intensity = intensity[order] 112 | proj_y = proj_y[order] 113 | proj_x = proj_x[order] 114 | 115 | scan_x = scan_x[order] 116 | scan_y = scan_y[order] 117 | scan_z = scan_z[order] 118 | 119 | indices = np.arange(depth.shape[0]) 120 | indices = indices[order] 121 | 122 | proj_range = np.full((proj_H, proj_W), -1, 123 | dtype=np.float32) # [H,W] range (-1 is no data) 124 | proj_vertex = np.full((proj_H, proj_W, 4), -1, 125 | dtype=np.float32) # [H,W] index (-1 is no data) 126 | proj_idx = np.full((proj_H, proj_W), -1, 127 | dtype=np.int32) # [H,W] index (-1 is no data) 128 | proj_intensity = np.full((proj_H, proj_W), -1, 129 | dtype=np.float32) # [H,W] index (-1 is no data) 130 | 131 | proj_range[proj_y, proj_x] = depth 132 | proj_vertex[proj_y, proj_x] = np.array([scan_x, scan_y, scan_z, np.ones(len(scan_x))]).T 133 | proj_idx[proj_y, proj_x] = indices 134 | proj_intensity[proj_y, proj_x] = intensity 135 | 136 | return proj_range, proj_vertex, proj_intensity, proj_idx 137 | 138 | 139 | def gen_normal_map(current_range, current_vertex, proj_H=64, proj_W=900): # 高64,宽900 140 | """ Generate a normal image given the range projection of a point cloud. 141 | Args: 142 | current_range: range projection of a point cloud, each pixel contains the corresponding depth 143 | current_vertex: range projection of a point cloud, 144 | each pixel contains the corresponding point (x, y, z, 1) 145 | Returns: 146 | normal_data: each pixel contains the corresponding normal 147 | """ 148 | 149 | 150 | normal_data = np.full((proj_H, proj_W, 3), -1, dtype=np.float32) 151 | time_pre1 = time.time() 152 | 153 | # iterate over all pixels in the range image 154 | for x in range(proj_W): 155 | for y in range(proj_H - 1): 156 | p = current_vertex[y, x][:3] 157 | depth = current_range[y, x] 158 | 159 | if depth > 0: 160 | wrap_x = wrap(x + 1, proj_W) 161 | u = current_vertex[y, wrap_x][:3] 162 | u_depth = current_range[y, wrap_x] 163 | if u_depth <= 0: 164 | continue 165 | 166 | v = current_vertex[y + 1, x][:3] 167 | v_depth = current_range[y + 1, x] 168 | if v_depth <= 0: 169 | continue 170 | 171 | u_norm = (u - p) / np.linalg.norm(u - p) 172 | v_norm = (v - p) / np.linalg.norm(v - p) 173 | 174 | w = np.cross(v_norm, u_norm) 175 | norm = np.linalg.norm(w) 176 | if norm > 0: 177 | normal = w / norm 178 | normal_data[y, x] = normal 179 | time_pre2 = time.time() 180 | print("gen normal time ", time_pre2 - time_pre1) 181 | return normal_data 182 | 183 | 184 | def wrap(x, dim): 185 | """ Wrap the boarder of the range image. 186 | """ 187 | value = x 188 | if value >= dim: 189 | value = (value - dim) 190 | if value < 0: 191 | value = (value + dim) 192 | return value 193 | 194 | 195 | def euler_angles_from_rotation_matrix(R): 196 | """ From the paper by Gregory G. Slabaugh, 197 | Computing Euler angles from a rotation matrix 198 | psi, theta, phi = roll pitch yaw (x, y, z) 199 | Args: 200 | R: rotation matrix, a 3x3 numpy array 201 | Returns: 202 | a tuple with the 3 values psi, theta, phi in radians 203 | """ 204 | 205 | def isclose(x, y, rtol=1.e-5, atol=1.e-8): 206 | return abs(x - y) <= atol + rtol * abs(y) 207 | 208 | phi = 0.0 209 | if isclose(R[2, 0], -1.0): 210 | theta = math.pi / 2.0 211 | psi = math.atan2(R[0, 1], R[0, 2]) 212 | elif isclose(R[2, 0], 1.0): 213 | theta = -math.pi / 2.0 214 | psi = math.atan2(-R[0, 1], -R[0, 2]) 215 | else: 216 | theta = -math.asin(R[2, 0]) 217 | cos_theta = math.cos(theta) 218 | psi = math.atan2(R[2, 1] / cos_theta, R[2, 2] / cos_theta) 219 | phi = math.atan2(R[1, 0] / cos_theta, R[0, 0] / cos_theta) 220 | return psi, theta, phi 221 | 222 | 223 | def load_vertex(scan_path): 224 | """ Load 3D points of a scan. The fileformat is the .bin format used in 225 | the KITTI dataset. 226 | Args: 227 | scan_path: the (full) filename of the scan file 228 | Returns: 229 | A nx4 numpy array of homogeneous points (x, y, z, 1). 230 | """ 231 | current_vertex = np.fromfile(scan_path, dtype=np.float32) 232 | current_vertex = current_vertex.reshape((-1, 4)) 233 | current_points = current_vertex[:, 0:3] 234 | current_vertex = np.ones((current_points.shape[0], current_points.shape[1] + 1)) 235 | current_vertex[:, :-1] = current_points 236 | return current_vertex 237 | 238 | 239 | def load_files(folder): 240 | """ Load all files in a folder and sort. 241 | """ 242 | file_paths = [os.path.join(dp, f) for dp, dn, fn in os.walk( 243 | os.path.expanduser(folder)) for f in fn] 244 | file_paths.sort() 245 | return file_paths 246 | 247 | 248 | semantic_mapping = { # bgr 249 | 0: [0, 0, 0], # "unlabeled", and others ignored 250 | 1: [245, 150, 100], # "car" 251 | 2: [245, 230, 100], # "bicycle" 252 | 3: [150, 60, 30], # "motorcycle" 253 | 4: [180, 30, 80], # "truck" 254 | 5: [255, 0, 0], # "other-vehicle" 255 | 6: [30, 30, 255], # "person" 256 | 7: [200, 40, 255], # "bicyclist" 257 | 8: [90, 30, 150], # "motorcyclist" 258 | 9: [255, 0, 255], # "road" 259 | 10: [255, 150, 255], # "parking" 260 | 11: [75, 0, 75], # "sidewalk" 261 | 12: [75, 0, 175], # "other-ground" 262 | 13: [0, 200, 255], # "building" 263 | 14: [50, 120, 255], # "fence" 264 | 15: [0, 175, 0], # "vegetation" 265 | 16: [0, 60, 135], # "trunk" 266 | 17: [80, 240, 150], # "terrain" 267 | 18: [150, 240, 255], # "pole" 268 | 19: [0, 0, 255] # "traffic-sign" 269 | } 270 | -------------------------------------------------------------------------------- /train/training_distributed.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | import os 5 | import sys 6 | p = os.path.dirname(os.path.dirname((os.path.abspath(__file__)))) 7 | if p not in sys.path: 8 | sys.path.append(p) 9 | sys.path.append('../tools/') 10 | sys.path.append('../modules/') 11 | import torch 12 | import torch.distributed as dist 13 | from torch.nn.parallel import DistributedDataParallel as DDP 14 | from torch.utils.data.distributed import DistributedSampler 15 | 16 | 17 | import numpy as np 18 | from tensorboardX import SummaryWriter 19 | from tools.read_all_sets import overlap_orientation_npz_file2string_string_nparray 20 | from modules.overlap_transformer import featureExtracter 21 | from tools.read_samples import read_one_batch_pos_neg 22 | from tools.read_samples import read_one_need_from_seq 23 | np.set_printoptions(threshold=sys.maxsize) 24 | import modules.loss as PNV_loss 25 | from tools.utils.utils import * 26 | from valid.valid_seq import validate_seq_faiss 27 | import yaml 28 | import argparse 29 | 30 | def setup(rank, world_size): 31 | os.environ['MASTER_ADDR'] = 'localhost' 32 | os.environ['MASTER_PORT'] = '12355' 33 | dist.init_process_group("nccl", rank=rank, world_size=world_size) 34 | 35 | def cleanup(): 36 | dist.destroy_process_group() 37 | 38 | class trainHandler(): 39 | def __init__(self, rank, world_size, height=64, width=900, channels=5, norm_layer=None, use_transformer=True, use_mamba=False, lr = 0.001, 40 | data_root_folder = None, train_set=None, training_seqs=None): 41 | super(trainHandler, self).__init__() 42 | 43 | self.rank = rank 44 | self.world_size = world_size 45 | setup(rank, world_size) 46 | 47 | self.height = height 48 | self.width = width 49 | self.channels = channels 50 | self.norm_layer = norm_layer 51 | self.use_transformer = use_transformer 52 | self.use_mamba = use_mamba 53 | self.learning_rate = lr 54 | self.data_root_folder = data_root_folder 55 | self.train_set = train_set 56 | self.training_seqs = training_seqs 57 | 58 | self.amodel = featureExtracter(channels=self.channels, use_transformer=self.use_transformer, use_mamba=self.use_mamba) 59 | self.device = torch.device(f"cuda:{rank}" if torch.cuda.is_available() else "cpu") 60 | self.amodel.to(self.device) 61 | self.amodel = DDP(self.amodel, device_ids=[rank]) 62 | self.parameters = self.amodel.parameters() 63 | self.optimizer = torch.optim.Adam(self.parameters, self.learning_rate) 64 | 65 | # self.optimizer = torch.optim.SGD(self.parameters, lr=self.learning_rate, momentum=0.9) 66 | self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=5, gamma=0.9) 67 | # self.scheduler = torch.optim.lr_scheduler.ExponentialLR(self.optimizer, gamma=0.98) 68 | 69 | self.traindata_npzfiles = train_set 70 | 71 | (self.train_imgf1, self.train_imgf2, self.train_dir1, self.train_dir2, self.train_overlap) = \ 72 | overlap_orientation_npz_file2string_string_nparray(self.traindata_npzfiles) 73 | 74 | """change the args for resuming training process""" 75 | self.resume = True 76 | self.save_name = "../weights/train/trained_overlap_transformer24.pth.tar" 77 | 78 | """overlap threshold follows OverlapNet""" 79 | self.overlap_thresh = 0.3 80 | 81 | def train(self): 82 | 83 | epochs = 100 84 | 85 | """resume from the saved model""" 86 | if self.resume: 87 | resume_filename = self.save_name 88 | print("Resuming from ", resume_filename) 89 | checkpoint = torch.load(resume_filename, map_location=self.device) 90 | starting_epoch = checkpoint['epoch'] 91 | self.amodel.load_state_dict(checkpoint['state_dict']) 92 | self.optimizer.load_state_dict(checkpoint['optimizer']) 93 | else: 94 | print("Training From Scratch ..." ) 95 | starting_epoch = 0 96 | 97 | writer1 = SummaryWriter(comment=f"LR_0.xxxx_{self.rank}") 98 | 99 | for i in range(starting_epoch+1, epochs): 100 | if self.rank == 0: 101 | (self.train_imgf1, self.train_imgf2, self.train_dir1, self.train_dir2, self.train_overlap) = \ 102 | overlap_orientation_npz_file2string_string_nparray(self.traindata_npzfiles, shuffle=True) 103 | data = [self.train_imgf1, self.train_imgf2, self.train_dir1, self.train_dir2, self.train_overlap] 104 | for r in range(1, self.world_size): 105 | dist.send_object(data, dst=r) 106 | else: 107 | data = dist.recv_object(src=0) 108 | self.train_imgf1, self.train_imgf2, self.train_dir1, self.train_dir2, self.train_overlap = data 109 | 110 | dist.barrier() 111 | 112 | if self.rank == 0: 113 | print("=======================================================================\n\n\n") 114 | print("training with seq: ", np.unique(np.array(self.train_dir1))) 115 | print("total pairs: ", len(self.train_imgf1)) 116 | print("\n\n\n=======================================================================") 117 | 118 | loss_each_epoch = 0 119 | used_num = 0 120 | 121 | used_list_f1 = [] 122 | used_list_dir1 = [] 123 | 124 | for j in range(len(self.train_imgf1)): 125 | """ 126 | check whether the query is used to train before (continue_flag==True/False). 127 | TODO: More efficient method 128 | """ 129 | f1_index = self.train_imgf1[j] 130 | dir1_index = self.train_dir1[j] 131 | continue_flag = False 132 | for iddd in range(len(used_list_f1)): 133 | if f1_index==used_list_f1[iddd] and dir1_index==used_list_dir1[iddd]: 134 | continue_flag = True 135 | else: 136 | used_list_f1.append(f1_index) 137 | used_list_dir1.append(dir1_index) 138 | 139 | if continue_flag: 140 | continue 141 | 142 | """read one query range image from KITTI sequences""" 143 | current_batch = read_one_need_from_seq(self.data_root_folder, f1_index, dir1_index) 144 | 145 | """ 146 | read several reference range images from KITTI sequences 147 | to consist of positive samples and negative samples 148 | """ 149 | sample_batch, sample_truth, pos_num, neg_num = read_one_batch_pos_neg \ 150 | (self.data_root_folder,f1_index, dir1_index, 151 | self.train_imgf1, self.train_imgf2, self.train_dir1, self.train_dir2, self.train_overlap, 152 | self.overlap_thresh) 153 | 154 | """ 155 | the balance of positive samples and negative samples. 156 | TODO: Update for better training results 157 | """ 158 | use_pos_num = 18 159 | use_neg_num = 18 160 | if pos_num >= use_pos_num and neg_num>=use_neg_num: 161 | sample_batch = torch.cat((sample_batch[0:use_pos_num, :, :, :], sample_batch[pos_num:pos_num+use_neg_num, :, :, :]), dim=0) 162 | sample_truth = torch.cat((sample_truth[0:use_pos_num, :], sample_truth[pos_num:pos_num+use_neg_num, :]), dim=0) 163 | pos_num = use_pos_num 164 | neg_num = use_neg_num 165 | elif pos_num >= use_pos_num: 166 | sample_batch = torch.cat((sample_batch[0:use_pos_num, :, :, :], sample_batch[pos_num:, :, :, :]), dim=0) 167 | sample_truth = torch.cat((sample_truth[0:use_pos_num, :], sample_truth[pos_num:, :]), dim=0) 168 | pos_num = use_pos_num 169 | elif neg_num >= use_neg_num: 170 | sample_batch = sample_batch[0:pos_num+use_neg_num,:,:,:] 171 | sample_truth = sample_truth[0:pos_num+use_neg_num, :] 172 | neg_num = use_neg_num 173 | 174 | if neg_num == 0: 175 | continue 176 | 177 | input_batch = torch.cat((current_batch, sample_batch), dim=0) 178 | 179 | input_batch.requires_grad_(True) 180 | self.amodel.train() 181 | self.optimizer.zero_grad() 182 | 183 | global_des = self.amodel(input_batch) 184 | o1, o2, o3 = torch.split( 185 | global_des, [1, pos_num, neg_num], dim=0) 186 | MARGIN_1 = 0.5 187 | """ 188 | triplet_loss: Lazy for pos 189 | triplet_loss_inv: Lazy for neg 190 | """ 191 | loss = PNV_loss.triplet_loss(o1, o2, o3, MARGIN_1, lazy=False) 192 | # loss = PNV_loss.triplet_loss_inv(o1, o2, o3, MARGIN_1, lazy=False, use_min=True) 193 | loss.backward() 194 | self.optimizer.step() 195 | 196 | if self.rank == 0: 197 | print(str(used_num), loss) 198 | 199 | if torch.isnan(loss): 200 | if self.rank == 0: 201 | print("Something error ...") 202 | print(pos_num) 203 | print(neg_num) 204 | 205 | loss_each_epoch = loss_each_epoch + loss.item() 206 | used_num = used_num + 1 207 | 208 | if self.rank == 0: 209 | print("epoch {} loss {}".format(i, loss_each_epoch/used_num)) 210 | print("saving weights ...") 211 | self.scheduler.step() 212 | 213 | """save trained weights and optimizer states""" 214 | self.save_name = "../weights/train/trained_overlap_transformer"+str(i)+".pth.tar" 215 | 216 | torch.save({ 217 | 'epoch': i, 218 | 'state_dict': self.amodel.module.state_dict(), 219 | 'optimizer': self.optimizer.state_dict() 220 | }, 221 | self.save_name) 222 | 223 | print("Model Saved As " + f'trained_overlap_transformer{self.rank}_' + str(i) + '.pth.tar') 224 | 225 | writer1.add_scalar("loss", loss_each_epoch / used_num, global_step=i) 226 | 227 | print("validating ......") 228 | with torch.no_grad(): 229 | top1_rate = validate_seq_faiss(self.amodel.module, "02") 230 | writer1.add_scalar("top1_rate", top1_rate, global_step=i) 231 | 232 | cleanup() 233 | 234 | if __name__ == '__main__': 235 | parser = argparse.ArgumentParser() 236 | parser.add_argument("--local_rank", type=int) 237 | args = parser.parse_args() 238 | 239 | # load config ================================================================ 240 | config_filename = '../config/config.yml' 241 | config = yaml.safe_load(open(config_filename)) 242 | data_root_folder = config["data_root"]["data_root_folder"] 243 | training_seqs = config["training_config"]["training_seqs"] 244 | # ============================================================================ 245 | 246 | traindata_npzfiles = [os.path.join(data_root_folder, seq, 'overlaps/train_set.npz') for seq in training_seqs] 247 | 248 | torch.cuda.set_device(args.local_rank) 249 | 250 | train_handler = trainHandler(rank=args.local_rank, world_size=torch.cuda.device_count(), height=64, width=900, channels=1, norm_layer=None, use_transformer=False, use_mamba=True, lr=0.000005, 251 | data_root_folder=data_root_folder, train_set=traindata_npzfiles, training_seqs = training_seqs) 252 | 253 | train_handler.train() 254 | -------------------------------------------------------------------------------- /train/training_overlap_mamba_kitti.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import os 4 | import sys 5 | p = os.path.dirname(os.path.dirname((os.path.abspath(__file__)))) 6 | if p not in sys.path: 7 | sys.path.append(p) 8 | sys.path.append('../tools/') 9 | sys.path.append('../modules/') 10 | import torch 11 | import numpy as np 12 | from tensorboardX import SummaryWriter 13 | from tools.read_all_sets import overlap_orientation_npz_file2string_string_nparray 14 | from modules.overlap_mamba import featureExtracter 15 | from tools.read_samples import read_one_batch_pos_neg 16 | from tools.read_samples import read_one_need_from_seq 17 | np.set_printoptions(threshold=sys.maxsize) 18 | import modules.loss as PNV_loss 19 | from tools.utils.utils import * 20 | from valid.valid_seq import validate_seq_faiss 21 | import yaml 22 | from datetime import datetime 23 | 24 | def shift(tensor, dim, index): 25 | length = tensor.size(dim) 26 | shifted_tensor = torch.cat((tensor.narrow(dim, index, length - index), 27 | tensor.narrow(dim, 0, index)), dim=dim) 28 | return shifted_tensor 29 | 30 | def unshift(tensor, dim, index): 31 | length = tensor.size(dim) 32 | unshifted_tensor = torch.cat((tensor.narrow(dim, length - index, index), 33 | tensor.narrow(dim, 0, length - index)), dim=dim) 34 | return unshifted_tensor 35 | 36 | 37 | class trainHandler(): 38 | def __init__(self, height=64, width=900, channels=5, norm_layer=None, lr = 0.001, 39 | data_root_folder = None, train_set=None, training_seqs=None): 40 | super(trainHandler, self).__init__() 41 | 42 | self.height = height 43 | self.width = width 44 | self.channels = channels 45 | self.norm_layer = norm_layer 46 | self.learning_rate = lr 47 | self.data_root_folder = data_root_folder 48 | self.train_set = train_set 49 | self.training_seqs = training_seqs 50 | 51 | self.amodel = featureExtracter(channels=self.channels) 52 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 53 | self.amodel.to(self.device) 54 | self.parameters = self.amodel.parameters() 55 | self.optimizer = torch.optim.Adam(self.parameters, self.learning_rate) 56 | # 57 | # self.optimizer = torch.optim.SGD(self.parameters, lr=self.learning_rate, momentum=0.9) 58 | self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=5, gamma=0.9) 59 | # self.scheduler = torch.optim.lr_scheduler.ExponentialLR(self.optimizer, gamma=0.98) 60 | 61 | self.traindata_npzfiles = train_set 62 | 63 | (self.train_imgf1, self.train_imgf2, self.train_dir1, self.train_dir2, self.train_overlap) = \ 64 | overlap_orientation_npz_file2string_string_nparray(self.traindata_npzfiles) 65 | 66 | """change the args for resuming training process""" 67 | self.resume = False 68 | self.save_name = "../weights/shift_bimamba_sppf_80/trained_overlap_transformer19.pth.tar" 69 | """overlap threshold follows OverlapNet""" 70 | self.overlap_thresh = 0.3 71 | 72 | def train(self): 73 | 74 | epochs = 30 75 | 76 | """resume from the saved model""" 77 | if self.resume: 78 | resume_filename = self.save_name 79 | print("Resuming from ", resume_filename) 80 | checkpoint = torch.load(resume_filename) 81 | starting_epoch = checkpoint['epoch'] 82 | self.amodel.load_state_dict(checkpoint['state_dict']) 83 | self.optimizer.load_state_dict(checkpoint['optimizer']) 84 | else: 85 | print("Training From Scratch ..." ) 86 | starting_epoch = 0 87 | 88 | writer1 = SummaryWriter(comment="LR_0.xxxx", log_dir="test_device") 89 | 90 | for i in range(starting_epoch+1, epochs): 91 | 92 | (self.train_imgf1, self.train_imgf2, self.train_dir1, self.train_dir2, self.train_overlap) = \ 93 | overlap_orientation_npz_file2string_string_nparray(self.traindata_npzfiles, shuffle=True) 94 | 95 | print("=======================================================================\n\n\n") 96 | 97 | print("training with seq: ", np.unique(np.array(self.train_dir1))) 98 | print("total pairs: ", len(self.train_imgf1)) 99 | print("\n\n\n=======================================================================") 100 | 101 | loss_each_epoch = 0 102 | used_num = 0 103 | 104 | used_list_f1 = [] 105 | used_list_dir1 = [] 106 | 107 | for j in range(len(self.train_imgf1)): 108 | """ 109 | check whether the query is used to train before (continue_flag==True/False). 110 | TODO: More efficient method 111 | """ 112 | f1_index = self.train_imgf1[j] 113 | dir1_index = self.train_dir1[j] 114 | continue_flag = False 115 | for iddd in range(len(used_list_f1)): 116 | if f1_index==used_list_f1[iddd] and dir1_index==used_list_dir1[iddd]: 117 | continue_flag = True 118 | else: 119 | used_list_f1.append(f1_index) 120 | used_list_dir1.append(dir1_index) 121 | 122 | if continue_flag: 123 | continue 124 | 125 | """read one query range image from KITTI sequences""" 126 | current_batch = read_one_need_from_seq(self.data_root_folder, f1_index, dir1_index) 127 | 128 | """ 129 | read several reference range images from KITTI sequences 130 | to consist of positive samples and negative samples 131 | """ 132 | sample_batch, sample_truth, pos_num, neg_num = read_one_batch_pos_neg \ 133 | (self.data_root_folder,f1_index, dir1_index, 134 | self.train_imgf1, self.train_imgf2, self.train_dir1, self.train_dir2, self.train_overlap, 135 | self.overlap_thresh) 136 | 137 | """ 138 | the balance of positive samples and negative samples. 139 | TODO: Update for better training results 140 | """ 141 | use_pos_num = 6 142 | use_neg_num = 6 143 | if pos_num >= use_pos_num and neg_num>=use_neg_num: 144 | sample_batch = torch.cat((sample_batch[0:use_pos_num, :, :, :], sample_batch[pos_num:pos_num+use_neg_num, :, :, :]), dim=0) 145 | sample_truth = torch.cat((sample_truth[0:use_pos_num, :], sample_truth[pos_num:pos_num+use_neg_num, :]), dim=0) 146 | pos_num = use_pos_num 147 | neg_num = use_neg_num 148 | elif pos_num >= use_pos_num: 149 | sample_batch = torch.cat((sample_batch[0:use_pos_num, :, :, :], sample_batch[pos_num:, :, :, :]), dim=0) 150 | sample_truth = torch.cat((sample_truth[0:use_pos_num, :], sample_truth[pos_num:, :]), dim=0) 151 | pos_num = use_pos_num 152 | elif neg_num >= use_neg_num: 153 | sample_batch = sample_batch[0:pos_num+use_neg_num,:,:,:] 154 | sample_truth = sample_truth[0:pos_num+use_neg_num, :] 155 | neg_num = use_neg_num 156 | 157 | if neg_num == 0: 158 | continue 159 | 160 | input_batch = torch.cat((current_batch, sample_batch), dim=0) 161 | 162 | input_batch.requires_grad_(True) 163 | self.amodel.train() 164 | self.optimizer.zero_grad() 165 | 166 | index = int(torch.rand(1).item() * 900) 167 | input_batch_shift = shift(input_batch, 3, index) 168 | 169 | 170 | global_des = self.amodel(input_batch) 171 | global_des_shift = self.amodel(input_batch_shift) 172 | global_des_add = torch.cat((global_des, global_des_shift), dim=1) 173 | 174 | o1, o2, o3 = torch.split( 175 | global_des_add, [1, pos_num, neg_num], dim=0) 176 | MARGIN_1 = 0.5 177 | """ 178 | triplet_loss: Lazy for pos 179 | triplet_loss_inv: Lazy for neg 180 | """ 181 | loss = PNV_loss.triplet_loss(o1, o2, o3, MARGIN_1, lazy=False) 182 | # loss = PNV_loss.triplet_loss_inv(o1, o2, o3, MARGIN_1, lazy=False, use_min=True) 183 | loss.backward() 184 | self.optimizer.step() 185 | 186 | current_time = datetime.now() 187 | formatted_time = current_time.strftime("[%Y-%m-%d %H:%M:%S]") 188 | 189 | if used_num % 1000 == 0: 190 | print(formatted_time, str(used_num), loss) 191 | 192 | if torch.isnan(loss): 193 | print("Something error ...") 194 | print(pos_num) 195 | print(neg_num) 196 | 197 | loss_each_epoch = loss_each_epoch + loss.item() 198 | used_num = used_num + 1 199 | 200 | 201 | print("epoch {} loss {}".format(i, loss_each_epoch/used_num)) 202 | print("saving weights ...") 203 | self.scheduler.step() 204 | 205 | """save trained weights and optimizer states""" 206 | self.save_name = "../weights/nclt_cvtnet/trained_overlap_transformer"+str(i)+".pth.tar" 207 | 208 | torch.save({ 209 | 'epoch': i, 210 | 'state_dict': self.amodel.state_dict(), 211 | 'optimizer': self.optimizer.state_dict() 212 | }, 213 | self.save_name) 214 | 215 | print("Model Saved As " + 'trained_overlap_transformer' + str(i) + '.pth.tar') 216 | 217 | writer1.add_scalar("loss", loss_each_epoch / used_num, global_step=i) 218 | 219 | """a simple validation with KITTI 02""" 220 | print("validating ......") 221 | with torch.no_grad(): 222 | top1_rate = validate_seq_faiss(self.amodel, "02") 223 | writer1.add_scalar("top1_rate", top1_rate, global_step=i) 224 | 225 | 226 | if __name__ == '__main__': 227 | # load config ================================================================ 228 | config_filename = '../config/config.yml' 229 | config = yaml.safe_load(open(config_filename)) 230 | data_root_folder = config["data_root"]["data_root_folder"] 231 | training_seqs = config["training_config"]["training_seqs"] 232 | # ============================================================================ 233 | 234 | # along the lines of OverlapNet 235 | traindata_npzfiles = [os.path.join(data_root_folder, seq, 'overlaps/train_set.npz') for seq in training_seqs] 236 | 237 | """ 238 | trainHandler to handle with training process. 239 | Args: 240 | height: the height of the range image (the beam number for convenience). 241 | width: the width of the range image (900, alone the lines of OverlapNet). 242 | channels: 1 for depth only in our work. 243 | norm_layer: None in our work for better model. 244 | use_transformer: Whether to use MHSA. 245 | lr: learning rate, which needs to fine tune while training for the best performance. 246 | data_root_folder: root of KITTI sequences. It's better to follow our file structure. 247 | train_set: traindata_npzfiles (alone the lines of OverlapNet). 248 | training_seqs: sequences number for training (alone the lines of OverlapNet). 249 | """ 250 | train_handler = trainHandler(height=64, width=900, channels=1, norm_layer=None, lr=0.000005, 251 | data_root_folder=data_root_folder, train_set=traindata_npzfiles, training_seqs = training_seqs) 252 | 253 | train_handler.train() 254 | -------------------------------------------------------------------------------- /train/training_overlap_mambar_nclt.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | import os 5 | import sys 6 | 7 | p = os.path.dirname(os.path.dirname((os.path.abspath(__file__)))) 8 | if p not in sys.path: 9 | sys.path.append(p) 10 | sys.path.append('../tools/') 11 | sys.path.append('../modules/') 12 | import torch 13 | import numpy as np 14 | from tensorboardX import SummaryWriter 15 | from tools.read_all_sets import overlap_orientation_npz_file2string_string_nparray 16 | from modules.overlap_mamba import featureExtracter 17 | from tools.read_samples import read_one_batch_pos_neg 18 | from tools.read_samples import read_one_need_from_seq 19 | 20 | np.set_printoptions(threshold=sys.maxsize) 21 | import modules.loss as PNV_loss 22 | from tools.utils.utils import * 23 | from valid.valid_seq import validate_seq_faiss 24 | import yaml 25 | from datetime import datetime 26 | 27 | 28 | def shift(tensor, dim, index): 29 | length = tensor.size(dim) 30 | shifted_tensor = torch.cat((tensor.narrow(dim, index, length - index), 31 | tensor.narrow(dim, 0, index)), dim=dim) 32 | return shifted_tensor 33 | 34 | def unshift(tensor, dim, index): 35 | length = tensor.size(dim) 36 | unshifted_tensor = torch.cat((tensor.narrow(dim, length - index, index), 37 | tensor.narrow(dim, 0, length - index)), dim=dim) 38 | return unshifted_tensor 39 | 40 | 41 | class trainHandler(): 42 | def __init__(self, height=32, width=900, channels=5, norm_layer=None, lr=0.001, 43 | data_root_folder=None, train_set=None, training_seqs=None): 44 | super(trainHandler, self).__init__() 45 | 46 | self.height = height 47 | self.width = width 48 | self.channels = channels 49 | self.norm_layer = norm_layer 50 | self.learning_rate = lr 51 | self.data_root_folder = data_root_folder 52 | self.train_set = train_set 53 | self.training_seqs = training_seqs 54 | 55 | self.amodel = featureExtracter(channels=self.channels) 56 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 57 | self.amodel.to(self.device) 58 | self.parameters = self.amodel.parameters() 59 | self.optimizer = torch.optim.Adam(self.parameters, self.learning_rate) 60 | # 61 | # self.optimizer = torch.optim.SGD(self.parameters, lr=self.learning_rate, momentum=0.9) 62 | self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=5, gamma=0.9) 63 | # self.scheduler = torch.optim.lr_scheduler.ExponentialLR(self.optimizer, gamma=0.98) 64 | 65 | self.traindata_npzfiles = train_set 66 | 67 | (self.train_imgf1, self.train_imgf2, self.train_dir1, self.train_dir2, self.train_overlap) = \ 68 | overlap_orientation_npz_file2string_string_nparray(self.traindata_npzfiles) 69 | 70 | """change the args for resuming training process""" 71 | self.resume = False 72 | self.save_name = "/home/robot/下载/OverlapTransformer-master-copy/weights/nclt_noimloss/trained_overlap_transformer4.pth.tar" 73 | """overlap threshold follows OverlapNet""" 74 | self.overlap_thresh = 0.3 75 | 76 | def train(self): 77 | 78 | epochs = 30 79 | 80 | """resume from the saved model""" 81 | if self.resume: 82 | resume_filename = self.save_name 83 | print("Resuming from ", resume_filename) 84 | checkpoint = torch.load(resume_filename) 85 | starting_epoch = checkpoint['epoch'] 86 | self.amodel.load_state_dict(checkpoint['state_dict']) 87 | self.optimizer.load_state_dict(checkpoint['optimizer']) 88 | else: 89 | print("Training From Scratch ...") 90 | starting_epoch = 0 91 | 92 | writer1 = SummaryWriter(comment="LR_0.xxxx", log_dir="test_device") 93 | 94 | for i in range(starting_epoch + 1, epochs): 95 | 96 | (self.train_imgf1, self.train_imgf2, self.train_dir1, self.train_dir2, self.train_overlap) = \ 97 | overlap_orientation_npz_file2string_string_nparray(self.traindata_npzfiles, shuffle=True) 98 | 99 | print("=======================================================================\n\n\n") 100 | 101 | print("training with seq: ", np.unique(np.array(self.train_dir1))) 102 | print("total pairs: ", len(self.train_imgf1)) 103 | print("\n\n\n=======================================================================") 104 | 105 | loss_each_epoch = 0 106 | used_num = 0 107 | 108 | used_list_f1 = [] 109 | used_list_dir1 = [] 110 | 111 | for j in range(len(self.train_imgf1)): 112 | """ 113 | check whether the query is used to train before (continue_flag==True/False). 114 | TODO: More efficient method 115 | """ 116 | f1_index = self.train_imgf1[j] 117 | dir1_index = self.train_dir1[j] 118 | continue_flag = False 119 | for iddd in range(len(used_list_f1)): 120 | if f1_index == used_list_f1[iddd] and dir1_index == used_list_dir1[iddd]: 121 | continue_flag = True 122 | else: 123 | used_list_f1.append(f1_index) 124 | used_list_dir1.append(dir1_index) 125 | 126 | if continue_flag: 127 | continue 128 | 129 | """read one query range image from KITTI sequences""" 130 | current_batch = read_one_need_from_seq(self.data_root_folder, f1_index, dir1_index) 131 | 132 | """ 133 | read several reference range images from KITTI sequences 134 | to consist of positive samples and negative samples 135 | """ 136 | sample_batch, sample_truth, pos_num, neg_num = read_one_batch_pos_neg \ 137 | (self.data_root_folder, f1_index, dir1_index, 138 | self.train_imgf1, self.train_imgf2, self.train_dir1, self.train_dir2, self.train_overlap, 139 | self.overlap_thresh) 140 | 141 | """ 142 | the balance of positive samples and negative samples. 143 | TODO: Update for better training results 144 | """ 145 | use_pos_num = 6 146 | use_neg_num = 6 147 | if pos_num >= use_pos_num and neg_num >= use_neg_num: 148 | sample_batch = torch.cat( 149 | (sample_batch[0:use_pos_num, :, :, :], sample_batch[pos_num:pos_num + use_neg_num, :, :, :]), 150 | dim=0) 151 | sample_truth = torch.cat( 152 | (sample_truth[0:use_pos_num, :], sample_truth[pos_num:pos_num + use_neg_num, :]), dim=0) 153 | pos_num = use_pos_num 154 | neg_num = use_neg_num 155 | elif pos_num >= use_pos_num: 156 | sample_batch = torch.cat((sample_batch[0:use_pos_num, :, :, :], sample_batch[pos_num:, :, :, :]), 157 | dim=0) 158 | sample_truth = torch.cat((sample_truth[0:use_pos_num, :], sample_truth[pos_num:, :]), dim=0) 159 | pos_num = use_pos_num 160 | elif neg_num >= use_neg_num: 161 | sample_batch = sample_batch[0:pos_num + use_neg_num, :, :, :] 162 | sample_truth = sample_truth[0:pos_num + use_neg_num, :] 163 | neg_num = use_neg_num 164 | 165 | if neg_num == 0: 166 | continue 167 | 168 | # inputbatch = torch.cat((current_batch, shift_batch1), dim=0) 169 | # inputbatch1 = torch.cat((inputbatch, shift_batch2), dim=0) 170 | # input_batch = torch.cat((inputbatch1, sample_batch), dim=0) 171 | 172 | input_batch = torch.cat((current_batch, sample_batch), dim=0) 173 | 174 | input_batch.requires_grad_(True) 175 | self.amodel.train() 176 | self.optimizer.zero_grad() 177 | 178 | index = int(torch.rand(1).item() * 900) 179 | input_batch_shift = shift(input_batch, 3, index) 180 | 181 | global_des = self.amodel(input_batch) 182 | global_des_shift = self.amodel(input_batch_shift) 183 | global_des_add = torch.cat((global_des, global_des_shift), dim=1) 184 | 185 | o1, o2, o3 = torch.split( 186 | global_des_add, [1, pos_num, neg_num], dim=0) 187 | MARGIN_1 = 0.5 188 | """ 189 | triplet_loss: Lazy for pos 190 | triplet_loss_inv: Lazy for neg 191 | """ 192 | loss = PNV_loss.triplet_loss(o1, o2, o3, MARGIN_1, lazy=False) 193 | # loss = PNV_loss.triplet_loss_inv(o1, o2, o3, MARGIN_1, lazy=False, use_min=True) 194 | loss.backward() 195 | self.optimizer.step() 196 | 197 | current_time = datetime.now() 198 | formatted_time = current_time.strftime("[%Y-%m-%d %H:%M:%S]") 199 | 200 | if used_num % 1000 == 0: 201 | print(formatted_time, str(used_num), loss) 202 | 203 | if torch.isnan(loss): 204 | print("Something error ...") 205 | print(pos_num) 206 | print(neg_num) 207 | 208 | loss_each_epoch = loss_each_epoch + loss.item() 209 | used_num = used_num + 1 210 | 211 | print("epoch {} loss {}".format(i, loss_each_epoch / used_num)) 212 | print("saving weights ...") 213 | self.scheduler.step() 214 | 215 | """save trained weights and optimizer states""" 216 | self.save_name = "../weights/nclt_cvtnet/trained_overlap_transformer" + str(i) + ".pth.tar" 217 | 218 | torch.save({ 219 | 'epoch': i, 220 | 'state_dict': self.amodel.state_dict(), 221 | 'optimizer': self.optimizer.state_dict() 222 | }, 223 | self.save_name) 224 | 225 | print("Model Saved As " + 'trained_overlap_mamba' + str(i) + '.pth.tar') 226 | 227 | writer1.add_scalar("loss", loss_each_epoch / used_num, global_step=i) 228 | 229 | """a simple validation with KITTI 02""" 230 | print("validating ......") 231 | # with torch.no_grad(): 232 | # top1_rate = validate_seq_faiss(self.amodel, "02") 233 | # writer1.add_scalar("top1_rate", top1_rate, global_step=i) 234 | 235 | 236 | if __name__ == '__main__': 237 | # load config ================================================================ 238 | config_filename = '../config/config.yml' 239 | config = yaml.safe_load(open(config_filename)) 240 | data_root_folder = config["data_root"]["data_root_folder"] 241 | training_seqs = config["training_config"]["training_seqs"] 242 | # ============================================================================ 243 | 244 | # along the lines of OverlapNet 245 | traindata_npzfiles = [os.path.join(data_root_folder, seq, 'overlaps/train_set.npz') for seq in training_seqs] 246 | 247 | """ 248 | trainHandler to handle with training process. 249 | Args: 250 | height: the height of the range image (the beam number for convenience). 251 | width: the width of the range image (900, alone the lines of OverlapNet). 252 | channels: 1 for depth only in our work. 253 | norm_layer: None in our work for better model. 254 | use_transformer: Whether to use MHSA. 255 | lr: learning rate, which needs to fine tune while training for the best performance. 256 | data_root_folder: root of KITTI sequences. It's better to follow our file structure. 257 | train_set: traindata_npzfiles (alone the lines of OverlapNet). 258 | training_seqs: sequences number for training (alone the lines of OverlapNet). 259 | """ 260 | train_handler = trainHandler(height=32, width=900, channels=1, norm_layer=None, lr=0.000005, 261 | data_root_folder=data_root_folder, train_set=traindata_npzfiles, 262 | training_seqs=training_seqs) 263 | 264 | train_handler.train() 265 | --------------------------------------------------------------------------------