├── .gitignore ├── .gitmodules ├── LICENSE.BSD ├── README.md ├── datasets ├── __init__.py ├── data_module.py ├── dataset.py ├── dataset_dict.py ├── euroc_dataset.py ├── nerf_dataset.py ├── real_sense_dataset.py ├── replica_dataset.py └── tum_dataset.py ├── droid.pth ├── examples ├── __init__.py └── slam_demo.py ├── factor_graph ├── __init__.py ├── factor.py ├── factor_graph.py ├── key.py ├── loss_function.py └── variables.py ├── fusion ├── __init__.py ├── fusion_module.py ├── nerf_fusion.py └── tsdf_fusion.py ├── gui ├── __init__.py ├── gui_module.py └── open3d_gui.py ├── media ├── intro.gif ├── mit.png ├── mrg_logo.png └── sparklab_logo.png ├── networks ├── __init__.py ├── droid_frontend.py ├── droid_net.py ├── factor_graph.py ├── geom │ ├── __init__.py │ ├── ba.py │ ├── chol.py │ ├── graph_utils.py │ ├── losses.py │ ├── projective_ops.py │ └── rgbd_utils.py ├── modules │ ├── __init__.py │ ├── clipping.py │ ├── corr.py │ ├── extractor.py │ └── gru.py └── motion_filter.py ├── pipeline ├── __init__.py ├── pipeline.py └── pipeline_module.py ├── requirements.txt ├── scripts ├── convergence_plots.ipynb ├── download_cube.bash ├── download_replica.bash ├── download_replica_sample.bash ├── record_real_sense.py ├── replica_results.py ├── replica_to_nerf_dataset.py └── unzip_tartan_air.py ├── setup.py ├── slam ├── __init__.py ├── inertial_frontends │ ├── __init__.py │ └── inertial_frontend.py ├── meta_slam.py ├── slam_module.py ├── vio_slam.py └── visual_frontends │ ├── __init__.py │ └── visual_frontend.py ├── solvers ├── __init__.py ├── linear_solver.py ├── meta_solver.py └── nonlinear_solver.py ├── src ├── altcorr_kernel.cu ├── correlation_kernels.cu ├── droid.cpp └── droid_kernels.cu └── utils ├── __init__.py ├── evaluation.py ├── flow_viz.py ├── open3d_pickle.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | imgui.ini 2 | *.npz 3 | *.svg 4 | 5 | # Meshes 6 | *.pcd 7 | *.ply 8 | *.obj 9 | 10 | 11 | ### Python specific gitignore 12 | # Byte-compiled / optimized / DLL files 13 | __pycache__/ 14 | *.py[cod] 15 | *$py.class 16 | 17 | # C extensions 18 | *.so 19 | 20 | # Distribution / packaging 21 | .Python 22 | build/ 23 | develop-eggs/ 24 | dist/ 25 | downloads/ 26 | eggs/ 27 | .eggs/ 28 | lib/ 29 | lib64/ 30 | parts/ 31 | sdist/ 32 | var/ 33 | wheels/ 34 | share/python-wheels/ 35 | *.egg-info/ 36 | .installed.cfg 37 | *.egg 38 | MANIFEST 39 | 40 | # PyInstaller 41 | # Usually these files are written by a python script from a template 42 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 43 | *.manifest 44 | *.spec 45 | 46 | # Installer logs 47 | pip-log.txt 48 | pip-delete-this-directory.txt 49 | 50 | # Unit test / coverage reports 51 | htmlcov/ 52 | .tox/ 53 | .nox/ 54 | .coverage 55 | .coverage.* 56 | .cache 57 | nosetests.xml 58 | coverage.xml 59 | *.cover 60 | *.py,cover 61 | .hypothesis/ 62 | .pytest_cache/ 63 | cover/ 64 | 65 | # Translations 66 | *.mo 67 | *.pot 68 | 69 | # Django stuff: 70 | *.log 71 | local_settings.py 72 | db.sqlite3 73 | db.sqlite3-journal 74 | 75 | # Flask stuff: 76 | instance/ 77 | .webassets-cache 78 | 79 | # Scrapy stuff: 80 | .scrapy 81 | 82 | # Sphinx documentation 83 | docs/_build/ 84 | 85 | # PyBuilder 86 | .pybuilder/ 87 | target/ 88 | 89 | # Jupyter Notebook 90 | .ipynb_checkpoints 91 | 92 | # IPython 93 | profile_default/ 94 | ipython_config.py 95 | 96 | # pyenv 97 | # For a library or package, you might want to ignore these files since the code is 98 | # intended to run in multiple environments; otherwise, check them in: 99 | # .python-version 100 | 101 | # pipenv 102 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 103 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 104 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 105 | # install all needed dependencies. 106 | #Pipfile.lock 107 | 108 | # poetry 109 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 110 | # This is especially recommended for binary packages to ensure reproducibility, and is more 111 | # commonly ignored for libraries. 112 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 113 | #poetry.lock 114 | 115 | # pdm 116 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 117 | #pdm.lock 118 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 119 | # in version control. 120 | # https://pdm.fming.dev/#use-with-ide 121 | .pdm.toml 122 | 123 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 124 | __pypackages__/ 125 | 126 | # Celery stuff 127 | celerybeat-schedule 128 | celerybeat.pid 129 | 130 | # SageMath parsed files 131 | *.sage.py 132 | 133 | # Environments 134 | .env 135 | .venv 136 | env/ 137 | venv/ 138 | ENV/ 139 | env.bak/ 140 | venv.bak/ 141 | 142 | # Spyder project settings 143 | .spyderproject 144 | .spyproject 145 | 146 | # Rope project settings 147 | .ropeproject 148 | 149 | # mkdocs documentation 150 | /site 151 | 152 | # mypy 153 | .mypy_cache/ 154 | .dmypy.json 155 | dmypy.json 156 | 157 | # Pyre type checker 158 | .pyre/ 159 | 160 | # pytype static type analyzer 161 | .pytype/ 162 | 163 | # Cython debug symbols 164 | cython_debug/ 165 | 166 | # PyCharm 167 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 168 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 169 | # and can be added to the global gitignore or merged into this file. For a more nuclear 170 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 171 | #.idea/ 172 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "thirdparty/lietorch"] 2 | path = thirdparty/lietorch 3 | url = https://github.com/princeton-vl/lietorch 4 | [submodule "thirdparty/eigen"] 5 | path = thirdparty/eigen 6 | url = https://gitlab.com/libeigen/eigen.git 7 | [submodule "thirdparty/instant-ngp"] 8 | path = thirdparty/instant-ngp 9 | url = https://github.com/ToniRV/instant-ngp.git 10 | branch = feature/nerf_slam 11 | [submodule "thirdparty/gtsam"] 12 | path = thirdparty/gtsam 13 | url = https://github.com/ToniRV/gtsam-1.git 14 | branch = feature/nerf_slam 15 | -------------------------------------------------------------------------------- /LICENSE.BSD: -------------------------------------------------------------------------------- 1 | Copyright 2022 Massachusetts Institute of Technology. 2 | 3 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 4 | 5 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 6 | 7 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 8 | 9 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 10 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | sparklab 4 | 5 | 6 | kimera 7 | 8 | 9 | mit 10 | 11 |
12 | 13 |

14 |

15 |

NeRF-SLAM

16 |
17 |

18 | Real-Time Dense Monocular SLAM with Neural Radiance Fields

19 |

20 | Antoni Rosinol 21 | · 22 | John J. Leonard 23 | · 24 | Luca Carlone 25 |

26 | 27 |

28 | Paper | 29 | Video | 30 | 31 |

32 |
33 |

34 |

35 | 36 | 37 | 38 |

39 | 40 |
41 | Table of Contents 42 |
    43 |
  1. 44 | Install 45 |
  2. 46 |
  3. 47 | Download Datasets 48 |
  4. 49 |
  5. 50 | Run 51 |
  6. 52 |
  7. 53 | Citation 54 |
  8. 55 |
  9. 56 | License 57 |
  10. 58 |
  11. 59 | Acknowledgments 60 |
  12. 61 |
  13. 62 | Contact 63 |
  14. 64 |
