├── .gitignore
├── README.md
├── app.py
├── datasets_preprocess
├── preprocess_ase.py
├── preprocess_co3d.py
└── preprocess_scannetpp.py
├── docs
├── data_preprocess.md
└── recon_tips.md
├── evaluation
├── eval_recon.py
└── process_gt.py
├── media
├── gradio.png
├── gradio_office.gif
├── gradio_office.jpg
├── replica.gif
└── wild.gif
├── recon.py
├── requirements.txt
├── requirements_optional.txt
├── scripts
├── demo_replica.sh
├── demo_vis_wild.sh
├── demo_wild.sh
├── eval_replica.sh
├── train_i2p.sh
└── train_l2w.sh
├── slam3r
├── __init__.py
├── blocks
│ ├── __init__.py
│ ├── basic_blocks.py
│ └── multiview_blocks.py
├── datasets
│ ├── __init__.py
│ ├── base
│ │ ├── __init__.py
│ │ ├── base_stereo_view_dataset.py
│ │ ├── batched_sampler.py
│ │ └── easy_dataset.py
│ ├── co3d_seq.py
│ ├── project_aria_seq.py
│ ├── replica_seq.py
│ ├── scannetpp_seq.py
│ ├── seven_scenes_seq.py
│ ├── utils
│ │ ├── __init__.py
│ │ ├── cropping.py
│ │ └── transforms.py
│ └── wild_seq.py
├── heads
│ ├── __init__.py
│ ├── dpt_block.py
│ ├── dpt_head.py
│ ├── linear_head.py
│ └── postprocess.py
├── inference.py
├── losses.py
├── models.py
├── patch_embed.py
├── pos_embed
│ ├── __init__.py
│ ├── curope
│ │ ├── __init__.py
│ │ ├── curope.cpp
│ │ ├── curope2d.py
│ │ ├── kernels.cu
│ │ └── setup.py
│ └── pos_embed.py
├── utils
│ ├── __init__.py
│ ├── croco_misc.py
│ ├── device.py
│ ├── geometry.py
│ ├── image.py
│ ├── misc.py
│ └── recon_utils.py
└── viz.py
├── train.py
└── visualize.py
/.gitignore:
--------------------------------------------------------------------------------
1 | data
2 | checkpoints
3 | results
4 | stuffs
5 | recon
6 | visualization
7 | replica_gt
8 | tmp
9 |
10 | # Byte-compiled / optimized / DLL files
11 | __pycache__/
12 | *.py[cod]
13 | *$py.class
14 |
15 | # C extensions
16 | *.so
17 |
18 | # Distribution / packaging
19 | .Python
20 | build/
21 | develop-eggs/
22 | dist/
23 | downloads/
24 | eggs/
25 | .eggs/
26 | lib/
27 | lib64/
28 | parts/
29 | sdist/
30 | var/
31 | wheels/
32 | pip-wheel-metadata/
33 | share/python-wheels/
34 | *.egg-info/
35 | .installed.cfg
36 | *.egg
37 | MANIFEST
38 |
39 | # PyInstaller
40 | # Usually these files are written by a python script from a template
41 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
42 | *.manifest
43 | *.spec
44 |
45 | # Installer logs
46 | pip-log.txt
47 | pip-delete-this-directory.txt
48 |
49 | # Unit test / coverage reports
50 | htmlcov/
51 | .tox/
52 | .nox/
53 | .coverage
54 | .coverage.*
55 | .cache
56 | nosetests.xml
57 | coverage.xml
58 | *.cover
59 | *.py,cover
60 | .hypothesis/
61 | .pytest_cache/
62 |
63 | # Translations
64 | *.mo
65 | *.pot
66 |
67 | # Django stuff:
68 | *.log
69 | local_settings.py
70 | db.sqlite3
71 | db.sqlite3-journal
72 |
73 | # Flask stuff:
74 | instance/
75 | .webassets-cache
76 |
77 | # Scrapy stuff:
78 | .scrapy
79 |
80 | # Sphinx documentation
81 | docs/_build/
82 |
83 | # PyBuilder
84 | target/
85 |
86 | # Jupyter Notebook
87 | .ipynb_checkpoints
88 |
89 | # IPython
90 | profile_default/
91 | ipython_config.py
92 |
93 | # pyenv
94 | .python-version
95 |
96 | # pipenv
97 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
98 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
99 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
100 | # install all needed dependencies.
101 | #Pipfile.lock
102 |
103 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
104 | __pypackages__/
105 |
106 | # Celery stuff
107 | celerybeat-schedule
108 | celerybeat.pid
109 |
110 | # SageMath parsed files
111 | *.sage.py
112 |
113 | # Environments
114 | .env
115 | .venv
116 | env/
117 | venv/
118 | ENV/
119 | env.bak/
120 | venv.bak/
121 |
122 | # Spyder project settings
123 | .spyderproject
124 | .spyproject
125 |
126 | # Rope project settings
127 | .ropeproject
128 |
129 | # mkdocs documentation
130 | /site
131 |
132 | # mypy
133 | .mypy_cache/
134 | .dmypy.json
135 | dmypy.json
136 |
137 | # Pyre type checker
138 | .pyre/
139 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
6 |
7 |
8 |
9 |
[CVPR 2025 Highlight] SLAM3R: Real-Time Dense Scene Reconstruction from Monocular RGB Videos
10 |
11 | Yuzheng Liu*
12 | ·
13 | Siyan Dong*
14 | ·
15 | Shuzhe Wang
16 | ·
17 | Yingda Yin
18 | ·
19 | Yanchao Yang
20 | ·
21 | Qingnan Fan
22 | ·
23 | Baoquan Chen
24 |
25 |
26 |
31 |
32 |
33 |
34 |
35 |
36 |

37 |

