├── .gitignore ├── LICENCE ├── README.md ├── figures ├── boxplot.jpg ├── methode.jpg ├── roe.jpg └── rpy.jpg ├── main_EUROC.py ├── main_TUMVI.py ├── requirements.txt └── src ├── dataset.py ├── learning.py ├── lie_algebra.py ├── losses.py ├── networks.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | */__pycache__ 2 | data 3 | results 4 | -------------------------------------------------------------------------------- /LICENCE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Martin Brossard 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Denoising IMU Gyroscope with Deep Learning for Open-Loop Attitude Estimation 2 | 3 | ## Overview [[IEEE paper](https://ieeexplore.ieee.org/document/9119813), [preprint paper](https://hal.archives-ouvertes.fr/hal-02488923v4/document)] 4 | 5 | This repo contains a learning method for denoising gyroscopes of Inertial Measurement Units (IMUs) using 6 | ground truth data. In terms of attitude dead-reckoning estimation, the obtained algorithm is able to beat top-ranked 7 | visual-inertial odometry systems [3-5] in terms of attitude estimation 8 | although it only uses signals from a low-cost IMU. The obtained 9 | performances are achieved thanks to a well chosen model, and a 10 | proper loss function for orientation increments. Our approach builds upon a neural network based 11 | on dilated convolutions, without requiring any recurrent neural 12 | network. 13 | 14 | ## Code 15 | Our implementation is based on Python 3 and [Pytorch](https://pytorch.org/). We 16 | test the code under Ubuntu 16.04, Python 3.5, and Pytorch 1.5. The codebase is licensed under the MIT License. 17 | 18 | ### Installation & Prerequies 19 | 1. Install the correct version of [Pytorch](http://pytorch.org) 20 | ``` 21 | pip install --pre torch -f https://download.pytorch.org/whl/nightly/cu101/torch_nightly.html 22 | ``` 23 | 24 | 2. Clone this repo and create empty directories 25 | ``` 26 | git clone https://github.com/mbrossar/denoise-imu-gyro.git 27 | mkdir denoise-imu-gyro/data 28 | mkdir denoise-imu-gyro/results 29 | ``` 30 | 31 | 3. Install the following required Python packages, e.g. with the pip command 32 | ``` 33 | pip install -r denoise-imu-gyro/requirements.txt 34 | ``` 35 | 36 | ### Testing 37 | 38 | 1. Download reformated pickle format of the _EuRoC_ [1] and _TUM-VI_ [2] datasets at this [url](https://cloud.mines-paristech.fr/index.php/s/d2lHqzIk1PxzWmb/download), extract and copy then in the `data` folder. 39 | ``` 40 | wget "https://cloud.mines-paristech.fr/index.php/s/d2lHqzIk1PxzWmb/download" 41 | unzip download -d denoise-imu-gyro/data 42 | rm download 43 | ``` 44 | These file can alternatively be generated after downloading the _EuRoC_ and 45 | _TUM-VI_ datasets. They will be generated when lanching the main file after 46 | providing data paths. 47 | 48 | 2. Download optimized parameters at this [url](https://cloud.mines-paristech.fr/index.php/s/OLnj74YXtOLA7Hv/download), extract and copy in the `results` folder. 49 | ``` 50 | wget "https://cloud.mines-paristech.fr/index.php/s/OLnj74YXtOLA7Hv/download" 51 | unzip download -d denoise-imu-gyro/results 52 | rm download 53 | ``` 54 | 3. Test on the dataset on your choice ! 55 | ``` 56 | cd denoise-imu-gyro 57 | python3 main_EUROC.py 58 | # or alternatively 59 | # python3 main_TUMVI.py 60 | ``` 61 | 62 | You can then compare results with the evaluation [toolbox](https://github.com/rpng/open_vins/) of [3]. 63 | 64 | ### Training 65 | You can train the method by 66 | uncomment the two lines after # train in the main files. Edit then the 67 | configuration to obtain results with another sets of parameters. It roughly 68 | takes 5 minutes per dataset with a decent GPU. 69 | 70 | ## Schematic Illustration of the Proposed Method 71 | 72 |

73 | Schematic illustration of the proposed
 74 | method 75 |

76 | 77 | The convolutional neural network 78 | computes gyro corrections (based on past IMU measurements) that filters 79 | undesirable errors in the raw IMU signals. We 80 | then perform open-loop time integration on the noise-free measurements 81 | for regressing low frequency errors between ground truth and estimated 82 | orientation increments. 83 | 84 | ## Results 85 | 86 |

87 | Orientation estimates 88 |

89 | 90 | Orientation estimates on the test sequence _MH 04 difficult_ of [1] (left), and 91 | _room 4_ of [2] (right). Our method removes errors of the IMU. 92 | 93 |

94 | Relative Orientation Error 95 |

