├── .gitignore ├── README.md ├── app.py ├── datasets_preprocess ├── preprocess_ase.py ├── preprocess_co3d.py └── preprocess_scannetpp.py ├── docs ├── data_preprocess.md └── recon_tips.md ├── evaluation ├── eval_recon.py └── process_gt.py ├── media ├── gradio.png ├── gradio_office.gif ├── gradio_office.jpg ├── replica.gif └── wild.gif ├── recon.py ├── requirements.txt ├── requirements_optional.txt ├── scripts ├── demo_replica.sh ├── demo_vis_wild.sh ├── demo_wild.sh ├── eval_replica.sh ├── train_i2p.sh └── train_l2w.sh ├── slam3r ├── __init__.py ├── blocks │ ├── __init__.py │ ├── basic_blocks.py │ └── multiview_blocks.py ├── datasets │ ├── __init__.py │ ├── base │ │ ├── __init__.py │ │ ├── base_stereo_view_dataset.py │ │ ├── batched_sampler.py │ │ └── easy_dataset.py │ ├── co3d_seq.py │ ├── project_aria_seq.py │ ├── replica_seq.py │ ├── scannetpp_seq.py │ ├── seven_scenes_seq.py │ ├── utils │ │ ├── __init__.py │ │ ├── cropping.py │ │ └── transforms.py │ └── wild_seq.py ├── heads │ ├── __init__.py │ ├── dpt_block.py │ ├── dpt_head.py │ ├── linear_head.py │ └── postprocess.py ├── inference.py ├── losses.py ├── models.py ├── patch_embed.py ├── pos_embed │ ├── __init__.py │ ├── curope │ │ ├── __init__.py │ │ ├── curope.cpp │ │ ├── curope2d.py │ │ ├── kernels.cu │ │ └── setup.py │ └── pos_embed.py ├── utils │ ├── __init__.py │ ├── croco_misc.py │ ├── device.py │ ├── geometry.py │ ├── image.py │ ├── misc.py │ └── recon_utils.py └── viz.py ├── train.py └── visualize.py /.gitignore: -------------------------------------------------------------------------------- 1 | data 2 | checkpoints 3 | results 4 | stuffs 5 | recon 6 | visualization 7 | replica_gt 8 | tmp 9 | 10 | # Byte-compiled / optimized / DLL files 11 | __pycache__/ 12 | *.py[cod] 13 | *$py.class 14 | 15 | # C extensions 16 | *.so 17 | 18 | # Distribution / packaging 19 | .Python 20 | build/ 21 | develop-eggs/ 22 | dist/ 23 | downloads/ 24 | eggs/ 25 | .eggs/ 26 | lib/ 27 | lib64/ 28 | parts/ 29 | sdist/ 30 | var/ 31 | wheels/ 32 | pip-wheel-metadata/ 33 | share/python-wheels/ 34 | *.egg-info/ 35 | .installed.cfg 36 | *.egg 37 | MANIFEST 38 | 39 | # PyInstaller 40 | # Usually these files are written by a python script from a template 41 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 42 | *.manifest 43 | *.spec 44 | 45 | # Installer logs 46 | pip-log.txt 47 | pip-delete-this-directory.txt 48 | 49 | # Unit test / coverage reports 50 | htmlcov/ 51 | .tox/ 52 | .nox/ 53 | .coverage 54 | .coverage.* 55 | .cache 56 | nosetests.xml 57 | coverage.xml 58 | *.cover 59 | *.py,cover 60 | .hypothesis/ 61 | .pytest_cache/ 62 | 63 | # Translations 64 | *.mo 65 | *.pot 66 | 67 | # Django stuff: 68 | *.log 69 | local_settings.py 70 | db.sqlite3 71 | db.sqlite3-journal 72 | 73 | # Flask stuff: 74 | instance/ 75 | .webassets-cache 76 | 77 | # Scrapy stuff: 78 | .scrapy 79 | 80 | # Sphinx documentation 81 | docs/_build/ 82 | 83 | # PyBuilder 84 | target/ 85 | 86 | # Jupyter Notebook 87 | .ipynb_checkpoints 88 | 89 | # IPython 90 | profile_default/ 91 | ipython_config.py 92 | 93 | # pyenv 94 | .python-version 95 | 96 | # pipenv 97 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 98 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 99 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 100 | # install all needed dependencies. 101 | #Pipfile.lock 102 | 103 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 104 | __pypackages__/ 105 | 106 | # Celery stuff 107 | celerybeat-schedule 108 | celerybeat.pid 109 | 110 | # SageMath parsed files 111 | *.sage.py 112 | 113 | # Environments 114 | .env 115 | .venv 116 | env/ 117 | venv/ 118 | ENV/ 119 | env.bak/ 120 | venv.bak/ 121 | 122 | # Spyder project settings 123 | .spyderproject 124 | .spyproject 125 | 126 | # Rope project settings 127 | .ropeproject 128 | 129 | # mkdocs documentation 130 | /site 131 | 132 | # mypy 133 | .mypy_cache/ 134 | .dmypy.json 135 | dmypy.json 136 | 137 | # Pyre type checker 138 | .pyre/ 139 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 6 | 7 | 8 |

9 |

[CVPR 2025 Highlight] SLAM3R: Real-Time Dense Scene Reconstruction from Monocular RGB Videos

10 |

11 | Yuzheng Liu* 12 | · 13 | Siyan Dong* 14 | · 15 | Shuzhe Wang 16 | · 17 | Yingda Yin 18 | · 19 | Yanchao Yang 20 | · 21 | Qingnan Fan 22 | · 23 | Baoquan Chen 24 |

25 |

Paper | Video | Poster

26 | 31 | 32 |
33 |

34 | 35 |
36 | 37 | 38 |
39 | 40 |

41 | SLAM3R is a real-time dense scene reconstruction system that regresses 3D points from video frames using feed-forward neural networks, without explicitly estimating camera parameters. 42 |