38 |
39 |
40 |
41 | SLAM3R is a real-time dense scene reconstruction system that regresses 3D points from video frames using feed-forward neural networks, without explicitly estimating camera parameters.
42 |
43 |
44 |
45 |
46 | ## News
47 |
48 | * **2025-04:** SLAM3R is reported by [机器之心(Chinese)](https://mp.weixin.qq.com/s/fK5vJwbogcfwoduI9FuQ6w)
49 |
50 | * **2025-04:** 🎉 SLAM3R is selected as a **highlight paper** in CVPR 2025 and **Top1 paper** in China3DV 2025.
51 |
52 | ## Table of Contents
53 |
54 | - [Installation](#installation)
55 | - [Demo](#demo)
56 | - [Gradio interface](#gradio-interface)
57 | - [Evaluation on the Replica dataset](#Evaluation-on-the-Replica-dataset)
58 | - [Training](#training)
59 | - [Citation](#citation)
60 | - [Acknowledgments](#acknowledgments)
61 |
62 | ## Installation
63 |
64 | 1. Clone SLAM3R
65 | ```bash
66 | git clone https://github.com/PKU-VCL-3DV/SLAM3R.git
67 | cd SLAM3R
68 | ```
69 |
70 | 2. Prepare environment
71 | ```bash
72 | conda create -n slam3r python=3.11 cmake=3.14.0
73 | conda activate slam3r
74 | # install torch according to your cuda version
75 | pip install torch==2.5.0 torchvision==0.20.0 torchaudio==2.5.0 --index-url https://download.pytorch.org/whl/cu118
76 | pip install -r requirements.txt
77 | # optional: install additional packages to support visualization and data preprocessing
78 | pip install -r requirements_optional.txt
79 | ```
80 |
81 | 3. Optional: Accelerate SLAM3R with XFormers and custom cuda kernels for RoPE
82 | ```bash
83 | # install XFormers according to your pytorch version, see https://github.com/facebookresearch/xformers
84 | pip install xformers==0.0.28.post2
85 | # compile cuda kernels for RoPE
86 | # if the compilation fails, try the propoesd solution: https://github.com/CUT3R/CUT3R/issues/7.
87 | cd slam3r/pos_embed/curope/
88 | python setup.py build_ext --inplace
89 | cd ../../../
90 | ```
91 |
92 | 4. Optional: Download the SLAM3R checkpoints for the [Image-to-Points](https://huggingface.co/siyan824/slam3r_i2p) and [Local-to-World](https://huggingface.co/siyan824/slam3r_l2w) models through HuggingFace
93 | ```bash
94 | from slam3r.models import Image2PointsModel, Local2WorldModel
95 | Image2PointsModel.from_pretrained('siyan824/slam3r_i2p')
96 | Local2WorldModel.from_pretrained('siyan824/slam3r_l2w')
97 | ```
98 | The pre-trained model weights will automatically download when running the demo and evaluation code below.
99 |
100 |
101 | ## Demo
102 | ### Replica dataset
103 | To run our demo on Replica dataset, download the sample scene [here](https://drive.google.com/file/d/1NmBtJ2A30qEzdwM0kluXJOp2d1Y4cRcO/view?usp=drive_link) and unzip it to `./data/Replica_demo/`. Then run the following command to reconstruct the scene from the video images
104 |
105 | ```bash
106 | bash scripts/demo_replica.sh
107 | ```
108 |
109 | The results will be stored at `./results/` by default.
110 |
111 | ### Self-captured outdoor data
112 | We also provide a set of images extracted from an in-the-wild captured video. Download it [here](https://drive.google.com/file/d/1FVLFXgepsqZGkIwg4RdeR5ko_xorKyGt/view?usp=drive_link) and unzip it to `./data/wild/`.
113 |
114 | Set the required parameter in this [script](./scripts/demo_wild.sh), and then run SLAM3R by using the following command
115 |
116 | ```bash
117 | bash scripts/demo_wild.sh
118 | ```
119 |
120 | When `--save_preds` is set in the script, the per-frame prediction for reconstruction will be saved at `./results/TEST_NAME/preds/`. Then you can visualize the incremental reconstruction process with the following command
121 |
122 | ```bash
123 | bash scripts/demo_vis_wild.sh
124 | ```
125 |
126 | A Open3D window will appear after running the script. Please click `space key` to record the adjusted rendering view and close the window. The code will then do the rendering of the incremental reconstruction.
127 |
128 | You can run SLAM3R on your self-captured video with the steps above. Here are [some tips](./docs/recon_tips.md) for it
129 |
130 |
131 | ## Gradio interface
132 | We also provide a Gradio interface, where you can upload a directory, a video or specific images to perform the reconstruction. After setting the reconstruction parameters, you can click the 'Run' button to start the process. Modifying the visualization parameters at the bottom allows you to directly display different visualization results without rerunning the inference.
133 |
134 | The interface can be launched with the following command:
135 |
136 | ```bash
137 | python app.py
138 | ```
139 |
140 | Here is a demo GIF for the Gradio interface (accelerated).
141 |
142 |
143 |
144 |
145 | ## Evaluation on the Replica dataset
146 |
147 | 1. Download the Replica dataset generated by the authors of iMAP:
148 | ```bash
149 | cd data
150 | wget https://cvg-data.inf.ethz.ch/nice-slam/data/Replica.zip
151 | unzip Replica.zip
152 | rm -rf Replica.zip
153 | ```
154 |
155 | 2. Obtain the GT pointmaps and valid masks for each frame by running the following command:
156 | ```bash
157 | python evaluation/process_gt.py
158 | ```
159 | The processed GT will be saved at `./results/gt/replica`.
160 |
161 | 3. Evaluate the reconstruction on the Replica dataset with the following command:
162 |
163 | ```bash
164 | bash ./scripts/eval_replica.sh
165 | ```
166 |
167 | Both the numerical results and the error heatmaps will be saved in the directory `./results/TEST_NAME/eval/`.
168 |
169 | > [!NOTE]
170 | > Different versions of CUDA, PyTorch, and xformers can lead to slight variations in the predicted point cloud. These differences may be amplified during the alignment process in evaluation. Consequently, the numerical results you obtain might differ from those reported in the paper. However, the average values should remain approximately the same.
171 |
172 | ## Training
173 |
174 | ### Datasets
175 |
176 | We use ScanNet++, Aria Synthetic Environments and Co3Dv2 to train our models. For data downloading and pre-processing, please refer to [here](./docs/data_preprocess.md).
177 |
178 | ### Pretrained weights
179 |
180 | ```bash
181 | # download the pretrained weights from DUSt3R
182 | mkdir checkpoints
183 | wget https://download.europe.naverlabs.com/ComputerVision/DUSt3R/DUSt3R_ViTLarge_BaseDecoder_224_linear.pth -P checkpoints/
184 | ```
185 |
186 | ### Start training
187 |
188 | ```bash
189 | # train the Image-to-Points model and the retrieval module
190 | bash ./scripts/train_i2p.sh
191 | # train the Local-to-Wrold model
192 | bash ./scripts/train_l2w.sh
193 | ```
194 | > [!NOTE]
195 | > They are not strictly equivalent to what was used to train SLAM3R, but they should be close enough.
196 |
197 |
198 | ## Citation
199 |
200 | If you find our work helpful in your research, please consider citing:
201 | ```
202 | @article{slam3r,
203 | title={SLAM3R: Real-Time Dense Scene Reconstruction from Monocular RGB Videos},
204 | author={Liu, Yuzheng and Dong, Siyan and Wang, Shuzhe and Yin, Yingda and Yang, Yanchao and Fan, Qingnan and Chen, Baoquan},
205 | journal={arXiv preprint arXiv:2412.09401},
206 | year={2024}
207 | }
208 | ```
209 |
210 |
211 | ## Acknowledgments
212 |
213 | Our implementation is based on several awesome repositories:
214 |
215 | - [Croco](https://github.com/naver/croco)
216 | - [DUSt3R](https://github.com/naver/dust3r)
217 | - [NICER-SLAM](https://github.com/cvg/nicer-slam)
218 | - [Spann3R](https://github.com/HengyiWang/spann3r)
219 |
220 | We thank the respective authors for open-sourcing their code.
221 |
222 |
--------------------------------------------------------------------------------
/datasets_preprocess/preprocess_ase.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # --------------------------------------------------------
3 | # Script to pre-process the aria-ase dataset
4 | # Usage:
5 | # 1. Prepare the codebase and environment for the projectaria_tools
6 | # 2. copy this script to the project root directory
7 | # 3. Run the script
8 | # --------------------------------------------------------
9 | import matplotlib.colors as colors
10 | import matplotlib.pyplot as plt
11 | import numpy as np
12 | import plotly.graph_objects as go
13 | from pathlib import Path
14 | import os
15 | from PIL import Image
16 | from scipy.spatial.transform import Rotation as R
17 | from projectaria_tools.projects import ase
18 | from projectaria_tools.core import data_provider, calibration
19 | from projectaria_tools.core.image import InterpolationMethod
20 | from projects.AriaSyntheticEnvironment.tutorial.code_snippets.readers import read_trajectory_file
21 | import cv2
22 | from tqdm import tqdm
23 | import os, sys, json
24 | import open3d as o3d
25 | import random
26 |
27 |
28 | def save_pointcloud(points_3d_array, rgb ,pcd_name):
29 | # Flatten the instance values array
30 | rgb_values_flat = rgb
31 |
32 | # Check if the number of points matches the number of instance values
33 | assert points_3d_array.shape[0] == rgb_values_flat.shape[0], "The number of points must match the number of instance values"
34 |
35 | # Create an Open3D point cloud object
36 | pcd = o3d.geometry.PointCloud()
37 |
38 | # Assign the 3D points to the point cloud object
39 | pcd.points = o3d.utility.Vector3dVector(points_3d_array)
40 |
41 | # Assign the colors to the point cloud
42 | pcd.colors = o3d.utility.Vector3dVector(rgb_values_flat / 255.0) # Normalize colors to [0, 1]
43 |
44 | # Define the file path where you want to save the point cloud
45 | output_file_path = pcd_name+'.pcd'
46 |
47 | # Save the point cloud in PCD format
48 | o3d.io.write_point_cloud(output_file_path, pcd)
49 |
50 | print(f"Point cloud saved to {output_file_path}")
51 |
52 |
53 | def unproject(camera_params, undistorted_depth,undistorted_rgb):
54 | # Get the height and width of the depth image
55 | height, width = undistorted_depth.shape
56 |
57 | # Generate pixel coordinates
58 | y, x = np.indices((height, width))
59 | pixel_coords = np.stack((x, y), axis=-1).reshape(-1, 2)
60 |
61 | # Flatten the depth image to create a 1D array of depth values
62 | depth_values_flat = undistorted_depth.flatten()
63 | rgb_values_flat = undistorted_rgb.reshape(-1,3)
64 |
65 | # Initialize an array to store 3D points
66 | points_3d = []
67 | valid_rgb = []
68 |
69 | for pixel_coord, depth, rgb in zip(pixel_coords, depth_values_flat, rgb_values_flat):
70 | # Format the pixel coordinate for unproject (reshape to [2, 1])
71 | pixel_coord_reshaped = np.array([[pixel_coord[0]], [pixel_coord[1]]], dtype=np.float64)
72 |
73 | # Unproject the pixel to get the direction vector (ray)
74 | # direction_vector = device.unproject(pixel_coord_reshaped)
75 | X = (pixel_coord_reshaped[0] - camera_params[2]) / camera_params[0] # X = (u - cx) / fx
76 | Y = (pixel_coord_reshaped[1] - camera_params[3]) / camera_params[1] # Y = (v - cy) / fy
77 | direction_vector = np.array([X[0], Y[0], 1],dtype=np.float32)
78 | if direction_vector is not None:
79 | # Replace the z-value of the direction vector with the depth value
80 | # Assuming the direction vector is normalized
81 | direction_vector_normalized = direction_vector / np.linalg.norm(direction_vector)
82 | point_3d = direction_vector_normalized * (depth / 1000)
83 |
84 | # Append the computed 3D point and the corresponding instance
85 | points_3d.append(point_3d.flatten())
86 | valid_rgb.append(rgb)
87 |
88 | # Convert the list of 3D points to a numpy array
89 | points_3d_array = np.array(points_3d)
90 | points_rgb = np.array(valid_rgb)
91 | return points_3d_array,points_rgb
92 |
93 | def distance_to_depth(K, dist, uv=None):
94 | if uv is None and len(dist.shape) >= 2:
95 | # create mesh grid according to d
96 | uv = np.stack(np.meshgrid(np.arange(dist.shape[1]), np.arange(dist.shape[0])), -1)
97 | uv = uv.reshape(-1, 2)
98 | dist = dist.reshape(-1)
99 | if not isinstance(dist, np.ndarray):
100 | import torch
101 | uv = torch.from_numpy(uv).to(dist)
102 | if isinstance(dist, np.ndarray):
103 | # z * np.sqrt(x_temp**2+y_temp**2+z_temp**2) = dist
104 | uvh = np.concatenate([uv, np.ones((len(uv), 1))], -1)
105 | uvh = uvh.T # N, 3
106 | temp_point = np.linalg.inv(K) @ uvh # 3, N
107 | temp_point = temp_point.T # N, 3
108 | z = dist / np.linalg.norm(temp_point, axis=1)
109 | else:
110 | uvh = torch.cat([uv, torch.ones(len(uv), 1).to(uv)], -1)
111 | temp_point = torch.inverse(K) @ uvh
112 | z = dist / torch.linalg.norm(temp_point, dim=1)
113 | return z
114 |
115 | def transform_3d_points(transform, points):
116 | N = len(points)
117 | points_h = np.concatenate([points, np.ones((N, 1))], axis=1)
118 | transformed_points_h = (transform @ points_h.T).T
119 | transformed_points = transformed_points_h[:, :-1]
120 | return transformed_points
121 |
122 |
123 | def aria_export_to_scannet(scene_id, seed):
124 | random.seed(int(seed + scene_id))
125 | src_folder = Path("ase_raw/"+str(scene_id))
126 | trgt_folder = Path("ase_processed/"+str(scene_id))
127 | trgt_folder.mkdir(parents=True, exist_ok=True)
128 | SCENE_ID = src_folder.stem
129 | print("SCENE_ID:", SCENE_ID)
130 |
131 | scene_max_depth = 0
132 | scene_min_depth = np.inf
133 | Path(trgt_folder, "intrinsic").mkdir(exist_ok=True)
134 | Path(trgt_folder, "pose").mkdir(exist_ok=True)
135 | Path(trgt_folder, "depth").mkdir(exist_ok=True)
136 | Path(trgt_folder, "color").mkdir(exist_ok=True)
137 |
138 | rgb_dir = src_folder / "rgb"
139 | depth_dir = src_folder / "depth"
140 | # Load camera calibration
141 | device = ase.get_ase_rgb_calibration()
142 | # Load the trajectory using read_trajectory_file()
143 | trajectory_path = src_folder / "trajectory.csv"
144 | trajectory = read_trajectory_file(trajectory_path)
145 | all_points_3d = []
146 | all_rgb = []
147 | num_frames = len(list(rgb_dir.glob("*.jpg")))
148 | # Path('./debug').mkdir(exist_ok=True)
149 | for frame_idx in tqdm(range(num_frames)):
150 | frame_id = str(frame_idx).zfill(7)
151 | rgb_path = rgb_dir / f"vignette{frame_id}.jpg"
152 | depth_path = depth_dir / f"depth{frame_id}.png"
153 | depth = Image.open(depth_path) # uint16
154 | rgb = cv2.imread(str(rgb_path), cv2.IMREAD_UNCHANGED)
155 | depth = np.array(depth)
156 | scene_min_depth = min(depth.min(), scene_min_depth)
157 | inf_value = np.iinfo(np.array(depth).dtype).max
158 | depth[depth == inf_value] = 0 # consider it as invalid, inplace with 0
159 | T_world_from_device = trajectory["Ts_world_from_device"][frame_idx] # camera-to-world
160 | assert device.get_image_size()[0] == 704
161 | # https://facebookresearch.github.io/projectaria_tools/docs/data_utilities/advanced_code_snippets/image_utilities
162 | focal_length = device.get_focal_lengths()[0]
163 | pinhole = calibration.get_linear_camera_calibration(
164 | 512,
165 | 512,
166 | focal_length,
167 | "camera-rgb",
168 | device.get_transform_device_camera() # important to get correct transformation matrix in pinhole_cw90
169 | )
170 | # distort image
171 | rectified_rgb = calibration.distort_by_calibration(np.array(rgb), pinhole, device, InterpolationMethod.BILINEAR)
172 | # raw_image = np.array(depth) # Will not work
173 | depth = np.array(depth).astype(np.float32) # WILL WORK
174 | rectified_depth = calibration.distort_by_calibration(depth, pinhole, device)
175 |
176 | rotated_image = np.rot90(rectified_rgb, k=3)
177 | rotated_depth = np.rot90(rectified_depth, k=3)
178 |
179 | cv2.imwrite(str(Path(trgt_folder, "color", f"{frame_id}.jpg")), rotated_image)
180 | # # TODO: check this
181 | # plt.imsave(Path(f"./debug/debug_undistort_{frame_id}.png"), np.uint16(rotated_depth), cmap="plasma")
182 | # Get rotated image calibration
183 | pinhole_cw90 = calibration.rotate_camera_calib_cw90deg(pinhole)
184 | principal = pinhole_cw90.get_principal_point()
185 | cx, cy = principal[0], principal[1]
186 | focal_lengths = pinhole_cw90.get_focal_lengths()
187 | fx, fy = focal_lengths
188 | K = np.array([ # camera-to-pixel
189 | [fx, 0, cx],
190 | [0, fy, cy],
191 | [0, 0, 1.0]])
192 |
193 | c2w = T_world_from_device
194 | c2w_rotation = pinhole_cw90.get_transform_device_camera().to_matrix()
195 | c2w_final = c2w @ c2w_rotation # right-matmul!
196 | cam2world = c2w_final
197 |
198 | # save depth
199 | rotated_depth = np.uint16(rotated_depth)
200 | depth_image = Image.fromarray(rotated_depth, mode='I;16')
201 | depth_image.save(str(Path(trgt_folder, "depth", f"{frame_id}.png")))
202 | # for debug; load depth and convert to pointcloud
203 | # depth_image = np.array(Image.open(str(Path(trgt_folder, "depth", f"{frame_id}.png"))), dtype=np.uint16)
204 | # points_3d_array, points_rgb = unproject((fx, fy, cx, cy), depth_image, rotated_image)
205 | # points_3d_world = transform_3d_points(cam2world, points_3d_array)
206 | # all_points_3d.append(points_3d_world)
207 | # all_rgb.append(points_rgb)
208 | # distance-to-depth
209 | # rotated_depth = distance_to_depth(K, rotated_depth).reshape((rotated_depth.shape[0], rotated_depth.shape[1]))#.reshape((dpt.shape[0], dpt.shape[1]))
210 |
211 | Path(trgt_folder, "intrinsic", "intrinsic_color.txt").write_text(f"""{K[0][0]} {K[0][1]} {K[0][2]} 0.00\n{K[1][0]} {K[1][1]} {K[1][2]} 0.00\n{K[2][0]} {K[2][1]} {K[2][2]} 0.00\n0.00 0.00 0.00 1.00""")
212 | Path(trgt_folder, "pose", f"{frame_id}.txt").write_text(f"""{cam2world[0, 0]} {cam2world[0, 1]} {cam2world[0, 2]} {cam2world[0, 3]}\n{cam2world[1, 0]} {cam2world[1, 1]} {cam2world[1, 2]} {cam2world[1, 3]}\n{cam2world[2, 0]} {cam2world[2, 1]} {cam2world[2, 2]} {cam2world[2, 3]}\n0.00 0.00 0.00 1.00""")
213 |
214 |
215 |
216 | if __name__ == "__main__":
217 | seed = 42
218 | for scene_id in tqdm(range(0, 500)):
219 | aria_export_to_scannet(scene_id=scene_id, seed = seed)
220 |
221 |
222 |
--------------------------------------------------------------------------------
/docs/data_preprocess.md:
--------------------------------------------------------------------------------
1 | ### ScanNet++
2 |
3 | 1. Download the [dataset](https://kaldir.vc.in.tum.de/scannetpp/), extract RGB frames and masks from the iPhone data following the [official instruction](https://github.com/scannetpp/scannetpp).
4 |
5 | 2. Preprocess the data with the following command:
6 |
7 | ```bash
8 | python datasets_preprocess/preprocess_scannetpp.py \
9 | --scannetpp_dir $SCANNETPP_DATA_ROOT\
10 | --output_dir data/scannetpp_processed
11 | ```
12 |
13 | the processed data will be saved at `./data/scannetpp_processed`
14 |
15 | > We only use ScanNetpp-V1 (280 scenes in total) to train and validate our SLAM3R models now. ScanNetpp-V2 (906 scenes) is available for potential use, but you may need to modify the scripts for certain scenes in it.
16 |
17 | ### Aria Synthetic Environments
18 |
19 | For more details, please refer to the [official website](https://facebookresearch.github.io/projectaria_tools/docs/open_datasets/aria_synthetic_environments_dataset)
20 |
21 | 1. Prepare the codebase and environment
22 | ```bash
23 | mkdir data/projectaria
24 | cd data/projectaria
25 | git clone https://github.com/facebookresearch/projectaria_tools.git -b 1.5.7
26 | cd -
27 | conda create -n aria python=3.10
28 | conda activate aria
29 | pip install projectaria-tools'[all]' opencv-python open3d
30 | ```
31 |
32 | 2. Get the download-urls file [here](https://www.projectaria.com/datasets/ase/) and place it under .`/data/projectaria/projectaria_tools`. Then download the ASE dataset:
33 | ```bash
34 | cd ./data/projectaria/projectaria_tools
35 | python projects/AriaSyntheticEnvironment/aria_synthetic_environments_downloader.py \
36 | --set train \
37 | --scene-ids 0-499 \
38 | --unzip True \
39 | --cdn-file aria_synthetic_environments_dataset_download_urls.json \
40 | --output-dir $SLAM3R_DIR/data/projectaria/ase_raw
41 | ```
42 |
43 | > We only use the first 500 scenes to train and validate our SLAM3R models now. You can leverage more scenes depending on your resources.
44 |
45 | 4. Preprocess the data.
46 | ```bash
47 | cp ./datasets_preprocess/preprocess_ase.py ./data/projectaria/projectaria_tools/
48 | cd ./data/projectaria
49 | python projectaria_tools/preprocess_ase.py
50 | ```
51 | The processed data will be saved at `./data/projectaria/ase_processed`
52 |
53 |
54 | ### CO3Dv2
55 | 1. Download the [dataset](https://github.com/facebookresearch/co3d)
56 |
57 | 2. Preprocess the data with the same script as in [DUSt3R](https://github.com/naver/dust3r?tab=readme-ov-file), and place the processed data at `./data/co3d_processed`. The data consists of 41 categories for training and 10 categories for validation.
58 |
59 |
--------------------------------------------------------------------------------
/docs/recon_tips.md:
--------------------------------------------------------------------------------
1 | # Tips for running SLAM3R on self-captured data
2 |
3 |
4 | ## Image name format
5 |
6 | Your images should be consecutive frames (e.g., sampled from a video)with numbered filenames to indicate their sequential order, such as `frame-0031.color.png, output_0414.jpg, ...` with zero padding.
7 |
8 |
9 | ## Guidance on parameters in the [script](../scripts/demo_wild.sh)
10 |
11 | `KEYFRAME_STRIDE`: The selection of keyframe stride is crucial. A small stride may lack sufficient camera motion, while a large stride can cause insufficient overlap, both of which hinder the performance. We offer an automatic setting option, but you can also configure it manually.
12 |
13 | `CONF_THRES_I2P`: It helps to filter out points predicted by the Image-to-Points model before input to the Local-to-World model. Consider increasing this threshold if there are many areas in your image with unpredictable depth (e.g., sky or black-masked regions).
14 |
15 | `INITIAL_WINSIZE`: The maximum window size for initializing scene reconstruction. Note that if initial confidence scores are too low, the window size will automatically reduce, with a minimum size of 3.
16 |
17 | `BUFFER_SIZE`: The maximum size of the buffering set. The buffer size should be large enough to store the required scene frames, but an excessively large buffer can impede retrieval efficiency and degrade retrieval performance.
18 |
19 | `BUFFER_STRATEGY`: We provide two strategies for maintaining the scene frame buffering set under a limited buffer size. "reservoir" is suitable for single-room scenarios, while "fifo" performs more effectively on larger scenes. As discussed in our paper, SLAM3R suffers from drift issues in very large scenes.
20 |
21 |
22 | ## Failure cases
23 |
24 | SLAM3R's performance can degrade when processing sparse views with large viewpoint changes, low overlaps, or motion blur. In such cases, the default window size may include images that don't overlap enough, leading to reduced performance. This can also lead to incorrect retrievals, causing some frames to be registered outside the main reconstruction.
25 |
26 | Due to limited training data (currently only ScanNet++, Aria Synthetic Environments, and CO3D-v2), SLAM3R cannot process images with unfamiliar camera distortion and has poor performance with wide-angle cameras and panoramas. The system also struggles with dynamic scenes.
27 |
--------------------------------------------------------------------------------
/evaluation/process_gt.py:
--------------------------------------------------------------------------------
1 | import os
2 | from os import path as osp
3 | import numpy as np
4 | import sys
5 | from tqdm import tqdm
6 |
7 | SLAM3R_DIR = osp.dirname(osp.dirname(osp.abspath(__file__)))
8 | sys.path.insert(0, SLAM3R_DIR) # noqa: E402
9 |
10 | from slam3r.datasets import Replica
11 |
12 |
13 | def get_replica_gt_pcd(scene_id, save_dir, sample_stride=20):
14 | os.makedirs(save_dir, exist_ok=True)
15 | H, W = 224, 224
16 | dataset = Replica(resolution=(W,H), scene_name=scene_id, num_views=1, sample_freq=sample_stride)
17 | print(dataset[0][0]['pts3d'].shape)
18 | all_pcd = np.zeros([len(dataset),H,W,3])
19 | valid_masks = np.ones([len(dataset),H,W], dtype=bool)
20 | for id in tqdm(range(len(dataset))):
21 | view = dataset[id][0]
22 | pcd =view['pts3d']
23 | valid_masks[id] = view['valid_mask']
24 | all_pcd[id] = pcd
25 |
26 | np.save(os.path.join(save_dir, f"{scene_id}_pcds.npy"), all_pcd)
27 | np.save(os.path.join(save_dir, f"{scene_id}_valid_masks.npy"), valid_masks)
28 |
29 |
30 |
31 | if __name__ == "__main__":
32 | for scene_id in ['office0', 'office1', 'office2', 'office3', 'office4', 'room0', 'room1', 'room2']:
33 | get_replica_gt_pcd(scene_id, sample_stride=1, save_dir="results/gt/replica")
--------------------------------------------------------------------------------
/media/gradio.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PKU-VCL-3DV/SLAM3R/c5d0bddb14dc2a04a725a985a3526fc7d5182a5c/media/gradio.png
--------------------------------------------------------------------------------
/media/gradio_office.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PKU-VCL-3DV/SLAM3R/c5d0bddb14dc2a04a725a985a3526fc7d5182a5c/media/gradio_office.gif
--------------------------------------------------------------------------------
/media/gradio_office.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PKU-VCL-3DV/SLAM3R/c5d0bddb14dc2a04a725a985a3526fc7d5182a5c/media/gradio_office.jpg
--------------------------------------------------------------------------------
/media/replica.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PKU-VCL-3DV/SLAM3R/c5d0bddb14dc2a04a725a985a3526fc7d5182a5c/media/replica.gif
--------------------------------------------------------------------------------
/media/wild.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PKU-VCL-3DV/SLAM3R/c5d0bddb14dc2a04a725a985a3526fc7d5182a5c/media/wild.gif
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | roma
2 | gradio
3 | matplotlib
4 | tqdm
5 | opencv-python
6 | scipy
7 | einops
8 | trimesh
9 | tensorboard
10 | pyglet<2
11 | huggingface-hub[torch]>=0.22
12 | pycuda
13 |
--------------------------------------------------------------------------------
/requirements_optional.txt:
--------------------------------------------------------------------------------
1 | # for visualization and evaluation
2 | open3d
3 | imageio[ffmpeg]
4 | scikit-image
5 | # for rendering depths in scannetpp
6 | pyrender
--------------------------------------------------------------------------------
/scripts/demo_replica.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | ######################################################################################
4 | # set the img_dir below to the directory of the set of images you want to reconstruct
5 | # set the postfix below to the format of the rgb images in the img_dir
6 | ######################################################################################
7 | TEST_DATASET="Seq_Data(img_dir='data/Replica_demo/room0', postfix='.jpg', \
8 | img_size=224, silent=False, sample_freq=1, \
9 | start_idx=0, num_views=-1, start_freq=1, to_tensor=True)"
10 |
11 | ######################################################################################
12 | # set the parameters for whole scene reconstruction below
13 | # for defination of these parameters, please refer to the recon.py
14 | ######################################################################################
15 | TEST_NAME="Replica_demo"
16 | KEYFRAME_STRIDE=20
17 | UPDATE_BUFFER_INTV=3
18 | MAX_NUM_REGISTER=10
19 | WIN_R=5
20 | NUM_SCENE_FRAME=10
21 | INITIAL_WINSIZE=5
22 | CONF_THRES_L2W=10
23 | CONF_THRES_I2P=1.5
24 | NUM_POINTS_SAVE=1000000
25 |
26 | GPU_ID=-1
27 |
28 |
29 | python recon.py \
30 | --test_name $TEST_NAME \
31 | --dataset "${TEST_DATASET}" \
32 | --gpu_id $GPU_ID \
33 | --keyframe_stride $KEYFRAME_STRIDE \
34 | --win_r $WIN_R \
35 | --num_scene_frame $NUM_SCENE_FRAME \
36 | --initial_winsize $INITIAL_WINSIZE \
37 | --conf_thres_l2w $CONF_THRES_L2W \
38 | --conf_thres_i2p $CONF_THRES_I2P \
39 | --num_points_save $NUM_POINTS_SAVE \
40 | --update_buffer_intv $UPDATE_BUFFER_INTV \
41 | --max_num_register $MAX_NUM_REGISTER
42 |
--------------------------------------------------------------------------------
/scripts/demo_vis_wild.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | python visualize.py \
3 | --vis_dir results/wild_demo \
4 | --save_stride 1 \
5 | --enhance_z \
6 | --conf_thres_l2w 12 \
7 | # --vis_cam
--------------------------------------------------------------------------------
/scripts/demo_wild.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 |
4 | ######################################################################################
5 | # set the img_dir below to the directory of the set of images you want to reconstruct
6 | # set the postfix below to the format of the rgb images in the img_dir
7 | ######################################################################################
8 | TEST_DATASET="Seq_Data(img_dir='data/wild/Library', postfix='.png', \
9 | img_size=224, silent=False, sample_freq=1, \
10 | start_idx=0, num_views=-1, start_freq=1, to_tensor=True)"
11 |
12 | ######################################################################################
13 | # set the parameters for whole scene reconstruction below
14 | # for defination of these parameters, please refer to the recon.py
15 | ######################################################################################
16 | TEST_NAME="wild_demo"
17 | KEYFRAME_STRIDE=-1 #-1 for auto-adaptive keyframe stride selection
18 | WIN_R=5
19 | MAX_NUM_REGISTER=10
20 | NUM_SCENE_FRAME=10
21 | INITIAL_WINSIZE=5
22 | CONF_THRES_L2W=12
23 | CONF_THRES_I2P=1.5
24 | NUM_POINTS_SAVE=1000000
25 |
26 | UPDATE_BUFFER_INTV=1
27 | BUFFER_SIZE=100 # -1 if size is not limited
28 | BUFFER_STRATEGY="reservoir" # or "fifo"
29 |
30 | KEYFRAME_ADAPT_MIN=1
31 | KEYFRAME_ADAPT_MAX=20
32 | KEYFRAME_ADAPT_STRIDE=1
33 |
34 | GPU_ID=-1
35 |
36 | python recon.py \
37 | --test_name $TEST_NAME \
38 | --dataset "${TEST_DATASET}" \
39 | --gpu_id $GPU_ID \
40 | --keyframe_stride $KEYFRAME_STRIDE \
41 | --win_r $WIN_R \
42 | --num_scene_frame $NUM_SCENE_FRAME \
43 | --initial_winsize $INITIAL_WINSIZE \
44 | --conf_thres_l2w $CONF_THRES_L2W \
45 | --conf_thres_i2p $CONF_THRES_I2P \
46 | --num_points_save $NUM_POINTS_SAVE \
47 | --update_buffer_intv $UPDATE_BUFFER_INTV \
48 | --buffer_size $BUFFER_SIZE \
49 | --buffer_strategy "${BUFFER_STRATEGY}" \
50 | --max_num_register $MAX_NUM_REGISTER \
51 | --keyframe_adapt_min $KEYFRAME_ADAPT_MIN \
52 | --keyframe_adapt_max $KEYFRAME_ADAPT_MAX \
53 | --keyframe_adapt_stride $KEYFRAME_ADAPT_STRIDE \
54 | --save_preds
55 |
--------------------------------------------------------------------------------
/scripts/eval_replica.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | ######################################################################################
4 | # set the parameters for whole scene reconstruction below
5 | # for defination of these parameters, please refer to the recon.py
6 | ######################################################################################
7 | KEYFRAME_STRIDE=20
8 | UPDATE_BUFFER_INTV=3
9 | MAX_NUM_REGISTER=10
10 | WIN_R=5
11 | NUM_SCENE_FRAME=10
12 | INITIAL_WINSIZE=5
13 | CONF_THRES_I2P=1.5
14 |
15 | # the parameter below have nothing to do with the evaluation
16 | NUM_POINTS_SAVE=1000000
17 | CONF_THRES_L2W=10
18 | GPU_ID=-1
19 |
20 | SCENE_NAMES=("office0" "office1" "office2" "office3" "office4" "room0" "room1" "room2")
21 |
22 | for SCENE_NAME in ${SCENE_NAMES[@]};
23 | do
24 |
25 | TEST_NAME="Replica_${SCENE_NAME}"
26 |
27 | echo "--------Start reconstructing scene ${SCENE_NAME} with test name ${TEST_NAME}--------"
28 |
29 | python recon.py \
30 | --test_name "${TEST_NAME}" \
31 | --img_dir "data/Replica/${SCENE_NAME}/results" \
32 | --gpu_id $GPU_ID \
33 | --keyframe_stride $KEYFRAME_STRIDE \
34 | --win_r $WIN_R \
35 | --num_scene_frame $NUM_SCENE_FRAME \
36 | --initial_winsize $INITIAL_WINSIZE \
37 | --conf_thres_l2w $CONF_THRES_L2W \
38 | --conf_thres_i2p $CONF_THRES_I2P \
39 | --num_points_save $NUM_POINTS_SAVE \
40 | --update_buffer_intv $UPDATE_BUFFER_INTV \
41 | --max_num_register $MAX_NUM_REGISTER \
42 | --save_for_eval
43 |
44 | echo "--------Start evaluating scene ${SCENE_NAME} with test name ${TEST_NAME}--------"
45 |
46 | python evaluation/eval_recon.py \
47 | --test_name="${TEST_NAME}" \
48 | --gt_pcd="results/gt/replica/${SCENE_NAME}_pcds.npy"
49 |
50 | done
--------------------------------------------------------------------------------
/scripts/train_i2p.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | MODEL="Image2PointsModel(pos_embed='RoPE100', img_size=(224, 224), head_type='linear', output_mode='pts3d', depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), \
3 | enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_depth=12, dec_num_heads=12, \
4 | mv_dec1='MultiviewDecoderBlock_max',mv_dec2='MultiviewDecoderBlock_max', enc_minibatch = 11)"
5 |
6 | TRAIN_DATASET="4000 @ ScanNetpp_Seq(filter=True, num_views=11, sample_freq=3, split='train', aug_crop=256, resolution=224, transform=ColorJitter, seed=233) + \
7 | 2000 @ Aria_Seq(num_views=11, sample_freq=2, split='train', aug_crop=128, resolution=224, transform=ColorJitter, seed=233) + \
8 | 2000 @ Co3d_Seq(num_views=11, sel_num=3, degree=180, mask_bg='rand', split='train', aug_crop=16, resolution=224, transform=ColorJitter, seed=233)"
9 |
10 | TEST_DATASET="1000 @ ScanNetpp_Seq(filter=True, num_views=11, split='test', resolution=224, seed=666) + \
11 | 1000 @ Aria_Seq(num_views=11, split='test', resolution=224, seed=666) + \
12 | 1000 @ Co3d_Seq(num_views=11, sel_num=3, degree=180, mask_bg='rand', split='test', resolution=224, seed=666)"
13 |
14 | # Stage 1: Train the i2p model for pointmap prediction
15 | PRETRAINED="checkpoints/DUSt3R_ViTLarge_BaseDecoder_224_linear.pth"
16 | TRAIN_OUT_DIR="checkpoints/i2p/slam3r_i2p_stage1"
17 |
18 | torchrun --nproc_per_node=8 train.py \
19 | --train_dataset "${TRAIN_DATASET}" \
20 | --test_dataset "${TEST_DATASET}" \
21 | --model "$MODEL" \
22 | --train_criterion "Jointnorm_ConfLoss(Jointnorm_Regr3D(L21, norm_mode='avg_dis'), alpha=0.2)" \
23 | --test_criterion "Jointnorm_Regr3D(L21, norm_mode='avg_dis')" \
24 | --pretrained $PRETRAINED \
25 | --pretrained_type "dust3r" \
26 | --lr 5e-5 --min_lr 5e-7 --warmup_epochs 10 --epochs 100 --batch_size 4 --accum_iter 1 \
27 | --save_freq 2 --keep_freq 20 --eval_freq 1 --print_freq 20\
28 | --save_config\
29 | --freeze "encoder"\
30 | --loss_func 'i2p' \
31 | --output_dir $TRAIN_OUT_DIR \
32 | --ref_id -1
33 |
34 |
35 | # Stage 2: Train a simple mlp to predict the correlation score
36 | PRETRAINED="checkpoints/i2p/slam3r_i2p_stage1/checkpoint-final.pth"
37 | TRAIN_OUT_DIR="checkpoints/i2p/slam3r_i2p"
38 |
39 | torchrun --nproc_per_node=8 train.py \
40 | --train_dataset "${TRAIN_DATASET}" \
41 | --test_dataset "${TEST_DATASET}" \
42 | --model "$MODEL" \
43 | --train_criterion "Jointnorm_ConfLoss(Jointnorm_Regr3D(L21, norm_mode='avg_dis'), alpha=0.2)" \
44 | --test_criterion "Jointnorm_Regr3D(L21, gt_scale=True)" \
45 | --pretrained $PRETRAINED \
46 | --pretrained_type "slam3r" \
47 | --lr 1e-4 --min_lr 1e-6 --warmup_epochs 5 --epochs 50 --batch_size 4 --accum_iter 1 \
48 | --save_freq 2 --keep_freq 20 --eval_freq 1 --print_freq 20\
49 | --save_config\
50 | --freeze "corr_score_head_only"\
51 | --loss_func "i2p_corr_score" \
52 | --output_dir $TRAIN_OUT_DIR \
53 | --ref_id -1
54 |
--------------------------------------------------------------------------------
/scripts/train_l2w.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | MODEL="Local2WorldModel(pos_embed='RoPE100', img_size=(224, 224), head_type='linear', output_mode='pts3d', depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), \
3 | enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_depth=12, dec_num_heads=12, \
4 | mv_dec1='MultiviewDecoderBlock_max',mv_dec2='MultiviewDecoderBlock_max', enc_minibatch = 12, need_encoder=True)"
5 |
6 | TRAIN_DATASET="4000 @ ScanNetpp_Seq(filter=True, sample_freq=3, num_views=13, split='train', aug_crop=256, resolution=224, transform=ColorJitter, seed=233) + \
7 | 2000 @ Aria_Seq(num_views=13, sample_freq=2, split='train', aug_crop=128, resolution=224, transform=ColorJitter, seed=233) + \
8 | 2000 @ Co3d_Seq(num_views=13, sel_num=3, degree=180, mask_bg='rand', split='train', aug_crop=16, resolution=224, transform=ColorJitter, seed=233)"
9 | TEST_DATASET="1000 @ ScanNetpp_Seq(filter=True, sample_freq=3, num_views=13, split='test', resolution=224, seed=666)+ \
10 | 1000 @ Aria_Seq(num_views=13, split='test', resolution=224, seed=666) + \
11 | 1000 @ Co3d_Seq(num_views=13, sel_num=3, degree=180, mask_bg='rand', split='test', resolution=224, seed=666)"
12 |
13 | PRETRAINED="checkpoints/DUSt3R_ViTLarge_BaseDecoder_224_linear.pth"
14 | TRAIN_OUT_DIR="checkpoints/slam3r_l2w"
15 |
16 | torchrun --nproc_per_node=8 train.py \
17 | --train_dataset "${TRAIN_DATASET}" \
18 | --test_dataset "${TEST_DATASET}" \
19 | --model "$MODEL" \
20 | --train_criterion "Jointnorm_ConfLoss(Jointnorm_Regr3D(L21,norm_mode=None), alpha=0.2)" \
21 | --test_criterion "Jointnorm_Regr3D(L21, norm_mode=None)" \
22 | --pretrained $PRETRAINED \
23 | --pretrained_type "dust3r" \
24 | --lr 5e-5 --min_lr 5e-7 --warmup_epochs 20 --epochs 200 --batch_size 4 --accum_iter 1 \
25 | --save_freq 2 --keep_freq 20 --eval_freq 1 --print_freq 20\
26 | --save_config\
27 | --output_dir $TRAIN_OUT_DIR \
28 | --freeze "encoder"\
29 | --loss_func "l2w" \
30 | --ref_ids 0 1 2 3 4 5
31 |
32 |
--------------------------------------------------------------------------------
/slam3r/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3 |
--------------------------------------------------------------------------------
/slam3r/blocks/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3 |
4 | from .basic_blocks import Block, DecoderBlock, PatchEmbed, Mlp
5 | from .multiview_blocks import MultiviewDecoderBlock_max
--------------------------------------------------------------------------------
/slam3r/blocks/basic_blocks.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3 |
4 |
5 | # --------------------------------------------------------
6 | # Main encoder/decoder blocks
7 | # --------------------------------------------------------
8 | # References:
9 | # timm
10 | # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
11 | # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/helpers.py
12 | # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
13 | # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/mlp.py
14 | # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/patch_embed.py
15 |
16 |
17 | import torch
18 | import torch.nn as nn
19 |
20 | from itertools import repeat
21 | import collections.abc
22 |
23 |
24 | def _ntuple(n):
25 | def parse(x):
26 | if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
27 | return x
28 | return tuple(repeat(x, n))
29 | return parse
30 | to_2tuple = _ntuple(2)
31 |
32 | def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
33 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
34 | """
35 | if drop_prob == 0. or not training:
36 | return x
37 | keep_prob = 1 - drop_prob
38 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
39 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
40 | if keep_prob > 0.0 and scale_by_keep:
41 | random_tensor.div_(keep_prob)
42 | return x * random_tensor
43 |
44 | class DropPath(nn.Module):
45 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
46 | """
47 | def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
48 | super(DropPath, self).__init__()
49 | self.drop_prob = drop_prob
50 | self.scale_by_keep = scale_by_keep
51 |
52 | def forward(self, x):
53 | return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
54 |
55 | def extra_repr(self):
56 | return f'drop_prob={round(self.drop_prob,3):0.3f}'
57 |
58 | class Mlp(nn.Module):
59 | """ MLP as used in Vision Transformer, MLP-Mixer and related networks"""
60 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, drop=0.):
61 | super().__init__()
62 | out_features = out_features or in_features
63 | hidden_features = hidden_features or in_features
64 | bias = to_2tuple(bias)
65 | drop_probs = to_2tuple(drop)
66 |
67 | self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
68 | self.act = act_layer()
69 | self.drop1 = nn.Dropout(drop_probs[0])
70 | self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
71 | self.drop2 = nn.Dropout(drop_probs[1])
72 |
73 | def forward(self, x):
74 | x = self.fc1(x)
75 | x = self.act(x)
76 | x = self.drop1(x)
77 | x = self.fc2(x)
78 | x = self.drop2(x)
79 | return x
80 |
81 | class Attention(nn.Module):
82 |
83 | def __init__(self, dim, rope=None, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
84 | super().__init__()
85 | self.num_heads = num_heads
86 | head_dim = dim // num_heads
87 | self.scale = head_dim ** -0.5
88 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
89 | self.attn_drop = nn.Dropout(attn_drop)
90 | self.proj = nn.Linear(dim, dim)
91 | self.proj_drop = nn.Dropout(proj_drop)
92 | self.rope = rope
93 |
94 | def forward(self, x, xpos):
95 | B, N, C = x.shape
96 |
97 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).transpose(1,3)
98 | q, k, v = [qkv[:,:,i] for i in range(3)]
99 | # q,k,v = qkv.unbind(2) # make torchscript happy (cannot use tensor as tuple)
100 |
101 | if self.rope is not None:
102 | q = self.rope(q, xpos)
103 | k = self.rope(k, xpos)
104 |
105 | attn = (q @ k.transpose(-2, -1)) * self.scale
106 | attn = attn.softmax(dim=-1)
107 | attn = self.attn_drop(attn)
108 |
109 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
110 | x = self.proj(x)
111 | x = self.proj_drop(x)
112 | return x
113 |
114 | class Block(nn.Module):
115 |
116 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
117 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, rope=None):
118 | super().__init__()
119 | self.norm1 = norm_layer(dim)
120 | self.attn = Attention(dim, rope=rope, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
121 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
122 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
123 | self.norm2 = norm_layer(dim)
124 | mlp_hidden_dim = int(dim * mlp_ratio)
125 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
126 |
127 | def forward(self, x, xpos):
128 | x = x + self.drop_path(self.attn(self.norm1(x), xpos))
129 | x = x + self.drop_path(self.mlp(self.norm2(x)))
130 | return x
131 |
132 | class CrossAttention(nn.Module):
133 |
134 | def __init__(self, dim, rope=None, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
135 | super().__init__()
136 | self.num_heads = num_heads
137 | head_dim = dim // num_heads
138 | self.scale = head_dim ** -0.5
139 |
140 | self.projq = nn.Linear(dim, dim, bias=qkv_bias)
141 | self.projk = nn.Linear(dim, dim, bias=qkv_bias)
142 | self.projv = nn.Linear(dim, dim, bias=qkv_bias)
143 | self.attn_drop = nn.Dropout(attn_drop)
144 | self.proj = nn.Linear(dim, dim)
145 | self.proj_drop = nn.Dropout(proj_drop)
146 |
147 | self.rope = rope
148 |
149 | def forward(self, query, key, value, qpos, kpos):
150 | B, Nq, C = query.shape
151 | Nk = key.shape[1]
152 | Nv = value.shape[1]
153 |
154 | q = self.projq(query).reshape(B,Nq,self.num_heads, C// self.num_heads).permute(0, 2, 1, 3)
155 | k = self.projk(key).reshape(B,Nk,self.num_heads, C// self.num_heads).permute(0, 2, 1, 3)
156 | v = self.projv(value).reshape(B,Nv,self.num_heads, C// self.num_heads).permute(0, 2, 1, 3)
157 |
158 | if self.rope is not None:
159 | q = self.rope(q, qpos)
160 | k = self.rope(k, kpos)
161 |
162 | attn = (q @ k.transpose(-2, -1)) * self.scale
163 | attn = attn.softmax(dim=-1)
164 | attn = self.attn_drop(attn)
165 |
166 | x = (attn @ v).transpose(1, 2).reshape(B, Nq, C)
167 | x = self.proj(x)
168 | x = self.proj_drop(x)
169 | return x
170 |
171 | class DecoderBlock(nn.Module):
172 |
173 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
174 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, norm_mem=True, rope=None):
175 | super().__init__()
176 | self.norm1 = norm_layer(dim)
177 | self.attn = Attention(dim, rope=rope, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
178 | self.cross_attn = CrossAttention(dim, rope=rope, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
179 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
180 | self.norm2 = norm_layer(dim)
181 | self.norm3 = norm_layer(dim)
182 | mlp_hidden_dim = int(dim * mlp_ratio)
183 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
184 | self.norm_y = norm_layer(dim) if norm_mem else nn.Identity()
185 |
186 | def forward(self, x, y, xpos, ypos):
187 | x = x + self.drop_path(self.attn(self.norm1(x), xpos))
188 | y_ = self.norm_y(y)
189 | x = x + self.drop_path(self.cross_attn(self.norm2(x), y_, y_, xpos, ypos))
190 | x = x + self.drop_path(self.mlp(self.norm3(x)))
191 | return x, y
192 |
193 |
194 | # patch embedding
195 | class PositionGetter(object):
196 | """ return positions of patches """
197 |
198 | def __init__(self):
199 | self.cache_positions = {}
200 |
201 | def __call__(self, b, h, w, device):
202 | if not (h,w) in self.cache_positions:
203 | x = torch.arange(w, device=device)
204 | y = torch.arange(h, device=device)
205 | self.cache_positions[h,w] = torch.cartesian_prod(y, x) # (h, w, 2)
206 | pos = self.cache_positions[h,w].view(1, h*w, 2).expand(b, -1, 2).clone()
207 | return pos
208 |
209 | class PatchEmbed(nn.Module):
210 | """ just adding _init_weights + position getter compared to timm.models.layers.patch_embed.PatchEmbed"""
211 |
212 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):
213 | super().__init__()
214 | img_size = to_2tuple(img_size)
215 | patch_size = to_2tuple(patch_size)
216 | self.img_size = img_size
217 | self.patch_size = patch_size
218 | self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
219 | self.num_patches = self.grid_size[0] * self.grid_size[1]
220 | self.flatten = flatten
221 |
222 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
223 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
224 |
225 | self.position_getter = PositionGetter()
226 |
227 | def forward(self, x):
228 | B, C, H, W = x.shape
229 | torch._assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).")
230 | torch._assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).")
231 | x = self.proj(x)
232 | pos = self.position_getter(B, x.size(2), x.size(3), x.device)
233 | if self.flatten:
234 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
235 | x = self.norm(x)
236 | return x, pos
237 |
238 | def _init_weights(self):
239 | w = self.proj.weight.data
240 | torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
241 |
242 |
--------------------------------------------------------------------------------
/slam3r/blocks/multiview_blocks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from .basic_blocks import Mlp, Attention, CrossAttention, DropPath
5 | try:
6 | import xformers.ops as xops
7 | XFORMERS_AVALIABLE = True
8 | except ImportError:
9 | print("xformers not avaliable, use self-implemented attention instead")
10 | XFORMERS_AVALIABLE = False
11 |
12 |
13 | class XFormer_Attention(nn.Module):
14 | """Warpper for self-attention module with xformers.
15 | Calculate attention scores with xformers memory_efficient_attention.
16 | When inference is performed on the CPU or when xformer is not installed, it will degrade to the normal attention.
17 | """
18 | def __init__(self, old_module:Attention):
19 | super().__init__()
20 | self.num_heads = old_module.num_heads
21 | self.scale = old_module.scale
22 | self.qkv = old_module.qkv
23 | self.attn_drop_prob = old_module.attn_drop.p
24 | self.proj = old_module.proj
25 | self.proj_drop = old_module.proj_drop
26 | self.rope = old_module.rope
27 | self.attn_drop = old_module.attn_drop
28 |
29 | def forward(self, x, xpos):
30 | B, N, C = x.shape
31 |
32 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).transpose(1,3)
33 | q, k, v = [qkv[:,:,i] for i in range(3)] #shape: (B, num_heads, N, C//num_heads)
34 |
35 | if self.rope is not None:
36 | q = self.rope(q, xpos) # (B, H, N, K)
37 | k = self.rope(k, xpos)
38 |
39 | if x.is_cuda and XFORMERS_AVALIABLE:
40 | q = q.permute(0, 2, 1, 3) # (B, N, H, K)
41 | k = k.permute(0, 2, 1, 3)
42 | v = v.permute(0, 2, 1, 3)
43 | drop_prob = self.attn_drop_prob if self.training else 0
44 | x = xops.memory_efficient_attention(q, k, v, scale=self.scale, p=drop_prob) # (B, N, H, K)
45 | else:
46 | attn = (q @ k.transpose(-2, -1)) * self.scale
47 | attn = attn.softmax(dim=-1)
48 | attn = self.attn_drop(attn)
49 | x = (attn @ v).transpose(1, 2)
50 |
51 | x=x.reshape(B, N, C)
52 | x = self.proj(x)
53 | x = self.proj_drop(x)
54 | return x
55 |
56 |
57 | class MultiviewDecoderBlock_max(nn.Module):
58 | """Multiview decoder block,
59 | which takes as input arbitrary number of source views and target view features.
60 | Use max-pooling to merge features queried from different src views.
61 | """
62 |
63 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
64 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, norm_mem=True, rope=None):
65 | super().__init__()
66 | self.norm1 = norm_layer(dim)
67 | self.attn = Attention(dim, rope=rope, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
68 | self.cross_attn = CrossAttention(dim, rope=rope, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
69 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
70 | self.norm2 = norm_layer(dim)
71 | self.norm3 = norm_layer(dim)
72 | mlp_hidden_dim = int(dim * mlp_ratio)
73 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
74 | self.norm_y = norm_layer(dim) if norm_mem else nn.Identity()
75 | if XFORMERS_AVALIABLE:
76 | self.attn = XFormer_Attention(self.attn)
77 |
78 | def batched_cross_attn(self, xs, ys, xposes, yposes, rel_ids_list_d, M):
79 | """
80 | Calculate cross-attention between Vx target views and Vy source views in a single batch.
81 | """
82 | xs_normed = self.norm2(xs)
83 | ys_normed = self.norm_y(ys)
84 | cross_attn = self.cross_attn
85 | Vx, B, Nx, C = xs.shape
86 | Vy, B, Ny, C = ys.shape
87 | num_heads = cross_attn.num_heads
88 |
89 | #precompute q,k,v for each view to save computation
90 | qs = cross_attn.projq(xs_normed).reshape(Vx*B,Nx,num_heads, C//num_heads).permute(0, 2, 1, 3) # (Vx*B,num_heads,Nx,C//num_heads)
91 | ks = cross_attn.projk(ys_normed).reshape(Vy*B,Ny,num_heads, C//num_heads).permute(0, 2, 1, 3) # (Vy*B,num_heads,Ny,C//num_heads)
92 | vs = cross_attn.projv(ys_normed).reshape(Vy,B,Ny,num_heads, C//num_heads) # (Vy*B,num_heads,Ny,C//num_heads)
93 |
94 | #add rope
95 | if cross_attn.rope is not None:
96 | qs = cross_attn.rope(qs, xposes)
97 | ks = cross_attn.rope(ks, yposes)
98 | qs = qs.permute(0, 2, 1, 3).reshape(Vx,B,Nx,num_heads,C// num_heads) # (Vx, B, Nx, H, K)
99 | ks = ks.permute(0, 2, 1, 3).reshape(Vy,B,Ny,num_heads,C// num_heads) # (Vy, B, Ny, H, K)
100 |
101 | # construct query, key, value for each target view
102 | ks_respect = torch.index_select(ks, 0, rel_ids_list_d) # (Vx*M, B, Ny, H, K)
103 | vs_respect = torch.index_select(vs, 0, rel_ids_list_d) # (Vx*M, B, Ny, H, K)
104 | qs_corresp = torch.unsqueeze(qs, 1).expand(-1, M, -1, -1, -1, -1) # (Vx, M, B, Nx, H, K)
105 |
106 | ks_compact = ks_respect.reshape(Vx*M*B, Ny, num_heads, C//num_heads)
107 | vs_compact = vs_respect.reshape(Vx*M*B, Ny, num_heads, C//num_heads)
108 | qs_compact = qs_corresp.reshape(Vx*M*B, Nx, num_heads, C//num_heads)
109 |
110 | # calculate attention results for all target views in one go
111 | if xs.is_cuda and XFORMERS_AVALIABLE:
112 | drop_prob = cross_attn.attn_drop.p if self.training else 0
113 | attn_outputs = xops.memory_efficient_attention(qs_compact, ks_compact, vs_compact,
114 | scale=self.cross_attn.scale, p=drop_prob) # (V*M*B, N, H, K)
115 | else:
116 | ks_compact = ks_compact.permute(0, 2, 1, 3) # (Vx*M*B, H, Ny, K)
117 | qs_compact = qs_compact.permute(0, 2, 1, 3) # (Vx*M*B, H, Nx, K)
118 | vs_compact = vs_compact.permute(0, 2, 1, 3) # (Vx*M*B, H, Ny, K)
119 | attn = (qs_compact @ ks_compact.transpose(-2, -1)) * self.cross_attn.scale # (V*M*B, H, Nx, Ny)
120 | attn = attn.softmax(dim=-1) # (V*M*B, H, Nx, Ny)
121 | attn = self.cross_attn.attn_drop(attn)
122 | attn_outputs = (attn @ vs_compact).transpose(1, 2).reshape(Vx*M*B, Nx, num_heads, C//num_heads) # (V*M*B, Nx, H, K)
123 |
124 | attn_outputs = attn_outputs.reshape(Vx, M, B, Nx, C) #(Vx, M, B, Nx, C)
125 | attn_outputs = cross_attn.proj_drop(cross_attn.proj(attn_outputs)) #(Vx, M, B, Nx, C)
126 |
127 | return attn_outputs
128 |
129 | def forward(self, xs:torch.Tensor, ys:torch.Tensor,
130 | xposes:torch.Tensor, yposes:torch.Tensor,
131 | rel_ids_list_d:torch.Tensor, M:int):
132 | """refine Vx target view feature parallelly, with the information of Vy source view
133 |
134 | Args:
135 | xs: (Vx,B,S,D): features of target views to refine.(S: number of tokens, D: feature dimension)
136 | ys: (Vy,B,S,D): features of source views to query from.
137 | M: number of source views to query from for each target view
138 | rel_ids_list_d: (Vx*M,) indices of source views to query from for each target view
139 |
140 | For example:
141 | Suppose we have 3 target views and 4 source views,
142 | then xs shuold has shape (3,B,S,D), ys should has shape (4,B,S,D).
143 |
144 | If we require xs[0] to query features from ys[0], ys[1],
145 | xs[1] to query features from ys[2], ys[2],(duplicate number supported)
146 | xs[2] to query features from ys[2], ys[3],
147 | then we should set M=2, rel_ids_list_d=[0,1, 2,2, 2,3]
148 | """
149 | Vx, B, Nx, C = xs.shape
150 |
151 | # self-attention on each target view feature
152 | xs = xs.reshape(-1, *xs.shape[2:]) # (Vx*B,S,C)
153 | xposes = xposes.reshape(-1, *xposes.shape[2:]) # (Vx*B,S,2)
154 | yposes = yposes.reshape(-1, *yposes.shape[2:])
155 | xs = xs + self.drop_path(self.attn(self.norm1(xs), xposes)) #(Vx*B,S,C)
156 |
157 | # each target view conducts cross-attention with all source views to query features
158 | attn_outputs = self.batched_cross_attn(xs.reshape(Vx,B,Nx,C), ys, xposes, yposes, rel_ids_list_d, M)
159 |
160 | # max-pooling to aggregate features queried from different source views
161 | merged_ys, indices = torch.max(attn_outputs, dim=1) #(Vx, B, Nx, C)
162 | merged_ys = merged_ys.reshape(Vx*B,Nx,C) #(Vx*B,Nx,C)
163 |
164 | xs = xs + self.drop_path(merged_ys)
165 | xs = xs + self.drop_path(self.mlp(self.norm3(xs))) #(VB,N,C)
166 | xs = xs.reshape(Vx,B,Nx,C)
167 | return xs
168 |
169 |
--------------------------------------------------------------------------------
/slam3r/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3 | import torch
4 | import numpy as np
5 |
6 | from .utils.transforms import *
7 | from .base.batched_sampler import BatchedRandomSampler # noqa: F401
8 | from .replica_seq import Replica
9 | from .scannetpp_seq import ScanNetpp_Seq
10 | from .project_aria_seq import Aria_Seq
11 | from .co3d_seq import Co3d_Seq
12 | from .base.base_stereo_view_dataset import EasyDataset
13 |
14 | def get_data_loader(dataset, batch_size, num_workers=8, shuffle=True, drop_last=True, pin_mem=True):
15 | import torch
16 | from slam3r.utils.croco_misc import get_world_size, get_rank
17 |
18 | # pytorch dataset
19 | if isinstance(dataset, str):
20 | dataset = eval(dataset)
21 |
22 | world_size = get_world_size()
23 | rank = get_rank()
24 |
25 | try:
26 | sampler = dataset.make_sampler(batch_size, shuffle=shuffle, world_size=world_size,
27 | rank=rank, drop_last=drop_last)
28 | except (AttributeError, NotImplementedError):
29 | # not avail for this dataset
30 | if torch.distributed.is_initialized():
31 | sampler = torch.utils.data.DistributedSampler(
32 | dataset, num_replicas=world_size, rank=rank, shuffle=shuffle, drop_last=drop_last
33 | )
34 | elif shuffle:
35 | sampler = torch.utils.data.RandomSampler(dataset)
36 | else:
37 | sampler = torch.utils.data.SequentialSampler(dataset)
38 |
39 | data_loader = torch.utils.data.DataLoader(
40 | dataset,
41 | sampler=sampler,
42 | batch_size=batch_size,
43 | num_workers=num_workers,
44 | pin_memory=pin_mem,
45 | drop_last=drop_last,
46 | )
47 |
48 | return data_loader
49 |
50 | class MultiDataLoader:
51 | def __init__(self, dataloaders:list, return_id=False):
52 | self.dataloaders = dataloaders
53 | self.len_dataloaders = [len(loader) for loader in dataloaders]
54 | self.total_length = sum(self.len_dataloaders)
55 | self.epoch = None
56 | self.return_id = return_id
57 |
58 | def __len__(self):
59 | return self.total_length
60 |
61 | def set_epoch(self, epoch):
62 | for data_loader in self.dataloaders:
63 | if hasattr(data_loader, 'dataset') and hasattr(data_loader.dataset, 'set_epoch'):
64 | data_loader.dataset.set_epoch(epoch)
65 | if hasattr(data_loader, 'sampler') and hasattr(data_loader.sampler, 'set_epoch'):
66 | data_loader.sampler.set_epoch(epoch)
67 | self.epoch = epoch
68 |
69 | def __iter__(self):
70 | loader_idx = []
71 | for idx, length in enumerate(self.len_dataloaders):
72 | loader_idx += [idx]*length
73 | loader_idx = np.array(loader_idx)
74 | assert loader_idx.shape[0] == self.total_length
75 | #是否需要一个统一的seed让每个进程遍历dataloaders的顺序相同?
76 | if self.epoch is None:
77 | assert len(self.dataloaders) == 1
78 | else:
79 | seed = 777 + self.epoch
80 | rng = np.random.default_rng(seed=seed)
81 | rng.shuffle(loader_idx)
82 | batch_count = 0
83 |
84 | iters = [iter(loader) for loader in self.dataloaders] # iterator for each dataloader
85 | while True:
86 | idx = loader_idx[batch_count]
87 | try:
88 | batch = next(iters[idx])
89 | except StopIteration: # this won't happen in distribute mode if drop_last is False
90 | iters[idx] = iter(self.dataloaders[idx])
91 | batch = next(iters[idx])
92 | if self.return_id:
93 | batch = (batch, idx)
94 | yield batch
95 | batch_count += 1
96 | if batch_count == self.total_length:
97 | # batch_count -= self.total_length
98 | break
99 |
100 |
101 | def get_multi_data_loader(dataset, batch_size, return_id=False, num_workers=8, shuffle=True, drop_last=True, pin_mem=True):
102 | import torch
103 | from slam3r.utils.croco_misc import get_world_size, get_rank
104 |
105 | # pytorch dataset
106 | if isinstance(dataset, str):
107 | dataset = eval(dataset)
108 |
109 | if isinstance(dataset, EasyDataset):
110 | datasets = [dataset]
111 | else:
112 | datasets = dataset
113 | print(datasets)
114 | assert isinstance(datasets,(tuple, list))
115 |
116 | world_size = get_world_size()
117 | rank = get_rank()
118 | dataloaders = []
119 | for dataset in datasets:
120 | try:
121 | sampler = dataset.make_sampler(batch_size, shuffle=shuffle, world_size=world_size,
122 | rank=rank, drop_last=drop_last)
123 | except (AttributeError, NotImplementedError):
124 | # not avail for this dataset
125 | if torch.distributed.is_initialized():
126 | sampler = torch.utils.data.DistributedSampler(
127 | dataset, num_replicas=world_size, rank=rank, shuffle=shuffle, drop_last=drop_last
128 | )
129 | elif shuffle:
130 | sampler = torch.utils.data.RandomSampler(dataset)
131 | else:
132 | sampler = torch.utils.data.SequentialSampler(dataset)
133 |
134 | data_loader = torch.utils.data.DataLoader(
135 | dataset,
136 | sampler=sampler,
137 | batch_size=batch_size,
138 | num_workers=num_workers,
139 | pin_memory=pin_mem,
140 | drop_last=drop_last,
141 | )
142 | dataloaders.append(data_loader)
143 |
144 | multi_dataloader = MultiDataLoader(dataloaders, return_id=return_id)
145 | return multi_dataloader
146 |
--------------------------------------------------------------------------------
/slam3r/datasets/base/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3 |
--------------------------------------------------------------------------------
/slam3r/datasets/base/batched_sampler.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3 | #
4 | # --------------------------------------------------------
5 | # Random sampling under a constraint
6 | # --------------------------------------------------------
7 | import numpy as np
8 | import torch
9 |
10 |
11 | class BatchedRandomSampler:
12 | """ Random sampling under a constraint: each sample in the batch has the same feature,
13 | which is chosen randomly from a known pool of 'features' for each batch.
14 |
15 | For instance, the 'feature' could be the image aspect-ratio.
16 |
17 | The index returned is a tuple (sample_idx, feat_idx).
18 | This sampler ensures that each series of `batch_size` indices has the same `feat_idx`.
19 | """
20 |
21 | def __init__(self, dataset, batch_size, pool_size, world_size=1, rank=0, drop_last=True):
22 | self.batch_size = batch_size
23 | self.pool_size = pool_size
24 |
25 | self.len_dataset = N = len(dataset)
26 | self.total_size = round_by(N, batch_size*world_size) if drop_last else N
27 | assert world_size == 1 or drop_last, 'must drop the last batch in distributed mode'
28 |
29 | # distributed sampler
30 | self.world_size = world_size
31 | self.rank = rank
32 | self.epoch = None
33 |
34 | def __len__(self):
35 | return self.total_size // self.world_size
36 |
37 | def set_epoch(self, epoch):
38 | self.epoch = epoch
39 |
40 | def __iter__(self):
41 | # prepare RNG
42 | if self.epoch is None:
43 | assert self.world_size == 1 and self.rank == 0, 'use set_epoch() if distributed mode is used'
44 | seed = int(torch.empty((), dtype=torch.int64).random_().item())
45 | else:
46 | seed = self.epoch + 777
47 | rng = np.random.default_rng(seed=seed)
48 |
49 | # random indices (will restart from 0 if not drop_last)
50 | sample_idxs = np.arange(self.total_size)
51 | rng.shuffle(sample_idxs)
52 |
53 | # random feat_idxs (same across each batch)
54 | n_batches = (self.total_size+self.batch_size-1) // self.batch_size
55 | feat_idxs = rng.integers(self.pool_size, size=n_batches) # shape = (n_batches,)
56 | feat_idxs = np.broadcast_to(feat_idxs[:, None], (n_batches, self.batch_size)) # shape = (n_batches, batch_size)
57 | feat_idxs = feat_idxs.ravel()[:self.total_size]
58 |
59 | # put them together
60 | idxs = np.c_[sample_idxs, feat_idxs] # shape = (total_size, 2)
61 |
62 | # Distributed sampler: we select a subset of batches
63 | # make sure the slice for each node is aligned with batch_size
64 | size_per_proc = self.batch_size * ((self.total_size + self.world_size *
65 | self.batch_size-1) // (self.world_size * self.batch_size))
66 | idxs = idxs[self.rank*size_per_proc: (self.rank+1)*size_per_proc]
67 |
68 | yield from (tuple(idx) for idx in idxs)
69 |
70 |
71 | def round_by(total, multiple, up=False):
72 | if up:
73 | total = total + multiple-1
74 | return (total//multiple) * multiple
75 |
--------------------------------------------------------------------------------
/slam3r/datasets/base/easy_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3 | #
4 | # --------------------------------------------------------
5 | # A dataset base class that you can easily resize and combine.
6 | # --------------------------------------------------------
7 | import numpy as np
8 | from slam3r.datasets.base.batched_sampler import BatchedRandomSampler
9 |
10 |
11 | class EasyDataset:
12 | """ a dataset that you can easily resize and combine.
13 | Examples:
14 | ---------
15 | 2 * dataset ==> duplicate each element 2x
16 |
17 | 10 @ dataset ==> set the size to 10 (random sampling, duplicates if necessary)
18 |
19 | dataset1 + dataset2 ==> concatenate datasets
20 | """
21 |
22 | def __add__(self, other):
23 | return CatDataset([self, other])
24 |
25 | def __rmul__(self, factor):
26 | return MulDataset(factor, self)
27 |
28 | def __rmatmul__(self, factor):
29 | return ResizedDataset(factor, self)
30 |
31 | def set_epoch(self, epoch):
32 | pass # nothing to do by default
33 |
34 | def make_sampler(self, batch_size, shuffle=True, world_size=1, rank=0, drop_last=True):
35 | if not (shuffle):
36 | raise NotImplementedError() # cannot deal yet
37 | num_of_aspect_ratios = len(self._resolutions)
38 | return BatchedRandomSampler(self, batch_size, num_of_aspect_ratios, world_size=world_size, rank=rank, drop_last=drop_last)
39 |
40 |
41 | class MulDataset (EasyDataset):
42 | """ Artifically augmenting the size of a dataset.
43 | """
44 | multiplicator: int
45 |
46 | def __init__(self, multiplicator, dataset):
47 | assert isinstance(multiplicator, int) and multiplicator > 0
48 | self.multiplicator = multiplicator
49 | self.dataset = dataset
50 |
51 | def __len__(self):
52 | return self.multiplicator * len(self.dataset)
53 |
54 | def __repr__(self):
55 | return f'{self.multiplicator}*{repr(self.dataset)}'
56 |
57 | def __getitem__(self, idx):
58 | if isinstance(idx, tuple):
59 | idx, other = idx
60 | return self.dataset[idx // self.multiplicator, other]
61 | else:
62 | return self.dataset[idx // self.multiplicator]
63 |
64 | @property
65 | def _resolutions(self):
66 | return self.dataset._resolutions
67 |
68 |
69 | class ResizedDataset (EasyDataset):
70 | """ Artifically changing the size of a dataset.
71 | """
72 | new_size: int
73 |
74 | def __init__(self, new_size, dataset):
75 | assert isinstance(new_size, int) and new_size > 0
76 | self.new_size = new_size
77 | self.dataset = dataset
78 |
79 | def __len__(self):
80 | return self.new_size
81 |
82 | def __repr__(self):
83 | size_str = str(self.new_size)
84 | for i in range((len(size_str)-1) // 3):
85 | sep = -4*i-3
86 | size_str = size_str[:sep] + '_' + size_str[sep:]
87 | return f'{size_str} @ {repr(self.dataset)}'
88 |
89 | def set_epoch(self, epoch):
90 | # this random shuffle only depends on the epoch
91 | rng = np.random.default_rng(seed=epoch+777)
92 |
93 | # shuffle all indices
94 | perm = rng.permutation(len(self.dataset))
95 |
96 | # rotary extension until target size is met
97 | shuffled_idxs = np.concatenate([perm] * (1 + (len(self)-1) // len(self.dataset)))
98 | self._idxs_mapping = shuffled_idxs[:self.new_size]
99 |
100 | assert len(self._idxs_mapping) == self.new_size
101 |
102 | def __getitem__(self, idx):
103 | assert hasattr(self, '_idxs_mapping'), 'You need to call dataset.set_epoch() to use ResizedDataset.__getitem__()'
104 | if isinstance(idx, tuple):
105 | idx, other = idx
106 | return self.dataset[self._idxs_mapping[idx], other]
107 | else:
108 | return self.dataset[self._idxs_mapping[idx]]
109 |
110 | @property
111 | def _resolutions(self):
112 | return self.dataset._resolutions
113 |
114 |
115 | class CatDataset (EasyDataset):
116 | """ Concatenation of several datasets
117 | """
118 |
119 | def __init__(self, datasets):
120 | for dataset in datasets:
121 | assert isinstance(dataset, EasyDataset)
122 | self.datasets = datasets
123 | self._cum_sizes = np.cumsum([len(dataset) for dataset in datasets])
124 |
125 | def __len__(self):
126 | return self._cum_sizes[-1]
127 |
128 | def __repr__(self):
129 | # remove uselessly long transform
130 | return ' + '.join(repr(dataset).replace(',transform=Compose( ToTensor() Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))', '') for dataset in self.datasets)
131 |
132 | def set_epoch(self, epoch):
133 | for dataset in self.datasets:
134 | dataset.set_epoch(epoch)
135 |
136 | def __getitem__(self, idx):
137 | other = None
138 | if isinstance(idx, tuple):
139 | idx, other = idx
140 |
141 | if not (0 <= idx < len(self)):
142 | raise IndexError()
143 |
144 | db_idx = np.searchsorted(self._cum_sizes, idx, 'right')
145 | dataset = self.datasets[db_idx]
146 | new_idx = idx - (self._cum_sizes[db_idx - 1] if db_idx > 0 else 0)
147 |
148 | if other is not None:
149 | new_idx = (new_idx, other)
150 | return dataset[new_idx]
151 |
152 | @property
153 | def _resolutions(self):
154 | resolutions = self.datasets[0]._resolutions
155 | for dataset in self.datasets[1:]:
156 | assert tuple(dataset._resolutions) == tuple(resolutions)
157 | return resolutions
158 |
--------------------------------------------------------------------------------
/slam3r/datasets/co3d_seq.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3 | #
4 | # --------------------------------------------------------
5 | # Dataloader for preprocessed Co3d_v2
6 | # dataset at https://github.com/facebookresearch/co3d - Creative Commons Attribution-NonCommercial 4.0 International
7 | # See datasets_preprocess/preprocess_co3d.py
8 | # --------------------------------------------------------
9 | import os.path as osp
10 | import json
11 | import itertools
12 | from collections import deque
13 | import cv2
14 | import numpy as np
15 |
16 | SLAM3R_DIR = osp.dirname(osp.dirname(osp.dirname(osp.abspath(__file__))))
17 | import sys # noqa: E402
18 | sys.path.insert(0, SLAM3R_DIR) # noqa: E402
19 | from slam3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset
20 | from slam3r.utils.image import imread_cv2
21 |
22 | TRAINING_CATEGORIES = [
23 | "apple","backpack","banana","baseballbat","baseballglove","bench","bicycle",
24 | "bottle","bowl","broccoli","cake","car","carrot","cellphone","chair","cup","donut","hairdryer","handbag","hydrant","keyboard",
25 | "laptop","microwave","motorcycle","mouse","orange","parkingmeter","pizza","plant","stopsign","teddybear","toaster","toilet",
26 | "toybus","toyplane","toytrain","toytruck","tv","umbrella","vase","wineglass",
27 | ]
28 | TEST_CATEGORIES = ["ball", "book", "couch", "frisbee", "hotdog", "kite", "remote", "sandwich", "skateboard", "suitcase"]
29 |
30 |
31 | class Co3d_Seq(BaseStereoViewDataset):
32 | def __init__(self,
33 | mask_bg=True,
34 | ROOT="data/co3d_processed",
35 | num_views=2,
36 | degree=90, # degree range to select views
37 | sel_num=1, # number of views to select inside a degree range
38 | *args,
39 | **kwargs):
40 | self.ROOT = ROOT
41 | super().__init__(*args, **kwargs)
42 | assert mask_bg in (True, False, 'rand')
43 | self.mask_bg = mask_bg
44 | self.degree = degree
45 | self.winsize = int(degree / 360 * 100)
46 | self.sel_num = sel_num
47 | self.sel_num_perseq = (101 - self.winsize) * self.sel_num
48 | self.num_views = num_views
49 |
50 | # load all scenes
51 | if self.split == 'train':
52 | self.categories = TRAINING_CATEGORIES
53 | elif self.split == 'test':
54 | self.categories = TEST_CATEGORIES
55 | else:
56 | raise ValueError(f"Unknown split {self.split}")
57 | self.scenes = {}
58 | for cate in TRAINING_CATEGORIES:
59 | with open(osp.join(self.ROOT, cate, f'selected_seqs_{self.split}.json'), 'r') as f:
60 | self.scenes[cate] = json.load(f)
61 | self.scenes = {(k, k2): v2 for k, v in self.scenes.items()
62 | for k2, v2 in v.items()}
63 | self.scene_list = list(self.scenes.keys()) # for each scene, we have about 100 images ==> 360 degrees (so 25 frames ~= 90 degrees)
64 | self.scene_lens = [len(v) for k,v in self.scenes.items()]
65 | # print(np.unique(np.array(self.scene_lens)))
66 | self.invalidate = {scene: {} for scene in self.scene_list}
67 |
68 | print(self)
69 |
70 | def __len__(self):
71 | return len(self.scene_list) * self.sel_num_perseq
72 |
73 | def get_img_idxes(self, idx, rng):
74 | sid = max(0, idx // self.sel_num - 1) #from 0 to 99-winsize
75 | eid = sid + self.winsize
76 | if idx % self.sel_num == 0:
77 | # generate a uniform sample between sid and eid
78 | return np.linspace(sid, eid, self.num_views, endpoint=True, dtype=int)
79 |
80 | # select the first and last, and randomly select the rest n-2 in between
81 | if self.num_views == 2:
82 | return [sid, eid]
83 | sel_ids = rng.choice(range(sid+1, eid), self.num_views-2, replace=False)
84 | sel_ids.sort()
85 | return [sid] + list(sel_ids) + [eid]
86 |
87 |
88 | def _get_views(self, idx, resolution, rng):
89 | # choose a scene
90 | obj, instance = self.scene_list[idx // self.sel_num_perseq]
91 | image_pool = self.scenes[obj, instance]
92 | last = len(image_pool)-1
93 | if last <= self.winsize:
94 | return self._get_views(rng.integers(0, len(self)-1), resolution, rng)
95 |
96 | imgs_idxs = self.get_img_idxes(idx % self.sel_num_perseq, rng)
97 |
98 | for i, idx in enumerate(imgs_idxs):
99 | if idx > last:
100 | idx = idx % last
101 | imgs_idxs[i] = idx
102 | # print(imgs_idxs)
103 |
104 | if resolution not in self.invalidate[obj, instance]: # flag invalid images
105 | self.invalidate[obj, instance][resolution] = [False for _ in range(len(image_pool))]
106 |
107 | # decide now if we mask the bg
108 | mask_bg = (self.mask_bg == True) or (self.mask_bg == 'rand' and rng.choice(2))
109 |
110 | views = []
111 | imgs_idxs = deque(imgs_idxs)
112 |
113 | while len(imgs_idxs) > 0: # some images (few) have zero depth
114 | im_idx = imgs_idxs.popleft()
115 |
116 | if self.invalidate[obj, instance][resolution][im_idx]:
117 | # search for a valid image
118 | random_direction = 2 * rng.choice(2) - 1
119 | for offset in range(1, len(image_pool)):
120 | tentative_im_idx = (im_idx + (random_direction * offset)) % len(image_pool)
121 | if not self.invalidate[obj, instance][resolution][tentative_im_idx]:
122 | im_idx = tentative_im_idx
123 | break
124 | if offset == len(image_pool) - 1:
125 | # no valid image found
126 | return self._get_views((idx+1)%len(self), resolution, rng)
127 |
128 | view_idx = image_pool[im_idx]
129 |
130 | impath = osp.join(self.ROOT, obj, instance, 'images', f'frame{view_idx:06n}.jpg')
131 |
132 | # load camera params
133 | input_metadata = np.load(impath.replace('jpg', 'npz'))
134 | camera_pose = input_metadata['camera_pose'].astype(np.float32)
135 | intrinsics = input_metadata['camera_intrinsics'].astype(np.float32)
136 |
137 | # load image and depth
138 | rgb_image = imread_cv2(impath)
139 | depthmap = imread_cv2(impath.replace('images', 'depths') + '.geometric.png', cv2.IMREAD_UNCHANGED)
140 | depthmap = (depthmap.astype(np.float32) / 65535) * np.nan_to_num(input_metadata['maximum_depth'])
141 | if mask_bg:
142 | # load object mask
143 | maskpath = osp.join(self.ROOT, obj, instance, 'masks', f'frame{view_idx:06n}.png')
144 | maskmap = imread_cv2(maskpath, cv2.IMREAD_UNCHANGED).astype(np.float32)
145 | maskmap = (maskmap / 255.0) > 0.1
146 |
147 | # update the depthmap with mask
148 | depthmap *= maskmap
149 |
150 | rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
151 | rgb_image, depthmap, intrinsics, resolution, rng=rng, info=impath)
152 |
153 | # TODO: check if this is resonable
154 | valid_depth = depthmap[depthmap > 0.0]
155 | if valid_depth.size > 0:
156 | median_depth = np.median(valid_depth)
157 | # print(f"median depth: {median_depth}")
158 | depthmap[depthmap > median_depth*3] = 0. # filter out floatig points
159 |
160 | num_valid = (depthmap > 0.0).sum()
161 | if num_valid == 0:
162 | # problem, invalidate image and retry
163 | self.invalidate[obj, instance][resolution][im_idx] = True
164 | imgs_idxs.append(im_idx)
165 | continue
166 |
167 | views.append(dict(
168 | img=rgb_image,
169 | depthmap=depthmap,
170 | camera_pose=camera_pose,
171 | camera_intrinsics=intrinsics,
172 | dataset='Co3d_v2',
173 | label=f"{obj}_{instance}_frame{view_idx:06n}.jpg",
174 | instance=osp.split(impath)[1],
175 | ))
176 | return views
177 |
178 |
179 | if __name__ == "__main__":
180 | from slam3r.datasets.base.base_stereo_view_dataset import view_name
181 | import os
182 | import trimesh
183 |
184 | num_views = 11
185 | dataset = Co3d_Seq(split='train',
186 | mask_bg=False, resolution=224, aug_crop=16,
187 | num_views=num_views, degree=90, sel_num=3)
188 |
189 | save_dir = "visualization/co3d_seq_views"
190 | os.makedirs(save_dir, exist_ok=True)
191 |
192 | # import tqdm
193 | # for idx in tqdm.tqdm(np.random.permutation(len(dataset))):
194 | # views = dataset[(idx,0)]
195 | # print([view['instance'] for view in views])
196 |
197 | for idx in np.random.permutation(len(dataset))[:10]:
198 | # for idx in range(len(dataset))[5:10000:2000]:
199 | os.makedirs(osp.join(save_dir, str(idx)), exist_ok=True)
200 | views = dataset[(idx,0)]
201 | assert len(views) == num_views
202 | all_pts = []
203 | all_color=[]
204 | for i, view in enumerate(views):
205 | img = np.array(view['img']).transpose(1, 2, 0)
206 | save_path = osp.join(save_dir, str(idx), f"{i}_{view['label']}")
207 | # img=cv2.COLOR_RGB2BGR(img)
208 | img=img[...,::-1]
209 | img = (img+1)/2
210 | cv2.imwrite(save_path, img*255)
211 | print(f"save to {save_path}")
212 | pts3d = np.array(view['pts3d']).reshape(-1,3)
213 | img = img[...,::-1]
214 | pct = trimesh.PointCloud(pts3d, colors=img.reshape(-1, 3))
215 | pct.export(save_path.replace('.jpg','.ply'))
216 | all_pts.append(pts3d)
217 | all_color.append(img.reshape(-1, 3))
218 | all_pts = np.concatenate(all_pts, axis=0)
219 | all_color = np.concatenate(all_color, axis=0)
220 | pct = trimesh.PointCloud(all_pts, all_color)
221 | pct.export(osp.join(save_dir, str(idx), f"all.ply"))
222 |
--------------------------------------------------------------------------------
/slam3r/datasets/project_aria_seq.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3 | #
4 | # --------------------------------------------------------
5 | # Dataloader for preprocessed project-aria dataset
6 | # --------------------------------------------------------
7 | import os.path as osp
8 | import os
9 | import cv2
10 | import numpy as np
11 | import math
12 |
13 | SLAM3R_DIR = osp.dirname(osp.dirname(osp.dirname(osp.abspath(__file__))))
14 | import sys # noqa: E402
15 | sys.path.insert(0, SLAM3R_DIR) # noqa: E402
16 | from slam3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset
17 | from slam3r.utils.image import imread_cv2
18 |
19 |
20 | class Aria_Seq(BaseStereoViewDataset):
21 | def __init__(self,
22 | ROOT='data/projectaria/ase_processed',
23 | num_views=2,
24 | scene_name=None, # specify scene name(s) to load
25 | sample_freq=1, # stride of the frmaes inside the sliding window
26 | start_freq=1, # start frequency for the sliding window
27 | filter=False, # filter out the windows with abnormally large stride
28 | rand_sel=False, # randomly select views from a window
29 | winsize=0, # window size to randomly select views
30 | sel_num=0, # number of combinations to randomly select from a window
31 | *args,**kwargs):
32 | super().__init__(*args, **kwargs)
33 | self.ROOT = ROOT
34 | self.sample_freq = sample_freq
35 | self.start_freq = start_freq
36 | self.num_views = num_views
37 |
38 | self.rand_sel = rand_sel
39 | if rand_sel:
40 | assert winsize > 0 and sel_num > 0
41 | comb_num = math.comb(winsize-1, num_views-2)
42 | assert comb_num >= sel_num
43 | self.winsize = winsize
44 | self.sel_num = sel_num
45 | else:
46 | self.winsize = sample_freq*(num_views-1)
47 |
48 | self.scene_names = os.listdir(self.ROOT)
49 | self.scene_names = [int(scene_name) for scene_name in self.scene_names if scene_name.isdigit()]
50 | self.scene_names = sorted(self.scene_names)
51 | self.scene_names = [str(scene_name) for scene_name in self.scene_names]
52 | total_scene_num = len(self.scene_names)
53 |
54 | if self.split == 'train':
55 | # choose 90% of the data as training set
56 | self.scene_names = self.scene_names[:int(total_scene_num*0.9)]
57 | elif self.split=='test':
58 | self.scene_names = self.scene_names[int(total_scene_num*0.9):]
59 | if scene_name is not None:
60 | assert self.split is None
61 | if isinstance(scene_name, list):
62 | self.scene_names = scene_name
63 | else:
64 | if isinstance(scene_name, int):
65 | scene_name = str(scene_name)
66 | assert isinstance(scene_name, str)
67 | self.scene_names = [scene_name]
68 |
69 | self._load_data(filter=filter)
70 | print(self)
71 |
72 | def filter_windows(self, sid, eid, image_names):
73 | return False
74 |
75 | def _load_data(self, filter=False):
76 | self.sceneids = []
77 | self.images = []
78 | self.intrinsics = [] #scene_num*(3,3)
79 | self.win_bid = []
80 |
81 | num_count = 0
82 | for id, scene_name in enumerate(self.scene_names):
83 | scene_dir = os.path.join(self.ROOT, scene_name)
84 | # print(id, scene_name)
85 | image_names = os.listdir(os.path.join(scene_dir, 'color'))
86 | image_names = sorted(image_names)
87 | intrinsic = np.loadtxt(os.path.join(scene_dir, 'intrinsic', 'intrinsic_color.txt'))[:3,:3]
88 | image_num = len(image_names)
89 | # precompute the window indices
90 | for i in range(0, image_num, self.start_freq):
91 | last_id = i+self.winsize
92 | if last_id >= image_num:
93 | break
94 | if filter and self.filter_windows(i, last_id, image_names):
95 | continue
96 | self.win_bid.append((num_count+i, num_count+last_id))
97 |
98 | self.intrinsics.append(intrinsic)
99 | self.images += image_names
100 | self.sceneids += [id,] * image_num
101 | num_count += image_num
102 | # print(self.sceneids, self.scene_names)
103 | self.intrinsics = np.stack(self.intrinsics, axis=0)
104 | print(self.intrinsics.shape)
105 | assert len(self.sceneids)==len(self.images), f"{len(self.sceneids)}, {len(self.images)}"
106 |
107 | def __len__(self):
108 | if self.rand_sel:
109 | return self.sel_num*len(self.win_bid)
110 | return len(self.win_bid)
111 |
112 | def get_img_idxes(self, idx, rng):
113 | if self.rand_sel:
114 | sid, eid = self.win_bid[idx//self.sel_num]
115 | if idx % self.sel_num == 0:
116 | return np.linspace(sid, eid, self.num_views, endpoint=True, dtype=int)
117 |
118 | if self.num_views == 2:
119 | return [sid, eid]
120 | sel_ids = rng.choice(range(sid+1, eid), self.num_views-2, replace=False)
121 | sel_ids.sort()
122 | return [sid] + list(sel_ids) + [eid]
123 | else:
124 | sid, eid = self.win_bid[idx]
125 | return [sid + i*self.sample_freq for i in range(self.num_views)]
126 |
127 |
128 | def _get_views(self, idx, resolution, rng):
129 |
130 | image_idxes = self.get_img_idxes(idx, rng)
131 | # print(image_idxes)
132 | views = []
133 | for view_idx in image_idxes:
134 | scene_id = self.sceneids[view_idx]
135 | scene_dir = osp.join(self.ROOT, self.scene_names[scene_id])
136 |
137 | intrinsics = self.intrinsics[scene_id]
138 | basename = self.images[view_idx]
139 | camera_pose = np.loadtxt(osp.join(scene_dir, 'pose', basename.replace('.jpg', '.txt')))
140 | # Load RGB image
141 | rgb_image = imread_cv2(osp.join(scene_dir, 'color', basename))
142 | # Load depthmap
143 | depthmap = imread_cv2(osp.join(scene_dir, 'depth', basename.replace('.jpg', '.png')), cv2.IMREAD_UNCHANGED)
144 | depthmap[~np.isfinite(depthmap)] = 0 # invalid
145 | depthmap = depthmap.astype(np.float32) / 1000
146 | depthmap[depthmap > 20] = 0 # invalid
147 |
148 | rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
149 | rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx)
150 |
151 | views.append(dict(
152 | img=rgb_image,
153 | depthmap=depthmap.astype(np.float32),
154 | camera_pose=camera_pose.astype(np.float32),
155 | camera_intrinsics=intrinsics.astype(np.float32),
156 | dataset='Aria',
157 | label=self.scene_names[scene_id] + '_' + basename,
158 | instance=f'{str(idx)}_{str(view_idx)}',
159 | ))
160 | # print([view['label'] for view in views])
161 | return views
162 |
163 | if __name__ == "__main__":
164 | import trimesh
165 |
166 | num_views = 4
167 | # dataset = Aria_Seq(resolution=(224,224),
168 | # num_views=num_views,
169 | # start_freq=1, sample_freq=2)
170 | dataset = Aria_Seq(split='train', resolution=(224,224),
171 | num_views=num_views,
172 | start_freq=1, rand_sel=True, winsize=6, sel_num=3)
173 | save_dir = "visualization/aria_seq_views"
174 | os.makedirs(save_dir, exist_ok=True)
175 |
176 | for idx in np.random.permutation(len(dataset))[:10]:
177 | os.makedirs(osp.join(save_dir, str(idx)), exist_ok=True)
178 | views = dataset[(idx,0)]
179 | assert len(views) == num_views
180 | all_pts = []
181 | all_color=[]
182 | for i, view in enumerate(views):
183 | img = np.array(view['img']).transpose(1, 2, 0)
184 | # save_path = osp.join(save_dir, str(idx), f"{'_'.join(view_name(view).split('/')[1:])}.jpg")
185 | save_path = osp.join(save_dir, str(idx), f"{i}_{view['label']}")
186 | # img=cv2.COLOR_RGB2BGR(img)
187 | img=img[...,::-1]
188 | img = (img+1)/2
189 | cv2.imwrite(save_path, img*255)
190 | print(f"save to {save_path}")
191 | img = img[...,::-1]
192 | pts3d = np.array(view['pts3d']).reshape(-1,3)
193 | pct = trimesh.PointCloud(pts3d, colors=img.reshape(-1, 3))
194 | pct.export(save_path.replace('.jpg','.ply'))
195 | all_pts.append(pts3d)
196 | all_color.append(img.reshape(-1, 3))
197 | all_pts = np.concatenate(all_pts, axis=0)
198 | all_color = np.concatenate(all_color, axis=0)
199 | pct = trimesh.PointCloud(all_pts, all_color)
200 | pct.export(osp.join(save_dir, str(idx), f"all.ply"))
--------------------------------------------------------------------------------
/slam3r/datasets/replica_seq.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3 | # --------------------------------------------------------
4 | # Dataloader for preprocessed Replica dataset provided by NICER-SLAM
5 | # --------------------------------------------------------
6 | import os.path as osp
7 | import os
8 | import cv2
9 | import numpy as np
10 | from glob import glob
11 | import json
12 | import trimesh
13 |
14 | SLAM3R_DIR = osp.dirname(osp.dirname(osp.dirname(osp.abspath(__file__))))
15 | import sys # noqa: E402
16 | sys.path.insert(0, SLAM3R_DIR) # noqa: E402
17 | from slam3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset
18 | from slam3r.utils.image import imread_cv2
19 |
20 |
21 | class Replica(BaseStereoViewDataset):
22 | def __init__(self,
23 | ROOT='data/Replica',
24 | num_views=2,
25 | num_fetch_views=None,
26 | sel_view=None,
27 | scene_name=None,
28 | sample_freq=20,
29 | start_freq=1,
30 | sample_dis=1,
31 | cycle=False,
32 | ref_id=-1,
33 | print_mess=False,
34 | *args,**kwargs):
35 | super().__init__(*args, **kwargs)
36 | self.ROOT = ROOT
37 | self.print_mess = print_mess
38 | self.sample_freq = sample_freq
39 | self.start_freq = start_freq
40 | self.sample_dis = sample_dis
41 | self.cycle=cycle
42 | self.num_fetch_views = num_fetch_views if num_fetch_views is not None else num_views
43 | self.sel_view = np.arange(num_views) if sel_view is None else np.array(sel_view)
44 | self.num_views = num_views
45 | assert ref_id < num_views
46 | self.ref_id = ref_id if ref_id >= 0 else (num_views-1) // 2
47 | self.scene_names = ["room0", "room1", "room2", "office0", "office1", "office2", "office3", "office4"]
48 | if self.split == 'train':
49 | self.scene_names = ["room0", "room1", "room2", "office0", "office1", "office2"]
50 | elif self.split=='val':
51 | self.scene_names = ["office3", "office4"]
52 | if scene_name is not None:
53 | assert self.split is None
54 | if isinstance(scene_name, list):
55 | self.scene_names = scene_name
56 | else:
57 | assert isinstance(scene_name, str)
58 | self.scene_names = [scene_name]
59 | self._load_data()
60 | print(self)
61 |
62 | def _load_data(self):
63 | self.sceneids = []
64 | self.image_paths = []
65 | self.trajectories = [] #c2w
66 | self.pairs = []
67 | with open(os.path.join(self.ROOT,"cam_params.json"),'r') as f:
68 | self.intrinsic = json.load(f)['camera']
69 | K = np.eye(3)
70 | K[0, 0] = self.intrinsic['fx']
71 | K[1, 1] = self.intrinsic['fy']
72 | K[0, 2] = self.intrinsic['cx']
73 | K[1, 2] = self.intrinsic['cy']
74 | self.intri_mat = K
75 | num_count = 0
76 | for id, scene_name in enumerate(self.scene_names):
77 | scene_dir = os.path.join(self.ROOT, scene_name)
78 | image_paths = sorted(glob(os.path.join(scene_dir,"results","frame*.jpg")))
79 |
80 | image_paths = image_paths[::self.sample_freq]
81 | image_num = len(image_paths)
82 |
83 | if not self.cycle:
84 | for i in range(0, image_num, self.start_freq):
85 | last_id = i+self.sample_dis*(self.num_fetch_views-1)
86 | if last_id >= image_num:
87 | break
88 | self.pairs.append([j+num_count for j in range(i,last_id+1,self.sample_dis)])
89 | else:
90 | for i in range(0, image_num, self.start_freq):
91 | pair = []
92 | for j in range(0, self.num_fetch_views):
93 | pair.append((i+(j-self.ref_id)*self.sample_dis+image_num)%image_num + num_count)
94 | self.pairs.append(pair)
95 |
96 | self.trajectories.append(np.loadtxt(os.path.join(scene_dir,"traj.txt")).reshape(-1,4,4)[::self.sample_freq])
97 | self.image_paths += image_paths
98 | self.sceneids += [id,] * image_num
99 | num_count += image_num
100 | # print(self.sceneids, self.scene_names)
101 | self.trajectories = np.concatenate(self.trajectories,axis=0)
102 | assert len(self.trajectories) == len(self.sceneids) and len(self.sceneids)==len(self.image_paths), f"{len(self.trajectories)}, {len(self.sceneids)}, {len(self.image_paths)}"
103 |
104 | def __len__(self):
105 | return len(self.pairs)
106 |
107 | def _get_views(self, idx, resolution, rng):
108 |
109 | image_idxes = self.pairs[idx]
110 | assert len(image_idxes) == self.num_fetch_views
111 | image_idxes = [image_idxes[i] for i in self.sel_view]
112 | views = []
113 | for view_idx in image_idxes:
114 | scene_id = self.sceneids[view_idx]
115 | camera_pose = self.trajectories[view_idx]
116 | image_path = self.image_paths[view_idx]
117 | image_name = os.path.basename(image_path)
118 | depth_path = image_path.replace(".jpg",".png").replace("frame","depth")
119 | # Load RGB image
120 | rgb_image = imread_cv2(image_path)
121 |
122 | # Load depthmap
123 | depthmap = imread_cv2(depth_path, cv2.IMREAD_UNCHANGED)
124 | depthmap = depthmap.astype(np.float32)
125 | depthmap[~np.isfinite(depthmap)] = 0 # TODO:invalid
126 | depthmap /= self.intrinsic['scale']
127 |
128 | rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
129 | rgb_image, depthmap, self.intri_mat, resolution, rng=rng, info=view_idx)
130 | # print(intrinsics)
131 | views.append(dict(
132 | img=rgb_image,
133 | depthmap=depthmap.astype(np.float32),
134 | camera_pose=camera_pose.astype(np.float32),
135 | camera_intrinsics=intrinsics.astype(np.float32),
136 | dataset='Replica',
137 | label=self.scene_names[scene_id] + '_' + image_name,
138 | instance=f'{str(idx)}_{str(view_idx)}',
139 | ))
140 | if self.print_mess:
141 | print(f"loading {[view['label'] for view in views]}")
142 | return views
143 |
144 |
145 | if __name__ == "__main__":
146 | num_views = 5
147 | dataset= Replica(ref_id=1, print_mess=True, cycle=True, resolution=224, num_views=num_views, sample_freq=100, seed=777, start_freq=1, sample_dis=1)
148 | save_dir = "visualization/replica_views"
149 |
150 | # combine the pointmaps from different views with c2ws
151 | # to check the correctness of the dataloader
152 | for idx in np.random.permutation(len(dataset))[:10]:
153 | # for idx in range(10):
154 | views = dataset[(idx,0)]
155 | os.makedirs(osp.join(save_dir, str(idx)), exist_ok=True)
156 | assert len(views) == num_views
157 | all_pts = []
158 | all_color = []
159 | for i, view in enumerate(views):
160 | img = np.array(view['img']).transpose(1, 2, 0)
161 | save_path = osp.join(save_dir, str(idx), f"{i}_{view['label']}")
162 | print(save_path)
163 | img=img[...,::-1]
164 | img = (img+1)/2
165 | cv2.imwrite(save_path, img*255)
166 | print(f"save to {save_path}")
167 | img = img[...,::-1]
168 | pts3d = np.array(view['pts3d']).reshape(-1,3)
169 | pct = trimesh.PointCloud(pts3d, colors=img.reshape(-1, 3))
170 | pct.export(save_path.replace('.jpg','.ply'))
171 | all_pts.append(pts3d)
172 | all_color.append(img.reshape(-1, 3))
173 | all_pts = np.concatenate(all_pts, axis=0)
174 | all_color = np.concatenate(all_color, axis=0)
175 | pct = trimesh.PointCloud(all_pts, all_color)
176 | pct.export(osp.join(save_dir, str(idx), f"all.ply"))
177 |
--------------------------------------------------------------------------------
/slam3r/datasets/scannetpp_seq.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3 | # --------------------------------------------------------
4 | # Dataloader for preprocessed scannet++
5 | # dataset at https://github.com/scannetpp/scannetpp - non-commercial research and educational purposes
6 | # https://kaldir.vc.in.tum.de/scannetpp/static/scannetpp-terms-of-use.pdf
7 | # See datasets_preprocess/preprocess_scannetpp.py
8 | # --------------------------------------------------------
9 | import os.path as osp
10 | import os
11 | import cv2
12 | import numpy as np
13 | import math
14 |
15 | SLAM3R_DIR = osp.dirname(osp.dirname(osp.dirname(osp.abspath(__file__))))
16 | import sys # noqa: E402
17 | sys.path.insert(0, SLAM3R_DIR) # noqa: E402
18 | from slam3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset
19 | from slam3r.utils.image import imread_cv2
20 |
21 |
22 | class ScanNetpp_Seq(BaseStereoViewDataset):
23 | def __init__(self,
24 | ROOT='data/scannetpp_processed',
25 | num_views=2,
26 | scene_name=None, # specify scene name(s) to load
27 | sample_freq=1, # stride of the frmaes inside the sliding window
28 | start_freq=1, # start frequency for the sliding window
29 | img_types=['iphone', 'dslr'],
30 | filter=False, # filter out the windows with abnormally large stride
31 | rand_sel=False, # randomly select views from a window
32 | winsize=0, # window size to randomly select views
33 | sel_num=0, # number of combinations to randomly select from a window
34 | *args,**kwargs):
35 | super().__init__(*args, **kwargs)
36 | self.ROOT = ROOT
37 | self.sample_freq = sample_freq
38 | self.start_freq = start_freq
39 | self.num_views = num_views
40 | self.img_types = img_types
41 |
42 | self.rand_sel = rand_sel
43 | if rand_sel:
44 | assert winsize > 0 and sel_num > 0
45 | comb_num = math.comb(winsize-1, num_views-2)
46 | assert comb_num >= sel_num
47 | self.winsize = winsize
48 | self.sel_num = sel_num
49 | else:
50 | self.winsize = sample_freq*(num_views-1)
51 |
52 | if self.split == 'train':
53 | with open(os.path.join(ROOT, 'splits', 'nvs_sem_train.txt'), 'r') as f:
54 | self.scene_names = f.read().splitlines()
55 | elif self.split=='test':
56 | with open(os.path.join(ROOT, 'splits', 'nvs_sem_val.txt'), 'r') as f:
57 | self.scene_names = f.read().splitlines()
58 | if scene_name is not None:
59 | assert self.split is None
60 | if isinstance(scene_name, list):
61 | self.scene_names = scene_name
62 | else:
63 | assert isinstance(scene_name, str)
64 | self.scene_names = [scene_name]
65 |
66 | self._load_data(filter=filter)
67 | print(self)
68 |
69 | def filter_windows(self, img_type, sid, eid, image_names):
70 | if img_type == 'iphone': # frame_000450.jpg
71 | start_id = int(image_names[sid].split('_')[-1].split('.')[0])
72 | end_id = int(image_names[eid].split('_')[-1].split('.')[0])
73 | base_stride = 10*self.winsize
74 | elif img_type == 'dslr': # DSC06967.jpg
75 | start_id = int(image_names[sid].split('.')[0][-5:])
76 | end_id = int(image_names[eid].split('.')[0][-5:])
77 | base_stride = self.winsize
78 | # filiter out the windows with abnormally large stride
79 | if end_id - start_id >= base_stride*3:
80 | return True
81 | return False
82 |
83 | def _load_data(self, filter=False):
84 | self.sceneids = []
85 | self.images = []
86 | self.intrinsics = [] #(3,3)
87 | self.trajectories = [] #c2w (4,4)
88 | self.win_bid = []
89 |
90 | num_count = 0
91 | for id, scene_name in enumerate(self.scene_names):
92 | scene_dir = os.path.join(self.ROOT, scene_name)
93 | for img_type in self.img_types:
94 | metadata_path = os.path.join(scene_dir, f'scene_{img_type}_metadata.npz')
95 | if not os.path.exists(metadata_path):
96 | continue
97 | metadata = np.load(metadata_path)
98 | image_names = metadata['images'].tolist()
99 |
100 | # check if the images are in the same sequence,
101 | # only work for certain scenes in ScanNet++ V2
102 | if img_type == 'dslr':
103 | prefixes = [name[0:3] for name in image_names]
104 | if len(set(prefixes)) > 1:
105 | # dslr images are not in the same sequence
106 | print(f"Warning: {scene_name} {img_type} images are not in the same sequence {set(prefixes)}")
107 | continue
108 |
109 | assert image_names == sorted(image_names)
110 | image_names = sorted(image_names)
111 | intrinsics = metadata['intrinsics']
112 | trajectories = metadata['trajectories']
113 | image_num = len(image_names)
114 | # precompute the window indices
115 | for i in range(0, image_num, self.start_freq):
116 | last_id = i+self.winsize
117 | if last_id >= image_num:
118 | break
119 | if filter and self.filter_windows(img_type, i, last_id, image_names):
120 | continue
121 | self.win_bid.append((num_count+i, num_count+last_id))
122 |
123 | self.trajectories.append(trajectories)
124 | self.intrinsics.append(intrinsics)
125 | self.images += image_names
126 | self.sceneids += [id,] * image_num
127 | num_count += image_num
128 | # print(self.sceneids, self.scene_names)
129 | self.trajectories = np.concatenate(self.trajectories,axis=0)
130 | self.intrinsics = np.concatenate(self.intrinsics, axis=0)
131 | assert len(self.trajectories) == len(self.sceneids) and len(self.sceneids)==len(self.images), f"{len(self.trajectories)}, {len(self.sceneids)}, {len(self.images)}"
132 |
133 | def __len__(self):
134 | if self.rand_sel:
135 | return self.sel_num*len(self.win_bid)
136 | return len(self.win_bid)
137 |
138 | def get_img_idxes(self, idx, rng):
139 | if self.rand_sel:
140 | sid, eid = self.win_bid[idx//self.sel_num]
141 | if idx % self.sel_num == 0:
142 | return np.linspace(sid, eid, self.num_views, endpoint=True, dtype=int)
143 |
144 | # random select the views, including the start and end view
145 | if self.num_views == 2:
146 | return [sid, eid]
147 | sel_ids = rng.choice(range(sid+1, eid), self.num_views-2, replace=False)
148 | sel_ids.sort()
149 | return [sid] + list(sel_ids) + [eid]
150 | else:
151 | sid, eid = self.win_bid[idx]
152 | return [sid + i*self.sample_freq for i in range(self.num_views)]
153 |
154 |
155 | def _get_views(self, idx, resolution, rng):
156 |
157 | image_idxes = self.get_img_idxes(idx, rng)
158 | # print(image_idxes)
159 | views = []
160 | for view_idx in image_idxes:
161 | scene_id = self.sceneids[view_idx]
162 | scene_dir = osp.join(self.ROOT, self.scene_names[scene_id])
163 |
164 | intrinsics = self.intrinsics[view_idx]
165 | camera_pose = self.trajectories[view_idx]
166 | basename = self.images[view_idx]
167 | # Load RGB image
168 | rgb_image = imread_cv2(osp.join(scene_dir, 'images', basename))
169 | # Load depthmap
170 | depthmap = imread_cv2(osp.join(scene_dir, 'depth', basename.replace('.jpg', '.png')), cv2.IMREAD_UNCHANGED)
171 | depthmap = depthmap.astype(np.float32) / 1000
172 | depthmap[~np.isfinite(depthmap)] = 0 # invalid
173 |
174 | rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
175 | rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx)
176 |
177 | views.append(dict(
178 | img=rgb_image,
179 | depthmap=depthmap.astype(np.float32),
180 | camera_pose=camera_pose.astype(np.float32),
181 | camera_intrinsics=intrinsics.astype(np.float32),
182 | dataset='ScanNet++',
183 | label=self.scene_names[scene_id] + '_' + basename,
184 | instance=f'{str(idx)}_{str(view_idx)}',
185 | ))
186 |
187 | return views
188 |
189 | if __name__ == "__main__":
190 | from slam3r.datasets.base.base_stereo_view_dataset import view_name
191 | import trimesh
192 |
193 | num_views = 5
194 | dataset = ScanNetpp_Seq(split='train', resolution=(224,224),
195 | num_views=num_views,
196 | start_freq=1, sample_freq=3)
197 | # dataset = ScanNetpp_Seq(split='train', resolution=(224,224),
198 | # num_views=num_views,
199 | # start_freq=1, rand_sel=True, winsize=6, sel_num=3)
200 | save_dir = "visualization/scannetpp_seq_views"
201 | os.makedirs(save_dir, exist_ok=True)
202 |
203 | for idx in np.random.permutation(len(dataset))[:10]:
204 | # for idx in range(len(dataset))[5:10000:2000]:
205 | os.makedirs(osp.join(save_dir, str(idx)), exist_ok=True)
206 | views = dataset[(idx,0)]
207 | assert len(views) == num_views
208 | all_pts = []
209 | all_color=[]
210 | for i, view in enumerate(views):
211 | img = np.array(view['img']).transpose(1, 2, 0)
212 | save_path = osp.join(save_dir, str(idx), f"{i}_{view['label']}")
213 | img=img[...,::-1]
214 | img = (img+1)/2
215 | cv2.imwrite(save_path, img*255)
216 | print(f"save to {save_path}")
217 | img = img[...,::-1]
218 | pts3d = np.array(view['pts3d']).reshape(-1,3)
219 | pct = trimesh.PointCloud(pts3d, colors=img.reshape(-1, 3))
220 | pct.export(save_path.replace('.jpg','.ply'))
221 | all_pts.append(pts3d)
222 | all_color.append(img.reshape(-1, 3))
223 | all_pts = np.concatenate(all_pts, axis=0)
224 | all_color = np.concatenate(all_color, axis=0)
225 | pct = trimesh.PointCloud(all_pts, all_color)
226 | pct.export(osp.join(save_dir, str(idx), f"all.ply"))
--------------------------------------------------------------------------------
/slam3r/datasets/seven_scenes_seq.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3 | #
4 | # --------------------------------------------------------
5 | # Dataloader for 7Scenes dataset
6 | # --------------------------------------------------------
7 | import os.path as osp
8 | import os
9 | import cv2
10 | import numpy as np
11 | import torch
12 | import itertools
13 | from glob import glob
14 | import json
15 | import trimesh
16 |
17 | SLAM3R_DIR = osp.dirname(osp.dirname(osp.dirname(osp.abspath(__file__))))
18 | import sys # noqa: E402
19 | sys.path.insert(0, SLAM3R_DIR) # noqa: E402
20 | from slam3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset
21 | from slam3r.utils.image import imread_cv2
22 |
23 | class SevenScenes_Seq(BaseStereoViewDataset):
24 | def __init__(self,
25 | ROOT='data/7Scenes',
26 | scene_id='office',
27 | seq_id=1,
28 | num_views=1,
29 | sample_freq=1,
30 | start_freq=1,
31 | cycle=False,
32 | ref_id=-1,
33 | *args,**kwargs):
34 | super().__init__(*args, **kwargs)
35 | self.ROOT = ROOT
36 | self.cycle = cycle
37 | self.scene_id = scene_id
38 | self.scene_names = [scene_id+'_seq-'+f"{seq_id:02d}"]
39 | self.seq_id = seq_id
40 | self.sample_freq = sample_freq
41 | self.start_freq = start_freq
42 | self.ref_id = ref_id if ref_id >= 0 else (num_views-1) // 2
43 | self.num_views = num_views
44 | self.num_fetch_views = self.num_views
45 | self.data_dir = os.path.join(self.ROOT, scene_id, f'seq-{seq_id:02d}')
46 | self._load_data()
47 | print(self)
48 |
49 | def _load_data(self):
50 | self.intrinsics = np.array([[585, 0, 320],
51 | [0, 585, 240],
52 | [0, 0, 1]], dtype=np.float32)
53 | self.trajectories = [] #c2w (4,4)
54 | self.pairs = []
55 | self.images = sorted(glob(osp.join(self.data_dir, '*.color.png'))) #frame-000000.color.png
56 | image_num = len(self.images)
57 | #这两行能否提速
58 | if not self.cycle:
59 | for i in range(0, image_num, self.start_freq):
60 | last_id = i+(self.num_views-1)*self.sample_freq
61 | if last_id >= image_num: break
62 | self.pairs.append([i+j*self.sample_freq
63 | for j in range(self.num_views)])
64 | else:
65 | for i in range(0, image_num, self.start_freq):
66 | pair = []
67 | for j in range(0, self.num_fetch_views):
68 | pair.append((i+(j-self.ref_id)*self.sample_freq+image_num)%image_num)
69 | self.pairs.append(pair)
70 | print(self.pairs)
71 | def __len__(self):
72 | return len(self.pairs)
73 | # return len(self.img_group)
74 |
75 | def _get_views(self, idx, resolution, rng):
76 |
77 | image_idxes = self.pairs[idx]
78 | views = []
79 | scene_dir = self.data_dir
80 |
81 | for view_idx in image_idxes:
82 |
83 | intrinsics = self.intrinsics
84 | img_path = self.images[view_idx]
85 |
86 | # Load RGB image
87 | rgb_image = imread_cv2(img_path)
88 | # Load depthmap(16-bit, PNG, invalid depth is set to 65535)
89 | depthmap = imread_cv2(img_path.replace('.color.png','.depth.png'), cv2.IMREAD_UNCHANGED)
90 | depthmap[depthmap == 65535] = 0
91 | depthmap = depthmap.astype(np.float32) / 1000
92 | depthmap[~np.isfinite(depthmap)] = 0 # invalid
93 | camera_pose = np.loadtxt(img_path.replace('.color.png','.pose.txt'))
94 |
95 | rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
96 | rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx)
97 |
98 | views.append(dict(
99 | img=rgb_image,
100 | depthmap=depthmap.astype(np.float32),
101 | camera_pose=camera_pose.astype(np.float32),
102 | camera_intrinsics=intrinsics.astype(np.float32),
103 | dataset='ScanNet++',
104 | label=img_path,
105 | instance=f'{str(idx)}_{str(view_idx)}',
106 | ))
107 | return views
108 |
109 |
110 | class SevenScenes_Seq_Cali(BaseStereoViewDataset):
111 | def __init__(self,
112 | ROOT='data/7s_dsac/dsac',
113 | scene_id='office',
114 | seq_id=1,
115 | num_views=1,
116 | sample_freq=1,
117 | start_freq=1,
118 | cycle=False,
119 | ref_id=-1,
120 | *args,**kwargs):
121 | super().__init__(*args, **kwargs)
122 | self.ROOT = ROOT
123 | self.cycle = cycle
124 | self.scene_id = scene_id
125 | self.scene_names = [scene_id+'_seq-'+f"{seq_id:02d}"]
126 | self.seq_id = seq_id
127 | self.sample_freq = sample_freq
128 | self.start_freq = start_freq
129 | self.ref_id = ref_id if ref_id >= 0 else (num_views-1) // 2
130 | self.num_views = num_views
131 | self.num_fetch_views = self.num_views
132 | self.data_dir = os.path.join(self.ROOT, scene_id, f'seq-{seq_id:02d}')
133 | self._load_data()
134 | print(self)
135 |
136 | def _load_data(self):
137 | self.intrinsics = np.array([[525, 0, 320],
138 | [0, 525, 240],
139 | [0, 0, 1]], dtype=np.float32)
140 | self.trajectories = [] #c2w (4,4)
141 | self.pairs = []
142 | self.images = sorted(glob(osp.join(self.data_dir, '*.color.png'))) #frame-000000.color.png
143 | image_num = len(self.images)
144 | #这两行能否提速
145 | if not self.cycle:
146 | for i in range(0, image_num, self.start_freq):
147 | last_id = i+(self.num_views-1)*self.sample_freq
148 | if last_id >= image_num: break
149 | self.pairs.append([i+j*self.sample_freq
150 | for j in range(self.num_views)])
151 | else:
152 | for i in range(0, image_num, self.start_freq):
153 | pair = []
154 | for j in range(0, self.num_fetch_views):
155 | pair.append((i+(j-self.ref_id)*self.sample_freq+image_num)%image_num)
156 | self.pairs.append(pair)
157 | # print(self.pairs)
158 | def __len__(self):
159 | return len(self.pairs)
160 | # return len(self.img_group)
161 |
162 | def _get_views(self, idx, resolution, rng):
163 |
164 | image_idxes = self.pairs[idx]
165 | views = []
166 | scene_dir = self.data_dir
167 |
168 | for view_idx in image_idxes:
169 |
170 | intrinsics = self.intrinsics
171 | img_path = self.images[view_idx]
172 |
173 | # Load RGB image
174 | rgb_image = imread_cv2(img_path)
175 | # Load depthmap(16-bit, PNG, invalid depth is set to 65535)
176 | depthmap = imread_cv2(img_path.replace('.color.png','.depth_cali.png'), cv2.IMREAD_UNCHANGED)
177 | depthmap[depthmap == 65535] = 0
178 | depthmap = depthmap.astype(np.float32) / 1000
179 | depthmap[~np.isfinite(depthmap)] = 0 # invalid
180 |
181 | camera_pose = np.loadtxt(img_path.replace('.color.png','.pose.txt'))
182 |
183 | rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
184 | rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx)
185 | print(intrinsics)
186 | views.append(dict(
187 | img=rgb_image,
188 | depthmap=depthmap.astype(np.float32),
189 | camera_pose=camera_pose.astype(np.float32),
190 | camera_intrinsics=intrinsics.astype(np.float32),
191 | dataset='SevenScenes',
192 | label=img_path,
193 | instance=f'{str(idx)}_{str(view_idx)}',
194 | ))
195 | return views
196 |
197 |
198 | if __name__ == "__main__":
199 | from slam3r.datasets.base.base_stereo_view_dataset import view_name
200 | from slam3r.viz import SceneViz, auto_cam_size
201 | from slam3r.utils.image import rgb
202 |
203 | num_views = 3
204 | dataset = SevenScenes_Seq(scene_id='office',seq_id=9,resolution=(224,224), num_views=num_views,
205 | start_freq=1, sample_freq=20)
206 | save_dir = "visualization/7scenes_seq_views"
207 | os.makedirs(save_dir, exist_ok=True)
208 |
209 | # for idx in np.random.permutation(len(dataset))[:10]:
210 | for idx in range(len(dataset))[:500:100]:
211 | os.makedirs(osp.join(save_dir, str(idx)), exist_ok=True)
212 | views = dataset[(idx,0)]
213 | assert len(views) == num_views
214 | all_pts = []
215 | all_color=[]
216 | for i, view in enumerate(views):
217 | img = np.array(view['img']).transpose(1, 2, 0)
218 | # save_path = osp.join(save_dir, str(idx), f"{'_'.join(view_name(view).split('/')[1:])}.jpg")
219 | print(view['label'])
220 | save_path = osp.join(save_dir, str(idx), f"{i}_{os.path.basename(view['label'])}")
221 | # img=cv2.COLOR_RGB2BGR(img)
222 | img=img[...,::-1]
223 | img = (img+1)/2
224 | cv2.imwrite(save_path, img*255)
225 | print(f"save to {save_path}")
226 | pts3d = np.array(view['pts3d']).reshape(-1,3)
227 | pct = trimesh.PointCloud(pts3d, colors=img.reshape(-1, 3))
228 | pct.export(save_path.replace('.png','.ply'))
229 | all_pts.append(pts3d)
230 | all_color.append(img.reshape(-1, 3))
231 | all_pts = np.concatenate(all_pts, axis=0)
232 | all_color = np.concatenate(all_color, axis=0)
233 | pct = trimesh.PointCloud(all_pts, all_color)
234 | pct.export(osp.join(save_dir, str(idx), f"all.ply"))
235 | # for idx in range(len(dataset)):
236 | # views = dataset[(idx,0)]
237 | # print([view['label'] for view in views])
--------------------------------------------------------------------------------
/slam3r/datasets/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3 |
--------------------------------------------------------------------------------
/slam3r/datasets/utils/cropping.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3 | #
4 | # --------------------------------------------------------
5 | # croppping utilities
6 | # --------------------------------------------------------
7 | import PIL.Image
8 | import os
9 | os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
10 | import cv2 # noqa
11 | import numpy as np # noqa
12 | from slam3r.utils.geometry import colmap_to_opencv_intrinsics, opencv_to_colmap_intrinsics # noqa
13 | try:
14 | lanczos = PIL.Image.Resampling.LANCZOS
15 | except AttributeError:
16 | lanczos = PIL.Image.LANCZOS
17 |
18 |
19 | class ImageList:
20 | """ Convenience class to aply the same operation to a whole set of images.
21 | """
22 |
23 | def __init__(self, images):
24 | if not isinstance(images, (tuple, list, set)):
25 | images = [images]
26 | self.images = []
27 | for image in images:
28 | if not isinstance(image, PIL.Image.Image):
29 | image = PIL.Image.fromarray(image)
30 | self.images.append(image)
31 |
32 | def __len__(self):
33 | return len(self.images)
34 |
35 | def to_pil(self):
36 | return tuple(self.images) if len(self.images) > 1 else self.images[0]
37 |
38 | @property
39 | def size(self):
40 | sizes = [im.size for im in self.images]
41 | assert all(sizes[0] == s for s in sizes)
42 | return sizes[0]
43 |
44 | def resize(self, *args, **kwargs):
45 | return ImageList(self._dispatch('resize', *args, **kwargs))
46 |
47 | def crop(self, *args, **kwargs):
48 | return ImageList(self._dispatch('crop', *args, **kwargs))
49 |
50 | def _dispatch(self, func, *args, **kwargs):
51 | return [getattr(im, func)(*args, **kwargs) for im in self.images]
52 |
53 |
54 | def rescale_image_depthmap(image, depthmap, camera_intrinsics, output_resolution):
55 | """ Jointly rescale a (image, depthmap)
56 | so that (out_width, out_height) >= output_res
57 | """
58 | image = ImageList(image)
59 | input_resolution = np.array(image.size) # (W,H)
60 | output_resolution = np.array(output_resolution)
61 | if depthmap is not None:
62 | # can also use this with masks instead of depthmaps
63 | # print(tuple(depthmap.shape[:2]), image.size[::-1])
64 | assert tuple(depthmap.shape[:2]) == image.size[::-1]
65 | assert output_resolution.shape == (2,)
66 | # define output resolution
67 | scale_final = max(output_resolution / image.size) + 1e-8
68 | output_resolution = np.floor(input_resolution * scale_final).astype(int)
69 |
70 | # first rescale the image so that it contains the crop
71 | image = image.resize(output_resolution, resample=lanczos)
72 | if depthmap is not None:
73 | depthmap = cv2.resize(depthmap, output_resolution, fx=scale_final,
74 | fy=scale_final, interpolation=cv2.INTER_NEAREST)
75 |
76 | # no offset here; simple rescaling
77 | camera_intrinsics = camera_matrix_of_crop(
78 | camera_intrinsics, input_resolution, output_resolution, scaling=scale_final)
79 |
80 | return image.to_pil(), depthmap, camera_intrinsics
81 |
82 |
83 | def camera_matrix_of_crop(input_camera_matrix, input_resolution, output_resolution, scaling=1, offset_factor=0.5, offset=None):
84 | # Margins to offset the origin
85 | margins = np.asarray(input_resolution) * scaling - output_resolution
86 | assert np.all(margins >= 0.0)
87 | if offset is None:
88 | offset = offset_factor * margins
89 |
90 | # Generate new camera parameters
91 | output_camera_matrix_colmap = opencv_to_colmap_intrinsics(input_camera_matrix)
92 | output_camera_matrix_colmap[:2, :] *= scaling
93 | output_camera_matrix_colmap[:2, 2] -= offset
94 | output_camera_matrix = colmap_to_opencv_intrinsics(output_camera_matrix_colmap)
95 |
96 | return output_camera_matrix
97 |
98 |
99 | def crop_image_depthmap(image, depthmap, camera_intrinsics, crop_bbox):
100 | """
101 | Return a crop of the input view.
102 | """
103 | image = ImageList(image)
104 | l, t, r, b = crop_bbox
105 |
106 | image = image.crop((l, t, r, b))
107 | depthmap = depthmap[t:b, l:r]
108 |
109 | camera_intrinsics = camera_intrinsics.copy()
110 | camera_intrinsics[0, 2] -= l
111 | camera_intrinsics[1, 2] -= t
112 |
113 | return image.to_pil(), depthmap, camera_intrinsics
114 |
115 |
116 | def bbox_from_intrinsics_in_out(input_camera_matrix, output_camera_matrix, output_resolution):
117 | out_width, out_height = output_resolution
118 | l, t = np.int32(np.round(input_camera_matrix[:2, 2] - output_camera_matrix[:2, 2]))
119 | crop_bbox = (l, t, l+out_width, t+out_height)
120 | return crop_bbox
121 |
--------------------------------------------------------------------------------
/slam3r/datasets/utils/transforms.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3 | #
4 | # --------------------------------------------------------
5 | # DUST3R default transforms
6 | # --------------------------------------------------------
7 | import torchvision.transforms as tvf
8 | from slam3r.utils.image import ImgNorm
9 |
10 | # define the standard image transforms
11 | ColorJitter = tvf.Compose([tvf.ColorJitter(0.5, 0.5, 0.5, 0.1), ImgNorm])
12 |
--------------------------------------------------------------------------------
/slam3r/datasets/wild_seq.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3 | #
4 | # --------------------------------------------------------
5 | # Dataloader for self-captured img sequence
6 | # --------------------------------------------------------
7 | import os.path as osp
8 | import torch
9 |
10 | SLAM3R_DIR = osp.dirname(osp.dirname(osp.dirname(osp.abspath(__file__))))
11 | import sys # noqa: E402
12 | sys.path.insert(0, SLAM3R_DIR) # noqa: E402
13 | from slam3r.utils.image import load_images
14 |
15 | class Seq_Data():
16 | def __init__(self,
17 | img_dir, # the directory of the img sequence
18 | img_size=224, # only img_size=224 is supported now
19 | silent=False,
20 | sample_freq=1, # the frequency of the imgs to be sampled
21 | num_views=-1, # only take the first num_views imgs in the img_dir
22 | start_freq=1,
23 | postfix=None, # the postfix of the img in the img_dir(.jpg, .png, ...)
24 | to_tensor=False,
25 | start_idx=0):
26 |
27 | # Note that only img_size=224 is supported now.
28 | # Imgs will be cropped and resized to 224x224, thus losing the information in the border.
29 | assert img_size==224, "Sorry, only img_size=224 is supported now."
30 |
31 | # load imgs with sequential number.
32 | # Imgs in the img_dir should have number in their names to indicate the order,
33 | # such as frame-0031.color.png, output_414.jpg, ...
34 | self.imgs = load_images(img_dir, size=img_size,
35 | verbose=not silent, img_freq=sample_freq,
36 | postfix=postfix, start_idx=start_idx, img_num=num_views)
37 |
38 | self.num_views = num_views if num_views > 0 else len(self.imgs)
39 | self.stride = start_freq
40 | self.img_num = len(self.imgs)
41 | if to_tensor:
42 | for img in self.imgs:
43 | img['true_shape'] = torch.tensor(img['true_shape'])
44 | self.make_groups()
45 | self.length = len(self.groups)
46 |
47 | if isinstance(img_dir, str):
48 | if img_dir[-1] == '/':
49 | img_dir = img_dir[:-1]
50 | self.scene_names = ['_'.join(img_dir.split('/')[-2:])]
51 |
52 | def make_groups(self):
53 | self.groups = []
54 | for start in range(0,self.img_num, self.stride):
55 | end = start + self.num_views
56 | if end > self.img_num:
57 | break
58 | self.groups.append(self.imgs[start:end])
59 |
60 | def __len__(self):
61 | return len(self.groups)
62 |
63 | def __getitem__(self, idx):
64 | return self.groups[idx]
65 |
66 |
67 |
68 | if __name__ == "__main__":
69 | from slam3r.datasets.base.base_stereo_view_dataset import view_name
70 | from slam3r.viz import SceneViz, auto_cam_size
71 | from slam3r.utils.image import rgb
72 |
73 | dataset = Seq_Data(img_dir="dataset/7Scenes/office-09",
74 | img_size=224, silent=False, sample_freq=10,
75 | num_views=5, start_freq=2, postfix="color.png")
76 | for i in range(len(dataset)):
77 | data = dataset[i]
78 | print([img['idx'] for img in data])
79 | # break
--------------------------------------------------------------------------------
/slam3r/heads/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3 | #
4 | # --------------------------------------------------------
5 | # head factory
6 | # --------------------------------------------------------
7 | from .linear_head import LinearPts3d
8 | from .dpt_head import create_dpt_head
9 |
10 |
11 | def head_factory(head_type, output_mode, net, has_conf=False):
12 | """" build a prediction head for the decoder
13 | """
14 | if head_type == 'linear' and output_mode == 'pts3d':
15 | return LinearPts3d(net, has_conf)
16 | elif head_type == 'dpt' and output_mode == 'pts3d':
17 | return create_dpt_head(net, has_conf=has_conf)
18 | else:
19 | raise NotImplementedError(f"unexpected {head_type=} and {output_mode=}")
20 |
--------------------------------------------------------------------------------
/slam3r/heads/dpt_head.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3 | #
4 | # --------------------------------------------------------
5 | # dpt head implementation for DUST3R
6 | # Downstream heads assume inputs of size B x N x C (where N is the number of tokens) ;
7 | # or if it takes as input the output at every layer, the attribute return_all_layers should be set to True
8 | # the forward function also takes as input a dictionnary img_info with key "height" and "width"
9 | # for PixelwiseTask, the output will be of dimension B x num_channels x H x W
10 | # --------------------------------------------------------
11 | from einops import rearrange
12 | from typing import List
13 | import torch
14 | import torch.nn as nn
15 |
16 | from .dpt_block import DPTOutputAdapter # noqa
17 | from .postprocess import postprocess
18 |
19 |
20 | class DPTOutputAdapter_fix(DPTOutputAdapter):
21 | """
22 | Adapt croco's DPTOutputAdapter implementation for dust3r:
23 | remove duplicated weigths, and fix forward for dust3r
24 | """
25 |
26 | def init(self, dim_tokens_enc=768):
27 | super().init(dim_tokens_enc)
28 | # these are duplicated weights
29 | del self.act_1_postprocess
30 | del self.act_2_postprocess
31 | del self.act_3_postprocess
32 | del self.act_4_postprocess
33 |
34 | def forward(self, encoder_tokens: List[torch.Tensor], image_size=None):
35 | assert self.dim_tokens_enc is not None, 'Need to call init(dim_tokens_enc) function first'
36 | # H, W = input_info['image_size']
37 | image_size = self.image_size if image_size is None else image_size
38 | H, W = image_size
39 | # print(encoder_tokens[0].shape, "size:", H, W)
40 | # Number of patches in height and width
41 | N_H = H // (self.stride_level * self.P_H)
42 | N_W = W // (self.stride_level * self.P_W)
43 |
44 | # Hook decoder onto 4 layers from specified ViT layers
45 | layers = [encoder_tokens[hook] for hook in self.hooks]
46 |
47 | # Extract only task-relevant tokens and ignore global tokens.
48 | layers = [self.adapt_tokens(l) for l in layers]
49 |
50 | # Reshape tokens to spatial representation
51 | layers = [rearrange(l, 'b (nh nw) c -> b c nh nw', nh=N_H, nw=N_W) for l in layers]
52 |
53 | layers = [self.act_postprocess[idx](l) for idx, l in enumerate(layers)]
54 | # Project layers to chosen feature dim
55 | layers = [self.scratch.layer_rn[idx](l) for idx, l in enumerate(layers)]
56 |
57 | # Fuse layers using refinement stages
58 | path_4 = self.scratch.refinenet4(layers[3])[:, :, :layers[2].shape[2], :layers[2].shape[3]]
59 | path_3 = self.scratch.refinenet3(path_4, layers[2])
60 | path_2 = self.scratch.refinenet2(path_3, layers[1])
61 | path_1 = self.scratch.refinenet1(path_2, layers[0])
62 |
63 | # Output head
64 | out = self.head(path_1)
65 |
66 | return out
67 |
68 |
69 | class PixelwiseTaskWithDPT(nn.Module):
70 | """ DPT module for dust3r, can return 3D points + confidence for all pixels"""
71 |
72 | def __init__(self, *, n_cls_token=0, hooks_idx=None, dim_tokens=None,
73 | output_width_ratio=1, num_channels=1, postprocess=None, depth_mode=None, conf_mode=None, **kwargs):
74 | super(PixelwiseTaskWithDPT, self).__init__()
75 | self.return_all_layers = True # backbone needs to return all layers
76 | self.postprocess = postprocess
77 | self.depth_mode = depth_mode
78 | self.conf_mode = conf_mode
79 |
80 | assert n_cls_token == 0, "Not implemented"
81 | dpt_args = dict(output_width_ratio=output_width_ratio,
82 | num_channels=num_channels,
83 | **kwargs)
84 | if hooks_idx is not None:
85 | dpt_args.update(hooks=hooks_idx)
86 | self.dpt = DPTOutputAdapter_fix(**dpt_args)
87 | dpt_init_args = {} if dim_tokens is None else {'dim_tokens_enc': dim_tokens}
88 | self.dpt.init(**dpt_init_args)
89 |
90 | def forward(self, x, img_info):
91 | # print(len(x), "img info", img_info)
92 | out = self.dpt(x, image_size=(img_info[0], img_info[1]))
93 | if self.postprocess:
94 | out = self.postprocess(out, self.depth_mode, self.conf_mode)
95 | return out
96 |
97 |
98 | def create_dpt_head(net, has_conf=False):
99 | """
100 | return PixelwiseTaskWithDPT for given net params
101 | """
102 | assert net.dec_depth > 9
103 | l2 = net.dec_depth
104 | feature_dim = 256
105 | last_dim = feature_dim//2
106 | out_nchan = 3
107 | ed = net.enc_embed_dim
108 | dd = net.dec_embed_dim
109 | return PixelwiseTaskWithDPT(num_channels=out_nchan + has_conf,
110 | feature_dim=feature_dim,
111 | last_dim=last_dim,
112 | hooks_idx=[0, l2*2//4, l2*3//4, l2],
113 | dim_tokens=[ed, dd, dd, dd],
114 | postprocess=postprocess,
115 | depth_mode=net.depth_mode,
116 | conf_mode=net.conf_mode,
117 | head_type='regression')
118 |
--------------------------------------------------------------------------------
/slam3r/heads/linear_head.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3 | #
4 | # --------------------------------------------------------
5 | # linear head implementation for DUST3R
6 | # --------------------------------------------------------
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | from .postprocess import postprocess
10 |
11 |
12 | class LinearPts3d (nn.Module):
13 | """
14 | Linear head for dust3r
15 | Each token outputs: - 16x16 3D points (+ confidence)
16 | """
17 |
18 | def __init__(self, net, has_conf=False):
19 | super().__init__()
20 | self.patch_size = net.patch_embed.patch_size[0]
21 | self.depth_mode = net.depth_mode
22 | self.conf_mode = net.conf_mode
23 | self.has_conf = has_conf
24 |
25 | self.proj = nn.Linear(net.dec_embed_dim, (3 + has_conf)*self.patch_size**2)
26 |
27 | def setup(self, croconet):
28 | pass
29 |
30 | def forward(self, decout, img_shape):
31 | H, W = img_shape
32 | tokens = decout[-1]
33 | B, S, D = tokens.shape #S is the number of tokens
34 |
35 | # extract 3D points
36 | feat = self.proj(tokens) # B,S,D ;D is property of all the 3D points in the patch
37 | feat = feat.transpose(-1, -2).view(B, -1, H//self.patch_size, W//self.patch_size)
38 | feat = F.pixel_shuffle(feat, self.patch_size) # B,3,H,W
39 |
40 | # permute + norm depth
41 | return postprocess(feat, self.depth_mode, self.conf_mode)
42 |
--------------------------------------------------------------------------------
/slam3r/heads/postprocess.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3 | #
4 | # --------------------------------------------------------
5 | # post process function for all heads: extract 3D points/confidence from output
6 | # --------------------------------------------------------
7 | import torch
8 |
9 |
10 | def postprocess(out, depth_mode, conf_mode):
11 | """
12 | extract 3D points/confidence from prediction head output
13 | """
14 | fmap = out.permute(0, 2, 3, 1) # B,H,W,3
15 | res = dict(pts3d=reg_dense_depth(fmap[:, :, :, 0:3], mode=depth_mode))
16 |
17 | if conf_mode is not None:
18 | res['conf'] = reg_dense_conf(fmap[:, :, :, 3], mode=conf_mode)
19 | return res
20 |
21 |
22 | def reg_dense_depth(xyz, mode):
23 | """
24 | extract 3D points from prediction head output
25 | """
26 | mode, vmin, vmax = mode
27 |
28 | no_bounds = (vmin == -float('inf')) and (vmax == float('inf'))
29 | assert no_bounds
30 |
31 | if mode == 'linear':
32 | if no_bounds:
33 | return xyz # [-inf, +inf]
34 | return xyz.clip(min=vmin, max=vmax)
35 |
36 | # distance to origin
37 | d = xyz.norm(dim=-1, keepdim=True)
38 | xyz = xyz / d.clip(min=1e-8)
39 |
40 | if mode == 'square':
41 | return xyz * d.square()
42 |
43 | if mode == 'exp':
44 | return xyz * torch.expm1(d)
45 |
46 | raise ValueError(f'bad {mode=}')
47 |
48 |
49 | def reg_dense_conf(x, mode):
50 | """
51 | extract confidence from prediction head output
52 | """
53 | mode, vmin, vmax = mode
54 | if mode == 'exp':
55 | return vmin + x.exp().clip(max=vmax-vmin)
56 | if mode == 'sigmoid':
57 | return (vmax - vmin) * torch.sigmoid(x) + vmin
58 | raise ValueError(f'bad {mode=}')
59 |
--------------------------------------------------------------------------------
/slam3r/inference.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3 | #
4 | # --------------------------------------------------------
5 | # utilities needed for the inference
6 | # --------------------------------------------------------
7 | import torch
8 | import numpy as np
9 |
10 | from .utils.misc import invalid_to_zeros
11 | from .utils.geometry import geotrf, inv
12 |
13 |
14 | def loss_of_one_batch(loss_func, batch, model, criterion, device,
15 | use_amp=False, ret=None,
16 | assist_model=None, train=False, epoch=0,
17 | args=None):
18 | if loss_func == "i2p":
19 | return loss_of_one_batch_multiview(batch, model, criterion,
20 | device, use_amp, ret,
21 | args.ref_id)
22 | elif loss_func == "i2p_corr_score":
23 | return loss_of_one_batch_multiview_corr_score(batch, model, criterion,
24 | device, use_amp, ret,
25 | args.ref_id)
26 | elif loss_func == "l2w":
27 | return loss_of_one_batch_l2w(
28 | batch, model, criterion,
29 | device, use_amp, ret,
30 | ref_ids=args.ref_ids, coord_frame_id=0,
31 | exclude_ident=True, to_zero=True
32 | )
33 | else:
34 | raise NotImplementedError
35 |
36 |
37 | def loss_of_one_batch_multiview(batch, model, criterion, device,
38 | use_amp=False, ret=None, ref_id=-1):
39 | """ Function to compute the reconstruction loss of the Image-to-Points model
40 | """
41 | views = batch
42 | for view in views:
43 | for name in 'img pts3d valid_mask camera_pose'.split(): # pseudo_focal
44 | if name not in view:
45 | continue
46 | view[name] = view[name].to(device, non_blocking=True)
47 |
48 | if ref_id == -1:
49 | ref_id = (len(views)-1)//2
50 |
51 | with torch.cuda.amp.autocast(enabled=bool(use_amp)):
52 | preds = model(views, ref_id=ref_id)
53 | assert len(preds) == len(views)
54 |
55 | with torch.cuda.amp.autocast(enabled=False):
56 | if criterion is None:
57 | loss = None
58 | else:
59 | loss = criterion(views, preds, ref_id=ref_id)
60 |
61 | result = dict(views=views, preds=preds, loss=loss)
62 | for i in range(len(preds)):
63 | result[f'pred{i+1}'] = preds[i]
64 | result[f'view{i+1}'] = views[i]
65 | return result[ret] if ret else result
66 |
67 |
68 | def loss_of_one_batch_multiview_corr_score(batch, model, criterion, device,
69 | use_amp=False, ret=None, ref_id=-1):
70 |
71 | views = batch
72 | for view in views:
73 | for name in 'img pts3d valid_mask camera_pose'.split(): # pseudo_focal
74 | if name not in view:
75 | continue
76 | view[name] = view[name].to(device, non_blocking=True)
77 |
78 | if ref_id == -1:
79 | ref_id = (len(views)-1)//2
80 |
81 | all_loss = [0, {}]
82 | with torch.cuda.amp.autocast(enabled=bool(use_amp)):
83 | preds = model(views, ref_id=ref_id, return_corr_score=True)
84 | assert len(preds) == len(views)
85 | for i,pred in enumerate(preds):
86 | if i == ref_id:
87 | continue
88 | patch_pseudo_conf = pred['pseudo_conf'] # (B,S)
89 | true_conf = (pred['conf']-1.).mean(dim=(1,2)) # (B,) mean(exp(x))
90 | pseudo_conf = torch.exp(patch_pseudo_conf).mean(dim=1) # (B,) mean(exp(batch(x)))
91 | pseudo_conf = pseudo_conf / (1+pseudo_conf)
92 | true_conf = true_conf / (1+true_conf)
93 | dis = torch.abs(pseudo_conf-true_conf)
94 | loss = dis.mean()
95 | # if loss.isinf():
96 | # print(((patch_pseudo_conf-patch_true_conf)**2).max())
97 | all_loss[0] += loss
98 | all_loss[1][f'pseudo_conf_loss_{i}'] = loss
99 |
100 | result = dict(views=views, preds=preds, loss=all_loss)
101 | for i in range(len(preds)):
102 | result[f'pred{i+1}'] = preds[i]
103 | result[f'view{i+1}'] = views[i]
104 | return result[ret] if ret else result
105 |
106 |
107 | def get_multiview_scale(pts:list, valid:list, norm_mode='avg_dis'):
108 | # adpat from DUSt3R
109 | for i in range(len(pts)):
110 | assert pts[i].ndim >= 3 and pts[i].shape[-1] == 3
111 | assert len(pts) == len(valid)
112 | norm_mode, dis_mode = norm_mode.split('_')
113 |
114 | if norm_mode == 'avg':
115 | # gather all points together (joint normalization)
116 | all_pts = []
117 | all_nnz = 0
118 | for i in range(len(pts)):
119 | nan_pts, nnz = invalid_to_zeros(pts[i], valid[i], ndim=3)
120 | # print(nnz,nan_pts.shape) #(B,) (B,H*W,3)
121 | all_pts.append(nan_pts)
122 | all_nnz += nnz
123 | all_pts = torch.cat(all_pts, dim=1)
124 | # compute distance to origin
125 | all_dis = all_pts.norm(dim=-1)
126 | if dis_mode == 'dis':
127 | pass # do nothing
128 | elif dis_mode == 'log1p':
129 | all_dis = torch.log1p(all_dis)
130 | else:
131 | raise ValueError(f'bad {dis_mode=}')
132 |
133 | norm_factor = all_dis.sum(dim=1) / (all_nnz + 1e-8)
134 | else:
135 | raise ValueError(f'bad {norm_mode=}')
136 |
137 | norm_factor = norm_factor.clip(min=1e-8)
138 | while norm_factor.ndim < pts[0].ndim:
139 | norm_factor.unsqueeze_(-1)
140 | # print('norm factor:', norm_factor)
141 | return norm_factor
142 |
143 |
144 | def loss_of_one_batch_l2w(batch, model, criterion, device,
145 | use_amp=False, ret=None,
146 | ref_ids=-1, coord_frame_id=0,
147 | exclude_ident=True, to_zero=True):
148 | """ Function to compute the reconstruction loss of the Local-to-World model
149 | ref_ids: list of indices of the suppporting frames(excluding the coord_frame)
150 | coord_frame_id: all the pointmaps input and output will be in the coord_frame_id's camera coordinate
151 | exclude_ident: whether to exclude the coord_frame to simulate real-life inference scenarios
152 | to_zero: whether to set the invalid points to zero
153 | """
154 | views = batch
155 | for view in views:
156 | for name in 'img pts3d pts3d_cam valid_mask camera_pose'.split(): # pseudo_focal
157 | if name not in view:
158 | continue
159 | view[name] = view[name].to(device, non_blocking=True)
160 |
161 | if coord_frame_id == -1:
162 | # ramdomly select a camera as the target camera
163 | coord_frame_id = np.random.randint(0, len(views))
164 | # print(coord_frame_id)
165 | c2w = views[coord_frame_id]['camera_pose']
166 | w2c = inv(c2w)
167 |
168 | # exclude the frame that has the identity pose
169 | if exclude_ident:
170 | views.pop(coord_frame_id)
171 |
172 | if ref_ids == -1:
173 | ref_ids = [i for i in range(len(views)-1)] # all views except the last one
174 | elif ref_ids == -2:
175 | #select half of the views randomly
176 | ref_ids = np.random.choice(len(views), len(views)//2, replace=False).tolist()
177 | else:
178 | assert isinstance(ref_ids, list)
179 |
180 | for id in ref_ids:
181 | views[id]['pts3d_world'] = geotrf(w2c, views[id]['pts3d']) #转移到目标坐标系
182 | norm_factor_world = get_multiview_scale([views[id]['pts3d_world'] for id in ref_ids],
183 | [views[id]['valid_mask'] for id in ref_ids],
184 | norm_mode='avg_dis')
185 | for id,view in enumerate(views):
186 | if id in ref_ids:
187 | view['pts3d_world'] = view['pts3d_world'].permute(0,3,1,2) / norm_factor_world
188 | else:
189 | norm_factor_src = get_multiview_scale([view['pts3d_cam']],
190 | [view['valid_mask']],
191 | norm_mode='avg_dis')
192 | view['pts3d_cam'] = view['pts3d_cam'].permute(0,3,1,2) / norm_factor_src
193 |
194 | if to_zero:
195 | for id,view in enumerate(views):
196 | valid_mask = view['valid_mask'].unsqueeze(1).float() # B,1,H,W
197 | if id in ref_ids:
198 | # print(view['pts3d_world'].shape, valid_mask.shape, (-valid_mask+1).sum())
199 | view['pts3d_world'] = view['pts3d_world'] * valid_mask
200 | else:
201 | view['pts3d_cam'] = view['pts3d_cam'] * valid_mask
202 |
203 | with torch.cuda.amp.autocast(enabled=bool(use_amp)):
204 | preds = model(views, ref_ids=ref_ids)
205 | assert len(preds) == len(views)
206 | with torch.cuda.amp.autocast(enabled=False):
207 | if criterion is None:
208 | loss = None
209 | else:
210 | loss = criterion(views, preds, ref_id=ref_ids, ref_camera=w2c, norm_scale=norm_factor_world)
211 |
212 | result = dict(views=views, preds=preds, loss=loss)
213 | for i in range(len(preds)):
214 | result[f'pred{i+1}'] = preds[i]
215 | result[f'view{i+1}'] = views[i]
216 | return result[ret] if ret else result
217 |
--------------------------------------------------------------------------------
/slam3r/losses.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3 | #
4 | # --------------------------------------------------------
5 | # Implementation of DUSt3R training losses
6 | # --------------------------------------------------------
7 | from copy import copy, deepcopy
8 | import torch
9 | import torch.nn as nn
10 |
11 | from slam3r.utils.geometry import inv, geotrf, depthmap_to_pts3d, multiview_normalize_pointcloud
12 |
13 | def get_pred_pts3d(gt, pred, use_pose=False):
14 | if 'depth' in pred and 'pseudo_focal' in pred:
15 | try:
16 | pp = gt['camera_intrinsics'][..., :2, 2]
17 | except KeyError:
18 | pp = None
19 | pts3d = depthmap_to_pts3d(**pred, pp=pp)
20 |
21 | elif 'pts3d' in pred:
22 | # pts3d from my camera
23 | pts3d = pred['pts3d']
24 |
25 | elif 'pts3d_in_other_view' in pred:
26 | # pts3d from the other camera, already transformed
27 | assert use_pose is True
28 | return pred['pts3d_in_other_view'] # return!
29 |
30 | if use_pose:
31 | camera_pose = pred.get('camera_pose')
32 | assert camera_pose is not None
33 | pts3d = geotrf(camera_pose, pts3d)
34 |
35 | return pts3d
36 |
37 | def Sum(*losses_and_masks):
38 | loss, mask = losses_and_masks[0]
39 | if loss.ndim > 0:
40 | # we are actually returning the loss for every pixels
41 | return losses_and_masks
42 | else:
43 | # we are returning the global loss
44 | for loss2, mask2 in losses_and_masks[1:]:
45 | loss = loss + loss2
46 | return loss
47 |
48 |
49 | class LLoss (nn.Module):
50 | """ L-norm loss
51 | """
52 |
53 | def __init__(self, reduction='mean'):
54 | super().__init__()
55 | self.reduction = reduction
56 |
57 | def forward(self, a, b):
58 | assert a.shape == b.shape and a.ndim >= 2 and 1 <= a.shape[-1] <= 3, f'Bad shape = {a.shape}'
59 | dist = self.distance(a, b)
60 | assert dist.ndim == a.ndim-1 # one dimension less
61 | if self.reduction == 'none':
62 | return dist
63 | if self.reduction == 'sum':
64 | return dist.sum()
65 | if self.reduction == 'mean':
66 | return dist.mean() if dist.numel() > 0 else dist.new_zeros(())
67 | raise ValueError(f'bad {self.reduction=} mode')
68 |
69 | def distance(self, a, b):
70 | raise NotImplementedError()
71 |
72 |
73 | class L21Loss (LLoss):
74 | """ Euclidean distance between 3d points """
75 |
76 | def distance(self, a, b):
77 | return torch.norm(a - b, dim=-1) # normalized L2 distance
78 |
79 |
80 | L21 = L21Loss()
81 |
82 |
83 | class Criterion (nn.Module):
84 | def __init__(self, criterion=None):
85 | super().__init__()
86 | assert isinstance(criterion, LLoss), f'{criterion} is not a proper criterion!'
87 | self.criterion = copy(criterion)
88 |
89 | def get_name(self):
90 | return f'{type(self).__name__}({self.criterion})'
91 |
92 | def with_reduction(self, mode):
93 | res = loss = deepcopy(self)
94 | while loss is not None:
95 | assert isinstance(loss, Criterion)
96 | loss.criterion.reduction = 'none' # make it return the loss for each sample
97 | loss = loss._loss2 # we assume loss is a Multiloss
98 | return res
99 |
100 |
101 | class MultiLoss (nn.Module):
102 | """ Easily combinable losses (also keep track of individual loss values):
103 | loss = MyLoss1() + 0.1*MyLoss2()
104 | Usage:
105 | Inherit from this class and override get_name() and compute_loss()
106 | """
107 |
108 | def __init__(self):
109 | super().__init__()
110 | self._alpha = 1
111 | self._loss2 = None
112 |
113 | def compute_loss(self, *args, **kwargs):
114 | raise NotImplementedError()
115 |
116 | def get_name(self):
117 | raise NotImplementedError()
118 |
119 | def __mul__(self, alpha):
120 | assert isinstance(alpha, (int, float))
121 | res = copy(self)
122 | res._alpha = alpha
123 | return res
124 | __rmul__ = __mul__ # same
125 |
126 | def __add__(self, loss2):
127 | assert isinstance(loss2, MultiLoss)
128 | res = cur = copy(self)
129 | # find the end of the chain
130 | while cur._loss2 is not None:
131 | cur = cur._loss2
132 | cur._loss2 = loss2
133 | return res
134 |
135 | def __repr__(self):
136 | name = self.get_name()
137 | if self._alpha != 1:
138 | name = f'{self._alpha:g}*{name}'
139 | if self._loss2:
140 | name = f'{name} + {self._loss2}'
141 | return name
142 |
143 | def forward(self, *args, **kwargs):
144 | loss = self.compute_loss(*args, **kwargs)
145 | if isinstance(loss, tuple):
146 | loss, details = loss
147 | elif loss.ndim == 0:
148 | details = {self.get_name(): float(loss)}
149 | else:
150 | details = {}
151 | loss = loss * self._alpha
152 |
153 | if self._loss2:
154 | loss2, details2 = self._loss2(*args, **kwargs)
155 | loss = loss + loss2
156 | details |= details2
157 |
158 | return loss, details
159 |
160 | class Jointnorm_Regr3D (Criterion, MultiLoss):
161 | """ Ensure that all 3D points are correct.
162 | Asymmetric loss: view1 is supposed to be the anchor.
163 |
164 | P1 = RT1 @ D1
165 | P2 = RT2 @ D2
166 | loss1 = (I @ pred_D1) - (RT1^-1 @ RT1 @ D1)
167 | loss2 = (RT21 @ pred_D2) - (RT1^-1 @ P2)
168 | = (RT21 @ pred_D2) - (RT1^-1 @ RT2 @ D2)
169 | gt and pred are transformed into localframe1
170 | """
171 |
172 | def __init__(self, criterion, norm_mode='avg_dis', gt_scale=False, dist_clip=None):
173 | super().__init__(criterion)
174 | self.norm_mode = norm_mode
175 | self.gt_scale = gt_scale
176 | self.dist_clip = dist_clip
177 |
178 | def get_all_pts3d(self, gts, preds, ref_id, in_camera=None, norm_scale=None, dist_clip=None):
179 | # everything is normalized w.r.t. in_camera.
180 | # pointcloud normalization is conducted with the distance from the origin if norm_scale is None, otherwise use a fixed norm_scale
181 | if in_camera is None:
182 | in_camera = inv(gts[ref_id]['camera_pose'])
183 | gt_pts = []
184 | valids = []
185 | for gt in gts:
186 | gt_pts.append(geotrf(in_camera, gt['pts3d']))
187 | valids.append(gt['valid_mask'].clone())
188 |
189 | dist_clip = self.dist_clip if dist_clip is None else dist_clip
190 | if dist_clip is not None:
191 | # points that are too far-away == invalid
192 | for i in range(len(gts)):
193 | dis = gt_pts[i].norm(dim=-1)
194 | valids[i] = valids[i] & (dis conf_loss = x / 10 + alpha*log(10)
245 | low confidence means low conf = 10 ==> conf_loss = x * 10 - alpha*log(10)
246 |
247 | alpha: hyperparameter
248 | """
249 |
250 | def __init__(self, pixel_loss, alpha=1):
251 | super().__init__()
252 | assert alpha > 0
253 | self.alpha = alpha
254 | self.pixel_loss = pixel_loss.with_reduction('none')
255 |
256 | def get_name(self):
257 | return f'ConfLoss({self.pixel_loss})'
258 |
259 | def get_conf_log(self, x):
260 | return x, torch.log(x)
261 |
262 | def compute_loss(self, gts, preds, head='', **kw):
263 | # compute per-pixel loss
264 | losses_and_masks, details = self.pixel_loss(gts, preds, head=head, **kw)
265 | for i in range(len(losses_and_masks)):
266 | if losses_and_masks[i][0].numel() == 0:
267 | print(f'NO VALID POINTS in img{i+1}', force=True)
268 |
269 | res_loss = 0
270 | res_info = details
271 | for i in range(len(losses_and_masks)):
272 | loss = losses_and_masks[i][0]
273 | mask = losses_and_masks[i][1]
274 | conf, log_conf = self.get_conf_log(preds[i]['conf'][mask])
275 | conf_loss = loss * conf - self.alpha * log_conf
276 | conf_loss = conf_loss.mean() if conf_loss.numel() > 0 else 0
277 | res_loss += conf_loss
278 | info_name = f"conf_loss_{i+1}" if head == '' else f"conf_loss_{head}_{i+1}"
279 | res_info[info_name] = float(conf_loss)
280 |
281 | return res_loss, res_info
282 |
--------------------------------------------------------------------------------
/slam3r/patch_embed.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3 | #
4 | # --------------------------------------------------------
5 | # PatchEmbed implementation for DUST3R,
6 | # in particular ManyAR_PatchEmbed that Handle images with non-square aspect ratio
7 | # --------------------------------------------------------
8 | import torch
9 | from .blocks import PatchEmbed # noqa
10 |
11 | def get_patch_embed(patch_embed_cls, img_size, patch_size, enc_embed_dim):
12 | assert patch_embed_cls in ['PatchEmbedDust3R', 'ManyAR_PatchEmbed']
13 | patch_embed = eval(patch_embed_cls)(img_size, patch_size, 3, enc_embed_dim)
14 | return patch_embed
15 |
16 |
17 | class PatchEmbedDust3R(PatchEmbed):
18 | def forward(self, x, **kw):
19 | B, C, H, W = x.shape
20 | assert H % self.patch_size[0] == 0, f"Input image height ({H}) is not a multiple of patch size ({self.patch_size[0]})."
21 | assert W % self.patch_size[1] == 0, f"Input image width ({W}) is not a multiple of patch size ({self.patch_size[1]})."
22 | x = self.proj(x)
23 | pos = self.position_getter(B, x.size(2), x.size(3), x.device)
24 | if self.flatten:
25 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
26 | x = self.norm(x)
27 | return x, pos
28 |
29 |
30 | class ManyAR_PatchEmbed (PatchEmbed):
31 | """ Handle images with non-square aspect ratio.
32 | All images in the same batch have the same aspect ratio.
33 | true_shape = [(height, width) ...] indicates the actual shape of each image.
34 | """
35 |
36 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):
37 | self.embed_dim = embed_dim
38 | super().__init__(img_size, patch_size, in_chans, embed_dim, norm_layer, flatten)
39 |
40 | def forward(self, img, true_shape):
41 | B, C, H, W = img.shape
42 | assert W >= H, f'img should be in landscape mode, but got {W=} {H=}'
43 | assert H % self.patch_size[0] == 0, f"Input image height ({H}) is not a multiple of patch size ({self.patch_size[0]})."
44 | assert W % self.patch_size[1] == 0, f"Input image width ({W}) is not a multiple of patch size ({self.patch_size[1]})."
45 | assert true_shape.shape == (B, 2), f"true_shape has the wrong shape={true_shape.shape}"
46 |
47 | # size expressed in tokens
48 | W //= self.patch_size[0]
49 | H //= self.patch_size[1]
50 | n_tokens = H * W
51 |
52 | height, width = true_shape.T
53 | is_landscape = (width >= height)
54 | is_portrait = ~is_landscape
55 |
56 | # allocate result
57 | x = img.new_zeros((B, n_tokens, self.embed_dim))
58 | pos = img.new_zeros((B, n_tokens, 2), dtype=torch.int64)
59 |
60 | # linear projection, transposed if necessary
61 | x[is_landscape] = self.proj(img[is_landscape]).permute(0, 2, 3, 1).flatten(1, 2).float()
62 | x[is_portrait] = self.proj(img[is_portrait].swapaxes(-1, -2)).permute(0, 2, 3, 1).flatten(1, 2).float()
63 |
64 | pos[is_landscape] = self.position_getter(1, H, W, pos.device)
65 | pos[is_portrait] = self.position_getter(1, W, H, pos.device)
66 |
67 | x = self.norm(x)
68 | return x, pos
69 |
--------------------------------------------------------------------------------
/slam3r/pos_embed/__init__.py:
--------------------------------------------------------------------------------
1 | from .pos_embed import get_2d_sincos_pos_embed, RoPE2D
--------------------------------------------------------------------------------
/slam3r/pos_embed/curope/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3 |
4 | from .curope2d import cuRoPE2D
5 |
--------------------------------------------------------------------------------
/slam3r/pos_embed/curope/curope.cpp:
--------------------------------------------------------------------------------
1 | /*
2 | Copyright (C) 2022-present Naver Corporation. All rights reserved.
3 | Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
4 | */
5 |
6 | #include
7 |
8 | // forward declaration
9 | void rope_2d_cuda( torch::Tensor tokens, const torch::Tensor pos, const float base, const float fwd );
10 |
11 | void rope_2d_cpu( torch::Tensor tokens, const torch::Tensor positions, const float base, const float fwd )
12 | {
13 | const int B = tokens.size(0);
14 | const int N = tokens.size(1);
15 | const int H = tokens.size(2);
16 | const int D = tokens.size(3) / 4;
17 |
18 | auto tok = tokens.accessor();
19 | auto pos = positions.accessor();
20 |
21 | for (int b = 0; b < B; b++) {
22 | for (int x = 0; x < 2; x++) { // y and then x (2d)
23 | for (int n = 0; n < N; n++) {
24 |
25 | // grab the token position
26 | const int p = pos[b][n][x];
27 |
28 | for (int h = 0; h < H; h++) {
29 | for (int d = 0; d < D; d++) {
30 | // grab the two values
31 | float u = tok[b][n][h][d+0+x*2*D];
32 | float v = tok[b][n][h][d+D+x*2*D];
33 |
34 | // grab the cos,sin
35 | const float inv_freq = fwd * p / powf(base, d/float(D));
36 | float c = cosf(inv_freq);
37 | float s = sinf(inv_freq);
38 |
39 | // write the result
40 | tok[b][n][h][d+0+x*2*D] = u*c - v*s;
41 | tok[b][n][h][d+D+x*2*D] = v*c + u*s;
42 | }
43 | }
44 | }
45 | }
46 | }
47 | }
48 |
49 | void rope_2d( torch::Tensor tokens, // B,N,H,D
50 | const torch::Tensor positions, // B,N,2
51 | const float base,
52 | const float fwd )
53 | {
54 | TORCH_CHECK(tokens.dim() == 4, "tokens must have 4 dimensions");
55 | TORCH_CHECK(positions.dim() == 3, "positions must have 3 dimensions");
56 | TORCH_CHECK(tokens.size(0) == positions.size(0), "batch size differs between tokens & positions");
57 | TORCH_CHECK(tokens.size(1) == positions.size(1), "seq_length differs between tokens & positions");
58 | TORCH_CHECK(positions.size(2) == 2, "positions.shape[2] must be equal to 2");
59 | TORCH_CHECK(tokens.is_cuda() == positions.is_cuda(), "tokens and positions are not on the same device" );
60 |
61 | if (tokens.is_cuda())
62 | rope_2d_cuda( tokens, positions, base, fwd );
63 | else
64 | rope_2d_cpu( tokens, positions, base, fwd );
65 | }
66 |
67 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
68 | m.def("rope_2d", &rope_2d, "RoPE 2d forward/backward");
69 | }
70 |
--------------------------------------------------------------------------------
/slam3r/pos_embed/curope/curope2d.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3 |
4 | import torch
5 |
6 | try:
7 | import curope as _kernels # run `python setup.py install`
8 | except ModuleNotFoundError:
9 | from . import curope as _kernels # run `python setup.py build_ext --inplace`
10 |
11 |
12 | class cuRoPE2D_func (torch.autograd.Function):
13 |
14 | @staticmethod
15 | def forward(ctx, tokens, positions, base, F0=1):
16 | ctx.save_for_backward(positions)
17 | ctx.saved_base = base
18 | ctx.saved_F0 = F0
19 | # tokens = tokens.clone() # uncomment this if inplace doesn't work
20 | _kernels.rope_2d( tokens, positions, base, F0 )
21 | ctx.mark_dirty(tokens)
22 | return tokens
23 |
24 | @staticmethod
25 | def backward(ctx, grad_res):
26 | positions, base, F0 = ctx.saved_tensors[0], ctx.saved_base, ctx.saved_F0
27 | _kernels.rope_2d( grad_res, positions, base, -F0 )
28 | ctx.mark_dirty(grad_res)
29 | return grad_res, None, None, None
30 |
31 |
32 | class cuRoPE2D(torch.nn.Module):
33 | def __init__(self, freq=100.0, F0=1.0):
34 | super().__init__()
35 | self.base = freq
36 | self.F0 = F0
37 |
38 | def forward(self, tokens, positions):
39 | cuRoPE2D_func.apply( tokens.transpose(1,2), positions, self.base, self.F0 )
40 | return tokens
--------------------------------------------------------------------------------
/slam3r/pos_embed/curope/kernels.cu:
--------------------------------------------------------------------------------
1 | /*
2 | Copyright (C) 2022-present Naver Corporation. All rights reserved.
3 | Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
4 | */
5 |
6 | #include
7 | #include
8 | #include
9 | #include
10 |
11 | #define CHECK_CUDA(tensor) {\
12 | TORCH_CHECK((tensor).is_cuda(), #tensor " is not in cuda memory"); \
13 | TORCH_CHECK((tensor).is_contiguous(), #tensor " is not contiguous"); }
14 | void CHECK_KERNEL() {auto error = cudaGetLastError(); TORCH_CHECK( error == cudaSuccess, cudaGetErrorString(error));}
15 |
16 |
17 | template < typename scalar_t >
18 | __global__ void rope_2d_cuda_kernel(
19 | //scalar_t* __restrict__ tokens,
20 | torch::PackedTensorAccessor32 tokens,
21 | const int64_t* __restrict__ pos,
22 | const float base,
23 | const float fwd )
24 | // const int N, const int H, const int D )
25 | {
26 | // tokens shape = (B, N, H, D)
27 | const int N = tokens.size(1);
28 | const int H = tokens.size(2);
29 | const int D = tokens.size(3);
30 |
31 | // each block update a single token, for all heads
32 | // each thread takes care of a single output
33 | extern __shared__ float shared[];
34 | float* shared_inv_freq = shared + D;
35 |
36 | const int b = blockIdx.x / N;
37 | const int n = blockIdx.x % N;
38 |
39 | const int Q = D / 4;
40 | // one token = [0..Q : Q..2Q : 2Q..3Q : 3Q..D]
41 | // u_Y v_Y u_X v_X
42 |
43 | // shared memory: first, compute inv_freq
44 | if (threadIdx.x < Q)
45 | shared_inv_freq[threadIdx.x] = fwd / powf(base, threadIdx.x/float(Q));
46 | __syncthreads();
47 |
48 | // start of X or Y part
49 | const int X = threadIdx.x < D/2 ? 0 : 1;
50 | const int m = (X*D/2) + (threadIdx.x % Q); // index of u_Y or u_X
51 |
52 | // grab the cos,sin appropriate for me
53 | const float freq = pos[blockIdx.x*2+X] * shared_inv_freq[threadIdx.x % Q];
54 | const float cos = cosf(freq);
55 | const float sin = sinf(freq);
56 | /*
57 | float* shared_cos_sin = shared + D + D/4;
58 | if ((threadIdx.x % (D/2)) < Q)
59 | shared_cos_sin[m+0] = cosf(freq);
60 | else
61 | shared_cos_sin[m+Q] = sinf(freq);
62 | __syncthreads();
63 | const float cos = shared_cos_sin[m+0];
64 | const float sin = shared_cos_sin[m+Q];
65 | */
66 |
67 | for (int h = 0; h < H; h++)
68 | {
69 | // then, load all the token for this head in shared memory
70 | shared[threadIdx.x] = tokens[b][n][h][threadIdx.x];
71 | __syncthreads();
72 |
73 | const float u = shared[m];
74 | const float v = shared[m+Q];
75 |
76 | // write output
77 | if ((threadIdx.x % (D/2)) < Q)
78 | tokens[b][n][h][threadIdx.x] = u*cos - v*sin;
79 | else
80 | tokens[b][n][h][threadIdx.x] = v*cos + u*sin;
81 | }
82 | }
83 |
84 | void rope_2d_cuda( torch::Tensor tokens, const torch::Tensor pos, const float base, const float fwd )
85 | {
86 | const int B = tokens.size(0); // batch size
87 | const int N = tokens.size(1); // sequence length
88 | const int H = tokens.size(2); // number of heads
89 | const int D = tokens.size(3); // dimension per head
90 |
91 | TORCH_CHECK(tokens.stride(3) == 1 && tokens.stride(2) == D, "tokens are not contiguous");
92 | TORCH_CHECK(pos.is_contiguous(), "positions are not contiguous");
93 | TORCH_CHECK(pos.size(0) == B && pos.size(1) == N && pos.size(2) == 2, "bad pos.shape");
94 | TORCH_CHECK(D % 4 == 0, "token dim must be multiple of 4");
95 |
96 | // one block for each layer, one thread per local-max
97 | const int THREADS_PER_BLOCK = D;
98 | const int N_BLOCKS = B * N; // each block takes care of H*D values
99 | const int SHARED_MEM = sizeof(float) * (D + D/4);
100 |
101 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(tokens.type(), "rope_2d_cuda", ([&] {
102 | rope_2d_cuda_kernel <<>> (
103 | //tokens.data_ptr(),
104 | tokens.packed_accessor32(),
105 | pos.data_ptr(),
106 | base, fwd); //, N, H, D );
107 | }));
108 | }
109 |
--------------------------------------------------------------------------------
/slam3r/pos_embed/curope/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3 |
4 | from setuptools import setup
5 | from torch import cuda
6 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
7 |
8 | # compile for all possible CUDA architectures
9 | # all_cuda_archs = cuda.get_gencode_flags().replace('compute=','arch=').split()
10 | # alternatively, you can list cuda archs that you want, eg:
11 | all_cuda_archs = [
12 | # '-gencode', 'arch=compute_70,code=sm_70',
13 | # '-gencode', 'arch=compute_75,code=sm_75',
14 | # '-gencode', 'arch=compute_80,code=sm_80',
15 | # '-gencode', 'arch=compute_86,code=sm_86'
16 | "-arch=native"
17 | ]
18 |
19 | setup(
20 | name = 'curope',
21 | ext_modules = [
22 | CUDAExtension(
23 | name='curope',
24 | sources=[
25 | "curope.cpp",
26 | "kernels.cu",
27 | ],
28 | extra_compile_args = dict(
29 | nvcc=['-O3','--ptxas-options=-v',"--use_fast_math"]+all_cuda_archs,
30 | cxx=['-O3'])
31 | )
32 | ],
33 | cmdclass = {
34 | 'build_ext': BuildExtension
35 | })
36 |
--------------------------------------------------------------------------------
/slam3r/pos_embed/pos_embed.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3 |
4 |
5 | # --------------------------------------------------------
6 | # Position embedding utils
7 | # --------------------------------------------------------
8 |
9 |
10 |
11 | import numpy as np
12 |
13 | import torch
14 |
15 | # --------------------------------------------------------
16 | # 2D sine-cosine position embedding
17 | # References:
18 | # MAE: https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
19 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
20 | # MoCo v3: https://github.com/facebookresearch/moco-v3
21 | # --------------------------------------------------------
22 | def get_2d_sincos_pos_embed(embed_dim, grid_size, n_cls_token=0):
23 | """
24 | grid_size: int of the grid height and width
25 | return:
26 | pos_embed: [grid_size*grid_size, embed_dim] or [n_cls_token+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
27 | """
28 | grid_h = np.arange(grid_size, dtype=np.float32)
29 | grid_w = np.arange(grid_size, dtype=np.float32)
30 | grid = np.meshgrid(grid_w, grid_h) # here w goes first
31 | grid = np.stack(grid, axis=0)
32 |
33 | grid = grid.reshape([2, 1, grid_size, grid_size])
34 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
35 | if n_cls_token>0:
36 | pos_embed = np.concatenate([np.zeros([n_cls_token, embed_dim]), pos_embed], axis=0)
37 | return pos_embed
38 |
39 |
40 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
41 | assert embed_dim % 2 == 0
42 |
43 | # use half of dimensions to encode grid_h
44 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
45 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
46 |
47 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
48 | return emb
49 |
50 |
51 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
52 | """
53 | embed_dim: output dimension for each position
54 | pos: a list of positions to be encoded: size (M,)
55 | out: (M, D)
56 | """
57 | assert embed_dim % 2 == 0
58 | omega = np.arange(embed_dim // 2, dtype=float)
59 | omega /= embed_dim / 2.
60 | omega = 1. / 10000**omega # (D/2,)
61 |
62 | pos = pos.reshape(-1) # (M,)
63 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
64 |
65 | emb_sin = np.sin(out) # (M, D/2)
66 | emb_cos = np.cos(out) # (M, D/2)
67 |
68 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
69 | return emb
70 |
71 |
72 | # --------------------------------------------------------
73 | # Interpolate position embeddings for high-resolution
74 | # References:
75 | # MAE: https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
76 | # DeiT: https://github.com/facebookresearch/deit
77 | # --------------------------------------------------------
78 | def interpolate_pos_embed(model, checkpoint_model):
79 | if 'pos_embed' in checkpoint_model:
80 | pos_embed_checkpoint = checkpoint_model['pos_embed']
81 | embedding_size = pos_embed_checkpoint.shape[-1]
82 | num_patches = model.patch_embed.num_patches
83 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches
84 | # height (== width) for the checkpoint position embedding
85 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
86 | # height (== width) for the new position embedding
87 | new_size = int(num_patches ** 0.5)
88 | # class_token and dist_token are kept unchanged
89 | if orig_size != new_size:
90 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
91 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
92 | # only the position tokens are interpolated
93 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
94 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
95 | pos_tokens = torch.nn.functional.interpolate(
96 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
97 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
98 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
99 | checkpoint_model['pos_embed'] = new_pos_embed
100 |
101 |
102 | #----------------------------------------------------------
103 | # RoPE2D: RoPE implementation in 2D
104 | #----------------------------------------------------------
105 |
106 | try:
107 | from .curope import cuRoPE2D
108 | RoPE2D = cuRoPE2D
109 | except ImportError as e:
110 | # print(e)
111 | print(f'Warning, cannot find cuda-compiled version of RoPE2D, using a slow pytorch version instead')
112 |
113 | class RoPE2D(torch.nn.Module):
114 |
115 | def __init__(self, freq=100.0, F0=1.0):
116 | super().__init__()
117 | self.base = freq
118 | self.F0 = F0
119 | self.cache = {}
120 |
121 | def get_cos_sin(self, D, seq_len, device, dtype):
122 | if (D,seq_len,device,dtype) not in self.cache:
123 | inv_freq = 1.0 / (self.base ** (torch.arange(0, D, 2).float().to(device) / D))
124 | t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
125 | freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype)
126 | freqs = torch.cat((freqs, freqs), dim=-1)
127 | cos = freqs.cos() # (Seq, Dim)
128 | sin = freqs.sin()
129 | self.cache[D,seq_len,device,dtype] = (cos,sin)
130 | return self.cache[D,seq_len,device,dtype]
131 |
132 | @staticmethod
133 | def rotate_half(x):
134 | x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
135 | return torch.cat((-x2, x1), dim=-1)
136 |
137 | def apply_rope1d(self, tokens, pos1d, cos, sin):
138 | assert pos1d.ndim==2
139 | cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :]
140 | sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :]
141 | return (tokens * cos) + (self.rotate_half(tokens) * sin)
142 |
143 | def forward(self, tokens, positions):
144 | """
145 | input:
146 | * tokens: batch_size x nheads x ntokens x dim
147 | * positions: batch_size x ntokens x 2 (y and x position of each token)
148 | output:
149 | * tokens after appplying RoPE2D (batch_size x nheads x ntokens x dim)
150 | """
151 | assert tokens.size(3)%2==0, "number of dimensions should be a multiple of two"
152 | D = tokens.size(3) // 2
153 | assert positions.ndim==3 and positions.shape[-1] == 2 # Batch, Seq, 2
154 | cos, sin = self.get_cos_sin(D, int(positions.max())+1, tokens.device, tokens.dtype)
155 | # split features into two along the feature dimension, and apply rope1d on each half
156 | y, x = tokens.chunk(2, dim=-1)
157 | y = self.apply_rope1d(y, positions[:,:,0], cos, sin)
158 | x = self.apply_rope1d(x, positions[:,:,1], cos, sin)
159 | tokens = torch.cat((y, x), dim=-1)
160 | return tokens
--------------------------------------------------------------------------------
/slam3r/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3 |
--------------------------------------------------------------------------------
/slam3r/utils/device.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3 | #
4 | # --------------------------------------------------------
5 | # utilitary functions for DUSt3R
6 | # --------------------------------------------------------
7 | import numpy as np
8 | import torch
9 |
10 |
11 | def todevice(batch, device, callback=None, non_blocking=False):
12 | ''' Transfer some variables to another device (i.e. GPU, CPU:torch, CPU:numpy).
13 |
14 | batch: list, tuple, dict of tensors or other things
15 | device: pytorch device or 'numpy'
16 | callback: function that would be called on every sub-elements.
17 | '''
18 | if callback:
19 | batch = callback(batch)
20 |
21 | if isinstance(batch, dict):
22 | return {k: todevice(v, device) for k, v in batch.items()}
23 |
24 | if isinstance(batch, (tuple, list)):
25 | return type(batch)(todevice(x, device) for x in batch)
26 |
27 | x = batch
28 | if device == 'numpy':
29 | if isinstance(x, torch.Tensor):
30 | x = x.detach().cpu().numpy()
31 | elif x is not None:
32 | if isinstance(x, np.ndarray):
33 | x = torch.from_numpy(x)
34 | if torch.is_tensor(x):
35 | x = x.to(device, non_blocking=non_blocking)
36 | return x
37 |
38 |
39 | to_device = todevice # alias
40 |
41 |
42 | def to_numpy(x): return todevice(x, 'numpy')
43 | def to_cpu(x): return todevice(x, 'cpu')
44 | def to_cuda(x): return todevice(x, 'cuda')
45 |
46 |
47 | def collate_with_cat(whatever, lists=False):
48 | if isinstance(whatever, dict):
49 | return {k: collate_with_cat(vals, lists=lists) for k, vals in whatever.items()}
50 |
51 | elif isinstance(whatever, (tuple, list)):
52 | if len(whatever) == 0:
53 | return whatever
54 | elem = whatever[0]
55 | T = type(whatever)
56 |
57 | if elem is None:
58 | return None
59 | if isinstance(elem, (bool, float, int, str)):
60 | return whatever
61 | if isinstance(elem, tuple): #将whatever中含有的多个元组的对应位置的元素组合起来 [(1,2),(3,4)] -> [(1,3),(2,4)]
62 | return T(collate_with_cat(x, lists=lists) for x in zip(*whatever))
63 | if isinstance(elem, dict): #将whatever中含有的多个字典组合成一个字典
64 | return {k: collate_with_cat([e[k] for e in whatever], lists=lists) for k in elem}
65 |
66 | if isinstance(elem, torch.Tensor):
67 | return listify(whatever) if lists else torch.cat(whatever)
68 | if isinstance(elem, np.ndarray):
69 | return listify(whatever) if lists else torch.cat([torch.from_numpy(x) for x in whatever])
70 |
71 | # otherwise, we just chain lists
72 | return sum(whatever, T())
73 |
74 |
75 | def listify(elems):
76 | return [x for e in elems for x in e]
77 |
78 | class MyNvtxRange():
79 | def __init__(self, name):
80 | self.name = name
81 |
82 | def __enter__(self):
83 | torch.cuda.synchronize()
84 | torch.cuda.nvtx.range_push(self.name)
85 |
86 | def __exit__(self, type, value, traceback):
87 | torch.cuda.synchronize()
88 | torch.cuda.nvtx.range_pop()
89 |
--------------------------------------------------------------------------------
/slam3r/utils/image.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3 | #
4 | # --------------------------------------------------------
5 | # utilitary functions about images (loading/converting...)
6 | # --------------------------------------------------------
7 | import os
8 | import torch
9 | import numpy as np
10 | import PIL.Image
11 | from tqdm import tqdm
12 | from PIL.ImageOps import exif_transpose
13 | import torchvision.transforms as tvf
14 | os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
15 | import cv2 # noqa
16 |
17 | try:
18 | from pillow_heif import register_heif_opener # noqa
19 | register_heif_opener()
20 | heif_support_enabled = True
21 | except ImportError:
22 | heif_support_enabled = False
23 |
24 | from .geometry import depthmap_to_camera_coordinates
25 | import json
26 |
27 | ImgNorm = tvf.Compose([tvf.ToTensor(), tvf.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
28 |
29 |
30 | def imread_cv2(path, options=cv2.IMREAD_COLOR):
31 | """ Open an image or a depthmap with opencv-python.
32 | """
33 | if path.endswith(('.exr', 'EXR')):
34 | options = cv2.IMREAD_ANYDEPTH
35 | img = cv2.imread(path, options)
36 | if img is None:
37 | raise IOError(f'Could not load image={path} with {options=}')
38 | if img.ndim == 3:
39 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
40 | return img
41 |
42 |
43 | def rgb(ftensor, true_shape=None):
44 | if isinstance(ftensor, list):
45 | return [rgb(x, true_shape=true_shape) for x in ftensor]
46 | if isinstance(ftensor, torch.Tensor):
47 | ftensor = ftensor.detach().cpu().numpy() # H,W,3
48 | if ftensor.ndim == 3 and ftensor.shape[0] == 3:
49 | ftensor = ftensor.transpose(1, 2, 0)
50 | elif ftensor.ndim == 4 and ftensor.shape[1] == 3:
51 | ftensor = ftensor.transpose(0, 2, 3, 1)
52 | if true_shape is not None:
53 | H, W = true_shape
54 | ftensor = ftensor[:H, :W]
55 | if ftensor.dtype == np.uint8:
56 | img = np.float32(ftensor) / 255
57 | else:
58 | img = (ftensor * 0.5) + 0.5
59 | return img.clip(min=0, max=1)
60 |
61 |
62 | def _resize_pil_image(img, long_edge_size):
63 | S = max(img.size)
64 | if S > long_edge_size:
65 | interp = PIL.Image.LANCZOS
66 | elif S <= long_edge_size:
67 | interp = PIL.Image.BICUBIC
68 | new_size = tuple(int(round(x*long_edge_size/S)) for x in img.size)
69 | return img.resize(new_size, interp)
70 |
71 |
72 | def load_images(folder_or_list, size, square_ok=False,
73 | verbose=1, img_num=0, img_freq=0,
74 | postfix=None, start_idx=0):
75 | """ open and convert all images in a list or folder to proper input format for DUSt3R
76 | """
77 | if isinstance(folder_or_list, str):
78 | if verbose > 0:
79 | print(f'>> Loading images from {folder_or_list}')
80 | img_names = [name for name in os.listdir(folder_or_list) if not "depth" in name]
81 | if postfix is not None:
82 | img_names = [name for name in img_names if name.endswith(postfix)]
83 | root, folder_content = folder_or_list, img_names
84 |
85 | elif isinstance(folder_or_list, list):
86 | if verbose > 0:
87 | print(f'>> Loading a list of {len(folder_or_list)} images')
88 | root, folder_content = '', folder_or_list
89 |
90 | else:
91 | raise ValueError(f'bad {folder_or_list=} ({type(folder_or_list)})')
92 |
93 | # sort images by number in name
94 | len_postfix = len(postfix) if postfix is not None \
95 | else len(folder_content[0]) - folder_content[0].rfind('.')
96 |
97 | img_numbers = []
98 | for name in folder_content:
99 | dot_index = len(name) - len_postfix
100 | number_start = 0
101 | for i in range(dot_index-1, 0, -1):
102 | if not name[i].isdigit():
103 | number_start = i + 1
104 | break
105 | img_numbers.append(float(name[number_start:dot_index]))
106 | folder_content = [x for _, x in sorted(zip(img_numbers, folder_content))]
107 |
108 | if start_idx > 0:
109 | folder_content = folder_content[start_idx:]
110 | if(img_freq > 0):
111 | folder_content = folder_content[::img_freq]
112 | if(img_num > 0):
113 | folder_content = folder_content[:img_num]
114 |
115 | # print(root, folder_content)
116 |
117 | supported_images_extensions = ['.jpg', '.jpeg', '.png']
118 | if heif_support_enabled:
119 | supported_images_extensions += ['.heic', '.heif']
120 | supported_images_extensions = tuple(supported_images_extensions)
121 |
122 | imgs = []
123 | if verbose > 0:
124 | folder_content = tqdm(folder_content, desc='Loading images')
125 | for path in folder_content:
126 | if not path.lower().endswith(supported_images_extensions):
127 | continue
128 | img = exif_transpose(PIL.Image.open(os.path.join(root, path))).convert('RGB')
129 | W1, H1 = img.size
130 | if size == 224:
131 | # resize short side to 224 (then crop)
132 | img = _resize_pil_image(img, round(size * max(W1/H1, H1/W1)))
133 | else:
134 | # resize long side to 512
135 | img = _resize_pil_image(img, size)
136 | W, H = img.size
137 | cx, cy = W//2, H//2
138 | if size == 224:
139 | half = min(cx, cy)
140 | img = img.crop((cx-half, cy-half, cx+half, cy+half))
141 | else:
142 | halfw, halfh = ((2*cx)//16)*8, ((2*cy)//16)*8
143 | if not (square_ok) and W == H:
144 | halfh = 3*halfw/4
145 | img = img.crop((cx-halfw, cy-halfh, cx+halfw, cy+halfh))
146 |
147 | W2, H2 = img.size
148 | if verbose > 1:
149 | print(f' - adding {path} with resolution {W1}x{H1} --> {W2}x{H2}')
150 |
151 | imgs.append(dict(img=ImgNorm(img)[None], true_shape=np.int32(
152 | [img.size[::-1]]), idx=len(imgs), instance=str(len(imgs)), label=path))
153 |
154 | assert imgs, 'no images foud at '+ root
155 | if verbose > 0:
156 | print(f' ({len(imgs)} images loaded)')
157 | return imgs
158 |
159 |
160 |
161 | def crop_and_resize(image, depthmap, intrinsics, long_size, rng=None, info=None, use_crop=False):
162 | """ This function:
163 | 1. 将图片crop,使得其principal point真正落在中间
164 | 2. 根据图片横竖确定target resolution的横竖
165 | """
166 | import slam3r.datasets.utils.cropping as cropping
167 | if not isinstance(image, PIL.Image.Image):
168 | image = PIL.Image.fromarray(image)
169 |
170 | W, H = image.size
171 | cx, cy = intrinsics[:2, 2].round().astype(int)
172 | if(use_crop):
173 | # downscale with lanczos interpolation so that image.size == resolution
174 | # cropping centered on the principal point
175 | min_margin_x = min(cx, W-cx)
176 | min_margin_y = min(cy, H-cy)
177 | assert min_margin_x > W/5, f'Bad principal point in view={info}'
178 | assert min_margin_y > H/5, f'Bad principal point in view={info}'
179 | # the new window will be a rectangle of size (2*min_margin_x, 2*min_margin_y) centered on (cx,cy)
180 | l, t = cx - min_margin_x, cy - min_margin_y
181 | r, b = cx + min_margin_x, cy + min_margin_y
182 | crop_bbox = (l, t, r, b)
183 | image, depthmap, intrinsics = cropping.crop_image_depthmap(image, depthmap, intrinsics, crop_bbox)
184 |
185 | # transpose the resolution if necessary
186 | W, H = image.size # new size
187 | scale = long_size / max(W, H)
188 |
189 | # high-quality Lanczos down-scaling
190 | target_resolution = np.array([W, H]) * scale
191 |
192 | image, depthmap, intrinsics = cropping.rescale_image_depthmap(image, depthmap, intrinsics, target_resolution)
193 |
194 | return image, depthmap, intrinsics
195 |
196 |
197 | def load_scannetpp_images_pts3dcam(folder_or_list, size, square_ok=False, verbose=True, img_num=0, img_freq=0):
198 | """ open and convert all images in a list or folder to proper input format for DUSt3R
199 | """
200 | if isinstance(folder_or_list, str):
201 | if verbose:
202 | print(f'>> Loading images from {folder_or_list}')
203 | root, folder_content = folder_or_list, sorted(os.listdir(folder_or_list))
204 |
205 | elif isinstance(folder_or_list, list):
206 | if verbose:
207 | print(f'>> Loading a list of {len(folder_or_list)} images')
208 | root, folder_content = '', folder_or_list
209 |
210 | else:
211 | raise ValueError(f'bad {folder_or_list=} ({type(folder_or_list)})')
212 |
213 | if(img_freq > 0):
214 | folder_content = folder_content[1000::img_freq]
215 | if(img_num > 0):
216 | folder_content = folder_content[:img_num]
217 |
218 | supported_images_extensions = ['.jpg', '.jpeg', '.png']
219 | if heif_support_enabled:
220 | supported_images_extensions += ['.heic', '.heif']
221 | supported_images_extensions = tuple(supported_images_extensions)
222 |
223 | imgs = []
224 |
225 | intrinsic_path = os.path.join(os.path.dirname(root), 'pose_intrinsic_imu.json')
226 | with open(intrinsic_path, 'r') as f:
227 | info = json.load(f)
228 |
229 | for path in folder_content:
230 | if not path.lower().endswith(supported_images_extensions):
231 | continue
232 | img_path = os.path.join(root, path)
233 | img = exif_transpose(PIL.Image.open(img_path)).convert('RGB')
234 | W1, H1 = img.size
235 |
236 | depth_path = img_path.replace('.jpg', '.png').replace('rgb','depth')
237 | depthmap = imread_cv2(depth_path, cv2.IMREAD_UNCHANGED)
238 | depthmap = depthmap.astype(np.float32) / 1000.
239 | """
240 | img and depth has different convention about shape
241 | """
242 | # print(img.size, depthmap.shape)
243 | depthmap = cv2.resize(depthmap, (W1,H1), interpolation=cv2.INTER_CUBIC)
244 | # print(img.size, depthmap.shape)
245 | img_id = os.path.basename(img_path)[:-4]
246 | intrinsics = np.array(info[img_id]['intrinsic'])
247 | # print(img, depthmap, intrinsics)
248 | img, depthmap, intrinsics = crop_and_resize(img, depthmap, intrinsics, size)
249 | # print(img, depthmap, intrinsics)
250 | pts3d_cam, mask = depthmap_to_camera_coordinates(depthmap, intrinsics)
251 | pts3d_cam = pts3d_cam * mask[..., None]
252 | # print(pts3d_cam.shape)
253 | valid_mask = np.isfinite(pts3d_cam).all(axis=-1)
254 | W2, H2 = img.size
255 | if verbose:
256 | print(f' - adding {path} with resolution {W1}x{H1} --> {W2}x{H2}')
257 |
258 | imgs.append(dict(img=ImgNorm(img)[None],
259 | true_shape=np.int32([img.size[::-1]]),
260 | idx=len(imgs),
261 | instance=str(len(imgs)),
262 | pts3d_cam=pts3d_cam[None],
263 | valid_mask=valid_mask[None]
264 | ))
265 | # break
266 |
267 | assert imgs, 'no images foud at '+root
268 | if verbose:
269 | print(f' (Found {len(imgs)} images)')
270 | return imgs
271 |
272 |
--------------------------------------------------------------------------------
/slam3r/utils/misc.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3 | #
4 | # --------------------------------------------------------
5 | # utilitary functions for DUSt3R
6 | # --------------------------------------------------------
7 | import torch
8 |
9 |
10 | def fill_default_args(kwargs, func):
11 | import inspect # a bit hacky but it works reliably
12 | signature = inspect.signature(func)
13 |
14 | for k, v in signature.parameters.items():
15 | if v.default is inspect.Parameter.empty:
16 | continue
17 | kwargs.setdefault(k, v.default)
18 |
19 | return kwargs
20 |
21 |
22 | def freeze_all_params(modules):
23 | for module in modules:
24 | try:
25 | for n, param in module.named_parameters():
26 | param.requires_grad = False
27 | except AttributeError:
28 | # module is directly a parameter
29 | module.requires_grad = False
30 |
31 |
32 | def is_symmetrized(gt1, gt2):
33 | x = gt1['instance']
34 | y = gt2['instance']
35 | if len(x) == len(y) and len(x) == 1:
36 | return False # special case of batchsize 1
37 | ok = True
38 | for i in range(0, len(x), 2):
39 | ok = ok and (x[i] == y[i+1]) and (x[i+1] == y[i])
40 | return ok
41 |
42 |
43 | def flip(tensor):
44 | """ flip so that tensor[0::2] <=> tensor[1::2] """
45 | return torch.stack((tensor[1::2], tensor[0::2]), dim=1).flatten(0, 1)
46 |
47 |
48 | def interleave(tensor1, tensor2):
49 | res1 = torch.stack((tensor1, tensor2), dim=1).flatten(0, 1)
50 | res2 = torch.stack((tensor2, tensor1), dim=1).flatten(0, 1)
51 | return res1, res2
52 |
53 |
54 | def transpose_to_landscape(head, activate=True):
55 | """ Predict in the correct aspect-ratio,
56 | then transpose the result in landscape
57 | and stack everything back together.
58 | """
59 | def wrapper_no(decout, true_shape):
60 | B = len(true_shape)
61 | assert true_shape[0:1].allclose(true_shape), 'true_shape must be all identical'
62 | H, W = true_shape[0].cpu().tolist()
63 | res = head(decout, (H, W))
64 | return res
65 |
66 | def wrapper_yes(decout, true_shape):
67 | B = len(true_shape)
68 | # by definition, the batch is in landscape mode so W >= H
69 | H, W = int(true_shape.min()), int(true_shape.max())
70 |
71 | height, width = true_shape.T
72 | is_landscape = (width >= height)
73 | is_portrait = ~is_landscape
74 |
75 | # true_shape = true_shape.cpu()
76 | if is_landscape.all():
77 | return head(decout, (H, W))
78 | if is_portrait.all():
79 | return transposed(head(decout, (W, H)))
80 |
81 | # batch is a mix of both portraint & landscape
82 | def selout(ar): return [d[ar] for d in decout]
83 | l_result = head(selout(is_landscape), (H, W))
84 | p_result = transposed(head(selout(is_portrait), (W, H)))
85 |
86 | # allocate full result
87 | result = {}
88 | for k in l_result | p_result: #遍历字典的键
89 | x = l_result[k].new(B, *l_result[k].shape[1:])
90 | x[is_landscape] = l_result[k]
91 | x[is_portrait] = p_result[k]
92 | result[k] = x
93 |
94 | return result
95 |
96 | return wrapper_yes if activate else wrapper_no
97 |
98 |
99 | def transposed(dic):
100 | return {k: v.swapaxes(1, 2) for k, v in dic.items()}
101 |
102 |
103 | def invalid_to_nans(arr, valid_mask, ndim=999):
104 | if valid_mask is not None:
105 | arr = arr.clone()
106 | arr[~valid_mask] = float('nan')
107 | if arr.ndim > ndim:
108 | arr = arr.flatten(-2 - (arr.ndim - ndim), -2)
109 | return arr
110 |
111 |
112 | def invalid_to_zeros(arr, valid_mask, ndim=999):
113 | if valid_mask is not None:
114 | arr = arr.clone()
115 | arr[~valid_mask] = 0
116 | nnz = valid_mask.view(len(valid_mask), -1).sum(1)
117 | else:
118 | nnz = arr.numel() // len(arr) if len(arr) else 0 # number of point per image
119 | if arr.ndim > ndim:
120 | arr = arr.flatten(-2 - (arr.ndim - ndim), -2)
121 | return arr, nnz
122 |
--------------------------------------------------------------------------------
/slam3r/viz.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Visualization utilities. The code is adapted from Spann3r:
3 | # https://github.com/HengyiWang/spann3r/blob/main/spann3r/tools/vis.py
4 | # --------------------------------------------------------
5 |
6 | import numpy as np
7 | from tqdm import tqdm
8 | import open3d as o3d
9 | import imageio
10 | import os
11 | import os.path as osp
12 | import matplotlib.pyplot as plt
13 | import matplotlib.colors as mcolors
14 | from skimage import exposure
15 |
16 |
17 | def render_scene(vis, geometry, camera_parameters, bg_color=[1,1,1], point_size=1.,
18 | uint8=True):
19 | vis.clear_geometries()
20 | for g in geometry:
21 | vis.add_geometry(g)
22 |
23 | ctr = vis.get_view_control()
24 | ctr.convert_from_pinhole_camera_parameters(camera_parameters, allow_arbitrary=True)
25 |
26 | opt = vis.get_render_option()
27 | #调整点的大小
28 | opt.point_size = point_size
29 | opt.background_color = np.array(bg_color)
30 |
31 | vis.poll_events()
32 | vis.update_renderer()
33 |
34 | image = vis.capture_screen_float_buffer(do_render=True)
35 |
36 | if not uint8:
37 | return image
38 | else:
39 | image_uint8 = (np.asarray(image) * 255).astype(np.uint8)
40 | return image_uint8
41 |
42 | def render_frames(pts_all, image_all, camera_parameters, output_dir, mask=None, save_video=True, save_camera=True,
43 | init_ids=[],
44 | c2ws=None,
45 | vis_cam=False,
46 | save_stride=1,
47 | sample_ratio=1.,
48 | incremental=True,
49 | save_name='render_frames',
50 | bg_color=[1, 1, 1],
51 | point_size=1.,
52 | fps=10):
53 |
54 | t, h, w, _ = pts_all.shape
55 |
56 | vis = o3d.visualization.Visualizer()
57 | vis.create_window(width=960, height=544)
58 |
59 | render_frame_path = os.path.join(output_dir, save_name)
60 | os.makedirs(render_frame_path, exist_ok=True)
61 |
62 | if save_camera:
63 | o3d.io.write_pinhole_camera_parameters(os.path.join(render_frame_path, 'camera.json'), camera_parameters)
64 |
65 | video_path = os.path.join(output_dir, f'{save_name}.mp4')
66 | if save_video:
67 | writer = imageio.get_writer(video_path, fps=fps)
68 |
69 | # construct point cloud for initial window
70 | pcd = o3d.geometry.PointCloud()
71 | if init_ids is None: init_ids = []
72 | if len(init_ids) > 0:
73 | init_ids = np.array(init_ids)
74 | init_masks = mask[init_ids]
75 | init_pts = pts_all[init_ids][init_masks]
76 | init_colors = image_all[init_ids][init_masks]
77 | if sample_ratio < 1.:
78 | sampled_idx = np.random.choice(len(init_pts), int(len(init_pts)*sample_ratio), replace=False)
79 | init_pts = init_pts[sampled_idx]
80 | init_colors = init_colors[sampled_idx]
81 |
82 | pcd.points = o3d.utility.Vector3dVector(init_pts)
83 | pcd.colors = o3d.utility.Vector3dVector(init_colors)
84 |
85 | vis.add_geometry(pcd)
86 |
87 | # visualize incremental reconstruction
88 | for i in tqdm(range(t), desc="Rendering incremental reconstruction"):
89 | if i not in init_ids:
90 | new_pts = pts_all[i].reshape(-1, 3)
91 | new_colors = image_all[i].reshape(-1, 3)
92 |
93 | if mask is not None:
94 | new_pts = new_pts[mask[i].reshape(-1)]
95 | new_colors = new_colors[mask[i].reshape(-1)]
96 | if sample_ratio < 1.:
97 | sampled_idx = np.random.choice(len(new_pts), int(len(new_pts)*sample_ratio), replace=False)
98 | new_pts = new_pts[sampled_idx]
99 | new_colors = new_colors[sampled_idx]
100 | if incremental:
101 | pcd.points.extend(o3d.utility.Vector3dVector(new_pts))
102 | pcd.colors.extend(o3d.utility.Vector3dVector(new_colors))
103 | else:
104 | pcd.points = o3d.utility.Vector3dVector(new_pts)
105 | pcd.colors = o3d.utility.Vector3dVector(new_colors)
106 |
107 | if (i+1) % save_stride != 0:
108 | continue
109 |
110 | geometry = [pcd]
111 | if vis_cam:
112 | geometry = geometry + draw_camera(c2ws[i], img=image_all[i])
113 |
114 | image_uint8 = render_scene(vis, geometry, camera_parameters, bg_color=bg_color, point_size=point_size)
115 | frame_filename = f'frame_{i:03d}.png'
116 | imageio.imwrite(osp.join(render_frame_path, frame_filename), image_uint8)
117 | if save_video:
118 | writer.append_data(image_uint8)
119 |
120 | if save_video:
121 | writer.close()
122 |
123 | vis.destroy_window()
124 |
125 |
126 | def create_image_plane(img, c2w, scale=0.1):
127 | # simulate a image in 3D with point cloud
128 | H, W, _ = img.shape
129 | points = np.meshgrid(np.linspace(0, W-1, W), np.linspace(0, H-1, H))
130 | points = np.stack(points, axis=-1).reshape(-1, 2)
131 | #translate the center of focal
132 | points -= np.array([W/2, H/2])
133 |
134 | points *= 2*scale/W
135 | points = np.concatenate([points, 0.1*np.ones((len(points), 1))], axis=-1)
136 |
137 | colors = img.reshape(-1, 3)
138 |
139 | # no need for such resolution
140 | sample_stride = max(1, int(0.2/scale))
141 | points = points[::sample_stride]
142 | colors = colors[::sample_stride]
143 |
144 | pcd = o3d.geometry.PointCloud()
145 | pcd.points = o3d.utility.Vector3dVector(points)
146 | pcd.colors = o3d.utility.Vector3dVector(colors)
147 | pcd.transform(c2w)
148 |
149 | return pcd
150 |
151 | def draw_camera(c2w, cam_width=0.2/2, cam_height=0.2/2, f=0.10, color=[0, 1, 0],
152 | show_axis=True, img=None):
153 | points = [[0, 0, 0], [-cam_width, -cam_height, f], [cam_width, -cam_height, f],
154 | [cam_width, cam_height, f], [-cam_width, cam_height, f]]
155 | lines = [[0, 1], [0, 2], [0, 3], [0, 4], [1, 2], [2, 3], [3, 4], [4, 1]]
156 | colors = [color for i in range(len(lines))]
157 |
158 | line_set = o3d.geometry.LineSet()
159 | line_set.points = o3d.utility.Vector3dVector(points)
160 | line_set.lines = o3d.utility.Vector2iVector(lines)
161 | line_set.colors = o3d.utility.Vector3dVector(colors)
162 | line_set.transform(c2w)
163 |
164 | res = [line_set]
165 |
166 | if show_axis:
167 | axis = o3d.geometry.TriangleMesh.create_coordinate_frame()
168 | axis.scale(min(cam_width, cam_height), np.array([0., 0., 0.]))
169 | axis.transform(c2w)
170 | res.append(axis)
171 |
172 | if img is not None:
173 | # draw image in the plane of the camera
174 | img_plane = create_image_plane(img, c2w)
175 | res.append(img_plane)
176 |
177 | return res
178 |
179 | def find_render_cam(pcd, poses_all=None, cam_width=0.016, cam_height=0.012, cam_f=0.02):
180 | last_camera_params = None
181 |
182 | def print_camera_pose(vis):
183 | nonlocal last_camera_params
184 | ctr = vis.get_view_control()
185 | camera_params = ctr.convert_to_pinhole_camera_parameters()
186 | last_camera_params = camera_params
187 |
188 | print("Intrinsic matrix:")
189 | print(camera_params.intrinsic.intrinsic_matrix)
190 | print("\nExtrinsic matrix:")
191 | print(camera_params.extrinsic)
192 |
193 | return False
194 |
195 | vis = o3d.visualization.VisualizerWithKeyCallback()
196 | vis.create_window(width=960, height=544)
197 | vis.add_geometry(pcd)
198 | if poses_all is not None:
199 | for pose in poses_all:
200 | for geometry in draw_camera(pose, cam_width, cam_height, cam_f):
201 | vis.add_geometry(geometry)
202 |
203 | opt = vis.get_render_option()
204 | opt.point_size = 1
205 | opt.background_color = np.array([0, 0, 0])
206 |
207 | print_camera_pose(vis)
208 | print("Press the space key to record the current rendering view.")
209 | vis.register_key_callback(32, print_camera_pose)
210 |
211 | while vis.poll_events():
212 | vis.update_renderer()
213 |
214 | vis.destroy_window()
215 |
216 | return last_camera_params
217 |
218 | def vis_frame_preds(preds, type, save_path, norm_dims=(0, 1, 2),
219 | enhance_z=False, cmap=True,
220 | save_imgs=True, save_video=True, fps=10):
221 |
222 | if norm_dims is not None:
223 | min_val = preds.min(axis=norm_dims, keepdims=True)
224 | max_val = preds.max(axis=norm_dims, keepdims=True)
225 | preds = (preds - min_val) / (max_val - min_val)
226 |
227 | save_path = osp.join(save_path, type)
228 | if save_imgs:
229 | os.makedirs(save_path, exist_ok=True)
230 |
231 | if save_video:
232 | video_path = osp.join(osp.dirname(save_path), f'{type}.mp4')
233 | writer = imageio.get_writer(video_path, fps=fps)
234 |
235 | for frame_id in tqdm(range(preds.shape[0]), desc=f"Visualizing {type}"):
236 | pred_vis = preds[frame_id].astype(np.float32)
237 | if cmap:
238 | if preds.shape[-1] == 3:
239 | h = 1-pred_vis[...,0]
240 | s = 1-pred_vis[...,1]
241 | v = 1-pred_vis[...,2]
242 | if enhance_z:
243 | new_v = exposure.equalize_adapthist(v, clip_limit=0.01, nbins=256)
244 | v = new_v*0.2 + v*0.8
245 | pred_vis = mcolors.hsv_to_rgb(np.stack([h, s, v], axis=-1))
246 | elif len(pred_vis.shape)==2 or pred_vis.shape[-1] == 1:
247 | pred_vis = plt.cm.jet(pred_vis)
248 | pred_vis_rgb_uint8 = (pred_vis * 255).astype(np.uint8)
249 |
250 | if save_imgs:
251 | plt.imsave(osp.join(save_path, f'{type}_{frame_id:04d}.png'), pred_vis_rgb_uint8)
252 |
253 | if save_video:
254 | writer.append_data(pred_vis_rgb_uint8)
255 |
256 | if save_video:
257 | writer.close()
258 |
259 |
260 |
261 |
--------------------------------------------------------------------------------
/visualize.py:
--------------------------------------------------------------------------------
1 | import open3d as o3d
2 | import numpy as np
3 | import argparse
4 | import trimesh
5 | import torch
6 | from glob import glob
7 | from os.path import join
8 | from tqdm import tqdm
9 | import json
10 |
11 | from slam3r.utils.recon_utils import estimate_focal_knowing_depth, estimate_camera_pose
12 | from slam3r.viz import find_render_cam, render_frames, vis_frame_preds
13 |
14 | parser = argparse.ArgumentParser(description="Inference on a wild captured scene")
15 | parser.add_argument("--vis_cam", action="store_true", help="visualize camera poses")
16 | parser.add_argument("--vis_dir", type=str, required=True, help="directory to the predictions for visualization")
17 | parser.add_argument("--save_stride", type=int, default=1, help="the stride for visualizing per-frame predictions")
18 | parser.add_argument("--enhance_z", action="store_true", help="enhance the z axis for better visualization")
19 | parser.add_argument("--conf_thres_l2w", type=float, default=12, help="confidence threshold for filter out low-confidence points in L2W")
20 |
21 | def vis(args):
22 |
23 | root_dir = args.vis_dir
24 |
25 | preds_dir = join(args.vis_dir, "preds")
26 | local_pcds = np.load(join(preds_dir, 'local_pcds.npy')) # (V, 224, 224, 3)
27 | registered_pcds = np.load(join(preds_dir, 'registered_pcds.npy')) # (V, 224, 224, 3)
28 | local_confs = np.load(join(preds_dir, 'local_confs.npy')) # (V, 224, 224)
29 | registered_confs = np.load(join(preds_dir, 'registered_confs.npy')) # (V, 224, 224)
30 | rgb_imgs = np.load(join(preds_dir, 'input_imgs.npy')) # (V, 224, 224, 3)
31 |
32 | rgb_imgs = rgb_imgs/255.
33 |
34 | recon_res_path = glob(join(args.vis_dir, "*.ply"))[0]
35 | recon_res = trimesh.load(recon_res_path)
36 | whole_pcd = recon_res.vertices
37 | whole_colors = recon_res.visual.vertex_colors[:, :3]/255.
38 |
39 | # change to open3d coordinate x->x y->-y z->-z
40 | whole_pcd[..., 1:] *= -1
41 | registered_pcds[..., 1:] *= -1
42 |
43 | recon_pcd = o3d.geometry.PointCloud()
44 | recon_pcd.points = o3d.utility.Vector3dVector(whole_pcd)
45 | recon_pcd.colors = o3d.utility.Vector3dVector(whole_colors)
46 |
47 | # extract information about the initial window in the reconstruction
48 | num_views = local_pcds.shape[0]
49 | with open(join(preds_dir, "metadata.json"), 'r') as f:
50 | metadata = json.load(f)
51 | init_winsize = metadata['init_winsize']
52 | kf_stride = metadata['kf_stride']
53 | init_ids = list(range(0, init_winsize*kf_stride, kf_stride))
54 | init_ref_id = metadata['init_ref_id'] * kf_stride
55 |
56 | if args.vis_cam:
57 | # estimate camera intrinsics and poses
58 | principal_point = torch.tensor((local_pcds[0].shape[0]//2, local_pcds[0].shape[1]//2))
59 | init_window_focal = estimate_focal_knowing_depth(torch.tensor(local_pcds[init_ref_id][None]),
60 | principal_point,
61 | focal_mode='weiszfeld')
62 |
63 | focals = []
64 | for i in tqdm(range(num_views), desc="estimating intrinsics"):
65 | if i in init_ids:
66 | focals.append(init_window_focal)
67 | else:
68 | focal = estimate_focal_knowing_depth(torch.tensor(local_pcds[i:i+1]),
69 | principal_point,
70 | focal_mode='weiszfeld')
71 | focals.append(focal)
72 |
73 | intrinsics = []
74 | for i in range(num_views):
75 | intrinsic = np.eye(3)
76 | intrinsic[0, 0] = focals[i]
77 | intrinsic[1, 1] = focals[i]
78 | intrinsic[:2, 2] = principal_point
79 | intrinsics.append(intrinsic)
80 |
81 | mean_intrinsics = np.mean(np.stack(intrinsics,axis=0), axis=0)
82 | init_window_intrinsics = intrinsics[init_ref_id]
83 |
84 | c2ws = []
85 | for i in tqdm(range(0, num_views, 1), desc="estimating camera poses"):
86 | registered_pcd = registered_pcds[i]
87 | # c2w, succ = estimate_camera_pose(registered_pcd, init_window_intrinsics)
88 | c2w, succ = estimate_camera_pose(registered_pcd, mean_intrinsics)
89 | # c2w, succ = estimate_camera_pose(registered_pcd, intrinsics[i])
90 | if not succ:
91 | print(f"fail to estimate camera pose for view {i}")
92 | c2ws.append(c2w)
93 |
94 | # find the camera parameters for rendering incremental reconstruction process
95 | # It will show a window of open3d, and you can rotate and translate the camera
96 | # press space to save the camera parameters selected
97 | camera_parameters = find_render_cam(recon_pcd, c2ws if args.vis_cam else None)
98 | # render the incremental reconstruction process
99 | render_frames(registered_pcds, rgb_imgs, camera_parameters, root_dir,
100 | mask=(registered_confs > args.conf_thres_l2w),
101 | init_ids=init_ids,
102 | c2ws=c2ws if args.vis_cam else None,
103 | sample_ratio=1/args.save_stride,
104 | save_stride=args.save_stride,
105 | fps=10,
106 | vis_cam=args.vis_cam,
107 | )
108 |
109 | # save visualizations of per-frame predictions, and combine them into a video
110 | vis_frame_preds(local_confs[::args.save_stride], type="I2P_conf",
111 | save_path=root_dir)
112 | vis_frame_preds(registered_confs[::args.save_stride], type="L2W_conf",
113 | save_path=root_dir)
114 | vis_frame_preds(local_pcds[::args.save_stride], type="I2P_pcds",
115 | save_path=root_dir,
116 | enhance_z=args.enhance_z
117 | )
118 | vis_frame_preds(registered_pcds[::args.save_stride], type="L2W_pcds",
119 | save_path=root_dir,
120 | )
121 | vis_frame_preds(rgb_imgs[::args.save_stride], type="imgs",
122 | save_path=root_dir,
123 | norm_dims=None,
124 | cmap=False
125 | )
126 |
127 | if __name__ == "__main__":
128 | args = parser.parse_args()
129 | vis(args)
--------------------------------------------------------------------------------