96 | 97 | Relative Orientation Error (ROE) in terms of 3D orientation and 98 | yaw errors on the test sequences. Our method competes with VIO methods albeit based only on IMU signals. 99 | 100 | ## Paper 101 | The paper M. Brossard, S. Bonnabel and A. Barrau, "Denoising IMU Gyroscopes With Deep Learning for Open-Loop Attitude Estimation," in _IEEE Robotics and Automation Letters_, vol. 5, no. 3, pp. 4796-4803, July 2020, doi: 10.1109/LRA.2020.3003256., relative to this repo, is 102 | available at this [url](https://ieeexplore.ieee.org/document/9119813) and a preprint at this [url](https://hal.archives-ouvertes.fr/hal-02488923/document). 103 | 104 | 105 | ## Citation 106 | 107 | If you use this code in your research, please cite: 108 | 109 | ``` 110 | @article{brossard2020denoising, 111 | author={M. {Brossard} and S. {Bonnabel} and A. {Barrau}}, 112 | journal={IEEE Robotics and Automation Letters}, 113 | title={Denoising IMU Gyroscopes With Deep Learning for Open-Loop Attitude Estimation}, 114 | year={2020}, 115 | volume={5}, 116 | number={3}, 117 | pages={4796-4803}, 118 | } 119 | 120 | ``` 121 | 122 | ## Authors 123 | 124 | This code was written by the [Centre of Robotique](http://caor-mines-paristech.fr/en/home/) at the 125 | MINESParisTech, Paris, France. 126 | 127 | [Martin 128 | Brossard](mailto:martin.brossard@mines-paristech.fr)^, [Axel 129 | Barrau](mailto:axel.barrau@safrangroup.com)^ and [Silvère 130 | Bonnabel](mailto:silvere.bonnabel@mines-paristech.fr)^. 131 | 132 | ^[MINES ParisTech](http://www.mines-paristech.eu/), PSL Research University, 133 | Centre for Robotics, 60 Boulevard Saint-Michel, 75006 Paris, France. 134 | 135 | ## Biblio 136 | 137 | [1] M. Burri, J. Nikolic, P. Gohl, T. Schneider, J. Rehder, S. Omari, 138 | M. W. Achtelik, and R. Siegwart, ``_The EuRoC Micro Aerial Vehicle 139 | Datasets_", The International Journal of Robotics Research, vol. 35, 140 | no. 10, pp. 1157–1163, 2016. 141 | 142 | [2] D. Schubert, T. Goll, N. Demmel, V. Usenko, J. Stuckler, and 143 | D. Cremers, ``_The TUM VI Benchmark for Evaluating Visual-Inertial 144 | Odometry_", in International Conference on Intelligent Robots and 145 | Systems (IROS). IEEE, pp. 1680–1687, 2018. 146 | 147 | [3] P. Geneva, K. Eckenhoff, W. Lee, Y. Yang, and G. Huang, ``_OpenVINS: 148 | A Research Platform for Visual-Inertial Estimation_", IROS Workshop 149 | on Visual-Inertial Navigation: Challenges and Applications, 2019. 150 | 151 | [4] T. Qin, P. Li, and S. Shen, ``_VINS-Mono: A Robust and Versatile 152 | Monocular Visual-Inertial State Estimator_", IEEE Transactions on 153 | Robotics, vol. 34, no. 4, pp. 1004–1020, 2018. 154 | 155 | [5] M. Bloesch, M. Burri, S. Omari, M. Hutter, and R. Siegwart, ``_Iterated 156 | Extended Kalman Filter Based Visual-Inertial Odometry Using Direct Photometric 157 | Feedback_", The International Journal of Robotics Research,vol. 36, no. 10, pp. 158 | 1053ñ1072, 2017. 159 | -------------------------------------------------------------------------------- /figures/boxplot.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mbrossar/denoise-imu-gyro/9948b6646882c547d5ba6eda19a95ea4cba5e89e/figures/boxplot.jpg -------------------------------------------------------------------------------- /figures/methode.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mbrossar/denoise-imu-gyro/9948b6646882c547d5ba6eda19a95ea4cba5e89e/figures/methode.jpg -------------------------------------------------------------------------------- /figures/roe.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mbrossar/denoise-imu-gyro/9948b6646882c547d5ba6eda19a95ea4cba5e89e/figures/roe.jpg -------------------------------------------------------------------------------- /figures/rpy.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mbrossar/denoise-imu-gyro/9948b6646882c547d5ba6eda19a95ea4cba5e89e/figures/rpy.jpg -------------------------------------------------------------------------------- /main_EUROC.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import src.learning as lr 4 | import src.networks as sn 5 | import src.losses as sl 6 | import src.dataset as ds 7 | import numpy as np 8 | 9 | base_dir = os.path.dirname(os.path.realpath(__file__)) 10 | data_dir = '/path/to/EUROC/dataset' 11 | # test a given network 12 | # address = os.path.join(base_dir, 'results/EUROC/2020_02_18_16_52_55/') 13 | # or test the last trained network 14 | address = "last" 15 | ################################################################################ 16 | # Network parameters 17 | ################################################################################ 18 | net_class = sn.GyroNet 19 | net_params = { 20 | 'in_dim': 6, 21 | 'out_dim': 3, 22 | 'c0': 16, 23 | 'dropout': 0.1, 24 | 'ks': [7, 7, 7, 7], 25 | 'ds': [4, 4, 4], 26 | 'momentum': 0.1, 27 | 'gyro_std': [1*np.pi/180, 2*np.pi/180, 5*np.pi/180], 28 | } 29 | ################################################################################ 30 | # Dataset parameters 31 | ################################################################################ 32 | dataset_class = ds.EUROCDataset 33 | dataset_params = { 34 | # where are raw data ? 35 | 'data_dir': data_dir, 36 | # where record preloaded data ? 37 | 'predata_dir': os.path.join(base_dir, 'data/EUROC'), 38 | # set train, val and test sequence 39 | 'train_seqs': [ 40 | 'MH_01_easy', 41 | 'MH_03_medium', 42 | 'MH_05_difficult', 43 | 'V1_02_medium', 44 | 'V2_01_easy', 45 | 'V2_03_difficult' 46 | ], 47 | 'val_seqs': [ 48 | 'MH_01_easy', 49 | 'MH_03_medium', 50 | 'MH_05_difficult', 51 | 'V1_02_medium', 52 | 'V2_01_easy', 53 | 'V2_03_difficult', 54 | ], 55 | 'test_seqs': [ 56 | 'MH_02_easy', 57 | 'MH_04_difficult', 58 | 'V2_02_medium', 59 | 'V1_03_difficult', 60 | 'V1_01_easy', 61 | ], 62 | # size of trajectory during training 63 | 'N': 32 * 500, # should be integer * 'max_train_freq' 64 | 'min_train_freq': 16, 65 | 'max_train_freq': 32, 66 | } 67 | ################################################################################ 68 | # Training parameters 69 | ################################################################################ 70 | train_params = { 71 | 'optimizer_class': torch.optim.Adam, 72 | 'optimizer': { 73 | 'lr': 0.01, 74 | 'weight_decay': 1e-1, 75 | 'amsgrad': False, 76 | }, 77 | 'loss_class': sl.GyroLoss, 78 | 'loss': { 79 | 'min_N': int(np.log2(dataset_params['min_train_freq'])), 80 | 'max_N': int(np.log2(dataset_params['max_train_freq'])), 81 | 'w': 1e6, 82 | 'target': 'rotation matrix', 83 | 'huber': 0.005, 84 | 'dt': 0.005, 85 | }, 86 | 'scheduler_class': torch.optim.lr_scheduler.CosineAnnealingWarmRestarts, 87 | 'scheduler': { 88 | 'T_0': 600, 89 | 'T_mult': 2, 90 | 'eta_min': 1e-3, 91 | }, 92 | 'dataloader': { 93 | 'batch_size': 10, 94 | 'pin_memory': False, 95 | 'num_workers': 0, 96 | 'shuffle': False, 97 | }, 98 | # frequency of validation step 99 | 'freq_val': 600, 100 | # total number of epochs 101 | 'n_epochs': 1800, 102 | # where record results ? 103 | 'res_dir': os.path.join(base_dir, "results/EUROC"), 104 | # where record Tensorboard log ? 105 | 'tb_dir': os.path.join(base_dir, "results/runs/EUROC"), 106 | } 107 | ################################################################################ 108 | # Train on training data set 109 | ################################################################################ 110 | # learning_process = lr.GyroLearningBasedProcessing(train_params['res_dir'], 111 | # train_params['tb_dir'], net_class, net_params, None, 112 | # train_params['loss']['dt']) 113 | # learning_process.train(dataset_class, dataset_params, train_params) 114 | ################################################################################ 115 | # Test on full data set 116 | ################################################################################ 117 | learning_process = lr.GyroLearningBasedProcessing(train_params['res_dir'], 118 | train_params['tb_dir'], net_class, net_params, address=address, 119 | dt=train_params['loss']['dt']) 120 | learning_process.test(dataset_class, dataset_params, ['test']) -------------------------------------------------------------------------------- /main_TUMVI.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import src.learning as lr 4 | import src.networks as sn 5 | import src.losses as sl 6 | import src.dataset as ds 7 | import numpy as np 8 | 9 | base_dir = os.path.dirname(os.path.realpath(__file__)) 10 | data_dir = '/path/to/TUM/dataset' 11 | # test a given network 12 | # address = os.path.join(base_dir, 'results/TUM/2020_02_18_16_26_33') 13 | # or test the last trained network 14 | address = 'last' 15 | ################################################################################ 16 | # Network parameters 17 | ################################################################################ 18 | net_class = sn.GyroNet 19 | net_params = { 20 | 'in_dim': 6, 21 | 'out_dim': 3, 22 | 'c0': 16, 23 | 'dropout': 0.1, 24 | 'ks': [7, 7, 7, 7], 25 | 'ds': [4, 4, 4], 26 | 'momentum': 0.1, 27 | 'gyro_std': [0.2*np.pi/180, 0.2*np.pi/180, 0.2*np.pi/180], 28 | } 29 | ################################################################################ 30 | # Dataset parameters 31 | ################################################################################ 32 | dataset_class = ds.TUMVIDataset 33 | dataset_params = { 34 | # where are raw data ? 35 | 'data_dir': data_dir, 36 | # where record preloaded data ? 37 | 'predata_dir': os.path.join(base_dir, 'data/TUM'), 38 | # set train, val and test sequence 39 | 'train_seqs': [ 40 | 'dataset-room1_512_16', 41 | 'dataset-room3_512_16', 42 | 'dataset-room5_512_16', 43 | ], 44 | 'val_seqs': [ 45 | 'dataset-room2_512_16', 46 | 'dataset-room4_512_16', 47 | 'dataset-room6_512_16', 48 | ], 49 | 'test_seqs': [ 50 | 'dataset-room2_512_16', 51 | 'dataset-room4_512_16', 52 | 'dataset-room6_512_16' 53 | ], 54 | # size of trajectory during training 55 | 'N': 32 * 500, # should be integer * 'max_train_freq' 56 | 'min_train_freq': 16, 57 | 'max_train_freq': 32, 58 | } 59 | ################################################################################ 60 | # Training parameters 61 | ################################################################################ 62 | train_params = { 63 | 'optimizer_class': torch.optim.Adam, 64 | 'optimizer': { 65 | 'lr': 0.01, 66 | 'weight_decay': 1e-1, 67 | 'amsgrad': False, 68 | }, 69 | 'loss_class': sl.GyroLoss, 70 | 'loss': { 71 | 'min_N': int(np.log2(dataset_params['min_train_freq'])), 72 | 'max_N': int(np.log2(dataset_params['max_train_freq'])), 73 | 'w': 1e6, 74 | 'target': 'rotation matrix mask', 75 | 'huber': 0.005, 76 | 'dt': 0.005, 77 | }, 78 | 'scheduler_class': torch.optim.lr_scheduler.CosineAnnealingWarmRestarts, 79 | 'scheduler': { 80 | 'T_0': 600, 81 | 'T_mult': 2, 82 | 'eta_min': 1e-3, 83 | }, 84 | 'dataloader': { 85 | 'batch_size': 10, 86 | 'pin_memory': False, 87 | 'num_workers': 0, 88 | 'shuffle': False, 89 | }, 90 | # frequency of validation step 91 | 'freq_val': 600, 92 | # total number of epochs 93 | 'n_epochs': 1800, 94 | # where record results ? 95 | 'res_dir': os.path.join(base_dir, "results/TUM"), 96 | # where record Tensorboard log ? 97 | 'tb_dir': os.path.join(base_dir, "results/runs/TUM"), 98 | } 99 | ################################################################################ 100 | # Train on training data set 101 | ################################################################################ 102 | # learning_process = lr.GyroLearningBasedProcessing(train_params['res_dir'], 103 | # train_params['tb_dir'], net_class, net_params, None 104 | # train_params['loss']['dt']) 105 | # learning_process.train(dataset_class, dataset_params, train_params) 106 | ################################################################################ 107 | # Test on full data set 108 | ################################################################################ 109 | learning_process = lr.GyroLearningBasedProcessing(train_params['res_dir'], 110 | train_params['tb_dir'], net_class, net_params, address=address, 111 | dt=train_params['loss']['dt']) 112 | learning_process.test(dataset_class, dataset_params, ['test']) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | torch 3 | datetime 4 | matplotlib 5 | termcolor 6 | pyyaml -------------------------------------------------------------------------------- /src/dataset.py: -------------------------------------------------------------------------------- 1 | from src.utils import pdump, pload, bmtv, bmtm 2 | from src.lie_algebra import SO3 3 | from termcolor import cprint 4 | from torch.utils.data.dataset import Dataset 5 | from scipy.interpolate import interp1d 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | import pickle 9 | import os 10 | import torch 11 | import sys 12 | 13 | class BaseDataset(Dataset): 14 | 15 | def __init__(self, predata_dir, train_seqs, val_seqs, test_seqs, mode, N, 16 | min_train_freq=128, max_train_freq=512, dt=0.005): 17 | super().__init__() 18 | # where record pre loaded data 19 | self.predata_dir = predata_dir 20 | self.path_normalize_factors = os.path.join(predata_dir, 'nf.p') 21 | 22 | self.mode = mode 23 | # choose between training, validation or test sequences 24 | train_seqs, self.sequences = self.get_sequences(train_seqs, val_seqs, 25 | test_seqs) 26 | # get and compute value for normalizing inputs 27 | self.mean_u, self.std_u = self.init_normalize_factors(train_seqs) 28 | self.mode = mode # train, val or test 29 | self._train = False 30 | self._val = False 31 | # noise density 32 | self.imu_std = torch.Tensor([8e-5, 1e-3]).float() 33 | # bias repeatability (without in-run bias stability) 34 | self.imu_b0 = torch.Tensor([1e-3, 1e-3]).float() 35 | # IMU sampling time 36 | self.dt = dt # (s) 37 | # sequence size during training 38 | self.N = N # power of 2 39 | self.min_train_freq = min_train_freq 40 | self.max_train_freq = max_train_freq 41 | self.uni = torch.distributions.uniform.Uniform(-torch.ones(1), 42 | torch.ones(1)) 43 | 44 | def get_sequences(self, train_seqs, val_seqs, test_seqs): 45 | """Choose sequence list depending on dataset mode""" 46 | sequences_dict = { 47 | 'train': train_seqs, 48 | 'val': val_seqs, 49 | 'test': test_seqs, 50 | } 51 | return sequences_dict['train'], sequences_dict[self.mode] 52 | 53 | def __getitem__(self, i): 54 | mondict = self.load_seq(i) 55 | N_max = mondict['xs'].shape[0] 56 | if self._train: # random start 57 | n0 = torch.randint(0, self.max_train_freq, (1, )) 58 | nend = n0 + self.N 59 | elif self._val: # end sequence 60 | n0 = self.max_train_freq + self.N 61 | nend = N_max - ((N_max - n0) % self.max_train_freq) 62 | else: # full sequence 63 | n0 = 0 64 | nend = N_max - (N_max % self.max_train_freq) 65 | u = mondict['us'][n0: nend] 66 | x = mondict['xs'][n0: nend] 67 | return u, x 68 | 69 | def __len__(self): 70 | return len(self.sequences) 71 | 72 | def add_noise(self, u): 73 | """Add Gaussian noise and bias to input""" 74 | noise = torch.randn_like(u) 75 | noise[:, :, :3] = noise[:, :, :3] * self.imu_std[0] 76 | noise[:, :, 3:6] = noise[:, :, 3:6] * self.imu_std[1] 77 | 78 | # bias repeatability (without in run bias stability) 79 | b0 = self.uni.sample(u[:, 0].shape).cuda() 80 | b0[:, :, :3] = b0[:, :, :3] * self.imu_b0[0] 81 | b0[:, :, 3:6] = b0[:, :, 3:6] * self.imu_b0[1] 82 | u = u + noise + b0.transpose(1, 2) 83 | return u 84 | 85 | def init_train(self): 86 | self._train = True 87 | self._val = False 88 | 89 | def init_val(self): 90 | self._train = False 91 | self._val = True 92 | 93 | def length(self): 94 | return self._length 95 | 96 | def load_seq(self, i): 97 | return pload(self.predata_dir, self.sequences[i] + '.p') 98 | 99 | def load_gt(self, i): 100 | return pload(self.predata_dir, self.sequences[i] + '_gt.p') 101 | 102 | def init_normalize_factors(self, train_seqs): 103 | if os.path.exists(self.path_normalize_factors): 104 | mondict = pload(self.path_normalize_factors) 105 | return mondict['mean_u'], mondict['std_u'] 106 | 107 | path = os.path.join(self.predata_dir, train_seqs[0] + '.p') 108 | if not os.path.exists(path): 109 | print("init_normalize_factors not computed") 110 | return 0, 0 111 | 112 | print('Start computing normalizing factors ...') 113 | cprint("Do it only on training sequences, it is vital!", 'yellow') 114 | # first compute mean 115 | num_data = 0 116 | 117 | for i, sequence in enumerate(train_seqs): 118 | pickle_dict = pload(self.predata_dir, sequence + '.p') 119 | us = pickle_dict['us'] 120 | sms = pickle_dict['xs'] 121 | if i == 0: 122 | mean_u = us.sum(dim=0) 123 | num_positive = sms.sum(dim=0) 124 | num_negative = sms.shape[0] - sms.sum(dim=0) 125 | else: 126 | mean_u += us.sum(dim=0) 127 | num_positive += sms.sum(dim=0) 128 | num_negative += sms.shape[0] - sms.sum(dim=0) 129 | num_data += us.shape[0] 130 | mean_u = mean_u / num_data 131 | pos_weight = num_negative / num_positive 132 | 133 | # second compute standard deviation 134 | for i, sequence in enumerate(train_seqs): 135 | pickle_dict = pload(self.predata_dir, sequence + '.p') 136 | us = pickle_dict['us'] 137 | if i == 0: 138 | std_u = ((us - mean_u) ** 2).sum(dim=0) 139 | else: 140 | std_u += ((us - mean_u) ** 2).sum(dim=0) 141 | std_u = (std_u / num_data).sqrt() 142 | normalize_factors = { 143 | 'mean_u': mean_u, 144 | 'std_u': std_u, 145 | } 146 | print('... ended computing normalizing factors') 147 | print('pos_weight:', pos_weight) 148 | print('This values most be a training parameters !') 149 | print('mean_u :', mean_u) 150 | print('std_u :', std_u) 151 | print('num_data :', num_data) 152 | pdump(normalize_factors, self.path_normalize_factors) 153 | return mean_u, std_u 154 | 155 | def read_data(self, data_dir): 156 | raise NotImplementedError 157 | 158 | @staticmethod 159 | def interpolate(x, t, t_int): 160 | """ 161 | Interpolate ground truth at the sensor timestamps 162 | """ 163 | 164 | # vector interpolation 165 | x_int = np.zeros((t_int.shape[0], x.shape[1])) 166 | for i in range(x.shape[1]): 167 | if i in [4, 5, 6, 7]: 168 | continue 169 | x_int[:, i] = np.interp(t_int, t, x[:, i]) 170 | # quaternion interpolation 171 | t_int = torch.Tensor(t_int - t[0]) 172 | t = torch.Tensor(t - t[0]) 173 | qs = SO3.qnorm(torch.Tensor(x[:, 4:8])) 174 | x_int[:, 4:8] = SO3.qinterp(qs, t, t_int).numpy() 175 | return x_int 176 | 177 | 178 | class EUROCDataset(BaseDataset): 179 | """ 180 | Dataloader for the EUROC Data Set. 181 | """ 182 | 183 | def __init__(self, data_dir, predata_dir, train_seqs, val_seqs, 184 | test_seqs, mode, N, min_train_freq, max_train_freq, dt=0.005): 185 | super().__init__(predata_dir, train_seqs, val_seqs, test_seqs, mode, N, min_train_freq, max_train_freq, dt) 186 | # convert raw data to pre loaded data 187 | self.read_data(data_dir) 188 | 189 | def read_data(self, data_dir): 190 | r"""Read the data from the dataset""" 191 | 192 | f = os.path.join(self.predata_dir, 'MH_01_easy.p') 193 | if True and os.path.exists(f): 194 | return 195 | 196 | print("Start read_data, be patient please") 197 | def set_path(seq): 198 | path_imu = os.path.join(data_dir, seq, "mav0", "imu0", "data.csv") 199 | path_gt = os.path.join(data_dir, seq, "mav0", "state_groundtruth_estimate0", "data.csv") 200 | return path_imu, path_gt 201 | 202 | sequences = os.listdir(data_dir) 203 | # read each sequence 204 | for sequence in sequences: 205 | print("\nSequence name: " + sequence) 206 | path_imu, path_gt = set_path(sequence) 207 | imu = np.genfromtxt(path_imu, delimiter=",", skip_header=1) 208 | gt = np.genfromtxt(path_gt, delimiter=",", skip_header=1) 209 | 210 | # time synchronization between IMU and ground truth 211 | t0 = np.max([gt[0, 0], imu[0, 0]]) 212 | t_end = np.min([gt[-1, 0], imu[-1, 0]]) 213 | 214 | # start index 215 | idx0_imu = np.searchsorted(imu[:, 0], t0) 216 | idx0_gt = np.searchsorted(gt[:, 0], t0) 217 | 218 | # end index 219 | idx_end_imu = np.searchsorted(imu[:, 0], t_end, 'right') 220 | idx_end_gt = np.searchsorted(gt[:, 0], t_end, 'right') 221 | 222 | # subsample 223 | imu = imu[idx0_imu: idx_end_imu] 224 | gt = gt[idx0_gt: idx_end_gt] 225 | ts = imu[:, 0]/1e9 226 | 227 | # interpolate 228 | gt = self.interpolate(gt, gt[:, 0]/1e9, ts) 229 | 230 | # take ground truth position 231 | p_gt = gt[:, 1:4] 232 | p_gt = p_gt - p_gt[0] 233 | 234 | # take ground true quaternion pose 235 | q_gt = torch.Tensor(gt[:, 4:8]).double() 236 | q_gt = q_gt / q_gt.norm(dim=1, keepdim=True) 237 | Rot_gt = SO3.from_quaternion(q_gt.cuda(), ordering='wxyz').cpu() 238 | 239 | # convert from numpy 240 | p_gt = torch.Tensor(p_gt).double() 241 | v_gt = torch.tensor(gt[:, 8:11]).double() 242 | imu = torch.Tensor(imu[:, 1:]).double() 243 | 244 | # compute pre-integration factors for all training 245 | mtf = self.min_train_freq 246 | dRot_ij = bmtm(Rot_gt[:-mtf], Rot_gt[mtf:]) 247 | dRot_ij = SO3.dnormalize(dRot_ij.cuda()) 248 | dxi_ij = SO3.log(dRot_ij).cpu() 249 | 250 | # save for all training 251 | mondict = { 252 | 'xs': dxi_ij.float(), 253 | 'us': imu.float(), 254 | } 255 | pdump(mondict, self.predata_dir, sequence + ".p") 256 | # save ground truth 257 | mondict = { 258 | 'ts': ts, 259 | 'qs': q_gt.float(), 260 | 'vs': v_gt.float(), 261 | 'ps': p_gt.float(), 262 | } 263 | pdump(mondict, self.predata_dir, sequence + "_gt.p") 264 | 265 | 266 | class TUMVIDataset(BaseDataset): 267 | """ 268 | Dataloader for the TUM-VI Data Set. 269 | """ 270 | 271 | def __init__(self, data_dir, predata_dir, train_seqs, val_seqs, 272 | test_seqs, mode, N, min_train_freq, max_train_freq, dt=0.005): 273 | super().__init__(predata_dir, train_seqs, val_seqs, test_seqs, mode, N, 274 | min_train_freq, max_train_freq, dt) 275 | # convert raw data to pre loaded data 276 | self.read_data(data_dir) 277 | # noise density 278 | self.imu_std = torch.Tensor([8e-5, 1e-3]).float() 279 | # bias repeatability (without in-run bias stability) 280 | self.imu_b0 = torch.Tensor([1e-3, 1e-3]).float() 281 | 282 | def read_data(self, data_dir): 283 | r"""Read the data from the dataset""" 284 | 285 | f = os.path.join(self.predata_dir, 'dataset-room1_512_16_gt.p') 286 | if True and os.path.exists(f): 287 | return 288 | 289 | print("Start read_data, be patient please") 290 | def set_path(seq): 291 | path_imu = os.path.join(data_dir, seq, "mav0", "imu0", "data.csv") 292 | path_gt = os.path.join(data_dir, seq, "mav0", "mocap0", "data.csv") 293 | return path_imu, path_gt 294 | 295 | sequences = os.listdir(data_dir) 296 | 297 | # read each sequence 298 | for sequence in sequences: 299 | print("\nSequence name: " + sequence) 300 | if 'room' not in sequence: 301 | continue 302 | 303 | path_imu, path_gt = set_path(sequence) 304 | imu = np.genfromtxt(path_imu, delimiter=",", skip_header=1) 305 | gt = np.genfromtxt(path_gt, delimiter=",", skip_header=1) 306 | 307 | # time synchronization between IMU and ground truth 308 | t0 = np.max([gt[0, 0], imu[0, 0]]) 309 | t_end = np.min([gt[-1, 0], imu[-1, 0]]) 310 | 311 | # start index 312 | idx0_imu = np.searchsorted(imu[:, 0], t0) 313 | idx0_gt = np.searchsorted(gt[:, 0], t0) 314 | 315 | # end index 316 | idx_end_imu = np.searchsorted(imu[:, 0], t_end, 'right') 317 | idx_end_gt = np.searchsorted(gt[:, 0], t_end, 'right') 318 | 319 | # subsample 320 | imu = imu[idx0_imu: idx_end_imu] 321 | gt = gt[idx0_gt: idx_end_gt] 322 | ts = imu[:, 0]/1e9 323 | 324 | # interpolate 325 | t_gt = gt[:, 0]/1e9 326 | gt = self.interpolate(gt, t_gt, ts) 327 | 328 | # take ground truth position 329 | p_gt = gt[:, 1:4] 330 | p_gt = p_gt - p_gt[0] 331 | 332 | # take ground true quaternion pose 333 | q_gt = SO3.qnorm(torch.Tensor(gt[:, 4:8]).double()) 334 | Rot_gt = SO3.from_quaternion(q_gt.cuda(), ordering='wxyz').cpu() 335 | 336 | # convert from numpy 337 | p_gt = torch.Tensor(p_gt).double() 338 | v_gt = torch.zeros_like(p_gt).double() 339 | v_gt[1:] = (p_gt[1:]-p_gt[:-1])/self.dt 340 | imu = torch.Tensor(imu[:, 1:]).double() 341 | 342 | # compute pre-integration factors for all training 343 | mtf = self.min_train_freq 344 | dRot_ij = bmtm(Rot_gt[:-mtf], Rot_gt[mtf:]) 345 | dRot_ij = SO3.dnormalize(dRot_ij.cuda()) 346 | dxi_ij = SO3.log(dRot_ij).cpu() 347 | 348 | # masks with 1 when ground truth is available, 0 otherwise 349 | masks = dxi_ij.new_ones(dxi_ij.shape[0]) 350 | tmp = np.searchsorted(t_gt, ts[:-mtf]) 351 | diff_t = ts[:-mtf] - t_gt[tmp] 352 | masks[np.abs(diff_t) > 0.01] = 0 353 | 354 | # save all the sequence 355 | mondict = { 356 | 'xs': torch.cat((dxi_ij, masks.unsqueeze(1)), 1).float(), 357 | 'us': imu.float(), 358 | } 359 | pdump(mondict, self.predata_dir, sequence + ".p") 360 | 361 | # save ground truth 362 | mondict = { 363 | 'ts': ts, 364 | 'qs': q_gt.float(), 365 | 'vs': v_gt.float(), 366 | 'ps': p_gt.float(), 367 | } 368 | pdump(mondict, self.predata_dir, sequence + "_gt.p") 369 | -------------------------------------------------------------------------------- /src/learning.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import time 4 | import matplotlib.pyplot as plt 5 | plt.rcParams["legend.loc"] = "upper right" 6 | plt.rcParams['axes.titlesize'] = 'x-large' 7 | plt.rcParams['axes.labelsize'] = 'x-large' 8 | plt.rcParams['legend.fontsize'] = 'x-large' 9 | plt.rcParams['xtick.labelsize'] = 'x-large' 10 | plt.rcParams['ytick.labelsize'] = 'x-large' 11 | from termcolor import cprint 12 | import numpy as np 13 | import os 14 | from torch.utils.tensorboard import SummaryWriter 15 | from torch.utils.data import DataLoader 16 | from src.utils import pload, pdump, yload, ydump, mkdir, bmv 17 | from src.utils import bmtm, bmtv, bmmt 18 | from datetime import datetime 19 | from src.lie_algebra import SO3, CPUSO3 20 | 21 | 22 | class LearningBasedProcessing: 23 | def __init__(self, res_dir, tb_dir, net_class, net_params, address, dt): 24 | self.res_dir = res_dir 25 | self.tb_dir = tb_dir 26 | self.net_class = net_class 27 | self.net_params = net_params 28 | self._ready = False 29 | self.train_params = {} 30 | self.figsize = (20, 12) 31 | self.dt = dt # (s) 32 | self.address, self.tb_address = self.find_address(address) 33 | if address is None: # create new address 34 | pdump(self.net_params, self.address, 'net_params.p') 35 | ydump(self.net_params, self.address, 'net_params.yaml') 36 | else: # pick the network parameters 37 | self.net_params = pload(self.address, 'net_params.p') 38 | self.train_params = pload(self.address, 'train_params.p') 39 | self._ready = True 40 | self.path_weights = os.path.join(self.address, 'weights.pt') 41 | self.net = self.net_class(**self.net_params) 42 | if self._ready: # fill network parameters 43 | self.load_weights() 44 | 45 | def find_address(self, address): 46 | """return path where net and training info are saved""" 47 | if address == 'last': 48 | addresses = sorted(os.listdir(self.res_dir)) 49 | tb_address = os.path.join(self.tb_dir, str(len(addresses))) 50 | address = os.path.join(self.res_dir, addresses[-1]) 51 | elif address is None: 52 | now = datetime.now().strftime("%Y_%m_%d_%H_%M_%S") 53 | address = os.path.join(self.res_dir, now) 54 | mkdir(address) 55 | tb_address = os.path.join(self.tb_dir, now) 56 | else: 57 | tb_address = None 58 | return address, tb_address 59 | 60 | def load_weights(self): 61 | weights = torch.load(self.path_weights) 62 | self.net.load_state_dict(weights) 63 | self.net.cuda() 64 | 65 | def train(self, dataset_class, dataset_params, train_params): 66 | """train the neural network. GPU is assumed""" 67 | self.train_params = train_params 68 | pdump(self.train_params, self.address, 'train_params.p') 69 | ydump(self.train_params, self.address, 'train_params.yaml') 70 | 71 | hparams = self.get_hparams(dataset_class, dataset_params, train_params) 72 | ydump(hparams, self.address, 'hparams.yaml') 73 | 74 | # define datasets 75 | dataset_train = dataset_class(**dataset_params, mode='train') 76 | dataset_train.init_train() 77 | dataset_val = dataset_class(**dataset_params, mode='val') 78 | dataset_val.init_val() 79 | 80 | # get class 81 | Optimizer = train_params['optimizer_class'] 82 | Scheduler = train_params['scheduler_class'] 83 | Loss = train_params['loss_class'] 84 | 85 | # get parameters 86 | dataloader_params = train_params['dataloader'] 87 | optimizer_params = train_params['optimizer'] 88 | scheduler_params = train_params['scheduler'] 89 | loss_params = train_params['loss'] 90 | 91 | # define optimizer, scheduler and loss 92 | dataloader = DataLoader(dataset_train, **dataloader_params) 93 | optimizer = Optimizer(self.net.parameters(), **optimizer_params) 94 | scheduler = Scheduler(optimizer, **scheduler_params) 95 | criterion = Loss(**loss_params) 96 | 97 | # remaining training parameters 98 | freq_val = train_params['freq_val'] 99 | n_epochs = train_params['n_epochs'] 100 | 101 | # init net w.r.t dataset 102 | self.net = self.net.cuda() 103 | mean_u, std_u = dataset_train.mean_u, dataset_train.std_u 104 | self.net.set_normalized_factors(mean_u, std_u) 105 | 106 | # start tensorboard writer 107 | writer = SummaryWriter(self.tb_address) 108 | start_time = time.time() 109 | best_loss = torch.Tensor([float('Inf')]) 110 | 111 | # define some function for seeing evolution of training 112 | def write(epoch, loss_epoch): 113 | writer.add_scalar('loss/train', loss_epoch.item(), epoch) 114 | writer.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch) 115 | print('Train Epoch: {:2d} \tLoss: {:.4f}'.format( 116 | epoch, loss_epoch.item())) 117 | scheduler.step(epoch) 118 | 119 | def write_time(epoch, start_time): 120 | delta_t = time.time() - start_time 121 | print("Amount of time spent for epochs " + 122 | "{}-{}: {:.1f}s\n".format(epoch - freq_val, epoch, delta_t)) 123 | writer.add_scalar('time_spend', delta_t, epoch) 124 | 125 | def write_val(loss, best_loss): 126 | if 0.5*loss <= best_loss: 127 | msg = 'validation loss decreases! :) ' 128 | msg += '(curr/prev loss {:.4f}/{:.4f})'.format(loss.item(), 129 | best_loss.item()) 130 | cprint(msg, 'green') 131 | best_loss = loss 132 | self.save_net() 133 | else: 134 | msg = 'validation loss increases! :( ' 135 | msg += '(curr/prev loss {:.4f}/{:.4f})'.format(loss.item(), 136 | best_loss.item()) 137 | cprint(msg, 'yellow') 138 | writer.add_scalar('loss/val', loss.item(), epoch) 139 | return best_loss 140 | 141 | # training loop ! 142 | for epoch in range(1, n_epochs + 1): 143 | loss_epoch = self.loop_train(dataloader, optimizer, criterion) 144 | write(epoch, loss_epoch) 145 | scheduler.step(epoch) 146 | if epoch % freq_val == 0: 147 | loss = self.loop_val(dataset_val, criterion) 148 | write_time(epoch, start_time) 149 | best_loss = write_val(loss, best_loss) 150 | start_time = time.time() 151 | # training is over ! 152 | 153 | # test on new data 154 | dataset_test = dataset_class(**dataset_params, mode='test') 155 | self.load_weights() 156 | test_loss = self.loop_val(dataset_test, criterion) 157 | dict_loss = { 158 | 'final_loss/val': best_loss.item(), 159 | 'final_loss/test': test_loss.item() 160 | } 161 | writer.add_hparams(hparams, dict_loss) 162 | ydump(dict_loss, self.address, 'final_loss.yaml') 163 | writer.close() 164 | 165 | def loop_train(self, dataloader, optimizer, criterion): 166 | """Forward-backward loop over training data""" 167 | loss_epoch = 0 168 | optimizer.zero_grad() 169 | for us, xs in dataloader: 170 | us = dataloader.dataset.add_noise(us.cuda()) 171 | hat_xs = self.net(us) 172 | loss = criterion(xs.cuda(), hat_xs)/len(dataloader) 173 | loss.backward() 174 | loss_epoch += loss.detach().cpu() 175 | optimizer.step() 176 | return loss_epoch 177 | 178 | def loop_val(self, dataset, criterion): 179 | """Forward loop over validation data""" 180 | loss_epoch = 0 181 | self.net.eval() 182 | with torch.no_grad(): 183 | for i in range(len(dataset)): 184 | us, xs = dataset[i] 185 | hat_xs = self.net(us.cuda().unsqueeze(0)) 186 | loss = criterion(xs.cuda().unsqueeze(0), hat_xs)/len(dataset) 187 | loss_epoch += loss.cpu() 188 | self.net.train() 189 | return loss_epoch 190 | 191 | def save_net(self): 192 | """save the weights on the net in CPU""" 193 | self.net.eval().cpu() 194 | torch.save(self.net.state_dict(), self.path_weights) 195 | self.net.train().cuda() 196 | 197 | def get_hparams(self, dataset_class, dataset_params, train_params): 198 | """return all training hyperparameters in a dict""" 199 | Optimizer = train_params['optimizer_class'] 200 | Scheduler = train_params['scheduler_class'] 201 | Loss = train_params['loss_class'] 202 | 203 | # get training class parameters 204 | dataloader_params = train_params['dataloader'] 205 | optimizer_params = train_params['optimizer'] 206 | scheduler_params = train_params['scheduler'] 207 | loss_params = train_params['loss'] 208 | 209 | # remaining training parameters 210 | freq_val = train_params['freq_val'] 211 | n_epochs = train_params['n_epochs'] 212 | 213 | dict_class = { 214 | 'Optimizer': str(Optimizer), 215 | 'Scheduler': str(Scheduler), 216 | 'Loss': str(Loss) 217 | } 218 | 219 | return {**dict_class, **dataloader_params, **optimizer_params, 220 | **loss_params, **scheduler_params, 221 | 'n_epochs': n_epochs, 'freq_val': freq_val} 222 | 223 | def test(self, dataset_class, dataset_params, modes): 224 | """test a network once training is over""" 225 | 226 | # get loss function 227 | Loss = self.train_params['loss_class'] 228 | loss_params = self.train_params['loss'] 229 | criterion = Loss(**loss_params) 230 | 231 | # test on each type of sequence 232 | for mode in modes: 233 | dataset = dataset_class(**dataset_params, mode=mode) 234 | self.loop_test(dataset, criterion) 235 | self.display_test(dataset, mode) 236 | 237 | def loop_test(self, dataset, criterion): 238 | """Forward loop over test data""" 239 | self.net.eval() 240 | for i in range(len(dataset)): 241 | seq = dataset.sequences[i] 242 | us, xs = dataset[i] 243 | with torch.no_grad(): 244 | hat_xs = self.net(us.cuda().unsqueeze(0)) 245 | loss = criterion(xs.cuda().unsqueeze(0), hat_xs) 246 | mkdir(self.address, seq) 247 | mondict = { 248 | 'hat_xs': hat_xs[0].cpu(), 249 | 'loss': loss.cpu().item(), 250 | } 251 | pdump(mondict, self.address, seq, 'results.p') 252 | 253 | def display_test(self, dataset, mode): 254 | raise NotImplementedError 255 | 256 | 257 | class GyroLearningBasedProcessing(LearningBasedProcessing): 258 | def __init__(self, res_dir, tb_dir, net_class, net_params, address, dt): 259 | super().__init__(res_dir, tb_dir, net_class, net_params, address, dt) 260 | self.roe_dist = [7, 14, 21, 28, 35] # m 261 | self.freq = 100 # subsampling frequency for RTE computation 262 | self.roes = { # relative trajectory errors 263 | 'Rots': [], 264 | 'yaws': [], 265 | } 266 | 267 | def display_test(self, dataset, mode): 268 | self.roes = { 269 | 'Rots': [], 270 | 'yaws': [], 271 | } 272 | self.to_open_vins(dataset) 273 | for i, seq in enumerate(dataset.sequences): 274 | print('\n', 'Results for sequence ' + seq ) 275 | self.seq = seq 276 | # get ground truth 277 | self.gt = dataset.load_gt(i) 278 | Rots = SO3.from_quaternion(self.gt['qs'].cuda()) 279 | self.gt['Rots'] = Rots.cpu() 280 | self.gt['rpys'] = SO3.to_rpy(Rots).cpu() 281 | # get data and estimate 282 | self.net_us = pload(self.address, seq, 'results.p')['hat_xs'] 283 | self.raw_us, _ = dataset[i] 284 | N = self.net_us.shape[0] 285 | self.gyro_corrections = (self.raw_us[:, :3] - self.net_us[:N, :3]) 286 | self.ts = torch.linspace(0, N*self.dt, N) 287 | 288 | self.convert() 289 | self.plot_gyro() 290 | self.plot_gyro_correction() 291 | plt.show() 292 | 293 | def to_open_vins(self, dataset): 294 | """ 295 | Export results to Open-VINS format. Use them eval toolbox available 296 | at https://github.com/rpng/open_vins/ 297 | """ 298 | 299 | for i, seq in enumerate(dataset.sequences): 300 | self.seq = seq 301 | # get ground truth 302 | self.gt = dataset.load_gt(i) 303 | raw_us, _ = dataset[i] 304 | net_us = pload(self.address, seq, 'results.p')['hat_xs'] 305 | N = net_us.shape[0] 306 | net_qs, imu_Rots, net_Rots = self.integrate_with_quaternions_superfast(N, raw_us, net_us) 307 | path = os.path.join(self.address, seq + '.txt') 308 | header = "timestamp(s) tx ty tz qx qy qz qw" 309 | x = np.zeros((net_qs.shape[0], 8)) 310 | x[:, 0] = self.gt['ts'][:net_qs.shape[0]] 311 | x[:, [7, 4, 5, 6]] = net_qs 312 | np.savetxt(path, x[::10], header=header, delimiter=" ", 313 | fmt='%1.9f') 314 | 315 | def convert(self): 316 | # s -> min 317 | l = 1/60 318 | self.ts *= l 319 | 320 | # rad -> deg 321 | l = 180/np.pi 322 | self.gyro_corrections *= l 323 | self.gt['rpys'] *= l 324 | 325 | def integrate_with_quaternions_superfast(self, N, raw_us, net_us): 326 | imu_qs = SO3.qnorm(SO3.qexp(raw_us[:, :3].cuda().double()*self.dt)) 327 | net_qs = SO3.qnorm(SO3.qexp(net_us[:, :3].cuda().double()*self.dt)) 328 | Rot0 = SO3.qnorm(self.gt['qs'][:2].cuda().double()) 329 | imu_qs[0] = Rot0[0] 330 | net_qs[0] = Rot0[0] 331 | 332 | N = np.log2(imu_qs.shape[0]) 333 | for i in range(int(N)): 334 | k = 2**i 335 | imu_qs[k:] = SO3.qnorm(SO3.qmul(imu_qs[:-k], imu_qs[k:])) 336 | net_qs[k:] = SO3.qnorm(SO3.qmul(net_qs[:-k], net_qs[k:])) 337 | 338 | if int(N) < N: 339 | k = 2**int(N) 340 | k2 = imu_qs[k:].shape[0] 341 | imu_qs[k:] = SO3.qnorm(SO3.qmul(imu_qs[:k2], imu_qs[k:])) 342 | net_qs[k:] = SO3.qnorm(SO3.qmul(net_qs[:k2], net_qs[k:])) 343 | 344 | imu_Rots = SO3.from_quaternion(imu_qs).float() 345 | net_Rots = SO3.from_quaternion(net_qs).float() 346 | return net_qs.cpu(), imu_Rots, net_Rots 347 | 348 | def plot_gyro(self): 349 | N = self.raw_us.shape[0] 350 | raw_us = self.raw_us[:, :3] 351 | net_us = self.net_us[:, :3] 352 | 353 | net_qs, imu_Rots, net_Rots = self.integrate_with_quaternions_superfast(N, 354 | raw_us, net_us) 355 | imu_rpys = 180/np.pi*SO3.to_rpy(imu_Rots).cpu() 356 | net_rpys = 180/np.pi*SO3.to_rpy(net_Rots).cpu() 357 | self.plot_orientation(imu_rpys, net_rpys, N) 358 | self.plot_orientation_error(imu_Rots, net_Rots, N) 359 | 360 | def plot_orientation(self, imu_rpys, net_rpys, N): 361 | title = "Orientation estimation" 362 | gt = self.gt['rpys'][:N] 363 | fig, axs = plt.subplots(3, 1, sharex=True, figsize=self.figsize) 364 | axs[0].set(ylabel='roll (deg)', title=title) 365 | axs[1].set(ylabel='pitch (deg)') 366 | axs[2].set(xlabel='$t$ (min)', ylabel='yaw (deg)') 367 | 368 | for i in range(3): 369 | axs[i].plot(self.ts, gt[:, i], color='black', label=r'ground truth') 370 | axs[i].plot(self.ts, imu_rpys[:, i], color='red', label=r'raw IMU') 371 | axs[i].plot(self.ts, net_rpys[:, i], color='blue', label=r'net IMU') 372 | axs[i].set_xlim(self.ts[0], self.ts[-1]) 373 | self.savefig(axs, fig, 'orientation') 374 | 375 | def plot_orientation_error(self, imu_Rots, net_Rots, N): 376 | gt = self.gt['Rots'][:N].cuda() 377 | raw_err = 180/np.pi*SO3.log(bmtm(imu_Rots, gt)).cpu() 378 | net_err = 180/np.pi*SO3.log(bmtm(net_Rots, gt)).cpu() 379 | title = "$SO(3)$ orientation error" 380 | fig, axs = plt.subplots(3, 1, sharex=True, figsize=self.figsize) 381 | axs[0].set(ylabel='roll (deg)', title=title) 382 | axs[1].set(ylabel='pitch (deg)') 383 | axs[2].set(xlabel='$t$ (min)', ylabel='yaw (deg)') 384 | 385 | for i in range(3): 386 | axs[i].plot(self.ts, raw_err[:, i], color='red', label=r'raw IMU') 387 | axs[i].plot(self.ts, net_err[:, i], color='blue', label=r'net IMU') 388 | axs[i].set_ylim(-10, 10) 389 | axs[i].set_xlim(self.ts[0], self.ts[-1]) 390 | self.savefig(axs, fig, 'orientation_error') 391 | 392 | def plot_gyro_correction(self): 393 | title = "Gyro correction" + self.end_title 394 | ylabel = 'gyro correction (deg/s)' 395 | fig, ax = plt.subplots(figsize=self.figsize) 396 | ax.set(xlabel='$t$ (min)', ylabel=ylabel, title=title) 397 | plt.plot(self.ts, self.gyro_corrections, label=r'net IMU') 398 | ax.set_xlim(self.ts[0], self.ts[-1]) 399 | self.savefig(ax, fig, 'gyro_correction') 400 | 401 | @property 402 | def end_title(self): 403 | return " for sequence " + self.seq.replace("_", " ") 404 | 405 | def savefig(self, axs, fig, name): 406 | if isinstance(axs, np.ndarray): 407 | for i in range(len(axs)): 408 | axs[i].grid() 409 | axs[i].legend() 410 | else: 411 | axs.grid() 412 | axs.legend() 413 | fig.tight_layout() 414 | fig.savefig(os.path.join(self.address, self.seq, name + '.png')) 415 | 416 | -------------------------------------------------------------------------------- /src/lie_algebra.py: -------------------------------------------------------------------------------- 1 | from src.utils import * 2 | import numpy as np 3 | 4 | 5 | class SO3: 6 | #  tolerance criterion 7 | TOL = 1e-8 8 | Id = torch.eye(3).cuda().float() 9 | dId = torch.eye(3).cuda().double() 10 | 11 | @classmethod 12 | def exp(cls, phi): 13 | angle = phi.norm(dim=1, keepdim=True) 14 | mask = angle[:, 0] < cls.TOL 15 | dim_batch = phi.shape[0] 16 | Id = cls.Id.expand(dim_batch, 3, 3) 17 | 18 | axis = phi[~mask] / angle[~mask] 19 | c = angle[~mask].cos().unsqueeze(2) 20 | s = angle[~mask].sin().unsqueeze(2) 21 | 22 | Rot = phi.new_empty(dim_batch, 3, 3) 23 | Rot[mask] = Id[mask] + SO3.wedge(phi[mask]) 24 | Rot[~mask] = c*Id[~mask] + \ 25 | (1-c)*cls.bouter(axis, axis) + s*cls.wedge(axis) 26 | return Rot 27 | 28 | @classmethod 29 | def log(cls, Rot): 30 | dim_batch = Rot.shape[0] 31 | Id = cls.Id.expand(dim_batch, 3, 3) 32 | 33 | cos_angle = (0.5 * cls.btrace(Rot) - 0.5).clamp(-1., 1.) 34 | # Clip cos(angle) to its proper domain to avoid NaNs from rounding 35 | # errors 36 | angle = cos_angle.acos() 37 | mask = angle < cls.TOL 38 | if mask.sum() == 0: 39 | angle = angle.unsqueeze(1).unsqueeze(1) 40 | return cls.vee((0.5 * angle/angle.sin())*(Rot-Rot.transpose(1, 2))) 41 | elif mask.sum() == dim_batch: 42 | # If angle is close to zero, use first-order Taylor expansion 43 | return cls.vee(Rot - Id) 44 | phi = cls.vee(Rot - Id) 45 | angle = angle 46 | phi[~mask] = cls.vee((0.5 * angle[~mask]/angle[~mask].sin()).unsqueeze( 47 | 1).unsqueeze(2)*(Rot[~mask] - Rot[~mask].transpose(1, 2))) 48 | return phi 49 | 50 | @staticmethod 51 | def vee(Phi): 52 | return torch.stack((Phi[:, 2, 1], 53 | Phi[:, 0, 2], 54 | Phi[:, 1, 0]), dim=1) 55 | 56 | @staticmethod 57 | def wedge(phi): 58 | dim_batch = phi.shape[0] 59 | zero = phi.new_zeros(dim_batch) 60 | return torch.stack((zero, -phi[:, 2], phi[:, 1], 61 | phi[:, 2], zero, -phi[:, 0], 62 | -phi[:, 1], phi[:, 0], zero), 1).view(dim_batch, 63 | 3, 3) 64 | 65 | @classmethod 66 | def from_rpy(cls, roll, pitch, yaw): 67 | return cls.rotz(yaw).bmm(cls.roty(pitch).bmm(cls.rotx(roll))) 68 | 69 | @classmethod 70 | def rotx(cls, angle_in_radians): 71 | c = angle_in_radians.cos() 72 | s = angle_in_radians.sin() 73 | mat = c.new_zeros((c.shape[0], 3, 3)) 74 | mat[:, 0, 0] = 1 75 | mat[:, 1, 1] = c 76 | mat[:, 2, 2] = c 77 | mat[:, 1, 2] = -s 78 | mat[:, 2, 1] = s 79 | return mat 80 | 81 | @classmethod 82 | def roty(cls, angle_in_radians): 83 | c = angle_in_radians.cos() 84 | s = angle_in_radians.sin() 85 | mat = c.new_zeros((c.shape[0], 3, 3)) 86 | mat[:, 1, 1] = 1 87 | mat[:, 0, 0] = c 88 | mat[:, 2, 2] = c 89 | mat[:, 0, 2] = s 90 | mat[:, 2, 0] = -s 91 | return mat 92 | 93 | @classmethod 94 | def rotz(cls, angle_in_radians): 95 | c = angle_in_radians.cos() 96 | s = angle_in_radians.sin() 97 | mat = c.new_zeros((c.shape[0], 3, 3)) 98 | mat[:, 2, 2] = 1 99 | mat[:, 0, 0] = c 100 | mat[:, 1, 1] = c 101 | mat[:, 0, 1] = -s 102 | mat[:, 1, 0] = s 103 | return mat 104 | 105 | @classmethod 106 | def isclose(cls, x, y): 107 | return (x-y).abs() < cls.TOL 108 | 109 | @classmethod 110 | def to_rpy(cls, Rots): 111 | """Convert a rotation matrix to RPY Euler angles.""" 112 | 113 | pitch = torch.atan2(-Rots[:, 2, 0], 114 | torch.sqrt(Rots[:, 0, 0]**2 + Rots[:, 1, 0]**2)) 115 | yaw = pitch.new_empty(pitch.shape) 116 | roll = pitch.new_empty(pitch.shape) 117 | 118 | near_pi_over_two_mask = cls.isclose(pitch, np.pi / 2.) 119 | near_neg_pi_over_two_mask = cls.isclose(pitch, -np.pi / 2.) 120 | 121 | remainder_inds = ~(near_pi_over_two_mask | near_neg_pi_over_two_mask) 122 | 123 | yaw[near_pi_over_two_mask] = 0 124 | roll[near_pi_over_two_mask] = torch.atan2( 125 | Rots[near_pi_over_two_mask, 0, 1], 126 | Rots[near_pi_over_two_mask, 1, 1]) 127 | 128 | yaw[near_neg_pi_over_two_mask] = 0. 129 | roll[near_neg_pi_over_two_mask] = -torch.atan2( 130 | Rots[near_neg_pi_over_two_mask, 0, 1], 131 | Rots[near_neg_pi_over_two_mask, 1, 1]) 132 | 133 | sec_pitch = 1/pitch[remainder_inds].cos() 134 | remainder_mats = Rots[remainder_inds] 135 | yaw = torch.atan2(remainder_mats[:, 1, 0] * sec_pitch, 136 | remainder_mats[:, 0, 0] * sec_pitch) 137 | roll = torch.atan2(remainder_mats[:, 2, 1] * sec_pitch, 138 | remainder_mats[:, 2, 2] * sec_pitch) 139 | rpys = torch.cat([roll.unsqueeze(dim=1), 140 | pitch.unsqueeze(dim=1), 141 | yaw.unsqueeze(dim=1)], dim=1) 142 | return rpys 143 | 144 | @classmethod 145 | def from_quaternion(cls, quat, ordering='wxyz'): 146 | """Form a rotation matrix from a unit length quaternion. 147 | Valid orderings are 'xyzw' and 'wxyz'. 148 | """ 149 | if ordering is 'xyzw': 150 | qx = quat[:, 0] 151 | qy = quat[:, 1] 152 | qz = quat[:, 2] 153 | qw = quat[:, 3] 154 | elif ordering is 'wxyz': 155 | qw = quat[:, 0] 156 | qx = quat[:, 1] 157 | qy = quat[:, 2] 158 | qz = quat[:, 3] 159 | 160 | # Form the matrix 161 | mat = quat.new_empty(quat.shape[0], 3, 3) 162 | 163 | qx2 = qx * qx 164 | qy2 = qy * qy 165 | qz2 = qz * qz 166 | 167 | mat[:, 0, 0] = 1. - 2. * (qy2 + qz2) 168 | mat[:, 0, 1] = 2. * (qx * qy - qw * qz) 169 | mat[:, 0, 2] = 2. * (qw * qy + qx * qz) 170 | 171 | mat[:, 1, 0] = 2. * (qw * qz + qx * qy) 172 | mat[:, 1, 1] = 1. - 2. * (qx2 + qz2) 173 | mat[:, 1, 2] = 2. * (qy * qz - qw * qx) 174 | 175 | mat[:, 2, 0] = 2. * (qx * qz - qw * qy) 176 | mat[:, 2, 1] = 2. * (qw * qx + qy * qz) 177 | mat[:, 2, 2] = 1. - 2. * (qx2 + qy2) 178 | return mat 179 | 180 | @classmethod 181 | def to_quaternion(cls, Rots, ordering='wxyz'): 182 | """Convert a rotation matrix to a unit length quaternion. 183 | Valid orderings are 'xyzw' and 'wxyz'. 184 | """ 185 | tmp = 1 + Rots[:, 0, 0] + Rots[:, 1, 1] + Rots[:, 2, 2] 186 | tmp[tmp < 0] = 0 187 | qw = 0.5 * torch.sqrt(tmp) 188 | qx = qw.new_empty(qw.shape[0]) 189 | qy = qw.new_empty(qw.shape[0]) 190 | qz = qw.new_empty(qw.shape[0]) 191 | 192 | near_zero_mask = qw.abs() < cls.TOL 193 | 194 | if near_zero_mask.sum() > 0: 195 | cond1_mask = near_zero_mask * \ 196 | (Rots[:, 0, 0] > Rots[:, 1, 1])*(Rots[:, 0, 0] > Rots[:, 2, 2]) 197 | cond1_inds = cond1_mask.nonzero() 198 | 199 | if len(cond1_inds) > 0: 200 | cond1_inds = cond1_inds.squeeze() 201 | R_cond1 = Rots[cond1_inds].view(-1, 3, 3) 202 | d = 2. * torch.sqrt(1. + R_cond1[:, 0, 0] - 203 | R_cond1[:, 1, 1] - R_cond1[:, 2, 2]).view(-1) 204 | qw[cond1_inds] = (R_cond1[:, 2, 1] - R_cond1[:, 1, 2]) / d 205 | qx[cond1_inds] = 0.25 * d 206 | qy[cond1_inds] = (R_cond1[:, 1, 0] + R_cond1[:, 0, 1]) / d 207 | qz[cond1_inds] = (R_cond1[:, 0, 2] + R_cond1[:, 2, 0]) / d 208 | 209 | cond2_mask = near_zero_mask * (Rots[:, 1, 1] > Rots[:, 2, 2]) 210 | cond2_inds = cond2_mask.nonzero() 211 | 212 | if len(cond2_inds) > 0: 213 | cond2_inds = cond2_inds.squeeze() 214 | R_cond2 = Rots[cond2_inds].view(-1, 3, 3) 215 | d = 2. * torch.sqrt(1. + R_cond2[:, 1, 1] - 216 | R_cond2[:, 0, 0] - R_cond2[:, 2, 2]).squeeze() 217 | tmp = (R_cond2[:, 0, 2] - R_cond2[:, 2, 0]) / d 218 | qw[cond2_inds] = tmp 219 | qx[cond2_inds] = (R_cond2[:, 1, 0] + R_cond2[:, 0, 1]) / d 220 | qy[cond2_inds] = 0.25 * d 221 | qz[cond2_inds] = (R_cond2[:, 2, 1] + R_cond2[:, 1, 2]) / d 222 | 223 | cond3_mask = near_zero_mask & cond1_mask.logical_not() & cond2_mask.logical_not() 224 | cond3_inds = cond3_mask 225 | 226 | if len(cond3_inds) > 0: 227 | R_cond3 = Rots[cond3_inds].view(-1, 3, 3) 228 | d = 2. * \ 229 | torch.sqrt(1. + R_cond3[:, 2, 2] - 230 | R_cond3[:, 0, 0] - R_cond3[:, 1, 1]).squeeze() 231 | qw[cond3_inds] = (R_cond3[:, 1, 0] - R_cond3[:, 0, 1]) / d 232 | qx[cond3_inds] = (R_cond3[:, 0, 2] + R_cond3[:, 2, 0]) / d 233 | qy[cond3_inds] = (R_cond3[:, 2, 1] + R_cond3[:, 1, 2]) / d 234 | qz[cond3_inds] = 0.25 * d 235 | 236 | far_zero_mask = near_zero_mask.logical_not() 237 | far_zero_inds = far_zero_mask 238 | if len(far_zero_inds) > 0: 239 | R_fz = Rots[far_zero_inds] 240 | d = 4. * qw[far_zero_inds] 241 | qx[far_zero_inds] = (R_fz[:, 2, 1] - R_fz[:, 1, 2]) / d 242 | qy[far_zero_inds] = (R_fz[:, 0, 2] - R_fz[:, 2, 0]) / d 243 | qz[far_zero_inds] = (R_fz[:, 1, 0] - R_fz[:, 0, 1]) / d 244 | 245 | # Check ordering last 246 | if ordering is 'xyzw': 247 | quat = torch.stack([qx, qy, qz, qw], dim=1) 248 | elif ordering is 'wxyz': 249 | quat = torch.stack([qw, qx, qy, qz], dim=1) 250 | return quat 251 | 252 | @classmethod 253 | def normalize(cls, Rots): 254 | U, _, V = torch.svd(Rots) 255 | S = cls.Id.clone().repeat(Rots.shape[0], 1, 1) 256 | S[:, 2, 2] = torch.det(U) * torch.det(V) 257 | return U.bmm(S).bmm(V.transpose(1, 2)) 258 | 259 | @classmethod 260 | def dnormalize(cls, Rots): 261 | U, _, V = torch.svd(Rots) 262 | S = cls.dId.clone().repeat(Rots.shape[0], 1, 1) 263 | S[:, 2, 2] = torch.det(U) * torch.det(V) 264 | return U.bmm(S).bmm(V.transpose(1, 2)) 265 | 266 | @classmethod 267 | def qmul(cls, q, r, ordering='wxyz'): 268 | """ 269 | Multiply quaternion(s) q with quaternion(s) r. 270 | """ 271 | terms = cls.bouter(r, q) 272 | w = terms[:, 0, 0] - terms[:, 1, 1] - terms[:, 2, 2] - terms[:, 3, 3] 273 | x = terms[:, 0, 1] + terms[:, 1, 0] - terms[:, 2, 3] + terms[:, 3, 2] 274 | y = terms[:, 0, 2] + terms[:, 1, 3] + terms[:, 2, 0] - terms[:, 3, 1] 275 | z = terms[:, 0, 3] - terms[:, 1, 2] + terms[:, 2, 1] + terms[:, 3, 0] 276 | xyz = torch.stack((x, y, z), dim=1) 277 | xyz[w < 0] *= -1 278 | w[w < 0] *= -1 279 | if ordering == 'wxyz': 280 | q = torch.cat((w.unsqueeze(1), xyz), dim=1) 281 | else: 282 | q = torch.cat((xyz, w.unsqueeze(1)), dim=1) 283 | return q / q.norm(dim=1, keepdim=True) 284 | 285 | @staticmethod 286 | def sinc(x): 287 | return x.sin() / x 288 | 289 | @classmethod 290 | def qexp(cls, xi, ordering='wxyz'): 291 | """ 292 | Convert exponential maps to quaternions. 293 | """ 294 | theta = xi.norm(dim=1, keepdim=True) 295 | w = (0.5*theta).cos() 296 | xyz = 0.5*cls.sinc(0.5*theta/np.pi)*xi 297 | return torch.cat((w, xyz), 1) 298 | 299 | @classmethod 300 | def qlog(cls, q, ordering='wxyz'): 301 | """ 302 | Applies the log map to quaternions. 303 | """ 304 | n = 0.5*torch.norm(q[:, 1:], p=2, dim=1, keepdim=True) 305 | n = torch.clamp(n, min=1e-8) 306 | q = q[:, 1:] * torch.acos(torch.clamp(q[:, :1], min=-1.0, max=1.0)) 307 | r = q / n 308 | return r 309 | 310 | @classmethod 311 | def qinv(cls, q, ordering='wxyz'): 312 | "Quaternion inverse" 313 | r = torch.empty_like(q) 314 | if ordering == 'wxyz': 315 | r[:, 1:4] = -q[:, 1:4] 316 | r[:, 0] = q[:, 0] 317 | else: 318 | r[:, :3] = -q[:, :3] 319 | r[:, 3] = q[:, 3] 320 | return r 321 | 322 | @classmethod 323 | def qnorm(cls, q): 324 | "Quaternion normalization" 325 | return q / q.norm(dim=1, keepdim=True) 326 | 327 | @classmethod 328 | def qinterp(cls, qs, t, t_int): 329 | idxs = np.searchsorted(t, t_int) 330 | idxs0 = idxs-1 331 | idxs0[idxs0 < 0] = 0 332 | idxs1 = idxs 333 | idxs1[idxs1 == t.shape[0]] = t.shape[0] - 1 334 | q0 = qs[idxs0] 335 | q1 = qs[idxs1] 336 | tau = torch.zeros_like(t_int) 337 | dt = (t[idxs1]-t[idxs0])[idxs0 != idxs1] 338 | tau[idxs0 != idxs1] = (t_int-t[idxs0])[idxs0 != idxs1]/dt 339 | return cls.slerp(q0, q1, tau) 340 | 341 | @classmethod 342 | def slerp(cls, q0, q1, tau, DOT_THRESHOLD = 0.9995): 343 | """Spherical linear interpolation.""" 344 | 345 | dot = (q0*q1).sum(dim=1) 346 | q1[dot < 0] = -q1[dot < 0] 347 | dot[dot < 0] = -dot[dot < 0] 348 | 349 | q = torch.zeros_like(q0) 350 | tmp = q0 + tau.unsqueeze(1) * (q1 - q0) 351 | tmp = tmp[dot > DOT_THRESHOLD] 352 | q[dot > DOT_THRESHOLD] = tmp / tmp.norm(dim=1, keepdim=True) 353 | 354 | theta_0 = dot.acos() 355 | sin_theta_0 = theta_0.sin() 356 | theta = theta_0 * tau 357 | sin_theta = theta.sin() 358 | s0 = (theta.cos() - dot * sin_theta / sin_theta_0).unsqueeze(1) 359 | s1 = (sin_theta / sin_theta_0).unsqueeze(1) 360 | q[dot < DOT_THRESHOLD] = ((s0 * q0) + (s1 * q1))[dot < DOT_THRESHOLD] 361 | return q / q.norm(dim=1, keepdim=True) 362 | 363 | @staticmethod 364 | def bouter(vec1, vec2): 365 | """batch outer product""" 366 | return torch.einsum('bi, bj -> bij', vec1, vec2) 367 | 368 | @staticmethod 369 | def btrace(mat): 370 | """batch matrix trace""" 371 | return torch.einsum('bii -> b', mat) 372 | 373 | 374 | class CPUSO3: 375 | # tolerance criterion 376 | TOL = 1e-8 377 | Id = torch.eye(3) 378 | 379 | @classmethod 380 | def qmul(cls, q, r): 381 | """ 382 | Multiply quaternion(s) q with quaternion(s) r. 383 | """ 384 | # Compute outer product 385 | terms = cls.outer(r, q) 386 | w = terms[0, 0] - terms[1, 1] - terms[2, 2] - terms[3, 3] 387 | x = terms[0, 1] + terms[1, 0] - terms[2, 3] + terms[3, 2] 388 | y = terms[0, 2] + terms[1, 3] + terms[2, 0] - terms[3, 1] 389 | z = terms[0, 3] - terms[1, 2] + terms[2, 1] + terms[3, 0] 390 | return torch.stack((w, x, y, z)) 391 | 392 | @staticmethod 393 | def outer(a, b): 394 | return torch.einsum('i, j -> ij', a, b) -------------------------------------------------------------------------------- /src/losses.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from src.utils import bmmt, bmv, bmtv, bbmv, bmtm 4 | from src.lie_algebra import SO3 5 | import matplotlib.pyplot as plt 6 | 7 | 8 | class BaseLoss(torch.nn.Module): 9 | 10 | def __init__(self, min_N, max_N, dt): 11 | super().__init__() 12 | # windows sizes 13 | self.min_N = min_N 14 | self.max_N = max_N 15 | self.min_train_freq = 2 ** self.min_N 16 | self.max_train_freq = 2 ** self.max_N 17 | # sampling time 18 | self.dt = dt # (s) 19 | 20 | 21 | class GyroLoss(BaseLoss): 22 | """Loss for low-frequency orientation increment""" 23 | 24 | def __init__(self, w, min_N, max_N, dt, target, huber): 25 | super().__init__(min_N, max_N, dt) 26 | # weights on loss 27 | self.w = w 28 | self.sl = torch.nn.SmoothL1Loss() 29 | if target == 'rotation matrix': 30 | self.forward = self.forward_with_rotation_matrices 31 | elif target == 'quaternion': 32 | self.forward = self.forward_with_quaternions 33 | elif target == 'rotation matrix mask': 34 | self.forward = self.forward_with_rotation_matrices_mask 35 | elif target == 'quaternion mask': 36 | self.forward = self.forward_with_quaternion_mask 37 | self.huber = huber 38 | self.weight = torch.ones(1, 1, 39 | self.min_train_freq).cuda()/self.min_train_freq 40 | self.N0 = 5 # remove first N0 increment in loss due not account padding 41 | 42 | def f_huber(self, rs): 43 | """Huber loss function""" 44 | loss = self.w*self.sl(rs/self.huber, 45 | torch.zeros_like(rs))*(self.huber**2) 46 | return loss 47 | 48 | def forward_with_rotation_matrices(self, xs, hat_xs): 49 | """Forward errors with rotation matrices""" 50 | N = xs.shape[0] 51 | Xs = SO3.exp(xs[:, ::self.min_train_freq].reshape(-1, 3).double()) 52 | hat_xs = self.dt*hat_xs.reshape(-1, 3).double() 53 | Omegas = SO3.exp(hat_xs[:, :3]) 54 | # compute increment at min_train_freq by decimation 55 | for k in range(self.min_N): 56 | Omegas = Omegas[::2].bmm(Omegas[1::2]) 57 | rs = SO3.log(bmtm(Omegas, Xs)).reshape(N, -1, 3)[:, self.N0:] 58 | loss = self.f_huber(rs) 59 | # compute increment from min_train_freq to max_train_freq 60 | for k in range(self.min_N, self.max_N): 61 | Omegas = Omegas[::2].bmm(Omegas[1::2]) 62 | Xs = Xs[::2].bmm(Xs[1::2]) 63 | rs = SO3.log(bmtm(Omegas, Xs)).reshape(N, -1, 3)[:, self.N0:] 64 | loss = loss + self.f_huber(rs)/(2**(k - self.min_N + 1)) 65 | return loss 66 | 67 | def forward_with_quaternions(self, xs, hat_xs): 68 | """Forward errors with quaternion""" 69 | N = xs.shape[0] 70 | Xs = SO3.qexp(xs[:, ::self.min_train_freq].reshape(-1, 3).double()) 71 | hat_xs = self.dt*hat_xs.reshape(-1, 3).double() 72 | Omegas = SO3.qexp(hat_xs[:, :3]) 73 | # compute increment at min_train_freq by decimation 74 | for k in range(self.min_N): 75 | Omegas = SO3.qmul(Omegas[::2], Omegas[1::2]) 76 | rs = SO3.qlog(SO3.qmul(SO3.qinv(Omegas), Xs)).reshape(N, 77 | -1, 3)[:, self.N0:] 78 | loss = self.f_huber(rs) 79 | # compute increment from min_train_freq to max_train_freq 80 | for k in range(self.min_N, self.max_N): 81 | Omegas = SO3.qmul(Omegas[::2], Omegas[1::2]) 82 | Xs = SO3.qmul(Xs[::2], Xs[1::2]) 83 | rs = SO3.qlog(SO3.qmul(SO3.qinv(Omegas), Xs)) 84 | rs = rs.view(N, -1, 3)[:, self.N0:] 85 | loss = loss + self.f_huber(rs)/(2**(k - self.min_N + 1)) 86 | return loss 87 | 88 | def forward_with_rotation_matrices_mask(self, xs, hat_xs): 89 | """Forward errors with rotation matrices""" 90 | N = xs.shape[0] 91 | masks = xs[:, :, 3].unsqueeze(1) 92 | masks = torch.nn.functional.conv1d(masks, self.weight, bias=None, 93 | stride=self.min_train_freq).double().transpose(1, 2) 94 | masks[masks < 1] = 0 95 | Xs = SO3.exp(xs[:, ::self.min_train_freq, :3].reshape(-1, 3).double()) 96 | hat_xs = self.dt*hat_xs.reshape(-1, 3).double() 97 | Omegas = SO3.exp(hat_xs[:, :3]) 98 | # compute increment at min_train_freq by decimation 99 | for k in range(self.min_N): 100 | Omegas = Omegas[::2].bmm(Omegas[1::2]) 101 | rs = SO3.log(bmtm(Omegas, Xs)).reshape(N, -1, 3)[:, self.N0:] 102 | loss = self.f_huber(rs) 103 | # compute increment from min_train_freq to max_train_freq 104 | for k in range(self.min_N, self.max_N): 105 | Omegas = Omegas[::2].bmm(Omegas[1::2]) 106 | Xs = Xs[::2].bmm(Xs[1::2]) 107 | masks = masks[:, ::2] * masks[:, 1::2] 108 | rs = SO3.log(bmtm(Omegas, Xs)).reshape(N, -1, 3)[:, self.N0:] 109 | rs = rs[masks[:, self.N0:].squeeze(2) == 1] 110 | loss = loss + self.f_huber(rs[:,2])/(2**(k - self.min_N + 1)) 111 | return loss 112 | 113 | def forward_with_quaternion_mask(self, xs, hat_xs): 114 | """Forward errors with quaternion""" 115 | N = xs.shape[0] 116 | masks = xs[:, :, 3].unsqueeze(1) 117 | masks = torch.nn.functional.conv1d(masks, self.weight, bias=None, 118 | stride=self.min_train_freq).double().transpose(1, 2) 119 | masks[masks < 1] = 0 120 | Xs = SO3.qexp(xs[:, ::self.min_train_freq, :3].reshape(-1, 3).double()) 121 | hat_xs = self.dt*hat_xs.reshape(-1, 3).double() 122 | Omegas = SO3.qexp(hat_xs[:, :3]) 123 | # compute increment at min_train_freq by decimation 124 | for k in range(self.min_N): 125 | Omegas = SO3.qmul(Omegas[::2], Omegas[1::2]) 126 | rs = SO3.qlog(SO3.qmul(SO3.qinv(Omegas), Xs)).reshape(N, 127 | -1, 3)[:, self.N0:] 128 | rs = rs[masks[:, self.N0:].squeeze(2) == 1] 129 | loss = self.f_huber(rs) 130 | # compute increment from min_train_freq to max_train_freq 131 | for k in range(self.min_N, self.max_N): 132 | Omegas = SO3.qmul(Omegas[::2], Omegas[1::2]) 133 | Xs = SO3.qmul(Xs[::2], Xs[1::2]) 134 | masks = masks[:, ::2] * masks[:, 1::2] 135 | rs = SO3.qlog(SO3.qmul(SO3.qinv(Omegas), Xs)).reshape(N, 136 | -1, 3)[:, self.N0:] 137 | rs = rs[masks[:, self.N0:].squeeze(2) == 1] 138 | loss = loss + self.f_huber(rs)/(2**(k - self.min_N + 1)) 139 | return loss 140 | -------------------------------------------------------------------------------- /src/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | from src.utils import bmtm, bmtv, bmmt, bbmv 5 | from src.lie_algebra import SO3 6 | 7 | 8 | class BaseNet(torch.nn.Module): 9 | def __init__(self, in_dim, out_dim, c0, dropout, ks, ds, momentum): 10 | super().__init__() 11 | self.in_dim = in_dim 12 | self.out_dim = out_dim 13 | # channel dimension 14 | c1 = 2*c0 15 | c2 = 2*c1 16 | c3 = 2*c2 17 | # kernel dimension (odd number) 18 | k0 = ks[0] 19 | k1 = ks[1] 20 | k2 = ks[2] 21 | k3 = ks[3] 22 | # dilation dimension 23 | d0 = ds[0] 24 | d1 = ds[1] 25 | d2 = ds[2] 26 | # padding 27 | p0 = (k0-1) + d0*(k1-1) + d0*d1*(k2-1) + d0*d1*d2*(k3-1) 28 | # nets 29 | self.cnn = torch.nn.Sequential( 30 | torch.nn.ReplicationPad1d((p0, 0)), # padding at start 31 | torch.nn.Conv1d(in_dim, c0, k0, dilation=1), 32 | torch.nn.BatchNorm1d(c0, momentum=momentum), 33 | torch.nn.GELU(), 34 | torch.nn.Dropout(dropout), 35 | torch.nn.Conv1d(c0, c1, k1, dilation=d0), 36 | torch.nn.BatchNorm1d(c1, momentum=momentum), 37 | torch.nn.GELU(), 38 | torch.nn.Dropout(dropout), 39 | torch.nn.Conv1d(c1, c2, k2, dilation=d0*d1), 40 | torch.nn.BatchNorm1d(c2, momentum=momentum), 41 | torch.nn.GELU(), 42 | torch.nn.Dropout(dropout), 43 | torch.nn.Conv1d(c2, c3, k3, dilation=d0*d1*d2), 44 | torch.nn.BatchNorm1d(c3, momentum=momentum), 45 | torch.nn.GELU(), 46 | torch.nn.Dropout(dropout), 47 | torch.nn.Conv1d(c3, out_dim, 1, dilation=1), 48 | torch.nn.ReplicationPad1d((0, 0)), # no padding at end 49 | ) 50 | # for normalizing inputs 51 | self.mean_u = torch.nn.Parameter(torch.zeros(in_dim), 52 | requires_grad=False) 53 | self.std_u = torch.nn.Parameter(torch.ones(in_dim), 54 | requires_grad=False) 55 | 56 | def forward(self, us): 57 | u = self.norm(us).transpose(1, 2) 58 | y = self.cnn(u) 59 | return y 60 | 61 | def norm(self, us): 62 | return (us-self.mean_u)/self.std_u 63 | 64 | def set_normalized_factors(self, mean_u, std_u): 65 | self.mean_u = torch.nn.Parameter(mean_u.cuda(), requires_grad=False) 66 | self.std_u = torch.nn.Parameter(std_u.cuda(), requires_grad=False) 67 | 68 | 69 | class GyroNet(BaseNet): 70 | def __init__(self, in_dim, out_dim, c0, dropout, ks, ds, momentum, 71 | gyro_std): 72 | super().__init__(in_dim, out_dim, c0, dropout, ks, ds, momentum) 73 | gyro_std = torch.Tensor(gyro_std) 74 | self.gyro_std = torch.nn.Parameter(gyro_std, requires_grad=False) 75 | 76 | gyro_Rot = 0.05*torch.randn(3, 3).cuda() 77 | self.gyro_Rot = torch.nn.Parameter(gyro_Rot) 78 | self.Id3 = torch.eye(3).cuda() 79 | 80 | def forward(self, us): 81 | ys = super().forward(us) 82 | Rots = (self.Id3 + self.gyro_Rot).expand(us.shape[0], us.shape[1], 3, 3) 83 | Rot_us = bbmv(Rots, us[:, :, :3]) 84 | return self.gyro_std*ys.transpose(1, 2) + Rot_us 85 | 86 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import pickle 4 | import yaml 5 | 6 | 7 | def pload(*f_names): 8 | """Pickle load""" 9 | f_name = os.path.join(*f_names) 10 | with open(f_name, "rb") as f: 11 | pickle_dict = pickle.load(f) 12 | return pickle_dict 13 | 14 | def pdump(pickle_dict, *f_names): 15 | """Pickle dump""" 16 | f_name = os.path.join(*f_names) 17 | with open(f_name, "wb") as f: 18 | pickle.dump(pickle_dict, f) 19 | 20 | def mkdir(*paths): 21 | '''Create a directory if not existing.''' 22 | path = os.path.join(*paths) 23 | if not os.path.exists(path): 24 | os.mkdir(path) 25 | 26 | def yload(*f_names): 27 | """YAML load""" 28 | f_name = os.path.join(*f_names) 29 | with open(f_name, 'r') as f: 30 | yaml_dict = yaml.load(f) 31 | return yaml_dict 32 | 33 | def ydump(yaml_dict, *f_names): 34 | """YAML dump""" 35 | f_name = os.path.join(*f_names) 36 | with open(f_name, 'w') as f: 37 | yaml.dump(yaml_dict, f, default_flow_style=False) 38 | 39 | def bmv(mat, vec): 40 | """batch matrix vector product""" 41 | return torch.einsum('bij, bj -> bi', mat, vec) 42 | 43 | def bbmv(mat, vec): 44 | """double batch matrix vector product""" 45 | return torch.einsum('baij, baj -> bai', mat, vec) 46 | 47 | def bmtv(mat, vec): 48 | """batch matrix transpose vector product""" 49 | return torch.einsum('bji, bj -> bi', mat, vec) 50 | 51 | def bmtm(mat1, mat2): 52 | """batch matrix transpose matrix product""" 53 | return torch.einsum("bji, bjk -> bik", mat1, mat2) 54 | 55 | def bmmt(mat1, mat2): 56 | """batch matrix matrix transpose product""" 57 | return torch.einsum("bij, bkj -> bik", mat1, mat2) 58 | --------------------------------------------------------------------------------