65 |
66 | 67 | ## Install 68 | 69 | Clone repo with submodules: 70 | ``` 71 | git clone https://github.com/ToniRV/NeRF-SLAM.git --recurse-submodules 72 | git submodule update --init --recursive 73 | ``` 74 | 75 | From this point on, use a virtual environment... 76 | Install torch (see [here](https://pytorch.org/get-started/previous-versions) for other versions): 77 | ``` 78 | # CUDA 11.3 79 | pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113 80 | ``` 81 | 82 | Pip install requirements: 83 | ``` 84 | pip install -r requirements.txt 85 | pip install -r ./thirdparty/gtsam/python/requirements.txt 86 | ``` 87 | 88 | Compile ngp (you need cmake>3.22): 89 | ``` 90 | cmake ./thirdparty/instant-ngp -B build_ngp 91 | cmake --build build_ngp --config RelWithDebInfo -j 92 | ``` 93 | 94 | Compile gtsam and enable the python wrapper: 95 | ``` 96 | cmake ./thirdparty/gtsam -DGTSAM_BUILD_PYTHON=1 -B build_gtsam 97 | cmake --build build_gtsam --config RelWithDebInfo -j 98 | cd build_gtsam 99 | make python-install 100 | ``` 101 | 102 | Install: 103 | ``` 104 | python setup.py install 105 | ``` 106 | 107 | ## Download Sample Data 108 | 109 | This will just download one of the replica scenes: 110 | ``` 111 | ./scripts/download_replica_sample.bash 112 | ``` 113 | 114 | ## Run 115 | 116 | ``` 117 | python ./examples/slam_demo.py --dataset_dir=./datasets/Replica/office0 --dataset_name=nerf --buffer=100 --slam --parallel_run --img_stride=2 --fusion='nerf' --multi_gpu --gui 118 | ``` 119 | 120 | This repo also implements [Sigma-Fusion](https://arxiv.org/abs/2210.01276): just change `--fusion='sigma'` to run that. 121 | 122 | ## FAQ 123 | 124 | ### GPU Memory 125 | 126 | This is a GPU memory intensive pipeline, to monitor your GPU usage, I'd recommend to use `nvitop`. 127 | Install nvitop in a local env: 128 | ``` 129 | pip3 install --upgrade nvitop 130 | ``` 131 | 132 | Keep it running on a terminal, and monitor GPU memory usage: 133 | ``` 134 | nvitop --monitor 135 | ``` 136 | 137 | If you consistently see "out-of-memory" errors, you may either need to change parameters or buy better GPUs :). 138 | The memory consuming parts of this pipeline are: 139 | - Frame to frame correlation volumes (but can be avoided using on-the-fly correlation computation). 140 | - Volumetric rendering (intrinsically memory intensive, tricks exist, but ultimately we need to move to light fields or some better representation (OpenVDB?)). 141 | 142 | ### Installation issues 143 | 144 | 1. Gtsam not working: check that the python wrapper is installed, check instructions here: [gtsam_python](https://github.com/ToniRV/gtsam-1/blob/develop/python/README.md). Make sure you use our gtsam fork, which exposes more of gtsam's functionality to python. 145 | 2. Gtsam's dependency is not really needed, I just used to experiment adding IMU and/or stereo cameras, and have an easier interface to build factor-graphs. This didn't quite work though, because the network seemed to have a concept of scale, and it didn't quite work when updating poses/landmarks and then optical flow. 146 | 3. Somehow the parser converts [this](https://github.com/borglab/gtsam/compare/develop...ToniRV:gtsam-1:feature/nerf_slam#diff-add3627555fb7411e36ea4d863c15f4187e018b6e00b608ab260e3221aef057aR345) to 147 | `const std::vector&`, and I need to remove manually in 148 | `gtsam/build/python/linear.cpp` 149 | the inner `const X& ...`, and also add `` because: 150 | ``` 151 | Did you forget to `#include `? 152 | ``` 153 | 154 | ## Citation 155 | 156 | ```bibtex 157 | @article{rosinol2022nerf, 158 | title={NeRF-SLAM: Real-Time Dense Monocular SLAM with Neural Radiance Fields}, 159 | author={Rosinol, Antoni and Leonard, John J and Carlone, Luca}, 160 | journal={arXiv preprint arXiv:2210.13641}, 161 | year={2022} 162 | } 163 | ``` 164 | 165 | ## License 166 | 167 | This repo is BSD Licensed. 168 | It reimplements parts of Droid-SLAM (BSD Licensed). 169 | Our changes to instant-NGP (Nvidia License) are released in our [fork of instant-ngp](https://github.com/ToniRV/instant-ngp) (branch `feature/nerf_slam`) and 170 | added here as a thirdparty dependency using git submodules. 171 | 172 | ## Acknowledgments 173 | 174 | This work has been possible thanks to the open-source code from [Droid-SLAM](https://github.com/princeton-vl/DROID-SLAM) and 175 | [Instant-NGP](https://github.com/NVlabs/instant-ngp), as well as the open-source datasets [Replica](https://github.com/facebookresearch/Replica-Dataset) and [Cube-Diorama](https://github.com/jc211/nerf-cube-diorama-dataset). 176 | 177 | ## Contact 178 | 179 | I have many ideas on how to improve this approach, but I just graduated so I won't have much time to do another PhD... 180 | If you are interested in building on top of this, 181 | feel free to reach out :) [arosinol@mit.edu](arosinol@mit.edu) 182 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | 4 | -------------------------------------------------------------------------------- /datasets/data_module.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import colored_glog as log 4 | from pipeline.pipeline_module import MIMOPipelineModule 5 | 6 | class DataModule(MIMOPipelineModule): 7 | def __init__(self, name, args, device="cpu") -> None: 8 | super().__init__(name, args.parallel_run, args) 9 | self.device = device 10 | self.idx = -1 11 | 12 | def get_input_packet(self): 13 | return True 14 | 15 | def spin_once(self, input): 16 | log.check(input) 17 | if self.name == 'real': 18 | return self.dataset.stream() 19 | else: 20 | self.idx += 1 21 | if self.idx < len(self.dataset): 22 | return self.dataset[self.idx] 23 | else: 24 | print("Stopping data module!") 25 | super().shutdown_module() 26 | return None 27 | 28 | def initialize_module(self): 29 | if self.name == "euroc": 30 | from datasets.euroc_dataset import EurocDataset 31 | self.dataset = EurocDataset(self.args, self.device) 32 | elif self.name == "tum": 33 | from datasets.tum_dataset import TumDataset 34 | self.dataset = TumDataset(self.args, self.device) 35 | elif self.name == "nerf": 36 | from datasets.nerf_dataset import NeRFDataset 37 | self.dataset = NeRFDataset(self.args, self.device) 38 | elif self.name == "replica": 39 | from datasets.replica_dataset import ReplicaDataset 40 | self.dataset = ReplicaDataset(self.args, self.device) 41 | elif self.name == "real": 42 | from datasets.real_sense_dataset import RealSenseDataset 43 | self.dataset = RealSenseDataset(self.args, self.device) 44 | else: 45 | raise Exception(f"Unknown dataset: {self.name}") 46 | return super().initialize_module() 47 | -------------------------------------------------------------------------------- /datasets/dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import numpy as np 4 | import open3d as o3d 5 | 6 | from abc import abstractclassmethod 7 | from torch.utils.data.dataset import Dataset 8 | 9 | class Dataset(Dataset): 10 | def __init__(self, name, args, device) -> None: 11 | super().__init__() 12 | self.name = name 13 | self.args = args 14 | self.device = device 15 | 16 | self.dataset_dir = args.dataset_dir 17 | self.initial_k = args.initial_k # first frame to load 18 | self.final_k = args.final_k # last frame to load, if -1 load all 19 | self.img_stride = args.img_stride # stride for loading images 20 | self.stereo = args.stereo 21 | 22 | self.viz = False 23 | 24 | # list of data packets, 25 | # each data packet consists of two frames and all imu data in between 26 | self.data_packets = None 27 | 28 | @abstractclassmethod 29 | def stream(self): 30 | pass 31 | 32 | class PointCloudTransmissionFormat: 33 | def __init__(self, pointcloud: o3d.geometry.PointCloud): 34 | self.points = np.array(pointcloud.points) 35 | self.colors = np.array(pointcloud.colors) 36 | self.normals = np.array(pointcloud.normals) 37 | 38 | def create_pointcloud(self) -> o3d.geometry.PointCloud: 39 | pointcloud = o3d.geometry.PointCloud() 40 | pointcloud.points = o3d.utility.Vector3dVector(self.points) 41 | pointcloud.colors = o3d.utility.Vector3dVector(self.colors) 42 | pointcloud.normals = o3d.utility.Vector3dVector(self.normals) 43 | return pointcloud 44 | 45 | class CameraModel: 46 | def __init__(self, model) -> None: 47 | self.model = model 48 | pass 49 | 50 | def project(self, xyz): 51 | return self.model.project(xyz) 52 | 53 | def backproject(self, uv): 54 | return self.model.backproject(uv) 55 | 56 | class Resolution: 57 | def __init__(self, width, height) -> None: 58 | self.width = width 59 | self.height = height 60 | 61 | def numpy(self): 62 | return np.array([self.width, self.height]) 63 | 64 | def total(self): 65 | return self.width * self.height 66 | 67 | class PinholeCameraModel(CameraModel): 68 | def __init__(self, fx, fy, cx, cy) -> None: 69 | super().__init__('Pinhole') 70 | self.fx = fx 71 | self.fy = fy 72 | self.cx = cx 73 | self.cy = cy 74 | 75 | self.K = np.eye(3) 76 | self.K[0,0] = fx 77 | self.K[0,2] = cx 78 | self.K[1,1] = fy 79 | self.K[1,2] = cy 80 | 81 | def scale_intrinsics(self, scale_x, scale_y): 82 | self.fx *= scale_x # fx, cx 83 | self.cx *= scale_x # fx, cx 84 | self.fy *= scale_y # fx, cx 85 | self.cy *= scale_y # fx, cx 86 | 87 | self.K = np.eye(3) 88 | self.K[0,0] = self.fx 89 | self.K[0,2] = self.cx 90 | self.K[1,1] = self.fy 91 | self.K[1,2] = self.cy 92 | 93 | def numpy(self): 94 | return np.array([self.fx, self.fy, self.cx, self.cy]) 95 | 96 | def matrix(self): 97 | return self.K 98 | 99 | class DistortionModel: 100 | def __init__(self, model) -> None: 101 | self.model = model 102 | 103 | class RadTanDistortionModel(DistortionModel): 104 | def __init__(self, k1, k2, p1, p2) -> None: 105 | super().__init__('RadTan') 106 | # Distortioncoefficients=(k1 k2 p1 p2 k3) #OpenCV convention 107 | self.k1 = k1 108 | self.k2 = k2 109 | self.p1 = p1 110 | self.p2 = p2 111 | 112 | def get_distortion_as_vector(self): 113 | return np.array([self.k1, self.k2, self.p1, self.p2]) 114 | 115 | class CameraCalibration: 116 | def __init__(self, body_T_cam, camera_model, distortion_model, rate_hz, resolution, aabb, depth_scale) -> None: 117 | self.body_T_cam = body_T_cam 118 | self.camera_model = camera_model 119 | self.distortion_model = distortion_model 120 | self.rate_hz = rate_hz 121 | self.resolution = resolution 122 | self.aabb = aabb 123 | self.depth_scale = depth_scale 124 | 125 | class ImuCalibration: 126 | def __init__(self, body_T_imu, a_n, a_b, g_n, g_b, rate_hz, imu_integration_sigma, imu_time_shift, n_gravity) -> None: 127 | self.body_T_imu = body_T_imu 128 | self.a_n = a_n 129 | self.g_n = g_n 130 | self.a_b = a_b 131 | self.g_b = g_b 132 | self.rate_hz = rate_hz 133 | self.imu_integration_sigma = imu_integration_sigma 134 | self.imu_time_shift = imu_time_shift 135 | self.n_gravity = n_gravity 136 | pass 137 | 138 | class ViconCalibration: 139 | def __init__(self, body_T_vicon) -> None: 140 | self.body_T_vicon = body_T_vicon 141 | -------------------------------------------------------------------------------- /datasets/dataset_dict.py: -------------------------------------------------------------------------------- 1 | from .euroc_dataset import EurocDataset 2 | 3 | dataset_dict = {'euroc': EurocDataset} -------------------------------------------------------------------------------- /datasets/nerf_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import os 4 | import sys 5 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 6 | 7 | import cv2 8 | import json 9 | import numpy as np 10 | from tqdm import tqdm 11 | 12 | from datasets.dataset import * 13 | from utils.utils import * 14 | 15 | class NeRFDataset(Dataset): 16 | def __init__(self, args, device): 17 | super().__init__("Nerf", args, device) 18 | self.parse_metadata() 19 | # self._build_dataset_index() # Loads all the data first, and then streams. 20 | self.tqdm = tqdm(total=self.__len__()) # Call after parsing metadata 21 | 22 | def get_cam_calib(self): 23 | w, h = self.json["w"], self.json["h"] 24 | fx, fy = self.json["fl_x"], self.json["fl_y"] 25 | cx, cy = self.json["cx"], self.json["cy"] 26 | 27 | body_T_cam0 = np.eye(4,4) 28 | rate_hz = 10.0 29 | resolution = Resolution(w, h) 30 | pinhole0 = PinholeCameraModel(fx, fy, cx, cy) 31 | distortion0 = RadTanDistortionModel(0, 0, 0, 0) 32 | 33 | aabb = self.json["aabb"] 34 | depth_scale = self.json["integer_depth_scale"] if "integer_depth_scale" in self.json else 1.0 35 | 36 | return CameraCalibration(body_T_cam0, pinhole0, distortion0, rate_hz, resolution, aabb, depth_scale) 37 | 38 | def parse_metadata(self): 39 | with open(os.path.join(self.dataset_dir, "transforms.json"), 'r') as f: 40 | self.json = json.load(f) 41 | 42 | self.calib = self.get_cam_calib() 43 | 44 | self.resize_images = False 45 | if self.calib.resolution.total() > 640*640: 46 | self.resize_images = True 47 | # TODO(Toni): keep aspect ratio, and resize max res to 640 48 | self.output_image_size = [341, 640] # h, w 49 | 50 | self.image_paths = [] 51 | self.depth_paths = [] 52 | self.w2c = [] 53 | 54 | if self.resize_images: 55 | h0, w0 = self.calib.resolution.height, self.calib.resolution.width 56 | total_output_pixels = (self.output_image_size[0] * self.output_image_size[1]) 57 | self.h1 = int(h0 * np.sqrt(total_output_pixels / (h0 * w0))) 58 | self.w1 = int(w0 * np.sqrt(total_output_pixels / (h0 * w0))) 59 | self.h1 = self.h1 - self.h1 % 8 60 | self.w1 = self.w1 - self.w1 % 8 61 | self.calib.camera_model.scale_intrinsics(self.w1 / w0, self.h1 / h0) 62 | self.calib.resolution = Resolution(self.w1, self.h1) 63 | 64 | frames = self.json["frames"] 65 | frames = frames[self.initial_k:self.final_k:self.img_stride] 66 | print(f'Loading {len(frames)} images.') 67 | for i, frame in enumerate(frames): 68 | # Convert from nerf to ngp 69 | # TODO: convert poses to our format 70 | c2w = np.array(frame['transform_matrix']) 71 | c2w = nerf_matrix_to_ngp(c2w) # THIS multiplies by scale = 1 and offset = 0.5 72 | # TODO(TONI): prone to numerical errors, do se(3) inverse instead 73 | w2c = np.linalg.inv(c2w) 74 | 75 | # Get rgb/depth images path 76 | if frame['file_path'].endswith(".png") or frame['file_path'].endswith(".jpg"): 77 | image_path = os.path.join(self.dataset_dir, f"{frame['file_path']}") 78 | else: 79 | image_path = os.path.join(self.dataset_dir, f"{frame['file_path']}.png") 80 | depth_path = None 81 | if 'depth_path' in frame: 82 | depth_path = os.path.join(self.dataset_dir, f"{frame['depth_path']}") 83 | 84 | self.image_paths.append([i, image_path]) 85 | self.depth_paths += [depth_path] 86 | self.w2c += [w2c] 87 | 88 | # Sort paths chronologically 89 | if os.path.splitext(os.path.basename(self.image_paths[0][1]))[0].isdigit(): 90 | # name is "000000.jpg" for Cube-Diorama 91 | sorted(self.image_paths, key=lambda path: int(os.path.splitext(os.path.basename(path[1]))[0])) 92 | else: 93 | # name is "frame000000.jpg" for Replica 94 | sorted(self.image_paths, key=lambda path: int(os.path.splitext(os.path.basename(path[1]))[0][5:])) 95 | 96 | # Store the first pose, used as prior and initial state in SLAM. 97 | self.args.world_T_imu_t0 = self.w2c[0] 98 | 99 | def _get_data_packet(self, k0, k1=None): 100 | if k1 is None: 101 | k1 = k0 + 1 102 | else: 103 | assert(k1 >= k0) 104 | 105 | timestamps = [] 106 | poses = [] 107 | images = [] 108 | depths = [] 109 | calibs = [] 110 | 111 | W, H = self.calib.resolution.width, self.calib.resolution.height 112 | 113 | # Parse images and tfs 114 | for k in np.arange(k0, k1): 115 | i, image_path = self.image_paths[k] 116 | depth_path = self.depth_paths[i] # index with i, bcs we sorted image_paths to have increasing timestamps. 117 | w2c = self.w2c[i] 118 | 119 | # Parse rgb/depth images 120 | image = cv2.imread(image_path) # H, W, C=4 121 | image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA) # Required for Nerf Fusion, perhaps we can put it in there 122 | if depth_path: 123 | depth = cv2.imread(depth_path, cv2.IMREAD_UNCHANGED)[..., None] # H, W, C=1 124 | else: 125 | depth = (-1 * np.ones_like(image[:, :, 0])).astype(np.uint16) # invalid depth 126 | 127 | if self.resize_images: 128 | w1, h1 = self.w1, self.h1 129 | image = cv2.resize(image, (w1, h1)) 130 | depth = cv2.resize(depth, (w1, h1)) 131 | depth = depth[:, :, np.newaxis] 132 | 133 | if self.viz: 134 | cv2.imshow(f"Img Resized", image) 135 | cv2.imshow(f"Depth Resized", depth) 136 | cv2.waitKey(1) 137 | 138 | assert(H == image.shape[0]) 139 | assert(W == image.shape[1]) 140 | assert(3 == image.shape[2] or 4 == image.shape[2]) 141 | assert(np.uint8 == image.dtype) 142 | assert(H == depth.shape[0]) 143 | assert(W == depth.shape[1]) 144 | assert(1 == depth.shape[2]) 145 | assert(np.uint16 == depth.dtype) 146 | 147 | depth = depth.astype(np.int32) # converting to int32, because torch does not support uint16, and I don't want to lose precision 148 | 149 | timestamps += [i] 150 | poses += [w2c] 151 | images += [image] 152 | depths += [depth] 153 | calibs += [self.calib] 154 | 155 | return {"k": np.arange(k0,k1), 156 | "t_cams": np.array(timestamps), 157 | "poses": np.array(poses), 158 | "images": np.array(images), 159 | "depths": np.array(depths), 160 | "calibs": np.array(calibs), 161 | "is_last_frame": (i >= self.__len__() - 1), 162 | } 163 | 164 | def __len__(self): 165 | return len(self.image_paths) 166 | 167 | def __getitem__(self, k): 168 | self.tqdm.update(1) 169 | return self._get_data_packet(k) if self.data_packets is None else self.data_packets[k] 170 | 171 | 172 | # Up to you how you index the dataset depending on your training procedure 173 | def _build_dataset_index(self): 174 | # Go through the stream and bundle as you wish 175 | self.data_packets = [data_packet for data_packet in self.stream()] 176 | 177 | def stream(self): 178 | for k in range(self.__len__()): 179 | yield self._get_data_packet(k) 180 | -------------------------------------------------------------------------------- /datasets/real_sense_dataset.py: -------------------------------------------------------------------------------- 1 | 2 | import pyrealsense2 as rs 3 | import numpy as np 4 | import cv2 5 | import os 6 | 7 | import tqdm 8 | 9 | from datasets.dataset import * 10 | from icecream import ic 11 | 12 | class RealSenseDataset(Dataset): 13 | 14 | def __init__(self, args, device): 15 | super().__init__("RealSense", args, device) 16 | self.parse_metadata() 17 | 18 | def parse_metadata(self): 19 | # Configure depth and color streams 20 | self.pipeline = rs.pipeline() 21 | config = rs.config() 22 | 23 | # Get device product line for setting a supporting resolution 24 | pipeline_wrapper = rs.pipeline_wrapper(self.pipeline) 25 | pipeline_profile = config.resolve(pipeline_wrapper) 26 | 27 | device = pipeline_profile.get_device() 28 | device_product_line = str(device.get_info(rs.camera_info.product_line)) 29 | 30 | found_rgb = False 31 | for s in device.sensors: 32 | if s.get_info(rs.camera_info.name) == 'RGB Camera': 33 | found_rgb = True 34 | break 35 | 36 | if not found_rgb: 37 | raise NotImplementedError("No RGB camera found") 38 | 39 | if device_product_line == 'L500': 40 | raise NotImplementedError 41 | 42 | self.rate_hz = 30 43 | config.enable_stream(rs.stream.color, 640, 480, rs.format.rgb8, self.rate_hz) 44 | config.enable_stream(rs.stream.depth, 640, 480, rs.format.z16, self.rate_hz) 45 | 46 | # Set timestamp 47 | self.timestamp = 0 48 | 49 | # Start streaming 50 | ic("Start streaming") 51 | cfg = self.pipeline.start(config) 52 | 53 | # Set profile 54 | depth_cam, rgb_cam = cfg.get_device().query_sensors() 55 | rgb_cam.set_option(rs.option.enable_auto_exposure, False) 56 | rgb_cam.set_option(rs.option.exposure, 238) 57 | rgb_cam.set_option(rs.option.enable_auto_white_balance, False) 58 | rgb_cam.set_option(rs.option.white_balance, 3700) 59 | rgb_cam.set_option(rs.option.gain, 0) 60 | depth_cam.set_option(rs.option.enable_auto_exposure, False) 61 | depth_cam.set_option(rs.option.exposure, 438) 62 | depth_cam.set_option(rs.option.gain, 0) 63 | # color_sensor.set_option(rs.option.backlight_compensation, 0) # Disable backlight compensation 64 | 65 | # Set calib 66 | profile = cfg.get_stream(rs.stream.color) # Fetch stream profile for color stream 67 | intrinsics = profile.as_video_stream_profile().get_intrinsics() # Downcast to video_stream_profile and fetch intrinsics 68 | self.calib = self._get_cam_calib(intrinsics) 69 | 70 | self.resize_images = True 71 | if self.resize_images: 72 | self.output_image_size = [315, 420] # h, w 73 | h0, w0 = self.calib.resolution.height, self.calib.resolution.width 74 | total_output_pixels = (self.output_image_size[0] * self.output_image_size[1]) 75 | self.h1 = int(h0 * np.sqrt(total_output_pixels / (h0 * w0))) 76 | self.w1 = int(w0 * np.sqrt(total_output_pixels / (h0 * w0))) 77 | self.h1 = self.h1 - self.h1 % 8 78 | self.w1 = self.w1 - self.w1 % 8 79 | self.calib.camera_model.scale_intrinsics(self.w1 / w0, self.h1 / h0) 80 | self.calib.resolution = Resolution(self.w1, self.h1) 81 | 82 | def _get_cam_calib(self, intrinsics): 83 | """ intrinsics: 84 | model Distortion model of the image 85 | coeffs Distortion coefficients 86 | fx Focal length of the image plane, as a multiple of pixel width 87 | fy Focal length of the image plane, as a multiple of pixel height 88 | ppx Horizontal coordinate of the principal point of the image, as a pixel offset from the left edge 89 | ppy Vertical coordinate of the principal point of the image, as a pixel offset from the top edge 90 | height Height of the image in pixels 91 | width Width of the image in pixels 92 | """ 93 | w, h = intrinsics.width, intrinsics.height 94 | fx, fy, cx, cy= intrinsics.fx, intrinsics.fy, intrinsics.ppx, intrinsics.ppy 95 | 96 | distortion_coeffs = intrinsics.coeffs 97 | distortion_model = intrinsics.model 98 | k1, k2, p1, p2 = 0, 0, 0, 0 99 | body_T_cam0 = np.eye(4) 100 | rate_hz = self.rate_hz 101 | 102 | resolution = Resolution(w, h) 103 | pinhole0 = PinholeCameraModel(fx, fy, cx, cy) 104 | distortion0 = RadTanDistortionModel(k1, k2, p1, p2) 105 | 106 | aabb = (2*np.array([[-2, -2, -2], [2, 2, 2]])).tolist() # Computed automatically in to_nerf() 107 | depth_scale = 1.0 # TODO # Since we multiply as gt_depth *= depth_scale, we need to invert camera["scale"] 108 | 109 | return CameraCalibration(body_T_cam0, pinhole0, distortion0, rate_hz, resolution, aabb, depth_scale) 110 | 111 | 112 | def stream(self): 113 | self.viz=True 114 | 115 | timestamps = [] 116 | poses = [] 117 | images = [] 118 | depths = [] 119 | calibs = [] 120 | 121 | got_image = False 122 | while not got_image: 123 | # Wait for a coherent pair of frames: depth and color 124 | try: 125 | frames = self.pipeline.wait_for_frames() 126 | except Exception as e: 127 | print(e) 128 | continue 129 | color_frame = frames.get_color_frame() 130 | depth_frame = frames.get_depth_frame() 131 | #depth_frame = np.zeros((color_image.shape[0], color_image.shape[1], 1)) 132 | 133 | if not depth_frame or not color_frame: 134 | print("No depth and color frame parsed.") 135 | continue 136 | 137 | # Convert images to numpy arrays 138 | color_image = np.asanyarray(color_frame.get_data()) 139 | depth_image = np.asanyarray(depth_frame.get_data()) 140 | 141 | 142 | if self.resize_images: 143 | color_image = cv2.resize(color_image, (self.w1, self.h1)) 144 | depth_image = cv2.resize(depth_image, (self.w1, self.h1)) 145 | 146 | if self.viz: 147 | # Apply colormap on depth image (image must be converted to 8-bit per pixel first) 148 | depth_colormap = cv2.applyColorMap(cv2.convertScaleAbs(depth_image, alpha=0.03), cv2.COLORMAP_JET) 149 | cv2.imshow(f"Color Img", color_image) 150 | cv2.imshow(f"Depth Img", depth_colormap) 151 | cv2.waitKey(1) 152 | 153 | self.timestamp += 1 154 | if self.args.img_stride > 1 and self.timestamp % self.args.img_stride == 0: 155 | # Software imposed fps to rate_hz/img_stride 156 | continue 157 | 158 | timestamps += [self.timestamp] 159 | poses += [np.eye(4)] # We don't have poses 160 | images += [color_image] 161 | depths += [depth_image] # We don't use depth 162 | calibs += [self.calib] 163 | got_image = True 164 | 165 | return {"k": np.arange(self.timestamp-1,self.timestamp), 166 | "t_cams": np.array(timestamps), 167 | "poses": np.array(poses), 168 | "images": np.array(images), 169 | "depths": np.array(depths), 170 | "calibs": np.array(calibs), 171 | "is_last_frame": False, #TODO 172 | } 173 | 174 | def shutdown(self): 175 | # Stop streaming 176 | self.pipeline.stop() 177 | 178 | def to_nerf_format(self, data_packets): 179 | print("Exporting RealSense dataset to Nerf") 180 | OUT_PATH = "transforms.json" 181 | AABB_SCALE = 4 182 | out = { 183 | "fl_x": self.calib.camera_model.fx, 184 | "fl_y": self.calib.camera_model.fy, 185 | "k1": self.calib.distortion_model.k1, 186 | "k2": self.calib.distortion_model.k2, 187 | "p1": self.calib.distortion_model.p1, 188 | "p2": self.calib.distortion_model.p2, 189 | "cx": self.calib.camera_model.cx, 190 | "cy": self.calib.camera_model.cy, 191 | "w": self.calib.resolution.width, 192 | "h": self.calib.resolution.height, 193 | "aabb": self.calib.aabb, 194 | "aabb_scale": AABB_SCALE, 195 | "integer_depth_scale": self.calib.depth_scale, 196 | "frames": [], 197 | } 198 | 199 | from PIL import Image 200 | 201 | c2w = np.eye(4).tolist() 202 | for data_packet in tqdm(data_packets): 203 | # Image 204 | ic(data_packet["k"]) 205 | k = data_packet["k"][0] 206 | i = data_packet["images"][0] 207 | d = data_packet["depths"][0] 208 | 209 | # Store image paths 210 | color_path = os.path.join(self.args.dataset_dir, "results", f"frame{k:05}.png") 211 | depth_path = os.path.join(self.args.dataset_dir, "results", f"depth{k:05}.png") 212 | 213 | # Save image to disk 214 | color = Image.fromarray(i) 215 | depth = Image.fromarray(d) 216 | color.save(color_path) 217 | depth.save(depth_path) 218 | 219 | # Sharpness 220 | sharp = sharpness(i) 221 | 222 | # Store relative path 223 | relative_color_path = os.path.join("results", os.path.basename(color_path)) 224 | relative_depth_path = os.path.join("results", os.path.basename(depth_path)) 225 | 226 | frame = {"file_path": relative_color_path, 227 | "sharpness": sharp, 228 | "depth_path": relative_depth_path, 229 | "transform_matrix": c2w} 230 | out["frames"].append(frame) 231 | 232 | with open(os.path.join(self.args.dataset_dir, OUT_PATH), "w") as outfile: 233 | import json 234 | json.dump(out, outfile, indent=2) -------------------------------------------------------------------------------- /datasets/replica_dataset.py: -------------------------------------------------------------------------------- 1 | 2 | import glob 3 | import os 4 | import json 5 | import numpy as np 6 | 7 | import cv2 8 | import tqdm 9 | 10 | from icecream import ic 11 | from datasets.dataset import * 12 | 13 | class ReplicaDataset(Dataset): 14 | def __init__(self, args, device): 15 | super().__init__("Replica", args, device) 16 | self.dataset_dir = args.dataset_dir 17 | self.parse_dataset() 18 | self._build_dataset_index() 19 | 20 | def load_poses(self, path): 21 | poses = [] 22 | with open(path, "r") as f: 23 | lines = f.readlines() 24 | for i in range(len(self.image_paths)): 25 | line = lines[i] 26 | c2w = np.array(list(map(float, line.split()))).reshape(4, 4) 27 | c2w[:3, 1] *= -1 28 | c2w[:3, 2] *= -1 29 | w2c = np.linalg.inv(c2w) 30 | poses.append(w2c) 31 | return poses 32 | 33 | def _get_cam_calib(self, path): 34 | with open(os.path.join(self.dataset_dir, "../cam_params.json"), 'r') as f: 35 | self.json = json.load(f) 36 | 37 | camera = self.json["camera"] 38 | w, h = camera['w'], camera['h'] 39 | fx, fy, cx, cy= camera['fx'], camera['fy'], camera['cx'], camera['cy'] 40 | 41 | k1, k2, p1, p2 = 0, 0, 0, 0 42 | body_T_cam0 = np.eye(4) 43 | rate_hz = 0 44 | 45 | resolution = Resolution(w, h) 46 | pinhole0 = PinholeCameraModel(fx, fy, cx, cy) 47 | distortion0 = RadTanDistortionModel(k1, k2, p1, p2) 48 | 49 | aabb = np.array([[-2, -2, -2], [2, 2, 2]]) # Computed automatically in to_nerf() 50 | depth_scale = 1.0 / camera["scale"] # Since we multiply as gt_depth *= depth_scale, we need to invert camera["scale"] 51 | 52 | return CameraCalibration(body_T_cam0, pinhole0, distortion0, rate_hz, resolution, aabb, depth_scale) 53 | 54 | def parse_dataset(self): 55 | self.timestamps = [] 56 | self.poses = [] 57 | self.images = [] 58 | self.depths = [] 59 | self.calibs = [] 60 | 61 | self.image_paths = sorted(glob.glob(f'{self.dataset_dir}/results/frame*.jpg')) 62 | self.depth_paths = sorted(glob.glob(f'{self.dataset_dir}/results/depth*.png')) 63 | self.poses = self.load_poses(f'{self.dataset_dir}/traj.txt') 64 | self.calib = self._get_cam_calib(f'{self.dataset_dir}/../cam_params.json') 65 | 66 | N = self.args.buffer 67 | H, W = self.calib.resolution.height, self.calib.resolution.width 68 | 69 | # Parse images and tfs 70 | for i, (image_path, depth_path) in enumerate(tqdm(zip(self.image_paths, self.depth_paths))): 71 | if i >= N: 72 | break 73 | 74 | # Parse rgb/depth images 75 | image = cv2.imread(image_path) 76 | depth = cv2.imread(depth_path, cv2.IMREAD_UNCHANGED) 77 | 78 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # this is for NERF 79 | depth = cv2.imread(depth_path, cv2.IMREAD_UNCHANGED)[..., None] # H, W, C=1 80 | 81 | H, W, _ = depth.shape 82 | assert(H == image.shape[0]) 83 | assert(W == image.shape[1]) 84 | assert(3 == image.shape[2] or 4 == image.shape[2]) 85 | assert(np.uint8 == image.dtype) 86 | assert(H == depth.shape[0]) 87 | assert(W == depth.shape[1]) 88 | assert(1 == depth.shape[2]) 89 | assert(np.uint16 == depth.dtype) 90 | 91 | depth = depth.astype(np.int32) # converting to int32, because torch does not support uint16, and I don't want to lose precision 92 | 93 | self.timestamps += [i] 94 | self.images += [image] 95 | self.depths += [depth] 96 | self.calibs += [self.calib] 97 | 98 | self.poses = self.poses[:N] 99 | 100 | self.timestamps = np.array(self.timestamps) 101 | self.poses = np.array(self.poses) 102 | self.images = np.array(self.images) 103 | self.depths = np.array(self.depths) 104 | self.calibs = np.array(self.calibs) 105 | 106 | N = len(self.timestamps) 107 | assert(N == self.poses.shape[0]) 108 | assert(N == self.images.shape[0]) 109 | assert(N == self.depths.shape[0]) 110 | assert(N == self.calibs.shape[0]) 111 | 112 | def __len__(self): 113 | return len(self.poses) 114 | 115 | def __getitem__(self, k): 116 | return self.data_packets[k] if self.data_packets is not None else self._get_data_packet(k) 117 | 118 | def _get_data_packet(self, k0, k1=None): 119 | if k1 is None: 120 | k1 = k0 + 1 121 | else: 122 | assert(k1 >= k0) 123 | return {"k": np.arange(k0,k1), 124 | "t_cams": self.timestamps[k0:k1], 125 | "poses": self.poses[k0:k1], 126 | "images": self.images[k0:k1], 127 | "depths": self.depths[k0:k1], 128 | "calibs": self.calibs[k0:k1], 129 | "is_last_frame": (k0 >= self.__len__() - 1), 130 | } 131 | 132 | # Up to you how you index the dataset depending on your training procedure 133 | def _build_dataset_index(self): 134 | # Go through the stream and bundle as you wish 135 | self.data_packets = [data_packet for data_packet in self.stream()] 136 | 137 | def stream(self): 138 | for k in range(self.__len__()): 139 | yield self._get_data_packet(k) 140 | 141 | def to_nerf_format(self): 142 | print("Exporting Replica dataset to Nerf") 143 | OUT_PATH = "transforms.json" 144 | AABB_SCALE = 4 145 | out = { 146 | "fl_x": self.calib.camera_model.fx, 147 | "fl_y": self.calib.camera_model.fy, 148 | "k1": self.calib.distortion_model.k1, 149 | "k2": self.calib.distortion_model.k2, 150 | "p1": self.calib.distortion_model.p1, 151 | "p2": self.calib.distortion_model.p2, 152 | "cx": self.calib.camera_model.cx, 153 | "cy": self.calib.camera_model.cy, 154 | "w": self.calib.resolution.width, 155 | "h": self.calib.resolution.height, 156 | # TODO(Toni): calculate this automatically. Box that fits all cameras +2m 157 | "aabb": self.calib.aabb, 158 | "aabb_scale": AABB_SCALE, 159 | "integer_depth_scale": self.calib.depth_scale, 160 | "frames": [], 161 | } 162 | 163 | poses_t = [] 164 | if self.data_packets is None: 165 | self._build_dataset_index() 166 | for data_packet in self.data_packets: 167 | # Image 168 | ic(data_packet["k"]) 169 | color_path = self.image_paths[data_packet["k"][0]] 170 | depth_path = self.depth_paths[data_packet["k"][0]] 171 | 172 | relative_color_path = os.path.join("results", os.path.basename(color_path)) 173 | relative_depth_path = os.path.join("results", os.path.basename(depth_path)) 174 | 175 | # Transform 176 | w2c = data_packet["poses"][0] 177 | c2w = np.linalg.inv(w2c) 178 | 179 | # Convert from opencv convention to nerf convention 180 | c2w[0:3, 1] *= -1 # flip the y axis 181 | c2w[0:3, 2] *= -1 # flip the z axis 182 | 183 | poses_t += [w2c[:3,3].flatten()] 184 | 185 | frame = {"file_path": relative_color_path, # "sharpness": b, 186 | "depth_path": relative_depth_path, 187 | "transform_matrix": c2w.tolist()} 188 | out["frames"].append(frame) 189 | 190 | poses_t = np.array(poses_t) 191 | delta_t = 1.0 # 1 meter extra to allow for the depth of the camera 192 | t_max = np.amax(poses_t, 0).flatten() 193 | t_min = np.amin(poses_t, 0).flatten() 194 | out["aabb"] = np.array([t_min-delta_t, t_max+delta_t]).tolist() 195 | 196 | # Save the path to the ground-truth mesh as well 197 | out["gt_mesh"] = os.path.join("..", os.path.basename(self.dataset_dir)+"_mesh.ply") 198 | ic(out["gt_mesh"]) 199 | 200 | with open(OUT_PATH, "w") as outfile: 201 | import json 202 | json.dump(out, outfile, indent=2) 203 | -------------------------------------------------------------------------------- /datasets/tum_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import pandas as pd 4 | 5 | from datasets.dataset import * 6 | 7 | # TODO: Didn't finish to implement 8 | class TumDataset(Dataset): 9 | def __init__(self, args, device): 10 | super().__init__("Tum", args, device) 11 | self.dataset_dir = args.dataset_dir 12 | 13 | json_file = os.path.join(self.dataset_dir, "associations.txt") 14 | self.associations = open(associations_file).readlines() 15 | assert self.associations is not None 16 | 17 | self.resize_images = False 18 | 19 | self.t0 = None 20 | 21 | self.final_k = self.final_k if self.final_k != -1.0 and self.final_k < len( 22 | self.associations) else len(self.associations) 23 | 24 | self.parse_dataset() 25 | 26 | def parse_dataset(self): 27 | ## Get Cam Calib 28 | self.cam0_calib : CameraCalibration = self._get_cam_calib() 29 | self.cam_calibs = [self.cam0_calib] 30 | 31 | ## Get ground-truth pose 32 | gt_dir = os.path.join(self.dataset_dir, 'groundtruth.txt') 33 | self.gt_df = pd.read_csv(gt_dir, sep=" ", comment="#", 34 | names=["timestamp", "tx", "ty", "tz", "qx", "qy", "qz", "qw"]) 35 | self.gt_df.timestamp *= 10000 36 | self.gt_df.timestamp = self.gt_df['timestamp'].astype('int64') 37 | self.gt_df.set_index("timestamp", drop=False, inplace=True) 38 | 39 | def get_rgb(self, frame_id): 40 | return self._get_img(frame_id, 'rgb') 41 | 42 | def get_depth(self, frame_id): 43 | return self._get_img(frame_id, 'depth') 44 | 45 | def _get_img(self, frame_id, type): 46 | if frame_id >= self.final_k: 47 | return None, None 48 | 49 | row = self.associations[frame_id].strip().split() 50 | if type == 'rgb': 51 | img_file_name = row[1] 52 | elif type == 'depth': 53 | img_file_name = row[3] 54 | else: 55 | raise "Unknown img type" 56 | 57 | timestamp = float(row[0]) 58 | img = cv2.imread(os.path.join(self.dataset_dir, img_file_name)) 59 | 60 | assert img is not None 61 | 62 | return timestamp, img 63 | 64 | def _get_data_packet(self, k): 65 | # The img_filename has the timestamp of the image! At least for Euroc! 66 | t_rgb0, img_0 = self.get_rgb(k) 67 | t_depth0, depth_0 = self.get_depth(k) 68 | assert t_rgb0 == t_depth0 69 | t_cam0 = t_rgb0 70 | 71 | t_cams = [t_cam0] 72 | images = [img_0] 73 | depths = [depth_0] 74 | 75 | if self.viz: 76 | for i, (rgb, depth) in enumerate(zip(images, depths)): 77 | if rgb is not None: 78 | cv2.imshow(f"Img{i}", rgb) 79 | if depth is not None: 80 | cv2.imshow(f"Depth{i}", depth) 81 | 82 | # Resize images 83 | if self.resize_images: 84 | h0, w0, _ = images[0].shape 85 | output_image_size = [384, 512] 86 | total_output_pixels = (output_image_size[0] * output_image_size[1]) 87 | h1 = int(h0 * np.sqrt(total_output_pixels / (h0 * w0))) 88 | w1 = int(w0 * np.sqrt(total_output_pixels / (h0 * w0))) 89 | 90 | for i in range(len(images)): 91 | images[i] = cv2.resize(images[i], (w1, h1)) 92 | images[i] = images[i][:h1-h1%8, :w1-w1%8] 93 | 94 | # TODO: can we do this for depths?? 95 | depths[i] = cv2.resize(depths[i], (w1, h1)) 96 | depths[i] = depths[i][:h1-h1%8, :w1-w1%8] 97 | 98 | if self.viz: 99 | for i, rgb, depth in enumerate(zip(images, depths)): 100 | cv2.imshow(f"Img{i} Resized", rgb) 101 | cv2.imshow(f"Depth{i} Resized", depth) 102 | 103 | if self.viz: 104 | cv2.waitKey(1) 105 | 106 | t_cams = np.array(t_cams) 107 | images = np.array(images) 108 | assert len(images) == len(t_cams) 109 | assert len(depths) == len(t_cams) 110 | 111 | # Ground-truth 112 | #t1_gt_near = self.gt_df['timestamp'].sub(t1).abs().idxmin() 113 | #t0_gt_near = self.gt_df['timestamp'].sub(self.t0).abs().idxmin() 114 | t1 = t_cam0 115 | t1_near = self.gt_df.index.get_indexer([t1], method="nearest")[0] 116 | if self.t0 is not None: 117 | gt_t0_t1 = self.gt_df.iloc[self.t0:t1_near+1] # +1 to include t1 118 | else: 119 | gt_t0_t1 = self.gt_df.iloc[t1_near] 120 | self.t0 = t1_near 121 | 122 | return {"k": k, "t_cams": t_cams, "images": images, 123 | "cam_calibs": self.cam_calibs if not self.resize_images else self.cam_calibs_resized, 124 | "gt_t0_t1": gt_t0_t1, 125 | "is_last_frame": (k >= self.__len__() - 1)} 126 | 127 | def _get_cam_calib(self): 128 | # TODO: remove hardcoded 129 | body_T_cam0 = np.eye(4) 130 | rate_hz = 0.0 131 | width, height = 640, 480 132 | fx, fy, cx, cy = 535.4, 539.2, 320.1, 247.6 133 | k1, k2, p1, p2 = 0.0, 0.0, 0.0, 0.0 134 | 135 | resolution = Resolution(width, height) 136 | pinhole0 = PinholeCameraModel(fx, fy, cx, cy) 137 | distortion0 = RadTanDistortionModel(k1, k2, p1, p2) 138 | 139 | aabb = np.array([[0,0,0],[1,1,1]]) 140 | depth_scale = 1.0 141 | 142 | return CameraCalibration(body_T_cam0, pinhole0, distortion0, rate_hz, resolution, aabb, depth_scale) 143 | 144 | # Up to you how you index the dataset depending on your training procedure 145 | def _build_dataset_index(self): 146 | # Go through the stream and bundle as you wish 147 | # Here we do the simplest scenario, send imu data between frames, 148 | # and the next frame as a packet 149 | self.data_packets = [data_packet for data_packet in self.stream()] 150 | 151 | def __getitem__(self, index): 152 | return self.data_packets[index] if self.data_packets is not None else self._get_data_packet(index) 153 | 154 | def __len__(self): 155 | return len(self.data_packets) if self.data_packets is not None else len(self.associations) 156 | 157 | # Return all data btw frames, plus the subsequent frame. 158 | def stream(self): 159 | for k in self.__len__(): 160 | yield self._get_data_packet(k) 161 | 162 | -------------------------------------------------------------------------------- /droid.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ToniRV/NeRF-SLAM/4e407e8cb1378c6ece18e621b19ccd5be982b7dd/droid.pth -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | 4 | -------------------------------------------------------------------------------- /examples/slam_demo.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import os 4 | import sys 5 | sys.settrace 6 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 7 | 8 | import torch 9 | from torch.multiprocessing import Process 10 | 11 | from datasets.data_module import DataModule 12 | from gui.gui_module import GuiModule 13 | from slam.slam_module import SlamModule 14 | from fusion.fusion_module import FusionModule 15 | 16 | from icecream import ic 17 | 18 | import argparse 19 | 20 | def parse_args(): 21 | parser = argparse.ArgumentParser(description="Instant-SLAM") 22 | 23 | # SLAM ARGS 24 | parser.add_argument("--parallel_run", action="store_true", help="Whether to run in parallel") 25 | parser.add_argument("--multi_gpu", action="store_true", help="Whether to run with multiple (two) GPUs") 26 | parser.add_argument("--initial_k", type=int, help="Initial frame to parse in the dataset", default=0) 27 | parser.add_argument("--final_k", type=int, help="Final frame to parse in the dataset, -1 is all.", default=-1) 28 | parser.add_argument("--img_stride", type=int, help="Number of frames to skip when parsing the dataset", default=1) 29 | parser.add_argument("--stereo", action="store_true", help="Use stereo images") 30 | parser.add_argument("--weights", default="droid.pth", help="Path to the weights file") 31 | parser.add_argument("--buffer", type=int, default=512, help="Number of keyframes to keep") 32 | 33 | parser.add_argument("--dataset_dir", type=str, 34 | help="Path to the dataset directory", 35 | default="/home/tonirv/Datasets/euroc/V1_01_easy") 36 | parser.add_argument('--dataset_name', type=str, default='euroc', 37 | choices=['euroc', 'nerf', 'replica', 'real'], 38 | help='Dataset format to use.') 39 | 40 | parser.add_argument("--mask_type", type=str, default='ours', choices=['no_depth', 'raw', 'ours', 'ours_w_thresh']) 41 | 42 | #parser.add_argument("--gui", action="store_true", help="Run the testbed GUI interactively.") 43 | parser.add_argument("--slam", action="store_true", help="Run SLAM.") 44 | parser.add_argument("--fusion", type=str, default='', choices=['tsdf', 'sigma', 'nerf', ''], 45 | help="Fusion approach ('' for none):\n\ 46 | -`tsdf' classical tsdf-fusion using Open3D\n \ 47 | -`sigma' tsdf-fusion with uncertainty values (Rosinol22wacv)\n \ 48 | -`nerf' radiance field reconstruction using Instant-NGP.") 49 | 50 | # GUI ARGS 51 | parser.add_argument("--gui", action="store_true", help="Run O3D Gui, use when volume='tsdf'or'sigma'.") 52 | parser.add_argument("--width", "--screenshot_w", type=int, default=0, help="Resolution width of GUI and screenshots.") 53 | parser.add_argument("--height", "--screenshot_h", type=int, default=0, help="Resolution height of GUI and screenshots.") 54 | 55 | # NERF ARGS 56 | parser.add_argument("--network", default="", help="Path to the network config. Uses the scene's default if unspecified.") 57 | 58 | parser.add_argument("--eval", action="store_true", help="Evaluate method.") 59 | 60 | return parser.parse_args() 61 | 62 | def run(args): 63 | if args.parallel_run and args.multi_gpu: 64 | os.environ['CUDA_VISIBLE_DEVICES'] = "0,1" 65 | cpu = 'cpu' 66 | cuda_slam = 'cuda:0' 67 | cuda_fusion = 'cuda:1' # you can also try same device as in slam. 68 | else: 69 | os.environ['CUDA_VISIBLE_DEVICES'] = "0" 70 | cpu = 'cpu' 71 | cuda_slam = cuda_fusion = 'cuda:0' 72 | print(f"Running with GPUs: {os.environ['CUDA_VISIBLE_DEVICES']}") 73 | 74 | if not args.parallel_run: 75 | from queue import Queue 76 | else: 77 | from torch.multiprocessing import Queue 78 | 79 | # Create the Queue object 80 | data_for_viz_output_queue = Queue() 81 | data_for_fusion_output_queue = Queue() 82 | data_output_queue = Queue() 83 | slam_output_queue_for_fusion = Queue() 84 | slam_output_queue_for_o3d = Queue() 85 | fusion_output_queue_for_gui = Queue() 86 | gui_output_queue_for_fusion = Queue() 87 | 88 | # Create Dataset provider 89 | data_provider_module = DataModule(args.dataset_name, args, device=cpu) 90 | 91 | # Create MetaSLAM pipeline 92 | # The SLAM module takes care of creating the SLAM object itself, to avoid pickling issues 93 | # (see initialize_module() method inside) 94 | slam = args.slam 95 | if slam: 96 | slam_module = SlamModule("VioSLAM", args, device=cuda_slam) 97 | data_provider_module.register_output_queue(data_output_queue) 98 | slam_module.register_input_queue("data", data_output_queue) 99 | 100 | # Create Neural Volume 101 | fusion = args.fusion != "" 102 | if fusion: 103 | fusion_module = FusionModule(args.fusion, args, device=cuda_fusion) 104 | if slam: 105 | slam_module.register_output_queue(slam_output_queue_for_fusion) 106 | fusion_module.register_input_queue("slam", slam_output_queue_for_fusion) 107 | 108 | if (args.fusion == 'nerf' and not slam) or (args.fusion != 'nerf' and args.eval): 109 | # Only used for evaluation, or in case we do not use slam (for nerf) 110 | data_provider_module.register_output_queue(data_for_fusion_output_queue) 111 | fusion_module.register_input_queue("data", data_for_fusion_output_queue) 112 | 113 | # Create interactive Gui 114 | gui = args.gui and args.fusion != 'nerf' # nerf has its own gui 115 | if gui: 116 | gui_module = GuiModule("Open3DGui", args, device=cuda_slam) # don't use cuda:1, o3d doesn't work... 117 | data_provider_module.register_output_queue(data_for_viz_output_queue) 118 | slam_module.register_output_queue(slam_output_queue_for_o3d) 119 | gui_module.register_input_queue("data", data_for_viz_output_queue) 120 | gui_module.register_input_queue("slam", slam_output_queue_for_o3d) 121 | if fusion and (fusion_module.name == "tsdf" or fusion_module.name == "sigma"): 122 | fusion_module.register_output_queue(fusion_output_queue_for_gui) 123 | gui_module.register_input_queue("fusion", fusion_output_queue_for_gui) 124 | gui_module.register_output_queue(gui_output_queue_for_fusion) 125 | fusion_module.register_input_queue("gui", gui_output_queue_for_fusion) 126 | 127 | # Run 128 | if args.parallel_run: 129 | print("Running pipeline in parallel mode.") 130 | 131 | data_provider_thread = Process(target=data_provider_module.spin, args=()) 132 | if fusion: fusion_thread = Process(target=fusion_module.spin) # FUSION NEEDS TO BE IN A PROCESS 133 | #if slam: slam_thread = Process(target=slam_module.spin, args=()) 134 | if gui: gui_thread = Process(target=gui_module.spin, args=()) 135 | 136 | data_provider_thread.start() 137 | if fusion: fusion_thread.start() 138 | #if slam: slam_thread.start() 139 | if gui: gui_thread.start() 140 | 141 | # Runs in main thread 142 | if slam: 143 | slam_module.spin() # visualizer should be the main spin, but pytorch has a memory bug/leak if threaded... 144 | slam_module.shutdown_module() 145 | ic("Deleting SLAM module to free memory") 146 | torch.cuda.empty_cache() 147 | # slam_module.slam. # add function to empty all matrices? 148 | del slam_module 149 | print("FINISHED RUNNING SLAM") 150 | while (fusion and fusion_thread.exitcode == None): 151 | continue 152 | print("FINISHED RUNNING FUSION") 153 | while (gui and not gui_module.shutdown): 154 | continue 155 | print("FINISHED RUNNING GUI") 156 | 157 | # This is not doing what you think, because Process has another module 158 | if gui: gui_module.shutdown_module() 159 | if fusion: fusion_module.shutdown_module() 160 | data_provider_module.shutdown_module() 161 | 162 | if gui: gui_thread.terminate() # violent, should be join() 163 | #if slam: slam_thread.terminate() # violent, should be join() 164 | if fusion: fusion_thread.terminate() # violent, should be join() 165 | data_provider_thread.terminate() # violent, should be a join(), but I don't know how to flush the queue 166 | else: 167 | print("Running pipeline in sequential mode.") 168 | 169 | # Initialize all modules first (and register 3D volume) 170 | if data_provider_module.spin() \ 171 | and (not slam or slam_module.spin()) \ 172 | and (not fusion or fusion_module.spin()): 173 | if gui: 174 | gui_module.spin() 175 | #gui_module.register_volume(fusion_module.fusion.volume) 176 | 177 | # Run sequential, dataprovider fills queue and gui empties it 178 | while data_provider_module.spin() \ 179 | and (not slam or slam_module.spin()) \ 180 | and (not fusion or fusion_module.spin()) \ 181 | and (not gui or gui_module.spin()): 182 | continue 183 | 184 | # Then gui runs indefinitely until user closes window 185 | ok = True 186 | while ok: 187 | if gui: ok &= gui_module.spin() 188 | if fusion: ok &= fusion_module.spin() 189 | 190 | # Delete everything and clean memory 191 | 192 | if __name__ == '__main__': 193 | args = parse_args() 194 | 195 | torch.multiprocessing.set_start_method('spawn') 196 | torch.cuda.empty_cache() 197 | torch.backends.cudnn.benchmark = True 198 | torch.set_grad_enabled(False) 199 | 200 | run(args) 201 | print('Done...') 202 | -------------------------------------------------------------------------------- /factor_graph/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | 4 | -------------------------------------------------------------------------------- /factor_graph/factor.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from abc import ABC, abstractmethod 4 | from typing import List 5 | 6 | import torch as th 7 | from torch import nn 8 | 9 | from factor_graph.loss_function import LossFunction 10 | from factor_graph.key import Key 11 | from factor_graph.variables import Variable, Variables 12 | 13 | class Factor(ABC, nn.Module): 14 | def __init__(self, keys: List[Key], loss_function: LossFunction, device='cpu') -> None: 15 | super().__init__() 16 | self.keys = keys 17 | self.loss_function = loss_function 18 | self.device = device 19 | 20 | @abstractmethod 21 | def linearize(self, x0: Variables) -> th.Tensor: 22 | raise 23 | 24 | # Forward === error 25 | def forward(self, x: Variables or th.Tensor) -> th.Tensor: 26 | return self.error(x) 27 | 28 | # f(x) 29 | @abstractmethod 30 | def error(self, x: Variables or th.Tensor) -> th.Tensor: 31 | pass 32 | 33 | # w(e) where e = f(x), 34 | def weight(self, x: Variables) -> th.Tensor: 35 | return self.loss_function(self.error(x)) 36 | 37 | def _batch_jacobian(self, f, x0: th.Tensor): 38 | f_sum = lambda x: th.sum(f(x), axis=0) # sum over all batches 39 | return th.autograd.functional.jacobian(f_sum, x0, create_graph=True).swapaxes(1, 0) 40 | -------------------------------------------------------------------------------- /factor_graph/factor_graph.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from typing import Tuple 4 | 5 | import torch as th 6 | from torch import nn 7 | 8 | from icecream import ic 9 | 10 | import colored_glog as log 11 | 12 | from factor_graph.variables import Variables 13 | 14 | from gtsam import NonlinearFactorGraph as FactorGraph 15 | 16 | class FactorGraphManager: 17 | def __init__(self) -> None: 18 | self.factor_graph = FactorGraph() 19 | 20 | def add(self, factor_graph): 21 | # Check we don't have repeated factors 22 | 23 | if factor_graph.empty(): 24 | #log.warn("Attempted to add factors, but none were provided") 25 | return False 26 | 27 | # Remove None factors? 28 | 29 | # Add factors 30 | self.factor_graph.push_back(factor_graph) 31 | 32 | # Perhaps remove old factors 33 | 34 | # Update map from factor-id to slot 35 | return True 36 | 37 | def replace(self, key, factor): 38 | self.factor_graph.replace(key, factor) 39 | 40 | def remove(self, key): 41 | self.factor_graph.remove(key) 42 | 43 | def __iter__(self): 44 | return self.factor_graph.__iter__() 45 | 46 | def __getitem__(self, key): 47 | assert self.factor_graph.exists(key) 48 | return self.factor_graph.at(key) 49 | 50 | def __len__(self): 51 | # self.factor_graph.nrFactors() 52 | return self.factor_graph.size() 53 | 54 | def is_empty(self): 55 | return self.factor_graph.empty() 56 | 57 | def reset_factor_graph(self): 58 | self.factor_graph = FactorGraph() 59 | 60 | def get_factor_graph(self): 61 | return self.factor_graph 62 | 63 | 64 | class TorchFactorGraph(nn.Module): 65 | def __init__(self): 66 | super().__init__() 67 | self.factors = nn.ModuleList([]) 68 | ic(self.factors) 69 | self.run_jit = False 70 | 71 | def add(self, factors): 72 | # Check we don't have repeated factors 73 | 74 | # Append factors 75 | self.factors.extend(factors) 76 | 77 | # Update map from factor-id to slot 78 | pass 79 | 80 | def remove(self, factor_ids): 81 | pass 82 | 83 | def __iter__(self): 84 | return self.factors.__iter__() 85 | 86 | def __getitem__(self, key): 87 | return self.factors.__getitem__(key) 88 | 89 | def __len__(self): 90 | return len(self.factors) 91 | 92 | def is_empty(self): 93 | return self.__len__() == 0 94 | 95 | def forward(self, x: Variables) -> th.Tensor: 96 | return self._forward_jit(x) if self.run_jit else self._forward(x) 97 | 98 | def _forward(self, x: Variables) -> th.Tensor: 99 | residuals = th.stack([factor(x) for factor in self.factors]) 100 | weights = th.stack([factor.weight(x) for factor in self.factors]) 101 | return th.sum(th.pow(residuals, 2) * weights, dim=0) 102 | 103 | # Not necessarily faster... 104 | def _forward_jit(self, x: Variables) -> th.Tensor: 105 | residuals_calc = [th.jit.fork(factor, x) for factor in self.factors] 106 | weights_calc = [th.jit.fork(factor.weight, x) for factor in self.factors] 107 | residuals = th.stack([th.jit.wait(thread) for thread in residuals_calc]) 108 | weights = th.stack([th.jit.wait(thread) for thread in weights_calc]) 109 | return th.sum(th.pow(residuals, 2) * weights, dim=0) 110 | 111 | 112 | # TODO parallelize 113 | # These variables are already ordered. 114 | def linearize(self, x0: Variables) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: 115 | # Build the Jacobian matrix, here only for one f 116 | # Returns a tensor per each entry [i][j] corresponding to J_ij = d(output_i)/d(input_j) 117 | # this allows us to reason about the jacobian in terms of blocks. 118 | # A[i has dimensions of the output f(x0)][j has dimensions of the input x0] 119 | # vectorize=True --> first dimension is considered batch dimension 120 | # A = torch.autograd.functional.jacobian(f, x0, vectorize=True, create_graph=True) 121 | 122 | # Here we could query how to linearize: 123 | ## a) Linearize using AutoDiff (current) 124 | ## b) Linearize using closed-form jacobian 125 | ## extra) Linearize itself + compress with Schur complement, and return reduced Hessian matrix? 126 | AA = None; bb = None; ww = None 127 | # TODO Linearize factors in parallel... 128 | for factor in self.factors: 129 | if factor is None: 130 | log.warn("Got a None factor...") 131 | continue 132 | A, b, w = factor.linearize(x0) 133 | if AA is None: 134 | AA = A; bb = b; ww = w 135 | continue 136 | AA = th.hstack((AA, A)) 137 | bb = th.hstack((bb, b)) 138 | ww = th.hstack((ww, w)) 139 | return AA, bb, ww 140 | 141 | # TODO parallelize 142 | def weight(self, x: Variables) -> th.Tensor: 143 | ww = None; 144 | for factor in self.factors: 145 | w = factor.weight(x) 146 | if ww is None: 147 | ww = w 148 | continue 149 | ww = th.hstack((ww, w)) 150 | return ww 151 | -------------------------------------------------------------------------------- /factor_graph/key.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | class Key(): 4 | def __init__(self, name: str, idx: int): 5 | self.name = name 6 | self.idx = idx 7 | 8 | def name(self): 9 | return self.name 10 | 11 | def idx(self): 12 | return self.idx 13 | 14 | def __hash__(self): 15 | return hash((self.name, self.idx)) 16 | 17 | def __eq__(self, other): 18 | return (self.name, self.idx) == (other.name, other.idx) 19 | 20 | def __ne__(self, other): 21 | # Not strictly necessary, but to avoid having both x==y and x!=y 22 | # True at the same time 23 | return not(self == other) 24 | 25 | def __repr__(self) -> str: 26 | return self.__str__() 27 | 28 | def __str__(self) -> str: 29 | return f"{self.name}_{self.idx}" 30 | -------------------------------------------------------------------------------- /factor_graph/loss_function.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import torch as th 4 | from torch import nn 5 | 6 | class LossFunction(nn.Module): 7 | def __init__(self, device) -> None: 8 | nn.Module.__init__(self) 9 | super(LossFunction).__init__() 10 | self.device = device 11 | 12 | class CauchyLossFunction(LossFunction): 13 | def __init__(self, device, k_initial: float = 1.0) -> None: 14 | super().__init__(device) 15 | self.k = th.nn.Parameter(k_initial * th.ones(1, device=self.device, requires_grad=True)) 16 | 17 | # w(e) = w(f(x)) 18 | # The loss function weights the residuals 19 | def forward(self, residuals: th.Tensor) -> th.Tensor: 20 | return 1.0 / (1.0 + th.pow(residuals / self.k, 2)) 21 | 22 | class GMLossFunction(LossFunction): 23 | def __init__(self, device, k_initial: float = 1.0) -> None: 24 | super().__init__(device) 25 | self.k = th.nn.Parameter(k_initial * th.ones(1, device=self.device, requires_grad=True)) 26 | 27 | # w(e) = w(f(x)) 28 | # The loss function weights the residuals 29 | def forward(self, residuals: th.Tensor) -> th.Tensor: 30 | return th.pow(self.k, 2) / th.pow(self.k + th.pow(residuals, 2), 2) # Isn't this cauchy? -------------------------------------------------------------------------------- /factor_graph/variables.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from typing import List, Dict 4 | import colored_glog as log 5 | import torch as th 6 | 7 | from factor_graph.key import Key 8 | 9 | class Variable: 10 | def __init__(self, key: Key, value: th.Tensor): 11 | self.key = key 12 | self.value = value 13 | 14 | def __repr__(self) -> str: 15 | return self.__str__() 16 | 17 | def __str__(self) -> str: 18 | return f"Variable(key={self.key}, value={self.value})" 19 | 20 | class Variables: 21 | def __init__(self): 22 | self.vars: Dict[Key, Variable] = {} 23 | pass 24 | 25 | def add(self, variable: Variable): 26 | self.vars[variable.key] = variable 27 | 28 | # The order of the keys is important 29 | def at(self, keys: List[Key]) -> th.Tensor: 30 | # Stack all variables in the order of the keys 31 | return th.hstack([self.vars[key].value for key in keys]) 32 | 33 | # The order of the keys is important 34 | def stack(self) -> th.Tensor: 35 | # Stack all variables in the order of the keys 36 | # TODO: super slow, we should keep an hstack perhaps 37 | return th.hstack(list(map(lambda x: x.value, self.vars.values()))) 38 | 39 | # The order of the delta is important, it must match the order of the keys 40 | # in the Variables object self.vars 41 | # delta -> [B, N] 42 | # vars -> [B, N] 43 | def retract(self, delta: th.Tensor): 44 | log.check_eq(delta.shape[1], len(self.vars)) 45 | # TODO 46 | # Retract variables in parallel, rather than sequentially 47 | # Use ordered dict, instead of retrieving in order 48 | x_new = {} 49 | for delta_i, key in zip(delta.t(), self.vars.keys()): 50 | x_new[key] = Variable(key, self.vars[key].value + delta_i.unsqueeze(0).t()) 51 | return x_new 52 | 53 | def __repr__(self) -> str: 54 | return self.__str__() 55 | 56 | def __str__(self) -> str: 57 | return f"Variables(vars={self.vars})" 58 | 59 | -------------------------------------------------------------------------------- /fusion/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | 4 | -------------------------------------------------------------------------------- /fusion/fusion_module.py: -------------------------------------------------------------------------------- 1 | from pipeline.pipeline_module import MIMOPipelineModule 2 | 3 | class FusionModule(MIMOPipelineModule): 4 | def __init__(self, name, args, device="cpu") -> None: 5 | super().__init__(name, args.parallel_run, args) 6 | self.device = device 7 | 8 | def spin_once(self, data_packet): 9 | output = self.fusion.fuse(data_packet) 10 | # TODO: if you uncomment this, we never reach gui/fusion loop, but if you comment it never stops. 11 | if self.fusion.stop_condition(): 12 | print("Stopping fusion module!") 13 | super().shutdown_module() 14 | #if not output: 15 | # super().shutdown_module() 16 | return output 17 | 18 | def initialize_module(self): 19 | self.set_cuda_device() # This needs to be done before importing NerfFusion or TsdfFusion 20 | if self.name == "tsdf" or self.name == "sigma": 21 | from fusion.tsdf_fusion import TsdfFusion 22 | self.fusion = TsdfFusion(self.name, self.args, self.device) 23 | elif self.name == "nerf": 24 | from fusion.nerf_fusion import NerfFusion 25 | self.fusion = NerfFusion(self.name, self.args, self.device) 26 | else: 27 | raise NotImplementedError 28 | return super().initialize_module() 29 | 30 | def get_input_packet(self): 31 | input = super().get_input_packet(timeout=0.0000000001) # don't block fusion waiting for input 32 | return input if input is not None else False # so that we keep running, and do not just stop in spin() 33 | 34 | def set_cuda_device(self): 35 | if self.device == "cpu": 36 | return 37 | 38 | import os 39 | if self.device == "cuda:0": 40 | os.environ['CUDA_VISIBLE_DEVICES'] = "0" 41 | elif self.device == "cuda:1": 42 | os.environ['CUDA_VISIBLE_DEVICES'] = "1" 43 | self.device = "cuda:0" # Since only 1 will be visible... 44 | else: 45 | raise NotImplementedError 46 | -------------------------------------------------------------------------------- /gui/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | 4 | -------------------------------------------------------------------------------- /gui/gui_module.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from pipeline.pipeline_module import MIMOPipelineModule 4 | 5 | class GuiModule(MIMOPipelineModule): 6 | def __init__(self, name, args, device="cpu") -> None: 7 | super().__init__(name, args.parallel_run, args) 8 | self.device = device 9 | 10 | def spin_once(self, data_packet): 11 | # If the queue is empty, queue.get() will block until the queue has data 12 | output = self.gui.visualize(data_packet) 13 | #if not output: 14 | # super().shutdown_module() 15 | return output 16 | 17 | def initialize_module(self): 18 | if self.name == "Open3DGui": 19 | from gui.open3d_gui import Open3dGui 20 | self.gui = Open3dGui(self.args, self.device) 21 | elif self.name == "DearPyGui": 22 | raise NotImplementedError 23 | else: 24 | raise NotImplementedError 25 | self.gui.initialize() 26 | return super().initialize_module() 27 | 28 | def get_input_packet(self): 29 | # don't block rendering waiting for input 30 | input = super().get_input_packet(timeout=0.000000001) 31 | # so that we keep running, and do not just stop in spin() 32 | return input if input is not None else False 33 | -------------------------------------------------------------------------------- /media/intro.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ToniRV/NeRF-SLAM/4e407e8cb1378c6ece18e621b19ccd5be982b7dd/media/intro.gif -------------------------------------------------------------------------------- /media/mit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ToniRV/NeRF-SLAM/4e407e8cb1378c6ece18e621b19ccd5be982b7dd/media/mit.png -------------------------------------------------------------------------------- /media/mrg_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ToniRV/NeRF-SLAM/4e407e8cb1378c6ece18e621b19ccd5be982b7dd/media/mrg_logo.png -------------------------------------------------------------------------------- /media/sparklab_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ToniRV/NeRF-SLAM/4e407e8cb1378c6ece18e621b19ccd5be982b7dd/media/sparklab_logo.png -------------------------------------------------------------------------------- /networks/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | 4 | -------------------------------------------------------------------------------- /networks/droid_frontend.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import lietorch 3 | import numpy as np 4 | 5 | from lietorch import SE3 6 | from networks.factor_graph import FactorGraph 7 | 8 | 9 | class DroidFrontend: 10 | def __init__(self, droid_net, video, args): 11 | self.video = video # has the frames, the features, and the context activations 12 | self.update_net = droid_net.update_net # The GRU update network 13 | self.graph = FactorGraph(video, droid_net.update_net, max_factors=48) 14 | 15 | # local optimization window 16 | self.t0 = 0 17 | self.t1 = 0 18 | 19 | # frontent variables 20 | self.is_initialized = False 21 | self.count = 0 22 | 23 | self.max_age = 25 24 | self.iters1 = 4 25 | self.iters2 = 2 26 | 27 | self.warmup = args.warmup 28 | self.beta = args.beta 29 | self.frontend_nms = args.frontend_nms 30 | self.keyframe_thresh = args.keyframe_thresh 31 | self.frontend_window = args.frontend_window 32 | self.frontend_thresh = args.frontend_thresh 33 | self.frontend_radius = args.frontend_radius 34 | 35 | def __update(self): 36 | """ add edges, perform update """ 37 | 38 | self.count += 1 39 | self.t1 += 1 40 | 41 | if self.graph.correlation_volumes is not None: 42 | self.graph.rm_factors(self.graph.age > self.max_age, store=True) 43 | 44 | self.graph.add_proximity_factors(self.t1-5, max(self.t1-self.frontend_window, 0), 45 | rad=self.frontend_radius, nms=self.frontend_nms, thresh=self.frontend_thresh, beta=self.beta, remove=True) 46 | 47 | self.video.disps[self.t1-1] = torch.where(self.video.disps_sens[self.t1-1] > 0, 48 | self.video.disps_sens[self.t1-1], self.video.disps[self.t1-1]) 49 | 50 | for itr in range(self.iters1): 51 | self.graph.update(None, None, use_inactive=True) 52 | 53 | # set initial pose for next frame 54 | poses = SE3(self.video.poses) 55 | d = self.video.distance([self.t1-3], [self.t1-2], beta=self.beta, bidirectional=True) 56 | 57 | if d.item() < self.keyframe_thresh: 58 | self.graph.rm_keyframe(self.t1 - 2) 59 | 60 | with self.video.get_lock(): 61 | self.video.counter.value -= 1 62 | self.t1 -= 1 63 | 64 | else: 65 | for itr in range(self.iters2): 66 | self.graph.update(None, None, use_inactive=True) 67 | 68 | # set pose for next iteration 69 | self.video.poses[self.t1] = self.video.poses[self.t1-1] 70 | self.video.disps[self.t1] = self.video.disps[self.t1-1].mean() 71 | 72 | # update visualization 73 | self.video.dirty[self.graph.ii.min():self.t1] = True 74 | 75 | # Do we really need this special initialization? 76 | def __initialize(self): 77 | """ initialize the SLAM system """ 78 | 79 | self.t0 = 0 80 | self.t1 = self.video.counter.value 81 | 82 | # Just adds the `r' sequential frames to the graph 83 | self.graph.add_neighborhood_factors(self.t0, self.t1, r=3) 84 | 85 | for itr in range(8): 86 | self.graph.update(1, use_inactive=True) 87 | 88 | self.graph.add_proximity_factors(0, 0, rad=2, nms=2, thresh=self.frontend_thresh, remove=False) 89 | 90 | for itr in range(8): 91 | self.graph.update(1, use_inactive=True) 92 | 93 | # Set initial pose/depth for next iteration 94 | # self.video.normalize() 95 | self.video.poses[self.t1] = self.video.poses[self.t1-1].clone() 96 | self.video.disps[self.t1] = self.video.disps[self.t1-4:self.t1].mean() # Next depth is just the mean? 97 | 98 | # initialization complete 99 | self.is_initialized = True 100 | self.last_pose = self.video.poses[self.t1-1].clone() 101 | self.last_disp = self.video.disps[self.t1-1].clone() 102 | self.last_time = self.video.tstamp[self.t1-1].clone() 103 | 104 | # for visualization 105 | with self.video.get_lock(): 106 | self.video.ready.value = 1 107 | self.video.dirty[:self.t1] = True 108 | 109 | self.graph.rm_factors(self.graph.ii < self.warmup-4, store=True) 110 | 111 | def __call__(self): 112 | """ main update """ 113 | 114 | # do initialization 115 | if not self.is_initialized and self.video.counter.value == self.warmup: 116 | self.__initialize() 117 | 118 | # do update 119 | elif self.is_initialized and self.t1 < self.video.counter.value: 120 | self.__update() 121 | 122 | 123 | -------------------------------------------------------------------------------- /networks/droid_net.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from collections import OrderedDict 6 | 7 | from networks.modules.extractor import BasicEncoder 8 | from networks.modules.corr import CorrBlock 9 | from networks.modules.gru import ConvGRU 10 | from networks.modules.clipping import GradientClip 11 | 12 | from lietorch import SE3 13 | from networks.geom.ba import BA 14 | 15 | import networks.geom.projective_ops as pops 16 | from networks.geom.graph_utils import graph_to_edge_list, keyframe_indicies 17 | 18 | from torch_scatter import scatter_mean 19 | 20 | 21 | def cvx_upsample(data, mask): 22 | """ upsample pixel-wise transformation field """ 23 | batch, ht, wd, dim = data.shape 24 | data = data.permute(0, 3, 1, 2) 25 | mask = mask.view(batch, 1, 9, 8, 8, ht, wd) 26 | mask = torch.softmax(mask, dim=2) 27 | 28 | up_data = F.unfold(data, [3,3], padding=1) 29 | up_data = up_data.view(batch, dim, 9, 1, 1, ht, wd) 30 | 31 | up_data = torch.sum(mask * up_data, dim=2) 32 | up_data = up_data.permute(0, 4, 2, 5, 3, 1) 33 | up_data = up_data.reshape(batch, 8*ht, 8*wd, dim) 34 | 35 | return up_data 36 | 37 | def upsample_disp(disp, mask): 38 | batch, num, ht, wd = disp.shape 39 | disp = disp.view(batch*num, ht, wd, 1) 40 | mask = mask.view(batch*num, -1, ht, wd) 41 | return cvx_upsample(disp, mask).view(batch, num, 8*ht, 8*wd) 42 | 43 | 44 | class GraphAgg(nn.Module): 45 | def __init__(self): 46 | super(GraphAgg, self).__init__() 47 | self.conv1 = nn.Conv2d(128, 128, 3, padding=1) 48 | self.conv2 = nn.Conv2d(128, 128, 3, padding=1) 49 | self.relu = nn.ReLU(inplace=True) 50 | 51 | self.eta = nn.Sequential( 52 | nn.Conv2d(128, 1, 3, padding=1), 53 | GradientClip(), 54 | nn.Softplus()) 55 | 56 | self.upmask = nn.Sequential( 57 | nn.Conv2d(128, 8*8*9, 1, padding=0)) 58 | 59 | def forward(self, net, ii): 60 | batch, num, ch, ht, wd = net.shape 61 | net = net.view(batch*num, ch, ht, wd) 62 | 63 | _, ix = torch.unique(ii, return_inverse=True) 64 | net = self.relu(self.conv1(net)) 65 | 66 | net = net.view(batch, num, 128, ht, wd) 67 | net = scatter_mean(net, ix, dim=1) 68 | net = net.view(-1, 128, ht, wd) 69 | 70 | net = self.relu(self.conv2(net)) 71 | 72 | eta = self.eta(net).view(batch, -1, ht, wd) 73 | upmask = self.upmask(net).view(batch, -1, 8*8*9, ht, wd) 74 | 75 | return .01 * eta, upmask 76 | 77 | 78 | class UpdateModule(nn.Module): 79 | def __init__(self): 80 | super(UpdateModule, self).__init__() 81 | cor_planes = 4 * (2*3 + 1)**2 82 | 83 | self.corr_encoder = nn.Sequential( 84 | nn.Conv2d(cor_planes, 128, 1, padding=0), 85 | nn.ReLU(inplace=True), 86 | nn.Conv2d(128, 128, 3, padding=1), 87 | nn.ReLU(inplace=True)) 88 | 89 | self.flow_encoder = nn.Sequential( 90 | nn.Conv2d(4, 128, 7, padding=3), 91 | nn.ReLU(inplace=True), 92 | nn.Conv2d(128, 64, 3, padding=1), 93 | nn.ReLU(inplace=True)) 94 | 95 | self.weight = nn.Sequential( 96 | nn.Conv2d(128, 128, 3, padding=1), 97 | nn.ReLU(inplace=True), 98 | nn.Conv2d(128, 2, 3, padding=1), 99 | GradientClip(), 100 | nn.Sigmoid()) 101 | 102 | self.delta = nn.Sequential( 103 | nn.Conv2d(128, 128, 3, padding=1), 104 | nn.ReLU(inplace=True), 105 | nn.Conv2d(128, 2, 3, padding=1), 106 | GradientClip()) 107 | 108 | # Inputs to gru are: 109 | # i) hidden state (128) 110 | # ii) input state (128+128+64) 111 | # a) context_encoder_output(128), 112 | # b) corr_encoder_output (128), 113 | # c) flow_encoder_output (64), 114 | # Input planes: 128=constant_context, 128=correlation, 64=motion, 64=sigma 115 | self.gru = ConvGRU(128, 128+128+64) 116 | self.agg = GraphAgg() 117 | 118 | def forward(self, net, inp, corr, flow=None, ii=None, jj=None): 119 | """ RaftSLAM update operator """ 120 | 121 | batch, num, ch, ht, wd = net.shape 122 | 123 | if flow is None: 124 | # Initialize flow to zero... (TODO: better initialization using IMU) 125 | flow = torch.zeros(batch, num, 4, ht, wd, device=net.device) 126 | 127 | output_dim = (batch, num, -1, ht, wd) 128 | net = net.view(batch*num, -1, ht, wd) 129 | inp = inp.view(batch*num, -1, ht, wd) 130 | corr = corr.view(batch*num, -1, ht, wd) 131 | flow = flow.view(batch*num, -1, ht, wd) 132 | 133 | corr = self.corr_encoder(corr) 134 | flow = self.flow_encoder(flow) 135 | net = self.gru(net, inp, corr, flow) 136 | 137 | ### update variables ### 138 | delta = self.delta(net).view(*output_dim) 139 | weight = self.weight(net).view(*output_dim) 140 | 141 | delta = delta.permute(0,1,3,4,2)[...,:2].contiguous() 142 | weight = weight.permute(0,1,3,4,2)[...,:2].contiguous() 143 | 144 | net = net.view(*output_dim) 145 | 146 | if ii is not None: 147 | eta, upmask = self.agg(net, ii.to(net.device)) 148 | return net, delta, weight, eta, upmask 149 | else: 150 | return net, delta, weight 151 | 152 | 153 | class DroidNet(nn.Module): 154 | def __init__(self): 155 | super(DroidNet, self).__init__() 156 | self.feature_net = BasicEncoder(output_dim=128, norm_fn='instance') 157 | self.context_net = BasicEncoder(output_dim=256, norm_fn='none') 158 | self.update_net = UpdateModule() 159 | 160 | 161 | #### EVERYTHING BELOW IS ONLY USED FOR TRAINING NOT INFERENCE #### 162 | def extract_features(self, images): 163 | """ run feature extraction networks """ 164 | 165 | # normalize images 166 | b, n, c1, h1, w1 = x.shape 167 | images = images[:, :, [2,1,0]] / 255.0 168 | mean = torch.as_tensor([0.485, 0.456, 0.406], device=images.device) 169 | std = torch.as_tensor([0.229, 0.224, 0.225], device=images.device) 170 | images = images.sub_(mean[:, None, None]).div_(std[:, None, None]) 171 | 172 | feature_maps = self.feature_net(images) 173 | context_maps = self.context_net(images) 174 | 175 | context_maps, gru_input_maps = context_maps.split([128,128], dim=2) 176 | context_maps = torch.tanh(context_maps) 177 | gru_input_maps = torch.relu(gru_input_maps) 178 | return feature_maps, context_maps, gru_input_maps 179 | 180 | 181 | def forward(self, Gs, images, disps, intrinsics, graph=None, num_steps=12, fixedp=2): 182 | """ Estimates SE3 or Sim3 between pair of frames """ 183 | 184 | u = keyframe_indicies(graph) 185 | ii, jj, kk = graph_to_edge_list(graph) 186 | 187 | ii = ii.to(device=images.device, dtype=torch.long) 188 | jj = jj.to(device=images.device, dtype=torch.long) 189 | 190 | feature_maps, context_maps, gru_input_maps = self.extract_features(images) 191 | context_maps, gru_input_maps = context_maps[:,ii], gru_input_maps[:,ii] 192 | corr_fn = CorrBlock(feature_maps[:,ii], feature_maps[:,jj], num_levels=4, radius=3) 193 | 194 | ht, wd = images.shape[-2:] 195 | coords0 = pops.coords_grid(ht//8, wd//8, device=images.device) 196 | 197 | coords1, _ = pops.projective_transform(Gs, disps, intrinsics, ii, jj) 198 | target = coords1.clone() 199 | 200 | Gs_list, disp_list, residual_list = [], [], [] 201 | for _ in range(num_steps): 202 | Gs = Gs.detach() 203 | disps = disps.detach() 204 | coords1 = coords1.detach() 205 | target = target.detach() 206 | 207 | # extract motion features 208 | corr = corr_fn(coords1) 209 | residual = target - coords1 210 | flow = coords1 - coords0 211 | 212 | motion = torch.cat([flow, residual], dim=-1) 213 | motion = motion.permute(0,1,4,2,3).clamp(-64.0, 64.0) 214 | 215 | context_maps, delta, weight, eta, upmask = \ 216 | self.update_net(context_maps, gru_input_maps, corr, motion, ii, jj) 217 | 218 | target = coords1 + delta 219 | 220 | for _ in range(2): 221 | Gs, disps = BA(target, weight, eta, Gs, disps, intrinsics, ii, jj, fixedp=2) 222 | 223 | coords1, valid_mask = pops.projective_transform(Gs, disps, intrinsics, ii, jj) 224 | residual = (target - coords1) 225 | 226 | Gs_list.append(Gs) 227 | disp_list.append(upsample_disp(disps, upmask)) 228 | residual_list.append(valid_mask * residual) 229 | 230 | return Gs_list, disp_list, residual_list -------------------------------------------------------------------------------- /networks/geom/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | 4 | -------------------------------------------------------------------------------- /networks/geom/ba.py: -------------------------------------------------------------------------------- 1 | import lietorch 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | from .chol import block_solve, schur_solve 6 | from . import projective_ops as pops 7 | 8 | from torch_scatter import scatter_sum 9 | 10 | 11 | # utility functions for scattering ops 12 | def safe_scatter_add_mat(A, ii, jj, n, m): 13 | v = (ii >= 0) & (jj >= 0) & (ii < n) & (jj < m) 14 | return scatter_sum(A[:,v], ii[v]*m + jj[v], dim=1, dim_size=n*m) 15 | 16 | def safe_scatter_add_vec(b, ii, n): 17 | v = (ii >= 0) & (ii < n) 18 | return scatter_sum(b[:,v], ii[v], dim=1, dim_size=n) 19 | 20 | # apply retraction operator to inv-depth maps 21 | def disp_retr(disps, dz, ii): 22 | ii = ii.to(device=dz.device) 23 | return disps + scatter_sum(dz, ii, dim=1, dim_size=disps.shape[1]) 24 | 25 | # apply retraction operator to poses 26 | def pose_retr(poses, dx, ii): 27 | ii = ii.to(device=dx.device) 28 | return poses.retr(scatter_sum(dx, ii, dim=1, dim_size=poses.shape[1])) 29 | 30 | 31 | def BA(target, weight, eta, poses, disps, intrinsics, ii, jj, fixedp=1, rig=1): 32 | """ Full Bundle Adjustment """ 33 | 34 | B, P, ht, wd = disps.shape 35 | N = ii.shape[0] 36 | D = poses.manifold_dim 37 | 38 | ### 1: commpute jacobians and residuals ### 39 | coords, valid, (Ji, Jj, Jz) = pops.projective_transform( 40 | poses, disps, intrinsics, ii, jj, jacobian=True) 41 | 42 | r = (target - coords).view(B, N, -1, 1) 43 | w = .001 * (valid * weight).view(B, N, -1, 1) 44 | 45 | ### 2: construct linear system ### 46 | Ji = Ji.reshape(B, N, -1, D) 47 | Jj = Jj.reshape(B, N, -1, D) 48 | wJiT = (w * Ji).transpose(2,3) 49 | wJjT = (w * Jj).transpose(2,3) 50 | 51 | Jz = Jz.reshape(B, N, ht*wd, -1) 52 | 53 | Hii = torch.matmul(wJiT, Ji) 54 | Hij = torch.matmul(wJiT, Jj) 55 | Hji = torch.matmul(wJjT, Ji) 56 | Hjj = torch.matmul(wJjT, Jj) 57 | 58 | vi = torch.matmul(wJiT, r).squeeze(-1) 59 | vj = torch.matmul(wJjT, r).squeeze(-1) 60 | 61 | Ei = (wJiT.view(B,N,D,ht*wd,-1) * Jz[:,:,None]).sum(dim=-1) 62 | Ej = (wJjT.view(B,N,D,ht*wd,-1) * Jz[:,:,None]).sum(dim=-1) 63 | 64 | w = w.view(B, N, ht*wd, -1) 65 | r = r.view(B, N, ht*wd, -1) 66 | wk = torch.sum(w*r*Jz, dim=-1) 67 | Ck = torch.sum(w*Jz*Jz, dim=-1) 68 | 69 | kx, kk = torch.unique(ii, return_inverse=True) 70 | M = kx.shape[0] 71 | 72 | # only optimize keyframe poses 73 | P = P // rig - fixedp 74 | ii = ii // rig - fixedp 75 | jj = jj // rig - fixedp 76 | 77 | H = safe_scatter_add_mat(Hii, ii, ii, P, P) + \ 78 | safe_scatter_add_mat(Hij, ii, jj, P, P) + \ 79 | safe_scatter_add_mat(Hji, jj, ii, P, P) + \ 80 | safe_scatter_add_mat(Hjj, jj, jj, P, P) 81 | 82 | E = safe_scatter_add_mat(Ei, ii, kk, P, M) + \ 83 | safe_scatter_add_mat(Ej, jj, kk, P, M) 84 | 85 | v = safe_scatter_add_vec(vi, ii, P) + \ 86 | safe_scatter_add_vec(vj, jj, P) 87 | 88 | C = safe_scatter_add_vec(Ck, kk, M) 89 | w = safe_scatter_add_vec(wk, kk, M) 90 | 91 | C = C + eta.view(*C.shape) + 1e-7 92 | 93 | H = H.view(B, P, P, D, D) 94 | E = E.view(B, P, M, D, ht*wd) 95 | 96 | ### 3: solve the system ### 97 | dx, dz = schur_solve(H, E, C, v, w) 98 | 99 | ### 4: apply retraction ### 100 | poses = pose_retr(poses, dx, torch.arange(P) + fixedp) 101 | disps = disp_retr(disps, dz.view(B,-1,ht,wd), kx) 102 | 103 | disps = torch.where(disps > 10, torch.zeros_like(disps), disps) 104 | disps = disps.clamp(min=0.0) 105 | 106 | return poses, disps 107 | 108 | 109 | def MoBA(target, weight, eta, poses, disps, intrinsics, ii, jj, fixedp=1, rig=1): 110 | """ Motion only bundle adjustment """ 111 | 112 | B, P, ht, wd = disps.shape 113 | N = ii.shape[0] 114 | D = poses.manifold_dim 115 | 116 | ### 1: commpute jacobians and residuals ### 117 | coords, valid, (Ji, Jj, Jz) = pops.projective_transform( 118 | poses, disps, intrinsics, ii, jj, jacobian=True) 119 | 120 | r = (target - coords).view(B, N, -1, 1) 121 | w = .001 * (valid * weight).view(B, N, -1, 1) 122 | 123 | ### 2: construct linear system ### 124 | Ji = Ji.reshape(B, N, -1, D) 125 | Jj = Jj.reshape(B, N, -1, D) 126 | wJiT = (w * Ji).transpose(2,3) 127 | wJjT = (w * Jj).transpose(2,3) 128 | 129 | Hii = torch.matmul(wJiT, Ji) 130 | Hij = torch.matmul(wJiT, Jj) 131 | Hji = torch.matmul(wJjT, Ji) 132 | Hjj = torch.matmul(wJjT, Jj) 133 | 134 | vi = torch.matmul(wJiT, r).squeeze(-1) 135 | vj = torch.matmul(wJjT, r).squeeze(-1) 136 | 137 | # only optimize keyframe poses 138 | P = P // rig - fixedp 139 | ii = ii // rig - fixedp 140 | jj = jj // rig - fixedp 141 | 142 | H = safe_scatter_add_mat(Hii, ii, ii, P, P) + \ 143 | safe_scatter_add_mat(Hij, ii, jj, P, P) + \ 144 | safe_scatter_add_mat(Hji, jj, ii, P, P) + \ 145 | safe_scatter_add_mat(Hjj, jj, jj, P, P) 146 | 147 | v = safe_scatter_add_vec(vi, ii, P) + \ 148 | safe_scatter_add_vec(vj, jj, P) 149 | 150 | H = H.view(B, P, P, D, D) 151 | 152 | ### 3: solve the system ### 153 | dx = block_solve(H, v) 154 | 155 | ### 4: apply retraction ### 156 | poses = pose_retr(poses, dx, torch.arange(P) + fixedp) 157 | return poses 158 | 159 | -------------------------------------------------------------------------------- /networks/geom/chol.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from . import projective_ops as pops 4 | 5 | class CholeskySolver(torch.autograd.Function): 6 | @staticmethod 7 | def forward(ctx, H, b): 8 | # don't crash training if cholesky decomp fails 9 | try: 10 | U = torch.linalg.cholesky(H) 11 | xs = torch.cholesky_solve(b, U) 12 | ctx.save_for_backward(U, xs) 13 | ctx.failed = False 14 | except Exception as e: 15 | print(e) 16 | ctx.failed = True 17 | xs = torch.zeros_like(b) 18 | 19 | return xs 20 | 21 | @staticmethod 22 | def backward(ctx, grad_x): 23 | if ctx.failed: 24 | return None, None 25 | 26 | U, xs = ctx.saved_tensors 27 | dz = torch.cholesky_solve(grad_x, U) 28 | dH = -torch.matmul(xs, dz.transpose(-1,-2)) 29 | 30 | return dH, dz 31 | 32 | def block_solve(H, b, ep=0.1, lm=0.0001): 33 | """ solve normal equations """ 34 | B, N, _, D, _ = H.shape 35 | I = torch.eye(D).to(H.device) 36 | H = H + (ep + lm*H) * I 37 | 38 | H = H.permute(0,1,3,2,4) 39 | H = H.reshape(B, N*D, N*D) 40 | b = b.reshape(B, N*D, 1) 41 | 42 | x = CholeskySolver.apply(H,b) 43 | return x.reshape(B, N, D) 44 | 45 | 46 | def schur_solve(H, E, C, v, w, ep=0.1, lm=0.0001, sless=False): 47 | """ solve using shur complement """ 48 | 49 | B, P, M, D, HW = E.shape 50 | H = H.permute(0,1,3,2,4).reshape(B, P*D, P*D) 51 | E = E.permute(0,1,3,2,4).reshape(B, P*D, M*HW) 52 | Q = (1.0 / C).view(B, M*HW, 1) 53 | 54 | # damping 55 | I = torch.eye(P*D).to(H.device) 56 | H = H + (ep + lm*H) * I 57 | 58 | v = v.reshape(B, P*D, 1) 59 | w = w.reshape(B, M*HW, 1) 60 | 61 | Et = E.transpose(1,2) 62 | S = H - torch.matmul(E, Q*Et) 63 | v = v - torch.matmul(E, Q*w) 64 | 65 | dx = CholeskySolver.apply(S, v) 66 | if sless: 67 | return dx.reshape(B, P, D) 68 | 69 | dz = Q * (w - Et @ dx) 70 | dx = dx.reshape(B, P, D) 71 | dz = dz.reshape(B, M, HW) 72 | 73 | return dx, dz 74 | -------------------------------------------------------------------------------- /networks/geom/graph_utils.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import numpy as np 4 | from collections import OrderedDict 5 | 6 | import lietorch 7 | from .rgbd_utils import compute_distance_matrix_flow, compute_distance_matrix_flow2 8 | 9 | def graph_to_edge_list(graph): 10 | ii, jj, kk = [], [], [] 11 | for s, u in enumerate(graph): 12 | for v in graph[u]: 13 | ii.append(u) 14 | jj.append(v) 15 | kk.append(s) 16 | 17 | ii = torch.as_tensor(ii) 18 | jj = torch.as_tensor(jj) 19 | kk = torch.as_tensor(kk) 20 | return ii, jj, kk 21 | 22 | def keyframe_indicies(graph): 23 | return torch.as_tensor([u for u in graph]) 24 | 25 | def meshgrid(m, n, device='cuda'): 26 | ii, jj = torch.meshgrid(torch.arange(m), torch.arange(n)) 27 | return ii.reshape(-1).to(device), jj.reshape(-1).to(device) 28 | 29 | def neighbourhood_graph(n, r): 30 | ii, jj = meshgrid(n, n) 31 | d = (ii - jj).abs() 32 | keep = (d >= 1) & (d <= r) 33 | return ii[keep], jj[keep] 34 | 35 | 36 | def build_frame_graph(poses, disps, intrinsics, num=16, thresh=24.0, r=2): 37 | """ construct a frame graph between co-visible frames """ 38 | N = poses.shape[1] 39 | poses = poses[0].cpu().numpy() 40 | disps = disps[0][:,3::8,3::8].cpu().numpy() 41 | intrinsics = intrinsics[0].cpu().numpy() / 8.0 42 | d = compute_distance_matrix_flow(poses, disps, intrinsics) 43 | 44 | count = 0 45 | graph = OrderedDict() 46 | 47 | for i in range(N): 48 | graph[i] = [] 49 | d[i,i] = np.inf 50 | for j in range(i-r, i+r+1): 51 | if 0 <= j < N and i != j: 52 | graph[i].append(j) 53 | d[i,j] = np.inf 54 | count += 1 55 | 56 | while count < num: 57 | ix = np.argmin(d) 58 | i, j = ix // N, ix % N 59 | 60 | if d[i,j] < thresh: 61 | graph[i].append(j) 62 | d[i,j] = np.inf 63 | count += 1 64 | else: 65 | break 66 | 67 | return graph 68 | 69 | 70 | 71 | def build_frame_graph_v2(poses, disps, intrinsics, num=16, thresh=24.0, r=2): 72 | """ construct a frame graph between co-visible frames """ 73 | N = poses.shape[1] 74 | # poses = poses[0].cpu().numpy() 75 | # disps = disps[0].cpu().numpy() 76 | # intrinsics = intrinsics[0].cpu().numpy() 77 | d = compute_distance_matrix_flow2(poses, disps, intrinsics) 78 | 79 | # import matplotlib.pyplot as plt 80 | # plt.imshow(d) 81 | # plt.show() 82 | 83 | count = 0 84 | graph = OrderedDict() 85 | 86 | for i in range(N): 87 | graph[i] = [] 88 | d[i,i] = np.inf 89 | for j in range(i-r, i+r+1): 90 | if 0 <= j < N and i != j: 91 | graph[i].append(j) 92 | d[i,j] = np.inf 93 | count += 1 94 | 95 | while 1: 96 | ix = np.argmin(d) 97 | i, j = ix // N, ix % N 98 | 99 | if d[i,j] < thresh: 100 | graph[i].append(j) 101 | 102 | for i1 in range(i-1, i+2): 103 | for j1 in range(j-1, j+2): 104 | if 0 <= i1 < N and 0 <= j1 < N: 105 | d[i1, j1] = np.inf 106 | 107 | count += 1 108 | else: 109 | break 110 | 111 | return graph 112 | 113 | -------------------------------------------------------------------------------- /networks/geom/losses.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import numpy as np 3 | import torch 4 | from lietorch import SO3, SE3, Sim3 5 | from .graph_utils import graph_to_edge_list 6 | from .projective_ops import projective_transform 7 | 8 | 9 | def pose_metrics(dE): 10 | """ Translation/Rotation/Scaling metrics from Sim3 """ 11 | t, q, s = dE.data.split([3, 4, 1], -1) 12 | ang = SO3(q).log().norm(dim=-1) 13 | 14 | # convert radians to degrees 15 | r_err = (180 / np.pi) * ang 16 | t_err = t.norm(dim=-1) 17 | s_err = (s - 1.0).abs() 18 | return r_err, t_err, s_err 19 | 20 | 21 | def fit_scale(Ps, Gs): 22 | b = Ps.shape[0] 23 | t1 = Ps.data[...,:3].detach().reshape(b, -1) 24 | t2 = Gs.data[...,:3].detach().reshape(b, -1) 25 | 26 | s = (t1*t2).sum(-1) / ((t2*t2).sum(-1) + 1e-8) 27 | return s 28 | 29 | 30 | def geodesic_loss(Ps, Gs, graph, gamma=0.9, do_scale=True): 31 | """ Loss function for training network """ 32 | 33 | # relative pose 34 | ii, jj, kk = graph_to_edge_list(graph) 35 | dP = Ps[:,jj] * Ps[:,ii].inv() 36 | 37 | n = len(Gs) 38 | geodesic_loss = 0.0 39 | 40 | for i in range(n): 41 | w = gamma ** (n - i - 1) 42 | dG = Gs[i][:,jj] * Gs[i][:,ii].inv() 43 | 44 | if do_scale: 45 | s = fit_scale(dP, dG) 46 | dG = dG.scale(s[:,None]) 47 | 48 | # pose error 49 | d = (dG * dP.inv()).log() 50 | 51 | if isinstance(dG, SE3): 52 | tau, phi = d.split([3,3], dim=-1) 53 | geodesic_loss += w * ( 54 | tau.norm(dim=-1).mean() + 55 | phi.norm(dim=-1).mean()) 56 | 57 | elif isinstance(dG, Sim3): 58 | tau, phi, sig = d.split([3,3,1], dim=-1) 59 | geodesic_loss += w * ( 60 | tau.norm(dim=-1).mean() + 61 | phi.norm(dim=-1).mean() + 62 | 0.05 * sig.norm(dim=-1).mean()) 63 | 64 | dE = Sim3(dG * dP.inv()).detach() 65 | r_err, t_err, s_err = pose_metrics(dE) 66 | 67 | metrics = { 68 | 'rot_error': r_err.mean().item(), 69 | 'tr_error': t_err.mean().item(), 70 | 'bad_rot': (r_err < .1).float().mean().item(), 71 | 'bad_tr': (t_err < .01).float().mean().item(), 72 | } 73 | 74 | return geodesic_loss, metrics 75 | 76 | 77 | def residual_loss(residuals, gamma=0.9): 78 | """ loss on system residuals """ 79 | residual_loss = 0.0 80 | n = len(residuals) 81 | 82 | for i in range(n): 83 | w = gamma ** (n - i - 1) 84 | residual_loss += w * residuals[i].abs().mean() 85 | 86 | return residual_loss, {'residual': residual_loss.item()} 87 | 88 | 89 | def flow_loss(Ps, disps, poses_est, disps_est, intrinsics, graph, gamma=0.9): 90 | """ optical flow loss """ 91 | 92 | N = Ps.shape[1] 93 | graph = OrderedDict() 94 | for i in range(N): 95 | graph[i] = [j for j in range(N) if abs(i-j)==1] 96 | 97 | ii, jj, kk = graph_to_edge_list(graph) 98 | coords0, val0 = projective_transform(Ps, disps, intrinsics, ii, jj) 99 | val0 = val0 * (disps[:,ii] > 0).float().unsqueeze(dim=-1) 100 | 101 | n = len(poses_est) 102 | flow_loss = 0.0 103 | 104 | for i in range(n): 105 | w = gamma ** (n - i - 1) 106 | coords1, val1 = projective_transform(poses_est[i], disps_est[i], intrinsics, ii, jj) 107 | 108 | v = (val0 * val1).squeeze(dim=-1) 109 | epe = v * (coords1 - coords0).norm(dim=-1) 110 | flow_loss += w * epe.mean() 111 | 112 | epe = epe.reshape(-1)[v.reshape(-1) > 0.5] 113 | metrics = { 114 | 'f_error': epe.mean().item(), 115 | '1px': (epe<1.0).float().mean().item(), 116 | } 117 | 118 | return flow_loss, metrics 119 | -------------------------------------------------------------------------------- /networks/geom/projective_ops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from lietorch import SE3, Sim3 5 | 6 | from icecream import ic 7 | 8 | MIN_DEPTH = 0.2 # ????? 9 | 10 | def extract_intrinsics(intrinsics): 11 | return intrinsics[...,None,None,:].unbind(dim=-1) 12 | 13 | def coords_grid(ht, wd, **kwargs): 14 | y, x = torch.meshgrid( 15 | torch.arange(ht).to(**kwargs).float(), 16 | torch.arange(wd).to(**kwargs).float()) 17 | 18 | return torch.stack([x, y], dim=-1) 19 | 20 | def iproj(disps, intrinsics, jacobian=False): 21 | """ pinhole camera inverse projection """ 22 | ht, wd = disps.shape[2:] 23 | fx, fy, cx, cy = extract_intrinsics(intrinsics) 24 | 25 | y, x = torch.meshgrid( 26 | torch.arange(ht).to(disps.device).float(), 27 | torch.arange(wd).to(disps.device).float()) 28 | 29 | i = torch.ones_like(disps) 30 | X = (x - cx) / fx 31 | Y = (y - cy) / fy 32 | pts = torch.stack([X, Y, i, disps], dim=-1) 33 | 34 | if jacobian: 35 | J = torch.zeros_like(pts) 36 | J[...,-1] = 1.0 37 | return pts, J 38 | 39 | return pts, None 40 | 41 | def proj(Xs, intrinsics, jacobian=False, return_depth=False): 42 | """ pinhole camera projection """ 43 | fx, fy, cx, cy = extract_intrinsics(intrinsics) 44 | X, Y, Z, D = Xs.unbind(dim=-1) 45 | 46 | Z = torch.where(Z < 0.5*MIN_DEPTH, torch.ones_like(Z), Z) 47 | d = 1.0 / Z 48 | 49 | x = fx * (X * d) + cx 50 | y = fy * (Y * d) + cy 51 | if return_depth: 52 | coords = torch.stack([x, y, D*d], dim=-1) 53 | else: 54 | coords = torch.stack([x, y], dim=-1) 55 | 56 | if jacobian: 57 | B, N, H, W = d.shape 58 | o = torch.zeros_like(d) 59 | proj_jac = torch.stack([ 60 | fx*d, o, -fx*X*d*d, o, 61 | o, fy*d, -fy*Y*d*d, o, 62 | # o, o, -D*d*d, d, 63 | ], dim=-1).view(B, N, H, W, 2, 4) 64 | 65 | return coords, proj_jac 66 | 67 | return coords, None 68 | 69 | def actp(Gij, X0, jacobian=False): 70 | """ action on point cloud """ 71 | X1 = Gij[:,:,None,None] * X0 72 | 73 | if jacobian: 74 | X, Y, Z, d = X1.unbind(dim=-1) 75 | o = torch.zeros_like(d) 76 | B, N, H, W = d.shape 77 | 78 | if isinstance(Gij, SE3): 79 | Ja = torch.stack([ 80 | d, o, o, o, Z, -Y, 81 | o, d, o, -Z, o, X, 82 | o, o, d, Y, -X, o, 83 | o, o, o, o, o, o, 84 | ], dim=-1).view(B, N, H, W, 4, 6) 85 | 86 | elif isinstance(Gij, Sim3): 87 | Ja = torch.stack([ 88 | d, o, o, o, Z, -Y, X, 89 | o, d, o, -Z, o, X, Y, 90 | o, o, d, Y, -X, o, Z, 91 | o, o, o, o, o, o, o 92 | ], dim=-1).view(B, N, H, W, 4, 7) 93 | 94 | return X1, Ja 95 | 96 | return X1, None 97 | 98 | def projective_transform(poses, depths, intrinsics, ii, jj, 99 | cam_T_body=None, # aka extrinsics 100 | stereo_extrinsics=[-0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0], 101 | jacobian=False, return_depth=False): 102 | """ map points from ii->jj """ 103 | 104 | # inverse project (pinhole) 105 | X0, Jz = iproj(depths[:,ii], intrinsics[:,ii], jacobian=jacobian) 106 | 107 | # transform 108 | Gij = poses[:,jj] * poses[:,ii].inv() 109 | 110 | Gij.data[:,ii==jj] = torch.as_tensor(stereo_extrinsics, device=poses.device) 111 | X1, Ja = actp(Gij, X0, jacobian=jacobian) 112 | 113 | # project (pinhole) 114 | x1, Jp = proj(X1, intrinsics[:,jj], jacobian=jacobian, return_depth=return_depth) 115 | 116 | # exclude points too close to camera 117 | valid = ((X1[...,2] > MIN_DEPTH) & (X0[...,2] > MIN_DEPTH)).float() 118 | valid = valid.unsqueeze(-1) 119 | 120 | if jacobian: 121 | # Ji transforms according to dual adjoint 122 | Jj = torch.matmul(Jp, Ja) 123 | Ji = -Gij[:,:,None,None,None].adjT(Jj) 124 | 125 | # TODO, TODO, TODO: check this math... 126 | # Account for body to camera transformation 127 | if cam_T_body is not None: 128 | cam_T_body = SE3(cam_T_body) 129 | Ji = cam_T_body[None,None,None,None,None].adjT(Ji) 130 | Jj = cam_T_body[None,None,None,None,None].adjT(Jj) 131 | 132 | # To get right jacobians wrt world_T_body (doesn't really matter for covariance calculation since we take the square) 133 | Ji *= -1.0 134 | Jj *= -1.0 135 | 136 | # Account for Droid's (x,y,z,wx,wy,wz) convention to GTSAM's (wx,wy,wz,x,y,z) convention 137 | Ji = Ji[...,[3,4,5,0,1,2]] 138 | Jj = Jj[...,[3,4,5,0,1,2]] 139 | 140 | Jz = Gij[:,:,None,None] * Jz 141 | Jz = torch.matmul(Jp, Jz.unsqueeze(-1)) 142 | 143 | return x1, valid, (Ji, Jj, Jz) 144 | 145 | return x1, valid, (None, None, None) 146 | 147 | def induced_flow(poses, disps, intrinsics, ii, jj): 148 | """ optical flow induced by camera motion """ 149 | 150 | ht, wd = disps.shape[2:] 151 | y, x = torch.meshgrid( 152 | torch.arange(ht).to(disps.device).float(), 153 | torch.arange(wd).to(disps.device).float()) 154 | 155 | coords0 = torch.stack([x, y], dim=-1) 156 | coords1, valid = projective_transform(poses, disps, intrinsics, ii, jj, False) 157 | 158 | return coords1[...,:2] - coords0, valid 159 | 160 | -------------------------------------------------------------------------------- /networks/geom/rgbd_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os.path as osp 3 | 4 | import torch 5 | from lietorch import SE3 6 | 7 | from . import projective_ops as pops 8 | from scipy.spatial.transform import Rotation 9 | 10 | 11 | def parse_list(filepath, skiprows=0): 12 | """ read list data """ 13 | data = np.loadtxt(filepath, delimiter=' ', dtype=np.unicode_, skiprows=skiprows) 14 | return data 15 | 16 | def associate_frames(tstamp_image, tstamp_depth, tstamp_pose, max_dt=1.0): 17 | """ pair images, depths, and poses """ 18 | associations = [] 19 | for i, t in enumerate(tstamp_image): 20 | if tstamp_pose is None: 21 | j = np.argmin(np.abs(tstamp_depth - t)) 22 | if (np.abs(tstamp_depth[j] - t) < max_dt): 23 | associations.append((i, j)) 24 | 25 | else: 26 | j = np.argmin(np.abs(tstamp_depth - t)) 27 | k = np.argmin(np.abs(tstamp_pose - t)) 28 | 29 | if (np.abs(tstamp_depth[j] - t) < max_dt) and \ 30 | (np.abs(tstamp_pose[k] - t) < max_dt): 31 | associations.append((i, j, k)) 32 | 33 | return associations 34 | 35 | def loadtum(datapath, frame_rate=-1): 36 | """ read video data in tum-rgbd format """ 37 | if osp.isfile(osp.join(datapath, 'groundtruth.txt')): 38 | pose_list = osp.join(datapath, 'groundtruth.txt') 39 | 40 | elif osp.isfile(osp.join(datapath, 'pose.txt')): 41 | pose_list = osp.join(datapath, 'pose.txt') 42 | 43 | else: 44 | return None, None, None, None 45 | 46 | image_list = osp.join(datapath, 'rgb.txt') 47 | depth_list = osp.join(datapath, 'depth.txt') 48 | 49 | calib_path = osp.join(datapath, 'calibration.txt') 50 | intrinsic = None 51 | if osp.isfile(calib_path): 52 | intrinsic = np.loadtxt(calib_path, delimiter=' ') 53 | intrinsic = intrinsic.astype(np.float64) 54 | 55 | image_data = parse_list(image_list) 56 | depth_data = parse_list(depth_list) 57 | pose_data = parse_list(pose_list, skiprows=1) 58 | pose_vecs = pose_data[:,1:].astype(np.float64) 59 | 60 | tstamp_image = image_data[:,0].astype(np.float64) 61 | tstamp_depth = depth_data[:,0].astype(np.float64) 62 | tstamp_pose = pose_data[:,0].astype(np.float64) 63 | associations = associate_frames(tstamp_image, tstamp_depth, tstamp_pose) 64 | 65 | # print(len(tstamp_image)) 66 | # print(len(associations)) 67 | 68 | indicies = range(len(associations))[::5] 69 | 70 | # indicies = [ 0 ] 71 | # for i in range(1, len(associations)): 72 | # t0 = tstamp_image[associations[indicies[-1]][0]] 73 | # t1 = tstamp_image[associations[i][0]] 74 | # if t1 - t0 > 1.0 / frame_rate: 75 | # indicies += [ i ] 76 | 77 | images, poses, depths, intrinsics, tstamps = [], [], [], [], [] 78 | for ix in indicies: 79 | (i, j, k) = associations[ix] 80 | images += [ osp.join(datapath, image_data[i,1]) ] 81 | depths += [ osp.join(datapath, depth_data[j,1]) ] 82 | poses += [ pose_vecs[k] ] 83 | tstamps += [ tstamp_image[i] ] 84 | 85 | if intrinsic is not None: 86 | intrinsics += [ intrinsic ] 87 | 88 | return images, depths, poses, intrinsics, tstamps 89 | 90 | 91 | def all_pairs_distance_matrix(poses, beta=2.5): 92 | """ compute distance matrix between all pairs of poses """ 93 | poses = np.array(poses, dtype=np.float32) 94 | poses[:,:3] *= beta # scale to balence rot + trans 95 | poses = SE3(torch.from_numpy(poses)) 96 | 97 | r = (poses[:,None].inv() * poses[None,:]).log() 98 | return r.norm(dim=-1).cpu().numpy() 99 | 100 | def pose_matrix_to_quaternion(pose): 101 | """ convert 4x4 pose matrix to (t, q) """ 102 | q = Rotation.from_matrix(pose[:3, :3]).as_quat() 103 | return np.concatenate([pose[:3, 3], q], axis=0) 104 | 105 | def compute_distance_matrix_flow(poses, disps, intrinsics): 106 | """ compute flow magnitude between all pairs of frames """ 107 | if not isinstance(poses, SE3): 108 | poses = torch.from_numpy(poses).float().cuda()[None] 109 | poses = SE3(poses).inv() 110 | 111 | disps = torch.from_numpy(disps).float().cuda()[None] 112 | intrinsics = torch.from_numpy(intrinsics).float().cuda()[None] 113 | 114 | N = poses.shape[1] 115 | 116 | ii, jj = torch.meshgrid(torch.arange(N), torch.arange(N)) 117 | ii = ii.reshape(-1).cuda() 118 | jj = jj.reshape(-1).cuda() 119 | 120 | MAX_FLOW = 100.0 121 | matrix = np.zeros((N, N), dtype=np.float32) 122 | 123 | s = 2048 124 | for i in range(0, ii.shape[0], s): 125 | flow1, val1 = pops.induced_flow(poses, disps, intrinsics, ii[i:i+s], jj[i:i+s]) 126 | flow2, val2 = pops.induced_flow(poses, disps, intrinsics, jj[i:i+s], ii[i:i+s]) 127 | 128 | flow = torch.stack([flow1, flow2], dim=2) 129 | val = torch.stack([val1, val2], dim=2) 130 | 131 | mag = flow.norm(dim=-1).clamp(max=MAX_FLOW) 132 | mag = mag.view(mag.shape[1], -1) 133 | val = val.view(val.shape[1], -1) 134 | 135 | mag = (mag * val).mean(-1) / val.mean(-1) 136 | mag[val.mean(-1) < 0.7] = np.inf 137 | 138 | i1 = ii[i:i+s].cpu().numpy() 139 | j1 = jj[i:i+s].cpu().numpy() 140 | matrix[i1, j1] = mag.cpu().numpy() 141 | 142 | return matrix 143 | 144 | 145 | def compute_distance_matrix_flow2(poses, disps, intrinsics, beta=0.4): 146 | """ compute flow magnitude between all pairs of frames """ 147 | # if not isinstance(poses, SE3): 148 | # poses = torch.from_numpy(poses).float().cuda()[None] 149 | # poses = SE3(poses).inv() 150 | 151 | # disps = torch.from_numpy(disps).float().cuda()[None] 152 | # intrinsics = torch.from_numpy(intrinsics).float().cuda()[None] 153 | 154 | N = poses.shape[1] 155 | 156 | ii, jj = torch.meshgrid(torch.arange(N), torch.arange(N)) 157 | ii = ii.reshape(-1) 158 | jj = jj.reshape(-1) 159 | 160 | MAX_FLOW = 128.0 161 | matrix = np.zeros((N, N), dtype=np.float32) 162 | 163 | s = 2048 164 | for i in range(0, ii.shape[0], s): 165 | flow1a, val1a = pops.induced_flow(poses, disps, intrinsics, ii[i:i+s], jj[i:i+s], tonly=True) 166 | flow1b, val1b = pops.induced_flow(poses, disps, intrinsics, ii[i:i+s], jj[i:i+s]) 167 | flow2a, val2a = pops.induced_flow(poses, disps, intrinsics, jj[i:i+s], ii[i:i+s], tonly=True) 168 | flow2b, val2b = pops.induced_flow(poses, disps, intrinsics, ii[i:i+s], jj[i:i+s]) 169 | 170 | flow1 = flow1a + beta * flow1b 171 | val1 = val1a * val2b 172 | 173 | flow2 = flow2a + beta * flow2b 174 | val2 = val2a * val2b 175 | 176 | flow = torch.stack([flow1, flow2], dim=2) 177 | val = torch.stack([val1, val2], dim=2) 178 | 179 | mag = flow.norm(dim=-1).clamp(max=MAX_FLOW) 180 | mag = mag.view(mag.shape[1], -1) 181 | val = val.view(val.shape[1], -1) 182 | 183 | mag = (mag * val).mean(-1) / val.mean(-1) 184 | mag[val.mean(-1) < 0.8] = np.inf 185 | 186 | i1 = ii[i:i+s].cpu().numpy() 187 | j1 = jj[i:i+s].cpu().numpy() 188 | matrix[i1, j1] = mag.cpu().numpy() 189 | 190 | return matrix 191 | -------------------------------------------------------------------------------- /networks/modules/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | -------------------------------------------------------------------------------- /networks/modules/clipping.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | GRAD_CLIP = .01 6 | 7 | class GradClip(torch.autograd.Function): 8 | @staticmethod 9 | def forward(ctx, x): 10 | return x 11 | 12 | @staticmethod 13 | def backward(ctx, grad_x): 14 | o = torch.zeros_like(grad_x) 15 | grad_x = torch.where(grad_x.abs()>GRAD_CLIP, o, grad_x) 16 | grad_x = torch.where(torch.isnan(grad_x), o, grad_x) 17 | return grad_x 18 | 19 | class GradientClip(nn.Module): 20 | def __init__(self): 21 | super(GradientClip, self).__init__() 22 | 23 | def forward(self, x): 24 | return GradClip.apply(x) -------------------------------------------------------------------------------- /networks/modules/corr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | import droid_backends 5 | 6 | class CorrSampler(torch.autograd.Function): 7 | 8 | @staticmethod 9 | def forward(ctx, volume, coords, radius): 10 | ctx.save_for_backward(volume,coords) 11 | ctx.radius = radius 12 | corr, = droid_backends.corr_index_forward(volume, coords, radius) 13 | return corr 14 | 15 | @staticmethod 16 | def backward(ctx, grad_output): 17 | volume, coords = ctx.saved_tensors 18 | grad_output = grad_output.contiguous() 19 | grad_volume, = droid_backends.corr_index_backward(volume, coords, grad_output, ctx.radius) 20 | return grad_volume, None, None 21 | 22 | 23 | class CorrBlock: 24 | def __init__(self, fmap1, fmap2, num_levels=4, radius=3): 25 | self.num_levels = num_levels 26 | self.radius = radius 27 | self.corr_pyramid = [] 28 | 29 | # all pairs correlation 30 | corr = CorrBlock.corr(fmap1, fmap2) 31 | 32 | batch, num, h1, w1, h2, w2 = corr.shape 33 | corr = corr.reshape(batch*num*h1*w1, 1, h2, w2) 34 | 35 | for i in range(self.num_levels): 36 | self.corr_pyramid.append( 37 | corr.view(batch*num, h1, w1, h2//2**i, w2//2**i)) 38 | corr = F.avg_pool2d(corr, 2, stride=2) 39 | 40 | def __call__(self, coords): 41 | out_pyramid = [] 42 | batch, num, ht, wd, _ = coords.shape 43 | coords = coords.permute(0,1,4,2,3) 44 | coords = coords.contiguous().view(batch*num, 2, ht, wd) 45 | 46 | for i in range(self.num_levels): 47 | corr = CorrSampler.apply(self.corr_pyramid[i], coords/2**i, self.radius) 48 | out_pyramid.append(corr.view(batch, num, -1, ht, wd)) 49 | 50 | return torch.cat(out_pyramid, dim=2) 51 | 52 | def cat(self, other): 53 | for i in range(self.num_levels): 54 | self.corr_pyramid[i] = torch.cat([self.corr_pyramid[i], other.corr_pyramid[i]], 0) 55 | return self 56 | 57 | def __getitem__(self, index): 58 | for i in range(self.num_levels): 59 | self.corr_pyramid[i] = self.corr_pyramid[i][index] 60 | return self 61 | 62 | 63 | @staticmethod 64 | def corr(fmap1, fmap2): 65 | """ all-pairs correlation """ 66 | batch, num, dim, ht, wd = fmap1.shape 67 | fmap1 = fmap1.reshape(batch*num, dim, ht*wd) / 4.0 68 | fmap2 = fmap2.reshape(batch*num, dim, ht*wd) / 4.0 69 | 70 | # multiplies along the `dim' dimension, for each batch, for each feature map 71 | corr = torch.matmul(fmap1.transpose(1,2), fmap2) 72 | return corr.view(batch, num, ht, wd, ht, wd) 73 | 74 | 75 | class CorrLayer(torch.autograd.Function): 76 | @staticmethod 77 | def forward(ctx, fmap1, fmap2, coords, r): 78 | ctx.r = r 79 | ctx.save_for_backward(fmap1, fmap2, coords) 80 | corr, = droid_backends.altcorr_forward(fmap1, fmap2, coords, ctx.r) 81 | return corr 82 | 83 | @staticmethod 84 | def backward(ctx, grad_corr): 85 | fmap1, fmap2, coords = ctx.saved_tensors 86 | grad_corr = grad_corr.contiguous() 87 | fmap1_grad, fmap2_grad, coords_grad = \ 88 | droid_backends.altcorr_backward(fmap1, fmap2, coords, grad_corr, ctx.r) 89 | return fmap1_grad, fmap2_grad, coords_grad, None 90 | 91 | 92 | class AltCorrBlock: 93 | def __init__(self, fmaps, num_levels=4, radius=3): 94 | self.num_levels = num_levels 95 | self.radius = radius 96 | 97 | B, N, C, H, W = fmaps.shape 98 | fmaps = fmaps.view(B*N, C, H, W) / 4.0 99 | 100 | self.pyramid = [] 101 | for i in range(self.num_levels): 102 | sz = (B, N, H//2**i, W//2**i, C) 103 | fmap_lvl = fmaps.permute(0, 2, 3, 1).contiguous() 104 | self.pyramid.append(fmap_lvl.view(*sz)) 105 | fmaps = F.avg_pool2d(fmaps, 2, stride=2) 106 | 107 | def corr_fn(self, coords, ii, jj): 108 | B, N, H, W, S, _ = coords.shape 109 | coords = coords.permute(0, 1, 4, 2, 3, 5) 110 | 111 | corr_list = [] 112 | for i in range(self.num_levels): 113 | r = self.radius 114 | fmap1_i = self.pyramid[0][:, ii] 115 | fmap2_i = self.pyramid[i][:, jj] 116 | 117 | coords_i = (coords / 2**i).reshape(B*N, S, H, W, 2).contiguous() 118 | fmap1_i = fmap1_i.reshape((B*N,) + fmap1_i.shape[2:]) 119 | fmap2_i = fmap2_i.reshape((B*N,) + fmap2_i.shape[2:]) 120 | 121 | corr = CorrLayer.apply(fmap1_i.float(), fmap2_i.float(), coords_i, self.radius) 122 | corr = corr.view(B, N, S, -1, H, W).permute(0, 1, 3, 4, 5, 2) 123 | corr_list.append(corr) 124 | 125 | corr = torch.cat(corr_list, dim=2) 126 | return corr 127 | 128 | 129 | def __call__(self, coords, ii, jj): 130 | squeeze_output = False 131 | if len(coords.shape) == 5: 132 | coords = coords.unsqueeze(dim=-2) 133 | squeeze_output = True 134 | 135 | corr = self.corr_fn(coords, ii, jj) 136 | 137 | if squeeze_output: 138 | corr = corr.squeeze(dim=-1) 139 | 140 | return corr.contiguous() 141 | 142 | -------------------------------------------------------------------------------- /networks/modules/extractor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class ResidualBlock(nn.Module): 7 | def __init__(self, in_planes, planes, norm_fn='group', stride=1): 8 | super(ResidualBlock, self).__init__() 9 | 10 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride) 11 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) 12 | self.relu = nn.ReLU(inplace=True) 13 | 14 | num_groups = planes // 8 15 | 16 | if norm_fn == 'group': 17 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 18 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 19 | if not stride == 1: 20 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 21 | 22 | elif norm_fn == 'batch': 23 | self.norm1 = nn.BatchNorm2d(planes) 24 | self.norm2 = nn.BatchNorm2d(planes) 25 | if not stride == 1: 26 | self.norm3 = nn.BatchNorm2d(planes) 27 | 28 | elif norm_fn == 'instance': 29 | self.norm1 = nn.InstanceNorm2d(planes) 30 | self.norm2 = nn.InstanceNorm2d(planes) 31 | if not stride == 1: 32 | self.norm3 = nn.InstanceNorm2d(planes) 33 | 34 | elif norm_fn == 'none': 35 | self.norm1 = nn.Sequential() 36 | self.norm2 = nn.Sequential() 37 | if not stride == 1: 38 | self.norm3 = nn.Sequential() 39 | 40 | if stride == 1: 41 | self.downsample = None 42 | 43 | else: 44 | self.downsample = nn.Sequential( 45 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) 46 | 47 | def forward(self, x): 48 | y = x 49 | y = self.relu(self.norm1(self.conv1(y))) 50 | y = self.relu(self.norm2(self.conv2(y))) 51 | 52 | if self.downsample is not None: 53 | x = self.downsample(x) 54 | 55 | return self.relu(x+y) 56 | 57 | 58 | class BottleneckBlock(nn.Module): 59 | def __init__(self, in_planes, planes, norm_fn='group', stride=1): 60 | super(BottleneckBlock, self).__init__() 61 | 62 | self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0) 63 | self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride) 64 | self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0) 65 | self.relu = nn.ReLU(inplace=True) 66 | 67 | num_groups = planes // 8 68 | 69 | if norm_fn == 'group': 70 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) 71 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) 72 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 73 | if not stride == 1: 74 | self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 75 | 76 | elif norm_fn == 'batch': 77 | self.norm1 = nn.BatchNorm2d(planes//4) 78 | self.norm2 = nn.BatchNorm2d(planes//4) 79 | self.norm3 = nn.BatchNorm2d(planes) 80 | if not stride == 1: 81 | self.norm4 = nn.BatchNorm2d(planes) 82 | 83 | elif norm_fn == 'instance': 84 | self.norm1 = nn.InstanceNorm2d(planes//4) 85 | self.norm2 = nn.InstanceNorm2d(planes//4) 86 | self.norm3 = nn.InstanceNorm2d(planes) 87 | if not stride == 1: 88 | self.norm4 = nn.InstanceNorm2d(planes) 89 | 90 | elif norm_fn == 'none': 91 | self.norm1 = nn.Sequential() 92 | self.norm2 = nn.Sequential() 93 | self.norm3 = nn.Sequential() 94 | if not stride == 1: 95 | self.norm4 = nn.Sequential() 96 | 97 | if stride == 1: 98 | self.downsample = None 99 | 100 | else: 101 | self.downsample = nn.Sequential( 102 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4) 103 | 104 | def forward(self, x): 105 | y = x 106 | y = self.relu(self.norm1(self.conv1(y))) 107 | y = self.relu(self.norm2(self.conv2(y))) 108 | y = self.relu(self.norm3(self.conv3(y))) 109 | 110 | if self.downsample is not None: 111 | x = self.downsample(x) 112 | 113 | return self.relu(x+y) 114 | 115 | 116 | DIM=32 117 | 118 | class BasicEncoder(nn.Module): 119 | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0, multidim=False): 120 | super(BasicEncoder, self).__init__() 121 | self.norm_fn = norm_fn 122 | self.multidim = multidim 123 | 124 | if self.norm_fn == 'group': 125 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=DIM) 126 | 127 | elif self.norm_fn == 'batch': 128 | self.norm1 = nn.BatchNorm2d(DIM) 129 | 130 | elif self.norm_fn == 'instance': 131 | self.norm1 = nn.InstanceNorm2d(DIM) 132 | 133 | elif self.norm_fn == 'none': 134 | self.norm1 = nn.Sequential() 135 | 136 | self.conv1 = nn.Conv2d(3, DIM, kernel_size=7, stride=2, padding=3) 137 | self.relu1 = nn.ReLU(inplace=True) 138 | 139 | self.in_planes = DIM 140 | self.layer1 = self._make_layer(DIM, stride=1) 141 | self.layer2 = self._make_layer(2*DIM, stride=2) 142 | self.layer3 = self._make_layer(4*DIM, stride=2) 143 | 144 | # output convolution 145 | self.conv2 = nn.Conv2d(4*DIM, output_dim, kernel_size=1) 146 | 147 | if self.multidim: 148 | self.layer4 = self._make_layer(256, stride=2) 149 | self.layer5 = self._make_layer(512, stride=2) 150 | 151 | self.in_planes = 256 152 | self.layer6 = self._make_layer(256, stride=1) 153 | 154 | self.in_planes = 128 155 | self.layer7 = self._make_layer(128, stride=1) 156 | 157 | self.up1 = nn.Conv2d(512, 256, 1) 158 | self.up2 = nn.Conv2d(256, 128, 1) 159 | self.conv3 = nn.Conv2d(128, output_dim, kernel_size=1) 160 | 161 | if dropout > 0: 162 | self.dropout = nn.Dropout2d(p=dropout) 163 | else: 164 | self.dropout = None 165 | 166 | for m in self.modules(): 167 | if isinstance(m, nn.Conv2d): 168 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 169 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): 170 | if m.weight is not None: 171 | nn.init.constant_(m.weight, 1) 172 | if m.bias is not None: 173 | nn.init.constant_(m.bias, 0) 174 | 175 | def _make_layer(self, dim, stride=1): 176 | layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) 177 | layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) 178 | layers = (layer1, layer2) 179 | 180 | self.in_planes = dim 181 | return nn.Sequential(*layers) 182 | 183 | def forward(self, x): 184 | b, n, c1, h1, w1 = x.shape 185 | x = x.view(b*n, c1, h1, w1) 186 | 187 | x = self.conv1(x) 188 | x = self.norm1(x) 189 | x = self.relu1(x) 190 | 191 | x = self.layer1(x) 192 | x = self.layer2(x) 193 | x = self.layer3(x) 194 | 195 | x = self.conv2(x) 196 | 197 | _, c2, h2, w2 = x.shape 198 | return x.view(b, n, c2, h2, w2) 199 | -------------------------------------------------------------------------------- /networks/modules/gru.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class ConvGRU(nn.Module): 6 | def __init__(self, h_planes=128, i_planes=128): 7 | super(ConvGRU, self).__init__() 8 | self.do_checkpoint = False 9 | self.convz = nn.Conv2d(h_planes+i_planes, h_planes, 3, padding=1) 10 | self.convr = nn.Conv2d(h_planes+i_planes, h_planes, 3, padding=1) 11 | self.convq = nn.Conv2d(h_planes+i_planes, h_planes, 3, padding=1) 12 | 13 | self.w = nn.Conv2d(h_planes, h_planes, 1, padding=0) 14 | 15 | self.convz_glo = nn.Conv2d(h_planes, h_planes, 1, padding=0) 16 | self.convr_glo = nn.Conv2d(h_planes, h_planes, 1, padding=0) 17 | self.convq_glo = nn.Conv2d(h_planes, h_planes, 1, padding=0) 18 | 19 | def forward(self, net, *inputs): 20 | inp = torch.cat(inputs, dim=1) 21 | net_inp = torch.cat([net, inp], dim=1) 22 | 23 | b, c, h, w = net.shape 24 | glo = torch.sigmoid(self.w(net)) * net 25 | glo = glo.view(b, c, h*w).mean(-1).view(b, c, 1, 1) 26 | 27 | z = torch.sigmoid(self.convz(net_inp) + self.convz_glo(glo)) 28 | r = torch.sigmoid(self.convr(net_inp) + self.convr_glo(glo)) 29 | q = torch.tanh(self.convq(torch.cat([r*net, inp], dim=1)) + self.convq_glo(glo)) 30 | 31 | net = (1-z) * net + z * q 32 | return net 33 | 34 | 35 | -------------------------------------------------------------------------------- /networks/motion_filter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import lietorch 3 | 4 | from .geom import projective_ops as pops 5 | from .modules.corr import CorrBlock 6 | 7 | 8 | # Basically populates the video with the frames that have enough motion 9 | # It calculates the features for each frame, does one step GRU update 10 | # to get a sense of the flow and if the flow > min_flow 11 | # it adds the frame, together with features and context to the video object 12 | class MotionFilter: 13 | """ This class is used to filter incoming frames and extract features """ 14 | def __init__(self, net, video, min_flow_thresh=2.5, device="cuda:0"): 15 | # split net modules 16 | self.context_net = net.cnet 17 | self.feature_net = net.fnet 18 | self.update_net = net.update 19 | 20 | self.video = video 21 | self.min_flow_thresh = min_flow_thresh 22 | 23 | self.skipped_frames = 0 24 | self.device = device 25 | 26 | # mean, std for image normalization 27 | self.MEAN = torch.as_tensor([0.485, 0.456, 0.406], device=self.device)[:, None, None] 28 | self.STDV = torch.as_tensor([0.229, 0.224, 0.225], device=self.device)[:, None, None] 29 | 30 | @torch.cuda.amp.autocast(enabled=True) 31 | def __context_encoder(self, image): 32 | """ context features """ 33 | context_maps, gru_input_maps = self.context_net(image).split([128,128], dim=2) 34 | return context_maps.tanh().squeeze(0), gru_input_maps.relu().squeeze(0) 35 | 36 | @torch.cuda.amp.autocast(enabled=True) 37 | def __feature_encoder(self, image): 38 | """ features for correlation volume """ 39 | return self.feature_net(image).squeeze(0) 40 | 41 | @torch.cuda.amp.autocast(enabled=True) 42 | @torch.no_grad() 43 | def track(self, k, timestamp, image, depth=None, intrinsics=None): 44 | """ main update operation - run on every frame in video """ 45 | 46 | # normalize images 47 | img_normalized = image[None, :, [2,1,0]].to(self.device) / 255.0 48 | img_normalized = img_normalized.sub_(self.MEAN).div_(self.STDV) 49 | 50 | # extract features 51 | feature_map = self.__feature_encoder(img_normalized) 52 | 53 | ### always add first frame to the depth video ### 54 | if k == 0: 55 | self.add_frame_to_video(timestamp, image, img_normalized, feature_map, depth, intrinsics) 56 | ### only add new frame if there is enough motion ### 57 | else: 58 | ht = image.shape[-2] // 8 59 | wd = image.shape[-1] // 8 60 | 61 | # Index correlation volume 62 | coords0 = pops.coords_grid(ht, wd, device=self.device)[None,None] 63 | corr = CorrBlock(self.feature_maps[None,[0]], feature_map[None,[0]])(coords0) # TODO why not send the corr block? 64 | 65 | # Approximate flow magnitude using 1 update iteration 66 | _, delta, weight = self.update_net(self.context_maps[None], self.gru_input_maps[None], corr) 67 | 68 | # Check motion magnitue / add new frame to video 69 | has_enough_motion = delta.norm(dim=-1).mean().item() > self.min_flow_thresh 70 | if has_enough_motion: 71 | self.add_frame_to_video(timestamp, image, img_normalized, feature_map, depth, intrinsics) 72 | self.skipped_frames = 0 73 | else: 74 | self.skipped_frames += 1 75 | 76 | # Save the network responses (feature + context) and image for later inference 77 | def add_frame_to_video(self, timestamp, image, img_normalized, feature_map, depth_img=None, intrinsics=None): 78 | context_maps, gru_input_maps = self.__context_encoder(img_normalized[:,[0]]) 79 | self.context_maps, self.gru_input_maps, self.feature_maps = context_maps, gru_input_maps, feature_map 80 | identity_pose = lietorch.SE3.Identity(1,).data.squeeze() 81 | self.video.append(timestamp, image[0], identity_pose, 82 | 1.0, # Initialize disps at 1 83 | depth_img, # If available 84 | intrinsics / 8.0, 85 | feature_map, context_maps[0,0], gru_input_maps[0,0]) 86 | -------------------------------------------------------------------------------- /pipeline/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ToniRV/NeRF-SLAM/4e407e8cb1378c6ece18e621b19ccd5be982b7dd/pipeline/__init__.py -------------------------------------------------------------------------------- /pipeline/pipeline.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | 4 | class Pipeline: 5 | def __init__(self, ): 6 | super().__init__() 7 | 8 | -------------------------------------------------------------------------------- /pipeline/pipeline_module.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | import colored_glog as log 3 | from icecream import ic 4 | 5 | #from torch.profiler import profile, ProfilerActivity 6 | 7 | class PipelineModuleBase: 8 | def __init__(self, name, parallel_run, args=None, grad=False): 9 | self.name = name 10 | self.parallel_run = parallel_run 11 | self.grad = grad # determines if we are tracing grads or not 12 | self.shutdown = False # needs to be atomic 13 | self.is_initialized = False # needs to be atomic 14 | self.is_thread_working = False # needs to be atomic 15 | self.args = args # arguments to init module 16 | # Callbacks to be called in case module does not return an output. 17 | self.on_failure_callbacks = [] 18 | self.profile = False # Profile the code for runtime and/or memory 19 | 20 | @abstractmethod 21 | def initialize_module(self): 22 | # Allocate memory and initialize variables here so that 23 | # we do not need to avoid when running in parallel a: 24 | # "TypeError: cannot pickle 'XXXX' object" 25 | self.is_initialized = True 26 | 27 | @abstractmethod 28 | def spin(self) -> bool: 29 | pass 30 | 31 | @abstractmethod 32 | def shutdown_queues(self): 33 | pass 34 | 35 | @abstractmethod 36 | def has_work(self): 37 | pass 38 | 39 | def shutdown_module(self): 40 | # TODO shouldn't self.shutdown be atomic? (i.e. thread-safe?) 41 | if self.shutdown: 42 | log.warn(f"Module: {self.name} - Shutdown requested, but was already shutdown.") 43 | log.debug(f"Stopping module {self.name} and its queues...") 44 | self.shutdown_queues() 45 | log.info(f"Module: {self.name} - Shutting down.") 46 | self.shutdown = True 47 | 48 | def restart(self): 49 | log.info(f"Module: {self.name} - Resetting shutdown flag to false") 50 | self.shutdown = False 51 | 52 | def is_working(self): 53 | return self.is_thread_working or self.hasWork() 54 | 55 | def register_on_failure_callback(self, callback): 56 | log.check(callback) 57 | self.on_failure_callbacks.append(callback) 58 | 59 | def notify_on_failure(self): 60 | for on_failure_callback in self.on_failure_callbacks: 61 | if on_failure_callback: 62 | on_failure_callback() 63 | else: 64 | log.error(f"Invalid OnFailureCallback for module: {self.name}") 65 | 66 | class PipelineModule(PipelineModuleBase): 67 | def __init__(self, name_id, parallel_run, args=None, grad=False) -> None: 68 | super().__init__(name_id, parallel_run, args, grad) 69 | 70 | @abstractmethod 71 | def get_input_packet(self): 72 | raise 73 | 74 | @abstractmethod 75 | def push_output_packet(self, output_packet) -> bool: 76 | raise 77 | 78 | @abstractmethod 79 | def spin_once(self, input): 80 | raise 81 | 82 | # Spin is called in a thread. 83 | def spin(self): 84 | if self.parallel_run: 85 | log.info(f'Module: {self.name} - Spinning.') 86 | 87 | if not self.is_initialized: 88 | self.initialize_module() 89 | 90 | while not self.shutdown: 91 | self.is_thread_working = False; 92 | input = self.get_input_packet(); 93 | self.is_thread_working = True; 94 | if input is not None: 95 | output = None 96 | if self.profile: 97 | #with profile(activities=[ProfilerActivity.CUDA], profile_memory=True, record_shapes=True) as prof: 98 | output = self.spin_once(input); 99 | #print(prof.key_averages().table(sort_by="self_cuda_memory_usage", row_limit=10)) 100 | else: 101 | output = self.spin_once(input); 102 | if output is not None: 103 | # Received a valid output, send to output queue 104 | if not self.push_output_packet(output): 105 | log.warn(f"Module: {self.name} - Output push failed.") 106 | else: 107 | log.debug(f"Module: {self.name} - Pushed output.") 108 | else: 109 | log.debug(f"Module: {self.name} - Skipped sending an output.") 110 | # Notify interested parties about failure. 111 | self.notify_on_failure(); 112 | else: 113 | log.log(2, f"Module: {self.name} - No Input received.") 114 | 115 | # Break the while loop if we are in sequential mode. 116 | if not self.parallel_run: 117 | self.is_thread_working = False; 118 | return True; 119 | 120 | self.is_thread_working = False; 121 | log.info(f"Module: {self.name} - Successful shutdown.") 122 | return False; 123 | 124 | class MIMOPipelineModule(PipelineModule): 125 | def __init__(self, name_id, parallel_run, args=None, grad=False): 126 | super().__init__(name_id, parallel_run, args, grad) 127 | self.input_queues = {} 128 | self.output_callbacks = [] 129 | self.output_queues = [] 130 | 131 | def register_input_queue(self, name, input_queue): 132 | self.input_queues[name] = input_queue 133 | 134 | def register_output_callback(self, output_callback): 135 | self.output_callbacks.append(output_callback) 136 | 137 | def register_output_queue(self, output_queue): 138 | self.output_queues.append(output_queue) 139 | 140 | # TODO: warn when callbacks take too long 141 | def push_output_packet(self, output_packet): 142 | push_success = True 143 | # Push output to all queues 144 | for output_queue in self.output_queues: 145 | try: 146 | output_queue.put(output_packet) 147 | except Exception as e: 148 | log.warn(e) 149 | push_success = False 150 | # Push output to all callbacks 151 | for callback in self.output_callbacks: 152 | try: 153 | callback(output_packet) 154 | except Exception as e: 155 | log.warn(e) 156 | push_success = False 157 | return push_success 158 | 159 | def get_input_packet(self, timeout=0.1): 160 | inputs = {} 161 | if self.parallel_run: 162 | for name, input_queue in self.input_queues.items(): 163 | try: 164 | inputs[name] = input_queue.get(timeout=timeout) 165 | except Exception as e: 166 | log.debug(e) 167 | else: 168 | for name, input_queue in self.input_queues.items(): 169 | try: 170 | inputs[name] = input_queue.get_nowait() 171 | except Exception as e: 172 | log.debug(e) 173 | 174 | if len(inputs) == 0: 175 | log.debug(f"Module: {self.name} - Input queues didn't return an output.") 176 | inputs = None 177 | return inputs 178 | 179 | def shutdown_queues(self): 180 | super().shutdown_queues() 181 | # This should be automatically called by garbage collector 182 | # [input_queue.close() for input_queue in self.input_queues] 183 | # [output_queue.close() for output_queue in self.output_queues] -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | scipy 3 | tqdm 4 | matplotlib 5 | 6 | open3d 7 | opencv-python 8 | 9 | torch-scatter 10 | 11 | jupyter 12 | imageio 13 | 14 | glog 15 | icecream 16 | 17 | pandas 18 | pyrealsense2 19 | 20 | pybind11 21 | gdown 22 | -------------------------------------------------------------------------------- /scripts/convergence_plots.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "10010259", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import pandas as pd\n", 11 | "\n", 12 | "import numpy as np\n", 13 | "import matplotlib.pyplot as plt\n", 14 | "plt.rcParams[\"text.usetex\"] = True\n", 15 | "plt.rcParams[\"font.family\"] = \"serif\"\n", 16 | "plt.rcParams[\"font.size\"] = \"14\"\n", 17 | "\n", 18 | "\n", 19 | "def plot_results(dfs, show=True, save=True, xlim=5000):\n", 20 | " fig, ax1 = plt.subplots()\n", 21 | " \n", 22 | " from itertools import cycle\n", 23 | " lines = [\"-\",\"--\",\"-.\",\":\"]\n", 24 | " linecycler = cycle(lines)\n", 25 | "\n", 26 | " \n", 27 | " ax1.set_xlabel('Iter [-]')\n", 28 | " \n", 29 | " psnr_color = 'tab:red'\n", 30 | " ax1.set_ylabel('PSNR [dB]', color=psnr_color)\n", 31 | " ax1.tick_params(axis='y', labelcolor=psnr_color)\n", 32 | " #ax1.set_xlim(0, df['Iter'][df.index[-1]])\n", 33 | " ax1.set_xlim(0, xlim)\n", 34 | " \n", 35 | " ax2 = ax1.twiny()\n", 36 | " ax2.set_xlabel('$\\Delta$t [s]')\n", 37 | "\n", 38 | " l1_color = 'tab:blue'\n", 39 | " ax3 = ax1.twinx() # instantiate a second axes that shares the same x-axis\n", 40 | " ax3.set_ylabel('L1 [cm]', color=l1_color)\n", 41 | " ax3.set_ylim(0, 30)\n", 42 | " ax3.tick_params(axis='y', labelcolor=l1_color)\n", 43 | " \n", 44 | " for name, df in dfs.items(): \n", 45 | " \n", 46 | " #df = df.iloc[1: , :] # Drop first row which only has 0.0\n", 47 | " df = df.set_index('Iter')\n", 48 | " \n", 49 | " line_style = next(linecycler)\n", 50 | "\n", 51 | " \n", 52 | " ax1.plot(df.index, df['PSNR'], color=psnr_color, linestyle=line_style, label=name) \n", 53 | " ax3.plot(df.index, df['L1'], color=l1_color, linestyle=line_style)\n", 54 | "\n", 55 | " ax2.set_xticks(ax1.get_xticks())\n", 56 | " ax2.set_xbound(ax1.get_xbound())\n", 57 | " xticklabels = [df['Dt'].loc[xtick].round(0).item() for xtick in ax1.get_xticks()]\n", 58 | " print(xticklabels)\n", 59 | " ax2.set_xticklabels(xticklabels)\n", 60 | " \n", 61 | " \n", 62 | " #ax1.legend(loc='upper center', fontsize='x-large')\n", 63 | " leg = ax1.legend(loc='center right')\n", 64 | " [lgd.set_color('black') for lgd in leg.legendHandles]\n", 65 | "\n", 66 | " fig.tight_layout() # otherwise the right y-label is slightly clipped\n", 67 | " if show: plt.show()\n", 68 | " if save: \n", 69 | " #fig.set_size_inches(3, 2)\n", 70 | " out = ''\n", 71 | " for name in list(dfs.keys())[:-1]:\n", 72 | " out += name + '_vs_'\n", 73 | " out += list(dfs.keys())[-1]\n", 74 | " print(f\"Saving to: {out}\")\n", 75 | " fig.savefig(out + '.svg', dpi=100, transparent=False, bbox_inches='tight')" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": null, 81 | "id": "edc45e33", 82 | "metadata": { 83 | "scrolled": false 84 | }, 85 | "outputs": [], 86 | "source": [ 87 | "dfs = {\n", 88 | " #\"ours_l1\": pd.read_csv('../room_results_ours_l1.csv'),\n", 89 | " \"no depth\": pd.read_csv('../room_results_no_depth_3.csv'),\n", 90 | " \"raw depth\": pd.read_csv('../room_results_raw_depth.csv'),\n", 91 | " #\"raw depth (w annealing)\": pd.read_csv('../room_results_raw_w_annealing.csv'),\n", 92 | " \"weighted depth\": pd.read_csv('../room_results_ours_wo_thresh.csv'),\n", 93 | " #\"weighted + annealed\": pd.read_csv('../room_results_ours_l2_full.csv'),\n", 94 | " \"weighted + annealed\": pd.read_csv('../room_results_ours_w_annealing.csv'),\n", 95 | "}\n", 96 | "plot_results(dfs, show=True, save=\"room_results\")" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": null, 102 | "id": "ac677682", 103 | "metadata": {}, 104 | "outputs": [], 105 | "source": [ 106 | "dfs = {\n", 107 | " #\"ours_no_anneal_thresh\": pd.read_csv('../replica_office0_ours_wo_anneal_w_thresh.csv').iloc[9:,:],\n", 108 | " \"no depth\": pd.read_csv('../replica_office0_no_depth_2.csv'),\n", 109 | " #\"ours_median_annealed\": pd.read_csv('../replica_office0_ours_median_annealed.csv'),\n", 110 | " #\"ours_median_annealed\": pd.read_csv('../replica_office0_ours_w_anneal_w_median.csv'), \n", 111 | " #\"ours_wo_thresh_10\": pd.read_csv('../replica_office0_ours_wo_anneal_wo_thresh_3.csv'),\n", 112 | " #\"ours_wo_thresh\": pd.read_csv('../replica_office0_ours_wo_anneal_wo_thresh_2.csv'),\n", 113 | " \"raw\": pd.read_csv('../replica_office0_raw_2.csv'),\n", 114 | " \"ours_median\": pd.read_csv('../replica_office0_ours_wo_anneal_w_median.csv'),\n", 115 | " #\"ours_median_100\": pd.read_csv('../replica_office0_ours_wo_anneal_median_100_lucky.csv')\n", 116 | "}\n", 117 | "plot_results(dfs, show=True, save=\"room_results\", xlim=20000)" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": null, 123 | "id": "7d52e51f", 124 | "metadata": {}, 125 | "outputs": [], 126 | "source": [ 127 | "dfs = {\n", 128 | " \"room0\": pd.read_csv('../room0_nerf_results.csv'),\n", 129 | " \"room1\": pd.read_csv('../room1_nerf_results.csv'),\n", 130 | " \"room2\": pd.read_csv('../room2_nerf_results.csv'),\n", 131 | " \"office0\": pd.read_csv('../office0_nerf_results.csv'),\n", 132 | " \"office1\": pd.read_csv('../office1_nerf_results.csv'),\n", 133 | "}\n", 134 | "plot_results(dfs, show=True, save=\"room_results\", xlim=25000)" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": null, 140 | "id": "a7327f16", 141 | "metadata": {}, 142 | "outputs": [], 143 | "source": [ 144 | "dfs = {\n", 145 | " \"no depth\": pd.read_csv('../room_nerf_no_depth_results.csv'),\n", 146 | " \n", 147 | " #\"weighted+filtered\": pd.read_csv('../room_nerf_ours_w_thresh_results.csv'),\n", 148 | " \"raw\": pd.read_csv('../room_nerf_raw_results.csv'),\n", 149 | " \"weighted\": pd.read_csv('../room_nerf_ours_results.csv'),\n", 150 | "}\n", 151 | "plot_results(dfs, show=True, save=\"room_results\", xlim=25000)" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": null, 157 | "id": "0e9db743", 158 | "metadata": {}, 159 | "outputs": [], 160 | "source": [ 161 | "dfs = {\n", 162 | " \"no depth\": pd.read_csv('../room0_nerf_no_depth_results.csv'),\n", 163 | " \"ours\": pd.read_csv('../room0_nerf_ours_results.csv'),\n", 164 | " \"raw\": pd.read_csv('../room0_nerf_raw_results.csv'),\n", 165 | "}\n", 166 | "plot_results(dfs, show=True, save=\"room0_results\", xlim=25000)" 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": null, 172 | "id": "7fa4ea26", 173 | "metadata": {}, 174 | "outputs": [], 175 | "source": [ 176 | "dfs = {\n", 177 | " \"no depth\": pd.read_csv('../room1_nerf_no_depth_results.csv'),\n", 178 | " \"ours\": pd.read_csv('../room1_nerf_ours_results.csv'),\n", 179 | " \"raw\": pd.read_csv('../room1_nerf_raw_results.csv'),\n", 180 | "}\n", 181 | "plot_results(dfs, show=True, save=\"room0_results\", xlim=25000)" 182 | ] 183 | }, 184 | { 185 | "cell_type": "code", 186 | "execution_count": null, 187 | "id": "5676517b", 188 | "metadata": {}, 189 | "outputs": [], 190 | "source": [ 191 | "dfs = {\n", 192 | " \"no depth\": pd.read_csv('../room2_nerf_no_depth_results.csv'),\n", 193 | " \"ours\": pd.read_csv('../room2_nerf_ours_results.csv'),\n", 194 | " \"raw\": pd.read_csv('../room2_nerf_raw_results.csv'),\n", 195 | "}\n", 196 | "plot_results(dfs, show=True, save=\"room0_results\", xlim=25000)" 197 | ] 198 | }, 199 | { 200 | "cell_type": "code", 201 | "execution_count": null, 202 | "id": "42de5872", 203 | "metadata": {}, 204 | "outputs": [], 205 | "source": [ 206 | "dfs = {\n", 207 | " \"no depth\": pd.read_csv('../office0_nerf_no_depth_results.csv'),\n", 208 | " \"ours\": pd.read_csv('../office0_nerf_ours_results.csv'),\n", 209 | " \"raw\": pd.read_csv('../office0_nerf_raw_results.csv'),\n", 210 | "}\n", 211 | "plot_results(dfs, show=True, save=\"room0_results\", xlim=25000)" 212 | ] 213 | }, 214 | { 215 | "cell_type": "code", 216 | "execution_count": null, 217 | "id": "a45fc9c7", 218 | "metadata": {}, 219 | "outputs": [], 220 | "source": [ 221 | "dfs = {\n", 222 | " \"no depth\": pd.read_csv('../office1_nerf_no_depth_results.csv'),\n", 223 | " \"ours\": pd.read_csv('../office1_nerf_ours_results.csv'),\n", 224 | " \"raw\": pd.read_csv('../office1_nerf_raw_results.csv'),\n", 225 | "}\n", 226 | "plot_results(dfs, show=True, save=\"room0_results\", xlim=25000)" 227 | ] 228 | }, 229 | { 230 | "cell_type": "code", 231 | "execution_count": null, 232 | "id": "08cb721d", 233 | "metadata": {}, 234 | "outputs": [], 235 | "source": [ 236 | "dfs = {\n", 237 | " \"no depth\": pd.read_csv('../office2_nerf_no_depth_results.csv'),\n", 238 | " \"ours\": pd.read_csv('../office2_nerf_ours_results.csv'),\n", 239 | " \"raw\": pd.read_csv('../office2_nerf_raw_results.csv'),\n", 240 | "}\n", 241 | "plot_results(dfs, show=True, save=\"room0_results\", xlim=25000)" 242 | ] 243 | }, 244 | { 245 | "cell_type": "code", 246 | "execution_count": null, 247 | "id": "09db9418", 248 | "metadata": {}, 249 | "outputs": [], 250 | "source": [ 251 | "dfs = {\n", 252 | " \"no depth\": pd.read_csv('../office3_nerf_no_depth_results.csv'),\n", 253 | " \"ours\": pd.read_csv('../office3_nerf_ours_results.csv'),\n", 254 | " \"raw\": pd.read_csv('../office3_nerf_raw_results.csv'),\n", 255 | "}\n", 256 | "plot_results(dfs, show=True, save=\"room0_results\", xlim=25000)" 257 | ] 258 | }, 259 | { 260 | "cell_type": "code", 261 | "execution_count": null, 262 | "id": "e40a7acc", 263 | "metadata": {}, 264 | "outputs": [], 265 | "source": [ 266 | "dfs = {\n", 267 | " \"no depth\": pd.read_csv('../office4_nerf_no_depth_results.csv'),\n", 268 | " \"ours\": pd.read_csv('../office4_nerf_ours_results.csv'),\n", 269 | " \"raw\": pd.read_csv('../office4_nerf_raw_results.csv'),\n", 270 | "}\n", 271 | "plot_results(dfs, show=True, save=\"room0_results\", xlim=25000)" 272 | ] 273 | }, 274 | { 275 | "cell_type": "code", 276 | "execution_count": null, 277 | "id": "02bbbd47", 278 | "metadata": {}, 279 | "outputs": [], 280 | "source": [] 281 | } 282 | ], 283 | "metadata": { 284 | "kernelspec": { 285 | "display_name": "Python 3 (ipykernel)", 286 | "language": "python", 287 | "name": "python3" 288 | }, 289 | "language_info": { 290 | "codemirror_mode": { 291 | "name": "ipython", 292 | "version": 3 293 | }, 294 | "file_extension": ".py", 295 | "mimetype": "text/x-python", 296 | "name": "python", 297 | "nbconvert_exporter": "python", 298 | "pygments_lexer": "ipython3", 299 | "version": "3.8.10" 300 | } 301 | }, 302 | "nbformat": 4, 303 | "nbformat_minor": 5 304 | } 305 | -------------------------------------------------------------------------------- /scripts/download_cube.bash: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | mkdir -p Datasets 4 | cd Datasets 5 | git clone https://github.com/jc211/nerf-cube-diorama-dataset.git 6 | -------------------------------------------------------------------------------- /scripts/download_replica.bash: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | mkdir -p Datasets 4 | cd Datasets 5 | 6 | # This is 9.2Gb of data: contains all the images, transforms (.json), and meshes (.ply). 7 | gdown 1tdioYdNGK6yZZdfyQKkQI84ALtEqd2fH 8 | 9 | unzip Replica.zip 10 | 11 | -------------------------------------------------------------------------------- /scripts/download_replica_sample.bash: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | mkdir -p Datasets 4 | cd Datasets 5 | 6 | gdown 1f4RJ4W9uxlhCvihG12X9Wa4yZrlBD4TE 7 | 8 | mkdir -p Replica/ 9 | unzip ReplicaSample.zip -d Replica 10 | 11 | -------------------------------------------------------------------------------- /scripts/record_real_sense.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import os 4 | import sys 5 | sys.settrace 6 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 7 | 8 | from icecream import ic 9 | import argparse 10 | 11 | import numpy as np 12 | 13 | import cv2 14 | 15 | from datasets.data_module import DataModule 16 | 17 | def parse_args(): 18 | parser = argparse.ArgumentParser(description="RealSense Recorder") 19 | return parser.parse_args() 20 | 21 | 22 | if __name__ == '__main__': 23 | args = parse_args() 24 | 25 | args.parallel_run = False 26 | args.dataset_dir = "/home/tonirv/Datasets/RealSense" 27 | args.dataset_name = "real" 28 | 29 | args.initial_k = 0 30 | args.final_k = 0 31 | args.img_stride = 1 32 | args.stereo = False 33 | 34 | data_provider_module = DataModule(args.dataset_name, args, device="cpu") 35 | data_provider_module.initialize_module() 36 | 37 | # Start once we press 's' 38 | print("Waiting to start recorgin, press 's'.") 39 | cv2.imshow("Click 's' to start; 'q' to stop", np.ones((200,200))) 40 | while cv2.waitKey(33) != ord('s'): 41 | continue 42 | 43 | print("Recording") 44 | data_packets = [] 45 | while cv2.waitKey(33) != ord('q'): # quit 46 | output = data_provider_module.spin_once("aha") 47 | data_packets += [output] 48 | print("Stopping") 49 | ic(len(data_packets)) 50 | 51 | # Write images to disk, and save path with to_nerf() function 52 | print("Saving to nerf format") 53 | data_provider_module.dataset.to_nerf_format(data_packets) 54 | print('Done...') 55 | 56 | -------------------------------------------------------------------------------- /scripts/replica_results.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import os 4 | import sys 5 | sys.settrace 6 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 7 | 8 | from icecream import ic 9 | 10 | import argparse 11 | from tqdm import tqdm 12 | import copy 13 | import torch 14 | 15 | from examples.slam_demo import run 16 | 17 | def parse_args(): 18 | parser = argparse.ArgumentParser(description="INSTANT SLAM") 19 | return parser.parse_args() 20 | 21 | 22 | if __name__ == '__main__': 23 | args = parse_args() 24 | 25 | args.parallel_run = True 26 | args.multi_gpu = True 27 | 28 | args.initial_k = 0 29 | args.final_k = 2000 30 | 31 | args.stereo = False 32 | 33 | args.weights = "droid.pth" 34 | args.network = "" 35 | args.width = 0 36 | args.height = 0 37 | 38 | args.dataset_dir = None 39 | args.dataset_name = "nerf" 40 | 41 | args.slam = True 42 | args.fusion = 'nerf' 43 | args.gui = True 44 | 45 | args.eval = True 46 | 47 | args.sigma_fusion = True 48 | 49 | 50 | args.buffer = 100 51 | args.img_stride = 1 52 | 53 | # "ours", "ours_w_thresh" or "raw", "no_depth" 54 | args1 = copy.deepcopy(args) 55 | args1.mask_type = "ours" 56 | 57 | args2 = copy.deepcopy(args) 58 | args2.mask_type = "no_depth" 59 | 60 | args3 = copy.deepcopy(args) 61 | args3.mask_type = "raw" 62 | 63 | args4 = copy.deepcopy(args) 64 | args4.mask_type = "ours_w_thresh" 65 | 66 | args_to_test = { 67 | #args1.mask_type: args1, 68 | #args2.mask_type: args2, 69 | args3.mask_type: args3, 70 | #args4.mask_type: args4, 71 | } 72 | 73 | 74 | results = "results.csv" 75 | datasets = [ 76 | #"/home/tonirv/Datasets/nerf-cube-diorama-dataset/room", 77 | #"/home/tonirv/Datasets/nerf-cube-diorama-dataset/bluebell", 78 | #"/home/tonirv/Datasets/nerf-cube-diorama-dataset/book", 79 | #"/home/tonirv/Datasets/nerf-cube-diorama-dataset/cup", 80 | #"/home/tonirv/Datasets/nerf-cube-diorama-dataset/laptop" 81 | #"/home/tonirv/Datasets/Replica/office0", 82 | #"/home/tonirv/Datasets/Replica/office2", 83 | #"/home/tonirv/Datasets/Replica/office3", 84 | #"/home/tonirv/Datasets/Replica/office4", 85 | #"/home/tonirv/Datasets/Replica/room1", 86 | #"/home/tonirv/Datasets/Replica/room2", 87 | # These get stuck not sure why.... 88 | #"/home/tonirv/Datasets/Replica/room0", 89 | "/home/tonirv/Datasets/Replica/office1", 90 | ] 91 | 92 | torch.multiprocessing.set_start_method('spawn') 93 | torch.cuda.empty_cache() 94 | torch.backends.cudnn.benchmark = True 95 | torch.set_grad_enabled(False) 96 | 97 | for dataset in tqdm(datasets): 98 | for test_name, test_args in args_to_test.items(): 99 | output = os.path.basename(dataset) + '_nerf_' + test_name + '_' + results 100 | ic(output) 101 | test_args.dataset_dir = dataset 102 | print(f"Processing dataset: {test_args.dataset_dir}") 103 | try: 104 | run(test_args) 105 | except Exception as e: 106 | print(e) 107 | print(f"Saving output in: {output}") 108 | # Copy transforms.json to its args.dataset_dir folder 109 | os.replace(results, os.path.join(output)) 110 | #torch.cuda.empty_cache() 111 | print('Done...') 112 | 113 | -------------------------------------------------------------------------------- /scripts/replica_to_nerf_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import os 4 | import sys 5 | sys.settrace 6 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 7 | 8 | import argparse 9 | from tqdm import tqdm 10 | 11 | from datasets.data_module import DataModule 12 | 13 | def parse_args(): 14 | parser = argparse.ArgumentParser(description="INSTANT SLAM") 15 | parser.add_argument("--replica_dir", type=str, 16 | help="Path to the Replica dataset root dir", 17 | default="/home/tonirv/Datasets/Replica/") 18 | 19 | return parser.parse_args() 20 | 21 | if __name__ == '__main__': 22 | args = parse_args() 23 | args.parallel_run = True 24 | args.dataset_dir = None 25 | args.initial_k = 0 # first frame to load 26 | args.final_k = -1 # last frame to load, if -1 load all 27 | args.img_stride = 1 # stride for loading images 28 | args.stereo = False 29 | args.buffer = 3000 30 | 31 | transform = "transforms.json" 32 | dataset_names = ["room0", "room1", "room2", "office0", "office1", "office2", "office3", "office4"] 33 | for dataset in tqdm(dataset_names): 34 | args.dataset_dir = os.path.join(args.replica_dir, dataset) 35 | print(f"Processing dataset: {args.dataset_dir}") 36 | # Parse dataset and transform to Nerf 37 | data_provider_module = DataModule("replica", args) 38 | data_provider_module.initialize_module() 39 | data_provider_module.dataset.to_nerf_format() 40 | # Copy transforms.json to its args.dataset_dir folder 41 | os.replace(transform, os.path.join(args.dataset_dir, transform)) 42 | print('Done...') 43 | 44 | -------------------------------------------------------------------------------- /scripts/unzip_tartan_air.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import glob 4 | import os 5 | import os.path as osp 6 | from tqdm import tqdm 7 | import shutil 8 | 9 | cur_path = osp.dirname(osp.abspath(__file__)) 10 | 11 | # Download dataset should be: 12 | # python download_training.py --rgb --depth --only-left --output-dir ./datasets/TartanAir 13 | 14 | # Dirs inside datasets/TartanAir must be 15 | # {dataset}/{level}/P***/{depth_left,image_left,pose_left.txt} 16 | 17 | LEVELS = ['Easy', 'Hard'] 18 | 19 | def unzip(tartanair_path='datasets/TartanAir', remove_zip=False): 20 | datasets_paths = glob.glob(osp.join(tartanair_path, "*")) 21 | for dataset in tqdm(sorted(datasets_paths)): 22 | dataset_name = os.path.basename(dataset) 23 | print("Dataset: %s" % dataset_name) 24 | for level in LEVELS: 25 | print("Level: %s"%(level)) 26 | 27 | ### Form paths and pass checks 28 | dataset_level_path = osp.join(dataset, level) 29 | depth_dataset_level_path = osp.join(dataset_level_path, "depth_left.zip") 30 | if not osp.exists(depth_dataset_level_path): 31 | print("Missing Depth zip file for Dataset/Level: %s/%s" %(dataset, level)) 32 | continue 33 | image_dataset_level_path = osp.join(dataset_level_path, "image_left.zip") 34 | if not osp.exists(image_dataset_level_path): 35 | print("Missing Image zip file for Dataset/Level: %s/%s"%(dataset, level)) 36 | continue 37 | 38 | if osp.exists(osp.join(dataset_level_path, dataset_name)) or len(glob.glob(osp.join(dataset_level_path, "P*"))) != 0: 39 | print("Seems like the dataset was already unzipped? %s" % dataset) 40 | else: 41 | ### Unzip dataset 42 | command = "unzip -q -n %s -d %s"%(depth_dataset_level_path, dataset_level_path) 43 | print(command) 44 | os.system(command) 45 | if remove_zip: 46 | os.remove(depth_dataset_level_path) 47 | command = "unzip -q -n %s -d %s"%(image_dataset_level_path, dataset_level_path) 48 | print(command) 49 | os.system(command) 50 | if remove_zip: 51 | os.remove(image_dataset_level_path) 52 | 53 | ### Remove junk directories 54 | from_ = osp.join(dataset_level_path, "*/*/*/P*") 55 | if len(glob.glob(from_)) != 0: # We have junk folders 56 | to_ = dataset_level_path 57 | command = "mv %s %s"%(from_,to_) 58 | print(command) 59 | os.system(command) 60 | shutil.rmtree(osp.join(dataset_level_path,dataset_name)) 61 | 62 | import argparse 63 | 64 | def dir_path(string): 65 | if os.path.isdir(string): 66 | return string 67 | else: 68 | raise NotADirectoryError(string) 69 | 70 | if __name__ == "__main__": 71 | parser = argparse.ArgumentParser() 72 | parser.add_argument('--dataset_path', type=dir_path) 73 | parser.add_argument('--remove_zip', action="store_true") 74 | args = parser.parse_args() 75 | 76 | print("Unzipping TartanAir dataset") 77 | unzip(args.dataset_path, args.remove_zip) 78 | 79 | 80 | 81 | 82 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from setuptools import setup 4 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 5 | 6 | import os.path as osp 7 | ROOT = osp.dirname(osp.abspath(__file__)) 8 | 9 | setup( 10 | name='droid_backends', 11 | ext_modules=[ 12 | CUDAExtension('droid_backends', 13 | include_dirs=[osp.join(ROOT, 'thirdparty/eigen')], 14 | sources=[ 15 | 'src/droid.cpp', 16 | 'src/droid_kernels.cu', 17 | 'src/correlation_kernels.cu', 18 | 'src/altcorr_kernel.cu', 19 | ], 20 | extra_compile_args={ 21 | 'cxx': ['-O3'], 22 | 'nvcc': ['-O3', 23 | '-gencode=arch=compute_60,code=sm_60', 24 | '-gencode=arch=compute_61,code=sm_61', 25 | '-gencode=arch=compute_70,code=sm_70', 26 | '-gencode=arch=compute_75,code=sm_75', 27 | '-gencode=arch=compute_80,code=sm_80', 28 | '-gencode=arch=compute_86,code=sm_86', 29 | ] 30 | }), 31 | ], 32 | cmdclass={ 'build_ext' : BuildExtension } 33 | ) 34 | 35 | setup( 36 | name='lietorch', 37 | version='0.2', 38 | description='Lie Groups for PyTorch', 39 | packages=['lietorch'], 40 | package_dir={'': 'thirdparty/lietorch'}, 41 | ext_modules=[ 42 | CUDAExtension('lietorch_backends', 43 | include_dirs=[ 44 | osp.join(ROOT, 'thirdparty/lietorch/lietorch/include'), 45 | osp.join(ROOT, 'thirdparty/eigen')], 46 | sources=[ 47 | 'thirdparty/lietorch/lietorch/src/lietorch.cpp', 48 | 'thirdparty/lietorch/lietorch/src/lietorch_gpu.cu', 49 | 'thirdparty/lietorch/lietorch/src/lietorch_cpu.cpp'], 50 | extra_compile_args={ 51 | 'cxx': ['-O2'], 52 | 'nvcc': ['-O2', 53 | '-gencode=arch=compute_60,code=sm_60', 54 | '-gencode=arch=compute_61,code=sm_61', 55 | '-gencode=arch=compute_70,code=sm_70', 56 | '-gencode=arch=compute_75,code=sm_75', 57 | '-gencode=arch=compute_80,code=sm_80', 58 | '-gencode=arch=compute_86,code=sm_86', 59 | ] 60 | }), 61 | ], 62 | cmdclass={ 'build_ext' : BuildExtension } 63 | ) 64 | 65 | -------------------------------------------------------------------------------- /slam/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | -------------------------------------------------------------------------------- /slam/inertial_frontends/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | 4 | -------------------------------------------------------------------------------- /slam/inertial_frontends/inertial_frontend.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | 4 | from __future__ import print_function 5 | 6 | from abc import abstractclassmethod 7 | 8 | from icecream import ic 9 | 10 | import torch as th 11 | from torch import nn 12 | from factor_graph.variables import Variable, Variables 13 | from factor_graph.factor import Factor 14 | from factor_graph.factor_graph import TorchFactorGraph 15 | 16 | from slam.meta_slam import SLAM 17 | import numpy as np 18 | 19 | import gtsam 20 | from gtsam import (ImuFactor, Pose3, Rot3, Point3) 21 | from gtsam import PriorFactorPose3, PriorFactorConstantBias, PriorFactorVector 22 | from gtsam.symbol_shorthand import B, V, X 23 | 24 | 25 | def vector3(x, y, z): 26 | """Create 3d double numpy array.""" 27 | return np.array([x, y, z], dtype=float) 28 | 29 | 30 | class InertialFrontend(nn.Module): 31 | def __init__(self): 32 | super().__init__() 33 | 34 | @abstractclassmethod 35 | def forward(self, mini_batch): 36 | pass 37 | 38 | 39 | import gtsam 40 | from gtsam import Values 41 | from gtsam import NonlinearFactorGraph as FactorGraph 42 | from gtsam.symbol_shorthand import B 43 | 44 | import torch as th 45 | 46 | 47 | class PreIntegrationInertialFrontend(InertialFrontend): 48 | def __init__(self): 49 | super().__init__() 50 | self.last_key = None 51 | 52 | # THESE ARE A LOT OF ALLOCATIONS: #frames * imu_buffer_size * 7 * 2 (because double precision!) 53 | # self.imu_buffer = 200 54 | # self.imu_t0_t1 = th.empty(self.imu_buffer, 7, device='cpu', dtype=th.float64)#.share_memory_() 55 | 56 | def initialize_imu_frontend(): 57 | pass 58 | 59 | def forward(self, mini_batch, last_state): 60 | # Call IMU preintegration from kimera/gtsam or re-implement... 61 | # The output of PreIntegrationInertialFrontend is a 62 | # bunch of (pose,vel,bias)-to-(pose,vel,bias) factors. 63 | print("PreIntegrationInertialFrontend.forward") 64 | k = mini_batch["k"] 65 | if last_state is None and k == 0: 66 | self.last_key = k 67 | # Initialize the preintegration 68 | self.imu_params = mini_batch["imu_calib"] 69 | self.preint_imu_params = self.get_preint_imu_params(self.imu_params) 70 | initial_state = self.initial_state() 71 | self.pim = gtsam.PreintegratedImuMeasurements(self.preint_imu_params, 72 | initial_state[2]) 73 | # Set initial priors and values 74 | x0, factors = self.initial_priors_and_values(k, initial_state) 75 | return x0, factors 76 | 77 | imu_t0_t1 = mini_batch["imu_t0_t1"] 78 | # imu_meas_count = len(imu_t0_t1_df) 79 | # self.imu_t0_t1[:imu_meas_count] = th.as_tensor(imu_t0_t1_df, device='cpu', dtype=th.float64).cpu().numpy() 80 | 81 | # Integrate measurements between frames (101 measurements) 82 | #self.preintegrate_imu(self.imu_t0_t1, imu_meas_count) 83 | self.preintegrate_imu(imu_t0_t1, -1) 84 | #self.pim.print() 85 | 86 | #if k % 10 != 0: # simulate keyframe selection 87 | # return Values(), FactorGraph() 88 | 89 | # Get factors 90 | imu_factor = ImuFactor(X(self.last_key), V(self.last_key), X(k), V(k), B(self.last_key), self.pim) 91 | bias_btw_factor = self.get_bias_btw_factor(k, self.pim.deltaTij()) 92 | 93 | # Add factors to inertial graph 94 | graph = FactorGraph() 95 | graph.add(imu_factor) 96 | graph.add(bias_btw_factor) 97 | 98 | print("LAST STATE") 99 | print(last_state) 100 | # Get guessed values (from IMU integration) 101 | last_W_Pose_B, last_W_Vel_B, last_imu_bias = \ 102 | last_state.atPose3(X(self.last_key)), \ 103 | last_state.atVector(V(self.last_key)),\ 104 | last_state.atConstantBias(B(self.last_key)) 105 | last_navstate = gtsam.NavState(last_W_Pose_B, last_W_Vel_B) 106 | new_navstate = self.pim.predict(last_navstate, last_imu_bias); 107 | 108 | x0 = Values() 109 | x0.insert(X(k), new_navstate.pose()) 110 | x0.insert(V(k), new_navstate.velocity()) 111 | x0.insert(B(k), last_imu_bias) 112 | 113 | self.last_key = k 114 | self.pim.resetIntegrationAndSetBias(last_imu_bias) 115 | 116 | return x0, graph 117 | 118 | # k: current key, n: number of imu measurements 119 | def preintegrate_imu(self, imu_t0_t1, n): 120 | meas_acc = imu_t0_t1[:n, 4:7] 121 | meas_gyr = imu_t0_t1[:n, 1:4] 122 | delta_t = (imu_t0_t1[1:n, 0] - imu_t0_t1[0:n-1, 0]) * 1e-9 # convert to seconds 123 | for acc, gyr, dt in zip(meas_acc, meas_gyr, delta_t): 124 | self.pim.integrateMeasurement(acc, gyr, dt) 125 | # TODO: fix this loop! 126 | #self.pim.integrateMeasurements(meas_acc, meas_gyr, delta_t) 127 | 128 | def get_bias_btw_factor(self, k, delta_t_ij): 129 | # Bias evolution as given in the IMU metadata 130 | sqrt_delta_t_ij = np.sqrt(delta_t_ij); 131 | bias_sigmas_acc = sqrt_delta_t_ij * self.imu_params.a_b * np.ones(3) 132 | bias_sigmas_gyr = sqrt_delta_t_ij * self.imu_params.g_b * np.ones(3) 133 | bias_sigmas = np.concatenate((bias_sigmas_acc, bias_sigmas_gyr), axis=None) 134 | bias_noise_model = gtsam.noiseModel.Diagonal.Sigmas(bias_sigmas) 135 | bias_value = gtsam.imuBias.ConstantBias() 136 | return gtsam.BetweenFactorConstantBias(B(self.last_key), B(k), bias_value, bias_noise_model) 137 | 138 | # Send IMU priors 139 | def initial_priors_and_values(self, k, initial_state): 140 | pose_key = X(k) 141 | vel_key = V(k) 142 | bias_key = B(k) 143 | 144 | pose_noise = gtsam.noiseModel.Diagonal.Sigmas( 145 | np.array([0.001, 0.001, 0.001, 0.01, 0.01, 0.01])) 146 | vel_noise = gtsam.noiseModel.Isotropic.Sigma(3, 0.001) 147 | bias_noise = gtsam.noiseModel.Isotropic.Sigma(6, 0.01) 148 | 149 | initial_pose, initial_vel, initial_bias = initial_state 150 | 151 | # Get inertial factors 152 | pose_prior = PriorFactorPose3(pose_key, initial_pose, pose_noise) 153 | vel_prior = PriorFactorVector(vel_key, initial_vel, vel_noise) 154 | bias_prior = PriorFactorConstantBias(bias_key, initial_bias, bias_noise) 155 | 156 | # Add factors to inertial graph 157 | graph = FactorGraph() 158 | graph.push_back(pose_prior) 159 | graph.push_back(vel_prior) 160 | graph.push_back(bias_prior) 161 | 162 | # Get guessed values 163 | x0 = Values() 164 | x0.insert(pose_key, initial_pose) 165 | x0.insert(vel_key, initial_vel) 166 | x0.insert(bias_key, initial_bias) 167 | 168 | return x0, graph 169 | 170 | def initial_state(self): 171 | true_pose = gtsam.Pose3(gtsam.Rot3(0.060514,-0.828459,-0.058956,-0.553641), # qw, qx, qy, qz 172 | gtsam.Point3(0.878612,2.142470,0.947262)) 173 | true_vel = np.array([0.009474,-0.014009,-0.002145]) 174 | true_bias = gtsam.imuBias.ConstantBias(np.array([-0.012492,0.547666,0.069073]), np.array([-0.002229,0.020700,0.076350])) 175 | naive_pose = gtsam.Pose3.identity() 176 | naive_vel = np.zeros(3) 177 | naive_bias = gtsam.imuBias.ConstantBias() 178 | initial_pose = true_pose 179 | initial_vel = true_vel 180 | initial_bias = true_bias 181 | return initial_pose, initial_vel, initial_bias 182 | 183 | def get_preint_imu_params(self, imu_calib): 184 | I = np.eye(3) 185 | preint_params = gtsam.PreintegrationParams(imu_calib.n_gravity); 186 | preint_params.setAccelerometerCovariance(np.power(imu_calib.a_n, 2.0) * I) 187 | preint_params.setGyroscopeCovariance(np.power(imu_calib.g_n, 2.0) * I) 188 | preint_params.setIntegrationCovariance(np.power(imu_calib.imu_integration_sigma, 2.0) * I) 189 | preint_params.setUse2ndOrderCoriolis(False) 190 | preint_params.setOmegaCoriolis(np.zeros(3, dtype=float)) 191 | preint_params.print() 192 | return preint_params -------------------------------------------------------------------------------- /slam/meta_slam.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from abc import abstractclassmethod 4 | from icecream import ic 5 | 6 | from torch import nn 7 | 8 | from factor_graph.variables import Variable, Variables 9 | from factor_graph.factor_graph import FactorGraphManager 10 | 11 | from solvers.nonlinear_solver import iSAM2, LevenbergMarquardt, Solver 12 | 13 | # An abstract learned SLAM model 14 | class SLAM(nn.Module): 15 | def __init__(self, name, args, device): 16 | super().__init__() 17 | self.name = name 18 | self.args = args 19 | self.device = device 20 | self.factor_graph_manager = FactorGraphManager() 21 | self.state = None 22 | self.delta = None 23 | 24 | # This is our spin_once 25 | def forward(self, batch): 26 | # TODO Parallelize frontend/backend 27 | assert("data" in batch) 28 | 29 | # Frontend 30 | output = self._frontend(batch["data"], self.state, self.delta) 31 | if output == False: 32 | return output 33 | else: 34 | x0, factors, viz_out = output 35 | 36 | # Backend 37 | self.factor_graph_manager.add(factors) 38 | factor_graph = self.factor_graph_manager.get_factor_graph() 39 | self.state, self.delta = self._backend(factor_graph, x0) 40 | if type(self.backend) is iSAM2: 41 | self.factor_graph_manager.reset_factor_graph() 42 | self.backend = iSAM2() # uncomment if running DroidSLAM 43 | return [self.state, viz_out] 44 | 45 | # Converts sensory inputs to factors and initial guess (x0) 46 | @abstractclassmethod # Implemented by derived classes 47 | def _frontend(self, mini_batch, last_state, last_delta): 48 | raise 49 | 50 | # Solves the factor graph, given an initial estimate 51 | @abstractclassmethod # Implemented by derived classes 52 | def _backend(self, factor_graph, x0): 53 | raise 54 | 55 | 56 | class MetaSLAM(SLAM): 57 | def __init__(self, name, device): 58 | super().__init__(name, device) 59 | 60 | @abstractclassmethod 61 | def calculate_loss(self, x, mini_batch): 62 | # mini_batch contains the ground-truth parameters as well 63 | pass 64 | -------------------------------------------------------------------------------- /slam/slam_module.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from pipeline.pipeline_module import MIMOPipelineModule 4 | 5 | class SlamModule(MIMOPipelineModule): 6 | def __init__(self, name, args, device="cpu"): 7 | super().__init__(name, args.parallel_run, args) 8 | self.device = device 9 | 10 | def spin_once(self, input): 11 | output = self.slam(input) 12 | if not output or self.slam.stop_condition(): 13 | super().shutdown_module() 14 | return output 15 | 16 | def initialize_module(self): 17 | if self.name == "VioSLAM": 18 | from slam.vio_slam import VioSLAM 19 | self.slam = VioSLAM(self.name, self.args, self.device) 20 | else: 21 | raise NotImplementedError 22 | return super().initialize_module() 23 | -------------------------------------------------------------------------------- /slam/vio_slam.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from abc import abstractclassmethod 4 | 5 | from icecream import ic 6 | 7 | from torch import nn 8 | from factor_graph.variables import Variable, Variables 9 | from factor_graph.factor_graph import TorchFactorGraph 10 | 11 | from gtsam import Values 12 | from gtsam import NonlinearFactorGraph 13 | from gtsam import GaussianFactorGraph 14 | 15 | from slam.meta_slam import SLAM 16 | from slam.inertial_frontends.inertial_frontend import PreIntegrationInertialFrontend 17 | from slam.visual_frontends.visual_frontend import RaftVisualFrontend 18 | from solvers.nonlinear_solver import iSAM2, LevenbergMarquardt, Solver 19 | 20 | ########################### REMOVE ############################ 21 | import numpy as np 22 | import gtsam 23 | from gtsam import (ImuFactor, Pose3, Rot3, Point3) 24 | from gtsam import PriorFactorPose3, PriorFactorConstantBias, PriorFactorVector 25 | from gtsam.symbol_shorthand import B, V, X 26 | 27 | 28 | # Send IMU priors 29 | def initial_priors_and_values(k, initial_state): 30 | pose_key = X(k) 31 | vel_key = V(k) 32 | bias_key = B(k) 33 | 34 | pose_noise = gtsam.noiseModel.Diagonal.Sigmas( 35 | np.array([0.000001, 0.000001, 0.000001, 0.00001, 0.00001, 0.00001])) 36 | vel_noise = gtsam.noiseModel.Isotropic.Sigma(3, 0.000001) 37 | bias_noise = gtsam.noiseModel.Isotropic.Sigma(6, 0.00001) 38 | 39 | initial_pose, initial_vel, initial_bias = initial_state 40 | 41 | # Get inertial factors 42 | pose_prior = PriorFactorPose3(pose_key, initial_pose, pose_noise) 43 | vel_prior = PriorFactorVector(vel_key, initial_vel, vel_noise) 44 | bias_prior = PriorFactorConstantBias(bias_key, initial_bias, bias_noise) 45 | 46 | # Add factors to inertial graph 47 | graph = NonlinearFactorGraph() 48 | graph.push_back(pose_prior) 49 | graph.push_back(vel_prior) 50 | graph.push_back(bias_prior) 51 | 52 | # Get guessed values 53 | x0 = Values() 54 | x0.insert(pose_key, initial_pose) 55 | x0.insert(vel_key, initial_vel) 56 | x0.insert(bias_key, initial_bias) 57 | 58 | return x0, graph 59 | 60 | def initial_state(): 61 | true_world_T_imu_t0 = gtsam.Pose3(gtsam.Rot3(0.060514, -0.828459, -0.058956, -0.553641), # qw, qx, qy, qz 62 | gtsam.Point3(0.878612, 2.142470, 0.947262)) 63 | true_vel = np.array([0.009474,-0.014009,-0.002145]) 64 | true_bias = gtsam.imuBias.ConstantBias(np.array([-0.012492,0.547666,0.069073]), np.array([-0.002229,0.020700,0.076350])) 65 | naive_pose = gtsam.Pose3.identity() 66 | naive_vel = np.zeros(3) 67 | naive_bias = gtsam.imuBias.ConstantBias() 68 | initial_pose = true_world_T_imu_t0 69 | initial_vel = true_vel 70 | initial_bias = true_bias 71 | initial_pose = naive_pose 72 | initial_vel = naive_vel 73 | initial_bias = naive_bias 74 | return initial_pose, initial_vel, initial_bias 75 | ############################################################### 76 | 77 | 78 | class VioSLAM(SLAM): 79 | def __init__(self, name, args, device): 80 | super().__init__(name, args, device) 81 | world_T_imu_t0, _, _ = initial_state() 82 | imu_T_cam0 = Pose3(np.array([[0.0148655429818, -0.999880929698, 0.00414029679422, -0.0216401454975], 83 | [0.999557249008, 0.0149672133247, 0.025715529948, -0.064676986768], 84 | [-0.0257744366974, 0.00375618835797, 0.999660727178, 0.00981073058949], 85 | [0.0, 0.0, 0.0, 1.0]])) 86 | 87 | imu_T_cam0 = Pose3(np.eye(4)) 88 | 89 | #world_T_imu_t0 = Pose3(args.world_T_imu_t0) 90 | #world_T_imu_t0 = Pose3(np.eye(4)) 91 | world_T_imu_t0 = Pose3(np.array( 92 | [[-7.6942980e-02, -3.1037781e-01, 9.4749427e-01, 8.9643948e-02], 93 | [-2.8366595e-10, -9.5031142e-01, -3.1130061e-01, 4.1829333e-01], 94 | [ 9.9703550e-01, -2.3952398e-02, 7.3119797e-02, 4.8306200e-01], 95 | [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 1.0000000e+00]])) 96 | 97 | self.visual_frontend = RaftVisualFrontend(world_T_imu_t0, imu_T_cam0, args, device=device) 98 | #self.inertial_frontend = PreIntegrationInertialFrontend() 99 | 100 | self.backend = iSAM2() 101 | 102 | self.values = Values() 103 | self.inertial_factors = NonlinearFactorGraph() 104 | self.visual_factors = NonlinearFactorGraph() 105 | 106 | self.last_state = None 107 | 108 | def stop_condition(self): 109 | return self.visual_frontend.stop_condition() 110 | 111 | # Converts sensory inputs to measurements and initial guess 112 | def _frontend(self, batch, last_state, last_delta): 113 | # Compute optical flow 114 | x0_visual, visual_factors, viz_out = self.visual_frontend(batch) # TODO: currently also calls BA, and global BA 115 | self.last_state = x0_visual 116 | 117 | if x0_visual is None: 118 | return False 119 | 120 | # Wrap guesses 121 | x0 = Values() 122 | factors = NonlinearFactorGraph() 123 | 124 | return x0, factors, viz_out 125 | 126 | def _backend(self, factor_graph, x0): 127 | return self.backend.solve(factor_graph, x0) 128 | 129 | -------------------------------------------------------------------------------- /slam/visual_frontends/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | -------------------------------------------------------------------------------- /solvers/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | 4 | -------------------------------------------------------------------------------- /solvers/linear_solver.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import enum 4 | import colored_glog as log 5 | import torch as th 6 | 7 | class LinearSolverType(enum.Enum): 8 | Inverse = 0 9 | Cholesky = 1 10 | QR = 2 # to be implemented 11 | MultifrontalQR = 3 # to be implemented 12 | CG = 4 # Conjugate Gradient, to be implemented 13 | # Gauss-Seidel for dense flow? 14 | # Use scipy.optimize... 15 | 16 | # Linear Solver 17 | class LinearLS: 18 | # loss = |Y - ([X, 1] * [w, b])|^2_sigma 19 | # [X, 1] == A, Y - ([X, 1] * [^w, ^b]) == b [w, b] == x 20 | @staticmethod 21 | def solve_inverse(A, b, E): 22 | # loss = |Ax - b|^2_sigma(\theta) 23 | At = th.transpose(A, 1, 2) 24 | AtE = At @ E 25 | AtEA = AtE @ A 26 | AtEb = AtE @ b 27 | x = th.linalg.inv(AtEA) @ AtEb 28 | return x 29 | 30 | # TODO(Toni): We should make the sparse version of this 31 | # A has size: (B, M, N) where B = batch, M = measurements, N = variables 32 | # w has size: (B, M, 1) and is the weight for each measurement (assumes diagonal weights) 33 | # b has size: (B, N) 34 | @staticmethod 35 | def solve_cholesky(A: th.Tensor, b: th.Tensor, w: th.Tensor) -> th.Tensor: 36 | # loss = |Ax - b|^2_w 37 | # Solve all Batch linear problems 38 | if A is None or b is None or w is None: 39 | return None 40 | 41 | # Check dimensionality, first is the batch dimension 42 | B, M = b.shape 43 | _, _, N = A.shape 44 | log.check_eq(th.Size([B, M, N]), A.shape) 45 | log.check_eq(th.Size([B, M]), w.shape) # We assume E diagonal: E=diag(w) 46 | WA = A * w.unsqueeze(-1) # multiply each row by the weight (uses broadcasting) 47 | AtWt = WA.transpose(1,2) 48 | AtEA = AtWt @ WA 49 | AtWtb = AtWt @ b.unsqueeze(-1) 50 | L = th.linalg.cholesky(AtEA) 51 | y = th.linalg.solve_triangular(L, AtWtb, upper=False) 52 | x = th.linalg.solve_triangular(th.transpose(L, 1, 2), y, upper=True) 53 | return x.squeeze(-1) 54 | 55 | # TODO(Toni): We should make the sparse version of this 56 | # A has size: (B, M, N) where B = batch, M = measurements, N = variables <- In standard form 57 | # b has size: (B, N) 58 | # @staticmethod 59 | # def solve_cholesky(A, b): 60 | # # loss = |Ax - b|^2_2 61 | # # Solve all Batch linear problems 62 | 63 | # # Check dimensionality, first is the batch dimension 64 | # B, M = b.shape 65 | # _, _, N = A.shape 66 | # log.check_eq(th.Size([B, M, N]), A.shape) 67 | # At = A.transpose(1, 2) # Batched A transposes 68 | # AtA = At @ A 69 | # Atb = At @ b.unsqueeze(-1) 70 | # L = th.linalg.cholesky(AtA) 71 | # y = th.linalg.solve_triangular(L, Atb, upper=False) 72 | # x = th.linalg.solve_triangular(th.transpose(L, 1, 2), y, upper=True) 73 | # return x.squeeze(-1) 74 | 75 | # For pyth 1.11 76 | @staticmethod 77 | def solve_cholesky_11(A, b, E): 78 | # loss = |Ax - b|^2_sigma(\theta) 79 | At = th.transpose(A, 1, 2) 80 | AtE = At @ E 81 | AtEA = AtE @ A 82 | AtEb = AtE @ b 83 | L = th.linalg.cholesky(AtEA) 84 | x = th.linalg.cholesky_solve(AtEb, L, upper=False) 85 | return x 86 | 87 | -------------------------------------------------------------------------------- /solvers/meta_solver.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import torch as th 4 | from torch.utils.data import DataLoader 5 | 6 | from slam.meta_slam import MetaSLAM 7 | 8 | class MetaSolver: 9 | def __init__(self, epochs: int, optimizer: th.optim.Optimizer, meta_slam: MetaSLAM): 10 | self.epochs = epochs 11 | self.optimizer = optimizer 12 | self.meta_slam = meta_slam 13 | self.log_every_n = 1 14 | 15 | def solve(self, dataloader: DataLoader): 16 | for epoch in range(self.epochs): 17 | print(f"Epoch {epoch}\n-------------------------------") 18 | for i, mini_batch in enumerate(dataloader): 19 | print(f"Batch {i}\n-------------------------------") 20 | 21 | # Build and solve factor graphs 22 | x = self.meta_slam(mini_batch) 23 | if self.logging_callback: 24 | self.logging_callback() 25 | 26 | # Compute and print loss 27 | if self.optimizer: 28 | loss = self.meta_slam.calculate_loss(x, mini_batch) 29 | if i % self.log_every_n == 0: 30 | print(f"outer_loss: {loss.item():>7f} [{i:>5d}/{len(dataloader):>5d}]") 31 | 32 | # Zero gradients, perform a backward pass, and update the weights from the outer optimization 33 | self.optimizer.zero_grad() 34 | loss.backward() 35 | self.optimizer.step() 36 | 37 | def register_logging_callback(self, log): 38 | self.logging_callback = log -------------------------------------------------------------------------------- /solvers/nonlinear_solver.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from abc import abstractclassmethod 4 | 5 | import colored_glog as log 6 | 7 | import torch as th 8 | 9 | from factor_graph.variables import Variable, Variables 10 | from factor_graph.factor_graph import TorchFactorGraph 11 | 12 | from solvers.linear_solver import LinearLS, LinearSolverType 13 | 14 | import gtsam 15 | from gtsam import (ISAM2, LevenbergMarquardtOptimizer, NonlinearFactorGraph, PriorFactorPose2, Values) 16 | from gtsam import NonlinearFactorGraph as FactorGraph 17 | 18 | from icecream import ic 19 | 20 | class Solver: 21 | def __init__(self): 22 | pass 23 | 24 | @abstractclassmethod 25 | def solve(self, factor_graph, x0): 26 | raise 27 | 28 | class iSAM2(Solver): 29 | 30 | def __init__(self): 31 | super().__init__() 32 | # Set ISAM2 parameters and create ISAM2 solver object 33 | isam_params = gtsam.ISAM2Params() 34 | 35 | # Dogleg params 36 | dogleg_params = gtsam.ISAM2DoglegParams() 37 | dogleg_params.setInitialDelta(1.0) 38 | dogleg_params.setWildfireThreshold(1e-5) 39 | dogleg_params.setVerbose(True) 40 | # dogleg_params.setAdaptationMode(string adaptationMode); 41 | 42 | # Gauss-Newton params 43 | gauss_newton_params = gtsam.ISAM2GaussNewtonParams() 44 | gauss_newton_params.setWildfireThreshold(1e-5) 45 | 46 | # Optimization parameters 47 | isam_params.setOptimizationParams(gauss_newton_params) 48 | isam_params.setFactorization("CHOLESKY") # QR or Cholesky 49 | 50 | # Linearization parameters 51 | isam_params.enableRelinearization = True 52 | isam_params.enablePartialRelinearizationCheck = False 53 | isam_params.setRelinearizeThreshold(0.1) # TODO 54 | isam_params.relinearizeSkip = 10 55 | 56 | # Memory efficiency, but slower 57 | isam_params.findUnusedFactorSlots = True 58 | 59 | # Debugging parameters, disable for speed 60 | isam_params.evaluateNonlinearError = True 61 | isam_params.enableDetailedResults = True 62 | 63 | #isam_params.print() 64 | 65 | self.isam2 = gtsam.ISAM2(isam_params) 66 | 67 | def solve(self, factor_graph, x0): 68 | # factors_to_remove = factor_graph.keyVector() # since all are structureless, and we re-add all inertial 69 | # print(factors_to_remove) 70 | # self.isam2.update(factor_graph, x0, factors_to_remove) 71 | self.isam2.update(factor_graph, x0) # Only one iteration!! 72 | result = self.isam2.calculateEstimate() 73 | delta = self.isam2.getDelta() 74 | return result, delta 75 | 76 | # class GaussNewton(Solver): 77 | # def __init__(self): 78 | # super().__init__() 79 | # self.params = gtsam.GaussNewtonParams() 80 | # self.params.setVerbosityLM("SUMMARY") 81 | # self.x0 = Values() 82 | # 83 | # def solve(self, factor_graph, x0): 84 | # self.x0.insert(x0) 85 | # optimizer = gtsam.LevenbergMarquardtOptimizer(factor_graph, self.x0, self.params) 86 | # return optimizer.optimize() 87 | 88 | class LevenbergMarquardt(Solver): 89 | def __init__(self): 90 | super().__init__() 91 | self.params = gtsam.LevenbergMarquardtParams() 92 | # static void SetLegacyDefaults(LevenbergMarquardtParams* p) { 93 | # // Relevant NonlinearOptimizerParams: 94 | # p->maxIterations = 100; 95 | # p->relativeErrorTol = 1e-5; 96 | # p->absoluteErrorTol = 1e-5; 97 | # // LM-specific: 98 | # p->lambdaInitial = 1e-5; 99 | # p->lambdaFactor = 10.0; 100 | # p->lambdaUpperBound = 1e5; 101 | # p->lambdaLowerBound = 0.0; 102 | # p->minModelFidelity = 1e-3; 103 | # p->diagonalDamping = false; 104 | # p->useFixedLambdaFactor = true; 105 | self.params.setVerbosityLM("SUMMARY") 106 | self.x0 = Values() 107 | 108 | def solve(self, factor_graph, x0): 109 | self.x0.insert(x0) 110 | optimizer = gtsam.LevenbergMarquardtOptimizer(factor_graph, self.x0, self.params) 111 | return optimizer.optimize() 112 | 113 | 114 | class NonlinearLS(Solver): 115 | def __init__(self, iterations=2, min_x_update=0.001, linear_solver=LinearSolverType.Cholesky): 116 | super().__init__() 117 | self.iters = iterations 118 | self.min_x_update = min_x_update 119 | if linear_solver == LinearSolverType.Inverse: 120 | self.linear_solver = LinearLS.solve_inverse 121 | elif linear_solver == LinearSolverType.Cholesky: 122 | self.linear_solver = LinearLS.solve_cholesky 123 | else: 124 | raise ValueError(f'Unknown linear solver: {linear_solver}') 125 | 126 | # Set all data members to 0 127 | self._reset_history() 128 | 129 | # Logging 130 | self.log_every_n = 1 131 | 132 | # f is a nonlinear function of the states resulting in measurements 133 | # w is the measurement weights 134 | # x0 is the linearization point (initial guess) 135 | def solve(self, fg: TorchFactorGraph, x0: Variables) -> th.Tensor: 136 | # Start iterative call to nonlinear solve 137 | self._reset_history() 138 | return self._solve_nonlinear(fg, x0, 1) 139 | 140 | # TODO(Toni): abstract this, and perhaps even call iSAM2 from gtsam... 141 | # f is expected to be in standard form, meaning weighted/whitened. 142 | def _solve_nonlinear(self, fg: TorchFactorGraph, x0: Variables, i: int) -> th.Tensor: 143 | if fg.is_empty(): 144 | log.warn("Factor graph is empty, returning initial guess...") 145 | return x0 146 | 147 | # 1. Linearize nonlinear system with respect to x, this can do Schur complement as well. 148 | A, b, w = fg.linearize(x0) 149 | 150 | # 2. Solve linear system 151 | delta_x = self.linear_solver(A, b, w) 152 | 153 | if delta_x is None: 154 | log.warn("Linear system is not invertible, returning initial guess...") 155 | return x0 156 | 157 | # 3. Retract/Update 158 | x_new = Variables() 159 | for var in x0.retract(delta_x).items(): 160 | x_new.add(var[1]) 161 | 162 | # 4. Store best solutions so far 163 | with th.no_grad(): 164 | loss_new = fg(x_new) 165 | if self.x_best is None or th.sum(loss_new) < th.sum(self.loss_best): 166 | self.x_best = x_new 167 | self.loss_best = loss_new 168 | 169 | # Logging 170 | self._logging(x0, delta_x, loss_new, i) 171 | 172 | # 5. Repeat until we reach termination condition 173 | return x_new if self.terminate(i, x_new) else self._solve_nonlinear(fg, x_new, i + 1) 174 | 175 | def _logging(self, x0, delta_x, loss, i): 176 | self.x0_list.append(x0) 177 | self.delta_x_list.append(delta_x) 178 | 179 | if i % self.log_every_n == 0: 180 | # Print sum over batches loss 181 | print(f"inner_loss: {th.sum(loss):>7f} [{i:>5d}/{self.iters:>5d}]") 182 | 183 | 184 | # should be abstract 185 | def terminate(self, i, x): 186 | # one implementation is just to look at how many iters we have done 187 | convergence = False 188 | # with th.no_grad(): 189 | # ic(x.shape) 190 | # # We need to determine per-batch convergence :O 191 | # # And how do we avoid per-batch updates for the converged ones? 192 | # convergence = th.abs(x - self.x0_list[-1]) < self.x_tol 193 | # convergence = self.delta_x_list[-1] < self.min_x_update 194 | return i >= self.iters or convergence 195 | 196 | 197 | def _reset_history(self): 198 | self.x_best = None 199 | self.loss_best = None 200 | self.A_list = [] 201 | self.b_list = [] 202 | self.x0_list = [] 203 | self.delta_x_list = [] 204 | -------------------------------------------------------------------------------- /src/correlation_kernels.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | 9 | #include 10 | #include 11 | #include 12 | 13 | #define BLOCK 16 14 | 15 | __forceinline__ __device__ bool within_bounds(int h, int w, int H, int W) { 16 | return h >= 0 && h < H && w >= 0 && w < W; 17 | } 18 | 19 | template 20 | __global__ void corr_index_forward_kernel( 21 | const torch::PackedTensorAccessor32 volume, 22 | const torch::PackedTensorAccessor32 coords, 23 | torch::PackedTensorAccessor32 corr, 24 | int r) 25 | { 26 | // batch index 27 | const int x = blockIdx.x * blockDim.x + threadIdx.x; 28 | const int y = blockIdx.y * blockDim.y + threadIdx.y; 29 | const int n = blockIdx.z; 30 | 31 | const int h1 = volume.size(1); 32 | const int w1 = volume.size(2); 33 | const int h2 = volume.size(3); 34 | const int w2 = volume.size(4); 35 | 36 | if (!within_bounds(y, x, h1, w1)) { 37 | return; 38 | } 39 | 40 | float x0 = coords[n][0][y][x]; 41 | float y0 = coords[n][1][y][x]; 42 | 43 | float dx = x0 - floor(x0); 44 | float dy = y0 - floor(y0); 45 | 46 | int rd = 2*r + 1; 47 | for (int i=0; i(floor(x0)) - r + i; 50 | int y1 = static_cast(floor(y0)) - r + j; 51 | 52 | if (within_bounds(y1, x1, h2, w2)) { 53 | scalar_t s = volume[n][y][x][y1][x1]; 54 | 55 | if (i > 0 && j > 0) 56 | corr[n][i-1][j-1][y][x] += s * scalar_t(dx * dy); 57 | 58 | if (i > 0 && j < rd) 59 | corr[n][i-1][j][y][x] += s * scalar_t(dx * (1.0f-dy)); 60 | 61 | if (i < rd && j > 0) 62 | corr[n][i][j-1][y][x] += s * scalar_t((1.0f-dx) * dy); 63 | 64 | if (i < rd && j < rd) 65 | corr[n][i][j][y][x] += s * scalar_t((1.0f-dx) * (1.0f-dy)); 66 | 67 | } 68 | } 69 | } 70 | } 71 | 72 | 73 | template 74 | __global__ void corr_index_backward_kernel( 75 | const torch::PackedTensorAccessor32 coords, 76 | const torch::PackedTensorAccessor32 corr_grad, 77 | torch::PackedTensorAccessor32 volume_grad, 78 | int r) 79 | { 80 | // batch index 81 | const int x = blockIdx.x * blockDim.x + threadIdx.x; 82 | const int y = blockIdx.y * blockDim.y + threadIdx.y; 83 | const int n = blockIdx.z; 84 | 85 | const int h1 = volume_grad.size(1); 86 | const int w1 = volume_grad.size(2); 87 | const int h2 = volume_grad.size(3); 88 | const int w2 = volume_grad.size(4); 89 | 90 | if (!within_bounds(y, x, h1, w1)) { 91 | return; 92 | } 93 | 94 | float x0 = coords[n][0][y][x]; 95 | float y0 = coords[n][1][y][x]; 96 | 97 | float dx = x0 - floor(x0); 98 | float dy = y0 - floor(y0); 99 | 100 | int rd = 2*r + 1; 101 | for (int i=0; i(floor(x0)) - r + i; 104 | int y1 = static_cast(floor(y0)) - r + j; 105 | 106 | if (within_bounds(y1, x1, h2, w2)) { 107 | scalar_t g = 0.0; 108 | if (i > 0 && j > 0) 109 | g += corr_grad[n][i-1][j-1][y][x] * scalar_t(dx * dy); 110 | 111 | if (i > 0 && j < rd) 112 | g += corr_grad[n][i-1][j][y][x] * scalar_t(dx * (1.0f-dy)); 113 | 114 | if (i < rd && j > 0) 115 | g += corr_grad[n][i][j-1][y][x] * scalar_t((1.0f-dx) * dy); 116 | 117 | if (i < rd && j < rd) 118 | g += corr_grad[n][i][j][y][x] * scalar_t((1.0f-dx) * (1.0f-dy)); 119 | 120 | volume_grad[n][y][x][y1][x1] += g; 121 | } 122 | } 123 | } 124 | } 125 | 126 | std::vector corr_index_cuda_forward( 127 | torch::Tensor volume, 128 | torch::Tensor coords, 129 | int radius) 130 | { 131 | const auto batch_size = volume.size(0); 132 | const auto ht = volume.size(1); 133 | const auto wd = volume.size(2); 134 | 135 | const dim3 blocks((wd + BLOCK - 1) / BLOCK, 136 | (ht + BLOCK - 1) / BLOCK, 137 | batch_size); 138 | 139 | const dim3 threads(BLOCK, BLOCK); 140 | 141 | auto opts = volume.options(); 142 | torch::Tensor corr = torch::zeros( 143 | {batch_size, 2*radius+1, 2*radius+1, ht, wd}, opts); 144 | 145 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(volume.type(), "sampler_forward_kernel", ([&] { 146 | corr_index_forward_kernel<<>>( 147 | volume.packed_accessor32(), 148 | coords.packed_accessor32(), 149 | corr.packed_accessor32(), 150 | radius); 151 | })); 152 | 153 | return {corr}; 154 | 155 | } 156 | 157 | std::vector corr_index_cuda_backward( 158 | torch::Tensor volume, 159 | torch::Tensor coords, 160 | torch::Tensor corr_grad, 161 | int radius) 162 | { 163 | const auto batch_size = volume.size(0); 164 | const auto ht = volume.size(1); 165 | const auto wd = volume.size(2); 166 | 167 | auto volume_grad = torch::zeros_like(volume); 168 | 169 | const dim3 blocks((wd + BLOCK - 1) / BLOCK, 170 | (ht + BLOCK - 1) / BLOCK, 171 | batch_size); 172 | 173 | const dim3 threads(BLOCK, BLOCK); 174 | 175 | 176 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(volume.type(), "sampler_backward_kernel", ([&] { 177 | corr_index_backward_kernel<<>>( 178 | coords.packed_accessor32(), 179 | corr_grad.packed_accessor32(), 180 | volume_grad.packed_accessor32(), 181 | radius); 182 | })); 183 | 184 | return {volume_grad}; 185 | } -------------------------------------------------------------------------------- /src/droid.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include 6 | #include 7 | 8 | // CUDA forward declarations 9 | std::vector projective_transform_cuda( 10 | torch::Tensor poses, 11 | torch::Tensor disps, 12 | torch::Tensor intrinsics, 13 | torch::Tensor ii, 14 | torch::Tensor jj); 15 | 16 | 17 | 18 | torch::Tensor depth_filter_cuda( 19 | torch::Tensor poses, 20 | torch::Tensor disps, 21 | torch::Tensor intrinsics, 22 | torch::Tensor ix, 23 | torch::Tensor thresh); 24 | 25 | 26 | torch::Tensor frame_distance_cuda( 27 | torch::Tensor poses, 28 | torch::Tensor disps, 29 | torch::Tensor intrinsics, 30 | torch::Tensor ii, 31 | torch::Tensor jj, 32 | const float beta); 33 | 34 | std::vector projmap_cuda( 35 | torch::Tensor poses, 36 | torch::Tensor disps, 37 | torch::Tensor intrinsics, 38 | torch::Tensor ii, 39 | torch::Tensor jj); 40 | 41 | torch::Tensor iproj_cuda( 42 | torch::Tensor poses, 43 | torch::Tensor disps, 44 | torch::Tensor intrinsics); 45 | 46 | std::vector ba_cuda( 47 | torch::Tensor poses, 48 | torch::Tensor body_poses, 49 | torch::Tensor disps, 50 | torch::Tensor intrinsics, 51 | torch::Tensor extrinsics, 52 | torch::Tensor disps_sens, 53 | torch::Tensor targets, 54 | torch::Tensor weights, 55 | torch::Tensor eta, 56 | torch::Tensor ii, 57 | torch::Tensor jj, 58 | const int t0, 59 | const int t1, 60 | const int iterations, 61 | const float lm, 62 | const float ep, 63 | const bool motion_only); 64 | 65 | std::vector 66 | reduced_camera_matrix_cuda( 67 | torch::Tensor poses, 68 | torch::Tensor body_poses, 69 | torch::Tensor disps, 70 | torch::Tensor intrinsics, 71 | torch::Tensor extrinsics, 72 | torch::Tensor disps_sens, 73 | torch::Tensor targets, 74 | torch::Tensor weights, 75 | torch::Tensor eta, 76 | torch::Tensor ii, 77 | torch::Tensor jj, 78 | const int t0, 79 | const int t1); 80 | 81 | void solve_depth_cuda( 82 | torch::Tensor dx, 83 | torch::Tensor disps, 84 | torch::Tensor Q, 85 | torch::Tensor E, 86 | torch::Tensor w, 87 | torch::Tensor ii, 88 | torch::Tensor jj, 89 | const int t0, 90 | const int t1); 91 | 92 | // void solve_cuda( 93 | // torch::Tensor A, 94 | // torch::Tensor S, 95 | // const float lm, 96 | // const float ep); 97 | // 98 | void solve_poses_cuda( 99 | torch::Tensor poses, 100 | torch::Tensor dx, 101 | const int t0, 102 | const int t1); 103 | 104 | std::vector corr_index_cuda_forward( 105 | torch::Tensor volume, 106 | torch::Tensor coords, 107 | int radius); 108 | 109 | std::vector corr_index_cuda_backward( 110 | torch::Tensor volume, 111 | torch::Tensor coords, 112 | torch::Tensor corr_grad, 113 | int radius); 114 | 115 | std::vector altcorr_cuda_forward( 116 | torch::Tensor fmap1, 117 | torch::Tensor fmap2, 118 | torch::Tensor coords, 119 | int radius); 120 | 121 | std::vector altcorr_cuda_backward( 122 | torch::Tensor fmap1, 123 | torch::Tensor fmap2, 124 | torch::Tensor coords, 125 | torch::Tensor corr_grad, 126 | int radius); 127 | 128 | 129 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 130 | #define CHECK_INPUT(x) CHECK_CONTIGUOUS(x) 131 | 132 | 133 | std::vector ba( 134 | torch::Tensor poses, 135 | torch::Tensor body_poses, 136 | torch::Tensor disps, 137 | torch::Tensor intrinsics, 138 | torch::Tensor extrinsics, 139 | torch::Tensor disps_sens, 140 | torch::Tensor targets, 141 | torch::Tensor weights, 142 | torch::Tensor eta, 143 | torch::Tensor ii, 144 | torch::Tensor jj, 145 | const int t0, 146 | const int t1, 147 | const int iterations, 148 | const float lm, 149 | const float ep, 150 | const bool motion_only) { 151 | 152 | CHECK_INPUT(targets); 153 | CHECK_INPUT(weights); 154 | CHECK_INPUT(poses); 155 | CHECK_INPUT(body_poses); 156 | CHECK_INPUT(disps); 157 | CHECK_INPUT(intrinsics); 158 | CHECK_INPUT(disps_sens); 159 | CHECK_INPUT(ii); 160 | CHECK_INPUT(jj); 161 | 162 | return ba_cuda(poses, body_poses, disps, intrinsics, extrinsics, disps_sens, targets, weights, 163 | eta, ii, jj, t0, t1, iterations, lm, ep, motion_only); 164 | 165 | } 166 | 167 | std::vector 168 | reduced_camera_matrix( 169 | torch::Tensor poses, 170 | torch::Tensor body_poses, 171 | torch::Tensor disps, 172 | torch::Tensor intrinsics, 173 | torch::Tensor extrinsics, 174 | torch::Tensor disps_sens, 175 | torch::Tensor targets, 176 | torch::Tensor weights, 177 | torch::Tensor eta, 178 | torch::Tensor ii, 179 | torch::Tensor jj, 180 | const int t0, 181 | const int t1) { 182 | 183 | CHECK_INPUT(poses); 184 | CHECK_INPUT(disps); 185 | CHECK_INPUT(intrinsics); 186 | CHECK_INPUT(extrinsics); 187 | CHECK_INPUT(disps_sens); 188 | CHECK_INPUT(targets); 189 | CHECK_INPUT(weights); 190 | CHECK_INPUT(eta); 191 | CHECK_INPUT(ii); 192 | CHECK_INPUT(jj); 193 | 194 | return reduced_camera_matrix_cuda(poses, body_poses, disps, intrinsics, extrinsics, disps_sens, targets, weights, 195 | eta, ii, jj, t0, t1); 196 | } 197 | 198 | void solve_depth( 199 | torch::Tensor dx, 200 | torch::Tensor disps, 201 | torch::Tensor Q, 202 | torch::Tensor E, 203 | torch::Tensor w, 204 | torch::Tensor ii, 205 | torch::Tensor jj, 206 | const int t0, 207 | const int t1) { 208 | 209 | CHECK_INPUT(dx); 210 | CHECK_INPUT(disps); 211 | CHECK_INPUT(Q); 212 | CHECK_INPUT(E); 213 | CHECK_INPUT(w); 214 | CHECK_INPUT(ii); 215 | CHECK_INPUT(jj); 216 | 217 | return solve_depth_cuda(dx, disps, Q, E, w, ii, jj, t0, t1); 218 | } 219 | 220 | void solve_poses(torch::Tensor poses, 221 | torch::Tensor dx, 222 | const int t0, 223 | const int t1) { 224 | CHECK_INPUT(poses); 225 | CHECK_INPUT(dx); 226 | 227 | return solve_poses_cuda(poses, dx, t0, t1); 228 | } 229 | 230 | torch::Tensor frame_distance( 231 | torch::Tensor poses, 232 | torch::Tensor disps, 233 | torch::Tensor intrinsics, 234 | torch::Tensor ii, 235 | torch::Tensor jj, 236 | const float beta) { 237 | 238 | CHECK_INPUT(poses); 239 | CHECK_INPUT(disps); 240 | CHECK_INPUT(intrinsics); 241 | CHECK_INPUT(ii); 242 | CHECK_INPUT(jj); 243 | 244 | return frame_distance_cuda(poses, disps, intrinsics, ii, jj, beta); 245 | 246 | } 247 | 248 | 249 | std::vector projmap( 250 | torch::Tensor poses, 251 | torch::Tensor disps, 252 | torch::Tensor intrinsics, 253 | torch::Tensor ii, 254 | torch::Tensor jj) { 255 | 256 | CHECK_INPUT(poses); 257 | CHECK_INPUT(disps); 258 | CHECK_INPUT(intrinsics); 259 | CHECK_INPUT(ii); 260 | CHECK_INPUT(jj); 261 | 262 | return projmap_cuda(poses, disps, intrinsics, ii, jj); 263 | 264 | } 265 | 266 | 267 | torch::Tensor iproj( 268 | torch::Tensor poses, 269 | torch::Tensor disps, 270 | torch::Tensor intrinsics) { 271 | CHECK_INPUT(poses); 272 | CHECK_INPUT(disps); 273 | CHECK_INPUT(intrinsics); 274 | 275 | return iproj_cuda(poses, disps, intrinsics); 276 | } 277 | 278 | 279 | // c++ python binding 280 | std::vector corr_index_forward( 281 | torch::Tensor volume, 282 | torch::Tensor coords, 283 | int radius) { 284 | CHECK_INPUT(volume); 285 | CHECK_INPUT(coords); 286 | 287 | return corr_index_cuda_forward(volume, coords, radius); 288 | } 289 | 290 | std::vector corr_index_backward( 291 | torch::Tensor volume, 292 | torch::Tensor coords, 293 | torch::Tensor corr_grad, 294 | int radius) { 295 | CHECK_INPUT(volume); 296 | CHECK_INPUT(coords); 297 | CHECK_INPUT(corr_grad); 298 | 299 | auto volume_grad = corr_index_cuda_backward(volume, coords, corr_grad, radius); 300 | return {volume_grad}; 301 | } 302 | 303 | std::vector altcorr_forward( 304 | torch::Tensor fmap1, 305 | torch::Tensor fmap2, 306 | torch::Tensor coords, 307 | int radius) { 308 | CHECK_INPUT(fmap1); 309 | CHECK_INPUT(fmap2); 310 | CHECK_INPUT(coords); 311 | 312 | return altcorr_cuda_forward(fmap1, fmap2, coords, radius); 313 | } 314 | 315 | std::vector altcorr_backward( 316 | torch::Tensor fmap1, 317 | torch::Tensor fmap2, 318 | torch::Tensor coords, 319 | torch::Tensor corr_grad, 320 | int radius) { 321 | CHECK_INPUT(fmap1); 322 | CHECK_INPUT(fmap2); 323 | CHECK_INPUT(coords); 324 | CHECK_INPUT(corr_grad); 325 | 326 | return altcorr_cuda_backward(fmap1, fmap2, coords, corr_grad, radius); 327 | } 328 | 329 | 330 | torch::Tensor depth_filter( 331 | torch::Tensor poses, 332 | torch::Tensor disps, 333 | torch::Tensor intrinsics, 334 | torch::Tensor ix, 335 | torch::Tensor thresh) { 336 | 337 | CHECK_INPUT(poses); 338 | CHECK_INPUT(disps); 339 | CHECK_INPUT(intrinsics); 340 | CHECK_INPUT(ix); 341 | CHECK_INPUT(thresh); 342 | 343 | return depth_filter_cuda(poses, disps, intrinsics, ix, thresh); 344 | } 345 | 346 | 347 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 348 | // bundle adjustment kernels 349 | m.def("ba", &ba, "bundle adjustment"); 350 | m.def("reduced_camera_matrix", &reduced_camera_matrix, "reduced camera matrix"); 351 | m.def("solve_depth", &solve_depth, "Retract depth given dx"); 352 | m.def("solve_poses", &solve_poses, "Retract poses given dx"); 353 | m.def("frame_distance", &frame_distance, "frame_distance"); 354 | m.def("projmap", &projmap, "projmap"); 355 | m.def("depth_filter", &depth_filter, "depth_filter"); 356 | m.def("iproj", &iproj, "back projection"); 357 | 358 | // correlation volume kernels 359 | m.def("altcorr_forward", &altcorr_forward, "ALTCORR forward"); 360 | m.def("altcorr_backward", &altcorr_backward, "ALTCORR backward"); 361 | m.def("corr_index_forward", &corr_index_forward, "INDEX forward"); 362 | m.def("corr_index_backward", &corr_index_backward, "INDEX backward"); 363 | } 364 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | 4 | -------------------------------------------------------------------------------- /utils/evaluation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import numpy as np 4 | import open3d as o3d 5 | from icecream import ic 6 | 7 | class MeshRenderer: 8 | def __init__(self, mesh_path, intrinsics, resolution) -> None: 9 | self.mesh = o3d.io.read_triangle_mesh(mesh_path) 10 | # Do we need to register the gt_mesh to the SLAM frame of ref? I don't think so because we start SLAM with gt pose 11 | # Nonethelesssss, we need to register the SLAM's estimated trajectory with the ground-truth using Sim(3)! 12 | # And scale the gt_mesh with the resulting scale parameter, 13 | # OR! Just render depths at the ground-truth trajectory and register the depth-maps with scale? (I think that is what ORbeez doess) 14 | 15 | focal_length = intrinsics[:2] 16 | principal_point = intrinsics[2:] 17 | fx, fy = focal_length[0], focal_length[1] 18 | cx, cy = principal_point[0], principal_point[1] 19 | w, h = resolution[0], resolution[1] 20 | 21 | self.cam = o3d.camera.PinholeCameraParameters() 22 | self.cam.intrinsic = o3d.camera.PinholeCameraIntrinsic(w, h, fx, fy, cx, cy) 23 | 24 | self.viz = o3d.visualization.Visualizer() 25 | self.viz.create_window(width=w, height=h) 26 | self.viz.get_render_option().mesh_show_back_face = True 27 | 28 | self.mesh_uploaded = False 29 | 30 | self.viz_world_frame = False 31 | 32 | def render_mesh(self, c2w): 33 | viewport = self.viz.get_view_control().convert_to_pinhole_camera_parameters() 34 | 35 | self.cam.extrinsic = np.linalg.inv(c2w) 36 | 37 | if self.viz_world_frame: 38 | self.viz.add_geometry(self.create_frame_actor(np.eye(4), scale=0.5), reset_bounding_box=True) 39 | self.viz_world_frame = False 40 | 41 | if not self.mesh_uploaded: 42 | self.viz.add_geometry(self.mesh, reset_bounding_box=True) 43 | self.mesh_uploaded = True 44 | 45 | #ctr = self.viz.get_view_control() 46 | #ctr.set_constant_z_far(20) 47 | #ctr.convert_from_pinhole_camera_parameters(self.cam) 48 | 49 | self.viz.poll_events() 50 | self.viz.update_renderer() 51 | 52 | gt_depth = self.viz.capture_depth_float_buffer(True) 53 | gt_depth = np.asarray(gt_depth) 54 | 55 | # hack to allow interacting when using add_geometry 56 | viewport = self.viz.get_view_control().convert_from_pinhole_camera_parameters(viewport) 57 | 58 | self.viz.poll_events() 59 | self.viz.update_renderer() 60 | 61 | return gt_depth 62 | 63 | 64 | def create_frame_actor(self, pose, scale=0.05): 65 | frame_actor = o3d.geometry.TriangleMesh.create_coordinate_frame(size=scale, 66 | origin=np.array([0., 0., 0.])) 67 | frame_actor.transform(pose) 68 | return frame_actor -------------------------------------------------------------------------------- /utils/open3d_pickle.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import open3d as o3d 4 | import numpy as np 5 | 6 | 7 | class _MeshTransmissionFormat: 8 | def __init__(self, mesh: o3d.geometry.TriangleMesh): 9 | assert(mesh) 10 | self.triangle_material_ids = np.array(mesh.triangle_material_ids) 11 | self.triangle_normals = np.array(mesh.triangle_normals) 12 | self.triangle_uvs = np.array(mesh.triangle_uvs) 13 | self.triangles = np.array(mesh.triangles) 14 | 15 | self.vertex_colors = np.array(mesh.vertex_colors) 16 | self.vertex_normals = np.array(mesh.vertex_normals) 17 | self.vertices = np.array(mesh.vertices) 18 | 19 | def create_mesh(self) -> o3d.geometry.TriangleMesh: 20 | mesh = o3d.geometry.TriangleMesh() 21 | 22 | mesh.triangle_material_ids = o3d.utility.IntVector( 23 | self.triangle_material_ids) 24 | mesh.triangle_normals = o3d.utility.Vector3dVector( 25 | self.triangle_normals) 26 | mesh.triangle_uvs = o3d.utility.Vector2dVector(self.triangle_uvs) 27 | mesh.triangles = o3d.utility.Vector3iVector(self.triangles) 28 | 29 | mesh.vertex_colors = o3d.utility.Vector3dVector(self.vertex_colors) 30 | mesh.vertex_normals = o3d.utility.Vector3dVector(self.vertex_normals) 31 | 32 | mesh.vertices = o3d.utility.Vector3dVector(self.vertices) 33 | return mesh 34 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import cv2 4 | 5 | import numpy as np 6 | from scipy.spatial.transform import Rotation 7 | 8 | from icecream import ic 9 | import torch 10 | 11 | def qvec2rotmat(qvec): 12 | return np.array([ 13 | [ 14 | 1 - 2 * qvec[2]**2 - 2 * qvec[3]**2, 15 | 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], 16 | 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2] 17 | ], [ 18 | 2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], 19 | 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2, 20 | 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1] 21 | ], [ 22 | 2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], 23 | 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], 24 | 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2 25 | ] 26 | ]) 27 | 28 | def rotmat(a, b): 29 | a, b = a / np.linalg.norm(a), b / np.linalg.norm(b) 30 | try: 31 | v = np.cross(a, b) 32 | except: 33 | pass 34 | c = np.dot(a, b) 35 | s = np.linalg.norm(v) 36 | kmat = np.array([[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]]) 37 | return np.eye(3) + kmat + kmat.dot(kmat) * ((1 - c) / (s ** 2 + 1e-10)) 38 | 39 | def closest_point_2_lines(oa, da, ob, db): # returns point closest to both rays of form o+t*d, and a weight factor that goes to 0 if the lines are parallel 40 | da = da / np.linalg.norm(da) 41 | db = db / np.linalg.norm(db) 42 | try: 43 | c = np.cross(da, db) 44 | except: 45 | pass 46 | denom = np.linalg.norm(c)**2 47 | t = ob - oa 48 | ta = np.linalg.det([t, db, c]) / (denom + 1e-10) 49 | tb = np.linalg.det([t, da, c]) / (denom + 1e-10) 50 | if ta > 0: 51 | ta = 0 52 | if tb > 0: 53 | tb = 0 54 | return (oa+ta*da+ob+tb*db) * 0.5, denom 55 | 56 | def variance_of_laplacian(image): 57 | return cv2.Laplacian(image, cv2.CV_64F).var() 58 | 59 | def sharpness(image): 60 | if image is str: 61 | image = cv2.imread(image) 62 | gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) 63 | fm = variance_of_laplacian(gray) 64 | return fm 65 | 66 | def pose_matrix_to_t_and_quat(pose): 67 | """ convert 4x4 pose matrix to (t, q) """ 68 | q = Rotation.from_matrix(pose[:3, :3]).as_quat() 69 | return np.concatenate([pose[:3, 3], q], axis=0) 70 | 71 | def get_pose_from_df(gt_df): 72 | x = gt_df.loc['tx'] # [m] 73 | y = gt_df.loc['ty'] # [m] 74 | z = gt_df.loc['tz'] # [m] 75 | qx = gt_df.loc['qx'] 76 | qy = gt_df.loc['qy'] 77 | qz = gt_df.loc['qz'] 78 | qw = gt_df.loc['qw'] 79 | r = Rotation.from_quat([qx, qy, qz, qw]) 80 | pose = np.eye(4) 81 | pose[:3, :3] = r.as_matrix() 82 | pose[:3, 3] = [x, y, z] 83 | return pose 84 | 85 | def get_velocity(euroc_df): 86 | # TODO: rename these 87 | vx = euroc_df.loc[' v_RS_R_x [m s^-1]'] 88 | vy = euroc_df.loc[' v_RS_R_y [m s^-1]'] 89 | vz = euroc_df.loc[' v_RS_R_z [m s^-1]'] 90 | return np.array([vx, vy, vz]) 91 | 92 | def get_bias(euroc_df): 93 | # TODO: rename these 94 | ba_x = euroc_df.loc[' b_a_RS_S_x [m s^-2]'] 95 | ba_y = euroc_df.loc[' b_a_RS_S_y [m s^-2]'] 96 | ba_z = euroc_df.loc[' b_a_RS_S_z [m s^-2]'] 97 | bg_x = euroc_df.loc[' b_w_RS_S_x [rad s^-1]'] 98 | bg_y = euroc_df.loc[' b_w_RS_S_y [rad s^-1]'] 99 | bg_z = euroc_df.loc[' b_w_RS_S_z [rad s^-1]'] 100 | return np.array([ba_x, ba_y, ba_z, bg_x, bg_y, bg_z]) 101 | 102 | 103 | # offse's default is 0.5 bcs we scale/offset poses to 1.0/0.5 before feeding to nerf 104 | def nerf_matrix_to_ngp(nerf_matrix, scale=1.0, offset=0.5): 105 | result = nerf_matrix.copy() 106 | result[:3, 1] *= -1 107 | result[:3, 2] *= -1 108 | result[:3, 3] = result[:3, 3] * scale + offset 109 | 110 | # Cycle axes xyz<-yzx 111 | tmp = result[0, :].copy() 112 | result[0, :] = result[1, :] 113 | result[1, :] = result[2, :] 114 | result[2, :] = tmp 115 | 116 | return result 117 | 118 | # offset's default is 0.5 bcs we scale/offset poses to 1.0/0.5 before feeding to nerf 119 | def ngp_matrix_to_nerf(ngp_matrix, scale=1.0, offset=0.5): 120 | result = ngp_matrix.copy() 121 | 122 | # Cycle axes xyz->yzx 123 | tmp = result[2, :].copy() 124 | result[1, :] = result[0, :] 125 | result[2, :] = result[1, :] 126 | result[0, :] = tmp 127 | 128 | result[:3, 0] *= 1 / scale 129 | result[:3, 1] *= -1 / scale 130 | result[:3, 2] *= -1 / scale 131 | result[:3, 3] = (result[:3, 3] - offset) / scale 132 | 133 | return result; 134 | 135 | # This can be very slow, send to gpu device... 136 | def srgb_to_linear(img, device): 137 | img = img.to(device=device) 138 | limit = 0.04045 139 | return torch.where(img > limit, torch.pow((img + 0.055) / 1.055, 2.4), img / 12.92) 140 | 141 | 142 | def linear_to_srgb(img): 143 | limit = 0.0031308 144 | return np.where(img > limit, 1.055 * (img ** (1.0 / 2.4)) - 0.055, 12.92 * img) 145 | 146 | 147 | def get_scale_and_offset(aabb): 148 | # map the given aabb of the form [[minx,miny,minz],[maxx,maxy,maxz]] 149 | # via an isotropic scale and translate to fit in the (0,0,0)-(1,1,1) cube, 150 | # with the given center at 0.5,0.5,0.5 151 | aabb = np.array(aabb, dtype=np.float64) 152 | dx = aabb[1][0]-aabb[0][0] 153 | dy = aabb[1][1]-aabb[0][1] 154 | dz = aabb[1][2]-aabb[0][2] 155 | length = max(0.000001, max(max(abs(dx), abs(dy)), abs(dz))) 156 | scale = 1.0 / length 157 | offset = np.array([((aabb[1][0]+aabb[0][0])*0.5) * -scale + 0.5, 158 | ((aabb[1][1]+aabb[0][1])*0.5) * -scale + 0.5, 159 | ((aabb[1][2]+aabb[0][2])*0.5) * -scale + 0.5]) 160 | return scale, offset 161 | 162 | 163 | def scale_offset_poses(poses, scale, offset): # for c2w poses! 164 | poses[:, :3, 3] = poses[:, :3, 3] * scale + offset 165 | return poses 166 | 167 | 168 | def mse2psnr(x): 169 | return -10.*np.log(x)/np.log(10.) 170 | 171 | 172 | def L2(img, ref): 173 | return (img - ref)**2 174 | 175 | 176 | def compute_mse_img(img, ref): 177 | img[np.logical_not(np.isfinite(img))] = 0 178 | img = np.maximum(img, 0.) 179 | return L2(img, ref) 180 | 181 | 182 | def compute_error(img, ref): 183 | metric_map = compute_mse_img(img, ref) 184 | metric_map[np.logical_not(np.isfinite(metric_map))] = 0 185 | if len(metric_map.shape) == 3: 186 | metric_map = np.mean(metric_map, axis=2) 187 | mean = np.mean(metric_map) 188 | return mean 189 | --------------------------------------------------------------------------------