43 | 44 | 45 | 46 | ## News 47 | 48 | * **2025-04:** SLAM3R is reported by [机器之心(Chinese)](https://mp.weixin.qq.com/s/fK5vJwbogcfwoduI9FuQ6w) 49 | 50 | * **2025-04:** 🎉 SLAM3R is selected as a **highlight paper** in CVPR 2025 and **Top1 paper** in China3DV 2025. 51 | 52 | ## Table of Contents 53 | 54 | - [Installation](#installation) 55 | - [Demo](#demo) 56 | - [Gradio interface](#gradio-interface) 57 | - [Evaluation on the Replica dataset](#Evaluation-on-the-Replica-dataset) 58 | - [Training](#training) 59 | - [Citation](#citation) 60 | - [Acknowledgments](#acknowledgments) 61 | 62 | ## Installation 63 | 64 | 1. Clone SLAM3R 65 | ```bash 66 | git clone https://github.com/PKU-VCL-3DV/SLAM3R.git 67 | cd SLAM3R 68 | ``` 69 | 70 | 2. Prepare environment 71 | ```bash 72 | conda create -n slam3r python=3.11 cmake=3.14.0 73 | conda activate slam3r 74 | # install torch according to your cuda version 75 | pip install torch==2.5.0 torchvision==0.20.0 torchaudio==2.5.0 --index-url https://download.pytorch.org/whl/cu118 76 | pip install -r requirements.txt 77 | # optional: install additional packages to support visualization and data preprocessing 78 | pip install -r requirements_optional.txt 79 | ``` 80 | 81 | 3. Optional: Accelerate SLAM3R with XFormers and custom cuda kernels for RoPE 82 | ```bash 83 | # install XFormers according to your pytorch version, see https://github.com/facebookresearch/xformers 84 | pip install xformers==0.0.28.post2 85 | # compile cuda kernels for RoPE 86 | # if the compilation fails, try the propoesd solution: https://github.com/CUT3R/CUT3R/issues/7. 87 | cd slam3r/pos_embed/curope/ 88 | python setup.py build_ext --inplace 89 | cd ../../../ 90 | ``` 91 | 92 | 4. Optional: Download the SLAM3R checkpoints for the [Image-to-Points](https://huggingface.co/siyan824/slam3r_i2p) and [Local-to-World](https://huggingface.co/siyan824/slam3r_l2w) models through HuggingFace 93 | ```bash 94 | from slam3r.models import Image2PointsModel, Local2WorldModel 95 | Image2PointsModel.from_pretrained('siyan824/slam3r_i2p') 96 | Local2WorldModel.from_pretrained('siyan824/slam3r_l2w') 97 | ``` 98 | The pre-trained model weights will automatically download when running the demo and evaluation code below. 99 | 100 | 101 | ## Demo 102 | ### Replica dataset 103 | To run our demo on Replica dataset, download the sample scene [here](https://drive.google.com/file/d/1NmBtJ2A30qEzdwM0kluXJOp2d1Y4cRcO/view?usp=drive_link) and unzip it to `./data/Replica_demo/`. Then run the following command to reconstruct the scene from the video images 104 | 105 | ```bash 106 | bash scripts/demo_replica.sh 107 | ``` 108 | 109 | The results will be stored at `./results/` by default. 110 | 111 | ### Self-captured outdoor data 112 | We also provide a set of images extracted from an in-the-wild captured video. Download it [here](https://drive.google.com/file/d/1FVLFXgepsqZGkIwg4RdeR5ko_xorKyGt/view?usp=drive_link) and unzip it to `./data/wild/`. 113 | 114 | Set the required parameter in this [script](./scripts/demo_wild.sh), and then run SLAM3R by using the following command 115 | 116 | ```bash 117 | bash scripts/demo_wild.sh 118 | ``` 119 | 120 | When `--save_preds` is set in the script, the per-frame prediction for reconstruction will be saved at `./results/TEST_NAME/preds/`. Then you can visualize the incremental reconstruction process with the following command 121 | 122 | ```bash 123 | bash scripts/demo_vis_wild.sh 124 | ``` 125 | 126 | A Open3D window will appear after running the script. Please click `space key` to record the adjusted rendering view and close the window. The code will then do the rendering of the incremental reconstruction. 127 | 128 | You can run SLAM3R on your self-captured video with the steps above. Here are [some tips](./docs/recon_tips.md) for it 129 | 130 | 131 | ## Gradio interface 132 | We also provide a Gradio interface, where you can upload a directory, a video or specific images to perform the reconstruction. After setting the reconstruction parameters, you can click the 'Run' button to start the process. Modifying the visualization parameters at the bottom allows you to directly display different visualization results without rerunning the inference. 133 | 134 | The interface can be launched with the following command: 135 | 136 | ```bash 137 | python app.py 138 | ``` 139 | 140 | Here is a demo GIF for the Gradio interface (accelerated). 141 | 142 | 143 | 144 | 145 | ## Evaluation on the Replica dataset 146 | 147 | 1. Download the Replica dataset generated by the authors of iMAP: 148 | ```bash 149 | cd data 150 | wget https://cvg-data.inf.ethz.ch/nice-slam/data/Replica.zip 151 | unzip Replica.zip 152 | rm -rf Replica.zip 153 | ``` 154 | 155 | 2. Obtain the GT pointmaps and valid masks for each frame by running the following command: 156 | ```bash 157 | python evaluation/process_gt.py 158 | ``` 159 | The processed GT will be saved at `./results/gt/replica`. 160 | 161 | 3. Evaluate the reconstruction on the Replica dataset with the following command: 162 | 163 | ```bash 164 | bash ./scripts/eval_replica.sh 165 | ``` 166 | 167 | Both the numerical results and the error heatmaps will be saved in the directory `./results/TEST_NAME/eval/`. 168 | 169 | > [!NOTE] 170 | > Different versions of CUDA, PyTorch, and xformers can lead to slight variations in the predicted point cloud. These differences may be amplified during the alignment process in evaluation. Consequently, the numerical results you obtain might differ from those reported in the paper. However, the average values should remain approximately the same. 171 | 172 | ## Training 173 | 174 | ### Datasets 175 | 176 | We use ScanNet++, Aria Synthetic Environments and Co3Dv2 to train our models. For data downloading and pre-processing, please refer to [here](./docs/data_preprocess.md). 177 | 178 | ### Pretrained weights 179 | 180 | ```bash 181 | # download the pretrained weights from DUSt3R 182 | mkdir checkpoints 183 | wget https://download.europe.naverlabs.com/ComputerVision/DUSt3R/DUSt3R_ViTLarge_BaseDecoder_224_linear.pth -P checkpoints/ 184 | ``` 185 | 186 | ### Start training 187 | 188 | ```bash 189 | # train the Image-to-Points model and the retrieval module 190 | bash ./scripts/train_i2p.sh 191 | # train the Local-to-Wrold model 192 | bash ./scripts/train_l2w.sh 193 | ``` 194 | > [!NOTE] 195 | > They are not strictly equivalent to what was used to train SLAM3R, but they should be close enough. 196 | 197 | 198 | ## Citation 199 | 200 | If you find our work helpful in your research, please consider citing: 201 | ``` 202 | @article{slam3r, 203 | title={SLAM3R: Real-Time Dense Scene Reconstruction from Monocular RGB Videos}, 204 | author={Liu, Yuzheng and Dong, Siyan and Wang, Shuzhe and Yin, Yingda and Yang, Yanchao and Fan, Qingnan and Chen, Baoquan}, 205 | journal={arXiv preprint arXiv:2412.09401}, 206 | year={2024} 207 | } 208 | ``` 209 | 210 | 211 | ## Acknowledgments 212 | 213 | Our implementation is based on several awesome repositories: 214 | 215 | - [Croco](https://github.com/naver/croco) 216 | - [DUSt3R](https://github.com/naver/dust3r) 217 | - [NICER-SLAM](https://github.com/cvg/nicer-slam) 218 | - [Spann3R](https://github.com/HengyiWang/spann3r) 219 | 220 | We thank the respective authors for open-sourcing their code. 221 | 222 | -------------------------------------------------------------------------------- /datasets_preprocess/preprocess_ase.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -------------------------------------------------------- 3 | # Script to pre-process the aria-ase dataset 4 | # Usage: 5 | # 1. Prepare the codebase and environment for the projectaria_tools 6 | # 2. copy this script to the project root directory 7 | # 3. Run the script 8 | # -------------------------------------------------------- 9 | import matplotlib.colors as colors 10 | import matplotlib.pyplot as plt 11 | import numpy as np 12 | import plotly.graph_objects as go 13 | from pathlib import Path 14 | import os 15 | from PIL import Image 16 | from scipy.spatial.transform import Rotation as R 17 | from projectaria_tools.projects import ase 18 | from projectaria_tools.core import data_provider, calibration 19 | from projectaria_tools.core.image import InterpolationMethod 20 | from projects.AriaSyntheticEnvironment.tutorial.code_snippets.readers import read_trajectory_file 21 | import cv2 22 | from tqdm import tqdm 23 | import os, sys, json 24 | import open3d as o3d 25 | import random 26 | 27 | 28 | def save_pointcloud(points_3d_array, rgb ,pcd_name): 29 | # Flatten the instance values array 30 | rgb_values_flat = rgb 31 | 32 | # Check if the number of points matches the number of instance values 33 | assert points_3d_array.shape[0] == rgb_values_flat.shape[0], "The number of points must match the number of instance values" 34 | 35 | # Create an Open3D point cloud object 36 | pcd = o3d.geometry.PointCloud() 37 | 38 | # Assign the 3D points to the point cloud object 39 | pcd.points = o3d.utility.Vector3dVector(points_3d_array) 40 | 41 | # Assign the colors to the point cloud 42 | pcd.colors = o3d.utility.Vector3dVector(rgb_values_flat / 255.0) # Normalize colors to [0, 1] 43 | 44 | # Define the file path where you want to save the point cloud 45 | output_file_path = pcd_name+'.pcd' 46 | 47 | # Save the point cloud in PCD format 48 | o3d.io.write_point_cloud(output_file_path, pcd) 49 | 50 | print(f"Point cloud saved to {output_file_path}") 51 | 52 | 53 | def unproject(camera_params, undistorted_depth,undistorted_rgb): 54 | # Get the height and width of the depth image 55 | height, width = undistorted_depth.shape 56 | 57 | # Generate pixel coordinates 58 | y, x = np.indices((height, width)) 59 | pixel_coords = np.stack((x, y), axis=-1).reshape(-1, 2) 60 | 61 | # Flatten the depth image to create a 1D array of depth values 62 | depth_values_flat = undistorted_depth.flatten() 63 | rgb_values_flat = undistorted_rgb.reshape(-1,3) 64 | 65 | # Initialize an array to store 3D points 66 | points_3d = [] 67 | valid_rgb = [] 68 | 69 | for pixel_coord, depth, rgb in zip(pixel_coords, depth_values_flat, rgb_values_flat): 70 | # Format the pixel coordinate for unproject (reshape to [2, 1]) 71 | pixel_coord_reshaped = np.array([[pixel_coord[0]], [pixel_coord[1]]], dtype=np.float64) 72 | 73 | # Unproject the pixel to get the direction vector (ray) 74 | # direction_vector = device.unproject(pixel_coord_reshaped) 75 | X = (pixel_coord_reshaped[0] - camera_params[2]) / camera_params[0] # X = (u - cx) / fx 76 | Y = (pixel_coord_reshaped[1] - camera_params[3]) / camera_params[1] # Y = (v - cy) / fy 77 | direction_vector = np.array([X[0], Y[0], 1],dtype=np.float32) 78 | if direction_vector is not None: 79 | # Replace the z-value of the direction vector with the depth value 80 | # Assuming the direction vector is normalized 81 | direction_vector_normalized = direction_vector / np.linalg.norm(direction_vector) 82 | point_3d = direction_vector_normalized * (depth / 1000) 83 | 84 | # Append the computed 3D point and the corresponding instance 85 | points_3d.append(point_3d.flatten()) 86 | valid_rgb.append(rgb) 87 | 88 | # Convert the list of 3D points to a numpy array 89 | points_3d_array = np.array(points_3d) 90 | points_rgb = np.array(valid_rgb) 91 | return points_3d_array,points_rgb 92 | 93 | def distance_to_depth(K, dist, uv=None): 94 | if uv is None and len(dist.shape) >= 2: 95 | # create mesh grid according to d 96 | uv = np.stack(np.meshgrid(np.arange(dist.shape[1]), np.arange(dist.shape[0])), -1) 97 | uv = uv.reshape(-1, 2) 98 | dist = dist.reshape(-1) 99 | if not isinstance(dist, np.ndarray): 100 | import torch 101 | uv = torch.from_numpy(uv).to(dist) 102 | if isinstance(dist, np.ndarray): 103 | # z * np.sqrt(x_temp**2+y_temp**2+z_temp**2) = dist 104 | uvh = np.concatenate([uv, np.ones((len(uv), 1))], -1) 105 | uvh = uvh.T # N, 3 106 | temp_point = np.linalg.inv(K) @ uvh # 3, N 107 | temp_point = temp_point.T # N, 3 108 | z = dist / np.linalg.norm(temp_point, axis=1) 109 | else: 110 | uvh = torch.cat([uv, torch.ones(len(uv), 1).to(uv)], -1) 111 | temp_point = torch.inverse(K) @ uvh 112 | z = dist / torch.linalg.norm(temp_point, dim=1) 113 | return z 114 | 115 | def transform_3d_points(transform, points): 116 | N = len(points) 117 | points_h = np.concatenate([points, np.ones((N, 1))], axis=1) 118 | transformed_points_h = (transform @ points_h.T).T 119 | transformed_points = transformed_points_h[:, :-1] 120 | return transformed_points 121 | 122 | 123 | def aria_export_to_scannet(scene_id, seed): 124 | random.seed(int(seed + scene_id)) 125 | src_folder = Path("ase_raw/"+str(scene_id)) 126 | trgt_folder = Path("ase_processed/"+str(scene_id)) 127 | trgt_folder.mkdir(parents=True, exist_ok=True) 128 | SCENE_ID = src_folder.stem 129 | print("SCENE_ID:", SCENE_ID) 130 | 131 | scene_max_depth = 0 132 | scene_min_depth = np.inf 133 | Path(trgt_folder, "intrinsic").mkdir(exist_ok=True) 134 | Path(trgt_folder, "pose").mkdir(exist_ok=True) 135 | Path(trgt_folder, "depth").mkdir(exist_ok=True) 136 | Path(trgt_folder, "color").mkdir(exist_ok=True) 137 | 138 | rgb_dir = src_folder / "rgb" 139 | depth_dir = src_folder / "depth" 140 | # Load camera calibration 141 | device = ase.get_ase_rgb_calibration() 142 | # Load the trajectory using read_trajectory_file() 143 | trajectory_path = src_folder / "trajectory.csv" 144 | trajectory = read_trajectory_file(trajectory_path) 145 | all_points_3d = [] 146 | all_rgb = [] 147 | num_frames = len(list(rgb_dir.glob("*.jpg"))) 148 | # Path('./debug').mkdir(exist_ok=True) 149 | for frame_idx in tqdm(range(num_frames)): 150 | frame_id = str(frame_idx).zfill(7) 151 | rgb_path = rgb_dir / f"vignette{frame_id}.jpg" 152 | depth_path = depth_dir / f"depth{frame_id}.png" 153 | depth = Image.open(depth_path) # uint16 154 | rgb = cv2.imread(str(rgb_path), cv2.IMREAD_UNCHANGED) 155 | depth = np.array(depth) 156 | scene_min_depth = min(depth.min(), scene_min_depth) 157 | inf_value = np.iinfo(np.array(depth).dtype).max 158 | depth[depth == inf_value] = 0 # consider it as invalid, inplace with 0 159 | T_world_from_device = trajectory["Ts_world_from_device"][frame_idx] # camera-to-world 160 | assert device.get_image_size()[0] == 704 161 | # https://facebookresearch.github.io/projectaria_tools/docs/data_utilities/advanced_code_snippets/image_utilities 162 | focal_length = device.get_focal_lengths()[0] 163 | pinhole = calibration.get_linear_camera_calibration( 164 | 512, 165 | 512, 166 | focal_length, 167 | "camera-rgb", 168 | device.get_transform_device_camera() # important to get correct transformation matrix in pinhole_cw90 169 | ) 170 | # distort image 171 | rectified_rgb = calibration.distort_by_calibration(np.array(rgb), pinhole, device, InterpolationMethod.BILINEAR) 172 | # raw_image = np.array(depth) # Will not work 173 | depth = np.array(depth).astype(np.float32) # WILL WORK 174 | rectified_depth = calibration.distort_by_calibration(depth, pinhole, device) 175 | 176 | rotated_image = np.rot90(rectified_rgb, k=3) 177 | rotated_depth = np.rot90(rectified_depth, k=3) 178 | 179 | cv2.imwrite(str(Path(trgt_folder, "color", f"{frame_id}.jpg")), rotated_image) 180 | # # TODO: check this 181 | # plt.imsave(Path(f"./debug/debug_undistort_{frame_id}.png"), np.uint16(rotated_depth), cmap="plasma") 182 | # Get rotated image calibration 183 | pinhole_cw90 = calibration.rotate_camera_calib_cw90deg(pinhole) 184 | principal = pinhole_cw90.get_principal_point() 185 | cx, cy = principal[0], principal[1] 186 | focal_lengths = pinhole_cw90.get_focal_lengths() 187 | fx, fy = focal_lengths 188 | K = np.array([ # camera-to-pixel 189 | [fx, 0, cx], 190 | [0, fy, cy], 191 | [0, 0, 1.0]]) 192 | 193 | c2w = T_world_from_device 194 | c2w_rotation = pinhole_cw90.get_transform_device_camera().to_matrix() 195 | c2w_final = c2w @ c2w_rotation # right-matmul! 196 | cam2world = c2w_final 197 | 198 | # save depth 199 | rotated_depth = np.uint16(rotated_depth) 200 | depth_image = Image.fromarray(rotated_depth, mode='I;16') 201 | depth_image.save(str(Path(trgt_folder, "depth", f"{frame_id}.png"))) 202 | # for debug; load depth and convert to pointcloud 203 | # depth_image = np.array(Image.open(str(Path(trgt_folder, "depth", f"{frame_id}.png"))), dtype=np.uint16) 204 | # points_3d_array, points_rgb = unproject((fx, fy, cx, cy), depth_image, rotated_image) 205 | # points_3d_world = transform_3d_points(cam2world, points_3d_array) 206 | # all_points_3d.append(points_3d_world) 207 | # all_rgb.append(points_rgb) 208 | # distance-to-depth 209 | # rotated_depth = distance_to_depth(K, rotated_depth).reshape((rotated_depth.shape[0], rotated_depth.shape[1]))#.reshape((dpt.shape[0], dpt.shape[1])) 210 | 211 | Path(trgt_folder, "intrinsic", "intrinsic_color.txt").write_text(f"""{K[0][0]} {K[0][1]} {K[0][2]} 0.00\n{K[1][0]} {K[1][1]} {K[1][2]} 0.00\n{K[2][0]} {K[2][1]} {K[2][2]} 0.00\n0.00 0.00 0.00 1.00""") 212 | Path(trgt_folder, "pose", f"{frame_id}.txt").write_text(f"""{cam2world[0, 0]} {cam2world[0, 1]} {cam2world[0, 2]} {cam2world[0, 3]}\n{cam2world[1, 0]} {cam2world[1, 1]} {cam2world[1, 2]} {cam2world[1, 3]}\n{cam2world[2, 0]} {cam2world[2, 1]} {cam2world[2, 2]} {cam2world[2, 3]}\n0.00 0.00 0.00 1.00""") 213 | 214 | 215 | 216 | if __name__ == "__main__": 217 | seed = 42 218 | for scene_id in tqdm(range(0, 500)): 219 | aria_export_to_scannet(scene_id=scene_id, seed = seed) 220 | 221 | 222 | -------------------------------------------------------------------------------- /docs/data_preprocess.md: -------------------------------------------------------------------------------- 1 | ### ScanNet++ 2 | 3 | 1. Download the [dataset](https://kaldir.vc.in.tum.de/scannetpp/), extract RGB frames and masks from the iPhone data following the [official instruction](https://github.com/scannetpp/scannetpp). 4 | 5 | 2. Preprocess the data with the following command: 6 | 7 | ```bash 8 | python datasets_preprocess/preprocess_scannetpp.py \ 9 | --scannetpp_dir $SCANNETPP_DATA_ROOT\ 10 | --output_dir data/scannetpp_processed 11 | ``` 12 | 13 | the processed data will be saved at `./data/scannetpp_processed` 14 | 15 | > We only use ScanNetpp-V1 (280 scenes in total) to train and validate our SLAM3R models now. ScanNetpp-V2 (906 scenes) is available for potential use, but you may need to modify the scripts for certain scenes in it. 16 | 17 | ### Aria Synthetic Environments 18 | 19 | For more details, please refer to the [official website](https://facebookresearch.github.io/projectaria_tools/docs/open_datasets/aria_synthetic_environments_dataset) 20 | 21 | 1. Prepare the codebase and environment 22 | ```bash 23 | mkdir data/projectaria 24 | cd data/projectaria 25 | git clone https://github.com/facebookresearch/projectaria_tools.git -b 1.5.7 26 | cd - 27 | conda create -n aria python=3.10 28 | conda activate aria 29 | pip install projectaria-tools'[all]' opencv-python open3d 30 | ``` 31 | 32 | 2. Get the download-urls file [here](https://www.projectaria.com/datasets/ase/) and place it under .`/data/projectaria/projectaria_tools`. Then download the ASE dataset: 33 | ```bash 34 | cd ./data/projectaria/projectaria_tools 35 | python projects/AriaSyntheticEnvironment/aria_synthetic_environments_downloader.py \ 36 | --set train \ 37 | --scene-ids 0-499 \ 38 | --unzip True \ 39 | --cdn-file aria_synthetic_environments_dataset_download_urls.json \ 40 | --output-dir $SLAM3R_DIR/data/projectaria/ase_raw 41 | ``` 42 | 43 | > We only use the first 500 scenes to train and validate our SLAM3R models now. You can leverage more scenes depending on your resources. 44 | 45 | 4. Preprocess the data. 46 | ```bash 47 | cp ./datasets_preprocess/preprocess_ase.py ./data/projectaria/projectaria_tools/ 48 | cd ./data/projectaria 49 | python projectaria_tools/preprocess_ase.py 50 | ``` 51 | The processed data will be saved at `./data/projectaria/ase_processed` 52 | 53 | 54 | ### CO3Dv2 55 | 1. Download the [dataset](https://github.com/facebookresearch/co3d) 56 | 57 | 2. Preprocess the data with the same script as in [DUSt3R](https://github.com/naver/dust3r?tab=readme-ov-file), and place the processed data at `./data/co3d_processed`. The data consists of 41 categories for training and 10 categories for validation. 58 | 59 | -------------------------------------------------------------------------------- /docs/recon_tips.md: -------------------------------------------------------------------------------- 1 | # Tips for running SLAM3R on self-captured data 2 | 3 | 4 | ## Image name format 5 | 6 | Your images should be consecutive frames (e.g., sampled from a video)with numbered filenames to indicate their sequential order, such as `frame-0031.color.png, output_0414.jpg, ...` with zero padding. 7 | 8 | 9 | ## Guidance on parameters in the [script](../scripts/demo_wild.sh) 10 | 11 | `KEYFRAME_STRIDE`: The selection of keyframe stride is crucial. A small stride may lack sufficient camera motion, while a large stride can cause insufficient overlap, both of which hinder the performance. We offer an automatic setting option, but you can also configure it manually. 12 | 13 | `CONF_THRES_I2P`: It helps to filter out points predicted by the Image-to-Points model before input to the Local-to-World model. Consider increasing this threshold if there are many areas in your image with unpredictable depth (e.g., sky or black-masked regions). 14 | 15 | `INITIAL_WINSIZE`: The maximum window size for initializing scene reconstruction. Note that if initial confidence scores are too low, the window size will automatically reduce, with a minimum size of 3. 16 | 17 | `BUFFER_SIZE`: The maximum size of the buffering set. The buffer size should be large enough to store the required scene frames, but an excessively large buffer can impede retrieval efficiency and degrade retrieval performance. 18 | 19 | `BUFFER_STRATEGY`: We provide two strategies for maintaining the scene frame buffering set under a limited buffer size. "reservoir" is suitable for single-room scenarios, while "fifo" performs more effectively on larger scenes. As discussed in our paper, SLAM3R suffers from drift issues in very large scenes. 20 | 21 | 22 | ## Failure cases 23 | 24 | SLAM3R's performance can degrade when processing sparse views with large viewpoint changes, low overlaps, or motion blur. In such cases, the default window size may include images that don't overlap enough, leading to reduced performance. This can also lead to incorrect retrievals, causing some frames to be registered outside the main reconstruction. 25 | 26 | Due to limited training data (currently only ScanNet++, Aria Synthetic Environments, and CO3D-v2), SLAM3R cannot process images with unfamiliar camera distortion and has poor performance with wide-angle cameras and panoramas. The system also struggles with dynamic scenes. 27 | -------------------------------------------------------------------------------- /evaluation/process_gt.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os import path as osp 3 | import numpy as np 4 | import sys 5 | from tqdm import tqdm 6 | 7 | SLAM3R_DIR = osp.dirname(osp.dirname(osp.abspath(__file__))) 8 | sys.path.insert(0, SLAM3R_DIR) # noqa: E402 9 | 10 | from slam3r.datasets import Replica 11 | 12 | 13 | def get_replica_gt_pcd(scene_id, save_dir, sample_stride=20): 14 | os.makedirs(save_dir, exist_ok=True) 15 | H, W = 224, 224 16 | dataset = Replica(resolution=(W,H), scene_name=scene_id, num_views=1, sample_freq=sample_stride) 17 | print(dataset[0][0]['pts3d'].shape) 18 | all_pcd = np.zeros([len(dataset),H,W,3]) 19 | valid_masks = np.ones([len(dataset),H,W], dtype=bool) 20 | for id in tqdm(range(len(dataset))): 21 | view = dataset[id][0] 22 | pcd =view['pts3d'] 23 | valid_masks[id] = view['valid_mask'] 24 | all_pcd[id] = pcd 25 | 26 | np.save(os.path.join(save_dir, f"{scene_id}_pcds.npy"), all_pcd) 27 | np.save(os.path.join(save_dir, f"{scene_id}_valid_masks.npy"), valid_masks) 28 | 29 | 30 | 31 | if __name__ == "__main__": 32 | for scene_id in ['office0', 'office1', 'office2', 'office3', 'office4', 'room0', 'room1', 'room2']: 33 | get_replica_gt_pcd(scene_id, sample_stride=1, save_dir="results/gt/replica") -------------------------------------------------------------------------------- /media/gradio.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-VCL-3DV/SLAM3R/c5d0bddb14dc2a04a725a985a3526fc7d5182a5c/media/gradio.png -------------------------------------------------------------------------------- /media/gradio_office.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-VCL-3DV/SLAM3R/c5d0bddb14dc2a04a725a985a3526fc7d5182a5c/media/gradio_office.gif -------------------------------------------------------------------------------- /media/gradio_office.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-VCL-3DV/SLAM3R/c5d0bddb14dc2a04a725a985a3526fc7d5182a5c/media/gradio_office.jpg -------------------------------------------------------------------------------- /media/replica.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-VCL-3DV/SLAM3R/c5d0bddb14dc2a04a725a985a3526fc7d5182a5c/media/replica.gif -------------------------------------------------------------------------------- /media/wild.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-VCL-3DV/SLAM3R/c5d0bddb14dc2a04a725a985a3526fc7d5182a5c/media/wild.gif -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | roma 2 | gradio 3 | matplotlib 4 | tqdm 5 | opencv-python 6 | scipy 7 | einops 8 | trimesh 9 | tensorboard 10 | pyglet<2 11 | huggingface-hub[torch]>=0.22 12 | pycuda 13 | -------------------------------------------------------------------------------- /requirements_optional.txt: -------------------------------------------------------------------------------- 1 | # for visualization and evaluation 2 | open3d 3 | imageio[ffmpeg] 4 | scikit-image 5 | # for rendering depths in scannetpp 6 | pyrender -------------------------------------------------------------------------------- /scripts/demo_replica.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ###################################################################################### 4 | # set the img_dir below to the directory of the set of images you want to reconstruct 5 | # set the postfix below to the format of the rgb images in the img_dir 6 | ###################################################################################### 7 | TEST_DATASET="Seq_Data(img_dir='data/Replica_demo/room0', postfix='.jpg', \ 8 | img_size=224, silent=False, sample_freq=1, \ 9 | start_idx=0, num_views=-1, start_freq=1, to_tensor=True)" 10 | 11 | ###################################################################################### 12 | # set the parameters for whole scene reconstruction below 13 | # for defination of these parameters, please refer to the recon.py 14 | ###################################################################################### 15 | TEST_NAME="Replica_demo" 16 | KEYFRAME_STRIDE=20 17 | UPDATE_BUFFER_INTV=3 18 | MAX_NUM_REGISTER=10 19 | WIN_R=5 20 | NUM_SCENE_FRAME=10 21 | INITIAL_WINSIZE=5 22 | CONF_THRES_L2W=10 23 | CONF_THRES_I2P=1.5 24 | NUM_POINTS_SAVE=1000000 25 | 26 | GPU_ID=-1 27 | 28 | 29 | python recon.py \ 30 | --test_name $TEST_NAME \ 31 | --dataset "${TEST_DATASET}" \ 32 | --gpu_id $GPU_ID \ 33 | --keyframe_stride $KEYFRAME_STRIDE \ 34 | --win_r $WIN_R \ 35 | --num_scene_frame $NUM_SCENE_FRAME \ 36 | --initial_winsize $INITIAL_WINSIZE \ 37 | --conf_thres_l2w $CONF_THRES_L2W \ 38 | --conf_thres_i2p $CONF_THRES_I2P \ 39 | --num_points_save $NUM_POINTS_SAVE \ 40 | --update_buffer_intv $UPDATE_BUFFER_INTV \ 41 | --max_num_register $MAX_NUM_REGISTER 42 | -------------------------------------------------------------------------------- /scripts/demo_vis_wild.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python visualize.py \ 3 | --vis_dir results/wild_demo \ 4 | --save_stride 1 \ 5 | --enhance_z \ 6 | --conf_thres_l2w 12 \ 7 | # --vis_cam -------------------------------------------------------------------------------- /scripts/demo_wild.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | ###################################################################################### 5 | # set the img_dir below to the directory of the set of images you want to reconstruct 6 | # set the postfix below to the format of the rgb images in the img_dir 7 | ###################################################################################### 8 | TEST_DATASET="Seq_Data(img_dir='data/wild/Library', postfix='.png', \ 9 | img_size=224, silent=False, sample_freq=1, \ 10 | start_idx=0, num_views=-1, start_freq=1, to_tensor=True)" 11 | 12 | ###################################################################################### 13 | # set the parameters for whole scene reconstruction below 14 | # for defination of these parameters, please refer to the recon.py 15 | ###################################################################################### 16 | TEST_NAME="wild_demo" 17 | KEYFRAME_STRIDE=-1 #-1 for auto-adaptive keyframe stride selection 18 | WIN_R=5 19 | MAX_NUM_REGISTER=10 20 | NUM_SCENE_FRAME=10 21 | INITIAL_WINSIZE=5 22 | CONF_THRES_L2W=12 23 | CONF_THRES_I2P=1.5 24 | NUM_POINTS_SAVE=1000000 25 | 26 | UPDATE_BUFFER_INTV=1 27 | BUFFER_SIZE=100 # -1 if size is not limited 28 | BUFFER_STRATEGY="reservoir" # or "fifo" 29 | 30 | KEYFRAME_ADAPT_MIN=1 31 | KEYFRAME_ADAPT_MAX=20 32 | KEYFRAME_ADAPT_STRIDE=1 33 | 34 | GPU_ID=-1 35 | 36 | python recon.py \ 37 | --test_name $TEST_NAME \ 38 | --dataset "${TEST_DATASET}" \ 39 | --gpu_id $GPU_ID \ 40 | --keyframe_stride $KEYFRAME_STRIDE \ 41 | --win_r $WIN_R \ 42 | --num_scene_frame $NUM_SCENE_FRAME \ 43 | --initial_winsize $INITIAL_WINSIZE \ 44 | --conf_thres_l2w $CONF_THRES_L2W \ 45 | --conf_thres_i2p $CONF_THRES_I2P \ 46 | --num_points_save $NUM_POINTS_SAVE \ 47 | --update_buffer_intv $UPDATE_BUFFER_INTV \ 48 | --buffer_size $BUFFER_SIZE \ 49 | --buffer_strategy "${BUFFER_STRATEGY}" \ 50 | --max_num_register $MAX_NUM_REGISTER \ 51 | --keyframe_adapt_min $KEYFRAME_ADAPT_MIN \ 52 | --keyframe_adapt_max $KEYFRAME_ADAPT_MAX \ 53 | --keyframe_adapt_stride $KEYFRAME_ADAPT_STRIDE \ 54 | --save_preds 55 | -------------------------------------------------------------------------------- /scripts/eval_replica.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ###################################################################################### 4 | # set the parameters for whole scene reconstruction below 5 | # for defination of these parameters, please refer to the recon.py 6 | ###################################################################################### 7 | KEYFRAME_STRIDE=20 8 | UPDATE_BUFFER_INTV=3 9 | MAX_NUM_REGISTER=10 10 | WIN_R=5 11 | NUM_SCENE_FRAME=10 12 | INITIAL_WINSIZE=5 13 | CONF_THRES_I2P=1.5 14 | 15 | # the parameter below have nothing to do with the evaluation 16 | NUM_POINTS_SAVE=1000000 17 | CONF_THRES_L2W=10 18 | GPU_ID=-1 19 | 20 | SCENE_NAMES=("office0" "office1" "office2" "office3" "office4" "room0" "room1" "room2") 21 | 22 | for SCENE_NAME in ${SCENE_NAMES[@]}; 23 | do 24 | 25 | TEST_NAME="Replica_${SCENE_NAME}" 26 | 27 | echo "--------Start reconstructing scene ${SCENE_NAME} with test name ${TEST_NAME}--------" 28 | 29 | python recon.py \ 30 | --test_name "${TEST_NAME}" \ 31 | --img_dir "data/Replica/${SCENE_NAME}/results" \ 32 | --gpu_id $GPU_ID \ 33 | --keyframe_stride $KEYFRAME_STRIDE \ 34 | --win_r $WIN_R \ 35 | --num_scene_frame $NUM_SCENE_FRAME \ 36 | --initial_winsize $INITIAL_WINSIZE \ 37 | --conf_thres_l2w $CONF_THRES_L2W \ 38 | --conf_thres_i2p $CONF_THRES_I2P \ 39 | --num_points_save $NUM_POINTS_SAVE \ 40 | --update_buffer_intv $UPDATE_BUFFER_INTV \ 41 | --max_num_register $MAX_NUM_REGISTER \ 42 | --save_for_eval 43 | 44 | echo "--------Start evaluating scene ${SCENE_NAME} with test name ${TEST_NAME}--------" 45 | 46 | python evaluation/eval_recon.py \ 47 | --test_name="${TEST_NAME}" \ 48 | --gt_pcd="results/gt/replica/${SCENE_NAME}_pcds.npy" 49 | 50 | done -------------------------------------------------------------------------------- /scripts/train_i2p.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | MODEL="Image2PointsModel(pos_embed='RoPE100', img_size=(224, 224), head_type='linear', output_mode='pts3d', depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), \ 3 | enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_depth=12, dec_num_heads=12, \ 4 | mv_dec1='MultiviewDecoderBlock_max',mv_dec2='MultiviewDecoderBlock_max', enc_minibatch = 11)" 5 | 6 | TRAIN_DATASET="4000 @ ScanNetpp_Seq(filter=True, num_views=11, sample_freq=3, split='train', aug_crop=256, resolution=224, transform=ColorJitter, seed=233) + \ 7 | 2000 @ Aria_Seq(num_views=11, sample_freq=2, split='train', aug_crop=128, resolution=224, transform=ColorJitter, seed=233) + \ 8 | 2000 @ Co3d_Seq(num_views=11, sel_num=3, degree=180, mask_bg='rand', split='train', aug_crop=16, resolution=224, transform=ColorJitter, seed=233)" 9 | 10 | TEST_DATASET="1000 @ ScanNetpp_Seq(filter=True, num_views=11, split='test', resolution=224, seed=666) + \ 11 | 1000 @ Aria_Seq(num_views=11, split='test', resolution=224, seed=666) + \ 12 | 1000 @ Co3d_Seq(num_views=11, sel_num=3, degree=180, mask_bg='rand', split='test', resolution=224, seed=666)" 13 | 14 | # Stage 1: Train the i2p model for pointmap prediction 15 | PRETRAINED="checkpoints/DUSt3R_ViTLarge_BaseDecoder_224_linear.pth" 16 | TRAIN_OUT_DIR="checkpoints/i2p/slam3r_i2p_stage1" 17 | 18 | torchrun --nproc_per_node=8 train.py \ 19 | --train_dataset "${TRAIN_DATASET}" \ 20 | --test_dataset "${TEST_DATASET}" \ 21 | --model "$MODEL" \ 22 | --train_criterion "Jointnorm_ConfLoss(Jointnorm_Regr3D(L21, norm_mode='avg_dis'), alpha=0.2)" \ 23 | --test_criterion "Jointnorm_Regr3D(L21, norm_mode='avg_dis')" \ 24 | --pretrained $PRETRAINED \ 25 | --pretrained_type "dust3r" \ 26 | --lr 5e-5 --min_lr 5e-7 --warmup_epochs 10 --epochs 100 --batch_size 4 --accum_iter 1 \ 27 | --save_freq 2 --keep_freq 20 --eval_freq 1 --print_freq 20\ 28 | --save_config\ 29 | --freeze "encoder"\ 30 | --loss_func 'i2p' \ 31 | --output_dir $TRAIN_OUT_DIR \ 32 | --ref_id -1 33 | 34 | 35 | # Stage 2: Train a simple mlp to predict the correlation score 36 | PRETRAINED="checkpoints/i2p/slam3r_i2p_stage1/checkpoint-final.pth" 37 | TRAIN_OUT_DIR="checkpoints/i2p/slam3r_i2p" 38 | 39 | torchrun --nproc_per_node=8 train.py \ 40 | --train_dataset "${TRAIN_DATASET}" \ 41 | --test_dataset "${TEST_DATASET}" \ 42 | --model "$MODEL" \ 43 | --train_criterion "Jointnorm_ConfLoss(Jointnorm_Regr3D(L21, norm_mode='avg_dis'), alpha=0.2)" \ 44 | --test_criterion "Jointnorm_Regr3D(L21, gt_scale=True)" \ 45 | --pretrained $PRETRAINED \ 46 | --pretrained_type "slam3r" \ 47 | --lr 1e-4 --min_lr 1e-6 --warmup_epochs 5 --epochs 50 --batch_size 4 --accum_iter 1 \ 48 | --save_freq 2 --keep_freq 20 --eval_freq 1 --print_freq 20\ 49 | --save_config\ 50 | --freeze "corr_score_head_only"\ 51 | --loss_func "i2p_corr_score" \ 52 | --output_dir $TRAIN_OUT_DIR \ 53 | --ref_id -1 54 | -------------------------------------------------------------------------------- /scripts/train_l2w.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | MODEL="Local2WorldModel(pos_embed='RoPE100', img_size=(224, 224), head_type='linear', output_mode='pts3d', depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), \ 3 | enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_depth=12, dec_num_heads=12, \ 4 | mv_dec1='MultiviewDecoderBlock_max',mv_dec2='MultiviewDecoderBlock_max', enc_minibatch = 12, need_encoder=True)" 5 | 6 | TRAIN_DATASET="4000 @ ScanNetpp_Seq(filter=True, sample_freq=3, num_views=13, split='train', aug_crop=256, resolution=224, transform=ColorJitter, seed=233) + \ 7 | 2000 @ Aria_Seq(num_views=13, sample_freq=2, split='train', aug_crop=128, resolution=224, transform=ColorJitter, seed=233) + \ 8 | 2000 @ Co3d_Seq(num_views=13, sel_num=3, degree=180, mask_bg='rand', split='train', aug_crop=16, resolution=224, transform=ColorJitter, seed=233)" 9 | TEST_DATASET="1000 @ ScanNetpp_Seq(filter=True, sample_freq=3, num_views=13, split='test', resolution=224, seed=666)+ \ 10 | 1000 @ Aria_Seq(num_views=13, split='test', resolution=224, seed=666) + \ 11 | 1000 @ Co3d_Seq(num_views=13, sel_num=3, degree=180, mask_bg='rand', split='test', resolution=224, seed=666)" 12 | 13 | PRETRAINED="checkpoints/DUSt3R_ViTLarge_BaseDecoder_224_linear.pth" 14 | TRAIN_OUT_DIR="checkpoints/slam3r_l2w" 15 | 16 | torchrun --nproc_per_node=8 train.py \ 17 | --train_dataset "${TRAIN_DATASET}" \ 18 | --test_dataset "${TEST_DATASET}" \ 19 | --model "$MODEL" \ 20 | --train_criterion "Jointnorm_ConfLoss(Jointnorm_Regr3D(L21,norm_mode=None), alpha=0.2)" \ 21 | --test_criterion "Jointnorm_Regr3D(L21, norm_mode=None)" \ 22 | --pretrained $PRETRAINED \ 23 | --pretrained_type "dust3r" \ 24 | --lr 5e-5 --min_lr 5e-7 --warmup_epochs 20 --epochs 200 --batch_size 4 --accum_iter 1 \ 25 | --save_freq 2 --keep_freq 20 --eval_freq 1 --print_freq 20\ 26 | --save_config\ 27 | --output_dir $TRAIN_OUT_DIR \ 28 | --freeze "encoder"\ 29 | --loss_func "l2w" \ 30 | --ref_ids 0 1 2 3 4 5 31 | 32 | -------------------------------------------------------------------------------- /slam3r/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | -------------------------------------------------------------------------------- /slam3r/blocks/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2022-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | 4 | from .basic_blocks import Block, DecoderBlock, PatchEmbed, Mlp 5 | from .multiview_blocks import MultiviewDecoderBlock_max -------------------------------------------------------------------------------- /slam3r/blocks/basic_blocks.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2022-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | 4 | 5 | # -------------------------------------------------------- 6 | # Main encoder/decoder blocks 7 | # -------------------------------------------------------- 8 | # References: 9 | # timm 10 | # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py 11 | # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/helpers.py 12 | # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py 13 | # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/mlp.py 14 | # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/patch_embed.py 15 | 16 | 17 | import torch 18 | import torch.nn as nn 19 | 20 | from itertools import repeat 21 | import collections.abc 22 | 23 | 24 | def _ntuple(n): 25 | def parse(x): 26 | if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): 27 | return x 28 | return tuple(repeat(x, n)) 29 | return parse 30 | to_2tuple = _ntuple(2) 31 | 32 | def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True): 33 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 34 | """ 35 | if drop_prob == 0. or not training: 36 | return x 37 | keep_prob = 1 - drop_prob 38 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 39 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob) 40 | if keep_prob > 0.0 and scale_by_keep: 41 | random_tensor.div_(keep_prob) 42 | return x * random_tensor 43 | 44 | class DropPath(nn.Module): 45 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 46 | """ 47 | def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True): 48 | super(DropPath, self).__init__() 49 | self.drop_prob = drop_prob 50 | self.scale_by_keep = scale_by_keep 51 | 52 | def forward(self, x): 53 | return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) 54 | 55 | def extra_repr(self): 56 | return f'drop_prob={round(self.drop_prob,3):0.3f}' 57 | 58 | class Mlp(nn.Module): 59 | """ MLP as used in Vision Transformer, MLP-Mixer and related networks""" 60 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, drop=0.): 61 | super().__init__() 62 | out_features = out_features or in_features 63 | hidden_features = hidden_features or in_features 64 | bias = to_2tuple(bias) 65 | drop_probs = to_2tuple(drop) 66 | 67 | self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0]) 68 | self.act = act_layer() 69 | self.drop1 = nn.Dropout(drop_probs[0]) 70 | self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1]) 71 | self.drop2 = nn.Dropout(drop_probs[1]) 72 | 73 | def forward(self, x): 74 | x = self.fc1(x) 75 | x = self.act(x) 76 | x = self.drop1(x) 77 | x = self.fc2(x) 78 | x = self.drop2(x) 79 | return x 80 | 81 | class Attention(nn.Module): 82 | 83 | def __init__(self, dim, rope=None, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): 84 | super().__init__() 85 | self.num_heads = num_heads 86 | head_dim = dim // num_heads 87 | self.scale = head_dim ** -0.5 88 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 89 | self.attn_drop = nn.Dropout(attn_drop) 90 | self.proj = nn.Linear(dim, dim) 91 | self.proj_drop = nn.Dropout(proj_drop) 92 | self.rope = rope 93 | 94 | def forward(self, x, xpos): 95 | B, N, C = x.shape 96 | 97 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).transpose(1,3) 98 | q, k, v = [qkv[:,:,i] for i in range(3)] 99 | # q,k,v = qkv.unbind(2) # make torchscript happy (cannot use tensor as tuple) 100 | 101 | if self.rope is not None: 102 | q = self.rope(q, xpos) 103 | k = self.rope(k, xpos) 104 | 105 | attn = (q @ k.transpose(-2, -1)) * self.scale 106 | attn = attn.softmax(dim=-1) 107 | attn = self.attn_drop(attn) 108 | 109 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 110 | x = self.proj(x) 111 | x = self.proj_drop(x) 112 | return x 113 | 114 | class Block(nn.Module): 115 | 116 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., 117 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, rope=None): 118 | super().__init__() 119 | self.norm1 = norm_layer(dim) 120 | self.attn = Attention(dim, rope=rope, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) 121 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 122 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 123 | self.norm2 = norm_layer(dim) 124 | mlp_hidden_dim = int(dim * mlp_ratio) 125 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 126 | 127 | def forward(self, x, xpos): 128 | x = x + self.drop_path(self.attn(self.norm1(x), xpos)) 129 | x = x + self.drop_path(self.mlp(self.norm2(x))) 130 | return x 131 | 132 | class CrossAttention(nn.Module): 133 | 134 | def __init__(self, dim, rope=None, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): 135 | super().__init__() 136 | self.num_heads = num_heads 137 | head_dim = dim // num_heads 138 | self.scale = head_dim ** -0.5 139 | 140 | self.projq = nn.Linear(dim, dim, bias=qkv_bias) 141 | self.projk = nn.Linear(dim, dim, bias=qkv_bias) 142 | self.projv = nn.Linear(dim, dim, bias=qkv_bias) 143 | self.attn_drop = nn.Dropout(attn_drop) 144 | self.proj = nn.Linear(dim, dim) 145 | self.proj_drop = nn.Dropout(proj_drop) 146 | 147 | self.rope = rope 148 | 149 | def forward(self, query, key, value, qpos, kpos): 150 | B, Nq, C = query.shape 151 | Nk = key.shape[1] 152 | Nv = value.shape[1] 153 | 154 | q = self.projq(query).reshape(B,Nq,self.num_heads, C// self.num_heads).permute(0, 2, 1, 3) 155 | k = self.projk(key).reshape(B,Nk,self.num_heads, C// self.num_heads).permute(0, 2, 1, 3) 156 | v = self.projv(value).reshape(B,Nv,self.num_heads, C// self.num_heads).permute(0, 2, 1, 3) 157 | 158 | if self.rope is not None: 159 | q = self.rope(q, qpos) 160 | k = self.rope(k, kpos) 161 | 162 | attn = (q @ k.transpose(-2, -1)) * self.scale 163 | attn = attn.softmax(dim=-1) 164 | attn = self.attn_drop(attn) 165 | 166 | x = (attn @ v).transpose(1, 2).reshape(B, Nq, C) 167 | x = self.proj(x) 168 | x = self.proj_drop(x) 169 | return x 170 | 171 | class DecoderBlock(nn.Module): 172 | 173 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., 174 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, norm_mem=True, rope=None): 175 | super().__init__() 176 | self.norm1 = norm_layer(dim) 177 | self.attn = Attention(dim, rope=rope, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) 178 | self.cross_attn = CrossAttention(dim, rope=rope, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) 179 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 180 | self.norm2 = norm_layer(dim) 181 | self.norm3 = norm_layer(dim) 182 | mlp_hidden_dim = int(dim * mlp_ratio) 183 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 184 | self.norm_y = norm_layer(dim) if norm_mem else nn.Identity() 185 | 186 | def forward(self, x, y, xpos, ypos): 187 | x = x + self.drop_path(self.attn(self.norm1(x), xpos)) 188 | y_ = self.norm_y(y) 189 | x = x + self.drop_path(self.cross_attn(self.norm2(x), y_, y_, xpos, ypos)) 190 | x = x + self.drop_path(self.mlp(self.norm3(x))) 191 | return x, y 192 | 193 | 194 | # patch embedding 195 | class PositionGetter(object): 196 | """ return positions of patches """ 197 | 198 | def __init__(self): 199 | self.cache_positions = {} 200 | 201 | def __call__(self, b, h, w, device): 202 | if not (h,w) in self.cache_positions: 203 | x = torch.arange(w, device=device) 204 | y = torch.arange(h, device=device) 205 | self.cache_positions[h,w] = torch.cartesian_prod(y, x) # (h, w, 2) 206 | pos = self.cache_positions[h,w].view(1, h*w, 2).expand(b, -1, 2).clone() 207 | return pos 208 | 209 | class PatchEmbed(nn.Module): 210 | """ just adding _init_weights + position getter compared to timm.models.layers.patch_embed.PatchEmbed""" 211 | 212 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True): 213 | super().__init__() 214 | img_size = to_2tuple(img_size) 215 | patch_size = to_2tuple(patch_size) 216 | self.img_size = img_size 217 | self.patch_size = patch_size 218 | self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) 219 | self.num_patches = self.grid_size[0] * self.grid_size[1] 220 | self.flatten = flatten 221 | 222 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 223 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 224 | 225 | self.position_getter = PositionGetter() 226 | 227 | def forward(self, x): 228 | B, C, H, W = x.shape 229 | torch._assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).") 230 | torch._assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).") 231 | x = self.proj(x) 232 | pos = self.position_getter(B, x.size(2), x.size(3), x.device) 233 | if self.flatten: 234 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC 235 | x = self.norm(x) 236 | return x, pos 237 | 238 | def _init_weights(self): 239 | w = self.proj.weight.data 240 | torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) 241 | 242 | -------------------------------------------------------------------------------- /slam3r/blocks/multiview_blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .basic_blocks import Mlp, Attention, CrossAttention, DropPath 5 | try: 6 | import xformers.ops as xops 7 | XFORMERS_AVALIABLE = True 8 | except ImportError: 9 | print("xformers not avaliable, use self-implemented attention instead") 10 | XFORMERS_AVALIABLE = False 11 | 12 | 13 | class XFormer_Attention(nn.Module): 14 | """Warpper for self-attention module with xformers. 15 | Calculate attention scores with xformers memory_efficient_attention. 16 | When inference is performed on the CPU or when xformer is not installed, it will degrade to the normal attention. 17 | """ 18 | def __init__(self, old_module:Attention): 19 | super().__init__() 20 | self.num_heads = old_module.num_heads 21 | self.scale = old_module.scale 22 | self.qkv = old_module.qkv 23 | self.attn_drop_prob = old_module.attn_drop.p 24 | self.proj = old_module.proj 25 | self.proj_drop = old_module.proj_drop 26 | self.rope = old_module.rope 27 | self.attn_drop = old_module.attn_drop 28 | 29 | def forward(self, x, xpos): 30 | B, N, C = x.shape 31 | 32 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).transpose(1,3) 33 | q, k, v = [qkv[:,:,i] for i in range(3)] #shape: (B, num_heads, N, C//num_heads) 34 | 35 | if self.rope is not None: 36 | q = self.rope(q, xpos) # (B, H, N, K) 37 | k = self.rope(k, xpos) 38 | 39 | if x.is_cuda and XFORMERS_AVALIABLE: 40 | q = q.permute(0, 2, 1, 3) # (B, N, H, K) 41 | k = k.permute(0, 2, 1, 3) 42 | v = v.permute(0, 2, 1, 3) 43 | drop_prob = self.attn_drop_prob if self.training else 0 44 | x = xops.memory_efficient_attention(q, k, v, scale=self.scale, p=drop_prob) # (B, N, H, K) 45 | else: 46 | attn = (q @ k.transpose(-2, -1)) * self.scale 47 | attn = attn.softmax(dim=-1) 48 | attn = self.attn_drop(attn) 49 | x = (attn @ v).transpose(1, 2) 50 | 51 | x=x.reshape(B, N, C) 52 | x = self.proj(x) 53 | x = self.proj_drop(x) 54 | return x 55 | 56 | 57 | class MultiviewDecoderBlock_max(nn.Module): 58 | """Multiview decoder block, 59 | which takes as input arbitrary number of source views and target view features. 60 | Use max-pooling to merge features queried from different src views. 61 | """ 62 | 63 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., 64 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, norm_mem=True, rope=None): 65 | super().__init__() 66 | self.norm1 = norm_layer(dim) 67 | self.attn = Attention(dim, rope=rope, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) 68 | self.cross_attn = CrossAttention(dim, rope=rope, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) 69 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 70 | self.norm2 = norm_layer(dim) 71 | self.norm3 = norm_layer(dim) 72 | mlp_hidden_dim = int(dim * mlp_ratio) 73 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 74 | self.norm_y = norm_layer(dim) if norm_mem else nn.Identity() 75 | if XFORMERS_AVALIABLE: 76 | self.attn = XFormer_Attention(self.attn) 77 | 78 | def batched_cross_attn(self, xs, ys, xposes, yposes, rel_ids_list_d, M): 79 | """ 80 | Calculate cross-attention between Vx target views and Vy source views in a single batch. 81 | """ 82 | xs_normed = self.norm2(xs) 83 | ys_normed = self.norm_y(ys) 84 | cross_attn = self.cross_attn 85 | Vx, B, Nx, C = xs.shape 86 | Vy, B, Ny, C = ys.shape 87 | num_heads = cross_attn.num_heads 88 | 89 | #precompute q,k,v for each view to save computation 90 | qs = cross_attn.projq(xs_normed).reshape(Vx*B,Nx,num_heads, C//num_heads).permute(0, 2, 1, 3) # (Vx*B,num_heads,Nx,C//num_heads) 91 | ks = cross_attn.projk(ys_normed).reshape(Vy*B,Ny,num_heads, C//num_heads).permute(0, 2, 1, 3) # (Vy*B,num_heads,Ny,C//num_heads) 92 | vs = cross_attn.projv(ys_normed).reshape(Vy,B,Ny,num_heads, C//num_heads) # (Vy*B,num_heads,Ny,C//num_heads) 93 | 94 | #add rope 95 | if cross_attn.rope is not None: 96 | qs = cross_attn.rope(qs, xposes) 97 | ks = cross_attn.rope(ks, yposes) 98 | qs = qs.permute(0, 2, 1, 3).reshape(Vx,B,Nx,num_heads,C// num_heads) # (Vx, B, Nx, H, K) 99 | ks = ks.permute(0, 2, 1, 3).reshape(Vy,B,Ny,num_heads,C// num_heads) # (Vy, B, Ny, H, K) 100 | 101 | # construct query, key, value for each target view 102 | ks_respect = torch.index_select(ks, 0, rel_ids_list_d) # (Vx*M, B, Ny, H, K) 103 | vs_respect = torch.index_select(vs, 0, rel_ids_list_d) # (Vx*M, B, Ny, H, K) 104 | qs_corresp = torch.unsqueeze(qs, 1).expand(-1, M, -1, -1, -1, -1) # (Vx, M, B, Nx, H, K) 105 | 106 | ks_compact = ks_respect.reshape(Vx*M*B, Ny, num_heads, C//num_heads) 107 | vs_compact = vs_respect.reshape(Vx*M*B, Ny, num_heads, C//num_heads) 108 | qs_compact = qs_corresp.reshape(Vx*M*B, Nx, num_heads, C//num_heads) 109 | 110 | # calculate attention results for all target views in one go 111 | if xs.is_cuda and XFORMERS_AVALIABLE: 112 | drop_prob = cross_attn.attn_drop.p if self.training else 0 113 | attn_outputs = xops.memory_efficient_attention(qs_compact, ks_compact, vs_compact, 114 | scale=self.cross_attn.scale, p=drop_prob) # (V*M*B, N, H, K) 115 | else: 116 | ks_compact = ks_compact.permute(0, 2, 1, 3) # (Vx*M*B, H, Ny, K) 117 | qs_compact = qs_compact.permute(0, 2, 1, 3) # (Vx*M*B, H, Nx, K) 118 | vs_compact = vs_compact.permute(0, 2, 1, 3) # (Vx*M*B, H, Ny, K) 119 | attn = (qs_compact @ ks_compact.transpose(-2, -1)) * self.cross_attn.scale # (V*M*B, H, Nx, Ny) 120 | attn = attn.softmax(dim=-1) # (V*M*B, H, Nx, Ny) 121 | attn = self.cross_attn.attn_drop(attn) 122 | attn_outputs = (attn @ vs_compact).transpose(1, 2).reshape(Vx*M*B, Nx, num_heads, C//num_heads) # (V*M*B, Nx, H, K) 123 | 124 | attn_outputs = attn_outputs.reshape(Vx, M, B, Nx, C) #(Vx, M, B, Nx, C) 125 | attn_outputs = cross_attn.proj_drop(cross_attn.proj(attn_outputs)) #(Vx, M, B, Nx, C) 126 | 127 | return attn_outputs 128 | 129 | def forward(self, xs:torch.Tensor, ys:torch.Tensor, 130 | xposes:torch.Tensor, yposes:torch.Tensor, 131 | rel_ids_list_d:torch.Tensor, M:int): 132 | """refine Vx target view feature parallelly, with the information of Vy source view 133 | 134 | Args: 135 | xs: (Vx,B,S,D): features of target views to refine.(S: number of tokens, D: feature dimension) 136 | ys: (Vy,B,S,D): features of source views to query from. 137 | M: number of source views to query from for each target view 138 | rel_ids_list_d: (Vx*M,) indices of source views to query from for each target view 139 | 140 | For example: 141 | Suppose we have 3 target views and 4 source views, 142 | then xs shuold has shape (3,B,S,D), ys should has shape (4,B,S,D). 143 | 144 | If we require xs[0] to query features from ys[0], ys[1], 145 | xs[1] to query features from ys[2], ys[2],(duplicate number supported) 146 | xs[2] to query features from ys[2], ys[3], 147 | then we should set M=2, rel_ids_list_d=[0,1, 2,2, 2,3] 148 | """ 149 | Vx, B, Nx, C = xs.shape 150 | 151 | # self-attention on each target view feature 152 | xs = xs.reshape(-1, *xs.shape[2:]) # (Vx*B,S,C) 153 | xposes = xposes.reshape(-1, *xposes.shape[2:]) # (Vx*B,S,2) 154 | yposes = yposes.reshape(-1, *yposes.shape[2:]) 155 | xs = xs + self.drop_path(self.attn(self.norm1(xs), xposes)) #(Vx*B,S,C) 156 | 157 | # each target view conducts cross-attention with all source views to query features 158 | attn_outputs = self.batched_cross_attn(xs.reshape(Vx,B,Nx,C), ys, xposes, yposes, rel_ids_list_d, M) 159 | 160 | # max-pooling to aggregate features queried from different source views 161 | merged_ys, indices = torch.max(attn_outputs, dim=1) #(Vx, B, Nx, C) 162 | merged_ys = merged_ys.reshape(Vx*B,Nx,C) #(Vx*B,Nx,C) 163 | 164 | xs = xs + self.drop_path(merged_ys) 165 | xs = xs + self.drop_path(self.mlp(self.norm3(xs))) #(VB,N,C) 166 | xs = xs.reshape(Vx,B,Nx,C) 167 | return xs 168 | 169 | -------------------------------------------------------------------------------- /slam3r/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | import torch 4 | import numpy as np 5 | 6 | from .utils.transforms import * 7 | from .base.batched_sampler import BatchedRandomSampler # noqa: F401 8 | from .replica_seq import Replica 9 | from .scannetpp_seq import ScanNetpp_Seq 10 | from .project_aria_seq import Aria_Seq 11 | from .co3d_seq import Co3d_Seq 12 | from .base.base_stereo_view_dataset import EasyDataset 13 | 14 | def get_data_loader(dataset, batch_size, num_workers=8, shuffle=True, drop_last=True, pin_mem=True): 15 | import torch 16 | from slam3r.utils.croco_misc import get_world_size, get_rank 17 | 18 | # pytorch dataset 19 | if isinstance(dataset, str): 20 | dataset = eval(dataset) 21 | 22 | world_size = get_world_size() 23 | rank = get_rank() 24 | 25 | try: 26 | sampler = dataset.make_sampler(batch_size, shuffle=shuffle, world_size=world_size, 27 | rank=rank, drop_last=drop_last) 28 | except (AttributeError, NotImplementedError): 29 | # not avail for this dataset 30 | if torch.distributed.is_initialized(): 31 | sampler = torch.utils.data.DistributedSampler( 32 | dataset, num_replicas=world_size, rank=rank, shuffle=shuffle, drop_last=drop_last 33 | ) 34 | elif shuffle: 35 | sampler = torch.utils.data.RandomSampler(dataset) 36 | else: 37 | sampler = torch.utils.data.SequentialSampler(dataset) 38 | 39 | data_loader = torch.utils.data.DataLoader( 40 | dataset, 41 | sampler=sampler, 42 | batch_size=batch_size, 43 | num_workers=num_workers, 44 | pin_memory=pin_mem, 45 | drop_last=drop_last, 46 | ) 47 | 48 | return data_loader 49 | 50 | class MultiDataLoader: 51 | def __init__(self, dataloaders:list, return_id=False): 52 | self.dataloaders = dataloaders 53 | self.len_dataloaders = [len(loader) for loader in dataloaders] 54 | self.total_length = sum(self.len_dataloaders) 55 | self.epoch = None 56 | self.return_id = return_id 57 | 58 | def __len__(self): 59 | return self.total_length 60 | 61 | def set_epoch(self, epoch): 62 | for data_loader in self.dataloaders: 63 | if hasattr(data_loader, 'dataset') and hasattr(data_loader.dataset, 'set_epoch'): 64 | data_loader.dataset.set_epoch(epoch) 65 | if hasattr(data_loader, 'sampler') and hasattr(data_loader.sampler, 'set_epoch'): 66 | data_loader.sampler.set_epoch(epoch) 67 | self.epoch = epoch 68 | 69 | def __iter__(self): 70 | loader_idx = [] 71 | for idx, length in enumerate(self.len_dataloaders): 72 | loader_idx += [idx]*length 73 | loader_idx = np.array(loader_idx) 74 | assert loader_idx.shape[0] == self.total_length 75 | #是否需要一个统一的seed让每个进程遍历dataloaders的顺序相同? 76 | if self.epoch is None: 77 | assert len(self.dataloaders) == 1 78 | else: 79 | seed = 777 + self.epoch 80 | rng = np.random.default_rng(seed=seed) 81 | rng.shuffle(loader_idx) 82 | batch_count = 0 83 | 84 | iters = [iter(loader) for loader in self.dataloaders] # iterator for each dataloader 85 | while True: 86 | idx = loader_idx[batch_count] 87 | try: 88 | batch = next(iters[idx]) 89 | except StopIteration: # this won't happen in distribute mode if drop_last is False 90 | iters[idx] = iter(self.dataloaders[idx]) 91 | batch = next(iters[idx]) 92 | if self.return_id: 93 | batch = (batch, idx) 94 | yield batch 95 | batch_count += 1 96 | if batch_count == self.total_length: 97 | # batch_count -= self.total_length 98 | break 99 | 100 | 101 | def get_multi_data_loader(dataset, batch_size, return_id=False, num_workers=8, shuffle=True, drop_last=True, pin_mem=True): 102 | import torch 103 | from slam3r.utils.croco_misc import get_world_size, get_rank 104 | 105 | # pytorch dataset 106 | if isinstance(dataset, str): 107 | dataset = eval(dataset) 108 | 109 | if isinstance(dataset, EasyDataset): 110 | datasets = [dataset] 111 | else: 112 | datasets = dataset 113 | print(datasets) 114 | assert isinstance(datasets,(tuple, list)) 115 | 116 | world_size = get_world_size() 117 | rank = get_rank() 118 | dataloaders = [] 119 | for dataset in datasets: 120 | try: 121 | sampler = dataset.make_sampler(batch_size, shuffle=shuffle, world_size=world_size, 122 | rank=rank, drop_last=drop_last) 123 | except (AttributeError, NotImplementedError): 124 | # not avail for this dataset 125 | if torch.distributed.is_initialized(): 126 | sampler = torch.utils.data.DistributedSampler( 127 | dataset, num_replicas=world_size, rank=rank, shuffle=shuffle, drop_last=drop_last 128 | ) 129 | elif shuffle: 130 | sampler = torch.utils.data.RandomSampler(dataset) 131 | else: 132 | sampler = torch.utils.data.SequentialSampler(dataset) 133 | 134 | data_loader = torch.utils.data.DataLoader( 135 | dataset, 136 | sampler=sampler, 137 | batch_size=batch_size, 138 | num_workers=num_workers, 139 | pin_memory=pin_mem, 140 | drop_last=drop_last, 141 | ) 142 | dataloaders.append(data_loader) 143 | 144 | multi_dataloader = MultiDataLoader(dataloaders, return_id=return_id) 145 | return multi_dataloader 146 | -------------------------------------------------------------------------------- /slam3r/datasets/base/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | -------------------------------------------------------------------------------- /slam3r/datasets/base/batched_sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | # 4 | # -------------------------------------------------------- 5 | # Random sampling under a constraint 6 | # -------------------------------------------------------- 7 | import numpy as np 8 | import torch 9 | 10 | 11 | class BatchedRandomSampler: 12 | """ Random sampling under a constraint: each sample in the batch has the same feature, 13 | which is chosen randomly from a known pool of 'features' for each batch. 14 | 15 | For instance, the 'feature' could be the image aspect-ratio. 16 | 17 | The index returned is a tuple (sample_idx, feat_idx). 18 | This sampler ensures that each series of `batch_size` indices has the same `feat_idx`. 19 | """ 20 | 21 | def __init__(self, dataset, batch_size, pool_size, world_size=1, rank=0, drop_last=True): 22 | self.batch_size = batch_size 23 | self.pool_size = pool_size 24 | 25 | self.len_dataset = N = len(dataset) 26 | self.total_size = round_by(N, batch_size*world_size) if drop_last else N 27 | assert world_size == 1 or drop_last, 'must drop the last batch in distributed mode' 28 | 29 | # distributed sampler 30 | self.world_size = world_size 31 | self.rank = rank 32 | self.epoch = None 33 | 34 | def __len__(self): 35 | return self.total_size // self.world_size 36 | 37 | def set_epoch(self, epoch): 38 | self.epoch = epoch 39 | 40 | def __iter__(self): 41 | # prepare RNG 42 | if self.epoch is None: 43 | assert self.world_size == 1 and self.rank == 0, 'use set_epoch() if distributed mode is used' 44 | seed = int(torch.empty((), dtype=torch.int64).random_().item()) 45 | else: 46 | seed = self.epoch + 777 47 | rng = np.random.default_rng(seed=seed) 48 | 49 | # random indices (will restart from 0 if not drop_last) 50 | sample_idxs = np.arange(self.total_size) 51 | rng.shuffle(sample_idxs) 52 | 53 | # random feat_idxs (same across each batch) 54 | n_batches = (self.total_size+self.batch_size-1) // self.batch_size 55 | feat_idxs = rng.integers(self.pool_size, size=n_batches) # shape = (n_batches,) 56 | feat_idxs = np.broadcast_to(feat_idxs[:, None], (n_batches, self.batch_size)) # shape = (n_batches, batch_size) 57 | feat_idxs = feat_idxs.ravel()[:self.total_size] 58 | 59 | # put them together 60 | idxs = np.c_[sample_idxs, feat_idxs] # shape = (total_size, 2) 61 | 62 | # Distributed sampler: we select a subset of batches 63 | # make sure the slice for each node is aligned with batch_size 64 | size_per_proc = self.batch_size * ((self.total_size + self.world_size * 65 | self.batch_size-1) // (self.world_size * self.batch_size)) 66 | idxs = idxs[self.rank*size_per_proc: (self.rank+1)*size_per_proc] 67 | 68 | yield from (tuple(idx) for idx in idxs) 69 | 70 | 71 | def round_by(total, multiple, up=False): 72 | if up: 73 | total = total + multiple-1 74 | return (total//multiple) * multiple 75 | -------------------------------------------------------------------------------- /slam3r/datasets/base/easy_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | # 4 | # -------------------------------------------------------- 5 | # A dataset base class that you can easily resize and combine. 6 | # -------------------------------------------------------- 7 | import numpy as np 8 | from slam3r.datasets.base.batched_sampler import BatchedRandomSampler 9 | 10 | 11 | class EasyDataset: 12 | """ a dataset that you can easily resize and combine. 13 | Examples: 14 | --------- 15 | 2 * dataset ==> duplicate each element 2x 16 | 17 | 10 @ dataset ==> set the size to 10 (random sampling, duplicates if necessary) 18 | 19 | dataset1 + dataset2 ==> concatenate datasets 20 | """ 21 | 22 | def __add__(self, other): 23 | return CatDataset([self, other]) 24 | 25 | def __rmul__(self, factor): 26 | return MulDataset(factor, self) 27 | 28 | def __rmatmul__(self, factor): 29 | return ResizedDataset(factor, self) 30 | 31 | def set_epoch(self, epoch): 32 | pass # nothing to do by default 33 | 34 | def make_sampler(self, batch_size, shuffle=True, world_size=1, rank=0, drop_last=True): 35 | if not (shuffle): 36 | raise NotImplementedError() # cannot deal yet 37 | num_of_aspect_ratios = len(self._resolutions) 38 | return BatchedRandomSampler(self, batch_size, num_of_aspect_ratios, world_size=world_size, rank=rank, drop_last=drop_last) 39 | 40 | 41 | class MulDataset (EasyDataset): 42 | """ Artifically augmenting the size of a dataset. 43 | """ 44 | multiplicator: int 45 | 46 | def __init__(self, multiplicator, dataset): 47 | assert isinstance(multiplicator, int) and multiplicator > 0 48 | self.multiplicator = multiplicator 49 | self.dataset = dataset 50 | 51 | def __len__(self): 52 | return self.multiplicator * len(self.dataset) 53 | 54 | def __repr__(self): 55 | return f'{self.multiplicator}*{repr(self.dataset)}' 56 | 57 | def __getitem__(self, idx): 58 | if isinstance(idx, tuple): 59 | idx, other = idx 60 | return self.dataset[idx // self.multiplicator, other] 61 | else: 62 | return self.dataset[idx // self.multiplicator] 63 | 64 | @property 65 | def _resolutions(self): 66 | return self.dataset._resolutions 67 | 68 | 69 | class ResizedDataset (EasyDataset): 70 | """ Artifically changing the size of a dataset. 71 | """ 72 | new_size: int 73 | 74 | def __init__(self, new_size, dataset): 75 | assert isinstance(new_size, int) and new_size > 0 76 | self.new_size = new_size 77 | self.dataset = dataset 78 | 79 | def __len__(self): 80 | return self.new_size 81 | 82 | def __repr__(self): 83 | size_str = str(self.new_size) 84 | for i in range((len(size_str)-1) // 3): 85 | sep = -4*i-3 86 | size_str = size_str[:sep] + '_' + size_str[sep:] 87 | return f'{size_str} @ {repr(self.dataset)}' 88 | 89 | def set_epoch(self, epoch): 90 | # this random shuffle only depends on the epoch 91 | rng = np.random.default_rng(seed=epoch+777) 92 | 93 | # shuffle all indices 94 | perm = rng.permutation(len(self.dataset)) 95 | 96 | # rotary extension until target size is met 97 | shuffled_idxs = np.concatenate([perm] * (1 + (len(self)-1) // len(self.dataset))) 98 | self._idxs_mapping = shuffled_idxs[:self.new_size] 99 | 100 | assert len(self._idxs_mapping) == self.new_size 101 | 102 | def __getitem__(self, idx): 103 | assert hasattr(self, '_idxs_mapping'), 'You need to call dataset.set_epoch() to use ResizedDataset.__getitem__()' 104 | if isinstance(idx, tuple): 105 | idx, other = idx 106 | return self.dataset[self._idxs_mapping[idx], other] 107 | else: 108 | return self.dataset[self._idxs_mapping[idx]] 109 | 110 | @property 111 | def _resolutions(self): 112 | return self.dataset._resolutions 113 | 114 | 115 | class CatDataset (EasyDataset): 116 | """ Concatenation of several datasets 117 | """ 118 | 119 | def __init__(self, datasets): 120 | for dataset in datasets: 121 | assert isinstance(dataset, EasyDataset) 122 | self.datasets = datasets 123 | self._cum_sizes = np.cumsum([len(dataset) for dataset in datasets]) 124 | 125 | def __len__(self): 126 | return self._cum_sizes[-1] 127 | 128 | def __repr__(self): 129 | # remove uselessly long transform 130 | return ' + '.join(repr(dataset).replace(',transform=Compose( ToTensor() Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))', '') for dataset in self.datasets) 131 | 132 | def set_epoch(self, epoch): 133 | for dataset in self.datasets: 134 | dataset.set_epoch(epoch) 135 | 136 | def __getitem__(self, idx): 137 | other = None 138 | if isinstance(idx, tuple): 139 | idx, other = idx 140 | 141 | if not (0 <= idx < len(self)): 142 | raise IndexError() 143 | 144 | db_idx = np.searchsorted(self._cum_sizes, idx, 'right') 145 | dataset = self.datasets[db_idx] 146 | new_idx = idx - (self._cum_sizes[db_idx - 1] if db_idx > 0 else 0) 147 | 148 | if other is not None: 149 | new_idx = (new_idx, other) 150 | return dataset[new_idx] 151 | 152 | @property 153 | def _resolutions(self): 154 | resolutions = self.datasets[0]._resolutions 155 | for dataset in self.datasets[1:]: 156 | assert tuple(dataset._resolutions) == tuple(resolutions) 157 | return resolutions 158 | -------------------------------------------------------------------------------- /slam3r/datasets/co3d_seq.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | # 4 | # -------------------------------------------------------- 5 | # Dataloader for preprocessed Co3d_v2 6 | # dataset at https://github.com/facebookresearch/co3d - Creative Commons Attribution-NonCommercial 4.0 International 7 | # See datasets_preprocess/preprocess_co3d.py 8 | # -------------------------------------------------------- 9 | import os.path as osp 10 | import json 11 | import itertools 12 | from collections import deque 13 | import cv2 14 | import numpy as np 15 | 16 | SLAM3R_DIR = osp.dirname(osp.dirname(osp.dirname(osp.abspath(__file__)))) 17 | import sys # noqa: E402 18 | sys.path.insert(0, SLAM3R_DIR) # noqa: E402 19 | from slam3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset 20 | from slam3r.utils.image import imread_cv2 21 | 22 | TRAINING_CATEGORIES = [ 23 | "apple","backpack","banana","baseballbat","baseballglove","bench","bicycle", 24 | "bottle","bowl","broccoli","cake","car","carrot","cellphone","chair","cup","donut","hairdryer","handbag","hydrant","keyboard", 25 | "laptop","microwave","motorcycle","mouse","orange","parkingmeter","pizza","plant","stopsign","teddybear","toaster","toilet", 26 | "toybus","toyplane","toytrain","toytruck","tv","umbrella","vase","wineglass", 27 | ] 28 | TEST_CATEGORIES = ["ball", "book", "couch", "frisbee", "hotdog", "kite", "remote", "sandwich", "skateboard", "suitcase"] 29 | 30 | 31 | class Co3d_Seq(BaseStereoViewDataset): 32 | def __init__(self, 33 | mask_bg=True, 34 | ROOT="data/co3d_processed", 35 | num_views=2, 36 | degree=90, # degree range to select views 37 | sel_num=1, # number of views to select inside a degree range 38 | *args, 39 | **kwargs): 40 | self.ROOT = ROOT 41 | super().__init__(*args, **kwargs) 42 | assert mask_bg in (True, False, 'rand') 43 | self.mask_bg = mask_bg 44 | self.degree = degree 45 | self.winsize = int(degree / 360 * 100) 46 | self.sel_num = sel_num 47 | self.sel_num_perseq = (101 - self.winsize) * self.sel_num 48 | self.num_views = num_views 49 | 50 | # load all scenes 51 | if self.split == 'train': 52 | self.categories = TRAINING_CATEGORIES 53 | elif self.split == 'test': 54 | self.categories = TEST_CATEGORIES 55 | else: 56 | raise ValueError(f"Unknown split {self.split}") 57 | self.scenes = {} 58 | for cate in TRAINING_CATEGORIES: 59 | with open(osp.join(self.ROOT, cate, f'selected_seqs_{self.split}.json'), 'r') as f: 60 | self.scenes[cate] = json.load(f) 61 | self.scenes = {(k, k2): v2 for k, v in self.scenes.items() 62 | for k2, v2 in v.items()} 63 | self.scene_list = list(self.scenes.keys()) # for each scene, we have about 100 images ==> 360 degrees (so 25 frames ~= 90 degrees) 64 | self.scene_lens = [len(v) for k,v in self.scenes.items()] 65 | # print(np.unique(np.array(self.scene_lens))) 66 | self.invalidate = {scene: {} for scene in self.scene_list} 67 | 68 | print(self) 69 | 70 | def __len__(self): 71 | return len(self.scene_list) * self.sel_num_perseq 72 | 73 | def get_img_idxes(self, idx, rng): 74 | sid = max(0, idx // self.sel_num - 1) #from 0 to 99-winsize 75 | eid = sid + self.winsize 76 | if idx % self.sel_num == 0: 77 | # generate a uniform sample between sid and eid 78 | return np.linspace(sid, eid, self.num_views, endpoint=True, dtype=int) 79 | 80 | # select the first and last, and randomly select the rest n-2 in between 81 | if self.num_views == 2: 82 | return [sid, eid] 83 | sel_ids = rng.choice(range(sid+1, eid), self.num_views-2, replace=False) 84 | sel_ids.sort() 85 | return [sid] + list(sel_ids) + [eid] 86 | 87 | 88 | def _get_views(self, idx, resolution, rng): 89 | # choose a scene 90 | obj, instance = self.scene_list[idx // self.sel_num_perseq] 91 | image_pool = self.scenes[obj, instance] 92 | last = len(image_pool)-1 93 | if last <= self.winsize: 94 | return self._get_views(rng.integers(0, len(self)-1), resolution, rng) 95 | 96 | imgs_idxs = self.get_img_idxes(idx % self.sel_num_perseq, rng) 97 | 98 | for i, idx in enumerate(imgs_idxs): 99 | if idx > last: 100 | idx = idx % last 101 | imgs_idxs[i] = idx 102 | # print(imgs_idxs) 103 | 104 | if resolution not in self.invalidate[obj, instance]: # flag invalid images 105 | self.invalidate[obj, instance][resolution] = [False for _ in range(len(image_pool))] 106 | 107 | # decide now if we mask the bg 108 | mask_bg = (self.mask_bg == True) or (self.mask_bg == 'rand' and rng.choice(2)) 109 | 110 | views = [] 111 | imgs_idxs = deque(imgs_idxs) 112 | 113 | while len(imgs_idxs) > 0: # some images (few) have zero depth 114 | im_idx = imgs_idxs.popleft() 115 | 116 | if self.invalidate[obj, instance][resolution][im_idx]: 117 | # search for a valid image 118 | random_direction = 2 * rng.choice(2) - 1 119 | for offset in range(1, len(image_pool)): 120 | tentative_im_idx = (im_idx + (random_direction * offset)) % len(image_pool) 121 | if not self.invalidate[obj, instance][resolution][tentative_im_idx]: 122 | im_idx = tentative_im_idx 123 | break 124 | if offset == len(image_pool) - 1: 125 | # no valid image found 126 | return self._get_views((idx+1)%len(self), resolution, rng) 127 | 128 | view_idx = image_pool[im_idx] 129 | 130 | impath = osp.join(self.ROOT, obj, instance, 'images', f'frame{view_idx:06n}.jpg') 131 | 132 | # load camera params 133 | input_metadata = np.load(impath.replace('jpg', 'npz')) 134 | camera_pose = input_metadata['camera_pose'].astype(np.float32) 135 | intrinsics = input_metadata['camera_intrinsics'].astype(np.float32) 136 | 137 | # load image and depth 138 | rgb_image = imread_cv2(impath) 139 | depthmap = imread_cv2(impath.replace('images', 'depths') + '.geometric.png', cv2.IMREAD_UNCHANGED) 140 | depthmap = (depthmap.astype(np.float32) / 65535) * np.nan_to_num(input_metadata['maximum_depth']) 141 | if mask_bg: 142 | # load object mask 143 | maskpath = osp.join(self.ROOT, obj, instance, 'masks', f'frame{view_idx:06n}.png') 144 | maskmap = imread_cv2(maskpath, cv2.IMREAD_UNCHANGED).astype(np.float32) 145 | maskmap = (maskmap / 255.0) > 0.1 146 | 147 | # update the depthmap with mask 148 | depthmap *= maskmap 149 | 150 | rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( 151 | rgb_image, depthmap, intrinsics, resolution, rng=rng, info=impath) 152 | 153 | # TODO: check if this is resonable 154 | valid_depth = depthmap[depthmap > 0.0] 155 | if valid_depth.size > 0: 156 | median_depth = np.median(valid_depth) 157 | # print(f"median depth: {median_depth}") 158 | depthmap[depthmap > median_depth*3] = 0. # filter out floatig points 159 | 160 | num_valid = (depthmap > 0.0).sum() 161 | if num_valid == 0: 162 | # problem, invalidate image and retry 163 | self.invalidate[obj, instance][resolution][im_idx] = True 164 | imgs_idxs.append(im_idx) 165 | continue 166 | 167 | views.append(dict( 168 | img=rgb_image, 169 | depthmap=depthmap, 170 | camera_pose=camera_pose, 171 | camera_intrinsics=intrinsics, 172 | dataset='Co3d_v2', 173 | label=f"{obj}_{instance}_frame{view_idx:06n}.jpg", 174 | instance=osp.split(impath)[1], 175 | )) 176 | return views 177 | 178 | 179 | if __name__ == "__main__": 180 | from slam3r.datasets.base.base_stereo_view_dataset import view_name 181 | import os 182 | import trimesh 183 | 184 | num_views = 11 185 | dataset = Co3d_Seq(split='train', 186 | mask_bg=False, resolution=224, aug_crop=16, 187 | num_views=num_views, degree=90, sel_num=3) 188 | 189 | save_dir = "visualization/co3d_seq_views" 190 | os.makedirs(save_dir, exist_ok=True) 191 | 192 | # import tqdm 193 | # for idx in tqdm.tqdm(np.random.permutation(len(dataset))): 194 | # views = dataset[(idx,0)] 195 | # print([view['instance'] for view in views]) 196 | 197 | for idx in np.random.permutation(len(dataset))[:10]: 198 | # for idx in range(len(dataset))[5:10000:2000]: 199 | os.makedirs(osp.join(save_dir, str(idx)), exist_ok=True) 200 | views = dataset[(idx,0)] 201 | assert len(views) == num_views 202 | all_pts = [] 203 | all_color=[] 204 | for i, view in enumerate(views): 205 | img = np.array(view['img']).transpose(1, 2, 0) 206 | save_path = osp.join(save_dir, str(idx), f"{i}_{view['label']}") 207 | # img=cv2.COLOR_RGB2BGR(img) 208 | img=img[...,::-1] 209 | img = (img+1)/2 210 | cv2.imwrite(save_path, img*255) 211 | print(f"save to {save_path}") 212 | pts3d = np.array(view['pts3d']).reshape(-1,3) 213 | img = img[...,::-1] 214 | pct = trimesh.PointCloud(pts3d, colors=img.reshape(-1, 3)) 215 | pct.export(save_path.replace('.jpg','.ply')) 216 | all_pts.append(pts3d) 217 | all_color.append(img.reshape(-1, 3)) 218 | all_pts = np.concatenate(all_pts, axis=0) 219 | all_color = np.concatenate(all_color, axis=0) 220 | pct = trimesh.PointCloud(all_pts, all_color) 221 | pct.export(osp.join(save_dir, str(idx), f"all.ply")) 222 | -------------------------------------------------------------------------------- /slam3r/datasets/project_aria_seq.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | # 4 | # -------------------------------------------------------- 5 | # Dataloader for preprocessed project-aria dataset 6 | # -------------------------------------------------------- 7 | import os.path as osp 8 | import os 9 | import cv2 10 | import numpy as np 11 | import math 12 | 13 | SLAM3R_DIR = osp.dirname(osp.dirname(osp.dirname(osp.abspath(__file__)))) 14 | import sys # noqa: E402 15 | sys.path.insert(0, SLAM3R_DIR) # noqa: E402 16 | from slam3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset 17 | from slam3r.utils.image import imread_cv2 18 | 19 | 20 | class Aria_Seq(BaseStereoViewDataset): 21 | def __init__(self, 22 | ROOT='data/projectaria/ase_processed', 23 | num_views=2, 24 | scene_name=None, # specify scene name(s) to load 25 | sample_freq=1, # stride of the frmaes inside the sliding window 26 | start_freq=1, # start frequency for the sliding window 27 | filter=False, # filter out the windows with abnormally large stride 28 | rand_sel=False, # randomly select views from a window 29 | winsize=0, # window size to randomly select views 30 | sel_num=0, # number of combinations to randomly select from a window 31 | *args,**kwargs): 32 | super().__init__(*args, **kwargs) 33 | self.ROOT = ROOT 34 | self.sample_freq = sample_freq 35 | self.start_freq = start_freq 36 | self.num_views = num_views 37 | 38 | self.rand_sel = rand_sel 39 | if rand_sel: 40 | assert winsize > 0 and sel_num > 0 41 | comb_num = math.comb(winsize-1, num_views-2) 42 | assert comb_num >= sel_num 43 | self.winsize = winsize 44 | self.sel_num = sel_num 45 | else: 46 | self.winsize = sample_freq*(num_views-1) 47 | 48 | self.scene_names = os.listdir(self.ROOT) 49 | self.scene_names = [int(scene_name) for scene_name in self.scene_names if scene_name.isdigit()] 50 | self.scene_names = sorted(self.scene_names) 51 | self.scene_names = [str(scene_name) for scene_name in self.scene_names] 52 | total_scene_num = len(self.scene_names) 53 | 54 | if self.split == 'train': 55 | # choose 90% of the data as training set 56 | self.scene_names = self.scene_names[:int(total_scene_num*0.9)] 57 | elif self.split=='test': 58 | self.scene_names = self.scene_names[int(total_scene_num*0.9):] 59 | if scene_name is not None: 60 | assert self.split is None 61 | if isinstance(scene_name, list): 62 | self.scene_names = scene_name 63 | else: 64 | if isinstance(scene_name, int): 65 | scene_name = str(scene_name) 66 | assert isinstance(scene_name, str) 67 | self.scene_names = [scene_name] 68 | 69 | self._load_data(filter=filter) 70 | print(self) 71 | 72 | def filter_windows(self, sid, eid, image_names): 73 | return False 74 | 75 | def _load_data(self, filter=False): 76 | self.sceneids = [] 77 | self.images = [] 78 | self.intrinsics = [] #scene_num*(3,3) 79 | self.win_bid = [] 80 | 81 | num_count = 0 82 | for id, scene_name in enumerate(self.scene_names): 83 | scene_dir = os.path.join(self.ROOT, scene_name) 84 | # print(id, scene_name) 85 | image_names = os.listdir(os.path.join(scene_dir, 'color')) 86 | image_names = sorted(image_names) 87 | intrinsic = np.loadtxt(os.path.join(scene_dir, 'intrinsic', 'intrinsic_color.txt'))[:3,:3] 88 | image_num = len(image_names) 89 | # precompute the window indices 90 | for i in range(0, image_num, self.start_freq): 91 | last_id = i+self.winsize 92 | if last_id >= image_num: 93 | break 94 | if filter and self.filter_windows(i, last_id, image_names): 95 | continue 96 | self.win_bid.append((num_count+i, num_count+last_id)) 97 | 98 | self.intrinsics.append(intrinsic) 99 | self.images += image_names 100 | self.sceneids += [id,] * image_num 101 | num_count += image_num 102 | # print(self.sceneids, self.scene_names) 103 | self.intrinsics = np.stack(self.intrinsics, axis=0) 104 | print(self.intrinsics.shape) 105 | assert len(self.sceneids)==len(self.images), f"{len(self.sceneids)}, {len(self.images)}" 106 | 107 | def __len__(self): 108 | if self.rand_sel: 109 | return self.sel_num*len(self.win_bid) 110 | return len(self.win_bid) 111 | 112 | def get_img_idxes(self, idx, rng): 113 | if self.rand_sel: 114 | sid, eid = self.win_bid[idx//self.sel_num] 115 | if idx % self.sel_num == 0: 116 | return np.linspace(sid, eid, self.num_views, endpoint=True, dtype=int) 117 | 118 | if self.num_views == 2: 119 | return [sid, eid] 120 | sel_ids = rng.choice(range(sid+1, eid), self.num_views-2, replace=False) 121 | sel_ids.sort() 122 | return [sid] + list(sel_ids) + [eid] 123 | else: 124 | sid, eid = self.win_bid[idx] 125 | return [sid + i*self.sample_freq for i in range(self.num_views)] 126 | 127 | 128 | def _get_views(self, idx, resolution, rng): 129 | 130 | image_idxes = self.get_img_idxes(idx, rng) 131 | # print(image_idxes) 132 | views = [] 133 | for view_idx in image_idxes: 134 | scene_id = self.sceneids[view_idx] 135 | scene_dir = osp.join(self.ROOT, self.scene_names[scene_id]) 136 | 137 | intrinsics = self.intrinsics[scene_id] 138 | basename = self.images[view_idx] 139 | camera_pose = np.loadtxt(osp.join(scene_dir, 'pose', basename.replace('.jpg', '.txt'))) 140 | # Load RGB image 141 | rgb_image = imread_cv2(osp.join(scene_dir, 'color', basename)) 142 | # Load depthmap 143 | depthmap = imread_cv2(osp.join(scene_dir, 'depth', basename.replace('.jpg', '.png')), cv2.IMREAD_UNCHANGED) 144 | depthmap[~np.isfinite(depthmap)] = 0 # invalid 145 | depthmap = depthmap.astype(np.float32) / 1000 146 | depthmap[depthmap > 20] = 0 # invalid 147 | 148 | rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( 149 | rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx) 150 | 151 | views.append(dict( 152 | img=rgb_image, 153 | depthmap=depthmap.astype(np.float32), 154 | camera_pose=camera_pose.astype(np.float32), 155 | camera_intrinsics=intrinsics.astype(np.float32), 156 | dataset='Aria', 157 | label=self.scene_names[scene_id] + '_' + basename, 158 | instance=f'{str(idx)}_{str(view_idx)}', 159 | )) 160 | # print([view['label'] for view in views]) 161 | return views 162 | 163 | if __name__ == "__main__": 164 | import trimesh 165 | 166 | num_views = 4 167 | # dataset = Aria_Seq(resolution=(224,224), 168 | # num_views=num_views, 169 | # start_freq=1, sample_freq=2) 170 | dataset = Aria_Seq(split='train', resolution=(224,224), 171 | num_views=num_views, 172 | start_freq=1, rand_sel=True, winsize=6, sel_num=3) 173 | save_dir = "visualization/aria_seq_views" 174 | os.makedirs(save_dir, exist_ok=True) 175 | 176 | for idx in np.random.permutation(len(dataset))[:10]: 177 | os.makedirs(osp.join(save_dir, str(idx)), exist_ok=True) 178 | views = dataset[(idx,0)] 179 | assert len(views) == num_views 180 | all_pts = [] 181 | all_color=[] 182 | for i, view in enumerate(views): 183 | img = np.array(view['img']).transpose(1, 2, 0) 184 | # save_path = osp.join(save_dir, str(idx), f"{'_'.join(view_name(view).split('/')[1:])}.jpg") 185 | save_path = osp.join(save_dir, str(idx), f"{i}_{view['label']}") 186 | # img=cv2.COLOR_RGB2BGR(img) 187 | img=img[...,::-1] 188 | img = (img+1)/2 189 | cv2.imwrite(save_path, img*255) 190 | print(f"save to {save_path}") 191 | img = img[...,::-1] 192 | pts3d = np.array(view['pts3d']).reshape(-1,3) 193 | pct = trimesh.PointCloud(pts3d, colors=img.reshape(-1, 3)) 194 | pct.export(save_path.replace('.jpg','.ply')) 195 | all_pts.append(pts3d) 196 | all_color.append(img.reshape(-1, 3)) 197 | all_pts = np.concatenate(all_pts, axis=0) 198 | all_color = np.concatenate(all_color, axis=0) 199 | pct = trimesh.PointCloud(all_pts, all_color) 200 | pct.export(osp.join(save_dir, str(idx), f"all.ply")) -------------------------------------------------------------------------------- /slam3r/datasets/replica_seq.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | # -------------------------------------------------------- 4 | # Dataloader for preprocessed Replica dataset provided by NICER-SLAM 5 | # -------------------------------------------------------- 6 | import os.path as osp 7 | import os 8 | import cv2 9 | import numpy as np 10 | from glob import glob 11 | import json 12 | import trimesh 13 | 14 | SLAM3R_DIR = osp.dirname(osp.dirname(osp.dirname(osp.abspath(__file__)))) 15 | import sys # noqa: E402 16 | sys.path.insert(0, SLAM3R_DIR) # noqa: E402 17 | from slam3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset 18 | from slam3r.utils.image import imread_cv2 19 | 20 | 21 | class Replica(BaseStereoViewDataset): 22 | def __init__(self, 23 | ROOT='data/Replica', 24 | num_views=2, 25 | num_fetch_views=None, 26 | sel_view=None, 27 | scene_name=None, 28 | sample_freq=20, 29 | start_freq=1, 30 | sample_dis=1, 31 | cycle=False, 32 | ref_id=-1, 33 | print_mess=False, 34 | *args,**kwargs): 35 | super().__init__(*args, **kwargs) 36 | self.ROOT = ROOT 37 | self.print_mess = print_mess 38 | self.sample_freq = sample_freq 39 | self.start_freq = start_freq 40 | self.sample_dis = sample_dis 41 | self.cycle=cycle 42 | self.num_fetch_views = num_fetch_views if num_fetch_views is not None else num_views 43 | self.sel_view = np.arange(num_views) if sel_view is None else np.array(sel_view) 44 | self.num_views = num_views 45 | assert ref_id < num_views 46 | self.ref_id = ref_id if ref_id >= 0 else (num_views-1) // 2 47 | self.scene_names = ["room0", "room1", "room2", "office0", "office1", "office2", "office3", "office4"] 48 | if self.split == 'train': 49 | self.scene_names = ["room0", "room1", "room2", "office0", "office1", "office2"] 50 | elif self.split=='val': 51 | self.scene_names = ["office3", "office4"] 52 | if scene_name is not None: 53 | assert self.split is None 54 | if isinstance(scene_name, list): 55 | self.scene_names = scene_name 56 | else: 57 | assert isinstance(scene_name, str) 58 | self.scene_names = [scene_name] 59 | self._load_data() 60 | print(self) 61 | 62 | def _load_data(self): 63 | self.sceneids = [] 64 | self.image_paths = [] 65 | self.trajectories = [] #c2w 66 | self.pairs = [] 67 | with open(os.path.join(self.ROOT,"cam_params.json"),'r') as f: 68 | self.intrinsic = json.load(f)['camera'] 69 | K = np.eye(3) 70 | K[0, 0] = self.intrinsic['fx'] 71 | K[1, 1] = self.intrinsic['fy'] 72 | K[0, 2] = self.intrinsic['cx'] 73 | K[1, 2] = self.intrinsic['cy'] 74 | self.intri_mat = K 75 | num_count = 0 76 | for id, scene_name in enumerate(self.scene_names): 77 | scene_dir = os.path.join(self.ROOT, scene_name) 78 | image_paths = sorted(glob(os.path.join(scene_dir,"results","frame*.jpg"))) 79 | 80 | image_paths = image_paths[::self.sample_freq] 81 | image_num = len(image_paths) 82 | 83 | if not self.cycle: 84 | for i in range(0, image_num, self.start_freq): 85 | last_id = i+self.sample_dis*(self.num_fetch_views-1) 86 | if last_id >= image_num: 87 | break 88 | self.pairs.append([j+num_count for j in range(i,last_id+1,self.sample_dis)]) 89 | else: 90 | for i in range(0, image_num, self.start_freq): 91 | pair = [] 92 | for j in range(0, self.num_fetch_views): 93 | pair.append((i+(j-self.ref_id)*self.sample_dis+image_num)%image_num + num_count) 94 | self.pairs.append(pair) 95 | 96 | self.trajectories.append(np.loadtxt(os.path.join(scene_dir,"traj.txt")).reshape(-1,4,4)[::self.sample_freq]) 97 | self.image_paths += image_paths 98 | self.sceneids += [id,] * image_num 99 | num_count += image_num 100 | # print(self.sceneids, self.scene_names) 101 | self.trajectories = np.concatenate(self.trajectories,axis=0) 102 | assert len(self.trajectories) == len(self.sceneids) and len(self.sceneids)==len(self.image_paths), f"{len(self.trajectories)}, {len(self.sceneids)}, {len(self.image_paths)}" 103 | 104 | def __len__(self): 105 | return len(self.pairs) 106 | 107 | def _get_views(self, idx, resolution, rng): 108 | 109 | image_idxes = self.pairs[idx] 110 | assert len(image_idxes) == self.num_fetch_views 111 | image_idxes = [image_idxes[i] for i in self.sel_view] 112 | views = [] 113 | for view_idx in image_idxes: 114 | scene_id = self.sceneids[view_idx] 115 | camera_pose = self.trajectories[view_idx] 116 | image_path = self.image_paths[view_idx] 117 | image_name = os.path.basename(image_path) 118 | depth_path = image_path.replace(".jpg",".png").replace("frame","depth") 119 | # Load RGB image 120 | rgb_image = imread_cv2(image_path) 121 | 122 | # Load depthmap 123 | depthmap = imread_cv2(depth_path, cv2.IMREAD_UNCHANGED) 124 | depthmap = depthmap.astype(np.float32) 125 | depthmap[~np.isfinite(depthmap)] = 0 # TODO:invalid 126 | depthmap /= self.intrinsic['scale'] 127 | 128 | rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( 129 | rgb_image, depthmap, self.intri_mat, resolution, rng=rng, info=view_idx) 130 | # print(intrinsics) 131 | views.append(dict( 132 | img=rgb_image, 133 | depthmap=depthmap.astype(np.float32), 134 | camera_pose=camera_pose.astype(np.float32), 135 | camera_intrinsics=intrinsics.astype(np.float32), 136 | dataset='Replica', 137 | label=self.scene_names[scene_id] + '_' + image_name, 138 | instance=f'{str(idx)}_{str(view_idx)}', 139 | )) 140 | if self.print_mess: 141 | print(f"loading {[view['label'] for view in views]}") 142 | return views 143 | 144 | 145 | if __name__ == "__main__": 146 | num_views = 5 147 | dataset= Replica(ref_id=1, print_mess=True, cycle=True, resolution=224, num_views=num_views, sample_freq=100, seed=777, start_freq=1, sample_dis=1) 148 | save_dir = "visualization/replica_views" 149 | 150 | # combine the pointmaps from different views with c2ws 151 | # to check the correctness of the dataloader 152 | for idx in np.random.permutation(len(dataset))[:10]: 153 | # for idx in range(10): 154 | views = dataset[(idx,0)] 155 | os.makedirs(osp.join(save_dir, str(idx)), exist_ok=True) 156 | assert len(views) == num_views 157 | all_pts = [] 158 | all_color = [] 159 | for i, view in enumerate(views): 160 | img = np.array(view['img']).transpose(1, 2, 0) 161 | save_path = osp.join(save_dir, str(idx), f"{i}_{view['label']}") 162 | print(save_path) 163 | img=img[...,::-1] 164 | img = (img+1)/2 165 | cv2.imwrite(save_path, img*255) 166 | print(f"save to {save_path}") 167 | img = img[...,::-1] 168 | pts3d = np.array(view['pts3d']).reshape(-1,3) 169 | pct = trimesh.PointCloud(pts3d, colors=img.reshape(-1, 3)) 170 | pct.export(save_path.replace('.jpg','.ply')) 171 | all_pts.append(pts3d) 172 | all_color.append(img.reshape(-1, 3)) 173 | all_pts = np.concatenate(all_pts, axis=0) 174 | all_color = np.concatenate(all_color, axis=0) 175 | pct = trimesh.PointCloud(all_pts, all_color) 176 | pct.export(osp.join(save_dir, str(idx), f"all.ply")) 177 | -------------------------------------------------------------------------------- /slam3r/datasets/scannetpp_seq.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | # -------------------------------------------------------- 4 | # Dataloader for preprocessed scannet++ 5 | # dataset at https://github.com/scannetpp/scannetpp - non-commercial research and educational purposes 6 | # https://kaldir.vc.in.tum.de/scannetpp/static/scannetpp-terms-of-use.pdf 7 | # See datasets_preprocess/preprocess_scannetpp.py 8 | # -------------------------------------------------------- 9 | import os.path as osp 10 | import os 11 | import cv2 12 | import numpy as np 13 | import math 14 | 15 | SLAM3R_DIR = osp.dirname(osp.dirname(osp.dirname(osp.abspath(__file__)))) 16 | import sys # noqa: E402 17 | sys.path.insert(0, SLAM3R_DIR) # noqa: E402 18 | from slam3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset 19 | from slam3r.utils.image import imread_cv2 20 | 21 | 22 | class ScanNetpp_Seq(BaseStereoViewDataset): 23 | def __init__(self, 24 | ROOT='data/scannetpp_processed', 25 | num_views=2, 26 | scene_name=None, # specify scene name(s) to load 27 | sample_freq=1, # stride of the frmaes inside the sliding window 28 | start_freq=1, # start frequency for the sliding window 29 | img_types=['iphone', 'dslr'], 30 | filter=False, # filter out the windows with abnormally large stride 31 | rand_sel=False, # randomly select views from a window 32 | winsize=0, # window size to randomly select views 33 | sel_num=0, # number of combinations to randomly select from a window 34 | *args,**kwargs): 35 | super().__init__(*args, **kwargs) 36 | self.ROOT = ROOT 37 | self.sample_freq = sample_freq 38 | self.start_freq = start_freq 39 | self.num_views = num_views 40 | self.img_types = img_types 41 | 42 | self.rand_sel = rand_sel 43 | if rand_sel: 44 | assert winsize > 0 and sel_num > 0 45 | comb_num = math.comb(winsize-1, num_views-2) 46 | assert comb_num >= sel_num 47 | self.winsize = winsize 48 | self.sel_num = sel_num 49 | else: 50 | self.winsize = sample_freq*(num_views-1) 51 | 52 | if self.split == 'train': 53 | with open(os.path.join(ROOT, 'splits', 'nvs_sem_train.txt'), 'r') as f: 54 | self.scene_names = f.read().splitlines() 55 | elif self.split=='test': 56 | with open(os.path.join(ROOT, 'splits', 'nvs_sem_val.txt'), 'r') as f: 57 | self.scene_names = f.read().splitlines() 58 | if scene_name is not None: 59 | assert self.split is None 60 | if isinstance(scene_name, list): 61 | self.scene_names = scene_name 62 | else: 63 | assert isinstance(scene_name, str) 64 | self.scene_names = [scene_name] 65 | 66 | self._load_data(filter=filter) 67 | print(self) 68 | 69 | def filter_windows(self, img_type, sid, eid, image_names): 70 | if img_type == 'iphone': # frame_000450.jpg 71 | start_id = int(image_names[sid].split('_')[-1].split('.')[0]) 72 | end_id = int(image_names[eid].split('_')[-1].split('.')[0]) 73 | base_stride = 10*self.winsize 74 | elif img_type == 'dslr': # DSC06967.jpg 75 | start_id = int(image_names[sid].split('.')[0][-5:]) 76 | end_id = int(image_names[eid].split('.')[0][-5:]) 77 | base_stride = self.winsize 78 | # filiter out the windows with abnormally large stride 79 | if end_id - start_id >= base_stride*3: 80 | return True 81 | return False 82 | 83 | def _load_data(self, filter=False): 84 | self.sceneids = [] 85 | self.images = [] 86 | self.intrinsics = [] #(3,3) 87 | self.trajectories = [] #c2w (4,4) 88 | self.win_bid = [] 89 | 90 | num_count = 0 91 | for id, scene_name in enumerate(self.scene_names): 92 | scene_dir = os.path.join(self.ROOT, scene_name) 93 | for img_type in self.img_types: 94 | metadata_path = os.path.join(scene_dir, f'scene_{img_type}_metadata.npz') 95 | if not os.path.exists(metadata_path): 96 | continue 97 | metadata = np.load(metadata_path) 98 | image_names = metadata['images'].tolist() 99 | 100 | # check if the images are in the same sequence, 101 | # only work for certain scenes in ScanNet++ V2 102 | if img_type == 'dslr': 103 | prefixes = [name[0:3] for name in image_names] 104 | if len(set(prefixes)) > 1: 105 | # dslr images are not in the same sequence 106 | print(f"Warning: {scene_name} {img_type} images are not in the same sequence {set(prefixes)}") 107 | continue 108 | 109 | assert image_names == sorted(image_names) 110 | image_names = sorted(image_names) 111 | intrinsics = metadata['intrinsics'] 112 | trajectories = metadata['trajectories'] 113 | image_num = len(image_names) 114 | # precompute the window indices 115 | for i in range(0, image_num, self.start_freq): 116 | last_id = i+self.winsize 117 | if last_id >= image_num: 118 | break 119 | if filter and self.filter_windows(img_type, i, last_id, image_names): 120 | continue 121 | self.win_bid.append((num_count+i, num_count+last_id)) 122 | 123 | self.trajectories.append(trajectories) 124 | self.intrinsics.append(intrinsics) 125 | self.images += image_names 126 | self.sceneids += [id,] * image_num 127 | num_count += image_num 128 | # print(self.sceneids, self.scene_names) 129 | self.trajectories = np.concatenate(self.trajectories,axis=0) 130 | self.intrinsics = np.concatenate(self.intrinsics, axis=0) 131 | assert len(self.trajectories) == len(self.sceneids) and len(self.sceneids)==len(self.images), f"{len(self.trajectories)}, {len(self.sceneids)}, {len(self.images)}" 132 | 133 | def __len__(self): 134 | if self.rand_sel: 135 | return self.sel_num*len(self.win_bid) 136 | return len(self.win_bid) 137 | 138 | def get_img_idxes(self, idx, rng): 139 | if self.rand_sel: 140 | sid, eid = self.win_bid[idx//self.sel_num] 141 | if idx % self.sel_num == 0: 142 | return np.linspace(sid, eid, self.num_views, endpoint=True, dtype=int) 143 | 144 | # random select the views, including the start and end view 145 | if self.num_views == 2: 146 | return [sid, eid] 147 | sel_ids = rng.choice(range(sid+1, eid), self.num_views-2, replace=False) 148 | sel_ids.sort() 149 | return [sid] + list(sel_ids) + [eid] 150 | else: 151 | sid, eid = self.win_bid[idx] 152 | return [sid + i*self.sample_freq for i in range(self.num_views)] 153 | 154 | 155 | def _get_views(self, idx, resolution, rng): 156 | 157 | image_idxes = self.get_img_idxes(idx, rng) 158 | # print(image_idxes) 159 | views = [] 160 | for view_idx in image_idxes: 161 | scene_id = self.sceneids[view_idx] 162 | scene_dir = osp.join(self.ROOT, self.scene_names[scene_id]) 163 | 164 | intrinsics = self.intrinsics[view_idx] 165 | camera_pose = self.trajectories[view_idx] 166 | basename = self.images[view_idx] 167 | # Load RGB image 168 | rgb_image = imread_cv2(osp.join(scene_dir, 'images', basename)) 169 | # Load depthmap 170 | depthmap = imread_cv2(osp.join(scene_dir, 'depth', basename.replace('.jpg', '.png')), cv2.IMREAD_UNCHANGED) 171 | depthmap = depthmap.astype(np.float32) / 1000 172 | depthmap[~np.isfinite(depthmap)] = 0 # invalid 173 | 174 | rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( 175 | rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx) 176 | 177 | views.append(dict( 178 | img=rgb_image, 179 | depthmap=depthmap.astype(np.float32), 180 | camera_pose=camera_pose.astype(np.float32), 181 | camera_intrinsics=intrinsics.astype(np.float32), 182 | dataset='ScanNet++', 183 | label=self.scene_names[scene_id] + '_' + basename, 184 | instance=f'{str(idx)}_{str(view_idx)}', 185 | )) 186 | 187 | return views 188 | 189 | if __name__ == "__main__": 190 | from slam3r.datasets.base.base_stereo_view_dataset import view_name 191 | import trimesh 192 | 193 | num_views = 5 194 | dataset = ScanNetpp_Seq(split='train', resolution=(224,224), 195 | num_views=num_views, 196 | start_freq=1, sample_freq=3) 197 | # dataset = ScanNetpp_Seq(split='train', resolution=(224,224), 198 | # num_views=num_views, 199 | # start_freq=1, rand_sel=True, winsize=6, sel_num=3) 200 | save_dir = "visualization/scannetpp_seq_views" 201 | os.makedirs(save_dir, exist_ok=True) 202 | 203 | for idx in np.random.permutation(len(dataset))[:10]: 204 | # for idx in range(len(dataset))[5:10000:2000]: 205 | os.makedirs(osp.join(save_dir, str(idx)), exist_ok=True) 206 | views = dataset[(idx,0)] 207 | assert len(views) == num_views 208 | all_pts = [] 209 | all_color=[] 210 | for i, view in enumerate(views): 211 | img = np.array(view['img']).transpose(1, 2, 0) 212 | save_path = osp.join(save_dir, str(idx), f"{i}_{view['label']}") 213 | img=img[...,::-1] 214 | img = (img+1)/2 215 | cv2.imwrite(save_path, img*255) 216 | print(f"save to {save_path}") 217 | img = img[...,::-1] 218 | pts3d = np.array(view['pts3d']).reshape(-1,3) 219 | pct = trimesh.PointCloud(pts3d, colors=img.reshape(-1, 3)) 220 | pct.export(save_path.replace('.jpg','.ply')) 221 | all_pts.append(pts3d) 222 | all_color.append(img.reshape(-1, 3)) 223 | all_pts = np.concatenate(all_pts, axis=0) 224 | all_color = np.concatenate(all_color, axis=0) 225 | pct = trimesh.PointCloud(all_pts, all_color) 226 | pct.export(osp.join(save_dir, str(idx), f"all.ply")) -------------------------------------------------------------------------------- /slam3r/datasets/seven_scenes_seq.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | # 4 | # -------------------------------------------------------- 5 | # Dataloader for 7Scenes dataset 6 | # -------------------------------------------------------- 7 | import os.path as osp 8 | import os 9 | import cv2 10 | import numpy as np 11 | import torch 12 | import itertools 13 | from glob import glob 14 | import json 15 | import trimesh 16 | 17 | SLAM3R_DIR = osp.dirname(osp.dirname(osp.dirname(osp.abspath(__file__)))) 18 | import sys # noqa: E402 19 | sys.path.insert(0, SLAM3R_DIR) # noqa: E402 20 | from slam3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset 21 | from slam3r.utils.image import imread_cv2 22 | 23 | class SevenScenes_Seq(BaseStereoViewDataset): 24 | def __init__(self, 25 | ROOT='data/7Scenes', 26 | scene_id='office', 27 | seq_id=1, 28 | num_views=1, 29 | sample_freq=1, 30 | start_freq=1, 31 | cycle=False, 32 | ref_id=-1, 33 | *args,**kwargs): 34 | super().__init__(*args, **kwargs) 35 | self.ROOT = ROOT 36 | self.cycle = cycle 37 | self.scene_id = scene_id 38 | self.scene_names = [scene_id+'_seq-'+f"{seq_id:02d}"] 39 | self.seq_id = seq_id 40 | self.sample_freq = sample_freq 41 | self.start_freq = start_freq 42 | self.ref_id = ref_id if ref_id >= 0 else (num_views-1) // 2 43 | self.num_views = num_views 44 | self.num_fetch_views = self.num_views 45 | self.data_dir = os.path.join(self.ROOT, scene_id, f'seq-{seq_id:02d}') 46 | self._load_data() 47 | print(self) 48 | 49 | def _load_data(self): 50 | self.intrinsics = np.array([[585, 0, 320], 51 | [0, 585, 240], 52 | [0, 0, 1]], dtype=np.float32) 53 | self.trajectories = [] #c2w (4,4) 54 | self.pairs = [] 55 | self.images = sorted(glob(osp.join(self.data_dir, '*.color.png'))) #frame-000000.color.png 56 | image_num = len(self.images) 57 | #这两行能否提速 58 | if not self.cycle: 59 | for i in range(0, image_num, self.start_freq): 60 | last_id = i+(self.num_views-1)*self.sample_freq 61 | if last_id >= image_num: break 62 | self.pairs.append([i+j*self.sample_freq 63 | for j in range(self.num_views)]) 64 | else: 65 | for i in range(0, image_num, self.start_freq): 66 | pair = [] 67 | for j in range(0, self.num_fetch_views): 68 | pair.append((i+(j-self.ref_id)*self.sample_freq+image_num)%image_num) 69 | self.pairs.append(pair) 70 | print(self.pairs) 71 | def __len__(self): 72 | return len(self.pairs) 73 | # return len(self.img_group) 74 | 75 | def _get_views(self, idx, resolution, rng): 76 | 77 | image_idxes = self.pairs[idx] 78 | views = [] 79 | scene_dir = self.data_dir 80 | 81 | for view_idx in image_idxes: 82 | 83 | intrinsics = self.intrinsics 84 | img_path = self.images[view_idx] 85 | 86 | # Load RGB image 87 | rgb_image = imread_cv2(img_path) 88 | # Load depthmap(16-bit, PNG, invalid depth is set to 65535) 89 | depthmap = imread_cv2(img_path.replace('.color.png','.depth.png'), cv2.IMREAD_UNCHANGED) 90 | depthmap[depthmap == 65535] = 0 91 | depthmap = depthmap.astype(np.float32) / 1000 92 | depthmap[~np.isfinite(depthmap)] = 0 # invalid 93 | camera_pose = np.loadtxt(img_path.replace('.color.png','.pose.txt')) 94 | 95 | rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( 96 | rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx) 97 | 98 | views.append(dict( 99 | img=rgb_image, 100 | depthmap=depthmap.astype(np.float32), 101 | camera_pose=camera_pose.astype(np.float32), 102 | camera_intrinsics=intrinsics.astype(np.float32), 103 | dataset='ScanNet++', 104 | label=img_path, 105 | instance=f'{str(idx)}_{str(view_idx)}', 106 | )) 107 | return views 108 | 109 | 110 | class SevenScenes_Seq_Cali(BaseStereoViewDataset): 111 | def __init__(self, 112 | ROOT='data/7s_dsac/dsac', 113 | scene_id='office', 114 | seq_id=1, 115 | num_views=1, 116 | sample_freq=1, 117 | start_freq=1, 118 | cycle=False, 119 | ref_id=-1, 120 | *args,**kwargs): 121 | super().__init__(*args, **kwargs) 122 | self.ROOT = ROOT 123 | self.cycle = cycle 124 | self.scene_id = scene_id 125 | self.scene_names = [scene_id+'_seq-'+f"{seq_id:02d}"] 126 | self.seq_id = seq_id 127 | self.sample_freq = sample_freq 128 | self.start_freq = start_freq 129 | self.ref_id = ref_id if ref_id >= 0 else (num_views-1) // 2 130 | self.num_views = num_views 131 | self.num_fetch_views = self.num_views 132 | self.data_dir = os.path.join(self.ROOT, scene_id, f'seq-{seq_id:02d}') 133 | self._load_data() 134 | print(self) 135 | 136 | def _load_data(self): 137 | self.intrinsics = np.array([[525, 0, 320], 138 | [0, 525, 240], 139 | [0, 0, 1]], dtype=np.float32) 140 | self.trajectories = [] #c2w (4,4) 141 | self.pairs = [] 142 | self.images = sorted(glob(osp.join(self.data_dir, '*.color.png'))) #frame-000000.color.png 143 | image_num = len(self.images) 144 | #这两行能否提速 145 | if not self.cycle: 146 | for i in range(0, image_num, self.start_freq): 147 | last_id = i+(self.num_views-1)*self.sample_freq 148 | if last_id >= image_num: break 149 | self.pairs.append([i+j*self.sample_freq 150 | for j in range(self.num_views)]) 151 | else: 152 | for i in range(0, image_num, self.start_freq): 153 | pair = [] 154 | for j in range(0, self.num_fetch_views): 155 | pair.append((i+(j-self.ref_id)*self.sample_freq+image_num)%image_num) 156 | self.pairs.append(pair) 157 | # print(self.pairs) 158 | def __len__(self): 159 | return len(self.pairs) 160 | # return len(self.img_group) 161 | 162 | def _get_views(self, idx, resolution, rng): 163 | 164 | image_idxes = self.pairs[idx] 165 | views = [] 166 | scene_dir = self.data_dir 167 | 168 | for view_idx in image_idxes: 169 | 170 | intrinsics = self.intrinsics 171 | img_path = self.images[view_idx] 172 | 173 | # Load RGB image 174 | rgb_image = imread_cv2(img_path) 175 | # Load depthmap(16-bit, PNG, invalid depth is set to 65535) 176 | depthmap = imread_cv2(img_path.replace('.color.png','.depth_cali.png'), cv2.IMREAD_UNCHANGED) 177 | depthmap[depthmap == 65535] = 0 178 | depthmap = depthmap.astype(np.float32) / 1000 179 | depthmap[~np.isfinite(depthmap)] = 0 # invalid 180 | 181 | camera_pose = np.loadtxt(img_path.replace('.color.png','.pose.txt')) 182 | 183 | rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( 184 | rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx) 185 | print(intrinsics) 186 | views.append(dict( 187 | img=rgb_image, 188 | depthmap=depthmap.astype(np.float32), 189 | camera_pose=camera_pose.astype(np.float32), 190 | camera_intrinsics=intrinsics.astype(np.float32), 191 | dataset='SevenScenes', 192 | label=img_path, 193 | instance=f'{str(idx)}_{str(view_idx)}', 194 | )) 195 | return views 196 | 197 | 198 | if __name__ == "__main__": 199 | from slam3r.datasets.base.base_stereo_view_dataset import view_name 200 | from slam3r.viz import SceneViz, auto_cam_size 201 | from slam3r.utils.image import rgb 202 | 203 | num_views = 3 204 | dataset = SevenScenes_Seq(scene_id='office',seq_id=9,resolution=(224,224), num_views=num_views, 205 | start_freq=1, sample_freq=20) 206 | save_dir = "visualization/7scenes_seq_views" 207 | os.makedirs(save_dir, exist_ok=True) 208 | 209 | # for idx in np.random.permutation(len(dataset))[:10]: 210 | for idx in range(len(dataset))[:500:100]: 211 | os.makedirs(osp.join(save_dir, str(idx)), exist_ok=True) 212 | views = dataset[(idx,0)] 213 | assert len(views) == num_views 214 | all_pts = [] 215 | all_color=[] 216 | for i, view in enumerate(views): 217 | img = np.array(view['img']).transpose(1, 2, 0) 218 | # save_path = osp.join(save_dir, str(idx), f"{'_'.join(view_name(view).split('/')[1:])}.jpg") 219 | print(view['label']) 220 | save_path = osp.join(save_dir, str(idx), f"{i}_{os.path.basename(view['label'])}") 221 | # img=cv2.COLOR_RGB2BGR(img) 222 | img=img[...,::-1] 223 | img = (img+1)/2 224 | cv2.imwrite(save_path, img*255) 225 | print(f"save to {save_path}") 226 | pts3d = np.array(view['pts3d']).reshape(-1,3) 227 | pct = trimesh.PointCloud(pts3d, colors=img.reshape(-1, 3)) 228 | pct.export(save_path.replace('.png','.ply')) 229 | all_pts.append(pts3d) 230 | all_color.append(img.reshape(-1, 3)) 231 | all_pts = np.concatenate(all_pts, axis=0) 232 | all_color = np.concatenate(all_color, axis=0) 233 | pct = trimesh.PointCloud(all_pts, all_color) 234 | pct.export(osp.join(save_dir, str(idx), f"all.ply")) 235 | # for idx in range(len(dataset)): 236 | # views = dataset[(idx,0)] 237 | # print([view['label'] for view in views]) -------------------------------------------------------------------------------- /slam3r/datasets/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | -------------------------------------------------------------------------------- /slam3r/datasets/utils/cropping.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | # 4 | # -------------------------------------------------------- 5 | # croppping utilities 6 | # -------------------------------------------------------- 7 | import PIL.Image 8 | import os 9 | os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" 10 | import cv2 # noqa 11 | import numpy as np # noqa 12 | from slam3r.utils.geometry import colmap_to_opencv_intrinsics, opencv_to_colmap_intrinsics # noqa 13 | try: 14 | lanczos = PIL.Image.Resampling.LANCZOS 15 | except AttributeError: 16 | lanczos = PIL.Image.LANCZOS 17 | 18 | 19 | class ImageList: 20 | """ Convenience class to aply the same operation to a whole set of images. 21 | """ 22 | 23 | def __init__(self, images): 24 | if not isinstance(images, (tuple, list, set)): 25 | images = [images] 26 | self.images = [] 27 | for image in images: 28 | if not isinstance(image, PIL.Image.Image): 29 | image = PIL.Image.fromarray(image) 30 | self.images.append(image) 31 | 32 | def __len__(self): 33 | return len(self.images) 34 | 35 | def to_pil(self): 36 | return tuple(self.images) if len(self.images) > 1 else self.images[0] 37 | 38 | @property 39 | def size(self): 40 | sizes = [im.size for im in self.images] 41 | assert all(sizes[0] == s for s in sizes) 42 | return sizes[0] 43 | 44 | def resize(self, *args, **kwargs): 45 | return ImageList(self._dispatch('resize', *args, **kwargs)) 46 | 47 | def crop(self, *args, **kwargs): 48 | return ImageList(self._dispatch('crop', *args, **kwargs)) 49 | 50 | def _dispatch(self, func, *args, **kwargs): 51 | return [getattr(im, func)(*args, **kwargs) for im in self.images] 52 | 53 | 54 | def rescale_image_depthmap(image, depthmap, camera_intrinsics, output_resolution): 55 | """ Jointly rescale a (image, depthmap) 56 | so that (out_width, out_height) >= output_res 57 | """ 58 | image = ImageList(image) 59 | input_resolution = np.array(image.size) # (W,H) 60 | output_resolution = np.array(output_resolution) 61 | if depthmap is not None: 62 | # can also use this with masks instead of depthmaps 63 | # print(tuple(depthmap.shape[:2]), image.size[::-1]) 64 | assert tuple(depthmap.shape[:2]) == image.size[::-1] 65 | assert output_resolution.shape == (2,) 66 | # define output resolution 67 | scale_final = max(output_resolution / image.size) + 1e-8 68 | output_resolution = np.floor(input_resolution * scale_final).astype(int) 69 | 70 | # first rescale the image so that it contains the crop 71 | image = image.resize(output_resolution, resample=lanczos) 72 | if depthmap is not None: 73 | depthmap = cv2.resize(depthmap, output_resolution, fx=scale_final, 74 | fy=scale_final, interpolation=cv2.INTER_NEAREST) 75 | 76 | # no offset here; simple rescaling 77 | camera_intrinsics = camera_matrix_of_crop( 78 | camera_intrinsics, input_resolution, output_resolution, scaling=scale_final) 79 | 80 | return image.to_pil(), depthmap, camera_intrinsics 81 | 82 | 83 | def camera_matrix_of_crop(input_camera_matrix, input_resolution, output_resolution, scaling=1, offset_factor=0.5, offset=None): 84 | # Margins to offset the origin 85 | margins = np.asarray(input_resolution) * scaling - output_resolution 86 | assert np.all(margins >= 0.0) 87 | if offset is None: 88 | offset = offset_factor * margins 89 | 90 | # Generate new camera parameters 91 | output_camera_matrix_colmap = opencv_to_colmap_intrinsics(input_camera_matrix) 92 | output_camera_matrix_colmap[:2, :] *= scaling 93 | output_camera_matrix_colmap[:2, 2] -= offset 94 | output_camera_matrix = colmap_to_opencv_intrinsics(output_camera_matrix_colmap) 95 | 96 | return output_camera_matrix 97 | 98 | 99 | def crop_image_depthmap(image, depthmap, camera_intrinsics, crop_bbox): 100 | """ 101 | Return a crop of the input view. 102 | """ 103 | image = ImageList(image) 104 | l, t, r, b = crop_bbox 105 | 106 | image = image.crop((l, t, r, b)) 107 | depthmap = depthmap[t:b, l:r] 108 | 109 | camera_intrinsics = camera_intrinsics.copy() 110 | camera_intrinsics[0, 2] -= l 111 | camera_intrinsics[1, 2] -= t 112 | 113 | return image.to_pil(), depthmap, camera_intrinsics 114 | 115 | 116 | def bbox_from_intrinsics_in_out(input_camera_matrix, output_camera_matrix, output_resolution): 117 | out_width, out_height = output_resolution 118 | l, t = np.int32(np.round(input_camera_matrix[:2, 2] - output_camera_matrix[:2, 2])) 119 | crop_bbox = (l, t, l+out_width, t+out_height) 120 | return crop_bbox 121 | -------------------------------------------------------------------------------- /slam3r/datasets/utils/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | # 4 | # -------------------------------------------------------- 5 | # DUST3R default transforms 6 | # -------------------------------------------------------- 7 | import torchvision.transforms as tvf 8 | from slam3r.utils.image import ImgNorm 9 | 10 | # define the standard image transforms 11 | ColorJitter = tvf.Compose([tvf.ColorJitter(0.5, 0.5, 0.5, 0.1), ImgNorm]) 12 | -------------------------------------------------------------------------------- /slam3r/datasets/wild_seq.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | # 4 | # -------------------------------------------------------- 5 | # Dataloader for self-captured img sequence 6 | # -------------------------------------------------------- 7 | import os.path as osp 8 | import torch 9 | 10 | SLAM3R_DIR = osp.dirname(osp.dirname(osp.dirname(osp.abspath(__file__)))) 11 | import sys # noqa: E402 12 | sys.path.insert(0, SLAM3R_DIR) # noqa: E402 13 | from slam3r.utils.image import load_images 14 | 15 | class Seq_Data(): 16 | def __init__(self, 17 | img_dir, # the directory of the img sequence 18 | img_size=224, # only img_size=224 is supported now 19 | silent=False, 20 | sample_freq=1, # the frequency of the imgs to be sampled 21 | num_views=-1, # only take the first num_views imgs in the img_dir 22 | start_freq=1, 23 | postfix=None, # the postfix of the img in the img_dir(.jpg, .png, ...) 24 | to_tensor=False, 25 | start_idx=0): 26 | 27 | # Note that only img_size=224 is supported now. 28 | # Imgs will be cropped and resized to 224x224, thus losing the information in the border. 29 | assert img_size==224, "Sorry, only img_size=224 is supported now." 30 | 31 | # load imgs with sequential number. 32 | # Imgs in the img_dir should have number in their names to indicate the order, 33 | # such as frame-0031.color.png, output_414.jpg, ... 34 | self.imgs = load_images(img_dir, size=img_size, 35 | verbose=not silent, img_freq=sample_freq, 36 | postfix=postfix, start_idx=start_idx, img_num=num_views) 37 | 38 | self.num_views = num_views if num_views > 0 else len(self.imgs) 39 | self.stride = start_freq 40 | self.img_num = len(self.imgs) 41 | if to_tensor: 42 | for img in self.imgs: 43 | img['true_shape'] = torch.tensor(img['true_shape']) 44 | self.make_groups() 45 | self.length = len(self.groups) 46 | 47 | if isinstance(img_dir, str): 48 | if img_dir[-1] == '/': 49 | img_dir = img_dir[:-1] 50 | self.scene_names = ['_'.join(img_dir.split('/')[-2:])] 51 | 52 | def make_groups(self): 53 | self.groups = [] 54 | for start in range(0,self.img_num, self.stride): 55 | end = start + self.num_views 56 | if end > self.img_num: 57 | break 58 | self.groups.append(self.imgs[start:end]) 59 | 60 | def __len__(self): 61 | return len(self.groups) 62 | 63 | def __getitem__(self, idx): 64 | return self.groups[idx] 65 | 66 | 67 | 68 | if __name__ == "__main__": 69 | from slam3r.datasets.base.base_stereo_view_dataset import view_name 70 | from slam3r.viz import SceneViz, auto_cam_size 71 | from slam3r.utils.image import rgb 72 | 73 | dataset = Seq_Data(img_dir="dataset/7Scenes/office-09", 74 | img_size=224, silent=False, sample_freq=10, 75 | num_views=5, start_freq=2, postfix="color.png") 76 | for i in range(len(dataset)): 77 | data = dataset[i] 78 | print([img['idx'] for img in data]) 79 | # break -------------------------------------------------------------------------------- /slam3r/heads/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | # 4 | # -------------------------------------------------------- 5 | # head factory 6 | # -------------------------------------------------------- 7 | from .linear_head import LinearPts3d 8 | from .dpt_head import create_dpt_head 9 | 10 | 11 | def head_factory(head_type, output_mode, net, has_conf=False): 12 | """" build a prediction head for the decoder 13 | """ 14 | if head_type == 'linear' and output_mode == 'pts3d': 15 | return LinearPts3d(net, has_conf) 16 | elif head_type == 'dpt' and output_mode == 'pts3d': 17 | return create_dpt_head(net, has_conf=has_conf) 18 | else: 19 | raise NotImplementedError(f"unexpected {head_type=} and {output_mode=}") 20 | -------------------------------------------------------------------------------- /slam3r/heads/dpt_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | # 4 | # -------------------------------------------------------- 5 | # dpt head implementation for DUST3R 6 | # Downstream heads assume inputs of size B x N x C (where N is the number of tokens) ; 7 | # or if it takes as input the output at every layer, the attribute return_all_layers should be set to True 8 | # the forward function also takes as input a dictionnary img_info with key "height" and "width" 9 | # for PixelwiseTask, the output will be of dimension B x num_channels x H x W 10 | # -------------------------------------------------------- 11 | from einops import rearrange 12 | from typing import List 13 | import torch 14 | import torch.nn as nn 15 | 16 | from .dpt_block import DPTOutputAdapter # noqa 17 | from .postprocess import postprocess 18 | 19 | 20 | class DPTOutputAdapter_fix(DPTOutputAdapter): 21 | """ 22 | Adapt croco's DPTOutputAdapter implementation for dust3r: 23 | remove duplicated weigths, and fix forward for dust3r 24 | """ 25 | 26 | def init(self, dim_tokens_enc=768): 27 | super().init(dim_tokens_enc) 28 | # these are duplicated weights 29 | del self.act_1_postprocess 30 | del self.act_2_postprocess 31 | del self.act_3_postprocess 32 | del self.act_4_postprocess 33 | 34 | def forward(self, encoder_tokens: List[torch.Tensor], image_size=None): 35 | assert self.dim_tokens_enc is not None, 'Need to call init(dim_tokens_enc) function first' 36 | # H, W = input_info['image_size'] 37 | image_size = self.image_size if image_size is None else image_size 38 | H, W = image_size 39 | # print(encoder_tokens[0].shape, "size:", H, W) 40 | # Number of patches in height and width 41 | N_H = H // (self.stride_level * self.P_H) 42 | N_W = W // (self.stride_level * self.P_W) 43 | 44 | # Hook decoder onto 4 layers from specified ViT layers 45 | layers = [encoder_tokens[hook] for hook in self.hooks] 46 | 47 | # Extract only task-relevant tokens and ignore global tokens. 48 | layers = [self.adapt_tokens(l) for l in layers] 49 | 50 | # Reshape tokens to spatial representation 51 | layers = [rearrange(l, 'b (nh nw) c -> b c nh nw', nh=N_H, nw=N_W) for l in layers] 52 | 53 | layers = [self.act_postprocess[idx](l) for idx, l in enumerate(layers)] 54 | # Project layers to chosen feature dim 55 | layers = [self.scratch.layer_rn[idx](l) for idx, l in enumerate(layers)] 56 | 57 | # Fuse layers using refinement stages 58 | path_4 = self.scratch.refinenet4(layers[3])[:, :, :layers[2].shape[2], :layers[2].shape[3]] 59 | path_3 = self.scratch.refinenet3(path_4, layers[2]) 60 | path_2 = self.scratch.refinenet2(path_3, layers[1]) 61 | path_1 = self.scratch.refinenet1(path_2, layers[0]) 62 | 63 | # Output head 64 | out = self.head(path_1) 65 | 66 | return out 67 | 68 | 69 | class PixelwiseTaskWithDPT(nn.Module): 70 | """ DPT module for dust3r, can return 3D points + confidence for all pixels""" 71 | 72 | def __init__(self, *, n_cls_token=0, hooks_idx=None, dim_tokens=None, 73 | output_width_ratio=1, num_channels=1, postprocess=None, depth_mode=None, conf_mode=None, **kwargs): 74 | super(PixelwiseTaskWithDPT, self).__init__() 75 | self.return_all_layers = True # backbone needs to return all layers 76 | self.postprocess = postprocess 77 | self.depth_mode = depth_mode 78 | self.conf_mode = conf_mode 79 | 80 | assert n_cls_token == 0, "Not implemented" 81 | dpt_args = dict(output_width_ratio=output_width_ratio, 82 | num_channels=num_channels, 83 | **kwargs) 84 | if hooks_idx is not None: 85 | dpt_args.update(hooks=hooks_idx) 86 | self.dpt = DPTOutputAdapter_fix(**dpt_args) 87 | dpt_init_args = {} if dim_tokens is None else {'dim_tokens_enc': dim_tokens} 88 | self.dpt.init(**dpt_init_args) 89 | 90 | def forward(self, x, img_info): 91 | # print(len(x), "img info", img_info) 92 | out = self.dpt(x, image_size=(img_info[0], img_info[1])) 93 | if self.postprocess: 94 | out = self.postprocess(out, self.depth_mode, self.conf_mode) 95 | return out 96 | 97 | 98 | def create_dpt_head(net, has_conf=False): 99 | """ 100 | return PixelwiseTaskWithDPT for given net params 101 | """ 102 | assert net.dec_depth > 9 103 | l2 = net.dec_depth 104 | feature_dim = 256 105 | last_dim = feature_dim//2 106 | out_nchan = 3 107 | ed = net.enc_embed_dim 108 | dd = net.dec_embed_dim 109 | return PixelwiseTaskWithDPT(num_channels=out_nchan + has_conf, 110 | feature_dim=feature_dim, 111 | last_dim=last_dim, 112 | hooks_idx=[0, l2*2//4, l2*3//4, l2], 113 | dim_tokens=[ed, dd, dd, dd], 114 | postprocess=postprocess, 115 | depth_mode=net.depth_mode, 116 | conf_mode=net.conf_mode, 117 | head_type='regression') 118 | -------------------------------------------------------------------------------- /slam3r/heads/linear_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | # 4 | # -------------------------------------------------------- 5 | # linear head implementation for DUST3R 6 | # -------------------------------------------------------- 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from .postprocess import postprocess 10 | 11 | 12 | class LinearPts3d (nn.Module): 13 | """ 14 | Linear head for dust3r 15 | Each token outputs: - 16x16 3D points (+ confidence) 16 | """ 17 | 18 | def __init__(self, net, has_conf=False): 19 | super().__init__() 20 | self.patch_size = net.patch_embed.patch_size[0] 21 | self.depth_mode = net.depth_mode 22 | self.conf_mode = net.conf_mode 23 | self.has_conf = has_conf 24 | 25 | self.proj = nn.Linear(net.dec_embed_dim, (3 + has_conf)*self.patch_size**2) 26 | 27 | def setup(self, croconet): 28 | pass 29 | 30 | def forward(self, decout, img_shape): 31 | H, W = img_shape 32 | tokens = decout[-1] 33 | B, S, D = tokens.shape #S is the number of tokens 34 | 35 | # extract 3D points 36 | feat = self.proj(tokens) # B,S,D ;D is property of all the 3D points in the patch 37 | feat = feat.transpose(-1, -2).view(B, -1, H//self.patch_size, W//self.patch_size) 38 | feat = F.pixel_shuffle(feat, self.patch_size) # B,3,H,W 39 | 40 | # permute + norm depth 41 | return postprocess(feat, self.depth_mode, self.conf_mode) 42 | -------------------------------------------------------------------------------- /slam3r/heads/postprocess.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | # 4 | # -------------------------------------------------------- 5 | # post process function for all heads: extract 3D points/confidence from output 6 | # -------------------------------------------------------- 7 | import torch 8 | 9 | 10 | def postprocess(out, depth_mode, conf_mode): 11 | """ 12 | extract 3D points/confidence from prediction head output 13 | """ 14 | fmap = out.permute(0, 2, 3, 1) # B,H,W,3 15 | res = dict(pts3d=reg_dense_depth(fmap[:, :, :, 0:3], mode=depth_mode)) 16 | 17 | if conf_mode is not None: 18 | res['conf'] = reg_dense_conf(fmap[:, :, :, 3], mode=conf_mode) 19 | return res 20 | 21 | 22 | def reg_dense_depth(xyz, mode): 23 | """ 24 | extract 3D points from prediction head output 25 | """ 26 | mode, vmin, vmax = mode 27 | 28 | no_bounds = (vmin == -float('inf')) and (vmax == float('inf')) 29 | assert no_bounds 30 | 31 | if mode == 'linear': 32 | if no_bounds: 33 | return xyz # [-inf, +inf] 34 | return xyz.clip(min=vmin, max=vmax) 35 | 36 | # distance to origin 37 | d = xyz.norm(dim=-1, keepdim=True) 38 | xyz = xyz / d.clip(min=1e-8) 39 | 40 | if mode == 'square': 41 | return xyz * d.square() 42 | 43 | if mode == 'exp': 44 | return xyz * torch.expm1(d) 45 | 46 | raise ValueError(f'bad {mode=}') 47 | 48 | 49 | def reg_dense_conf(x, mode): 50 | """ 51 | extract confidence from prediction head output 52 | """ 53 | mode, vmin, vmax = mode 54 | if mode == 'exp': 55 | return vmin + x.exp().clip(max=vmax-vmin) 56 | if mode == 'sigmoid': 57 | return (vmax - vmin) * torch.sigmoid(x) + vmin 58 | raise ValueError(f'bad {mode=}') 59 | -------------------------------------------------------------------------------- /slam3r/inference.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | # 4 | # -------------------------------------------------------- 5 | # utilities needed for the inference 6 | # -------------------------------------------------------- 7 | import torch 8 | import numpy as np 9 | 10 | from .utils.misc import invalid_to_zeros 11 | from .utils.geometry import geotrf, inv 12 | 13 | 14 | def loss_of_one_batch(loss_func, batch, model, criterion, device, 15 | use_amp=False, ret=None, 16 | assist_model=None, train=False, epoch=0, 17 | args=None): 18 | if loss_func == "i2p": 19 | return loss_of_one_batch_multiview(batch, model, criterion, 20 | device, use_amp, ret, 21 | args.ref_id) 22 | elif loss_func == "i2p_corr_score": 23 | return loss_of_one_batch_multiview_corr_score(batch, model, criterion, 24 | device, use_amp, ret, 25 | args.ref_id) 26 | elif loss_func == "l2w": 27 | return loss_of_one_batch_l2w( 28 | batch, model, criterion, 29 | device, use_amp, ret, 30 | ref_ids=args.ref_ids, coord_frame_id=0, 31 | exclude_ident=True, to_zero=True 32 | ) 33 | else: 34 | raise NotImplementedError 35 | 36 | 37 | def loss_of_one_batch_multiview(batch, model, criterion, device, 38 | use_amp=False, ret=None, ref_id=-1): 39 | """ Function to compute the reconstruction loss of the Image-to-Points model 40 | """ 41 | views = batch 42 | for view in views: 43 | for name in 'img pts3d valid_mask camera_pose'.split(): # pseudo_focal 44 | if name not in view: 45 | continue 46 | view[name] = view[name].to(device, non_blocking=True) 47 | 48 | if ref_id == -1: 49 | ref_id = (len(views)-1)//2 50 | 51 | with torch.cuda.amp.autocast(enabled=bool(use_amp)): 52 | preds = model(views, ref_id=ref_id) 53 | assert len(preds) == len(views) 54 | 55 | with torch.cuda.amp.autocast(enabled=False): 56 | if criterion is None: 57 | loss = None 58 | else: 59 | loss = criterion(views, preds, ref_id=ref_id) 60 | 61 | result = dict(views=views, preds=preds, loss=loss) 62 | for i in range(len(preds)): 63 | result[f'pred{i+1}'] = preds[i] 64 | result[f'view{i+1}'] = views[i] 65 | return result[ret] if ret else result 66 | 67 | 68 | def loss_of_one_batch_multiview_corr_score(batch, model, criterion, device, 69 | use_amp=False, ret=None, ref_id=-1): 70 | 71 | views = batch 72 | for view in views: 73 | for name in 'img pts3d valid_mask camera_pose'.split(): # pseudo_focal 74 | if name not in view: 75 | continue 76 | view[name] = view[name].to(device, non_blocking=True) 77 | 78 | if ref_id == -1: 79 | ref_id = (len(views)-1)//2 80 | 81 | all_loss = [0, {}] 82 | with torch.cuda.amp.autocast(enabled=bool(use_amp)): 83 | preds = model(views, ref_id=ref_id, return_corr_score=True) 84 | assert len(preds) == len(views) 85 | for i,pred in enumerate(preds): 86 | if i == ref_id: 87 | continue 88 | patch_pseudo_conf = pred['pseudo_conf'] # (B,S) 89 | true_conf = (pred['conf']-1.).mean(dim=(1,2)) # (B,) mean(exp(x)) 90 | pseudo_conf = torch.exp(patch_pseudo_conf).mean(dim=1) # (B,) mean(exp(batch(x))) 91 | pseudo_conf = pseudo_conf / (1+pseudo_conf) 92 | true_conf = true_conf / (1+true_conf) 93 | dis = torch.abs(pseudo_conf-true_conf) 94 | loss = dis.mean() 95 | # if loss.isinf(): 96 | # print(((patch_pseudo_conf-patch_true_conf)**2).max()) 97 | all_loss[0] += loss 98 | all_loss[1][f'pseudo_conf_loss_{i}'] = loss 99 | 100 | result = dict(views=views, preds=preds, loss=all_loss) 101 | for i in range(len(preds)): 102 | result[f'pred{i+1}'] = preds[i] 103 | result[f'view{i+1}'] = views[i] 104 | return result[ret] if ret else result 105 | 106 | 107 | def get_multiview_scale(pts:list, valid:list, norm_mode='avg_dis'): 108 | # adpat from DUSt3R 109 | for i in range(len(pts)): 110 | assert pts[i].ndim >= 3 and pts[i].shape[-1] == 3 111 | assert len(pts) == len(valid) 112 | norm_mode, dis_mode = norm_mode.split('_') 113 | 114 | if norm_mode == 'avg': 115 | # gather all points together (joint normalization) 116 | all_pts = [] 117 | all_nnz = 0 118 | for i in range(len(pts)): 119 | nan_pts, nnz = invalid_to_zeros(pts[i], valid[i], ndim=3) 120 | # print(nnz,nan_pts.shape) #(B,) (B,H*W,3) 121 | all_pts.append(nan_pts) 122 | all_nnz += nnz 123 | all_pts = torch.cat(all_pts, dim=1) 124 | # compute distance to origin 125 | all_dis = all_pts.norm(dim=-1) 126 | if dis_mode == 'dis': 127 | pass # do nothing 128 | elif dis_mode == 'log1p': 129 | all_dis = torch.log1p(all_dis) 130 | else: 131 | raise ValueError(f'bad {dis_mode=}') 132 | 133 | norm_factor = all_dis.sum(dim=1) / (all_nnz + 1e-8) 134 | else: 135 | raise ValueError(f'bad {norm_mode=}') 136 | 137 | norm_factor = norm_factor.clip(min=1e-8) 138 | while norm_factor.ndim < pts[0].ndim: 139 | norm_factor.unsqueeze_(-1) 140 | # print('norm factor:', norm_factor) 141 | return norm_factor 142 | 143 | 144 | def loss_of_one_batch_l2w(batch, model, criterion, device, 145 | use_amp=False, ret=None, 146 | ref_ids=-1, coord_frame_id=0, 147 | exclude_ident=True, to_zero=True): 148 | """ Function to compute the reconstruction loss of the Local-to-World model 149 | ref_ids: list of indices of the suppporting frames(excluding the coord_frame) 150 | coord_frame_id: all the pointmaps input and output will be in the coord_frame_id's camera coordinate 151 | exclude_ident: whether to exclude the coord_frame to simulate real-life inference scenarios 152 | to_zero: whether to set the invalid points to zero 153 | """ 154 | views = batch 155 | for view in views: 156 | for name in 'img pts3d pts3d_cam valid_mask camera_pose'.split(): # pseudo_focal 157 | if name not in view: 158 | continue 159 | view[name] = view[name].to(device, non_blocking=True) 160 | 161 | if coord_frame_id == -1: 162 | # ramdomly select a camera as the target camera 163 | coord_frame_id = np.random.randint(0, len(views)) 164 | # print(coord_frame_id) 165 | c2w = views[coord_frame_id]['camera_pose'] 166 | w2c = inv(c2w) 167 | 168 | # exclude the frame that has the identity pose 169 | if exclude_ident: 170 | views.pop(coord_frame_id) 171 | 172 | if ref_ids == -1: 173 | ref_ids = [i for i in range(len(views)-1)] # all views except the last one 174 | elif ref_ids == -2: 175 | #select half of the views randomly 176 | ref_ids = np.random.choice(len(views), len(views)//2, replace=False).tolist() 177 | else: 178 | assert isinstance(ref_ids, list) 179 | 180 | for id in ref_ids: 181 | views[id]['pts3d_world'] = geotrf(w2c, views[id]['pts3d']) #转移到目标坐标系 182 | norm_factor_world = get_multiview_scale([views[id]['pts3d_world'] for id in ref_ids], 183 | [views[id]['valid_mask'] for id in ref_ids], 184 | norm_mode='avg_dis') 185 | for id,view in enumerate(views): 186 | if id in ref_ids: 187 | view['pts3d_world'] = view['pts3d_world'].permute(0,3,1,2) / norm_factor_world 188 | else: 189 | norm_factor_src = get_multiview_scale([view['pts3d_cam']], 190 | [view['valid_mask']], 191 | norm_mode='avg_dis') 192 | view['pts3d_cam'] = view['pts3d_cam'].permute(0,3,1,2) / norm_factor_src 193 | 194 | if to_zero: 195 | for id,view in enumerate(views): 196 | valid_mask = view['valid_mask'].unsqueeze(1).float() # B,1,H,W 197 | if id in ref_ids: 198 | # print(view['pts3d_world'].shape, valid_mask.shape, (-valid_mask+1).sum()) 199 | view['pts3d_world'] = view['pts3d_world'] * valid_mask 200 | else: 201 | view['pts3d_cam'] = view['pts3d_cam'] * valid_mask 202 | 203 | with torch.cuda.amp.autocast(enabled=bool(use_amp)): 204 | preds = model(views, ref_ids=ref_ids) 205 | assert len(preds) == len(views) 206 | with torch.cuda.amp.autocast(enabled=False): 207 | if criterion is None: 208 | loss = None 209 | else: 210 | loss = criterion(views, preds, ref_id=ref_ids, ref_camera=w2c, norm_scale=norm_factor_world) 211 | 212 | result = dict(views=views, preds=preds, loss=loss) 213 | for i in range(len(preds)): 214 | result[f'pred{i+1}'] = preds[i] 215 | result[f'view{i+1}'] = views[i] 216 | return result[ret] if ret else result 217 | -------------------------------------------------------------------------------- /slam3r/losses.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | # 4 | # -------------------------------------------------------- 5 | # Implementation of DUSt3R training losses 6 | # -------------------------------------------------------- 7 | from copy import copy, deepcopy 8 | import torch 9 | import torch.nn as nn 10 | 11 | from slam3r.utils.geometry import inv, geotrf, depthmap_to_pts3d, multiview_normalize_pointcloud 12 | 13 | def get_pred_pts3d(gt, pred, use_pose=False): 14 | if 'depth' in pred and 'pseudo_focal' in pred: 15 | try: 16 | pp = gt['camera_intrinsics'][..., :2, 2] 17 | except KeyError: 18 | pp = None 19 | pts3d = depthmap_to_pts3d(**pred, pp=pp) 20 | 21 | elif 'pts3d' in pred: 22 | # pts3d from my camera 23 | pts3d = pred['pts3d'] 24 | 25 | elif 'pts3d_in_other_view' in pred: 26 | # pts3d from the other camera, already transformed 27 | assert use_pose is True 28 | return pred['pts3d_in_other_view'] # return! 29 | 30 | if use_pose: 31 | camera_pose = pred.get('camera_pose') 32 | assert camera_pose is not None 33 | pts3d = geotrf(camera_pose, pts3d) 34 | 35 | return pts3d 36 | 37 | def Sum(*losses_and_masks): 38 | loss, mask = losses_and_masks[0] 39 | if loss.ndim > 0: 40 | # we are actually returning the loss for every pixels 41 | return losses_and_masks 42 | else: 43 | # we are returning the global loss 44 | for loss2, mask2 in losses_and_masks[1:]: 45 | loss = loss + loss2 46 | return loss 47 | 48 | 49 | class LLoss (nn.Module): 50 | """ L-norm loss 51 | """ 52 | 53 | def __init__(self, reduction='mean'): 54 | super().__init__() 55 | self.reduction = reduction 56 | 57 | def forward(self, a, b): 58 | assert a.shape == b.shape and a.ndim >= 2 and 1 <= a.shape[-1] <= 3, f'Bad shape = {a.shape}' 59 | dist = self.distance(a, b) 60 | assert dist.ndim == a.ndim-1 # one dimension less 61 | if self.reduction == 'none': 62 | return dist 63 | if self.reduction == 'sum': 64 | return dist.sum() 65 | if self.reduction == 'mean': 66 | return dist.mean() if dist.numel() > 0 else dist.new_zeros(()) 67 | raise ValueError(f'bad {self.reduction=} mode') 68 | 69 | def distance(self, a, b): 70 | raise NotImplementedError() 71 | 72 | 73 | class L21Loss (LLoss): 74 | """ Euclidean distance between 3d points """ 75 | 76 | def distance(self, a, b): 77 | return torch.norm(a - b, dim=-1) # normalized L2 distance 78 | 79 | 80 | L21 = L21Loss() 81 | 82 | 83 | class Criterion (nn.Module): 84 | def __init__(self, criterion=None): 85 | super().__init__() 86 | assert isinstance(criterion, LLoss), f'{criterion} is not a proper criterion!' 87 | self.criterion = copy(criterion) 88 | 89 | def get_name(self): 90 | return f'{type(self).__name__}({self.criterion})' 91 | 92 | def with_reduction(self, mode): 93 | res = loss = deepcopy(self) 94 | while loss is not None: 95 | assert isinstance(loss, Criterion) 96 | loss.criterion.reduction = 'none' # make it return the loss for each sample 97 | loss = loss._loss2 # we assume loss is a Multiloss 98 | return res 99 | 100 | 101 | class MultiLoss (nn.Module): 102 | """ Easily combinable losses (also keep track of individual loss values): 103 | loss = MyLoss1() + 0.1*MyLoss2() 104 | Usage: 105 | Inherit from this class and override get_name() and compute_loss() 106 | """ 107 | 108 | def __init__(self): 109 | super().__init__() 110 | self._alpha = 1 111 | self._loss2 = None 112 | 113 | def compute_loss(self, *args, **kwargs): 114 | raise NotImplementedError() 115 | 116 | def get_name(self): 117 | raise NotImplementedError() 118 | 119 | def __mul__(self, alpha): 120 | assert isinstance(alpha, (int, float)) 121 | res = copy(self) 122 | res._alpha = alpha 123 | return res 124 | __rmul__ = __mul__ # same 125 | 126 | def __add__(self, loss2): 127 | assert isinstance(loss2, MultiLoss) 128 | res = cur = copy(self) 129 | # find the end of the chain 130 | while cur._loss2 is not None: 131 | cur = cur._loss2 132 | cur._loss2 = loss2 133 | return res 134 | 135 | def __repr__(self): 136 | name = self.get_name() 137 | if self._alpha != 1: 138 | name = f'{self._alpha:g}*{name}' 139 | if self._loss2: 140 | name = f'{name} + {self._loss2}' 141 | return name 142 | 143 | def forward(self, *args, **kwargs): 144 | loss = self.compute_loss(*args, **kwargs) 145 | if isinstance(loss, tuple): 146 | loss, details = loss 147 | elif loss.ndim == 0: 148 | details = {self.get_name(): float(loss)} 149 | else: 150 | details = {} 151 | loss = loss * self._alpha 152 | 153 | if self._loss2: 154 | loss2, details2 = self._loss2(*args, **kwargs) 155 | loss = loss + loss2 156 | details |= details2 157 | 158 | return loss, details 159 | 160 | class Jointnorm_Regr3D (Criterion, MultiLoss): 161 | """ Ensure that all 3D points are correct. 162 | Asymmetric loss: view1 is supposed to be the anchor. 163 | 164 | P1 = RT1 @ D1 165 | P2 = RT2 @ D2 166 | loss1 = (I @ pred_D1) - (RT1^-1 @ RT1 @ D1) 167 | loss2 = (RT21 @ pred_D2) - (RT1^-1 @ P2) 168 | = (RT21 @ pred_D2) - (RT1^-1 @ RT2 @ D2) 169 | gt and pred are transformed into localframe1 170 | """ 171 | 172 | def __init__(self, criterion, norm_mode='avg_dis', gt_scale=False, dist_clip=None): 173 | super().__init__(criterion) 174 | self.norm_mode = norm_mode 175 | self.gt_scale = gt_scale 176 | self.dist_clip = dist_clip 177 | 178 | def get_all_pts3d(self, gts, preds, ref_id, in_camera=None, norm_scale=None, dist_clip=None): 179 | # everything is normalized w.r.t. in_camera. 180 | # pointcloud normalization is conducted with the distance from the origin if norm_scale is None, otherwise use a fixed norm_scale 181 | if in_camera is None: 182 | in_camera = inv(gts[ref_id]['camera_pose']) 183 | gt_pts = [] 184 | valids = [] 185 | for gt in gts: 186 | gt_pts.append(geotrf(in_camera, gt['pts3d'])) 187 | valids.append(gt['valid_mask'].clone()) 188 | 189 | dist_clip = self.dist_clip if dist_clip is None else dist_clip 190 | if dist_clip is not None: 191 | # points that are too far-away == invalid 192 | for i in range(len(gts)): 193 | dis = gt_pts[i].norm(dim=-1) 194 | valids[i] = valids[i] & (dis conf_loss = x / 10 + alpha*log(10) 245 | low confidence means low conf = 10 ==> conf_loss = x * 10 - alpha*log(10) 246 | 247 | alpha: hyperparameter 248 | """ 249 | 250 | def __init__(self, pixel_loss, alpha=1): 251 | super().__init__() 252 | assert alpha > 0 253 | self.alpha = alpha 254 | self.pixel_loss = pixel_loss.with_reduction('none') 255 | 256 | def get_name(self): 257 | return f'ConfLoss({self.pixel_loss})' 258 | 259 | def get_conf_log(self, x): 260 | return x, torch.log(x) 261 | 262 | def compute_loss(self, gts, preds, head='', **kw): 263 | # compute per-pixel loss 264 | losses_and_masks, details = self.pixel_loss(gts, preds, head=head, **kw) 265 | for i in range(len(losses_and_masks)): 266 | if losses_and_masks[i][0].numel() == 0: 267 | print(f'NO VALID POINTS in img{i+1}', force=True) 268 | 269 | res_loss = 0 270 | res_info = details 271 | for i in range(len(losses_and_masks)): 272 | loss = losses_and_masks[i][0] 273 | mask = losses_and_masks[i][1] 274 | conf, log_conf = self.get_conf_log(preds[i]['conf'][mask]) 275 | conf_loss = loss * conf - self.alpha * log_conf 276 | conf_loss = conf_loss.mean() if conf_loss.numel() > 0 else 0 277 | res_loss += conf_loss 278 | info_name = f"conf_loss_{i+1}" if head == '' else f"conf_loss_{head}_{i+1}" 279 | res_info[info_name] = float(conf_loss) 280 | 281 | return res_loss, res_info 282 | -------------------------------------------------------------------------------- /slam3r/patch_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | # 4 | # -------------------------------------------------------- 5 | # PatchEmbed implementation for DUST3R, 6 | # in particular ManyAR_PatchEmbed that Handle images with non-square aspect ratio 7 | # -------------------------------------------------------- 8 | import torch 9 | from .blocks import PatchEmbed # noqa 10 | 11 | def get_patch_embed(patch_embed_cls, img_size, patch_size, enc_embed_dim): 12 | assert patch_embed_cls in ['PatchEmbedDust3R', 'ManyAR_PatchEmbed'] 13 | patch_embed = eval(patch_embed_cls)(img_size, patch_size, 3, enc_embed_dim) 14 | return patch_embed 15 | 16 | 17 | class PatchEmbedDust3R(PatchEmbed): 18 | def forward(self, x, **kw): 19 | B, C, H, W = x.shape 20 | assert H % self.patch_size[0] == 0, f"Input image height ({H}) is not a multiple of patch size ({self.patch_size[0]})." 21 | assert W % self.patch_size[1] == 0, f"Input image width ({W}) is not a multiple of patch size ({self.patch_size[1]})." 22 | x = self.proj(x) 23 | pos = self.position_getter(B, x.size(2), x.size(3), x.device) 24 | if self.flatten: 25 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC 26 | x = self.norm(x) 27 | return x, pos 28 | 29 | 30 | class ManyAR_PatchEmbed (PatchEmbed): 31 | """ Handle images with non-square aspect ratio. 32 | All images in the same batch have the same aspect ratio. 33 | true_shape = [(height, width) ...] indicates the actual shape of each image. 34 | """ 35 | 36 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True): 37 | self.embed_dim = embed_dim 38 | super().__init__(img_size, patch_size, in_chans, embed_dim, norm_layer, flatten) 39 | 40 | def forward(self, img, true_shape): 41 | B, C, H, W = img.shape 42 | assert W >= H, f'img should be in landscape mode, but got {W=} {H=}' 43 | assert H % self.patch_size[0] == 0, f"Input image height ({H}) is not a multiple of patch size ({self.patch_size[0]})." 44 | assert W % self.patch_size[1] == 0, f"Input image width ({W}) is not a multiple of patch size ({self.patch_size[1]})." 45 | assert true_shape.shape == (B, 2), f"true_shape has the wrong shape={true_shape.shape}" 46 | 47 | # size expressed in tokens 48 | W //= self.patch_size[0] 49 | H //= self.patch_size[1] 50 | n_tokens = H * W 51 | 52 | height, width = true_shape.T 53 | is_landscape = (width >= height) 54 | is_portrait = ~is_landscape 55 | 56 | # allocate result 57 | x = img.new_zeros((B, n_tokens, self.embed_dim)) 58 | pos = img.new_zeros((B, n_tokens, 2), dtype=torch.int64) 59 | 60 | # linear projection, transposed if necessary 61 | x[is_landscape] = self.proj(img[is_landscape]).permute(0, 2, 3, 1).flatten(1, 2).float() 62 | x[is_portrait] = self.proj(img[is_portrait].swapaxes(-1, -2)).permute(0, 2, 3, 1).flatten(1, 2).float() 63 | 64 | pos[is_landscape] = self.position_getter(1, H, W, pos.device) 65 | pos[is_portrait] = self.position_getter(1, W, H, pos.device) 66 | 67 | x = self.norm(x) 68 | return x, pos 69 | -------------------------------------------------------------------------------- /slam3r/pos_embed/__init__.py: -------------------------------------------------------------------------------- 1 | from .pos_embed import get_2d_sincos_pos_embed, RoPE2D -------------------------------------------------------------------------------- /slam3r/pos_embed/curope/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2022-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | 4 | from .curope2d import cuRoPE2D 5 | -------------------------------------------------------------------------------- /slam3r/pos_embed/curope/curope.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright (C) 2022-present Naver Corporation. All rights reserved. 3 | Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 4 | */ 5 | 6 | #include 7 | 8 | // forward declaration 9 | void rope_2d_cuda( torch::Tensor tokens, const torch::Tensor pos, const float base, const float fwd ); 10 | 11 | void rope_2d_cpu( torch::Tensor tokens, const torch::Tensor positions, const float base, const float fwd ) 12 | { 13 | const int B = tokens.size(0); 14 | const int N = tokens.size(1); 15 | const int H = tokens.size(2); 16 | const int D = tokens.size(3) / 4; 17 | 18 | auto tok = tokens.accessor(); 19 | auto pos = positions.accessor(); 20 | 21 | for (int b = 0; b < B; b++) { 22 | for (int x = 0; x < 2; x++) { // y and then x (2d) 23 | for (int n = 0; n < N; n++) { 24 | 25 | // grab the token position 26 | const int p = pos[b][n][x]; 27 | 28 | for (int h = 0; h < H; h++) { 29 | for (int d = 0; d < D; d++) { 30 | // grab the two values 31 | float u = tok[b][n][h][d+0+x*2*D]; 32 | float v = tok[b][n][h][d+D+x*2*D]; 33 | 34 | // grab the cos,sin 35 | const float inv_freq = fwd * p / powf(base, d/float(D)); 36 | float c = cosf(inv_freq); 37 | float s = sinf(inv_freq); 38 | 39 | // write the result 40 | tok[b][n][h][d+0+x*2*D] = u*c - v*s; 41 | tok[b][n][h][d+D+x*2*D] = v*c + u*s; 42 | } 43 | } 44 | } 45 | } 46 | } 47 | } 48 | 49 | void rope_2d( torch::Tensor tokens, // B,N,H,D 50 | const torch::Tensor positions, // B,N,2 51 | const float base, 52 | const float fwd ) 53 | { 54 | TORCH_CHECK(tokens.dim() == 4, "tokens must have 4 dimensions"); 55 | TORCH_CHECK(positions.dim() == 3, "positions must have 3 dimensions"); 56 | TORCH_CHECK(tokens.size(0) == positions.size(0), "batch size differs between tokens & positions"); 57 | TORCH_CHECK(tokens.size(1) == positions.size(1), "seq_length differs between tokens & positions"); 58 | TORCH_CHECK(positions.size(2) == 2, "positions.shape[2] must be equal to 2"); 59 | TORCH_CHECK(tokens.is_cuda() == positions.is_cuda(), "tokens and positions are not on the same device" ); 60 | 61 | if (tokens.is_cuda()) 62 | rope_2d_cuda( tokens, positions, base, fwd ); 63 | else 64 | rope_2d_cpu( tokens, positions, base, fwd ); 65 | } 66 | 67 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 68 | m.def("rope_2d", &rope_2d, "RoPE 2d forward/backward"); 69 | } 70 | -------------------------------------------------------------------------------- /slam3r/pos_embed/curope/curope2d.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2022-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | 4 | import torch 5 | 6 | try: 7 | import curope as _kernels # run `python setup.py install` 8 | except ModuleNotFoundError: 9 | from . import curope as _kernels # run `python setup.py build_ext --inplace` 10 | 11 | 12 | class cuRoPE2D_func (torch.autograd.Function): 13 | 14 | @staticmethod 15 | def forward(ctx, tokens, positions, base, F0=1): 16 | ctx.save_for_backward(positions) 17 | ctx.saved_base = base 18 | ctx.saved_F0 = F0 19 | # tokens = tokens.clone() # uncomment this if inplace doesn't work 20 | _kernels.rope_2d( tokens, positions, base, F0 ) 21 | ctx.mark_dirty(tokens) 22 | return tokens 23 | 24 | @staticmethod 25 | def backward(ctx, grad_res): 26 | positions, base, F0 = ctx.saved_tensors[0], ctx.saved_base, ctx.saved_F0 27 | _kernels.rope_2d( grad_res, positions, base, -F0 ) 28 | ctx.mark_dirty(grad_res) 29 | return grad_res, None, None, None 30 | 31 | 32 | class cuRoPE2D(torch.nn.Module): 33 | def __init__(self, freq=100.0, F0=1.0): 34 | super().__init__() 35 | self.base = freq 36 | self.F0 = F0 37 | 38 | def forward(self, tokens, positions): 39 | cuRoPE2D_func.apply( tokens.transpose(1,2), positions, self.base, self.F0 ) 40 | return tokens -------------------------------------------------------------------------------- /slam3r/pos_embed/curope/kernels.cu: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright (C) 2022-present Naver Corporation. All rights reserved. 3 | Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 4 | */ 5 | 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | #define CHECK_CUDA(tensor) {\ 12 | TORCH_CHECK((tensor).is_cuda(), #tensor " is not in cuda memory"); \ 13 | TORCH_CHECK((tensor).is_contiguous(), #tensor " is not contiguous"); } 14 | void CHECK_KERNEL() {auto error = cudaGetLastError(); TORCH_CHECK( error == cudaSuccess, cudaGetErrorString(error));} 15 | 16 | 17 | template < typename scalar_t > 18 | __global__ void rope_2d_cuda_kernel( 19 | //scalar_t* __restrict__ tokens, 20 | torch::PackedTensorAccessor32 tokens, 21 | const int64_t* __restrict__ pos, 22 | const float base, 23 | const float fwd ) 24 | // const int N, const int H, const int D ) 25 | { 26 | // tokens shape = (B, N, H, D) 27 | const int N = tokens.size(1); 28 | const int H = tokens.size(2); 29 | const int D = tokens.size(3); 30 | 31 | // each block update a single token, for all heads 32 | // each thread takes care of a single output 33 | extern __shared__ float shared[]; 34 | float* shared_inv_freq = shared + D; 35 | 36 | const int b = blockIdx.x / N; 37 | const int n = blockIdx.x % N; 38 | 39 | const int Q = D / 4; 40 | // one token = [0..Q : Q..2Q : 2Q..3Q : 3Q..D] 41 | // u_Y v_Y u_X v_X 42 | 43 | // shared memory: first, compute inv_freq 44 | if (threadIdx.x < Q) 45 | shared_inv_freq[threadIdx.x] = fwd / powf(base, threadIdx.x/float(Q)); 46 | __syncthreads(); 47 | 48 | // start of X or Y part 49 | const int X = threadIdx.x < D/2 ? 0 : 1; 50 | const int m = (X*D/2) + (threadIdx.x % Q); // index of u_Y or u_X 51 | 52 | // grab the cos,sin appropriate for me 53 | const float freq = pos[blockIdx.x*2+X] * shared_inv_freq[threadIdx.x % Q]; 54 | const float cos = cosf(freq); 55 | const float sin = sinf(freq); 56 | /* 57 | float* shared_cos_sin = shared + D + D/4; 58 | if ((threadIdx.x % (D/2)) < Q) 59 | shared_cos_sin[m+0] = cosf(freq); 60 | else 61 | shared_cos_sin[m+Q] = sinf(freq); 62 | __syncthreads(); 63 | const float cos = shared_cos_sin[m+0]; 64 | const float sin = shared_cos_sin[m+Q]; 65 | */ 66 | 67 | for (int h = 0; h < H; h++) 68 | { 69 | // then, load all the token for this head in shared memory 70 | shared[threadIdx.x] = tokens[b][n][h][threadIdx.x]; 71 | __syncthreads(); 72 | 73 | const float u = shared[m]; 74 | const float v = shared[m+Q]; 75 | 76 | // write output 77 | if ((threadIdx.x % (D/2)) < Q) 78 | tokens[b][n][h][threadIdx.x] = u*cos - v*sin; 79 | else 80 | tokens[b][n][h][threadIdx.x] = v*cos + u*sin; 81 | } 82 | } 83 | 84 | void rope_2d_cuda( torch::Tensor tokens, const torch::Tensor pos, const float base, const float fwd ) 85 | { 86 | const int B = tokens.size(0); // batch size 87 | const int N = tokens.size(1); // sequence length 88 | const int H = tokens.size(2); // number of heads 89 | const int D = tokens.size(3); // dimension per head 90 | 91 | TORCH_CHECK(tokens.stride(3) == 1 && tokens.stride(2) == D, "tokens are not contiguous"); 92 | TORCH_CHECK(pos.is_contiguous(), "positions are not contiguous"); 93 | TORCH_CHECK(pos.size(0) == B && pos.size(1) == N && pos.size(2) == 2, "bad pos.shape"); 94 | TORCH_CHECK(D % 4 == 0, "token dim must be multiple of 4"); 95 | 96 | // one block for each layer, one thread per local-max 97 | const int THREADS_PER_BLOCK = D; 98 | const int N_BLOCKS = B * N; // each block takes care of H*D values 99 | const int SHARED_MEM = sizeof(float) * (D + D/4); 100 | 101 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(tokens.type(), "rope_2d_cuda", ([&] { 102 | rope_2d_cuda_kernel <<>> ( 103 | //tokens.data_ptr(), 104 | tokens.packed_accessor32(), 105 | pos.data_ptr(), 106 | base, fwd); //, N, H, D ); 107 | })); 108 | } 109 | -------------------------------------------------------------------------------- /slam3r/pos_embed/curope/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2022-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | 4 | from setuptools import setup 5 | from torch import cuda 6 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 7 | 8 | # compile for all possible CUDA architectures 9 | # all_cuda_archs = cuda.get_gencode_flags().replace('compute=','arch=').split() 10 | # alternatively, you can list cuda archs that you want, eg: 11 | all_cuda_archs = [ 12 | # '-gencode', 'arch=compute_70,code=sm_70', 13 | # '-gencode', 'arch=compute_75,code=sm_75', 14 | # '-gencode', 'arch=compute_80,code=sm_80', 15 | # '-gencode', 'arch=compute_86,code=sm_86' 16 | "-arch=native" 17 | ] 18 | 19 | setup( 20 | name = 'curope', 21 | ext_modules = [ 22 | CUDAExtension( 23 | name='curope', 24 | sources=[ 25 | "curope.cpp", 26 | "kernels.cu", 27 | ], 28 | extra_compile_args = dict( 29 | nvcc=['-O3','--ptxas-options=-v',"--use_fast_math"]+all_cuda_archs, 30 | cxx=['-O3']) 31 | ) 32 | ], 33 | cmdclass = { 34 | 'build_ext': BuildExtension 35 | }) 36 | -------------------------------------------------------------------------------- /slam3r/pos_embed/pos_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2022-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | 4 | 5 | # -------------------------------------------------------- 6 | # Position embedding utils 7 | # -------------------------------------------------------- 8 | 9 | 10 | 11 | import numpy as np 12 | 13 | import torch 14 | 15 | # -------------------------------------------------------- 16 | # 2D sine-cosine position embedding 17 | # References: 18 | # MAE: https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py 19 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py 20 | # MoCo v3: https://github.com/facebookresearch/moco-v3 21 | # -------------------------------------------------------- 22 | def get_2d_sincos_pos_embed(embed_dim, grid_size, n_cls_token=0): 23 | """ 24 | grid_size: int of the grid height and width 25 | return: 26 | pos_embed: [grid_size*grid_size, embed_dim] or [n_cls_token+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 27 | """ 28 | grid_h = np.arange(grid_size, dtype=np.float32) 29 | grid_w = np.arange(grid_size, dtype=np.float32) 30 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 31 | grid = np.stack(grid, axis=0) 32 | 33 | grid = grid.reshape([2, 1, grid_size, grid_size]) 34 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 35 | if n_cls_token>0: 36 | pos_embed = np.concatenate([np.zeros([n_cls_token, embed_dim]), pos_embed], axis=0) 37 | return pos_embed 38 | 39 | 40 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 41 | assert embed_dim % 2 == 0 42 | 43 | # use half of dimensions to encode grid_h 44 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 45 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 46 | 47 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 48 | return emb 49 | 50 | 51 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 52 | """ 53 | embed_dim: output dimension for each position 54 | pos: a list of positions to be encoded: size (M,) 55 | out: (M, D) 56 | """ 57 | assert embed_dim % 2 == 0 58 | omega = np.arange(embed_dim // 2, dtype=float) 59 | omega /= embed_dim / 2. 60 | omega = 1. / 10000**omega # (D/2,) 61 | 62 | pos = pos.reshape(-1) # (M,) 63 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 64 | 65 | emb_sin = np.sin(out) # (M, D/2) 66 | emb_cos = np.cos(out) # (M, D/2) 67 | 68 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 69 | return emb 70 | 71 | 72 | # -------------------------------------------------------- 73 | # Interpolate position embeddings for high-resolution 74 | # References: 75 | # MAE: https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py 76 | # DeiT: https://github.com/facebookresearch/deit 77 | # -------------------------------------------------------- 78 | def interpolate_pos_embed(model, checkpoint_model): 79 | if 'pos_embed' in checkpoint_model: 80 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 81 | embedding_size = pos_embed_checkpoint.shape[-1] 82 | num_patches = model.patch_embed.num_patches 83 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 84 | # height (== width) for the checkpoint position embedding 85 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 86 | # height (== width) for the new position embedding 87 | new_size = int(num_patches ** 0.5) 88 | # class_token and dist_token are kept unchanged 89 | if orig_size != new_size: 90 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 91 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 92 | # only the position tokens are interpolated 93 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 94 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 95 | pos_tokens = torch.nn.functional.interpolate( 96 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 97 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 98 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 99 | checkpoint_model['pos_embed'] = new_pos_embed 100 | 101 | 102 | #---------------------------------------------------------- 103 | # RoPE2D: RoPE implementation in 2D 104 | #---------------------------------------------------------- 105 | 106 | try: 107 | from .curope import cuRoPE2D 108 | RoPE2D = cuRoPE2D 109 | except ImportError as e: 110 | # print(e) 111 | print(f'Warning, cannot find cuda-compiled version of RoPE2D, using a slow pytorch version instead') 112 | 113 | class RoPE2D(torch.nn.Module): 114 | 115 | def __init__(self, freq=100.0, F0=1.0): 116 | super().__init__() 117 | self.base = freq 118 | self.F0 = F0 119 | self.cache = {} 120 | 121 | def get_cos_sin(self, D, seq_len, device, dtype): 122 | if (D,seq_len,device,dtype) not in self.cache: 123 | inv_freq = 1.0 / (self.base ** (torch.arange(0, D, 2).float().to(device) / D)) 124 | t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) 125 | freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype) 126 | freqs = torch.cat((freqs, freqs), dim=-1) 127 | cos = freqs.cos() # (Seq, Dim) 128 | sin = freqs.sin() 129 | self.cache[D,seq_len,device,dtype] = (cos,sin) 130 | return self.cache[D,seq_len,device,dtype] 131 | 132 | @staticmethod 133 | def rotate_half(x): 134 | x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] 135 | return torch.cat((-x2, x1), dim=-1) 136 | 137 | def apply_rope1d(self, tokens, pos1d, cos, sin): 138 | assert pos1d.ndim==2 139 | cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :] 140 | sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :] 141 | return (tokens * cos) + (self.rotate_half(tokens) * sin) 142 | 143 | def forward(self, tokens, positions): 144 | """ 145 | input: 146 | * tokens: batch_size x nheads x ntokens x dim 147 | * positions: batch_size x ntokens x 2 (y and x position of each token) 148 | output: 149 | * tokens after appplying RoPE2D (batch_size x nheads x ntokens x dim) 150 | """ 151 | assert tokens.size(3)%2==0, "number of dimensions should be a multiple of two" 152 | D = tokens.size(3) // 2 153 | assert positions.ndim==3 and positions.shape[-1] == 2 # Batch, Seq, 2 154 | cos, sin = self.get_cos_sin(D, int(positions.max())+1, tokens.device, tokens.dtype) 155 | # split features into two along the feature dimension, and apply rope1d on each half 156 | y, x = tokens.chunk(2, dim=-1) 157 | y = self.apply_rope1d(y, positions[:,:,0], cos, sin) 158 | x = self.apply_rope1d(x, positions[:,:,1], cos, sin) 159 | tokens = torch.cat((y, x), dim=-1) 160 | return tokens -------------------------------------------------------------------------------- /slam3r/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | -------------------------------------------------------------------------------- /slam3r/utils/device.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | # 4 | # -------------------------------------------------------- 5 | # utilitary functions for DUSt3R 6 | # -------------------------------------------------------- 7 | import numpy as np 8 | import torch 9 | 10 | 11 | def todevice(batch, device, callback=None, non_blocking=False): 12 | ''' Transfer some variables to another device (i.e. GPU, CPU:torch, CPU:numpy). 13 | 14 | batch: list, tuple, dict of tensors or other things 15 | device: pytorch device or 'numpy' 16 | callback: function that would be called on every sub-elements. 17 | ''' 18 | if callback: 19 | batch = callback(batch) 20 | 21 | if isinstance(batch, dict): 22 | return {k: todevice(v, device) for k, v in batch.items()} 23 | 24 | if isinstance(batch, (tuple, list)): 25 | return type(batch)(todevice(x, device) for x in batch) 26 | 27 | x = batch 28 | if device == 'numpy': 29 | if isinstance(x, torch.Tensor): 30 | x = x.detach().cpu().numpy() 31 | elif x is not None: 32 | if isinstance(x, np.ndarray): 33 | x = torch.from_numpy(x) 34 | if torch.is_tensor(x): 35 | x = x.to(device, non_blocking=non_blocking) 36 | return x 37 | 38 | 39 | to_device = todevice # alias 40 | 41 | 42 | def to_numpy(x): return todevice(x, 'numpy') 43 | def to_cpu(x): return todevice(x, 'cpu') 44 | def to_cuda(x): return todevice(x, 'cuda') 45 | 46 | 47 | def collate_with_cat(whatever, lists=False): 48 | if isinstance(whatever, dict): 49 | return {k: collate_with_cat(vals, lists=lists) for k, vals in whatever.items()} 50 | 51 | elif isinstance(whatever, (tuple, list)): 52 | if len(whatever) == 0: 53 | return whatever 54 | elem = whatever[0] 55 | T = type(whatever) 56 | 57 | if elem is None: 58 | return None 59 | if isinstance(elem, (bool, float, int, str)): 60 | return whatever 61 | if isinstance(elem, tuple): #将whatever中含有的多个元组的对应位置的元素组合起来 [(1,2),(3,4)] -> [(1,3),(2,4)] 62 | return T(collate_with_cat(x, lists=lists) for x in zip(*whatever)) 63 | if isinstance(elem, dict): #将whatever中含有的多个字典组合成一个字典 64 | return {k: collate_with_cat([e[k] for e in whatever], lists=lists) for k in elem} 65 | 66 | if isinstance(elem, torch.Tensor): 67 | return listify(whatever) if lists else torch.cat(whatever) 68 | if isinstance(elem, np.ndarray): 69 | return listify(whatever) if lists else torch.cat([torch.from_numpy(x) for x in whatever]) 70 | 71 | # otherwise, we just chain lists 72 | return sum(whatever, T()) 73 | 74 | 75 | def listify(elems): 76 | return [x for e in elems for x in e] 77 | 78 | class MyNvtxRange(): 79 | def __init__(self, name): 80 | self.name = name 81 | 82 | def __enter__(self): 83 | torch.cuda.synchronize() 84 | torch.cuda.nvtx.range_push(self.name) 85 | 86 | def __exit__(self, type, value, traceback): 87 | torch.cuda.synchronize() 88 | torch.cuda.nvtx.range_pop() 89 | -------------------------------------------------------------------------------- /slam3r/utils/image.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | # 4 | # -------------------------------------------------------- 5 | # utilitary functions about images (loading/converting...) 6 | # -------------------------------------------------------- 7 | import os 8 | import torch 9 | import numpy as np 10 | import PIL.Image 11 | from tqdm import tqdm 12 | from PIL.ImageOps import exif_transpose 13 | import torchvision.transforms as tvf 14 | os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" 15 | import cv2 # noqa 16 | 17 | try: 18 | from pillow_heif import register_heif_opener # noqa 19 | register_heif_opener() 20 | heif_support_enabled = True 21 | except ImportError: 22 | heif_support_enabled = False 23 | 24 | from .geometry import depthmap_to_camera_coordinates 25 | import json 26 | 27 | ImgNorm = tvf.Compose([tvf.ToTensor(), tvf.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 28 | 29 | 30 | def imread_cv2(path, options=cv2.IMREAD_COLOR): 31 | """ Open an image or a depthmap with opencv-python. 32 | """ 33 | if path.endswith(('.exr', 'EXR')): 34 | options = cv2.IMREAD_ANYDEPTH 35 | img = cv2.imread(path, options) 36 | if img is None: 37 | raise IOError(f'Could not load image={path} with {options=}') 38 | if img.ndim == 3: 39 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 40 | return img 41 | 42 | 43 | def rgb(ftensor, true_shape=None): 44 | if isinstance(ftensor, list): 45 | return [rgb(x, true_shape=true_shape) for x in ftensor] 46 | if isinstance(ftensor, torch.Tensor): 47 | ftensor = ftensor.detach().cpu().numpy() # H,W,3 48 | if ftensor.ndim == 3 and ftensor.shape[0] == 3: 49 | ftensor = ftensor.transpose(1, 2, 0) 50 | elif ftensor.ndim == 4 and ftensor.shape[1] == 3: 51 | ftensor = ftensor.transpose(0, 2, 3, 1) 52 | if true_shape is not None: 53 | H, W = true_shape 54 | ftensor = ftensor[:H, :W] 55 | if ftensor.dtype == np.uint8: 56 | img = np.float32(ftensor) / 255 57 | else: 58 | img = (ftensor * 0.5) + 0.5 59 | return img.clip(min=0, max=1) 60 | 61 | 62 | def _resize_pil_image(img, long_edge_size): 63 | S = max(img.size) 64 | if S > long_edge_size: 65 | interp = PIL.Image.LANCZOS 66 | elif S <= long_edge_size: 67 | interp = PIL.Image.BICUBIC 68 | new_size = tuple(int(round(x*long_edge_size/S)) for x in img.size) 69 | return img.resize(new_size, interp) 70 | 71 | 72 | def load_images(folder_or_list, size, square_ok=False, 73 | verbose=1, img_num=0, img_freq=0, 74 | postfix=None, start_idx=0): 75 | """ open and convert all images in a list or folder to proper input format for DUSt3R 76 | """ 77 | if isinstance(folder_or_list, str): 78 | if verbose > 0: 79 | print(f'>> Loading images from {folder_or_list}') 80 | img_names = [name for name in os.listdir(folder_or_list) if not "depth" in name] 81 | if postfix is not None: 82 | img_names = [name for name in img_names if name.endswith(postfix)] 83 | root, folder_content = folder_or_list, img_names 84 | 85 | elif isinstance(folder_or_list, list): 86 | if verbose > 0: 87 | print(f'>> Loading a list of {len(folder_or_list)} images') 88 | root, folder_content = '', folder_or_list 89 | 90 | else: 91 | raise ValueError(f'bad {folder_or_list=} ({type(folder_or_list)})') 92 | 93 | # sort images by number in name 94 | len_postfix = len(postfix) if postfix is not None \ 95 | else len(folder_content[0]) - folder_content[0].rfind('.') 96 | 97 | img_numbers = [] 98 | for name in folder_content: 99 | dot_index = len(name) - len_postfix 100 | number_start = 0 101 | for i in range(dot_index-1, 0, -1): 102 | if not name[i].isdigit(): 103 | number_start = i + 1 104 | break 105 | img_numbers.append(float(name[number_start:dot_index])) 106 | folder_content = [x for _, x in sorted(zip(img_numbers, folder_content))] 107 | 108 | if start_idx > 0: 109 | folder_content = folder_content[start_idx:] 110 | if(img_freq > 0): 111 | folder_content = folder_content[::img_freq] 112 | if(img_num > 0): 113 | folder_content = folder_content[:img_num] 114 | 115 | # print(root, folder_content) 116 | 117 | supported_images_extensions = ['.jpg', '.jpeg', '.png'] 118 | if heif_support_enabled: 119 | supported_images_extensions += ['.heic', '.heif'] 120 | supported_images_extensions = tuple(supported_images_extensions) 121 | 122 | imgs = [] 123 | if verbose > 0: 124 | folder_content = tqdm(folder_content, desc='Loading images') 125 | for path in folder_content: 126 | if not path.lower().endswith(supported_images_extensions): 127 | continue 128 | img = exif_transpose(PIL.Image.open(os.path.join(root, path))).convert('RGB') 129 | W1, H1 = img.size 130 | if size == 224: 131 | # resize short side to 224 (then crop) 132 | img = _resize_pil_image(img, round(size * max(W1/H1, H1/W1))) 133 | else: 134 | # resize long side to 512 135 | img = _resize_pil_image(img, size) 136 | W, H = img.size 137 | cx, cy = W//2, H//2 138 | if size == 224: 139 | half = min(cx, cy) 140 | img = img.crop((cx-half, cy-half, cx+half, cy+half)) 141 | else: 142 | halfw, halfh = ((2*cx)//16)*8, ((2*cy)//16)*8 143 | if not (square_ok) and W == H: 144 | halfh = 3*halfw/4 145 | img = img.crop((cx-halfw, cy-halfh, cx+halfw, cy+halfh)) 146 | 147 | W2, H2 = img.size 148 | if verbose > 1: 149 | print(f' - adding {path} with resolution {W1}x{H1} --> {W2}x{H2}') 150 | 151 | imgs.append(dict(img=ImgNorm(img)[None], true_shape=np.int32( 152 | [img.size[::-1]]), idx=len(imgs), instance=str(len(imgs)), label=path)) 153 | 154 | assert imgs, 'no images foud at '+ root 155 | if verbose > 0: 156 | print(f' ({len(imgs)} images loaded)') 157 | return imgs 158 | 159 | 160 | 161 | def crop_and_resize(image, depthmap, intrinsics, long_size, rng=None, info=None, use_crop=False): 162 | """ This function: 163 | 1. 将图片crop,使得其principal point真正落在中间 164 | 2. 根据图片横竖确定target resolution的横竖 165 | """ 166 | import slam3r.datasets.utils.cropping as cropping 167 | if not isinstance(image, PIL.Image.Image): 168 | image = PIL.Image.fromarray(image) 169 | 170 | W, H = image.size 171 | cx, cy = intrinsics[:2, 2].round().astype(int) 172 | if(use_crop): 173 | # downscale with lanczos interpolation so that image.size == resolution 174 | # cropping centered on the principal point 175 | min_margin_x = min(cx, W-cx) 176 | min_margin_y = min(cy, H-cy) 177 | assert min_margin_x > W/5, f'Bad principal point in view={info}' 178 | assert min_margin_y > H/5, f'Bad principal point in view={info}' 179 | # the new window will be a rectangle of size (2*min_margin_x, 2*min_margin_y) centered on (cx,cy) 180 | l, t = cx - min_margin_x, cy - min_margin_y 181 | r, b = cx + min_margin_x, cy + min_margin_y 182 | crop_bbox = (l, t, r, b) 183 | image, depthmap, intrinsics = cropping.crop_image_depthmap(image, depthmap, intrinsics, crop_bbox) 184 | 185 | # transpose the resolution if necessary 186 | W, H = image.size # new size 187 | scale = long_size / max(W, H) 188 | 189 | # high-quality Lanczos down-scaling 190 | target_resolution = np.array([W, H]) * scale 191 | 192 | image, depthmap, intrinsics = cropping.rescale_image_depthmap(image, depthmap, intrinsics, target_resolution) 193 | 194 | return image, depthmap, intrinsics 195 | 196 | 197 | def load_scannetpp_images_pts3dcam(folder_or_list, size, square_ok=False, verbose=True, img_num=0, img_freq=0): 198 | """ open and convert all images in a list or folder to proper input format for DUSt3R 199 | """ 200 | if isinstance(folder_or_list, str): 201 | if verbose: 202 | print(f'>> Loading images from {folder_or_list}') 203 | root, folder_content = folder_or_list, sorted(os.listdir(folder_or_list)) 204 | 205 | elif isinstance(folder_or_list, list): 206 | if verbose: 207 | print(f'>> Loading a list of {len(folder_or_list)} images') 208 | root, folder_content = '', folder_or_list 209 | 210 | else: 211 | raise ValueError(f'bad {folder_or_list=} ({type(folder_or_list)})') 212 | 213 | if(img_freq > 0): 214 | folder_content = folder_content[1000::img_freq] 215 | if(img_num > 0): 216 | folder_content = folder_content[:img_num] 217 | 218 | supported_images_extensions = ['.jpg', '.jpeg', '.png'] 219 | if heif_support_enabled: 220 | supported_images_extensions += ['.heic', '.heif'] 221 | supported_images_extensions = tuple(supported_images_extensions) 222 | 223 | imgs = [] 224 | 225 | intrinsic_path = os.path.join(os.path.dirname(root), 'pose_intrinsic_imu.json') 226 | with open(intrinsic_path, 'r') as f: 227 | info = json.load(f) 228 | 229 | for path in folder_content: 230 | if not path.lower().endswith(supported_images_extensions): 231 | continue 232 | img_path = os.path.join(root, path) 233 | img = exif_transpose(PIL.Image.open(img_path)).convert('RGB') 234 | W1, H1 = img.size 235 | 236 | depth_path = img_path.replace('.jpg', '.png').replace('rgb','depth') 237 | depthmap = imread_cv2(depth_path, cv2.IMREAD_UNCHANGED) 238 | depthmap = depthmap.astype(np.float32) / 1000. 239 | """ 240 | img and depth has different convention about shape 241 | """ 242 | # print(img.size, depthmap.shape) 243 | depthmap = cv2.resize(depthmap, (W1,H1), interpolation=cv2.INTER_CUBIC) 244 | # print(img.size, depthmap.shape) 245 | img_id = os.path.basename(img_path)[:-4] 246 | intrinsics = np.array(info[img_id]['intrinsic']) 247 | # print(img, depthmap, intrinsics) 248 | img, depthmap, intrinsics = crop_and_resize(img, depthmap, intrinsics, size) 249 | # print(img, depthmap, intrinsics) 250 | pts3d_cam, mask = depthmap_to_camera_coordinates(depthmap, intrinsics) 251 | pts3d_cam = pts3d_cam * mask[..., None] 252 | # print(pts3d_cam.shape) 253 | valid_mask = np.isfinite(pts3d_cam).all(axis=-1) 254 | W2, H2 = img.size 255 | if verbose: 256 | print(f' - adding {path} with resolution {W1}x{H1} --> {W2}x{H2}') 257 | 258 | imgs.append(dict(img=ImgNorm(img)[None], 259 | true_shape=np.int32([img.size[::-1]]), 260 | idx=len(imgs), 261 | instance=str(len(imgs)), 262 | pts3d_cam=pts3d_cam[None], 263 | valid_mask=valid_mask[None] 264 | )) 265 | # break 266 | 267 | assert imgs, 'no images foud at '+root 268 | if verbose: 269 | print(f' (Found {len(imgs)} images)') 270 | return imgs 271 | 272 | -------------------------------------------------------------------------------- /slam3r/utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | # 4 | # -------------------------------------------------------- 5 | # utilitary functions for DUSt3R 6 | # -------------------------------------------------------- 7 | import torch 8 | 9 | 10 | def fill_default_args(kwargs, func): 11 | import inspect # a bit hacky but it works reliably 12 | signature = inspect.signature(func) 13 | 14 | for k, v in signature.parameters.items(): 15 | if v.default is inspect.Parameter.empty: 16 | continue 17 | kwargs.setdefault(k, v.default) 18 | 19 | return kwargs 20 | 21 | 22 | def freeze_all_params(modules): 23 | for module in modules: 24 | try: 25 | for n, param in module.named_parameters(): 26 | param.requires_grad = False 27 | except AttributeError: 28 | # module is directly a parameter 29 | module.requires_grad = False 30 | 31 | 32 | def is_symmetrized(gt1, gt2): 33 | x = gt1['instance'] 34 | y = gt2['instance'] 35 | if len(x) == len(y) and len(x) == 1: 36 | return False # special case of batchsize 1 37 | ok = True 38 | for i in range(0, len(x), 2): 39 | ok = ok and (x[i] == y[i+1]) and (x[i+1] == y[i]) 40 | return ok 41 | 42 | 43 | def flip(tensor): 44 | """ flip so that tensor[0::2] <=> tensor[1::2] """ 45 | return torch.stack((tensor[1::2], tensor[0::2]), dim=1).flatten(0, 1) 46 | 47 | 48 | def interleave(tensor1, tensor2): 49 | res1 = torch.stack((tensor1, tensor2), dim=1).flatten(0, 1) 50 | res2 = torch.stack((tensor2, tensor1), dim=1).flatten(0, 1) 51 | return res1, res2 52 | 53 | 54 | def transpose_to_landscape(head, activate=True): 55 | """ Predict in the correct aspect-ratio, 56 | then transpose the result in landscape 57 | and stack everything back together. 58 | """ 59 | def wrapper_no(decout, true_shape): 60 | B = len(true_shape) 61 | assert true_shape[0:1].allclose(true_shape), 'true_shape must be all identical' 62 | H, W = true_shape[0].cpu().tolist() 63 | res = head(decout, (H, W)) 64 | return res 65 | 66 | def wrapper_yes(decout, true_shape): 67 | B = len(true_shape) 68 | # by definition, the batch is in landscape mode so W >= H 69 | H, W = int(true_shape.min()), int(true_shape.max()) 70 | 71 | height, width = true_shape.T 72 | is_landscape = (width >= height) 73 | is_portrait = ~is_landscape 74 | 75 | # true_shape = true_shape.cpu() 76 | if is_landscape.all(): 77 | return head(decout, (H, W)) 78 | if is_portrait.all(): 79 | return transposed(head(decout, (W, H))) 80 | 81 | # batch is a mix of both portraint & landscape 82 | def selout(ar): return [d[ar] for d in decout] 83 | l_result = head(selout(is_landscape), (H, W)) 84 | p_result = transposed(head(selout(is_portrait), (W, H))) 85 | 86 | # allocate full result 87 | result = {} 88 | for k in l_result | p_result: #遍历字典的键 89 | x = l_result[k].new(B, *l_result[k].shape[1:]) 90 | x[is_landscape] = l_result[k] 91 | x[is_portrait] = p_result[k] 92 | result[k] = x 93 | 94 | return result 95 | 96 | return wrapper_yes if activate else wrapper_no 97 | 98 | 99 | def transposed(dic): 100 | return {k: v.swapaxes(1, 2) for k, v in dic.items()} 101 | 102 | 103 | def invalid_to_nans(arr, valid_mask, ndim=999): 104 | if valid_mask is not None: 105 | arr = arr.clone() 106 | arr[~valid_mask] = float('nan') 107 | if arr.ndim > ndim: 108 | arr = arr.flatten(-2 - (arr.ndim - ndim), -2) 109 | return arr 110 | 111 | 112 | def invalid_to_zeros(arr, valid_mask, ndim=999): 113 | if valid_mask is not None: 114 | arr = arr.clone() 115 | arr[~valid_mask] = 0 116 | nnz = valid_mask.view(len(valid_mask), -1).sum(1) 117 | else: 118 | nnz = arr.numel() // len(arr) if len(arr) else 0 # number of point per image 119 | if arr.ndim > ndim: 120 | arr = arr.flatten(-2 - (arr.ndim - ndim), -2) 121 | return arr, nnz 122 | -------------------------------------------------------------------------------- /slam3r/viz.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Visualization utilities. The code is adapted from Spann3r: 3 | # https://github.com/HengyiWang/spann3r/blob/main/spann3r/tools/vis.py 4 | # -------------------------------------------------------- 5 | 6 | import numpy as np 7 | from tqdm import tqdm 8 | import open3d as o3d 9 | import imageio 10 | import os 11 | import os.path as osp 12 | import matplotlib.pyplot as plt 13 | import matplotlib.colors as mcolors 14 | from skimage import exposure 15 | 16 | 17 | def render_scene(vis, geometry, camera_parameters, bg_color=[1,1,1], point_size=1., 18 | uint8=True): 19 | vis.clear_geometries() 20 | for g in geometry: 21 | vis.add_geometry(g) 22 | 23 | ctr = vis.get_view_control() 24 | ctr.convert_from_pinhole_camera_parameters(camera_parameters, allow_arbitrary=True) 25 | 26 | opt = vis.get_render_option() 27 | #调整点的大小 28 | opt.point_size = point_size 29 | opt.background_color = np.array(bg_color) 30 | 31 | vis.poll_events() 32 | vis.update_renderer() 33 | 34 | image = vis.capture_screen_float_buffer(do_render=True) 35 | 36 | if not uint8: 37 | return image 38 | else: 39 | image_uint8 = (np.asarray(image) * 255).astype(np.uint8) 40 | return image_uint8 41 | 42 | def render_frames(pts_all, image_all, camera_parameters, output_dir, mask=None, save_video=True, save_camera=True, 43 | init_ids=[], 44 | c2ws=None, 45 | vis_cam=False, 46 | save_stride=1, 47 | sample_ratio=1., 48 | incremental=True, 49 | save_name='render_frames', 50 | bg_color=[1, 1, 1], 51 | point_size=1., 52 | fps=10): 53 | 54 | t, h, w, _ = pts_all.shape 55 | 56 | vis = o3d.visualization.Visualizer() 57 | vis.create_window(width=960, height=544) 58 | 59 | render_frame_path = os.path.join(output_dir, save_name) 60 | os.makedirs(render_frame_path, exist_ok=True) 61 | 62 | if save_camera: 63 | o3d.io.write_pinhole_camera_parameters(os.path.join(render_frame_path, 'camera.json'), camera_parameters) 64 | 65 | video_path = os.path.join(output_dir, f'{save_name}.mp4') 66 | if save_video: 67 | writer = imageio.get_writer(video_path, fps=fps) 68 | 69 | # construct point cloud for initial window 70 | pcd = o3d.geometry.PointCloud() 71 | if init_ids is None: init_ids = [] 72 | if len(init_ids) > 0: 73 | init_ids = np.array(init_ids) 74 | init_masks = mask[init_ids] 75 | init_pts = pts_all[init_ids][init_masks] 76 | init_colors = image_all[init_ids][init_masks] 77 | if sample_ratio < 1.: 78 | sampled_idx = np.random.choice(len(init_pts), int(len(init_pts)*sample_ratio), replace=False) 79 | init_pts = init_pts[sampled_idx] 80 | init_colors = init_colors[sampled_idx] 81 | 82 | pcd.points = o3d.utility.Vector3dVector(init_pts) 83 | pcd.colors = o3d.utility.Vector3dVector(init_colors) 84 | 85 | vis.add_geometry(pcd) 86 | 87 | # visualize incremental reconstruction 88 | for i in tqdm(range(t), desc="Rendering incremental reconstruction"): 89 | if i not in init_ids: 90 | new_pts = pts_all[i].reshape(-1, 3) 91 | new_colors = image_all[i].reshape(-1, 3) 92 | 93 | if mask is not None: 94 | new_pts = new_pts[mask[i].reshape(-1)] 95 | new_colors = new_colors[mask[i].reshape(-1)] 96 | if sample_ratio < 1.: 97 | sampled_idx = np.random.choice(len(new_pts), int(len(new_pts)*sample_ratio), replace=False) 98 | new_pts = new_pts[sampled_idx] 99 | new_colors = new_colors[sampled_idx] 100 | if incremental: 101 | pcd.points.extend(o3d.utility.Vector3dVector(new_pts)) 102 | pcd.colors.extend(o3d.utility.Vector3dVector(new_colors)) 103 | else: 104 | pcd.points = o3d.utility.Vector3dVector(new_pts) 105 | pcd.colors = o3d.utility.Vector3dVector(new_colors) 106 | 107 | if (i+1) % save_stride != 0: 108 | continue 109 | 110 | geometry = [pcd] 111 | if vis_cam: 112 | geometry = geometry + draw_camera(c2ws[i], img=image_all[i]) 113 | 114 | image_uint8 = render_scene(vis, geometry, camera_parameters, bg_color=bg_color, point_size=point_size) 115 | frame_filename = f'frame_{i:03d}.png' 116 | imageio.imwrite(osp.join(render_frame_path, frame_filename), image_uint8) 117 | if save_video: 118 | writer.append_data(image_uint8) 119 | 120 | if save_video: 121 | writer.close() 122 | 123 | vis.destroy_window() 124 | 125 | 126 | def create_image_plane(img, c2w, scale=0.1): 127 | # simulate a image in 3D with point cloud 128 | H, W, _ = img.shape 129 | points = np.meshgrid(np.linspace(0, W-1, W), np.linspace(0, H-1, H)) 130 | points = np.stack(points, axis=-1).reshape(-1, 2) 131 | #translate the center of focal 132 | points -= np.array([W/2, H/2]) 133 | 134 | points *= 2*scale/W 135 | points = np.concatenate([points, 0.1*np.ones((len(points), 1))], axis=-1) 136 | 137 | colors = img.reshape(-1, 3) 138 | 139 | # no need for such resolution 140 | sample_stride = max(1, int(0.2/scale)) 141 | points = points[::sample_stride] 142 | colors = colors[::sample_stride] 143 | 144 | pcd = o3d.geometry.PointCloud() 145 | pcd.points = o3d.utility.Vector3dVector(points) 146 | pcd.colors = o3d.utility.Vector3dVector(colors) 147 | pcd.transform(c2w) 148 | 149 | return pcd 150 | 151 | def draw_camera(c2w, cam_width=0.2/2, cam_height=0.2/2, f=0.10, color=[0, 1, 0], 152 | show_axis=True, img=None): 153 | points = [[0, 0, 0], [-cam_width, -cam_height, f], [cam_width, -cam_height, f], 154 | [cam_width, cam_height, f], [-cam_width, cam_height, f]] 155 | lines = [[0, 1], [0, 2], [0, 3], [0, 4], [1, 2], [2, 3], [3, 4], [4, 1]] 156 | colors = [color for i in range(len(lines))] 157 | 158 | line_set = o3d.geometry.LineSet() 159 | line_set.points = o3d.utility.Vector3dVector(points) 160 | line_set.lines = o3d.utility.Vector2iVector(lines) 161 | line_set.colors = o3d.utility.Vector3dVector(colors) 162 | line_set.transform(c2w) 163 | 164 | res = [line_set] 165 | 166 | if show_axis: 167 | axis = o3d.geometry.TriangleMesh.create_coordinate_frame() 168 | axis.scale(min(cam_width, cam_height), np.array([0., 0., 0.])) 169 | axis.transform(c2w) 170 | res.append(axis) 171 | 172 | if img is not None: 173 | # draw image in the plane of the camera 174 | img_plane = create_image_plane(img, c2w) 175 | res.append(img_plane) 176 | 177 | return res 178 | 179 | def find_render_cam(pcd, poses_all=None, cam_width=0.016, cam_height=0.012, cam_f=0.02): 180 | last_camera_params = None 181 | 182 | def print_camera_pose(vis): 183 | nonlocal last_camera_params 184 | ctr = vis.get_view_control() 185 | camera_params = ctr.convert_to_pinhole_camera_parameters() 186 | last_camera_params = camera_params 187 | 188 | print("Intrinsic matrix:") 189 | print(camera_params.intrinsic.intrinsic_matrix) 190 | print("\nExtrinsic matrix:") 191 | print(camera_params.extrinsic) 192 | 193 | return False 194 | 195 | vis = o3d.visualization.VisualizerWithKeyCallback() 196 | vis.create_window(width=960, height=544) 197 | vis.add_geometry(pcd) 198 | if poses_all is not None: 199 | for pose in poses_all: 200 | for geometry in draw_camera(pose, cam_width, cam_height, cam_f): 201 | vis.add_geometry(geometry) 202 | 203 | opt = vis.get_render_option() 204 | opt.point_size = 1 205 | opt.background_color = np.array([0, 0, 0]) 206 | 207 | print_camera_pose(vis) 208 | print("Press the space key to record the current rendering view.") 209 | vis.register_key_callback(32, print_camera_pose) 210 | 211 | while vis.poll_events(): 212 | vis.update_renderer() 213 | 214 | vis.destroy_window() 215 | 216 | return last_camera_params 217 | 218 | def vis_frame_preds(preds, type, save_path, norm_dims=(0, 1, 2), 219 | enhance_z=False, cmap=True, 220 | save_imgs=True, save_video=True, fps=10): 221 | 222 | if norm_dims is not None: 223 | min_val = preds.min(axis=norm_dims, keepdims=True) 224 | max_val = preds.max(axis=norm_dims, keepdims=True) 225 | preds = (preds - min_val) / (max_val - min_val) 226 | 227 | save_path = osp.join(save_path, type) 228 | if save_imgs: 229 | os.makedirs(save_path, exist_ok=True) 230 | 231 | if save_video: 232 | video_path = osp.join(osp.dirname(save_path), f'{type}.mp4') 233 | writer = imageio.get_writer(video_path, fps=fps) 234 | 235 | for frame_id in tqdm(range(preds.shape[0]), desc=f"Visualizing {type}"): 236 | pred_vis = preds[frame_id].astype(np.float32) 237 | if cmap: 238 | if preds.shape[-1] == 3: 239 | h = 1-pred_vis[...,0] 240 | s = 1-pred_vis[...,1] 241 | v = 1-pred_vis[...,2] 242 | if enhance_z: 243 | new_v = exposure.equalize_adapthist(v, clip_limit=0.01, nbins=256) 244 | v = new_v*0.2 + v*0.8 245 | pred_vis = mcolors.hsv_to_rgb(np.stack([h, s, v], axis=-1)) 246 | elif len(pred_vis.shape)==2 or pred_vis.shape[-1] == 1: 247 | pred_vis = plt.cm.jet(pred_vis) 248 | pred_vis_rgb_uint8 = (pred_vis * 255).astype(np.uint8) 249 | 250 | if save_imgs: 251 | plt.imsave(osp.join(save_path, f'{type}_{frame_id:04d}.png'), pred_vis_rgb_uint8) 252 | 253 | if save_video: 254 | writer.append_data(pred_vis_rgb_uint8) 255 | 256 | if save_video: 257 | writer.close() 258 | 259 | 260 | 261 | -------------------------------------------------------------------------------- /visualize.py: -------------------------------------------------------------------------------- 1 | import open3d as o3d 2 | import numpy as np 3 | import argparse 4 | import trimesh 5 | import torch 6 | from glob import glob 7 | from os.path import join 8 | from tqdm import tqdm 9 | import json 10 | 11 | from slam3r.utils.recon_utils import estimate_focal_knowing_depth, estimate_camera_pose 12 | from slam3r.viz import find_render_cam, render_frames, vis_frame_preds 13 | 14 | parser = argparse.ArgumentParser(description="Inference on a wild captured scene") 15 | parser.add_argument("--vis_cam", action="store_true", help="visualize camera poses") 16 | parser.add_argument("--vis_dir", type=str, required=True, help="directory to the predictions for visualization") 17 | parser.add_argument("--save_stride", type=int, default=1, help="the stride for visualizing per-frame predictions") 18 | parser.add_argument("--enhance_z", action="store_true", help="enhance the z axis for better visualization") 19 | parser.add_argument("--conf_thres_l2w", type=float, default=12, help="confidence threshold for filter out low-confidence points in L2W") 20 | 21 | def vis(args): 22 | 23 | root_dir = args.vis_dir 24 | 25 | preds_dir = join(args.vis_dir, "preds") 26 | local_pcds = np.load(join(preds_dir, 'local_pcds.npy')) # (V, 224, 224, 3) 27 | registered_pcds = np.load(join(preds_dir, 'registered_pcds.npy')) # (V, 224, 224, 3) 28 | local_confs = np.load(join(preds_dir, 'local_confs.npy')) # (V, 224, 224) 29 | registered_confs = np.load(join(preds_dir, 'registered_confs.npy')) # (V, 224, 224) 30 | rgb_imgs = np.load(join(preds_dir, 'input_imgs.npy')) # (V, 224, 224, 3) 31 | 32 | rgb_imgs = rgb_imgs/255. 33 | 34 | recon_res_path = glob(join(args.vis_dir, "*.ply"))[0] 35 | recon_res = trimesh.load(recon_res_path) 36 | whole_pcd = recon_res.vertices 37 | whole_colors = recon_res.visual.vertex_colors[:, :3]/255. 38 | 39 | # change to open3d coordinate x->x y->-y z->-z 40 | whole_pcd[..., 1:] *= -1 41 | registered_pcds[..., 1:] *= -1 42 | 43 | recon_pcd = o3d.geometry.PointCloud() 44 | recon_pcd.points = o3d.utility.Vector3dVector(whole_pcd) 45 | recon_pcd.colors = o3d.utility.Vector3dVector(whole_colors) 46 | 47 | # extract information about the initial window in the reconstruction 48 | num_views = local_pcds.shape[0] 49 | with open(join(preds_dir, "metadata.json"), 'r') as f: 50 | metadata = json.load(f) 51 | init_winsize = metadata['init_winsize'] 52 | kf_stride = metadata['kf_stride'] 53 | init_ids = list(range(0, init_winsize*kf_stride, kf_stride)) 54 | init_ref_id = metadata['init_ref_id'] * kf_stride 55 | 56 | if args.vis_cam: 57 | # estimate camera intrinsics and poses 58 | principal_point = torch.tensor((local_pcds[0].shape[0]//2, local_pcds[0].shape[1]//2)) 59 | init_window_focal = estimate_focal_knowing_depth(torch.tensor(local_pcds[init_ref_id][None]), 60 | principal_point, 61 | focal_mode='weiszfeld') 62 | 63 | focals = [] 64 | for i in tqdm(range(num_views), desc="estimating intrinsics"): 65 | if i in init_ids: 66 | focals.append(init_window_focal) 67 | else: 68 | focal = estimate_focal_knowing_depth(torch.tensor(local_pcds[i:i+1]), 69 | principal_point, 70 | focal_mode='weiszfeld') 71 | focals.append(focal) 72 | 73 | intrinsics = [] 74 | for i in range(num_views): 75 | intrinsic = np.eye(3) 76 | intrinsic[0, 0] = focals[i] 77 | intrinsic[1, 1] = focals[i] 78 | intrinsic[:2, 2] = principal_point 79 | intrinsics.append(intrinsic) 80 | 81 | mean_intrinsics = np.mean(np.stack(intrinsics,axis=0), axis=0) 82 | init_window_intrinsics = intrinsics[init_ref_id] 83 | 84 | c2ws = [] 85 | for i in tqdm(range(0, num_views, 1), desc="estimating camera poses"): 86 | registered_pcd = registered_pcds[i] 87 | # c2w, succ = estimate_camera_pose(registered_pcd, init_window_intrinsics) 88 | c2w, succ = estimate_camera_pose(registered_pcd, mean_intrinsics) 89 | # c2w, succ = estimate_camera_pose(registered_pcd, intrinsics[i]) 90 | if not succ: 91 | print(f"fail to estimate camera pose for view {i}") 92 | c2ws.append(c2w) 93 | 94 | # find the camera parameters for rendering incremental reconstruction process 95 | # It will show a window of open3d, and you can rotate and translate the camera 96 | # press space to save the camera parameters selected 97 | camera_parameters = find_render_cam(recon_pcd, c2ws if args.vis_cam else None) 98 | # render the incremental reconstruction process 99 | render_frames(registered_pcds, rgb_imgs, camera_parameters, root_dir, 100 | mask=(registered_confs > args.conf_thres_l2w), 101 | init_ids=init_ids, 102 | c2ws=c2ws if args.vis_cam else None, 103 | sample_ratio=1/args.save_stride, 104 | save_stride=args.save_stride, 105 | fps=10, 106 | vis_cam=args.vis_cam, 107 | ) 108 | 109 | # save visualizations of per-frame predictions, and combine them into a video 110 | vis_frame_preds(local_confs[::args.save_stride], type="I2P_conf", 111 | save_path=root_dir) 112 | vis_frame_preds(registered_confs[::args.save_stride], type="L2W_conf", 113 | save_path=root_dir) 114 | vis_frame_preds(local_pcds[::args.save_stride], type="I2P_pcds", 115 | save_path=root_dir, 116 | enhance_z=args.enhance_z 117 | ) 118 | vis_frame_preds(registered_pcds[::args.save_stride], type="L2W_pcds", 119 | save_path=root_dir, 120 | ) 121 | vis_frame_preds(rgb_imgs[::args.save_stride], type="imgs", 122 | save_path=root_dir, 123 | norm_dims=None, 124 | cmap=False 125 | ) 126 | 127 | if __name__ == "__main__": 128 | args = parser.parse_args() 129 | vis(args) --------------------------------------------------------------------------------