├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── calib ├── barn.txt ├── eth.txt ├── euroc.txt ├── tartan.txt └── tum3.txt ├── data └── euroc_groundtruth │ ├── MH_01_easy.txt │ ├── MH_02_easy.txt │ ├── MH_03_medium.txt │ ├── MH_04_difficult.txt │ ├── MH_05_difficult.txt │ ├── V1_01_easy.txt │ ├── V1_02_medium.txt │ ├── V1_03_difficult.txt │ ├── V2_01_easy.txt │ ├── V2_02_medium.txt │ └── V2_03_difficult.txt ├── demo.py ├── droid_slam ├── align.py ├── cuda_timer.py ├── data_readers │ ├── __init__.py │ ├── augmentation.py │ ├── base.py │ ├── factory.py │ ├── rgbd_utils.py │ ├── stream.py │ ├── tartan.py │ └── tartan_test.txt ├── depth_video.py ├── droid.py ├── droid_async.py ├── droid_backend.py ├── droid_frontend.py ├── droid_net.py ├── factor_graph.py ├── geom │ ├── __init__.py │ ├── ba.py │ ├── chol.py │ ├── graph_utils.py │ ├── losses.py │ └── projective_ops.py ├── logger.py ├── modules │ ├── __init__.py │ ├── clipping.py │ ├── corr.py │ ├── extractor.py │ └── gru.py ├── motion_filter.py ├── trajectory_filler.py ├── visualization.py └── visualizer │ ├── __init__.py │ ├── camera.py │ └── droid_visualizer.py ├── environment.yaml ├── environment_novis.yaml ├── evaluation_scripts ├── parse_results.py ├── test_eth3d.py ├── test_euroc.py ├── test_tartanair.py ├── test_tum.py └── validate_tartanair.py ├── misc ├── DROID.png ├── renderoption.json └── screenshot.png ├── requirements.txt ├── requirements_frozen.txt ├── setup.py ├── src ├── altcorr_kernel.cu ├── correlation_kernels.cu ├── droid.cpp └── droid_kernels.cu ├── thirdparty └── tartanair_tools │ ├── LICENSE │ ├── README.md │ ├── TartanAir_Sample.ipynb │ ├── data_type.md │ ├── download_cvpr_slam_test.txt │ ├── download_training.py │ ├── download_training_zipfiles.txt │ ├── evaluation │ ├── __init__.py │ ├── evaluate_ate_scale.py │ ├── evaluate_kitti.py │ ├── evaluate_rpe.py │ ├── evaluator_base.py │ ├── pose_est.txt │ ├── pose_gt.txt │ ├── tartanair_evaluator.py │ ├── trajectory_transform.py │ └── transformation.py │ └── seg_rgbs.txt ├── tools ├── download_eth3d.sh ├── download_euroc.sh ├── download_model.sh ├── download_sample_data.sh ├── download_tartanair_test.sh ├── download_tum.sh ├── evaluate_eth3d.sh ├── evaluate_euroc.sh ├── evaluate_tum.sh └── validate_tartanair.sh ├── train.py └── view_reconstruction.py /.gitignore: -------------------------------------------------------------------------------- 1 | a# Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ 139 | 140 | 141 | 142 | __pycache__ 143 | build 144 | dist 145 | *.egg-info 146 | *.vscode/ 147 | *.pth 148 | tests 149 | checkpoints 150 | datasets 151 | runs 152 | cache 153 | *.out 154 | *.o 155 | data 156 | figures/*.pdf 157 | 158 | 159 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "thirdparty/lietorch"] 2 | path = thirdparty/lietorch 3 | url = https://github.com/princeton-vl/lietorch.git 4 | [submodule "thirdparty/pytorch_scatter"] 5 | path = thirdparty/pytorch_scatter 6 | url = https://github.com/rusty1s/pytorch_scatter.git 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2021, Princeton Vision & Learning Lab 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /calib/barn.txt: -------------------------------------------------------------------------------- 1 | 1161.545689 1161.545689 960.000000 540.000000 -0.025158 0.0 0.0 0.0 -------------------------------------------------------------------------------- /calib/eth.txt: -------------------------------------------------------------------------------- 1 | 726.21081542969 726.21081542969 359.2048034668 202.47247314453 -------------------------------------------------------------------------------- /calib/euroc.txt: -------------------------------------------------------------------------------- 1 | 458.654 457.296 367.215 248.375 -0.28340811 0.07395907 0.00019359 1.76187114e-05 -------------------------------------------------------------------------------- /calib/tartan.txt: -------------------------------------------------------------------------------- 1 | 320.0 320.0 320.0 240.0 -------------------------------------------------------------------------------- /calib/tum3.txt: -------------------------------------------------------------------------------- 1 | 535.4 539.2 320.1 247.6 -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('droid_slam') 3 | 4 | from tqdm import tqdm 5 | import numpy as np 6 | import torch 7 | import lietorch 8 | import cv2 9 | import os 10 | import glob 11 | import time 12 | import argparse 13 | 14 | from torch.multiprocessing import Process 15 | from droid import Droid 16 | from droid_async import DroidAsync 17 | 18 | import torch.nn.functional as F 19 | 20 | 21 | def show_image(image): 22 | image = image.permute(1, 2, 0).cpu().numpy() 23 | cv2.imshow('image', image / 255.0) 24 | cv2.waitKey(1) 25 | 26 | def image_stream(imagedir, calib, stride): 27 | """ image generator """ 28 | 29 | calib = np.loadtxt(calib, delimiter=" ") 30 | fx, fy, cx, cy = calib[:4] 31 | 32 | K = np.eye(3) 33 | K[0,0] = fx 34 | K[0,2] = cx 35 | K[1,1] = fy 36 | K[1,2] = cy 37 | 38 | image_list = sorted(os.listdir(imagedir))[::stride] 39 | 40 | for t, imfile in enumerate(image_list): 41 | image = cv2.imread(os.path.join(imagedir, imfile)) 42 | if len(calib) > 4: 43 | image = cv2.undistort(image, K, calib[4:]) 44 | 45 | h0, w0, _ = image.shape 46 | h1 = int(h0 * np.sqrt((384 * 512) / (h0 * w0))) 47 | w1 = int(w0 * np.sqrt((384 * 512) / (h0 * w0))) 48 | 49 | image = cv2.resize(image, (w1, h1)) 50 | image = image[:h1-h1%8, :w1-w1%8] 51 | image = torch.as_tensor(image).permute(2, 0, 1) 52 | 53 | intrinsics = torch.as_tensor([fx, fy, cx, cy]) 54 | intrinsics[0::2] *= (w1 / w0) 55 | intrinsics[1::2] *= (h1 / h0) 56 | 57 | yield t, image[None], intrinsics 58 | 59 | 60 | def save_reconstruction(droid, save_path): 61 | 62 | if hasattr(droid, "video2"): 63 | video = droid.video2 64 | else: 65 | video = droid.video 66 | 67 | t = video.counter.value 68 | save_data = { 69 | "tstamps": video.tstamp[:t].cpu(), 70 | "images": video.images[:t].cpu(), 71 | "disps": video.disps_up[:t].cpu(), 72 | "poses": video.poses[:t].cpu(), 73 | "intrinsics": video.intrinsics[:t].cpu() 74 | } 75 | 76 | torch.save(save_data, save_path) 77 | 78 | 79 | if __name__ == '__main__': 80 | parser = argparse.ArgumentParser() 81 | parser.add_argument("--imagedir", type=str, help="path to image directory") 82 | parser.add_argument("--calib", type=str, help="path to calibration file") 83 | parser.add_argument("--t0", default=0, type=int, help="starting frame") 84 | parser.add_argument("--stride", default=3, type=int, help="frame stride") 85 | 86 | parser.add_argument("--weights", default="droid.pth") 87 | parser.add_argument("--buffer", type=int, default=512) 88 | parser.add_argument("--image_size", default=[240, 320]) 89 | parser.add_argument("--disable_vis", action="store_true") 90 | 91 | parser.add_argument("--beta", type=float, default=0.3, help="weight for translation / rotation components of flow") 92 | parser.add_argument("--filter_thresh", type=float, default=2.4, help="how much motion before considering new keyframe") 93 | parser.add_argument("--warmup", type=int, default=8, help="number of warmup frames") 94 | parser.add_argument("--keyframe_thresh", type=float, default=4.0, help="threshold to create a new keyframe") 95 | parser.add_argument("--frontend_thresh", type=float, default=16.0, help="add edges between frames whithin this distance") 96 | parser.add_argument("--frontend_window", type=int, default=25, help="frontend optimization window") 97 | parser.add_argument("--frontend_radius", type=int, default=2, help="force edges between frames within radius") 98 | parser.add_argument("--frontend_nms", type=int, default=1, help="non-maximal supression of edges") 99 | 100 | parser.add_argument("--backend_thresh", type=float, default=22.0) 101 | parser.add_argument("--backend_radius", type=int, default=2) 102 | parser.add_argument("--backend_nms", type=int, default=3) 103 | parser.add_argument("--upsample", action="store_true") 104 | parser.add_argument("--asynchronous", action="store_true") 105 | parser.add_argument("--frontend_device", type=str, default="cuda") 106 | parser.add_argument("--backend_device", type=str, default="cuda") 107 | 108 | parser.add_argument("--reconstruction_path", help="path to saved reconstruction") 109 | args = parser.parse_args() 110 | 111 | args.stereo = False 112 | torch.multiprocessing.set_start_method('spawn') 113 | 114 | droid = None 115 | 116 | # need high resolution depths 117 | if args.reconstruction_path is not None: 118 | args.upsample = True 119 | 120 | tstamps = [] 121 | for (t, image, intrinsics) in tqdm(image_stream(args.imagedir, args.calib, args.stride)): 122 | if t < args.t0: 123 | continue 124 | 125 | if not args.disable_vis: 126 | show_image(image[0]) 127 | 128 | if droid is None: 129 | args.image_size = [image.shape[2], image.shape[3]] 130 | droid = DroidAsync(args) if args.asynchronous else Droid(args) 131 | 132 | droid.track(t, image, intrinsics=intrinsics) 133 | 134 | traj_est = droid.terminate(image_stream(args.imagedir, args.calib, args.stride)) 135 | 136 | if args.reconstruction_path is not None: 137 | save_reconstruction(droid, args.reconstruction_path) 138 | -------------------------------------------------------------------------------- /droid_slam/align.py: -------------------------------------------------------------------------------- 1 | from lietorch import SE3 2 | 3 | def align_pose_fragements(pose0, pose1): 4 | P0 = SE3(pose0.clone()) 5 | P1 = SE3(pose1.clone()) 6 | 7 | dP1 = P0[None, :].inv() * P0[:, None] 8 | dP2 = P1[None, :].inv() * P1[:, None] 9 | 10 | dt1 = dP1.matrix()[:, :, :3, 3].view(-1, 3) 11 | dt2 = dP2.matrix()[:, :, :3, 3].view(-1, 3) 12 | 13 | s = (dt1 * dt2).sum() / (dt1 * dt1).sum() 14 | 15 | P0.data[..., :3] *= s 16 | 17 | dP = P1 * P0.inv() 18 | dG = dP[[0]] 19 | 20 | for _ in range(3): 21 | e = (P1 * (dG * P0).inv()).log() 22 | dG = SE3.exp(e.mean(dim=0, keepdim=True)) * dG 23 | 24 | return dG, s 25 | -------------------------------------------------------------------------------- /droid_slam/cuda_timer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class CudaTimer: 4 | def __init__(self, name, enabled=True): 5 | self.name = name 6 | self.enabled = enabled 7 | 8 | if self.enabled: 9 | self.start = torch.cuda.Event(enable_timing=True) 10 | self.end = torch.cuda.Event(enable_timing=True) 11 | 12 | def __enter__(self): 13 | if self.enabled: 14 | self.start.record() 15 | 16 | def __exit__(self, type, value, traceback): 17 | global all_times 18 | if self.enabled: 19 | self.end.record() 20 | torch.cuda.synchronize() 21 | 22 | elapsed = self.start.elapsed_time(self.end) 23 | print(self.name, elapsed) 24 | -------------------------------------------------------------------------------- /droid_slam/data_readers/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /droid_slam/data_readers/augmentation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms as transforms 3 | import numpy as np 4 | import torch.nn.functional as F 5 | 6 | 7 | class RGBDAugmentor: 8 | """ perform augmentation on RGB-D video """ 9 | 10 | def __init__(self, crop_size): 11 | self.crop_size = crop_size 12 | self.augcolor = transforms.Compose([ 13 | transforms.ToPILImage(), 14 | transforms.ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.4/3.14), 15 | transforms.RandomGrayscale(p=0.1), 16 | transforms.ToTensor()]) 17 | 18 | self.max_scale = 0.25 19 | 20 | def spatial_transform(self, images, depths, poses, intrinsics): 21 | """ cropping and resizing """ 22 | ht, wd = images.shape[2:] 23 | 24 | max_scale = self.max_scale 25 | min_scale = np.log2(np.maximum( 26 | (self.crop_size[0] + 1) / float(ht), 27 | (self.crop_size[1] + 1) / float(wd))) 28 | 29 | scale = 2 ** np.random.uniform(min_scale, max_scale) 30 | intrinsics = scale * intrinsics 31 | depths = depths.unsqueeze(dim=1) 32 | 33 | images = F.interpolate(images, scale_factor=scale, mode='bilinear', 34 | align_corners=False, recompute_scale_factor=True) 35 | 36 | depths = F.interpolate(depths, scale_factor=scale, recompute_scale_factor=True) 37 | 38 | # always perform center crop (TODO: try non-center crops) 39 | y0 = (images.shape[2] - self.crop_size[0]) // 2 40 | x0 = (images.shape[3] - self.crop_size[1]) // 2 41 | 42 | intrinsics = intrinsics - torch.tensor([0.0, 0.0, x0, y0]) 43 | images = images[:, :, y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 44 | depths = depths[:, :, y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 45 | 46 | depths = depths.squeeze(dim=1) 47 | return images, poses, depths, intrinsics 48 | 49 | def color_transform(self, images): 50 | """ color jittering """ 51 | num, ch, ht, wd = images.shape 52 | images = images.permute(1, 2, 3, 0).reshape(ch, ht, wd*num) 53 | images = 255 * self.augcolor(images[[2,1,0]] / 255.0) 54 | return images[[2,1,0]].reshape(ch, ht, wd, num).permute(3,0,1,2).contiguous() 55 | 56 | def __call__(self, images, poses, depths, intrinsics): 57 | images = self.color_transform(images) 58 | return self.spatial_transform(images, depths, poses, intrinsics) 59 | -------------------------------------------------------------------------------- /droid_slam/data_readers/base.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import torch 4 | import torch.utils.data as data 5 | import torch.nn.functional as F 6 | 7 | import csv 8 | import os 9 | import cv2 10 | import math 11 | import random 12 | import json 13 | import pickle 14 | import os.path as osp 15 | 16 | from .augmentation import RGBDAugmentor 17 | from .rgbd_utils import * 18 | 19 | class RGBDDataset(data.Dataset): 20 | def __init__(self, name, datapath, n_frames=4, crop_size=[384,512], fmin=8.0, fmax=75.0, do_aug=True): 21 | """ Base class for RGBD dataset """ 22 | self.aug = None 23 | self.root = datapath 24 | self.name = name 25 | 26 | self.n_frames = n_frames 27 | self.fmin = fmin # exclude very easy examples 28 | self.fmax = fmax # exclude very hard examples 29 | 30 | if do_aug: 31 | self.aug = RGBDAugmentor(crop_size=crop_size) 32 | 33 | # building dataset is expensive, cache so only needs to be performed once 34 | cur_path = osp.dirname(osp.abspath(__file__)) 35 | if not os.path.isdir(osp.join(cur_path, 'cache')): 36 | os.mkdir(osp.join(cur_path, 'cache')) 37 | 38 | cache_path = osp.join(cur_path, 'cache', '{}.pickle'.format(self.name)) 39 | 40 | if osp.isfile(cache_path): 41 | scene_info = pickle.load(open(cache_path, 'rb'))[0] 42 | else: 43 | scene_info = self._build_dataset() 44 | with open(cache_path, 'wb') as cachefile: 45 | pickle.dump((scene_info,), cachefile) 46 | 47 | self.scene_info = scene_info 48 | self._build_dataset_index() 49 | 50 | def _build_dataset_index(self): 51 | self.dataset_index = [] 52 | for scene in self.scene_info: 53 | if not self.__class__.is_test_scene(scene): 54 | graph = self.scene_info[scene]['graph'] 55 | for i in graph: 56 | if len(graph[i][0]) > self.n_frames: 57 | self.dataset_index.append((scene, i)) 58 | else: 59 | print("Reserving {} for validation".format(scene)) 60 | 61 | @staticmethod 62 | def image_read(image_file): 63 | return cv2.imread(image_file) 64 | 65 | @staticmethod 66 | def depth_read(depth_file): 67 | return np.load(depth_file) 68 | 69 | def build_frame_graph(self, poses, depths, intrinsics, f=16, max_flow=256): 70 | """ compute optical flow distance between all pairs of frames """ 71 | def read_disp(fn): 72 | depth = self.__class__.depth_read(fn)[f//2::f, f//2::f] 73 | depth[depth < 0.01] = np.mean(depth) 74 | return 1.0 / depth 75 | 76 | poses = np.array(poses) 77 | intrinsics = np.array(intrinsics) / f 78 | 79 | disps = np.stack(list(map(read_disp, depths)), 0) 80 | d = f * compute_distance_matrix_flow(poses, disps, intrinsics) 81 | 82 | # uncomment for nice visualization 83 | # import matplotlib.pyplot as plt 84 | # plt.imshow(d) 85 | # plt.show() 86 | 87 | graph = {} 88 | for i in range(d.shape[0]): 89 | j, = np.where(d[i] < max_flow) 90 | graph[i] = (j, d[i,j]) 91 | 92 | return graph 93 | 94 | def __getitem__(self, index): 95 | """ return training video """ 96 | 97 | index = index % len(self.dataset_index) 98 | scene_id, ix = self.dataset_index[index] 99 | 100 | frame_graph = self.scene_info[scene_id]['graph'] 101 | images_list = self.scene_info[scene_id]['images'] 102 | depths_list = self.scene_info[scene_id]['depths'] 103 | poses_list = self.scene_info[scene_id]['poses'] 104 | intrinsics_list = self.scene_info[scene_id]['intrinsics'] 105 | 106 | inds = [ ix ] 107 | while len(inds) < self.n_frames: 108 | # get other frames within flow threshold 109 | k = (frame_graph[ix][1] > self.fmin) & (frame_graph[ix][1] < self.fmax) 110 | frames = frame_graph[ix][0][k] 111 | 112 | # prefer frames forward in time 113 | if np.count_nonzero(frames[frames > ix]): 114 | ix = np.random.choice(frames[frames > ix]) 115 | 116 | elif np.count_nonzero(frames): 117 | ix = np.random.choice(frames) 118 | 119 | inds += [ ix ] 120 | 121 | images, depths, poses, intrinsics = [], [], [], [] 122 | for i in inds: 123 | images.append(self.__class__.image_read(images_list[i])) 124 | depths.append(self.__class__.depth_read(depths_list[i])) 125 | poses.append(poses_list[i]) 126 | intrinsics.append(intrinsics_list[i]) 127 | 128 | images = np.stack(images).astype(np.float32) 129 | depths = np.stack(depths).astype(np.float32) 130 | poses = np.stack(poses).astype(np.float32) 131 | intrinsics = np.stack(intrinsics).astype(np.float32) 132 | 133 | images = torch.from_numpy(images).float() 134 | images = images.permute(0, 3, 1, 2) 135 | 136 | disps = torch.from_numpy(1.0 / depths) 137 | poses = torch.from_numpy(poses) 138 | intrinsics = torch.from_numpy(intrinsics) 139 | 140 | if self.aug is not None: 141 | images, poses, disps, intrinsics = \ 142 | self.aug(images, poses, disps, intrinsics) 143 | 144 | # scale scene 145 | if len(disps[disps>0.01]) > 0: 146 | s = disps[disps>0.01].mean() 147 | disps = disps / s 148 | poses[...,:3] *= s 149 | 150 | return images, poses, disps, intrinsics 151 | 152 | def __len__(self): 153 | return len(self.dataset_index) 154 | 155 | def __imul__(self, x): 156 | self.dataset_index *= x 157 | return self 158 | -------------------------------------------------------------------------------- /droid_slam/data_readers/factory.py: -------------------------------------------------------------------------------- 1 | 2 | import pickle 3 | import os 4 | import os.path as osp 5 | 6 | # RGBD-Dataset 7 | from .tartan import TartanAir 8 | 9 | from .stream import ImageStream 10 | from .stream import StereoStream 11 | from .stream import RGBDStream 12 | 13 | # streaming datasets for inference 14 | from .tartan import TartanAirStream 15 | from .tartan import TartanAirTestStream 16 | 17 | def dataset_factory(dataset_list, **kwargs): 18 | """ create a combined dataset """ 19 | 20 | from torch.utils.data import ConcatDataset 21 | 22 | dataset_map = { 'tartan': (TartanAir, ) } 23 | db_list = [] 24 | for key in dataset_list: 25 | # cache datasets for faster future loading 26 | db = dataset_map[key][0](**kwargs) 27 | 28 | print("Dataset {} has {} images".format(key, len(db))) 29 | db_list.append(db) 30 | 31 | return ConcatDataset(db_list) 32 | 33 | 34 | def create_datastream(dataset_path, **kwargs): 35 | """ create data_loader to stream images 1 by 1 """ 36 | 37 | from torch.utils.data import DataLoader 38 | 39 | if osp.isfile(osp.join(dataset_path, 'calibration.txt')): 40 | db = ETH3DStream(dataset_path, **kwargs) 41 | 42 | elif osp.isdir(osp.join(dataset_path, 'image_left')): 43 | db = TartanAirStream(dataset_path, **kwargs) 44 | 45 | elif osp.isfile(osp.join(dataset_path, 'rgb.txt')): 46 | db = TUMStream(dataset_path, **kwargs) 47 | 48 | elif osp.isdir(osp.join(dataset_path, 'mav0')): 49 | db = EurocStream(dataset_path, **kwargs) 50 | 51 | elif osp.isfile(osp.join(dataset_path, 'calib.txt')): 52 | db = KITTIStream(dataset_path, **kwargs) 53 | 54 | else: 55 | # db = TartanAirStream(dataset_path, **kwargs) 56 | db = TartanAirTestStream(dataset_path, **kwargs) 57 | 58 | stream = DataLoader(db, shuffle=False, batch_size=1, num_workers=4) 59 | return stream 60 | 61 | 62 | def create_imagestream(dataset_path, **kwargs): 63 | """ create data_loader to stream images 1 by 1 """ 64 | from torch.utils.data import DataLoader 65 | 66 | db = ImageStream(dataset_path, **kwargs) 67 | return DataLoader(db, shuffle=False, batch_size=1, num_workers=4) 68 | 69 | def create_stereostream(dataset_path, **kwargs): 70 | """ create data_loader to stream images 1 by 1 """ 71 | from torch.utils.data import DataLoader 72 | 73 | db = StereoStream(dataset_path, **kwargs) 74 | return DataLoader(db, shuffle=False, batch_size=1, num_workers=4) 75 | 76 | def create_rgbdstream(dataset_path, **kwargs): 77 | """ create data_loader to stream images 1 by 1 """ 78 | from torch.utils.data import DataLoader 79 | 80 | db = RGBDStream(dataset_path, **kwargs) 81 | return DataLoader(db, shuffle=False, batch_size=1, num_workers=4) 82 | 83 | -------------------------------------------------------------------------------- /droid_slam/data_readers/rgbd_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os.path as osp 3 | 4 | import torch 5 | from lietorch import SE3 6 | 7 | import geom.projective_ops as pops 8 | from scipy.spatial.transform import Rotation 9 | 10 | 11 | def parse_list(filepath, skiprows=0): 12 | """ read list data """ 13 | data = np.loadtxt(filepath, delimiter=' ', dtype=np.unicode_, skiprows=skiprows) 14 | return data 15 | 16 | def associate_frames(tstamp_image, tstamp_depth, tstamp_pose, max_dt=1.0): 17 | """ pair images, depths, and poses """ 18 | associations = [] 19 | for i, t in enumerate(tstamp_image): 20 | if tstamp_pose is None: 21 | j = np.argmin(np.abs(tstamp_depth - t)) 22 | if (np.abs(tstamp_depth[j] - t) < max_dt): 23 | associations.append((i, j)) 24 | 25 | else: 26 | j = np.argmin(np.abs(tstamp_depth - t)) 27 | k = np.argmin(np.abs(tstamp_pose - t)) 28 | 29 | if (np.abs(tstamp_depth[j] - t) < max_dt) and \ 30 | (np.abs(tstamp_pose[k] - t) < max_dt): 31 | associations.append((i, j, k)) 32 | 33 | return associations 34 | 35 | def loadtum(datapath, frame_rate=-1): 36 | """ read video data in tum-rgbd format """ 37 | if osp.isfile(osp.join(datapath, 'groundtruth.txt')): 38 | pose_list = osp.join(datapath, 'groundtruth.txt') 39 | 40 | elif osp.isfile(osp.join(datapath, 'pose.txt')): 41 | pose_list = osp.join(datapath, 'pose.txt') 42 | 43 | else: 44 | return None, None, None, None 45 | 46 | image_list = osp.join(datapath, 'rgb.txt') 47 | depth_list = osp.join(datapath, 'depth.txt') 48 | 49 | calib_path = osp.join(datapath, 'calibration.txt') 50 | intrinsic = None 51 | if osp.isfile(calib_path): 52 | intrinsic = np.loadtxt(calib_path, delimiter=' ') 53 | intrinsic = intrinsic.astype(np.float64) 54 | 55 | image_data = parse_list(image_list) 56 | depth_data = parse_list(depth_list) 57 | pose_data = parse_list(pose_list, skiprows=1) 58 | pose_vecs = pose_data[:,1:].astype(np.float64) 59 | 60 | tstamp_image = image_data[:,0].astype(np.float64) 61 | tstamp_depth = depth_data[:,0].astype(np.float64) 62 | tstamp_pose = pose_data[:,0].astype(np.float64) 63 | associations = associate_frames(tstamp_image, tstamp_depth, tstamp_pose) 64 | 65 | # print(len(tstamp_image)) 66 | # print(len(associations)) 67 | 68 | indicies = range(len(associations))[::5] 69 | 70 | # indicies = [ 0 ] 71 | # for i in range(1, len(associations)): 72 | # t0 = tstamp_image[associations[indicies[-1]][0]] 73 | # t1 = tstamp_image[associations[i][0]] 74 | # if t1 - t0 > 1.0 / frame_rate: 75 | # indicies += [ i ] 76 | 77 | images, poses, depths, intrinsics, tstamps = [], [], [], [], [] 78 | for ix in indicies: 79 | (i, j, k) = associations[ix] 80 | images += [ osp.join(datapath, image_data[i,1]) ] 81 | depths += [ osp.join(datapath, depth_data[j,1]) ] 82 | poses += [ pose_vecs[k] ] 83 | tstamps += [ tstamp_image[i] ] 84 | 85 | if intrinsic is not None: 86 | intrinsics += [ intrinsic ] 87 | 88 | return images, depths, poses, intrinsics, tstamps 89 | 90 | 91 | def all_pairs_distance_matrix(poses, beta=2.5): 92 | """ compute distance matrix between all pairs of poses """ 93 | poses = np.array(poses, dtype=np.float32) 94 | poses[:,:3] *= beta # scale to balence rot + trans 95 | poses = SE3(torch.from_numpy(poses)) 96 | 97 | r = (poses[:,None].inv() * poses[None,:]).log() 98 | return r.norm(dim=-1).cpu().numpy() 99 | 100 | def pose_matrix_to_quaternion(pose): 101 | """ convert 4x4 pose matrix to (t, q) """ 102 | q = Rotation.from_matrix(pose[:3, :3]).as_quat() 103 | return np.concatenate([pose[:3, 3], q], axis=0) 104 | 105 | def compute_distance_matrix_flow(poses, disps, intrinsics): 106 | """ compute flow magnitude between all pairs of frames """ 107 | if not isinstance(poses, SE3): 108 | poses = torch.from_numpy(poses).float().cuda()[None] 109 | poses = SE3(poses).inv() 110 | 111 | disps = torch.from_numpy(disps).float().cuda()[None] 112 | intrinsics = torch.from_numpy(intrinsics).float().cuda()[None] 113 | 114 | N = poses.shape[1] 115 | 116 | ii, jj = torch.meshgrid(torch.arange(N), torch.arange(N)) 117 | ii = ii.reshape(-1).cuda() 118 | jj = jj.reshape(-1).cuda() 119 | 120 | MAX_FLOW = 100.0 121 | matrix = np.zeros((N, N), dtype=np.float32) 122 | 123 | s = 2048 124 | for i in range(0, ii.shape[0], s): 125 | flow1, val1 = pops.induced_flow(poses, disps, intrinsics, ii[i:i+s], jj[i:i+s]) 126 | flow2, val2 = pops.induced_flow(poses, disps, intrinsics, jj[i:i+s], ii[i:i+s]) 127 | 128 | flow = torch.stack([flow1, flow2], dim=2) 129 | val = torch.stack([val1, val2], dim=2) 130 | 131 | mag = flow.norm(dim=-1).clamp(max=MAX_FLOW) 132 | mag = mag.view(mag.shape[1], -1) 133 | val = val.view(val.shape[1], -1) 134 | 135 | mag = (mag * val).mean(-1) / val.mean(-1) 136 | mag[val.mean(-1) < 0.7] = np.inf 137 | 138 | i1 = ii[i:i+s].cpu().numpy() 139 | j1 = jj[i:i+s].cpu().numpy() 140 | matrix[i1, j1] = mag.cpu().numpy() 141 | 142 | return matrix 143 | 144 | 145 | def compute_distance_matrix_flow2(poses, disps, intrinsics, beta=0.4): 146 | """ compute flow magnitude between all pairs of frames """ 147 | # if not isinstance(poses, SE3): 148 | # poses = torch.from_numpy(poses).float().cuda()[None] 149 | # poses = SE3(poses).inv() 150 | 151 | # disps = torch.from_numpy(disps).float().cuda()[None] 152 | # intrinsics = torch.from_numpy(intrinsics).float().cuda()[None] 153 | 154 | N = poses.shape[1] 155 | 156 | ii, jj = torch.meshgrid(torch.arange(N), torch.arange(N)) 157 | ii = ii.reshape(-1) 158 | jj = jj.reshape(-1) 159 | 160 | MAX_FLOW = 128.0 161 | matrix = np.zeros((N, N), dtype=np.float32) 162 | 163 | s = 2048 164 | for i in range(0, ii.shape[0], s): 165 | flow1a, val1a = pops.induced_flow(poses, disps, intrinsics, ii[i:i+s], jj[i:i+s], tonly=True) 166 | flow1b, val1b = pops.induced_flow(poses, disps, intrinsics, ii[i:i+s], jj[i:i+s]) 167 | flow2a, val2a = pops.induced_flow(poses, disps, intrinsics, jj[i:i+s], ii[i:i+s], tonly=True) 168 | flow2b, val2b = pops.induced_flow(poses, disps, intrinsics, ii[i:i+s], jj[i:i+s]) 169 | 170 | flow1 = flow1a + beta * flow1b 171 | val1 = val1a * val2b 172 | 173 | flow2 = flow2a + beta * flow2b 174 | val2 = val2a * val2b 175 | 176 | flow = torch.stack([flow1, flow2], dim=2) 177 | val = torch.stack([val1, val2], dim=2) 178 | 179 | mag = flow.norm(dim=-1).clamp(max=MAX_FLOW) 180 | mag = mag.view(mag.shape[1], -1) 181 | val = val.view(val.shape[1], -1) 182 | 183 | mag = (mag * val).mean(-1) / val.mean(-1) 184 | mag[val.mean(-1) < 0.8] = np.inf 185 | 186 | i1 = ii[i:i+s].cpu().numpy() 187 | j1 = jj[i:i+s].cpu().numpy() 188 | matrix[i1, j1] = mag.cpu().numpy() 189 | 190 | return matrix 191 | -------------------------------------------------------------------------------- /droid_slam/data_readers/tartan.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import torch 4 | import glob 5 | import cv2 6 | import os 7 | import os.path as osp 8 | 9 | from lietorch import SE3 10 | from .base import RGBDDataset 11 | from .stream import RGBDStream 12 | 13 | cur_path = osp.dirname(osp.abspath(__file__)) 14 | test_split = osp.join(cur_path, 'tartan_test.txt') 15 | test_split = open(test_split).read().split() 16 | 17 | 18 | class TartanAir(RGBDDataset): 19 | 20 | # scale depths to balance rot & trans 21 | DEPTH_SCALE = 5.0 22 | 23 | def __init__(self, mode='training', **kwargs): 24 | self.mode = mode 25 | self.n_frames = 2 26 | super(TartanAir, self).__init__(name='TartanAir', **kwargs) 27 | 28 | @staticmethod 29 | def is_test_scene(scene): 30 | # print(scene, any(x in scene for x in test_split)) 31 | return any(x in scene for x in test_split) 32 | 33 | def _build_dataset(self): 34 | from tqdm import tqdm 35 | print("Building TartanAir dataset") 36 | 37 | scene_info = {} 38 | scenes = glob.glob(osp.join(self.root, '*/*/*/*')) 39 | for scene in tqdm(sorted(scenes)): 40 | images = sorted(glob.glob(osp.join(scene, 'image_left/*.png'))) 41 | depths = sorted(glob.glob(osp.join(scene, 'depth_left/*.npy'))) 42 | 43 | poses = np.loadtxt(osp.join(scene, 'pose_left.txt'), delimiter=' ') 44 | poses = poses[:, [1, 2, 0, 4, 5, 3, 6]] 45 | poses[:,:3] /= TartanAir.DEPTH_SCALE 46 | intrinsics = [TartanAir.calib_read()] * len(images) 47 | 48 | # graph of co-visible frames based on flow 49 | graph = self.build_frame_graph(poses, depths, intrinsics) 50 | 51 | scene = '/'.join(scene.split('/')) 52 | scene_info[scene] = {'images': images, 'depths': depths, 53 | 'poses': poses, 'intrinsics': intrinsics, 'graph': graph} 54 | 55 | return scene_info 56 | 57 | @staticmethod 58 | def calib_read(): 59 | return np.array([320.0, 320.0, 320.0, 240.0]) 60 | 61 | @staticmethod 62 | def image_read(image_file): 63 | return cv2.imread(image_file) 64 | 65 | @staticmethod 66 | def depth_read(depth_file): 67 | depth = np.load(depth_file) / TartanAir.DEPTH_SCALE 68 | depth[depth==np.nan] = 1.0 69 | depth[depth==np.inf] = 1.0 70 | return depth 71 | 72 | 73 | class TartanAirStream(RGBDStream): 74 | def __init__(self, datapath, **kwargs): 75 | super(TartanAirStream, self).__init__(datapath=datapath, **kwargs) 76 | 77 | def _build_dataset_index(self): 78 | """ build list of images, poses, depths, and intrinsics """ 79 | self.root = 'datasets/TartanAir' 80 | 81 | scene = osp.join(self.root, self.datapath) 82 | image_glob = osp.join(scene, 'image_left/*.png') 83 | images = sorted(glob.glob(image_glob)) 84 | 85 | poses = np.loadtxt(osp.join(scene, 'pose_left.txt'), delimiter=' ') 86 | poses = poses[:, [1, 2, 0, 4, 5, 3, 6]] 87 | 88 | poses = SE3(torch.as_tensor(poses)) 89 | poses = poses[[0]].inv() * poses 90 | poses = poses.data.cpu().numpy() 91 | 92 | intrinsic = self.calib_read(self.datapath) 93 | intrinsics = np.tile(intrinsic[None], (len(images), 1)) 94 | 95 | self.images = images[::int(self.frame_rate)] 96 | self.poses = poses[::int(self.frame_rate)] 97 | self.intrinsics = intrinsics[::int(self.frame_rate)] 98 | 99 | @staticmethod 100 | def calib_read(datapath): 101 | return np.array([320.0, 320.0, 320.0, 240.0]) 102 | 103 | @staticmethod 104 | def image_read(image_file): 105 | return cv2.imread(image_file) 106 | 107 | 108 | class TartanAirTestStream(RGBDStream): 109 | def __init__(self, datapath, **kwargs): 110 | super(TartanAirTestStream, self).__init__(datapath=datapath, **kwargs) 111 | 112 | def _build_dataset_index(self): 113 | """ build list of images, poses, depths, and intrinsics """ 114 | self.root = 'datasets/mono' 115 | image_glob = osp.join(self.root, self.datapath, '*.png') 116 | images = sorted(glob.glob(image_glob)) 117 | 118 | poses = np.loadtxt(osp.join(self.root, 'mono_gt', self.datapath + '.txt'), delimiter=' ') 119 | poses = poses[:, [1, 2, 0, 4, 5, 3, 6]] 120 | 121 | poses = SE3(torch.as_tensor(poses)) 122 | poses = poses[[0]].inv() * poses 123 | poses = poses.data.cpu().numpy() 124 | 125 | intrinsic = self.calib_read(self.datapath) 126 | intrinsics = np.tile(intrinsic[None], (len(images), 1)) 127 | 128 | self.images = images[::int(self.frame_rate)] 129 | self.poses = poses[::int(self.frame_rate)] 130 | self.intrinsics = intrinsics[::int(self.frame_rate)] 131 | 132 | @staticmethod 133 | def calib_read(datapath): 134 | return np.array([320.0, 320.0, 320.0, 240.0]) 135 | 136 | @staticmethod 137 | def image_read(image_file): 138 | return cv2.imread(image_file) -------------------------------------------------------------------------------- /droid_slam/data_readers/tartan_test.txt: -------------------------------------------------------------------------------- 1 | abandonedfactory/abandonedfactory/Easy/P011 2 | abandonedfactory/abandonedfactory/Hard/P011 3 | abandonedfactory_night/abandonedfactory_night/Easy/P013 4 | abandonedfactory_night/abandonedfactory_night/Hard/P014 5 | amusement/amusement/Easy/P008 6 | amusement/amusement/Hard/P007 7 | carwelding/carwelding/Easy/P007 8 | endofworld/endofworld/Easy/P009 9 | gascola/gascola/Easy/P008 10 | gascola/gascola/Hard/P009 11 | hospital/hospital/Easy/P036 12 | hospital/hospital/Hard/P049 13 | japanesealley/japanesealley/Easy/P007 14 | japanesealley/japanesealley/Hard/P005 15 | neighborhood/neighborhood/Easy/P021 16 | neighborhood/neighborhood/Hard/P017 17 | ocean/ocean/Easy/P013 18 | ocean/ocean/Hard/P009 19 | office2/office2/Easy/P011 20 | office2/office2/Hard/P010 21 | office/office/Hard/P007 22 | oldtown/oldtown/Easy/P007 23 | oldtown/oldtown/Hard/P008 24 | seasidetown/seasidetown/Easy/P009 25 | seasonsforest/seasonsforest/Easy/P011 26 | seasonsforest/seasonsforest/Hard/P006 27 | seasonsforest_winter/seasonsforest_winter/Easy/P009 28 | seasonsforest_winter/seasonsforest_winter/Hard/P018 29 | soulcity/soulcity/Easy/P012 30 | soulcity/soulcity/Hard/P009 31 | westerndesert/westerndesert/Easy/P013 32 | westerndesert/westerndesert/Hard/P007 33 | -------------------------------------------------------------------------------- /droid_slam/droid.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import lietorch 3 | import numpy as np 4 | 5 | from droid_net import DroidNet 6 | from depth_video import DepthVideo 7 | from motion_filter import MotionFilter 8 | from droid_frontend import DroidFrontend 9 | from droid_backend import DroidBackend 10 | from trajectory_filler import PoseTrajectoryFiller 11 | 12 | from collections import OrderedDict 13 | from torch.multiprocessing import Process 14 | 15 | 16 | class Droid: 17 | def __init__(self, args): 18 | super(Droid, self).__init__() 19 | self.load_weights(args.weights) 20 | self.args = args 21 | self.disable_vis = args.disable_vis 22 | 23 | # store images, depth, poses, intrinsics (shared between processes) 24 | self.video = DepthVideo(args.image_size, args.buffer, stereo=args.stereo) 25 | 26 | # filter incoming frames so that there is enough motion 27 | self.filterx = MotionFilter(self.net, self.video, thresh=args.filter_thresh) 28 | 29 | # frontend process 30 | self.frontend = DroidFrontend(self.net, self.video, self.args) 31 | 32 | # backend process 33 | self.backend = DroidBackend(self.net, self.video, self.args) 34 | 35 | # visualizer 36 | if not self.disable_vis: 37 | from visualizer.droid_visualizer import visualization_fn 38 | self.visualizer = Process(target=visualization_fn, args=(self.video, None)) 39 | self.visualizer.start() 40 | 41 | # post processor - fill in poses for non-keyframes 42 | self.traj_filler = PoseTrajectoryFiller(self.net, self.video) 43 | 44 | 45 | def load_weights(self, weights): 46 | """ load trained model weights """ 47 | 48 | print(weights) 49 | self.net = DroidNet() 50 | state_dict = OrderedDict([ 51 | (k.replace("module.", ""), v) for (k, v) in torch.load(weights).items()]) 52 | 53 | state_dict["update.weight.2.weight"] = state_dict["update.weight.2.weight"][:2] 54 | state_dict["update.weight.2.bias"] = state_dict["update.weight.2.bias"][:2] 55 | state_dict["update.delta.2.weight"] = state_dict["update.delta.2.weight"][:2] 56 | state_dict["update.delta.2.bias"] = state_dict["update.delta.2.bias"][:2] 57 | 58 | self.net.load_state_dict(state_dict) 59 | self.net.to("cuda:0").eval() 60 | 61 | def track(self, tstamp, image, depth=None, intrinsics=None): 62 | """ main thread - update map """ 63 | 64 | with torch.no_grad(): 65 | # check there is enough motion 66 | self.filterx.track(tstamp, image, depth, intrinsics) 67 | 68 | # local bundle adjustment 69 | self.frontend() 70 | 71 | def terminate(self, stream=None): 72 | """ terminate the visualization process, return poses [t, q] """ 73 | 74 | del self.frontend 75 | 76 | torch.cuda.empty_cache() 77 | print("#" * 32) 78 | self.backend(7) 79 | 80 | torch.cuda.empty_cache() 81 | print("#" * 32) 82 | self.backend(12) 83 | 84 | camera_trajectory = self.traj_filler(stream) 85 | return camera_trajectory.inv().data.cpu().numpy() 86 | 87 | -------------------------------------------------------------------------------- /droid_slam/droid_backend.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import lietorch 3 | import numpy as np 4 | 5 | from lietorch import SE3 6 | from factor_graph import FactorGraph 7 | 8 | 9 | class DroidBackend: 10 | def __init__(self, net, video, args): 11 | self.video = video 12 | self.update_op = net.update 13 | 14 | # global optimization window 15 | self.t0 = 0 16 | self.t1 = 0 17 | 18 | self.upsample = args.upsample 19 | self.beta = args.beta 20 | self.backend_thresh = args.backend_thresh 21 | self.backend_radius = args.backend_radius 22 | self.backend_nms = args.backend_nms 23 | 24 | @torch.no_grad() 25 | def __call__(self, steps=12, normalize=True): 26 | """ main update """ 27 | 28 | t = self.video.counter.value 29 | if normalize: 30 | if not self.video.stereo and not torch.any(self.video.disps_sens): 31 | self.video.normalize() 32 | 33 | graph = FactorGraph(self.video, self.update_op, corr_impl="alt", max_factors=16*t, upsample=self.upsample) 34 | 35 | graph.add_proximity_factors(rad=self.backend_radius, 36 | nms=self.backend_nms, 37 | thresh=self.backend_thresh, 38 | beta=self.beta) 39 | 40 | graph.update_lowmem(steps=steps) 41 | graph.clear_edges() 42 | self.video.dirty[:t] = True 43 | 44 | 45 | class DroidAsyncBackend: 46 | def __init__(self, net, video, args, max_age = 7): 47 | self.video = video 48 | self.update_op = net.update 49 | self.max_age = max_age 50 | 51 | # global optimization window 52 | self.t0 = 0 53 | self.t1 = 0 54 | 55 | self.upsample = args.upsample 56 | self.beta = args.beta 57 | self.backend_thresh = args.backend_thresh 58 | self.backend_radius = args.backend_radius 59 | self.backend_nms = args.backend_nms 60 | 61 | self.graph = FactorGraph( 62 | self.video, 63 | self.update_op, 64 | corr_impl="alt", 65 | max_factors=-1, 66 | upsample=self.upsample, 67 | ) 68 | 69 | @torch.no_grad() 70 | def __call__(self, steps=12, normalize=True): 71 | """main update""" 72 | 73 | t = self.video.counter.value 74 | if normalize: 75 | if not self.video.stereo and not torch.any(self.video.disps_sens): 76 | self.video.normalize() 77 | 78 | self.graph.add_proximity_factors( 79 | rad=self.backend_radius, 80 | nms=self.backend_nms, 81 | thresh=self.backend_thresh, 82 | beta=self.beta, 83 | ) 84 | 85 | self.graph.update_lowmem(steps=steps, use_inactive=True) 86 | self.graph.rm_factors(self.graph.age > self.max_age, store=True) 87 | 88 | self.video.dirty[:t] = True 89 | -------------------------------------------------------------------------------- /droid_slam/droid_frontend.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import lietorch 3 | import numpy as np 4 | 5 | from lietorch import SE3 6 | from factor_graph import FactorGraph 7 | 8 | from cuda_timer import CudaTimer 9 | 10 | 11 | ENABLE_TIMING = False 12 | 13 | class DroidFrontend: 14 | def __init__(self, net, video, args): 15 | self.video = video 16 | self.update_op = net.update 17 | self.graph = FactorGraph( 18 | video, net.update, max_factors=48, upsample=args.upsample 19 | ) 20 | 21 | # local optimization window 22 | self.t0 = 0 23 | self.t1 = 0 24 | 25 | # frontent variables 26 | self.is_initialized = False 27 | self.count = 0 28 | 29 | self.max_age = 20 30 | self.iters1 = 3 31 | self.iters2 = 2 32 | 33 | self.keyframe_removal_index = 3 34 | 35 | self.warmup = args.warmup 36 | self.beta = args.beta 37 | self.frontend_nms = args.frontend_nms 38 | self.keyframe_thresh = args.keyframe_thresh 39 | self.frontend_window = args.frontend_window 40 | self.frontend_thresh = args.frontend_thresh 41 | self.frontend_radius = args.frontend_radius 42 | 43 | self.depth_window = 3 44 | 45 | self.motion_damping = 0.0 46 | if hasattr(args, "motion_damping"): 47 | self.motion_damping = args.motion_damping 48 | 49 | def _init_next_state(self): 50 | # set pose / depth for next iteration 51 | self.video.poses[self.t1] = self.video.poses[self.t1 - 1] 52 | 53 | self.video.disps[self.t1] = torch.quantile( 54 | self.video.disps[self.t1 - 3 : self.t1 - 1], 0.5 55 | ) 56 | 57 | # damped linear velocity model 58 | if self.motion_damping >= 0: 59 | poses = SE3(self.video.poses) 60 | vel = (poses[self.t1 - 1] * poses[self.t1 - 2].inv()).log() 61 | damped_vel = self.motion_damping * vel 62 | next_pose = SE3.exp(damped_vel) * poses[self.t1 - 1] 63 | self.video.poses[self.t1] = next_pose.data 64 | 65 | def _update(self): 66 | """add edges, perform update""" 67 | 68 | self.count += 1 69 | self.t1 += 1 70 | 71 | if self.graph.corr is not None: 72 | self.graph.rm_factors(self.graph.age > self.max_age, store=True) 73 | 74 | self.graph.add_proximity_factors( 75 | self.t1 - 5, 76 | max(self.t1 - self.frontend_window, 0), 77 | rad=self.frontend_radius, 78 | nms=self.frontend_nms, 79 | thresh=self.frontend_thresh, 80 | beta=self.beta, 81 | remove=True, 82 | ) 83 | 84 | self.video.disps[self.t1 - 1] = torch.where( 85 | self.video.disps_sens[self.t1 - 1] > 0, 86 | self.video.disps_sens[self.t1 - 1], 87 | self.video.disps[self.t1 - 1], 88 | ) 89 | 90 | for itr in range(self.iters1): 91 | self.graph.update(None, None, use_inactive=True) 92 | 93 | # set initial pose for next frame 94 | d = self.video.distance( 95 | [self.t1 - 4], [self.t1 - 2], beta=self.beta, bidirectional=True 96 | ) 97 | 98 | if d.item() < 2 * self.keyframe_thresh: 99 | self.graph.rm_keyframe(self.t1 - 3) 100 | 101 | with self.video.get_lock(): 102 | self.video.counter.value -= 1 103 | self.t1 -= 1 104 | 105 | else: 106 | for itr in range(self.iters2): 107 | self.graph.update(None, None, use_inactive=True) 108 | 109 | 110 | # set pose for next itration 111 | self.video.poses[self.t1] = self.video.poses[self.t1 - 1] 112 | self.video.disps[self.t1] = torch.quantile( 113 | self.video.disps[self.t1 - self.depth_window - 1 : self.t1 - 1], 0.7 114 | ) 115 | 116 | # update visualization 117 | self.video.dirty[self.graph.ii.min() : self.t1] = True 118 | 119 | def _initialize(self): 120 | """initialize the SLAM system""" 121 | 122 | self.t0 = 0 123 | self.t1 = self.video.counter.value 124 | 125 | self.graph.add_neighborhood_factors(self.t0, self.t1, r=3) 126 | 127 | for itr in range(8): 128 | self.graph.update(1, use_inactive=True) 129 | 130 | self.graph.add_proximity_factors( 131 | 0, 0, rad=2, nms=2, thresh=self.frontend_thresh, remove=False 132 | ) 133 | 134 | for itr in range(8): 135 | self.graph.update(1, use_inactive=True) 136 | 137 | # self.video.normalize() 138 | self.video.poses[self.t1] = self.video.poses[self.t1 - 1].clone() 139 | self.video.disps[self.t1] = self.video.disps[self.t1 - 4 : self.t1].mean() 140 | 141 | # initialization complete 142 | self.is_initialized = True 143 | self.last_pose = self.video.poses[self.t1 - 1].clone() 144 | self.last_disp = self.video.disps[self.t1 - 1].clone() 145 | self.last_time = self.video.tstamp[self.t1 - 1].clone() 146 | 147 | with self.video.get_lock(): 148 | self.video.ready.value = 1 149 | self.video.dirty[: self.t1] = True 150 | 151 | self.graph.rm_factors(self.graph.ii < self.warmup - 4, store=True) 152 | 153 | def __call__(self): 154 | """main update""" 155 | 156 | # do initialization 157 | if not self.is_initialized and self.video.counter.value == self.warmup: 158 | self._initialize() 159 | self._init_next_state() 160 | 161 | # do update 162 | elif self.is_initialized and self.t1 < self.video.counter.value: 163 | self._update() 164 | self._init_next_state() 165 | -------------------------------------------------------------------------------- /droid_slam/droid_net.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from collections import OrderedDict 6 | 7 | from modules.extractor import BasicEncoder 8 | from modules.corr import CorrBlock 9 | from modules.gru import ConvGRU 10 | from modules.clipping import GradientClip 11 | 12 | from lietorch import SE3 13 | from geom.ba import BA 14 | 15 | import geom.projective_ops as pops 16 | from geom.graph_utils import graph_to_edge_list, keyframe_indicies 17 | 18 | from torch_scatter import scatter_mean 19 | 20 | 21 | def cvx_upsample(data, mask): 22 | """ upsample pixel-wise transformation field """ 23 | batch, ht, wd, dim = data.shape 24 | data = data.permute(0, 3, 1, 2) 25 | mask = mask.view(batch, 1, 9, 8, 8, ht, wd) 26 | mask = torch.softmax(mask, dim=2) 27 | 28 | up_data = F.unfold(data, [3,3], padding=1) 29 | up_data = up_data.view(batch, dim, 9, 1, 1, ht, wd) 30 | 31 | up_data = torch.sum(mask * up_data, dim=2) 32 | up_data = up_data.permute(0, 4, 2, 5, 3, 1) 33 | up_data = up_data.reshape(batch, 8*ht, 8*wd, dim) 34 | 35 | return up_data 36 | 37 | def upsample_disp(disp, mask): 38 | batch, num, ht, wd = disp.shape 39 | disp = disp.view(batch*num, ht, wd, 1) 40 | mask = mask.view(batch*num, -1, ht, wd) 41 | return cvx_upsample(disp, mask).view(batch, num, 8*ht, 8*wd) 42 | 43 | 44 | class GraphAgg(nn.Module): 45 | def __init__(self): 46 | super(GraphAgg, self).__init__() 47 | self.conv1 = nn.Conv2d(128, 128, 3, padding=1) 48 | self.conv2 = nn.Conv2d(128, 128, 3, padding=1) 49 | self.relu = nn.ReLU(inplace=True) 50 | 51 | self.eta = nn.Sequential( 52 | nn.Conv2d(128, 1, 3, padding=1), 53 | GradientClip(), 54 | nn.Softplus()) 55 | 56 | self.upmask = nn.Sequential( 57 | nn.Conv2d(128, 8*8*9, 1, padding=0)) 58 | 59 | def forward(self, net, ii): 60 | batch, num, ch, ht, wd = net.shape 61 | net = net.view(batch*num, ch, ht, wd) 62 | 63 | _, ix = torch.unique(ii, return_inverse=True) 64 | net = self.relu(self.conv1(net)) 65 | 66 | net = net.view(batch, num, 128, ht, wd) 67 | net = scatter_mean(net, ix, dim=1) 68 | net = net.view(-1, 128, ht, wd) 69 | 70 | net = self.relu(self.conv2(net)) 71 | 72 | eta = self.eta(net).view(batch, -1, ht, wd) 73 | upmask = self.upmask(net).view(batch, -1, 8*8*9, ht, wd) 74 | 75 | return .01 * eta, upmask 76 | 77 | 78 | class UpdateModule(nn.Module): 79 | def __init__(self): 80 | super(UpdateModule, self).__init__() 81 | cor_planes = 4 * (2*3 + 1)**2 82 | 83 | self.corr_encoder = nn.Sequential( 84 | nn.Conv2d(cor_planes, 128, 1, padding=0), 85 | nn.ReLU(inplace=True), 86 | nn.Conv2d(128, 128, 3, padding=1), 87 | nn.ReLU(inplace=True)) 88 | 89 | self.flow_encoder = nn.Sequential( 90 | nn.Conv2d(4, 128, 7, padding=3), 91 | nn.ReLU(inplace=True), 92 | nn.Conv2d(128, 64, 3, padding=1), 93 | nn.ReLU(inplace=True)) 94 | 95 | self.weight = nn.Sequential( 96 | nn.Conv2d(128, 128, 3, padding=1), 97 | nn.ReLU(inplace=True), 98 | nn.Conv2d(128, 2, 3, padding=1), 99 | GradientClip(), 100 | nn.Sigmoid()) 101 | 102 | self.delta = nn.Sequential( 103 | nn.Conv2d(128, 128, 3, padding=1), 104 | nn.ReLU(inplace=True), 105 | nn.Conv2d(128, 2, 3, padding=1), 106 | GradientClip()) 107 | 108 | self.gru = ConvGRU(128, 128+128+64) 109 | self.agg = GraphAgg() 110 | 111 | def forward(self, net, inp, corr, flow=None, ii=None, jj=None): 112 | """ RaftSLAM update operator """ 113 | 114 | batch, num, ch, ht, wd = net.shape 115 | 116 | if flow is None: 117 | flow = torch.zeros(batch, num, 4, ht, wd, device=net.device) 118 | 119 | output_dim = (batch, num, -1, ht, wd) 120 | net = net.view(batch*num, -1, ht, wd) 121 | inp = inp.view(batch*num, -1, ht, wd) 122 | corr = corr.view(batch*num, -1, ht, wd) 123 | flow = flow.view(batch*num, -1, ht, wd) 124 | 125 | corr = self.corr_encoder(corr) 126 | flow = self.flow_encoder(flow) 127 | net = self.gru(net, inp, corr, flow) 128 | 129 | ### update variables ### 130 | delta = self.delta(net).view(*output_dim) 131 | weight = self.weight(net).view(*output_dim) 132 | 133 | delta = delta.permute(0,1,3,4,2)[...,:2].contiguous() 134 | weight = weight.permute(0,1,3,4,2)[...,:2].contiguous() 135 | 136 | net = net.view(*output_dim) 137 | 138 | if ii is not None: 139 | eta, upmask = self.agg(net, ii.to(net.device)) 140 | return net, delta, weight, eta, upmask 141 | 142 | else: 143 | return net, delta, weight 144 | 145 | 146 | class DroidNet(nn.Module): 147 | def __init__(self): 148 | super(DroidNet, self).__init__() 149 | self.fnet = BasicEncoder(output_dim=128, norm_fn='instance') 150 | self.cnet = BasicEncoder(output_dim=256, norm_fn='none') 151 | self.update = UpdateModule() 152 | 153 | 154 | def extract_features(self, images): 155 | """ run feeature extraction networks """ 156 | 157 | # normalize images 158 | images = images[:, :, [2,1,0]] / 255.0 159 | mean = torch.as_tensor([0.485, 0.456, 0.406], device=images.device) 160 | std = torch.as_tensor([0.229, 0.224, 0.225], device=images.device) 161 | images = images.sub_(mean[:, None, None]).div_(std[:, None, None]) 162 | 163 | fmaps = self.fnet(images) 164 | net = self.cnet(images) 165 | 166 | net, inp = net.split([128,128], dim=2) 167 | net = torch.tanh(net) 168 | inp = torch.relu(inp) 169 | return fmaps, net, inp 170 | 171 | 172 | def forward(self, Gs, images, disps, intrinsics, graph=None, num_steps=12, fixedp=2): 173 | """ Estimates SE3 or Sim3 between pair of frames """ 174 | 175 | u = keyframe_indicies(graph) 176 | ii, jj, kk = graph_to_edge_list(graph) 177 | 178 | ii = ii.to(device=images.device, dtype=torch.long) 179 | jj = jj.to(device=images.device, dtype=torch.long) 180 | 181 | fmaps, net, inp = self.extract_features(images) 182 | net, inp = net[:,ii], inp[:,ii] 183 | corr_fn = CorrBlock(fmaps[:,ii], fmaps[:,jj], num_levels=4, radius=3) 184 | 185 | ht, wd = images.shape[-2:] 186 | coords0 = pops.coords_grid(ht//8, wd//8, device=images.device) 187 | 188 | coords1, _ = pops.projective_transform(Gs, disps, intrinsics, ii, jj) 189 | target = coords1.clone() 190 | 191 | Gs_list, disp_list, residual_list = [], [], [] 192 | for step in range(num_steps): 193 | Gs = Gs.detach() 194 | disps = disps.detach() 195 | coords1 = coords1.detach() 196 | target = target.detach() 197 | 198 | # extract motion features 199 | corr = corr_fn(coords1) 200 | resd = target - coords1 201 | flow = coords1 - coords0 202 | 203 | motion = torch.cat([flow, resd], dim=-1) 204 | motion = motion.permute(0,1,4,2,3).clamp(-64.0, 64.0) 205 | 206 | net, delta, weight, eta, upmask = \ 207 | self.update(net, inp, corr, motion, ii, jj) 208 | 209 | target = coords1 + delta 210 | 211 | for i in range(2): 212 | Gs, disps = BA(target, weight, eta, Gs, disps, intrinsics, ii, jj, fixedp=2) 213 | 214 | coords1, valid_mask = pops.projective_transform(Gs, disps, intrinsics, ii, jj) 215 | residual = (target - coords1) 216 | 217 | Gs_list.append(Gs) 218 | disp_list.append(upsample_disp(disps, upmask)) 219 | residual_list.append(valid_mask * residual) 220 | 221 | 222 | return Gs_list, disp_list, residual_list 223 | -------------------------------------------------------------------------------- /droid_slam/geom/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-vl/DROID-SLAM/2dfd39f0dcad44012ca7bbb8aa70b55edbfa9c99/droid_slam/geom/__init__.py -------------------------------------------------------------------------------- /droid_slam/geom/ba.py: -------------------------------------------------------------------------------- 1 | import lietorch 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | from .chol import block_solve, schur_solve 6 | import geom.projective_ops as pops 7 | 8 | from torch_scatter import scatter_sum 9 | 10 | 11 | # utility functions for scattering ops 12 | def safe_scatter_add_mat(A, ii, jj, n, m): 13 | v = (ii >= 0) & (jj >= 0) & (ii < n) & (jj < m) 14 | return scatter_sum(A[:,v], ii[v]*m + jj[v], dim=1, dim_size=n*m) 15 | 16 | def safe_scatter_add_vec(b, ii, n): 17 | v = (ii >= 0) & (ii < n) 18 | return scatter_sum(b[:,v], ii[v], dim=1, dim_size=n) 19 | 20 | # apply retraction operator to inv-depth maps 21 | def disp_retr(disps, dz, ii): 22 | ii = ii.to(device=dz.device) 23 | return disps + scatter_sum(dz, ii, dim=1, dim_size=disps.shape[1]) 24 | 25 | # apply retraction operator to poses 26 | def pose_retr(poses, dx, ii): 27 | ii = ii.to(device=dx.device) 28 | return poses.retr(scatter_sum(dx, ii, dim=1, dim_size=poses.shape[1])) 29 | 30 | 31 | def BA(target, weight, eta, poses, disps, intrinsics, ii, jj, fixedp=1, rig=1): 32 | """ Full Bundle Adjustment """ 33 | 34 | B, P, ht, wd = disps.shape 35 | N = ii.shape[0] 36 | D = poses.manifold_dim 37 | 38 | ### 1: commpute jacobians and residuals ### 39 | coords, valid, (Ji, Jj, Jz) = pops.projective_transform( 40 | poses, disps, intrinsics, ii, jj, jacobian=True) 41 | 42 | r = (target - coords).view(B, N, -1, 1) 43 | w = .001 * (valid * weight).view(B, N, -1, 1) 44 | 45 | ### 2: construct linear system ### 46 | Ji = Ji.reshape(B, N, -1, D) 47 | Jj = Jj.reshape(B, N, -1, D) 48 | wJiT = (w * Ji).transpose(2,3) 49 | wJjT = (w * Jj).transpose(2,3) 50 | 51 | Jz = Jz.reshape(B, N, ht*wd, -1) 52 | 53 | Hii = torch.matmul(wJiT, Ji) 54 | Hij = torch.matmul(wJiT, Jj) 55 | Hji = torch.matmul(wJjT, Ji) 56 | Hjj = torch.matmul(wJjT, Jj) 57 | 58 | vi = torch.matmul(wJiT, r).squeeze(-1) 59 | vj = torch.matmul(wJjT, r).squeeze(-1) 60 | 61 | Ei = (wJiT.view(B,N,D,ht*wd,-1) * Jz[:,:,None]).sum(dim=-1) 62 | Ej = (wJjT.view(B,N,D,ht*wd,-1) * Jz[:,:,None]).sum(dim=-1) 63 | 64 | w = w.view(B, N, ht*wd, -1) 65 | r = r.view(B, N, ht*wd, -1) 66 | wk = torch.sum(w*r*Jz, dim=-1) 67 | Ck = torch.sum(w*Jz*Jz, dim=-1) 68 | 69 | kx, kk = torch.unique(ii, return_inverse=True) 70 | M = kx.shape[0] 71 | 72 | # only optimize keyframe poses 73 | P = P // rig - fixedp 74 | ii = ii // rig - fixedp 75 | jj = jj // rig - fixedp 76 | 77 | H = safe_scatter_add_mat(Hii, ii, ii, P, P) + \ 78 | safe_scatter_add_mat(Hij, ii, jj, P, P) + \ 79 | safe_scatter_add_mat(Hji, jj, ii, P, P) + \ 80 | safe_scatter_add_mat(Hjj, jj, jj, P, P) 81 | 82 | E = safe_scatter_add_mat(Ei, ii, kk, P, M) + \ 83 | safe_scatter_add_mat(Ej, jj, kk, P, M) 84 | 85 | v = safe_scatter_add_vec(vi, ii, P) + \ 86 | safe_scatter_add_vec(vj, jj, P) 87 | 88 | C = safe_scatter_add_vec(Ck, kk, M) 89 | w = safe_scatter_add_vec(wk, kk, M) 90 | 91 | C = C + eta.view(*C.shape) + 1e-7 92 | 93 | H = H.view(B, P, P, D, D) 94 | E = E.view(B, P, M, D, ht*wd) 95 | 96 | ### 3: solve the system ### 97 | dx, dz = schur_solve(H, E, C, v, w) 98 | 99 | ### 4: apply retraction ### 100 | poses = pose_retr(poses, dx, torch.arange(P) + fixedp) 101 | disps = disp_retr(disps, dz.view(B,-1,ht,wd), kx) 102 | 103 | disps = torch.where(disps > 10, torch.zeros_like(disps), disps) 104 | disps = disps.clamp(min=0.0) 105 | 106 | return poses, disps 107 | 108 | 109 | def MoBA(target, weight, eta, poses, disps, intrinsics, ii, jj, fixedp=1, rig=1): 110 | """ Motion only bundle adjustment """ 111 | 112 | B, P, ht, wd = disps.shape 113 | N = ii.shape[0] 114 | D = poses.manifold_dim 115 | 116 | ### 1: commpute jacobians and residuals ### 117 | coords, valid, (Ji, Jj, Jz) = pops.projective_transform( 118 | poses, disps, intrinsics, ii, jj, jacobian=True) 119 | 120 | r = (target - coords).view(B, N, -1, 1) 121 | w = .001 * (valid * weight).view(B, N, -1, 1) 122 | 123 | ### 2: construct linear system ### 124 | Ji = Ji.reshape(B, N, -1, D) 125 | Jj = Jj.reshape(B, N, -1, D) 126 | wJiT = (w * Ji).transpose(2,3) 127 | wJjT = (w * Jj).transpose(2,3) 128 | 129 | Hii = torch.matmul(wJiT, Ji) 130 | Hij = torch.matmul(wJiT, Jj) 131 | Hji = torch.matmul(wJjT, Ji) 132 | Hjj = torch.matmul(wJjT, Jj) 133 | 134 | vi = torch.matmul(wJiT, r).squeeze(-1) 135 | vj = torch.matmul(wJjT, r).squeeze(-1) 136 | 137 | # only optimize keyframe poses 138 | P = P // rig - fixedp 139 | ii = ii // rig - fixedp 140 | jj = jj // rig - fixedp 141 | 142 | H = safe_scatter_add_mat(Hii, ii, ii, P, P) + \ 143 | safe_scatter_add_mat(Hij, ii, jj, P, P) + \ 144 | safe_scatter_add_mat(Hji, jj, ii, P, P) + \ 145 | safe_scatter_add_mat(Hjj, jj, jj, P, P) 146 | 147 | v = safe_scatter_add_vec(vi, ii, P) + \ 148 | safe_scatter_add_vec(vj, jj, P) 149 | 150 | H = H.view(B, P, P, D, D) 151 | 152 | ### 3: solve the system ### 153 | dx = block_solve(H, v) 154 | 155 | ### 4: apply retraction ### 156 | poses = pose_retr(poses, dx, torch.arange(P) + fixedp) 157 | return poses 158 | 159 | -------------------------------------------------------------------------------- /droid_slam/geom/chol.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import geom.projective_ops as pops 4 | 5 | class CholeskySolver(torch.autograd.Function): 6 | @staticmethod 7 | def forward(ctx, H, b): 8 | # don't crash training if cholesky decomp fails 9 | try: 10 | U = torch.linalg.cholesky(H) 11 | xs = torch.cholesky_solve(b, U) 12 | ctx.save_for_backward(U, xs) 13 | ctx.failed = False 14 | except Exception as e: 15 | print(e) 16 | ctx.failed = True 17 | xs = torch.zeros_like(b) 18 | 19 | return xs 20 | 21 | @staticmethod 22 | def backward(ctx, grad_x): 23 | if ctx.failed: 24 | return None, None 25 | 26 | U, xs = ctx.saved_tensors 27 | dz = torch.cholesky_solve(grad_x, U) 28 | dH = -torch.matmul(xs, dz.transpose(-1,-2)) 29 | 30 | return dH, dz 31 | 32 | def block_solve(H, b, ep=0.1, lm=0.0001): 33 | """ solve normal equations """ 34 | B, N, _, D, _ = H.shape 35 | I = torch.eye(D).to(H.device) 36 | H = H + (ep + lm*H) * I 37 | 38 | H = H.permute(0,1,3,2,4) 39 | H = H.reshape(B, N*D, N*D) 40 | b = b.reshape(B, N*D, 1) 41 | 42 | x = CholeskySolver.apply(H,b) 43 | return x.reshape(B, N, D) 44 | 45 | 46 | def schur_solve(H, E, C, v, w, ep=0.1, lm=0.0001, sless=False): 47 | """ solve using shur complement """ 48 | 49 | B, P, M, D, HW = E.shape 50 | H = H.permute(0,1,3,2,4).reshape(B, P*D, P*D) 51 | E = E.permute(0,1,3,2,4).reshape(B, P*D, M*HW) 52 | Q = (1.0 / C).view(B, M*HW, 1) 53 | 54 | # damping 55 | I = torch.eye(P*D).to(H.device) 56 | H = H + (ep + lm*H) * I 57 | 58 | v = v.reshape(B, P*D, 1) 59 | w = w.reshape(B, M*HW, 1) 60 | 61 | Et = E.transpose(1,2) 62 | S = H - torch.matmul(E, Q*Et) 63 | v = v - torch.matmul(E, Q*w) 64 | 65 | dx = CholeskySolver.apply(S, v) 66 | if sless: 67 | return dx.reshape(B, P, D) 68 | 69 | dz = Q * (w - Et @ dx) 70 | dx = dx.reshape(B, P, D) 71 | dz = dz.reshape(B, M, HW) 72 | 73 | return dx, dz -------------------------------------------------------------------------------- /droid_slam/geom/graph_utils.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import numpy as np 4 | from collections import OrderedDict 5 | 6 | import lietorch 7 | from data_readers.rgbd_utils import compute_distance_matrix_flow, compute_distance_matrix_flow2 8 | 9 | 10 | def graph_to_edge_list(graph): 11 | ii, jj, kk = [], [], [] 12 | for s, u in enumerate(graph): 13 | for v in graph[u]: 14 | ii.append(u) 15 | jj.append(v) 16 | kk.append(s) 17 | 18 | ii = torch.as_tensor(ii) 19 | jj = torch.as_tensor(jj) 20 | kk = torch.as_tensor(kk) 21 | return ii, jj, kk 22 | 23 | def keyframe_indicies(graph): 24 | return torch.as_tensor([u for u in graph]) 25 | 26 | def meshgrid(m, n, device='cuda'): 27 | ii, jj = torch.meshgrid(torch.arange(m), torch.arange(n), indexing="ij") 28 | return ii.reshape(-1).to(device), jj.reshape(-1).to(device) 29 | 30 | def neighbourhood_graph(n, r): 31 | ii, jj = meshgrid(n, n) 32 | d = (ii - jj).abs() 33 | keep = (d >= 1) & (d <= r) 34 | return ii[keep], jj[keep] 35 | 36 | 37 | def build_frame_graph(poses, disps, intrinsics, num=16, thresh=24.0, r=2): 38 | """ construct a frame graph between co-visible frames """ 39 | N = poses.shape[1] 40 | poses = poses[0].cpu().numpy() 41 | disps = disps[0][:,3::8,3::8].cpu().numpy() 42 | intrinsics = intrinsics[0].cpu().numpy() / 8.0 43 | d = compute_distance_matrix_flow(poses, disps, intrinsics) 44 | 45 | count = 0 46 | graph = OrderedDict() 47 | 48 | for i in range(N): 49 | graph[i] = [] 50 | d[i,i] = np.inf 51 | for j in range(i-r, i+r+1): 52 | if 0 <= j < N and i != j: 53 | graph[i].append(j) 54 | d[i,j] = np.inf 55 | count += 1 56 | 57 | while count < num: 58 | ix = np.argmin(d) 59 | i, j = ix // N, ix % N 60 | 61 | if d[i,j] < thresh: 62 | graph[i].append(j) 63 | d[i,j] = np.inf 64 | count += 1 65 | else: 66 | break 67 | 68 | return graph 69 | 70 | 71 | 72 | def build_frame_graph_v2(poses, disps, intrinsics, num=16, thresh=24.0, r=2): 73 | """ construct a frame graph between co-visible frames """ 74 | N = poses.shape[1] 75 | # poses = poses[0].cpu().numpy() 76 | # disps = disps[0].cpu().numpy() 77 | # intrinsics = intrinsics[0].cpu().numpy() 78 | d = compute_distance_matrix_flow2(poses, disps, intrinsics) 79 | 80 | # import matplotlib.pyplot as plt 81 | # plt.imshow(d) 82 | # plt.show() 83 | 84 | count = 0 85 | graph = OrderedDict() 86 | 87 | for i in range(N): 88 | graph[i] = [] 89 | d[i,i] = np.inf 90 | for j in range(i-r, i+r+1): 91 | if 0 <= j < N and i != j: 92 | graph[i].append(j) 93 | d[i,j] = np.inf 94 | count += 1 95 | 96 | while 1: 97 | ix = np.argmin(d) 98 | i, j = ix // N, ix % N 99 | 100 | if d[i,j] < thresh: 101 | graph[i].append(j) 102 | 103 | for i1 in range(i-1, i+2): 104 | for j1 in range(j-1, j+2): 105 | if 0 <= i1 < N and 0 <= j1 < N: 106 | d[i1, j1] = np.inf 107 | 108 | count += 1 109 | else: 110 | break 111 | 112 | return graph 113 | 114 | -------------------------------------------------------------------------------- /droid_slam/geom/losses.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import numpy as np 3 | import torch 4 | from lietorch import SO3, SE3, Sim3 5 | from .graph_utils import graph_to_edge_list 6 | from .projective_ops import projective_transform 7 | 8 | 9 | def pose_metrics(dE): 10 | """ Translation/Rotation/Scaling metrics from Sim3 """ 11 | t, q, s = dE.data.split([3, 4, 1], -1) 12 | ang = SO3(q).log().norm(dim=-1) 13 | 14 | # convert radians to degrees 15 | r_err = (180 / np.pi) * ang 16 | t_err = t.norm(dim=-1) 17 | s_err = (s - 1.0).abs() 18 | return r_err, t_err, s_err 19 | 20 | 21 | def fit_scale(Ps, Gs): 22 | b = Ps.shape[0] 23 | t1 = Ps.data[...,:3].detach().reshape(b, -1) 24 | t2 = Gs.data[...,:3].detach().reshape(b, -1) 25 | 26 | s = (t1*t2).sum(-1) / ((t2*t2).sum(-1) + 1e-8) 27 | return s 28 | 29 | 30 | def geodesic_loss(Ps, Gs, graph, gamma=0.9, do_scale=True): 31 | """ Loss function for training network """ 32 | 33 | # relative pose 34 | ii, jj, kk = graph_to_edge_list(graph) 35 | dP = Ps[:,jj] * Ps[:,ii].inv() 36 | 37 | n = len(Gs) 38 | geodesic_loss = 0.0 39 | 40 | for i in range(n): 41 | w = gamma ** (n - i - 1) 42 | dG = Gs[i][:,jj] * Gs[i][:,ii].inv() 43 | 44 | if do_scale: 45 | s = fit_scale(dP, dG) 46 | dG = dG.scale(s[:,None]) 47 | 48 | # pose error 49 | d = (dG * dP.inv()).log() 50 | 51 | if isinstance(dG, SE3): 52 | tau, phi = d.split([3,3], dim=-1) 53 | geodesic_loss += w * ( 54 | tau.norm(dim=-1).mean() + 55 | phi.norm(dim=-1).mean()) 56 | 57 | elif isinstance(dG, Sim3): 58 | tau, phi, sig = d.split([3,3,1], dim=-1) 59 | geodesic_loss += w * ( 60 | tau.norm(dim=-1).mean() + 61 | phi.norm(dim=-1).mean() + 62 | 0.05 * sig.norm(dim=-1).mean()) 63 | 64 | dE = Sim3(dG * dP.inv()).detach() 65 | r_err, t_err, s_err = pose_metrics(dE) 66 | 67 | metrics = { 68 | 'rot_error': r_err.mean().item(), 69 | 'tr_error': t_err.mean().item(), 70 | 'bad_rot': (r_err < .1).float().mean().item(), 71 | 'bad_tr': (t_err < .01).float().mean().item(), 72 | } 73 | 74 | return geodesic_loss, metrics 75 | 76 | 77 | def residual_loss(residuals, gamma=0.9): 78 | """ loss on system residuals """ 79 | residual_loss = 0.0 80 | n = len(residuals) 81 | 82 | for i in range(n): 83 | w = gamma ** (n - i - 1) 84 | residual_loss += w * residuals[i].abs().mean() 85 | 86 | return residual_loss, {'residual': residual_loss.item()} 87 | 88 | 89 | def flow_loss(Ps, disps, poses_est, disps_est, intrinsics, graph, gamma=0.9): 90 | """ optical flow loss """ 91 | 92 | N = Ps.shape[1] 93 | graph = OrderedDict() 94 | for i in range(N): 95 | graph[i] = [j for j in range(N) if abs(i-j)==1] 96 | 97 | ii, jj, kk = graph_to_edge_list(graph) 98 | coords0, val0 = projective_transform(Ps, disps, intrinsics, ii, jj) 99 | val0 = val0 * (disps[:,ii] > 0).float().unsqueeze(dim=-1) 100 | 101 | n = len(poses_est) 102 | flow_loss = 0.0 103 | 104 | for i in range(n): 105 | w = gamma ** (n - i - 1) 106 | coords1, val1 = projective_transform(poses_est[i], disps_est[i], intrinsics, ii, jj) 107 | 108 | v = (val0 * val1).squeeze(dim=-1) 109 | epe = v * (coords1 - coords0).norm(dim=-1) 110 | flow_loss += w * epe.mean() 111 | 112 | epe = epe.reshape(-1)[v.reshape(-1) > 0.5] 113 | metrics = { 114 | 'f_error': epe.mean().item(), 115 | '1px': (epe<1.0).float().mean().item(), 116 | } 117 | 118 | return flow_loss, metrics 119 | -------------------------------------------------------------------------------- /droid_slam/geom/projective_ops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from lietorch import SE3, Sim3 5 | 6 | MIN_DEPTH = 0.2 7 | 8 | 9 | def extract_intrinsics(intrinsics): 10 | return intrinsics[..., None, None, :].unbind(dim=-1) 11 | 12 | 13 | def coords_grid(ht, wd, **kwargs): 14 | y, x = torch.meshgrid( 15 | torch.arange(ht, dtype=torch.float, **kwargs), 16 | torch.arange(wd, dtype=torch.float, **kwargs), 17 | indexing="ij", 18 | ) 19 | 20 | return torch.stack([x, y], dim=-1) 21 | 22 | 23 | def iproj(disps, intrinsics, jacobian=False): 24 | """pinhole camera inverse projection""" 25 | ht, wd = disps.shape[2:] 26 | fx, fy, cx, cy = extract_intrinsics(intrinsics) 27 | 28 | y, x = torch.meshgrid( 29 | torch.arange(ht, device=disps.device, dtype=torch.float), 30 | torch.arange(wd, device=disps.device, dtype=torch.float), 31 | indexing="ij", 32 | ) 33 | 34 | i = torch.ones_like(disps) 35 | X = (x - cx) / fx 36 | Y = (y - cy) / fy 37 | pts = torch.stack([X, Y, i, disps], dim=-1) 38 | 39 | if jacobian: 40 | J = torch.zeros_like(pts) 41 | J[..., -1] = 1.0 42 | return pts, J 43 | 44 | return pts, None 45 | 46 | 47 | def proj(Xs, intrinsics, jacobian=False, return_depth=False): 48 | """pinhole camera projection""" 49 | fx, fy, cx, cy = extract_intrinsics(intrinsics) 50 | X, Y, Z, D = Xs.unbind(dim=-1) 51 | 52 | Z = torch.where(Z < 0.5 * MIN_DEPTH, torch.ones_like(Z), Z) 53 | d = 1.0 / Z 54 | 55 | x = fx * (X * d) + cx 56 | y = fy * (Y * d) + cy 57 | if return_depth: 58 | coords = torch.stack([x, y, D * d], dim=-1) 59 | else: 60 | coords = torch.stack([x, y], dim=-1) 61 | 62 | if jacobian: 63 | B, N, H, W = d.shape 64 | o = torch.zeros_like(d) 65 | proj_jac = torch.stack( 66 | [ 67 | fx * d, 68 | o, 69 | -fx * X * d * d, 70 | o, 71 | o, 72 | fy * d, 73 | -fy * Y * d * d, 74 | o, 75 | # o, o, -D*d*d, d, 76 | ], 77 | dim=-1, 78 | ).view(B, N, H, W, 2, 4) 79 | 80 | return coords, proj_jac 81 | 82 | return coords, None 83 | 84 | 85 | def actp(Gij, X0, jacobian=False): 86 | """action on point cloud""" 87 | X1 = Gij[:, :, None, None] * X0 88 | 89 | if jacobian: 90 | X, Y, Z, d = X1.unbind(dim=-1) 91 | o = torch.zeros_like(d) 92 | B, N, H, W = d.shape 93 | 94 | if isinstance(Gij, SE3): 95 | Ja = torch.stack( 96 | [ 97 | d, 98 | o, 99 | o, 100 | o, 101 | Z, 102 | -Y, 103 | o, 104 | d, 105 | o, 106 | -Z, 107 | o, 108 | X, 109 | o, 110 | o, 111 | d, 112 | Y, 113 | -X, 114 | o, 115 | o, 116 | o, 117 | o, 118 | o, 119 | o, 120 | o, 121 | ], 122 | dim=-1, 123 | ).view(B, N, H, W, 4, 6) 124 | 125 | elif isinstance(Gij, Sim3): 126 | Ja = torch.stack( 127 | [ 128 | d, 129 | o, 130 | o, 131 | o, 132 | Z, 133 | -Y, 134 | X, 135 | o, 136 | d, 137 | o, 138 | -Z, 139 | o, 140 | X, 141 | Y, 142 | o, 143 | o, 144 | d, 145 | Y, 146 | -X, 147 | o, 148 | Z, 149 | o, 150 | o, 151 | o, 152 | o, 153 | o, 154 | o, 155 | o, 156 | ], 157 | dim=-1, 158 | ).view(B, N, H, W, 4, 7) 159 | 160 | return X1, Ja 161 | 162 | return X1, None 163 | 164 | 165 | def projective_transform( 166 | poses, depths, intrinsics, ii, jj, jacobian=False, return_depth=False 167 | ): 168 | """map points from ii->jj""" 169 | 170 | # inverse project (pinhole) 171 | X0, Jz = iproj(depths[:, ii], intrinsics[:, ii], jacobian=jacobian) 172 | 173 | # transform 174 | Gij = poses[:, jj] * poses[:, ii].inv() 175 | 176 | Gij.data[:, ii == jj] = torch.as_tensor( 177 | [-0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0], device="cuda" 178 | ) 179 | X1, Ja = actp(Gij, X0, jacobian=jacobian) 180 | 181 | # project (pinhole) 182 | x1, Jp = proj(X1, intrinsics[:, jj], jacobian=jacobian, return_depth=return_depth) 183 | 184 | # exclude points too close to camera 185 | valid = ((X1[..., 2] > MIN_DEPTH) & (X0[..., 2] > MIN_DEPTH)).float() 186 | valid = valid.unsqueeze(-1) 187 | 188 | if jacobian: 189 | # Ji transforms according to dual adjoint 190 | Jj = torch.matmul(Jp, Ja) 191 | Ji = -Gij[:, :, None, None, None].adjT(Jj) 192 | 193 | Jz = Gij[:, :, None, None] * Jz 194 | Jz = torch.matmul(Jp, Jz.unsqueeze(-1)) 195 | 196 | return x1, valid, (Ji, Jj, Jz) 197 | 198 | return x1, valid 199 | 200 | 201 | def induced_flow(poses, disps, intrinsics, ii, jj): 202 | """optical flow induced by camera motion""" 203 | 204 | ht, wd = disps.shape[2:] 205 | y, x = torch.meshgrid( 206 | torch.arange(ht, device=disps.device, dtype=torch.float), 207 | torch.arange(wd, device=disps.device, dtype=torch.float), 208 | indexing="ij", 209 | ) 210 | 211 | coords0 = torch.stack([x, y], dim=-1) 212 | coords1, valid = projective_transform(poses, disps, intrinsics, ii, jj, False) 213 | 214 | return coords1[..., :2] - coords0, valid 215 | -------------------------------------------------------------------------------- /droid_slam/logger.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torch.utils.tensorboard import SummaryWriter 4 | 5 | 6 | SUM_FREQ = 100 7 | 8 | class Logger: 9 | def __init__(self, name, scheduler): 10 | self.total_steps = 0 11 | self.running_loss = {} 12 | self.writer = None 13 | self.name = name 14 | self.scheduler = scheduler 15 | 16 | def _print_training_status(self): 17 | if self.writer is None: 18 | self.writer = SummaryWriter('runs/%s' % self.name) 19 | print([k for k in self.running_loss]) 20 | 21 | lr = self.scheduler.get_lr().pop() 22 | metrics_data = [self.running_loss[k]/SUM_FREQ for k in self.running_loss.keys()] 23 | training_str = "[{:6d}, {:10.7f}] ".format(self.total_steps+1, lr) 24 | metrics_str = ("{:10.4f}, "*len(metrics_data)).format(*metrics_data) 25 | 26 | # print the training status 27 | print(training_str + metrics_str) 28 | 29 | for key in self.running_loss: 30 | val = self.running_loss[key] / SUM_FREQ 31 | self.writer.add_scalar(key, val, self.total_steps) 32 | self.running_loss[key] = 0.0 33 | 34 | def push(self, metrics): 35 | 36 | for key in metrics: 37 | if key not in self.running_loss: 38 | self.running_loss[key] = 0.0 39 | 40 | self.running_loss[key] += metrics[key] 41 | 42 | if self.total_steps % SUM_FREQ == SUM_FREQ-1: 43 | self._print_training_status() 44 | self.running_loss = {} 45 | 46 | self.total_steps += 1 47 | 48 | def write_dict(self, results): 49 | for key in results: 50 | self.writer.add_scalar(key, results[key], self.total_steps) 51 | 52 | def close(self): 53 | self.writer.close() 54 | 55 | -------------------------------------------------------------------------------- /droid_slam/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-vl/DROID-SLAM/2dfd39f0dcad44012ca7bbb8aa70b55edbfa9c99/droid_slam/modules/__init__.py -------------------------------------------------------------------------------- /droid_slam/modules/clipping.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | GRAD_CLIP = .01 6 | 7 | class GradClip(torch.autograd.Function): 8 | @staticmethod 9 | def forward(ctx, x): 10 | return x 11 | 12 | @staticmethod 13 | def backward(ctx, grad_x): 14 | o = torch.zeros_like(grad_x) 15 | grad_x = torch.where(grad_x.abs()>GRAD_CLIP, o, grad_x) 16 | grad_x = torch.where(torch.isnan(grad_x), o, grad_x) 17 | return grad_x 18 | 19 | class GradientClip(nn.Module): 20 | def __init__(self): 21 | super(GradientClip, self).__init__() 22 | 23 | def forward(self, x): 24 | return GradClip.apply(x) -------------------------------------------------------------------------------- /droid_slam/modules/corr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | import droid_backends 5 | 6 | class CorrSampler(torch.autograd.Function): 7 | 8 | @staticmethod 9 | def forward(ctx, volume, coords, radius): 10 | ctx.save_for_backward(volume,coords) 11 | ctx.radius = radius 12 | corr, = droid_backends.corr_index_forward(volume, coords, radius) 13 | return corr 14 | 15 | @staticmethod 16 | def backward(ctx, grad_output): 17 | volume, coords = ctx.saved_tensors 18 | grad_output = grad_output.contiguous() 19 | grad_volume, = droid_backends.corr_index_backward(volume, coords, grad_output, ctx.radius) 20 | return grad_volume, None, None 21 | 22 | 23 | class CorrBlock: 24 | def __init__(self, fmap1, fmap2, num_levels=4, radius=3): 25 | self.num_levels = num_levels 26 | self.radius = radius 27 | self.corr_pyramid = [] 28 | 29 | # all pairs correlation 30 | corr = CorrBlock.corr(fmap1, fmap2) 31 | 32 | batch, num, h1, w1, h2, w2 = corr.shape 33 | corr = corr.reshape(batch*num*h1*w1, 1, h2, w2) 34 | 35 | for i in range(self.num_levels): 36 | self.corr_pyramid.append( 37 | corr.view(batch*num, h1, w1, h2//2**i, w2//2**i)) 38 | corr = F.avg_pool2d(corr, 2, stride=2) 39 | 40 | def __call__(self, coords): 41 | out_pyramid = [] 42 | batch, num, ht, wd, _ = coords.shape 43 | coords = coords.permute(0,1,4,2,3) 44 | coords = coords.contiguous().view(batch*num, 2, ht, wd) 45 | 46 | for i in range(self.num_levels): 47 | corr = CorrSampler.apply(self.corr_pyramid[i], coords/2**i, self.radius) 48 | out_pyramid.append(corr.view(batch, num, -1, ht, wd)) 49 | 50 | return torch.cat(out_pyramid, dim=2) 51 | 52 | def cat(self, other): 53 | for i in range(self.num_levels): 54 | self.corr_pyramid[i] = torch.cat([self.corr_pyramid[i], other.corr_pyramid[i]], 0) 55 | return self 56 | 57 | def __getitem__(self, index): 58 | for i in range(self.num_levels): 59 | self.corr_pyramid[i] = self.corr_pyramid[i][index] 60 | return self 61 | 62 | 63 | @staticmethod 64 | def corr(fmap1, fmap2): 65 | """ all-pairs correlation """ 66 | batch, num, dim, ht, wd = fmap1.shape 67 | fmap1 = fmap1.reshape(batch*num, dim, ht*wd) / 4.0 68 | fmap2 = fmap2.reshape(batch*num, dim, ht*wd) / 4.0 69 | 70 | corr = torch.matmul(fmap1.transpose(1,2), fmap2) 71 | return corr.view(batch, num, ht, wd, ht, wd) 72 | 73 | 74 | class CorrLayer(torch.autograd.Function): 75 | @staticmethod 76 | def forward(ctx, fmap1, fmap2, coords, ii, jj, r): 77 | ctx.r = r 78 | ctx.save_for_backward(fmap1, fmap2, coords, ii, jj) 79 | corr, = droid_backends.altcorr_forward(fmap1, fmap2, coords, ii, jj, ctx.r) 80 | return corr 81 | 82 | @staticmethod 83 | def backward(ctx, grad_corr): 84 | fmap1, fmap2, coords = ctx.saved_tensors 85 | fmap1_grad, fmap2_grad, coords_grad, ii, jj = \ 86 | droid_backends.altcorr_backward(fmap1, fmap2, coords, ii, jj, grad_corr, ctx.r) 87 | return fmap1_grad, fmap2_grad, coords_grad, None 88 | 89 | class AltCorrBlock: 90 | def __init__(self, fmaps, num_levels=4, radius=3): 91 | self.num_levels = num_levels 92 | self.radius = radius 93 | 94 | B, N, C, H, W = fmaps.shape 95 | fmaps = fmaps.view(B*N, C, H, W) 96 | 97 | self.pyramid = [] 98 | for i in range(self.num_levels): 99 | sz = (B, N, C, H//2**i, W//2**i) 100 | self.pyramid.append(fmaps.view(*sz)) 101 | fmaps = F.avg_pool2d(fmaps, 2, stride=2) 102 | 103 | 104 | def __call__(self, coords, ii, jj): 105 | 106 | coords = coords.permute(0, 1, 4, 2, 3).contiguous() 107 | 108 | corr_list = [] 109 | for i in range(self.num_levels): 110 | corr = CorrLayer.apply( 111 | self.pyramid[0], self.pyramid[i], coords / 2**i, ii, jj, self.radius 112 | ) 113 | 114 | corr_list.append(corr.flatten(2, 3)) 115 | 116 | corr = torch.stack(corr_list, dim=2).flatten(2, 3) 117 | return corr 118 | -------------------------------------------------------------------------------- /droid_slam/modules/extractor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class ResidualBlock(nn.Module): 7 | def __init__(self, in_planes, planes, norm_fn='group', stride=1): 8 | super(ResidualBlock, self).__init__() 9 | 10 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride) 11 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) 12 | self.relu = nn.ReLU(inplace=True) 13 | 14 | num_groups = planes // 8 15 | 16 | if norm_fn == 'group': 17 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 18 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 19 | if not stride == 1: 20 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 21 | 22 | elif norm_fn == 'batch': 23 | self.norm1 = nn.BatchNorm2d(planes) 24 | self.norm2 = nn.BatchNorm2d(planes) 25 | if not stride == 1: 26 | self.norm3 = nn.BatchNorm2d(planes) 27 | 28 | elif norm_fn == 'instance': 29 | self.norm1 = nn.InstanceNorm2d(planes) 30 | self.norm2 = nn.InstanceNorm2d(planes) 31 | if not stride == 1: 32 | self.norm3 = nn.InstanceNorm2d(planes) 33 | 34 | elif norm_fn == 'none': 35 | self.norm1 = nn.Sequential() 36 | self.norm2 = nn.Sequential() 37 | if not stride == 1: 38 | self.norm3 = nn.Sequential() 39 | 40 | if stride == 1: 41 | self.downsample = None 42 | 43 | else: 44 | self.downsample = nn.Sequential( 45 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) 46 | 47 | def forward(self, x): 48 | y = x 49 | y = self.relu(self.norm1(self.conv1(y))) 50 | y = self.relu(self.norm2(self.conv2(y))) 51 | 52 | if self.downsample is not None: 53 | x = self.downsample(x) 54 | 55 | return self.relu(x+y) 56 | 57 | 58 | class BottleneckBlock(nn.Module): 59 | def __init__(self, in_planes, planes, norm_fn='group', stride=1): 60 | super(BottleneckBlock, self).__init__() 61 | 62 | self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0) 63 | self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride) 64 | self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0) 65 | self.relu = nn.ReLU(inplace=True) 66 | 67 | num_groups = planes // 8 68 | 69 | if norm_fn == 'group': 70 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) 71 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) 72 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 73 | if not stride == 1: 74 | self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 75 | 76 | elif norm_fn == 'batch': 77 | self.norm1 = nn.BatchNorm2d(planes//4) 78 | self.norm2 = nn.BatchNorm2d(planes//4) 79 | self.norm3 = nn.BatchNorm2d(planes) 80 | if not stride == 1: 81 | self.norm4 = nn.BatchNorm2d(planes) 82 | 83 | elif norm_fn == 'instance': 84 | self.norm1 = nn.InstanceNorm2d(planes//4) 85 | self.norm2 = nn.InstanceNorm2d(planes//4) 86 | self.norm3 = nn.InstanceNorm2d(planes) 87 | if not stride == 1: 88 | self.norm4 = nn.InstanceNorm2d(planes) 89 | 90 | elif norm_fn == 'none': 91 | self.norm1 = nn.Sequential() 92 | self.norm2 = nn.Sequential() 93 | self.norm3 = nn.Sequential() 94 | if not stride == 1: 95 | self.norm4 = nn.Sequential() 96 | 97 | if stride == 1: 98 | self.downsample = None 99 | 100 | else: 101 | self.downsample = nn.Sequential( 102 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4) 103 | 104 | def forward(self, x): 105 | y = x 106 | y = self.relu(self.norm1(self.conv1(y))) 107 | y = self.relu(self.norm2(self.conv2(y))) 108 | y = self.relu(self.norm3(self.conv3(y))) 109 | 110 | if self.downsample is not None: 111 | x = self.downsample(x) 112 | 113 | return self.relu(x+y) 114 | 115 | 116 | DIM=32 117 | 118 | class BasicEncoder(nn.Module): 119 | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0, multidim=False): 120 | super(BasicEncoder, self).__init__() 121 | self.norm_fn = norm_fn 122 | self.multidim = multidim 123 | 124 | if self.norm_fn == 'group': 125 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=DIM) 126 | 127 | elif self.norm_fn == 'batch': 128 | self.norm1 = nn.BatchNorm2d(DIM) 129 | 130 | elif self.norm_fn == 'instance': 131 | self.norm1 = nn.InstanceNorm2d(DIM) 132 | 133 | elif self.norm_fn == 'none': 134 | self.norm1 = nn.Sequential() 135 | 136 | self.conv1 = nn.Conv2d(3, DIM, kernel_size=7, stride=2, padding=3) 137 | self.relu1 = nn.ReLU(inplace=True) 138 | 139 | self.in_planes = DIM 140 | self.layer1 = self._make_layer(DIM, stride=1) 141 | self.layer2 = self._make_layer(2*DIM, stride=2) 142 | self.layer3 = self._make_layer(4*DIM, stride=2) 143 | 144 | # output convolution 145 | self.conv2 = nn.Conv2d(4*DIM, output_dim, kernel_size=1) 146 | 147 | if self.multidim: 148 | self.layer4 = self._make_layer(256, stride=2) 149 | self.layer5 = self._make_layer(512, stride=2) 150 | 151 | self.in_planes = 256 152 | self.layer6 = self._make_layer(256, stride=1) 153 | 154 | self.in_planes = 128 155 | self.layer7 = self._make_layer(128, stride=1) 156 | 157 | self.up1 = nn.Conv2d(512, 256, 1) 158 | self.up2 = nn.Conv2d(256, 128, 1) 159 | self.conv3 = nn.Conv2d(128, output_dim, kernel_size=1) 160 | 161 | if dropout > 0: 162 | self.dropout = nn.Dropout2d(p=dropout) 163 | else: 164 | self.dropout = None 165 | 166 | for m in self.modules(): 167 | if isinstance(m, nn.Conv2d): 168 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 169 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): 170 | if m.weight is not None: 171 | nn.init.constant_(m.weight, 1) 172 | if m.bias is not None: 173 | nn.init.constant_(m.bias, 0) 174 | 175 | def _make_layer(self, dim, stride=1): 176 | layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) 177 | layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) 178 | layers = (layer1, layer2) 179 | 180 | self.in_planes = dim 181 | return nn.Sequential(*layers) 182 | 183 | def forward(self, x): 184 | b, n, c1, h1, w1 = x.shape 185 | x = x.view(b*n, c1, h1, w1) 186 | 187 | x = self.conv1(x) 188 | x = self.norm1(x) 189 | x = self.relu1(x) 190 | 191 | x = self.layer1(x) 192 | x = self.layer2(x) 193 | x = self.layer3(x) 194 | 195 | x = self.conv2(x) 196 | 197 | _, c2, h2, w2 = x.shape 198 | return x.view(b, n, c2, h2, w2) 199 | -------------------------------------------------------------------------------- /droid_slam/modules/gru.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class ConvGRU(nn.Module): 6 | def __init__(self, h_planes=128, i_planes=128): 7 | super(ConvGRU, self).__init__() 8 | self.do_checkpoint = False 9 | self.convz = nn.Conv2d(h_planes+i_planes, h_planes, 3, padding=1) 10 | self.convr = nn.Conv2d(h_planes+i_planes, h_planes, 3, padding=1) 11 | self.convq = nn.Conv2d(h_planes+i_planes, h_planes, 3, padding=1) 12 | 13 | self.w = nn.Conv2d(h_planes, h_planes, 1, padding=0) 14 | 15 | self.convz_glo = nn.Conv2d(h_planes, h_planes, 1, padding=0) 16 | self.convr_glo = nn.Conv2d(h_planes, h_planes, 1, padding=0) 17 | self.convq_glo = nn.Conv2d(h_planes, h_planes, 1, padding=0) 18 | 19 | def forward(self, net, *inputs): 20 | inp = torch.cat(inputs, dim=1) 21 | net_inp = torch.cat([net, inp], dim=1) 22 | 23 | b, c, h, w = net.shape 24 | glo = torch.sigmoid(self.w(net)) * net 25 | glo = glo.view(b, c, h*w).mean(-1).view(b, c, 1, 1) 26 | 27 | z = torch.sigmoid(self.convz(net_inp) + self.convz_glo(glo)) 28 | r = torch.sigmoid(self.convr(net_inp) + self.convr_glo(glo)) 29 | q = torch.tanh(self.convq(torch.cat([r*net, inp], dim=1)) + self.convq_glo(glo)) 30 | 31 | net = (1-z) * net + z * q 32 | return net 33 | 34 | 35 | -------------------------------------------------------------------------------- /droid_slam/motion_filter.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch 3 | import lietorch 4 | 5 | from collections import OrderedDict 6 | from droid_net import DroidNet 7 | 8 | import geom.projective_ops as pops 9 | from modules.corr import CorrBlock 10 | 11 | from functools import partial 12 | 13 | if torch.__version__.startswith("2"): 14 | autocast = partial(torch.autocast, device_type="cuda") 15 | else: 16 | autocast = torch.cuda.amp.autocast 17 | 18 | 19 | class MotionFilter: 20 | """ This class is used to filter incoming frames and extract features """ 21 | 22 | def __init__(self, net, video, thresh=2.5, device="cuda"): 23 | 24 | # split net modules 25 | self.cnet = net.cnet 26 | self.fnet = net.fnet 27 | self.update = net.update 28 | 29 | self.video = video 30 | self.thresh = thresh 31 | self.device = device 32 | 33 | self.count = 0 34 | 35 | # mean, std for image normalization 36 | self.MEAN = torch.as_tensor([0.485, 0.456, 0.406], device=self.device)[:, None, None] 37 | self.STDV = torch.as_tensor([0.229, 0.224, 0.225], device=self.device)[:, None, None] 38 | 39 | @autocast(enabled=True) 40 | def __context_encoder(self, image): 41 | """ context features """ 42 | net, inp = self.cnet(image).split([128,128], dim=2) 43 | return net.tanh().squeeze(0), inp.relu().squeeze(0) 44 | 45 | @autocast(enabled=True) 46 | def __feature_encoder(self, image): 47 | """ features for correlation volume """ 48 | return self.fnet(image).squeeze(0) 49 | 50 | @autocast(enabled=True) 51 | @torch.no_grad() 52 | def track(self, tstamp, image, depth=None, intrinsics=None): 53 | """ main update operation - run on every frame in video """ 54 | 55 | Id = lietorch.SE3.Identity(1,).data.squeeze() 56 | ht = image.shape[-2] // 8 57 | wd = image.shape[-1] // 8 58 | 59 | image = image.cuda() 60 | 61 | # normalize images 62 | inputs = image[None, :, [2,1,0]].to(self.device) / 255.0 63 | inputs = inputs.sub_(self.MEAN).div_(self.STDV) 64 | 65 | # extract features 66 | gmap = self.__feature_encoder(inputs) 67 | 68 | ### always add first frame to the depth video ### 69 | if self.video.counter.value == 0: 70 | net, inp = self.__context_encoder(inputs[:,[0]]) 71 | self.net, self.inp, self.fmap = net, inp, gmap 72 | self.video.append(tstamp, image[0], Id, 1.0, depth, intrinsics / 8.0, gmap, net[0,0], inp[0,0]) 73 | 74 | ### only add new frame if there is enough motion ### 75 | else: 76 | # index correlation volume 77 | coords0 = pops.coords_grid(ht, wd, device=self.device)[None,None] 78 | corr = CorrBlock(self.fmap[None,[0]], gmap[None,[0]])(coords0) 79 | 80 | # approximate flow magnitude using 1 update iteration 81 | _, delta, weight = self.update(self.net[None], self.inp[None], corr) 82 | 83 | # check motion magnitue / add new frame to video 84 | if delta.norm(dim=-1).mean().item() > self.thresh: 85 | self.count = 0 86 | net, inp = self.__context_encoder(inputs[:,[0]]) 87 | self.net, self.inp, self.fmap = net, inp, gmap 88 | self.video.append(tstamp, image[0], None, None, depth, intrinsics / 8.0, gmap, net[0], inp[0]) 89 | 90 | else: 91 | self.count += 1 92 | -------------------------------------------------------------------------------- /droid_slam/trajectory_filler.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch 3 | import lietorch 4 | 5 | from lietorch import SE3 6 | from collections import OrderedDict 7 | from factor_graph import FactorGraph 8 | from droid_net import DroidNet 9 | import geom.projective_ops as pops 10 | 11 | from functools import partial 12 | 13 | if torch.__version__.startswith("2"): 14 | autocast = partial(torch.autocast, device_type="cuda") 15 | else: 16 | autocast = torch.cuda.amp.autocast 17 | 18 | 19 | class PoseTrajectoryFiller: 20 | """ This class is used to fill in non-keyframe poses """ 21 | 22 | def __init__(self, net, video, device="cuda"): 23 | 24 | # split net modules 25 | self.cnet = net.cnet 26 | self.fnet = net.fnet 27 | self.update = net.update 28 | 29 | self.count = 0 30 | self.video = video 31 | self.device = device 32 | 33 | # mean, std for image normalization 34 | self.MEAN = torch.as_tensor([0.485, 0.456, 0.406], device=self.device)[:, None, None] 35 | self.STDV = torch.as_tensor([0.229, 0.224, 0.225], device=self.device)[:, None, None] 36 | 37 | @autocast(enabled=True) 38 | def __feature_encoder(self, image): 39 | """ features for correlation volume """ 40 | return self.fnet(image) 41 | 42 | def __fill(self, tstamps, images, intrinsics): 43 | """ fill operator """ 44 | 45 | tt = torch.as_tensor(tstamps, device="cuda") 46 | images = torch.stack(images, 0).cuda() 47 | intrinsics = torch.stack(intrinsics, 0) 48 | inputs = images[:,:,[2,1,0]].to(self.device) / 255.0 49 | 50 | ### linear pose interpolation ### 51 | N = self.video.counter.value 52 | M = len(tstamps) 53 | 54 | ts = self.video.tstamp[:N] 55 | Ps = SE3(self.video.poses[:N]) 56 | 57 | t0 = torch.as_tensor([ts[ts<=t].shape[0] - 1 for t in tstamps]) 58 | t1 = torch.where(t0 0: 107 | pose_list += self.__fill(tstamps, images, intrinsics) 108 | 109 | # stitch pose segments together 110 | return lietorch.cat(pose_list, 0) 111 | 112 | -------------------------------------------------------------------------------- /droid_slam/visualization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import cv2 3 | import lietorch 4 | import droid_backends 5 | import time 6 | import argparse 7 | import numpy as np 8 | import open3d as o3d 9 | 10 | from lietorch import SE3 11 | import geom.projective_ops as pops 12 | 13 | CAM_POINTS = np.array([ 14 | [ 0, 0, 0], 15 | [-1, -1, 1.5], 16 | [ 1, -1, 1.5], 17 | [ 1, 1, 1.5], 18 | [-1, 1, 1.5], 19 | [-0.5, 1, 1.5], 20 | [ 0.5, 1, 1.5], 21 | [ 0, 1.2, 1.5]]) 22 | 23 | CAM_LINES = np.array([ 24 | [1,2], [2,3], [3,4], [4,1], [1,0], [0,2], [3,0], [0,4], [5,7], [7,6]]) 25 | 26 | def white_balance(img): 27 | # from https://stackoverflow.com/questions/46390779/automatic-white-balancing-with-grayworld-assumption 28 | result = cv2.cvtColor(img, cv2.COLOR_BGR2LAB) 29 | avg_a = np.average(result[:, :, 1]) 30 | avg_b = np.average(result[:, :, 2]) 31 | result[:, :, 1] = result[:, :, 1] - ((avg_a - 128) * (result[:, :, 0] / 255.0) * 1.1) 32 | result[:, :, 2] = result[:, :, 2] - ((avg_b - 128) * (result[:, :, 0] / 255.0) * 1.1) 33 | result = cv2.cvtColor(result, cv2.COLOR_LAB2BGR) 34 | return result 35 | 36 | def create_camera_actor(g, scale=0.05): 37 | """ build open3d camera polydata """ 38 | camera_actor = o3d.geometry.LineSet( 39 | points=o3d.utility.Vector3dVector(scale * CAM_POINTS), 40 | lines=o3d.utility.Vector2iVector(CAM_LINES)) 41 | 42 | color = (g * 1.0, 0.5 * (1-g), 0.9 * (1-g)) 43 | camera_actor.paint_uniform_color(color) 44 | return camera_actor 45 | 46 | def create_point_actor(points, colors): 47 | """ open3d point cloud from numpy array """ 48 | point_cloud = o3d.geometry.PointCloud() 49 | point_cloud.points = o3d.utility.Vector3dVector(points) 50 | point_cloud.colors = o3d.utility.Vector3dVector(colors) 51 | return point_cloud 52 | 53 | def droid_visualization(video, device="cuda:0"): 54 | """ DROID visualization frontend """ 55 | 56 | torch.cuda.set_device(device) 57 | droid_visualization.video = video 58 | droid_visualization.cameras = {} 59 | droid_visualization.points = {} 60 | droid_visualization.warmup = 8 61 | droid_visualization.scale = 1.0 62 | droid_visualization.ix = 0 63 | 64 | droid_visualization.filter_thresh = 0.005 65 | droid_visualization.is_initialized = False 66 | 67 | def increase_filter(vis): 68 | droid_visualization.filter_thresh *= 2 69 | with droid_visualization.video.get_lock(): 70 | droid_visualization.video.dirty[:droid_visualization.video.counter.value] = True 71 | 72 | def decrease_filter(vis): 73 | droid_visualization.filter_thresh *= 0.5 74 | with droid_visualization.video.get_lock(): 75 | droid_visualization.video.dirty[:droid_visualization.video.counter.value] = True 76 | 77 | def animation_callback(vis): 78 | cam = vis.get_view_control().convert_to_pinhole_camera_parameters() 79 | 80 | if not droid_visualization.is_initialized: 81 | extrinsics = np.eye(4) 82 | extrinsics[2, 3] = 2.0 83 | cam.extrinsic = extrinsics 84 | droid_visualization.is_initialized = True 85 | 86 | with torch.no_grad(): 87 | 88 | with video.get_lock(): 89 | t = video.counter.value 90 | dirty_index, = torch.where(video.dirty.clone()) 91 | dirty_index = dirty_index 92 | 93 | if len(dirty_index) == 0: 94 | return 95 | 96 | video.dirty[dirty_index] = False 97 | 98 | # convert poses to 4x4 matrix 99 | poses = torch.index_select(video.poses, 0, dirty_index) 100 | disps = torch.index_select(video.disps, 0, dirty_index) 101 | Ps = SE3(poses).inv().matrix().cpu().numpy() 102 | 103 | images = torch.index_select(video.images, 0, dirty_index) 104 | images = images.cpu()[:,[2,1,0],3::8,3::8].permute(0,2,3,1) / 255.0 105 | points = droid_backends.iproj(SE3(poses).inv().data, disps, video.intrinsics[0]).cpu() 106 | 107 | thresh = droid_visualization.filter_thresh * torch.ones_like(disps.mean(dim=[1,2])) 108 | 109 | count = droid_backends.depth_filter( 110 | video.poses, video.disps, video.intrinsics[0], dirty_index, thresh) 111 | 112 | count = count.cpu() 113 | disps = disps.cpu() 114 | masks = ((count >= 2) & (disps > .5*disps.mean(dim=[1,2], keepdim=True))) 115 | 116 | for i in range(len(dirty_index)): 117 | pose = Ps[i] 118 | ix = dirty_index[i].item() 119 | 120 | if ix in droid_visualization.cameras: 121 | vis.remove_geometry(droid_visualization.cameras[ix]) 122 | del droid_visualization.cameras[ix] 123 | 124 | if ix in droid_visualization.points: 125 | vis.remove_geometry(droid_visualization.points[ix]) 126 | del droid_visualization.points[ix] 127 | 128 | ### add camera actor ### 129 | cam_actor = create_camera_actor(True) 130 | cam_actor.transform(pose) 131 | vis.add_geometry(cam_actor) 132 | droid_visualization.cameras[ix] = cam_actor 133 | 134 | mask = masks[i].reshape(-1) 135 | pts = points[i].reshape(-1, 3)[mask].cpu().numpy() 136 | clr = images[i].reshape(-1, 3)[mask].cpu().numpy() 137 | 138 | ## add point actor ### 139 | point_actor = create_point_actor(pts, clr) 140 | vis.add_geometry(point_actor) 141 | droid_visualization.points[ix] = point_actor 142 | 143 | # hack to allow interacting with vizualization during inference 144 | cam = vis.get_view_control().convert_from_pinhole_camera_parameters(cam) 145 | 146 | droid_visualization.ix += 1 147 | vis.poll_events() 148 | vis.update_renderer() 149 | 150 | ### create Open3D visualization ### 151 | vis = o3d.visualization.VisualizerWithKeyCallback() 152 | vis.register_animation_callback(animation_callback) 153 | vis.register_key_callback(ord("S"), increase_filter) 154 | vis.register_key_callback(ord("A"), decrease_filter) 155 | 156 | vis.create_window(height=540, width=960) 157 | vis.get_render_option().load_from_json("misc/renderoption.json") 158 | 159 | vis.run() 160 | vis.destroy_window() 161 | -------------------------------------------------------------------------------- /droid_slam/visualizer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-vl/DROID-SLAM/2dfd39f0dcad44012ca7bbb8aa70b55edbfa9c99/droid_slam/visualizer/__init__.py -------------------------------------------------------------------------------- /droid_slam/visualizer/camera.py: -------------------------------------------------------------------------------- 1 | import time 2 | from typing import Any, Optional, Union 3 | 4 | import glm 5 | from glm import cos, radians, sin 6 | 7 | import moderngl_window 8 | from moderngl_window.scene.camera import Camera 9 | 10 | 11 | class OrbitCamera(Camera): 12 | def __init__( 13 | self, 14 | target: Union[glm.vec3, tuple[float, float, float]] = (0.0, 0.0, 0.0), 15 | radius: float = 2.0, 16 | angles: tuple[float, float] = (60.0, -100.0), 17 | **kwargs: Any, 18 | ): 19 | self.radius = radius # radius in base units 20 | self.angle_x, self.angle_y = angles # angles in degrees 21 | self.target = glm.vec3(target) # camera target in base units 22 | self.up = glm.vec3(0.0, 1.0, 0.0) # camera up vector 23 | 24 | self._mouse_sensitivity = 1.0 25 | self._zoom_sensitivity = 1.0 26 | 27 | self.world_up = glm.vec3(0.0, -1.0, 0.0) 28 | self._mouse_sensitivity = 1.0 29 | self._zoom_sensitivity = 1.0 30 | self._pan_sensitivity = 0.001 31 | super().__init__(**kwargs) 32 | 33 | @property 34 | def pan_sensitivity(self) -> float: 35 | return self._pan_sensitivity 36 | 37 | @pan_sensitivity.setter 38 | def pan_sensitivity(self, value: float): 39 | self._pan_sensitivity = value 40 | 41 | def rot_state(self, dx: float, dy: float) -> None: 42 | """Unclamped, continuous orbit around the target.""" 43 | self.angle_x = (self.angle_x - dx * self.mouse_sensitivity / 10.0) 44 | self.angle_y = (self.angle_y - dy * self.mouse_sensitivity / 10.0) 45 | self.angle_y = max(min(self.angle_y, -5.0), -175.0) 46 | # self.angle_y = self.angle_y.clamp() 47 | 48 | def zoom_state(self, y_offset: float) -> None: 49 | self.radius = max(1.0, self.radius - y_offset * self._zoom_sensitivity) 50 | 51 | @property 52 | def matrix(self) -> glm.mat4: 53 | # compute camera position as before 54 | px = cos(radians(self.angle_x)) * sin(radians(self.angle_y)) * self.radius + self.target.x 55 | py = cos(radians(self.angle_y)) * self.radius + self.target.y 56 | pz = sin(radians(self.angle_x)) * sin(radians(self.angle_y)) * self.radius + self.target.z 57 | pos = glm.vec3(px, py, pz) 58 | self.set_position(*pos) 59 | return glm.lookAt(pos, self.target, self.world_up) 60 | 61 | def pan_state(self, dx: float, dy: float) -> None: 62 | """Pan the orbit‐center using camera‐relative axes.""" 63 | # Recompute camera position & forward vector 64 | px = cos(radians(self.angle_x)) * sin(radians(self.angle_y)) * self.radius + self.target.x 65 | py = cos(radians(self.angle_y)) * self.radius + self.target.y 66 | pz = sin(radians(self.angle_x)) * sin(radians(self.angle_y)) * self.radius + self.target.z 67 | pos = glm.vec3(px, py, pz) 68 | forward = glm.normalize(self.target - pos) 69 | 70 | # Build a stable right & up in camera‐space: 71 | right = glm.normalize(glm.cross(forward, self.world_up)) 72 | up = glm.normalize(glm.cross(right, forward)) 73 | 74 | # Screen‐space offset: right = +dx, up = +dy 75 | offset = (-right * dx + up * dy) * self._pan_sensitivity * self.radius 76 | self.target += offset 77 | 78 | 79 | 80 | class OrbitDragCameraWindow(moderngl_window.WindowConfig): 81 | """Base class with drag-based 3D orbit support 82 | 83 | Click and drag with the left mouse button to orbit the camera around the view point. 84 | """ 85 | 86 | def __init__(self, **kwargs): 87 | super().__init__(**kwargs) 88 | self.camera = OrbitCamera(aspect_ratio=self.wnd.aspect_ratio) 89 | 90 | def on_key_event(self, key, action, modifiers): 91 | keys = self.wnd.keys 92 | 93 | if action == keys.ACTION_PRESS: 94 | if key == keys.SPACE: 95 | self.timer.toggle_pause() 96 | 97 | def on_mouse_drag_event(self, x: int, y: int, dx: float, dy: float): 98 | mb = self.wnd.mouse_states 99 | if mb.right: # ← right‑button drag → pan 100 | self.camera.pan_state(dx, dy) 101 | else: # ← left‑button drag → orbit 102 | self.camera.rot_state(dx, dy) 103 | 104 | def on_mouse_scroll_event(self, x_offset: float, y_offset: float): 105 | self.camera.zoom_state(y_offset) 106 | 107 | def on_resize(self, width: int, height: int): 108 | self.camera.projection.update(aspect_ratio=self.wnd.aspect_ratio) 109 | 110 | -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: droidenv 2 | channels: 3 | - rusty1s 4 | - pytorch 5 | - open3d-admin 6 | - nvidia 7 | - conda-forge 8 | - defaults 9 | dependencies: 10 | - pytorch-scatter 11 | - torchaudio 12 | - torchvision 13 | - open3d 14 | - pytorch=1.10 15 | - cudatoolkit=11.3 16 | - tensorboard 17 | - scipy 18 | - opencv 19 | - tqdm 20 | - suitesparse 21 | - matplotlib 22 | - pyyaml 23 | -------------------------------------------------------------------------------- /environment_novis.yaml: -------------------------------------------------------------------------------- 1 | name: droidenv 2 | channels: 3 | - rusty1s 4 | - pytorch 5 | - nvidia 6 | - conda-forge 7 | - defaults 8 | dependencies: 9 | - pytorch-scatter 10 | - torchaudio 11 | - torchvision 12 | - pytorch=1.10 13 | - cudatoolkit=11.3 14 | - tensorboard 15 | - scipy 16 | - opencv 17 | - tqdm 18 | - suitesparse 19 | - matplotlib 20 | - pyyaml 21 | -------------------------------------------------------------------------------- /evaluation_scripts/parse_results.py: -------------------------------------------------------------------------------- 1 | import re 2 | import sys 3 | import numpy as np 4 | 5 | def extract_rmse_from_file(filepath): 6 | rmse_values = [] 7 | # Regular expression to match 'rmse': 8 | rmse_pattern = re.compile(r'^\s*rmse\s+([0-9]+\.[0-9]+)') 9 | 10 | with open(filepath, 'r') as file: 11 | for line in file: 12 | match = rmse_pattern.search(line) 13 | if match: 14 | try: 15 | rmse = float(match.group(1)) 16 | rmse_values.append(rmse) 17 | except ValueError: 18 | print(f"Skipping invalid float: {match.group(1)}") 19 | 20 | return rmse_values 21 | 22 | 23 | # Example usage: 24 | filepath = sys.argv[1] 25 | rmse_list = extract_rmse_from_file(filepath) 26 | 27 | arr = 100 * np.array(rmse_list) 28 | print("rmse auc 2cm", np.sum(np.clip(2.0 - arr, 0.0, None))) 29 | print("rmse auc 8cm", np.sum(np.clip(8.0 - arr, 0.0, None))) 30 | 31 | print() 32 | print("Listing RMSE") 33 | for rmse in rmse_list: 34 | print(rmse) 35 | 36 | print(f"Average: {np.mean(rmse_list)}") -------------------------------------------------------------------------------- /evaluation_scripts/test_eth3d.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('droid_slam') 3 | 4 | from tqdm import tqdm 5 | import numpy as np 6 | import torch 7 | import lietorch 8 | import cv2 9 | import os 10 | import glob 11 | import time 12 | import argparse 13 | 14 | import torch.nn.functional as F 15 | from droid import Droid 16 | from droid_async import DroidAsync 17 | 18 | import matplotlib.pyplot as plt 19 | 20 | 21 | def show_image(image): 22 | image = image.permute(1, 2, 0).cpu().numpy() 23 | cv2.imshow('image', image / 255.0) 24 | cv2.waitKey(1) 25 | 26 | def image_stream(datapath, use_depth=False, stride=1): 27 | """ image generator """ 28 | 29 | fx, fy, cx, cy = np.loadtxt(os.path.join(datapath, 'calibration.txt')).tolist() 30 | image_list = sorted(glob.glob(os.path.join(datapath, 'rgb', '*.png')))[::stride] 31 | depth_list = sorted(glob.glob(os.path.join(datapath, 'depth', '*.png')))[::stride] 32 | 33 | for t, (image_file, depth_file) in enumerate(zip(image_list, depth_list)): 34 | image = cv2.imread(image_file) 35 | depth = cv2.imread(depth_file, cv2.IMREAD_ANYDEPTH) / 5000.0 36 | 37 | h0, w0, _ = image.shape 38 | h1 = int(h0 * np.sqrt((384 * 512) / (h0 * w0))) 39 | w1 = int(w0 * np.sqrt((384 * 512) / (h0 * w0))) 40 | 41 | image = cv2.resize(image, (w1, h1)) 42 | image = image[:h1-h1%8, :w1-w1%8] 43 | image = torch.as_tensor(image).permute(2, 0, 1) 44 | 45 | depth = torch.as_tensor(depth) 46 | depth = F.interpolate(depth[None,None], (h1, w1)).squeeze() 47 | depth = depth[:h1-h1%8, :w1-w1%8] 48 | 49 | intrinsics = torch.as_tensor([fx, fy, cx, cy]) 50 | intrinsics[0::2] *= (w1 / w0) 51 | intrinsics[1::2] *= (h1 / h0) 52 | 53 | if use_depth: 54 | yield t, image[None], depth, intrinsics 55 | 56 | else: 57 | yield t, image[None], intrinsics 58 | 59 | if __name__ == '__main__': 60 | parser = argparse.ArgumentParser() 61 | parser.add_argument("--datapath") 62 | parser.add_argument("--weights", default="droid.pth") 63 | parser.add_argument("--buffer", type=int, default=1024) 64 | parser.add_argument("--image_size", default=[240, 320]) 65 | parser.add_argument("--disable_vis", action="store_true") 66 | 67 | parser.add_argument("--beta", type=float, default=0.5) 68 | parser.add_argument("--filter_thresh", type=float, default=2.0) 69 | parser.add_argument("--warmup", type=int, default=8) 70 | parser.add_argument("--keyframe_thresh", type=float, default=3.5) 71 | parser.add_argument("--frontend_thresh", type=float, default=16.0) 72 | parser.add_argument("--frontend_window", type=int, default=20) 73 | parser.add_argument("--frontend_radius", type=int, default=2) 74 | parser.add_argument("--frontend_nms", type=int, default=1) 75 | 76 | parser.add_argument("--stereo", action="store_true") 77 | parser.add_argument("--depth", action="store_true") 78 | 79 | parser.add_argument("--backend_thresh", type=float, default=22.0) 80 | parser.add_argument("--backend_radius", type=int, default=2) 81 | parser.add_argument("--backend_nms", type=int, default=3) 82 | parser.add_argument("--motion_damping", type=float, default=0.5) 83 | 84 | parser.add_argument("--upsample", action="store_true") 85 | parser.add_argument("--asynchronous", action="store_true") 86 | parser.add_argument("--frontend_device", type=str, default="cuda") 87 | parser.add_argument("--backend_device", type=str, default="cuda") 88 | 89 | 90 | args = parser.parse_args() 91 | 92 | torch.multiprocessing.set_start_method('spawn') 93 | 94 | print("Running evaluation on {}".format(args.datapath)) 95 | print(args) 96 | 97 | # this can usually be set to 2-3 except for "camera_shake" scenes 98 | # set to 2 for test scenes 99 | stride = 1 100 | 101 | tstamps = [] 102 | for (t, image, depth, intrinsics) in tqdm(image_stream(args.datapath, use_depth=True, stride=stride)): 103 | if not args.disable_vis: 104 | show_image(image[0]) 105 | 106 | if t == 0: 107 | args.image_size = [image.shape[2], image.shape[3]] 108 | droid = DroidAsync(args) if args.asynchronous else Droid(args) 109 | 110 | droid.track(t, image, depth, intrinsics=intrinsics) 111 | 112 | traj_est = droid.terminate(image_stream(args.datapath, use_depth=False, stride=stride)) 113 | 114 | ### run evaluation ### 115 | 116 | print("#"*20 + " Results...") 117 | 118 | import evo 119 | from evo.core.trajectory import PoseTrajectory3D 120 | from evo.tools import file_interface 121 | from evo.core import sync 122 | import evo.main_ape as main_ape 123 | from evo.core.metrics import PoseRelation 124 | 125 | image_path = os.path.join(args.datapath, 'rgb') 126 | images_list = sorted(glob.glob(os.path.join(image_path, '*.png')))[::stride] 127 | tstamps = [float(x.split('/')[-1][:-4]) for x in images_list] 128 | 129 | traj_est = PoseTrajectory3D( 130 | positions_xyz=traj_est[:,:3], 131 | orientations_quat_wxyz=traj_est[:,3:], 132 | timestamps=np.array(tstamps)) 133 | 134 | gt_file = os.path.join(args.datapath, 'groundtruth.txt') 135 | traj_ref = file_interface.read_tum_trajectory_file(gt_file) 136 | 137 | traj_ref, traj_est = sync.associate_trajectories(traj_ref, traj_est) 138 | 139 | result = main_ape.ape(traj_ref, traj_est, est_name='traj', 140 | pose_relation=PoseRelation.translation_part, align=True, correct_scale=False) 141 | 142 | print(result) 143 | 144 | -------------------------------------------------------------------------------- /evaluation_scripts/test_euroc.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('droid_slam') 3 | 4 | from tqdm import tqdm 5 | import numpy as np 6 | import torch 7 | import lietorch 8 | import cv2 9 | import os 10 | import glob 11 | import time 12 | import argparse 13 | 14 | from pathlib import Path 15 | from torch.multiprocessing import Process 16 | from droid import Droid 17 | from droid_async import DroidAsync 18 | 19 | 20 | import torch.nn.functional as F 21 | 22 | 23 | def show_image(image): 24 | image = image.permute(1, 2, 0).cpu().numpy() 25 | cv2.imshow('image', image / 255.0) 26 | cv2.waitKey(1) 27 | 28 | def image_stream(datapath, image_size=[320, 512], stereo=False, stride=1): 29 | """ image generator """ 30 | 31 | K_l = np.array([458.654, 0.0, 367.215, 0.0, 457.296, 248.375, 0.0, 0.0, 1.0]).reshape(3,3) 32 | d_l = np.array([-0.28340811, 0.07395907, 0.00019359, 1.76187114e-05, 0.0]) 33 | R_l = np.array([ 34 | 0.999966347530033, -0.001422739138722922, 0.008079580483432283, 35 | 0.001365741834644127, 0.9999741760894847, 0.007055629199258132, 36 | -0.008089410156878961, -0.007044357138835809, 0.9999424675829176 37 | ]).reshape(3,3) 38 | 39 | P_l = np.array([435.2046959714599, 0, 367.4517211914062, 0, 0, 435.2046959714599, 252.2008514404297, 0, 0, 0, 1, 0]).reshape(3,4) 40 | map_l = cv2.initUndistortRectifyMap(K_l, d_l, R_l, P_l[:3,:3], (752, 480), cv2.CV_32F) 41 | 42 | K_r = np.array([457.587, 0.0, 379.999, 0.0, 456.134, 255.238, 0.0, 0.0, 1]).reshape(3,3) 43 | d_r = np.array([-0.28368365, 0.07451284, -0.00010473, -3.555907e-05, 0.0]).reshape(5) 44 | R_r = np.array([ 45 | 0.9999633526194376, -0.003625811871560086, 0.007755443660172947, 46 | 0.003680398547259526, 0.9999684752771629, -0.007035845251224894, 47 | -0.007729688520722713, 0.007064130529506649, 0.999945173484644 48 | ]).reshape(3,3) 49 | 50 | P_r = np.array([435.2046959714599, 0, 367.4517211914062, -47.90639384423901, 0, 435.2046959714599, 252.2008514404297, 0, 0, 0, 1, 0]).reshape(3,4) 51 | map_r = cv2.initUndistortRectifyMap(K_r, d_r, R_r, P_r[:3,:3], (752, 480), cv2.CV_32F) 52 | 53 | intrinsics_vec = [435.2046959714599, 435.2046959714599, 367.4517211914062, 252.2008514404297] 54 | ht0, wd0 = [480, 752] 55 | 56 | # read all png images in folder 57 | images_left = sorted(glob.glob(os.path.join(datapath, 'mav0/cam0/data/*.png')))[::stride] 58 | images_right = [x.replace('cam0', 'cam1') for x in images_left] 59 | 60 | data_list = [] 61 | for t, (imgL, imgR) in enumerate(zip(images_left, images_right)): 62 | if stereo and not os.path.isfile(imgR): 63 | continue 64 | tstamp = float(imgL.split('/')[-1][:-4]) 65 | images = [cv2.remap(cv2.imread(imgL), map_l[0], map_l[1], interpolation=cv2.INTER_LINEAR)] 66 | if stereo: 67 | images += [cv2.remap(cv2.imread(imgR), map_r[0], map_r[1], interpolation=cv2.INTER_LINEAR)] 68 | 69 | images = [cv2.resize(image, (image_size[1], image_size[0])) for image in images] 70 | images = torch.from_numpy(np.stack(images, 0)) 71 | images = images.permute(0, 3, 1, 2).to(dtype=torch.float32) 72 | 73 | intrinsics = torch.as_tensor(intrinsics_vec) 74 | intrinsics[0] *= image_size[1] / wd0 75 | intrinsics[1] *= image_size[0] / ht0 76 | intrinsics[2] *= image_size[1] / wd0 77 | intrinsics[3] *= image_size[0] / ht0 78 | 79 | data_list.append((stride*t, images, intrinsics)) 80 | 81 | return data_list 82 | 83 | 84 | if __name__ == '__main__': 85 | parser = argparse.ArgumentParser() 86 | parser.add_argument("--datapath", help="path to euroc sequence") 87 | parser.add_argument("--gt", help="path to gt file") 88 | parser.add_argument("--weights", default="droid.pth") 89 | parser.add_argument("--buffer", type=int, default=512) 90 | parser.add_argument("--image_size", default=[320,512]) 91 | parser.add_argument("--disable_vis", action="store_true") 92 | parser.add_argument("--stereo", action="store_true") 93 | 94 | parser.add_argument("--beta", type=float, default=0.3) 95 | parser.add_argument("--filter_thresh", type=float, default=2.4) 96 | parser.add_argument("--warmup", type=int, default=15) 97 | parser.add_argument("--keyframe_thresh", type=float, default=3.0) 98 | parser.add_argument("--frontend_thresh", type=float, default=17.5) 99 | parser.add_argument("--frontend_window", type=int, default=20) 100 | parser.add_argument("--frontend_radius", type=int, default=2) 101 | parser.add_argument("--frontend_nms", type=int, default=1) 102 | 103 | parser.add_argument("--backend_thresh", type=float, default=24.0) 104 | parser.add_argument("--backend_radius", type=int, default=2) 105 | parser.add_argument("--backend_nms", type=int, default=2) 106 | 107 | parser.add_argument("--upsample", action="store_true") 108 | parser.add_argument("--asynchronous", action="store_true") 109 | parser.add_argument("--frontend_device", type=str, default="cuda") 110 | parser.add_argument("--backend_device", type=str, default="cuda") 111 | args = parser.parse_args() 112 | 113 | torch.multiprocessing.set_start_method('spawn') 114 | 115 | print("Running evaluation on {}".format(args.datapath)) 116 | print(args) 117 | 118 | droid = DroidAsync(args) if args.asynchronous else Droid(args) 119 | scene = Path(args.datapath).name 120 | 121 | images = image_stream(args.datapath, stereo=args.stereo, stride=1) 122 | 123 | # run with stride 2 124 | for (t, image, intrinsics) in tqdm(images[::2], desc=scene): 125 | droid.track(t, image, intrinsics=intrinsics) 126 | 127 | # fill in missing poses with stride 1 128 | traj_est = droid.terminate(images) 129 | 130 | ### run evaluation ### 131 | 132 | import evo 133 | from evo.core.trajectory import PoseTrajectory3D 134 | from evo.tools import file_interface 135 | from evo.core import sync 136 | import evo.main_ape as main_ape 137 | from evo.core.metrics import PoseRelation 138 | 139 | images_list = sorted(glob.glob(os.path.join(args.datapath, 'mav0/cam0/data/*.png'))) 140 | tstamps = [float(x.split('/')[-1][:-4]) for x in images_list] 141 | 142 | traj_est = PoseTrajectory3D( 143 | positions_xyz=1.10 * traj_est[:,:3], 144 | orientations_quat_wxyz=traj_est[:,3:], 145 | timestamps=np.array(tstamps)) 146 | 147 | traj_ref = file_interface.read_tum_trajectory_file(args.gt) 148 | 149 | traj_ref, traj_est = sync.associate_trajectories(traj_ref, traj_est) 150 | 151 | result = main_ape.ape(traj_ref, traj_est, est_name='traj', 152 | pose_relation=PoseRelation.translation_part, align=True, correct_scale=True) 153 | 154 | print(result) 155 | 156 | 157 | -------------------------------------------------------------------------------- /evaluation_scripts/test_tartanair.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('droid_slam') 3 | sys.path.append('thirdparty/tartanair_tools') 4 | 5 | from evaluation.tartanair_evaluator import TartanAirEvaluator 6 | 7 | from tqdm import tqdm 8 | import numpy as np 9 | import torch 10 | import lietorch 11 | import cv2 12 | import os 13 | import glob 14 | import time 15 | import yaml 16 | import argparse 17 | 18 | from droid import Droid 19 | from droid_async import DroidAsync 20 | 21 | # camera baseline hardcoded to 0.1m 22 | STEREO_SCALE_FACTOR = 2.5 23 | 24 | MONO_TEST_SCENES = [f"M{s}{i:03d}" for s in ["E", "H"] for i in range(8)] 25 | STEREO_TEST_SCENES = [f"S{s}{i:03d}" for s in ["E", "H"] for i in range(8)] 26 | 27 | 28 | def image_stream(datapath, image_size=[384, 512], intrinsics_vec=[320.0, 320.0, 320.0, 240.0], stereo=False): 29 | """ image generator """ 30 | 31 | # read all png images in folder 32 | ht0, wd0 = [480, 640] 33 | 34 | if stereo: 35 | images_left = sorted(glob.glob(os.path.join(datapath, 'image_left/*.png'))) 36 | images_right = sorted(glob.glob(os.path.join(datapath, 'image_right/*.png'))) 37 | 38 | else: 39 | if os.path.exists(os.path.join(datapath, "image_left")): 40 | images_left = sorted(glob.glob(os.path.join(datapath, 'image_left/*.png'))) 41 | else: 42 | images_left = sorted(glob.glob(os.path.join(datapath, '*.png'))) 43 | 44 | data = [] 45 | for t in range(len(images_left)): 46 | images = [ cv2.resize(cv2.imread(images_left[t]), (image_size[1], image_size[0])) ] 47 | if stereo: 48 | images += [ cv2.resize(cv2.imread(images_right[t]), (image_size[1], image_size[0])) ] 49 | 50 | images = torch.from_numpy(np.stack(images, 0)).permute(0,3,1,2) 51 | intrinsics = .8 * torch.as_tensor(intrinsics_vec) 52 | 53 | data.append((t, images, intrinsics)) 54 | 55 | return data 56 | 57 | 58 | if __name__ == '__main__': 59 | parser = argparse.ArgumentParser() 60 | parser.add_argument("--datapath") 61 | parser.add_argument("--gt_path") 62 | 63 | parser.add_argument("--weights", default="droid.pth") 64 | parser.add_argument("--buffer", type=int, default=1000) 65 | parser.add_argument("--image_size", default=[384,512]) 66 | parser.add_argument("--stereo", action="store_true") 67 | parser.add_argument("--disable_vis", action="store_true") 68 | parser.add_argument("--plot_curve", action="store_true") 69 | parser.add_argument("--scene", type=str) 70 | 71 | parser.add_argument("--beta", type=float, default=0.3) 72 | parser.add_argument("--filter_thresh", type=float, default=2.5) 73 | parser.add_argument("--warmup", type=int, default=12) 74 | parser.add_argument("--keyframe_thresh", type=float, default=3.0) 75 | parser.add_argument("--frontend_thresh", type=float, default=15) 76 | parser.add_argument("--frontend_window", type=int, default=20) 77 | parser.add_argument("--frontend_radius", type=int, default=1) 78 | parser.add_argument("--frontend_nms", type=int, default=1) 79 | 80 | parser.add_argument("--backend_thresh", type=float, default=20.0) 81 | parser.add_argument("--backend_radius", type=int, default=2) 82 | parser.add_argument("--backend_nms", type=int, default=3) 83 | 84 | # damped linear velocity model 85 | parser.add_argument("--motion_damping", type=int, default=0.5) 86 | 87 | parser.add_argument("--upsample", action="store_true") 88 | parser.add_argument("--asynchronous", action="store_true") 89 | parser.add_argument("--frontend_device", type=str, default="cuda") 90 | parser.add_argument("--backend_device", type=str, default="cuda") 91 | 92 | args = parser.parse_args() 93 | torch.multiprocessing.set_start_method('spawn') 94 | 95 | 96 | if not os.path.isdir("figures"): 97 | os.mkdir("figures") 98 | 99 | test_scenes = STEREO_TEST_SCENES if args.stereo else MONO_TEST_SCENES 100 | 101 | # evaluate on a specific scene 102 | if args.scene is not None: 103 | test_scenes = [args.scene] 104 | 105 | ate_list = [] 106 | for scene in test_scenes: 107 | print("Performing evaluation on {}".format(scene)) 108 | torch.cuda.empty_cache() 109 | 110 | droid = DroidAsync(args) if args.asynchronous else Droid(args) 111 | 112 | scenedir = os.path.join(args.datapath, scene) 113 | gt_file = os.path.join(args.gt_path, f"{scene}.txt") 114 | 115 | for (tstamp, image, intrinsics) in tqdm(image_stream(scenedir, stereo=args.stereo), desc=scene): 116 | droid.track(tstamp, image, intrinsics=intrinsics) 117 | 118 | # fill in non-keyframe poses + global BA 119 | traj_est = droid.terminate(image_stream(scenedir)) 120 | 121 | if args.stereo: 122 | traj_est[:, :3] *= STEREO_SCALE_FACTOR 123 | 124 | ### do evaluation ### 125 | evaluator = TartanAirEvaluator() 126 | traj_ref = np.loadtxt(gt_file, delimiter=' ')[:, [1, 2, 0, 4, 5, 3, 6]] # ned -> xyz 127 | 128 | # usually stereo should not be scale corrected, but we are comparing monocular and stereo here 129 | results = evaluator.evaluate_one_trajectory( 130 | traj_ref, traj_est, scale=not args.stereo, title=scenedir[-20:].replace('/', '_')) 131 | 132 | print(results) 133 | ate_list.append(results["ate_score"]) 134 | 135 | print("Results") 136 | for (scene, ate) in zip(test_scenes, ate_list): 137 | print(f"{scene}: {ate}") 138 | 139 | print(ate_list) 140 | print("Mean ATE", np.mean(ate_list)) 141 | 142 | if args.plot_curve: 143 | import matplotlib.pyplot as plt 144 | ate = np.array(ate_list) 145 | xs = np.linspace(0.0, 1.0, 512) 146 | ys = [np.count_nonzero(ate < t) / ate.shape[0] for t in xs] 147 | 148 | plt.plot(xs, ys) 149 | plt.xlabel("ATE [m]") 150 | plt.ylabel("% runs") 151 | plt.show() 152 | 153 | -------------------------------------------------------------------------------- /evaluation_scripts/test_tum.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('droid_slam') 3 | 4 | from tqdm import tqdm 5 | import numpy as np 6 | import torch 7 | import lietorch 8 | import cv2 9 | import os 10 | import glob 11 | import time 12 | import argparse 13 | from pathlib import Path 14 | 15 | import torch.nn.functional as F 16 | from droid import Droid 17 | from droid_async import DroidAsync 18 | 19 | def show_image(image): 20 | image = image.permute(1, 2, 0).cpu().numpy() 21 | cv2.imshow('image', image / 255.0) 22 | cv2.waitKey(1) 23 | 24 | def image_stream(datapath): 25 | """ image generator """ 26 | 27 | fx, fy, cx, cy = 517.3, 516.5, 318.6, 255.3 28 | 29 | K_l = np.array([fx, 0.0, cx, 0.0, fy, cy, 0.0, 0.0, 1.0]).reshape(3,3) 30 | d_l = np.array([0.2624, -0.9531, -0.0054, 0.0026, 1.1633]) 31 | 32 | # read all png images in folder 33 | images_list = sorted(glob.glob(os.path.join(datapath, 'rgb', '*.png')))[::2] 34 | 35 | data_list = [] 36 | for t, imfile in enumerate(images_list): 37 | image = cv2.imread(imfile) 38 | ht0, wd0, _ = image.shape 39 | image = cv2.undistort(image, K_l, d_l) 40 | image = cv2.resize(image, (320+32, 240+16)) 41 | image = torch.from_numpy(image).permute(2,0,1) 42 | 43 | intrinsics = torch.as_tensor([fx, fy, cx, cy]) 44 | intrinsics[0] *= image.shape[2] / 640.0 45 | intrinsics[1] *= image.shape[1] / 480.0 46 | intrinsics[2] *= image.shape[2] / 640.0 47 | intrinsics[3] *= image.shape[1] / 480.0 48 | 49 | # crop image to remove distortion boundary 50 | intrinsics[2] -= 16 51 | intrinsics[3] -= 8 52 | image = image[:, 8:-8, 16:-16] 53 | 54 | data_list.append((t, image[None], intrinsics)) 55 | 56 | return data_list 57 | 58 | if __name__ == '__main__': 59 | parser = argparse.ArgumentParser() 60 | parser.add_argument("--datapath") 61 | parser.add_argument("--weights", default="droid.pth") 62 | parser.add_argument("--buffer", type=int, default=512) 63 | parser.add_argument("--image_size", default=[240, 320]) 64 | parser.add_argument("--disable_vis", action="store_true") 65 | 66 | parser.add_argument("--beta", type=float, default=0.3) 67 | parser.add_argument("--filter_thresh", type=float, default=1.5) 68 | parser.add_argument("--warmup", type=int, default=12) 69 | parser.add_argument("--keyframe_thresh", type=float, default=2.0) 70 | parser.add_argument("--frontend_thresh", type=float, default=12.0) 71 | parser.add_argument("--frontend_window", type=int, default=25) 72 | parser.add_argument("--frontend_radius", type=int, default=2) 73 | parser.add_argument("--frontend_nms", type=int, default=1) 74 | 75 | parser.add_argument("--backend_thresh", type=float, default=20.0) 76 | parser.add_argument("--backend_radius", type=int, default=2) 77 | parser.add_argument("--backend_nms", type=int, default=3) 78 | 79 | parser.add_argument("--upsample", action="store_true") 80 | 81 | parser.add_argument("--asynchronous", action="store_true") 82 | parser.add_argument("--frontend_device", type=str, default="cuda") 83 | parser.add_argument("--backend_device", type=str, default="cuda") 84 | parser.add_argument("--motion_damping", type=float, default=0.5) 85 | 86 | args = parser.parse_args() 87 | 88 | args.stereo = False 89 | torch.multiprocessing.set_start_method('spawn') 90 | 91 | print("Running evaluation on {}".format(args.datapath)) 92 | print(args) 93 | 94 | droid = DroidAsync(args) if args.asynchronous else Droid(args) 95 | scene = Path(args.datapath).name 96 | 97 | tstamps = [] 98 | images = image_stream(args.datapath) 99 | 100 | for (t, image, intrinsics) in tqdm(images, desc=scene): 101 | if not args.disable_vis: 102 | show_image(image) 103 | droid.track(t, image, intrinsics=intrinsics) 104 | 105 | traj_est = droid.terminate(images) 106 | 107 | ### run evaluation ### 108 | 109 | print("#"*20 + " Results...") 110 | 111 | import evo 112 | from evo.core.trajectory import PoseTrajectory3D 113 | from evo.tools import file_interface 114 | from evo.core import sync 115 | import evo.main_ape as main_ape 116 | from evo.core.metrics import PoseRelation 117 | 118 | image_path = os.path.join(args.datapath, 'rgb') 119 | images_list = sorted(glob.glob(os.path.join(image_path, '*.png')))[::2] 120 | tstamps = [float(x.split('/')[-1][:-4]) for x in images_list] 121 | 122 | traj_est = PoseTrajectory3D( 123 | positions_xyz=traj_est[:,:3], 124 | orientations_quat_wxyz=traj_est[:,3:], 125 | timestamps=np.array(tstamps)) 126 | 127 | gt_file = os.path.join(args.datapath, 'groundtruth.txt') 128 | traj_ref = file_interface.read_tum_trajectory_file(gt_file) 129 | 130 | traj_ref, traj_est = sync.associate_trajectories(traj_ref, traj_est) 131 | result = main_ape.ape(traj_ref, traj_est, est_name='traj', 132 | pose_relation=PoseRelation.translation_part, align=True, correct_scale=True) 133 | 134 | 135 | print(result) 136 | 137 | -------------------------------------------------------------------------------- /evaluation_scripts/validate_tartanair.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('droid_slam') 3 | sys.path.append('thirdparty/tartanair_tools') 4 | 5 | from tqdm import tqdm 6 | import numpy as np 7 | import torch 8 | import lietorch 9 | import cv2 10 | import os 11 | import glob 12 | import time 13 | import yaml 14 | import argparse 15 | 16 | from droid import Droid 17 | 18 | def image_stream(datapath, image_size=[384, 512], intrinsics_vec=[320.0, 320.0, 320.0, 240.0], stereo=False): 19 | """ image generator """ 20 | 21 | # read all png images in folder 22 | ht0, wd0 = [480, 640] 23 | images_left = sorted(glob.glob(os.path.join(datapath, 'image_left/*.png'))) 24 | images_right = sorted(glob.glob(os.path.join(datapath, 'image_right/*.png'))) 25 | 26 | data = [] 27 | for t in range(len(images_left)): 28 | images = [ cv2.resize(cv2.imread(images_left[t]), (image_size[1], image_size[0])) ] 29 | if stereo: 30 | images += [ cv2.resize(cv2.imread(images_right[t]), (image_size[1], image_size[0])) ] 31 | 32 | images = torch.from_numpy(np.stack(images, 0)).permute(0,3,1,2) 33 | intrinsics = .8 * torch.as_tensor(intrinsics_vec) 34 | 35 | data.append((t, images, intrinsics)) 36 | 37 | return data 38 | 39 | 40 | if __name__ == '__main__': 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument("--datapath", default="datasets/TartanAir") 43 | parser.add_argument("--weights", default="droid.pth") 44 | parser.add_argument("--buffer", type=int, default=1000) 45 | parser.add_argument("--image_size", default=[384,512]) 46 | parser.add_argument("--stereo", action="store_true") 47 | parser.add_argument("--disable_vis", action="store_true") 48 | parser.add_argument("--plot_curve", action="store_true") 49 | parser.add_argument("--id", type=int, default=-1) 50 | 51 | parser.add_argument("--beta", type=float, default=0.3) 52 | parser.add_argument("--filter_thresh", type=float, default=2.4) 53 | parser.add_argument("--warmup", type=int, default=12) 54 | parser.add_argument("--keyframe_thresh", type=float, default=3.5) 55 | parser.add_argument("--frontend_thresh", type=float, default=15) 56 | parser.add_argument("--frontend_window", type=int, default=20) 57 | parser.add_argument("--frontend_radius", type=int, default=1) 58 | parser.add_argument("--frontend_nms", type=int, default=1) 59 | 60 | parser.add_argument("--backend_thresh", type=float, default=20.0) 61 | parser.add_argument("--backend_radius", type=int, default=2) 62 | parser.add_argument("--backend_nms", type=int, default=3) 63 | parser.add_argument("--upsample", action="store_true") 64 | 65 | args = parser.parse_args() 66 | torch.multiprocessing.set_start_method('spawn') 67 | 68 | from data_readers.tartan import test_split 69 | from evaluation.tartanair_evaluator import TartanAirEvaluator 70 | 71 | if not os.path.isdir("figures"): 72 | os.mkdir("figures") 73 | 74 | if args.id >= 0: 75 | test_split = [ test_split[args.id] ] 76 | 77 | ate_list = [] 78 | for scene in test_split: 79 | print("Performing evaluation on {}".format(scene)) 80 | torch.cuda.empty_cache() 81 | droid = Droid(args) 82 | 83 | scenedir = os.path.join(args.datapath, scene) 84 | 85 | for (tstamp, image, intrinsics) in tqdm(image_stream(scenedir, stereo=args.stereo)): 86 | droid.track(tstamp, image, intrinsics=intrinsics) 87 | 88 | # fill in non-keyframe poses + global BA 89 | traj_est = droid.terminate(image_stream(scenedir)) 90 | 91 | ### do evaluation ### 92 | evaluator = TartanAirEvaluator() 93 | gt_file = os.path.join(scenedir, "pose_left.txt") 94 | traj_ref = np.loadtxt(gt_file, delimiter=' ')[:, [1, 2, 0, 4, 5, 3, 6]] # ned -> xyz 95 | 96 | # usually stereo should not be scale corrected, but we are comparing monocular and stereo here 97 | results = evaluator.evaluate_one_trajectory( 98 | traj_ref, traj_est, scale=True, title=scenedir[-20:].replace('/', '_')) 99 | 100 | print(results) 101 | ate_list.append(results["ate_score"]) 102 | 103 | print("Results") 104 | print(ate_list) 105 | 106 | if args.plot_curve: 107 | import matplotlib.pyplot as plt 108 | ate = np.array(ate_list) 109 | xs = np.linspace(0.0, 1.0, 512) 110 | ys = [np.count_nonzero(ate < t) / ate.shape[0] for t in xs] 111 | 112 | plt.plot(xs, ys) 113 | plt.xlabel("ATE [m]") 114 | plt.ylabel("% runs") 115 | plt.show() 116 | 117 | -------------------------------------------------------------------------------- /misc/DROID.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-vl/DROID-SLAM/2dfd39f0dcad44012ca7bbb8aa70b55edbfa9c99/misc/DROID.png -------------------------------------------------------------------------------- /misc/renderoption.json: -------------------------------------------------------------------------------- 1 | { 2 | "background_color" : [ 1, 1, 1 ], 3 | "class_name" : "RenderOption", 4 | "default_mesh_color" : [ 0.69999999999999996, 0.69999999999999996, 0.69999999999999996 ], 5 | "image_max_depth" : 3000, 6 | "image_stretch_option" : 0, 7 | "interpolation_option" : 0, 8 | "light0_color" : [ 1, 1, 1 ], 9 | "light0_diffuse_power" : 20, 10 | "light0_position" : [ 0, 0, 20 ], 11 | "light0_specular_power" : 2.20000000000000001, 12 | "light0_specular_shininess" : 100, 13 | "light1_color" : [ 1, 1, 1 ], 14 | "light1_diffuse_power" : 0.66000000000000003, 15 | "light1_position" : [ 0, 0, 2 ], 16 | "light1_specular_power" : 2.20000000000000001, 17 | "light1_specular_shininess" : 100, 18 | "light2_color" : [ 1, 1, 1 ], 19 | "light2_diffuse_power" : 20, 20 | "light2_position" : [ 0, 0, -20 ], 21 | "light2_specular_power" : 2.20000000000000001, 22 | "light2_specular_shininess" : 100, 23 | "light3_color" : [ 1, 1, 1 ], 24 | "light3_diffuse_power" : 20, 25 | "light3_position" : [ 0, 0, -20 ], 26 | "light3_specular_power" : 2.20000000000000001, 27 | "light3_specular_shininess" : 100, 28 | "light_ambient_color" : [ 0, 0, 0 ], 29 | "light_on" : true, 30 | "mesh_color_option" : 1, 31 | "mesh_shade_option" : 0, 32 | "mesh_show_back_face" : false, 33 | "mesh_show_wireframe" : false, 34 | "point_color_option" : 7, 35 | "point_show_normal" : false, 36 | "point_size" : 2, 37 | "show_coordinate_frame" : false, 38 | "version_major" : 1, 39 | "version_minor" : 0 40 | } 41 | -------------------------------------------------------------------------------- /misc/screenshot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-vl/DROID-SLAM/2dfd39f0dcad44012ca7bbb8aa70b55edbfa9c99/misc/screenshot.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | torchaudio 4 | wheel 5 | tqdm 6 | evo 7 | scipy 8 | open3d 9 | gdown 10 | tensorboard 11 | opencv-python 12 | pyyaml -------------------------------------------------------------------------------- /requirements_frozen.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.2.2 2 | addict==2.4.0 3 | argcomplete==3.6.2 4 | asttokens==3.0.0 5 | attrs==25.3.0 6 | beautifulsoup4==4.13.4 7 | blinker==1.9.0 8 | certifi==2025.4.26 9 | charset-normalizer==3.4.2 10 | click==8.1.8 11 | colorama==0.4.6 12 | comm==0.2.2 13 | ConfigArgParse==1.7 14 | contourpy==1.3.2 15 | cycler==0.12.1 16 | dash==3.0.4 17 | decorator==5.2.1 18 | evo==1.31.1 19 | exceptiongroup==1.2.2 20 | executing==2.2.0 21 | fastjsonschema==2.21.1 22 | filelock==3.18.0 23 | Flask==3.0.3 24 | fonttools==4.57.0 25 | fsspec==2025.3.2 26 | gdown==5.2.0 27 | grpcio==1.71.0 28 | idna==3.10 29 | importlib_metadata==8.7.0 30 | ipython==8.36.0 31 | ipywidgets==8.1.6 32 | itsdangerous==2.2.0 33 | jedi==0.19.2 34 | Jinja2==3.1.6 35 | joblib==1.5.0 36 | jsonschema==4.23.0 37 | jsonschema-specifications==2025.4.1 38 | jupyter_core==5.7.2 39 | jupyterlab_widgets==3.0.14 40 | kiwisolver==1.4.8 41 | lz4==4.4.4 42 | Markdown==3.8 43 | MarkupSafe==3.0.2 44 | matplotlib==3.10.1 45 | matplotlib-inline==0.1.7 46 | mpmath==1.3.0 47 | narwhals==1.37.1 48 | natsort==8.4.0 49 | nbformat==5.10.4 50 | nest-asyncio==1.6.0 51 | networkx==3.4.2 52 | numexpr==2.10.2 53 | numpy==2.2.5 54 | nvidia-cublas-cu12==12.6.4.1 55 | nvidia-cuda-cupti-cu12==12.6.80 56 | nvidia-cuda-nvrtc-cu12==12.6.77 57 | nvidia-cuda-runtime-cu12==12.6.77 58 | nvidia-cudnn-cu12==9.5.1.17 59 | nvidia-cufft-cu12==11.3.0.4 60 | nvidia-cufile-cu12==1.11.1.6 61 | nvidia-curand-cu12==10.3.7.77 62 | nvidia-cusolver-cu12==11.7.1.2 63 | nvidia-cusparse-cu12==12.5.4.2 64 | nvidia-cusparselt-cu12==0.6.3 65 | nvidia-nccl-cu12==2.26.2 66 | nvidia-nvjitlink-cu12==12.6.85 67 | nvidia-nvtx-cu12==12.6.77 68 | open3d==0.19.0 69 | opencv-python==4.11.0.86 70 | packaging==25.0 71 | pandas==2.2.3 72 | parso==0.8.4 73 | pexpect==4.9.0 74 | pillow==11.2.1 75 | platformdirs==4.3.7 76 | plotly==6.0.1 77 | prompt_toolkit==3.0.51 78 | protobuf==6.30.2 79 | ptyprocess==0.7.0 80 | pure_eval==0.2.3 81 | Pygments==2.19.1 82 | pyparsing==3.2.3 83 | pyquaternion==0.9.9 84 | PySocks==1.7.1 85 | python-dateutil==2.9.0.post0 86 | pytz==2025.2 87 | PyYAML==6.0.2 88 | referencing==0.36.2 89 | requests==2.32.3 90 | retrying==1.3.4 91 | rosbags==0.10.9 92 | rpds-py==0.24.0 93 | ruamel.yaml==0.18.10 94 | ruamel.yaml.clib==0.2.12 95 | scikit-learn==1.6.1 96 | scipy==1.15.2 97 | seaborn==0.13.2 98 | six==1.17.0 99 | soupsieve==2.7 100 | stack-data==0.6.3 101 | sympy==1.14.0 102 | tensorboard==2.19.0 103 | tensorboard-data-server==0.7.2 104 | threadpoolctl==3.6.0 105 | torch==2.7.0 106 | torchaudio==2.7.0 107 | torchvision==0.22.0 108 | tqdm==4.67.1 109 | traitlets==5.14.3 110 | triton==3.3.0 111 | typing_extensions==4.13.2 112 | tzdata==2025.2 113 | urllib3==2.4.0 114 | wcwidth==0.2.13 115 | Werkzeug==3.0.6 116 | widgetsnbextension==4.0.14 117 | zipp==3.21.0 118 | zstandard==0.23.0 119 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | import os.path as osp 5 | ROOT = osp.dirname(osp.abspath(__file__)) 6 | 7 | setup( 8 | name='droid_backends', 9 | ext_modules=[ 10 | CUDAExtension('droid_backends', 11 | include_dirs=[osp.join(ROOT, 'thirdparty/lietorch/eigen')], 12 | sources=[ 13 | 'src/droid.cpp', 14 | 'src/droid_kernels.cu', 15 | 'src/correlation_kernels.cu', 16 | 'src/altcorr_kernel.cu', 17 | ], 18 | extra_compile_args={ 19 | 'cxx': ['-O3'], 20 | 'nvcc': ['-O3'], 21 | }), 22 | ], 23 | cmdclass={ 'build_ext' : BuildExtension } 24 | ) 25 | -------------------------------------------------------------------------------- /src/correlation_kernels.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | 9 | #include 10 | #include 11 | #include 12 | 13 | #define BLOCK 16 14 | 15 | 16 | __forceinline__ __device__ bool within_bounds(int h, int w, int H, int W) { 17 | return h >= 0 && h < H && w >= 0 && w < W; 18 | } 19 | 20 | template 21 | __global__ void corr_index_forward_kernel( 22 | const torch::PackedTensorAccessor32 volume, 23 | const torch::PackedTensorAccessor32 coords, 24 | torch::PackedTensorAccessor32 corr, 25 | int r) 26 | { 27 | // batch index 28 | const int x = blockIdx.x * blockDim.x + threadIdx.x; 29 | const int y = blockIdx.y * blockDim.y + threadIdx.y; 30 | const int n = blockIdx.z; 31 | 32 | const int h1 = volume.size(1); 33 | const int w1 = volume.size(2); 34 | const int h2 = volume.size(3); 35 | const int w2 = volume.size(4); 36 | 37 | if (!within_bounds(y, x, h1, w1)) { 38 | return; 39 | } 40 | 41 | float x0 = coords[n][0][y][x]; 42 | float y0 = coords[n][1][y][x]; 43 | 44 | float dx = x0 - floor(x0); 45 | float dy = y0 - floor(y0); 46 | 47 | int rd = 2*r + 1; 48 | for (int i=0; i(floor(x0)) - r + i; 51 | int y1 = static_cast(floor(y0)) - r + j; 52 | 53 | if (within_bounds(y1, x1, h2, w2)) { 54 | scalar_t s = volume[n][y][x][y1][x1]; 55 | 56 | if (i > 0 && j > 0) 57 | corr[n][i-1][j-1][y][x] += s * scalar_t(dx * dy); 58 | 59 | if (i > 0 && j < rd) 60 | corr[n][i-1][j][y][x] += s * scalar_t(dx * (1.0f-dy)); 61 | 62 | if (i < rd && j > 0) 63 | corr[n][i][j-1][y][x] += s * scalar_t((1.0f-dx) * dy); 64 | 65 | if (i < rd && j < rd) 66 | corr[n][i][j][y][x] += s * scalar_t((1.0f-dx) * (1.0f-dy)); 67 | 68 | } 69 | } 70 | } 71 | } 72 | 73 | 74 | template 75 | __global__ void corr_index_backward_kernel( 76 | const torch::PackedTensorAccessor32 coords, 77 | const torch::PackedTensorAccessor32 corr_grad, 78 | torch::PackedTensorAccessor32 volume_grad, 79 | int r) 80 | { 81 | // batch index 82 | const int x = blockIdx.x * blockDim.x + threadIdx.x; 83 | const int y = blockIdx.y * blockDim.y + threadIdx.y; 84 | const int n = blockIdx.z; 85 | 86 | const int h1 = volume_grad.size(1); 87 | const int w1 = volume_grad.size(2); 88 | const int h2 = volume_grad.size(3); 89 | const int w2 = volume_grad.size(4); 90 | 91 | if (!within_bounds(y, x, h1, w1)) { 92 | return; 93 | } 94 | 95 | float x0 = coords[n][0][y][x]; 96 | float y0 = coords[n][1][y][x]; 97 | 98 | float dx = x0 - floor(x0); 99 | float dy = y0 - floor(y0); 100 | 101 | int rd = 2*r + 1; 102 | for (int i=0; i(floor(x0)) - r + i; 105 | int y1 = static_cast(floor(y0)) - r + j; 106 | 107 | if (within_bounds(y1, x1, h2, w2)) { 108 | scalar_t g = 0.0; 109 | if (i > 0 && j > 0) 110 | g += corr_grad[n][i-1][j-1][y][x] * scalar_t(dx * dy); 111 | 112 | if (i > 0 && j < rd) 113 | g += corr_grad[n][i-1][j][y][x] * scalar_t(dx * (1.0f-dy)); 114 | 115 | if (i < rd && j > 0) 116 | g += corr_grad[n][i][j-1][y][x] * scalar_t((1.0f-dx) * dy); 117 | 118 | if (i < rd && j < rd) 119 | g += corr_grad[n][i][j][y][x] * scalar_t((1.0f-dx) * (1.0f-dy)); 120 | 121 | volume_grad[n][y][x][y1][x1] += g; 122 | } 123 | } 124 | } 125 | } 126 | 127 | std::vector corr_index_cuda_forward( 128 | torch::Tensor volume, 129 | torch::Tensor coords, 130 | int radius) 131 | { 132 | const auto batch_size = volume.size(0); 133 | const auto ht = volume.size(1); 134 | const auto wd = volume.size(2); 135 | 136 | const dim3 blocks((wd + BLOCK - 1) / BLOCK, 137 | (ht + BLOCK - 1) / BLOCK, 138 | batch_size); 139 | 140 | const dim3 threads(BLOCK, BLOCK); 141 | 142 | auto opts = volume.options(); 143 | torch::Tensor corr = torch::zeros( 144 | {batch_size, 2*radius+1, 2*radius+1, ht, wd}, opts); 145 | 146 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(volume.scalar_type(), "sampler_forward_kernel", ([&] { 147 | corr_index_forward_kernel<<>>( 148 | volume.packed_accessor32(), 149 | coords.packed_accessor32(), 150 | corr.packed_accessor32(), 151 | radius); 152 | })); 153 | 154 | return {corr}; 155 | 156 | } 157 | 158 | std::vector corr_index_cuda_backward( 159 | torch::Tensor volume, 160 | torch::Tensor coords, 161 | torch::Tensor corr_grad, 162 | int radius) 163 | { 164 | const auto batch_size = volume.size(0); 165 | const auto ht = volume.size(1); 166 | const auto wd = volume.size(2); 167 | 168 | auto volume_grad = torch::zeros_like(volume); 169 | 170 | const dim3 blocks((wd + BLOCK - 1) / BLOCK, 171 | (ht + BLOCK - 1) / BLOCK, 172 | batch_size); 173 | 174 | const dim3 threads(BLOCK, BLOCK); 175 | 176 | 177 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(volume.scalar_type(), "sampler_backward_kernel", ([&] { 178 | corr_index_backward_kernel<<>>( 179 | coords.packed_accessor32(), 180 | corr_grad.packed_accessor32(), 181 | volume_grad.packed_accessor32(), 182 | radius); 183 | })); 184 | 185 | return {volume_grad}; 186 | } -------------------------------------------------------------------------------- /src/droid.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | // CUDA forward declarations 5 | std::vector projective_transform_cuda( 6 | torch::Tensor poses, 7 | torch::Tensor disps, 8 | torch::Tensor intrinsics, 9 | torch::Tensor ii, 10 | torch::Tensor jj); 11 | 12 | 13 | 14 | torch::Tensor depth_filter_cuda( 15 | torch::Tensor poses, 16 | torch::Tensor disps, 17 | torch::Tensor intrinsics, 18 | torch::Tensor ix, 19 | torch::Tensor thresh); 20 | 21 | 22 | torch::Tensor frame_distance_cuda( 23 | torch::Tensor poses, 24 | torch::Tensor disps, 25 | torch::Tensor intrinsics, 26 | torch::Tensor ii, 27 | torch::Tensor jj, 28 | const float beta); 29 | 30 | std::vector projmap_cuda( 31 | torch::Tensor poses, 32 | torch::Tensor disps, 33 | torch::Tensor intrinsics, 34 | torch::Tensor ii, 35 | torch::Tensor jj); 36 | 37 | torch::Tensor iproj_cuda( 38 | torch::Tensor poses, 39 | torch::Tensor disps, 40 | torch::Tensor intrinsics); 41 | 42 | std::vector ba_cuda( 43 | torch::Tensor poses, 44 | torch::Tensor disps, 45 | torch::Tensor intrinsics, 46 | torch::Tensor disps_sens, 47 | torch::Tensor targets, 48 | torch::Tensor weights, 49 | torch::Tensor eta, 50 | torch::Tensor ii, 51 | torch::Tensor jj, 52 | const int t0, 53 | const int t1, 54 | const int iterations, 55 | const float lm, 56 | const float ep, 57 | const bool motion_only); 58 | 59 | std::vector corr_index_cuda_forward( 60 | torch::Tensor volume, 61 | torch::Tensor coords, 62 | int radius); 63 | 64 | std::vector corr_index_cuda_backward( 65 | torch::Tensor volume, 66 | torch::Tensor coords, 67 | torch::Tensor corr_grad, 68 | int radius); 69 | 70 | std::vector altcorr_cuda_forward( 71 | torch::Tensor fmap1, 72 | torch::Tensor fmap2, 73 | torch::Tensor coords, 74 | torch::Tensor ii, 75 | torch::Tensor jj, 76 | int radius); 77 | 78 | std::vector altcorr_cuda_backward( 79 | torch::Tensor fmap1, 80 | torch::Tensor fmap2, 81 | torch::Tensor coords, 82 | torch::Tensor corr_grad, 83 | torch::Tensor ii, 84 | torch::Tensor jj, 85 | int radius); 86 | 87 | 88 | 89 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 90 | #define CHECK_INPUT(x) CHECK_CONTIGUOUS(x) 91 | 92 | 93 | std::vector ba( 94 | torch::Tensor poses, 95 | torch::Tensor disps, 96 | torch::Tensor intrinsics, 97 | torch::Tensor disps_sens, 98 | torch::Tensor targets, 99 | torch::Tensor weights, 100 | torch::Tensor eta, 101 | torch::Tensor ii, 102 | torch::Tensor jj, 103 | const int t0, 104 | const int t1, 105 | const int iterations, 106 | const float lm, 107 | const float ep, 108 | const bool motion_only) { 109 | 110 | CHECK_INPUT(targets); 111 | CHECK_INPUT(weights); 112 | CHECK_INPUT(poses); 113 | CHECK_INPUT(disps); 114 | CHECK_INPUT(intrinsics); 115 | CHECK_INPUT(disps_sens); 116 | CHECK_INPUT(ii); 117 | CHECK_INPUT(jj); 118 | 119 | return ba_cuda(poses, disps, intrinsics, disps_sens, targets, weights, 120 | eta, ii, jj, t0, t1, iterations, lm, ep, motion_only); 121 | 122 | } 123 | 124 | 125 | torch::Tensor frame_distance( 126 | torch::Tensor poses, 127 | torch::Tensor disps, 128 | torch::Tensor intrinsics, 129 | torch::Tensor ii, 130 | torch::Tensor jj, 131 | const float beta) { 132 | 133 | CHECK_INPUT(poses); 134 | CHECK_INPUT(disps); 135 | CHECK_INPUT(intrinsics); 136 | CHECK_INPUT(ii); 137 | CHECK_INPUT(jj); 138 | 139 | return frame_distance_cuda(poses, disps, intrinsics, ii, jj, beta); 140 | 141 | } 142 | 143 | 144 | std::vector projmap( 145 | torch::Tensor poses, 146 | torch::Tensor disps, 147 | torch::Tensor intrinsics, 148 | torch::Tensor ii, 149 | torch::Tensor jj) { 150 | 151 | CHECK_INPUT(poses); 152 | CHECK_INPUT(disps); 153 | CHECK_INPUT(intrinsics); 154 | CHECK_INPUT(ii); 155 | CHECK_INPUT(jj); 156 | 157 | return projmap_cuda(poses, disps, intrinsics, ii, jj); 158 | 159 | } 160 | 161 | 162 | torch::Tensor iproj( 163 | torch::Tensor poses, 164 | torch::Tensor disps, 165 | torch::Tensor intrinsics) { 166 | CHECK_INPUT(poses); 167 | CHECK_INPUT(disps); 168 | CHECK_INPUT(intrinsics); 169 | 170 | return iproj_cuda(poses, disps, intrinsics); 171 | } 172 | 173 | 174 | // c++ python binding 175 | std::vector corr_index_forward( 176 | torch::Tensor volume, 177 | torch::Tensor coords, 178 | int radius) { 179 | CHECK_INPUT(volume); 180 | CHECK_INPUT(coords); 181 | 182 | return corr_index_cuda_forward(volume, coords, radius); 183 | } 184 | 185 | std::vector corr_index_backward( 186 | torch::Tensor volume, 187 | torch::Tensor coords, 188 | torch::Tensor corr_grad, 189 | int radius) { 190 | CHECK_INPUT(volume); 191 | CHECK_INPUT(coords); 192 | CHECK_INPUT(corr_grad); 193 | 194 | auto volume_grad = corr_index_cuda_backward(volume, coords, corr_grad, radius); 195 | return {volume_grad}; 196 | } 197 | 198 | std::vector altcorr_forward( 199 | torch::Tensor fmap1, 200 | torch::Tensor fmap2, 201 | torch::Tensor coords, 202 | torch::Tensor ii, 203 | torch::Tensor jj, 204 | int radius) { 205 | CHECK_INPUT(fmap1); 206 | CHECK_INPUT(fmap2); 207 | CHECK_INPUT(coords); 208 | 209 | return altcorr_cuda_forward(fmap1, fmap2, coords, ii, jj, radius); 210 | } 211 | 212 | std::vector altcorr_backward( 213 | torch::Tensor fmap1, 214 | torch::Tensor fmap2, 215 | torch::Tensor coords, 216 | torch::Tensor corr_grad, 217 | torch::Tensor ii, 218 | torch::Tensor jj, 219 | int radius) { 220 | CHECK_INPUT(fmap1); 221 | CHECK_INPUT(fmap2); 222 | CHECK_INPUT(coords); 223 | CHECK_INPUT(corr_grad); 224 | 225 | return altcorr_cuda_backward(fmap1, fmap2, coords, ii, jj, corr_grad, radius); 226 | } 227 | 228 | torch::Tensor depth_filter( 229 | torch::Tensor poses, 230 | torch::Tensor disps, 231 | torch::Tensor intrinsics, 232 | torch::Tensor ix, 233 | torch::Tensor thresh) { 234 | 235 | CHECK_INPUT(poses); 236 | CHECK_INPUT(disps); 237 | CHECK_INPUT(intrinsics); 238 | CHECK_INPUT(ix); 239 | CHECK_INPUT(thresh); 240 | 241 | return depth_filter_cuda(poses, disps, intrinsics, ix, thresh); 242 | } 243 | 244 | 245 | 246 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 247 | // bundle adjustment kernels 248 | m.def("ba", &ba, "bundle adjustment"); 249 | m.def("frame_distance", &frame_distance, "frame_distance"); 250 | m.def("projmap", &projmap, "projmap"); 251 | m.def("depth_filter", &depth_filter, "depth_filter"); 252 | m.def("iproj", &iproj, "back projection"); 253 | 254 | // correlation volume kernels 255 | m.def("altcorr_forward", &altcorr_forward, "ALTCORR forward"); 256 | m.def("altcorr_backward", &altcorr_backward, "ALTCORR backward"); 257 | m.def("corr_index_forward", &corr_index_forward, "INDEX forward"); 258 | m.def("corr_index_backward", &corr_index_backward, "INDEX backward"); 259 | } -------------------------------------------------------------------------------- /thirdparty/tartanair_tools/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2020, Carnegie Mellon University 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 5 | 6 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 7 | 8 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 9 | 10 | 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 11 | 12 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /thirdparty/tartanair_tools/data_type.md: -------------------------------------------------------------------------------- 1 | ### GRB Image 2 | 3 | The color images are stored as 640x480 8-bit RGB images in PNG format. 4 | 5 | * Load the image using OpenCV: 6 | ``` 7 | import cv2 8 | img = cv2.imread(FILENAME) 9 | cv2.imshow('img', img) 10 | cv2.waitKey(0) 11 | ``` 12 | 13 | * Load the image using Pillow: 14 | ``` 15 | from PIL import Image 16 | img = Image.open(FILENAME) 17 | img.show() 18 | ``` 19 | 20 | ### Camera intrinsics 21 | ``` 22 | fx = 320.0 # focal length x 23 | fy = 320.0 # focal length y 24 | cx = 320.0 # optical center x 25 | cy = 240.0 # optical center y 26 | 27 | fov = 90 deg # field of view 28 | 29 | width = 640 30 | height = 480 31 | ``` 32 | 33 | ### Depth image 34 | 35 | The depth maps are stored as 640x480 16-bit numpy array in NPY format. In the Unreal Engine, the environment usually has a sky sphere at a large distance. So the infinite distant object such as the sky has a large depth value (e.g. 10000) instead of an infinite number. 36 | 37 | The unit of the depth value is meter. The baseline between the left and right cameras is 0.25m. 38 | 39 | * Load the depth image: 40 | ``` 41 | import numpy as np 42 | depth = np.load(FILENAME) 43 | 44 | # change to disparity image 45 | disparity = 80.0 / depth 46 | ``` 47 | 48 | ### Segmentation image 49 | 50 | The segmentation images are saved as a uint8 numpy array. AirSim assigns value 0 to 255 to each mesh available in the environment. 51 | 52 | [More details](https://github.com/microsoft/AirSim/blob/master/docs/image_apis.md#segmentation) 53 | 54 | * Load the segmentation image 55 | ``` 56 | import numpy as np 57 | depth = np.load(FILENAME) 58 | ``` 59 | 60 | ### Optical flow 61 | 62 | The optical flow maps are saved as a float32 numpy array, which is calculated based on the ground truth depth and ground truth camera motion, using [this](https://github.com/huyaoyu/ImageFlow) code. Dynamic objects and occlusions are masked by the mask file, which is a uint8 numpy array. We currently provide the optical flow for the left camera. 63 | 64 | * Load the optical flow 65 | ``` 66 | import numpy as np 67 | flow = np.load(FILENAME) 68 | 69 | # load the mask 70 | mask = np.load(MASKFILENAME) 71 | ``` 72 | 73 | ### Pose file 74 | 75 | The camera pose file is a text file containing the translation and orientation of the camera in a fixed coordinate frame. Note that our automatic evaluation tool expects both the ground truth trajectory and the estimated trajectory to be in this format.  76 | 77 | * Each line in the text file contains a single pose. 78 | 79 | * The number of lines/poses is the same as the number of image frames in that trajectory.  80 | 81 | * The format of each line is '**tx ty tz qx qy qz qw**'.  82 | 83 | * **tx ty tz** (3 floats) give the position of the optical center of the color camera with respect to the world origin in the world frame. 84 | 85 | * **qx qy qz qw** (4 floats) give the orientation of the optical center of the color camera in the form of a unit quaternion with respect to the world frame.  86 | 87 | * The camera motion is defined in the NED frame. That is to say, the x-axis is pointing to the camera's forward, the y-axis is pointing to the camera's right, the z-axis is pointing to the camera's downward. 88 | 89 | * Load the pose file: 90 | ``` 91 | import numpy as np 92 | flow = np.loadtxt(FILENAME) 93 | ``` 94 | -------------------------------------------------------------------------------- /thirdparty/tartanair_tools/download_cvpr_slam_test.txt: -------------------------------------------------------------------------------- 1 | https://tartanair.blob.core.windows.net/tartanair-testing1/tartanair-test-mono-release.tar.gz 2 | https://tartanair.blob.core.windows.net/tartanair-testing1/tartanair-test-stereo-release.tar.gz 3 | https://tartanair.blob.core.windows.net/tartanair-testing1/tartanair-test-release.tar.gz -------------------------------------------------------------------------------- /thirdparty/tartanair_tools/download_training.py: -------------------------------------------------------------------------------- 1 | from os import system, mkdir 2 | import argparse 3 | from os.path import isdir, isfile 4 | 5 | def get_args(): 6 | parser = argparse.ArgumentParser(description='TartanAir') 7 | 8 | parser.add_argument('--output-dir', default='./', 9 | help='root directory for downloaded files') 10 | 11 | parser.add_argument('--rgb', action='store_true', default=False, 12 | help='download rgb image') 13 | 14 | parser.add_argument('--depth', action='store_true', default=False, 15 | help='download depth image') 16 | 17 | parser.add_argument('--flow', action='store_true', default=False, 18 | help='download optical flow') 19 | 20 | parser.add_argument('--seg', action='store_true', default=False, 21 | help='download segmentation image') 22 | 23 | parser.add_argument('--only-easy', action='store_true', default=False, 24 | help='download only easy trajectories') 25 | 26 | parser.add_argument('--only-hard', action='store_true', default=False, 27 | help='download only hard trajectories') 28 | 29 | parser.add_argument('--only-left', action='store_true', default=False, 30 | help='download only left camera') 31 | 32 | parser.add_argument('--only-right', action='store_true', default=False, 33 | help='download only right camera') 34 | 35 | parser.add_argument('--only-flow', action='store_true', default=False, 36 | help='download only optical flow wo/ mask') 37 | 38 | parser.add_argument('--only-mask', action='store_true', default=False, 39 | help='download only mask wo/ flow') 40 | 41 | parser.add_argument('--azcopy', action='store_true', default=False, 42 | help='download the data with AzCopy, which is 10x faster in our test') 43 | 44 | args = parser.parse_args() 45 | 46 | return args 47 | 48 | def _help(): 49 | print '' 50 | 51 | if __name__ == '__main__': 52 | args = get_args() 53 | 54 | # output directory 55 | outdir = args.output_dir 56 | if not isdir(outdir): 57 | print('Output dir {} does not exists!'.format(outdir)) 58 | exit() 59 | 60 | # difficulty level 61 | levellist = ['Easy', 'Hard'] 62 | if args.only_easy: 63 | levellist = ['Easy'] 64 | if args.only_hard: 65 | levellist = ['Hard'] 66 | if args.only_easy and args.only_hard: 67 | print('--only-eazy and --only-hard tags can not be set at the same time!') 68 | exit() 69 | 70 | 71 | # filetype 72 | typelist = [] 73 | if args.rgb: 74 | typelist.append('image') 75 | if args.depth: 76 | typelist.append('depth') 77 | if args.seg: 78 | typelist.append('seg') 79 | if args.flow: 80 | typelist.append('flow') 81 | if len(typelist)==0: 82 | print('Specify the type of data you want to download by --rgb/depth/seg/flow') 83 | exit() 84 | 85 | # camera 86 | cameralist = ['left', 'right', 'flow', 'mask'] 87 | if args.only_left: 88 | cameralist.remove('right') 89 | if args.only_right: 90 | cameralist.remove('left') 91 | if args.only_flow: 92 | cameralist.remove('mask') 93 | if args.only_mask: 94 | cameralist.remove('flow') 95 | if args.only_left and args.only_right: 96 | print('--only-left and --only-right tags can not be set at the same time!') 97 | exit() 98 | if args.only_flow and args.only_mask: 99 | print('--only-flow and --only-mask tags can not be set at the same time!') 100 | exit() 101 | 102 | # read all the zip file urls 103 | with open('download_training_zipfiles.txt') as f: 104 | lines = f.readlines() 105 | ziplist = [ll.strip() for ll in lines if ll.strip().endswith('.zip')] 106 | 107 | downloadlist = [] 108 | for zipfile in ziplist: 109 | zf = zipfile.split('/') 110 | filename = zf[-1] 111 | difflevel = zf[-2] 112 | 113 | # image/depth/seg/flow 114 | filetype = filename.split('_')[0] 115 | # left/right/flow/mask 116 | cameratype = filename.split('.')[0].split('_')[-1] 117 | 118 | if (difflevel in levellist) and (filetype in typelist) and (cameratype in cameralist): 119 | downloadlist.append(zipfile) 120 | 121 | if len(downloadlist)==0: 122 | print('No file meets the condition!') 123 | exit() 124 | 125 | print('{} files are going to be downloaded...'.format(len(downloadlist))) 126 | for fileurl in downloadlist: 127 | print fileurl 128 | 129 | for fileurl in downloadlist: 130 | zf = fileurl.split('/') 131 | filename = zf[-1] 132 | difflevel = zf[-2] 133 | envname = zf[-3] 134 | 135 | envfolder = outdir + '/' + envname 136 | if not isdir(envfolder): 137 | mkdir(envfolder) 138 | print('Created a new env folder {}..'.format(envfolder)) 139 | # else: 140 | # print('Env folder {} already exists..'.format(envfolder)) 141 | 142 | levelfolder = envfolder + '/' + difflevel 143 | if not isdir(levelfolder): 144 | mkdir(levelfolder) 145 | print(' Created a new level folder {}..'.format(levelfolder)) 146 | # else: 147 | # print('Level folder {} already exists..'.format(levelfolder)) 148 | 149 | targetfile = levelfolder + '/' + filename 150 | if isfile(targetfile): 151 | print('Target file {} already exists..'.format(targetfile)) 152 | exit() 153 | 154 | if args.azcopy: 155 | cmd = 'azcopy copy ' + fileurl + ' ' + targetfile 156 | else: 157 | cmd = 'wget -r -O ' + targetfile + ' ' + fileurl 158 | print cmd 159 | system(cmd) -------------------------------------------------------------------------------- /thirdparty/tartanair_tools/evaluation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-vl/DROID-SLAM/2dfd39f0dcad44012ca7bbb8aa70b55edbfa9c99/thirdparty/tartanair_tools/evaluation/__init__.py -------------------------------------------------------------------------------- /thirdparty/tartanair_tools/evaluation/evaluate_ate_scale.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | 3 | # Modified by Wenshan Wang 4 | # Modified by Raul Mur-Artal 5 | # Automatically compute the optimal scale factor for monocular VO/SLAM. 6 | 7 | # Software License Agreement (BSD License) 8 | # 9 | # Copyright (c) 2013, Juergen Sturm, TUM 10 | # All rights reserved. 11 | # 12 | # Redistribution and use in source and binary forms, with or without 13 | # modification, are permitted provided that the following conditions 14 | # are met: 15 | # 16 | # * Redistributions of source code must retain the above copyright 17 | # notice, this list of conditions and the following disclaimer. 18 | # * Redistributions in binary form must reproduce the above 19 | # copyright notice, this list of conditions and the following 20 | # disclaimer in the documentation and/or other materials provided 21 | # with the distribution. 22 | # * Neither the name of TUM nor the names of its 23 | # contributors may be used to endorse or promote products derived 24 | # from this software without specific prior written permission. 25 | # 26 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 27 | # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 28 | # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS 29 | # FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE 30 | # COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, 31 | # INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 32 | # BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 33 | # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 34 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 35 | # LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN 36 | # ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 37 | # POSSIBILITY OF SUCH DAMAGE. 38 | # 39 | # Requirements: 40 | # sudo apt-get install python-argparse 41 | 42 | """ 43 | This script computes the absolute trajectory error from the ground truth 44 | trajectory and the estimated trajectory. 45 | """ 46 | 47 | import numpy 48 | 49 | def align(model,data,calc_scale=False): 50 | """Align two trajectories using the method of Horn (closed-form). 51 | 52 | Input: 53 | model -- first trajectory (3xn) 54 | data -- second trajectory (3xn) 55 | 56 | Output: 57 | rot -- rotation matrix (3x3) 58 | trans -- translation vector (3x1) 59 | trans_error -- translational error per point (1xn) 60 | 61 | """ 62 | numpy.set_printoptions(precision=3,suppress=True) 63 | model_zerocentered = model - model.mean(1) 64 | data_zerocentered = data - data.mean(1) 65 | 66 | W = numpy.zeros( (3,3) ) 67 | for column in range(model.shape[1]): 68 | W += numpy.outer(model_zerocentered[:,column],data_zerocentered[:,column]) 69 | U,d,Vh = numpy.linalg.linalg.svd(W.transpose()) 70 | S = numpy.matrix(numpy.identity( 3 )) 71 | if(numpy.linalg.det(U) * numpy.linalg.det(Vh)<0): 72 | S[2,2] = -1 73 | rot = U*S*Vh 74 | 75 | if calc_scale: 76 | rotmodel = rot*model_zerocentered 77 | dots = 0.0 78 | norms = 0.0 79 | for column in range(data_zerocentered.shape[1]): 80 | dots += numpy.dot(data_zerocentered[:,column].transpose(),rotmodel[:,column]) 81 | normi = numpy.linalg.norm(model_zerocentered[:,column]) 82 | norms += normi*normi 83 | # s = float(dots/norms) 84 | s = float(norms/dots) 85 | else: 86 | s = 1.0 87 | 88 | # trans = data.mean(1) - s*rot * model.mean(1) 89 | # model_aligned = s*rot * model + trans 90 | # alignment_error = model_aligned - data 91 | 92 | # scale the est to the gt, otherwise the ATE could be very small if the est scale is small 93 | trans = s*data.mean(1) - rot * model.mean(1) 94 | model_aligned = rot * model + trans 95 | data_alingned = s * data 96 | alignment_error = model_aligned - data_alingned 97 | 98 | trans_error = numpy.sqrt(numpy.sum(numpy.multiply(alignment_error,alignment_error),0)).A[0] 99 | 100 | return rot,trans,trans_error, s 101 | 102 | def plot_traj(ax,stamps,traj,style,color,label): 103 | """ 104 | Plot a trajectory using matplotlib. 105 | 106 | Input: 107 | ax -- the plot 108 | stamps -- time stamps (1xn) 109 | traj -- trajectory (3xn) 110 | style -- line style 111 | color -- line color 112 | label -- plot legend 113 | 114 | """ 115 | stamps.sort() 116 | interval = numpy.median([s-t for s,t in zip(stamps[1:],stamps[:-1])]) 117 | x = [] 118 | y = [] 119 | last = stamps[0] 120 | for i in range(len(stamps)): 121 | if stamps[i]-last < 2*interval: 122 | x.append(traj[i][0]) 123 | y.append(traj[i][1]) 124 | elif len(x)>0: 125 | ax.plot(x,y,style,color=color,label=label) 126 | label="" 127 | x=[] 128 | y=[] 129 | last= stamps[i] 130 | if len(x)>0: 131 | ax.plot(x,y,style,color=color,label=label) 132 | 133 | 134 | -------------------------------------------------------------------------------- /thirdparty/tartanair_tools/evaluation/evaluate_kitti.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Carnegie Mellon University, Wenshan Wang 2 | # For License information please see the LICENSE file in the root directory. 3 | # This is a python reinplementation of the KITTI metric: http://www.cvlibs.net/datasets/kitti/eval_odometry.php 4 | # Cridit: Xiangwei Wang https://github.com/TimingSpace 5 | 6 | import numpy as np 7 | import sys 8 | 9 | def trajectory_distances(poses): 10 | distances = [] 11 | distances.append(0) 12 | for i in range(1,len(poses)): 13 | p1 = poses[i-1] 14 | p2 = poses[i] 15 | delta = p1[0:3,3] - p2[0:3,3] 16 | distances.append(distances[i-1]+np.linalg.norm(delta)) 17 | return distances 18 | 19 | def last_frame_from_segment_length(dist,first_frame,length): 20 | for i in range(first_frame,len(dist)): 21 | if dist[i]>dist[first_frame]+length: 22 | return i 23 | return -1 24 | 25 | def rotation_error(pose_error): 26 | a = pose_error[0,0] 27 | b = pose_error[1,1] 28 | c = pose_error[2,2] 29 | d = 0.5*(a+b+c-1) 30 | rot_error = np.arccos(max(min(d,1.0),-1.0)) 31 | return rot_error 32 | 33 | def translation_error(pose_error): 34 | dx = pose_error[0,3] 35 | dy = pose_error[1,3] 36 | dz = pose_error[2,3] 37 | return np.sqrt(dx*dx+dy*dy+dz*dz) 38 | 39 | # def line2matrix(pose_line): 40 | # pose_line = np.array(pose_line) 41 | # pose_m = np.matrix(np.eye(4)) 42 | # pose_m[0:3,:] = pose_line.reshape(3,4) 43 | # return pose_m 44 | 45 | def calculate_sequence_error(poses_gt,poses_result,lengths=[10,20,30,40,50,60,70,80]): 46 | # error_vetor 47 | errors = [] 48 | 49 | # paramet 50 | step_size = 1 #10; # every second 51 | num_lengths = len(lengths) 52 | 53 | # import ipdb;ipdb.set_trace() 54 | # pre-compute distances (from ground truth as reference) 55 | dist = trajectory_distances(poses_gt) 56 | # for all start positions do 57 | for first_frame in range(0, len(poses_gt), step_size): 58 | # for all segment lengths do 59 | for i in range(0,num_lengths): 60 | # current length 61 | length = lengths[i]; 62 | 63 | # compute last frame 64 | last_frame = last_frame_from_segment_length(dist,first_frame,length); 65 | # continue, if sequence not long enough 66 | if (last_frame==-1): 67 | continue; 68 | 69 | # compute rotational and translational errors 70 | pose_delta_gt = np.linalg.inv(poses_gt[first_frame]).dot(poses_gt[last_frame]) 71 | pose_delta_result = np.linalg.inv(poses_result[first_frame]).dot(poses_result[last_frame]) 72 | pose_error = np.linalg.inv(pose_delta_result).dot(pose_delta_gt) 73 | r_err = rotation_error(pose_error); 74 | t_err = translation_error(pose_error); 75 | 76 | # compute speed 77 | num_frames = (float)(last_frame-first_frame+1); 78 | speed = length/(0.1*num_frames); 79 | 80 | # write to file 81 | error = [first_frame,r_err/length,t_err/length,length,speed] 82 | errors.append(error) 83 | # return error vector 84 | return errors 85 | 86 | def calculate_ave_errors(errors,lengths=[10,20,30,40,50,60,70,80]): 87 | rot_errors=[] 88 | tra_errors=[] 89 | for length in lengths: 90 | rot_error_each_length =[] 91 | tra_error_each_length =[] 92 | for error in errors: 93 | if abs(error[3]-length)<0.1: 94 | rot_error_each_length.append(error[1]) 95 | tra_error_each_length.append(error[2]) 96 | 97 | if len(rot_error_each_length)==0: 98 | # import ipdb;ipdb.set_trace() 99 | continue 100 | else: 101 | rot_errors.append(sum(rot_error_each_length)/len(rot_error_each_length)) 102 | tra_errors.append(sum(tra_error_each_length)/len(tra_error_each_length)) 103 | return np.array(rot_errors)*180/np.pi, tra_errors 104 | 105 | def evaluate(gt, data,rescale_=False): 106 | lens = [5,10,15,20,25,30,35,40] #[1,2,3,4,5,6] # 107 | errors = calculate_sequence_error(gt, data, lengths=lens) 108 | rot,tra = calculate_ave_errors(errors, lengths=lens) 109 | return np.mean(rot), np.mean(tra) 110 | 111 | def main(): 112 | # usage: python main.py path_to_ground_truth path_to_predict_pose 113 | # load and preprocess data 114 | ground_truth_data = np.loadtxt(sys.argv[1]) 115 | predict_pose__data = np.loadtxt(sys.argv[2]) 116 | errors = calculate_sequence_error(ground_truth_data,predict_pose__data) 117 | rot,tra = calculate_ave_errors(errors) 118 | print(rot,'\n',tra) 119 | #print(error) 120 | # evaluate the vo result 121 | # save and visualization the evaluatation result 122 | 123 | if __name__ == "__main__": 124 | main() 125 | 126 | 127 | -------------------------------------------------------------------------------- /thirdparty/tartanair_tools/evaluation/evaluator_base.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Carnegie Mellon University, Wenshan Wang 2 | # For License information please see the LICENSE file in the root directory. 3 | 4 | import numpy as np 5 | from .trajectory_transform import trajectory_transform, rescale 6 | from .transformation import pos_quats2SE_matrices, SE2pos_quat 7 | 8 | 9 | np.set_printoptions(suppress=True, precision=2, threshold=100000) 10 | 11 | def transform_trajs(gt_traj, est_traj, cal_scale): 12 | gt_traj, est_traj = trajectory_transform(gt_traj, est_traj) 13 | if cal_scale : 14 | est_traj, s = rescale(gt_traj, est_traj) 15 | print(' Scale, {}'.format(s)) 16 | else: 17 | s = 1.0 18 | return gt_traj, est_traj, s 19 | 20 | def quats2SEs(gt_traj, est_traj): 21 | gt_SEs = pos_quats2SE_matrices(gt_traj) 22 | est_SEs = pos_quats2SE_matrices(est_traj) 23 | return gt_SEs, est_SEs 24 | 25 | from .evaluate_ate_scale import align, plot_traj 26 | 27 | 28 | class ATEEvaluator(object): 29 | def __init__(self): 30 | super(ATEEvaluator, self).__init__() 31 | 32 | 33 | def evaluate(self, gt_traj, est_traj, scale): 34 | gt_xyz = np.matrix(gt_traj[:,0:3].transpose()) 35 | est_xyz = np.matrix(est_traj[:, 0:3].transpose()) 36 | 37 | rot, trans, trans_error, s = align(gt_xyz, est_xyz, scale) 38 | print(' ATE scale: {}'.format(s)) 39 | error = np.sqrt(np.dot(trans_error,trans_error) / len(trans_error)) 40 | 41 | # align two trajs 42 | est_SEs = pos_quats2SE_matrices(est_traj) 43 | T = np.eye(4) 44 | T[:3,:3] = rot 45 | T[:3,3:] = trans 46 | T = np.linalg.inv(T) 47 | est_traj_aligned = [] 48 | for se in est_SEs: 49 | se[:3,3] = se[:3,3] * s 50 | se_new = T.dot(se) 51 | se_new = SE2pos_quat(se_new) 52 | est_traj_aligned.append(se_new) 53 | 54 | 55 | return error, gt_traj, est_traj_aligned 56 | 57 | # ======================= 58 | 59 | from .evaluate_rpe import evaluate_trajectory 60 | 61 | class RPEEvaluator(object): 62 | def __init__(self): 63 | super(RPEEvaluator, self).__init__() 64 | 65 | 66 | def evaluate(self, gt_SEs, est_SEs): 67 | result = evaluate_trajectory(gt_SEs, est_SEs) 68 | 69 | trans_error = np.array(result)[:,2] 70 | rot_error = np.array(result)[:,3] 71 | 72 | trans_error_mean = np.mean(trans_error) 73 | rot_error_mean = np.mean(rot_error) 74 | 75 | # import ipdb;ipdb.set_trace() 76 | 77 | return (rot_error_mean, trans_error_mean) 78 | 79 | # ======================= 80 | 81 | from .evaluate_kitti import evaluate as kittievaluate 82 | 83 | class KittiEvaluator(object): 84 | def __init__(self): 85 | super(KittiEvaluator, self).__init__() 86 | 87 | # return rot_error, tra_error 88 | def evaluate(self, gt_SEs, est_SEs): 89 | # trajectory_scale(est_SEs, 0.831984631412) 90 | error = kittievaluate(gt_SEs, est_SEs) 91 | return error 92 | -------------------------------------------------------------------------------- /thirdparty/tartanair_tools/evaluation/tartanair_evaluator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Carnegie Mellon University, Wenshan Wang 2 | # For License information please see the LICENSE file in the root directory. 3 | 4 | import numpy as np 5 | from os.path import isdir, isfile 6 | 7 | from .evaluator_base import ATEEvaluator, RPEEvaluator, KittiEvaluator, transform_trajs, quats2SEs 8 | 9 | # from trajectory_transform import timestamp_associate 10 | 11 | 12 | def plot_traj(gtposes, estposes, vis=False, savefigname=None, title=''): 13 | import matplotlib.pyplot as plt 14 | fig = plt.figure(figsize=(4,4)) 15 | 16 | 17 | cm = plt.cm.get_cmap('Spectral') 18 | 19 | plt.subplot(111) 20 | plt.plot(gtposes[:,2],gtposes[:,0], linestyle='dashed',c='k') 21 | plt.plot(estposes[:, 2], estposes[:, 0],c='#ff7f0e') 22 | plt.xlabel('x (m)') 23 | plt.ylabel('y (m)') 24 | plt.legend(['Ground Truth','Ours']) 25 | plt.title(title) 26 | 27 | plt.axis('equal') 28 | 29 | if savefigname is not None: 30 | plt.savefig(savefigname) 31 | 32 | if vis: 33 | plt.show() 34 | 35 | plt.close(fig) 36 | 37 | 38 | # 39 | 40 | class TartanAirEvaluator: 41 | def __init__(self, scale = False, round=1): 42 | self.ate_eval = ATEEvaluator() 43 | self.rpe_eval = RPEEvaluator() 44 | self.kitti_eval = KittiEvaluator() 45 | 46 | def evaluate_one_trajectory(self, gt_traj, est_traj, scale=False, title=''): 47 | """ 48 | scale = True: calculate a global scale 49 | """ 50 | 51 | if gt_traj.shape[0] != est_traj.shape[0]: 52 | raise Exception("POSEFILE_LENGTH_ILLEGAL") 53 | 54 | if gt_traj.shape[1] != 7 or est_traj.shape[1] != 7: 55 | raise Exception("POSEFILE_FORMAT_ILLEGAL") 56 | 57 | gt_traj = gt_traj.astype(np.float64) 58 | est_traj = est_traj.astype(np.float64) 59 | 60 | ate_score, gt_ate_aligned, est_ate_aligned = self.ate_eval.evaluate(gt_traj, est_traj, scale) 61 | 62 | plot_traj(np.matrix(gt_ate_aligned), np.matrix(est_ate_aligned), vis=False, savefigname="figures/%s.pdf"%title, title=title) 63 | 64 | est_ate_aligned = np.array(est_ate_aligned) 65 | gt_SEs, est_SEs = quats2SEs(gt_ate_aligned, est_ate_aligned) 66 | 67 | 68 | 69 | rpe_score = self.rpe_eval.evaluate(gt_SEs, est_SEs) 70 | kitti_score = self.kitti_eval.evaluate(gt_SEs, est_SEs) 71 | 72 | return {'ate_score': ate_score, 'rpe_score': rpe_score, 'kitti_score': kitti_score} 73 | 74 | 75 | if __name__ == "__main__": 76 | 77 | # scale = True for monocular track, scale = False for stereo track 78 | aicrowd_evaluator = TartanAirEvaluator() 79 | result = aicrowd_evaluator.evaluate_one_trajectory('pose_gt.txt', 'pose_est.txt', scale=True) 80 | print(result) 81 | -------------------------------------------------------------------------------- /thirdparty/tartanair_tools/evaluation/trajectory_transform.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Carnegie Mellon University, Wenshan Wang 2 | # For License information please see the LICENSE file in the root directory. 3 | 4 | import numpy as np 5 | from . import transformation as tf 6 | 7 | def shift0(traj): 8 | ''' 9 | Traj: a list of [t + quat] 10 | Return: translate and rotate the traj 11 | ''' 12 | traj_ses = tf.pos_quats2SE_matrices(np.array(traj)) 13 | traj_init = traj_ses[0] 14 | traj_init_inv = np.linalg.inv(traj_init) 15 | new_traj = [] 16 | for tt in traj_ses: 17 | ttt=traj_init_inv.dot(tt) 18 | new_traj.append(tf.SE2pos_quat(ttt)) 19 | return np.array(new_traj) 20 | 21 | def ned2cam(traj): 22 | ''' 23 | transfer a ned traj to camera frame traj 24 | ''' 25 | T = np.array([[0,1,0,0], 26 | [0,0,1,0], 27 | [1,0,0,0], 28 | [0,0,0,1]], dtype=np.float32) 29 | T_inv = np.linalg.inv(T) 30 | new_traj = [] 31 | traj_ses = tf.pos_quats2SE_matrices(np.array(traj)) 32 | 33 | for tt in traj_ses: 34 | ttt=T.dot(tt).dot(T_inv) 35 | new_traj.append(tf.SE2pos_quat(ttt)) 36 | 37 | return np.array(new_traj) 38 | 39 | def cam2ned(traj): 40 | ''' 41 | transfer a camera traj to ned frame traj 42 | ''' 43 | T = np.array([[0,0,1,0], 44 | [1,0,0,0], 45 | [0,1,0,0], 46 | [0,0,0,1]], dtype=np.float32) 47 | T_inv = np.linalg.inv(T) 48 | new_traj = [] 49 | traj_ses = tf.pos_quats2SE_matrices(np.array(traj)) 50 | 51 | for tt in traj_ses: 52 | ttt=T.dot(tt).dot(T_inv) 53 | new_traj.append(tf.SE2pos_quat(ttt)) 54 | 55 | return np.array(new_traj) 56 | 57 | 58 | def trajectory_transform(gt_traj, est_traj): 59 | ''' 60 | 1. center the start frame to the axis origin 61 | 2. align the GT frame (NED) with estimation frame (camera) 62 | ''' 63 | gt_traj_trans = shift0(gt_traj) 64 | est_traj_trans = shift0(est_traj) 65 | 66 | # gt_traj_trans = ned2cam(gt_traj_trans) 67 | # est_traj_trans = cam2ned(est_traj_trans) 68 | 69 | return gt_traj_trans, est_traj_trans 70 | 71 | def rescale_bk(poses_gt, poses): 72 | motion_gt = tf.pose2motion(poses_gt) 73 | motion = tf.pose2motion(poses) 74 | 75 | speed_square_gt = np.sum(motion_gt[:,0:3,3]*motion_gt[:,0:3,3],1) 76 | speed_gt = np.sqrt(speed_square_gt) 77 | speed_square = np.sum(motion[:,0:3,3]*motion[:,0:3,3],1) 78 | speed = np.sqrt(speed_square) 79 | # when the speed is small, the scale could become very large 80 | # import ipdb;ipdb.set_trace() 81 | mask = (speed_gt>0.0001) # * (speed>0.00001) 82 | scale = np.mean((speed[mask])/speed_gt[mask]) 83 | scale = 1.0/scale 84 | motion[:,0:3,3] = motion[:,0:3,3]*scale 85 | pose_update = tf.motion2pose(motion) 86 | return pose_update, scale 87 | 88 | def pose2trans(pose_data): 89 | data_size = len(pose_data) 90 | trans = [] 91 | for i in range(0,data_size-1): 92 | tran = np.array(pose_data[i+1][:3]) - np.array(pose_data[i][:3]) # np.linalg.inv(data[i]).dot(data[i+1]) 93 | trans.append(tran) 94 | 95 | return np.array(trans) # N x 3 96 | 97 | 98 | def rescale(poses_gt, poses): 99 | ''' 100 | similar to rescale 101 | poses_gt/poses: N x 7 poselist in quaternion format 102 | ''' 103 | trans_gt = pose2trans(poses_gt) 104 | trans = pose2trans(poses) 105 | 106 | speed_square_gt = np.sum(trans_gt*trans_gt,1) 107 | speed_gt = np.sqrt(speed_square_gt) 108 | speed_square = np.sum(trans*trans,1) 109 | speed = np.sqrt(speed_square) 110 | # when the speed is small, the scale could become very large 111 | # import ipdb;ipdb.set_trace() 112 | mask = (speed_gt>0.0001) # * (speed>0.00001) 113 | scale = np.mean((speed[mask])/speed_gt[mask]) 114 | scale = 1.0/scale 115 | poses[:,0:3] = poses[:,0:3]*scale 116 | return poses, scale 117 | 118 | def trajectory_scale(traj, scale): 119 | for ttt in traj: 120 | ttt[0:3,3] = ttt[0:3,3]*scale 121 | return traj 122 | 123 | def timestamp_associate(first_list, second_list, max_difference): 124 | """ 125 | Associate two trajectory of [stamp,data]. As the time stamps never match exactly, we aim 126 | to find the closest match for every input tuple. 127 | 128 | Input: 129 | first_list -- first list of (stamp,data) 130 | second_list -- second list of (stamp,data) 131 | max_difference -- search radius for candidate generation 132 | 133 | Output: 134 | first_res: matched data from the first list 135 | second_res: matched data from the second list 136 | 137 | """ 138 | first_dict = dict([(l[0],l[1:]) for l in first_list if len(l)>1]) 139 | second_dict = dict([(l[0],l[1:]) for l in second_list if len(l)>1]) 140 | 141 | first_keys = first_dict.keys() 142 | second_keys = second_dict.keys() 143 | potential_matches = [(abs(a - b ), a, b) 144 | for a in first_keys 145 | for b in second_keys 146 | if abs(a - b) < max_difference] 147 | potential_matches.sort() 148 | matches = [] 149 | for diff, a, b in potential_matches: 150 | if a in first_keys and b in second_keys: 151 | first_keys.remove(a) 152 | second_keys.remove(b) 153 | matches.append((a, b)) 154 | 155 | matches.sort() 156 | 157 | first_res = [] 158 | second_res = [] 159 | for t1, t2 in matches: 160 | first_res.append(first_dict[t1]) 161 | second_res.append(second_dict[t2]) 162 | return np.array(first_res), np.array(second_res) 163 | -------------------------------------------------------------------------------- /thirdparty/tartanair_tools/evaluation/transformation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Carnegie Mellon University, Wenshan Wang 2 | # For License information please see the LICENSE file in the root directory. 3 | # Cridit: Xiangwei Wang https://github.com/TimingSpace 4 | 5 | import numpy as np 6 | from scipy.spatial.transform import Rotation as R 7 | 8 | def line2mat(line_data): 9 | mat = np.eye(4) 10 | mat[0:3,:] = line_data.reshape(3,4) 11 | return np.matrix(mat) 12 | 13 | def motion2pose(data): 14 | data_size = len(data) 15 | all_pose = [] # np.zeros((data_size+1, 4, 4)) 16 | all_pose.append(np.eye(4,4)) #[0,:] = np.eye(4,4) 17 | pose = np.eye(4,4) 18 | for i in range(0,data_size): 19 | pose = pose.dot(data[i]) 20 | all_pose.append(pose) 21 | return all_pose 22 | 23 | def pose2motion(data): 24 | data_size = len(data) 25 | all_motion = [] 26 | for i in range(0,data_size-1): 27 | motion = np.linalg.inv(data[i]).dot(data[i+1]) 28 | all_motion.append(motion) 29 | 30 | return np.array(all_motion) # N x 4 x 4 31 | 32 | def SE2se(SE_data): 33 | result = np.zeros((6)) 34 | result[0:3] = np.array(SE_data[0:3,3].T) 35 | result[3:6] = SO2so(SE_data[0:3,0:3]).T 36 | return result 37 | 38 | def SO2so(SO_data): 39 | return R.from_matrix(SO_data).as_rotvec() 40 | 41 | def so2SO(so_data): 42 | return R.from_rotvec(so_data).as_matrix() 43 | 44 | def se2SE(se_data): 45 | result_mat = np.matrix(np.eye(4)) 46 | result_mat[0:3,0:3] = so2SO(se_data[3:6]) 47 | result_mat[0:3,3] = np.matrix(se_data[0:3]).T 48 | return result_mat 49 | ### can get wrong result 50 | def se_mean(se_datas): 51 | all_SE = np.matrix(np.eye(4)) 52 | for i in range(se_datas.shape[0]): 53 | se = se_datas[i,:] 54 | SE = se2SE(se) 55 | all_SE = all_SE*SE 56 | all_se = SE2se(all_SE) 57 | mean_se = all_se/se_datas.shape[0] 58 | return mean_se 59 | 60 | def ses_mean(se_datas): 61 | se_datas = np.array(se_datas) 62 | se_datas = np.transpose(se_datas.reshape(se_datas.shape[0],se_datas.shape[1],se_datas.shape[2]*se_datas.shape[3]),(0,2,1)) 63 | se_result = np.zeros((se_datas.shape[0],se_datas.shape[2])) 64 | for i in range(0,se_datas.shape[0]): 65 | mean_se = se_mean(se_datas[i,:,:]) 66 | se_result[i,:] = mean_se 67 | return se_result 68 | 69 | def ses2poses(data): 70 | data_size = data.shape[0] 71 | all_pose = np.zeros((data_size+1,12)) 72 | temp = np.eye(4,4).reshape(1,16) 73 | all_pose[0,:] = temp[0,0:12] 74 | pose = np.matrix(np.eye(4,4)) 75 | for i in range(0,data_size): 76 | data_mat = se2SE(data[i,:]) 77 | pose = pose*data_mat 78 | pose_line = np.array(pose[0:3,:]).reshape(1,12) 79 | all_pose[i+1,:] = pose_line 80 | return all_pose 81 | 82 | def SEs2ses(motion_data): 83 | data_size = motion_data.shape[0] 84 | ses = np.zeros((data_size,6)) 85 | for i in range(0,data_size): 86 | SE = np.matrix(np.eye(4)) 87 | SE[0:3,:] = motion_data[i,:].reshape(3,4) 88 | ses[i,:] = SE2se(SE) 89 | return ses 90 | 91 | def so2quat(so_data): 92 | so_data = np.array(so_data) 93 | theta = np.sqrt(np.sum(so_data*so_data)) 94 | axis = so_data/theta 95 | quat=np.zeros(4) 96 | quat[0:3] = np.sin(theta/2)*axis 97 | quat[3] = np.cos(theta/2) 98 | return quat 99 | 100 | def quat2so(quat_data): 101 | quat_data = np.array(quat_data) 102 | sin_half_theta = np.sqrt(np.sum(quat_data[0:3]*quat_data[0:3])) 103 | axis = quat_data[0:3]/sin_half_theta 104 | cos_half_theta = quat_data[3] 105 | theta = 2*np.arctan2(sin_half_theta,cos_half_theta) 106 | so = theta*axis 107 | return so 108 | 109 | # input so_datas batch*channel*height*width 110 | # return quat_datas batch*numner*channel 111 | def sos2quats(so_datas,mean_std=[[1],[1]]): 112 | so_datas = np.array(so_datas) 113 | so_datas = so_datas.reshape(so_datas.shape[0],so_datas.shape[1],so_datas.shape[2]*so_datas.shape[3]) 114 | so_datas = np.transpose(so_datas,(0,2,1)) 115 | quat_datas = np.zeros((so_datas.shape[0],so_datas.shape[1],4)) 116 | for i_b in range(0,so_datas.shape[0]): 117 | for i_p in range(0,so_datas.shape[1]): 118 | so_data = so_datas[i_b,i_p,:] 119 | quat_data = so2quat(so_data) 120 | quat_datas[i_b,i_p,:] = quat_data 121 | return quat_datas 122 | 123 | def SO2quat(SO_data): 124 | rr = R.from_matrix(SO_data) 125 | return rr.as_quat() 126 | 127 | def quat2SO(quat_data): 128 | return R.from_quat(quat_data).as_matrix() 129 | 130 | 131 | def pos_quat2SE(quat_data): 132 | SO = R.from_quat(quat_data[3:7]).as_matrix() 133 | SE = np.matrix(np.eye(4)) 134 | SE[0:3,0:3] = np.matrix(SO) 135 | SE[0:3,3] = np.matrix(quat_data[0:3]).T 136 | SE = np.array(SE[0:3,:]).reshape(1,12) 137 | return SE 138 | 139 | 140 | def pos_quats2SEs(quat_datas): 141 | data_len = quat_datas.shape[0] 142 | SEs = np.zeros((data_len,12)) 143 | for i_data in range(0,data_len): 144 | SE = pos_quat2SE(quat_datas[i_data,:]) 145 | SEs[i_data,:] = SE 146 | return SEs 147 | 148 | 149 | def pos_quats2SE_matrices(quat_datas): 150 | data_len = quat_datas.shape[0] 151 | SEs = [] 152 | for quat in quat_datas: 153 | SO = R.from_quat(quat[3:7]).as_matrix() 154 | SE = np.eye(4) 155 | SE[0:3,0:3] = SO 156 | SE[0:3,3] = quat[0:3] 157 | SEs.append(SE) 158 | return SEs 159 | 160 | def SE2pos_quat(SE_data): 161 | pos_quat = np.zeros(7) 162 | pos_quat[3:] = SO2quat(SE_data[0:3,0:3]) 163 | pos_quat[:3] = SE_data[0:3,3].T 164 | return pos_quat -------------------------------------------------------------------------------- /thirdparty/tartanair_tools/seg_rgbs.txt: -------------------------------------------------------------------------------- 1 | 0 0 0 2 | 153 108 6 3 | 112 105 191 4 | 89 121 72 5 | 190 225 64 6 | 206 190 59 7 | 81 13 36 8 | 115 176 195 9 | 161 171 27 10 | 135 169 180 11 | 29 26 199 12 | 102 16 239 13 | 242 107 146 14 | 156 198 23 15 | 49 89 160 16 | 68 218 116 17 | 11 236 9 18 | 196 30 8 19 | 121 67 28 20 | 0 53 65 21 | 146 52 70 22 | 226 149 143 23 | 151 126 171 24 | 194 39 7 25 | 205 120 161 26 | 212 51 60 27 | 211 80 208 28 | 189 135 188 29 | 54 72 205 30 | 103 252 157 31 | 124 21 123 32 | 19 132 69 33 | 195 237 132 34 | 94 253 175 35 | 182 251 87 36 | 90 162 242 37 | 199 29 1 38 | 254 12 229 39 | 35 196 244 40 | 220 163 49 41 | 86 254 214 42 | 152 3 129 43 | 92 31 106 44 | 207 229 90 45 | 125 75 48 46 | 98 55 74 47 | 126 129 238 48 | 222 153 109 49 | 85 152 34 50 | 173 69 31 51 | 37 128 125 52 | 58 19 33 53 | 134 57 119 54 | 218 124 115 55 | 120 0 200 56 | 225 131 92 57 | 246 90 16 58 | 51 155 241 59 | 202 97 155 60 | 184 145 182 61 | 96 232 44 62 | 133 244 133 63 | 180 191 29 64 | 1 222 192 65 | 99 242 104 66 | 91 168 219 67 | 65 54 217 68 | 148 66 130 69 | 203 102 204 70 | 216 78 75 71 | 234 20 250 72 | 109 206 24 73 | 164 194 17 74 | 157 23 236 75 | 158 114 88 76 | 245 22 110 77 | 67 17 35 78 | 181 213 93 79 | 170 179 42 80 | 52 187 148 81 | 247 200 111 82 | 25 62 174 83 | 100 25 240 84 | 191 195 144 85 | 252 36 67 86 | 241 77 149 87 | 237 33 141 88 | 119 230 85 89 | 28 34 108 90 | 78 98 254 91 | 114 161 30 92 | 75 50 243 93 | 66 226 253 94 | 46 104 76 95 | 8 234 216 96 | 15 241 102 97 | 93 14 71 98 | 192 255 193 99 | 253 41 164 100 | 24 175 120 101 | 185 243 231 102 | 169 233 97 103 | 243 215 145 104 | 72 137 21 105 | 160 113 101 106 | 214 92 13 107 | 167 140 147 108 | 101 109 181 109 | 53 118 126 110 | 3 177 32 111 | 40 63 99 112 | 186 139 153 113 | 88 207 100 114 | 71 146 227 115 | 236 38 187 116 | 215 4 215 117 | 18 211 66 118 | 113 49 134 119 | 47 42 63 120 | 219 103 127 121 | 57 240 137 122 | 227 133 211 123 | 145 71 201 124 | 217 173 183 125 | 250 40 113 126 | 208 125 68 127 | 224 186 249 128 | 69 148 46 129 | 239 85 20 130 | 108 116 224 131 | 56 214 26 132 | 179 147 43 133 | 48 188 172 134 | 221 83 47 135 | 155 166 218 136 | 62 217 189 137 | 198 180 122 138 | 201 144 169 139 | 132 2 14 140 | 128 189 114 141 | 163 227 112 142 | 45 157 177 143 | 64 86 142 144 | 118 193 163 145 | 14 32 79 146 | 200 45 170 147 | 74 81 2 148 | 59 37 212 149 | 73 35 225 150 | 95 224 39 151 | 84 170 220 152 | 159 58 173 153 | 17 91 237 154 | 31 95 84 155 | 34 201 248 156 | 63 73 209 157 | 129 235 107 158 | 231 115 40 159 | 36 74 95 160 | 238 228 154 161 | 61 212 54 162 | 13 94 165 163 | 141 174 0 164 | 140 167 255 165 | 117 93 91 166 | 183 10 186 167 | 165 28 61 168 | 144 238 194 169 | 12 158 41 170 | 76 110 234 171 | 150 9 121 172 | 142 1 246 173 | 230 136 198 174 | 5 60 233 175 | 232 250 80 176 | 143 112 56 177 | 187 70 156 178 | 2 185 62 179 | 138 223 226 180 | 122 183 222 181 | 166 245 3 182 | 175 6 140 183 | 240 59 210 184 | 248 44 10 185 | 83 82 52 186 | 223 248 167 187 | 87 15 150 188 | 111 178 117 189 | 197 84 22 190 | 235 208 124 191 | 9 76 45 192 | 176 24 50 193 | 154 159 251 194 | 149 111 207 195 | 168 231 15 196 | 209 247 202 197 | 80 205 152 198 | 178 221 213 199 | 27 8 38 200 | 244 117 51 201 | 107 68 190 202 | 23 199 139 203 | 171 88 168 204 | 136 202 58 205 | 6 46 86 206 | 105 127 176 207 | 174 249 197 208 | 172 172 138 209 | 228 142 81 210 | 7 204 185 211 | 22 61 247 212 | 233 100 78 213 | 127 65 105 214 | 33 87 158 215 | 139 156 252 216 | 42 7 136 217 | 20 99 179 218 | 79 150 223 219 | 131 182 184 220 | 110 123 37 221 | 60 138 96 222 | 210 96 94 223 | 123 48 18 224 | 137 197 162 225 | 188 18 5 226 | 39 219 151 227 | 204 143 135 228 | 249 79 73 229 | 77 64 178 230 | 41 246 77 231 | 16 154 4 232 | 116 134 19 233 | 4 122 235 234 | 177 106 230 235 | 21 119 12 236 | 104 5 98 237 | 50 130 53 238 | 30 192 25 239 | 26 165 166 240 | 10 160 82 241 | 106 43 131 242 | 44 216 103 243 | 255 101 221 244 | 32 151 196 245 | 213 220 89 246 | 70 209 228 247 | 97 184 83 248 | 82 239 232 249 | 251 164 128 250 | 193 11 245 251 | 38 27 159 252 | 229 141 203 253 | 130 56 55 254 | 147 210 11 255 | 162 203 118 256 | 255 255 255 -------------------------------------------------------------------------------- /tools/download_eth3d.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -euo pipefail 3 | 4 | # where to put everything 5 | ETH3D_PATH="datasets/ETH3D-SLAM" 6 | 7 | # list of all sequences 8 | evalset=( 9 | cables_1 10 | cables_2 11 | cables_3 12 | camera_shake_1 13 | camera_shake_2 14 | camera_shake_3 15 | ceiling_1 16 | ceiling_2 17 | desk_3 18 | desk_changing_1 19 | einstein_1 20 | einstein_2 21 | einstein_dark 22 | einstein_flashlight 23 | einstein_global_light_changes_1 24 | einstein_global_light_changes_2 25 | einstein_global_light_changes_3 26 | kidnap_1 27 | kidnap_dark 28 | large_loop_1 29 | mannequin_1 30 | mannequin_3 31 | mannequin_4 32 | mannequin_5 33 | mannequin_7 34 | mannequin_face_1 35 | mannequin_face_2 36 | mannequin_face_3 37 | mannequin_head 38 | motion_1 39 | planar_2 40 | planar_3 41 | plant_1 42 | plant_2 43 | plant_3 44 | plant_4 45 | plant_5 46 | plant_dark 47 | plant_scene_1 48 | plant_scene_2 49 | plant_scene_3 50 | reflective_1 51 | repetitive 52 | sfm_bench 53 | sfm_garden 54 | sfm_house_loop 55 | sfm_lab_room_1 56 | sfm_lab_room_2 57 | sofa_1 58 | sofa_2 59 | sofa_3 60 | sofa_4 61 | sofa_dark_1 62 | sofa_dark_2 63 | sofa_dark_3 64 | sofa_shake 65 | table_3 66 | table_4 67 | table_7 68 | vicon_light_1 69 | vicon_light_2 70 | ) 71 | 72 | data_modes=( 73 | mono 74 | rgbd 75 | ) 76 | 77 | # make sure base dir exists 78 | mkdir -p "${ETH3D_PATH}" 79 | 80 | for scene in "${evalset[@]}"; do 81 | 82 | for mode in "${data_modes[@]}"; do 83 | url=https://www.eth3d.net/data/slam/datasets/${scene}_${mode}.zip 84 | 85 | # local paths 86 | zipfile="${ETH3D_PATH}/${scene}.zip" 87 | outdir="${ETH3D_PATH}" 88 | 89 | mkdir -p "${outdir}" 90 | 91 | echo "Downloading ${scene}..." 92 | wget -c "${url}" -O "${zipfile}" 93 | 94 | echo " Unzipping into ${outdir}/..." 95 | unzip -o "${zipfile}" -d "${outdir}" 96 | 97 | echo " Cleaning up..." 98 | rm "${zipfile}" 99 | 100 | done 101 | 102 | echo "✔ Done with ${scene}" 103 | 104 | done 105 | -------------------------------------------------------------------------------- /tools/download_euroc.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -euo pipefail 3 | 4 | # where to put everything 5 | EUROC_PATH="datasets/EuRoC" 6 | 7 | # list of all sequences 8 | declare -a evalset=( 9 | MH_01_easy 10 | MH_02_easy 11 | MH_03_medium 12 | MH_04_difficult 13 | MH_05_difficult 14 | V1_01_easy 15 | V1_02_medium 16 | V1_03_difficult 17 | V2_01_easy 18 | V2_02_medium 19 | V2_03_difficult 20 | ) 21 | 22 | # make sure base dir exists 23 | mkdir -p "${EUROC_PATH}" 24 | 25 | for scene in "${evalset[@]}"; do 26 | # full URL still needs the right folder on the server 27 | if [[ "${scene}" == MH* ]]; then 28 | url="http://robotics.ethz.ch/~asl-datasets/ijrr_euroc_mav_dataset/machine_hall/${scene}/${scene}.zip" 29 | elif [[ "${scene}" == V1* ]]; then 30 | url="http://robotics.ethz.ch/~asl-datasets/ijrr_euroc_mav_dataset/vicon_room1/${scene}/${scene}.zip" 31 | else 32 | url="http://robotics.ethz.ch/~asl-datasets/ijrr_euroc_mav_dataset/vicon_room2/${scene}/${scene}.zip" 33 | fi 34 | 35 | # local paths 36 | zipfile="${EUROC_PATH}/${scene}.zip" 37 | outdir="${EUROC_PATH}/${scene}" 38 | 39 | mkdir -p "${outdir}" 40 | 41 | echo "Downloading ${scene}..." 42 | wget -c "${url}" -O "${zipfile}" 43 | 44 | echo " Unzipping into ${outdir}/..." 45 | unzip -o "${zipfile}" -d "${outdir}" 46 | 47 | echo " Cleaning up..." 48 | rm "${zipfile}" 49 | 50 | echo "✔ Done with ${scene}" 51 | done 52 | -------------------------------------------------------------------------------- /tools/download_model.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | gdown 1PpqVt1H4maBa_GbPJp4NwxRsd9jk-elh -------------------------------------------------------------------------------- /tools/download_sample_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | mkdir -p data && cd data 4 | 5 | wget https://www.eth3d.net/data/slam/datasets/sfm_bench_mono.zip 6 | unzip sfm_bench_mono.zip 7 | rm sfm_bench_mono.zip 8 | 9 | wget https://vision.in.tum.de/rgbd/dataset/freiburg3/rgbd_dataset_freiburg3_cabinet.tgz 10 | tar -zxvf rgbd_dataset_freiburg3_cabinet.tgz 11 | rm rgbd_dataset_freiburg3_cabinet.tgz 12 | 13 | wget http://robotics.ethz.ch/~asl-datasets/ijrr_euroc_mav_dataset/machine_hall/MH_03_medium/MH_03_medium.zip 14 | unzip MH_03_medium.zip 15 | rm MH_03_medium.zip 16 | 17 | cd .. 18 | -------------------------------------------------------------------------------- /tools/download_tartanair_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | TARTANAIR_PATH="datasets/tartanair_test" 4 | 5 | mkdir -p "${TARTANAIR_PATH}" 6 | 7 | gdown "1N8qoU-oEjRKdaKSrHPWA-xsnRtofR_jJ" --output "${TARTANAIR_PATH}/images.tar.gz" 8 | wget -c "https://cmu.box.com/shared/static/3p1sf0eljfwrz4qgbpc6g95xtn2alyfk.zip" -O ${TARTANAIR_PATH}/groundtruth.zip 9 | 10 | unzip -o ${TARTANAIR_PATH}/groundtruth.zip -d "${TARTANAIR_PATH}" 11 | tar -zxvf "${TARTANAIR_PATH}/images.tar.gz" -C "${TARTANAIR_PATH}" 12 | 13 | 14 | # rm ${TARTANAIR_PATH}/groundtruth.zip 15 | # rm ${TARTANAIR_PATH}/images.zip 16 | -------------------------------------------------------------------------------- /tools/download_tum.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -euo pipefail 3 | 4 | # where to put everything 5 | TUM_PATH="datasets/TUM-RGBD" 6 | 7 | # list of all sequences 8 | declare -a evalset=( 9 | rgbd_dataset_freiburg1_360 10 | rgbd_dataset_freiburg1_desk 11 | rgbd_dataset_freiburg1_desk2 12 | rgbd_dataset_freiburg1_floor 13 | rgbd_dataset_freiburg1_plant 14 | rgbd_dataset_freiburg1_room 15 | rgbd_dataset_freiburg1_rpy 16 | rgbd_dataset_freiburg1_teddy 17 | rgbd_dataset_freiburg1_xyz 18 | ) 19 | 20 | # make sure base dir exists 21 | mkdir -p "${TUM_PATH}" 22 | 23 | for scene in "${evalset[@]}"; do 24 | # full URL still needs the right folder on the server 25 | url="https://cvg.cit.tum.de/rgbd/dataset/freiburg1/${scene}.tgz" 26 | 27 | # local paths 28 | tarfile="${TUM_PATH}/${scene}.tgz" 29 | outdir="${TUM_PATH}" 30 | 31 | mkdir -p "${outdir}" 32 | 33 | echo "Downloading ${scene}..." 34 | wget -c "${url}" -O "${tarfile}" 35 | 36 | echo " Unzipping into ${outdir}/..." 37 | tar -zxvf "${tarfile}" -C "${outdir}" 38 | 39 | echo " Cleaning up..." 40 | rm "${tarfile}" 41 | 42 | echo "✔ Done with ${scene}" 43 | done 44 | -------------------------------------------------------------------------------- /tools/evaluate_eth3d.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ETH_PATH=datasets/ETH3D-SLAM 4 | 5 | # all "non-dark" training scenes 6 | evalset=( 7 | cables_1 8 | cables_2 9 | cables_3 10 | camera_shake_1 11 | camera_shake_2 12 | camera_shake_3 13 | ceiling_1 14 | ceiling_2 15 | desk_3 16 | desk_changing_1 17 | einstein_1 18 | einstein_2 19 | # einstein_dark 20 | einstein_flashlight 21 | einstein_global_light_changes_1 22 | einstein_global_light_changes_2 23 | einstein_global_light_changes_3 24 | kidnap_1 25 | # kidnap_dark 26 | large_loop_1 27 | mannequin_1 28 | mannequin_3 29 | mannequin_4 30 | mannequin_5 31 | mannequin_7 32 | mannequin_face_1 33 | mannequin_face_2 34 | mannequin_face_3 35 | mannequin_head 36 | motion_1 37 | planar_2 38 | planar_3 39 | plant_1 40 | plant_2 41 | plant_3 42 | plant_4 43 | plant_5 44 | # plant_dark 45 | plant_scene_1 46 | plant_scene_2 47 | plant_scene_3 48 | reflective_1 49 | repetitive 50 | sfm_bench 51 | sfm_garden 52 | sfm_house_loop 53 | sfm_lab_room_1 54 | sfm_lab_room_2 55 | sofa_1 56 | sofa_2 57 | sofa_3 58 | sofa_4 59 | # sofa_dark_1 60 | # sofa_dark_2 61 | # sofa_dark_3 62 | sofa_shake 63 | table_3 64 | table_4 65 | table_7 66 | vicon_light_1 67 | vicon_light_2 68 | ) 69 | 70 | for seq in ${evalset[@]}; do 71 | python evaluation_scripts/test_eth3d.py --datapath=$ETH_PATH/$seq --weights=droid.pth --disable_vis $@ 72 | done 73 | 74 | 75 | 76 | 77 | -------------------------------------------------------------------------------- /tools/evaluate_euroc.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | EUROC_PATH=datasets/EuRoC 5 | 6 | evalset=( 7 | MH_01_easy 8 | MH_02_easy 9 | MH_03_medium 10 | MH_04_difficult 11 | MH_05_difficult 12 | V1_01_easy 13 | V1_02_medium 14 | V1_03_difficult 15 | V2_01_easy 16 | V2_02_medium 17 | V2_03_difficult 18 | ) 19 | 20 | for seq in ${evalset[@]}; do 21 | python evaluation_scripts/test_euroc.py --datapath=$EUROC_PATH/$seq --gt=data/euroc_groundtruth/$seq.txt --disable_vis --weights=droid.pth $@ 22 | 23 | done 24 | 25 | -------------------------------------------------------------------------------- /tools/evaluate_tum.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | TUM_PATH=datasets/TUM-RGBD/ 5 | 6 | evalset=( 7 | rgbd_dataset_freiburg1_360 8 | rgbd_dataset_freiburg1_desk 9 | rgbd_dataset_freiburg1_desk2 10 | rgbd_dataset_freiburg1_floor 11 | rgbd_dataset_freiburg1_plant 12 | rgbd_dataset_freiburg1_room 13 | rgbd_dataset_freiburg1_rpy 14 | rgbd_dataset_freiburg1_teddy 15 | rgbd_dataset_freiburg1_xyz 16 | ) 17 | 18 | for seq in ${evalset[@]}; do 19 | python evaluation_scripts/test_tum.py --datapath=$TUM_PATH/$seq --weights=droid.pth --disable_vis $@ 20 | done 21 | 22 | -------------------------------------------------------------------------------- /tools/validate_tartanair.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | TARTANAIR_PATH=datasets/TartanAir 5 | 6 | python evaluation_scripts/validate_tartanair.py --datapath=$TARTANAIR_PATH --weights=droid.pth --disable_vis $@ 7 | 8 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('droid_slam') 3 | 4 | import cv2 5 | import numpy as np 6 | from collections import OrderedDict 7 | 8 | import torch 9 | import torch.optim as optim 10 | from torch.utils.data import DataLoader 11 | from data_readers.factory import dataset_factory 12 | 13 | from lietorch import SO3, SE3, Sim3 14 | from geom import losses 15 | from geom.losses import geodesic_loss, residual_loss, flow_loss 16 | from geom.graph_utils import build_frame_graph 17 | 18 | # network 19 | from droid_net import DroidNet 20 | from logger import Logger 21 | 22 | # DDP training 23 | import torch.multiprocessing as mp 24 | import torch.distributed as dist 25 | from torch.nn.parallel import DistributedDataParallel as DDP 26 | 27 | 28 | def setup_ddp(gpu, args): 29 | dist.init_process_group( 30 | backend='nccl', 31 | init_method='env://', 32 | world_size=args.world_size, 33 | rank=gpu) 34 | 35 | torch.manual_seed(0) 36 | torch.cuda.set_device(gpu) 37 | 38 | def show_image(image): 39 | image = image.permute(1, 2, 0).cpu().numpy() 40 | cv2.imshow('image', image / 255.0) 41 | cv2.waitKey() 42 | 43 | def train(gpu, args): 44 | """ Test to make sure project transform correctly maps points """ 45 | 46 | # coordinate multiple GPUs 47 | setup_ddp(gpu, args) 48 | rng = np.random.default_rng(12345) 49 | 50 | N = args.n_frames 51 | model = DroidNet() 52 | model.cuda() 53 | model.train() 54 | 55 | model = DDP(model, device_ids=[gpu], find_unused_parameters=False) 56 | 57 | if args.ckpt is not None: 58 | model.load_state_dict(torch.load(args.ckpt)) 59 | 60 | # fetch dataloader 61 | db = dataset_factory(['tartan'], datapath=args.datapath, n_frames=args.n_frames, fmin=args.fmin, fmax=args.fmax) 62 | 63 | train_sampler = torch.utils.data.distributed.DistributedSampler( 64 | db, shuffle=True, num_replicas=args.world_size, rank=gpu) 65 | 66 | train_loader = DataLoader(db, batch_size=args.batch, sampler=train_sampler, num_workers=2) 67 | 68 | # fetch optimizer 69 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5) 70 | scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, 71 | args.lr, args.steps, pct_start=0.01, cycle_momentum=False) 72 | 73 | logger = Logger(args.name, scheduler) 74 | should_keep_training = True 75 | total_steps = 0 76 | 77 | while should_keep_training: 78 | for i_batch, item in enumerate(train_loader): 79 | optimizer.zero_grad() 80 | 81 | images, poses, disps, intrinsics = [x.to('cuda') for x in item] 82 | 83 | # convert poses w2c -> c2w 84 | Ps = SE3(poses).inv() 85 | Gs = SE3.IdentityLike(Ps) 86 | 87 | # randomize frame graph 88 | if np.random.rand() < 0.5: 89 | graph = build_frame_graph(poses, disps, intrinsics, num=args.edges) 90 | 91 | else: 92 | graph = OrderedDict() 93 | for i in range(N): 94 | graph[i] = [j for j in range(N) if i!=j and abs(i-j) <= 2] 95 | 96 | # fix first to camera poses 97 | Gs.data[:,0] = Ps.data[:,0].clone() 98 | Gs.data[:,1:] = Ps.data[:,[1]].clone() 99 | disp0 = torch.ones_like(disps[:,:,3::8,3::8]) 100 | 101 | # perform random restarts 102 | r = 0 103 | while r < args.restart_prob: 104 | r = rng.random() 105 | 106 | intrinsics0 = intrinsics / 8.0 107 | poses_est, disps_est, residuals = model(Gs, images, disp0, intrinsics0, 108 | graph, num_steps=args.iters, fixedp=2) 109 | 110 | geo_loss, geo_metrics = losses.geodesic_loss(Ps, poses_est, graph, do_scale=False) 111 | res_loss, res_metrics = losses.residual_loss(residuals) 112 | flo_loss, flo_metrics = losses.flow_loss(Ps, disps, poses_est, disps_est, intrinsics, graph) 113 | 114 | loss = args.w1 * geo_loss + args.w2 * res_loss + args.w3 * flo_loss 115 | loss.backward() 116 | 117 | Gs = poses_est[-1].detach() 118 | disp0 = disps_est[-1][:,:,3::8,3::8].detach() 119 | 120 | metrics = {} 121 | metrics.update(geo_metrics) 122 | metrics.update(res_metrics) 123 | metrics.update(flo_metrics) 124 | 125 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) 126 | optimizer.step() 127 | scheduler.step() 128 | 129 | total_steps += 1 130 | 131 | if gpu == 0: 132 | logger.push(metrics) 133 | 134 | if total_steps % 10000 == 0 and gpu == 0: 135 | PATH = 'checkpoints/%s_%06d.pth' % (args.name, total_steps) 136 | torch.save(model.state_dict(), PATH) 137 | 138 | if total_steps >= args.steps: 139 | should_keep_training = False 140 | break 141 | 142 | dist.destroy_process_group() 143 | 144 | 145 | if __name__ == '__main__': 146 | import argparse 147 | parser = argparse.ArgumentParser() 148 | parser.add_argument('--name', default='bla', help='name your experiment') 149 | parser.add_argument('--ckpt', help='checkpoint to restore') 150 | parser.add_argument('--datasets', nargs='+', help='lists of datasets for training') 151 | parser.add_argument('--datapath', default='datasets/TartanAir', help="path to dataset directory") 152 | parser.add_argument('--gpus', type=int, default=4) 153 | 154 | parser.add_argument('--batch', type=int, default=1) 155 | parser.add_argument('--iters', type=int, default=15) 156 | parser.add_argument('--steps', type=int, default=250000) 157 | parser.add_argument('--lr', type=float, default=0.00025) 158 | parser.add_argument('--clip', type=float, default=2.5) 159 | parser.add_argument('--n_frames', type=int, default=7) 160 | 161 | parser.add_argument('--w1', type=float, default=10.0) 162 | parser.add_argument('--w2', type=float, default=0.01) 163 | parser.add_argument('--w3', type=float, default=0.05) 164 | 165 | parser.add_argument('--fmin', type=float, default=8.0) 166 | parser.add_argument('--fmax', type=float, default=96.0) 167 | parser.add_argument('--noise', action='store_true') 168 | parser.add_argument('--scale', action='store_true') 169 | parser.add_argument('--edges', type=int, default=24) 170 | parser.add_argument('--restart_prob', type=float, default=0.2) 171 | 172 | args = parser.parse_args() 173 | 174 | args.world_size = args.gpus 175 | print(args) 176 | 177 | import os 178 | if not os.path.isdir('checkpoints'): 179 | os.mkdir('checkpoints') 180 | 181 | args = parser.parse_args() 182 | args.world_size = args.gpus 183 | 184 | os.environ['MASTER_ADDR'] = 'localhost' 185 | os.environ['MASTER_PORT'] = '12356' 186 | mp.spawn(train, nprocs=args.gpus, args=(args,)) 187 | 188 | -------------------------------------------------------------------------------- /view_reconstruction.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("droid_slam") 3 | 4 | import torch 5 | import argparse 6 | 7 | import droid_backends 8 | import argparse 9 | import open3d as o3d 10 | 11 | from visualization import create_camera_actor 12 | from lietorch import SE3 13 | 14 | from cuda_timer import CudaTimer 15 | 16 | def view_reconstruction(filename: str, filter_thresh = 0.005, filter_count=2): 17 | reconstruction_blob = torch.load(filename) 18 | images = reconstruction_blob["images"].cuda()[...,::2,::2] 19 | disps = reconstruction_blob["disps"].cuda()[...,::2,::2] 20 | poses = reconstruction_blob["poses"].cuda() 21 | intrinsics = 4 * reconstruction_blob["intrinsics"].cuda() 22 | 23 | disps = disps.contiguous() 24 | 25 | index = torch.arange(len(images), device="cuda") 26 | thresh = filter_thresh * torch.ones_like(disps.mean(dim=[1,2])) 27 | 28 | with CudaTimer("iproj"): 29 | points = droid_backends.iproj(SE3(poses).inv().data, disps, intrinsics[0]) 30 | colors = images[:,[2,1,0]].permute(0,2,3,1) / 255.0 31 | 32 | with CudaTimer("filter"): 33 | counts = droid_backends.depth_filter(poses, disps, intrinsics[0], index, thresh) 34 | 35 | mask = (counts >= filter_count) & (disps > .25 * disps.mean()) 36 | points_np = points[mask].cpu().numpy() 37 | colors_np = colors[mask].cpu().numpy() 38 | 39 | point_cloud = o3d.geometry.PointCloud() 40 | point_cloud.points = o3d.utility.Vector3dVector(points_np) 41 | point_cloud.colors = o3d.utility.Vector3dVector(colors_np) 42 | 43 | vis = o3d.visualization.Visualizer() 44 | vis.create_window(height=960, width=960) 45 | vis.get_render_option().load_from_json("misc/renderoption.json") 46 | 47 | vis.add_geometry(point_cloud) 48 | 49 | # get pose matrices as a nx4x4 numpy array 50 | pose_mats = SE3(poses).inv().matrix().cpu().numpy() 51 | 52 | ### add camera actor ### 53 | for i in range(len(poses)): 54 | cam_actor = create_camera_actor(False) 55 | cam_actor.transform(pose_mats[i]) 56 | vis.add_geometry(cam_actor) 57 | 58 | vis.run() 59 | vis.destroy_window() 60 | 61 | 62 | if __name__ == '__main__': 63 | parser = argparse.ArgumentParser() 64 | parser.add_argument("filename", type=str, help="path to image directory") 65 | parser.add_argument("--filter_threshold", type=float, default=0.005) 66 | parser.add_argument("--filter_count", type=int, default=3) 67 | args = parser.parse_args() 68 | 69 | view_reconstruction(args.filename, args.filter_threshold, args.filter_count) --------------------------------------------------------------------------------