├── .gitignore
├── .gitmodules
├── LICENSE.BSD
├── README.md
├── datasets
├── __init__.py
├── data_module.py
├── dataset.py
├── dataset_dict.py
├── euroc_dataset.py
├── nerf_dataset.py
├── real_sense_dataset.py
├── replica_dataset.py
└── tum_dataset.py
├── droid.pth
├── examples
├── __init__.py
└── slam_demo.py
├── factor_graph
├── __init__.py
├── factor.py
├── factor_graph.py
├── key.py
├── loss_function.py
└── variables.py
├── fusion
├── __init__.py
├── fusion_module.py
├── nerf_fusion.py
└── tsdf_fusion.py
├── gui
├── __init__.py
├── gui_module.py
└── open3d_gui.py
├── media
├── intro.gif
├── mit.png
├── mrg_logo.png
└── sparklab_logo.png
├── networks
├── __init__.py
├── droid_frontend.py
├── droid_net.py
├── factor_graph.py
├── geom
│ ├── __init__.py
│ ├── ba.py
│ ├── chol.py
│ ├── graph_utils.py
│ ├── losses.py
│ ├── projective_ops.py
│ └── rgbd_utils.py
├── modules
│ ├── __init__.py
│ ├── clipping.py
│ ├── corr.py
│ ├── extractor.py
│ └── gru.py
└── motion_filter.py
├── pipeline
├── __init__.py
├── pipeline.py
└── pipeline_module.py
├── requirements.txt
├── scripts
├── convergence_plots.ipynb
├── download_cube.bash
├── download_replica.bash
├── download_replica_sample.bash
├── record_real_sense.py
├── replica_results.py
├── replica_to_nerf_dataset.py
└── unzip_tartan_air.py
├── setup.py
├── slam
├── __init__.py
├── inertial_frontends
│ ├── __init__.py
│ └── inertial_frontend.py
├── meta_slam.py
├── slam_module.py
├── vio_slam.py
└── visual_frontends
│ ├── __init__.py
│ └── visual_frontend.py
├── solvers
├── __init__.py
├── linear_solver.py
├── meta_solver.py
└── nonlinear_solver.py
├── src
├── altcorr_kernel.cu
├── correlation_kernels.cu
├── droid.cpp
└── droid_kernels.cu
└── utils
├── __init__.py
├── evaluation.py
├── flow_viz.py
├── open3d_pickle.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | imgui.ini
2 | *.npz
3 | *.svg
4 |
5 | # Meshes
6 | *.pcd
7 | *.ply
8 | *.obj
9 |
10 |
11 | ### Python specific gitignore
12 | # Byte-compiled / optimized / DLL files
13 | __pycache__/
14 | *.py[cod]
15 | *$py.class
16 |
17 | # C extensions
18 | *.so
19 |
20 | # Distribution / packaging
21 | .Python
22 | build/
23 | develop-eggs/
24 | dist/
25 | downloads/
26 | eggs/
27 | .eggs/
28 | lib/
29 | lib64/
30 | parts/
31 | sdist/
32 | var/
33 | wheels/
34 | share/python-wheels/
35 | *.egg-info/
36 | .installed.cfg
37 | *.egg
38 | MANIFEST
39 |
40 | # PyInstaller
41 | # Usually these files are written by a python script from a template
42 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
43 | *.manifest
44 | *.spec
45 |
46 | # Installer logs
47 | pip-log.txt
48 | pip-delete-this-directory.txt
49 |
50 | # Unit test / coverage reports
51 | htmlcov/
52 | .tox/
53 | .nox/
54 | .coverage
55 | .coverage.*
56 | .cache
57 | nosetests.xml
58 | coverage.xml
59 | *.cover
60 | *.py,cover
61 | .hypothesis/
62 | .pytest_cache/
63 | cover/
64 |
65 | # Translations
66 | *.mo
67 | *.pot
68 |
69 | # Django stuff:
70 | *.log
71 | local_settings.py
72 | db.sqlite3
73 | db.sqlite3-journal
74 |
75 | # Flask stuff:
76 | instance/
77 | .webassets-cache
78 |
79 | # Scrapy stuff:
80 | .scrapy
81 |
82 | # Sphinx documentation
83 | docs/_build/
84 |
85 | # PyBuilder
86 | .pybuilder/
87 | target/
88 |
89 | # Jupyter Notebook
90 | .ipynb_checkpoints
91 |
92 | # IPython
93 | profile_default/
94 | ipython_config.py
95 |
96 | # pyenv
97 | # For a library or package, you might want to ignore these files since the code is
98 | # intended to run in multiple environments; otherwise, check them in:
99 | # .python-version
100 |
101 | # pipenv
102 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
103 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
104 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
105 | # install all needed dependencies.
106 | #Pipfile.lock
107 |
108 | # poetry
109 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
110 | # This is especially recommended for binary packages to ensure reproducibility, and is more
111 | # commonly ignored for libraries.
112 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
113 | #poetry.lock
114 |
115 | # pdm
116 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
117 | #pdm.lock
118 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
119 | # in version control.
120 | # https://pdm.fming.dev/#use-with-ide
121 | .pdm.toml
122 |
123 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
124 | __pypackages__/
125 |
126 | # Celery stuff
127 | celerybeat-schedule
128 | celerybeat.pid
129 |
130 | # SageMath parsed files
131 | *.sage.py
132 |
133 | # Environments
134 | .env
135 | .venv
136 | env/
137 | venv/
138 | ENV/
139 | env.bak/
140 | venv.bak/
141 |
142 | # Spyder project settings
143 | .spyderproject
144 | .spyproject
145 |
146 | # Rope project settings
147 | .ropeproject
148 |
149 | # mkdocs documentation
150 | /site
151 |
152 | # mypy
153 | .mypy_cache/
154 | .dmypy.json
155 | dmypy.json
156 |
157 | # Pyre type checker
158 | .pyre/
159 |
160 | # pytype static type analyzer
161 | .pytype/
162 |
163 | # Cython debug symbols
164 | cython_debug/
165 |
166 | # PyCharm
167 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
168 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
169 | # and can be added to the global gitignore or merged into this file. For a more nuclear
170 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
171 | #.idea/
172 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "thirdparty/lietorch"]
2 | path = thirdparty/lietorch
3 | url = https://github.com/princeton-vl/lietorch
4 | [submodule "thirdparty/eigen"]
5 | path = thirdparty/eigen
6 | url = https://gitlab.com/libeigen/eigen.git
7 | [submodule "thirdparty/instant-ngp"]
8 | path = thirdparty/instant-ngp
9 | url = https://github.com/ToniRV/instant-ngp.git
10 | branch = feature/nerf_slam
11 | [submodule "thirdparty/gtsam"]
12 | path = thirdparty/gtsam
13 | url = https://github.com/ToniRV/gtsam-1.git
14 | branch = feature/nerf_slam
15 |
--------------------------------------------------------------------------------
/LICENSE.BSD:
--------------------------------------------------------------------------------
1 | Copyright 2022 Massachusetts Institute of Technology.
2 |
3 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
4 |
5 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
6 |
7 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
8 |
9 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
10 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
12 |
13 |
14 |
15 |
NeRF-SLAM
16 |
17 |
18 | Real-Time Dense Monocular SLAM with Neural Radiance Fields
19 |
20 | Antoni Rosinol
21 | ·
22 | John J. Leonard
23 | ·
24 | Luca Carlone
25 |
26 |
27 |
28 | Paper |
29 | Video |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 | Table of Contents
42 |
43 | -
44 | Install
45 |
46 | -
47 | Download Datasets
48 |
49 | -
50 | Run
51 |
52 | -
53 | Citation
54 |
55 | -
56 | License
57 |
58 | -
59 | Acknowledgments
60 |
61 | -
62 | Contact
63 |
64 |
65 |
66 |
67 | ## Install
68 |
69 | Clone repo with submodules:
70 | ```
71 | git clone https://github.com/ToniRV/NeRF-SLAM.git --recurse-submodules
72 | git submodule update --init --recursive
73 | ```
74 |
75 | From this point on, use a virtual environment...
76 | Install torch (see [here](https://pytorch.org/get-started/previous-versions) for other versions):
77 | ```
78 | # CUDA 11.3
79 | pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113
80 | ```
81 |
82 | Pip install requirements:
83 | ```
84 | pip install -r requirements.txt
85 | pip install -r ./thirdparty/gtsam/python/requirements.txt
86 | ```
87 |
88 | Compile ngp (you need cmake>3.22):
89 | ```
90 | cmake ./thirdparty/instant-ngp -B build_ngp
91 | cmake --build build_ngp --config RelWithDebInfo -j
92 | ```
93 |
94 | Compile gtsam and enable the python wrapper:
95 | ```
96 | cmake ./thirdparty/gtsam -DGTSAM_BUILD_PYTHON=1 -B build_gtsam
97 | cmake --build build_gtsam --config RelWithDebInfo -j
98 | cd build_gtsam
99 | make python-install
100 | ```
101 |
102 | Install:
103 | ```
104 | python setup.py install
105 | ```
106 |
107 | ## Download Sample Data
108 |
109 | This will just download one of the replica scenes:
110 | ```
111 | ./scripts/download_replica_sample.bash
112 | ```
113 |
114 | ## Run
115 |
116 | ```
117 | python ./examples/slam_demo.py --dataset_dir=./datasets/Replica/office0 --dataset_name=nerf --buffer=100 --slam --parallel_run --img_stride=2 --fusion='nerf' --multi_gpu --gui
118 | ```
119 |
120 | This repo also implements [Sigma-Fusion](https://arxiv.org/abs/2210.01276): just change `--fusion='sigma'` to run that.
121 |
122 | ## FAQ
123 |
124 | ### GPU Memory
125 |
126 | This is a GPU memory intensive pipeline, to monitor your GPU usage, I'd recommend to use `nvitop`.
127 | Install nvitop in a local env:
128 | ```
129 | pip3 install --upgrade nvitop
130 | ```
131 |
132 | Keep it running on a terminal, and monitor GPU memory usage:
133 | ```
134 | nvitop --monitor
135 | ```
136 |
137 | If you consistently see "out-of-memory" errors, you may either need to change parameters or buy better GPUs :).
138 | The memory consuming parts of this pipeline are:
139 | - Frame to frame correlation volumes (but can be avoided using on-the-fly correlation computation).
140 | - Volumetric rendering (intrinsically memory intensive, tricks exist, but ultimately we need to move to light fields or some better representation (OpenVDB?)).
141 |
142 | ### Installation issues
143 |
144 | 1. Gtsam not working: check that the python wrapper is installed, check instructions here: [gtsam_python](https://github.com/ToniRV/gtsam-1/blob/develop/python/README.md). Make sure you use our gtsam fork, which exposes more of gtsam's functionality to python.
145 | 2. Gtsam's dependency is not really needed, I just used to experiment adding IMU and/or stereo cameras, and have an easier interface to build factor-graphs. This didn't quite work though, because the network seemed to have a concept of scale, and it didn't quite work when updating poses/landmarks and then optical flow.
146 | 3. Somehow the parser converts [this](https://github.com/borglab/gtsam/compare/develop...ToniRV:gtsam-1:feature/nerf_slam#diff-add3627555fb7411e36ea4d863c15f4187e018b6e00b608ab260e3221aef057aR345) to
147 | `const std::vector&`, and I need to remove manually in
148 | `gtsam/build/python/linear.cpp`
149 | the inner `const X& ...`, and also add `` because:
150 | ```
151 | Did you forget to `#include `?
152 | ```
153 |
154 | ## Citation
155 |
156 | ```bibtex
157 | @article{rosinol2022nerf,
158 | title={NeRF-SLAM: Real-Time Dense Monocular SLAM with Neural Radiance Fields},
159 | author={Rosinol, Antoni and Leonard, John J and Carlone, Luca},
160 | journal={arXiv preprint arXiv:2210.13641},
161 | year={2022}
162 | }
163 | ```
164 |
165 | ## License
166 |
167 | This repo is BSD Licensed.
168 | It reimplements parts of Droid-SLAM (BSD Licensed).
169 | Our changes to instant-NGP (Nvidia License) are released in our [fork of instant-ngp](https://github.com/ToniRV/instant-ngp) (branch `feature/nerf_slam`) and
170 | added here as a thirdparty dependency using git submodules.
171 |
172 | ## Acknowledgments
173 |
174 | This work has been possible thanks to the open-source code from [Droid-SLAM](https://github.com/princeton-vl/DROID-SLAM) and
175 | [Instant-NGP](https://github.com/NVlabs/instant-ngp), as well as the open-source datasets [Replica](https://github.com/facebookresearch/Replica-Dataset) and [Cube-Diorama](https://github.com/jc211/nerf-cube-diorama-dataset).
176 |
177 | ## Contact
178 |
179 | I have many ideas on how to improve this approach, but I just graduated so I won't have much time to do another PhD...
180 | If you are interested in building on top of this,
181 | feel free to reach out :) [arosinol@mit.edu](arosinol@mit.edu)
182 |
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 |
4 |
--------------------------------------------------------------------------------
/datasets/data_module.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import colored_glog as log
4 | from pipeline.pipeline_module import MIMOPipelineModule
5 |
6 | class DataModule(MIMOPipelineModule):
7 | def __init__(self, name, args, device="cpu") -> None:
8 | super().__init__(name, args.parallel_run, args)
9 | self.device = device
10 | self.idx = -1
11 |
12 | def get_input_packet(self):
13 | return True
14 |
15 | def spin_once(self, input):
16 | log.check(input)
17 | if self.name == 'real':
18 | return self.dataset.stream()
19 | else:
20 | self.idx += 1
21 | if self.idx < len(self.dataset):
22 | return self.dataset[self.idx]
23 | else:
24 | print("Stopping data module!")
25 | super().shutdown_module()
26 | return None
27 |
28 | def initialize_module(self):
29 | if self.name == "euroc":
30 | from datasets.euroc_dataset import EurocDataset
31 | self.dataset = EurocDataset(self.args, self.device)
32 | elif self.name == "tum":
33 | from datasets.tum_dataset import TumDataset
34 | self.dataset = TumDataset(self.args, self.device)
35 | elif self.name == "nerf":
36 | from datasets.nerf_dataset import NeRFDataset
37 | self.dataset = NeRFDataset(self.args, self.device)
38 | elif self.name == "replica":
39 | from datasets.replica_dataset import ReplicaDataset
40 | self.dataset = ReplicaDataset(self.args, self.device)
41 | elif self.name == "real":
42 | from datasets.real_sense_dataset import RealSenseDataset
43 | self.dataset = RealSenseDataset(self.args, self.device)
44 | else:
45 | raise Exception(f"Unknown dataset: {self.name}")
46 | return super().initialize_module()
47 |
--------------------------------------------------------------------------------
/datasets/dataset.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import numpy as np
4 | import open3d as o3d
5 |
6 | from abc import abstractclassmethod
7 | from torch.utils.data.dataset import Dataset
8 |
9 | class Dataset(Dataset):
10 | def __init__(self, name, args, device) -> None:
11 | super().__init__()
12 | self.name = name
13 | self.args = args
14 | self.device = device
15 |
16 | self.dataset_dir = args.dataset_dir
17 | self.initial_k = args.initial_k # first frame to load
18 | self.final_k = args.final_k # last frame to load, if -1 load all
19 | self.img_stride = args.img_stride # stride for loading images
20 | self.stereo = args.stereo
21 |
22 | self.viz = False
23 |
24 | # list of data packets,
25 | # each data packet consists of two frames and all imu data in between
26 | self.data_packets = None
27 |
28 | @abstractclassmethod
29 | def stream(self):
30 | pass
31 |
32 | class PointCloudTransmissionFormat:
33 | def __init__(self, pointcloud: o3d.geometry.PointCloud):
34 | self.points = np.array(pointcloud.points)
35 | self.colors = np.array(pointcloud.colors)
36 | self.normals = np.array(pointcloud.normals)
37 |
38 | def create_pointcloud(self) -> o3d.geometry.PointCloud:
39 | pointcloud = o3d.geometry.PointCloud()
40 | pointcloud.points = o3d.utility.Vector3dVector(self.points)
41 | pointcloud.colors = o3d.utility.Vector3dVector(self.colors)
42 | pointcloud.normals = o3d.utility.Vector3dVector(self.normals)
43 | return pointcloud
44 |
45 | class CameraModel:
46 | def __init__(self, model) -> None:
47 | self.model = model
48 | pass
49 |
50 | def project(self, xyz):
51 | return self.model.project(xyz)
52 |
53 | def backproject(self, uv):
54 | return self.model.backproject(uv)
55 |
56 | class Resolution:
57 | def __init__(self, width, height) -> None:
58 | self.width = width
59 | self.height = height
60 |
61 | def numpy(self):
62 | return np.array([self.width, self.height])
63 |
64 | def total(self):
65 | return self.width * self.height
66 |
67 | class PinholeCameraModel(CameraModel):
68 | def __init__(self, fx, fy, cx, cy) -> None:
69 | super().__init__('Pinhole')
70 | self.fx = fx
71 | self.fy = fy
72 | self.cx = cx
73 | self.cy = cy
74 |
75 | self.K = np.eye(3)
76 | self.K[0,0] = fx
77 | self.K[0,2] = cx
78 | self.K[1,1] = fy
79 | self.K[1,2] = cy
80 |
81 | def scale_intrinsics(self, scale_x, scale_y):
82 | self.fx *= scale_x # fx, cx
83 | self.cx *= scale_x # fx, cx
84 | self.fy *= scale_y # fx, cx
85 | self.cy *= scale_y # fx, cx
86 |
87 | self.K = np.eye(3)
88 | self.K[0,0] = self.fx
89 | self.K[0,2] = self.cx
90 | self.K[1,1] = self.fy
91 | self.K[1,2] = self.cy
92 |
93 | def numpy(self):
94 | return np.array([self.fx, self.fy, self.cx, self.cy])
95 |
96 | def matrix(self):
97 | return self.K
98 |
99 | class DistortionModel:
100 | def __init__(self, model) -> None:
101 | self.model = model
102 |
103 | class RadTanDistortionModel(DistortionModel):
104 | def __init__(self, k1, k2, p1, p2) -> None:
105 | super().__init__('RadTan')
106 | # Distortioncoefficients=(k1 k2 p1 p2 k3) #OpenCV convention
107 | self.k1 = k1
108 | self.k2 = k2
109 | self.p1 = p1
110 | self.p2 = p2
111 |
112 | def get_distortion_as_vector(self):
113 | return np.array([self.k1, self.k2, self.p1, self.p2])
114 |
115 | class CameraCalibration:
116 | def __init__(self, body_T_cam, camera_model, distortion_model, rate_hz, resolution, aabb, depth_scale) -> None:
117 | self.body_T_cam = body_T_cam
118 | self.camera_model = camera_model
119 | self.distortion_model = distortion_model
120 | self.rate_hz = rate_hz
121 | self.resolution = resolution
122 | self.aabb = aabb
123 | self.depth_scale = depth_scale
124 |
125 | class ImuCalibration:
126 | def __init__(self, body_T_imu, a_n, a_b, g_n, g_b, rate_hz, imu_integration_sigma, imu_time_shift, n_gravity) -> None:
127 | self.body_T_imu = body_T_imu
128 | self.a_n = a_n
129 | self.g_n = g_n
130 | self.a_b = a_b
131 | self.g_b = g_b
132 | self.rate_hz = rate_hz
133 | self.imu_integration_sigma = imu_integration_sigma
134 | self.imu_time_shift = imu_time_shift
135 | self.n_gravity = n_gravity
136 | pass
137 |
138 | class ViconCalibration:
139 | def __init__(self, body_T_vicon) -> None:
140 | self.body_T_vicon = body_T_vicon
141 |
--------------------------------------------------------------------------------
/datasets/dataset_dict.py:
--------------------------------------------------------------------------------
1 | from .euroc_dataset import EurocDataset
2 |
3 | dataset_dict = {'euroc': EurocDataset}
--------------------------------------------------------------------------------
/datasets/nerf_dataset.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import os
4 | import sys
5 | sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
6 |
7 | import cv2
8 | import json
9 | import numpy as np
10 | from tqdm import tqdm
11 |
12 | from datasets.dataset import *
13 | from utils.utils import *
14 |
15 | class NeRFDataset(Dataset):
16 | def __init__(self, args, device):
17 | super().__init__("Nerf", args, device)
18 | self.parse_metadata()
19 | # self._build_dataset_index() # Loads all the data first, and then streams.
20 | self.tqdm = tqdm(total=self.__len__()) # Call after parsing metadata
21 |
22 | def get_cam_calib(self):
23 | w, h = self.json["w"], self.json["h"]
24 | fx, fy = self.json["fl_x"], self.json["fl_y"]
25 | cx, cy = self.json["cx"], self.json["cy"]
26 |
27 | body_T_cam0 = np.eye(4,4)
28 | rate_hz = 10.0
29 | resolution = Resolution(w, h)
30 | pinhole0 = PinholeCameraModel(fx, fy, cx, cy)
31 | distortion0 = RadTanDistortionModel(0, 0, 0, 0)
32 |
33 | aabb = self.json["aabb"]
34 | depth_scale = self.json["integer_depth_scale"] if "integer_depth_scale" in self.json else 1.0
35 |
36 | return CameraCalibration(body_T_cam0, pinhole0, distortion0, rate_hz, resolution, aabb, depth_scale)
37 |
38 | def parse_metadata(self):
39 | with open(os.path.join(self.dataset_dir, "transforms.json"), 'r') as f:
40 | self.json = json.load(f)
41 |
42 | self.calib = self.get_cam_calib()
43 |
44 | self.resize_images = False
45 | if self.calib.resolution.total() > 640*640:
46 | self.resize_images = True
47 | # TODO(Toni): keep aspect ratio, and resize max res to 640
48 | self.output_image_size = [341, 640] # h, w
49 |
50 | self.image_paths = []
51 | self.depth_paths = []
52 | self.w2c = []
53 |
54 | if self.resize_images:
55 | h0, w0 = self.calib.resolution.height, self.calib.resolution.width
56 | total_output_pixels = (self.output_image_size[0] * self.output_image_size[1])
57 | self.h1 = int(h0 * np.sqrt(total_output_pixels / (h0 * w0)))
58 | self.w1 = int(w0 * np.sqrt(total_output_pixels / (h0 * w0)))
59 | self.h1 = self.h1 - self.h1 % 8
60 | self.w1 = self.w1 - self.w1 % 8
61 | self.calib.camera_model.scale_intrinsics(self.w1 / w0, self.h1 / h0)
62 | self.calib.resolution = Resolution(self.w1, self.h1)
63 |
64 | frames = self.json["frames"]
65 | frames = frames[self.initial_k:self.final_k:self.img_stride]
66 | print(f'Loading {len(frames)} images.')
67 | for i, frame in enumerate(frames):
68 | # Convert from nerf to ngp
69 | # TODO: convert poses to our format
70 | c2w = np.array(frame['transform_matrix'])
71 | c2w = nerf_matrix_to_ngp(c2w) # THIS multiplies by scale = 1 and offset = 0.5
72 | # TODO(TONI): prone to numerical errors, do se(3) inverse instead
73 | w2c = np.linalg.inv(c2w)
74 |
75 | # Get rgb/depth images path
76 | if frame['file_path'].endswith(".png") or frame['file_path'].endswith(".jpg"):
77 | image_path = os.path.join(self.dataset_dir, f"{frame['file_path']}")
78 | else:
79 | image_path = os.path.join(self.dataset_dir, f"{frame['file_path']}.png")
80 | depth_path = None
81 | if 'depth_path' in frame:
82 | depth_path = os.path.join(self.dataset_dir, f"{frame['depth_path']}")
83 |
84 | self.image_paths.append([i, image_path])
85 | self.depth_paths += [depth_path]
86 | self.w2c += [w2c]
87 |
88 | # Sort paths chronologically
89 | if os.path.splitext(os.path.basename(self.image_paths[0][1]))[0].isdigit():
90 | # name is "000000.jpg" for Cube-Diorama
91 | sorted(self.image_paths, key=lambda path: int(os.path.splitext(os.path.basename(path[1]))[0]))
92 | else:
93 | # name is "frame000000.jpg" for Replica
94 | sorted(self.image_paths, key=lambda path: int(os.path.splitext(os.path.basename(path[1]))[0][5:]))
95 |
96 | # Store the first pose, used as prior and initial state in SLAM.
97 | self.args.world_T_imu_t0 = self.w2c[0]
98 |
99 | def _get_data_packet(self, k0, k1=None):
100 | if k1 is None:
101 | k1 = k0 + 1
102 | else:
103 | assert(k1 >= k0)
104 |
105 | timestamps = []
106 | poses = []
107 | images = []
108 | depths = []
109 | calibs = []
110 |
111 | W, H = self.calib.resolution.width, self.calib.resolution.height
112 |
113 | # Parse images and tfs
114 | for k in np.arange(k0, k1):
115 | i, image_path = self.image_paths[k]
116 | depth_path = self.depth_paths[i] # index with i, bcs we sorted image_paths to have increasing timestamps.
117 | w2c = self.w2c[i]
118 |
119 | # Parse rgb/depth images
120 | image = cv2.imread(image_path) # H, W, C=4
121 | image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA) # Required for Nerf Fusion, perhaps we can put it in there
122 | if depth_path:
123 | depth = cv2.imread(depth_path, cv2.IMREAD_UNCHANGED)[..., None] # H, W, C=1
124 | else:
125 | depth = (-1 * np.ones_like(image[:, :, 0])).astype(np.uint16) # invalid depth
126 |
127 | if self.resize_images:
128 | w1, h1 = self.w1, self.h1
129 | image = cv2.resize(image, (w1, h1))
130 | depth = cv2.resize(depth, (w1, h1))
131 | depth = depth[:, :, np.newaxis]
132 |
133 | if self.viz:
134 | cv2.imshow(f"Img Resized", image)
135 | cv2.imshow(f"Depth Resized", depth)
136 | cv2.waitKey(1)
137 |
138 | assert(H == image.shape[0])
139 | assert(W == image.shape[1])
140 | assert(3 == image.shape[2] or 4 == image.shape[2])
141 | assert(np.uint8 == image.dtype)
142 | assert(H == depth.shape[0])
143 | assert(W == depth.shape[1])
144 | assert(1 == depth.shape[2])
145 | assert(np.uint16 == depth.dtype)
146 |
147 | depth = depth.astype(np.int32) # converting to int32, because torch does not support uint16, and I don't want to lose precision
148 |
149 | timestamps += [i]
150 | poses += [w2c]
151 | images += [image]
152 | depths += [depth]
153 | calibs += [self.calib]
154 |
155 | return {"k": np.arange(k0,k1),
156 | "t_cams": np.array(timestamps),
157 | "poses": np.array(poses),
158 | "images": np.array(images),
159 | "depths": np.array(depths),
160 | "calibs": np.array(calibs),
161 | "is_last_frame": (i >= self.__len__() - 1),
162 | }
163 |
164 | def __len__(self):
165 | return len(self.image_paths)
166 |
167 | def __getitem__(self, k):
168 | self.tqdm.update(1)
169 | return self._get_data_packet(k) if self.data_packets is None else self.data_packets[k]
170 |
171 |
172 | # Up to you how you index the dataset depending on your training procedure
173 | def _build_dataset_index(self):
174 | # Go through the stream and bundle as you wish
175 | self.data_packets = [data_packet for data_packet in self.stream()]
176 |
177 | def stream(self):
178 | for k in range(self.__len__()):
179 | yield self._get_data_packet(k)
180 |
--------------------------------------------------------------------------------
/datasets/real_sense_dataset.py:
--------------------------------------------------------------------------------
1 |
2 | import pyrealsense2 as rs
3 | import numpy as np
4 | import cv2
5 | import os
6 |
7 | import tqdm
8 |
9 | from datasets.dataset import *
10 | from icecream import ic
11 |
12 | class RealSenseDataset(Dataset):
13 |
14 | def __init__(self, args, device):
15 | super().__init__("RealSense", args, device)
16 | self.parse_metadata()
17 |
18 | def parse_metadata(self):
19 | # Configure depth and color streams
20 | self.pipeline = rs.pipeline()
21 | config = rs.config()
22 |
23 | # Get device product line for setting a supporting resolution
24 | pipeline_wrapper = rs.pipeline_wrapper(self.pipeline)
25 | pipeline_profile = config.resolve(pipeline_wrapper)
26 |
27 | device = pipeline_profile.get_device()
28 | device_product_line = str(device.get_info(rs.camera_info.product_line))
29 |
30 | found_rgb = False
31 | for s in device.sensors:
32 | if s.get_info(rs.camera_info.name) == 'RGB Camera':
33 | found_rgb = True
34 | break
35 |
36 | if not found_rgb:
37 | raise NotImplementedError("No RGB camera found")
38 |
39 | if device_product_line == 'L500':
40 | raise NotImplementedError
41 |
42 | self.rate_hz = 30
43 | config.enable_stream(rs.stream.color, 640, 480, rs.format.rgb8, self.rate_hz)
44 | config.enable_stream(rs.stream.depth, 640, 480, rs.format.z16, self.rate_hz)
45 |
46 | # Set timestamp
47 | self.timestamp = 0
48 |
49 | # Start streaming
50 | ic("Start streaming")
51 | cfg = self.pipeline.start(config)
52 |
53 | # Set profile
54 | depth_cam, rgb_cam = cfg.get_device().query_sensors()
55 | rgb_cam.set_option(rs.option.enable_auto_exposure, False)
56 | rgb_cam.set_option(rs.option.exposure, 238)
57 | rgb_cam.set_option(rs.option.enable_auto_white_balance, False)
58 | rgb_cam.set_option(rs.option.white_balance, 3700)
59 | rgb_cam.set_option(rs.option.gain, 0)
60 | depth_cam.set_option(rs.option.enable_auto_exposure, False)
61 | depth_cam.set_option(rs.option.exposure, 438)
62 | depth_cam.set_option(rs.option.gain, 0)
63 | # color_sensor.set_option(rs.option.backlight_compensation, 0) # Disable backlight compensation
64 |
65 | # Set calib
66 | profile = cfg.get_stream(rs.stream.color) # Fetch stream profile for color stream
67 | intrinsics = profile.as_video_stream_profile().get_intrinsics() # Downcast to video_stream_profile and fetch intrinsics
68 | self.calib = self._get_cam_calib(intrinsics)
69 |
70 | self.resize_images = True
71 | if self.resize_images:
72 | self.output_image_size = [315, 420] # h, w
73 | h0, w0 = self.calib.resolution.height, self.calib.resolution.width
74 | total_output_pixels = (self.output_image_size[0] * self.output_image_size[1])
75 | self.h1 = int(h0 * np.sqrt(total_output_pixels / (h0 * w0)))
76 | self.w1 = int(w0 * np.sqrt(total_output_pixels / (h0 * w0)))
77 | self.h1 = self.h1 - self.h1 % 8
78 | self.w1 = self.w1 - self.w1 % 8
79 | self.calib.camera_model.scale_intrinsics(self.w1 / w0, self.h1 / h0)
80 | self.calib.resolution = Resolution(self.w1, self.h1)
81 |
82 | def _get_cam_calib(self, intrinsics):
83 | """ intrinsics:
84 | model Distortion model of the image
85 | coeffs Distortion coefficients
86 | fx Focal length of the image plane, as a multiple of pixel width
87 | fy Focal length of the image plane, as a multiple of pixel height
88 | ppx Horizontal coordinate of the principal point of the image, as a pixel offset from the left edge
89 | ppy Vertical coordinate of the principal point of the image, as a pixel offset from the top edge
90 | height Height of the image in pixels
91 | width Width of the image in pixels
92 | """
93 | w, h = intrinsics.width, intrinsics.height
94 | fx, fy, cx, cy= intrinsics.fx, intrinsics.fy, intrinsics.ppx, intrinsics.ppy
95 |
96 | distortion_coeffs = intrinsics.coeffs
97 | distortion_model = intrinsics.model
98 | k1, k2, p1, p2 = 0, 0, 0, 0
99 | body_T_cam0 = np.eye(4)
100 | rate_hz = self.rate_hz
101 |
102 | resolution = Resolution(w, h)
103 | pinhole0 = PinholeCameraModel(fx, fy, cx, cy)
104 | distortion0 = RadTanDistortionModel(k1, k2, p1, p2)
105 |
106 | aabb = (2*np.array([[-2, -2, -2], [2, 2, 2]])).tolist() # Computed automatically in to_nerf()
107 | depth_scale = 1.0 # TODO # Since we multiply as gt_depth *= depth_scale, we need to invert camera["scale"]
108 |
109 | return CameraCalibration(body_T_cam0, pinhole0, distortion0, rate_hz, resolution, aabb, depth_scale)
110 |
111 |
112 | def stream(self):
113 | self.viz=True
114 |
115 | timestamps = []
116 | poses = []
117 | images = []
118 | depths = []
119 | calibs = []
120 |
121 | got_image = False
122 | while not got_image:
123 | # Wait for a coherent pair of frames: depth and color
124 | try:
125 | frames = self.pipeline.wait_for_frames()
126 | except Exception as e:
127 | print(e)
128 | continue
129 | color_frame = frames.get_color_frame()
130 | depth_frame = frames.get_depth_frame()
131 | #depth_frame = np.zeros((color_image.shape[0], color_image.shape[1], 1))
132 |
133 | if not depth_frame or not color_frame:
134 | print("No depth and color frame parsed.")
135 | continue
136 |
137 | # Convert images to numpy arrays
138 | color_image = np.asanyarray(color_frame.get_data())
139 | depth_image = np.asanyarray(depth_frame.get_data())
140 |
141 |
142 | if self.resize_images:
143 | color_image = cv2.resize(color_image, (self.w1, self.h1))
144 | depth_image = cv2.resize(depth_image, (self.w1, self.h1))
145 |
146 | if self.viz:
147 | # Apply colormap on depth image (image must be converted to 8-bit per pixel first)
148 | depth_colormap = cv2.applyColorMap(cv2.convertScaleAbs(depth_image, alpha=0.03), cv2.COLORMAP_JET)
149 | cv2.imshow(f"Color Img", color_image)
150 | cv2.imshow(f"Depth Img", depth_colormap)
151 | cv2.waitKey(1)
152 |
153 | self.timestamp += 1
154 | if self.args.img_stride > 1 and self.timestamp % self.args.img_stride == 0:
155 | # Software imposed fps to rate_hz/img_stride
156 | continue
157 |
158 | timestamps += [self.timestamp]
159 | poses += [np.eye(4)] # We don't have poses
160 | images += [color_image]
161 | depths += [depth_image] # We don't use depth
162 | calibs += [self.calib]
163 | got_image = True
164 |
165 | return {"k": np.arange(self.timestamp-1,self.timestamp),
166 | "t_cams": np.array(timestamps),
167 | "poses": np.array(poses),
168 | "images": np.array(images),
169 | "depths": np.array(depths),
170 | "calibs": np.array(calibs),
171 | "is_last_frame": False, #TODO
172 | }
173 |
174 | def shutdown(self):
175 | # Stop streaming
176 | self.pipeline.stop()
177 |
178 | def to_nerf_format(self, data_packets):
179 | print("Exporting RealSense dataset to Nerf")
180 | OUT_PATH = "transforms.json"
181 | AABB_SCALE = 4
182 | out = {
183 | "fl_x": self.calib.camera_model.fx,
184 | "fl_y": self.calib.camera_model.fy,
185 | "k1": self.calib.distortion_model.k1,
186 | "k2": self.calib.distortion_model.k2,
187 | "p1": self.calib.distortion_model.p1,
188 | "p2": self.calib.distortion_model.p2,
189 | "cx": self.calib.camera_model.cx,
190 | "cy": self.calib.camera_model.cy,
191 | "w": self.calib.resolution.width,
192 | "h": self.calib.resolution.height,
193 | "aabb": self.calib.aabb,
194 | "aabb_scale": AABB_SCALE,
195 | "integer_depth_scale": self.calib.depth_scale,
196 | "frames": [],
197 | }
198 |
199 | from PIL import Image
200 |
201 | c2w = np.eye(4).tolist()
202 | for data_packet in tqdm(data_packets):
203 | # Image
204 | ic(data_packet["k"])
205 | k = data_packet["k"][0]
206 | i = data_packet["images"][0]
207 | d = data_packet["depths"][0]
208 |
209 | # Store image paths
210 | color_path = os.path.join(self.args.dataset_dir, "results", f"frame{k:05}.png")
211 | depth_path = os.path.join(self.args.dataset_dir, "results", f"depth{k:05}.png")
212 |
213 | # Save image to disk
214 | color = Image.fromarray(i)
215 | depth = Image.fromarray(d)
216 | color.save(color_path)
217 | depth.save(depth_path)
218 |
219 | # Sharpness
220 | sharp = sharpness(i)
221 |
222 | # Store relative path
223 | relative_color_path = os.path.join("results", os.path.basename(color_path))
224 | relative_depth_path = os.path.join("results", os.path.basename(depth_path))
225 |
226 | frame = {"file_path": relative_color_path,
227 | "sharpness": sharp,
228 | "depth_path": relative_depth_path,
229 | "transform_matrix": c2w}
230 | out["frames"].append(frame)
231 |
232 | with open(os.path.join(self.args.dataset_dir, OUT_PATH), "w") as outfile:
233 | import json
234 | json.dump(out, outfile, indent=2)
--------------------------------------------------------------------------------
/datasets/replica_dataset.py:
--------------------------------------------------------------------------------
1 |
2 | import glob
3 | import os
4 | import json
5 | import numpy as np
6 |
7 | import cv2
8 | import tqdm
9 |
10 | from icecream import ic
11 | from datasets.dataset import *
12 |
13 | class ReplicaDataset(Dataset):
14 | def __init__(self, args, device):
15 | super().__init__("Replica", args, device)
16 | self.dataset_dir = args.dataset_dir
17 | self.parse_dataset()
18 | self._build_dataset_index()
19 |
20 | def load_poses(self, path):
21 | poses = []
22 | with open(path, "r") as f:
23 | lines = f.readlines()
24 | for i in range(len(self.image_paths)):
25 | line = lines[i]
26 | c2w = np.array(list(map(float, line.split()))).reshape(4, 4)
27 | c2w[:3, 1] *= -1
28 | c2w[:3, 2] *= -1
29 | w2c = np.linalg.inv(c2w)
30 | poses.append(w2c)
31 | return poses
32 |
33 | def _get_cam_calib(self, path):
34 | with open(os.path.join(self.dataset_dir, "../cam_params.json"), 'r') as f:
35 | self.json = json.load(f)
36 |
37 | camera = self.json["camera"]
38 | w, h = camera['w'], camera['h']
39 | fx, fy, cx, cy= camera['fx'], camera['fy'], camera['cx'], camera['cy']
40 |
41 | k1, k2, p1, p2 = 0, 0, 0, 0
42 | body_T_cam0 = np.eye(4)
43 | rate_hz = 0
44 |
45 | resolution = Resolution(w, h)
46 | pinhole0 = PinholeCameraModel(fx, fy, cx, cy)
47 | distortion0 = RadTanDistortionModel(k1, k2, p1, p2)
48 |
49 | aabb = np.array([[-2, -2, -2], [2, 2, 2]]) # Computed automatically in to_nerf()
50 | depth_scale = 1.0 / camera["scale"] # Since we multiply as gt_depth *= depth_scale, we need to invert camera["scale"]
51 |
52 | return CameraCalibration(body_T_cam0, pinhole0, distortion0, rate_hz, resolution, aabb, depth_scale)
53 |
54 | def parse_dataset(self):
55 | self.timestamps = []
56 | self.poses = []
57 | self.images = []
58 | self.depths = []
59 | self.calibs = []
60 |
61 | self.image_paths = sorted(glob.glob(f'{self.dataset_dir}/results/frame*.jpg'))
62 | self.depth_paths = sorted(glob.glob(f'{self.dataset_dir}/results/depth*.png'))
63 | self.poses = self.load_poses(f'{self.dataset_dir}/traj.txt')
64 | self.calib = self._get_cam_calib(f'{self.dataset_dir}/../cam_params.json')
65 |
66 | N = self.args.buffer
67 | H, W = self.calib.resolution.height, self.calib.resolution.width
68 |
69 | # Parse images and tfs
70 | for i, (image_path, depth_path) in enumerate(tqdm(zip(self.image_paths, self.depth_paths))):
71 | if i >= N:
72 | break
73 |
74 | # Parse rgb/depth images
75 | image = cv2.imread(image_path)
76 | depth = cv2.imread(depth_path, cv2.IMREAD_UNCHANGED)
77 |
78 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # this is for NERF
79 | depth = cv2.imread(depth_path, cv2.IMREAD_UNCHANGED)[..., None] # H, W, C=1
80 |
81 | H, W, _ = depth.shape
82 | assert(H == image.shape[0])
83 | assert(W == image.shape[1])
84 | assert(3 == image.shape[2] or 4 == image.shape[2])
85 | assert(np.uint8 == image.dtype)
86 | assert(H == depth.shape[0])
87 | assert(W == depth.shape[1])
88 | assert(1 == depth.shape[2])
89 | assert(np.uint16 == depth.dtype)
90 |
91 | depth = depth.astype(np.int32) # converting to int32, because torch does not support uint16, and I don't want to lose precision
92 |
93 | self.timestamps += [i]
94 | self.images += [image]
95 | self.depths += [depth]
96 | self.calibs += [self.calib]
97 |
98 | self.poses = self.poses[:N]
99 |
100 | self.timestamps = np.array(self.timestamps)
101 | self.poses = np.array(self.poses)
102 | self.images = np.array(self.images)
103 | self.depths = np.array(self.depths)
104 | self.calibs = np.array(self.calibs)
105 |
106 | N = len(self.timestamps)
107 | assert(N == self.poses.shape[0])
108 | assert(N == self.images.shape[0])
109 | assert(N == self.depths.shape[0])
110 | assert(N == self.calibs.shape[0])
111 |
112 | def __len__(self):
113 | return len(self.poses)
114 |
115 | def __getitem__(self, k):
116 | return self.data_packets[k] if self.data_packets is not None else self._get_data_packet(k)
117 |
118 | def _get_data_packet(self, k0, k1=None):
119 | if k1 is None:
120 | k1 = k0 + 1
121 | else:
122 | assert(k1 >= k0)
123 | return {"k": np.arange(k0,k1),
124 | "t_cams": self.timestamps[k0:k1],
125 | "poses": self.poses[k0:k1],
126 | "images": self.images[k0:k1],
127 | "depths": self.depths[k0:k1],
128 | "calibs": self.calibs[k0:k1],
129 | "is_last_frame": (k0 >= self.__len__() - 1),
130 | }
131 |
132 | # Up to you how you index the dataset depending on your training procedure
133 | def _build_dataset_index(self):
134 | # Go through the stream and bundle as you wish
135 | self.data_packets = [data_packet for data_packet in self.stream()]
136 |
137 | def stream(self):
138 | for k in range(self.__len__()):
139 | yield self._get_data_packet(k)
140 |
141 | def to_nerf_format(self):
142 | print("Exporting Replica dataset to Nerf")
143 | OUT_PATH = "transforms.json"
144 | AABB_SCALE = 4
145 | out = {
146 | "fl_x": self.calib.camera_model.fx,
147 | "fl_y": self.calib.camera_model.fy,
148 | "k1": self.calib.distortion_model.k1,
149 | "k2": self.calib.distortion_model.k2,
150 | "p1": self.calib.distortion_model.p1,
151 | "p2": self.calib.distortion_model.p2,
152 | "cx": self.calib.camera_model.cx,
153 | "cy": self.calib.camera_model.cy,
154 | "w": self.calib.resolution.width,
155 | "h": self.calib.resolution.height,
156 | # TODO(Toni): calculate this automatically. Box that fits all cameras +2m
157 | "aabb": self.calib.aabb,
158 | "aabb_scale": AABB_SCALE,
159 | "integer_depth_scale": self.calib.depth_scale,
160 | "frames": [],
161 | }
162 |
163 | poses_t = []
164 | if self.data_packets is None:
165 | self._build_dataset_index()
166 | for data_packet in self.data_packets:
167 | # Image
168 | ic(data_packet["k"])
169 | color_path = self.image_paths[data_packet["k"][0]]
170 | depth_path = self.depth_paths[data_packet["k"][0]]
171 |
172 | relative_color_path = os.path.join("results", os.path.basename(color_path))
173 | relative_depth_path = os.path.join("results", os.path.basename(depth_path))
174 |
175 | # Transform
176 | w2c = data_packet["poses"][0]
177 | c2w = np.linalg.inv(w2c)
178 |
179 | # Convert from opencv convention to nerf convention
180 | c2w[0:3, 1] *= -1 # flip the y axis
181 | c2w[0:3, 2] *= -1 # flip the z axis
182 |
183 | poses_t += [w2c[:3,3].flatten()]
184 |
185 | frame = {"file_path": relative_color_path, # "sharpness": b,
186 | "depth_path": relative_depth_path,
187 | "transform_matrix": c2w.tolist()}
188 | out["frames"].append(frame)
189 |
190 | poses_t = np.array(poses_t)
191 | delta_t = 1.0 # 1 meter extra to allow for the depth of the camera
192 | t_max = np.amax(poses_t, 0).flatten()
193 | t_min = np.amin(poses_t, 0).flatten()
194 | out["aabb"] = np.array([t_min-delta_t, t_max+delta_t]).tolist()
195 |
196 | # Save the path to the ground-truth mesh as well
197 | out["gt_mesh"] = os.path.join("..", os.path.basename(self.dataset_dir)+"_mesh.ply")
198 | ic(out["gt_mesh"])
199 |
200 | with open(OUT_PATH, "w") as outfile:
201 | import json
202 | json.dump(out, outfile, indent=2)
203 |
--------------------------------------------------------------------------------
/datasets/tum_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import pandas as pd
4 |
5 | from datasets.dataset import *
6 |
7 | # TODO: Didn't finish to implement
8 | class TumDataset(Dataset):
9 | def __init__(self, args, device):
10 | super().__init__("Tum", args, device)
11 | self.dataset_dir = args.dataset_dir
12 |
13 | json_file = os.path.join(self.dataset_dir, "associations.txt")
14 | self.associations = open(associations_file).readlines()
15 | assert self.associations is not None
16 |
17 | self.resize_images = False
18 |
19 | self.t0 = None
20 |
21 | self.final_k = self.final_k if self.final_k != -1.0 and self.final_k < len(
22 | self.associations) else len(self.associations)
23 |
24 | self.parse_dataset()
25 |
26 | def parse_dataset(self):
27 | ## Get Cam Calib
28 | self.cam0_calib : CameraCalibration = self._get_cam_calib()
29 | self.cam_calibs = [self.cam0_calib]
30 |
31 | ## Get ground-truth pose
32 | gt_dir = os.path.join(self.dataset_dir, 'groundtruth.txt')
33 | self.gt_df = pd.read_csv(gt_dir, sep=" ", comment="#",
34 | names=["timestamp", "tx", "ty", "tz", "qx", "qy", "qz", "qw"])
35 | self.gt_df.timestamp *= 10000
36 | self.gt_df.timestamp = self.gt_df['timestamp'].astype('int64')
37 | self.gt_df.set_index("timestamp", drop=False, inplace=True)
38 |
39 | def get_rgb(self, frame_id):
40 | return self._get_img(frame_id, 'rgb')
41 |
42 | def get_depth(self, frame_id):
43 | return self._get_img(frame_id, 'depth')
44 |
45 | def _get_img(self, frame_id, type):
46 | if frame_id >= self.final_k:
47 | return None, None
48 |
49 | row = self.associations[frame_id].strip().split()
50 | if type == 'rgb':
51 | img_file_name = row[1]
52 | elif type == 'depth':
53 | img_file_name = row[3]
54 | else:
55 | raise "Unknown img type"
56 |
57 | timestamp = float(row[0])
58 | img = cv2.imread(os.path.join(self.dataset_dir, img_file_name))
59 |
60 | assert img is not None
61 |
62 | return timestamp, img
63 |
64 | def _get_data_packet(self, k):
65 | # The img_filename has the timestamp of the image! At least for Euroc!
66 | t_rgb0, img_0 = self.get_rgb(k)
67 | t_depth0, depth_0 = self.get_depth(k)
68 | assert t_rgb0 == t_depth0
69 | t_cam0 = t_rgb0
70 |
71 | t_cams = [t_cam0]
72 | images = [img_0]
73 | depths = [depth_0]
74 |
75 | if self.viz:
76 | for i, (rgb, depth) in enumerate(zip(images, depths)):
77 | if rgb is not None:
78 | cv2.imshow(f"Img{i}", rgb)
79 | if depth is not None:
80 | cv2.imshow(f"Depth{i}", depth)
81 |
82 | # Resize images
83 | if self.resize_images:
84 | h0, w0, _ = images[0].shape
85 | output_image_size = [384, 512]
86 | total_output_pixels = (output_image_size[0] * output_image_size[1])
87 | h1 = int(h0 * np.sqrt(total_output_pixels / (h0 * w0)))
88 | w1 = int(w0 * np.sqrt(total_output_pixels / (h0 * w0)))
89 |
90 | for i in range(len(images)):
91 | images[i] = cv2.resize(images[i], (w1, h1))
92 | images[i] = images[i][:h1-h1%8, :w1-w1%8]
93 |
94 | # TODO: can we do this for depths??
95 | depths[i] = cv2.resize(depths[i], (w1, h1))
96 | depths[i] = depths[i][:h1-h1%8, :w1-w1%8]
97 |
98 | if self.viz:
99 | for i, rgb, depth in enumerate(zip(images, depths)):
100 | cv2.imshow(f"Img{i} Resized", rgb)
101 | cv2.imshow(f"Depth{i} Resized", depth)
102 |
103 | if self.viz:
104 | cv2.waitKey(1)
105 |
106 | t_cams = np.array(t_cams)
107 | images = np.array(images)
108 | assert len(images) == len(t_cams)
109 | assert len(depths) == len(t_cams)
110 |
111 | # Ground-truth
112 | #t1_gt_near = self.gt_df['timestamp'].sub(t1).abs().idxmin()
113 | #t0_gt_near = self.gt_df['timestamp'].sub(self.t0).abs().idxmin()
114 | t1 = t_cam0
115 | t1_near = self.gt_df.index.get_indexer([t1], method="nearest")[0]
116 | if self.t0 is not None:
117 | gt_t0_t1 = self.gt_df.iloc[self.t0:t1_near+1] # +1 to include t1
118 | else:
119 | gt_t0_t1 = self.gt_df.iloc[t1_near]
120 | self.t0 = t1_near
121 |
122 | return {"k": k, "t_cams": t_cams, "images": images,
123 | "cam_calibs": self.cam_calibs if not self.resize_images else self.cam_calibs_resized,
124 | "gt_t0_t1": gt_t0_t1,
125 | "is_last_frame": (k >= self.__len__() - 1)}
126 |
127 | def _get_cam_calib(self):
128 | # TODO: remove hardcoded
129 | body_T_cam0 = np.eye(4)
130 | rate_hz = 0.0
131 | width, height = 640, 480
132 | fx, fy, cx, cy = 535.4, 539.2, 320.1, 247.6
133 | k1, k2, p1, p2 = 0.0, 0.0, 0.0, 0.0
134 |
135 | resolution = Resolution(width, height)
136 | pinhole0 = PinholeCameraModel(fx, fy, cx, cy)
137 | distortion0 = RadTanDistortionModel(k1, k2, p1, p2)
138 |
139 | aabb = np.array([[0,0,0],[1,1,1]])
140 | depth_scale = 1.0
141 |
142 | return CameraCalibration(body_T_cam0, pinhole0, distortion0, rate_hz, resolution, aabb, depth_scale)
143 |
144 | # Up to you how you index the dataset depending on your training procedure
145 | def _build_dataset_index(self):
146 | # Go through the stream and bundle as you wish
147 | # Here we do the simplest scenario, send imu data between frames,
148 | # and the next frame as a packet
149 | self.data_packets = [data_packet for data_packet in self.stream()]
150 |
151 | def __getitem__(self, index):
152 | return self.data_packets[index] if self.data_packets is not None else self._get_data_packet(index)
153 |
154 | def __len__(self):
155 | return len(self.data_packets) if self.data_packets is not None else len(self.associations)
156 |
157 | # Return all data btw frames, plus the subsequent frame.
158 | def stream(self):
159 | for k in self.__len__():
160 | yield self._get_data_packet(k)
161 |
162 |
--------------------------------------------------------------------------------
/droid.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ToniRV/NeRF-SLAM/4e407e8cb1378c6ece18e621b19ccd5be982b7dd/droid.pth
--------------------------------------------------------------------------------
/examples/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 |
4 |
--------------------------------------------------------------------------------
/examples/slam_demo.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import os
4 | import sys
5 | sys.settrace
6 | sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
7 |
8 | import torch
9 | from torch.multiprocessing import Process
10 |
11 | from datasets.data_module import DataModule
12 | from gui.gui_module import GuiModule
13 | from slam.slam_module import SlamModule
14 | from fusion.fusion_module import FusionModule
15 |
16 | from icecream import ic
17 |
18 | import argparse
19 |
20 | def parse_args():
21 | parser = argparse.ArgumentParser(description="Instant-SLAM")
22 |
23 | # SLAM ARGS
24 | parser.add_argument("--parallel_run", action="store_true", help="Whether to run in parallel")
25 | parser.add_argument("--multi_gpu", action="store_true", help="Whether to run with multiple (two) GPUs")
26 | parser.add_argument("--initial_k", type=int, help="Initial frame to parse in the dataset", default=0)
27 | parser.add_argument("--final_k", type=int, help="Final frame to parse in the dataset, -1 is all.", default=-1)
28 | parser.add_argument("--img_stride", type=int, help="Number of frames to skip when parsing the dataset", default=1)
29 | parser.add_argument("--stereo", action="store_true", help="Use stereo images")
30 | parser.add_argument("--weights", default="droid.pth", help="Path to the weights file")
31 | parser.add_argument("--buffer", type=int, default=512, help="Number of keyframes to keep")
32 |
33 | parser.add_argument("--dataset_dir", type=str,
34 | help="Path to the dataset directory",
35 | default="/home/tonirv/Datasets/euroc/V1_01_easy")
36 | parser.add_argument('--dataset_name', type=str, default='euroc',
37 | choices=['euroc', 'nerf', 'replica', 'real'],
38 | help='Dataset format to use.')
39 |
40 | parser.add_argument("--mask_type", type=str, default='ours', choices=['no_depth', 'raw', 'ours', 'ours_w_thresh'])
41 |
42 | #parser.add_argument("--gui", action="store_true", help="Run the testbed GUI interactively.")
43 | parser.add_argument("--slam", action="store_true", help="Run SLAM.")
44 | parser.add_argument("--fusion", type=str, default='', choices=['tsdf', 'sigma', 'nerf', ''],
45 | help="Fusion approach ('' for none):\n\
46 | -`tsdf' classical tsdf-fusion using Open3D\n \
47 | -`sigma' tsdf-fusion with uncertainty values (Rosinol22wacv)\n \
48 | -`nerf' radiance field reconstruction using Instant-NGP.")
49 |
50 | # GUI ARGS
51 | parser.add_argument("--gui", action="store_true", help="Run O3D Gui, use when volume='tsdf'or'sigma'.")
52 | parser.add_argument("--width", "--screenshot_w", type=int, default=0, help="Resolution width of GUI and screenshots.")
53 | parser.add_argument("--height", "--screenshot_h", type=int, default=0, help="Resolution height of GUI and screenshots.")
54 |
55 | # NERF ARGS
56 | parser.add_argument("--network", default="", help="Path to the network config. Uses the scene's default if unspecified.")
57 |
58 | parser.add_argument("--eval", action="store_true", help="Evaluate method.")
59 |
60 | return parser.parse_args()
61 |
62 | def run(args):
63 | if args.parallel_run and args.multi_gpu:
64 | os.environ['CUDA_VISIBLE_DEVICES'] = "0,1"
65 | cpu = 'cpu'
66 | cuda_slam = 'cuda:0'
67 | cuda_fusion = 'cuda:1' # you can also try same device as in slam.
68 | else:
69 | os.environ['CUDA_VISIBLE_DEVICES'] = "0"
70 | cpu = 'cpu'
71 | cuda_slam = cuda_fusion = 'cuda:0'
72 | print(f"Running with GPUs: {os.environ['CUDA_VISIBLE_DEVICES']}")
73 |
74 | if not args.parallel_run:
75 | from queue import Queue
76 | else:
77 | from torch.multiprocessing import Queue
78 |
79 | # Create the Queue object
80 | data_for_viz_output_queue = Queue()
81 | data_for_fusion_output_queue = Queue()
82 | data_output_queue = Queue()
83 | slam_output_queue_for_fusion = Queue()
84 | slam_output_queue_for_o3d = Queue()
85 | fusion_output_queue_for_gui = Queue()
86 | gui_output_queue_for_fusion = Queue()
87 |
88 | # Create Dataset provider
89 | data_provider_module = DataModule(args.dataset_name, args, device=cpu)
90 |
91 | # Create MetaSLAM pipeline
92 | # The SLAM module takes care of creating the SLAM object itself, to avoid pickling issues
93 | # (see initialize_module() method inside)
94 | slam = args.slam
95 | if slam:
96 | slam_module = SlamModule("VioSLAM", args, device=cuda_slam)
97 | data_provider_module.register_output_queue(data_output_queue)
98 | slam_module.register_input_queue("data", data_output_queue)
99 |
100 | # Create Neural Volume
101 | fusion = args.fusion != ""
102 | if fusion:
103 | fusion_module = FusionModule(args.fusion, args, device=cuda_fusion)
104 | if slam:
105 | slam_module.register_output_queue(slam_output_queue_for_fusion)
106 | fusion_module.register_input_queue("slam", slam_output_queue_for_fusion)
107 |
108 | if (args.fusion == 'nerf' and not slam) or (args.fusion != 'nerf' and args.eval):
109 | # Only used for evaluation, or in case we do not use slam (for nerf)
110 | data_provider_module.register_output_queue(data_for_fusion_output_queue)
111 | fusion_module.register_input_queue("data", data_for_fusion_output_queue)
112 |
113 | # Create interactive Gui
114 | gui = args.gui and args.fusion != 'nerf' # nerf has its own gui
115 | if gui:
116 | gui_module = GuiModule("Open3DGui", args, device=cuda_slam) # don't use cuda:1, o3d doesn't work...
117 | data_provider_module.register_output_queue(data_for_viz_output_queue)
118 | slam_module.register_output_queue(slam_output_queue_for_o3d)
119 | gui_module.register_input_queue("data", data_for_viz_output_queue)
120 | gui_module.register_input_queue("slam", slam_output_queue_for_o3d)
121 | if fusion and (fusion_module.name == "tsdf" or fusion_module.name == "sigma"):
122 | fusion_module.register_output_queue(fusion_output_queue_for_gui)
123 | gui_module.register_input_queue("fusion", fusion_output_queue_for_gui)
124 | gui_module.register_output_queue(gui_output_queue_for_fusion)
125 | fusion_module.register_input_queue("gui", gui_output_queue_for_fusion)
126 |
127 | # Run
128 | if args.parallel_run:
129 | print("Running pipeline in parallel mode.")
130 |
131 | data_provider_thread = Process(target=data_provider_module.spin, args=())
132 | if fusion: fusion_thread = Process(target=fusion_module.spin) # FUSION NEEDS TO BE IN A PROCESS
133 | #if slam: slam_thread = Process(target=slam_module.spin, args=())
134 | if gui: gui_thread = Process(target=gui_module.spin, args=())
135 |
136 | data_provider_thread.start()
137 | if fusion: fusion_thread.start()
138 | #if slam: slam_thread.start()
139 | if gui: gui_thread.start()
140 |
141 | # Runs in main thread
142 | if slam:
143 | slam_module.spin() # visualizer should be the main spin, but pytorch has a memory bug/leak if threaded...
144 | slam_module.shutdown_module()
145 | ic("Deleting SLAM module to free memory")
146 | torch.cuda.empty_cache()
147 | # slam_module.slam. # add function to empty all matrices?
148 | del slam_module
149 | print("FINISHED RUNNING SLAM")
150 | while (fusion and fusion_thread.exitcode == None):
151 | continue
152 | print("FINISHED RUNNING FUSION")
153 | while (gui and not gui_module.shutdown):
154 | continue
155 | print("FINISHED RUNNING GUI")
156 |
157 | # This is not doing what you think, because Process has another module
158 | if gui: gui_module.shutdown_module()
159 | if fusion: fusion_module.shutdown_module()
160 | data_provider_module.shutdown_module()
161 |
162 | if gui: gui_thread.terminate() # violent, should be join()
163 | #if slam: slam_thread.terminate() # violent, should be join()
164 | if fusion: fusion_thread.terminate() # violent, should be join()
165 | data_provider_thread.terminate() # violent, should be a join(), but I don't know how to flush the queue
166 | else:
167 | print("Running pipeline in sequential mode.")
168 |
169 | # Initialize all modules first (and register 3D volume)
170 | if data_provider_module.spin() \
171 | and (not slam or slam_module.spin()) \
172 | and (not fusion or fusion_module.spin()):
173 | if gui:
174 | gui_module.spin()
175 | #gui_module.register_volume(fusion_module.fusion.volume)
176 |
177 | # Run sequential, dataprovider fills queue and gui empties it
178 | while data_provider_module.spin() \
179 | and (not slam or slam_module.spin()) \
180 | and (not fusion or fusion_module.spin()) \
181 | and (not gui or gui_module.spin()):
182 | continue
183 |
184 | # Then gui runs indefinitely until user closes window
185 | ok = True
186 | while ok:
187 | if gui: ok &= gui_module.spin()
188 | if fusion: ok &= fusion_module.spin()
189 |
190 | # Delete everything and clean memory
191 |
192 | if __name__ == '__main__':
193 | args = parse_args()
194 |
195 | torch.multiprocessing.set_start_method('spawn')
196 | torch.cuda.empty_cache()
197 | torch.backends.cudnn.benchmark = True
198 | torch.set_grad_enabled(False)
199 |
200 | run(args)
201 | print('Done...')
202 |
--------------------------------------------------------------------------------
/factor_graph/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 |
4 |
--------------------------------------------------------------------------------
/factor_graph/factor.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | from abc import ABC, abstractmethod
4 | from typing import List
5 |
6 | import torch as th
7 | from torch import nn
8 |
9 | from factor_graph.loss_function import LossFunction
10 | from factor_graph.key import Key
11 | from factor_graph.variables import Variable, Variables
12 |
13 | class Factor(ABC, nn.Module):
14 | def __init__(self, keys: List[Key], loss_function: LossFunction, device='cpu') -> None:
15 | super().__init__()
16 | self.keys = keys
17 | self.loss_function = loss_function
18 | self.device = device
19 |
20 | @abstractmethod
21 | def linearize(self, x0: Variables) -> th.Tensor:
22 | raise
23 |
24 | # Forward === error
25 | def forward(self, x: Variables or th.Tensor) -> th.Tensor:
26 | return self.error(x)
27 |
28 | # f(x)
29 | @abstractmethod
30 | def error(self, x: Variables or th.Tensor) -> th.Tensor:
31 | pass
32 |
33 | # w(e) where e = f(x),
34 | def weight(self, x: Variables) -> th.Tensor:
35 | return self.loss_function(self.error(x))
36 |
37 | def _batch_jacobian(self, f, x0: th.Tensor):
38 | f_sum = lambda x: th.sum(f(x), axis=0) # sum over all batches
39 | return th.autograd.functional.jacobian(f_sum, x0, create_graph=True).swapaxes(1, 0)
40 |
--------------------------------------------------------------------------------
/factor_graph/factor_graph.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | from typing import Tuple
4 |
5 | import torch as th
6 | from torch import nn
7 |
8 | from icecream import ic
9 |
10 | import colored_glog as log
11 |
12 | from factor_graph.variables import Variables
13 |
14 | from gtsam import NonlinearFactorGraph as FactorGraph
15 |
16 | class FactorGraphManager:
17 | def __init__(self) -> None:
18 | self.factor_graph = FactorGraph()
19 |
20 | def add(self, factor_graph):
21 | # Check we don't have repeated factors
22 |
23 | if factor_graph.empty():
24 | #log.warn("Attempted to add factors, but none were provided")
25 | return False
26 |
27 | # Remove None factors?
28 |
29 | # Add factors
30 | self.factor_graph.push_back(factor_graph)
31 |
32 | # Perhaps remove old factors
33 |
34 | # Update map from factor-id to slot
35 | return True
36 |
37 | def replace(self, key, factor):
38 | self.factor_graph.replace(key, factor)
39 |
40 | def remove(self, key):
41 | self.factor_graph.remove(key)
42 |
43 | def __iter__(self):
44 | return self.factor_graph.__iter__()
45 |
46 | def __getitem__(self, key):
47 | assert self.factor_graph.exists(key)
48 | return self.factor_graph.at(key)
49 |
50 | def __len__(self):
51 | # self.factor_graph.nrFactors()
52 | return self.factor_graph.size()
53 |
54 | def is_empty(self):
55 | return self.factor_graph.empty()
56 |
57 | def reset_factor_graph(self):
58 | self.factor_graph = FactorGraph()
59 |
60 | def get_factor_graph(self):
61 | return self.factor_graph
62 |
63 |
64 | class TorchFactorGraph(nn.Module):
65 | def __init__(self):
66 | super().__init__()
67 | self.factors = nn.ModuleList([])
68 | ic(self.factors)
69 | self.run_jit = False
70 |
71 | def add(self, factors):
72 | # Check we don't have repeated factors
73 |
74 | # Append factors
75 | self.factors.extend(factors)
76 |
77 | # Update map from factor-id to slot
78 | pass
79 |
80 | def remove(self, factor_ids):
81 | pass
82 |
83 | def __iter__(self):
84 | return self.factors.__iter__()
85 |
86 | def __getitem__(self, key):
87 | return self.factors.__getitem__(key)
88 |
89 | def __len__(self):
90 | return len(self.factors)
91 |
92 | def is_empty(self):
93 | return self.__len__() == 0
94 |
95 | def forward(self, x: Variables) -> th.Tensor:
96 | return self._forward_jit(x) if self.run_jit else self._forward(x)
97 |
98 | def _forward(self, x: Variables) -> th.Tensor:
99 | residuals = th.stack([factor(x) for factor in self.factors])
100 | weights = th.stack([factor.weight(x) for factor in self.factors])
101 | return th.sum(th.pow(residuals, 2) * weights, dim=0)
102 |
103 | # Not necessarily faster...
104 | def _forward_jit(self, x: Variables) -> th.Tensor:
105 | residuals_calc = [th.jit.fork(factor, x) for factor in self.factors]
106 | weights_calc = [th.jit.fork(factor.weight, x) for factor in self.factors]
107 | residuals = th.stack([th.jit.wait(thread) for thread in residuals_calc])
108 | weights = th.stack([th.jit.wait(thread) for thread in weights_calc])
109 | return th.sum(th.pow(residuals, 2) * weights, dim=0)
110 |
111 |
112 | # TODO parallelize
113 | # These variables are already ordered.
114 | def linearize(self, x0: Variables) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
115 | # Build the Jacobian matrix, here only for one f
116 | # Returns a tensor per each entry [i][j] corresponding to J_ij = d(output_i)/d(input_j)
117 | # this allows us to reason about the jacobian in terms of blocks.
118 | # A[i has dimensions of the output f(x0)][j has dimensions of the input x0]
119 | # vectorize=True --> first dimension is considered batch dimension
120 | # A = torch.autograd.functional.jacobian(f, x0, vectorize=True, create_graph=True)
121 |
122 | # Here we could query how to linearize:
123 | ## a) Linearize using AutoDiff (current)
124 | ## b) Linearize using closed-form jacobian
125 | ## extra) Linearize itself + compress with Schur complement, and return reduced Hessian matrix?
126 | AA = None; bb = None; ww = None
127 | # TODO Linearize factors in parallel...
128 | for factor in self.factors:
129 | if factor is None:
130 | log.warn("Got a None factor...")
131 | continue
132 | A, b, w = factor.linearize(x0)
133 | if AA is None:
134 | AA = A; bb = b; ww = w
135 | continue
136 | AA = th.hstack((AA, A))
137 | bb = th.hstack((bb, b))
138 | ww = th.hstack((ww, w))
139 | return AA, bb, ww
140 |
141 | # TODO parallelize
142 | def weight(self, x: Variables) -> th.Tensor:
143 | ww = None;
144 | for factor in self.factors:
145 | w = factor.weight(x)
146 | if ww is None:
147 | ww = w
148 | continue
149 | ww = th.hstack((ww, w))
150 | return ww
151 |
--------------------------------------------------------------------------------
/factor_graph/key.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | class Key():
4 | def __init__(self, name: str, idx: int):
5 | self.name = name
6 | self.idx = idx
7 |
8 | def name(self):
9 | return self.name
10 |
11 | def idx(self):
12 | return self.idx
13 |
14 | def __hash__(self):
15 | return hash((self.name, self.idx))
16 |
17 | def __eq__(self, other):
18 | return (self.name, self.idx) == (other.name, other.idx)
19 |
20 | def __ne__(self, other):
21 | # Not strictly necessary, but to avoid having both x==y and x!=y
22 | # True at the same time
23 | return not(self == other)
24 |
25 | def __repr__(self) -> str:
26 | return self.__str__()
27 |
28 | def __str__(self) -> str:
29 | return f"{self.name}_{self.idx}"
30 |
--------------------------------------------------------------------------------
/factor_graph/loss_function.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import torch as th
4 | from torch import nn
5 |
6 | class LossFunction(nn.Module):
7 | def __init__(self, device) -> None:
8 | nn.Module.__init__(self)
9 | super(LossFunction).__init__()
10 | self.device = device
11 |
12 | class CauchyLossFunction(LossFunction):
13 | def __init__(self, device, k_initial: float = 1.0) -> None:
14 | super().__init__(device)
15 | self.k = th.nn.Parameter(k_initial * th.ones(1, device=self.device, requires_grad=True))
16 |
17 | # w(e) = w(f(x))
18 | # The loss function weights the residuals
19 | def forward(self, residuals: th.Tensor) -> th.Tensor:
20 | return 1.0 / (1.0 + th.pow(residuals / self.k, 2))
21 |
22 | class GMLossFunction(LossFunction):
23 | def __init__(self, device, k_initial: float = 1.0) -> None:
24 | super().__init__(device)
25 | self.k = th.nn.Parameter(k_initial * th.ones(1, device=self.device, requires_grad=True))
26 |
27 | # w(e) = w(f(x))
28 | # The loss function weights the residuals
29 | def forward(self, residuals: th.Tensor) -> th.Tensor:
30 | return th.pow(self.k, 2) / th.pow(self.k + th.pow(residuals, 2), 2) # Isn't this cauchy?
--------------------------------------------------------------------------------
/factor_graph/variables.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | from typing import List, Dict
4 | import colored_glog as log
5 | import torch as th
6 |
7 | from factor_graph.key import Key
8 |
9 | class Variable:
10 | def __init__(self, key: Key, value: th.Tensor):
11 | self.key = key
12 | self.value = value
13 |
14 | def __repr__(self) -> str:
15 | return self.__str__()
16 |
17 | def __str__(self) -> str:
18 | return f"Variable(key={self.key}, value={self.value})"
19 |
20 | class Variables:
21 | def __init__(self):
22 | self.vars: Dict[Key, Variable] = {}
23 | pass
24 |
25 | def add(self, variable: Variable):
26 | self.vars[variable.key] = variable
27 |
28 | # The order of the keys is important
29 | def at(self, keys: List[Key]) -> th.Tensor:
30 | # Stack all variables in the order of the keys
31 | return th.hstack([self.vars[key].value for key in keys])
32 |
33 | # The order of the keys is important
34 | def stack(self) -> th.Tensor:
35 | # Stack all variables in the order of the keys
36 | # TODO: super slow, we should keep an hstack perhaps
37 | return th.hstack(list(map(lambda x: x.value, self.vars.values())))
38 |
39 | # The order of the delta is important, it must match the order of the keys
40 | # in the Variables object self.vars
41 | # delta -> [B, N]
42 | # vars -> [B, N]
43 | def retract(self, delta: th.Tensor):
44 | log.check_eq(delta.shape[1], len(self.vars))
45 | # TODO
46 | # Retract variables in parallel, rather than sequentially
47 | # Use ordered dict, instead of retrieving in order
48 | x_new = {}
49 | for delta_i, key in zip(delta.t(), self.vars.keys()):
50 | x_new[key] = Variable(key, self.vars[key].value + delta_i.unsqueeze(0).t())
51 | return x_new
52 |
53 | def __repr__(self) -> str:
54 | return self.__str__()
55 |
56 | def __str__(self) -> str:
57 | return f"Variables(vars={self.vars})"
58 |
59 |
--------------------------------------------------------------------------------
/fusion/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 |
4 |
--------------------------------------------------------------------------------
/fusion/fusion_module.py:
--------------------------------------------------------------------------------
1 | from pipeline.pipeline_module import MIMOPipelineModule
2 |
3 | class FusionModule(MIMOPipelineModule):
4 | def __init__(self, name, args, device="cpu") -> None:
5 | super().__init__(name, args.parallel_run, args)
6 | self.device = device
7 |
8 | def spin_once(self, data_packet):
9 | output = self.fusion.fuse(data_packet)
10 | # TODO: if you uncomment this, we never reach gui/fusion loop, but if you comment it never stops.
11 | if self.fusion.stop_condition():
12 | print("Stopping fusion module!")
13 | super().shutdown_module()
14 | #if not output:
15 | # super().shutdown_module()
16 | return output
17 |
18 | def initialize_module(self):
19 | self.set_cuda_device() # This needs to be done before importing NerfFusion or TsdfFusion
20 | if self.name == "tsdf" or self.name == "sigma":
21 | from fusion.tsdf_fusion import TsdfFusion
22 | self.fusion = TsdfFusion(self.name, self.args, self.device)
23 | elif self.name == "nerf":
24 | from fusion.nerf_fusion import NerfFusion
25 | self.fusion = NerfFusion(self.name, self.args, self.device)
26 | else:
27 | raise NotImplementedError
28 | return super().initialize_module()
29 |
30 | def get_input_packet(self):
31 | input = super().get_input_packet(timeout=0.0000000001) # don't block fusion waiting for input
32 | return input if input is not None else False # so that we keep running, and do not just stop in spin()
33 |
34 | def set_cuda_device(self):
35 | if self.device == "cpu":
36 | return
37 |
38 | import os
39 | if self.device == "cuda:0":
40 | os.environ['CUDA_VISIBLE_DEVICES'] = "0"
41 | elif self.device == "cuda:1":
42 | os.environ['CUDA_VISIBLE_DEVICES'] = "1"
43 | self.device = "cuda:0" # Since only 1 will be visible...
44 | else:
45 | raise NotImplementedError
46 |
--------------------------------------------------------------------------------
/gui/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 |
4 |
--------------------------------------------------------------------------------
/gui/gui_module.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | from pipeline.pipeline_module import MIMOPipelineModule
4 |
5 | class GuiModule(MIMOPipelineModule):
6 | def __init__(self, name, args, device="cpu") -> None:
7 | super().__init__(name, args.parallel_run, args)
8 | self.device = device
9 |
10 | def spin_once(self, data_packet):
11 | # If the queue is empty, queue.get() will block until the queue has data
12 | output = self.gui.visualize(data_packet)
13 | #if not output:
14 | # super().shutdown_module()
15 | return output
16 |
17 | def initialize_module(self):
18 | if self.name == "Open3DGui":
19 | from gui.open3d_gui import Open3dGui
20 | self.gui = Open3dGui(self.args, self.device)
21 | elif self.name == "DearPyGui":
22 | raise NotImplementedError
23 | else:
24 | raise NotImplementedError
25 | self.gui.initialize()
26 | return super().initialize_module()
27 |
28 | def get_input_packet(self):
29 | # don't block rendering waiting for input
30 | input = super().get_input_packet(timeout=0.000000001)
31 | # so that we keep running, and do not just stop in spin()
32 | return input if input is not None else False
33 |
--------------------------------------------------------------------------------
/media/intro.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ToniRV/NeRF-SLAM/4e407e8cb1378c6ece18e621b19ccd5be982b7dd/media/intro.gif
--------------------------------------------------------------------------------
/media/mit.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ToniRV/NeRF-SLAM/4e407e8cb1378c6ece18e621b19ccd5be982b7dd/media/mit.png
--------------------------------------------------------------------------------
/media/mrg_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ToniRV/NeRF-SLAM/4e407e8cb1378c6ece18e621b19ccd5be982b7dd/media/mrg_logo.png
--------------------------------------------------------------------------------
/media/sparklab_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ToniRV/NeRF-SLAM/4e407e8cb1378c6ece18e621b19ccd5be982b7dd/media/sparklab_logo.png
--------------------------------------------------------------------------------
/networks/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 |
4 |
--------------------------------------------------------------------------------
/networks/droid_frontend.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import lietorch
3 | import numpy as np
4 |
5 | from lietorch import SE3
6 | from networks.factor_graph import FactorGraph
7 |
8 |
9 | class DroidFrontend:
10 | def __init__(self, droid_net, video, args):
11 | self.video = video # has the frames, the features, and the context activations
12 | self.update_net = droid_net.update_net # The GRU update network
13 | self.graph = FactorGraph(video, droid_net.update_net, max_factors=48)
14 |
15 | # local optimization window
16 | self.t0 = 0
17 | self.t1 = 0
18 |
19 | # frontent variables
20 | self.is_initialized = False
21 | self.count = 0
22 |
23 | self.max_age = 25
24 | self.iters1 = 4
25 | self.iters2 = 2
26 |
27 | self.warmup = args.warmup
28 | self.beta = args.beta
29 | self.frontend_nms = args.frontend_nms
30 | self.keyframe_thresh = args.keyframe_thresh
31 | self.frontend_window = args.frontend_window
32 | self.frontend_thresh = args.frontend_thresh
33 | self.frontend_radius = args.frontend_radius
34 |
35 | def __update(self):
36 | """ add edges, perform update """
37 |
38 | self.count += 1
39 | self.t1 += 1
40 |
41 | if self.graph.correlation_volumes is not None:
42 | self.graph.rm_factors(self.graph.age > self.max_age, store=True)
43 |
44 | self.graph.add_proximity_factors(self.t1-5, max(self.t1-self.frontend_window, 0),
45 | rad=self.frontend_radius, nms=self.frontend_nms, thresh=self.frontend_thresh, beta=self.beta, remove=True)
46 |
47 | self.video.disps[self.t1-1] = torch.where(self.video.disps_sens[self.t1-1] > 0,
48 | self.video.disps_sens[self.t1-1], self.video.disps[self.t1-1])
49 |
50 | for itr in range(self.iters1):
51 | self.graph.update(None, None, use_inactive=True)
52 |
53 | # set initial pose for next frame
54 | poses = SE3(self.video.poses)
55 | d = self.video.distance([self.t1-3], [self.t1-2], beta=self.beta, bidirectional=True)
56 |
57 | if d.item() < self.keyframe_thresh:
58 | self.graph.rm_keyframe(self.t1 - 2)
59 |
60 | with self.video.get_lock():
61 | self.video.counter.value -= 1
62 | self.t1 -= 1
63 |
64 | else:
65 | for itr in range(self.iters2):
66 | self.graph.update(None, None, use_inactive=True)
67 |
68 | # set pose for next iteration
69 | self.video.poses[self.t1] = self.video.poses[self.t1-1]
70 | self.video.disps[self.t1] = self.video.disps[self.t1-1].mean()
71 |
72 | # update visualization
73 | self.video.dirty[self.graph.ii.min():self.t1] = True
74 |
75 | # Do we really need this special initialization?
76 | def __initialize(self):
77 | """ initialize the SLAM system """
78 |
79 | self.t0 = 0
80 | self.t1 = self.video.counter.value
81 |
82 | # Just adds the `r' sequential frames to the graph
83 | self.graph.add_neighborhood_factors(self.t0, self.t1, r=3)
84 |
85 | for itr in range(8):
86 | self.graph.update(1, use_inactive=True)
87 |
88 | self.graph.add_proximity_factors(0, 0, rad=2, nms=2, thresh=self.frontend_thresh, remove=False)
89 |
90 | for itr in range(8):
91 | self.graph.update(1, use_inactive=True)
92 |
93 | # Set initial pose/depth for next iteration
94 | # self.video.normalize()
95 | self.video.poses[self.t1] = self.video.poses[self.t1-1].clone()
96 | self.video.disps[self.t1] = self.video.disps[self.t1-4:self.t1].mean() # Next depth is just the mean?
97 |
98 | # initialization complete
99 | self.is_initialized = True
100 | self.last_pose = self.video.poses[self.t1-1].clone()
101 | self.last_disp = self.video.disps[self.t1-1].clone()
102 | self.last_time = self.video.tstamp[self.t1-1].clone()
103 |
104 | # for visualization
105 | with self.video.get_lock():
106 | self.video.ready.value = 1
107 | self.video.dirty[:self.t1] = True
108 |
109 | self.graph.rm_factors(self.graph.ii < self.warmup-4, store=True)
110 |
111 | def __call__(self):
112 | """ main update """
113 |
114 | # do initialization
115 | if not self.is_initialized and self.video.counter.value == self.warmup:
116 | self.__initialize()
117 |
118 | # do update
119 | elif self.is_initialized and self.t1 < self.video.counter.value:
120 | self.__update()
121 |
122 |
123 |
--------------------------------------------------------------------------------
/networks/droid_net.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from collections import OrderedDict
6 |
7 | from networks.modules.extractor import BasicEncoder
8 | from networks.modules.corr import CorrBlock
9 | from networks.modules.gru import ConvGRU
10 | from networks.modules.clipping import GradientClip
11 |
12 | from lietorch import SE3
13 | from networks.geom.ba import BA
14 |
15 | import networks.geom.projective_ops as pops
16 | from networks.geom.graph_utils import graph_to_edge_list, keyframe_indicies
17 |
18 | from torch_scatter import scatter_mean
19 |
20 |
21 | def cvx_upsample(data, mask):
22 | """ upsample pixel-wise transformation field """
23 | batch, ht, wd, dim = data.shape
24 | data = data.permute(0, 3, 1, 2)
25 | mask = mask.view(batch, 1, 9, 8, 8, ht, wd)
26 | mask = torch.softmax(mask, dim=2)
27 |
28 | up_data = F.unfold(data, [3,3], padding=1)
29 | up_data = up_data.view(batch, dim, 9, 1, 1, ht, wd)
30 |
31 | up_data = torch.sum(mask * up_data, dim=2)
32 | up_data = up_data.permute(0, 4, 2, 5, 3, 1)
33 | up_data = up_data.reshape(batch, 8*ht, 8*wd, dim)
34 |
35 | return up_data
36 |
37 | def upsample_disp(disp, mask):
38 | batch, num, ht, wd = disp.shape
39 | disp = disp.view(batch*num, ht, wd, 1)
40 | mask = mask.view(batch*num, -1, ht, wd)
41 | return cvx_upsample(disp, mask).view(batch, num, 8*ht, 8*wd)
42 |
43 |
44 | class GraphAgg(nn.Module):
45 | def __init__(self):
46 | super(GraphAgg, self).__init__()
47 | self.conv1 = nn.Conv2d(128, 128, 3, padding=1)
48 | self.conv2 = nn.Conv2d(128, 128, 3, padding=1)
49 | self.relu = nn.ReLU(inplace=True)
50 |
51 | self.eta = nn.Sequential(
52 | nn.Conv2d(128, 1, 3, padding=1),
53 | GradientClip(),
54 | nn.Softplus())
55 |
56 | self.upmask = nn.Sequential(
57 | nn.Conv2d(128, 8*8*9, 1, padding=0))
58 |
59 | def forward(self, net, ii):
60 | batch, num, ch, ht, wd = net.shape
61 | net = net.view(batch*num, ch, ht, wd)
62 |
63 | _, ix = torch.unique(ii, return_inverse=True)
64 | net = self.relu(self.conv1(net))
65 |
66 | net = net.view(batch, num, 128, ht, wd)
67 | net = scatter_mean(net, ix, dim=1)
68 | net = net.view(-1, 128, ht, wd)
69 |
70 | net = self.relu(self.conv2(net))
71 |
72 | eta = self.eta(net).view(batch, -1, ht, wd)
73 | upmask = self.upmask(net).view(batch, -1, 8*8*9, ht, wd)
74 |
75 | return .01 * eta, upmask
76 |
77 |
78 | class UpdateModule(nn.Module):
79 | def __init__(self):
80 | super(UpdateModule, self).__init__()
81 | cor_planes = 4 * (2*3 + 1)**2
82 |
83 | self.corr_encoder = nn.Sequential(
84 | nn.Conv2d(cor_planes, 128, 1, padding=0),
85 | nn.ReLU(inplace=True),
86 | nn.Conv2d(128, 128, 3, padding=1),
87 | nn.ReLU(inplace=True))
88 |
89 | self.flow_encoder = nn.Sequential(
90 | nn.Conv2d(4, 128, 7, padding=3),
91 | nn.ReLU(inplace=True),
92 | nn.Conv2d(128, 64, 3, padding=1),
93 | nn.ReLU(inplace=True))
94 |
95 | self.weight = nn.Sequential(
96 | nn.Conv2d(128, 128, 3, padding=1),
97 | nn.ReLU(inplace=True),
98 | nn.Conv2d(128, 2, 3, padding=1),
99 | GradientClip(),
100 | nn.Sigmoid())
101 |
102 | self.delta = nn.Sequential(
103 | nn.Conv2d(128, 128, 3, padding=1),
104 | nn.ReLU(inplace=True),
105 | nn.Conv2d(128, 2, 3, padding=1),
106 | GradientClip())
107 |
108 | # Inputs to gru are:
109 | # i) hidden state (128)
110 | # ii) input state (128+128+64)
111 | # a) context_encoder_output(128),
112 | # b) corr_encoder_output (128),
113 | # c) flow_encoder_output (64),
114 | # Input planes: 128=constant_context, 128=correlation, 64=motion, 64=sigma
115 | self.gru = ConvGRU(128, 128+128+64)
116 | self.agg = GraphAgg()
117 |
118 | def forward(self, net, inp, corr, flow=None, ii=None, jj=None):
119 | """ RaftSLAM update operator """
120 |
121 | batch, num, ch, ht, wd = net.shape
122 |
123 | if flow is None:
124 | # Initialize flow to zero... (TODO: better initialization using IMU)
125 | flow = torch.zeros(batch, num, 4, ht, wd, device=net.device)
126 |
127 | output_dim = (batch, num, -1, ht, wd)
128 | net = net.view(batch*num, -1, ht, wd)
129 | inp = inp.view(batch*num, -1, ht, wd)
130 | corr = corr.view(batch*num, -1, ht, wd)
131 | flow = flow.view(batch*num, -1, ht, wd)
132 |
133 | corr = self.corr_encoder(corr)
134 | flow = self.flow_encoder(flow)
135 | net = self.gru(net, inp, corr, flow)
136 |
137 | ### update variables ###
138 | delta = self.delta(net).view(*output_dim)
139 | weight = self.weight(net).view(*output_dim)
140 |
141 | delta = delta.permute(0,1,3,4,2)[...,:2].contiguous()
142 | weight = weight.permute(0,1,3,4,2)[...,:2].contiguous()
143 |
144 | net = net.view(*output_dim)
145 |
146 | if ii is not None:
147 | eta, upmask = self.agg(net, ii.to(net.device))
148 | return net, delta, weight, eta, upmask
149 | else:
150 | return net, delta, weight
151 |
152 |
153 | class DroidNet(nn.Module):
154 | def __init__(self):
155 | super(DroidNet, self).__init__()
156 | self.feature_net = BasicEncoder(output_dim=128, norm_fn='instance')
157 | self.context_net = BasicEncoder(output_dim=256, norm_fn='none')
158 | self.update_net = UpdateModule()
159 |
160 |
161 | #### EVERYTHING BELOW IS ONLY USED FOR TRAINING NOT INFERENCE ####
162 | def extract_features(self, images):
163 | """ run feature extraction networks """
164 |
165 | # normalize images
166 | b, n, c1, h1, w1 = x.shape
167 | images = images[:, :, [2,1,0]] / 255.0
168 | mean = torch.as_tensor([0.485, 0.456, 0.406], device=images.device)
169 | std = torch.as_tensor([0.229, 0.224, 0.225], device=images.device)
170 | images = images.sub_(mean[:, None, None]).div_(std[:, None, None])
171 |
172 | feature_maps = self.feature_net(images)
173 | context_maps = self.context_net(images)
174 |
175 | context_maps, gru_input_maps = context_maps.split([128,128], dim=2)
176 | context_maps = torch.tanh(context_maps)
177 | gru_input_maps = torch.relu(gru_input_maps)
178 | return feature_maps, context_maps, gru_input_maps
179 |
180 |
181 | def forward(self, Gs, images, disps, intrinsics, graph=None, num_steps=12, fixedp=2):
182 | """ Estimates SE3 or Sim3 between pair of frames """
183 |
184 | u = keyframe_indicies(graph)
185 | ii, jj, kk = graph_to_edge_list(graph)
186 |
187 | ii = ii.to(device=images.device, dtype=torch.long)
188 | jj = jj.to(device=images.device, dtype=torch.long)
189 |
190 | feature_maps, context_maps, gru_input_maps = self.extract_features(images)
191 | context_maps, gru_input_maps = context_maps[:,ii], gru_input_maps[:,ii]
192 | corr_fn = CorrBlock(feature_maps[:,ii], feature_maps[:,jj], num_levels=4, radius=3)
193 |
194 | ht, wd = images.shape[-2:]
195 | coords0 = pops.coords_grid(ht//8, wd//8, device=images.device)
196 |
197 | coords1, _ = pops.projective_transform(Gs, disps, intrinsics, ii, jj)
198 | target = coords1.clone()
199 |
200 | Gs_list, disp_list, residual_list = [], [], []
201 | for _ in range(num_steps):
202 | Gs = Gs.detach()
203 | disps = disps.detach()
204 | coords1 = coords1.detach()
205 | target = target.detach()
206 |
207 | # extract motion features
208 | corr = corr_fn(coords1)
209 | residual = target - coords1
210 | flow = coords1 - coords0
211 |
212 | motion = torch.cat([flow, residual], dim=-1)
213 | motion = motion.permute(0,1,4,2,3).clamp(-64.0, 64.0)
214 |
215 | context_maps, delta, weight, eta, upmask = \
216 | self.update_net(context_maps, gru_input_maps, corr, motion, ii, jj)
217 |
218 | target = coords1 + delta
219 |
220 | for _ in range(2):
221 | Gs, disps = BA(target, weight, eta, Gs, disps, intrinsics, ii, jj, fixedp=2)
222 |
223 | coords1, valid_mask = pops.projective_transform(Gs, disps, intrinsics, ii, jj)
224 | residual = (target - coords1)
225 |
226 | Gs_list.append(Gs)
227 | disp_list.append(upsample_disp(disps, upmask))
228 | residual_list.append(valid_mask * residual)
229 |
230 | return Gs_list, disp_list, residual_list
--------------------------------------------------------------------------------
/networks/geom/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 |
4 |
--------------------------------------------------------------------------------
/networks/geom/ba.py:
--------------------------------------------------------------------------------
1 | import lietorch
2 | import torch
3 | import torch.nn.functional as F
4 |
5 | from .chol import block_solve, schur_solve
6 | from . import projective_ops as pops
7 |
8 | from torch_scatter import scatter_sum
9 |
10 |
11 | # utility functions for scattering ops
12 | def safe_scatter_add_mat(A, ii, jj, n, m):
13 | v = (ii >= 0) & (jj >= 0) & (ii < n) & (jj < m)
14 | return scatter_sum(A[:,v], ii[v]*m + jj[v], dim=1, dim_size=n*m)
15 |
16 | def safe_scatter_add_vec(b, ii, n):
17 | v = (ii >= 0) & (ii < n)
18 | return scatter_sum(b[:,v], ii[v], dim=1, dim_size=n)
19 |
20 | # apply retraction operator to inv-depth maps
21 | def disp_retr(disps, dz, ii):
22 | ii = ii.to(device=dz.device)
23 | return disps + scatter_sum(dz, ii, dim=1, dim_size=disps.shape[1])
24 |
25 | # apply retraction operator to poses
26 | def pose_retr(poses, dx, ii):
27 | ii = ii.to(device=dx.device)
28 | return poses.retr(scatter_sum(dx, ii, dim=1, dim_size=poses.shape[1]))
29 |
30 |
31 | def BA(target, weight, eta, poses, disps, intrinsics, ii, jj, fixedp=1, rig=1):
32 | """ Full Bundle Adjustment """
33 |
34 | B, P, ht, wd = disps.shape
35 | N = ii.shape[0]
36 | D = poses.manifold_dim
37 |
38 | ### 1: commpute jacobians and residuals ###
39 | coords, valid, (Ji, Jj, Jz) = pops.projective_transform(
40 | poses, disps, intrinsics, ii, jj, jacobian=True)
41 |
42 | r = (target - coords).view(B, N, -1, 1)
43 | w = .001 * (valid * weight).view(B, N, -1, 1)
44 |
45 | ### 2: construct linear system ###
46 | Ji = Ji.reshape(B, N, -1, D)
47 | Jj = Jj.reshape(B, N, -1, D)
48 | wJiT = (w * Ji).transpose(2,3)
49 | wJjT = (w * Jj).transpose(2,3)
50 |
51 | Jz = Jz.reshape(B, N, ht*wd, -1)
52 |
53 | Hii = torch.matmul(wJiT, Ji)
54 | Hij = torch.matmul(wJiT, Jj)
55 | Hji = torch.matmul(wJjT, Ji)
56 | Hjj = torch.matmul(wJjT, Jj)
57 |
58 | vi = torch.matmul(wJiT, r).squeeze(-1)
59 | vj = torch.matmul(wJjT, r).squeeze(-1)
60 |
61 | Ei = (wJiT.view(B,N,D,ht*wd,-1) * Jz[:,:,None]).sum(dim=-1)
62 | Ej = (wJjT.view(B,N,D,ht*wd,-1) * Jz[:,:,None]).sum(dim=-1)
63 |
64 | w = w.view(B, N, ht*wd, -1)
65 | r = r.view(B, N, ht*wd, -1)
66 | wk = torch.sum(w*r*Jz, dim=-1)
67 | Ck = torch.sum(w*Jz*Jz, dim=-1)
68 |
69 | kx, kk = torch.unique(ii, return_inverse=True)
70 | M = kx.shape[0]
71 |
72 | # only optimize keyframe poses
73 | P = P // rig - fixedp
74 | ii = ii // rig - fixedp
75 | jj = jj // rig - fixedp
76 |
77 | H = safe_scatter_add_mat(Hii, ii, ii, P, P) + \
78 | safe_scatter_add_mat(Hij, ii, jj, P, P) + \
79 | safe_scatter_add_mat(Hji, jj, ii, P, P) + \
80 | safe_scatter_add_mat(Hjj, jj, jj, P, P)
81 |
82 | E = safe_scatter_add_mat(Ei, ii, kk, P, M) + \
83 | safe_scatter_add_mat(Ej, jj, kk, P, M)
84 |
85 | v = safe_scatter_add_vec(vi, ii, P) + \
86 | safe_scatter_add_vec(vj, jj, P)
87 |
88 | C = safe_scatter_add_vec(Ck, kk, M)
89 | w = safe_scatter_add_vec(wk, kk, M)
90 |
91 | C = C + eta.view(*C.shape) + 1e-7
92 |
93 | H = H.view(B, P, P, D, D)
94 | E = E.view(B, P, M, D, ht*wd)
95 |
96 | ### 3: solve the system ###
97 | dx, dz = schur_solve(H, E, C, v, w)
98 |
99 | ### 4: apply retraction ###
100 | poses = pose_retr(poses, dx, torch.arange(P) + fixedp)
101 | disps = disp_retr(disps, dz.view(B,-1,ht,wd), kx)
102 |
103 | disps = torch.where(disps > 10, torch.zeros_like(disps), disps)
104 | disps = disps.clamp(min=0.0)
105 |
106 | return poses, disps
107 |
108 |
109 | def MoBA(target, weight, eta, poses, disps, intrinsics, ii, jj, fixedp=1, rig=1):
110 | """ Motion only bundle adjustment """
111 |
112 | B, P, ht, wd = disps.shape
113 | N = ii.shape[0]
114 | D = poses.manifold_dim
115 |
116 | ### 1: commpute jacobians and residuals ###
117 | coords, valid, (Ji, Jj, Jz) = pops.projective_transform(
118 | poses, disps, intrinsics, ii, jj, jacobian=True)
119 |
120 | r = (target - coords).view(B, N, -1, 1)
121 | w = .001 * (valid * weight).view(B, N, -1, 1)
122 |
123 | ### 2: construct linear system ###
124 | Ji = Ji.reshape(B, N, -1, D)
125 | Jj = Jj.reshape(B, N, -1, D)
126 | wJiT = (w * Ji).transpose(2,3)
127 | wJjT = (w * Jj).transpose(2,3)
128 |
129 | Hii = torch.matmul(wJiT, Ji)
130 | Hij = torch.matmul(wJiT, Jj)
131 | Hji = torch.matmul(wJjT, Ji)
132 | Hjj = torch.matmul(wJjT, Jj)
133 |
134 | vi = torch.matmul(wJiT, r).squeeze(-1)
135 | vj = torch.matmul(wJjT, r).squeeze(-1)
136 |
137 | # only optimize keyframe poses
138 | P = P // rig - fixedp
139 | ii = ii // rig - fixedp
140 | jj = jj // rig - fixedp
141 |
142 | H = safe_scatter_add_mat(Hii, ii, ii, P, P) + \
143 | safe_scatter_add_mat(Hij, ii, jj, P, P) + \
144 | safe_scatter_add_mat(Hji, jj, ii, P, P) + \
145 | safe_scatter_add_mat(Hjj, jj, jj, P, P)
146 |
147 | v = safe_scatter_add_vec(vi, ii, P) + \
148 | safe_scatter_add_vec(vj, jj, P)
149 |
150 | H = H.view(B, P, P, D, D)
151 |
152 | ### 3: solve the system ###
153 | dx = block_solve(H, v)
154 |
155 | ### 4: apply retraction ###
156 | poses = pose_retr(poses, dx, torch.arange(P) + fixedp)
157 | return poses
158 |
159 |
--------------------------------------------------------------------------------
/networks/geom/chol.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from . import projective_ops as pops
4 |
5 | class CholeskySolver(torch.autograd.Function):
6 | @staticmethod
7 | def forward(ctx, H, b):
8 | # don't crash training if cholesky decomp fails
9 | try:
10 | U = torch.linalg.cholesky(H)
11 | xs = torch.cholesky_solve(b, U)
12 | ctx.save_for_backward(U, xs)
13 | ctx.failed = False
14 | except Exception as e:
15 | print(e)
16 | ctx.failed = True
17 | xs = torch.zeros_like(b)
18 |
19 | return xs
20 |
21 | @staticmethod
22 | def backward(ctx, grad_x):
23 | if ctx.failed:
24 | return None, None
25 |
26 | U, xs = ctx.saved_tensors
27 | dz = torch.cholesky_solve(grad_x, U)
28 | dH = -torch.matmul(xs, dz.transpose(-1,-2))
29 |
30 | return dH, dz
31 |
32 | def block_solve(H, b, ep=0.1, lm=0.0001):
33 | """ solve normal equations """
34 | B, N, _, D, _ = H.shape
35 | I = torch.eye(D).to(H.device)
36 | H = H + (ep + lm*H) * I
37 |
38 | H = H.permute(0,1,3,2,4)
39 | H = H.reshape(B, N*D, N*D)
40 | b = b.reshape(B, N*D, 1)
41 |
42 | x = CholeskySolver.apply(H,b)
43 | return x.reshape(B, N, D)
44 |
45 |
46 | def schur_solve(H, E, C, v, w, ep=0.1, lm=0.0001, sless=False):
47 | """ solve using shur complement """
48 |
49 | B, P, M, D, HW = E.shape
50 | H = H.permute(0,1,3,2,4).reshape(B, P*D, P*D)
51 | E = E.permute(0,1,3,2,4).reshape(B, P*D, M*HW)
52 | Q = (1.0 / C).view(B, M*HW, 1)
53 |
54 | # damping
55 | I = torch.eye(P*D).to(H.device)
56 | H = H + (ep + lm*H) * I
57 |
58 | v = v.reshape(B, P*D, 1)
59 | w = w.reshape(B, M*HW, 1)
60 |
61 | Et = E.transpose(1,2)
62 | S = H - torch.matmul(E, Q*Et)
63 | v = v - torch.matmul(E, Q*w)
64 |
65 | dx = CholeskySolver.apply(S, v)
66 | if sless:
67 | return dx.reshape(B, P, D)
68 |
69 | dz = Q * (w - Et @ dx)
70 | dx = dx.reshape(B, P, D)
71 | dz = dz.reshape(B, M, HW)
72 |
73 | return dx, dz
74 |
--------------------------------------------------------------------------------
/networks/geom/graph_utils.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import numpy as np
4 | from collections import OrderedDict
5 |
6 | import lietorch
7 | from .rgbd_utils import compute_distance_matrix_flow, compute_distance_matrix_flow2
8 |
9 | def graph_to_edge_list(graph):
10 | ii, jj, kk = [], [], []
11 | for s, u in enumerate(graph):
12 | for v in graph[u]:
13 | ii.append(u)
14 | jj.append(v)
15 | kk.append(s)
16 |
17 | ii = torch.as_tensor(ii)
18 | jj = torch.as_tensor(jj)
19 | kk = torch.as_tensor(kk)
20 | return ii, jj, kk
21 |
22 | def keyframe_indicies(graph):
23 | return torch.as_tensor([u for u in graph])
24 |
25 | def meshgrid(m, n, device='cuda'):
26 | ii, jj = torch.meshgrid(torch.arange(m), torch.arange(n))
27 | return ii.reshape(-1).to(device), jj.reshape(-1).to(device)
28 |
29 | def neighbourhood_graph(n, r):
30 | ii, jj = meshgrid(n, n)
31 | d = (ii - jj).abs()
32 | keep = (d >= 1) & (d <= r)
33 | return ii[keep], jj[keep]
34 |
35 |
36 | def build_frame_graph(poses, disps, intrinsics, num=16, thresh=24.0, r=2):
37 | """ construct a frame graph between co-visible frames """
38 | N = poses.shape[1]
39 | poses = poses[0].cpu().numpy()
40 | disps = disps[0][:,3::8,3::8].cpu().numpy()
41 | intrinsics = intrinsics[0].cpu().numpy() / 8.0
42 | d = compute_distance_matrix_flow(poses, disps, intrinsics)
43 |
44 | count = 0
45 | graph = OrderedDict()
46 |
47 | for i in range(N):
48 | graph[i] = []
49 | d[i,i] = np.inf
50 | for j in range(i-r, i+r+1):
51 | if 0 <= j < N and i != j:
52 | graph[i].append(j)
53 | d[i,j] = np.inf
54 | count += 1
55 |
56 | while count < num:
57 | ix = np.argmin(d)
58 | i, j = ix // N, ix % N
59 |
60 | if d[i,j] < thresh:
61 | graph[i].append(j)
62 | d[i,j] = np.inf
63 | count += 1
64 | else:
65 | break
66 |
67 | return graph
68 |
69 |
70 |
71 | def build_frame_graph_v2(poses, disps, intrinsics, num=16, thresh=24.0, r=2):
72 | """ construct a frame graph between co-visible frames """
73 | N = poses.shape[1]
74 | # poses = poses[0].cpu().numpy()
75 | # disps = disps[0].cpu().numpy()
76 | # intrinsics = intrinsics[0].cpu().numpy()
77 | d = compute_distance_matrix_flow2(poses, disps, intrinsics)
78 |
79 | # import matplotlib.pyplot as plt
80 | # plt.imshow(d)
81 | # plt.show()
82 |
83 | count = 0
84 | graph = OrderedDict()
85 |
86 | for i in range(N):
87 | graph[i] = []
88 | d[i,i] = np.inf
89 | for j in range(i-r, i+r+1):
90 | if 0 <= j < N and i != j:
91 | graph[i].append(j)
92 | d[i,j] = np.inf
93 | count += 1
94 |
95 | while 1:
96 | ix = np.argmin(d)
97 | i, j = ix // N, ix % N
98 |
99 | if d[i,j] < thresh:
100 | graph[i].append(j)
101 |
102 | for i1 in range(i-1, i+2):
103 | for j1 in range(j-1, j+2):
104 | if 0 <= i1 < N and 0 <= j1 < N:
105 | d[i1, j1] = np.inf
106 |
107 | count += 1
108 | else:
109 | break
110 |
111 | return graph
112 |
113 |
--------------------------------------------------------------------------------
/networks/geom/losses.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | import numpy as np
3 | import torch
4 | from lietorch import SO3, SE3, Sim3
5 | from .graph_utils import graph_to_edge_list
6 | from .projective_ops import projective_transform
7 |
8 |
9 | def pose_metrics(dE):
10 | """ Translation/Rotation/Scaling metrics from Sim3 """
11 | t, q, s = dE.data.split([3, 4, 1], -1)
12 | ang = SO3(q).log().norm(dim=-1)
13 |
14 | # convert radians to degrees
15 | r_err = (180 / np.pi) * ang
16 | t_err = t.norm(dim=-1)
17 | s_err = (s - 1.0).abs()
18 | return r_err, t_err, s_err
19 |
20 |
21 | def fit_scale(Ps, Gs):
22 | b = Ps.shape[0]
23 | t1 = Ps.data[...,:3].detach().reshape(b, -1)
24 | t2 = Gs.data[...,:3].detach().reshape(b, -1)
25 |
26 | s = (t1*t2).sum(-1) / ((t2*t2).sum(-1) + 1e-8)
27 | return s
28 |
29 |
30 | def geodesic_loss(Ps, Gs, graph, gamma=0.9, do_scale=True):
31 | """ Loss function for training network """
32 |
33 | # relative pose
34 | ii, jj, kk = graph_to_edge_list(graph)
35 | dP = Ps[:,jj] * Ps[:,ii].inv()
36 |
37 | n = len(Gs)
38 | geodesic_loss = 0.0
39 |
40 | for i in range(n):
41 | w = gamma ** (n - i - 1)
42 | dG = Gs[i][:,jj] * Gs[i][:,ii].inv()
43 |
44 | if do_scale:
45 | s = fit_scale(dP, dG)
46 | dG = dG.scale(s[:,None])
47 |
48 | # pose error
49 | d = (dG * dP.inv()).log()
50 |
51 | if isinstance(dG, SE3):
52 | tau, phi = d.split([3,3], dim=-1)
53 | geodesic_loss += w * (
54 | tau.norm(dim=-1).mean() +
55 | phi.norm(dim=-1).mean())
56 |
57 | elif isinstance(dG, Sim3):
58 | tau, phi, sig = d.split([3,3,1], dim=-1)
59 | geodesic_loss += w * (
60 | tau.norm(dim=-1).mean() +
61 | phi.norm(dim=-1).mean() +
62 | 0.05 * sig.norm(dim=-1).mean())
63 |
64 | dE = Sim3(dG * dP.inv()).detach()
65 | r_err, t_err, s_err = pose_metrics(dE)
66 |
67 | metrics = {
68 | 'rot_error': r_err.mean().item(),
69 | 'tr_error': t_err.mean().item(),
70 | 'bad_rot': (r_err < .1).float().mean().item(),
71 | 'bad_tr': (t_err < .01).float().mean().item(),
72 | }
73 |
74 | return geodesic_loss, metrics
75 |
76 |
77 | def residual_loss(residuals, gamma=0.9):
78 | """ loss on system residuals """
79 | residual_loss = 0.0
80 | n = len(residuals)
81 |
82 | for i in range(n):
83 | w = gamma ** (n - i - 1)
84 | residual_loss += w * residuals[i].abs().mean()
85 |
86 | return residual_loss, {'residual': residual_loss.item()}
87 |
88 |
89 | def flow_loss(Ps, disps, poses_est, disps_est, intrinsics, graph, gamma=0.9):
90 | """ optical flow loss """
91 |
92 | N = Ps.shape[1]
93 | graph = OrderedDict()
94 | for i in range(N):
95 | graph[i] = [j for j in range(N) if abs(i-j)==1]
96 |
97 | ii, jj, kk = graph_to_edge_list(graph)
98 | coords0, val0 = projective_transform(Ps, disps, intrinsics, ii, jj)
99 | val0 = val0 * (disps[:,ii] > 0).float().unsqueeze(dim=-1)
100 |
101 | n = len(poses_est)
102 | flow_loss = 0.0
103 |
104 | for i in range(n):
105 | w = gamma ** (n - i - 1)
106 | coords1, val1 = projective_transform(poses_est[i], disps_est[i], intrinsics, ii, jj)
107 |
108 | v = (val0 * val1).squeeze(dim=-1)
109 | epe = v * (coords1 - coords0).norm(dim=-1)
110 | flow_loss += w * epe.mean()
111 |
112 | epe = epe.reshape(-1)[v.reshape(-1) > 0.5]
113 | metrics = {
114 | 'f_error': epe.mean().item(),
115 | '1px': (epe<1.0).float().mean().item(),
116 | }
117 |
118 | return flow_loss, metrics
119 |
--------------------------------------------------------------------------------
/networks/geom/projective_ops.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | from lietorch import SE3, Sim3
5 |
6 | from icecream import ic
7 |
8 | MIN_DEPTH = 0.2 # ?????
9 |
10 | def extract_intrinsics(intrinsics):
11 | return intrinsics[...,None,None,:].unbind(dim=-1)
12 |
13 | def coords_grid(ht, wd, **kwargs):
14 | y, x = torch.meshgrid(
15 | torch.arange(ht).to(**kwargs).float(),
16 | torch.arange(wd).to(**kwargs).float())
17 |
18 | return torch.stack([x, y], dim=-1)
19 |
20 | def iproj(disps, intrinsics, jacobian=False):
21 | """ pinhole camera inverse projection """
22 | ht, wd = disps.shape[2:]
23 | fx, fy, cx, cy = extract_intrinsics(intrinsics)
24 |
25 | y, x = torch.meshgrid(
26 | torch.arange(ht).to(disps.device).float(),
27 | torch.arange(wd).to(disps.device).float())
28 |
29 | i = torch.ones_like(disps)
30 | X = (x - cx) / fx
31 | Y = (y - cy) / fy
32 | pts = torch.stack([X, Y, i, disps], dim=-1)
33 |
34 | if jacobian:
35 | J = torch.zeros_like(pts)
36 | J[...,-1] = 1.0
37 | return pts, J
38 |
39 | return pts, None
40 |
41 | def proj(Xs, intrinsics, jacobian=False, return_depth=False):
42 | """ pinhole camera projection """
43 | fx, fy, cx, cy = extract_intrinsics(intrinsics)
44 | X, Y, Z, D = Xs.unbind(dim=-1)
45 |
46 | Z = torch.where(Z < 0.5*MIN_DEPTH, torch.ones_like(Z), Z)
47 | d = 1.0 / Z
48 |
49 | x = fx * (X * d) + cx
50 | y = fy * (Y * d) + cy
51 | if return_depth:
52 | coords = torch.stack([x, y, D*d], dim=-1)
53 | else:
54 | coords = torch.stack([x, y], dim=-1)
55 |
56 | if jacobian:
57 | B, N, H, W = d.shape
58 | o = torch.zeros_like(d)
59 | proj_jac = torch.stack([
60 | fx*d, o, -fx*X*d*d, o,
61 | o, fy*d, -fy*Y*d*d, o,
62 | # o, o, -D*d*d, d,
63 | ], dim=-1).view(B, N, H, W, 2, 4)
64 |
65 | return coords, proj_jac
66 |
67 | return coords, None
68 |
69 | def actp(Gij, X0, jacobian=False):
70 | """ action on point cloud """
71 | X1 = Gij[:,:,None,None] * X0
72 |
73 | if jacobian:
74 | X, Y, Z, d = X1.unbind(dim=-1)
75 | o = torch.zeros_like(d)
76 | B, N, H, W = d.shape
77 |
78 | if isinstance(Gij, SE3):
79 | Ja = torch.stack([
80 | d, o, o, o, Z, -Y,
81 | o, d, o, -Z, o, X,
82 | o, o, d, Y, -X, o,
83 | o, o, o, o, o, o,
84 | ], dim=-1).view(B, N, H, W, 4, 6)
85 |
86 | elif isinstance(Gij, Sim3):
87 | Ja = torch.stack([
88 | d, o, o, o, Z, -Y, X,
89 | o, d, o, -Z, o, X, Y,
90 | o, o, d, Y, -X, o, Z,
91 | o, o, o, o, o, o, o
92 | ], dim=-1).view(B, N, H, W, 4, 7)
93 |
94 | return X1, Ja
95 |
96 | return X1, None
97 |
98 | def projective_transform(poses, depths, intrinsics, ii, jj,
99 | cam_T_body=None, # aka extrinsics
100 | stereo_extrinsics=[-0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0],
101 | jacobian=False, return_depth=False):
102 | """ map points from ii->jj """
103 |
104 | # inverse project (pinhole)
105 | X0, Jz = iproj(depths[:,ii], intrinsics[:,ii], jacobian=jacobian)
106 |
107 | # transform
108 | Gij = poses[:,jj] * poses[:,ii].inv()
109 |
110 | Gij.data[:,ii==jj] = torch.as_tensor(stereo_extrinsics, device=poses.device)
111 | X1, Ja = actp(Gij, X0, jacobian=jacobian)
112 |
113 | # project (pinhole)
114 | x1, Jp = proj(X1, intrinsics[:,jj], jacobian=jacobian, return_depth=return_depth)
115 |
116 | # exclude points too close to camera
117 | valid = ((X1[...,2] > MIN_DEPTH) & (X0[...,2] > MIN_DEPTH)).float()
118 | valid = valid.unsqueeze(-1)
119 |
120 | if jacobian:
121 | # Ji transforms according to dual adjoint
122 | Jj = torch.matmul(Jp, Ja)
123 | Ji = -Gij[:,:,None,None,None].adjT(Jj)
124 |
125 | # TODO, TODO, TODO: check this math...
126 | # Account for body to camera transformation
127 | if cam_T_body is not None:
128 | cam_T_body = SE3(cam_T_body)
129 | Ji = cam_T_body[None,None,None,None,None].adjT(Ji)
130 | Jj = cam_T_body[None,None,None,None,None].adjT(Jj)
131 |
132 | # To get right jacobians wrt world_T_body (doesn't really matter for covariance calculation since we take the square)
133 | Ji *= -1.0
134 | Jj *= -1.0
135 |
136 | # Account for Droid's (x,y,z,wx,wy,wz) convention to GTSAM's (wx,wy,wz,x,y,z) convention
137 | Ji = Ji[...,[3,4,5,0,1,2]]
138 | Jj = Jj[...,[3,4,5,0,1,2]]
139 |
140 | Jz = Gij[:,:,None,None] * Jz
141 | Jz = torch.matmul(Jp, Jz.unsqueeze(-1))
142 |
143 | return x1, valid, (Ji, Jj, Jz)
144 |
145 | return x1, valid, (None, None, None)
146 |
147 | def induced_flow(poses, disps, intrinsics, ii, jj):
148 | """ optical flow induced by camera motion """
149 |
150 | ht, wd = disps.shape[2:]
151 | y, x = torch.meshgrid(
152 | torch.arange(ht).to(disps.device).float(),
153 | torch.arange(wd).to(disps.device).float())
154 |
155 | coords0 = torch.stack([x, y], dim=-1)
156 | coords1, valid = projective_transform(poses, disps, intrinsics, ii, jj, False)
157 |
158 | return coords1[...,:2] - coords0, valid
159 |
160 |
--------------------------------------------------------------------------------
/networks/geom/rgbd_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os.path as osp
3 |
4 | import torch
5 | from lietorch import SE3
6 |
7 | from . import projective_ops as pops
8 | from scipy.spatial.transform import Rotation
9 |
10 |
11 | def parse_list(filepath, skiprows=0):
12 | """ read list data """
13 | data = np.loadtxt(filepath, delimiter=' ', dtype=np.unicode_, skiprows=skiprows)
14 | return data
15 |
16 | def associate_frames(tstamp_image, tstamp_depth, tstamp_pose, max_dt=1.0):
17 | """ pair images, depths, and poses """
18 | associations = []
19 | for i, t in enumerate(tstamp_image):
20 | if tstamp_pose is None:
21 | j = np.argmin(np.abs(tstamp_depth - t))
22 | if (np.abs(tstamp_depth[j] - t) < max_dt):
23 | associations.append((i, j))
24 |
25 | else:
26 | j = np.argmin(np.abs(tstamp_depth - t))
27 | k = np.argmin(np.abs(tstamp_pose - t))
28 |
29 | if (np.abs(tstamp_depth[j] - t) < max_dt) and \
30 | (np.abs(tstamp_pose[k] - t) < max_dt):
31 | associations.append((i, j, k))
32 |
33 | return associations
34 |
35 | def loadtum(datapath, frame_rate=-1):
36 | """ read video data in tum-rgbd format """
37 | if osp.isfile(osp.join(datapath, 'groundtruth.txt')):
38 | pose_list = osp.join(datapath, 'groundtruth.txt')
39 |
40 | elif osp.isfile(osp.join(datapath, 'pose.txt')):
41 | pose_list = osp.join(datapath, 'pose.txt')
42 |
43 | else:
44 | return None, None, None, None
45 |
46 | image_list = osp.join(datapath, 'rgb.txt')
47 | depth_list = osp.join(datapath, 'depth.txt')
48 |
49 | calib_path = osp.join(datapath, 'calibration.txt')
50 | intrinsic = None
51 | if osp.isfile(calib_path):
52 | intrinsic = np.loadtxt(calib_path, delimiter=' ')
53 | intrinsic = intrinsic.astype(np.float64)
54 |
55 | image_data = parse_list(image_list)
56 | depth_data = parse_list(depth_list)
57 | pose_data = parse_list(pose_list, skiprows=1)
58 | pose_vecs = pose_data[:,1:].astype(np.float64)
59 |
60 | tstamp_image = image_data[:,0].astype(np.float64)
61 | tstamp_depth = depth_data[:,0].astype(np.float64)
62 | tstamp_pose = pose_data[:,0].astype(np.float64)
63 | associations = associate_frames(tstamp_image, tstamp_depth, tstamp_pose)
64 |
65 | # print(len(tstamp_image))
66 | # print(len(associations))
67 |
68 | indicies = range(len(associations))[::5]
69 |
70 | # indicies = [ 0 ]
71 | # for i in range(1, len(associations)):
72 | # t0 = tstamp_image[associations[indicies[-1]][0]]
73 | # t1 = tstamp_image[associations[i][0]]
74 | # if t1 - t0 > 1.0 / frame_rate:
75 | # indicies += [ i ]
76 |
77 | images, poses, depths, intrinsics, tstamps = [], [], [], [], []
78 | for ix in indicies:
79 | (i, j, k) = associations[ix]
80 | images += [ osp.join(datapath, image_data[i,1]) ]
81 | depths += [ osp.join(datapath, depth_data[j,1]) ]
82 | poses += [ pose_vecs[k] ]
83 | tstamps += [ tstamp_image[i] ]
84 |
85 | if intrinsic is not None:
86 | intrinsics += [ intrinsic ]
87 |
88 | return images, depths, poses, intrinsics, tstamps
89 |
90 |
91 | def all_pairs_distance_matrix(poses, beta=2.5):
92 | """ compute distance matrix between all pairs of poses """
93 | poses = np.array(poses, dtype=np.float32)
94 | poses[:,:3] *= beta # scale to balence rot + trans
95 | poses = SE3(torch.from_numpy(poses))
96 |
97 | r = (poses[:,None].inv() * poses[None,:]).log()
98 | return r.norm(dim=-1).cpu().numpy()
99 |
100 | def pose_matrix_to_quaternion(pose):
101 | """ convert 4x4 pose matrix to (t, q) """
102 | q = Rotation.from_matrix(pose[:3, :3]).as_quat()
103 | return np.concatenate([pose[:3, 3], q], axis=0)
104 |
105 | def compute_distance_matrix_flow(poses, disps, intrinsics):
106 | """ compute flow magnitude between all pairs of frames """
107 | if not isinstance(poses, SE3):
108 | poses = torch.from_numpy(poses).float().cuda()[None]
109 | poses = SE3(poses).inv()
110 |
111 | disps = torch.from_numpy(disps).float().cuda()[None]
112 | intrinsics = torch.from_numpy(intrinsics).float().cuda()[None]
113 |
114 | N = poses.shape[1]
115 |
116 | ii, jj = torch.meshgrid(torch.arange(N), torch.arange(N))
117 | ii = ii.reshape(-1).cuda()
118 | jj = jj.reshape(-1).cuda()
119 |
120 | MAX_FLOW = 100.0
121 | matrix = np.zeros((N, N), dtype=np.float32)
122 |
123 | s = 2048
124 | for i in range(0, ii.shape[0], s):
125 | flow1, val1 = pops.induced_flow(poses, disps, intrinsics, ii[i:i+s], jj[i:i+s])
126 | flow2, val2 = pops.induced_flow(poses, disps, intrinsics, jj[i:i+s], ii[i:i+s])
127 |
128 | flow = torch.stack([flow1, flow2], dim=2)
129 | val = torch.stack([val1, val2], dim=2)
130 |
131 | mag = flow.norm(dim=-1).clamp(max=MAX_FLOW)
132 | mag = mag.view(mag.shape[1], -1)
133 | val = val.view(val.shape[1], -1)
134 |
135 | mag = (mag * val).mean(-1) / val.mean(-1)
136 | mag[val.mean(-1) < 0.7] = np.inf
137 |
138 | i1 = ii[i:i+s].cpu().numpy()
139 | j1 = jj[i:i+s].cpu().numpy()
140 | matrix[i1, j1] = mag.cpu().numpy()
141 |
142 | return matrix
143 |
144 |
145 | def compute_distance_matrix_flow2(poses, disps, intrinsics, beta=0.4):
146 | """ compute flow magnitude between all pairs of frames """
147 | # if not isinstance(poses, SE3):
148 | # poses = torch.from_numpy(poses).float().cuda()[None]
149 | # poses = SE3(poses).inv()
150 |
151 | # disps = torch.from_numpy(disps).float().cuda()[None]
152 | # intrinsics = torch.from_numpy(intrinsics).float().cuda()[None]
153 |
154 | N = poses.shape[1]
155 |
156 | ii, jj = torch.meshgrid(torch.arange(N), torch.arange(N))
157 | ii = ii.reshape(-1)
158 | jj = jj.reshape(-1)
159 |
160 | MAX_FLOW = 128.0
161 | matrix = np.zeros((N, N), dtype=np.float32)
162 |
163 | s = 2048
164 | for i in range(0, ii.shape[0], s):
165 | flow1a, val1a = pops.induced_flow(poses, disps, intrinsics, ii[i:i+s], jj[i:i+s], tonly=True)
166 | flow1b, val1b = pops.induced_flow(poses, disps, intrinsics, ii[i:i+s], jj[i:i+s])
167 | flow2a, val2a = pops.induced_flow(poses, disps, intrinsics, jj[i:i+s], ii[i:i+s], tonly=True)
168 | flow2b, val2b = pops.induced_flow(poses, disps, intrinsics, ii[i:i+s], jj[i:i+s])
169 |
170 | flow1 = flow1a + beta * flow1b
171 | val1 = val1a * val2b
172 |
173 | flow2 = flow2a + beta * flow2b
174 | val2 = val2a * val2b
175 |
176 | flow = torch.stack([flow1, flow2], dim=2)
177 | val = torch.stack([val1, val2], dim=2)
178 |
179 | mag = flow.norm(dim=-1).clamp(max=MAX_FLOW)
180 | mag = mag.view(mag.shape[1], -1)
181 | val = val.view(val.shape[1], -1)
182 |
183 | mag = (mag * val).mean(-1) / val.mean(-1)
184 | mag[val.mean(-1) < 0.8] = np.inf
185 |
186 | i1 = ii[i:i+s].cpu().numpy()
187 | j1 = jj[i:i+s].cpu().numpy()
188 | matrix[i1, j1] = mag.cpu().numpy()
189 |
190 | return matrix
191 |
--------------------------------------------------------------------------------
/networks/modules/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 |
--------------------------------------------------------------------------------
/networks/modules/clipping.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | GRAD_CLIP = .01
6 |
7 | class GradClip(torch.autograd.Function):
8 | @staticmethod
9 | def forward(ctx, x):
10 | return x
11 |
12 | @staticmethod
13 | def backward(ctx, grad_x):
14 | o = torch.zeros_like(grad_x)
15 | grad_x = torch.where(grad_x.abs()>GRAD_CLIP, o, grad_x)
16 | grad_x = torch.where(torch.isnan(grad_x), o, grad_x)
17 | return grad_x
18 |
19 | class GradientClip(nn.Module):
20 | def __init__(self):
21 | super(GradientClip, self).__init__()
22 |
23 | def forward(self, x):
24 | return GradClip.apply(x)
--------------------------------------------------------------------------------
/networks/modules/corr.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | import droid_backends
5 |
6 | class CorrSampler(torch.autograd.Function):
7 |
8 | @staticmethod
9 | def forward(ctx, volume, coords, radius):
10 | ctx.save_for_backward(volume,coords)
11 | ctx.radius = radius
12 | corr, = droid_backends.corr_index_forward(volume, coords, radius)
13 | return corr
14 |
15 | @staticmethod
16 | def backward(ctx, grad_output):
17 | volume, coords = ctx.saved_tensors
18 | grad_output = grad_output.contiguous()
19 | grad_volume, = droid_backends.corr_index_backward(volume, coords, grad_output, ctx.radius)
20 | return grad_volume, None, None
21 |
22 |
23 | class CorrBlock:
24 | def __init__(self, fmap1, fmap2, num_levels=4, radius=3):
25 | self.num_levels = num_levels
26 | self.radius = radius
27 | self.corr_pyramid = []
28 |
29 | # all pairs correlation
30 | corr = CorrBlock.corr(fmap1, fmap2)
31 |
32 | batch, num, h1, w1, h2, w2 = corr.shape
33 | corr = corr.reshape(batch*num*h1*w1, 1, h2, w2)
34 |
35 | for i in range(self.num_levels):
36 | self.corr_pyramid.append(
37 | corr.view(batch*num, h1, w1, h2//2**i, w2//2**i))
38 | corr = F.avg_pool2d(corr, 2, stride=2)
39 |
40 | def __call__(self, coords):
41 | out_pyramid = []
42 | batch, num, ht, wd, _ = coords.shape
43 | coords = coords.permute(0,1,4,2,3)
44 | coords = coords.contiguous().view(batch*num, 2, ht, wd)
45 |
46 | for i in range(self.num_levels):
47 | corr = CorrSampler.apply(self.corr_pyramid[i], coords/2**i, self.radius)
48 | out_pyramid.append(corr.view(batch, num, -1, ht, wd))
49 |
50 | return torch.cat(out_pyramid, dim=2)
51 |
52 | def cat(self, other):
53 | for i in range(self.num_levels):
54 | self.corr_pyramid[i] = torch.cat([self.corr_pyramid[i], other.corr_pyramid[i]], 0)
55 | return self
56 |
57 | def __getitem__(self, index):
58 | for i in range(self.num_levels):
59 | self.corr_pyramid[i] = self.corr_pyramid[i][index]
60 | return self
61 |
62 |
63 | @staticmethod
64 | def corr(fmap1, fmap2):
65 | """ all-pairs correlation """
66 | batch, num, dim, ht, wd = fmap1.shape
67 | fmap1 = fmap1.reshape(batch*num, dim, ht*wd) / 4.0
68 | fmap2 = fmap2.reshape(batch*num, dim, ht*wd) / 4.0
69 |
70 | # multiplies along the `dim' dimension, for each batch, for each feature map
71 | corr = torch.matmul(fmap1.transpose(1,2), fmap2)
72 | return corr.view(batch, num, ht, wd, ht, wd)
73 |
74 |
75 | class CorrLayer(torch.autograd.Function):
76 | @staticmethod
77 | def forward(ctx, fmap1, fmap2, coords, r):
78 | ctx.r = r
79 | ctx.save_for_backward(fmap1, fmap2, coords)
80 | corr, = droid_backends.altcorr_forward(fmap1, fmap2, coords, ctx.r)
81 | return corr
82 |
83 | @staticmethod
84 | def backward(ctx, grad_corr):
85 | fmap1, fmap2, coords = ctx.saved_tensors
86 | grad_corr = grad_corr.contiguous()
87 | fmap1_grad, fmap2_grad, coords_grad = \
88 | droid_backends.altcorr_backward(fmap1, fmap2, coords, grad_corr, ctx.r)
89 | return fmap1_grad, fmap2_grad, coords_grad, None
90 |
91 |
92 | class AltCorrBlock:
93 | def __init__(self, fmaps, num_levels=4, radius=3):
94 | self.num_levels = num_levels
95 | self.radius = radius
96 |
97 | B, N, C, H, W = fmaps.shape
98 | fmaps = fmaps.view(B*N, C, H, W) / 4.0
99 |
100 | self.pyramid = []
101 | for i in range(self.num_levels):
102 | sz = (B, N, H//2**i, W//2**i, C)
103 | fmap_lvl = fmaps.permute(0, 2, 3, 1).contiguous()
104 | self.pyramid.append(fmap_lvl.view(*sz))
105 | fmaps = F.avg_pool2d(fmaps, 2, stride=2)
106 |
107 | def corr_fn(self, coords, ii, jj):
108 | B, N, H, W, S, _ = coords.shape
109 | coords = coords.permute(0, 1, 4, 2, 3, 5)
110 |
111 | corr_list = []
112 | for i in range(self.num_levels):
113 | r = self.radius
114 | fmap1_i = self.pyramid[0][:, ii]
115 | fmap2_i = self.pyramid[i][:, jj]
116 |
117 | coords_i = (coords / 2**i).reshape(B*N, S, H, W, 2).contiguous()
118 | fmap1_i = fmap1_i.reshape((B*N,) + fmap1_i.shape[2:])
119 | fmap2_i = fmap2_i.reshape((B*N,) + fmap2_i.shape[2:])
120 |
121 | corr = CorrLayer.apply(fmap1_i.float(), fmap2_i.float(), coords_i, self.radius)
122 | corr = corr.view(B, N, S, -1, H, W).permute(0, 1, 3, 4, 5, 2)
123 | corr_list.append(corr)
124 |
125 | corr = torch.cat(corr_list, dim=2)
126 | return corr
127 |
128 |
129 | def __call__(self, coords, ii, jj):
130 | squeeze_output = False
131 | if len(coords.shape) == 5:
132 | coords = coords.unsqueeze(dim=-2)
133 | squeeze_output = True
134 |
135 | corr = self.corr_fn(coords, ii, jj)
136 |
137 | if squeeze_output:
138 | corr = corr.squeeze(dim=-1)
139 |
140 | return corr.contiguous()
141 |
142 |
--------------------------------------------------------------------------------
/networks/modules/extractor.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class ResidualBlock(nn.Module):
7 | def __init__(self, in_planes, planes, norm_fn='group', stride=1):
8 | super(ResidualBlock, self).__init__()
9 |
10 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)
11 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
12 | self.relu = nn.ReLU(inplace=True)
13 |
14 | num_groups = planes // 8
15 |
16 | if norm_fn == 'group':
17 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
18 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
19 | if not stride == 1:
20 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
21 |
22 | elif norm_fn == 'batch':
23 | self.norm1 = nn.BatchNorm2d(planes)
24 | self.norm2 = nn.BatchNorm2d(planes)
25 | if not stride == 1:
26 | self.norm3 = nn.BatchNorm2d(planes)
27 |
28 | elif norm_fn == 'instance':
29 | self.norm1 = nn.InstanceNorm2d(planes)
30 | self.norm2 = nn.InstanceNorm2d(planes)
31 | if not stride == 1:
32 | self.norm3 = nn.InstanceNorm2d(planes)
33 |
34 | elif norm_fn == 'none':
35 | self.norm1 = nn.Sequential()
36 | self.norm2 = nn.Sequential()
37 | if not stride == 1:
38 | self.norm3 = nn.Sequential()
39 |
40 | if stride == 1:
41 | self.downsample = None
42 |
43 | else:
44 | self.downsample = nn.Sequential(
45 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
46 |
47 | def forward(self, x):
48 | y = x
49 | y = self.relu(self.norm1(self.conv1(y)))
50 | y = self.relu(self.norm2(self.conv2(y)))
51 |
52 | if self.downsample is not None:
53 | x = self.downsample(x)
54 |
55 | return self.relu(x+y)
56 |
57 |
58 | class BottleneckBlock(nn.Module):
59 | def __init__(self, in_planes, planes, norm_fn='group', stride=1):
60 | super(BottleneckBlock, self).__init__()
61 |
62 | self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0)
63 | self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride)
64 | self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0)
65 | self.relu = nn.ReLU(inplace=True)
66 |
67 | num_groups = planes // 8
68 |
69 | if norm_fn == 'group':
70 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
71 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
72 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
73 | if not stride == 1:
74 | self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
75 |
76 | elif norm_fn == 'batch':
77 | self.norm1 = nn.BatchNorm2d(planes//4)
78 | self.norm2 = nn.BatchNorm2d(planes//4)
79 | self.norm3 = nn.BatchNorm2d(planes)
80 | if not stride == 1:
81 | self.norm4 = nn.BatchNorm2d(planes)
82 |
83 | elif norm_fn == 'instance':
84 | self.norm1 = nn.InstanceNorm2d(planes//4)
85 | self.norm2 = nn.InstanceNorm2d(planes//4)
86 | self.norm3 = nn.InstanceNorm2d(planes)
87 | if not stride == 1:
88 | self.norm4 = nn.InstanceNorm2d(planes)
89 |
90 | elif norm_fn == 'none':
91 | self.norm1 = nn.Sequential()
92 | self.norm2 = nn.Sequential()
93 | self.norm3 = nn.Sequential()
94 | if not stride == 1:
95 | self.norm4 = nn.Sequential()
96 |
97 | if stride == 1:
98 | self.downsample = None
99 |
100 | else:
101 | self.downsample = nn.Sequential(
102 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4)
103 |
104 | def forward(self, x):
105 | y = x
106 | y = self.relu(self.norm1(self.conv1(y)))
107 | y = self.relu(self.norm2(self.conv2(y)))
108 | y = self.relu(self.norm3(self.conv3(y)))
109 |
110 | if self.downsample is not None:
111 | x = self.downsample(x)
112 |
113 | return self.relu(x+y)
114 |
115 |
116 | DIM=32
117 |
118 | class BasicEncoder(nn.Module):
119 | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0, multidim=False):
120 | super(BasicEncoder, self).__init__()
121 | self.norm_fn = norm_fn
122 | self.multidim = multidim
123 |
124 | if self.norm_fn == 'group':
125 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=DIM)
126 |
127 | elif self.norm_fn == 'batch':
128 | self.norm1 = nn.BatchNorm2d(DIM)
129 |
130 | elif self.norm_fn == 'instance':
131 | self.norm1 = nn.InstanceNorm2d(DIM)
132 |
133 | elif self.norm_fn == 'none':
134 | self.norm1 = nn.Sequential()
135 |
136 | self.conv1 = nn.Conv2d(3, DIM, kernel_size=7, stride=2, padding=3)
137 | self.relu1 = nn.ReLU(inplace=True)
138 |
139 | self.in_planes = DIM
140 | self.layer1 = self._make_layer(DIM, stride=1)
141 | self.layer2 = self._make_layer(2*DIM, stride=2)
142 | self.layer3 = self._make_layer(4*DIM, stride=2)
143 |
144 | # output convolution
145 | self.conv2 = nn.Conv2d(4*DIM, output_dim, kernel_size=1)
146 |
147 | if self.multidim:
148 | self.layer4 = self._make_layer(256, stride=2)
149 | self.layer5 = self._make_layer(512, stride=2)
150 |
151 | self.in_planes = 256
152 | self.layer6 = self._make_layer(256, stride=1)
153 |
154 | self.in_planes = 128
155 | self.layer7 = self._make_layer(128, stride=1)
156 |
157 | self.up1 = nn.Conv2d(512, 256, 1)
158 | self.up2 = nn.Conv2d(256, 128, 1)
159 | self.conv3 = nn.Conv2d(128, output_dim, kernel_size=1)
160 |
161 | if dropout > 0:
162 | self.dropout = nn.Dropout2d(p=dropout)
163 | else:
164 | self.dropout = None
165 |
166 | for m in self.modules():
167 | if isinstance(m, nn.Conv2d):
168 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
169 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
170 | if m.weight is not None:
171 | nn.init.constant_(m.weight, 1)
172 | if m.bias is not None:
173 | nn.init.constant_(m.bias, 0)
174 |
175 | def _make_layer(self, dim, stride=1):
176 | layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
177 | layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
178 | layers = (layer1, layer2)
179 |
180 | self.in_planes = dim
181 | return nn.Sequential(*layers)
182 |
183 | def forward(self, x):
184 | b, n, c1, h1, w1 = x.shape
185 | x = x.view(b*n, c1, h1, w1)
186 |
187 | x = self.conv1(x)
188 | x = self.norm1(x)
189 | x = self.relu1(x)
190 |
191 | x = self.layer1(x)
192 | x = self.layer2(x)
193 | x = self.layer3(x)
194 |
195 | x = self.conv2(x)
196 |
197 | _, c2, h2, w2 = x.shape
198 | return x.view(b, n, c2, h2, w2)
199 |
--------------------------------------------------------------------------------
/networks/modules/gru.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class ConvGRU(nn.Module):
6 | def __init__(self, h_planes=128, i_planes=128):
7 | super(ConvGRU, self).__init__()
8 | self.do_checkpoint = False
9 | self.convz = nn.Conv2d(h_planes+i_planes, h_planes, 3, padding=1)
10 | self.convr = nn.Conv2d(h_planes+i_planes, h_planes, 3, padding=1)
11 | self.convq = nn.Conv2d(h_planes+i_planes, h_planes, 3, padding=1)
12 |
13 | self.w = nn.Conv2d(h_planes, h_planes, 1, padding=0)
14 |
15 | self.convz_glo = nn.Conv2d(h_planes, h_planes, 1, padding=0)
16 | self.convr_glo = nn.Conv2d(h_planes, h_planes, 1, padding=0)
17 | self.convq_glo = nn.Conv2d(h_planes, h_planes, 1, padding=0)
18 |
19 | def forward(self, net, *inputs):
20 | inp = torch.cat(inputs, dim=1)
21 | net_inp = torch.cat([net, inp], dim=1)
22 |
23 | b, c, h, w = net.shape
24 | glo = torch.sigmoid(self.w(net)) * net
25 | glo = glo.view(b, c, h*w).mean(-1).view(b, c, 1, 1)
26 |
27 | z = torch.sigmoid(self.convz(net_inp) + self.convz_glo(glo))
28 | r = torch.sigmoid(self.convr(net_inp) + self.convr_glo(glo))
29 | q = torch.tanh(self.convq(torch.cat([r*net, inp], dim=1)) + self.convq_glo(glo))
30 |
31 | net = (1-z) * net + z * q
32 | return net
33 |
34 |
35 |
--------------------------------------------------------------------------------
/networks/motion_filter.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import lietorch
3 |
4 | from .geom import projective_ops as pops
5 | from .modules.corr import CorrBlock
6 |
7 |
8 | # Basically populates the video with the frames that have enough motion
9 | # It calculates the features for each frame, does one step GRU update
10 | # to get a sense of the flow and if the flow > min_flow
11 | # it adds the frame, together with features and context to the video object
12 | class MotionFilter:
13 | """ This class is used to filter incoming frames and extract features """
14 | def __init__(self, net, video, min_flow_thresh=2.5, device="cuda:0"):
15 | # split net modules
16 | self.context_net = net.cnet
17 | self.feature_net = net.fnet
18 | self.update_net = net.update
19 |
20 | self.video = video
21 | self.min_flow_thresh = min_flow_thresh
22 |
23 | self.skipped_frames = 0
24 | self.device = device
25 |
26 | # mean, std for image normalization
27 | self.MEAN = torch.as_tensor([0.485, 0.456, 0.406], device=self.device)[:, None, None]
28 | self.STDV = torch.as_tensor([0.229, 0.224, 0.225], device=self.device)[:, None, None]
29 |
30 | @torch.cuda.amp.autocast(enabled=True)
31 | def __context_encoder(self, image):
32 | """ context features """
33 | context_maps, gru_input_maps = self.context_net(image).split([128,128], dim=2)
34 | return context_maps.tanh().squeeze(0), gru_input_maps.relu().squeeze(0)
35 |
36 | @torch.cuda.amp.autocast(enabled=True)
37 | def __feature_encoder(self, image):
38 | """ features for correlation volume """
39 | return self.feature_net(image).squeeze(0)
40 |
41 | @torch.cuda.amp.autocast(enabled=True)
42 | @torch.no_grad()
43 | def track(self, k, timestamp, image, depth=None, intrinsics=None):
44 | """ main update operation - run on every frame in video """
45 |
46 | # normalize images
47 | img_normalized = image[None, :, [2,1,0]].to(self.device) / 255.0
48 | img_normalized = img_normalized.sub_(self.MEAN).div_(self.STDV)
49 |
50 | # extract features
51 | feature_map = self.__feature_encoder(img_normalized)
52 |
53 | ### always add first frame to the depth video ###
54 | if k == 0:
55 | self.add_frame_to_video(timestamp, image, img_normalized, feature_map, depth, intrinsics)
56 | ### only add new frame if there is enough motion ###
57 | else:
58 | ht = image.shape[-2] // 8
59 | wd = image.shape[-1] // 8
60 |
61 | # Index correlation volume
62 | coords0 = pops.coords_grid(ht, wd, device=self.device)[None,None]
63 | corr = CorrBlock(self.feature_maps[None,[0]], feature_map[None,[0]])(coords0) # TODO why not send the corr block?
64 |
65 | # Approximate flow magnitude using 1 update iteration
66 | _, delta, weight = self.update_net(self.context_maps[None], self.gru_input_maps[None], corr)
67 |
68 | # Check motion magnitue / add new frame to video
69 | has_enough_motion = delta.norm(dim=-1).mean().item() > self.min_flow_thresh
70 | if has_enough_motion:
71 | self.add_frame_to_video(timestamp, image, img_normalized, feature_map, depth, intrinsics)
72 | self.skipped_frames = 0
73 | else:
74 | self.skipped_frames += 1
75 |
76 | # Save the network responses (feature + context) and image for later inference
77 | def add_frame_to_video(self, timestamp, image, img_normalized, feature_map, depth_img=None, intrinsics=None):
78 | context_maps, gru_input_maps = self.__context_encoder(img_normalized[:,[0]])
79 | self.context_maps, self.gru_input_maps, self.feature_maps = context_maps, gru_input_maps, feature_map
80 | identity_pose = lietorch.SE3.Identity(1,).data.squeeze()
81 | self.video.append(timestamp, image[0], identity_pose,
82 | 1.0, # Initialize disps at 1
83 | depth_img, # If available
84 | intrinsics / 8.0,
85 | feature_map, context_maps[0,0], gru_input_maps[0,0])
86 |
--------------------------------------------------------------------------------
/pipeline/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ToniRV/NeRF-SLAM/4e407e8cb1378c6ece18e621b19ccd5be982b7dd/pipeline/__init__.py
--------------------------------------------------------------------------------
/pipeline/pipeline.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 |
4 | class Pipeline:
5 | def __init__(self, ):
6 | super().__init__()
7 |
8 |
--------------------------------------------------------------------------------
/pipeline/pipeline_module.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod
2 | import colored_glog as log
3 | from icecream import ic
4 |
5 | #from torch.profiler import profile, ProfilerActivity
6 |
7 | class PipelineModuleBase:
8 | def __init__(self, name, parallel_run, args=None, grad=False):
9 | self.name = name
10 | self.parallel_run = parallel_run
11 | self.grad = grad # determines if we are tracing grads or not
12 | self.shutdown = False # needs to be atomic
13 | self.is_initialized = False # needs to be atomic
14 | self.is_thread_working = False # needs to be atomic
15 | self.args = args # arguments to init module
16 | # Callbacks to be called in case module does not return an output.
17 | self.on_failure_callbacks = []
18 | self.profile = False # Profile the code for runtime and/or memory
19 |
20 | @abstractmethod
21 | def initialize_module(self):
22 | # Allocate memory and initialize variables here so that
23 | # we do not need to avoid when running in parallel a:
24 | # "TypeError: cannot pickle 'XXXX' object"
25 | self.is_initialized = True
26 |
27 | @abstractmethod
28 | def spin(self) -> bool:
29 | pass
30 |
31 | @abstractmethod
32 | def shutdown_queues(self):
33 | pass
34 |
35 | @abstractmethod
36 | def has_work(self):
37 | pass
38 |
39 | def shutdown_module(self):
40 | # TODO shouldn't self.shutdown be atomic? (i.e. thread-safe?)
41 | if self.shutdown:
42 | log.warn(f"Module: {self.name} - Shutdown requested, but was already shutdown.")
43 | log.debug(f"Stopping module {self.name} and its queues...")
44 | self.shutdown_queues()
45 | log.info(f"Module: {self.name} - Shutting down.")
46 | self.shutdown = True
47 |
48 | def restart(self):
49 | log.info(f"Module: {self.name} - Resetting shutdown flag to false")
50 | self.shutdown = False
51 |
52 | def is_working(self):
53 | return self.is_thread_working or self.hasWork()
54 |
55 | def register_on_failure_callback(self, callback):
56 | log.check(callback)
57 | self.on_failure_callbacks.append(callback)
58 |
59 | def notify_on_failure(self):
60 | for on_failure_callback in self.on_failure_callbacks:
61 | if on_failure_callback:
62 | on_failure_callback()
63 | else:
64 | log.error(f"Invalid OnFailureCallback for module: {self.name}")
65 |
66 | class PipelineModule(PipelineModuleBase):
67 | def __init__(self, name_id, parallel_run, args=None, grad=False) -> None:
68 | super().__init__(name_id, parallel_run, args, grad)
69 |
70 | @abstractmethod
71 | def get_input_packet(self):
72 | raise
73 |
74 | @abstractmethod
75 | def push_output_packet(self, output_packet) -> bool:
76 | raise
77 |
78 | @abstractmethod
79 | def spin_once(self, input):
80 | raise
81 |
82 | # Spin is called in a thread.
83 | def spin(self):
84 | if self.parallel_run:
85 | log.info(f'Module: {self.name} - Spinning.')
86 |
87 | if not self.is_initialized:
88 | self.initialize_module()
89 |
90 | while not self.shutdown:
91 | self.is_thread_working = False;
92 | input = self.get_input_packet();
93 | self.is_thread_working = True;
94 | if input is not None:
95 | output = None
96 | if self.profile:
97 | #with profile(activities=[ProfilerActivity.CUDA], profile_memory=True, record_shapes=True) as prof:
98 | output = self.spin_once(input);
99 | #print(prof.key_averages().table(sort_by="self_cuda_memory_usage", row_limit=10))
100 | else:
101 | output = self.spin_once(input);
102 | if output is not None:
103 | # Received a valid output, send to output queue
104 | if not self.push_output_packet(output):
105 | log.warn(f"Module: {self.name} - Output push failed.")
106 | else:
107 | log.debug(f"Module: {self.name} - Pushed output.")
108 | else:
109 | log.debug(f"Module: {self.name} - Skipped sending an output.")
110 | # Notify interested parties about failure.
111 | self.notify_on_failure();
112 | else:
113 | log.log(2, f"Module: {self.name} - No Input received.")
114 |
115 | # Break the while loop if we are in sequential mode.
116 | if not self.parallel_run:
117 | self.is_thread_working = False;
118 | return True;
119 |
120 | self.is_thread_working = False;
121 | log.info(f"Module: {self.name} - Successful shutdown.")
122 | return False;
123 |
124 | class MIMOPipelineModule(PipelineModule):
125 | def __init__(self, name_id, parallel_run, args=None, grad=False):
126 | super().__init__(name_id, parallel_run, args, grad)
127 | self.input_queues = {}
128 | self.output_callbacks = []
129 | self.output_queues = []
130 |
131 | def register_input_queue(self, name, input_queue):
132 | self.input_queues[name] = input_queue
133 |
134 | def register_output_callback(self, output_callback):
135 | self.output_callbacks.append(output_callback)
136 |
137 | def register_output_queue(self, output_queue):
138 | self.output_queues.append(output_queue)
139 |
140 | # TODO: warn when callbacks take too long
141 | def push_output_packet(self, output_packet):
142 | push_success = True
143 | # Push output to all queues
144 | for output_queue in self.output_queues:
145 | try:
146 | output_queue.put(output_packet)
147 | except Exception as e:
148 | log.warn(e)
149 | push_success = False
150 | # Push output to all callbacks
151 | for callback in self.output_callbacks:
152 | try:
153 | callback(output_packet)
154 | except Exception as e:
155 | log.warn(e)
156 | push_success = False
157 | return push_success
158 |
159 | def get_input_packet(self, timeout=0.1):
160 | inputs = {}
161 | if self.parallel_run:
162 | for name, input_queue in self.input_queues.items():
163 | try:
164 | inputs[name] = input_queue.get(timeout=timeout)
165 | except Exception as e:
166 | log.debug(e)
167 | else:
168 | for name, input_queue in self.input_queues.items():
169 | try:
170 | inputs[name] = input_queue.get_nowait()
171 | except Exception as e:
172 | log.debug(e)
173 |
174 | if len(inputs) == 0:
175 | log.debug(f"Module: {self.name} - Input queues didn't return an output.")
176 | inputs = None
177 | return inputs
178 |
179 | def shutdown_queues(self):
180 | super().shutdown_queues()
181 | # This should be automatically called by garbage collector
182 | # [input_queue.close() for input_queue in self.input_queues]
183 | # [output_queue.close() for output_queue in self.output_queues]
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | scipy
3 | tqdm
4 | matplotlib
5 |
6 | open3d
7 | opencv-python
8 |
9 | torch-scatter
10 |
11 | jupyter
12 | imageio
13 |
14 | glog
15 | icecream
16 |
17 | pandas
18 | pyrealsense2
19 |
20 | pybind11
21 | gdown
22 |
--------------------------------------------------------------------------------
/scripts/convergence_plots.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "id": "10010259",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "import pandas as pd\n",
11 | "\n",
12 | "import numpy as np\n",
13 | "import matplotlib.pyplot as plt\n",
14 | "plt.rcParams[\"text.usetex\"] = True\n",
15 | "plt.rcParams[\"font.family\"] = \"serif\"\n",
16 | "plt.rcParams[\"font.size\"] = \"14\"\n",
17 | "\n",
18 | "\n",
19 | "def plot_results(dfs, show=True, save=True, xlim=5000):\n",
20 | " fig, ax1 = plt.subplots()\n",
21 | " \n",
22 | " from itertools import cycle\n",
23 | " lines = [\"-\",\"--\",\"-.\",\":\"]\n",
24 | " linecycler = cycle(lines)\n",
25 | "\n",
26 | " \n",
27 | " ax1.set_xlabel('Iter [-]')\n",
28 | " \n",
29 | " psnr_color = 'tab:red'\n",
30 | " ax1.set_ylabel('PSNR [dB]', color=psnr_color)\n",
31 | " ax1.tick_params(axis='y', labelcolor=psnr_color)\n",
32 | " #ax1.set_xlim(0, df['Iter'][df.index[-1]])\n",
33 | " ax1.set_xlim(0, xlim)\n",
34 | " \n",
35 | " ax2 = ax1.twiny()\n",
36 | " ax2.set_xlabel('$\\Delta$t [s]')\n",
37 | "\n",
38 | " l1_color = 'tab:blue'\n",
39 | " ax3 = ax1.twinx() # instantiate a second axes that shares the same x-axis\n",
40 | " ax3.set_ylabel('L1 [cm]', color=l1_color)\n",
41 | " ax3.set_ylim(0, 30)\n",
42 | " ax3.tick_params(axis='y', labelcolor=l1_color)\n",
43 | " \n",
44 | " for name, df in dfs.items(): \n",
45 | " \n",
46 | " #df = df.iloc[1: , :] # Drop first row which only has 0.0\n",
47 | " df = df.set_index('Iter')\n",
48 | " \n",
49 | " line_style = next(linecycler)\n",
50 | "\n",
51 | " \n",
52 | " ax1.plot(df.index, df['PSNR'], color=psnr_color, linestyle=line_style, label=name) \n",
53 | " ax3.plot(df.index, df['L1'], color=l1_color, linestyle=line_style)\n",
54 | "\n",
55 | " ax2.set_xticks(ax1.get_xticks())\n",
56 | " ax2.set_xbound(ax1.get_xbound())\n",
57 | " xticklabels = [df['Dt'].loc[xtick].round(0).item() for xtick in ax1.get_xticks()]\n",
58 | " print(xticklabels)\n",
59 | " ax2.set_xticklabels(xticklabels)\n",
60 | " \n",
61 | " \n",
62 | " #ax1.legend(loc='upper center', fontsize='x-large')\n",
63 | " leg = ax1.legend(loc='center right')\n",
64 | " [lgd.set_color('black') for lgd in leg.legendHandles]\n",
65 | "\n",
66 | " fig.tight_layout() # otherwise the right y-label is slightly clipped\n",
67 | " if show: plt.show()\n",
68 | " if save: \n",
69 | " #fig.set_size_inches(3, 2)\n",
70 | " out = ''\n",
71 | " for name in list(dfs.keys())[:-1]:\n",
72 | " out += name + '_vs_'\n",
73 | " out += list(dfs.keys())[-1]\n",
74 | " print(f\"Saving to: {out}\")\n",
75 | " fig.savefig(out + '.svg', dpi=100, transparent=False, bbox_inches='tight')"
76 | ]
77 | },
78 | {
79 | "cell_type": "code",
80 | "execution_count": null,
81 | "id": "edc45e33",
82 | "metadata": {
83 | "scrolled": false
84 | },
85 | "outputs": [],
86 | "source": [
87 | "dfs = {\n",
88 | " #\"ours_l1\": pd.read_csv('../room_results_ours_l1.csv'),\n",
89 | " \"no depth\": pd.read_csv('../room_results_no_depth_3.csv'),\n",
90 | " \"raw depth\": pd.read_csv('../room_results_raw_depth.csv'),\n",
91 | " #\"raw depth (w annealing)\": pd.read_csv('../room_results_raw_w_annealing.csv'),\n",
92 | " \"weighted depth\": pd.read_csv('../room_results_ours_wo_thresh.csv'),\n",
93 | " #\"weighted + annealed\": pd.read_csv('../room_results_ours_l2_full.csv'),\n",
94 | " \"weighted + annealed\": pd.read_csv('../room_results_ours_w_annealing.csv'),\n",
95 | "}\n",
96 | "plot_results(dfs, show=True, save=\"room_results\")"
97 | ]
98 | },
99 | {
100 | "cell_type": "code",
101 | "execution_count": null,
102 | "id": "ac677682",
103 | "metadata": {},
104 | "outputs": [],
105 | "source": [
106 | "dfs = {\n",
107 | " #\"ours_no_anneal_thresh\": pd.read_csv('../replica_office0_ours_wo_anneal_w_thresh.csv').iloc[9:,:],\n",
108 | " \"no depth\": pd.read_csv('../replica_office0_no_depth_2.csv'),\n",
109 | " #\"ours_median_annealed\": pd.read_csv('../replica_office0_ours_median_annealed.csv'),\n",
110 | " #\"ours_median_annealed\": pd.read_csv('../replica_office0_ours_w_anneal_w_median.csv'), \n",
111 | " #\"ours_wo_thresh_10\": pd.read_csv('../replica_office0_ours_wo_anneal_wo_thresh_3.csv'),\n",
112 | " #\"ours_wo_thresh\": pd.read_csv('../replica_office0_ours_wo_anneal_wo_thresh_2.csv'),\n",
113 | " \"raw\": pd.read_csv('../replica_office0_raw_2.csv'),\n",
114 | " \"ours_median\": pd.read_csv('../replica_office0_ours_wo_anneal_w_median.csv'),\n",
115 | " #\"ours_median_100\": pd.read_csv('../replica_office0_ours_wo_anneal_median_100_lucky.csv')\n",
116 | "}\n",
117 | "plot_results(dfs, show=True, save=\"room_results\", xlim=20000)"
118 | ]
119 | },
120 | {
121 | "cell_type": "code",
122 | "execution_count": null,
123 | "id": "7d52e51f",
124 | "metadata": {},
125 | "outputs": [],
126 | "source": [
127 | "dfs = {\n",
128 | " \"room0\": pd.read_csv('../room0_nerf_results.csv'),\n",
129 | " \"room1\": pd.read_csv('../room1_nerf_results.csv'),\n",
130 | " \"room2\": pd.read_csv('../room2_nerf_results.csv'),\n",
131 | " \"office0\": pd.read_csv('../office0_nerf_results.csv'),\n",
132 | " \"office1\": pd.read_csv('../office1_nerf_results.csv'),\n",
133 | "}\n",
134 | "plot_results(dfs, show=True, save=\"room_results\", xlim=25000)"
135 | ]
136 | },
137 | {
138 | "cell_type": "code",
139 | "execution_count": null,
140 | "id": "a7327f16",
141 | "metadata": {},
142 | "outputs": [],
143 | "source": [
144 | "dfs = {\n",
145 | " \"no depth\": pd.read_csv('../room_nerf_no_depth_results.csv'),\n",
146 | " \n",
147 | " #\"weighted+filtered\": pd.read_csv('../room_nerf_ours_w_thresh_results.csv'),\n",
148 | " \"raw\": pd.read_csv('../room_nerf_raw_results.csv'),\n",
149 | " \"weighted\": pd.read_csv('../room_nerf_ours_results.csv'),\n",
150 | "}\n",
151 | "plot_results(dfs, show=True, save=\"room_results\", xlim=25000)"
152 | ]
153 | },
154 | {
155 | "cell_type": "code",
156 | "execution_count": null,
157 | "id": "0e9db743",
158 | "metadata": {},
159 | "outputs": [],
160 | "source": [
161 | "dfs = {\n",
162 | " \"no depth\": pd.read_csv('../room0_nerf_no_depth_results.csv'),\n",
163 | " \"ours\": pd.read_csv('../room0_nerf_ours_results.csv'),\n",
164 | " \"raw\": pd.read_csv('../room0_nerf_raw_results.csv'),\n",
165 | "}\n",
166 | "plot_results(dfs, show=True, save=\"room0_results\", xlim=25000)"
167 | ]
168 | },
169 | {
170 | "cell_type": "code",
171 | "execution_count": null,
172 | "id": "7fa4ea26",
173 | "metadata": {},
174 | "outputs": [],
175 | "source": [
176 | "dfs = {\n",
177 | " \"no depth\": pd.read_csv('../room1_nerf_no_depth_results.csv'),\n",
178 | " \"ours\": pd.read_csv('../room1_nerf_ours_results.csv'),\n",
179 | " \"raw\": pd.read_csv('../room1_nerf_raw_results.csv'),\n",
180 | "}\n",
181 | "plot_results(dfs, show=True, save=\"room0_results\", xlim=25000)"
182 | ]
183 | },
184 | {
185 | "cell_type": "code",
186 | "execution_count": null,
187 | "id": "5676517b",
188 | "metadata": {},
189 | "outputs": [],
190 | "source": [
191 | "dfs = {\n",
192 | " \"no depth\": pd.read_csv('../room2_nerf_no_depth_results.csv'),\n",
193 | " \"ours\": pd.read_csv('../room2_nerf_ours_results.csv'),\n",
194 | " \"raw\": pd.read_csv('../room2_nerf_raw_results.csv'),\n",
195 | "}\n",
196 | "plot_results(dfs, show=True, save=\"room0_results\", xlim=25000)"
197 | ]
198 | },
199 | {
200 | "cell_type": "code",
201 | "execution_count": null,
202 | "id": "42de5872",
203 | "metadata": {},
204 | "outputs": [],
205 | "source": [
206 | "dfs = {\n",
207 | " \"no depth\": pd.read_csv('../office0_nerf_no_depth_results.csv'),\n",
208 | " \"ours\": pd.read_csv('../office0_nerf_ours_results.csv'),\n",
209 | " \"raw\": pd.read_csv('../office0_nerf_raw_results.csv'),\n",
210 | "}\n",
211 | "plot_results(dfs, show=True, save=\"room0_results\", xlim=25000)"
212 | ]
213 | },
214 | {
215 | "cell_type": "code",
216 | "execution_count": null,
217 | "id": "a45fc9c7",
218 | "metadata": {},
219 | "outputs": [],
220 | "source": [
221 | "dfs = {\n",
222 | " \"no depth\": pd.read_csv('../office1_nerf_no_depth_results.csv'),\n",
223 | " \"ours\": pd.read_csv('../office1_nerf_ours_results.csv'),\n",
224 | " \"raw\": pd.read_csv('../office1_nerf_raw_results.csv'),\n",
225 | "}\n",
226 | "plot_results(dfs, show=True, save=\"room0_results\", xlim=25000)"
227 | ]
228 | },
229 | {
230 | "cell_type": "code",
231 | "execution_count": null,
232 | "id": "08cb721d",
233 | "metadata": {},
234 | "outputs": [],
235 | "source": [
236 | "dfs = {\n",
237 | " \"no depth\": pd.read_csv('../office2_nerf_no_depth_results.csv'),\n",
238 | " \"ours\": pd.read_csv('../office2_nerf_ours_results.csv'),\n",
239 | " \"raw\": pd.read_csv('../office2_nerf_raw_results.csv'),\n",
240 | "}\n",
241 | "plot_results(dfs, show=True, save=\"room0_results\", xlim=25000)"
242 | ]
243 | },
244 | {
245 | "cell_type": "code",
246 | "execution_count": null,
247 | "id": "09db9418",
248 | "metadata": {},
249 | "outputs": [],
250 | "source": [
251 | "dfs = {\n",
252 | " \"no depth\": pd.read_csv('../office3_nerf_no_depth_results.csv'),\n",
253 | " \"ours\": pd.read_csv('../office3_nerf_ours_results.csv'),\n",
254 | " \"raw\": pd.read_csv('../office3_nerf_raw_results.csv'),\n",
255 | "}\n",
256 | "plot_results(dfs, show=True, save=\"room0_results\", xlim=25000)"
257 | ]
258 | },
259 | {
260 | "cell_type": "code",
261 | "execution_count": null,
262 | "id": "e40a7acc",
263 | "metadata": {},
264 | "outputs": [],
265 | "source": [
266 | "dfs = {\n",
267 | " \"no depth\": pd.read_csv('../office4_nerf_no_depth_results.csv'),\n",
268 | " \"ours\": pd.read_csv('../office4_nerf_ours_results.csv'),\n",
269 | " \"raw\": pd.read_csv('../office4_nerf_raw_results.csv'),\n",
270 | "}\n",
271 | "plot_results(dfs, show=True, save=\"room0_results\", xlim=25000)"
272 | ]
273 | },
274 | {
275 | "cell_type": "code",
276 | "execution_count": null,
277 | "id": "02bbbd47",
278 | "metadata": {},
279 | "outputs": [],
280 | "source": []
281 | }
282 | ],
283 | "metadata": {
284 | "kernelspec": {
285 | "display_name": "Python 3 (ipykernel)",
286 | "language": "python",
287 | "name": "python3"
288 | },
289 | "language_info": {
290 | "codemirror_mode": {
291 | "name": "ipython",
292 | "version": 3
293 | },
294 | "file_extension": ".py",
295 | "mimetype": "text/x-python",
296 | "name": "python",
297 | "nbconvert_exporter": "python",
298 | "pygments_lexer": "ipython3",
299 | "version": "3.8.10"
300 | }
301 | },
302 | "nbformat": 4,
303 | "nbformat_minor": 5
304 | }
305 |
--------------------------------------------------------------------------------
/scripts/download_cube.bash:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env sh
2 |
3 | mkdir -p Datasets
4 | cd Datasets
5 | git clone https://github.com/jc211/nerf-cube-diorama-dataset.git
6 |
--------------------------------------------------------------------------------
/scripts/download_replica.bash:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env sh
2 |
3 | mkdir -p Datasets
4 | cd Datasets
5 |
6 | # This is 9.2Gb of data: contains all the images, transforms (.json), and meshes (.ply).
7 | gdown 1tdioYdNGK6yZZdfyQKkQI84ALtEqd2fH
8 |
9 | unzip Replica.zip
10 |
11 |
--------------------------------------------------------------------------------
/scripts/download_replica_sample.bash:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env sh
2 |
3 | mkdir -p Datasets
4 | cd Datasets
5 |
6 | gdown 1f4RJ4W9uxlhCvihG12X9Wa4yZrlBD4TE
7 |
8 | mkdir -p Replica/
9 | unzip ReplicaSample.zip -d Replica
10 |
11 |
--------------------------------------------------------------------------------
/scripts/record_real_sense.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import os
4 | import sys
5 | sys.settrace
6 | sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
7 |
8 | from icecream import ic
9 | import argparse
10 |
11 | import numpy as np
12 |
13 | import cv2
14 |
15 | from datasets.data_module import DataModule
16 |
17 | def parse_args():
18 | parser = argparse.ArgumentParser(description="RealSense Recorder")
19 | return parser.parse_args()
20 |
21 |
22 | if __name__ == '__main__':
23 | args = parse_args()
24 |
25 | args.parallel_run = False
26 | args.dataset_dir = "/home/tonirv/Datasets/RealSense"
27 | args.dataset_name = "real"
28 |
29 | args.initial_k = 0
30 | args.final_k = 0
31 | args.img_stride = 1
32 | args.stereo = False
33 |
34 | data_provider_module = DataModule(args.dataset_name, args, device="cpu")
35 | data_provider_module.initialize_module()
36 |
37 | # Start once we press 's'
38 | print("Waiting to start recorgin, press 's'.")
39 | cv2.imshow("Click 's' to start; 'q' to stop", np.ones((200,200)))
40 | while cv2.waitKey(33) != ord('s'):
41 | continue
42 |
43 | print("Recording")
44 | data_packets = []
45 | while cv2.waitKey(33) != ord('q'): # quit
46 | output = data_provider_module.spin_once("aha")
47 | data_packets += [output]
48 | print("Stopping")
49 | ic(len(data_packets))
50 |
51 | # Write images to disk, and save path with to_nerf() function
52 | print("Saving to nerf format")
53 | data_provider_module.dataset.to_nerf_format(data_packets)
54 | print('Done...')
55 |
56 |
--------------------------------------------------------------------------------
/scripts/replica_results.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import os
4 | import sys
5 | sys.settrace
6 | sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
7 |
8 | from icecream import ic
9 |
10 | import argparse
11 | from tqdm import tqdm
12 | import copy
13 | import torch
14 |
15 | from examples.slam_demo import run
16 |
17 | def parse_args():
18 | parser = argparse.ArgumentParser(description="INSTANT SLAM")
19 | return parser.parse_args()
20 |
21 |
22 | if __name__ == '__main__':
23 | args = parse_args()
24 |
25 | args.parallel_run = True
26 | args.multi_gpu = True
27 |
28 | args.initial_k = 0
29 | args.final_k = 2000
30 |
31 | args.stereo = False
32 |
33 | args.weights = "droid.pth"
34 | args.network = ""
35 | args.width = 0
36 | args.height = 0
37 |
38 | args.dataset_dir = None
39 | args.dataset_name = "nerf"
40 |
41 | args.slam = True
42 | args.fusion = 'nerf'
43 | args.gui = True
44 |
45 | args.eval = True
46 |
47 | args.sigma_fusion = True
48 |
49 |
50 | args.buffer = 100
51 | args.img_stride = 1
52 |
53 | # "ours", "ours_w_thresh" or "raw", "no_depth"
54 | args1 = copy.deepcopy(args)
55 | args1.mask_type = "ours"
56 |
57 | args2 = copy.deepcopy(args)
58 | args2.mask_type = "no_depth"
59 |
60 | args3 = copy.deepcopy(args)
61 | args3.mask_type = "raw"
62 |
63 | args4 = copy.deepcopy(args)
64 | args4.mask_type = "ours_w_thresh"
65 |
66 | args_to_test = {
67 | #args1.mask_type: args1,
68 | #args2.mask_type: args2,
69 | args3.mask_type: args3,
70 | #args4.mask_type: args4,
71 | }
72 |
73 |
74 | results = "results.csv"
75 | datasets = [
76 | #"/home/tonirv/Datasets/nerf-cube-diorama-dataset/room",
77 | #"/home/tonirv/Datasets/nerf-cube-diorama-dataset/bluebell",
78 | #"/home/tonirv/Datasets/nerf-cube-diorama-dataset/book",
79 | #"/home/tonirv/Datasets/nerf-cube-diorama-dataset/cup",
80 | #"/home/tonirv/Datasets/nerf-cube-diorama-dataset/laptop"
81 | #"/home/tonirv/Datasets/Replica/office0",
82 | #"/home/tonirv/Datasets/Replica/office2",
83 | #"/home/tonirv/Datasets/Replica/office3",
84 | #"/home/tonirv/Datasets/Replica/office4",
85 | #"/home/tonirv/Datasets/Replica/room1",
86 | #"/home/tonirv/Datasets/Replica/room2",
87 | # These get stuck not sure why....
88 | #"/home/tonirv/Datasets/Replica/room0",
89 | "/home/tonirv/Datasets/Replica/office1",
90 | ]
91 |
92 | torch.multiprocessing.set_start_method('spawn')
93 | torch.cuda.empty_cache()
94 | torch.backends.cudnn.benchmark = True
95 | torch.set_grad_enabled(False)
96 |
97 | for dataset in tqdm(datasets):
98 | for test_name, test_args in args_to_test.items():
99 | output = os.path.basename(dataset) + '_nerf_' + test_name + '_' + results
100 | ic(output)
101 | test_args.dataset_dir = dataset
102 | print(f"Processing dataset: {test_args.dataset_dir}")
103 | try:
104 | run(test_args)
105 | except Exception as e:
106 | print(e)
107 | print(f"Saving output in: {output}")
108 | # Copy transforms.json to its args.dataset_dir folder
109 | os.replace(results, os.path.join(output))
110 | #torch.cuda.empty_cache()
111 | print('Done...')
112 |
113 |
--------------------------------------------------------------------------------
/scripts/replica_to_nerf_dataset.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import os
4 | import sys
5 | sys.settrace
6 | sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
7 |
8 | import argparse
9 | from tqdm import tqdm
10 |
11 | from datasets.data_module import DataModule
12 |
13 | def parse_args():
14 | parser = argparse.ArgumentParser(description="INSTANT SLAM")
15 | parser.add_argument("--replica_dir", type=str,
16 | help="Path to the Replica dataset root dir",
17 | default="/home/tonirv/Datasets/Replica/")
18 |
19 | return parser.parse_args()
20 |
21 | if __name__ == '__main__':
22 | args = parse_args()
23 | args.parallel_run = True
24 | args.dataset_dir = None
25 | args.initial_k = 0 # first frame to load
26 | args.final_k = -1 # last frame to load, if -1 load all
27 | args.img_stride = 1 # stride for loading images
28 | args.stereo = False
29 | args.buffer = 3000
30 |
31 | transform = "transforms.json"
32 | dataset_names = ["room0", "room1", "room2", "office0", "office1", "office2", "office3", "office4"]
33 | for dataset in tqdm(dataset_names):
34 | args.dataset_dir = os.path.join(args.replica_dir, dataset)
35 | print(f"Processing dataset: {args.dataset_dir}")
36 | # Parse dataset and transform to Nerf
37 | data_provider_module = DataModule("replica", args)
38 | data_provider_module.initialize_module()
39 | data_provider_module.dataset.to_nerf_format()
40 | # Copy transforms.json to its args.dataset_dir folder
41 | os.replace(transform, os.path.join(args.dataset_dir, transform))
42 | print('Done...')
43 |
44 |
--------------------------------------------------------------------------------
/scripts/unzip_tartan_air.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import glob
4 | import os
5 | import os.path as osp
6 | from tqdm import tqdm
7 | import shutil
8 |
9 | cur_path = osp.dirname(osp.abspath(__file__))
10 |
11 | # Download dataset should be:
12 | # python download_training.py --rgb --depth --only-left --output-dir ./datasets/TartanAir
13 |
14 | # Dirs inside datasets/TartanAir must be
15 | # {dataset}/{level}/P***/{depth_left,image_left,pose_left.txt}
16 |
17 | LEVELS = ['Easy', 'Hard']
18 |
19 | def unzip(tartanair_path='datasets/TartanAir', remove_zip=False):
20 | datasets_paths = glob.glob(osp.join(tartanair_path, "*"))
21 | for dataset in tqdm(sorted(datasets_paths)):
22 | dataset_name = os.path.basename(dataset)
23 | print("Dataset: %s" % dataset_name)
24 | for level in LEVELS:
25 | print("Level: %s"%(level))
26 |
27 | ### Form paths and pass checks
28 | dataset_level_path = osp.join(dataset, level)
29 | depth_dataset_level_path = osp.join(dataset_level_path, "depth_left.zip")
30 | if not osp.exists(depth_dataset_level_path):
31 | print("Missing Depth zip file for Dataset/Level: %s/%s" %(dataset, level))
32 | continue
33 | image_dataset_level_path = osp.join(dataset_level_path, "image_left.zip")
34 | if not osp.exists(image_dataset_level_path):
35 | print("Missing Image zip file for Dataset/Level: %s/%s"%(dataset, level))
36 | continue
37 |
38 | if osp.exists(osp.join(dataset_level_path, dataset_name)) or len(glob.glob(osp.join(dataset_level_path, "P*"))) != 0:
39 | print("Seems like the dataset was already unzipped? %s" % dataset)
40 | else:
41 | ### Unzip dataset
42 | command = "unzip -q -n %s -d %s"%(depth_dataset_level_path, dataset_level_path)
43 | print(command)
44 | os.system(command)
45 | if remove_zip:
46 | os.remove(depth_dataset_level_path)
47 | command = "unzip -q -n %s -d %s"%(image_dataset_level_path, dataset_level_path)
48 | print(command)
49 | os.system(command)
50 | if remove_zip:
51 | os.remove(image_dataset_level_path)
52 |
53 | ### Remove junk directories
54 | from_ = osp.join(dataset_level_path, "*/*/*/P*")
55 | if len(glob.glob(from_)) != 0: # We have junk folders
56 | to_ = dataset_level_path
57 | command = "mv %s %s"%(from_,to_)
58 | print(command)
59 | os.system(command)
60 | shutil.rmtree(osp.join(dataset_level_path,dataset_name))
61 |
62 | import argparse
63 |
64 | def dir_path(string):
65 | if os.path.isdir(string):
66 | return string
67 | else:
68 | raise NotADirectoryError(string)
69 |
70 | if __name__ == "__main__":
71 | parser = argparse.ArgumentParser()
72 | parser.add_argument('--dataset_path', type=dir_path)
73 | parser.add_argument('--remove_zip', action="store_true")
74 | args = parser.parse_args()
75 |
76 | print("Unzipping TartanAir dataset")
77 | unzip(args.dataset_path, args.remove_zip)
78 |
79 |
80 |
81 |
82 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | from setuptools import setup
4 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
5 |
6 | import os.path as osp
7 | ROOT = osp.dirname(osp.abspath(__file__))
8 |
9 | setup(
10 | name='droid_backends',
11 | ext_modules=[
12 | CUDAExtension('droid_backends',
13 | include_dirs=[osp.join(ROOT, 'thirdparty/eigen')],
14 | sources=[
15 | 'src/droid.cpp',
16 | 'src/droid_kernels.cu',
17 | 'src/correlation_kernels.cu',
18 | 'src/altcorr_kernel.cu',
19 | ],
20 | extra_compile_args={
21 | 'cxx': ['-O3'],
22 | 'nvcc': ['-O3',
23 | '-gencode=arch=compute_60,code=sm_60',
24 | '-gencode=arch=compute_61,code=sm_61',
25 | '-gencode=arch=compute_70,code=sm_70',
26 | '-gencode=arch=compute_75,code=sm_75',
27 | '-gencode=arch=compute_80,code=sm_80',
28 | '-gencode=arch=compute_86,code=sm_86',
29 | ]
30 | }),
31 | ],
32 | cmdclass={ 'build_ext' : BuildExtension }
33 | )
34 |
35 | setup(
36 | name='lietorch',
37 | version='0.2',
38 | description='Lie Groups for PyTorch',
39 | packages=['lietorch'],
40 | package_dir={'': 'thirdparty/lietorch'},
41 | ext_modules=[
42 | CUDAExtension('lietorch_backends',
43 | include_dirs=[
44 | osp.join(ROOT, 'thirdparty/lietorch/lietorch/include'),
45 | osp.join(ROOT, 'thirdparty/eigen')],
46 | sources=[
47 | 'thirdparty/lietorch/lietorch/src/lietorch.cpp',
48 | 'thirdparty/lietorch/lietorch/src/lietorch_gpu.cu',
49 | 'thirdparty/lietorch/lietorch/src/lietorch_cpu.cpp'],
50 | extra_compile_args={
51 | 'cxx': ['-O2'],
52 | 'nvcc': ['-O2',
53 | '-gencode=arch=compute_60,code=sm_60',
54 | '-gencode=arch=compute_61,code=sm_61',
55 | '-gencode=arch=compute_70,code=sm_70',
56 | '-gencode=arch=compute_75,code=sm_75',
57 | '-gencode=arch=compute_80,code=sm_80',
58 | '-gencode=arch=compute_86,code=sm_86',
59 | ]
60 | }),
61 | ],
62 | cmdclass={ 'build_ext' : BuildExtension }
63 | )
64 |
65 |
--------------------------------------------------------------------------------
/slam/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
--------------------------------------------------------------------------------
/slam/inertial_frontends/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 |
4 |
--------------------------------------------------------------------------------
/slam/inertial_frontends/inertial_frontend.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 |
4 | from __future__ import print_function
5 |
6 | from abc import abstractclassmethod
7 |
8 | from icecream import ic
9 |
10 | import torch as th
11 | from torch import nn
12 | from factor_graph.variables import Variable, Variables
13 | from factor_graph.factor import Factor
14 | from factor_graph.factor_graph import TorchFactorGraph
15 |
16 | from slam.meta_slam import SLAM
17 | import numpy as np
18 |
19 | import gtsam
20 | from gtsam import (ImuFactor, Pose3, Rot3, Point3)
21 | from gtsam import PriorFactorPose3, PriorFactorConstantBias, PriorFactorVector
22 | from gtsam.symbol_shorthand import B, V, X
23 |
24 |
25 | def vector3(x, y, z):
26 | """Create 3d double numpy array."""
27 | return np.array([x, y, z], dtype=float)
28 |
29 |
30 | class InertialFrontend(nn.Module):
31 | def __init__(self):
32 | super().__init__()
33 |
34 | @abstractclassmethod
35 | def forward(self, mini_batch):
36 | pass
37 |
38 |
39 | import gtsam
40 | from gtsam import Values
41 | from gtsam import NonlinearFactorGraph as FactorGraph
42 | from gtsam.symbol_shorthand import B
43 |
44 | import torch as th
45 |
46 |
47 | class PreIntegrationInertialFrontend(InertialFrontend):
48 | def __init__(self):
49 | super().__init__()
50 | self.last_key = None
51 |
52 | # THESE ARE A LOT OF ALLOCATIONS: #frames * imu_buffer_size * 7 * 2 (because double precision!)
53 | # self.imu_buffer = 200
54 | # self.imu_t0_t1 = th.empty(self.imu_buffer, 7, device='cpu', dtype=th.float64)#.share_memory_()
55 |
56 | def initialize_imu_frontend():
57 | pass
58 |
59 | def forward(self, mini_batch, last_state):
60 | # Call IMU preintegration from kimera/gtsam or re-implement...
61 | # The output of PreIntegrationInertialFrontend is a
62 | # bunch of (pose,vel,bias)-to-(pose,vel,bias) factors.
63 | print("PreIntegrationInertialFrontend.forward")
64 | k = mini_batch["k"]
65 | if last_state is None and k == 0:
66 | self.last_key = k
67 | # Initialize the preintegration
68 | self.imu_params = mini_batch["imu_calib"]
69 | self.preint_imu_params = self.get_preint_imu_params(self.imu_params)
70 | initial_state = self.initial_state()
71 | self.pim = gtsam.PreintegratedImuMeasurements(self.preint_imu_params,
72 | initial_state[2])
73 | # Set initial priors and values
74 | x0, factors = self.initial_priors_and_values(k, initial_state)
75 | return x0, factors
76 |
77 | imu_t0_t1 = mini_batch["imu_t0_t1"]
78 | # imu_meas_count = len(imu_t0_t1_df)
79 | # self.imu_t0_t1[:imu_meas_count] = th.as_tensor(imu_t0_t1_df, device='cpu', dtype=th.float64).cpu().numpy()
80 |
81 | # Integrate measurements between frames (101 measurements)
82 | #self.preintegrate_imu(self.imu_t0_t1, imu_meas_count)
83 | self.preintegrate_imu(imu_t0_t1, -1)
84 | #self.pim.print()
85 |
86 | #if k % 10 != 0: # simulate keyframe selection
87 | # return Values(), FactorGraph()
88 |
89 | # Get factors
90 | imu_factor = ImuFactor(X(self.last_key), V(self.last_key), X(k), V(k), B(self.last_key), self.pim)
91 | bias_btw_factor = self.get_bias_btw_factor(k, self.pim.deltaTij())
92 |
93 | # Add factors to inertial graph
94 | graph = FactorGraph()
95 | graph.add(imu_factor)
96 | graph.add(bias_btw_factor)
97 |
98 | print("LAST STATE")
99 | print(last_state)
100 | # Get guessed values (from IMU integration)
101 | last_W_Pose_B, last_W_Vel_B, last_imu_bias = \
102 | last_state.atPose3(X(self.last_key)), \
103 | last_state.atVector(V(self.last_key)),\
104 | last_state.atConstantBias(B(self.last_key))
105 | last_navstate = gtsam.NavState(last_W_Pose_B, last_W_Vel_B)
106 | new_navstate = self.pim.predict(last_navstate, last_imu_bias);
107 |
108 | x0 = Values()
109 | x0.insert(X(k), new_navstate.pose())
110 | x0.insert(V(k), new_navstate.velocity())
111 | x0.insert(B(k), last_imu_bias)
112 |
113 | self.last_key = k
114 | self.pim.resetIntegrationAndSetBias(last_imu_bias)
115 |
116 | return x0, graph
117 |
118 | # k: current key, n: number of imu measurements
119 | def preintegrate_imu(self, imu_t0_t1, n):
120 | meas_acc = imu_t0_t1[:n, 4:7]
121 | meas_gyr = imu_t0_t1[:n, 1:4]
122 | delta_t = (imu_t0_t1[1:n, 0] - imu_t0_t1[0:n-1, 0]) * 1e-9 # convert to seconds
123 | for acc, gyr, dt in zip(meas_acc, meas_gyr, delta_t):
124 | self.pim.integrateMeasurement(acc, gyr, dt)
125 | # TODO: fix this loop!
126 | #self.pim.integrateMeasurements(meas_acc, meas_gyr, delta_t)
127 |
128 | def get_bias_btw_factor(self, k, delta_t_ij):
129 | # Bias evolution as given in the IMU metadata
130 | sqrt_delta_t_ij = np.sqrt(delta_t_ij);
131 | bias_sigmas_acc = sqrt_delta_t_ij * self.imu_params.a_b * np.ones(3)
132 | bias_sigmas_gyr = sqrt_delta_t_ij * self.imu_params.g_b * np.ones(3)
133 | bias_sigmas = np.concatenate((bias_sigmas_acc, bias_sigmas_gyr), axis=None)
134 | bias_noise_model = gtsam.noiseModel.Diagonal.Sigmas(bias_sigmas)
135 | bias_value = gtsam.imuBias.ConstantBias()
136 | return gtsam.BetweenFactorConstantBias(B(self.last_key), B(k), bias_value, bias_noise_model)
137 |
138 | # Send IMU priors
139 | def initial_priors_and_values(self, k, initial_state):
140 | pose_key = X(k)
141 | vel_key = V(k)
142 | bias_key = B(k)
143 |
144 | pose_noise = gtsam.noiseModel.Diagonal.Sigmas(
145 | np.array([0.001, 0.001, 0.001, 0.01, 0.01, 0.01]))
146 | vel_noise = gtsam.noiseModel.Isotropic.Sigma(3, 0.001)
147 | bias_noise = gtsam.noiseModel.Isotropic.Sigma(6, 0.01)
148 |
149 | initial_pose, initial_vel, initial_bias = initial_state
150 |
151 | # Get inertial factors
152 | pose_prior = PriorFactorPose3(pose_key, initial_pose, pose_noise)
153 | vel_prior = PriorFactorVector(vel_key, initial_vel, vel_noise)
154 | bias_prior = PriorFactorConstantBias(bias_key, initial_bias, bias_noise)
155 |
156 | # Add factors to inertial graph
157 | graph = FactorGraph()
158 | graph.push_back(pose_prior)
159 | graph.push_back(vel_prior)
160 | graph.push_back(bias_prior)
161 |
162 | # Get guessed values
163 | x0 = Values()
164 | x0.insert(pose_key, initial_pose)
165 | x0.insert(vel_key, initial_vel)
166 | x0.insert(bias_key, initial_bias)
167 |
168 | return x0, graph
169 |
170 | def initial_state(self):
171 | true_pose = gtsam.Pose3(gtsam.Rot3(0.060514,-0.828459,-0.058956,-0.553641), # qw, qx, qy, qz
172 | gtsam.Point3(0.878612,2.142470,0.947262))
173 | true_vel = np.array([0.009474,-0.014009,-0.002145])
174 | true_bias = gtsam.imuBias.ConstantBias(np.array([-0.012492,0.547666,0.069073]), np.array([-0.002229,0.020700,0.076350]))
175 | naive_pose = gtsam.Pose3.identity()
176 | naive_vel = np.zeros(3)
177 | naive_bias = gtsam.imuBias.ConstantBias()
178 | initial_pose = true_pose
179 | initial_vel = true_vel
180 | initial_bias = true_bias
181 | return initial_pose, initial_vel, initial_bias
182 |
183 | def get_preint_imu_params(self, imu_calib):
184 | I = np.eye(3)
185 | preint_params = gtsam.PreintegrationParams(imu_calib.n_gravity);
186 | preint_params.setAccelerometerCovariance(np.power(imu_calib.a_n, 2.0) * I)
187 | preint_params.setGyroscopeCovariance(np.power(imu_calib.g_n, 2.0) * I)
188 | preint_params.setIntegrationCovariance(np.power(imu_calib.imu_integration_sigma, 2.0) * I)
189 | preint_params.setUse2ndOrderCoriolis(False)
190 | preint_params.setOmegaCoriolis(np.zeros(3, dtype=float))
191 | preint_params.print()
192 | return preint_params
--------------------------------------------------------------------------------
/slam/meta_slam.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | from abc import abstractclassmethod
4 | from icecream import ic
5 |
6 | from torch import nn
7 |
8 | from factor_graph.variables import Variable, Variables
9 | from factor_graph.factor_graph import FactorGraphManager
10 |
11 | from solvers.nonlinear_solver import iSAM2, LevenbergMarquardt, Solver
12 |
13 | # An abstract learned SLAM model
14 | class SLAM(nn.Module):
15 | def __init__(self, name, args, device):
16 | super().__init__()
17 | self.name = name
18 | self.args = args
19 | self.device = device
20 | self.factor_graph_manager = FactorGraphManager()
21 | self.state = None
22 | self.delta = None
23 |
24 | # This is our spin_once
25 | def forward(self, batch):
26 | # TODO Parallelize frontend/backend
27 | assert("data" in batch)
28 |
29 | # Frontend
30 | output = self._frontend(batch["data"], self.state, self.delta)
31 | if output == False:
32 | return output
33 | else:
34 | x0, factors, viz_out = output
35 |
36 | # Backend
37 | self.factor_graph_manager.add(factors)
38 | factor_graph = self.factor_graph_manager.get_factor_graph()
39 | self.state, self.delta = self._backend(factor_graph, x0)
40 | if type(self.backend) is iSAM2:
41 | self.factor_graph_manager.reset_factor_graph()
42 | self.backend = iSAM2() # uncomment if running DroidSLAM
43 | return [self.state, viz_out]
44 |
45 | # Converts sensory inputs to factors and initial guess (x0)
46 | @abstractclassmethod # Implemented by derived classes
47 | def _frontend(self, mini_batch, last_state, last_delta):
48 | raise
49 |
50 | # Solves the factor graph, given an initial estimate
51 | @abstractclassmethod # Implemented by derived classes
52 | def _backend(self, factor_graph, x0):
53 | raise
54 |
55 |
56 | class MetaSLAM(SLAM):
57 | def __init__(self, name, device):
58 | super().__init__(name, device)
59 |
60 | @abstractclassmethod
61 | def calculate_loss(self, x, mini_batch):
62 | # mini_batch contains the ground-truth parameters as well
63 | pass
64 |
--------------------------------------------------------------------------------
/slam/slam_module.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | from pipeline.pipeline_module import MIMOPipelineModule
4 |
5 | class SlamModule(MIMOPipelineModule):
6 | def __init__(self, name, args, device="cpu"):
7 | super().__init__(name, args.parallel_run, args)
8 | self.device = device
9 |
10 | def spin_once(self, input):
11 | output = self.slam(input)
12 | if not output or self.slam.stop_condition():
13 | super().shutdown_module()
14 | return output
15 |
16 | def initialize_module(self):
17 | if self.name == "VioSLAM":
18 | from slam.vio_slam import VioSLAM
19 | self.slam = VioSLAM(self.name, self.args, self.device)
20 | else:
21 | raise NotImplementedError
22 | return super().initialize_module()
23 |
--------------------------------------------------------------------------------
/slam/vio_slam.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | from abc import abstractclassmethod
4 |
5 | from icecream import ic
6 |
7 | from torch import nn
8 | from factor_graph.variables import Variable, Variables
9 | from factor_graph.factor_graph import TorchFactorGraph
10 |
11 | from gtsam import Values
12 | from gtsam import NonlinearFactorGraph
13 | from gtsam import GaussianFactorGraph
14 |
15 | from slam.meta_slam import SLAM
16 | from slam.inertial_frontends.inertial_frontend import PreIntegrationInertialFrontend
17 | from slam.visual_frontends.visual_frontend import RaftVisualFrontend
18 | from solvers.nonlinear_solver import iSAM2, LevenbergMarquardt, Solver
19 |
20 | ########################### REMOVE ############################
21 | import numpy as np
22 | import gtsam
23 | from gtsam import (ImuFactor, Pose3, Rot3, Point3)
24 | from gtsam import PriorFactorPose3, PriorFactorConstantBias, PriorFactorVector
25 | from gtsam.symbol_shorthand import B, V, X
26 |
27 |
28 | # Send IMU priors
29 | def initial_priors_and_values(k, initial_state):
30 | pose_key = X(k)
31 | vel_key = V(k)
32 | bias_key = B(k)
33 |
34 | pose_noise = gtsam.noiseModel.Diagonal.Sigmas(
35 | np.array([0.000001, 0.000001, 0.000001, 0.00001, 0.00001, 0.00001]))
36 | vel_noise = gtsam.noiseModel.Isotropic.Sigma(3, 0.000001)
37 | bias_noise = gtsam.noiseModel.Isotropic.Sigma(6, 0.00001)
38 |
39 | initial_pose, initial_vel, initial_bias = initial_state
40 |
41 | # Get inertial factors
42 | pose_prior = PriorFactorPose3(pose_key, initial_pose, pose_noise)
43 | vel_prior = PriorFactorVector(vel_key, initial_vel, vel_noise)
44 | bias_prior = PriorFactorConstantBias(bias_key, initial_bias, bias_noise)
45 |
46 | # Add factors to inertial graph
47 | graph = NonlinearFactorGraph()
48 | graph.push_back(pose_prior)
49 | graph.push_back(vel_prior)
50 | graph.push_back(bias_prior)
51 |
52 | # Get guessed values
53 | x0 = Values()
54 | x0.insert(pose_key, initial_pose)
55 | x0.insert(vel_key, initial_vel)
56 | x0.insert(bias_key, initial_bias)
57 |
58 | return x0, graph
59 |
60 | def initial_state():
61 | true_world_T_imu_t0 = gtsam.Pose3(gtsam.Rot3(0.060514, -0.828459, -0.058956, -0.553641), # qw, qx, qy, qz
62 | gtsam.Point3(0.878612, 2.142470, 0.947262))
63 | true_vel = np.array([0.009474,-0.014009,-0.002145])
64 | true_bias = gtsam.imuBias.ConstantBias(np.array([-0.012492,0.547666,0.069073]), np.array([-0.002229,0.020700,0.076350]))
65 | naive_pose = gtsam.Pose3.identity()
66 | naive_vel = np.zeros(3)
67 | naive_bias = gtsam.imuBias.ConstantBias()
68 | initial_pose = true_world_T_imu_t0
69 | initial_vel = true_vel
70 | initial_bias = true_bias
71 | initial_pose = naive_pose
72 | initial_vel = naive_vel
73 | initial_bias = naive_bias
74 | return initial_pose, initial_vel, initial_bias
75 | ###############################################################
76 |
77 |
78 | class VioSLAM(SLAM):
79 | def __init__(self, name, args, device):
80 | super().__init__(name, args, device)
81 | world_T_imu_t0, _, _ = initial_state()
82 | imu_T_cam0 = Pose3(np.array([[0.0148655429818, -0.999880929698, 0.00414029679422, -0.0216401454975],
83 | [0.999557249008, 0.0149672133247, 0.025715529948, -0.064676986768],
84 | [-0.0257744366974, 0.00375618835797, 0.999660727178, 0.00981073058949],
85 | [0.0, 0.0, 0.0, 1.0]]))
86 |
87 | imu_T_cam0 = Pose3(np.eye(4))
88 |
89 | #world_T_imu_t0 = Pose3(args.world_T_imu_t0)
90 | #world_T_imu_t0 = Pose3(np.eye(4))
91 | world_T_imu_t0 = Pose3(np.array(
92 | [[-7.6942980e-02, -3.1037781e-01, 9.4749427e-01, 8.9643948e-02],
93 | [-2.8366595e-10, -9.5031142e-01, -3.1130061e-01, 4.1829333e-01],
94 | [ 9.9703550e-01, -2.3952398e-02, 7.3119797e-02, 4.8306200e-01],
95 | [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 1.0000000e+00]]))
96 |
97 | self.visual_frontend = RaftVisualFrontend(world_T_imu_t0, imu_T_cam0, args, device=device)
98 | #self.inertial_frontend = PreIntegrationInertialFrontend()
99 |
100 | self.backend = iSAM2()
101 |
102 | self.values = Values()
103 | self.inertial_factors = NonlinearFactorGraph()
104 | self.visual_factors = NonlinearFactorGraph()
105 |
106 | self.last_state = None
107 |
108 | def stop_condition(self):
109 | return self.visual_frontend.stop_condition()
110 |
111 | # Converts sensory inputs to measurements and initial guess
112 | def _frontend(self, batch, last_state, last_delta):
113 | # Compute optical flow
114 | x0_visual, visual_factors, viz_out = self.visual_frontend(batch) # TODO: currently also calls BA, and global BA
115 | self.last_state = x0_visual
116 |
117 | if x0_visual is None:
118 | return False
119 |
120 | # Wrap guesses
121 | x0 = Values()
122 | factors = NonlinearFactorGraph()
123 |
124 | return x0, factors, viz_out
125 |
126 | def _backend(self, factor_graph, x0):
127 | return self.backend.solve(factor_graph, x0)
128 |
129 |
--------------------------------------------------------------------------------
/slam/visual_frontends/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
--------------------------------------------------------------------------------
/solvers/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 |
4 |
--------------------------------------------------------------------------------
/solvers/linear_solver.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import enum
4 | import colored_glog as log
5 | import torch as th
6 |
7 | class LinearSolverType(enum.Enum):
8 | Inverse = 0
9 | Cholesky = 1
10 | QR = 2 # to be implemented
11 | MultifrontalQR = 3 # to be implemented
12 | CG = 4 # Conjugate Gradient, to be implemented
13 | # Gauss-Seidel for dense flow?
14 | # Use scipy.optimize...
15 |
16 | # Linear Solver
17 | class LinearLS:
18 | # loss = |Y - ([X, 1] * [w, b])|^2_sigma
19 | # [X, 1] == A, Y - ([X, 1] * [^w, ^b]) == b [w, b] == x
20 | @staticmethod
21 | def solve_inverse(A, b, E):
22 | # loss = |Ax - b|^2_sigma(\theta)
23 | At = th.transpose(A, 1, 2)
24 | AtE = At @ E
25 | AtEA = AtE @ A
26 | AtEb = AtE @ b
27 | x = th.linalg.inv(AtEA) @ AtEb
28 | return x
29 |
30 | # TODO(Toni): We should make the sparse version of this
31 | # A has size: (B, M, N) where B = batch, M = measurements, N = variables
32 | # w has size: (B, M, 1) and is the weight for each measurement (assumes diagonal weights)
33 | # b has size: (B, N)
34 | @staticmethod
35 | def solve_cholesky(A: th.Tensor, b: th.Tensor, w: th.Tensor) -> th.Tensor:
36 | # loss = |Ax - b|^2_w
37 | # Solve all Batch linear problems
38 | if A is None or b is None or w is None:
39 | return None
40 |
41 | # Check dimensionality, first is the batch dimension
42 | B, M = b.shape
43 | _, _, N = A.shape
44 | log.check_eq(th.Size([B, M, N]), A.shape)
45 | log.check_eq(th.Size([B, M]), w.shape) # We assume E diagonal: E=diag(w)
46 | WA = A * w.unsqueeze(-1) # multiply each row by the weight (uses broadcasting)
47 | AtWt = WA.transpose(1,2)
48 | AtEA = AtWt @ WA
49 | AtWtb = AtWt @ b.unsqueeze(-1)
50 | L = th.linalg.cholesky(AtEA)
51 | y = th.linalg.solve_triangular(L, AtWtb, upper=False)
52 | x = th.linalg.solve_triangular(th.transpose(L, 1, 2), y, upper=True)
53 | return x.squeeze(-1)
54 |
55 | # TODO(Toni): We should make the sparse version of this
56 | # A has size: (B, M, N) where B = batch, M = measurements, N = variables <- In standard form
57 | # b has size: (B, N)
58 | # @staticmethod
59 | # def solve_cholesky(A, b):
60 | # # loss = |Ax - b|^2_2
61 | # # Solve all Batch linear problems
62 |
63 | # # Check dimensionality, first is the batch dimension
64 | # B, M = b.shape
65 | # _, _, N = A.shape
66 | # log.check_eq(th.Size([B, M, N]), A.shape)
67 | # At = A.transpose(1, 2) # Batched A transposes
68 | # AtA = At @ A
69 | # Atb = At @ b.unsqueeze(-1)
70 | # L = th.linalg.cholesky(AtA)
71 | # y = th.linalg.solve_triangular(L, Atb, upper=False)
72 | # x = th.linalg.solve_triangular(th.transpose(L, 1, 2), y, upper=True)
73 | # return x.squeeze(-1)
74 |
75 | # For pyth 1.11
76 | @staticmethod
77 | def solve_cholesky_11(A, b, E):
78 | # loss = |Ax - b|^2_sigma(\theta)
79 | At = th.transpose(A, 1, 2)
80 | AtE = At @ E
81 | AtEA = AtE @ A
82 | AtEb = AtE @ b
83 | L = th.linalg.cholesky(AtEA)
84 | x = th.linalg.cholesky_solve(AtEb, L, upper=False)
85 | return x
86 |
87 |
--------------------------------------------------------------------------------
/solvers/meta_solver.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import torch as th
4 | from torch.utils.data import DataLoader
5 |
6 | from slam.meta_slam import MetaSLAM
7 |
8 | class MetaSolver:
9 | def __init__(self, epochs: int, optimizer: th.optim.Optimizer, meta_slam: MetaSLAM):
10 | self.epochs = epochs
11 | self.optimizer = optimizer
12 | self.meta_slam = meta_slam
13 | self.log_every_n = 1
14 |
15 | def solve(self, dataloader: DataLoader):
16 | for epoch in range(self.epochs):
17 | print(f"Epoch {epoch}\n-------------------------------")
18 | for i, mini_batch in enumerate(dataloader):
19 | print(f"Batch {i}\n-------------------------------")
20 |
21 | # Build and solve factor graphs
22 | x = self.meta_slam(mini_batch)
23 | if self.logging_callback:
24 | self.logging_callback()
25 |
26 | # Compute and print loss
27 | if self.optimizer:
28 | loss = self.meta_slam.calculate_loss(x, mini_batch)
29 | if i % self.log_every_n == 0:
30 | print(f"outer_loss: {loss.item():>7f} [{i:>5d}/{len(dataloader):>5d}]")
31 |
32 | # Zero gradients, perform a backward pass, and update the weights from the outer optimization
33 | self.optimizer.zero_grad()
34 | loss.backward()
35 | self.optimizer.step()
36 |
37 | def register_logging_callback(self, log):
38 | self.logging_callback = log
--------------------------------------------------------------------------------
/solvers/nonlinear_solver.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | from abc import abstractclassmethod
4 |
5 | import colored_glog as log
6 |
7 | import torch as th
8 |
9 | from factor_graph.variables import Variable, Variables
10 | from factor_graph.factor_graph import TorchFactorGraph
11 |
12 | from solvers.linear_solver import LinearLS, LinearSolverType
13 |
14 | import gtsam
15 | from gtsam import (ISAM2, LevenbergMarquardtOptimizer, NonlinearFactorGraph, PriorFactorPose2, Values)
16 | from gtsam import NonlinearFactorGraph as FactorGraph
17 |
18 | from icecream import ic
19 |
20 | class Solver:
21 | def __init__(self):
22 | pass
23 |
24 | @abstractclassmethod
25 | def solve(self, factor_graph, x0):
26 | raise
27 |
28 | class iSAM2(Solver):
29 |
30 | def __init__(self):
31 | super().__init__()
32 | # Set ISAM2 parameters and create ISAM2 solver object
33 | isam_params = gtsam.ISAM2Params()
34 |
35 | # Dogleg params
36 | dogleg_params = gtsam.ISAM2DoglegParams()
37 | dogleg_params.setInitialDelta(1.0)
38 | dogleg_params.setWildfireThreshold(1e-5)
39 | dogleg_params.setVerbose(True)
40 | # dogleg_params.setAdaptationMode(string adaptationMode);
41 |
42 | # Gauss-Newton params
43 | gauss_newton_params = gtsam.ISAM2GaussNewtonParams()
44 | gauss_newton_params.setWildfireThreshold(1e-5)
45 |
46 | # Optimization parameters
47 | isam_params.setOptimizationParams(gauss_newton_params)
48 | isam_params.setFactorization("CHOLESKY") # QR or Cholesky
49 |
50 | # Linearization parameters
51 | isam_params.enableRelinearization = True
52 | isam_params.enablePartialRelinearizationCheck = False
53 | isam_params.setRelinearizeThreshold(0.1) # TODO
54 | isam_params.relinearizeSkip = 10
55 |
56 | # Memory efficiency, but slower
57 | isam_params.findUnusedFactorSlots = True
58 |
59 | # Debugging parameters, disable for speed
60 | isam_params.evaluateNonlinearError = True
61 | isam_params.enableDetailedResults = True
62 |
63 | #isam_params.print()
64 |
65 | self.isam2 = gtsam.ISAM2(isam_params)
66 |
67 | def solve(self, factor_graph, x0):
68 | # factors_to_remove = factor_graph.keyVector() # since all are structureless, and we re-add all inertial
69 | # print(factors_to_remove)
70 | # self.isam2.update(factor_graph, x0, factors_to_remove)
71 | self.isam2.update(factor_graph, x0) # Only one iteration!!
72 | result = self.isam2.calculateEstimate()
73 | delta = self.isam2.getDelta()
74 | return result, delta
75 |
76 | # class GaussNewton(Solver):
77 | # def __init__(self):
78 | # super().__init__()
79 | # self.params = gtsam.GaussNewtonParams()
80 | # self.params.setVerbosityLM("SUMMARY")
81 | # self.x0 = Values()
82 | #
83 | # def solve(self, factor_graph, x0):
84 | # self.x0.insert(x0)
85 | # optimizer = gtsam.LevenbergMarquardtOptimizer(factor_graph, self.x0, self.params)
86 | # return optimizer.optimize()
87 |
88 | class LevenbergMarquardt(Solver):
89 | def __init__(self):
90 | super().__init__()
91 | self.params = gtsam.LevenbergMarquardtParams()
92 | # static void SetLegacyDefaults(LevenbergMarquardtParams* p) {
93 | # // Relevant NonlinearOptimizerParams:
94 | # p->maxIterations = 100;
95 | # p->relativeErrorTol = 1e-5;
96 | # p->absoluteErrorTol = 1e-5;
97 | # // LM-specific:
98 | # p->lambdaInitial = 1e-5;
99 | # p->lambdaFactor = 10.0;
100 | # p->lambdaUpperBound = 1e5;
101 | # p->lambdaLowerBound = 0.0;
102 | # p->minModelFidelity = 1e-3;
103 | # p->diagonalDamping = false;
104 | # p->useFixedLambdaFactor = true;
105 | self.params.setVerbosityLM("SUMMARY")
106 | self.x0 = Values()
107 |
108 | def solve(self, factor_graph, x0):
109 | self.x0.insert(x0)
110 | optimizer = gtsam.LevenbergMarquardtOptimizer(factor_graph, self.x0, self.params)
111 | return optimizer.optimize()
112 |
113 |
114 | class NonlinearLS(Solver):
115 | def __init__(self, iterations=2, min_x_update=0.001, linear_solver=LinearSolverType.Cholesky):
116 | super().__init__()
117 | self.iters = iterations
118 | self.min_x_update = min_x_update
119 | if linear_solver == LinearSolverType.Inverse:
120 | self.linear_solver = LinearLS.solve_inverse
121 | elif linear_solver == LinearSolverType.Cholesky:
122 | self.linear_solver = LinearLS.solve_cholesky
123 | else:
124 | raise ValueError(f'Unknown linear solver: {linear_solver}')
125 |
126 | # Set all data members to 0
127 | self._reset_history()
128 |
129 | # Logging
130 | self.log_every_n = 1
131 |
132 | # f is a nonlinear function of the states resulting in measurements
133 | # w is the measurement weights
134 | # x0 is the linearization point (initial guess)
135 | def solve(self, fg: TorchFactorGraph, x0: Variables) -> th.Tensor:
136 | # Start iterative call to nonlinear solve
137 | self._reset_history()
138 | return self._solve_nonlinear(fg, x0, 1)
139 |
140 | # TODO(Toni): abstract this, and perhaps even call iSAM2 from gtsam...
141 | # f is expected to be in standard form, meaning weighted/whitened.
142 | def _solve_nonlinear(self, fg: TorchFactorGraph, x0: Variables, i: int) -> th.Tensor:
143 | if fg.is_empty():
144 | log.warn("Factor graph is empty, returning initial guess...")
145 | return x0
146 |
147 | # 1. Linearize nonlinear system with respect to x, this can do Schur complement as well.
148 | A, b, w = fg.linearize(x0)
149 |
150 | # 2. Solve linear system
151 | delta_x = self.linear_solver(A, b, w)
152 |
153 | if delta_x is None:
154 | log.warn("Linear system is not invertible, returning initial guess...")
155 | return x0
156 |
157 | # 3. Retract/Update
158 | x_new = Variables()
159 | for var in x0.retract(delta_x).items():
160 | x_new.add(var[1])
161 |
162 | # 4. Store best solutions so far
163 | with th.no_grad():
164 | loss_new = fg(x_new)
165 | if self.x_best is None or th.sum(loss_new) < th.sum(self.loss_best):
166 | self.x_best = x_new
167 | self.loss_best = loss_new
168 |
169 | # Logging
170 | self._logging(x0, delta_x, loss_new, i)
171 |
172 | # 5. Repeat until we reach termination condition
173 | return x_new if self.terminate(i, x_new) else self._solve_nonlinear(fg, x_new, i + 1)
174 |
175 | def _logging(self, x0, delta_x, loss, i):
176 | self.x0_list.append(x0)
177 | self.delta_x_list.append(delta_x)
178 |
179 | if i % self.log_every_n == 0:
180 | # Print sum over batches loss
181 | print(f"inner_loss: {th.sum(loss):>7f} [{i:>5d}/{self.iters:>5d}]")
182 |
183 |
184 | # should be abstract
185 | def terminate(self, i, x):
186 | # one implementation is just to look at how many iters we have done
187 | convergence = False
188 | # with th.no_grad():
189 | # ic(x.shape)
190 | # # We need to determine per-batch convergence :O
191 | # # And how do we avoid per-batch updates for the converged ones?
192 | # convergence = th.abs(x - self.x0_list[-1]) < self.x_tol
193 | # convergence = self.delta_x_list[-1] < self.min_x_update
194 | return i >= self.iters or convergence
195 |
196 |
197 | def _reset_history(self):
198 | self.x_best = None
199 | self.loss_best = None
200 | self.A_list = []
201 | self.b_list = []
202 | self.x0_list = []
203 | self.delta_x_list = []
204 |
--------------------------------------------------------------------------------
/src/correlation_kernels.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 | #include
6 | #include
7 |
8 |
9 | #include
10 | #include
11 | #include
12 |
13 | #define BLOCK 16
14 |
15 | __forceinline__ __device__ bool within_bounds(int h, int w, int H, int W) {
16 | return h >= 0 && h < H && w >= 0 && w < W;
17 | }
18 |
19 | template
20 | __global__ void corr_index_forward_kernel(
21 | const torch::PackedTensorAccessor32 volume,
22 | const torch::PackedTensorAccessor32 coords,
23 | torch::PackedTensorAccessor32 corr,
24 | int r)
25 | {
26 | // batch index
27 | const int x = blockIdx.x * blockDim.x + threadIdx.x;
28 | const int y = blockIdx.y * blockDim.y + threadIdx.y;
29 | const int n = blockIdx.z;
30 |
31 | const int h1 = volume.size(1);
32 | const int w1 = volume.size(2);
33 | const int h2 = volume.size(3);
34 | const int w2 = volume.size(4);
35 |
36 | if (!within_bounds(y, x, h1, w1)) {
37 | return;
38 | }
39 |
40 | float x0 = coords[n][0][y][x];
41 | float y0 = coords[n][1][y][x];
42 |
43 | float dx = x0 - floor(x0);
44 | float dy = y0 - floor(y0);
45 |
46 | int rd = 2*r + 1;
47 | for (int i=0; i(floor(x0)) - r + i;
50 | int y1 = static_cast(floor(y0)) - r + j;
51 |
52 | if (within_bounds(y1, x1, h2, w2)) {
53 | scalar_t s = volume[n][y][x][y1][x1];
54 |
55 | if (i > 0 && j > 0)
56 | corr[n][i-1][j-1][y][x] += s * scalar_t(dx * dy);
57 |
58 | if (i > 0 && j < rd)
59 | corr[n][i-1][j][y][x] += s * scalar_t(dx * (1.0f-dy));
60 |
61 | if (i < rd && j > 0)
62 | corr[n][i][j-1][y][x] += s * scalar_t((1.0f-dx) * dy);
63 |
64 | if (i < rd && j < rd)
65 | corr[n][i][j][y][x] += s * scalar_t((1.0f-dx) * (1.0f-dy));
66 |
67 | }
68 | }
69 | }
70 | }
71 |
72 |
73 | template
74 | __global__ void corr_index_backward_kernel(
75 | const torch::PackedTensorAccessor32 coords,
76 | const torch::PackedTensorAccessor32 corr_grad,
77 | torch::PackedTensorAccessor32 volume_grad,
78 | int r)
79 | {
80 | // batch index
81 | const int x = blockIdx.x * blockDim.x + threadIdx.x;
82 | const int y = blockIdx.y * blockDim.y + threadIdx.y;
83 | const int n = blockIdx.z;
84 |
85 | const int h1 = volume_grad.size(1);
86 | const int w1 = volume_grad.size(2);
87 | const int h2 = volume_grad.size(3);
88 | const int w2 = volume_grad.size(4);
89 |
90 | if (!within_bounds(y, x, h1, w1)) {
91 | return;
92 | }
93 |
94 | float x0 = coords[n][0][y][x];
95 | float y0 = coords[n][1][y][x];
96 |
97 | float dx = x0 - floor(x0);
98 | float dy = y0 - floor(y0);
99 |
100 | int rd = 2*r + 1;
101 | for (int i=0; i(floor(x0)) - r + i;
104 | int y1 = static_cast(floor(y0)) - r + j;
105 |
106 | if (within_bounds(y1, x1, h2, w2)) {
107 | scalar_t g = 0.0;
108 | if (i > 0 && j > 0)
109 | g += corr_grad[n][i-1][j-1][y][x] * scalar_t(dx * dy);
110 |
111 | if (i > 0 && j < rd)
112 | g += corr_grad[n][i-1][j][y][x] * scalar_t(dx * (1.0f-dy));
113 |
114 | if (i < rd && j > 0)
115 | g += corr_grad[n][i][j-1][y][x] * scalar_t((1.0f-dx) * dy);
116 |
117 | if (i < rd && j < rd)
118 | g += corr_grad[n][i][j][y][x] * scalar_t((1.0f-dx) * (1.0f-dy));
119 |
120 | volume_grad[n][y][x][y1][x1] += g;
121 | }
122 | }
123 | }
124 | }
125 |
126 | std::vector corr_index_cuda_forward(
127 | torch::Tensor volume,
128 | torch::Tensor coords,
129 | int radius)
130 | {
131 | const auto batch_size = volume.size(0);
132 | const auto ht = volume.size(1);
133 | const auto wd = volume.size(2);
134 |
135 | const dim3 blocks((wd + BLOCK - 1) / BLOCK,
136 | (ht + BLOCK - 1) / BLOCK,
137 | batch_size);
138 |
139 | const dim3 threads(BLOCK, BLOCK);
140 |
141 | auto opts = volume.options();
142 | torch::Tensor corr = torch::zeros(
143 | {batch_size, 2*radius+1, 2*radius+1, ht, wd}, opts);
144 |
145 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(volume.type(), "sampler_forward_kernel", ([&] {
146 | corr_index_forward_kernel<<>>(
147 | volume.packed_accessor32(),
148 | coords.packed_accessor32(),
149 | corr.packed_accessor32(),
150 | radius);
151 | }));
152 |
153 | return {corr};
154 |
155 | }
156 |
157 | std::vector corr_index_cuda_backward(
158 | torch::Tensor volume,
159 | torch::Tensor coords,
160 | torch::Tensor corr_grad,
161 | int radius)
162 | {
163 | const auto batch_size = volume.size(0);
164 | const auto ht = volume.size(1);
165 | const auto wd = volume.size(2);
166 |
167 | auto volume_grad = torch::zeros_like(volume);
168 |
169 | const dim3 blocks((wd + BLOCK - 1) / BLOCK,
170 | (ht + BLOCK - 1) / BLOCK,
171 | batch_size);
172 |
173 | const dim3 threads(BLOCK, BLOCK);
174 |
175 |
176 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(volume.type(), "sampler_backward_kernel", ([&] {
177 | corr_index_backward_kernel<<>>(
178 | coords.packed_accessor32(),
179 | corr_grad.packed_accessor32(),
180 | volume_grad.packed_accessor32(),
181 | radius);
182 | }));
183 |
184 | return {volume_grad};
185 | }
--------------------------------------------------------------------------------
/src/droid.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 |
5 | #include
6 | #include
7 |
8 | // CUDA forward declarations
9 | std::vector projective_transform_cuda(
10 | torch::Tensor poses,
11 | torch::Tensor disps,
12 | torch::Tensor intrinsics,
13 | torch::Tensor ii,
14 | torch::Tensor jj);
15 |
16 |
17 |
18 | torch::Tensor depth_filter_cuda(
19 | torch::Tensor poses,
20 | torch::Tensor disps,
21 | torch::Tensor intrinsics,
22 | torch::Tensor ix,
23 | torch::Tensor thresh);
24 |
25 |
26 | torch::Tensor frame_distance_cuda(
27 | torch::Tensor poses,
28 | torch::Tensor disps,
29 | torch::Tensor intrinsics,
30 | torch::Tensor ii,
31 | torch::Tensor jj,
32 | const float beta);
33 |
34 | std::vector projmap_cuda(
35 | torch::Tensor poses,
36 | torch::Tensor disps,
37 | torch::Tensor intrinsics,
38 | torch::Tensor ii,
39 | torch::Tensor jj);
40 |
41 | torch::Tensor iproj_cuda(
42 | torch::Tensor poses,
43 | torch::Tensor disps,
44 | torch::Tensor intrinsics);
45 |
46 | std::vector ba_cuda(
47 | torch::Tensor poses,
48 | torch::Tensor body_poses,
49 | torch::Tensor disps,
50 | torch::Tensor intrinsics,
51 | torch::Tensor extrinsics,
52 | torch::Tensor disps_sens,
53 | torch::Tensor targets,
54 | torch::Tensor weights,
55 | torch::Tensor eta,
56 | torch::Tensor ii,
57 | torch::Tensor jj,
58 | const int t0,
59 | const int t1,
60 | const int iterations,
61 | const float lm,
62 | const float ep,
63 | const bool motion_only);
64 |
65 | std::vector
66 | reduced_camera_matrix_cuda(
67 | torch::Tensor poses,
68 | torch::Tensor body_poses,
69 | torch::Tensor disps,
70 | torch::Tensor intrinsics,
71 | torch::Tensor extrinsics,
72 | torch::Tensor disps_sens,
73 | torch::Tensor targets,
74 | torch::Tensor weights,
75 | torch::Tensor eta,
76 | torch::Tensor ii,
77 | torch::Tensor jj,
78 | const int t0,
79 | const int t1);
80 |
81 | void solve_depth_cuda(
82 | torch::Tensor dx,
83 | torch::Tensor disps,
84 | torch::Tensor Q,
85 | torch::Tensor E,
86 | torch::Tensor w,
87 | torch::Tensor ii,
88 | torch::Tensor jj,
89 | const int t0,
90 | const int t1);
91 |
92 | // void solve_cuda(
93 | // torch::Tensor A,
94 | // torch::Tensor S,
95 | // const float lm,
96 | // const float ep);
97 | //
98 | void solve_poses_cuda(
99 | torch::Tensor poses,
100 | torch::Tensor dx,
101 | const int t0,
102 | const int t1);
103 |
104 | std::vector corr_index_cuda_forward(
105 | torch::Tensor volume,
106 | torch::Tensor coords,
107 | int radius);
108 |
109 | std::vector corr_index_cuda_backward(
110 | torch::Tensor volume,
111 | torch::Tensor coords,
112 | torch::Tensor corr_grad,
113 | int radius);
114 |
115 | std::vector altcorr_cuda_forward(
116 | torch::Tensor fmap1,
117 | torch::Tensor fmap2,
118 | torch::Tensor coords,
119 | int radius);
120 |
121 | std::vector altcorr_cuda_backward(
122 | torch::Tensor fmap1,
123 | torch::Tensor fmap2,
124 | torch::Tensor coords,
125 | torch::Tensor corr_grad,
126 | int radius);
127 |
128 |
129 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
130 | #define CHECK_INPUT(x) CHECK_CONTIGUOUS(x)
131 |
132 |
133 | std::vector ba(
134 | torch::Tensor poses,
135 | torch::Tensor body_poses,
136 | torch::Tensor disps,
137 | torch::Tensor intrinsics,
138 | torch::Tensor extrinsics,
139 | torch::Tensor disps_sens,
140 | torch::Tensor targets,
141 | torch::Tensor weights,
142 | torch::Tensor eta,
143 | torch::Tensor ii,
144 | torch::Tensor jj,
145 | const int t0,
146 | const int t1,
147 | const int iterations,
148 | const float lm,
149 | const float ep,
150 | const bool motion_only) {
151 |
152 | CHECK_INPUT(targets);
153 | CHECK_INPUT(weights);
154 | CHECK_INPUT(poses);
155 | CHECK_INPUT(body_poses);
156 | CHECK_INPUT(disps);
157 | CHECK_INPUT(intrinsics);
158 | CHECK_INPUT(disps_sens);
159 | CHECK_INPUT(ii);
160 | CHECK_INPUT(jj);
161 |
162 | return ba_cuda(poses, body_poses, disps, intrinsics, extrinsics, disps_sens, targets, weights,
163 | eta, ii, jj, t0, t1, iterations, lm, ep, motion_only);
164 |
165 | }
166 |
167 | std::vector
168 | reduced_camera_matrix(
169 | torch::Tensor poses,
170 | torch::Tensor body_poses,
171 | torch::Tensor disps,
172 | torch::Tensor intrinsics,
173 | torch::Tensor extrinsics,
174 | torch::Tensor disps_sens,
175 | torch::Tensor targets,
176 | torch::Tensor weights,
177 | torch::Tensor eta,
178 | torch::Tensor ii,
179 | torch::Tensor jj,
180 | const int t0,
181 | const int t1) {
182 |
183 | CHECK_INPUT(poses);
184 | CHECK_INPUT(disps);
185 | CHECK_INPUT(intrinsics);
186 | CHECK_INPUT(extrinsics);
187 | CHECK_INPUT(disps_sens);
188 | CHECK_INPUT(targets);
189 | CHECK_INPUT(weights);
190 | CHECK_INPUT(eta);
191 | CHECK_INPUT(ii);
192 | CHECK_INPUT(jj);
193 |
194 | return reduced_camera_matrix_cuda(poses, body_poses, disps, intrinsics, extrinsics, disps_sens, targets, weights,
195 | eta, ii, jj, t0, t1);
196 | }
197 |
198 | void solve_depth(
199 | torch::Tensor dx,
200 | torch::Tensor disps,
201 | torch::Tensor Q,
202 | torch::Tensor E,
203 | torch::Tensor w,
204 | torch::Tensor ii,
205 | torch::Tensor jj,
206 | const int t0,
207 | const int t1) {
208 |
209 | CHECK_INPUT(dx);
210 | CHECK_INPUT(disps);
211 | CHECK_INPUT(Q);
212 | CHECK_INPUT(E);
213 | CHECK_INPUT(w);
214 | CHECK_INPUT(ii);
215 | CHECK_INPUT(jj);
216 |
217 | return solve_depth_cuda(dx, disps, Q, E, w, ii, jj, t0, t1);
218 | }
219 |
220 | void solve_poses(torch::Tensor poses,
221 | torch::Tensor dx,
222 | const int t0,
223 | const int t1) {
224 | CHECK_INPUT(poses);
225 | CHECK_INPUT(dx);
226 |
227 | return solve_poses_cuda(poses, dx, t0, t1);
228 | }
229 |
230 | torch::Tensor frame_distance(
231 | torch::Tensor poses,
232 | torch::Tensor disps,
233 | torch::Tensor intrinsics,
234 | torch::Tensor ii,
235 | torch::Tensor jj,
236 | const float beta) {
237 |
238 | CHECK_INPUT(poses);
239 | CHECK_INPUT(disps);
240 | CHECK_INPUT(intrinsics);
241 | CHECK_INPUT(ii);
242 | CHECK_INPUT(jj);
243 |
244 | return frame_distance_cuda(poses, disps, intrinsics, ii, jj, beta);
245 |
246 | }
247 |
248 |
249 | std::vector projmap(
250 | torch::Tensor poses,
251 | torch::Tensor disps,
252 | torch::Tensor intrinsics,
253 | torch::Tensor ii,
254 | torch::Tensor jj) {
255 |
256 | CHECK_INPUT(poses);
257 | CHECK_INPUT(disps);
258 | CHECK_INPUT(intrinsics);
259 | CHECK_INPUT(ii);
260 | CHECK_INPUT(jj);
261 |
262 | return projmap_cuda(poses, disps, intrinsics, ii, jj);
263 |
264 | }
265 |
266 |
267 | torch::Tensor iproj(
268 | torch::Tensor poses,
269 | torch::Tensor disps,
270 | torch::Tensor intrinsics) {
271 | CHECK_INPUT(poses);
272 | CHECK_INPUT(disps);
273 | CHECK_INPUT(intrinsics);
274 |
275 | return iproj_cuda(poses, disps, intrinsics);
276 | }
277 |
278 |
279 | // c++ python binding
280 | std::vector corr_index_forward(
281 | torch::Tensor volume,
282 | torch::Tensor coords,
283 | int radius) {
284 | CHECK_INPUT(volume);
285 | CHECK_INPUT(coords);
286 |
287 | return corr_index_cuda_forward(volume, coords, radius);
288 | }
289 |
290 | std::vector corr_index_backward(
291 | torch::Tensor volume,
292 | torch::Tensor coords,
293 | torch::Tensor corr_grad,
294 | int radius) {
295 | CHECK_INPUT(volume);
296 | CHECK_INPUT(coords);
297 | CHECK_INPUT(corr_grad);
298 |
299 | auto volume_grad = corr_index_cuda_backward(volume, coords, corr_grad, radius);
300 | return {volume_grad};
301 | }
302 |
303 | std::vector altcorr_forward(
304 | torch::Tensor fmap1,
305 | torch::Tensor fmap2,
306 | torch::Tensor coords,
307 | int radius) {
308 | CHECK_INPUT(fmap1);
309 | CHECK_INPUT(fmap2);
310 | CHECK_INPUT(coords);
311 |
312 | return altcorr_cuda_forward(fmap1, fmap2, coords, radius);
313 | }
314 |
315 | std::vector altcorr_backward(
316 | torch::Tensor fmap1,
317 | torch::Tensor fmap2,
318 | torch::Tensor coords,
319 | torch::Tensor corr_grad,
320 | int radius) {
321 | CHECK_INPUT(fmap1);
322 | CHECK_INPUT(fmap2);
323 | CHECK_INPUT(coords);
324 | CHECK_INPUT(corr_grad);
325 |
326 | return altcorr_cuda_backward(fmap1, fmap2, coords, corr_grad, radius);
327 | }
328 |
329 |
330 | torch::Tensor depth_filter(
331 | torch::Tensor poses,
332 | torch::Tensor disps,
333 | torch::Tensor intrinsics,
334 | torch::Tensor ix,
335 | torch::Tensor thresh) {
336 |
337 | CHECK_INPUT(poses);
338 | CHECK_INPUT(disps);
339 | CHECK_INPUT(intrinsics);
340 | CHECK_INPUT(ix);
341 | CHECK_INPUT(thresh);
342 |
343 | return depth_filter_cuda(poses, disps, intrinsics, ix, thresh);
344 | }
345 |
346 |
347 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
348 | // bundle adjustment kernels
349 | m.def("ba", &ba, "bundle adjustment");
350 | m.def("reduced_camera_matrix", &reduced_camera_matrix, "reduced camera matrix");
351 | m.def("solve_depth", &solve_depth, "Retract depth given dx");
352 | m.def("solve_poses", &solve_poses, "Retract poses given dx");
353 | m.def("frame_distance", &frame_distance, "frame_distance");
354 | m.def("projmap", &projmap, "projmap");
355 | m.def("depth_filter", &depth_filter, "depth_filter");
356 | m.def("iproj", &iproj, "back projection");
357 |
358 | // correlation volume kernels
359 | m.def("altcorr_forward", &altcorr_forward, "ALTCORR forward");
360 | m.def("altcorr_backward", &altcorr_backward, "ALTCORR backward");
361 | m.def("corr_index_forward", &corr_index_forward, "INDEX forward");
362 | m.def("corr_index_backward", &corr_index_backward, "INDEX backward");
363 | }
364 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 |
4 |
--------------------------------------------------------------------------------
/utils/evaluation.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import numpy as np
4 | import open3d as o3d
5 | from icecream import ic
6 |
7 | class MeshRenderer:
8 | def __init__(self, mesh_path, intrinsics, resolution) -> None:
9 | self.mesh = o3d.io.read_triangle_mesh(mesh_path)
10 | # Do we need to register the gt_mesh to the SLAM frame of ref? I don't think so because we start SLAM with gt pose
11 | # Nonethelesssss, we need to register the SLAM's estimated trajectory with the ground-truth using Sim(3)!
12 | # And scale the gt_mesh with the resulting scale parameter,
13 | # OR! Just render depths at the ground-truth trajectory and register the depth-maps with scale? (I think that is what ORbeez doess)
14 |
15 | focal_length = intrinsics[:2]
16 | principal_point = intrinsics[2:]
17 | fx, fy = focal_length[0], focal_length[1]
18 | cx, cy = principal_point[0], principal_point[1]
19 | w, h = resolution[0], resolution[1]
20 |
21 | self.cam = o3d.camera.PinholeCameraParameters()
22 | self.cam.intrinsic = o3d.camera.PinholeCameraIntrinsic(w, h, fx, fy, cx, cy)
23 |
24 | self.viz = o3d.visualization.Visualizer()
25 | self.viz.create_window(width=w, height=h)
26 | self.viz.get_render_option().mesh_show_back_face = True
27 |
28 | self.mesh_uploaded = False
29 |
30 | self.viz_world_frame = False
31 |
32 | def render_mesh(self, c2w):
33 | viewport = self.viz.get_view_control().convert_to_pinhole_camera_parameters()
34 |
35 | self.cam.extrinsic = np.linalg.inv(c2w)
36 |
37 | if self.viz_world_frame:
38 | self.viz.add_geometry(self.create_frame_actor(np.eye(4), scale=0.5), reset_bounding_box=True)
39 | self.viz_world_frame = False
40 |
41 | if not self.mesh_uploaded:
42 | self.viz.add_geometry(self.mesh, reset_bounding_box=True)
43 | self.mesh_uploaded = True
44 |
45 | #ctr = self.viz.get_view_control()
46 | #ctr.set_constant_z_far(20)
47 | #ctr.convert_from_pinhole_camera_parameters(self.cam)
48 |
49 | self.viz.poll_events()
50 | self.viz.update_renderer()
51 |
52 | gt_depth = self.viz.capture_depth_float_buffer(True)
53 | gt_depth = np.asarray(gt_depth)
54 |
55 | # hack to allow interacting when using add_geometry
56 | viewport = self.viz.get_view_control().convert_from_pinhole_camera_parameters(viewport)
57 |
58 | self.viz.poll_events()
59 | self.viz.update_renderer()
60 |
61 | return gt_depth
62 |
63 |
64 | def create_frame_actor(self, pose, scale=0.05):
65 | frame_actor = o3d.geometry.TriangleMesh.create_coordinate_frame(size=scale,
66 | origin=np.array([0., 0., 0.]))
67 | frame_actor.transform(pose)
68 | return frame_actor
--------------------------------------------------------------------------------
/utils/open3d_pickle.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import open3d as o3d
4 | import numpy as np
5 |
6 |
7 | class _MeshTransmissionFormat:
8 | def __init__(self, mesh: o3d.geometry.TriangleMesh):
9 | assert(mesh)
10 | self.triangle_material_ids = np.array(mesh.triangle_material_ids)
11 | self.triangle_normals = np.array(mesh.triangle_normals)
12 | self.triangle_uvs = np.array(mesh.triangle_uvs)
13 | self.triangles = np.array(mesh.triangles)
14 |
15 | self.vertex_colors = np.array(mesh.vertex_colors)
16 | self.vertex_normals = np.array(mesh.vertex_normals)
17 | self.vertices = np.array(mesh.vertices)
18 |
19 | def create_mesh(self) -> o3d.geometry.TriangleMesh:
20 | mesh = o3d.geometry.TriangleMesh()
21 |
22 | mesh.triangle_material_ids = o3d.utility.IntVector(
23 | self.triangle_material_ids)
24 | mesh.triangle_normals = o3d.utility.Vector3dVector(
25 | self.triangle_normals)
26 | mesh.triangle_uvs = o3d.utility.Vector2dVector(self.triangle_uvs)
27 | mesh.triangles = o3d.utility.Vector3iVector(self.triangles)
28 |
29 | mesh.vertex_colors = o3d.utility.Vector3dVector(self.vertex_colors)
30 | mesh.vertex_normals = o3d.utility.Vector3dVector(self.vertex_normals)
31 |
32 | mesh.vertices = o3d.utility.Vector3dVector(self.vertices)
33 | return mesh
34 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import cv2
4 |
5 | import numpy as np
6 | from scipy.spatial.transform import Rotation
7 |
8 | from icecream import ic
9 | import torch
10 |
11 | def qvec2rotmat(qvec):
12 | return np.array([
13 | [
14 | 1 - 2 * qvec[2]**2 - 2 * qvec[3]**2,
15 | 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
16 | 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]
17 | ], [
18 | 2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
19 | 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2,
20 | 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]
21 | ], [
22 | 2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
23 | 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
24 | 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2
25 | ]
26 | ])
27 |
28 | def rotmat(a, b):
29 | a, b = a / np.linalg.norm(a), b / np.linalg.norm(b)
30 | try:
31 | v = np.cross(a, b)
32 | except:
33 | pass
34 | c = np.dot(a, b)
35 | s = np.linalg.norm(v)
36 | kmat = np.array([[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]])
37 | return np.eye(3) + kmat + kmat.dot(kmat) * ((1 - c) / (s ** 2 + 1e-10))
38 |
39 | def closest_point_2_lines(oa, da, ob, db): # returns point closest to both rays of form o+t*d, and a weight factor that goes to 0 if the lines are parallel
40 | da = da / np.linalg.norm(da)
41 | db = db / np.linalg.norm(db)
42 | try:
43 | c = np.cross(da, db)
44 | except:
45 | pass
46 | denom = np.linalg.norm(c)**2
47 | t = ob - oa
48 | ta = np.linalg.det([t, db, c]) / (denom + 1e-10)
49 | tb = np.linalg.det([t, da, c]) / (denom + 1e-10)
50 | if ta > 0:
51 | ta = 0
52 | if tb > 0:
53 | tb = 0
54 | return (oa+ta*da+ob+tb*db) * 0.5, denom
55 |
56 | def variance_of_laplacian(image):
57 | return cv2.Laplacian(image, cv2.CV_64F).var()
58 |
59 | def sharpness(image):
60 | if image is str:
61 | image = cv2.imread(image)
62 | gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
63 | fm = variance_of_laplacian(gray)
64 | return fm
65 |
66 | def pose_matrix_to_t_and_quat(pose):
67 | """ convert 4x4 pose matrix to (t, q) """
68 | q = Rotation.from_matrix(pose[:3, :3]).as_quat()
69 | return np.concatenate([pose[:3, 3], q], axis=0)
70 |
71 | def get_pose_from_df(gt_df):
72 | x = gt_df.loc['tx'] # [m]
73 | y = gt_df.loc['ty'] # [m]
74 | z = gt_df.loc['tz'] # [m]
75 | qx = gt_df.loc['qx']
76 | qy = gt_df.loc['qy']
77 | qz = gt_df.loc['qz']
78 | qw = gt_df.loc['qw']
79 | r = Rotation.from_quat([qx, qy, qz, qw])
80 | pose = np.eye(4)
81 | pose[:3, :3] = r.as_matrix()
82 | pose[:3, 3] = [x, y, z]
83 | return pose
84 |
85 | def get_velocity(euroc_df):
86 | # TODO: rename these
87 | vx = euroc_df.loc[' v_RS_R_x [m s^-1]']
88 | vy = euroc_df.loc[' v_RS_R_y [m s^-1]']
89 | vz = euroc_df.loc[' v_RS_R_z [m s^-1]']
90 | return np.array([vx, vy, vz])
91 |
92 | def get_bias(euroc_df):
93 | # TODO: rename these
94 | ba_x = euroc_df.loc[' b_a_RS_S_x [m s^-2]']
95 | ba_y = euroc_df.loc[' b_a_RS_S_y [m s^-2]']
96 | ba_z = euroc_df.loc[' b_a_RS_S_z [m s^-2]']
97 | bg_x = euroc_df.loc[' b_w_RS_S_x [rad s^-1]']
98 | bg_y = euroc_df.loc[' b_w_RS_S_y [rad s^-1]']
99 | bg_z = euroc_df.loc[' b_w_RS_S_z [rad s^-1]']
100 | return np.array([ba_x, ba_y, ba_z, bg_x, bg_y, bg_z])
101 |
102 |
103 | # offse's default is 0.5 bcs we scale/offset poses to 1.0/0.5 before feeding to nerf
104 | def nerf_matrix_to_ngp(nerf_matrix, scale=1.0, offset=0.5):
105 | result = nerf_matrix.copy()
106 | result[:3, 1] *= -1
107 | result[:3, 2] *= -1
108 | result[:3, 3] = result[:3, 3] * scale + offset
109 |
110 | # Cycle axes xyz<-yzx
111 | tmp = result[0, :].copy()
112 | result[0, :] = result[1, :]
113 | result[1, :] = result[2, :]
114 | result[2, :] = tmp
115 |
116 | return result
117 |
118 | # offset's default is 0.5 bcs we scale/offset poses to 1.0/0.5 before feeding to nerf
119 | def ngp_matrix_to_nerf(ngp_matrix, scale=1.0, offset=0.5):
120 | result = ngp_matrix.copy()
121 |
122 | # Cycle axes xyz->yzx
123 | tmp = result[2, :].copy()
124 | result[1, :] = result[0, :]
125 | result[2, :] = result[1, :]
126 | result[0, :] = tmp
127 |
128 | result[:3, 0] *= 1 / scale
129 | result[:3, 1] *= -1 / scale
130 | result[:3, 2] *= -1 / scale
131 | result[:3, 3] = (result[:3, 3] - offset) / scale
132 |
133 | return result;
134 |
135 | # This can be very slow, send to gpu device...
136 | def srgb_to_linear(img, device):
137 | img = img.to(device=device)
138 | limit = 0.04045
139 | return torch.where(img > limit, torch.pow((img + 0.055) / 1.055, 2.4), img / 12.92)
140 |
141 |
142 | def linear_to_srgb(img):
143 | limit = 0.0031308
144 | return np.where(img > limit, 1.055 * (img ** (1.0 / 2.4)) - 0.055, 12.92 * img)
145 |
146 |
147 | def get_scale_and_offset(aabb):
148 | # map the given aabb of the form [[minx,miny,minz],[maxx,maxy,maxz]]
149 | # via an isotropic scale and translate to fit in the (0,0,0)-(1,1,1) cube,
150 | # with the given center at 0.5,0.5,0.5
151 | aabb = np.array(aabb, dtype=np.float64)
152 | dx = aabb[1][0]-aabb[0][0]
153 | dy = aabb[1][1]-aabb[0][1]
154 | dz = aabb[1][2]-aabb[0][2]
155 | length = max(0.000001, max(max(abs(dx), abs(dy)), abs(dz)))
156 | scale = 1.0 / length
157 | offset = np.array([((aabb[1][0]+aabb[0][0])*0.5) * -scale + 0.5,
158 | ((aabb[1][1]+aabb[0][1])*0.5) * -scale + 0.5,
159 | ((aabb[1][2]+aabb[0][2])*0.5) * -scale + 0.5])
160 | return scale, offset
161 |
162 |
163 | def scale_offset_poses(poses, scale, offset): # for c2w poses!
164 | poses[:, :3, 3] = poses[:, :3, 3] * scale + offset
165 | return poses
166 |
167 |
168 | def mse2psnr(x):
169 | return -10.*np.log(x)/np.log(10.)
170 |
171 |
172 | def L2(img, ref):
173 | return (img - ref)**2
174 |
175 |
176 | def compute_mse_img(img, ref):
177 | img[np.logical_not(np.isfinite(img))] = 0
178 | img = np.maximum(img, 0.)
179 | return L2(img, ref)
180 |
181 |
182 | def compute_error(img, ref):
183 | metric_map = compute_mse_img(img, ref)
184 | metric_map[np.logical_not(np.isfinite(metric_map))] = 0
185 | if len(metric_map.shape) == 3:
186 | metric_map = np.mean(metric_map, axis=2)
187 | mean = np.mean(metric_map)
188 | return mean
189 |
--------------------------------------------------------------------------------