├── examples
├── __init__.py
├── core
│ ├── geom
│ │ ├── __init__.py
│ │ ├── chol.py
│ │ ├── sampler_utils.py
│ │ ├── losses.py
│ │ ├── ba.py
│ │ ├── graph_utils.py
│ │ └── projective_ops.py
│ ├── networks
│ │ ├── __init__.py
│ │ ├── modules
│ │ │ ├── __init__.py
│ │ │ ├── clipping.py
│ │ │ ├── gru.py
│ │ │ ├── unet.py
│ │ │ ├── corr.py
│ │ │ └── extractor.py
│ │ ├── sim3_net.py
│ │ └── rslam.py
│ ├── data_readers
│ │ ├── __init__.py
│ │ ├── factory.py
│ │ ├── stream.py
│ │ ├── nyu2.py
│ │ ├── augmentation.py
│ │ ├── tum.py
│ │ ├── eth3d.py
│ │ ├── scannet.py
│ │ ├── rgbd_utils.py
│ │ ├── base.py
│ │ └── tartan.py
│ └── logger.py
├── rgbdslam
│ ├── rgbd_benchmark
│ │ ├── __init__.py
│ │ ├── associate.py
│ │ └── evaluate_ate.py
│ ├── assets
│ │ ├── room.png
│ │ ├── floor.png
│ │ └── renderoption.json
│ ├── reprojection_test.py
│ ├── demo.py
│ ├── evaluate.py
│ ├── readme.md
│ ├── train.py
│ └── viz.py
├── registration
│ ├── assets
│ │ ├── depth1.npy
│ │ ├── depth2.npy
│ │ ├── depth3.npy
│ │ ├── depth4.npy
│ │ ├── image1.png
│ │ ├── image2.png
│ │ ├── image3.png
│ │ ├── image4.png
│ │ ├── registration.gif
│ │ └── renderoption.json
│ ├── readme.md
│ ├── demo.py
│ ├── viz.py
│ └── main.py
├── readme.md
└── pgo
│ ├── readme.md
│ └── main.py
├── run_tests.sh
├── lietorch.png
├── .gitmodules
├── lietorch
├── __init__.py
├── include
│ ├── common.h
│ ├── lietorch_cpu.h
│ ├── lietorch_gpu.h
│ ├── dispatch.h
│ ├── sim3.h
│ ├── se3.h
│ └── so3.h
├── broadcasting.py
├── group_ops.py
├── extras
│ ├── se3_solver.cu
│ ├── corr_index_kernel.cu
│ └── extras.cpp
└── groups.py
├── .gitignore
├── pyproject.toml
├── setup.py
├── LICENSE
└── README.md
/examples/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/examples/core/geom/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/examples/core/networks/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/examples/core/data_readers/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/examples/core/networks/modules/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/examples/rgbdslam/rgbd_benchmark/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/run_tests.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | python lietorch/run_tests.py
4 |
--------------------------------------------------------------------------------
/lietorch.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-vl/lietorch/HEAD/lietorch.png
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "eigen"]
2 | path = eigen
3 | url = https://gitlab.com/libeigen/eigen.git
4 |
--------------------------------------------------------------------------------
/examples/rgbdslam/assets/room.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-vl/lietorch/HEAD/examples/rgbdslam/assets/room.png
--------------------------------------------------------------------------------
/lietorch/__init__.py:
--------------------------------------------------------------------------------
1 | __all__ = ['groups']
2 | from .groups import LieGroupParameter, SO3, RxSO3, SE3, Sim3, cat, stack
3 |
--------------------------------------------------------------------------------
/examples/rgbdslam/assets/floor.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-vl/lietorch/HEAD/examples/rgbdslam/assets/floor.png
--------------------------------------------------------------------------------
/examples/registration/assets/depth1.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-vl/lietorch/HEAD/examples/registration/assets/depth1.npy
--------------------------------------------------------------------------------
/examples/registration/assets/depth2.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-vl/lietorch/HEAD/examples/registration/assets/depth2.npy
--------------------------------------------------------------------------------
/examples/registration/assets/depth3.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-vl/lietorch/HEAD/examples/registration/assets/depth3.npy
--------------------------------------------------------------------------------
/examples/registration/assets/depth4.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-vl/lietorch/HEAD/examples/registration/assets/depth4.npy
--------------------------------------------------------------------------------
/examples/registration/assets/image1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-vl/lietorch/HEAD/examples/registration/assets/image1.png
--------------------------------------------------------------------------------
/examples/registration/assets/image2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-vl/lietorch/HEAD/examples/registration/assets/image2.png
--------------------------------------------------------------------------------
/examples/registration/assets/image3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-vl/lietorch/HEAD/examples/registration/assets/image3.png
--------------------------------------------------------------------------------
/examples/registration/assets/image4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-vl/lietorch/HEAD/examples/registration/assets/image4.png
--------------------------------------------------------------------------------
/examples/registration/assets/registration.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-vl/lietorch/HEAD/examples/registration/assets/registration.gif
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | build
3 | dist
4 | *.egg-info
5 | *.vscode/
6 | *.pth
7 | tests
8 | checkpoints
9 | datasets
10 | runs
11 | a.out
12 | cache
13 | *.g2o
14 |
15 |
--------------------------------------------------------------------------------
/lietorch/include/common.h:
--------------------------------------------------------------------------------
1 | #ifndef COMMON_H
2 | #define COMMON_H
3 |
4 | #define EIGEN_DEFAULT_DENSE_INDEX_TYPE int
5 | #define EIGEN_RUNTIME_NO_MALLOC
6 |
7 | #define EPS 1e-6
8 | #define PI 3.14159265358979323846
9 |
10 |
11 | #endif
12 |
13 |
--------------------------------------------------------------------------------
/examples/readme.md:
--------------------------------------------------------------------------------
1 | # Examples
2 |
3 | Instructions for running demos and experiments can be found in each of the example directories
4 | 1. [Pose Graph Optimization](pgo/readme.md) -> `pgo`
5 | 1. [Sim3 Registration](registration/readme.md) -> `registration`
6 | 1. [RGBD-SLAM](rgbdslam/readme.md) -> `rgbdslam`
7 | 2. [RAFT-3D (SceneFlow)]()
8 |
9 | `core` contains networks, data loaders, and other common utility functions.
10 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools", "torch", "wheel"]
3 | build-backend = "setuptools.build_meta:__legacy__"
4 |
5 | [project]
6 | name = "lietorch"
7 | version = "0.3"
8 | description = "Lie Groups for PyTorch"
9 | authors = [
10 | { name="Zachary Teed", email="zachteed@gmail.com" }
11 | ]
12 | license = { text = "BSD-3-Clause" }
13 | readme = "README.md"
14 | requires-python = ">=3.9"
15 | dependencies = [
16 | "torch",
17 | "numpy",
18 | ]
19 |
--------------------------------------------------------------------------------
/examples/core/networks/modules/clipping.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | GRAD_CLIP = .01
6 |
7 | class GradClip(torch.autograd.Function):
8 | @staticmethod
9 | def forward(ctx, x):
10 | return x
11 |
12 | @staticmethod
13 | def backward(ctx, grad_x):
14 | o = torch.zeros_like(grad_x)
15 | grad_x = torch.where(grad_x.abs()>GRAD_CLIP, o, grad_x)
16 | grad_x = torch.where(torch.isnan(grad_x), o, grad_x)
17 | return grad_x
18 |
19 | class GradientClip(nn.Module):
20 | def __init__(self):
21 | super(GradientClip, self).__init__()
22 |
23 | def forward(self, x):
24 | return GradClip.apply(x)
--------------------------------------------------------------------------------
/examples/core/networks/modules/gru.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class ConvGRU(nn.Module):
5 | def __init__(self, h_planes=128, i_planes=128):
6 | super(ConvGRU, self).__init__()
7 | self.do_checkpoint = False
8 | self.convz = nn.Conv2d(h_planes+i_planes, h_planes, 3, padding=1)
9 | self.convr = nn.Conv2d(h_planes+i_planes, h_planes, 3, padding=1)
10 | self.convq = nn.Conv2d(h_planes+i_planes, h_planes, 3, padding=1)
11 |
12 | def forward(self, net, *inputs):
13 | inp = torch.cat(inputs, dim=1)
14 | net_inp = torch.cat([net, inp], dim=1)
15 |
16 | z = torch.sigmoid(self.convz(net_inp))
17 | r = torch.sigmoid(self.convr(net_inp))
18 | q = torch.tanh(self.convq(torch.cat([r*net, inp], dim=1)))
19 |
20 | net = (1-z) * net + z * q
21 | return net
22 |
--------------------------------------------------------------------------------
/examples/pgo/readme.md:
--------------------------------------------------------------------------------
1 | ## Pose Graph Optimization / Rotation Averaging
2 |
3 | Pose Graph Optimization (PGO) is the problem of estimating the global trajectory from a set of relative pose measurements. PGO is typically performed using nonlinear least-squares algorithms (e.g Levenberg-Marquardt) and requires a good initialization in order to converge.
4 |
5 | In this experiment, we implement Riemannian Gradient Descent with a reshaping function (Tron et al. 2012). The algorithm is implemented in the function `gradient_initializer` and runs on the GPU using lietorch.
6 |
7 | ### Running on a .g2o file
8 |
9 | Download a 3D problem from [datasets](https://lucacarlone.mit.edu/datasets/) (our implementation currently only supports uniform information matricies in Sphere-A, Torus, Cube, and Garage).
10 |
11 | Then run the `gradient_initializer` on the problem
12 | ```python
13 | python main.py --problem=torus3D.g2o --steps=500
14 | ```
15 |
16 | The output graph, `torus3D_rotavg.g2o`, can then be used as the initialization for non-linear least squares optimizers such as `ceres`, `g2o`, and `gtsam`.
17 |
--------------------------------------------------------------------------------
/lietorch/broadcasting.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 | def check_broadcastable(x, y):
5 | assert len(x.shape) == len(y.shape)
6 | for (n, m) in zip(x.shape[:-1], y.shape[:-1]):
7 | assert n==m or n==1 or m==1
8 |
9 | def broadcast_inputs(x, y):
10 | """ Automatic broadcasting of missing dimensions """
11 | if y is None:
12 | xs, xd = x.shape[:-1], x.shape[-1]
13 | return (x.view(-1, xd).contiguous(), ), x.shape[:-1]
14 |
15 | check_broadcastable(x, y)
16 |
17 | xs, xd = x.shape[:-1], x.shape[-1]
18 | ys, yd = y.shape[:-1], y.shape[-1]
19 | out_shape = [max(n,m) for (n,m) in zip(xs,ys)]
20 |
21 | if x.shape[:-1] == y.shape[-1]:
22 | x1 = x.view(-1, xd)
23 | y1 = y.view(-1, yd)
24 |
25 | else:
26 | x_expand = [m if n==1 else 1 for (n,m) in zip(xs, ys)]
27 | y_expand = [n if m==1 else 1 for (n,m) in zip(xs, ys)]
28 | x1 = x.repeat(x_expand + [1]).reshape(-1, xd).contiguous()
29 | y1 = y.repeat(y_expand + [1]).reshape(-1, yd).contiguous()
30 |
31 | return (x1, y1), tuple(out_shape)
32 |
--------------------------------------------------------------------------------
/examples/core/geom/chol.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import geom.projective_ops as pops
4 |
5 | class CholeskySolver(torch.autograd.Function):
6 | @staticmethod
7 | def forward(ctx, H, b):
8 | # don't crash training if cholesky decomp fails
9 | try:
10 | U = torch.cholesky(H)
11 | xs = torch.cholesky_solve(b, U)
12 | ctx.save_for_backward(U, xs)
13 | ctx.failed = False
14 | except Exception as e:
15 | print(e)
16 | ctx.failed = True
17 | xs = torch.zeros_like(b)
18 |
19 | return xs
20 |
21 | @staticmethod
22 | def backward(ctx, grad_x):
23 | if ctx.failed:
24 | return None, None
25 |
26 | U, xs = ctx.saved_tensors
27 | dz = torch.cholesky_solve(grad_x, U)
28 | dH = -torch.matmul(xs, dz.transpose(-1,-2))
29 |
30 | return dH, dz
31 |
32 | def block_solve(H, b, ep=0.1, lm=0.0001):
33 | """ solve normal equations """
34 | B, N, _, D, _ = H.shape
35 | I = torch.eye(D).to(H.device)
36 | H = H + (ep + lm*H) * I
37 |
38 | H = H.permute(0,1,3,2,4)
39 | H = H.reshape(B, N*D, N*D)
40 | b = b.reshape(B, N*D, 1)
41 |
42 | x = CholeskySolver.apply(H,b)
43 | return x.reshape(B, N, D)
--------------------------------------------------------------------------------
/examples/core/geom/sampler_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | def _bilinear_sampler(img, coords, mode='bilinear', mask=False):
5 | """ Wrapper for grid_sample, uses pixel coordinates """
6 | H, W = img.shape[-2:]
7 | xgrid, ygrid = coords.split([1,1], dim=-1)
8 | xgrid = 2*xgrid/(W-1) - 1
9 | ygrid = 2*ygrid/(H-1) - 1
10 |
11 | grid = torch.cat([xgrid, ygrid], dim=-1)
12 | img = F.grid_sample(img, grid, align_corners=True)
13 |
14 | if mask:
15 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
16 | return img, mask.float()
17 |
18 | return img
19 |
20 | def bilinear_sampler(img, coords):
21 | """ Wrapper for bilinear sampler for inputs with extra batch dimensions """
22 | unflatten = False
23 | if len(img.shape) == 5:
24 | unflatten = True
25 | b, n, c, h, w = img.shape
26 | img = img.view(b*n, c, h, w)
27 | coords = coords.view(b*n, h, w, 2)
28 |
29 | img1 = _bilinear_sampler(img, coords)
30 |
31 | if unflatten:
32 | return img1.view(b, n, c, h, w)
33 |
34 | return img1
35 |
36 | def sample_depths(depths, coords):
37 | batch, num, ht, wd = depths.shape
38 | depths = depths.view(batch, num, 1, ht, wd)
39 | coords = coords.view(batch, num, ht, wd, 2)
40 |
41 | depths_proj = bilinear_sampler(depths, coords)
42 | return depths_proj.view(batch, num, ht, wd, 1)
43 |
44 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 | import os
3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
4 |
5 | ROOT = os.path.dirname(os.path.abspath(__file__))
6 |
7 | setup(
8 | name="lietorch",
9 | version="0.3",
10 | description="Lie Groups for PyTorch",
11 | author="Zachary Teed",
12 | packages=["lietorch"],
13 | ext_modules=[
14 | CUDAExtension("lietorch_backends",
15 | include_dirs=[
16 | os.path.join(ROOT, "lietorch/include"),
17 | os.path.join(ROOT, "eigen")],
18 | sources=[
19 | "lietorch/src/lietorch.cpp",
20 | "lietorch/src/lietorch_gpu.cu",
21 | "lietorch/src/lietorch_cpu.cpp"],
22 | extra_compile_args={
23 | "cxx": ["-O2"],
24 | "nvcc": ["-O2"],
25 | }),
26 |
27 | CUDAExtension("lietorch_extras",
28 | sources=[
29 | "lietorch/extras/altcorr_kernel.cu",
30 | "lietorch/extras/corr_index_kernel.cu",
31 | "lietorch/extras/se3_builder.cu",
32 | "lietorch/extras/se3_inplace_builder.cu",
33 | "lietorch/extras/se3_solver.cu",
34 | "lietorch/extras/extras.cpp",
35 | ],
36 | extra_compile_args={
37 | "cxx": ["-O2"],
38 | "nvcc": ["-O2"],
39 | }),
40 | ],
41 | cmdclass={ "build_ext": BuildExtension }
42 | )
43 |
--------------------------------------------------------------------------------
/examples/rgbdslam/assets/renderoption.json:
--------------------------------------------------------------------------------
1 | {
2 | "background_color" : [ 1, 1, 1 ],
3 | "class_name" : "RenderOption",
4 | "default_mesh_color" : [ 0.69999999999999996, 0.69999999999999996, 0.69999999999999996 ],
5 | "image_max_depth" : 3000,
6 | "image_stretch_option" : 0,
7 | "interpolation_option" : 0,
8 | "light0_color" : [ 1, 1, 1 ],
9 | "light0_diffuse_power" : 0.66000000000000003,
10 | "light0_position" : [ 0, 0, 2 ],
11 | "light0_specular_power" : 0.20000000000000001,
12 | "light0_specular_shininess" : 100,
13 | "light1_color" : [ 1, 1, 1 ],
14 | "light1_diffuse_power" : 0.66000000000000003,
15 | "light1_position" : [ 0, 0, 2 ],
16 | "light1_specular_power" : 0.20000000000000001,
17 | "light1_specular_shininess" : 100,
18 | "light2_color" : [ 1, 1, 1 ],
19 | "light2_diffuse_power" : 0.66000000000000003,
20 | "light2_position" : [ 0, 0, -2 ],
21 | "light2_specular_power" : 0.20000000000000001,
22 | "light2_specular_shininess" : 100,
23 | "light3_color" : [ 1, 1, 1 ],
24 | "light3_diffuse_power" : 0.66000000000000003,
25 | "light3_position" : [ 0, 0, -2 ],
26 | "light3_specular_power" : 0.20000000000000001,
27 | "light3_specular_shininess" : 100,
28 | "light_ambient_color" : [ 0, 0, 0 ],
29 | "light_on" : true,
30 | "mesh_color_option" : 1,
31 | "mesh_shade_option" : 0,
32 | "mesh_show_back_face" : false,
33 | "mesh_show_wireframe" : false,
34 | "point_color_option" : 9,
35 | "point_show_normal" : false,
36 | "point_size" : 2,
37 | "show_coordinate_frame" : false,
38 | "version_major" : 1,
39 | "version_minor" : 0
40 | }
41 |
--------------------------------------------------------------------------------
/examples/registration/assets/renderoption.json:
--------------------------------------------------------------------------------
1 | {
2 | "background_color" : [ 1, 1, 1 ],
3 | "class_name" : "RenderOption",
4 | "default_mesh_color" : [ 0.69999999999999996, 0.69999999999999996, 0.69999999999999996 ],
5 | "image_max_depth" : 3000,
6 | "image_stretch_option" : 0,
7 | "interpolation_option" : 0,
8 | "light0_color" : [ 1, 1, 1 ],
9 | "light0_diffuse_power" : 0.66000000000000003,
10 | "light0_position" : [ 0, 0, 2 ],
11 | "light0_specular_power" : 0.20000000000000001,
12 | "light0_specular_shininess" : 100,
13 | "light1_color" : [ 1, 1, 1 ],
14 | "light1_diffuse_power" : 0.66000000000000003,
15 | "light1_position" : [ 0, 0, 2 ],
16 | "light1_specular_power" : 0.20000000000000001,
17 | "light1_specular_shininess" : 100,
18 | "light2_color" : [ 1, 1, 1 ],
19 | "light2_diffuse_power" : 0.66000000000000003,
20 | "light2_position" : [ 0, 0, -2 ],
21 | "light2_specular_power" : 0.20000000000000001,
22 | "light2_specular_shininess" : 100,
23 | "light3_color" : [ 1, 1, 1 ],
24 | "light3_diffuse_power" : 0.66000000000000003,
25 | "light3_position" : [ 0, 0, -2 ],
26 | "light3_specular_power" : 0.20000000000000001,
27 | "light3_specular_shininess" : 100,
28 | "light_ambient_color" : [ 0, 0, 0 ],
29 | "light_on" : true,
30 | "mesh_color_option" : 1,
31 | "mesh_shade_option" : 0,
32 | "mesh_show_back_face" : false,
33 | "mesh_show_wireframe" : false,
34 | "point_color_option" : 9,
35 | "point_show_normal" : false,
36 | "point_size" : 2,
37 | "show_coordinate_frame" : false,
38 | "version_major" : 1,
39 | "version_minor" : 0
40 | }
41 |
--------------------------------------------------------------------------------
/examples/core/logger.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | from torch.utils.tensorboard import SummaryWriter
4 |
5 |
6 | SUM_FREQ = 100
7 |
8 | class Logger:
9 | def __init__(self, name, scheduler):
10 | self.total_steps = 0
11 | self.running_loss = {}
12 | self.writer = None
13 | self.name = name
14 | self.scheduler = scheduler
15 |
16 | def _print_training_status(self):
17 | if self.writer is None:
18 | self.writer = SummaryWriter('runs/%s' % self.name)
19 | print([k for k in self.running_loss])
20 |
21 | lr = self.scheduler.get_lr().pop()
22 | metrics_data = [self.running_loss[k]/SUM_FREQ for k in self.running_loss.keys()]
23 | training_str = "[{:6d}, {:10.7f}] ".format(self.total_steps+1, lr)
24 | metrics_str = ("{:10.4f}, "*len(metrics_data)).format(*metrics_data)
25 |
26 | # print the training status
27 | print(training_str + metrics_str)
28 |
29 | for key in self.running_loss:
30 | val = self.running_loss[key] / SUM_FREQ
31 | self.writer.add_scalar(key, val, self.total_steps)
32 | self.running_loss[key] = 0.0
33 |
34 | def push(self, metrics):
35 |
36 | for key in metrics:
37 | if key not in self.running_loss:
38 | self.running_loss[key] = 0.0
39 |
40 | self.running_loss[key] += metrics[key]
41 |
42 | if self.total_steps % SUM_FREQ == SUM_FREQ-1:
43 | self._print_training_status()
44 | self.running_loss = {}
45 |
46 | self.total_steps += 1
47 |
--------------------------------------------------------------------------------
/examples/registration/readme.md:
--------------------------------------------------------------------------------
1 | ## SE3 / Sim3 Registration
2 | Estimate the 3D transformation between two RGB-D frames
3 |
4 |
5 |
6 | ### Models
7 | | Model | Rot. Acc. | Tr. Acc. | Scale Acc. |
8 | | --------- | --------- | -------- | ---------- |
9 | | [se3.pth](https://drive.google.com/file/d/17pgeY5m-GXnrY3oFLPRaIrTZYvae_l9u/view?usp=sharing) | 91.90 | 77.70 | - |
10 | | [sim3.pth](https://drive.google.com/file/d/1LMnKND_4DAmd9DMTSKdz_zAoCgja6X43/view?usp=sharing) | 93.45 | 76.05 | 98.70 |
11 |
12 | for thresholds 0.1 deg. rotation error, 1cm translation error, and 1% scale error.
13 |
14 | ### Demo
15 | Download one of the models to run the demo (requres Open3D)
16 | ```python
17 | python demo.py --transformation=SE3 --ckpt=se3.pth
18 | python demo.py --transformation=Sim3 --ckpt=sim3.pth
19 | ```
20 |
21 | ### Training and Evaluation
22 | Training and evaluation is performed on the [TartanAir](https://theairlab.org/tartanair-dataset/) (only depth_left and image_left need to be downloaded). Note: our dataloader computes the optical flow between every pair of frames which can take several hours on the first run. However, this result is cached so that future loads will only take a few seconds.
23 |
24 | The training script expects the dataset to be in the directory datasets/TartanAir.
25 |
26 | To train a Sim3 network:
27 | ```python
28 | python main.py --train --transformation=Sim3 --name=sim3
29 | ```
30 | A trained model can then be evaluated:
31 | ```python
32 | python main.py --transformation=Sim3 --ckpt=sim3.pth
33 | ```
34 |
35 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2021, princeton-vl
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | * Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | * Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | * Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 |
--------------------------------------------------------------------------------
/examples/core/data_readers/factory.py:
--------------------------------------------------------------------------------
1 |
2 | import pickle
3 | import os
4 | import os.path as osp
5 |
6 | # RGBD-Dataset
7 | from .tartan import TartanAir
8 | from .nyu2 import NYUv2
9 | from .eth3d import ETH3D
10 | from .scannet import ScanNet
11 |
12 | # streaming datasets for inference
13 | from .eth3d import ETH3DStream
14 | from .tum import TUMStream
15 | from .tartan import TartanAirStream
16 |
17 |
18 | def dataset_factory(dataset_list, **kwargs):
19 | """ create a combined dataset """
20 |
21 | from torch.utils.data import ConcatDataset
22 |
23 | dataset_map = {
24 | 'tartan': (TartanAir, 1),
25 | 'nyu': (NYUv2, 2),
26 | 'eth': (ETH3D, 5),
27 | 'scannet': (ScanNet, 1)}
28 |
29 | db_list = []
30 | for key in dataset_list:
31 | # cache datasets for faster future loading
32 | db = dataset_map[key][0](**kwargs)
33 | db *= dataset_map[key][1]
34 |
35 | print("Dataset {} has {} images".format(key, len(db)))
36 | db_list.append(db)
37 |
38 | return ConcatDataset(db_list)
39 |
40 |
41 | def create_datastream(dataset_path, **kwargs):
42 | """ create data_loader to stream images 1 by 1 """
43 |
44 | from torch.utils.data import DataLoader
45 |
46 | if osp.isfile(osp.join(dataset_path, 'calibration.txt')):
47 | db = ETH3DStream(dataset_path, **kwargs)
48 |
49 | elif osp.isfile(osp.join(dataset_path, 'rgb.txt')):
50 | db = TUMStream(dataset_path, **kwargs)
51 |
52 | elif osp.isdir(osp.join(dataset_path, 'image_left')):
53 | db = TartanStream(dataset_path, **kwargs)
54 |
55 | stream = DataLoader(db, shuffle=False, batch_size=1, num_workers=4)
56 | return stream
57 |
58 |
59 |
60 |
--------------------------------------------------------------------------------
/examples/rgbdslam/reprojection_test.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('../core')
3 |
4 | import torch
5 | import cv2
6 | import numpy as np
7 |
8 | from torch.utils.data import DataLoader
9 | from data_readers.factory import dataset_factory
10 |
11 | from lietorch import SO3, SE3, Sim3
12 | import geom.projective_ops as pops
13 | from geom.sampler_utils import bilinear_sampler
14 |
15 | def show_image(image):
16 | if len(image.shape) == 3:
17 | image = image.permute(1, 2, 0)
18 | image = image.cpu().numpy()
19 | cv2.imshow('image', image / 255.0)
20 | cv2.waitKey()
21 |
22 | def reproj_test(args, N=2):
23 | """ Test to make sure project transform correctly maps points """
24 |
25 | db = dataset_factory(args.datasets, n_frames=N)
26 | train_loader = DataLoader(db, batch_size=1, shuffle=True, num_workers=0)
27 |
28 | for item in train_loader:
29 | images, poses, depths, intrinsics = [x.to('cuda') for x in item]
30 | poses = SE3(poses).inv()
31 | disps = 1.0 / depths
32 |
33 | coords, _ = pops.projective_transform(poses, disps, intrinsics, [0], [1])
34 | imagew = bilinear_sampler(images[:,[1]], coords[...,[0,1]])
35 |
36 | # these two image should show camera motion
37 | show_image(images[0,0])
38 | show_image(images[0,1])
39 |
40 | # these two images should show the camera motion removed by reprojection / warping
41 | show_image(images[0,0])
42 | show_image(imagew[0,0])
43 |
44 |
45 | if __name__ == '__main__':
46 | import argparse
47 | parser = argparse.ArgumentParser()
48 | parser.add_argument('--datasets', nargs='+', help='lists of datasets for training')
49 | args = parser.parse_args()
50 |
51 | reproj_test(args)
52 |
--------------------------------------------------------------------------------
/lietorch/include/lietorch_cpu.h:
--------------------------------------------------------------------------------
1 |
2 | #ifndef LIETORCH_CPU_H_
3 | #define LIETORCH_CPU_H_
4 |
5 | #include
6 | #include
7 |
8 |
9 | // unary operations
10 | torch::Tensor exp_forward_cpu(int, torch::Tensor);
11 | std::vector exp_backward_cpu(int, torch::Tensor, torch::Tensor);
12 |
13 | torch::Tensor log_forward_cpu(int, torch::Tensor);
14 | std::vector log_backward_cpu(int, torch::Tensor, torch::Tensor);
15 |
16 | torch::Tensor inv_forward_cpu(int, torch::Tensor);
17 | std::vector inv_backward_cpu(int, torch::Tensor, torch::Tensor);
18 |
19 | // binary operations
20 | torch::Tensor mul_forward_cpu(int, torch::Tensor, torch::Tensor);
21 | std::vector mul_backward_cpu(int, torch::Tensor, torch::Tensor, torch::Tensor);
22 |
23 | torch::Tensor adj_forward_cpu(int, torch::Tensor, torch::Tensor);
24 | std::vector adj_backward_cpu(int, torch::Tensor, torch::Tensor, torch::Tensor);
25 |
26 | torch::Tensor adjT_forward_cpu(int, torch::Tensor, torch::Tensor);
27 | std::vector adjT_backward_cpu(int, torch::Tensor, torch::Tensor, torch::Tensor);
28 |
29 | torch::Tensor act_forward_cpu(int, torch::Tensor, torch::Tensor);
30 | std::vector act_backward_cpu(int, torch::Tensor, torch::Tensor, torch::Tensor);
31 |
32 | torch::Tensor act4_forward_cpu(int, torch::Tensor, torch::Tensor);
33 | std::vector act4_backward_cpu(int, torch::Tensor, torch::Tensor, torch::Tensor);
34 |
35 |
36 | // conversion operations
37 | // std::vector to_vec_backward_cpu(int, torch::Tensor, torch::Tensor);
38 | // std::vector from_vec_backward_cpu(int, torch::Tensor, torch::Tensor);
39 |
40 | // utility operations
41 | torch::Tensor orthogonal_projector_cpu(int, torch::Tensor);
42 |
43 | torch::Tensor as_matrix_forward_cpu(int, torch::Tensor);
44 |
45 | torch::Tensor jleft_forward_cpu(int, torch::Tensor, torch::Tensor);
46 |
47 |
48 | #endif
49 |
50 |
51 |
--------------------------------------------------------------------------------
/lietorch/include/lietorch_gpu.h:
--------------------------------------------------------------------------------
1 |
2 | #ifndef LIETORCH_GPU_H_
3 | #define LIETORCH_GPU_H_
4 |
5 | #include
6 | #include
7 | #include
8 | #include
9 |
10 |
11 | // unary operations
12 | torch::Tensor exp_forward_gpu(int, torch::Tensor);
13 | std::vector exp_backward_gpu(int, torch::Tensor, torch::Tensor);
14 |
15 | torch::Tensor log_forward_gpu(int, torch::Tensor);
16 | std::vector log_backward_gpu(int, torch::Tensor, torch::Tensor);
17 |
18 | torch::Tensor inv_forward_gpu(int, torch::Tensor);
19 | std::vector inv_backward_gpu(int, torch::Tensor, torch::Tensor);
20 |
21 | // binary operations
22 | torch::Tensor mul_forward_gpu(int, torch::Tensor, torch::Tensor);
23 | std::vector mul_backward_gpu(int, torch::Tensor, torch::Tensor, torch::Tensor);
24 |
25 | torch::Tensor adj_forward_gpu(int, torch::Tensor, torch::Tensor);
26 | std::vector adj_backward_gpu(int, torch::Tensor, torch::Tensor, torch::Tensor);
27 |
28 | torch::Tensor adjT_forward_gpu(int, torch::Tensor, torch::Tensor);
29 | std::vector adjT_backward_gpu(int, torch::Tensor, torch::Tensor, torch::Tensor);
30 |
31 | torch::Tensor act_forward_gpu(int, torch::Tensor, torch::Tensor);
32 | std::vector act_backward_gpu(int, torch::Tensor, torch::Tensor, torch::Tensor);
33 |
34 | torch::Tensor act4_forward_gpu(int, torch::Tensor, torch::Tensor);
35 | std::vector act4_backward_gpu(int, torch::Tensor, torch::Tensor, torch::Tensor);
36 |
37 | // conversion operations
38 | // std::vector to_vec_backward_gpu(int, torch::Tensor, torch::Tensor);
39 | // std::vector from_vec_backward_gpu(int, torch::Tensor, torch::Tensor);
40 |
41 | // utility operators
42 | torch::Tensor orthogonal_projector_gpu(int, torch::Tensor);
43 |
44 | torch::Tensor as_matrix_forward_gpu(int, torch::Tensor);
45 |
46 | torch::Tensor jleft_forward_gpu(int, torch::Tensor, torch::Tensor);
47 |
48 | #endif
49 |
50 |
51 |
--------------------------------------------------------------------------------
/examples/core/data_readers/stream.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np
3 | import torch
4 | import torch.utils.data as data
5 | import torch.nn.functional as F
6 |
7 | import csv
8 | import os
9 | import cv2
10 | import math
11 | import random
12 | import json
13 | import pickle
14 | import os.path as osp
15 |
16 | from .rgbd_utils import *
17 |
18 | class RGBDStream(data.Dataset):
19 | def __init__(self, datapath, frame_rate=-1, crop_size=[384,512]):
20 | self.datapath = datapath
21 | self.frame_rate = frame_rate
22 | self.crop_size = crop_size
23 | self._build_dataset_index()
24 |
25 | @staticmethod
26 | def image_read(image_file):
27 | return cv2.imread(image_file)
28 |
29 | @staticmethod
30 | def depth_read(depth_file):
31 | return np.load(depth_file)
32 |
33 | def __len__(self):
34 | return len(self.images)
35 |
36 | def __getitem__(self, index):
37 | """ return training video """
38 | image = self.__class__.image_read(self.images[index])
39 | image = torch.from_numpy(image).float()
40 | image = image.permute(2, 0, 1)
41 |
42 | depth = self.__class__.depth_read(self.depths[index])
43 | depth = torch.from_numpy(depth).float()
44 |
45 | pose = torch.from_numpy(self.poses[index]).float()
46 | intrinsic = torch.from_numpy(self.intrinsics[index]).float()
47 |
48 | sx = self.crop_size[1] / depth.shape[1]
49 | sy = self.crop_size[0] / depth.shape[0]
50 | image = F.interpolate(image[None], self.crop_size, mode='bilinear', align_corners=True)[0]
51 | depth = F.interpolate(depth[None,None], self.crop_size, mode='nearest')[0,0]
52 |
53 | image = image[..., 8:-8, 8:-8]
54 | depth = depth[..., 8:-8, 8:-8]
55 |
56 | fx, fy, cx, cy = intrinsic.unbind(dim=0)
57 | intrinsic = torch.stack([sx*fx, sy*fy, sx*cx - 8, sy*cy - 8])
58 |
59 | # intrinsic *= torch.as_tensor([sx, sy, sx, sy])
60 | return index, image, depth, pose, intrinsic
61 |
62 |
63 |
64 |
--------------------------------------------------------------------------------
/examples/core/data_readers/nyu2.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np
3 | import torch
4 | import glob
5 | import cv2
6 | import os
7 | import os.path as osp
8 |
9 | from .base import RGBDDataset
10 | from .augmentation import RGBDAugmentor
11 | from .rgbd_utils import all_pairs_distance_matrix, loadtum
12 |
13 | class NYUv2(RGBDDataset):
14 | def __init__(self, **kwargs):
15 | super(NYUv2, self).__init__(root='datasets/NYUv2', name='NYUv2', **kwargs)
16 |
17 | @staticmethod
18 | def is_test_scene(scene):
19 | return False
20 |
21 | def _build_dataset(self):
22 |
23 | from tqdm import tqdm
24 | print("Building NYUv2 dataset")
25 |
26 | scene_info = {}
27 | dataset_index = []
28 |
29 | scenes = os.listdir(self.root)
30 | for scene in tqdm(scenes):
31 | scene_path = osp.join(self.root, scene)
32 | images, depths, poses, intrinsics = loadtum(scene_path, frame_rate=10)
33 |
34 | # filter out some errors in dataset
35 | if images is None or len(images) < 8:
36 | continue
37 |
38 | intrinsic = NYUv2.calib_read()
39 | intrinsics = [intrinsic] * len(images)
40 |
41 | # graph of co-visible frames based on flow
42 | graph = self.build_frame_graph(poses, depths, intrinsics)
43 |
44 | scene_info[scene] = {'images': images, 'depths': depths,
45 | 'poses': poses, 'intrinsics': intrinsics, 'graph': graph}
46 |
47 | return scene_info
48 |
49 | @staticmethod
50 | def calib_read():
51 | fx = 5.1885790117450188e+02
52 | fy = 5.1946961112127485e+02
53 | cx = 3.2558244941119034e+02
54 | cy = 2.5373616633400465e+02
55 | return np.array([fx, fy, cx, cy])
56 |
57 | @staticmethod
58 | def image_read(image_file):
59 | return cv2.imread(image_file)
60 |
61 | @staticmethod
62 | def depth_read(depth_file):
63 | depth = cv2.imread(depth_file, cv2.IMREAD_ANYDEPTH)
64 | return depth.astype(np.float32) / 5000.0
65 |
66 |
--------------------------------------------------------------------------------
/examples/core/geom/losses.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from lietorch import SO3, SE3, Sim3
4 | from .graph_utils import graph_to_edge_list
5 |
6 | def pose_metrics(dE):
7 | """ Translation/Rotation/Scaling metrics from Sim3 """
8 | t, q, s = dE.data.split([3, 4, 1], -1)
9 | ang = SO3(q).log().norm(dim=-1)
10 |
11 | # convert radians to degrees
12 | r_err = (180 / np.pi) * ang
13 | t_err = t.norm(dim=-1)
14 | s_err = (s - 1.0).abs()
15 | return r_err, t_err, s_err
16 |
17 | def geodesic_loss(Ps, Gs, graph, gamma=0.9):
18 | """ Loss function for training network """
19 |
20 | # relative pose
21 | ii, jj, kk = graph_to_edge_list(graph)
22 | dP = Ps[:,jj] * Ps[:,ii].inv()
23 |
24 | n = len(Gs)
25 | geodesic_loss = 0.0
26 |
27 | for i in range(n):
28 | w = gamma ** (n - i - 1)
29 | dG = Gs[i][:,jj] * Gs[i][:,ii].inv()
30 |
31 | # pose error
32 | d = (dG * dP.inv()).log()
33 |
34 | if isinstance(dG, SE3):
35 | tau, phi = d.split([3,3], dim=-1)
36 | geodesic_loss += w * (
37 | tau.norm(dim=-1).mean() +
38 | phi.norm(dim=-1).mean())
39 |
40 | elif isinstance(dG, Sim3):
41 | tau, phi, sig = d.split([3,3,1], dim=-1)
42 | geodesic_loss += w * (
43 | tau.norm(dim=-1).mean() +
44 | phi.norm(dim=-1).mean() +
45 | 0.05 * sig.norm(dim=-1).mean())
46 |
47 | dE = Sim3(dG * dP.inv()).detach()
48 | r_err, t_err, s_err = pose_metrics(dE)
49 |
50 | metrics = {
51 | 'r_error': r_err.mean().item(),
52 | 't_error': t_err.mean().item(),
53 | 's_error': s_err.mean().item(),
54 | }
55 |
56 | return geodesic_loss, metrics
57 |
58 |
59 | def residual_loss(residuals, gamma=0.9):
60 | """ loss on system residuals """
61 | residual_loss = 0.0
62 | n = len(residuals)
63 |
64 | for i in range(n):
65 | w = gamma ** (n - i - 1)
66 | residual_loss += w * residuals[i].abs().mean()
67 |
68 | return residual_loss, {'residual': residual_loss.item()}
69 |
--------------------------------------------------------------------------------
/examples/core/data_readers/augmentation.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision.transforms as transforms
3 | import numpy as np
4 | import torch.nn.functional as F
5 |
6 |
7 | class RGBDAugmentor:
8 | """ perform augmentation on RGB-D video """
9 |
10 | def __init__(self, crop_size):
11 | self.crop_size = crop_size
12 | self.augcolor = transforms.Compose([
13 | transforms.ToPILImage(),
14 | transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.4/3.14),
15 | transforms.ToTensor()])
16 |
17 | self.max_scale = 0.25
18 |
19 | def spatial_transform(self, images, depths, poses, intrinsics):
20 | """ cropping and resizing """
21 | ht, wd = images.shape[2:]
22 |
23 | max_scale = self.max_scale
24 | min_scale = np.log2(np.maximum(
25 | (self.crop_size[0] + 1) / float(ht),
26 | (self.crop_size[1] + 1) / float(wd)))
27 |
28 | scale = 2 ** np.random.uniform(min_scale, max_scale)
29 | intrinsics = scale * intrinsics
30 | depths = depths.unsqueeze(dim=1)
31 |
32 | images = F.interpolate(images, scale_factor=scale, mode='bilinear',
33 | align_corners=True, recompute_scale_factor=True)
34 |
35 | depths = F.interpolate(depths, scale_factor=scale, recompute_scale_factor=True)
36 |
37 | # always perform center crop (TODO: try non-center crops)
38 | y0 = (images.shape[2] - self.crop_size[0]) // 2
39 | x0 = (images.shape[3] - self.crop_size[1]) // 2
40 |
41 | intrinsics = intrinsics - torch.tensor([0.0, 0.0, x0, y0])
42 | images = images[:, :, y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
43 | depths = depths[:, :, y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
44 |
45 | depths = depths.squeeze(dim=1)
46 | return images, poses, depths, intrinsics
47 |
48 | def color_transform(self, images):
49 | """ color jittering """
50 | num, ch, ht, wd = images.shape
51 | images = images.permute(1, 2, 3, 0).reshape(ch, ht, wd*num)
52 | images = 255 * self.augcolor(images[[2,1,0]] / 255.0)
53 | return images[[2,1,0]].reshape(ch, ht, wd, num).permute(3,0,1,2).contiguous()
54 |
55 | def __call__(self, images, poses, depths, intrinsics):
56 | images = self.color_transform(images)
57 | return self.spatial_transform(images, depths, poses, intrinsics)
58 |
--------------------------------------------------------------------------------
/examples/core/networks/modules/unet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | # Unet model from https://github.com/usuyama/pytorch-unet
5 |
6 |
7 |
8 | GRAD_CLIP = .01
9 |
10 | class GradClip(torch.autograd.Function):
11 | @staticmethod
12 | def forward(ctx, x):
13 | return x
14 |
15 | @staticmethod
16 | def backward(ctx, grad_x):
17 | o = torch.zeros_like(grad_x)
18 | grad_x = torch.where(grad_x.abs()>GRAD_CLIP, o, grad_x)
19 | grad_x = torch.where(torch.isnan(grad_x), o, grad_x)
20 | return grad_x
21 |
22 |
23 | def double_conv(in_channels, out_channels):
24 | return nn.Sequential(
25 | nn.Conv2d(in_channels, out_channels, 5, padding=2),
26 | nn.ReLU(inplace=True),
27 | nn.Conv2d(out_channels, out_channels, 5, padding=2),
28 | nn.ReLU(inplace=True)
29 | )
30 |
31 |
32 | class UNet(nn.Module):
33 |
34 | def __init__(self):
35 | super().__init__()
36 |
37 | self.dconv_down1 = double_conv(128, 128)
38 | self.dconv_down2 = double_conv(128, 256)
39 | self.dconv_down3 = double_conv(256, 256)
40 | # self.dconv_down4 = double_conv(256, 512)
41 |
42 | self.maxpool = nn.AvgPool2d(2)
43 | self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
44 |
45 | self.dconv_up3 = double_conv(256 + 256, 256)
46 | self.dconv_up2 = double_conv(256 + 256, 128)
47 | self.dconv_up1 = double_conv(128 + 128, 128)
48 |
49 | self.conv_r = nn.Conv2d(128, 3, 1)
50 | self.conv_w = nn.Conv2d(128, 3, 1)
51 |
52 |
53 | def forward(self, x):
54 | b, n, c, ht, wd = x.shape
55 | x = x.view(b*n, c, ht, wd)
56 |
57 | conv1 = self.dconv_down1(x)
58 | x = self.maxpool(conv1)
59 |
60 | conv2 = self.dconv_down2(x)
61 | x = self.maxpool(conv2)
62 |
63 | conv3 = self.dconv_down3(x)
64 | x = torch.cat([x, conv3], dim=1)
65 |
66 | x = self.dconv_up3(x)
67 | x = self.upsample(x)
68 | x = torch.cat([x, conv2], dim=1)
69 |
70 | x = self.dconv_up2(x)
71 | x = self.upsample(x)
72 | x = torch.cat([x, conv1], dim=1)
73 |
74 | x = self.dconv_up1(x)
75 | r = self.conv_r(x)
76 | w = self.conv_w(x)
77 |
78 | w = torch.sigmoid(w)
79 | w = w.view(b, n, 3, ht, wd).permute(0,1,3,4,2)
80 | r = r.view(b, n, 3, ht, wd).permute(0,1,3,4,2)
81 |
82 | # w = GradClip.apply(w)
83 | # r = GradClip.apply(r)
84 | return r, w
--------------------------------------------------------------------------------
/lietorch/include/dispatch.h:
--------------------------------------------------------------------------------
1 | #ifndef DISPATCH_H
2 | #define DISPATCH_H
3 |
4 | #include
5 |
6 | #include "so3.h"
7 | #include "rxso3.h"
8 | #include "se3.h"
9 | #include "sim3.h"
10 |
11 |
12 | #define PRIVATE_CASE_TYPE(group_index, enum_type, type, ...) \
13 | case enum_type: { \
14 | using scalar_t = type; \
15 | switch (group_index) { \
16 | case 1: { \
17 | using group_t = SO3; \
18 | return __VA_ARGS__(); \
19 | } \
20 | case 2: { \
21 | using group_t = RxSO3; \
22 | return __VA_ARGS__(); \
23 | } \
24 | case 3: { \
25 | using group_t = SE3; \
26 | return __VA_ARGS__(); \
27 | } \
28 | case 4: { \
29 | using group_t = Sim3; \
30 | return __VA_ARGS__(); \
31 | } \
32 | } \
33 | } \
34 |
35 | #define DISPATCH_GROUP_AND_FLOATING_TYPES(GROUP_INDEX, TYPE, NAME, ...) \
36 | [&] { \
37 | const auto& the_type = TYPE; \
38 | /* don't use TYPE again in case it is an expensive or side-effect op */ \
39 | at::ScalarType _st = ::detail::scalar_type(the_type); \
40 | switch (_st) { \
41 | PRIVATE_CASE_TYPE(GROUP_INDEX, at::ScalarType::Double, double, __VA_ARGS__) \
42 | PRIVATE_CASE_TYPE(GROUP_INDEX, at::ScalarType::Float, float, __VA_ARGS__) \
43 | default: break; \
44 | } \
45 | }()
46 |
47 | #endif
48 |
49 |
--------------------------------------------------------------------------------
/examples/core/geom/ba.py:
--------------------------------------------------------------------------------
1 | import lietorch
2 | import torch
3 | import torch.nn.functional as F
4 |
5 | from .chol import block_solve
6 | import geom.projective_ops as pops
7 |
8 | # utility functions for scattering ops
9 | def safe_scatter_add_mat(H, data, ii, jj, B, M, D):
10 | v = (ii >= 0) & (jj >= 0)
11 | H.scatter_add_(1, (ii[v]*M + jj[v]).view(1,-1,1,1).repeat(B,1,D,D), data[:,v])
12 |
13 | def safe_scatter_add_vec(b, data, ii, B, M, D):
14 | v = ii >= 0
15 | b.scatter_add_(1, ii[v].view(1,-1,1).repeat(B,1,D), data[:,v])
16 |
17 | def MoBA(target, weight, poses, disps, intrinsics, ii, jj, fixedp=1, lm=0.0001, ep=0.1):
18 | """ MoBA: Motion Only Bundle Adjustment """
19 |
20 | B, M = poses.shape[:2]
21 | D = poses.manifold_dim
22 | N = ii.shape[0]
23 |
24 | ### 1: commpute jacobians and residuals ###
25 | coords, valid, (Ji, Jj) = pops.projective_transform(
26 | poses, disps, intrinsics, ii, jj, jacobian=True)
27 |
28 | r = (target - coords).view(B, N, -1, 1)
29 | w = (valid * weight).view(B, N, -1, 1)
30 |
31 | ### 2: construct linear system ###
32 | Ji = Ji.view(B, N, -1, D)
33 | Jj = Jj.view(B, N, -1, D)
34 | wJiT = (.001 * w * Ji).transpose(2,3)
35 | wJjT = (.001 * w * Jj).transpose(2,3)
36 |
37 | Hii = torch.matmul(wJiT, Ji)
38 | Hij = torch.matmul(wJiT, Jj)
39 | Hji = torch.matmul(wJjT, Ji)
40 | Hjj = torch.matmul(wJjT, Jj)
41 |
42 | vi = torch.matmul(wJiT, r).squeeze(-1)
43 | vj = torch.matmul(wJjT, r).squeeze(-1)
44 |
45 | # only optimize keyframe poses
46 | M = M - fixedp
47 | ii = ii - fixedp
48 | jj = jj - fixedp
49 |
50 | H = torch.zeros(B, M*M, D, D, device=target.device)
51 | safe_scatter_add_mat(H, Hii, ii, ii, B, M, D)
52 | safe_scatter_add_mat(H, Hij, ii, jj, B, M, D)
53 | safe_scatter_add_mat(H, Hji, jj, ii, B, M, D)
54 | safe_scatter_add_mat(H, Hjj, jj, jj, B, M, D)
55 | H = H.reshape(B, M, M, D, D)
56 |
57 | v = torch.zeros(B, M, D, device=target.device)
58 | safe_scatter_add_vec(v, vi, ii, B, M, D)
59 | safe_scatter_add_vec(v, vj, jj, B, M, D)
60 |
61 | ### 3: solve the system + apply retraction ###
62 | dx = block_solve(H, v, ep=ep, lm=lm)
63 |
64 | poses1, poses2 = poses[:,:fixedp], poses[:,fixedp:]
65 | poses2 = poses2.retr(dx)
66 |
67 | poses = lietorch.cat([poses1, poses2], dim=1)
68 | return poses
69 |
70 | def SLessBA(target, weight, poses, disps, intrinsics, ii, jj, fixedp=1):
71 | """ Structureless Bundle Adjustment """
72 | pass
73 |
74 |
75 | def BA(target, weight, poses, disps, intrinsics, ii, jj, fixedp=1):
76 | """ Full Bundle Adjustment """
77 | pass
--------------------------------------------------------------------------------
/examples/rgbdslam/demo.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('../core')
3 |
4 | from tqdm import tqdm
5 | import numpy as np
6 | import torch
7 | import cv2
8 | import os
9 |
10 | from viz import SLAMFrontend
11 | from lietorch import SE3
12 | from networks.slam_system import SLAMSystem
13 | from data_readers import factory
14 |
15 |
16 | def show_image(image):
17 | image = image.permute(1, 2, 0).cpu().numpy()
18 | cv2.imshow('image', image / 255.0)
19 | cv2.waitKey(10)
20 |
21 | def evaluate(poses_gt, poses_est):
22 | from rgbd_benchmark.evaluate_ate import evaluate_ate
23 |
24 | poses_gt = poses_gt.cpu().numpy()
25 | poses_est = poses_est.cpu().numpy()
26 |
27 | N = poses_gt.shape[0]
28 | poses_gt = dict([(i, poses_gt[i]) for i in range(N)])
29 | poses_est = dict([(i, poses_est[i]) for i in range(N)])
30 |
31 | results = evaluate_ate(poses_gt, poses_est)
32 | print(results)
33 |
34 |
35 | @torch.no_grad()
36 | def run_slam(tracker, datapath, frame_rate=8.0):
37 | """ run slam over full sequence """
38 |
39 | torch.multiprocessing.set_sharing_strategy('file_system')
40 | stream = factory.create_datastream(args.datapath, frame_rate=frame_rate)
41 |
42 | # start the frontend thread
43 | if args.viz:
44 | frontend = SLAMFrontend().start()
45 | tracker.set_frontend(frontend)
46 |
47 | # store groundtruth poses for evaluation
48 | poses_gt = []
49 |
50 | for (tstamp, image, depth, pose, intrinsics) in tqdm(stream):
51 | tracker.track(tstamp, image[None].cuda(), depth.cuda(), intrinsics.cuda())
52 | poses_gt.append(pose)
53 |
54 | if args.viz:
55 | show_image(image[0])
56 | frontend.update_pose(tstamp, pose[0], gt=True)
57 |
58 | # global optimization / loop closure
59 | if args.go:
60 | tracker.global_refinement()
61 |
62 | poses_gt = torch.cat(poses_gt, 0)
63 | poses_est = tracker.raw_poses()
64 | evaluate(poses_gt, poses_est)
65 |
66 |
67 | if __name__ == '__main__':
68 | import argparse
69 | parser = argparse.ArgumentParser()
70 | parser.add_argument('--datapath', help='path to video for slam')
71 | parser.add_argument('--ckpt', help='saved network weights')
72 | parser.add_argument('--viz', action='store_true', help='run visualization frontent')
73 | parser.add_argument('--go', action='store_true', help='use global optimization')
74 | parser.add_argument('--frame_rate', type=float, default=8.0, help='frame rate')
75 | args = parser.parse_args()
76 |
77 | # initialize tracker / load weights
78 | tracker = SLAMSystem(args)
79 | tracker.load_state_dict(torch.load(args.ckpt))
80 | tracker.eval()
81 | tracker.cuda()
82 |
83 | run_slam(tracker, args.datapath, args.frame_rate)
84 |
--------------------------------------------------------------------------------
/examples/core/data_readers/tum.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np
3 | import torch
4 |
5 | import csv
6 | import os
7 | import cv2
8 | import math
9 | import random
10 | import json
11 | import pickle
12 | import os.path as osp
13 |
14 | from lietorch import SE3
15 | from .stream import RGBDStream
16 | from .rgbd_utils import loadtum
17 |
18 | intrinsics_dict = {
19 | 'freiburg1': [517.3, 516.5, 318.6, 255.3],
20 | 'freiburg2': [520.9, 521.0, 325.1, 249.7],
21 | 'freiburg3': [535.4, 539.2, 320.1, 247.6],
22 | }
23 |
24 | distortion_dict = {
25 | 'freiburg1': [0.2624, -0.9531, -0.0054, 0.0026, 1.1633],
26 | 'freiburg2': [0.2312, -0.7849, -0.0033, -0.0001, 0.9172],
27 | 'freiburg3': [0, 0, 0, 0, 0],
28 | }
29 |
30 | def as_intrinsics_matrix(intrinsics):
31 | K = np.eye(3)
32 | K[0,0] = intrinsics[0]
33 | K[1,1] = intrinsics[1]
34 | K[0,2] = intrinsics[2]
35 | K[1,2] = intrinsics[3]
36 | return K
37 |
38 |
39 | class TUMStream(RGBDStream):
40 | def __init__(self, datapath, **kwargs):
41 | super(TUMStream, self).__init__(datapath=datapath, **kwargs)
42 |
43 | def _build_dataset_index(self):
44 | """ build list of images, poses, depths, and intrinsics """
45 | images, depths, poses, intrinsics = loadtum(self.datapath, self.frame_rate)
46 | intrinsic, _ = TUMStream.calib_read(self.datapath)
47 | intrinsics = np.tile(intrinsic[None], (len(images), 1))
48 |
49 | # set first pose to identity
50 | poses = SE3(torch.as_tensor(poses))
51 | poses = poses[[0]].inv() * poses
52 | poses = poses.data.cpu().numpy()
53 |
54 | self.images = images
55 | self.poses = poses
56 | self.depths = depths
57 | self.intrinsics = intrinsics
58 |
59 | @staticmethod
60 | def calib_read(datapath):
61 | if 'freiburg1' in datapath:
62 | intrinsic = intrinsics_dict['freiburg1']
63 | d_coef = distortion_dict['freiburg1']
64 |
65 | elif 'freiburg2' in datapath:
66 | intrinsic = intrinsics_dict['freiburg2']
67 | d_coef = distortion_dict['freiburg2']
68 |
69 | elif 'freiburg3' in datapath:
70 | intrinsic = intrinsics_dict['freiburg3']
71 | d_coef = distortion_dict['freiburg3']
72 |
73 | return np.array(intrinsic), np.array(d_coef)
74 |
75 | @staticmethod
76 | def image_read(image_file):
77 | intrinsics, d_coef = TUMStream.calib_read(image_file)
78 | K = as_intrinsics_matrix(intrinsics)
79 | image = cv2.imread(image_file)
80 | return cv2.undistort(image, K, d_coef)
81 |
82 | @staticmethod
83 | def depth_read(depth_file):
84 | depth = cv2.imread(depth_file, cv2.IMREAD_ANYDEPTH)
85 | return depth.astype(np.float32) / 5000.0
86 |
--------------------------------------------------------------------------------
/examples/registration/demo.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('../core')
3 |
4 | import argparse
5 | import torch
6 | import cv2
7 | import numpy as np
8 |
9 | from viz import sim3_visualization
10 | from lietorch import SO3, SE3, Sim3
11 | from networks.sim3_net import Sim3Net
12 |
13 | def normalize_images(images):
14 | images = images[:, :, [2,1,0]]
15 | mean = torch.as_tensor([0.485, 0.456, 0.406], device=images.device)
16 | std = torch.as_tensor([0.229, 0.224, 0.225], device=images.device)
17 | return (images/255.0).sub_(mean[:, None, None]).div_(std[:, None, None])
18 |
19 | def load_example(i=0):
20 | """ get demo example """
21 | DEPTH_SCALE = 5.0
22 | if i==0:
23 | image1 = cv2.imread('assets/image1.png')
24 | image2 = cv2.imread('assets/image2.png')
25 | depth1 = np.load('assets/depth1.npy') / DEPTH_SCALE
26 | depth2 = np.load('assets/depth2.npy') / DEPTH_SCALE
27 |
28 | elif i==1:
29 | image1 = cv2.imread('assets/image3.png')
30 | image2 = cv2.imread('assets/image4.png')
31 | depth1 = np.load('assets/depth3.npy') / DEPTH_SCALE
32 | depth2 = np.load('assets/depth4.npy') / DEPTH_SCALE
33 |
34 | images = np.stack([image1, image2], 0)
35 | images = torch.from_numpy(images).permute(0,3,1,2)
36 |
37 | depths = np.stack([depth1, depth2], 0)
38 | depths = torch.from_numpy(depths).float()
39 |
40 | intrinsics = np.array([320.0, 320.0, 320.0, 240.0])
41 | intrinsics = np.tile(intrinsics[None], (2,1))
42 | intrinsics = torch.from_numpy(intrinsics).float()
43 |
44 | return images[None].cuda(), depths[None].cuda(), intrinsics[None].cuda()
45 |
46 |
47 | @torch.no_grad()
48 | def demo(model, index=0):
49 |
50 | images, depths, intrinsics = load_example(index)
51 |
52 | # initial transformation estimate
53 | if args.transformation == 'SE3':
54 | Gs = SE3.Identity(1, 2, device='cuda')
55 |
56 | elif args.transformation == 'Sim3':
57 | Gs = Sim3.Identity(1, 2, device='cuda')
58 | depths[:,0] *= 2**(2*torch.rand(1) - 1.0).cuda()
59 |
60 | images1 = normalize_images(images)
61 | ests, _ = model(Gs, images1, depths, intrinsics, num_steps=12)
62 |
63 | # only care about last transformation
64 | Gs = ests[-1]
65 | T = Gs[:,0] * Gs[:,1].inv()
66 |
67 | T = T[0].matrix().double().cpu().numpy()
68 | sim3_visualization(T, images, depths, intrinsics)
69 |
70 |
71 | if __name__ == '__main__':
72 | parser = argparse.ArgumentParser()
73 | parser.add_argument('--transformation', default='SE3', help='checkpoint to restore')
74 | parser.add_argument('--ckpt', help='checkpoint to restore')
75 | args = parser.parse_args()
76 |
77 | model = Sim3Net(args)
78 | model.load_state_dict(torch.load(args.ckpt))
79 |
80 | model.cuda()
81 | model.eval()
82 |
83 | # run two demos
84 | demo(model, 0)
85 | demo(model, 1)
86 |
87 |
--------------------------------------------------------------------------------
/examples/registration/viz.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('../core')
3 |
4 | import argparse
5 | import torch
6 | import scipy
7 | import numpy as np
8 |
9 | import geom.projective_ops as pops
10 | import open3d as o3d
11 |
12 | def make_point_cloud(image, depth, intrinsics, max_depth=5.0):
13 | """ create a point cloud """
14 | colors = image.permute(1,2,0).view(-1,3)
15 | colors = colors[...,[2,1,0]] / 255.0
16 | clr = colors.cpu().numpy()
17 |
18 | inv_depth = 1.0 / depth[None,None]
19 | points = pops.iproj(inv_depth, intrinsics[None,None])
20 | points = (points[..., :3] / points[..., 3:]).view(-1,3)
21 | pts = points.cpu().numpy()
22 |
23 | # open3d point cloud
24 | pc = o3d.geometry.PointCloud()
25 |
26 | keep = pts[:,2] < max_depth
27 | pc.points = o3d.utility.Vector3dVector(pts[keep])
28 | pc.colors = o3d.utility.Vector3dVector(clr[keep])
29 |
30 | return pc
31 |
32 | def set_camera_pose(vis):
33 | """ set initial camera position """
34 | cam = vis.get_view_control().convert_to_pinhole_camera_parameters()
35 |
36 | cam.extrinsic = np.array(
37 | [[ 0.91396544, 0.1462376, -0.37852575, 0.94374719],
38 | [-0.13923432, 0.98919177, 0.04597225, 1.01177687],
39 | [ 0.38115743, 0.01068673, 0.92444838, 3.35964868],
40 | [ 0., 0., 0., 1. ]])
41 |
42 | vis.get_view_control().convert_from_pinhole_camera_parameters(cam)
43 |
44 |
45 | def sim3_visualization(T, images, depths, intrinsics):
46 | """ convert depth to open3d point clouds """
47 |
48 | images = images.squeeze(0)
49 | depths = depths.squeeze(0)
50 | intrinsics = intrinsics.squeeze(0)
51 |
52 | pc1 = make_point_cloud(images[0], depths[0], intrinsics[0])
53 | pc2 = make_point_cloud(images[1], depths[1], intrinsics[1])
54 |
55 | sim3_visualization.index = 1
56 | sim3_visualization.pc2 = pc2
57 |
58 | NUM_STEPS = 100
59 | dt = scipy.linalg.logm(T) / NUM_STEPS
60 | dT = scipy.linalg.expm(dt)
61 | sim3_visualization.transform = dT
62 |
63 | def animation_callback(vis):
64 | sim3_visualization.index += 1
65 |
66 | pc2 = sim3_visualization.pc2
67 | if sim3_visualization.index >= NUM_STEPS and \
68 | sim3_visualization.index < 2*NUM_STEPS:
69 | pc2.transform(sim3_visualization.transform)
70 |
71 | vis.update_geometry(pc2)
72 | vis.poll_events()
73 | vis.update_renderer()
74 |
75 | vis = o3d.visualization.Visualizer()
76 | vis.register_animation_callback(animation_callback)
77 | vis.create_window(height=540, width=960)
78 |
79 | vis.add_geometry(pc1)
80 | vis.add_geometry(pc2)
81 |
82 | vis.get_render_option().load_from_json("assets/renderoption.json")
83 | set_camera_pose(vis)
84 |
85 | print("Press q to move to next example")
86 | vis.run()
87 | vis.destroy_window()
88 |
--------------------------------------------------------------------------------
/examples/core/data_readers/eth3d.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np
3 | import torch
4 | import torch.utils.data as data
5 | import torch.nn.functional as F
6 |
7 | import csv
8 | import os
9 | import cv2
10 | import math
11 | import random
12 | import json
13 | import pickle
14 | import os.path as osp
15 |
16 | from lietorch import SE3
17 | from .base import RGBDDataset
18 | from .stream import RGBDStream
19 | from .augmentation import RGBDAugmentor
20 | from .rgbd_utils import loadtum, all_pairs_distance_matrix
21 |
22 | class ETH3D(RGBDDataset):
23 | def __init__(self, **kwargs):
24 | super(ETH3D, self).__init__(root='datasets/ETH3D', name='ETH3D', **kwargs)
25 |
26 | @staticmethod
27 | def is_test_scene(scene):
28 | return False
29 |
30 | def _build_dataset(self):
31 | from tqdm import tqdm
32 | print("Building ETH3D dataset")
33 |
34 | scene_info = {}
35 | dataset_index = []
36 |
37 | for scene in tqdm(os.listdir(self.root)):
38 | scene_path = osp.join(self.root, scene)
39 |
40 | if not osp.isdir(scene_path):
41 | continue
42 |
43 | # don't use scenes with no rgb info
44 | if 'dark' in scene or 'kidnap' in scene:
45 | continue
46 |
47 | scene_data, graph = {}, {}
48 | images, depths, poses, intrinsics = loadtum(scene_path, skip=2)
49 |
50 | # graph of co-visible frames based on flow
51 | graph = self.build_frame_graph(poses, depths, intrinsics)
52 |
53 | scene_info[scene] = {'images': images, 'depths': depths,
54 | 'poses': poses, 'intrinsics': intrinsics, 'graph': graph}
55 |
56 | return scene_info
57 |
58 | @staticmethod
59 | def image_read(image_file):
60 | return cv2.imread(image_file)
61 |
62 | @staticmethod
63 | def depth_read(depth_file):
64 | depth = cv2.imread(depth_file, cv2.IMREAD_ANYDEPTH)
65 | return depth.astype(np.float32) / 5000.0
66 |
67 |
68 | class ETH3DStream(RGBDStream):
69 | def __init__(self, datapath, **kwargs):
70 | super(ETH3DStream, self).__init__(datapath=datapath, **kwargs)
71 |
72 | def _build_dataset_index(self):
73 | """ build list of images, poses, depths, and intrinsics """
74 | images, depths, poses, intrinsics = loadtum(self.datapath, self.frame_rate)
75 |
76 | # set first pose to identity
77 | poses = SE3(torch.as_tensor(poses))
78 | poses = poses[[0]].inv() * poses
79 | poses = poses.data.cpu().numpy()
80 |
81 | self.images = images
82 | self.poses = poses
83 | self.depths = depths
84 | self.intrinsics = intrinsics
85 |
86 | @staticmethod
87 | def image_read(image_file):
88 | return cv2.imread(image_file)
89 |
90 | @staticmethod
91 | def depth_read(depth_file):
92 | depth = cv2.imread(depth_file, cv2.IMREAD_ANYDEPTH)
93 | return depth.astype(np.float32) / 5000.0
94 |
--------------------------------------------------------------------------------
/examples/rgbdslam/evaluate.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('../core')
3 |
4 | from tqdm import tqdm
5 | import numpy as np
6 | import torch
7 | import cv2
8 | import os
9 |
10 | from lietorch import SE3
11 | from networks.slam_system import SLAMSystem
12 | from data_readers import factory
13 |
14 | def evaluate(poses_gt, poses_est):
15 | from rgbd_benchmark.evaluate_ate import evaluate_ate
16 |
17 | poses_gt = poses_gt.cpu().numpy()
18 | poses_est = poses_est.cpu().numpy()
19 |
20 | N = poses_gt.shape[0]
21 | poses_gt = dict([(i, poses_gt[i]) for i in range(N)])
22 | poses_est = dict([(i, poses_est[i]) for i in range(N)])
23 |
24 | results = evaluate_ate(poses_gt, poses_est)
25 | print(results)
26 | return results['absolute_translational_error.rmse']
27 |
28 | @torch.no_grad()
29 | def run_slam(tracker, datapath, global_optimization=False, frame_rate=3):
30 | """ run slam over full sequence """
31 |
32 | torch.multiprocessing.set_sharing_strategy('file_system')
33 | stream = factory.create_datastream(datapath, frame_rate=frame_rate)
34 |
35 | # store groundtruth poses for evaluatino
36 | poses_gt = []
37 | for (tstamp, image, depth, pose, intrinsics) in tqdm(stream):
38 | tracker.track(tstamp, image[None].cuda(), depth.cuda(), intrinsics.cuda())
39 | poses_gt.append(pose)
40 |
41 | if global_optimization:
42 | tracker.global_refinement()
43 |
44 | poses_gt = torch.cat(poses_gt, 0)
45 | poses_est = tracker.raw_poses()
46 |
47 | ate = evaluate(poses_gt, poses_est)
48 | return ate
49 |
50 | def run_evaluation(ckpt, frame_rate=8.0):
51 | validation_scenes = [
52 | 'rgbd_dataset_freiburg1_360',
53 | 'rgbd_dataset_freiburg1_desk',
54 | 'rgbd_dataset_freiburg1_desk2',
55 | 'rgbd_dataset_freiburg1_floor',
56 | 'rgbd_dataset_freiburg1_plant',
57 | 'rgbd_dataset_freiburg1_room',
58 | 'rgbd_dataset_freiburg1_rpy',
59 | 'rgbd_dataset_freiburg1_teddy',
60 | 'rgbd_dataset_freiburg1_xyz',
61 | ]
62 |
63 | results = {}
64 | for scene in validation_scenes:
65 | # initialize tracker / load weights
66 | tracker = SLAMSystem(None)
67 | tracker.load_state_dict(torch.load(ckpt))
68 | tracker.eval()
69 | tracker.cuda()
70 |
71 | datapath = os.path.join('datasets/TUM-RGBD', scene)
72 | results[scene] = run_slam(tracker, datapath,
73 | global_optimization=args.go, frame_rate=frame_rate)
74 |
75 | print("Aggregate Results: ")
76 | for scene in results:
77 | print(scene, results[scene])
78 |
79 | print("MEAN: ", np.mean([results[key] for key in results]))
80 |
81 | if __name__ == '__main__':
82 | import argparse
83 | parser = argparse.ArgumentParser()
84 | parser.add_argument('--ckpt', help='saved network weights')
85 | parser.add_argument('--frame_rate', type=float, default=8.0, help='frame rate')
86 | parser.add_argument('--go', action='store_true', help='use global optimization')
87 | args = parser.parse_args()
88 |
89 | run_evaluation(args.ckpt, frame_rate=args.frame_rate)
90 |
--------------------------------------------------------------------------------
/examples/rgbdslam/readme.md:
--------------------------------------------------------------------------------
1 | ## RGB-D SLAM / VO
2 |
3 |
4 |
5 |
6 |
7 | ### Pretrained Model
8 |
9 | Absolute Trajectory Error (ATE) on all freiburg1 sequences. The default model acts as a visual odometry system (no loop closure). The model rgbdslam.pth + go performs global optimization at the end of tracking to correct for drift.
10 |
11 | | Model | 360 | desk | desk2 | floor | plant | room | rpy | teddy | xyz | avg |
12 | | ----- | --- | ---- | ----- | ----- | ----- | ---- | --- | ----- | --- | --- |
13 | | DeepV2D | 0.072 | 0.069 | 0.074 | 0.317 | 0.046 | 0.213 | 0.082 | 0.114 | 0.028 | 0.113 |
14 | | [lietorch_rgbdslam.pth](https://drive.google.com/file/d/1SVQTFCchZuhFeSucS5jLeNbOWyff4BA8/view?usp=sharing) | 0.076 | 0.045 | 0.054 | 0.057 | 0.032 | 0.143 | 0.064 | 0.092 | 0.033 | 0.066 |
15 | | [lietorch_rgbdslam.pth](https://drive.google.com/file/d/1SVQTFCchZuhFeSucS5jLeNbOWyff4BA8/view?usp=sharing) + go | 0.047 | 0.018 | 0.023 | 0.017 | 0.015 | 0.029 | 0.019 | 0.030 | 0.009 | 0.023 |
16 |
17 | ### Demo
18 | Requires a GPU with at least 8gb of memory. First download a sequence from the [TUM-RGBD dataset](https://vision.in.tum.de/data/datasets/rgbd-dataset/download), then run the demo. You can interact with the Open3D window during tracking.
19 |
20 | ```python
21 | python demo.py --ckpt=lietorch_rgbdslam.pth --datapath= --frame_rate=8.0 --go --viz
22 | ```
23 |
24 | The `--frame_rate` flag determines the rate images are subsampled from the video (e.g `--frame_rate=8.0` subsamples the video at a rate of 8 fps). With a RTX-3090 GPU and visualization disabled, `--frame_rate <= 8.0` gives real-time performance.
25 |
26 |
27 | ### Evaluation
28 | Assuming all TUM-RGBD sequences have been download, a trained model can be evaluated on the TUM-RGBD dataset
29 | ```
30 | python evaluate.py --ckpt=rgbdslam.pth --datapath= --go --frame_rate=8.0
31 | ```
32 |
33 | ### Training
34 | We provide data_loaders for [NYUv2](https://cs.nyu.edu/~silberman/datasets/nyu_depth_v2.html), [ScanNet](http://www.scan-net.org/), [ETH3D-SLAM](https://www.eth3d.net/slam_datasets), and [TartanAir](https://theairlab.org/tartanair-dataset/). The dataloaders will work directly on ScanNet, ETH3D, and TartanAir. For NYUv2, you will need to first extract the depths and images from the raw format then run ORB-SLAM2 to generate psuedo-groundtruth poses. Send me an email (Zachary Teed) if you need a link to the preprocessed NYU data.
35 |
36 | You can train on any subset of the datasets by listing their keys {`nyu`, `scannet`, `eth`, `tartan`}. The provided models are trained on scannet and nyu. Note: our dataloader computes the optical flow between every pair of frames which can take several hours on the first run. However, this result is cached so that future loads will only take a few seconds. The default training setttings require a GPU with 24 Gb of memory.
37 |
38 | ```
39 | python train.py --batch=3 --iters=12 --lr=0.00025 --name nyu_scannet_eth_v2 --datasets nyu scannet
40 | ```
41 |
42 | #### Training on your own dataset
43 | Additional datasets can easily be added by subclassing `RGBDDataset`, see `nyu2.py` or `scannet.py` as examples. To verify the dataloading is correct, you can use the `reprojection_test.py` script to verify that the warped images align.
44 |
--------------------------------------------------------------------------------
/examples/core/geom/graph_utils.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import numpy as np
4 | from collections import OrderedDict
5 |
6 | import lietorch
7 | from data_readers.rgbd_utils import compute_distance_matrix_flow
8 |
9 |
10 | def graph_to_edge_list(graph):
11 | ii, jj, kk = [], [], []
12 | for s, u in enumerate(graph):
13 | for v in graph[u]:
14 | ii.append(u)
15 | jj.append(v)
16 | kk.append(s)
17 |
18 | ii = torch.as_tensor(ii).cuda()
19 | jj = torch.as_tensor(jj).cuda()
20 | kk = torch.as_tensor(kk).cuda()
21 | return ii, jj, kk
22 |
23 | def keyframe_indicies(graph):
24 | return torch.as_tensor([u for u in graph]).cuda()
25 |
26 |
27 | def meshgrid(m, n, device='cuda'):
28 | ii, jj = torch.meshgrid(torch.arange(m), torch.arange(n))
29 | return ii.reshape(-1).to(device), jj.reshape(-1).to(device)
30 |
31 |
32 | class KeyframeGraph:
33 | def __init__(self, images, poses, depths, intrinsics):
34 | self.images = images.cpu()
35 | self.depths = depths.cpu()
36 | self.poses = poses
37 | self.intrinsics = intrinsics
38 |
39 | depths = depths[..., 3::8, 3::8].float().cuda()
40 | disps = torch.where(depths>0.1, 1.0/depths, depths)
41 |
42 | N = poses.shape[1]
43 | d = compute_distance_matrix_flow(poses, disps, intrinsics / 8.0)
44 |
45 | i, j = 0, 0
46 | ixs = [ i ]
47 |
48 | while j < N-1:
49 | if d[i, j+1] > 7.5:
50 | ixs += [ j ]
51 | i = j
52 | j += 1
53 |
54 | # indicies of keyframes
55 | self.distance_matrix = d[ixs][:,ixs]
56 | self.ixs = np.array(ixs)
57 | self.frame_graph = {}
58 |
59 | for i in range(N):
60 | k = np.argmin(np.abs(i - self.ixs))
61 | j = self.ixs[k]
62 | self.frame_graph[i] = (k, poses[:,i] * poses[:,j].inv())
63 |
64 | def get_keyframes(self):
65 | ix = torch.as_tensor(self.ixs).cuda()
66 | return self.images[:,ix], self.poses[:,ix], self.depths[:,ix], self.intrinsics[:,ix]
67 |
68 | def get_graph(self, num=-1, thresh=24.0, r=2):
69 | d = self.distance_matrix.copy()
70 |
71 | N = d.shape[0]
72 | if num < 0:
73 | num = N
74 |
75 | graph = OrderedDict()
76 | for i in range(N):
77 | graph[i] = [j for j in range(N) if i!=j and abs(i-j) <= 2]
78 |
79 | for i in range(N):
80 | for j in range(i-r, i+r+1):
81 | if j >= 0 and j < N:
82 | d[i,j] = np.inf
83 |
84 | for _ in range(num):
85 | ix = np.argmin(d)
86 | i, j = ix // N, ix % N
87 |
88 | if d[i,j] < thresh:
89 | graph[i].append(j)
90 | for ii in range(i-r, i+r+1):
91 | for jj in range(j-r, j+r+1):
92 | if ii>=0 and jj>=0 and ii 660
24 |
25 | def _build_dataset_index(self):
26 | """ construct scene_info and dataset_index objects """
27 |
28 | from tqdm import tqdm
29 | print("Building ScanNet dataset")
30 |
31 | scene_info = {}
32 | dataset_index = []
33 |
34 | for scene in tqdm(os.listdir(self.root)):
35 | scene_path = osp.join(self.root, scene)
36 | depth_glob = osp.join(scene_path, 'depth', '*.png')
37 | depth_list = glob.glob(depth_glob)
38 |
39 | get_indicies = lambda x: int(osp.basename(x).split('.')[0])
40 | get_images = lambda i: osp.join(scene_path, 'color', '%d.jpg' % i)
41 | get_depths = lambda i: osp.join(scene_path, 'depth', '%d.png' % i)
42 | get_poses = lambda i: osp.join(scene_path, 'pose', '%d.txt' % i)
43 |
44 | indicies = sorted(map(get_indicies, depth_list))[::2]
45 | image_list = list(map(get_images, indicies))
46 | depth_list = list(map(get_depths, indicies))
47 |
48 | pose_list = map(get_poses, indicies)
49 | pose_list = list(map(ScanNet.pose_read, pose_list))
50 |
51 | # remove nan poses
52 | pvecs = np.stack(pose_list, 0)
53 | keep, = np.where(~np.any(np.isnan(pvecs) | np.isinf(pvecs), axis=1))
54 | images = [image_list[i] for i in keep]
55 | depths = [depth_list[i] for i in keep]
56 | poses = [pose_list[i] for i in keep]
57 |
58 | intrinsic = ScanNet.calib_read(scene_path)
59 | intrinsics = [intrinsic] * len(images)
60 |
61 | graph = self.build_frame_graph(poses, depths, intrinsics)
62 |
63 | scene_info[scene] = {'images': images, 'depths': depths,
64 | 'poses': poses, 'intrinsics': intrinsics, 'graph': graph}
65 |
66 | for i in range(len(images)):
67 | if len(graph[i][0]) > 1:
68 | dataset_index.append((scene, i))
69 |
70 | return scene_info, dataset_index
71 |
72 | @staticmethod
73 | def calib_read(scene_path):
74 | intrinsic_file = osp.join(scene_path, 'intrinsic', 'intrinsic_depth.txt')
75 | K = np.loadtxt(intrinsic_file, delimiter=' ')
76 | return np.array([K[0,0], K[1,1], K[0,2], K[1,2]])
77 |
78 | @staticmethod
79 | def pose_read(pose_file):
80 | pose = np.loadtxt(pose_file, delimiter=' ').astype(np.float64)
81 | return pose_matrix_to_quaternion(pose)
82 |
83 | @staticmethod
84 | def image_read(image_file):
85 | image = cv2.imread(image_file)
86 | return cv2.resize(image, (640, 480))
87 |
88 | @staticmethod
89 | def depth_read(depth_file):
90 | depth = cv2.imread(depth_file, cv2.IMREAD_ANYDEPTH)
91 | return depth.astype(np.float32) / 1000.0
--------------------------------------------------------------------------------
/lietorch/group_ops.py:
--------------------------------------------------------------------------------
1 | import lietorch_backends
2 | import torch
3 | import torch.nn.functional as F
4 |
5 |
6 |
7 | class GroupOp(torch.autograd.Function):
8 | """ group operation base class """
9 |
10 | @classmethod
11 | def forward(cls, ctx, group_id, *inputs):
12 | ctx.group_id = group_id
13 | ctx.save_for_backward(*inputs)
14 | out = cls.forward_op(ctx.group_id, *inputs)
15 | return out
16 |
17 | @classmethod
18 | def backward(cls, ctx, grad):
19 | error_str = "Backward operation not implemented for {}".format(cls)
20 | assert cls.backward_op is not None, error_str
21 |
22 | inputs = ctx.saved_tensors
23 | grad = grad.contiguous()
24 | grad_inputs = cls.backward_op(ctx.group_id, grad, *inputs)
25 | return (None, ) + tuple(grad_inputs)
26 |
27 |
28 | class Exp(GroupOp):
29 | """ exponential map """
30 | forward_op, backward_op = lietorch_backends.expm, lietorch_backends.expm_backward
31 |
32 | class Log(GroupOp):
33 | """ logarithm map """
34 | forward_op, backward_op = lietorch_backends.logm, lietorch_backends.logm_backward
35 |
36 | class Inv(GroupOp):
37 | """ group inverse """
38 | forward_op, backward_op = lietorch_backends.inv, lietorch_backends.inv_backward
39 |
40 | class Mul(GroupOp):
41 | """ group multiplication """
42 | forward_op, backward_op = lietorch_backends.mul, lietorch_backends.mul_backward
43 |
44 | class Adj(GroupOp):
45 | """ adjoint operator """
46 | forward_op, backward_op = lietorch_backends.adj, lietorch_backends.adj_backward
47 |
48 | class AdjT(GroupOp):
49 | """ adjoint operator """
50 | forward_op, backward_op = lietorch_backends.adjT, lietorch_backends.adjT_backward
51 |
52 | class Act3(GroupOp):
53 | """ action on point """
54 | forward_op, backward_op = lietorch_backends.act, lietorch_backends.act_backward
55 |
56 | class Act4(GroupOp):
57 | """ action on point """
58 | forward_op, backward_op = lietorch_backends.act4, lietorch_backends.act4_backward
59 |
60 | class Jinv(GroupOp):
61 | """ adjoint operator """
62 | forward_op, backward_op = lietorch_backends.Jinv, None
63 |
64 | class ToMatrix(GroupOp):
65 | """ convert to matrix representation """
66 | forward_op, backward_op = lietorch_backends.as_matrix, None
67 |
68 |
69 |
70 |
71 | ### conversion operations to/from Euclidean embeddings ###
72 |
73 | class FromVec(torch.autograd.Function):
74 | """ convert vector into group object """
75 |
76 | @classmethod
77 | def forward(cls, ctx, group_id, *inputs):
78 | ctx.group_id = group_id
79 | ctx.save_for_backward(*inputs)
80 | return inputs[0]
81 |
82 | @classmethod
83 | def backward(cls, ctx, grad):
84 | inputs = ctx.saved_tensors
85 | J = lietorch_backends.projector(ctx.group_id, *inputs)
86 | return None, torch.matmul(grad.unsqueeze(-2), torch.linalg.pinv(J)).squeeze(-2)
87 |
88 | class ToVec(torch.autograd.Function):
89 | """ convert group object to vector """
90 |
91 | @classmethod
92 | def forward(cls, ctx, group_id, *inputs):
93 | ctx.group_id = group_id
94 | ctx.save_for_backward(*inputs)
95 | return inputs[0]
96 |
97 | @classmethod
98 | def backward(cls, ctx, grad):
99 | inputs = ctx.saved_tensors
100 | J = lietorch_backends.projector(ctx.group_id, *inputs)
101 | return None, torch.matmul(grad.unsqueeze(-2), J).squeeze(-2)
102 |
103 |
--------------------------------------------------------------------------------
/examples/core/geom/projective_ops.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | from lietorch import SE3, Sim3
5 |
6 | MIN_DEPTH = 0.1
7 |
8 | def extract_intrinsics(intrinsics):
9 | return intrinsics[...,None,None,:].unbind(dim=-1)
10 |
11 | def iproj(disps, intrinsics):
12 | """ pinhole camera inverse projection """
13 | ht, wd = disps.shape[2:]
14 | fx, fy, cx, cy = extract_intrinsics(intrinsics)
15 |
16 | y, x = torch.meshgrid(
17 | torch.arange(ht).to(disps.device).float(),
18 | torch.arange(wd).to(disps.device).float())
19 |
20 | i = torch.ones_like(disps)
21 | X = (x - cx) / fx
22 | Y = (y - cy) / fy
23 | return torch.stack([X, Y, i, disps], dim=-1)
24 |
25 | def proj(Xs, intrinsics, jacobian=False):
26 | """ pinhole camera projection """
27 | fx, fy, cx, cy = extract_intrinsics(intrinsics)
28 | X, Y, Z, D = Xs.unbind(dim=-1)
29 | d = torch.where(Z.abs() < 0.001, torch.zeros_like(Z), 1.0/Z)
30 |
31 | x = fx * (X * d) + cx
32 | y = fy * (Y * d) + cy
33 | coords = torch.stack([x,y, D*d], dim=-1)
34 |
35 | if jacobian:
36 | B, N, H, W = d.shape
37 | o = torch.zeros_like(d)
38 | proj_jac = torch.stack([
39 | fx*d, o, -fx*X*d*d, o,
40 | o, fy*d, -fy*Y*d*d, o,
41 | o, o, -D*d*d, d,
42 | ], dim=-1).view(B, N, H, W, 3, 4)
43 |
44 | return coords, proj_jac
45 |
46 | return coords, None
47 |
48 | def actp(Gij, X0, jacobian=False):
49 | """ action on point cloud """
50 | X1 = Gij[:,:,None,None] * X0
51 |
52 | if jacobian:
53 | X, Y, Z, d = X1.unbind(dim=-1)
54 | o = torch.zeros_like(d)
55 | B, N, H, W = d.shape
56 |
57 | if isinstance(Gij, SE3):
58 | Ja = torch.stack([
59 | d, o, o, o, Z, -Y,
60 | o, d, o, -Z, o, X,
61 | o, o, d, Y, -X, o,
62 | o, o, o, o, o, o,
63 | ], dim=-1).view(B, N, H, W, 4, 6)
64 |
65 | elif isinstance(Gij, Sim3):
66 | Ja = torch.stack([
67 | d, o, o, o, Z, -Y, X,
68 | o, d, o, -Z, o, X, Y,
69 | o, o, d, Y, -X, o, Z,
70 | o, o, o, o, o, o, o
71 | ], dim=-1).view(B, N, H, W, 4, 7)
72 |
73 | return X1, Ja
74 |
75 | return X1, None
76 |
77 | def projective_transform(poses, depths, intrinsics, ii, jj, jacobian=False):
78 | """ map points from ii->jj """
79 |
80 | # inverse project (pinhole)
81 | X0 = iproj(depths[:,ii], intrinsics[:,ii])
82 |
83 | # transform
84 | Gij = poses[:,jj] * poses[:,ii].inv()
85 | X1, Ja = actp(Gij, X0, jacobian=jacobian)
86 |
87 | # project (pinhole)
88 | x1, Jp = proj(X1, intrinsics[:,jj], jacobian=jacobian)
89 |
90 | # exclude points too close to camera
91 | valid = ((X1[...,2] > MIN_DEPTH) & (X0[...,2] > MIN_DEPTH)).float()
92 | valid = valid.unsqueeze(-1)
93 |
94 | if jacobian:
95 | Jj = torch.matmul(Jp, Ja)
96 | Ji = -Gij[:,:,None,None,None].adjT(Jj)
97 | return x1, valid, (Ji, Jj)
98 |
99 | return x1, valid
100 |
101 |
102 | def induced_flow(poses, disps, intrinsics, ii, jj):
103 | """ optical flow induced by camera motion """
104 |
105 | ht, wd = disps.shape[2:]
106 | y, x = torch.meshgrid(
107 | torch.arange(ht).to(disps.device).float(),
108 | torch.arange(wd).to(disps.device).float())
109 |
110 | coords0 = torch.stack([x, y], dim=-1)
111 | coords1, valid = projective_transform(poses, disps, intrinsics, ii, jj)
112 |
113 | return coords1[...,:2] - coords0, valid
114 |
115 |
--------------------------------------------------------------------------------
/examples/pgo/main.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from lietorch import SO3, SE3, LieGroupParameter
3 |
4 | import argparse
5 | import numpy as np
6 | import time
7 | import torch.optim as optim
8 | import torch.nn.functional as F
9 |
10 |
11 | def draw(verticies):
12 | """ draw pose graph """
13 | import open3d as o3d
14 |
15 | n = len(verticies)
16 | points = np.array([x[1][:3] for x in verticies])
17 | lines = np.stack([np.arange(0,n-1), np.arange(1,n)], 1)
18 |
19 | line_set = o3d.geometry.LineSet(
20 | points=o3d.utility.Vector3dVector(points),
21 | lines=o3d.utility.Vector2iVector(lines),
22 | )
23 | o3d.visualization.draw_geometries([line_set])
24 |
25 | def info2mat(info):
26 | mat = np.zeros((6,6))
27 | ix = 0
28 | for i in range(mat.shape[0]):
29 | mat[i,i:] = info[ix:ix+(6-i)]
30 | mat[i:,i] = info[ix:ix+(6-i)]
31 | ix += (6-i)
32 |
33 | return mat
34 |
35 | def read_g2o(fn):
36 | verticies, edges = [], []
37 | with open(fn) as f:
38 | for line in f:
39 | line = line.split()
40 | if line[0] == 'VERTEX_SE3:QUAT':
41 | v = int(line[1])
42 | pose = np.array(line[2:], dtype=np.float32)
43 | verticies.append([v, pose])
44 |
45 | elif line[0] == 'EDGE_SE3:QUAT':
46 | u = int(line[1])
47 | v = int(line[2])
48 | pose = np.array(line[3:10], dtype=np.float32)
49 | info = np.array(line[10:], dtype=np.float32)
50 |
51 | info = info2mat(info)
52 | edges.append([u, v, pose, info, line])
53 |
54 | return verticies, edges
55 |
56 | def write_g2o(pose_graph, fn):
57 | import csv
58 | verticies, edges = pose_graph
59 | with open(fn, 'w') as f:
60 | writer = csv.writer(f, delimiter=' ')
61 | for (v, pose) in verticies:
62 | row = ['VERTEX_SE3:QUAT', v] + pose.tolist()
63 | writer.writerow(row)
64 | for edge in edges:
65 | writer.writerow(edge[-1])
66 |
67 | def reshaping_fn(dE, b=1.5):
68 | """ Reshaping function from "Intrinsic consensus on SO(3), Tron et al."""
69 | ang = dE.log().norm(dim=-1)
70 | err = 1/b - (1/b + ang) * torch.exp(-b*ang)
71 | return err.sum()
72 |
73 | def gradient_initializer(pose_graph, n_steps=500, lr_init=0.2):
74 | """ Riemannian Gradient Descent """
75 |
76 | verticies, edges = pose_graph
77 |
78 | # edge indicies (ii, jj)
79 | ii = np.array([x[0] for x in edges])
80 | jj = np.array([x[1] for x in edges])
81 | ii = torch.from_numpy(ii).cuda()
82 | jj = torch.from_numpy(jj).cuda()
83 |
84 | Eij = np.stack([x[2][3:] for x in edges])
85 | Eij = SO3(torch.from_numpy(Eij).float().cuda())
86 |
87 | R = np.stack([x[1][3:] for x in verticies])
88 | R = SO3(torch.from_numpy(R).float().cuda())
89 | R = LieGroupParameter(R)
90 |
91 | # use gradient descent with momentum
92 | optimizer = optim.SGD([R], lr=lr_init, momentum=0.5)
93 |
94 | start = time.time()
95 | for i in range(n_steps):
96 | optimizer.zero_grad()
97 |
98 | for param_group in optimizer.param_groups:
99 | param_group['lr'] = lr_init * .995**i
100 |
101 | # rotation error
102 | dE = (R[ii].inv() * R[jj]) * Eij.inv()
103 | loss = reshaping_fn(dE)
104 |
105 | loss.backward()
106 | optimizer.step()
107 |
108 | if i%25 == 0:
109 | print(i, lr_init * .995**i, loss.item())
110 |
111 | # convert rotations to pose3
112 | quats = R.group.data.detach().cpu().numpy()
113 |
114 | for i in range(len(verticies)):
115 | verticies[i][1][3:] = quats[i]
116 |
117 | return verticies, edges
118 |
119 |
120 | if __name__ == '__main__':
121 | parser = argparse.ArgumentParser()
122 | parser.add_argument('--problem', help="input pose graph optimization file (.g2o format)")
123 | args = parser.parse_args()
124 |
125 | output_path = args.problem.replace('.g2o', '_rotavg.g2o')
126 | input_pose_graph = read_g2o(args.problem)
127 |
128 | rot_pose_graph = gradient_initializer(input_pose_graph)
129 | write_g2o(rot_pose_graph, output_path)
130 |
131 |
--------------------------------------------------------------------------------
/examples/rgbdslam/train.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('../core')
3 |
4 | import cv2
5 | import numpy as np
6 | from collections import OrderedDict
7 |
8 | import torch
9 | import torch.optim as optim
10 | from torch.utils.data import DataLoader
11 | from data_readers.factory import dataset_factory
12 |
13 | from lietorch import SO3, SE3, Sim3
14 | from geom.losses import geodesic_loss, residual_loss
15 |
16 | # network
17 | from networks.rslam import RaftSLAM
18 | from logger import Logger
19 | from evaluate import run_evaluation
20 |
21 | def show_image(image):
22 | image = image.permute(1, 2, 0).cpu().numpy()
23 | cv2.imshow('image', image / 255.0)
24 | cv2.waitKey()
25 |
26 | def normalize_images(images):
27 | images = images[:, :, [2,1,0]]
28 | mean = torch.as_tensor([0.485, 0.456, 0.406], device=images.device)
29 | std = torch.as_tensor([0.229, 0.224, 0.225], device=images.device)
30 | return (images/255.0).sub_(mean[:, None, None]).div_(std[:, None, None])
31 |
32 | def train(args):
33 | """ Test to make sure project transform correctly maps points """
34 |
35 | N = args.n_frames
36 | model = RaftSLAM(args)
37 | model.cuda()
38 | model.train()
39 |
40 | if args.ckpt is not None:
41 | model.load_state_dict(torch.load(args.ckpt))
42 |
43 | db = dataset_factory(args.datasets, n_frames=N, fmin=16.0, fmax=96.0)
44 | train_loader = DataLoader(db, batch_size=args.batch, shuffle=True, num_workers=4)
45 |
46 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5)
47 | scheduler = optim.lr_scheduler.OneCycleLR(optimizer,
48 | args.lr, args.steps, pct_start=0.01, cycle_momentum=False)
49 |
50 | logger = Logger(args.name, scheduler)
51 | should_keep_training = True
52 | total_steps = 0
53 |
54 | while should_keep_training:
55 | for i_batch, item in enumerate(train_loader):
56 | optimizer.zero_grad()
57 |
58 | graph = OrderedDict()
59 | for i in range(N):
60 | graph[i] = [j for j in range(N) if i!=j and abs(i-j) <= 2]
61 |
62 | images, poses, depths, intrinsics = [x.to('cuda') for x in item]
63 |
64 | # convert poses w2c -> c2w
65 | Ps = SE3(poses).inv()
66 | Gs = SE3.Identity(Ps.shape, device='cuda')
67 |
68 | images = normalize_images(images)
69 | Gs, residuals = model(Gs, images, depths, intrinsics, graph, num_steps=args.iters)
70 |
71 | geo_loss, geo_metrics = geodesic_loss(Ps, Gs, graph)
72 | res_loss, res_metrics = residual_loss(residuals)
73 |
74 | metrics = {}
75 | metrics.update(geo_metrics)
76 | metrics.update(res_metrics)
77 |
78 | loss = args.w1 * geo_loss + args.w2 * res_loss
79 | loss.backward()
80 |
81 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
82 | optimizer.step()
83 | scheduler.step()
84 |
85 | logger.push(metrics)
86 | total_steps += 1
87 |
88 | if total_steps % 10000 == 0:
89 | PATH = 'checkpoints/%s_%06d.pth' % (args.name, total_steps)
90 | torch.save(model.state_dict(), PATH)
91 |
92 | run_evaluation(PATH)
93 |
94 | if total_steps >= args.steps:
95 | should_keep_training = False
96 | break
97 |
98 | return model
99 |
100 |
101 | if __name__ == '__main__':
102 | import argparse
103 | parser = argparse.ArgumentParser()
104 | parser.add_argument('--name', default='bla', help='name your experiment')
105 | parser.add_argument('--ckpt', help='checkpoint to restore')
106 | parser.add_argument('--datasets', nargs='+', help='lists of datasets for training')
107 |
108 | parser.add_argument('--batch', type=int, default=2)
109 | parser.add_argument('--iters', type=int, default=8)
110 | parser.add_argument('--steps', type=int, default=100000)
111 | parser.add_argument('--lr', type=float, default=0.0001)
112 | parser.add_argument('--clip', type=float, default=2.5)
113 | parser.add_argument('--n_frames', type=int, default=4)
114 |
115 | parser.add_argument('--w1', type=float, default=10.0)
116 | parser.add_argument('--w2', type=float, default=0.1)
117 |
118 | args = parser.parse_args()
119 |
120 | import os
121 | if not os.path.isdir('checkpoints'):
122 | os.mkdir('checkpoints')
123 |
124 | model = train(args)
125 |
126 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # LieTorch: Tangent Space Backpropagation
2 |
3 |
4 | ## Introduction
5 |
6 | The LieTorch library generalizes PyTorch to 3D transformation groups. Just as `torch.Tensor` is a multi-dimensional matrix of scalar elements, `lietorch.SE3` is a multi-dimensional matrix of SE3 elements. We support common tensor manipulations such as indexing, reshaping, and broadcasting. Group operations can be composed into computation graphs and backpropagation is automatically peformed in the tangent space of each element. For more details, please see our paper:
7 |
8 |
9 |
10 | [Tangent Space Backpropagation for 3D Transformation Groups](https://arxiv.org/pdf/2103.12032.pdf)
11 | Zachary Teed and Jia Deng, CVPR 2021
12 |
13 | ```
14 | @inproceedings{teed2021tangent,
15 | title={Tangent Space Backpropagation for 3D Transformation Groups},
16 | author={Teed, Zachary and Deng, Jia},
17 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
18 | year={2021},
19 | }
20 | ```
21 |
22 |
23 | ## Installation
24 |
25 |
26 | ### Installing (from source):
27 |
28 | Requires torch >= 2 and CUDA >= 11. Tested up to torch==2.7 and CUDA 12. Make sure PyTorch and CUDA major versions match.
29 |
30 | ```bash
31 | git clone --recursive https://github.com/princeton-vl/lietorch.git
32 | cd lietorch
33 |
34 | python3 -m venv .venv
35 | source .venv/bin/activate
36 |
37 | # install requirements
38 | pip install torch torchvision torchaudio wheel
39 |
40 | # optional: specify GPU architectures
41 | export TORCH_CUDA_ARCH_LIST="7.5;8.6;8.9;9.0"
42 |
43 | # install lietorch
44 | pip install --no-build-isolation .
45 | ```
46 |
47 | ### Installing (with pip)
48 | ```bash
49 | # optional: specify GPU architectures
50 | export TORCH_CUDA_ARCH_LIST="7.5;8.6;8.9;9.0"
51 |
52 | pip install --no-build-isolation git+https://github.com/princeton-vl/lietorch.git
53 | ```
54 |
55 |
56 | To run the examples, you will need these additional libraries
57 | ```bash
58 | pip install opencv-python open3d scipy pyyaml
59 | ```
60 |
61 | ### Running Tests
62 |
63 | After building, you can run the tests
64 | ```bash
65 | ./run_tests.sh
66 | ```
67 |
68 |
69 |
70 | ## Overview
71 |
72 | LieTorch currently supports the 3D transformation groups.
73 |
74 | | Group | Dimension | Action |
75 | | -------| --------- | ------------- |
76 | | SO3 | 3 | rotation |
77 | | RxSO3 | 4 | rotation + scaling |
78 | | SE3 | 6 | rotation + translation |
79 | | Sim3 | 7 | rotation + translation + scaling |
80 |
81 | Each group supports the following differentiable operations:
82 |
83 | | Operation | Map | Description |
84 | | -------| --------| ------------- |
85 | | exp | g -> G | exponential map |
86 | | log | G -> g | logarithm map |
87 | | inv | G -> G | group inverse |
88 | | mul | G x G -> G | group multiplication |
89 | | adj | G x g -> g | adjoint |
90 | | adjT | G x g*-> g* | dual adjoint |
91 | | act | G x R^3 -> R^3 | action on point (set) |
92 | | act4 | G x P^3 -> P^3 | action on homogeneous point (set) |
93 | | matrix | G -> R^{4x4} | convert to 4x4 matrix
94 | | vec | G -> R^D | map to Euclidean embedding vector |
95 | | InitFromVec | R^D -> G | initialize group from Euclidean embedding
96 |
97 |
98 |
99 |
100 | ### Simple Example:
101 | Compute the angles between all pairs of rotation matrices
102 |
103 | ```python
104 | import torch
105 | from lietorch import SO3
106 |
107 | phi = torch.randn(8000, 3, device='cuda', requires_grad=True)
108 | R = SO3.exp(phi)
109 |
110 | # relative rotation matrix, SO3 ^ {8000 x 8000}
111 | dR = R[:,None].inv() * R[None,:]
112 |
113 | # 8000x8000 matrix of angles
114 | ang = dR.log().norm(dim=-1)
115 |
116 | # backpropogation in tangent space
117 | loss = ang.sum()
118 | loss.backward()
119 | ```
120 |
121 |
122 | ### Converting between Groups Elements and Euclidean Embeddings
123 | We provide differentiable `FromVec` and `ToVec` functions which can be used to convert between LieGroup elements and their vector embeddings. Additional, the `.matrix` function returns a 4x4 transformation matrix.
124 | ```python
125 |
126 | # random quaternion
127 | q = torch.randn(1, 4, requires_grad=True)
128 | q = q / q.norm(dim=-1, keepdim=True)
129 |
130 | # create SO3 object from quaternion (differentiable w.r.t q)
131 | R = SO3.InitFromVec(q)
132 |
133 | # 4x4 transformation matrix (differentiable w.r.t R)
134 | T = R.matrix()
135 |
136 | # map back to quaterion (differentiable w.r.t R)
137 | q = R.vec()
138 |
139 | ```
140 |
141 |
142 | ## Examples
143 | We provide real use cases in the examples directory
144 | 1. Pose Graph Optimization
145 | 2. Deep SE3/Sim3 Registrtion
146 | 3. RGB-D SLAM / VO
147 |
148 | ### Acknowledgements
149 | Many of the Lie Group implementations are adapted from [Sophus](https://github.com/strasdat/Sophus).
150 |
--------------------------------------------------------------------------------
/examples/core/networks/sim3_net.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from collections import OrderedDict
6 |
7 | from networks.modules.extractor import BasicEncoder
8 | from networks.modules.corr import CorrBlock
9 | from networks.modules.gru import ConvGRU
10 | from networks.modules.clipping import GradientClip
11 |
12 | from lietorch import SE3, Sim3
13 | from geom.ba import MoBA
14 |
15 | import geom.projective_ops as pops
16 | from geom.sampler_utils import bilinear_sampler, sample_depths
17 | from geom.graph_utils import graph_to_edge_list, keyframe_indicies
18 |
19 |
20 | class UpdateModule(nn.Module):
21 | def __init__(self, args):
22 | super(UpdateModule, self).__init__()
23 | self.args = args
24 |
25 | cor_planes = 4 * (2*3 + 1)**2
26 |
27 | self.encoder = nn.Sequential(
28 | nn.Conv2d(cor_planes, 128, 1, padding=0),
29 | nn.ReLU(inplace=True),
30 | nn.Conv2d(128, 128, 3, padding=1),
31 | nn.ReLU(inplace=True),
32 | nn.Conv2d(128, 128, 3, padding=1),
33 | nn.ReLU(inplace=True))
34 |
35 | self.weight = nn.Sequential(
36 | nn.Conv2d(128, 128, 3, padding=1),
37 | nn.ReLU(inplace=True),
38 | nn.Conv2d(128, 3, 3, padding=1),
39 | GradientClip(),
40 | nn.Sigmoid())
41 |
42 | self.delta = nn.Sequential(
43 | nn.Conv2d(128, 128, 3, padding=1),
44 | nn.ReLU(inplace=True),
45 | nn.Conv2d(128, 3, 3, padding=1),
46 | GradientClip())
47 |
48 | self.gru = ConvGRU(128, 128+128+1)
49 |
50 | def forward(self, net, inp, corr, dz):
51 | """ update operator """
52 |
53 | batch, num, ch, ht, wd = net.shape
54 | output_dim = (batch, num, -1, ht, wd)
55 | net = net.view(batch*num, -1, ht, wd)
56 | inp = inp.view(batch*num, -1, ht, wd)
57 |
58 | corr = corr.view(batch*num, -1, ht, wd)
59 | dz = dz.view(batch*num, 1, ht, wd)
60 | corr = self.encoder(corr)
61 | net = self.gru(net, inp, corr, dz)
62 |
63 | ### update variables ###
64 | delta = self.delta(net).view(*output_dim)
65 | weight = self.weight(net).view(*output_dim)
66 |
67 | delta = delta.permute(0,1,3,4,2).contiguous()
68 | weight = weight.permute(0,1,3,4,2).contiguous()
69 |
70 | net = net.view(*output_dim)
71 | return net, delta, weight
72 |
73 |
74 | class Sim3Net(nn.Module):
75 | def __init__(self, args):
76 | super(Sim3Net, self).__init__()
77 | self.args = args
78 | self.fnet = BasicEncoder(output_dim=128, norm_fn='instance')
79 | self.cnet = BasicEncoder(output_dim=256, norm_fn='none')
80 | self.update = UpdateModule(args)
81 |
82 | def extract_features(self, images):
83 | """ run feeature extraction networks """
84 | fmaps = self.fnet(images)
85 | net = self.cnet(images)
86 |
87 | net, inp = net.split([128,128], dim=2)
88 | net = torch.tanh(net)
89 | inp = torch.relu(inp)
90 | return fmaps, net, inp
91 |
92 | def forward(self, Gs, images, depths, intrinsics, graph=None, num_steps=12):
93 | """ Estimates SE3 or Sim3 between pair of frames """
94 |
95 | if graph is None:
96 | graph = OrderedDict()
97 | graph[0] = [1]
98 | graph[1] = [0]
99 |
100 | u = keyframe_indicies(graph)
101 | ii, jj, kk = graph_to_edge_list(graph)
102 |
103 | # use inverse depth parameterization
104 | depths = depths.clamp(min=0.1, max=1000.0)
105 | disps = 1.0 / depths[:, :, 3::8, 3::8]
106 | intrinsics = intrinsics / 8.0
107 |
108 | fmaps, net, inp = self.extract_features(images)
109 | corr_fn = CorrBlock(fmaps[:,ii], fmaps[:,jj], num_levels=4, radius=3)
110 |
111 | Gs_list, coords_list, residual_list = [], [], []
112 | for step in range(num_steps):
113 | Gs = Gs.detach()
114 | coords1_xyz, _ = pops.projective_transform(Gs, disps, intrinsics, ii, jj)
115 |
116 | coords1, zinv_proj = coords1_xyz.split([2,1], dim=-1)
117 | zinv = sample_depths(disps[:,jj], coords1)
118 | dz = (zinv - zinv_proj).clamp(-1.0, 1.0)
119 |
120 | corr = corr_fn(coords1)
121 | net, delta, weight = self.update(net, inp, corr, dz)
122 |
123 | target = coords1_xyz + delta
124 | for i in range(3):
125 | Gs = MoBA(target, weight, Gs, disps, intrinsics, ii, jj)
126 |
127 | coords1_xyz, valid_mask = pops.projective_transform(Gs, disps, intrinsics, ii, jj)
128 | residual = valid_mask * (target - coords1_xyz)
129 |
130 | Gs_list.append(Gs)
131 | coords_list.append(target)
132 | residual_list.append(residual)
133 |
134 | return Gs_list, residual_list
135 |
136 |
--------------------------------------------------------------------------------
/examples/core/data_readers/rgbd_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os.path as osp
3 |
4 | import torch
5 | from lietorch import SE3
6 |
7 | import geom.projective_ops as pops
8 | from scipy.spatial.transform import Rotation
9 |
10 |
11 | def parse_list(filepath, skiprows=0):
12 | """ read list data """
13 | data = np.loadtxt(filepath, delimiter=' ', dtype=np.unicode_, skiprows=skiprows)
14 | return data
15 |
16 | def associate_frames(tstamp_image, tstamp_depth, tstamp_pose, max_dt=0.08):
17 | """ pair images, depths, and poses """
18 | associations = []
19 | for i, t in enumerate(tstamp_image):
20 | if tstamp_pose is None:
21 | j = np.argmin(np.abs(tstamp_depth - t))
22 | if (np.abs(tstamp_depth[j] - t) < max_dt):
23 | associations.append((i, j))
24 |
25 | else:
26 | j = np.argmin(np.abs(tstamp_depth - t))
27 | k = np.argmin(np.abs(tstamp_pose - t))
28 |
29 | if (np.abs(tstamp_depth[j] - t) < max_dt) and \
30 | (np.abs(tstamp_pose[k] - t) < max_dt):
31 | associations.append((i, j, k))
32 |
33 | return associations
34 |
35 | def loadtum(datapath, frame_rate=-1):
36 | """ read video data in tum-rgbd format """
37 | if osp.isfile(osp.join(datapath, 'groundtruth.txt')):
38 | pose_list = osp.join(datapath, 'groundtruth.txt')
39 | elif osp.isfile(osp.join(datapath, 'pose.txt')):
40 | pose_list = osp.join(datapath, 'pose.txt')
41 |
42 | image_list = osp.join(datapath, 'rgb.txt')
43 | depth_list = osp.join(datapath, 'depth.txt')
44 |
45 | calib_path = osp.join(datapath, 'calibration.txt')
46 | intrinsic = None
47 | if osp.isfile(calib_path):
48 | intrinsic = np.loadtxt(calib_path, delimiter=' ')
49 | intrinsic = intrinsic.astype(np.float64)
50 |
51 | image_data = parse_list(image_list)
52 | depth_data = parse_list(depth_list)
53 | pose_data = parse_list(pose_list, skiprows=1)
54 | pose_vecs = pose_data[:,1:].astype(np.float64)
55 |
56 | tstamp_image = image_data[:,0].astype(np.float64)
57 | tstamp_depth = depth_data[:,0].astype(np.float64)
58 | tstamp_pose = pose_data[:,0].astype(np.float64)
59 | associations = associate_frames(tstamp_image, tstamp_depth, tstamp_pose)
60 |
61 | indicies = [ 0 ]
62 | for i in range(1, len(associations)):
63 | t0 = tstamp_image[associations[indicies[-1]][0]]
64 | t1 = tstamp_image[associations[i][0]]
65 | if t1 - t0 > 1.0 / frame_rate:
66 | indicies += [ i ]
67 |
68 | images, poses, depths, intrinsics = [], [], [], []
69 | for ix in indicies:
70 | (i, j, k) = associations[ix]
71 | images += [ osp.join(datapath, image_data[i,1]) ]
72 | depths += [ osp.join(datapath, depth_data[j,1]) ]
73 | poses += [ pose_vecs[k] ]
74 |
75 | if intrinsic is not None:
76 | intrinsics += [ intrinsic ]
77 |
78 | return images, depths, poses, intrinsics
79 |
80 |
81 | def all_pairs_distance_matrix(poses, beta=2.5):
82 | """ compute distance matrix between all pairs of poses """
83 | poses = np.array(poses, dtype=np.float32)
84 | poses[:,:3] *= beta # scale to balence rot + trans
85 | poses = SE3(torch.from_numpy(poses))
86 |
87 | r = (poses[:,None].inv() * poses[None,:]).log()
88 | return r.norm(dim=-1).cpu().numpy()
89 |
90 | def pose_matrix_to_quaternion(pose):
91 | """ convert 4x4 pose matrix to (t, q) """
92 | q = Rotation.from_matrix(pose[:3, :3]).as_quat()
93 | return np.concatenate([pose[:3, 3], q], axis=0)
94 |
95 | def compute_distance_matrix_flow(poses, disps, intrinsics):
96 | """ compute flow magnitude between all pairs of frames """
97 | if not isinstance(poses, SE3):
98 | poses = torch.from_numpy(poses).float().cuda()[None]
99 | poses = SE3(poses).inv()
100 |
101 | disps = torch.from_numpy(disps).float().cuda()[None]
102 | intrinsics = torch.from_numpy(intrinsics).float().cuda()[None]
103 |
104 | N = poses.shape[1]
105 |
106 | ii, jj = torch.meshgrid(torch.arange(N), torch.arange(N))
107 | ii = ii.reshape(-1).cuda()
108 | jj = jj.reshape(-1).cuda()
109 |
110 | MAX_FLOW = 100.0
111 | matrix = np.zeros((N, N), dtype=np.float32)
112 |
113 | s = 2048
114 | for i in range(0, ii.shape[0], s):
115 | flow1, val1 = pops.induced_flow(poses, disps, intrinsics, ii[i:i+s], jj[i:i+s])
116 | flow2, val2 = pops.induced_flow(poses, disps, intrinsics, jj[i:i+s], ii[i:i+s])
117 |
118 | flow = torch.stack([flow1, flow2], dim=2)
119 | val = torch.stack([val1, val2], dim=2)
120 |
121 | mag = flow.norm(dim=-1).clamp(max=MAX_FLOW)
122 | mag = mag.view(mag.shape[1], -1)
123 | val = val.view(val.shape[1], -1)
124 |
125 | mag = (mag * val).mean(-1) / val.mean(-1)
126 | mag[val.mean(-1) < 0.7] = np.inf
127 |
128 | i1 = ii[i:i+s].cpu().numpy()
129 | j1 = jj[i:i+s].cpu().numpy()
130 | matrix[i1, j1] = mag.cpu().numpy()
131 |
132 | return matrix
133 |
--------------------------------------------------------------------------------
/examples/core/networks/modules/corr.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | from geom.sampler_utils import bilinear_sampler
5 | import lietorch_extras
6 |
7 |
8 | class CorrSampler(torch.autograd.Function):
9 |
10 | @staticmethod
11 | def forward(ctx, volume, coords, radius):
12 | ctx.save_for_backward(volume,coords)
13 | ctx.radius = radius
14 | corr, = lietorch_extras.corr_index_forward(volume, coords, radius)
15 | return corr
16 |
17 | @staticmethod
18 | def backward(ctx, grad_output):
19 | volume, coords = ctx.saved_tensors
20 | grad_output = grad_output.contiguous()
21 | grad_volume, = lietorch_extras.corr_index_backward(volume, coords, grad_output, ctx.radius)
22 | return grad_volume, None, None
23 |
24 |
25 | class CorrBlock:
26 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
27 | self.num_levels = num_levels
28 | self.radius = radius
29 | self.corr_pyramid = []
30 |
31 | # all pairs correlation
32 | corr = CorrBlock.corr(fmap1, fmap2)
33 |
34 | batch, num, h1, w1, h2, w2 = corr.shape
35 | corr = corr.reshape(batch*num*h1*w1, 1, h2, w2)
36 |
37 | for i in range(self.num_levels):
38 | self.corr_pyramid.append(
39 | corr.view(batch*num, h1, w1, h2//2**i, w2//2**i))
40 | corr = F.avg_pool2d(corr, 2, stride=2)
41 |
42 | def __call__(self, coords):
43 | out_pyramid = []
44 | batch, num, ht, wd, _ = coords.shape
45 | coords = coords.permute(0,1,4,2,3)
46 | coords = coords.contiguous().view(batch*num, 2, ht, wd)
47 |
48 | for i in range(self.num_levels):
49 | corr = CorrSampler.apply(self.corr_pyramid[i], coords/2**i, self.radius)
50 | out_pyramid.append(corr.view(batch, num, -1, ht, wd))
51 |
52 | return torch.cat(out_pyramid, dim=2)
53 |
54 | def append(self, other):
55 | for i in range(self.num_levels):
56 | self.corr_pyramid[i] = torch.cat([self.corr_pyramid[i], other.corr_pyramid[i]], 0)
57 |
58 | def remove(self, ix):
59 | for i in range(self.num_levels):
60 | self.corr_pyramid[i] = self.corr_pyramid[i][ix].contiguous()
61 |
62 | @staticmethod
63 | def corr(fmap1, fmap2):
64 | """ all-pairs correlation """
65 | batch, num, dim, ht, wd = fmap1.shape
66 | fmap1 = fmap1.reshape(batch*num, dim, ht*wd) / 4.0
67 | fmap2 = fmap2.reshape(batch*num, dim, ht*wd) / 4.0
68 |
69 | corr = torch.matmul(fmap1.transpose(1,2), fmap2)
70 | return corr.view(batch, num, ht, wd, ht, wd)
71 |
72 |
73 | class CorrLayer(torch.autograd.Function):
74 | @staticmethod
75 | def forward(ctx, fmap1, fmap2, coords, r):
76 | ctx.r = r
77 | fmap1 = fmap1.contiguous()
78 | fmap2 = fmap2.contiguous()
79 | coords = coords.contiguous()
80 | ctx.save_for_backward(fmap1, fmap2, coords)
81 | corr, = lietorch_extras.altcorr_forward(fmap1, fmap2, coords, ctx.r)
82 | return corr
83 |
84 | @staticmethod
85 | def backward(ctx, grad_corr):
86 | fmap1, fmap2, coords = ctx.saved_tensors
87 | grad_corr = grad_corr.contiguous()
88 | fmap1_grad, fmap2_grad, coords_grad = \
89 | lietorch_extras.altcorr_backward(fmap1, fmap2, coords, grad_corr, ctx.r)
90 | return fmap1_grad, fmap2_grad, coords_grad, None
91 |
92 |
93 |
94 | class AltCorrBlock:
95 | def __init__(self, fmaps, inds, num_levels=4, radius=3):
96 | self.num_levels = num_levels
97 | self.radius = radius
98 | self.inds = inds
99 |
100 | B, N, C, H, W = fmaps.shape
101 | fmaps = fmaps.view(B*N, C, H, W)
102 |
103 | self.pyramid = []
104 | for i in range(self.num_levels):
105 | sz = (B, N, H//2**i, W//2**i, C)
106 | fmap_lvl = fmaps.permute(0, 2, 3, 1)
107 | self.pyramid.append(fmap_lvl.reshape(*sz))
108 | fmaps = F.avg_pool2d(fmaps, 2, stride=2)
109 |
110 | def corr_fn(self, coords, ii, jj):
111 | B, N, H, W, S, _ = coords.shape
112 | coords = coords.permute(0, 1, 4, 2, 3, 5)
113 |
114 | corr_list = []
115 | for i in range(self.num_levels):
116 | r = self.radius
117 | fmap1_i = self.pyramid[0][:, ii]
118 | fmap2_i = self.pyramid[i][:, jj]
119 |
120 | coords_i = (coords / 2**i).reshape(B*N, S, H, W, 2).contiguous()
121 | fmap1_i = fmap1_i.reshape((B*N,) + fmap1_i.shape[2:])
122 | fmap2_i = fmap2_i.reshape((B*N,) + fmap2_i.shape[2:])
123 |
124 | corr = CorrLayer.apply(fmap1_i, fmap2_i, coords_i, self.radius)
125 | corr = corr.view(B, N, S, -1, H, W).permute(0, 1, 3, 4, 5, 2)
126 | corr_list.append(corr)
127 |
128 | corr = torch.cat(corr_list, dim=2)
129 | return corr / 16.0
130 |
131 |
132 | def __call__(self, coords, ii, jj):
133 | squeeze_output = False
134 | if len(coords.shape) == 5:
135 | coords = coords.unsqueeze(dim=-2)
136 | squeeze_output = True
137 |
138 | corr = self.corr_fn(coords, ii, jj)
139 |
140 | if squeeze_output:
141 | corr = corr.squeeze(dim=-1)
142 |
143 | return corr.contiguous()
144 |
145 |
--------------------------------------------------------------------------------
/examples/core/networks/rslam.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from collections import OrderedDict
6 |
7 | from networks.modules.extractor import BasicEncoder
8 | from networks.modules.corr import CorrBlock
9 | from networks.modules.gru import ConvGRU
10 | from networks.modules.clipping import GradientClip
11 |
12 | from lietorch import SE3, Sim3
13 | from geom.ba import MoBA
14 |
15 | import geom.projective_ops as pops
16 | from geom.sampler_utils import bilinear_sampler, sample_depths
17 | from geom.graph_utils import graph_to_edge_list, keyframe_indicies
18 |
19 |
20 | class UpdateModule(nn.Module):
21 | def __init__(self, args):
22 | super(UpdateModule, self).__init__()
23 | self.args = args
24 |
25 | cor_planes = 4 * (2*3 + 1)**2 + 1
26 |
27 | self.corr_encoder = nn.Sequential(
28 | nn.Conv2d(cor_planes, 128, 1, padding=0),
29 | nn.ReLU(inplace=True),
30 | nn.Conv2d(128, 128, 3, padding=1),
31 | nn.ReLU(inplace=True))
32 |
33 | self.flow_encoder = nn.Sequential(
34 | nn.Conv2d(3, 128, 7, padding=3),
35 | nn.ReLU(inplace=True),
36 | nn.Conv2d(128, 64, 3, padding=1),
37 | nn.ReLU(inplace=True))
38 |
39 | self.weight = nn.Sequential(
40 | nn.Conv2d(128, 128, 3, padding=1),
41 | nn.ReLU(inplace=True),
42 | nn.Conv2d(128, 3, 3, padding=1),
43 | GradientClip(),
44 | nn.Sigmoid())
45 |
46 | self.delta = nn.Sequential(
47 | nn.Conv2d(128, 128, 3, padding=1),
48 | nn.ReLU(inplace=True),
49 | nn.Conv2d(128, 3, 3, padding=1),
50 | GradientClip())
51 |
52 | self.gru = ConvGRU(128, 128+128+64)
53 |
54 | def forward(self, net, inp, corr, flow):
55 | """ RaftSLAM update operator """
56 |
57 | batch, num, ch, ht, wd = net.shape
58 | output_dim = (batch, num, -1, ht, wd)
59 | net = net.view(batch*num, -1, ht, wd)
60 | inp = inp.view(batch*num, -1, ht, wd)
61 | corr = corr.view(batch*num, -1, ht, wd)
62 | flow = flow.view(batch*num, -1, ht, wd)
63 |
64 | corr = self.corr_encoder(corr)
65 | flow = self.flow_encoder(flow)
66 | net = self.gru(net, inp, corr, flow)
67 |
68 | ### update variables ###
69 | delta = self.delta(net).view(*output_dim)
70 | weight = self.weight(net).view(*output_dim)
71 |
72 | delta = delta.permute(0,1,3,4,2).contiguous()
73 | weight = weight.permute(0,1,3,4,2).contiguous()
74 |
75 | net = net.view(*output_dim)
76 | return net, delta, weight
77 |
78 |
79 | class RaftSLAM(nn.Module):
80 | def __init__(self, args):
81 | super(RaftSLAM, self).__init__()
82 | self.args = args
83 | self.fnet = BasicEncoder(output_dim=128, norm_fn='instance')
84 | self.cnet = BasicEncoder(output_dim=256, norm_fn='none')
85 | self.update = UpdateModule(args)
86 |
87 | def extract_features(self, images):
88 | """ run feeature extraction networks """
89 | fmaps = self.fnet(images)
90 | net = self.cnet(images)
91 |
92 | net, inp = net.split([128,128], dim=2)
93 | net = torch.tanh(net)
94 | inp = torch.relu(inp)
95 | return fmaps, net, inp
96 |
97 | def forward(self, Gs, images, depths, intrinsics, graph=None, num_steps=12):
98 | """ Estimates SE3 or Sim3 between pair of frames """
99 |
100 | u = keyframe_indicies(graph)
101 | ii, jj, kk = graph_to_edge_list(graph)
102 |
103 | depths = depths[:, :, 3::8, 3::8]
104 | intrinsics = intrinsics / 8
105 | mask = (depths > 0.1).float()
106 | disps = torch.where(depths>0.1, 1.0/depths, depths)
107 |
108 | fmaps, net, inp = self.extract_features(images)
109 | net, inp = net[:,ii], inp[:,ii]
110 | corr_fn = CorrBlock(fmaps[:,ii], fmaps[:,jj], num_levels=4, radius=3)
111 |
112 | coords, valid_mask = pops.projective_transform(Gs, disps, intrinsics, ii, jj)
113 | residual = torch.zeros_like(coords[...,:2])
114 |
115 | Gs_list, coords_list, residual_list = [], [], []
116 | for step in range(num_steps):
117 | Gs = Gs.detach()
118 | coords = coords.detach()
119 | residual = residual.detach()
120 |
121 | corr = corr_fn(coords[...,:2])
122 | flow = residual.permute(0,1,4,2,3).clamp(-32.0, 32.0)
123 |
124 | corr = torch.cat([corr, mask[:,ii,None]], dim=2)
125 | flow = torch.cat([flow, mask[:,ii,None]], dim=2)
126 | net, delta, weight = self.update(net, inp, corr, flow)
127 |
128 | target = coords + delta
129 | weight[...,2] = 0.0
130 |
131 | for i in range(3):
132 | Gs = MoBA(target, weight, Gs, disps, intrinsics, ii, jj)
133 |
134 | coords, valid_mask = pops.projective_transform(Gs, disps, intrinsics, ii, jj)
135 | residual = (target - coords)[...,:2]
136 |
137 | Gs_list.append(Gs)
138 | coords_list.append(target)
139 |
140 | valid_mask = valid_mask * mask[:,ii].unsqueeze(-1)
141 | residual_list.append(valid_mask * residual)
142 |
143 | return Gs_list, residual_list
144 |
--------------------------------------------------------------------------------
/examples/core/data_readers/base.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np
3 | import torch
4 | import torch.utils.data as data
5 | import torch.nn.functional as F
6 |
7 | import csv
8 | import os
9 | import cv2
10 | import math
11 | import random
12 | import json
13 | import pickle
14 | import os.path as osp
15 |
16 | from .augmentation import RGBDAugmentor
17 | from .rgbd_utils import *
18 |
19 | class RGBDDataset(data.Dataset):
20 | def __init__(self, root, name, n_frames=4, crop_size=[384,512], fmin=8.0, fmax=75.0, do_aug=True):
21 | """ Base class for RGBD dataset """
22 | self.aug = None
23 | self.root = root
24 | self.name = name
25 |
26 | self.n_frames = n_frames
27 | self.fmin = fmin # exclude very easy examples
28 | self.fmax = fmax # exclude very hard examples
29 |
30 | if do_aug:
31 | self.aug = RGBDAugmentor(crop_size=crop_size)
32 |
33 | # building dataset is expensive, cache so only needs to be performed once
34 | cur_path = osp.dirname(osp.abspath(__file__))
35 | cache_path = osp.join(cur_path, 'cache', '{}.pickle'.format(self.name))
36 |
37 | if not osp.isdir(osp.join(cur_path, 'cache')):
38 | os.mkdir(osp.join(cur_path, 'cache'))
39 |
40 | if osp.isfile(cache_path):
41 | scene_info = pickle.load(open(cache_path, 'rb'))[0]
42 | else:
43 | scene_info = self._build_dataset()
44 | with open(cache_path, 'wb') as cachefile:
45 | pickle.dump((scene_info,), cachefile)
46 |
47 | self.scene_info = scene_info
48 | self._build_dataset_index()
49 |
50 | def _build_dataset_index(self):
51 | self.dataset_index = []
52 | for scene in self.scene_info:
53 | if not self.__class__.is_test_scene(scene):
54 | graph = self.scene_info[scene]['graph']
55 | for i in graph:
56 | if len(graph[i][0]) > self.n_frames:
57 | self.dataset_index.append((scene, i))
58 |
59 | @staticmethod
60 | def image_read(image_file):
61 | return cv2.imread(image_file)
62 |
63 | @staticmethod
64 | def depth_read(depth_file):
65 | return np.load(depth_file)
66 |
67 | def build_frame_graph(self, poses, depths, intrinsics, f=16, max_flow=256):
68 | """ compute optical flow distance between all pairs of frames """
69 | def read_disp(fn):
70 | depth = self.__class__.depth_read(fn)[f//2::f, f//2::f]
71 | depth[depth < 0.01] = np.mean(depth)
72 | return 1.0 / depth
73 |
74 | poses = np.array(poses)
75 | intrinsics = np.array(intrinsics) / f
76 |
77 | disps = np.stack(list(map(read_disp, depths)), 0)
78 | d = f * compute_distance_matrix_flow(poses, disps, intrinsics)
79 |
80 | # uncomment for nice visualization
81 | # import matplotlib.pyplot as plt
82 | # plt.imshow(d)
83 | # plt.show()
84 |
85 | graph = {}
86 | for i in range(d.shape[0]):
87 | j, = np.where(d[i] < max_flow)
88 | graph[i] = (j, d[i,j])
89 |
90 | return graph
91 |
92 | def __getitem__(self, index):
93 | """ return training video """
94 |
95 | index = index % len(self.dataset_index)
96 | scene_id, ix = self.dataset_index[index]
97 |
98 | frame_graph = self.scene_info[scene_id]['graph']
99 | images_list = self.scene_info[scene_id]['images']
100 | depths_list = self.scene_info[scene_id]['depths']
101 | poses_list = self.scene_info[scene_id]['poses']
102 | intrinsics_list = self.scene_info[scene_id]['intrinsics']
103 |
104 | inds = [ ix ]
105 | while len(inds) < self.n_frames:
106 | # get other frames within flow threshold
107 | k = (frame_graph[ix][1] > self.fmin) & (frame_graph[ix][1] < self.fmax)
108 | frames = frame_graph[ix][0][k]
109 |
110 | # prefer frames forward in time
111 | if np.count_nonzero(frames[frames > ix]):
112 | ix = np.random.choice(frames[frames > ix])
113 |
114 | elif np.count_nonzero(frames):
115 | ix = np.random.choice(frames)
116 |
117 | inds += [ ix ]
118 |
119 | images, depths, poses, intrinsics = [], [], [], []
120 | for i in inds:
121 | images.append(self.__class__.image_read(images_list[i]))
122 | depths.append(self.__class__.depth_read(depths_list[i]))
123 | poses.append(poses_list[i])
124 | intrinsics.append(intrinsics_list[i])
125 |
126 | images = np.stack(images).astype(np.float32)
127 | depths = np.stack(depths).astype(np.float32)
128 | poses = np.stack(poses).astype(np.float32)
129 | intrinsics = np.stack(intrinsics).astype(np.float32)
130 |
131 | images = torch.from_numpy(images).float()
132 | images = images.permute(0, 3, 1, 2)
133 |
134 | depths = torch.from_numpy(depths)
135 | poses = torch.from_numpy(poses)
136 | intrinsics = torch.from_numpy(intrinsics)
137 |
138 | if self.aug is not None:
139 | images, poses, depths, intrinsics = \
140 | self.aug(images, poses, depths, intrinsics)
141 |
142 | return images, poses, depths, intrinsics
143 |
144 | def __len__(self):
145 | return len(self.dataset_index)
146 |
147 | def __imul__(self, x):
148 | self.dataset_index *= x
149 | return self
150 |
--------------------------------------------------------------------------------
/lietorch/extras/se3_solver.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 |
5 | #define NUM_THREADS 64
6 | #define EPS 1e-8
7 |
8 |
9 | template
10 | __device__ __forceinline__ void llt(const float A[N][N], float L[N][N])
11 | {
12 |
13 | for (int i=0; i A[i][i] ? A[i][i] + EPS : s;
28 | L[i][j] = sqrtf(A[i][i]-s);
29 | }
30 |
31 | else
32 | L[i][j] = (A[i][j] - s) / L[j][j];
33 | }
34 | }
35 | }
36 |
37 | template
38 | __device__ __forceinline__ void llt_solve(const float L[N][N], float x[N])
39 | {
40 | float s;
41 | for (int i=0; i=0; i--) {
50 | s = 0.0;
51 | for (int j=i+1; j H_tensor,
61 | const torch::PackedTensorAccessor32 b_tensor,
62 | torch::PackedTensorAccessor32 x_tensor) {
63 |
64 | /*Inputs: H [batch,6,6,ht,wd], b [batch,6,1,ht,wd]
65 | Outputs: x [batch,6,1,ht,wd]; Hx = b
66 | */
67 |
68 | int batch_id = blockIdx.x;
69 | const int dim = H_tensor.size(3) * H_tensor.size(4);
70 | int m = blockIdx.y * NUM_THREADS + threadIdx.x;
71 |
72 | const float* H_ptr = H_tensor[batch_id].data();
73 | const float* b_ptr = b_tensor[batch_id].data();
74 | float* x_ptr = x_tensor[batch_id].data();
75 |
76 | if (m < dim) {
77 | float H[6][6], L[6][6], x[6];
78 |
79 | for (int i=0; i<6; i++) {
80 | for (int j=0; j<6; j++) {
81 | H[i][j] = H_ptr[m + (6*i+j)*dim];
82 | }
83 | }
84 |
85 | for (int i=0; i<6; i++) {
86 | x[i] = b_ptr[m + i*dim];
87 | }
88 |
89 | llt<6>(H, L);
90 | llt_solve<6>(L, x);
91 |
92 | for (int i=0; i<6; i++) {
93 | x_ptr[m + i*dim] = x[i];
94 | }
95 | }
96 | }
97 |
98 |
99 | __global__ void cholesky_solve6x6_backward_kernel(
100 | const torch::PackedTensorAccessor32 H_tensor,
101 | const torch::PackedTensorAccessor32 b_tensor,
102 | const torch::PackedTensorAccessor32 dx_tensor,
103 | torch::PackedTensorAccessor32 dH_tensor,
104 | torch::PackedTensorAccessor32 db_tensor) {
105 |
106 |
107 | int batch_id = blockIdx.x;
108 | const int dim = H_tensor.size(3) * H_tensor.size(4);
109 | int m = blockIdx.y * NUM_THREADS + threadIdx.x;
110 |
111 | const float* H_ptr = H_tensor[batch_id].data();
112 | const float* b_ptr = b_tensor[batch_id].data();
113 |
114 | const float* dx_ptr = dx_tensor[batch_id].data();
115 | float* dH_ptr = dH_tensor[batch_id].data();
116 | float* db_ptr = db_tensor[batch_id].data();
117 |
118 | if (m < dim) {
119 | float H[6][6], L[6][6], x[6], dz[6];
120 |
121 | for (int i=0; i<6; i++) {
122 | for (int j=0; j<6; j++) {
123 | H[i][j] = H_ptr[m + (6*i+j)*dim];
124 | }
125 | }
126 |
127 | for (int i=0; i<6; i++) {
128 | x[i] = b_ptr[m + i*dim];
129 | }
130 |
131 | for (int i=0; i<6; i++) {
132 | dz[i] = dx_ptr[m + i*dim];
133 | }
134 |
135 | // cholesky factorization
136 | llt<6>(H, L);
137 |
138 | llt_solve<6>(L, x);
139 | llt_solve<6>(L, dz);
140 |
141 | for (int i=0; i<6; i++) {
142 | for (int j=0; j<6; j++) {
143 | dH_ptr[m + (6*i+j)*dim] = -dz[i] * x[j];
144 | }
145 | }
146 |
147 | for (int i=0; i<6; i++) {
148 | db_ptr[m + i*dim] = dz[i];
149 | }
150 | }
151 | }
152 |
153 |
154 | std::vector cholesky_solve6x6_forward_cuda(torch::Tensor H, torch::Tensor b) {
155 |
156 | const int batch_size = H.size(0);
157 | const int ht = H.size(3);
158 | const int wd = H.size(4);
159 |
160 | torch::Tensor x = torch::zeros_like(b);
161 | dim3 grid = dim3(batch_size, (ht*wd + NUM_THREADS-1) / NUM_THREADS);
162 |
163 | cholesky_solve6x6_forward_kernel<<>>(
164 | H.packed_accessor32(),
165 | b.packed_accessor32(),
166 | x.packed_accessor32());
167 |
168 | return {x};
169 | }
170 |
171 |
172 | std::vector cholesky_solve6x6_backward_cuda(torch::Tensor H, torch::Tensor b, torch::Tensor dx) {
173 | const int batch_size = H.size(0);
174 | const int ht = H.size(3);
175 | const int wd = H.size(4);
176 |
177 | torch::Tensor dH = torch::zeros_like(H);
178 | torch::Tensor db = torch::zeros_like(b);
179 | dim3 grid = dim3(batch_size, (ht*wd + NUM_THREADS-1) / NUM_THREADS);
180 |
181 | cholesky_solve6x6_backward_kernel<<>>(
182 | H.packed_accessor32(),
183 | b.packed_accessor32(),
184 | dx.packed_accessor32(),
185 | dH.packed_accessor32(),
186 | db.packed_accessor32());
187 |
188 | return {dH, db};
189 | }
--------------------------------------------------------------------------------
/examples/rgbdslam/viz.py:
--------------------------------------------------------------------------------
1 | import time
2 | import argparse
3 | import torch
4 | import scipy
5 | import numpy as np
6 | import open3d as o3d
7 |
8 | from queue import Empty
9 | from multiprocessing import Queue, Process
10 | from scipy.spatial.transform import Rotation
11 |
12 | def pose_matrix_from_quaternion(pvec):
13 | """ convert 4x4 pose matrix to (t, q) """
14 | pose = np.eye(4)
15 | pose[:3,:3] = Rotation.from_quat(pvec[3:]).as_matrix()
16 | pose[:3, 3] = pvec[:3]
17 | return pose
18 |
19 | def create_camera_actor(is_gt=False, scale=0.05):
20 | """ build open3d camera polydata """
21 |
22 | cam_points = scale * np.array([
23 | [ 0, 0, 0],
24 | [-1, -1, 1.5],
25 | [ 1, -1, 1.5],
26 | [ 1, 1, 1.5],
27 | [-1, 1, 1.5],
28 | [-0.5, 1, 1.5],
29 | [ 0.5, 1, 1.5],
30 | [ 0, 1.2, 1.5]])
31 |
32 | cam_lines = np.array([[1, 2], [2, 3], [3, 4], [4, 1],
33 | [1, 0], [0, 2], [3, 0], [0, 4], [5, 7], [7, 6]])
34 |
35 | camera_actor = o3d.geometry.LineSet(
36 | points=o3d.utility.Vector3dVector(cam_points),
37 | lines=o3d.utility.Vector2iVector(cam_lines))
38 |
39 | color = (0.0, 0.0, 0.0) if is_gt else (0.0, 0.8, 0.8)
40 | camera_actor.paint_uniform_color(color)
41 |
42 | return camera_actor
43 |
44 | def create_point_cloud_actor(points, colors):
45 | """ open3d point cloud from numpy array """
46 |
47 | point_cloud = o3d.geometry.PointCloud()
48 | point_cloud.points = o3d.utility.Vector3dVector(points)
49 | point_cloud.colors = o3d.utility.Vector3dVector(colors)
50 |
51 | return point_cloud
52 |
53 | def draw_trajectory(queue):
54 |
55 | draw_trajectory.queue = queue
56 | draw_trajectory.cameras = {}
57 | draw_trajectory.points = {}
58 | draw_trajectory.ix = 0
59 | draw_trajectory.warmup = 8
60 |
61 | def animation_callback(vis):
62 | cam = vis.get_view_control().convert_to_pinhole_camera_parameters()
63 | while True:
64 | try:
65 | data = draw_trajectory.queue.get_nowait()
66 | if data[0] == 'pose':
67 | i, pose, is_gt = data[1:]
68 |
69 | # convert to 4x4 matrix
70 | pose = pose_matrix_from_quaternion(pose)
71 |
72 | if i in draw_trajectory.cameras:
73 | cam_actor, pose_prev = draw_trajectory.cameras[i]
74 | pose_change = pose @ np.linalg.inv(pose_prev)
75 |
76 | cam_actor.transform(pose_change)
77 | vis.update_geometry(cam_actor)
78 |
79 | if i in draw_trajectory.points:
80 | pc = draw_trajectory.points[i]
81 | pc.transform(pose_change)
82 | vis.update_geometry(pc)
83 |
84 | else:
85 | cam_actor = create_camera_actor(is_gt)
86 | cam_actor.transform(pose)
87 | vis.add_geometry(cam_actor)
88 |
89 | if not is_gt:
90 | draw_trajectory.cameras[i] = (cam_actor, pose)
91 |
92 | elif data[0] == 'points':
93 | i, points, colors = data[1:]
94 | point_actor = create_point_cloud_actor(points, colors)
95 |
96 | pose = draw_trajectory.cameras[i][1]
97 | point_actor.transform(pose)
98 | vis.add_geometry(point_actor)
99 |
100 | draw_trajectory.points[i] = point_actor
101 |
102 | elif data[0] == 'reset':
103 | draw_trajectory.warmup = -1
104 |
105 | for i in draw_trajectory.points:
106 | vis.remove_geometry(draw_trajectory.points[i])
107 |
108 | for i in draw_trajectory.cameras:
109 | vis.remove_geometry(draw_trajectory.cameras[i][0])
110 |
111 | draw_trajectory.cameras = {}
112 | draw_trajectory.points = {}
113 |
114 | except Empty:
115 | break
116 |
117 | # hack to allow interacting with vizualization during inference
118 | if len(draw_trajectory.cameras) >= draw_trajectory.warmup:
119 | cam = vis.get_view_control().convert_from_pinhole_camera_parameters(cam)
120 |
121 | vis.poll_events()
122 | vis.update_renderer()
123 |
124 | vis = o3d.visualization.Visualizer()
125 |
126 | vis.register_animation_callback(animation_callback)
127 | vis.create_window(height=540, width=960)
128 | vis.get_render_option().load_from_json("assets/renderoption.json")
129 |
130 | vis.run()
131 | vis.destroy_window()
132 |
133 |
134 | class SLAMFrontend:
135 | def __init__(self):
136 | self.queue = Queue()
137 | self.p = Process(target=draw_trajectory, args=(self.queue, ))
138 |
139 | def update_pose(self, index, pose, gt=False):
140 | if isinstance(pose, torch.Tensor):
141 | pose = pose.cpu().numpy()
142 | self.queue.put_nowait(('pose', index, pose, gt))
143 |
144 | def update_points(self, index, points, colors):
145 | if isinstance(points, torch.Tensor):
146 | points = points.cpu().numpy()
147 | self.queue.put_nowait(('points', index, points, colors))
148 |
149 | def reset(self):
150 | self.queue.put_nowait(('reset', ))
151 |
152 | def start(self):
153 | self.p.start()
154 | return self
155 |
156 | def join(self):
157 | self.p.join()
158 |
159 |
160 |
161 |
--------------------------------------------------------------------------------
/examples/rgbdslam/rgbd_benchmark/associate.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # Software License Agreement (BSD License)
3 | #
4 | # Copyright (c) 2013, Juergen Sturm, TUM
5 | # All rights reserved.
6 | #
7 | # Redistribution and use in source and binary forms, with or without
8 | # modification, are permitted provided that the following conditions
9 | # are met:
10 | #
11 | # * Redistributions of source code must retain the above copyright
12 | # notice, this list of conditions and the following disclaimer.
13 | # * Redistributions in binary form must reproduce the above
14 | # copyright notice, this list of conditions and the following
15 | # disclaimer in the documentation and/or other materials provided
16 | # with the distribution.
17 | # * Neither the name of TUM nor the names of its
18 | # contributors may be used to endorse or promote products derived
19 | # from this software without specific prior written permission.
20 | #
21 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22 | # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23 | # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
24 | # FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
25 | # COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
26 | # INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
27 | # BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
28 | # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
29 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
30 | # LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
31 | # ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
32 | # POSSIBILITY OF SUCH DAMAGE.
33 | #
34 | # Requirements:
35 | # sudo apt-get install python-argparse
36 |
37 | """
38 | The Kinect provides the color and depth images in an un-synchronized way. This means that the set of time stamps from the color images do not intersect with those of the depth images. Therefore, we need some way of associating color images to depth images.
39 |
40 | For this purpose, you can use the ''associate.py'' script. It reads the time stamps from the rgb.txt file and the depth.txt file, and joins them by finding the best matches.
41 | """
42 |
43 | import argparse
44 | import sys
45 | import os
46 | import numpy
47 |
48 |
49 | def read_file_list(filename):
50 | """
51 | Reads a trajectory from a text file.
52 |
53 | File format:
54 | The file format is "stamp d1 d2 d3 ...", where stamp denotes the time stamp (to be matched)
55 | and "d1 d2 d3.." is arbitary data (e.g., a 3D position and 3D orientation) associated to this timestamp.
56 |
57 | Input:
58 | filename -- File name
59 |
60 | Output:
61 | dict -- dictionary of (stamp,data) tuples
62 |
63 | """
64 | file = open(filename)
65 | data = file.read()
66 | lines = data.replace(","," ").replace("\t"," ").split("\n")
67 | list = [[v.strip() for v in line.split(" ") if v.strip()!=""] for line in lines if len(line)>0 and line[0]!="#"]
68 | list = [(float(l[0]),l[1:]) for l in list if len(l)>1]
69 | return dict(list)
70 |
71 | def associate(first_list, second_list,offset=0.0,max_difference=0.02):
72 | """
73 | Associate two dictionaries of (stamp,data). As the time stamps never match exactly, we aim
74 | to find the closest match for every input tuple.
75 |
76 | Input:
77 | first_list -- first dictionary of (stamp,data) tuples
78 | second_list -- second dictionary of (stamp,data) tuples
79 | offset -- time offset between both dictionaries (e.g., to model the delay between the sensors)
80 | max_difference -- search radius for candidate generation
81 |
82 | Output:
83 | matches -- list of matched tuples ((stamp1,data1),(stamp2,data2))
84 |
85 | """
86 | first_keys = list(first_list.keys())
87 | second_keys = list(second_list.keys())
88 | potential_matches = [(abs(a - (b + offset)), a, b)
89 | for a in first_keys
90 | for b in second_keys
91 | if abs(a - (b + offset)) < max_difference]
92 | potential_matches.sort()
93 | matches = []
94 | for diff, a, b in potential_matches:
95 | if a in first_keys and b in second_keys:
96 | first_keys.remove(a)
97 | second_keys.remove(b)
98 | matches.append((a, b))
99 |
100 | matches.sort()
101 | return matches
102 |
103 | if __name__ == '__main__':
104 |
105 | # parse command line
106 | parser = argparse.ArgumentParser(description='''
107 | This script takes two data files with timestamps and associates them
108 | ''')
109 | parser.add_argument('first_file', help='first text file (format: timestamp data)')
110 | parser.add_argument('second_file', help='second text file (format: timestamp data)')
111 | parser.add_argument('--first_only', help='only output associated lines from first file', action='store_true')
112 | parser.add_argument('--offset', help='time offset added to the timestamps of the second file (default: 0.0)',default=0.0)
113 | parser.add_argument('--max_difference', help='maximally allowed time difference for matching entries (default: 0.02)',default=0.02)
114 | args = parser.parse_args()
115 |
116 | first_list = read_file_list(args.first_file)
117 | second_list = read_file_list(args.second_file)
118 |
119 | matches = associate(first_list, second_list,float(args.offset),float(args.max_difference))
120 |
121 | if args.first_only:
122 | for a,b in matches:
123 | print("%f %s"%(a," ".join(first_list[a])))
124 | else:
125 | for a,b in matches:
126 | print("%f %s %f %s"%(a," ".join(first_list[a]),b-float(args.offset)," ".join(second_list[b])))
127 |
128 |
129 |
--------------------------------------------------------------------------------
/examples/core/data_readers/tartan.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np
3 | import torch
4 | import glob
5 | import cv2
6 | import os
7 | import os.path as osp
8 |
9 | from .base import RGBDDataset
10 | from .stream import RGBDStream
11 |
12 | class TartanAir(RGBDDataset):
13 |
14 | # scale depths to balance rot & trans
15 | DEPTH_SCALE = 5.0
16 | TEST_SET = ['westerndesert', 'seasidetown', 'seasonsforest_winter', 'office2', 'gascola']
17 |
18 | def __init__(self, mode='training', **kwargs):
19 | self.mode = mode
20 | self.n_frames = 2
21 | super(TartanAir, self).__init__(root='datasets/TartanAir', name='TartanAir', **kwargs)
22 |
23 | @staticmethod
24 | def is_test_scene(scene):
25 | return scene.split('/')[-3] in TartanAir.TEST_SET
26 |
27 | def _build_dataset(self):
28 | from tqdm import tqdm
29 | print("Building TartanAir dataset")
30 |
31 | scene_info = {}
32 | scenes = glob.glob(osp.join(self.root, '*/*/*/*'))
33 | for scene in tqdm(sorted(scenes)):
34 | images = sorted(glob.glob(osp.join(scene, 'image_left/*.png')))
35 | depths = sorted(glob.glob(osp.join(scene, 'depth_left/*.npy')))
36 |
37 | poses = np.loadtxt(osp.join(scene, 'pose_left.txt'), delimiter=' ')
38 | poses = poses[:, [1, 2, 0, 4, 5, 3, 6]]
39 | poses[:,:3] /= TartanAir.DEPTH_SCALE
40 | intrinsics = [TartanAir.calib_read()] * len(images)
41 |
42 | # graph of co-visible frames based on flow
43 | graph = self.build_frame_graph(poses, depths, intrinsics)
44 |
45 | scene = '/'.join(scene.split('/'))
46 | scene_info[scene] = {'images': images, 'depths': depths,
47 | 'poses': poses, 'intrinsics': intrinsics, 'graph': graph}
48 |
49 | return scene_info
50 |
51 | @staticmethod
52 | def calib_read():
53 | return np.array([320.0, 320.0, 320.0, 240.0])
54 |
55 | @staticmethod
56 | def image_read(image_file):
57 | return cv2.imread(image_file)
58 |
59 | @staticmethod
60 | def depth_read(depth_file):
61 | depth = np.load(depth_file) / TartanAir.DEPTH_SCALE
62 | depth[depth==np.nan] = 1.0
63 | depth[depth==np.inf] = 1.0
64 | return depth
65 |
66 |
67 | class TartanAirTest(torch.utils.data.Dataset):
68 | def __init__(self, root='datasets/Tartan'):
69 | self.root = root
70 | self.dataset_index = []
71 |
72 | self.scene_info = {}
73 | scenes = glob.glob(osp.join(self.root, '*/*/*/*'))
74 |
75 | for scene in sorted(scenes):
76 | image_glob = osp.join(scene, 'image_left/*.png')
77 | depth_glob = osp.join(scene, 'depth_left/*.npy')
78 | images = sorted(glob.glob(image_glob))
79 | depths = sorted(glob.glob(depth_glob))
80 |
81 | poses = np.loadtxt(osp.join(scene, 'pose_left.txt'), delimiter=' ')
82 | poses = poses[:, [1, 2, 0, 4, 5, 3, 6]]
83 | poses[:,:3] /= TartanAir.DEPTH_SCALE
84 | intrinsics = [TartanAir.calib_read()] * len(images)
85 |
86 | self.scene_info[scene] = {'images': images,
87 | 'depths': depths, 'poses': poses, 'intrinsics': intrinsics}
88 |
89 | with open('assets/tartan_test.txt') as f:
90 | self.dataset_index = f.readlines()
91 |
92 | def __getitem__(self, index):
93 | """ load test example """
94 |
95 | scene_id, ix1, ix2 = self.dataset_index[index].split()
96 | inds = [int(ix1), int(ix2)]
97 |
98 | images_list = self.scene_info[scene_id]['images']
99 | depths_list = self.scene_info[scene_id]['depths']
100 | poses_list = self.scene_info[scene_id]['poses']
101 | intrinsics_list = self.scene_info[scene_id]['intrinsics']
102 |
103 | images, depths, poses, intrinsics = [], [], [], []
104 | for i in inds:
105 | images.append(TartanAir.image_read(images_list[i]))
106 | depths.append(TartanAir.depth_read(depths_list[i]))
107 | poses.append(poses_list[i])
108 | intrinsics.append(intrinsics_list[i])
109 |
110 | images = np.stack(images).astype(np.float32)
111 | depths = np.stack(depths).astype(np.float32)
112 | poses = np.stack(poses).astype(np.float32)
113 | intrinsics = np.stack(intrinsics).astype(np.float32)
114 |
115 | images = torch.from_numpy(images).float()
116 | images = images.permute(0, 3, 1, 2)
117 |
118 | depths = torch.from_numpy(depths)
119 | poses = torch.from_numpy(poses)
120 | intrinsics = torch.from_numpy(intrinsics)
121 |
122 | return images, poses, depths, intrinsics
123 |
124 | def __len__(self):
125 | return len(self.dataset_index)
126 |
127 | class TartanAirStream(RGBDStream):
128 | def __init__(self, datapath, **kwargs):
129 | super(TartanAirStream, self).__init__(datapath=datapath, **kwargs)
130 |
131 | def _build_dataset_index(self):
132 | """ build list of images, poses, depths, and intrinsics """
133 | images, poses, depths, intrinsics = loadtum(self.datapath)
134 | intrinsic = NYUv2.TUMStream(self.datapath)
135 | intrinsics = np.tile(intrinsic[None], (len(images), 1))
136 |
137 | self.images = images
138 | self.poses = poses
139 | self.depths = depths
140 | self.intrinsics = intrinsics
141 |
142 | @staticmethod
143 | def calib_read(datapath):
144 | return np.array([320.0, 320.0, 320.0, 240.0])
145 |
146 | @staticmethod
147 | def image_read(image_file):
148 | return cv2.imread(image_file)
149 |
150 | @staticmethod
151 | def depth_read(depth_file):
152 | depth = cv2.imread(depth_file, cv2.IMREAD_ANYDEPTH)
153 | return depth.astype(np.float32) / 5000.0
154 |
--------------------------------------------------------------------------------
/lietorch/extras/corr_index_kernel.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 | #include
6 | #include
7 |
8 |
9 | #include
10 | #include
11 | #include
12 |
13 | #define BLOCK 16
14 |
15 | __forceinline__ __device__ bool within_bounds(int h, int w, int H, int W) {
16 | return h >= 0 && h < H && w >= 0 && w < W;
17 | }
18 |
19 | template
20 | __global__ void corr_index_forward_kernel(
21 | const torch::PackedTensorAccessor32 volume,
22 | const torch::PackedTensorAccessor32 coords,
23 | torch::PackedTensorAccessor32 corr,
24 | int r)
25 | {
26 | // batch index
27 | const int x = blockIdx.x * blockDim.x + threadIdx.x;
28 | const int y = blockIdx.y * blockDim.y + threadIdx.y;
29 | const int n = blockIdx.z;
30 |
31 | const int h1 = volume.size(1);
32 | const int w1 = volume.size(2);
33 | const int h2 = volume.size(3);
34 | const int w2 = volume.size(4);
35 |
36 | if (!within_bounds(y, x, h1, w1)) {
37 | return;
38 | }
39 |
40 | float x0 = coords[n][0][y][x];
41 | float y0 = coords[n][1][y][x];
42 |
43 | float dx = x0 - floor(x0);
44 | float dy = y0 - floor(y0);
45 |
46 | int rd = 2*r + 1;
47 | for (int i=0; i(floor(x0)) - r + i;
50 | int y1 = static_cast(floor(y0)) - r + j;
51 |
52 | if (within_bounds(y1, x1, h2, w2)) {
53 | scalar_t s = volume[n][y][x][y1][x1];
54 |
55 | if (i > 0 && j > 0)
56 | corr[n][i-1][j-1][y][x] += s * scalar_t(dx * dy);
57 |
58 | if (i > 0 && j < rd)
59 | corr[n][i-1][j][y][x] += s * scalar_t(dx * (1.0f-dy));
60 |
61 | if (i < rd && j > 0)
62 | corr[n][i][j-1][y][x] += s * scalar_t((1.0f-dx) * dy);
63 |
64 | if (i < rd && j < rd)
65 | corr[n][i][j][y][x] += s * scalar_t((1.0f-dx) * (1.0f-dy));
66 |
67 | }
68 | }
69 | }
70 | }
71 |
72 |
73 | template
74 | __global__ void corr_index_backward_kernel(
75 | const torch::PackedTensorAccessor32 coords,
76 | const torch::PackedTensorAccessor32 corr_grad,
77 | torch::PackedTensorAccessor32 volume_grad,
78 | int r)
79 | {
80 | // batch index
81 | const int x = blockIdx.x * blockDim.x + threadIdx.x;
82 | const int y = blockIdx.y * blockDim.y + threadIdx.y;
83 | const int n = blockIdx.z;
84 |
85 | const int h1 = volume_grad.size(1);
86 | const int w1 = volume_grad.size(2);
87 | const int h2 = volume_grad.size(3);
88 | const int w2 = volume_grad.size(4);
89 |
90 | if (!within_bounds(y, x, h1, w1)) {
91 | return;
92 | }
93 |
94 | float x0 = coords[n][0][y][x];
95 | float y0 = coords[n][1][y][x];
96 |
97 | float dx = x0 - floor(x0);
98 | float dy = y0 - floor(y0);
99 |
100 | int rd = 2*r + 1;
101 | for (int i=0; i(floor(x0)) - r + i;
104 | int y1 = static_cast(floor(y0)) - r + j;
105 |
106 | if (within_bounds(y1, x1, h2, w2)) {
107 | scalar_t g = 0.0;
108 | if (i > 0 && j > 0)
109 | g += corr_grad[n][i-1][j-1][y][x] * scalar_t(dx * dy);
110 |
111 | if (i > 0 && j < rd)
112 | g += corr_grad[n][i-1][j][y][x] * scalar_t(dx * (1.0f-dy));
113 |
114 | if (i < rd && j > 0)
115 | g += corr_grad[n][i][j-1][y][x] * scalar_t((1.0f-dx) * dy);
116 |
117 | if (i < rd && j < rd)
118 | g += corr_grad[n][i][j][y][x] * scalar_t((1.0f-dx) * (1.0f-dy));
119 |
120 | volume_grad[n][y][x][y1][x1] += g;
121 | }
122 | }
123 | }
124 | }
125 |
126 | std::vector corr_index_cuda_forward(
127 | torch::Tensor volume,
128 | torch::Tensor coords,
129 | int radius)
130 | {
131 | const auto batch_size = volume.size(0);
132 | const auto ht = volume.size(1);
133 | const auto wd = volume.size(2);
134 |
135 | const dim3 blocks((wd + BLOCK - 1) / BLOCK,
136 | (ht + BLOCK - 1) / BLOCK,
137 | batch_size);
138 |
139 | const dim3 threads(BLOCK, BLOCK);
140 |
141 | auto opts = volume.options();
142 | torch::Tensor corr = torch::zeros(
143 | {batch_size, 2*radius+1, 2*radius+1, ht, wd}, opts);
144 |
145 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(volume.scalar_type(), "sampler_forward_kernel", ([&] {
146 | corr_index_forward_kernel<<>>(
147 | volume.packed_accessor32(),
148 | coords.packed_accessor32(),
149 | corr.packed_accessor32(),
150 | radius);
151 | }));
152 |
153 | return {corr};
154 |
155 | }
156 |
157 | std::vector corr_index_cuda_backward(
158 | torch::Tensor volume,
159 | torch::Tensor coords,
160 | torch::Tensor corr_grad,
161 | int radius)
162 | {
163 | const auto batch_size = volume.size(0);
164 | const auto ht = volume.size(1);
165 | const auto wd = volume.size(2);
166 |
167 | auto volume_grad = torch::zeros_like(volume);
168 |
169 | const dim3 blocks((wd + BLOCK - 1) / BLOCK,
170 | (ht + BLOCK - 1) / BLOCK,
171 | batch_size);
172 |
173 | const dim3 threads(BLOCK, BLOCK);
174 |
175 |
176 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(volume.scalar_type(), "sampler_backward_kernel", ([&] {
177 | corr_index_backward_kernel<<>>(
178 | coords.packed_accessor32(),
179 | corr_grad.packed_accessor32(),
180 | volume_grad.packed_accessor32(),
181 | radius);
182 | }));
183 |
184 | return {volume_grad};
185 | }
186 |
187 |
188 |
189 |
190 |
191 |
192 |
--------------------------------------------------------------------------------
/examples/registration/main.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('../core')
3 |
4 | import argparse
5 | import torch
6 | import cv2
7 | import numpy as np
8 | from collections import OrderedDict
9 |
10 | import torch.optim as optim
11 | import torch.nn.functional as F
12 | from torch.utils.data import DataLoader
13 | from data_readers.tartan import TartanAir, TartanAirTest
14 |
15 | from lietorch import SO3, SE3, Sim3
16 | from geom.losses import *
17 |
18 | # network
19 | from networks.sim3_net import Sim3Net
20 | from logger import Logger
21 |
22 |
23 | def show_image(image):
24 | image = image.permute(1, 2, 0).cpu().numpy()
25 | cv2.imshow('image', image / 255.0)
26 | cv2.waitKey()
27 |
28 | def normalize_images(images):
29 | images = images[:, :, [2,1,0]]
30 | mean = torch.as_tensor([0.485, 0.456, 0.406], device=images.device)
31 | std = torch.as_tensor([0.229, 0.224, 0.225], device=images.device)
32 | return (images/255.0).sub_(mean[:, None, None]).div_(std[:, None, None])
33 |
34 | @torch.no_grad()
35 | def evaluate(model):
36 | """ evaluate trained model """
37 |
38 | model.cuda()
39 | model.eval()
40 |
41 | R_THRESHOLD = 0.1
42 | T_THRESHOLD = 0.01
43 | S_THRESHOLD = 0.01
44 |
45 | model.eval()
46 | db = TartanAirTest()
47 | test_loader = DataLoader(db, batch_size=1, shuffle=False, num_workers=4)
48 |
49 | # random scales, make sure they are the same every time
50 | from numpy.random import default_rng
51 | rng = default_rng(1234)
52 | scales = 2 ** rng.uniform(-1.0, 1.0, 2000)
53 | scales = scales.astype(np.float32)
54 |
55 | metrics = {'t': [], 'r': [], 's': []}
56 | for i_batch, item in enumerate(test_loader):
57 | images, poses, depths, intrinsics = [x.to('cuda') for x in item]
58 |
59 | # convert poses w2c -> c2w
60 | Ps = SE3(poses).inv()
61 | batch, num = images.shape[:2]
62 |
63 | if args.transformation == 'SE3':
64 | Gs = SE3.Identity(Ps.shape, device='cuda')
65 |
66 | elif args.transformation == 'Sim3':
67 | Ps = Sim3(Ps)
68 | Gs = Sim3.Identity(Ps.shape, device='cuda')
69 |
70 | s = torch.as_tensor(scales[i_batch]).cuda().unsqueeze(0)
71 | phi = torch.zeros(batch, num, 7, device='cuda')
72 | phi[:,0,6] = s.log()
73 |
74 | Ps = Sim3.exp(phi) * Ps
75 | depths[:,0] *= s[:,None,None]
76 |
77 | images = normalize_images(images)
78 | Gs, _ = model(Gs, images, depths, intrinsics, num_steps=16)
79 |
80 | Gs = Gs[-1]
81 | dP = Ps[:,1] * Ps[:,0].inv()
82 | dG = Gs[:,1] * Gs[:,0].inv()
83 |
84 | dE = Sim3(dP.inv() * dG)
85 | r_err, t_err, s_err = pose_metrics(dE)
86 |
87 | t_err = t_err * TartanAir.DEPTH_SCALE
88 |
89 | metrics['t'].append(t_err.item())
90 | metrics['r'].append(r_err.item())
91 | metrics['s'].append(s_err.item())
92 |
93 | rlist = np.array(metrics['r'])
94 | tlist = np.array(metrics['t'])
95 | slist = np.array(metrics['s'])
96 |
97 | r_all = np.count_nonzero(rlist < R_THRESHOLD) / len(metrics['r'])
98 | t_all = np.count_nonzero(tlist < T_THRESHOLD) / len(metrics['t'])
99 | s_all = np.count_nonzero(slist < S_THRESHOLD) / len(metrics['s'])
100 |
101 | print("Rotation Acc: ", r_all)
102 | print("Translation Acc: ", t_all)
103 | print("Scale Acc: ", s_all)
104 |
105 |
106 | def train(args):
107 | """ Test to make sure project transform correctly maps points """
108 |
109 | model = Sim3Net(args)
110 | model.cuda()
111 | model.train()
112 |
113 | if args.ckpt is not None:
114 | model.load_state_dict(torch.load(args.ckpt))
115 |
116 | db = TartanAir(mode='training', n_frames=2, do_aug=True, fmin=8.0, fmax=100.0)
117 | train_loader = DataLoader(db, batch_size=args.batch, shuffle=True, num_workers=4)
118 |
119 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5)
120 | scheduler = optim.lr_scheduler.OneCycleLR(optimizer,
121 | args.lr, 100000, pct_start=0.01, cycle_momentum=False)
122 |
123 | from collections import OrderedDict
124 | graph = OrderedDict()
125 | graph[0] = [1]
126 | graph[1] = [0]
127 |
128 | logger = Logger(args.name, scheduler)
129 | should_keep_training = True
130 | total_steps = 0
131 |
132 | while should_keep_training:
133 | for i_batch, item in enumerate(train_loader):
134 | optimizer.zero_grad()
135 | images, poses, depths, intrinsics = [x.to('cuda') for x in item]
136 |
137 | # convert poses w2c -> c2w
138 | Ps = SE3(poses).inv()
139 | batch, num = images.shape[:2]
140 |
141 | if args.transformation == 'SE3':
142 | Gs = SE3.Identity(Ps.shape, device='cuda')
143 |
144 | elif args.transformation == 'Sim3':
145 | Ps = Sim3(Ps)
146 | Gs = Sim3.Identity(Ps.shape, device='cuda')
147 |
148 | s = 2**(2*torch.rand(batch) - 1.0).cuda()
149 | phi = torch.zeros(batch, num, 7, device='cuda')
150 | phi[:,0,6] = s.log()
151 |
152 | Ps = Sim3.exp(phi) * Ps
153 | depths[:,0] *= s[:,None,None]
154 |
155 | images = normalize_images(images)
156 | Gs, residuals = model(Gs, images, depths, intrinsics, num_steps=args.iters)
157 |
158 | geo_loss, geo_metrics = geodesic_loss(Ps, Gs, graph)
159 | res_loss, res_metrics = residual_loss(residuals)
160 |
161 | metrics = {}
162 | metrics.update(geo_metrics)
163 | metrics.update(res_metrics)
164 |
165 | loss = args.w1 * geo_loss + args.w2 * res_loss
166 | loss.backward()
167 |
168 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
169 | optimizer.step()
170 | scheduler.step()
171 |
172 | logger.push(metrics)
173 | total_steps += 1
174 |
175 | if total_steps % 5000 == 0:
176 | PATH = 'checkpoints/%s_%06d.pth' % (args.name, total_steps)
177 | torch.save(model.state_dict(), PATH)
178 |
179 | model.train()
180 |
181 | return model
182 |
183 |
184 | if __name__ == '__main__':
185 | parser = argparse.ArgumentParser()
186 | parser.add_argument('--name', default='bla', help='name your experiment')
187 | parser.add_argument('--transformation', default='SE3', help='checkpoint to restore')
188 | parser.add_argument('--ckpt', help='checkpoint to restore')
189 | parser.add_argument('--train', action='store_true')
190 |
191 | parser.add_argument('--batch', type=int, default=4)
192 | parser.add_argument('--iters', type=int, default=8)
193 | parser.add_argument('--lr', type=float, default=0.00025)
194 | parser.add_argument('--clip', type=float, default=2.5)
195 |
196 | parser.add_argument('--w1', type=float, default=10.0)
197 | parser.add_argument('--w2', type=float, default=0.1)
198 |
199 |
200 | args = parser.parse_args()
201 |
202 | if args.train:
203 | import os
204 | if not os.path.isdir('checkpoints'):
205 | os.mkdir('checkpoints')
206 |
207 | model = train(args)
208 |
209 | else:
210 | model = Sim3Net(args)
211 | model.load_state_dict(torch.load(args.ckpt))
212 |
213 | evaluate(model)
214 |
215 |
--------------------------------------------------------------------------------
/lietorch/include/sim3.h:
--------------------------------------------------------------------------------
1 |
2 | #ifndef Sim3_HEADER
3 | #define Sim3_HEADER
4 |
5 | #include
6 | #include
7 |
8 | #include
9 | #include
10 |
11 | #include "common.h"
12 | #include "so3.h"
13 | #include "rxso3.h"
14 |
15 |
16 | template
17 | class Sim3 {
18 | public:
19 | const static int constexpr K = 7; // manifold dimension
20 | const static int constexpr N = 8; // embedding dimension
21 |
22 | using Vector3 = Eigen::Matrix;
23 | using Vector4 = Eigen::Matrix;
24 | using Matrix3 = Eigen::Matrix;
25 |
26 | using Tangent = Eigen::Matrix;
27 | using Point = Eigen::Matrix;
28 | using Point4 = Eigen::Matrix;
29 | using Data = Eigen::Matrix;
30 | using Transformation = Eigen::Matrix;
31 | using Adjoint = Eigen::Matrix;
32 |
33 | EIGEN_DEVICE_FUNC Sim3() {
34 | translation = Vector3::Zero();
35 | }
36 |
37 | EIGEN_DEVICE_FUNC Sim3(RxSO3 const& rxso3, Vector3 const& t)
38 | : rxso3(rxso3), translation(t) {};
39 |
40 | EIGEN_DEVICE_FUNC Sim3(const Scalar *data)
41 | : translation(data), rxso3(data+3) {};
42 |
43 | EIGEN_DEVICE_FUNC Sim3 inv() {
44 | return Sim3(rxso3.inv(), -(rxso3.inv() * translation));
45 | }
46 |
47 | EIGEN_DEVICE_FUNC Data data() const {
48 | Data data_vec; data_vec << translation, rxso3.data();
49 | return data_vec;
50 | }
51 |
52 | EIGEN_DEVICE_FUNC Sim3 operator*(Sim3 const& other) {
53 | return Sim3(rxso3 * other.rxso3, translation + rxso3 * other.translation);
54 | }
55 |
56 | EIGEN_DEVICE_FUNC Point operator*(Point const& p) const {
57 | return (rxso3 * p) + translation;
58 | }
59 |
60 | EIGEN_DEVICE_FUNC Point4 act4(Point4 const& p) const {
61 | Point4 p1; p1 << rxso3 * p.template segment<3>(0) + p(3) * translation , p(3);
62 | return p1;
63 | }
64 |
65 | EIGEN_DEVICE_FUNC Transformation Matrix() const {
66 | Transformation T = Transformation::Identity();
67 | T.template block<3,3>(0,0) = rxso3.Matrix();
68 | T.template block<3,1>(0,3) = translation;
69 | return T;
70 | }
71 |
72 | EIGEN_DEVICE_FUNC Transformation Matrix4x4() const {
73 | Transformation T = Transformation::Identity();
74 | T.template block<3,3>(0,0) = rxso3.Matrix();
75 | T.template block<3,1>(0,3) = translation;
76 | return T;
77 | }
78 |
79 | EIGEN_DEVICE_FUNC Eigen::Matrix orthogonal_projector() const {
80 | // jacobian action on a point
81 | Eigen::Matrix J = Eigen::Matrix::Zero();
82 | J.template block<3,3>(0,0) = Matrix3::Identity();
83 | J.template block<3,3>(0,3) = SO3::hat(-translation);
84 | J.template block<3,1>(0,6) = translation;
85 | J.template block<5,5>(3,3) = rxso3.orthogonal_projector();
86 | return J;
87 | }
88 |
89 | EIGEN_DEVICE_FUNC Adjoint Adj() const {
90 | Adjoint Ad = Adjoint::Identity();
91 | Matrix3 sR = rxso3.Matrix();
92 | Matrix3 tx = SO3::hat(translation);
93 | Matrix3 R = rxso3.Rotation();
94 |
95 | Ad.template block<3,3>(0,0) = sR;
96 | Ad.template block<3,3>(0,3) = tx * R;
97 | Ad.template block<3,1>(0,6) = -translation;
98 | Ad.template block<3,3>(3,3) = R;
99 |
100 | return Ad;
101 | }
102 |
103 | EIGEN_DEVICE_FUNC Tangent Adj(Tangent const& a) const {
104 | return Adj() * a;
105 | }
106 |
107 | EIGEN_DEVICE_FUNC Tangent AdjT(Tangent const& a) const {
108 | return Adj().transpose() * a;
109 | }
110 |
111 | EIGEN_DEVICE_FUNC static Transformation hat(Tangent const& tau_phi_sigma) {
112 | Vector3 tau = tau_phi_sigma.template segment<3>(0);
113 | Vector3 phi = tau_phi_sigma.template segment<3>(3);
114 | Scalar sigma = tau_phi_sigma(6);
115 |
116 | Matrix3 Phi = SO3::hat(phi);
117 | Matrix3 I = Matrix3::Identity();
118 |
119 | Transformation Omega = Transformation::Zero();
120 | Omega.template block<3,3>(0,0) = Phi + sigma * I;
121 | Omega.template block<3,1>(0,3) = tau;
122 |
123 | return Omega;
124 | }
125 |
126 | EIGEN_DEVICE_FUNC static Adjoint adj(Tangent const& tau_phi_sigma) {
127 | Adjoint ad = Adjoint::Zero();
128 | Vector3 tau = tau_phi_sigma.template segment<3>(0);
129 | Vector3 phi = tau_phi_sigma.template segment<3>(3);
130 | Scalar sigma = tau_phi_sigma(6);
131 |
132 | Matrix3 Tau = SO3::hat(tau);
133 | Matrix3 Phi = SO3::hat(phi);
134 | Matrix3 I = Matrix3::Identity();
135 |
136 | ad.template block<3,3>(0,0) = Phi + sigma * I;
137 | ad.template block<3,3>(0,3) = Tau;
138 | ad.template block<3,1>(0,6) = -tau;
139 | ad.template block<3,3>(3,3) = Phi;
140 |
141 | return ad;
142 | }
143 |
144 |
145 | EIGEN_DEVICE_FUNC Tangent Log() const {
146 | // logarithm map
147 | Vector4 phi_sigma = rxso3.Log();
148 | Matrix3 W = RxSO3::calcW(phi_sigma);
149 |
150 | Tangent tau_phi_sigma;
151 | tau_phi_sigma << W.inverse() * translation, phi_sigma;
152 |
153 | return tau_phi_sigma;
154 | }
155 |
156 | EIGEN_DEVICE_FUNC static Sim3 Exp(Tangent const& tau_phi_sigma) {
157 | // exponential map
158 | Vector3 tau = tau_phi_sigma.template segment<3>(0);
159 | Vector4 phi_sigma = tau_phi_sigma.template segment<4>(3);
160 |
161 | RxSO3 rxso3 = RxSO3::Exp(phi_sigma);
162 | Matrix3 W = RxSO3::calcW(phi_sigma);
163 |
164 | return Sim3(rxso3, W*tau);
165 | }
166 |
167 | EIGEN_DEVICE_FUNC static Adjoint left_jacobian(Tangent const& tau_phi_sigma) {
168 | // left jacobian
169 | Adjoint const Xi = adj(tau_phi_sigma);
170 | Adjoint const Xi2 = Xi * Xi;
171 | Adjoint const Xi4 = Xi2 * Xi2;
172 |
173 | return Adjoint::Identity()
174 | + Scalar(1.0/2.0)*Xi
175 | + Scalar(1.0/6.0)*Xi2
176 | + Scalar(1.0/24.0)*Xi*Xi2
177 | + Scalar(1.0/120.0)*Xi4;
178 | + Scalar(1.0/720.0)*Xi*Xi4;
179 | }
180 |
181 | EIGEN_DEVICE_FUNC static Adjoint left_jacobian_inverse(Tangent const& tau_phi_sigma) {
182 | // left jacobian inverse
183 | Adjoint const Xi = adj(tau_phi_sigma);
184 | Adjoint const Xi2 = Xi * Xi;
185 | Adjoint const Xi4 = Xi2 * Xi2;
186 |
187 | return Adjoint::Identity()
188 | - Scalar(1.0/2.0)*Xi
189 | + Scalar(1.0/12.0)*Xi2
190 | - Scalar(1.0/720.0)*Xi4;
191 | }
192 |
193 | EIGEN_DEVICE_FUNC static Eigen::Matrix act_jacobian(Point const& p) {
194 | // jacobian action on a point
195 | Eigen::Matrix J;
196 | J.template block<3,3>(0,0) = Matrix3::Identity();
197 | J.template block<3,3>(0,3) = SO3::hat(-p);
198 | J.template block<3,1>(0,6) = p;
199 | return J;
200 | }
201 |
202 | EIGEN_DEVICE_FUNC static Eigen::Matrix act4_jacobian(Point4 const& p) {
203 | // jacobian action on a point
204 | Eigen::Matrix J = Eigen::Matrix::Zero();
205 | J.template block<3,3>(0,0) = p(3) * Matrix3::Identity();
206 | J.template block<3,3>(0,3) = SO3::hat(-p.template segment<3>(0));
207 | J.template block<3,1>(0,6) = p.template segment<3>(0);
208 | return J;
209 | }
210 |
211 | private:
212 | Vector3 translation;
213 | RxSO3 rxso3;
214 | };
215 |
216 | #endif
217 |
218 |
--------------------------------------------------------------------------------
/lietorch/include/se3.h:
--------------------------------------------------------------------------------
1 |
2 | #ifndef SE3_HEADER
3 | #define SE3_HEADER
4 |
5 | #include
6 | #include
7 | #include
8 |
9 | #include "common.h"
10 | #include "so3.h"
11 |
12 |
13 | template
14 | class SE3 {
15 | public:
16 | const static int constexpr K = 6; // manifold dimension
17 | const static int constexpr N = 7; // embedding dimension
18 |
19 | using Vector3 = Eigen::Matrix;
20 | using Vector4 = Eigen::Matrix;
21 | using Matrix3 = Eigen::Matrix;
22 |
23 | using Tangent = Eigen::Matrix;
24 | using Point = Eigen::Matrix;
25 | using Point4 = Eigen::Matrix;
26 | using Data = Eigen::Matrix;
27 | using Transformation = Eigen::Matrix;
28 | using Adjoint = Eigen::Matrix;
29 |
30 | EIGEN_DEVICE_FUNC SE3() { translation = Vector3::Zero(); }
31 |
32 | EIGEN_DEVICE_FUNC SE3(SO3 const& so3, Vector3 const& t) : so3(so3), translation(t) {};
33 |
34 | EIGEN_DEVICE_FUNC SE3(const Scalar *data) : translation(data), so3(data+3) {};
35 |
36 | EIGEN_DEVICE_FUNC SE3 inv() {
37 | return SE3(so3.inv(), -(so3.inv()*translation));
38 | }
39 |
40 | EIGEN_DEVICE_FUNC Data data() const {
41 | Data data_vec; data_vec << translation, so3.data();
42 | return data_vec;
43 | }
44 |
45 | EIGEN_DEVICE_FUNC SE3 operator*(SE3 const& other) {
46 | return SE3(so3 * other.so3, translation + so3 * other.translation);
47 | }
48 |
49 | EIGEN_DEVICE_FUNC Point operator*(Point const& p) const {
50 | return so3 * p + translation;
51 | }
52 |
53 | EIGEN_DEVICE_FUNC Point4 act4(Point4 const& p) const {
54 | Point4 p1; p1 << so3 * p.template segment<3>(0) + translation * p(3), p(3);
55 | return p1;
56 | }
57 |
58 | EIGEN_DEVICE_FUNC Adjoint Adj() const {
59 | Matrix3 R = so3.Matrix();
60 | Matrix3 tx = SO3::hat(translation);
61 | Matrix3 Zer = Matrix3::Zero();
62 |
63 | Adjoint Ad;
64 | Ad << R, tx*R, Zer, R;
65 |
66 | return Ad;
67 | }
68 |
69 | EIGEN_DEVICE_FUNC Transformation Matrix() const {
70 | Transformation T = Transformation::Identity();
71 | T.template block<3,3>(0,0) = so3.Matrix();
72 | T.template block<3,1>(0,3) = translation;
73 | return T;
74 | }
75 |
76 | EIGEN_DEVICE_FUNC Transformation Matrix4x4() const {
77 | return Matrix();
78 | }
79 |
80 | EIGEN_DEVICE_FUNC Tangent Adj(Tangent const& a) const {
81 | return Adj() * a;
82 | }
83 |
84 | EIGEN_DEVICE_FUNC Tangent AdjT(Tangent const& a) const {
85 | return Adj().transpose() * a;
86 | }
87 |
88 |
89 | EIGEN_DEVICE_FUNC static Transformation hat(Tangent const& tau_phi) {
90 | Vector3 tau = tau_phi.template segment<3>(0);
91 | Vector3 phi = tau_phi.template segment<3>(3);
92 |
93 | Transformation TauPhi = Transformation::Zero();
94 | TauPhi.template block<3,3>(0,0) = SO3::hat(phi);
95 | TauPhi.template block<3,1>(0,3) = tau;
96 |
97 | return TauPhi;
98 | }
99 |
100 | EIGEN_DEVICE_FUNC static Adjoint adj(Tangent const& tau_phi) {
101 | Vector3 tau = tau_phi.template segment<3>(0);
102 | Vector3 phi = tau_phi.template segment<3>(3);
103 |
104 | Matrix3 Tau = SO3::hat(tau);
105 | Matrix3 Phi = SO3::hat(phi);
106 | Matrix3 Zer = Matrix3::Zero();
107 |
108 | Adjoint ad;
109 | ad << Phi, Tau, Zer, Phi;
110 |
111 | return ad;
112 | }
113 |
114 | EIGEN_DEVICE_FUNC Eigen::Matrix orthogonal_projector() const {
115 | // jacobian action on a point
116 | Eigen::Matrix J = Eigen::Matrix::Zero();
117 | J.template block<3,3>(0,0) = Matrix3::Identity();
118 | J.template block<3,3>(0,3) = SO3::hat(-translation);
119 | J.template block<4,4>(3,3) = so3.orthogonal_projector();
120 |
121 | return J;
122 | }
123 |
124 | EIGEN_DEVICE_FUNC Tangent Log() const {
125 | Vector3 phi = so3.Log();
126 | Matrix3 Vinv = SO3::left_jacobian_inverse(phi);
127 |
128 | Tangent tau_phi;
129 | tau_phi << Vinv * translation, phi;
130 |
131 | return tau_phi;
132 | }
133 |
134 | EIGEN_DEVICE_FUNC static SE3 Exp(Tangent const& tau_phi) {
135 | Vector3 tau = tau_phi.template segment<3>(0);
136 | Vector3 phi = tau_phi.template segment<3>(3);
137 |
138 | SO3 so3 = SO3::Exp(phi);
139 | Vector3 t = SO3::left_jacobian(phi) * tau;
140 |
141 | return SE3(so3, t);
142 | }
143 |
144 | EIGEN_DEVICE_FUNC static Matrix3 calcQ(Tangent const& tau_phi) {
145 | // Q matrix
146 | Vector3 tau = tau_phi.template segment<3>(0);
147 | Vector3 phi = tau_phi.template segment<3>(3);
148 | Matrix3 Tau = SO3::hat(tau);
149 | Matrix3 Phi = SO3::hat(phi);
150 |
151 | Scalar theta = phi.norm();
152 | Scalar theta_pow2 = theta * theta;
153 | Scalar theta_pow4 = theta_pow2 * theta_pow2;
154 |
155 | Scalar coef1 = (theta < EPS) ?
156 | Scalar(1.0/6.0) - Scalar(1.0/120.0) * theta_pow2 :
157 | (theta - sin(theta)) / (theta_pow2 * theta);
158 |
159 | Scalar coef2 = (theta < EPS) ?
160 | Scalar(1.0/24.0) - Scalar(1.0/720.0) * theta_pow2 :
161 | (theta_pow2 + 2*cos(theta) - 2) / (2 * theta_pow4);
162 |
163 | Scalar coef3 = (theta < EPS) ?
164 | Scalar(1.0/120.0) - Scalar(1.0/2520.0) * theta_pow2 :
165 | (2*theta - 3*sin(theta) + theta*cos(theta)) / (2 * theta_pow4 * theta);
166 |
167 | Matrix3 Q = Scalar(0.5) * Tau +
168 | coef1 * (Phi*Tau + Tau*Phi + Phi*Tau*Phi) +
169 | coef2 * (Phi*Phi*Tau + Tau*Phi*Phi - 3*Phi*Tau*Phi) +
170 | coef3 * (Phi*Tau*Phi*Phi + Phi*Phi*Tau*Phi);
171 |
172 | return Q;
173 | }
174 |
175 | EIGEN_DEVICE_FUNC static Adjoint left_jacobian(Tangent const& tau_phi) {
176 | // left jacobian
177 | Vector3 phi = tau_phi.template segment<3>(3);
178 | Matrix3 J = SO3::left_jacobian(phi);
179 | Matrix3 Q = SE3::calcQ(tau_phi);
180 | Matrix3 Zer = Matrix3::Zero();
181 |
182 | Adjoint J6x6;
183 | J6x6 << J, Q, Zer, J;
184 |
185 | return J6x6;
186 | }
187 |
188 | EIGEN_DEVICE_FUNC static Adjoint left_jacobian_inverse(Tangent const& tau_phi) {
189 | // left jacobian inverse
190 | Vector3 tau = tau_phi.template segment<3>(0);
191 | Vector3 phi = tau_phi.template segment<3>(3);
192 | Matrix3 Jinv = SO3::left_jacobian_inverse(phi);
193 | Matrix3 Q = SE3::calcQ(tau_phi);
194 | Matrix3 Zer = Matrix3::Zero();
195 |
196 | Adjoint J6x6;
197 | J6x6 << Jinv, -Jinv * Q * Jinv, Zer, Jinv;
198 |
199 | return J6x6;
200 |
201 | }
202 |
203 | EIGEN_DEVICE_FUNC static Eigen::Matrix act_jacobian(Point const& p) {
204 | // jacobian action on a point
205 | Eigen::Matrix J;
206 | J.template block<3,3>(0,0) = Matrix3::Identity();
207 | J.template block<3,3>(0,3) = SO3::hat(-p);
208 | return J;
209 | }
210 |
211 | EIGEN_DEVICE_FUNC static Eigen::Matrix act4_jacobian(Point4 const& p) {
212 | // jacobian action on a point
213 | Eigen::Matrix J = Eigen::Matrix::Zero();
214 | J.template block<3,3>(0,0) = p(3) * Matrix3::Identity();
215 | J.template block<3,3>(0,3) = SO3::hat(-p.template segment<3>(0));
216 | return J;
217 | }
218 |
219 |
220 |
221 |
222 | private:
223 | SO3 so3;
224 | Vector3 translation;
225 |
226 | };
227 |
228 | #endif
229 |
230 |
--------------------------------------------------------------------------------
/lietorch/include/so3.h:
--------------------------------------------------------------------------------
1 |
2 | #ifndef SO3_HEADER
3 | #define SO3_HEADER
4 |
5 | #include
6 | #include
7 | #include
8 | #include
9 |
10 | #include "common.h"
11 |
12 | template
13 | class SO3 {
14 | public:
15 | const static int constexpr K = 3; // manifold dimension
16 | const static int constexpr N = 4; // embedding dimension
17 |
18 | using Vector3 = Eigen::Matrix;
19 | using Vector4 = Eigen::Matrix;
20 | using Matrix3 = Eigen::Matrix;
21 |
22 | using Tangent = Eigen::Matrix;
23 | using Data = Eigen::Matrix;
24 |
25 | using Point = Eigen::Matrix;
26 | using Point4 = Eigen::Matrix;
27 | using Transformation = Eigen::Matrix;
28 | using Adjoint = Eigen::Matrix;
29 | using Quaternion = Eigen::Quaternion;
30 |
31 | EIGEN_DEVICE_FUNC SO3(Quaternion const& q) : unit_quaternion(q) {
32 | unit_quaternion.normalize();
33 | };
34 |
35 | EIGEN_DEVICE_FUNC SO3(const Scalar *data) : unit_quaternion(data) {
36 | unit_quaternion.normalize();
37 | };
38 |
39 | EIGEN_DEVICE_FUNC SO3() {
40 | unit_quaternion = Quaternion::Identity();
41 | }
42 |
43 | EIGEN_DEVICE_FUNC SO3 inv() {
44 | return SO3(unit_quaternion.conjugate());
45 | }
46 |
47 | EIGEN_DEVICE_FUNC Data data() const {
48 | return unit_quaternion.coeffs();
49 | }
50 |
51 | EIGEN_DEVICE_FUNC SO3 operator*(SO3 const& other) {
52 | return SO3(unit_quaternion * other.unit_quaternion);
53 | }
54 |
55 | EIGEN_DEVICE_FUNC Point operator*(Point const& p) const {
56 | const Quaternion& q = unit_quaternion;
57 | Point uv = q.vec().cross(p);
58 | uv += uv;
59 | return p + q.w()*uv + q.vec().cross(uv);
60 | }
61 |
62 | EIGEN_DEVICE_FUNC Point4 act4(Point4 const& p) const {
63 | Point4 p1; p1 << this->operator*(p.template segment<3>(0)), p(3);
64 | return p1;
65 | }
66 |
67 | EIGEN_DEVICE_FUNC Adjoint Adj() const {
68 | return unit_quaternion.toRotationMatrix();
69 | }
70 |
71 | EIGEN_DEVICE_FUNC Transformation Matrix() const {
72 | return unit_quaternion.toRotationMatrix();
73 | }
74 |
75 | EIGEN_DEVICE_FUNC Eigen::Matrix Matrix4x4() const {
76 | Eigen::Matrix T = Eigen::Matrix::Identity();
77 | T.template block<3,3>(0,0) = Matrix();
78 | return T;
79 | }
80 |
81 | EIGEN_DEVICE_FUNC Eigen::Matrix orthogonal_projector() const {
82 | // jacobian action on a point
83 | Eigen::Matrix J = Eigen::Matrix::Zero();
84 | J.template block<3,3>(0,0) = 0.5 * (
85 | unit_quaternion.w() * Matrix3::Identity() +
86 | SO3::hat(-unit_quaternion.vec())
87 | );
88 |
89 | J.template block<1,3>(3,0) = 0.5 * (-unit_quaternion.vec());
90 | return J;
91 | }
92 |
93 | EIGEN_DEVICE_FUNC Tangent Adj(Tangent const& a) const {
94 | return Adj() * a;
95 | }
96 |
97 | EIGEN_DEVICE_FUNC Tangent AdjT(Tangent const& a) const {
98 | return Adj().transpose() * a;
99 | }
100 |
101 | EIGEN_DEVICE_FUNC static Transformation hat(Tangent const& phi) {
102 | Transformation Phi;
103 | Phi <<
104 | 0.0, -phi(2), phi(1),
105 | phi(2), 0.0, -phi(0),
106 | -phi(1), phi(0), 0.0;
107 |
108 | return Phi;
109 | }
110 |
111 | EIGEN_DEVICE_FUNC static Adjoint adj(Tangent const& phi) {
112 | return SO3::hat(phi);
113 | }
114 |
115 | EIGEN_DEVICE_FUNC Tangent Log() const {
116 | using std::abs;
117 | using std::atan;
118 | using std::sqrt;
119 | Scalar squared_n = unit_quaternion.vec().squaredNorm();
120 | Scalar w = unit_quaternion.w();
121 |
122 | Scalar two_atan_nbyw_by_n;
123 |
124 | /// Atan-based log thanks to
125 | ///
126 | /// C. Hertzberg et al.:
127 | /// "Integrating Generic Sensor Fusion Algorithms with Sound State
128 | /// Representation through Encapsulation of Manifolds"
129 | /// Information Fusion, 2011
130 |
131 | if (squared_n < EPS * EPS) {
132 | // If quaternion is normalized and n=0, then w should be 1;
133 | // w=0 should never happen here!
134 | Scalar squared_w = w * w;
135 | two_atan_nbyw_by_n =
136 | Scalar(2) / w - Scalar(2.0/3.0) * (squared_n) / (w * squared_w);
137 | } else {
138 | Scalar n = sqrt(squared_n);
139 | if (abs(w) < EPS) {
140 | if (w > Scalar(0)) {
141 | two_atan_nbyw_by_n = Scalar(PI) / n;
142 | } else {
143 | two_atan_nbyw_by_n = -Scalar(PI) / n;
144 | }
145 | } else {
146 | two_atan_nbyw_by_n = Scalar(2) * atan(n / w) / n;
147 | }
148 | }
149 |
150 | return two_atan_nbyw_by_n * unit_quaternion.vec();
151 | }
152 |
153 | EIGEN_DEVICE_FUNC static SO3 Exp(Tangent const& phi) {
154 | Scalar theta2 = phi.squaredNorm();
155 | Scalar theta = sqrt(theta2);
156 | Scalar imag_factor;
157 | Scalar real_factor;
158 |
159 | if (theta < EPS) {
160 | Scalar theta4 = theta2 * theta2;
161 | imag_factor = Scalar(0.5) - Scalar(1.0/48.0) * theta2 + Scalar(1.0/3840.0) * theta4;
162 | real_factor = Scalar(1) - Scalar(1.0/8.0) * theta2 + Scalar(1.0/384.0) * theta4;
163 | } else {
164 | imag_factor = sin(.5 * theta) / theta;
165 | real_factor = cos(.5 * theta);
166 | }
167 |
168 | Quaternion q(real_factor, imag_factor*phi.x(), imag_factor*phi.y(), imag_factor*phi.z());
169 | return SO3(q);
170 | }
171 |
172 | EIGEN_DEVICE_FUNC static Adjoint left_jacobian(Tangent const& phi) {
173 | // left jacobian
174 | Matrix3 I = Matrix3::Identity();
175 | Matrix3 Phi = SO3::hat(phi);
176 | Matrix3 Phi2 = Phi * Phi;
177 |
178 | Scalar theta2 = phi.squaredNorm();
179 | Scalar theta = sqrt(theta2);
180 |
181 | Scalar coef1 = (theta < EPS) ?
182 | Scalar(1.0/2.0) - Scalar(1.0/24.0) * theta2 :
183 | (1.0 - cos(theta)) / theta2;
184 |
185 | Scalar coef2 = (theta < EPS) ?
186 | Scalar(1.0/6.0) - Scalar(1.0/120.0) * theta2 :
187 | (theta - sin(theta)) / (theta2 * theta);
188 |
189 | return I + coef1 * Phi + coef2 * Phi2;
190 | }
191 |
192 | EIGEN_DEVICE_FUNC static Adjoint left_jacobian_inverse(Tangent const& phi) {
193 | // left jacobian inverse
194 | Matrix3 I = Matrix3::Identity();
195 | Matrix3 Phi = SO3::hat(phi);
196 | Matrix3 Phi2 = Phi * Phi;
197 |
198 | Scalar theta2 = phi.squaredNorm();
199 | Scalar theta = sqrt(theta2);
200 | Scalar half_theta = Scalar(.5) * theta ;
201 |
202 | Scalar coef2 = (theta < EPS) ? Scalar(1.0/12.0) :
203 | (Scalar(1) -
204 | theta * cos(half_theta) / (Scalar(2) * sin(half_theta))) /
205 | (theta * theta);
206 |
207 | return I + Scalar(-0.5) * Phi + coef2 * Phi2;
208 | }
209 |
210 | EIGEN_DEVICE_FUNC static Eigen::Matrix act_jacobian(Point const& p) {
211 | // jacobian action on a point
212 | return SO3::hat(-p);
213 | }
214 |
215 | EIGEN_DEVICE_FUNC static Eigen::Matrix act4_jacobian(Point4 const& p) {
216 | // jacobian action on a point
217 | Eigen::Matrix J = Eigen::Matrix::Zero();
218 | J.template block<3,3>(0,0) = SO3::hat(-p.template segment<3>(0));
219 | return J;
220 | }
221 |
222 | private:
223 | Quaternion unit_quaternion;
224 |
225 | };
226 |
227 | #endif
228 |
229 |
230 |
--------------------------------------------------------------------------------
/lietorch/extras/extras.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 |
5 | // C++ interface
6 | #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
7 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
8 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
9 |
10 |
11 | // CUDA forward declarations
12 | std::vector corr_index_cuda_forward(
13 | torch::Tensor volume,
14 | torch::Tensor coords,
15 | int radius);
16 |
17 | std::vector corr_index_cuda_backward(
18 | torch::Tensor volume,
19 | torch::Tensor coords,
20 | torch::Tensor corr_grad,
21 | int radius);
22 |
23 | std::vector altcorr_cuda_forward(
24 | torch::Tensor fmap1,
25 | torch::Tensor fmap2,
26 | torch::Tensor coords,
27 | int radius);
28 |
29 | std::vector altcorr_cuda_backward(
30 | torch::Tensor fmap1,
31 | torch::Tensor fmap2,
32 | torch::Tensor coords,
33 | torch::Tensor corr_grad,
34 | int radius);
35 |
36 | std::vector dense_se3_forward_cuda(
37 | torch::Tensor transforms,
38 | torch::Tensor embeddings,
39 | torch::Tensor points,
40 | torch::Tensor targets,
41 | torch::Tensor weights,
42 | torch::Tensor intrinsics,
43 | int radius);
44 |
45 | std::vector dense_se3_backward_cuda(
46 | torch::Tensor transforms,
47 | torch::Tensor embeddings,
48 | torch::Tensor points,
49 | torch::Tensor targets,
50 | torch::Tensor weights,
51 | torch::Tensor intrinsics,
52 | torch::Tensor H_grad,
53 | torch::Tensor b_grad,
54 | int radius);
55 |
56 |
57 | std::vector se3_build_cuda(
58 | torch::Tensor attention,
59 | torch::Tensor transforms,
60 | torch::Tensor points,
61 | torch::Tensor targets,
62 | torch::Tensor weights,
63 | torch::Tensor intrinsics,
64 | int radius);
65 |
66 |
67 | std::vector se3_build_backward_cuda(
68 | torch::Tensor attention,
69 | torch::Tensor transforms,
70 | torch::Tensor points,
71 | torch::Tensor targets,
72 | torch::Tensor weights,
73 | torch::Tensor intrinsics,
74 | torch::Tensor H_grad,
75 | torch::Tensor b_grad,
76 | int radius);
77 |
78 |
79 | std::vector cholesky_solve6x6_forward_cuda(
80 | torch::Tensor H, torch::Tensor b);
81 |
82 | std::vector cholesky_solve6x6_backward_cuda(
83 | torch::Tensor H, torch::Tensor b, torch::Tensor dx);
84 |
85 | // c++ python binding
86 | std::vector corr_index_forward(
87 | torch::Tensor volume,
88 | torch::Tensor coords,
89 | int radius) {
90 | CHECK_INPUT(volume);
91 | CHECK_INPUT(coords);
92 |
93 | return corr_index_cuda_forward(volume, coords, radius);
94 | }
95 |
96 | std::vector corr_index_backward(
97 | torch::Tensor volume,
98 | torch::Tensor coords,
99 | torch::Tensor corr_grad,
100 | int radius) {
101 | CHECK_INPUT(volume);
102 | CHECK_INPUT(coords);
103 | CHECK_INPUT(corr_grad);
104 |
105 | auto volume_grad = corr_index_cuda_backward(volume, coords, corr_grad, radius);
106 | return {volume_grad};
107 | }
108 |
109 | std::vector altcorr_forward(
110 | torch::Tensor fmap1,
111 | torch::Tensor fmap2,
112 | torch::Tensor coords,
113 | int radius) {
114 | CHECK_INPUT(fmap1);
115 | CHECK_INPUT(fmap2);
116 | CHECK_INPUT(coords);
117 |
118 | return altcorr_cuda_forward(fmap1, fmap2, coords, radius);
119 | }
120 |
121 | std::vector altcorr_backward(
122 | torch::Tensor fmap1,
123 | torch::Tensor fmap2,
124 | torch::Tensor coords,
125 | torch::Tensor corr_grad,
126 | int radius) {
127 | CHECK_INPUT(fmap1);
128 | CHECK_INPUT(fmap2);
129 | CHECK_INPUT(coords);
130 | CHECK_INPUT(corr_grad);
131 |
132 | return altcorr_cuda_backward(fmap1, fmap2, coords, corr_grad, radius);
133 | }
134 |
135 |
136 | std::vector se3_build(
137 | torch::Tensor attention,
138 | torch::Tensor transforms,
139 | torch::Tensor points,
140 | torch::Tensor targets,
141 | torch::Tensor weights,
142 | torch::Tensor intrinsics,
143 | int radius) {
144 |
145 | CHECK_INPUT(transforms);
146 | CHECK_INPUT(attention);
147 | CHECK_INPUT(points);
148 | CHECK_INPUT(targets);
149 | CHECK_INPUT(weights);
150 | CHECK_INPUT(intrinsics);
151 |
152 | return se3_build_cuda(attention, transforms,
153 | points, targets, weights, intrinsics, radius);
154 | }
155 |
156 | std::vector se3_build_backward(
157 | torch::Tensor attention,
158 | torch::Tensor transforms,
159 | torch::Tensor points,
160 | torch::Tensor targets,
161 | torch::Tensor weights,
162 | torch::Tensor intrinsics,
163 | torch::Tensor H_grad,
164 | torch::Tensor b_grad,
165 | int radius) {
166 |
167 | CHECK_INPUT(transforms);
168 | CHECK_INPUT(attention);
169 | CHECK_INPUT(points);
170 | CHECK_INPUT(targets);
171 | CHECK_INPUT(weights);
172 | CHECK_INPUT(intrinsics);
173 |
174 | CHECK_INPUT(H_grad);
175 | CHECK_INPUT(b_grad);
176 |
177 | return se3_build_backward_cuda(attention, transforms, points,
178 | targets, weights, intrinsics, H_grad, b_grad, radius);
179 | }
180 |
181 | std::vector se3_build_inplace(
182 | torch::Tensor transforms,
183 | torch::Tensor embeddings,
184 | torch::Tensor points,
185 | torch::Tensor targets,
186 | torch::Tensor weights,
187 | torch::Tensor intrinsics,
188 | int radius) {
189 |
190 | CHECK_INPUT(transforms);
191 | CHECK_INPUT(embeddings);
192 | CHECK_INPUT(points);
193 | CHECK_INPUT(targets);
194 | CHECK_INPUT(weights);
195 | CHECK_INPUT(intrinsics);
196 |
197 | return dense_se3_forward_cuda(transforms, embeddings,
198 | points, targets, weights, intrinsics, radius);
199 | }
200 |
201 | std::vector se3_build_inplace_backward(
202 | torch::Tensor transforms,
203 | torch::Tensor embeddings,
204 | torch::Tensor points,
205 | torch::Tensor targets,
206 | torch::Tensor weights,
207 | torch::Tensor intrinsics,
208 | torch::Tensor H_grad,
209 | torch::Tensor b_grad,
210 | int radius) {
211 |
212 | CHECK_INPUT(transforms);
213 | CHECK_INPUT(embeddings);
214 | CHECK_INPUT(points);
215 | CHECK_INPUT(targets);
216 | CHECK_INPUT(weights);
217 | CHECK_INPUT(intrinsics);
218 |
219 | CHECK_INPUT(H_grad);
220 | CHECK_INPUT(b_grad);
221 |
222 | return dense_se3_backward_cuda(transforms, embeddings, points,
223 | targets, weights, intrinsics, H_grad, b_grad, radius);
224 | }
225 |
226 |
227 | std::vector cholesky6x6_forward(
228 | torch::Tensor H,
229 | torch::Tensor b) {
230 | CHECK_INPUT(H);
231 | CHECK_INPUT(b);
232 |
233 | return cholesky_solve6x6_forward_cuda(H, b);
234 | }
235 |
236 | std::vector cholesky6x6_backward(
237 | torch::Tensor H,
238 | torch::Tensor b,
239 | torch::Tensor dx) {
240 |
241 | CHECK_INPUT(H);
242 | CHECK_INPUT(b);
243 | CHECK_INPUT(dx);
244 |
245 | return cholesky_solve6x6_backward_cuda(H, b, dx);
246 | }
247 |
248 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
249 | m.def("altcorr_forward", &altcorr_forward, "ALTCORR forward");
250 | m.def("altcorr_backward", &altcorr_backward, "ALTCORR backward");
251 | m.def("corr_index_forward", &corr_index_forward, "INDEX forward");
252 | m.def("corr_index_backward", &corr_index_backward, "INDEX backward");
253 |
254 | // RAFT-3D functions
255 | m.def("se3_build", &se3_build, "build forward");
256 | m.def("se3_build_backward", &se3_build_backward, "build backward");
257 |
258 | m.def("se3_build_inplace", &se3_build_inplace, "build forward");
259 | m.def("se3_build_inplace_backward", &se3_build_inplace_backward, "build backward");
260 |
261 | m.def("cholesky6x6_forward", &cholesky6x6_forward, "solve forward");
262 | m.def("cholesky6x6_backward", &cholesky6x6_backward, "solve backward");
263 | }
264 |
265 |
--------------------------------------------------------------------------------
/examples/rgbdslam/rgbd_benchmark/evaluate_ate.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # Software License Agreement (BSD License)
3 | #
4 | # Copyright (c) 2013, Juergen Sturm, TUM
5 | # All rights reserved.
6 | #
7 | # Redistribution and use in source and binary forms, with or without
8 | # modification, are permitted provided that the following conditions
9 | # are met:
10 | #
11 | # * Redistributions of source code must retain the above copyright
12 | # notice, this list of conditions and the following disclaimer.
13 | # * Redistributions in binary form must reproduce the above
14 | # copyright notice, this list of conditions and the following
15 | # disclaimer in the documentation and/or other materials provided
16 | # with the distribution.
17 | # * Neither the name of TUM nor the names of its
18 | # contributors may be used to endorse or promote products derived
19 | # from this software without specific prior written permission.
20 | #
21 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22 | # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23 | # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
24 | # FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
25 | # COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
26 | # INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
27 | # BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
28 | # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
29 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
30 | # LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
31 | # ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
32 | # POSSIBILITY OF SUCH DAMAGE.
33 | #
34 | # Requirements:
35 | # sudo apt-get install python-argparse
36 |
37 | """
38 | This script computes the absolute trajectory error from the ground truth
39 | trajectory and the estimated trajectory.
40 | """
41 |
42 | import sys
43 | import numpy
44 | import argparse
45 | if __name__=="__main__":
46 | import associate
47 | else:
48 | from . import associate
49 |
50 | def align(model,data):
51 | """Align two trajectories using the method of Horn (closed-form).
52 |
53 | Input:
54 | model -- first trajectory (3xn)
55 | data -- second trajectory (3xn)
56 |
57 | Output:
58 | rot -- rotation matrix (3x3)
59 | trans -- translation vector (3x1)
60 | trans_error -- translational error per point (1xn)
61 |
62 | """
63 | numpy.set_printoptions(precision=3,suppress=True)
64 | model_zerocentered = model - model.mean(1)
65 | data_zerocentered = data - data.mean(1)
66 |
67 | W = numpy.zeros( (3,3) )
68 | for column in range(model.shape[1]):
69 | W += numpy.outer(model_zerocentered[:,column],data_zerocentered[:,column])
70 | U,d,Vh = numpy.linalg.linalg.svd(W.transpose())
71 | S = numpy.matrix(numpy.identity( 3 ))
72 | if(numpy.linalg.det(U) * numpy.linalg.det(Vh)<0):
73 | S[2,2] = -1
74 | rot = U*S*Vh
75 | trans = data.mean(1) - rot * model.mean(1)
76 |
77 | model_aligned = rot * model + trans
78 | alignment_error = model_aligned - data
79 |
80 | trans_error = numpy.sqrt(numpy.sum(numpy.multiply(alignment_error,alignment_error),0)).A[0]
81 |
82 | return rot,trans,trans_error
83 |
84 | def plot_traj(ax,stamps,traj,style,color,label):
85 | """
86 | Plot a trajectory using matplotlib.
87 |
88 | Input:
89 | ax -- the plot
90 | stamps -- time stamps (1xn)
91 | traj -- trajectory (3xn)
92 | style -- line style
93 | color -- line color
94 | label -- plot legend
95 |
96 | """
97 | stamps.sort()
98 | interval = numpy.median([s-t for s,t in zip(stamps[1:],stamps[:-1])])
99 | x = []
100 | y = []
101 | last = stamps[0]
102 | for i in range(len(stamps)):
103 | if stamps[i]-last < 2*interval:
104 | x.append(traj[i][0])
105 | y.append(traj[i][1])
106 | elif len(x)>0:
107 | ax.plot(x,y,style,color=color,label=label)
108 | label=""
109 | x=[]
110 | y=[]
111 | last= stamps[i]
112 | if len(x)>0:
113 | ax.plot(x,y,style,color=color,label=label)
114 |
115 |
116 | def evaluate_ate(first_list, second_list, _args=""):
117 | # parse command line
118 | parser = argparse.ArgumentParser(
119 | description='This script computes the absolute trajectory error from the ground truth trajectory and the estimated trajectory.')
120 | # parser.add_argument('first_file', help='ground truth trajectory (format: timestamp tx ty tz qx qy qz qw)')
121 | # parser.add_argument('second_file', help='estimated trajectory (format: timestamp tx ty tz qx qy qz qw)')
122 | parser.add_argument('--offset', help='time offset added to the timestamps of the second file (default: 0.0)',default=0.0)
123 | parser.add_argument('--scale', help='scaling factor for the second trajectory (default: 1.0)',default=1.0)
124 | parser.add_argument('--max_difference', help='maximally allowed time difference for matching entries (default: 0.02)',default=0.02)
125 | parser.add_argument('--save', help='save aligned second trajectory to disk (format: stamp2 x2 y2 z2)')
126 | parser.add_argument('--save_associations', help='save associated first and aligned second trajectory to disk (format: stamp1 x1 y1 z1 stamp2 x2 y2 z2)')
127 | parser.add_argument('--plot', help='plot the first and the aligned second trajectory to an image (format: png)')
128 | parser.add_argument('--verbose', help='print all evaluation data (otherwise, only the RMSE absolute translational error in meters after alignment will be printed)', action='store_true')
129 | args = parser.parse_args(_args)
130 |
131 | # first_list = associate.read_file_list(args.first_file)
132 | # second_list = associate.read_file_list(args.second_file)
133 |
134 | matches = associate.associate(first_list, second_list,float(args.offset),float(args.max_difference))
135 | if len(matches)<2:
136 | raise ValueError("Couldn't find matching timestamp pairs between groundtruth and estimated trajectory! Did you choose the correct sequence?")
137 |
138 | first_xyz = numpy.matrix([[float(value) for value in first_list[a][0:3]] for a,b in matches]).transpose()
139 | second_xyz = numpy.matrix([[float(value)*float(args.scale) for value in second_list[b][0:3]] for a,b in matches]).transpose()
140 | rot,trans,trans_error = align(second_xyz,first_xyz)
141 |
142 | second_xyz_aligned = rot * second_xyz + trans
143 |
144 | first_stamps = list(first_list.keys())
145 | first_stamps.sort()
146 | first_xyz_full = numpy.matrix([[float(value) for value in first_list[b][0:3]] for b in first_stamps]).transpose()
147 |
148 | second_stamps = list(second_list.keys())
149 | second_stamps.sort()
150 | second_xyz_full = numpy.matrix([[float(value)*float(args.scale) for value in second_list[b][0:3]] for b in second_stamps]).transpose()
151 | second_xyz_full_aligned = rot * second_xyz_full + trans
152 |
153 | if args.verbose:
154 | print( "compared_pose_pairs %d pairs"%(len(trans_error)))
155 |
156 | print( "absolute_translational_error.rmse %f m"%numpy.sqrt(numpy.dot(trans_error,trans_error) / len(trans_error)))
157 | print( "absolute_translational_error.mean %f m"%numpy.mean(trans_error))
158 | print( "absolute_translational_error.median %f m"%numpy.median(trans_error))
159 | print( "absolute_translational_error.std %f m"%numpy.std(trans_error))
160 | print( "absolute_translational_error.min %f m"%numpy.min(trans_error))
161 | print( "absolute_translational_error.max %f m"%numpy.max(trans_error))
162 |
163 |
164 | if args.save_associations:
165 | file = open(args.save_associations,"w")
166 | file.write("\n".join(["%f %f %f %f %f %f %f %f"%(a,x1,y1,z1,b,x2,y2,z2) for (a,b),(x1,y1,z1),(x2,y2,z2) in zip(matches,first_xyz.transpose().A,second_xyz_aligned.transpose().A)]))
167 | file.close()
168 |
169 | if args.save:
170 | file = open(args.save,"w")
171 | file.write("\n".join(["%f "%stamp+" ".join(["%f"%d for d in line]) for stamp,line in zip(second_stamps,second_xyz_full_aligned.transpose().A)]))
172 | file.close()
173 |
174 | if args.plot:
175 | import matplotlib
176 | matplotlib.use('Agg')
177 | import matplotlib.pyplot as plt
178 | import matplotlib.pylab as pylab
179 | from matplotlib.patches import Ellipse
180 | fig = plt.figure()
181 | ax = fig.add_subplot(111)
182 | plot_traj(ax,first_stamps,first_xyz_full.transpose().A,'-',"black","ground truth")
183 | plot_traj(ax,second_stamps,second_xyz_full_aligned.transpose().A,'-',"blue","estimated")
184 |
185 | label="difference"
186 | for (a,b),(x1,y1,z1),(x2,y2,z2) in zip(matches,first_xyz.transpose().A,second_xyz_aligned.transpose().A):
187 | ax.plot([x1,x2],[y1,y2],'-',color="red",label=label)
188 | label=""
189 |
190 | ax.legend()
191 |
192 | ax.set_xlabel('x [m]')
193 | ax.set_ylabel('y [m]')
194 | plt.savefig(args.plot,dpi=90)
195 |
196 | return {
197 | "compared_pose_pairs": (len(trans_error)),
198 | "absolute_translational_error.rmse": numpy.sqrt(numpy.dot(trans_error,trans_error) / len(trans_error)),
199 | "absolute_translational_error.mean": numpy.mean(trans_error),
200 | "absolute_translational_error.median": numpy.median(trans_error),
201 | "absolute_translational_error.std": numpy.std(trans_error),
202 | "absolute_translational_error.min": numpy.min(trans_error),
203 | "absolute_translational_error.max": numpy.max(trans_error),
204 | }
205 |
206 |
207 |
--------------------------------------------------------------------------------
/examples/core/networks/modules/extractor.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class ResidualBlock(nn.Module):
7 | def __init__(self, in_planes, planes, norm_fn='group', stride=1):
8 | super(ResidualBlock, self).__init__()
9 |
10 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)
11 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
12 | self.relu = nn.ReLU(inplace=True)
13 |
14 | num_groups = planes // 8
15 |
16 | if norm_fn == 'group':
17 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
18 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
19 | if not stride == 1:
20 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
21 |
22 | elif norm_fn == 'batch':
23 | self.norm1 = nn.BatchNorm2d(planes)
24 | self.norm2 = nn.BatchNorm2d(planes)
25 | if not stride == 1:
26 | self.norm3 = nn.BatchNorm2d(planes)
27 |
28 | elif norm_fn == 'instance':
29 | self.norm1 = nn.InstanceNorm2d(planes)
30 | self.norm2 = nn.InstanceNorm2d(planes)
31 | if not stride == 1:
32 | self.norm3 = nn.InstanceNorm2d(planes)
33 |
34 | elif norm_fn == 'none':
35 | self.norm1 = nn.Sequential()
36 | self.norm2 = nn.Sequential()
37 | if not stride == 1:
38 | self.norm3 = nn.Sequential()
39 |
40 | if stride == 1:
41 | self.downsample = None
42 |
43 | else:
44 | self.downsample = nn.Sequential(
45 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
46 |
47 | def forward(self, x):
48 | y = x
49 | y = self.relu(self.norm1(self.conv1(y)))
50 | y = self.relu(self.norm2(self.conv2(y)))
51 |
52 | if self.downsample is not None:
53 | x = self.downsample(x)
54 |
55 | return self.relu(x+y)
56 |
57 |
58 | class BottleneckBlock(nn.Module):
59 | def __init__(self, in_planes, planes, norm_fn='group', stride=1):
60 | super(BottleneckBlock, self).__init__()
61 |
62 | self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0)
63 | self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride)
64 | self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0)
65 | self.relu = nn.ReLU(inplace=True)
66 |
67 | num_groups = planes // 8
68 |
69 | if norm_fn == 'group':
70 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
71 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
72 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
73 | if not stride == 1:
74 | self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
75 |
76 | elif norm_fn == 'batch':
77 | self.norm1 = nn.BatchNorm2d(planes//4)
78 | self.norm2 = nn.BatchNorm2d(planes//4)
79 | self.norm3 = nn.BatchNorm2d(planes)
80 | if not stride == 1:
81 | self.norm4 = nn.BatchNorm2d(planes)
82 |
83 | elif norm_fn == 'instance':
84 | self.norm1 = nn.InstanceNorm2d(planes//4)
85 | self.norm2 = nn.InstanceNorm2d(planes//4)
86 | self.norm3 = nn.InstanceNorm2d(planes)
87 | if not stride == 1:
88 | self.norm4 = nn.InstanceNorm2d(planes)
89 |
90 | elif norm_fn == 'none':
91 | self.norm1 = nn.Sequential()
92 | self.norm2 = nn.Sequential()
93 | self.norm3 = nn.Sequential()
94 | if not stride == 1:
95 | self.norm4 = nn.Sequential()
96 |
97 | if stride == 1:
98 | self.downsample = None
99 |
100 | else:
101 | self.downsample = nn.Sequential(
102 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4)
103 |
104 | def forward(self, x):
105 | y = x
106 | y = self.relu(self.norm1(self.conv1(y)))
107 | y = self.relu(self.norm2(self.conv2(y)))
108 | y = self.relu(self.norm3(self.conv3(y)))
109 |
110 | if self.downsample is not None:
111 | x = self.downsample(x)
112 |
113 | return self.relu(x+y)
114 |
115 |
116 | DIM=32
117 |
118 | class BasicEncoder(nn.Module):
119 | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0, multidim=False):
120 | super(BasicEncoder, self).__init__()
121 | self.norm_fn = norm_fn
122 | self.multidim = multidim
123 |
124 | if self.norm_fn == 'group':
125 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=DIM)
126 |
127 | elif self.norm_fn == 'batch':
128 | self.norm1 = nn.BatchNorm2d(DIM)
129 |
130 | elif self.norm_fn == 'instance':
131 | self.norm1 = nn.InstanceNorm2d(DIM)
132 |
133 | elif self.norm_fn == 'none':
134 | self.norm1 = nn.Sequential()
135 |
136 | self.conv1 = nn.Conv2d(3, DIM, kernel_size=7, stride=2, padding=3)
137 | self.relu1 = nn.ReLU(inplace=True)
138 |
139 | self.in_planes = DIM
140 | self.layer1 = self._make_layer(DIM, stride=1)
141 | self.layer2 = self._make_layer(2*DIM, stride=2)
142 | self.layer3 = self._make_layer(4*DIM, stride=2)
143 |
144 | # output convolution
145 | self.conv2 = nn.Conv2d(4*DIM, output_dim, kernel_size=1)
146 |
147 | if self.multidim:
148 | self.layer4 = self._make_layer(256, stride=2)
149 | self.layer5 = self._make_layer(512, stride=2)
150 |
151 | self.in_planes = 256
152 | self.layer6 = self._make_layer(256, stride=1)
153 |
154 | self.in_planes = 128
155 | self.layer7 = self._make_layer(128, stride=1)
156 |
157 | self.up1 = nn.Conv2d(512, 256, 1)
158 | self.up2 = nn.Conv2d(256, 128, 1)
159 | self.conv3 = nn.Conv2d(128, output_dim, kernel_size=1)
160 |
161 | if dropout > 0:
162 | self.dropout = nn.Dropout2d(p=dropout)
163 | else:
164 | self.dropout = None
165 |
166 | for m in self.modules():
167 | if isinstance(m, nn.Conv2d):
168 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
169 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
170 | if m.weight is not None:
171 | nn.init.constant_(m.weight, 1)
172 | if m.bias is not None:
173 | nn.init.constant_(m.bias, 0)
174 |
175 | def _make_layer(self, dim, stride=1):
176 | layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
177 | layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
178 | layers = (layer1, layer2)
179 |
180 | self.in_planes = dim
181 | return nn.Sequential(*layers)
182 |
183 | def forward(self, x):
184 | b, n, c1, h1, w1 = x.shape
185 | x = x.view(b*n, c1, h1, w1)
186 |
187 | x = self.conv1(x)
188 | x = self.norm1(x)
189 | x = self.relu1(x)
190 |
191 | x = self.layer1(x)
192 | x = self.layer2(x)
193 | x = self.layer3(x)
194 |
195 | x = self.conv2(x)
196 |
197 | _, c2, h2, w2 = x.shape
198 | return x.view(b, n, c2, h2, w2)
199 |
200 |
201 |
202 | class BasicEncoder16(nn.Module):
203 | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0, multidim=False):
204 | super(BasicEncoder16, self).__init__()
205 | self.norm_fn = norm_fn
206 | self.multidim = multidim
207 |
208 | if self.norm_fn == 'group':
209 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=DIM)
210 |
211 | elif self.norm_fn == 'batch':
212 | self.norm1 = nn.BatchNorm2d(DIM)
213 |
214 | elif self.norm_fn == 'instance':
215 | self.norm1 = nn.InstanceNorm2d(DIM)
216 |
217 | elif self.norm_fn == 'none':
218 | self.norm1 = nn.Sequential()
219 |
220 | self.conv1 = nn.Conv2d(3, DIM, kernel_size=7, stride=2, padding=3)
221 | self.relu1 = nn.ReLU(inplace=True)
222 |
223 | self.in_planes = DIM
224 | self.layer1 = self._make_layer(DIM, stride=1)
225 | self.layer2 = self._make_layer(2*DIM, stride=2)
226 | self.layer3 = self._make_layer(4*DIM, stride=2)
227 | self.layer4 = self._make_layer(4*DIM, stride=2)
228 |
229 |
230 | # output convolution
231 | self.conv2 = nn.Conv2d(4*DIM, output_dim, kernel_size=1)
232 |
233 | if dropout > 0:
234 | self.dropout = nn.Dropout2d(p=dropout)
235 | else:
236 | self.dropout = None
237 |
238 | for m in self.modules():
239 | if isinstance(m, nn.Conv2d):
240 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
241 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
242 | if m.weight is not None:
243 | nn.init.constant_(m.weight, 1)
244 | if m.bias is not None:
245 | nn.init.constant_(m.bias, 0)
246 |
247 | def _make_layer(self, dim, stride=1):
248 | layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
249 | layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
250 | layers = (layer1, layer2)
251 |
252 | self.in_planes = dim
253 | return nn.Sequential(*layers)
254 |
255 | def forward(self, x):
256 | b, n, c1, h1, w1 = x.shape
257 | x = x.view(b*n, c1, h1, w1)
258 |
259 | x = self.conv1(x)
260 | x = self.norm1(x)
261 | x = self.relu1(x)
262 |
263 | x = self.layer1(x)
264 | x = self.layer2(x)
265 | x = self.layer3(x)
266 | x = self.layer4(x)
267 |
268 | x = self.conv2(x)
269 |
270 | _, c2, h2, w2 = x.shape
271 | return x.view(b, n, c2, h2, w2)
272 |
--------------------------------------------------------------------------------
/lietorch/groups.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 | # group operations implemented in cuda
5 | from .group_ops import Exp, Log, Inv, Mul, Adj, AdjT, Jinv, Act3, Act4, ToVec, FromVec
6 | from .broadcasting import broadcast_inputs
7 |
8 |
9 | class LieGroupParameter(torch.Tensor):
10 | """ Wrapper class for LieGroup """
11 |
12 | from torch._C import _disabled_torch_function_impl
13 | __torch_function__ = _disabled_torch_function_impl
14 |
15 | def __new__(cls, group, requires_grad=True):
16 | data = torch.zeros(group.tangent_shape,
17 | device=group.data.device,
18 | dtype=group.data.dtype,
19 | requires_grad=True)
20 |
21 | return torch.Tensor._make_subclass(cls, data, requires_grad)
22 |
23 | def __init__(self, group):
24 | self.group = group
25 |
26 | def retr(self):
27 | return self.group.retr(self)
28 |
29 | def log(self):
30 | return self.retr().log()
31 |
32 | def inv(self):
33 | return self.retr().inv()
34 |
35 | def adj(self, a):
36 | return self.retr().adj(a)
37 |
38 | def __mul__(self, other):
39 | if isinstance(other, LieGroupParameter):
40 | return self.retr() * other.retr()
41 | else:
42 | return self.retr() * other
43 |
44 | def add_(self, update, alpha):
45 | self.group = self.group.exp(alpha*update) * self.group
46 |
47 | def __getitem__(self, index):
48 | return self.retr().__getitem__(index)
49 |
50 |
51 | class LieGroup:
52 | """ Base class for Lie Group """
53 |
54 | def __init__(self, data):
55 | self.data = data
56 |
57 | def __repr__(self):
58 | return "{}: size={}, device={}, dtype={}".format(
59 | self.group_name, self.shape, self.device, self.dtype)
60 |
61 | @property
62 | def shape(self):
63 | return self.data.shape[:-1]
64 |
65 | @property
66 | def device(self):
67 | return self.data.device
68 |
69 | @property
70 | def dtype(self):
71 | return self.data.dtype
72 |
73 | def vec(self):
74 | return self.apply_op(ToVec, self.data)
75 |
76 | @property
77 | def tangent_shape(self):
78 | return self.data.shape[:-1] + (self.manifold_dim,)
79 |
80 | @classmethod
81 | def Identity(cls, *batch_shape, **kwargs):
82 | """ Construct identity element with batch shape """
83 |
84 | if isinstance(batch_shape[0], tuple):
85 | batch_shape = batch_shape[0]
86 |
87 | elif isinstance(batch_shape[0], list):
88 | batch_shape = tuple(batch_shape[0])
89 |
90 | numel = np.prod(batch_shape)
91 | data = cls.id_elem.reshape(1,-1)
92 |
93 | if 'device' in kwargs:
94 | data = data.to(kwargs['device'])
95 |
96 | if 'dtype' in kwargs:
97 | data = data.type(kwargs['dtype'])
98 |
99 | data = data.repeat(numel, 1)
100 | return cls(data).view(batch_shape)
101 |
102 | @classmethod
103 | def IdentityLike(cls, G):
104 | return cls.Identity(G.shape, device=G.data.device, dtype=G.data.dtype)
105 |
106 | @classmethod
107 | def InitFromVec(cls, data):
108 | return cls(cls.apply_op(FromVec, data))
109 |
110 | @classmethod
111 | def Random(cls, *batch_shape, sigma=1.0, **kwargs):
112 | """ Construct random element with batch_shape by random sampling in tangent space"""
113 |
114 | if isinstance(batch_shape[0], tuple):
115 | batch_shape = batch_shape[0]
116 |
117 | elif isinstance(batch_shape[0], list):
118 | batch_shape = tuple(batch_shape[0])
119 |
120 | tangent_shape = batch_shape + (cls.manifold_dim,)
121 | xi = torch.randn(tangent_shape, **kwargs)
122 | return cls.exp(sigma * xi)
123 |
124 | @classmethod
125 | def apply_op(cls, op, x, y=None):
126 | """ Apply group operator """
127 | inputs, out_shape = broadcast_inputs(x, y)
128 |
129 | data = op.apply(cls.group_id, *inputs)
130 | return data.view(out_shape + (-1,))
131 |
132 | @classmethod
133 | def exp(cls, x):
134 | """ exponential map: x -> X """
135 | return cls(cls.apply_op(Exp, x))
136 |
137 | def quaternion(self):
138 | """ extract quaternion """
139 | return self.apply_op(Quat, self.data)
140 |
141 | def log(self):
142 | """ logarithm map """
143 | return self.apply_op(Log, self.data)
144 |
145 | def inv(self):
146 | """ group inverse """
147 | return self.__class__(self.apply_op(Inv, self.data))
148 |
149 | def mul(self, other):
150 | """ group multiplication """
151 | return self.__class__(self.apply_op(Mul, self.data, other.data))
152 |
153 | def retr(self, a):
154 | """ retraction: Exp(a) * X """
155 | dX = self.__class__.apply_op(Exp, a)
156 | return self.__class__(self.apply_op(Mul, dX, self.data))
157 |
158 | def adj(self, a):
159 | """ adjoint operator: b = A(X) * a """
160 | return self.apply_op(Adj, self.data, a)
161 |
162 | def adjT(self, a):
163 | """ transposed adjoint operator: b = a * A(X) """
164 | return self.apply_op(AdjT, self.data, a)
165 |
166 | def Jinv(self, a):
167 | return self.apply_op(Jinv, self.data, a)
168 |
169 | def act(self, p):
170 | """ action on a point cloud """
171 |
172 | # action on point
173 | if p.shape[-1] == 3:
174 | return self.apply_op(Act3, self.data, p)
175 |
176 | # action on homogeneous point
177 | elif p.shape[-1] == 4:
178 | return self.apply_op(Act4, self.data, p)
179 |
180 | def matrix(self):
181 | """ convert element to 4x4 matrix """
182 | Id = torch.eye(4, dtype=self.dtype, device=self.device)
183 | Id = Id.view([1] * (len(self.data.shape) - 1) + [4, 4])
184 | return self.__class__(self.data[...,None,:]).act(Id).transpose(-1,-2)
185 |
186 | def translation(self):
187 | """ extract translation component """
188 | p = torch.as_tensor([0.0, 0.0, 0.0, 1.0], dtype=self.dtype, device=self.device)
189 | p = p.view([1] * (len(self.data.shape) - 1) + [4,])
190 | return self.apply_op(Act4, self.data, p)
191 |
192 | def detach(self):
193 | return self.__class__(self.data.detach())
194 |
195 | def view(self, dims):
196 | data_reshaped = self.data.view(dims + (self.embedded_dim,))
197 | return self.__class__(data_reshaped)
198 |
199 | def __mul__(self, other):
200 | # group multiplication
201 | if isinstance(other, LieGroup):
202 | return self.mul(other)
203 |
204 | # action on point
205 | elif isinstance(other, torch.Tensor):
206 | return self.act(other)
207 |
208 | def __getitem__(self, index):
209 | return self.__class__(self.data[index])
210 |
211 | def __setitem__(self, index, item):
212 | self.data[index] = item.data
213 |
214 | def to(self, *args, **kwargs):
215 | return self.__class__(self.data.to(*args, **kwargs))
216 |
217 | def cpu(self):
218 | return self.__class__(self.data.cpu())
219 |
220 | def cuda(self):
221 | return self.__class__(self.data.cuda())
222 |
223 | def float(self):
224 | return self.__class__(self.data.float())
225 |
226 | def double(self):
227 | return self.__class__(self.data.double())
228 |
229 | def unbind(self, dim=0):
230 | return [self.__class__(x) for x in self.data.unbind(dim=dim)]
231 |
232 |
233 | class SO3(LieGroup):
234 | group_name = 'SO3'
235 | group_id = 1
236 | manifold_dim = 3
237 | embedded_dim = 4
238 |
239 | # unit quaternion
240 | id_elem = torch.as_tensor([0.0, 0.0, 0.0, 1.0])
241 |
242 | def __init__(self, data):
243 | if isinstance(data, SE3):
244 | data = data.data[..., 3:7]
245 |
246 | super(SO3, self).__init__(data)
247 |
248 |
249 | class RxSO3(LieGroup):
250 | group_name = 'RxSO3'
251 | group_id = 2
252 | manifold_dim = 4
253 | embedded_dim = 5
254 |
255 | # unit quaternion
256 | id_elem = torch.as_tensor([0.0, 0.0, 0.0, 1.0, 1.0])
257 |
258 | def __init__(self, data):
259 | if isinstance(data, Sim3):
260 | data = data.data[..., 3:8]
261 |
262 | super(RxSO3, self).__init__(data)
263 |
264 |
265 | class SE3(LieGroup):
266 | group_name = 'SE3'
267 | group_id = 3
268 | manifold_dim = 6
269 | embedded_dim = 7
270 |
271 | # translation, unit quaternion
272 | id_elem = torch.as_tensor([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0])
273 |
274 | def __init__(self, data):
275 | if isinstance(data, SO3):
276 | translation = torch.zeros_like(data.data[...,:3])
277 | data = torch.cat([translation, data.data], -1)
278 |
279 | super(SE3, self).__init__(data)
280 |
281 | def scale(self, s):
282 | t, q = self.data.split([3,4], -1)
283 | t = t * s.unsqueeze(-1)
284 | return SE3(torch.cat([t, q], dim=-1))
285 |
286 |
287 | class Sim3(LieGroup):
288 | group_name = 'Sim3'
289 | group_id = 4
290 | manifold_dim = 7
291 | embedded_dim = 8
292 |
293 | # translation, unit quaternion, scale
294 | id_elem = torch.as_tensor([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0])
295 |
296 | def __init__(self, data):
297 |
298 | if isinstance(data, SO3):
299 | scale = torch.ones_like(SO3.data[...,:1])
300 | translation = torch.zeros_like(SO3.data[...,:3])
301 | data = torch.cat([translation, SO3.data, scale], -1)
302 |
303 | elif isinstance(data, SE3):
304 | scale = torch.ones_like(data.data[...,:1])
305 | data = torch.cat([data.data, scale], -1)
306 |
307 | elif isinstance(data, Sim3):
308 | data = data.data
309 |
310 | super(Sim3, self).__init__(data)
311 |
312 |
313 | def cat(group_list, dim):
314 | """ Concatenate groups along dimension """
315 | data = torch.cat([X.data for X in group_list], dim=dim)
316 | return group_list[0].__class__(data)
317 |
318 | def stack(group_list, dim):
319 | """ Concatenate groups along dimension """
320 | data = torch.stack([X.data for X in group_list], dim=dim)
321 | return group_list[0].__class__(data)
322 |
--------------------------------------------------------------------------------