5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # OverlapMamba: Novel Shift State Space Model for LiDAR-based Place Recognition
2 |
3 |
4 |
5 |
6 |
7 |
8 | Fig. 1 Core idea of the proposed OverlapMamba
9 |
10 | ### 💾Environment
11 |
12 |
13 | We use pytorch-gpu for neural networks.
14 |
15 | To use a GPU, first you need to install the nvidia driver and CUDA.
16 |
17 | - CUDA Installation guide: [link](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html)
18 | We use CUDA 11.3 in our work. Other versions of CUDA are also supported but you should choose the corresponding torch version in the following Torch dependences.
19 |
20 | - System dependencies:
21 |
22 | ```bash
23 | sudo apt-get update
24 | sudo apt-get install -y python3-pip python3-tk
25 | sudo -H pip3 install --upgrade pip
26 | ```
27 | - Torch dependences:
28 | Following this [link](https://pytorch.org/get-started/locally/), you can download Torch dependences by pip:
29 | ```bash
30 | pip3 install torch==1.10.2+cu113 torchvision==0.11.3+cu113 torchaudio==0.10.2+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html
31 | ```
32 | or by conda:
33 | ```bash
34 | conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch
35 | ```
36 |
37 | - Other Python dependencies (may also work with different versions than mentioned in the requirements file):
38 |
39 | ```bash
40 | sudo -H pip3 install -r requirements.txt
41 | ```
42 |
43 | ## 📖How to use
44 |
45 | We provide a training and test tutorial for KITTI sequences in this repository. Before any operation, please modify the [config file](https://github.com/SCNU-RISLAB/OverlapMamba/blob/main/config/config.yml) according to your setups.
46 |
47 | ### 📚Dataset
48 |
49 | Download KITTI dataset from [KITTI](https://www.cvlibs.net/datasets/kitti/user_login.php).
50 |
51 | We recommend you follow our code and data structures as follows.
52 |
53 | ### Code Structure
54 |
55 | ```bash
56 | ├── config
57 | │ ├── config_nclt.yml
58 | │ └── config.yml
59 | ├── modules
60 | │ ├── loss.py
61 | │ ├── netvlad.py
62 | │ ├── overlap_mamba_nclt.py
63 | │ └── overlap_mamba.py
64 | ├── test
65 | │ ├── test_mamba_topn_prepare.py
66 | │ ├── test_mamba_topn.py
67 | │ ├── test_kitti00_prepare.py
68 | │ ├── test_kitti00_PR.py
69 | │ ├── test_kitti00_topN.py
70 | │ ├── test_results_nclt
71 | │ │ └── predicted_des_L2_dis_bet_traj_forward.npz (to be generated)
72 | │ └── test_results_kitti
73 | │ └── predicted_des_L2_dis.npz (to be generated)
74 | ├── tools
75 | │ ├── read_all_sets.py
76 | │ ├── read_samples_haomo.py
77 | │ ├── read_samples.py
78 | │ └── utils
79 | │ ├── gen_depth_data.py
80 | │ ├── split_train_val.py
81 | │ └── utils.py
82 | ├── train
83 | │ ├── training_overlap_mamba_nclt.py
84 | │ └── training_overlap_mamba_kitti.py
85 | ├── valid
86 | │ └── valid_seq.py
87 | ├── visualize
88 | │ ├── des_list.npy
89 | │ └── viz_nclt.py
90 | └── weights
91 | ├── pretrained_overlap_mamba_nclt.pth.tar
92 | └── pretrained_overlap_mamba.pth.tar
93 | ```
94 |
95 | ### Dataset Structure
96 | ```
97 | data_root_folder (KITTI sequences root) follows:
98 | ├── 00
99 | │ ├── depth_map
100 | │ ├── 000000.png
101 | │ ├── 000001.png
102 | │ ├── 000002.png
103 | │ ├── ...
104 | │ └── overlaps
105 | │ ├── train_set.npz
106 | ├── 01
107 | ├── 02
108 | ├── ...
109 | ├── 10
110 | └── loop_gt_seq00_0.3overlap_inactive.npz
111 |
112 | valid_scan_folder (KITTI sequence 02 velodyne) contains:
113 | ├── 000000.bin
114 | ├── 000001.bin
115 | ...
116 |
117 | gt_valid_folder (KITTI sequence 02 computed overlaps) contains:
118 | ├── 02
119 | │ ├── overlap_0.npy
120 | │ ├── overlap_10.npy
121 | ...
122 | ```
123 | You need to download or generate the following files and put them in the right positions of the structure above:
124 | - You can find the groud truth for KITTI 00 here: [loop_gt_seq00_0.3overlap_inactive.npz](https://drive.google.com/file/d/1upAwJBF-_UIB7R8evW0PuJBM3RnrTbzl/view?usp=sharing)
125 | - You can find `gt_valid_folder` for sequence 02 [here](https://drive.google.com/file/d/13_1j20Uq3ppjVEkYaYcKjiJ2Zm7tudyH/view?usp=sharing).
126 | - Since the whole KITTI sequences need a large memory, we recommend you generate range images such as `00/depth_map/000000.png` by the preprocessing from [Overlap_Localization](https://github.com/PRBonn/overlap_localization/blob/master/src/prepare_training/gen_depth_and_normal_map.py), and we will not provide these images.
127 | - More directly, you can generate `.png` range images by [this script](https://github.com/SCNU-RISLAB/OverlapMamba/blob/main/OverlapMamba/tools/utils/gen_depth_data.py) .
128 | - `overlaps` folder of each sequence below `data_root_folder` is provided by the authors of OverlapNet [here](https://drive.google.com/file/d/1i333NUC1DnJglXasqkGYCmo9p45Fx28-/view?usp=sharing). You should rename them to `train_set.npz`.
129 |
130 |
131 | ### Quick Use
132 |
133 | For a quick use, you could download our [model pretrained on KITTI](https://github.com/SCNU-RISLAB/OverlapMamba/blob/main/data_root_folder/pretrained_overlap_mamba.pth.tar), and the following two files also should be downloaded :
134 | - [calib_file](https://drive.google.com/file/d/1LAcFrRSZQPxdD4EKSwIC0d3-uGvLB3yk/view?usp=sharing): calibration file from KITTI 00.
135 | - [poses_file](https://drive.google.com/file/d/1n02m1OqxK122ce8Cjz_N68PkazGqzj9l/view?usp=sharing): pose file from KITTI 00.
136 |
137 | Then you should modify `demo1_config` in the file [config.yaml](https://github.com/SCNU-RISLAB/OverlapMamba/blob/main/config/config.yml).
138 |
139 | Run the demo by:
140 |
141 | ```
142 | cd demo
143 | python ./demo_compute_overlap_sim.py
144 | ```
145 | You can see a query scan (000000.bin of KITTI 00) with a reprojected positive sample (000005.bin of KITTI 00) and a reprojected negative sample (000015.bin of KITTI 00), and the corresponding similarity.
146 |
147 |
148 | ### Training
149 |
150 | In the file [config.yaml](https://github.com/SCNU-RISLAB/OverlapMamba/blob/main/config/config.yml), `training_seqs` are set for the KITTI sequences used for training.
151 |
152 | You can start the training with
153 |
154 | ```
155 | cd train
156 | python ./training_overlap_mamba_kitti.py
157 | ```
158 | You can resume from our pretrained model [here](https://github.com/SCNU-RISLAB/OverlapMamba/blob/main/data_root_folder/pretrained_overlap_mamba.pth.tar) for training.
159 |
160 |
161 | ### Testing
162 |
163 | Once a model has been trained , the performance of the network can be evaluated. Before testing, the parameters shoud be set in [config.yaml](https://github.com/SCNU-RISLAB/OverlapMamba/blob/main/config/config.yml)
164 |
165 | - `test_seqs`: sequence number for evaluation which is "00" in our work.
166 | - `test_weights`: path of the pretrained model.
167 | - `gt_file`: path of the ground truth file provided by the author of OverlapNet, which can be downloaded [here](https://drive.google.com/file/d/1upAwJBF-_UIB7R8evW0PuJBM3RnrTbzl/view?usp=sharing).
168 |
169 |
170 | Therefore you can start the testing scripts as follows:
171 |
172 | ```
173 | cd test
174 | mkdir test_results_kitti
175 | python test_kitti00_prepare.py
176 | python test_kitti00_PR.py
177 | python test_kitti00_topN.py
178 | ```
179 | After you run `test_kitti00_prepare.py`, a file named `predicted_des_L2_dis.npz` is generated in `test_results_kitti`, which is used by `python test_kitti00_PR.py` to calculate PR curve and F1max, and used by `python test_kitti00_topN.py` to calculate topN recall.
180 |
181 |
182 | ## Performance
183 | The code of this repo uses a minimal PyTorch implementation of Mamba, which implements the scan operation as a sequential loop. It closely follows the file of the official Mamba implementation, but its performance are a bit worse than the official Mamba implementation.
184 |
185 | After our paper is accepted, we will provide a C++ implementation of OverlapMamba with libtorch for faster retrival.
186 |
187 |
188 | ## 👏Acknowledgment
189 | This repo is based on [OverlapTransformer](https://github.com/haomo-ai/OverlapTransformer) and [mamba.py](https://github.com/alxndrTL/mamba.py), we are very grateful for their excellent work
190 | and appreciate their contributions to LiDAR-based place recognition(LPR) and highly recommend people to use their excellent public available code.
191 |
--------------------------------------------------------------------------------
/config/config.yml:
--------------------------------------------------------------------------------
1 | data_root:
2 |
3 | data_root_folder: "/home/robot/Project/OverlapTransformer/data_root_folder/" # KITTI sequences root
4 | valid_scan_folder: "/home/robot/下载/kitti/sequences/02/velodyne" # KITTI sequence 02 velodyne
5 | gt_valid_folder: "/home/robot/Project/OverlapTransformer/gt_valid_folder/" # KITTI sequence 02 computed overlaps
6 |
7 | # data_root_folder: "/home/robot/Project/OverlapTransformer/dataset-1/" # Ford campus sequences root
8 | # valid_scan_folder: "/home/robot/下载/kitti/sequences/02/velodyne" # KITTI sequence 02 velodyne
9 | # gt_valid_folder: "/home/robot/Project/OverlapTransformer/gt_valid_folder/" # KITTI sequence 02 computed overlaps
10 |
11 | # data_root_folder: "/home/robot/Project/CVTNet/NCLT/" # NCLT sequences root
12 |
13 |
14 | demo1_config:
15 |
16 | calib_file: "/home/robot/Project/OverlapTransformer/data_root_folder/00/calib.txt" # calibration file from KITTI 00
17 | poses_file: "/home/robot/Project/OverlapTransformer/weights/00.txt" # pose file from KITTI 00
18 | test_weights: "/home/robot/Project/OverlapTransformer/weights/pretrained_overlap_transformer.pth.tar" # pretrained model
19 |
20 |
21 | training_config:
22 |
23 | training_seqs: ["03", "04", "05","06", "07", "08", "09", "10"] # KITTI sequences for training
24 | # training_seqs: ["2012-01-08_vel/2012-01-08"] # KITTI sequences for training
25 |
26 | test_config:
27 |
28 | test_seqs: ["00"] # KITTI sequence 00 for evaluation
29 | test_weights: "/home/robot/下载/OverlapTransformer-master-copy/weights/final_weight/trained_overlap_transformer16.pth.tar" # pretrained model
30 | gt_file: "/home/robot/Project/OverlapTransformer/data_root_folder/loop_gt_seq00_0.3overlap_inactive.npz" # ground truth
31 |
32 | # test_seqs: [ "00" ] # Ford campus sequence 00 for evaluation
33 | # test_weights: "/home/robot/下载/OverlapTransformer-master-copy/weights/shift_bimamba_sppf_80/trained_overlap_transformer4.pth.tar" # pretrained model
34 | # gt_file: "/home/robot/下载/Ford campus/loop_gt_seq00_0.3overlap_inactive.npz" # ground truth
35 |
36 | # test_seqs: ["2012-02-05_vel/2012-02-05"] # KITTI sequence 00 for evaluation
37 | # test_weights: "/home/robot/下载/OverlapTransformer-master-copy/weights/nclt_npy/pretrained_overlap_transformer_haomo22.pth.tar" # pretrained model
38 | # gt_file: "/home/robot/Project/CVTNet/NCLT/2012-02-05_vel/2012-02-05/overlaps/gt.npz" # ground truth
39 |
40 |
41 | viz_config:
42 |
43 | calib_file: "/home/robot/Project/OverlapTransformer/data _root_folder/00/calib.txt" # calibration file from KITTI 00
44 | poses_file: "/home/robot/Project/OverlapTransformer/weights/00.txt" # pose file from KITTI 00
45 | cov_file: "/home/robot/Project/OverlapTransformer/weights/covariance_2nd.txt" # covariance file from SUMA++ on KITTI 00
46 |
47 |
--------------------------------------------------------------------------------
/config/config_nclt.yml:
--------------------------------------------------------------------------------
1 | file_root:
2 | data_root_folder: "/home/robot/Project/CVTNet/NCLT/2012-01-08_vel/2012-01-08/depth_map/"
3 | triplets_for_training: "/home/robot/Project/CVTNet/more_chosen_normalized_data_120108.npy"
4 | pose_file_database: "/home/mjy/datasets/haomo_data/stamps_calibrated_poses_1208_1_01.npy"
5 | pose_file_query: "/home/mjy/datasets/haomo_data/stamps_calibrated_poses_1208_1_02.npy"
6 |
7 |
8 | data_root_folder_test: "/home/robot/Project/CVTNet/NCLT/2012-02-05_vel/2012-02-05/depth_map/"
9 |
10 |
11 | test_weights: "/home/robot/下载/OverlapTransformer-master-copy/weights/nclt_o1-shift2/trained_overlap_transformer5.pth.tar"
12 | #gt_file: "/home/robot/Project/CVTNet/gt_120108_120205.npy"
13 | gt_file: "/home/robot/Project/SeqOT-main/data_prepararion/gt/gt_120108_120205.npy"
14 |
15 | training_config:
16 |
17 | training_seqs: ["2012-01-08_vel/2012-01-08"] # KITTI sequences for training
--------------------------------------------------------------------------------
/demo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SCNU-RISLAB/OverlapMamba/b20add9b83ea87727a993eb316869d271436be78/demo.png
--------------------------------------------------------------------------------
/demo/__pycache__/com_overlap.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SCNU-RISLAB/OverlapMamba/b20add9b83ea87727a993eb316869d271436be78/demo/__pycache__/com_overlap.cpython-38.pyc
--------------------------------------------------------------------------------
/demo/com_overlap.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Developed by Xieyuanli Chen and Thomas Läbe
3 | # This file is covered by the LICENSE file in the root of this project.
4 | # Brief: This script generate the overlap and orientation combined mapping file.
5 | import sys
6 | import os
7 | p = os.path.dirname(os.path.dirname((os.path.abspath(__file__))))
8 | if p not in sys.path:
9 | sys.path.append(p)
10 | from tools.utils.utils import *
11 |
12 | def com_overlap(scan_paths, poses, frame_idx):
13 | # init ground truth overlap and yaw
14 | print('Start to compute ground truth overlap ...')
15 | overlaps = []
16 |
17 | # we calculate the ground truth for one given frame only
18 | # generate range projection for the given frame
19 | current_points = load_vertex(scan_paths[frame_idx])
20 | current_range, project_points, _, _ = range_projection(current_points, fov_up=3, fov_down=-25.0, proj_H=64, proj_W=900, max_range=50)
21 | visible_points = project_points[current_range > 0]
22 | valid_num = len(visible_points)
23 | current_pose = poses[frame_idx]
24 |
25 | tau_ = 1.2
26 | print("threshold for overlap: ", tau_)
27 |
28 | reference_range_list = []
29 | for i in range(len(scan_paths)):
30 | # generate range projection for the reference frame
31 | reference_idx = int(scan_paths[i][-10:-4])
32 | reference_pose = poses[reference_idx]
33 | reference_points = load_vertex(scan_paths[i])
34 |
35 | reference_points_world = reference_pose.dot(reference_points.T).T
36 | reference_points_in_current = np.linalg.inv(current_pose).dot(reference_points_world.T).T
37 | reference_range, _, _, _ = range_projection(reference_points_in_current, fov_up=3, fov_down=-25.0, proj_H=64, proj_W=900, max_range=50)
38 | # calculate overlap
39 | overlap = np.count_nonzero(
40 | abs(reference_range[reference_range > 0] - current_range[reference_range > 0]) < tau_) / valid_num
41 | overlaps.append(overlap)
42 | reference_range_list.append(reference_range)
43 |
44 |
45 | # ground truth format: each row contains [current_frame_idx, reference_frame_idx, overlap,]
46 | ground_truth_mapping = np.zeros((len(scan_paths), 3))
47 | ground_truth_mapping[:, 0] = np.ones(len(scan_paths)) * frame_idx
48 | ground_truth_mapping[:, 1] = np.arange(len(scan_paths))
49 | ground_truth_mapping[:, 2] = overlaps
50 |
51 | print('Finish generating ground_truth_mapping!')
52 |
53 | return ground_truth_mapping, current_range, reference_range_list
54 |
--------------------------------------------------------------------------------
/demo/demo_compute_overlap_sim.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | p = os.path.dirname(os.path.dirname((os.path.abspath(__file__))))
4 | if p not in sys.path:
5 | sys.path.append(p)
6 | import matplotlib.pyplot as plt
7 | import numpy as np
8 | from com_overlap import com_overlap
9 | import yaml
10 | from tools.utils.utils import *
11 | from modules.overlap_mamba import featureExtracter
12 | import torch
13 |
14 | # load config ================================================================
15 | config_filename = '../config/config.yml'
16 | config = yaml.safe_load(open(config_filename))
17 | calib_file = config["demo1_config"]["calib_file"]
18 | poses_file = config["demo1_config"]["poses_file"]
19 | test_weights = config["demo1_config"]["test_weights"]
20 | # ============================================================================
21 |
22 | # load scan paths
23 | scan_folder = "./scans"
24 | scan_paths = load_files(scan_folder)
25 |
26 | # load calibrations
27 | T_cam_velo = load_calib(calib_file)
28 | T_cam_velo = np.asarray(T_cam_velo).reshape((4, 4))
29 | T_velo_cam = np.linalg.inv(T_cam_velo)
30 |
31 | # load poses
32 | poses = load_poses(poses_file)
33 | pose0_inv = np.linalg.inv(poses[0])
34 |
35 | # for KITTI dataset, we need to convert the provided poses
36 | # from the camera coordinate system into the LiDAR coordinate system
37 | poses_new = []
38 | for pose in poses:
39 | poses_new.append(T_velo_cam.dot(pose0_inv).dot(pose).dot(T_cam_velo))
40 | poses = np.array(poses_new)
41 |
42 | # calculate overlap
43 | ground_truth_mapping, current_range, reference_range_list = com_overlap(scan_paths, poses, frame_idx=0)
44 |
45 | # build model and load pretrained weights
46 | amodel = featureExtracter(channels=1, use_transformer=True)
47 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
48 | amodel.to(device)
49 | print("Loading weights from ", test_weights)
50 | checkpoint = torch.load(test_weights)
51 | amodel.load_state_dict(checkpoint['state_dict'])
52 | amodel.eval()
53 |
54 | overlap_pos = round(ground_truth_mapping[1,-1]*100,2)
55 | overlap_neg = round(ground_truth_mapping[2,-1]*100,2)
56 |
57 | reference_range_pos = reference_range_list[1]
58 | reference_range_neg = reference_range_list[2]
59 | currentrange_neg_tensor = torch.from_numpy(current_range).unsqueeze(0)
60 | currentrange_neg_tensor = currentrange_neg_tensor.unsqueeze(0).cuda()
61 | reference_range_pos_tensor = torch.from_numpy(reference_range_pos).unsqueeze(0)
62 | reference_range_pos_tensor = reference_range_pos_tensor.unsqueeze(0).cuda()
63 | reference_range_neg_tensor = torch.from_numpy(reference_range_neg).unsqueeze(0)
64 | reference_range_neg_tensor = reference_range_neg_tensor.unsqueeze(0).cuda()
65 |
66 | # generate descriptors
67 | des_cur = amodel(currentrange_neg_tensor).cpu().detach().numpy()
68 | des_pos = amodel(reference_range_pos_tensor).cpu().detach().numpy()
69 | des_neg = amodel(reference_range_neg_tensor).cpu().detach().numpy()
70 |
71 | # calculate similarity
72 | dis_pos = np.linalg.norm(des_cur - des_pos)
73 | dis_neg = np.linalg.norm(des_cur - des_neg)
74 | sim_pos = round(1/(1+dis_pos),2)
75 | sim_neg = round(1/(1+dis_neg),2)
76 |
77 | plt.figure(figsize=(8,4))
78 | plt.subplot(311)
79 | plt.title("query: " + scan_paths[0])
80 | plt.imshow(current_range)
81 | plt.subplot(312)
82 | plt.title("positive reference: " + scan_paths[1] + " - similarity: " + str(sim_pos))
83 | plt.imshow(reference_range_list[1])
84 | plt.subplot(313)
85 | plt.title("negative reference: " + scan_paths[2] + " - similarity: " + str(sim_neg))
86 | plt.imshow(reference_range_list[2])
87 | plt.show()
88 |
--------------------------------------------------------------------------------
/demo/scans/000000.bin:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SCNU-RISLAB/OverlapMamba/b20add9b83ea87727a993eb316869d271436be78/demo/scans/000000.bin
--------------------------------------------------------------------------------
/demo/scans/000005.bin:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SCNU-RISLAB/OverlapMamba/b20add9b83ea87727a993eb316869d271436be78/demo/scans/000005.bin
--------------------------------------------------------------------------------
/demo/scans/000015.bin:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SCNU-RISLAB/OverlapMamba/b20add9b83ea87727a993eb316869d271436be78/demo/scans/000015.bin
--------------------------------------------------------------------------------
/fig.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SCNU-RISLAB/OverlapMamba/b20add9b83ea87727a993eb316869d271436be78/fig.png
--------------------------------------------------------------------------------
/mambapy/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | **/*.npz
--------------------------------------------------------------------------------
/mambapy/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Alexandre TL
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/mambapy/README.md:
--------------------------------------------------------------------------------
1 | # mamba.py 🐍 : a simple and efficient Mamba implementation
2 | A straightfoward implementation of [Mamba](https://arxiv.org/abs/2312.00752) in PyTorch with a simple parallel scan implementation, offering an major speedup over a sequential implementation, as the parallel scan allows the parallelization over the time dimension.
3 | It combines the ease of read with good performances.
4 |
5 | ## Updates
6 | - 09/02/2024 : First part of the performance update. For small sequences (<128), it can speed up training by more than 20% compared to the first version. For setups close to what can found in practice (like in NLP), it can speed up training by 10%.
7 |
8 | - 22/01/2024 : Added a MLX version of `mamba.py`, which supports inference as well as training. This version is similar to PyTorch, and allows Mac users to play around with Mamba models. It was [tested]() on the largest Mamba trained to date (2.8b), as well as
9 |
10 | ___
11 | ## Overview
12 |
13 | 
14 |
15 | This graph shows the training time (forward and backward pass) of a single Mamba layer (`d_model=16, d_state=16`) using 3 different methods : `CUDA`, which is the official [Mamba implementation](https://github.com/state-spaces/mamba), `mamba.py`, which is this repo, and `sequential`, which is a sequential (RNN-like) implementation of the selective scan.
16 |
17 | This repo contains a simple and readable code implementing the [Mamba](https://arxiv.org/abs/2312.00752) architecture in pure PyTorch as well as MLX. Its primary goal is educational.
18 |
19 |
20 |
21 |
22 |
23 | The repo is organized as follows :
24 | - `pscan.py` : a PyTorch implementation of Blelloch's parallel scan
25 | - `mamba.py` : the Mamba model, as described in the [paper](https://arxiv.org/abs/2312.00752). It is numerically equivalent (initialization, forward and backward pass).
26 | - `mamba_lm.py` : encapsulates a Mamba model in order to use it as a language model
27 | - `📁 mlx` : basically the same code as above, but in MLX.
28 | - `📁 docs` : a folder containing annotated explanations about the code, focusing on the parallel scan
29 | - `📁 examples` : two examples of how to use the Mamba model in PyTorch.
30 |
31 | ## Usage
32 |
33 | The most basic usage is to use the `Mamba` object ([mamba.py](mamba.py)), which implements a simple Mamba model given a configuration.
34 | No embedding, no head : input is `(B, L, D)` and output is `(B, L, D)` as well.
35 |
36 | ```
37 | import torch
38 | from mamba import Mamba, MambaConfig
39 |
40 | config = MambaConfig(d_model=16, n_layers=2)
41 | model = Mamba(config)
42 |
43 | B, L, D = 2, 64, 16
44 | x = torch.randn(B, L, D)
45 | y = model(x)
46 |
47 | assert y.shape == x.shape
48 | ```
49 |
50 | The class `MambaLM` ([mamba_lm.py](mamba_lm.py)) builds on the `Mamba` object and offers a classic API for language models. It can be used as follows :
51 |
52 | ```
53 | from mamba_lm import MambaLM, MambaLMConfig
54 |
55 | config = MambaLMConfig(d_model=16, n_layers=4, vocab_size=32000)
56 | model = MambaLM(config)
57 |
58 | x = torch.randint(high=32000, size=(16, 64))
59 | logits = model(x) # (B, L, vocab_size)
60 | ```
61 |
62 | It simply encapsulates a `Mamba` object with an embedding layer, a final normalization and a language modeling head.
63 |
64 | ## Examples
65 | There are two basics examples available :
66 | - `example_llm.ipynb` : load a Mamba model with pretrained weights (from 130M to 2.8B from HuggingFace)
67 | - `example_e2e_training.ipynb` : an end-to-end training example where a Mamba model is employed as a world model for a simple 3-3 grid game (training is not completed, the model should be larger).
68 |
69 | If you want a full training example (like in llama2.c), you can check the [othello_mamba repo](https://github.com/alxndrTL/othello_mamba) I've done. With this repo, you can train a Mamba from scratch, easily swipe it with a Transformer, come up with your own data, etc ...
70 |
71 | ___
72 | ## Performances
73 | This section provides a more comprehensive performance comparison between `mamba.py` and the official Mamba implementation.
74 | Overall, as the first graph of this file shows, both have approximately the same asymptotic performance with respect to the sequence length. You can think as `mamba.py` as a regular Transformer implementation, while the official Mamba implementation is more like FlashAttention v1. Both have their owns advantages.
75 |
76 | That being said, does the two implementations have the same asymptotic performances with respect to the other parameters ?
77 |
78 | ##### `d_model` asymptotic performances
79 |
80 |
82 |
83 |
84 | We can see that both implementations behave the same as we increase `d_model`. The gap between the two stays roughly the same. (`mamba.py` is overall ~2x slower)
85 |
86 | ##### `d_state` asymptotic performances
87 |
88 |
90 |
91 |
92 | This graph is important. We see that here, the asymptotic performance is not the same as we increase `d_state`. For a reminder, `d_state`, or $N$ in the paper, is the state expansion factor : each channel of the input is expanded into $N$ channels of the hidden state.
93 |
94 | Note : the CUDA version doesn't seem to be impacted by the increase of `d_state`. This is because the benchmark was done with a batch size of 1 : the GPU was not at its full capacity and thus the impact of an increased `d_state` isn't visible. The same happens if you have a small model, or a small input length. See [this issue](https://github.com/alxndrTL/mamba.py/issues/8).
95 |
96 | Does it matter in practice ? As of now, all the pretrained Mamba models (up to 2.8B parameters) used `d_state=16`, so this change of performance over `d_state` isn't important in this case. As `d_state` is not something that is supposed to grow (contrary to the seq length or `d_model`), this isn't a catastrophic result, but something to consider.
97 |
98 | However, it is interesting to relate this observation with the claim made by Albert Gu and Tri Dao [Mamba paper](https://arxiv.org/abs/2312.00752) : The main idea is to leverage properties of modern accelerators (GPUs) to materialize the state ℎ only in more efficient levels of the memory hierarchy.
99 | They also describe (Annex D) the main data movements of their selective scan : working mainly in SRAM, they can reduce the memory reads/writes by a factor of $O(N)$. This explains the different asymptotic behaviors that we see here.
100 |
101 | With `d_state=16` (as in `state-spaces/mamba-2.8b-slimpj`), the gap between the two is relatively small, but with `d_state=64` (currently not used in any models), the gap widens. (note the OOM on the second graph)
102 |
103 |
104 |
106 |
107 |
108 | All the previous graph were computed with a batch size of 1, on a A100 80GB.
109 | It is a measure of both the forward and backward pass of a single Mamba block.
110 |
111 | The previous analysis showed the importance of kernel fusion, which reduces the memory accesses by $O(N)$, which makes the whole process faster.
112 |
113 | But memory requierement should also be considered : the official Mamba implementation uses recomputation in the backward pass : rather than keeping in memory the activations computed during the forward pass, it simply recomputes them in the backward pass, when needed. This greatly reduces the memory requierement of the Mamba model when doing training. This is not implemented in this repo.
114 |
115 | Hence, this repo implements one of the three techniques mentionned in the Mamba paper that form the so called "hardware-aware selective scan" : the parallel scan.
116 | We say how kernel fusion impacts the speed while recomputation the memory requierements.
117 |
118 | ___
119 | ## Sources and where to learn more
120 | - the [Mamba paper](https://arxiv.org/abs/2312.00752) : describes the Mamba architecture as implemented in this repo, which allows to model sequences in linear time.
121 | - the [Mamba implementation](https://github.com/state-spaces/mamba), which is written in PyTorch but uses a parallel scan written in CUDA. This is the version that is the fastest.
122 | - [a minimal PyTorch implementation of Mamba](https://github.com/johnma2006/mamba-minimal), which implements the scan operation as a sequential loop (its performance are a bit worse than the 'sequential' line in the first graph). This code closely follows [this file](https://github.com/state-spaces/mamba/blob/da2626b5a5f347a8e844ac5e96a2cbcde3c34abb/mamba_ssm/modules/mamba_simple.py) from the officile Mamba implementation, but replaces the CUDA convolution with `torch.nn.Conv1d`, and the selective scan written in CUDA with a sequential loop. The code of this repo follows the structure of these 2 files.
123 | - [Prefix Sums and Their Applications](https://www.cs.cmu.edu/~guyb/papers/Ble93.pdf), by Guy E. Blelloch (1993).
124 | - [Parallelizing Linear Recurrent Neural Nets Over Sequence Length](https://arxiv.org/abs/1709.04057) : applies a parallel scan over the sequence in order to get rid of the sequential for-loop.
125 | - x.com/fchollet : original pscan implementation.
126 |
127 | ## TODOs
128 | - docs
129 | - ~~more tests with an increased `d_model` (add a Performances section)~~
130 | - ~~a step function, used for (auto-regressive) inference.~~
131 | - ~~a training function, similar to [llama2.c](https://github.com/karpathy/llama2.c)~~
132 |
133 | perfs :
134 | - ~~unfold the for-loops in `pscan.py` to achieve better performance (see [François Fleuret's pscan](https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=mygptrnn.git;a=blob;f=pscan.py;h=0bb0d145bf9c6c82115956c8ce1e6a063e56e747;hb=HEAD)) (although this will sacrifice readability of bit)~~
135 | ~~- write a reverse parallel scan specifically for the backward pass. (For now, we have to flip the array before and after the scan).~~
136 | - enable gradient checkpointing to reduce the memory usage
137 | - use torch.compile(). As far as I tested, it doesn’t work for now. It seems it isn’t happy with the custom PScan autograd function. Need to investigate. (see [PR#1](https://github.com/alxndrTL/mamba.py/pull/1))
138 |
--------------------------------------------------------------------------------
/mambapy/assets/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SCNU-RISLAB/OverlapMamba/b20add9b83ea87727a993eb316869d271436be78/mambapy/assets/logo.png
--------------------------------------------------------------------------------
/mambapy/assets/speed_comparison.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SCNU-RISLAB/OverlapMamba/b20add9b83ea87727a993eb316869d271436be78/mambapy/assets/speed_comparison.png
--------------------------------------------------------------------------------
/mambapy/assets/training_vs_d_model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SCNU-RISLAB/OverlapMamba/b20add9b83ea87727a993eb316869d271436be78/mambapy/assets/training_vs_d_model.png
--------------------------------------------------------------------------------
/mambapy/assets/training_vs_d_state.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SCNU-RISLAB/OverlapMamba/b20add9b83ea87727a993eb316869d271436be78/mambapy/assets/training_vs_d_state.png
--------------------------------------------------------------------------------
/mambapy/assets/training_vs_seqlen_d_state_var.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SCNU-RISLAB/OverlapMamba/b20add9b83ea87727a993eb316869d271436be78/mambapy/assets/training_vs_seqlen_d_state_var.png
--------------------------------------------------------------------------------
/mambapy/docs/assets/cumsum_rnns.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SCNU-RISLAB/OverlapMamba/b20add9b83ea87727a993eb316869d271436be78/mambapy/docs/assets/cumsum_rnns.jpg
--------------------------------------------------------------------------------
/mambapy/docs/assets/down_sweep_rule.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SCNU-RISLAB/OverlapMamba/b20add9b83ea87727a993eb316869d271436be78/mambapy/docs/assets/down_sweep_rule.jpg
--------------------------------------------------------------------------------
/mambapy/docs/assets/downsweep.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SCNU-RISLAB/OverlapMamba/b20add9b83ea87727a993eb316869d271436be78/mambapy/docs/assets/downsweep.jpg
--------------------------------------------------------------------------------
/mambapy/docs/assets/downsweep_ex.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SCNU-RISLAB/OverlapMamba/b20add9b83ea87727a993eb316869d271436be78/mambapy/docs/assets/downsweep_ex.jpg
--------------------------------------------------------------------------------
/mambapy/docs/assets/downsweep_updated.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SCNU-RISLAB/OverlapMamba/b20add9b83ea87727a993eb316869d271436be78/mambapy/docs/assets/downsweep_updated.jpg
--------------------------------------------------------------------------------
/mambapy/docs/assets/reduction_tree.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SCNU-RISLAB/OverlapMamba/b20add9b83ea87727a993eb316869d271436be78/mambapy/docs/assets/reduction_tree.jpg
--------------------------------------------------------------------------------
/mambapy/docs/assets/tensor_mem.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SCNU-RISLAB/OverlapMamba/b20add9b83ea87727a993eb316869d271436be78/mambapy/docs/assets/tensor_mem.jpeg
--------------------------------------------------------------------------------
/mambapy/docs/assets/tensor_mem_tree.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SCNU-RISLAB/OverlapMamba/b20add9b83ea87727a993eb316869d271436be78/mambapy/docs/assets/tensor_mem_tree.jpeg
--------------------------------------------------------------------------------
/mambapy/docs/assets/tree_reduction.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SCNU-RISLAB/OverlapMamba/b20add9b83ea87727a993eb316869d271436be78/mambapy/docs/assets/tree_reduction.jpeg
--------------------------------------------------------------------------------
/mambapy/docs/assets/tree_reduction_xs.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SCNU-RISLAB/OverlapMamba/b20add9b83ea87727a993eb316869d271436be78/mambapy/docs/assets/tree_reduction_xs.jpeg
--------------------------------------------------------------------------------
/mambapy/docs/assets/up_down_trees.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SCNU-RISLAB/OverlapMamba/b20add9b83ea87727a993eb316869d271436be78/mambapy/docs/assets/up_down_trees.jpg
--------------------------------------------------------------------------------
/mambapy/examples/buffer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | # todo : numpy est une bottleneck ? on a tout qui est en pytorch, on passe par numpy, pour ensuite reconvertir en pytorch pour le training...
4 | # pour la V1 (où la récolte et le training se déroulent pas en mm temps) ce n'est pas très genant
5 |
6 | # todo : ce qui est récolté en même temps est fourni en même temps dans les batch (ie across env)... encore une fois, pg pour la V1 les envs sont continuing (cas très particulier...)
7 |
8 | class ReplayBuffer():
9 | def __init__(self, num_envs, capacity, obs_dim, act_dim):
10 | self.obs_buffer = np.empty((capacity, num_envs, obs_dim), dtype=np.uint8)
11 | self.act_buffer = np.empty((capacity, num_envs), dtype=np.uint8) # no one-hot
12 | self.rew_buffer = np.empty((capacity, num_envs), dtype=np.float32)
13 |
14 | self.num_envs = num_envs
15 | self.capacity = capacity
16 | self.idx = 0
17 | self.size = 0
18 | self.rng = np.random.default_rng()
19 |
20 | def store(self, obs, act, rew):
21 | # obs : (num_envs, L*L)
22 | # act : (num_envs,)
23 | # rew : (num_envs,)
24 |
25 | self.obs_buffer[self.idx] = obs
26 | self.act_buffer[self.idx] = act
27 | self.rew_buffer[self.idx] = rew
28 |
29 | self.idx = (self.idx + 1) % self.capacity
30 | self.size = min(self.size + 1, self.capacity)
31 |
32 | def sample(self, batch_size, batch_len):
33 | assert self.size >= batch_len, "not enough experience stored"
34 |
35 | start_idxs = self.rng.integers(0, self.size - batch_len, size=batch_size) # (B,)
36 | env_idxs = self.rng.integers(0, self.num_envs, size=batch_size)[:, None] # (B, 1)
37 |
38 | # all indices for sampling : from start_idxs to start_idxs+batch_len
39 | idxs = start_idxs[:, None] + np.arange(batch_len) # (B, batch_len)
40 |
41 | batch_obs = self.obs_buffer[idxs, env_idxs]
42 | batch_acts = self.act_buffer[idxs, env_idxs]
43 | batch_rews = self.rew_buffer[idxs, env_idxs]
44 |
45 | batch = {'obs': batch_obs, 'act': batch_acts, 'rew': batch_rews}
46 | return batch
47 |
--------------------------------------------------------------------------------
/mambapy/examples/example_e2e_training.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 10,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import torch\n",
10 | "import torch.nn.functional as F\n",
11 | "\n",
12 | "import time\n",
13 | "from IPython.display import clear_output\n",
14 | "\n",
15 | "from example_src.tinyhome import TinyHomeEngineV1, print_grid, print_act\n",
16 | "from example_src.buffer import ReplayBuffer\n",
17 | "\n",
18 | "from mamba_lm import MambaLM, MambaLMConfig"
19 | ]
20 | },
21 | {
22 | "cell_type": "code",
23 | "execution_count": 3,
24 | "metadata": {},
25 | "outputs": [],
26 | "source": [
27 | "device = \"cuda\" if torch.cuda.is_available() else \"cpu\""
28 | ]
29 | },
30 | {
31 | "cell_type": "code",
32 | "execution_count": 4,
33 | "metadata": {},
34 | "outputs": [],
35 | "source": [
36 | "L = 5\n",
37 | "num_actions = 5\n",
38 | "num_obs_type = 4\n",
39 | "\n",
40 | "nb_instances = 512\n",
41 | "steps = 10000\n",
42 | "\n",
43 | "envs = TinyHomeEngineV1(B=nb_instances, h=L, w=L)\n",
44 | "buffer = ReplayBuffer(num_envs=nb_instances, capacity=int(1e6), obs_dim=L*L, act_dim=num_actions)\n",
45 | "\n",
46 | "obs = envs.reset()\n",
47 | "\n",
48 | "for _ in range(steps):\n",
49 | " a = torch.randint(low=0, high=num_actions, size=(nb_instances,))\n",
50 | " next_obs, rew = envs.step(a)\n",
51 | "\n",
52 | " buffer.store(obs.view(-1, L*L), a, rew.squeeze(1))\n",
53 | " obs = next_obs"
54 | ]
55 | },
56 | {
57 | "cell_type": "code",
58 | "execution_count": 5,
59 | "metadata": {},
60 | "outputs": [
61 | {
62 | "name": "stderr",
63 | "output_type": "stream",
64 | "text": [
65 | "/home/alex/miniconda3/envs/torch23/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
66 | " from .autonotebook import tqdm as notebook_tqdm\n"
67 | ]
68 | }
69 | ],
70 | "source": [
71 | "config = MambaLMConfig(d_model=16, n_layers=4, vocab_size=num_actions+num_obs_type, pad_vocab_size_multiple=num_actions+num_obs_type)\n",
72 | "model = MambaLM(config).to(device)\n",
73 | "optim = torch.optim.AdamW(model.parameters(), lr=3e-3)"
74 | ]
75 | },
76 | {
77 | "cell_type": "code",
78 | "execution_count": 6,
79 | "metadata": {},
80 | "outputs": [
81 | {
82 | "data": {
83 | "text/plain": [
84 | "13664"
85 | ]
86 | },
87 | "execution_count": 6,
88 | "metadata": {},
89 | "output_type": "execute_result"
90 | }
91 | ],
92 | "source": [
93 | "sum([p.numel() for p in model.parameters()])"
94 | ]
95 | },
96 | {
97 | "cell_type": "code",
98 | "execution_count": 7,
99 | "metadata": {},
100 | "outputs": [
101 | {
102 | "name": "stdout",
103 | "output_type": "stream",
104 | "text": [
105 | "6.719587326049805\n",
106 | "0.38883084058761597\n",
107 | "0.26524049043655396\n",
108 | "0.22898846864700317\n",
109 | "0.21574825048446655\n",
110 | "0.1874855011701584\n",
111 | "0.1682835817337036\n",
112 | "0.14247804880142212\n",
113 | "0.1338169276714325\n",
114 | "0.10260577499866486\n"
115 | ]
116 | }
117 | ],
118 | "source": [
119 | "for i in range(1000): \n",
120 | " B, T = 64, 10\n",
121 | " batch = buffer.sample(B, T)\n",
122 | "\n",
123 | " obs = torch.tensor(batch['obs']).long().to(device)\n",
124 | " act = torch.tensor(batch['act']).long().to(device)\n",
125 | "\n",
126 | " tokens = torch.cat([obs, torch.zeros(B, T, 1, dtype=torch.int, device='cuda')], dim=2).view(B, 26*T) # (B, 26T)\n",
127 | " tokens[:, 25::26] = act+4\n",
128 | "\n",
129 | " input = tokens\n",
130 | " output = tokens[:, 1:].reshape(-1)\n",
131 | "\n",
132 | " logits = model(tokens[:, :-1]) # (B, 26T-1, vocab_size)\n",
133 | " loss = F.cross_entropy(logits.view(-1, logits.size(-1)), output)\n",
134 | "\n",
135 | " optim.zero_grad()\n",
136 | " loss.backward()\n",
137 | " optim.step()\n",
138 | "\n",
139 | " if i%100==0:\n",
140 | " print(loss.item())"
141 | ]
142 | },
143 | {
144 | "cell_type": "code",
145 | "execution_count": 15,
146 | "metadata": {},
147 | "outputs": [],
148 | "source": [
149 | "tokens = torch.ones(1, 2, dtype=torch.long).cuda() # (B=1, 2)\n",
150 | "T = 20\n",
151 | "for _ in range(26*T-2):\n",
152 | " logits = model(tokens)[0, -1]\n",
153 | " probs = F.softmax(logits, dim=0)\n",
154 | " sampled = torch.multinomial(probs, num_samples=1, replacement=True)\n",
155 | " tokens = torch.cat([tokens, sampled.view(1, 1)], dim=1)\n",
156 | "tokens = tokens.view(T, 26) # (T, 26)"
157 | ]
158 | },
159 | {
160 | "cell_type": "code",
161 | "execution_count": 16,
162 | "metadata": {},
163 | "outputs": [
164 | {
165 | "name": "stdout",
166 | "output_type": "stream",
167 | "text": [
168 | "#####\n",
169 | "# #\n",
170 | "# #\n",
171 | "#G@ #\n",
172 | "#####\n",
173 | "\n",
174 | "\n",
175 | "E\n"
176 | ]
177 | }
178 | ],
179 | "source": [
180 | "for timestep in tokens:\n",
181 | " grid = timestep[:-1].view(1, 5, 5)\n",
182 | " a = timestep[-1]-4\n",
183 | "\n",
184 | " clear_output(wait=True)\n",
185 | " print_grid(grid)\n",
186 | " print_act(a.item())\n",
187 | " time.sleep(0.1)"
188 | ]
189 | },
190 | {
191 | "cell_type": "code",
192 | "execution_count": null,
193 | "metadata": {},
194 | "outputs": [],
195 | "source": []
196 | }
197 | ],
198 | "metadata": {
199 | "kernelspec": {
200 | "display_name": "torch23",
201 | "language": "python",
202 | "name": "python3"
203 | },
204 | "language_info": {
205 | "codemirror_mode": {
206 | "name": "ipython",
207 | "version": 3
208 | },
209 | "file_extension": ".py",
210 | "mimetype": "text/x-python",
211 | "name": "python",
212 | "nbconvert_exporter": "python",
213 | "pygments_lexer": "ipython3",
214 | "version": "3.11.5"
215 | }
216 | },
217 | "nbformat": 4,
218 | "nbformat_minor": 2
219 | }
220 |
--------------------------------------------------------------------------------
/mambapy/examples/example_llm.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 3,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import torch\n",
10 | "\n",
11 | "from transformers import AutoTokenizer\n",
12 | "\n",
13 | "from mamba_lm import from_pretrained"
14 | ]
15 | },
16 | {
17 | "cell_type": "code",
18 | "execution_count": 4,
19 | "metadata": {},
20 | "outputs": [],
21 | "source": [
22 | "device = \"cuda\" if torch.cuda.is_available() else \"cpu\""
23 | ]
24 | },
25 | {
26 | "cell_type": "code",
27 | "execution_count": 7,
28 | "metadata": {},
29 | "outputs": [
30 | {
31 | "name": "stderr",
32 | "output_type": "stream",
33 | "text": [
34 | "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
35 | ]
36 | }
37 | ],
38 | "source": [
39 | "model = from_pretrained('state-spaces/mamba-130m').to(device)\n",
40 | "tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b')"
41 | ]
42 | },
43 | {
44 | "cell_type": "code",
45 | "execution_count": 8,
46 | "metadata": {},
47 | "outputs": [],
48 | "source": [
49 | "output = model.generate(tokenizer, \"Mamba is a type of\")"
50 | ]
51 | },
52 | {
53 | "cell_type": "code",
54 | "execution_count": 9,
55 | "metadata": {},
56 | "outputs": [
57 | {
58 | "name": "stdout",
59 | "output_type": "stream",
60 | "text": [
61 | "Mamba is a type of black sheep. Many types of black sheep, from black hares, black mares and black oxen, to Mamba, which is a black sheep. Mamba is the black sheep in Mocha. As they move in and out of one\n"
62 | ]
63 | }
64 | ],
65 | "source": [
66 | "print(output)"
67 | ]
68 | },
69 | {
70 | "cell_type": "code",
71 | "execution_count": null,
72 | "metadata": {},
73 | "outputs": [],
74 | "source": []
75 | }
76 | ],
77 | "metadata": {
78 | "kernelspec": {
79 | "display_name": "torch23",
80 | "language": "python",
81 | "name": "python3"
82 | },
83 | "language_info": {
84 | "codemirror_mode": {
85 | "name": "ipython",
86 | "version": 3
87 | },
88 | "file_extension": ".py",
89 | "mimetype": "text/x-python",
90 | "name": "python",
91 | "nbconvert_exporter": "python",
92 | "pygments_lexer": "ipython3",
93 | "version": "3.11.5"
94 | }
95 | },
96 | "nbformat": 4,
97 | "nbformat_minor": 2
98 | }
99 |
--------------------------------------------------------------------------------
/mambapy/examples/tinyhome.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | import os
5 | import time
6 |
7 | """
8 | V1 : fully observable, tiny room with a goal generated randomly.
9 | going over the goal gets the agent a reward and spawn a new goal.
10 | the episodes ends after T steps.
11 |
12 | it is convenient because:
13 | - all envs end at the same time (both convenient for the env engine, AND for the training of the transformer : no padding needed)
14 | - no obstacles to handle (convenient in the step() func)
15 | - continuing, so its easier for the replay buffer (see buffer.py)
16 | """
17 | class TinyHomeEngineV1:
18 | def __init__(self, B, h=10, w=10, max_envs_disp=4):
19 | self.B = B
20 | self.h = h
21 | self.w = w
22 | self.max_envs_disp = max_envs_disp
23 |
24 | def reset(self):
25 | self.grid = torch.zeros(self.B, self.h, self.w, dtype=torch.int)
26 | self.grid[:, 0, :] = 1
27 | self.grid[:, -1, :] = 1
28 | self.grid[:, :, 0] = 1
29 | self.grid[:, :, -1] = 1
30 |
31 | self.pos_player = torch.randint(low=1, high=self.h-1, size=(self.B, 2))
32 | self.pos_goal = torch.randint(low=1, high=self.h-1, size=(self.B, 2))
33 |
34 | while True:
35 | overlap = torch.all(self.pos_player == self.pos_goal, dim=1)
36 | if not overlap.any():
37 | break
38 | self.pos_goal[overlap] = torch.randint(low=1, high=self.h-1, size=(overlap.sum(), 2))
39 |
40 | disp_grid = self.grid.clone()
41 | disp_grid[torch.arange(self.B), self.pos_player[:, 0], self.pos_player[:, 1]] = 2
42 | disp_grid[torch.arange(self.B), self.pos_goal[:, 0], self.pos_goal[:, 1]] = 3
43 |
44 | """
45 | x = F.one_hot(self.pos_player[:, 0]-1, num_classes=3)
46 | y = F.one_hot(self.pos_player[:, 1]-1, num_classes=3)
47 | u = F.one_hot(self.pos_goal[:, 0]-1, num_classes=3)
48 | v = F.one_hot(self.pos_goal[:, 1]-1, num_classes=3)
49 |
50 | concatenated = torch.cat([x, y, u, v], dim=1) # (B, 12)
51 | """
52 |
53 | return disp_grid
54 |
55 | def optimal_policy_vectorized(self, moves):
56 | B, _ = self.pos_player.shape
57 |
58 | # Expand pos_player to (B, 5, 2) to match the moves
59 | expanded_pos_player = self.pos_player.unsqueeze(1).expand(-1, moves.size(0), -1)
60 |
61 | # Compute new positions for each move
62 | new_positions = expanded_pos_player + moves
63 | new_positions = new_positions.clamp(min=1, max=self.h-2)
64 |
65 | # Calculate Manhattan distances for each new position
66 | distances = torch.sum(torch.abs(new_positions - self.pos_goal.unsqueeze(1)), dim=2)
67 |
68 | # Find the move with the minimum distance for each environment
69 | actions = torch.argmin(distances, dim=1)
70 |
71 | return actions
72 |
73 | def step(self, a):
74 | # a : (B,)
75 |
76 | moves = torch.tensor([[0, 0], [-1, 0], [0, 1], [1, 0], [0, -1]]) # X, N, E, S, W
77 |
78 | #a = self.optimal_policy_vectorized(moves)
79 |
80 | self.pos_player += moves[a]
81 | self.pos_player = self.pos_player.clamp(min=1, max=self.h-2) # pas du tout généralisable à des murs placés au milieu etc etc
82 |
83 | reached_goal = torch.all(self.pos_player == self.pos_goal, dim=1)
84 | reward = torch.where(reached_goal, 1., 0.).unsqueeze(1)
85 |
86 | # regen goal (only for "completed" env)
87 | num_reached = reached_goal.sum()
88 | if num_reached > 0:
89 | self.pos_goal[reached_goal] = torch.randint(low=1, high=self.h-1, size=(num_reached, 2))
90 |
91 | # make sure that the regenerated goals are at a different place
92 | while True:
93 | overlap = torch.all(self.pos_player == self.pos_goal, dim=1)
94 | if not overlap.any():
95 | break
96 | self.pos_goal[overlap] = torch.randint(low=1, high=self.h-1, size=(overlap.sum(), 2))
97 |
98 | disp_grid = self.grid.clone()
99 | disp_grid[torch.arange(self.B), self.pos_player[:, 0], self.pos_player[:, 1]] = 2
100 | disp_grid[torch.arange(self.B), self.pos_goal[:, 0], self.pos_goal[:, 1]] = 3
101 |
102 | """
103 | x = F.one_hot(self.pos_player[:, 0]-1, num_classes=3)
104 | y = F.one_hot(self.pos_player[:, 1]-1, num_classes=3)
105 | u = F.one_hot(self.pos_goal[:, 0]-1, num_classes=3)
106 | v = F.one_hot(self.pos_goal[:, 1]-1, num_classes=3)
107 |
108 | concatenated = torch.cat([x, y, u, v], dim=1) # (B, 12)
109 | """
110 |
111 | return disp_grid, reward
112 |
113 | def display(self):
114 | os.system('cls' if os.name == 'nt' else 'clear')
115 |
116 | disp_grid = self.grid.clone()
117 | disp_grid[torch.arange(self.B), self.pos_player[:, 0], self.pos_player[:, 1]] = 2
118 | disp_grid[torch.arange(self.B), self.pos_goal[:, 0], self.pos_goal[:, 1]] = 3
119 |
120 | for b in range(min(self.B, self.max_envs_disp)):
121 | for row in disp_grid[b]:
122 | print(''.join(display_mapping.get(value.item(), '?') for value in row))
123 |
124 | print("\n")
125 |
126 | display_mapping = {
127 | 0: ' ',
128 | 1: '#',
129 | 2: '@',
130 | 3: 'G'
131 | }
132 |
133 | def print_grid(grid):
134 | for b in range(grid.shape[0]):
135 | for row in grid[b]:
136 | print(''.join(display_mapping.get(value.item(), '?') for value in row))
137 |
138 | print("\n")
139 |
140 | actions_to_char = ['X', 'N', 'E', 'S', 'W']
141 | def print_act(act):
142 | print(actions_to_char[act])
143 |
144 | if __name__ == "__main__":
145 | mode = "display" # "display" or "grind" or "collect"
146 |
147 | if mode == "display":
148 | nb_instances = 2
149 | steps = 10
150 |
151 | engine = TinyHomeEngineV1(nb_instances, 5, 5)
152 | engine.reset()
153 |
154 | engine.display()
155 | time.sleep(0.5)
156 |
157 | for _ in range(steps):
158 | obs, rew = engine.step(torch.randint(low=0, high=5, size=(nb_instances,)))
159 | print(obs)
160 | #engine.display()
161 | print(rew.shape)
162 | time.sleep(0.05)
163 |
164 | elif mode == "grind":
165 | nb_instances = 1000
166 | steps = 1000
167 |
168 | engine = TinyHomeEngineV1(nb_instances)
169 | engine.reset()
170 |
171 | start_time = time.perf_counter()
172 |
173 | for _ in range(steps):
174 | obs, rew = engine.step(torch.randint(low=0, high=5, size=(nb_instances,)))
175 |
176 | end_time = time.perf_counter()
177 |
178 | print(f"The collection of {nb_instances*steps} steps took {end_time-start_time} seconds")
179 |
180 | elif mode == "collect":
181 | nb_instances = 2
182 | steps = 1
183 |
184 | embed = torch.nn.Embedding(num_embeddings=6, embedding_dim=2)
185 |
186 | engine = TinyHomeEngineV1(nb_instances, 5, 5)
187 | engine.reset()
188 |
189 | for _ in range(steps):
190 | obs, rew = engine.step(torch.randint(low=0, high=5, size=(nb_instances,)))
191 | # obs: (B, h, w), rew: (B, 1)
192 |
193 | obs = obs.view(nb_instances, 25) # (B, h*w)
194 |
195 | e = embed(obs) # (B, h*w, embed_dim)
196 | print(e.shape)
--------------------------------------------------------------------------------
/mambapy/mamba_lm.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, fields, asdict
2 | import json
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 | from mamba import Mamba, MambaConfig, RMSNorm
9 |
10 | """
11 |
12 | Encapsulates a Mamba model as language model. It has an embedding layer, and a LM head which maps the model output to logits.
13 |
14 | """
15 |
16 | # TODO generate function : batch size != 1 ? (for now B=1)
17 | # TODO generate function : top-p sampling
18 |
19 | @dataclass
20 | class MambaLMConfig(MambaConfig):
21 | vocab_size: int = 32000
22 | pad_vocab_size_multiple: int = 8
23 |
24 | def __post_init__(self):
25 | super().__post_init__()
26 |
27 | if self.vocab_size % self.pad_vocab_size_multiple != 0:
28 | self.vocab_size += (self.pad_vocab_size_multiple - self.vocab_size % self.pad_vocab_size_multiple)
29 |
30 | def to_mamba_config(self) -> MambaConfig:
31 | mamba_config_fields = {field.name for field in fields(MambaConfig)}
32 | filtered_dict = {k: v for k, v in asdict(self).items() if k in mamba_config_fields}
33 | return MambaConfig(**filtered_dict)
34 |
35 | # adapted from https://github.com/johnma2006/mamba-minimal
36 | def from_pretrained(name: str):
37 | """
38 | Returns a model loaded with pretrained weights pulled from HuggingFace.
39 |
40 | Args:
41 | name: As of now, supports
42 | * 'state-spaces/mamba-2.8b-slimpj'
43 | * 'state-spaces/mamba-2.8b'
44 | * 'state-spaces/mamba-1.4b'
45 | * 'state-spaces/mamba-790m'
46 | * 'state-spaces/mamba-370m'
47 | * 'state-spaces/mamba-130m'
48 |
49 | Returns:
50 | model: a Mamba model configured with the proper parameters and initialized with the proper weights
51 | """
52 |
53 | from transformers.utils import WEIGHTS_NAME, CONFIG_NAME
54 | from transformers.utils.hub import cached_file
55 |
56 | def load_config_hf(model_name):
57 | resolved_archive_file = cached_file(model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=False)
58 | return json.load(open(resolved_archive_file))
59 |
60 | def load_state_dict_hf(model_name):
61 | resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False)
62 | return torch.load(resolved_archive_file, weights_only=True, map_location='cpu', mmap=True)
63 |
64 | # copy config data
65 | config_data = load_config_hf(name)
66 | config = MambaLMConfig(d_model=config_data['d_model'], n_layers=config_data['n_layer'], vocab_size=config_data['vocab_size'])
67 |
68 | model = MambaLM(config)
69 |
70 | # copy weights
71 | state_dict = load_state_dict_hf(name)
72 |
73 | new_state_dict = {}
74 | for key in state_dict:
75 | if key == 'backbone.embedding.weight' or key == 'backbone.norm_f.weight':
76 | new_key = key.replace('backbone.', '')
77 | else:
78 | new_key = key.replace('backbone', 'mamba')
79 |
80 | new_state_dict[new_key] = state_dict[key]
81 |
82 | model.load_state_dict(new_state_dict)
83 |
84 | return model
85 |
86 | class MambaLM(nn.Module):
87 | def __init__(self, lm_config: MambaLMConfig):
88 | super().__init__()
89 | self.lm_config = lm_config
90 | self.config = lm_config.to_mamba_config()
91 |
92 | self.embedding = nn.Embedding(self.lm_config.vocab_size, self.config.d_model)
93 | self.mamba = Mamba(self.config)
94 | self.norm_f = RMSNorm(self.config.d_model)
95 |
96 | self.lm_head = nn.Linear(self.config.d_model, self.lm_config.vocab_size, bias=False)
97 | self.lm_head.weight = self.embedding.weight
98 |
99 | def forward(self, tokens):
100 | # tokens : (B, L)
101 |
102 | # logits : (B, L, vocab_size)
103 |
104 | x = self.embedding(tokens)
105 |
106 | x = self.mamba(x)
107 | x = self.norm_f(x)
108 |
109 | logits = self.lm_head(x)
110 |
111 | return logits
112 |
113 | def step(self, token, caches):
114 | # token : (B)
115 | # caches : [cache(layer) for all layers], cache : (h, inputs)
116 |
117 | # logits : (B, vocab_size)
118 | # caches : [cache(layer) for all layers], cache : (h, inputs)
119 |
120 | x = self.embedding(token)
121 |
122 | x, caches = self.mamba.step(x, caches)
123 | x = self.norm_f(x)
124 |
125 | logits = self.lm_head(x)
126 |
127 | return logits, caches
128 |
129 | # TODO temperature
130 | # TODO process prompt in parallel, and pass in sequential mode when prompt is finished ?
131 | def generate(self, tokenizer, prompt: str, num_tokens: int = 50, sample: bool = True, top_k: int = 40):
132 | self.eval()
133 |
134 | input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(next(self.parameters()).device) # (1, num_tokens)
135 |
136 | # caches is a list of cache, one per layer
137 | # cache is composed of : the hidden state, and the last d_conv-1 inputs
138 | # the hidden state because the update is like an RNN
139 | # the last d_conv-1 inputs because they are used in a 1d convolution (usually d_conv=4 so this is not large)
140 | caches = [(None, torch.zeros(1, self.config.d_inner, self.config.d_conv-1, device=input_ids.device)) for _ in range(self.config.n_layers)]
141 |
142 | for i in range(input_ids.size(1) + num_tokens - 1):
143 | with torch.no_grad():
144 | # forward the new output, get new cache
145 | next_token_logits, caches = self.step(input_ids[:, i], caches) # (1, vocab_size), caches
146 |
147 | # sample (no sampling when the prompt is being processed)
148 | if i+1 >= input_ids.size(1):
149 | probs = F.softmax(next_token_logits, dim=-1) # (1, vocab_size)
150 |
151 | if top_k is not None:
152 | values, _ = torch.topk(probs, k=top_k) # (1, k) ordered from lowest to biggest
153 | probs[probs < values[:, -1, None]] = 0
154 | probs = probs / probs.sum(axis=1, keepdims=True)
155 |
156 | if sample:
157 | next_token = torch.multinomial(probs, num_samples=1).squeeze(1) # (1)
158 | else:
159 | next_token = torch.argmax(probs, dim=-1) # (1)
160 |
161 | input_ids = torch.cat([input_ids, next_token.unsqueeze(1)], dim=1)
162 |
163 | output = [tokenizer.decode(output.tolist()) for output in input_ids][0]
164 |
165 | self.train()
166 |
167 | return output
168 |
--------------------------------------------------------------------------------
/mambapy/mlx/README.md:
--------------------------------------------------------------------------------
1 | # MLX implementation of Mamba 🐍
2 |
3 | This folder contains a complete MLX implementation of [Mamba](https://arxiv.org/abs/2312.00752), which allows to train and do inference with Mamba models using an Apple silicon equiped Mac.
4 | Both the forward and backward pass are numerically equivalent to the PyTorch code from `mamba.py`, as well as to the official [Mamba implementation](https://github.com/state-spaces/mamba).
5 |
6 |
7 |
8 |
9 |
10 | The folder is organized as follows :
11 | - `pscan_mlx.py` : a MLX implementation of Blelloch's parallel scan.
12 | - `mamba_mlx.py` : the Mamba model, as described in the [paper](https://arxiv.org/abs/2312.00752). It is numerically equivalent (initialization, forward and backward pass).
13 | - `mamba_lm_mlx.py` : encapsulates a Mamba model in order to use it as a language model.
14 | - `utils.py` : utilitary functions
15 | - `misc.py` : a temporary file containing functions not yet implemented in MLX
16 | - `📁 scripts` : example scripts to play around with Mamba.
17 |
18 | # Quickstart
19 | Make sur you have `mlx`, `torch` and `transformers` installed.
20 |
21 | First, you can clone the repo :
22 |
23 | ```
24 | git clone https://github.com/alxndrTL/mamba.py.git
25 | cd mamba.py/mlx
26 | ```
27 |
28 | If you want to do inference with pretrained models (from 130M to 2.8B parameters), you can simply do :
29 |
30 | ```
31 | cd scripts
32 | python3 generate.py --prompt="Mamba is a type of" --hf_model_name="state-spaces/mamba-130m" --n_tokens=100
33 | ```
34 |
35 | It will download the specified model from [HuggingFace](https://huggingface.co/state-spaces), convert it (and save it to disk) to run with MLX, and stream generated words.
36 | As of now, you can choose from :
37 |
38 | ```
39 | state-spaces/mamba-130m
40 | state-spaces/mamba-370m
41 | state-spaces/mamba-790m
42 | state-spaces/mamba-1.4b
43 | state-spaces/mamba-2.8b
44 | state-spaces/mamba-2.8b-slimpj
45 | ```
46 |
47 | As of today, only single precision inference is supported. On an M2 Pro (16GB), the 790M model runs at ~30tok/s.
48 |
49 | Unlike the Transformers, inference doesn't depend on the sequence length, so we just have to carry along a hidden state 😎 (and the last `d_conv-1` inputs, where `d_conv` is usually 4).
50 |
51 | As of now, `generate.py` is the only available script. But you can train the model using your own script, just like you would with a Transformer.
52 |
53 | # About
54 | Mamba is a new state-space model that is able to do sequence modeling - just like Transformers do.
55 | While Transformers use attention to flow information through time, Mamba uses a simple hidden state, just like RNNs. It has the benefit of a constant inference time wrt. sequence length.
56 | What is important to know is that while it uses a hidden state that is updated sequentially through time :
57 |
58 | $$
59 | h_t = A h_{t-1} + Bx_t
60 | $$
61 |
62 | all the $h_t$ can actually be computed in parallel, thanks to an algorithm named the parallel scan, implemented in `pscan_mlx.py` in MLX.
63 | You can learn more about this algorithm and its implementation in `docs/pscan.ipynb` at the root of this repo.
64 | As you can see on the graph shown on the landing page of this repo, the naive sequential implementation is way slower than implementations than use this parallel scan.
65 |
66 | However, it's important to note that while the parallel scan gives correct computations with MLX, it's slow, so slow that it is sometimes actually harmful to use it.
67 | Why ? It is not yet clear. When translating the algorithm from PyTorch to MLX, a little modification is needed : at each iteration, we need to write back to our original arrays the numbers we computed. This is because MLX doesn't have views implemented (yet?). (see [this issue](https://github.com/ml-explore/mlx/issues/466)). I thus switched to a version which only uses slicing (see `pscan_mlx.py` for more details), but the performances are still lacking behind the sequential version (should be orders of magnitude faster).
68 |
69 | But, MLX is not even 2 months old :)
70 | I will closely follow MLX development to watch for potential upgrades of this MLX implementation.
71 |
72 | # Why [mamba.py](../) in MLX ?
73 | While the primary goal of the PyTorch version is educational, this implementation (with a performing parallel scan) could power future fine-tuning scripts. We are early, as there is still not much resources about fine-tuned Mamba models (see [this](https://github.com/havenhq/mamba-chat)). MLX doesn't yet have an associative scan operation implemented.
74 |
75 | Also, the more people play around and train Mamba models, the more we will be able to know better its strengths and limits, allowing us to compare it against its "competitors" (Based, RWKV, StripedHyena, or the Transformer).
76 |
77 | And finally, it was a great exercise for me, after having implemented Mamba in PyTorch and not knowing MLX.
78 |
79 | # TODOs
80 | - fix large memory footprint at inference ([issue](https://github.com/alxndrTL/mamba.py/issues/5))
81 | - add more ready-to-go scripts (training and fine-tuning)
82 | - support for mixed precision training ? (see [this](https://github.com/state-spaces/mamba/tree/main?tab=readme-ov-file#precision) from the official Mamba implementation)
83 | - set device (cpu and gpu) (see [A Simple Example](https://ml-explore.github.io/mlx/build/html/usage/unified_memory.html#a-simple-example) from the MLX docs)
84 | - see TODOs of the PyTorch versions
85 | - watch out for new MLX updates ;)
86 |
87 | Feel free to contribute !
--------------------------------------------------------------------------------
/mambapy/mlx/assets/mamba_mlx.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SCNU-RISLAB/OverlapMamba/b20add9b83ea87727a993eb316869d271436be78/mambapy/mlx/assets/mamba_mlx.png
--------------------------------------------------------------------------------
/mambapy/mlx/mamba_lm_mlx.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, fields, asdict
2 | import json
3 |
4 | import mlx.core as mx
5 | import mlx.nn as nn
6 |
7 | from mamba_mlx import Mamba, MambaConfig
8 | from misc import topk
9 | from utils import load_config_hf, load_state_dict_hf
10 |
11 | """
12 |
13 | Encapsulates a Mamba model as language model. It has an embedding layer, and a LM head which maps the model output to logits.
14 |
15 | """
16 |
17 | # TODO generate function : batch size != 1 ? (for now B=1)
18 | # TODO generate function : top-p sampling
19 |
20 | @dataclass
21 | class MambaLMConfig(MambaConfig):
22 | vocab_size: int = 32000
23 | pad_vocab_size_multiple: int = 8
24 |
25 | def __post_init__(self):
26 | super().__post_init__()
27 |
28 | if self.vocab_size % self.pad_vocab_size_multiple != 0:
29 | self.vocab_size += (self.pad_vocab_size_multiple - self.vocab_size % self.pad_vocab_size_multiple)
30 |
31 | def to_mamba_config(self) -> MambaConfig:
32 | mamba_config_fields = {field.name for field in fields(MambaConfig)}
33 | filtered_dict = {k: v for k, v in asdict(self).items() if k in mamba_config_fields}
34 | return MambaConfig(**filtered_dict)
35 |
36 | class MambaLM(nn.Module):
37 | def __init__(self, lm_config: MambaLMConfig):
38 | super().__init__()
39 | self.lm_config = lm_config
40 | self.config = lm_config.to_mamba_config()
41 |
42 | self.embedding = nn.Embedding(self.lm_config.vocab_size, self.config.d_model)
43 | self.mamba = Mamba(self.config)
44 | self.norm_f = nn.RMSNorm(self.config.d_model)
45 |
46 | self.lm_head = nn.Linear(self.config.d_model, self.lm_config.vocab_size, bias=False)
47 | self.lm_head.weight = self.embedding.weight #TODO this does not really tie the weights, investigate
48 |
49 | def __call__(self, tokens):
50 | # tokens : (B, L)
51 |
52 | # logits : (B, L, vocab_size)
53 |
54 | x = self.embedding(tokens)
55 |
56 | x = self.mamba(x)
57 | x = self.norm_f(x)
58 |
59 | logits = self.lm_head(x)
60 |
61 | return logits
62 |
63 | def step(self, token, caches):
64 | # token : (B)
65 | # caches : [cache(layer) for all layers], cache : (h, inputs)
66 |
67 | # logits : (B, vocab_size)
68 | # caches : [cache(layer) for all layers], cache : (h, inputs)
69 |
70 | x = self.embedding(token)
71 |
72 | x, caches = self.mamba.step(x, caches)
73 | x = self.norm_f(x)
74 |
75 | logits = self.lm_head(x)
76 |
77 | return logits, caches
78 |
79 | def generate(self, tokenizer, prompt: str, n_tokens_to_gen: int = 50, sample: bool = True, temperature: float = 1.0, top_k: int = None):
80 | self.eval()
81 |
82 | input_ids = mx.array(tokenizer(prompt, return_tensors='np').input_ids) # (1, tokens_prompt) # (1, num_tokens)
83 |
84 | # caches is a list of cache, one per layer
85 | # cache is composed of : the hidden state, and the last d_conv-1 inputs
86 | # the hidden state because the update is like an RNN
87 | # the last d_conv-1 inputs because they are used in a 1d convolution (usually d_conv=4 so this is not large)
88 | caches = [(None, mx.zeros([1, self.config.d_conv-1, self.config.d_inner])) for _ in range(self.config.n_layers)]
89 |
90 | for i in range(input_ids.shape[1] + n_tokens_to_gen - 1):
91 | next_token_logits, caches = self.step(input_ids[:, i], caches) # (1, vocab_size), caches
92 |
93 | # sample (no sampling when the prompt is being processed)
94 | if i+1 >= input_ids.shape[1]:
95 |
96 | if top_k is not None:
97 | values = topk(next_token_logits, k=top_k) # (1, k) ordered from lowest to biggest
98 | mask = next_token_logits < (values[:, 0, None])
99 | next_token_logits = mx.where(mask, -5000, next_token_logits) # TODO -mx.inf is problematic for now
100 |
101 | if sample and temperature > 0:
102 | next_token = mx.random.categorical(next_token_logits * (1/temperature), num_samples=1)
103 | else:
104 | next_token = mx.argmax(next_token_logits, axis=-1)[:, None]
105 |
106 | input_ids = mx.concatenate([input_ids, next_token], axis=1)
107 |
108 | output = [tokenizer.decode(output.tolist()) for output in input_ids][0]
109 |
110 | self.train()
111 |
112 | return output
113 |
114 | @staticmethod
115 | def from_pretrained(name: str):
116 | """
117 | Returns a model loaded with pretrained weights pulled from HuggingFace.
118 |
119 | Args:
120 | name: As of now, supports
121 | * 'state-spaces/mamba-2.8b-slimpj'
122 | * 'state-spaces/mamba-2.8b'
123 | * 'state-spaces/mamba-1.4b'
124 | * 'state-spaces/mamba-790m'
125 | * 'state-spaces/mamba-370m'
126 | * 'state-spaces/mamba-130m'
127 |
128 | Returns:
129 | model: a Mamba model configured with the proper parameters and initialized with the proper weights
130 | """
131 |
132 | import os
133 | import numpy as np
134 | from mlx.utils import tree_unflatten
135 |
136 | from utils import map_mambassm_torch_to_mlx
137 |
138 | # copy config data
139 | config_data = load_config_hf(name)
140 | config = MambaLMConfig(d_model=config_data['d_model'], n_layers=config_data['n_layer'], vocab_size=config_data['vocab_size'])
141 |
142 | model = MambaLM(config)
143 |
144 | # copy weights
145 | filename = name.split('/')[-1] + '.mlx.npz'
146 |
147 | if not os.path.exists(filename):
148 | state_dict = load_state_dict_hf(name)
149 | mlx_state_dict = map_mambassm_torch_to_mlx(state_dict)
150 |
151 | np.savez(filename, **mlx_state_dict)
152 |
153 | model.update(tree_unflatten(list(mx.load(filename).items())))
154 |
155 | return model
156 |
--------------------------------------------------------------------------------
/mambapy/mlx/misc.py:
--------------------------------------------------------------------------------
1 | import mlx.core as mx
2 | import mlx.nn as nn
3 |
4 | import torch
5 |
6 | """
7 |
8 | This is a temporary file, as it contains additional functions which are needed but not yet implemented in MLX (as of release v0.0.10).
9 | The first functions are straightforward, while the depthwise 1d convolution is a bit more elaborared.
10 |
11 | """
12 |
13 | def softplus(x, beta=1, threshold=20):
14 | scaled_x = beta * x
15 | mask = scaled_x > threshold
16 | return mx.where(mask, x, 1/beta * mx.logaddexp(0, x))
17 |
18 | def unsqueeze(x, axis):
19 | """
20 | Same API as PyTorch.
21 | """
22 |
23 | assert axis <= len(x.shape)
24 | if axis >= 0:
25 | new_shape = x.shape[:axis] + [1] + x.shape[axis:]
26 | else:
27 | new_shape = x.shape + [1]
28 | return x.reshape(new_shape)
29 |
30 | def clamp(x, min=None, max=None):
31 | if min is not None:
32 | mask_lower = x < min
33 | if max is not None:
34 | mask_upper = x > max
35 |
36 | if min is not None:
37 | if max is not None:
38 | return mx.where(mask_upper, max, mx.where(mask_lower, min, x))
39 | return mx.where(mask_lower, min, x)
40 |
41 | return mx.where(mask_upper, max, x)
42 |
43 | def topk(x, k):
44 | """
45 | Returns the top k biggest values of x along the 2nd dim.
46 |
47 | Args:
48 | x : (B, vocab_size). can be probs or logits
49 |
50 | Returns:
51 | values : (B, k). ordered from lowest to biggest val
52 | """
53 |
54 | return mx.sort(x)[:, -k:]
55 |
56 | class DepthWiseConv1d(nn.Module):
57 | def __init__(self, channels, kernel_size, bias, padding):
58 | super().__init__()
59 |
60 | self.channels = channels
61 | self.kernel_size = kernel_size
62 | self.bias = bias
63 | self.padding = padding
64 |
65 | self.conv1d = nn.Conv1d(in_channels=channels, out_channels=channels,
66 | kernel_size=kernel_size, bias=True, padding=padding)
67 |
68 | # see comment below
69 | indices = mx.arange(channels)
70 | mask = mx.zeros_like(self.conv1d.weight)
71 | mask[indices, :, indices] = 1
72 | self.conv1d.weight *= mask
73 |
74 | def __call__(self, x):
75 | return self.conv1d(x)
76 |
77 | def torch_to_mlx_depthwise_weights(torch_weights):
78 | """
79 |
80 | A convolution is said to be "depthwise" when channel i of the output is only computed by passing the filter overing channel i of the input.
81 | In torch, this is done by setting groups=number of channels.
82 | Because it is not yet implemented in MLX, a workaround is to zero out the weights of a conv object initialized with groups=1 (groups=1 is when output channel i is computing by passing the filter over all input channels)
83 | To do that, we need to zero out all elements except those on the "diagonal":
84 | for channels=8 and kernel_size=4, the weights are (8, 4, 8).
85 | these are composed of 8 x (8, 4, 1) filter, each of those is used to compute one output channel.
86 | this (8, 4, 1) filter is composed of 8 x (1, 4, 1) filter, each of those is passed over each input channel.
87 | so we need to set to 0 all those 8 filters, except the one which corresponds to the output channel of these 8 filters (so that the channels don't mix)
88 |
89 | """
90 |
91 | # torch_weights : (channels, 1, kernel_size) = (ED, 1, d_conv)
92 |
93 | # mlx_weights : (channels, kernel_size, channels) = (ED, d_conv, ED)
94 |
95 | torch_weights = torch_weights.transpose(2, 1) # (channels, kernel_size, 1) = (ED, d_conv, 1)
96 | channels, kernel_size, _ = torch_weights.shape
97 |
98 | mlx_weights = torch.zeros(channels, kernel_size, channels)
99 |
100 | indices = torch.arange(channels)
101 | if torch_weights[:, :, 0].type() == 'torch.BFloat16Tensor':
102 | mlx_weights[indices, :, indices] = torch_weights[:, :, 0].float()
103 | else:
104 | mlx_weights[indices, :, indices] = torch_weights[:, :, 0]
105 |
106 | return mlx_weights
107 |
--------------------------------------------------------------------------------
/mambapy/mlx/pscan_mlx.py:
--------------------------------------------------------------------------------
1 | import math
2 | import mlx.core as mx
3 |
4 | """
5 |
6 | An implementation of the parallel scan algorithm in MLX (Blelloch version).
7 | The PyTorch implementation is easier to read.
8 |
9 | In a few words, this algorithm computes the sequence H[t] = A[t] * H[t-1] + X[t] in parallel for all H[t].
10 | This repaces the naive sequential way of computing this sequence : first H[0], then H[1] and so on.
11 |
12 | If you want more explanation about what's happening here, please see docs/pscan.ipynb.
13 |
14 | There are a few points which are different from PyTorch :
15 | - when taking a reshape, we have a new tensor rather than a simple view of it.
16 | this is quite problematic because this algorithm works with in-place updates
17 | thus, at the end of each iteration (down sweep and up sweep) we actually need to write back to A and X the computations done in the iteration.
18 | - there is no need for hand-written backward computation !
19 |
20 | Unfortunately, this parallel scan implementation is not worth it (compared to sequential implementation).
21 | From the different tests I've done, I suspect that it is partly caused by all the re-write we have to do at the end of each iteration. Still, turning them off does not make the pscan
22 | as fast as in PyTorch (relative to the sequential mode). There is a second version of this pscan (see the commented function) which works with indices only, but still is not competitive.
23 |
24 | This is *very different* from what is observed in PyTorch, where the pscan is on the same order of magnitude as the official Mamba impementation (for d_state=16), and orders of
25 | magnitude faster than the sequential mode.
26 |
27 | (Tests were done with a M2 Pro, 16GB)
28 |
29 | """
30 |
31 | def pscan_f(A, X):
32 | # A : (B, D, L, N)
33 | # X : (B, D, L, N)
34 |
35 | # modifies X in place by doing a parallel scan.
36 | # more formally, X will be populated by these values :
37 | # H[t] = A[t] * H[t-1] + X[t] with H[0] = 0
38 | # which are computed in parallel (2*log2(T) sequential steps (ideally), instead of T sequential steps)
39 |
40 | Aa = A
41 | Xa = X
42 |
43 | B, D, L, _ = A.shape
44 |
45 | num_steps = int(math.log2(L))
46 |
47 | # up sweep
48 | for k in range(num_steps):
49 | T = 2 * (Xa.shape[2] // 2)
50 |
51 | Aa = Aa[:, :, :T].reshape(B, D, T//2, 2, -1)
52 | Xa = Xa[:, :, :T].reshape(B, D, T//2, 2, -1)
53 |
54 | Xa[:, :, :, 1] += Aa[:, :, :, 1] * Xa[:, :, :, 0]
55 | Aa[:, :, :, 1] *= Aa[:, :, :, 0]
56 |
57 | A[:, :, 2**(k+1)-1::2**(k+1)] = Aa[:, :, :, 1]
58 | X[:, :, 2**(k+1)-1::2**(k+1)] = Xa[:, :, :, 1]
59 |
60 | Aa = Aa[:, :, :, 1]
61 | Xa = Xa[:, :, :, 1]
62 |
63 | # down sweep
64 | for k in range(num_steps-1, -1, -1):
65 | Aa = A[:, :, 2**k-1::2**k]
66 | Xa = X[:, :, 2**k-1::2**k]
67 |
68 | step_len = Xa.shape[2]
69 | T = 2 * (step_len // 2)
70 |
71 | if T < step_len:
72 | last_val_aa = Aa[:, :, -1] * Aa[:, :, -2]
73 | last_val_xa = Xa[:, :, -1] + Aa[:, :, -1] * Xa[:, :, -2]
74 |
75 | Aa = Aa[:, :, :T].reshape(B, D, T//2, 2, -1)
76 | Xa = Xa[:, :, :T].reshape(B, D, T//2, 2, -1)
77 |
78 | Xa[:, :, 1:, 0] += Aa[:, :, 1:, 0] * Xa[:, :, :-1, 1]
79 | Aa[:, :, 1:, 0] *= Aa[:, :, :-1, 1]
80 |
81 | if T == step_len:
82 | A[:, :, 2**k-1::2**(k+1)] = Aa[:, :, :, 0]
83 | X[:, :, 2**k-1::2**(k+1)] = Xa[:, :, :, 0]
84 | else:
85 | A[:, :, 2**k-1::2**(k+1)] = mx.concatenate([Aa[:, :, :, 0], mx.array([last_val_aa]).reshape(B, D, 1, -1)], axis=2)
86 | X[:, :, 2**k-1::2**(k+1)] = mx.concatenate([Xa[:, :, :, 0], mx.array([last_val_xa]).reshape(B, D, 1, -1)], axis=2)
87 |
88 | # main function, used in the Mamba model (mamba_mlx.py)
89 | def pscan(A_in, X_in):
90 | """
91 | Applies the parallel scan operation, as defined above. Returns a new tensor.
92 |
93 | Args:
94 | A_in : (B, L, ED, N)
95 | X_in : (B, L, ED, N)
96 |
97 | Returns:
98 | H : (B, L, ED, N)
99 | """
100 |
101 | A = A_in[:].transpose(0, 2, 1, 3)
102 | X = X_in[:].transpose(0, 2, 1, 3)
103 |
104 | pscan_f(A, X)
105 |
106 | return X.transpose(0, 2, 1, 3)
107 |
108 | """
109 | def pscan_f(A, X):
110 | # A : (B, D, L, N)
111 | # X : (B, D, L, N)
112 |
113 | # This functions is numerically equivalent to the preivous one, but instead of creating new arrays (Aa and Xa) at each iterations, it simply
114 | # updates in-place A and X at the correct indices (hence the quite not-understandable code)
115 | # (it only works with L being a power of 2)
116 |
117 | # While being faster than the previous one (~4x), it stills is not competitive with the naive sequential implementation
118 |
119 | _, _, L, _ = A.shape
120 |
121 | num_steps = int(math.log2(L))
122 |
123 | for k in range(0, num_steps):
124 | temp = 2**(k+1)
125 | X[:, :, temp-1::temp] += A[:, :, temp-1::temp] * X[:, :, 2**k-1::temp]
126 | A[:, :, temp-1::temp] *= A[:, :, 2**k-1::temp]
127 |
128 | for k in range(num_steps, -1, -1):
129 | temp = 2**(k+1)
130 | X[:, :, 3*2**k-1::temp] += A[:, :, 3*2**k-1::temp] * X[:, :, temp-1:L-2**k:temp]
131 | A[:, :, 3*2**k-1::temp] *= A[:, :, temp-1:L-2**k:temp]
132 | """
--------------------------------------------------------------------------------
/mambapy/mlx/scripts/generate.py:
--------------------------------------------------------------------------------
1 | import sys, os
2 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
3 |
4 | import warnings
5 | warnings.filterwarnings('ignore')
6 |
7 | import argparse
8 |
9 | import mlx.core as mx
10 |
11 | import transformers
12 | transformers.logging.set_verbosity_error()
13 | from transformers import AutoTokenizer
14 |
15 | from mamba_lm_mlx import MambaLM
16 |
17 | if __name__ == "__main__":
18 | parser = argparse.ArgumentParser()
19 | parser.add_argument('--hf_model_name', type=str, default='state-spaces/mamba-130m')
20 | parser.add_argument('--model_dir', type=str, default=None, help='local model to load. overwrites hf_model_name')
21 | parser.add_argument('--prompt', type=str, default='Mamba is a type of')
22 | parser.add_argument('--n_tokens', type=int, default=50, help='number of tokens to generate')
23 | parser.add_argument('--temperature', type=float, default=1.0)
24 | parser.add_argument('--top_k', type=int, default=None, help='top_k sampling : only sample from the top k most probable tokens')
25 |
26 | args = parser.parse_args()
27 |
28 | if args.model_dir is not None:
29 | raise NotImplementedError
30 | else:
31 | model = MambaLM.from_pretrained(args.hf_model_name)
32 | tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b')
33 |
34 | #print(model.config)
35 |
36 | mx.set_default_device(mx.gpu)
37 |
38 | output = model.generate(tokenizer, args.prompt, n_tokens_to_gen=args.n_tokens, temperature=args.temperature, top_k=args.top_k)
39 |
40 | print(output)
41 |
--------------------------------------------------------------------------------
/mambapy/mlx/utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | from typing import Union
3 |
4 | import numpy as np
5 | import mlx.core as mx
6 | import torch
7 |
8 | from misc import torch_to_mlx_depthwise_weights
9 |
10 | # TODO : map_mlx_to_mambapy_torch
11 |
12 | def load_config_hf(model_name):
13 | from transformers.utils import WEIGHTS_NAME, CONFIG_NAME
14 | from transformers.utils.hub import cached_file
15 |
16 | resolved_archive_file = cached_file(model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=False)
17 | return json.load(open(resolved_archive_file))
18 |
19 | def load_state_dict_hf(model_name):
20 | from transformers.utils import WEIGHTS_NAME, CONFIG_NAME
21 | from transformers.utils.hub import cached_file
22 |
23 | resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False)
24 | return torch.load(resolved_archive_file, weights_only=True, map_location='cpu', mmap=True)
25 |
26 | def map_mambapy_torch_to_mlx(torch_state_dict):
27 | new_state_dict = {}
28 | for key, value in torch_state_dict.items():
29 |
30 | # from torch to mlx, we need to convert the conv weights (see misc.py for explanations)
31 | if 'conv1d.weight' in key:
32 | value = torch_to_mlx_depthwise_weights(value)
33 |
34 | if 'conv1d' in key:
35 | key = key.replace('conv1d', 'conv1d.conv1d')
36 |
37 | if value.type() == 'torch.BFloat16Tensor':
38 | new_state_dict[key] = value.half().numpy()
39 | else:
40 | new_state_dict[key] = value.numpy()
41 |
42 | return new_state_dict
43 |
44 | def map_mambassm_torch_to_mlx(torch_state_dict):
45 | # convert mambassm to mambapy
46 | new_state_dict = {}
47 | for key in torch_state_dict:
48 | if key == 'backbone.embedding.weight' or key == 'backbone.norm_f.weight':
49 | new_key = key.replace('backbone.', '')
50 | else:
51 | new_key = key.replace('backbone', 'mamba')
52 |
53 | new_state_dict[new_key] = torch_state_dict[key]
54 |
55 | # convert mambapy to mlx
56 | return map_mambapy_torch_to_mlx(new_state_dict)
57 |
58 | """
59 | # todo : doesnt work, because MambaConfig and MambaLMConfig are not the ones defined in mamba.py and mamba_lm.py
60 | def mambapy_torch_to_mlx(torch_state_dict, config: Union[MambaConfig, MambaLMConfig]):
61 | mlx_state_dict = map_mambapy_torch_to_mlx(torch_state_dict)
62 |
63 | if isinstance(config, MambaConfig):
64 | model = Mamba(config)
65 | else:
66 | model = MambaLM(config)
67 |
68 | np.savez("weights.mlx.npz", **mlx_state_dict) # TODO name with config?
69 | model.update(tree_unflatten(list(mx.load("weights.mlx.npz").items())))
70 |
71 | # todo : name the file according to config
72 | # todo : check if file already exists
73 |
74 | return model
75 | """
76 |
--------------------------------------------------------------------------------
/mambapy/pscan.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn.functional as F
5 |
6 | """
7 |
8 | An implementation of the parallel scan operation in PyTorch (Blelloch version).
9 | Please see docs/pscan.ipynb for a detailed explanation of what happens here.
10 |
11 | """
12 |
13 | def npo2(len):
14 | """
15 | Returns the next power of 2 above len
16 | """
17 |
18 | return 2 ** math.ceil(math.log2(len))
19 |
20 | def pad_npo2(X):
21 | """
22 | Pads input length dim to the next power of 2
23 |
24 | Args:
25 | X : (B, L, D, N)
26 |
27 | Returns:
28 | Y : (B, npo2(L), D, N)
29 | """
30 |
31 | len_npo2 = npo2(X.size(1))
32 | pad_tuple = (0, 0, 0, 0, 0, len_npo2 - X.size(1))
33 | return F.pad(X, pad_tuple, "constant", 0)
34 |
35 | class PScan(torch.autograd.Function):
36 | @staticmethod
37 | def pscan(A, X):
38 | # A : (B, D, L, N)
39 | # X : (B, D, L, N)
40 |
41 | # modifies X in place by doing a parallel scan.
42 | # more formally, X will be populated by these values :
43 | # H[t] = A[t] * H[t-1] + X[t] with H[0] = 0
44 | # which are computed in parallel (2*log2(T) sequential steps (ideally), instead of T sequential steps)
45 |
46 | # only supports L that is a power of two (mainly for a clearer code)
47 |
48 | B, D, L, _ = A.size()
49 | num_steps = int(math.log2(L))
50 |
51 | # up sweep (last 2 steps unfolded)
52 | Aa = A
53 | Xa = X
54 | for _ in range(num_steps-2):
55 | T = Xa.size(2)
56 | Aa = Aa.view(B, D, T//2, 2, -1)
57 | Xa = Xa.view(B, D, T//2, 2, -1)
58 |
59 | Xa[:, :, :, 1].add_(Aa[:, :, :, 1].mul(Xa[:, :, :, 0]))
60 | Aa[:, :, :, 1].mul_(Aa[:, :, :, 0])
61 |
62 | Aa = Aa[:, :, :, 1]
63 | Xa = Xa[:, :, :, 1]
64 |
65 | # we have only 4, 2 or 1 nodes left
66 | if Xa.size(2) == 4:
67 | Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 0]))
68 | Aa[:, :, 1].mul_(Aa[:, :, 0])
69 |
70 | Xa[:, :, 3].add_(Aa[:, :, 3].mul(Xa[:, :, 2] + Aa[:, :, 2].mul(Xa[:, :, 1])))
71 | elif Xa.size(2) == 2:
72 | Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 0]))
73 | return
74 | else:
75 | return
76 |
77 | # down sweep (first 2 steps unfolded)
78 | Aa = A[:, :, 2**(num_steps-2)-1:L:2**(num_steps-2)]
79 | Xa = X[:, :, 2**(num_steps-2)-1:L:2**(num_steps-2)]
80 | Xa[:, :, 2].add_(Aa[:, :, 2].mul(Xa[:, :, 1]))
81 | Aa[:, :, 2].mul_(Aa[:, :, 1])
82 |
83 | for k in range(num_steps-3, -1, -1):
84 | Aa = A[:, :, 2**k-1:L:2**k]
85 | Xa = X[:, :, 2**k-1:L:2**k]
86 |
87 | T = Xa.size(2)
88 | Aa = Aa.view(B, D, T//2, 2, -1)
89 | Xa = Xa.view(B, D, T//2, 2, -1)
90 |
91 | Xa[:, :, 1:, 0].add_(Aa[:, :, 1:, 0].mul(Xa[:, :, :-1, 1]))
92 | Aa[:, :, 1:, 0].mul_(Aa[:, :, :-1, 1])
93 |
94 | @staticmethod
95 | def pscan_rev(A, X):
96 | # A : (B, D, L, N)
97 | # X : (B, D, L, N)
98 |
99 | # the same function as above, but in reverse
100 | # (if you flip the input, call pscan, then flip the output, you get what this function outputs)
101 | # it is used in the backward pass
102 |
103 | # only supports L that is a power of two (mainly for a clearer code)
104 |
105 | B, D, L, _ = A.size()
106 | num_steps = int(math.log2(L))
107 |
108 | # up sweep (last 2 steps unfolded)
109 | Aa = A
110 | Xa = X
111 | for _ in range(num_steps-2):
112 | T = Xa.size(2)
113 | Aa = Aa.view(B, D, T//2, 2, -1)
114 | Xa = Xa.view(B, D, T//2, 2, -1)
115 |
116 | Xa[:, :, :, 0].add_(Aa[:, :, :, 0].mul(Xa[:, :, :, 1]))
117 | Aa[:, :, :, 0].mul_(Aa[:, :, :, 1])
118 |
119 | Aa = Aa[:, :, :, 0]
120 | Xa = Xa[:, :, :, 0]
121 |
122 | # we have only 4, 2 or 1 nodes left
123 | if Xa.size(2) == 4:
124 | Xa[:, :, 2].add_(Aa[:, :, 2].mul(Xa[:, :, 3]))
125 | Aa[:, :, 2].mul_(Aa[:, :, 3])
126 |
127 | Xa[:, :, 0].add_(Aa[:, :, 0].mul(Xa[:, :, 1].add(Aa[:, :, 1].mul(Xa[:, :, 2]))))
128 | elif Xa.size(2) == 2:
129 | Xa[:, :, 0].add_(Aa[:, :, 0].mul(Xa[:, :, 1]))
130 | return
131 | else:
132 | return
133 |
134 | # down sweep (first 2 steps unfolded)
135 | Aa = A[:, :, 0:L:2**(num_steps-2)]
136 | Xa = X[:, :, 0:L:2**(num_steps-2)]
137 | Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 2]))
138 | Aa[:, :, 1].mul_(Aa[:, :, 2])
139 |
140 | for k in range(num_steps-3, -1, -1):
141 | Aa = A[:, :, 0:L:2**k]
142 | Xa = X[:, :, 0:L:2**k]
143 |
144 | T = Xa.size(2)
145 | Aa = Aa.view(B, D, T//2, 2, -1)
146 | Xa = Xa.view(B, D, T//2, 2, -1)
147 |
148 | Xa[:, :, :-1, 1].add_(Aa[:, :, :-1, 1].mul(Xa[:, :, 1:, 0]))
149 | Aa[:, :, :-1, 1].mul_(Aa[:, :, 1:, 0])
150 |
151 | @staticmethod
152 | def forward(ctx, A_in, X_in):
153 | """
154 | Applies the parallel scan operation, as defined above. Returns a new tensor.
155 | If you can, privilege sequence lengths that are powers of two.
156 |
157 | Args:
158 | A_in : (B, L, D, N)
159 | X_in : (B, L, D, N)
160 |
161 | Returns:
162 | H : (B, L, D, N)
163 | """
164 |
165 | L = X_in.size(1)
166 |
167 | # cloning is requiered because of the in-place ops
168 | if L == npo2(L):
169 | A = A_in.clone()
170 | X = X_in.clone()
171 | else:
172 | # pad tensors (and clone btw)
173 | A = pad_npo2(A_in) # (B, npo2(L), D, N)
174 | X = pad_npo2(X_in) # (B, npo2(L), D, N)
175 |
176 | # prepare tensors
177 | A = A.transpose(2, 1) # (B, D, npo2(L), N)
178 | X = X.transpose(2, 1) # (B, D, npo2(L), N)
179 |
180 | # parallel scan (modifies X in-place)
181 | PScan.pscan(A, X)
182 |
183 | ctx.save_for_backward(A_in, X)
184 |
185 | # slice [:, :L] (cut if there was padding)
186 | return X.transpose(2, 1)[:, :L]
187 |
188 | @staticmethod
189 | def backward(ctx, grad_output_in):
190 | """
191 | Flows the gradient from the output to the input. Returns two new tensors.
192 |
193 | Args:
194 | ctx : A_in : (B, L, D, N), X : (B, D, L, N)
195 | grad_output_in : (B, L, D, N)
196 |
197 | Returns:
198 | gradA : (B, L, D, N), gradX : (B, L, D, N)
199 | """
200 |
201 | A_in, X = ctx.saved_tensors
202 |
203 | L = grad_output_in.size(1)
204 |
205 | # cloning is requiered because of the in-place ops
206 | if L == npo2(L):
207 | grad_output = grad_output_in.clone()
208 | # the next padding will clone A_in
209 | else:
210 | grad_output = pad_npo2(grad_output_in) # (B, npo2(L), D, N)
211 | A_in = pad_npo2(A_in) # (B, npo2(L), D, N)
212 |
213 | # prepare tensors
214 | grad_output = grad_output.transpose(2, 1)
215 | A_in = A_in.transpose(2, 1) # (B, D, npo2(L), N)
216 | A = torch.nn.functional.pad(A_in[:, :, 1:], (0, 0, 0, 1)) # (B, D, npo2(L), N) shift 1 to the left (see hand derivation)
217 |
218 | # reverse parallel scan (modifies grad_output in-place)
219 | PScan.pscan_rev(A, grad_output)
220 |
221 | Q = torch.zeros_like(X)
222 | Q[:, :, 1:].add_(X[:, :, :-1] * grad_output[:, :, 1:])
223 |
224 | return Q.transpose(2, 1)[:, :L], grad_output.transpose(2, 1)[:, :L]
225 |
226 | pscan = PScan.apply
227 |
--------------------------------------------------------------------------------
/mambapy/tests/compare_mambapy_cuda.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from mamba_ssm import Mamba
4 |
5 | from mamba import MambaBlock, MambaConfig
6 |
7 | batch, length, dim = 2, 512, 16
8 |
9 | x = torch.randn(batch, length, dim).to("cuda")
10 | x.requieres_grad = True
11 |
12 | # CUDA Model
13 |
14 | torch.manual_seed(1)
15 |
16 | model_cuda = Mamba(
17 | # This module uses roughly 3 * expand * d_model^2 parameters
18 | d_model=dim, # Model dimension d_model
19 | d_state=16, # SSM state expansion factor
20 | d_conv=4, # Local convolution width
21 | expand=2, # Block expansion factor
22 | ).to("cuda")
23 |
24 | y_cuda = model_cuda(x)
25 |
26 | print(sum([p.numel() for p in model_cuda.parameters()]))
27 | print(y_cuda.shape)
28 |
29 | # mamba.py model
30 |
31 | torch.manual_seed(1)
32 |
33 | config = MambaConfig(d_model=dim, n_layers=1)
34 | model = MambaBlock(config).to("cuda")
35 |
36 | y_pscan = model(x)
37 |
38 | print(sum([p.numel() for p in model.parameters()]))
39 | print(y_pscan.shape)
40 |
41 | # forward #
42 | print(torch.allclose(y_cuda, y_pscan, rtol=0.1))
43 |
44 | # backward #
45 | J_cuda = y_cuda.sum()
46 | J_cuda.backward()
47 |
48 | J_pscan = y_pscan.sum()
49 | J_pscan.backward()
50 |
51 | print(torch.allclose(model_cuda.in_proj.weight.grad, model.in_proj.weight.grad, rtol=0.01))
52 |
53 |
54 |
--------------------------------------------------------------------------------
/mambapy/tests/mem_mamba_3.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from mamba import Mamba, MambaConfig
4 |
5 | device = "cuda"
6 |
7 | B, L, D, N = 16, 64, 128, 16
8 |
9 | config = MambaConfig(d_model=D, n_layers=8, d_state=N)
10 | model = Mamba(config).to(device)
11 |
12 | torch.cuda.empty_cache()
13 | torch.cuda.reset_peak_memory_stats(device)
14 |
15 | torch.cuda.reset_peak_memory_stats(device)
16 | initial_memory = torch.cuda.max_memory_allocated(device)
17 |
18 | for _ in range(100):
19 | X = torch.randn(B, L, D).to(device, non_blocking=True)
20 |
21 | output = model(X)
22 | loss = output.sum()
23 |
24 | loss.backward()
25 |
26 | peak_memory = torch.cuda.max_memory_allocated(device=device) # Peak memory during backward
27 |
28 | print(initial_memory/(1024**2))
29 | print(peak_memory/(1024**2))
30 | print("-----------------------------")
31 |
32 | # relate bien ce qu'on voit dans nvidia-smi (a qqchose pres : sans le torch-cuda overhead)
--------------------------------------------------------------------------------
/mambapy/tests/profiling_mamba.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.autograd.profiler as profiler
3 |
4 | from mamba import Mamba, MambaConfig
5 |
6 | device = "cuda"
7 |
8 | B, L, D, N = 16, 1024, 1024, 16
9 |
10 | config = MambaConfig(d_model=D, n_layers=8, d_state=N)
11 | model = Mamba(config).to(device)
12 |
13 | X = torch.randn(B, L, D).to(device, non_blocking=True)
14 |
15 | model(X)
16 |
17 | with profiler.profile(record_shapes=True, use_cuda=True, profile_memory=True) as prof:
18 | with profiler.record_function("model_forward"):
19 | output = model(X)
20 | loss = output.sum()
21 | with profiler.record_function("model_backward"):
22 | loss.backward()
23 |
24 | print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
25 | print(f"Peak CUDA Memory Usage: {prof.total_average().cuda_memory_usage / (1024 ** 2)} MB")
26 |
--------------------------------------------------------------------------------
/mambapy/tests/profiling_pscan.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.autograd.profiler as profiler
3 |
4 | from pscan import pscan
5 |
6 | B, L, D, N = 16, 1024, 32, 16
7 |
8 | A = torch.randn(B, L, D, N).to("cuda")
9 | X = torch.randn(B, L, D, N).to("cuda")
10 |
11 | H = pscan(A, X)
12 | with profiler.profile(record_shapes=True, use_cuda=True, profile_memory=True) as prof:
13 | with profiler.record_function("pscan_custom_function"):
14 | H = pscan(A, X)
15 |
16 | print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=100))
17 | prof.export_chrome_trace("pscan_profiling_trace.json")
18 |
19 | print(f"Peak CUDA Memory Usage: {prof.total_average().cuda_memory_usage / (1024 ** 2)} MB")
20 |
21 |
--------------------------------------------------------------------------------
/modules/__pycache__/loss.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SCNU-RISLAB/OverlapMamba/b20add9b83ea87727a993eb316869d271436be78/modules/__pycache__/loss.cpython-38.pyc
--------------------------------------------------------------------------------
/modules/__pycache__/loss.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SCNU-RISLAB/OverlapMamba/b20add9b83ea87727a993eb316869d271436be78/modules/__pycache__/loss.cpython-39.pyc
--------------------------------------------------------------------------------
/modules/__pycache__/netvlad.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SCNU-RISLAB/OverlapMamba/b20add9b83ea87727a993eb316869d271436be78/modules/__pycache__/netvlad.cpython-38.pyc
--------------------------------------------------------------------------------
/modules/__pycache__/netvlad.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SCNU-RISLAB/OverlapMamba/b20add9b83ea87727a993eb316869d271436be78/modules/__pycache__/netvlad.cpython-39.pyc
--------------------------------------------------------------------------------
/modules/loss.py:
--------------------------------------------------------------------------------
1 | # import os
2 | # import sys
3 | # p = os.path.dirname(os.path.dirname((os.path.abspath(__file__))))
4 | # if p not in sys.path:
5 | # sys.path.append(p)
6 | #
7 | # import torch
8 | # import torch.nn as nn
9 | # import os
10 | # import numpy as np
11 | #
12 | #
13 | # def best_pos_distance(query, pos_vecs):
14 | # num_pos = pos_vecs.shape[0]
15 | # query_copies = query.repeat(int(num_pos), 1)
16 | # diff = ((pos_vecs - query_copies) ** 2).sum(1)
17 | # diff, _ = torch.sort(diff, 0)
18 | # if num_pos == 0:
19 | # print(diff.size())
20 | #
21 | # min_pos, _ = diff.min(0)
22 | # max_pos, _ = diff.max(0)
23 | # return min_pos, max_pos
24 | #
25 | #
26 | # def triplet_loss(q_vec, pos_vecs, neg_vecs, margin, use_min=False, lazy=False, ignore_zero_loss=False):
27 | #
28 | # if pos_vecs.shape[0] == 0:
29 | #
30 | # num_neg = neg_vecs.shape[0]
31 | # query_copies = q_vec.repeat(int(num_neg), 1)
32 | #
33 | # negative = ((neg_vecs - query_copies) ** 2).sum(1).unsqueeze(1)
34 | # negative, _ = negative.min(0)
35 | # negative = negative.repeat(int(num_neg), 1)
36 | #
37 | # loss = margin - negative
38 | #
39 | # loss = loss.clamp(min=0.0)
40 | #
41 | # # loss = torch.log1p(loss)
42 | #
43 | # if lazy:
44 | # triplet_loss = loss.max(1)[0]
45 | # else:
46 | # triplet_loss = loss.sum(0)
47 | # if ignore_zero_loss:
48 | # hard_triplets = torch.gt(triplet_loss, 1e-16).float()
49 | # num_hard_triplets = torch.sum(hard_triplets)
50 | # triplet_loss = triplet_loss.sum() / (num_hard_triplets + 1e-16)
51 | # else:
52 | # triplet_loss = triplet_loss.mean()
53 | #
54 | # else:
55 | # min_pos, max_pos = best_pos_distance(q_vec, pos_vecs)
56 | #
57 | # if use_min:
58 | # positive = min_pos
59 | # else:
60 | # positive = max_pos
61 | # num_neg = neg_vecs.shape[0]
62 | # query_copies = q_vec.repeat(int(num_neg), 1)
63 | # positive = positive.view(-1, 1)
64 | # positive = positive.repeat(int(num_neg), 1)
65 | #
66 | # negative = ((neg_vecs - query_copies) ** 2).sum(1).unsqueeze(1)
67 | # negative, _ = negative.min(0)
68 | # negative = negative.repeat(int(num_neg), 1)
69 | #
70 | # loss = margin + positive - negative
71 | #
72 | # loss = loss.clamp(min=0.0)
73 | #
74 | # # loss = torch.log1p(loss)
75 | #
76 | # if lazy:
77 | # triplet_loss = loss.max(1)[0]
78 | # else:
79 | # triplet_loss = loss.sum(0)
80 | # if ignore_zero_loss:
81 | # hard_triplets = torch.gt(triplet_loss, 1e-16).float()
82 | # num_hard_triplets = torch.sum(hard_triplets)
83 | # triplet_loss = triplet_loss.sum() / (num_hard_triplets + 1e-16)
84 | # else:
85 | # triplet_loss = triplet_loss.mean()
86 | # return triplet_loss
87 | #
88 | # def triplet_loss_inv(q_vec, pos_vecs, neg_vecs, margin, use_min=True, lazy=False, ignore_zero_loss=False):
89 | #
90 | # min_neg, max_neg = best_pos_distance(q_vec, neg_vecs)
91 | #
92 | # if use_min:
93 | # negative = min_neg
94 | # else:
95 | # negative = max_neg
96 | # num_neg = neg_vecs.shape[0]
97 | # num_pos= pos_vecs.shape[0]
98 | # query_copies = q_vec.repeat(int(num_pos), 1)
99 | # negative = negative.view(-1, 1)
100 | # negative = negative.repeat(int(num_pos), 1)
101 | #
102 | # loss = margin - negative + ((pos_vecs - query_copies) ** 2).sum(1).unsqueeze(1)
103 | #
104 | # loss = loss.clamp(min=0.0)
105 | #
106 | #
107 | #
108 | # if lazy:
109 | # triplet_loss = loss.max(1)[0]
110 | # else:
111 | # triplet_loss = loss.sum(0)
112 | # if ignore_zero_loss:
113 | # hard_triplets = torch.gt(triplet_loss, 1e-16).float()
114 | # num_hard_triplets = torch.sum(hard_triplets)
115 | # triplet_loss = triplet_loss.sum() / (num_hard_triplets + 1e-16)
116 | # else:
117 | # triplet_loss = triplet_loss.mean()
118 | # return triplet_loss
119 | #
120 | #
121 | # def triplet_loss_wrapper(q_vec, pos_vecs, neg_vecs, m1, m2, use_min=False, lazy=False, ignore_zero_loss=False):
122 | # return triplet_loss(q_vec, pos_vecs, neg_vecs, m1, use_min, lazy, ignore_zero_loss)
123 | import os
124 | import sys
125 |
126 | p = os.path.dirname(os.path.dirname((os.path.abspath(__file__))))
127 | if p not in sys.path:
128 | sys.path.append(p)
129 |
130 | import torch
131 | import torch.nn as nn
132 | import os
133 | import numpy as np
134 |
135 |
136 | def best_pos_distance(query, pos_vecs):
137 | num_pos = pos_vecs.shape[0]
138 | query_copies = query.repeat(int(num_pos), 1)
139 | diff = ((pos_vecs - query_copies) ** 2).sum(1)
140 |
141 | min_pos, _ = diff.min(0)
142 | max_pos, _ = diff.max(0)
143 | return min_pos, max_pos
144 |
145 |
146 | def triplet_loss(q_vec, pos_vecs, neg_vecs, margin, use_min=False, lazy=False, ignore_zero_loss=False):
147 | if pos_vecs.shape[0] == 0:
148 |
149 | num_neg = neg_vecs.shape[0]
150 | num_pos = pos_vecs.shape[0]
151 | query_copies = q_vec.repeat(int(num_neg), 1)
152 |
153 | negative = ((neg_vecs - query_copies) ** 2).sum(1).unsqueeze(1)
154 |
155 | loss = margin - ((neg_vecs - query_copies) ** 2).sum(1).unsqueeze(1)
156 |
157 | loss = loss.clamp(min=0.0)
158 |
159 | if lazy:
160 | triplet_loss = loss.max(1)[0]
161 | else:
162 | triplet_loss = loss.sum(0)
163 | if ignore_zero_loss:
164 | hard_triplets = torch.gt(triplet_loss, 1e-16).float()
165 | num_hard_triplets = torch.sum(hard_triplets)
166 | triplet_loss = triplet_loss.sum() / (num_hard_triplets + 1e-16)
167 | else:
168 | triplet_loss = triplet_loss.mean()
169 | return triplet_loss
170 |
171 | else:
172 | min_pos, max_pos = best_pos_distance(q_vec, pos_vecs)
173 |
174 | if use_min:
175 | positive = min_pos
176 | else:
177 | positive = max_pos
178 | num_neg = neg_vecs.shape[0]
179 | num_pos = pos_vecs.shape[0]
180 | query_copies = q_vec.repeat(int(num_neg), 1)
181 | positive = positive.view(-1, 1)
182 | positive = positive.repeat(int(num_neg), 1)
183 |
184 | negative = ((neg_vecs - query_copies) ** 2).sum(1).unsqueeze(1)
185 |
186 | loss = margin + positive - ((neg_vecs - query_copies) ** 2).sum(1).unsqueeze(1)
187 |
188 | loss = loss.clamp(min=0.0)
189 |
190 | if lazy:
191 | triplet_loss = loss.max(1)[0]
192 | else:
193 | triplet_loss = loss.sum(0)
194 | if ignore_zero_loss:
195 | hard_triplets = torch.gt(triplet_loss, 1e-16).float()
196 | num_hard_triplets = torch.sum(hard_triplets)
197 | triplet_loss = triplet_loss.sum() / (num_hard_triplets + 1e-16)
198 | else:
199 | triplet_loss = triplet_loss.mean()
200 | return triplet_loss
201 |
202 |
203 | def triplet_loss_inv(q_vec, pos_vecs, neg_vecs, margin, use_min=True, lazy=False, ignore_zero_loss=False):
204 | min_neg, max_neg = best_pos_distance(q_vec, neg_vecs)
205 |
206 | if use_min:
207 | negative = min_neg
208 | else:
209 | negative = max_neg
210 | num_neg = neg_vecs.shape[0]
211 | num_pos = pos_vecs.shape[0]
212 | query_copies = q_vec.repeat(int(num_pos), 1)
213 | negative = negative.view(-1, 1)
214 | negative = negative.repeat(int(num_pos), 1)
215 |
216 | loss = margin - negative + ((pos_vecs - query_copies) ** 2).sum(1).unsqueeze(1)
217 |
218 | loss = loss.clamp(min=0.0)
219 |
220 | if lazy:
221 | triplet_loss = loss.max(1)[0]
222 | else:
223 | triplet_loss = loss.sum(0)
224 | if ignore_zero_loss:
225 | hard_triplets = torch.gt(triplet_loss, 1e-16).float()
226 | num_hard_triplets = torch.sum(hard_triplets)
227 | triplet_loss = triplet_loss.sum() / (num_hard_triplets + 1e-16)
228 | else:
229 | triplet_loss = triplet_loss.mean()
230 | return triplet_loss
231 |
232 |
233 | def triplet_loss_wrapper(q_vec, pos_vecs, neg_vecs, m1, m2, use_min=False, lazy=False, ignore_zero_loss=False):
234 | return triplet_loss(q_vec, pos_vecs, neg_vecs, m1, use_min, lazy, ignore_zero_loss)
--------------------------------------------------------------------------------
/modules/netvlad.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | p = os.path.dirname(os.path.dirname((os.path.abspath(__file__))))
4 | if p not in sys.path:
5 | sys.path.append(p)
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | import math
11 |
12 |
13 | class NetVLADLoupe(nn.Module):
14 | def __init__(self, feature_size, max_samples, cluster_size, output_dim,
15 | gating=True, add_batch_norm=True, is_training=True):
16 | super(NetVLADLoupe, self).__init__()
17 | self.feature_size = feature_size
18 | self.max_samples = max_samples
19 | self.output_dim = output_dim
20 | self.is_training = is_training
21 | self.gating = gating
22 | self.add_batch_norm = add_batch_norm
23 | self.cluster_size = cluster_size
24 | self.softmax = nn.Softmax(dim=-1)
25 |
26 | self.cluster_weights = nn.Parameter(torch.randn(
27 | feature_size, cluster_size) * 1 / math.sqrt(feature_size))
28 | self.cluster_weights2 = nn.Parameter(torch.randn(
29 | 1, feature_size, cluster_size) * 1 / math.sqrt(feature_size))
30 | self.hidden1_weights = nn.Parameter(torch.randn(
31 | cluster_size * feature_size, output_dim) * 1 / math.sqrt(feature_size))
32 |
33 | if add_batch_norm:
34 | self.cluster_biases = None
35 | self.bn1 = nn.BatchNorm1d(cluster_size)
36 | else:
37 | self.cluster_biases = nn.Parameter(torch.randn(
38 | cluster_size) * 1 / math.sqrt(feature_size))
39 | self.bn1 = None
40 |
41 | self.bn2 = nn.BatchNorm1d(output_dim)
42 |
43 | if gating:
44 | self.context_gating = GatingContext(
45 | output_dim, add_batch_norm=add_batch_norm)
46 |
47 | def forward(self, x):
48 | x = x.transpose(1, 3).contiguous()
49 | x = x.view((-1, self.max_samples, self.feature_size))
50 | activation = torch.matmul(x, self.cluster_weights)
51 | if self.add_batch_norm:
52 | # activation = activation.transpose(1,2).contiguous()
53 | activation = activation.view(-1, self.cluster_size)
54 | activation = self.bn1(activation)
55 | activation = activation.view(-1, self.max_samples, self.cluster_size)
56 | # activation = activation.transpose(1,2).contiguous()
57 | else:
58 | activation = activation + self.cluster_biases
59 | activation = self.softmax(activation)
60 | activation = activation.view((-1, self.max_samples, self.cluster_size))
61 |
62 | a_sum = activation.sum(-2, keepdim=True)
63 | a = a_sum * self.cluster_weights2
64 |
65 | activation = torch.transpose(activation, 2, 1)
66 | x = x.view((-1, self.max_samples, self.feature_size))
67 | vlad = torch.matmul(activation, x)
68 | vlad = torch.transpose(vlad, 2, 1)
69 | vlad = vlad - a
70 |
71 | vlad = F.normalize(vlad, dim=1, p=2)
72 | # vlad = vlad.view((-1, self.cluster_size * self.feature_size))
73 | vlad = vlad.reshape((-1, self.cluster_size * self.feature_size))
74 | vlad = F.normalize(vlad, dim=1, p=2)
75 |
76 | vlad = torch.matmul(vlad, self.hidden1_weights)
77 |
78 | # vlad = self.bn2(vlad)
79 |
80 | if self.gating:
81 | vlad = self.context_gating(vlad)
82 |
83 | return vlad
84 |
85 |
86 | class GatingContext(nn.Module):
87 | def __init__(self, dim, add_batch_norm=True):
88 | super(GatingContext, self).__init__()
89 | self.dim = dim
90 | self.add_batch_norm = add_batch_norm
91 | self.gating_weights = nn.Parameter(
92 | torch.randn(dim, dim) * 1 / math.sqrt(dim))
93 | self.sigmoid = nn.Sigmoid()
94 |
95 | if add_batch_norm:
96 | self.gating_biases = None
97 | self.bn1 = nn.BatchNorm1d(dim)
98 | else:
99 | self.gating_biases = nn.Parameter(
100 | torch.randn(dim) * 1 / math.sqrt(dim))
101 | self.bn1 = None
102 |
103 | def forward(self, x):
104 | gates = torch.matmul(x, self.gating_weights)
105 |
106 | if self.add_batch_norm:
107 | gates = self.bn1(gates)
108 | else:
109 | gates = gates + self.gating_biases
110 |
111 | gates = self.sigmoid(gates)
112 |
113 | activation = x * gates
114 |
115 | return activation
116 |
117 |
118 | if __name__ == '__main__':
119 | net_vlad = NetVLADLoupe(feature_size=1024, max_samples=360, cluster_size=16,
120 | output_dim=20, gating=True, add_batch_norm=True,
121 | is_training=True)
122 | # input (bs, 1024, 360, 1)
123 | torch.manual_seed(1234)
124 | input_tensor = F.normalize(torch.randn((1,1024,360,1)), dim=1)
125 | input_tensor2 = torch.zeros_like(input_tensor)
126 | input_tensor2[:, :, 2:, :] = input_tensor[:, :, 0:-2, :].clone()
127 | input_tensor2[:, :, :2, :] = input_tensor[:, :, -2:, :].clone()
128 | input_tensor2= F.normalize(input_tensor2, dim=1)
129 | input_tensor_com = torch.cat((input_tensor, input_tensor2), dim=0)
130 |
131 | # print(input_tensor[0,0,:,0])
132 | # print(input_tensor2[0,0,:,0])
133 | print("==================================")
134 |
135 | with torch.no_grad():
136 | net_vlad.eval()
137 | # output_tensor = net_vlad(input_tensor_com)
138 | # print(output_tensor)
139 | out1 = net_vlad(input_tensor)
140 | print(out1)
141 | net_vlad.eval()
142 | # input_tensor2[:, :, 20:, :] = 0.1
143 | input_tensor2 = F.normalize(input_tensor2, dim=1)
144 | out2 = net_vlad(input_tensor2)
145 | print(out2)
146 | net_vlad.eval()
147 | input_tensor3 = torch.randn((1,1024,360,1))
148 | out3 = net_vlad(input_tensor3)
149 | print(out3)
150 |
151 |
152 | print(((out1-out2)**2).sum(1))
153 | print(((out1-out3)**2).sum(1))
154 |
155 |
156 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | matplotlib==3.4.3
2 | numpy==1.20.3
3 | tensorboardx==2.4
4 | pyyaml==6.0
5 | opencv-python==4.5.4.58
6 | faiss-cpu==1.7.1
7 | scikit-learn==0.24.2
8 |
--------------------------------------------------------------------------------
/test/test_kitti00_PR.py:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | import os
5 | import sys
6 | from datetime import datetime
7 |
8 | p = os.path.dirname(os.path.dirname((os.path.abspath(__file__))))
9 | if p not in sys.path:
10 | sys.path.append(p)
11 |
12 | import copy
13 | import matplotlib.pyplot as plt
14 | import numpy as np
15 | import yaml
16 | from sklearn import metrics
17 |
18 |
19 | def cal_pr_curve(prediction_file_name, ground_truth_file_name):
20 | precisions = []
21 | recalls = []
22 |
23 | print(1)
24 | des_dists = np.load(prediction_file_name)['arr_0']
25 | des_dists = np.asarray(des_dists, dtype='float32')
26 | des_dists = des_dists.reshape((len(des_dists), 3))
27 | print('Load descrptor distances predictions with pairs of ', len(des_dists))
28 | print(des_dists.shape)
29 |
30 | ground_truth = np.load(ground_truth_file_name, allow_pickle='True')['arr_0']
31 |
32 | """Changing the threshold will lead to different test results"""
33 | for thres in np.arange(0, 1.0, 0.01):
34 | current_time = datetime.now()
35 | formatted_time = current_time.strftime("[%Y-%m-%d %H:%M:%S]")
36 | print(formatted_time, " thresh: ", thres)
37 | tps = 0
38 | fps = 0
39 | tns = 0
40 | fns = 0
41 | """Update the start frame index"""
42 | for idx in range(150, len(ground_truth) - 1):
43 | gt_idxes = ground_truth[int(idx)]
44 | reject_flag = False
45 |
46 | if des_dists[des_dists[:, 0] == int(idx), 2][0] > thres:
47 | reject_flag = True
48 | if reject_flag:
49 | if not any(gt_idxes):
50 | tns += 1
51 | else:
52 | fns += 1
53 | else:
54 | predicted_idx = des_dists[des_dists[:, 0] == int(idx), 1][0]
55 | if any(predicted_idx == gt for gt in gt_idxes):
56 | # if des_dists[des_dists[:, 0] == int(idx), 1][0] in gt_idxes:
57 | tps += 1
58 | else:
59 | fps += 1
60 |
61 | # for idx in range(150, len(ground_truth) - 1):
62 | # gt_idxes = ground_truth[int(idx)]
63 | # reject_flag = False
64 | #
65 | # if des_dists[des_dists[:, 0] == int(idx), 2][0] > thres:
66 | # reject_flag = True
67 | # if reject_flag:
68 | # if not np.any(gt_idxes):
69 | # tns += 1
70 | # else:
71 | # fns += 1
72 | # else:
73 | # if any(des_dists[des_dists[:, 0] == int(idx), 1][0] == gt for gt in gt_idxes):
74 | # tps += 1
75 | # else:
76 | # fps += 1
77 |
78 | if fps == 0:
79 | precision = 1
80 | else:
81 | precision = float(tps) / (float(tps) + float(fps))
82 | if fns == 0:
83 | recall = 1
84 | else:
85 | recall = float(tps) / (float(tps) + float(fns))
86 |
87 | f1_score = 2 * (precision * recall) / (precision + recall + 1e-10)
88 | print("f1 score: ", f1_score)
89 | print("recall: ", recall)
90 | print("precision:", precision)
91 | precisions.append(precision)
92 | recalls.append(recall)
93 |
94 | # print("precision ", precision)
95 | # print("recall ", recall)
96 |
97 | print("Highest precision: %s" % max(precisions))
98 | print("Highest recall: %s" % max(recalls))
99 |
100 | return precisions, recalls
101 |
102 |
103 | """Ploting and saving AUC."""
104 |
105 |
106 | def plotPRC(precisions, recalls, print_2file=True):
107 | # initial plot
108 | plt.clf()
109 |
110 | if print_2file:
111 | save_name = "./" + dir_name + "/PR.png"
112 |
113 | recalls, precisions = (list(t) for t in zip(*sorted(zip(recalls, precisions), reverse=True)))
114 | auc = metrics.auc(recalls, precisions) * 100
115 |
116 | plt.plot(recalls, precisions, linewidth=1.0)
117 |
118 | plt.xlabel('Recall')
119 | plt.ylabel('Precision')
120 | plt.ylim([0.0, 1])
121 | plt.xlim([0.0, 1])
122 | plt.title('auc = ' + str(auc))
123 |
124 | if print_2file:
125 | plt.savefig(save_name)
126 |
127 | plt.show()
128 |
129 |
130 | """Calculating Max F1 score."""
131 |
132 |
133 | def cal_F1_score():
134 | pr_values = np.load("./" + dir_name + "/PR.npz")
135 | f1_scores_max = -1
136 | for i in range(pr_values['precisions'].shape[0]):
137 | precision = pr_values['precisions'][i]
138 | recall = pr_values['recalls'][i]
139 | f1_score = 2 * (precision * recall) / (precision + recall + 1e-10)
140 | if f1_score > f1_scores_max:
141 | f1_scores_max = copy.deepcopy(f1_score)
142 | print("f1_score on KITTI test seq: ", f1_scores_max)
143 |
144 |
145 | def test_with_PR(ground_truth_file_name):
146 | plot_curve = True
147 | save_pr_results = True
148 | print_2file = True
149 |
150 | prediction_file_name = "./" + dir_name + "/predicted_des_L2_dis.npz"
151 |
152 | if not os.path.exists("./" + dir_name + "/PR.npz"):
153 | precisions, recalls = cal_pr_curve(prediction_file_name, ground_truth_file_name)
154 |
155 | pr_values = np.asarray([precisions, recalls])
156 | pr_values = pr_values[:, np.argsort(pr_values[0, :])]
157 |
158 | if (plot_curve):
159 | plotPRC(pr_values[0], pr_values[1], print_2file)
160 |
161 | if (save_pr_results):
162 | np.savez_compressed("./" + dir_name + "/PR", precisions=pr_values[0], recalls=pr_values[1])
163 |
164 | cal_F1_score()
165 |
166 |
167 | if __name__ == "__main__":
168 | # load config ================================================================
169 | config_filename = '../config/config.yml'
170 |
171 | dir_name = "nclt/1"
172 |
173 | config = yaml.safe_load(open(config_filename))
174 | ground_truth_file_name = config["test_config"]["gt_file"]
175 | # ============================================================================
176 |
177 | """ground truth file follows OverlapNet"""
178 | test_with_PR(ground_truth_file_name)
179 |
--------------------------------------------------------------------------------
/test/test_kitti00_prepare.py:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | import os
5 | import sys
6 | p = os.path.dirname(os.path.dirname((os.path.abspath(__file__))))
7 | if p not in sys.path:
8 | sys.path.append(p)
9 |
10 | from matplotlib import pyplot as plt
11 | import torch
12 | import numpy as np
13 | from modules.overlap_transformer import featureExtracter
14 | from tools.read_samples import read_one_need_from_seq
15 | np.set_printoptions(threshold=sys.maxsize)
16 | from tools.utils.utils import *
17 | import faiss
18 | import yaml
19 | import time
20 |
21 | """
22 | Evaluation is conducted on KITTI 00 in our work.
23 | Args:
24 | amodel: pretrained model.
25 | data_root_folder: dataset root of KITTI.
26 | test_seq: "00" in our work.
27 | """
28 | def test_chosen_seq(amodel, data_root_folder, test_seq):
29 | range_images = os.listdir(os.path.join(data_root_folder, test_seq, "depth_map"))
30 |
31 | des_list = np.zeros((len(range_images), 256))
32 | des_list_inv = np.zeros((len(range_images), 256))
33 |
34 | """Calculate the descriptors of scans"""
35 | print("Calculating the descriptors of scans ...")
36 | for i in range(0, len(range_images)):
37 | current_batch = read_one_need_from_seq(data_root_folder, str(i).zfill(6), test_seq)
38 | current_batch_inv_double = torch.cat((current_batch, current_batch), dim=-1)
39 | current_batch_inv = current_batch_inv_double[:,:,:,450:1350]
40 | current_batch = torch.cat((current_batch, current_batch_inv), dim=0)
41 | amodel.eval()
42 | t = time.time()
43 | current_batch_des = amodel(current_batch)
44 | #print(f'cost:{time.time() - t:.8f}s')
45 | a = time.time() - t
46 | a *= 1000
47 | #print(a/2)
48 | des_list[i, :] = current_batch_des[0, :].cpu().detach().numpy()
49 | des_list_inv[i, :] = current_batch_des[1, :].cpu().detach().numpy()
50 |
51 | des_list = des_list.astype('float32')
52 | """TODO: You can test the rotation-invariance with des_list_inv."""
53 | des_list_inv = des_list_inv.astype('float32')
54 |
55 | row_list = []
56 | # for i in range(101, 3817 - 1):
57 | # for i in range(101, 4541-1):
58 | for i in range(101, 28239 - 1):
59 | nlist = 1
60 | k = 50
61 | d = 256
62 | quantizer = faiss.IndexFlatL2(d)
63 | index = faiss.IndexIVFFlat(quantizer, d, nlist, faiss.METRIC_L2)
64 | assert not index.is_trained
65 | index.train(des_list[:i-100,:])
66 | assert index.is_trained
67 | index.add(des_list[:i-100,:])
68 | plt.clf()
69 | """Faiss searching"""
70 | D, I = index.search(des_list[i, :].reshape(1, -1), k)
71 | for j in range(D.shape[1]):
72 | """The nearest 100 frames are not considered."""
73 | if (i-I[:,j])<100:
74 | continue
75 | else:
76 | one_row = np.zeros((1,3))
77 | one_row[:, 0] = i
78 | one_row[:, 1] = I[:,j]
79 | one_row[:, 2] = D[:,j]
80 | row_list.append(one_row)
81 | print(str(i) + "---->" + str(I[:, j]) + " " + str(D[:, j]))
82 |
83 | row_list_arr = np.array(row_list)
84 | """Saving for the next test"""
85 | folder_name = "./nclt/17/"
86 | if not os.path.exists(folder_name):
87 | os.mkdir(folder_name)
88 | np.savez_compressed(folder_name + "predicted_des_L2_dis", row_list_arr)
89 |
90 |
91 | class testHandler():
92 | def __init__(self, height=64, width=900, channels=5, norm_layer=None,
93 | data_root_folder=None,
94 | test_seq=None, test_weights=None):
95 | super(testHandler, self).__init__()
96 |
97 | self.height = height
98 | self.width = width
99 | self.channels = channels
100 | self.norm_layer = norm_layer
101 | self.data_root_folder = data_root_folder
102 | self.test_seq = test_seq
103 |
104 |
105 | self.amodel = featureExtracter(channels=self.channels)
106 | # self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
107 | self.device = torch.device("cuda")
108 | self.amodel.to(self.device)
109 |
110 | self.parameters = self.amodel.parameters()
111 | self.test_weights = test_weights
112 | self.overlap_thresh = 0.3
113 |
114 | def eval(self):
115 | with torch.no_grad():
116 | print("Loading weights from ", self.test_weights)
117 | checkpoint = torch.load(self.test_weights)
118 | self.amodel.load_state_dict(checkpoint['state_dict'])
119 | test_chosen_seq(self.amodel, self.data_root_folder, self.test_seq)
120 |
121 |
122 |
123 |
124 | if __name__ == '__main__':
125 |
126 | # load config ================================================================
127 | config_filename = '../config/config.yml'
128 | config = yaml.safe_load(open(config_filename))
129 | data_root_folder = config["data_root"]["data_root_folder"]
130 | test_seq = config["test_config"]["test_seqs"][0]
131 | test_weights = config["test_config"]["test_weights"]
132 | # ============================================================================
133 |
134 | """
135 | testHandler to handle with testing process.
136 | Args:
137 | height: the height of the range image (the beam number for convenience).
138 | width: the width of the range image (900, alone the lines of OverlapNet).
139 | channels: 1 for depth only in our work.
140 | norm_layer: None in our work for better model.
141 | use_transformer: Whether to use MHSA.
142 | data_root_folder: root of KITTI sequences. It's better to follow our file structure.
143 | test_seq: "00" in the evaluation.
144 | test_weights: pretrained weights.
145 | """
146 | test_handler = testHandler(height=32, width=900, channels=1, norm_layer=None,
147 | data_root_folder=data_root_folder, test_seq=test_seq, test_weights=test_weights)
148 | test_handler.eval()
149 |
150 |
--------------------------------------------------------------------------------
/test/test_kitti00_topN.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Developed by Junyi Ma, Xieyuanli Chen, and Jun Zhang
3 | # This file is covered by the LICENSE file in the root of the project OverlapTransformer:
4 | # https://github.com/haomo-ai/OverlapTransformer/
5 | # Brief: calculate Recall@N using the prediction files generated by test_kitti00_prepare.py
6 |
7 |
8 | import os
9 | import sys
10 | p = os.path.dirname(os.path.dirname((os.path.abspath(__file__))))
11 | if p not in sys.path:
12 | sys.path.append(p)
13 |
14 | import copy
15 | import matplotlib.pyplot as plt
16 | import numpy as np
17 | import yaml
18 | from sklearn import metrics
19 |
20 |
21 |
22 | def cal_topN(prediction_file_name, ground_truth_file_name, topn):
23 | precisions = []
24 | recalls = []
25 |
26 |
27 | # loading overlap predictions
28 | des_dists = np.load(prediction_file_name)['arr_0']
29 | des_dists = np.asarray(des_dists, dtype='float32')
30 | des_dists = des_dists.reshape((len(des_dists), 3))
31 |
32 |
33 | # loading ground truth in terms of distance
34 | ground_truth = np.load(ground_truth_file_name, allow_pickle='True')['arr_0']
35 |
36 |
37 | all_have_gt = 0
38 | tps = 0
39 |
40 |
41 | for idx in range(0,len(ground_truth)-1):
42 | gt_idxes = ground_truth[int(idx)]
43 |
44 | if not gt_idxes.any():
45 | continue
46 |
47 | all_have_gt += 1
48 | for t in range(topn):
49 | if des_dists[des_dists[:,0]==int(idx),:][t, 1] in gt_idxes:
50 | tps += 1
51 | break
52 |
53 | recall_topN = tps/all_have_gt
54 | print(recall_topN)
55 |
56 |
57 | return recall_topN
58 |
59 |
60 |
61 |
62 | def test_with_topN(topn, ground_truth_file_name):
63 |
64 | prediction_file_name = dir_name + "predicted_des_L2_dis.npz"
65 | recall_topN = cal_topN(prediction_file_name, ground_truth_file_name, topn)
66 | return recall_topN
67 |
68 |
69 |
70 | if __name__ == "__main__":
71 | # load config ================================================================
72 | config_filename = '../config/config.yml'
73 | config = yaml.safe_load(open(config_filename))
74 | ground_truth_file_name = config["test_config"]["gt_file"]
75 | # ============================================================================
76 | dir_name = "./test_/"
77 | # topn = 46 # for KITTI 00 top1%
78 | topn = 38 # for KITTI 00 top1%
79 | recall_list = []
80 | for i in range(1, topn):
81 | print("top"+str(i)+": ")
82 | rec = test_with_topN(i, ground_truth_file_name)
83 | recall_list.append(rec)
84 | print(recall_list)
85 | np.save(dir_name + "recall_list", np.array(recall_list))
86 |
87 |
88 |
--------------------------------------------------------------------------------
/test/test_nclt_topn.py:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | import os
5 | import sys
6 |
7 | p = os.path.dirname(os.path.dirname((os.path.abspath(__file__))))
8 | if p not in sys.path:
9 | sys.path.append(p)
10 |
11 | import numpy as np
12 | import yaml
13 | import time
14 |
15 | rubish = []
16 |
17 |
18 | def load_files(folder):
19 | file_paths = [os.path.join(dp, f) for dp, dn, fn in os.walk(
20 | os.path.expanduser(folder)) for f in fn]
21 | file_paths.sort()
22 | return file_paths
23 |
24 |
25 | def cal_pr_curve(prediction_file_name, ground_truth_file_name, topn):
26 | des_dists = np.load(prediction_file_name)['arr_0']
27 | des_dists = np.asarray(des_dists, dtype='float32')
28 | des_dists = des_dists.reshape((len(des_dists), 3))
29 |
30 | ground_truth = np.load(ground_truth_file_name, allow_pickle='True')
31 |
32 | gt_num = 0
33 | all_num = 0
34 | check_out = 0
35 |
36 | img_paths_query = load_files(query_scan)
37 | print(len(img_paths_query))
38 | sum1 = 0
39 | for idx in range(0, len(img_paths_query), 5):
40 |
41 | gt_idxes = np.array(ground_truth[int(gt_num)])
42 | if gt_idxes.any():
43 | all_num += 1
44 | else:
45 | gt_num += 1
46 | continue
47 |
48 | gt_num += 1
49 |
50 | dist_lists_cur = des_dists[des_dists[:, 0] == idx, :]
51 | idx_sorted = np.argsort(dist_lists_cur[:, -1], axis=-1)
52 |
53 | for i in range(topn):
54 | if int(dist_lists_cur[idx_sorted[i], 1]) in gt_idxes:
55 | check_out += 1
56 | break
57 |
58 | print("top" + str(topn) + " recall: ", check_out / all_num)
59 | return check_out / all_num
60 |
61 |
62 | def main(topn, ground_truth_file_name, dir_name):
63 | prediction_file_name = dir_name + "/predicted_des_L2_dis_bet_traj_forward.npz"
64 |
65 | topn_recall = cal_pr_curve(prediction_file_name, ground_truth_file_name, topn)
66 |
67 | return topn_recall
68 |
69 |
70 | if __name__ == "__main__":
71 | # load config ================================================================
72 | config_filename = '../config/config_nclt.yml'
73 | config = yaml.safe_load(open(config_filename))
74 | ground_truth_file_name = config["file_root"]["gt_file"]
75 | query_scan = config["file_root"]["data_root_folder_test"]
76 | # ============================================================================
77 | topn = 20
78 | recall_sum = 0
79 | recall_list = []
80 |
81 | dir_name = ("./nclt_o1shift2/12.6.15/5_15dis/")
82 |
83 | for i in range(1, topn + 1):
84 | rec = main(i, ground_truth_file_name, dir_name)
85 | recall_sum += rec
86 | if i == 1:
87 | print("AR@1 = ", recall_sum / i)
88 | if i == 5:
89 | print("AR@5 = ", recall_sum / i)
90 | if i == 20:
91 | print("AR@20 = ", recall_sum / i)
92 | recall_list.append(rec)
93 |
94 | print(recall_list)
95 |
96 | np.save(dir_name + "/recall_list", np.array(recall_list))
97 |
--------------------------------------------------------------------------------
/test/test_nclt_topn_prepare.py:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | import os
5 | import sys
6 |
7 | p = os.path.dirname(os.path.dirname((os.path.abspath(__file__))))
8 | if p not in sys.path:
9 | sys.path.append(p)
10 | sys.path.append('../tools/')
11 | sys.path.append('../modules/')
12 |
13 | import matplotlib.pyplot as plt
14 | import torch
15 | import yaml
16 | import numpy as np
17 |
18 | from modules.overlap_transformer import featureExtracter
19 | from tools.read_samples_haomo import read_one_need_from_seq
20 | from tools.read_samples_haomo import read_one_need_from_seq_test
21 |
22 | np.set_printoptions(threshold=sys.maxsize)
23 | from tqdm import tqdm
24 | import faiss
25 | from tools.utils.utils import *
26 |
27 |
28 | def shift(tensor, dim, index):
29 | length = tensor.size(dim)
30 | shifted_tensor = torch.cat((tensor.narrow(dim, index, length - index),
31 | tensor.narrow(dim, 0, index)), dim=dim)
32 | return shifted_tensor
33 |
34 | def unshift(tensor, dim, index):
35 | length = tensor.size(dim)
36 | unshifted_tensor = torch.cat((tensor.narrow(dim, length - index, index),
37 | tensor.narrow(dim, 0, length - index)), dim=dim)
38 | return unshifted_tensor
39 |
40 | class testHandler():
41 | def __init__(self, height=32, width=900, channels=1, norm_layer=None, use_transformer=False,
42 | data_root_folder=None, data_root_folder_test=None, test_weights=None):
43 | super(testHandler, self).__init__()
44 |
45 | self.height = height
46 | self.width = width
47 | self.channels = channels
48 | self.norm_layer = norm_layer
49 | self.use_transformer = use_transformer
50 | self.data_root_folder = data_root_folder
51 | self.data_root_folder_test = data_root_folder_test
52 |
53 | self.amodel = featureExtracter(channels=self.channels, use_transformer=self.use_transformer)
54 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
55 | self.amodel.to(self.device)
56 | self.parameters = self.amodel.parameters()
57 |
58 | self.test_weights = test_weights
59 |
60 | def eval(self):
61 |
62 | print("Resuming From ", self.test_weights)
63 | checkpoint = torch.load(self.test_weights)
64 | self.amodel.load_state_dict(checkpoint['state_dict'])
65 |
66 | range_image_paths_database = load_files(self.data_root_folder)
67 | print("scan number of database: ", len(range_image_paths_database))
68 |
69 | des_list = np.zeros((len(range_image_paths_database), 512)) # for forward driving
70 | for j in tqdm(range(0, len(range_image_paths_database))):
71 | f1_index = str(j).zfill(6)
72 | current_batch = read_one_need_from_seq(self.data_root_folder, f1_index)
73 | # current_batch_double = torch.cat((current_batch, current_batch), dim=-1)
74 | # current_batch_inv = current_batch_double[:, :, :, 450:1350]
75 | # print(current_batch.shape)
76 | # current_batch = torch.cat((current_batch, current_batch_inv), dim=0)
77 |
78 |
79 |
80 |
81 | # print(current_batch.shape)
82 | self.amodel.eval()
83 | #current_batch_des = self.amodel(current_batch)
84 | index222 = int(torch.rand(1).item() * 900)
85 | input_batch_shift = shift(current_batch, 3, index222)
86 |
87 | global_des = self.amodel(current_batch)
88 | global_des_shift = self.amodel(input_batch_shift)
89 | global_des_add = torch.cat((global_des, global_des_shift), dim=1)
90 | des_list[(j), :] = global_des_add[0, :].cpu().detach().numpy()
91 |
92 | des_list = des_list.astype('float32')
93 |
94 | nlist = 1
95 | k = 50
96 | d = 512
97 | quantizer = faiss.IndexFlatL2(d)
98 |
99 | index = faiss.IndexIVFFlat(quantizer, d, nlist, faiss.METRIC_L2)
100 | assert not index.is_trained
101 |
102 | index.train(des_list)
103 | assert index.is_trained
104 | index.add(des_list)
105 | row_list = []
106 |
107 | range_image_paths_query = load_files(self.data_root_folder_test)
108 | print("scan number of query: ", len(range_image_paths_query))
109 |
110 | for i in range(0, len(range_image_paths_query), 5):
111 |
112 | i_index = str(i).zfill(6)
113 | current_batch = read_one_need_from_seq_test(self.data_root_folder_test, i_index) # compute 1 descriptors
114 | # current_batch_double = torch.cat((current_batch, current_batch), dim=-1)
115 | # current_batch_inv = current_batch_double[:, :, :, 450:1350]
116 | # current_batch = torch.cat((current_batch, current_batch_inv), dim=0)
117 | self.amodel.eval()
118 | index111 = int(torch.rand(1).item() * 900)
119 | input_batch_shift = shift(current_batch, 3, index111)
120 |
121 | global_des = self.amodel(current_batch)
122 | global_des_shift = self.amodel(input_batch_shift)
123 | global_des_add = torch.cat((global_des, global_des_shift), dim=1)
124 | #current_batch_des = self.amodel(current_batch) # torch.Size([(1+pos_num+neg_num)), 256])
125 | print()
126 | des_list_current = global_des_add[0, :].cpu().detach().numpy()
127 |
128 | D, I = index.search(des_list_current.reshape(1, -1), k) # actual search
129 |
130 | for j in range(D.shape[1]):
131 | one_row = np.zeros((1, 3))
132 | one_row[:, 0] = i
133 | one_row[:, 1] = I[:, j]
134 | one_row[:, 2] = D[:, j]
135 | row_list.append(one_row)
136 | print("query:" + str(i) + "---->" + "database:" + str(I[:, j]) + " " + str(D[:, j]))
137 |
138 | row_list_arr = np.array(row_list)
139 | dir_name = "./nclt_cvtnet/12.2.5/5_15dis/"
140 | if not os.path.exists(dir_name):
141 | os.mkdir(dir_name)
142 | np.savez_compressed(dir_name + "predicted_des_L2_dis_bet_traj_forward", row_list_arr)
143 |
144 |
145 | if __name__ == '__main__':
146 | # data
147 | # load config ================================================================
148 | config_filename = '../config/config_haomo.yml'
149 | config = yaml.safe_load(open(config_filename))
150 | data_root_folder = config["file_root"]["data_root_folder"]
151 | data_root_folder_test = config["file_root"]["data_root_folder_test1"]
152 | test_weights = config["file_root"]["test_weights"]
153 | # ============================================================================
154 |
155 | test_handler = testHandler(height=32, width=900, channels=1, norm_layer=None, use_transformer=False,
156 | data_root_folder=data_root_folder, data_root_folder_test=data_root_folder_test,
157 | test_weights=test_weights)
158 |
159 | test_handler.eval()
160 |
--------------------------------------------------------------------------------
/tools/__pycache__/read_all_sets.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SCNU-RISLAB/OverlapMamba/b20add9b83ea87727a993eb316869d271436be78/tools/__pycache__/read_all_sets.cpython-38.pyc
--------------------------------------------------------------------------------
/tools/__pycache__/read_all_sets.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SCNU-RISLAB/OverlapMamba/b20add9b83ea87727a993eb316869d271436be78/tools/__pycache__/read_all_sets.cpython-39.pyc
--------------------------------------------------------------------------------
/tools/__pycache__/read_samples.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SCNU-RISLAB/OverlapMamba/b20add9b83ea87727a993eb316869d271436be78/tools/__pycache__/read_samples.cpython-38.pyc
--------------------------------------------------------------------------------
/tools/__pycache__/read_samples.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SCNU-RISLAB/OverlapMamba/b20add9b83ea87727a993eb316869d271436be78/tools/__pycache__/read_samples.cpython-39.pyc
--------------------------------------------------------------------------------
/tools/__pycache__/read_samples_haomo.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SCNU-RISLAB/OverlapMamba/b20add9b83ea87727a993eb316869d271436be78/tools/__pycache__/read_samples_haomo.cpython-39.pyc
--------------------------------------------------------------------------------
/tools/name_folder.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 |
4 | def create_folders_in_subfolders(folder_path, new_folder_name):
5 | # 获取文件夹路径下的所有子文件夹
6 | subfolders = [f.path for f in os.scandir(folder_path) if f.is_dir()]
7 |
8 | # 在每个子文件夹中创建新文件夹(如果不存在同名文件夹)
9 | for subfolder in subfolders:
10 | new_folder_path = os.path.join(subfolder, new_folder_name)
11 | if not os.path.exists(new_folder_path):
12 | os.makedirs(new_folder_path)
13 | print(f"Created new folder: {new_folder_path}")
14 | else:
15 | print(f"Folder already exists, skipping: {new_folder_path}")
16 |
17 |
18 | # 指定文件夹路径
19 | folder_path = "/home/lenovo/xqc/OverlapTransformer-master/data_root_folder"
20 |
21 | # 指定新文件夹名称
22 | new_created_folder_name = "depth_map"
23 |
24 |
25 | def rename_folders_in_subfolders(folder_path, old_folder_name, new_folder_name):
26 | # 获取文件夹路径下的所有子文件夹
27 | subfolders = [f.path for f in os.scandir(folder_path) if f.is_dir()]
28 |
29 | # 遍历每个子文件夹,如果子文件夹包含名为old_folder_name的文件夹,则将其重命名为new_folder_name
30 | for subfolder in subfolders:
31 | folder_list = [f.name for f in os.scandir(subfolder) if f.is_dir()]
32 | if old_folder_name in folder_list:
33 | old_folder_path = os.path.join(subfolder, old_folder_name)
34 | new_folder_path = os.path.join(subfolder, new_folder_name)
35 | os.rename(old_folder_path, new_folder_path)
36 | print(f"Renamed folder from {old_folder_path} to {new_folder_path}")
37 |
38 |
39 | # 指定要替换的旧文件夹名称和新文件夹名称
40 | old_folder_name = "depth_map_50"
41 | new_folder_name = "depth_map"
42 |
43 | # 调用函数创建新文件夹
44 | # create_folders_in_subfolders(folder_path, new_created_folder_name)
45 |
46 | # 调用函数重命名文件夹
47 | rename_folders_in_subfolders(folder_path, old_folder_name, new_folder_name)
48 |
--------------------------------------------------------------------------------
/tools/read_all_sets.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | p = os.path.dirname(os.path.dirname((os.path.abspath(__file__))))
4 | if p not in sys.path:
5 | sys.path.append(p)
6 |
7 | import numpy as np
8 |
9 | """Use the tools from OverlapNet"""
10 |
11 | def overlap_orientation_npz_file2string_string_nparray(npzfilenames, shuffle=True):
12 |
13 | imgf1_all = []
14 | imgf2_all = []
15 | dir1_all = []
16 | dir2_all = []
17 | overlap_all = []
18 |
19 | for npzfilename in npzfilenames:
20 | h = np.load(npzfilename, allow_pickle=True)
21 |
22 | if len(h.files) == 3:
23 | # old format
24 | imgf1 = np.char.mod('%06d', h[h.files[0]][:, 0]).tolist()
25 | imgf2 = np.char.mod('%06d', h[h.files[0]][:, 1]).tolist()
26 | overlap = h[h.files[0]][:, 2]
27 | orientation = h[h.files[0]][:, 3]
28 | n = len(imgf1)
29 | dir1 = np.array(['' for _ in range(n)]).tolist()
30 | dir2 = np.array(['' for _ in range(n)]).tolist()
31 | else:
32 | imgf1 = np.char.mod('%06d', h['overlaps'][:, 0]).tolist()
33 | imgf2 = np.char.mod('%06d', h['overlaps'][:, 1]).tolist()
34 | overlap = h['overlaps'][:, 2]
35 | dir1 = (h['seq'][:, 0]).tolist()
36 | dir2 = (h['seq'][:, 1]).tolist()
37 |
38 | if shuffle:
39 | shuffled_idx = np.random.permutation(overlap.shape[0])
40 | imgf1 = (np.array(imgf1)[shuffled_idx]).tolist()
41 | imgf2 = (np.array(imgf2)[shuffled_idx]).tolist()
42 | dir1 = (np.array(dir1)[shuffled_idx]).tolist()
43 | dir2 = (np.array(dir2)[shuffled_idx]).tolist()
44 | overlap = overlap[shuffled_idx]
45 |
46 | imgf1_all.extend(imgf1)
47 | imgf2_all.extend(imgf2)
48 | dir1_all.extend(dir1)
49 | dir2_all.extend(dir2)
50 | overlap_all.extend(overlap)
51 |
52 | return (imgf1_all, imgf2_all, dir1_all, dir2_all, np.asarray(overlap_all))
53 |
--------------------------------------------------------------------------------
/tools/read_samples.py:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | import os
5 | import sys
6 |
7 | p = os.path.dirname(os.path.dirname((os.path.abspath(__file__))))
8 | if p not in sys.path:
9 | sys.path.append(p)
10 |
11 | import matplotlib.pyplot as plt
12 | import torch
13 | import cv2
14 | import numpy as np
15 |
16 | np.set_printoptions(threshold=sys.maxsize)
17 | from utils.utils import *
18 | import yaml
19 | from tools.read_all_sets import overlap_orientation_npz_file2string_string_nparray
20 |
21 | """
22 | read one needed $file_num range image from sequence $file_num.
23 | Args:
24 | data_root_folder: dataset root of KITTI.
25 | file_num: the index of the needed scan (zfill 6).
26 | seq_num: the sequence in which the needed scan is (zfill 2).
27 | """
28 |
29 |
30 | def read_one_need_from_seq(data_root_folder, file_num, seq_num):
31 | depth_data = \
32 | np.array(cv2.imread(data_root_folder + seq_num + "/depth_map/" + file_num + ".png",
33 | cv2.IMREAD_GRAYSCALE))
34 |
35 | # depth_data = \
36 | # np.array(np.load(data_root_folder + seq_num + "/depth_map/" + file_num + ".npy",
37 | # ))
38 |
39 | depth_data_tensor = torch.from_numpy(depth_data).type(torch.FloatTensor).cuda()
40 | depth_data_tensor = torch.unsqueeze(depth_data_tensor, dim=0)
41 | depth_data_tensor = torch.unsqueeze(depth_data_tensor, dim=0)
42 |
43 | return depth_data_tensor
44 |
45 |
46 | """
47 | read one batch of positive samples and negative samples with respect to $f1_index in sequence $f1_seq.
48 | Args:
49 | data_root_folder: dataset root of KITTI.
50 | f1_index: the index of the needed scan (zfill 6).
51 | f1_seq: the sequence in which the needed scan is (zfill 2).
52 | train_imgf1, train_imgf2, train_dir1, train_dir2: the index dictionary and sequence dictionary following OverlapNet.
53 | train_overlap: overlaps dictionary following OverlapNet.
54 | overlap_thresh: 0.3 following OverlapNet.
55 | """
56 |
57 |
58 | def read_one_batch_pos_neg(data_root_folder, f1_index, f1_seq, train_imgf1, train_imgf2, train_dir1, train_dir2,
59 | train_overlap, overlap_thresh): # without end
60 |
61 | batch_size = 0
62 | for tt in range(len(train_imgf1)):
63 | if f1_index == train_imgf1[tt] and f1_seq == train_dir1[tt]:
64 | batch_size = batch_size + 1
65 |
66 | sample_batch = torch.from_numpy(np.zeros((batch_size, 1, 32, 900))).type(torch.FloatTensor).cuda()
67 | sample_truth = torch.from_numpy(np.zeros((batch_size, 1))).type(torch.FloatTensor).cuda()
68 |
69 | pos_idx = 0
70 | neg_idx = 0
71 | pos_num = 0
72 | neg_num = 0
73 |
74 | for j in range(len(train_imgf1)):
75 | pos_flag = False
76 | if f1_index == train_imgf1[j] and f1_seq == train_dir1[j]:
77 | if train_overlap[j] > overlap_thresh:
78 | pos_num = pos_num + 1
79 | pos_flag = True
80 | else:
81 | neg_num = neg_num + 1
82 |
83 | depth_data_r = \
84 | np.array(cv2.imread(data_root_folder + train_dir2[j] + "/depth_map/" + train_imgf2[j] + ".png",
85 | cv2.IMREAD_GRAYSCALE))
86 |
87 | # depth_data_r = \
88 | # np.array(np.load(data_root_folder + train_dir2[j] + "/depth_map/" + train_imgf2[j] + ".npy",
89 | # ))
90 |
91 | depth_data_tensor_r = torch.from_numpy(depth_data_r).type(torch.FloatTensor).cuda()
92 | depth_data_tensor_r = torch.unsqueeze(depth_data_tensor_r, dim=0)
93 |
94 | if pos_flag:
95 | sample_batch[pos_idx, :, :, :] = depth_data_tensor_r
96 | sample_truth[pos_idx, :] = torch.from_numpy(np.array(train_overlap[j])).type(torch.FloatTensor).cuda()
97 | pos_idx = pos_idx + 1
98 | else:
99 | sample_batch[batch_size - neg_idx - 1, :, :, :] = depth_data_tensor_r
100 | sample_truth[batch_size - neg_idx - 1, :] = torch.from_numpy(np.array(train_overlap[j])).type(
101 | torch.FloatTensor).cuda()
102 | neg_idx = neg_idx + 1
103 |
104 | return sample_batch, sample_truth, pos_num, neg_num
105 |
106 |
107 | if __name__ == '__main__':
108 | # load config ================================================================
109 | config_filename = '../config/config.yml'
110 | config = yaml.safe_load(open(config_filename))
111 | seqs_root = config["data_root"]["data_root_folder"]
112 | # ============================================================================
113 |
114 | seq = "08"
115 | cur_frame_idx = "000887"
116 | current_frame = read_one_need_from_seq(seqs_root, cur_frame_idx, seq)
117 |
118 | traindata_npzfiles = [os.path.join(seqs_root, seq, 'overlaps/train_set.npz')]
119 | (train_imgf1, train_imgf2, train_dir1, train_dir2, train_overlap) = \
120 | overlap_orientation_npz_file2string_string_nparray(traindata_npzfiles)
121 | reference_frames, reference_gts, pos_num, neg_num = read_one_batch_pos_neg(seqs_root, cur_frame_idx, seq,
122 | train_imgf1, train_imgf2, train_dir1,
123 | train_dir2, train_overlap, 0.3)
124 |
125 | # visualization
126 | print("the size of current_frame: ", current_frame.size())
127 | plt.figure(figsize=(15, 3))
128 | plt.title("One sampled range image from KITTI sequence " + seq + ": " + cur_frame_idx + ".bin")
129 | plt.imshow(current_frame.cpu().detach().numpy()[0, 0, :, :])
130 | plt.show()
131 |
132 | print("the size of reference_frames: ", reference_frames.size())
133 | vis_idx = 5 # show the 5th sampled range image in the reference batch
134 | plt.figure(figsize=(15, 3))
135 | plt.suptitle(
136 | "One sampled query-reference from KITTI sequence " + seq + ", Overlap: " + str(reference_gts[vis_idx].item()))
137 | plt.subplot(211)
138 | plt.title("query")
139 | plt.imshow(current_frame.cpu().detach().numpy()[0, 0, :, :])
140 | plt.subplot(212)
141 | plt.title("reference")
142 | plt.imshow(reference_frames.cpu().detach().numpy()[vis_idx, 0, :, :])
143 | plt.show()
144 |
--------------------------------------------------------------------------------
/tools/read_samples_haomo.py:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | import os
5 | import sys
6 | p = os.path.dirname(os.path.dirname((os.path.abspath(__file__))))
7 | if p not in sys.path:
8 | sys.path.append(p)
9 | sys.path.append('../tools/')
10 | sys.path.append('../modules/')
11 |
12 | import torch
13 | import numpy as np
14 | np.set_printoptions(threshold=sys.maxsize)
15 | from utils.utils import *
16 | import yaml
17 | import matplotlib.pyplot as plt
18 |
19 |
20 | def read_one_need_from_seq(data_root_folder, file_num):
21 |
22 | depth_data = np.load(data_root_folder + file_num + ".npy")
23 | depth_data_tensor = torch.from_numpy(depth_data).type(torch.FloatTensor).cuda()
24 | depth_data_tensor = torch.unsqueeze(depth_data_tensor, dim=0)
25 | depth_data_tensor = torch.unsqueeze(depth_data_tensor, dim=0)
26 |
27 | return depth_data_tensor
28 |
29 | def read_one_need_from_seq_test(data_root_folder_test, file_num):
30 |
31 | depth_data = np.load(data_root_folder_test + file_num + ".npy")
32 |
33 | depth_data_tensor = torch.from_numpy(depth_data).type(torch.FloatTensor).cuda()
34 | depth_data_tensor = torch.unsqueeze(depth_data_tensor, dim=0)
35 | depth_data_tensor = torch.unsqueeze(depth_data_tensor, dim=0)
36 |
37 | return depth_data_tensor
38 |
39 |
40 |
41 |
42 | def read_one_batch_pos_neg(data_root_folder, f1_index, f1_seq, train_imgf1, train_imgf2, train_dir1, train_dir2, train_overlap, overlap_thresh): # without end
43 |
44 | batch_size = 0
45 | for tt in range(len(train_imgf1)):
46 | if f1_index == train_imgf1[tt] and f1_seq == train_dir1[tt] and (train_overlap[tt]> overlap_thresh or train_overlap[tt]<(overlap_thresh-0.0)): # TODO: You can update the range
47 | batch_size = batch_size + 1
48 |
49 | sample_batch = torch.from_numpy(np.zeros((batch_size, 1, 32, 900))).type(torch.FloatTensor).cuda()
50 | sample_truth = torch.from_numpy(np.zeros((batch_size, 1))).type(torch.FloatTensor).cuda()
51 |
52 | pos_idx = 0
53 | neg_idx = 0
54 | pos_num = 0
55 | neg_num = 0
56 |
57 |
58 | for j in range(len(train_imgf1)):
59 | pos_flag = False
60 | if f1_index == train_imgf1[j] and f1_seq==train_dir1[j]:
61 | if train_overlap[j]> overlap_thresh:
62 | pos_num = pos_num + 1
63 | pos_flag = True
64 | elif train_overlap[j]< overlap_thresh:
65 | neg_num = neg_num + 1
66 | else:
67 | continue
68 |
69 | depth_data_r = np.load(data_root_folder + train_imgf2[j] + ".npy")
70 | depth_data_tensor_r = torch.from_numpy(depth_data_r).type(torch.FloatTensor).cuda()
71 | depth_data_tensor_r = torch.unsqueeze(depth_data_tensor_r, dim=0)
72 |
73 | if pos_flag:
74 | sample_batch[pos_idx,:,:,:] = depth_data_tensor_r
75 | sample_truth[pos_idx, :] = torch.from_numpy(np.array(train_overlap[j])).type(torch.FloatTensor).cuda()
76 | pos_idx = pos_idx + 1
77 | else:
78 | sample_batch[batch_size-neg_idx-1, :, :, :] = depth_data_tensor_r
79 | sample_truth[batch_size-neg_idx-1, :] = torch.from_numpy(np.array(train_overlap[j])).type(torch.FloatTensor).cuda()
80 | neg_idx = neg_idx + 1
81 |
82 |
83 | return sample_batch, sample_truth, pos_num, neg_num
84 |
85 |
86 |
87 | if __name__ == '__main__':
88 | # load config ================================================================
89 | config_filename = '../config/config_haomo.yml'
90 | config = yaml.safe_load(open(config_filename))
91 | data_root_folder = config["file_root"]["data_root_folder"]
92 | triplets_for_training = config["file_root"]["triplets_for_training"]
93 | training_seqs = config["training_config"]["training_seqs"]
94 | # ============================================================================
95 |
96 | train_set_imgf1_imgf2_overlap = np.load(triplets_for_training)
97 | # print(train_set_imgf1_imgf2_overlap)
98 |
99 | cur_frame_idx = "003430"
100 | current_frame = read_one_need_from_seq(data_root_folder, cur_frame_idx)
101 |
102 | train_imgf1 = train_set_imgf1_imgf2_overlap[:, 0]
103 | train_imgf2 = train_set_imgf1_imgf2_overlap[:, 1]
104 | train_dir1 = np.zeros((len(train_imgf1),)) # to use the same form as KITTI
105 | train_dir2 = np.zeros((len(train_imgf2),))
106 | train_overlap = train_set_imgf1_imgf2_overlap[:, 2].astype(float)
107 | reference_frames, reference_gts, pos_num, neg_num = read_one_batch_pos_neg \
108 | (data_root_folder, cur_frame_idx, 0, train_imgf1, train_imgf2, train_dir1,
109 | train_dir2, train_overlap, 0.3)
110 |
111 |
112 |
113 | # visualization
114 | print("the size of current_frame: ", current_frame.size())
115 | plt.figure(figsize=(15,3))
116 | plt.title("One sampled range image from Haomo dataset: " + cur_frame_idx + ".bin")
117 | plt.imshow(current_frame.cpu().detach().numpy()[0, 0, :, :])
118 | plt.show()
119 |
120 | print("the size of reference_frames: ", reference_frames.size())
121 | vis_idx = 5 # show the 2rd sampled range image in the reference batch
122 | plt.figure(figsize=(15,3))
123 | plt.suptitle("One sampled query-reference from Haomo dataset, Overlap: " + str(reference_gts[vis_idx].item()))
124 | plt.subplot(211)
125 | plt.title("query")
126 | plt.imshow(current_frame.cpu().detach().numpy()[0, 0, :, :])
127 | plt.subplot(212)
128 | plt.title("reference")
129 | plt.imshow(reference_frames.cpu().detach().numpy()[vis_idx, 0, :, :])
130 | plt.show()
131 |
132 |
133 |
--------------------------------------------------------------------------------
/tools/utils/__pycache__/utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SCNU-RISLAB/OverlapMamba/b20add9b83ea87727a993eb316869d271436be78/tools/utils/__pycache__/utils.cpython-38.pyc
--------------------------------------------------------------------------------
/tools/utils/__pycache__/utils.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SCNU-RISLAB/OverlapMamba/b20add9b83ea87727a993eb316869d271436be78/tools/utils/__pycache__/utils.cpython-39.pyc
--------------------------------------------------------------------------------
/tools/utils/com_overlap_yaw.py:
--------------------------------------------------------------------------------
1 | # #!/usr/bin/env python3
2 | # # Developed by Xieyuanli Chen and Thomas Läbe
3 | # # This file is covered by the LICENSE file in the root of this project.
4 | # # Brief: This script generate the overlap and orientation combined mapping file.
5 |
6 |
7 |
8 | try:
9 | from utils import *
10 | except:
11 | from utils import *
12 |
13 |
14 | def com_overlap_yaw(scan_paths, poses, frame_idx, leg_output_width=360):
15 | """compute the overlap and yaw ground truth from the ground truth poses,
16 | which is used for OverlapNet training and testing.
17 | Args:
18 | scan_paths: paths of all raw LiDAR scans
19 | poses: ground-truth poses either given by the dataset or generated by SLAM or odometry
20 | frame_idx: the current frame index
21 | Returns:
22 | ground_truth_mapping: the ground truth overlap and yaw used for training OverlapNet,
23 | where each row contains [current_frame_idx, reference_frame_idx, overlap, yaw]
24 | """
25 | # init ground truth overlap and yaw
26 | print('Start to compute ground truth overlap and yaw ...')
27 | overlaps = []
28 | yaw_idxs = []
29 | yaw_resolution = leg_output_width
30 |
31 | # we calculate the ground truth for one given frame only
32 | # generate range projection for the given frame
33 | current_points = load_vertex(scan_paths[frame_idx])
34 | current_range, project_points, _, _ = range_projection(current_points)
35 | visible_points = project_points[current_range > 0]
36 | valid_num = len(visible_points)
37 | current_pose = poses[frame_idx]
38 |
39 | for reference_idx in range(len(scan_paths)):
40 | # generate range projection for the reference frame
41 | reference_pose = poses[reference_idx]
42 | reference_points = load_vertex(scan_paths[reference_idx])
43 | reference_points_world = reference_pose.dot(reference_points.T).T
44 | reference_points_in_current = np.linalg.inv(current_pose).dot(reference_points_world.T).T
45 | reference_range, _, _, _ = range_projection(reference_points_in_current)
46 |
47 | # calculate overlap
48 | overlap = np.count_nonzero(
49 | abs(reference_range[reference_range > 0] - current_range[reference_range > 0]) < 1) / valid_num
50 | overlaps.append(overlap)
51 |
52 | # calculate yaw angle
53 | relative_transform = np.linalg.inv(current_pose).dot(reference_pose)
54 | relative_rotation = relative_transform[:3, :3]
55 | _, _, yaw = euler_angles_from_rotation_matrix(relative_rotation)
56 |
57 | # discretize yaw angle and shift the 0 degree to the center to make the network easy to lean
58 | yaw_element_idx = int(- (yaw / np.pi) * yaw_resolution // 2 + yaw_resolution // 2)
59 | yaw_idxs.append(yaw_element_idx)
60 |
61 | # print('finished pair id: ', reference_idx)
62 |
63 | # ground truth format: each row contains [current_frame_idx, reference_frame_idx, overlap, yaw]
64 | ground_truth_mapping = np.zeros((len(scan_paths), 4))
65 | ground_truth_mapping[:, 0] = np.ones(len(scan_paths)) * frame_idx
66 | ground_truth_mapping[:, 1] = np.arange(len(scan_paths))
67 | ground_truth_mapping[:, 2] = overlaps
68 | ground_truth_mapping[:, 3] = yaw_idxs
69 | # print(ground_truth_mapping)
70 |
71 | print('Finish generating ground_truth_mapping!')
72 |
73 | return ground_truth_mapping
74 |
--------------------------------------------------------------------------------
/tools/utils/gen_depth_data.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Developed by Xieyuanli Chen and Thomas Läbe
3 | # This file is covered by the LICENSE file in the root of the project OverlapNet:
4 | #https://github.com/PRBonn/OverlapNet
5 | # Brief: a script to generate depth data
6 | import os
7 | from utils import load_files
8 | import numpy as np
9 | from utils import range_projection
10 |
11 | import cv2
12 | try:
13 | from utils import *
14 | except:
15 | from utils import *
16 |
17 | import scipy.linalg as linalg
18 | def rotate_mat( axis, radian):
19 | rot_matrix = linalg.expm(np.cross(np.eye(3), axis / linalg.norm(axis) * radian))
20 | return rot_matrix
21 | # print(type(rot_matrix))
22 |
23 |
24 |
25 | def gen_depth_data(scan_folder, dst_folder, normalize=False):
26 | """ Generate projected range data in the shape of (64, 900, 1).
27 | The input raw data are in the shape of (Num_points, 3).
28 | """
29 | # specify the goal folder
30 | dst_folder = os.path.join(dst_folder, 'depth')
31 | try:
32 | os.stat(dst_folder)
33 | print('generating depth data in: ', dst_folder)
34 | except:
35 | print('creating new depth folder: ', dst_folder)
36 | os.mkdir(dst_folder)
37 |
38 | # load LiDAR scan files
39 | scan_paths = load_files(scan_folder)
40 |
41 | depths = []
42 | axis_x, axis_y, axis_z = [1,0,0], [0,1,0], [0, 0, 1]
43 |
44 | # iterate over all scan files
45 | for idx in range(len(scan_paths)):
46 | # load a point cloud
47 | current_vertex = np.fromfile(scan_paths[idx], dtype=np.float32)
48 | current_vertex = current_vertex.reshape((-1, 4))
49 |
50 | proj_range, _, _, _ = range_projection(current_vertex) # proj_ranges from larger to smaller
51 |
52 | # normalize the image
53 | if normalize:
54 | proj_range = proj_range / np.max(proj_range)
55 |
56 | # generate the destination path
57 | dst_path = os.path.join(dst_folder, str(idx).zfill(6))
58 |
59 | # np.save(dst_path, proj_range)
60 | filename = dst_path + ".png"
61 | cv2.imwrite(filename, proj_range)
62 | print('finished generating depth data at: ', dst_path)
63 |
64 | return depths
65 |
66 |
67 | if __name__ == '__main__':
68 | scan_folder = '/home/lenovo/xqc/datasets/kitti/sequences/11/velodyne'
69 | dst_folder = '/data_root_folder/11'
70 |
71 | depth_data = gen_depth_data(scan_folder, dst_folder)
72 |
--------------------------------------------------------------------------------
/tools/utils/gen_gt_data.py:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | import os
5 | import sys
6 | p = os.path.dirname(os.path.dirname((os.path.abspath(__file__))))
7 | if p not in sys.path:
8 | sys.path.append(p)
9 |
10 | import time
11 | from com_overlap_yaw import com_overlap_yaw
12 | from utils import *
13 |
14 | # paths of kitti dataset
15 | scan_folder = "/media/mjy/My Passport/kitti_all/dataset/sequences/00/velodyne"
16 | calib_file = "/home/mjy/datasets/kitti/data_odometry_calib/dataset/sequences/00/calib.txt"
17 | # prepare poses of semantic kitti dataset (refined poses)
18 | poses_file = "/media/mjy/Samsung_T5/SemanticKITTI/data_odometry_labels/dataset/sequences/00/poses.txt"
19 |
20 | scan_paths = load_files(scan_folder)
21 | T_cam_velo = load_calib(calib_file)
22 | T_cam_velo = np.asarray(T_cam_velo).reshape((4, 4))
23 | T_velo_cam = np.linalg.inv(T_cam_velo)
24 | poses = load_poses(poses_file)
25 | pose0_inv = np.linalg.inv(poses[0])
26 | poses_new = []
27 | for pose in poses:
28 | poses_new.append(T_velo_cam.dot(pose0_inv).dot(pose).dot(T_cam_velo))
29 | poses = np.array(poses_new)
30 |
31 |
32 | all_rows = []
33 | thresh = 0.3
34 | for i in range(len(scan_paths)):
35 | print(str(i) + " -------------------------------->")
36 | time1 = time.time()
37 | scan_paths_this_frame = []
38 | poses_this_frame = []
39 | scan_paths_this_frame.append(scan_paths[i])
40 | poses_this_frame.append(poses[i])
41 | idx_in_range = []
42 | for idx in range(len(scan_paths)):
43 | if np.linalg.norm(poses[idx, :3, -1] - poses[i, :3, -1]) < 30 and (i-idx) > 100:
44 | scan_paths_this_frame.append(scan_paths[idx])
45 | poses_this_frame.append(poses[idx])
46 | idx_in_range.append(idx)
47 | print("prepared indexes for current laser: ", idx_in_range)
48 |
49 | poses_new_this_frame = np.array(poses_this_frame)
50 | ground_truth_mapping = com_overlap_yaw(scan_paths_this_frame, poses_new_this_frame, frame_idx=0, leg_output_width=360)
51 |
52 | one_row = []
53 | for m in range(1, ground_truth_mapping.shape[0]):
54 | if ground_truth_mapping[m,2] > thresh:
55 | one_row.append(idx_in_range[m-1])
56 | all_rows.append(one_row)
57 | print("gt list for current laser: ", one_row)
58 | time2 = time.time()
59 | print("time: ", time2-time1)
60 |
61 | print(len(all_rows))
62 | all_rows_array = np.array(all_rows)
63 | np.savez_compressed("loop_gt_seq00_0.3overlap_inactive", all_rows_array)
64 |
65 |
66 |
--------------------------------------------------------------------------------
/tools/utils/split_train_val.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Developed by Xieyuanli Chen and Thomas Läbe
3 | # This file is covered by the LICENSE file in the root of the project OverlapNet:
4 | #https://github.com/PRBonn/OverlapNet
5 | # Brief: a simple example to split the ground truth data into training and validation parts
6 |
7 | import numpy as np
8 | from sklearn.model_selection import train_test_split
9 |
10 |
11 | def split_train_val(ground_truth_mapping):
12 | """Split the ground truth data into training and validation two parts.
13 | Args:
14 | ground_truth_mapping: the raw ground truth mapping array
15 | Returns:
16 | train_set: data used for training
17 | test_set: data used for validation
18 | """
19 | # set the ratio of validation data
20 | test_size = int(len(ground_truth_mapping) / 10)
21 |
22 | # use sklearn library to split the data
23 | train_set, test_set = train_test_split(ground_truth_mapping, test_size=test_size)
24 |
25 | print('finished generating training data and validation data')
26 |
27 | return train_set, test_set
28 |
29 |
30 | if __name__ == '__main__':
31 | # read from npz file
32 | ground_truth_file = 'path/to/the/groun-truth/file'
33 | ground_truth_mapping = np.load(ground_truth_file)['arr_0']
34 |
35 | split_train_val(ground_truth_mapping)
36 |
37 |
--------------------------------------------------------------------------------
/tools/utils/utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Developed by Xieyuanli Chen and Thomas Läbe
3 | # This file is covered by the LICENSE file in the root of the project OverlapNet:
4 | #https://github.com/PRBonn/OverlapNet
5 | # Brief: some utilities
6 | import os
7 | import math
8 | import numpy as np
9 | import time
10 |
11 |
12 | def load_poses(pose_path):
13 | """ Load ground truth poses (T_w_cam0) from file.
14 | Args:
15 | pose_path: (Complete) filename for the pose file
16 | Returns:
17 | A numpy array of size nx4x4 with n poses as 4x4 transformation
18 | matrices
19 | """
20 | # Read and parse the poses
21 | poses = []
22 | try:
23 | if '.txt' in pose_path:
24 | with open(pose_path, 'r') as f:
25 | lines = f.readlines()
26 | for line in lines:
27 | T_w_cam0 = np.fromstring(line, dtype=float, sep=' ')
28 | T_w_cam0 = T_w_cam0.reshape(3, 4)
29 | T_w_cam0 = np.vstack((T_w_cam0, [0, 0, 0, 1]))
30 | poses.append(T_w_cam0)
31 | else:
32 | poses = np.load(pose_path)['arr_0']
33 |
34 | except FileNotFoundError:
35 | print('Ground truth poses are not avaialble.')
36 |
37 | return np.array(poses)
38 |
39 |
40 | def load_calib(calib_path):
41 | """ Load calibrations (T_cam_velo) from file.
42 | """
43 | # Read and parse the calibrations
44 | T_cam_velo = []
45 | try:
46 | with open(calib_path, 'r') as f:
47 | lines = f.readlines()
48 | for line in lines:
49 | if 'Tr:' in line:
50 | line = line.replace('Tr:', '')
51 | T_cam_velo = np.fromstring(line, dtype=float, sep=' ')
52 | T_cam_velo = T_cam_velo.reshape(3, 4)
53 | T_cam_velo = np.vstack((T_cam_velo, [0, 0, 0, 1]))
54 |
55 | except FileNotFoundError:
56 | print('Calibrations are not avaialble.')
57 |
58 | return np.array(T_cam_velo)
59 |
60 |
61 | def range_projection(current_vertex, fov_up=3.0, fov_down=-25.0, proj_H=64, proj_W=900, max_range=80):
62 | """ Project a pointcloud into a spherical projection, range image.
63 | Args:
64 | current_vertex: raw point clouds
65 | Returns:
66 | proj_range: projected range image with depth, each pixel contains the corresponding depth
67 | proj_vertex: each pixel contains the corresponding point (x, y, z, 1)
68 | proj_intensity: each pixel contains the corresponding intensity
69 | proj_idx: each pixel contains the corresponding index of the point in the raw point cloud
70 | """
71 | # laser parameters
72 | fov_up = fov_up / 180.0 * np.pi # field of view up in radians
73 | fov_down = fov_down / 180.0 * np.pi # field of view down in radians
74 | fov = abs(fov_down) + abs(fov_up) # get field of view total in radians
75 |
76 | # get depth of all points
77 | depth = np.linalg.norm(current_vertex[:, :3], 2, axis=1)
78 | current_vertex = current_vertex[(depth > 0) & (depth < max_range)] # get rid of [0, 0, 0] points
79 | depth = depth[(depth > 0) & (depth < max_range)]
80 |
81 | # get scan components
82 | scan_x = current_vertex[:, 0]
83 | scan_y = current_vertex[:, 1]
84 | scan_z = current_vertex[:, 2]
85 | intensity = current_vertex[:, 3]
86 |
87 | # get angles of all points
88 | yaw = -np.arctan2(scan_y, scan_x)
89 | pitch = np.arcsin(scan_z / depth)
90 |
91 | # get projections in image coords
92 | proj_x = 0.5 * (yaw / np.pi + 1.0) # in [0.0, 1.0]
93 | proj_y = 1.0 - (pitch + abs(fov_down)) / fov # in [0.0, 1.0]
94 |
95 | # scale to image size using angular resolution
96 | proj_x *= proj_W # in [0.0, W]
97 | proj_y *= proj_H # in [0.0, H]
98 |
99 | # round and clamp for use as index
100 | proj_x = np.floor(proj_x)
101 | proj_x = np.minimum(proj_W - 1, proj_x)
102 | proj_x = np.maximum(0, proj_x).astype(np.int32) # in [0,W-1]
103 |
104 | proj_y = np.floor(proj_y)
105 | proj_y = np.minimum(proj_H - 1, proj_y)
106 | proj_y = np.maximum(0, proj_y).astype(np.int32) # in [0,H-1]
107 |
108 | # order in decreasing depth
109 | order = np.argsort(depth)[::-1]
110 | depth = depth[order]
111 | intensity = intensity[order]
112 | proj_y = proj_y[order]
113 | proj_x = proj_x[order]
114 |
115 | scan_x = scan_x[order]
116 | scan_y = scan_y[order]
117 | scan_z = scan_z[order]
118 |
119 | indices = np.arange(depth.shape[0])
120 | indices = indices[order]
121 |
122 | proj_range = np.full((proj_H, proj_W), -1,
123 | dtype=np.float32) # [H,W] range (-1 is no data)
124 | proj_vertex = np.full((proj_H, proj_W, 4), -1,
125 | dtype=np.float32) # [H,W] index (-1 is no data)
126 | proj_idx = np.full((proj_H, proj_W), -1,
127 | dtype=np.int32) # [H,W] index (-1 is no data)
128 | proj_intensity = np.full((proj_H, proj_W), -1,
129 | dtype=np.float32) # [H,W] index (-1 is no data)
130 |
131 | proj_range[proj_y, proj_x] = depth
132 | proj_vertex[proj_y, proj_x] = np.array([scan_x, scan_y, scan_z, np.ones(len(scan_x))]).T
133 | proj_idx[proj_y, proj_x] = indices
134 | proj_intensity[proj_y, proj_x] = intensity
135 |
136 | return proj_range, proj_vertex, proj_intensity, proj_idx
137 |
138 |
139 | def gen_normal_map(current_range, current_vertex, proj_H=64, proj_W=900): # 高64,宽900
140 | """ Generate a normal image given the range projection of a point cloud.
141 | Args:
142 | current_range: range projection of a point cloud, each pixel contains the corresponding depth
143 | current_vertex: range projection of a point cloud,
144 | each pixel contains the corresponding point (x, y, z, 1)
145 | Returns:
146 | normal_data: each pixel contains the corresponding normal
147 | """
148 |
149 |
150 | normal_data = np.full((proj_H, proj_W, 3), -1, dtype=np.float32)
151 | time_pre1 = time.time()
152 |
153 | # iterate over all pixels in the range image
154 | for x in range(proj_W):
155 | for y in range(proj_H - 1):
156 | p = current_vertex[y, x][:3]
157 | depth = current_range[y, x]
158 |
159 | if depth > 0:
160 | wrap_x = wrap(x + 1, proj_W)
161 | u = current_vertex[y, wrap_x][:3]
162 | u_depth = current_range[y, wrap_x]
163 | if u_depth <= 0:
164 | continue
165 |
166 | v = current_vertex[y + 1, x][:3]
167 | v_depth = current_range[y + 1, x]
168 | if v_depth <= 0:
169 | continue
170 |
171 | u_norm = (u - p) / np.linalg.norm(u - p)
172 | v_norm = (v - p) / np.linalg.norm(v - p)
173 |
174 | w = np.cross(v_norm, u_norm)
175 | norm = np.linalg.norm(w)
176 | if norm > 0:
177 | normal = w / norm
178 | normal_data[y, x] = normal
179 | time_pre2 = time.time()
180 | print("gen normal time ", time_pre2 - time_pre1)
181 | return normal_data
182 |
183 |
184 | def wrap(x, dim):
185 | """ Wrap the boarder of the range image.
186 | """
187 | value = x
188 | if value >= dim:
189 | value = (value - dim)
190 | if value < 0:
191 | value = (value + dim)
192 | return value
193 |
194 |
195 | def euler_angles_from_rotation_matrix(R):
196 | """ From the paper by Gregory G. Slabaugh,
197 | Computing Euler angles from a rotation matrix
198 | psi, theta, phi = roll pitch yaw (x, y, z)
199 | Args:
200 | R: rotation matrix, a 3x3 numpy array
201 | Returns:
202 | a tuple with the 3 values psi, theta, phi in radians
203 | """
204 |
205 | def isclose(x, y, rtol=1.e-5, atol=1.e-8):
206 | return abs(x - y) <= atol + rtol * abs(y)
207 |
208 | phi = 0.0
209 | if isclose(R[2, 0], -1.0):
210 | theta = math.pi / 2.0
211 | psi = math.atan2(R[0, 1], R[0, 2])
212 | elif isclose(R[2, 0], 1.0):
213 | theta = -math.pi / 2.0
214 | psi = math.atan2(-R[0, 1], -R[0, 2])
215 | else:
216 | theta = -math.asin(R[2, 0])
217 | cos_theta = math.cos(theta)
218 | psi = math.atan2(R[2, 1] / cos_theta, R[2, 2] / cos_theta)
219 | phi = math.atan2(R[1, 0] / cos_theta, R[0, 0] / cos_theta)
220 | return psi, theta, phi
221 |
222 |
223 | def load_vertex(scan_path):
224 | """ Load 3D points of a scan. The fileformat is the .bin format used in
225 | the KITTI dataset.
226 | Args:
227 | scan_path: the (full) filename of the scan file
228 | Returns:
229 | A nx4 numpy array of homogeneous points (x, y, z, 1).
230 | """
231 | current_vertex = np.fromfile(scan_path, dtype=np.float32)
232 | current_vertex = current_vertex.reshape((-1, 4))
233 | current_points = current_vertex[:, 0:3]
234 | current_vertex = np.ones((current_points.shape[0], current_points.shape[1] + 1))
235 | current_vertex[:, :-1] = current_points
236 | return current_vertex
237 |
238 |
239 | def load_files(folder):
240 | """ Load all files in a folder and sort.
241 | """
242 | file_paths = [os.path.join(dp, f) for dp, dn, fn in os.walk(
243 | os.path.expanduser(folder)) for f in fn]
244 | file_paths.sort()
245 | return file_paths
246 |
247 |
248 | semantic_mapping = { # bgr
249 | 0: [0, 0, 0], # "unlabeled", and others ignored
250 | 1: [245, 150, 100], # "car"
251 | 2: [245, 230, 100], # "bicycle"
252 | 3: [150, 60, 30], # "motorcycle"
253 | 4: [180, 30, 80], # "truck"
254 | 5: [255, 0, 0], # "other-vehicle"
255 | 6: [30, 30, 255], # "person"
256 | 7: [200, 40, 255], # "bicyclist"
257 | 8: [90, 30, 150], # "motorcyclist"
258 | 9: [255, 0, 255], # "road"
259 | 10: [255, 150, 255], # "parking"
260 | 11: [75, 0, 75], # "sidewalk"
261 | 12: [75, 0, 175], # "other-ground"
262 | 13: [0, 200, 255], # "building"
263 | 14: [50, 120, 255], # "fence"
264 | 15: [0, 175, 0], # "vegetation"
265 | 16: [0, 60, 135], # "trunk"
266 | 17: [80, 240, 150], # "terrain"
267 | 18: [150, 240, 255], # "pole"
268 | 19: [0, 0, 255] # "traffic-sign"
269 | }
270 |
--------------------------------------------------------------------------------
/train/training_distributed.py:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | import os
5 | import sys
6 | p = os.path.dirname(os.path.dirname((os.path.abspath(__file__))))
7 | if p not in sys.path:
8 | sys.path.append(p)
9 | sys.path.append('../tools/')
10 | sys.path.append('../modules/')
11 | import torch
12 | import torch.distributed as dist
13 | from torch.nn.parallel import DistributedDataParallel as DDP
14 | from torch.utils.data.distributed import DistributedSampler
15 |
16 |
17 | import numpy as np
18 | from tensorboardX import SummaryWriter
19 | from tools.read_all_sets import overlap_orientation_npz_file2string_string_nparray
20 | from modules.overlap_transformer import featureExtracter
21 | from tools.read_samples import read_one_batch_pos_neg
22 | from tools.read_samples import read_one_need_from_seq
23 | np.set_printoptions(threshold=sys.maxsize)
24 | import modules.loss as PNV_loss
25 | from tools.utils.utils import *
26 | from valid.valid_seq import validate_seq_faiss
27 | import yaml
28 | import argparse
29 |
30 | def setup(rank, world_size):
31 | os.environ['MASTER_ADDR'] = 'localhost'
32 | os.environ['MASTER_PORT'] = '12355'
33 | dist.init_process_group("nccl", rank=rank, world_size=world_size)
34 |
35 | def cleanup():
36 | dist.destroy_process_group()
37 |
38 | class trainHandler():
39 | def __init__(self, rank, world_size, height=64, width=900, channels=5, norm_layer=None, use_transformer=True, use_mamba=False, lr = 0.001,
40 | data_root_folder = None, train_set=None, training_seqs=None):
41 | super(trainHandler, self).__init__()
42 |
43 | self.rank = rank
44 | self.world_size = world_size
45 | setup(rank, world_size)
46 |
47 | self.height = height
48 | self.width = width
49 | self.channels = channels
50 | self.norm_layer = norm_layer
51 | self.use_transformer = use_transformer
52 | self.use_mamba = use_mamba
53 | self.learning_rate = lr
54 | self.data_root_folder = data_root_folder
55 | self.train_set = train_set
56 | self.training_seqs = training_seqs
57 |
58 | self.amodel = featureExtracter(channels=self.channels, use_transformer=self.use_transformer, use_mamba=self.use_mamba)
59 | self.device = torch.device(f"cuda:{rank}" if torch.cuda.is_available() else "cpu")
60 | self.amodel.to(self.device)
61 | self.amodel = DDP(self.amodel, device_ids=[rank])
62 | self.parameters = self.amodel.parameters()
63 | self.optimizer = torch.optim.Adam(self.parameters, self.learning_rate)
64 |
65 | # self.optimizer = torch.optim.SGD(self.parameters, lr=self.learning_rate, momentum=0.9)
66 | self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=5, gamma=0.9)
67 | # self.scheduler = torch.optim.lr_scheduler.ExponentialLR(self.optimizer, gamma=0.98)
68 |
69 | self.traindata_npzfiles = train_set
70 |
71 | (self.train_imgf1, self.train_imgf2, self.train_dir1, self.train_dir2, self.train_overlap) = \
72 | overlap_orientation_npz_file2string_string_nparray(self.traindata_npzfiles)
73 |
74 | """change the args for resuming training process"""
75 | self.resume = True
76 | self.save_name = "../weights/train/trained_overlap_transformer24.pth.tar"
77 |
78 | """overlap threshold follows OverlapNet"""
79 | self.overlap_thresh = 0.3
80 |
81 | def train(self):
82 |
83 | epochs = 100
84 |
85 | """resume from the saved model"""
86 | if self.resume:
87 | resume_filename = self.save_name
88 | print("Resuming from ", resume_filename)
89 | checkpoint = torch.load(resume_filename, map_location=self.device)
90 | starting_epoch = checkpoint['epoch']
91 | self.amodel.load_state_dict(checkpoint['state_dict'])
92 | self.optimizer.load_state_dict(checkpoint['optimizer'])
93 | else:
94 | print("Training From Scratch ..." )
95 | starting_epoch = 0
96 |
97 | writer1 = SummaryWriter(comment=f"LR_0.xxxx_{self.rank}")
98 |
99 | for i in range(starting_epoch+1, epochs):
100 | if self.rank == 0:
101 | (self.train_imgf1, self.train_imgf2, self.train_dir1, self.train_dir2, self.train_overlap) = \
102 | overlap_orientation_npz_file2string_string_nparray(self.traindata_npzfiles, shuffle=True)
103 | data = [self.train_imgf1, self.train_imgf2, self.train_dir1, self.train_dir2, self.train_overlap]
104 | for r in range(1, self.world_size):
105 | dist.send_object(data, dst=r)
106 | else:
107 | data = dist.recv_object(src=0)
108 | self.train_imgf1, self.train_imgf2, self.train_dir1, self.train_dir2, self.train_overlap = data
109 |
110 | dist.barrier()
111 |
112 | if self.rank == 0:
113 | print("=======================================================================\n\n\n")
114 | print("training with seq: ", np.unique(np.array(self.train_dir1)))
115 | print("total pairs: ", len(self.train_imgf1))
116 | print("\n\n\n=======================================================================")
117 |
118 | loss_each_epoch = 0
119 | used_num = 0
120 |
121 | used_list_f1 = []
122 | used_list_dir1 = []
123 |
124 | for j in range(len(self.train_imgf1)):
125 | """
126 | check whether the query is used to train before (continue_flag==True/False).
127 | TODO: More efficient method
128 | """
129 | f1_index = self.train_imgf1[j]
130 | dir1_index = self.train_dir1[j]
131 | continue_flag = False
132 | for iddd in range(len(used_list_f1)):
133 | if f1_index==used_list_f1[iddd] and dir1_index==used_list_dir1[iddd]:
134 | continue_flag = True
135 | else:
136 | used_list_f1.append(f1_index)
137 | used_list_dir1.append(dir1_index)
138 |
139 | if continue_flag:
140 | continue
141 |
142 | """read one query range image from KITTI sequences"""
143 | current_batch = read_one_need_from_seq(self.data_root_folder, f1_index, dir1_index)
144 |
145 | """
146 | read several reference range images from KITTI sequences
147 | to consist of positive samples and negative samples
148 | """
149 | sample_batch, sample_truth, pos_num, neg_num = read_one_batch_pos_neg \
150 | (self.data_root_folder,f1_index, dir1_index,
151 | self.train_imgf1, self.train_imgf2, self.train_dir1, self.train_dir2, self.train_overlap,
152 | self.overlap_thresh)
153 |
154 | """
155 | the balance of positive samples and negative samples.
156 | TODO: Update for better training results
157 | """
158 | use_pos_num = 18
159 | use_neg_num = 18
160 | if pos_num >= use_pos_num and neg_num>=use_neg_num:
161 | sample_batch = torch.cat((sample_batch[0:use_pos_num, :, :, :], sample_batch[pos_num:pos_num+use_neg_num, :, :, :]), dim=0)
162 | sample_truth = torch.cat((sample_truth[0:use_pos_num, :], sample_truth[pos_num:pos_num+use_neg_num, :]), dim=0)
163 | pos_num = use_pos_num
164 | neg_num = use_neg_num
165 | elif pos_num >= use_pos_num:
166 | sample_batch = torch.cat((sample_batch[0:use_pos_num, :, :, :], sample_batch[pos_num:, :, :, :]), dim=0)
167 | sample_truth = torch.cat((sample_truth[0:use_pos_num, :], sample_truth[pos_num:, :]), dim=0)
168 | pos_num = use_pos_num
169 | elif neg_num >= use_neg_num:
170 | sample_batch = sample_batch[0:pos_num+use_neg_num,:,:,:]
171 | sample_truth = sample_truth[0:pos_num+use_neg_num, :]
172 | neg_num = use_neg_num
173 |
174 | if neg_num == 0:
175 | continue
176 |
177 | input_batch = torch.cat((current_batch, sample_batch), dim=0)
178 |
179 | input_batch.requires_grad_(True)
180 | self.amodel.train()
181 | self.optimizer.zero_grad()
182 |
183 | global_des = self.amodel(input_batch)
184 | o1, o2, o3 = torch.split(
185 | global_des, [1, pos_num, neg_num], dim=0)
186 | MARGIN_1 = 0.5
187 | """
188 | triplet_loss: Lazy for pos
189 | triplet_loss_inv: Lazy for neg
190 | """
191 | loss = PNV_loss.triplet_loss(o1, o2, o3, MARGIN_1, lazy=False)
192 | # loss = PNV_loss.triplet_loss_inv(o1, o2, o3, MARGIN_1, lazy=False, use_min=True)
193 | loss.backward()
194 | self.optimizer.step()
195 |
196 | if self.rank == 0:
197 | print(str(used_num), loss)
198 |
199 | if torch.isnan(loss):
200 | if self.rank == 0:
201 | print("Something error ...")
202 | print(pos_num)
203 | print(neg_num)
204 |
205 | loss_each_epoch = loss_each_epoch + loss.item()
206 | used_num = used_num + 1
207 |
208 | if self.rank == 0:
209 | print("epoch {} loss {}".format(i, loss_each_epoch/used_num))
210 | print("saving weights ...")
211 | self.scheduler.step()
212 |
213 | """save trained weights and optimizer states"""
214 | self.save_name = "../weights/train/trained_overlap_transformer"+str(i)+".pth.tar"
215 |
216 | torch.save({
217 | 'epoch': i,
218 | 'state_dict': self.amodel.module.state_dict(),
219 | 'optimizer': self.optimizer.state_dict()
220 | },
221 | self.save_name)
222 |
223 | print("Model Saved As " + f'trained_overlap_transformer{self.rank}_' + str(i) + '.pth.tar')
224 |
225 | writer1.add_scalar("loss", loss_each_epoch / used_num, global_step=i)
226 |
227 | print("validating ......")
228 | with torch.no_grad():
229 | top1_rate = validate_seq_faiss(self.amodel.module, "02")
230 | writer1.add_scalar("top1_rate", top1_rate, global_step=i)
231 |
232 | cleanup()
233 |
234 | if __name__ == '__main__':
235 | parser = argparse.ArgumentParser()
236 | parser.add_argument("--local_rank", type=int)
237 | args = parser.parse_args()
238 |
239 | # load config ================================================================
240 | config_filename = '../config/config.yml'
241 | config = yaml.safe_load(open(config_filename))
242 | data_root_folder = config["data_root"]["data_root_folder"]
243 | training_seqs = config["training_config"]["training_seqs"]
244 | # ============================================================================
245 |
246 | traindata_npzfiles = [os.path.join(data_root_folder, seq, 'overlaps/train_set.npz') for seq in training_seqs]
247 |
248 | torch.cuda.set_device(args.local_rank)
249 |
250 | train_handler = trainHandler(rank=args.local_rank, world_size=torch.cuda.device_count(), height=64, width=900, channels=1, norm_layer=None, use_transformer=False, use_mamba=True, lr=0.000005,
251 | data_root_folder=data_root_folder, train_set=traindata_npzfiles, training_seqs = training_seqs)
252 |
253 | train_handler.train()
254 |
--------------------------------------------------------------------------------
/train/training_overlap_mamba_kitti.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | import os
4 | import sys
5 | p = os.path.dirname(os.path.dirname((os.path.abspath(__file__))))
6 | if p not in sys.path:
7 | sys.path.append(p)
8 | sys.path.append('../tools/')
9 | sys.path.append('../modules/')
10 | import torch
11 | import numpy as np
12 | from tensorboardX import SummaryWriter
13 | from tools.read_all_sets import overlap_orientation_npz_file2string_string_nparray
14 | from modules.overlap_mamba import featureExtracter
15 | from tools.read_samples import read_one_batch_pos_neg
16 | from tools.read_samples import read_one_need_from_seq
17 | np.set_printoptions(threshold=sys.maxsize)
18 | import modules.loss as PNV_loss
19 | from tools.utils.utils import *
20 | from valid.valid_seq import validate_seq_faiss
21 | import yaml
22 | from datetime import datetime
23 |
24 | def shift(tensor, dim, index):
25 | length = tensor.size(dim)
26 | shifted_tensor = torch.cat((tensor.narrow(dim, index, length - index),
27 | tensor.narrow(dim, 0, index)), dim=dim)
28 | return shifted_tensor
29 |
30 | def unshift(tensor, dim, index):
31 | length = tensor.size(dim)
32 | unshifted_tensor = torch.cat((tensor.narrow(dim, length - index, index),
33 | tensor.narrow(dim, 0, length - index)), dim=dim)
34 | return unshifted_tensor
35 |
36 |
37 | class trainHandler():
38 | def __init__(self, height=64, width=900, channels=5, norm_layer=None, lr = 0.001,
39 | data_root_folder = None, train_set=None, training_seqs=None):
40 | super(trainHandler, self).__init__()
41 |
42 | self.height = height
43 | self.width = width
44 | self.channels = channels
45 | self.norm_layer = norm_layer
46 | self.learning_rate = lr
47 | self.data_root_folder = data_root_folder
48 | self.train_set = train_set
49 | self.training_seqs = training_seqs
50 |
51 | self.amodel = featureExtracter(channels=self.channels)
52 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
53 | self.amodel.to(self.device)
54 | self.parameters = self.amodel.parameters()
55 | self.optimizer = torch.optim.Adam(self.parameters, self.learning_rate)
56 | #
57 | # self.optimizer = torch.optim.SGD(self.parameters, lr=self.learning_rate, momentum=0.9)
58 | self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=5, gamma=0.9)
59 | # self.scheduler = torch.optim.lr_scheduler.ExponentialLR(self.optimizer, gamma=0.98)
60 |
61 | self.traindata_npzfiles = train_set
62 |
63 | (self.train_imgf1, self.train_imgf2, self.train_dir1, self.train_dir2, self.train_overlap) = \
64 | overlap_orientation_npz_file2string_string_nparray(self.traindata_npzfiles)
65 |
66 | """change the args for resuming training process"""
67 | self.resume = False
68 | self.save_name = "../weights/shift_bimamba_sppf_80/trained_overlap_transformer19.pth.tar"
69 | """overlap threshold follows OverlapNet"""
70 | self.overlap_thresh = 0.3
71 |
72 | def train(self):
73 |
74 | epochs = 30
75 |
76 | """resume from the saved model"""
77 | if self.resume:
78 | resume_filename = self.save_name
79 | print("Resuming from ", resume_filename)
80 | checkpoint = torch.load(resume_filename)
81 | starting_epoch = checkpoint['epoch']
82 | self.amodel.load_state_dict(checkpoint['state_dict'])
83 | self.optimizer.load_state_dict(checkpoint['optimizer'])
84 | else:
85 | print("Training From Scratch ..." )
86 | starting_epoch = 0
87 |
88 | writer1 = SummaryWriter(comment="LR_0.xxxx", log_dir="test_device")
89 |
90 | for i in range(starting_epoch+1, epochs):
91 |
92 | (self.train_imgf1, self.train_imgf2, self.train_dir1, self.train_dir2, self.train_overlap) = \
93 | overlap_orientation_npz_file2string_string_nparray(self.traindata_npzfiles, shuffle=True)
94 |
95 | print("=======================================================================\n\n\n")
96 |
97 | print("training with seq: ", np.unique(np.array(self.train_dir1)))
98 | print("total pairs: ", len(self.train_imgf1))
99 | print("\n\n\n=======================================================================")
100 |
101 | loss_each_epoch = 0
102 | used_num = 0
103 |
104 | used_list_f1 = []
105 | used_list_dir1 = []
106 |
107 | for j in range(len(self.train_imgf1)):
108 | """
109 | check whether the query is used to train before (continue_flag==True/False).
110 | TODO: More efficient method
111 | """
112 | f1_index = self.train_imgf1[j]
113 | dir1_index = self.train_dir1[j]
114 | continue_flag = False
115 | for iddd in range(len(used_list_f1)):
116 | if f1_index==used_list_f1[iddd] and dir1_index==used_list_dir1[iddd]:
117 | continue_flag = True
118 | else:
119 | used_list_f1.append(f1_index)
120 | used_list_dir1.append(dir1_index)
121 |
122 | if continue_flag:
123 | continue
124 |
125 | """read one query range image from KITTI sequences"""
126 | current_batch = read_one_need_from_seq(self.data_root_folder, f1_index, dir1_index)
127 |
128 | """
129 | read several reference range images from KITTI sequences
130 | to consist of positive samples and negative samples
131 | """
132 | sample_batch, sample_truth, pos_num, neg_num = read_one_batch_pos_neg \
133 | (self.data_root_folder,f1_index, dir1_index,
134 | self.train_imgf1, self.train_imgf2, self.train_dir1, self.train_dir2, self.train_overlap,
135 | self.overlap_thresh)
136 |
137 | """
138 | the balance of positive samples and negative samples.
139 | TODO: Update for better training results
140 | """
141 | use_pos_num = 6
142 | use_neg_num = 6
143 | if pos_num >= use_pos_num and neg_num>=use_neg_num:
144 | sample_batch = torch.cat((sample_batch[0:use_pos_num, :, :, :], sample_batch[pos_num:pos_num+use_neg_num, :, :, :]), dim=0)
145 | sample_truth = torch.cat((sample_truth[0:use_pos_num, :], sample_truth[pos_num:pos_num+use_neg_num, :]), dim=0)
146 | pos_num = use_pos_num
147 | neg_num = use_neg_num
148 | elif pos_num >= use_pos_num:
149 | sample_batch = torch.cat((sample_batch[0:use_pos_num, :, :, :], sample_batch[pos_num:, :, :, :]), dim=0)
150 | sample_truth = torch.cat((sample_truth[0:use_pos_num, :], sample_truth[pos_num:, :]), dim=0)
151 | pos_num = use_pos_num
152 | elif neg_num >= use_neg_num:
153 | sample_batch = sample_batch[0:pos_num+use_neg_num,:,:,:]
154 | sample_truth = sample_truth[0:pos_num+use_neg_num, :]
155 | neg_num = use_neg_num
156 |
157 | if neg_num == 0:
158 | continue
159 |
160 | input_batch = torch.cat((current_batch, sample_batch), dim=0)
161 |
162 | input_batch.requires_grad_(True)
163 | self.amodel.train()
164 | self.optimizer.zero_grad()
165 |
166 | index = int(torch.rand(1).item() * 900)
167 | input_batch_shift = shift(input_batch, 3, index)
168 |
169 |
170 | global_des = self.amodel(input_batch)
171 | global_des_shift = self.amodel(input_batch_shift)
172 | global_des_add = torch.cat((global_des, global_des_shift), dim=1)
173 |
174 | o1, o2, o3 = torch.split(
175 | global_des_add, [1, pos_num, neg_num], dim=0)
176 | MARGIN_1 = 0.5
177 | """
178 | triplet_loss: Lazy for pos
179 | triplet_loss_inv: Lazy for neg
180 | """
181 | loss = PNV_loss.triplet_loss(o1, o2, o3, MARGIN_1, lazy=False)
182 | # loss = PNV_loss.triplet_loss_inv(o1, o2, o3, MARGIN_1, lazy=False, use_min=True)
183 | loss.backward()
184 | self.optimizer.step()
185 |
186 | current_time = datetime.now()
187 | formatted_time = current_time.strftime("[%Y-%m-%d %H:%M:%S]")
188 |
189 | if used_num % 1000 == 0:
190 | print(formatted_time, str(used_num), loss)
191 |
192 | if torch.isnan(loss):
193 | print("Something error ...")
194 | print(pos_num)
195 | print(neg_num)
196 |
197 | loss_each_epoch = loss_each_epoch + loss.item()
198 | used_num = used_num + 1
199 |
200 |
201 | print("epoch {} loss {}".format(i, loss_each_epoch/used_num))
202 | print("saving weights ...")
203 | self.scheduler.step()
204 |
205 | """save trained weights and optimizer states"""
206 | self.save_name = "../weights/nclt_cvtnet/trained_overlap_transformer"+str(i)+".pth.tar"
207 |
208 | torch.save({
209 | 'epoch': i,
210 | 'state_dict': self.amodel.state_dict(),
211 | 'optimizer': self.optimizer.state_dict()
212 | },
213 | self.save_name)
214 |
215 | print("Model Saved As " + 'trained_overlap_transformer' + str(i) + '.pth.tar')
216 |
217 | writer1.add_scalar("loss", loss_each_epoch / used_num, global_step=i)
218 |
219 | """a simple validation with KITTI 02"""
220 | print("validating ......")
221 | with torch.no_grad():
222 | top1_rate = validate_seq_faiss(self.amodel, "02")
223 | writer1.add_scalar("top1_rate", top1_rate, global_step=i)
224 |
225 |
226 | if __name__ == '__main__':
227 | # load config ================================================================
228 | config_filename = '../config/config.yml'
229 | config = yaml.safe_load(open(config_filename))
230 | data_root_folder = config["data_root"]["data_root_folder"]
231 | training_seqs = config["training_config"]["training_seqs"]
232 | # ============================================================================
233 |
234 | # along the lines of OverlapNet
235 | traindata_npzfiles = [os.path.join(data_root_folder, seq, 'overlaps/train_set.npz') for seq in training_seqs]
236 |
237 | """
238 | trainHandler to handle with training process.
239 | Args:
240 | height: the height of the range image (the beam number for convenience).
241 | width: the width of the range image (900, alone the lines of OverlapNet).
242 | channels: 1 for depth only in our work.
243 | norm_layer: None in our work for better model.
244 | use_transformer: Whether to use MHSA.
245 | lr: learning rate, which needs to fine tune while training for the best performance.
246 | data_root_folder: root of KITTI sequences. It's better to follow our file structure.
247 | train_set: traindata_npzfiles (alone the lines of OverlapNet).
248 | training_seqs: sequences number for training (alone the lines of OverlapNet).
249 | """
250 | train_handler = trainHandler(height=64, width=900, channels=1, norm_layer=None, lr=0.000005,
251 | data_root_folder=data_root_folder, train_set=traindata_npzfiles, training_seqs = training_seqs)
252 |
253 | train_handler.train()
254 |
--------------------------------------------------------------------------------
/train/training_overlap_mambar_nclt.py:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | import os
5 | import sys
6 |
7 | p = os.path.dirname(os.path.dirname((os.path.abspath(__file__))))
8 | if p not in sys.path:
9 | sys.path.append(p)
10 | sys.path.append('../tools/')
11 | sys.path.append('../modules/')
12 | import torch
13 | import numpy as np
14 | from tensorboardX import SummaryWriter
15 | from tools.read_all_sets import overlap_orientation_npz_file2string_string_nparray
16 | from modules.overlap_mamba import featureExtracter
17 | from tools.read_samples import read_one_batch_pos_neg
18 | from tools.read_samples import read_one_need_from_seq
19 |
20 | np.set_printoptions(threshold=sys.maxsize)
21 | import modules.loss as PNV_loss
22 | from tools.utils.utils import *
23 | from valid.valid_seq import validate_seq_faiss
24 | import yaml
25 | from datetime import datetime
26 |
27 |
28 | def shift(tensor, dim, index):
29 | length = tensor.size(dim)
30 | shifted_tensor = torch.cat((tensor.narrow(dim, index, length - index),
31 | tensor.narrow(dim, 0, index)), dim=dim)
32 | return shifted_tensor
33 |
34 | def unshift(tensor, dim, index):
35 | length = tensor.size(dim)
36 | unshifted_tensor = torch.cat((tensor.narrow(dim, length - index, index),
37 | tensor.narrow(dim, 0, length - index)), dim=dim)
38 | return unshifted_tensor
39 |
40 |
41 | class trainHandler():
42 | def __init__(self, height=32, width=900, channels=5, norm_layer=None, lr=0.001,
43 | data_root_folder=None, train_set=None, training_seqs=None):
44 | super(trainHandler, self).__init__()
45 |
46 | self.height = height
47 | self.width = width
48 | self.channels = channels
49 | self.norm_layer = norm_layer
50 | self.learning_rate = lr
51 | self.data_root_folder = data_root_folder
52 | self.train_set = train_set
53 | self.training_seqs = training_seqs
54 |
55 | self.amodel = featureExtracter(channels=self.channels)
56 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
57 | self.amodel.to(self.device)
58 | self.parameters = self.amodel.parameters()
59 | self.optimizer = torch.optim.Adam(self.parameters, self.learning_rate)
60 | #
61 | # self.optimizer = torch.optim.SGD(self.parameters, lr=self.learning_rate, momentum=0.9)
62 | self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=5, gamma=0.9)
63 | # self.scheduler = torch.optim.lr_scheduler.ExponentialLR(self.optimizer, gamma=0.98)
64 |
65 | self.traindata_npzfiles = train_set
66 |
67 | (self.train_imgf1, self.train_imgf2, self.train_dir1, self.train_dir2, self.train_overlap) = \
68 | overlap_orientation_npz_file2string_string_nparray(self.traindata_npzfiles)
69 |
70 | """change the args for resuming training process"""
71 | self.resume = False
72 | self.save_name = "/home/robot/下载/OverlapTransformer-master-copy/weights/nclt_noimloss/trained_overlap_transformer4.pth.tar"
73 | """overlap threshold follows OverlapNet"""
74 | self.overlap_thresh = 0.3
75 |
76 | def train(self):
77 |
78 | epochs = 30
79 |
80 | """resume from the saved model"""
81 | if self.resume:
82 | resume_filename = self.save_name
83 | print("Resuming from ", resume_filename)
84 | checkpoint = torch.load(resume_filename)
85 | starting_epoch = checkpoint['epoch']
86 | self.amodel.load_state_dict(checkpoint['state_dict'])
87 | self.optimizer.load_state_dict(checkpoint['optimizer'])
88 | else:
89 | print("Training From Scratch ...")
90 | starting_epoch = 0
91 |
92 | writer1 = SummaryWriter(comment="LR_0.xxxx", log_dir="test_device")
93 |
94 | for i in range(starting_epoch + 1, epochs):
95 |
96 | (self.train_imgf1, self.train_imgf2, self.train_dir1, self.train_dir2, self.train_overlap) = \
97 | overlap_orientation_npz_file2string_string_nparray(self.traindata_npzfiles, shuffle=True)
98 |
99 | print("=======================================================================\n\n\n")
100 |
101 | print("training with seq: ", np.unique(np.array(self.train_dir1)))
102 | print("total pairs: ", len(self.train_imgf1))
103 | print("\n\n\n=======================================================================")
104 |
105 | loss_each_epoch = 0
106 | used_num = 0
107 |
108 | used_list_f1 = []
109 | used_list_dir1 = []
110 |
111 | for j in range(len(self.train_imgf1)):
112 | """
113 | check whether the query is used to train before (continue_flag==True/False).
114 | TODO: More efficient method
115 | """
116 | f1_index = self.train_imgf1[j]
117 | dir1_index = self.train_dir1[j]
118 | continue_flag = False
119 | for iddd in range(len(used_list_f1)):
120 | if f1_index == used_list_f1[iddd] and dir1_index == used_list_dir1[iddd]:
121 | continue_flag = True
122 | else:
123 | used_list_f1.append(f1_index)
124 | used_list_dir1.append(dir1_index)
125 |
126 | if continue_flag:
127 | continue
128 |
129 | """read one query range image from KITTI sequences"""
130 | current_batch = read_one_need_from_seq(self.data_root_folder, f1_index, dir1_index)
131 |
132 | """
133 | read several reference range images from KITTI sequences
134 | to consist of positive samples and negative samples
135 | """
136 | sample_batch, sample_truth, pos_num, neg_num = read_one_batch_pos_neg \
137 | (self.data_root_folder, f1_index, dir1_index,
138 | self.train_imgf1, self.train_imgf2, self.train_dir1, self.train_dir2, self.train_overlap,
139 | self.overlap_thresh)
140 |
141 | """
142 | the balance of positive samples and negative samples.
143 | TODO: Update for better training results
144 | """
145 | use_pos_num = 6
146 | use_neg_num = 6
147 | if pos_num >= use_pos_num and neg_num >= use_neg_num:
148 | sample_batch = torch.cat(
149 | (sample_batch[0:use_pos_num, :, :, :], sample_batch[pos_num:pos_num + use_neg_num, :, :, :]),
150 | dim=0)
151 | sample_truth = torch.cat(
152 | (sample_truth[0:use_pos_num, :], sample_truth[pos_num:pos_num + use_neg_num, :]), dim=0)
153 | pos_num = use_pos_num
154 | neg_num = use_neg_num
155 | elif pos_num >= use_pos_num:
156 | sample_batch = torch.cat((sample_batch[0:use_pos_num, :, :, :], sample_batch[pos_num:, :, :, :]),
157 | dim=0)
158 | sample_truth = torch.cat((sample_truth[0:use_pos_num, :], sample_truth[pos_num:, :]), dim=0)
159 | pos_num = use_pos_num
160 | elif neg_num >= use_neg_num:
161 | sample_batch = sample_batch[0:pos_num + use_neg_num, :, :, :]
162 | sample_truth = sample_truth[0:pos_num + use_neg_num, :]
163 | neg_num = use_neg_num
164 |
165 | if neg_num == 0:
166 | continue
167 |
168 | # inputbatch = torch.cat((current_batch, shift_batch1), dim=0)
169 | # inputbatch1 = torch.cat((inputbatch, shift_batch2), dim=0)
170 | # input_batch = torch.cat((inputbatch1, sample_batch), dim=0)
171 |
172 | input_batch = torch.cat((current_batch, sample_batch), dim=0)
173 |
174 | input_batch.requires_grad_(True)
175 | self.amodel.train()
176 | self.optimizer.zero_grad()
177 |
178 | index = int(torch.rand(1).item() * 900)
179 | input_batch_shift = shift(input_batch, 3, index)
180 |
181 | global_des = self.amodel(input_batch)
182 | global_des_shift = self.amodel(input_batch_shift)
183 | global_des_add = torch.cat((global_des, global_des_shift), dim=1)
184 |
185 | o1, o2, o3 = torch.split(
186 | global_des_add, [1, pos_num, neg_num], dim=0)
187 | MARGIN_1 = 0.5
188 | """
189 | triplet_loss: Lazy for pos
190 | triplet_loss_inv: Lazy for neg
191 | """
192 | loss = PNV_loss.triplet_loss(o1, o2, o3, MARGIN_1, lazy=False)
193 | # loss = PNV_loss.triplet_loss_inv(o1, o2, o3, MARGIN_1, lazy=False, use_min=True)
194 | loss.backward()
195 | self.optimizer.step()
196 |
197 | current_time = datetime.now()
198 | formatted_time = current_time.strftime("[%Y-%m-%d %H:%M:%S]")
199 |
200 | if used_num % 1000 == 0:
201 | print(formatted_time, str(used_num), loss)
202 |
203 | if torch.isnan(loss):
204 | print("Something error ...")
205 | print(pos_num)
206 | print(neg_num)
207 |
208 | loss_each_epoch = loss_each_epoch + loss.item()
209 | used_num = used_num + 1
210 |
211 | print("epoch {} loss {}".format(i, loss_each_epoch / used_num))
212 | print("saving weights ...")
213 | self.scheduler.step()
214 |
215 | """save trained weights and optimizer states"""
216 | self.save_name = "../weights/nclt_cvtnet/trained_overlap_transformer" + str(i) + ".pth.tar"
217 |
218 | torch.save({
219 | 'epoch': i,
220 | 'state_dict': self.amodel.state_dict(),
221 | 'optimizer': self.optimizer.state_dict()
222 | },
223 | self.save_name)
224 |
225 | print("Model Saved As " + 'trained_overlap_mamba' + str(i) + '.pth.tar')
226 |
227 | writer1.add_scalar("loss", loss_each_epoch / used_num, global_step=i)
228 |
229 | """a simple validation with KITTI 02"""
230 | print("validating ......")
231 | # with torch.no_grad():
232 | # top1_rate = validate_seq_faiss(self.amodel, "02")
233 | # writer1.add_scalar("top1_rate", top1_rate, global_step=i)
234 |
235 |
236 | if __name__ == '__main__':
237 | # load config ================================================================
238 | config_filename = '../config/config.yml'
239 | config = yaml.safe_load(open(config_filename))
240 | data_root_folder = config["data_root"]["data_root_folder"]
241 | training_seqs = config["training_config"]["training_seqs"]
242 | # ============================================================================
243 |
244 | # along the lines of OverlapNet
245 | traindata_npzfiles = [os.path.join(data_root_folder, seq, 'overlaps/train_set.npz') for seq in training_seqs]
246 |
247 | """
248 | trainHandler to handle with training process.
249 | Args:
250 | height: the height of the range image (the beam number for convenience).
251 | width: the width of the range image (900, alone the lines of OverlapNet).
252 | channels: 1 for depth only in our work.
253 | norm_layer: None in our work for better model.
254 | use_transformer: Whether to use MHSA.
255 | lr: learning rate, which needs to fine tune while training for the best performance.
256 | data_root_folder: root of KITTI sequences. It's better to follow our file structure.
257 | train_set: traindata_npzfiles (alone the lines of OverlapNet).
258 | training_seqs: sequences number for training (alone the lines of OverlapNet).
259 | """
260 | train_handler = trainHandler(height=32, width=900, channels=1, norm_layer=None, lr=0.000005,
261 | data_root_folder=data_root_folder, train_set=traindata_npzfiles,
262 | training_seqs=training_seqs)
263 |
264 | train_handler.train()
265 |
--------------------------------------------------------------------------------