├── .gitignore
├── LICENCE
├── README.md
├── figures
├── boxplot.jpg
├── methode.jpg
├── roe.jpg
└── rpy.jpg
├── main_EUROC.py
├── main_TUMVI.py
├── requirements.txt
└── src
├── dataset.py
├── learning.py
├── lie_algebra.py
├── losses.py
├── networks.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | */__pycache__
2 | data
3 | results
4 |
--------------------------------------------------------------------------------
/LICENCE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Martin Brossard
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Denoising IMU Gyroscope with Deep Learning for Open-Loop Attitude Estimation
2 |
3 | ## Overview [[IEEE paper](https://ieeexplore.ieee.org/document/9119813), [preprint paper](https://hal.archives-ouvertes.fr/hal-02488923v4/document)]
4 |
5 | This repo contains a learning method for denoising gyroscopes of Inertial Measurement Units (IMUs) using
6 | ground truth data. In terms of attitude dead-reckoning estimation, the obtained algorithm is able to beat top-ranked
7 | visual-inertial odometry systems [3-5] in terms of attitude estimation
8 | although it only uses signals from a low-cost IMU. The obtained
9 | performances are achieved thanks to a well chosen model, and a
10 | proper loss function for orientation increments. Our approach builds upon a neural network based
11 | on dilated convolutions, without requiring any recurrent neural
12 | network.
13 |
14 | ## Code
15 | Our implementation is based on Python 3 and [Pytorch](https://pytorch.org/). We
16 | test the code under Ubuntu 16.04, Python 3.5, and Pytorch 1.5. The codebase is licensed under the MIT License.
17 |
18 | ### Installation & Prerequies
19 | 1. Install the correct version of [Pytorch](http://pytorch.org)
20 | ```
21 | pip install --pre torch -f https://download.pytorch.org/whl/nightly/cu101/torch_nightly.html
22 | ```
23 |
24 | 2. Clone this repo and create empty directories
25 | ```
26 | git clone https://github.com/mbrossar/denoise-imu-gyro.git
27 | mkdir denoise-imu-gyro/data
28 | mkdir denoise-imu-gyro/results
29 | ```
30 |
31 | 3. Install the following required Python packages, e.g. with the pip command
32 | ```
33 | pip install -r denoise-imu-gyro/requirements.txt
34 | ```
35 |
36 | ### Testing
37 |
38 | 1. Download reformated pickle format of the _EuRoC_ [1] and _TUM-VI_ [2] datasets at this [url](https://cloud.mines-paristech.fr/index.php/s/d2lHqzIk1PxzWmb/download), extract and copy then in the `data` folder.
39 | ```
40 | wget "https://cloud.mines-paristech.fr/index.php/s/d2lHqzIk1PxzWmb/download"
41 | unzip download -d denoise-imu-gyro/data
42 | rm download
43 | ```
44 | These file can alternatively be generated after downloading the _EuRoC_ and
45 | _TUM-VI_ datasets. They will be generated when lanching the main file after
46 | providing data paths.
47 |
48 | 2. Download optimized parameters at this [url](https://cloud.mines-paristech.fr/index.php/s/OLnj74YXtOLA7Hv/download), extract and copy in the `results` folder.
49 | ```
50 | wget "https://cloud.mines-paristech.fr/index.php/s/OLnj74YXtOLA7Hv/download"
51 | unzip download -d denoise-imu-gyro/results
52 | rm download
53 | ```
54 | 3. Test on the dataset on your choice !
55 | ```
56 | cd denoise-imu-gyro
57 | python3 main_EUROC.py
58 | # or alternatively
59 | # python3 main_TUMVI.py
60 | ```
61 |
62 | You can then compare results with the evaluation [toolbox](https://github.com/rpng/open_vins/) of [3].
63 |
64 | ### Training
65 | You can train the method by
66 | uncomment the two lines after # train in the main files. Edit then the
67 | configuration to obtain results with another sets of parameters. It roughly
68 | takes 5 minutes per dataset with a decent GPU.
69 |
70 | ## Schematic Illustration of the Proposed Method
71 |
72 |
73 |
75 |
76 |
77 | The convolutional neural network
78 | computes gyro corrections (based on past IMU measurements) that filters
79 | undesirable errors in the raw IMU signals. We
80 | then perform open-loop time integration on the noise-free measurements
81 | for regressing low frequency errors between ground truth and estimated
82 | orientation increments.
83 |
84 | ## Results
85 |
86 |
87 |
88 |
89 |
90 | Orientation estimates on the test sequence _MH 04 difficult_ of [1] (left), and
91 | _room 4_ of [2] (right). Our method removes errors of the IMU.
92 |
93 |
94 |
95 |
96 |
97 | Relative Orientation Error (ROE) in terms of 3D orientation and
98 | yaw errors on the test sequences. Our method competes with VIO methods albeit based only on IMU signals.
99 |
100 | ## Paper
101 | The paper M. Brossard, S. Bonnabel and A. Barrau, "Denoising IMU Gyroscopes With Deep Learning for Open-Loop Attitude Estimation," in _IEEE Robotics and Automation Letters_, vol. 5, no. 3, pp. 4796-4803, July 2020, doi: 10.1109/LRA.2020.3003256., relative to this repo, is
102 | available at this [url](https://ieeexplore.ieee.org/document/9119813) and a preprint at this [url](https://hal.archives-ouvertes.fr/hal-02488923/document).
103 |
104 |
105 | ## Citation
106 |
107 | If you use this code in your research, please cite:
108 |
109 | ```
110 | @article{brossard2020denoising,
111 | author={M. {Brossard} and S. {Bonnabel} and A. {Barrau}},
112 | journal={IEEE Robotics and Automation Letters},
113 | title={Denoising IMU Gyroscopes With Deep Learning for Open-Loop Attitude Estimation},
114 | year={2020},
115 | volume={5},
116 | number={3},
117 | pages={4796-4803},
118 | }
119 |
120 | ```
121 |
122 | ## Authors
123 |
124 | This code was written by the [Centre of Robotique](http://caor-mines-paristech.fr/en/home/) at the
125 | MINESParisTech, Paris, France.
126 |
127 | [Martin
128 | Brossard](mailto:martin.brossard@mines-paristech.fr)^, [Axel
129 | Barrau](mailto:axel.barrau@safrangroup.com)^ and [Silvère
130 | Bonnabel](mailto:silvere.bonnabel@mines-paristech.fr)^.
131 |
132 | ^[MINES ParisTech](http://www.mines-paristech.eu/), PSL Research University,
133 | Centre for Robotics, 60 Boulevard Saint-Michel, 75006 Paris, France.
134 |
135 | ## Biblio
136 |
137 | [1] M. Burri, J. Nikolic, P. Gohl, T. Schneider, J. Rehder, S. Omari,
138 | M. W. Achtelik, and R. Siegwart, ``_The EuRoC Micro Aerial Vehicle
139 | Datasets_", The International Journal of Robotics Research, vol. 35,
140 | no. 10, pp. 1157–1163, 2016.
141 |
142 | [2] D. Schubert, T. Goll, N. Demmel, V. Usenko, J. Stuckler, and
143 | D. Cremers, ``_The TUM VI Benchmark for Evaluating Visual-Inertial
144 | Odometry_", in International Conference on Intelligent Robots and
145 | Systems (IROS). IEEE, pp. 1680–1687, 2018.
146 |
147 | [3] P. Geneva, K. Eckenhoff, W. Lee, Y. Yang, and G. Huang, ``_OpenVINS:
148 | A Research Platform for Visual-Inertial Estimation_", IROS Workshop
149 | on Visual-Inertial Navigation: Challenges and Applications, 2019.
150 |
151 | [4] T. Qin, P. Li, and S. Shen, ``_VINS-Mono: A Robust and Versatile
152 | Monocular Visual-Inertial State Estimator_", IEEE Transactions on
153 | Robotics, vol. 34, no. 4, pp. 1004–1020, 2018.
154 |
155 | [5] M. Bloesch, M. Burri, S. Omari, M. Hutter, and R. Siegwart, ``_Iterated
156 | Extended Kalman Filter Based Visual-Inertial Odometry Using Direct Photometric
157 | Feedback_", The International Journal of Robotics Research,vol. 36, no. 10, pp.
158 | 1053ñ1072, 2017.
159 |
--------------------------------------------------------------------------------
/figures/boxplot.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mbrossar/denoise-imu-gyro/9948b6646882c547d5ba6eda19a95ea4cba5e89e/figures/boxplot.jpg
--------------------------------------------------------------------------------
/figures/methode.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mbrossar/denoise-imu-gyro/9948b6646882c547d5ba6eda19a95ea4cba5e89e/figures/methode.jpg
--------------------------------------------------------------------------------
/figures/roe.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mbrossar/denoise-imu-gyro/9948b6646882c547d5ba6eda19a95ea4cba5e89e/figures/roe.jpg
--------------------------------------------------------------------------------
/figures/rpy.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mbrossar/denoise-imu-gyro/9948b6646882c547d5ba6eda19a95ea4cba5e89e/figures/rpy.jpg
--------------------------------------------------------------------------------
/main_EUROC.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import src.learning as lr
4 | import src.networks as sn
5 | import src.losses as sl
6 | import src.dataset as ds
7 | import numpy as np
8 |
9 | base_dir = os.path.dirname(os.path.realpath(__file__))
10 | data_dir = '/path/to/EUROC/dataset'
11 | # test a given network
12 | # address = os.path.join(base_dir, 'results/EUROC/2020_02_18_16_52_55/')
13 | # or test the last trained network
14 | address = "last"
15 | ################################################################################
16 | # Network parameters
17 | ################################################################################
18 | net_class = sn.GyroNet
19 | net_params = {
20 | 'in_dim': 6,
21 | 'out_dim': 3,
22 | 'c0': 16,
23 | 'dropout': 0.1,
24 | 'ks': [7, 7, 7, 7],
25 | 'ds': [4, 4, 4],
26 | 'momentum': 0.1,
27 | 'gyro_std': [1*np.pi/180, 2*np.pi/180, 5*np.pi/180],
28 | }
29 | ################################################################################
30 | # Dataset parameters
31 | ################################################################################
32 | dataset_class = ds.EUROCDataset
33 | dataset_params = {
34 | # where are raw data ?
35 | 'data_dir': data_dir,
36 | # where record preloaded data ?
37 | 'predata_dir': os.path.join(base_dir, 'data/EUROC'),
38 | # set train, val and test sequence
39 | 'train_seqs': [
40 | 'MH_01_easy',
41 | 'MH_03_medium',
42 | 'MH_05_difficult',
43 | 'V1_02_medium',
44 | 'V2_01_easy',
45 | 'V2_03_difficult'
46 | ],
47 | 'val_seqs': [
48 | 'MH_01_easy',
49 | 'MH_03_medium',
50 | 'MH_05_difficult',
51 | 'V1_02_medium',
52 | 'V2_01_easy',
53 | 'V2_03_difficult',
54 | ],
55 | 'test_seqs': [
56 | 'MH_02_easy',
57 | 'MH_04_difficult',
58 | 'V2_02_medium',
59 | 'V1_03_difficult',
60 | 'V1_01_easy',
61 | ],
62 | # size of trajectory during training
63 | 'N': 32 * 500, # should be integer * 'max_train_freq'
64 | 'min_train_freq': 16,
65 | 'max_train_freq': 32,
66 | }
67 | ################################################################################
68 | # Training parameters
69 | ################################################################################
70 | train_params = {
71 | 'optimizer_class': torch.optim.Adam,
72 | 'optimizer': {
73 | 'lr': 0.01,
74 | 'weight_decay': 1e-1,
75 | 'amsgrad': False,
76 | },
77 | 'loss_class': sl.GyroLoss,
78 | 'loss': {
79 | 'min_N': int(np.log2(dataset_params['min_train_freq'])),
80 | 'max_N': int(np.log2(dataset_params['max_train_freq'])),
81 | 'w': 1e6,
82 | 'target': 'rotation matrix',
83 | 'huber': 0.005,
84 | 'dt': 0.005,
85 | },
86 | 'scheduler_class': torch.optim.lr_scheduler.CosineAnnealingWarmRestarts,
87 | 'scheduler': {
88 | 'T_0': 600,
89 | 'T_mult': 2,
90 | 'eta_min': 1e-3,
91 | },
92 | 'dataloader': {
93 | 'batch_size': 10,
94 | 'pin_memory': False,
95 | 'num_workers': 0,
96 | 'shuffle': False,
97 | },
98 | # frequency of validation step
99 | 'freq_val': 600,
100 | # total number of epochs
101 | 'n_epochs': 1800,
102 | # where record results ?
103 | 'res_dir': os.path.join(base_dir, "results/EUROC"),
104 | # where record Tensorboard log ?
105 | 'tb_dir': os.path.join(base_dir, "results/runs/EUROC"),
106 | }
107 | ################################################################################
108 | # Train on training data set
109 | ################################################################################
110 | # learning_process = lr.GyroLearningBasedProcessing(train_params['res_dir'],
111 | # train_params['tb_dir'], net_class, net_params, None,
112 | # train_params['loss']['dt'])
113 | # learning_process.train(dataset_class, dataset_params, train_params)
114 | ################################################################################
115 | # Test on full data set
116 | ################################################################################
117 | learning_process = lr.GyroLearningBasedProcessing(train_params['res_dir'],
118 | train_params['tb_dir'], net_class, net_params, address=address,
119 | dt=train_params['loss']['dt'])
120 | learning_process.test(dataset_class, dataset_params, ['test'])
--------------------------------------------------------------------------------
/main_TUMVI.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import src.learning as lr
4 | import src.networks as sn
5 | import src.losses as sl
6 | import src.dataset as ds
7 | import numpy as np
8 |
9 | base_dir = os.path.dirname(os.path.realpath(__file__))
10 | data_dir = '/path/to/TUM/dataset'
11 | # test a given network
12 | # address = os.path.join(base_dir, 'results/TUM/2020_02_18_16_26_33')
13 | # or test the last trained network
14 | address = 'last'
15 | ################################################################################
16 | # Network parameters
17 | ################################################################################
18 | net_class = sn.GyroNet
19 | net_params = {
20 | 'in_dim': 6,
21 | 'out_dim': 3,
22 | 'c0': 16,
23 | 'dropout': 0.1,
24 | 'ks': [7, 7, 7, 7],
25 | 'ds': [4, 4, 4],
26 | 'momentum': 0.1,
27 | 'gyro_std': [0.2*np.pi/180, 0.2*np.pi/180, 0.2*np.pi/180],
28 | }
29 | ################################################################################
30 | # Dataset parameters
31 | ################################################################################
32 | dataset_class = ds.TUMVIDataset
33 | dataset_params = {
34 | # where are raw data ?
35 | 'data_dir': data_dir,
36 | # where record preloaded data ?
37 | 'predata_dir': os.path.join(base_dir, 'data/TUM'),
38 | # set train, val and test sequence
39 | 'train_seqs': [
40 | 'dataset-room1_512_16',
41 | 'dataset-room3_512_16',
42 | 'dataset-room5_512_16',
43 | ],
44 | 'val_seqs': [
45 | 'dataset-room2_512_16',
46 | 'dataset-room4_512_16',
47 | 'dataset-room6_512_16',
48 | ],
49 | 'test_seqs': [
50 | 'dataset-room2_512_16',
51 | 'dataset-room4_512_16',
52 | 'dataset-room6_512_16'
53 | ],
54 | # size of trajectory during training
55 | 'N': 32 * 500, # should be integer * 'max_train_freq'
56 | 'min_train_freq': 16,
57 | 'max_train_freq': 32,
58 | }
59 | ################################################################################
60 | # Training parameters
61 | ################################################################################
62 | train_params = {
63 | 'optimizer_class': torch.optim.Adam,
64 | 'optimizer': {
65 | 'lr': 0.01,
66 | 'weight_decay': 1e-1,
67 | 'amsgrad': False,
68 | },
69 | 'loss_class': sl.GyroLoss,
70 | 'loss': {
71 | 'min_N': int(np.log2(dataset_params['min_train_freq'])),
72 | 'max_N': int(np.log2(dataset_params['max_train_freq'])),
73 | 'w': 1e6,
74 | 'target': 'rotation matrix mask',
75 | 'huber': 0.005,
76 | 'dt': 0.005,
77 | },
78 | 'scheduler_class': torch.optim.lr_scheduler.CosineAnnealingWarmRestarts,
79 | 'scheduler': {
80 | 'T_0': 600,
81 | 'T_mult': 2,
82 | 'eta_min': 1e-3,
83 | },
84 | 'dataloader': {
85 | 'batch_size': 10,
86 | 'pin_memory': False,
87 | 'num_workers': 0,
88 | 'shuffle': False,
89 | },
90 | # frequency of validation step
91 | 'freq_val': 600,
92 | # total number of epochs
93 | 'n_epochs': 1800,
94 | # where record results ?
95 | 'res_dir': os.path.join(base_dir, "results/TUM"),
96 | # where record Tensorboard log ?
97 | 'tb_dir': os.path.join(base_dir, "results/runs/TUM"),
98 | }
99 | ################################################################################
100 | # Train on training data set
101 | ################################################################################
102 | # learning_process = lr.GyroLearningBasedProcessing(train_params['res_dir'],
103 | # train_params['tb_dir'], net_class, net_params, None
104 | # train_params['loss']['dt'])
105 | # learning_process.train(dataset_class, dataset_params, train_params)
106 | ################################################################################
107 | # Test on full data set
108 | ################################################################################
109 | learning_process = lr.GyroLearningBasedProcessing(train_params['res_dir'],
110 | train_params['tb_dir'], net_class, net_params, address=address,
111 | dt=train_params['loss']['dt'])
112 | learning_process.test(dataset_class, dataset_params, ['test'])
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | torch
3 | datetime
4 | matplotlib
5 | termcolor
6 | pyyaml
--------------------------------------------------------------------------------
/src/dataset.py:
--------------------------------------------------------------------------------
1 | from src.utils import pdump, pload, bmtv, bmtm
2 | from src.lie_algebra import SO3
3 | from termcolor import cprint
4 | from torch.utils.data.dataset import Dataset
5 | from scipy.interpolate import interp1d
6 | import numpy as np
7 | import matplotlib.pyplot as plt
8 | import pickle
9 | import os
10 | import torch
11 | import sys
12 |
13 | class BaseDataset(Dataset):
14 |
15 | def __init__(self, predata_dir, train_seqs, val_seqs, test_seqs, mode, N,
16 | min_train_freq=128, max_train_freq=512, dt=0.005):
17 | super().__init__()
18 | # where record pre loaded data
19 | self.predata_dir = predata_dir
20 | self.path_normalize_factors = os.path.join(predata_dir, 'nf.p')
21 |
22 | self.mode = mode
23 | # choose between training, validation or test sequences
24 | train_seqs, self.sequences = self.get_sequences(train_seqs, val_seqs,
25 | test_seqs)
26 | # get and compute value for normalizing inputs
27 | self.mean_u, self.std_u = self.init_normalize_factors(train_seqs)
28 | self.mode = mode # train, val or test
29 | self._train = False
30 | self._val = False
31 | # noise density
32 | self.imu_std = torch.Tensor([8e-5, 1e-3]).float()
33 | # bias repeatability (without in-run bias stability)
34 | self.imu_b0 = torch.Tensor([1e-3, 1e-3]).float()
35 | # IMU sampling time
36 | self.dt = dt # (s)
37 | # sequence size during training
38 | self.N = N # power of 2
39 | self.min_train_freq = min_train_freq
40 | self.max_train_freq = max_train_freq
41 | self.uni = torch.distributions.uniform.Uniform(-torch.ones(1),
42 | torch.ones(1))
43 |
44 | def get_sequences(self, train_seqs, val_seqs, test_seqs):
45 | """Choose sequence list depending on dataset mode"""
46 | sequences_dict = {
47 | 'train': train_seqs,
48 | 'val': val_seqs,
49 | 'test': test_seqs,
50 | }
51 | return sequences_dict['train'], sequences_dict[self.mode]
52 |
53 | def __getitem__(self, i):
54 | mondict = self.load_seq(i)
55 | N_max = mondict['xs'].shape[0]
56 | if self._train: # random start
57 | n0 = torch.randint(0, self.max_train_freq, (1, ))
58 | nend = n0 + self.N
59 | elif self._val: # end sequence
60 | n0 = self.max_train_freq + self.N
61 | nend = N_max - ((N_max - n0) % self.max_train_freq)
62 | else: # full sequence
63 | n0 = 0
64 | nend = N_max - (N_max % self.max_train_freq)
65 | u = mondict['us'][n0: nend]
66 | x = mondict['xs'][n0: nend]
67 | return u, x
68 |
69 | def __len__(self):
70 | return len(self.sequences)
71 |
72 | def add_noise(self, u):
73 | """Add Gaussian noise and bias to input"""
74 | noise = torch.randn_like(u)
75 | noise[:, :, :3] = noise[:, :, :3] * self.imu_std[0]
76 | noise[:, :, 3:6] = noise[:, :, 3:6] * self.imu_std[1]
77 |
78 | # bias repeatability (without in run bias stability)
79 | b0 = self.uni.sample(u[:, 0].shape).cuda()
80 | b0[:, :, :3] = b0[:, :, :3] * self.imu_b0[0]
81 | b0[:, :, 3:6] = b0[:, :, 3:6] * self.imu_b0[1]
82 | u = u + noise + b0.transpose(1, 2)
83 | return u
84 |
85 | def init_train(self):
86 | self._train = True
87 | self._val = False
88 |
89 | def init_val(self):
90 | self._train = False
91 | self._val = True
92 |
93 | def length(self):
94 | return self._length
95 |
96 | def load_seq(self, i):
97 | return pload(self.predata_dir, self.sequences[i] + '.p')
98 |
99 | def load_gt(self, i):
100 | return pload(self.predata_dir, self.sequences[i] + '_gt.p')
101 |
102 | def init_normalize_factors(self, train_seqs):
103 | if os.path.exists(self.path_normalize_factors):
104 | mondict = pload(self.path_normalize_factors)
105 | return mondict['mean_u'], mondict['std_u']
106 |
107 | path = os.path.join(self.predata_dir, train_seqs[0] + '.p')
108 | if not os.path.exists(path):
109 | print("init_normalize_factors not computed")
110 | return 0, 0
111 |
112 | print('Start computing normalizing factors ...')
113 | cprint("Do it only on training sequences, it is vital!", 'yellow')
114 | # first compute mean
115 | num_data = 0
116 |
117 | for i, sequence in enumerate(train_seqs):
118 | pickle_dict = pload(self.predata_dir, sequence + '.p')
119 | us = pickle_dict['us']
120 | sms = pickle_dict['xs']
121 | if i == 0:
122 | mean_u = us.sum(dim=0)
123 | num_positive = sms.sum(dim=0)
124 | num_negative = sms.shape[0] - sms.sum(dim=0)
125 | else:
126 | mean_u += us.sum(dim=0)
127 | num_positive += sms.sum(dim=0)
128 | num_negative += sms.shape[0] - sms.sum(dim=0)
129 | num_data += us.shape[0]
130 | mean_u = mean_u / num_data
131 | pos_weight = num_negative / num_positive
132 |
133 | # second compute standard deviation
134 | for i, sequence in enumerate(train_seqs):
135 | pickle_dict = pload(self.predata_dir, sequence + '.p')
136 | us = pickle_dict['us']
137 | if i == 0:
138 | std_u = ((us - mean_u) ** 2).sum(dim=0)
139 | else:
140 | std_u += ((us - mean_u) ** 2).sum(dim=0)
141 | std_u = (std_u / num_data).sqrt()
142 | normalize_factors = {
143 | 'mean_u': mean_u,
144 | 'std_u': std_u,
145 | }
146 | print('... ended computing normalizing factors')
147 | print('pos_weight:', pos_weight)
148 | print('This values most be a training parameters !')
149 | print('mean_u :', mean_u)
150 | print('std_u :', std_u)
151 | print('num_data :', num_data)
152 | pdump(normalize_factors, self.path_normalize_factors)
153 | return mean_u, std_u
154 |
155 | def read_data(self, data_dir):
156 | raise NotImplementedError
157 |
158 | @staticmethod
159 | def interpolate(x, t, t_int):
160 | """
161 | Interpolate ground truth at the sensor timestamps
162 | """
163 |
164 | # vector interpolation
165 | x_int = np.zeros((t_int.shape[0], x.shape[1]))
166 | for i in range(x.shape[1]):
167 | if i in [4, 5, 6, 7]:
168 | continue
169 | x_int[:, i] = np.interp(t_int, t, x[:, i])
170 | # quaternion interpolation
171 | t_int = torch.Tensor(t_int - t[0])
172 | t = torch.Tensor(t - t[0])
173 | qs = SO3.qnorm(torch.Tensor(x[:, 4:8]))
174 | x_int[:, 4:8] = SO3.qinterp(qs, t, t_int).numpy()
175 | return x_int
176 |
177 |
178 | class EUROCDataset(BaseDataset):
179 | """
180 | Dataloader for the EUROC Data Set.
181 | """
182 |
183 | def __init__(self, data_dir, predata_dir, train_seqs, val_seqs,
184 | test_seqs, mode, N, min_train_freq, max_train_freq, dt=0.005):
185 | super().__init__(predata_dir, train_seqs, val_seqs, test_seqs, mode, N, min_train_freq, max_train_freq, dt)
186 | # convert raw data to pre loaded data
187 | self.read_data(data_dir)
188 |
189 | def read_data(self, data_dir):
190 | r"""Read the data from the dataset"""
191 |
192 | f = os.path.join(self.predata_dir, 'MH_01_easy.p')
193 | if True and os.path.exists(f):
194 | return
195 |
196 | print("Start read_data, be patient please")
197 | def set_path(seq):
198 | path_imu = os.path.join(data_dir, seq, "mav0", "imu0", "data.csv")
199 | path_gt = os.path.join(data_dir, seq, "mav0", "state_groundtruth_estimate0", "data.csv")
200 | return path_imu, path_gt
201 |
202 | sequences = os.listdir(data_dir)
203 | # read each sequence
204 | for sequence in sequences:
205 | print("\nSequence name: " + sequence)
206 | path_imu, path_gt = set_path(sequence)
207 | imu = np.genfromtxt(path_imu, delimiter=",", skip_header=1)
208 | gt = np.genfromtxt(path_gt, delimiter=",", skip_header=1)
209 |
210 | # time synchronization between IMU and ground truth
211 | t0 = np.max([gt[0, 0], imu[0, 0]])
212 | t_end = np.min([gt[-1, 0], imu[-1, 0]])
213 |
214 | # start index
215 | idx0_imu = np.searchsorted(imu[:, 0], t0)
216 | idx0_gt = np.searchsorted(gt[:, 0], t0)
217 |
218 | # end index
219 | idx_end_imu = np.searchsorted(imu[:, 0], t_end, 'right')
220 | idx_end_gt = np.searchsorted(gt[:, 0], t_end, 'right')
221 |
222 | # subsample
223 | imu = imu[idx0_imu: idx_end_imu]
224 | gt = gt[idx0_gt: idx_end_gt]
225 | ts = imu[:, 0]/1e9
226 |
227 | # interpolate
228 | gt = self.interpolate(gt, gt[:, 0]/1e9, ts)
229 |
230 | # take ground truth position
231 | p_gt = gt[:, 1:4]
232 | p_gt = p_gt - p_gt[0]
233 |
234 | # take ground true quaternion pose
235 | q_gt = torch.Tensor(gt[:, 4:8]).double()
236 | q_gt = q_gt / q_gt.norm(dim=1, keepdim=True)
237 | Rot_gt = SO3.from_quaternion(q_gt.cuda(), ordering='wxyz').cpu()
238 |
239 | # convert from numpy
240 | p_gt = torch.Tensor(p_gt).double()
241 | v_gt = torch.tensor(gt[:, 8:11]).double()
242 | imu = torch.Tensor(imu[:, 1:]).double()
243 |
244 | # compute pre-integration factors for all training
245 | mtf = self.min_train_freq
246 | dRot_ij = bmtm(Rot_gt[:-mtf], Rot_gt[mtf:])
247 | dRot_ij = SO3.dnormalize(dRot_ij.cuda())
248 | dxi_ij = SO3.log(dRot_ij).cpu()
249 |
250 | # save for all training
251 | mondict = {
252 | 'xs': dxi_ij.float(),
253 | 'us': imu.float(),
254 | }
255 | pdump(mondict, self.predata_dir, sequence + ".p")
256 | # save ground truth
257 | mondict = {
258 | 'ts': ts,
259 | 'qs': q_gt.float(),
260 | 'vs': v_gt.float(),
261 | 'ps': p_gt.float(),
262 | }
263 | pdump(mondict, self.predata_dir, sequence + "_gt.p")
264 |
265 |
266 | class TUMVIDataset(BaseDataset):
267 | """
268 | Dataloader for the TUM-VI Data Set.
269 | """
270 |
271 | def __init__(self, data_dir, predata_dir, train_seqs, val_seqs,
272 | test_seqs, mode, N, min_train_freq, max_train_freq, dt=0.005):
273 | super().__init__(predata_dir, train_seqs, val_seqs, test_seqs, mode, N,
274 | min_train_freq, max_train_freq, dt)
275 | # convert raw data to pre loaded data
276 | self.read_data(data_dir)
277 | # noise density
278 | self.imu_std = torch.Tensor([8e-5, 1e-3]).float()
279 | # bias repeatability (without in-run bias stability)
280 | self.imu_b0 = torch.Tensor([1e-3, 1e-3]).float()
281 |
282 | def read_data(self, data_dir):
283 | r"""Read the data from the dataset"""
284 |
285 | f = os.path.join(self.predata_dir, 'dataset-room1_512_16_gt.p')
286 | if True and os.path.exists(f):
287 | return
288 |
289 | print("Start read_data, be patient please")
290 | def set_path(seq):
291 | path_imu = os.path.join(data_dir, seq, "mav0", "imu0", "data.csv")
292 | path_gt = os.path.join(data_dir, seq, "mav0", "mocap0", "data.csv")
293 | return path_imu, path_gt
294 |
295 | sequences = os.listdir(data_dir)
296 |
297 | # read each sequence
298 | for sequence in sequences:
299 | print("\nSequence name: " + sequence)
300 | if 'room' not in sequence:
301 | continue
302 |
303 | path_imu, path_gt = set_path(sequence)
304 | imu = np.genfromtxt(path_imu, delimiter=",", skip_header=1)
305 | gt = np.genfromtxt(path_gt, delimiter=",", skip_header=1)
306 |
307 | # time synchronization between IMU and ground truth
308 | t0 = np.max([gt[0, 0], imu[0, 0]])
309 | t_end = np.min([gt[-1, 0], imu[-1, 0]])
310 |
311 | # start index
312 | idx0_imu = np.searchsorted(imu[:, 0], t0)
313 | idx0_gt = np.searchsorted(gt[:, 0], t0)
314 |
315 | # end index
316 | idx_end_imu = np.searchsorted(imu[:, 0], t_end, 'right')
317 | idx_end_gt = np.searchsorted(gt[:, 0], t_end, 'right')
318 |
319 | # subsample
320 | imu = imu[idx0_imu: idx_end_imu]
321 | gt = gt[idx0_gt: idx_end_gt]
322 | ts = imu[:, 0]/1e9
323 |
324 | # interpolate
325 | t_gt = gt[:, 0]/1e9
326 | gt = self.interpolate(gt, t_gt, ts)
327 |
328 | # take ground truth position
329 | p_gt = gt[:, 1:4]
330 | p_gt = p_gt - p_gt[0]
331 |
332 | # take ground true quaternion pose
333 | q_gt = SO3.qnorm(torch.Tensor(gt[:, 4:8]).double())
334 | Rot_gt = SO3.from_quaternion(q_gt.cuda(), ordering='wxyz').cpu()
335 |
336 | # convert from numpy
337 | p_gt = torch.Tensor(p_gt).double()
338 | v_gt = torch.zeros_like(p_gt).double()
339 | v_gt[1:] = (p_gt[1:]-p_gt[:-1])/self.dt
340 | imu = torch.Tensor(imu[:, 1:]).double()
341 |
342 | # compute pre-integration factors for all training
343 | mtf = self.min_train_freq
344 | dRot_ij = bmtm(Rot_gt[:-mtf], Rot_gt[mtf:])
345 | dRot_ij = SO3.dnormalize(dRot_ij.cuda())
346 | dxi_ij = SO3.log(dRot_ij).cpu()
347 |
348 | # masks with 1 when ground truth is available, 0 otherwise
349 | masks = dxi_ij.new_ones(dxi_ij.shape[0])
350 | tmp = np.searchsorted(t_gt, ts[:-mtf])
351 | diff_t = ts[:-mtf] - t_gt[tmp]
352 | masks[np.abs(diff_t) > 0.01] = 0
353 |
354 | # save all the sequence
355 | mondict = {
356 | 'xs': torch.cat((dxi_ij, masks.unsqueeze(1)), 1).float(),
357 | 'us': imu.float(),
358 | }
359 | pdump(mondict, self.predata_dir, sequence + ".p")
360 |
361 | # save ground truth
362 | mondict = {
363 | 'ts': ts,
364 | 'qs': q_gt.float(),
365 | 'vs': v_gt.float(),
366 | 'ps': p_gt.float(),
367 | }
368 | pdump(mondict, self.predata_dir, sequence + "_gt.p")
369 |
--------------------------------------------------------------------------------
/src/learning.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import time
4 | import matplotlib.pyplot as plt
5 | plt.rcParams["legend.loc"] = "upper right"
6 | plt.rcParams['axes.titlesize'] = 'x-large'
7 | plt.rcParams['axes.labelsize'] = 'x-large'
8 | plt.rcParams['legend.fontsize'] = 'x-large'
9 | plt.rcParams['xtick.labelsize'] = 'x-large'
10 | plt.rcParams['ytick.labelsize'] = 'x-large'
11 | from termcolor import cprint
12 | import numpy as np
13 | import os
14 | from torch.utils.tensorboard import SummaryWriter
15 | from torch.utils.data import DataLoader
16 | from src.utils import pload, pdump, yload, ydump, mkdir, bmv
17 | from src.utils import bmtm, bmtv, bmmt
18 | from datetime import datetime
19 | from src.lie_algebra import SO3, CPUSO3
20 |
21 |
22 | class LearningBasedProcessing:
23 | def __init__(self, res_dir, tb_dir, net_class, net_params, address, dt):
24 | self.res_dir = res_dir
25 | self.tb_dir = tb_dir
26 | self.net_class = net_class
27 | self.net_params = net_params
28 | self._ready = False
29 | self.train_params = {}
30 | self.figsize = (20, 12)
31 | self.dt = dt # (s)
32 | self.address, self.tb_address = self.find_address(address)
33 | if address is None: # create new address
34 | pdump(self.net_params, self.address, 'net_params.p')
35 | ydump(self.net_params, self.address, 'net_params.yaml')
36 | else: # pick the network parameters
37 | self.net_params = pload(self.address, 'net_params.p')
38 | self.train_params = pload(self.address, 'train_params.p')
39 | self._ready = True
40 | self.path_weights = os.path.join(self.address, 'weights.pt')
41 | self.net = self.net_class(**self.net_params)
42 | if self._ready: # fill network parameters
43 | self.load_weights()
44 |
45 | def find_address(self, address):
46 | """return path where net and training info are saved"""
47 | if address == 'last':
48 | addresses = sorted(os.listdir(self.res_dir))
49 | tb_address = os.path.join(self.tb_dir, str(len(addresses)))
50 | address = os.path.join(self.res_dir, addresses[-1])
51 | elif address is None:
52 | now = datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
53 | address = os.path.join(self.res_dir, now)
54 | mkdir(address)
55 | tb_address = os.path.join(self.tb_dir, now)
56 | else:
57 | tb_address = None
58 | return address, tb_address
59 |
60 | def load_weights(self):
61 | weights = torch.load(self.path_weights)
62 | self.net.load_state_dict(weights)
63 | self.net.cuda()
64 |
65 | def train(self, dataset_class, dataset_params, train_params):
66 | """train the neural network. GPU is assumed"""
67 | self.train_params = train_params
68 | pdump(self.train_params, self.address, 'train_params.p')
69 | ydump(self.train_params, self.address, 'train_params.yaml')
70 |
71 | hparams = self.get_hparams(dataset_class, dataset_params, train_params)
72 | ydump(hparams, self.address, 'hparams.yaml')
73 |
74 | # define datasets
75 | dataset_train = dataset_class(**dataset_params, mode='train')
76 | dataset_train.init_train()
77 | dataset_val = dataset_class(**dataset_params, mode='val')
78 | dataset_val.init_val()
79 |
80 | # get class
81 | Optimizer = train_params['optimizer_class']
82 | Scheduler = train_params['scheduler_class']
83 | Loss = train_params['loss_class']
84 |
85 | # get parameters
86 | dataloader_params = train_params['dataloader']
87 | optimizer_params = train_params['optimizer']
88 | scheduler_params = train_params['scheduler']
89 | loss_params = train_params['loss']
90 |
91 | # define optimizer, scheduler and loss
92 | dataloader = DataLoader(dataset_train, **dataloader_params)
93 | optimizer = Optimizer(self.net.parameters(), **optimizer_params)
94 | scheduler = Scheduler(optimizer, **scheduler_params)
95 | criterion = Loss(**loss_params)
96 |
97 | # remaining training parameters
98 | freq_val = train_params['freq_val']
99 | n_epochs = train_params['n_epochs']
100 |
101 | # init net w.r.t dataset
102 | self.net = self.net.cuda()
103 | mean_u, std_u = dataset_train.mean_u, dataset_train.std_u
104 | self.net.set_normalized_factors(mean_u, std_u)
105 |
106 | # start tensorboard writer
107 | writer = SummaryWriter(self.tb_address)
108 | start_time = time.time()
109 | best_loss = torch.Tensor([float('Inf')])
110 |
111 | # define some function for seeing evolution of training
112 | def write(epoch, loss_epoch):
113 | writer.add_scalar('loss/train', loss_epoch.item(), epoch)
114 | writer.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch)
115 | print('Train Epoch: {:2d} \tLoss: {:.4f}'.format(
116 | epoch, loss_epoch.item()))
117 | scheduler.step(epoch)
118 |
119 | def write_time(epoch, start_time):
120 | delta_t = time.time() - start_time
121 | print("Amount of time spent for epochs " +
122 | "{}-{}: {:.1f}s\n".format(epoch - freq_val, epoch, delta_t))
123 | writer.add_scalar('time_spend', delta_t, epoch)
124 |
125 | def write_val(loss, best_loss):
126 | if 0.5*loss <= best_loss:
127 | msg = 'validation loss decreases! :) '
128 | msg += '(curr/prev loss {:.4f}/{:.4f})'.format(loss.item(),
129 | best_loss.item())
130 | cprint(msg, 'green')
131 | best_loss = loss
132 | self.save_net()
133 | else:
134 | msg = 'validation loss increases! :( '
135 | msg += '(curr/prev loss {:.4f}/{:.4f})'.format(loss.item(),
136 | best_loss.item())
137 | cprint(msg, 'yellow')
138 | writer.add_scalar('loss/val', loss.item(), epoch)
139 | return best_loss
140 |
141 | # training loop !
142 | for epoch in range(1, n_epochs + 1):
143 | loss_epoch = self.loop_train(dataloader, optimizer, criterion)
144 | write(epoch, loss_epoch)
145 | scheduler.step(epoch)
146 | if epoch % freq_val == 0:
147 | loss = self.loop_val(dataset_val, criterion)
148 | write_time(epoch, start_time)
149 | best_loss = write_val(loss, best_loss)
150 | start_time = time.time()
151 | # training is over !
152 |
153 | # test on new data
154 | dataset_test = dataset_class(**dataset_params, mode='test')
155 | self.load_weights()
156 | test_loss = self.loop_val(dataset_test, criterion)
157 | dict_loss = {
158 | 'final_loss/val': best_loss.item(),
159 | 'final_loss/test': test_loss.item()
160 | }
161 | writer.add_hparams(hparams, dict_loss)
162 | ydump(dict_loss, self.address, 'final_loss.yaml')
163 | writer.close()
164 |
165 | def loop_train(self, dataloader, optimizer, criterion):
166 | """Forward-backward loop over training data"""
167 | loss_epoch = 0
168 | optimizer.zero_grad()
169 | for us, xs in dataloader:
170 | us = dataloader.dataset.add_noise(us.cuda())
171 | hat_xs = self.net(us)
172 | loss = criterion(xs.cuda(), hat_xs)/len(dataloader)
173 | loss.backward()
174 | loss_epoch += loss.detach().cpu()
175 | optimizer.step()
176 | return loss_epoch
177 |
178 | def loop_val(self, dataset, criterion):
179 | """Forward loop over validation data"""
180 | loss_epoch = 0
181 | self.net.eval()
182 | with torch.no_grad():
183 | for i in range(len(dataset)):
184 | us, xs = dataset[i]
185 | hat_xs = self.net(us.cuda().unsqueeze(0))
186 | loss = criterion(xs.cuda().unsqueeze(0), hat_xs)/len(dataset)
187 | loss_epoch += loss.cpu()
188 | self.net.train()
189 | return loss_epoch
190 |
191 | def save_net(self):
192 | """save the weights on the net in CPU"""
193 | self.net.eval().cpu()
194 | torch.save(self.net.state_dict(), self.path_weights)
195 | self.net.train().cuda()
196 |
197 | def get_hparams(self, dataset_class, dataset_params, train_params):
198 | """return all training hyperparameters in a dict"""
199 | Optimizer = train_params['optimizer_class']
200 | Scheduler = train_params['scheduler_class']
201 | Loss = train_params['loss_class']
202 |
203 | # get training class parameters
204 | dataloader_params = train_params['dataloader']
205 | optimizer_params = train_params['optimizer']
206 | scheduler_params = train_params['scheduler']
207 | loss_params = train_params['loss']
208 |
209 | # remaining training parameters
210 | freq_val = train_params['freq_val']
211 | n_epochs = train_params['n_epochs']
212 |
213 | dict_class = {
214 | 'Optimizer': str(Optimizer),
215 | 'Scheduler': str(Scheduler),
216 | 'Loss': str(Loss)
217 | }
218 |
219 | return {**dict_class, **dataloader_params, **optimizer_params,
220 | **loss_params, **scheduler_params,
221 | 'n_epochs': n_epochs, 'freq_val': freq_val}
222 |
223 | def test(self, dataset_class, dataset_params, modes):
224 | """test a network once training is over"""
225 |
226 | # get loss function
227 | Loss = self.train_params['loss_class']
228 | loss_params = self.train_params['loss']
229 | criterion = Loss(**loss_params)
230 |
231 | # test on each type of sequence
232 | for mode in modes:
233 | dataset = dataset_class(**dataset_params, mode=mode)
234 | self.loop_test(dataset, criterion)
235 | self.display_test(dataset, mode)
236 |
237 | def loop_test(self, dataset, criterion):
238 | """Forward loop over test data"""
239 | self.net.eval()
240 | for i in range(len(dataset)):
241 | seq = dataset.sequences[i]
242 | us, xs = dataset[i]
243 | with torch.no_grad():
244 | hat_xs = self.net(us.cuda().unsqueeze(0))
245 | loss = criterion(xs.cuda().unsqueeze(0), hat_xs)
246 | mkdir(self.address, seq)
247 | mondict = {
248 | 'hat_xs': hat_xs[0].cpu(),
249 | 'loss': loss.cpu().item(),
250 | }
251 | pdump(mondict, self.address, seq, 'results.p')
252 |
253 | def display_test(self, dataset, mode):
254 | raise NotImplementedError
255 |
256 |
257 | class GyroLearningBasedProcessing(LearningBasedProcessing):
258 | def __init__(self, res_dir, tb_dir, net_class, net_params, address, dt):
259 | super().__init__(res_dir, tb_dir, net_class, net_params, address, dt)
260 | self.roe_dist = [7, 14, 21, 28, 35] # m
261 | self.freq = 100 # subsampling frequency for RTE computation
262 | self.roes = { # relative trajectory errors
263 | 'Rots': [],
264 | 'yaws': [],
265 | }
266 |
267 | def display_test(self, dataset, mode):
268 | self.roes = {
269 | 'Rots': [],
270 | 'yaws': [],
271 | }
272 | self.to_open_vins(dataset)
273 | for i, seq in enumerate(dataset.sequences):
274 | print('\n', 'Results for sequence ' + seq )
275 | self.seq = seq
276 | # get ground truth
277 | self.gt = dataset.load_gt(i)
278 | Rots = SO3.from_quaternion(self.gt['qs'].cuda())
279 | self.gt['Rots'] = Rots.cpu()
280 | self.gt['rpys'] = SO3.to_rpy(Rots).cpu()
281 | # get data and estimate
282 | self.net_us = pload(self.address, seq, 'results.p')['hat_xs']
283 | self.raw_us, _ = dataset[i]
284 | N = self.net_us.shape[0]
285 | self.gyro_corrections = (self.raw_us[:, :3] - self.net_us[:N, :3])
286 | self.ts = torch.linspace(0, N*self.dt, N)
287 |
288 | self.convert()
289 | self.plot_gyro()
290 | self.plot_gyro_correction()
291 | plt.show()
292 |
293 | def to_open_vins(self, dataset):
294 | """
295 | Export results to Open-VINS format. Use them eval toolbox available
296 | at https://github.com/rpng/open_vins/
297 | """
298 |
299 | for i, seq in enumerate(dataset.sequences):
300 | self.seq = seq
301 | # get ground truth
302 | self.gt = dataset.load_gt(i)
303 | raw_us, _ = dataset[i]
304 | net_us = pload(self.address, seq, 'results.p')['hat_xs']
305 | N = net_us.shape[0]
306 | net_qs, imu_Rots, net_Rots = self.integrate_with_quaternions_superfast(N, raw_us, net_us)
307 | path = os.path.join(self.address, seq + '.txt')
308 | header = "timestamp(s) tx ty tz qx qy qz qw"
309 | x = np.zeros((net_qs.shape[0], 8))
310 | x[:, 0] = self.gt['ts'][:net_qs.shape[0]]
311 | x[:, [7, 4, 5, 6]] = net_qs
312 | np.savetxt(path, x[::10], header=header, delimiter=" ",
313 | fmt='%1.9f')
314 |
315 | def convert(self):
316 | # s -> min
317 | l = 1/60
318 | self.ts *= l
319 |
320 | # rad -> deg
321 | l = 180/np.pi
322 | self.gyro_corrections *= l
323 | self.gt['rpys'] *= l
324 |
325 | def integrate_with_quaternions_superfast(self, N, raw_us, net_us):
326 | imu_qs = SO3.qnorm(SO3.qexp(raw_us[:, :3].cuda().double()*self.dt))
327 | net_qs = SO3.qnorm(SO3.qexp(net_us[:, :3].cuda().double()*self.dt))
328 | Rot0 = SO3.qnorm(self.gt['qs'][:2].cuda().double())
329 | imu_qs[0] = Rot0[0]
330 | net_qs[0] = Rot0[0]
331 |
332 | N = np.log2(imu_qs.shape[0])
333 | for i in range(int(N)):
334 | k = 2**i
335 | imu_qs[k:] = SO3.qnorm(SO3.qmul(imu_qs[:-k], imu_qs[k:]))
336 | net_qs[k:] = SO3.qnorm(SO3.qmul(net_qs[:-k], net_qs[k:]))
337 |
338 | if int(N) < N:
339 | k = 2**int(N)
340 | k2 = imu_qs[k:].shape[0]
341 | imu_qs[k:] = SO3.qnorm(SO3.qmul(imu_qs[:k2], imu_qs[k:]))
342 | net_qs[k:] = SO3.qnorm(SO3.qmul(net_qs[:k2], net_qs[k:]))
343 |
344 | imu_Rots = SO3.from_quaternion(imu_qs).float()
345 | net_Rots = SO3.from_quaternion(net_qs).float()
346 | return net_qs.cpu(), imu_Rots, net_Rots
347 |
348 | def plot_gyro(self):
349 | N = self.raw_us.shape[0]
350 | raw_us = self.raw_us[:, :3]
351 | net_us = self.net_us[:, :3]
352 |
353 | net_qs, imu_Rots, net_Rots = self.integrate_with_quaternions_superfast(N,
354 | raw_us, net_us)
355 | imu_rpys = 180/np.pi*SO3.to_rpy(imu_Rots).cpu()
356 | net_rpys = 180/np.pi*SO3.to_rpy(net_Rots).cpu()
357 | self.plot_orientation(imu_rpys, net_rpys, N)
358 | self.plot_orientation_error(imu_Rots, net_Rots, N)
359 |
360 | def plot_orientation(self, imu_rpys, net_rpys, N):
361 | title = "Orientation estimation"
362 | gt = self.gt['rpys'][:N]
363 | fig, axs = plt.subplots(3, 1, sharex=True, figsize=self.figsize)
364 | axs[0].set(ylabel='roll (deg)', title=title)
365 | axs[1].set(ylabel='pitch (deg)')
366 | axs[2].set(xlabel='$t$ (min)', ylabel='yaw (deg)')
367 |
368 | for i in range(3):
369 | axs[i].plot(self.ts, gt[:, i], color='black', label=r'ground truth')
370 | axs[i].plot(self.ts, imu_rpys[:, i], color='red', label=r'raw IMU')
371 | axs[i].plot(self.ts, net_rpys[:, i], color='blue', label=r'net IMU')
372 | axs[i].set_xlim(self.ts[0], self.ts[-1])
373 | self.savefig(axs, fig, 'orientation')
374 |
375 | def plot_orientation_error(self, imu_Rots, net_Rots, N):
376 | gt = self.gt['Rots'][:N].cuda()
377 | raw_err = 180/np.pi*SO3.log(bmtm(imu_Rots, gt)).cpu()
378 | net_err = 180/np.pi*SO3.log(bmtm(net_Rots, gt)).cpu()
379 | title = "$SO(3)$ orientation error"
380 | fig, axs = plt.subplots(3, 1, sharex=True, figsize=self.figsize)
381 | axs[0].set(ylabel='roll (deg)', title=title)
382 | axs[1].set(ylabel='pitch (deg)')
383 | axs[2].set(xlabel='$t$ (min)', ylabel='yaw (deg)')
384 |
385 | for i in range(3):
386 | axs[i].plot(self.ts, raw_err[:, i], color='red', label=r'raw IMU')
387 | axs[i].plot(self.ts, net_err[:, i], color='blue', label=r'net IMU')
388 | axs[i].set_ylim(-10, 10)
389 | axs[i].set_xlim(self.ts[0], self.ts[-1])
390 | self.savefig(axs, fig, 'orientation_error')
391 |
392 | def plot_gyro_correction(self):
393 | title = "Gyro correction" + self.end_title
394 | ylabel = 'gyro correction (deg/s)'
395 | fig, ax = plt.subplots(figsize=self.figsize)
396 | ax.set(xlabel='$t$ (min)', ylabel=ylabel, title=title)
397 | plt.plot(self.ts, self.gyro_corrections, label=r'net IMU')
398 | ax.set_xlim(self.ts[0], self.ts[-1])
399 | self.savefig(ax, fig, 'gyro_correction')
400 |
401 | @property
402 | def end_title(self):
403 | return " for sequence " + self.seq.replace("_", " ")
404 |
405 | def savefig(self, axs, fig, name):
406 | if isinstance(axs, np.ndarray):
407 | for i in range(len(axs)):
408 | axs[i].grid()
409 | axs[i].legend()
410 | else:
411 | axs.grid()
412 | axs.legend()
413 | fig.tight_layout()
414 | fig.savefig(os.path.join(self.address, self.seq, name + '.png'))
415 |
416 |
--------------------------------------------------------------------------------
/src/lie_algebra.py:
--------------------------------------------------------------------------------
1 | from src.utils import *
2 | import numpy as np
3 |
4 |
5 | class SO3:
6 | # tolerance criterion
7 | TOL = 1e-8
8 | Id = torch.eye(3).cuda().float()
9 | dId = torch.eye(3).cuda().double()
10 |
11 | @classmethod
12 | def exp(cls, phi):
13 | angle = phi.norm(dim=1, keepdim=True)
14 | mask = angle[:, 0] < cls.TOL
15 | dim_batch = phi.shape[0]
16 | Id = cls.Id.expand(dim_batch, 3, 3)
17 |
18 | axis = phi[~mask] / angle[~mask]
19 | c = angle[~mask].cos().unsqueeze(2)
20 | s = angle[~mask].sin().unsqueeze(2)
21 |
22 | Rot = phi.new_empty(dim_batch, 3, 3)
23 | Rot[mask] = Id[mask] + SO3.wedge(phi[mask])
24 | Rot[~mask] = c*Id[~mask] + \
25 | (1-c)*cls.bouter(axis, axis) + s*cls.wedge(axis)
26 | return Rot
27 |
28 | @classmethod
29 | def log(cls, Rot):
30 | dim_batch = Rot.shape[0]
31 | Id = cls.Id.expand(dim_batch, 3, 3)
32 |
33 | cos_angle = (0.5 * cls.btrace(Rot) - 0.5).clamp(-1., 1.)
34 | # Clip cos(angle) to its proper domain to avoid NaNs from rounding
35 | # errors
36 | angle = cos_angle.acos()
37 | mask = angle < cls.TOL
38 | if mask.sum() == 0:
39 | angle = angle.unsqueeze(1).unsqueeze(1)
40 | return cls.vee((0.5 * angle/angle.sin())*(Rot-Rot.transpose(1, 2)))
41 | elif mask.sum() == dim_batch:
42 | # If angle is close to zero, use first-order Taylor expansion
43 | return cls.vee(Rot - Id)
44 | phi = cls.vee(Rot - Id)
45 | angle = angle
46 | phi[~mask] = cls.vee((0.5 * angle[~mask]/angle[~mask].sin()).unsqueeze(
47 | 1).unsqueeze(2)*(Rot[~mask] - Rot[~mask].transpose(1, 2)))
48 | return phi
49 |
50 | @staticmethod
51 | def vee(Phi):
52 | return torch.stack((Phi[:, 2, 1],
53 | Phi[:, 0, 2],
54 | Phi[:, 1, 0]), dim=1)
55 |
56 | @staticmethod
57 | def wedge(phi):
58 | dim_batch = phi.shape[0]
59 | zero = phi.new_zeros(dim_batch)
60 | return torch.stack((zero, -phi[:, 2], phi[:, 1],
61 | phi[:, 2], zero, -phi[:, 0],
62 | -phi[:, 1], phi[:, 0], zero), 1).view(dim_batch,
63 | 3, 3)
64 |
65 | @classmethod
66 | def from_rpy(cls, roll, pitch, yaw):
67 | return cls.rotz(yaw).bmm(cls.roty(pitch).bmm(cls.rotx(roll)))
68 |
69 | @classmethod
70 | def rotx(cls, angle_in_radians):
71 | c = angle_in_radians.cos()
72 | s = angle_in_radians.sin()
73 | mat = c.new_zeros((c.shape[0], 3, 3))
74 | mat[:, 0, 0] = 1
75 | mat[:, 1, 1] = c
76 | mat[:, 2, 2] = c
77 | mat[:, 1, 2] = -s
78 | mat[:, 2, 1] = s
79 | return mat
80 |
81 | @classmethod
82 | def roty(cls, angle_in_radians):
83 | c = angle_in_radians.cos()
84 | s = angle_in_radians.sin()
85 | mat = c.new_zeros((c.shape[0], 3, 3))
86 | mat[:, 1, 1] = 1
87 | mat[:, 0, 0] = c
88 | mat[:, 2, 2] = c
89 | mat[:, 0, 2] = s
90 | mat[:, 2, 0] = -s
91 | return mat
92 |
93 | @classmethod
94 | def rotz(cls, angle_in_radians):
95 | c = angle_in_radians.cos()
96 | s = angle_in_radians.sin()
97 | mat = c.new_zeros((c.shape[0], 3, 3))
98 | mat[:, 2, 2] = 1
99 | mat[:, 0, 0] = c
100 | mat[:, 1, 1] = c
101 | mat[:, 0, 1] = -s
102 | mat[:, 1, 0] = s
103 | return mat
104 |
105 | @classmethod
106 | def isclose(cls, x, y):
107 | return (x-y).abs() < cls.TOL
108 |
109 | @classmethod
110 | def to_rpy(cls, Rots):
111 | """Convert a rotation matrix to RPY Euler angles."""
112 |
113 | pitch = torch.atan2(-Rots[:, 2, 0],
114 | torch.sqrt(Rots[:, 0, 0]**2 + Rots[:, 1, 0]**2))
115 | yaw = pitch.new_empty(pitch.shape)
116 | roll = pitch.new_empty(pitch.shape)
117 |
118 | near_pi_over_two_mask = cls.isclose(pitch, np.pi / 2.)
119 | near_neg_pi_over_two_mask = cls.isclose(pitch, -np.pi / 2.)
120 |
121 | remainder_inds = ~(near_pi_over_two_mask | near_neg_pi_over_two_mask)
122 |
123 | yaw[near_pi_over_two_mask] = 0
124 | roll[near_pi_over_two_mask] = torch.atan2(
125 | Rots[near_pi_over_two_mask, 0, 1],
126 | Rots[near_pi_over_two_mask, 1, 1])
127 |
128 | yaw[near_neg_pi_over_two_mask] = 0.
129 | roll[near_neg_pi_over_two_mask] = -torch.atan2(
130 | Rots[near_neg_pi_over_two_mask, 0, 1],
131 | Rots[near_neg_pi_over_two_mask, 1, 1])
132 |
133 | sec_pitch = 1/pitch[remainder_inds].cos()
134 | remainder_mats = Rots[remainder_inds]
135 | yaw = torch.atan2(remainder_mats[:, 1, 0] * sec_pitch,
136 | remainder_mats[:, 0, 0] * sec_pitch)
137 | roll = torch.atan2(remainder_mats[:, 2, 1] * sec_pitch,
138 | remainder_mats[:, 2, 2] * sec_pitch)
139 | rpys = torch.cat([roll.unsqueeze(dim=1),
140 | pitch.unsqueeze(dim=1),
141 | yaw.unsqueeze(dim=1)], dim=1)
142 | return rpys
143 |
144 | @classmethod
145 | def from_quaternion(cls, quat, ordering='wxyz'):
146 | """Form a rotation matrix from a unit length quaternion.
147 | Valid orderings are 'xyzw' and 'wxyz'.
148 | """
149 | if ordering is 'xyzw':
150 | qx = quat[:, 0]
151 | qy = quat[:, 1]
152 | qz = quat[:, 2]
153 | qw = quat[:, 3]
154 | elif ordering is 'wxyz':
155 | qw = quat[:, 0]
156 | qx = quat[:, 1]
157 | qy = quat[:, 2]
158 | qz = quat[:, 3]
159 |
160 | # Form the matrix
161 | mat = quat.new_empty(quat.shape[0], 3, 3)
162 |
163 | qx2 = qx * qx
164 | qy2 = qy * qy
165 | qz2 = qz * qz
166 |
167 | mat[:, 0, 0] = 1. - 2. * (qy2 + qz2)
168 | mat[:, 0, 1] = 2. * (qx * qy - qw * qz)
169 | mat[:, 0, 2] = 2. * (qw * qy + qx * qz)
170 |
171 | mat[:, 1, 0] = 2. * (qw * qz + qx * qy)
172 | mat[:, 1, 1] = 1. - 2. * (qx2 + qz2)
173 | mat[:, 1, 2] = 2. * (qy * qz - qw * qx)
174 |
175 | mat[:, 2, 0] = 2. * (qx * qz - qw * qy)
176 | mat[:, 2, 1] = 2. * (qw * qx + qy * qz)
177 | mat[:, 2, 2] = 1. - 2. * (qx2 + qy2)
178 | return mat
179 |
180 | @classmethod
181 | def to_quaternion(cls, Rots, ordering='wxyz'):
182 | """Convert a rotation matrix to a unit length quaternion.
183 | Valid orderings are 'xyzw' and 'wxyz'.
184 | """
185 | tmp = 1 + Rots[:, 0, 0] + Rots[:, 1, 1] + Rots[:, 2, 2]
186 | tmp[tmp < 0] = 0
187 | qw = 0.5 * torch.sqrt(tmp)
188 | qx = qw.new_empty(qw.shape[0])
189 | qy = qw.new_empty(qw.shape[0])
190 | qz = qw.new_empty(qw.shape[0])
191 |
192 | near_zero_mask = qw.abs() < cls.TOL
193 |
194 | if near_zero_mask.sum() > 0:
195 | cond1_mask = near_zero_mask * \
196 | (Rots[:, 0, 0] > Rots[:, 1, 1])*(Rots[:, 0, 0] > Rots[:, 2, 2])
197 | cond1_inds = cond1_mask.nonzero()
198 |
199 | if len(cond1_inds) > 0:
200 | cond1_inds = cond1_inds.squeeze()
201 | R_cond1 = Rots[cond1_inds].view(-1, 3, 3)
202 | d = 2. * torch.sqrt(1. + R_cond1[:, 0, 0] -
203 | R_cond1[:, 1, 1] - R_cond1[:, 2, 2]).view(-1)
204 | qw[cond1_inds] = (R_cond1[:, 2, 1] - R_cond1[:, 1, 2]) / d
205 | qx[cond1_inds] = 0.25 * d
206 | qy[cond1_inds] = (R_cond1[:, 1, 0] + R_cond1[:, 0, 1]) / d
207 | qz[cond1_inds] = (R_cond1[:, 0, 2] + R_cond1[:, 2, 0]) / d
208 |
209 | cond2_mask = near_zero_mask * (Rots[:, 1, 1] > Rots[:, 2, 2])
210 | cond2_inds = cond2_mask.nonzero()
211 |
212 | if len(cond2_inds) > 0:
213 | cond2_inds = cond2_inds.squeeze()
214 | R_cond2 = Rots[cond2_inds].view(-1, 3, 3)
215 | d = 2. * torch.sqrt(1. + R_cond2[:, 1, 1] -
216 | R_cond2[:, 0, 0] - R_cond2[:, 2, 2]).squeeze()
217 | tmp = (R_cond2[:, 0, 2] - R_cond2[:, 2, 0]) / d
218 | qw[cond2_inds] = tmp
219 | qx[cond2_inds] = (R_cond2[:, 1, 0] + R_cond2[:, 0, 1]) / d
220 | qy[cond2_inds] = 0.25 * d
221 | qz[cond2_inds] = (R_cond2[:, 2, 1] + R_cond2[:, 1, 2]) / d
222 |
223 | cond3_mask = near_zero_mask & cond1_mask.logical_not() & cond2_mask.logical_not()
224 | cond3_inds = cond3_mask
225 |
226 | if len(cond3_inds) > 0:
227 | R_cond3 = Rots[cond3_inds].view(-1, 3, 3)
228 | d = 2. * \
229 | torch.sqrt(1. + R_cond3[:, 2, 2] -
230 | R_cond3[:, 0, 0] - R_cond3[:, 1, 1]).squeeze()
231 | qw[cond3_inds] = (R_cond3[:, 1, 0] - R_cond3[:, 0, 1]) / d
232 | qx[cond3_inds] = (R_cond3[:, 0, 2] + R_cond3[:, 2, 0]) / d
233 | qy[cond3_inds] = (R_cond3[:, 2, 1] + R_cond3[:, 1, 2]) / d
234 | qz[cond3_inds] = 0.25 * d
235 |
236 | far_zero_mask = near_zero_mask.logical_not()
237 | far_zero_inds = far_zero_mask
238 | if len(far_zero_inds) > 0:
239 | R_fz = Rots[far_zero_inds]
240 | d = 4. * qw[far_zero_inds]
241 | qx[far_zero_inds] = (R_fz[:, 2, 1] - R_fz[:, 1, 2]) / d
242 | qy[far_zero_inds] = (R_fz[:, 0, 2] - R_fz[:, 2, 0]) / d
243 | qz[far_zero_inds] = (R_fz[:, 1, 0] - R_fz[:, 0, 1]) / d
244 |
245 | # Check ordering last
246 | if ordering is 'xyzw':
247 | quat = torch.stack([qx, qy, qz, qw], dim=1)
248 | elif ordering is 'wxyz':
249 | quat = torch.stack([qw, qx, qy, qz], dim=1)
250 | return quat
251 |
252 | @classmethod
253 | def normalize(cls, Rots):
254 | U, _, V = torch.svd(Rots)
255 | S = cls.Id.clone().repeat(Rots.shape[0], 1, 1)
256 | S[:, 2, 2] = torch.det(U) * torch.det(V)
257 | return U.bmm(S).bmm(V.transpose(1, 2))
258 |
259 | @classmethod
260 | def dnormalize(cls, Rots):
261 | U, _, V = torch.svd(Rots)
262 | S = cls.dId.clone().repeat(Rots.shape[0], 1, 1)
263 | S[:, 2, 2] = torch.det(U) * torch.det(V)
264 | return U.bmm(S).bmm(V.transpose(1, 2))
265 |
266 | @classmethod
267 | def qmul(cls, q, r, ordering='wxyz'):
268 | """
269 | Multiply quaternion(s) q with quaternion(s) r.
270 | """
271 | terms = cls.bouter(r, q)
272 | w = terms[:, 0, 0] - terms[:, 1, 1] - terms[:, 2, 2] - terms[:, 3, 3]
273 | x = terms[:, 0, 1] + terms[:, 1, 0] - terms[:, 2, 3] + terms[:, 3, 2]
274 | y = terms[:, 0, 2] + terms[:, 1, 3] + terms[:, 2, 0] - terms[:, 3, 1]
275 | z = terms[:, 0, 3] - terms[:, 1, 2] + terms[:, 2, 1] + terms[:, 3, 0]
276 | xyz = torch.stack((x, y, z), dim=1)
277 | xyz[w < 0] *= -1
278 | w[w < 0] *= -1
279 | if ordering == 'wxyz':
280 | q = torch.cat((w.unsqueeze(1), xyz), dim=1)
281 | else:
282 | q = torch.cat((xyz, w.unsqueeze(1)), dim=1)
283 | return q / q.norm(dim=1, keepdim=True)
284 |
285 | @staticmethod
286 | def sinc(x):
287 | return x.sin() / x
288 |
289 | @classmethod
290 | def qexp(cls, xi, ordering='wxyz'):
291 | """
292 | Convert exponential maps to quaternions.
293 | """
294 | theta = xi.norm(dim=1, keepdim=True)
295 | w = (0.5*theta).cos()
296 | xyz = 0.5*cls.sinc(0.5*theta/np.pi)*xi
297 | return torch.cat((w, xyz), 1)
298 |
299 | @classmethod
300 | def qlog(cls, q, ordering='wxyz'):
301 | """
302 | Applies the log map to quaternions.
303 | """
304 | n = 0.5*torch.norm(q[:, 1:], p=2, dim=1, keepdim=True)
305 | n = torch.clamp(n, min=1e-8)
306 | q = q[:, 1:] * torch.acos(torch.clamp(q[:, :1], min=-1.0, max=1.0))
307 | r = q / n
308 | return r
309 |
310 | @classmethod
311 | def qinv(cls, q, ordering='wxyz'):
312 | "Quaternion inverse"
313 | r = torch.empty_like(q)
314 | if ordering == 'wxyz':
315 | r[:, 1:4] = -q[:, 1:4]
316 | r[:, 0] = q[:, 0]
317 | else:
318 | r[:, :3] = -q[:, :3]
319 | r[:, 3] = q[:, 3]
320 | return r
321 |
322 | @classmethod
323 | def qnorm(cls, q):
324 | "Quaternion normalization"
325 | return q / q.norm(dim=1, keepdim=True)
326 |
327 | @classmethod
328 | def qinterp(cls, qs, t, t_int):
329 | idxs = np.searchsorted(t, t_int)
330 | idxs0 = idxs-1
331 | idxs0[idxs0 < 0] = 0
332 | idxs1 = idxs
333 | idxs1[idxs1 == t.shape[0]] = t.shape[0] - 1
334 | q0 = qs[idxs0]
335 | q1 = qs[idxs1]
336 | tau = torch.zeros_like(t_int)
337 | dt = (t[idxs1]-t[idxs0])[idxs0 != idxs1]
338 | tau[idxs0 != idxs1] = (t_int-t[idxs0])[idxs0 != idxs1]/dt
339 | return cls.slerp(q0, q1, tau)
340 |
341 | @classmethod
342 | def slerp(cls, q0, q1, tau, DOT_THRESHOLD = 0.9995):
343 | """Spherical linear interpolation."""
344 |
345 | dot = (q0*q1).sum(dim=1)
346 | q1[dot < 0] = -q1[dot < 0]
347 | dot[dot < 0] = -dot[dot < 0]
348 |
349 | q = torch.zeros_like(q0)
350 | tmp = q0 + tau.unsqueeze(1) * (q1 - q0)
351 | tmp = tmp[dot > DOT_THRESHOLD]
352 | q[dot > DOT_THRESHOLD] = tmp / tmp.norm(dim=1, keepdim=True)
353 |
354 | theta_0 = dot.acos()
355 | sin_theta_0 = theta_0.sin()
356 | theta = theta_0 * tau
357 | sin_theta = theta.sin()
358 | s0 = (theta.cos() - dot * sin_theta / sin_theta_0).unsqueeze(1)
359 | s1 = (sin_theta / sin_theta_0).unsqueeze(1)
360 | q[dot < DOT_THRESHOLD] = ((s0 * q0) + (s1 * q1))[dot < DOT_THRESHOLD]
361 | return q / q.norm(dim=1, keepdim=True)
362 |
363 | @staticmethod
364 | def bouter(vec1, vec2):
365 | """batch outer product"""
366 | return torch.einsum('bi, bj -> bij', vec1, vec2)
367 |
368 | @staticmethod
369 | def btrace(mat):
370 | """batch matrix trace"""
371 | return torch.einsum('bii -> b', mat)
372 |
373 |
374 | class CPUSO3:
375 | # tolerance criterion
376 | TOL = 1e-8
377 | Id = torch.eye(3)
378 |
379 | @classmethod
380 | def qmul(cls, q, r):
381 | """
382 | Multiply quaternion(s) q with quaternion(s) r.
383 | """
384 | # Compute outer product
385 | terms = cls.outer(r, q)
386 | w = terms[0, 0] - terms[1, 1] - terms[2, 2] - terms[3, 3]
387 | x = terms[0, 1] + terms[1, 0] - terms[2, 3] + terms[3, 2]
388 | y = terms[0, 2] + terms[1, 3] + terms[2, 0] - terms[3, 1]
389 | z = terms[0, 3] - terms[1, 2] + terms[2, 1] + terms[3, 0]
390 | return torch.stack((w, x, y, z))
391 |
392 | @staticmethod
393 | def outer(a, b):
394 | return torch.einsum('i, j -> ij', a, b)
--------------------------------------------------------------------------------
/src/losses.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from src.utils import bmmt, bmv, bmtv, bbmv, bmtm
4 | from src.lie_algebra import SO3
5 | import matplotlib.pyplot as plt
6 |
7 |
8 | class BaseLoss(torch.nn.Module):
9 |
10 | def __init__(self, min_N, max_N, dt):
11 | super().__init__()
12 | # windows sizes
13 | self.min_N = min_N
14 | self.max_N = max_N
15 | self.min_train_freq = 2 ** self.min_N
16 | self.max_train_freq = 2 ** self.max_N
17 | # sampling time
18 | self.dt = dt # (s)
19 |
20 |
21 | class GyroLoss(BaseLoss):
22 | """Loss for low-frequency orientation increment"""
23 |
24 | def __init__(self, w, min_N, max_N, dt, target, huber):
25 | super().__init__(min_N, max_N, dt)
26 | # weights on loss
27 | self.w = w
28 | self.sl = torch.nn.SmoothL1Loss()
29 | if target == 'rotation matrix':
30 | self.forward = self.forward_with_rotation_matrices
31 | elif target == 'quaternion':
32 | self.forward = self.forward_with_quaternions
33 | elif target == 'rotation matrix mask':
34 | self.forward = self.forward_with_rotation_matrices_mask
35 | elif target == 'quaternion mask':
36 | self.forward = self.forward_with_quaternion_mask
37 | self.huber = huber
38 | self.weight = torch.ones(1, 1,
39 | self.min_train_freq).cuda()/self.min_train_freq
40 | self.N0 = 5 # remove first N0 increment in loss due not account padding
41 |
42 | def f_huber(self, rs):
43 | """Huber loss function"""
44 | loss = self.w*self.sl(rs/self.huber,
45 | torch.zeros_like(rs))*(self.huber**2)
46 | return loss
47 |
48 | def forward_with_rotation_matrices(self, xs, hat_xs):
49 | """Forward errors with rotation matrices"""
50 | N = xs.shape[0]
51 | Xs = SO3.exp(xs[:, ::self.min_train_freq].reshape(-1, 3).double())
52 | hat_xs = self.dt*hat_xs.reshape(-1, 3).double()
53 | Omegas = SO3.exp(hat_xs[:, :3])
54 | # compute increment at min_train_freq by decimation
55 | for k in range(self.min_N):
56 | Omegas = Omegas[::2].bmm(Omegas[1::2])
57 | rs = SO3.log(bmtm(Omegas, Xs)).reshape(N, -1, 3)[:, self.N0:]
58 | loss = self.f_huber(rs)
59 | # compute increment from min_train_freq to max_train_freq
60 | for k in range(self.min_N, self.max_N):
61 | Omegas = Omegas[::2].bmm(Omegas[1::2])
62 | Xs = Xs[::2].bmm(Xs[1::2])
63 | rs = SO3.log(bmtm(Omegas, Xs)).reshape(N, -1, 3)[:, self.N0:]
64 | loss = loss + self.f_huber(rs)/(2**(k - self.min_N + 1))
65 | return loss
66 |
67 | def forward_with_quaternions(self, xs, hat_xs):
68 | """Forward errors with quaternion"""
69 | N = xs.shape[0]
70 | Xs = SO3.qexp(xs[:, ::self.min_train_freq].reshape(-1, 3).double())
71 | hat_xs = self.dt*hat_xs.reshape(-1, 3).double()
72 | Omegas = SO3.qexp(hat_xs[:, :3])
73 | # compute increment at min_train_freq by decimation
74 | for k in range(self.min_N):
75 | Omegas = SO3.qmul(Omegas[::2], Omegas[1::2])
76 | rs = SO3.qlog(SO3.qmul(SO3.qinv(Omegas), Xs)).reshape(N,
77 | -1, 3)[:, self.N0:]
78 | loss = self.f_huber(rs)
79 | # compute increment from min_train_freq to max_train_freq
80 | for k in range(self.min_N, self.max_N):
81 | Omegas = SO3.qmul(Omegas[::2], Omegas[1::2])
82 | Xs = SO3.qmul(Xs[::2], Xs[1::2])
83 | rs = SO3.qlog(SO3.qmul(SO3.qinv(Omegas), Xs))
84 | rs = rs.view(N, -1, 3)[:, self.N0:]
85 | loss = loss + self.f_huber(rs)/(2**(k - self.min_N + 1))
86 | return loss
87 |
88 | def forward_with_rotation_matrices_mask(self, xs, hat_xs):
89 | """Forward errors with rotation matrices"""
90 | N = xs.shape[0]
91 | masks = xs[:, :, 3].unsqueeze(1)
92 | masks = torch.nn.functional.conv1d(masks, self.weight, bias=None,
93 | stride=self.min_train_freq).double().transpose(1, 2)
94 | masks[masks < 1] = 0
95 | Xs = SO3.exp(xs[:, ::self.min_train_freq, :3].reshape(-1, 3).double())
96 | hat_xs = self.dt*hat_xs.reshape(-1, 3).double()
97 | Omegas = SO3.exp(hat_xs[:, :3])
98 | # compute increment at min_train_freq by decimation
99 | for k in range(self.min_N):
100 | Omegas = Omegas[::2].bmm(Omegas[1::2])
101 | rs = SO3.log(bmtm(Omegas, Xs)).reshape(N, -1, 3)[:, self.N0:]
102 | loss = self.f_huber(rs)
103 | # compute increment from min_train_freq to max_train_freq
104 | for k in range(self.min_N, self.max_N):
105 | Omegas = Omegas[::2].bmm(Omegas[1::2])
106 | Xs = Xs[::2].bmm(Xs[1::2])
107 | masks = masks[:, ::2] * masks[:, 1::2]
108 | rs = SO3.log(bmtm(Omegas, Xs)).reshape(N, -1, 3)[:, self.N0:]
109 | rs = rs[masks[:, self.N0:].squeeze(2) == 1]
110 | loss = loss + self.f_huber(rs[:,2])/(2**(k - self.min_N + 1))
111 | return loss
112 |
113 | def forward_with_quaternion_mask(self, xs, hat_xs):
114 | """Forward errors with quaternion"""
115 | N = xs.shape[0]
116 | masks = xs[:, :, 3].unsqueeze(1)
117 | masks = torch.nn.functional.conv1d(masks, self.weight, bias=None,
118 | stride=self.min_train_freq).double().transpose(1, 2)
119 | masks[masks < 1] = 0
120 | Xs = SO3.qexp(xs[:, ::self.min_train_freq, :3].reshape(-1, 3).double())
121 | hat_xs = self.dt*hat_xs.reshape(-1, 3).double()
122 | Omegas = SO3.qexp(hat_xs[:, :3])
123 | # compute increment at min_train_freq by decimation
124 | for k in range(self.min_N):
125 | Omegas = SO3.qmul(Omegas[::2], Omegas[1::2])
126 | rs = SO3.qlog(SO3.qmul(SO3.qinv(Omegas), Xs)).reshape(N,
127 | -1, 3)[:, self.N0:]
128 | rs = rs[masks[:, self.N0:].squeeze(2) == 1]
129 | loss = self.f_huber(rs)
130 | # compute increment from min_train_freq to max_train_freq
131 | for k in range(self.min_N, self.max_N):
132 | Omegas = SO3.qmul(Omegas[::2], Omegas[1::2])
133 | Xs = SO3.qmul(Xs[::2], Xs[1::2])
134 | masks = masks[:, ::2] * masks[:, 1::2]
135 | rs = SO3.qlog(SO3.qmul(SO3.qinv(Omegas), Xs)).reshape(N,
136 | -1, 3)[:, self.N0:]
137 | rs = rs[masks[:, self.N0:].squeeze(2) == 1]
138 | loss = loss + self.f_huber(rs)/(2**(k - self.min_N + 1))
139 | return loss
140 |
--------------------------------------------------------------------------------
/src/networks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import matplotlib.pyplot as plt
3 | import numpy as np
4 | from src.utils import bmtm, bmtv, bmmt, bbmv
5 | from src.lie_algebra import SO3
6 |
7 |
8 | class BaseNet(torch.nn.Module):
9 | def __init__(self, in_dim, out_dim, c0, dropout, ks, ds, momentum):
10 | super().__init__()
11 | self.in_dim = in_dim
12 | self.out_dim = out_dim
13 | # channel dimension
14 | c1 = 2*c0
15 | c2 = 2*c1
16 | c3 = 2*c2
17 | # kernel dimension (odd number)
18 | k0 = ks[0]
19 | k1 = ks[1]
20 | k2 = ks[2]
21 | k3 = ks[3]
22 | # dilation dimension
23 | d0 = ds[0]
24 | d1 = ds[1]
25 | d2 = ds[2]
26 | # padding
27 | p0 = (k0-1) + d0*(k1-1) + d0*d1*(k2-1) + d0*d1*d2*(k3-1)
28 | # nets
29 | self.cnn = torch.nn.Sequential(
30 | torch.nn.ReplicationPad1d((p0, 0)), # padding at start
31 | torch.nn.Conv1d(in_dim, c0, k0, dilation=1),
32 | torch.nn.BatchNorm1d(c0, momentum=momentum),
33 | torch.nn.GELU(),
34 | torch.nn.Dropout(dropout),
35 | torch.nn.Conv1d(c0, c1, k1, dilation=d0),
36 | torch.nn.BatchNorm1d(c1, momentum=momentum),
37 | torch.nn.GELU(),
38 | torch.nn.Dropout(dropout),
39 | torch.nn.Conv1d(c1, c2, k2, dilation=d0*d1),
40 | torch.nn.BatchNorm1d(c2, momentum=momentum),
41 | torch.nn.GELU(),
42 | torch.nn.Dropout(dropout),
43 | torch.nn.Conv1d(c2, c3, k3, dilation=d0*d1*d2),
44 | torch.nn.BatchNorm1d(c3, momentum=momentum),
45 | torch.nn.GELU(),
46 | torch.nn.Dropout(dropout),
47 | torch.nn.Conv1d(c3, out_dim, 1, dilation=1),
48 | torch.nn.ReplicationPad1d((0, 0)), # no padding at end
49 | )
50 | # for normalizing inputs
51 | self.mean_u = torch.nn.Parameter(torch.zeros(in_dim),
52 | requires_grad=False)
53 | self.std_u = torch.nn.Parameter(torch.ones(in_dim),
54 | requires_grad=False)
55 |
56 | def forward(self, us):
57 | u = self.norm(us).transpose(1, 2)
58 | y = self.cnn(u)
59 | return y
60 |
61 | def norm(self, us):
62 | return (us-self.mean_u)/self.std_u
63 |
64 | def set_normalized_factors(self, mean_u, std_u):
65 | self.mean_u = torch.nn.Parameter(mean_u.cuda(), requires_grad=False)
66 | self.std_u = torch.nn.Parameter(std_u.cuda(), requires_grad=False)
67 |
68 |
69 | class GyroNet(BaseNet):
70 | def __init__(self, in_dim, out_dim, c0, dropout, ks, ds, momentum,
71 | gyro_std):
72 | super().__init__(in_dim, out_dim, c0, dropout, ks, ds, momentum)
73 | gyro_std = torch.Tensor(gyro_std)
74 | self.gyro_std = torch.nn.Parameter(gyro_std, requires_grad=False)
75 |
76 | gyro_Rot = 0.05*torch.randn(3, 3).cuda()
77 | self.gyro_Rot = torch.nn.Parameter(gyro_Rot)
78 | self.Id3 = torch.eye(3).cuda()
79 |
80 | def forward(self, us):
81 | ys = super().forward(us)
82 | Rots = (self.Id3 + self.gyro_Rot).expand(us.shape[0], us.shape[1], 3, 3)
83 | Rot_us = bbmv(Rots, us[:, :, :3])
84 | return self.gyro_std*ys.transpose(1, 2) + Rot_us
85 |
86 |
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import pickle
4 | import yaml
5 |
6 |
7 | def pload(*f_names):
8 | """Pickle load"""
9 | f_name = os.path.join(*f_names)
10 | with open(f_name, "rb") as f:
11 | pickle_dict = pickle.load(f)
12 | return pickle_dict
13 |
14 | def pdump(pickle_dict, *f_names):
15 | """Pickle dump"""
16 | f_name = os.path.join(*f_names)
17 | with open(f_name, "wb") as f:
18 | pickle.dump(pickle_dict, f)
19 |
20 | def mkdir(*paths):
21 | '''Create a directory if not existing.'''
22 | path = os.path.join(*paths)
23 | if not os.path.exists(path):
24 | os.mkdir(path)
25 |
26 | def yload(*f_names):
27 | """YAML load"""
28 | f_name = os.path.join(*f_names)
29 | with open(f_name, 'r') as f:
30 | yaml_dict = yaml.load(f)
31 | return yaml_dict
32 |
33 | def ydump(yaml_dict, *f_names):
34 | """YAML dump"""
35 | f_name = os.path.join(*f_names)
36 | with open(f_name, 'w') as f:
37 | yaml.dump(yaml_dict, f, default_flow_style=False)
38 |
39 | def bmv(mat, vec):
40 | """batch matrix vector product"""
41 | return torch.einsum('bij, bj -> bi', mat, vec)
42 |
43 | def bbmv(mat, vec):
44 | """double batch matrix vector product"""
45 | return torch.einsum('baij, baj -> bai', mat, vec)
46 |
47 | def bmtv(mat, vec):
48 | """batch matrix transpose vector product"""
49 | return torch.einsum('bji, bj -> bi', mat, vec)
50 |
51 | def bmtm(mat1, mat2):
52 | """batch matrix transpose matrix product"""
53 | return torch.einsum("bji, bjk -> bik", mat1, mat2)
54 |
55 | def bmmt(mat1, mat2):
56 | """batch matrix matrix transpose product"""
57 | return torch.einsum("bij, bkj -> bik", mat1, mat2)
58 |
--------------------------------------------------------------------------------